From f72613c2baf7329055e0338f0f2d799cb4a55765 Mon Sep 17 00:00:00 2001
From: ryanblenis <ryan.blenis@gmail.com>
Date: Sat, 16 Dec 2023 15:07:37 -0500
Subject: [PATCH] Add BasicAuthTimeout setting versus static 5 seconds (#90)

---
 README.md                         | 2 ++
 cmd/rdpgw/config/configuration.go | 2 ++
 cmd/rdpgw/main.go                 | 2 +-
 cmd/rdpgw/web/basic.go            | 3 ++-
 4 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/README.md b/README.md
index 74ac10f..9f3bd29 100644
--- a/README.md
+++ b/README.md
@@ -66,6 +66,8 @@ Server:
  # The socket to connect to if using local auth. Ensure rdpgw auth is configured to
  # use the same socket.
  AuthSocket: /tmp/rdpgw-auth.sock
+ # Basic auth timeout (in seconds). Useful if you're planning on waiting for MFA
+ BasicAuthTimeout: 5
  # The default option 'auto' uses a certificate file if provided and found otherwise
  # it uses letsencrypt to obtain a certificate, the latter requires that the host is reachable
  # from letsencrypt servers. If TLS termination happens somewhere else (e.g. a load balancer)
diff --git a/cmd/rdpgw/config/configuration.go b/cmd/rdpgw/config/configuration.go
index 525158b..52a3018 100644
--- a/cmd/rdpgw/config/configuration.go
+++ b/cmd/rdpgw/config/configuration.go
@@ -51,6 +51,7 @@ type ServerConfig struct {
 	Tls                  string   `koanf:"tls"`
 	Authentication       []string `koanf:"authentication"`
 	AuthSocket           string   `koanf:"authsocket"`
+	BasicAuthTimeout     int      `koanf:"basicauthtimeout"`
 }
 
 type KerberosConfig struct {
@@ -143,6 +144,7 @@ func Load(configFile string) Configuration {
 		"Server.HostSelection":       "roundrobin",
 		"Server.Authentication":      "openid",
 		"Server.AuthSocket":          "/tmp/rdpgw-auth.sock",
+		"Server.BasicAuthTimeout":    5,
 		"Client.NetworkAutoDetect":   1,
 		"Client.BandwidthAutoDetect": 1,
 		"Security.VerifyClientIp":    true,
diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go
index bf44b7b..8b8892f 100644
--- a/cmd/rdpgw/main.go
+++ b/cmd/rdpgw/main.go
@@ -232,7 +232,7 @@ func main() {
 	// basic auth
 	if conf.Server.BasicAuthEnabled() {
 		log.Printf("enabling basic authentication")
-		q := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket}
+		q := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket, Timeout: conf.Server.BasicAuthTimeout}
 		rdp.NewRoute().HeadersRegexp("Authorization", "Basic").HandlerFunc(q.BasicAuth(gw.HandleGatewayProtocol))
 		auth.Register(`Basic realm="restricted", charset="UTF-8"`)
 	}
diff --git a/cmd/rdpgw/web/basic.go b/cmd/rdpgw/web/basic.go
index 84724e3..9f829f6 100644
--- a/cmd/rdpgw/web/basic.go
+++ b/cmd/rdpgw/web/basic.go
@@ -18,6 +18,7 @@ const (
 
 type BasicAuthHandler struct {
 	SocketAddress string
+	Timeout       int
 }
 
 func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
@@ -38,7 +39,7 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
 			defer conn.Close()
 
 			c := auth.NewAuthenticateClient(conn)
-			ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+			ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(h.Timeout))
 			defer cancel()
 
 			req := &auth.UserPass{Username: username, Password: password}
-- 
GitLab