Skip to content

Commit

Permalink
add ser_json_inf_nan setting
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Nov 6, 2023
1 parent 227ad33 commit c62b637
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 36 deletions.
3 changes: 3 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class CoreConfig(TypedDict, total=False):
allow_inf_nan: Whether to allow infinity and NaN values for float fields. Default is `True`.
ser_json_timedelta: The serialization option for `timedelta` values. Default is 'iso8601'.
ser_json_bytes: The serialization option for `bytes` values. Default is 'utf8'.
ser_json_inf_nan: The serialization option for infinity and NaN values
in float fields. Default is 'null'.
hide_input_in_errors: Whether to hide input data from `ValidationError` representation.
validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError.
Requires exceptiongroup backport pre Python 3.11.
Expand Down Expand Up @@ -102,6 +104,7 @@ class CoreConfig(TypedDict, total=False):
# the config options are used to customise serialization to JSON
ser_json_timedelta: Literal['iso8601', 'float'] # default: 'iso8601'
ser_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8'
ser_json_inf_nan: Literal['null', 'constants'] # default: 'null'
# used to hide input data from ValidationError repr
hide_input_in_errors: bool
validation_error_cause: bool # default: False
Expand Down
29 changes: 29 additions & 0 deletions src/serializers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,32 @@ pub fn utf8_py_error(py: Python, err: Utf8Error, data: &[u8]) -> PyErr {
Err(err) => err,
}
}

#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub(crate) enum InfNanMode {
#[default]
Null,
Constants,
}

impl FromStr for InfNanMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"null" => Ok(Self::Null),
"constants" => Ok(Self::Constants),
s => py_schema_err!(
"Invalid inf_nan serialization mode: `{}`, expected `null` or `constants`",
s
),
}
}
}

impl FromPyObject<'_> for InfNanMode {
fn extract(ob: &'_ PyAny) -> PyResult<Self> {
let s = ob.extract::<&str>()?;
Self::from_str(s)
}
}
21 changes: 12 additions & 9 deletions src/serializers/type_serializers/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::borrow::Cow;
use serde::Serializer;

use crate::definitions::DefinitionsBuilder;
use crate::serializers::config::InfNanMode;
use crate::tools::SchemaDict;

use super::simple::to_str_json_key;
Expand All @@ -16,21 +17,22 @@ use super::{

#[derive(Debug, Clone)]
pub struct FloatSerializer {
allow_inf_nan: bool,
inf_nan_mode: InfNanMode,
}

impl BuildSerializer for FloatSerializer {
const EXPECTED_TYPE: &'static str = "float";

fn build(
schema: &PyDict,
_config: Option<&PyDict>,
config: Option<&PyDict>,
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
let allow_inf_nan = schema
.get_as::<bool>(intern!(schema.py(), "allow_inf_nan"))?
.unwrap_or(false);
Ok(Self { allow_inf_nan }.into())
let inf_nan_mode = config
.and_then(|c| c.get_as(intern!(schema.py(), "ser_json_inf_nan")).transpose())
.transpose()?
.unwrap_or_default();
Ok(Self { inf_nan_mode }.into())
}
}

Expand Down Expand Up @@ -81,10 +83,11 @@ impl TypeSerializer for FloatSerializer {
) -> Result<S::Ok, S::Error> {
match value.extract::<f64>() {
Ok(v) => {
if (v.is_nan() || v.is_infinite()) && !self.allow_inf_nan {
return serializer.serialize_none();
if (v.is_nan() || v.is_infinite()) && self.inf_nan_mode == InfNanMode::Null {
serializer.serialize_none()
} else {
serializer.serialize_f64(v)
}
serializer.serialize_f64(v)
}
Err(_) => {
extra.warnings.on_fallback_ser::<S>(self.get_name(), value, extra)?;
Expand Down
46 changes: 19 additions & 27 deletions tests/serializers/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,35 +139,27 @@ def test_numpy():


@pytest.mark.parametrize(
'schema_type,value,expected,allow_inf_nan',
'value,expected_json,config',
[
('float', float('inf'), float('inf'), True),
('float', float('+inf'), float('+inf'), True),
('float', float('-inf'), float('-inf'), True),
('float', float('inf'), None, False),
('float', float('+inf'), None, False),
('float', float('-inf'), None, False),
('float', float('NaN'), float('NaN'), True),
('float', float('NAN'), float('NAN'), True),
('float', float('NaN'), None, False),
('float', float('NAN'), None, False),
# default values of ser_json_inf_nan
(float('inf'), 'null', {}),
(float('-inf'), 'null', {}),
(float('nan'), 'null', {}),
# explicit values of ser_json_inf_nan
(float('inf'), 'null', {'ser_json_inf_nan': 'null'}),
(float('-inf'), 'null', {'ser_json_inf_nan': 'null'}),
(float('nan'), 'null', {'ser_json_inf_nan': 'null'}),
(float('inf'), 'Infinity', {'ser_json_inf_nan': 'constants'}),
(float('-inf'), '-Infinity', {'ser_json_inf_nan': 'constants'}),
(float('nan'), 'NaN', {'ser_json_inf_nan': 'constants'}),
],
)
def test_float_inf_and_nan_serializers(schema_type, value, expected, allow_inf_nan):
schema = {'type': schema_type, 'allow_inf_nan': allow_inf_nan}
def test_float_inf_and_nan_serializers(value, expected_json, config):
s = SchemaSerializer(core_schema.float_schema(), config)

s = SchemaSerializer(schema)
v = s.to_python(value)

if allow_inf_nan:
assert type(v) == type(expected)
else:
assert expected is None
# Python can represent these values without needing any changes
assert s.to_python(value) is value
assert s.to_python(value, mode='json') is value

assert s.to_json(value) == json.dumps(expected).encode('utf-8')

v_json = s.to_python(value, mode='json')
if allow_inf_nan:
assert type(v_json) == type(expected)
else:
assert expected is None
# Serialized JSON value respects the ser_json_inf_nan setting
assert s.to_json(value).decode() == expected_json

0 comments on commit c62b637

Please sign in to comment.