diff --git a/config/configuration.go b/config/configuration.go index 0f921b0e44ab56b6c8af29758c7225acb4c41cdf..59812f05ec2c03d816d5bff67a8bd9798307e28d 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -48,6 +48,7 @@ type SecurityConfig struct { PAATokenSigningKey string UserTokenEncryptionKey string UserTokenSigningKey string + VerifyClientIp bool } type ClientConfig struct { @@ -61,9 +62,9 @@ func init() { viper.SetDefault("server.certFile", "server.pem") viper.SetDefault("server.keyFile", "key.pem") viper.SetDefault("server.port", 443) - viper.SetDefault("security.enableOpenId", true) viper.SetDefault("client.networkAutoDetect", 1) viper.SetDefault("client.bandwidthAutoDetect", 1) + viper.SetDefault("security.verifyClientIp", true) } func Load(configFile string) Configuration { diff --git a/main.go b/main.go index 759aea634f49d4723560018fb8f5a7aa8a747229..8c415a7b38e4938d6da4b596152866a904cd8085 100644 --- a/main.go +++ b/main.go @@ -34,6 +34,8 @@ func main() { cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml", "config file (json, yaml, ini)") conf = config.Load(configFile) + security.VerifyClientIP = conf.Security.VerifyClientIp + // set security keys security.SigningKey = []byte(conf.Security.PAATokenSigningKey) security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey) diff --git a/protocol/common.go b/protocol/common.go index 37053642b442557df1477689d3beb6460b300748..0e3f7f7ce7e56d3cc68610d563265a64df07b55d 100644 --- a/protocol/common.go +++ b/protocol/common.go @@ -21,13 +21,22 @@ type RedirectFlags struct { } type SessionInfo struct { + // The connection-id (RDG-ConnID) as reported by the client ConnId string + // The underlying incoming transport being either websocket or legacy http + // in case of websocket TransportOut will equal TransportIn TransportIn transport.Transport + // The underlying outgoing transport being either websocket or legacy http + // in case of websocket TransportOut will equal TransportOut TransportOut transport.Transport + // The remote desktop server (rdp, vnc etc) the clients intends to connect to RemoteServer string + // The obtained client ip address ClientIp string } +// readMessage parses and defragments a packet from a Transport. It returns +// at most the bytes that have been reported by the packet func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) { fragment := false index := 0 @@ -66,6 +75,7 @@ func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) } } +// createPacket wraps the data into the protocol packet func createPacket(pktType uint16, data []byte) (packet []byte) { size := len(data) + 8 buf := new(bytes.Buffer) @@ -78,6 +88,7 @@ func createPacket(pktType uint16, data []byte) (packet []byte) { return buf.Bytes() } +// readHeader parses a packet and verifies its reported size func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) { // header needs to be 8 min if len(data) < 8 { @@ -90,10 +101,10 @@ func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err if len(data) < int(size) { return packetType, size, data[8:], errors.New("data incomplete, fragment received") } - return packetType, size, data[8:], nil + return packetType, size, data[8:size-8], nil } -// sends data wrapped inside the rdpgw protocol +// forwards data from a Connection to Transport and wraps it in the rdpgw protocol func forward(in net.Conn, out transport.Transport) { defer in.Close() @@ -113,7 +124,7 @@ func forward(in net.Conn, out transport.Transport) { } } -// receive data from the wire, unwrap and forward to the client +// receive data received from the gateway client, unwrap and forward the remote desktop server func receive(data []byte, out net.Conn) { buf := bytes.NewReader(data) diff --git a/security/jwt.go b/security/jwt.go index a80654a338a1e7c00a4ab29114fb3cc145d250fd..6cae11498d05d905b5e2db894ff60ac8e23f630d 100644 --- a/security/jwt.go +++ b/security/jwt.go @@ -24,6 +24,7 @@ var ( ) var ExpiryTime time.Duration = 5 +var VerifyClientIP bool = true type customClaims struct { RemoteServer string `json:"remoteServer"` @@ -89,11 +90,11 @@ func VerifyServerFunc(ctx context.Context, host string) (bool, error) { return false, nil } - /*if s.ClientIp != common.GetClientIp(ctx) { + if VerifyClientIP && s.ClientIp != common.GetClientIp(ctx) { log.Printf("Current client ip address %s does not match token client ip %s", common.GetClientIp(ctx), s.ClientIp) return false, nil - }*/ + } return true, nil }