diff --git a/src/serializers/type_serializers/simple.rs b/src/serializers/type_serializers/simple.rs index dafb2b786..65fbee146 100644 --- a/src/serializers/type_serializers/simple.rs +++ b/src/serializers/type_serializers/simple.rs @@ -5,11 +5,12 @@ use std::borrow::Cow; use serde::Serialize; +use crate::PydanticSerializationUnexpectedValue; use crate::{definitions::DefinitionsBuilder, input::Int}; use super::{ infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, IsType, ObType, - SerMode, TypeSerializer, + SerCheck, SerMode, TypeSerializer, }; #[derive(Debug, Clone)] @@ -85,7 +86,7 @@ impl TypeSerializer for NoneSerializer { } macro_rules! build_simple_serializer { - ($struct_name:ident, $expected_type:literal, $rust_type:ty, $ob_type:expr, $key_method:ident) => { + ($struct_name:ident, $expected_type:literal, $rust_type:ty, $ob_type:expr, $key_method:ident, $subtypes_allowed:expr) => { #[derive(Debug, Clone)] pub struct $struct_name; @@ -114,12 +115,15 @@ macro_rules! build_simple_serializer { let py = value.py(); match extra.ob_type_lookup.is_type(value, $ob_type) { IsType::Exact => Ok(value.into_py(py)), - IsType::Subclass => match extra.mode { - SerMode::Json => { - let rust_value = value.extract::<$rust_type>()?; - Ok(rust_value.to_object(py)) - } - _ => infer_to_python(value, include, exclude, extra), + IsType::Subclass => match extra.check { + SerCheck::Strict => Err(PydanticSerializationUnexpectedValue::new_err(None)), + SerCheck::Lax | SerCheck::None => match extra.mode { + SerMode::Json => { + let rust_value = value.extract::<$rust_type>()?; + Ok(rust_value.to_object(py)) + } + _ => infer_to_python(value, include, exclude, extra), + }, }, IsType::False => { extra.warnings.on_fallback_py(self.get_name(), value, extra)?; @@ -160,6 +164,10 @@ macro_rules! build_simple_serializer { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn retry_with_lax_check(&self) -> bool { + $subtypes_allowed + } } }; } @@ -168,7 +176,7 @@ pub(crate) fn to_str_json_key(key: &PyAny) -> PyResult> { Ok(key.str()?.to_string_lossy()) } -build_simple_serializer!(IntSerializer, "int", Int, ObType::Int, to_str_json_key); +build_simple_serializer!(IntSerializer, "int", Int, ObType::Int, to_str_json_key, true); pub(crate) fn bool_json_key(key: &PyAny) -> PyResult> { let v = if key.is_true().unwrap_or(false) { @@ -179,4 +187,4 @@ pub(crate) fn bool_json_key(key: &PyAny) -> PyResult> { Ok(Cow::Borrowed(v)) } -build_simple_serializer!(BoolSerializer, "bool", bool, ObType::Bool, bool_json_key); +build_simple_serializer!(BoolSerializer, "bool", bool, ObType::Bool, bool_json_key, false); diff --git a/tests/serializers/test_union.py b/tests/serializers/test_union.py index 9b021e66e..ee5ed3fc4 100644 --- a/tests/serializers/test_union.py +++ b/tests/serializers/test_union.py @@ -1,6 +1,8 @@ import dataclasses import json import re +import uuid +from decimal import Decimal from typing import Any, ClassVar, Union import pytest @@ -510,3 +512,117 @@ class Item(BaseModel): ) assert s.to_python([DBUser(name='John', password='secret')]) == [{'name': 'John'}] + + +EXAMPLE_UUID = uuid.uuid4() + + +class IntSubclass(int): + pass + + +@pytest.mark.parametrize('reverse', [False, True]) +@pytest.mark.parametrize( + 'core_schema_left,core_schema_right,input_value,expected_value', + [ + (core_schema.int_schema(), core_schema.bool_schema(), True, True), + (core_schema.int_schema(), core_schema.bool_schema(), 1, 1), + (core_schema.str_schema(), core_schema.int_schema(), 1, 1), + (core_schema.str_schema(), core_schema.int_schema(), '1', '1'), + (core_schema.int_schema(), core_schema.bool_schema(), IntSubclass(1), 1), + ( + core_schema.decimal_schema(), + core_schema.int_schema(), + Decimal('1'), + Decimal('1'), + ), + (core_schema.decimal_schema(), core_schema.int_schema(), 1, 1), + ( + core_schema.decimal_schema(), + core_schema.float_schema(), + Decimal('1.'), + Decimal('1.'), + ), + ( + core_schema.decimal_schema(), + core_schema.str_schema(), + Decimal('_1'), + Decimal('_1'), + ), + ( + core_schema.decimal_schema(), + core_schema.str_schema(), + '_1', + '_1', + ), + ( + core_schema.uuid_schema(), + core_schema.str_schema(), + EXAMPLE_UUID, + EXAMPLE_UUID, + ), + ( + core_schema.uuid_schema(), + core_schema.str_schema(), + str(EXAMPLE_UUID), + str(EXAMPLE_UUID), + ), + ], +) +def test_union_serializer_picks_exact_type_over_subclass( + core_schema_left, core_schema_right, input_value, expected_value, reverse +): + s = SchemaSerializer( + core_schema.union_schema( + [core_schema_right, core_schema_left] if reverse else [core_schema_left, core_schema_right] + ) + ) + assert s.to_python(input_value) == expected_value + + +@pytest.mark.parametrize('reverse', [False, True]) +@pytest.mark.parametrize( + 'core_schema_left,core_schema_right,input_value,expected_value', + [ + (core_schema.int_schema(), core_schema.bool_schema(), True, True), + (core_schema.int_schema(), core_schema.bool_schema(), 1, 1), + (core_schema.str_schema(), core_schema.int_schema(), 1, 1), + (core_schema.str_schema(), core_schema.int_schema(), '1', '1'), + (core_schema.int_schema(), core_schema.bool_schema(), IntSubclass(1), 1), + ( + core_schema.decimal_schema(), + core_schema.int_schema(), + Decimal('1'), + '1', + ), + (core_schema.decimal_schema(), core_schema.int_schema(), 1, 1), + ( + core_schema.decimal_schema(), + core_schema.float_schema(), + Decimal('1.'), + '1', + ), + ( + core_schema.decimal_schema(), + core_schema.str_schema(), + Decimal('_1'), + '1', + ), + ( + core_schema.decimal_schema(), + core_schema.str_schema(), + '_1', + '_1', + ), + ], +) +def test_union_serializer_picks_exact_type_over_subclass_json( + core_schema_left, core_schema_right, input_value, expected_value, reverse +): + s = SchemaSerializer( + core_schema.union_schema( + [core_schema_right, core_schema_left] if reverse else [core_schema_left, core_schema_right] + ) + ) + assert s.to_python(input_value, mode='json') == expected_value + assert s.to_json(input_value) == json.dumps(expected_value).encode()