Skip to content

Commit

Permalink
ability to pass context to serialization (pydantic#7143) (#1215)
Browse files Browse the repository at this point in the history
Co-authored-by: ornariece <[email protected]>
  • Loading branch information
ornariece and ornariece authored Mar 6, 2024
1 parent c6301fe commit 1083986
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 21 deletions.
22 changes: 17 additions & 5 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class SchemaValidator:
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: 'dict[str, Any] | None' = None,
context: dict[str, Any] | None = None,
self_instance: Any | None = None,
) -> Any:
"""
Expand Down Expand Up @@ -131,7 +131,7 @@ class SchemaValidator:
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: 'dict[str, Any] | None' = None,
context: dict[str, Any] | None = None,
self_instance: Any | None = None,
) -> bool:
"""
Expand All @@ -148,7 +148,7 @@ class SchemaValidator:
input: str | bytes | bytearray,
*,
strict: bool | None = None,
context: 'dict[str, Any] | None' = None,
context: dict[str, Any] | None = None,
self_instance: Any | None = None,
) -> Any:
"""
Expand Down Expand Up @@ -176,7 +176,7 @@ class SchemaValidator:
The validated Python object.
"""
def validate_strings(
self, input: _StringInput, *, strict: bool | None = None, context: 'dict[str, Any] | None' = None
self, input: _StringInput, *, strict: bool | None = None, context: dict[str, Any] | None = None
) -> Any:
"""
Validate a string against the schema and return the validated Python object.
Expand Down Expand Up @@ -206,7 +206,7 @@ class SchemaValidator:
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: 'dict[str, Any] | None' = None,
context: dict[str, Any] | None = None,
) -> dict[str, Any] | tuple[dict[str, Any], dict[str, Any] | None, set[str]]:
"""
Validate an assignment to a field on a model.
Expand Down Expand Up @@ -278,6 +278,7 @@ class SchemaSerializer:
round_trip: bool = False,
warnings: bool = True,
fallback: Callable[[Any], Any] | None = None,
context: dict[str, Any] | None = None,
) -> Any:
"""
Serialize/marshal a Python object to a Python object including transforming and filtering data.
Expand All @@ -297,6 +298,8 @@ class SchemaSerializer:
warnings: Whether to log warnings when invalid fields are encountered.
fallback: A function to call when an unknown value is encountered,
if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
context: The context to use for serialization, this is passed to functional serializers as
[`info.context`][pydantic_core.core_schema.SerializationInfo.context].
Raises:
PydanticSerializationError: If serialization fails and no `fallback` function is provided.
Expand All @@ -318,6 +321,7 @@ class SchemaSerializer:
round_trip: bool = False,
warnings: bool = True,
fallback: Callable[[Any], Any] | None = None,
context: dict[str, Any] | None = None,
) -> bytes:
"""
Serialize a Python object to JSON including transforming and filtering data.
Expand All @@ -336,6 +340,8 @@ class SchemaSerializer:
warnings: Whether to log warnings when invalid fields are encountered.
fallback: A function to call when an unknown value is encountered,
if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
context: The context to use for serialization, this is passed to functional serializers as
[`info.context`][pydantic_core.core_schema.SerializationInfo.context].
Raises:
PydanticSerializationError: If serialization fails and no `fallback` function is provided.
Expand All @@ -358,6 +364,7 @@ def to_json(
inf_nan_mode: Literal['null', 'constants'] = 'constants',
serialize_unknown: bool = False,
fallback: Callable[[Any], Any] | None = None,
context: dict[str, Any] | None = None,
) -> bytes:
"""
Serialize a Python object to JSON including transforming and filtering data.
Expand All @@ -379,6 +386,8 @@ def to_json(
`"<Unserializable {value_type} object>"` will be used.
fallback: A function to call when an unknown value is encountered,
if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
context: The context to use for serialization, this is passed to functional serializers as
[`info.context`][pydantic_core.core_schema.SerializationInfo.context].
Raises:
PydanticSerializationError: If serialization fails and no `fallback` function is provided.
Expand Down Expand Up @@ -419,6 +428,7 @@ def to_jsonable_python(
inf_nan_mode: Literal['null', 'constants'] = 'constants',
serialize_unknown: bool = False,
fallback: Callable[[Any], Any] | None = None,
context: dict[str, Any] | None = None,
) -> Any:
"""
Serialize/marshal a Python object to a JSON-serializable Python object including transforming and filtering data.
Expand All @@ -440,6 +450,8 @@ def to_jsonable_python(
`"<Unserializable {value_type} object>"` will be used.
fallback: A function to call when an unknown value is encountered,
if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
context: The context to use for serialization, this is passed to functional serializers as
[`info.context`][pydantic_core.core_schema.SerializationInfo.context].
Raises:
PydanticSerializationError: If serialization fails and no `fallback` function is provided.
Expand Down
4 changes: 4 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def include(self) -> IncExCall: ...
@property
def exclude(self) -> IncExCall: ...

@property
def context(self) -> Any | None:
"""Current serialization context."""

@property
def mode(self) -> str: ...

Expand Down
2 changes: 1 addition & 1 deletion src/errors/validation_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ impl ValidationError {
include_input: bool,
) -> PyResult<&'py PyString> {
let state = SerializationState::new("iso8601", "utf8", "constants")?;
let extra = state.extra(py, &SerMode::Json, true, false, false, true, None);
let extra = state.extra(py, &SerMode::Json, true, false, false, true, None, None);
let serializer = ValidationErrorSerializer {
py,
line_errors: &self.line_errors,
Expand Down
12 changes: 10 additions & 2 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ impl SerializationState {
round_trip: bool,
serialize_unknown: bool,
fallback: Option<&'py PyAny>,
context: Option<&'py PyAny>,
) -> Extra<'py> {
Extra::new(
py,
Expand All @@ -59,6 +60,7 @@ impl SerializationState {
&self.rec_guard,
serialize_unknown,
fallback,
context,
)
}

Expand Down Expand Up @@ -90,6 +92,7 @@ pub(crate) struct Extra<'a> {
pub field_name: Option<&'a str>,
pub serialize_unknown: bool,
pub fallback: Option<&'a PyAny>,
pub context: Option<&'a PyAny>,
}

impl<'a> Extra<'a> {
Expand All @@ -107,6 +110,7 @@ impl<'a> Extra<'a> {
rec_guard: &'a SerRecursionState,
serialize_unknown: bool,
fallback: Option<&'a PyAny>,
context: Option<&'a PyAny>,
) -> Self {
Self {
mode,
Expand All @@ -124,6 +128,7 @@ impl<'a> Extra<'a> {
field_name: None,
serialize_unknown,
fallback,
context,
}
}

Expand Down Expand Up @@ -178,10 +183,11 @@ pub(crate) struct ExtraOwned {
config: SerializationConfig,
rec_guard: SerRecursionState,
check: SerCheck,
model: Option<PyObject>,
pub model: Option<PyObject>,
field_name: Option<String>,
serialize_unknown: bool,
fallback: Option<PyObject>,
pub fallback: Option<PyObject>,
pub context: Option<PyObject>,
}

impl ExtraOwned {
Expand All @@ -201,6 +207,7 @@ impl ExtraOwned {
field_name: extra.field_name.map(ToString::to_string),
serialize_unknown: extra.serialize_unknown,
fallback: extra.fallback.map(Into::into),
context: extra.context.map(Into::into),
}
}

Expand All @@ -221,6 +228,7 @@ impl ExtraOwned {
field_name: self.field_name.as_deref(),
serialize_unknown: self.serialize_unknown,
fallback: self.fallback.as_ref().map(|m| m.as_ref(py)),
context: self.context.as_ref().map(|m| m.as_ref(py)),
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ pub(crate) fn infer_to_python_known(
extra.rec_guard,
extra.serialize_unknown,
extra.fallback,
extra.context,
);
serializer.serializer.to_python(value, include, exclude, &extra)
};
Expand Down Expand Up @@ -468,6 +469,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
extra.rec_guard,
extra.serialize_unknown,
extra.fallback,
extra.context,
);
let pydantic_serializer =
PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra);
Expand Down
18 changes: 14 additions & 4 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl SchemaSerializer {
rec_guard: &'a SerRecursionState,
serialize_unknown: bool,
fallback: Option<&'a PyAny>,
context: Option<&'a PyAny>,
) -> Extra<'b> {
Extra::new(
py,
Expand All @@ -69,6 +70,7 @@ impl SchemaSerializer {
rec_guard,
serialize_unknown,
fallback,
context,
)
}
}
Expand All @@ -95,7 +97,7 @@ impl SchemaSerializer {
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (value, *, mode = None, include = None, exclude = None, by_alias = true,
exclude_unset = false, exclude_defaults = false, exclude_none = false, round_trip = false, warnings = true,
fallback = None))]
fallback = None, context = None))]
pub fn to_python(
&self,
py: Python,
Expand All @@ -110,6 +112,7 @@ impl SchemaSerializer {
round_trip: bool,
warnings: bool,
fallback: Option<&PyAny>,
context: Option<&PyAny>,
) -> PyResult<PyObject> {
let mode: SerMode = mode.into();
let warnings = CollectWarnings::new(warnings);
Expand All @@ -126,6 +129,7 @@ impl SchemaSerializer {
&rec_guard,
false,
fallback,
context,
);
let v = self.serializer.to_python(value, include, exclude, &extra)?;
warnings.final_check(py)?;
Expand All @@ -135,7 +139,7 @@ impl SchemaSerializer {
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (value, *, indent = None, include = None, exclude = None, by_alias = true,
exclude_unset = false, exclude_defaults = false, exclude_none = false, round_trip = false, warnings = true,
fallback = None))]
fallback = None, context = None))]
pub fn to_json(
&self,
py: Python,
Expand All @@ -150,6 +154,7 @@ impl SchemaSerializer {
round_trip: bool,
warnings: bool,
fallback: Option<&PyAny>,
context: Option<&PyAny>,
) -> PyResult<PyObject> {
let warnings = CollectWarnings::new(warnings);
let rec_guard = SerRecursionState::default();
Expand All @@ -165,6 +170,7 @@ impl SchemaSerializer {
&rec_guard,
false,
fallback,
context,
);
let bytes = to_json_bytes(
value,
Expand Down Expand Up @@ -213,7 +219,7 @@ impl SchemaSerializer {
#[pyfunction]
#[pyo3(signature = (value, *, indent = None, include = None, exclude = None, by_alias = true,
exclude_none = false, round_trip = false, timedelta_mode = "iso8601", bytes_mode = "utf8",
inf_nan_mode = "constants", serialize_unknown = false, fallback = None))]
inf_nan_mode = "constants", serialize_unknown = false, fallback = None, context = None))]
pub fn to_json(
py: Python,
value: &PyAny,
Expand All @@ -228,6 +234,7 @@ pub fn to_json(
inf_nan_mode: &str,
serialize_unknown: bool,
fallback: Option<&PyAny>,
context: Option<&PyAny>,
) -> PyResult<PyObject> {
let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?;
let extra = state.extra(
Expand All @@ -238,6 +245,7 @@ pub fn to_json(
round_trip,
serialize_unknown,
fallback,
context,
);
let serializer = type_serializers::any::AnySerializer.into();
let bytes = to_json_bytes(value, &serializer, include, exclude, &extra, indent, 1024)?;
Expand All @@ -249,7 +257,7 @@ pub fn to_json(
#[allow(clippy::too_many_arguments)]
#[pyfunction]
#[pyo3(signature = (value, *, include = None, exclude = None, by_alias = true, exclude_none = false, round_trip = false,
timedelta_mode = "iso8601", bytes_mode = "utf8", inf_nan_mode = "constants", serialize_unknown = false, fallback = None))]
timedelta_mode = "iso8601", bytes_mode = "utf8", inf_nan_mode = "constants", serialize_unknown = false, fallback = None, context = None))]
pub fn to_jsonable_python(
py: Python,
value: &PyAny,
Expand All @@ -263,6 +271,7 @@ pub fn to_jsonable_python(
inf_nan_mode: &str,
serialize_unknown: bool,
fallback: Option<&PyAny>,
context: Option<&PyAny>,
) -> PyResult<PyObject> {
let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?;
let extra = state.extra(
Expand All @@ -273,6 +282,7 @@ pub fn to_jsonable_python(
round_trip,
serialize_unknown,
fallback,
context,
);
let v = infer::infer_to_python(value, include, exclude, &extra)?;
state.final_check(py)?;
Expand Down
Loading

0 comments on commit 1083986

Please sign in to comment.