diff --git a/cmd/rdpgw/protocol/common.go b/cmd/rdpgw/protocol/common.go index ad7c9b0854e73fecf038eb83f4b88d0efc6dd6f5..7a263bc67708a903f011424b5dedec1294fdd1dc 100644 --- a/cmd/rdpgw/protocol/common.go +++ b/cmd/rdpgw/protocol/common.go @@ -24,7 +24,7 @@ type RedirectFlags struct { type SessionInfo struct { // The connection-id (RDG-ConnID) as reported by the client - ConnId string + ConnId string // The underlying incoming transport being either websocket or legacy http // in case of websocket TransportOut will equal TransportIn TransportIn transport.Transport @@ -32,9 +32,11 @@ type SessionInfo struct { // in case of websocket TransportOut will equal TransportOut TransportOut transport.Transport // The remote desktop server (rdp, vnc etc) the clients intends to connect to - RemoteServer string + RemoteServer string // The obtained client ip address - ClientIp string + ClientIp string + // User + UserName string } // readMessage parses and defragments a packet from a Transport. It returns @@ -145,4 +147,4 @@ func wrapSyscallError(name string, err error) error { err = os.NewSyscallError(name, err) } return err -} \ No newline at end of file +} diff --git a/cmd/rdpgw/security/basic.go b/cmd/rdpgw/security/basic.go index 5a066119c271c6bb58f6d4b49831f3c7ce116558..c7e6f9622cacc2120db03f6a2a2aa7d8b8cf8033 100644 --- a/cmd/rdpgw/security/basic.go +++ b/cmd/rdpgw/security/basic.go @@ -22,10 +22,13 @@ func CheckHost(ctx context.Context, host string) (bool, error) { return false, errors.New("cannot verify host in 'signed' mode as token data is missing") case "roundrobin", "unsigned": log.Printf("Checking host") - username := ctx.Value("preferred_username").(string) + s := getSessionInfo(ctx) + if s == nil { + return false, errors.New("no valid session info found in context") + } for _, h := range Hosts { - if username != "" { - h = strings.Replace(h, "{{ preferred_username }}", username, 1) + if s.UserName != "" { + h = strings.Replace(h, "{{ preferred_username }}", s.UserName, 1) } if h == host { return true, nil diff --git a/cmd/rdpgw/security/jwt.go b/cmd/rdpgw/security/jwt.go index 324c8737bb2eacc673d4653be92618de142221b1..9bd5f4a66130cf098cc2d0208728f9b8ebc0e3c4 100644 --- a/cmd/rdpgw/security/jwt.go +++ b/cmd/rdpgw/security/jwt.go @@ -95,19 +95,18 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) { } // validate the access token - if custom.AccessToken != "EMPTY" { - tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken}) - _, err = OIDCProvider.UserInfo(ctx, tokenSource) - if err != nil { - log.Printf("Cannot get user info for access token: %s", err) - return false, err - } + tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken}) + user, err := OIDCProvider.UserInfo(ctx, tokenSource) + if err != nil { + log.Printf("Cannot get user info for access token: %s", err) + return false, err } s := getSessionInfo(ctx) s.RemoteServer = custom.RemoteServer s.ClientIp = custom.ClientIP + s.UserName = user.Subject return true, nil }