diff --git a/protocol/handler.go b/protocol/handler.go index 28f5f432990b08b58c5d5d0f7af15be90b685022..549b4efc3a44796e219dbbff2b79c81e79c5e837 100644 --- a/protocol/handler.go +++ b/protocol/handler.go @@ -21,7 +21,7 @@ type VerifyServerFunc func(string) (bool, error) type Handler struct { TransportIn transport.Transport - TransportOut transport.Transport + TransportOut transport.Transport VerifyPAACookieFunc VerifyPAACookieFunc VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyServerFunc VerifyServerFunc @@ -29,12 +29,14 @@ type Handler struct { TokenAuth bool ClientName string Remote net.Conn + State int } func NewHandler(in transport.Transport, out transport.Transport) *Handler { h := &Handler{ - TransportIn: in, + TransportIn: in, TransportOut: out, + State: SERVER_STATE_INITIAL, } return h } @@ -49,35 +51,61 @@ func (h *Handler) Process() error { switch pt { case PKT_TYPE_HANDSHAKE_REQUEST: + if h.State != SERVER_STATE_INITIAL { + log.Printf("Handshake attempted while in wrong state %d != %d", h.State, SERVER_STATE_INITIAL) + return errors.New("wrong state") + } major, minor, _, auth := readHandshake(pkt) msg := h.handshakeResponse(major, minor, auth) h.TransportOut.WritePacket(msg) + h.State = SERVER_STATE_HANDSHAKE case PKT_TYPE_TUNNEL_CREATE: - log.Printf("Tunnel create") + if h.State != SERVER_STATE_HANDSHAKE { + log.Printf("Tunnel create attempted while in wrong state %d != %d", + h.State, SERVER_STATE_HANDSHAKE) + return errors.New("wrong state") + } _, cookie := readCreateTunnelRequest(pkt) if h.VerifyPAACookieFunc != nil { - if ok, _ := h.VerifyPAACookieFunc(cookie); ok == false { + if ok, _ := h.VerifyPAACookieFunc(cookie); !ok { log.Printf("Invalid PAA cookie: %s", cookie) return errors.New("invalid PAA cookie") } } msg := createTunnelResponse() h.TransportOut.WritePacket(msg) - log.Printf("Tunnel done") + h.State = SERVER_STATE_TUNNEL_CREATE case PKT_TYPE_TUNNEL_AUTH: - log.Printf("Tunnel auth") - h.readTunnelAuthRequest(pkt) + if h.State != SERVER_STATE_TUNNEL_CREATE { + log.Printf("Tunnel auth attempted while in wrong state %d != %d", + h.State, SERVER_STATE_TUNNEL_CREATE) + return errors.New("wrong state") + } + client := h.readTunnelAuthRequest(pkt) + if ok, _ := h.VerifyTunnelAuthFunc(client); !ok { + log.Printf("Invalid client name: %s", client) + return errors.New("invalid client name") + } msg := h.createTunnelAuthResponse() h.TransportOut.WritePacket(msg) + h.State = SERVER_STATE_TUNNEL_AUTHORIZE case PKT_TYPE_CHANNEL_CREATE: + if h.State != SERVER_STATE_TUNNEL_AUTHORIZE { + log.Printf("Channel create attempted while in wrong state %d != %d", + h.State, SERVER_STATE_TUNNEL_AUTHORIZE) + return errors.New("wrong state") + } server, port := readChannelCreateRequest(pkt) - log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server) - h.Remote, err = net.DialTimeout( - "tcp", - net.JoinHostPort(server, strconv.Itoa(int(port))), - time.Second*15) + host := net.JoinHostPort(server, strconv.Itoa(int(port))) + if h.VerifyServerFunc != nil { + if ok, _ := h.VerifyServerFunc(host); !ok { + log.Printf("Not allowed to connect to %s by policy handler", host) + } + } + log.Printf("Establishing connection to RDP server: %s", host) + h.Remote, err = net.DialTimeout("tcp", host, time.Second*15) if err != nil { - log.Printf("Error connecting to %s, %d, %s", server, port, err) + log.Printf("Error connecting to %s, %s", host, err) return err } log.Printf("Connection established") @@ -87,14 +115,31 @@ func (h *Handler) Process() error { // Make sure to start the flow from the RDP server first otherwise connections // might hang eventually go h.sendDataPacket() + h.State = SERVER_STATE_CHANNEL_CREATE case PKT_TYPE_DATA: + if h.State != SERVER_STATE_CHANNEL_CREATE { + log.Printf("Data received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE) + return errors.New("wrong state") + } + h.State = SERVER_STATE_OPENED h.forwardDataPacket(pkt) case PKT_TYPE_KEEPALIVE: + // keepalives can be received while the channel is not open yet + if h.State < SERVER_STATE_CHANNEL_CREATE { + log.Printf("Keepalive received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE) + return errors.New("wrong state") + } + // avoid concurrency issues // p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) case PKT_TYPE_CLOSE_CHANNEL: + if h.State != SERVER_STATE_OPENED { + log.Printf("Channel closed while in wrong state %d != %d", h.State, SERVER_STATE_OPENED) + return errors.New("wrong state") + } h.TransportIn.Close() h.TransportOut.Close() + h.State = SERVER_STATE_CLOSED default: log.Printf("Unknown packet (size %d): %x", sz, pkt) } @@ -223,7 +268,7 @@ func createTunnelResponse() []byte { return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes()) } -func (h *Handler) readTunnelAuthRequest(data []byte) { +func (h *Handler) readTunnelAuthRequest(data []byte) string { buf := bytes.NewReader(data) var size uint16 @@ -231,7 +276,8 @@ func (h *Handler) readTunnelAuthRequest(data []byte) { clData := make([]byte, size) binary.Read(buf, binary.LittleEndian, &clData) clientName, _ := DecodeUTF16(clData) - log.Printf("Client: %s", clientName) + + return clientName } func (h *Handler) createTunnelAuthResponse() []byte { diff --git a/protocol/types.go b/protocol/types.go index c12487d6086da5c90ee8c3fe25a456fe17401ecb..a8e788ddfc13411a01a725d63c71e0d847418ec5 100644 --- a/protocol/types.go +++ b/protocol/types.go @@ -58,3 +58,12 @@ const ( HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1 ) +const ( + SERVER_STATE_INITIAL = 0x0 + SERVER_STATE_HANDSHAKE = 0x1 + SERVER_STATE_TUNNEL_CREATE = 0x2 + SERVER_STATE_TUNNEL_AUTHORIZE = 0x3 + SERVER_STATE_CHANNEL_CREATE = 0x4 + SERVER_STATE_OPENED = 0x5 + SERVER_STATE_CLOSED = 0x6 +) diff --git a/transport/legacy.go b/transport/legacy.go index 70f74ff7bedef5294fb3c64cabeb563fdd212f69..cfce51778e7af1c78b408f610b651d33e6558082 100644 --- a/transport/legacy.go +++ b/transport/legacy.go @@ -79,4 +79,4 @@ func (t *LegacyPKT) SendAccept(doSeed bool) { func (t *LegacyPKT) Drain() { p := make([]byte, 32767) t.Conn.Read(p) -} \ No newline at end of file +}