diff --git a/main.go b/main.go index ea1f3f234b5b061d783ce6af7fe668b3b3ddb226..759aea634f49d4723560018fb8f5a7aa8a747229 100644 --- a/main.go +++ b/main.go @@ -41,8 +41,7 @@ func main() { security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey) // set oidc config - ctx := context.Background() - provider, err := oidc.NewProvider(ctx, conf.OpenId.ProviderUrl) + provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl) if err != nil { log.Fatalf("Cannot get oidc provider: %s", err) } @@ -58,6 +57,8 @@ func main() { Endpoint: provider.Endpoint(), Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, } + security.OIDCProvider = provider + security.Oauth2Config = oauthConfig api := &api.Config{ GatewayAddress: conf.Server.GatewayAddress, diff --git a/security/jwt.go b/security/jwt.go index 482ea906f45f1b9868a30839623a13504b0ae6ca..a80654a338a1e7c00a4ab29114fb3cc145d250fd 100644 --- a/security/jwt.go +++ b/security/jwt.go @@ -6,8 +6,10 @@ import ( "fmt" "github.com/bolkedebruin/rdpgw/common" "github.com/bolkedebruin/rdpgw/protocol" + "github.com/coreos/go-oidc/v3/oidc" "github.com/square/go-jose/v3" "github.com/square/go-jose/v3/jwt" + "golang.org/x/oauth2" "log" "time" ) @@ -17,6 +19,8 @@ var ( EncryptionKey []byte UserSigningKey []byte UserEncryptionKey []byte + OIDCProvider *oidc.Provider + Oauth2Config oauth2.Config ) var ExpiryTime time.Duration = 5 @@ -58,6 +62,14 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) { return false, err } + // validate the access token + tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken}) + _, err = OIDCProvider.UserInfo(ctx, tokenSource) + if err != nil { + log.Printf("Cannot get user info for access token: %s", err) + return false, err + } + s := getSessionInfo(ctx) s.RemoteServer = custom.RemoteServer