diff --git a/main.go b/main.go index a3a9163d73e6f1df0049e3d00ba635b52834aed2..804872adbc832cea785e576be0a9bc937a9b3664 100644 --- a/main.go +++ b/main.go @@ -89,7 +89,18 @@ func main() { TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 } - http.HandleFunc("/remoteDesktopGateway/", protocol.HandleGatewayProtocol) + // create the gateway + handlerConfig := protocol.HandlerConf{ + TokenAuth: true, + RedirectFlags: protocol.RedirectFlags{ + Clipboard: true, + }, + } + gw := protocol.Gateway{ + HandlerConf: &handlerConfig, + } + + http.HandleFunc("/remoteDesktopGateway/", gw.HandleGatewayProtocol) http.HandleFunc("/connect", handleRdpDownload) http.Handle("/metrics", promhttp.Handler()) http.HandleFunc("/callback", handleCallback) diff --git a/protocol/handler.go b/protocol/handler.go index 0890b9150041be20fcfd4fea527833c93164bd7f..247fbb8b7f0afefffce12bc20ca2e3dbc4fde818 100644 --- a/protocol/handler.go +++ b/protocol/handler.go @@ -12,19 +12,28 @@ import ( "time" ) -// When should the client disconnect when idle in minutes -var IdleTimeout = 0 - type VerifyPAACookieFunc func(string) (bool, error) type VerifyTunnelAuthFunc func(string) (bool, error) type VerifyServerFunc func(string) (bool, error) +type RedirectFlags struct { + Clipboard bool + Port bool + Drive bool + Printer bool + Pnp bool + disableAll bool + enableAll bool +} + type Handler struct { TransportIn transport.Transport TransportOut transport.Transport VerifyPAACookieFunc VerifyPAACookieFunc VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyServerFunc VerifyServerFunc + RedirectFlags int + IdleTimeout int SmartCardAuth bool TokenAuth bool ClientName string @@ -32,11 +41,28 @@ type Handler struct { State int } -func NewHandler(in transport.Transport, out transport.Transport) *Handler { +type HandlerConf struct { + VerifyPAACookieFunc VerifyPAACookieFunc + VerifyTunnelAuthFunc VerifyTunnelAuthFunc + VerifyServerFunc VerifyServerFunc + RedirectFlags RedirectFlags + IdleTimeout int + SmartCardAuth bool + TokenAuth bool +} + +func NewHandler(in transport.Transport, out transport.Transport, conf *HandlerConf) *Handler { h := &Handler{ - TransportIn: in, - TransportOut: out, - State: SERVER_STATE_INITIAL, + TransportIn: in, + TransportOut: out, + State: SERVER_STATE_INITIAL, + RedirectFlags: makeRedirectFlags(conf.RedirectFlags), + IdleTimeout: conf.IdleTimeout, + SmartCardAuth: conf.SmartCardAuth, + TokenAuth: conf.TokenAuth, + VerifyPAACookieFunc: conf.VerifyPAACookieFunc, + VerifyServerFunc: conf.VerifyServerFunc, + VerifyTunnelAuthFunc: conf.VerifyTunnelAuthFunc, } return h } @@ -55,8 +81,8 @@ func (h *Handler) Process() error { log.Printf("Handshake attempted while in wrong state %d != %d", h.State, SERVER_STATE_INITIAL) return errors.New("wrong state") } - major, minor, _, auth := readHandshake(pkt) - msg := h.handshakeResponse(major, minor, auth) + major, minor, _, _ := readHandshake(pkt) // todo check if auth matches what the handler can do + msg := h.handshakeResponse(major, minor) h.TransportOut.WritePacket(msg) h.State = SERVER_STATE_HANDSHAKE case PKT_TYPE_TUNNEL_CREATE: @@ -189,7 +215,7 @@ func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) { // Creates a packet the is a response to a handshake request // HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux // but could be in Windows. However the NTLM protocol is insecure -func (h *Handler) handshakeResponse(major byte, minor byte, auth uint16) []byte { +func (h *Handler) handshakeResponse(major byte, minor byte) []byte { var caps uint16 if h.SmartCardAuth { caps = caps | HTTP_EXTENDED_AUTH_PAA @@ -289,40 +315,13 @@ func (h *Handler) createTunnelAuthResponse() []byte { binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved - // flags - var redir uint32 - /* - if conf.Caps.RedirectAll { - redir = HTTP_TUNNEL_REDIR_ENABLE_ALL - } else if conf.Caps.DisableRedirect { - redir = HTTP_TUNNEL_REDIR_DISABLE_ALL - } else { - if conf.Caps.DisableClipboard { - redir = redir | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD - } - if conf.Caps.DisableDrive { - redir = redir | HTTP_TUNNEL_REDIR_DISABLE_DRIVE - } - if conf.Caps.DisablePnp { - redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PNP - } - if conf.Caps.DisablePrinter { - redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PRINTER - } - if conf.Caps.DisablePort { - redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PORT - } - } - */ - redir = HTTP_TUNNEL_REDIR_ENABLE_ALL - // idle timeout - if IdleTimeout < 0 { - IdleTimeout = 0 + if h.IdleTimeout < 0 { + h.IdleTimeout = 0 } - binary.Write(buf, binary.LittleEndian, uint32(redir)) // redir flags - binary.Write(buf, binary.LittleEndian, uint32(IdleTimeout)) // timeout in minutes + binary.Write(buf, binary.LittleEndian, uint32(h.RedirectFlags)) // redir flags + binary.Write(buf, binary.LittleEndian, uint32(h.IdleTimeout)) // timeout in minutes return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes()) } @@ -405,3 +404,31 @@ func createPacket(pktType uint16, data []byte) (packet []byte) { return buf.Bytes() } + +func makeRedirectFlags(flags RedirectFlags) int { + var redir = 0 + + if flags.disableAll { + return HTTP_TUNNEL_REDIR_DISABLE_ALL + } + if flags.enableAll { + return HTTP_TUNNEL_REDIR_ENABLE_ALL + } + + if !flags.Port { + redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PORT + } + if !flags.Clipboard { + redir = redir | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD + } + if !flags.Drive { + redir = redir | HTTP_TUNNEL_REDIR_DISABLE_DRIVE + } + if !flags.Pnp { + redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PNP + } + if !flags.Printer { + redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PRINTER + } + return redir +} diff --git a/protocol/rdpgw.go b/protocol/rdpgw.go index bc337b01b280cb447a91182a581f0342a4d3a364..ee395e21b554ca00f7db264763676d4f2a4285a8 100644 --- a/protocol/rdpgw.go +++ b/protocol/rdpgw.go @@ -39,6 +39,10 @@ var ( }) ) +type Gateway struct { + HandlerConf *HandlerConf +} + type SessionInfo struct { ConnId string CorrelationId string @@ -60,14 +64,11 @@ func init() { prometheus.MustRegister(websocketConnections) } -func HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) { +func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) { connectionCache.Set(float64(c.ItemCount())) if r.Method == MethodRDGOUT { - for name, value := range r.Header { - log.Printf("Header Name: %s Value: %s", name, value) - } if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { - handleLegacyProtocol(w, r) + g.handleLegacyProtocol(w, r) return } r.Method = "GET" // force @@ -78,25 +79,25 @@ func HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) { } defer conn.Close() - handleWebsocketProtocol(conn) + g.handleWebsocketProtocol(conn) } else if r.Method == MethodRDGIN { - handleLegacyProtocol(w, r) + g.handleLegacyProtocol(w, r) } } -func handleWebsocketProtocol(c *websocket.Conn) { +func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn) { websocketConnections.Inc() defer websocketConnections.Dec() inout, _ := transport.NewWS(c) - handler := NewHandler(inout, inout) + handler := NewHandler(inout, inout, g.HandlerConf) handler.Process() } // The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server // and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different // to ensure the connections do not get cached or terminated by a proxy prematurely. -func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { +func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { var s SessionInfo connId := r.Header.Get(rdgConnectionIdKey) @@ -143,7 +144,7 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { in.Drain() log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String()) - handler := NewHandler(in, s.TransportOut) + handler := NewHandler(in, s.TransportOut, g.HandlerConf) handler.Process() } } diff --git a/security/simple.go b/security/simple.go new file mode 100644 index 0000000000000000000000000000000000000000..30c6cda187385581d54f8ff5225e73c6473fdd21 --- /dev/null +++ b/security/simple.go @@ -0,0 +1,5 @@ +package security + +func VerifyServerTemplate(server string) (bool, err) { + +}