diff --git a/cmd/rdpgw/api/basic.go b/cmd/rdpgw/api/basic.go index afa4108d5dd6affa1ec5f2a2b2901a1c19626bd2..d2540ba726984f55519625432b023601b18b70d7 100644 --- a/cmd/rdpgw/api/basic.go +++ b/cmd/rdpgw/api/basic.go @@ -33,7 +33,7 @@ func (c *Config) BasicAuth(next http.HandlerFunc) http.HandlerFunc { defer conn.Close() c := auth.NewAuthenticateClient(conn) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() req := &auth.UserPass{Username: username, Password: password} diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index 2800cf78fc7f9b975450a8c5fa8599275a0c95aa..96ce034ba76ce93b94b44768fa211eadd16e774e 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -44,8 +44,6 @@ func main() { // configure api api := &api.Config{ - PAATokenGenerator: security.GeneratePAAToken, - UserTokenGenerator: security.GenerateUserToken, QueryInfo: security.QueryInfo, QueryTokenIssuer: conf.Security.QueryTokenIssuer, EnableUserToken: conf.Security.EnableUserToken, @@ -64,6 +62,13 @@ func main() { Authentication: conf.Server.Authentication, } + if conf.Caps.TokenAuth { + api.PAATokenGenerator = security.GeneratePAAToken + } + if conf.Security.EnableUserToken { + api.UserTokenGenerator = security.GenerateUserToken + } + if conf.Server.Authentication == "openid" { // set oidc config provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl) @@ -144,10 +149,12 @@ func main() { DisableAll: conf.Caps.DisableRedirect, EnableAll: conf.Caps.RedirectAll, }, - VerifyTunnelCreate: security.VerifyPAAToken, - VerifyServerFunc: security.VerifyServerFunc, - SendBuf: conf.Server.SendBuf, - ReceiveBuf: conf.Server.ReceiveBuf, + SendBuf: conf.Server.SendBuf, + ReceiveBuf: conf.Server.ReceiveBuf, + } + if conf.Caps.TokenAuth { + handlerConfig.VerifyTunnelAuthFunc = security.VerifyPAAToken + handlerConfig.VerifyServerFunc = security.VerifyServerFunc } gw := protocol.Gateway{ ServerConf: &handlerConfig, diff --git a/cmd/rdpgw/protocol/server.go b/cmd/rdpgw/protocol/server.go index a85a088b039e68a09e6d5725f7f3de1efbb4896d..6571ece871e011b759a2a9c55c163608b8092939 100644 --- a/cmd/rdpgw/protocol/server.go +++ b/cmd/rdpgw/protocol/server.go @@ -78,8 +78,8 @@ func (s *Server) Process(ctx context.Context) error { s.Session.TransportOut.WritePacket(msg) return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR) } - major, minor, _, auth := s.handshakeRequest(pkt) // todo check if auth matches what the handler can do - caps, err := s.matchAuth(auth) + major, minor, _, reqAuth := s.handshakeRequest(pkt) + caps, err := s.matchAuth(reqAuth) if err != nil { log.Println(err) msg := s.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH) @@ -224,7 +224,7 @@ func (s *Server) handshakeRequest(data []byte) (major byte, minor byte, version return } -func (s *Server) matchAuth(extAuth uint16) (caps uint16, err error) { +func (s *Server) matchAuth(clientAuthCaps uint16) (caps uint16, err error) { if s.SmartCardAuth { caps = caps | HTTP_EXTENDED_AUTH_SC } @@ -232,10 +232,13 @@ func (s *Server) matchAuth(extAuth uint16) (caps uint16, err error) { caps = caps | HTTP_EXTENDED_AUTH_PAA } - if caps & extAuth == 0 && extAuth > 0 { - return 0, fmt.Errorf("%x has no matching capability configured (%x). Did you configure caps? ", extAuth, caps) + if caps&clientAuthCaps == 0 && clientAuthCaps > 0 { + return 0, fmt.Errorf("%x has no matching capability configured (%x). Did you configure caps? ", clientAuthCaps, caps) } + if caps > 0 && clientAuthCaps == 0 { + return 0, fmt.Errorf("%d caps are required by the server, but the client does not support them", caps) + } return caps, nil } diff --git a/cmd/rdpgw/security/jwt.go b/cmd/rdpgw/security/jwt.go index 84ab15b5949bb38116b8a64b3b05254827a18cda..60f272b730e04b1a945f170c9e7a46863de2197b 100644 --- a/cmd/rdpgw/security/jwt.go +++ b/cmd/rdpgw/security/jwt.go @@ -34,7 +34,16 @@ type customClaims struct { } func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) { + if tokenString == "" { + log.Printf("no token to parse") + return false, errors.New("no token to parse") + } + token, err := jwt.ParseSigned(tokenString) + if err != nil { + log.Printf("cannot parse token due to: %s", err) + return false, err + } // check if the signing algo matches what we expect for _, header := range token.Headers {