Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a custom Serializer/Deserializer to fix Nan and Infinity float #888

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/errors/validation_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,12 @@ impl ValidationError {
Some(indent) => {
let indent = vec![b' '; indent];
let formatter = PrettyFormatter::with_indent(&indent);
let mut ser = serde_json::Serializer::with_formatter(writer, formatter);
let mut ser = crate::serde::PythonSerializer::with_formatter(writer, formatter);
serializer.serialize(&mut ser).map_err(json_py_err)?;
ser.into_inner()
}
None => {
let mut ser = serde_json::Serializer::new(writer);
let mut ser = crate::serde::PythonSerializer::new(writer);
serializer.serialize(&mut ser).map_err(json_py_err)?;
ser.into_inner()
}
Expand Down
4 changes: 2 additions & 2 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl<'a> Input<'a> for JsonInput {

fn parse_json(&'a self) -> ValResult<'a, JsonInput> {
match self {
JsonInput::String(s) => serde_json::from_str(s.as_str()).map_err(|e| map_json_err(self, e)),
JsonInput::String(s) => crate::serde::from_str(s.as_str()).map_err(|e| map_json_err(self, e)),
_ => Err(ValError::new(ErrorTypeDefaults::JsonType, self)),
}
}
Expand Down Expand Up @@ -392,7 +392,7 @@ impl<'a> Input<'a> for String {
}

fn parse_json(&'a self) -> ValResult<'a, JsonInput> {
serde_json::from_str(self.as_str()).map_err(|e| map_json_err(self, e))
crate::serde::from_str(self.as_str()).map_err(|e| map_json_err(self, e))
}

fn validate_str(&'a self, _strict: bool) -> ValResult<EitherString<'a>> {
Expand Down
6 changes: 3 additions & 3 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,12 @@ impl<'a> Input<'a> for PyAny {

fn parse_json(&'a self) -> ValResult<'a, JsonInput> {
if let Ok(py_bytes) = self.downcast::<PyBytes>() {
serde_json::from_slice(py_bytes.as_bytes()).map_err(|e| map_json_err(self, e))
crate::serde::from_slice(py_bytes.as_bytes()).map_err(|e| map_json_err(self, e))
} else if let Ok(py_str) = self.downcast::<PyString>() {
let str = py_str.to_str()?;
serde_json::from_str(str).map_err(|e| map_json_err(self, e))
crate::serde::from_str(str).map_err(|e| map_json_err(self, e))
} else if let Ok(py_byte_array) = self.downcast::<PyByteArray>() {
serde_json::from_slice(unsafe { py_byte_array.as_bytes() }).map_err(|e| map_json_err(self, e))
crate::serde::from_slice(unsafe { py_byte_array.as_bytes() }).map_err(|e| map_json_err(self, e))
} else {
Err(ValError::new(ErrorTypeDefaults::JsonType, self))
}
Expand Down
1 change: 0 additions & 1 deletion src/input/parse_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ impl<'de> Deserialize<'de> for JsonInput {
Some(first_key) => {
let mut values = LazyIndexMap::new();
let first_value = visitor.next_value()?;

// serde_json will parse arbitrary precision numbers into a map
// structure with a "number" key and a String value
'try_number: {
Expand Down
29 changes: 28 additions & 1 deletion src/input/shared.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use num_bigint::BigInt;
use pyo3::exceptions::PyValueError;
use pyo3::pyclass;

use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult};
use crate::input::EitherInt;

use super::{EitherFloat, Input};

pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: serde_json::Error) -> ValError<'a> {
pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: crate::serde::PydanticSerdeError) -> ValError<'a> {
ValError::new(
ErrorType::JsonInvalid {
error: error.to_string(),
Expand Down Expand Up @@ -136,3 +138,28 @@ pub fn float_as_int<'a>(input: &'a impl Input<'a>, float: f64) -> ValResult<'a,
Err(ValError::new(ErrorTypeDefaults::IntParsingSize, input))
}
}

#[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")]
#[derive(Debug, Clone)]
pub struct PythonDeserializerError {
pub message: String,
}

impl std::fmt::Display for PythonDeserializerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}

impl std::error::Error for PythonDeserializerError {}

impl serde::ser::Error for PythonDeserializerError {
fn custom<T>(msg: T) -> Self
where
T: std::fmt::Display,
{
PythonDeserializerError {
message: format!("{msg}"),
}
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ mod input;
mod lazy_index_map;
mod lookup_key;
mod recursion_guard;
mod serde;
mod serializers;
mod tools;
mod url;
Expand Down
Loading