diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index 96ce034ba76ce93b94b44768fa211eadd16e774e..e89e77d587cdc5fbcb360a30cff2f0a7273b1094 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -41,6 +41,8 @@ func main() { security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey) security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey) security.QuerySigningKey = []byte(conf.Security.QueryTokenSigningKey) + security.HostSelection = conf.Server.HostSelection + security.Hosts = conf.Server.Hosts // configure api api := &api.Config{ @@ -136,7 +138,7 @@ func main() { } // create the gateway - handlerConfig := protocol.ServerConf{ + gwConfig := protocol.ServerConf{ IdleTimeout: conf.Caps.IdleTimeout, TokenAuth: conf.Caps.TokenAuth, SmartCardAuth: conf.Caps.SmartCardAuth, @@ -153,11 +155,13 @@ func main() { ReceiveBuf: conf.Server.ReceiveBuf, } if conf.Caps.TokenAuth { - handlerConfig.VerifyTunnelAuthFunc = security.VerifyPAAToken - handlerConfig.VerifyServerFunc = security.VerifyServerFunc + gwConfig.VerifyTunnelAuthFunc = security.VerifyPAAToken + gwConfig.VerifyServerFunc = security.VerifyServerFunc + } else { + gwConfig.VerifyServerFunc = security.BasicVerifyServer } gw := protocol.Gateway{ - ServerConf: &handlerConfig, + ServerConf: &gwConfig, } if conf.Server.Authentication == "local" { diff --git a/cmd/rdpgw/protocol/server.go b/cmd/rdpgw/protocol/server.go index 6571ece871e011b759a2a9c55c163608b8092939..3f073338563fa393b53d04a56b51f496f5759804 100644 --- a/cmd/rdpgw/protocol/server.go +++ b/cmd/rdpgw/protocol/server.go @@ -143,6 +143,7 @@ func (s *Server) Process(ctx context.Context) error { server, port := s.channelRequest(pkt) host := net.JoinHostPort(server, strconv.Itoa(int(port))) if s.VerifyServerFunc != nil { + log.Printf("Verifying %s host connection", host) if ok, _ := s.VerifyServerFunc(ctx, host); !ok { log.Printf("Not allowed to connect to %s by policy handler", host) msg := s.channelResponse(E_PROXY_RAP_ACCESSDENIED) diff --git a/cmd/rdpgw/security/basic.go b/cmd/rdpgw/security/basic.go new file mode 100644 index 0000000000000000000000000000000000000000..64b5db0ad6c9b25d61d7db3d460247b2d827c2f1 --- /dev/null +++ b/cmd/rdpgw/security/basic.go @@ -0,0 +1,36 @@ +package security + +import ( + "context" + "errors" + "fmt" + "log" +) + +var ( + Hosts []string + HostSelection string +) + +func BasicVerifyServer(ctx context.Context, host string) (bool, error) { + if HostSelection == "any" { + return true, nil + } + + if HostSelection == "signed" { + // todo get from context + return false, errors.New("cannot verify host in 'signed' mode as token data is missing") + } + + if HostSelection == "roundrobin" || HostSelection == "unsigned" { + log.Printf("Checking host") + for _, h := range Hosts { + if h == host { + return true, nil + } + } + return false, fmt.Errorf("invalid host %s", host) + } + + return false, errors.New("unrecognized host selection criteria") +}