diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index 27ed7838f6593cd2ca4108ce1b90c9343204b06e..69eacd4addcc79684333f7a08addd3e6bcc73e50 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -177,7 +177,7 @@ func main() { } // create the gateway - gwConfig := protocol.ServerConf{ + gwConfig := protocol.ProcessorConf{ IdleTimeout: conf.Caps.IdleTimeout, TokenAuth: conf.Caps.TokenAuth, SmartCardAuth: conf.Caps.SmartCardAuth, @@ -202,6 +202,7 @@ func main() { gw := protocol.Gateway{ ServerConf: &gwConfig, } + gwserver = &gw if conf.Server.Authentication == config.AuthenticationBasic { h := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket} @@ -215,6 +216,7 @@ func main() { } http.Handle("/metrics", promhttp.Handler()) http.HandleFunc("/tokeninfo", web.TokenInfo) + http.HandleFunc("/list", List) if conf.Server.Tls == config.TlsDisable { err = server.ListenAndServe() @@ -225,3 +227,12 @@ func main() { log.Fatal("ListenAndServe: ", err) } } + +var gwserver *protocol.Gateway + +func List(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + for k, v := range protocol.Connections { + fmt.Fprintf(w, "ConnId: %s Connected Since: %s User: %s \n", k, v.Since, v.SessionInfo.UserName) + } +} diff --git a/cmd/rdpgw/protocol/gateway.go b/cmd/rdpgw/protocol/gateway.go index 2cd96a57b66e56373eec7c8ed28e314db9707b71..3b06769215ff430c19e974bce9843c4fc4322e33 100644 --- a/cmd/rdpgw/protocol/gateway.go +++ b/cmd/rdpgw/protocol/gateway.go @@ -46,7 +46,7 @@ var ( ) type Gateway struct { - ServerConf *ServerConf + ServerConf *ProcessorConf } var upgrader = websocket.Upgrader{} @@ -164,7 +164,9 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn s.TransportOut = inout s.TransportIn = inout - handler := NewServer(s, g.ServerConf) + handler := NewProcessor(s, g.ServerConf) + RegisterConnection(s.ConnId, handler, s) + defer CloseConnection(s.ConnId) handler.Process(ctx) } @@ -208,7 +210,9 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s in.Drain() log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context())) - handler := NewServer(s, g.ServerConf) + handler := NewProcessor(s, g.ServerConf) + RegisterConnection(s.ConnId, handler, s) + defer CloseConnection(s.ConnId) handler.Process(r.Context()) } } diff --git a/cmd/rdpgw/protocol/server.go b/cmd/rdpgw/protocol/process.go similarity index 66% rename from cmd/rdpgw/protocol/server.go rename to cmd/rdpgw/protocol/process.go index 3f073338563fa393b53d04a56b51f496f5759804..a528e690ab0f4387693cc6f0fd45023b2b7be680 100644 --- a/cmd/rdpgw/protocol/server.go +++ b/cmd/rdpgw/protocol/process.go @@ -18,7 +18,7 @@ type VerifyTunnelCreate func(context.Context, string) (bool, error) type VerifyTunnelAuthFunc func(context.Context, string) (bool, error) type VerifyServerFunc func(context.Context, string) (bool, error) -type Server struct { +type Processor struct { Session *SessionInfo VerifyTunnelCreate VerifyTunnelCreate VerifyTunnelAuthFunc VerifyTunnelAuthFunc @@ -32,7 +32,7 @@ type Server struct { State int } -type ServerConf struct { +type ProcessorConf struct { VerifyTunnelCreate VerifyTunnelCreate VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyServerFunc VerifyServerFunc @@ -44,8 +44,8 @@ type ServerConf struct { SendBuf int } -func NewServer(s *SessionInfo, conf *ServerConf) *Server { - h := &Server{ +func NewProcessor(s *SessionInfo, conf *ProcessorConf) *Processor { + h := &Processor{ State: SERVER_STATE_INITIALIZED, Session: s, RedirectFlags: makeRedirectFlags(conf.RedirectFlags), @@ -61,123 +61,123 @@ func NewServer(s *SessionInfo, conf *ServerConf) *Server { const tunnelId = 10 -func (s *Server) Process(ctx context.Context) error { +func (p *Processor) Process(ctx context.Context) error { for { - pt, sz, pkt, err := readMessage(s.Session.TransportIn) + pt, sz, pkt, err := readMessage(p.Session.TransportIn) if err != nil { - log.Printf("Cannot read message from stream %s", err) + log.Printf("Cannot read message from stream %p", err) return err } switch pt { case PKT_TYPE_HANDSHAKE_REQUEST: - log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx)) - if s.State != SERVER_STATE_INITIALIZED { - log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIALIZED) - msg := s.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR) - s.Session.TransportOut.WritePacket(msg) + log.Printf("Client handshakeRequest from %p", common.GetClientIp(ctx)) + if p.State != SERVER_STATE_INITIALIZED { + log.Printf("Handshake attempted while in wrong state %d != %d", p.State, SERVER_STATE_INITIALIZED) + msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR) + p.Session.TransportOut.WritePacket(msg) return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR) } - major, minor, _, reqAuth := s.handshakeRequest(pkt) - caps, err := s.matchAuth(reqAuth) + major, minor, _, reqAuth := p.handshakeRequest(pkt) + caps, err := p.matchAuth(reqAuth) if err != nil { log.Println(err) - msg := s.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH) - s.Session.TransportOut.WritePacket(msg) + msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH) + p.Session.TransportOut.WritePacket(msg) return err } - msg := s.handshakeResponse(major, minor, caps, ERROR_SUCCESS) - s.Session.TransportOut.WritePacket(msg) - s.State = SERVER_STATE_HANDSHAKE + msg := p.handshakeResponse(major, minor, caps, ERROR_SUCCESS) + p.Session.TransportOut.WritePacket(msg) + p.State = SERVER_STATE_HANDSHAKE case PKT_TYPE_TUNNEL_CREATE: log.Printf("Tunnel create") - if s.State != SERVER_STATE_HANDSHAKE { + if p.State != SERVER_STATE_HANDSHAKE { log.Printf("Tunnel create attempted while in wrong state %d != %d", - s.State, SERVER_STATE_HANDSHAKE) - msg := s.tunnelResponse(E_PROXY_INTERNALERROR) - s.Session.TransportOut.WritePacket(msg) + p.State, SERVER_STATE_HANDSHAKE) + msg := p.tunnelResponse(E_PROXY_INTERNALERROR) + p.Session.TransportOut.WritePacket(msg) return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR) } - _, cookie := s.tunnelRequest(pkt) - if s.VerifyTunnelCreate != nil { - if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok { - log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx)) - msg := s.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) - s.Session.TransportOut.WritePacket(msg) + _, cookie := p.tunnelRequest(pkt) + if p.VerifyTunnelCreate != nil { + if ok, _ := p.VerifyTunnelCreate(ctx, cookie); !ok { + log.Printf("Invalid PAA cookie received from client %p", common.GetClientIp(ctx)) + msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) + p.Session.TransportOut.WritePacket(msg) return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) } } - msg := s.tunnelResponse(ERROR_SUCCESS) - s.Session.TransportOut.WritePacket(msg) - s.State = SERVER_STATE_TUNNEL_CREATE + msg := p.tunnelResponse(ERROR_SUCCESS) + p.Session.TransportOut.WritePacket(msg) + p.State = SERVER_STATE_TUNNEL_CREATE case PKT_TYPE_TUNNEL_AUTH: log.Printf("Tunnel auth") - if s.State != SERVER_STATE_TUNNEL_CREATE { + if p.State != SERVER_STATE_TUNNEL_CREATE { log.Printf("Tunnel auth attempted while in wrong state %d != %d", - s.State, SERVER_STATE_TUNNEL_CREATE) - msg := s.tunnelAuthResponse(E_PROXY_INTERNALERROR) - s.Session.TransportOut.WritePacket(msg) + p.State, SERVER_STATE_TUNNEL_CREATE) + msg := p.tunnelAuthResponse(E_PROXY_INTERNALERROR) + p.Session.TransportOut.WritePacket(msg) return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR) } - client := s.tunnelAuthRequest(pkt) - if s.VerifyTunnelAuthFunc != nil { - if ok, _ := s.VerifyTunnelAuthFunc(ctx, client); !ok { - log.Printf("Invalid client name: %s", client) - msg := s.tunnelAuthResponse(ERROR_ACCESS_DENIED) - s.Session.TransportOut.WritePacket(msg) + client := p.tunnelAuthRequest(pkt) + if p.VerifyTunnelAuthFunc != nil { + if ok, _ := p.VerifyTunnelAuthFunc(ctx, client); !ok { + log.Printf("Invalid client name: %p", client) + msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED) + p.Session.TransportOut.WritePacket(msg) return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED) } } - msg := s.tunnelAuthResponse(ERROR_SUCCESS) - s.Session.TransportOut.WritePacket(msg) - s.State = SERVER_STATE_TUNNEL_AUTHORIZE + msg := p.tunnelAuthResponse(ERROR_SUCCESS) + p.Session.TransportOut.WritePacket(msg) + p.State = SERVER_STATE_TUNNEL_AUTHORIZE case PKT_TYPE_CHANNEL_CREATE: log.Printf("Channel create") - if s.State != SERVER_STATE_TUNNEL_AUTHORIZE { + if p.State != SERVER_STATE_TUNNEL_AUTHORIZE { log.Printf("Channel create attempted while in wrong state %d != %d", - s.State, SERVER_STATE_TUNNEL_AUTHORIZE) - msg := s.channelResponse(E_PROXY_INTERNALERROR) - s.Session.TransportOut.WritePacket(msg) + p.State, SERVER_STATE_TUNNEL_AUTHORIZE) + msg := p.channelResponse(E_PROXY_INTERNALERROR) + p.Session.TransportOut.WritePacket(msg) return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR) } - server, port := s.channelRequest(pkt) + server, port := p.channelRequest(pkt) host := net.JoinHostPort(server, strconv.Itoa(int(port))) - if s.VerifyServerFunc != nil { - log.Printf("Verifying %s host connection", host) - if ok, _ := s.VerifyServerFunc(ctx, host); !ok { - log.Printf("Not allowed to connect to %s by policy handler", host) - msg := s.channelResponse(E_PROXY_RAP_ACCESSDENIED) - s.Session.TransportOut.WritePacket(msg) + if p.VerifyServerFunc != nil { + log.Printf("Verifying %p host connection", host) + if ok, _ := p.VerifyServerFunc(ctx, host); !ok { + log.Printf("Not allowed to connect to %p by policy handler", host) + msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED) + p.Session.TransportOut.WritePacket(msg) return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED) } } - log.Printf("Establishing connection to RDP server: %s", host) - s.Remote, err = net.DialTimeout("tcp", host, time.Second*15) + log.Printf("Establishing connection to RDP server: %p", host) + p.Remote, err = net.DialTimeout("tcp", host, time.Second*15) if err != nil { - log.Printf("Error connecting to %s, %s", host, err) - msg := s.channelResponse(E_PROXY_INTERNALERROR) - s.Session.TransportOut.WritePacket(msg) + log.Printf("Error connecting to %p, %p", host, err) + msg := p.channelResponse(E_PROXY_INTERNALERROR) + p.Session.TransportOut.WritePacket(msg) return err } log.Printf("Connection established") - msg := s.channelResponse(ERROR_SUCCESS) - s.Session.TransportOut.WritePacket(msg) + msg := p.channelResponse(ERROR_SUCCESS) + p.Session.TransportOut.WritePacket(msg) // Make sure to start the flow from the RDP server first otherwise connections // might hang eventually - go forward(s.Remote, s.Session.TransportOut) - s.State = SERVER_STATE_CHANNEL_CREATE + go forward(p.Remote, p.Session.TransportOut) + p.State = SERVER_STATE_CHANNEL_CREATE case PKT_TYPE_DATA: - if s.State < SERVER_STATE_CHANNEL_CREATE { - log.Printf("Data received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE) + if p.State < SERVER_STATE_CHANNEL_CREATE { + log.Printf("Data received while in wrong state %d != %d", p.State, SERVER_STATE_CHANNEL_CREATE) return errors.New("wrong state") } - s.State = SERVER_STATE_OPENED - receive(pkt, s.Remote) + p.State = SERVER_STATE_OPENED + receive(pkt, p.Remote) case PKT_TYPE_KEEPALIVE: // keepalives can be received while the channel is not open yet - if s.State < SERVER_STATE_CHANNEL_CREATE { - log.Printf("Keepalive received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE) + if p.State < SERVER_STATE_CHANNEL_CREATE { + log.Printf("Keepalive received while in wrong state %d != %d", p.State, SERVER_STATE_CHANNEL_CREATE) return errors.New("wrong state") } @@ -185,15 +185,15 @@ func (s *Server) Process(ctx context.Context) error { // p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) case PKT_TYPE_CLOSE_CHANNEL: log.Printf("Close channel") - if s.State != SERVER_STATE_OPENED { - log.Printf("Channel closed while in wrong state %d != %d", s.State, SERVER_STATE_OPENED) + if p.State != SERVER_STATE_OPENED { + log.Printf("Channel closed while in wrong state %d != %d", p.State, SERVER_STATE_OPENED) return errors.New("wrong state") } - msg := s.channelCloseResponse(ERROR_SUCCESS) - s.Session.TransportOut.WritePacket(msg) - //s.Session.TransportIn.Close() - //s.Session.TransportOut.Close() - s.State = SERVER_STATE_CLOSED + msg := p.channelCloseResponse(ERROR_SUCCESS) + p.Session.TransportOut.WritePacket(msg) + //p.Session.TransportIn.Close() + //p.Session.TransportOut.Close() + p.State = SERVER_STATE_CLOSED return nil default: log.Printf("Unknown packet (size %d): %x", sz, pkt) @@ -204,7 +204,7 @@ func (s *Server) Process(ctx context.Context) error { // Creates a packet the is a response to a handshakeRequest request // HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux // but could be in Windows. However the NTLM protocol is insecure -func (s *Server) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte { +func (p *Processor) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error_code buf.Write([]byte{major, minor}) @@ -214,7 +214,7 @@ func (s *Server) handshakeResponse(major byte, minor byte, caps uint16, errorCod return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes()) } -func (s *Server) handshakeRequest(data []byte) (major byte, minor byte, version uint16, extAuth uint16) { +func (p *Processor) handshakeRequest(data []byte) (major byte, minor byte, version uint16, extAuth uint16) { r := bytes.NewReader(data) binary.Read(r, binary.LittleEndian, &major) binary.Read(r, binary.LittleEndian, &minor) @@ -225,11 +225,11 @@ func (s *Server) handshakeRequest(data []byte) (major byte, minor byte, version return } -func (s *Server) matchAuth(clientAuthCaps uint16) (caps uint16, err error) { - if s.SmartCardAuth { +func (p *Processor) matchAuth(clientAuthCaps uint16) (caps uint16, err error) { + if p.SmartCardAuth { caps = caps | HTTP_EXTENDED_AUTH_SC } - if s.TokenAuth { + if p.TokenAuth { caps = caps | HTTP_EXTENDED_AUTH_PAA } @@ -243,7 +243,7 @@ func (s *Server) matchAuth(clientAuthCaps uint16) (caps uint16, err error) { return caps, nil } -func (s *Server) tunnelRequest(data []byte) (caps uint32, cookie string) { +func (p *Processor) tunnelRequest(data []byte) (caps uint32, cookie string) { var fields uint16 r := bytes.NewReader(data) @@ -262,7 +262,7 @@ func (s *Server) tunnelRequest(data []byte) (caps uint32, cookie string) { return } -func (s *Server) tunnelResponse(errorCode int) []byte { +func (p *Processor) tunnelResponse(errorCode int) []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.LittleEndian, uint16(0)) // server version @@ -278,7 +278,7 @@ func (s *Server) tunnelResponse(errorCode int) []byte { return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes()) } -func (s *Server) tunnelAuthRequest(data []byte) string { +func (p *Processor) tunnelAuthRequest(data []byte) string { buf := bytes.NewReader(data) var size uint16 @@ -290,7 +290,7 @@ func (s *Server) tunnelAuthRequest(data []byte) string { return clientName } -func (s *Server) tunnelAuthResponse(errorCode int) []byte { +func (p *Processor) tunnelAuthResponse(errorCode int) []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error code @@ -298,17 +298,17 @@ func (s *Server) tunnelAuthResponse(errorCode int) []byte { binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved // idle timeout - if s.IdleTimeout < 0 { - s.IdleTimeout = 0 + if p.IdleTimeout < 0 { + p.IdleTimeout = 0 } - binary.Write(buf, binary.LittleEndian, uint32(s.RedirectFlags)) // redir flags - binary.Write(buf, binary.LittleEndian, uint32(s.IdleTimeout)) // timeout in minutes + binary.Write(buf, binary.LittleEndian, uint32(p.RedirectFlags)) // redir flags + binary.Write(buf, binary.LittleEndian, uint32(p.IdleTimeout)) // timeout in minutes return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes()) } -func (s *Server) channelRequest(data []byte) (server string, port uint16) { +func (p *Processor) channelRequest(data []byte) (server string, port uint16) { buf := bytes.NewReader(data) var resourcesSize byte @@ -330,7 +330,7 @@ func (s *Server) channelRequest(data []byte) (server string, port uint16) { return } -func (s *Server) channelResponse(errorCode int) []byte { +func (p *Processor) channelResponse(errorCode int) []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error code @@ -349,7 +349,7 @@ func (s *Server) channelResponse(errorCode int) []byte { return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes()) } -func (s *Server) channelCloseResponse(errorCode int) []byte { +func (p *Processor) channelCloseResponse(errorCode int) []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error code diff --git a/cmd/rdpgw/protocol/protocol_test.go b/cmd/rdpgw/protocol/protocol_test.go index a57213a597ddee7311e6dabdad5c7ae6cb9bfaf1..d682f4aba56455f1bebd17a55c43778f3439560f 100644 --- a/cmd/rdpgw/protocol/protocol_test.go +++ b/cmd/rdpgw/protocol/protocol_test.go @@ -14,8 +14,8 @@ const ( TunnelCreateResponseLen = HeaderLen + 18 TunnelAuthLen = HeaderLen + 2 // + dynamic TunnelAuthResponseLen = HeaderLen + 16 - ChannelCreateLen = HeaderLen + 8 // + dynamic - ChannelResponseLen = HeaderLen + 12 + ChannelCreateLen = HeaderLen + 8 // + dynamic + ChannelResponseLen = HeaderLen + 12 ) func verifyPacketHeader(data []byte, expPt uint16, expSize uint32) (uint16, uint32, []byte, error) { @@ -41,10 +41,10 @@ func TestHandshake(t *testing.T) { PAAToken: "abab", } s := &SessionInfo{} - hc := &ServerConf{ + hc := &ProcessorConf{ TokenAuth: true, } - h := NewServer(s, hc) + h := NewProcessor(s, hc) data := client.handshakeRequest() @@ -79,7 +79,7 @@ func TestHandshake(t *testing.T) { } } -func capsHelper(h Server) uint16 { +func capsHelper(h Processor) uint16 { var caps uint16 if h.TokenAuth { caps = caps | HTTP_EXTENDED_AUTH_PAA @@ -92,12 +92,12 @@ func capsHelper(h Server) uint16 { func TestMatchAuth(t *testing.T) { s := &SessionInfo{} - hc := &ServerConf{ - TokenAuth: false, + hc := &ProcessorConf{ + TokenAuth: false, SmartCardAuth: false, } - h:= NewServer(s, hc) + h := NewProcessor(s, hc) in := uint16(0) caps, err := h.matchAuth(in) @@ -136,10 +136,10 @@ func TestTunnelCreation(t *testing.T) { PAAToken: "abab", } s := &SessionInfo{} - hc := &ServerConf{ + hc := &ProcessorConf{ TokenAuth: true, } - h := NewServer(s, hc) + h := NewProcessor(s, hc) data := client.tunnelRequest() _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE, @@ -180,14 +180,14 @@ func TestTunnelAuth(t *testing.T) { Name: name, } s := &SessionInfo{} - hc := &ServerConf{ + hc := &ProcessorConf{ TokenAuth: true, IdleTimeout: 10, RedirectFlags: RedirectFlags{ Clipboard: true, }, } - h := NewServer(s, hc) + h := NewProcessor(s, hc) data := client.tunnelAuthRequest() _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2)) @@ -223,17 +223,17 @@ func TestChannelCreation(t *testing.T) { server := "test_server" client := ClientConfig{ Server: server, - Port: 3389, + Port: 3389, } s := &SessionInfo{} - hc := &ServerConf{ + hc := &ProcessorConf{ TokenAuth: true, IdleTimeout: 10, RedirectFlags: RedirectFlags{ Clipboard: true, }, } - h := NewServer(s, hc) + h := NewProcessor(s, hc) data := client.channelRequest() _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2)) diff --git a/cmd/rdpgw/protocol/track.go b/cmd/rdpgw/protocol/track.go new file mode 100644 index 0000000000000000000000000000000000000000..1e6e1c585e6567611fe2fe87afc7c78193cf7f6c --- /dev/null +++ b/cmd/rdpgw/protocol/track.go @@ -0,0 +1,30 @@ +package protocol + +import ( + "time" +) + +var Connections map[string]*GatewayConnection + +type GatewayConnection struct { + PacketHandler *Processor + SessionInfo *SessionInfo + Since time.Time + IsWebsocket bool +} + +func RegisterConnection(connId string, h *Processor, s *SessionInfo) { + if Connections == nil { + Connections = make(map[string]*GatewayConnection) + } + + Connections[connId] = &GatewayConnection{ + PacketHandler: h, + SessionInfo: s, + Since: time.Now(), + } +} + +func CloseConnection(connId string) { + delete(Connections, connId) +}