diff --git a/api/web.go b/api/web.go
index 7dd24adf57ca07e8ca9f709ab8aecce06009faca..f3686d60c0d82bcfe1a9ce7a5c71067c5b99f3fd 100644
--- a/api/web.go
+++ b/api/web.go
@@ -127,31 +127,27 @@ func (c *Config) Authenticated(next http.Handler) http.Handler {
 			return
 		}
 
-		next.ServeHTTP(w, r)
+		ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
+		next.ServeHTTP(w, r.WithContext(ctx))
 	})
 }
 
 func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
-	session, err := c.store.Get(r, RdpGwSession)
-	if err != nil {
-		http.Error(w, err.Error(), http.StatusInternalServerError)
-		return
-	}
+	ctx := r.Context()
+	userName, ok := ctx.Value("preferred_username").(string)
 
-	userName := session.Values["preferred_username"]
-	if userName == nil || userName.(string) == "" {
-		// This shouldnt happen if the Authenticated handler is used to wrap this func
-		log.Printf("Found expired or non existent session")
-		http.Error(w, errors.New("cannot find session").Error(), http.StatusInternalServerError)
+	if !ok {
+		log.Printf("preferred_username not found in context")
+		http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
 		return
 	}
 
 	// do a round robin selection for now
 	rand.Seed(time.Now().Unix())
 	host := c.Hosts[rand.Intn(len(c.Hosts))]
-	host = strings.Replace(host, "{{ preferred_username }}", userName.(string), 1)
+	host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
 
-	user := userName.(string)
+	user := userName
 	if c.UsernameTemplate != "" {
 		user = strings.Replace(c.UsernameTemplate, "{{ username }}", user, 1)
 		if c.UsernameTemplate == user {
diff --git a/protocol/handler.go b/protocol/handler.go
index 5ee4468241e96703a891cac1eeaca589267ec76d..c622057308a4cdd298175af0e3519039157e3ad3 100644
--- a/protocol/handler.go
+++ b/protocol/handler.go
@@ -2,6 +2,7 @@ package protocol
 
 import (
 	"bytes"
+	"context"
 	"encoding/binary"
 	"errors"
 	"github.com/bolkedebruin/rdpgw/transport"
@@ -71,7 +72,7 @@ func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
 
 const tunnelId = 10
 
-func (h *Handler) Process() error {
+func (h *Handler) Process(ctx context.Context) error {
 	for {
 		pt, sz, pkt, err := h.ReadMessage()
 		if err != nil {
diff --git a/protocol/rdpgw.go b/protocol/rdpgw.go
index b1a22ca196bff265472479d4ef2b2c6b2bae9d21..f30855c7edca2735f4dc908b92bfee50c1023846 100644
--- a/protocol/rdpgw.go
+++ b/protocol/rdpgw.go
@@ -1,6 +1,7 @@
 package protocol
 
 import (
+	"context"
 	"github.com/bolkedebruin/rdpgw/transport"
 	"github.com/gorilla/websocket"
 	"github.com/patrickmn/go-cache"
@@ -64,6 +65,9 @@ func init() {
 }
 
 func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
 	connectionCache.Set(float64(c.ItemCount()))
 
 	var s *SessionInfo
@@ -78,7 +82,7 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
 
 	if r.Method == MethodRDGOUT {
 		if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
-			g.handleLegacyProtocol(w, r, s)
+			g.handleLegacyProtocol(w, r.WithContext(ctx), s)
 			return
 		}
 		r.Method = "GET" // force
@@ -89,13 +93,13 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
 		}
 		defer conn.Close()
 
-		g.handleWebsocketProtocol(conn, s)
+		g.handleWebsocketProtocol(ctx, conn, s)
 	} else if r.Method == MethodRDGIN {
-		g.handleLegacyProtocol(w, r, s)
+		g.handleLegacyProtocol(w, r.WithContext(ctx), s)
 	}
 }
 
-func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
+func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) {
 	websocketConnections.Inc()
 	defer websocketConnections.Dec()
 
@@ -103,7 +107,7 @@ func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
 	s.TransportOut = inout
 	s.TransportIn = inout
 	handler := NewHandler(s, g.HandlerConf)
-	handler.Process()
+	handler.Process(ctx)
 }
 
 // The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
@@ -147,7 +151,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
 
 			log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
 			handler := NewHandler(s, g.HandlerConf)
-			handler.Process()
+			handler.Process(r.Context())
 		}
 	}
 }