diff --git a/cmd/rdpgw/config/configuration.go b/cmd/rdpgw/config/configuration.go index 6f337c978901cb13eea45a960efc0a2a7e9d8e2c..2e5026bf7d65d37caae8a2cd108e7786633dc333 100644 --- a/cmd/rdpgw/config/configuration.go +++ b/cmd/rdpgw/config/configuration.go @@ -236,7 +236,7 @@ func (s *ServerConfig) KerberosEnabled() bool { } func (s *ServerConfig) BasicAuthEnabled() bool { - return s.matchAuth("local") + return s.matchAuth("local") || s.matchAuth("basic") } func (s *ServerConfig) matchAuth(needle string) bool { diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index bdc3e9d376bba74fa1204b85e09374ab4baae8d4..42b0e5f6146f31e59cf312f0382200c393333ec3 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -226,12 +226,13 @@ func main() { // for stacking of authentication auth := web.NewAuthMux() + rdp.MatcherFunc(web.NoAuthz).HandlerFunc(auth.SetAuthenticate) // basic auth if conf.Server.BasicAuthEnabled() { log.Printf("enabling basic authentication") q := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket} - rdp.Headers("Authorization", "Basic*").HandlerFunc(q.BasicAuth(gw.HandleGatewayProtocol)) + rdp.NewRoute().HeadersRegexp("Authorization", "Basic").HandlerFunc(q.BasicAuth(gw.HandleGatewayProtocol)) auth.Register(`Basic realm="restricted", charset="UTF-8"`) } @@ -242,7 +243,7 @@ func main() { if err != nil { log.Fatalf("Cannot load keytab: %s", err) } - rdp.Headers("Authorization", "Negotiate*").Handler( + rdp.NewRoute().HeadersRegexp("Authorization", "Negotiate").Handler( spnego.SPNEGOKRB5Authenticate(web.TransposeSPNEGOContext(http.HandlerFunc(gw.HandleGatewayProtocol)), keytab, service.Logger(log.Default()))) @@ -253,9 +254,6 @@ func main() { auth.Register("Negotiate") } - // allow stacking of authentication - rdp.Use(auth.Route) - // setup server server := http.Server{ Addr: ":" + strconv.Itoa(conf.Server.Port), diff --git a/cmd/rdpgw/web/router.go b/cmd/rdpgw/web/router.go index 12cc8af956e11016f71d72c82018f00cc49be728..02069ae35108abebee473b73ee07c30ab630d610 100644 --- a/cmd/rdpgw/web/router.go +++ b/cmd/rdpgw/web/router.go @@ -1,6 +1,7 @@ package web import ( + "github.com/gorilla/mux" "net/http" ) @@ -16,16 +17,13 @@ func (a *AuthMux) Register(s string) { a.headers = append(a.headers, s) } -func (a *AuthMux) Route(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - h := r.Header.Get("Authorization") - if h == "" { - for _, s := range a.headers { - w.Header().Add("WWW-Authenticate", s) - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } - } - next.ServeHTTP(w, r) - }) +func (a *AuthMux) SetAuthenticate(w http.ResponseWriter, r *http.Request) { + for _, s := range a.headers { + w.Header().Add("WWW-Authenticate", s) + } + http.Error(w, "Unauthorized", http.StatusUnauthorized) +} + +func NoAuthz(r *http.Request, rm *mux.RouteMatch) bool { + return r.Header.Get("Authorization") == "" }