From 8bc3e25f83cf074e76ac5303a1a086029b69aaad Mon Sep 17 00:00:00 2001
From: Bolke de Bruin <bolke@xs4all.nl>
Date: Wed, 17 Aug 2022 10:48:14 +0200
Subject: [PATCH] Allow host query parameter

the host query parameter can now be used
dependent on the `hostselection` config.
---
 README.md                         |  9 ++--
 cmd/rdpgw/api/web.go              | 86 ++++++++++++++++++++++++-------
 cmd/rdpgw/api/web_test.go         | 58 +++++++++++++++++++++
 cmd/rdpgw/config/configuration.go | 19 ++++---
 cmd/rdpgw/main.go                 |  1 +
 5 files changed, 142 insertions(+), 31 deletions(-)
 create mode 100644 cmd/rdpgw/api/web_test.go

diff --git a/README.md b/README.md
index 2f64d6a..79c6434 100644
--- a/README.md
+++ b/README.md
@@ -59,10 +59,13 @@ Server:
  Hosts:
   - localhost:3389
   - my-{{ preferred_username }}-host:3389
-  # Allow the user to connect to any host (insecure)
-  - any 
  # if true the server randomly selects a host to connect to
- RoundRobin: false 
+ # valid options are: 
+ #  - roundrobin, which selects a random host from the list (default)
+ #  - signed, a listed host specified in the signed query parameter
+ #  - unsigned, a listed host specified in the query parameter
+ #  - any, insecurely allow any host specified in the query parameter
+ HostSelection: roundrobin 
  # a random strings of at least 32 characters to secure cookies on the client
  # make sure to share this across the different pods
  SessionKey: thisisasessionkeyreplacethisjetzt
diff --git a/cmd/rdpgw/api/web.go b/cmd/rdpgw/api/web.go
index ac804e2..4522bdb 100644
--- a/cmd/rdpgw/api/web.go
+++ b/cmd/rdpgw/api/web.go
@@ -13,6 +13,7 @@ import (
 	"log"
 	"math/rand"
 	"net/http"
+	"net/url"
 	"os"
 	"strconv"
 	"strings"
@@ -21,7 +22,7 @@ import (
 
 const (
 	RdpGwSession = "RDPGWSESSION"
-	MaxAge 		 = 120
+	MaxAge       = 120
 )
 
 type TokenGeneratorFunc func(context.Context, string, string) (string, error)
@@ -30,7 +31,7 @@ type UserTokenGeneratorFunc func(context.Context, string) (string, error)
 type Config struct {
 	SessionKey           []byte
 	SessionEncryptionKey []byte
-	SessionStore		 string
+	SessionStore         string
 	PAATokenGenerator    TokenGeneratorFunc
 	UserTokenGenerator   UserTokenGeneratorFunc
 	EnableUserToken      bool
@@ -39,13 +40,14 @@ type Config struct {
 	OIDCTokenVerifier    *oidc.IDTokenVerifier
 	stateStore           *cache.Cache
 	Hosts                []string
+	HostSelection        string
 	GatewayAddress       string
 	UsernameTemplate     string
 	NetworkAutoDetect    int
 	BandwidthAutoDetect  int
 	ConnectionType       int
-	SplitUserDomain		 bool
-	DefaultDomain		 string
+	SplitUserDomain      bool
+	DefaultDomain        string
 }
 
 func (c *Config) NewApi() {
@@ -151,6 +153,47 @@ func (c *Config) Authenticated(next http.Handler) http.Handler {
 	})
 }
 
+func (c *Config) selectRandomHost() string {
+	rand.Seed(time.Now().Unix())
+	host := c.Hosts[rand.Intn(len(c.Hosts))]
+	return host
+}
+
+func (c *Config) getHost(u *url.URL) (string, error) {
+	var host string
+	switch c.HostSelection {
+	case "roundrobin":
+		host = c.selectRandomHost()
+	case "signed":
+	case "unsigned":
+		hosts, ok := u.Query()["host"]
+		if !ok {
+			return "", errors.New("invalid query parameter")
+		}
+		found := false
+		for _, check := range c.Hosts {
+			if check == hosts[0] {
+				host = hosts[0]
+				found = true
+				break
+			}
+		}
+		if !found {
+			log.Printf("Invalid host %s specified in client request", hosts[0])
+			return "", errors.New("invalid host specified in query parameter")
+		}
+	case "any":
+		hosts, ok := u.Query()["host"]
+		if !ok {
+			return "", errors.New("invalid query parameter")
+		}
+		host = hosts[0]
+	default:
+		host = c.selectRandomHost()
+	}
+	return host, nil
+}
+
 func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
 	ctx := r.Context()
 	userName, ok := ctx.Value("preferred_username").(string)
@@ -161,9 +204,12 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	// do a round robin selection for now
-	rand.Seed(time.Now().Unix())
-	host := c.Hosts[rand.Intn(len(c.Hosts))]
+	// determine host to connect to
+	host, err := c.getHost(r.URL)
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusBadRequest)
+		return
+	}
 	host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
 
 	// split the username into user and domain
@@ -210,19 +256,19 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
 
 	w.Header().Set("Content-Disposition", "attachment; filename="+fn)
 	w.Header().Set("Content-Type", "application/x-rdp")
-	data := "full address:s:"+host+"\r\n"+
-		"gatewayhostname:s:"+c.GatewayAddress+"\r\n"+
-		"gatewaycredentialssource:i:5\r\n"+
-		"gatewayusagemethod:i:1\r\n"+
-		"gatewayprofileusagemethod:i:1\r\n"+
-		"gatewayaccesstoken:s:"+token+"\r\n"+
-		"networkautodetect:i:"+strconv.Itoa(c.NetworkAutoDetect)+"\r\n"+
-		"bandwidthautodetect:i:"+strconv.Itoa(c.BandwidthAutoDetect)+"\r\n"+
-		"connection type:i:"+strconv.Itoa(c.ConnectionType)+"\r\n"+
-		"username:s:"+render+"\r\n"+
-		"domain:s:"+domain+"\r\n"+
-		"bitmapcachesize:i:32000\r\n"+
-	        "smart sizing:i:1\r\n"
+	data := "full address:s:" + host + "\r\n" +
+		"gatewayhostname:s:" + c.GatewayAddress + "\r\n" +
+		"gatewaycredentialssource:i:5\r\n" +
+		"gatewayusagemethod:i:1\r\n" +
+		"gatewayprofileusagemethod:i:1\r\n" +
+		"gatewayaccesstoken:s:" + token + "\r\n" +
+		"networkautodetect:i:" + strconv.Itoa(c.NetworkAutoDetect) + "\r\n" +
+		"bandwidthautodetect:i:" + strconv.Itoa(c.BandwidthAutoDetect) + "\r\n" +
+		"connection type:i:" + strconv.Itoa(c.ConnectionType) + "\r\n" +
+		"username:s:" + render + "\r\n" +
+		"domain:s:" + domain + "\r\n" +
+		"bitmapcachesize:i:32000\r\n" +
+		"smart sizing:i:1\r\n"
 
 	http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data))
 }
diff --git a/cmd/rdpgw/api/web_test.go b/cmd/rdpgw/api/web_test.go
new file mode 100644
index 0000000..168111f
--- /dev/null
+++ b/cmd/rdpgw/api/web_test.go
@@ -0,0 +1,58 @@
+package api
+
+import (
+	"net/url"
+	"testing"
+)
+
+var (
+	hosts = []string{"10.0.0.1:3389", "10.1.1.1:3000", "32.32.11.1", "remote.host.com"}
+)
+
+func contains(needle string, haystack []string) bool {
+	for _, val := range haystack {
+		if val == needle {
+			return true
+		}
+	}
+	return false
+}
+
+func TestGetHost(t *testing.T) {
+	c := Config{
+		HostSelection: "roundrobin",
+		Hosts:         hosts,
+	}
+	u := &url.URL{
+		Host: "example.com",
+	}
+	vals := u.Query()
+
+	host, err := c.getHost(u)
+	if err != nil {
+		t.Fatalf("#{err}")
+	}
+	if !contains(host, hosts) {
+		t.Fatalf("host %s is not in hosts list", host)
+	}
+
+	// check unsigned
+	c.HostSelection = "unsigned"
+	vals.Set("host", "in.valid.host")
+	u.RawQuery = vals.Encode()
+	host, err = c.getHost(u)
+	if err == nil {
+		t.Fatalf("Accepted host %s is not in hosts list", host)
+	}
+
+	vals.Set("host", hosts[0])
+	u.RawQuery = vals.Encode()
+	host, err = c.getHost(u)
+	if err != nil {
+		t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
+	}
+	if host != hosts[0] {
+		t.Fatalf("host %s is not equal to input %s", host, hosts[0])
+	}
+
+}
diff --git a/cmd/rdpgw/config/configuration.go b/cmd/rdpgw/config/configuration.go
index 9a485e9..10d4ea8 100644
--- a/cmd/rdpgw/config/configuration.go
+++ b/cmd/rdpgw/config/configuration.go
@@ -25,7 +25,7 @@ type ServerConfig struct {
 	CertFile             string   `koanf:"certfile"`
 	KeyFile              string   `koanf:"keyfile"`
 	Hosts                []string `koanf:"hosts"`
-	RoundRobin           bool     `koanf:"roundrobin"`
+	HostSelection        string   `koanf:"hostselection"`
 	SessionKey           string   `koanf:"sessionkey"`
 	SessionEncryptionKey string   `koanf:"sessionencryptionkey"`
 	SessionStore         string   `koanf:"sessionstore"`
@@ -118,6 +118,7 @@ func Load(configFile string) Configuration {
 		"Server.TlsDisabled":         false,
 		"Server.Port":                443,
 		"Server.SessionStore":        "cookie",
+		"Server.HostSelection":       "roundrobin",
 		"Client.NetworkAutoDetect":   1,
 		"Client.BandwidthAutoDetect": 1,
 		"Security.VerifyClientIp":    true,
@@ -153,14 +154,16 @@ func Load(configFile string) Configuration {
 		log.Printf("No valid `security.paatokensigningkey` specified (empty or not 32 characters). Setting to random")
 	}
 
-	if len(Conf.Security.UserTokenEncryptionKey) != 32 {
-		Conf.Security.UserTokenEncryptionKey, _ = security.GenerateRandomString(32)
-		log.Printf("No valid `security.usertokenencryptionkey` specified (empty or not 32 characters). Setting to random")
-	}
+	if Conf.Security.EnableUserToken {
+		if len(Conf.Security.UserTokenEncryptionKey) != 32 {
+			Conf.Security.UserTokenEncryptionKey, _ = security.GenerateRandomString(32)
+			log.Printf("No valid `security.usertokenencryptionkey` specified (empty or not 32 characters). Setting to random")
+		}
 
-	if len(Conf.Security.UserTokenSigningKey) != 32 {
-		Conf.Security.UserTokenSigningKey, _ = security.GenerateRandomString(32)
-		log.Printf("No valid `security.usertokensigningkey` specified (empty or not 32 characters). Setting to random")
+		if len(Conf.Security.UserTokenSigningKey) != 32 {
+			Conf.Security.UserTokenSigningKey, _ = security.GenerateRandomString(32)
+			log.Printf("No valid `security.usertokensigningkey` specified (empty or not 32 characters). Setting to random")
+		}
 	}
 
 	if len(Conf.Server.SessionKey) != 32 {
diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go
index d07bff7..436893f 100644
--- a/cmd/rdpgw/main.go
+++ b/cmd/rdpgw/main.go
@@ -79,6 +79,7 @@ func main() {
 		SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
 		SessionStore:         conf.Server.SessionStore,
 		Hosts:                conf.Server.Hosts,
+		HostSelection:        conf.Server.HostSelection,
 		NetworkAutoDetect:    conf.Client.NetworkAutoDetect,
 		UsernameTemplate:     conf.Client.UsernameTemplate,
 		BandwidthAutoDetect:  conf.Client.BandwidthAutoDetect,
-- 
GitLab