diff --git a/hug/middleware.py b/hug/middleware.py index ea522794..19a3c964 100644 --- a/hug/middleware.py +++ b/hug/middleware.py @@ -166,15 +166,18 @@ def match_route(self, reqpath): """Match a request with parameter to it's corresponding route""" route_dicts = [routes for _, routes in self.api.http.routes.items()][0] routes = [route for route, _ in route_dicts.items()] - if reqpath not in routes: - for route in routes: # replace params in route with regex - reqpath = re.sub(r"^(/v\d*/?)", "/", reqpath) - base_url = getattr(self.api.http, "base_url", "") - reqpath = reqpath.replace(base_url, "", 1) if base_url else reqpath - if re.match(re.sub(r"/{[^{}]+}", ".+", route) + "$", reqpath, re.DOTALL): - return route - - return reqpath + # If the route is valid, it should return the valid route. + for route in routes: # replace params in route with regex + reqpath = re.sub(r"^(/v\d*/?)", "/", reqpath) + # This will match the path with our without the trailing slash + if reqpath in route: + return route + base_url = getattr(self.api.http, "base_url", "") + reqpath = reqpath.replace(base_url, "", 1) if base_url else reqpath + if re.match(re.sub(r"/{[^{}]+}", ".+", route) + "$", reqpath, re.DOTALL): + return route + # If match route does not find a valid http route, it should return None + return None def process_response(self, request, response, resource, req_succeeded): """Add CORS headers to the response""" @@ -185,12 +188,17 @@ def process_response(self, request, response, resource, req_succeeded): response.set_header("Access-Control-Allow-Origin", origin) if request.method == "OPTIONS": # check if we are handling a preflight request - allowed_methods = set( - method - for _, routes in self.api.http.routes.items() - for method, _ in routes[self.match_route(request.path)].items() - ) - allowed_methods.add("OPTIONS") + allowed_methods = set(["OPTIONS"]) + # If we cannot match the route of a preflight request, send not_found from the origin. + route = self.match_route(request.path) + if not route: + self.api.http.not_found(request, response) + else: + allowed_methods.update(set( + method + for _, routes in self.api.http.routes.items() + for method, _ in routes[route].items() + )) # return allowed methods response.set_header("Access-Control-Allow-Methods", ", ".join(allowed_methods)) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 219e49fd..b1c914fd 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -153,3 +153,13 @@ def get_demo(param): assert set(methods.split(",")) == set(["OPTIONS", "GET", "DELETE", "PUT"]) assert set(allow.split(",")) == set(["OPTIONS", "GET", "DELETE", "PUT"]) assert response.headers_dict["access-control-max-age"] == "10" + + assert "404" in hug.test.get(hug_api, "/not_there").status + + response = hug.test.options(hug_api, "/not_there") + methods = response.headers_dict["access-control-allow-methods"].replace(" ", "") + allow = response.headers_dict["allow"].replace(" ", "") + assert set(methods.split(",")) == set(["OPTIONS"]) + assert set(allow.split(",")) == set(["OPTIONS"]) + assert response.headers_dict["access-control-max-age"] == "10" + assert "404" in response.status