diff --git a/rdg.go b/rdg.go index eb816e14bc5ab5068d0e9025fc669a60e7fe33d6..0cd309459cc7e90764680f90935303befd74aae8 100644 --- a/rdg.go +++ b/rdg.go @@ -1,7 +1,6 @@ package main import ( - "bufio" "bytes" "encoding/binary" "errors" @@ -15,7 +14,6 @@ import ( "net" "net/http" "strconv" - "strings" "time" "unicode/utf16" "unicode/utf8" @@ -117,8 +115,8 @@ type RdgSession struct { ConnId string CorrelationId string UserId string - TransportIn transport.HttpLayer - TransportOut transport.HttpLayer + TransportIn transport.Transport + TransportOut transport.Transport StateIn int StateOut int Remote net.Conn @@ -133,21 +131,6 @@ var ErrNotHijacker = RejectConnectionError( var DefaultSession RdgSession -func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) { - log.Print("Accept connection") - hj, ok := w.(http.Hijacker) - if ok { - return hj.Hijack() - } else { - err = ErrNotHijacker - } - if err != nil { - httpError(w, err.Error(), http.StatusInternalServerError) - return nil, nil, err - } - return -} - var upgrader = websocket.Upgrader{} var c = cache.New(5*time.Minute, 10*time.Minute) @@ -172,7 +155,7 @@ func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) { } } -func handleWebsocketProtocol(conn *websocket.Conn) { +func handleWebsocketProtocol(c *websocket.Conn) { fragment := false buf := make([]byte, 4096) index := 0 @@ -182,9 +165,11 @@ func handleWebsocketProtocol(conn *websocket.Conn) { websocketConnections.Inc() defer websocketConnections.Dec() + inout, _ := transport.NewWS(c) + var host string for { - mt, msg, err := conn.ReadMessage() + _, msg, err := inout.ReadPacket() if err != nil { log.Printf("Error read: %s", err) break @@ -216,28 +201,29 @@ func handleWebsocketProtocol(conn *websocket.Conn) { major, minor, _, auth := readHandshake(pkt) msg := handshakeResponse(major, minor, auth) log.Printf("Handshake response: %x", msg) - conn.WriteMessage(mt, msg) + inout.WritePacket(msg) case PKT_TYPE_TUNNEL_CREATE: - _, cookie := readCreateTunnelRequest(pkt) - data, found := tokens.Get(cookie) + readCreateTunnelRequest(pkt) + /*data, found := tokens.Get(cookie) if found == false { - log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr()) + log.Printf("Invalid PAA cookie: %s from %s", cookie, inout.Conn.RemoteAddr()) return - } + }*/ host = conf.Server.HostTemplate + /* for k, v := range data.(map[string]interface{}) { if val, ok := v.(string); ok == true { host = strings.Replace(host, "{{ " + k + " }}", val, 1) } - } + }*/ msg := createTunnelResponse() log.Printf("Create tunnel response: %x", msg) - conn.WriteMessage(mt, msg) + inout.WritePacket(msg) case PKT_TYPE_TUNNEL_AUTH: readTunnelAuthRequest(pkt) msg := createTunnelAuthResponse() log.Printf("Create tunnel auth response: %x", msg) - conn.WriteMessage(mt, msg) + inout.WritePacket(msg) case PKT_TYPE_CHANNEL_CREATE: server, port := readChannelCreateRequest(pkt) if conf.Server.EnableOverride == true { @@ -256,13 +242,13 @@ func handleWebsocketProtocol(conn *websocket.Conn) { log.Printf("Connection established") msg := createChannelCreateResponse() log.Printf("Create channel create response: %x", msg) - conn.WriteMessage(mt, msg) - go handleWebsocketData(remote, mt, conn) + inout.WritePacket(msg) + go sendDataPacket(remote, inout) case PKT_TYPE_DATA: forwardDataPacket(remote, pkt) case PKT_TYPE_KEEPALIVE: // do not write to make sure we do not create concurrency issues - // conn.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{})) + // inout.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{})) case PKT_TYPE_CLOSE_CHANNEL: break default: @@ -612,7 +598,7 @@ func forwardDataPacket(conn net.Conn, data []byte) { conn.Write(pkt) } -func handleWebsocketData(rdp net.Conn, mt int, conn *websocket.Conn) { +func handleWebsocketData(rdp net.Conn, conn transport.Transport) { defer rdp.Close() b1 := new(bytes.Buffer) buf := make([]byte, 4086) @@ -625,12 +611,12 @@ func handleWebsocketData(rdp net.Conn, mt int, conn *websocket.Conn) { break } b1.Write(buf[:n]) - conn.WriteMessage(mt, createPacket(PKT_TYPE_DATA, b1.Bytes())) + conn.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes())) b1.Reset() } } -func sendDataPacket(connIn net.Conn, connOut transport.HttpLayer) { +func sendDataPacket(connIn net.Conn, connOut transport.Transport) { defer connIn.Close() b1 := new(bytes.Buffer) buf := make([]byte, 4086) diff --git a/transport/transport.go b/transport/transport.go index 92cb89457943e5371a3bec6a00ec2ca8ea2a9e21..fa6e8aa0c26ac966e6177b960f0ee4a5bc8a84dc 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -1,6 +1,6 @@ package transport -type HttpLayer interface { +type Transport interface { ReadPacket() (n int, p []byte, err error) WritePacket(b []byte) (n int, err error) Close() error