Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to merge muxer #482

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
Most recent version is listed first.


# v0.1.13
- ong/mux: add ability to merge Muxers: https://github.com/komuw/ong/pull/482

# v0.1.12
- ong/mux: add flexible pattern that allows a handler to serve almost all request URIs: https://github.com/komuw/ong/pull/481
-

# v0.1.11
- ong/cry: use constant parameters for argon key generation: https://github.com/komuw/ong/pull/477

Expand Down
116 changes: 109 additions & 7 deletions internal/mx/mx.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import (
"fmt"
"net/http"
"net/url"
"path"
"slices"
"strings"

"github.com/komuw/ong/config"
"github.com/komuw/ong/middleware"
Expand Down Expand Up @@ -75,21 +78,24 @@ func New(opt config.Opts, notFoundHandler http.Handler, routes ...Route) (Muxer,
mid = middleware.All
}

if err := m.addPattern(
m.addPattern(
rt.method,
rt.pattern,
rt.originalHandler,
mid(rt.originalHandler, opt),
); err != nil {
return Muxer{}, err
}
)
}

// Try and detect conflicting routes.
if err := detectConflict(m); err != nil {
return Muxer{}, err
}

return m, nil
}

func (m Muxer) addPattern(method, pattern string, originalHandler, wrappingHandler http.Handler) error {
return m.router.handle(method, pattern, originalHandler, wrappingHandler)
func (m Muxer) addPattern(method, pattern string, originalHandler, wrappingHandler http.Handler) {
m.router.handle(method, pattern, originalHandler, wrappingHandler)
}

// ServeHTTP implements a http.Handler
Expand Down Expand Up @@ -124,12 +130,36 @@ func (m Muxer) Resolve(path string) Route {
// Users of ong should not use this method. Instead, pass all your routes when calling [New]
func (m Muxer) AddRoute(rt Route) error {
// AddRoute should only be used internally by ong.
return m.addPattern(
m.addPattern(
rt.method,
rt.pattern,
rt.originalHandler,
middleware.All(rt.originalHandler, m.opt),
)

// Try and detect conflicting routes.
if err := detectConflict(m); err != nil {
return err
}

return nil
}

// Merge combines mxs into m. The resulting muxer uses the opts & notFoundHandler of m.
func (m Muxer) Merge(mxs ...Muxer) (Muxer, error) {
if len(mxs) < 1 {
return m, nil
}

for _, v := range mxs {
m.router.routes = append(m.router.routes, v.router.routes...)
}

if err := detectConflict(m); err != nil {
return m, err
}

return m, nil
}

// Param gets the path/url parameter from the specified Context.
Expand All @@ -140,3 +170,75 @@ func Param(ctx context.Context, param string) string {
}
return vStr
}

// detectConflict returns an error with a diagnostic message when you try to add a route that would conflict with an already existing one.
//
// The error message looks like:
//
// You are trying to add
// pattern: /post/:id/
// method: GET
// handler: github.com/myAPp/server/main.loginHandler - /home/server/main.go:351
// However
// pattern: post/create
// method: GET
// handler: github.com/myAPp/server/main.logoutHandler - /home/server/main.go:345
// already exists and would conflict.
//
// /
func detectConflict(m Muxer) error {
for k := range m.router.routes {
candidate := m.router.routes[k]
pattern := candidate.pattern
incomingSegments := pathSegments(pattern)

for _, rt := range m.router.routes {
if pattern == rt.pattern && (slices.Equal(candidate.segments, rt.segments)) && (getfunc(candidate.originalHandler) == getfunc(rt.originalHandler)) {
continue
}

existingSegments := rt.segments
sameLen := len(incomingSegments) == len(existingSegments)
if !sameLen {
// no conflict
continue
}

errMsg := fmt.Errorf(`
You are trying to add
pattern: %s
method: %s
handler: %v
However
pattern: %s
method: %s
handler: %v
already exists and would conflict`,
pattern,
strings.ToUpper(candidate.method),
getfunc(candidate.originalHandler),
path.Join(rt.segments...),
strings.ToUpper(rt.method),
getfunc(rt.originalHandler),
)

if len(existingSegments) == 1 && existingSegments[0] == "*" && len(incomingSegments) > 0 {
return errMsg
}

if pattern == rt.pattern {
return errMsg
}

if strings.Contains(pattern, ":") && (incomingSegments[0] == existingSegments[0]) {
return errMsg
}

if strings.Contains(rt.pattern, ":") && (incomingSegments[0] == existingSegments[0]) {
return errMsg
}
}
}

return nil
}
146 changes: 145 additions & 1 deletion internal/mx/mx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func TestMux(t *testing.T) {

msg := "hello world"
uri1 := "/api/hi"
uri2 := "/api/:someId" // This conflicts with uri1
uri2 := "api/:someId" // This conflicts with uri1
method := MethodGet

rt1, err := NewRoute(
Expand Down Expand Up @@ -406,6 +406,150 @@ func TestMux(t *testing.T) {
attest.Equal(t, string(rb2), "someOtherMuxHandler")
}
})

t.Run("merge", func(t *testing.T) {
t.Parallel()

httpsPort := tst.GetPort()
domain := "localhost"

t.Run("success", func(t *testing.T) {
t.Parallel()
rt1, err := NewRoute("/abc", MethodGet, someMuxHandler("hello"))
attest.Ok(t, err)

mux1, err := New(config.WithOpts(domain, httpsPort, tst.SecretKey(), config.DirectIpStrategy, l), nil, rt1)
attest.Ok(t, err)

rt2, err := NewRoute("/ijk", MethodGet, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
attest.Ok(t, err)
rt3, err := NewRoute("/xyz", MethodGet, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
attest.Ok(t, err)

mux2, err := New(config.WithOpts(domain, httpsPort, tst.SecretKey(), config.DirectIpStrategy, l), nil, rt2, rt3)
attest.Ok(t, err)

m, err := mux1.Merge(mux2)
attest.Ok(t, err)

attest.Equal(t, m.opt, mux1.opt)
attest.Equal(t, fmt.Sprintf("%p", m.router.notFoundHandler), fmt.Sprintf("%p", mux1.router.notFoundHandler))
attest.Equal(t, len(m.router.routes), 3)
})

t.Run("conflict", func(t *testing.T) {
rt1, err := NewRoute("/abc", MethodGet, someMuxHandler("hello"))
attest.Ok(t, err)

mux1, err := New(config.WithOpts(domain, httpsPort, tst.SecretKey(), config.DirectIpStrategy, l), nil, rt1)
attest.Ok(t, err)

rt2, err := NewRoute("/ijk", MethodGet, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
attest.Ok(t, err)
rt3, err := NewRoute("/abc", MethodGet, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
attest.Ok(t, err)

mux2, err := New(config.WithOpts(domain, httpsPort, tst.SecretKey(), config.DirectIpStrategy, l), nil, rt2, rt3)
attest.Ok(t, err)

_, errM := mux1.Merge(mux2)
attest.Error(t, errM)
rStr := errM.Error()
attest.Subsequence(t, rStr, "would conflict")
attest.Subsequence(t, rStr, "ong/internal/mx/mx_test.go:28") // location where `someMuxHandler` is declared.
attest.Subsequence(t, rStr, "ong/internal/mx/mx_test.go:449") // location where the other handler is declared.
})
})
}

func TestConflicts(t *testing.T) {
t.Parallel()

l := log.New(context.Background(), &bytes.Buffer{}, 500)

t.Run("conflicts detected", func(t *testing.T) {
t.Parallel()

msg1 := "firstRoute"
msg2 := "secondRoute"

rt1, err := NewRoute("/post/create", http.MethodGet, firstRoute(msg1))
attest.Ok(t, err)

rt2, err := NewRoute("/post/:id", http.MethodGet, secondRoute(msg2))
attest.Ok(t, err)

_, errH := New(
config.WithOpts("localhost", 443, tst.SecretKey(), config.DirectIpStrategy, l),
nil,
rt1,
rt2,
)
attest.Error(t, errH)
})

t.Run("different http methods same path conflicts detected", func(t *testing.T) {
t.Parallel()

msg1 := "firstRoute"
msg2 := "secondRoute"

rt1, err := NewRoute("/post", http.MethodGet, firstRoute(msg1))
attest.Ok(t, err)

rt2, err := NewRoute("post/", http.MethodDelete, secondRoute(msg2))
attest.Ok(t, err)

_, errH := New(
config.WithOpts("localhost", 443, tst.SecretKey(), config.DirectIpStrategy, l),
nil,
rt1,
rt2,
)
attest.Error(t, errH)
})

t.Run("no conflict", func(t *testing.T) {
t.Parallel()

msg1 := "firstRoute-one"
msg2 := "secondRoute-two"

rt1, err := NewRoute("/w00tw00t.at.blackhats.romanian.anti-sec:)", http.MethodGet, firstRoute(msg1))
attest.Ok(t, err)

rt2, err := NewRoute("/index.php", http.MethodGet, secondRoute(msg2))
attest.Ok(t, err)

_, errH := New(
config.WithOpts("localhost", 443, tst.SecretKey(), config.DirectIpStrategy, l),
nil,
rt1,
rt2,
)
attest.Ok(t, errH)
})

t.Run("http MethodAll conflicts with all other methods", func(t *testing.T) {
t.Parallel()

msg1 := "firstRoute"
msg2 := "secondRoute"

rt1, err := NewRoute("/post", http.MethodGet, firstRoute(msg1))
attest.Ok(t, err)

rt2, err := NewRoute("post/", MethodAll, secondRoute(msg2))
attest.Ok(t, err)

_, errH := New(
config.WithOpts("localhost", 443, tst.SecretKey(), config.DirectIpStrategy, l),
nil,
rt1,
rt2,
)
attest.Error(t, errH)
})
}

func TestMuxFlexiblePattern(t *testing.T) {
Expand Down
Loading
Loading