Skip to content
Snippets Groups Projects
main.go 4.4 KiB
Newer Older
  • Learn to ignore specific revisions
  • Bolke de Bruin's avatar
    Bolke de Bruin committed
    package main
    
    import (
    
    Bolke de Bruin's avatar
    Bolke de Bruin committed
    	"crypto/tls"
    
    	"github.com/coreos/go-oidc/v3/oidc"
    
    	"github.com/patrickmn/go-cache"
    
    	"github.com/prometheus/client_golang/prometheus/promhttp"
    	"github.com/prometheus/client_golang/prometheus"
    
    	"github.com/spf13/cobra"
    	"github.com/spf13/viper"
    
    Bolke de Bruin's avatar
    Bolke de Bruin committed
    	"log"
    	"net/http"
    
    Bolke de Bruin's avatar
    Bolke de Bruin committed
    	"os"
    	"strconv"
    
    var cmd = &cobra.Command{
    	Use:	"rdpgw",
    	Long:	"Remote Desktop Gateway",
    }
    
    var (
    	port		int
    	certFile	string
    	keyFile		string
    
    	configFile	string
    )
    
    var tokens = cache.New(time.Minute *5, 10*time.Minute)
    
    
    var state string
    
    var oauthConfig oauth2.Config
    var oidcConfig *oidc.Config
    var verifier *oidc.IDTokenVerifier
    var ctx context.Context
    
    var gateway string
    var overrideHost bool
    var hostTemplate string
    var claim string
    
    
    Bolke de Bruin's avatar
    Bolke de Bruin committed
    func main() {
    
    	cmd.PersistentFlags().IntVarP(&port, "port", "p", 443, "port to listen on for incoming connection")
    	cmd.PersistentFlags().StringVarP(&certFile, "certfile", "", "server.pem", "public key certificate file")
    	cmd.PersistentFlags().StringVarP(&keyFile, "keyfile", "", "key.pem", "private key file")
    	cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml",  "config file (json, yaml, ini)")
    
    	cmd.PersistentFlags().StringVarP(&gateway, "gateway", "g", "localhost",  "gateway dns name")
    	cmd.PersistentFlags().BoolVarP(&overrideHost, "hostOverride", "", false, "weather the user can override the host to connect to")
    	cmd.PersistentFlags().StringVarP(&hostTemplate, "hostTemplate", "t", "", "host template")
    	cmd.PersistentFlags().StringVarP(&claim, "claim", "", "preferred_username", "openid claim to use for filling in template")
    
    
    	viper.BindPFlag("port", cmd.PersistentFlags().Lookup("port"))
    	viper.BindPFlag("certfile", cmd.PersistentFlags().Lookup("certfile"))
    	viper.BindPFlag("keyfile", cmd.PersistentFlags().Lookup("keyfile"))
    
    	viper.BindPFlag("gateway", cmd.PersistentFlags().Lookup("gateway"))
    	viper.BindPFlag("hostOverride", cmd.PersistentFlags().Lookup("hostOverride"))
    	viper.BindPFlag("hostTemplate", cmd.PersistentFlags().Lookup("hostTemplate"))
    	viper.BindPFlag("claim", cmd.PersistentFlags().Lookup("claim"))
    
    	viper.SetConfigName("rdpgw")
    	//viper.SetConfigFile(configFile)
    
    	viper.AddConfigPath(".")
    	viper.SetEnvPrefix("RDPGW")
    	viper.AutomaticEnv()
    
    	err := viper.ReadInConfig()
    	if err != nil {
    		log.Printf("No config file found. Using defaults")
    	}
    
    	// dont understand why I need to do this
    	gateway = viper.GetString("gateway")
    	hostTemplate = viper.GetString("hostTemplate")
    	overrideHost = viper.GetBool("hostOverride")
    
    	// set oidc config
    	ctx = context.Background()
    	provider, err := oidc.NewProvider(ctx, viper.GetString("providerUrl"))
    	if err != nil {
    		log.Fatalf("Cannot get oidc provider: %s", err)
    	}
    	oidcConfig = &oidc.Config{
    		ClientID: viper.GetString("clientId"),
    	}
    	verifier = provider.Verifier(oidcConfig)
    
    	oauthConfig = oauth2.Config{
    		ClientID: viper.GetString("clientId"),
    		ClientSecret: viper.GetString("clientSecret"),
    		RedirectURL: "https://" + gateway + "/callback",
    		Endpoint: provider.Endpoint(),
    		Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
    	}
    
    	// check what is required
    	state = "rdpstate"
    
    
    	if certFile == "" || keyFile == "" {
    
    Bolke de Bruin's avatar
    Bolke de Bruin committed
    		log.Fatal("Both certfile and keyfile need to be specified")
    
    Bolke de Bruin's avatar
    Bolke de Bruin committed
    	//mux := http.NewServeMux()
    	//mux.HandleFunc("*", HelloServer)
    
    Bolke de Bruin's avatar
    Bolke de Bruin committed
    	log.Printf("Starting remote desktop gateway server")
    	cfg := &tls.Config{}
    	tlsDebug := os.Getenv("SSLKEYLOGFILE")
    	if tlsDebug != "" {
    		w, err := os.OpenFile(tlsDebug, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
    		if err != nil {
    			log.Fatalf("Cannot open key log file %s for writing %s", tlsDebug, err)
    
    Bolke de Bruin's avatar
    Bolke de Bruin committed
    		log.Printf("Key log file set to: %s", tlsDebug)
    		cfg.KeyLogWriter = w
    
    	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
    
    Bolke de Bruin's avatar
    Bolke de Bruin committed
    	if err != nil {
    		log.Fatal(err)
    	}
    	cfg.Certificates = append(cfg.Certificates, cert)
    	server := http.Server{
    
    		Addr:      ":" + strconv.Itoa(port),
    
    Bolke de Bruin's avatar
    Bolke de Bruin committed
    		TLSConfig: cfg,
    
    		TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
    
    	http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol)
    
    	http.HandleFunc("/connect", handleRdpDownload)
    
    	http.Handle("/metrics", promhttp.Handler())
    
    	http.HandleFunc("/callback", handleCallback)
    
    
    	prometheus.MustRegister(connectionCache)
    	prometheus.MustRegister(legacyConnections)
    	prometheus.MustRegister(websocketConnections)
    
    
    Bolke de Bruin's avatar
    Bolke de Bruin committed
    	err = server.ListenAndServeTLS("", "")
    
    Bolke de Bruin's avatar
    Bolke de Bruin committed
    	if err != nil {
    		log.Fatal("ListenAndServe: ", err)
    	}
    
    Bolke de Bruin's avatar
    Bolke de Bruin committed
    }