diff --git a/protocol/client.go b/protocol/client.go index f6665fbece45ff0f9303503440d1f8808b5315ed..a5fd750f1c755dcbda1264f56a16376baade0928 100644 --- a/protocol/client.go +++ b/protocol/client.go @@ -4,7 +4,9 @@ import ( "bytes" "encoding/binary" "fmt" + "github.com/bolkedebruin/rdpgw/transport" "io" + "net" ) const ( @@ -17,6 +19,8 @@ type ClientConfig struct { SmartCardAuth bool PAAToken string NTLMAuth bool + GatewayConn transport.Transport + LocalConn net.Conn } func (c *ClientConfig) handshakeRequest() []byte { @@ -147,4 +151,39 @@ func (c *ClientConfig) tunnelAuthResponse(data []byte) (flags uint32, timeout ui } return -} \ No newline at end of file +} + +func (c *ClientConfig) channelRequest(server string, port uint16) []byte { + utf16server := EncodeUTF16(server) + + buf := new(bytes.Buffer) + binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names + binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3) + binary.Write(buf, binary.LittleEndian, uint16(port)) + binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3 + + binary.Write(buf, binary.LittleEndian, uint16(len(utf16server))) + buf.Write(utf16server) + + return createPacket(PKT_TYPE_CHANNEL_CREATE, buf.Bytes()) +} + +func (c *ClientConfig) channelResponse(data []byte) (channelId uint32, err error) { + var errorCode uint32 + var fields uint16 + + r := bytes.NewReader(data) + binary.Read(r, binary.LittleEndian, &errorCode) + binary.Read(r, binary.LittleEndian, &fields) + r.Seek(2, io.SeekCurrent) + + if (fields & HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID) == HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID { + binary.Read(r, binary.LittleEndian, &channelId) + } + + if errorCode > 0 { + return 0, fmt.Errorf("channel response error %d", errorCode) + } + + return channelId, nil +} diff --git a/protocol/common.go b/protocol/common.go index 662a8466fb92d03c08a9fb17ede4f28b689ef18a..6644e3b581a6f681335c1764ab27f7c22cdd8100 100644 --- a/protocol/common.go +++ b/protocol/common.go @@ -4,7 +4,10 @@ import ( "bytes" "encoding/binary" "errors" + "github.com/bolkedebruin/rdpgw/transport" "io" + "log" + "net" ) func createPacket(pktType uint16, data []byte) (packet []byte) { @@ -34,4 +37,35 @@ func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err return packetType, size, data[8:], nil } +// sends data wrapped inside the rdpgw protocol +func forward(in net.Conn, out transport.Transport) { + defer in.Close() + + b1 := new(bytes.Buffer) + buf := make([]byte, 4086) + + for { + n, err := in.Read(buf) + if err != nil { + log.Printf("Error reading from local conn %s", err) + break + } + binary.Write(b1, binary.LittleEndian, uint16(n)) + b1.Write(buf[:n]) + out.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes())) + b1.Reset() + } +} + +// receive data from the wire, unwrap and forward to the client +func receive(data []byte, out net.Conn) { + buf := bytes.NewReader(data) + + var cblen uint16 + binary.Read(buf, binary.LittleEndian, &cblen) + pkt := make([]byte, cblen) + binary.Read(buf, binary.LittleEndian, &pkt) + + out.Write(pkt) +} diff --git a/protocol/protocol_test.go b/protocol/protocol_test.go index 20bdc98a50122c87b17e2aef9a0b8983de6bbe07..ed20f90c8901faa1ac4b979fceb2a3125ecc164f 100644 --- a/protocol/protocol_test.go +++ b/protocol/protocol_test.go @@ -14,6 +14,8 @@ const ( TunnelCreateResponseLen = HeaderLen + 18 TunnelAuthLen = HeaderLen + 2 // + dynamic TunnelAuthResponseLen = HeaderLen + 16 + ChannelCreateLen = HeaderLen + 8 // + dynamic + ChannelResponseLen = HeaderLen + 12 ) func verifyPacketHeader(data []byte, expPt uint16, expSize uint32) (uint16, uint32, []byte, error) { @@ -162,3 +164,44 @@ func TestTunnelAuth(t *testing.T) { timeout, hc.IdleTimeout) } } + +func TestChannelCreation(t *testing.T) { + client := ClientConfig{} + s := &SessionInfo{} + hc := &ServerConf{ + TokenAuth: true, + IdleTimeout: 10, + RedirectFlags: RedirectFlags{ + Clipboard: true, + }, + } + h := NewServer(s, hc) + server := "test_server" + port := uint16(3389) + + data := client.channelRequest(server, port) + _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2)) + if err != nil { + t.Fatalf("verifyHeader failed: %s", err) + } + hServer, hPort := h.channelRequest(pkt) + if hServer != server { + t.Fatalf("channelRequest failed got server %s, expected %s", hServer, server) + } + if hPort != port { + t.Fatalf("channelRequest failed got port %d, expected %d", hPort, port) + } + + data = h.channelResponse() + _, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_CHANNEL_RESPONSE, uint32(ChannelResponseLen)) + if err != nil { + t.Fatalf("verifyHeader failed: %s", err) + } + channelId, err := client.channelResponse(pkt) + if err != nil { + t.Fatalf("channelResponse failed: %s", err) + } + if channelId < 1 { + t.Fatalf("channelResponse failed got channeld id %d, expected > 0", channelId) + } +} \ No newline at end of file diff --git a/protocol/server.go b/protocol/server.go index ba946b7f7245a0ef29c8323cf8b3f851111592f3..0e59535f985f72b8ac2b73f583ab20997f6f8cda 100644 --- a/protocol/server.go +++ b/protocol/server.go @@ -148,7 +148,7 @@ func (s *Server) Process(ctx context.Context) error { // Make sure to start the flow from the RDP server first otherwise connections // might hang eventually - go s.sendDataPacket() + go forward(s.Remote, s.Session.TransportOut) s.State = SERVER_STATE_CHANNEL_CREATE case PKT_TYPE_DATA: if s.State < SERVER_STATE_CHANNEL_CREATE { @@ -156,7 +156,7 @@ func (s *Server) Process(ctx context.Context) error { return errors.New("wrong state") } s.State = SERVER_STATE_OPENED - s.forwardDataPacket(pkt) + receive(pkt, s.Remote) case PKT_TYPE_KEEPALIVE: // keepalives can be received while the channel is not open yet if s.State < SERVER_STATE_CHANNEL_CREATE { @@ -357,34 +357,6 @@ func (s *Server) channelResponse() []byte { return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes()) } -func (s *Server) forwardDataPacket(data []byte) { - buf := bytes.NewReader(data) - - var cblen uint16 - binary.Read(buf, binary.LittleEndian, &cblen) - pkt := make([]byte, cblen) - binary.Read(buf, binary.LittleEndian, &pkt) - - s.Remote.Write(pkt) -} - -func (s *Server) sendDataPacket() { - defer s.Remote.Close() - b1 := new(bytes.Buffer) - buf := make([]byte, 4086) - for { - n, err := s.Remote.Read(buf) - binary.Write(b1, binary.LittleEndian, uint16(n)) - if err != nil { - log.Printf("Error reading from conn %s", err) - break - } - b1.Write(buf[:n]) - s.Session.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes())) - b1.Reset() - } -} - func makeRedirectFlags(flags RedirectFlags) int { var redir = 0