Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support serialization mode specification from model config and SerializationConfig #1122

Merged
merged 12 commits into from
Jan 9, 2024
4 changes: 4 additions & 0 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def to_json(
round_trip: bool = False,
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
bytes_mode: Literal['utf8', 'base64'] = 'utf8',
inf_nan_mode: Literal['null', 'constants'] = 'constants',
serialize_unknown: bool = False,
fallback: Callable[[Any], Any] | None = None,
) -> bytes:
Expand All @@ -373,6 +374,7 @@ def to_json(
round_trip: Whether to enable serialization and validation round-trip support.
timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`.
bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`.
inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'` or `'constants'`.
serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails
`"<Unserializable {value_type} object>"` will be used.
fallback: A function to call when an unknown value is encountered,
Expand Down Expand Up @@ -414,6 +416,7 @@ def to_jsonable_python(
round_trip: bool = False,
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
bytes_mode: Literal['utf8', 'base64'] = 'utf8',
inf_nan_mode: Literal['null', 'constants'] = 'constants',
serialize_unknown: bool = False,
fallback: Callable[[Any], Any] | None = None,
) -> Any:
Expand All @@ -432,6 +435,7 @@ def to_jsonable_python(
round_trip: Whether to enable serialization and validation round-trip support.
timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`.
bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`.
inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'` or `'constants'`.
serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails
`"<Unserializable {value_type} object>"` will be used.
fallback: A function to call when an unknown value is encountered,
Expand Down
2 changes: 1 addition & 1 deletion src/errors/validation_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ impl ValidationError {
include_context: bool,
include_input: bool,
) -> PyResult<&'py PyString> {
let state = SerializationState::new("iso8601", "utf8")?;
let state = SerializationState::new("iso8601", "utf8", "constants")?;
let extra = state.extra(py, &SerMode::Json, true, false, false, true, None);
let serializer = ValidationErrorSerializer {
py,
Expand Down
144 changes: 64 additions & 80 deletions src/serializers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,60 +15,98 @@ use crate::tools::SchemaDict;
use super::errors::py_err_se_err;

#[derive(Debug, Clone)]
#[allow(clippy::struct_field_names)]
pub(crate) struct SerializationConfig {
pub timedelta_mode: TimedeltaMode,
pub bytes_mode: BytesMode,
pub inf_nan_mode: InfNanMode,
}

impl SerializationConfig {
pub fn from_config(config: Option<&PyDict>) -> PyResult<Self> {
let timedelta_mode = TimedeltaMode::from_config(config)?;
let bytes_mode = BytesMode::from_config(config)?;
let inf_nan_mode = InfNanMode::from_config(config)?;
Ok(Self {
timedelta_mode,
bytes_mode,
inf_nan_mode,
})
}

pub fn from_args(timedelta_mode: &str, bytes_mode: &str) -> PyResult<Self> {
pub fn from_args(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> {
Ok(Self {
timedelta_mode: TimedeltaMode::from_str(timedelta_mode)?,
bytes_mode: BytesMode::from_str(bytes_mode)?,
inf_nan_mode: InfNanMode::from_str(inf_nan_mode)?,
})
}
}

#[derive(Default, Debug, Clone)]
pub(crate) enum TimedeltaMode {
#[default]
Iso8601,
Float,
pub trait FromConfig {
fn from_config(config: Option<&PyDict>) -> PyResult<Self>
where
Self: Sized;
}

impl FromStr for TimedeltaMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"iso8601" => Ok(Self::Iso8601),
"float" => Ok(Self::Float),
s => py_schema_err!(
"Invalid timedelta serialization mode: `{}`, expected `iso8601` or `float`",
s
),
macro_rules! serialization_mode {
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
($name:ident, $config_key:expr, $($variant:ident => $value:expr),* $(,)?) => {
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub(crate) enum $name {
#[default]
$($variant,)*
}
}

impl FromStr for $name {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
$($value => Ok(Self::$variant),)*
s => py_schema_err!(
concat!("Invalid ", stringify!($name), " serialization mode: `{}`, expected ", $($value, " or "),*),
s
),
}
}
}

impl FromConfig for $name {
fn from_config(config: Option<&PyDict>) -> PyResult<Self> {
let Some(config_dict) = config else {
return Ok(Self::default());
};
let raw_mode = config_dict.get_as::<&str>(intern!(config_dict.py(), $config_key))?;
raw_mode.map_or_else(|| Ok(Self::default()), Self::from_str)
}
}

};
}

impl TimedeltaMode {
pub fn from_config(config: Option<&PyDict>) -> PyResult<Self> {
let Some(config_dict) = config else {
return Ok(Self::default());
};
let raw_mode = config_dict.get_as::<&str>(intern!(config_dict.py(), "ser_json_timedelta"))?;
raw_mode.map_or_else(|| Ok(Self::default()), Self::from_str)
}
serialization_mode! {
TimedeltaMode,
"ser_json_timedelta",
Iso8601 => "iso8601",
Float => "float",
}

serialization_mode! {
BytesMode,
"ser_json_bytes",
Utf8 => "utf8",
Base64 => "base64",
Hex => "hex",
}

serialization_mode! {
InfNanMode,
"ser_json_inf_nan",
Null => "null",
Constants => "constants",
}

impl TimedeltaMode {
fn total_seconds(py_timedelta: &PyDelta) -> PyResult<&PyAny> {
py_timedelta.call_method0(intern!(py_timedelta.py(), "total_seconds"))
}
Expand Down Expand Up @@ -124,39 +162,7 @@ impl TimedeltaMode {
}
}

#[derive(Default, Debug, Clone)]
pub(crate) enum BytesMode {
#[default]
Utf8,
Base64,
Hex,
}

impl FromStr for BytesMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"utf8" => Ok(Self::Utf8),
"base64" => Ok(Self::Base64),
"hex" => Ok(Self::Hex),
s => py_schema_err!(
"Invalid bytes serialization mode: `{}`, expected `utf8`, `base64` or `hex`",
s
),
}
}
}

impl BytesMode {
pub fn from_config(config: Option<&PyDict>) -> PyResult<Self> {
let Some(config_dict) = config else {
return Ok(Self::default());
};
let raw_mode = config_dict.get_as::<&str>(intern!(config_dict.py(), "ser_json_bytes"))?;
raw_mode.map_or_else(|| Ok(Self::default()), Self::from_str)
}

pub fn bytes_to_string<'py>(&self, py: Python, bytes: &'py [u8]) -> PyResult<Cow<'py, str>> {
match self {
Self::Utf8 => from_utf8(bytes)
Expand Down Expand Up @@ -190,28 +196,6 @@ pub fn utf8_py_error(py: Python, err: Utf8Error, data: &[u8]) -> PyErr {
}
}

#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub(crate) enum InfNanMode {
#[default]
Null,
Constants,
}

impl FromStr for InfNanMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"null" => Ok(Self::Null),
"constants" => Ok(Self::Constants),
s => py_schema_err!(
"Invalid inf_nan serialization mode: `{}`, expected `null` or `constants`",
s
),
}
}
}

impl FromPyObject<'_> for InfNanMode {
fn extract(ob: &'_ PyAny) -> PyResult<Self> {
let s = ob.extract::<&str>()?;
Expand Down
4 changes: 2 additions & 2 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ pub(crate) struct SerializationState {
}

impl SerializationState {
pub fn new(timedelta_mode: &str, bytes_mode: &str) -> PyResult<Self> {
pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> {
let warnings = CollectWarnings::new(false);
let rec_guard = SerRecursionGuard::default();
let config = SerializationConfig::from_args(timedelta_mode, bytes_mode)?;
let config = SerializationConfig::from_args(timedelta_mode, bytes_mode, inf_nan_mode)?;
Ok(Self {
warnings,
rec_guard,
Expand Down
11 changes: 9 additions & 2 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use pyo3::types::{
use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};

use crate::input::{EitherTimedelta, Int};
use crate::serializers::config::InfNanMode;
use crate::serializers::errors::SERIALIZATION_ERR_MARKER;
use crate::serializers::filter::SchemaFilter;
use crate::serializers::shared::{PydanticSerializer, TypeSerializer};
Expand Down Expand Up @@ -120,10 +121,16 @@ pub(crate) fn infer_to_python_known(
let value = match extra.mode {
SerMode::Json => match ob_type {
// `bool` and `None` can't be subclasses, `ObType::Int`, `ObType::Float`, `ObType::Str` refer to exact types
ObType::None | ObType::Bool | ObType::Int | ObType::Float | ObType::Str => value.into_py(py),
ObType::None | ObType::Bool | ObType::Int | ObType::Str => value.into_py(py),
// have to do this to make sure subclasses of for example str are upcast to `str`
ObType::IntSubclass => extract_i64(value)?.into_py(py),
ObType::FloatSubclass => value.extract::<f64>()?.into_py(py),
ObType::Float | ObType::FloatSubclass => {
let v = value.extract::<f64>()?;
if (v.is_nan() || v.is_infinite()) && extra.config.inf_nan_mode == InfNanMode::Null {
return Ok(py.None());
}
v.into_py(py)
}
ObType::Decimal => value.to_string().into_py(py),
ObType::StrSubclass => value.extract::<&str>()?.into_py(py),
ObType::Bytes => extra
Expand Down
10 changes: 6 additions & 4 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ impl SchemaSerializer {
#[pyfunction]
#[pyo3(signature = (value, *, indent = None, include = None, exclude = None, by_alias = true,
exclude_none = false, round_trip = false, timedelta_mode = "iso8601", bytes_mode = "utf8",
serialize_unknown = false, fallback = None))]
inf_nan_mode = "constants", serialize_unknown = false, fallback = None))]
pub fn to_json(
py: Python,
value: &PyAny,
Expand All @@ -225,10 +225,11 @@ pub fn to_json(
round_trip: bool,
timedelta_mode: &str,
bytes_mode: &str,
inf_nan_mode: &str,
serialize_unknown: bool,
fallback: Option<&PyAny>,
) -> PyResult<PyObject> {
let state = SerializationState::new(timedelta_mode, bytes_mode)?;
let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?;
let extra = state.extra(
py,
&SerMode::Json,
Expand All @@ -248,7 +249,7 @@ pub fn to_json(
#[allow(clippy::too_many_arguments)]
#[pyfunction]
#[pyo3(signature = (value, *, include = None, exclude = None, by_alias = true, exclude_none = false, round_trip = false,
timedelta_mode = "iso8601", bytes_mode = "utf8", serialize_unknown = false, fallback = None))]
timedelta_mode = "iso8601", bytes_mode = "utf8", inf_nan_mode = "constants", serialize_unknown = false, fallback = None))]
pub fn to_jsonable_python(
py: Python,
value: &PyAny,
Expand All @@ -259,10 +260,11 @@ pub fn to_jsonable_python(
round_trip: bool,
timedelta_mode: &str,
bytes_mode: &str,
inf_nan_mode: &str,
serialize_unknown: bool,
fallback: Option<&PyAny>,
) -> PyResult<PyObject> {
let state = SerializationState::new(timedelta_mode, bytes_mode)?;
let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?;
let extra = state.extra(
py,
&SerMode::Json,
Expand Down
17 changes: 10 additions & 7 deletions src/serializers/type_serializers/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,28 @@ use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict};

use crate::definitions::DefinitionsBuilder;
use crate::serializers::config::{BytesMode, FromConfig};

use super::{
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerMode,
TypeSerializer,
};

#[derive(Debug, Clone)]
pub struct BytesSerializer;
pub struct BytesSerializer {
bytes_mode: BytesMode,
davidhewitt marked this conversation as resolved.
Show resolved Hide resolved
}

impl BuildSerializer for BytesSerializer {
const EXPECTED_TYPE: &'static str = "bytes";

fn build(
_schema: &PyDict,
_config: Option<&PyDict>,
config: Option<&PyDict>,
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
Ok(Self {}.into())
let bytes_mode = BytesMode::from_config(config)?;
Ok(Self { bytes_mode }.into())
}
}

Expand All @@ -38,8 +42,7 @@ impl TypeSerializer for BytesSerializer {
let py = value.py();
match value.downcast::<PyBytes>() {
Ok(py_bytes) => match extra.mode {
SerMode::Json => extra
.config
SerMode::Json => self
.bytes_mode
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
.bytes_to_string(py, py_bytes.as_bytes())
.map(|s| s.into_py(py)),
Expand All @@ -54,7 +57,7 @@ impl TypeSerializer for BytesSerializer {

fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult<Cow<'py, str>> {
match key.downcast::<PyBytes>() {
Ok(py_bytes) => extra.config.bytes_mode.bytes_to_string(key.py(), py_bytes.as_bytes()),
Ok(py_bytes) => self.bytes_mode.bytes_to_string(key.py(), py_bytes.as_bytes()),
Err(_) => {
extra.warnings.on_fallback_py(self.get_name(), key, extra)?;
infer_json_key(key, extra)
Expand All @@ -71,7 +74,7 @@ impl TypeSerializer for BytesSerializer {
extra: &Extra,
) -> Result<S::Ok, S::Error> {
match value.downcast::<PyBytes>() {
Ok(py_bytes) => extra.config.bytes_mode.serialize_bytes(py_bytes.as_bytes(), serializer),
Ok(py_bytes) => self.bytes_mode.serialize_bytes(py_bytes.as_bytes(), serializer),
Err(_) => {
extra.warnings.on_fallback_ser::<S>(self.get_name(), value, extra)?;
infer_serialize(value, serializer, include, exclude, extra)
Expand Down
Loading