From 26cc6eed38cd82da58c8e29d6d2b98d7bd0ca77d Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Wed, 19 Jun 2024 19:38:06 +1000 Subject: [PATCH] Use Typevar defaults for `TaskStatus` and `Matcher` (#3019) * Default TaskStatus to use None if unspecified * Default Matcher to BaseException if unspecified * Update Sphinx logic for new typevar name * Add some type tests for defaulted typevar classes --- docs/source/conf.py | 14 +++--- src/trio/_core/_run.py | 19 ++++---- src/trio/_tests/type_tests/raisesgroup.py | 8 ++++ src/trio/_tests/type_tests/task_status.py | 29 ++++++++++++ src/trio/testing/_raises_group.py | 55 ++++++++++++++--------- 5 files changed, 91 insertions(+), 34 deletions(-) create mode 100644 src/trio/_tests/type_tests/task_status.py diff --git a/docs/source/conf.py b/docs/source/conf.py index ff08adab48..7ea27de24b 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -113,8 +113,10 @@ def autodoc_process_signature( # name. assert isinstance(obj, property), obj assert isinstance(obj.fget, types.FunctionType), obj.fget - assert obj.fget.__annotations__["return"] == "type[E]", obj.fget.__annotations__ - obj.fget.__annotations__["return"] = "type[~trio.testing._raises_group.E]" + assert ( + obj.fget.__annotations__["return"] == "type[MatchE]" + ), obj.fget.__annotations__ + obj.fget.__annotations__["return"] = "type[~trio.testing._raises_group.MatchE]" if signature is not None: signature = signature.replace("~_contextvars.Context", "~contextvars.Context") if name == "trio.lowlevel.RunVar": # Typevar is not useful here. @@ -123,13 +125,15 @@ def autodoc_process_signature( # Strip the type from the union, make it look like = ... signature = signature.replace(" | type[trio._core._local._NoValue]", "") signature = signature.replace("", "...") - if ( - name in ("trio.testing.RaisesGroup", "trio.testing.Matcher") - and "+E" in signature + if name in ("trio.testing.RaisesGroup", "trio.testing.Matcher") and ( + "+E" in signature or "+MatchE" in signature ): # This typevar being covariant isn't handled correctly in some cases, strip the + # and insert the fully-qualified name. signature = signature.replace("+E", "~trio.testing._raises_group.E") + signature = signature.replace( + "+MatchE", "~trio.testing._raises_group.MatchE" + ) if "DTLS" in name: signature = signature.replace("SSL.Context", "OpenSSL.SSL.Context") # Don't specify PathLike[str] | PathLike[bytes], this is just for humans. diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 1eddaad431..1512fdf954 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -21,7 +21,6 @@ Final, NoReturn, Protocol, - TypeVar, cast, overload, ) @@ -54,12 +53,6 @@ if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup -FnT = TypeVar("FnT", bound="Callable[..., Any]") -StatusT = TypeVar("StatusT") -StatusT_co = TypeVar("StatusT_co", covariant=True) -StatusT_contra = TypeVar("StatusT_contra", contravariant=True) -RetT = TypeVar("RetT") - if TYPE_CHECKING: import contextvars @@ -77,9 +70,19 @@ # for some strange reason Sphinx works with outcome.Outcome, but not Outcome, in # start_guest_run. Same with types.FrameType in iter_await_frames import outcome - from typing_extensions import Self, TypeVarTuple, Unpack + from typing_extensions import Self, TypeVar, TypeVarTuple, Unpack PosArgT = TypeVarTuple("PosArgT") + StatusT = TypeVar("StatusT", default=None) + StatusT_contra = TypeVar("StatusT_contra", contravariant=True, default=None) +else: + from typing import TypeVar + + StatusT = TypeVar("StatusT") + StatusT_contra = TypeVar("StatusT_contra", contravariant=True) + +FnT = TypeVar("FnT", bound="Callable[..., Any]") +RetT = TypeVar("RetT") DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: Final = 1000 diff --git a/src/trio/_tests/type_tests/raisesgroup.py b/src/trio/_tests/type_tests/raisesgroup.py index ba88eb09cc..fe4053ebc5 100644 --- a/src/trio/_tests/type_tests/raisesgroup.py +++ b/src/trio/_tests/type_tests/raisesgroup.py @@ -37,6 +37,14 @@ def check_inheritance_and_assignments() -> None: assert a +def check_matcher_typevar_default(e: Matcher) -> object: + assert e.exception_type is not None + exc: type[BaseException] = e.exception_type + # this would previously pass, as the type would be `Any` + e.exception_type().blah() # type: ignore + return exc # Silence Pyright unused var warning + + def check_basic_contextmanager() -> None: # One level of Group is correctly translated - except it's a BaseExceptionGroup # instead of an ExceptionGroup. diff --git a/src/trio/_tests/type_tests/task_status.py b/src/trio/_tests/type_tests/task_status.py new file mode 100644 index 0000000000..90cfc6957f --- /dev/null +++ b/src/trio/_tests/type_tests/task_status.py @@ -0,0 +1,29 @@ +"""Check that started() can only be called for TaskStatus[None].""" + +from trio import TaskStatus +from typing_extensions import assert_type + + +async def check_status( + none_status_explicit: TaskStatus[None], + none_status_implicit: TaskStatus, + int_status: TaskStatus[int], +) -> None: + assert_type(none_status_explicit, TaskStatus[None]) + assert_type(none_status_implicit, TaskStatus[None]) # Default typevar + assert_type(int_status, TaskStatus[int]) + + # Omitting the parameter is only allowed for None. + none_status_explicit.started() + none_status_implicit.started() + int_status.started() # type: ignore + + # Explicit None is allowed. + none_status_explicit.started(None) + none_status_implicit.started(None) + int_status.started(None) # type: ignore + + none_status_explicit.started(42) # type: ignore + none_status_implicit.started(42) # type: ignore + int_status.started(42) + int_status.started(True) diff --git a/src/trio/testing/_raises_group.py b/src/trio/testing/_raises_group.py index 16bde651f4..7c2e6c5a83 100644 --- a/src/trio/testing/_raises_group.py +++ b/src/trio/testing/_raises_group.py @@ -10,7 +10,6 @@ Literal, Pattern, Sequence, - TypeVar, cast, overload, ) @@ -26,35 +25,49 @@ import types from _pytest._code.code import ExceptionChainRepr, ReprExceptionInfo, Traceback - from typing_extensions import TypeGuard + from typing_extensions import TypeGuard, TypeVar -if sys.version_info < (3, 11): - from exceptiongroup import BaseExceptionGroup + MatchE = TypeVar( + "MatchE", bound=BaseException, default=BaseException, covariant=True + ) +else: + from typing import TypeVar + MatchE = TypeVar("MatchE", bound=BaseException, covariant=True) +# RaisesGroup doesn't work with a default. E = TypeVar("E", bound=BaseException, covariant=True) +# These two typevars are special cased in sphinx config to workaround lookup bugs. + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup @final -class _ExceptionInfo(Generic[E]): +class _ExceptionInfo(Generic[MatchE]): """Minimal re-implementation of pytest.ExceptionInfo, only used if pytest is not available. Supports a subset of its features necessary for functionality of :class:`trio.testing.RaisesGroup` and :class:`trio.testing.Matcher`.""" - _excinfo: tuple[type[E], E, types.TracebackType] | None + _excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None - def __init__(self, excinfo: tuple[type[E], E, types.TracebackType] | None): + def __init__( + self, excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None + ): self._excinfo = excinfo - def fill_unfilled(self, exc_info: tuple[type[E], E, types.TracebackType]) -> None: + def fill_unfilled( + self, exc_info: tuple[type[MatchE], MatchE, types.TracebackType] + ) -> None: """Fill an unfilled ExceptionInfo created with ``for_later()``.""" assert self._excinfo is None, "ExceptionInfo was already filled" self._excinfo = exc_info @classmethod - def for_later(cls) -> _ExceptionInfo[E]: + def for_later(cls) -> _ExceptionInfo[MatchE]: """Return an unfilled ExceptionInfo.""" return cls(None) + # Note, special cased in sphinx config, since "type" conflicts. @property - def type(self) -> type[E]: + def type(self) -> type[MatchE]: """The exception class.""" assert ( self._excinfo is not None @@ -62,7 +75,7 @@ def type(self) -> type[E]: return self._excinfo[0] @property - def value(self) -> E: + def value(self) -> MatchE: """The exception value.""" assert ( self._excinfo is not None @@ -95,7 +108,7 @@ def getrepr( showlocals: bool = False, style: str = "long", abspath: bool = False, - tbfilter: bool | Callable[[_ExceptionInfo[BaseException]], Traceback] = True, + tbfilter: bool | Callable[[_ExceptionInfo], Traceback] = True, funcargs: bool = False, truncate_locals: bool = True, chain: bool = True, @@ -135,7 +148,7 @@ def _stringify_exception(exc: BaseException) -> str: @final -class Matcher(Generic[E]): +class Matcher(Generic[MatchE]): """Helper class to be used together with RaisesGroups when you want to specify requirements on sub-exceptions. Only specifying the type is redundant, and it's also unnecessary when the type is a nested `RaisesGroup` since it supports the same arguments. The type is checked with `isinstance`, and does not need to be an exact match. If that is wanted you can use the ``check`` parameter. :meth:`trio.testing.Matcher.matches` can also be used standalone to check individual exceptions. @@ -154,10 +167,10 @@ class Matcher(Generic[E]): # At least one of the three parameters must be passed. @overload def __init__( - self: Matcher[E], - exception_type: type[E], + self: Matcher[MatchE], + exception_type: type[MatchE], match: str | Pattern[str] = ..., - check: Callable[[E], bool] = ..., + check: Callable[[MatchE], bool] = ..., ): ... @overload @@ -174,9 +187,9 @@ def __init__(self, *, check: Callable[[BaseException], bool]): ... def __init__( self, - exception_type: type[E] | None = None, + exception_type: type[MatchE] | None = None, match: str | Pattern[str] | None = None, - check: Callable[[E], bool] | None = None, + check: Callable[[MatchE], bool] | None = None, ): if exception_type is None and match is None and check is None: raise ValueError("You must specify at least one parameter to match on.") @@ -192,7 +205,7 @@ def __init__( self.match = match self.check = check - def matches(self, exception: BaseException) -> TypeGuard[E]: + def matches(self, exception: BaseException) -> TypeGuard[MatchE]: """Check if an exception matches the requirements of this Matcher. Examples:: @@ -220,7 +233,7 @@ def matches(self, exception: BaseException) -> TypeGuard[E]: return False # If exception_type is None check() accepts BaseException. # If non-none, we have done an isinstance check above. - if self.check is not None and not self.check(cast(E, exception)): + if self.check is not None and not self.check(cast(MatchE, exception)): return False return True @@ -254,8 +267,8 @@ def __str__(self) -> str: # We lie to type checkers that we inherit, so excinfo.value and sub-exceptiongroups can be treated as ExceptionGroups if TYPE_CHECKING: SuperClass = BaseExceptionGroup -# Inheriting at runtime leads to a series of TypeErrors, so we do not want to do that. else: + # At runtime, use a redundant Generic base class which effectively gets ignored. SuperClass = Generic