diff --git a/go.mod b/go.mod index 7b6b0d6d51c36dc0bfb02fd81afcf46912b4c96a..a2cdd6c626cfd77e66945dfe86163f0caf113724 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,5 @@ go 1.14 require ( github.com/gorilla/websocket v1.4.2 github.com/patrickmn/go-cache v2.1.0+incompatible + github.com/prometheus/client_golang v1.7.1 ) diff --git a/main.go b/main.go index ba3da3daeacdd7170c006cb48fc3924e740cbc9e..8342ae70c70464c11654d6858c3f147b48a6eb96 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,8 @@ package main import ( "crypto/tls" "flag" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/prometheus/client_golang/prometheus" "log" "net/http" "os" @@ -42,9 +44,16 @@ func main() { server := http.Server{ Addr: ":" + strconv.Itoa(*port), TLSConfig: cfg, + TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 } + http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol) - + http.Handle("/metrics", promhttp.Handler()) + + prometheus.MustRegister(connectionCache) + prometheus.MustRegister(legacyConnections) + prometheus.MustRegister(websocketConnections) + err = server.ListenAndServeTLS("", "") if err != nil { log.Fatal("ListenAndServe: ", err) diff --git a/rdg.go b/rdg.go index 4d8526ad9744c7764fc50fd40d92e063a64c0a9f..b4984aa51ea03030b58aa8198d4a8a8e84cdc45e 100644 --- a/rdg.go +++ b/rdg.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/gorilla/websocket" "github.com/patrickmn/go-cache" + "github.com/prometheus/client_golang/prometheus" "io" "log" "math/rand" @@ -15,7 +16,6 @@ import ( "net/http" "net/http/httputil" "strconv" - "time" "unicode/utf16" "unicode/utf8" @@ -84,6 +84,29 @@ const ( HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1 ) +var ( + connectionCache = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: "rdpgw", + Name: "connection_cache", + Help: "The amount of connections in the cache", + }) + + websocketConnections = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: "rdpgw", + Name: "websocket_connections", + Help: "The count of websocket connections", + }) + + legacyConnections = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: "rdpgw", + Name: "legacy_connections", + Help: "The count of legacy https connections", + }) +) + // HandshakeHeader is the interface that writes both upgrade request or // response headers into a given io.Writer. type HandshakeHeader interface { @@ -129,6 +152,7 @@ var upgrader = websocket.Upgrader{} var c = cache.New(5*time.Minute, 10*time.Minute) func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) { + connectionCache.Set(float64(c.ItemCount())) if r.Method == MethodRDGOUT { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { handleLegacyProtocol(w, r) @@ -155,6 +179,9 @@ func handleWebsocketProtocol(conn *websocket.Conn) { var remote net.Conn + websocketConnections.Inc() + defer websocketConnections.Dec() + for { mt, msg, err := conn.ReadMessage() if err != nil { @@ -217,8 +244,7 @@ func handleWebsocketProtocol(conn *websocket.Conn) { // do not write to make sure we do not create concurrency issues // conn.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{})) case PKT_TYPE_CLOSE_CHANNEL: - remote.Close() - return + break default: log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz, pkt) } @@ -252,6 +278,9 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { c.Set(connId, s, cache.DefaultExpiration) } else if r.Method == MethodRDGIN { + legacyConnections.Inc() + defer legacyConnections.Dec() + var remote net.Conn conn, rw, _ := Accept(w) @@ -314,7 +343,6 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { case PKT_TYPE_CLOSE_CHANNEL: s.ConnIn.Close() s.ConnOut.Close() - remote.Close() break default: log.Printf("Unknown packet (size %d): %x", n, packet)