diff --git a/api/web.go b/api/web.go index f3686d60c0d82bcfe1a9ce7a5c71067c5b99f3fd..c262203b86cdb15ddcfe7e4c866edab9f39fa800 100644 --- a/api/web.go +++ b/api/web.go @@ -19,9 +19,10 @@ import ( const ( RdpGwSession = "RDPGWSESSION" + MaxAge = 120 ) -type TokenGeneratorFunc func(string, string) (string, error) +type TokenGeneratorFunc func(context.Context, string, string) (string, error) type Config struct { SessionKey []byte @@ -99,6 +100,7 @@ func (c *Config) HandleCallback(w http.ResponseWriter, r *http.Request) { return } + session.Options.MaxAge = MaxAge session.Values["preferred_username"] = data["preferred_username"] session.Values["authenticated"] = true @@ -157,7 +159,7 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) { } } - token, err := c.TokenGenerator(user, host) + token, err := c.TokenGenerator(ctx, user, host) if err != nil { log.Printf("Cannot generate token for user %s due to %s", user, err) http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError) diff --git a/main.go b/main.go index f59dd120400c1d10faebe895d0ed818ced45f9a2..98fed5c963a500ae073faee331e80c81c15675ca 100644 --- a/main.go +++ b/main.go @@ -116,6 +116,7 @@ func main() { EnableAll: conf.Caps.RedirectAll, }, VerifyTunnelCreate: security.VerifyPAAToken, + VerifyServerFunc: security.VerifyServerFunc, } gw := protocol.Gateway{ HandlerConf: &handlerConfig, diff --git a/protocol/handler.go b/protocol/handler.go index c622057308a4cdd298175af0e3519039157e3ad3..138b6057b165bfe9990fdd77e554743baf3db73b 100644 --- a/protocol/handler.go +++ b/protocol/handler.go @@ -5,7 +5,6 @@ import ( "context" "encoding/binary" "errors" - "github.com/bolkedebruin/rdpgw/transport" "io" "log" "net" @@ -13,9 +12,9 @@ import ( "time" ) -type VerifyTunnelCreate func(*SessionInfo, string) (bool, error) -type VerifyTunnelAuthFunc func(*SessionInfo, string) (bool, error) -type VerifyServerFunc func(*SessionInfo, string) (bool, error) +type VerifyTunnelCreate func(context.Context, string) (bool, error) +type VerifyTunnelAuthFunc func(context.Context, string) (bool, error) +type VerifyServerFunc func(context.Context, string) (bool, error) type RedirectFlags struct { Clipboard bool @@ -29,8 +28,6 @@ type RedirectFlags struct { type Handler struct { Session *SessionInfo - TransportIn transport.Transport - TransportOut transport.Transport VerifyTunnelCreate VerifyTunnelCreate VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyServerFunc VerifyServerFunc @@ -55,10 +52,8 @@ type HandlerConf struct { func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler { h := &Handler{ + State: SERVER_STATE_INITIAL, Session: s, - TransportIn: s.TransportIn, - TransportOut: s.TransportOut, - State: SERVER_STATE_INITIAL, RedirectFlags: makeRedirectFlags(conf.RedirectFlags), IdleTimeout: conf.IdleTimeout, SmartCardAuth: conf.SmartCardAuth, @@ -89,7 +84,7 @@ func (h *Handler) Process(ctx context.Context) error { } major, minor, _, _ := readHandshake(pkt) // todo check if auth matches what the handler can do msg := h.handshakeResponse(major, minor) - h.TransportOut.WritePacket(msg) + h.Session.TransportOut.WritePacket(msg) h.State = SERVER_STATE_HANDSHAKE case PKT_TYPE_TUNNEL_CREATE: log.Printf("Tunnel create") @@ -100,13 +95,13 @@ func (h *Handler) Process(ctx context.Context) error { } _, cookie := readCreateTunnelRequest(pkt) if h.VerifyTunnelCreate != nil { - if ok, _ := h.VerifyTunnelCreate(h.Session, cookie); !ok { + if ok, _ := h.VerifyTunnelCreate(ctx, cookie); !ok { log.Printf("Invalid PAA cookie received") return errors.New("invalid PAA cookie") } } msg := createTunnelResponse() - h.TransportOut.WritePacket(msg) + h.Session.TransportOut.WritePacket(msg) h.State = SERVER_STATE_TUNNEL_CREATE case PKT_TYPE_TUNNEL_AUTH: log.Printf("Tunnel auth") @@ -117,13 +112,13 @@ func (h *Handler) Process(ctx context.Context) error { } client := h.readTunnelAuthRequest(pkt) if h.VerifyTunnelAuthFunc != nil { - if ok, _ := h.VerifyTunnelAuthFunc(h.Session, client); !ok { + if ok, _ := h.VerifyTunnelAuthFunc(ctx, client); !ok { log.Printf("Invalid client name: %s", client) return errors.New("invalid client name") } } msg := h.createTunnelAuthResponse() - h.TransportOut.WritePacket(msg) + h.Session.TransportOut.WritePacket(msg) h.State = SERVER_STATE_TUNNEL_AUTHORIZE case PKT_TYPE_CHANNEL_CREATE: log.Printf("Channel create") @@ -135,8 +130,9 @@ func (h *Handler) Process(ctx context.Context) error { server, port := readChannelCreateRequest(pkt) host := net.JoinHostPort(server, strconv.Itoa(int(port))) if h.VerifyServerFunc != nil { - if ok, _ := h.VerifyServerFunc(h.Session, host); !ok { + if ok, _ := h.VerifyServerFunc(ctx, host); !ok { log.Printf("Not allowed to connect to %s by policy handler", host) + return errors.New("denied by security policy") } } log.Printf("Establishing connection to RDP server: %s", host) @@ -147,7 +143,7 @@ func (h *Handler) Process(ctx context.Context) error { } log.Printf("Connection established") msg := createChannelCreateResponse() - h.TransportOut.WritePacket(msg) + h.Session.TransportOut.WritePacket(msg) // Make sure to start the flow from the RDP server first otherwise connections // might hang eventually @@ -175,8 +171,8 @@ func (h *Handler) Process(ctx context.Context) error { log.Printf("Channel closed while in wrong state %d != %d", h.State, SERVER_STATE_OPENED) return errors.New("wrong state") } - h.TransportIn.Close() - h.TransportOut.Close() + h.Session.TransportIn.Close() + h.Session.TransportOut.Close() h.State = SERVER_STATE_CLOSED default: log.Printf("Unknown packet (size %d): %x", sz, pkt) @@ -190,7 +186,7 @@ func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) { buf := make([]byte, 4096) for { - size, pkt, err := h.TransportIn.ReadPacket() + size, pkt, err := h.Session.TransportIn.ReadPacket() if err != nil { return 0, 0, []byte{0, 0}, err } @@ -398,7 +394,7 @@ func (h *Handler) sendDataPacket() { break } b1.Write(buf[:n]) - h.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes())) + h.Session.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes())) b1.Reset() } } diff --git a/protocol/rdpgw.go b/protocol/rdpgw.go index f30855c7edca2735f4dc908b92bfee50c1023846..12bff90ca536db0aaaba59894a06bf74583c3fd0 100644 --- a/protocol/rdpgw.go +++ b/protocol/rdpgw.go @@ -46,13 +46,11 @@ type Gateway struct { type SessionInfo struct { ConnId string - CorrelationId string - ClientGeneration string TransportIn transport.Transport TransportOut transport.Transport RemoteAddress string - ProxyAddresses string - UserName string + ProxyAddress string + RemoteServer string } var upgrader = websocket.Upgrader{} @@ -65,9 +63,6 @@ func init() { } func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - connectionCache.Set(float64(c.ItemCount())) var s *SessionInfo @@ -79,6 +74,7 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) } else { s = x.(*SessionInfo) } + ctx := context.WithValue(r.Context(), "SessionInfo", s) if r.Method == MethodRDGOUT { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { diff --git a/security/jwt.go b/security/jwt.go index e083009cfb4009827528e43909f76ea86bc67a34..e90038b0ff9544f2063117b1a55e23695ba31d8f 100644 --- a/security/jwt.go +++ b/security/jwt.go @@ -1,6 +1,7 @@ package security import ( + "context" "errors" "fmt" "github.com/bolkedebruin/rdpgw/protocol" @@ -17,7 +18,7 @@ type customClaims struct { jwt.StandardClaims } -func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) { +func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) { token, err := jwt.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) @@ -30,7 +31,9 @@ func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) { return false, err } - if _, ok := token.Claims.(*customClaims); ok && token.Valid { + if c, ok := token.Claims.(*customClaims); ok && token.Valid { + s := getSessionInfo(ctx) + s.RemoteServer = c.RemoteServer return true, nil } @@ -38,7 +41,21 @@ func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) { return false, err } -func GeneratePAAToken(username string, server string) (string, error) { +func VerifyServerFunc(ctx context.Context, host string) (bool, error) { + s := getSessionInfo(ctx) + if s == nil { + return false, errors.New("no valid session info found in context") + } + + if s.RemoteServer != host { + log.Printf("Client host %s does not match token host %s", host, s.RemoteServer) + return false, nil + } + + return true, nil +} + +func GeneratePAAToken(ctx context.Context, username string, server string) (string, error) { if len(SigningKey) < 32 { return "", errors.New("token signing key not long enough or not specified") } @@ -67,4 +84,13 @@ func GeneratePAAToken(username string, server string) (string, error) { } else { return ss, nil } +} + +func getSessionInfo(ctx context.Context) *protocol.SessionInfo { + s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo) + if !ok { + log.Printf("cannot get session info from context") + return nil + } + return s } \ No newline at end of file