Skip to content
Snippets Groups Projects
Commit 22d796c5 authored by Bolke de Bruin's avatar Bolke de Bruin
Browse files

Add tokeninfo endpoint

parent 188f077d
Branches
Tags
No related merge requests found
package api
import (
"context"
"encoding/json"
"fmt"
"github.com/bolkedebruin/rdpgw/security"
"log"
"net/http"
)
func (c *Config) TokenInfo(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Invalid request", http.StatusMethodNotAllowed)
return
}
tokens, ok := r.URL.Query()["access_token"]
if !ok || len(tokens[0]) < 1 {
log.Printf("Missing access_token in request")
http.Error(w, "access_token missing in request", http.StatusBadRequest)
return
}
token := tokens[0]
info, err := security.UserInfo(context.Background(), token)
if err != nil {
log.Printf("Token validation failed due to %s", err)
http.Error(w, fmt.Sprintf("token validation failed due to %s", err), http.StatusForbidden)
return
}
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
if err = json.NewEncoder(w).Encode(info); err != nil {
log.Printf("Cannot encode json due to %s", err)
http.Error(w, "cannot encode json", http.StatusInternalServerError)
return
}
}
\ No newline at end of file
......@@ -130,6 +130,7 @@ func main() {
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/tokeninfo", api.TokenInfo)
http.HandleFunc("/callback", api.HandleCallback)
err = server.ListenAndServeTLS("", "")
......
......@@ -145,6 +145,69 @@ func GenerateUserToken(ctx context.Context, userName string) (string, error) {
return token, err
}
func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
standard := jwt.Claims{}
if len(UserEncryptionKey) > 0 && len(UserSigningKey) > 0 {
enc, err := jwt.ParseSignedAndEncrypted(token)
if err != nil {
log.Printf("Cannot get token %s", err)
return standard, errors.New("cannot get token")
}
token, err := enc.Decrypt(UserEncryptionKey)
if err != nil {
log.Printf("Cannot decrypt token %s", err)
return standard, errors.New("cannot decrypt token")
}
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
log.Printf("signature validation failure: %s", err)
return standard, errors.New("signature validation failure")
}
if err = token.Claims(UserSigningKey, &standard); err != nil {
log.Printf("cannot verify signature %s", err)
return standard, errors.New("cannot verify signature")
}
} else if len(UserSigningKey) == 0 {
token, err := jwt.ParseEncrypted(token)
if err != nil {
log.Printf("Cannot get token %s", err)
return standard, errors.New("cannot get token")
}
err = token.Claims(UserEncryptionKey, &standard)
if err != nil {
log.Printf("Cannot decrypt token %s", err)
return standard, errors.New("cannot decrypt token")
}
} else {
token, err := jwt.ParseSigned(token)
if err != nil {
log.Printf("Cannot get token %s", err)
return standard, errors.New("cannot get token")
}
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
log.Printf("signature validation failure: %s", err)
return standard, errors.New("signature validation failure")
}
err = token.Claims(UserSigningKey, &standard)
if err = token.Claims(UserSigningKey, &standard); err != nil {
log.Printf("cannot verify signature %s", err)
return standard, errors.New("cannot verify signature")
}
}
// go-jose doesnt verify the expiry
err := standard.Validate(jwt.Expected{
Issuer: "rdpgw",
Time: time.Now(),
})
if err != nil {
log.Printf("token validation failed due to %s", err)
return standard, fmt.Errorf("token validation failed due to %s", err)
}
return standard, nil
}
func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
if !ok {
......@@ -153,3 +216,12 @@ func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
}
return s
}
func verifyAlg(headers []jose.Header, alg string) (bool, error) {
for _, header := range headers {
if header.Algorithm != alg {
return false, fmt.Errorf("invalid signing method %s", header.Algorithm)
}
}
return true, nil
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment