From fb58cb299e7a661611300267ea69ec6603b68c55 Mon Sep 17 00:00:00 2001
From: Bolke de Bruin <bolke@xs4all.nl>
Date: Wed, 24 Aug 2022 13:47:26 +0200
Subject: [PATCH] Add server implementation of basic auth

---
 cmd/auth/auth.go                  | 84 +++++++++++++++++++------------
 cmd/auth/proto/auth.proto         | 14 ------
 cmd/rdpgw/api/basic.go            | 62 +++++++++++++++++++++++
 cmd/rdpgw/api/web.go              |  2 +
 cmd/rdpgw/config/configuration.go |  9 +++-
 cmd/rdpgw/main.go                 | 13 +++--
 go.mod                            |  7 +--
 proto/auth.proto                  | 19 +++++++
 8 files changed, 157 insertions(+), 53 deletions(-)
 delete mode 100644 cmd/auth/proto/auth.proto
 create mode 100644 cmd/rdpgw/api/basic.go
 create mode 100644 proto/auth.proto

diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go
index 611762d..40d2542 100644
--- a/cmd/auth/auth.go
+++ b/cmd/auth/auth.go
@@ -1,38 +1,66 @@
 package main
 
 import (
+	"context"
 	"errors"
-	"github.com/golang/protobuf/proto"
-	ipc "github.com/james-barrow/golang-ipc"
+	"github.com/bolkedebruin/rdpgw/shared/auth"
 	"github.com/msteinert/pam"
 	"github.com/thought-machine/go-flags"
+	"google.golang.org/grpc"
 	"log"
+	"net"
+	"os"
+	"syscall"
+)
+
+const (
+	protocol = "unix"
 )
 
 var opts struct {
-	serviceName string `short:"s" long:"service" default:"rdpgw" description:"the PAM service name to use"`
+	ServiceName string `short:"n" long:"name" default:"rdpgw" description:"the PAM service name to use"`
+	SocketAddr  string `short:"s" long:"socket" default:"/tmp/rdpgw-auth.sock" description:"the location of the socket"`
+}
+
+type AuthServiceImpl struct {
+	serviceName string
+}
+
+var _ auth.AuthenticateServer = (*AuthServiceImpl)(nil)
+
+func NewAuthService(serviceName string) auth.AuthenticateServer {
+	s := &AuthServiceImpl{serviceName: serviceName}
+	return s
 }
 
-func auth(service, user, passwd string) error {
-	t, err := pam.StartFunc(service, user, func(s pam.Style, msg string) (string, error) {
+func (s *AuthServiceImpl) Authenticate(ctx context.Context, message *auth.UserPass) (*auth.AuthResponse, error) {
+	t, err := pam.StartFunc(s.serviceName, message.Username, func(s pam.Style, msg string) (string, error) {
 		switch s {
 		case pam.PromptEchoOff:
-			return passwd, nil
+			return message.Password, nil
 		case pam.PromptEchoOn, pam.ErrorMsg, pam.TextInfo:
 			return "", nil
 		}
 		return "", errors.New("unrecognized PAM message style")
 	})
 
+	r := &auth.AuthResponse{}
+	r.Authenticated = false
 	if err != nil {
-		return err
+		log.Printf("Error authenticating user: %s due to: %s", message.Username, err)
+		r.Error = err.Error()
+		return r, err
 	}
 
 	if err = t.Authenticate(0); err != nil {
-		return err
+		log.Printf("Authentication for user: %s failed due to: %s", message.Username, err)
+		r.Error = err.Error()
+		return r, nil
 	}
 
-	return nil
+	log.Printf("User: %s authenticated", message.Username)
+	r.Authenticated = true
+	return r, nil
 }
 
 func main() {
@@ -41,32 +69,24 @@ func main() {
 		panic(err)
 	}
 
-	config := &ipc.ServerConfig{UnmaskPermissions: true}
-	sc, err := ipc.StartServer("rdpgw-auth", config)
-	for {
-		msg, err := sc.Read()
-		if err != nil {
-			log.Printf("server error, %s", err)
-			continue
-		}
-		if msg.MsgType > 0 {
-			req := &UserPass{}
-			if err = proto.Unmarshal(msg.Data, req); err != nil {
-				log.Printf("cannot unmarshal request %s", string(msg.Data))
-				continue
-			}
-			err := auth(opts.serviceName, req.Username, req.Password)
-			if err != nil {
-				res := &Response{Status: "cannot authenticate"}
-				out, err := proto.Marshal(res)
-				if err != nil {
-					log.Fatalf("cannot marshal response due to %s", err)
-				}
-				sc.Write(1, out)
+	log.Printf("Starting auth server on %s", opts.SocketAddr)
+	cleanup := func() {
+		if _, err := os.Stat(opts.SocketAddr); err == nil {
+			if err := os.RemoveAll(opts.SocketAddr); err != nil {
+				log.Fatal(err)
 			}
 		}
 	}
+	cleanup()
+
+	oldUmask := syscall.Umask(0)
+	listener, err := net.Listen(protocol, opts.SocketAddr)
+	syscall.Umask(oldUmask)
 	if err != nil {
-		log.Printf("cannot authenticate due to %s", err)
+		log.Fatal(err)
 	}
+	server := grpc.NewServer()
+	service := NewAuthService(opts.ServiceName)
+	auth.RegisterAuthenticateServer(server, service)
+	server.Serve(listener)
 }
diff --git a/cmd/auth/proto/auth.proto b/cmd/auth/proto/auth.proto
deleted file mode 100644
index acc33a2..0000000
--- a/cmd/auth/proto/auth.proto
+++ /dev/null
@@ -1,14 +0,0 @@
-syntax = "proto3";
-
-package main;
-
-option go_package = "./auth;main";
-
-message UserPass {
-  string username = 1;
-  string password = 2;
-}
-
-message Response {
-  string status = 1;
-}
\ No newline at end of file
diff --git a/cmd/rdpgw/api/basic.go b/cmd/rdpgw/api/basic.go
new file mode 100644
index 0000000..91f6d17
--- /dev/null
+++ b/cmd/rdpgw/api/basic.go
@@ -0,0 +1,62 @@
+package api
+
+import (
+	"context"
+	"github.com/bolkedebruin/rdpgw/shared/auth"
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/credentials/insecure"
+	"log"
+	"net"
+	"net/http"
+	"time"
+)
+
+const (
+	protocol = "unix"
+)
+
+func (c *Config) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
+	return func(w http.ResponseWriter, r *http.Request) {
+		username, password, ok := r.BasicAuth()
+		if ok {
+			ctx := r.Context()
+
+			conn, err := grpc.Dial(c.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()),
+				grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
+					return net.Dial(protocol, addr)
+				}))
+			if err != nil {
+				log.Printf("Cannot reach authentication provider: %s", err)
+				http.Error(w, "Server error", http.StatusInternalServerError)
+				return
+			}
+			defer conn.Close()
+
+			c := auth.NewAuthenticateClient(conn)
+			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+			defer cancel()
+
+			req := &auth.UserPass{Username: username, Password: password}
+			res, err := c.Authenticate(ctx, req)
+			if err != nil {
+				log.Printf("Error talking to authentication provider: %s", err)
+				http.Error(w, "Server error", http.StatusInternalServerError)
+				return
+			}
+
+			if !res.Authenticated {
+				log.Printf("User %s is not authenticated for this service", username)
+			} else {
+				next.ServeHTTP(w, r.WithContext(ctx))
+				return
+			}
+
+		}
+		// If the Authentication header is not present, is invalid, or the
+		// username or password is wrong, then set a WWW-Authenticate
+		// header to inform the client that we expect them to use basic
+		// authentication and send a 401 Unauthorized response.
+		w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
+		http.Error(w, "Unauthorized", http.StatusUnauthorized)
+	}
+}
diff --git a/cmd/rdpgw/api/web.go b/cmd/rdpgw/api/web.go
index acf1720..b1da3ab 100644
--- a/cmd/rdpgw/api/web.go
+++ b/cmd/rdpgw/api/web.go
@@ -51,6 +51,8 @@ type Config struct {
 	ConnectionType       int
 	SplitUserDomain      bool
 	DefaultDomain        string
+	SocketAddress        string
+	Authentication       string
 }
 
 func (c *Config) NewApi() {
diff --git a/cmd/rdpgw/config/configuration.go b/cmd/rdpgw/config/configuration.go
index 43c9445..f034964 100644
--- a/cmd/rdpgw/config/configuration.go
+++ b/cmd/rdpgw/config/configuration.go
@@ -30,8 +30,10 @@ type ServerConfig struct {
 	SessionEncryptionKey string   `koanf:"sessionencryptionkey"`
 	SessionStore         string   `koanf:"sessionstore"`
 	SendBuf              int      `koanf:"sendbuf"`
-	ReceiveBuf           int      `koanf:"recievebuf"`
+	ReceiveBuf           int      `koanf:"receivebuf"`
 	DisableTLS           bool     `koanf:"disabletls"`
+	Authentication       string   `koanf:"authentication"`
+	AuthSocket           string   `koanf:"authsocket"`
 }
 
 type OpenIDConfig struct {
@@ -121,6 +123,8 @@ func Load(configFile string) Configuration {
 		"Server.Port":                443,
 		"Server.SessionStore":        "cookie",
 		"Server.HostSelection":       "roundrobin",
+		"Server.Authentication":      "openid",
+		"Server.AuthSocket":          "/tmp/rdpgw-auth.sock",
 		"Client.NetworkAutoDetect":   1,
 		"Client.BandwidthAutoDetect": 1,
 		"Security.VerifyClientIp":    true,
@@ -182,6 +186,9 @@ func Load(configFile string) Configuration {
 		log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
 	}
 
+	if Conf.Server.Authentication == "local" && Conf.Server.DisableTLS {
+		log.Fatalf("basicauth=local and disabletls are mutually exclusive")
+	}
 	return Conf
 
 }
diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go
index 567c677..24e4bb0 100644
--- a/cmd/rdpgw/main.go
+++ b/cmd/rdpgw/main.go
@@ -89,6 +89,8 @@ func main() {
 		ConnectionType:       conf.Client.ConnectionType,
 		SplitUserDomain:      conf.Client.SplitUserDomain,
 		DefaultDomain:        conf.Client.DefaultDomain,
+		SocketAddress:        conf.Server.AuthSocket,
+		Authentication:       conf.Server.Authentication,
 	}
 	api.NewApi()
 
@@ -148,11 +150,16 @@ func main() {
 		ServerConf: &handlerConfig,
 	}
 
-	http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
-	http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
+	if conf.Server.Authentication == "local" {
+		http.Handle("/connect", common.EnrichContext(api.BasicAuth(api.HandleDownload)))
+		http.Handle("/remoteDesktopGateway/", common.EnrichContext(api.BasicAuth(gw.HandleGatewayProtocol)))
+	} else {
+		// openid
+		http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
+		http.HandleFunc("/callback", api.HandleCallback)
+	}
 	http.Handle("/metrics", promhttp.Handler())
 	http.HandleFunc("/tokeninfo", api.TokenInfo)
-	http.HandleFunc("/callback", api.HandleCallback)
 
 	if conf.Server.DisableTLS {
 		err = server.ListenAndServe()
diff --git a/go.mod b/go.mod
index 1922f09..4e7f144 100644
--- a/go.mod
+++ b/go.mod
@@ -7,17 +7,17 @@ require (
 	github.com/go-jose/go-jose/v3 v3.0.0
 	github.com/gorilla/sessions v1.2.1
 	github.com/gorilla/websocket v1.5.0
-	github.com/james-barrow/golang-ipc v1.0.0
 	github.com/knadh/koanf v1.4.2
 	github.com/msteinert/pam v1.0.0
 	github.com/patrickmn/go-cache v2.1.0+incompatible
 	github.com/prometheus/client_golang v1.12.1
 	github.com/thought-machine/go-flags v1.6.1
 	golang.org/x/oauth2 v0.0.0-20220722155238-128564f6959c
+	google.golang.org/grpc v1.49.0
+	google.golang.org/protobuf v1.28.1
 )
 
 require (
-	github.com/Microsoft/go-winio v0.4.16 // indirect
 	github.com/beorn7/perks v1.0.1 // indirect
 	github.com/cespare/xxhash/v2 v2.1.2 // indirect
 	github.com/fsnotify/fsnotify v1.5.4 // indirect
@@ -35,8 +35,9 @@ require (
 	golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
 	golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e // indirect
 	golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect
+	golang.org/x/text v0.3.7 // indirect
 	google.golang.org/appengine v1.6.7 // indirect
-	google.golang.org/protobuf v1.28.0 // indirect
+	google.golang.org/genproto v0.0.0-20200825200019-8632dd797987 // indirect
 	gopkg.in/square/go-jose.v2 v2.6.0 // indirect
 	gopkg.in/yaml.v3 v3.0.0 // indirect
 )
diff --git a/proto/auth.proto b/proto/auth.proto
new file mode 100644
index 0000000..de07903
--- /dev/null
+++ b/proto/auth.proto
@@ -0,0 +1,19 @@
+syntax = "proto3";
+
+package auth;
+
+option go_package = "./auth";
+
+message UserPass {
+  string username = 1;
+  string password = 2;
+}
+
+message AuthResponse {
+  bool authenticated = 1;
+  string error = 2;
+}
+
+service Authenticate {
+  rpc Authenticate (UserPass) returns (AuthResponse) {}
+}
\ No newline at end of file
-- 
GitLab