From 2f78a7fd8e8c2d86fcb12330fea42c2de832a1a4 Mon Sep 17 00:00:00 2001 From: Bolke de Bruin <bolke@xs4all.nl> Date: Mon, 20 Jul 2020 14:29:24 +0200 Subject: [PATCH] Normalize packet handling --- protocol/handler.go | 73 +++++++++++++++++++++++++++++++++++++++ rdg.go | 83 ++++++--------------------------------------- 2 files changed, 84 insertions(+), 72 deletions(-) create mode 100644 protocol/handler.go diff --git a/protocol/handler.go b/protocol/handler.go new file mode 100644 index 0000000..b48ee7f --- /dev/null +++ b/protocol/handler.go @@ -0,0 +1,73 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "errors" + "github.com/bolkedebruin/rdpgw/transport" + "io" +) + +type Handler struct { + Transport transport.Transport +} + +func NewHandler(t transport.Transport) *Handler { + h := &Handler{ + Transport: t, + } + return h +} + +func (p *Handler) ReadMessage() (pt int, n int, msg []byte, err error) { + fragment := false + index := 0 + buf := make([]byte, 4096) + + for { + size, pkt, err := p.Transport.ReadPacket() + if err != nil { + return 0, 0, []byte{0,0}, err + } + + // check for fragments + var pt uint16 + var sz uint32 + var msg []byte + + if !fragment { + pt, sz, msg, err = readHeader(pkt[:size]) + if err != nil { + fragment = true + index = copy(buf, pkt[:size]) + continue + } + index = 0 + } else { + fragment = false + pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...)) + // header is corrupted even after defragmenting + if err != nil { + return 0, 0, []byte{0,0}, err + } + } + if !fragment { + return int(pt), int(sz), msg, nil + } + } +} + +func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) { + // header needs to be 8 min + if len(data) < 8 { + return 0, 0, nil, errors.New("header too short, fragment likely") + } + r := bytes.NewReader(data) + binary.Read(r, binary.LittleEndian, &packetType) + r.Seek(4, io.SeekStart) + binary.Read(r, binary.LittleEndian, &size) + if len(data) < int(size) { + return packetType, size, data[8:], errors.New("data incomplete, fragment received") + } + return packetType, size, data[8:], nil +} diff --git a/rdg.go b/rdg.go index 0cd3094..e478329 100644 --- a/rdg.go +++ b/rdg.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/bolkedebruin/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/transport" "github.com/gorilla/websocket" "github.com/patrickmn/go-cache" @@ -156,46 +157,21 @@ func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) { } func handleWebsocketProtocol(c *websocket.Conn) { - fragment := false - buf := make([]byte, 4096) - index := 0 - var remote net.Conn websocketConnections.Inc() defer websocketConnections.Dec() inout, _ := transport.NewWS(c) + handler := protocol.NewHandler(inout) var host string for { - _, msg, err := inout.ReadPacket() + pt, sz, pkt, err := handler.ReadMessage() if err != nil { - log.Printf("Error read: %s", err) - break - } - - // check for fragments - var pt uint16 - var sz uint32 - var pkt []byte - - if !fragment { - pt, sz, pkt, err = readHeader(msg) - if err != nil { - // fragment received - // log.Printf("Received non websocket fragment") - fragment = true - index = copy(buf, msg) - continue - } - index = 0 - } else { - //log.Printf("Dealing with fragment") - fragment = false - pt, sz, pkt, _ = readHeader(append(buf[:index], msg...)) + log.Printf("Cannot read message from stream %s", err) + return } - switch pt { case PKT_TYPE_HANDSHAKE_REQUEST: major, minor, _, auth := readHandshake(pkt) @@ -301,10 +277,6 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { defer in.Close() if s.TransportIn == nil { - fragment := false - index := 0 - buf := make([]byte, 4096) - s.TransportIn = in c.Set(connId, s, cache.DefaultExpiration) @@ -315,30 +287,12 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { in.Drain() log.Printf("Reading packet from client %s", in.Conn.RemoteAddr().String()) - + handler := protocol.NewHandler(in) for { - n, msg, err := in.ReadPacket() - if err == io.EOF || n == 0 { - break - } - - // check for fragments - var pt uint16 - var sz uint32 - var pkt []byte - - if !fragment { - pt, sz, pkt, err = readHeader(msg[:n]) - if err != nil { - // fragment received - fragment = true - index = copy(buf, msg[:n]) - continue - } - index = 0 - } else { - fragment = false - pt, sz, pkt, _ = readHeader(append(buf[:index], msg[:n]...)) + pt, sz, pkt, err := handler.ReadMessage() + if err != nil { + log.Printf("Cannot read message from stream %s", err) + return } switch pt { @@ -386,28 +340,13 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { s.TransportOut.Close() break default: - log.Printf("Unknown packet (size %d): %x", sz, pkt[:n]) + log.Printf("Unknown packet (size %d): %x", sz, pkt) } } } } } -func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) { - // header needs to be 8 min - if len(data) < 8 { - return 0, 0, nil, errors.New("header too short, fragment likely") - } - r := bytes.NewReader(data) - binary.Read(r, binary.LittleEndian, &packetType) - r.Seek(4, io.SeekStart) - binary.Read(r, binary.LittleEndian, &size) - if len(data) < int(size) { - return packetType, size, data[8:], errors.New("data incomplete, fragment received") - } - return packetType, size, data[8:], nil -} - // Creates a packet the is a response to a handshake request // HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux // but could be in Windows. However the NTLM protocol is insecure -- GitLab