Skip to content

Commit

Permalink
[AIRFLOW-3871] Operators template fields can now render fields inside…
Browse files Browse the repository at this point in the history
… objects (#4743)

(cherry picked from commit d567f9a)
  • Loading branch information
galak75 authored and ashb committed Sep 24, 2019
1 parent d0b8e1f commit 6ad673c
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 6 deletions.
33 changes: 27 additions & 6 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,14 +670,18 @@ def render_template_fields(self, context, jinja_env=None):
if not jinja_env:
jinja_env = self.get_template_env()

for attr_name in self.template_fields:
content = getattr(self, attr_name)
self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())

def _do_render_template_fields(self, parent, template_fields, context, jinja_env, seen_oids):
# type: (Any, Iterable[str], Dict, jinja2.Environment, Set) -> None
for attr_name in template_fields:
content = getattr(parent, attr_name)
if content:
rendered_content = self.render_template(content, context, jinja_env)
setattr(self, attr_name, rendered_content)
rendered_content = self.render_template(content, context, jinja_env, seen_oids)
setattr(parent, attr_name, rendered_content)

def render_template(self, content, context, jinja_env=None):
# type: (Any, Dict, Optional[jinja2.Environment]) -> Any
def render_template(self, content, context, jinja_env=None, seen_oids=None):
# type: (Any, Dict, Optional[jinja2.Environment], Optional[Set]) -> Any
"""
Render a templated string. The content can be a collection holding multiple templated strings and will
be templated recursively.
Expand All @@ -689,6 +693,8 @@ def render_template(self, content, context, jinja_env=None):
:param jinja_env: Jinja environment. Can be provided to avoid re-creating Jinja environments during
recursion.
:type jinja_env: jinja2.Environment
:param seen_oids: template fields already rendered (to avoid RecursionError on circular dependencies)
:type seen_oids: set
:return: Templated content
"""

Expand Down Expand Up @@ -721,8 +727,23 @@ def render_template(self, content, context, jinja_env=None):
return {self.render_template(element, context, jinja_env) for element in content}

else:
if seen_oids is None:
seen_oids = set()
self._render_nested_template_fields(content, context, jinja_env, seen_oids)
return content

def _render_nested_template_fields(self, content, context, jinja_env, seen_oids):
# type: (Any, Dict, jinja2.Environment, Set) -> None
if id(content) not in seen_oids:
seen_oids.add(id(content))
try:
nested_template_fields = content.template_fields
except AttributeError:
# content has no inner template fields
return

self._do_render_template_fields(content, nested_template_fields, context, jinja_env, seen_oids)

def get_template_env(self): # type: () -> jinja2.Environment
"""Fetch a Jinja template environment from the DAG or instantiate empty environment if no DAG."""
return self.dag.get_template_env() if self.has_dag() else jinja2.Environment(cache_size=0)
Expand Down
55 changes: 55 additions & 0 deletions docs/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,61 @@ You can use Jinja templating with every parameter that is marked as "templated"
in the documentation. Template substitution occurs just before the pre_execute
function of your operator is called.

You can also use Jinja templating with nested fields, as long as these nested fields
are marked as templated in the structure they belong to: fields registered in
``template_fields`` property will be submitted to template substitution, like the
``path`` field in the example below:

.. code:: python
class MyDataReader:
template_fields = ['path']
def __init__(self, my_path):
self.path = my_path
# [additional code here...]
t = PythonOperator(
task_id='transform_data',
python_callable=transform_data
op_args=[
MyDataReader('/tmp/{{ ds }}/my_file')
],
dag=dag)
.. note:: ``template_fields`` property can equally be a class variable or an
instance variable.

Deep nested fields can also be substituted, as long as all intermediate fields are
marked as template fields:

.. code:: python
class MyDataTransformer:
template_fields = ['reader']
def __init__(self, my_reader):
self.reader = my_reader
# [additional code here...]
class MyDataReader:
template_fields = ['path']
def __init__(self, my_path):
self.path = my_path
# [additional code here...]
t = PythonOperator(
task_id='transform_data',
python_callable=transform_data
op_args=[
MyDataTransformer(MyDataReader('/tmp/{{ ds }}/my_file'))
],
dag=dag)
Packaged dags
'''''''''''''
While often you will specify dags in a single ``.py`` file it might sometimes
Expand Down
95 changes: 95 additions & 0 deletions tests/models/test_baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,39 @@ def execute(self, context):
TestNamedTuple = namedtuple("TestNamedTuple", ["var1", "var2"])


class ClassWithCustomAttributes:
"""Class for testing purpose: allows to create objects with custom attributes in one single statement."""

def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)

def __str__(self):
return "{}({})".format(ClassWithCustomAttributes.__name__, str(self.__dict__))

def __repr__(self):
return self.__str__()

def __eq__(self, other):
return self.__dict__ == other.__dict__

def __ne__(self, other):
return not self.__eq__(other)


# Objects with circular references (for testing purpose)
object1 = ClassWithCustomAttributes(
attr="{{ foo }}_1",
template_fields=["ref"]
)
object2 = ClassWithCustomAttributes(
attr="{{ foo }}_2",
ref=object1,
template_fields=["ref"]
)
setattr(object1, 'ref', object2)


class BaseOperatorTest(unittest.TestCase):
@parameterized.expand(
[
Expand All @@ -72,6 +105,49 @@ class BaseOperatorTest(unittest.TestCase):
(datetime.datetime(2018, 12, 6, 10, 55), {"foo": "bar"}, datetime.datetime(2018, 12, 6, 10, 55)),
(TestNamedTuple("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, TestNamedTuple("bar_1", "bar_2")),
({"{{ foo }}_1", "{{ foo }}_2"}, {"foo": "bar"}, {"bar_1", "bar_2"}),
(None, {}, None),
([], {}, []),
({}, {}, {}),
(
# check nested fields can be templated
ClassWithCustomAttributes(att1="{{ foo }}_1", att2="{{ foo }}_2", template_fields=["att1"]),
{"foo": "bar"},
ClassWithCustomAttributes(att1="bar_1", att2="{{ foo }}_2", template_fields=["att1"]),
),
(
# check deep nested fields can be templated
ClassWithCustomAttributes(nested1=ClassWithCustomAttributes(att1="{{ foo }}_1",
att2="{{ foo }}_2",
template_fields=["att1"]),
nested2=ClassWithCustomAttributes(att3="{{ foo }}_3",
att4="{{ foo }}_4",
template_fields=["att3"]),
template_fields=["nested1"]),
{"foo": "bar"},
ClassWithCustomAttributes(nested1=ClassWithCustomAttributes(att1="bar_1",
att2="{{ foo }}_2",
template_fields=["att1"]),
nested2=ClassWithCustomAttributes(att3="{{ foo }}_3",
att4="{{ foo }}_4",
template_fields=["att3"]),
template_fields=["nested1"]),
),
(
# check null value on nested template field
ClassWithCustomAttributes(att1=None,
template_fields=["att1"]),
{},
ClassWithCustomAttributes(att1=None,
template_fields=["att1"]),
),
(
# check there is no RecursionError on circular references
object1,
{"foo": "bar"},
object1,
),
# By default, Jinja2 drops one (single) trailing newline
("{{ foo }}\n\n", {"foo": "bar"}, "bar\n"),
]
)
def test_render_template(self, content, context, expected_output):
Expand Down Expand Up @@ -133,6 +209,25 @@ def test_render_template_field_undefined_strict(self):
with self.assertRaises(jinja2.UndefinedError):
task.render_template("{{ foo }}", {})

def test_nested_template_fields_declared_must_exist(self):
"""Test render_template when a nested template field is missing."""
with DAG("test-dag", start_date=DEFAULT_DATE):
task = DummyOperator(task_id="op1")

with self.assertRaises(AttributeError) as e:
task.render_template(ClassWithCustomAttributes(template_fields=["missing_field"]), {})

self.assertEqual("'ClassWithCustomAttributes' object has no attribute 'missing_field'",
str(e.exception))

def test_jinja_invalid_expression_is_just_propagated(self):
"""Test render_template propagates Jinja invalid expression errors."""
with DAG("test-dag", start_date=DEFAULT_DATE):
task = DummyOperator(task_id="op1")

with self.assertRaises(jinja2.exceptions.TemplateSyntaxError):
task.render_template("{{ invalid expression }}", {})

@mock.patch("jinja2.Environment", autospec=True)
def test_jinja_env_creation(self, mock_jinja_env):
"""Verify if a Jinja environment is created only once when templating."""
Expand Down

0 comments on commit 6ad673c

Please sign in to comment.