diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index e89e77d587cdc5fbcb360a30cff2f0a7273b1094..d4d99927e47f2d56a3f7286a0f4f3b9d126c0dcb 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -156,9 +156,9 @@ func main() { } if conf.Caps.TokenAuth { gwConfig.VerifyTunnelAuthFunc = security.VerifyPAAToken - gwConfig.VerifyServerFunc = security.VerifyServerFunc + gwConfig.VerifyServerFunc = security.CheckSession(security.CheckHost) } else { - gwConfig.VerifyServerFunc = security.BasicVerifyServer + gwConfig.VerifyServerFunc = security.CheckHost } gw := protocol.Gateway{ ServerConf: &gwConfig, diff --git a/cmd/rdpgw/security/basic.go b/cmd/rdpgw/security/basic.go index 64b5db0ad6c9b25d61d7db3d460247b2d827c2f1..5a066119c271c6bb58f6d4b49831f3c7ce116558 100644 --- a/cmd/rdpgw/security/basic.go +++ b/cmd/rdpgw/security/basic.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log" + "strings" ) var ( @@ -12,19 +13,20 @@ var ( HostSelection string ) -func BasicVerifyServer(ctx context.Context, host string) (bool, error) { - if HostSelection == "any" { +func CheckHost(ctx context.Context, host string) (bool, error) { + switch HostSelection { + case "any": return true, nil - } - - if HostSelection == "signed" { - // todo get from context + case "signed": + // todo get from context? return false, errors.New("cannot verify host in 'signed' mode as token data is missing") - } - - if HostSelection == "roundrobin" || HostSelection == "unsigned" { + case "roundrobin", "unsigned": log.Printf("Checking host") + username := ctx.Value("preferred_username").(string) for _, h := range Hosts { + if username != "" { + h = strings.Replace(h, "{{ preferred_username }}", username, 1) + } if h == host { return true, nil } diff --git a/cmd/rdpgw/security/jwt.go b/cmd/rdpgw/security/jwt.go index 60f272b730e04b1a945f170c9e7a46863de2197b..324c8737bb2eacc673d4653be92618de142221b1 100644 --- a/cmd/rdpgw/security/jwt.go +++ b/cmd/rdpgw/security/jwt.go @@ -33,6 +33,27 @@ type customClaims struct { AccessToken string `json:"accessToken"` } +func CheckSession(next protocol.VerifyServerFunc) protocol.VerifyServerFunc { + return func(ctx context.Context, host string) (bool, error) { + s := getSessionInfo(ctx) + if s == nil { + return false, errors.New("no valid session info found in context") + } + + if s.RemoteServer != host { + log.Printf("Client specified host %s does not match token host %s", host, s.RemoteServer) + return false, nil + } + + if VerifyClientIP && s.ClientIp != common.GetClientIp(ctx) { + log.Printf("Current client ip address %s does not match token client ip %s", + common.GetClientIp(ctx), s.ClientIp) + return false, nil + } + return next(ctx, host) + } +} + func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) { if tokenString == "" { log.Printf("no token to parse") @@ -91,26 +112,6 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) { return true, nil } -func VerifyServerFunc(ctx context.Context, host string) (bool, error) { - s := getSessionInfo(ctx) - if s == nil { - return false, errors.New("no valid session info found in context") - } - - if s.RemoteServer != host { - log.Printf("Client specified host %s does not match token host %s", host, s.RemoteServer) - return false, nil - } - - if VerifyClientIP && s.ClientIp != common.GetClientIp(ctx) { - log.Printf("Current client ip address %s does not match token client ip %s", - common.GetClientIp(ctx), s.ClientIp) - return false, nil - } - - return true, nil -} - func GeneratePAAToken(ctx context.Context, username string, server string) (string, error) { if len(SigningKey) < 32 { return "", errors.New("token signing key not long enough or not specified")