diff --git a/airflow/sentry.py b/airflow/sentry.py index 8dc9091513845..62eac9abf7610 100644 --- a/airflow/sentry.py +++ b/airflow/sentry.py @@ -21,7 +21,7 @@ from functools import wraps from airflow.configuration import conf -from airflow.utils.session import provide_session +from airflow.utils.session import find_session_idx, provide_session from airflow.utils.state import State log = logging.getLogger(__name__) @@ -149,14 +149,21 @@ def add_breadcrumbs(self, task_instance, session=None): def enrich_errors(self, func): """Wrap TaskInstance._run_raw_task to support task specific tags and breadcrumbs.""" + session_args_idx = find_session_idx(func) @wraps(func) - def wrapper(task_instance, *args, session=None, **kwargs): + def wrapper(task_instance, *args, **kwargs): # Wrapping the _run_raw_task function with push_scope to contain # tags and breadcrumbs to a specific Task Instance + + try: + session = kwargs.get('session', args[session_args_idx]) + except IndexError: + session = None + with sentry_sdk.push_scope(): try: - return func(task_instance, *args, session=session, **kwargs) + return func(task_instance, *args, **kwargs) except Exception as e: self.add_tagging(task_instance) self.add_breadcrumbs(task_instance, session=session) diff --git a/airflow/utils/session.py b/airflow/utils/session.py index 4001a0f454387..f8b9bcd071eb4 100644 --- a/airflow/utils/session.py +++ b/airflow/utils/session.py @@ -40,6 +40,18 @@ def create_session(): RT = TypeVar("RT") # pylint: disable=invalid-name +def find_session_idx(func: Callable[..., RT]) -> int: + """Find session index in function call parameter.""" + func_params = signature(func).parameters + try: + # func_params is an ordered dict -- this is the "recommended" way of getting the position + session_args_idx = tuple(func_params).index("session") + except ValueError: + raise ValueError(f"Function {func.__qualname__} has no `session` argument") from None + + return session_args_idx + + def provide_session(func: Callable[..., RT]) -> Callable[..., RT]: """ Function decorator that provides a session if it isn't provided. @@ -47,14 +59,7 @@ def provide_session(func: Callable[..., RT]) -> Callable[..., RT]: database transaction, you pass it to the function, if not this wrapper will create one and close it for you. """ - func_params = signature(func).parameters - try: - # func_params is an ordered dict -- this is the "recommended" way of getting the position - session_args_idx = tuple(func_params).index("session") - except ValueError: - raise ValueError(f"Function {func.__qualname__} has no `session` argument") from None - # We don't need this anymore -- ensure we don't keep a reference to it by mistake - del func_params + session_args_idx = find_session_idx(func) @wraps(func) def wrapper(*args, **kwargs) -> RT: diff --git a/tests/utils/test_session.py b/tests/utils/test_session.py new file mode 100644 index 0000000000000..08f317f42e2a3 --- /dev/null +++ b/tests/utils/test_session.py @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import pytest + +from airflow.utils.session import provide_session + + +class TestSession: + def dummy_session(self, session=None): + return session + + def test_raised_provide_session(self): + with pytest.raises(ValueError, match="Function .*dummy has no `session` argument"): + + @provide_session + def dummy(): + pass + + def test_provide_session_without_args_and_kwargs(self): + assert self.dummy_session() is None + + wrapper = provide_session(self.dummy_session) + + assert wrapper() is not None + + def test_provide_session_with_args(self): + wrapper = provide_session(self.dummy_session) + + session = object() + assert wrapper(session) is session + + def test_provide_session_with_kwargs(self): + wrapper = provide_session(self.dummy_session) + + session = object() + assert wrapper(session=session) is session