Skip to content

Commit

Permalink
Add schema tree node gathering for cleaning in pydantic GenerateSchema
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusSintonen committed Oct 19, 2024
1 parent 92a259e commit 7eae801
Show file tree
Hide file tree
Showing 7 changed files with 317 additions and 3 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
repository: pydantic/pydantic
repository: MarkusSintonen/pydantic # TODO remove before merging
ref: optimized-schema-building
path: pydantic

- uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ python-source = "python"
module-name = "pydantic_core._pydantic_core"
bindings = 'pyo3'
features = ["pyo3/extension-module"]
profile = "release" # TEMPORARY: remove this

[tool.ruff]
line-length = 120
Expand Down
12 changes: 11 additions & 1 deletion python/pydantic_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
ValidationError,
__version__,
from_json,
gather_schemas_for_cleaning,
to_json,
to_jsonable_python,
validate_core_schema,
)
from .core_schema import CoreConfig, CoreSchema, CoreSchemaType, ErrorType
from .core_schema import CoreConfig, CoreSchema, CoreSchemaType, DefinitionReferenceSchema, ErrorType

if _sys.version_info < (3, 11):
from typing_extensions import NotRequired as _NotRequired
Expand Down Expand Up @@ -67,6 +68,7 @@
'from_json',
'to_jsonable_python',
'validate_core_schema',
'gather_schemas_for_cleaning',
]


Expand Down Expand Up @@ -137,3 +139,11 @@ class MultiHostHost(_TypedDict):
"""The host part of this host, or `None`."""
port: int | None
"""The port part of this host, or `None`."""


class GatherResult(_TypedDict):
"""Internal result of gathering schemas for cleaning."""

definition_refs: dict[str, list[DefinitionReferenceSchema]]
recursive_refs: set[str]
deferred_discriminators: list[tuple[CoreSchema, _Any]]
7 changes: 6 additions & 1 deletion python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ from typing import Any, Callable, Generic, Literal, TypeVar, final
from _typeshed import SupportsAllComparisons
from typing_extensions import LiteralString, Self, TypeAlias

from pydantic_core import ErrorDetails, ErrorTypeInfo, InitErrorDetails, MultiHostHost
from pydantic_core import ErrorDetails, ErrorTypeInfo, GatherResult, InitErrorDetails, MultiHostHost
from pydantic_core.core_schema import CoreConfig, CoreSchema, ErrorType

__all__ = [
Expand Down Expand Up @@ -35,6 +35,7 @@ __all__ = [
'list_all_errors',
'TzInfo',
'validate_core_schema',
'gather_schemas_for_cleaning',
]
__version__: str
build_profile: str
Expand Down Expand Up @@ -1164,3 +1165,7 @@ def validate_core_schema(schema: CoreSchema, *, strict: bool | None = None) -> C
We may also remove this function altogether, do not rely on it being present if you are
using pydantic-core directly.
"""

def gather_schemas_for_cleaning(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> GatherResult:
"""Used internally for schema cleaning when schemas are generated.
Gathers information from the schema tree for the cleaning."""
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ mod errors;
mod input;
mod lookup_key;
mod recursion_guard;
mod schema_traverse;
mod serializers;
mod tools;
mod url;
Expand All @@ -35,6 +36,7 @@ pub use build_tools::SchemaError;
pub use errors::{
list_all_errors, PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault, ValidationError,
};
pub use schema_traverse::gather_schemas_for_cleaning;
pub use serializers::{
to_json, to_jsonable_python, PydanticSerializationError, PydanticSerializationUnexpectedValue, SchemaSerializer,
WarningsArg,
Expand Down Expand Up @@ -133,6 +135,7 @@ fn _pydantic_core(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(from_json, m)?)?;
m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?;
m.add_function(wrap_pyfunction!(list_all_errors, m)?)?;
m.add_function(wrap_pyfunction!(gather_schemas_for_cleaning, m)?)?;
m.add_function(wrap_pyfunction!(validate_core_schema, m)?)?;
Ok(())
}
187 changes: 187 additions & 0 deletions src/schema_traverse.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
use crate::tools::py_err;
use pyo3::exceptions::PyKeyError;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple};
use pyo3::{intern, Bound, PyResult};
use std::collections::HashSet;

const CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY: &str = "pydantic.internal.union_discriminator";

macro_rules! get {
($dict: expr, $key: expr) => {
$dict.get_item(intern!($dict.py(), $key))?
};
}

macro_rules! traverse_key_fn {
($key: expr, $func: expr, $dict: expr, $ctx: expr) => {{
if let Some(v) = get!($dict, $key) {
$func(v.downcast_exact()?, $ctx)?
}
}};
}

macro_rules! traverse {
($($key:expr => $func:expr),*; $dict: expr, $ctx: expr) => {{
$(traverse_key_fn!($key, $func, $dict, $ctx);)*
traverse_key_fn!("serialization", gather_serialization, $dict, $ctx);
gather_meta($dict, $ctx)
}}
}

macro_rules! defaultdict_list_append {
($dict: expr, $key: expr, $value: expr) => {{
match $dict.get_item($key)? {
None => {
let list = PyList::empty_bound($dict.py());
list.append($value)?;
$dict.set_item($key, list)?;
}
// Safety: we know that the value is a PyList as we just created it above
Some(list) => unsafe { list.downcast_unchecked::<PyList>() }.append($value)?,
};
}};
}

fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
if let Some(schema_ref) = get!(schema_ref_dict, "schema_ref") {
let schema_ref_pystr = schema_ref.downcast_exact::<PyString>()?;
let schema_ref_str = schema_ref_pystr.to_str()?;

if !ctx.recursively_seen_refs.contains(schema_ref_str) {
defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict);

// TODO should py_err! when not found. That error can be used to detect the missing defs in cleaning side
if let Some(definition) = ctx.definitions_dict.get_item(schema_ref_pystr)? {
ctx.recursively_seen_refs.insert(schema_ref_str.to_string());

gather_schema(definition.downcast_exact::<PyDict>()?, ctx)?;
traverse_key_fn!("serialization", gather_serialization, schema_ref_dict, ctx);
gather_meta(schema_ref_dict, ctx)?;

ctx.recursively_seen_refs.remove(schema_ref_str);
}
} else {
ctx.recursive_def_refs.add(schema_ref_pystr)?;
for seen_ref in &ctx.recursively_seen_refs {
let seen_ref_pystr = PyString::new_bound(schema_ref.py(), seen_ref);
ctx.recursive_def_refs.add(seen_ref_pystr)?;
}
}
Ok(())
} else {
py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref")
}
}

fn gather_serialization(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
traverse!("schema" => gather_schema, "return_schema" => gather_schema; schema, ctx)
}

fn gather_meta(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
if let Some(meta) = get!(schema, "metadata") {
let meta_dict = meta.downcast_exact::<PyDict>()?;
if let Some(discriminator) = get!(meta_dict, CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY) {
let schema_discriminator = PyTuple::new_bound(schema.py(), vec![schema.as_any(), &discriminator]);
ctx.discriminators.append(schema_discriminator)?;
}
}
Ok(())
}

fn gather_list(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
for v in schema_list.iter() {
gather_schema(v.downcast_exact()?, ctx)?;
}
Ok(())
}

fn gather_dict(schemas_by_key: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
for (_, v) in schemas_by_key.iter() {
gather_schema(v.downcast_exact()?, ctx)?;
}
Ok(())
}

fn gather_union_choices(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
for v in schema_list.iter() {
if let Ok(tup) = v.downcast_exact::<PyTuple>() {
gather_schema(tup.get_item(0)?.downcast_exact()?, ctx)?;
} else {
gather_schema(v.downcast_exact()?, ctx)?;
}
}
Ok(())
}

fn gather_arguments(arguments: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {
for v in arguments.iter() {
traverse_key_fn!("schema", gather_schema, v.downcast_exact::<PyDict>()?, ctx);
}
Ok(())
}

// Has 100% coverage in Pydantic side. This is exclusively used there
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn gather_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {
let type_ = get!(schema, "type");
if type_.is_none() {
return py_err!(PyKeyError; "Schema type missing");
}
match type_.unwrap().downcast_exact::<PyString>()?.to_str()? {
"definition-ref" => gather_definition_ref(schema, ctx),
"definitions" => traverse!("schema" => gather_schema, "definitions" => gather_list; schema, ctx),
"list" | "set" | "frozenset" | "generator" => traverse!("items_schema" => gather_schema; schema, ctx),
"tuple" => traverse!("items_schema" => gather_list; schema, ctx),
"dict" => traverse!("keys_schema" => gather_schema, "values_schema" => gather_schema; schema, ctx),
"union" => traverse!("choices" => gather_union_choices; schema, ctx),
"tagged-union" => traverse!("choices" => gather_dict; schema, ctx),
"chain" => traverse!("steps" => gather_list; schema, ctx),
"lax-or-strict" => traverse!("lax_schema" => gather_schema, "strict_schema" => gather_schema; schema, ctx),
"json-or-python" => traverse!("json_schema" => gather_schema, "python_schema" => gather_schema; schema, ctx),
"model-fields" | "typed-dict" => traverse!(
"extras_schema" => gather_schema, "computed_fields" => gather_list, "fields" => gather_dict; schema, ctx
),
"dataclass-args" => traverse!("computed_fields" => gather_list, "fields" => gather_list; schema, ctx),
"arguments" => traverse!(
"arguments_schema" => gather_arguments,
"var_args_schema" => gather_schema,
"var_kwargs_schema" => gather_schema;
schema, ctx
),
"call" => traverse!("arguments_schema" => gather_schema, "return_schema" => gather_schema; schema, ctx),
"computed-field" | "function-plain" => traverse!("return_schema" => gather_schema; schema, ctx),
"function-wrap" => traverse!("return_schema" => gather_schema, "schema" => gather_schema; schema, ctx),
_ => traverse!("schema" => gather_schema; schema, ctx),
}
}

pub struct GatherCtx<'a, 'py> {
pub definitions_dict: &'a Bound<'py, PyDict>,
pub def_refs: Bound<'py, PyDict>,
pub recursive_def_refs: Bound<'py, PySet>,
pub discriminators: Bound<'py, PyList>,
recursively_seen_refs: HashSet<String>,
}

#[pyfunction(signature = (schema, definitions))]
pub fn gather_schemas_for_cleaning<'py>(
schema: &Bound<'py, PyAny>,
definitions: &Bound<'py, PyAny>,
) -> PyResult<Bound<'py, PyDict>> {
let py = schema.py();
let mut ctx = GatherCtx {
definitions_dict: definitions.downcast_exact()?,
def_refs: PyDict::new_bound(py),
recursive_def_refs: PySet::empty_bound(py)?,
discriminators: PyList::empty_bound(py),
recursively_seen_refs: HashSet::new(),
};
gather_schema(schema.downcast_exact::<PyDict>()?, &mut ctx)?;

let res = PyDict::new_bound(py);
res.set_item(intern!(py, "definition_refs"), ctx.def_refs)?;
res.set_item(intern!(py, "recursive_refs"), ctx.recursive_def_refs)?;
res.set_item(intern!(py, "deferred_discriminators"), ctx.discriminators)?;
Ok(res)
}
107 changes: 107 additions & 0 deletions tests/test_gather_schemas_for_cleaning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from pydantic_core import core_schema, gather_schemas_for_cleaning


def test_no_refs():
p1 = core_schema.arguments_parameter('a', core_schema.int_schema())
p2 = core_schema.arguments_parameter('b', core_schema.int_schema())
schema = core_schema.arguments_schema([p1, p2])
res = gather_schemas_for_cleaning(schema, definitions={})
assert res['definition_refs'] == {}
assert res['recursive_refs'] == set()
assert res['deferred_discriminators'] == []


def test_simple_ref_schema():
schema = core_schema.definition_reference_schema('ref1')
definitions = {'ref1': core_schema.int_schema(ref='ref1')}

res = gather_schemas_for_cleaning(schema, definitions)
assert res['definition_refs'] == {'ref1': [schema]} and res['definition_refs']['ref1'][0] is schema
assert res['recursive_refs'] == set()
assert res['deferred_discriminators'] == []


def test_deep_ref_schema():
class Model:
pass

ref11 = core_schema.definition_reference_schema('ref1')
ref12 = core_schema.definition_reference_schema('ref1')
ref2 = core_schema.definition_reference_schema('ref2')

union = core_schema.union_schema([core_schema.int_schema(), (ref11, 'ref_label')])
tup = core_schema.tuple_schema([ref12, core_schema.str_schema()])
schema = core_schema.model_schema(
Model,
core_schema.model_fields_schema(
{'a': core_schema.model_field(union), 'b': core_schema.model_field(ref2), 'c': core_schema.model_field(tup)}
),
)
definitions = {'ref1': core_schema.str_schema(ref='ref1'), 'ref2': core_schema.bytes_schema(ref='ref2')}

res = gather_schemas_for_cleaning(schema, definitions)
assert res['definition_refs'] == {'ref1': [ref11, ref12], 'ref2': [ref2]}
assert res['definition_refs']['ref1'][0] is ref11 and res['definition_refs']['ref1'][1] is ref12
assert res['definition_refs']['ref2'][0] is ref2
assert res['recursive_refs'] == set()
assert res['deferred_discriminators'] == []


def test_ref_in_serialization_schema():
ref = core_schema.definition_reference_schema('ref1')
schema = core_schema.str_schema(
serialization=core_schema.plain_serializer_function_ser_schema(lambda v: v, return_schema=ref),
)
res = gather_schemas_for_cleaning(schema, definitions={'ref1': core_schema.str_schema()})
assert res['definition_refs'] == {'ref1': [ref]} and res['definition_refs']['ref1'][0] is ref
assert res['recursive_refs'] == set()
assert res['deferred_discriminators'] == []


def test_recursive_ref_schema():
ref1 = core_schema.definition_reference_schema('ref1')
res = gather_schemas_for_cleaning(ref1, definitions={'ref1': ref1})
assert res['definition_refs'] == {'ref1': [ref1]} and res['definition_refs']['ref1'][0] is ref1
assert res['recursive_refs'] == {'ref1'}
assert res['deferred_discriminators'] == []


def test_deep_recursive_ref_schema():
ref1 = core_schema.definition_reference_schema('ref1')
ref2 = core_schema.definition_reference_schema('ref2')
ref3 = core_schema.definition_reference_schema('ref3')

res = gather_schemas_for_cleaning(
core_schema.union_schema([ref1, core_schema.int_schema()]),
definitions={
'ref1': core_schema.union_schema([core_schema.int_schema(), ref2]),
'ref2': core_schema.union_schema([ref3, core_schema.float_schema()]),
'ref3': core_schema.union_schema([ref1, core_schema.str_schema()]),
},
)
assert res['definition_refs'] == {'ref1': [ref1], 'ref2': [ref2], 'ref3': [ref3]}
assert res['recursive_refs'] == {'ref1', 'ref2', 'ref3'}
assert res['definition_refs']['ref1'][0] is ref1
assert res['definition_refs']['ref2'][0] is ref2
assert res['definition_refs']['ref3'][0] is ref3
assert res['deferred_discriminators'] == []


def test_discriminator_meta():
class Model:
pass

ref1 = core_schema.definition_reference_schema('ref1')

field1 = core_schema.model_field(core_schema.str_schema())
field1['metadata'] = {'pydantic.internal.union_discriminator': 'foobar1'}

field2 = core_schema.model_field(core_schema.int_schema())
field2['metadata'] = {'pydantic.internal.union_discriminator': 'foobar2'}

schema = core_schema.model_schema(Model, core_schema.model_fields_schema({'a': field1, 'b': ref1}))
res = gather_schemas_for_cleaning(schema, definitions={'ref1': field2})
assert res['definition_refs'] == {'ref1': [ref1]} and res['definition_refs']['ref1'][0] is ref1
assert res['recursive_refs'] == set()
assert res['deferred_discriminators'] == [(field1, 'foobar1'), (field2, 'foobar2')]
assert res['deferred_discriminators'][0][0] is field1 and res['deferred_discriminators'][1][0] is field2

0 comments on commit 7eae801

Please sign in to comment.