From 454d20307077d32544666637e96e3179f4d7f302 Mon Sep 17 00:00:00 2001
From: Bolke de Bruin <bolke@xs4all.nl>
Date: Tue, 30 Aug 2022 11:47:26 +0200
Subject: [PATCH] Add acme support

---
 cmd/rdpgw/config/configuration.go | 12 +++----
 cmd/rdpgw/main.go                 | 55 ++++++++++++++++++++++---------
 go.mod                            |  2 +-
 3 files changed, 46 insertions(+), 23 deletions(-)

diff --git a/cmd/rdpgw/config/configuration.go b/cmd/rdpgw/config/configuration.go
index 865da21..4f28ddd 100644
--- a/cmd/rdpgw/config/configuration.go
+++ b/cmd/rdpgw/config/configuration.go
@@ -31,7 +31,7 @@ type ServerConfig struct {
 	SessionStore         string   `koanf:"sessionstore"`
 	SendBuf              int      `koanf:"sendbuf"`
 	ReceiveBuf           int      `koanf:"receivebuf"`
-	DisableTLS           bool     `koanf:"disabletls"`
+	Tls                  string   `koanf:"disabletls"`
 	Authentication       string   `koanf:"authentication"`
 	AuthSocket           string   `koanf:"authsocket"`
 }
@@ -117,9 +117,7 @@ func Load(configFile string) Configuration {
 	var k = koanf.New(".")
 
 	k.Load(confmap.Provider(map[string]interface{}{
-		"Server.CertFile":            "server.pem",
-		"Server.KeyFile":             "key.pem",
-		"Server.TlsDisabled":         false,
+		"Server.Tls":                 "auto",
 		"Server.Port":                443,
 		"Server.SessionStore":        "cookie",
 		"Server.HostSelection":       "roundrobin",
@@ -186,8 +184,8 @@ func Load(configFile string) Configuration {
 		log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
 	}
 
-	if Conf.Server.Authentication == "local" && Conf.Server.DisableTLS {
-		log.Fatalf("basicauth=local and disabletls are mutually exclusive")
+	if Conf.Server.Authentication == "local" && Conf.Server.Tls == "disable" {
+		log.Fatalf("basicauth=local and tls=disable are mutually exclusive")
 	}
 
 	if !Conf.Caps.TokenAuth && Conf.Server.Authentication == "openid" {
@@ -198,7 +196,7 @@ func Load(configFile string) Configuration {
 	if !strings.Contains(Conf.Server.GatewayAddress, "//") {
 		Conf.Server.GatewayAddress = "//" + Conf.Server.GatewayAddress
 	}
-	
+
 	return Conf
 
 }
diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go
index 8f05a93..52865f0 100644
--- a/cmd/rdpgw/main.go
+++ b/cmd/rdpgw/main.go
@@ -3,6 +3,7 @@ package main
 import (
 	"context"
 	"crypto/tls"
+	"fmt"
 	"github.com/bolkedebruin/rdpgw/cmd/rdpgw/api"
 	"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
 	"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config"
@@ -11,6 +12,7 @@ import (
 	"github.com/coreos/go-oidc/v3/oidc"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	"github.com/thought-machine/go-flags"
+	"golang.org/x/crypto/acme/autocert"
 	"golang.org/x/oauth2"
 	"log"
 	"net/http"
@@ -71,6 +73,13 @@ func main() {
 		api.UserTokenGenerator = security.GenerateUserToken
 	}
 
+	// get callback url and external advertised gateway address
+	url, err := url.Parse(conf.Server.GatewayAddress)
+	if url.Scheme == "" {
+		url.Scheme = "https"
+	}
+	url.Path = "callback"
+
 	if conf.Server.Authentication == "openid" {
 		// set oidc config
 		provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
@@ -82,12 +91,6 @@ func main() {
 		}
 		verifier := provider.Verifier(oidcConfig)
 
-		// get callback url and external advertised gateway address
-		url, err := url.Parse(conf.Server.GatewayAddress)
-		if url.Scheme == "" {
-			url.Scheme = "https"
-		}
-		url.Path = "callback"
 		api.GatewayAddress = url
 
 		oauthConfig := oauth2.Config{
@@ -107,12 +110,11 @@ func main() {
 	log.Printf("Starting remote desktop gateway server")
 	cfg := &tls.Config{}
 
-	if conf.Server.DisableTLS {
+	if conf.Server.Tls == "disable" {
 		log.Printf("TLS disabled - rdp gw connections require tls, make sure to have a terminator")
 	} else {
-		if conf.Server.CertFile == "" || conf.Server.KeyFile == "" {
-			log.Fatal("Both certfile and keyfile need to be specified")
-		}
+		// auto config
+		tlsConfigured := false
 
 		tlsDebug := os.Getenv("SSLKEYLOGFILE")
 		if tlsDebug != "" {
@@ -124,11 +126,34 @@ func main() {
 			cfg.KeyLogWriter = w
 		}
 
-		cert, err := tls.LoadX509KeyPair(conf.Server.CertFile, conf.Server.KeyFile)
-		if err != nil {
-			log.Fatal(err)
+		if conf.Server.KeyFile != "" && conf.Server.CertFile != "" {
+			cert, err := tls.LoadX509KeyPair(conf.Server.CertFile, conf.Server.KeyFile)
+			if err != nil {
+				log.Printf("Cannot load certfile or keyfile (%s) falling back to acme", err)
+			}
+			cfg.Certificates = append(cfg.Certificates, cert)
+			tlsConfigured = true
+		}
+
+		if !tlsConfigured {
+			log.Printf("Using acme / letsencrypt for tls configuration. Enabling http (port 80) for verification")
+			// setup a simple handler which sends a HTHS header for six months (!)
+			http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+				w.Header().Set("Strict-Transport-Security", "max-age=15768000 ; includeSubDomains")
+				fmt.Fprintf(w, "Hello from RDPGW")
+			})
+
+			certMgr := autocert.Manager{
+				Prompt:     autocert.AcceptTOS,
+				HostPolicy: autocert.HostWhitelist(url.Host),
+				Cache:      autocert.DirCache("/tmp/rdpgw"),
+			}
+			cfg.GetCertificate = certMgr.GetCertificate
+
+			go func() {
+				http.ListenAndServe(":http", certMgr.HTTPHandler(nil))
+			}()
 		}
-		cfg.Certificates = append(cfg.Certificates, cert)
 	}
 
 	server := http.Server{
@@ -175,7 +200,7 @@ func main() {
 	http.Handle("/metrics", promhttp.Handler())
 	http.HandleFunc("/tokeninfo", api.TokenInfo)
 
-	if conf.Server.DisableTLS {
+	if conf.Server.Tls == "disabled" {
 		err = server.ListenAndServe()
 	} else {
 		err = server.ListenAndServeTLS("", "")
diff --git a/go.mod b/go.mod
index 4e7f144..a0c35cd 100644
--- a/go.mod
+++ b/go.mod
@@ -12,6 +12,7 @@ require (
 	github.com/patrickmn/go-cache v2.1.0+incompatible
 	github.com/prometheus/client_golang v1.12.1
 	github.com/thought-machine/go-flags v1.6.1
+	golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4
 	golang.org/x/oauth2 v0.0.0-20220722155238-128564f6959c
 	google.golang.org/grpc v1.49.0
 	google.golang.org/protobuf v1.28.1
@@ -32,7 +33,6 @@ require (
 	github.com/prometheus/common v0.32.1 // indirect
 	github.com/prometheus/procfs v0.7.3 // indirect
 	github.com/stretchr/testify v1.7.1 // indirect
-	golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
 	golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e // indirect
 	golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect
 	golang.org/x/text v0.3.7 // indirect
-- 
GitLab