Skip to content

Commit

Permalink
Fix inconsistent processing of server variables in gorillamux router (#…
Browse files Browse the repository at this point in the history
…705)

Co-authored-by: Steve Lessard <[email protected]>
  • Loading branch information
slessard and sl255051 authored Dec 16, 2022
1 parent 6cbc1b0 commit 6a3b779
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 27 deletions.
60 changes: 34 additions & 26 deletions routers/gorillamux/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ func NewRouter(doc *openapi3.T) (routers.Router, error) {
muxRouter := mux.NewRouter().UseEncodedPath()
r := &Router{}
for _, path := range orderedPaths(doc.Paths) {
servers := servers

pathItem := doc.Paths[path]
if len(pathItem.Servers) > 0 {
if servers, err = makeServers(pathItem.Servers); err != nil {
Expand Down Expand Up @@ -140,19 +138,13 @@ func makeServers(in openapi3.Servers) ([]srv, error) {
if lhs := strings.TrimSuffix(serverURL, server.Variables[sVar].Default); lhs != "" {
varsUpdater = func(vars map[string]string) { vars[sVar] = lhs }
}
servers = append(servers, srv{
base: server.Variables[sVar].Default,
server: server,
varsUpdater: varsUpdater,
})
continue
}
svr, err := newSrv(serverURL, server, varsUpdater)
if err != nil {
return nil, err
}

var schemes []string
if strings.Contains(serverURL, "://") {
scheme0 := strings.Split(serverURL, "://")[0]
schemes = permutePart(scheme0, server)
serverURL = strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1)
servers = append(servers, svr)
continue
}

// If a variable represents the port "http://domain.tld:{port}/bla"
Expand All @@ -172,21 +164,11 @@ func makeServers(in openapi3.Servers) ([]srv, error) {
}
}

u, err := url.Parse(bEncode(serverURL))
svr, err := newSrv(serverURL, server, varsUpdater)
if err != nil {
return nil, err
}
path := bDecode(u.EscapedPath())
if len(path) > 0 && path[len(path)-1] == '/' {
path = path[:len(path)-1]
}
servers = append(servers, srv{
host: bDecode(u.Host), //u.Hostname()?
base: path,
schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624
server: server,
varsUpdater: varsUpdater,
})
servers = append(servers, svr)
}
if len(servers) == 0 {
servers = append(servers, srv{})
Expand All @@ -195,6 +177,32 @@ func makeServers(in openapi3.Servers) ([]srv, error) {
return servers, nil
}

func newSrv(serverURL string, server *openapi3.Server, varsUpdater varsf) (srv, error) {
var schemes []string
if strings.Contains(serverURL, "://") {
scheme0 := strings.Split(serverURL, "://")[0]
schemes = permutePart(scheme0, server)
serverURL = strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1)
}

u, err := url.Parse(bEncode(serverURL))
if err != nil {
return srv{}, err
}
path := bDecode(u.EscapedPath())
if len(path) > 0 && path[len(path)-1] == '/' {
path = path[:len(path)-1]
}
svr := srv{
host: bDecode(u.Host), //u.Hostname()?
base: path,
schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624
server: server,
varsUpdater: varsUpdater,
}
return svr, nil
}

func orderedPaths(paths map[string]*openapi3.PathItem) []string {
// https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.3.md#pathsObject
// When matching URLs, concrete (non-templated) paths would be matched
Expand Down
163 changes: 162 additions & 1 deletion routers/gorillamux/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sort"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/getkin/kin-openapi/openapi3"
Expand Down Expand Up @@ -249,7 +250,16 @@ func TestServerPath(t *testing.T) {
"http://example.com:{port}/path",
map[string]string{
"port": "8088",
})},
}),
newServerWithVariables(
"{server}",
map[string]string{
"server": "/",
}),
newServerWithVariables(
"/",
nil,
)},
})
require.NoError(t, err)
}
Expand Down Expand Up @@ -325,6 +335,157 @@ func TestRelativeURL(t *testing.T) {
require.Equal(t, "/hello", route.Path)
}

func Test_makeServers(t *testing.T) {
type testStruct struct {
name string
servers openapi3.Servers
want []srv
wantErr bool
initFn func(tt *testStruct)
}
tests := []testStruct{
{
name: "server is root path",
servers: openapi3.Servers{
newServerWithVariables("/", nil),
},
want: []srv{{
schemes: nil,
host: "",
base: "",
server: nil,
varsUpdater: nil,
}},
wantErr: false,
initFn: func(tt *testStruct) {
for i, server := range tt.servers {
tt.want[i].server = server
}
},
},
{
name: "server with single variable that evaluates to root path",
servers: openapi3.Servers{
newServerWithVariables("{server}", map[string]string{"server": "/"}),
},
want: []srv{{
schemes: nil,
host: "",
base: "",
server: nil,
varsUpdater: nil,
}},
wantErr: false,
initFn: func(tt *testStruct) {
for i, server := range tt.servers {
tt.want[i].server = server
}
},
},
{
name: "server is http://localhost:28002",
servers: openapi3.Servers{
newServerWithVariables("http://localhost:28002", nil),
},
want: []srv{{
schemes: []string{"http"},
host: "localhost:28002",
base: "",
server: nil,
varsUpdater: nil,
}},
wantErr: false,
initFn: func(tt *testStruct) {
for i, server := range tt.servers {
tt.want[i].server = server
}
},
},
{
name: "server with single variable that evaluates to http://localhost:28002",
servers: openapi3.Servers{
newServerWithVariables("{server}", map[string]string{"server": "http://localhost:28002"}),
},
want: []srv{{
schemes: []string{"http"},
host: "localhost:28002",
base: "",
server: nil,
varsUpdater: nil,
}},
wantErr: false,
initFn: func(tt *testStruct) {
for i, server := range tt.servers {
tt.want[i].server = server
}
},
},
{
name: "server with multiple variables that evaluates to http://localhost:28002",
servers: openapi3.Servers{
newServerWithVariables("{scheme}://{host}:{port}", map[string]string{"scheme": "http", "host": "localhost", "port": "28002"}),
},
want: []srv{{
schemes: []string{"http"},
host: "{host}:28002",
base: "",
server: nil,
varsUpdater: func(vars map[string]string) { vars["port"] = "28002" },
}},
wantErr: false,
initFn: func(tt *testStruct) {
for i, server := range tt.servers {
tt.want[i].server = server
}
},
},
{
name: "server with unparsable URL fails",
servers: openapi3.Servers{
newServerWithVariables("exam^ple.com:443", nil),
},
want: nil,
wantErr: true,
initFn: nil,
},
{
name: "server with single variable that evaluates to unparsable URL fails",
servers: openapi3.Servers{
newServerWithVariables("{server}", map[string]string{"server": "exam^ple.com:443"}),
},
want: nil,
wantErr: true,
initFn: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.initFn != nil {
tt.initFn(&tt)
}
got, err := makeServers(tt.servers)
if (err != nil) != tt.wantErr {
t.Errorf("makeServers() error = %v, wantErr %v", err, tt.wantErr)
return
}
assert.Equal(t, len(tt.want), len(got), "expected and actual servers lengths are not equal")
for i := 0; i < len(tt.want); i++ {
// Unfortunately using assert.Equals or reflect.DeepEquals isn't
// an option because function pointers cannot be compared
assert.Equal(t, tt.want[i].schemes, got[i].schemes)
assert.Equal(t, tt.want[i].host, got[i].host)
assert.Equal(t, tt.want[i].host, got[i].host)
assert.Equal(t, tt.want[i].server, got[i].server)
if tt.want[i].varsUpdater == nil {
assert.Nil(t, got[i].varsUpdater, "expected and actual varsUpdater should point to same function")
} else {
assert.NotNil(t, got[i].varsUpdater, "expected and actual varsUpdater should point to same function")
}
}
})
}
}

func newServerWithVariables(url string, variables map[string]string) *openapi3.Server {
var serverVariables = map[string]*openapi3.ServerVariable{}

Expand Down

0 comments on commit 6a3b779

Please sign in to comment.