From 1ac36df8676ae16772a3085826365b4d933e36fc Mon Sep 17 00:00:00 2001
From: Bolke de Bruin <bolke@xs4all.nl>
Date: Wed, 10 Aug 2022 22:20:49 +0200
Subject: [PATCH] Return proper error if caps don't match

---
 cmd/rdpgw/protocol/server.go | 38 ++++++++++++++++++++++++------------
 1 file changed, 26 insertions(+), 12 deletions(-)

diff --git a/cmd/rdpgw/protocol/server.go b/cmd/rdpgw/protocol/server.go
index 11e231e..2a1c4c8 100644
--- a/cmd/rdpgw/protocol/server.go
+++ b/cmd/rdpgw/protocol/server.go
@@ -74,12 +74,19 @@ func (s *Server) Process(ctx context.Context) error {
 			log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx))
 			if s.State != SERVER_STATE_INITIALIZED {
 				log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIALIZED)
-				msg := s.handshakeResponse(0x0, 0x0, E_PROXY_INTERNALERROR)
+				msg := s.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
 				s.Session.TransportOut.WritePacket(msg)
 				return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR)
 			}
-			major, minor, _, _ := s.handshakeRequest(pkt) // todo check if auth matches what the handler can do
-			msg := s.handshakeResponse(major, minor, ERROR_SUCCESS)
+			major, minor, _, auth := s.handshakeRequest(pkt) // todo check if auth matches what the handler can do
+			caps, err := s.matchAuth(auth)
+			if err != nil {
+				log.Println(err)
+				msg := s.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH)
+				s.Session.TransportOut.WritePacket(msg)
+				return err
+			}
+			msg := s.handshakeResponse(major, minor, caps, ERROR_SUCCESS)
 			s.Session.TransportOut.WritePacket(msg)
 			s.State = SERVER_STATE_HANDSHAKE
 		case PKT_TYPE_TUNNEL_CREATE:
@@ -196,15 +203,7 @@ func (s *Server) Process(ctx context.Context) error {
 // Creates a packet the is a response to a handshakeRequest request
 // HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
 // but could be in Windows. However the NTLM protocol is insecure
-func (s *Server) handshakeResponse(major byte, minor byte, errorCode int) []byte {
-	var caps uint16
-	if s.SmartCardAuth {
-		caps = caps | HTTP_EXTENDED_AUTH_SC
-	}
-	if s.TokenAuth {
-		caps = caps | HTTP_EXTENDED_AUTH_PAA
-	}
-
+func (s *Server) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte {
 	buf := new(bytes.Buffer)
 	binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error_code
 	buf.Write([]byte{major, minor})
@@ -225,6 +224,21 @@ func (s *Server) handshakeRequest(data []byte) (major byte, minor byte, version
 	return
 }
 
+func (s *Server) matchAuth(extAuth uint16) (caps uint16, err error) {
+	if s.SmartCardAuth && extAuth & HTTP_EXTENDED_AUTH_SC == 1 {
+		caps = caps | HTTP_EXTENDED_AUTH_SC
+	}
+	if s.TokenAuth && extAuth & HTTP_EXTENDED_AUTH_PAA == 1 {
+		caps = caps | HTTP_EXTENDED_AUTH_PAA
+	}
+
+	if caps & extAuth == 0 {
+		return 0, fmt.Errorf("%x has no matching capability configured (%x). Did you configure caps? ", extAuth, caps)
+	}
+
+	return caps, nil
+}
+
 func (s *Server) tunnelRequest(data []byte) (caps uint32, cookie string) {
 	var fields uint16
 
-- 
GitLab