From c6cfdc4dd40b34e6c5cc44ad8142f9202506159d Mon Sep 17 00:00:00 2001
From: Bolke de Bruin <bolke@xs4all.nl>
Date: Mon, 31 Aug 2020 21:07:58 +0200
Subject: [PATCH] Add support for splitting the username from the domain to
 enable smaller tokens

---
 README.md               |  3 +++
 api/web.go              | 29 +++++++++++++++++------------
 config/configuration.go |  1 +
 main.go                 |  1 +
 4 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/README.md b/README.md
index 590c25b..5d6fe7d 100644
--- a/README.md
+++ b/README.md
@@ -85,6 +85,9 @@ client:
   networkAutoDetect: 0
   bandwidthAutoDetect: 1
   ConnectionType: 6
+  # If true puts splits "user@domain.com" into the user and domain component so that
+  # domain gets set in the rdp file and the domain name is stripped from the username
+  SplitUserDomain: false
 security:
   # a random string of at least 32 characters to secure cookies on the client
   # make sure to share this amongst different pods
diff --git a/api/web.go b/api/web.go
index 6e51195..8a35069 100644
--- a/api/web.go
+++ b/api/web.go
@@ -42,6 +42,7 @@ type Config struct {
 	NetworkAutoDetect    int
 	BandwidthAutoDetect  int
 	ConnectionType       int
+	SplitUserDomain		 bool
 }
 
 func (c *Config) NewApi() {
@@ -157,17 +158,23 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
 	host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
 
 	// split the username into user and domain
-	creds := strings.SplitN(userName, "@", 2)
-	user := creds[0]
+	var user string
 	var domain string
-	if len(creds) > 1 {
-		domain = creds[1]
+	if c.SplitUserDomain {
+		creds := strings.SplitN(userName, "@", 2)
+		user = creds[0]
+		if len(creds) > 1 {
+			domain = creds[1]
+		}
+	} else {
+		user = userName
 	}
 
+	render := user
 	if c.UsernameTemplate != "" {
-		c.UsernameTemplate = fmt.Sprintf(c.UsernameTemplate)
-		user = strings.Replace(c.UsernameTemplate, "{{ username }}", user, 1)
-		if c.UsernameTemplate == user {
+		render = fmt.Sprintf(c.UsernameTemplate)
+		render = strings.Replace(render, "{{ username }}", user, 1)
+		if c.UsernameTemplate == render {
 			log.Printf("Invalid username template. %s == %s", c.UsernameTemplate, user)
 			http.Error(w, errors.New("invalid server configuration").Error(), http.StatusInternalServerError)
 			return
@@ -180,17 +187,15 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
 	}
 
-	userToken := user
 	if c.EnableUserToken {
-		userToken, err = c.UserTokenGenerator(ctx, user)
+		userToken, err := c.UserTokenGenerator(ctx, user)
 		if err != nil {
 			log.Printf("Cannot generate token for user %s due to %s", user, err)
 			http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
 		}
+		render = strings.Replace(render, "{{ token }}", userToken, 1)
 	}
 
-	user = strings.Replace(user,"{{ token }}", userToken, 1)
-
 	// authenticated
 	seed := make([]byte, 16)
 	rand.Read(seed)
@@ -207,7 +212,7 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
 		"networkautodetect:i:"+strconv.Itoa(c.NetworkAutoDetect)+"\r\n"+
 		"bandwidthautodetect:i:"+strconv.Itoa(c.BandwidthAutoDetect)+"\r\n"+
 		"connection type:i:"+strconv.Itoa(c.ConnectionType)+"\r\n"+
-		"username:s:"+user+"\r\n"+
+		"username:s:"+render+"\r\n"+
 		"domain:s:"+domain+"\r\n"+
 		"bitmapcachesize:i:32000\r\n"
 
diff --git a/config/configuration.go b/config/configuration.go
index 1a6c4e8..03de469 100644
--- a/config/configuration.go
+++ b/config/configuration.go
@@ -57,6 +57,7 @@ type ClientConfig struct {
 	BandwidthAutoDetect int
 	ConnectionType      int
 	UsernameTemplate    string
+	SplitUserDomain     bool
 }
 
 func init() {
diff --git a/main.go b/main.go
index 226bef2..733e555 100644
--- a/main.go
+++ b/main.go
@@ -76,6 +76,7 @@ func main() {
 		UsernameTemplate:     conf.Client.UsernameTemplate,
 		BandwidthAutoDetect:  conf.Client.BandwidthAutoDetect,
 		ConnectionType:       conf.Client.ConnectionType,
+		SplitUserDomain:      conf.Client.SplitUserDomain,
 	}
 	api.NewApi()
 
-- 
GitLab