From 1f191a5e41b32e9e38ffc32f8616df02ec3f81f8 Mon Sep 17 00:00:00 2001 From: Bolke de Bruin <bolke@xs4all.nl> Date: Mon, 13 Jul 2020 15:38:25 +0200 Subject: [PATCH] Use standard HandleFunc pattern --- main.go | 4 ++-- rdg.go | 42 ++++++++++++++++++------------------------ 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/main.go b/main.go index 9cb7531..ba3da3d 100644 --- a/main.go +++ b/main.go @@ -41,10 +41,10 @@ func main() { cfg.Certificates = append(cfg.Certificates, cert) server := http.Server{ Addr: ":" + strconv.Itoa(*port), - Handler: Upgrade(nil), TLSConfig: cfg, } - + http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol) + err = server.ListenAndServeTLS("", "") if err != nil { log.Fatal("ListenAndServe: ", err) diff --git a/rdg.go b/rdg.go index 7222e2b..4d8526a 100644 --- a/rdg.go +++ b/rdg.go @@ -110,10 +110,6 @@ var ErrNotHijacker = RejectConnectionError( var DefaultSession RdgSession -func Upgrade(next http.Handler) http.Handler { - return handleGatewayProtocol(next) -} - func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) { log.Print("Accept connection") hj, ok := w.(http.Hijacker) @@ -132,29 +128,27 @@ func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err err var upgrader = websocket.Upgrader{} var c = cache.New(5*time.Minute, 10*time.Minute) -func handleGatewayProtocol(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == MethodRDGOUT { - if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { - handleLegacyProtocol(w, r) - return - } - r.Method = "GET" // force - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - log.Printf("Cannot upgrade falling back to old protocol: %s", err) - return - } - defer conn.Close() - - handleWebsocketProtocol(conn) - } else if r.Method == MethodRDGIN { +func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) { + if r.Method == MethodRDGOUT { + if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { handleLegacyProtocol(w, r) + return + } + r.Method = "GET" // force + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("Cannot upgrade falling back to old protocol: %s", err) + return } - }) + defer conn.Close() + + handleWebsocketProtocol(conn) + } else if r.Method == MethodRDGIN { + handleLegacyProtocol(w, r) + } } -func handleWebsocketProtocol(conn *websocket.Conn) { +func handleWebsocketProtocol(conn *websocket.Conn) { fragment := false buf := make([]byte, 4096) index := 0 @@ -375,7 +369,7 @@ func handshakeResponse(major byte, minor byte, auth uint16) []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code buf.Write([]byte{major, minor}) - binary.Write(buf, binary.LittleEndian, uint16(0)) // server version + binary.Write(buf, binary.LittleEndian, uint16(0)) // server version binary.Write(buf, binary.LittleEndian, uint16(HTTP_EXTENDED_AUTH_PAA|HTTP_EXTENDED_AUTH_SC)) // extended auth return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes()) -- GitLab