diff --git a/cmd/rdpgw/web/oidc.go b/cmd/rdpgw/web/oidc.go index 1a41b01b7fa4cbe16640a5ffe6de9cafd71200ec..03cece1891713b0a535fc99909a70905b5b1cd75 100644 --- a/cmd/rdpgw/web/oidc.go +++ b/cmd/rdpgw/web/oidc.go @@ -15,7 +15,6 @@ import ( const ( CacheExpiration = time.Minute * 2 CleanupInterval = time.Minute * 5 - oidcKeyUserName = "preferred_username" ) type OIDC struct { @@ -81,7 +80,13 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) { } id := identity.FromRequestCtx(r) - id.SetUserName(data[oidcKeyUserName].(string)) + + userName := findUsernameInClaims(data) + if userName == "" { + http.Error(w, "no oidc claim for username found", http.StatusInternalServerError) + } + + id.SetUserName(userName) id.SetAuthenticated(true) id.SetAuthTime(time.Now()) id.SetAttribute(identity.AttrAccessToken, oauth2Token.AccessToken) @@ -93,6 +98,18 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, url, http.StatusFound) } +func findUsernameInClaims(data map[string]interface{}) string { + candidates := []string{"preferred_username", "unique_name", "upn"} + for _, claim := range candidates { + userName, found := data[claim].(string) + if found { + return userName + } + } + + return "" +} + func (h *OIDC) Authenticated(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id := identity.FromRequestCtx(r) diff --git a/cmd/rdpgw/web/oidc_test.go b/cmd/rdpgw/web/oidc_test.go new file mode 100644 index 0000000000000000000000000000000000000000..37eb90853d6c65ac4204d326690850dc6a7771ad --- /dev/null +++ b/cmd/rdpgw/web/oidc_test.go @@ -0,0 +1,49 @@ +package web + +import "testing" + +func TestFindUserNameInClaims(t *testing.T) { + cases := []struct { + data map[string]interface{} + ret string + name string + }{ + { + data: map[string]interface{}{ + "preferred_username": "exists", + }, + ret: "exists", + name: "preferred_username", + }, + { + data: map[string]interface{}{ + "upn": "exists", + }, + ret: "exists", + name: "upn", + }, + { + data: map[string]interface{}{ + "unique_name": "exists", + }, + ret: "exists", + name: "unique_name", + }, + { + data: map[string]interface{}{ + "fail": "exists", + }, + ret: "", + name: "fail", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + s := findUsernameInClaims(tc.data) + if s != tc.ret { + t.Fatalf("expected return: %v, got: %v", tc.ret, s) + } + }) + } +}