diff --git a/client/remote.go b/common/remote.go similarity index 98% rename from client/remote.go rename to common/remote.go index 141f6c9fe7c5c9196424531437a72914e5261f47..39ad7267d8dea16f0a8e31632c4a0f2f98e9c9e3 100644 --- a/client/remote.go +++ b/common/remote.go @@ -1,4 +1,4 @@ -package client +package common import ( "context" diff --git a/main.go b/main.go index 7496a98391187c66eedc5d46bf2cd3212c15ef15..e89a29365d686844c27d3452522783c60108fe75 100644 --- a/main.go +++ b/main.go @@ -4,7 +4,7 @@ import ( "context" "crypto/tls" "github.com/bolkedebruin/rdpgw/api" - "github.com/bolkedebruin/rdpgw/client" + "github.com/bolkedebruin/rdpgw/common" "github.com/bolkedebruin/rdpgw/config" "github.com/bolkedebruin/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/security" @@ -123,8 +123,8 @@ func main() { ServerConf: &handlerConfig, } - http.Handle("/remoteDesktopGateway/", client.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol))) - http.Handle("/connect", client.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload)))) + http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol))) + http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload)))) http.Handle("/metrics", promhttp.Handler()) http.HandleFunc("/callback", api.HandleCallback) diff --git a/protocol/client.go b/protocol/client.go index a5fd750f1c755dcbda1264f56a16376baade0928..a9b66354ab1e6c9e21505c0b6351769177c6deae 100644 --- a/protocol/client.go +++ b/protocol/client.go @@ -4,8 +4,8 @@ import ( "bytes" "encoding/binary" "fmt" - "github.com/bolkedebruin/rdpgw/transport" "io" + "log" "net" ) @@ -17,10 +17,67 @@ const ( type ClientConfig struct { SmartCardAuth bool - PAAToken string - NTLMAuth bool - GatewayConn transport.Transport - LocalConn net.Conn + PAAToken string + NTLMAuth bool + Session *SessionInfo + LocalConn net.Conn + Server string + Port int + Name string +} + +func (c *ClientConfig) ConnectAndForward() error { + c.Session.TransportOut.WritePacket(c.handshakeRequest()) + + for { + pt, sz, pkt, err := readMessage(c.Session.TransportIn) + if err != nil { + log.Printf("Cannot read message from stream %s", err) + return err + } + + switch pt { + case PKT_TYPE_HANDSHAKE_RESPONSE: + caps, err := c.handshakeResponse(pkt) + if err != nil { + log.Printf("Cannot connect to %s due to %s", c.Server, err) + return err + } + log.Printf("Handshake response received. Caps: %d", caps) + c.Session.TransportOut.WritePacket(c.tunnelRequest()) + case PKT_TYPE_TUNNEL_RESPONSE: + tid, caps, err := c.tunnelResponse(pkt) + if err != nil { + log.Printf("Cannot setup tunnel due to %s", err) + return err + } + log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps) + c.Session.TransportOut.WritePacket(c.tunnelAuthRequest()) + case PKT_TYPE_TUNNEL_AUTH_RESPONSE: + flags, timeout, err := c.tunnelAuthResponse(pkt) + if err != nil { + log.Printf("Cannot do tunnel auth due to %s", err) + return err + } + log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout) + c.Session.TransportOut.WritePacket(c.channelRequest()) + case PKT_TYPE_CHANNEL_RESPONSE: + cid, err := c.channelResponse(pkt) + if err != nil { + log.Printf("Cannot do tunnel auth due to %s", err) + return err + } + if cid < 1 { + log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid) + } + log.Printf("Channel creation succesful. Channel id: %d", cid) + go forward(c.LocalConn, c.Session.TransportOut) + case PKT_TYPE_DATA: + receive(pkt, c.LocalConn) + default: + log.Printf("Unknown packet type received: %d size %d", pt, sz) + } + } } func (c *ClientConfig) handshakeRequest() []byte { @@ -83,7 +140,7 @@ func (c *ClientConfig) tunnelRequest() []byte { binary.Write(buf, binary.LittleEndian, caps) binary.Write(buf, binary.LittleEndian, fields) - binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved + binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved if len(c.PAAToken) > 0 { utf16Token := EncodeUTF16(c.PAAToken) @@ -119,8 +176,8 @@ func (c *ClientConfig) tunnelResponse(data []byte) (tunnelId uint32, caps uint32 return } -func (c *ClientConfig) tunnelAuthRequest(name string) []byte { - utf16name := EncodeUTF16(name) +func (c *ClientConfig) tunnelAuthRequest() []byte { + utf16name := EncodeUTF16(c.Name) size := uint16(len(utf16name)) buf := new(bytes.Buffer) @@ -153,14 +210,14 @@ func (c *ClientConfig) tunnelAuthResponse(data []byte) (flags uint32, timeout ui return } -func (c *ClientConfig) channelRequest(server string, port uint16) []byte { - utf16server := EncodeUTF16(server) +func (c *ClientConfig) channelRequest() []byte { + utf16server := EncodeUTF16(c.Server) buf := new(bytes.Buffer) - binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names - binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3) - binary.Write(buf, binary.LittleEndian, uint16(port)) - binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3 + binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names + binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3) + binary.Write(buf, binary.LittleEndian, uint16(c.Port)) + binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3 binary.Write(buf, binary.LittleEndian, uint16(len(utf16server))) buf.Write(utf16server) diff --git a/protocol/common.go b/protocol/common.go index 6644e3b581a6f681335c1764ab27f7c22cdd8100..37053642b442557df1477689d3beb6460b300748 100644 --- a/protocol/common.go +++ b/protocol/common.go @@ -10,6 +10,62 @@ import ( "net" ) +type RedirectFlags struct { + Clipboard bool + Port bool + Drive bool + Printer bool + Pnp bool + DisableAll bool + EnableAll bool +} + +type SessionInfo struct { + ConnId string + TransportIn transport.Transport + TransportOut transport.Transport + RemoteServer string + ClientIp string +} + +func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) { + fragment := false + index := 0 + buf := make([]byte, 4096) + + for { + size, pkt, err := in.ReadPacket() + if err != nil { + return 0, 0, []byte{0, 0}, err + } + + // check for fragments + var pt uint16 + var sz uint32 + var msg []byte + + if !fragment { + pt, sz, msg, err = readHeader(pkt[:size]) + if err != nil { + fragment = true + index = copy(buf, pkt[:size]) + continue + } + index = 0 + } else { + fragment = false + pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...)) + // header is corrupted even after defragmenting + if err != nil { + return 0, 0, []byte{0, 0}, err + } + } + if !fragment { + return int(pt), int(sz), msg, nil + } + } +} + func createPacket(pktType uint16, data []byte) (packet []byte) { size := len(data) + 8 buf := new(bytes.Buffer) diff --git a/protocol/gateway.go b/protocol/gateway.go index 22729f61bbaa43b257d1c143e91b9b2a8176ce6c..fe4eee211998deb0404d3ef26e81bc221f6210c6 100644 --- a/protocol/gateway.go +++ b/protocol/gateway.go @@ -2,7 +2,7 @@ package protocol import ( "context" - "github.com/bolkedebruin/rdpgw/client" + "github.com/bolkedebruin/rdpgw/common" "github.com/bolkedebruin/rdpgw/transport" "github.com/gorilla/websocket" "github.com/patrickmn/go-cache" @@ -45,14 +45,6 @@ type Gateway struct { ServerConf *ServerConf } -type SessionInfo struct { - ConnId string - TransportIn transport.Transport - TransportOut transport.Transport - RemoteServer string - ClientIp string -} - var upgrader = websocket.Upgrader{} var c = cache.New(5*time.Minute, 10*time.Minute) @@ -118,7 +110,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) return } - log.Printf("Opening RDGOUT for client %s", client.GetClientIp(r.Context())) + log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context())) s.TransportOut = out out.SendAccept(true) @@ -139,13 +131,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s s.TransportIn = in c.Set(s.ConnId, s, cache.DefaultExpiration) - log.Printf("Opening RDGIN for client %s", client.GetClientIp(r.Context())) + log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context())) in.SendAccept(false) // read some initial data in.Drain() - log.Printf("Legacy handshakeRequest done for client %s", client.GetClientIp(r.Context())) + log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context())) handler := NewServer(s, g.ServerConf) handler.Process(r.Context()) } diff --git a/protocol/protocol_test.go b/protocol/protocol_test.go index ed20f90c8901faa1ac4b979fceb2a3125ecc164f..215ee789621e613b23fc9e8458f22d23d73c904d 100644 --- a/protocol/protocol_test.go +++ b/protocol/protocol_test.go @@ -204,4 +204,4 @@ func TestChannelCreation(t *testing.T) { if channelId < 1 { t.Fatalf("channelResponse failed got channeld id %d, expected > 0", channelId) } -} \ No newline at end of file +} diff --git a/protocol/server.go b/protocol/server.go index 0e59535f985f72b8ac2b73f583ab20997f6f8cda..13288dd28c7e6e27e662fa1e7b2b78a77043098f 100644 --- a/protocol/server.go +++ b/protocol/server.go @@ -5,7 +5,7 @@ import ( "context" "encoding/binary" "errors" - "github.com/bolkedebruin/rdpgw/client" + "github.com/bolkedebruin/rdpgw/common" "io" "log" "net" @@ -17,16 +17,6 @@ type VerifyTunnelCreate func(context.Context, string) (bool, error) type VerifyTunnelAuthFunc func(context.Context, string) (bool, error) type VerifyServerFunc func(context.Context, string) (bool, error) -type RedirectFlags struct { - Clipboard bool - Port bool - Drive bool - Printer bool - Pnp bool - DisableAll bool - EnableAll bool -} - type Server struct { Session *SessionInfo VerifyTunnelCreate VerifyTunnelCreate @@ -70,7 +60,7 @@ const tunnelId = 10 func (s *Server) Process(ctx context.Context) error { for { - pt, sz, pkt, err := s.ReadMessage() + pt, sz, pkt, err := readMessage(s.Session.TransportIn) if err != nil { log.Printf("Cannot read message from stream %s", err) return err @@ -78,7 +68,7 @@ func (s *Server) Process(ctx context.Context) error { switch pt { case PKT_TYPE_HANDSHAKE_REQUEST: - log.Printf("Client handshakeRequest from %s", client.GetClientIp(ctx)) + log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx)) if s.State != SERVER_STATE_INITIAL { log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIAL) return errors.New("wrong state") @@ -97,7 +87,7 @@ func (s *Server) Process(ctx context.Context) error { _, cookie := s.tunnelRequest(pkt) if s.VerifyTunnelCreate != nil { if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok { - log.Printf("Invalid PAA cookie received from client %s", client.GetClientIp(ctx)) + log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx)) return errors.New("invalid PAA cookie") } } @@ -181,44 +171,6 @@ func (s *Server) Process(ctx context.Context) error { } } -func (s *Server) ReadMessage() (pt int, n int, msg []byte, err error) { - fragment := false - index := 0 - buf := make([]byte, 4096) - - for { - size, pkt, err := s.Session.TransportIn.ReadPacket() - if err != nil { - return 0, 0, []byte{0, 0}, err - } - - // check for fragments - var pt uint16 - var sz uint32 - var msg []byte - - if !fragment { - pt, sz, msg, err = readHeader(pkt[:size]) - if err != nil { - fragment = true - index = copy(buf, pkt[:size]) - continue - } - index = 0 - } else { - fragment = false - pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...)) - // header is corrupted even after defragmenting - if err != nil { - return 0, 0, []byte{0, 0}, err - } - } - if !fragment { - return int(pt), int(sz), msg, nil - } - } -} - // Creates a packet the is a response to a handshakeRequest request // HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux // but could be in Windows. However the NTLM protocol is insecure diff --git a/security/jwt.go b/security/jwt.go index ef253980fdb163997c33bbd2539fec86c6c17632..af27bb3c89d7d99c629cb6f7e732b0b0274359c7 100644 --- a/security/jwt.go +++ b/security/jwt.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "github.com/bolkedebruin/rdpgw/client" + "github.com/bolkedebruin/rdpgw/common" "github.com/bolkedebruin/rdpgw/protocol" "github.com/dgrijalva/jwt-go/v4" "log" @@ -55,9 +55,9 @@ func VerifyServerFunc(ctx context.Context, host string) (bool, error) { return false, nil } - if s.ClientIp != client.GetClientIp(ctx) { + if s.ClientIp != common.GetClientIp(ctx) { log.Printf("Current client ip address %s does not match token client ip %s", - client.GetClientIp(ctx), s.ClientIp) + common.GetClientIp(ctx), s.ClientIp) return false, nil } @@ -78,7 +78,7 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri c := customClaims{ RemoteServer: server, - ClientIP: client.GetClientIp(ctx), + ClientIP: common.GetClientIp(ctx), StandardClaims: jwt.StandardClaims{ ExpiresAt: exp, IssuedAt: now,