diff --git a/protocol/handler.go b/protocol/handler.go index 98cb9cd1e63fcd839fa919a8fdff7647b0b9c3b0..ad5d43bd1f70ff59b6f33ba4383b92310ed79a4b 100644 --- a/protocol/handler.go +++ b/protocol/handler.go @@ -51,10 +51,10 @@ type HandlerConf struct { TokenAuth bool } -func NewHandler(in transport.Transport, out transport.Transport, conf *HandlerConf) *Handler { +func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler { h := &Handler{ - TransportIn: in, - TransportOut: out, + TransportIn: s.TransportIn, + TransportOut: s.TransportOut, State: SERVER_STATE_INITIAL, RedirectFlags: makeRedirectFlags(conf.RedirectFlags), IdleTimeout: conf.IdleTimeout, diff --git a/protocol/rdpgw.go b/protocol/rdpgw.go index ee395e21b554ca00f7db264763676d4f2a4285a8..b1a22ca196bff265472479d4ef2b2c6b2bae9d21 100644 --- a/protocol/rdpgw.go +++ b/protocol/rdpgw.go @@ -51,10 +51,9 @@ type SessionInfo struct { TransportOut transport.Transport RemoteAddress string ProxyAddresses string + UserName string } -var DefaultSession SessionInfo - var upgrader = websocket.Upgrader{} var c = cache.New(5*time.Minute, 10*time.Minute) @@ -66,9 +65,20 @@ func init() { func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) { connectionCache.Set(float64(c.ItemCount())) + + var s *SessionInfo + + connId := r.Header.Get(rdgConnectionIdKey) + x, found := c.Get(connId) + if !found { + s = &SessionInfo{ConnId: connId} + } else { + s = x.(*SessionInfo) + } + if r.Method == MethodRDGOUT { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { - g.handleLegacyProtocol(w, r) + g.handleLegacyProtocol(w, r, s) return } r.Method = "GET" // force @@ -79,35 +89,27 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) } defer conn.Close() - g.handleWebsocketProtocol(conn) + g.handleWebsocketProtocol(conn, s) } else if r.Method == MethodRDGIN { - g.handleLegacyProtocol(w, r) + g.handleLegacyProtocol(w, r, s) } } -func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn) { +func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) { websocketConnections.Inc() defer websocketConnections.Dec() inout, _ := transport.NewWS(c) - handler := NewHandler(inout, inout, g.HandlerConf) + s.TransportOut = inout + s.TransportIn = inout + handler := NewHandler(s, g.HandlerConf) handler.Process() } // The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server // and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different // to ensure the connections do not get cached or terminated by a proxy prematurely. -func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { - var s SessionInfo - - connId := r.Header.Get(rdgConnectionIdKey) - x, found := c.Get(connId) - if !found { - s = SessionInfo{ConnId: connId} - } else { - s = x.(SessionInfo) - } - +func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s *SessionInfo) { log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil) if r.Method == MethodRDGOUT { @@ -121,7 +123,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { s.TransportOut = out out.SendAccept(true) - c.Set(connId, s, cache.DefaultExpiration) + c.Set(s.ConnId, s, cache.DefaultExpiration) } else if r.Method == MethodRDGIN { legacyConnections.Inc() defer legacyConnections.Dec() @@ -135,7 +137,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { if s.TransportIn == nil { s.TransportIn = in - c.Set(connId, s, cache.DefaultExpiration) + c.Set(s.ConnId, s, cache.DefaultExpiration) log.Printf("Opening RDGIN for client %s", in.Conn.RemoteAddr().String()) in.SendAccept(false) @@ -144,7 +146,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { in.Drain() log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String()) - handler := NewHandler(in, s.TransportOut, g.HandlerConf) + handler := NewHandler(s, g.HandlerConf) handler.Process() } }