Skip to content

Commit

Permalink
Support serialization mode specification from model config and `Seria…
Browse files Browse the repository at this point in the history
…lizationConfig` (#1122)

Co-authored-by: David Hewitt <[email protected]>
  • Loading branch information
sydney-runkle and davidhewitt authored Jan 9, 2024
1 parent f3d0cc5 commit e3ae7f6
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 111 deletions.
4 changes: 4 additions & 0 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def to_json(
round_trip: bool = False,
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
bytes_mode: Literal['utf8', 'base64'] = 'utf8',
inf_nan_mode: Literal['null', 'constants'] = 'constants',
serialize_unknown: bool = False,
fallback: Callable[[Any], Any] | None = None,
) -> bytes:
Expand All @@ -373,6 +374,7 @@ def to_json(
round_trip: Whether to enable serialization and validation round-trip support.
timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`.
bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`.
inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'` or `'constants'`.
serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails
`"<Unserializable {value_type} object>"` will be used.
fallback: A function to call when an unknown value is encountered,
Expand Down Expand Up @@ -414,6 +416,7 @@ def to_jsonable_python(
round_trip: bool = False,
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
bytes_mode: Literal['utf8', 'base64'] = 'utf8',
inf_nan_mode: Literal['null', 'constants'] = 'constants',
serialize_unknown: bool = False,
fallback: Callable[[Any], Any] | None = None,
) -> Any:
Expand All @@ -432,6 +435,7 @@ def to_jsonable_python(
round_trip: Whether to enable serialization and validation round-trip support.
timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`.
bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`.
inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'` or `'constants'`.
serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails
`"<Unserializable {value_type} object>"` will be used.
fallback: A function to call when an unknown value is encountered,
Expand Down
2 changes: 1 addition & 1 deletion src/errors/validation_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ impl ValidationError {
include_context: bool,
include_input: bool,
) -> PyResult<&'py PyString> {
let state = SerializationState::new("iso8601", "utf8")?;
let state = SerializationState::new("iso8601", "utf8", "constants")?;
let extra = state.extra(py, &SerMode::Json, true, false, false, true, None);
let serializer = ValidationErrorSerializer {
py,
Expand Down
144 changes: 64 additions & 80 deletions src/serializers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,60 +15,98 @@ use crate::tools::SchemaDict;
use super::errors::py_err_se_err;

#[derive(Debug, Clone)]
#[allow(clippy::struct_field_names)]
pub(crate) struct SerializationConfig {
pub timedelta_mode: TimedeltaMode,
pub bytes_mode: BytesMode,
pub inf_nan_mode: InfNanMode,
}

impl SerializationConfig {
pub fn from_config(config: Option<&PyDict>) -> PyResult<Self> {
let timedelta_mode = TimedeltaMode::from_config(config)?;
let bytes_mode = BytesMode::from_config(config)?;
let inf_nan_mode = InfNanMode::from_config(config)?;
Ok(Self {
timedelta_mode,
bytes_mode,
inf_nan_mode,
})
}

pub fn from_args(timedelta_mode: &str, bytes_mode: &str) -> PyResult<Self> {
pub fn from_args(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> {
Ok(Self {
timedelta_mode: TimedeltaMode::from_str(timedelta_mode)?,
bytes_mode: BytesMode::from_str(bytes_mode)?,
inf_nan_mode: InfNanMode::from_str(inf_nan_mode)?,
})
}
}

#[derive(Default, Debug, Clone)]
pub(crate) enum TimedeltaMode {
#[default]
Iso8601,
Float,
pub trait FromConfig {
fn from_config(config: Option<&PyDict>) -> PyResult<Self>
where
Self: Sized;
}

impl FromStr for TimedeltaMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"iso8601" => Ok(Self::Iso8601),
"float" => Ok(Self::Float),
s => py_schema_err!(
"Invalid timedelta serialization mode: `{}`, expected `iso8601` or `float`",
s
),
macro_rules! serialization_mode {
($name:ident, $config_key:expr, $($variant:ident => $value:expr),* $(,)?) => {
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub(crate) enum $name {
#[default]
$($variant,)*
}
}

impl FromStr for $name {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
$($value => Ok(Self::$variant),)*
s => py_schema_err!(
concat!("Invalid ", stringify!($name), " serialization mode: `{}`, expected ", $($value, " or "),*),
s
),
}
}
}

impl FromConfig for $name {
fn from_config(config: Option<&PyDict>) -> PyResult<Self> {
let Some(config_dict) = config else {
return Ok(Self::default());
};
let raw_mode = config_dict.get_as::<&str>(intern!(config_dict.py(), $config_key))?;
raw_mode.map_or_else(|| Ok(Self::default()), Self::from_str)
}
}

};
}

impl TimedeltaMode {
pub fn from_config(config: Option<&PyDict>) -> PyResult<Self> {
let Some(config_dict) = config else {
return Ok(Self::default());
};
let raw_mode = config_dict.get_as::<&str>(intern!(config_dict.py(), "ser_json_timedelta"))?;
raw_mode.map_or_else(|| Ok(Self::default()), Self::from_str)
}
serialization_mode! {
TimedeltaMode,
"ser_json_timedelta",
Iso8601 => "iso8601",
Float => "float",
}

serialization_mode! {
BytesMode,
"ser_json_bytes",
Utf8 => "utf8",
Base64 => "base64",
Hex => "hex",
}

serialization_mode! {
InfNanMode,
"ser_json_inf_nan",
Null => "null",
Constants => "constants",
}

impl TimedeltaMode {
fn total_seconds(py_timedelta: &PyDelta) -> PyResult<&PyAny> {
py_timedelta.call_method0(intern!(py_timedelta.py(), "total_seconds"))
}
Expand Down Expand Up @@ -124,39 +162,7 @@ impl TimedeltaMode {
}
}

#[derive(Default, Debug, Clone)]
pub(crate) enum BytesMode {
#[default]
Utf8,
Base64,
Hex,
}

impl FromStr for BytesMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"utf8" => Ok(Self::Utf8),
"base64" => Ok(Self::Base64),
"hex" => Ok(Self::Hex),
s => py_schema_err!(
"Invalid bytes serialization mode: `{}`, expected `utf8`, `base64` or `hex`",
s
),
}
}
}

impl BytesMode {
pub fn from_config(config: Option<&PyDict>) -> PyResult<Self> {
let Some(config_dict) = config else {
return Ok(Self::default());
};
let raw_mode = config_dict.get_as::<&str>(intern!(config_dict.py(), "ser_json_bytes"))?;
raw_mode.map_or_else(|| Ok(Self::default()), Self::from_str)
}

pub fn bytes_to_string<'py>(&self, py: Python, bytes: &'py [u8]) -> PyResult<Cow<'py, str>> {
match self {
Self::Utf8 => from_utf8(bytes)
Expand Down Expand Up @@ -190,28 +196,6 @@ pub fn utf8_py_error(py: Python, err: Utf8Error, data: &[u8]) -> PyErr {
}
}

#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub(crate) enum InfNanMode {
#[default]
Null,
Constants,
}

impl FromStr for InfNanMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"null" => Ok(Self::Null),
"constants" => Ok(Self::Constants),
s => py_schema_err!(
"Invalid inf_nan serialization mode: `{}`, expected `null` or `constants`",
s
),
}
}
}

impl FromPyObject<'_> for InfNanMode {
fn extract(ob: &'_ PyAny) -> PyResult<Self> {
let s = ob.extract::<&str>()?;
Expand Down
4 changes: 2 additions & 2 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ pub(crate) struct SerializationState {
}

impl SerializationState {
pub fn new(timedelta_mode: &str, bytes_mode: &str) -> PyResult<Self> {
pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> {
let warnings = CollectWarnings::new(false);
let rec_guard = SerRecursionGuard::default();
let config = SerializationConfig::from_args(timedelta_mode, bytes_mode)?;
let config = SerializationConfig::from_args(timedelta_mode, bytes_mode, inf_nan_mode)?;
Ok(Self {
warnings,
rec_guard,
Expand Down
11 changes: 9 additions & 2 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use pyo3::types::{
use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};

use crate::input::{EitherTimedelta, Int};
use crate::serializers::config::InfNanMode;
use crate::serializers::errors::SERIALIZATION_ERR_MARKER;
use crate::serializers::filter::SchemaFilter;
use crate::serializers::shared::{PydanticSerializer, TypeSerializer};
Expand Down Expand Up @@ -120,10 +121,16 @@ pub(crate) fn infer_to_python_known(
let value = match extra.mode {
SerMode::Json => match ob_type {
// `bool` and `None` can't be subclasses, `ObType::Int`, `ObType::Float`, `ObType::Str` refer to exact types
ObType::None | ObType::Bool | ObType::Int | ObType::Float | ObType::Str => value.into_py(py),
ObType::None | ObType::Bool | ObType::Int | ObType::Str => value.into_py(py),
// have to do this to make sure subclasses of for example str are upcast to `str`
ObType::IntSubclass => extract_i64(value)?.into_py(py),
ObType::FloatSubclass => value.extract::<f64>()?.into_py(py),
ObType::Float | ObType::FloatSubclass => {
let v = value.extract::<f64>()?;
if (v.is_nan() || v.is_infinite()) && extra.config.inf_nan_mode == InfNanMode::Null {
return Ok(py.None());
}
v.into_py(py)
}
ObType::Decimal => value.to_string().into_py(py),
ObType::StrSubclass => value.extract::<&str>()?.into_py(py),
ObType::Bytes => extra
Expand Down
10 changes: 6 additions & 4 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ impl SchemaSerializer {
#[pyfunction]
#[pyo3(signature = (value, *, indent = None, include = None, exclude = None, by_alias = true,
exclude_none = false, round_trip = false, timedelta_mode = "iso8601", bytes_mode = "utf8",
serialize_unknown = false, fallback = None))]
inf_nan_mode = "constants", serialize_unknown = false, fallback = None))]
pub fn to_json(
py: Python,
value: &PyAny,
Expand All @@ -225,10 +225,11 @@ pub fn to_json(
round_trip: bool,
timedelta_mode: &str,
bytes_mode: &str,
inf_nan_mode: &str,
serialize_unknown: bool,
fallback: Option<&PyAny>,
) -> PyResult<PyObject> {
let state = SerializationState::new(timedelta_mode, bytes_mode)?;
let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?;
let extra = state.extra(
py,
&SerMode::Json,
Expand All @@ -248,7 +249,7 @@ pub fn to_json(
#[allow(clippy::too_many_arguments)]
#[pyfunction]
#[pyo3(signature = (value, *, include = None, exclude = None, by_alias = true, exclude_none = false, round_trip = false,
timedelta_mode = "iso8601", bytes_mode = "utf8", serialize_unknown = false, fallback = None))]
timedelta_mode = "iso8601", bytes_mode = "utf8", inf_nan_mode = "constants", serialize_unknown = false, fallback = None))]
pub fn to_jsonable_python(
py: Python,
value: &PyAny,
Expand All @@ -259,10 +260,11 @@ pub fn to_jsonable_python(
round_trip: bool,
timedelta_mode: &str,
bytes_mode: &str,
inf_nan_mode: &str,
serialize_unknown: bool,
fallback: Option<&PyAny>,
) -> PyResult<PyObject> {
let state = SerializationState::new(timedelta_mode, bytes_mode)?;
let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?;
let extra = state.extra(
py,
&SerMode::Json,
Expand Down
17 changes: 10 additions & 7 deletions src/serializers/type_serializers/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,28 @@ use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict};

use crate::definitions::DefinitionsBuilder;
use crate::serializers::config::{BytesMode, FromConfig};

use super::{
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerMode,
TypeSerializer,
};

#[derive(Debug, Clone)]
pub struct BytesSerializer;
pub struct BytesSerializer {
bytes_mode: BytesMode,
}

impl BuildSerializer for BytesSerializer {
const EXPECTED_TYPE: &'static str = "bytes";

fn build(
_schema: &PyDict,
_config: Option<&PyDict>,
config: Option<&PyDict>,
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
Ok(Self {}.into())
let bytes_mode = BytesMode::from_config(config)?;
Ok(Self { bytes_mode }.into())
}
}

Expand All @@ -38,8 +42,7 @@ impl TypeSerializer for BytesSerializer {
let py = value.py();
match value.downcast::<PyBytes>() {
Ok(py_bytes) => match extra.mode {
SerMode::Json => extra
.config
SerMode::Json => self
.bytes_mode
.bytes_to_string(py, py_bytes.as_bytes())
.map(|s| s.into_py(py)),
Expand All @@ -54,7 +57,7 @@ impl TypeSerializer for BytesSerializer {

fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult<Cow<'py, str>> {
match key.downcast::<PyBytes>() {
Ok(py_bytes) => extra.config.bytes_mode.bytes_to_string(key.py(), py_bytes.as_bytes()),
Ok(py_bytes) => self.bytes_mode.bytes_to_string(key.py(), py_bytes.as_bytes()),
Err(_) => {
extra.warnings.on_fallback_py(self.get_name(), key, extra)?;
infer_json_key(key, extra)
Expand All @@ -71,7 +74,7 @@ impl TypeSerializer for BytesSerializer {
extra: &Extra,
) -> Result<S::Ok, S::Error> {
match value.downcast::<PyBytes>() {
Ok(py_bytes) => extra.config.bytes_mode.serialize_bytes(py_bytes.as_bytes(), serializer),
Ok(py_bytes) => self.bytes_mode.serialize_bytes(py_bytes.as_bytes(), serializer),
Err(_) => {
extra.warnings.on_fallback_ser::<S>(self.get_name(), value, extra)?;
infer_serialize(value, serializer, include, exclude, extra)
Expand Down
Loading

0 comments on commit e3ae7f6

Please sign in to comment.