diff --git a/apischema/json_schema/refs.py b/apischema/json_schema/refs.py index 802d77c1..d756f02e 100644 --- a/apischema/json_schema/refs.py +++ b/apischema/json_schema/refs.py @@ -27,7 +27,7 @@ def _default_ref(cls: AnyType) -> Ref: - if ( + if not hasattr(cls, "__parameters__") and ( is_dataclass(cls) or hasattr(cls, "__supertype__") or isinstance(cls, _TypedDictMeta) @@ -59,13 +59,21 @@ def __post_init__(self): raise ValueError("Empty schema ref not allowed") def check_type(self, cls: AnyType): - """Check if the given type can have a ref - - NewType of non-builtin types cannot have a ref because their serialization - could be customized, but the NewType ref would then erase this customization - in the schema""" if hasattr(cls, "__supertype__") and not is_builtin(cls): + # NewType of non-builtin types cannot have a ref because their serialization + # could be customized, but the NewType ref would then erase this + # customization in the schema. raise TypeError("NewType of non-builtin type can not have a ref") + if hasattr(cls, "__parameters__") and ( + not hasattr(cls, "__origin__") + or any( + isinstance(arg, TypeVar) # type: ignore + for arg in getattr(cls, "__args__", ()) + ) + ): + raise TypeError("Unspecialized generic types cannot have a ref") + if hasattr(cls, "__origin__") and self.ref is ...: + raise TypeError(f"Generic alias {cls} cannot have ... ref") def __call__(self, cls: T) -> T: self.check_type(cls) diff --git a/tests/test_refs.py b/tests/test_refs.py index 4d8c46ab..32da9f20 100644 --- a/tests/test_refs.py +++ b/tests/test_refs.py @@ -1,7 +1,11 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import Generic, List, Optional, TypeVar + +from _pytest.python_api import raises +from pytest import mark from apischema import schema_ref +from apischema.json_schema import deserialization_schema from apischema.json_schema.generation.builder import DeserializationSchemaBuilder from apischema.typing import Annotated @@ -42,3 +46,52 @@ def test_find_refs(): "Bs2": (Annotated[List[B], schema_ref("Bs2")], 1), "Recursive": (Recursive, 2), } + + +T = TypeVar("T") +U = TypeVar("U") + + +@dataclass +class DataGeneric(Generic[T]): + a: T + + +schema_ref("StrData")(DataGeneric[str]) + + +@mark.parametrize("cls", [DataGeneric, DataGeneric[U], DataGeneric[int]]) +def test_generic_ref_error(cls): + with raises(TypeError): + schema_ref(...)(cls) + + +def test_generic_schema(): + schema_ref("StrData")(DataGeneric[str]) + print() + assert deserialization_schema(DataGeneric, all_refs=True) == { + "$schema": "http://json-schema.org/draft/2019-09/schema#", + "type": "object", + "properties": {"a": {}}, + "required": ["a"], + "additionalProperties": False, + } + assert deserialization_schema(DataGeneric[int], all_refs=True) == { + "$schema": "http://json-schema.org/draft/2019-09/schema#", + "type": "object", + "properties": {"a": {"type": "integer"}}, + "required": ["a"], + "additionalProperties": False, + } + assert deserialization_schema(DataGeneric[str], all_refs=True) == { + "$schema": "http://json-schema.org/draft/2019-09/schema#", + "$ref": "#/$defs/StrData", + "$defs": { + "StrData": { + "type": "object", + "properties": {"a": {"type": "string"}}, + "required": ["a"], + "additionalProperties": False, + } + }, + }