Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds missing return type to collect_traces function #442

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/pages/development.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Or as a decorator:
>>> from returns.result import Failure, Result
>>> from returns.primitives.tracing import collect_traces

>>> @collect_traces()
>>> @collect_traces
... def traced_function(value: str) -> IOResult[str, str]:
... return IOFailure(value)

Expand Down
45 changes: 35 additions & 10 deletions returns/primitives/tracing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
import types
from contextlib import contextmanager
from inspect import FrameInfo, stack
from typing import List, Optional
from typing import (
Callable,
ContextManager,
Iterator,
List,
Optional,
TypeVar,
Union,
overload,
)

from returns.result import _Failure

_FunctionType = TypeVar('_FunctionType', bound=Callable)

@contextmanager
def collect_traces():

@overload
def collect_traces() -> ContextManager[None]:
"""Context Manager to active traces collect to the Failures."""


@overload
def collect_traces(function: _FunctionType) -> _FunctionType:
"""Decorator to active traces collect to the Failures."""


def collect_traces(
function: Optional[_FunctionType] = None,
) -> Union[_FunctionType, ContextManager[None]]: # noqa: DAR101, DAR201, DAR301
"""
Context Manager/Decorator to active traces collect to the Failures.

Expand Down Expand Up @@ -36,13 +58,16 @@ def collect_traces():
# doctest: # noqa: DAR301, E501

"""
unpatched_get_trace = getattr(_Failure, '_get_trace') # noqa: B009
substitute_get_trace = types.MethodType(_get_trace, _Failure)
setattr(_Failure, '_get_trace', substitute_get_trace) # noqa: B010
try:
yield
finally:
setattr(_Failure, '_get_trace', unpatched_get_trace) # noqa: B010
@contextmanager
def factory() -> Iterator[None]:
unpatched_get_trace = getattr(_Failure, '_get_trace') # noqa: B009
substitute_get_trace = types.MethodType(_get_trace, _Failure)
setattr(_Failure, '_get_trace', substitute_get_trace) # noqa: B010
try: # noqa: WPS501
yield
finally:
setattr(_Failure, '_get_trace', unpatched_get_trace) # noqa: B010
return factory()(function) if function else factory()


def _get_trace(_self: _Failure) -> Optional[List[FrameInfo]]:
Expand Down
36 changes: 36 additions & 0 deletions typesafety/test_primitives/test_tracing/test_collect_traces.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
- case: collect_traces_context_manager_return_type_one
disable_cache: true
main: |
from returns.primitives.tracing import collect_traces

reveal_type(collect_traces) # N: Revealed type is 'Overload(def () -> typing.ContextManager[None], def [_FunctionType <: def (*Any, **Any) -> Any] (function: _FunctionType`-1) -> _FunctionType`-1)'

- case: collect_traces_context_manager_return_type_two
disable_cache: true
main: |
from returns.primitives.tracing import collect_traces

with reveal_type(collect_traces()): # N: Revealed type is 'typing.ContextManager[None]'
pass

- case: collect_traces_decorated_function_return_type
thepabloaguilar marked this conversation as resolved.
Show resolved Hide resolved
disable_cache: true
main: |
from returns.primitives.tracing import collect_traces

@collect_traces
def function() -> int:
return 0

reveal_type(function) # N: Revealed type is 'def () -> builtins.int'

- case: collect_traces_decorated_function_with_argument_return_type
disable_cache: true
main: |
from returns.primitives.tracing import collect_traces

@collect_traces
def function(number: int) -> str:
return str(number)

reveal_type(function) # N: Revealed type is 'def (number: builtins.int) -> builtins.str'