diff --git a/README.md b/README.md index d1c2040bdce018dd676d5231962a18c2c9e4ab9b..7fbed77a846eb6df7f03586813c3547e16c8ab08 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Connect integration enabled by default. Cookies are encrypted and signed on the on [Gorilla Sessions](https://www.gorillatoolkit.org/pkg/sessions). PAA tokens (gateway access tokens) are generated and signed according to the JWT spec by using [jwt-go](https://github.com/dgrijalva/jwt-go) signed with a 512 bit HMAC. Hosts provided by the user are verified against what was provided by -the server. +the server. Finally, the client's ip address needs to match the one it obtained the token with. ## How to build ```bash diff --git a/client/remote.go b/client/remote.go new file mode 100644 index 0000000000000000000000000000000000000000..f7e70b097e514f1f761e690e9a57c307f61f39d5 --- /dev/null +++ b/client/remote.go @@ -0,0 +1,49 @@ +package client + +import ( + "context" + "net/http" + "strings" +) + +const ( + ClientIPCtx = "ClientIP" + ProxyAddressesCtx = "ProxyAddresses" + RemoteAddressCtx = "RemoteAddress" +) + +func EnrichContext(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + h := r.Header.Get("X-Forwarded-For") + if h != "" { + var proxies []string + ips := strings.Split(h, ",") + for i := range ips { + ips[i] = strings.TrimSpace(ips[i]) + } + clientIp := ips[0] + if len(ips) > 1 { + proxies = ips[1:] + } + ctx = context.WithValue(ctx, ClientIPCtx, clientIp) + ctx = context.WithValue(ctx, ProxyAddressesCtx, proxies) + } + + remote := r.Header.Get("REMOTE_ADDR") + ctx = context.WithValue(ctx, RemoteAddressCtx, remote) + if h == "" { + ctx = context.WithValue(ctx, ClientIPCtx, remote) + } + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func GetClientIp(ctx context.Context) string { + s, ok := ctx.Value(ClientIPCtx).(string) + if !ok { + return "" + } + return s +} diff --git a/main.go b/main.go index 98fed5c963a500ae073faee331e80c81c15675ca..0d9203fe670dcff2ff2819853cb18410fccd72d6 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "github.com/bolkedebruin/rdpgw/api" + "github.com/bolkedebruin/rdpgw/client" "github.com/bolkedebruin/rdpgw/config" "github.com/bolkedebruin/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/security" @@ -122,8 +123,8 @@ func main() { HandlerConf: &handlerConfig, } - http.HandleFunc("/remoteDesktopGateway/", gw.HandleGatewayProtocol) - http.Handle("/connect", api.Authenticated(http.HandlerFunc(api.HandleDownload))) + http.Handle("/remoteDesktopGateway/", client.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol))) + http.Handle("/connect", client.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload)))) http.Handle("/metrics", promhttp.Handler()) http.HandleFunc("/callback", api.HandleCallback) diff --git a/protocol/handler.go b/protocol/handler.go index 138b6057b165bfe9990fdd77e554743baf3db73b..dc55fac54a47addd40833141575b098a96195a37 100644 --- a/protocol/handler.go +++ b/protocol/handler.go @@ -5,6 +5,7 @@ import ( "context" "encoding/binary" "errors" + "github.com/bolkedebruin/rdpgw/client" "io" "log" "net" @@ -96,7 +97,7 @@ func (h *Handler) Process(ctx context.Context) error { _, cookie := readCreateTunnelRequest(pkt) if h.VerifyTunnelCreate != nil { if ok, _ := h.VerifyTunnelCreate(ctx, cookie); !ok { - log.Printf("Invalid PAA cookie received") + log.Printf("Invalid PAA cookie received from client %s", client.GetClientIp(ctx)) return errors.New("invalid PAA cookie") } } diff --git a/protocol/rdpgw.go b/protocol/rdpgw.go index 12bff90ca536db0aaaba59894a06bf74583c3fd0..f3e321ed016f7850d02c1efcaad28a3f4295e6d6 100644 --- a/protocol/rdpgw.go +++ b/protocol/rdpgw.go @@ -2,6 +2,7 @@ package protocol import ( "context" + "github.com/bolkedebruin/rdpgw/client" "github.com/bolkedebruin/rdpgw/transport" "github.com/gorilla/websocket" "github.com/patrickmn/go-cache" @@ -48,9 +49,8 @@ type SessionInfo struct { ConnId string TransportIn transport.Transport TransportOut transport.Transport - RemoteAddress string - ProxyAddress string RemoteServer string + ClientIp string } var upgrader = websocket.Upgrader{} @@ -118,7 +118,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) return } - log.Printf("Opening RDGOUT for client %s", out.Conn.RemoteAddr().String()) + log.Printf("Opening RDGOUT for client %s", client.GetClientIp(r.Context())) s.TransportOut = out out.SendAccept(true) @@ -139,13 +139,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s s.TransportIn = in c.Set(s.ConnId, s, cache.DefaultExpiration) - log.Printf("Opening RDGIN for client %s", in.Conn.RemoteAddr().String()) + log.Printf("Opening RDGIN for client %s", client.GetClientIp(r.Context())) in.SendAccept(false) // read some initial data in.Drain() - log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String()) + log.Printf("Legacy handshake done for client %s", client.GetClientIp(r.Context())) handler := NewHandler(s, g.HandlerConf) handler.Process(r.Context()) } diff --git a/security/jwt.go b/security/jwt.go index e90038b0ff9544f2063117b1a55e23695ba31d8f..836245d94b31beb6589232ac098298556260bdeb 100644 --- a/security/jwt.go +++ b/security/jwt.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/bolkedebruin/rdpgw/client" "github.com/bolkedebruin/rdpgw/protocol" "github.com/dgrijalva/jwt-go/v4" "log" @@ -15,6 +16,7 @@ var ExpiryTime time.Duration = 5 type customClaims struct { RemoteServer string `json:"remoteServer"` + ClientIP string `json:"clientIp"` jwt.StandardClaims } @@ -34,6 +36,7 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) { if c, ok := token.Claims.(*customClaims); ok && token.Valid { s := getSessionInfo(ctx) s.RemoteServer = c.RemoteServer + s.ClientIp = client.GetClientIp(ctx) return true, nil } @@ -48,7 +51,13 @@ func VerifyServerFunc(ctx context.Context, host string) (bool, error) { } if s.RemoteServer != host { - log.Printf("Client host %s does not match token host %s", host, s.RemoteServer) + log.Printf("Client specified host %s does not match token host %s", host, s.RemoteServer) + return false, nil + } + + if s.ClientIp != client.GetClientIp(ctx) { + log.Printf("Current client ip address %s does not match token client ip %s", + client.GetClientIp(ctx), s.ClientIp) return false, nil }