Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support mapping types in OpenAPI schema #476

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions blacksheep/server/openapi/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import sys
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict, defaultdict
from dataclasses import dataclass, fields, is_dataclass
from datetime import date, datetime
from decimal import Decimal
from enum import Enum, IntEnum
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -666,6 +668,11 @@ def _get_schema_by_type(
if schema:
return schema

# Dict, OrderedDict, defaultdict are handled first than GenericAlias
schema = self._try_get_schema_for_mapping(object_type, type_args)
if schema:
return schema

# List, Set, Tuple are handled first than GenericAlias
schema = self._try_get_schema_for_iterable(object_type, type_args)
if schema:
Expand Down Expand Up @@ -733,6 +740,44 @@ def _try_get_schema_for_iterable(
items=self.get_schema_by_type(item_type, context_type_args),
)

def _try_get_schema_for_mapping(
self, object_type: Type, context_type_args: Optional[Dict[Any, Type]] = None
) -> Optional[Schema]:
if object_type in {dict, defaultdict, OrderedDict}:
# the user didn't specify the key and value types
return Schema(
type=ValueType.OBJECT,
additional_properties=Schema(
type=ValueType.STRING,
),
)

origin = get_origin(object_type)

if not origin or origin not in {
dict,
Dict,
collections_abc.Mapping,
}:
return None

# can be Dict, Dict[str, str] or dict[str, str] (Python 3.9),
# note: it could also be union if it wasn't handled above for dataclasses
try:
_, value_type = object_type.__args__ # type: ignore
except AttributeError: # pragma: no cover
value_type = str

if context_type_args and value_type in context_type_args:
value_type = context_type_args.get(value_type, value_type)

return Schema(
type=ValueType.OBJECT,
additional_properties=self.get_schema_by_type(
value_type, context_type_args
),
)

def get_fields(self, object_type: Any) -> List[FieldInfo]:
for handler in self._object_types_handlers:
if handler.handles_type(object_type):
Expand Down
62 changes: 61 additions & 1 deletion tests/test_openapi_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from datetime import date, datetime
from enum import IntEnum
from typing import Generic, List, Optional, Sequence, TypeVar, Union
from typing import Generic, List, Mapping, Optional, Sequence, TypeVar, Union
from uuid import UUID

import pytest
Expand Down Expand Up @@ -1513,6 +1513,66 @@ def home() -> Sequence[Cat]: ...
)


@pytest.mark.asyncio
async def test_handling_of_mapping(docs: OpenAPIHandler, serializer: Serializer):
app = get_app()

@app.router.route("/")
def home() -> Mapping[str, Mapping[int, List[Cat]]]:
...

docs.bind_app(app)
await app.start()

yaml = serializer.to_yaml(docs.generate_documentation(app))

assert (
yaml.strip()
== r"""
openapi: 3.0.3
info:
title: Example
version: 0.0.1
paths:
/:
get:
responses:
'200':
description: Success response
content:
application/json:
schema:
type: object
additionalProperties:
type: object
additionalProperties:
type: array
nullable: false
items:
$ref: '#/components/schemas/Cat'
nullable: false
nullable: false
operationId: home
components:
schemas:
Cat:
type: object
required:
- id
- name
properties:
id:
type: integer
format: int64
nullable: false
name:
type: string
nullable: false
tags: []
""".strip()
)


def test_handling_of_generic_with_forward_references(docs: OpenAPIHandler):
with pytest.warns(UserWarning):
docs.register_schema_for_type(GenericWithForwardRefExample[Cat])
Expand Down
Loading