From 1f191a5e41b32e9e38ffc32f8616df02ec3f81f8 Mon Sep 17 00:00:00 2001
From: Bolke de Bruin <bolke@xs4all.nl>
Date: Mon, 13 Jul 2020 15:38:25 +0200
Subject: [PATCH] Use standard HandleFunc pattern

---
 main.go |  4 ++--
 rdg.go  | 42 ++++++++++++++++++------------------------
 2 files changed, 20 insertions(+), 26 deletions(-)

diff --git a/main.go b/main.go
index 9cb7531..ba3da3d 100644
--- a/main.go
+++ b/main.go
@@ -41,10 +41,10 @@ func main() {
 	cfg.Certificates = append(cfg.Certificates, cert)
 	server := http.Server{
 		Addr:      ":" + strconv.Itoa(*port),
-		Handler:   Upgrade(nil),
 		TLSConfig: cfg,
 	}
-
+	http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol)
+	
 	err = server.ListenAndServeTLS("", "")
 	if err != nil {
 		log.Fatal("ListenAndServe: ", err)
diff --git a/rdg.go b/rdg.go
index 7222e2b..4d8526a 100644
--- a/rdg.go
+++ b/rdg.go
@@ -110,10 +110,6 @@ var ErrNotHijacker = RejectConnectionError(
 
 var DefaultSession RdgSession
 
-func Upgrade(next http.Handler) http.Handler {
-	return handleGatewayProtocol(next)
-}
-
 func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) {
 	log.Print("Accept connection")
 	hj, ok := w.(http.Hijacker)
@@ -132,29 +128,27 @@ func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err err
 var upgrader = websocket.Upgrader{}
 var c = cache.New(5*time.Minute, 10*time.Minute)
 
-func handleGatewayProtocol(next http.Handler) http.Handler {
-	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		if r.Method == MethodRDGOUT {
-			if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
-				handleLegacyProtocol(w, r)
-				return
-			}
-			r.Method = "GET" // force
-			conn, err := upgrader.Upgrade(w, r, nil)
-			if err != nil {
-				log.Printf("Cannot upgrade falling back to old protocol: %s", err)
-				return
-			}
-			defer conn.Close()
-
-			handleWebsocketProtocol(conn)
-		} else if r.Method == MethodRDGIN {
+func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
+	if r.Method == MethodRDGOUT {
+		if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
 			handleLegacyProtocol(w, r)
+			return
+		}
+		r.Method = "GET" // force
+		conn, err := upgrader.Upgrade(w, r, nil)
+		if err != nil {
+			log.Printf("Cannot upgrade falling back to old protocol: %s", err)
+			return
 		}
-	})
+		defer conn.Close()
+
+		handleWebsocketProtocol(conn)
+	} else if r.Method == MethodRDGIN {
+		handleLegacyProtocol(w, r)
+	}
 }
 
-func handleWebsocketProtocol(conn *websocket.Conn)  {
+func handleWebsocketProtocol(conn *websocket.Conn) {
 	fragment := false
 	buf := make([]byte, 4096)
 	index := 0
@@ -375,7 +369,7 @@ func handshakeResponse(major byte, minor byte, auth uint16) []byte {
 	buf := new(bytes.Buffer)
 	binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code
 	buf.Write([]byte{major, minor})
-	binary.Write(buf, binary.LittleEndian, uint16(0))                                                                         // server version
+	binary.Write(buf, binary.LittleEndian, uint16(0))                                            // server version
 	binary.Write(buf, binary.LittleEndian, uint16(HTTP_EXTENDED_AUTH_PAA|HTTP_EXTENDED_AUTH_SC)) // extended auth
 
 	return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())
-- 
GitLab