diff --git a/protocol/handler.go b/protocol/handler.go index b48ee7f9ecb374b8d4625ee21afd64516891c8bd..07be70d999359becb26579c52a94353208b30c0a 100644 --- a/protocol/handler.go +++ b/protocol/handler.go @@ -6,10 +6,28 @@ import ( "errors" "github.com/bolkedebruin/rdpgw/transport" "io" + "log" + "net" + "strconv" + "time" ) +// When should the client disconnect when idle in minutes +var IdleTimeout = 0 + +type VerifyPAACookieFunc func(string) (bool, error) +type VerifyTunnelAuthFunc func(string) (bool, error) +type VerifyServerFunc func(string) (bool, error) + type Handler struct { - Transport transport.Transport + Transport transport.Transport + VerifyPAACookieFunc VerifyPAACookieFunc + VerifyTunnelAuthFunc VerifyTunnelAuthFunc + VerifyServerFunc VerifyServerFunc + SmartCardAuth bool + TokenAuth bool + ClientName string + Remote net.Conn } func NewHandler(t transport.Transport) *Handler { @@ -19,15 +37,73 @@ func NewHandler(t transport.Transport) *Handler { return h } -func (p *Handler) ReadMessage() (pt int, n int, msg []byte, err error) { +func (h *Handler) Process() error { + for { + pt, sz, pkt, err := h.ReadMessage() + if err != nil { + log.Printf("Cannot read message from stream %s", err) + return err + } + + switch pt { + case PKT_TYPE_HANDSHAKE_REQUEST: + major, minor, _, auth := readHandshake(pkt) + msg := h.handshakeResponse(major, minor, auth) + h.Transport.WritePacket(msg) + case PKT_TYPE_TUNNEL_CREATE: + _, cookie := readCreateTunnelRequest(pkt) + if h.VerifyPAACookieFunc != nil { + if ok, _ := h.VerifyPAACookieFunc(cookie); ok == false { + log.Printf("Invalid PAA cookie: %s", cookie) + return errors.New("invalid PAA cookie") + } + } + msg := createTunnelResponse() + h.Transport.WritePacket(msg) + case PKT_TYPE_TUNNEL_AUTH: + h.readTunnelAuthRequest(pkt) + msg := h.createTunnelAuthResponse() + h.Transport.WritePacket(msg) + case PKT_TYPE_CHANNEL_CREATE: + server, port := readChannelCreateRequest(pkt) + log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server) + h.Remote, err = net.DialTimeout( + "tcp", + net.JoinHostPort(server, strconv.Itoa(int(port))), + time.Second*15) + if err != nil { + log.Printf("Error connecting to %s, %d, %s", server, port, err) + return err + } + log.Printf("Connection established") + msg := createChannelCreateResponse() + h.Transport.WritePacket(msg) + + // Make sure to start the flow from the RDP server first otherwise connections + // might hang eventually + go h.sendDataPacket() + case PKT_TYPE_DATA: + h.forwardDataPacket(pkt) + case PKT_TYPE_KEEPALIVE: + // avoid concurrency issues + // p.Transport.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) + case PKT_TYPE_CLOSE_CHANNEL: + h.Transport.Close() + default: + log.Printf("Unknown packet (size %d): %x", sz, pkt) + } + } +} + +func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) { fragment := false index := 0 buf := make([]byte, 4096) for { - size, pkt, err := p.Transport.ReadPacket() + size, pkt, err := h.Transport.ReadPacket() if err != nil { - return 0, 0, []byte{0,0}, err + return 0, 0, []byte{0, 0}, err } // check for fragments @@ -48,7 +124,7 @@ func (p *Handler) ReadMessage() (pt int, n int, msg []byte, err error) { pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...)) // header is corrupted even after defragmenting if err != nil { - return 0, 0, []byte{0,0}, err + return 0, 0, []byte{0, 0}, err } } if !fragment { @@ -57,6 +133,27 @@ func (p *Handler) ReadMessage() (pt int, n int, msg []byte, err error) { } } +// Creates a packet the is a response to a handshake request +// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux +// but could be in Windows. However the NTLM protocol is insecure +func (h *Handler) handshakeResponse(major byte, minor byte, auth uint16) []byte { + var caps uint16 + if h.SmartCardAuth { + caps = caps | HTTP_EXTENDED_AUTH_PAA + } + if h.TokenAuth { + caps = caps | HTTP_EXTENDED_AUTH_PAA + } + + buf := new(bytes.Buffer) + binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code + buf.Write([]byte{major, minor}) + binary.Write(buf, binary.LittleEndian, uint16(0)) // server version + binary.Write(buf, binary.LittleEndian, uint16(caps)) // extended auth + + return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes()) +} + func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) { // header needs to be 8 min if len(data) < 8 { @@ -71,3 +168,188 @@ func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err } return packetType, size, data[8:], nil } + +func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth uint16) { + r := bytes.NewReader(data) + binary.Read(r, binary.LittleEndian, &major) + binary.Read(r, binary.LittleEndian, &minor) + binary.Read(r, binary.LittleEndian, &version) + binary.Read(r, binary.LittleEndian, &extAuth) + + log.Printf("major: %d, minor: %d, version: %d, ext auth: %d", major, minor, version, extAuth) + return +} + +func readCreateTunnelRequest(data []byte) (caps uint32, cookie string) { + var fields uint16 + + r := bytes.NewReader(data) + + binary.Read(r, binary.LittleEndian, &caps) + binary.Read(r, binary.LittleEndian, &fields) + r.Seek(2, io.SeekCurrent) + + if fields == HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE { + var size uint16 + binary.Read(r, binary.LittleEndian, &size) + cookieB := make([]byte, size) + r.Read(cookieB) + cookie, _ = DecodeUTF16(cookieB) + } + log.Printf("Create tunnel caps: %d, cookie: %s", caps, cookie) + return +} + +func createTunnelResponse() []byte { + buf := new(bytes.Buffer) + + binary.Write(buf, binary.LittleEndian, uint16(0)) // server version + binary.Write(buf, binary.LittleEndian, uint32(0)) // error code + binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID|HTTP_TUNNEL_RESPONSE_FIELD_CAPS)) // fields present + binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved + binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved + + // tunnel id ? + binary.Write(buf, binary.LittleEndian, uint32(15)) + // caps ? + binary.Write(buf, binary.LittleEndian, uint32(2)) + + return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes()) +} + +func (h *Handler) readTunnelAuthRequest(data []byte) { + buf := bytes.NewReader(data) + + var size uint16 + binary.Read(buf, binary.LittleEndian, &size) + clData := make([]byte, size) + binary.Read(buf, binary.LittleEndian, &clData) + clientName, _ := DecodeUTF16(clData) + log.Printf("Client: %s", clientName) +} + +func (h *Handler) createTunnelAuthResponse() []byte { + buf := new(bytes.Buffer) + + binary.Write(buf, binary.LittleEndian, uint32(0)) // error code + binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present + binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved + + // flags + var redir uint32 + /* + if conf.Caps.RedirectAll { + redir = HTTP_TUNNEL_REDIR_ENABLE_ALL + } else if conf.Caps.DisableRedirect { + redir = HTTP_TUNNEL_REDIR_DISABLE_ALL + } else { + if conf.Caps.DisableClipboard { + redir = redir | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD + } + if conf.Caps.DisableDrive { + redir = redir | HTTP_TUNNEL_REDIR_DISABLE_DRIVE + } + if conf.Caps.DisablePnp { + redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PNP + } + if conf.Caps.DisablePrinter { + redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PRINTER + } + if conf.Caps.DisablePort { + redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PORT + } + } + */ + redir = HTTP_TUNNEL_REDIR_ENABLE_ALL + + // idle timeout + if IdleTimeout < 0 { + IdleTimeout = 0 + } + + binary.Write(buf, binary.LittleEndian, uint32(redir)) // redir flags + binary.Write(buf, binary.LittleEndian, uint32(IdleTimeout)) // timeout in minutes + + return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes()) +} + +func readChannelCreateRequest(data []byte) (server string, port uint16) { + buf := bytes.NewReader(data) + + var resourcesSize byte + var alternative byte + var protocol uint16 + var nameSize uint16 + + binary.Read(buf, binary.LittleEndian, &resourcesSize) + binary.Read(buf, binary.LittleEndian, &alternative) + binary.Read(buf, binary.LittleEndian, &port) + binary.Read(buf, binary.LittleEndian, &protocol) + binary.Read(buf, binary.LittleEndian, &nameSize) + + nameData := make([]byte, nameSize) + binary.Read(buf, binary.LittleEndian, &nameData) + + log.Printf("Name data %q", nameData) + server, _ = DecodeUTF16(nameData) + + log.Printf("Should connect to %s on port %d", server, port) + return +} + +func createChannelCreateResponse() []byte { + buf := new(bytes.Buffer) + + binary.Write(buf, binary.LittleEndian, uint32(0)) // error code + //binary.Write(buf, binary.LittleEndian, uint16(HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID | HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE | HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT)) // fields present + binary.Write(buf, binary.LittleEndian, uint16(0)) // fields + binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved + + // optional fields + // channel id uint32 (4) + // udp port uint16 (2) + // udp auth cookie 1 byte for side channel + // length uint16 + + return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes()) +} + +func (h *Handler) forwardDataPacket(data []byte) { + buf := bytes.NewReader(data) + + var cblen uint16 + binary.Read(buf, binary.LittleEndian, &cblen) + pkt := make([]byte, cblen) + binary.Read(buf, binary.LittleEndian, &pkt) + + h.Remote.Write(pkt) +} + +func (h *Handler) sendDataPacket() { + defer h.Remote.Close() + b1 := new(bytes.Buffer) + buf := make([]byte, 4086) + for { + n, err := h.Remote.Read(buf) + binary.Write(b1, binary.LittleEndian, uint16(n)) + if err != nil { + log.Printf("Error reading from conn %s", err) + break + } + b1.Write(buf[:n]) + h.Transport.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes())) + b1.Reset() + } +} + +func createPacket(pktType uint16, data []byte) (packet []byte) { + size := len(data) + 8 + buf := new(bytes.Buffer) + + binary.Write(buf, binary.LittleEndian, uint16(pktType)) + binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved + binary.Write(buf, binary.LittleEndian, uint32(size)) + buf.Write(data) + + return buf.Bytes() +} diff --git a/protocol/types.go b/protocol/types.go new file mode 100644 index 0000000000000000000000000000000000000000..c12487d6086da5c90ee8c3fe25a456fe17401ecb --- /dev/null +++ b/protocol/types.go @@ -0,0 +1,60 @@ +package protocol + +const ( + PKT_TYPE_HANDSHAKE_REQUEST = 0x1 + PKT_TYPE_HANDSHAKE_RESPONSE = 0x2 + PKT_TYPE_EXTENDED_AUTH_MSG = 0x3 + PKT_TYPE_TUNNEL_CREATE = 0x4 + PKT_TYPE_TUNNEL_RESPONSE = 0x5 + PKT_TYPE_TUNNEL_AUTH = 0x6 + PKT_TYPE_TUNNEL_AUTH_RESPONSE = 0x7 + PKT_TYPE_CHANNEL_CREATE = 0x8 + PKT_TYPE_CHANNEL_RESPONSE = 0x9 + PKT_TYPE_DATA = 0xA + PKT_TYPE_SERVICE_MESSAGE = 0xB + PKT_TYPE_REAUTH_MESSAGE = 0xC + PKT_TYPE_KEEPALIVE = 0xD + PKT_TYPE_CLOSE_CHANNEL = 0x10 + PKT_TYPE_CLOSE_CHANNEL_RESPONSE = 0x11 +) + +const ( + HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID = 0x01 + HTTP_TUNNEL_RESPONSE_FIELD_CAPS = 0x02 + HTTP_TUNNEL_RESPONSE_FIELD_SOH_REQ = 0x04 + HTTP_TUNNEL_RESPONSE_FIELD_CONSENT_MSG = 0x10 +) + +const ( + HTTP_EXTENDED_AUTH_NONE = 0x0 + HTTP_EXTENDED_AUTH_SC = 0x1 /* Smart card authentication. */ + HTTP_EXTENDED_AUTH_PAA = 0x02 /* Pluggable authentication. */ + HTTP_EXTENDED_AUTH_SSPI_NTLM = 0x04 /* NTLM extended authentication. */ +) + +const ( + HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS = 0x01 + HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT = 0x02 + HTTP_TUNNEL_AUTH_RESPONSE_FIELD_SOH_RESPONSE = 0x04 +) + +const ( + HTTP_TUNNEL_REDIR_ENABLE_ALL = 0x80000000 + HTTP_TUNNEL_REDIR_DISABLE_ALL = 0x40000000 + HTTP_TUNNEL_REDIR_DISABLE_DRIVE = 0x01 + HTTP_TUNNEL_REDIR_DISABLE_PRINTER = 0x02 + HTTP_TUNNEL_REDIR_DISABLE_PORT = 0x03 + HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD = 0x08 + HTTP_TUNNEL_REDIR_DISABLE_PNP = 0x10 +) + +const ( + HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID = 0x01 + HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE = 0x02 + HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT = 0x04 +) + +const ( + HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1 +) + diff --git a/protocol/utf16.go b/protocol/utf16.go new file mode 100644 index 0000000000000000000000000000000000000000..963dce1746c7b2a3a334f819abf0f3526dd1a263 --- /dev/null +++ b/protocol/utf16.go @@ -0,0 +1,32 @@ +package protocol + +import ( + "bytes" + "fmt" + "unicode/utf16" + "unicode/utf8" +) + +func DecodeUTF16(b []byte) (string, error) { + if len(b)%2 != 0 { + return "", fmt.Errorf("must have even length byte slice") + } + + u16s := make([]uint16, 1) + ret := &bytes.Buffer{} + b8buf := make([]byte, 4) + + lb := len(b) + for i := 0; i < lb; i += 2 { + u16s[0] = uint16(b[i]) + (uint16(b[i+1]) << 8) + r := utf16.Decode(u16s) + n := utf8.EncodeRune(b8buf, r[0]) + ret.Write(b8buf[:n]) + } + + bret := ret.Bytes() + if len(bret) > 0 && bret[len(bret)-1] == '\x00' { + bret = bret[:len(bret)-1] + } + return string(bret), nil +} diff --git a/rdg.go b/rdg.go index e478329d1a1be4cd5ebaa583a8e9a26183536669..1417ed8b584e30631e6310e8735d9e0b22368e97 100644 --- a/rdg.go +++ b/rdg.go @@ -3,7 +3,6 @@ package main import ( "bytes" "encoding/binary" - "errors" "fmt" "github.com/bolkedebruin/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/transport" @@ -157,80 +156,12 @@ func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) { } func handleWebsocketProtocol(c *websocket.Conn) { - var remote net.Conn - websocketConnections.Inc() defer websocketConnections.Dec() inout, _ := transport.NewWS(c) handler := protocol.NewHandler(inout) - - var host string - for { - pt, sz, pkt, err := handler.ReadMessage() - if err != nil { - log.Printf("Cannot read message from stream %s", err) - return - } - switch pt { - case PKT_TYPE_HANDSHAKE_REQUEST: - major, minor, _, auth := readHandshake(pkt) - msg := handshakeResponse(major, minor, auth) - log.Printf("Handshake response: %x", msg) - inout.WritePacket(msg) - case PKT_TYPE_TUNNEL_CREATE: - readCreateTunnelRequest(pkt) - /*data, found := tokens.Get(cookie) - if found == false { - log.Printf("Invalid PAA cookie: %s from %s", cookie, inout.Conn.RemoteAddr()) - return - }*/ - host = conf.Server.HostTemplate - /* - for k, v := range data.(map[string]interface{}) { - if val, ok := v.(string); ok == true { - host = strings.Replace(host, "{{ " + k + " }}", val, 1) - } - }*/ - msg := createTunnelResponse() - log.Printf("Create tunnel response: %x", msg) - inout.WritePacket(msg) - case PKT_TYPE_TUNNEL_AUTH: - readTunnelAuthRequest(pkt) - msg := createTunnelAuthResponse() - log.Printf("Create tunnel auth response: %x", msg) - inout.WritePacket(msg) - case PKT_TYPE_CHANNEL_CREATE: - server, port := readChannelCreateRequest(pkt) - if conf.Server.EnableOverride == true { - log.Printf("Override allowed") - host = net.JoinHostPort(server, strconv.Itoa(int(port))) - } - log.Printf("Establishing connection to RDP server: %s", host) - remote, err = net.DialTimeout( - "tcp", - host, - time.Second * 30) - if err != nil { - log.Printf("Error connecting to %s", host) - return - } - log.Printf("Connection established") - msg := createChannelCreateResponse() - log.Printf("Create channel create response: %x", msg) - inout.WritePacket(msg) - go sendDataPacket(remote, inout) - case PKT_TYPE_DATA: - forwardDataPacket(remote, pkt) - case PKT_TYPE_KEEPALIVE: - // do not write to make sure we do not create concurrency issues - // inout.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{})) - case PKT_TYPE_CLOSE_CHANNEL: - break - default: - log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz, pkt) - } - } + handler.Process() } // The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server @@ -537,24 +468,6 @@ func forwardDataPacket(conn net.Conn, data []byte) { conn.Write(pkt) } -func handleWebsocketData(rdp net.Conn, conn transport.Transport) { - defer rdp.Close() - b1 := new(bytes.Buffer) - buf := make([]byte, 4086) - - for { - n, err := rdp.Read(buf) - binary.Write(b1, binary.LittleEndian, uint16(n)) - if err != nil { - log.Printf("Error reading from conn %s", err) - break - } - b1.Write(buf[:n]) - conn.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes())) - b1.Reset() - } -} - func sendDataPacket(connIn net.Conn, connOut transport.Transport) { defer connIn.Close() b1 := new(bytes.Buffer)