diff --git a/download.go b/download.go
new file mode 100644
index 0000000000000000000000000000000000000000..5d57a7aae9cbb470c341da343647332d1a5d81ba
--- /dev/null
+++ b/download.go
@@ -0,0 +1,30 @@
+package main
+
+import (
+	"encoding/hex"
+	"github.com/patrickmn/go-cache"
+	"math/rand"
+	"net/http"
+	"strings"
+	"time"
+)
+
+func handleRdpDownload(w http.ResponseWriter, r *http.Request) {
+	seed := make([]byte, 16)
+	rand.Read(seed)
+	fn := hex.EncodeToString(seed) + ".rdp"
+
+	rand.Read(seed)
+	token := hex.EncodeToString(seed)
+
+	tokens.Set(token, token, cache.DefaultExpiration)
+
+	w.Header().Set("Content-Disposition", "attachment; filename="+fn)
+	w.Header().Set("Content-Type", "application/x-rdp")
+	http.ServeContent(w, r, fn, time.Now(), strings.NewReader(
+		"full address:s:localhost\r\n"+
+			"gatewayhostname:s:localhost\r\n"+
+			"gatewaycredentialssource:i:5\r\n"+
+			"gatewayusagemethod:i:1\r\n"+
+			"gatewayaccesstoken:s:" + token + "\r\n"))
+}
diff --git a/go.mod b/go.mod
index a2cdd6c626cfd77e66945dfe86163f0caf113724..f17e15d8a9752ea4f3fe46e724d25b665fd470e8 100644
--- a/go.mod
+++ b/go.mod
@@ -6,4 +6,6 @@ require (
 	github.com/gorilla/websocket v1.4.2
 	github.com/patrickmn/go-cache v2.1.0+incompatible
 	github.com/prometheus/client_golang v1.7.1
+	github.com/spf13/cobra v1.0.0
+	github.com/spf13/viper v1.7.0
 )
diff --git a/main.go b/main.go
index 8342ae70c70464c11654d6858c3f147b48a6eb96..5cf47278adacb15e4302a6c4da4a7ab55abaa23d 100644
--- a/main.go
+++ b/main.go
@@ -2,23 +2,54 @@ package main
 
 import (
 	"crypto/tls"
-	"flag"
+	"github.com/patrickmn/go-cache"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	"github.com/prometheus/client_golang/prometheus"
+	"github.com/spf13/cobra"
+	"github.com/spf13/viper"
 	"log"
 	"net/http"
 	"os"
 	"strconv"
+	"time"
 )
 
+var cmd = &cobra.Command{
+	Use:	"rdpgw",
+	Long:	"Remote Desktop Gateway",
+}
+
+var (
+	port		int
+	certFile	string
+	keyFile		string
+
+	configFile	string
+)
+
+var tokens = cache.New(time.Minute *5, 10*time.Minute)
+
 func main() {
-	port := flag.Int("port", 443, "port to listen on for incoming connections")
-	certFile := flag.String("certfile", "server.pem", "public key certificate file")
-	keyFile := flag.String("keyfile", "key.pem", "private key file")
+	cmd.PersistentFlags().IntVarP(&port, "port", "p", 443, "port to listen on for incoming connection")
+	cmd.PersistentFlags().StringVarP(&certFile, "certfile", "", "server.pem", "public key certificate file")
+	cmd.PersistentFlags().StringVarP(&keyFile, "keyfile", "", "key.pem", "private key file")
+	cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml",  "config file (json, yaml, ini)")
+
+	viper.BindPFlag("port", cmd.PersistentFlags().Lookup("port"))
+	viper.BindPFlag("certfile", cmd.PersistentFlags().Lookup("certfile"))
+	viper.BindPFlag("keyfile", cmd.PersistentFlags().Lookup("keyfile"))
+
+	viper.SetConfigFile(configFile)
+	viper.AddConfigPath(".")
+	viper.SetEnvPrefix("RDPGW")
+	viper.AutomaticEnv()
 
-	flag.Parse()
+	err := viper.ReadInConfig()
+	if err != nil {
+		log.Printf("No config file found. Using defaults")
+	}
 
-	if *certFile == "" || *keyFile == "" {
+	if certFile == "" || keyFile == "" {
 		log.Fatal("Both certfile and keyfile need to be specified")
 	}
 
@@ -36,18 +67,19 @@ func main() {
 		log.Printf("Key log file set to: %s", tlsDebug)
 		cfg.KeyLogWriter = w
 	}
-	cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
+	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
 	if err != nil {
 		log.Fatal(err)
 	}
 	cfg.Certificates = append(cfg.Certificates, cert)
 	server := http.Server{
-		Addr:      ":" + strconv.Itoa(*port),
+		Addr:      ":" + strconv.Itoa(port),
 		TLSConfig: cfg,
 		TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
 	}
 
 	http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol)
+	http.HandleFunc("/connect", handleRdpDownload)
 	http.Handle("/metrics", promhttp.Handler())
 
 	prometheus.MustRegister(connectionCache)
diff --git a/rdg.go b/rdg.go
index 9ebad663be4d72ac5a8dbee16eb136d82ac89716..25097775d59b6ebb76707bcb900ef14d116d69b6 100644
--- a/rdg.go
+++ b/rdg.go
@@ -217,7 +217,11 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
 			log.Printf("Handshake response: %x", msg)
 			conn.WriteMessage(mt, msg)
 		case PKT_TYPE_TUNNEL_CREATE:
-			readCreateTunnelRequest(pkt)
+			_, cookie := readCreateTunnelRequest(pkt)
+			if _, found := tokens.Get(cookie); found == false {
+				log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr())
+				return
+			}
 			msg := createTunnelResponse()
 			log.Printf("Create tunnel response: %x", msg)
 			conn.WriteMessage(mt, msg)