From 69bcf812307c914436a8cd7516c250666601befe Mon Sep 17 00:00:00 2001 From: Bolke de Bruin <bolke@xs4all.nl> Date: Wed, 24 Aug 2022 22:44:44 +0200 Subject: [PATCH] Fix randomstring generation --- cmd/rdpgw/api/basic.go | 2 ++ cmd/rdpgw/main.go | 63 +++++++++++++++++++----------------- cmd/rdpgw/security/jwt.go | 12 ++++--- cmd/rdpgw/security/string.go | 2 +- 4 files changed, 43 insertions(+), 36 deletions(-) diff --git a/cmd/rdpgw/api/basic.go b/cmd/rdpgw/api/basic.go index 91f6d17..afa4108 100644 --- a/cmd/rdpgw/api/basic.go +++ b/cmd/rdpgw/api/basic.go @@ -47,6 +47,8 @@ func (c *Config) BasicAuth(next http.HandlerFunc) http.HandlerFunc { if !res.Authenticated { log.Printf("User %s is not authenticated for this service", username) } else { + ctx := context.WithValue(r.Context(), "preferred_username", username) + ctx = context.WithValue(ctx, "access_token", "EMPTY") next.ServeHTTP(w, r.WithContext(ctx)) return } diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index 24e4bb0..2800cf7 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -42,37 +42,8 @@ func main() { security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey) security.QuerySigningKey = []byte(conf.Security.QueryTokenSigningKey) - // set oidc config - provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl) - if err != nil { - log.Fatalf("Cannot get oidc provider: %s", err) - } - oidcConfig := &oidc.Config{ - ClientID: conf.OpenId.ClientId, - } - verifier := provider.Verifier(oidcConfig) - - // get callback url and external advertised gateway address - url, err := url.Parse(conf.Server.GatewayAddress) - if url.Scheme == "" { - url.Scheme = "https" - } - url.Path = "callback" - - oauthConfig := oauth2.Config{ - ClientID: conf.OpenId.ClientId, - ClientSecret: conf.OpenId.ClientSecret, - RedirectURL: url.String(), - Endpoint: provider.Endpoint(), - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - } - security.OIDCProvider = provider - security.Oauth2Config = oauthConfig - + // configure api api := &api.Config{ - GatewayAddress: url.Host, - OAuth2Config: &oauthConfig, - OIDCTokenVerifier: verifier, PAATokenGenerator: security.GeneratePAAToken, UserTokenGenerator: security.GenerateUserToken, QueryInfo: security.QueryInfo, @@ -92,6 +63,38 @@ func main() { SocketAddress: conf.Server.AuthSocket, Authentication: conf.Server.Authentication, } + + if conf.Server.Authentication == "openid" { + // set oidc config + provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl) + if err != nil { + log.Fatalf("Cannot get oidc provider: %s", err) + } + oidcConfig := &oidc.Config{ + ClientID: conf.OpenId.ClientId, + } + verifier := provider.Verifier(oidcConfig) + + // get callback url and external advertised gateway address + url, err := url.Parse(conf.Server.GatewayAddress) + if url.Scheme == "" { + url.Scheme = "https" + } + url.Path = "callback" + api.GatewayAddress = url.Host + + oauthConfig := oauth2.Config{ + ClientID: conf.OpenId.ClientId, + ClientSecret: conf.OpenId.ClientSecret, + RedirectURL: url.String(), + Endpoint: provider.Endpoint(), + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + } + security.OIDCProvider = provider + security.Oauth2Config = oauthConfig + api.OAuth2Config = &oauthConfig + api.OIDCTokenVerifier = verifier + } api.NewApi() log.Printf("Starting remote desktop gateway server") diff --git a/cmd/rdpgw/security/jwt.go b/cmd/rdpgw/security/jwt.go index 8deef42..84ab15b 100644 --- a/cmd/rdpgw/security/jwt.go +++ b/cmd/rdpgw/security/jwt.go @@ -65,11 +65,13 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) { } // validate the access token - tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken}) - _, err = OIDCProvider.UserInfo(ctx, tokenSource) - if err != nil { - log.Printf("Cannot get user info for access token: %s", err) - return false, err + if custom.AccessToken != "EMPTY" { + tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken}) + _, err = OIDCProvider.UserInfo(ctx, tokenSource) + if err != nil { + log.Printf("Cannot get user info for access token: %s", err) + return false, err + } } s := getSessionInfo(ctx) diff --git a/cmd/rdpgw/security/string.go b/cmd/rdpgw/security/string.go index dcf5821..a79dbff 100644 --- a/cmd/rdpgw/security/string.go +++ b/cmd/rdpgw/security/string.go @@ -32,7 +32,7 @@ func GenerateRandomString(n int) (string, error) { if err != nil { return "", err } - ret = append(ret, letters[num.Int64()]) + ret[i] = letters[num.Int64()] } return string(ret), nil -- GitLab