Skip to content
Snippets Groups Projects
Commit 2f78a7fd authored by Bolke de Bruin's avatar Bolke de Bruin
Browse files

Normalize packet handling

parent 46be8de0
Branches
Tags
No related merge requests found
package protocol
import (
"bytes"
"encoding/binary"
"errors"
"github.com/bolkedebruin/rdpgw/transport"
"io"
)
type Handler struct {
Transport transport.Transport
}
func NewHandler(t transport.Transport) *Handler {
h := &Handler{
Transport: t,
}
return h
}
func (p *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
fragment := false
index := 0
buf := make([]byte, 4096)
for {
size, pkt, err := p.Transport.ReadPacket()
if err != nil {
return 0, 0, []byte{0,0}, err
}
// check for fragments
var pt uint16
var sz uint32
var msg []byte
if !fragment {
pt, sz, msg, err = readHeader(pkt[:size])
if err != nil {
fragment = true
index = copy(buf, pkt[:size])
continue
}
index = 0
} else {
fragment = false
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
// header is corrupted even after defragmenting
if err != nil {
return 0, 0, []byte{0,0}, err
}
}
if !fragment {
return int(pt), int(sz), msg, nil
}
}
}
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
// header needs to be 8 min
if len(data) < 8 {
return 0, 0, nil, errors.New("header too short, fragment likely")
}
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &packetType)
r.Seek(4, io.SeekStart)
binary.Read(r, binary.LittleEndian, &size)
if len(data) < int(size) {
return packetType, size, data[8:], errors.New("data incomplete, fragment received")
}
return packetType, size, data[8:], nil
}
......@@ -5,6 +5,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"github.com/bolkedebruin/rdpgw/protocol"
"github.com/bolkedebruin/rdpgw/transport"
"github.com/gorilla/websocket"
"github.com/patrickmn/go-cache"
......@@ -156,46 +157,21 @@ func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
}
func handleWebsocketProtocol(c *websocket.Conn) {
fragment := false
buf := make([]byte, 4096)
index := 0
var remote net.Conn
websocketConnections.Inc()
defer websocketConnections.Dec()
inout, _ := transport.NewWS(c)
handler := protocol.NewHandler(inout)
var host string
for {
_, msg, err := inout.ReadPacket()
pt, sz, pkt, err := handler.ReadMessage()
if err != nil {
log.Printf("Error read: %s", err)
break
}
// check for fragments
var pt uint16
var sz uint32
var pkt []byte
if !fragment {
pt, sz, pkt, err = readHeader(msg)
if err != nil {
// fragment received
// log.Printf("Received non websocket fragment")
fragment = true
index = copy(buf, msg)
continue
}
index = 0
} else {
//log.Printf("Dealing with fragment")
fragment = false
pt, sz, pkt, _ = readHeader(append(buf[:index], msg...))
log.Printf("Cannot read message from stream %s", err)
return
}
switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST:
major, minor, _, auth := readHandshake(pkt)
......@@ -301,10 +277,6 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
defer in.Close()
if s.TransportIn == nil {
fragment := false
index := 0
buf := make([]byte, 4096)
s.TransportIn = in
c.Set(connId, s, cache.DefaultExpiration)
......@@ -315,30 +287,12 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
in.Drain()
log.Printf("Reading packet from client %s", in.Conn.RemoteAddr().String())
handler := protocol.NewHandler(in)
for {
n, msg, err := in.ReadPacket()
if err == io.EOF || n == 0 {
break
}
// check for fragments
var pt uint16
var sz uint32
var pkt []byte
if !fragment {
pt, sz, pkt, err = readHeader(msg[:n])
pt, sz, pkt, err := handler.ReadMessage()
if err != nil {
// fragment received
fragment = true
index = copy(buf, msg[:n])
continue
}
index = 0
} else {
fragment = false
pt, sz, pkt, _ = readHeader(append(buf[:index], msg[:n]...))
log.Printf("Cannot read message from stream %s", err)
return
}
switch pt {
......@@ -386,28 +340,13 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
s.TransportOut.Close()
break
default:
log.Printf("Unknown packet (size %d): %x", sz, pkt[:n])
log.Printf("Unknown packet (size %d): %x", sz, pkt)
}
}
}
}
}
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
// header needs to be 8 min
if len(data) < 8 {
return 0, 0, nil, errors.New("header too short, fragment likely")
}
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &packetType)
r.Seek(4, io.SeekStart)
binary.Read(r, binary.LittleEndian, &size)
if len(data) < int(size) {
return packetType, size, data[8:], errors.New("data incomplete, fragment received")
}
return packetType, size, data[8:], nil
}
// Creates a packet the is a response to a handshake request
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
// but could be in Windows. However the NTLM protocol is insecure
......
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment