diff --git a/go.mod b/go.mod index 49e035a..29ec6e4 100644 --- a/go.mod +++ b/go.mod @@ -16,5 +16,6 @@ require ( github.com/go-logr/logr v1.2.4 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/roadrunner-server/tcplisten v1.3.0 // indirect + github.com/rs/cors v1.9.0 // indirect go.opentelemetry.io/otel/metric v1.16.0 // indirect ) diff --git a/go.sum b/go.sum index 27224a8..5a72575 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/roadrunner-server/sdk/v4 v4.2.6 h1:BSQ+HHklszJKGCo91jqRwvgjhSkuz097cb github.com/roadrunner-server/sdk/v4 v4.2.6/go.mod h1:WBLEsz9EMY6CkwpdeageMEPLevD/PaUf4rOOsBsaKlo= github.com/roadrunner-server/tcplisten v1.3.0 h1:VDd6IbP8oIjm5vKvMVozeZgeHgOcoP0XYLOyOqcZHCY= github.com/roadrunner-server/tcplisten v1.3.0/go.mod h1:VR6Ob5am0oEuLMOeLiVvQxG9ShykAEgrlvZddX8EfoU= +github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE= +github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0 h1:pginetY7+onl4qN1vl0xW/V/v6OBZ0vVdH+esuJgvmM= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0/go.mod h1:XiYsayHc36K3EByOO6nbAXnAWbrUxdjUROCEeeROOH8= diff --git a/plugin.go b/plugin.go index 0c73492..9e3c0bd 100644 --- a/plugin.go +++ b/plugin.go @@ -3,10 +3,11 @@ package headers import ( "fmt" "net/http" - "strconv" + "strings" "github.com/roadrunner-server/errors" "github.com/roadrunner-server/sdk/v4/utils" + "github.com/rs/cors" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" jprop "go.opentelemetry.io/contrib/propagators/jaeger" "go.opentelemetry.io/otel/propagation" @@ -60,6 +61,40 @@ func (p *Plugin) Init(cfg Configurer) error { // Middleware is HTTP plugin middleware to serve headers func (p *Plugin) Middleware(next http.Handler) http.Handler { + // Configure CORS handler + if p.cfg.CORS != nil { + corsOptions := cors.Options{ + // Keep BC with previous implementation + OptionsSuccessStatus: http.StatusOK, + } + + if p.cfg.CORS.AllowedOrigin != "" { + corsOptions.AllowedOrigins = strings.Split(p.cfg.CORS.AllowedOrigin, ",") + } + + if p.cfg.CORS.AllowedMethods != "" { + corsOptions.AllowedMethods = strings.Split(p.cfg.CORS.AllowedMethods, ",") + } + + if p.cfg.CORS.AllowedHeaders != "" { + corsOptions.AllowedHeaders = strings.Split(p.cfg.CORS.AllowedHeaders, ",") + } + + if p.cfg.CORS.ExposedHeaders != "" { + corsOptions.ExposedHeaders = strings.Split(p.cfg.CORS.ExposedHeaders, ",") + } + + if p.cfg.CORS.MaxAge > 0 { + corsOptions.MaxAge = p.cfg.CORS.MaxAge + } + + if p.cfg.CORS.AllowCredentials != nil { + corsOptions.AllowCredentials = *p.cfg.CORS.AllowCredentials + } + + next = cors.New(corsOptions).Handler(next) + } + // Define the http.HandlerFunc return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if val, ok := r.Context().Value(utils.OtelTracerNameKey).(string); ok { @@ -86,15 +121,6 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler { } } - if p.cfg.CORS != nil { - if r.Method == http.MethodOptions { - p.preflightRequest(w) - - return - } - p.corsHeaders(w) - } - next.ServeHTTP(w, r) }) } @@ -102,57 +128,3 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler { func (p *Plugin) Name() string { return PluginName } - -// configure OPTIONS response -func (p *Plugin) preflightRequest(w http.ResponseWriter) { - headers := w.Header() - - headers.Add("Vary", "Origin") - headers.Add("Vary", "Access-Control-Request-Method") - headers.Add("Vary", "Access-Control-Request-Headers") - - if p.cfg.CORS.AllowedOrigin != "" { - headers.Set("Access-Control-Allow-Origin", p.cfg.CORS.AllowedOrigin) - } - - if p.cfg.CORS.AllowedHeaders != "" { - headers.Set("Access-Control-Allow-Headers", p.cfg.CORS.AllowedHeaders) - } - - if p.cfg.CORS.AllowedMethods != "" { - headers.Set("Access-Control-Allow-Methods", p.cfg.CORS.AllowedMethods) - } - - if p.cfg.CORS.AllowCredentials != nil { - headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*p.cfg.CORS.AllowCredentials)) - } - - if p.cfg.CORS.MaxAge > 0 { - headers.Set("Access-Control-Max-Age", strconv.Itoa(p.cfg.CORS.MaxAge)) - } - - w.WriteHeader(http.StatusOK) -} - -// configure CORS headers -func (p *Plugin) corsHeaders(w http.ResponseWriter) { - headers := w.Header() - - headers.Add("Vary", "Origin") - - if p.cfg.CORS.AllowedOrigin != "" { - headers.Set("Access-Control-Allow-Origin", p.cfg.CORS.AllowedOrigin) - } - - if p.cfg.CORS.AllowedHeaders != "" { - headers.Set("Access-Control-Allow-Headers", p.cfg.CORS.AllowedHeaders) - } - - if p.cfg.CORS.ExposedHeaders != "" { - headers.Set("Access-Control-Expose-Headers", p.cfg.CORS.ExposedHeaders) - } - - if p.cfg.CORS.AllowCredentials != nil { - headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*p.cfg.CORS.AllowCredentials)) - } -}