diff --git a/mux.go b/mux.go index 240ae676..6dc1904d 100644 --- a/mux.go +++ b/mux.go @@ -452,6 +452,10 @@ func (mx *Mux) routeHTTP(w http.ResponseWriter, r *http.Request) { // Find the route if _, _, h := mx.tree.FindRoute(rctx, method, routePath); h != nil { + if supportsPathValue { + setPathValue(rctx, r) + } + h.ServeHTTP(w, r) return } diff --git a/path_value.go b/path_value.go new file mode 100644 index 00000000..7e78171e --- /dev/null +++ b/path_value.go @@ -0,0 +1,20 @@ +//go:build go1.22 +// +build go1.22 + +package chi + +import "net/http" + +// supportsPathValue is true if the Go version is 1.22 and above. +// +// If this is true, `net/http.Request` has methods `SetPathValue` and `PathValue`. +const supportsPathValue = true + +// setPathValue sets the path values in the Request value +// based on the provided request context. +func setPathValue(rctx *Context, r *http.Request) { + for i, key := range rctx.URLParams.Keys { + value := rctx.URLParams.Values[i] + r.SetPathValue(key, value) + } +} diff --git a/path_value_fallback.go b/path_value_fallback.go new file mode 100644 index 00000000..f551781a --- /dev/null +++ b/path_value_fallback.go @@ -0,0 +1,19 @@ +//go:build !go1.22 +// +build !go1.22 + +package chi + +import "net/http" + +// supportsPathValue is true if the Go version is 1.22 and above. +// +// If this is true, `net/http.Request` has methods `SetPathValue` and `PathValue`. +const supportsPathValue = false + +// setPathValue sets the path values in the Request value +// based on the provided request context. +// +// setPathValue is only supported in Go 1.22 and above so +// this is just a blank function so that it compiles. +func setPathValue(rctx *Context, r *http.Request) { +} diff --git a/path_value_test.go b/path_value_test.go new file mode 100644 index 00000000..389360ea --- /dev/null +++ b/path_value_test.go @@ -0,0 +1,69 @@ +//go:build go1.22 +// +build go1.22 + +package chi + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestPathValue(t *testing.T) { + testCases := []struct { + name string + pattern string + method string + pathKeys []string + requestPath string + expectedBody string + }{ + { + name: "Basic path value", + pattern: "/hubs/{hubID}", + method: "GET", + pathKeys: []string{"hubID"}, + requestPath: "/hubs/392", + expectedBody: "392", + }, + { + name: "Two path values", + pattern: "/users/{userID}/conversations/{conversationID}", + method: "POST", + pathKeys: []string{"userID", "conversationID"}, + requestPath: "/users/Gojo/conversations/2948", + expectedBody: "Gojo 2948", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := NewRouter() + + r.Handle(tc.method+" "+tc.pattern, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + pathValues := []string{} + for _, pathKey := range tc.pathKeys { + pathValue := r.PathValue(pathKey) + if pathValue == "" { + pathValue = "NOT_FOUND:" + pathKey + } + + pathValues = append(pathValues, pathValue) + } + + body := strings.Join(pathValues, " ") + + w.Write([]byte(body)) + })) + + ts := httptest.NewServer(r) + defer ts.Close() + + _, body := testRequest(t, ts, tc.method, tc.requestPath, nil) + if body != tc.expectedBody { + t.Fatalf("expecting %q, got %q", tc.expectedBody, body) + } + }) + } +}