package main import ( "encoding/hex" "encoding/json" "github.com/patrickmn/go-cache" "github.com/spf13/viper" "golang.org/x/oauth2" "log" "math/rand" "net/http" "strings" "time" ) func handleRdpDownload(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("RDPGWSESSIONV1") if err != nil { http.Redirect(w, r, oauthConfig.AuthCodeURL(state), http.StatusFound) return } data, found := tokens.Get(cookie.Value) if found == false { log.Printf("Found expired or non existent session: %s", cookie.Value) http.Redirect(w, r, oauthConfig.AuthCodeURL(state), http.StatusFound) return } host := strings.Replace(viper.GetString("hostTemplate"), "%%", data.(string), 1) // authenticated seed := make([]byte, 16) rand.Read(seed) fn := hex.EncodeToString(seed) + ".rdp" w.Header().Set("Content-Disposition", "attachment; filename="+fn) w.Header().Set("Content-Type", "application/x-rdp") http.ServeContent(w, r, fn, time.Now(), strings.NewReader( "full address:s:" + host + "\r\n"+ "gatewayhostname:s:" + gateway +"\r\n"+ "gatewaycredentialssource:i:5\r\n"+ "gatewayusagemethod:i:1\r\n"+ "gatewayaccesstoken:s:" + cookie.Value + "\r\n")) } func handleCallback(w http.ResponseWriter, r *http.Request) { if r.URL.Query().Get("state") != state { http.Error(w, "state did not match", http.StatusBadRequest) return } oauthToken, err := oauthConfig.Exchange(ctx, r.URL.Query().Get("code")) if err != nil { http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) return } rawIDToken, ok := oauthToken.Extra("id_token").(string) if !ok { http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError) return } idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError) return } resp := struct { OAuth2Token *oauth2.Token IDTokenClaims *json.RawMessage // ID Token payload is just JSON. }{oauthToken, new(json.RawMessage)} if err := idToken.Claims(&resp.IDTokenClaims); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } var data map[string]interface{} if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } seed := make([]byte, 16) rand.Read(seed) token := hex.EncodeToString(seed) cookie := http.Cookie{ Name: "RDPGWSESSIONV1", Value: token, Path: "/", Secure: true, HttpOnly: true, } tokens.Set(token, data[claim].(string), cache.DefaultExpiration) http.SetCookie(w, &cookie) http.Redirect(w, r, "/connect", http.StatusFound) }