diff --git a/README.md b/README.md index af90c67a03b74158db095ea097619d42b1e4ee99..7379b800ac0243409af81a5a93d1ddc8d08da528 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,8 @@ client: security: # a random string of at least 32 characters to secure cookies on the client # make sure to share this amongst different pods - tokenSigningKey: thisisasessionkeyreplacethisjetzt + PAATokenSigningKey: thisisasessionkeyreplacethisjetzt + PAATokenEncryptionKey: thisisasessionkeyreplacethisjetzt ``` ## Testing locally A convenience docker-compose allows you to test the RDPGW locally. It uses [Keycloak](http://www.keycloak.org) diff --git a/api/web.go b/api/web.go index 464a92a90446762c2c5360a81dfbc4422cd93d99..3fd34fbac9d6235bd974ed453b041cfec693a81a 100644 --- a/api/web.go +++ b/api/web.go @@ -23,14 +23,16 @@ const ( ) type TokenGeneratorFunc func(context.Context, string, string) (string, error) +type UserTokenGeneratorFunc func(context.Context, string) (string, error) type Config struct { SessionKey []byte SessionEncryptionKey []byte - TokenGenerator TokenGeneratorFunc + PAATokenGenerator TokenGeneratorFunc + UserTokenGenerator UserTokenGeneratorFunc OAuth2Config *oauth2.Config store *sessions.CookieStore - TokenVerifier *oidc.IDTokenVerifier + OIDCTokenVerifier *oidc.IDTokenVerifier stateStore *cache.Cache Hosts []string GatewayAddress string @@ -72,7 +74,7 @@ func (c *Config) HandleCallback(w http.ResponseWriter, r *http.Request) { http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError) return } - idToken, err := c.TokenVerifier.Verify(ctx, rawIDToken) + idToken, err := c.OIDCTokenVerifier.Verify(ctx, rawIDToken) if err != nil { http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError) return @@ -103,6 +105,7 @@ func (c *Config) HandleCallback(w http.ResponseWriter, r *http.Request) { session.Options.MaxAge = MaxAge session.Values["preferred_username"] = data["preferred_username"] session.Values["authenticated"] = true + session.Values["access_token"] = oauth2Token.AccessToken if err = session.Save(r, w); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -130,6 +133,8 @@ func (c *Config) Authenticated(next http.Handler) http.Handler { } ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"]) + ctx = context.WithValue(ctx, "access_token", session.Values["access_token"]) + next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -159,7 +164,13 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) { } } - token, err := c.TokenGenerator(ctx, user, host) + token, err := c.PAATokenGenerator(ctx, user, host) + if err != nil { + log.Printf("Cannot generate PAA token for user %s due to %s", user, err) + http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError) + } + + userToken, err := c.UserTokenGenerator(ctx, user) if err != nil { log.Printf("Cannot generate token for user %s due to %s", user, err) http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError) @@ -182,6 +193,6 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) { "networkautodetect:i:"+strconv.Itoa(c.NetworkAutoDetect)+"\r\n"+ "bandwidthautodetect:i:"+strconv.Itoa(c.BandwidthAutoDetect)+"\r\n"+ "connection type:i:"+strconv.Itoa(c.ConnectionType)+"\r\n"+ - "username:s:"+token+"\r\n"+ + "username:s:"+userToken+"\r\n"+ "bitmapcachesize:i:32000\r\n")) } diff --git a/common/remote.go b/common/remote.go index 39ad7267d8dea16f0a8e31632c4a0f2f98e9c9e3..d6d97ee6b0132a71c7024a054914698a6f5c20c1 100644 --- a/common/remote.go +++ b/common/remote.go @@ -2,6 +2,7 @@ package common import ( "context" + "log" "net" "net/http" "strings" @@ -48,3 +49,12 @@ func GetClientIp(ctx context.Context) string { } return s } + +func GetAccessToken(ctx context.Context) string { + token, ok := ctx.Value("access_token").(string) + if !ok { + log.Printf("cannot get access token from context") + return "" + } + return token +} \ No newline at end of file diff --git a/config/configuration.go b/config/configuration.go index 55da5ea867f74ddf835658c19d0be414a36c269e..0f921b0e44ab56b6c8af29758c7225acb4c41cdf 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -44,9 +44,10 @@ type RDGCapsConfig struct { } type SecurityConfig struct { - EnableOpenId bool - TokenSigningKey string - PassTokenAsPassword bool + PAATokenEncryptionKey string + PAATokenSigningKey string + UserTokenEncryptionKey string + UserTokenSigningKey string } type ClientConfig struct { @@ -82,7 +83,7 @@ func Load(configFile string) Configuration { log.Fatalf("Cannot unmarshal the config file; %s", err) } - if len(conf.Security.TokenSigningKey) < 32 { + if len(conf.Security.PAATokenSigningKey) < 32 { log.Fatalf("Token signing key not long enough") } diff --git a/go.mod b/go.mod index c19ccd6ce12145915c05fdd4b6444d9157d14f98..77902d2133a8add191e820e6893cf1a66295c00c 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,12 @@ go 1.14 require ( github.com/coreos/go-oidc/v3 v3.0.0-alpha.1 - github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 github.com/gorilla/sessions v1.2.0 github.com/gorilla/websocket v1.4.2 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/prometheus/client_golang v1.7.1 github.com/spf13/cobra v1.0.0 github.com/spf13/viper v1.7.0 + github.com/square/go-jose/v3 v3.0.0-20200630053402-0a67ce9b0693 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d ) diff --git a/main.go b/main.go index e89a29365d686844c27d3452522783c60108fe75..26c0d092fd08799f1ce74bf5e6f401f7c3df6a5a 100644 --- a/main.go +++ b/main.go @@ -35,7 +35,10 @@ func main() { conf = config.Load(configFile) // set security keys - security.SigningKey = []byte(conf.Security.TokenSigningKey) + security.SigningKey = []byte(conf.Security.PAATokenSigningKey) + security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey) + security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey) + security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey) // set oidc config ctx := context.Background() @@ -57,17 +60,18 @@ func main() { } api := &api.Config{ - GatewayAddress: conf.Server.GatewayAddress, - OAuth2Config: &oauthConfig, - TokenVerifier: verifier, - TokenGenerator: security.GeneratePAAToken, - SessionKey: []byte(conf.Server.SessionKey), + GatewayAddress: conf.Server.GatewayAddress, + OAuth2Config: &oauthConfig, + OIDCTokenVerifier: verifier, + PAATokenGenerator: security.GeneratePAAToken, + UserTokenGenerator: security.GenerateUserToken, + SessionKey: []byte(conf.Server.SessionKey), SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey), - Hosts: conf.Server.Hosts, - NetworkAutoDetect: conf.Client.NetworkAutoDetect, - UsernameTemplate: conf.Client.UsernameTemplate, - BandwidthAutoDetect: conf.Client.BandwidthAutoDetect, - ConnectionType: conf.Client.ConnectionType, + Hosts: conf.Server.Hosts, + NetworkAutoDetect: conf.Client.NetworkAutoDetect, + UsernameTemplate: conf.Client.UsernameTemplate, + BandwidthAutoDetect: conf.Client.BandwidthAutoDetect, + ConnectionType: conf.Client.ConnectionType, } api.NewApi() diff --git a/security/jwt.go b/security/jwt.go index 4c119b8d6fea0077ac436d2538666618a572ac6e..97511624b0c6e829ab6fd805ade46474d0d75980 100644 --- a/security/jwt.go +++ b/security/jwt.go @@ -6,42 +6,64 @@ import ( "fmt" "github.com/bolkedebruin/rdpgw/common" "github.com/bolkedebruin/rdpgw/protocol" - "github.com/dgrijalva/jwt-go/v4" + "github.com/square/go-jose/v3" + "github.com/square/go-jose/v3/jwt" "log" "time" ) -var SigningKey []byte +var ( + SigningKey []byte + EncryptionKey []byte + UserSigningKey []byte + UserEncryptionKey []byte +) + var ExpiryTime time.Duration = 5 type customClaims struct { RemoteServer string `json:"remoteServer"` - ClientIP string `json:"clientIp"` - jwt.StandardClaims + ClientIP string `json:"clientIp"` + AccessToken string `json:"accessToken"` } func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) { - token, err := jwt.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + token, err := jwt.ParseSigned(tokenString) + + // check if the signing algo matches what we expect + for _, header := range token.Headers { + if header.Algorithm != string(jose.HS256) { + return false, fmt.Errorf("unexpected signing method: %v", header.Algorithm) } + } - return SigningKey, nil - }) + standard := jwt.Claims{} + custom := customClaims{} + // Claims automagically checks the signature... + err = token.Claims(SigningKey, &standard, &custom) if err != nil { + log.Printf("token signature validation failed due to %s", err) return false, err } - if c, ok := token.Claims.(*customClaims); ok && token.Valid { - s := getSessionInfo(ctx) - s.RemoteServer = c.RemoteServer - s.ClientIp = c.ClientIP - return true, nil + // ...but doesn't check the expiry claim :/ + err = standard.Validate(jwt.Expected{ + Issuer: "rdpgw", + Time: time.Now(), + }) + + if err != nil { + log.Printf("token validation failed due to %s", err) + return false, err } - log.Printf("token validation failed: %s", err) - return false, err + s := getSessionInfo(ctx) + + s.RemoteServer = custom.RemoteServer + s.ClientIp = custom.ClientIP + + return true, nil } func VerifyServerFunc(ctx context.Context, host string) (bool, error) { @@ -68,32 +90,59 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri if len(SigningKey) < 32 { return "", errors.New("token signing key not long enough or not specified") } - - exp := &jwt.Time{ - Time: time.Now().Add(time.Minute * 5), + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: SigningKey}, nil) + if err != nil { + log.Printf("Cannot obtain signer %s", err) + return "", err } - now := &jwt.Time{ - Time: time.Now(), + + standard := jwt.Claims{ + Issuer: "rdpgw", + Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + Subject: username, } - c := customClaims{ + private := customClaims{ RemoteServer: server, ClientIP: common.GetClientIp(ctx), - StandardClaims: jwt.StandardClaims{ - ExpiresAt: exp, - IssuedAt: now, - Issuer: "rdpgw", - Subject: username, - }, + AccessToken: common.GetAccessToken(ctx), } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, c) - if ss, err := token.SignedString(SigningKey); err != nil { + if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil { log.Printf("Cannot sign PAA token %s", err) return "", err } else { - return ss, nil + return token, nil + } +} + +func GenerateUserToken(ctx context.Context, userName string) (string, error) { + if len(UserEncryptionKey) < 32 { + return "", errors.New("user token encryption key not long enough or not specified") + } + + claims := jwt.Claims{ + Subject: userName, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + Issuer: "rdpgw", } + + enc, err := jose.NewEncrypter( + jose.A128CBC_HS256, + jose.Recipient{Algorithm: jose.DIRECT, Key: UserEncryptionKey}, + (&jose.EncrypterOptions{Compression: jose.DEFLATE}).WithContentType("JWT"), + ) + + if err != nil { + log.Printf("Cannot encrypt user token due to %s", err) + return "", err + } + + // this makes the token bigger and we deal with a limited space of 511 characters + // sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: SigningKey}, nil) + // token, err := jwt.SignedAndEncrypted(sig, enc).Claims(claims).CompactSerialize() + token, err := jwt.Encrypted(enc).Claims(claims).CompactSerialize() + return token, err } func getSessionInfo(ctx context.Context) *protocol.SessionInfo { @@ -103,4 +152,4 @@ func getSessionInfo(ctx context.Context) *protocol.SessionInfo { return nil } return s -} \ No newline at end of file +}