diff --git a/cmd/rdpgw/security/basic.go b/cmd/rdpgw/security/basic.go index e2596b781c0ca4ef530d1a9e6eb3bea6c373b912..7419fd6e40301bfbc9a3afcd796b8a34783f555b 100644 --- a/cmd/rdpgw/security/basic.go +++ b/cmd/rdpgw/security/basic.go @@ -30,6 +30,8 @@ func CheckHost(ctx context.Context, host string) (bool, error) { if !ok { return false, errors.New("no valid session info or username found in context") } + } else { + username = s.UserName } log.Printf("Checking host for user %s", username) for _, h := range Hosts { diff --git a/cmd/rdpgw/security/basic_test.go b/cmd/rdpgw/security/basic_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d65422ff5d20a204faf5a3e822fcd7a91d510aa0 --- /dev/null +++ b/cmd/rdpgw/security/basic_test.go @@ -0,0 +1,56 @@ +package security + +import ( + "context" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol" + "testing" +) + +var ( + info = protocol.SessionInfo{ + ConnId: "myid", + TransportIn: nil, + TransportOut: nil, + RemoteServer: "my.remote.server", + ClientIp: "10.0.0.1", + UserName: "Frank", + } + + hosts = []string{"localhost:3389", "my-{{ preferred_username }}-host:3389"} +) + +func TestCheckHost(t *testing.T) { + ctx := context.WithValue(context.Background(), "SessionInfo", &info) + + Hosts = hosts + + // check any + HostSelection = "any" + host := "try.my.server:3389" + if ok, err := CheckHost(ctx, host); !ok || err != nil { + t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err) + } + + HostSelection = "signed" + if ok, err := CheckHost(ctx, host); ok || err == nil { + t.Fatalf("signed host selection isnt supported at the moment") + } + + HostSelection = "roundrobin" + if ok, err := CheckHost(ctx, host); ok { + t.Fatalf("%s should NOT be allowed with host selection %s (err: %s)", host, HostSelection, err) + } + + host = "my-Frank-host:3389" + if ok, err := CheckHost(ctx, host); !ok { + t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err) + } + + info.UserName = "" + ctx = context.WithValue(ctx, "preferred_username", "dummy") + host = "my-dummy-host:3389" + if ok, err := CheckHost(ctx, host); !ok { + t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err) + } + +}