From 184ff320b81f618b3fa52aa138ffd92566599609 Mon Sep 17 00:00:00 2001
From: Bolke de Bruin <bolke@xs4all.nl>
Date: Fri, 26 Aug 2022 11:59:46 +0200
Subject: [PATCH] Fix checking host from list

---
 cmd/rdpgw/protocol/common.go | 10 ++++++----
 cmd/rdpgw/security/basic.go  |  9 ++++++---
 cmd/rdpgw/security/jwt.go    | 13 ++++++-------
 3 files changed, 18 insertions(+), 14 deletions(-)

diff --git a/cmd/rdpgw/protocol/common.go b/cmd/rdpgw/protocol/common.go
index ad7c9b0..7a263bc 100644
--- a/cmd/rdpgw/protocol/common.go
+++ b/cmd/rdpgw/protocol/common.go
@@ -24,7 +24,7 @@ type RedirectFlags struct {
 
 type SessionInfo struct {
 	// The connection-id (RDG-ConnID) as reported by the client
-	ConnId           string
+	ConnId string
 	// The underlying incoming transport being either websocket or legacy http
 	// in case of websocket TransportOut will equal TransportIn
 	TransportIn transport.Transport
@@ -32,9 +32,11 @@ type SessionInfo struct {
 	// in case of websocket TransportOut will equal TransportOut
 	TransportOut transport.Transport
 	// The remote desktop server (rdp, vnc etc) the clients intends to connect to
-	RemoteServer	 string
+	RemoteServer string
 	// The obtained client ip address
-	ClientIp		 string
+	ClientIp string
+	// User
+	UserName string
 }
 
 // readMessage parses and defragments a packet from a Transport. It returns
@@ -145,4 +147,4 @@ func wrapSyscallError(name string, err error) error {
 		err = os.NewSyscallError(name, err)
 	}
 	return err
-}
\ No newline at end of file
+}
diff --git a/cmd/rdpgw/security/basic.go b/cmd/rdpgw/security/basic.go
index 5a06611..c7e6f96 100644
--- a/cmd/rdpgw/security/basic.go
+++ b/cmd/rdpgw/security/basic.go
@@ -22,10 +22,13 @@ func CheckHost(ctx context.Context, host string) (bool, error) {
 		return false, errors.New("cannot verify host in 'signed' mode as token data is missing")
 	case "roundrobin", "unsigned":
 		log.Printf("Checking host")
-		username := ctx.Value("preferred_username").(string)
+		s := getSessionInfo(ctx)
+		if s == nil {
+			return false, errors.New("no valid session info found in context")
+		}
 		for _, h := range Hosts {
-			if username != "" {
-				h = strings.Replace(h, "{{ preferred_username }}", username, 1)
+			if s.UserName != "" {
+				h = strings.Replace(h, "{{ preferred_username }}", s.UserName, 1)
 			}
 			if h == host {
 				return true, nil
diff --git a/cmd/rdpgw/security/jwt.go b/cmd/rdpgw/security/jwt.go
index 324c873..9bd5f4a 100644
--- a/cmd/rdpgw/security/jwt.go
+++ b/cmd/rdpgw/security/jwt.go
@@ -95,19 +95,18 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
 	}
 
 	// validate the access token
-	if custom.AccessToken != "EMPTY" {
-		tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken})
-		_, err = OIDCProvider.UserInfo(ctx, tokenSource)
-		if err != nil {
-			log.Printf("Cannot get user info for access token: %s", err)
-			return false, err
-		}
+	tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken})
+	user, err := OIDCProvider.UserInfo(ctx, tokenSource)
+	if err != nil {
+		log.Printf("Cannot get user info for access token: %s", err)
+		return false, err
 	}
 
 	s := getSessionInfo(ctx)
 
 	s.RemoteServer = custom.RemoteServer
 	s.ClientIp = custom.ClientIP
+	s.UserName = user.Subject
 
 	return true, nil
 }
-- 
GitLab