Skip to content

Commit

Permalink
support dict as Schema
Browse files Browse the repository at this point in the history
  • Loading branch information
OmmyZhang committed Nov 28, 2022
1 parent 5ac7296 commit a17b409
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 5 deletions.
7 changes: 5 additions & 2 deletions flask_smorest/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import wraps
import http

import marshmallow as ma
from webargs.flaskparser import FlaskParser

from .utils import deepupdate
Expand All @@ -28,8 +29,8 @@ def arguments(
):
"""Decorator specifying the schema used to deserialize parameters
:param type|Schema schema: Marshmallow ``Schema`` class or instance
used to deserialize and validate the argument.
:param type|Schema|dict schema: Marshmallow ``Schema`` class or instance
or dict used to deserialize and validate the argument.
:param str location: Location of the argument.
:param str content_type: Content type of the argument.
Should only be used in conjunction with ``json``, ``form`` or
Expand All @@ -56,6 +57,8 @@ def arguments(
See :doc:`Arguments <arguments>`.
"""
if isinstance(schema, dict):
schema = ma.Schema.from_dict(schema)
# At this stage, put schema instance in doc dictionary. Il will be
# replaced later on by $ref or json.
parameters = {
Expand Down
8 changes: 7 additions & 1 deletion flask_smorest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections import abc

import marshmallow as ma
from werkzeug.datastructures import Headers
from flask import g
from apispec.utils import trim_docstring, dedent
Expand Down Expand Up @@ -31,9 +32,14 @@ def remove_none(mapping):
def resolve_schema_instance(schema):
"""Return schema instance for given schema (instance or class).
:param type|Schema schema: marshmallow.Schema instance or class
:param type|Schema|dict schema: marshmallow.Schema instance or class or dict
:return: schema instance of given schema
"""

# this dict may be used to document a file response, no a schema dict
if isinstance(schema, dict) and all([isinstance(v, (type, ma.fields.Field)) for v in schema.values()]):
schema = ma.Schema.from_dict(schema)

return schema() if isinstance(schema, type) else schema


Expand Down
9 changes: 7 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class ClientErrorSchema(ma.Schema):
error_id = ma.fields.Str()
text = ma.fields.Str()

return namedtuple("Model", ("DocSchema", "QueryArgsSchema", "ClientErrorSchema"))(
DocSchema, QueryArgsSchema, ClientErrorSchema
DictSchema = {
"item_id": ma.fields.Int(dump_only=True),
"field": ma.fields.Int(attribute="db_field"),
}

return namedtuple("Model", ("DocSchema", "QueryArgsSchema", "ClientErrorSchema", "DictSchema"))(
DocSchema, QueryArgsSchema, ClientErrorSchema, DictSchema
)
59 changes: 59 additions & 0 deletions tests/test_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,65 @@ def func(document, query_args):
"query_args": {"arg1": "test"},
}

@pytest.mark.parametrize("openapi_version", ("2.0", "3.0.2"))
def test_blueprint_dict_argument_schema(self, app, schemas, openapi_version):
app.config["OPENAPI_VERSION"] = openapi_version
api = Api(app)
blp = Blueprint("test", __name__, url_prefix="/test")
client = app.test_client()

@blp.route("/", methods=("POST",))
@blp.arguments(schemas.DictSchema)
def func(document):
return {"document": document}

api.register_blueprint(blp)
spec = api.spec.to_dict()

# Check parameters are documented
if openapi_version == "2.0":
parameters = spec["paths"]["/test/"]["post"]["parameters"]
assert len(parameters) == 1
assert parameters[0]["in"] == "body"
assert "schema" in parameters[0]
else:
assert (
"schema"
in spec["paths"]["/test/"]["post"]["requestBody"]["content"][
"application/json"
]
)

# Check parameters are passed as arguments to view function
item_data = {"field": 12}
response = client.post(
"/test/",
data=json.dumps(item_data),
content_type="application/json",
)
assert response.status_code == 200
assert response.json == {
"document": {"db_field": 12},
}

@pytest.mark.parametrize("openapi_version", ["2.0", "3.0.2"])
def test_blueprint_dict_response_schema(self, app, schemas, openapi_version):
"""Check alt_response passes response transparently"""
app.config["OPENAPI_VERSION"] = openapi_version
api = Api(app)
blp = Blueprint("test", "test", url_prefix="/test")
client = app.test_client()

@blp.route("/")
@blp.response(200, schema=schemas.DictSchema)
def func():
return {"item_id": 12}

api.register_blueprint(blp)

resp = client.get("/test/")
assert resp.json == {"item_id": 12}

@pytest.mark.parametrize("openapi_version", ("2.0", "3.0.2"))
def test_blueprint_arguments_files_multipart(self, app, schemas, openapi_version):
app.config["OPENAPI_VERSION"] = openapi_version
Expand Down

0 comments on commit a17b409

Please sign in to comment.