diff --git a/rdg.go b/rdg.go index 94326450784eff21b6802aeb8c12b64b2b1a291c..c70c16f1d4a1394978227d6f1160d45244ddc6ff 100644 --- a/rdg.go +++ b/rdg.go @@ -13,6 +13,7 @@ import ( "math/rand" "net" "net/http" + "net/http/httputil" "strconv" "time" @@ -133,21 +134,11 @@ var c = cache.New(5*time.Minute, 10*time.Minute) func handleGatewayProtocol(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var s RdgSession - - connId := r.Header.Get(rdgConnectionIdKey) - x, found := c.Get(connId) - if !found { - log.Printf("No cached session found") - s = RdgSession{ConnId: connId, StateIn: 0, StateOut: 0} - } else { - log.Printf("Found cached session") - s = x.(RdgSession) - } - - log.Printf("Session %s, %t, %t", s.ConnId, s.ConnOut != nil, s.ConnIn != nil) - if r.Method == MethodRDGOUT { + if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { + handleLegacyProtocol(w, r) + return + } r.Method = "GET" // force conn, err := upgrader.Upgrade(w, r, nil) if err != nil { @@ -157,81 +148,9 @@ func handleGatewayProtocol(next http.Handler) http.Handler { defer conn.Close() handleWebsocketProtocol(conn) - - - //conn, rw, _ := Accept(w) - //log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String()) - - //s.ConnOut = conn - //WriteAcceptSeed(rw.Writer, true) - - //c.Set(connId, s, cache.DefaultExpiration) - } /*else if r.Method == MethodRDGIN { - if !checkNTLMAuth(w, &s, "IN") { - c.Set(connId, s, cache.DefaultExpiration) - return + } else if r.Method == MethodRDGIN { + handleLegacyProtocol(w, r) } - conn, rw, _ := Accept(w) - - if s.ConnIn == nil { - defer conn.Close() - s.ConnIn = conn - c.Set(connId, s, cache.DefaultExpiration) - log.Printf("Opening RDGIN for client %s", conn.RemoteAddr().String()) - WriteAcceptSeed(rw.Writer, false) - p := make([]byte, 32767) - rw.Reader.Read(p) - //log.Printf("Read %q", p) - - log.Printf("Reading packet from client %s", conn.RemoteAddr().String()) - chunkScanner := httputil.NewChunkedReader(rw.Reader) - packet := make([]byte, 4096) // bufio.defaultBufSize - - for { - n, err := chunkScanner.Read(packet) - if err == io.EOF || n == 0 { - break - } - old_packet := packet - packetType, size, _, packet := readHeader(packet) - log.Printf("Scanned packet got packet type %x size %d", packetType, size) - switch packetType { - case PKT_TYPE_HANDSHAKE_REQUEST: - major, minor, _, auth := readHandshake(packet) - sendHandshakeResponse(s.ConnOut, major, minor, auth) - case PKT_TYPE_TUNNEL_CREATE: - readCreateTunnelRequest(packet) - sendCreateTunnelResponse(s.ConnOut) - case PKT_TYPE_TUNNEL_AUTH: - readTunnelAuthRequest(packet) - sendTunnelAuthResponse(s.ConnOut) - case PKT_TYPE_CHANNEL_CREATE: - server, port := readChannelCreateRequest(packet) - var err error - s.Remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port)))) - if err != nil { - log.Printf("Error connecting to %s, %d, %s", server, port, err) - return - } - sendChannelCreateResponse(s.ConnOut) - // Make sure to start the flow from the RDP server first otherwise connections - // might hang eventually - go sendDataPacket(s.Remote, s.ConnOut) - case PKT_TYPE_DATA: - receiveDataPacket(s.Remote, packet) - case PKT_TYPE_KEEPALIVE: - s.ConnOut.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) - case PKT_TYPE_CLOSE_CHANNEL: - s.ConnIn.Close() - s.ConnOut.Close() - break - default: - log.Printf("UNKNOWN PACKET (%d): %x", n, old_packet[:n]) - //receiveDataPacket(s.Remote, old_packet) - receiveUnknownPacket(s.Remote, old_packet, n) - } - } - }*/ }) } @@ -311,17 +230,115 @@ func handleWebsocketProtocol(conn *websocket.Conn) { } } +// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server +// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different +// to ensure the connections do not get cached or terminated by a proxy prematurely. +func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { + var s RdgSession + + connId := r.Header.Get(rdgConnectionIdKey) + x, found := c.Get(connId) + if !found { + log.Printf("No cached session found") + s = RdgSession{ConnId: connId, StateIn: 0, StateOut: 0} + } else { + log.Printf("Found cached session") + s = x.(RdgSession) + } + + log.Printf("Session %s, %t, %t", s.ConnId, s.ConnOut != nil, s.ConnIn != nil) + + if r.Method == MethodRDGOUT { + conn, rw, _ := Accept(w) + log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String()) + + s.ConnOut = conn + WriteAcceptSeed(rw.Writer, true) + + c.Set(connId, s, cache.DefaultExpiration) + } else if r.Method == MethodRDGIN { + var remote net.Conn + + conn, rw, _ := Accept(w) + + if s.ConnIn == nil { + defer conn.Close() + s.ConnIn = conn + c.Set(connId, s, cache.DefaultExpiration) + log.Printf("Opening RDGIN for client %s", conn.RemoteAddr().String()) + WriteAcceptSeed(rw.Writer, false) + p := make([]byte, 32767) + rw.Reader.Read(p) + + log.Printf("Reading packet from client %s", conn.RemoteAddr().String()) + chunkScanner := httputil.NewChunkedReader(rw.Reader) + packet := make([]byte, 4096) // bufio.defaultBufSize + + for { + n, err := chunkScanner.Read(packet) + if err == io.EOF || n == 0 { + break + } + packetType, size, packet, err := readHeader(packet) + if err != nil { + log.Printf("Need to deal with fragment %s", err) + } + log.Printf("Scanned packet got packet type %x size %d", packetType, size) + switch packetType { + case PKT_TYPE_HANDSHAKE_REQUEST: + major, minor, _, auth := readHandshake(packet) + msg := handshakeResponse(major, minor, auth) + s.ConnOut.Write(msg) + case PKT_TYPE_TUNNEL_CREATE: + readCreateTunnelRequest(packet) + msg := createTunnelResponse() + s.ConnOut.Write(msg) + case PKT_TYPE_TUNNEL_AUTH: + readTunnelAuthRequest(packet) + msg := createTunnelAuthResponse() + s.ConnOut.Write(msg) + case PKT_TYPE_CHANNEL_CREATE: + server, port := readChannelCreateRequest(packet) + var err error + remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port)))) + if err != nil { + log.Printf("Error connecting to %s, %d, %s", server, port, err) + return + } + msg := createChannelCreateResponse() + s.ConnOut.Write(msg) + + // Make sure to start the flow from the RDP server first otherwise connections + // might hang eventually + go sendDataPacket(remote, s.ConnOut) + case PKT_TYPE_DATA: + forwardDataPacket(remote, packet) + case PKT_TYPE_KEEPALIVE: + s.ConnOut.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) + case PKT_TYPE_CLOSE_CHANNEL: + s.ConnIn.Close() + s.ConnOut.Close() + remote.Close() + break + default: + log.Printf("UNKNOWN PACKET (%d): %x", n, packet) + } + } + } + } +} + // [MS-TSGU]: Terminal Services Gateway Server Protocol version 39.0 // The server sends back the final status code 200 OK, and also a random entity body of limited size (100 bytes). // This enables a reverse proxy to start allowing data from the RDG server to the RDG client. The RDG server does // not specify an entity length in its response. It uses HTTP 1.0 semantics to send the entity body and closes the // connection after the last byte is sent. func WriteAcceptSeed(bw *bufio.Writer, doSeed bool) { + log.Printf("Writing accept") bw.WriteString(HttpOK) - bw.WriteString("Server: Microsoft-HTTPAPI/2.0\r\n") - bw.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n") + bw.WriteString("Date: " + time.Now().Format(time.RFC1123) + crlf) if !doSeed { - bw.WriteString("Content-Length: 0\r\n") + bw.WriteString("Content-Length: 0" + crlf) } bw.WriteString(crlf)