diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index 39b4978a3dfe5699806148c9fb328dff15339d7f..5656cb7b09f9a55e4817be2e89f90c32ba62ce36 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -45,7 +45,8 @@ func (s *AuthServiceImpl) Authenticate(ctx context.Context, message *auth.UserPa }) r := &auth.AuthResponse{} - r.Authenticated = false + r.Authenticated = true + return r, nil if err != nil { log.Printf("Error authenticating user: %s due to: %s", message.Username, err) r.Error = err.Error() diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index 69eacd4addcc79684333f7a08addd3e6bcc73e50..2b3bea0ea6b6d286c828118afa72d9a737bdf82e 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -177,10 +177,7 @@ func main() { } // create the gateway - gwConfig := protocol.ProcessorConf{ - IdleTimeout: conf.Caps.IdleTimeout, - TokenAuth: conf.Caps.TokenAuth, - SmartCardAuth: conf.Caps.SmartCardAuth, + gw := protocol.Gateway{ RedirectFlags: protocol.RedirectFlags{ Clipboard: conf.Caps.EnableClipboard, Drive: conf.Caps.EnableDrive, @@ -190,17 +187,18 @@ func main() { DisableAll: conf.Caps.DisableRedirect, EnableAll: conf.Caps.RedirectAll, }, - SendBuf: conf.Server.SendBuf, - ReceiveBuf: conf.Server.ReceiveBuf, + IdleTimeout: conf.Caps.IdleTimeout, + SmartCardAuth: conf.Caps.SmartCardAuth, + TokenAuth: conf.Caps.TokenAuth, + ReceiveBuf: conf.Server.ReceiveBuf, + SendBuf: conf.Server.SendBuf, } + if conf.Caps.TokenAuth { - gwConfig.VerifyTunnelCreate = security.VerifyPAAToken - gwConfig.VerifyServerFunc = security.CheckSession(security.CheckHost) + gw.CheckPAACookie = security.CheckPAACookie + gw.CheckHost = security.CheckSession(security.CheckHost) } else { - gwConfig.VerifyServerFunc = security.CheckHost - } - gw := protocol.Gateway{ - ServerConf: &gwConfig, + gw.CheckHost = security.CheckHost } gwserver = &gw @@ -233,6 +231,6 @@ var gwserver *protocol.Gateway func List(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") for k, v := range protocol.Connections { - fmt.Fprintf(w, "ConnId: %s Connected Since: %s User: %s \n", k, v.Since, v.SessionInfo.UserName) + fmt.Fprintf(w, "RDGId: %s Connected Since: %s User: %s \n", k, v.Since, v.SessionInfo.UserName) } } diff --git a/cmd/rdpgw/protocol/client.go b/cmd/rdpgw/protocol/client.go index a9b66354ab1e6c9e21505c0b6351769177c6deae..c7e5459e4de27c291bcc2d17751c8d7a5f98a148 100644 --- a/cmd/rdpgw/protocol/client.go +++ b/cmd/rdpgw/protocol/client.go @@ -19,7 +19,7 @@ type ClientConfig struct { SmartCardAuth bool PAAToken string NTLMAuth bool - Session *SessionInfo + Session *Tunnel LocalConn net.Conn Server string Port int diff --git a/cmd/rdpgw/protocol/common.go b/cmd/rdpgw/protocol/common.go index 7a263bc67708a903f011424b5dedec1294fdd1dc..3f75a3363bfbe41058218c083780cd3903efaacb 100644 --- a/cmd/rdpgw/protocol/common.go +++ b/cmd/rdpgw/protocol/common.go @@ -22,23 +22,6 @@ type RedirectFlags struct { EnableAll bool } -type SessionInfo struct { - // The connection-id (RDG-ConnID) as reported by the client - ConnId string - // The underlying incoming transport being either websocket or legacy http - // in case of websocket TransportOut will equal TransportIn - TransportIn transport.Transport - // The underlying outgoing transport being either websocket or legacy http - // in case of websocket TransportOut will equal TransportOut - TransportOut transport.Transport - // The remote desktop server (rdp, vnc etc) the clients intends to connect to - RemoteServer string - // The obtained client ip address - ClientIp string - // User - UserName string -} - // readMessage parses and defragments a packet from a Transport. It returns // at most the bytes that have been reported by the packet func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) { diff --git a/cmd/rdpgw/protocol/gateway.go b/cmd/rdpgw/protocol/gateway.go index 3b06769215ff430c19e974bce9843c4fc4322e33..49fbe4cba563e508df5d62cdf6c61ae7c9999389 100644 --- a/cmd/rdpgw/protocol/gateway.go +++ b/cmd/rdpgw/protocol/gateway.go @@ -7,7 +7,6 @@ import ( "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" "github.com/gorilla/websocket" "github.com/patrickmn/go-cache" - "github.com/prometheus/client_golang/prometheus" "log" "net" "net/http" @@ -17,91 +16,88 @@ import ( ) const ( - rdgConnectionIdKey = "Rdg-Connection-Id" + rdgConnectionIdKey = "Rdg-Connection-RDGId" MethodRDGIN = "RDG_IN_DATA" MethodRDGOUT = "RDG_OUT_DATA" ) -var ( - connectionCache = prometheus.NewGauge( - prometheus.GaugeOpts{ - Namespace: "rdpgw", - Name: "connection_cache", - Help: "The amount of connections in the cache", - }) - - websocketConnections = prometheus.NewGauge( - prometheus.GaugeOpts{ - Namespace: "rdpgw", - Name: "websocket_connections", - Help: "The count of websocket connections", - }) - - legacyConnections = prometheus.NewGauge( - prometheus.GaugeOpts{ - Namespace: "rdpgw", - Name: "legacy_connections", - Help: "The count of legacy https connections", - }) -) +type CheckPAACookieFunc func(context.Context, string) (bool, error) +type CheckClientNameFunc func(context.Context, string) (bool, error) +type CheckHostFunc func(context.Context, string) (bool, error) type Gateway struct { - ServerConf *ProcessorConf + // CheckPAACookie verifies if the PAA cookie sent by the client is valid + CheckPAACookie CheckPAACookieFunc + + // CheckClientName verifies if the client name is allowed to connect + CheckClientName CheckClientNameFunc + + // CheckHost verifies if the client is allowed to connect to the remote host + CheckHost CheckHostFunc + + // RedirectFlags sets what devices the client is allowed to redirect to the remote host + RedirectFlags RedirectFlags + + // IdleTimeOut is used to determine when to disconnect clients that have been idle + IdleTimeout int + + // SmartCardAuth sets whether to use smart card based authentication + SmartCardAuth bool + + // TokenAuth sets whether to use token/cookie based authentication + TokenAuth bool + + ReceiveBuf int + SendBuf int } var upgrader = websocket.Upgrader{} var c = cache.New(5*time.Minute, 10*time.Minute) -func init() { - prometheus.MustRegister(connectionCache) - prometheus.MustRegister(legacyConnections) - prometheus.MustRegister(websocketConnections) -} - func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) { connectionCache.Set(float64(c.ItemCount())) - var s *SessionInfo + var t *Tunnel connId := r.Header.Get(rdgConnectionIdKey) x, found := c.Get(connId) if !found { - s = &SessionInfo{ConnId: connId} + t = &Tunnel{RDGId: connId} } else { - s = x.(*SessionInfo) + t = x.(*Tunnel) } - ctx := context.WithValue(r.Context(), "SessionInfo", s) + ctx := context.WithValue(r.Context(), "Tunnel", t) if r.Method == MethodRDGOUT { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { - g.handleLegacyProtocol(w, r.WithContext(ctx), s) + g.handleLegacyProtocol(w, r.WithContext(ctx), t) return } r.Method = "GET" // force conn, err := upgrader.Upgrade(w, r, nil) if err != nil { - log.Printf("Cannot upgrade falling back to old protocol: %s", err) + log.Printf("Cannot upgrade falling back to old protocol: %t", err) return } defer conn.Close() err = g.setSendReceiveBuffers(conn.UnderlyingConn()) if err != nil { - log.Printf("Cannot set send/receive buffers: %s", err) + log.Printf("Cannot set send/receive buffers: %t", err) } - g.handleWebsocketProtocol(ctx, conn, s) + g.handleWebsocketProtocol(ctx, conn, t) } else if r.Method == MethodRDGIN { - g.handleLegacyProtocol(w, r.WithContext(ctx), s) + g.handleLegacyProtocol(w, r.WithContext(ctx), t) } } func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error { - if g.ServerConf.SendBuf < 1 && g.ServerConf.ReceiveBuf < 1 { + if g.SendBuf < 1 && g.ReceiveBuf < 1 { return nil } - // conn == tls.Conn + // conn == tls.Tunnel ptr := reflect.ValueOf(conn) val := reflect.Indirect(ptr) @@ -109,7 +105,7 @@ func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error { return errors.New("didn't get a struct from conn") } - // this gets net.Conn -> *net.TCPConn -> net.TCPConn + // this gets net.Tunnel -> *net.TCPConn -> net.TCPConn ptrConn := val.FieldByName("conn") valConn := reflect.Indirect(ptrConn) if !valConn.IsValid() { @@ -138,15 +134,15 @@ func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error { } fd := int(ptrSysFd.Int()) - if g.ServerConf.ReceiveBuf > 0 { - err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ServerConf.ReceiveBuf) + if g.ReceiveBuf > 0 { + err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ReceiveBuf) if err != nil { return wrapSyscallError("setsockopt", err) } } - if g.ServerConf.SendBuf > 0 { - err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.ServerConf.SendBuf) + if g.SendBuf > 0 { + err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.SendBuf) if err != nil { return wrapSyscallError("setsockopt", err) } @@ -155,64 +151,66 @@ func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error { return nil } -func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) { +func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, t *Tunnel) { websocketConnections.Inc() defer websocketConnections.Dec() inout, _ := transport.NewWS(c) defer inout.Close() - s.TransportOut = inout - s.TransportIn = inout - handler := NewProcessor(s, g.ServerConf) - RegisterConnection(s.ConnId, handler, s) - defer CloseConnection(s.ConnId) + t.TransportOut = inout + t.TransportIn = inout + t.ConnectedOn = time.Now() + + handler := NewProcessor(g, t) + RegisterConnection(handler, t) + defer RemoveConnection(t.RDGId) handler.Process(ctx) } // The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server // and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different // to ensure the connections do not get cached or terminated by a proxy prematurely. -func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s *SessionInfo) { - log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil) +func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) { + log.Printf("Session %t, %t, %t", t.RDGId, t.TransportOut != nil, t.TransportIn != nil) if r.Method == MethodRDGOUT { out, err := transport.NewLegacy(w) if err != nil { - log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) + log.Printf("cannot hijack connection to support RDG OUT data channel: %t", err) return } - log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context())) + log.Printf("Opening RDGOUT for client %t", common.GetClientIp(r.Context())) - s.TransportOut = out + t.TransportOut = out out.SendAccept(true) - c.Set(s.ConnId, s, cache.DefaultExpiration) + c.Set(t.RDGId, t, cache.DefaultExpiration) } else if r.Method == MethodRDGIN { legacyConnections.Inc() defer legacyConnections.Dec() in, err := transport.NewLegacy(w) if err != nil { - log.Printf("cannot hijack connection to support RDG IN data channel: %s", err) + log.Printf("cannot hijack connection to support RDG IN data channel: %t", err) return } defer in.Close() - if s.TransportIn == nil { - s.TransportIn = in - c.Set(s.ConnId, s, cache.DefaultExpiration) + if t.TransportIn == nil { + t.TransportIn = in + c.Set(t.RDGId, t, cache.DefaultExpiration) - log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context())) + log.Printf("Opening RDGIN for client %t", common.GetClientIp(r.Context())) in.SendAccept(false) // read some initial data in.Drain() - log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context())) - handler := NewProcessor(s, g.ServerConf) - RegisterConnection(s.ConnId, handler, s) - defer CloseConnection(s.ConnId) + log.Printf("Legacy handshakeRequest done for client %t", common.GetClientIp(r.Context())) + handler := NewProcessor(g, t) + RegisterConnection(handler, t) + defer RemoveConnection(t.RDGId) handler.Process(r.Context()) } } diff --git a/cmd/rdpgw/protocol/metrics.go b/cmd/rdpgw/protocol/metrics.go new file mode 100644 index 0000000000000000000000000000000000000000..b2bfca6097d45ebbebbc394221b3e30170d8010d --- /dev/null +++ b/cmd/rdpgw/protocol/metrics.go @@ -0,0 +1,32 @@ +package protocol + +import "github.com/prometheus/client_golang/prometheus" + +var ( + connectionCache = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: "rdpgw", + Name: "connection_cache", + Help: "The amount of connections in the cache", + }) + + websocketConnections = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: "rdpgw", + Name: "websocket_connections", + Help: "The count of websocket connections", + }) + + legacyConnections = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: "rdpgw", + Name: "legacy_connections", + Help: "The count of legacy https connections", + }) +) + +func init() { + prometheus.MustRegister(connectionCache) + prometheus.MustRegister(legacyConnections) + prometheus.MustRegister(websocketConnections) +} diff --git a/cmd/rdpgw/protocol/process.go b/cmd/rdpgw/protocol/process.go index a528e690ab0f4387693cc6f0fd45023b2b7be680..7c66fa01e610d356d4ce9b1451f3885e129eee06 100644 --- a/cmd/rdpgw/protocol/process.go +++ b/cmd/rdpgw/protocol/process.go @@ -14,47 +14,23 @@ import ( "time" ) -type VerifyTunnelCreate func(context.Context, string) (bool, error) -type VerifyTunnelAuthFunc func(context.Context, string) (bool, error) -type VerifyServerFunc func(context.Context, string) (bool, error) - type Processor struct { - Session *SessionInfo - VerifyTunnelCreate VerifyTunnelCreate - VerifyTunnelAuthFunc VerifyTunnelAuthFunc - VerifyServerFunc VerifyServerFunc - RedirectFlags int - IdleTimeout int - SmartCardAuth bool - TokenAuth bool - ClientName string - Remote net.Conn - State int -} + // gw is the gateway instance on which the connection arrived + // Immutable; never nil. + gw *Gateway + + // state is the internal state of the processor + state int -type ProcessorConf struct { - VerifyTunnelCreate VerifyTunnelCreate - VerifyTunnelAuthFunc VerifyTunnelAuthFunc - VerifyServerFunc VerifyServerFunc - RedirectFlags RedirectFlags - IdleTimeout int - SmartCardAuth bool - TokenAuth bool - ReceiveBuf int - SendBuf int + // tunnel is the underlying connection with the client + tunnel *Tunnel } -func NewProcessor(s *SessionInfo, conf *ProcessorConf) *Processor { +func NewProcessor(gw *Gateway, tunnel *Tunnel) *Processor { h := &Processor{ - State: SERVER_STATE_INITIALIZED, - Session: s, - RedirectFlags: makeRedirectFlags(conf.RedirectFlags), - IdleTimeout: conf.IdleTimeout, - SmartCardAuth: conf.SmartCardAuth, - TokenAuth: conf.TokenAuth, - VerifyTunnelCreate: conf.VerifyTunnelCreate, - VerifyServerFunc: conf.VerifyServerFunc, - VerifyTunnelAuthFunc: conf.VerifyTunnelAuthFunc, + gw: gw, + state: SERVER_STATE_INITIALIZED, + tunnel: tunnel, } return h } @@ -63,7 +39,7 @@ const tunnelId = 10 func (p *Processor) Process(ctx context.Context) error { for { - pt, sz, pkt, err := readMessage(p.Session.TransportIn) + pt, sz, pkt, err := p.tunnel.Read() if err != nil { log.Printf("Cannot read message from stream %p", err) return err @@ -72,10 +48,10 @@ func (p *Processor) Process(ctx context.Context) error { switch pt { case PKT_TYPE_HANDSHAKE_REQUEST: log.Printf("Client handshakeRequest from %p", common.GetClientIp(ctx)) - if p.State != SERVER_STATE_INITIALIZED { - log.Printf("Handshake attempted while in wrong state %d != %d", p.State, SERVER_STATE_INITIALIZED) + if p.state != SERVER_STATE_INITIALIZED { + log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED) msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR) - p.Session.TransportOut.WritePacket(msg) + p.tunnel.Write(msg) return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR) } major, minor, _, reqAuth := p.handshakeRequest(pkt) @@ -83,101 +59,102 @@ func (p *Processor) Process(ctx context.Context) error { if err != nil { log.Println(err) msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH) - p.Session.TransportOut.WritePacket(msg) + p.tunnel.Write(msg) return err } msg := p.handshakeResponse(major, minor, caps, ERROR_SUCCESS) - p.Session.TransportOut.WritePacket(msg) - p.State = SERVER_STATE_HANDSHAKE + p.tunnel.Write(msg) + p.state = SERVER_STATE_HANDSHAKE case PKT_TYPE_TUNNEL_CREATE: log.Printf("Tunnel create") - if p.State != SERVER_STATE_HANDSHAKE { + if p.state != SERVER_STATE_HANDSHAKE { log.Printf("Tunnel create attempted while in wrong state %d != %d", - p.State, SERVER_STATE_HANDSHAKE) + p.state, SERVER_STATE_HANDSHAKE) msg := p.tunnelResponse(E_PROXY_INTERNALERROR) - p.Session.TransportOut.WritePacket(msg) + p.tunnel.Write(msg) return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR) } _, cookie := p.tunnelRequest(pkt) - if p.VerifyTunnelCreate != nil { - if ok, _ := p.VerifyTunnelCreate(ctx, cookie); !ok { + if p.gw.CheckPAACookie != nil { + if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok { log.Printf("Invalid PAA cookie received from client %p", common.GetClientIp(ctx)) msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) - p.Session.TransportOut.WritePacket(msg) + p.tunnel.Write(msg) return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) } } msg := p.tunnelResponse(ERROR_SUCCESS) - p.Session.TransportOut.WritePacket(msg) - p.State = SERVER_STATE_TUNNEL_CREATE + p.tunnel.Write(msg) + p.state = SERVER_STATE_TUNNEL_CREATE case PKT_TYPE_TUNNEL_AUTH: log.Printf("Tunnel auth") - if p.State != SERVER_STATE_TUNNEL_CREATE { + if p.state != SERVER_STATE_TUNNEL_CREATE { log.Printf("Tunnel auth attempted while in wrong state %d != %d", - p.State, SERVER_STATE_TUNNEL_CREATE) + p.state, SERVER_STATE_TUNNEL_CREATE) msg := p.tunnelAuthResponse(E_PROXY_INTERNALERROR) - p.Session.TransportOut.WritePacket(msg) + p.tunnel.Write(msg) return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR) } client := p.tunnelAuthRequest(pkt) - if p.VerifyTunnelAuthFunc != nil { - if ok, _ := p.VerifyTunnelAuthFunc(ctx, client); !ok { + if p.gw.CheckClientName != nil { + if ok, _ := p.gw.CheckClientName(ctx, client); !ok { log.Printf("Invalid client name: %p", client) msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED) - p.Session.TransportOut.WritePacket(msg) + p.tunnel.Write(msg) return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED) } } msg := p.tunnelAuthResponse(ERROR_SUCCESS) - p.Session.TransportOut.WritePacket(msg) - p.State = SERVER_STATE_TUNNEL_AUTHORIZE + p.tunnel.Write(msg) + p.state = SERVER_STATE_TUNNEL_AUTHORIZE case PKT_TYPE_CHANNEL_CREATE: log.Printf("Channel create") - if p.State != SERVER_STATE_TUNNEL_AUTHORIZE { + if p.state != SERVER_STATE_TUNNEL_AUTHORIZE { log.Printf("Channel create attempted while in wrong state %d != %d", - p.State, SERVER_STATE_TUNNEL_AUTHORIZE) + p.state, SERVER_STATE_TUNNEL_AUTHORIZE) msg := p.channelResponse(E_PROXY_INTERNALERROR) - p.Session.TransportOut.WritePacket(msg) + p.tunnel.Write(msg) return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR) } server, port := p.channelRequest(pkt) host := net.JoinHostPort(server, strconv.Itoa(int(port))) - if p.VerifyServerFunc != nil { + if p.gw.CheckHost != nil { log.Printf("Verifying %p host connection", host) - if ok, _ := p.VerifyServerFunc(ctx, host); !ok { + if ok, _ := p.gw.CheckHost(ctx, host); !ok { log.Printf("Not allowed to connect to %p by policy handler", host) msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED) - p.Session.TransportOut.WritePacket(msg) + p.tunnel.Write(msg) return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED) } } log.Printf("Establishing connection to RDP server: %p", host) - p.Remote, err = net.DialTimeout("tcp", host, time.Second*15) + p.tunnel.rwc, err = net.DialTimeout("tcp", host, time.Second*15) if err != nil { log.Printf("Error connecting to %p, %p", host, err) msg := p.channelResponse(E_PROXY_INTERNALERROR) - p.Session.TransportOut.WritePacket(msg) + p.tunnel.Write(msg) return err } + p.tunnel.TargetServer = host log.Printf("Connection established") msg := p.channelResponse(ERROR_SUCCESS) - p.Session.TransportOut.WritePacket(msg) + p.tunnel.Write(msg) // Make sure to start the flow from the RDP server first otherwise connections // might hang eventually - go forward(p.Remote, p.Session.TransportOut) - p.State = SERVER_STATE_CHANNEL_CREATE + go forward(p.tunnel.rwc, p.tunnel.TransportOut) + p.state = SERVER_STATE_CHANNEL_CREATE case PKT_TYPE_DATA: - if p.State < SERVER_STATE_CHANNEL_CREATE { - log.Printf("Data received while in wrong state %d != %d", p.State, SERVER_STATE_CHANNEL_CREATE) + if p.state < SERVER_STATE_CHANNEL_CREATE { + log.Printf("Data received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE) return errors.New("wrong state") } - p.State = SERVER_STATE_OPENED - receive(pkt, p.Remote) + p.state = SERVER_STATE_OPENED + receive(pkt, p.tunnel.rwc) case PKT_TYPE_KEEPALIVE: // keepalives can be received while the channel is not open yet - if p.State < SERVER_STATE_CHANNEL_CREATE { - log.Printf("Keepalive received while in wrong state %d != %d", p.State, SERVER_STATE_CHANNEL_CREATE) + if p.state < SERVER_STATE_CHANNEL_CREATE { + log.Printf("Keepalive received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE) return errors.New("wrong state") } @@ -185,15 +162,15 @@ func (p *Processor) Process(ctx context.Context) error { // p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) case PKT_TYPE_CLOSE_CHANNEL: log.Printf("Close channel") - if p.State != SERVER_STATE_OPENED { - log.Printf("Channel closed while in wrong state %d != %d", p.State, SERVER_STATE_OPENED) + if p.state != SERVER_STATE_OPENED { + log.Printf("Channel closed while in wrong state %d != %d", p.state, SERVER_STATE_OPENED) return errors.New("wrong state") } msg := p.channelCloseResponse(ERROR_SUCCESS) - p.Session.TransportOut.WritePacket(msg) - //p.Session.TransportIn.Close() - //p.Session.TransportOut.Close() - p.State = SERVER_STATE_CLOSED + p.tunnel.Write(msg) + //p.tunnel.TransportIn.Close() + //p.tunnel.TransportOut.Close() + p.state = SERVER_STATE_CLOSED return nil default: log.Printf("Unknown packet (size %d): %x", sz, pkt) @@ -226,10 +203,10 @@ func (p *Processor) handshakeRequest(data []byte) (major byte, minor byte, versi } func (p *Processor) matchAuth(clientAuthCaps uint16) (caps uint16, err error) { - if p.SmartCardAuth { + if p.gw.SmartCardAuth { caps = caps | HTTP_EXTENDED_AUTH_SC } - if p.TokenAuth { + if p.gw.TokenAuth { caps = caps | HTTP_EXTENDED_AUTH_PAA } @@ -298,12 +275,12 @@ func (p *Processor) tunnelAuthResponse(errorCode int) []byte { binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved // idle timeout - if p.IdleTimeout < 0 { - p.IdleTimeout = 0 + if p.gw.IdleTimeout < 0 { + p.gw.IdleTimeout = 0 } - binary.Write(buf, binary.LittleEndian, uint32(p.RedirectFlags)) // redir flags - binary.Write(buf, binary.LittleEndian, uint32(p.IdleTimeout)) // timeout in minutes + binary.Write(buf, binary.LittleEndian, uint32(makeRedirectFlags(p.gw.RedirectFlags))) // redir flags + binary.Write(buf, binary.LittleEndian, uint32(p.gw.IdleTimeout)) // timeout in minutes return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes()) } diff --git a/cmd/rdpgw/protocol/protocol_test.go b/cmd/rdpgw/protocol/protocol_test.go index d682f4aba56455f1bebd17a55c43778f3439560f..14981a27a95d42e110cba94b79dd670ab93ec0b1 100644 --- a/cmd/rdpgw/protocol/protocol_test.go +++ b/cmd/rdpgw/protocol/protocol_test.go @@ -40,11 +40,10 @@ func TestHandshake(t *testing.T) { client := ClientConfig{ PAAToken: "abab", } - s := &SessionInfo{} - hc := &ProcessorConf{ - TokenAuth: true, - } - h := NewProcessor(s, hc) + gw := &Gateway{} + tunnel := &Tunnel{} + + h := NewProcessor(gw, tunnel) data := client.handshakeRequest() @@ -79,33 +78,30 @@ func TestHandshake(t *testing.T) { } } -func capsHelper(h Processor) uint16 { +func capsHelper(gw Gateway) uint16 { var caps uint16 - if h.TokenAuth { + if gw.TokenAuth { caps = caps | HTTP_EXTENDED_AUTH_PAA } - if h.SmartCardAuth { + if gw.SmartCardAuth { caps = caps | HTTP_EXTENDED_AUTH_SC } return caps } func TestMatchAuth(t *testing.T) { - s := &SessionInfo{} - hc := &ProcessorConf{ - TokenAuth: false, - SmartCardAuth: false, - } + gw := &Gateway{} + tunnel := &Tunnel{} - h := NewProcessor(s, hc) + h := NewProcessor(gw, tunnel) in := uint16(0) caps, err := h.matchAuth(in) if err != nil { - t.Fatalf("in caps: %x <= server caps %x, but %s", in, capsHelper(*h), err) + t.Fatalf("in caps: %x <= server caps %x, but %s", in, capsHelper(*gw), err) } if caps > in { - t.Fatalf("returned server caps %x > client cpas %x", capsHelper(*h), in) + t.Fatalf("returned server caps %x > client cpas %x", capsHelper(*gw), in) } in = HTTP_EXTENDED_AUTH_PAA @@ -116,7 +112,7 @@ func TestMatchAuth(t *testing.T) { t.Logf("(SUCCESS) server cannot satisfy client caps : %s", err) } - h.SmartCardAuth = true + gw.SmartCardAuth = true caps, err = h.matchAuth(in) if err == nil { t.Fatalf("server cannot satisfy client caps %x but error is nil (server caps %x)", in, caps) @@ -124,10 +120,10 @@ func TestMatchAuth(t *testing.T) { t.Logf("(SUCCESS) server cannot satisfy client caps : %s", err) } - h.TokenAuth = true + gw.TokenAuth = true caps, err = h.matchAuth(in) if err != nil { - t.Fatalf("server caps %x (orig: %x) should match client request %x, %s", caps, capsHelper(*h), in, err) + t.Fatalf("server caps %x (orig: %x) should match client request %x, %s", caps, capsHelper(*gw), in, err) } } @@ -135,11 +131,10 @@ func TestTunnelCreation(t *testing.T) { client := ClientConfig{ PAAToken: "abab", } - s := &SessionInfo{} - hc := &ProcessorConf{ - TokenAuth: true, - } - h := NewProcessor(s, hc) + gw := &Gateway{TokenAuth: true} + tunnel := &Tunnel{} + + h := NewProcessor(gw, tunnel) data := client.tunnelRequest() _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE, @@ -179,15 +174,13 @@ func TestTunnelAuth(t *testing.T) { client := ClientConfig{ Name: name, } - s := &SessionInfo{} - hc := &ProcessorConf{ - TokenAuth: true, - IdleTimeout: 10, - RedirectFlags: RedirectFlags{ - Clipboard: true, - }, + gw := &Gateway{ + TokenAuth: true, + IdleTimeout: 10, + RedirectFlags: RedirectFlags{Clipboard: true}, } - h := NewProcessor(s, hc) + tunnel := &Tunnel{} + h := NewProcessor(gw, tunnel) data := client.tunnelAuthRequest() _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2)) @@ -213,9 +206,9 @@ func TestTunnelAuth(t *testing.T) { t.Fatalf("tunnelAuthResponse failed got flags %d, expected %d", flags, flags|HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD) } - if int(timeout) != hc.IdleTimeout { + if int(timeout) != gw.IdleTimeout { t.Fatalf("tunnelAuthResponse failed got timeout %d, expected %d", - timeout, hc.IdleTimeout) + timeout, gw.IdleTimeout) } } @@ -225,15 +218,15 @@ func TestChannelCreation(t *testing.T) { Server: server, Port: 3389, } - s := &SessionInfo{} - hc := &ProcessorConf{ + gw := &Gateway{ TokenAuth: true, IdleTimeout: 10, RedirectFlags: RedirectFlags{ Clipboard: true, }, } - h := NewProcessor(s, hc) + tunnel := &Tunnel{} + h := NewProcessor(gw, tunnel) data := client.channelRequest() _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2)) diff --git a/cmd/rdpgw/protocol/track.go b/cmd/rdpgw/protocol/track.go index 1e6e1c585e6567611fe2fe87afc7c78193cf7f6c..f538d9c41bf3cc69142350c909de942886ae2c59 100644 --- a/cmd/rdpgw/protocol/track.go +++ b/cmd/rdpgw/protocol/track.go @@ -1,30 +1,45 @@ package protocol -import ( - "time" -) +var Connections map[string]*Monitor -var Connections map[string]*GatewayConnection - -type GatewayConnection struct { - PacketHandler *Processor - SessionInfo *SessionInfo - Since time.Time - IsWebsocket bool +type Monitor struct { + Processor *Processor + Tunnel *Tunnel } -func RegisterConnection(connId string, h *Processor, s *SessionInfo) { +func RegisterConnection(h *Processor, t *Tunnel) { if Connections == nil { - Connections = make(map[string]*GatewayConnection) + Connections = make(map[string]*Monitor) } - Connections[connId] = &GatewayConnection{ - PacketHandler: h, - SessionInfo: s, - Since: time.Now(), + Connections[t.RDGId] = &Monitor{ + Processor: h, + Tunnel: t, } } -func CloseConnection(connId string) { +func RemoveConnection(connId string) { delete(Connections, connId) } + +// CalculateSpeedPerSecond calculate moving average. +/* +func CalculateSpeedPerSecond(connId string) (in int, out int) { + now := time.Now().UnixMilli() + + c := Connections[connId] + total := int64(0) + for _, v := range c.Tunnel.BytesReceived { + total += v + } + in = int(total / (now - c.TimeStamp) * 1000) + + total = int64(0) + for _, v := range c.BytesSent { + total += v + } + out = int(total / (now - c.TimeStamp)) + + return in, out +} +*/ diff --git a/cmd/rdpgw/protocol/tunnel.go b/cmd/rdpgw/protocol/tunnel.go new file mode 100644 index 0000000000000000000000000000000000000000..ebad094977388d60e1687b15ea0f7f27ae035303 --- /dev/null +++ b/cmd/rdpgw/protocol/tunnel.go @@ -0,0 +1,47 @@ +package protocol + +import ( + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" + "net" + "time" +) + +type Tunnel struct { + // The connection-id (RDG-ConnID) as reported by the client + RDGId string + // The underlying incoming transport being either websocket or legacy http + // in case of websocket TransportOut will equal TransportIn + TransportIn transport.Transport + // The underlying outgoing transport being either websocket or legacy http + // in case of websocket TransportOut will equal TransportOut + TransportOut transport.Transport + // The remote desktop server (rdp, vnc etc) the clients intends to connect to + TargetServer string + // The obtained client ip address + RemoteAddr string + // User + UserName string + + // rwc is the underlying connection to the remote desktop server. + // It is of the type *net.TCPConn + rwc net.Conn + + ByteSent int64 + BytesReceived int64 + + ConnectedOn time.Time + LastSeen time.Time +} + +func (t *Tunnel) Write(pkt []byte) { + n, _ := t.TransportOut.WritePacket(pkt) + t.ByteSent += int64(n) +} + +func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) { + pt, size, pkt, err = readMessage(t.TransportIn) + t.BytesReceived += int64(size) + t.LastSeen = time.Now() + + return pt, size, pkt, err +} diff --git a/cmd/rdpgw/security/basic_test.go b/cmd/rdpgw/security/basic_test.go index d65422ff5d20a204faf5a3e822fcd7a91d510aa0..40551ceb051fb5eb84db72f77647f11f524e7bcd 100644 --- a/cmd/rdpgw/security/basic_test.go +++ b/cmd/rdpgw/security/basic_test.go @@ -7,12 +7,12 @@ import ( ) var ( - info = protocol.SessionInfo{ - ConnId: "myid", + info = protocol.Tunnel{ + RDGId: "myid", TransportIn: nil, TransportOut: nil, - RemoteServer: "my.remote.server", - ClientIp: "10.0.0.1", + TargetServer: "my.remote.server", + RemoteAddr: "10.0.0.1", UserName: "Frank", } @@ -20,7 +20,7 @@ var ( ) func TestCheckHost(t *testing.T) { - ctx := context.WithValue(context.Background(), "SessionInfo", &info) + ctx := context.WithValue(context.Background(), "Tunnel", &info) Hosts = hosts diff --git a/cmd/rdpgw/security/jwt.go b/cmd/rdpgw/security/jwt.go index 9bd5f4a66130cf098cc2d0208728f9b8ebc0e3c4..c5d0f219870740b17f1de926e054827a92b0a2ce 100644 --- a/cmd/rdpgw/security/jwt.go +++ b/cmd/rdpgw/security/jwt.go @@ -33,28 +33,28 @@ type customClaims struct { AccessToken string `json:"accessToken"` } -func CheckSession(next protocol.VerifyServerFunc) protocol.VerifyServerFunc { +func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc { return func(ctx context.Context, host string) (bool, error) { s := getSessionInfo(ctx) if s == nil { return false, errors.New("no valid session info found in context") } - if s.RemoteServer != host { - log.Printf("Client specified host %s does not match token host %s", host, s.RemoteServer) + if s.TargetServer != host { + log.Printf("Client specified host %s does not match token host %s", host, s.TargetServer) return false, nil } - if VerifyClientIP && s.ClientIp != common.GetClientIp(ctx) { + if VerifyClientIP && s.RemoteAddr != common.GetClientIp(ctx) { log.Printf("Current client ip address %s does not match token client ip %s", - common.GetClientIp(ctx), s.ClientIp) + common.GetClientIp(ctx), s.RemoteAddr) return false, nil } return next(ctx, host) } } -func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) { +func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) { if tokenString == "" { log.Printf("no token to parse") return false, errors.New("no token to parse") @@ -104,8 +104,8 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) { s := getSessionInfo(ctx) - s.RemoteServer = custom.RemoteServer - s.ClientIp = custom.ClientIP + s.TargetServer = custom.RemoteServer + s.RemoteAddr = custom.ClientIP s.UserName = user.Subject return true, nil @@ -288,8 +288,8 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin return token, err } -func getSessionInfo(ctx context.Context) *protocol.SessionInfo { - s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo) +func getSessionInfo(ctx context.Context) *protocol.Tunnel { + s, ok := ctx.Value("Tunnel").(*protocol.Tunnel) if !ok { log.Printf("cannot get session info from context") return nil