diff --git a/rdg.go b/rdg.go index b4984aa51ea03030b58aa8198d4a8a8e84cdc45e..c1f57aef76d39c249201c3638e656217bac8ec52 100644 --- a/rdg.go +++ b/rdg.go @@ -154,10 +154,10 @@ var c = cache.New(5*time.Minute, 10*time.Minute) func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) { connectionCache.Set(float64(c.ItemCount())) if r.Method == MethodRDGOUT { - if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { + //if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { handleLegacyProtocol(w, r) return - } + //} r.Method = "GET" // force conn, err := upgrader.Upgrade(w, r, nil) if err != nil { @@ -270,7 +270,11 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { log.Printf("Session %s, %t, %t", s.ConnId, s.ConnOut != nil, s.ConnIn != nil) if r.Method == MethodRDGOUT { - conn, rw, _ := Accept(w) + conn, rw, err := Accept(w) + if err != nil { + log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) + return + } log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String()) s.ConnOut = conn @@ -283,10 +287,18 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { var remote net.Conn - conn, rw, _ := Accept(w) + conn, rw, err := Accept(w) + if err != nil { + log.Printf("cannot hijack connection to support RDG IN data channel: %s", err) + return + } + defer conn.Close() if s.ConnIn == nil { - defer conn.Close() + fragment := false + index := 0 + buf := make([]byte, 4096) + s.ConnIn = conn c.Set(connId, s, cache.DefaultExpiration) log.Printf("Opening RDGIN for client %s", conn.RemoteAddr().String()) @@ -296,33 +308,51 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { log.Printf("Reading packet from client %s", conn.RemoteAddr().String()) chunkScanner := httputil.NewChunkedReader(rw.Reader) - packet := make([]byte, 4096) // bufio.defaultBufSize + msg := make([]byte, 4096) // bufio.defaultBufSize for { - n, err := chunkScanner.Read(packet) + n, err := chunkScanner.Read(msg) if err == io.EOF || n == 0 { break } - packetType, size, packet, err := readHeader(packet) - if err != nil { - log.Printf("Need to deal with fragment %s", err) + + // check for fragments + var pt uint16 + var sz uint32 + var pkt []byte + + if !fragment { + pt, sz, pkt, err = readHeader(msg[:n]) + if err != nil { + // fragment received + log.Printf("Received non websocket fragment") + fragment = true + index = copy(buf, msg[:n]) + continue + } + index = 0 + } else { + log.Printf("Dealing with fragment") + fragment = false + pt, sz, pkt, _ = readHeader(append(buf[:index], msg[:n]...)) } - log.Printf("Scanned packet got packet type %x size %d", packetType, size) - switch packetType { + + log.Printf("Scanned packet got packet type %x size %d", pt, sz) + switch pt { case PKT_TYPE_HANDSHAKE_REQUEST: - major, minor, _, auth := readHandshake(packet) + major, minor, _, auth := readHandshake(pkt) msg := handshakeResponse(major, minor, auth) s.ConnOut.Write(msg) case PKT_TYPE_TUNNEL_CREATE: - readCreateTunnelRequest(packet) + readCreateTunnelRequest(pkt) msg := createTunnelResponse() s.ConnOut.Write(msg) case PKT_TYPE_TUNNEL_AUTH: - readTunnelAuthRequest(packet) + readTunnelAuthRequest(pkt) msg := createTunnelAuthResponse() s.ConnOut.Write(msg) case PKT_TYPE_CHANNEL_CREATE: - server, port := readChannelCreateRequest(packet) + server, port := readChannelCreateRequest(pkt) var err error remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port)))) if err != nil { @@ -336,7 +366,7 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { // might hang eventually go sendDataPacket(remote, s.ConnOut) case PKT_TYPE_DATA: - forwardDataPacket(remote, packet) + forwardDataPacket(remote, pkt) case PKT_TYPE_KEEPALIVE: // avoid concurrency issues // s.ConnOut.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) @@ -345,7 +375,7 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { s.ConnOut.Close() break default: - log.Printf("Unknown packet (size %d): %x", n, packet) + log.Printf("Unknown packet (size %d): %x", sz, pkt[:n]) } } }