diff --git a/cmd/rdpgw/protocol/protocol_test.go b/cmd/rdpgw/protocol/protocol_test.go index 6d74ae61adc6b195122c966ec6fa414ae1ad2b4e..a57213a597ddee7311e6dabdad5c7ae6cb9bfaf1 100644 --- a/cmd/rdpgw/protocol/protocol_test.go +++ b/cmd/rdpgw/protocol/protocol_test.go @@ -66,7 +66,7 @@ func TestHandshake(t *testing.T) { t.Fatalf("handshakeRequest failed got ext auth %d, expected %d", extAuth, extAuth|HTTP_EXTENDED_AUTH_PAA) } - data = h.handshakeResponse(0x0, 0x0, 0, ERROR_SUCCESS) + data = h.handshakeResponse(0x0, 0x0, HTTP_EXTENDED_AUTH_PAA, ERROR_SUCCESS) _, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_RESPONSE, HandshakeResponseLen) if err != nil { t.Fatalf("verifyHeader failed: %s", err) @@ -79,6 +79,58 @@ func TestHandshake(t *testing.T) { } } +func capsHelper(h Server) uint16 { + var caps uint16 + if h.TokenAuth { + caps = caps | HTTP_EXTENDED_AUTH_PAA + } + if h.SmartCardAuth { + caps = caps | HTTP_EXTENDED_AUTH_SC + } + return caps +} + +func TestMatchAuth(t *testing.T) { + s := &SessionInfo{} + hc := &ServerConf{ + TokenAuth: false, + SmartCardAuth: false, + } + + h:= NewServer(s, hc) + + in := uint16(0) + caps, err := h.matchAuth(in) + if err != nil { + t.Fatalf("in caps: %x <= server caps %x, but %s", in, capsHelper(*h), err) + } + if caps > in { + t.Fatalf("returned server caps %x > client cpas %x", capsHelper(*h), in) + } + + in = HTTP_EXTENDED_AUTH_PAA + caps, err = h.matchAuth(in) + if err == nil { + t.Fatalf("server cannot satisfy client caps %x but error is nil (server caps %x)", in, caps) + } else { + t.Logf("(SUCCESS) server cannot satisfy client caps : %s", err) + } + + h.SmartCardAuth = true + caps, err = h.matchAuth(in) + if err == nil { + t.Fatalf("server cannot satisfy client caps %x but error is nil (server caps %x)", in, caps) + } else { + t.Logf("(SUCCESS) server cannot satisfy client caps : %s", err) + } + + h.TokenAuth = true + caps, err = h.matchAuth(in) + if err != nil { + t.Fatalf("server caps %x (orig: %x) should match client request %x, %s", caps, capsHelper(*h), in, err) + } +} + func TestTunnelCreation(t *testing.T) { client := ClientConfig{ PAAToken: "abab", diff --git a/cmd/rdpgw/protocol/server.go b/cmd/rdpgw/protocol/server.go index 64bca2fa314d90d7e2a8f64e963154313decbe85..a85a088b039e68a09e6d5725f7f3de1efbb4896d 100644 --- a/cmd/rdpgw/protocol/server.go +++ b/cmd/rdpgw/protocol/server.go @@ -232,7 +232,7 @@ func (s *Server) matchAuth(extAuth uint16) (caps uint16, err error) { caps = caps | HTTP_EXTENDED_AUTH_PAA } - if caps & extAuth == 0 { + if caps & extAuth == 0 && extAuth > 0 { return 0, fmt.Errorf("%x has no matching capability configured (%x). Did you configure caps? ", extAuth, caps) }