From 0566f9048861392273f853b20c5c7114b6ffc160 Mon Sep 17 00:00:00 2001 From: Bolke de Bruin <bolke@xs4all.nl> Date: Sat, 24 Sep 2022 16:47:03 +0200 Subject: [PATCH] Make sure to use right keys --- cmd/rdpgw/protocol/client.go | 12 ++++++------ cmd/rdpgw/protocol/gateway.go | 22 +++++++++++----------- cmd/rdpgw/protocol/process.go | 6 +++--- cmd/rdpgw/protocol/tunnel.go | 12 ++++++------ cmd/rdpgw/security/basic_test.go | 2 -- cmd/rdpgw/security/jwt.go | 2 +- 6 files changed, 27 insertions(+), 29 deletions(-) diff --git a/cmd/rdpgw/protocol/client.go b/cmd/rdpgw/protocol/client.go index f28038c..c1c27ac 100644 --- a/cmd/rdpgw/protocol/client.go +++ b/cmd/rdpgw/protocol/client.go @@ -27,10 +27,10 @@ type ClientConfig struct { } func (c *ClientConfig) ConnectAndForward() error { - c.Session.TransportOut.WritePacket(c.handshakeRequest()) + c.Session.transportOut.WritePacket(c.handshakeRequest()) for { - pt, sz, pkt, err := readMessage(c.Session.TransportIn) + pt, sz, pkt, err := readMessage(c.Session.transportIn) if err != nil { log.Printf("Cannot read message from stream %s", err) return err @@ -44,7 +44,7 @@ func (c *ClientConfig) ConnectAndForward() error { return err } log.Printf("Handshake response received. Caps: %d", caps) - c.Session.TransportOut.WritePacket(c.tunnelRequest()) + c.Session.transportOut.WritePacket(c.tunnelRequest()) case PKT_TYPE_TUNNEL_RESPONSE: tid, caps, err := c.tunnelResponse(pkt) if err != nil { @@ -52,7 +52,7 @@ func (c *ClientConfig) ConnectAndForward() error { return err } log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps) - c.Session.TransportOut.WritePacket(c.tunnelAuthRequest()) + c.Session.transportOut.WritePacket(c.tunnelAuthRequest()) case PKT_TYPE_TUNNEL_AUTH_RESPONSE: flags, timeout, err := c.tunnelAuthResponse(pkt) if err != nil { @@ -60,7 +60,7 @@ func (c *ClientConfig) ConnectAndForward() error { return err } log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout) - c.Session.TransportOut.WritePacket(c.channelRequest()) + c.Session.transportOut.WritePacket(c.channelRequest()) case PKT_TYPE_CHANNEL_RESPONSE: cid, err := c.channelResponse(pkt) if err != nil { @@ -71,7 +71,7 @@ func (c *ClientConfig) ConnectAndForward() error { log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid) } log.Printf("Channel creation succesful. Channel id: %d", cid) - //go forward(c.LocalConn, c.Session.TransportOut) + //go forward(c.LocalConn, c.Session.transportOut) case PKT_TYPE_DATA: receive(pkt, c.LocalConn) default: diff --git a/cmd/rdpgw/protocol/gateway.go b/cmd/rdpgw/protocol/gateway.go index f484a49..5ce6ec7 100644 --- a/cmd/rdpgw/protocol/gateway.go +++ b/cmd/rdpgw/protocol/gateway.go @@ -165,8 +165,8 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn defer inout.Close() t.Id = uuid.New().String() - t.TransportOut = inout - t.TransportIn = inout + t.transportOut = inout + t.transportIn = inout t.ConnectedOn = time.Now() handler := NewProcessor(g, t) @@ -179,17 +179,17 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn // and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different // to ensure the connections do not get cached or terminated by a proxy prematurely. func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) { - log.Printf("Session %t, %t, %t", t.RDGId, t.TransportOut != nil, t.TransportIn != nil) + log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil) if r.Method == MethodRDGOUT { out, err := transport.NewLegacy(w) if err != nil { - log.Printf("cannot hijack connection to support RDG OUT data channel: %t", err) + log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) return } - log.Printf("Opening RDGOUT for client %t", common.GetClientIp(r.Context())) + log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context())) - t.TransportOut = out + t.transportOut = out out.SendAccept(true) c.Set(t.RDGId, t, cache.DefaultExpiration) @@ -199,23 +199,23 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t in, err := transport.NewLegacy(w) if err != nil { - log.Printf("cannot hijack connection to support RDG IN data channel: %t", err) + log.Printf("cannot hijack connection to support RDG IN data channel: %s", err) return } defer in.Close() - if t.TransportIn == nil { + if t.transportIn == nil { t.Id = uuid.New().String() - t.TransportIn = in + t.transportIn = in c.Set(t.RDGId, t, cache.DefaultExpiration) - log.Printf("Opening RDGIN for client %t", common.GetClientIp(r.Context())) + log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context())) in.SendAccept(false) // read some initial data in.Drain() - log.Printf("Legacy handshakeRequest done for client %t", common.GetClientIp(r.Context())) + log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context())) handler := NewProcessor(g, t) RegisterTunnel(t, handler) defer RemoveTunnel(t) diff --git a/cmd/rdpgw/protocol/process.go b/cmd/rdpgw/protocol/process.go index b12378b..de6262d 100644 --- a/cmd/rdpgw/protocol/process.go +++ b/cmd/rdpgw/protocol/process.go @@ -159,7 +159,7 @@ func (p *Processor) Process(ctx context.Context) error { } // avoid concurrency issues - // p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) + // p.transportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) case PKT_TYPE_CLOSE_CHANNEL: log.Printf("Close channel") if p.state != SERVER_STATE_OPENED { @@ -168,8 +168,8 @@ func (p *Processor) Process(ctx context.Context) error { } msg := p.channelCloseResponse(ERROR_SUCCESS) p.tunnel.Write(msg) - //p.tunnel.TransportIn.Close() - //p.tunnel.TransportOut.Close() + //p.tunnel.transportIn.Close() + //p.tunnel.transportOut.Close() p.state = SERVER_STATE_CLOSED return nil default: diff --git a/cmd/rdpgw/protocol/tunnel.go b/cmd/rdpgw/protocol/tunnel.go index dbf7c64..bb1cb69 100644 --- a/cmd/rdpgw/protocol/tunnel.go +++ b/cmd/rdpgw/protocol/tunnel.go @@ -12,11 +12,11 @@ type Tunnel struct { // The connection-id (RDG-ConnID) as reported by the client RDGId string // The underlying incoming transport being either websocket or legacy http - // in case of websocket TransportOut will equal TransportIn - TransportIn transport.Transport + // in case of websocket transportOut will equal transportIn + transportIn transport.Transport // The underlying outgoing transport being either websocket or legacy http - // in case of websocket TransportOut will equal TransportOut - TransportOut transport.Transport + // in case of websocket transportOut will equal transportOut + transportOut transport.Transport // The remote desktop server (rdp, vnc etc) the clients intends to connect to TargetServer string // The obtained client ip address @@ -43,7 +43,7 @@ type Tunnel struct { // Write puts the packet on the transport and updates the statistics for bytes sent func (t *Tunnel) Write(pkt []byte) { - n, _ := t.TransportOut.WritePacket(pkt) + n, _ := t.transportOut.WritePacket(pkt) t.BytesSent += int64(n) } @@ -51,7 +51,7 @@ func (t *Tunnel) Write(pkt []byte) { // packet, with the header removed, and the packet size. It updates the // statistics for bytes received func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) { - pt, size, pkt, err = readMessage(t.TransportIn) + pt, size, pkt, err = readMessage(t.transportIn) t.BytesReceived += int64(size) t.LastSeen = time.Now() diff --git a/cmd/rdpgw/security/basic_test.go b/cmd/rdpgw/security/basic_test.go index 6ab50b6..d1b6a7c 100644 --- a/cmd/rdpgw/security/basic_test.go +++ b/cmd/rdpgw/security/basic_test.go @@ -10,8 +10,6 @@ import ( var ( info = protocol.Tunnel{ RDGId: "myid", - TransportIn: nil, - TransportOut: nil, TargetServer: "my.remote.server", RemoteAddr: "10.0.0.1", UserName: "Frank", diff --git a/cmd/rdpgw/security/jwt.go b/cmd/rdpgw/security/jwt.go index cc8a3f9..c8654ec 100644 --- a/cmd/rdpgw/security/jwt.go +++ b/cmd/rdpgw/security/jwt.go @@ -289,7 +289,7 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin } func getTunnel(ctx context.Context) *protocol.Tunnel { - s, ok := ctx.Value("Tunnel").(*protocol.Tunnel) + s, ok := ctx.Value(common.TunnelCtx).(*protocol.Tunnel) if !ok { log.Printf("cannot get session info from context") return nil -- GitLab