Skip to content
Snippets Groups Projects
Commit 33a5e0e0 authored by Bolke de Bruin's avatar Bolke de Bruin
Browse files

Add prometheus and disable http2

parent 1f191a5e
No related branches found
No related tags found
No related merge requests found
...@@ -5,4 +5,5 @@ go 1.14 ...@@ -5,4 +5,5 @@ go 1.14
require ( require (
github.com/gorilla/websocket v1.4.2 github.com/gorilla/websocket v1.4.2
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/prometheus/client_golang v1.7.1
) )
...@@ -3,6 +3,8 @@ package main ...@@ -3,6 +3,8 @@ package main
import ( import (
"crypto/tls" "crypto/tls"
"flag" "flag"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/prometheus/client_golang/prometheus"
"log" "log"
"net/http" "net/http"
"os" "os"
...@@ -42,8 +44,15 @@ func main() { ...@@ -42,8 +44,15 @@ func main() {
server := http.Server{ server := http.Server{
Addr: ":" + strconv.Itoa(*port), Addr: ":" + strconv.Itoa(*port),
TLSConfig: cfg, TLSConfig: cfg,
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
} }
http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol) http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol)
http.Handle("/metrics", promhttp.Handler())
prometheus.MustRegister(connectionCache)
prometheus.MustRegister(legacyConnections)
prometheus.MustRegister(websocketConnections)
err = server.ListenAndServeTLS("", "") err = server.ListenAndServeTLS("", "")
if err != nil { if err != nil {
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"github.com/prometheus/client_golang/prometheus"
"io" "io"
"log" "log"
"math/rand" "math/rand"
...@@ -15,7 +16,6 @@ import ( ...@@ -15,7 +16,6 @@ import (
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"strconv" "strconv"
"time" "time"
"unicode/utf16" "unicode/utf16"
"unicode/utf8" "unicode/utf8"
...@@ -84,6 +84,29 @@ const ( ...@@ -84,6 +84,29 @@ const (
HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1 HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1
) )
var (
connectionCache = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "connection_cache",
Help: "The amount of connections in the cache",
})
websocketConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "websocket_connections",
Help: "The count of websocket connections",
})
legacyConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "legacy_connections",
Help: "The count of legacy https connections",
})
)
// HandshakeHeader is the interface that writes both upgrade request or // HandshakeHeader is the interface that writes both upgrade request or
// response headers into a given io.Writer. // response headers into a given io.Writer.
type HandshakeHeader interface { type HandshakeHeader interface {
...@@ -129,6 +152,7 @@ var upgrader = websocket.Upgrader{} ...@@ -129,6 +152,7 @@ var upgrader = websocket.Upgrader{}
var c = cache.New(5*time.Minute, 10*time.Minute) var c = cache.New(5*time.Minute, 10*time.Minute)
func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) { func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
connectionCache.Set(float64(c.ItemCount()))
if r.Method == MethodRDGOUT { if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
handleLegacyProtocol(w, r) handleLegacyProtocol(w, r)
...@@ -155,6 +179,9 @@ func handleWebsocketProtocol(conn *websocket.Conn) { ...@@ -155,6 +179,9 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
var remote net.Conn var remote net.Conn
websocketConnections.Inc()
defer websocketConnections.Dec()
for { for {
mt, msg, err := conn.ReadMessage() mt, msg, err := conn.ReadMessage()
if err != nil { if err != nil {
...@@ -217,8 +244,7 @@ func handleWebsocketProtocol(conn *websocket.Conn) { ...@@ -217,8 +244,7 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
// do not write to make sure we do not create concurrency issues // do not write to make sure we do not create concurrency issues
// conn.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{})) // conn.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL: case PKT_TYPE_CLOSE_CHANNEL:
remote.Close() break
return
default: default:
log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz, pkt) log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz, pkt)
} }
...@@ -252,6 +278,9 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { ...@@ -252,6 +278,9 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
c.Set(connId, s, cache.DefaultExpiration) c.Set(connId, s, cache.DefaultExpiration)
} else if r.Method == MethodRDGIN { } else if r.Method == MethodRDGIN {
legacyConnections.Inc()
defer legacyConnections.Dec()
var remote net.Conn var remote net.Conn
conn, rw, _ := Accept(w) conn, rw, _ := Accept(w)
...@@ -314,7 +343,6 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { ...@@ -314,7 +343,6 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
case PKT_TYPE_CLOSE_CHANNEL: case PKT_TYPE_CLOSE_CHANNEL:
s.ConnIn.Close() s.ConnIn.Close()
s.ConnOut.Close() s.ConnOut.Close()
remote.Close()
break break
default: default:
log.Printf("Unknown packet (size %d): %x", n, packet) log.Printf("Unknown packet (size %d): %x", n, packet)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment