Skip to content

Commit

Permalink
WIP con-type support
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD committed Dec 30, 2020
1 parent ca5bc7f commit eb7aaca
Show file tree
Hide file tree
Showing 3 changed files with 292 additions and 6 deletions.
228 changes: 227 additions & 1 deletion pydantic/_hypothesis_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
import ipaddress
import json
import math
from fractions import Fraction
from typing import cast

import hypothesis.strategies as st

import pydantic
import pydantic.color
from pydantic.networks import import_email_validator
import pydantic.types
from pydantic.networks import ascii_domain_regex, import_email_validator, int_domain_regex

# FilePath and DirectoryPath are explicitly unsupported, as we'd have to create
# them on-disk, and that's unsafe in general without being told *where* to do so.
Expand Down Expand Up @@ -141,6 +143,80 @@ def add_luhn_digit(card_number: str) -> str:
st.from_regex('|'.join(card_patterns), fullmatch=True).map(add_luhn_digit), # type: ignore[arg-type]
)


RESOLVERS = {}


def resolves(typ):
def inner(f):
assert f not in RESOLVERS
RESOLVERS[typ] = f
return f

return inner


@resolves(pydantic.JsonWrapper)
def resolve_jsonwrapper(cls): # type: ignore[no-untyped-def]
return st.builds(
json.dumps,
st.from_type(cls.inner_type),
ensure_ascii=st.booleans(),
indent=st.none() | st.integers(0, 16),
sort_keys=st.booleans(),
)


# URLs


@resolves(pydantic.AnyUrl)
def resolve_anyurl(cls): # type: ignore[no-untyped-def]
domains = st.one_of(
st.from_regex(ascii_domain_regex(), fullmatch=True),
st.from_regex(int_domain_regex(), fullmatch=True),
)
if cls.tld_required:

def has_tld(s: str) -> bool:
assert isinstance(s, str)
match = ascii_domain_regex().fullmatch(s) or int_domain_regex().fullmatch(s)
return bool(match and match.group('tld'))

hosts = domains.filter(has_tld)
else:
hosts = domains | st.from_regex(
r'(?P<ipv4>(?:\d{1,3}\.){3}\d{1,3})' r'|(?P<ipv6>\[[A-F0-9]*:[A-F0-9:]+\])',
fullmatch=True,
)

return st.builds(
cls.build,
scheme=(
st.sampled_from(sorted(cls.allowed_schemes))
if cls.allowed_schemes
else st.from_regex(r'(?P<scheme>[a-z][a-z0-9+\-.]+)', fullmatch=True)
),
user=st.one_of(
st.nothing() if cls.user_required else st.none(),
st.from_regex(r'(?P<user>[^\s:/]+)', fullmatch=True),
),
password=st.none() | st.from_regex(r'(?P<password>[^\s/]*)', fullmatch=True),
host=hosts,
port=st.none() | st.integers(0, 2 ** 16 - 1).map(str),
path=st.none() | st.from_regex(r'(?P<path>/[^\s?]*)', fullmatch=True),
query=st.none() | st.from_regex(r'(?P<query>[^\s#]+)', fullmatch=True),
fragment=st.none() | st.from_regex(r'(?P<fragment>\S+)', fullmatch=True),
).filter(lambda url: cls.min_length <= len(url) <= cls.max_length)


st.register_type_strategy(pydantic.AnyUrl, resolve_anyurl)
st.register_type_strategy(pydantic.AnyHttpUrl, resolve_anyurl)
st.register_type_strategy(pydantic.HttpUrl, resolve_anyurl)
st.register_type_strategy(pydantic.PostgresDsn, resolve_anyurl)
st.register_type_strategy(pydantic.RedisDsn, resolve_anyurl)


# UUIDs
st.register_type_strategy(pydantic.UUID1, st.uuids(version=1)) # type: ignore[arg-type]
st.register_type_strategy(pydantic.UUID3, st.uuids(version=3)) # type: ignore[arg-type]
Expand Down Expand Up @@ -173,3 +249,153 @@ def add_luhn_digit(card_number: str) -> str:
st.register_type_strategy(pydantic.NegativeInt, st.integers(max_value=-1)) # type: ignore[arg-type]
st.register_type_strategy(pydantic.PositiveFloat, st.floats(min_value=0, exclude_min=True)) # type: ignore[arg-type]
st.register_type_strategy(pydantic.NegativeFloat, st.floats(max_value=-0.0, exclude_max=True)) # type: ignore[arg-type]


@resolves(pydantic.ConstrainedBytes)
def resolve_conbytes(cls): # type: ignore[no-untyped-def] # pragma: no cover
min_size = cls.min_length or 0
max_size = cls.max_length
if not cls.strip_whitespace:
return st.binary(min_size=min_size, max_size=max_size)
# Fun with regex to ensure we neither start nor end with whitespace
repeats = '{{{},{}}}'.format(
min_size - 2 if min_size > 2 else 0,
max_size - 2 if (max_size or 0) > 2 else '',
)
if min_size >= 2:
pattern = rf'\W.{repeats}\W'
elif min_size == 1:
pattern = rf'\W(.{repeats}\W)?'
else:
assert min_size == 0
pattern = rf'(\W(.{repeats}\W)?)?'
return st.from_regex(pattern.encode(), fullmatch=True)


@resolves(pydantic.ConstrainedDecimal)
def resolve_condecimal(cls): # type: ignore[no-untyped-def]
min_value = cls.ge
max_value = cls.le
if cls.gt is not None:
assert min_value is None, 'Set `gt` or `ge`, but not both'
min_value = cls.gt
if cls.lt is not None:
assert max_value is None, 'Set `lt` or `le`, but not both'
max_value = cls.lt
# max_digits, decimal_places, and multiple_of are handled via the filter
return st.decimals(min_value, max_value, allow_nan=False).filter(cls.validate)


@resolves(pydantic.ConstrainedFloat)
def resolve_confloat(cls): # type: ignore[no-untyped-def]
min_value = cls.ge
max_value = cls.le
exclude_min = False
exclude_max = False
if cls.gt is not None:
assert min_value is None, 'Set `gt` or `ge`, but not both'
min_value = cls.gt
exclude_min = True
if cls.lt is not None:
assert max_value is None, 'Set `lt` or `le`, but not both'
max_value = cls.lt
exclude_max = True
# multiple_of is handled via the filter
return st.floats(
min_value,
max_value,
exclude_min=exclude_min,
exclude_max=exclude_max,
allow_nan=False,
).filter(cls.validate)


@resolves(pydantic.ConstrainedInt)
def resolve_conint(cls): # type: ignore[no-untyped-def]
min_value = cls.ge
max_value = cls.le
if cls.gt is not None:
assert min_value is None, 'Set `gt` or `ge`, but not both'
min_value = cls.gt + 1
if cls.lt is not None:
assert max_value is None, 'Set `lt` or `le`, but not both'
max_value = cls.lt - 1

if cls.multiple_of is None or cls.multiple_of == 1:
return st.integers(min_value, max_value)

# These adjustments and the .map handle integer-valued multiples, while the
# .filter handles trickier cases as for confloat.
if min_value is not None:
min_value = math.ceil(Fraction(min_value) / Fraction(cls.multiple_of))
if max_value is not None:
max_value = math.floor(Fraction(max_value) / Fraction(cls.multiple_of))
return st.integers(min_value, max_value).map(lambda x: x * cls.multiple_of).filter(cls.validate)


@resolves(pydantic.ConstrainedList)
def resolve_conlist(cls): # type: ignore[no-untyped-def]
return st.lists(
st.from_type(cls.item_type),
min_size=cls.min_items,
max_size=cls.max_items,
)


@resolves(pydantic.ConstrainedSet)
def resolve_conset(cls): # type: ignore[no-untyped-def]
return st.sets(
st.from_type(cls.item_type),
min_size=cls.min_items,
max_size=cls.max_items,
)


@resolves(pydantic.ConstrainedStr)
def resolve_constr(cls): # type: ignore[no-untyped-def] # pragma: no cover
min_size = cls.min_length or 0
max_size = cls.max_length

if cls.regex is None and not cls.strip_whitespace:
return st.text(min_size=min_size, max_size=max_size)

if cls.regex is not None:
strategy = st.from_regex(cls.regex)
if cls.strip_whitespace:
strategy = strategy.filter(lambda s: s == s.strip())
elif cls.strip_whitespace:
repeats = '{{{},{}}}'.format(
min_size - 2 if min_size > 2 else 0,
max_size - 2 if (max_size or 0) > 2 else '',
)
if min_size >= 2:
strategy = st.from_regex(rf'\W.{repeats}\W')
elif min_size == 1:
strategy = st.from_regex(rf'\W(.{repeats}\W)?')
else:
assert min_size == 0
strategy = st.from_regex(rf'(\W(.{repeats}\W)?)?')

if min_size == 0 and max_size is None:
return strategy
elif max_size is None:
return strategy.filter(lambda s: min_size <= len(s))
return strategy.filter(lambda s: min_size <= len(s) <= max_size)


def _registered(typ):
# This function replaces the version in `pydantic.types`, in order to
# effect the registration of new constrained types so that Hypothesis
# can generate valid examples.
pydantic.types._DEFINED_TYPES.add(typ)
for supertype, resolver in RESOLVERS.items():
if issubclass(typ, supertype):
st.register_type_strategy(typ, resolver(typ))
break
return typ


# Register all previously-defined types, then patch in our new function
for typ in pydantic.types._DEFINED_TYPES:
_registered(typ)
pydantic.types._registered = _registered
24 changes: 19 additions & 5 deletions pydantic/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
cast,
)
from uuid import UUID
from weakref import WeakSet

from . import errors
from .utils import import_string, update_not_none
Expand Down Expand Up @@ -109,6 +110,19 @@

ModelOrDc = Type[Union['BaseModel', 'Dataclass']]

_DEFINED_TYPES = WeakSet()


def _registered(typ):
# In order to generate valid examples of constrained types, Hypothesis needs
# to inspect the type object - so we keep a weakref to each contype object
# until it can be registered. When (or if) our Hypothesis plugin is loaded,
# it monkeypatches this function.
# If Hypothesis is never used, the total effect is to keep a weak reference
# which has minimal memory usage and doesn't even affect garbage collection.
_DEFINED_TYPES.add(typ)
return typ


class ConstrainedBytes(bytes):
strip_whitespace = False
Expand All @@ -134,7 +148,7 @@ class StrictBytes(ConstrainedBytes):
def conbytes(*, strip_whitespace: bool = False, min_length: int = None, max_length: int = None) -> Type[bytes]:
# use kwargs then define conf in a dict to aid with IDE type hinting
namespace = dict(strip_whitespace=strip_whitespace, min_length=min_length, max_length=max_length)
return type('ConstrainedBytesValue', (ConstrainedBytes,), namespace)
return _registered(type('ConstrainedBytesValue', (ConstrainedBytes,), namespace))


T = TypeVar('T')
Expand Down Expand Up @@ -179,7 +193,7 @@ def conlist(item_type: Type[T], *, min_items: int = None, max_items: int = None)
# __args__ is needed to conform to typing generics api
namespace = {'min_items': min_items, 'max_items': max_items, 'item_type': item_type, '__args__': (item_type,)}
# We use new_class to be able to deal with Generic types
return new_class('ConstrainedListValue', (ConstrainedList,), {}, lambda ns: ns.update(namespace))
return _registered(new_class('ConstrainedListValue', (ConstrainedList,), {}, lambda ns: ns.update(namespace)))


# This types superclass should be Set[T], but cython chokes on that...
Expand Down Expand Up @@ -218,7 +232,7 @@ def conset(item_type: Type[T], *, min_items: int = None, max_items: int = None)
# __args__ is needed to conform to typing generics api
namespace = {'min_items': min_items, 'max_items': max_items, 'item_type': item_type, '__args__': [item_type]}
# We use new_class to be able to deal with Generic types
return new_class('ConstrainedSetValue', (ConstrainedSet,), {}, lambda ns: ns.update(namespace))
return _registered(new_class('ConstrainedSetValue', (ConstrainedSet,), {}, lambda ns: ns.update(namespace)))


class ConstrainedStr(str):
Expand Down Expand Up @@ -272,7 +286,7 @@ def constr(
curtail_length=curtail_length,
regex=regex and re.compile(regex),
)
return type('ConstrainedStrValue', (ConstrainedStr,), namespace)
return _registered(type('ConstrainedStrValue', (ConstrainedStr,), namespace))


class StrictStr(ConstrainedStr):
Expand Down Expand Up @@ -344,7 +358,7 @@ def __new__(cls, name: str, bases: Any, dct: Dict[str, Any]) -> 'ConstrainedInt'
if new_cls.lt is not None and new_cls.le is not None:
raise errors.ConfigError('bounds lt and le cannot be specified at the same time')

return new_cls
return _registered(new_cls)


class ConstrainedInt(int, metaclass=ConstrainedNumberMeta):
Expand Down
46 changes: 46 additions & 0 deletions tests/test_hypothesis_plugin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing

import pytest
from hypothesis import given, settings, strategies as st

Expand Down Expand Up @@ -46,6 +48,45 @@ class NumbersModel(pydantic.BaseModel):
posfloat: pydantic.PositiveFloat
negfloat: pydantic.NegativeFloat

class JsonModel(pydantic.BaseModel):
json_any: pydantic.Json
json_int: pydantic.Json[int]
json_float: pydantic.Json[float]
json_str: pydantic.Json[str]
json_int_or_str: pydantic.Json[typing.Union[int, str]]
json_list_of_float: pydantic.Json[typing.List[float]]

class URLsModel(pydantic.BaseModel):
anyurl: pydantic.AnyUrl
anyhttp: pydantic.AnyHttpUrl
http: pydantic.HttpUrl
postgres: pydantic.PostgresDsn
redis: pydantic.RedisDsn

class ConstrainedNumbersModel(pydantic.BaseModel):
conintt: pydantic.conint(gt=10, lt=100)
coninte: pydantic.conint(ge=10, le=100)
conintmul: pydantic.conint(ge=10, le=100, multiple_of=7)
confloatt: pydantic.confloat(gt=10, lt=100)
confloate: pydantic.confloat(ge=10, le=100)
condecimalt: pydantic.condecimal(gt=10, lt=100)
condecimale: pydantic.condecimal(ge=10, le=100)

class CollectionsModel(pydantic.BaseModel):
conset: pydantic.conset(int, min_items=2, max_items=4)
conlist: pydantic.conlist(int, min_items=2, max_items=4)

class Foo:
# Trivial class to test constrained collections element type
pass

class CollectionsFooModel(pydantic.BaseModel):
conset: pydantic.conset(Foo, min_items=2, max_items=4)
conlist: pydantic.conlist(Foo, min_items=2, max_items=4)

class Config:
arbitrary_types_allowed = True

yield from (
MiscModel,
StringsModel,
Expand All @@ -55,6 +96,11 @@ class NumbersModel(pydantic.BaseModel):
IPvAnyNetwork,
StrictNumbersModel,
NumbersModel,
JsonModel,
URLsModel,
ConstrainedNumbersModel,
CollectionsModel,
CollectionsFooModel,
)

try:
Expand Down

0 comments on commit eb7aaca

Please sign in to comment.