Skip to content

Commit

Permalink
Enhance CORS support
Browse files Browse the repository at this point in the history
  • Loading branch information
Rokas Mikalkėnas committed Jun 16, 2023
1 parent 4291d45 commit 65bc7c7
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 64 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ require (
github.com/go-logr/logr v1.2.4 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/roadrunner-server/tcplisten v1.3.0 // indirect
github.com/rs/cors v1.9.0 // indirect
go.opentelemetry.io/otel/metric v1.16.0 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ github.com/roadrunner-server/sdk/v4 v4.2.6 h1:BSQ+HHklszJKGCo91jqRwvgjhSkuz097cb
github.com/roadrunner-server/sdk/v4 v4.2.6/go.mod h1:WBLEsz9EMY6CkwpdeageMEPLevD/PaUf4rOOsBsaKlo=
github.com/roadrunner-server/tcplisten v1.3.0 h1:VDd6IbP8oIjm5vKvMVozeZgeHgOcoP0XYLOyOqcZHCY=
github.com/roadrunner-server/tcplisten v1.3.0/go.mod h1:VR6Ob5am0oEuLMOeLiVvQxG9ShykAEgrlvZddX8EfoU=
github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE=
github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0 h1:pginetY7+onl4qN1vl0xW/V/v6OBZ0vVdH+esuJgvmM=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0/go.mod h1:XiYsayHc36K3EByOO6nbAXnAWbrUxdjUROCEeeROOH8=
Expand Down
100 changes: 36 additions & 64 deletions plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package headers
import (
"fmt"
"net/http"
"strconv"
"strings"

"github.com/roadrunner-server/errors"
"github.com/roadrunner-server/sdk/v4/utils"
"github.com/rs/cors"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
jprop "go.opentelemetry.io/contrib/propagators/jaeger"
"go.opentelemetry.io/otel/propagation"
Expand Down Expand Up @@ -60,6 +61,40 @@ func (p *Plugin) Init(cfg Configurer) error {

// Middleware is HTTP plugin middleware to serve headers
func (p *Plugin) Middleware(next http.Handler) http.Handler {
// Configure CORS handler
if p.cfg.CORS != nil {
corsOptions := cors.Options{
// Keep BC with previous implementation
OptionsSuccessStatus: http.StatusOK,
}

if p.cfg.CORS.AllowedOrigin != "" {
corsOptions.AllowedOrigins = strings.Split(p.cfg.CORS.AllowedOrigin, ",")
}

if p.cfg.CORS.AllowedMethods != "" {
corsOptions.AllowedMethods = strings.Split(p.cfg.CORS.AllowedMethods, ",")
}

if p.cfg.CORS.AllowedHeaders != "" {
corsOptions.AllowedHeaders = strings.Split(p.cfg.CORS.AllowedHeaders, ",")
}

if p.cfg.CORS.ExposedHeaders != "" {
corsOptions.ExposedHeaders = strings.Split(p.cfg.CORS.ExposedHeaders, ",")
}

if p.cfg.CORS.MaxAge > 0 {
corsOptions.MaxAge = p.cfg.CORS.MaxAge
}

if p.cfg.CORS.AllowCredentials != nil {
corsOptions.AllowCredentials = *p.cfg.CORS.AllowCredentials
}

next = cors.New(corsOptions).Handler(next)
}

// Define the http.HandlerFunc
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if val, ok := r.Context().Value(utils.OtelTracerNameKey).(string); ok {
Expand All @@ -86,73 +121,10 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler {
}
}

if p.cfg.CORS != nil {
if r.Method == http.MethodOptions {
p.preflightRequest(w)

return
}
p.corsHeaders(w)
}

next.ServeHTTP(w, r)
})
}

func (p *Plugin) Name() string {
return PluginName
}

// configure OPTIONS response
func (p *Plugin) preflightRequest(w http.ResponseWriter) {
headers := w.Header()

headers.Add("Vary", "Origin")
headers.Add("Vary", "Access-Control-Request-Method")
headers.Add("Vary", "Access-Control-Request-Headers")

if p.cfg.CORS.AllowedOrigin != "" {
headers.Set("Access-Control-Allow-Origin", p.cfg.CORS.AllowedOrigin)
}

if p.cfg.CORS.AllowedHeaders != "" {
headers.Set("Access-Control-Allow-Headers", p.cfg.CORS.AllowedHeaders)
}

if p.cfg.CORS.AllowedMethods != "" {
headers.Set("Access-Control-Allow-Methods", p.cfg.CORS.AllowedMethods)
}

if p.cfg.CORS.AllowCredentials != nil {
headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*p.cfg.CORS.AllowCredentials))
}

if p.cfg.CORS.MaxAge > 0 {
headers.Set("Access-Control-Max-Age", strconv.Itoa(p.cfg.CORS.MaxAge))
}

w.WriteHeader(http.StatusOK)
}

// configure CORS headers
func (p *Plugin) corsHeaders(w http.ResponseWriter) {
headers := w.Header()

headers.Add("Vary", "Origin")

if p.cfg.CORS.AllowedOrigin != "" {
headers.Set("Access-Control-Allow-Origin", p.cfg.CORS.AllowedOrigin)
}

if p.cfg.CORS.AllowedHeaders != "" {
headers.Set("Access-Control-Allow-Headers", p.cfg.CORS.AllowedHeaders)
}

if p.cfg.CORS.ExposedHeaders != "" {
headers.Set("Access-Control-Expose-Headers", p.cfg.CORS.ExposedHeaders)
}

if p.cfg.CORS.AllowCredentials != nil {
headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*p.cfg.CORS.AllowCredentials))
}
}

0 comments on commit 65bc7c7

Please sign in to comment.