Skip to content

Commit

Permalink
Remove most allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
rs committed Sep 5, 2023
1 parent d3f0a2b commit 080e86e
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 128 deletions.
93 changes: 58 additions & 35 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ import (
"strings"
)

var headerVaryOrigin = []string{"Origin"}
var headerOriginAll = []string{"*"}
var headerTrue = []string{"true"}

// Options is a configuration container to setup the CORS middleware.
type Options struct {
// AllowedOrigins is a list of origins a cross-domain request can be executed from.
Expand Down Expand Up @@ -108,9 +112,10 @@ type Cors struct {
allowedHeaders []string
// Normalized list of allowed methods
allowedMethods []string
// Normalized list of exposed headers
// Pre-computed normalized list of exposed headers
exposedHeaders []string
maxAge int
// Pre-computed maxAge header value
maxAge []string
// Set to true when allowed origins contains a "*"
allowedOriginsAll bool
// Set to true when allowed headers contains a "*"
Expand All @@ -120,15 +125,14 @@ type Cors struct {
allowCredentials bool
allowPrivateNetwork bool
optionPassthrough bool
preflightVary []string
}

// New creates a new Cors handler with the provided options.
func New(options Options) *Cors {
c := &Cors{
exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey),
allowCredentials: options.AllowCredentials,
allowPrivateNetwork: options.AllowPrivateNetwork,
maxAge: options.MaxAge,
optionPassthrough: options.OptionsPassthrough,
Log: options.Logger,
}
Expand Down Expand Up @@ -211,6 +215,23 @@ func New(options Options) *Cors {
c.optionsSuccessStatus = options.OptionsSuccessStatus
}

// Pre-compute exposed headers header value
c.exposedHeaders = []string{strings.Join(convert(options.ExposedHeaders, http.CanonicalHeaderKey), ", ")}

// Pre-compute prefight Vary header to save allocations
if c.allowPrivateNetwork {
c.preflightVary = []string{"Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network"}
} else {
c.preflightVary = []string{"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"}
}

// Precompute max-age
if options.MaxAge > 0 {
c.maxAge = []string{strconv.Itoa(options.MaxAge)}
} else if options.MaxAge < 0 {
c.maxAge = []string{"0"}
}

return c
}

Expand Down Expand Up @@ -307,13 +328,11 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
// Always set Vary headers
// see https://github.com/rs/cors/issues/10,
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
headers.Add("Vary", "Origin")
headers.Add("Vary", "Access-Control-Request-Method")
headers.Add("Vary", "Access-Control-Request-Headers")
if c.allowPrivateNetwork {
headers.Add("Vary", "Access-Control-Request-Private-Network")
if vary, found := headers["Vary"]; found {
headers["Vary"] = append(vary, c.preflightVary[0])
} else {
headers["Vary"] = c.preflightVary
}

allowed, additionalVaryHeaders := c.isOriginAllowed(r, origin)
if len(additionalVaryHeaders) > 0 {
headers.Add("Vary", strings.Join(convert(additionalVaryHeaders, http.CanonicalHeaderKey), ", "))
Expand All @@ -333,42 +352,41 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
c.logf(" Preflight aborted: method '%s' not allowed", reqMethod)
return
}
// Amazon API Gateway is sometimes feeding multiple values for
// Access-Control-Request-Headers in a way where r.Header.Values() picks
// them all up, but r.Header.Get() does not.
// I suspect it is something like this: https://stackoverflow.com/a/4371395
reqHeaderList := strings.Join(r.Header.Values("Access-Control-Request-Headers"), ",")
reqHeaders := parseHeaderList(reqHeaderList)
reqHeadersRaw := r.Header["Access-Control-Request-Headers"]
reqHeaders, reqHeadersEdited := convertDidCopy(splitHeaderValues(reqHeadersRaw), http.CanonicalHeaderKey)
if !c.areHeadersAllowed(reqHeaders) {
c.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders)
return
}
if c.allowedOriginsAll {
headers.Set("Access-Control-Allow-Origin", "*")
headers["Access-Control-Allow-Origin"] = headerOriginAll
} else {
headers.Set("Access-Control-Allow-Origin", origin)
headers["Access-Control-Allow-Origin"] = r.Header["Origin"]
}
// Spec says: Since the list of methods can be unbounded, simply returning the method indicated
// by Access-Control-Request-Method (if supported) can be enough
headers.Set("Access-Control-Allow-Methods", reqMethod)
headers["Access-Control-Allow-Methods"] = r.Header["Access-Control-Request-Method"]
if len(reqHeaders) > 0 {

// Spec says: Since the list of headers can be unbounded, simply returning supported headers
// from Access-Control-Request-Headers can be enough
headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", "))
if reqHeadersEdited || len(reqHeaders) != len(reqHeadersRaw) {
headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", "))
} else {
headers["Access-Control-Allow-Headers"] = reqHeadersRaw
}
}
if c.allowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
headers["Access-Control-Allow-Credentials"] = headerTrue
}
if c.allowPrivateNetwork && r.Header.Get("Access-Control-Request-Private-Network") == "true" {
headers.Set("Access-Control-Allow-Private-Network", "true")
headers["Access-Control-Allow-Private-Network"] = headerTrue
}
if len(c.maxAge) > 0 {
headers["Access-Control-Max-Age"] = c.maxAge
}
if c.maxAge > 0 {
headers.Set("Access-Control-Max-Age", strconv.Itoa(c.maxAge))
} else if c.maxAge < 0 {
headers.Set("Access-Control-Max-Age", "0")
if c.Log != nil {
c.logf(" Preflight response headers: %v", headers)
}
c.logf(" Preflight response headers: %v", headers)
}

// handleActualRequest handles simple cross-origin requests, actual request or redirects
Expand All @@ -379,7 +397,11 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {
allowed, additionalVaryHeaders := c.isOriginAllowed(r, origin)

// Always set Vary, see https://github.com/rs/cors/issues/10
headers.Add("Vary", "Origin")
if vary, found := headers["Vary"]; found {
headers["Vary"] = append(vary, headerVaryOrigin[0])
} else {
headers["Vary"] = headerVaryOrigin
}
if len(additionalVaryHeaders) > 0 {
headers.Add("Vary", strings.Join(convert(additionalVaryHeaders, http.CanonicalHeaderKey), ", "))
}
Expand All @@ -401,17 +423,19 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {
return
}
if c.allowedOriginsAll {
headers.Set("Access-Control-Allow-Origin", "*")
headers["Access-Control-Allow-Origin"] = headerOriginAll
} else {
headers.Set("Access-Control-Allow-Origin", origin)
headers["Access-Control-Allow-Origin"] = r.Header["Origin"]
}
if len(c.exposedHeaders) > 0 {
headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", "))
headers["Access-Control-Expose-Headers"] = c.exposedHeaders
}
if c.allowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
headers["Access-Control-Allow-Credentials"] = headerTrue
}
if c.Log != nil {
c.logf(" Actual response added headers: %v", headers)
}
c.logf(" Actual response added headers: %v", headers)
}

// convenience method. checks if a logger is set.
Expand Down Expand Up @@ -477,7 +501,6 @@ func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool {
return true
}
for _, header := range requestedHeaders {
header = http.CanonicalHeaderKey(header)
found := false
for _, h := range c.allowedHeaders {
if h == header {
Expand Down
19 changes: 11 additions & 8 deletions cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ import (
"testing"
)

var testResponse = []byte("bar")
var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bar"))
_, _ = w.Write(testResponse)
})

var allHeaders = []string{
Expand All @@ -26,6 +27,7 @@ var allHeaders = []string{
}

func assertHeaders(t *testing.T, resHeaders http.Header, expHeaders map[string]string) {
t.Helper()
for _, name := range allHeaders {
got := strings.Join(resHeaders[name], ", ")
want := expHeaders[name]
Expand All @@ -36,6 +38,7 @@ func assertHeaders(t *testing.T, resHeaders http.Header, expHeaders map[string]s
}

func assertResponse(t *testing.T, res *httptest.ResponseRecorder, responseCode int) {
t.Helper()
if responseCode != res.Code {
t.Errorf("assertResponse: expected response code to be %d but got %d. ", responseCode, res.Code)
}
Expand Down Expand Up @@ -721,37 +724,37 @@ func TestCorsAreHeadersAllowed(t *testing.T) {
{
name: "nil allowedHeaders",
allowedHeaders: nil,
requestedHeaders: parseHeaderList("X-PINGOTHER, Content-Type"),
requestedHeaders: []string{"X-PINGOTHER, Content-Type"},
want: false,
},
{
name: "star allowedHeaders",
allowedHeaders: []string{"*"},
requestedHeaders: parseHeaderList("X-PINGOTHER, Content-Type"),
requestedHeaders: []string{"X-PINGOTHER, Content-Type"},
want: true,
},
{
name: "empty reqHeader",
allowedHeaders: nil,
requestedHeaders: parseHeaderList(""),
requestedHeaders: []string{},
want: true,
},
{
name: "match allowedHeaders",
allowedHeaders: []string{"Content-Type", "X-PINGOTHER", "X-APP-KEY"},
requestedHeaders: parseHeaderList("X-PINGOTHER, Content-Type"),
requestedHeaders: []string{"X-PINGOTHER, Content-Type"},
want: true,
},
{
name: "not matched allowedHeaders",
allowedHeaders: []string{"X-PINGOTHER"},
requestedHeaders: parseHeaderList("X-API-KEY, Content-Type"),
requestedHeaders: []string{"X-API-KEY, Content-Type"},
want: false,
},
{
name: "allowedHeaders should be a superset of requestedHeaders",
allowedHeaders: []string{"X-PINGOTHER"},
requestedHeaders: parseHeaderList("X-PINGOTHER, Content-Type"),
requestedHeaders: []string{"X-PINGOTHER, Content-Type"},
want: false,
},
}
Expand All @@ -761,7 +764,7 @@ func TestCorsAreHeadersAllowed(t *testing.T) {

t.Run(tt.name, func(t *testing.T) {
c := New(Options{AllowedHeaders: tt.allowedHeaders})
have := c.areHeadersAllowed(tt.requestedHeaders)
have := c.areHeadersAllowed(convert(splitHeaderValues(tt.requestedHeaders), http.CanonicalHeaderKey))
if have != tt.want {
t.Errorf("Cors.areHeadersAllowed() have: %t want: %t", have, tt.want)
}
Expand Down
95 changes: 48 additions & 47 deletions utils.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package cors

import "strings"

const toLower = 'a' - 'A'
import (
"strings"
)

type converter func(string) string

Expand All @@ -15,57 +15,58 @@ func (w wildcard) match(s string) bool {
return len(s) >= len(w.prefix)+len(w.suffix) && strings.HasPrefix(s, w.prefix) && strings.HasSuffix(s, w.suffix)
}

// convert converts a list of string using the passed converter function
func convert(s []string, c converter) []string {
out := []string{}
for _, i := range s {
out = append(out, c(i))
}
return out
}

// parseHeaderList tokenize + normalize a string containing a list of headers
func parseHeaderList(headerList string) []string {
l := len(headerList)
h := make([]byte, 0, l)
upper := true
// Estimate the number headers in order to allocate the right splice size
t := 0
for i := 0; i < l; i++ {
if headerList[i] == ',' {
t++
}
}
headers := make([]string, 0, t)
for i := 0; i < l; i++ {
b := headerList[i]
switch {
case b >= 'a' && b <= 'z':
if upper {
h = append(h, b-toLower)
} else {
h = append(h, b)
// split compounded header values ["foo, bar", "baz"] -> ["foo", "bar", "baz"]
func splitHeaderValues(values []string) []string {
out := values
copied := false
for i, v := range values {
needsSplit := strings.IndexByte(v, ',') != -1
if !copied {
if needsSplit {
split := strings.Split(v, ",")
out = make([]string, i, len(values)+len(split)-1)
copy(out, values[:i])
for _, s := range split {
out = append(out, strings.TrimSpace(s))
}
copied = true
}
case b >= 'A' && b <= 'Z':
if !upper {
h = append(h, b+toLower)
} else {
if needsSplit {
split := strings.Split(v, ",")
for _, s := range split {
out = append(out, strings.TrimSpace(s))
}
} else {
h = append(h, b)
out = append(out, v)
}
case b == '-' || b == '_' || b == '.' || (b >= '0' && b <= '9'):
h = append(h, b)
}
}
return out
}

// convert converts a list of string using the passed converter function
func convert(s []string, c converter) []string {
out, _ := convertDidCopy(s, c)
return out
}

if b == ' ' || b == ',' || i == l-1 {
if len(h) > 0 {
// Flush the found header
headers = append(headers, string(h))
h = h[:0]
upper = true
// convertDidCopy is same as convert but returns true if it copied the slice
func convertDidCopy(s []string, c converter) ([]string, bool) {
out := s
copied := false
for i, v := range s {
if !copied {
v2 := c(v)
if v2 != v {
out = make([]string, len(s))
copy(out, s[:i])
out[i] = v2
copied = true
}
} else {
upper = b == '-' || b == '_'
out[i] = c(v)
}
}
return headers
return out, copied
}
Loading

1 comment on commit 080e86e

@jub0bs
Copy link
Contributor

@jub0bs jub0bs commented on 080e86e Sep 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😏

Please sign in to comment.