From 184ff320b81f618b3fa52aa138ffd92566599609 Mon Sep 17 00:00:00 2001 From: Bolke de Bruin <bolke@xs4all.nl> Date: Fri, 26 Aug 2022 11:59:46 +0200 Subject: [PATCH] Fix checking host from list --- cmd/rdpgw/protocol/common.go | 10 ++++++---- cmd/rdpgw/security/basic.go | 9 ++++++--- cmd/rdpgw/security/jwt.go | 13 ++++++------- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/cmd/rdpgw/protocol/common.go b/cmd/rdpgw/protocol/common.go index ad7c9b0..7a263bc 100644 --- a/cmd/rdpgw/protocol/common.go +++ b/cmd/rdpgw/protocol/common.go @@ -24,7 +24,7 @@ type RedirectFlags struct { type SessionInfo struct { // The connection-id (RDG-ConnID) as reported by the client - ConnId string + ConnId string // The underlying incoming transport being either websocket or legacy http // in case of websocket TransportOut will equal TransportIn TransportIn transport.Transport @@ -32,9 +32,11 @@ type SessionInfo struct { // in case of websocket TransportOut will equal TransportOut TransportOut transport.Transport // The remote desktop server (rdp, vnc etc) the clients intends to connect to - RemoteServer string + RemoteServer string // The obtained client ip address - ClientIp string + ClientIp string + // User + UserName string } // readMessage parses and defragments a packet from a Transport. It returns @@ -145,4 +147,4 @@ func wrapSyscallError(name string, err error) error { err = os.NewSyscallError(name, err) } return err -} \ No newline at end of file +} diff --git a/cmd/rdpgw/security/basic.go b/cmd/rdpgw/security/basic.go index 5a06611..c7e6f96 100644 --- a/cmd/rdpgw/security/basic.go +++ b/cmd/rdpgw/security/basic.go @@ -22,10 +22,13 @@ func CheckHost(ctx context.Context, host string) (bool, error) { return false, errors.New("cannot verify host in 'signed' mode as token data is missing") case "roundrobin", "unsigned": log.Printf("Checking host") - username := ctx.Value("preferred_username").(string) + s := getSessionInfo(ctx) + if s == nil { + return false, errors.New("no valid session info found in context") + } for _, h := range Hosts { - if username != "" { - h = strings.Replace(h, "{{ preferred_username }}", username, 1) + if s.UserName != "" { + h = strings.Replace(h, "{{ preferred_username }}", s.UserName, 1) } if h == host { return true, nil diff --git a/cmd/rdpgw/security/jwt.go b/cmd/rdpgw/security/jwt.go index 324c873..9bd5f4a 100644 --- a/cmd/rdpgw/security/jwt.go +++ b/cmd/rdpgw/security/jwt.go @@ -95,19 +95,18 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) { } // validate the access token - if custom.AccessToken != "EMPTY" { - tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken}) - _, err = OIDCProvider.UserInfo(ctx, tokenSource) - if err != nil { - log.Printf("Cannot get user info for access token: %s", err) - return false, err - } + tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken}) + user, err := OIDCProvider.UserInfo(ctx, tokenSource) + if err != nil { + log.Printf("Cannot get user info for access token: %s", err) + return false, err } s := getSessionInfo(ctx) s.RemoteServer = custom.RemoteServer s.ClientIp = custom.ClientIP + s.UserName = user.Subject return true, nil } -- GitLab