diff --git a/CHANGELOG.md b/CHANGELOG.md index 140d4bfe..2ae16e21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,9 +3,12 @@ Most recent version is listed first. +# v0.1.13 +- ong/mux: add ability to merge Muxers: https://github.com/komuw/ong/pull/482 + # v0.1.12 - ong/mux: add flexible pattern that allows a handler to serve almost all request URIs: https://github.com/komuw/ong/pull/481 -- + # v0.1.11 - ong/cry: use constant parameters for argon key generation: https://github.com/komuw/ong/pull/477 diff --git a/internal/mx/mx.go b/internal/mx/mx.go index 162e0e16..8222c012 100644 --- a/internal/mx/mx.go +++ b/internal/mx/mx.go @@ -8,6 +8,9 @@ import ( "fmt" "net/http" "net/url" + "path" + "slices" + "strings" "github.com/komuw/ong/config" "github.com/komuw/ong/middleware" @@ -75,21 +78,24 @@ func New(opt config.Opts, notFoundHandler http.Handler, routes ...Route) (Muxer, mid = middleware.All } - if err := m.addPattern( + m.addPattern( rt.method, rt.pattern, rt.originalHandler, mid(rt.originalHandler, opt), - ); err != nil { - return Muxer{}, err - } + ) + } + + // Try and detect conflicting routes. + if err := detectConflict(m); err != nil { + return Muxer{}, err } return m, nil } -func (m Muxer) addPattern(method, pattern string, originalHandler, wrappingHandler http.Handler) error { - return m.router.handle(method, pattern, originalHandler, wrappingHandler) +func (m Muxer) addPattern(method, pattern string, originalHandler, wrappingHandler http.Handler) { + m.router.handle(method, pattern, originalHandler, wrappingHandler) } // ServeHTTP implements a http.Handler @@ -124,12 +130,36 @@ func (m Muxer) Resolve(path string) Route { // Users of ong should not use this method. Instead, pass all your routes when calling [New] func (m Muxer) AddRoute(rt Route) error { // AddRoute should only be used internally by ong. - return m.addPattern( + m.addPattern( rt.method, rt.pattern, rt.originalHandler, middleware.All(rt.originalHandler, m.opt), ) + + // Try and detect conflicting routes. + if err := detectConflict(m); err != nil { + return err + } + + return nil +} + +// Merge combines mxs into m. The resulting muxer uses the opts & notFoundHandler of m. +func (m Muxer) Merge(mxs ...Muxer) (Muxer, error) { + if len(mxs) < 1 { + return m, nil + } + + for _, v := range mxs { + m.router.routes = append(m.router.routes, v.router.routes...) + } + + if err := detectConflict(m); err != nil { + return m, err + } + + return m, nil } // Param gets the path/url parameter from the specified Context. @@ -140,3 +170,75 @@ func Param(ctx context.Context, param string) string { } return vStr } + +// detectConflict returns an error with a diagnostic message when you try to add a route that would conflict with an already existing one. +// +// The error message looks like: +// +// You are trying to add +// pattern: /post/:id/ +// method: GET +// handler: github.com/myAPp/server/main.loginHandler - /home/server/main.go:351 +// However +// pattern: post/create +// method: GET +// handler: github.com/myAPp/server/main.logoutHandler - /home/server/main.go:345 +// already exists and would conflict. +// +// / +func detectConflict(m Muxer) error { + for k := range m.router.routes { + candidate := m.router.routes[k] + pattern := candidate.pattern + incomingSegments := pathSegments(pattern) + + for _, rt := range m.router.routes { + if pattern == rt.pattern && (slices.Equal(candidate.segments, rt.segments)) && (getfunc(candidate.originalHandler) == getfunc(rt.originalHandler)) { + continue + } + + existingSegments := rt.segments + sameLen := len(incomingSegments) == len(existingSegments) + if !sameLen { + // no conflict + continue + } + + errMsg := fmt.Errorf(` +You are trying to add + pattern: %s + method: %s + handler: %v +However + pattern: %s + method: %s + handler: %v +already exists and would conflict`, + pattern, + strings.ToUpper(candidate.method), + getfunc(candidate.originalHandler), + path.Join(rt.segments...), + strings.ToUpper(rt.method), + getfunc(rt.originalHandler), + ) + + if len(existingSegments) == 1 && existingSegments[0] == "*" && len(incomingSegments) > 0 { + return errMsg + } + + if pattern == rt.pattern { + return errMsg + } + + if strings.Contains(pattern, ":") && (incomingSegments[0] == existingSegments[0]) { + return errMsg + } + + if strings.Contains(rt.pattern, ":") && (incomingSegments[0] == existingSegments[0]) { + return errMsg + } + } + } + + return nil +} diff --git a/internal/mx/mx_test.go b/internal/mx/mx_test.go index 0e3562c9..cd648e1b 100644 --- a/internal/mx/mx_test.go +++ b/internal/mx/mx_test.go @@ -204,7 +204,7 @@ func TestMux(t *testing.T) { msg := "hello world" uri1 := "/api/hi" - uri2 := "/api/:someId" // This conflicts with uri1 + uri2 := "api/:someId" // This conflicts with uri1 method := MethodGet rt1, err := NewRoute( @@ -406,6 +406,150 @@ func TestMux(t *testing.T) { attest.Equal(t, string(rb2), "someOtherMuxHandler") } }) + + t.Run("merge", func(t *testing.T) { + t.Parallel() + + httpsPort := tst.GetPort() + domain := "localhost" + + t.Run("success", func(t *testing.T) { + t.Parallel() + rt1, err := NewRoute("/abc", MethodGet, someMuxHandler("hello")) + attest.Ok(t, err) + + mux1, err := New(config.WithOpts(domain, httpsPort, tst.SecretKey(), config.DirectIpStrategy, l), nil, rt1) + attest.Ok(t, err) + + rt2, err := NewRoute("/ijk", MethodGet, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + attest.Ok(t, err) + rt3, err := NewRoute("/xyz", MethodGet, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + attest.Ok(t, err) + + mux2, err := New(config.WithOpts(domain, httpsPort, tst.SecretKey(), config.DirectIpStrategy, l), nil, rt2, rt3) + attest.Ok(t, err) + + m, err := mux1.Merge(mux2) + attest.Ok(t, err) + + attest.Equal(t, m.opt, mux1.opt) + attest.Equal(t, fmt.Sprintf("%p", m.router.notFoundHandler), fmt.Sprintf("%p", mux1.router.notFoundHandler)) + attest.Equal(t, len(m.router.routes), 3) + }) + + t.Run("conflict", func(t *testing.T) { + rt1, err := NewRoute("/abc", MethodGet, someMuxHandler("hello")) + attest.Ok(t, err) + + mux1, err := New(config.WithOpts(domain, httpsPort, tst.SecretKey(), config.DirectIpStrategy, l), nil, rt1) + attest.Ok(t, err) + + rt2, err := NewRoute("/ijk", MethodGet, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + attest.Ok(t, err) + rt3, err := NewRoute("/abc", MethodGet, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + attest.Ok(t, err) + + mux2, err := New(config.WithOpts(domain, httpsPort, tst.SecretKey(), config.DirectIpStrategy, l), nil, rt2, rt3) + attest.Ok(t, err) + + _, errM := mux1.Merge(mux2) + attest.Error(t, errM) + rStr := errM.Error() + attest.Subsequence(t, rStr, "would conflict") + attest.Subsequence(t, rStr, "ong/internal/mx/mx_test.go:28") // location where `someMuxHandler` is declared. + attest.Subsequence(t, rStr, "ong/internal/mx/mx_test.go:449") // location where the other handler is declared. + }) + }) +} + +func TestConflicts(t *testing.T) { + t.Parallel() + + l := log.New(context.Background(), &bytes.Buffer{}, 500) + + t.Run("conflicts detected", func(t *testing.T) { + t.Parallel() + + msg1 := "firstRoute" + msg2 := "secondRoute" + + rt1, err := NewRoute("/post/create", http.MethodGet, firstRoute(msg1)) + attest.Ok(t, err) + + rt2, err := NewRoute("/post/:id", http.MethodGet, secondRoute(msg2)) + attest.Ok(t, err) + + _, errH := New( + config.WithOpts("localhost", 443, tst.SecretKey(), config.DirectIpStrategy, l), + nil, + rt1, + rt2, + ) + attest.Error(t, errH) + }) + + t.Run("different http methods same path conflicts detected", func(t *testing.T) { + t.Parallel() + + msg1 := "firstRoute" + msg2 := "secondRoute" + + rt1, err := NewRoute("/post", http.MethodGet, firstRoute(msg1)) + attest.Ok(t, err) + + rt2, err := NewRoute("post/", http.MethodDelete, secondRoute(msg2)) + attest.Ok(t, err) + + _, errH := New( + config.WithOpts("localhost", 443, tst.SecretKey(), config.DirectIpStrategy, l), + nil, + rt1, + rt2, + ) + attest.Error(t, errH) + }) + + t.Run("no conflict", func(t *testing.T) { + t.Parallel() + + msg1 := "firstRoute-one" + msg2 := "secondRoute-two" + + rt1, err := NewRoute("/w00tw00t.at.blackhats.romanian.anti-sec:)", http.MethodGet, firstRoute(msg1)) + attest.Ok(t, err) + + rt2, err := NewRoute("/index.php", http.MethodGet, secondRoute(msg2)) + attest.Ok(t, err) + + _, errH := New( + config.WithOpts("localhost", 443, tst.SecretKey(), config.DirectIpStrategy, l), + nil, + rt1, + rt2, + ) + attest.Ok(t, errH) + }) + + t.Run("http MethodAll conflicts with all other methods", func(t *testing.T) { + t.Parallel() + + msg1 := "firstRoute" + msg2 := "secondRoute" + + rt1, err := NewRoute("/post", http.MethodGet, firstRoute(msg1)) + attest.Ok(t, err) + + rt2, err := NewRoute("post/", MethodAll, secondRoute(msg2)) + attest.Ok(t, err) + + _, errH := New( + config.WithOpts("localhost", 443, tst.SecretKey(), config.DirectIpStrategy, l), + nil, + rt1, + rt2, + ) + attest.Error(t, errH) + }) } func TestMuxFlexiblePattern(t *testing.T) { diff --git a/internal/mx/route.go b/internal/mx/route.go index eb0a6bc8..c504e8f2 100644 --- a/internal/mx/route.go +++ b/internal/mx/route.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net/http" - "path" "reflect" "runtime" "strings" @@ -140,7 +139,7 @@ func pathSegments(p string) []string { // handle adds a handler with the specified method and pattern. // Pattern can contain path segments such as: /item/:id which is // accessible via the Param function. -func (r *router) handle(method, pattern string, originalHandler, wrappingHandler http.Handler) error { +func (r *router) handle(method, pattern string, originalHandler, wrappingHandler http.Handler) { if !strings.HasSuffix(pattern, "/") { // this will make the mux send requests for; // - localhost:80/check @@ -152,11 +151,6 @@ func (r *router) handle(method, pattern string, originalHandler, wrappingHandler pattern = "/" + pattern } - // Try and detect conflict before adding a new route. - if err := r.detectConflict(method, pattern, originalHandler); err != nil { - return err - } - rt := Route{ method: strings.ToUpper(method), pattern: pattern, @@ -165,8 +159,6 @@ func (r *router) handle(method, pattern string, originalHandler, wrappingHandler wrappingHandler: wrappingHandler, } r.routes = append(r.routes, rt) - - return nil } // serveHTTP routes the incoming http.Request based on method and path extracting path parameters as it goes. @@ -182,73 +174,6 @@ func (r *router) serveHTTP(w http.ResponseWriter, req *http.Request) { r.notFoundHandler.ServeHTTP(w, req) } -// detectConflict returns an error with a diagnostic message when you try to add a route that would conflict with an already existing one. -// -// The error message looks like: -// -// You are trying to add -// pattern: /post/:id/ -// method: GET -// handler: github.com/myAPp/server/main.loginHandler - /home/server/main.go:351 -// However -// pattern: post/create -// method: GET -// handler: github.com/myAPp/server/main.logoutHandler - /home/server/main.go:345 -// already exists and would conflict. -// -// / -func (r *router) detectConflict(method, pattern string, originalHandler http.Handler) error { - // Conflicting routes are a bad thing. - // They can be a source of bugs and confusion. - // see: https://www.alexedwards.net/blog/which-go-router-should-i-use - - incomingSegments := pathSegments(pattern) - for _, rt := range r.routes { - existingSegments := rt.segments - sameLen := len(incomingSegments) == len(existingSegments) - if !sameLen { - // no conflict - continue - } - - errMsg := fmt.Errorf(` -You are trying to add - pattern: %s - method: %s - handler: %v -However - pattern: %s - method: %s - handler: %v -already exists and would conflict`, - pattern, - strings.ToUpper(method), - getfunc(originalHandler), - path.Join(rt.segments...), - strings.ToUpper(rt.method), - getfunc(rt.originalHandler), - ) - - if len(existingSegments) == 1 && existingSegments[0] == "*" && len(incomingSegments) > 0 { - return errMsg - } - - if pattern == rt.pattern { - return errMsg - } - - if strings.Contains(pattern, ":") && (incomingSegments[0] == existingSegments[0]) { - return errMsg - } - - if strings.Contains(rt.pattern, ":") && (incomingSegments[0] == existingSegments[0]) { - return errMsg - } - } - - return nil -} - func getfunc(handler http.Handler) string { fn := runtime.FuncForPC(reflect.ValueOf(handler).Pointer()) file, line := fn.FileLine(fn.Entry()) diff --git a/internal/mx/route_test.go b/internal/mx/route_test.go index b861f069..358589e2 100644 --- a/internal/mx/route_test.go +++ b/internal/mx/route_test.go @@ -3,7 +3,6 @@ package mx import ( "context" "fmt" - "io" "net/http" "net/http/httptest" "testing" @@ -248,11 +247,10 @@ func TestRouter(t *testing.T) { match := false var ctx context.Context - err := r.handle(tt.RouteMethod, tt.RoutePattern, nil, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.handle(tt.RouteMethod, tt.RoutePattern, nil, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { match = true ctx = r.Context() })) - attest.Ok(t, err) req, err := http.NewRequest(tt.Method, tt.Path, nil) attest.Ok(t, err) @@ -285,10 +283,9 @@ func TestMultipleRoutesDifferentMethods(t *testing.T) { r := newRouter(nil) var match string - err := r.handle(MethodAll, "/path", nil, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.handle(MethodAll, "/path", nil, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { match = r.Method })) - attest.Ok(t, err) req, err := http.NewRequest(http.MethodGet, "/path", nil) attest.Ok(t, err) @@ -318,87 +315,6 @@ func secondRoute(msg string) http.HandlerFunc { } } -func TestConflicts(t *testing.T) { - t.Parallel() - - t.Run("conflicts detected", func(t *testing.T) { - t.Parallel() - r := newRouter(nil) - - msg1 := "firstRoute" - msg2 := "secondRoute" - err := r.handle(http.MethodGet, "/post/create", firstRoute(msg1), firstRoute(msg1)) - attest.Ok(t, err) - - // This one returns with a conflict message. - errH := r.handle(http.MethodGet, "/post/:id", secondRoute(msg2), secondRoute(msg2)) - attest.Error(t, errH) - - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/post/create", nil) - r.serveHTTP(rec, req) - - res := rec.Result() - defer res.Body.Close() - - rb, err := io.ReadAll(res.Body) - attest.Ok(t, err) - - attest.Equal(t, res.StatusCode, http.StatusOK) - attest.Equal(t, string(rb), msg1) - }) - - t.Run("different http methods same path conflicts detected", func(t *testing.T) { - t.Parallel() - r := newRouter(nil) - - msg1 := "firstRoute" - msg2 := "secondRoute" - err := r.handle(http.MethodGet, "/post", firstRoute(msg1), firstRoute(msg1)) - attest.Ok(t, err) - - // This one returns with a conflict message. - errH := r.handle(http.MethodGet, "/post/", secondRoute(msg2), secondRoute(msg2)) - attest.Error(t, errH) - - // This one returns with a conflict message. - errB := r.handle(http.MethodDelete, "post/", secondRoute(msg2), secondRoute(msg2)) - attest.Error(t, errB) - - // This one returns with a conflict message. - errC := r.handle(http.MethodPut, "post", secondRoute(msg2), secondRoute(msg2)) - attest.Error(t, errC) - }) - - t.Run("no conflict", func(t *testing.T) { - t.Parallel() - r := newRouter(nil) - - msg1 := "firstRoute-one" - msg2 := "secondRoute-two" - err := r.handle(http.MethodGet, "/w00tw00t.at.blackhats.romanian.anti-sec:)", firstRoute(msg1), firstRoute(msg1)) - attest.Ok(t, err) - - // This one should not conflict. - errH := r.handle(http.MethodGet, "/index.php", secondRoute(msg2), secondRoute(msg2)) - attest.Ok(t, errH) - }) - - t.Run("http MethodAll conflicts with all other methods", func(t *testing.T) { - t.Parallel() - r := newRouter(nil) - - msg1 := "firstRoute" - msg2 := "secondRoute" - err := r.handle(http.MethodGet, "/post", firstRoute(msg1), firstRoute(msg1)) - attest.Ok(t, err) - - // This one returns with a conflict message. - errB := r.handle(MethodAll, "post/", secondRoute(msg2), secondRoute(msg2)) - attest.Error(t, errB) - }) -} - func TestNotFound(t *testing.T) { t.Parallel() @@ -407,10 +323,9 @@ func TestNotFound(t *testing.T) { r := newRouter(nil) var match string - err := r.handle(MethodAll, "/path", nil, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.handle(MethodAll, "/path", nil, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { match = r.Method })) - attest.Ok(t, err) req, err := http.NewRequest(http.MethodGet, "/path", nil) attest.Ok(t, err) @@ -427,10 +342,9 @@ func TestNotFound(t *testing.T) { r := newRouter(nil) var match string - err := r.handle(MethodAll, "/path", nil, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.handle(MethodAll, "/path", nil, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { match = r.Method })) - attest.Ok(t, err) req, err := http.NewRequest(http.MethodGet, "/not-found-path", nil) attest.Ok(t, err) @@ -451,10 +365,9 @@ func TestNotFound(t *testing.T) { }) r := newRouter(notFoundHandler) - err := r.handle(MethodAll, "/path", nil, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.handle(MethodAll, "/path", nil, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { match = r.Method })) - attest.Ok(t, err) req, err := http.NewRequest(http.MethodGet, "/not-found-path", nil) attest.Ok(t, err) diff --git a/mux/mux.go b/mux/mux.go index d6414203..99b986bc 100644 --- a/mux/mux.go +++ b/mux/mux.go @@ -90,6 +90,21 @@ func (m Muxer) Unwrap() mx.Muxer { return m.internalMux } +// Merge combines mxs into m. The resulting muxer uses the opts & notFoundHandler of m. +func (m Muxer) Merge(mxs ...Muxer) (Muxer, error) { + mi := []mx.Muxer{} + for _, v := range mxs { + mi = append(mi, v.internalMux) + } + + mm, err := m.internalMux.Merge(mi...) + if err != nil { + return Muxer{}, err + } + + return Muxer{internalMux: mm}, nil +} + // Param gets the path/url parameter from the specified Context. // It returns an empty string if the parameter was not found. func Param(ctx context.Context, param string) string { diff --git a/mux/mux_test.go b/mux/mux_test.go index 2938aced..99171653 100644 --- a/mux/mux_test.go +++ b/mux/mux_test.go @@ -11,6 +11,16 @@ import ( "go.akshayshah.org/attest" ) +func tarpitRoutes() []Route { + return []Route{ + NewRoute( + "/libraries/joomla/", + MethodAll, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), + ), + } +} + func TestNew(t *testing.T) { l := log.New(context.Background(), &bytes.Buffer{}, 500) @@ -42,12 +52,48 @@ func TestNew(t *testing.T) { }) } -func tarpitRoutes() []Route { - return []Route{ - NewRoute( - "/libraries/joomla/", - MethodAll, - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), - ), - } +func TestMerge(t *testing.T) { + t.Parallel() + + l := log.New(context.Background(), &bytes.Buffer{}, 500) + + t.Run("okay", func(t *testing.T) { + t.Parallel() + + rt1 := []Route{ + NewRoute("/home", MethodGet, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})), + NewRoute("/health/", MethodAll, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})), + } + + rt2 := []Route{ + NewRoute("/uri2", MethodGet, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})), + } + + mx1 := New(config.DevOpts(l, "secretKey12@34String"), nil, rt1...) + mx2 := New(config.DevOpts(l, "secretKey12@34String"), nil, rt2...) + + _, err := mx1.Merge(mx2) + attest.Ok(t, err) + }) + + t.Run("conflict", func(t *testing.T) { + t.Parallel() + + rt1 := []Route{ + NewRoute("/home", MethodGet, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})), + NewRoute("/health/", MethodAll, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})), + } + + rt2 := []Route{ + NewRoute("/uri2", MethodGet, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})), + NewRoute("health", MethodPost, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})), + } + + mx1 := New(config.DevOpts(l, "secretKey12@34String"), nil, rt1...) + mx2 := New(config.DevOpts(l, "secretKey12@34String"), nil, rt2...) + + _, err := mx1.Merge(mx2) + attest.Error(t, err) + attest.Subsequence(t, err.Error(), "would conflict") + }) }