From e9b7b352cf2b30ef084975cf38e1ad0af9232a9a Mon Sep 17 00:00:00 2001 From: Bolke de Bruin <bolke@xs4all.nl> Date: Mon, 20 Jul 2020 15:51:23 +0200 Subject: [PATCH] More refactor --- errors.go | 6 +- main.go | 8 +- protocol/handler.go | 30 ++- protocol/rdg.go | 156 +++++++++++++ rdg.go | 550 -------------------------------------------- 5 files changed, 180 insertions(+), 570 deletions(-) create mode 100644 protocol/rdg.go delete mode 100644 rdg.go diff --git a/errors.go b/errors.go index bfd90b7..b4a5830 100644 --- a/errors.go +++ b/errors.go @@ -1,5 +1,7 @@ package main +import "github.com/bolkedebruin/rdpgw/protocol" + // RejectOption represents an option used to control the way connection is // rejected. type RejectOption func(*rejectConnectionError) @@ -22,7 +24,7 @@ func RejectionStatus(code int) RejectOption { // RejectionHeader returns an option that makes connection to be rejected with // given HTTP headers. -func RejectionHeader(h HandshakeHeader) RejectOption { +func RejectionHeader(h protocol.HandshakeHeader) RejectOption { return func(err *rejectConnectionError) { err.header = h } @@ -44,7 +46,7 @@ func RejectConnectionError(options ...RejectOption) error { type rejectConnectionError struct { reason string code int - header HandshakeHeader + header protocol.HandshakeHeader } // Error implements error interface. diff --git a/main.go b/main.go index e5d9960..a3a9163 100644 --- a/main.go +++ b/main.go @@ -4,9 +4,9 @@ import ( "context" "crypto/tls" "github.com/bolkedebruin/rdpgw/config" + "github.com/bolkedebruin/rdpgw/protocol" "github.com/coreos/go-oidc/v3/oidc" "github.com/patrickmn/go-cache" - "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/spf13/cobra" "golang.org/x/oauth2" @@ -89,15 +89,11 @@ func main() { TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 } - http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol) + http.HandleFunc("/remoteDesktopGateway/", protocol.HandleGatewayProtocol) http.HandleFunc("/connect", handleRdpDownload) http.Handle("/metrics", promhttp.Handler()) http.HandleFunc("/callback", handleCallback) - prometheus.MustRegister(connectionCache) - prometheus.MustRegister(legacyConnections) - prometheus.MustRegister(websocketConnections) - err = server.ListenAndServeTLS("", "") if err != nil { log.Fatal("ListenAndServe: ", err) diff --git a/protocol/handler.go b/protocol/handler.go index 07be70d..28f5f43 100644 --- a/protocol/handler.go +++ b/protocol/handler.go @@ -20,19 +20,21 @@ type VerifyTunnelAuthFunc func(string) (bool, error) type VerifyServerFunc func(string) (bool, error) type Handler struct { - Transport transport.Transport + TransportIn transport.Transport + TransportOut transport.Transport VerifyPAACookieFunc VerifyPAACookieFunc VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyServerFunc VerifyServerFunc SmartCardAuth bool TokenAuth bool ClientName string - Remote net.Conn + Remote net.Conn } -func NewHandler(t transport.Transport) *Handler { +func NewHandler(in transport.Transport, out transport.Transport) *Handler { h := &Handler{ - Transport: t, + TransportIn: in, + TransportOut: out, } return h } @@ -49,8 +51,9 @@ func (h *Handler) Process() error { case PKT_TYPE_HANDSHAKE_REQUEST: major, minor, _, auth := readHandshake(pkt) msg := h.handshakeResponse(major, minor, auth) - h.Transport.WritePacket(msg) + h.TransportOut.WritePacket(msg) case PKT_TYPE_TUNNEL_CREATE: + log.Printf("Tunnel create") _, cookie := readCreateTunnelRequest(pkt) if h.VerifyPAACookieFunc != nil { if ok, _ := h.VerifyPAACookieFunc(cookie); ok == false { @@ -59,11 +62,13 @@ func (h *Handler) Process() error { } } msg := createTunnelResponse() - h.Transport.WritePacket(msg) + h.TransportOut.WritePacket(msg) + log.Printf("Tunnel done") case PKT_TYPE_TUNNEL_AUTH: + log.Printf("Tunnel auth") h.readTunnelAuthRequest(pkt) msg := h.createTunnelAuthResponse() - h.Transport.WritePacket(msg) + h.TransportOut.WritePacket(msg) case PKT_TYPE_CHANNEL_CREATE: server, port := readChannelCreateRequest(pkt) log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server) @@ -77,7 +82,7 @@ func (h *Handler) Process() error { } log.Printf("Connection established") msg := createChannelCreateResponse() - h.Transport.WritePacket(msg) + h.TransportOut.WritePacket(msg) // Make sure to start the flow from the RDP server first otherwise connections // might hang eventually @@ -86,9 +91,10 @@ func (h *Handler) Process() error { h.forwardDataPacket(pkt) case PKT_TYPE_KEEPALIVE: // avoid concurrency issues - // p.Transport.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) + // p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) case PKT_TYPE_CLOSE_CHANNEL: - h.Transport.Close() + h.TransportIn.Close() + h.TransportOut.Close() default: log.Printf("Unknown packet (size %d): %x", sz, pkt) } @@ -101,7 +107,7 @@ func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) { buf := make([]byte, 4096) for { - size, pkt, err := h.Transport.ReadPacket() + size, pkt, err := h.TransportIn.ReadPacket() if err != nil { return 0, 0, []byte{0, 0}, err } @@ -337,7 +343,7 @@ func (h *Handler) sendDataPacket() { break } b1.Write(buf[:n]) - h.Transport.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes())) + h.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes())) b1.Reset() } } diff --git a/protocol/rdg.go b/protocol/rdg.go new file mode 100644 index 0000000..47cf3bb --- /dev/null +++ b/protocol/rdg.go @@ -0,0 +1,156 @@ +package protocol + +import ( + "github.com/bolkedebruin/rdpgw/transport" + "github.com/gorilla/websocket" + "github.com/patrickmn/go-cache" + "github.com/prometheus/client_golang/prometheus" + "io" + "log" + "net" + "net/http" + "time" +) + +const ( + rdgConnectionIdKey = "Rdg-Connection-Id" + MethodRDGIN = "RDG_IN_DATA" + MethodRDGOUT = "RDG_OUT_DATA" +) + +var ( + connectionCache = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: "rdpgw", + Name: "connection_cache", + Help: "The amount of connections in the cache", + }) + + websocketConnections = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: "rdpgw", + Name: "websocket_connections", + Help: "The count of websocket connections", + }) + + legacyConnections = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: "rdpgw", + Name: "legacy_connections", + Help: "The count of legacy https connections", + }) +) + +// HandshakeHeader is the interface that writes both upgrade request or +// response headers into a given io.Writer. +type HandshakeHeader interface { + io.WriterTo +} + +type RdgSession struct { + ConnId string + CorrelationId string + UserId string + TransportIn transport.Transport + TransportOut transport.Transport + StateIn int + StateOut int + Remote net.Conn +} + +var DefaultSession RdgSession + +var upgrader = websocket.Upgrader{} +var c = cache.New(5*time.Minute, 10*time.Minute) + +func init() { + prometheus.MustRegister(connectionCache) + prometheus.MustRegister(legacyConnections) + prometheus.MustRegister(websocketConnections) +} + +func HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) { + connectionCache.Set(float64(c.ItemCount())) + if r.Method == MethodRDGOUT { + if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { + handleLegacyProtocol(w, r) + return + } + r.Method = "GET" // force + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("Cannot upgrade falling back to old protocol: %s", err) + return + } + defer conn.Close() + + handleWebsocketProtocol(conn) + } else if r.Method == MethodRDGIN { + handleLegacyProtocol(w, r) + } +} + +func handleWebsocketProtocol(c *websocket.Conn) { + websocketConnections.Inc() + defer websocketConnections.Dec() + + inout, _ := transport.NewWS(c) + handler := NewHandler(inout, inout) + handler.Process() +} + +// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server +// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different +// to ensure the connections do not get cached or terminated by a proxy prematurely. +func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { + var s RdgSession + + connId := r.Header.Get(rdgConnectionIdKey) + x, found := c.Get(connId) + if !found { + s = RdgSession{ConnId: connId, StateIn: 0, StateOut: 0} + } else { + s = x.(RdgSession) + } + + log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil) + + if r.Method == MethodRDGOUT { + out, err := transport.NewLegacy(w) + if err != nil { + log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) + return + } + log.Printf("Opening RDGOUT for client %s", out.Conn.RemoteAddr().String()) + + s.TransportOut = out + out.SendAccept(true) + + c.Set(connId, s, cache.DefaultExpiration) + } else if r.Method == MethodRDGIN { + legacyConnections.Inc() + defer legacyConnections.Dec() + + in, err := transport.NewLegacy(w) + if err != nil { + log.Printf("cannot hijack connection to support RDG IN data channel: %s", err) + return + } + defer in.Close() + + if s.TransportIn == nil { + s.TransportIn = in + c.Set(connId, s, cache.DefaultExpiration) + + log.Printf("Opening RDGIN for client %s", in.Conn.RemoteAddr().String()) + in.SendAccept(false) + + // read some initial data + in.Drain() + + log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String()) + handler := NewHandler(in, s.TransportOut) + handler.Process() + } + } +} \ No newline at end of file diff --git a/rdg.go b/rdg.go deleted file mode 100644 index 1417ed8..0000000 --- a/rdg.go +++ /dev/null @@ -1,550 +0,0 @@ -package main - -import ( - "bytes" - "encoding/binary" - "fmt" - "github.com/bolkedebruin/rdpgw/protocol" - "github.com/bolkedebruin/rdpgw/transport" - "github.com/gorilla/websocket" - "github.com/patrickmn/go-cache" - "github.com/prometheus/client_golang/prometheus" - "io" - "log" - "net" - "net/http" - "strconv" - "time" - "unicode/utf16" - "unicode/utf8" -) - -const ( - crlf = "\r\n" - rdgConnectionIdKey = "Rdg-Connection-Id" -) - -const ( - PKT_TYPE_HANDSHAKE_REQUEST = 0x1 - PKT_TYPE_HANDSHAKE_RESPONSE = 0x2 - PKT_TYPE_EXTENDED_AUTH_MSG = 0x3 - PKT_TYPE_TUNNEL_CREATE = 0x4 - PKT_TYPE_TUNNEL_RESPONSE = 0x5 - PKT_TYPE_TUNNEL_AUTH = 0x6 - PKT_TYPE_TUNNEL_AUTH_RESPONSE = 0x7 - PKT_TYPE_CHANNEL_CREATE = 0x8 - PKT_TYPE_CHANNEL_RESPONSE = 0x9 - PKT_TYPE_DATA = 0xA - PKT_TYPE_SERVICE_MESSAGE = 0xB - PKT_TYPE_REAUTH_MESSAGE = 0xC - PKT_TYPE_KEEPALIVE = 0xD - PKT_TYPE_CLOSE_CHANNEL = 0x10 - PKT_TYPE_CLOSE_CHANNEL_RESPONSE = 0x11 -) - -const ( - HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID = 0x01 - HTTP_TUNNEL_RESPONSE_FIELD_CAPS = 0x02 - HTTP_TUNNEL_RESPONSE_FIELD_SOH_REQ = 0x04 - HTTP_TUNNEL_RESPONSE_FIELD_CONSENT_MSG = 0x10 -) - -const ( - HTTP_EXTENDED_AUTH_NONE = 0x0 - HTTP_EXTENDED_AUTH_SC = 0x1 /* Smart card authentication. */ - HTTP_EXTENDED_AUTH_PAA = 0x02 /* Pluggable authentication. */ - HTTP_EXTENDED_AUTH_SSPI_NTLM = 0x04 /* NTLM extended authentication. */ -) - -const ( - HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS = 0x01 - HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT = 0x02 - HTTP_TUNNEL_AUTH_RESPONSE_FIELD_SOH_RESPONSE = 0x04 -) - -const ( - HTTP_TUNNEL_REDIR_ENABLE_ALL = 0x80000000 - HTTP_TUNNEL_REDIR_DISABLE_ALL = 0x40000000 - HTTP_TUNNEL_REDIR_DISABLE_DRIVE = 0x01 - HTTP_TUNNEL_REDIR_DISABLE_PRINTER = 0x02 - HTTP_TUNNEL_REDIR_DISABLE_PORT = 0x03 - HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD = 0x08 - HTTP_TUNNEL_REDIR_DISABLE_PNP = 0x10 -) - -const ( - HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID = 0x01 - HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE = 0x02 - HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT = 0x04 -) - -const ( - HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1 -) - -var ( - connectionCache = prometheus.NewGauge( - prometheus.GaugeOpts{ - Namespace: "rdpgw", - Name: "connection_cache", - Help: "The amount of connections in the cache", - }) - - websocketConnections = prometheus.NewGauge( - prometheus.GaugeOpts{ - Namespace: "rdpgw", - Name: "websocket_connections", - Help: "The count of websocket connections", - }) - - legacyConnections = prometheus.NewGauge( - prometheus.GaugeOpts{ - Namespace: "rdpgw", - Name: "legacy_connections", - Help: "The count of legacy https connections", - }) -) - -// HandshakeHeader is the interface that writes both upgrade request or -// response headers into a given io.Writer. -type HandshakeHeader interface { - io.WriterTo -} - -type RdgSession struct { - ConnId string - CorrelationId string - UserId string - TransportIn transport.Transport - TransportOut transport.Transport - StateIn int - StateOut int - Remote net.Conn -} - -// ErrNotHijacker is an error returned when http.ResponseWriter does not -// implement http.Hijacker interface. -var ErrNotHijacker = RejectConnectionError( - RejectionStatus(http.StatusInternalServerError), - RejectionReason("given http.ResponseWriter is not a http.Hijacker"), -) - -var DefaultSession RdgSession - -var upgrader = websocket.Upgrader{} -var c = cache.New(5*time.Minute, 10*time.Minute) - -func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) { - connectionCache.Set(float64(c.ItemCount())) - if r.Method == MethodRDGOUT { - if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { - handleLegacyProtocol(w, r) - return - } - r.Method = "GET" // force - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - log.Printf("Cannot upgrade falling back to old protocol: %s", err) - return - } - defer conn.Close() - - handleWebsocketProtocol(conn) - } else if r.Method == MethodRDGIN { - handleLegacyProtocol(w, r) - } -} - -func handleWebsocketProtocol(c *websocket.Conn) { - websocketConnections.Inc() - defer websocketConnections.Dec() - - inout, _ := transport.NewWS(c) - handler := protocol.NewHandler(inout) - handler.Process() -} - -// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server -// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different -// to ensure the connections do not get cached or terminated by a proxy prematurely. -func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { - var s RdgSession - - connId := r.Header.Get(rdgConnectionIdKey) - x, found := c.Get(connId) - if !found { - log.Printf("No cached session found") - s = RdgSession{ConnId: connId, StateIn: 0, StateOut: 0} - } else { - log.Printf("Found cached session") - s = x.(RdgSession) - } - - log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil) - - if r.Method == MethodRDGOUT { - out, err := transport.NewLegacy(w) - if err != nil { - log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) - return - } - log.Printf("Opening RDGOUT for client %s", out.Conn.RemoteAddr().String()) - - s.TransportOut = out - out.SendAccept(true) - - c.Set(connId, s, cache.DefaultExpiration) - } else if r.Method == MethodRDGIN { - legacyConnections.Inc() - defer legacyConnections.Dec() - - var remote net.Conn - - in, err := transport.NewLegacy(w) - if err != nil { - log.Printf("cannot hijack connection to support RDG IN data channel: %s", err) - return - } - defer in.Close() - - if s.TransportIn == nil { - s.TransportIn = in - c.Set(connId, s, cache.DefaultExpiration) - - //log.Printf("Opening RDGIN for client %s", in.RemoteAddr().String()) - in.SendAccept(false) - - // read some initial data - in.Drain() - - log.Printf("Reading packet from client %s", in.Conn.RemoteAddr().String()) - handler := protocol.NewHandler(in) - for { - pt, sz, pkt, err := handler.ReadMessage() - if err != nil { - log.Printf("Cannot read message from stream %s", err) - return - } - - switch pt { - case PKT_TYPE_HANDSHAKE_REQUEST: - major, minor, _, auth := readHandshake(pkt) - msg := handshakeResponse(major, minor, auth) - s.TransportOut.WritePacket(msg) - case PKT_TYPE_TUNNEL_CREATE: - readCreateTunnelRequest(pkt) - /*if _, found := tokens.Get(cookie); found == false { - log.Printf("Invalid PAA cookie: %s from %s", cookie, in.Conn.RemoteAddr()) - return - }*/ - msg := createTunnelResponse() - s.TransportOut.WritePacket(msg) - case PKT_TYPE_TUNNEL_AUTH: - readTunnelAuthRequest(pkt) - msg := createTunnelAuthResponse() - s.TransportOut.WritePacket(msg) - case PKT_TYPE_CHANNEL_CREATE: - server, port := readChannelCreateRequest(pkt) - log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server) - remote, err = net.DialTimeout( - "tcp", - net.JoinHostPort(server, strconv.Itoa(int(port))), - time.Second * 15) - if err != nil { - log.Printf("Error connecting to %s, %d, %s", server, port, err) - return - } - log.Printf("Connection established") - msg := createChannelCreateResponse() - s.TransportOut.WritePacket(msg) - - // Make sure to start the flow from the RDP server first otherwise connections - // might hang eventually - go sendDataPacket(remote, s.TransportOut) - case PKT_TYPE_DATA: - forwardDataPacket(remote, pkt) - case PKT_TYPE_KEEPALIVE: - // avoid concurrency issues - // s.TransportOut.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) - case PKT_TYPE_CLOSE_CHANNEL: - s.TransportIn.Close() - s.TransportOut.Close() - break - default: - log.Printf("Unknown packet (size %d): %x", sz, pkt) - } - } - } - } -} - -// Creates a packet the is a response to a handshake request -// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux -// but could be in Windows. However the NTLM protocol is insecure -func handshakeResponse(major byte, minor byte, auth uint16) []byte { - var caps uint16 - if conf.Caps.SmartCardAuth { - caps = caps | HTTP_EXTENDED_AUTH_PAA - } - if conf.Caps.TokenAuth { - caps = caps | HTTP_EXTENDED_AUTH_PAA - } - - buf := new(bytes.Buffer) - binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code - buf.Write([]byte{major, minor}) - binary.Write(buf, binary.LittleEndian, uint16(0)) // server version - binary.Write(buf, binary.LittleEndian, uint16(caps)) // extended auth - - return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes()) -} - -func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth uint16) { - r := bytes.NewReader(data) - binary.Read(r, binary.LittleEndian, &major) - binary.Read(r, binary.LittleEndian, &minor) - binary.Read(r, binary.LittleEndian, &version) - binary.Read(r, binary.LittleEndian, &extAuth) - - log.Printf("major: %d, minor: %d, version: %d, ext auth: %d", major, minor, version, extAuth) - return -} - -func readCreateTunnelRequest(data []byte) (caps uint32, cookie string) { - var fields uint16 - - r := bytes.NewReader(data) - - binary.Read(r, binary.LittleEndian, &caps) - binary.Read(r, binary.LittleEndian, &fields) - r.Seek(2, io.SeekCurrent) - - if fields == HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE { - var size uint16 - binary.Read(r, binary.LittleEndian, &size) - cookieB := make([]byte, size) - r.Read(cookieB) - cookie, _ = DecodeUTF16(cookieB) - } - log.Printf("Create tunnel caps: %d, cookie: %s", caps, cookie) - return -} - -func createTunnelResponse() []byte { - buf := new(bytes.Buffer) - - binary.Write(buf, binary.LittleEndian, uint16(0)) // server version - binary.Write(buf, binary.LittleEndian, uint32(0)) // error code - binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID|HTTP_TUNNEL_RESPONSE_FIELD_CAPS)) // fields present - binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved - binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved - - // tunnel id ? - binary.Write(buf, binary.LittleEndian, uint32(15)) - // caps ? - binary.Write(buf, binary.LittleEndian, uint32(2)) - - return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes()) -} - -func readTunnelAuthRequest(data []byte) { - buf := bytes.NewReader(data) - - var size uint16 - binary.Read(buf, binary.LittleEndian, &size) - clData := make([]byte, size) - binary.Read(buf, binary.LittleEndian, &clData) - clientName, _ := DecodeUTF16(clData) - log.Printf("Client: %s", clientName) -} - -func createTunnelAuthResponse() []byte { - buf := new(bytes.Buffer) - - binary.Write(buf, binary.LittleEndian, uint32(0)) // error code - binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present - binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved - - // flags - var redir uint32 - if conf.Caps.RedirectAll { - redir = HTTP_TUNNEL_REDIR_ENABLE_ALL - } else if conf.Caps.DisableRedirect { - redir = HTTP_TUNNEL_REDIR_DISABLE_ALL - } else { - if conf.Caps.DisableClipboard { - redir = redir | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD - } - if conf.Caps.DisableDrive { - redir = redir | HTTP_TUNNEL_REDIR_DISABLE_DRIVE - } - if conf.Caps.DisablePnp { - redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PNP - } - if conf.Caps.DisablePrinter { - redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PRINTER - } - if conf.Caps.DisablePort { - redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PORT - } - } - - // idle timeout - timeout := conf.Caps.IdleTimeout - if timeout < 0 { - timeout = 0 - } - - binary.Write(buf, binary.LittleEndian, uint32(redir)) // redir flags - binary.Write(buf, binary.LittleEndian, uint32(timeout)) // timeout in minutes - - return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes()) -} - -func readChannelCreateRequest(data []byte) (server string, port uint16) { - buf := bytes.NewReader(data) - - var resourcesSize byte - var alternative byte - var protocol uint16 - var nameSize uint16 - - binary.Read(buf, binary.LittleEndian, &resourcesSize) - binary.Read(buf, binary.LittleEndian, &alternative) - binary.Read(buf, binary.LittleEndian, &port) - binary.Read(buf, binary.LittleEndian, &protocol) - binary.Read(buf, binary.LittleEndian, &nameSize) - - nameData := make([]byte, nameSize) - binary.Read(buf, binary.LittleEndian, &nameData) - - log.Printf("Name data %q", nameData) - server, _ = DecodeUTF16(nameData) - - log.Printf("Should connect to %s on port %d", server, port) - return -} - -func createChannelCreateResponse() []byte { - buf := new(bytes.Buffer) - - binary.Write(buf, binary.LittleEndian, uint32(0)) // error code - //binary.Write(buf, binary.LittleEndian, uint16(HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID | HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE | HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT)) // fields present - binary.Write(buf, binary.LittleEndian, uint16(0)) // fields - binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved - - // optional fields - // channel id uint32 (4) - // udp port uint16 (2) - // udp auth cookie 1 byte for side channel - // length uint16 - - return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes()) -} - -func createPacket(pktType uint16, data []byte) (packet []byte) { - size := len(data) + 8 - buf := new(bytes.Buffer) - - binary.Write(buf, binary.LittleEndian, uint16(pktType)) - binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved - binary.Write(buf, binary.LittleEndian, uint32(size)) - buf.Write(data) - - return buf.Bytes() -} - -func forwardDataPacket(conn net.Conn, data []byte) { - buf := bytes.NewReader(data) - - var cblen uint16 - binary.Read(buf, binary.LittleEndian, &cblen) - //log.Printf("Received PKT_DATA %d", cblen) - pkt := make([]byte, cblen) - binary.Read(buf, binary.LittleEndian, &pkt) - //n, _ := buf.Read(pkt) - //log.Printf("CBLEN: %d, N: %d", cblen, n) - //log.Printf("DATA FROM CLIENT %q", pkt) - conn.Write(pkt) -} - -func sendDataPacket(connIn net.Conn, connOut transport.Transport) { - defer connIn.Close() - b1 := new(bytes.Buffer) - buf := make([]byte, 4086) - for { - n, err := connIn.Read(buf) - binary.Write(b1, binary.LittleEndian, uint16(n)) - log.Printf("RDP SIZE: %d", n) - if err != nil { - log.Printf("Error reading from conn %s", err) - break - } - b1.Write(buf[:n]) - connOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes())) - b1.Reset() - } -} - -func DecodeUTF16(b []byte) (string, error) { - if len(b)%2 != 0 { - return "", fmt.Errorf("must have even length byte slice") - } - - u16s := make([]uint16, 1) - ret := &bytes.Buffer{} - b8buf := make([]byte, 4) - - lb := len(b) - for i := 0; i < lb; i += 2 { - u16s[0] = uint16(b[i]) + (uint16(b[i+1]) << 8) - r := utf16.Decode(u16s) - n := utf8.EncodeRune(b8buf, r[0]) - ret.Write(b8buf[:n]) - } - - bret := ret.Bytes() - if len(bret) > 0 && bret[len(bret)-1] == '\x00' { - bret = bret[:len(bret)-1] - } - return string(bret), nil -} - -// UTF-16 endian byte order -const ( - unknownEndian = iota - bigEndian - littleEndian -) - -// dropCREndian drops a terminal \r from the endian data. -func dropCREndian(data []byte, t1, t2 byte) []byte { - if len(data) > 1 { - if data[len(data)-2] == t1 && data[len(data)-1] == t2 { - return data[0 : len(data)-2] - } - } - return data -} - -// dropCRBE drops a terminal \r from the big endian data. -func dropCRBE(data []byte) []byte { - return dropCREndian(data, '\x00', '\r') -} - -// dropCRLE drops a terminal \r from the little endian data. -func dropCRLE(data []byte) []byte { - return dropCREndian(data, '\r', '\x00') -} - -// dropCR drops a terminal \r from the data. -func dropCR(data []byte) ([]byte, int) { - var endian = unknownEndian - switch ld := len(data); { - case ld != len(dropCRLE(data)): - endian = littleEndian - case ld != len(dropCRBE(data)): - endian = bigEndian - } - return data, endian -} -- GitLab