From 454d20307077d32544666637e96e3179f4d7f302 Mon Sep 17 00:00:00 2001 From: Bolke de Bruin <bolke@xs4all.nl> Date: Tue, 30 Aug 2022 11:47:26 +0200 Subject: [PATCH] Add acme support --- cmd/rdpgw/config/configuration.go | 12 +++---- cmd/rdpgw/main.go | 55 ++++++++++++++++++++++--------- go.mod | 2 +- 3 files changed, 46 insertions(+), 23 deletions(-) diff --git a/cmd/rdpgw/config/configuration.go b/cmd/rdpgw/config/configuration.go index 865da21..4f28ddd 100644 --- a/cmd/rdpgw/config/configuration.go +++ b/cmd/rdpgw/config/configuration.go @@ -31,7 +31,7 @@ type ServerConfig struct { SessionStore string `koanf:"sessionstore"` SendBuf int `koanf:"sendbuf"` ReceiveBuf int `koanf:"receivebuf"` - DisableTLS bool `koanf:"disabletls"` + Tls string `koanf:"disabletls"` Authentication string `koanf:"authentication"` AuthSocket string `koanf:"authsocket"` } @@ -117,9 +117,7 @@ func Load(configFile string) Configuration { var k = koanf.New(".") k.Load(confmap.Provider(map[string]interface{}{ - "Server.CertFile": "server.pem", - "Server.KeyFile": "key.pem", - "Server.TlsDisabled": false, + "Server.Tls": "auto", "Server.Port": 443, "Server.SessionStore": "cookie", "Server.HostSelection": "roundrobin", @@ -186,8 +184,8 @@ func Load(configFile string) Configuration { log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set") } - if Conf.Server.Authentication == "local" && Conf.Server.DisableTLS { - log.Fatalf("basicauth=local and disabletls are mutually exclusive") + if Conf.Server.Authentication == "local" && Conf.Server.Tls == "disable" { + log.Fatalf("basicauth=local and tls=disable are mutually exclusive") } if !Conf.Caps.TokenAuth && Conf.Server.Authentication == "openid" { @@ -198,7 +196,7 @@ func Load(configFile string) Configuration { if !strings.Contains(Conf.Server.GatewayAddress, "//") { Conf.Server.GatewayAddress = "//" + Conf.Server.GatewayAddress } - + return Conf } diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index 8f05a93..52865f0 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -3,6 +3,7 @@ package main import ( "context" "crypto/tls" + "fmt" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/api" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/config" @@ -11,6 +12,7 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/thought-machine/go-flags" + "golang.org/x/crypto/acme/autocert" "golang.org/x/oauth2" "log" "net/http" @@ -71,6 +73,13 @@ func main() { api.UserTokenGenerator = security.GenerateUserToken } + // get callback url and external advertised gateway address + url, err := url.Parse(conf.Server.GatewayAddress) + if url.Scheme == "" { + url.Scheme = "https" + } + url.Path = "callback" + if conf.Server.Authentication == "openid" { // set oidc config provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl) @@ -82,12 +91,6 @@ func main() { } verifier := provider.Verifier(oidcConfig) - // get callback url and external advertised gateway address - url, err := url.Parse(conf.Server.GatewayAddress) - if url.Scheme == "" { - url.Scheme = "https" - } - url.Path = "callback" api.GatewayAddress = url oauthConfig := oauth2.Config{ @@ -107,12 +110,11 @@ func main() { log.Printf("Starting remote desktop gateway server") cfg := &tls.Config{} - if conf.Server.DisableTLS { + if conf.Server.Tls == "disable" { log.Printf("TLS disabled - rdp gw connections require tls, make sure to have a terminator") } else { - if conf.Server.CertFile == "" || conf.Server.KeyFile == "" { - log.Fatal("Both certfile and keyfile need to be specified") - } + // auto config + tlsConfigured := false tlsDebug := os.Getenv("SSLKEYLOGFILE") if tlsDebug != "" { @@ -124,11 +126,34 @@ func main() { cfg.KeyLogWriter = w } - cert, err := tls.LoadX509KeyPair(conf.Server.CertFile, conf.Server.KeyFile) - if err != nil { - log.Fatal(err) + if conf.Server.KeyFile != "" && conf.Server.CertFile != "" { + cert, err := tls.LoadX509KeyPair(conf.Server.CertFile, conf.Server.KeyFile) + if err != nil { + log.Printf("Cannot load certfile or keyfile (%s) falling back to acme", err) + } + cfg.Certificates = append(cfg.Certificates, cert) + tlsConfigured = true + } + + if !tlsConfigured { + log.Printf("Using acme / letsencrypt for tls configuration. Enabling http (port 80) for verification") + // setup a simple handler which sends a HTHS header for six months (!) + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Strict-Transport-Security", "max-age=15768000 ; includeSubDomains") + fmt.Fprintf(w, "Hello from RDPGW") + }) + + certMgr := autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(url.Host), + Cache: autocert.DirCache("/tmp/rdpgw"), + } + cfg.GetCertificate = certMgr.GetCertificate + + go func() { + http.ListenAndServe(":http", certMgr.HTTPHandler(nil)) + }() } - cfg.Certificates = append(cfg.Certificates, cert) } server := http.Server{ @@ -175,7 +200,7 @@ func main() { http.Handle("/metrics", promhttp.Handler()) http.HandleFunc("/tokeninfo", api.TokenInfo) - if conf.Server.DisableTLS { + if conf.Server.Tls == "disabled" { err = server.ListenAndServe() } else { err = server.ListenAndServeTLS("", "") diff --git a/go.mod b/go.mod index 4e7f144..a0c35cd 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/prometheus/client_golang v1.12.1 github.com/thought-machine/go-flags v1.6.1 + golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 golang.org/x/oauth2 v0.0.0-20220722155238-128564f6959c google.golang.org/grpc v1.49.0 google.golang.org/protobuf v1.28.1 @@ -32,7 +33,6 @@ require ( github.com/prometheus/common v0.32.1 // indirect github.com/prometheus/procfs v0.7.3 // indirect github.com/stretchr/testify v1.7.1 // indirect - golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e // indirect golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect golang.org/x/text v0.3.7 // indirect -- GitLab