diff --git a/download.go b/download.go
index 5d57a7aae9cbb470c341da343647332d1a5d81ba..9125ca9153ac5b55cdaca0170676d73a2ac570b3 100644
--- a/download.go
+++ b/download.go
@@ -2,7 +2,11 @@ package main
 
 import (
 	"encoding/hex"
+	"encoding/json"
 	"github.com/patrickmn/go-cache"
+	"github.com/spf13/viper"
+	"golang.org/x/oauth2"
+	"log"
 	"math/rand"
 	"net/http"
 	"strings"
@@ -10,21 +14,93 @@ import (
 )
 
 func handleRdpDownload(w http.ResponseWriter, r *http.Request) {
+	cookie, err := r.Cookie("RDPGWSESSIONV1")
+	if err != nil {
+		http.Redirect(w, r, oauthConfig.AuthCodeURL(state), http.StatusFound)
+		return
+	}
+
+	data, found := tokens.Get(cookie.Value)
+	if found == false {
+		log.Printf("Found expired or non existent session: %s", cookie.Value)
+		http.Redirect(w, r, oauthConfig.AuthCodeURL(state), http.StatusFound)
+		return
+	}
+
+	// authenticated
 	seed := make([]byte, 16)
 	rand.Read(seed)
 	fn := hex.EncodeToString(seed) + ".rdp"
 
-	rand.Read(seed)
-	token := hex.EncodeToString(seed)
-
-	tokens.Set(token, token, cache.DefaultExpiration)
-
 	w.Header().Set("Content-Disposition", "attachment; filename="+fn)
 	w.Header().Set("Content-Type", "application/x-rdp")
 	http.ServeContent(w, r, fn, time.Now(), strings.NewReader(
-		"full address:s:localhost\r\n"+
-			"gatewayhostname:s:localhost\r\n"+
+		"full address:s:" + host + "\r\n"+
+			"gatewayhostname:s:" + gateway +"\r\n"+
 			"gatewaycredentialssource:i:5\r\n"+
 			"gatewayusagemethod:i:1\r\n"+
-			"gatewayaccesstoken:s:" + token + "\r\n"))
+			"gatewayaccesstoken:s:" + cookie.Value + "\r\n"))
 }
+
+func handleCallback(w http.ResponseWriter, r *http.Request) {
+	if r.URL.Query().Get("state") != state {
+		http.Error(w, "state did not match", http.StatusBadRequest)
+		return
+	}
+
+	oauthToken, err := oauthConfig.Exchange(ctx, r.URL.Query().Get("code"))
+	if err != nil {
+		http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
+		return
+	}
+
+	rawIDToken, ok := oauthToken.Extra("id_token").(string)
+	if !ok {
+		http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
+		return
+	}
+	idToken, err := verifier.Verify(ctx, rawIDToken)
+	if err != nil {
+		http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
+		return
+	}
+
+	resp := struct {
+		OAuth2Token   *oauth2.Token
+		IDTokenClaims *json.RawMessage // ID Token payload is just JSON.
+	}{oauthToken, new(json.RawMessage)}
+
+	if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+		return
+	}
+
+	q, err := json.MarshalIndent(resp, "", "    ")
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+		return
+	}
+
+	var data map[string]interface{}
+	if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil {
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+		return
+	}
+
+	seed := make([]byte, 16)
+	rand.Read(seed)
+	token := hex.EncodeToString(seed)
+
+	cookie := http.Cookie{
+		Name: "RDPGWSESSIONV1",
+		Value: token,
+		Path: "/",
+		Secure: true,
+		HttpOnly: true,
+	}
+
+	tokens.Set(token, data[claim].(string), cache.DefaultExpiration)
+
+	http.SetCookie(w, &cookie)
+	http.Redirect(w, r, "/connect", http.StatusFound)
+}
\ No newline at end of file
diff --git a/go.mod b/go.mod
index f17e15d8a9752ea4f3fe46e724d25b665fd470e8..3b3589aeef825cdb4a04d58ab0730a6d65ebf690 100644
--- a/go.mod
+++ b/go.mod
@@ -3,6 +3,7 @@ module github.com/bolkedebruin/rdpgw
 go 1.14
 
 require (
+	github.com/coreos/go-oidc/v3 v3.0.0-alpha.1
 	github.com/gorilla/websocket v1.4.2
 	github.com/patrickmn/go-cache v2.1.0+incompatible
 	github.com/prometheus/client_golang v1.7.1
diff --git a/main.go b/main.go
index 5cf47278adacb15e4302a6c4da4a7ab55abaa23d..846aa59b0b5969651bfb38a7170fdd5b3c703806 100644
--- a/main.go
+++ b/main.go
@@ -1,12 +1,15 @@
 package main
 
 import (
+	"context"
 	"crypto/tls"
+	"github.com/coreos/go-oidc/v3/oidc"
 	"github.com/patrickmn/go-cache"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	"github.com/prometheus/client_golang/prometheus"
 	"github.com/spf13/cobra"
 	"github.com/spf13/viper"
+	"golang.org/x/oauth2"
 	"log"
 	"net/http"
 	"os"
@@ -29,17 +32,39 @@ var (
 
 var tokens = cache.New(time.Minute *5, 10*time.Minute)
 
+var state string
+
+var oauthConfig oauth2.Config
+var oidcConfig *oidc.Config
+var verifier *oidc.IDTokenVerifier
+var ctx context.Context
+
+var gateway string
+var overrideHost bool
+var hostTemplate string
+var claim string
+
 func main() {
+	// get config
 	cmd.PersistentFlags().IntVarP(&port, "port", "p", 443, "port to listen on for incoming connection")
 	cmd.PersistentFlags().StringVarP(&certFile, "certfile", "", "server.pem", "public key certificate file")
 	cmd.PersistentFlags().StringVarP(&keyFile, "keyfile", "", "key.pem", "private key file")
 	cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml",  "config file (json, yaml, ini)")
+	cmd.PersistentFlags().StringVarP(&gateway, "gateway", "g", "localhost",  "gateway dns name")
+	cmd.PersistentFlags().BoolVarP(&overrideHost, "hostOverride", "", false, "weather the user can override the host to connect to")
+	cmd.PersistentFlags().StringVarP(&hostTemplate, "hostTemplate", "t", "", "host template")
+	cmd.PersistentFlags().StringVarP(&claim, "claim", "", "preferred_username", "openid claim to use for filling in template")
 
 	viper.BindPFlag("port", cmd.PersistentFlags().Lookup("port"))
 	viper.BindPFlag("certfile", cmd.PersistentFlags().Lookup("certfile"))
 	viper.BindPFlag("keyfile", cmd.PersistentFlags().Lookup("keyfile"))
+	viper.BindPFlag("gateway", cmd.PersistentFlags().Lookup("gateway"))
+	viper.BindPFlag("hostOverride", cmd.PersistentFlags().Lookup("hostOverride"))
+	viper.BindPFlag("hostTemplate", cmd.PersistentFlags().Lookup("hostTemplate"))
+	viper.BindPFlag("claim", cmd.PersistentFlags().Lookup("claim"))
 
-	viper.SetConfigFile(configFile)
+	viper.SetConfigName("rdpgw")
+	//viper.SetConfigFile(configFile)
 	viper.AddConfigPath(".")
 	viper.SetEnvPrefix("RDPGW")
 	viper.AutomaticEnv()
@@ -49,6 +74,33 @@ func main() {
 		log.Printf("No config file found. Using defaults")
 	}
 
+	// dont understand why I need to do this
+	gateway = viper.GetString("gateway")
+	hostTemplate = viper.GetString("hostTemplate")
+	overrideHost = viper.GetBool("hostOverride")
+
+	// set oidc config
+	ctx = context.Background()
+	provider, err := oidc.NewProvider(ctx, viper.GetString("providerUrl"))
+	if err != nil {
+		log.Fatalf("Cannot get oidc provider: %s", err)
+	}
+	oidcConfig = &oidc.Config{
+		ClientID: viper.GetString("clientId"),
+	}
+	verifier = provider.Verifier(oidcConfig)
+
+	oauthConfig = oauth2.Config{
+		ClientID: viper.GetString("clientId"),
+		ClientSecret: viper.GetString("clientSecret"),
+		RedirectURL: "https://" + gateway + "/callback",
+		Endpoint: provider.Endpoint(),
+		Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
+	}
+
+	// check what is required
+	state = "rdpstate"
+
 	if certFile == "" || keyFile == "" {
 		log.Fatal("Both certfile and keyfile need to be specified")
 	}
@@ -81,6 +133,7 @@ func main() {
 	http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol)
 	http.HandleFunc("/connect", handleRdpDownload)
 	http.Handle("/metrics", promhttp.Handler())
+	http.HandleFunc("/callback", handleCallback)
 
 	prometheus.MustRegister(connectionCache)
 	prometheus.MustRegister(legacyConnections)
diff --git a/rdg.go b/rdg.go
index 25097775d59b6ebb76707bcb900ef14d116d69b6..2aaec71591fedd89803a102a48789edf5b9410b0 100644
--- a/rdg.go
+++ b/rdg.go
@@ -16,6 +16,7 @@ import (
 	"net/http"
 	"net/http/httputil"
 	"strconv"
+	"strings"
 	"time"
 	"unicode/utf16"
 	"unicode/utf8"
@@ -182,6 +183,7 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
 	websocketConnections.Inc()
 	defer websocketConnections.Dec()
 
+	var host string
 	for {
 		mt, msg, err := conn.ReadMessage()
 		if err != nil {
@@ -218,10 +220,12 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
 			conn.WriteMessage(mt, msg)
 		case PKT_TYPE_TUNNEL_CREATE:
 			_, cookie := readCreateTunnelRequest(pkt)
-			if _, found := tokens.Get(cookie); found == false {
+			data, found := tokens.Get(cookie)
+			if found == false {
 				log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr())
 				return
 			}
+			host = strings.Replace(hostTemplate, "%%", data.(string), 1)
 			msg := createTunnelResponse()
 			log.Printf("Create tunnel response: %x", msg)
 			conn.WriteMessage(mt, msg)
@@ -232,13 +236,17 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
 			conn.WriteMessage(mt, msg)
 		case PKT_TYPE_CHANNEL_CREATE:
 			server, port := readChannelCreateRequest(pkt)
-			log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server)
+			if overrideHost == true {
+				log.Printf("Override allowed")
+				host = net.JoinHostPort(server, strconv.Itoa(int(port)))
+			}
+			log.Printf("Establishing connection to RDP server: %s", host)
 			remote, err = net.DialTimeout(
 				"tcp",
-				net.JoinHostPort(server, strconv.Itoa(int(port))),
-				time.Second * 15)
+				host,
+				time.Second * 30)
 			if err != nil {
-				log.Printf("Error connecting to %s, %d, %s", server, port, err)
+				log.Printf("Error connecting to %s", host)
 				return
 			}
 			log.Printf("Connection established")
@@ -349,7 +357,11 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
 					msg := handshakeResponse(major, minor, auth)
 					s.ConnOut.Write(msg)
 				case PKT_TYPE_TUNNEL_CREATE:
-					readCreateTunnelRequest(pkt)
+					_, cookie := readCreateTunnelRequest(pkt)
+					if _, found := tokens.Get(cookie); found == false {
+						log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr())
+						return
+					}
 					msg := createTunnelResponse()
 					s.ConnOut.Write(msg)
 				case PKT_TYPE_TUNNEL_AUTH: