diff --git a/main.go b/main.go index 6ee11e11fa7bec78bb3032db77836636b25ac4a3..053b97765fa91aae509223849c7d9ac9d7bcfb3c 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "github.com/bolkedebruin/rdpgw/config" "github.com/bolkedebruin/rdpgw/protocol" + "github.com/bolkedebruin/rdpgw/security" "github.com/coreos/go-oidc/v3/oidc" "github.com/patrickmn/go-cache" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -89,6 +90,11 @@ func main() { TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 } + // setup security + securityConfig := &security.Config{ + Store: tokens, + } + // create the gateway handlerConfig := protocol.HandlerConf{ IdleTimeout: conf.Caps.IdleTimeout, @@ -103,6 +109,7 @@ func main() { DisableAll: conf.Caps.DisableRedirect, EnableAll: conf.Caps.RedirectAll, }, + VerifyTunnelCreate: securityConfig.VerifyPAAToken, } gw := protocol.Gateway{ HandlerConf: &handlerConfig, diff --git a/protocol/handler.go b/protocol/handler.go index ad5d43bd1f70ff59b6f33ba4383b92310ed79a4b..d39f3f9cbcdd0d30b2f392e0971495144e11d22c 100644 --- a/protocol/handler.go +++ b/protocol/handler.go @@ -12,9 +12,9 @@ import ( "time" ) -type VerifyPAACookieFunc func(string) (bool, error) -type VerifyTunnelAuthFunc func(string) (bool, error) -type VerifyServerFunc func(string) (bool, error) +type VerifyTunnelCreate func(*SessionInfo, string) (bool, error) +type VerifyTunnelAuthFunc func(*SessionInfo, string) (bool, error) +type VerifyServerFunc func(*SessionInfo, string) (bool, error) type RedirectFlags struct { Clipboard bool @@ -27,9 +27,10 @@ type RedirectFlags struct { } type Handler struct { + Session *SessionInfo TransportIn transport.Transport TransportOut transport.Transport - VerifyPAACookieFunc VerifyPAACookieFunc + VerifyTunnelCreate VerifyTunnelCreate VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyServerFunc VerifyServerFunc RedirectFlags int @@ -42,7 +43,7 @@ type Handler struct { } type HandlerConf struct { - VerifyPAACookieFunc VerifyPAACookieFunc + VerifyTunnelCreate VerifyTunnelCreate VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyServerFunc VerifyServerFunc RedirectFlags RedirectFlags @@ -53,6 +54,7 @@ type HandlerConf struct { func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler { h := &Handler{ + Session: s, TransportIn: s.TransportIn, TransportOut: s.TransportOut, State: SERVER_STATE_INITIAL, @@ -60,7 +62,7 @@ func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler { IdleTimeout: conf.IdleTimeout, SmartCardAuth: conf.SmartCardAuth, TokenAuth: conf.TokenAuth, - VerifyPAACookieFunc: conf.VerifyPAACookieFunc, + VerifyTunnelCreate: conf.VerifyTunnelCreate, VerifyServerFunc: conf.VerifyServerFunc, VerifyTunnelAuthFunc: conf.VerifyTunnelAuthFunc, } @@ -92,8 +94,8 @@ func (h *Handler) Process() error { return errors.New("wrong state") } _, cookie := readCreateTunnelRequest(pkt) - if h.VerifyPAACookieFunc != nil { - if ok, _ := h.VerifyPAACookieFunc(cookie); !ok { + if h.VerifyTunnelCreate != nil { + if ok, _ := h.VerifyTunnelCreate(h.Session, cookie); !ok { log.Printf("Invalid PAA cookie: %s", cookie) return errors.New("invalid PAA cookie") } @@ -109,7 +111,7 @@ func (h *Handler) Process() error { } client := h.readTunnelAuthRequest(pkt) if h.VerifyTunnelAuthFunc != nil { - if ok, _ := h.VerifyTunnelAuthFunc(client); !ok { + if ok, _ := h.VerifyTunnelAuthFunc(h.Session, client); !ok { log.Printf("Invalid client name: %s", client) return errors.New("invalid client name") } @@ -126,7 +128,7 @@ func (h *Handler) Process() error { server, port := readChannelCreateRequest(pkt) host := net.JoinHostPort(server, strconv.Itoa(int(port))) if h.VerifyServerFunc != nil { - if ok, _ := h.VerifyServerFunc(host); !ok { + if ok, _ := h.VerifyServerFunc(h.Session, host); !ok { log.Printf("Not allowed to connect to %s by policy handler", host) } } diff --git a/security/simple.go b/security/simple.go index 30c6cda187385581d54f8ff5225e73c6473fdd21..2ab4260f3db176736e0dfaefa9beb030170ff958 100644 --- a/security/simple.go +++ b/security/simple.go @@ -1,5 +1,21 @@ package security -func VerifyServerTemplate(server string) (bool, err) { +import ( + "github.com/bolkedebruin/rdpgw/protocol" + "github.com/patrickmn/go-cache" + "log" +) +type Config struct { + Store *cache.Cache } + +func (c *Config) VerifyPAAToken(s *protocol.SessionInfo, token string) (bool, error) { + _, found := c.Store.Get(token) + if !found { + log.Printf("PAA Token %s not found", token) + return false, nil + } + + return true, nil +} \ No newline at end of file