Skip to content

Commit

Permalink
Merge pull request #3 from ThreeDotsLabs/sse-handler
Browse files Browse the repository at this point in the history
Add SSERouter
  • Loading branch information
m110 authored Jan 10, 2020
2 parents cbee977 + 84c1363 commit f9529a1
Show file tree
Hide file tree
Showing 5 changed files with 482 additions and 7 deletions.
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
module github.com/ThreeDotsLabs/watermill-http

require (
github.com/ThreeDotsLabs/watermill v1.0.2
github.com/ThreeDotsLabs/watermill v1.1.0
github.com/go-chi/chi v4.0.2+incompatible
github.com/go-chi/render v1.0.1
github.com/pkg/errors v0.8.1
github.com/stretchr/testify v1.3.0
)
6 changes: 4 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
github.com/ThreeDotsLabs/watermill v1.0.2 h1:UGiE61pRWAMwEX3z/AVZTaZmm/Y+AeEd+cXoeQz9NuM=
github.com/ThreeDotsLabs/watermill v1.0.2/go.mod h1:vZCPh7eN0P7r2qKau4SfmcUZ83+3JXWkRl4BiWUlqFw=
github.com/ThreeDotsLabs/watermill v1.1.0 h1:RWVfySGHEaK4TZhr8L/rKkYkSbyzWRzE8ut7QP7esLY=
github.com/ThreeDotsLabs/watermill v1.1.0/go.mod h1:Qd1xNFxolCAHCzcMrm6RnjW0manbvN+DJVWc1MWRFlI=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
Expand All @@ -10,6 +10,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-chi/chi v4.0.2+incompatible h1:maB6vn6FqCxrpz4FqWdh4+lwpyZIQS7YEAUcHlgXVRs=
github.com/go-chi/chi v4.0.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ=
github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8=
github.com/go-chi/render v1.0.1/go.mod h1:pq4Rr7HbnsdaeHagklXub+p6Wd16Af5l9koip1OvJns=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
Expand Down
4 changes: 0 additions & 4 deletions pkg/http/publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,6 @@ func (p *Publisher) Publish(topic string, messages ...*message.Message) error {
return errors.Wrap(ErrErrorResponse, resp.Status)
}

if err != nil {
return errors.Wrapf(err, "could not close response body for message %s", msg.UUID)
}

p.logger.Trace("Message published", logFields)
}

Expand Down
194 changes: 194 additions & 0 deletions pkg/http/sse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package http

import (
"context"
"github.com/go-chi/render"
"github.com/pkg/errors"
"net/http"

"github.com/ThreeDotsLabs/watermill"
"github.com/ThreeDotsLabs/watermill/message"
"github.com/ThreeDotsLabs/watermill/pubsub/gochannel"
)

type StreamAdapter interface {
// GetResponse returns the response to be sent back to client.
// Any errors that occur should be handled and written to `w`, returning false as `ok`.
GetResponse(w http.ResponseWriter, r *http.Request) (response interface{}, ok bool)
// Validate validates if the incoming message should be handled by this handler.
// Typically this involves checking some kind of model ID.
Validate(r *http.Request, msg *message.Message) (ok bool)
}

type HandleErrorFunc func(w http.ResponseWriter, r *http.Request, err error)

type defaultErrorResponse struct {
Error string `json:"error"`
}

// DefaultErrorHandler writes JSON error response along with Internal Server Error code (500).
func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
w.WriteHeader(500)
render.Respond(w, r, defaultErrorResponse{Error: err.Error()})
}

// SSERouter is a router handling Server-Sent Events.
type SSERouter struct {
fanOut *gochannel.FanOut
config SSERouterConfig
logger watermill.LoggerAdapter
}

type SSERouterConfig struct {
UpstreamSubscriber message.Subscriber
ErrorHandler HandleErrorFunc
}

func (c *SSERouterConfig) setDefaults() {
if c.ErrorHandler == nil {
c.ErrorHandler = DefaultErrorHandler
}
}

func (c SSERouterConfig) validate() error {
if c.UpstreamSubscriber == nil {
return errors.New("upstream subscriber is nil")
}

return nil
}

// NewSSERouter creates a new SSERouter.
func NewSSERouter(
config SSERouterConfig,
logger watermill.LoggerAdapter,
) (SSERouter, error) {
config.setDefaults()
if err := config.validate(); err != nil {
return SSERouter{}, errors.Wrap(err, "invalid SSERouter config")
}

if logger == nil {
logger = watermill.NopLogger{}
}

fanOut, err := gochannel.NewFanOut(config.UpstreamSubscriber, logger)
if err != nil {
return SSERouter{}, errors.Wrap(err, "could not create a FanOut")
}

return SSERouter{
fanOut: fanOut,
config: config,
logger: logger,
}, nil
}

// AddHandler starts a new handler for a given topic.
func (r SSERouter) AddHandler(topic string, streamAdapter StreamAdapter) http.HandlerFunc {
r.logger.Trace("Adding handler for topic", watermill.LogFields{
"topic": topic,
})

r.fanOut.AddSubscription(topic)

handler := sseHandler{
subscriber: r.fanOut,
topic: topic,
streamAdapter: streamAdapter,
config: r.config,
logger: r.logger,
}

return handler.Handle
}

// Run starts the SSERouter.
func (r SSERouter) Run(ctx context.Context) error {
return r.fanOut.Run(ctx)
}

// Running is closed when the SSERouter is running.
func (r SSERouter) Running() chan struct{} {
return r.fanOut.Running()
}

type sseHandler struct {
subscriber message.Subscriber
topic string
streamAdapter StreamAdapter
config SSERouterConfig
logger watermill.LoggerAdapter
}

func (h sseHandler) Handle(w http.ResponseWriter, r *http.Request) {
if render.GetAcceptedContentType(r) == render.ContentTypeEventStream {
h.handleEventStream(w, r)
return
}

h.handleGenericRequest(w, r)
}

func (h sseHandler) handleGenericRequest(w http.ResponseWriter, r *http.Request) {
response, ok := h.streamAdapter.GetResponse(w, r)
if !ok {
return
}

render.Respond(w, r, response)
}

func (h sseHandler) handleEventStream(w http.ResponseWriter, r *http.Request) {
messages, err := h.subscriber.Subscribe(r.Context(), h.topic)
if err != nil {
h.config.ErrorHandler(w, r, err)
return
}

responsesChan := make(chan interface{})

go func() {
defer func() {
h.logger.Trace("Closing SSE handler", nil)
close(responsesChan)
}()

response, ok := h.streamAdapter.GetResponse(w, r)
if !ok {
return
}

responsesChan <- response

h.logger.Trace("Listening for messages", nil)

for msg := range messages {
msg.Ack()

response, ok := h.processMessage(w, r, msg)
if ok {
responsesChan <- response
}

select {
case <-r.Context().Done():
return
default:
}
}
}()

render.Respond(w, r, responsesChan)
}

func (h sseHandler) processMessage(w http.ResponseWriter, r *http.Request, msg *message.Message) (interface{}, bool) {
ok := h.streamAdapter.Validate(r, msg)
if !ok {
return nil, false
}

h.logger.Trace("Received valid message", watermill.LogFields{"uuid": msg.UUID})

return h.streamAdapter.GetResponse(w, r)
}
Loading

0 comments on commit f9529a1

Please sign in to comment.