From e9b7b352cf2b30ef084975cf38e1ad0af9232a9a Mon Sep 17 00:00:00 2001
From: Bolke de Bruin <bolke@xs4all.nl>
Date: Mon, 20 Jul 2020 15:51:23 +0200
Subject: [PATCH] More refactor

---
 errors.go           |   6 +-
 main.go             |   8 +-
 protocol/handler.go |  30 ++-
 protocol/rdg.go     | 156 +++++++++++++
 rdg.go              | 550 --------------------------------------------
 5 files changed, 180 insertions(+), 570 deletions(-)
 create mode 100644 protocol/rdg.go
 delete mode 100644 rdg.go

diff --git a/errors.go b/errors.go
index bfd90b7..b4a5830 100644
--- a/errors.go
+++ b/errors.go
@@ -1,5 +1,7 @@
 package main
 
+import "github.com/bolkedebruin/rdpgw/protocol"
+
 // RejectOption represents an option used to control the way connection is
 // rejected.
 type RejectOption func(*rejectConnectionError)
@@ -22,7 +24,7 @@ func RejectionStatus(code int) RejectOption {
 
 // RejectionHeader returns an option that makes connection to be rejected with
 // given HTTP headers.
-func RejectionHeader(h HandshakeHeader) RejectOption {
+func RejectionHeader(h protocol.HandshakeHeader) RejectOption {
 	return func(err *rejectConnectionError) {
 		err.header = h
 	}
@@ -44,7 +46,7 @@ func RejectConnectionError(options ...RejectOption) error {
 type rejectConnectionError struct {
 	reason string
 	code   int
-	header HandshakeHeader
+	header protocol.HandshakeHeader
 }
 
 // Error implements error interface.
diff --git a/main.go b/main.go
index e5d9960..a3a9163 100644
--- a/main.go
+++ b/main.go
@@ -4,9 +4,9 @@ import (
 	"context"
 	"crypto/tls"
 	"github.com/bolkedebruin/rdpgw/config"
+	"github.com/bolkedebruin/rdpgw/protocol"
 	"github.com/coreos/go-oidc/v3/oidc"
 	"github.com/patrickmn/go-cache"
-	"github.com/prometheus/client_golang/prometheus"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	"github.com/spf13/cobra"
 	"golang.org/x/oauth2"
@@ -89,15 +89,11 @@ func main() {
 		TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
 	}
 
-	http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol)
+	http.HandleFunc("/remoteDesktopGateway/", protocol.HandleGatewayProtocol)
 	http.HandleFunc("/connect", handleRdpDownload)
 	http.Handle("/metrics", promhttp.Handler())
 	http.HandleFunc("/callback", handleCallback)
 
-	prometheus.MustRegister(connectionCache)
-	prometheus.MustRegister(legacyConnections)
-	prometheus.MustRegister(websocketConnections)
-
 	err = server.ListenAndServeTLS("", "")
 	if err != nil {
 		log.Fatal("ListenAndServe: ", err)
diff --git a/protocol/handler.go b/protocol/handler.go
index 07be70d..28f5f43 100644
--- a/protocol/handler.go
+++ b/protocol/handler.go
@@ -20,19 +20,21 @@ type VerifyTunnelAuthFunc func(string) (bool, error)
 type VerifyServerFunc func(string) (bool, error)
 
 type Handler struct {
-	Transport            transport.Transport
+	TransportIn          transport.Transport
+	TransportOut		 transport.Transport
 	VerifyPAACookieFunc  VerifyPAACookieFunc
 	VerifyTunnelAuthFunc VerifyTunnelAuthFunc
 	VerifyServerFunc     VerifyServerFunc
 	SmartCardAuth        bool
 	TokenAuth            bool
 	ClientName           string
-	Remote				 net.Conn
+	Remote               net.Conn
 }
 
-func NewHandler(t transport.Transport) *Handler {
+func NewHandler(in transport.Transport, out transport.Transport) *Handler {
 	h := &Handler{
-		Transport: t,
+		TransportIn: in,
+		TransportOut: out,
 	}
 	return h
 }
@@ -49,8 +51,9 @@ func (h *Handler) Process() error {
 		case PKT_TYPE_HANDSHAKE_REQUEST:
 			major, minor, _, auth := readHandshake(pkt)
 			msg := h.handshakeResponse(major, minor, auth)
-			h.Transport.WritePacket(msg)
+			h.TransportOut.WritePacket(msg)
 		case PKT_TYPE_TUNNEL_CREATE:
+			log.Printf("Tunnel create")
 			_, cookie := readCreateTunnelRequest(pkt)
 			if h.VerifyPAACookieFunc != nil {
 				if ok, _ := h.VerifyPAACookieFunc(cookie); ok == false {
@@ -59,11 +62,13 @@ func (h *Handler) Process() error {
 				}
 			}
 			msg := createTunnelResponse()
-			h.Transport.WritePacket(msg)
+			h.TransportOut.WritePacket(msg)
+			log.Printf("Tunnel done")
 		case PKT_TYPE_TUNNEL_AUTH:
+			log.Printf("Tunnel auth")
 			h.readTunnelAuthRequest(pkt)
 			msg := h.createTunnelAuthResponse()
-			h.Transport.WritePacket(msg)
+			h.TransportOut.WritePacket(msg)
 		case PKT_TYPE_CHANNEL_CREATE:
 			server, port := readChannelCreateRequest(pkt)
 			log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server)
@@ -77,7 +82,7 @@ func (h *Handler) Process() error {
 			}
 			log.Printf("Connection established")
 			msg := createChannelCreateResponse()
-			h.Transport.WritePacket(msg)
+			h.TransportOut.WritePacket(msg)
 
 			// Make sure to start the flow from the RDP server first otherwise connections
 			// might hang eventually
@@ -86,9 +91,10 @@ func (h *Handler) Process() error {
 			h.forwardDataPacket(pkt)
 		case PKT_TYPE_KEEPALIVE:
 			// avoid concurrency issues
-			// p.Transport.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
+			// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
 		case PKT_TYPE_CLOSE_CHANNEL:
-			h.Transport.Close()
+			h.TransportIn.Close()
+			h.TransportOut.Close()
 		default:
 			log.Printf("Unknown packet (size %d): %x", sz, pkt)
 		}
@@ -101,7 +107,7 @@ func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
 	buf := make([]byte, 4096)
 
 	for {
-		size, pkt, err := h.Transport.ReadPacket()
+		size, pkt, err := h.TransportIn.ReadPacket()
 		if err != nil {
 			return 0, 0, []byte{0, 0}, err
 		}
@@ -337,7 +343,7 @@ func (h *Handler) sendDataPacket() {
 			break
 		}
 		b1.Write(buf[:n])
-		h.Transport.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
+		h.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
 		b1.Reset()
 	}
 }
diff --git a/protocol/rdg.go b/protocol/rdg.go
new file mode 100644
index 0000000..47cf3bb
--- /dev/null
+++ b/protocol/rdg.go
@@ -0,0 +1,156 @@
+package protocol
+
+import (
+	"github.com/bolkedebruin/rdpgw/transport"
+	"github.com/gorilla/websocket"
+	"github.com/patrickmn/go-cache"
+	"github.com/prometheus/client_golang/prometheus"
+	"io"
+	"log"
+	"net"
+	"net/http"
+	"time"
+)
+
+const (
+	rdgConnectionIdKey = "Rdg-Connection-Id"
+	MethodRDGIN  = "RDG_IN_DATA"
+	MethodRDGOUT = "RDG_OUT_DATA"
+)
+
+var (
+	connectionCache = prometheus.NewGauge(
+		prometheus.GaugeOpts{
+			Namespace: "rdpgw",
+			Name:      "connection_cache",
+			Help:      "The amount of connections in the cache",
+		})
+
+	websocketConnections = prometheus.NewGauge(
+		prometheus.GaugeOpts{
+			Namespace: "rdpgw",
+			Name:      "websocket_connections",
+			Help:      "The count of websocket connections",
+		})
+
+	legacyConnections = prometheus.NewGauge(
+		prometheus.GaugeOpts{
+			Namespace: "rdpgw",
+			Name:      "legacy_connections",
+			Help:      "The count of legacy https connections",
+		})
+)
+
+// HandshakeHeader is the interface that writes both upgrade request or
+// response headers into a given io.Writer.
+type HandshakeHeader interface {
+	io.WriterTo
+}
+
+type RdgSession struct {
+	ConnId        string
+	CorrelationId string
+	UserId        string
+	TransportIn   transport.Transport
+	TransportOut  transport.Transport
+	StateIn       int
+	StateOut      int
+	Remote        net.Conn
+}
+
+var DefaultSession RdgSession
+
+var upgrader = websocket.Upgrader{}
+var c = cache.New(5*time.Minute, 10*time.Minute)
+
+func init() {
+	prometheus.MustRegister(connectionCache)
+	prometheus.MustRegister(legacyConnections)
+	prometheus.MustRegister(websocketConnections)
+}
+
+func HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
+	connectionCache.Set(float64(c.ItemCount()))
+	if r.Method == MethodRDGOUT {
+		if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
+			handleLegacyProtocol(w, r)
+			return
+		}
+		r.Method = "GET" // force
+		conn, err := upgrader.Upgrade(w, r, nil)
+		if err != nil {
+			log.Printf("Cannot upgrade falling back to old protocol: %s", err)
+			return
+		}
+		defer conn.Close()
+
+		handleWebsocketProtocol(conn)
+	} else if r.Method == MethodRDGIN {
+		handleLegacyProtocol(w, r)
+	}
+}
+
+func handleWebsocketProtocol(c *websocket.Conn) {
+	websocketConnections.Inc()
+	defer websocketConnections.Dec()
+
+	inout, _ := transport.NewWS(c)
+	handler := NewHandler(inout, inout)
+	handler.Process()
+}
+
+// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
+// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different
+// to ensure the connections do not get cached or terminated by a proxy prematurely.
+func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
+	var s RdgSession
+
+	connId := r.Header.Get(rdgConnectionIdKey)
+	x, found := c.Get(connId)
+	if !found {
+		s = RdgSession{ConnId: connId, StateIn: 0, StateOut: 0}
+	} else {
+		s = x.(RdgSession)
+	}
+
+	log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
+
+	if r.Method == MethodRDGOUT {
+		out, err := transport.NewLegacy(w)
+		if err != nil {
+			log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
+			return
+		}
+		log.Printf("Opening RDGOUT for client %s", out.Conn.RemoteAddr().String())
+
+		s.TransportOut = out
+		out.SendAccept(true)
+
+		c.Set(connId, s, cache.DefaultExpiration)
+	} else if r.Method == MethodRDGIN {
+		legacyConnections.Inc()
+		defer legacyConnections.Dec()
+
+		in, err := transport.NewLegacy(w)
+		if err != nil {
+			log.Printf("cannot hijack connection to support RDG IN data channel: %s", err)
+			return
+		}
+		defer in.Close()
+
+		if s.TransportIn == nil {
+			s.TransportIn = in
+			c.Set(connId, s, cache.DefaultExpiration)
+
+			log.Printf("Opening RDGIN for client %s", in.Conn.RemoteAddr().String())
+			in.SendAccept(false)
+
+			// read some initial data
+			in.Drain()
+
+			log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
+			handler := NewHandler(in, s.TransportOut)
+			handler.Process()
+		}
+	}
+}
\ No newline at end of file
diff --git a/rdg.go b/rdg.go
deleted file mode 100644
index 1417ed8..0000000
--- a/rdg.go
+++ /dev/null
@@ -1,550 +0,0 @@
-package main
-
-import (
-	"bytes"
-	"encoding/binary"
-	"fmt"
-	"github.com/bolkedebruin/rdpgw/protocol"
-	"github.com/bolkedebruin/rdpgw/transport"
-	"github.com/gorilla/websocket"
-	"github.com/patrickmn/go-cache"
-	"github.com/prometheus/client_golang/prometheus"
-	"io"
-	"log"
-	"net"
-	"net/http"
-	"strconv"
-	"time"
-	"unicode/utf16"
-	"unicode/utf8"
-)
-
-const (
-	crlf               = "\r\n"
-	rdgConnectionIdKey = "Rdg-Connection-Id"
-)
-
-const (
-	PKT_TYPE_HANDSHAKE_REQUEST      = 0x1
-	PKT_TYPE_HANDSHAKE_RESPONSE     = 0x2
-	PKT_TYPE_EXTENDED_AUTH_MSG      = 0x3
-	PKT_TYPE_TUNNEL_CREATE          = 0x4
-	PKT_TYPE_TUNNEL_RESPONSE        = 0x5
-	PKT_TYPE_TUNNEL_AUTH            = 0x6
-	PKT_TYPE_TUNNEL_AUTH_RESPONSE   = 0x7
-	PKT_TYPE_CHANNEL_CREATE         = 0x8
-	PKT_TYPE_CHANNEL_RESPONSE       = 0x9
-	PKT_TYPE_DATA                   = 0xA
-	PKT_TYPE_SERVICE_MESSAGE        = 0xB
-	PKT_TYPE_REAUTH_MESSAGE         = 0xC
-	PKT_TYPE_KEEPALIVE              = 0xD
-	PKT_TYPE_CLOSE_CHANNEL          = 0x10
-	PKT_TYPE_CLOSE_CHANNEL_RESPONSE = 0x11
-)
-
-const (
-	HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID   = 0x01
-	HTTP_TUNNEL_RESPONSE_FIELD_CAPS        = 0x02
-	HTTP_TUNNEL_RESPONSE_FIELD_SOH_REQ     = 0x04
-	HTTP_TUNNEL_RESPONSE_FIELD_CONSENT_MSG = 0x10
-)
-
-const (
-	HTTP_EXTENDED_AUTH_NONE      = 0x0
-	HTTP_EXTENDED_AUTH_SC        = 0x1  /* Smart card authentication. */
-	HTTP_EXTENDED_AUTH_PAA       = 0x02 /* Pluggable authentication. */
-	HTTP_EXTENDED_AUTH_SSPI_NTLM = 0x04 /* NTLM extended authentication. */
-)
-
-const (
-	HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS  = 0x01
-	HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT = 0x02
-	HTTP_TUNNEL_AUTH_RESPONSE_FIELD_SOH_RESPONSE = 0x04
-)
-
-const (
-	HTTP_TUNNEL_REDIR_ENABLE_ALL        = 0x80000000
-	HTTP_TUNNEL_REDIR_DISABLE_ALL       = 0x40000000
-	HTTP_TUNNEL_REDIR_DISABLE_DRIVE     = 0x01
-	HTTP_TUNNEL_REDIR_DISABLE_PRINTER   = 0x02
-	HTTP_TUNNEL_REDIR_DISABLE_PORT      = 0x03
-	HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD = 0x08
-	HTTP_TUNNEL_REDIR_DISABLE_PNP       = 0x10
-)
-
-const (
-	HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID   = 0x01
-	HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE = 0x02
-	HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT     = 0x04
-)
-
-const (
-	HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1
-)
-
-var (
-	connectionCache = prometheus.NewGauge(
-		prometheus.GaugeOpts{
-			Namespace: "rdpgw",
-			Name:      "connection_cache",
-			Help:      "The amount of connections in the cache",
-		})
-
-	websocketConnections = prometheus.NewGauge(
-		prometheus.GaugeOpts{
-			Namespace: "rdpgw",
-			Name:      "websocket_connections",
-			Help:      "The count of websocket connections",
-		})
-
-	legacyConnections = prometheus.NewGauge(
-		prometheus.GaugeOpts{
-			Namespace: "rdpgw",
-			Name:      "legacy_connections",
-			Help:      "The count of legacy https connections",
-		})
-)
-
-// HandshakeHeader is the interface that writes both upgrade request or
-// response headers into a given io.Writer.
-type HandshakeHeader interface {
-	io.WriterTo
-}
-
-type RdgSession struct {
-	ConnId        string
-	CorrelationId string
-	UserId        string
-	TransportIn   transport.Transport
-	TransportOut  transport.Transport
-	StateIn       int
-	StateOut      int
-	Remote        net.Conn
-}
-
-// ErrNotHijacker is an error returned when http.ResponseWriter does not
-// implement http.Hijacker interface.
-var ErrNotHijacker = RejectConnectionError(
-	RejectionStatus(http.StatusInternalServerError),
-	RejectionReason("given http.ResponseWriter is not a http.Hijacker"),
-)
-
-var DefaultSession RdgSession
-
-var upgrader = websocket.Upgrader{}
-var c = cache.New(5*time.Minute, 10*time.Minute)
-
-func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
-	connectionCache.Set(float64(c.ItemCount()))
-	if r.Method == MethodRDGOUT {
-		if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
-			handleLegacyProtocol(w, r)
-			return
-		}
-		r.Method = "GET" // force
-		conn, err := upgrader.Upgrade(w, r, nil)
-		if err != nil {
-			log.Printf("Cannot upgrade falling back to old protocol: %s", err)
-			return
-		}
-		defer conn.Close()
-
-		handleWebsocketProtocol(conn)
-	} else if r.Method == MethodRDGIN {
-		handleLegacyProtocol(w, r)
-	}
-}
-
-func handleWebsocketProtocol(c *websocket.Conn) {
-	websocketConnections.Inc()
-	defer websocketConnections.Dec()
-
-	inout, _ := transport.NewWS(c)
-	handler := protocol.NewHandler(inout)
-	handler.Process()
-}
-
-// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
-// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different
-// to ensure the connections do not get cached or terminated by a proxy prematurely.
-func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
-	var s RdgSession
-
-	connId := r.Header.Get(rdgConnectionIdKey)
-	x, found := c.Get(connId)
-	if !found {
-		log.Printf("No cached session found")
-		s = RdgSession{ConnId: connId, StateIn: 0, StateOut: 0}
-	} else {
-		log.Printf("Found cached session")
-		s = x.(RdgSession)
-	}
-
-	log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
-
-	if r.Method == MethodRDGOUT {
-		out, err := transport.NewLegacy(w)
-		if err != nil {
-			log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
-			return
-		}
-		log.Printf("Opening RDGOUT for client %s", out.Conn.RemoteAddr().String())
-
-		s.TransportOut = out
-		out.SendAccept(true)
-
-		c.Set(connId, s, cache.DefaultExpiration)
-	} else if r.Method == MethodRDGIN {
-		legacyConnections.Inc()
-		defer legacyConnections.Dec()
-
-		var remote net.Conn
-
-		in, err := transport.NewLegacy(w)
-		if err != nil {
-			log.Printf("cannot hijack connection to support RDG IN data channel: %s", err)
-			return
-		}
-		defer in.Close()
-
-		if s.TransportIn == nil {
-			s.TransportIn = in
-			c.Set(connId, s, cache.DefaultExpiration)
-
-			//log.Printf("Opening RDGIN for client %s", in.RemoteAddr().String())
-			in.SendAccept(false)
-
-			// read some initial data
-			in.Drain()
-
-			log.Printf("Reading packet from client %s", in.Conn.RemoteAddr().String())
-			handler := protocol.NewHandler(in)
-			for {
-				pt, sz, pkt, err := handler.ReadMessage()
-				if err != nil {
-					log.Printf("Cannot read message from stream %s", err)
-					return
-				}
-
-				switch pt {
-				case PKT_TYPE_HANDSHAKE_REQUEST:
-					major, minor, _, auth := readHandshake(pkt)
-					msg := handshakeResponse(major, minor, auth)
-					s.TransportOut.WritePacket(msg)
-				case PKT_TYPE_TUNNEL_CREATE:
-					readCreateTunnelRequest(pkt)
-					/*if _, found := tokens.Get(cookie); found == false {
-						log.Printf("Invalid PAA cookie: %s from %s", cookie, in.Conn.RemoteAddr())
-						return
-					}*/
-					msg := createTunnelResponse()
-					s.TransportOut.WritePacket(msg)
-				case PKT_TYPE_TUNNEL_AUTH:
-					readTunnelAuthRequest(pkt)
-					msg := createTunnelAuthResponse()
-					s.TransportOut.WritePacket(msg)
-				case PKT_TYPE_CHANNEL_CREATE:
-					server, port := readChannelCreateRequest(pkt)
-					log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server)
-					remote, err = net.DialTimeout(
-						"tcp",
-						net.JoinHostPort(server, strconv.Itoa(int(port))),
-						time.Second * 15)
-					if err != nil {
-						log.Printf("Error connecting to %s, %d, %s", server, port, err)
-						return
-					}
-					log.Printf("Connection established")
-					msg := createChannelCreateResponse()
-					s.TransportOut.WritePacket(msg)
-
-					// Make sure to start the flow from the RDP server first otherwise connections
-					// might hang eventually
-					go sendDataPacket(remote, s.TransportOut)
-				case PKT_TYPE_DATA:
-					forwardDataPacket(remote, pkt)
-				case PKT_TYPE_KEEPALIVE:
-					// avoid concurrency issues
-					// s.TransportOut.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
-				case PKT_TYPE_CLOSE_CHANNEL:
-					s.TransportIn.Close()
-					s.TransportOut.Close()
-					break
-				default:
-					log.Printf("Unknown packet (size %d): %x", sz, pkt)
-				}
-			}
-		}
-	}
-}
-
-// Creates a packet the is a response to a handshake request
-// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
-// but could be in Windows. However the NTLM protocol is insecure
-func handshakeResponse(major byte, minor byte, auth uint16) []byte {
-	var caps uint16
-	if conf.Caps.SmartCardAuth {
-		caps = caps | HTTP_EXTENDED_AUTH_PAA
-	}
-	if conf.Caps.TokenAuth {
-		caps = caps | HTTP_EXTENDED_AUTH_PAA
-	}
-
-	buf := new(bytes.Buffer)
-	binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code
-	buf.Write([]byte{major, minor})
-	binary.Write(buf, binary.LittleEndian, uint16(0))                                            // server version
-	binary.Write(buf, binary.LittleEndian, uint16(caps)) // extended auth
-
-	return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())
-}
-
-func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
-	r := bytes.NewReader(data)
-	binary.Read(r, binary.LittleEndian, &major)
-	binary.Read(r, binary.LittleEndian, &minor)
-	binary.Read(r, binary.LittleEndian, &version)
-	binary.Read(r, binary.LittleEndian, &extAuth)
-
-	log.Printf("major: %d, minor: %d, version: %d, ext auth: %d", major, minor, version, extAuth)
-	return
-}
-
-func readCreateTunnelRequest(data []byte) (caps uint32, cookie string) {
-	var fields uint16
-
-	r := bytes.NewReader(data)
-
-	binary.Read(r, binary.LittleEndian, &caps)
-	binary.Read(r, binary.LittleEndian, &fields)
-	r.Seek(2, io.SeekCurrent)
-
-	if fields == HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE {
-		var size uint16
-		binary.Read(r, binary.LittleEndian, &size)
-		cookieB := make([]byte, size)
-		r.Read(cookieB)
-		cookie, _ = DecodeUTF16(cookieB)
-	}
-	log.Printf("Create tunnel caps: %d, cookie: %s", caps, cookie)
-	return
-}
-
-func createTunnelResponse() []byte {
-	buf := new(bytes.Buffer)
-
-	binary.Write(buf, binary.LittleEndian, uint16(0))                                                                    // server version
-	binary.Write(buf, binary.LittleEndian, uint32(0))                                                                    // error code
-	binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID|HTTP_TUNNEL_RESPONSE_FIELD_CAPS)) // fields present
-	binary.Write(buf, binary.LittleEndian, uint16(0))                                                                    // reserved
-	binary.Write(buf, binary.LittleEndian, uint16(0))                                                                    // reserved
-
-	// tunnel id ?
-	binary.Write(buf, binary.LittleEndian, uint32(15))
-	// caps ?
-	binary.Write(buf, binary.LittleEndian, uint32(2))
-
-	return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
-}
-
-func readTunnelAuthRequest(data []byte) {
-	buf := bytes.NewReader(data)
-
-	var size uint16
-	binary.Read(buf, binary.LittleEndian, &size)
-	clData := make([]byte, size)
-	binary.Read(buf, binary.LittleEndian, &clData)
-	clientName, _ := DecodeUTF16(clData)
-	log.Printf("Client: %s", clientName)
-}
-
-func createTunnelAuthResponse() []byte {
-	buf := new(bytes.Buffer)
-
-	binary.Write(buf, binary.LittleEndian, uint32(0))                                                                                        // error code
-	binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present
-	binary.Write(buf, binary.LittleEndian, uint16(0))                                                                                        // reserved
-
-	// flags
-	var redir uint32
-	if conf.Caps.RedirectAll {
-		redir = HTTP_TUNNEL_REDIR_ENABLE_ALL
-	} else if conf.Caps.DisableRedirect {
-		redir = HTTP_TUNNEL_REDIR_DISABLE_ALL
-	} else {
-		if conf.Caps.DisableClipboard {
-			redir = redir | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD
-		}
-		if conf.Caps.DisableDrive {
-			redir = redir | HTTP_TUNNEL_REDIR_DISABLE_DRIVE
-		}
-		if conf.Caps.DisablePnp {
-			redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PNP
-		}
-		if conf.Caps.DisablePrinter {
-			redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PRINTER
-		}
-		if conf.Caps.DisablePort {
-			redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PORT
-		}
-	}
-
-	// idle timeout
-	timeout := conf.Caps.IdleTimeout
-	if timeout < 0 {
-		timeout = 0
-	}
-
-	binary.Write(buf, binary.LittleEndian, uint32(redir)) // redir flags
-	binary.Write(buf, binary.LittleEndian, uint32(timeout))                            // timeout in minutes
-
-	return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
-}
-
-func readChannelCreateRequest(data []byte) (server string, port uint16) {
-	buf := bytes.NewReader(data)
-
-	var resourcesSize byte
-	var alternative byte
-	var protocol uint16
-	var nameSize uint16
-
-	binary.Read(buf, binary.LittleEndian, &resourcesSize)
-	binary.Read(buf, binary.LittleEndian, &alternative)
-	binary.Read(buf, binary.LittleEndian, &port)
-	binary.Read(buf, binary.LittleEndian, &protocol)
-	binary.Read(buf, binary.LittleEndian, &nameSize)
-
-	nameData := make([]byte, nameSize)
-	binary.Read(buf, binary.LittleEndian, &nameData)
-
-	log.Printf("Name data %q", nameData)
-	server, _ = DecodeUTF16(nameData)
-
-	log.Printf("Should connect to %s on port %d", server, port)
-	return
-}
-
-func createChannelCreateResponse() []byte {
-	buf := new(bytes.Buffer)
-
-	binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
-	//binary.Write(buf, binary.LittleEndian, uint16(HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID | HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE | HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT)) // fields present
-	binary.Write(buf, binary.LittleEndian, uint16(0)) // fields
-	binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
-
-	// optional fields
-	// channel id uint32 (4)
-	// udp port uint16 (2)
-	// udp auth cookie 1 byte for side channel
-	// length uint16
-
-	return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
-}
-
-func createPacket(pktType uint16, data []byte) (packet []byte) {
-	size := len(data) + 8
-	buf := new(bytes.Buffer)
-
-	binary.Write(buf, binary.LittleEndian, uint16(pktType))
-	binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
-	binary.Write(buf, binary.LittleEndian, uint32(size))
-	buf.Write(data)
-
-	return buf.Bytes()
-}
-
-func forwardDataPacket(conn net.Conn, data []byte) {
-	buf := bytes.NewReader(data)
-
-	var cblen uint16
-	binary.Read(buf, binary.LittleEndian, &cblen)
-	//log.Printf("Received PKT_DATA %d", cblen)
-	pkt := make([]byte, cblen)
-	binary.Read(buf, binary.LittleEndian, &pkt)
-	//n, _ := buf.Read(pkt)
-	//log.Printf("CBLEN: %d, N: %d", cblen, n)
-	//log.Printf("DATA FROM CLIENT %q", pkt)
-	conn.Write(pkt)
-}
-
-func sendDataPacket(connIn net.Conn, connOut transport.Transport) {
-	defer connIn.Close()
-	b1 := new(bytes.Buffer)
-	buf := make([]byte, 4086)
-	for {
-		n, err := connIn.Read(buf)
-		binary.Write(b1, binary.LittleEndian, uint16(n))
-		log.Printf("RDP SIZE: %d", n)
-		if err != nil {
-			log.Printf("Error reading from conn %s", err)
-			break
-		}
-		b1.Write(buf[:n])
-		connOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
-		b1.Reset()
-	}
-}
-
-func DecodeUTF16(b []byte) (string, error) {
-	if len(b)%2 != 0 {
-		return "", fmt.Errorf("must have even length byte slice")
-	}
-
-	u16s := make([]uint16, 1)
-	ret := &bytes.Buffer{}
-	b8buf := make([]byte, 4)
-
-	lb := len(b)
-	for i := 0; i < lb; i += 2 {
-		u16s[0] = uint16(b[i]) + (uint16(b[i+1]) << 8)
-		r := utf16.Decode(u16s)
-		n := utf8.EncodeRune(b8buf, r[0])
-		ret.Write(b8buf[:n])
-	}
-
-	bret := ret.Bytes()
-	if len(bret) > 0 && bret[len(bret)-1] == '\x00' {
-		bret = bret[:len(bret)-1]
-	}
-	return string(bret), nil
-}
-
-// UTF-16 endian byte order
-const (
-	unknownEndian = iota
-	bigEndian
-	littleEndian
-)
-
-// dropCREndian drops a terminal \r from the endian data.
-func dropCREndian(data []byte, t1, t2 byte) []byte {
-	if len(data) > 1 {
-		if data[len(data)-2] == t1 && data[len(data)-1] == t2 {
-			return data[0 : len(data)-2]
-		}
-	}
-	return data
-}
-
-// dropCRBE drops a terminal \r from the big endian data.
-func dropCRBE(data []byte) []byte {
-	return dropCREndian(data, '\x00', '\r')
-}
-
-// dropCRLE drops a terminal \r from the little endian data.
-func dropCRLE(data []byte) []byte {
-	return dropCREndian(data, '\r', '\x00')
-}
-
-// dropCR drops a terminal \r from the data.
-func dropCR(data []byte) ([]byte, int) {
-	var endian = unknownEndian
-	switch ld := len(data); {
-	case ld != len(dropCRLE(data)):
-		endian = littleEndian
-	case ld != len(dropCRBE(data)):
-		endian = bigEndian
-	}
-	return data, endian
-}
-- 
GitLab