diff --git a/cmd/rdpgw/common/identity.go b/cmd/rdpgw/common/identity.go new file mode 100644 index 0000000000000000000000000000000000000000..baa8b36f3a7963cfb14b83fdd4d7a3990880ca58 --- /dev/null +++ b/cmd/rdpgw/common/identity.go @@ -0,0 +1,160 @@ +package common + +import ( + "context" + "github.com/google/uuid" + "net/http" + "time" +) + +const ( + CTXKey = "github.com/bolkedebruin/rdpgw/common/identity" + + AttrRemoteAddr = "remoteAddr" + AttrClientIp = "clientIp" + AttrProxies = "proxyAddresses" + AttrAccessToken = "accessToken" // todo remove for security reasons +) + +type Identity interface { + UserName() string + SetUserName(string) + DisplayName() string + SetDisplayName(string) + Domain() string + SetDomain(string) + Authenticated() bool + SetAuthenticated(bool) + AuthTime() time.Time + SetAuthTime(time2 time.Time) + SessionId() string + SetAttribute(string, interface{}) + GetAttribute(string) interface{} + Attributes() map[string]interface{} + DelAttribute(string) + Email() string + SetEmail(string) + Expiry() time.Time + SetExpiry(time.Time) +} + +func AddToRequestCtx(id Identity, r *http.Request) *http.Request { + ctx := r.Context() + ctx = context.WithValue(ctx, CTXKey, id) + return r.WithContext(ctx) +} + +func FromRequestCtx(r *http.Request) Identity { + return FromCtx(r.Context()) +} + +func FromCtx(ctx context.Context) Identity { + if id, ok := ctx.Value(CTXKey).(Identity); ok { + return id + } + return nil +} + +type User struct { + authenticated bool + domain string + userName string + displayName string + email string + authTime time.Time + sessionId string + expiry time.Time + attributes map[string]interface{} + groupMembership map[string]bool +} + +func NewUser() *User { + uuid := uuid.New().String() + return &User{ + attributes: make(map[string]interface{}), + groupMembership: make(map[string]bool), + sessionId: uuid, + } +} + +func (u *User) UserName() string { + return u.userName +} + +func (u *User) SetUserName(s string) { + u.userName = s +} + +func (u *User) DisplayName() string { + if u.displayName == "" { + return u.userName + } + return u.displayName +} + +func (u *User) SetDisplayName(s string) { + u.displayName = s +} + +func (u *User) Domain() string { + return u.domain +} + +func (u *User) SetDomain(s string) { + u.domain = s +} + +func (u *User) Authenticated() bool { + return u.authenticated +} + +func (u *User) SetAuthenticated(b bool) { + u.authenticated = b +} + +func (u *User) AuthTime() time.Time { + return u.authTime +} + +func (u *User) SetAuthTime(t time.Time) { + u.authTime = t +} + +func (u *User) SessionId() string { + return u.sessionId +} + +func (u *User) SetAttribute(s string, i interface{}) { + u.attributes[s] = i +} + +func (u *User) GetAttribute(s string) interface{} { + if found, ok := u.attributes[s]; ok { + return found + } + return nil +} + +func (u *User) Attributes() map[string]interface{} { + return u.attributes +} + +func (u *User) DelAttribute(s string) { + delete(u.attributes, s) +} + +func (u *User) Email() string { + return u.email +} + +func (u *User) SetEmail(s string) { + u.email = s +} + +func (u *User) Expiry() time.Time { + return u.expiry +} + +func (u *User) SetExpiry(t time.Time) { + u.expiry = t +} diff --git a/cmd/rdpgw/common/remote.go b/cmd/rdpgw/common/remote.go index f835e6e12f495e4e7ebf01acfa6ba424b2af48eb..bb73a7225923358c0bec6791f84356509ab54fe9 100644 --- a/cmd/rdpgw/common/remote.go +++ b/cmd/rdpgw/common/remote.go @@ -10,16 +10,17 @@ import ( ) const ( - ClientIPCtx = "ClientIP" - ProxyAddressesCtx = "ProxyAddresses" - RemoteAddressCtx = "RemoteAddress" - TunnelCtx = "TUNNEL" - UsernameCtx = "preferred_username" + CtxAccessToken = "github.com/bolkedebruin/rdpgw/oidc/access_token" ) func EnrichContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() + id := FromRequestCtx(r) + if id == nil { + id = NewUser() + } + log.Printf("Identity SessionId: %s, UserName: %s: Authenticated: %t", + id.SessionId(), id.UserName(), id.Authenticated()) h := r.Header.Get("X-Forwarded-For") if h != "" { @@ -32,41 +33,36 @@ func EnrichContext(next http.Handler) http.Handler { if len(ips) > 1 { proxies = ips[1:] } - ctx = context.WithValue(ctx, ClientIPCtx, clientIp) - ctx = context.WithValue(ctx, ProxyAddressesCtx, proxies) + id.SetAttribute(AttrClientIp, clientIp) + id.SetAttribute(AttrProxies, proxies) } - ctx = context.WithValue(ctx, RemoteAddressCtx, r.RemoteAddr) + id.SetAttribute(AttrRemoteAddr, r.RemoteAddr) if h == "" { clientIp, _, _ := net.SplitHostPort(r.RemoteAddr) - ctx = context.WithValue(ctx, ClientIPCtx, clientIp) + id.SetAttribute(AttrClientIp, clientIp) } - next.ServeHTTP(w, r.WithContext(ctx)) + next.ServeHTTP(w, AddToRequestCtx(id, r)) }) } func FixKerberosContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - id := goidentity.FromHTTPRequestContext(r) - if id != nil { - ctx = context.WithValue(ctx, UsernameCtx, id.UserName()) + gid := goidentity.FromHTTPRequestContext(r) + if gid != nil { + id := FromRequestCtx(r) + id.SetUserName(gid.UserName()) + id.SetAuthenticated(gid.Authenticated()) + id.SetDomain(gid.Domain()) + id.SetAuthTime(gid.AuthTime()) + r = AddToRequestCtx(id, r) } - next.ServeHTTP(w, r.WithContext(ctx)) + next.ServeHTTP(w, r) }) } -func GetClientIp(ctx context.Context) string { - s, ok := ctx.Value(ClientIPCtx).(string) - if !ok { - return "" - } - return s -} - func GetAccessToken(ctx context.Context) string { - token, ok := ctx.Value("access_token").(string) + token, ok := ctx.Value(CtxAccessToken).(string) if !ok { log.Printf("cannot get access token from context") return "" diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index d7a78fc77df25f4728218ac8b7b4b2216e015daf..5f02cffcbf47a08020794b25f7f3d378dd370cf0 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -226,7 +226,7 @@ func main() { oidc := initOIDC(url, store) http.Handle("/connect", common.EnrichContext(oidc.Authenticated(http.HandlerFunc(h.HandleDownload)))) http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol))) - http.HandleFunc("/callback", oidc.HandleCallback) + http.Handle("/callback", common.EnrichContext(http.HandlerFunc(oidc.HandleCallback))) } http.Handle("/metrics", promhttp.Handler()) http.HandleFunc("/tokeninfo", web.TokenInfo) diff --git a/cmd/rdpgw/protocol/gateway.go b/cmd/rdpgw/protocol/gateway.go index 2a4646bb29d0ef4e48fdf38d258983154f482d84..b2b42c4d52893678bbc3bd376354b1d31c47120a 100644 --- a/cmd/rdpgw/protocol/gateway.go +++ b/cmd/rdpgw/protocol/gateway.go @@ -61,24 +61,20 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) var t *Tunnel ctx := r.Context() + id := common.FromRequestCtx(r) connId := r.Header.Get(rdgConnectionIdKey) x, found := c.Get(connId) if !found { t = &Tunnel{ RDGId: connId, - RemoteAddr: ctx.Value(common.ClientIPCtx).(string), - } - // username can be nil with openid & kerberos as it's only available later - // todo grab kerberos principal now? - username := ctx.Value(common.UsernameCtx) - if username != nil { - t.UserName = username.(string) + RemoteAddr: id.GetAttribute(common.AttrRemoteAddr).(string), + User: id, } } else { t = x.(*Tunnel) } - ctx = context.WithValue(ctx, common.TunnelCtx, t) + ctx = context.WithValue(ctx, CtxTunnel, t) if r.Method == MethodRDGOUT { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { @@ -187,13 +183,14 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) { log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil) + id := common.FromRequestCtx(r) if r.Method == MethodRDGOUT { out, err := transport.NewLegacy(w) if err != nil { log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) return } - log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context())) + log.Printf("Opening RDGOUT for client %s", id.GetAttribute(common.AttrClientIp)) t.transportOut = out out.SendAccept(true) @@ -215,13 +212,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t t.transportIn = in c.Set(t.RDGId, t, cache.DefaultExpiration) - log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context())) + log.Printf("Opening RDGIN for client %s", id.GetAttribute(common.AttrClientIp)) in.SendAccept(false) // read some initial data in.Drain() - log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context())) + log.Printf("Legacy handshakeRequest done for client %s", id.GetAttribute(common.AttrClientIp)) handler := NewProcessor(g, t) RegisterTunnel(t, handler) defer RemoveTunnel(t) diff --git a/cmd/rdpgw/protocol/process.go b/cmd/rdpgw/protocol/process.go index 3cfa9fcae1d9230f2838b13ff73b8298d2eca19d..9dc4788f1e404565de9cae30d7a8dd1e75e331b4 100644 --- a/cmd/rdpgw/protocol/process.go +++ b/cmd/rdpgw/protocol/process.go @@ -51,7 +51,7 @@ func (p *Processor) Process(ctx context.Context) error { switch pt { case PKT_TYPE_HANDSHAKE_REQUEST: - log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx)) + log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(common.AttrClientIp)) if p.state != SERVER_STATE_INITIALIZED { log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED) msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR) @@ -81,7 +81,7 @@ func (p *Processor) Process(ctx context.Context) error { _, cookie := p.tunnelRequest(pkt) if p.gw.CheckPAACookie != nil { if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok { - log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx)) + log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(common.AttrClientIp)) msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) p.tunnel.Write(msg) return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) @@ -180,9 +180,9 @@ func (p *Processor) Process(ctx context.Context) error { } } -// Creates a packet the is a response to a handshakeRequest request +// Creates a packet and is a response to a handshakeRequest request // HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux -// but could be in Windows. However the NTLM protocol is insecure +// but could be in Windows. However, the NTLM protocol is insecure func (p *Processor) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error_code diff --git a/cmd/rdpgw/protocol/tunnel.go b/cmd/rdpgw/protocol/tunnel.go index bb1cb6908d54cac347eb1d7c30aa33f5c5044e45..cba7e1e0a497ae22235101707564c3ba15f9ee54 100644 --- a/cmd/rdpgw/protocol/tunnel.go +++ b/cmd/rdpgw/protocol/tunnel.go @@ -1,11 +1,16 @@ package protocol import ( + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" "net" "time" ) +const ( + CtxTunnel = "github.com/bolkedebruin/rdpgw/tunnel" +) + type Tunnel struct { // Id identifies the connection in the server Id string @@ -22,7 +27,7 @@ type Tunnel struct { // The obtained client ip address RemoteAddr string // User - UserName string + User common.Identity // rwc is the underlying connection to the remote desktop server. // It is of the type *net.TCPConn diff --git a/cmd/rdpgw/security/basic.go b/cmd/rdpgw/security/basic.go index ca7d2c009a895031d187a1b92bde6484f3232083..8ec88db2b87b2c6ab707cc14993f118c92cb479f 100644 --- a/cmd/rdpgw/security/basic.go +++ b/cmd/rdpgw/security/basic.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "log" "strings" ) @@ -22,23 +21,14 @@ func CheckHost(ctx context.Context, host string) (bool, error) { // todo get from context? return false, errors.New("cannot verify host in 'signed' mode as token data is missing") case "roundrobin", "unsigned": - var username string - s := getTunnel(ctx) - if s == nil || s.UserName == "" { - var ok bool - username, ok = ctx.Value(common.UsernameCtx).(string) - if !ok { - return false, errors.New("no valid session info or username found in context") - } - } else { - username = s.UserName + if s.User.UserName() == "" { + return false, errors.New("no valid session info or username found in context") } - log.Printf("Checking host for user %s", username) + + log.Printf("Checking host for user %s", s.User.UserName()) for _, h := range Hosts { - if username != "" { - h = strings.Replace(h, "{{ preferred_username }}", username, 1) - } + h = strings.Replace(h, "{{ preferred_username }}", s.User.UserName(), 1) if h == host { return true, nil } diff --git a/cmd/rdpgw/security/basic_test.go b/cmd/rdpgw/security/basic_test.go index d1b6a7c9d9b1d62da17b3d7dce9d65de925b965f..1026a6ccb702d92eea726b4468c6c7f7b462f3c1 100644 --- a/cmd/rdpgw/security/basic_test.go +++ b/cmd/rdpgw/security/basic_test.go @@ -12,14 +12,16 @@ var ( RDGId: "myid", TargetServer: "my.remote.server", RemoteAddr: "10.0.0.1", - UserName: "Frank", } hosts = []string{"localhost:3389", "my-{{ preferred_username }}-host:3389"} ) func TestCheckHost(t *testing.T) { - ctx := context.WithValue(context.Background(), common.TunnelCtx, &info) + info.User = common.NewUser() + info.User.SetUserName("MYNAME") + + ctx := context.WithValue(context.Background(), protocol.CtxTunnel, &info) Hosts = hosts @@ -40,14 +42,7 @@ func TestCheckHost(t *testing.T) { t.Fatalf("%s should NOT be allowed with host selection %s (err: %s)", host, HostSelection, err) } - host = "my-Frank-host:3389" - if ok, err := CheckHost(ctx, host); !ok { - t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err) - } - - info.UserName = "" - ctx = context.WithValue(ctx, "preferred_username", "dummy") - host = "my-dummy-host:3389" + host = "my-MYNAME-host:3389" if ok, err := CheckHost(ctx, host); !ok { t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err) } diff --git a/cmd/rdpgw/security/jwt.go b/cmd/rdpgw/security/jwt.go index c8654ec49f60fd642d731a16e69eb4f9da9663ff..6f21d17260446735f156f90f4649d2189ea8bc75 100644 --- a/cmd/rdpgw/security/jwt.go +++ b/cmd/rdpgw/security/jwt.go @@ -35,19 +35,21 @@ type customClaims struct { func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc { return func(ctx context.Context, host string) (bool, error) { - s := getTunnel(ctx) - if s == nil { + tunnel := getTunnel(ctx) + if tunnel == nil { return false, errors.New("no valid session info found in context") } - if s.TargetServer != host { - log.Printf("Client specified host %s does not match token host %s", host, s.TargetServer) + if tunnel.TargetServer != host { + log.Printf("Client specified host %s does not match token host %s", host, tunnel.TargetServer) return false, nil } - if VerifyClientIP && s.RemoteAddr != common.GetClientIp(ctx) { + // use identity from context rather then set by tunnel + id := common.FromCtx(ctx) + if VerifyClientIP && tunnel.RemoteAddr != id.GetAttribute(common.AttrClientIp) { log.Printf("Current client ip address %s does not match token client ip %s", - common.GetClientIp(ctx), s.RemoteAddr) + id.GetAttribute(common.AttrClientIp), tunnel.RemoteAddr) return false, nil } return next(ctx, host) @@ -106,7 +108,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) { tunnel.TargetServer = custom.RemoteServer tunnel.RemoteAddr = custom.ClientIP - tunnel.UserName = user.Subject + tunnel.User.SetUserName(user.Subject) return true, nil } @@ -127,10 +129,11 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri Subject: username, } + id := common.FromCtx(ctx) private := customClaims{ RemoteServer: server, - ClientIP: common.GetClientIp(ctx), - AccessToken: common.GetAccessToken(ctx), + ClientIP: id.GetAttribute(common.AttrClientIp).(string), + AccessToken: id.GetAttribute(common.AttrAccessToken).(string), } if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil { @@ -289,7 +292,7 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin } func getTunnel(ctx context.Context) *protocol.Tunnel { - s, ok := ctx.Value(common.TunnelCtx).(*protocol.Tunnel) + s, ok := ctx.Value(protocol.CtxTunnel).(*protocol.Tunnel) if !ok { log.Printf("cannot get session info from context") return nil diff --git a/cmd/rdpgw/web/basic.go b/cmd/rdpgw/web/basic.go index 5c1443eb5e1322e81d263b737002664ad9421af0..a8ef8077326727bbe142fc05c8656bc8d0d1debe 100644 --- a/cmd/rdpgw/web/basic.go +++ b/cmd/rdpgw/web/basic.go @@ -52,8 +52,12 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc { if !res.Authenticated { log.Printf("User %s is not authenticated for this service", username) } else { - ctx := context.WithValue(r.Context(), common.UsernameCtx, username) - next.ServeHTTP(w, r.WithContext(ctx)) + log.Printf("User %s authenticated", username) + id := common.FromRequestCtx(r) + id.SetUserName(username) + id.SetAuthenticated(true) + id.SetAuthTime(time.Now()) + next.ServeHTTP(w, common.AddToRequestCtx(id, r)) return } diff --git a/cmd/rdpgw/web/oidc.go b/cmd/rdpgw/web/oidc.go index 93b9945fde9baff84d30f39f1ea460d68a3d47b7..34104e55ac20d6e51761cb1c3df4efb67094fb50 100644 --- a/cmd/rdpgw/web/oidc.go +++ b/cmd/rdpgw/web/oidc.go @@ -1,7 +1,6 @@ package web import ( - "context" "encoding/hex" "encoding/json" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" @@ -15,8 +14,10 @@ import ( ) const ( - CacheExpiration = time.Minute * 2 - CleanupInterval = time.Minute * 5 + CacheExpiration = time.Minute * 2 + CleanupInterval = time.Minute * 5 + sessionKeyAuthenticated = "authenticated" + oidcKeyUserName = "preferred_username" ) type OIDC struct { @@ -90,10 +91,14 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) { return } + id := common.FromRequestCtx(r) + id.SetUserName(data[oidcKeyUserName].(string)) + id.SetAuthenticated(true) + id.SetAuthTime(time.Now()) + id.SetAttribute(common.AttrAccessToken, oauth2Token.AccessToken) + session.Options.MaxAge = MaxAge - session.Values["preferred_username"] = data["preferred_username"] - session.Values["authenticated"] = true - session.Values["access_token"] = oauth2Token.AccessToken + session.Values[common.CTXKey] = id if err = session.Save(r, w); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -110,8 +115,8 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler { return } - found := session.Values["authenticated"] - if found == nil || !found.(bool) { + id := session.Values[common.CTXKey].(common.Identity) + if id == nil { seed := make([]byte, 16) rand.Read(seed) state := hex.EncodeToString(seed) @@ -120,9 +125,7 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler { return } - ctx := context.WithValue(r.Context(), common.UsernameCtx, session.Values["preferred_username"]) - ctx = context.WithValue(ctx, "access_token", session.Values["access_token"]) - - next.ServeHTTP(w, r.WithContext(ctx)) + // replace the identity with the one from the sessions + next.ServeHTTP(w, common.AddToRequestCtx(id, r)) }) }