diff --git a/api/token.go b/api/token.go new file mode 100644 index 0000000000000000000000000000000000000000..07b90c09eef246ff48cac5ead9c60056b5641fd1 --- /dev/null +++ b/api/token.go @@ -0,0 +1,40 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "github.com/bolkedebruin/rdpgw/security" + "log" + "net/http" +) + +func (c *Config) TokenInfo(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Invalid request", http.StatusMethodNotAllowed) + return + } + + tokens, ok := r.URL.Query()["access_token"] + if !ok || len(tokens[0]) < 1 { + log.Printf("Missing access_token in request") + http.Error(w, "access_token missing in request", http.StatusBadRequest) + return + } + + token := tokens[0] + + info, err := security.UserInfo(context.Background(), token) + if err != nil { + log.Printf("Token validation failed due to %s", err) + http.Error(w, fmt.Sprintf("token validation failed due to %s", err), http.StatusForbidden) + return + } + + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + if err = json.NewEncoder(w).Encode(info); err != nil { + log.Printf("Cannot encode json due to %s", err) + http.Error(w, "cannot encode json", http.StatusInternalServerError) + return + } +} \ No newline at end of file diff --git a/main.go b/main.go index 26c0d092fd08799f1ce74bf5e6f401f7c3df6a5a..ea1f3f234b5b061d783ce6af7fe668b3b3ddb226 100644 --- a/main.go +++ b/main.go @@ -130,6 +130,7 @@ func main() { http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol))) http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload)))) http.Handle("/metrics", promhttp.Handler()) + http.HandleFunc("/tokeninfo", api.TokenInfo) http.HandleFunc("/callback", api.HandleCallback) err = server.ListenAndServeTLS("", "") diff --git a/security/jwt.go b/security/jwt.go index 97511624b0c6e829ab6fd805ade46474d0d75980..482ea906f45f1b9868a30839623a13504b0ae6ca 100644 --- a/security/jwt.go +++ b/security/jwt.go @@ -145,6 +145,69 @@ func GenerateUserToken(ctx context.Context, userName string) (string, error) { return token, err } +func UserInfo(ctx context.Context, token string) (jwt.Claims, error) { + standard := jwt.Claims{} + if len(UserEncryptionKey) > 0 && len(UserSigningKey) > 0 { + enc, err := jwt.ParseSignedAndEncrypted(token) + if err != nil { + log.Printf("Cannot get token %s", err) + return standard, errors.New("cannot get token") + } + token, err := enc.Decrypt(UserEncryptionKey) + if err != nil { + log.Printf("Cannot decrypt token %s", err) + return standard, errors.New("cannot decrypt token") + } + if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil { + log.Printf("signature validation failure: %s", err) + return standard, errors.New("signature validation failure") + } + if err = token.Claims(UserSigningKey, &standard); err != nil { + log.Printf("cannot verify signature %s", err) + return standard, errors.New("cannot verify signature") + } + } else if len(UserSigningKey) == 0 { + token, err := jwt.ParseEncrypted(token) + if err != nil { + log.Printf("Cannot get token %s", err) + return standard, errors.New("cannot get token") + } + err = token.Claims(UserEncryptionKey, &standard) + if err != nil { + log.Printf("Cannot decrypt token %s", err) + return standard, errors.New("cannot decrypt token") + } + } else { + token, err := jwt.ParseSigned(token) + if err != nil { + log.Printf("Cannot get token %s", err) + return standard, errors.New("cannot get token") + } + if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil { + log.Printf("signature validation failure: %s", err) + return standard, errors.New("signature validation failure") + } + err = token.Claims(UserSigningKey, &standard) + if err = token.Claims(UserSigningKey, &standard); err != nil { + log.Printf("cannot verify signature %s", err) + return standard, errors.New("cannot verify signature") + } + } + + // go-jose doesnt verify the expiry + err := standard.Validate(jwt.Expected{ + Issuer: "rdpgw", + Time: time.Now(), + }) + + if err != nil { + log.Printf("token validation failed due to %s", err) + return standard, fmt.Errorf("token validation failed due to %s", err) + } + + return standard, nil +} + func getSessionInfo(ctx context.Context) *protocol.SessionInfo { s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo) if !ok { @@ -153,3 +216,12 @@ func getSessionInfo(ctx context.Context) *protocol.SessionInfo { } return s } + +func verifyAlg(headers []jose.Header, alg string) (bool, error) { + for _, header := range headers { + if header.Algorithm != alg { + return false, fmt.Errorf("invalid signing method %s", header.Algorithm) + } + } + return true, nil +} \ No newline at end of file