Skip to content

Commit

Permalink
Fix equality checks for primitives in literals (#1459)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle authored Sep 25, 2024
1 parent 4aa52a8 commit f389728
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 22 deletions.
53 changes: 31 additions & 22 deletions src/validators/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ pub struct LiteralLookup<T: Debug> {
expected_py_dict: Option<Py<PyDict>>,
// Catch all for unhashable types like list
expected_py_values: Option<Vec<(Py<PyAny>, usize)>>,
// Fallback for ints, bools, and strings to use Python hash and equality checks
// which we can't mix with `expected_py_dict`, as there may be conflicts
// for an example, see tests/test_validators/test_literal.py::test_mix_int_enum_with_int
expected_py_primitives: Option<Py<PyDict>>,

pub values: Vec<T>,
}
Expand All @@ -46,20 +50,24 @@ impl<T: Debug> LiteralLookup<T> {
let mut expected_str: AHashMap<String, usize> = AHashMap::new();
let expected_py_dict = PyDict::new_bound(py);
let mut expected_py_values = Vec::new();
let expected_py_primitives = PyDict::new_bound(py);
let mut values = Vec::new();
for (k, v) in expected {
let id = values.len();
values.push(v);
if let Ok(bool) = k.validate_bool(true) {
if bool.into_inner() {

if let Ok(bool_value) = k.validate_bool(true) {
if bool_value.into_inner() {
expected_bool.true_id = Some(id);
} else {
expected_bool.false_id = Some(id);
}
expected_py_primitives.set_item(&k, id)?;
}
if k.is_exact_instance_of::<PyInt>() {
if let Ok(int_64) = k.extract::<i64>() {
expected_int.insert(int_64, id);
expected_py_primitives.set_item(&k, id)?;
} else {
// cover the case of an int that's > i64::MAX etc.
expected_py_dict.set_item(k, id)?;
Expand All @@ -69,32 +77,20 @@ impl<T: Debug> LiteralLookup<T> {
.as_cow()
.map_err(|_| py_schema_error_type!("error extracting str {:?}", k))?;
expected_str.insert(str.to_string(), id);
expected_py_primitives.set_item(&k, id)?;
} else if expected_py_dict.set_item(&k, id).is_err() {
expected_py_values.push((k.as_unbound().clone_ref(py), id));
}
}

Ok(Self {
expected_bool: match expected_bool.true_id.is_some() || expected_bool.false_id.is_some() {
true => Some(expected_bool),
false => None,
},
expected_int: match expected_int.is_empty() {
true => None,
false => Some(expected_int),
},
expected_str: match expected_str.is_empty() {
true => None,
false => Some(expected_str),
},
expected_py_dict: match expected_py_dict.is_empty() {
true => None,
false => Some(expected_py_dict.into()),
},
expected_py_values: match expected_py_values.is_empty() {
true => None,
false => Some(expected_py_values),
},
expected_bool: (expected_bool.true_id.is_some() || expected_bool.false_id.is_some())
.then_some(expected_bool),
expected_int: (!expected_int.is_empty()).then_some(expected_int),
expected_str: (!expected_str.is_empty()).then_some(expected_str),
expected_py_dict: (!expected_py_dict.is_empty()).then_some(expected_py_dict.into()),
expected_py_values: (!expected_py_values.is_empty()).then_some(expected_py_values),
expected_py_primitives: (!expected_py_primitives.is_empty()).then_some(expected_py_primitives.into()),
values,
})
}
Expand Down Expand Up @@ -162,6 +158,19 @@ impl<T: Debug> LiteralLookup<T> {
}
}
};

// this one must be last to avoid conflicts with the other lookups, think of this
// almost as a lax fallback
if let Some(expected_py_primitives) = &self.expected_py_primitives {
let py_input = py_input.get_or_insert_with(|| input.to_object(py));
// We don't use ? to unpack the result of `get_item` in the next line because unhashable
// inputs will produce a TypeError, which in this case we just want to treat equivalently
// to a failed lookup
if let Ok(Some(v)) = expected_py_primitives.bind(py).get_item(&*py_input) {
let id: usize = v.extract().unwrap();
return Ok(Some((input, &self.values[id])));
}
};
Ok(None)
}

Expand Down
12 changes: 12 additions & 0 deletions tests/validators/test_literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,15 @@ def test_big_int():
m = r'Input should be 18446744073709551617 or 340282366920938463463374607431768211457 \[type=literal_error'
with pytest.raises(ValidationError, match=m):
v.validate_python(37)


def test_enum_for_str() -> None:
class S(str, Enum):
a = 'a'

val_enum = SchemaValidator(core_schema.literal_schema([S.a]))
val_str = SchemaValidator(core_schema.literal_schema(['a']))

for val in [val_enum, val_str]:
assert val.validate_python('a') == 'a'
assert val.validate_python(S.a) == 'a'

0 comments on commit f389728

Please sign in to comment.