Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation with body which is intended for plain-text #317

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions spectree/plugins/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Context(NamedTuple):
query: list
json: list
form: list
body: list
headers: dict
cookies: dict

Expand Down Expand Up @@ -60,6 +61,7 @@ def validate(
query: Optional[ModelType],
json: Optional[ModelType],
form: Optional[ModelType],
body: Optional[ModelType],
headers: Optional[ModelType],
cookies: Optional[ModelType],
resp: Optional[Response],
Expand Down
50 changes: 32 additions & 18 deletions spectree/plugins/falcon_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,33 +167,31 @@ def parse_path(self, route, path_parameter_descriptions):

return f'/{"/".join(subs)}', parameters

def request_validation(self, req, query, json, form, headers, cookies):
def request_validation(self, req, query, json, form, body, headers, cookies):
if query:
req.context.query = query.parse_obj(req.params)
if headers:
req.context.headers = headers.parse_obj(req.headers)
if cookies:
req.context.cookies = cookies.parse_obj(req.cookies)
if json:
try:
media = req.media
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
media = None
media = self._get_req_media(req)
req.context.json = json.parse_obj(media)
if form:
# TODO - possible to pass the BodyPart here?
# req_form = {x.name: x for x in req.get_media()}
req_form = {x.name: x.stream.read() for x in req.get_media()}
req.context.form = form.parse_obj(req_form)
if body:
req.context.body = self._get_req_media(req)

def validate(
self,
func: Callable,
query: Optional[ModelType],
json: Optional[ModelType],
form: Optional[ModelType],
body: Optional[ModelType],
headers: Optional[ModelType],
cookies: Optional[ModelType],
resp: Optional[Response],
Expand All @@ -208,10 +206,10 @@ def validate(
_self, _req, _resp = args[:3]
req_validation_error, resp_validation_error = None, None
try:
self.request_validation(_req, query, json, form, headers, cookies)
self.request_validation(_req, query, json, form, body, headers, cookies)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "form", "headers", "cookies"):
for name in ("query", "json", "form", "body", "headers", "cookies"):
if annotations.get(name):
kwargs[name] = getattr(_req.context, name)

Expand Down Expand Up @@ -248,6 +246,14 @@ def validate(
def _data_set_manually(self, resp):
return (resp.text is not None or resp.data is not None) and resp.media is None

def _get_req_media(self, req):
try:
return req.media
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
return None

def bypass(self, func, method):
if isinstance(func, partial):
return True
Expand All @@ -261,20 +267,15 @@ class FalconAsgiPlugin(FalconPlugin):
OPEN_API_ROUTE_CLASS = OpenAPIAsgi
DOC_PAGE_ROUTE_CLASS = DocPageAsgi

async def request_validation(self, req, query, json, form, headers, cookies):
async def request_validation(self, req, query, json, form, body, headers, cookies):
if query:
req.context.query = query.parse_obj(req.params)
if headers:
req.context.headers = headers.parse_obj(req.headers)
if cookies:
req.context.cookies = cookies.parse_obj(req.cookies)
if json:
try:
media = await req.get_media()
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
media = None
media = await self._get_req_media(req)
req.context.json = json.parse_obj(media)
if form:
try:
Expand All @@ -289,13 +290,16 @@ async def request_validation(self, req, query, json, form, headers, cookies):
res_data[x.name] = x
await x.data # TODO - how to avoid this?
req.context.form = form.parse_obj(res_data)
if body:
req.context.body = self._get_req_media(req)

async def validate(
self,
func: Callable,
query: Optional[ModelType],
json: Optional[ModelType],
form: Optional[ModelType],
body: Optional[ModelType],
headers: Optional[ModelType],
cookies: Optional[ModelType],
resp: Optional[Response],
Expand All @@ -310,10 +314,12 @@ async def validate(
_self, _req, _resp = args[:3]
req_validation_error, resp_validation_error = None, None
try:
await self.request_validation(_req, query, json, form, headers, cookies)
await self.request_validation(
_req, query, json, form, body, headers, cookies
)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "form", "headers", "cookies"):
for name in ("query", "json", "form", "body", "headers", "cookies"):
if annotations.get(name):
kwargs[name] = getattr(_req.context, name)

Expand Down Expand Up @@ -347,3 +353,11 @@ async def validate(
_resp.media = err.errors()

after(_req, _resp, resp_validation_error, _self)

async def _get_req_media(self, req):
try:
return await req.get_media()
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
return None
9 changes: 6 additions & 3 deletions spectree/plugins/flask_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def parse_path(

return "".join(subs), parameters

def request_validation(self, request, query, json, form, headers, cookies):
def request_validation(self, request, query, json, form, body, headers, cookies):
"""
req_query: werkzeug.datastructures.ImmutableMultiDict
req_json: dict
Expand All @@ -149,11 +149,13 @@ def request_validation(self, request, query, json, form, headers, cookies):
and has_data
and any([x in request.mimetype for x in self.FORM_MIMETYPE])
)
use_body = body and has_data and request.mimetype == "text/plain"

request.context = Context(
query.parse_obj(req_query) if query else None,
json.parse_obj(request.get_json(silent=True) or {}) if use_json else None,
form.parse_obj(self._fill_form(request)) if use_form else None,
body.parse_obj(request.get_data() or {}) if use_body else None,
headers.parse_obj(req_headers) if headers else None,
cookies.parse_obj(req_cookies) if cookies else None,
)
Expand All @@ -169,6 +171,7 @@ def validate(
query: Optional[ModelType],
json: Optional[ModelType],
form: Optional[ModelType],
body: Optional[ModelType],
headers: Optional[ModelType],
cookies: Optional[ModelType],
resp: Optional[Response],
Expand All @@ -181,10 +184,10 @@ def validate(
):
response, req_validation_error, resp_validation_error = None, None, None
try:
self.request_validation(request, query, json, form, headers, cookies)
self.request_validation(request, query, json, form, body, headers, cookies)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "form", "headers", "cookies"):
for name in ("query", "json", "form", "body", "headers", "cookies"):
if annotations.get(name):
kwargs[name] = getattr(request.context, name)
except ValidationError as err:
Expand Down
13 changes: 10 additions & 3 deletions spectree/plugins/quart_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def parse_path(

return "".join(subs), parameters

async def request_validation(self, request, query, json, form, headers, cookies):
async def request_validation(
self, request, query, json, form, body, headers, cookies
):
"""
req_query: werkzeug.datastructures.ImmutableMultiDict
req_json: dict
Expand All @@ -152,13 +154,15 @@ async def request_validation(self, request, query, json, form, headers, cookies)
and has_data
and any([x in request.mimetype for x in self.FORM_MIMETYPE])
)
use_body = body and has_data and request.mimetype == "text/plain"

request.context = Context(
query.parse_obj(req_query) if query else None,
json.parse_obj(await request.get_json(silent=True) or {})
if use_json
else None,
form.parse_obj(self._fill_form(request)) if use_form else None,
body.parse_obj(await request.get_data() or {}) if use_body else None,
headers.parse_obj(req_headers) if headers else None,
cookies.parse_obj(req_cookies) if cookies else None,
)
Expand All @@ -174,6 +178,7 @@ async def validate(
query: Optional[ModelType],
json: Optional[ModelType],
form: Optional[ModelType],
body: Optional[ModelType],
headers: Optional[ModelType],
cookies: Optional[ModelType],
resp: Optional[Response],
Expand All @@ -186,10 +191,12 @@ async def validate(
):
response, req_validation_error, resp_validation_error = None, None, None
try:
await self.request_validation(request, query, json, form, headers, cookies)
await self.request_validation(
request, query, json, form, body, headers, cookies
)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "form", "headers", "cookies"):
for name in ("query", "json", "form", "body", "headers", "cookies"):
if annotations.get(name):
kwargs[name] = getattr(request.context, name)
except ValidationError as err:
Expand Down
13 changes: 10 additions & 3 deletions spectree/plugins/starlette_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,21 @@ def register_route(self, app):
),
)

async def request_validation(self, request, query, json, form, headers, cookies):
async def request_validation(
self, request, query, json, form, body, headers, cookies
):
has_data = request.method not in ("GET", "DELETE")
content_type = request.headers.get("content-type", "").lower()
use_json = json and has_data and content_type == "application/json"
use_form = (
form and has_data and any([x in content_type for x in self.FORM_MIMETYPE])
)
use_body = body and has_data and content_type == "text/plain"
request.context = Context(
query.parse_obj(request.query_params) if query else None,
json.parse_obj(await request.json() or {}) if use_json else None,
form.parse_obj(await request.form() or {}) if use_form else None,
body.parse_obj(await request.body() or {}) if use_body else None,
headers.parse_obj(request.headers) if headers else None,
cookies.parse_obj(request.cookies) if cookies else None,
)
Expand All @@ -76,6 +80,7 @@ async def validate(
query: Optional[ModelType],
json: Optional[ModelType],
form: Optional[ModelType],
body: Optional[ModelType],
headers: Optional[ModelType],
cookies: Optional[ModelType],
resp: Optional[Response],
Expand All @@ -95,10 +100,12 @@ async def validate(
req_validation_error = resp_validation_error = json_decode_error = None

try:
await self.request_validation(request, query, json, form, headers, cookies)
await self.request_validation(
request, query, json, form, body, headers, cookies
)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "form", "headers", "cookies"):
for name in ("query", "json", "form", "body", "headers", "cookies"):
if annotations.get(name):
kwargs[name] = getattr(request.context, name)
except ValidationError as err:
Expand Down
11 changes: 8 additions & 3 deletions spectree/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def validate(
query: Optional[ModelType] = None,
json: Optional[ModelType] = None,
form: Optional[ModelType] = None,
body: Optional[ModelType] = None,
headers: Optional[ModelType] = None,
cookies: Optional[ModelType] = None,
resp: Optional[Response] = None,
Expand All @@ -150,6 +151,7 @@ def validate(
:param query: `pydantic.BaseModel`, query in uri like `?name=value`
:param json: `pydantic.BaseModel`, JSON format request body
:param form: `pydantic.BaseModel`, form-data request body
:param body: `pydantic.BaseModel`, raw (plain text) request body
:param headers: `pydantic.BaseModel`, if you have specific headers
:param cookies: `pydantic.BaseModel`, if you have cookies for this route
:param resp: `spectree.Response`
Expand Down Expand Up @@ -182,6 +184,7 @@ def sync_validate(*args: Any, **kwargs: Any):
query,
json,
form,
body,
headers,
cookies,
resp,
Expand All @@ -201,6 +204,7 @@ async def async_validate(*args: Any, **kwargs: Any):
query,
json,
form,
body,
headers,
cookies,
resp,
Expand All @@ -217,18 +221,19 @@ async def async_validate(*args: Any, **kwargs: Any):
)

if self.config.annotations:
nonlocal query, json, form, headers, cookies
nonlocal query, json, form, body, headers, cookies
annotations = get_type_hints(func)
query = annotations.get("query", query)
json = annotations.get("json", json)
form = annotations.get("form", form)
body = annotations.get("body", body)
headers = annotations.get("headers", headers)
cookies = annotations.get("cookies", cookies)

# register
for name, model in zip(
("query", "json", "form", "headers", "cookies"),
(query, json, form, headers, cookies),
("query", "json", "form", "body", "headers", "cookies"),
(query, json, form, body, headers, cookies),
):
if model is not None:
model_key = self._add_model(model=model)
Expand Down
5 changes: 5 additions & 0 deletions spectree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def parse_request(func: Any) -> Dict[str, Any]:
"schema": {"$ref": f"#/components/schemas/{func.form}"}
}

if hasattr(func, "body"):
content_items["text/plain"] = {
"schema": {"$ref": f"#/components/schemas/{func.body}"}
}
Comment on lines +94 to +97
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will the BaseModel look like? Can you provide an example? If possible, can you add some tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kemingy It is intended for text inputs like CSV etc. BaseModel is used to describe data in this model and it can be like this:

class FooBarPostBodyData(BaseModel):
    """
    CSV body for example for some import
    """
    __root__: str = Field(description="Data in CSV format - it consists of columns name, surname, ..., ..., ... .")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kemingy : ?

Copy link
Member

@kemingy kemingy Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late response.

I feel that if we need to create a special argument for each MIME type, there will be endless arguments. By the way, they all belong to the request body. I wonder if we can add an attribute to the request model, for example, spec_mime_type or something else. This should be able to handle this kind of feature. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some old discussions here #176

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it could be useful to process on one place for multiple mime-types.


if not content_items:
return {}

Expand Down