Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Trace functions which return Awaitable #15650

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15650.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support to trace a function which returns an `Awaitable`.
MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved
37 changes: 26 additions & 11 deletions synapse/logging/opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def set_fates(clotho, lachesis, atropos, father="Zues", mother="Themis"):
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Collection,
ContextManager,
Expand Down Expand Up @@ -903,6 +904,7 @@ def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) ->
"""

if inspect.iscoroutinefunction(func):
# For this branch, we handle async functions like `async def func() -> RInner`.
# In this branch, R = Awaitable[RInner], for some other type RInner
@wraps(func)
async def _wrapper(
Expand All @@ -914,36 +916,49 @@ async def _wrapper(
return await func(*args, **kwargs) # type: ignore[misc]

else:
# The other case here handles both sync functions and those
# decorated with inlineDeferred.
# The other case here handles sync functions including those decorated with
# `@defer.inlineCallbacks` or return a `Deferred`, and those that return
# `Awaitable`
MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved
@wraps(func)
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
scope = wrapping_logic(func, *args, **kwargs)
scope.__enter__()

try:
result = func(*args, **kwargs)

if isinstance(result, defer.Deferred):

def call_back(result: R) -> R:
scope.__exit__(None, None, None)
return result

def err_back(result: R) -> R:
# XXX: Feels like we could put the error details into the
# `scope.__exit__(...)`
MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved
scope.__exit__(None, None, None)
return result

result.addCallbacks(call_back, err_back)

elif inspect.isawaitable(result):

async def await_coroutine() -> Any:
MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved
try:
assert isinstance(result, Awaitable)
awaited_result = await result
scope.__exit__(None, None, None)
return awaited_result
except Exception as e:
scope.__exit__(type(e), None, e.__traceback__)
raise

# The original method returned a coroutine, so we create another
# coroutine wrapping it, that calls `scope.__exit__(...)`.
MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved
return await_coroutine()
else:
if inspect.isawaitable(result):
logger.error(
"@trace may not have wrapped %s correctly! "
"The function is not async but returned a %s.",
func.__qualname__,
type(result).__name__,
)

# Just a simple sync function so we can just exit the scope and
# return the result without any fuss.
scope.__exit__(None, None, None)

return result
Expand Down
38 changes: 37 additions & 1 deletion tests/logging/test_opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast
from typing import Awaitable, cast

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactorClock
Expand Down Expand Up @@ -277,3 +277,39 @@ async def fixture_async_func() -> str:
[span.operation_name for span in self._reporter.get_spans()],
["fixture_async_func"],
)

def test_trace_decorator_awaitable_return(self) -> None:
"""
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with functions that return an awaitable (e.g. a coroutine)
"""
reactor = MemoryReactorClock()

with LoggingContext("root context"):
# Something we can return without `await` to get a coroutine
async def fixture_async_func() -> str:
return "foo"

# The actual kind of function we want to test that returns an awaitable
@trace_with_opname("fixture_awaitable_return_func", tracer=self._tracer)
@tag_args
def fixture_awaitable_return_func() -> Awaitable[str]:
return fixture_async_func()

# Something we can run with `defer.ensureDeferred(runner())` and pump the
# whole async tasks through to completion.
async def runner() -> str:
return await fixture_awaitable_return_func()

d1 = defer.ensureDeferred(runner())
MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved

# let the tasks complete
reactor.pump((2,) * 8)

MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(self.successResultOf(d1), "foo")

# the span should have been reported
self.assertEqual(
[span.operation_name for span in self._reporter.get_spans()],
["fixture_awaitable_return_func"],
)