package main import ( "context" "crypto/tls" "github.com/bolkedebruin/rdpgw/api" "github.com/bolkedebruin/rdpgw/client" "github.com/bolkedebruin/rdpgw/config" "github.com/bolkedebruin/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/security" "github.com/coreos/go-oidc/v3/oidc" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/spf13/cobra" "golang.org/x/oauth2" "log" "net/http" "os" "strconv" ) var cmd = &cobra.Command{ Use: "rdpgw", Long: "Remote Desktop Gateway", } var ( configFile string ) var conf config.Configuration func main() { // get config cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml", "config file (json, yaml, ini)") conf = config.Load(configFile) // set security keys security.SigningKey = []byte(conf.Security.TokenSigningKey) // set oidc config ctx := context.Background() provider, err := oidc.NewProvider(ctx, conf.OpenId.ProviderUrl) if err != nil { log.Fatalf("Cannot get oidc provider: %s", err) } oidcConfig := &oidc.Config{ ClientID: conf.OpenId.ClientId, } verifier := provider.Verifier(oidcConfig) oauthConfig := oauth2.Config{ ClientID: conf.OpenId.ClientId, ClientSecret: conf.OpenId.ClientSecret, RedirectURL: "https://" + conf.Server.GatewayAddress + "/callback", Endpoint: provider.Endpoint(), Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, } api := &api.Config{ GatewayAddress: conf.Server.GatewayAddress, OAuth2Config: &oauthConfig, TokenVerifier: verifier, TokenGenerator: security.GeneratePAAToken, SessionKey: []byte(conf.Server.SessionKey), SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey), Hosts: conf.Server.Hosts, NetworkAutoDetect: conf.Client.NetworkAutoDetect, UsernameTemplate: conf.Client.UsernameTemplate, BandwidthAutoDetect: conf.Client.BandwidthAutoDetect, ConnectionType: conf.Client.ConnectionType, } api.NewApi() if conf.Server.CertFile == "" || conf.Server.KeyFile == "" { log.Fatal("Both certfile and keyfile need to be specified") } //mux := http.NewServeMux() //mux.HandleFunc("*", HelloServer) log.Printf("Starting remote desktop gateway server") cfg := &tls.Config{} tlsDebug := os.Getenv("SSLKEYLOGFILE") if tlsDebug != "" { w, err := os.OpenFile(tlsDebug, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { log.Fatalf("Cannot open key log file %s for writing %s", tlsDebug, err) } log.Printf("Key log file set to: %s", tlsDebug) cfg.KeyLogWriter = w } cert, err := tls.LoadX509KeyPair(conf.Server.CertFile, conf.Server.KeyFile) if err != nil { log.Fatal(err) } cfg.Certificates = append(cfg.Certificates, cert) server := http.Server{ Addr: ":" + strconv.Itoa(conf.Server.Port), TLSConfig: cfg, TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 } // create the gateway handlerConfig := protocol.HandlerConf{ IdleTimeout: conf.Caps.IdleTimeout, TokenAuth: conf.Caps.TokenAuth, SmartCardAuth: conf.Caps.SmartCardAuth, RedirectFlags: protocol.RedirectFlags{ Clipboard: conf.Caps.EnableClipboard, Drive: conf.Caps.EnableDrive, Printer: conf.Caps.EnablePrinter, Port: conf.Caps.EnablePort, Pnp: conf.Caps.EnablePnp, DisableAll: conf.Caps.DisableRedirect, EnableAll: conf.Caps.RedirectAll, }, VerifyTunnelCreate: security.VerifyPAAToken, VerifyServerFunc: security.VerifyServerFunc, } gw := protocol.Gateway{ HandlerConf: &handlerConfig, } http.Handle("/remoteDesktopGateway/", client.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol))) http.Handle("/connect", client.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload)))) http.Handle("/metrics", promhttp.Handler()) http.HandleFunc("/callback", api.HandleCallback) err = server.ListenAndServeTLS("", "") if err != nil { log.Fatal("ListenAndServe: ", err) } }