From 80d11598ec188d7064164a014418e1e2e287259d Mon Sep 17 00:00:00 2001 From: Bolke de Bruin <bolke@xs4all.nl> Date: Thu, 9 Jul 2020 10:15:27 +0200 Subject: [PATCH] Working websockets --- go.mod | 8 + main.go | 215 ++++---------------------- rdg.go | 456 +++++++++++++++++++++++++++++++++++++------------------- 3 files changed, 340 insertions(+), 339 deletions(-) create mode 100644 go.mod diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7b6b0d6 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module github.com/bolkedebruin/rdpgw + +go 1.14 + +require ( + github.com/gorilla/websocket v1.4.2 + github.com/patrickmn/go-cache v2.1.0+incompatible +) diff --git a/main.go b/main.go index 035182f..9cb7531 100644 --- a/main.go +++ b/main.go @@ -2,212 +2,51 @@ package main import ( "crypto/tls" - "net/http/httputil" - "os" - //"time" - - //"bytes" - "fmt" + "flag" "log" - //"strings" - // "io" "net/http" - //"net/http/httputil" - //"math/rand" - //"encoding/binary" - //"encoding/base64" + "os" + "strconv" ) +func main() { + port := flag.Int("port", 443, "port to listen on for incoming connections") + certFile := flag.String("certfile", "server.pem", "public key certificate file") + keyFile := flag.String("keyfile", "key.pem", "private key file") -/* -func handleConnection(s *MySession) { - inData := make([]byte, 4096) + flag.Parse() - for { - size, err := s.buffIn.Read(inData) - if err != nil { - s.inConn.Close() - s.outConn.Close() - fmt.Println(err) - } - fmt.Printf("Bytes read on IN %d\n", size) + if *certFile == "" || *keyFile == "" { + log.Fatal("Both certfile and keyfile need to be specified") } -}*/ - -/* -func MethodOverride(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Println(r.Method) - dump, _ := httputil.DumpRequest(r, false) - fmt.Printf("%q\n", dump) - - headerKey := "Rdg-Connection-Id" - connId := r.Header.Get(headerKey) - if connId != "" { - s.guid = connId - } - auth := r.Header.Get("Authorization") - fmt.Printf("Connection ID: %s\n", s.guid) - - if strings.Contains(auth,"NTLMX") { - /*var msg_req_b []byte - base64.StdEncoding.Decode(msg_req_b, []byte(auth[strings.Index(auth,"NTLM")+6:])) - - msg_type := binary.LittleEndian.Uint32(msg_req_b[0:4]) - fmt.Printf("Message type %v\n", msg_type) - if msg_type == 1 { - var nonce [8]byte - r := make([]byte, 8) - rand.Read(r) - copy(nonce[:], r) - - sig_buf := new(bytes.Buffer) - var signature [8]byte - binary.Write(sig_buf, binary.LittleEndian, "NTLMSSP\000") - copy(signature[:], sig_buf.Bytes()) - zero := make([]byte, 7) - pad := make([]byte, 2) - - rand.Read(nonce) - - buf := new(bytes.Buffer) - msg := NtlmChallenge{ - signature, - uint32(0x02), - 0, - 0, - 0, - []byte(), - nonce, - 0, - 0 - } - _ := binary.Write(buf, binary.LittleEndian, msg) - header := "NTLM" + base64.StdEncoding.EncodeToString(buf.Bytes()) - w.Header().Set("WWW-Authenticate", header) - w.WriteHeader(401) - w.Write([]byte("Unauthorized.\n")) - fmt.Println("Unauthorized") - return - } - } else { - _, _, ok := r.BasicAuth() - - if !ok && !s.hasIn { - w.Header().Set("WWW-Authenticate", `Basic realm="rdpgw"`) - w.WriteHeader(401) - w.Write([]byte("Unauthorized.\n")) - fmt.Println("Unauthorized") - return - } - } + //mux := http.NewServeMux() + //mux.HandleFunc("*", HelloServer) - if r.Method == "RDG_OUT_DATA" { - fmt.Println("Hijacking OUT") - hj, ok := w.(http.Hijacker) - if !ok { - http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError) - return - } - conn, bufrw, err := hj.Hijack() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - s.outConn = conn - s.buffOut = bufrw - - if !s.hasOut { - fmt.Printf("Creating OUT and sending seed\n") - s.hasOut = true - seed := make([]byte, 100) - rand.Read(seed) - bufrw.WriteString("HTTP/1.1 200 OK\r\n") - fmt.Fprintf(bufrw, "Date: %s\r\n", time.Now().Format(time.RFC1123)) - bufrw.WriteString("Content-Type: application/octet-stream\r\n") - bufrw.WriteString("Content-Length: 0\r\n") - bufrw.WriteString(crlf) - bufrw.Write(seed) - bufrw.Flush() - return - } else { - fmt.Printf("Handle OUT\n") - handleConnection(s) - return - } - } - - if r.Method == "RDG_IN_DATA" { - fmt.Println("Hijacking IN") - hj, ok := w.(http.Hijacker) - if !ok { - http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError) - return - } - conn, bufrw, err := hj.Hijack() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - s.inConn = conn - s.buffIn = bufrw - - if !s.hasIn { - fmt.Printf("Creating IN and sending seed\n") - s.hasIn = true - seed := make([]byte, 100) - rand.Read(seed) - bufrw.WriteString("HTTP/1.1 200 OK\r\n") - fmt.Fprintf(bufrw, "Date: %s\r\n", time.Now().Format(time.RFC1123)) - bufrw.WriteString("Content-Type: application/octet-stream\r\n") - bufrw.WriteString("Content-Length: 0\r\n") - bufrw.WriteString(crlf) - bufrw.Write(seed) - bufrw.Flush() - return - } else { - fmt.Printf("Handle IN\n") - - handleConnection(s) - return - } + log.Printf("Starting remote desktop gateway server") + cfg := &tls.Config{} + tlsDebug := os.Getenv("SSLKEYLOGFILE") + if tlsDebug != "" { + w, err := os.OpenFile(tlsDebug, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + log.Fatalf("Cannot open key log file %s for writing %s", tlsDebug, err) } - - next.ServeHTTP(w, r) - }) -} -*/ - -func HelloServer(w http.ResponseWriter, req *http.Request) { - dump, _ := httputil.DumpRequest(req, true) - fmt.Println(dump) - w.Header().Set("Content-Type", "text/plain") - w.Write([]byte("This is an example server.\n")) - // io.WriteString(w, "This is an example server.\n") -} - -func main() { - fmt.Println("Hello!") - mux := http.NewServeMux() - mux.HandleFunc("*", HelloServer) - - w, err := os.OpenFile("tls-secrets.txt", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - cfg := &tls.Config{ - KeyLogWriter: w, + log.Printf("Key log file set to: %s", tlsDebug) + cfg.KeyLogWriter = w } - cert, err := tls.LoadX509KeyPair("server.pem", "key.pem") + cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) if err != nil { log.Fatal(err) } cfg.Certificates = append(cfg.Certificates, cert) server := http.Server{ - Addr: ":8000", - Handler: Upgrade(mux), + Addr: ":" + strconv.Itoa(*port), + Handler: Upgrade(nil), TLSConfig: cfg, } - err = server.ListenAndServeTLS("","") + + err = server.ListenAndServeTLS("", "") if err != nil { log.Fatal("ListenAndServe: ", err) } -} \ No newline at end of file +} diff --git a/rdg.go b/rdg.go index 2d4ebcd..40b018a 100644 --- a/rdg.go +++ b/rdg.go @@ -4,22 +4,25 @@ import ( "bufio" "bytes" "encoding/binary" + "errors" "fmt" + "github.com/patrickmn/go-cache" "io" "log" "math/rand" "net" "net/http" + //"net/http/httputil" "strconv" "time" "unicode/utf16" "unicode/utf8" + "github.com/gorilla/websocket" ) const ( - crlf = "\r\n" + crlf = "\r\n" rdgConnectionIdKey = "Rdg-Connection-Id" - HANDSHAKE = 1 ) const ( @@ -40,6 +43,13 @@ const ( PKT_TYPE_CLOSE_CHANNEL_RESPONSE = 0x11 ) +const ( + HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID = 0x01 + HTTP_TUNNEL_RESPONSE_FIELD_CAPS = 0x02 + HTTP_TUNNEL_RESPONSE_FIELD_SOH_REQ = 0x04 + HTTP_TUNNEL_RESPONSE_FIELD_CONSENT_MSG = 0x10 +) + const ( HTTP_EXTENDED_AUTH_NONE = 0x0 HTTP_EXTENDED_AUTH_SC = 0x1 /* Smart card authentication. */ @@ -47,6 +57,28 @@ const ( HTTP_EXTENDED_AUTH_SSPI_NTLM = 0x04 /* NTLM extended authentication. */ ) +const ( + HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS = 0x01 + HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT = 0x02 + HTTP_TUNNEL_AUTH_RESPONSE_FIELD_SOH_RESPONSE = 0x04 +) + +const ( + HTTP_TUNNEL_REDIR_ENABLE_ALL = 0x80000000 + HTTP_TUNNEL_REDIR_DISABLE_ALL = 0x40000000 + HTTP_TUNNEL_REDIR_DISABLE_DRIVE = 0x01 + HTTP_TUNNEL_REDIR_DISABLE_PRINTER = 0x02 + HTTP_TUNNEL_REDIR_DISABLE_PORT = 0x03 + HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD = 0x08 + HTTP_TUNNEL_REDIR_DISABLE_PNP = 0x10 +) + +const ( + HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID = 0x01 + HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE = 0x02 + HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT = 0x04 +) + const ( HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1 ) @@ -63,10 +95,9 @@ type RdgSession struct { UserId string ConnIn net.Conn ConnOut net.Conn - BufOut *bufio.Writer - BufIn *bufio.Reader - State int - Remote net.Conn + StateIn int + StateOut int + Remote net.Conn } // ErrNotHijacker is an error returned when http.ResponseWriter does not @@ -79,89 +110,196 @@ var ErrNotHijacker = RejectConnectionError( var DefaultSession RdgSession func Upgrade(next http.Handler) http.Handler { - return DefaultSession.RdgHandshake(next) + return RdgHandshake(next) } func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) { - log.Print("Accept connection") - hj, ok := w.(http.Hijacker) - if ok { - return hj.Hijack() - } else { - err = ErrNotHijacker - } - if err != nil { - httpError(w, err.Error(), http.StatusInternalServerError) - return nil, nil, err - } - return + log.Print("Accept connection") + hj, ok := w.(http.Hijacker) + if ok { + return hj.Hijack() + } else { + err = ErrNotHijacker + } + if err != nil { + httpError(w, err.Error(), http.StatusInternalServerError) + return nil, nil, err + } + return } -func (s RdgSession) RdgHandshake(next http.Handler) http.Handler { +var upgrader = websocket.Upgrader{} + +func RdgHandshake(next http.Handler) http.Handler { + c := cache.New(5*time.Minute, 10*time.Minute) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - /*_, _, ok := r.BasicAuth() + var s RdgSession - if !ok && s.ConnIn == nil { - w.Header().Set("WWW-Authenticate", `Basic realm="rdpgw"`) - w.WriteHeader(401) - w.Write([]byte("Unauthorized.\n")) - fmt.Println("Unauthorized") - return - }*/ + connId := r.Header.Get(rdgConnectionIdKey) + x, found := c.Get(connId) + if !found { + log.Printf("No cached session found") + s = RdgSession{ConnId: connId, StateIn: 0, StateOut: 0} + } else { + log.Printf("Found cached session") + s = x.(RdgSession) + } + + log.Printf("Session %s, %t, %t", s.ConnId, s.ConnOut != nil, s.ConnIn != nil) - conn, rw, _ := Accept(w) if r.Method == MethodRDGOUT { + r.Method = "GET" // force + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("Cannot upgrade falling back to old protocol: %s", err) + return + } + defer c.Close() + + fragment := false + buf := make([]byte, 4096) + index := 0 + for { + mt, msg, err := c.ReadMessage() + if err != nil { + log.Printf("Error read: %s", err) + break + } + log.Printf("Message type: %d, message: %x", mt, msg) + + // check for fragments + var pt uint16 + var sz uint32 + var pkt []byte + + if !fragment { + pt, sz, pkt, err = readHeader(msg) + if err != nil { + // fragment received + log.Printf("Received non websocket fragment") + fragment = true + index = copy(buf, msg) + continue + } + index = 0 + } else { + log.Printf("Dealing with fragment") + fragment = false + pt, sz, pkt, _ = readHeader(append(buf[:index], msg...)) + } + + switch pt { + case PKT_TYPE_HANDSHAKE_REQUEST: + major, minor, _, auth := readHandshake(pkt) + msg := handshakeResponse(major, minor, auth) + log.Printf("Handshake response: %x", msg) + c.WriteMessage(mt, msg) + case PKT_TYPE_TUNNEL_CREATE: + readCreateTunnelRequest(pkt) + msg := createTunnelResponse() + log.Printf("Create tunnel response: %x", msg) + c.WriteMessage(mt, msg) + case PKT_TYPE_TUNNEL_AUTH: + readTunnelAuthRequest(pkt) + msg := createTunnelAuthResponse() + log.Printf("Create tunnel auth response: %x", msg) + c.WriteMessage(mt, msg) + case PKT_TYPE_CHANNEL_CREATE: + server, port := readChannelCreateRequest(pkt) + s.Remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port)))) + if err != nil { + log.Printf("Error connecting to %s, %d, %s", server, port, err) + return + } + msg := createChannelCreateResponse() + log.Printf("Create channel create response: %x", msg) + c.WriteMessage(mt, msg) + go handleWebsocketData(s.Remote, mt, c) + case PKT_TYPE_DATA: + forwardDataPacket(s.Remote, pkt) + case PKT_TYPE_KEEPALIVE: + c.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{})) + case PKT_TYPE_CLOSE_CHANNEL: + s.Remote.Close() + return + default: + log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz) + } + } + conn, rw, _ := Accept(w) log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String()) - s.ConnId = r.Header.Get(rdgConnectionIdKey) + s.ConnOut = conn - s.BufOut = rw.Writer - WriteAcceptSeed(rw.Writer) - rw.Writer.Flush() - } else if r.Method == MethodRDGIN { - if s.ConnIn == nil { - defer conn.Close() - s.ConnIn = conn - s.BufIn = rw.Reader - log.Printf("Opening RDGIN for client %s", conn.RemoteAddr().String()) - WriteAcceptSeed(rw.Writer) - rw.Writer.Flush() - p := make([]byte, 4096) - rw.Reader.Read(p) - //log.Printf("Read %q", p) - - log.Printf("Reading packet from client %s", conn.RemoteAddr().String()) - scanner := bufio.NewScanner(rw.Reader) - scanner.Split(ReadPacket) - for scanner.Scan() { - packet := scanner.Bytes() - packetType, size, _, packet := readHeader(packet) - log.Printf("Scanned packet got packet type %x size %d", packetType, size) - switch packetType { - case PKT_TYPE_HANDSHAKE_REQUEST: - major, minor, _, auth := readHandshake(packet) - sendHandshakeResponse(s.BufOut, major, minor, auth) - case PKT_TYPE_TUNNEL_CREATE: - readCreateTunnelRequest(packet) - sendCreateTunnelResponse(s.BufOut) - case PKT_TYPE_TUNNEL_AUTH: - readTunnelAuthRequest(packet) - sendTunnelAuthResponse(s.BufOut) - case PKT_TYPE_CHANNEL_CREATE: - server, port := readChannelCreateRequest(packet) - var err error - s.Remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port)))) - if err != nil { - log.Printf("Error connecting to %s, %d, %s", server, port, err) - return - } - sendChannelCreateResponse(s.BufOut) - go sendDataPacket(s.Remote, s.BufOut) - case PKT_TYPE_DATA: - receiveDataPacket(s.Remote, packet) + WriteAcceptSeed(rw.Writer, true) + + //c.Set(connId, s, cache.DefaultExpiration) + } /*else if r.Method == MethodRDGIN { + if !checkNTLMAuth(w, &s, "IN") { + c.Set(connId, s, cache.DefaultExpiration) + return + } + conn, rw, _ := Accept(w) + + if s.ConnIn == nil { + defer conn.Close() + s.ConnIn = conn + c.Set(connId, s, cache.DefaultExpiration) + log.Printf("Opening RDGIN for client %s", conn.RemoteAddr().String()) + WriteAcceptSeed(rw.Writer, false) + p := make([]byte, 32767) + rw.Reader.Read(p) + //log.Printf("Read %q", p) + + log.Printf("Reading packet from client %s", conn.RemoteAddr().String()) + chunkScanner := httputil.NewChunkedReader(rw.Reader) + packet := make([]byte, 4096) // bufio.defaultBufSize + + for { + n, err := chunkScanner.Read(packet) + if err == io.EOF || n == 0 { + break + } + old_packet := packet + packetType, size, _, packet := readHeader(packet) + log.Printf("Scanned packet got packet type %x size %d", packetType, size) + switch packetType { + case PKT_TYPE_HANDSHAKE_REQUEST: + major, minor, _, auth := readHandshake(packet) + sendHandshakeResponse(s.ConnOut, major, minor, auth) + case PKT_TYPE_TUNNEL_CREATE: + readCreateTunnelRequest(packet) + sendCreateTunnelResponse(s.ConnOut) + case PKT_TYPE_TUNNEL_AUTH: + readTunnelAuthRequest(packet) + sendTunnelAuthResponse(s.ConnOut) + case PKT_TYPE_CHANNEL_CREATE: + server, port := readChannelCreateRequest(packet) + var err error + s.Remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port)))) + if err != nil { + log.Printf("Error connecting to %s, %d, %s", server, port, err) + return } + sendChannelCreateResponse(s.ConnOut) + // Make sure to start the flow from the RDP server first otherwise connections + // might hang eventually + go sendDataPacket(s.Remote, s.ConnOut) + case PKT_TYPE_DATA: + receiveDataPacket(s.Remote, packet) + case PKT_TYPE_KEEPALIVE: + s.ConnOut.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) + case PKT_TYPE_CLOSE_CHANNEL: + s.ConnIn.Close() + s.ConnOut.Close() + break + default: + log.Printf("UNKNOWN PACKET (%d): %x", n, old_packet[:n]) + //receiveDataPacket(s.Remote, old_packet) + receiveUnknownPacket(s.Remote, old_packet, n) } } - } + }*/ }) } @@ -170,58 +308,50 @@ func (s RdgSession) RdgHandshake(next http.Handler) http.Handler { // This enables a reverse proxy to start allowing data from the RDG server to the RDG client. The RDG server does // not specify an entity length in its response. It uses HTTP 1.0 semantics to send the entity body and closes the // connection after the last byte is sent. -func WriteAcceptSeed(bw *bufio.Writer) { +func WriteAcceptSeed(bw *bufio.Writer, doSeed bool) { bw.WriteString(HttpOK) + bw.WriteString("Server: Microsoft-HTTPAPI/2.0\r\n") bw.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n") - bw.WriteString("Content-Type: application/octet-stream\r\n") - bw.WriteString("Content-Length: 0\r\n") - bw.WriteString(crlf) - seed := make([]byte, 10) - rand.Read(seed) - bw.Write(seed) -} - -func ReadPacket(data []byte, atEOF bool) (advance int, packet []byte, err error) { - log.Printf("Reading data len = %d", len(data)) - if atEOF && len(data) == 0 { - return 0, nil, nil - } - - if i := bytes.Index(data, []byte{'\r', '\n'}); i >= 0 { - //log.Printf("Got rn at %d ", i) - chunkSize, err := strconv.ParseInt(string(data[0:i]), 16, 0) - log.Printf("chunkSize %d", chunkSize) - if err != nil { - return i + 2, data[0:i], err - } - //log.Printf("Return %d", i+2+int(chunkSize)+2) - return i + 2 + int(chunkSize) + 2, data[i+2 : i+2+int(chunkSize)+2], nil + if !doSeed { + bw.WriteString("Content-Length: 0\r\n") } + bw.WriteString(crlf) - if atEOF { - return len(data), data, nil + if doSeed { + seed := make([]byte, 10) + rand.Read(seed) + // docs say it's a seed but 2019 responds with ab cd * 5 + bw.Write(seed) } - - return 0, nil, nil + bw.Flush() } -func readHeader(data []byte) (packetType uint16, size uint32, advance int, remain []byte) { +func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) { + // header needs to be 8 min + if len(data) < 8 { + return 0, 0, nil, errors.New("header too short, fragment likely") + } r := bytes.NewReader(data) binary.Read(r, binary.LittleEndian, &packetType) r.Seek(4, io.SeekStart) binary.Read(r, binary.LittleEndian, &size) - return packetType, size, 8, data[8:] + if len(data) < int(size) { + return packetType, size, data[8:], errors.New("data incomplete, fragment received") + } + return packetType, size, data[8:], nil } -func sendHandshakeResponse(w *bufio.Writer, major byte, minor byte, auth uint16) { +// Creates a packet the is a response to a handshake request +// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux +// but could be in Windows. However the NTLM protocol is insecure +func handshakeResponse(major byte, minor byte, auth uint16) []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code buf.Write([]byte{major, minor}) - binary.Write(buf, binary.LittleEndian, uint16(0)) // server version - binary.Write(buf, binary.LittleEndian, uint16(2)) // PAA + binary.Write(buf, binary.LittleEndian, uint16(0)) // server version + binary.Write(buf, binary.LittleEndian, uint16(HTTP_EXTENDED_AUTH_PAA|HTTP_EXTENDED_AUTH_SC)) // extended auth - w.Write(createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())) - w.Flush() + return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes()) } func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth uint16) { @@ -235,7 +365,7 @@ func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth return } -func readCreateTunnelRequest(data []byte) (caps uint32, cookie string){ +func readCreateTunnelRequest(data []byte) (caps uint32, cookie string) { var fields uint16 r := bytes.NewReader(data) @@ -255,16 +385,21 @@ func readCreateTunnelRequest(data []byte) (caps uint32, cookie string){ return } -func sendCreateTunnelResponse(w *bufio.Writer) { +func createTunnelResponse() []byte { buf := new(bytes.Buffer) - binary.Write(buf, binary.LittleEndian, uint16(0)) // server version - binary.Write(buf, binary.LittleEndian, uint32(0)) // error code - binary.Write(buf, binary.LittleEndian, uint16(0)) // fields present - binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved + binary.Write(buf, binary.LittleEndian, uint16(0)) // server version + binary.Write(buf, binary.LittleEndian, uint32(0)) // error code + binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID|HTTP_TUNNEL_RESPONSE_FIELD_CAPS)) // fields present + binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved + binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved - w.Write(createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())) - w.Flush() + // tunnel id ? + binary.Write(buf, binary.LittleEndian, uint32(15)) + // caps ? + binary.Write(buf, binary.LittleEndian, uint32(2)) + + return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes()) } func readTunnelAuthRequest(data []byte) { @@ -278,18 +413,21 @@ func readTunnelAuthRequest(data []byte) { log.Printf("Client: %s", clientName) } -func sendTunnelAuthResponse(w *bufio.Writer) { +func createTunnelAuthResponse() []byte { buf := new(bytes.Buffer) - binary.Write(buf, binary.LittleEndian, uint32(0)) // error code - binary.Write(buf, binary.LittleEndian, uint16(0)) // fields present - binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved + binary.Write(buf, binary.LittleEndian, uint32(0)) // error code + binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present + binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved + + // flags + binary.Write(buf, binary.LittleEndian, uint32(HTTP_TUNNEL_REDIR_ENABLE_ALL)) // redir flags + binary.Write(buf, binary.LittleEndian, uint32(0)) // timeout in minutes - w.Write(createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())) - w.Flush() + return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes()) } -func readChannelCreateRequest(data []byte) (server string, port uint16){ +func readChannelCreateRequest(data []byte) (server string, port uint16) { buf := bytes.NewReader(data) var resourcesSize byte @@ -313,55 +451,55 @@ func readChannelCreateRequest(data []byte) (server string, port uint16){ return } -func sendChannelCreateResponse(w *bufio.Writer) { +func createChannelCreateResponse() []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.LittleEndian, uint32(0)) // error code - binary.Write(buf, binary.LittleEndian, uint16(0)) // fields present + //binary.Write(buf, binary.LittleEndian, uint16(HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID | HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE | HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT)) // fields present + binary.Write(buf, binary.LittleEndian, uint16(0)) // fields binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved - w.Write(createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())) - w.Flush() + // optional fields + // channel id uint32 (4) + // udp port uint16 (2) + // udp auth cookie 1 byte for side channel + // length uint16 + + return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes()) } -func createPacket(pktType uint16, data []byte) (packet []byte){ +func createPacket(pktType uint16, data []byte) (packet []byte) { size := len(data) + 8 buf := new(bytes.Buffer) - log.Printf("Data sent Size: %d", size) - // http chunk size in hex string - // fmt.Fprintf(buf,"%x\r\n", size) - binary.Write(buf, binary.LittleEndian, uint16(pktType)) - binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved + binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved binary.Write(buf, binary.LittleEndian, uint32(size)) buf.Write(data) - // http close crlf - // buf.Write([]byte(crlf)) - // log.Printf("data sent: %q", buf.Bytes()) return buf.Bytes() } -func receiveDataPacket(conn net.Conn, data []byte) { +func forwardDataPacket(conn net.Conn, data []byte) { buf := bytes.NewReader(data) var cblen uint16 binary.Read(buf, binary.LittleEndian, &cblen) - log.Printf("Received PKT_DATA %d", cblen) + //log.Printf("Received PKT_DATA %d", cblen) pkt := make([]byte, cblen) - //binary.Read(buf, binary.LittleEndian, &pkt) - buf.Read(pkt) + binary.Read(buf, binary.LittleEndian, &pkt) + //n, _ := buf.Read(pkt) + //log.Printf("CBLEN: %d, N: %d", cblen, n) //log.Printf("DATA FROM CLIENT %q", pkt) conn.Write(pkt) } -func sendDataPacket(conn net.Conn, w *bufio.Writer) { - defer conn.Close() +func handleWebsocketData(rdp net.Conn, mt int, conn *websocket.Conn) { + defer rdp.Close() b1 := new(bytes.Buffer) - buf := make([]byte, 32767) + buf := make([]byte, 4086) for { - n, err := conn.Read(buf) + n, err := rdp.Read(buf) binary.Write(b1, binary.LittleEndian, uint16(n)) log.Printf("RDP SIZE: %d", n) if err != nil { @@ -369,16 +507,32 @@ func sendDataPacket(conn net.Conn, w *bufio.Writer) { break } b1.Write(buf[:n]) - w.Write(createPacket(PKT_TYPE_DATA, b1.Bytes())) - w.Flush() + conn.WriteMessage(mt, createPacket(PKT_TYPE_DATA, b1.Bytes())) + b1.Reset() + } +} + +func sendDataPacket(connIn net.Conn, connOut net.Conn) { + defer connIn.Close() + b1 := new(bytes.Buffer) + buf := make([]byte, 4086) + for { + n, err := connIn.Read(buf) + binary.Write(b1, binary.LittleEndian, uint16(n)) + log.Printf("RDP SIZE: %d", n) + if err != nil { + log.Printf("Error reading from conn %s", err) + break + } + b1.Write(buf[:n]) + connOut.Write(createPacket(PKT_TYPE_DATA, b1.Bytes())) b1.Reset() } } func DecodeUTF16(b []byte) (string, error) { if len(b)%2 != 0 { - log.Printf("Error decoding utf16") - return "", fmt.Errorf("Must have even length byte slice") + return "", fmt.Errorf("must have even length byte slice") } u16s := make([]uint16, 1) @@ -394,4 +548,4 @@ func DecodeUTF16(b []byte) (string, error) { } return ret.String(), nil -} \ No newline at end of file +} -- GitLab