Skip to content

Commit

Permalink
Add SSERouterConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
m110 committed Jan 10, 2020
1 parent 460e77f commit 84c1363
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 27 deletions.
4 changes: 0 additions & 4 deletions pkg/http/publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,6 @@ func (p *Publisher) Publish(topic string, messages ...*message.Message) error {
return errors.Wrap(ErrErrorResponse, resp.Status)
}

if err != nil {
return errors.Wrapf(err, "could not close response body for message %s", msg.UUID)
}

p.logger.Trace("Message published", logFields)
}

Expand Down
62 changes: 42 additions & 20 deletions pkg/http/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package http

import (
"context"
"net/http"

"github.com/go-chi/render"
"github.com/pkg/errors"
"net/http"

"github.com/ThreeDotsLabs/watermill"
"github.com/ThreeDotsLabs/watermill/message"
Expand Down Expand Up @@ -32,38 +32,59 @@ func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
render.Respond(w, r, defaultErrorResponse{Error: err.Error()})
}

// SSERouter
// SSERouter is a router handling Server-Sent Events.
type SSERouter struct {
fanOut *gochannel.FanOut
errorHandler HandleErrorFunc
logger watermill.LoggerAdapter
fanOut *gochannel.FanOut
config SSERouterConfig
logger watermill.LoggerAdapter
}

// NewSSERouter creates new SSERouter.
type SSERouterConfig struct {
UpstreamSubscriber message.Subscriber
ErrorHandler HandleErrorFunc
}

func (c *SSERouterConfig) setDefaults() {
if c.ErrorHandler == nil {
c.ErrorHandler = DefaultErrorHandler
}
}

func (c SSERouterConfig) validate() error {
if c.UpstreamSubscriber == nil {
return errors.New("upstream subscriber is nil")
}

return nil
}

// NewSSERouter creates a new SSERouter.
func NewSSERouter(
upstreamSubscriber message.Subscriber,
errorHandler HandleErrorFunc,
config SSERouterConfig,
logger watermill.LoggerAdapter,
) (SSERouter, error) {
if errorHandler == nil {
errorHandler = DefaultErrorHandler
config.setDefaults()
if err := config.validate(); err != nil {
return SSERouter{}, errors.Wrap(err, "invalid SSERouter config")
}

if logger == nil {
logger = watermill.NopLogger{}
}

fanOut, err := gochannel.NewFanOut(upstreamSubscriber, logger)
fanOut, err := gochannel.NewFanOut(config.UpstreamSubscriber, logger)
if err != nil {
return SSERouter{}, err
return SSERouter{}, errors.Wrap(err, "could not create a FanOut")
}

return SSERouter{
fanOut: fanOut,
errorHandler: errorHandler,
logger: logger,
fanOut: fanOut,
config: config,
logger: logger,
}, nil
}

// AddHandler starts a new handler for a given topic.
func (r SSERouter) AddHandler(topic string, streamAdapter StreamAdapter) http.HandlerFunc {
r.logger.Trace("Adding handler for topic", watermill.LogFields{
"topic": topic,
Expand All @@ -75,13 +96,14 @@ func (r SSERouter) AddHandler(topic string, streamAdapter StreamAdapter) http.Ha
subscriber: r.fanOut,
topic: topic,
streamAdapter: streamAdapter,
errorHandler: r.errorHandler,
config: r.config,
logger: r.logger,
}

return handler.Handle
}

// Run starts the SSERouter.
func (r SSERouter) Run(ctx context.Context) error {
return r.fanOut.Run(ctx)
}
Expand All @@ -95,7 +117,7 @@ type sseHandler struct {
subscriber message.Subscriber
topic string
streamAdapter StreamAdapter
errorHandler HandleErrorFunc
config SSERouterConfig
logger watermill.LoggerAdapter
}

Expand All @@ -120,7 +142,7 @@ func (h sseHandler) handleGenericRequest(w http.ResponseWriter, r *http.Request)
func (h sseHandler) handleEventStream(w http.ResponseWriter, r *http.Request) {
messages, err := h.subscriber.Subscribe(r.Context(), h.topic)
if err != nil {
h.errorHandler(w, r, err)
h.config.ErrorHandler(w, r, err)
return
}

Expand Down Expand Up @@ -166,7 +188,7 @@ func (h sseHandler) processMessage(w http.ResponseWriter, r *http.Request, msg *
return nil, false
}

h.logger.Trace("Received valid message", nil)
h.logger.Trace("Received valid message", watermill.LogFields{"uuid": msg.UUID})

return h.streamAdapter.GetResponse(w, r)
}
10 changes: 7 additions & 3 deletions pkg/http/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,22 @@ import (
"testing"
"time"

"github.com/go-chi/chi"
"github.com/stretchr/testify/require"

"github.com/ThreeDotsLabs/watermill"
"github.com/ThreeDotsLabs/watermill-http/pkg/http"
"github.com/ThreeDotsLabs/watermill/message"
"github.com/ThreeDotsLabs/watermill/pubsub/gochannel"
"github.com/go-chi/chi"
"github.com/stretchr/testify/require"
)

func TestSSE(t *testing.T) {
pubsub := gochannel.NewGoChannel(gochannel.Config{}, watermill.NopLogger{})

sseRouter, err := http.NewSSERouter(pubsub, http.DefaultErrorHandler, watermill.NopLogger{})
sseRouter, err := http.NewSSERouter(http.SSERouterConfig{
UpstreamSubscriber: pubsub,
ErrorHandler: http.DefaultErrorHandler,
}, watermill.NopLogger{})
require.NoError(t, err)

postUpdatedTopic := "post-updated"
Expand Down

0 comments on commit 84c1363

Please sign in to comment.