diff --git a/protocol/client.go b/protocol/client.go index a15d1cbf7ea04b89a017cf8a87f0b04eaae52b65..73aadf32fd9dc12b01c26d0d0c37d87e174637fa 100644 --- a/protocol/client.go +++ b/protocol/client.go @@ -3,6 +3,8 @@ package protocol import ( "bytes" "encoding/binary" + "fmt" + "io" ) const ( @@ -43,4 +45,72 @@ func (c *ClientConfig) handshakeRequest() []byte { return createPacket(PKT_TYPE_HANDSHAKE_REQUEST, buf.Bytes()) } -func (c *ClientConfig) readServerHandshakeResponse(data []byte) () +func (c *ClientConfig) handshakeResponse(data []byte) (caps uint16, err error) { + var errorCode int32 + var major byte + var minor byte + var version uint16 + + r := bytes.NewReader(data) + binary.Read(r, binary.LittleEndian, &errorCode) + binary.Read(r, binary.LittleEndian, &major) + binary.Read(r, binary.LittleEndian, &minor) + binary.Read(r, binary.LittleEndian, &version) + binary.Read(r, binary.LittleEndian, &caps) + + if errorCode > 0 { + return 0, fmt.Errorf("error code: %d", errorCode) + } + + return caps, nil +} + +func (c *ClientConfig) tunnelRequest() []byte { + buf := new(bytes.Buffer) + var caps uint32 + var size uint16 + var fields uint16 + + if len(c.PAAToken) > 0 { + fields = fields | HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE + } + + caps = caps | HTTP_CAPABILITY_IDLE_TIMEOUT + + binary.Write(buf, binary.LittleEndian, caps) + binary.Write(buf, binary.LittleEndian, fields) + binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved + + if len(c.PAAToken) > 0 { + utf16Token := EncodeUTF16(c.PAAToken) + size = uint16(len(utf16Token)) + binary.Write(buf, binary.LittleEndian, size) + buf.Write(utf16Token) + } + + return createPacket(PKT_TYPE_TUNNEL_CREATE, buf.Bytes()) +} + +func (c *ClientConfig) tunnelResponse(data []byte) (tunnelId uint32, caps uint32, err error) { + var version uint16 + var errorCode uint32 + var fields uint16 + + r := bytes.NewReader(data) + binary.Read(r, binary.LittleEndian, &version) + binary.Read(r, binary.LittleEndian, &errorCode) + binary.Read(r, binary.LittleEndian, &fields) + r.Seek(2, io.SeekCurrent) + if (fields & HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID) == HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID { + binary.Read(r, binary.LittleEndian, &tunnelId) + } + if (fields & HTTP_TUNNEL_RESPONSE_FIELD_CAPS) == HTTP_TUNNEL_RESPONSE_FIELD_CAPS { + binary.Read(r, binary.LittleEndian, &caps) + } + + if errorCode != 0 { + err = fmt.Errorf("tunnel error %d", errorCode) + } + + return +} \ No newline at end of file diff --git a/protocol/handler_test.go b/protocol/handler_test.go index c938cade2c56a40ae470952896e62e1fb006acc3..18308cb552dbc3970dece9542454c48b7da96dbc 100644 --- a/protocol/handler_test.go +++ b/protocol/handler_test.go @@ -1,6 +1,7 @@ package protocol import ( + "fmt" "log" "testing" ) @@ -8,26 +9,40 @@ import ( const ( HeaderLen = 8 HandshakeRequestLen = HeaderLen + 6 + HandshakeResponseLen = HeaderLen + 10 + TunnelCreateRequestLen = HeaderLen + 8 // + dynamic + TunnelCreateResponseLen = HeaderLen + 18 ) +func verifyPacketHeader(data []byte , expPt uint16, expSize uint32) (uint16, uint32, []byte, error) { + pt, size, pkt, err := readHeader(data) + + if pt != expPt { + return 0,0, []byte{}, fmt.Errorf("readHeader failed, expected packet type %d got %d", expPt, pt) + } + + if size != expSize { + return 0, 0, []byte{}, fmt.Errorf("readHeader failed, expected size %d, got %d", expSize, size) + } + + if err != nil { + return 0, 0, []byte{}, err + } + + return pt, size, pkt, nil +} + func TestHandshake(t *testing.T) { client := ClientConfig{ PAAToken: "abab", } data := client.handshakeRequest() - pt, size, pkt, err := readHeader(data) - if pt != PKT_TYPE_HANDSHAKE_REQUEST { - t.Fatalf("readHeader failed, expected packet type %d got %d", PKT_TYPE_HANDSHAKE_REQUEST, pt) - } - - if size != HandshakeRequestLen { - t.Fatalf("readHeader failed, expected size %d, got %d", HandshakeRequestLen, size) - } + _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_REQUEST, HandshakeRequestLen) if err != nil { - t.Fatalf("readHeader failed got error %s", err) + t.Fatalf("verifyHeader failed: %s", err) } log.Printf("pkt: %x", pkt) @@ -41,4 +56,61 @@ func TestHandshake(t *testing.T) { if !((extAuth & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) { t.Fatalf("readHandshake failed got ext auth %d, expected %d", extAuth, extAuth | HTTP_EXTENDED_AUTH_PAA) } + + s := &SessionInfo{} + hc := &HandlerConf{ + TokenAuth: true, + } + + h := NewHandler(s, hc) + + data = h.handshakeResponse(0x0, 0x0) + _, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_RESPONSE, HandshakeResponseLen) + if err != nil { + t.Fatalf("verifyHeader failed: %s", err) + } + log.Printf("pkt: %x", pkt) + + caps, err := client.handshakeResponse(pkt) + if !((caps & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) { + t.Fatalf("handshakeResponse failed got caps %d, expected %d", caps, caps | HTTP_EXTENDED_AUTH_PAA) + } } + +func TestTunnelCreation(t *testing.T) { + client := ClientConfig{ + PAAToken: "abab", + } + + data := client.tunnelRequest() + _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE, + uint32(TunnelCreateRequestLen + 2 + len(client.PAAToken)*2)) + if err != nil { + t.Fatalf("verifyHeader failed: %s", err) + } + + caps, token := readCreateTunnelRequest(pkt) + if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) { + t.Fatalf("readCreateTunnelRequest failed got caps %d, expected %d", caps, caps | HTTP_CAPABILITY_IDLE_TIMEOUT) + } + if token != client.PAAToken { + t.Fatalf("readCreateTunnelRequest failed got token %s, expected %s", token, client.PAAToken) + } + + data = createTunnelResponse() + _, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_RESPONSE, TunnelCreateResponseLen) + if err != nil { + t.Fatalf("verifyHeader failed: %s", err) + } + + tid, caps, err := client.tunnelResponse(pkt) + if err != nil { + t.Fatalf("Error %s", err) + } + if tid != tunnelId { + t.Fatalf("tunnelResponse failed tunnel id %d, expected %d", tid, tunnelId) + } + if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) { + t.Fatalf("tunnelResponse failed got caps %d, expected %d", caps, caps | HTTP_CAPABILITY_IDLE_TIMEOUT) + } +} \ No newline at end of file diff --git a/protocol/rdpgw_test.go b/protocol/rdpgw_test.go deleted file mode 100644 index dc6bdf66d68bc7983ecff99fdc9c1551ee8baec5..0000000000000000000000000000000000000000 --- a/protocol/rdpgw_test.go +++ /dev/null @@ -1,2 +0,0 @@ -package protocol - diff --git a/protocol/utf16.go b/protocol/utf16.go index 963dce1746c7b2a3a334f819abf0f3526dd1a263..2c574d2edda25a75a2c31624f3e3715644999cf1 100644 --- a/protocol/utf16.go +++ b/protocol/utf16.go @@ -2,6 +2,7 @@ package protocol import ( "bytes" + "encoding/binary" "fmt" "unicode/utf16" "unicode/utf8" @@ -30,3 +31,12 @@ func DecodeUTF16(b []byte) (string, error) { } return string(bret), nil } + +func EncodeUTF16(s string) []byte { + ret := new(bytes.Buffer) + enc := utf16.Encode([]rune(s)) + for c := range enc { + binary.Write(ret, binary.LittleEndian, enc[c]) + } + return ret.Bytes() +} \ No newline at end of file