From 80d11598ec188d7064164a014418e1e2e287259d Mon Sep 17 00:00:00 2001
From: Bolke de Bruin <bolke@xs4all.nl>
Date: Thu, 9 Jul 2020 10:15:27 +0200
Subject: [PATCH] Working websockets

---
 go.mod  |   8 +
 main.go | 215 ++++----------------------
 rdg.go  | 456 +++++++++++++++++++++++++++++++++++++-------------------
 3 files changed, 340 insertions(+), 339 deletions(-)
 create mode 100644 go.mod

diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..7b6b0d6
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,8 @@
+module github.com/bolkedebruin/rdpgw
+
+go 1.14
+
+require (
+	github.com/gorilla/websocket v1.4.2
+	github.com/patrickmn/go-cache v2.1.0+incompatible
+)
diff --git a/main.go b/main.go
index 035182f..9cb7531 100644
--- a/main.go
+++ b/main.go
@@ -2,212 +2,51 @@ package main
 
 import (
 	"crypto/tls"
-	"net/http/httputil"
-	"os"
-	//"time"
-
-	//"bytes"
-	"fmt"
+	"flag"
 	"log"
-	//"strings"
-	// "io"
 	"net/http"
-	//"net/http/httputil"
-	//"math/rand"
-	//"encoding/binary"
-	//"encoding/base64"
+	"os"
+	"strconv"
 )
 
+func main() {
+	port := flag.Int("port", 443, "port to listen on for incoming connections")
+	certFile := flag.String("certfile", "server.pem", "public key certificate file")
+	keyFile := flag.String("keyfile", "key.pem", "private key file")
 
-/*
-func handleConnection(s *MySession) {
-	inData := make([]byte, 4096)
+	flag.Parse()
 
-	for {
-		size, err := s.buffIn.Read(inData)
-		if err != nil {
-			s.inConn.Close()
-			s.outConn.Close()
-			fmt.Println(err)
-		}
-		fmt.Printf("Bytes read on IN %d\n", size)
+	if *certFile == "" || *keyFile == "" {
+		log.Fatal("Both certfile and keyfile need to be specified")
 	}
-}*/
-
-/*
-func MethodOverride(next http.Handler) http.Handler {
-	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		fmt.Println(r.Method)
-		dump, _ := httputil.DumpRequest(r, false)
-		fmt.Printf("%q\n", dump)
-
-		headerKey := "Rdg-Connection-Id"
-		connId := r.Header.Get(headerKey)
-		if connId != "" {
-			s.guid = connId
-		}
-		auth := r.Header.Get("Authorization")
-		fmt.Printf("Connection ID: %s\n", s.guid)
-
-		if strings.Contains(auth,"NTLMX") {
-			/*var msg_req_b []byte
-			base64.StdEncoding.Decode(msg_req_b, []byte(auth[strings.Index(auth,"NTLM")+6:]))
-
-			msg_type := binary.LittleEndian.Uint32(msg_req_b[0:4])
-			fmt.Printf("Message type %v\n", msg_type)
-			if msg_type == 1 {
-				var nonce [8]byte
-				r := make([]byte, 8)
-				rand.Read(r)
-				copy(nonce[:], r)
-
-				sig_buf := new(bytes.Buffer)
-				var signature [8]byte
-				binary.Write(sig_buf, binary.LittleEndian, "NTLMSSP\000")
-				copy(signature[:], sig_buf.Bytes())
 
-				zero := make([]byte, 7)
-				pad := make([]byte, 2)
-
-				rand.Read(nonce)
-
-				buf := new(bytes.Buffer)
-				msg := NtlmChallenge{
-					signature,
-					uint32(0x02),
-					0,
-					0,
-					0,
-					[]byte(),
-					nonce,
-					0,
-					0
-				}
-				_ := binary.Write(buf, binary.LittleEndian, msg)
-				header := "NTLM" + base64.StdEncoding.EncodeToString(buf.Bytes())
-				w.Header().Set("WWW-Authenticate", header)
-				w.WriteHeader(401)
-				w.Write([]byte("Unauthorized.\n"))
-				fmt.Println("Unauthorized")
-				return
-			}
-		} else {
-			_, _, ok := r.BasicAuth()
-
-			if !ok && !s.hasIn {
-				w.Header().Set("WWW-Authenticate", `Basic realm="rdpgw"`)
-				w.WriteHeader(401)
-				w.Write([]byte("Unauthorized.\n"))
-				fmt.Println("Unauthorized")
-				return
-			}
-		}
+	//mux := http.NewServeMux()
+	//mux.HandleFunc("*", HelloServer)
 
-		if r.Method == "RDG_OUT_DATA" {
-			fmt.Println("Hijacking OUT")
-			hj, ok := w.(http.Hijacker)
-			if !ok {
-				http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError)
-				return
-			}
-			conn, bufrw, err := hj.Hijack()
-			if err != nil {
-				http.Error(w, err.Error(), http.StatusInternalServerError)
-				return
-			}
-			s.outConn = conn
-			s.buffOut = bufrw
-
-			if !s.hasOut {
-				fmt.Printf("Creating OUT and sending seed\n")
-				s.hasOut = true
-				seed := make([]byte, 100)
-				rand.Read(seed)
-				bufrw.WriteString("HTTP/1.1 200 OK\r\n")
-				fmt.Fprintf(bufrw, "Date: %s\r\n", time.Now().Format(time.RFC1123))
-				bufrw.WriteString("Content-Type: application/octet-stream\r\n")
-				bufrw.WriteString("Content-Length: 0\r\n")
-				bufrw.WriteString(crlf)
-				bufrw.Write(seed)
-				bufrw.Flush()
-				return
-			} else {
-				fmt.Printf("Handle OUT\n")
-				handleConnection(s)
-				return
-			}
-		}
-
-		if r.Method == "RDG_IN_DATA" {
-			fmt.Println("Hijacking IN")
-			hj, ok := w.(http.Hijacker)
-			if !ok {
-				http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError)
-				return
-			}
-			conn, bufrw, err := hj.Hijack()
-			if err != nil {
-				http.Error(w, err.Error(), http.StatusInternalServerError)
-				return
-			}
-			s.inConn = conn
-			s.buffIn = bufrw
-
-			if !s.hasIn {
-				fmt.Printf("Creating IN and sending seed\n")
-				s.hasIn = true
-				seed := make([]byte, 100)
-				rand.Read(seed)
-				bufrw.WriteString("HTTP/1.1 200 OK\r\n")
-				fmt.Fprintf(bufrw, "Date: %s\r\n", time.Now().Format(time.RFC1123))
-				bufrw.WriteString("Content-Type: application/octet-stream\r\n")
-				bufrw.WriteString("Content-Length: 0\r\n")
-				bufrw.WriteString(crlf)
-				bufrw.Write(seed)
-				bufrw.Flush()
-				return
-			} else {
-				fmt.Printf("Handle IN\n")
-
-				handleConnection(s)
-				return
-			}
+	log.Printf("Starting remote desktop gateway server")
+	cfg := &tls.Config{}
+	tlsDebug := os.Getenv("SSLKEYLOGFILE")
+	if tlsDebug != "" {
+		w, err := os.OpenFile(tlsDebug, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
+		if err != nil {
+			log.Fatalf("Cannot open key log file %s for writing %s", tlsDebug, err)
 		}
-
-		next.ServeHTTP(w, r)
-	})
-}
-*/
-
-func HelloServer(w http.ResponseWriter, req *http.Request) {
-	dump, _ := httputil.DumpRequest(req, true)
-	fmt.Println(dump)
-	w.Header().Set("Content-Type", "text/plain")
-	w.Write([]byte("This is an example server.\n"))
-	// io.WriteString(w, "This is an example server.\n")
-}
-
-func main() {
-	fmt.Println("Hello!")
-	mux := http.NewServeMux()
-	mux.HandleFunc("*", HelloServer)
-
-	w, err := os.OpenFile("tls-secrets.txt", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
-	cfg := &tls.Config{
-		KeyLogWriter: w,
+		log.Printf("Key log file set to: %s", tlsDebug)
+		cfg.KeyLogWriter = w
 	}
-	cert, err := tls.LoadX509KeyPair("server.pem", "key.pem")
+	cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
 	if err != nil {
 		log.Fatal(err)
 	}
 	cfg.Certificates = append(cfg.Certificates, cert)
 	server := http.Server{
-		Addr: ":8000",
-		Handler: Upgrade(mux),
+		Addr:      ":" + strconv.Itoa(*port),
+		Handler:   Upgrade(nil),
 		TLSConfig: cfg,
 	}
-	err = server.ListenAndServeTLS("","")
+
+	err = server.ListenAndServeTLS("", "")
 	if err != nil {
 		log.Fatal("ListenAndServe: ", err)
 	}
-}
\ No newline at end of file
+}
diff --git a/rdg.go b/rdg.go
index 2d4ebcd..40b018a 100644
--- a/rdg.go
+++ b/rdg.go
@@ -4,22 +4,25 @@ import (
 	"bufio"
 	"bytes"
 	"encoding/binary"
+	"errors"
 	"fmt"
+	"github.com/patrickmn/go-cache"
 	"io"
 	"log"
 	"math/rand"
 	"net"
 	"net/http"
+	//"net/http/httputil"
 	"strconv"
 	"time"
 	"unicode/utf16"
 	"unicode/utf8"
+	"github.com/gorilla/websocket"
 )
 
 const (
-	crlf      = "\r\n"
+	crlf               = "\r\n"
 	rdgConnectionIdKey = "Rdg-Connection-Id"
-	HANDSHAKE = 1
 )
 
 const (
@@ -40,6 +43,13 @@ const (
 	PKT_TYPE_CLOSE_CHANNEL_RESPONSE = 0x11
 )
 
+const (
+	HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID   = 0x01
+	HTTP_TUNNEL_RESPONSE_FIELD_CAPS        = 0x02
+	HTTP_TUNNEL_RESPONSE_FIELD_SOH_REQ     = 0x04
+	HTTP_TUNNEL_RESPONSE_FIELD_CONSENT_MSG = 0x10
+)
+
 const (
 	HTTP_EXTENDED_AUTH_NONE      = 0x0
 	HTTP_EXTENDED_AUTH_SC        = 0x1  /* Smart card authentication. */
@@ -47,6 +57,28 @@ const (
 	HTTP_EXTENDED_AUTH_SSPI_NTLM = 0x04 /* NTLM extended authentication. */
 )
 
+const (
+	HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS  = 0x01
+	HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT = 0x02
+	HTTP_TUNNEL_AUTH_RESPONSE_FIELD_SOH_RESPONSE = 0x04
+)
+
+const (
+	HTTP_TUNNEL_REDIR_ENABLE_ALL        = 0x80000000
+	HTTP_TUNNEL_REDIR_DISABLE_ALL       = 0x40000000
+	HTTP_TUNNEL_REDIR_DISABLE_DRIVE     = 0x01
+	HTTP_TUNNEL_REDIR_DISABLE_PRINTER   = 0x02
+	HTTP_TUNNEL_REDIR_DISABLE_PORT      = 0x03
+	HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD = 0x08
+	HTTP_TUNNEL_REDIR_DISABLE_PNP       = 0x10
+)
+
+const (
+	HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID   = 0x01
+	HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE = 0x02
+	HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT     = 0x04
+)
+
 const (
 	HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1
 )
@@ -63,10 +95,9 @@ type RdgSession struct {
 	UserId        string
 	ConnIn        net.Conn
 	ConnOut       net.Conn
-	BufOut        *bufio.Writer
-	BufIn         *bufio.Reader
-	State         int
-	Remote 		  net.Conn
+	StateIn       int
+	StateOut      int
+	Remote        net.Conn
 }
 
 // ErrNotHijacker is an error returned when http.ResponseWriter does not
@@ -79,89 +110,196 @@ var ErrNotHijacker = RejectConnectionError(
 var DefaultSession RdgSession
 
 func Upgrade(next http.Handler) http.Handler {
-	return DefaultSession.RdgHandshake(next)
+	return RdgHandshake(next)
 }
 
 func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) {
-		log.Print("Accept connection")
-		hj, ok := w.(http.Hijacker)
-		if ok {
-			return hj.Hijack()
-		} else {
-			err = ErrNotHijacker
-		}
-		if err != nil {
-			httpError(w, err.Error(), http.StatusInternalServerError)
-			return nil, nil, err
-		}
-		return
+	log.Print("Accept connection")
+	hj, ok := w.(http.Hijacker)
+	if ok {
+		return hj.Hijack()
+	} else {
+		err = ErrNotHijacker
+	}
+	if err != nil {
+		httpError(w, err.Error(), http.StatusInternalServerError)
+		return nil, nil, err
+	}
+	return
 }
 
-func (s RdgSession) RdgHandshake(next http.Handler) http.Handler {
+var upgrader = websocket.Upgrader{}
+
+func RdgHandshake(next http.Handler) http.Handler {
+	c := cache.New(5*time.Minute, 10*time.Minute)
+
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		/*_, _, ok := r.BasicAuth()
+		var s RdgSession
 
-		if !ok && s.ConnIn == nil {
-			w.Header().Set("WWW-Authenticate", `Basic realm="rdpgw"`)
-			w.WriteHeader(401)
-			w.Write([]byte("Unauthorized.\n"))
-			fmt.Println("Unauthorized")
-			return
-		}*/
+		connId := r.Header.Get(rdgConnectionIdKey)
+		x, found := c.Get(connId)
+		if !found {
+			log.Printf("No cached session found")
+			s = RdgSession{ConnId: connId, StateIn: 0, StateOut: 0}
+		} else {
+			log.Printf("Found cached session")
+			s = x.(RdgSession)
+		}
+
+		log.Printf("Session %s, %t, %t", s.ConnId, s.ConnOut != nil, s.ConnIn != nil)
 
-		conn, rw, _ := Accept(w)
 		if r.Method == MethodRDGOUT {
+			r.Method = "GET" // force
+			c, err := upgrader.Upgrade(w, r, nil)
+			if err != nil {
+				log.Printf("Cannot upgrade falling back to old protocol: %s", err)
+				return
+			}
+			defer c.Close()
+
+			fragment := false
+			buf := make([]byte, 4096)
+			index := 0
+			for {
+				mt, msg, err := c.ReadMessage()
+				if err != nil {
+					log.Printf("Error read: %s", err)
+					break
+				}
+				log.Printf("Message type: %d, message: %x", mt, msg)
+
+				// check for fragments
+				var pt uint16
+				var sz uint32
+				var pkt []byte
+
+				if !fragment {
+					pt, sz, pkt, err = readHeader(msg)
+					if err != nil {
+						// fragment received
+						log.Printf("Received non websocket fragment")
+						fragment = true
+						index = copy(buf, msg)
+						continue
+					}
+					index = 0
+				} else {
+					log.Printf("Dealing with fragment")
+					fragment = false
+					pt, sz, pkt, _ = readHeader(append(buf[:index], msg...))
+				}
+
+				switch pt {
+				case PKT_TYPE_HANDSHAKE_REQUEST:
+					major, minor, _, auth := readHandshake(pkt)
+					msg := handshakeResponse(major, minor, auth)
+					log.Printf("Handshake response: %x", msg)
+					c.WriteMessage(mt, msg)
+				case PKT_TYPE_TUNNEL_CREATE:
+					readCreateTunnelRequest(pkt)
+					msg := createTunnelResponse()
+					log.Printf("Create tunnel response: %x", msg)
+					c.WriteMessage(mt, msg)
+				case PKT_TYPE_TUNNEL_AUTH:
+					readTunnelAuthRequest(pkt)
+					msg := createTunnelAuthResponse()
+					log.Printf("Create tunnel auth response: %x", msg)
+					c.WriteMessage(mt, msg)
+				case PKT_TYPE_CHANNEL_CREATE:
+					server, port := readChannelCreateRequest(pkt)
+					s.Remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port))))
+					if err != nil {
+						log.Printf("Error connecting to %s, %d, %s", server, port, err)
+						return
+					}
+					msg := createChannelCreateResponse()
+					log.Printf("Create channel create response: %x", msg)
+					c.WriteMessage(mt, msg)
+					go handleWebsocketData(s.Remote, mt, c)
+				case PKT_TYPE_DATA:
+					forwardDataPacket(s.Remote, pkt)
+				case PKT_TYPE_KEEPALIVE:
+					c.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
+				case PKT_TYPE_CLOSE_CHANNEL:
+					s.Remote.Close()
+					return
+				default:
+					log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz)
+				}
+			}
+			conn, rw, _ := Accept(w)
 			log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String())
-			s.ConnId = r.Header.Get(rdgConnectionIdKey)
+
 			s.ConnOut = conn
-			s.BufOut = rw.Writer
-			WriteAcceptSeed(rw.Writer)
-			rw.Writer.Flush()
-		} else if r.Method == MethodRDGIN {
-			if s.ConnIn == nil {
-				defer conn.Close()
-				s.ConnIn = conn
-				s.BufIn = rw.Reader
-				log.Printf("Opening RDGIN for client %s", conn.RemoteAddr().String())
-				WriteAcceptSeed(rw.Writer)
-				rw.Writer.Flush()
-				p := make([]byte, 4096)
-				rw.Reader.Read(p)
-				//log.Printf("Read %q", p)
-
-				log.Printf("Reading packet from client %s", conn.RemoteAddr().String())
-				scanner := bufio.NewScanner(rw.Reader)
-				scanner.Split(ReadPacket)
-				for scanner.Scan() {
-					packet := scanner.Bytes()
-					packetType, size, _, packet := readHeader(packet)
-					log.Printf("Scanned packet got packet type %x size %d", packetType, size)
-					switch packetType {
-					case PKT_TYPE_HANDSHAKE_REQUEST:
-						major, minor, _, auth := readHandshake(packet)
-						sendHandshakeResponse(s.BufOut, major, minor, auth)
-					case PKT_TYPE_TUNNEL_CREATE:
-						readCreateTunnelRequest(packet)
-						sendCreateTunnelResponse(s.BufOut)
-					case PKT_TYPE_TUNNEL_AUTH:
-						readTunnelAuthRequest(packet)
-						sendTunnelAuthResponse(s.BufOut)
-					case PKT_TYPE_CHANNEL_CREATE:
-						server, port := readChannelCreateRequest(packet)
-						var err error
-						s.Remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port))))
-						if err != nil {
-							log.Printf("Error connecting to %s, %d, %s", server, port, err)
-							return
-						}
-						sendChannelCreateResponse(s.BufOut)
-						go sendDataPacket(s.Remote, s.BufOut)
-					case PKT_TYPE_DATA:
-						receiveDataPacket(s.Remote, packet)
+			WriteAcceptSeed(rw.Writer, true)
+
+			//c.Set(connId, s, cache.DefaultExpiration)
+		} /*else if r.Method == MethodRDGIN {
+		if !checkNTLMAuth(w, &s, "IN") {
+			c.Set(connId, s, cache.DefaultExpiration)
+			return
+		}
+		conn, rw, _ := Accept(w)
+
+		if s.ConnIn == nil {
+			defer conn.Close()
+			s.ConnIn = conn
+			c.Set(connId, s, cache.DefaultExpiration)
+			log.Printf("Opening RDGIN for client %s", conn.RemoteAddr().String())
+			WriteAcceptSeed(rw.Writer, false)
+			p := make([]byte, 32767)
+			rw.Reader.Read(p)
+			//log.Printf("Read %q", p)
+
+			log.Printf("Reading packet from client %s", conn.RemoteAddr().String())
+			chunkScanner := httputil.NewChunkedReader(rw.Reader)
+			packet := make([]byte, 4096) // bufio.defaultBufSize
+
+			for {
+				n, err := chunkScanner.Read(packet)
+				if err == io.EOF || n == 0 {
+					break
+				}
+				old_packet := packet
+				packetType, size, _, packet := readHeader(packet)
+				log.Printf("Scanned packet got packet type %x size %d", packetType, size)
+				switch packetType {
+				case PKT_TYPE_HANDSHAKE_REQUEST:
+					major, minor, _, auth := readHandshake(packet)
+					sendHandshakeResponse(s.ConnOut, major, minor, auth)
+				case PKT_TYPE_TUNNEL_CREATE:
+					readCreateTunnelRequest(packet)
+					sendCreateTunnelResponse(s.ConnOut)
+				case PKT_TYPE_TUNNEL_AUTH:
+					readTunnelAuthRequest(packet)
+					sendTunnelAuthResponse(s.ConnOut)
+				case PKT_TYPE_CHANNEL_CREATE:
+					server, port := readChannelCreateRequest(packet)
+					var err error
+					s.Remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port))))
+					if err != nil {
+						log.Printf("Error connecting to %s, %d, %s", server, port, err)
+						return
 					}
+					sendChannelCreateResponse(s.ConnOut)
+					// Make sure to start the flow from the RDP server first otherwise connections
+					// might hang eventually
+					go sendDataPacket(s.Remote, s.ConnOut)
+				case PKT_TYPE_DATA:
+					receiveDataPacket(s.Remote, packet)
+				case PKT_TYPE_KEEPALIVE:
+					s.ConnOut.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
+				case PKT_TYPE_CLOSE_CHANNEL:
+					s.ConnIn.Close()
+					s.ConnOut.Close()
+					break
+				default:
+					log.Printf("UNKNOWN PACKET (%d): %x", n, old_packet[:n])
+					//receiveDataPacket(s.Remote, old_packet)
+					receiveUnknownPacket(s.Remote, old_packet, n)
 				}
 			}
-		}
+		}*/
 	})
 }
 
@@ -170,58 +308,50 @@ func (s RdgSession) RdgHandshake(next http.Handler) http.Handler {
 // This enables a reverse proxy to start allowing data from the RDG server to the RDG client. The RDG server does
 // not specify an entity length in its response. It uses HTTP 1.0 semantics to send the entity body and closes the
 // connection after the last byte is sent.
-func WriteAcceptSeed(bw *bufio.Writer) {
+func WriteAcceptSeed(bw *bufio.Writer, doSeed bool) {
 	bw.WriteString(HttpOK)
+	bw.WriteString("Server: Microsoft-HTTPAPI/2.0\r\n")
 	bw.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n")
-	bw.WriteString("Content-Type: application/octet-stream\r\n")
-	bw.WriteString("Content-Length: 0\r\n")
-	bw.WriteString(crlf)
-	seed := make([]byte, 10)
-	rand.Read(seed)
-	bw.Write(seed)
-}
-
-func ReadPacket(data []byte, atEOF bool) (advance int, packet []byte, err error) {
-	log.Printf("Reading data len = %d", len(data))
-	if atEOF && len(data) == 0 {
-		return 0, nil, nil
-	}
-
-	if i := bytes.Index(data, []byte{'\r', '\n'}); i >= 0 {
-		//log.Printf("Got rn at %d ", i)
-		chunkSize, err := strconv.ParseInt(string(data[0:i]), 16, 0)
-		log.Printf("chunkSize %d", chunkSize)
-		if err != nil {
-			return i + 2, data[0:i], err
-		}
-		//log.Printf("Return %d", i+2+int(chunkSize)+2)
-		return i + 2 + int(chunkSize) + 2, data[i+2 : i+2+int(chunkSize)+2], nil
+	if !doSeed {
+		bw.WriteString("Content-Length: 0\r\n")
 	}
+	bw.WriteString(crlf)
 
-	if atEOF {
-		return len(data), data, nil
+	if doSeed {
+		seed := make([]byte, 10)
+		rand.Read(seed)
+		// docs say it's a seed but 2019 responds with ab cd * 5
+		bw.Write(seed)
 	}
-
-	return 0, nil, nil
+	bw.Flush()
 }
 
-func readHeader(data []byte) (packetType uint16, size uint32, advance int, remain []byte) {
+func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
+	// header needs to be 8 min
+	if len(data) < 8 {
+		return 0, 0, nil, errors.New("header too short, fragment likely")
+	}
 	r := bytes.NewReader(data)
 	binary.Read(r, binary.LittleEndian, &packetType)
 	r.Seek(4, io.SeekStart)
 	binary.Read(r, binary.LittleEndian, &size)
-	return packetType, size, 8, data[8:]
+	if len(data) < int(size) {
+		return packetType, size, data[8:], errors.New("data incomplete, fragment received")
+	}
+	return packetType, size, data[8:], nil
 }
 
-func sendHandshakeResponse(w *bufio.Writer, major byte, minor byte, auth uint16) {
+// Creates a packet the is a response to a handshake request
+// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
+// but could be in Windows. However the NTLM protocol is insecure
+func handshakeResponse(major byte, minor byte, auth uint16) []byte {
 	buf := new(bytes.Buffer)
 	binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code
 	buf.Write([]byte{major, minor})
-	binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
-	binary.Write(buf, binary.LittleEndian, uint16(2)) // PAA
+	binary.Write(buf, binary.LittleEndian, uint16(0))                                                                         // server version
+	binary.Write(buf, binary.LittleEndian, uint16(HTTP_EXTENDED_AUTH_PAA|HTTP_EXTENDED_AUTH_SC)) // extended auth
 
-	w.Write(createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes()))
-	w.Flush()
+	return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())
 }
 
 func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
@@ -235,7 +365,7 @@ func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth
 	return
 }
 
-func readCreateTunnelRequest(data []byte) (caps uint32, cookie string){
+func readCreateTunnelRequest(data []byte) (caps uint32, cookie string) {
 	var fields uint16
 
 	r := bytes.NewReader(data)
@@ -255,16 +385,21 @@ func readCreateTunnelRequest(data []byte) (caps uint32, cookie string){
 	return
 }
 
-func sendCreateTunnelResponse(w *bufio.Writer) {
+func createTunnelResponse() []byte {
 	buf := new(bytes.Buffer)
 
-	binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
-	binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
-	binary.Write(buf, binary.LittleEndian, uint16(0)) // fields present
-	binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
+	binary.Write(buf, binary.LittleEndian, uint16(0))                                                                    // server version
+	binary.Write(buf, binary.LittleEndian, uint32(0))                                                                    // error code
+	binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID|HTTP_TUNNEL_RESPONSE_FIELD_CAPS)) // fields present
+	binary.Write(buf, binary.LittleEndian, uint16(0))                                                                    // reserved
+	binary.Write(buf, binary.LittleEndian, uint16(0))                                                                    // reserved
 
-	w.Write(createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes()))
-	w.Flush()
+	// tunnel id ?
+	binary.Write(buf, binary.LittleEndian, uint32(15))
+	// caps ?
+	binary.Write(buf, binary.LittleEndian, uint32(2))
+
+	return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
 }
 
 func readTunnelAuthRequest(data []byte) {
@@ -278,18 +413,21 @@ func readTunnelAuthRequest(data []byte) {
 	log.Printf("Client: %s", clientName)
 }
 
-func sendTunnelAuthResponse(w *bufio.Writer) {
+func createTunnelAuthResponse() []byte {
 	buf := new(bytes.Buffer)
 
-	binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
-	binary.Write(buf, binary.LittleEndian, uint16(0)) // fields present
-	binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
+	binary.Write(buf, binary.LittleEndian, uint32(0))                                                                                        // error code
+	binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present
+	binary.Write(buf, binary.LittleEndian, uint16(0))                                                                                        // reserved
+
+	// flags
+	binary.Write(buf, binary.LittleEndian, uint32(HTTP_TUNNEL_REDIR_ENABLE_ALL)) // redir flags
+	binary.Write(buf, binary.LittleEndian, uint32(0))                            // timeout in minutes
 
-	w.Write(createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes()))
-	w.Flush()
+	return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
 }
 
-func readChannelCreateRequest(data []byte) (server string, port uint16){
+func readChannelCreateRequest(data []byte) (server string, port uint16) {
 	buf := bytes.NewReader(data)
 
 	var resourcesSize byte
@@ -313,55 +451,55 @@ func readChannelCreateRequest(data []byte) (server string, port uint16){
 	return
 }
 
-func sendChannelCreateResponse(w *bufio.Writer) {
+func createChannelCreateResponse() []byte {
 	buf := new(bytes.Buffer)
 
 	binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
-	binary.Write(buf, binary.LittleEndian, uint16(0)) // fields present
+	//binary.Write(buf, binary.LittleEndian, uint16(HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID | HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE | HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT)) // fields present
+	binary.Write(buf, binary.LittleEndian, uint16(0)) // fields
 	binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
 
-	w.Write(createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes()))
-	w.Flush()
+	// optional fields
+	// channel id uint32 (4)
+	// udp port uint16 (2)
+	// udp auth cookie 1 byte for side channel
+	// length uint16
+
+	return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
 }
 
-func createPacket(pktType uint16, data []byte) (packet []byte){
+func createPacket(pktType uint16, data []byte) (packet []byte) {
 	size := len(data) + 8
 	buf := new(bytes.Buffer)
 
-	log.Printf("Data sent Size: %d", size)
-	// http chunk size in hex string
-	// fmt.Fprintf(buf,"%x\r\n", size)
-
 	binary.Write(buf, binary.LittleEndian, uint16(pktType))
-	binary.Write(buf, binary.LittleEndian, uint16(0))  // reserved
+	binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
 	binary.Write(buf, binary.LittleEndian, uint32(size))
 	buf.Write(data)
 
-	// http close crlf
-	// buf.Write([]byte(crlf))
-	// log.Printf("data sent: %q", buf.Bytes())
 	return buf.Bytes()
 }
 
-func receiveDataPacket(conn net.Conn, data []byte) {
+func forwardDataPacket(conn net.Conn, data []byte) {
 	buf := bytes.NewReader(data)
 
 	var cblen uint16
 	binary.Read(buf, binary.LittleEndian, &cblen)
-	log.Printf("Received PKT_DATA %d", cblen)
+	//log.Printf("Received PKT_DATA %d", cblen)
 	pkt := make([]byte, cblen)
-	//binary.Read(buf, binary.LittleEndian, &pkt)
-	buf.Read(pkt)
+	binary.Read(buf, binary.LittleEndian, &pkt)
+	//n, _ := buf.Read(pkt)
+	//log.Printf("CBLEN: %d, N: %d", cblen, n)
 	//log.Printf("DATA FROM CLIENT %q", pkt)
 	conn.Write(pkt)
 }
 
-func sendDataPacket(conn net.Conn, w *bufio.Writer) {
-	defer conn.Close()
+func handleWebsocketData(rdp net.Conn, mt int, conn *websocket.Conn) {
+	defer rdp.Close()
 	b1 := new(bytes.Buffer)
-	buf := make([]byte, 32767)
+	buf := make([]byte, 4086)
 	for {
-		n, err := conn.Read(buf)
+		n, err := rdp.Read(buf)
 		binary.Write(b1, binary.LittleEndian, uint16(n))
 		log.Printf("RDP SIZE: %d", n)
 		if err != nil {
@@ -369,16 +507,32 @@ func sendDataPacket(conn net.Conn, w *bufio.Writer) {
 			break
 		}
 		b1.Write(buf[:n])
-		w.Write(createPacket(PKT_TYPE_DATA, b1.Bytes()))
-		w.Flush()
+		conn.WriteMessage(mt, createPacket(PKT_TYPE_DATA, b1.Bytes()))
+		b1.Reset()
+	}
+}
+
+func sendDataPacket(connIn net.Conn, connOut net.Conn) {
+	defer connIn.Close()
+	b1 := new(bytes.Buffer)
+	buf := make([]byte, 4086)
+	for {
+		n, err := connIn.Read(buf)
+		binary.Write(b1, binary.LittleEndian, uint16(n))
+		log.Printf("RDP SIZE: %d", n)
+		if err != nil {
+			log.Printf("Error reading from conn %s", err)
+			break
+		}
+		b1.Write(buf[:n])
+		connOut.Write(createPacket(PKT_TYPE_DATA, b1.Bytes()))
 		b1.Reset()
 	}
 }
 
 func DecodeUTF16(b []byte) (string, error) {
 	if len(b)%2 != 0 {
-		log.Printf("Error decoding utf16")
-		return "", fmt.Errorf("Must have even length byte slice")
+		return "", fmt.Errorf("must have even length byte slice")
 	}
 
 	u16s := make([]uint16, 1)
@@ -394,4 +548,4 @@ func DecodeUTF16(b []byte) (string, error) {
 	}
 
 	return ret.String(), nil
-}
\ No newline at end of file
+}
-- 
GitLab