Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

engine: Add websocket data handler register function #935

Merged
merged 8 commits into from
May 15, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -959,3 +959,12 @@ func (bot *Engine) RegisterWebsocketDataHandler(fn WebsocketDataHandler, interce
}
return bot.websocketRoutineManager.registerWebsocketDataHandler(fn, interceptorOnly)
}

// SetDefaultWebsocketDataHandler sets the default websocket handler and
// removing all pre-existing handlers
func (bot *Engine) SetDefaultWebsocketDataHandler() error {
if bot == nil {
return errNilBot
}
return bot.websocketRoutineManager.setWebsocketDataHandler(bot.websocketRoutineManager.websocketDataHandler)
}
15 changes: 15 additions & 0 deletions engine/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,18 @@ func TestRegisterWebsocketDataHandler(t *testing.T) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
}

func TestSetDefaultWebsocketDataHandler(t *testing.T) {
t.Parallel()
var e *Engine
err := e.SetDefaultWebsocketDataHandler()
if !errors.Is(err, errNilBot) {
t.Fatalf("received: '%v' but expected: '%v'", err, errNilBot)
}

e = &Engine{websocketRoutineManager: &websocketRoutineManager{}}
err = e.SetDefaultWebsocketDataHandler()
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
}
72 changes: 50 additions & 22 deletions engine/websocketroutine_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ func (m *websocketRoutineManager) websocketRoutine() {
log.Errorf(log.WebsocketMgr, "%v", err)
}

if atomic.LoadInt32(&m.started) != 0 {
m.wg.Add(1)
go m.websocketDataReceiver(ws)
err = m.websocketDataReceiver(ws)
if err != nil {
log.Errorf(log.WebsocketMgr, "%v", err)
}

err = ws.FlushChannels()
Expand All @@ -137,26 +137,42 @@ func (m *websocketRoutineManager) websocketRoutine() {

// WebsocketDataReceiver handles websocket data coming from a websocket feed
// associated with an exchange
func (m *websocketRoutineManager) websocketDataReceiver(ws *stream.Websocket) {
defer m.wg.Done()
for {
select {
case <-m.shutdown:
return
case data := <-ws.ToRoutine:
if data == nil {
log.Errorf(log.WebsocketMgr, "exchange %s nil data sent to websocket", ws.GetName())
}
m.mu.RLock()
for x := range m.dataHandlers {
err := m.dataHandlers[x](ws.GetName(), data)
if err != nil {
log.Error(log.WebsocketMgr, err)
func (m *websocketRoutineManager) websocketDataReceiver(ws *stream.Websocket) error {
if m == nil {
return fmt.Errorf("websocket routine manager %w", ErrNilSubsystem)
}

if ws == nil {
return errNilWebsocket
}

if atomic.LoadInt32(&m.started) == 0 {
return errRoutineManagerNotStarted
}

m.wg.Add(1)
go func() {
defer m.wg.Done()
for {
select {
case <-m.shutdown:
return
case data := <-ws.ToRoutine:
if data == nil {
log.Errorf(log.WebsocketMgr, "exchange %s nil data sent to websocket", ws.GetName())
}
m.mu.RLock()
for x := range m.dataHandlers {
err := m.dataHandlers[x](ws.GetName(), data)
if err != nil {
log.Error(log.WebsocketMgr, err)
}
}
m.mu.RUnlock()
}
m.mu.RUnlock()
}
}
}()
return nil
}

// websocketDataHandler is the default central point for exchange websocket
Expand Down Expand Up @@ -356,12 +372,24 @@ func (m *websocketRoutineManager) registerWebsocketDataHandler(fn WebsocketDataH
defer m.mu.Unlock()

if interceptorOnly {
m.dataHandlers = []WebsocketDataHandler{fn}
return nil
return m.setWebsocketDataHandler(fn)
}

// Push front so that any registered data handler has first preference
// over the gct default handler.
m.dataHandlers = append([]WebsocketDataHandler{fn}, m.dataHandlers...)
return nil
}

// setWebsocketDataHandler sets a single websocket data handler, removing all
// pre-existing handlers.
func (m *websocketRoutineManager) setWebsocketDataHandler(fn WebsocketDataHandler) error {
if m == nil {
return fmt.Errorf("%T %w", m, ErrNilSubsystem)
}
if fn == nil {
return errNilWebsocketDataHandlerFunction
}
m.dataHandlers = []WebsocketDataHandler{fn}
gloriousCode marked this conversation as resolved.
Show resolved Hide resolved
return nil
}
53 changes: 50 additions & 3 deletions engine/websocketroutine_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,11 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) {

mock := stream.New()
mock.ToRoutine = make(chan interface{})

m.wg.Add(1)
go m.websocketDataReceiver(mock)
m.started = 1
err = m.websocketDataReceiver(mock)
if err != nil {
t.Fatal(err)
}

mock.ToRoutine <- nil
mock.ToRoutine <- 1336
Expand All @@ -307,3 +309,48 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) {
close(m.shutdown)
m.wg.Wait()
}

func TestSetWebsocketDataHandler(t *testing.T) {
t.Parallel()
var m *websocketRoutineManager
err := m.setWebsocketDataHandler(nil)
if !errors.Is(err, ErrNilSubsystem) {
t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem)
}

m = new(websocketRoutineManager)
m.shutdown = make(chan struct{})

err = m.setWebsocketDataHandler(nil)
if !errors.Is(err, errNilWebsocketDataHandlerFunction) {
t.Fatalf("received: '%v' but expected: '%v'", err, errNilWebsocketDataHandlerFunction)
}

err = m.registerWebsocketDataHandler(m.websocketDataHandler, false)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}

err = m.registerWebsocketDataHandler(m.websocketDataHandler, false)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}

err = m.registerWebsocketDataHandler(m.websocketDataHandler, false)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}

if len(m.dataHandlers) != 3 {
t.Fatal("unexpected data handler count")
}

err = m.setWebsocketDataHandler(m.websocketDataHandler)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}

if len(m.dataHandlers) != 1 {
t.Fatal("unexpected data handler count")
}
}
2 changes: 2 additions & 0 deletions engine/websocketroutine_manager_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ var (
errNilCurrencyConfig = errors.New("nil currency config received")
errNilCurrencyPairFormat = errors.New("nil currency pair format received")
errNilWebsocketDataHandlerFunction = errors.New("websocket data handler function is nil")
errNilWebsocket = errors.New("websocket is nil")
errRoutineManagerNotStarted = errors.New("websocket routine manager not started")
)

// websocketRoutineManager is used to process websocket updates from a unified location
Expand Down