diff --git a/api/web.go b/api/web.go
index f3686d60c0d82bcfe1a9ce7a5c71067c5b99f3fd..c262203b86cdb15ddcfe7e4c866edab9f39fa800 100644
--- a/api/web.go
+++ b/api/web.go
@@ -19,9 +19,10 @@ import (
 
 const (
 	RdpGwSession = "RDPGWSESSION"
+	MaxAge 		 = 120
 )
 
-type TokenGeneratorFunc func(string, string) (string, error)
+type TokenGeneratorFunc func(context.Context, string, string) (string, error)
 
 type Config struct {
 	SessionKey           []byte
@@ -99,6 +100,7 @@ func (c *Config) HandleCallback(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
+	session.Options.MaxAge = MaxAge
 	session.Values["preferred_username"] = data["preferred_username"]
 	session.Values["authenticated"] = true
 
@@ -157,7 +159,7 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
 		}
 	}
 
-	token, err := c.TokenGenerator(user, host)
+	token, err := c.TokenGenerator(ctx, user, host)
 	if err != nil {
 		log.Printf("Cannot generate token for user %s due to %s", user, err)
 		http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
diff --git a/main.go b/main.go
index f59dd120400c1d10faebe895d0ed818ced45f9a2..98fed5c963a500ae073faee331e80c81c15675ca 100644
--- a/main.go
+++ b/main.go
@@ -116,6 +116,7 @@ func main() {
 			EnableAll: conf.Caps.RedirectAll,
 		},
 		VerifyTunnelCreate: security.VerifyPAAToken,
+		VerifyServerFunc: security.VerifyServerFunc,
 	}
 	gw := protocol.Gateway{
 		HandlerConf: &handlerConfig,
diff --git a/protocol/handler.go b/protocol/handler.go
index c622057308a4cdd298175af0e3519039157e3ad3..138b6057b165bfe9990fdd77e554743baf3db73b 100644
--- a/protocol/handler.go
+++ b/protocol/handler.go
@@ -5,7 +5,6 @@ import (
 	"context"
 	"encoding/binary"
 	"errors"
-	"github.com/bolkedebruin/rdpgw/transport"
 	"io"
 	"log"
 	"net"
@@ -13,9 +12,9 @@ import (
 	"time"
 )
 
-type VerifyTunnelCreate func(*SessionInfo, string) (bool, error)
-type VerifyTunnelAuthFunc func(*SessionInfo, string) (bool, error)
-type VerifyServerFunc func(*SessionInfo, string) (bool, error)
+type VerifyTunnelCreate func(context.Context, string) (bool, error)
+type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
+type VerifyServerFunc func(context.Context, string) (bool, error)
 
 type RedirectFlags struct {
 	Clipboard  bool
@@ -29,8 +28,6 @@ type RedirectFlags struct {
 
 type Handler struct {
 	Session              *SessionInfo
-	TransportIn          transport.Transport
-	TransportOut         transport.Transport
 	VerifyTunnelCreate   VerifyTunnelCreate
 	VerifyTunnelAuthFunc VerifyTunnelAuthFunc
 	VerifyServerFunc     VerifyServerFunc
@@ -55,10 +52,8 @@ type HandlerConf struct {
 
 func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
 	h := &Handler{
+		State:				  SERVER_STATE_INITIAL,
 		Session:              s,
-		TransportIn:          s.TransportIn,
-		TransportOut:         s.TransportOut,
-		State:                SERVER_STATE_INITIAL,
 		RedirectFlags:        makeRedirectFlags(conf.RedirectFlags),
 		IdleTimeout:          conf.IdleTimeout,
 		SmartCardAuth:        conf.SmartCardAuth,
@@ -89,7 +84,7 @@ func (h *Handler) Process(ctx context.Context) error {
 			}
 			major, minor, _, _ := readHandshake(pkt) // todo check if auth matches what the handler can do
 			msg := h.handshakeResponse(major, minor)
-			h.TransportOut.WritePacket(msg)
+			h.Session.TransportOut.WritePacket(msg)
 			h.State = SERVER_STATE_HANDSHAKE
 		case PKT_TYPE_TUNNEL_CREATE:
 			log.Printf("Tunnel create")
@@ -100,13 +95,13 @@ func (h *Handler) Process(ctx context.Context) error {
 			}
 			_, cookie := readCreateTunnelRequest(pkt)
 			if h.VerifyTunnelCreate != nil {
-				if ok, _ := h.VerifyTunnelCreate(h.Session, cookie); !ok {
+				if ok, _ := h.VerifyTunnelCreate(ctx, cookie); !ok {
 					log.Printf("Invalid PAA cookie received")
 					return errors.New("invalid PAA cookie")
 				}
 			}
 			msg := createTunnelResponse()
-			h.TransportOut.WritePacket(msg)
+			h.Session.TransportOut.WritePacket(msg)
 			h.State = SERVER_STATE_TUNNEL_CREATE
 		case PKT_TYPE_TUNNEL_AUTH:
 			log.Printf("Tunnel auth")
@@ -117,13 +112,13 @@ func (h *Handler) Process(ctx context.Context) error {
 			}
 			client := h.readTunnelAuthRequest(pkt)
 			if h.VerifyTunnelAuthFunc != nil {
-				if ok, _ := h.VerifyTunnelAuthFunc(h.Session, client); !ok {
+				if ok, _ := h.VerifyTunnelAuthFunc(ctx, client); !ok {
 					log.Printf("Invalid client name: %s", client)
 					return errors.New("invalid client name")
 				}
 			}
 			msg := h.createTunnelAuthResponse()
-			h.TransportOut.WritePacket(msg)
+			h.Session.TransportOut.WritePacket(msg)
 			h.State = SERVER_STATE_TUNNEL_AUTHORIZE
 		case PKT_TYPE_CHANNEL_CREATE:
 			log.Printf("Channel create")
@@ -135,8 +130,9 @@ func (h *Handler) Process(ctx context.Context) error {
 			server, port := readChannelCreateRequest(pkt)
 			host := net.JoinHostPort(server, strconv.Itoa(int(port)))
 			if h.VerifyServerFunc != nil {
-				if ok, _ := h.VerifyServerFunc(h.Session, host); !ok {
+				if ok, _ := h.VerifyServerFunc(ctx, host); !ok {
 					log.Printf("Not allowed to connect to %s by policy handler", host)
+					return errors.New("denied by security policy")
 				}
 			}
 			log.Printf("Establishing connection to RDP server: %s", host)
@@ -147,7 +143,7 @@ func (h *Handler) Process(ctx context.Context) error {
 			}
 			log.Printf("Connection established")
 			msg := createChannelCreateResponse()
-			h.TransportOut.WritePacket(msg)
+			h.Session.TransportOut.WritePacket(msg)
 
 			// Make sure to start the flow from the RDP server first otherwise connections
 			// might hang eventually
@@ -175,8 +171,8 @@ func (h *Handler) Process(ctx context.Context) error {
 				log.Printf("Channel closed while in wrong state %d != %d", h.State, SERVER_STATE_OPENED)
 				return errors.New("wrong state")
 			}
-			h.TransportIn.Close()
-			h.TransportOut.Close()
+			h.Session.TransportIn.Close()
+			h.Session.TransportOut.Close()
 			h.State = SERVER_STATE_CLOSED
 		default:
 			log.Printf("Unknown packet (size %d): %x", sz, pkt)
@@ -190,7 +186,7 @@ func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
 	buf := make([]byte, 4096)
 
 	for {
-		size, pkt, err := h.TransportIn.ReadPacket()
+		size, pkt, err := h.Session.TransportIn.ReadPacket()
 		if err != nil {
 			return 0, 0, []byte{0, 0}, err
 		}
@@ -398,7 +394,7 @@ func (h *Handler) sendDataPacket() {
 			break
 		}
 		b1.Write(buf[:n])
-		h.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
+		h.Session.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
 		b1.Reset()
 	}
 }
diff --git a/protocol/rdpgw.go b/protocol/rdpgw.go
index f30855c7edca2735f4dc908b92bfee50c1023846..12bff90ca536db0aaaba59894a06bf74583c3fd0 100644
--- a/protocol/rdpgw.go
+++ b/protocol/rdpgw.go
@@ -46,13 +46,11 @@ type Gateway struct {
 
 type SessionInfo struct {
 	ConnId           string
-	CorrelationId    string
-	ClientGeneration string
 	TransportIn      transport.Transport
 	TransportOut     transport.Transport
 	RemoteAddress	 string
-	ProxyAddresses	 string
-	UserName		 string
+	ProxyAddress	 string
+	RemoteServer	 string
 }
 
 var upgrader = websocket.Upgrader{}
@@ -65,9 +63,6 @@ func init() {
 }
 
 func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
-
 	connectionCache.Set(float64(c.ItemCount()))
 
 	var s *SessionInfo
@@ -79,6 +74,7 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
 	} else {
 		s = x.(*SessionInfo)
 	}
+	ctx := context.WithValue(r.Context(), "SessionInfo", s)
 
 	if r.Method == MethodRDGOUT {
 		if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
diff --git a/security/jwt.go b/security/jwt.go
index e083009cfb4009827528e43909f76ea86bc67a34..e90038b0ff9544f2063117b1a55e23695ba31d8f 100644
--- a/security/jwt.go
+++ b/security/jwt.go
@@ -1,6 +1,7 @@
 package security
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"github.com/bolkedebruin/rdpgw/protocol"
@@ -17,7 +18,7 @@ type customClaims struct {
 	jwt.StandardClaims
 }
 
-func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) {
+func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
 	token, err := jwt.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (interface{}, error) {
 		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
 			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
@@ -30,7 +31,9 @@ func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) {
 		return false, err
 	}
 
-	if _, ok := token.Claims.(*customClaims); ok && token.Valid {
+	if c, ok := token.Claims.(*customClaims); ok && token.Valid {
+		s := getSessionInfo(ctx)
+		s.RemoteServer = c.RemoteServer
 		return true, nil
 	}
 
@@ -38,7 +41,21 @@ func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) {
 	return false, err
 }
 
-func GeneratePAAToken(username string, server string) (string, error) {
+func VerifyServerFunc(ctx context.Context, host string) (bool, error) {
+	s := getSessionInfo(ctx)
+	if s == nil {
+		return false, errors.New("no valid session info found in context")
+	}
+
+	if s.RemoteServer != host {
+		log.Printf("Client host %s does not match token host %s", host, s.RemoteServer)
+		return false, nil
+	}
+
+	return true, nil
+}
+
+func GeneratePAAToken(ctx context.Context, username string, server string) (string, error) {
 	if len(SigningKey) < 32 {
 		return "", errors.New("token signing key not long enough or not specified")
 	}
@@ -67,4 +84,13 @@ func GeneratePAAToken(username string, server string) (string, error) {
 	} else {
 		return ss, nil
 	}
+}
+
+func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
+	s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
+	if !ok {
+		log.Printf("cannot get session info from context")
+		return nil
+	}
+	return s
 }
\ No newline at end of file