diff --git a/README.md b/README.md index 5d6fe7d145396c2120931754aa936942e12d09c9..9a0f7ad9e0453e8f91a3823fd720ecec38bf7874 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,11 @@ server: # make sure to share this across the different pods sessionKey: thisisasessionkeyreplacethisjetzt sessionEncryptionKey: thisisasessionkeyreplacethisnunu! + # tries to set the receive / send buffer of the connections to the client + # in case of high latency high bandwidth the defaults set by the OS might + # be to low for a good experience + # receiveBuf: 12582912 + # sendBuf: 12582912 # Open ID Connect specific settings openId: providerUrl: http://keycloak/auth/realms/test diff --git a/config/configuration.go b/config/configuration.go index 43b77056781c10774ee19461cbe9ef90cbfb56ed..2e706444f75ef5c4401773a1ac4bdf3b1f96a3f3 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -22,6 +22,8 @@ type ServerConfig struct { RoundRobin bool SessionKey string SessionEncryptionKey string + SendBuf int + ReceiveBuf int } type OpenIDConfig struct { diff --git a/main.go b/main.go index c6836d6f1e0e86f23dda7dcdaa83d5845bf20a59..efc115fd5fa1b776d74e3cc2dea78956b322311e 100644 --- a/main.go +++ b/main.go @@ -128,6 +128,8 @@ func main() { }, VerifyTunnelCreate: security.VerifyPAAToken, VerifyServerFunc: security.VerifyServerFunc, + SendBuf: conf.Server.SendBuf, + ReceiveBuf: conf.Server.ReceiveBuf, } gw := protocol.Gateway{ ServerConf: &handlerConfig, diff --git a/protocol/common.go b/protocol/common.go index 744f334ad3ff6bf537002af5706e79dd7da797c2..405445dddb9fdc382255ff611a8e3cf19d2afaaf 100644 --- a/protocol/common.go +++ b/protocol/common.go @@ -8,6 +8,8 @@ import ( "io" "log" "net" + "os" + "syscall" ) type RedirectFlags struct { @@ -136,3 +138,11 @@ func receive(data []byte, out net.Conn) { out.Write(pkt) } +// wrapSyscallError takes an error and a syscall name. If the error is +// a syscall.Errno, it wraps it in a os.SyscallError using the syscall name. +func wrapSyscallError(name string, err error) error { + if _, ok := err.(syscall.Errno); ok { + err = os.NewSyscallError(name, err) + } + return err +} \ No newline at end of file diff --git a/protocol/gateway.go b/protocol/gateway.go index fe4eee211998deb0404d3ef26e81bc221f6210c6..ba864317415f5a53f1b8a5c4940f7a5c3e1c70be 100644 --- a/protocol/gateway.go +++ b/protocol/gateway.go @@ -2,13 +2,17 @@ package protocol import ( "context" + "errors" "github.com/bolkedebruin/rdpgw/common" "github.com/bolkedebruin/rdpgw/transport" "github.com/gorilla/websocket" "github.com/patrickmn/go-cache" "github.com/prometheus/client_golang/prometheus" "log" + "net" "net/http" + "reflect" + "syscall" "time" ) @@ -81,12 +85,76 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) } defer conn.Close() + err = g.setSendReceiveBuffers(conn.UnderlyingConn()) + if err != nil { + log.Printf("Cannot set send/receive buffers: %s", err) + } + g.handleWebsocketProtocol(ctx, conn, s) } else if r.Method == MethodRDGIN { g.handleLegacyProtocol(w, r.WithContext(ctx), s) } } +func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error { + if g.ServerConf.SendBuf < 1 && g.ServerConf.ReceiveBuf < 1 { + return nil + } + + // conn == tls.Conn + ptr := reflect.ValueOf(conn) + val := reflect.Indirect(ptr) + + if val.Kind() != reflect.Struct { + return errors.New("didn't get a struct from conn") + } + + // this gets net.Conn -> *net.TCPConn -> net.TCPConn + ptrConn := val.FieldByName("conn") + valConn := reflect.Indirect(ptrConn) + if !valConn.IsValid() { + return errors.New("cannot find conn field") + } + valConn = valConn.Elem().Elem() + + // net.FD + ptrNetFd := valConn.FieldByName("fd") + valNetFd := reflect.Indirect(ptrNetFd) + if !valNetFd.IsValid() { + return errors.New("cannot find fd field") + } + + // pfd member + ptrPfd := valNetFd.FieldByName("pfd") + valPfd := reflect.Indirect(ptrPfd) + if !valPfd.IsValid() { + return errors.New("cannot find pfd field") + } + + // finally the exported Sysfd + ptrSysFd := valPfd.FieldByName("Sysfd") + if !ptrSysFd.IsValid() { + return errors.New("cannot find Sysfd field") + } + fd := int(ptrSysFd.Int()) + + if g.ServerConf.ReceiveBuf > 0 { + err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ServerConf.ReceiveBuf) + if err != nil { + return wrapSyscallError("setsockopt", err) + } + } + + if g.ServerConf.SendBuf > 0 { + err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.ServerConf.SendBuf) + if err != nil { + return wrapSyscallError("setsockopt", err) + } + } + + return nil +} + func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) { websocketConnections.Inc() defer websocketConnections.Dec() diff --git a/protocol/server.go b/protocol/server.go index 13288dd28c7e6e27e662fa1e7b2b78a77043098f..cf87132c2a8d4ba67cdc3a507428454c57dd8fc4 100644 --- a/protocol/server.go +++ b/protocol/server.go @@ -39,6 +39,8 @@ type ServerConf struct { IdleTimeout int SmartCardAuth bool TokenAuth bool + ReceiveBuf int + SendBuf int } func NewServer(s *SessionInfo, conf *ServerConf) *Server {