diff --git a/config/configuration.go b/config/configuration.go new file mode 100644 index 0000000000000000000000000000000000000000..0f1ac22809cd01373bd047f6c49eedeb2d1e0c79 --- /dev/null +++ b/config/configuration.go @@ -0,0 +1,37 @@ +package config + +type Configuration struct { + Server ServerConfig + OpenId OpenIDConfig + Caps RDGCapsConfig +} + +type ServerConfig struct { + GatewayAddress string + Port int + CertFile string + KeyFile string + FarmHosts []string + EnableOverride bool + HostTemplate string +} + +type OpenIDConfig struct { + ProviderUrl string + ClientId string + ClientSecret string + CallbackHost string +} + +type RDGCapsConfig struct { + SmartCardAuth bool + TokenAuth bool + IdleTimeout int + RedirectAll bool + DisableRedirect bool + DisableClipboard bool + DisablePrinter bool + DisablePort bool + DisablePnp bool + DisableDrive bool +} diff --git a/download.go b/download.go index 5da5f0784fca610d20bedb859f2456e452c00956..1cee6a6af2fb6b2e4e4352e30380ae7b98b2f1d4 100644 --- a/download.go +++ b/download.go @@ -8,11 +8,14 @@ import ( "golang.org/x/oauth2" "log" "math/rand" + "net" "net/http" "strings" "time" ) +const state = "thisismystatebutshouldberandom" + func handleRdpDownload(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("RDPGWSESSIONV1") if err != nil { @@ -38,7 +41,7 @@ func handleRdpDownload(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/x-rdp") http.ServeContent(w, r, fn, time.Now(), strings.NewReader( "full address:s:" + host + "\r\n"+ - "gatewayhostname:s:" + gateway +"\r\n"+ + "gatewayhostname:s:" + net.JoinHostPort(conf.Server.GatewayAddress, string(conf.Server.Port)) +"\r\n"+ "gatewaycredentialssource:i:5\r\n"+ "gatewayusagemethod:i:1\r\n"+ "gatewayaccesstoken:s:" + cookie.Value + "\r\n")) @@ -95,7 +98,8 @@ func handleCallback(w http.ResponseWriter, r *http.Request) { HttpOnly: true, } - tokens.Set(token, data[claim].(string), cache.DefaultExpiration) + // TODO: make dynamic + tokens.Set(token, data["preferred_username"].(string), cache.DefaultExpiration) http.SetCookie(w, &cookie) http.Redirect(w, r, "/connect", http.StatusFound) diff --git a/main.go b/main.go index 846aa59b0b5969651bfb38a7170fdd5b3c703806..7ab8b3fc8f1065311051cb4577ed2f82e5372012 100644 --- a/main.go +++ b/main.go @@ -3,10 +3,11 @@ package main import ( "context" "crypto/tls" + "github.com/bolkedebruin/rdpgw/config" "github.com/coreos/go-oidc/v3/oidc" "github.com/patrickmn/go-cache" - "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/spf13/cobra" "github.com/spf13/viper" "golang.org/x/oauth2" @@ -23,69 +24,41 @@ var cmd = &cobra.Command{ } var ( - port int - certFile string - keyFile string - configFile string ) var tokens = cache.New(time.Minute *5, 10*time.Minute) +var conf config.Configuration -var state string - -var oauthConfig oauth2.Config -var oidcConfig *oidc.Config var verifier *oidc.IDTokenVerifier +var oauthConfig oauth2.Config var ctx context.Context -var gateway string -var overrideHost bool -var hostTemplate string -var claim string - func main() { // get config - cmd.PersistentFlags().IntVarP(&port, "port", "p", 443, "port to listen on for incoming connection") - cmd.PersistentFlags().StringVarP(&certFile, "certfile", "", "server.pem", "public key certificate file") - cmd.PersistentFlags().StringVarP(&keyFile, "keyfile", "", "key.pem", "private key file") cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml", "config file (json, yaml, ini)") - cmd.PersistentFlags().StringVarP(&gateway, "gateway", "g", "localhost", "gateway dns name") - cmd.PersistentFlags().BoolVarP(&overrideHost, "hostOverride", "", false, "weather the user can override the host to connect to") - cmd.PersistentFlags().StringVarP(&hostTemplate, "hostTemplate", "t", "", "host template") - cmd.PersistentFlags().StringVarP(&claim, "claim", "", "preferred_username", "openid claim to use for filling in template") - - viper.BindPFlag("port", cmd.PersistentFlags().Lookup("port")) - viper.BindPFlag("certfile", cmd.PersistentFlags().Lookup("certfile")) - viper.BindPFlag("keyfile", cmd.PersistentFlags().Lookup("keyfile")) - viper.BindPFlag("gateway", cmd.PersistentFlags().Lookup("gateway")) - viper.BindPFlag("hostOverride", cmd.PersistentFlags().Lookup("hostOverride")) - viper.BindPFlag("hostTemplate", cmd.PersistentFlags().Lookup("hostTemplate")) - viper.BindPFlag("claim", cmd.PersistentFlags().Lookup("claim")) viper.SetConfigName("rdpgw") - //viper.SetConfigFile(configFile) + viper.SetConfigFile(configFile) viper.AddConfigPath(".") viper.SetEnvPrefix("RDPGW") viper.AutomaticEnv() - err := viper.ReadInConfig() - if err != nil { - log.Printf("No config file found. Using defaults") + if err := viper.ReadInConfig(); err != nil { + log.Printf("No config file found (%s). Using defaults", err) } - // dont understand why I need to do this - gateway = viper.GetString("gateway") - hostTemplate = viper.GetString("hostTemplate") - overrideHost = viper.GetBool("hostOverride") + if err := viper.Unmarshal(&conf); err != nil { + log.Fatalf("Cannot unmarshal the config file; %s", err) + } // set oidc config ctx = context.Background() - provider, err := oidc.NewProvider(ctx, viper.GetString("providerUrl")) + provider, err := oidc.NewProvider(ctx, conf.OpenId.ProviderUrl) if err != nil { log.Fatalf("Cannot get oidc provider: %s", err) } - oidcConfig = &oidc.Config{ + oidcConfig := &oidc.Config{ ClientID: viper.GetString("clientId"), } verifier = provider.Verifier(oidcConfig) @@ -93,15 +66,12 @@ func main() { oauthConfig = oauth2.Config{ ClientID: viper.GetString("clientId"), ClientSecret: viper.GetString("clientSecret"), - RedirectURL: "https://" + gateway + "/callback", + RedirectURL: "https://" + conf.Server.GatewayAddress + "/callback", Endpoint: provider.Endpoint(), Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, } - // check what is required - state = "rdpstate" - - if certFile == "" || keyFile == "" { + if conf.Server.CertFile == "" || conf.Server.KeyFile == "" { log.Fatal("Both certfile and keyfile need to be specified") } @@ -109,6 +79,7 @@ func main() { //mux.HandleFunc("*", HelloServer) log.Printf("Starting remote desktop gateway server") + cfg := &tls.Config{} tlsDebug := os.Getenv("SSLKEYLOGFILE") if tlsDebug != "" { @@ -119,13 +90,15 @@ func main() { log.Printf("Key log file set to: %s", tlsDebug) cfg.KeyLogWriter = w } - cert, err := tls.LoadX509KeyPair(certFile, keyFile) + + + cert, err := tls.LoadX509KeyPair(conf.Server.CertFile, conf.Server.KeyFile) if err != nil { log.Fatal(err) } cfg.Certificates = append(cfg.Certificates, cert) server := http.Server{ - Addr: ":" + strconv.Itoa(port), + Addr: ":" + strconv.Itoa(conf.Server.Port), TLSConfig: cfg, TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 } diff --git a/rdg.go b/rdg.go index 2aaec71591fedd89803a102a48789edf5b9410b0..a09d57c7aba9aefae396b970de42d56c86635a20 100644 --- a/rdg.go +++ b/rdg.go @@ -225,7 +225,7 @@ func handleWebsocketProtocol(conn *websocket.Conn) { log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr()) return } - host = strings.Replace(hostTemplate, "%%", data.(string), 1) + host = strings.Replace(conf.Server.HostTemplate, "%%", data.(string), 1) msg := createTunnelResponse() log.Printf("Create tunnel response: %x", msg) conn.WriteMessage(mt, msg) @@ -236,7 +236,7 @@ func handleWebsocketProtocol(conn *websocket.Conn) { conn.WriteMessage(mt, msg) case PKT_TYPE_CHANNEL_CREATE: server, port := readChannelCreateRequest(pkt) - if overrideHost == true { + if conf.Server.EnableOverride == true { log.Printf("Override allowed") host = net.JoinHostPort(server, strconv.Itoa(int(port))) }