Skip to content

Commit

Permalink
Validation with body which is intended for plain-text
Browse files Browse the repository at this point in the history
  • Loading branch information
Vlczech committed Jun 27, 2023
1 parent ac6143e commit 424589e
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 30 deletions.
2 changes: 2 additions & 0 deletions spectree/plugins/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Context(NamedTuple):
query: list
json: list
form: list
body: list
headers: dict
cookies: dict

Expand Down Expand Up @@ -60,6 +61,7 @@ def validate(
query: Optional[ModelType],
json: Optional[ModelType],
form: Optional[ModelType],
body: Optional[ModelType],
headers: Optional[ModelType],
cookies: Optional[ModelType],
resp: Optional[Response],
Expand Down
48 changes: 30 additions & 18 deletions spectree/plugins/falcon_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,33 +167,31 @@ def parse_path(self, route, path_parameter_descriptions):

return f'/{"/".join(subs)}', parameters

def request_validation(self, req, query, json, form, headers, cookies):
def request_validation(self, req, query, json, form, body, headers, cookies):
if query:
req.context.query = query.parse_obj(req.params)
if headers:
req.context.headers = headers.parse_obj(req.headers)
if cookies:
req.context.cookies = cookies.parse_obj(req.cookies)
if json:
try:
media = req.media
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
media = None
media = self._get_req_media(req)
req.context.json = json.parse_obj(media)
if form:
# TODO - possible to pass the BodyPart here?
# req_form = {x.name: x for x in req.get_media()}
req_form = {x.name: x.stream.read() for x in req.get_media()}
req.context.form = form.parse_obj(req_form)
if body:
req.context.body = self._get_req_media(req)

def validate(
self,
func: Callable,
query: Optional[ModelType],
json: Optional[ModelType],
form: Optional[ModelType],
body: Optional[ModelType],
headers: Optional[ModelType],
cookies: Optional[ModelType],
resp: Optional[Response],
Expand All @@ -208,10 +206,10 @@ def validate(
_self, _req, _resp = args[:3]
req_validation_error, resp_validation_error = None, None
try:
self.request_validation(_req, query, json, form, headers, cookies)
self.request_validation(_req, query, json, form, body, headers, cookies)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "form", "headers", "cookies"):
for name in ("query", "json", "form", "body", "headers", "cookies"):
if annotations.get(name):
kwargs[name] = getattr(_req.context, name)

Expand Down Expand Up @@ -248,6 +246,14 @@ def validate(
def _data_set_manually(self, resp):
return (resp.text is not None or resp.data is not None) and resp.media is None

def _get_req_media(self, req):
try:
return req.media
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
return None

def bypass(self, func, method):
if isinstance(func, partial):
return True
Expand All @@ -261,20 +267,15 @@ class FalconAsgiPlugin(FalconPlugin):
OPEN_API_ROUTE_CLASS = OpenAPIAsgi
DOC_PAGE_ROUTE_CLASS = DocPageAsgi

async def request_validation(self, req, query, json, form, headers, cookies):
async def request_validation(self, req, query, json, form, body, headers, cookies):
if query:
req.context.query = query.parse_obj(req.params)
if headers:
req.context.headers = headers.parse_obj(req.headers)
if cookies:
req.context.cookies = cookies.parse_obj(req.cookies)
if json:
try:
media = await req.get_media()
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
media = None
media = await self._get_req_media(req)
req.context.json = json.parse_obj(media)
if form:
try:
Expand All @@ -289,13 +290,16 @@ async def request_validation(self, req, query, json, form, headers, cookies):
res_data[x.name] = x
await x.data # TODO - how to avoid this?
req.context.form = form.parse_obj(res_data)
if body:
req.context.body = self._get_req_media(req)

async def validate(
self,
func: Callable,
query: Optional[ModelType],
json: Optional[ModelType],
form: Optional[ModelType],
body: Optional[ModelType],
headers: Optional[ModelType],
cookies: Optional[ModelType],
resp: Optional[Response],
Expand All @@ -310,10 +314,10 @@ async def validate(
_self, _req, _resp = args[:3]
req_validation_error, resp_validation_error = None, None
try:
await self.request_validation(_req, query, json, form, headers, cookies)
await self.request_validation(_req, query, json, form, body, headers, cookies)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "form", "headers", "cookies"):
for name in ("query", "json", "form", "body", "headers", "cookies"):
if annotations.get(name):
kwargs[name] = getattr(_req.context, name)

Expand Down Expand Up @@ -347,3 +351,11 @@ async def validate(
_resp.media = err.errors()

after(_req, _resp, resp_validation_error, _self)

async def _get_req_media(self, req):
try:
return await req.get_media()
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
return None
9 changes: 6 additions & 3 deletions spectree/plugins/flask_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def parse_path(

return "".join(subs), parameters

def request_validation(self, request, query, json, form, headers, cookies):
def request_validation(self, request, query, json, form, body, headers, cookies):
"""
req_query: werkzeug.datastructures.ImmutableMultiDict
req_json: dict
Expand All @@ -149,11 +149,13 @@ def request_validation(self, request, query, json, form, headers, cookies):
and has_data
and any([x in request.mimetype for x in self.FORM_MIMETYPE])
)
use_body = body and has_data and request.mimetype == "text/plain"

request.context = Context(
query.parse_obj(req_query) if query else None,
json.parse_obj(request.get_json(silent=True) or {}) if use_json else None,
form.parse_obj(self._fill_form(request)) if use_form else None,
body.parse_obj(request.get_data() or {}) if use_body else None,
headers.parse_obj(req_headers) if headers else None,
cookies.parse_obj(req_cookies) if cookies else None,
)
Expand All @@ -169,6 +171,7 @@ def validate(
query: Optional[ModelType],
json: Optional[ModelType],
form: Optional[ModelType],
body: Optional[ModelType],
headers: Optional[ModelType],
cookies: Optional[ModelType],
resp: Optional[Response],
Expand All @@ -181,10 +184,10 @@ def validate(
):
response, req_validation_error, resp_validation_error = None, None, None
try:
self.request_validation(request, query, json, form, headers, cookies)
self.request_validation(request, query, json, form, body, headers, cookies)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "form", "headers", "cookies"):
for name in ("query", "json", "form", "body", "headers", "cookies"):
if annotations.get(name):
kwargs[name] = getattr(request.context, name)
except ValidationError as err:
Expand Down
9 changes: 6 additions & 3 deletions spectree/plugins/quart_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def parse_path(

return "".join(subs), parameters

async def request_validation(self, request, query, json, form, headers, cookies):
async def request_validation(self, request, query, json, form, body, headers, cookies):
"""
req_query: werkzeug.datastructures.ImmutableMultiDict
req_json: dict
Expand All @@ -152,13 +152,15 @@ async def request_validation(self, request, query, json, form, headers, cookies)
and has_data
and any([x in request.mimetype for x in self.FORM_MIMETYPE])
)
use_body = body and has_data and request.mimetype == "text/plain"

request.context = Context(
query.parse_obj(req_query) if query else None,
json.parse_obj(await request.get_json(silent=True) or {})
if use_json
else None,
form.parse_obj(self._fill_form(request)) if use_form else None,
body.parse_obj(await request.get_data() or {}) if use_body else None,
headers.parse_obj(req_headers) if headers else None,
cookies.parse_obj(req_cookies) if cookies else None,
)
Expand All @@ -174,6 +176,7 @@ async def validate(
query: Optional[ModelType],
json: Optional[ModelType],
form: Optional[ModelType],
body: Optional[ModelType],
headers: Optional[ModelType],
cookies: Optional[ModelType],
resp: Optional[Response],
Expand All @@ -186,10 +189,10 @@ async def validate(
):
response, req_validation_error, resp_validation_error = None, None, None
try:
await self.request_validation(request, query, json, form, headers, cookies)
await self.request_validation(request, query, json, form, body, headers, cookies)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "form", "headers", "cookies"):
for name in ("query", "json", "form", "body", "headers", "cookies"):
if annotations.get(name):
kwargs[name] = getattr(request.context, name)
except ValidationError as err:
Expand Down
9 changes: 6 additions & 3 deletions spectree/plugins/starlette_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,19 @@ def register_route(self, app):
),
)

async def request_validation(self, request, query, json, form, headers, cookies):
async def request_validation(self, request, query, json, form, body, headers, cookies):
has_data = request.method not in ("GET", "DELETE")
content_type = request.headers.get("content-type", "").lower()
use_json = json and has_data and content_type == "application/json"
use_form = (
form and has_data and any([x in content_type for x in self.FORM_MIMETYPE])
)
use_body = body and has_data and content_type == "text/plain"
request.context = Context(
query.parse_obj(request.query_params) if query else None,
json.parse_obj(await request.json() or {}) if use_json else None,
form.parse_obj(await request.form() or {}) if use_form else None,
body.parse_obj(await request.body() or {}) if use_body else None,
headers.parse_obj(request.headers) if headers else None,
cookies.parse_obj(request.cookies) if cookies else None,
)
Expand All @@ -76,6 +78,7 @@ async def validate(
query: Optional[ModelType],
json: Optional[ModelType],
form: Optional[ModelType],
body: Optional[ModelType],
headers: Optional[ModelType],
cookies: Optional[ModelType],
resp: Optional[Response],
Expand All @@ -95,10 +98,10 @@ async def validate(
req_validation_error = resp_validation_error = json_decode_error = None

try:
await self.request_validation(request, query, json, form, headers, cookies)
await self.request_validation(request, query, json, form, body, headers, cookies)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "form", "headers", "cookies"):
for name in ("query", "json", "form", "body", "headers", "cookies"):
if annotations.get(name):
kwargs[name] = getattr(request.context, name)
except ValidationError as err:
Expand Down
11 changes: 8 additions & 3 deletions spectree/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def validate(
query: Optional[ModelType] = None,
json: Optional[ModelType] = None,
form: Optional[ModelType] = None,
body: Optional[ModelType] = None,
headers: Optional[ModelType] = None,
cookies: Optional[ModelType] = None,
resp: Optional[Response] = None,
Expand All @@ -150,6 +151,7 @@ def validate(
:param query: `pydantic.BaseModel`, query in uri like `?name=value`
:param json: `pydantic.BaseModel`, JSON format request body
:param form: `pydantic.BaseModel`, form-data request body
:param body: `pydantic.BaseModel`, raw (plain text) request body
:param headers: `pydantic.BaseModel`, if you have specific headers
:param cookies: `pydantic.BaseModel`, if you have cookies for this route
:param resp: `spectree.Response`
Expand Down Expand Up @@ -182,6 +184,7 @@ def sync_validate(*args: Any, **kwargs: Any):
query,
json,
form,
body,
headers,
cookies,
resp,
Expand All @@ -201,6 +204,7 @@ async def async_validate(*args: Any, **kwargs: Any):
query,
json,
form,
body,
headers,
cookies,
resp,
Expand All @@ -217,18 +221,19 @@ async def async_validate(*args: Any, **kwargs: Any):
)

if self.config.annotations:
nonlocal query, json, form, headers, cookies
nonlocal query, json, form, body, headers, cookies
annotations = get_type_hints(func)
query = annotations.get("query", query)
json = annotations.get("json", json)
form = annotations.get("form", form)
body = annotations.get("body", body)
headers = annotations.get("headers", headers)
cookies = annotations.get("cookies", cookies)

# register
for name, model in zip(
("query", "json", "form", "headers", "cookies"),
(query, json, form, headers, cookies),
("query", "json", "form", "body", "headers", "cookies"),
(query, json, form, body, headers, cookies),
):
if model is not None:
model_key = self._add_model(model=model)
Expand Down
5 changes: 5 additions & 0 deletions spectree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def parse_request(func: Any) -> Dict[str, Any]:
"schema": {"$ref": f"#/components/schemas/{func.form}"}
}

if hasattr(func, "body"):
content_items["text/plain"] = {
"schema": {"$ref": f"#/components/schemas/{func.body}"}
}

if not content_items:
return {}

Expand Down

0 comments on commit 424589e

Please sign in to comment.