Skip to content
Snippets Groups Projects
Commit 4e99b4e8 authored by Bolke de Bruin's avatar Bolke de Bruin
Browse files

Refactor and add tests

parent ecfa9e6c
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,9 @@ import (
"bytes"
"encoding/binary"
"fmt"
"github.com/bolkedebruin/rdpgw/transport"
"io"
"net"
)
const (
......@@ -17,6 +19,8 @@ type ClientConfig struct {
SmartCardAuth bool
PAAToken string
NTLMAuth bool
GatewayConn transport.Transport
LocalConn net.Conn
}
func (c *ClientConfig) handshakeRequest() []byte {
......@@ -148,3 +152,38 @@ func (c *ClientConfig) tunnelAuthResponse(data []byte) (flags uint32, timeout ui
return
}
func (c *ClientConfig) channelRequest(server string, port uint16) []byte {
utf16server := EncodeUTF16(server)
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names
binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3)
binary.Write(buf, binary.LittleEndian, uint16(port))
binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3
binary.Write(buf, binary.LittleEndian, uint16(len(utf16server)))
buf.Write(utf16server)
return createPacket(PKT_TYPE_CHANNEL_CREATE, buf.Bytes())
}
func (c *ClientConfig) channelResponse(data []byte) (channelId uint32, err error) {
var errorCode uint32
var fields uint16
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &errorCode)
binary.Read(r, binary.LittleEndian, &fields)
r.Seek(2, io.SeekCurrent)
if (fields & HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID) == HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID {
binary.Read(r, binary.LittleEndian, &channelId)
}
if errorCode > 0 {
return 0, fmt.Errorf("channel response error %d", errorCode)
}
return channelId, nil
}
......@@ -4,7 +4,10 @@ import (
"bytes"
"encoding/binary"
"errors"
"github.com/bolkedebruin/rdpgw/transport"
"io"
"log"
"net"
)
func createPacket(pktType uint16, data []byte) (packet []byte) {
......@@ -34,4 +37,35 @@ func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err
return packetType, size, data[8:], nil
}
// sends data wrapped inside the rdpgw protocol
func forward(in net.Conn, out transport.Transport) {
defer in.Close()
b1 := new(bytes.Buffer)
buf := make([]byte, 4086)
for {
n, err := in.Read(buf)
if err != nil {
log.Printf("Error reading from local conn %s", err)
break
}
binary.Write(b1, binary.LittleEndian, uint16(n))
b1.Write(buf[:n])
out.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset()
}
}
// receive data from the wire, unwrap and forward to the client
func receive(data []byte, out net.Conn) {
buf := bytes.NewReader(data)
var cblen uint16
binary.Read(buf, binary.LittleEndian, &cblen)
pkt := make([]byte, cblen)
binary.Read(buf, binary.LittleEndian, &pkt)
out.Write(pkt)
}
......@@ -14,6 +14,8 @@ const (
TunnelCreateResponseLen = HeaderLen + 18
TunnelAuthLen = HeaderLen + 2 // + dynamic
TunnelAuthResponseLen = HeaderLen + 16
ChannelCreateLen = HeaderLen + 8 // + dynamic
ChannelResponseLen = HeaderLen + 12
)
func verifyPacketHeader(data []byte, expPt uint16, expSize uint32) (uint16, uint32, []byte, error) {
......@@ -162,3 +164,44 @@ func TestTunnelAuth(t *testing.T) {
timeout, hc.IdleTimeout)
}
}
func TestChannelCreation(t *testing.T) {
client := ClientConfig{}
s := &SessionInfo{}
hc := &ServerConf{
TokenAuth: true,
IdleTimeout: 10,
RedirectFlags: RedirectFlags{
Clipboard: true,
},
}
h := NewServer(s, hc)
server := "test_server"
port := uint16(3389)
data := client.channelRequest(server, port)
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2))
if err != nil {
t.Fatalf("verifyHeader failed: %s", err)
}
hServer, hPort := h.channelRequest(pkt)
if hServer != server {
t.Fatalf("channelRequest failed got server %s, expected %s", hServer, server)
}
if hPort != port {
t.Fatalf("channelRequest failed got port %d, expected %d", hPort, port)
}
data = h.channelResponse()
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_CHANNEL_RESPONSE, uint32(ChannelResponseLen))
if err != nil {
t.Fatalf("verifyHeader failed: %s", err)
}
channelId, err := client.channelResponse(pkt)
if err != nil {
t.Fatalf("channelResponse failed: %s", err)
}
if channelId < 1 {
t.Fatalf("channelResponse failed got channeld id %d, expected > 0", channelId)
}
}
\ No newline at end of file
......@@ -148,7 +148,7 @@ func (s *Server) Process(ctx context.Context) error {
// Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually
go s.sendDataPacket()
go forward(s.Remote, s.Session.TransportOut)
s.State = SERVER_STATE_CHANNEL_CREATE
case PKT_TYPE_DATA:
if s.State < SERVER_STATE_CHANNEL_CREATE {
......@@ -156,7 +156,7 @@ func (s *Server) Process(ctx context.Context) error {
return errors.New("wrong state")
}
s.State = SERVER_STATE_OPENED
s.forwardDataPacket(pkt)
receive(pkt, s.Remote)
case PKT_TYPE_KEEPALIVE:
// keepalives can be received while the channel is not open yet
if s.State < SERVER_STATE_CHANNEL_CREATE {
......@@ -357,34 +357,6 @@ func (s *Server) channelResponse() []byte {
return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
}
func (s *Server) forwardDataPacket(data []byte) {
buf := bytes.NewReader(data)
var cblen uint16
binary.Read(buf, binary.LittleEndian, &cblen)
pkt := make([]byte, cblen)
binary.Read(buf, binary.LittleEndian, &pkt)
s.Remote.Write(pkt)
}
func (s *Server) sendDataPacket() {
defer s.Remote.Close()
b1 := new(bytes.Buffer)
buf := make([]byte, 4086)
for {
n, err := s.Remote.Read(buf)
binary.Write(b1, binary.LittleEndian, uint16(n))
if err != nil {
log.Printf("Error reading from conn %s", err)
break
}
b1.Write(buf[:n])
s.Session.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset()
}
}
func makeRedirectFlags(flags RedirectFlags) int {
var redir = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment