From 9c6d056d693bd9bbf69a15da54d83a1fb15f6dc6 Mon Sep 17 00:00:00 2001
From: Bolke de Bruin <bolke@xs4all.nl>
Date: Fri, 12 Apr 2024 12:33:46 +0200
Subject: [PATCH] Use jose v4 and make clearer and fix signing/encryption

---
 README.md                      |  2 +
 cmd/rdpgw/security/jwt.go      | 74 +++++++++++++--------------------
 cmd/rdpgw/security/jwt_test.go | 76 ++++++++++++++++++++++++++++++++++
 go.mod                         |  5 ++-
 4 files changed, 110 insertions(+), 47 deletions(-)
 create mode 100644 cmd/rdpgw/security/jwt_test.go

diff --git a/README.md b/README.md
index 4c9ad36..860cdea 100644
--- a/README.md
+++ b/README.md
@@ -299,6 +299,8 @@ Security:
   # PAATokenEncryptionKey: thisisasessionkeyreplacethisjetzt
   # a random string of 32 characters to secure cookies on the client
   UserTokenEncryptionKey: thisisasessionkeyreplacethisjetzt
+  # Signing makes the token bigger and we are limited to 511 characters
+  # UserTokenSigningKey: thisisasessionkeyreplacethisjetzt
   # if you want to enable token generation for the user
   # if true the username will be set to a jwt with the username embedded into it
   EnableUserToken: true
diff --git a/cmd/rdpgw/security/jwt.go b/cmd/rdpgw/security/jwt.go
index cd0f3a8..40ffade 100644
--- a/cmd/rdpgw/security/jwt.go
+++ b/cmd/rdpgw/security/jwt.go
@@ -7,8 +7,8 @@ import (
 	"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
 	"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
 	"github.com/coreos/go-oidc/v3/oidc"
-	"github.com/go-jose/go-jose/v3"
-	"github.com/go-jose/go-jose/v3/jwt"
+	"github.com/go-jose/go-jose/v4"
+	"github.com/go-jose/go-jose/v4/jwt"
 	"golang.org/x/oauth2"
 	"log"
 	"time"
@@ -62,9 +62,9 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
 		return false, errors.New("no token to parse")
 	}
 
-	token, err := jwt.ParseSigned(tokenString)
+	token, err := jwt.ParseSigned(tokenString, []jose.SignatureAlgorithm{jose.HS256})
 	if err != nil {
-		log.Printf("cannot parse token due to: %tunnel", err)
+		log.Printf("cannot parse token due to: %t", err)
 		return false, err
 	}
 
@@ -136,7 +136,7 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri
 		AccessToken:  id.GetAttribute(identity.AttrAccessToken).(string),
 	}
 
-	if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil {
+	if token, err := jwt.Signed(sig).Claims(standard).Claims(private).Serialize(); err != nil {
 		log.Printf("Cannot sign PAA token %s", err)
 		return "", err
 	} else {
@@ -157,7 +157,10 @@ func GenerateUserToken(ctx context.Context, userName string) (string, error) {
 
 	enc, err := jose.NewEncrypter(
 		jose.A128CBC_HS256,
-		jose.Recipient{Algorithm: jose.DIRECT, Key: UserEncryptionKey},
+		jose.Recipient{
+			Algorithm: jose.DIRECT,
+			Key:       UserEncryptionKey,
+		},
 		(&jose.EncrypterOptions{Compression: jose.DEFLATE}).WithContentType("JWT"),
 	)
 
@@ -167,16 +170,29 @@ func GenerateUserToken(ctx context.Context, userName string) (string, error) {
 	}
 
 	// this makes the token bigger and we deal with a limited space of 511 characters
-	// sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: SigningKey}, nil)
-	// token, err := jwt.SignedAndEncrypted(sig, enc).Claims(claims).CompactSerialize()
-	token, err := jwt.Encrypted(enc).Claims(claims).CompactSerialize()
+	if len(UserSigningKey) > 0 {
+		sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: UserSigningKey}, nil)
+		token, err := jwt.SignedAndEncrypted(sig, enc).Claims(claims).Serialize()
+		if len(token) > 511 {
+			log.Printf("WARNING: token too long: len %d > 511", len(token))
+		}
+		return token, err
+	}
+
+	// no signature
+	token, err := jwt.Encrypted(enc).Claims(claims).Serialize()
 	return token, err
 }
 
 func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
 	standard := jwt.Claims{}
 	if len(UserEncryptionKey) > 0 && len(UserSigningKey) > 0 {
-		enc, err := jwt.ParseSignedAndEncrypted(token)
+		enc, err := jwt.ParseSignedAndEncrypted(
+			token,
+			[]jose.KeyAlgorithm{jose.DIRECT},
+			[]jose.ContentEncryption{jose.A128CBC_HS256},
+			[]jose.SignatureAlgorithm{jose.HS256},
+		)
 		if err != nil {
 			log.Printf("Cannot get token %s", err)
 			return standard, errors.New("cannot get token")
@@ -186,16 +202,12 @@ func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
 			log.Printf("Cannot decrypt token %s", err)
 			return standard, errors.New("cannot decrypt token")
 		}
-		if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
-			log.Printf("signature validation failure: %s", err)
-			return standard, errors.New("signature validation failure")
-		}
 		if err = token.Claims(UserSigningKey, &standard); err != nil {
 			log.Printf("cannot verify signature %s", err)
 			return standard, errors.New("cannot verify signature")
 		}
 	} else if len(UserSigningKey) == 0 {
-		token, err := jwt.ParseEncrypted(token)
+		token, err := jwt.ParseEncrypted(token, []jose.KeyAlgorithm{jose.DIRECT}, []jose.ContentEncryption{jose.A128CBC_HS256})
 		if err != nil {
 			log.Printf("Cannot get token %s", err)
 			return standard, errors.New("cannot get token")
@@ -205,21 +217,6 @@ func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
 			log.Printf("Cannot decrypt token %s", err)
 			return standard, errors.New("cannot decrypt token")
 		}
-	} else {
-		token, err := jwt.ParseSigned(token)
-		if err != nil {
-			log.Printf("Cannot get token %s", err)
-			return standard, errors.New("cannot get token")
-		}
-		if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
-			log.Printf("signature validation failure: %s", err)
-			return standard, errors.New("signature validation failure")
-		}
-		err = token.Claims(UserSigningKey, &standard)
-		if err = token.Claims(UserSigningKey, &standard); err != nil {
-			log.Printf("cannot verify signature %s", err)
-			return standard, errors.New("cannot verify signature")
-		}
 	}
 
 	// go-jose doesnt verify the expiry
@@ -238,15 +235,11 @@ func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
 
 func QueryInfo(ctx context.Context, tokenString string, issuer string) (string, error) {
 	standard := jwt.Claims{}
-	token, err := jwt.ParseSigned(tokenString)
+	token, err := jwt.ParseSigned(tokenString, []jose.SignatureAlgorithm{jose.HS256})
 	if err != nil {
 		log.Printf("Cannot get token %s", err)
 		return "", errors.New("cannot get token")
 	}
-	if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
-		log.Printf("signature validation failure: %s", err)
-		return "", errors.New("signature validation failure")
-	}
 	err = token.Claims(QuerySigningKey, &standard)
 	if err = token.Claims(QuerySigningKey, &standard); err != nil {
 		log.Printf("cannot verify signature %s", err)
@@ -287,7 +280,7 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin
 		return "", err
 	}
 
-	token, err := jwt.Signed(sig).Claims(claims).CompactSerialize()
+	token, err := jwt.Signed(sig).Claims(claims).Serialize()
 	return token, err
 }
 
@@ -299,12 +292,3 @@ func getTunnel(ctx context.Context) *protocol.Tunnel {
 	}
 	return s
 }
-
-func verifyAlg(headers []jose.Header, alg string) (bool, error) {
-	for _, header := range headers {
-		if header.Algorithm != alg {
-			return false, fmt.Errorf("invalid signing method %s", header.Algorithm)
-		}
-	}
-	return true, nil
-}
diff --git a/cmd/rdpgw/security/jwt_test.go b/cmd/rdpgw/security/jwt_test.go
new file mode 100644
index 0000000..84dc876
--- /dev/null
+++ b/cmd/rdpgw/security/jwt_test.go
@@ -0,0 +1,76 @@
+package security
+
+import (
+	"context"
+	"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
+	"testing"
+)
+
+func TestGenerateUserToken(t *testing.T) {
+	cases := []struct {
+		SigningKey    []byte
+		EncryptionKey []byte
+		name          string
+		username      string
+	}{
+		{
+			SigningKey:    []byte("5aa3a1568fe8421cd7e127d5ace28d2d"),
+			EncryptionKey: []byte("d3ecd7e565e56e37e2f2e95b584d8c0c"),
+			name:          "sign_and_encrypt",
+			username:      "test_sign_and_encrypt",
+		},
+		{
+			SigningKey:    nil,
+			EncryptionKey: []byte("d3ecd7e565e56e37e2f2e95b584d8c0c"),
+			name:          "encrypt_only",
+			username:      "test_encrypt_only",
+		},
+	}
+	for _, tc := range cases {
+		t.Run(tc.name, func(t *testing.T) {
+			SigningKey = tc.SigningKey
+			UserEncryptionKey = tc.EncryptionKey
+			token, err := GenerateUserToken(context.Background(), tc.username)
+			if err != nil {
+				t.Fatalf("GenerateUserToken failed: %s", err)
+			}
+			claims, err := UserInfo(context.Background(), token)
+			if err != nil {
+				t.Fatalf("UserInfo failed: %s", err)
+			}
+			if claims.Subject != tc.username {
+				t.Fatalf("Expected %s, got %s", tc.username, claims.Subject)
+			}
+		})
+	}
+
+}
+
+func TestPAACookie(t *testing.T) {
+	SigningKey = []byte("5aa3a1568fe8421cd7e127d5ace28d2d")
+	EncryptionKey = []byte("d3ecd7e565e56e37e2f2e95b584d8c0c")
+
+	username := "test_paa_cookie"
+	attr_client_ip := "127.0.0.1"
+	attr_access_token := "aabbcc"
+
+	id := identity.NewUser()
+	id.SetUserName(username)
+	id.SetAttribute(identity.AttrClientIp, attr_client_ip)
+	id.SetAttribute(identity.AttrAccessToken, attr_access_token)
+
+	ctx := context.Background()
+	ctx = context.WithValue(ctx, identity.CTXKey, id)
+
+	_, err := GeneratePAAToken(ctx, "test_paa_cookie", "host.does.not.exist")
+	if err != nil {
+		t.Fatalf("GeneratePAAToken failed: %s", err)
+	}
+	/*ok, err := CheckPAACookie(ctx, token)
+	if err != nil {
+		t.Fatalf("CheckPAACookie failed: %s", err)
+	}
+	if !ok {
+		t.Fatalf("CheckPAACookie failed")
+	}*/
+}
diff --git a/go.mod b/go.mod
index 2d63b30..8c378e3 100644
--- a/go.mod
+++ b/go.mod
@@ -6,7 +6,8 @@ require (
 	github.com/bolkedebruin/gokrb5/v8 v8.5.0
 	github.com/coreos/go-oidc/v3 v3.9.0
 	github.com/fatih/structs v1.1.0
-	github.com/go-jose/go-jose/v3 v3.0.3
+	github.com/go-jose/go-jose/v4 v4.0.1
+	github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1
 	github.com/google/uuid v1.6.0
 	github.com/gorilla/mux v1.8.1
 	github.com/gorilla/sessions v1.2.2
@@ -34,7 +35,7 @@ require (
 	github.com/cespare/xxhash/v2 v2.2.0 // indirect
 	github.com/davecgh/go-spew v1.1.1 // indirect
 	github.com/fsnotify/fsnotify v1.7.0 // indirect
-	github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1 // indirect
+	github.com/go-jose/go-jose/v3 v3.0.1 // indirect
 	github.com/golang/protobuf v1.5.4 // indirect
 	github.com/gorilla/securecookie v1.1.2 // indirect
 	github.com/hashicorp/go-uuid v1.0.3 // indirect
-- 
GitLab