diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index 382a6c804..304a0612d 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -355,6 +355,7 @@ def to_json( round_trip: bool = False, timedelta_mode: Literal['iso8601', 'float'] = 'iso8601', bytes_mode: Literal['utf8', 'base64'] = 'utf8', + inf_nan_mode: Literal['null', 'constants'] = 'constants', serialize_unknown: bool = False, fallback: Callable[[Any], Any] | None = None, ) -> bytes: @@ -373,6 +374,7 @@ def to_json( round_trip: Whether to enable serialization and validation round-trip support. timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`. bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`. + inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'` or `'constants'`. serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails `""` will be used. fallback: A function to call when an unknown value is encountered, @@ -414,6 +416,7 @@ def to_jsonable_python( round_trip: bool = False, timedelta_mode: Literal['iso8601', 'float'] = 'iso8601', bytes_mode: Literal['utf8', 'base64'] = 'utf8', + inf_nan_mode: Literal['null', 'constants'] = 'constants', serialize_unknown: bool = False, fallback: Callable[[Any], Any] | None = None, ) -> Any: @@ -432,6 +435,7 @@ def to_jsonable_python( round_trip: Whether to enable serialization and validation round-trip support. timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`. bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`. + inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'` or `'constants'`. serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails `""` will be used. fallback: A function to call when an unknown value is encountered, diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index e77b21974..cd55e9aef 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -303,7 +303,7 @@ impl ValidationError { include_context: bool, include_input: bool, ) -> PyResult<&'py PyString> { - let state = SerializationState::new("iso8601", "utf8")?; + let state = SerializationState::new("iso8601", "utf8", "constants")?; let extra = state.extra(py, &SerMode::Json, true, false, false, true, None); let serializer = ValidationErrorSerializer { py, diff --git a/src/serializers/config.rs b/src/serializers/config.rs index e83497f64..422ee4162 100644 --- a/src/serializers/config.rs +++ b/src/serializers/config.rs @@ -15,60 +15,98 @@ use crate::tools::SchemaDict; use super::errors::py_err_se_err; #[derive(Debug, Clone)] +#[allow(clippy::struct_field_names)] pub(crate) struct SerializationConfig { pub timedelta_mode: TimedeltaMode, pub bytes_mode: BytesMode, + pub inf_nan_mode: InfNanMode, } impl SerializationConfig { pub fn from_config(config: Option<&PyDict>) -> PyResult { let timedelta_mode = TimedeltaMode::from_config(config)?; let bytes_mode = BytesMode::from_config(config)?; + let inf_nan_mode = InfNanMode::from_config(config)?; Ok(Self { timedelta_mode, bytes_mode, + inf_nan_mode, }) } - pub fn from_args(timedelta_mode: &str, bytes_mode: &str) -> PyResult { + pub fn from_args(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult { Ok(Self { timedelta_mode: TimedeltaMode::from_str(timedelta_mode)?, bytes_mode: BytesMode::from_str(bytes_mode)?, + inf_nan_mode: InfNanMode::from_str(inf_nan_mode)?, }) } } -#[derive(Default, Debug, Clone)] -pub(crate) enum TimedeltaMode { - #[default] - Iso8601, - Float, +pub trait FromConfig { + fn from_config(config: Option<&PyDict>) -> PyResult + where + Self: Sized; } -impl FromStr for TimedeltaMode { - type Err = PyErr; - - fn from_str(s: &str) -> Result { - match s { - "iso8601" => Ok(Self::Iso8601), - "float" => Ok(Self::Float), - s => py_schema_err!( - "Invalid timedelta serialization mode: `{}`, expected `iso8601` or `float`", - s - ), +macro_rules! serialization_mode { + ($name:ident, $config_key:expr, $($variant:ident => $value:expr),* $(,)?) => { + #[derive(Default, Debug, Clone, PartialEq, Eq)] + pub(crate) enum $name { + #[default] + $($variant,)* } - } + + impl FromStr for $name { + type Err = PyErr; + + fn from_str(s: &str) -> Result { + match s { + $($value => Ok(Self::$variant),)* + s => py_schema_err!( + concat!("Invalid ", stringify!($name), " serialization mode: `{}`, expected ", $($value, " or "),*), + s + ), + } + } + } + + impl FromConfig for $name { + fn from_config(config: Option<&PyDict>) -> PyResult { + let Some(config_dict) = config else { + return Ok(Self::default()); + }; + let raw_mode = config_dict.get_as::<&str>(intern!(config_dict.py(), $config_key))?; + raw_mode.map_or_else(|| Ok(Self::default()), Self::from_str) + } + } + + }; } -impl TimedeltaMode { - pub fn from_config(config: Option<&PyDict>) -> PyResult { - let Some(config_dict) = config else { - return Ok(Self::default()); - }; - let raw_mode = config_dict.get_as::<&str>(intern!(config_dict.py(), "ser_json_timedelta"))?; - raw_mode.map_or_else(|| Ok(Self::default()), Self::from_str) - } +serialization_mode! { + TimedeltaMode, + "ser_json_timedelta", + Iso8601 => "iso8601", + Float => "float", +} + +serialization_mode! { + BytesMode, + "ser_json_bytes", + Utf8 => "utf8", + Base64 => "base64", + Hex => "hex", +} + +serialization_mode! { + InfNanMode, + "ser_json_inf_nan", + Null => "null", + Constants => "constants", +} +impl TimedeltaMode { fn total_seconds(py_timedelta: &PyDelta) -> PyResult<&PyAny> { py_timedelta.call_method0(intern!(py_timedelta.py(), "total_seconds")) } @@ -124,39 +162,7 @@ impl TimedeltaMode { } } -#[derive(Default, Debug, Clone)] -pub(crate) enum BytesMode { - #[default] - Utf8, - Base64, - Hex, -} - -impl FromStr for BytesMode { - type Err = PyErr; - - fn from_str(s: &str) -> Result { - match s { - "utf8" => Ok(Self::Utf8), - "base64" => Ok(Self::Base64), - "hex" => Ok(Self::Hex), - s => py_schema_err!( - "Invalid bytes serialization mode: `{}`, expected `utf8`, `base64` or `hex`", - s - ), - } - } -} - impl BytesMode { - pub fn from_config(config: Option<&PyDict>) -> PyResult { - let Some(config_dict) = config else { - return Ok(Self::default()); - }; - let raw_mode = config_dict.get_as::<&str>(intern!(config_dict.py(), "ser_json_bytes"))?; - raw_mode.map_or_else(|| Ok(Self::default()), Self::from_str) - } - pub fn bytes_to_string<'py>(&self, py: Python, bytes: &'py [u8]) -> PyResult> { match self { Self::Utf8 => from_utf8(bytes) @@ -190,28 +196,6 @@ pub fn utf8_py_error(py: Python, err: Utf8Error, data: &[u8]) -> PyErr { } } -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub(crate) enum InfNanMode { - #[default] - Null, - Constants, -} - -impl FromStr for InfNanMode { - type Err = PyErr; - - fn from_str(s: &str) -> Result { - match s { - "null" => Ok(Self::Null), - "constants" => Ok(Self::Constants), - s => py_schema_err!( - "Invalid inf_nan serialization mode: `{}`, expected `null` or `constants`", - s - ), - } - } -} - impl FromPyObject<'_> for InfNanMode { fn extract(ob: &'_ PyAny) -> PyResult { let s = ob.extract::<&str>()?; diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index 7a9b84704..37307055e 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -21,10 +21,10 @@ pub(crate) struct SerializationState { } impl SerializationState { - pub fn new(timedelta_mode: &str, bytes_mode: &str) -> PyResult { + pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult { let warnings = CollectWarnings::new(false); let rec_guard = SerRecursionGuard::default(); - let config = SerializationConfig::from_args(timedelta_mode, bytes_mode)?; + let config = SerializationConfig::from_args(timedelta_mode, bytes_mode, inf_nan_mode)?; Ok(Self { warnings, rec_guard, diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 265967c57..13c20062b 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -10,6 +10,7 @@ use pyo3::types::{ use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer}; use crate::input::{EitherTimedelta, Int}; +use crate::serializers::config::InfNanMode; use crate::serializers::errors::SERIALIZATION_ERR_MARKER; use crate::serializers::filter::SchemaFilter; use crate::serializers::shared::{PydanticSerializer, TypeSerializer}; @@ -120,10 +121,16 @@ pub(crate) fn infer_to_python_known( let value = match extra.mode { SerMode::Json => match ob_type { // `bool` and `None` can't be subclasses, `ObType::Int`, `ObType::Float`, `ObType::Str` refer to exact types - ObType::None | ObType::Bool | ObType::Int | ObType::Float | ObType::Str => value.into_py(py), + ObType::None | ObType::Bool | ObType::Int | ObType::Str => value.into_py(py), // have to do this to make sure subclasses of for example str are upcast to `str` ObType::IntSubclass => extract_i64(value)?.into_py(py), - ObType::FloatSubclass => value.extract::()?.into_py(py), + ObType::Float | ObType::FloatSubclass => { + let v = value.extract::()?; + if (v.is_nan() || v.is_infinite()) && extra.config.inf_nan_mode == InfNanMode::Null { + return Ok(py.None()); + } + v.into_py(py) + } ObType::Decimal => value.to_string().into_py(py), ObType::StrSubclass => value.extract::<&str>()?.into_py(py), ObType::Bytes => extra diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index e9208a510..8159691cb 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -213,7 +213,7 @@ impl SchemaSerializer { #[pyfunction] #[pyo3(signature = (value, *, indent = None, include = None, exclude = None, by_alias = true, exclude_none = false, round_trip = false, timedelta_mode = "iso8601", bytes_mode = "utf8", - serialize_unknown = false, fallback = None))] + inf_nan_mode = "constants", serialize_unknown = false, fallback = None))] pub fn to_json( py: Python, value: &PyAny, @@ -225,10 +225,11 @@ pub fn to_json( round_trip: bool, timedelta_mode: &str, bytes_mode: &str, + inf_nan_mode: &str, serialize_unknown: bool, fallback: Option<&PyAny>, ) -> PyResult { - let state = SerializationState::new(timedelta_mode, bytes_mode)?; + let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?; let extra = state.extra( py, &SerMode::Json, @@ -248,7 +249,7 @@ pub fn to_json( #[allow(clippy::too_many_arguments)] #[pyfunction] #[pyo3(signature = (value, *, include = None, exclude = None, by_alias = true, exclude_none = false, round_trip = false, - timedelta_mode = "iso8601", bytes_mode = "utf8", serialize_unknown = false, fallback = None))] + timedelta_mode = "iso8601", bytes_mode = "utf8", inf_nan_mode = "constants", serialize_unknown = false, fallback = None))] pub fn to_jsonable_python( py: Python, value: &PyAny, @@ -259,10 +260,11 @@ pub fn to_jsonable_python( round_trip: bool, timedelta_mode: &str, bytes_mode: &str, + inf_nan_mode: &str, serialize_unknown: bool, fallback: Option<&PyAny>, ) -> PyResult { - let state = SerializationState::new(timedelta_mode, bytes_mode)?; + let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?; let extra = state.extra( py, &SerMode::Json, diff --git a/src/serializers/type_serializers/bytes.rs b/src/serializers/type_serializers/bytes.rs index 67dbe794b..bd354aa75 100644 --- a/src/serializers/type_serializers/bytes.rs +++ b/src/serializers/type_serializers/bytes.rs @@ -4,6 +4,7 @@ use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict}; use crate::definitions::DefinitionsBuilder; +use crate::serializers::config::{BytesMode, FromConfig}; use super::{ infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerMode, @@ -11,17 +12,20 @@ use super::{ }; #[derive(Debug, Clone)] -pub struct BytesSerializer; +pub struct BytesSerializer { + bytes_mode: BytesMode, +} impl BuildSerializer for BytesSerializer { const EXPECTED_TYPE: &'static str = "bytes"; fn build( _schema: &PyDict, - _config: Option<&PyDict>, + config: Option<&PyDict>, _definitions: &mut DefinitionsBuilder, ) -> PyResult { - Ok(Self {}.into()) + let bytes_mode = BytesMode::from_config(config)?; + Ok(Self { bytes_mode }.into()) } } @@ -38,8 +42,7 @@ impl TypeSerializer for BytesSerializer { let py = value.py(); match value.downcast::() { Ok(py_bytes) => match extra.mode { - SerMode::Json => extra - .config + SerMode::Json => self .bytes_mode .bytes_to_string(py, py_bytes.as_bytes()) .map(|s| s.into_py(py)), @@ -54,7 +57,7 @@ impl TypeSerializer for BytesSerializer { fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { match key.downcast::() { - Ok(py_bytes) => extra.config.bytes_mode.bytes_to_string(key.py(), py_bytes.as_bytes()), + Ok(py_bytes) => self.bytes_mode.bytes_to_string(key.py(), py_bytes.as_bytes()), Err(_) => { extra.warnings.on_fallback_py(self.get_name(), key, extra)?; infer_json_key(key, extra) @@ -71,7 +74,7 @@ impl TypeSerializer for BytesSerializer { extra: &Extra, ) -> Result { match value.downcast::() { - Ok(py_bytes) => extra.config.bytes_mode.serialize_bytes(py_bytes.as_bytes(), serializer), + Ok(py_bytes) => self.bytes_mode.serialize_bytes(py_bytes.as_bytes(), serializer), Err(_) => { extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; infer_serialize(value, serializer, include, exclude, extra) diff --git a/src/serializers/type_serializers/timedelta.rs b/src/serializers/type_serializers/timedelta.rs index baadbe0aa..042a8ffea 100644 --- a/src/serializers/type_serializers/timedelta.rs +++ b/src/serializers/type_serializers/timedelta.rs @@ -5,6 +5,7 @@ use pyo3::types::PyDict; use crate::definitions::DefinitionsBuilder; use crate::input::EitherTimedelta; +use crate::serializers::config::{FromConfig, TimedeltaMode}; use super::{ infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerMode, @@ -12,17 +13,20 @@ use super::{ }; #[derive(Debug, Clone)] -pub struct TimeDeltaSerializer; +pub struct TimeDeltaSerializer { + timedelta_mode: TimedeltaMode, +} impl BuildSerializer for TimeDeltaSerializer { const EXPECTED_TYPE: &'static str = "timedelta"; fn build( _schema: &PyDict, - _config: Option<&PyDict>, + config: Option<&PyDict>, _definitions: &mut DefinitionsBuilder, ) -> PyResult { - Ok(Self {}.into()) + let timedelta_mode = TimedeltaMode::from_config(config)?; + Ok(Self { timedelta_mode }.into()) } } @@ -38,10 +42,7 @@ impl TypeSerializer for TimeDeltaSerializer { ) -> PyResult { match extra.mode { SerMode::Json => match EitherTimedelta::try_from(value) { - Ok(either_timedelta) => extra - .config - .timedelta_mode - .either_delta_to_json(value.py(), &either_timedelta), + Ok(either_timedelta) => self.timedelta_mode.either_delta_to_json(value.py(), &either_timedelta), Err(_) => { extra.warnings.on_fallback_py(self.get_name(), value, extra)?; infer_to_python(value, include, exclude, extra) @@ -53,7 +54,7 @@ impl TypeSerializer for TimeDeltaSerializer { fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { match EitherTimedelta::try_from(key) { - Ok(either_timedelta) => extra.config.timedelta_mode.json_key(key.py(), &either_timedelta), + Ok(either_timedelta) => self.timedelta_mode.json_key(key.py(), &either_timedelta), Err(_) => { extra.warnings.on_fallback_py(self.get_name(), key, extra)?; infer_json_key(key, extra) @@ -70,12 +71,9 @@ impl TypeSerializer for TimeDeltaSerializer { extra: &Extra, ) -> Result { match EitherTimedelta::try_from(value) { - Ok(either_timedelta) => { - extra - .config - .timedelta_mode - .timedelta_serialize(value.py(), &either_timedelta, serializer) - } + Ok(either_timedelta) => self + .timedelta_mode + .timedelta_serialize(value.py(), &either_timedelta, serializer), Err(_) => { extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; infer_serialize(value, serializer, include, exclude, extra) diff --git a/tests/serializers/test_bytes.py b/tests/serializers/test_bytes.py index 13849bed0..cc2d44785 100644 --- a/tests/serializers/test_bytes.py +++ b/tests/serializers/test_bytes.py @@ -4,7 +4,7 @@ import pytest -from pydantic_core import PydanticSerializationError, SchemaSerializer, core_schema +from pydantic_core import PydanticSerializationError, SchemaSerializer, core_schema, to_json def test_bytes(): @@ -126,3 +126,37 @@ def test_any_bytes_base64(): assert s.to_json(b'foobar') == b'"Zm9vYmFy"' assert s.to_json({b'foobar': 123}) == b'{"Zm9vYmFy":123}' assert s.to_python({b'foobar': 123}, mode='json') == {'Zm9vYmFy': 123} + + +class BasicModel: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + +def test_bytes_mode_set_via_model_config_not_serializer_config(): + s = SchemaSerializer( + core_schema.model_schema( + BasicModel, + core_schema.model_fields_schema( + { + 'foo': core_schema.model_field(core_schema.bytes_schema()), + } + ), + config=core_schema.CoreConfig(ser_json_bytes='base64'), + ) + ) + + bm = BasicModel(foo=b'foobar') + assert s.to_python(bm) == {'foo': b'foobar'} + assert s.to_json(bm) == b'{"foo":"Zm9vYmFy"}' + assert s.to_python(bm, mode='json') == {'foo': 'Zm9vYmFy'} + + # assert doesn't override serializer config + # in V3, we can change the serialization settings provided to to_json to override model config settings, + # but that'd be a breaking change + BasicModel.__pydantic_serializer__ = s + assert to_json(bm, bytes_mode='utf8') == b'{"foo":"Zm9vYmFy"}' + + assert to_json({'foo': b'some bytes'}, bytes_mode='base64') == b'{"foo":"c29tZSBieXRlcw=="}' + assert to_json({'bar': bm}, bytes_mode='base64') == b'{"bar":{"foo":"Zm9vYmFy"}}'