diff --git a/protocol/handler.go b/protocol/handler.go
new file mode 100644
index 0000000000000000000000000000000000000000..b48ee7f9ecb374b8d4625ee21afd64516891c8bd
--- /dev/null
+++ b/protocol/handler.go
@@ -0,0 +1,73 @@
+package protocol
+
+import (
+	"bytes"
+	"encoding/binary"
+	"errors"
+	"github.com/bolkedebruin/rdpgw/transport"
+	"io"
+)
+
+type Handler struct {
+	Transport transport.Transport
+}
+
+func NewHandler(t transport.Transport) *Handler {
+	h := &Handler{
+		Transport: t,
+	}
+	return h
+}
+
+func (p *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
+	fragment := false
+	index := 0
+	buf := make([]byte, 4096)
+
+	for {
+		size, pkt, err := p.Transport.ReadPacket()
+		if err != nil {
+			return 0, 0, []byte{0,0}, err
+		}
+
+		// check for fragments
+		var pt uint16
+		var sz uint32
+		var msg []byte
+
+		if !fragment {
+			pt, sz, msg, err = readHeader(pkt[:size])
+			if err != nil {
+				fragment = true
+				index = copy(buf, pkt[:size])
+				continue
+			}
+			index = 0
+		} else {
+			fragment = false
+			pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
+			// header is corrupted even after defragmenting
+			if err != nil {
+				return 0, 0, []byte{0,0}, err
+			}
+		}
+		if !fragment {
+			return int(pt), int(sz), msg, nil
+		}
+	}
+}
+
+func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
+	// header needs to be 8 min
+	if len(data) < 8 {
+		return 0, 0, nil, errors.New("header too short, fragment likely")
+	}
+	r := bytes.NewReader(data)
+	binary.Read(r, binary.LittleEndian, &packetType)
+	r.Seek(4, io.SeekStart)
+	binary.Read(r, binary.LittleEndian, &size)
+	if len(data) < int(size) {
+		return packetType, size, data[8:], errors.New("data incomplete, fragment received")
+	}
+	return packetType, size, data[8:], nil
+}
diff --git a/rdg.go b/rdg.go
index 0cd309459cc7e90764680f90935303befd74aae8..e478329d1a1be4cd5ebaa583a8e9a26183536669 100644
--- a/rdg.go
+++ b/rdg.go
@@ -5,6 +5,7 @@ import (
 	"encoding/binary"
 	"errors"
 	"fmt"
+	"github.com/bolkedebruin/rdpgw/protocol"
 	"github.com/bolkedebruin/rdpgw/transport"
 	"github.com/gorilla/websocket"
 	"github.com/patrickmn/go-cache"
@@ -156,46 +157,21 @@ func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
 }
 
 func handleWebsocketProtocol(c *websocket.Conn) {
-	fragment := false
-	buf := make([]byte, 4096)
-	index := 0
-
 	var remote net.Conn
 
 	websocketConnections.Inc()
 	defer websocketConnections.Dec()
 
 	inout, _ := transport.NewWS(c)
+	handler := protocol.NewHandler(inout)
 
 	var host string
 	for {
-		_, msg, err := inout.ReadPacket()
+		pt, sz, pkt, err := handler.ReadMessage()
 		if err != nil {
-			log.Printf("Error read: %s", err)
-			break
-		}
-
-		// check for fragments
-		var pt uint16
-		var sz uint32
-		var pkt []byte
-
-		if !fragment {
-			pt, sz, pkt, err = readHeader(msg)
-			if err != nil {
-				// fragment received
-				// log.Printf("Received non websocket fragment")
-				fragment = true
-				index = copy(buf, msg)
-				continue
-			}
-			index = 0
-		} else {
-			//log.Printf("Dealing with fragment")
-			fragment = false
-			pt, sz, pkt, _ = readHeader(append(buf[:index], msg...))
+			log.Printf("Cannot read message from stream %s", err)
+			return
 		}
-
 		switch pt {
 		case PKT_TYPE_HANDSHAKE_REQUEST:
 			major, minor, _, auth := readHandshake(pkt)
@@ -301,10 +277,6 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
 		defer in.Close()
 
 		if s.TransportIn == nil {
-			fragment := false
-			index := 0
-			buf := make([]byte, 4096)
-
 			s.TransportIn = in
 			c.Set(connId, s, cache.DefaultExpiration)
 
@@ -315,30 +287,12 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
 			in.Drain()
 
 			log.Printf("Reading packet from client %s", in.Conn.RemoteAddr().String())
-
+			handler := protocol.NewHandler(in)
 			for {
-				n, msg, err := in.ReadPacket()
-				if err == io.EOF || n == 0 {
-					break
-				}
-
-				// check for fragments
-				var pt uint16
-				var sz uint32
-				var pkt []byte
-
-				if !fragment {
-					pt, sz, pkt, err = readHeader(msg[:n])
-					if err != nil {
-						// fragment received
-						fragment = true
-						index = copy(buf, msg[:n])
-						continue
-					}
-					index = 0
-				} else {
-					fragment = false
-					pt, sz, pkt, _ = readHeader(append(buf[:index], msg[:n]...))
+				pt, sz, pkt, err := handler.ReadMessage()
+				if err != nil {
+					log.Printf("Cannot read message from stream %s", err)
+					return
 				}
 
 				switch pt {
@@ -386,28 +340,13 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
 					s.TransportOut.Close()
 					break
 				default:
-					log.Printf("Unknown packet (size %d): %x", sz, pkt[:n])
+					log.Printf("Unknown packet (size %d): %x", sz, pkt)
 				}
 			}
 		}
 	}
 }
 
-func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
-	// header needs to be 8 min
-	if len(data) < 8 {
-		return 0, 0, nil, errors.New("header too short, fragment likely")
-	}
-	r := bytes.NewReader(data)
-	binary.Read(r, binary.LittleEndian, &packetType)
-	r.Seek(4, io.SeekStart)
-	binary.Read(r, binary.LittleEndian, &size)
-	if len(data) < int(size) {
-		return packetType, size, data[8:], errors.New("data incomplete, fragment received")
-	}
-	return packetType, size, data[8:], nil
-}
-
 // Creates a packet the is a response to a handshake request
 // HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
 // but could be in Windows. However the NTLM protocol is insecure