diff --git a/download.go b/download.go index 1cee6a6af2fb6b2e4e4352e30380ae7b98b2f1d4..04cc61899915dd27ee24e0066567254c7de837fa 100644 --- a/download.go +++ b/download.go @@ -4,11 +4,9 @@ import ( "encoding/hex" "encoding/json" "github.com/patrickmn/go-cache" - "github.com/spf13/viper" "golang.org/x/oauth2" "log" "math/rand" - "net" "net/http" "strings" "time" @@ -30,7 +28,12 @@ func handleRdpDownload(w http.ResponseWriter, r *http.Request) { return } - host := strings.Replace(viper.GetString("hostTemplate"), "%%", data.(string), 1) + var host = conf.Server.HostTemplate + for k, v := range data.(map[string]interface{}) { + if val, ok := v.(string); ok == true { + host = strings.Replace(host, "{{ " + k + " }}", val, 1) + } + } // authenticated seed := make([]byte, 16) @@ -41,7 +44,7 @@ func handleRdpDownload(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/x-rdp") http.ServeContent(w, r, fn, time.Now(), strings.NewReader( "full address:s:" + host + "\r\n"+ - "gatewayhostname:s:" + net.JoinHostPort(conf.Server.GatewayAddress, string(conf.Server.Port)) +"\r\n"+ + "gatewayhostname:s:" + conf.Server.GatewayAddress +"\r\n"+ "gatewaycredentialssource:i:5\r\n"+ "gatewayusagemethod:i:1\r\n"+ "gatewayaccesstoken:s:" + cookie.Value + "\r\n")) @@ -99,7 +102,7 @@ func handleCallback(w http.ResponseWriter, r *http.Request) { } // TODO: make dynamic - tokens.Set(token, data["preferred_username"].(string), cache.DefaultExpiration) + tokens.Set(token, data, cache.DefaultExpiration) http.SetCookie(w, &cookie) http.Redirect(w, r, "/connect", http.StatusFound) diff --git a/main.go b/main.go index 7ab8b3fc8f1065311051cb4577ed2f82e5372012..5e854b13e98d471fe55e9bcbf7c18bcc9206f666 100644 --- a/main.go +++ b/main.go @@ -59,13 +59,13 @@ func main() { log.Fatalf("Cannot get oidc provider: %s", err) } oidcConfig := &oidc.Config{ - ClientID: viper.GetString("clientId"), + ClientID: conf.OpenId.ClientId, } verifier = provider.Verifier(oidcConfig) oauthConfig = oauth2.Config{ - ClientID: viper.GetString("clientId"), - ClientSecret: viper.GetString("clientSecret"), + ClientID: conf.OpenId.ClientId, + ClientSecret: conf.OpenId.ClientSecret, RedirectURL: "https://" + conf.Server.GatewayAddress + "/callback", Endpoint: provider.Endpoint(), Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, diff --git a/rdg.go b/rdg.go index a09d57c7aba9aefae396b970de42d56c86635a20..5b65ff047c2aa659fc3c5eb4df457c4df993b972 100644 --- a/rdg.go +++ b/rdg.go @@ -225,7 +225,12 @@ func handleWebsocketProtocol(conn *websocket.Conn) { log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr()) return } - host = strings.Replace(conf.Server.HostTemplate, "%%", data.(string), 1) + host = conf.Server.HostTemplate + for k, v := range data.(map[string]interface{}) { + if val, ok := v.(string); ok == true { + host = strings.Replace(host, "{{ " + k + " }}", val, 1) + } + } msg := createTunnelResponse() log.Printf("Create tunnel response: %x", msg) conn.WriteMessage(mt, msg)