Newer
Older
package main
import (
"bufio"
"bytes"
"encoding/binary"
"github.com/prometheus/client_golang/prometheus"
"io"
"log"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"strconv"
"time"
"unicode/utf16"
"unicode/utf8"
)
const (
rdgConnectionIdKey = "Rdg-Connection-Id"
)
const (
PKT_TYPE_HANDSHAKE_REQUEST = 0x1
PKT_TYPE_HANDSHAKE_RESPONSE = 0x2
PKT_TYPE_EXTENDED_AUTH_MSG = 0x3
PKT_TYPE_TUNNEL_CREATE = 0x4
PKT_TYPE_TUNNEL_RESPONSE = 0x5
PKT_TYPE_TUNNEL_AUTH = 0x6
PKT_TYPE_TUNNEL_AUTH_RESPONSE = 0x7
PKT_TYPE_CHANNEL_CREATE = 0x8
PKT_TYPE_CHANNEL_RESPONSE = 0x9
PKT_TYPE_DATA = 0xA
PKT_TYPE_SERVICE_MESSAGE = 0xB
PKT_TYPE_REAUTH_MESSAGE = 0xC
PKT_TYPE_KEEPALIVE = 0xD
PKT_TYPE_CLOSE_CHANNEL = 0x10
PKT_TYPE_CLOSE_CHANNEL_RESPONSE = 0x11
)
const (
HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID = 0x01
HTTP_TUNNEL_RESPONSE_FIELD_CAPS = 0x02
HTTP_TUNNEL_RESPONSE_FIELD_SOH_REQ = 0x04
HTTP_TUNNEL_RESPONSE_FIELD_CONSENT_MSG = 0x10
)
const (
HTTP_EXTENDED_AUTH_NONE = 0x0
HTTP_EXTENDED_AUTH_SC = 0x1 /* Smart card authentication. */
HTTP_EXTENDED_AUTH_PAA = 0x02 /* Pluggable authentication. */
HTTP_EXTENDED_AUTH_SSPI_NTLM = 0x04 /* NTLM extended authentication. */
)
const (
HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS = 0x01
HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT = 0x02
HTTP_TUNNEL_AUTH_RESPONSE_FIELD_SOH_RESPONSE = 0x04
)
const (
HTTP_TUNNEL_REDIR_ENABLE_ALL = 0x80000000
HTTP_TUNNEL_REDIR_DISABLE_ALL = 0x40000000
HTTP_TUNNEL_REDIR_DISABLE_DRIVE = 0x01
HTTP_TUNNEL_REDIR_DISABLE_PRINTER = 0x02
HTTP_TUNNEL_REDIR_DISABLE_PORT = 0x03
HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD = 0x08
HTTP_TUNNEL_REDIR_DISABLE_PNP = 0x10
)
const (
HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID = 0x01
HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE = 0x02
HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT = 0x04
)
const (
HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1
)
var (
connectionCache = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "connection_cache",
Help: "The amount of connections in the cache",
})
websocketConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "websocket_connections",
Help: "The count of websocket connections",
})
legacyConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "legacy_connections",
Help: "The count of legacy https connections",
})
)
// HandshakeHeader is the interface that writes both upgrade request or
// response headers into a given io.Writer.
type HandshakeHeader interface {
io.WriterTo
}
type RdgSession struct {
ConnId string
CorrelationId string
UserId string
ConnIn net.Conn
ConnOut net.Conn
}
// ErrNotHijacker is an error returned when http.ResponseWriter does not
// implement http.Hijacker interface.
var ErrNotHijacker = RejectConnectionError(
RejectionStatus(http.StatusInternalServerError),
RejectionReason("given http.ResponseWriter is not a http.Hijacker"),
)
var DefaultSession RdgSession
func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) {
log.Print("Accept connection")
hj, ok := w.(http.Hijacker)
if ok {
return hj.Hijack()
} else {
err = ErrNotHijacker
}
if err != nil {
httpError(w, err.Error(), http.StatusInternalServerError)
return nil, nil, err
}
return
var c = cache.New(5*time.Minute, 10*time.Minute)
func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
connectionCache.Set(float64(c.ItemCount()))
//if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
handleLegacyProtocol(w, r)
r.Method = "GET" // force
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("Cannot upgrade falling back to old protocol: %s", err)
return
defer conn.Close()
handleWebsocketProtocol(conn)
} else if r.Method == MethodRDGIN {
handleLegacyProtocol(w, r)
}
func handleWebsocketProtocol(conn *websocket.Conn) {
fragment := false
buf := make([]byte, 4096)
index := 0
var remote net.Conn
websocketConnections.Inc()
defer websocketConnections.Dec()
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
for {
mt, msg, err := conn.ReadMessage()
if err != nil {
log.Printf("Error read: %s", err)
break
}
log.Printf("Message type: %d, message: %x", mt, msg)
// check for fragments
var pt uint16
var sz uint32
var pkt []byte
if !fragment {
pt, sz, pkt, err = readHeader(msg)
if err != nil {
// fragment received
log.Printf("Received non websocket fragment")
fragment = true
index = copy(buf, msg)
continue
}
index = 0
} else {
log.Printf("Dealing with fragment")
fragment = false
pt, sz, pkt, _ = readHeader(append(buf[:index], msg...))
}
switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST:
major, minor, _, auth := readHandshake(pkt)
msg := handshakeResponse(major, minor, auth)
log.Printf("Handshake response: %x", msg)
conn.WriteMessage(mt, msg)
case PKT_TYPE_TUNNEL_CREATE:
readCreateTunnelRequest(pkt)
msg := createTunnelResponse()
log.Printf("Create tunnel response: %x", msg)
conn.WriteMessage(mt, msg)
case PKT_TYPE_TUNNEL_AUTH:
readTunnelAuthRequest(pkt)
msg := createTunnelAuthResponse()
log.Printf("Create tunnel auth response: %x", msg)
conn.WriteMessage(mt, msg)
case PKT_TYPE_CHANNEL_CREATE:
server, port := readChannelCreateRequest(pkt)
remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port))))
if err != nil {
log.Printf("Error connecting to %s, %d, %s", server, port, err)
return
}
msg := createChannelCreateResponse()
log.Printf("Create channel create response: %x", msg)
conn.WriteMessage(mt, msg)
go handleWebsocketData(remote, mt, conn)
case PKT_TYPE_DATA:
forwardDataPacket(remote, pkt)
case PKT_TYPE_KEEPALIVE:
// do not write to make sure we do not create concurrency issues
// conn.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz, pkt)
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different
// to ensure the connections do not get cached or terminated by a proxy prematurely.
func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
var s RdgSession
connId := r.Header.Get(rdgConnectionIdKey)
x, found := c.Get(connId)
if !found {
log.Printf("No cached session found")
s = RdgSession{ConnId: connId, StateIn: 0, StateOut: 0}
} else {
log.Printf("Found cached session")
s = x.(RdgSession)
}
log.Printf("Session %s, %t, %t", s.ConnId, s.ConnOut != nil, s.ConnIn != nil)
if r.Method == MethodRDGOUT {
conn, rw, err := Accept(w)
if err != nil {
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
return
}
log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String())
s.ConnOut = conn
WriteAcceptSeed(rw.Writer, true)
c.Set(connId, s, cache.DefaultExpiration)
} else if r.Method == MethodRDGIN {
legacyConnections.Inc()
defer legacyConnections.Dec()
var remote net.Conn
conn, rw, err := Accept(w)
if err != nil {
log.Printf("cannot hijack connection to support RDG IN data channel: %s", err)
return
}
defer conn.Close()
if s.ConnIn == nil {
fragment := false
index := 0
buf := make([]byte, 4096)
s.ConnIn = conn
c.Set(connId, s, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %s", conn.RemoteAddr().String())
WriteAcceptSeed(rw.Writer, false)
p := make([]byte, 32767)
rw.Reader.Read(p)
log.Printf("Reading packet from client %s", conn.RemoteAddr().String())
chunkScanner := httputil.NewChunkedReader(rw.Reader)
msg := make([]byte, 4096) // bufio.defaultBufSize
n, err := chunkScanner.Read(msg)
if err == io.EOF || n == 0 {
break
}
// check for fragments
var pt uint16
var sz uint32
var pkt []byte
if !fragment {
pt, sz, pkt, err = readHeader(msg[:n])
if err != nil {
// fragment received
log.Printf("Received non websocket fragment")
fragment = true
index = copy(buf, msg[:n])
continue
}
index = 0
} else {
log.Printf("Dealing with fragment")
fragment = false
pt, sz, pkt, _ = readHeader(append(buf[:index], msg[:n]...))
log.Printf("Scanned packet got packet type %x size %d", pt, sz)
switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST:
major, minor, _, auth := readHandshake(pkt)
msg := handshakeResponse(major, minor, auth)
s.ConnOut.Write(msg)
case PKT_TYPE_TUNNEL_CREATE:
readCreateTunnelRequest(pkt)
msg := createTunnelResponse()
s.ConnOut.Write(msg)
case PKT_TYPE_TUNNEL_AUTH:
readTunnelAuthRequest(pkt)
msg := createTunnelAuthResponse()
s.ConnOut.Write(msg)
case PKT_TYPE_CHANNEL_CREATE:
server, port := readChannelCreateRequest(pkt)
var err error
remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port))))
if err != nil {
log.Printf("Error connecting to %s, %d, %s", server, port, err)
return
}
msg := createChannelCreateResponse()
s.ConnOut.Write(msg)
// Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually
go sendDataPacket(remote, s.ConnOut)
case PKT_TYPE_DATA:
forwardDataPacket(remote, pkt)
case PKT_TYPE_KEEPALIVE:
// avoid concurrency issues
// s.ConnOut.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL:
s.ConnIn.Close()
s.ConnOut.Close()
break
default:
log.Printf("Unknown packet (size %d): %x", sz, pkt[:n])
}
}
}
}
}
// [MS-TSGU]: Terminal Services Gateway Server Protocol version 39.0
// The server sends back the final status code 200 OK, and also a random entity body of limited size (100 bytes).
// This enables a reverse proxy to start allowing data from the RDG server to the RDG client. The RDG server does
// not specify an entity length in its response. It uses HTTP 1.0 semantics to send the entity body and closes the
// connection after the last byte is sent.
func WriteAcceptSeed(bw *bufio.Writer, doSeed bool) {
log.Printf("Writing accept")
bw.WriteString("Date: " + time.Now().Format(time.RFC1123) + crlf)
bw.WriteString("Content-Length: 0" + crlf)
if doSeed {
seed := make([]byte, 10)
rand.Read(seed)
// docs say it's a seed but 2019 responds with ab cd * 5
bw.Write(seed)
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
// header needs to be 8 min
if len(data) < 8 {
return 0, 0, nil, errors.New("header too short, fragment likely")
}
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &packetType)
r.Seek(4, io.SeekStart)
binary.Read(r, binary.LittleEndian, &size)
if len(data) < int(size) {
return packetType, size, data[8:], errors.New("data incomplete, fragment received")
}
return packetType, size, data[8:], nil
// Creates a packet the is a response to a handshake request
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
// but could be in Windows. However the NTLM protocol is insecure
func handshakeResponse(major byte, minor byte, auth uint16) []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code
buf.Write([]byte{major, minor})
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
binary.Write(buf, binary.LittleEndian, uint16(HTTP_EXTENDED_AUTH_PAA|HTTP_EXTENDED_AUTH_SC)) // extended auth
return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())
}
func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &major)
binary.Read(r, binary.LittleEndian, &minor)
binary.Read(r, binary.LittleEndian, &version)
binary.Read(r, binary.LittleEndian, &extAuth)
log.Printf("major: %d, minor: %d, version: %d, ext auth: %d", major, minor, version, extAuth)
return
}
func readCreateTunnelRequest(data []byte) (caps uint32, cookie string) {
var fields uint16
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &caps)
binary.Read(r, binary.LittleEndian, &fields)
r.Seek(2, io.SeekCurrent)
if fields == HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE {
var size uint16
binary.Read(r, binary.LittleEndian, &size)
cookieB := make([]byte, size)
r.Read(cookieB)
cookie, _ = DecodeUTF16(cookieB)
log.Printf("Create tunnel caps: %d, cookie: %s", caps, cookie)
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID|HTTP_TUNNEL_RESPONSE_FIELD_CAPS)) // fields present
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
// tunnel id ?
binary.Write(buf, binary.LittleEndian, uint32(15))
// caps ?
binary.Write(buf, binary.LittleEndian, uint32(2))
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
}
func readTunnelAuthRequest(data []byte) {
buf := bytes.NewReader(data)
var size uint16
binary.Read(buf, binary.LittleEndian, &size)
clData := make([]byte, size)
binary.Read(buf, binary.LittleEndian, &clData)
clientName, _ := DecodeUTF16(clData)
log.Printf("Client: %s", clientName)
}
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
// flags
binary.Write(buf, binary.LittleEndian, uint32(HTTP_TUNNEL_REDIR_ENABLE_ALL)) // redir flags
binary.Write(buf, binary.LittleEndian, uint32(0)) // timeout in minutes
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
func readChannelCreateRequest(data []byte) (server string, port uint16) {
buf := bytes.NewReader(data)
var resourcesSize byte
var alternative byte
var protocol uint16
var nameSize uint16
binary.Read(buf, binary.LittleEndian, &resourcesSize)
binary.Read(buf, binary.LittleEndian, &alternative)
binary.Read(buf, binary.LittleEndian, &port)
binary.Read(buf, binary.LittleEndian, &protocol)
binary.Read(buf, binary.LittleEndian, &nameSize)
nameData := make([]byte, nameSize)
binary.Read(buf, binary.LittleEndian, &nameData)
log.Printf("Name data %q", nameData)
server, _ = DecodeUTF16(nameData)
log.Printf("Should connect to %s on port %d", server, port)
return
}
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
//binary.Write(buf, binary.LittleEndian, uint16(HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID | HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE | HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT)) // fields present
binary.Write(buf, binary.LittleEndian, uint16(0)) // fields
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
// optional fields
// channel id uint32 (4)
// udp port uint16 (2)
// udp auth cookie 1 byte for side channel
// length uint16
return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
func createPacket(pktType uint16, data []byte) (packet []byte) {
size := len(data) + 8
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint16(pktType))
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
binary.Write(buf, binary.LittleEndian, uint32(size))
buf.Write(data)
return buf.Bytes()
}
func forwardDataPacket(conn net.Conn, data []byte) {
buf := bytes.NewReader(data)
var cblen uint16
binary.Read(buf, binary.LittleEndian, &cblen)
binary.Read(buf, binary.LittleEndian, &pkt)
//n, _ := buf.Read(pkt)
//log.Printf("CBLEN: %d, N: %d", cblen, n)
//log.Printf("DATA FROM CLIENT %q", pkt)
conn.Write(pkt)
}
func handleWebsocketData(rdp net.Conn, mt int, conn *websocket.Conn) {
defer rdp.Close()
binary.Write(b1, binary.LittleEndian, uint16(n))
log.Printf("RDP SIZE: %d", n)
if err != nil {
log.Printf("Error reading from conn %s", err)
break
}
b1.Write(buf[:n])
conn.WriteMessage(mt, createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset()
}
}
func sendDataPacket(connIn net.Conn, connOut net.Conn) {
defer connIn.Close()
b1 := new(bytes.Buffer)
buf := make([]byte, 4086)
for {
n, err := connIn.Read(buf)
binary.Write(b1, binary.LittleEndian, uint16(n))
log.Printf("RDP SIZE: %d", n)
if err != nil {
log.Printf("Error reading from conn %s", err)
break
}
b1.Write(buf[:n])
connOut.Write(createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset()
}
}
func DecodeUTF16(b []byte) (string, error) {
if len(b)%2 != 0 {
return "", fmt.Errorf("must have even length byte slice")
}
u16s := make([]uint16, 1)
ret := &bytes.Buffer{}
b8buf := make([]byte, 4)
lb := len(b)
for i := 0; i < lb; i += 2 {
u16s[0] = uint16(b[i]) + (uint16(b[i+1]) << 8)
r := utf16.Decode(u16s)
n := utf8.EncodeRune(b8buf, r[0])
ret.Write(b8buf[:n])
}
return ret.String(), nil