Skip to content

Commit

Permalink
collapse arguments into ValidationState (#886)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt authored Aug 15, 2023
1 parent 0a072b0 commit 2e59f25
Show file tree
Hide file tree
Showing 43 changed files with 702 additions and 938 deletions.
47 changes: 9 additions & 38 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ use pyo3::PyTypeInfo;
use serde::{ser::Error, Serialize, Serializer};

use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ValError, ValLineError, ValResult};
use crate::recursion_guard::RecursionGuard;
use crate::tools::py_err;
use crate::validators::{CombinedValidator, Extra, Validator};
use crate::validators::{CombinedValidator, ValidationState, Validator};

use super::parse_json::{JsonArray, JsonInput, JsonObject};
use super::{py_error_on_minusone, Input};
Expand Down Expand Up @@ -157,15 +156,13 @@ fn validate_iter_to_vec<'a, 's>(
capacity: usize,
mut max_length_check: MaxLengthCheck<'a, impl Input<'a>>,
validator: &'s CombinedValidator,
extra: &Extra,
definitions: &'a [CombinedValidator],
recursion_guard: &'s mut RecursionGuard,
state: &mut ValidationState,
) -> ValResult<'a, Vec<PyObject>> {
let mut output: Vec<PyObject> = Vec::with_capacity(capacity);
let mut errors: Vec<ValLineError> = Vec::new();
for (index, item_result) in iter.enumerate() {
let item = item_result.map_err(|e| any_next_error!(py, e, max_length_check.input, index))?;
match validator.validate(py, item, extra, definitions, recursion_guard) {
match validator.validate(py, item, state) {
Ok(item) => {
max_length_check.incr()?;
output.push(item);
Expand Down Expand Up @@ -226,14 +223,12 @@ fn validate_iter_to_set<'a, 's>(
field_type: &'static str,
max_length: Option<usize>,
validator: &'s CombinedValidator,
extra: &Extra,
definitions: &'a [CombinedValidator],
recursion_guard: &'s mut RecursionGuard,
state: &mut ValidationState,
) -> ValResult<'a, ()> {
let mut errors: Vec<ValLineError> = Vec::new();
for (index, item_result) in iter.enumerate() {
let item = item_result.map_err(|e| any_next_error!(py, e, input, index))?;
match validator.validate(py, item, extra, definitions, recursion_guard) {
match validator.validate(py, item, state) {
Ok(item) => {
set.build_add(item)?;
if let Some(max_length) = max_length {
Expand Down Expand Up @@ -315,9 +310,7 @@ impl<'a> GenericIterable<'a> {
max_length: Option<usize>,
field_type: &'static str,
validator: &'s CombinedValidator,
extra: &Extra,
definitions: &'a [CombinedValidator],
recursion_guard: &'s mut RecursionGuard,
state: &mut ValidationState,
) -> ValResult<'a, Vec<PyObject>> {
let capacity = self
.generic_len()
Expand All @@ -326,16 +319,7 @@ impl<'a> GenericIterable<'a> {

macro_rules! validate {
($iter:expr) => {
validate_iter_to_vec(
py,
$iter,
capacity,
max_length_check,
validator,
extra,
definitions,
recursion_guard,
)
validate_iter_to_vec(py, $iter, capacity, max_length_check, validator, state)
};
}

Expand All @@ -360,24 +344,11 @@ impl<'a> GenericIterable<'a> {
max_length: Option<usize>,
field_type: &'static str,
validator: &'s CombinedValidator,
extra: &Extra,
definitions: &'a [CombinedValidator],
recursion_guard: &'s mut RecursionGuard,
state: &mut ValidationState,
) -> ValResult<'a, ()> {
macro_rules! validate_set {
($iter:expr) => {
validate_iter_to_set(
py,
set,
$iter,
input,
field_type,
max_length,
validator,
extra,
definitions,
recursion_guard,
)
validate_iter_to_set(py, set, $iter, input, field_type, max_length, validator, state)
};
}

Expand Down
9 changes: 2 additions & 7 deletions src/validators/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ use pyo3::types::PyDict;
use crate::errors::ValResult;
use crate::input::Input;

use crate::recursion_guard::RecursionGuard;

use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::{validation_state::ValidationState, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

/// This might seem useless, but it's useful in DictValidator to avoid Option<Validator> a lot
#[derive(Debug, Clone)]
Expand All @@ -31,11 +29,8 @@ impl Validator for AnyValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
_extra: &Extra,
_definitions: &'data Definitions<CombinedValidator>,
_recursion_guard: &'s mut RecursionGuard,
_state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
// Ok(input.clone().into_py(py))
Ok(input.to_object(py))
}

Expand Down
22 changes: 8 additions & 14 deletions src/validators/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult};
use crate::input::{GenericArguments, Input};
use crate::lookup_key::LookupKey;

use crate::recursion_guard::RecursionGuard;
use crate::tools::SchemaDict;

use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, Clone)]
struct Parameter {
Expand Down Expand Up @@ -165,9 +165,7 @@ impl Validator for ArgumentsValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
extra: &Extra,
definitions: &'data Definitions<CombinedValidator>,
recursion_guard: &'s mut RecursionGuard,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
let args = input.validate_args()?;

Expand Down Expand Up @@ -205,9 +203,7 @@ impl Validator for ArgumentsValidator {
));
}
(Some(pos_value), None) => {
match parameter
.validator
.validate(py, pos_value, extra, definitions, recursion_guard)
match parameter.validator.validate(py, pos_value, state)
{
Ok(value) => output_args.push(value),
Err(ValError::LineErrors(line_errors)) => {
Expand All @@ -217,9 +213,7 @@ impl Validator for ArgumentsValidator {
}
}
(None, Some((lookup_path, kw_value))) => {
match parameter
.validator
.validate(py, kw_value, extra, definitions, recursion_guard)
match parameter.validator.validate(py, kw_value, state)
{
Ok(value) => output_kwargs.set_item(parameter.kwarg_key.as_ref().unwrap(), value)?,
Err(ValError::LineErrors(line_errors)) => {
Expand All @@ -231,7 +225,7 @@ impl Validator for ArgumentsValidator {
}
}
(None, None) => {
if let Some(value) = parameter.validator.default_value(py, Some(parameter.name.as_str()), extra, definitions, recursion_guard)? {
if let Some(value) = parameter.validator.default_value(py, Some(parameter.name.as_str()), state)? {
if let Some(ref kwarg_key) = parameter.kwarg_key {
output_kwargs.set_item(kwarg_key, value)?;
} else {
Expand Down Expand Up @@ -261,7 +255,7 @@ impl Validator for ArgumentsValidator {
if len > self.positional_params_count {
if let Some(ref validator) = self.var_args_validator {
for (index, item) in $slice_macro!(args, self.positional_params_count, len).iter().enumerate() {
match validator.validate(py, item, extra, definitions, recursion_guard) {
match validator.validate(py, item, state) {
Ok(value) => output_args.push(value),
Err(ValError::LineErrors(line_errors)) => {
errors.extend(
Expand Down Expand Up @@ -303,7 +297,7 @@ impl Validator for ArgumentsValidator {
};
if !used_kwargs.contains(either_str.as_cow()?.as_ref()) {
match self.var_kwargs_validator {
Some(ref validator) => match validator.validate(py, value, extra, definitions, recursion_guard) {
Some(ref validator) => match validator.validate(py, value, state) {
Ok(value) => output_kwargs.set_item(either_str.as_py_string(py), value)?,
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
Expand Down
11 changes: 4 additions & 7 deletions src/validators/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ use crate::build_tools::is_strict;
use crate::errors::ValResult;
use crate::input::Input;

use crate::recursion_guard::RecursionGuard;

use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};

#[derive(Debug, Clone)]
pub struct BoolValidator {
Expand Down Expand Up @@ -36,13 +34,12 @@ impl Validator for BoolValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
extra: &Extra,
_definitions: &'data Definitions<CombinedValidator>,
_recursion_guard: &'s mut RecursionGuard,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
// TODO in theory this could be quicker if we used PyBool rather than going to a bool
// and back again, might be worth profiling?
Ok(input.validate_bool(extra.strict.unwrap_or(self.strict))?.into_py(py))
let strict = state.strict_or(self.strict);
Ok(input.validate_bool(strict)?.into_py(py))
}

fn different_strict_behavior(
Expand Down
15 changes: 5 additions & 10 deletions src/validators/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ use crate::build_tools::is_strict;
use crate::errors::{ErrorType, ValError, ValResult};
use crate::input::Input;

use crate::recursion_guard::RecursionGuard;
use crate::tools::SchemaDict;

use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};

#[derive(Debug, Clone)]
pub struct BytesValidator {
Expand Down Expand Up @@ -45,11 +44,9 @@ impl Validator for BytesValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
extra: &Extra,
_definitions: &'data Definitions<CombinedValidator>,
_recursion_guard: &'s mut RecursionGuard,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
let either_bytes = input.validate_bytes(extra.strict.unwrap_or(self.strict))?;
let either_bytes = input.validate_bytes(state.strict_or(self.strict))?;
Ok(either_bytes.into_py(py))
}

Expand Down Expand Up @@ -84,11 +81,9 @@ impl Validator for BytesConstrainedValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
extra: &Extra,
_definitions: &'data Definitions<CombinedValidator>,
_recursion_guard: &'s mut RecursionGuard,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
let either_bytes = input.validate_bytes(extra.strict.unwrap_or(self.strict))?;
let either_bytes = input.validate_bytes(state.strict_or(self.strict))?;
let len = either_bytes.len()?;

if let Some(min_length) = self.min_length {
Expand Down
14 changes: 5 additions & 9 deletions src/validators/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ use pyo3::types::{PyDict, PyTuple};
use crate::errors::ValResult;
use crate::input::Input;

use crate::recursion_guard::RecursionGuard;
use crate::tools::SchemaDict;

use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, Clone)]
pub struct CallValidator {
Expand Down Expand Up @@ -76,13 +76,9 @@ impl Validator for CallValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
extra: &Extra,
definitions: &'data Definitions<CombinedValidator>,
recursion_guard: &'s mut RecursionGuard,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
let args = self
.arguments_validator
.validate(py, input, extra, definitions, recursion_guard)?;
let args = self.arguments_validator.validate(py, input, state)?;

let return_value = if let Ok((args, kwargs)) = args.extract::<(&PyTuple, &PyDict)>(py) {
self.function.call(py, args, Some(kwargs))?
Expand All @@ -95,7 +91,7 @@ impl Validator for CallValidator {

if let Some(return_validator) = &self.return_validator {
return_validator
.validate(py, return_value.into_ref(py), extra, definitions, recursion_guard)
.validate(py, return_value.into_ref(py), state)
.map_err(|e| e.with_outer_location("return".into()))
} else {
Ok(return_value.to_object(py))
Expand Down
8 changes: 2 additions & 6 deletions src/validators/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ use pyo3::types::PyDict;
use crate::errors::{ErrorTypeDefaults, ValError, ValResult};
use crate::input::Input;

use crate::recursion_guard::RecursionGuard;

use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};

#[derive(Debug, Clone)]
pub struct CallableValidator;
Expand All @@ -30,9 +28,7 @@ impl Validator for CallableValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
_extra: &Extra,
_definitions: &'data Definitions<CombinedValidator>,
_recursion_guard: &'s mut RecursionGuard,
_state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
match input.callable() {
true => Ok(input.to_object(py)),
Expand Down
14 changes: 5 additions & 9 deletions src/validators/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use pyo3::types::{PyDict, PyList};
use crate::build_tools::py_schema_err;
use crate::errors::ValResult;
use crate::input::Input;
use crate::recursion_guard::RecursionGuard;
use crate::tools::SchemaDict;

use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, Clone)]
pub struct ChainValidator {
Expand Down Expand Up @@ -74,17 +74,13 @@ impl Validator for ChainValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
extra: &Extra,
definitions: &'data Definitions<CombinedValidator>,
recursion_guard: &'s mut RecursionGuard,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
let mut steps_iter = self.steps.iter();
let first_step = steps_iter.next().unwrap();
let value = first_step.validate(py, input, extra, definitions, recursion_guard)?;
let value = first_step.validate(py, input, state)?;

steps_iter.try_fold(value, |v, step| {
step.validate(py, v.into_ref(py), extra, definitions, recursion_guard)
})
steps_iter.try_fold(value, |v, step| step.validate(py, v.into_ref(py), state))
}

fn different_strict_behavior(
Expand Down
10 changes: 4 additions & 6 deletions src/validators/custom_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use pyo3::types::PyDict;
use crate::build_tools::py_schema_err;
use crate::errors::{ErrorType, PydanticCustomError, PydanticKnownError, ValError, ValResult};
use crate::input::Input;
use crate::recursion_guard::RecursionGuard;
use crate::tools::SchemaDict;

use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, Clone)]
pub enum CustomError {
Expand Down Expand Up @@ -92,12 +92,10 @@ impl Validator for CustomErrorValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
extra: &Extra,
definitions: &'data Definitions<CombinedValidator>,
recursion_guard: &'s mut RecursionGuard,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
self.validator
.validate(py, input, extra, definitions, recursion_guard)
.validate(py, input, state)
.map_err(|_| self.custom_error.as_val_error(input))
}

Expand Down
Loading

0 comments on commit 2e59f25

Please sign in to comment.