diff --git a/cmd/rdpgw/config/configuration.go b/cmd/rdpgw/config/configuration.go index 92d08eb41210fc3abdecf5b1b5d65bab0b67710e..6f337c978901cb13eea45a960efc0a2a7e9d8e2c 100644 --- a/cmd/rdpgw/config/configuration.go +++ b/cmd/rdpgw/config/configuration.go @@ -48,7 +48,7 @@ type ServerConfig struct { SendBuf int `koanf:"sendbuf"` ReceiveBuf int `koanf:"receivebuf"` Tls string `koanf:"tls"` - Authentication string `koanf:"authentication"` + Authentication []string `koanf:"authentication"` AuthSocket string `koanf:"authsocket"` } @@ -206,15 +206,15 @@ func Load(configFile string) Configuration { log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set") } - if Conf.Server.Authentication == "local" && Conf.Server.Tls == "disable" { + if Conf.Server.BasicAuthEnabled() && Conf.Server.Tls == "disable" { log.Fatalf("basicauth=local and tls=disable are mutually exclusive") } - if !Conf.Caps.TokenAuth && Conf.Server.Authentication == "openid" { + if !Conf.Caps.TokenAuth && Conf.Server.OpenIDEnabled() { log.Fatalf("openid is configured but tokenauth disabled") } - if Conf.Server.Authentication == AuthenticationKerberos && Conf.Kerberos.Keytab == "" { + if Conf.Server.KerberosEnabled() && Conf.Kerberos.Keytab == "" { log.Fatalf("kerberos is configured but no keytab was specified") } @@ -226,3 +226,24 @@ func Load(configFile string) Configuration { return Conf } + +func (s *ServerConfig) OpenIDEnabled() bool { + return s.matchAuth("openid") +} + +func (s *ServerConfig) KerberosEnabled() bool { + return s.matchAuth("kerberos") +} + +func (s *ServerConfig) BasicAuthEnabled() bool { + return s.matchAuth("local") +} + +func (s *ServerConfig) matchAuth(needle string) bool { + for _, q := range s.Authentication { + if q == needle { + return true + } + } + return false +} diff --git a/cmd/rdpgw/identity/identity.go b/cmd/rdpgw/identity/identity.go new file mode 100644 index 0000000000000000000000000000000000000000..2ef7678fdbaf4e405e961ff3eb6670d44a45efb6 --- /dev/null +++ b/cmd/rdpgw/identity/identity.go @@ -0,0 +1,57 @@ +package identity + +import ( + "context" + "net/http" + "time" +) + +const ( + CTXKey = "github.com/bolkedebruin/rdpgw/common/identity" + + AttrRemoteAddr = "remoteAddr" + AttrClientIp = "clientIp" + AttrProxies = "proxyAddresses" + AttrAccessToken = "accessToken" // todo remove for security reasons +) + +type Identity interface { + UserName() string + SetUserName(string) + DisplayName() string + SetDisplayName(string) + Domain() string + SetDomain(string) + Authenticated() bool + SetAuthenticated(bool) + AuthTime() time.Time + SetAuthTime(time2 time.Time) + SessionId() string + SetAttribute(string, interface{}) + GetAttribute(string) interface{} + Attributes() map[string]interface{} + DelAttribute(string) + Email() string + SetEmail(string) + Expiry() time.Time + SetExpiry(time.Time) + Marshal() ([]byte, error) + Unmarshal([]byte) error +} + +func AddToRequestCtx(id Identity, r *http.Request) *http.Request { + ctx := r.Context() + ctx = context.WithValue(ctx, CTXKey, id) + return r.WithContext(ctx) +} + +func FromRequestCtx(r *http.Request) Identity { + return FromCtx(r.Context()) +} + +func FromCtx(ctx context.Context) Identity { + if id, ok := ctx.Value(CTXKey).(Identity); ok { + return id + } + return nil +} diff --git a/cmd/rdpgw/identity/identity_test.go b/cmd/rdpgw/identity/identity_test.go new file mode 100644 index 0000000000000000000000000000000000000000..fd66e1f2bed6f7807e5dcc0cc17ee7b3d9640405 --- /dev/null +++ b/cmd/rdpgw/identity/identity_test.go @@ -0,0 +1,28 @@ +package identity + +import ( + "log" + "testing" +) + +func TestMarshalling(t *testing.T) { + u := NewUser() + u.SetUserName("ANAME") + u.SetAuthenticated(true) + u.SetDomain("DOMAIN") + + c := NewUser() + data, err := u.Marshal() + if err != nil { + log.Fatalf("Cannot marshal %s", err) + } + + err = c.Unmarshal(data) + if err != nil { + t.Fatalf("Error while unmarshalling: %s", err) + } + + if u.UserName() != c.UserName() || u.Authenticated() != c.Authenticated() || u.Domain() != c.Domain() { + t.Fatalf("identities not equal: %+v != %+v", u, c) + } +} diff --git a/cmd/rdpgw/common/identity.go b/cmd/rdpgw/identity/user.go similarity index 58% rename from cmd/rdpgw/common/identity.go rename to cmd/rdpgw/identity/user.go index baa8b36f3a7963cfb14b83fdd4d7a3990880ca58..a141853503f3d1a32ca625c61f5c88f5a69967f7 100644 --- a/cmd/rdpgw/common/identity.go +++ b/cmd/rdpgw/identity/user.go @@ -1,60 +1,12 @@ -package common +package identity import ( - "context" + "bytes" + "encoding/gob" "github.com/google/uuid" - "net/http" "time" ) -const ( - CTXKey = "github.com/bolkedebruin/rdpgw/common/identity" - - AttrRemoteAddr = "remoteAddr" - AttrClientIp = "clientIp" - AttrProxies = "proxyAddresses" - AttrAccessToken = "accessToken" // todo remove for security reasons -) - -type Identity interface { - UserName() string - SetUserName(string) - DisplayName() string - SetDisplayName(string) - Domain() string - SetDomain(string) - Authenticated() bool - SetAuthenticated(bool) - AuthTime() time.Time - SetAuthTime(time2 time.Time) - SessionId() string - SetAttribute(string, interface{}) - GetAttribute(string) interface{} - Attributes() map[string]interface{} - DelAttribute(string) - Email() string - SetEmail(string) - Expiry() time.Time - SetExpiry(time.Time) -} - -func AddToRequestCtx(id Identity, r *http.Request) *http.Request { - ctx := r.Context() - ctx = context.WithValue(ctx, CTXKey, id) - return r.WithContext(ctx) -} - -func FromRequestCtx(r *http.Request) Identity { - return FromCtx(r.Context()) -} - -func FromCtx(ctx context.Context) Identity { - if id, ok := ctx.Value(CTXKey).(Identity); ok { - return id - } - return nil -} - type User struct { authenticated bool domain string @@ -68,6 +20,19 @@ type User struct { groupMembership map[string]bool } +type user struct { + Authenticated bool + UserName string + Domain string + DisplayName string + Email string + AuthTime time.Time + SessionId string + Expiry time.Time + Attributes map[string]interface{} + GroupMembership map[string]bool +} + func NewUser() *User { uuid := uuid.New().String() return &User{ @@ -158,3 +123,48 @@ func (u *User) Expiry() time.Time { func (u *User) SetExpiry(t time.Time) { u.expiry = t } + +func (u *User) Marshal() ([]byte, error) { + buf := new(bytes.Buffer) + enc := gob.NewEncoder(buf) + uu := user{ + Authenticated: u.authenticated, + UserName: u.userName, + Domain: u.domain, + DisplayName: u.displayName, + Email: u.email, + AuthTime: u.authTime, + SessionId: u.sessionId, + Expiry: u.expiry, + Attributes: u.attributes, + GroupMembership: u.groupMembership, + } + err := enc.Encode(uu) + + if err != nil { + return []byte{}, err + } + return buf.Bytes(), nil +} + +func (u *User) Unmarshal(b []byte) error { + buf := bytes.NewBuffer(b) + dec := gob.NewDecoder(buf) + var uu user + err := dec.Decode(&uu) + if err != nil { + return err + } + u.sessionId = uu.SessionId + u.userName = uu.UserName + u.domain = uu.Domain + u.displayName = uu.DisplayName + u.email = uu.Email + u.authenticated = uu.Authenticated + u.authTime = uu.AuthTime + u.expiry = uu.Expiry + u.attributes = uu.Attributes + u.groupMembership = uu.GroupMembership + + return nil +} diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index 5f02cffcbf47a08020794b25f7f3d378dd370cf0..bdc3e9d376bba74fa1204b85e09374ab4baae8d4 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -7,14 +7,13 @@ import ( "github.com/bolkedebruin/gokrb5/v8/keytab" "github.com/bolkedebruin/gokrb5/v8/service" "github.com/bolkedebruin/gokrb5/v8/spnego" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/config" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/kdcproxy" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/web" "github.com/coreos/go-oidc/v3/oidc" - "github.com/gorilla/sessions" + "github.com/gorilla/mux" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/thought-machine/go-flags" "golang.org/x/crypto/acme/autocert" @@ -26,13 +25,18 @@ import ( "strconv" ) +const ( + gatewayEndPoint = "/remoteDesktopGateway/" + kdcProxyEndPoint = "/KdcProxy" +) + var opts struct { ConfigFile string `short:"c" long:"conf" default:"rdpgw.yaml" description:"config file (yaml)"` } var conf config.Configuration -func initOIDC(callbackUrl *url.URL, store sessions.Store) *web.OIDC { +func initOIDC(callbackUrl *url.URL) *web.OIDC { // set oidc config provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl) if err != nil { @@ -56,7 +60,6 @@ func initOIDC(callbackUrl *url.URL, store sessions.Store) *web.OIDC { o := web.OIDCConfig{ OAuth2Config: &oauthConfig, OIDCTokenVerifier: verifier, - SessionStore: store, } return o.New() @@ -91,19 +94,13 @@ func main() { security.Hosts = conf.Server.Hosts // init session store - sessionConf := web.SessionManagerConf{ - SessionKey: []byte(conf.Server.SessionKey), - SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey), - StoreType: conf.Server.SessionStore, - } - store := sessionConf.Init() + web.InitStore([]byte(conf.Server.SessionKey), []byte(conf.Server.SessionEncryptionKey), conf.Server.SessionStore) // configure web backend w := &web.Config{ QueryInfo: security.QueryInfo, QueryTokenIssuer: conf.Security.QueryTokenIssuer, EnableUserToken: conf.Security.EnableUserToken, - SessionStore: store, Hosts: conf.Server.Hosts, HostSelection: conf.Server.HostSelection, RdpOpts: web.RdpOpts{ @@ -128,6 +125,7 @@ func main() { log.Printf("Starting remote desktop gateway server") cfg := &tls.Config{} + // configure tls security if conf.Server.Tls == config.TlsDisable { log.Printf("TLS disabled - rdp gw connections require tls, make sure to have a terminator") } else { @@ -174,13 +172,7 @@ func main() { } } - server := http.Server{ - Addr: ":" + strconv.Itoa(conf.Server.Port), - TLSConfig: cfg, - TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 - } - - // create the gateway + // gateway confg gw := protocol.Gateway{ RedirectFlags: protocol.RedirectFlags{ Clipboard: conf.Caps.EnableClipboard, @@ -205,31 +197,72 @@ func main() { gw.CheckHost = security.CheckHost } - if conf.Server.Authentication == config.AuthenticationBasic { - h := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket} - http.Handle("/remoteDesktopGateway/", common.EnrichContext(h.BasicAuth(gw.HandleGatewayProtocol))) - } else if conf.Server.Authentication == config.AuthenticationKerberos { + r := mux.NewRouter() + + // ensure identity is set in context and get some extra info + r.Use(web.EnrichContext) + + // prometheus metrics + r.Handle("/metrics", promhttp.Handler()) + + // for sso callbacks + r.HandleFunc("/tokeninfo", web.TokenInfo) + + // gateway endpoint + rdp := r.PathPrefix(gatewayEndPoint).Subrouter() + + // openid + if conf.Server.OpenIDEnabled() { + log.Printf("enabling openid extended authentication") + o := initOIDC(url) + r.Handle("/connect", o.Authenticated(http.HandlerFunc(h.HandleDownload))) + r.HandleFunc("/callback", o.HandleCallback) + + // only enable un-auth endpoint for openid only config + if !conf.Server.KerberosEnabled() || !conf.Server.BasicAuthEnabled() { + rdp.Name("gw").HandlerFunc(gw.HandleGatewayProtocol) + } + } + + // for stacking of authentication + auth := web.NewAuthMux() + + // basic auth + if conf.Server.BasicAuthEnabled() { + log.Printf("enabling basic authentication") + q := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket} + rdp.Headers("Authorization", "Basic*").HandlerFunc(q.BasicAuth(gw.HandleGatewayProtocol)) + auth.Register(`Basic realm="restricted", charset="UTF-8"`) + } + + // spnego / kerberos + if conf.Server.KerberosEnabled() { + log.Printf("enabling kerberos authentication") keytab, err := keytab.Load(conf.Kerberos.Keytab) if err != nil { log.Fatalf("Cannot load keytab: %s", err) } - http.Handle("/remoteDesktopGateway/", common.EnrichContext( - spnego.SPNEGOKRB5Authenticate( - common.FixKerberosContext(http.HandlerFunc(gw.HandleGatewayProtocol)), + rdp.Headers("Authorization", "Negotiate*").Handler( + spnego.SPNEGOKRB5Authenticate(web.TransposeSPNEGOContext(http.HandlerFunc(gw.HandleGatewayProtocol)), keytab, - service.Logger(log.Default()))), - ) + service.Logger(log.Default()))) + + // kdcproxy k := kdcproxy.InitKdcProxy(conf.Kerberos.Krb5Conf) - http.HandleFunc("/KdcProxy", k.Handler) - } else { - // openid - oidc := initOIDC(url, store) - http.Handle("/connect", common.EnrichContext(oidc.Authenticated(http.HandlerFunc(h.HandleDownload)))) - http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol))) - http.Handle("/callback", common.EnrichContext(http.HandlerFunc(oidc.HandleCallback))) - } - http.Handle("/metrics", promhttp.Handler()) - http.HandleFunc("/tokeninfo", web.TokenInfo) + r.HandleFunc(kdcProxyEndPoint, k.Handler).Methods("POST") + auth.Register("Negotiate") + } + + // allow stacking of authentication + rdp.Use(auth.Route) + + // setup server + server := http.Server{ + Addr: ":" + strconv.Itoa(conf.Server.Port), + Handler: r, + TLSConfig: cfg, + TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 + } if conf.Server.Tls == config.TlsDisable { err = server.ListenAndServe() diff --git a/cmd/rdpgw/protocol/gateway.go b/cmd/rdpgw/protocol/gateway.go index b2b42c4d52893678bbc3bd376354b1d31c47120a..51fae1a911004b037cad9f4cb599a6fe06392a92 100644 --- a/cmd/rdpgw/protocol/gateway.go +++ b/cmd/rdpgw/protocol/gateway.go @@ -3,7 +3,7 @@ package protocol import ( "context" "errors" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" "github.com/google/uuid" "github.com/gorilla/websocket" @@ -61,14 +61,14 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) var t *Tunnel ctx := r.Context() - id := common.FromRequestCtx(r) + id := identity.FromRequestCtx(r) connId := r.Header.Get(rdgConnectionIdKey) x, found := c.Get(connId) if !found { t = &Tunnel{ RDGId: connId, - RemoteAddr: id.GetAttribute(common.AttrRemoteAddr).(string), + RemoteAddr: id.GetAttribute(identity.AttrRemoteAddr).(string), User: id, } } else { @@ -183,14 +183,14 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) { log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil) - id := common.FromRequestCtx(r) + id := identity.FromRequestCtx(r) if r.Method == MethodRDGOUT { out, err := transport.NewLegacy(w) if err != nil { log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) return } - log.Printf("Opening RDGOUT for client %s", id.GetAttribute(common.AttrClientIp)) + log.Printf("Opening RDGOUT for client %s", id.GetAttribute(identity.AttrClientIp)) t.transportOut = out out.SendAccept(true) @@ -212,13 +212,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t t.transportIn = in c.Set(t.RDGId, t, cache.DefaultExpiration) - log.Printf("Opening RDGIN for client %s", id.GetAttribute(common.AttrClientIp)) + log.Printf("Opening RDGIN for client %s", id.GetAttribute(identity.AttrClientIp)) in.SendAccept(false) // read some initial data in.Drain() - log.Printf("Legacy handshakeRequest done for client %s", id.GetAttribute(common.AttrClientIp)) + log.Printf("Legacy handshakeRequest done for client %s", id.GetAttribute(identity.AttrClientIp)) handler := NewProcessor(g, t) RegisterTunnel(t, handler) defer RemoveTunnel(t) diff --git a/cmd/rdpgw/protocol/process.go b/cmd/rdpgw/protocol/process.go index 9dc4788f1e404565de9cae30d7a8dd1e75e331b4..bf90c074bb86c0eda4067a6f850f2346a66da96a 100644 --- a/cmd/rdpgw/protocol/process.go +++ b/cmd/rdpgw/protocol/process.go @@ -6,7 +6,7 @@ import ( "encoding/binary" "errors" "fmt" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "io" "log" "net" @@ -51,7 +51,7 @@ func (p *Processor) Process(ctx context.Context) error { switch pt { case PKT_TYPE_HANDSHAKE_REQUEST: - log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(common.AttrClientIp)) + log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(identity.AttrClientIp)) if p.state != SERVER_STATE_INITIALIZED { log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED) msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR) @@ -81,7 +81,7 @@ func (p *Processor) Process(ctx context.Context) error { _, cookie := p.tunnelRequest(pkt) if p.gw.CheckPAACookie != nil { if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok { - log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(common.AttrClientIp)) + log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(identity.AttrClientIp)) msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) p.tunnel.Write(msg) return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) diff --git a/cmd/rdpgw/protocol/tunnel.go b/cmd/rdpgw/protocol/tunnel.go index cba7e1e0a497ae22235101707564c3ba15f9ee54..dc3b1d5fef4678c2e59e8a83923df5c911453b48 100644 --- a/cmd/rdpgw/protocol/tunnel.go +++ b/cmd/rdpgw/protocol/tunnel.go @@ -1,7 +1,7 @@ package protocol import ( - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" "net" "time" @@ -27,7 +27,7 @@ type Tunnel struct { // The obtained client ip address RemoteAddr string // User - User common.Identity + User identity.Identity // rwc is the underlying connection to the remote desktop server. // It is of the type *net.TCPConn diff --git a/cmd/rdpgw/security/basic_test.go b/cmd/rdpgw/security/basic_test.go index 1026a6ccb702d92eea726b4468c6c7f7b462f3c1..c141e818ce41eaff5e993a8358f01ce70d5761d8 100644 --- a/cmd/rdpgw/security/basic_test.go +++ b/cmd/rdpgw/security/basic_test.go @@ -2,7 +2,7 @@ package security import ( "context" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol" "testing" ) @@ -18,7 +18,7 @@ var ( ) func TestCheckHost(t *testing.T) { - info.User = common.NewUser() + info.User = identity.NewUser() info.User.SetUserName("MYNAME") ctx := context.WithValue(context.Background(), protocol.CtxTunnel, &info) diff --git a/cmd/rdpgw/security/jwt.go b/cmd/rdpgw/security/jwt.go index 6f21d17260446735f156f90f4649d2189ea8bc75..cd0f3a870ba04b1a2954685bd953456091d94431 100644 --- a/cmd/rdpgw/security/jwt.go +++ b/cmd/rdpgw/security/jwt.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol" "github.com/coreos/go-oidc/v3/oidc" "github.com/go-jose/go-jose/v3" @@ -46,10 +46,10 @@ func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc { } // use identity from context rather then set by tunnel - id := common.FromCtx(ctx) - if VerifyClientIP && tunnel.RemoteAddr != id.GetAttribute(common.AttrClientIp) { + id := identity.FromCtx(ctx) + if VerifyClientIP && tunnel.RemoteAddr != id.GetAttribute(identity.AttrClientIp) { log.Printf("Current client ip address %s does not match token client ip %s", - id.GetAttribute(common.AttrClientIp), tunnel.RemoteAddr) + id.GetAttribute(identity.AttrClientIp), tunnel.RemoteAddr) return false, nil } return next(ctx, host) @@ -129,11 +129,11 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri Subject: username, } - id := common.FromCtx(ctx) + id := identity.FromCtx(ctx) private := customClaims{ RemoteServer: server, - ClientIP: id.GetAttribute(common.AttrClientIp).(string), - AccessToken: id.GetAttribute(common.AttrAccessToken).(string), + ClientIP: id.GetAttribute(identity.AttrClientIp).(string), + AccessToken: id.GetAttribute(identity.AttrAccessToken).(string), } if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil { diff --git a/cmd/rdpgw/web/basic.go b/cmd/rdpgw/web/basic.go index a8ef8077326727bbe142fc05c8656bc8d0d1debe..84724e3b8742413f95997826ede4360c4b3050ed 100644 --- a/cmd/rdpgw/web/basic.go +++ b/cmd/rdpgw/web/basic.go @@ -2,7 +2,7 @@ package web import ( "context" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "github.com/bolkedebruin/rdpgw/shared/auth" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -53,11 +53,11 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc { log.Printf("User %s is not authenticated for this service", username) } else { log.Printf("User %s authenticated", username) - id := common.FromRequestCtx(r) + id := identity.FromRequestCtx(r) id.SetUserName(username) id.SetAuthenticated(true) id.SetAuthTime(time.Now()) - next.ServeHTTP(w, common.AddToRequestCtx(id, r)) + next.ServeHTTP(w, identity.AddToRequestCtx(id, r)) return } @@ -66,7 +66,7 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc { // username or password is wrong, then set a WWW-Authenticate // header to inform the client that we expect them to use basic // authentication and send a 401 Unauthorized response. - w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`) + w.Header().Add("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`) http.Error(w, "Unauthorized", http.StatusUnauthorized) } } diff --git a/cmd/rdpgw/common/remote.go b/cmd/rdpgw/web/context.go similarity index 57% rename from cmd/rdpgw/common/remote.go rename to cmd/rdpgw/web/context.go index bb73a7225923358c0bec6791f84356509ab54fe9..af8e2d342c7b9daac6a338c09b2c58d5c658e691 100644 --- a/cmd/rdpgw/common/remote.go +++ b/cmd/rdpgw/web/context.go @@ -1,7 +1,7 @@ -package common +package web import ( - "context" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "github.com/jcmturner/goidentity/v6" "log" "net" @@ -9,16 +9,22 @@ import ( "strings" ) -const ( - CtxAccessToken = "github.com/bolkedebruin/rdpgw/oidc/access_token" -) - func EnrichContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - id := FromRequestCtx(r) + id, err := GetSessionIdentity(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if id == nil { - id = NewUser() + id = identity.NewUser() + if err := SaveSessionIdentity(r, w, id); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } + log.Printf("Identity SessionId: %s, UserName: %s: Authenticated: %t", id.SessionId(), id.UserName(), id.Authenticated()) @@ -33,39 +39,30 @@ func EnrichContext(next http.Handler) http.Handler { if len(ips) > 1 { proxies = ips[1:] } - id.SetAttribute(AttrClientIp, clientIp) - id.SetAttribute(AttrProxies, proxies) + id.SetAttribute(identity.AttrClientIp, clientIp) + id.SetAttribute(identity.AttrProxies, proxies) } - id.SetAttribute(AttrRemoteAddr, r.RemoteAddr) + id.SetAttribute(identity.AttrRemoteAddr, r.RemoteAddr) if h == "" { clientIp, _, _ := net.SplitHostPort(r.RemoteAddr) - id.SetAttribute(AttrClientIp, clientIp) + id.SetAttribute(identity.AttrClientIp, clientIp) } - next.ServeHTTP(w, AddToRequestCtx(id, r)) + next.ServeHTTP(w, identity.AddToRequestCtx(id, r)) }) } -func FixKerberosContext(next http.Handler) http.Handler { +func TransposeSPNEGOContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gid := goidentity.FromHTTPRequestContext(r) if gid != nil { - id := FromRequestCtx(r) + id := identity.FromRequestCtx(r) id.SetUserName(gid.UserName()) id.SetAuthenticated(gid.Authenticated()) id.SetDomain(gid.Domain()) id.SetAuthTime(gid.AuthTime()) - r = AddToRequestCtx(id, r) + r = identity.AddToRequestCtx(id, r) } next.ServeHTTP(w, r) }) } - -func GetAccessToken(ctx context.Context) string { - token, ok := ctx.Value(CtxAccessToken).(string) - if !ok { - log.Printf("cannot get access token from context") - return "" - } - return token -} diff --git a/cmd/rdpgw/web/oidc.go b/cmd/rdpgw/web/oidc.go index 34104e55ac20d6e51761cb1c3df4efb67094fb50..1a41b01b7fa4cbe16640a5ffe6de9cafd71200ec 100644 --- a/cmd/rdpgw/web/oidc.go +++ b/cmd/rdpgw/web/oidc.go @@ -3,9 +3,8 @@ package web import ( "encoding/hex" "encoding/json" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "github.com/coreos/go-oidc/v3/oidc" - "github.com/gorilla/sessions" "github.com/patrickmn/go-cache" "golang.org/x/oauth2" "math/rand" @@ -14,23 +13,20 @@ import ( ) const ( - CacheExpiration = time.Minute * 2 - CleanupInterval = time.Minute * 5 - sessionKeyAuthenticated = "authenticated" - oidcKeyUserName = "preferred_username" + CacheExpiration = time.Minute * 2 + CleanupInterval = time.Minute * 5 + oidcKeyUserName = "preferred_username" ) type OIDC struct { oAuth2Config *oauth2.Config oidcTokenVerifier *oidc.IDTokenVerifier stateStore *cache.Cache - sessionStore sessions.Store } type OIDCConfig struct { OAuth2Config *oauth2.Config OIDCTokenVerifier *oidc.IDTokenVerifier - SessionStore sessions.Store } func (c *OIDCConfig) New() *OIDC { @@ -38,7 +34,6 @@ func (c *OIDCConfig) New() *OIDC { oAuth2Config: c.OAuth2Config, oidcTokenVerifier: c.OIDCTokenVerifier, stateStore: cache.New(CacheExpiration, CleanupInterval), - sessionStore: c.SessionStore, } } @@ -85,22 +80,13 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) { return } - session, err := h.sessionStore.Get(r, RdpGwSession) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - id := common.FromRequestCtx(r) + id := identity.FromRequestCtx(r) id.SetUserName(data[oidcKeyUserName].(string)) id.SetAuthenticated(true) id.SetAuthTime(time.Now()) - id.SetAttribute(common.AttrAccessToken, oauth2Token.AccessToken) - - session.Options.MaxAge = MaxAge - session.Values[common.CTXKey] = id + id.SetAttribute(identity.AttrAccessToken, oauth2Token.AccessToken) - if err = session.Save(r, w); err != nil { + if err = SaveSessionIdentity(r, w, id); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } @@ -109,14 +95,9 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) { func (h *OIDC) Authenticated(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session, err := h.sessionStore.Get(r, RdpGwSession) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + id := identity.FromRequestCtx(r) - id := session.Values[common.CTXKey].(common.Identity) - if id == nil { + if !id.Authenticated() { seed := make([]byte, 16) rand.Read(seed) state := hex.EncodeToString(seed) @@ -126,6 +107,6 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler { } // replace the identity with the one from the sessions - next.ServeHTTP(w, common.AddToRequestCtx(id, r)) + next.ServeHTTP(w, r) }) } diff --git a/cmd/rdpgw/web/router.go b/cmd/rdpgw/web/router.go new file mode 100644 index 0000000000000000000000000000000000000000..12cc8af956e11016f71d72c82018f00cc49be728 --- /dev/null +++ b/cmd/rdpgw/web/router.go @@ -0,0 +1,31 @@ +package web + +import ( + "net/http" +) + +type AuthMux struct { + headers []string +} + +func NewAuthMux() *AuthMux { + return &AuthMux{} +} + +func (a *AuthMux) Register(s string) { + a.headers = append(a.headers, s) +} + +func (a *AuthMux) Route(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := r.Header.Get("Authorization") + if h == "" { + for _, s := range a.headers { + w.Header().Add("WWW-Authenticate", s) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + } + next.ServeHTTP(w, r) + }) +} diff --git a/cmd/rdpgw/web/session.go b/cmd/rdpgw/web/session.go index 1f9a0d9fa700b491c416c27230bb7918a75a6bda..fca4bc92b10299b3dd82b33eac587937dd73f430 100644 --- a/cmd/rdpgw/web/session.go +++ b/cmd/rdpgw/web/session.go @@ -1,30 +1,75 @@ package web import ( + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "github.com/gorilla/sessions" "log" + "net/http" "os" ) -type SessionManagerConf struct { - SessionKey []byte - SessionEncryptionKey []byte - StoreType string -} +const ( + rdpGwSession = "RDPGWSESSION" + MaxAge = 120 + identityKey = "RDPGWID" +) + +var sessionStore sessions.Store -func (c *SessionManagerConf) Init() sessions.Store { - if len(c.SessionKey) < 32 { +func InitStore(sessionKey []byte, encryptionKey []byte, storeType string) { + if len(sessionKey) < 32 { log.Fatal("Session key too small") } - if len(c.SessionEncryptionKey) < 32 { + if len(encryptionKey) < 32 { log.Fatal("Session key too small") } - if c.StoreType == "file" { + if storeType == "file" { log.Println("Filesystem is used as session storage") - return sessions.NewFilesystemStore(os.TempDir(), c.SessionKey, c.SessionEncryptionKey) + sessionStore = sessions.NewFilesystemStore(os.TempDir(), sessionKey, encryptionKey) } else { log.Println("Cookies are used as session storage") - return sessions.NewCookieStore(c.SessionKey, c.SessionEncryptionKey) + sessionStore = sessions.NewCookieStore(sessionKey, encryptionKey) + } +} + +func GetSession(r *http.Request) (*sessions.Session, error) { + session, err := sessionStore.Get(r, rdpGwSession) + if err != nil { + return nil, err } + return session, nil +} + +func GetSessionIdentity(r *http.Request) (identity.Identity, error) { + s, err := GetSession(r) + if err != nil { + return nil, err + } + + idData := s.Values[identityKey] + if idData == nil { + return nil, nil + + } + id := identity.NewUser() + id.Unmarshal(idData.([]byte)) + return id, nil +} + +func SaveSessionIdentity(r *http.Request, w http.ResponseWriter, id identity.Identity) error { + session, err := GetSession(r) + if err != nil { + return err + } + session.Options.MaxAge = MaxAge + + idData, err := id.Marshal() + if err != nil { + return err + } + session.Values[identityKey] = idData + + return sessionStore.Save(r, w, session) + } diff --git a/cmd/rdpgw/web/web.go b/cmd/rdpgw/web/web.go index 4db4a906aaebe210d0eaef9e290709d1ce14d34b..0c87206c3ffb5225e841d7fb6e7bea32c4c157b0 100644 --- a/cmd/rdpgw/web/web.go +++ b/cmd/rdpgw/web/web.go @@ -5,7 +5,7 @@ import ( "encoding/hex" "errors" "fmt" - "github.com/gorilla/sessions" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "log" "math/rand" "net/http" @@ -14,17 +14,11 @@ import ( "time" ) -const ( - RdpGwSession = "RDPGWSESSION" - MaxAge = 120 -) - type TokenGeneratorFunc func(context.Context, string, string) (string, error) type UserTokenGeneratorFunc func(context.Context, string) (string, error) type QueryInfoFunc func(context.Context, string, string) (string, error) type Config struct { - SessionStore sessions.Store PAATokenGenerator TokenGeneratorFunc UserTokenGenerator UserTokenGeneratorFunc QueryInfo QueryInfoFunc @@ -46,7 +40,6 @@ type RdpOpts struct { } type Handler struct { - sessionStore sessions.Store paaTokenGenerator TokenGeneratorFunc enableUserToken bool userTokenGenerator UserTokenGeneratorFunc @@ -63,7 +56,6 @@ func (c *Config) NewHandler() *Handler { log.Fatal("Not enough hosts to connect to specified") } return &Handler{ - sessionStore: c.SessionStore, paaTokenGenerator: c.PAATokenGenerator, enableUserToken: c.EnableUserToken, userTokenGenerator: c.UserTokenGenerator, @@ -132,13 +124,13 @@ func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) { } func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) { + id := identity.FromRequestCtx(r) ctx := r.Context() - userName, ok := ctx.Value("preferred_username").(string) opts := h.rdpOpts - if !ok { - log.Printf("preferred_username not found in context") + if !id.Authenticated() { + log.Printf("unauthenticated user %s", id.UserName()) http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError) return } @@ -149,13 +141,13 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } - host = strings.Replace(host, "{{ preferred_username }}", userName, 1) + host = strings.Replace(host, "{{ preferred_username }}", id.UserName(), 1) // split the username into user and domain - var user = userName + var user = id.UserName() var domain = opts.DefaultDomain if opts.SplitUserDomain { - creds := strings.SplitN(userName, "@", 2) + creds := strings.SplitN(id.UserName(), "@", 2) user = creds[0] if len(creds) > 1 { domain = creds[1] @@ -203,6 +195,8 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) { rdp.Connection.GatewayHostname = h.gatewayAddress.Host rdp.Connection.GatewayCredentialsSource = SourceCookie rdp.Connection.GatewayAccessToken = token + rdp.Connection.GatewayCredentialMethod = 1 + rdp.Connection.GatewayUsageMethod = 1 rdp.Session.NetworkAutodetect = opts.NetworkAutoDetect != 0 rdp.Session.BandwidthAutodetect = opts.BandwidthAutoDetect != 0 rdp.Session.ConnectionType = opts.ConnectionType diff --git a/cmd/rdpgw/web/web_test.go b/cmd/rdpgw/web/web_test.go index 6ceccdf7702c97519627f2c0cf5f09b35720ab12..02aae987f26a60cbb5fe8085b0a2795d70f8227d 100644 --- a/cmd/rdpgw/web/web_test.go +++ b/cmd/rdpgw/web/web_test.go @@ -2,6 +2,7 @@ package web import ( "context" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" "net/http" "net/http/httptest" @@ -113,9 +114,13 @@ func TestHandler_HandleDownload(t *testing.T) { } rr := httptest.NewRecorder() + id := identity.NewUser() + + id.SetUserName(testuser) + id.SetAuthenticated(true) + + req = identity.AddToRequestCtx(id, req) ctx := req.Context() - ctx = context.WithValue(ctx, "preferred_username", testuser) - req = req.WithContext(ctx) u, _ := url.Parse(gateway) c := Config{ diff --git a/go.mod b/go.mod index 01c99e0e337a4bf2dfcbf7690a355b9c3a116224..7f96c7f90818368fad15dd11ea381020568726ea 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/fatih/structs v1.1.0 github.com/go-jose/go-jose/v3 v3.0.0 github.com/google/uuid v1.1.2 + github.com/gorilla/mux v1.8.0 github.com/gorilla/sessions v1.2.1 github.com/gorilla/websocket v1.5.0 github.com/jcmturner/gofork v1.7.6