From 8bc3e25f83cf074e76ac5303a1a086029b69aaad Mon Sep 17 00:00:00 2001 From: Bolke de Bruin <bolke@xs4all.nl> Date: Wed, 17 Aug 2022 10:48:14 +0200 Subject: [PATCH] Allow host query parameter the host query parameter can now be used dependent on the `hostselection` config. --- README.md | 9 ++-- cmd/rdpgw/api/web.go | 86 ++++++++++++++++++++++++------- cmd/rdpgw/api/web_test.go | 58 +++++++++++++++++++++ cmd/rdpgw/config/configuration.go | 19 ++++--- cmd/rdpgw/main.go | 1 + 5 files changed, 142 insertions(+), 31 deletions(-) create mode 100644 cmd/rdpgw/api/web_test.go diff --git a/README.md b/README.md index 2f64d6a..79c6434 100644 --- a/README.md +++ b/README.md @@ -59,10 +59,13 @@ Server: Hosts: - localhost:3389 - my-{{ preferred_username }}-host:3389 - # Allow the user to connect to any host (insecure) - - any # if true the server randomly selects a host to connect to - RoundRobin: false + # valid options are: + # - roundrobin, which selects a random host from the list (default) + # - signed, a listed host specified in the signed query parameter + # - unsigned, a listed host specified in the query parameter + # - any, insecurely allow any host specified in the query parameter + HostSelection: roundrobin # a random strings of at least 32 characters to secure cookies on the client # make sure to share this across the different pods SessionKey: thisisasessionkeyreplacethisjetzt diff --git a/cmd/rdpgw/api/web.go b/cmd/rdpgw/api/web.go index ac804e2..4522bdb 100644 --- a/cmd/rdpgw/api/web.go +++ b/cmd/rdpgw/api/web.go @@ -13,6 +13,7 @@ import ( "log" "math/rand" "net/http" + "net/url" "os" "strconv" "strings" @@ -21,7 +22,7 @@ import ( const ( RdpGwSession = "RDPGWSESSION" - MaxAge = 120 + MaxAge = 120 ) type TokenGeneratorFunc func(context.Context, string, string) (string, error) @@ -30,7 +31,7 @@ type UserTokenGeneratorFunc func(context.Context, string) (string, error) type Config struct { SessionKey []byte SessionEncryptionKey []byte - SessionStore string + SessionStore string PAATokenGenerator TokenGeneratorFunc UserTokenGenerator UserTokenGeneratorFunc EnableUserToken bool @@ -39,13 +40,14 @@ type Config struct { OIDCTokenVerifier *oidc.IDTokenVerifier stateStore *cache.Cache Hosts []string + HostSelection string GatewayAddress string UsernameTemplate string NetworkAutoDetect int BandwidthAutoDetect int ConnectionType int - SplitUserDomain bool - DefaultDomain string + SplitUserDomain bool + DefaultDomain string } func (c *Config) NewApi() { @@ -151,6 +153,47 @@ func (c *Config) Authenticated(next http.Handler) http.Handler { }) } +func (c *Config) selectRandomHost() string { + rand.Seed(time.Now().Unix()) + host := c.Hosts[rand.Intn(len(c.Hosts))] + return host +} + +func (c *Config) getHost(u *url.URL) (string, error) { + var host string + switch c.HostSelection { + case "roundrobin": + host = c.selectRandomHost() + case "signed": + case "unsigned": + hosts, ok := u.Query()["host"] + if !ok { + return "", errors.New("invalid query parameter") + } + found := false + for _, check := range c.Hosts { + if check == hosts[0] { + host = hosts[0] + found = true + break + } + } + if !found { + log.Printf("Invalid host %s specified in client request", hosts[0]) + return "", errors.New("invalid host specified in query parameter") + } + case "any": + hosts, ok := u.Query()["host"] + if !ok { + return "", errors.New("invalid query parameter") + } + host = hosts[0] + default: + host = c.selectRandomHost() + } + return host, nil +} + func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) { ctx := r.Context() userName, ok := ctx.Value("preferred_username").(string) @@ -161,9 +204,12 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) { return } - // do a round robin selection for now - rand.Seed(time.Now().Unix()) - host := c.Hosts[rand.Intn(len(c.Hosts))] + // determine host to connect to + host, err := c.getHost(r.URL) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } host = strings.Replace(host, "{{ preferred_username }}", userName, 1) // split the username into user and domain @@ -210,19 +256,19 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Disposition", "attachment; filename="+fn) w.Header().Set("Content-Type", "application/x-rdp") - data := "full address:s:"+host+"\r\n"+ - "gatewayhostname:s:"+c.GatewayAddress+"\r\n"+ - "gatewaycredentialssource:i:5\r\n"+ - "gatewayusagemethod:i:1\r\n"+ - "gatewayprofileusagemethod:i:1\r\n"+ - "gatewayaccesstoken:s:"+token+"\r\n"+ - "networkautodetect:i:"+strconv.Itoa(c.NetworkAutoDetect)+"\r\n"+ - "bandwidthautodetect:i:"+strconv.Itoa(c.BandwidthAutoDetect)+"\r\n"+ - "connection type:i:"+strconv.Itoa(c.ConnectionType)+"\r\n"+ - "username:s:"+render+"\r\n"+ - "domain:s:"+domain+"\r\n"+ - "bitmapcachesize:i:32000\r\n"+ - "smart sizing:i:1\r\n" + data := "full address:s:" + host + "\r\n" + + "gatewayhostname:s:" + c.GatewayAddress + "\r\n" + + "gatewaycredentialssource:i:5\r\n" + + "gatewayusagemethod:i:1\r\n" + + "gatewayprofileusagemethod:i:1\r\n" + + "gatewayaccesstoken:s:" + token + "\r\n" + + "networkautodetect:i:" + strconv.Itoa(c.NetworkAutoDetect) + "\r\n" + + "bandwidthautodetect:i:" + strconv.Itoa(c.BandwidthAutoDetect) + "\r\n" + + "connection type:i:" + strconv.Itoa(c.ConnectionType) + "\r\n" + + "username:s:" + render + "\r\n" + + "domain:s:" + domain + "\r\n" + + "bitmapcachesize:i:32000\r\n" + + "smart sizing:i:1\r\n" http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data)) } diff --git a/cmd/rdpgw/api/web_test.go b/cmd/rdpgw/api/web_test.go new file mode 100644 index 0000000..168111f --- /dev/null +++ b/cmd/rdpgw/api/web_test.go @@ -0,0 +1,58 @@ +package api + +import ( + "net/url" + "testing" +) + +var ( + hosts = []string{"10.0.0.1:3389", "10.1.1.1:3000", "32.32.11.1", "remote.host.com"} +) + +func contains(needle string, haystack []string) bool { + for _, val := range haystack { + if val == needle { + return true + } + } + return false +} + +func TestGetHost(t *testing.T) { + c := Config{ + HostSelection: "roundrobin", + Hosts: hosts, + } + u := &url.URL{ + Host: "example.com", + } + vals := u.Query() + + host, err := c.getHost(u) + if err != nil { + t.Fatalf("#{err}") + } + if !contains(host, hosts) { + t.Fatalf("host %s is not in hosts list", host) + } + + // check unsigned + c.HostSelection = "unsigned" + vals.Set("host", "in.valid.host") + u.RawQuery = vals.Encode() + host, err = c.getHost(u) + if err == nil { + t.Fatalf("Accepted host %s is not in hosts list", host) + } + + vals.Set("host", hosts[0]) + u.RawQuery = vals.Encode() + host, err = c.getHost(u) + if err != nil { + t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err) + } + if host != hosts[0] { + t.Fatalf("host %s is not equal to input %s", host, hosts[0]) + } + +} diff --git a/cmd/rdpgw/config/configuration.go b/cmd/rdpgw/config/configuration.go index 9a485e9..10d4ea8 100644 --- a/cmd/rdpgw/config/configuration.go +++ b/cmd/rdpgw/config/configuration.go @@ -25,7 +25,7 @@ type ServerConfig struct { CertFile string `koanf:"certfile"` KeyFile string `koanf:"keyfile"` Hosts []string `koanf:"hosts"` - RoundRobin bool `koanf:"roundrobin"` + HostSelection string `koanf:"hostselection"` SessionKey string `koanf:"sessionkey"` SessionEncryptionKey string `koanf:"sessionencryptionkey"` SessionStore string `koanf:"sessionstore"` @@ -118,6 +118,7 @@ func Load(configFile string) Configuration { "Server.TlsDisabled": false, "Server.Port": 443, "Server.SessionStore": "cookie", + "Server.HostSelection": "roundrobin", "Client.NetworkAutoDetect": 1, "Client.BandwidthAutoDetect": 1, "Security.VerifyClientIp": true, @@ -153,14 +154,16 @@ func Load(configFile string) Configuration { log.Printf("No valid `security.paatokensigningkey` specified (empty or not 32 characters). Setting to random") } - if len(Conf.Security.UserTokenEncryptionKey) != 32 { - Conf.Security.UserTokenEncryptionKey, _ = security.GenerateRandomString(32) - log.Printf("No valid `security.usertokenencryptionkey` specified (empty or not 32 characters). Setting to random") - } + if Conf.Security.EnableUserToken { + if len(Conf.Security.UserTokenEncryptionKey) != 32 { + Conf.Security.UserTokenEncryptionKey, _ = security.GenerateRandomString(32) + log.Printf("No valid `security.usertokenencryptionkey` specified (empty or not 32 characters). Setting to random") + } - if len(Conf.Security.UserTokenSigningKey) != 32 { - Conf.Security.UserTokenSigningKey, _ = security.GenerateRandomString(32) - log.Printf("No valid `security.usertokensigningkey` specified (empty or not 32 characters). Setting to random") + if len(Conf.Security.UserTokenSigningKey) != 32 { + Conf.Security.UserTokenSigningKey, _ = security.GenerateRandomString(32) + log.Printf("No valid `security.usertokensigningkey` specified (empty or not 32 characters). Setting to random") + } } if len(Conf.Server.SessionKey) != 32 { diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index d07bff7..436893f 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -79,6 +79,7 @@ func main() { SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey), SessionStore: conf.Server.SessionStore, Hosts: conf.Server.Hosts, + HostSelection: conf.Server.HostSelection, NetworkAutoDetect: conf.Client.NetworkAutoDetect, UsernameTemplate: conf.Client.UsernameTemplate, BandwidthAutoDetect: conf.Client.BandwidthAutoDetect, -- GitLab