Skip to content

Commit

Permalink
Ensure secrets are masked
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy committed Mar 22, 2024
1 parent 0abc9b4 commit 548164b
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 16 deletions.
3 changes: 2 additions & 1 deletion airflow/models/renderedtifields.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def __init__(self, ti: TaskInstance, render_templates=True):

self.k8s_pod_yaml = render_k8s_pod_yaml(ti)
self.rendered_fields = {
field: serialize_template_field(getattr(self.task, field)) for field in self.task.template_fields
field: serialize_template_field(getattr(self.task, field), field)
for field in self.task.template_fields
}

self._redact()
Expand Down
6 changes: 4 additions & 2 deletions airflow/serialization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@

from airflow.configuration import conf
from airflow.settings import json
from airflow.utils.log.secrets_masker import redact


def serialize_template_field(template_field: Any) -> str | dict | list | int | float:
def serialize_template_field(template_field: Any, name) -> str | dict | list | int | float:
"""Return a serializable representation of the templated field.
If ``templated_field`` contains a class or instance that requires recursive
Expand All @@ -42,9 +43,10 @@ def is_jsonable(x):
max_length = conf.getint("core", "max_templated_field_length")

if template_field and len(str(template_field)) > max_length:
rendered = redact(str(template_field), name)
return (
"Truncated. You can change this behaviour in [core]max_templated_field_length. "
f"{str(template_field)[:max_length-79]}... "
f"{rendered[:max_length-79]!r}... "
)
if not is_jsonable(template_field):
return str(template_field)
Expand Down
2 changes: 1 addition & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool)
)
value = getattr(op, template_field, None)
if not cls._is_excluded(value, template_field, op):
serialize_op[template_field] = serialize_template_field(value)
serialize_op[template_field] = serialize_template_field(value, template_field)

if op.params:
serialize_op["params"] = cls._serialize_params_dict(op.params)
Expand Down
19 changes: 7 additions & 12 deletions tests/models/test_renderedtifields.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __str__(self):


max_length = conf.getint("core", "max_templated_field_length")
LargeStr = "a" * 5000


class TestRenderedTaskInstanceFields:
Expand Down Expand Up @@ -125,11 +126,11 @@ def teardown_method(self):
),
(
"a" * 5000,
f"Truncated. You can change this behaviour in [core]max_templated_field_length. {('a'*5000)[:max_length-79]}... ",
f"Truncated. You can change this behaviour in [core]max_templated_field_length. {('a'*5000)[:max_length-79]!r}... ",
),
(
LargeStrObject(),
f"Truncated. You can change this behaviour in [core]max_templated_field_length. {str(LargeStrObject())[:max_length-79]}... ",
f"Truncated. You can change this behaviour in [core]max_templated_field_length. {str(LargeStrObject())[:max_length-79]!r}... ",
),
],
)
Expand Down Expand Up @@ -168,32 +169,26 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field, da
# Fetching them will return None
assert RTIF.get_templated_fields(ti=ti2) is None

@mock.patch("airflow.utils.log.secrets_masker.redact", autospec=True)
def test_secrets_are_masked_when_large_string(self, redact, dag_maker):
def test_secrets_are_masked_when_large_string(self, dag_maker):
"""
Test that secrets are masked when the templated field is a large string
"""
Variable.set(
key="test_key",
key="api_key",
value="Some very long secret with private information that asserts private is not in the rendered field"
* 5000,
)
with dag_maker("test_serialized_rendered_fields"):
task = BashOperator(task_id="test", bash_command="echo {{ var.value.test_key }}")
task = BashOperator(task_id="test", bash_command="echo {{ var.value.api_key }}")
dr = dag_maker.create_dagrun()
redact.side_effect = [
"val 1", # bash_command
"val 2",
"val 3",
]
ti = dr.task_instances[0]
ti.task = task
rtif = RTIF(ti=ti)

assert ti.dag_id == rtif.dag_id
assert ti.task_id == rtif.task_id
assert ti.run_id == rtif.run_id
assert "val 1" == rtif.rendered_fields.get("bash_command")
assert "***" in rtif.rendered_fields.get("bash_command")

@pytest.mark.parametrize(
"rtif_num, num_to_keep, remaining_rtifs, expected_query_count",
Expand Down

0 comments on commit 548164b

Please sign in to comment.