diff --git a/src/validators/literal.rs b/src/validators/literal.rs index 5e09c97dc..9ef0d5f82 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -35,6 +35,10 @@ pub struct LiteralLookup { expected_py_dict: Option>, // Catch all for unhashable types like list expected_py_values: Option, usize)>>, + // Fallback for ints, bools, and strings to use Python hash and equality checks + // which we can't mix with `expected_py_dict`, as there may be conflicts + // for an example, see tests/test_validators/test_literal.py::test_mix_int_enum_with_int + expected_py_primitives: Option>, pub values: Vec, } @@ -46,20 +50,24 @@ impl LiteralLookup { let mut expected_str: AHashMap = AHashMap::new(); let expected_py_dict = PyDict::new_bound(py); let mut expected_py_values = Vec::new(); + let expected_py_primitives = PyDict::new_bound(py); let mut values = Vec::new(); for (k, v) in expected { let id = values.len(); values.push(v); - if let Ok(bool) = k.validate_bool(true) { - if bool.into_inner() { + + if let Ok(bool_value) = k.validate_bool(true) { + if bool_value.into_inner() { expected_bool.true_id = Some(id); } else { expected_bool.false_id = Some(id); } + expected_py_primitives.set_item(&k, id)?; } if k.is_exact_instance_of::() { if let Ok(int_64) = k.extract::() { expected_int.insert(int_64, id); + expected_py_primitives.set_item(&k, id)?; } else { // cover the case of an int that's > i64::MAX etc. expected_py_dict.set_item(k, id)?; @@ -69,32 +77,20 @@ impl LiteralLookup { .as_cow() .map_err(|_| py_schema_error_type!("error extracting str {:?}", k))?; expected_str.insert(str.to_string(), id); + expected_py_primitives.set_item(&k, id)?; } else if expected_py_dict.set_item(&k, id).is_err() { expected_py_values.push((k.as_unbound().clone_ref(py), id)); } } Ok(Self { - expected_bool: match expected_bool.true_id.is_some() || expected_bool.false_id.is_some() { - true => Some(expected_bool), - false => None, - }, - expected_int: match expected_int.is_empty() { - true => None, - false => Some(expected_int), - }, - expected_str: match expected_str.is_empty() { - true => None, - false => Some(expected_str), - }, - expected_py_dict: match expected_py_dict.is_empty() { - true => None, - false => Some(expected_py_dict.into()), - }, - expected_py_values: match expected_py_values.is_empty() { - true => None, - false => Some(expected_py_values), - }, + expected_bool: (expected_bool.true_id.is_some() || expected_bool.false_id.is_some()) + .then_some(expected_bool), + expected_int: (!expected_int.is_empty()).then_some(expected_int), + expected_str: (!expected_str.is_empty()).then_some(expected_str), + expected_py_dict: (!expected_py_dict.is_empty()).then_some(expected_py_dict.into()), + expected_py_values: (!expected_py_values.is_empty()).then_some(expected_py_values), + expected_py_primitives: (!expected_py_primitives.is_empty()).then_some(expected_py_primitives.into()), values, }) } @@ -162,6 +158,19 @@ impl LiteralLookup { } } }; + + // this one must be last to avoid conflicts with the other lookups, think of this + // almost as a lax fallback + if let Some(expected_py_primitives) = &self.expected_py_primitives { + let py_input = py_input.get_or_insert_with(|| input.to_object(py)); + // We don't use ? to unpack the result of `get_item` in the next line because unhashable + // inputs will produce a TypeError, which in this case we just want to treat equivalently + // to a failed lookup + if let Ok(Some(v)) = expected_py_primitives.bind(py).get_item(&*py_input) { + let id: usize = v.extract().unwrap(); + return Ok(Some((input, &self.values[id]))); + } + }; Ok(None) } diff --git a/tests/validators/test_literal.py b/tests/validators/test_literal.py index 5f9f942e2..d0d1af279 100644 --- a/tests/validators/test_literal.py +++ b/tests/validators/test_literal.py @@ -389,3 +389,15 @@ def test_big_int(): m = r'Input should be 18446744073709551617 or 340282366920938463463374607431768211457 \[type=literal_error' with pytest.raises(ValidationError, match=m): v.validate_python(37) + + +def test_enum_for_str() -> None: + class S(str, Enum): + a = 'a' + + val_enum = SchemaValidator(core_schema.literal_schema([S.a])) + val_str = SchemaValidator(core_schema.literal_schema(['a'])) + + for val in [val_enum, val_str]: + assert val.validate_python('a') == 'a' + assert val.validate_python(S.a) == 'a'