From 0566f9048861392273f853b20c5c7114b6ffc160 Mon Sep 17 00:00:00 2001
From: Bolke de Bruin <bolke@xs4all.nl>
Date: Sat, 24 Sep 2022 16:47:03 +0200
Subject: [PATCH] Make sure to use right keys

---
 cmd/rdpgw/protocol/client.go     | 12 ++++++------
 cmd/rdpgw/protocol/gateway.go    | 22 +++++++++++-----------
 cmd/rdpgw/protocol/process.go    |  6 +++---
 cmd/rdpgw/protocol/tunnel.go     | 12 ++++++------
 cmd/rdpgw/security/basic_test.go |  2 --
 cmd/rdpgw/security/jwt.go        |  2 +-
 6 files changed, 27 insertions(+), 29 deletions(-)

diff --git a/cmd/rdpgw/protocol/client.go b/cmd/rdpgw/protocol/client.go
index f28038c..c1c27ac 100644
--- a/cmd/rdpgw/protocol/client.go
+++ b/cmd/rdpgw/protocol/client.go
@@ -27,10 +27,10 @@ type ClientConfig struct {
 }
 
 func (c *ClientConfig) ConnectAndForward() error {
-	c.Session.TransportOut.WritePacket(c.handshakeRequest())
+	c.Session.transportOut.WritePacket(c.handshakeRequest())
 
 	for {
-		pt, sz, pkt, err := readMessage(c.Session.TransportIn)
+		pt, sz, pkt, err := readMessage(c.Session.transportIn)
 		if err != nil {
 			log.Printf("Cannot read message from stream %s", err)
 			return err
@@ -44,7 +44,7 @@ func (c *ClientConfig) ConnectAndForward() error {
 				return err
 			}
 			log.Printf("Handshake response received. Caps: %d", caps)
-			c.Session.TransportOut.WritePacket(c.tunnelRequest())
+			c.Session.transportOut.WritePacket(c.tunnelRequest())
 		case PKT_TYPE_TUNNEL_RESPONSE:
 			tid, caps, err := c.tunnelResponse(pkt)
 			if err != nil {
@@ -52,7 +52,7 @@ func (c *ClientConfig) ConnectAndForward() error {
 				return err
 			}
 			log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps)
-			c.Session.TransportOut.WritePacket(c.tunnelAuthRequest())
+			c.Session.transportOut.WritePacket(c.tunnelAuthRequest())
 		case PKT_TYPE_TUNNEL_AUTH_RESPONSE:
 			flags, timeout, err := c.tunnelAuthResponse(pkt)
 			if err != nil {
@@ -60,7 +60,7 @@ func (c *ClientConfig) ConnectAndForward() error {
 				return err
 			}
 			log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout)
-			c.Session.TransportOut.WritePacket(c.channelRequest())
+			c.Session.transportOut.WritePacket(c.channelRequest())
 		case PKT_TYPE_CHANNEL_RESPONSE:
 			cid, err := c.channelResponse(pkt)
 			if err != nil {
@@ -71,7 +71,7 @@ func (c *ClientConfig) ConnectAndForward() error {
 				log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
 			}
 			log.Printf("Channel creation succesful. Channel id: %d", cid)
-			//go forward(c.LocalConn, c.Session.TransportOut)
+			//go forward(c.LocalConn, c.Session.transportOut)
 		case PKT_TYPE_DATA:
 			receive(pkt, c.LocalConn)
 		default:
diff --git a/cmd/rdpgw/protocol/gateway.go b/cmd/rdpgw/protocol/gateway.go
index f484a49..5ce6ec7 100644
--- a/cmd/rdpgw/protocol/gateway.go
+++ b/cmd/rdpgw/protocol/gateway.go
@@ -165,8 +165,8 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
 	defer inout.Close()
 
 	t.Id = uuid.New().String()
-	t.TransportOut = inout
-	t.TransportIn = inout
+	t.transportOut = inout
+	t.transportIn = inout
 	t.ConnectedOn = time.Now()
 
 	handler := NewProcessor(g, t)
@@ -179,17 +179,17 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
 // and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different
 // to ensure the connections do not get cached or terminated by a proxy prematurely.
 func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) {
-	log.Printf("Session %t, %t, %t", t.RDGId, t.TransportOut != nil, t.TransportIn != nil)
+	log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil)
 
 	if r.Method == MethodRDGOUT {
 		out, err := transport.NewLegacy(w)
 		if err != nil {
-			log.Printf("cannot hijack connection to support RDG OUT data channel: %t", err)
+			log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
 			return
 		}
-		log.Printf("Opening RDGOUT for client %t", common.GetClientIp(r.Context()))
+		log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context()))
 
-		t.TransportOut = out
+		t.transportOut = out
 		out.SendAccept(true)
 
 		c.Set(t.RDGId, t, cache.DefaultExpiration)
@@ -199,23 +199,23 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
 
 		in, err := transport.NewLegacy(w)
 		if err != nil {
-			log.Printf("cannot hijack connection to support RDG IN data channel: %t", err)
+			log.Printf("cannot hijack connection to support RDG IN data channel: %s", err)
 			return
 		}
 		defer in.Close()
 
-		if t.TransportIn == nil {
+		if t.transportIn == nil {
 			t.Id = uuid.New().String()
-			t.TransportIn = in
+			t.transportIn = in
 			c.Set(t.RDGId, t, cache.DefaultExpiration)
 
-			log.Printf("Opening RDGIN for client %t", common.GetClientIp(r.Context()))
+			log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context()))
 			in.SendAccept(false)
 
 			// read some initial data
 			in.Drain()
 
-			log.Printf("Legacy handshakeRequest done for client %t", common.GetClientIp(r.Context()))
+			log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
 			handler := NewProcessor(g, t)
 			RegisterTunnel(t, handler)
 			defer RemoveTunnel(t)
diff --git a/cmd/rdpgw/protocol/process.go b/cmd/rdpgw/protocol/process.go
index b12378b..de6262d 100644
--- a/cmd/rdpgw/protocol/process.go
+++ b/cmd/rdpgw/protocol/process.go
@@ -159,7 +159,7 @@ func (p *Processor) Process(ctx context.Context) error {
 			}
 
 			// avoid concurrency issues
-			// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
+			// p.transportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
 		case PKT_TYPE_CLOSE_CHANNEL:
 			log.Printf("Close channel")
 			if p.state != SERVER_STATE_OPENED {
@@ -168,8 +168,8 @@ func (p *Processor) Process(ctx context.Context) error {
 			}
 			msg := p.channelCloseResponse(ERROR_SUCCESS)
 			p.tunnel.Write(msg)
-			//p.tunnel.TransportIn.Close()
-			//p.tunnel.TransportOut.Close()
+			//p.tunnel.transportIn.Close()
+			//p.tunnel.transportOut.Close()
 			p.state = SERVER_STATE_CLOSED
 			return nil
 		default:
diff --git a/cmd/rdpgw/protocol/tunnel.go b/cmd/rdpgw/protocol/tunnel.go
index dbf7c64..bb1cb69 100644
--- a/cmd/rdpgw/protocol/tunnel.go
+++ b/cmd/rdpgw/protocol/tunnel.go
@@ -12,11 +12,11 @@ type Tunnel struct {
 	// The connection-id (RDG-ConnID) as reported by the client
 	RDGId string
 	// The underlying incoming transport being either websocket or legacy http
-	// in case of websocket TransportOut will equal TransportIn
-	TransportIn transport.Transport
+	// in case of websocket transportOut will equal transportIn
+	transportIn transport.Transport
 	// The underlying outgoing transport being either websocket or legacy http
-	// in case of websocket TransportOut will equal TransportOut
-	TransportOut transport.Transport
+	// in case of websocket transportOut will equal transportOut
+	transportOut transport.Transport
 	// The remote desktop server (rdp, vnc etc) the clients intends to connect to
 	TargetServer string
 	// The obtained client ip address
@@ -43,7 +43,7 @@ type Tunnel struct {
 
 // Write puts the packet on the transport and updates the statistics for bytes sent
 func (t *Tunnel) Write(pkt []byte) {
-	n, _ := t.TransportOut.WritePacket(pkt)
+	n, _ := t.transportOut.WritePacket(pkt)
 	t.BytesSent += int64(n)
 }
 
@@ -51,7 +51,7 @@ func (t *Tunnel) Write(pkt []byte) {
 // packet, with the header removed, and the packet size. It updates the
 // statistics for bytes received
 func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) {
-	pt, size, pkt, err = readMessage(t.TransportIn)
+	pt, size, pkt, err = readMessage(t.transportIn)
 	t.BytesReceived += int64(size)
 	t.LastSeen = time.Now()
 
diff --git a/cmd/rdpgw/security/basic_test.go b/cmd/rdpgw/security/basic_test.go
index 6ab50b6..d1b6a7c 100644
--- a/cmd/rdpgw/security/basic_test.go
+++ b/cmd/rdpgw/security/basic_test.go
@@ -10,8 +10,6 @@ import (
 var (
 	info = protocol.Tunnel{
 		RDGId:        "myid",
-		TransportIn:  nil,
-		TransportOut: nil,
 		TargetServer: "my.remote.server",
 		RemoteAddr:   "10.0.0.1",
 		UserName:     "Frank",
diff --git a/cmd/rdpgw/security/jwt.go b/cmd/rdpgw/security/jwt.go
index cc8a3f9..c8654ec 100644
--- a/cmd/rdpgw/security/jwt.go
+++ b/cmd/rdpgw/security/jwt.go
@@ -289,7 +289,7 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin
 }
 
 func getTunnel(ctx context.Context) *protocol.Tunnel {
-	s, ok := ctx.Value("Tunnel").(*protocol.Tunnel)
+	s, ok := ctx.Value(common.TunnelCtx).(*protocol.Tunnel)
 	if !ok {
 		log.Printf("cannot get session info from context")
 		return nil
-- 
GitLab