Skip to content

Commit

Permalink
feat: Add CORS capability to HTTP API (sourcenetwork#467)
Browse files Browse the repository at this point in the history
RELEVANT ISSUE(S)
Resolves sourcenetwork#408

DESCRIPTION
This PR adds CORS capability to the HTTP API. In situations where the database will be accessed from a different domain than that of the Defra host, the HTTP API needs to communicate to the browser the allowed origins and handle preflight requests.
  • Loading branch information
fredcarle authored Jun 8, 2022
1 parent 7515494 commit 930a591
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 16 deletions.
10 changes: 8 additions & 2 deletions api/http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,19 @@ import (
type handler struct {
db client.DB
*chi.Mux

// user configurable options
options serverOptions
}

type ctxDB struct{}

// newHandler returns a handler with the router instantiated.
func newHandler(db client.DB) *handler {
return setRoutes(&handler{db: db})
func newHandler(db client.DB, opts serverOptions) *handler {
return setRoutes(&handler{
db: db,
options: opts,
})
}

func (h *handler) handle(f http.HandlerFunc) http.HandlerFunc {
Expand Down
56 changes: 55 additions & 1 deletion api/http/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
)

func TestNewHandlerWithLogger(t *testing.T) {
h := newHandler(nil)
h := newHandler(nil, serverOptions{})

dir := t.TempDir()

Expand Down Expand Up @@ -188,3 +188,57 @@ func TestDbFromContext(t *testing.T) {
_, err = dbFromContext(reqCtx)
assert.NoError(t, err)
}

func TestCORSRequest(t *testing.T) {
cases := []struct {
name string
method string
reqHeaders map[string]string
resHeaders map[string]string
}{
{
"DisallowedOrigin",
"OPTIONS",
map[string]string{
"Origin": "https://notsource.network",
},
map[string]string{
"Vary": "Origin",
},
},
{
"AllowedOrigin",
"OPTIONS",
map[string]string{
"Origin": "https://source.network",
},
map[string]string{
"Access-Control-Allow-Origin": "https://source.network",
"Vary": "Origin",
},
},
}

s := NewServer(nil, WithAllowedOrigins("https://source.network"))

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
req, err := http.NewRequest(c.method, PingPath, nil)
if err != nil {
t.Fatal(err)
}

for header, value := range c.reqHeaders {
req.Header.Add(header, value)
}

rec := httptest.NewRecorder()

s.Handler.ServeHTTP(rec, req)

for header, value := range c.resHeaders {
assert.Equal(t, value, rec.Result().Header.Get(header))
}
})
}
}
2 changes: 1 addition & 1 deletion api/http/logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func TestLoggerKeyValueOutput(t *testing.T) {

rec2 := httptest.NewRecorder()

h := newHandler(nil)
h := newHandler(nil, serverOptions{})
log.ApplyConfig(logging.Config{
EncoderFormat: logging.NewEncoderFormatOption(logging.JSON),
OutputPaths: []string{logFile},
Expand Down
11 changes: 11 additions & 0 deletions api/http/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"strings"

"github.com/go-chi/chi/v5"
"github.com/go-chi/cors"
"github.com/pkg/errors"
)

Expand All @@ -35,6 +36,16 @@ var schemeError = errors.New("base must start with the http or https scheme")
func setRoutes(h *handler) *handler {
h.Mux = chi.NewRouter()

// setup CORS
if len(h.options.allowedOrigins) != 0 {
h.Use(cors.Handler(cors.Options{
AllowedOrigins: h.options.allowedOrigins,
AllowedMethods: []string{"GET", "POST", "OPTIONS"},
AllowedHeaders: []string{"Content-Type"},
MaxAge: 300,
}))
}

// setup logger middleware
h.Use(loggerMiddleware)

Expand Down
42 changes: 35 additions & 7 deletions api/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,48 @@ import (

// The Server struct holds the Handler for the HTTP API
type Server struct {
options serverOptions
http.Server
}

type serverOptions struct {
allowedOrigins []string
}

// NewServer instantiated a new server with the given http.Handler.
func NewServer(db client.DB) *Server {
return &Server{
http.Server{
Handler: newHandler(db),
},
func NewServer(db client.DB, options ...func(*Server)) *Server {
svr := &Server{}

for _, opt := range append(options, DefaultOpts()) {
opt(svr)
}

svr.Server.Handler = newHandler(db, svr.options)

return svr
}

func DefaultOpts() func(*Server) {
return func(s *Server) {
if s.Addr == "" {
s.Addr = "localhost:9181"
}
}
}

func WithAllowedOrigins(origins ...string) func(*Server) {
return func(s *Server) {
s.options.allowedOrigins = append(s.options.allowedOrigins, origins...)
}
}

func WithAddress(addr string) func(*Server) {
return func(s *Server) {
s.Addr = addr
}
}

// Listen calls ListenAndServe with our router.
func (s *Server) Listen(addr string) error {
s.Addr = addr
func (s *Server) Listen() error {
return s.ListenAndServe()
}
23 changes: 20 additions & 3 deletions api/http/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@ import (
)

func TestNewServerAndListen(t *testing.T) {
s := NewServer(nil)
s := NewServer(nil, WithAddress(":303000"))
if ok := assert.NotNil(t, s); ok {
assert.Error(t, s.Listen(":303000"))
assert.Error(t, s.Listen())
}

serverRunning := make(chan struct{})
serverDone := make(chan struct{})
s = NewServer(nil, WithAddress(":3131"))
go func() {
close(serverRunning)
err := s.Listen(":3131")
err := s.Listen()
assert.ErrorIs(t, http.ErrServerClosed, err)
defer close(serverDone)
}()
Expand All @@ -39,3 +40,19 @@ func TestNewServerAndListen(t *testing.T) {

<-serverDone
}

func TestNewServerWithoutOptions(t *testing.T) {
s := NewServer(nil)
assert.Equal(t, "localhost:9181", s.Addr)
assert.Equal(t, []string(nil), s.options.allowedOrigins)
}

func TestNewServerWithAddress(t *testing.T) {
s := NewServer(nil, WithAddress("localhost:9999"))
assert.Equal(t, "localhost:9999", s.Addr)
}

func TestNewServerWithAllowedOrigins(t *testing.T) {
s := NewServer(nil, WithAllowedOrigins("https://source.network", "https://app.source.network"))
assert.Equal(t, []string{"https://source.network", "https://app.source.network"}, s.options.allowedOrigins)
}
4 changes: 2 additions & 2 deletions cli/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ var startCmd = &cobra.Command{
httpapi.RootPath,
),
)
s := http.NewServer(db)
if err := s.Listen(config.Database.Address); err != nil {
s := http.NewServer(db, http.WithAddress(config.Database.Address))
if err := s.Listen(); err != nil {
log.ErrorE(ctx, "Failed to start HTTP API listener", err)
if n != nil {
n.Close() //nolint
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ require (
github.com/fatih/color v1.13.0
github.com/go-chi/chi/v5 v5.0.7
github.com/iancoleman/strcase v0.2.0
github.com/go-chi/cors v1.2.1
github.com/pkg/errors v0.9.1
golang.org/x/text v0.3.7
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aev
github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98=
github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8=
github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4=
github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
github.com/go-delve/delve v1.5.0/go.mod h1:c6b3a1Gry6x8a4LGCe/CWzrocrfaHvkUxCj3k4bvSUQ=
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
Expand Down

0 comments on commit 930a591

Please sign in to comment.