From 7948574540df6685e4fd94261db8b55d4baa7d20 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 30 Oct 2023 08:30:51 -0500 Subject: [PATCH 1/2] supporting digits validation for normalized and non-normalized Decimal values --- src/validators/decimal.rs | 129 ++++++++++++++++++++++---------------- 1 file changed, 76 insertions(+), 53 deletions(-) diff --git a/src/validators/decimal.rs b/src/validators/decimal.rs index 730eeac69..eb3141c31 100644 --- a/src/validators/decimal.rs +++ b/src/validators/decimal.rs @@ -83,6 +83,41 @@ impl_py_gc_traverse!(DecimalValidator { gt }); +fn extract_decimal_digits_info<'data>( + decimal: &PyAny, + normalized: bool, + py: Python<'data>, +) -> ValResult<'data, (u64, u64)> { + let mut normalized_decimal: Option<&PyAny> = None; + if normalized { + normalized_decimal = Some(decimal.call_method0(intern!(py, "normalize")).unwrap_or(decimal)); + } + let (_, digit_tuple, exponent): (&PyAny, &PyTuple, &PyAny) = normalized_decimal + .unwrap_or(decimal) + .call_method0(intern!(py, "as_tuple"))? + .extract()?; + + // finite values have numeric exponent, we checked is_finite above + let exponent: i64 = exponent.extract()?; + let mut digits: u64 = u64::try_from(digit_tuple.len()).map_err(|e| ValError::InternalErr(e.into()))?; + let decimals; + if exponent >= 0 { + // A positive exponent adds that many trailing zeros. + digits += exponent as u64; + decimals = 0; + } else { + // If the absolute value of the negative exponent is larger than the + // number of digits, then it's the same as the number of digits, + // because it'll consume all the digits in digit_tuple and then + // add abs(exponent) - len(digit_tuple) leading zeros after the + // decimal point. + decimals = exponent.unsigned_abs(); + digits = digits.max(decimals); + } + + Ok((decimals, digits)) +} + impl Validator for DecimalValidator { fn validate<'data>( &self, @@ -98,65 +133,53 @@ impl Validator for DecimalValidator { } if self.check_digits { - let normalized_value = decimal.call_method0(intern!(py, "normalize")).unwrap_or(decimal); - let (_, digit_tuple, exponent): (&PyAny, &PyTuple, &PyAny) = - normalized_value.call_method0(intern!(py, "as_tuple"))?.extract()?; + if let Ok((normalized_decimals, normalized_digits)) = extract_decimal_digits_info(decimal, true, py) { + if let Ok((decimals, digits)) = extract_decimal_digits_info(decimal, false, py) { + if let Some(max_digits) = self.max_digits { + if (digits > max_digits) & (normalized_digits > max_digits) { + return Err(ValError::new( + ErrorType::DecimalMaxDigits { + max_digits, + context: None, + }, + input, + )); + } + } - // finite values have numeric exponent, we checked is_finite above - let exponent: i64 = exponent.extract()?; - let mut digits: u64 = u64::try_from(digit_tuple.len()).map_err(|e| ValError::InternalErr(e.into()))?; - let decimals; - if exponent >= 0 { - // A positive exponent adds that many trailing zeros. - digits += exponent as u64; - decimals = 0; - } else { - // If the absolute value of the negative exponent is larger than the - // number of digits, then it's the same as the number of digits, - // because it'll consume all the digits in digit_tuple and then - // add abs(exponent) - len(digit_tuple) leading zeros after the - // decimal point. - decimals = exponent.unsigned_abs(); - digits = digits.max(decimals); - } + if let Some(decimal_places) = self.decimal_places { + if (decimals > decimal_places) & (normalized_decimals > decimal_places) { + return Err(ValError::new( + ErrorType::DecimalMaxPlaces { + decimal_places, + context: None, + }, + input, + )); + } - if let Some(max_digits) = self.max_digits { - if digits > max_digits { - return Err(ValError::new( - ErrorType::DecimalMaxDigits { - max_digits, - context: None, - }, - input, - )); - } - } + if let Some(max_digits) = self.max_digits { + let whole_digits = digits.saturating_sub(decimals); + let max_whole_digits = max_digits.saturating_sub(decimal_places); - if let Some(decimal_places) = self.decimal_places { - if decimals > decimal_places { - return Err(ValError::new( - ErrorType::DecimalMaxPlaces { - decimal_places, - context: None, - }, - input, - )); - } + let normalized_whole_digits = normalized_digits.saturating_sub(normalized_decimals); + let normalized_max_whole_digits = max_digits.saturating_sub(decimal_places); - if let Some(max_digits) = self.max_digits { - let whole_digits = digits.saturating_sub(decimals); - let max_whole_digits = max_digits.saturating_sub(decimal_places); - if whole_digits > max_whole_digits { - return Err(ValError::new( - ErrorType::DecimalWholeDigits { - whole_digits: max_whole_digits, - context: None, - }, - input, - )); + if (whole_digits > max_whole_digits) + & (normalized_whole_digits > normalized_max_whole_digits) + { + return Err(ValError::new( + ErrorType::DecimalWholeDigits { + whole_digits: max_whole_digits, + context: None, + }, + input, + )); + } + } } } - } + }; } } From 20c5e17b5e5b06b25ab4026b5102f35061f01496 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 30 Oct 2023 08:43:53 -0500 Subject: [PATCH 2/2] adding tests --- tests/validators/test_decimal.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/validators/test_decimal.py b/tests/validators/test_decimal.py index cd54c89ae..43b3d19b9 100644 --- a/tests/validators/test_decimal.py +++ b/tests/validators/test_decimal.py @@ -437,3 +437,31 @@ def test_non_finite_constrained_decimal_values(input_value, allow_inf_nan, expec def test_validate_scientific_notation_from_json(input_value, expected): v = SchemaValidator({'type': 'decimal'}) assert v.validate_json(input_value) == expected + + +def test_validate_max_digits_and_decimal_places() -> None: + v = SchemaValidator({'type': 'decimal', 'max_digits': 5, 'decimal_places': 2}) + + # valid inputs + assert v.validate_json('1.23') == Decimal('1.23') + assert v.validate_json('123.45') == Decimal('123.45') + assert v.validate_json('-123.45') == Decimal('-123.45') + + # invalid inputs + with pytest.raises(ValidationError): + v.validate_json('1234.56') # too many digits + with pytest.raises(ValidationError): + v.validate_json('123.456') # too many decimal places + with pytest.raises(ValidationError): + v.validate_json('123456') # too many digits + with pytest.raises(ValidationError): + v.validate_json('abc') # not a valid decimal + + +def test_validate_max_digits_and_decimal_places_edge_case() -> None: + v = SchemaValidator({'type': 'decimal', 'max_digits': 34, 'decimal_places': 18}) + + # valid inputs + assert v.validate_python(Decimal('9999999999999999.999999999999999999')) == Decimal( + '9999999999999999.999999999999999999' + )