diff --git a/docs/pages/development.rst b/docs/pages/development.rst index 82e7d1d27..0733c9adc 100644 --- a/docs/pages/development.rst +++ b/docs/pages/development.rst @@ -54,7 +54,7 @@ Or as a decorator: >>> from returns.result import Failure, Result >>> from returns.primitives.tracing import collect_traces - >>> @collect_traces() + >>> @collect_traces ... def traced_function(value: str) -> IOResult[str, str]: ... return IOFailure(value) diff --git a/returns/primitives/tracing.py b/returns/primitives/tracing.py index 4ae98c6ed..42e7f2838 100644 --- a/returns/primitives/tracing.py +++ b/returns/primitives/tracing.py @@ -1,13 +1,35 @@ import types from contextlib import contextmanager from inspect import FrameInfo, stack -from typing import List, Optional +from typing import ( + Callable, + ContextManager, + Iterator, + List, + Optional, + TypeVar, + Union, + overload, +) from returns.result import _Failure +_FunctionType = TypeVar('_FunctionType', bound=Callable) -@contextmanager -def collect_traces(): + +@overload +def collect_traces() -> ContextManager[None]: + """Context Manager to active traces collect to the Failures.""" + + +@overload +def collect_traces(function: _FunctionType) -> _FunctionType: + """Decorator to active traces collect to the Failures.""" + + +def collect_traces( + function: Optional[_FunctionType] = None, +) -> Union[_FunctionType, ContextManager[None]]: # noqa: DAR101, DAR201, DAR301 """ Context Manager/Decorator to active traces collect to the Failures. @@ -36,13 +58,16 @@ def collect_traces(): # doctest: # noqa: DAR301, E501 """ - unpatched_get_trace = getattr(_Failure, '_get_trace') # noqa: B009 - substitute_get_trace = types.MethodType(_get_trace, _Failure) - setattr(_Failure, '_get_trace', substitute_get_trace) # noqa: B010 - try: - yield - finally: - setattr(_Failure, '_get_trace', unpatched_get_trace) # noqa: B010 + @contextmanager + def factory() -> Iterator[None]: + unpatched_get_trace = getattr(_Failure, '_get_trace') # noqa: B009 + substitute_get_trace = types.MethodType(_get_trace, _Failure) + setattr(_Failure, '_get_trace', substitute_get_trace) # noqa: B010 + try: # noqa: WPS501 + yield + finally: + setattr(_Failure, '_get_trace', unpatched_get_trace) # noqa: B010 + return factory()(function) if function else factory() def _get_trace(_self: _Failure) -> Optional[List[FrameInfo]]: diff --git a/typesafety/test_primitives/test_tracing/test_collect_traces.yml b/typesafety/test_primitives/test_tracing/test_collect_traces.yml new file mode 100644 index 000000000..e6dee6a5a --- /dev/null +++ b/typesafety/test_primitives/test_tracing/test_collect_traces.yml @@ -0,0 +1,36 @@ +- case: collect_traces_context_manager_return_type_one + disable_cache: true + main: | + from returns.primitives.tracing import collect_traces + + reveal_type(collect_traces) # N: Revealed type is 'Overload(def () -> typing.ContextManager[None], def [_FunctionType <: def (*Any, **Any) -> Any] (function: _FunctionType`-1) -> _FunctionType`-1)' + +- case: collect_traces_context_manager_return_type_two + disable_cache: true + main: | + from returns.primitives.tracing import collect_traces + + with reveal_type(collect_traces()): # N: Revealed type is 'typing.ContextManager[None]' + pass + +- case: collect_traces_decorated_function_return_type + disable_cache: true + main: | + from returns.primitives.tracing import collect_traces + + @collect_traces + def function() -> int: + return 0 + + reveal_type(function) # N: Revealed type is 'def () -> builtins.int' + +- case: collect_traces_decorated_function_with_argument_return_type + disable_cache: true + main: | + from returns.primitives.tracing import collect_traces + + @collect_traces + def function(number: int) -> str: + return str(number) + + reveal_type(function) # N: Revealed type is 'def (number: builtins.int) -> builtins.str'