diff --git a/go/grpcweb/DOC.md b/go/grpcweb/DOC.md index a8377d44..29e6a620 100644 --- a/go/grpcweb/DOC.md +++ b/go/grpcweb/DOC.md @@ -136,6 +136,16 @@ requests - usually to check that the origin is valid. The default behaviour is to check that the origin of the request matches the host of the request and deny all requests from remote origins. +#### func WithWebsocketPingInterval + +```go +func WithWebsocketPingInterval(websocketPingInterval time.Duration) Option +``` +WithWebsocketPingInterval enables websocket keepalive pinging with the +configured timeout. + +The default behaviour is to disable websocket pinging. + #### func WithWebsockets ```go diff --git a/go/grpcweb/options.go b/go/grpcweb/options.go index 96aea91d..db59af20 100644 --- a/go/grpcweb/options.go +++ b/go/grpcweb/options.go @@ -3,7 +3,10 @@ package grpcweb -import "net/http" +import ( + "net/http" + "time" +) var ( defaultOptions = &options{ @@ -19,6 +22,7 @@ type options struct { corsForRegisteredEndpointsOnly bool originFunc func(origin string) bool enableWebsockets bool + websocketPingInterval time.Duration websocketOriginFunc func(req *http.Request) bool allowNonRootResources bool } @@ -92,6 +96,15 @@ func WithWebsockets(enableWebsockets bool) Option { } } +// WithWebsocketPingInterval enables websocket keepalive pinging with the configured timeout. +// +// The default behaviour is to disable websocket pinging. +func WithWebsocketPingInterval(websocketPingInterval time.Duration) Option { + return func(o *options) { + o.websocketPingInterval = websocketPingInterval + } +} + // WithWebsocketOriginFunc allows for customizing the acceptance of Websocket requests - usually to check that the origin // is valid. // diff --git a/go/grpcweb/websocket_wrapper.go b/go/grpcweb/websocket_wrapper.go index 0cd375f9..a920d00e 100644 --- a/go/grpcweb/websocket_wrapper.go +++ b/go/grpcweb/websocket_wrapper.go @@ -10,16 +10,20 @@ import ( "net/http" "net/textproto" "strings" + "time" + "github.com/desertbit/timer" "github.com/gorilla/websocket" "golang.org/x/net/http2" ) type webSocketResponseWriter struct { - writtenHeaders bool - wsConn *websocket.Conn - headers http.Header - flushedHeaders http.Header + writtenHeaders bool + wsConn *websocket.Conn + headers http.Header + flushedHeaders http.Header + timeOutInterval time.Duration + timer *timer.Timer } func newWebSocketResponseWriter(wsConn *websocket.Conn) *webSocketResponseWriter { @@ -31,6 +35,33 @@ func newWebSocketResponseWriter(wsConn *websocket.Conn) *webSocketResponseWriter } } +func (w *webSocketResponseWriter) enablePing(timeOutInterval time.Duration) { + w.timeOutInterval = timeOutInterval + w.timer = timer.NewTimer(w.timeOutInterval) + dispose := make(chan bool) + w.wsConn.SetCloseHandler(func(code int, text string) error { + close(dispose) + return nil + }) + go w.ping(dispose) +} + +func (w *webSocketResponseWriter) ping(dispose chan bool) { + if dispose == nil { + return + } + defer w.timer.Stop() + for { + select { + case <-dispose: + return + case <-w.timer.C: + w.timer.Reset(w.timeOutInterval) + w.wsConn.WriteMessage(websocket.PingMessage, []byte{}) + } + } +} + func (w *webSocketResponseWriter) Header() http.Header { return w.headers } @@ -39,6 +70,9 @@ func (w *webSocketResponseWriter) Write(b []byte) (int, error) { if !w.writtenHeaders { w.WriteHeader(http.StatusOK) } + if w.timeOutInterval > time.Second && w.timer != nil { + w.timer.Reset(w.timeOutInterval) + } return len(b), w.wsConn.WriteMessage(websocket.BinaryMessage, b) } diff --git a/go/grpcweb/wrapper.go b/go/grpcweb/wrapper.go index 2f1606e8..9616c868 100644 --- a/go/grpcweb/wrapper.go +++ b/go/grpcweb/wrapper.go @@ -162,6 +162,9 @@ func (w *WrappedGrpcServer) handleWebSocket(wsConn *websocket.Conn, req *http.Re defer cancelFunc() respWriter := newWebSocketResponseWriter(wsConn) + if w.opts.websocketPingInterval >= time.Second { + respWriter.enablePing(w.opts.websocketPingInterval) + } wrappedReader := newWebsocketWrappedReader(wsConn, respWriter, cancelFunc) req.Body = wrappedReader diff --git a/go/grpcwebproxy/main.go b/go/grpcwebproxy/main.go index fb8838f8..3f33fc38 100644 --- a/go/grpcwebproxy/main.go +++ b/go/grpcwebproxy/main.go @@ -38,7 +38,8 @@ var ( runHttpServer = pflag.Bool("run_http_server", true, "whether to run HTTP server") runTlsServer = pflag.Bool("run_tls_server", true, "whether to run TLS server") - useWebsockets = pflag.Bool("use_websockets", false, "whether to use beta websocket transport layer") + useWebsockets = pflag.Bool("use_websockets", false, "whether to use beta websocket transport layer") + websocketPingInterval = pflag.Duration("websocket_ping_interval", 0, "whether to use websocket keepalive pinging. Only used when using websockets. Configured interval must be >= 1s.") flagHttpMaxWriteTimeout = pflag.Duration("server_http_max_write_timeout", 10*time.Second, "HTTP server config, max write duration.") flagHttpMaxReadTimeout = pflag.Duration("server_http_max_read_timeout", 10*time.Second, "HTTP server config, max read duration.") @@ -71,6 +72,13 @@ func main() { grpcweb.WithWebsockets(true), grpcweb.WithWebsocketOriginFunc(makeWebsocketOriginFunc(allowedOrigins)), ) + if *websocketPingInterval >= time.Second { + logrus.Infof("websocket keepalive pinging enabled, the timeout interval is %s", websocketPingInterval.String()) + } + options = append( + options, + grpcweb.WithWebsocketPingInterval(*websocketPingInterval), + ) } wrappedGrpc := grpcweb.WrapServer(grpcServer, options...)