From 46e1e9b9f43c0b054c96a40d8563b921926db852 Mon Sep 17 00:00:00 2001 From: Bolke de Bruin <bolke@xs4all.nl> Date: Fri, 24 Jul 2020 11:54:43 +0200 Subject: [PATCH] Switch to jwt tokens and allow some extra rdp settings --- README.md | 12 +++++++ api/web.go | 77 +++++++++++++++++++++++------------------ config/configuration.go | 34 ++++++++++++++---- go.mod | 1 + main.go | 19 +++++----- security/jwt.go | 70 +++++++++++++++++++++++++++++++++++++ security/simple.go | 21 ----------- 7 files changed, 163 insertions(+), 71 deletions(-) create mode 100644 security/jwt.go delete mode 100644 security/simple.go diff --git a/README.md b/README.md index 6485247..9e8e425 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ server: # if true the server randomly selects a host to connect to roundRobin: false # a random string of at least 32 characters to secure cookies on the client + # make sure to share this across the different pods sessionKey: thisisasessionkeyreplacethisjetzt # Open ID Connect specific settings openId: @@ -60,6 +61,17 @@ caps: enablePnp: true enableDrive: true enableClipboard: true +client: + usernameTemplate: "{{ username }}@bla.com" + # rdp file settings see: + # https://docs.microsoft.com/en-us/windows-server/remote/remote-desktop-services/clients/rdp-files + networkAutoDetect: 0 + bandwidthAutoDetect: 1 + ConnectionType: 6 +security: + # a random string of at least 32 characters to secure cookies on the client + # make sure to share this amongst different pods + tokenSigningKey: thisisasessionkeyreplacethisjetzt ``` ## Use diff --git a/api/web.go b/api/web.go index 181a2ae..27abe64 100644 --- a/api/web.go +++ b/api/web.go @@ -12,24 +12,30 @@ import ( "log" "math/rand" "net/http" + "strconv" "strings" "time" ) const ( RdpGwSession = "RDPGWSESSION" - PAAToken = "PAAToken" ) +type TokenGeneratorFunc func(string, string) (string, error) + type Config struct { - SessionKey []byte - TokenCache *cache.Cache - OAuth2Config *oauth2.Config - store *sessions.CookieStore - TokenVerifier *oidc.IDTokenVerifier - stateStore *cache.Cache - Hosts []string - GatewayAddress string + SessionKey []byte + TokenGenerator TokenGeneratorFunc + OAuth2Config *oauth2.Config + store *sessions.CookieStore + TokenVerifier *oidc.IDTokenVerifier + stateStore *cache.Cache + Hosts []string + GatewayAddress string + UsernameTemplate string + NetworkAutoDetect int + BandwidthAutoDetect int + ConnectionType int } func (c *Config) NewApi() { @@ -86,22 +92,18 @@ func (c *Config) HandleCallback(w http.ResponseWriter, r *http.Request) { return } - seed := make([]byte, 16) - rand.Read(seed) - token := hex.EncodeToString(seed) - session, err := c.store.Get(r, RdpGwSession) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - session.Values[PAAToken] = token + session.Values["preferred_username"] = data["preferred_username"] + session.Values["authenticated"] = true if err = session.Save(r, w); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } - c.TokenCache.Set(token, data, cache.DefaultExpiration) http.Redirect(w, r, url, http.StatusFound) } @@ -114,13 +116,8 @@ func (c *Config) Authenticated(next http.Handler) http.Handler { return } - found := false - token := session.Values[PAAToken] - if token != nil { - _, found = c.TokenCache.Get(token.(string)) - } - - if !found { + found := session.Values["authenticated"] + if found == nil || !found.(bool) { seed := make([]byte, 16) rand.Read(seed) state := hex.EncodeToString(seed) @@ -140,24 +137,35 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) { return } - token := session.Values[PAAToken].(string) - data, found := c.TokenCache.Get(token) - if found == false { + userName := session.Values["preferred_username"] + if userName == nil || userName.(string) == "" { // This shouldnt happen if the Authenticated handler is used to wrap this func - log.Printf("Found expired or non existent session: %s", token) - http.Error(w, errors.New("cannot find token").Error(), http.StatusInternalServerError) + log.Printf("Found expired or non existent session") + http.Error(w, errors.New("cannot find session").Error(), http.StatusInternalServerError) return } // do a round robin selection for now rand.Seed(time.Now().Unix()) - var host = c.Hosts[rand.Intn(len(c.Hosts))] - for k, v := range data.(map[string]interface{}) { - if val, ok := v.(string); ok == true { - host = strings.Replace(host, "{{ "+k+" }}", val, 1) + host := c.Hosts[rand.Intn(len(c.Hosts))] + host = strings.Replace(host, "{{ preferred_username }}", userName.(string), 1) + + user := userName.(string) + if c.UsernameTemplate != "" { + user = strings.Replace(c.UsernameTemplate, "{{ username }}", user, 1) + if c.UsernameTemplate == user { + log.Printf("Invalid username template. %s == %s", c.UsernameTemplate, user) + http.Error(w, errors.New("invalid server configuration").Error(), http.StatusInternalServerError) + return } } + token, err := c.TokenGenerator(user, host) + if err != nil { + log.Printf("Cannot generate token for user %s due to %s", user, err) + http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError) + } + // authenticated seed := make([]byte, 16) rand.Read(seed) @@ -172,7 +180,8 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) { "gatewayusagemethod:i:1\r\n"+ "gatewayprofileusagemethod:i:1\r\n"+ "gatewayaccesstoken:s:"+token+"\r\n"+ - "networkautodetect:i:0\r\n"+ - "bandwidthautodetect:i:1\r\n"+ - "connection type:i:6\r\n")) + "networkautodetect:i:"+strconv.Itoa(c.NetworkAutoDetect)+"\r\n"+ + "bandwidthautodetect:i:"+strconv.Itoa(c.BandwidthAutoDetect)+"\r\n"+ + "connection type:i:"+strconv.Itoa(c.ConnectionType)+"\r\n"+ + "username:s:"+user+"\r\n")) } diff --git a/config/configuration.go b/config/configuration.go index 45271ba..db75289 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -6,9 +6,11 @@ import ( ) type Configuration struct { - Server ServerConfig - OpenId OpenIDConfig - Caps RDGCapsConfig + Server ServerConfig + OpenId OpenIDConfig + Caps RDGCapsConfig + Security SecurityConfig + Client ClientConfig } type ServerConfig struct { @@ -17,8 +19,8 @@ type ServerConfig struct { CertFile string KeyFile string Hosts []string - RoundRobin bool - SessionKey string + RoundRobin bool + SessionKey string } type OpenIDConfig struct { @@ -40,10 +42,26 @@ type RDGCapsConfig struct { EnableDrive bool } +type SecurityConfig struct { + EnableOpenId bool + TokenSigningKey string + PassTokenAsPassword bool +} + +type ClientConfig struct { + NetworkAutoDetect int + BandwidthAutoDetect int + ConnectionType int + UsernameTemplate string +} + func init() { viper.SetDefault("server.certFile", "server.pem") viper.SetDefault("server.keyFile", "key.pem") viper.SetDefault("server.port", 443) + viper.SetDefault("security.enableOpenId", true) + viper.SetDefault("client.networkAutoDetect", 1) + viper.SetDefault("client.bandwidthAutoDetect", 1) } func Load(configFile string) Configuration { @@ -56,12 +74,16 @@ func Load(configFile string) Configuration { viper.AutomaticEnv() if err := viper.ReadInConfig(); err != nil { - log.Printf("No config file found (%s). Using defaults", err) + log.Fatalf("No config file found (%s)", err) } if err := viper.Unmarshal(&conf); err != nil { log.Fatalf("Cannot unmarshal the config file; %s", err) } + if len(conf.Security.TokenSigningKey) < 32 { + log.Fatalf("Token signing key not long enough") + } + return conf } diff --git a/go.mod b/go.mod index 74c06d0..c19ccd6 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.14 require ( github.com/coreos/go-oidc/v3 v3.0.0-alpha.1 + github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 github.com/gorilla/sessions v1.2.0 github.com/gorilla/websocket v1.4.2 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/main.go b/main.go index ea66074..9cc8381 100644 --- a/main.go +++ b/main.go @@ -8,7 +8,6 @@ import ( "github.com/bolkedebruin/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/security" "github.com/coreos/go-oidc/v3/oidc" - "github.com/patrickmn/go-cache" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/spf13/cobra" "golang.org/x/oauth2" @@ -16,7 +15,6 @@ import ( "net/http" "os" "strconv" - "time" ) var cmd = &cobra.Command{ @@ -28,7 +26,6 @@ var ( configFile string ) -var tokens = cache.New(time.Minute *5, 10*time.Minute) var conf config.Configuration func main() { @@ -36,6 +33,9 @@ func main() { cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml", "config file (json, yaml, ini)") conf = config.Load(configFile) + // set security keys + security.SigningKey = []byte(conf.Security.TokenSigningKey) + // set oidc config ctx := context.Background() provider, err := oidc.NewProvider(ctx, conf.OpenId.ProviderUrl) @@ -59,9 +59,13 @@ func main() { GatewayAddress: conf.Server.GatewayAddress, OAuth2Config: &oauthConfig, TokenVerifier: verifier, - TokenCache: tokens, + TokenGenerator: security.GeneratePAAToken, SessionKey: []byte(conf.Server.SessionKey), Hosts: conf.Server.Hosts, + NetworkAutoDetect: conf.Client.NetworkAutoDetect, + UsernameTemplate: conf.Client.UsernameTemplate, + BandwidthAutoDetect: conf.Client.BandwidthAutoDetect, + ConnectionType: conf.Client.ConnectionType, } api.NewApi() @@ -96,11 +100,6 @@ func main() { TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 } - // setup security - securityConfig := &security.Config{ - Store: tokens, - } - // create the gateway handlerConfig := protocol.HandlerConf{ IdleTimeout: conf.Caps.IdleTimeout, @@ -115,7 +114,7 @@ func main() { DisableAll: conf.Caps.DisableRedirect, EnableAll: conf.Caps.RedirectAll, }, - VerifyTunnelCreate: securityConfig.VerifyPAAToken, + VerifyTunnelCreate: security.VerifyPAAToken, } gw := protocol.Gateway{ HandlerConf: &handlerConfig, diff --git a/security/jwt.go b/security/jwt.go new file mode 100644 index 0000000..e083009 --- /dev/null +++ b/security/jwt.go @@ -0,0 +1,70 @@ +package security + +import ( + "errors" + "fmt" + "github.com/bolkedebruin/rdpgw/protocol" + "github.com/dgrijalva/jwt-go/v4" + "log" + "time" +) + +var SigningKey []byte +var ExpiryTime time.Duration = 5 + +type customClaims struct { + RemoteServer string `json:"remoteServer"` + jwt.StandardClaims +} + +func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) { + token, err := jwt.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + return SigningKey, nil + }) + + if err != nil { + return false, err + } + + if _, ok := token.Claims.(*customClaims); ok && token.Valid { + return true, nil + } + + log.Printf("token validation failed: %s", err) + return false, err +} + +func GeneratePAAToken(username string, server string) (string, error) { + if len(SigningKey) < 32 { + return "", errors.New("token signing key not long enough or not specified") + } + + exp := &jwt.Time{ + Time: time.Now().Add(time.Minute * 5), + } + now := &jwt.Time{ + Time: time.Now(), + } + + c := customClaims{ + RemoteServer: server, + StandardClaims: jwt.StandardClaims{ + ExpiresAt: exp, + IssuedAt: now, + Issuer: "rdpgw", + Subject: username, + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS512, c) + if ss, err := token.SignedString(SigningKey); err != nil { + log.Printf("Cannot sign PAA token %s", err) + return "", err + } else { + return ss, nil + } +} \ No newline at end of file diff --git a/security/simple.go b/security/simple.go deleted file mode 100644 index 2ab4260..0000000 --- a/security/simple.go +++ /dev/null @@ -1,21 +0,0 @@ -package security - -import ( - "github.com/bolkedebruin/rdpgw/protocol" - "github.com/patrickmn/go-cache" - "log" -) - -type Config struct { - Store *cache.Cache -} - -func (c *Config) VerifyPAAToken(s *protocol.SessionInfo, token string) (bool, error) { - _, found := c.Store.Get(token) - if !found { - log.Printf("PAA Token %s not found", token) - return false, nil - } - - return true, nil -} \ No newline at end of file -- GitLab