From b0f9afb2b5bbc380557553a135628c344d47c528 Mon Sep 17 00:00:00 2001 From: Ivan Date: Thu, 24 Aug 2023 14:33:15 +0100 Subject: [PATCH] Set context inside templates (#33645) * Set context inside templates --------- Co-authored-by: Ivan Afonichkin Co-authored-by: Tzu-ping Chung (cherry picked from commit 9fa782f622ad9f6e568f0efcadf93595f67b8a20) --- airflow/models/taskinstance.py | 4 +++- tests/models/test_taskinstance.py | 27 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 975e615d95414..0bd9d3d79d2a3 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1641,7 +1641,9 @@ def signal_handler(signum, frame): # Set the validated/merged params on the task object. self.task.params = context["params"] - task_orig = self.render_templates(context=context) + with set_current_context(context): + task_orig = self.render_templates(context=context) + if not test_mode: rtif = RenderedTaskInstanceFields(ti=self, render_templates=False) RenderedTaskInstanceFields.write(rtif) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index ac2f7f71748ef..e50917e101074 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2890,6 +2890,33 @@ def test_echo_env_variables(self, dag_maker): ti.refresh_from_db() assert ti.state == State.SUCCESS + def test_get_current_context_works_in_template(self, dag_maker): + def user_defined_macro(): + from airflow.operators.python import get_current_context + + get_current_context() + + with dag_maker( + "test_context_inside_template", + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + user_defined_macros={"user_defined_macro": user_defined_macro}, + ): + + def foo(arg): + print(arg) + + PythonOperator( + task_id="context_inside_template", + python_callable=foo, + op_kwargs={"arg": "{{ user_defined_macro() }}"}, + ), + dagrun = dag_maker.create_dagrun() + tis = dagrun.get_task_instances() + ti: TaskInstance = next(x for x in tis if x.task_id == "context_inside_template") + ti._run_raw_task() + assert ti.state == State.SUCCESS + @patch.object(Stats, "incr") def test_task_stats(self, stats_mock, create_task_instance): ti = create_task_instance(