diff --git a/download.go b/download.go index 5d57a7aae9cbb470c341da343647332d1a5d81ba..9125ca9153ac5b55cdaca0170676d73a2ac570b3 100644 --- a/download.go +++ b/download.go @@ -2,7 +2,11 @@ package main import ( "encoding/hex" + "encoding/json" "github.com/patrickmn/go-cache" + "github.com/spf13/viper" + "golang.org/x/oauth2" + "log" "math/rand" "net/http" "strings" @@ -10,21 +14,93 @@ import ( ) func handleRdpDownload(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("RDPGWSESSIONV1") + if err != nil { + http.Redirect(w, r, oauthConfig.AuthCodeURL(state), http.StatusFound) + return + } + + data, found := tokens.Get(cookie.Value) + if found == false { + log.Printf("Found expired or non existent session: %s", cookie.Value) + http.Redirect(w, r, oauthConfig.AuthCodeURL(state), http.StatusFound) + return + } + + // authenticated seed := make([]byte, 16) rand.Read(seed) fn := hex.EncodeToString(seed) + ".rdp" - rand.Read(seed) - token := hex.EncodeToString(seed) - - tokens.Set(token, token, cache.DefaultExpiration) - w.Header().Set("Content-Disposition", "attachment; filename="+fn) w.Header().Set("Content-Type", "application/x-rdp") http.ServeContent(w, r, fn, time.Now(), strings.NewReader( - "full address:s:localhost\r\n"+ - "gatewayhostname:s:localhost\r\n"+ + "full address:s:" + host + "\r\n"+ + "gatewayhostname:s:" + gateway +"\r\n"+ "gatewaycredentialssource:i:5\r\n"+ "gatewayusagemethod:i:1\r\n"+ - "gatewayaccesstoken:s:" + token + "\r\n")) + "gatewayaccesstoken:s:" + cookie.Value + "\r\n")) } + +func handleCallback(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("state") != state { + http.Error(w, "state did not match", http.StatusBadRequest) + return + } + + oauthToken, err := oauthConfig.Exchange(ctx, r.URL.Query().Get("code")) + if err != nil { + http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) + return + } + + rawIDToken, ok := oauthToken.Extra("id_token").(string) + if !ok { + http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError) + return + } + idToken, err := verifier.Verify(ctx, rawIDToken) + if err != nil { + http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError) + return + } + + resp := struct { + OAuth2Token *oauth2.Token + IDTokenClaims *json.RawMessage // ID Token payload is just JSON. + }{oauthToken, new(json.RawMessage)} + + if err := idToken.Claims(&resp.IDTokenClaims); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + q, err := json.MarshalIndent(resp, "", " ") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + var data map[string]interface{} + if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + seed := make([]byte, 16) + rand.Read(seed) + token := hex.EncodeToString(seed) + + cookie := http.Cookie{ + Name: "RDPGWSESSIONV1", + Value: token, + Path: "/", + Secure: true, + HttpOnly: true, + } + + tokens.Set(token, data[claim].(string), cache.DefaultExpiration) + + http.SetCookie(w, &cookie) + http.Redirect(w, r, "/connect", http.StatusFound) +} \ No newline at end of file diff --git a/go.mod b/go.mod index f17e15d8a9752ea4f3fe46e724d25b665fd470e8..3b3589aeef825cdb4a04d58ab0730a6d65ebf690 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/bolkedebruin/rdpgw go 1.14 require ( + github.com/coreos/go-oidc/v3 v3.0.0-alpha.1 github.com/gorilla/websocket v1.4.2 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/prometheus/client_golang v1.7.1 diff --git a/main.go b/main.go index 5cf47278adacb15e4302a6c4da4a7ab55abaa23d..846aa59b0b5969651bfb38a7170fdd5b3c703806 100644 --- a/main.go +++ b/main.go @@ -1,12 +1,15 @@ package main import ( + "context" "crypto/tls" + "github.com/coreos/go-oidc/v3/oidc" "github.com/patrickmn/go-cache" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus" "github.com/spf13/cobra" "github.com/spf13/viper" + "golang.org/x/oauth2" "log" "net/http" "os" @@ -29,17 +32,39 @@ var ( var tokens = cache.New(time.Minute *5, 10*time.Minute) +var state string + +var oauthConfig oauth2.Config +var oidcConfig *oidc.Config +var verifier *oidc.IDTokenVerifier +var ctx context.Context + +var gateway string +var overrideHost bool +var hostTemplate string +var claim string + func main() { + // get config cmd.PersistentFlags().IntVarP(&port, "port", "p", 443, "port to listen on for incoming connection") cmd.PersistentFlags().StringVarP(&certFile, "certfile", "", "server.pem", "public key certificate file") cmd.PersistentFlags().StringVarP(&keyFile, "keyfile", "", "key.pem", "private key file") cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml", "config file (json, yaml, ini)") + cmd.PersistentFlags().StringVarP(&gateway, "gateway", "g", "localhost", "gateway dns name") + cmd.PersistentFlags().BoolVarP(&overrideHost, "hostOverride", "", false, "weather the user can override the host to connect to") + cmd.PersistentFlags().StringVarP(&hostTemplate, "hostTemplate", "t", "", "host template") + cmd.PersistentFlags().StringVarP(&claim, "claim", "", "preferred_username", "openid claim to use for filling in template") viper.BindPFlag("port", cmd.PersistentFlags().Lookup("port")) viper.BindPFlag("certfile", cmd.PersistentFlags().Lookup("certfile")) viper.BindPFlag("keyfile", cmd.PersistentFlags().Lookup("keyfile")) + viper.BindPFlag("gateway", cmd.PersistentFlags().Lookup("gateway")) + viper.BindPFlag("hostOverride", cmd.PersistentFlags().Lookup("hostOverride")) + viper.BindPFlag("hostTemplate", cmd.PersistentFlags().Lookup("hostTemplate")) + viper.BindPFlag("claim", cmd.PersistentFlags().Lookup("claim")) - viper.SetConfigFile(configFile) + viper.SetConfigName("rdpgw") + //viper.SetConfigFile(configFile) viper.AddConfigPath(".") viper.SetEnvPrefix("RDPGW") viper.AutomaticEnv() @@ -49,6 +74,33 @@ func main() { log.Printf("No config file found. Using defaults") } + // dont understand why I need to do this + gateway = viper.GetString("gateway") + hostTemplate = viper.GetString("hostTemplate") + overrideHost = viper.GetBool("hostOverride") + + // set oidc config + ctx = context.Background() + provider, err := oidc.NewProvider(ctx, viper.GetString("providerUrl")) + if err != nil { + log.Fatalf("Cannot get oidc provider: %s", err) + } + oidcConfig = &oidc.Config{ + ClientID: viper.GetString("clientId"), + } + verifier = provider.Verifier(oidcConfig) + + oauthConfig = oauth2.Config{ + ClientID: viper.GetString("clientId"), + ClientSecret: viper.GetString("clientSecret"), + RedirectURL: "https://" + gateway + "/callback", + Endpoint: provider.Endpoint(), + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + } + + // check what is required + state = "rdpstate" + if certFile == "" || keyFile == "" { log.Fatal("Both certfile and keyfile need to be specified") } @@ -81,6 +133,7 @@ func main() { http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol) http.HandleFunc("/connect", handleRdpDownload) http.Handle("/metrics", promhttp.Handler()) + http.HandleFunc("/callback", handleCallback) prometheus.MustRegister(connectionCache) prometheus.MustRegister(legacyConnections) diff --git a/rdg.go b/rdg.go index 25097775d59b6ebb76707bcb900ef14d116d69b6..2aaec71591fedd89803a102a48789edf5b9410b0 100644 --- a/rdg.go +++ b/rdg.go @@ -16,6 +16,7 @@ import ( "net/http" "net/http/httputil" "strconv" + "strings" "time" "unicode/utf16" "unicode/utf8" @@ -182,6 +183,7 @@ func handleWebsocketProtocol(conn *websocket.Conn) { websocketConnections.Inc() defer websocketConnections.Dec() + var host string for { mt, msg, err := conn.ReadMessage() if err != nil { @@ -218,10 +220,12 @@ func handleWebsocketProtocol(conn *websocket.Conn) { conn.WriteMessage(mt, msg) case PKT_TYPE_TUNNEL_CREATE: _, cookie := readCreateTunnelRequest(pkt) - if _, found := tokens.Get(cookie); found == false { + data, found := tokens.Get(cookie) + if found == false { log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr()) return } + host = strings.Replace(hostTemplate, "%%", data.(string), 1) msg := createTunnelResponse() log.Printf("Create tunnel response: %x", msg) conn.WriteMessage(mt, msg) @@ -232,13 +236,17 @@ func handleWebsocketProtocol(conn *websocket.Conn) { conn.WriteMessage(mt, msg) case PKT_TYPE_CHANNEL_CREATE: server, port := readChannelCreateRequest(pkt) - log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server) + if overrideHost == true { + log.Printf("Override allowed") + host = net.JoinHostPort(server, strconv.Itoa(int(port))) + } + log.Printf("Establishing connection to RDP server: %s", host) remote, err = net.DialTimeout( "tcp", - net.JoinHostPort(server, strconv.Itoa(int(port))), - time.Second * 15) + host, + time.Second * 30) if err != nil { - log.Printf("Error connecting to %s, %d, %s", server, port, err) + log.Printf("Error connecting to %s", host) return } log.Printf("Connection established") @@ -349,7 +357,11 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { msg := handshakeResponse(major, minor, auth) s.ConnOut.Write(msg) case PKT_TYPE_TUNNEL_CREATE: - readCreateTunnelRequest(pkt) + _, cookie := readCreateTunnelRequest(pkt) + if _, found := tokens.Get(cookie); found == false { + log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr()) + return + } msg := createTunnelResponse() s.ConnOut.Write(msg) case PKT_TYPE_TUNNEL_AUTH: