Skip to content

Commit

Permalink
fix: Eliminate race conditions in reconnect feature. (#2285)
Browse files Browse the repository at this point in the history
  • Loading branch information
mturoci authored Mar 12, 2024
1 parent e10e866 commit 1658a6a
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 51 deletions.
6 changes: 6 additions & 0 deletions broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ func newBroker(site *Site, editable, noStore, noLog, keepAppLive, debug bool) *B
}
}

func (b *Broker) getClient(id string) *Client {
b.unicastsMux.RLock()
defer b.unicastsMux.RUnlock()
return b.clientsByID[id]
}

func (b *Broker) addApp(mode, route, addr, keyID, keySecret string) {
s := newApp(b, mode, route, addr, keyID, keySecret)

Expand Down
88 changes: 58 additions & 30 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,23 @@ import (
"context"
"encoding/json"
"net/http"
"sync"
"time"

"github.com/google/uuid"
"github.com/gorilla/websocket"
)

const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second

// Maximum message size allowed from peer.
maxMessageSize = 1 * 1024 * 1024 // bytes
writeWait = 10 * time.Second // Time allowed to write a message to the peer.
maxMessageSize = 1 * 1024 * 1024 // bytes Maximum message size allowed from peer.
// TODO: Refactor into iota.
STATE_CREATED = "CREATED"
STATE_TIMEOUT = "TIMEOUT"
STATE_LISTEN = "LISTEN"
STATE_RECONNECT = "RECONNECT"
STATE_DISCONNECT = "DISCONNECT"
STATE_CLOSED = "CLOSED"
)

var (
Expand Down Expand Up @@ -63,16 +68,17 @@ type Client struct {
header *http.Header // forwarded headers from the WS connection
appPath string // path of the app this client is connected to, doesn't change throughout WS lifetime
pingInterval time.Duration
isReconnect bool
cancel context.CancelFunc
reconnectTimeout time.Duration
lock *sync.Mutex
state string
}

// TODO: Refactor some of the params into a Config struct.
func newClient(addr string, auth *Auth, session *Session, broker *Broker, conn *websocket.Conn, editable bool,
baseURL string, header *http.Header, pingInterval time.Duration, isReconnect bool, reconnectTimeout time.Duration) *Client {
baseURL string, header *http.Header, pingInterval time.Duration, reconnectTimeout time.Duration) *Client {
id := uuid.New().String()
return &Client{id, auth, addr, session, broker, conn, nil, make(chan []byte, 256), editable, baseURL, header, "", pingInterval, isReconnect, nil, reconnectTimeout}
return &Client{id, auth, addr, session, broker, conn, nil, make(chan []byte, 256),
editable, baseURL, header, "", pingInterval, reconnectTimeout, &sync.Mutex{}, STATE_CREATED}
}

func (c *Client) refreshToken() error {
Expand All @@ -90,29 +96,44 @@ func (c *Client) refreshToken() error {
return nil
}

func (c *Client) setState(newState string) {
c.lock.Lock()
c.state = newState
c.lock.Unlock()
}

func (c *Client) listen() {
defer func() {
ctx, cancel := context.WithCancel(context.Background())
c.cancel = cancel
go func(ctx context.Context) {
select {
// Send disconnect message only if client doesn't reconnect within the specified timeframe.
case <-time.After(c.reconnectTimeout):
app := c.broker.getApp(c.appPath)
if app != nil {
app.forward(c.id, c.session, disconnectMsg)
if err := app.disconnect(c.id); err != nil {
echo(Log{"t": "disconnect", "client": c.addr, "route": c.appPath, "err": err.Error()})
}
}
c.lock.Lock()
defer c.lock.Unlock()
if c.state != STATE_DISCONNECT {
return
}
// This defer runs to completion. If the client drops, reconnects and drops out again, ignore first drop timeout.
timeoutID := STATE_TIMEOUT + c.addr
c.state = timeoutID
c.lock.Unlock()

c.broker.unsubscribe <- c
case <-ctx.Done():
select {
// Send disconnect message only if client doesn't reconnect within the specified timeframe.
case <-time.After(c.reconnectTimeout):
c.lock.Lock()
if c.state != timeoutID {
return
}
}(ctx)
app := c.broker.getApp(c.appPath)
if app != nil {
app.forward(c.id, c.session, disconnectMsg)
if err := app.disconnect(c.id); err != nil {
echo(Log{"t": "disconnect", "client": c.addr, "route": c.appPath, "err": err.Error()})
}
}

c.conn.Close()
echo(Log{"t": "client_unsubscribe", "client": c.id})
c.broker.unsubscribe <- c
c.state = STATE_CLOSED
return
}
}()
// Time allowed to read the next pong message from the peer. Must be greater than ping interval.
pongWait := 10 * c.pingInterval / 9
Expand All @@ -127,10 +148,8 @@ func (c *Client) listen() {
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
echo(Log{"t": "socket_read", "client": c.addr, "err": err.Error()})
} else {
// Firefox follows spec closely and requires a close message to be sent before closing the connection.
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
}
c.setState(STATE_DISCONNECT)
break
}

Expand Down Expand Up @@ -173,7 +192,10 @@ func (c *Client) listen() {
c.broker.sendAll(c.broker.clients[app.route], clearStateMsg)
}
case watchMsgT:
if c.isReconnect {
c.lock.Lock()
state := c.state
c.lock.Unlock()
if state == STATE_RECONNECT {
continue
}
c.subscribe(m.addr) // subscribe even if page is currently NA
Expand Down Expand Up @@ -238,10 +260,13 @@ func (c *Client) flush() {
defer func() {
ticker.Stop()
c.conn.Close()
c.lock.Unlock()
}()
for {
select {
case data, ok := <-c.data:
// An alternative to the mutex here would be a new channel for closing the connection so it does not race with reconnect.
c.lock.Lock()
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
// broker closed the channel.
Expand All @@ -265,11 +290,14 @@ func (c *Client) flush() {
if err := w.Close(); err != nil {
return
}
c.lock.Unlock()
case <-ticker.C:
c.lock.Lock()
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
c.lock.Unlock()
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,6 @@ type Conf struct {
SkipLogin bool `cfg:"oidc-skip-login" env:"H2O_WAVE_OIDC_SKIP_LOGIN" cfgDefault:"false" cfgHelper:"do not display the login form during OIDC authorization"`
KeepAppLive bool `cfg:"keep-app-live" env:"H2O_WAVE_KEEP_APP_LIVE" cfgDefault:"false" cfgHelper:"do not unregister unresponsive apps"`
Conf string `cfg:"conf" env:"H2O_WAVE_CONF" cfgDefault:".env" cfgHelper:"path to configuration file"`
ReconnectTimeout string `cfg:"reconnect-timeout" env:"H2O_WAVE_RECONNECT_TIMEOUT" cfgDefault:"2s" cfgHelper:"Time to wait for reconnect before dropping the client"`
ReconnectTimeout string `cfg:"reconnect-timeout" env:"H2O_WAVE_RECONNECT_TIMEOUT" cfgDefault:"5s" cfgHelper:"Time to wait for reconnect before dropping the client"`
AllowedOrigins string `cfg:"allowed-origins" env:"H2O_WAVE_ALLOWED_ORIGINS" cfgDefault:"" cfgHelper:"comma-separated list of allowed origins (e.g. http://foo.com) for websocket upgrades"`
}
31 changes: 16 additions & 15 deletions socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,28 +84,29 @@ func (s *SocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
clientID := r.URL.Query().Get("client-id")
client, ok := s.broker.clientsByID[clientID]
if ok {
client := s.broker.getClient(clientID)
if client != nil {
client.lock.Lock()
// Close prev connection gracefully.
client.conn.WriteMessage(websocket.CloseMessage, []byte{})
client.conn.Close()
client.conn = conn
client.isReconnect = true
if client.cancel != nil {
client.cancel()
}
if s.broker.debug {
echo(Log{"t": "socket_reconnect", "client_id": clientID, "addr": getRemoteAddr(r)})
}
client.state = STATE_RECONNECT
client.addr = getRemoteAddr(r)
client.lock.Unlock()
echo(Log{"t": "client_reconnect", "client_id": client.id, "addr": getRemoteAddr(r)})
} else {
client = newClient(getRemoteAddr(r), s.auth, session, s.broker, conn, s.editable, s.baseURL, &header, s.pingInterval, false, s.reconnectTimeout)
}
client = newClient(getRemoteAddr(r), s.auth, session, s.broker, conn, s.editable, s.baseURL, &header, s.pingInterval, s.reconnectTimeout)

if msg, err := json.Marshal(OpsD{I: client.id}); err == nil {
sw, err := conn.NextWriter(websocket.TextMessage)
helloMsg, err := json.Marshal(OpsD{I: client.id})
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
sw.Write(msg)
sw.Close()
if !client.send(helloMsg) {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
}

go client.flush()
Expand Down
7 changes: 3 additions & 4 deletions ui/src/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -946,12 +946,11 @@ export const
const
slug = window.location.pathname,
reconnect = (address: S) => {
if (_clientID && !address.includes('?client-id')) {
address = `${address}?${new URLSearchParams({ 'client-id': _clientID })}`
}
let wsAddr = address
if (_clientID) wsAddr = `${address}?${new URLSearchParams({ 'client-id': _clientID })}`

const retry = () => reconnect(address)
const socket = new WebSocket(address)
const socket = new WebSocket(wsAddr)
socket.onopen = () => {
_reconnectFailures = 0
_socket = socket
Expand Down
2 changes: 1 addition & 1 deletion website/docs/routing.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ Note that when a user logs out of the Wave daemon, all the apps linked to the da

### Handling client (browser tab) disconnect

To get notified when a user closes the tab, use the system-wide `@system.client_disconnect` event.
To get notified when a user closes the tab, use the system-wide `@system.client_disconnect` event. The time if takes for this function to be called depends on the value of `H2O_WAVE_RECONNECT_TIMEOUT` (which defaults to `5s`).

```py
@on('@system.client_disconnect')
Expand Down

0 comments on commit 1658a6a

Please sign in to comment.