diff --git a/api/web.go b/api/web.go index 7dd24adf57ca07e8ca9f709ab8aecce06009faca..f3686d60c0d82bcfe1a9ce7a5c71067c5b99f3fd 100644 --- a/api/web.go +++ b/api/web.go @@ -127,31 +127,27 @@ func (c *Config) Authenticated(next http.Handler) http.Handler { return } - next.ServeHTTP(w, r) + ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"]) + next.ServeHTTP(w, r.WithContext(ctx)) }) } func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) { - session, err := c.store.Get(r, RdpGwSession) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + ctx := r.Context() + userName, ok := ctx.Value("preferred_username").(string) - userName := session.Values["preferred_username"] - if userName == nil || userName.(string) == "" { - // This shouldnt happen if the Authenticated handler is used to wrap this func - log.Printf("Found expired or non existent session") - http.Error(w, errors.New("cannot find session").Error(), http.StatusInternalServerError) + if !ok { + log.Printf("preferred_username not found in context") + http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError) return } // do a round robin selection for now rand.Seed(time.Now().Unix()) host := c.Hosts[rand.Intn(len(c.Hosts))] - host = strings.Replace(host, "{{ preferred_username }}", userName.(string), 1) + host = strings.Replace(host, "{{ preferred_username }}", userName, 1) - user := userName.(string) + user := userName if c.UsernameTemplate != "" { user = strings.Replace(c.UsernameTemplate, "{{ username }}", user, 1) if c.UsernameTemplate == user { diff --git a/protocol/handler.go b/protocol/handler.go index 5ee4468241e96703a891cac1eeaca589267ec76d..c622057308a4cdd298175af0e3519039157e3ad3 100644 --- a/protocol/handler.go +++ b/protocol/handler.go @@ -2,6 +2,7 @@ package protocol import ( "bytes" + "context" "encoding/binary" "errors" "github.com/bolkedebruin/rdpgw/transport" @@ -71,7 +72,7 @@ func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler { const tunnelId = 10 -func (h *Handler) Process() error { +func (h *Handler) Process(ctx context.Context) error { for { pt, sz, pkt, err := h.ReadMessage() if err != nil { diff --git a/protocol/rdpgw.go b/protocol/rdpgw.go index b1a22ca196bff265472479d4ef2b2c6b2bae9d21..f30855c7edca2735f4dc908b92bfee50c1023846 100644 --- a/protocol/rdpgw.go +++ b/protocol/rdpgw.go @@ -1,6 +1,7 @@ package protocol import ( + "context" "github.com/bolkedebruin/rdpgw/transport" "github.com/gorilla/websocket" "github.com/patrickmn/go-cache" @@ -64,6 +65,9 @@ func init() { } func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + connectionCache.Set(float64(c.ItemCount())) var s *SessionInfo @@ -78,7 +82,7 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) if r.Method == MethodRDGOUT { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { - g.handleLegacyProtocol(w, r, s) + g.handleLegacyProtocol(w, r.WithContext(ctx), s) return } r.Method = "GET" // force @@ -89,13 +93,13 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) } defer conn.Close() - g.handleWebsocketProtocol(conn, s) + g.handleWebsocketProtocol(ctx, conn, s) } else if r.Method == MethodRDGIN { - g.handleLegacyProtocol(w, r, s) + g.handleLegacyProtocol(w, r.WithContext(ctx), s) } } -func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) { +func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) { websocketConnections.Inc() defer websocketConnections.Dec() @@ -103,7 +107,7 @@ func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) { s.TransportOut = inout s.TransportIn = inout handler := NewHandler(s, g.HandlerConf) - handler.Process() + handler.Process(ctx) } // The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server @@ -147,7 +151,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String()) handler := NewHandler(s, g.HandlerConf) - handler.Process() + handler.Process(r.Context()) } } }