diff --git a/cmd/rdpgw/common/remote.go b/cmd/rdpgw/common/remote.go index d6d97ee6b0132a71c7024a054914698a6f5c20c1..9e919425263883bdc680420347ae76bd4912c6fa 100644 --- a/cmd/rdpgw/common/remote.go +++ b/cmd/rdpgw/common/remote.go @@ -12,6 +12,8 @@ const ( ClientIPCtx = "ClientIP" ProxyAddressesCtx = "ProxyAddresses" RemoteAddressCtx = "RemoteAddress" + TunnelCtx = "TUNNEL" + UsernameCtx = "preferred_username" ) func EnrichContext(next http.Handler) http.Handler { @@ -57,4 +59,4 @@ func GetAccessToken(ctx context.Context) string { return "" } return token -} \ No newline at end of file +} diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index 2b3bea0ea6b6d286c828118afa72d9a737bdf82e..26caeed986ec11f662cc760d637cf41a4aff9afa 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -231,6 +231,8 @@ var gwserver *protocol.Gateway func List(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") for k, v := range protocol.Connections { - fmt.Fprintf(w, "RDGId: %s Connected Since: %s User: %s \n", k, v.Since, v.SessionInfo.UserName) + fmt.Fprintf(w, "Id: %s Rdg-Id: %s User: %s From: %s Connected Since: %s Bytes Sent: %d Bytes Received: %d Last Seen: %s Target: %s\n", + k, v.Tunnel.RDGId, v.Tunnel.UserName, v.Tunnel.RemoteAddr, v.Tunnel.ConnectedOn, v.Tunnel.BytesSent, v.Tunnel.BytesReceived, + v.Tunnel.LastSeen, v.Tunnel.TargetServer) } } diff --git a/cmd/rdpgw/protocol/client.go b/cmd/rdpgw/protocol/client.go index c7e5459e4de27c291bcc2d17751c8d7a5f98a148..f28038cbaa852b9d00fb952e5d80ca383e57e552 100644 --- a/cmd/rdpgw/protocol/client.go +++ b/cmd/rdpgw/protocol/client.go @@ -71,7 +71,7 @@ func (c *ClientConfig) ConnectAndForward() error { log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid) } log.Printf("Channel creation succesful. Channel id: %d", cid) - go forward(c.LocalConn, c.Session.TransportOut) + //go forward(c.LocalConn, c.Session.TransportOut) case PKT_TYPE_DATA: receive(pkt, c.LocalConn) default: diff --git a/cmd/rdpgw/protocol/common.go b/cmd/rdpgw/protocol/common.go index 3f75a3363bfbe41058218c083780cd3903efaacb..a8875b2d13cf162a64f15fbc64f430b04f6ec8a4 100644 --- a/cmd/rdpgw/protocol/common.go +++ b/cmd/rdpgw/protocol/common.go @@ -92,7 +92,7 @@ func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err } // forwards data from a Connection to Transport and wraps it in the rdpgw protocol -func forward(in net.Conn, out transport.Transport) { +func forward(in net.Conn, tunnel *Tunnel) { defer in.Close() b1 := new(bytes.Buffer) @@ -106,7 +106,7 @@ func forward(in net.Conn, out transport.Transport) { } binary.Write(b1, binary.LittleEndian, uint16(n)) b1.Write(buf[:n]) - out.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes())) + tunnel.Write(createPacket(PKT_TYPE_DATA, b1.Bytes())) b1.Reset() } } diff --git a/cmd/rdpgw/protocol/gateway.go b/cmd/rdpgw/protocol/gateway.go index 49fbe4cba563e508df5d62cdf6c61ae7c9999389..f484a492488ef3f524b7c7ac85db780b3e4c51e5 100644 --- a/cmd/rdpgw/protocol/gateway.go +++ b/cmd/rdpgw/protocol/gateway.go @@ -5,6 +5,7 @@ import ( "errors" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" + "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/patrickmn/go-cache" "log" @@ -16,7 +17,7 @@ import ( ) const ( - rdgConnectionIdKey = "Rdg-Connection-RDGId" + rdgConnectionIdKey = "Rdg-Connection-Id" MethodRDGIN = "RDG_IN_DATA" MethodRDGOUT = "RDG_OUT_DATA" ) @@ -59,14 +60,19 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) var t *Tunnel + ctx := context.WithValue(r.Context(), common.TunnelCtx, t) + connId := r.Header.Get(rdgConnectionIdKey) x, found := c.Get(connId) if !found { - t = &Tunnel{RDGId: connId} + t = &Tunnel{ + RDGId: connId, + RemoteAddr: ctx.Value(common.ClientIPCtx).(string), + UserName: ctx.Value(common.UsernameCtx).(string), + } } else { t = x.(*Tunnel) } - ctx := context.WithValue(r.Context(), "Tunnel", t) if r.Method == MethodRDGOUT { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { @@ -158,13 +164,14 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn inout, _ := transport.NewWS(c) defer inout.Close() + t.Id = uuid.New().String() t.TransportOut = inout t.TransportIn = inout t.ConnectedOn = time.Now() handler := NewProcessor(g, t) - RegisterConnection(handler, t) - defer RemoveConnection(t.RDGId) + RegisterTunnel(t, handler) + defer RemoveTunnel(t) handler.Process(ctx) } @@ -198,6 +205,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t defer in.Close() if t.TransportIn == nil { + t.Id = uuid.New().String() t.TransportIn = in c.Set(t.RDGId, t, cache.DefaultExpiration) @@ -209,8 +217,8 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t log.Printf("Legacy handshakeRequest done for client %t", common.GetClientIp(r.Context())) handler := NewProcessor(g, t) - RegisterConnection(handler, t) - defer RemoveConnection(t.RDGId) + RegisterTunnel(t, handler) + defer RemoveTunnel(t) handler.Process(r.Context()) } } diff --git a/cmd/rdpgw/protocol/process.go b/cmd/rdpgw/protocol/process.go index 7c66fa01e610d356d4ce9b1451f3885e129eee06..b12378b911fb181a5b8251e8ec0d96d59d022919 100644 --- a/cmd/rdpgw/protocol/process.go +++ b/cmd/rdpgw/protocol/process.go @@ -47,7 +47,7 @@ func (p *Processor) Process(ctx context.Context) error { switch pt { case PKT_TYPE_HANDSHAKE_REQUEST: - log.Printf("Client handshakeRequest from %p", common.GetClientIp(ctx)) + log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx)) if p.state != SERVER_STATE_INITIALIZED { log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED) msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR) @@ -77,7 +77,7 @@ func (p *Processor) Process(ctx context.Context) error { _, cookie := p.tunnelRequest(pkt) if p.gw.CheckPAACookie != nil { if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok { - log.Printf("Invalid PAA cookie received from client %p", common.GetClientIp(ctx)) + log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx)) msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) p.tunnel.Write(msg) return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) @@ -98,7 +98,7 @@ func (p *Processor) Process(ctx context.Context) error { client := p.tunnelAuthRequest(pkt) if p.gw.CheckClientName != nil { if ok, _ := p.gw.CheckClientName(ctx, client); !ok { - log.Printf("Invalid client name: %p", client) + log.Printf("Invalid client name: %s", client) msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED) p.tunnel.Write(msg) return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED) @@ -119,18 +119,18 @@ func (p *Processor) Process(ctx context.Context) error { server, port := p.channelRequest(pkt) host := net.JoinHostPort(server, strconv.Itoa(int(port))) if p.gw.CheckHost != nil { - log.Printf("Verifying %p host connection", host) + log.Printf("Verifying %s host connection", host) if ok, _ := p.gw.CheckHost(ctx, host); !ok { - log.Printf("Not allowed to connect to %p by policy handler", host) + log.Printf("Not allowed to connect to %s by policy handler", host) msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED) p.tunnel.Write(msg) return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED) } } - log.Printf("Establishing connection to RDP server: %p", host) + log.Printf("Establishing connection to RDP server: %s", host) p.tunnel.rwc, err = net.DialTimeout("tcp", host, time.Second*15) if err != nil { - log.Printf("Error connecting to %p, %p", host, err) + log.Printf("Error connecting to %s, %s", host, err) msg := p.channelResponse(E_PROXY_INTERNALERROR) p.tunnel.Write(msg) return err @@ -142,7 +142,7 @@ func (p *Processor) Process(ctx context.Context) error { // Make sure to start the flow from the RDP server first otherwise connections // might hang eventually - go forward(p.tunnel.rwc, p.tunnel.TransportOut) + go forward(p.tunnel.rwc, p.tunnel) p.state = SERVER_STATE_CHANNEL_CREATE case PKT_TYPE_DATA: if p.state < SERVER_STATE_CHANNEL_CREATE { diff --git a/cmd/rdpgw/protocol/track.go b/cmd/rdpgw/protocol/track.go index f538d9c41bf3cc69142350c909de942886ae2c59..83c4179e3e3913bef7add1b592da19d5d96e6d8c 100644 --- a/cmd/rdpgw/protocol/track.go +++ b/cmd/rdpgw/protocol/track.go @@ -7,19 +7,19 @@ type Monitor struct { Tunnel *Tunnel } -func RegisterConnection(h *Processor, t *Tunnel) { +func RegisterTunnel(t *Tunnel, p *Processor) { if Connections == nil { Connections = make(map[string]*Monitor) } - Connections[t.RDGId] = &Monitor{ - Processor: h, + Connections[t.Id] = &Monitor{ + Processor: p, Tunnel: t, } } -func RemoveConnection(connId string) { - delete(Connections, connId) +func RemoveTunnel(t *Tunnel) { + delete(Connections, t.Id) } // CalculateSpeedPerSecond calculate moving average. diff --git a/cmd/rdpgw/protocol/tunnel.go b/cmd/rdpgw/protocol/tunnel.go index ebad094977388d60e1687b15ea0f7f27ae035303..dbf7c64f35bea9f57e12e81e9656877fb5db07e1 100644 --- a/cmd/rdpgw/protocol/tunnel.go +++ b/cmd/rdpgw/protocol/tunnel.go @@ -7,6 +7,8 @@ import ( ) type Tunnel struct { + // Id identifies the connection in the server + Id string // The connection-id (RDG-ConnID) as reported by the client RDGId string // The underlying incoming transport being either websocket or legacy http @@ -26,18 +28,28 @@ type Tunnel struct { // It is of the type *net.TCPConn rwc net.Conn - ByteSent int64 + // BytesSent is the total amount of bytes sent by the server to the client minus tunnel overhead + BytesSent int64 + + // BytesReceived is the total amount of bytes received by the server from the client minus tunnel overhad BytesReceived int64 + // ConnectedOn is when the client connected to the server ConnectedOn time.Time - LastSeen time.Time + + // LastSeen is when the server received the last packet from the client + LastSeen time.Time } +// Write puts the packet on the transport and updates the statistics for bytes sent func (t *Tunnel) Write(pkt []byte) { n, _ := t.TransportOut.WritePacket(pkt) - t.ByteSent += int64(n) + t.BytesSent += int64(n) } +// Read picks up a packet from the transport and returns the packet type +// packet, with the header removed, and the packet size. It updates the +// statistics for bytes received func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) { pt, size, pkt, err = readMessage(t.TransportIn) t.BytesReceived += int64(size) diff --git a/cmd/rdpgw/security/basic.go b/cmd/rdpgw/security/basic.go index 7419fd6e40301bfbc9a3afcd796b8a34783f555b..ca7d2c009a895031d187a1b92bde6484f3232083 100644 --- a/cmd/rdpgw/security/basic.go +++ b/cmd/rdpgw/security/basic.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "log" "strings" ) @@ -23,10 +24,10 @@ func CheckHost(ctx context.Context, host string) (bool, error) { case "roundrobin", "unsigned": var username string - s := getSessionInfo(ctx) + s := getTunnel(ctx) if s == nil || s.UserName == "" { var ok bool - username, ok = ctx.Value("preferred_username").(string) + username, ok = ctx.Value(common.UsernameCtx).(string) if !ok { return false, errors.New("no valid session info or username found in context") } diff --git a/cmd/rdpgw/security/basic_test.go b/cmd/rdpgw/security/basic_test.go index 40551ceb051fb5eb84db72f77647f11f524e7bcd..6ab50b60fec38291178e56b562df227593be057f 100644 --- a/cmd/rdpgw/security/basic_test.go +++ b/cmd/rdpgw/security/basic_test.go @@ -2,6 +2,7 @@ package security import ( "context" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol" "testing" ) @@ -20,7 +21,7 @@ var ( ) func TestCheckHost(t *testing.T) { - ctx := context.WithValue(context.Background(), "Tunnel", &info) + ctx := context.WithValue(context.Background(), common.TunnelCtx, &info) Hosts = hosts diff --git a/cmd/rdpgw/security/jwt.go b/cmd/rdpgw/security/jwt.go index c5d0f219870740b17f1de926e054827a92b0a2ce..cc8a3f936c836d8a99c225cb028845f727635d03 100644 --- a/cmd/rdpgw/security/jwt.go +++ b/cmd/rdpgw/security/jwt.go @@ -35,7 +35,7 @@ type customClaims struct { func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc { return func(ctx context.Context, host string) (bool, error) { - s := getSessionInfo(ctx) + s := getTunnel(ctx) if s == nil { return false, errors.New("no valid session info found in context") } @@ -62,7 +62,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) { token, err := jwt.ParseSigned(tokenString) if err != nil { - log.Printf("cannot parse token due to: %s", err) + log.Printf("cannot parse token due to: %tunnel", err) return false, err } @@ -79,7 +79,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) { // Claims automagically checks the signature... err = token.Claims(SigningKey, &standard, &custom) if err != nil { - log.Printf("token signature validation failed due to %s", err) + log.Printf("token signature validation failed due to %tunnel", err) return false, err } @@ -90,7 +90,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) { }) if err != nil { - log.Printf("token validation failed due to %s", err) + log.Printf("token validation failed due to %tunnel", err) return false, err } @@ -98,15 +98,15 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) { tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken}) user, err := OIDCProvider.UserInfo(ctx, tokenSource) if err != nil { - log.Printf("Cannot get user info for access token: %s", err) + log.Printf("Cannot get user info for access token: %tunnel", err) return false, err } - s := getSessionInfo(ctx) + tunnel := getTunnel(ctx) - s.TargetServer = custom.RemoteServer - s.RemoteAddr = custom.ClientIP - s.UserName = user.Subject + tunnel.TargetServer = custom.RemoteServer + tunnel.RemoteAddr = custom.ClientIP + tunnel.UserName = user.Subject return true, nil } @@ -288,7 +288,7 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin return token, err } -func getSessionInfo(ctx context.Context) *protocol.Tunnel { +func getTunnel(ctx context.Context) *protocol.Tunnel { s, ok := ctx.Value("Tunnel").(*protocol.Tunnel) if !ok { log.Printf("cannot get session info from context") diff --git a/cmd/rdpgw/web/basic.go b/cmd/rdpgw/web/basic.go index 946036a6e693835ab2009e262155e8fdc758cb7b..5c1443eb5e1322e81d263b737002664ad9421af0 100644 --- a/cmd/rdpgw/web/basic.go +++ b/cmd/rdpgw/web/basic.go @@ -2,6 +2,7 @@ package web import ( "context" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/shared/auth" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -12,7 +13,7 @@ import ( ) const ( - protocol = "unix" + protocolGrpc = "unix" ) type BasicAuthHandler struct { @@ -27,7 +28,7 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc { conn, err := grpc.Dial(h.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - return net.Dial(protocol, addr) + return net.Dial(protocolGrpc, addr) })) if err != nil { log.Printf("Cannot reach authentication provider: %s", err) @@ -51,7 +52,7 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc { if !res.Authenticated { log.Printf("User %s is not authenticated for this service", username) } else { - ctx := context.WithValue(r.Context(), "preferred_username", username) + ctx := context.WithValue(r.Context(), common.UsernameCtx, username) next.ServeHTTP(w, r.WithContext(ctx)) return } diff --git a/cmd/rdpgw/web/oidc.go b/cmd/rdpgw/web/oidc.go index a06a1394367f73a1b44bb913ff64c1329adb2916..93b9945fde9baff84d30f39f1ea460d68a3d47b7 100644 --- a/cmd/rdpgw/web/oidc.go +++ b/cmd/rdpgw/web/oidc.go @@ -4,6 +4,7 @@ import ( "context" "encoding/hex" "encoding/json" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/sessions" "github.com/patrickmn/go-cache" @@ -119,7 +120,7 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler { return } - ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"]) + ctx := context.WithValue(r.Context(), common.UsernameCtx, session.Values["preferred_username"]) ctx = context.WithValue(ctx, "access_token", session.Values["access_token"]) next.ServeHTTP(w, r.WithContext(ctx)) diff --git a/go.mod b/go.mod index bdf03cc9646a610a266fca23f03f9550ed312ea2..5998bd2a6e83186a72a05441627b9019e37b1162 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/coreos/go-oidc/v3 v3.2.0 github.com/fatih/structs v1.1.0 github.com/go-jose/go-jose/v3 v3.0.0 + github.com/google/uuid v1.1.2 github.com/gorilla/sessions v1.2.1 github.com/gorilla/websocket v1.5.0 github.com/knadh/koanf v1.4.2