Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix equality checks for primitives in literals #1459

Merged
merged 5 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions src/validators/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ pub struct LiteralLookup<T: Debug> {
expected_py_dict: Option<Py<PyDict>>,
// Catch all for unhashable types like list
expected_py_values: Option<Vec<(Py<PyAny>, usize)>>,
// Fallback for ints, bools, and strings to use Python equality checks
expected_py_primitives: Option<Vec<(Py<PyAny>, usize)>>,

pub values: Vec<T>,
}
Expand All @@ -46,20 +48,24 @@ impl<T: Debug> LiteralLookup<T> {
let mut expected_str: AHashMap<String, usize> = AHashMap::new();
let expected_py_dict = PyDict::new_bound(py);
let mut expected_py_values = Vec::new();
let mut expected_py_primitives = Vec::new();
let mut values = Vec::new();
for (k, v) in expected {
let id = values.len();
values.push(v);
if let Ok(bool) = k.validate_bool(true) {
if bool.into_inner() {

if let Ok(bool_value) = k.validate_bool(true) {
if bool_value.into_inner() {
expected_bool.true_id = Some(id);
} else {
expected_bool.false_id = Some(id);
}
expected_py_primitives.push((k.as_unbound().clone_ref(py), id));
}
if k.is_exact_instance_of::<PyInt>() {
if let Ok(int_64) = k.extract::<i64>() {
expected_int.insert(int_64, id);
expected_py_primitives.push((k.as_unbound().clone_ref(py), id));
} else {
// cover the case of an int that's > i64::MAX etc.
expected_py_dict.set_item(k, id)?;
Expand All @@ -69,32 +75,20 @@ impl<T: Debug> LiteralLookup<T> {
.as_cow()
.map_err(|_| py_schema_error_type!("error extracting str {:?}", k))?;
expected_str.insert(str.to_string(), id);
expected_py_primitives.push((k.as_unbound().clone_ref(py), id));
} else if expected_py_dict.set_item(&k, id).is_err() {
expected_py_values.push((k.as_unbound().clone_ref(py), id));
}
}

Ok(Self {
expected_bool: match expected_bool.true_id.is_some() || expected_bool.false_id.is_some() {
true => Some(expected_bool),
false => None,
},
expected_int: match expected_int.is_empty() {
true => None,
false => Some(expected_int),
},
expected_str: match expected_str.is_empty() {
true => None,
false => Some(expected_str),
},
expected_py_dict: match expected_py_dict.is_empty() {
true => None,
false => Some(expected_py_dict.into()),
},
expected_py_values: match expected_py_values.is_empty() {
true => None,
false => Some(expected_py_values),
},
expected_bool: (expected_bool.true_id.is_some() || expected_bool.false_id.is_some())
.then_some(expected_bool),
expected_int: (!expected_int.is_empty()).then_some(expected_int),
expected_str: (!expected_str.is_empty()).then_some(expected_str),
expected_py_dict: (!expected_py_dict.is_empty()).then_some(expected_py_dict.into()),
expected_py_values: (!expected_py_values.is_empty()).then_some(expected_py_values),
expected_py_primitives: (!expected_py_primitives.is_empty()).then_some(expected_py_primitives),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder, is there an advantage of expected_py_primitives over expected_py_dict?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conclusion was that expected_py_primitives will be used only in lax mode. I think it probably makes sense to store as a dict and use hash lookup (as that checks __hash__ and __eq__) - we can always make more lax later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also conclusion - we're not going to fall back to union validation, that's too lax + time consuming.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright - can't enforce strict / lax here as we're in the literal lookup, but I've added the hash check.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If enforcing the hash check, surely a dict is just straight up more performant?

values,
})
}
Expand Down Expand Up @@ -162,6 +156,15 @@ impl<T: Debug> LiteralLookup<T> {
}
}
};

if let Some(expected_py_primitives) = &self.expected_py_primitives {
let py_input = py_input.get_or_insert_with(|| input.to_object(py));
for (k, id) in expected_py_primitives {
if k.bind(py).eq(&*py_input).unwrap_or(false) {
return Ok(Some((input, &self.values[*id])));
}
}
};
Ok(None)
}

Expand Down
12 changes: 12 additions & 0 deletions tests/validators/test_literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,15 @@ def test_big_int():
m = r'Input should be 18446744073709551617 or 340282366920938463463374607431768211457 \[type=literal_error'
with pytest.raises(ValidationError, match=m):
v.validate_python(37)


def test_enum_for_str() -> None:
class S(str, Enum):
a = 'a'

val_enum = SchemaValidator(core_schema.literal_schema([S.a]))
val_str = SchemaValidator(core_schema.literal_schema(['a']))

for val in [val_enum, val_str]:
assert val.validate_python('a') == 'a'
assert val.validate_python(S.a) == 'a'
Loading