diff --git a/cmd/rdpgw/kdcproxy/proxy.go b/cmd/rdpgw/kdcproxy/proxy.go index 86e796e5021d5eacecec987f38893c9e003c21c4..6069134161c822271f24d4a9b906e4236ef3cc0a 100644 --- a/cmd/rdpgw/kdcproxy/proxy.go +++ b/cmd/rdpgw/kdcproxy/proxy.go @@ -23,6 +23,13 @@ type KdcProxyMsg struct { Flags int `asn1:"tag:2,optional"` } +type Kdc struct { + Realm string + Host string + Proto string + Conn net.Conn +} + type KerberosProxy struct { krb5Config *krbconfig.Config } @@ -97,39 +104,71 @@ func (k *KerberosProxy) forward(realm string, data []byte) (resp []byte, err err } // load udp first as is the default for kerberos - c, kdcs, err := k.krb5Config.GetKDCs(realm, false) - if err != nil || c < 1 { - return nil, fmt.Errorf("cannot get kdc for realm %s due to %s", realm, err) + udpCnt, udpKdcs, err := k.krb5Config.GetKDCs(realm, false) + if err != nil { + return nil, fmt.Errorf("cannot get udp kdc for realm %s due to %s", realm, err) } + // load tcp + tcpCnt, tcpKdcs, err := k.krb5Config.GetKDCs(realm, true) + if err != nil { + return nil, fmt.Errorf("cannot get tcp kdc for realm %s due to %s", realm, err) + } + + if tcpCnt+udpCnt == 0 { + return nil, fmt.Errorf("cannot get any kdcs (tcp or udp) for realm %s", realm) + } + + // merge the kdcs + kdcs := make([]Kdc, tcpCnt+udpCnt) + for i := range udpKdcs { + kdcs[i] = Kdc{Realm: realm, Host: udpKdcs[i], Proto: "udp"} + } + for i := range tcpKdcs { + kdcs[i+udpCnt] = Kdc{Realm: realm, Host: tcpKdcs[i], Proto: "tcp"} + } + + replies := make(chan []byte, len(kdcs)) for i := range kdcs { - conn, err := net.Dial("tcp", kdcs[i]) + conn, err := net.Dial(kdcs[i].Proto, kdcs[i].Host) + if err != nil { log.Printf("error connecting to %s due to %s, trying next if available", kdcs[i], err) continue } conn.SetDeadline(time.Now().Add(timeout)) - _, err = conn.Write(data) + // if we proxy over UDP remove the length prefix + if kdcs[i].Proto == "tcp" { + _, err = conn.Write(data) + } else { + _, err = conn.Write(data[4:]) + } if err != nil { log.Printf("cannot write packet data to %s due to %s, trying next if available", kdcs[i], err) conn.Close() continue } - // todo check header - resp, err = io.ReadAll(conn) - if err != nil { - log.Printf("error reading from kdc %s due to %s, trying next if available", kdcs[i], err) - conn.Close() - continue + kdcs[i].Conn = conn + go awaitReply(conn, kdcs[i].Proto == "udp", replies) + } + + reply := <-replies + + // close all the connections and return the first reply + for kdc := range kdcs { + if kdcs[kdc].Conn != nil { + kdcs[kdc].Conn.Close() } - conn.Close() + <-replies + } - return resp, nil + if reply != nil { + return reply, nil } - return nil, fmt.Errorf("no kdcs found for realm %s", realm) + return nil, fmt.Errorf("no replies received from kdcs for realm %s", realm) } func decode(data []byte) (msg *KdcProxyMsg, err error) { @@ -155,3 +194,17 @@ func encode(krb5data []byte) (r []byte, err error) { } return enc, nil } + +func awaitReply(conn net.Conn, isUdp bool, reply chan<- []byte) { + resp, err := io.ReadAll(conn) + if err != nil { + log.Printf("error reading from kdc due to %s", err) + reply <- nil + return + } + if isUdp { + // udp will be missing the length prefix so add it + resp = append([]byte{byte(len(resp))}, resp...) + } + reply <- resp +}