diff --git a/cmd/rdpgw/api/web.go b/cmd/rdpgw/api/web.go deleted file mode 100644 index d02f5c393911b290f7c6b3671c23e7bc31f45380..0000000000000000000000000000000000000000 --- a/cmd/rdpgw/api/web.go +++ /dev/null @@ -1,293 +0,0 @@ -package api - -import ( - "context" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "github.com/coreos/go-oidc/v3/oidc" - "github.com/gorilla/sessions" - "github.com/patrickmn/go-cache" - "golang.org/x/oauth2" - "log" - "math/rand" - "net/http" - "net/url" - "os" - "strconv" - "strings" - "time" -) - -const ( - RdpGwSession = "RDPGWSESSION" - MaxAge = 120 -) - -type TokenGeneratorFunc func(context.Context, string, string) (string, error) -type UserTokenGeneratorFunc func(context.Context, string) (string, error) -type QueryInfoFunc func(context.Context, string, string) (string, error) - -type Config struct { - SessionKey []byte - SessionEncryptionKey []byte - SessionStore string - PAATokenGenerator TokenGeneratorFunc - UserTokenGenerator UserTokenGeneratorFunc - QueryInfo QueryInfoFunc - QueryTokenIssuer string - EnableUserToken bool - OAuth2Config *oauth2.Config - store sessions.Store - OIDCTokenVerifier *oidc.IDTokenVerifier - stateStore *cache.Cache - Hosts []string - HostSelection string - GatewayAddress *url.URL - UsernameTemplate string - NetworkAutoDetect int - BandwidthAutoDetect int - ConnectionType int - SplitUserDomain bool - DefaultDomain string - SocketAddress string - Authentication string -} - -func (c *Config) NewApi() { - if len(c.SessionKey) < 32 { - log.Fatal("Session key too small") - } - if len(c.Hosts) < 1 { - log.Fatal("Not enough hosts to connect to specified") - } - if c.SessionStore == "file" { - log.Println("Filesystem is used as session storage") - c.store = sessions.NewFilesystemStore(os.TempDir(), c.SessionKey, c.SessionEncryptionKey) - } else { - log.Println("Cookies are used as session storage") - c.store = sessions.NewCookieStore(c.SessionKey, c.SessionEncryptionKey) - } - c.stateStore = cache.New(time.Minute*2, 5*time.Minute) -} - -func (c *Config) HandleCallback(w http.ResponseWriter, r *http.Request) { - state := r.URL.Query().Get("state") - s, found := c.stateStore.Get(state) - if !found { - http.Error(w, "unknown state", http.StatusBadRequest) - return - } - url := s.(string) - - ctx := context.Background() - oauth2Token, err := c.OAuth2Config.Exchange(ctx, r.URL.Query().Get("code")) - if err != nil { - http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) - return - } - - rawIDToken, ok := oauth2Token.Extra("id_token").(string) - if !ok { - http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError) - return - } - idToken, err := c.OIDCTokenVerifier.Verify(ctx, rawIDToken) - if err != nil { - http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError) - return - } - - resp := struct { - OAuth2Token *oauth2.Token - IDTokenClaims *json.RawMessage // ID Token payload is just JSON. - }{oauth2Token, new(json.RawMessage)} - - if err := idToken.Claims(&resp.IDTokenClaims); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - var data map[string]interface{} - if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - session, err := c.store.Get(r, RdpGwSession) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - session.Options.MaxAge = MaxAge - session.Values["preferred_username"] = data["preferred_username"] - session.Values["authenticated"] = true - session.Values["access_token"] = oauth2Token.AccessToken - - if err = session.Save(r, w); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - - http.Redirect(w, r, url, http.StatusFound) -} - -func (c *Config) Authenticated(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session, err := c.store.Get(r, RdpGwSession) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - found := session.Values["authenticated"] - if found == nil || !found.(bool) { - seed := make([]byte, 16) - rand.Read(seed) - state := hex.EncodeToString(seed) - c.stateStore.Set(state, r.RequestURI, cache.DefaultExpiration) - http.Redirect(w, r, c.OAuth2Config.AuthCodeURL(state), http.StatusFound) - return - } - - ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"]) - ctx = context.WithValue(ctx, "access_token", session.Values["access_token"]) - - next.ServeHTTP(w, r.WithContext(ctx)) - }) -} - -func (c *Config) selectRandomHost() string { - rand.Seed(time.Now().Unix()) - host := c.Hosts[rand.Intn(len(c.Hosts))] - return host -} - -func (c *Config) getHost(ctx context.Context, u *url.URL) (string, error) { - switch c.HostSelection { - case "roundrobin": - return c.selectRandomHost(), nil - case "signed": - hosts, ok := u.Query()["host"] - if !ok { - return "", errors.New("invalid query parameter") - } - host, err := c.QueryInfo(ctx, hosts[0], c.QueryTokenIssuer) - if err != nil { - return "", err - } - found := false - for _, check := range c.Hosts { - if check == host { - found = true - break - } - } - if !found { - log.Printf("Invalid host %s specified in token", hosts[0]) - return "", errors.New("invalid host specified in query token") - } - return host, nil - case "unsigned": - hosts, ok := u.Query()["host"] - if !ok { - return "", errors.New("invalid query parameter") - } - for _, check := range c.Hosts { - if check == hosts[0] { - return hosts[0], nil - } - } - // not found - log.Printf("Invalid host %s specified in client request", hosts[0]) - return "", errors.New("invalid host specified in query parameter") - case "any": - hosts, ok := u.Query()["host"] - if !ok { - return "", errors.New("invalid query parameter") - } - return hosts[0], nil - default: - return c.selectRandomHost(), nil - } -} - -func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - userName, ok := ctx.Value("preferred_username").(string) - - if !ok { - log.Printf("preferred_username not found in context") - http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError) - return - } - - // determine host to connect to - host, err := c.getHost(ctx, r.URL) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - host = strings.Replace(host, "{{ preferred_username }}", userName, 1) - - // split the username into user and domain - var user = userName - var domain = c.DefaultDomain - if c.SplitUserDomain { - creds := strings.SplitN(userName, "@", 2) - user = creds[0] - if len(creds) > 1 { - domain = creds[1] - } - } - - render := user - if c.UsernameTemplate != "" { - render = fmt.Sprintf(c.UsernameTemplate) - render = strings.Replace(render, "{{ username }}", user, 1) - if c.UsernameTemplate == render { - log.Printf("Invalid username template. %s == %s", c.UsernameTemplate, user) - http.Error(w, errors.New("invalid server configuration").Error(), http.StatusInternalServerError) - return - } - } - - token, err := c.PAATokenGenerator(ctx, user, host) - if err != nil { - log.Printf("Cannot generate PAA token for user %s due to %s", user, err) - http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError) - } - - if c.EnableUserToken { - userToken, err := c.UserTokenGenerator(ctx, user) - if err != nil { - log.Printf("Cannot generate token for user %s due to %s", user, err) - http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError) - } - render = strings.Replace(render, "{{ token }}", userToken, 1) - } - - // authenticated - seed := make([]byte, 16) - rand.Read(seed) - fn := hex.EncodeToString(seed) + ".rdp" - - w.Header().Set("Content-Disposition", "attachment; filename="+fn) - w.Header().Set("Content-Type", "application/x-rdp") - data := "full address:s:" + host + "\r\n" + - "gatewayhostname:s:" + c.GatewayAddress.Host + "\r\n" + - "gatewaycredentialssource:i:5\r\n" + - "gatewayusagemethod:i:1\r\n" + - "gatewayprofileusagemethod:i:1\r\n" + - "gatewayaccesstoken:s:" + token + "\r\n" + - "networkautodetect:i:" + strconv.Itoa(c.NetworkAutoDetect) + "\r\n" + - "bandwidthautodetect:i:" + strconv.Itoa(c.BandwidthAutoDetect) + "\r\n" + - "connection type:i:" + strconv.Itoa(c.ConnectionType) + "\r\n" + - "username:s:" + render + "\r\n" + - "domain:s:" + domain + "\r\n" + - "bitmapcachesize:i:32000\r\n" + - "smart sizing:i:1\r\n" - - http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data)) -} diff --git a/cmd/rdpgw/config/configuration.go b/cmd/rdpgw/config/configuration.go index 4f28dddfe5433700afdc89072c9e0b07a980d177..28b9b7a821384665b2a10f96dd27c3a4e3520830 100644 --- a/cmd/rdpgw/config/configuration.go +++ b/cmd/rdpgw/config/configuration.go @@ -31,7 +31,7 @@ type ServerConfig struct { SessionStore string `koanf:"sessionstore"` SendBuf int `koanf:"sendbuf"` ReceiveBuf int `koanf:"receivebuf"` - Tls string `koanf:"disabletls"` + Tls string `koanf:"tls"` Authentication string `koanf:"authentication"` AuthSocket string `koanf:"authsocket"` } diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index 52865f02cafe1fae85c237713cb57d378419af04..6b1d7cd53d12afce160117abb9d19c2ef3b91efa 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -4,12 +4,13 @@ import ( "context" "crypto/tls" "fmt" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/api" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/config" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/web" "github.com/coreos/go-oidc/v3/oidc" + "github.com/gorilla/sessions" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/thought-machine/go-flags" "golang.org/x/crypto/acme/autocert" @@ -27,17 +28,56 @@ var opts struct { var conf config.Configuration +func initOIDC(callbackUrl *url.URL, store sessions.Store) *web.OIDC { + // set oidc config + provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl) + if err != nil { + log.Fatalf("Cannot get oidc provider: %s", err) + } + oidcConfig := &oidc.Config{ + ClientID: conf.OpenId.ClientId, + } + verifier := provider.Verifier(oidcConfig) + + oauthConfig := oauth2.Config{ + ClientID: conf.OpenId.ClientId, + ClientSecret: conf.OpenId.ClientSecret, + RedirectURL: callbackUrl.String(), + Endpoint: provider.Endpoint(), + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + } + security.OIDCProvider = provider + security.Oauth2Config = oauthConfig + + o := web.OIDCConfig{ + OAuth2Config: &oauthConfig, + OIDCTokenVerifier: verifier, + SessionStore: store, + } + + return o.New() +} + func main() { - // get config + // load config _, err := flags.Parse(&opts) if err != nil { panic(err) } conf = config.Load(opts.ConfigFile) - security.VerifyClientIP = conf.Security.VerifyClientIp + // set callback url and external advertised gateway address + url, err := url.Parse(conf.Server.GatewayAddress) + if err != nil { + log.Printf("Cannot parse server gateway address %s due to %s", url, err) + } + if url.Scheme == "" { + url.Scheme = "https" + } + url.Path = "callback" - // set security keys + // set security options + security.VerifyClientIP = conf.Security.VerifyClientIp security.SigningKey = []byte(conf.Security.PAATokenSigningKey) security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey) security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey) @@ -46,66 +86,39 @@ func main() { security.HostSelection = conf.Server.HostSelection security.Hosts = conf.Server.Hosts - // configure api - api := &api.Config{ - QueryInfo: security.QueryInfo, - QueryTokenIssuer: conf.Security.QueryTokenIssuer, - EnableUserToken: conf.Security.EnableUserToken, + // init session store + sessionConf := web.SessionManagerConf{ SessionKey: []byte(conf.Server.SessionKey), SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey), - SessionStore: conf.Server.SessionStore, - Hosts: conf.Server.Hosts, - HostSelection: conf.Server.HostSelection, - NetworkAutoDetect: conf.Client.NetworkAutoDetect, - UsernameTemplate: conf.Client.UsernameTemplate, - BandwidthAutoDetect: conf.Client.BandwidthAutoDetect, - ConnectionType: conf.Client.ConnectionType, - SplitUserDomain: conf.Client.SplitUserDomain, - DefaultDomain: conf.Client.DefaultDomain, - SocketAddress: conf.Server.AuthSocket, - Authentication: conf.Server.Authentication, + StoreType: conf.Server.SessionStore, + } + store := sessionConf.Init() + + // configure web backend + w := &web.Config{ + QueryInfo: security.QueryInfo, + QueryTokenIssuer: conf.Security.QueryTokenIssuer, + EnableUserToken: conf.Security.EnableUserToken, + SessionStore: store, + Hosts: conf.Server.Hosts, + HostSelection: conf.Server.HostSelection, + RdpOpts: web.RdpOpts{ + UsernameTemplate: conf.Client.UsernameTemplate, + SplitUserDomain: conf.Client.SplitUserDomain, + DefaultDomain: conf.Client.DefaultDomain, + NetworkAutoDetect: conf.Client.NetworkAutoDetect, + BandwidthAutoDetect: conf.Client.BandwidthAutoDetect, + ConnectionType: conf.Client.ConnectionType, + }, } if conf.Caps.TokenAuth { - api.PAATokenGenerator = security.GeneratePAAToken + w.PAATokenGenerator = security.GeneratePAAToken } if conf.Security.EnableUserToken { - api.UserTokenGenerator = security.GenerateUserToken - } - - // get callback url and external advertised gateway address - url, err := url.Parse(conf.Server.GatewayAddress) - if url.Scheme == "" { - url.Scheme = "https" - } - url.Path = "callback" - - if conf.Server.Authentication == "openid" { - // set oidc config - provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl) - if err != nil { - log.Fatalf("Cannot get oidc provider: %s", err) - } - oidcConfig := &oidc.Config{ - ClientID: conf.OpenId.ClientId, - } - verifier := provider.Verifier(oidcConfig) - - api.GatewayAddress = url - - oauthConfig := oauth2.Config{ - ClientID: conf.OpenId.ClientId, - ClientSecret: conf.OpenId.ClientSecret, - RedirectURL: url.String(), - Endpoint: provider.Endpoint(), - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - } - security.OIDCProvider = provider - security.Oauth2Config = oauthConfig - api.OAuth2Config = &oauthConfig - api.OIDCTokenVerifier = verifier + w.UserTokenGenerator = security.GenerateUserToken } - api.NewApi() + h := w.NewHandler() log.Printf("Starting remote desktop gateway server") cfg := &tls.Config{} @@ -151,7 +164,7 @@ func main() { cfg.GetCertificate = certMgr.GetCertificate go func() { - http.ListenAndServe(":http", certMgr.HTTPHandler(nil)) + http.ListenAndServe(":80", certMgr.HTTPHandler(nil)) }() } } @@ -190,15 +203,17 @@ func main() { } if conf.Server.Authentication == "local" { - http.Handle("/remoteDesktopGateway/", common.EnrichContext(api.BasicAuth(gw.HandleGatewayProtocol))) + h := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket} + http.Handle("/remoteDesktopGateway/", common.EnrichContext(h.BasicAuth(gw.HandleGatewayProtocol))) } else { // openid - http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload)))) + oidc := initOIDC(url, store) + http.Handle("/connect", common.EnrichContext(oidc.Authenticated(http.HandlerFunc(h.HandleDownload)))) http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol))) - http.HandleFunc("/callback", api.HandleCallback) + http.HandleFunc("/callback", oidc.HandleCallback) } http.Handle("/metrics", promhttp.Handler()) - http.HandleFunc("/tokeninfo", api.TokenInfo) + http.HandleFunc("/tokeninfo", web.TokenInfo) if conf.Server.Tls == "disabled" { err = server.ListenAndServe() diff --git a/cmd/rdpgw/api/basic.go b/cmd/rdpgw/web/basic.go similarity index 89% rename from cmd/rdpgw/api/basic.go rename to cmd/rdpgw/web/basic.go index 8085519b77a4aabe95e77dd815222f1440b159fe..946036a6e693835ab2009e262155e8fdc758cb7b 100644 --- a/cmd/rdpgw/api/basic.go +++ b/cmd/rdpgw/web/basic.go @@ -1,4 +1,4 @@ -package api +package web import ( "context" @@ -15,13 +15,17 @@ const ( protocol = "unix" ) -func (c *Config) BasicAuth(next http.HandlerFunc) http.HandlerFunc { +type BasicAuthHandler struct { + SocketAddress string +} + +func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { username, password, ok := r.BasicAuth() if ok { ctx := r.Context() - conn, err := grpc.Dial(c.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()), + conn, err := grpc.Dial(h.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { return net.Dial(protocol, addr) })) diff --git a/cmd/rdpgw/web/oidc.go b/cmd/rdpgw/web/oidc.go new file mode 100644 index 0000000000000000000000000000000000000000..a06a1394367f73a1b44bb913ff64c1329adb2916 --- /dev/null +++ b/cmd/rdpgw/web/oidc.go @@ -0,0 +1,127 @@ +package web + +import ( + "context" + "encoding/hex" + "encoding/json" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/gorilla/sessions" + "github.com/patrickmn/go-cache" + "golang.org/x/oauth2" + "math/rand" + "net/http" + "time" +) + +const ( + CacheExpiration = time.Minute * 2 + CleanupInterval = time.Minute * 5 +) + +type OIDC struct { + oAuth2Config *oauth2.Config + oidcTokenVerifier *oidc.IDTokenVerifier + stateStore *cache.Cache + sessionStore sessions.Store +} + +type OIDCConfig struct { + OAuth2Config *oauth2.Config + OIDCTokenVerifier *oidc.IDTokenVerifier + SessionStore sessions.Store +} + +func (c *OIDCConfig) New() *OIDC { + return &OIDC{ + oAuth2Config: c.OAuth2Config, + oidcTokenVerifier: c.OIDCTokenVerifier, + stateStore: cache.New(CacheExpiration, CleanupInterval), + sessionStore: c.SessionStore, + } +} + +func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) { + state := r.URL.Query().Get("state") + s, found := h.stateStore.Get(state) + if !found { + http.Error(w, "unknown state", http.StatusBadRequest) + return + } + url := s.(string) + + ctx := r.Context() + oauth2Token, err := h.oAuth2Config.Exchange(ctx, r.URL.Query().Get("code")) + if err != nil { + http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) + return + } + + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError) + return + } + idToken, err := h.oidcTokenVerifier.Verify(ctx, rawIDToken) + if err != nil { + http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError) + return + } + + resp := struct { + OAuth2Token *oauth2.Token + IDTokenClaims *json.RawMessage // ID Token payload is just JSON. + }{oauth2Token, new(json.RawMessage)} + + if err := idToken.Claims(&resp.IDTokenClaims); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + var data map[string]interface{} + if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + session, err := h.sessionStore.Get(r, RdpGwSession) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + session.Options.MaxAge = MaxAge + session.Values["preferred_username"] = data["preferred_username"] + session.Values["authenticated"] = true + session.Values["access_token"] = oauth2Token.AccessToken + + if err = session.Save(r, w); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + + http.Redirect(w, r, url, http.StatusFound) +} + +func (h *OIDC) Authenticated(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + session, err := h.sessionStore.Get(r, RdpGwSession) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + found := session.Values["authenticated"] + if found == nil || !found.(bool) { + seed := make([]byte, 16) + rand.Read(seed) + state := hex.EncodeToString(seed) + h.stateStore.Set(state, r.RequestURI, cache.DefaultExpiration) + http.Redirect(w, r, h.oAuth2Config.AuthCodeURL(state), http.StatusFound) + return + } + + ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"]) + ctx = context.WithValue(ctx, "access_token", session.Values["access_token"]) + + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/cmd/rdpgw/web/session.go b/cmd/rdpgw/web/session.go new file mode 100644 index 0000000000000000000000000000000000000000..1f9a0d9fa700b491c416c27230bb7918a75a6bda --- /dev/null +++ b/cmd/rdpgw/web/session.go @@ -0,0 +1,30 @@ +package web + +import ( + "github.com/gorilla/sessions" + "log" + "os" +) + +type SessionManagerConf struct { + SessionKey []byte + SessionEncryptionKey []byte + StoreType string +} + +func (c *SessionManagerConf) Init() sessions.Store { + if len(c.SessionKey) < 32 { + log.Fatal("Session key too small") + } + if len(c.SessionEncryptionKey) < 32 { + log.Fatal("Session key too small") + } + + if c.StoreType == "file" { + log.Println("Filesystem is used as session storage") + return sessions.NewFilesystemStore(os.TempDir(), c.SessionKey, c.SessionEncryptionKey) + } else { + log.Println("Cookies are used as session storage") + return sessions.NewCookieStore(c.SessionKey, c.SessionEncryptionKey) + } +} diff --git a/cmd/rdpgw/api/token.go b/cmd/rdpgw/web/token.go similarity index 92% rename from cmd/rdpgw/api/token.go rename to cmd/rdpgw/web/token.go index 328b23328cd7b5557bf87f6214270af0d16a3767..82693936e9e434ab6e072553b50f1bbeee04fd40 100644 --- a/cmd/rdpgw/api/token.go +++ b/cmd/rdpgw/web/token.go @@ -1,4 +1,4 @@ -package api +package web import ( "context" @@ -9,7 +9,7 @@ import ( "net/http" ) -func (c *Config) TokenInfo(w http.ResponseWriter, r *http.Request) { +func TokenInfo(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Invalid request", http.StatusMethodNotAllowed) return @@ -37,4 +37,4 @@ func (c *Config) TokenInfo(w http.ResponseWriter, r *http.Request) { http.Error(w, "cannot encode json", http.StatusInternalServerError) return } -} \ No newline at end of file +} diff --git a/cmd/rdpgw/web/web.go b/cmd/rdpgw/web/web.go new file mode 100644 index 0000000000000000000000000000000000000000..180fee9a34823a423ad90a65a75523b6d5ba6bc1 --- /dev/null +++ b/cmd/rdpgw/web/web.go @@ -0,0 +1,215 @@ +package web + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "github.com/gorilla/sessions" + "log" + "math/rand" + "net/http" + "net/url" + "strconv" + "strings" + "time" +) + +const ( + RdpGwSession = "RDPGWSESSION" + MaxAge = 120 +) + +type TokenGeneratorFunc func(context.Context, string, string) (string, error) +type UserTokenGeneratorFunc func(context.Context, string) (string, error) +type QueryInfoFunc func(context.Context, string, string) (string, error) + +type Config struct { + SessionStore sessions.Store + PAATokenGenerator TokenGeneratorFunc + UserTokenGenerator UserTokenGeneratorFunc + QueryInfo QueryInfoFunc + QueryTokenIssuer string + EnableUserToken bool + Hosts []string + HostSelection string + GatewayAddress *url.URL + RdpOpts RdpOpts +} + +type RdpOpts struct { + UsernameTemplate string + SplitUserDomain bool + DefaultDomain string + NetworkAutoDetect int + BandwidthAutoDetect int + ConnectionType int +} + +type Handler struct { + sessionStore sessions.Store + paaTokenGenerator TokenGeneratorFunc + enableUserToken bool + userTokenGenerator UserTokenGeneratorFunc + queryInfo QueryInfoFunc + queryTokenIssuer string + gatewayAddress *url.URL + hosts []string + hostSelection string + rdpOpts RdpOpts +} + +func (c *Config) NewHandler() *Handler { + if len(c.Hosts) < 1 { + log.Fatal("Not enough hosts to connect to specified") + } + return &Handler{ + sessionStore: c.SessionStore, + paaTokenGenerator: c.PAATokenGenerator, + enableUserToken: c.EnableUserToken, + userTokenGenerator: c.UserTokenGenerator, + queryInfo: c.QueryInfo, + queryTokenIssuer: c.QueryTokenIssuer, + gatewayAddress: c.GatewayAddress, + hosts: c.Hosts, + hostSelection: c.HostSelection, + rdpOpts: c.RdpOpts, + } +} + +func (h *Handler) selectRandomHost() string { + rand.Seed(time.Now().Unix()) + host := h.hosts[rand.Intn(len(h.hosts))] + return host +} + +func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) { + switch h.hostSelection { + case "roundrobin": + return h.selectRandomHost(), nil + case "signed": + hosts, ok := u.Query()["host"] + if !ok { + return "", errors.New("invalid query parameter") + } + host, err := h.queryInfo(ctx, hosts[0], h.queryTokenIssuer) + if err != nil { + return "", err + } + found := false + for _, check := range h.hosts { + if check == host { + found = true + break + } + } + if !found { + log.Printf("Invalid host %s specified in token", hosts[0]) + return "", errors.New("invalid host specified in query token") + } + return host, nil + case "unsigned": + hosts, ok := u.Query()["host"] + if !ok { + return "", errors.New("invalid query parameter") + } + for _, check := range h.hosts { + if check == hosts[0] { + return hosts[0], nil + } + } + // not found + log.Printf("Invalid host %s specified in client request", hosts[0]) + return "", errors.New("invalid host specified in query parameter") + case "any": + hosts, ok := u.Query()["host"] + if !ok { + return "", errors.New("invalid query parameter") + } + return hosts[0], nil + default: + return h.selectRandomHost(), nil + } +} + +func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + userName, ok := ctx.Value("preferred_username").(string) + + opts := h.rdpOpts + + if !ok { + log.Printf("preferred_username not found in context") + http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError) + return + } + + // determine host to connect to + host, err := h.getHost(ctx, r.URL) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + host = strings.Replace(host, "{{ preferred_username }}", userName, 1) + + // split the username into user and domain + var user = userName + var domain = opts.DefaultDomain + if opts.SplitUserDomain { + creds := strings.SplitN(userName, "@", 2) + user = creds[0] + if len(creds) > 1 { + domain = creds[1] + } + } + + render := user + if opts.UsernameTemplate != "" { + render = fmt.Sprintf(h.rdpOpts.UsernameTemplate) + render = strings.Replace(render, "{{ username }}", user, 1) + if h.rdpOpts.UsernameTemplate == render { + log.Printf("Invalid username template. %s == %s", h.rdpOpts.UsernameTemplate, user) + http.Error(w, errors.New("invalid server configuration").Error(), http.StatusInternalServerError) + return + } + } + + token, err := h.paaTokenGenerator(ctx, user, host) + if err != nil { + log.Printf("Cannot generate PAA token for user %s due to %s", user, err) + http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError) + } + + if h.enableUserToken { + userToken, err := h.userTokenGenerator(ctx, user) + if err != nil { + log.Printf("Cannot generate token for user %s due to %s", user, err) + http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError) + } + render = strings.Replace(render, "{{ token }}", userToken, 1) + } + + // authenticated + seed := make([]byte, 16) + rand.Read(seed) + fn := hex.EncodeToString(seed) + ".rdp" + + w.Header().Set("Content-Disposition", "attachment; filename="+fn) + w.Header().Set("Content-Type", "application/x-rdp") + + data := "full address:s:" + host + "\r\n" + + "gatewayhostname:s:" + h.gatewayAddress.Host + "\r\n" + + "gatewaycredentialssource:i:5\r\n" + + "gatewayusagemethod:i:1\r\n" + + "gatewayprofileusagemethod:i:1\r\n" + + "gatewayaccesstoken:s:" + token + "\r\n" + + "networkautodetect:i:" + strconv.Itoa(opts.NetworkAutoDetect) + "\r\n" + + "bandwidthautodetect:i:" + strconv.Itoa(opts.BandwidthAutoDetect) + "\r\n" + + "connection type:i:" + strconv.Itoa(opts.ConnectionType) + "\r\n" + + "username:s:" + render + "\r\n" + + "domain:s:" + domain + "\r\n" + + "bitmapcachesize:i:32000\r\n" + + "smart sizing:i:1\r\n" + + http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data)) +} diff --git a/cmd/rdpgw/api/web_test.go b/cmd/rdpgw/web/web_test.go similarity index 87% rename from cmd/rdpgw/api/web_test.go rename to cmd/rdpgw/web/web_test.go index c1e4e5c2803ce803fd21ee10e1a94b19c49a6860..166d3cb0d4f4799d215e90a5dd59a48aa528ce07 100644 --- a/cmd/rdpgw/api/web_test.go +++ b/cmd/rdpgw/web/web_test.go @@ -1,4 +1,4 @@ -package api +package web import ( "context" @@ -27,12 +27,14 @@ func TestGetHost(t *testing.T) { HostSelection: "roundrobin", Hosts: hosts, } + h := c.NewHandler() + u := &url.URL{ Host: "example.com", } vals := u.Query() - host, err := c.getHost(ctx, u) + host, err := h.getHost(ctx, u) if err != nil { t.Fatalf("#{err}") } @@ -44,14 +46,16 @@ func TestGetHost(t *testing.T) { c.HostSelection = "unsigned" vals.Set("host", "in.valid.host") u.RawQuery = vals.Encode() - host, err = c.getHost(ctx, u) + h = c.NewHandler() + host, err = h.getHost(ctx, u) if err == nil { t.Fatalf("Accepted host %s is not in hosts list", host) } vals.Set("host", hosts[0]) u.RawQuery = vals.Encode() - host, err = c.getHost(ctx, u) + h = c.NewHandler() + host, err = h.getHost(ctx, u) if err != nil { t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err) } @@ -64,7 +68,8 @@ func TestGetHost(t *testing.T) { test := "bla.bla.com" vals.Set("host", test) u.RawQuery = vals.Encode() - host, err = c.getHost(ctx, u) + h = c.NewHandler() + host, err = h.getHost(ctx, u) if err != nil { t.Fatalf("%s is not accepted", host) } @@ -83,7 +88,8 @@ func TestGetHost(t *testing.T) { } vals.Set("host", queryToken) u.RawQuery = vals.Encode() - host, err = c.getHost(ctx, u) + h = c.NewHandler() + host, err = h.getHost(ctx, u) if err != nil { t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err) }