diff --git a/protocol/handler.go b/protocol/handler.go
index b48ee7f9ecb374b8d4625ee21afd64516891c8bd..07be70d999359becb26579c52a94353208b30c0a 100644
--- a/protocol/handler.go
+++ b/protocol/handler.go
@@ -6,10 +6,28 @@ import (
 	"errors"
 	"github.com/bolkedebruin/rdpgw/transport"
 	"io"
+	"log"
+	"net"
+	"strconv"
+	"time"
 )
 
+// When should the client disconnect when idle in minutes
+var IdleTimeout = 0
+
+type VerifyPAACookieFunc func(string) (bool, error)
+type VerifyTunnelAuthFunc func(string) (bool, error)
+type VerifyServerFunc func(string) (bool, error)
+
 type Handler struct {
-	Transport transport.Transport
+	Transport            transport.Transport
+	VerifyPAACookieFunc  VerifyPAACookieFunc
+	VerifyTunnelAuthFunc VerifyTunnelAuthFunc
+	VerifyServerFunc     VerifyServerFunc
+	SmartCardAuth        bool
+	TokenAuth            bool
+	ClientName           string
+	Remote				 net.Conn
 }
 
 func NewHandler(t transport.Transport) *Handler {
@@ -19,15 +37,73 @@ func NewHandler(t transport.Transport) *Handler {
 	return h
 }
 
-func (p *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
+func (h *Handler) Process() error {
+	for {
+		pt, sz, pkt, err := h.ReadMessage()
+		if err != nil {
+			log.Printf("Cannot read message from stream %s", err)
+			return err
+		}
+
+		switch pt {
+		case PKT_TYPE_HANDSHAKE_REQUEST:
+			major, minor, _, auth := readHandshake(pkt)
+			msg := h.handshakeResponse(major, minor, auth)
+			h.Transport.WritePacket(msg)
+		case PKT_TYPE_TUNNEL_CREATE:
+			_, cookie := readCreateTunnelRequest(pkt)
+			if h.VerifyPAACookieFunc != nil {
+				if ok, _ := h.VerifyPAACookieFunc(cookie); ok == false {
+					log.Printf("Invalid PAA cookie: %s", cookie)
+					return errors.New("invalid PAA cookie")
+				}
+			}
+			msg := createTunnelResponse()
+			h.Transport.WritePacket(msg)
+		case PKT_TYPE_TUNNEL_AUTH:
+			h.readTunnelAuthRequest(pkt)
+			msg := h.createTunnelAuthResponse()
+			h.Transport.WritePacket(msg)
+		case PKT_TYPE_CHANNEL_CREATE:
+			server, port := readChannelCreateRequest(pkt)
+			log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server)
+			h.Remote, err = net.DialTimeout(
+				"tcp",
+				net.JoinHostPort(server, strconv.Itoa(int(port))),
+				time.Second*15)
+			if err != nil {
+				log.Printf("Error connecting to %s, %d, %s", server, port, err)
+				return err
+			}
+			log.Printf("Connection established")
+			msg := createChannelCreateResponse()
+			h.Transport.WritePacket(msg)
+
+			// Make sure to start the flow from the RDP server first otherwise connections
+			// might hang eventually
+			go h.sendDataPacket()
+		case PKT_TYPE_DATA:
+			h.forwardDataPacket(pkt)
+		case PKT_TYPE_KEEPALIVE:
+			// avoid concurrency issues
+			// p.Transport.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
+		case PKT_TYPE_CLOSE_CHANNEL:
+			h.Transport.Close()
+		default:
+			log.Printf("Unknown packet (size %d): %x", sz, pkt)
+		}
+	}
+}
+
+func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
 	fragment := false
 	index := 0
 	buf := make([]byte, 4096)
 
 	for {
-		size, pkt, err := p.Transport.ReadPacket()
+		size, pkt, err := h.Transport.ReadPacket()
 		if err != nil {
-			return 0, 0, []byte{0,0}, err
+			return 0, 0, []byte{0, 0}, err
 		}
 
 		// check for fragments
@@ -48,7 +124,7 @@ func (p *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
 			pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
 			// header is corrupted even after defragmenting
 			if err != nil {
-				return 0, 0, []byte{0,0}, err
+				return 0, 0, []byte{0, 0}, err
 			}
 		}
 		if !fragment {
@@ -57,6 +133,27 @@ func (p *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
 	}
 }
 
+// Creates a packet the is a response to a handshake request
+// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
+// but could be in Windows. However the NTLM protocol is insecure
+func (h *Handler) handshakeResponse(major byte, minor byte, auth uint16) []byte {
+	var caps uint16
+	if h.SmartCardAuth {
+		caps = caps | HTTP_EXTENDED_AUTH_PAA
+	}
+	if h.TokenAuth {
+		caps = caps | HTTP_EXTENDED_AUTH_PAA
+	}
+
+	buf := new(bytes.Buffer)
+	binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code
+	buf.Write([]byte{major, minor})
+	binary.Write(buf, binary.LittleEndian, uint16(0))    // server version
+	binary.Write(buf, binary.LittleEndian, uint16(caps)) // extended auth
+
+	return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())
+}
+
 func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
 	// header needs to be 8 min
 	if len(data) < 8 {
@@ -71,3 +168,188 @@ func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err
 	}
 	return packetType, size, data[8:], nil
 }
+
+func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
+	r := bytes.NewReader(data)
+	binary.Read(r, binary.LittleEndian, &major)
+	binary.Read(r, binary.LittleEndian, &minor)
+	binary.Read(r, binary.LittleEndian, &version)
+	binary.Read(r, binary.LittleEndian, &extAuth)
+
+	log.Printf("major: %d, minor: %d, version: %d, ext auth: %d", major, minor, version, extAuth)
+	return
+}
+
+func readCreateTunnelRequest(data []byte) (caps uint32, cookie string) {
+	var fields uint16
+
+	r := bytes.NewReader(data)
+
+	binary.Read(r, binary.LittleEndian, &caps)
+	binary.Read(r, binary.LittleEndian, &fields)
+	r.Seek(2, io.SeekCurrent)
+
+	if fields == HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE {
+		var size uint16
+		binary.Read(r, binary.LittleEndian, &size)
+		cookieB := make([]byte, size)
+		r.Read(cookieB)
+		cookie, _ = DecodeUTF16(cookieB)
+	}
+	log.Printf("Create tunnel caps: %d, cookie: %s", caps, cookie)
+	return
+}
+
+func createTunnelResponse() []byte {
+	buf := new(bytes.Buffer)
+
+	binary.Write(buf, binary.LittleEndian, uint16(0))                                                                    // server version
+	binary.Write(buf, binary.LittleEndian, uint32(0))                                                                    // error code
+	binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID|HTTP_TUNNEL_RESPONSE_FIELD_CAPS)) // fields present
+	binary.Write(buf, binary.LittleEndian, uint16(0))                                                                    // reserved
+	binary.Write(buf, binary.LittleEndian, uint16(0))                                                                    // reserved
+
+	// tunnel id ?
+	binary.Write(buf, binary.LittleEndian, uint32(15))
+	// caps ?
+	binary.Write(buf, binary.LittleEndian, uint32(2))
+
+	return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
+}
+
+func (h *Handler) readTunnelAuthRequest(data []byte) {
+	buf := bytes.NewReader(data)
+
+	var size uint16
+	binary.Read(buf, binary.LittleEndian, &size)
+	clData := make([]byte, size)
+	binary.Read(buf, binary.LittleEndian, &clData)
+	clientName, _ := DecodeUTF16(clData)
+	log.Printf("Client: %s", clientName)
+}
+
+func (h *Handler) createTunnelAuthResponse() []byte {
+	buf := new(bytes.Buffer)
+
+	binary.Write(buf, binary.LittleEndian, uint32(0))                                                                                        // error code
+	binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present
+	binary.Write(buf, binary.LittleEndian, uint16(0))                                                                                        // reserved
+
+	// flags
+	var redir uint32
+	/*
+		if conf.Caps.RedirectAll {
+			redir = HTTP_TUNNEL_REDIR_ENABLE_ALL
+		} else if conf.Caps.DisableRedirect {
+			redir = HTTP_TUNNEL_REDIR_DISABLE_ALL
+		} else {
+			if conf.Caps.DisableClipboard {
+				redir = redir | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD
+			}
+			if conf.Caps.DisableDrive {
+				redir = redir | HTTP_TUNNEL_REDIR_DISABLE_DRIVE
+			}
+			if conf.Caps.DisablePnp {
+				redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PNP
+			}
+			if conf.Caps.DisablePrinter {
+				redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PRINTER
+			}
+			if conf.Caps.DisablePort {
+				redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PORT
+			}
+		}
+	*/
+	redir = HTTP_TUNNEL_REDIR_ENABLE_ALL
+
+	// idle timeout
+	if IdleTimeout < 0 {
+		IdleTimeout = 0
+	}
+
+	binary.Write(buf, binary.LittleEndian, uint32(redir))       // redir flags
+	binary.Write(buf, binary.LittleEndian, uint32(IdleTimeout)) // timeout in minutes
+
+	return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
+}
+
+func readChannelCreateRequest(data []byte) (server string, port uint16) {
+	buf := bytes.NewReader(data)
+
+	var resourcesSize byte
+	var alternative byte
+	var protocol uint16
+	var nameSize uint16
+
+	binary.Read(buf, binary.LittleEndian, &resourcesSize)
+	binary.Read(buf, binary.LittleEndian, &alternative)
+	binary.Read(buf, binary.LittleEndian, &port)
+	binary.Read(buf, binary.LittleEndian, &protocol)
+	binary.Read(buf, binary.LittleEndian, &nameSize)
+
+	nameData := make([]byte, nameSize)
+	binary.Read(buf, binary.LittleEndian, &nameData)
+
+	log.Printf("Name data %q", nameData)
+	server, _ = DecodeUTF16(nameData)
+
+	log.Printf("Should connect to %s on port %d", server, port)
+	return
+}
+
+func createChannelCreateResponse() []byte {
+	buf := new(bytes.Buffer)
+
+	binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
+	//binary.Write(buf, binary.LittleEndian, uint16(HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID | HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE | HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT)) // fields present
+	binary.Write(buf, binary.LittleEndian, uint16(0)) // fields
+	binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
+
+	// optional fields
+	// channel id uint32 (4)
+	// udp port uint16 (2)
+	// udp auth cookie 1 byte for side channel
+	// length uint16
+
+	return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
+}
+
+func (h *Handler) forwardDataPacket(data []byte) {
+	buf := bytes.NewReader(data)
+
+	var cblen uint16
+	binary.Read(buf, binary.LittleEndian, &cblen)
+	pkt := make([]byte, cblen)
+	binary.Read(buf, binary.LittleEndian, &pkt)
+
+	h.Remote.Write(pkt)
+}
+
+func (h *Handler) sendDataPacket() {
+	defer h.Remote.Close()
+	b1 := new(bytes.Buffer)
+	buf := make([]byte, 4086)
+	for {
+		n, err := h.Remote.Read(buf)
+		binary.Write(b1, binary.LittleEndian, uint16(n))
+		if err != nil {
+			log.Printf("Error reading from conn %s", err)
+			break
+		}
+		b1.Write(buf[:n])
+		h.Transport.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
+		b1.Reset()
+	}
+}
+
+func createPacket(pktType uint16, data []byte) (packet []byte) {
+	size := len(data) + 8
+	buf := new(bytes.Buffer)
+
+	binary.Write(buf, binary.LittleEndian, uint16(pktType))
+	binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
+	binary.Write(buf, binary.LittleEndian, uint32(size))
+	buf.Write(data)
+
+	return buf.Bytes()
+}
diff --git a/protocol/types.go b/protocol/types.go
new file mode 100644
index 0000000000000000000000000000000000000000..c12487d6086da5c90ee8c3fe25a456fe17401ecb
--- /dev/null
+++ b/protocol/types.go
@@ -0,0 +1,60 @@
+package protocol
+
+const (
+	PKT_TYPE_HANDSHAKE_REQUEST      = 0x1
+	PKT_TYPE_HANDSHAKE_RESPONSE     = 0x2
+	PKT_TYPE_EXTENDED_AUTH_MSG      = 0x3
+	PKT_TYPE_TUNNEL_CREATE          = 0x4
+	PKT_TYPE_TUNNEL_RESPONSE        = 0x5
+	PKT_TYPE_TUNNEL_AUTH            = 0x6
+	PKT_TYPE_TUNNEL_AUTH_RESPONSE   = 0x7
+	PKT_TYPE_CHANNEL_CREATE         = 0x8
+	PKT_TYPE_CHANNEL_RESPONSE       = 0x9
+	PKT_TYPE_DATA                   = 0xA
+	PKT_TYPE_SERVICE_MESSAGE        = 0xB
+	PKT_TYPE_REAUTH_MESSAGE         = 0xC
+	PKT_TYPE_KEEPALIVE              = 0xD
+	PKT_TYPE_CLOSE_CHANNEL          = 0x10
+	PKT_TYPE_CLOSE_CHANNEL_RESPONSE = 0x11
+)
+
+const (
+	HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID   = 0x01
+	HTTP_TUNNEL_RESPONSE_FIELD_CAPS        = 0x02
+	HTTP_TUNNEL_RESPONSE_FIELD_SOH_REQ     = 0x04
+	HTTP_TUNNEL_RESPONSE_FIELD_CONSENT_MSG = 0x10
+)
+
+const (
+	HTTP_EXTENDED_AUTH_NONE      = 0x0
+	HTTP_EXTENDED_AUTH_SC        = 0x1  /* Smart card authentication. */
+	HTTP_EXTENDED_AUTH_PAA       = 0x02 /* Pluggable authentication. */
+	HTTP_EXTENDED_AUTH_SSPI_NTLM = 0x04 /* NTLM extended authentication. */
+)
+
+const (
+	HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS  = 0x01
+	HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT = 0x02
+	HTTP_TUNNEL_AUTH_RESPONSE_FIELD_SOH_RESPONSE = 0x04
+)
+
+const (
+	HTTP_TUNNEL_REDIR_ENABLE_ALL        = 0x80000000
+	HTTP_TUNNEL_REDIR_DISABLE_ALL       = 0x40000000
+	HTTP_TUNNEL_REDIR_DISABLE_DRIVE     = 0x01
+	HTTP_TUNNEL_REDIR_DISABLE_PRINTER   = 0x02
+	HTTP_TUNNEL_REDIR_DISABLE_PORT      = 0x03
+	HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD = 0x08
+	HTTP_TUNNEL_REDIR_DISABLE_PNP       = 0x10
+)
+
+const (
+	HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID   = 0x01
+	HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE = 0x02
+	HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT     = 0x04
+)
+
+const (
+	HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1
+)
+
diff --git a/protocol/utf16.go b/protocol/utf16.go
new file mode 100644
index 0000000000000000000000000000000000000000..963dce1746c7b2a3a334f819abf0f3526dd1a263
--- /dev/null
+++ b/protocol/utf16.go
@@ -0,0 +1,32 @@
+package protocol
+
+import (
+	"bytes"
+	"fmt"
+	"unicode/utf16"
+	"unicode/utf8"
+)
+
+func DecodeUTF16(b []byte) (string, error) {
+	if len(b)%2 != 0 {
+		return "", fmt.Errorf("must have even length byte slice")
+	}
+
+	u16s := make([]uint16, 1)
+	ret := &bytes.Buffer{}
+	b8buf := make([]byte, 4)
+
+	lb := len(b)
+	for i := 0; i < lb; i += 2 {
+		u16s[0] = uint16(b[i]) + (uint16(b[i+1]) << 8)
+		r := utf16.Decode(u16s)
+		n := utf8.EncodeRune(b8buf, r[0])
+		ret.Write(b8buf[:n])
+	}
+
+	bret := ret.Bytes()
+	if len(bret) > 0 && bret[len(bret)-1] == '\x00' {
+		bret = bret[:len(bret)-1]
+	}
+	return string(bret), nil
+}
diff --git a/rdg.go b/rdg.go
index e478329d1a1be4cd5ebaa583a8e9a26183536669..1417ed8b584e30631e6310e8735d9e0b22368e97 100644
--- a/rdg.go
+++ b/rdg.go
@@ -3,7 +3,6 @@ package main
 import (
 	"bytes"
 	"encoding/binary"
-	"errors"
 	"fmt"
 	"github.com/bolkedebruin/rdpgw/protocol"
 	"github.com/bolkedebruin/rdpgw/transport"
@@ -157,80 +156,12 @@ func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
 }
 
 func handleWebsocketProtocol(c *websocket.Conn) {
-	var remote net.Conn
-
 	websocketConnections.Inc()
 	defer websocketConnections.Dec()
 
 	inout, _ := transport.NewWS(c)
 	handler := protocol.NewHandler(inout)
-
-	var host string
-	for {
-		pt, sz, pkt, err := handler.ReadMessage()
-		if err != nil {
-			log.Printf("Cannot read message from stream %s", err)
-			return
-		}
-		switch pt {
-		case PKT_TYPE_HANDSHAKE_REQUEST:
-			major, minor, _, auth := readHandshake(pkt)
-			msg := handshakeResponse(major, minor, auth)
-			log.Printf("Handshake response: %x", msg)
-			inout.WritePacket(msg)
-		case PKT_TYPE_TUNNEL_CREATE:
-			readCreateTunnelRequest(pkt)
-			/*data, found := tokens.Get(cookie)
-			if found == false {
-				log.Printf("Invalid PAA cookie: %s from %s", cookie, inout.Conn.RemoteAddr())
-				return
-			}*/
-			host = conf.Server.HostTemplate
-			/*
-			for k, v := range data.(map[string]interface{}) {
-				if val, ok := v.(string); ok == true {
-					host = strings.Replace(host, "{{ " + k + " }}", val, 1)
-				}
-			}*/
-			msg := createTunnelResponse()
-			log.Printf("Create tunnel response: %x", msg)
-			inout.WritePacket(msg)
-		case PKT_TYPE_TUNNEL_AUTH:
-			readTunnelAuthRequest(pkt)
-			msg := createTunnelAuthResponse()
-			log.Printf("Create tunnel auth response: %x", msg)
-			inout.WritePacket(msg)
-		case PKT_TYPE_CHANNEL_CREATE:
-			server, port := readChannelCreateRequest(pkt)
-			if conf.Server.EnableOverride == true {
-				log.Printf("Override allowed")
-				host = net.JoinHostPort(server, strconv.Itoa(int(port)))
-			}
-			log.Printf("Establishing connection to RDP server: %s", host)
-			remote, err = net.DialTimeout(
-				"tcp",
-				host,
-				time.Second * 30)
-			if err != nil {
-				log.Printf("Error connecting to %s", host)
-				return
-			}
-			log.Printf("Connection established")
-			msg := createChannelCreateResponse()
-			log.Printf("Create channel create response: %x", msg)
-			inout.WritePacket(msg)
-			go sendDataPacket(remote, inout)
-		case PKT_TYPE_DATA:
-			forwardDataPacket(remote, pkt)
-		case PKT_TYPE_KEEPALIVE:
-			// do not write to make sure we do not create concurrency issues
-			// inout.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
-		case PKT_TYPE_CLOSE_CHANNEL:
-			break
-		default:
-			log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz, pkt)
-		}
-	}
+	handler.Process()
 }
 
 // The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
@@ -537,24 +468,6 @@ func forwardDataPacket(conn net.Conn, data []byte) {
 	conn.Write(pkt)
 }
 
-func handleWebsocketData(rdp net.Conn, conn transport.Transport) {
-	defer rdp.Close()
-	b1 := new(bytes.Buffer)
-	buf := make([]byte, 4086)
-
-	for {
-		n, err := rdp.Read(buf)
-		binary.Write(b1, binary.LittleEndian, uint16(n))
-		if err != nil {
-			log.Printf("Error reading from conn %s", err)
-			break
-		}
-		b1.Write(buf[:n])
-		conn.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
-		b1.Reset()
-	}
-}
-
 func sendDataPacket(connIn net.Conn, connOut transport.Transport) {
 	defer connIn.Close()
 	b1 := new(bytes.Buffer)