From 9d2dc57e9009e7ea6677293ca59c9e966ad49886 Mon Sep 17 00:00:00 2001
From: Bolke de Bruin <bolke@xs4all.nl>
Date: Thu, 25 Aug 2022 11:22:23 +0200
Subject: [PATCH] Check valid host from list

---
 cmd/rdpgw/main.go            | 12 ++++++++----
 cmd/rdpgw/protocol/server.go |  1 +
 cmd/rdpgw/security/basic.go  | 36 ++++++++++++++++++++++++++++++++++++
 3 files changed, 45 insertions(+), 4 deletions(-)
 create mode 100644 cmd/rdpgw/security/basic.go

diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go
index 96ce034..e89e77d 100644
--- a/cmd/rdpgw/main.go
+++ b/cmd/rdpgw/main.go
@@ -41,6 +41,8 @@ func main() {
 	security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey)
 	security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey)
 	security.QuerySigningKey = []byte(conf.Security.QueryTokenSigningKey)
+	security.HostSelection = conf.Server.HostSelection
+	security.Hosts = conf.Server.Hosts
 
 	// configure api
 	api := &api.Config{
@@ -136,7 +138,7 @@ func main() {
 	}
 
 	// create the gateway
-	handlerConfig := protocol.ServerConf{
+	gwConfig := protocol.ServerConf{
 		IdleTimeout:   conf.Caps.IdleTimeout,
 		TokenAuth:     conf.Caps.TokenAuth,
 		SmartCardAuth: conf.Caps.SmartCardAuth,
@@ -153,11 +155,13 @@ func main() {
 		ReceiveBuf: conf.Server.ReceiveBuf,
 	}
 	if conf.Caps.TokenAuth {
-		handlerConfig.VerifyTunnelAuthFunc = security.VerifyPAAToken
-		handlerConfig.VerifyServerFunc = security.VerifyServerFunc
+		gwConfig.VerifyTunnelAuthFunc = security.VerifyPAAToken
+		gwConfig.VerifyServerFunc = security.VerifyServerFunc
+	} else {
+		gwConfig.VerifyServerFunc = security.BasicVerifyServer
 	}
 	gw := protocol.Gateway{
-		ServerConf: &handlerConfig,
+		ServerConf: &gwConfig,
 	}
 
 	if conf.Server.Authentication == "local" {
diff --git a/cmd/rdpgw/protocol/server.go b/cmd/rdpgw/protocol/server.go
index 6571ece..3f07333 100644
--- a/cmd/rdpgw/protocol/server.go
+++ b/cmd/rdpgw/protocol/server.go
@@ -143,6 +143,7 @@ func (s *Server) Process(ctx context.Context) error {
 			server, port := s.channelRequest(pkt)
 			host := net.JoinHostPort(server, strconv.Itoa(int(port)))
 			if s.VerifyServerFunc != nil {
+				log.Printf("Verifying %s host connection", host)
 				if ok, _ := s.VerifyServerFunc(ctx, host); !ok {
 					log.Printf("Not allowed to connect to %s by policy handler", host)
 					msg := s.channelResponse(E_PROXY_RAP_ACCESSDENIED)
diff --git a/cmd/rdpgw/security/basic.go b/cmd/rdpgw/security/basic.go
new file mode 100644
index 0000000..64b5db0
--- /dev/null
+++ b/cmd/rdpgw/security/basic.go
@@ -0,0 +1,36 @@
+package security
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"log"
+)
+
+var (
+	Hosts         []string
+	HostSelection string
+)
+
+func BasicVerifyServer(ctx context.Context, host string) (bool, error) {
+	if HostSelection == "any" {
+		return true, nil
+	}
+
+	if HostSelection == "signed" {
+		// todo get from context
+		return false, errors.New("cannot verify host in 'signed' mode as token data is missing")
+	}
+
+	if HostSelection == "roundrobin" || HostSelection == "unsigned" {
+		log.Printf("Checking host")
+		for _, h := range Hosts {
+			if h == host {
+				return true, nil
+			}
+		}
+		return false, fmt.Errorf("invalid host %s", host)
+	}
+
+	return false, errors.New("unrecognized host selection criteria")
+}
-- 
GitLab