Skip to content

Commit

Permalink
Suppress instrumentation in excluded URLs
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Oct 9, 2024
1 parent c82306b commit 075077e
Showing 1 changed file with 53 additions and 10 deletions.
63 changes: 53 additions & 10 deletions logfire/_internal/integrations/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import inspect
from contextlib import contextmanager
from contextvars import ContextVar
from functools import lru_cache
from typing import Any, Awaitable, Callable, ContextManager, Iterable, cast
from weakref import WeakKeyDictionary
Expand All @@ -15,6 +16,9 @@
from starlette.responses import Response
from starlette.websockets import WebSocket

import logfire

from ... import suppress_instrumentation
from ..main import Logfire
from ..stack_info import StackInfo, get_code_object_info
from ..utils import maybe_capture_server_headers
Expand Down Expand Up @@ -111,6 +115,18 @@ def uninstrument_context():
return uninstrument_context()


current_request: ContextVar[Request | WebSocket | None] = ContextVar('current_request', default=None)


@contextmanager
def current_request_context(request: Request | WebSocket):
token = current_request.set(request)
try:
yield
finally:
current_request.reset(token)


@lru_cache # only patch once
def patch_fastapi():
"""Globally monkeypatch fastapi functions and return a dictionary for recording instrumentation config per app."""
Expand All @@ -134,14 +150,37 @@ async def patched_run_endpoint_function(*, dependant: Any, values: dict[str, Any
if isinstance(values, _InstrumentedValues):
request = values.request
if instrumentation := registry.get(request.app): # pragma: no branch
return await instrumentation.run_endpoint_function(
original_run_endpoint_function, request, dependant, values, **kwargs
)
if instrumentation.is_url_excluded(request):
with suppress_instrumentation():
return await original_run_endpoint_function(dependant=dependant, values=values, **kwargs)

elif (
(request := current_request.get())
and (instrumentation := registry.get(request.app))
and instrumentation.is_url_excluded(request)
):
with suppress_instrumentation():
return await original_run_endpoint_function(
dependant=dependant, values=values, **kwargs
) # pragma: no cover

return await original_run_endpoint_function(dependant=dependant, values=values, **kwargs) # pragma: no cover

original_run_endpoint_function = fastapi.routing.run_endpoint_function
fastapi.routing.run_endpoint_function = patched_run_endpoint_function

def patched_get_request_handler(*args: Any, **kwargs: Any) -> Any:
original_handler = original_get_request_handler(*args, **kwargs)

async def wrapped(request: Any) -> Any:
with current_request_context(request):
return await original_handler(request)

return wrapped

original_get_request_handler = fastapi.routing.get_request_handler
fastapi.routing.get_request_handler = patched_get_request_handler

return registry


Expand All @@ -168,15 +207,11 @@ def __init__(
self.excluded_urls_list = parse_excluded_urls(excluded_urls) # pragma: no cover

async def solve_dependencies(self, request: Request | WebSocket, original: Awaitable[Any]) -> Any:
try:
url = cast(str, get_host_port_url_tuple(request.scope)[2])
excluded = self.excluded_urls_list.url_disabled(url)
except Exception: # pragma: no cover
excluded = False
self.logfire_instance.exception('Error checking if URL is excluded from instrumentation')
excluded = self.is_url_excluded(request)

if excluded:
return await original # pragma: no cover
with logfire.suppress_instrumentation():
return await original # pragma: no cover

with self.logfire_instance.span('FastAPI arguments') as span:
result: Any = await original
Expand Down Expand Up @@ -275,6 +310,14 @@ async def run_endpoint_function(
):
return await original_run_endpoint_function(dependant=dependant, values=values, **kwargs)

def is_url_excluded(self, request: Request | WebSocket) -> bool:
try:
url = cast(str, get_host_port_url_tuple(request.scope)[2])
return self.excluded_urls_list.url_disabled(url)
except Exception: # pragma: no cover
self.logfire_instance.exception('Error checking if URL is excluded from instrumentation')
return False


def _default_request_attributes_mapper(
_request: Request | WebSocket,
Expand Down

0 comments on commit 075077e

Please sign in to comment.