From cb8b26947828c21475b68983cfce7fd211d42b21 Mon Sep 17 00:00:00 2001 From: Bolke de Bruin <bolke@xs4all.nl> Date: Wed, 17 Aug 2022 19:12:28 +0200 Subject: [PATCH] Enable signed hosts provied in query parameters --- cmd/rdpgw/api/web.go | 41 +++++++++++++++------- cmd/rdpgw/api/web_test.go | 41 ++++++++++++++++++++-- cmd/rdpgw/config/configuration.go | 6 ++++ cmd/rdpgw/main.go | 3 ++ cmd/rdpgw/security/jwt.go | 56 +++++++++++++++++++++++++++++++ 5 files changed, 132 insertions(+), 15 deletions(-) diff --git a/cmd/rdpgw/api/web.go b/cmd/rdpgw/api/web.go index 4522bdb..acf1720 100644 --- a/cmd/rdpgw/api/web.go +++ b/cmd/rdpgw/api/web.go @@ -27,6 +27,7 @@ const ( type TokenGeneratorFunc func(context.Context, string, string) (string, error) type UserTokenGeneratorFunc func(context.Context, string) (string, error) +type QueryInfoFunc func(context.Context, string, string) (string, error) type Config struct { SessionKey []byte @@ -34,6 +35,8 @@ type Config struct { SessionStore string PAATokenGenerator TokenGeneratorFunc UserTokenGenerator UserTokenGeneratorFunc + QueryInfo QueryInfoFunc + QueryTokenIssuer string EnableUserToken bool OAuth2Config *oauth2.Config store sessions.Store @@ -159,39 +162,53 @@ func (c *Config) selectRandomHost() string { return host } -func (c *Config) getHost(u *url.URL) (string, error) { - var host string +func (c *Config) getHost(ctx context.Context, u *url.URL) (string, error) { switch c.HostSelection { case "roundrobin": - host = c.selectRandomHost() + return c.selectRandomHost(), nil case "signed": - case "unsigned": hosts, ok := u.Query()["host"] if !ok { return "", errors.New("invalid query parameter") } + host, err := c.QueryInfo(ctx, hosts[0], c.QueryTokenIssuer) + if err != nil { + return "", err + } found := false for _, check := range c.Hosts { - if check == hosts[0] { - host = hosts[0] + if check == host { found = true break } } if !found { - log.Printf("Invalid host %s specified in client request", hosts[0]) - return "", errors.New("invalid host specified in query parameter") + log.Printf("Invalid host %s specified in token", hosts[0]) + return "", errors.New("invalid host specified in query token") + } + return host, nil + case "unsigned": + hosts, ok := u.Query()["host"] + if !ok { + return "", errors.New("invalid query parameter") + } + for _, check := range c.Hosts { + if check == hosts[0] { + return hosts[0], nil + } } + // not found + log.Printf("Invalid host %s specified in client request", hosts[0]) + return "", errors.New("invalid host specified in query parameter") case "any": hosts, ok := u.Query()["host"] if !ok { return "", errors.New("invalid query parameter") } - host = hosts[0] + return hosts[0], nil default: - host = c.selectRandomHost() + return c.selectRandomHost(), nil } - return host, nil } func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) { @@ -205,7 +222,7 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) { } // determine host to connect to - host, err := c.getHost(r.URL) + host, err := c.getHost(ctx, r.URL) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return diff --git a/cmd/rdpgw/api/web_test.go b/cmd/rdpgw/api/web_test.go index 168111f..c1e4e5c 100644 --- a/cmd/rdpgw/api/web_test.go +++ b/cmd/rdpgw/api/web_test.go @@ -1,12 +1,15 @@ package api import ( + "context" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" "net/url" "testing" ) var ( hosts = []string{"10.0.0.1:3389", "10.1.1.1:3000", "32.32.11.1", "remote.host.com"} + key = []byte("thisisasessionkeyreplacethisjetzt") ) func contains(needle string, haystack []string) bool { @@ -19,6 +22,7 @@ func contains(needle string, haystack []string) bool { } func TestGetHost(t *testing.T) { + ctx := context.Background() c := Config{ HostSelection: "roundrobin", Hosts: hosts, @@ -28,7 +32,7 @@ func TestGetHost(t *testing.T) { } vals := u.Query() - host, err := c.getHost(u) + host, err := c.getHost(ctx, u) if err != nil { t.Fatalf("#{err}") } @@ -40,14 +44,14 @@ func TestGetHost(t *testing.T) { c.HostSelection = "unsigned" vals.Set("host", "in.valid.host") u.RawQuery = vals.Encode() - host, err = c.getHost(u) + host, err = c.getHost(ctx, u) if err == nil { t.Fatalf("Accepted host %s is not in hosts list", host) } vals.Set("host", hosts[0]) u.RawQuery = vals.Encode() - host, err = c.getHost(u) + host, err = c.getHost(ctx, u) if err != nil { t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err) } @@ -55,4 +59,35 @@ func TestGetHost(t *testing.T) { t.Fatalf("host %s is not equal to input %s", host, hosts[0]) } + // check any + c.HostSelection = "any" + test := "bla.bla.com" + vals.Set("host", test) + u.RawQuery = vals.Encode() + host, err = c.getHost(ctx, u) + if err != nil { + t.Fatalf("%s is not accepted", host) + } + if test != host { + t.Fatalf("Returned host %s is not equal to input host %s", host, test) + } + + // check signed + c.HostSelection = "signed" + c.QueryInfo = security.QueryInfo + issuer := "rdpgwtest" + security.QuerySigningKey = key + queryToken, err := security.GenerateQueryToken(ctx, hosts[0], issuer) + if err != nil { + t.Fatalf("cannot generate token") + } + vals.Set("host", queryToken) + u.RawQuery = vals.Encode() + host, err = c.getHost(ctx, u) + if err != nil { + t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err) + } + if host != hosts[0] { + t.Fatalf("%s does not equal %s", host, hosts[0]) + } } diff --git a/cmd/rdpgw/config/configuration.go b/cmd/rdpgw/config/configuration.go index 10d4ea8..43c9445 100644 --- a/cmd/rdpgw/config/configuration.go +++ b/cmd/rdpgw/config/configuration.go @@ -58,6 +58,8 @@ type SecurityConfig struct { PAATokenSigningKey string `koanf:"paatokensigningkey"` UserTokenEncryptionKey string `koanf:"usertokenencryptionkey"` UserTokenSigningKey string `koanf:"usertokensigningkey"` + QueryTokenSigningKey string `koanf:"querytokensigningkey"` + QueryTokenIssuer string `koanf:"querytokenissuer"` VerifyClientIp bool `koanf:"verifyclientip"` EnableUserToken bool `koanf:"enableusertoken"` } @@ -176,6 +178,10 @@ func Load(configFile string) Configuration { log.Printf("No valid `server.sessionencryptionkey` specified (empty or not 32 characters). Setting to random") } + if Conf.Server.HostSelection == "signed" && len(Conf.Security.QueryTokenSigningKey) == 0 { + log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set") + } + return Conf } diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index 436893f..567c677 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -40,6 +40,7 @@ func main() { security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey) security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey) security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey) + security.QuerySigningKey = []byte(conf.Security.QueryTokenSigningKey) // set oidc config provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl) @@ -74,6 +75,8 @@ func main() { OIDCTokenVerifier: verifier, PAATokenGenerator: security.GeneratePAAToken, UserTokenGenerator: security.GenerateUserToken, + QueryInfo: security.QueryInfo, + QueryTokenIssuer: conf.Security.QueryTokenIssuer, EnableUserToken: conf.Security.EnableUserToken, SessionKey: []byte(conf.Server.SessionKey), SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey), diff --git a/cmd/rdpgw/security/jwt.go b/cmd/rdpgw/security/jwt.go index cc3fb5a..8deef42 100644 --- a/cmd/rdpgw/security/jwt.go +++ b/cmd/rdpgw/security/jwt.go @@ -19,6 +19,7 @@ var ( EncryptionKey []byte UserSigningKey []byte UserEncryptionKey []byte + QuerySigningKey []byte OIDCProvider *oidc.Provider Oauth2Config oauth2.Config ) @@ -221,6 +222,61 @@ func UserInfo(ctx context.Context, token string) (jwt.Claims, error) { return standard, nil } +func QueryInfo(ctx context.Context, tokenString string, issuer string) (string, error) { + standard := jwt.Claims{} + token, err := jwt.ParseSigned(tokenString) + if err != nil { + log.Printf("Cannot get token %s", err) + return "", errors.New("cannot get token") + } + if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil { + log.Printf("signature validation failure: %s", err) + return "", errors.New("signature validation failure") + } + err = token.Claims(QuerySigningKey, &standard) + if err = token.Claims(QuerySigningKey, &standard); err != nil { + log.Printf("cannot verify signature %s", err) + return "", errors.New("cannot verify signature") + } + + // go-jose doesnt verify the expiry + err = standard.Validate(jwt.Expected{ + Issuer: issuer, + Time: time.Now(), + }) + + if err != nil { + log.Printf("token validation failed due to %s", err) + return "", fmt.Errorf("token validation failed due to %s", err) + } + + return standard.Subject, nil +} + +// GenerateQueryToken this is a helper function for testing +func GenerateQueryToken(ctx context.Context, query string, issuer string) (string, error) { + if len(QuerySigningKey) < 32 { + return "", errors.New("query token encryption key not long enough or not specified") + } + + claims := jwt.Claims{ + Subject: query, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), + Issuer: issuer, + } + + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: QuerySigningKey}, + (&jose.SignerOptions{}).WithBase64(true)) + + if err != nil { + log.Printf("Cannot encrypt user token due to %s", err) + return "", err + } + + token, err := jwt.Signed(sig).Claims(claims).CompactSerialize() + return token, err +} + func getSessionInfo(ctx context.Context) *protocol.SessionInfo { s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo) if !ok { -- GitLab