diff --git a/download.go b/download.go new file mode 100644 index 0000000000000000000000000000000000000000..5d57a7aae9cbb470c341da343647332d1a5d81ba --- /dev/null +++ b/download.go @@ -0,0 +1,30 @@ +package main + +import ( + "encoding/hex" + "github.com/patrickmn/go-cache" + "math/rand" + "net/http" + "strings" + "time" +) + +func handleRdpDownload(w http.ResponseWriter, r *http.Request) { + seed := make([]byte, 16) + rand.Read(seed) + fn := hex.EncodeToString(seed) + ".rdp" + + rand.Read(seed) + token := hex.EncodeToString(seed) + + tokens.Set(token, token, cache.DefaultExpiration) + + w.Header().Set("Content-Disposition", "attachment; filename="+fn) + w.Header().Set("Content-Type", "application/x-rdp") + http.ServeContent(w, r, fn, time.Now(), strings.NewReader( + "full address:s:localhost\r\n"+ + "gatewayhostname:s:localhost\r\n"+ + "gatewaycredentialssource:i:5\r\n"+ + "gatewayusagemethod:i:1\r\n"+ + "gatewayaccesstoken:s:" + token + "\r\n")) +} diff --git a/go.mod b/go.mod index a2cdd6c626cfd77e66945dfe86163f0caf113724..f17e15d8a9752ea4f3fe46e724d25b665fd470e8 100644 --- a/go.mod +++ b/go.mod @@ -6,4 +6,6 @@ require ( github.com/gorilla/websocket v1.4.2 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/prometheus/client_golang v1.7.1 + github.com/spf13/cobra v1.0.0 + github.com/spf13/viper v1.7.0 ) diff --git a/main.go b/main.go index 8342ae70c70464c11654d6858c3f147b48a6eb96..5cf47278adacb15e4302a6c4da4a7ab55abaa23d 100644 --- a/main.go +++ b/main.go @@ -2,23 +2,54 @@ package main import ( "crypto/tls" - "flag" + "github.com/patrickmn/go-cache" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus" + "github.com/spf13/cobra" + "github.com/spf13/viper" "log" "net/http" "os" "strconv" + "time" ) +var cmd = &cobra.Command{ + Use: "rdpgw", + Long: "Remote Desktop Gateway", +} + +var ( + port int + certFile string + keyFile string + + configFile string +) + +var tokens = cache.New(time.Minute *5, 10*time.Minute) + func main() { - port := flag.Int("port", 443, "port to listen on for incoming connections") - certFile := flag.String("certfile", "server.pem", "public key certificate file") - keyFile := flag.String("keyfile", "key.pem", "private key file") + cmd.PersistentFlags().IntVarP(&port, "port", "p", 443, "port to listen on for incoming connection") + cmd.PersistentFlags().StringVarP(&certFile, "certfile", "", "server.pem", "public key certificate file") + cmd.PersistentFlags().StringVarP(&keyFile, "keyfile", "", "key.pem", "private key file") + cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml", "config file (json, yaml, ini)") + + viper.BindPFlag("port", cmd.PersistentFlags().Lookup("port")) + viper.BindPFlag("certfile", cmd.PersistentFlags().Lookup("certfile")) + viper.BindPFlag("keyfile", cmd.PersistentFlags().Lookup("keyfile")) + + viper.SetConfigFile(configFile) + viper.AddConfigPath(".") + viper.SetEnvPrefix("RDPGW") + viper.AutomaticEnv() - flag.Parse() + err := viper.ReadInConfig() + if err != nil { + log.Printf("No config file found. Using defaults") + } - if *certFile == "" || *keyFile == "" { + if certFile == "" || keyFile == "" { log.Fatal("Both certfile and keyfile need to be specified") } @@ -36,18 +67,19 @@ func main() { log.Printf("Key log file set to: %s", tlsDebug) cfg.KeyLogWriter = w } - cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) + cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { log.Fatal(err) } cfg.Certificates = append(cfg.Certificates, cert) server := http.Server{ - Addr: ":" + strconv.Itoa(*port), + Addr: ":" + strconv.Itoa(port), TLSConfig: cfg, TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 } http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol) + http.HandleFunc("/connect", handleRdpDownload) http.Handle("/metrics", promhttp.Handler()) prometheus.MustRegister(connectionCache) diff --git a/rdg.go b/rdg.go index 9ebad663be4d72ac5a8dbee16eb136d82ac89716..25097775d59b6ebb76707bcb900ef14d116d69b6 100644 --- a/rdg.go +++ b/rdg.go @@ -217,7 +217,11 @@ func handleWebsocketProtocol(conn *websocket.Conn) { log.Printf("Handshake response: %x", msg) conn.WriteMessage(mt, msg) case PKT_TYPE_TUNNEL_CREATE: - readCreateTunnelRequest(pkt) + _, cookie := readCreateTunnelRequest(pkt) + if _, found := tokens.Get(cookie); found == false { + log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr()) + return + } msg := createTunnelResponse() log.Printf("Create tunnel response: %x", msg) conn.WriteMessage(mt, msg)