diff --git a/main.go b/main.go index 0d9203fe670dcff2ff2819853cb18410fccd72d6..7496a98391187c66eedc5d46bf2cd3212c15ef15 100644 --- a/main.go +++ b/main.go @@ -103,7 +103,7 @@ func main() { } // create the gateway - handlerConfig := protocol.HandlerConf{ + handlerConfig := protocol.ServerConf{ IdleTimeout: conf.Caps.IdleTimeout, TokenAuth: conf.Caps.TokenAuth, SmartCardAuth: conf.Caps.SmartCardAuth, @@ -120,7 +120,7 @@ func main() { VerifyServerFunc: security.VerifyServerFunc, } gw := protocol.Gateway{ - HandlerConf: &handlerConfig, + ServerConf: &handlerConfig, } http.Handle("/remoteDesktopGateway/", client.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol))) diff --git a/protocol/common.go b/protocol/common.go new file mode 100644 index 0000000000000000000000000000000000000000..662a8466fb92d03c08a9fb17ede4f28b689ef18a --- /dev/null +++ b/protocol/common.go @@ -0,0 +1,37 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "errors" + "io" +) + +func createPacket(pktType uint16, data []byte) (packet []byte) { + size := len(data) + 8 + buf := new(bytes.Buffer) + + binary.Write(buf, binary.LittleEndian, uint16(pktType)) + binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved + binary.Write(buf, binary.LittleEndian, uint32(size)) + buf.Write(data) + + return buf.Bytes() +} + +func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) { + // header needs to be 8 min + if len(data) < 8 { + return 0, 0, nil, errors.New("header too short, fragment likely") + } + r := bytes.NewReader(data) + binary.Read(r, binary.LittleEndian, &packetType) + r.Seek(4, io.SeekStart) + binary.Read(r, binary.LittleEndian, &size) + if len(data) < int(size) { + return packetType, size, data[8:], errors.New("data incomplete, fragment received") + } + return packetType, size, data[8:], nil +} + + diff --git a/protocol/handler_test.go b/protocol/protocol_test.go similarity index 67% rename from protocol/handler_test.go rename to protocol/protocol_test.go index bfbbf7ca9faa88fa6849c300e456ee26994927f2..20bdc98a50122c87b17e2aef9a0b8983de6bbe07 100644 --- a/protocol/handler_test.go +++ b/protocol/protocol_test.go @@ -7,20 +7,20 @@ import ( ) const ( - HeaderLen = 8 - HandshakeRequestLen = HeaderLen + 6 - HandshakeResponseLen = HeaderLen + 10 - TunnelCreateRequestLen = HeaderLen + 8 // + dynamic + HeaderLen = 8 + HandshakeRequestLen = HeaderLen + 6 + HandshakeResponseLen = HeaderLen + 10 + TunnelCreateRequestLen = HeaderLen + 8 // + dynamic TunnelCreateResponseLen = HeaderLen + 18 - TunnelAuthLen = HeaderLen + 2 // + dynamic - TunnelAuthResponseLen = HeaderLen + 16 + TunnelAuthLen = HeaderLen + 2 // + dynamic + TunnelAuthResponseLen = HeaderLen + 16 ) -func verifyPacketHeader(data []byte , expPt uint16, expSize uint32) (uint16, uint32, []byte, error) { +func verifyPacketHeader(data []byte, expPt uint16, expSize uint32) (uint16, uint32, []byte, error) { pt, size, pkt, err := readHeader(data) if pt != expPt { - return 0,0, []byte{}, fmt.Errorf("readHeader failed, expected packet type %d got %d", expPt, pt) + return 0, 0, []byte{}, fmt.Errorf("readHeader failed, expected packet type %d got %d", expPt, pt) } if size != expSize { @@ -38,6 +38,11 @@ func TestHandshake(t *testing.T) { client := ClientConfig{ PAAToken: "abab", } + s := &SessionInfo{} + hc := &ServerConf{ + TokenAuth: true, + } + h := NewServer(s, hc) data := client.handshakeRequest() @@ -49,23 +54,16 @@ func TestHandshake(t *testing.T) { log.Printf("pkt: %x", pkt) - major, minor, version, extAuth := readHandshake(pkt) + major, minor, version, extAuth := h.handshakeRequest(pkt) if major != MajorVersion || minor != MinorVersion || version != Version { - t.Fatalf("readHandshake failed got version %d.%d protocol %d, expected %d.%d protocol %d", + t.Fatalf("handshakeRequest failed got version %d.%d protocol %d, expected %d.%d protocol %d", major, minor, version, MajorVersion, MinorVersion, Version) } if !((extAuth & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) { - t.Fatalf("readHandshake failed got ext auth %d, expected %d", extAuth, extAuth | HTTP_EXTENDED_AUTH_PAA) - } - - s := &SessionInfo{} - hc := &HandlerConf{ - TokenAuth: true, + t.Fatalf("handshakeRequest failed got ext auth %d, expected %d", extAuth, extAuth|HTTP_EXTENDED_AUTH_PAA) } - h := NewHandler(s, hc) - data = h.handshakeResponse(0x0, 0x0) _, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_RESPONSE, HandshakeResponseLen) if err != nil { @@ -75,7 +73,7 @@ func TestHandshake(t *testing.T) { caps, err := client.handshakeResponse(pkt) if !((caps & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) { - t.Fatalf("handshakeResponse failed got caps %d, expected %d", caps, caps | HTTP_EXTENDED_AUTH_PAA) + t.Fatalf("handshakeResponse failed got caps %d, expected %d", caps, caps|HTTP_EXTENDED_AUTH_PAA) } } @@ -83,23 +81,28 @@ func TestTunnelCreation(t *testing.T) { client := ClientConfig{ PAAToken: "abab", } + s := &SessionInfo{} + hc := &ServerConf{ + TokenAuth: true, + } + h := NewServer(s, hc) data := client.tunnelRequest() _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE, - uint32(TunnelCreateRequestLen + 2 + len(client.PAAToken)*2)) + uint32(TunnelCreateRequestLen+2+len(client.PAAToken)*2)) if err != nil { t.Fatalf("verifyHeader failed: %s", err) } - caps, token := readCreateTunnelRequest(pkt) + caps, token := h.tunnelRequest(pkt) if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) { - t.Fatalf("readCreateTunnelRequest failed got caps %d, expected %d", caps, caps | HTTP_CAPABILITY_IDLE_TIMEOUT) + t.Fatalf("tunnelRequest failed got caps %d, expected %d", caps, caps|HTTP_CAPABILITY_IDLE_TIMEOUT) } if token != client.PAAToken { - t.Fatalf("readCreateTunnelRequest failed got token %s, expected %s", token, client.PAAToken) + t.Fatalf("tunnelRequest failed got token %s, expected %s", token, client.PAAToken) } - data = createTunnelResponse() + data = h.tunnelResponse() _, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_RESPONSE, TunnelCreateResponseLen) if err != nil { t.Fatalf("verifyHeader failed: %s", err) @@ -113,35 +116,35 @@ func TestTunnelCreation(t *testing.T) { t.Fatalf("tunnelResponse failed tunnel id %d, expected %d", tid, tunnelId) } if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) { - t.Fatalf("tunnelResponse failed got caps %d, expected %d", caps, caps | HTTP_CAPABILITY_IDLE_TIMEOUT) + t.Fatalf("tunnelResponse failed got caps %d, expected %d", caps, caps|HTTP_CAPABILITY_IDLE_TIMEOUT) } } func TestTunnelAuth(t *testing.T) { client := ClientConfig{} s := &SessionInfo{} - hc := &HandlerConf{ - TokenAuth: true, + hc := &ServerConf{ + TokenAuth: true, IdleTimeout: 10, RedirectFlags: RedirectFlags{ Clipboard: true, }, } - h := NewHandler(s, hc) + h := NewServer(s, hc) name := "test_name" data := client.tunnelAuthRequest(name) - _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen + len(name) * 2)) + _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2)) if err != nil { t.Fatalf("verifyHeader failed: %s", err) } - n := h.readTunnelAuthRequest(pkt) + n := h.tunnelAuthRequest(pkt) if n != name { - t.Fatalf("readTunnelAuthRequest failed got name %s, expected %s", n, name) + t.Fatalf("tunnelAuthRequest failed got name %s, expected %s", n, name) } - data = h.createTunnelAuthResponse() + data = h.tunnelAuthResponse() _, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH_RESPONSE, TunnelAuthResponseLen) if err != nil { t.Fatalf("verifyHeader failed: %s", err) @@ -152,10 +155,10 @@ func TestTunnelAuth(t *testing.T) { } if (flags & HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD) == HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD { t.Fatalf("tunnelAuthResponse failed got flags %d, expected %d", - flags, flags | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD) + flags, flags|HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD) } if int(timeout) != hc.IdleTimeout { t.Fatalf("tunnelAuthResponse failed got timeout %d, expected %d", timeout, hc.IdleTimeout) } -} \ No newline at end of file +} diff --git a/protocol/rdpgw.go b/protocol/rdpgw.go index f3e321ed016f7850d02c1efcaad28a3f4295e6d6..22729f61bbaa43b257d1c143e91b9b2a8176ce6c 100644 --- a/protocol/rdpgw.go +++ b/protocol/rdpgw.go @@ -42,7 +42,7 @@ var ( ) type Gateway struct { - HandlerConf *HandlerConf + ServerConf *ServerConf } type SessionInfo struct { @@ -102,12 +102,12 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn inout, _ := transport.NewWS(c) s.TransportOut = inout s.TransportIn = inout - handler := NewHandler(s, g.HandlerConf) + handler := NewServer(s, g.ServerConf) handler.Process(ctx) } // The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server -// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different +// and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different // to ensure the connections do not get cached or terminated by a proxy prematurely. func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s *SessionInfo) { log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil) @@ -145,8 +145,8 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s // read some initial data in.Drain() - log.Printf("Legacy handshake done for client %s", client.GetClientIp(r.Context())) - handler := NewHandler(s, g.HandlerConf) + log.Printf("Legacy handshakeRequest done for client %s", client.GetClientIp(r.Context())) + handler := NewServer(s, g.ServerConf) handler.Process(r.Context()) } } diff --git a/protocol/handler.go b/protocol/server.go similarity index 69% rename from protocol/handler.go rename to protocol/server.go index d4fdb162e3a1914c493cad2f1e4043a8296b387c..ba946b7f7245a0ef29c8323cf8b3f851111592f3 100644 --- a/protocol/handler.go +++ b/protocol/server.go @@ -27,7 +27,7 @@ type RedirectFlags struct { EnableAll bool } -type Handler struct { +type Server struct { Session *SessionInfo VerifyTunnelCreate VerifyTunnelCreate VerifyTunnelAuthFunc VerifyTunnelAuthFunc @@ -41,7 +41,7 @@ type Handler struct { State int } -type HandlerConf struct { +type ServerConf struct { VerifyTunnelCreate VerifyTunnelCreate VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyServerFunc VerifyServerFunc @@ -51,8 +51,8 @@ type HandlerConf struct { TokenAuth bool } -func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler { - h := &Handler{ +func NewServer(s *SessionInfo, conf *ServerConf) *Server { + h := &Server{ State: SERVER_STATE_INITIAL, Session: s, RedirectFlags: makeRedirectFlags(conf.RedirectFlags), @@ -68,9 +68,9 @@ func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler { const tunnelId = 10 -func (h *Handler) Process(ctx context.Context) error { +func (s *Server) Process(ctx context.Context) error { for { - pt, sz, pkt, err := h.ReadMessage() + pt, sz, pkt, err := s.ReadMessage() if err != nil { log.Printf("Cannot read message from stream %s", err) return err @@ -78,89 +78,89 @@ func (h *Handler) Process(ctx context.Context) error { switch pt { case PKT_TYPE_HANDSHAKE_REQUEST: - log.Printf("Client handshake from %s", client.GetClientIp(ctx)) - if h.State != SERVER_STATE_INITIAL { - log.Printf("Handshake attempted while in wrong state %d != %d", h.State, SERVER_STATE_INITIAL) + log.Printf("Client handshakeRequest from %s", client.GetClientIp(ctx)) + if s.State != SERVER_STATE_INITIAL { + log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIAL) return errors.New("wrong state") } - major, minor, _, _ := readHandshake(pkt) // todo check if auth matches what the handler can do - msg := h.handshakeResponse(major, minor) - h.Session.TransportOut.WritePacket(msg) - h.State = SERVER_STATE_HANDSHAKE + major, minor, _, _ := s.handshakeRequest(pkt) // todo check if auth matches what the handler can do + msg := s.handshakeResponse(major, minor) + s.Session.TransportOut.WritePacket(msg) + s.State = SERVER_STATE_HANDSHAKE case PKT_TYPE_TUNNEL_CREATE: log.Printf("Tunnel create") - if h.State != SERVER_STATE_HANDSHAKE { + if s.State != SERVER_STATE_HANDSHAKE { log.Printf("Tunnel create attempted while in wrong state %d != %d", - h.State, SERVER_STATE_HANDSHAKE) + s.State, SERVER_STATE_HANDSHAKE) return errors.New("wrong state") } - _, cookie := readCreateTunnelRequest(pkt) - if h.VerifyTunnelCreate != nil { - if ok, _ := h.VerifyTunnelCreate(ctx, cookie); !ok { + _, cookie := s.tunnelRequest(pkt) + if s.VerifyTunnelCreate != nil { + if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok { log.Printf("Invalid PAA cookie received from client %s", client.GetClientIp(ctx)) return errors.New("invalid PAA cookie") } } - msg := createTunnelResponse() - h.Session.TransportOut.WritePacket(msg) - h.State = SERVER_STATE_TUNNEL_CREATE + msg := s.tunnelResponse() + s.Session.TransportOut.WritePacket(msg) + s.State = SERVER_STATE_TUNNEL_CREATE case PKT_TYPE_TUNNEL_AUTH: log.Printf("Tunnel auth") - if h.State != SERVER_STATE_TUNNEL_CREATE { + if s.State != SERVER_STATE_TUNNEL_CREATE { log.Printf("Tunnel auth attempted while in wrong state %d != %d", - h.State, SERVER_STATE_TUNNEL_CREATE) + s.State, SERVER_STATE_TUNNEL_CREATE) return errors.New("wrong state") } - client := h.readTunnelAuthRequest(pkt) - if h.VerifyTunnelAuthFunc != nil { - if ok, _ := h.VerifyTunnelAuthFunc(ctx, client); !ok { + client := s.tunnelAuthRequest(pkt) + if s.VerifyTunnelAuthFunc != nil { + if ok, _ := s.VerifyTunnelAuthFunc(ctx, client); !ok { log.Printf("Invalid client name: %s", client) return errors.New("invalid client name") } } - msg := h.createTunnelAuthResponse() - h.Session.TransportOut.WritePacket(msg) - h.State = SERVER_STATE_TUNNEL_AUTHORIZE + msg := s.tunnelAuthResponse() + s.Session.TransportOut.WritePacket(msg) + s.State = SERVER_STATE_TUNNEL_AUTHORIZE case PKT_TYPE_CHANNEL_CREATE: log.Printf("Channel create") - if h.State != SERVER_STATE_TUNNEL_AUTHORIZE { + if s.State != SERVER_STATE_TUNNEL_AUTHORIZE { log.Printf("Channel create attempted while in wrong state %d != %d", - h.State, SERVER_STATE_TUNNEL_AUTHORIZE) + s.State, SERVER_STATE_TUNNEL_AUTHORIZE) return errors.New("wrong state") } - server, port := readChannelCreateRequest(pkt) + server, port := s.channelRequest(pkt) host := net.JoinHostPort(server, strconv.Itoa(int(port))) - if h.VerifyServerFunc != nil { - if ok, _ := h.VerifyServerFunc(ctx, host); !ok { + if s.VerifyServerFunc != nil { + if ok, _ := s.VerifyServerFunc(ctx, host); !ok { log.Printf("Not allowed to connect to %s by policy handler", host) return errors.New("denied by security policy") } } log.Printf("Establishing connection to RDP server: %s", host) - h.Remote, err = net.DialTimeout("tcp", host, time.Second*15) + s.Remote, err = net.DialTimeout("tcp", host, time.Second*15) if err != nil { log.Printf("Error connecting to %s, %s", host, err) return err } log.Printf("Connection established") - msg := createChannelCreateResponse() - h.Session.TransportOut.WritePacket(msg) + msg := s.channelResponse() + s.Session.TransportOut.WritePacket(msg) // Make sure to start the flow from the RDP server first otherwise connections // might hang eventually - go h.sendDataPacket() - h.State = SERVER_STATE_CHANNEL_CREATE + go s.sendDataPacket() + s.State = SERVER_STATE_CHANNEL_CREATE case PKT_TYPE_DATA: - if h.State < SERVER_STATE_CHANNEL_CREATE { - log.Printf("Data received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE) + if s.State < SERVER_STATE_CHANNEL_CREATE { + log.Printf("Data received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE) return errors.New("wrong state") } - h.State = SERVER_STATE_OPENED - h.forwardDataPacket(pkt) + s.State = SERVER_STATE_OPENED + s.forwardDataPacket(pkt) case PKT_TYPE_KEEPALIVE: // keepalives can be received while the channel is not open yet - if h.State < SERVER_STATE_CHANNEL_CREATE { - log.Printf("Keepalive received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE) + if s.State < SERVER_STATE_CHANNEL_CREATE { + log.Printf("Keepalive received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE) return errors.New("wrong state") } @@ -168,26 +168,26 @@ func (h *Handler) Process(ctx context.Context) error { // p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) case PKT_TYPE_CLOSE_CHANNEL: log.Printf("Close channel") - if h.State != SERVER_STATE_OPENED { - log.Printf("Channel closed while in wrong state %d != %d", h.State, SERVER_STATE_OPENED) + if s.State != SERVER_STATE_OPENED { + log.Printf("Channel closed while in wrong state %d != %d", s.State, SERVER_STATE_OPENED) return errors.New("wrong state") } - h.Session.TransportIn.Close() - h.Session.TransportOut.Close() - h.State = SERVER_STATE_CLOSED + s.Session.TransportIn.Close() + s.Session.TransportOut.Close() + s.State = SERVER_STATE_CLOSED default: log.Printf("Unknown packet (size %d): %x", sz, pkt) } } } -func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) { +func (s *Server) ReadMessage() (pt int, n int, msg []byte, err error) { fragment := false index := 0 buf := make([]byte, 4096) for { - size, pkt, err := h.Session.TransportIn.ReadPacket() + size, pkt, err := s.Session.TransportIn.ReadPacket() if err != nil { return 0, 0, []byte{0, 0}, err } @@ -219,15 +219,15 @@ func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) { } } -// Creates a packet the is a response to a handshake request +// Creates a packet the is a response to a handshakeRequest request // HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux // but could be in Windows. However the NTLM protocol is insecure -func (h *Handler) handshakeResponse(major byte, minor byte) []byte { +func (s *Server) handshakeResponse(major byte, minor byte) []byte { var caps uint16 - if h.SmartCardAuth { + if s.SmartCardAuth { caps = caps | HTTP_EXTENDED_AUTH_SC } - if h.TokenAuth { + if s.TokenAuth { caps = caps | HTTP_EXTENDED_AUTH_PAA } @@ -240,22 +240,7 @@ func (h *Handler) handshakeResponse(major byte, minor byte) []byte { return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes()) } -func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) { - // header needs to be 8 min - if len(data) < 8 { - return 0, 0, nil, errors.New("header too short, fragment likely") - } - r := bytes.NewReader(data) - binary.Read(r, binary.LittleEndian, &packetType) - r.Seek(4, io.SeekStart) - binary.Read(r, binary.LittleEndian, &size) - if len(data) < int(size) { - return packetType, size, data[8:], errors.New("data incomplete, fragment received") - } - return packetType, size, data[8:], nil -} - -func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth uint16) { +func (s *Server) handshakeRequest(data []byte) (major byte, minor byte, version uint16, extAuth uint16) { r := bytes.NewReader(data) binary.Read(r, binary.LittleEndian, &major) binary.Read(r, binary.LittleEndian, &minor) @@ -266,7 +251,7 @@ func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth return } -func readCreateTunnelRequest(data []byte) (caps uint32, cookie string) { +func (s *Server) tunnelRequest(data []byte) (caps uint32, cookie string) { var fields uint16 r := bytes.NewReader(data) @@ -285,7 +270,7 @@ func readCreateTunnelRequest(data []byte) (caps uint32, cookie string) { return } -func createTunnelResponse() []byte { +func (s *Server) tunnelResponse() []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.LittleEndian, uint16(0)) // server version @@ -301,7 +286,7 @@ func createTunnelResponse() []byte { return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes()) } -func (h *Handler) readTunnelAuthRequest(data []byte) string { +func (s *Server) tunnelAuthRequest(data []byte) string { buf := bytes.NewReader(data) var size uint16 @@ -313,7 +298,7 @@ func (h *Handler) readTunnelAuthRequest(data []byte) string { return clientName } -func (h *Handler) createTunnelAuthResponse() []byte { +func (s *Server) tunnelAuthResponse() []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.LittleEndian, uint32(0)) // error code @@ -321,17 +306,17 @@ func (h *Handler) createTunnelAuthResponse() []byte { binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved // idle timeout - if h.IdleTimeout < 0 { - h.IdleTimeout = 0 + if s.IdleTimeout < 0 { + s.IdleTimeout = 0 } - binary.Write(buf, binary.LittleEndian, uint32(h.RedirectFlags)) // redir flags - binary.Write(buf, binary.LittleEndian, uint32(h.IdleTimeout)) // timeout in minutes + binary.Write(buf, binary.LittleEndian, uint32(s.RedirectFlags)) // redir flags + binary.Write(buf, binary.LittleEndian, uint32(s.IdleTimeout)) // timeout in minutes return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes()) } -func readChannelCreateRequest(data []byte) (server string, port uint16) { +func (s *Server) channelRequest(data []byte) (server string, port uint16) { buf := bytes.NewReader(data) var resourcesSize byte @@ -353,7 +338,7 @@ func readChannelCreateRequest(data []byte) (server string, port uint16) { return } -func createChannelCreateResponse() []byte { +func (s *Server) channelResponse() []byte { buf := new(bytes.Buffer) binary.Write(buf, binary.LittleEndian, uint32(0)) // error code @@ -372,7 +357,7 @@ func createChannelCreateResponse() []byte { return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes()) } -func (h *Handler) forwardDataPacket(data []byte) { +func (s *Server) forwardDataPacket(data []byte) { buf := bytes.NewReader(data) var cblen uint16 @@ -380,38 +365,26 @@ func (h *Handler) forwardDataPacket(data []byte) { pkt := make([]byte, cblen) binary.Read(buf, binary.LittleEndian, &pkt) - h.Remote.Write(pkt) + s.Remote.Write(pkt) } -func (h *Handler) sendDataPacket() { - defer h.Remote.Close() +func (s *Server) sendDataPacket() { + defer s.Remote.Close() b1 := new(bytes.Buffer) buf := make([]byte, 4086) for { - n, err := h.Remote.Read(buf) + n, err := s.Remote.Read(buf) binary.Write(b1, binary.LittleEndian, uint16(n)) if err != nil { log.Printf("Error reading from conn %s", err) break } b1.Write(buf[:n]) - h.Session.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes())) + s.Session.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes())) b1.Reset() } } -func createPacket(pktType uint16, data []byte) (packet []byte) { - size := len(data) + 8 - buf := new(bytes.Buffer) - - binary.Write(buf, binary.LittleEndian, uint16(pktType)) - binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved - binary.Write(buf, binary.LittleEndian, uint32(size)) - buf.Write(data) - - return buf.Bytes() -} - func makeRedirectFlags(flags RedirectFlags) int { var redir = 0