Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-6347] BugFix: Can't get task logs when serialization is enabled #7092

Merged
merged 5 commits into from
Jan 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@
"is_paused_upon_creation": { "type": "boolean" }
},
"required": [
"params",
"_dag_id",
"fileloc",
"tasks"
Expand Down
16 changes: 9 additions & 7 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,13 @@ def _value_is_hardcoded_default(cls, attrname: str, value: Any) -> bool:
user explicitly specifies an attribute with the same "value" as the
default. (This is because ``"default" is "default"`` will be False as
they are different strings with the same characters.)

Also returns True if the value is an empty list or empty dict. This is done
to account for the case where the default value of the field is None but has the
``field = field or {}`` set.
"""
if attrname in cls._CONSTRUCTOR_PARAMS and cls._CONSTRUCTOR_PARAMS[attrname].default is value:
if attrname in cls._CONSTRUCTOR_PARAMS and \
(cls._CONSTRUCTOR_PARAMS[attrname].default is value or (value in [{}, []])):
ashb marked this conversation as resolved.
Show resolved Hide resolved
return True
return False

Expand All @@ -265,11 +270,11 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
Class specific attributes used by UI are move to object attributes.
"""

_decorated_fields = {'executor_config', }
_decorated_fields = {'executor_config'}

_CONSTRUCTOR_PARAMS = {
k: v for k, v in signature(BaseOperator).parameters.items()
if v.default is not v.empty and v.default is not None
if v.default is not v.empty
}

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -366,9 +371,6 @@ def _is_excluded(cls, var: Any, attrname: str, op: BaseOperator):
dag_date = getattr(op.dag, attrname, None)
if var is dag_date or var == dag_date:
return True
if attrname in {"executor_config", "params"} and not var:
# Don't store empty executor config or params dicts.
return True
return super()._is_excluded(var, attrname, op)

@classmethod
Expand Down Expand Up @@ -470,7 +472,7 @@ def __get_constructor_defaults(): # pylint: disable=no-method-argument
}
return {
param_to_attr.get(k, k): v for k, v in signature(DAG).parameters.items()
if v.default is not v.empty and v.default is not None
if v.default is not v.empty
}
_CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__() # type: ignore
del __get_constructor_defaults
Expand Down
2 changes: 1 addition & 1 deletion airflow/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def wrapper(*args, **kwargs):
dag_args = copy(dag.default_args) or {}
dag_params = copy(dag.params) or {}

params = kwargs.get('params', {})
params = kwargs.get('params', {}) or {}
ashb marked this conversation as resolved.
Show resolved Hide resolved
dag_params.update(params)

default_args = {}
Expand Down
70 changes: 68 additions & 2 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
},
"start_date": 1564617600.0,
"is_paused_upon_creation": False,
"params": {},
"_dag_id": "simple_dag",
"fileloc": None,
"tasks": [
Expand Down Expand Up @@ -159,6 +158,9 @@ def collect_dags():
dags.update(make_example_dags(example_dags))
dags.update(make_example_dags(contrib_example_dags))
dags.update(make_example_dags(gcp_example_dags))

# Filter subdags as they are stored in same row in Serialized Dag table
dags = {dag_id: dag for dag_id, dag in dags.items() if not dag.is_subdag}
return dags


Expand Down Expand Up @@ -241,6 +243,9 @@ def test_deserialization(self):
self.assertTrue(set(stringified_dags.keys()) == set(dags.keys()))

# Verify deserialized DAGs.
for dag_id in stringified_dags:
self.validate_deserialized_dag(stringified_dags[dag_id], dags[dag_id])

example_skip_dag = stringified_dags['example_skip_dag']
skip_operator_1_task = example_skip_dag.task_dict['skip_operator_1']
self.validate_deserialized_task(
Expand All @@ -260,6 +265,22 @@ def test_deserialization(self):
SubDagOperator.ui_fgcolor
)

def validate_deserialized_dag(self, serialized_dag, dag):
"""
Verify that all example DAGs work with DAG Serialization by
checking fields between Serialized Dags & non-Serialized Dags
"""
fields_to_check = [
"task_ids", "params", "fileloc", "max_active_runs", "concurrency",
"is_paused_upon_creation", "doc_md", "safe_dag_id", "is_subdag",
"catchup", "description", "start_date", "end_date", "parent_dag",
"template_searchpath"
]

# fields_to_check = dag.get_serialized_fields()
for field in fields_to_check:
self.assertEqual(getattr(serialized_dag, field), getattr(dag, field))

def validate_deserialized_task(self, task, task_type, ui_color, ui_fgcolor):
"""Verify non-airflow operators are casted to BaseOperator."""
self.assertTrue(isinstance(task, SerializedBaseOperator))
Expand All @@ -275,6 +296,8 @@ def validate_deserialized_task(self, task, task_type, ui_color, ui_fgcolor):
self.assertTrue(isinstance(task.subdag, DAG))
else:
self.assertIsNone(task.subdag)
self.assertEqual({}, task.params)
self.assertEqual({}, task.executor_config)

@parameterized.expand([
(datetime(2019, 8, 1), None, datetime(2019, 8, 1)),
Expand Down Expand Up @@ -335,7 +358,6 @@ def test_deserialization_schedule_interval(self, serialized_schedule_interval, e
"__version": 1,
"dag": {
"default_args": {"__type": "dict", "__var": {}},
"params": {},
"_dag_id": "simple_dag",
"fileloc": __file__,
"tasks": [],
Expand Down Expand Up @@ -365,6 +387,50 @@ def test_roundtrip_relativedelta(self, val, expected):
round_tripped = SerializedDAG._deserialize(serialized)
self.assertEqual(val, round_tripped)

@parameterized.expand([
(None, {}),
({"param_1": "value_1"}, {"param_1": "value_1"}),
])
def test_dag_params_roundtrip(self, val, expected_val):
"""
Test that params work both on Serialized DAGs & Tasks
"""
dag = DAG(dag_id='simple_dag', params=val)
BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1))

serialized_dag = SerializedDAG.to_dict(dag)
if val:
self.assertIn("params", serialized_dag["dag"])
else:
self.assertNotIn("params", serialized_dag["dag"])

deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_simple_task = deserialized_dag.task_dict["simple_task"]
self.assertEqual(expected_val, deserialized_dag.params)
self.assertEqual(expected_val, deserialized_simple_task.params)

@parameterized.expand([
(None, {}),
({"param_1": "value_1"}, {"param_1": "value_1"}),
])
def test_task_params_roundtrip(self, val, expected_val):
"""
Test that params work both on Serialized DAGs & Tasks
"""
dag = DAG(dag_id='simple_dag')
BaseOperator(task_id='simple_task', dag=dag, params=val,
start_date=datetime(2019, 8, 1))

serialized_dag = SerializedDAG.to_dict(dag)
if val:
self.assertIn("params", serialized_dag["dag"]["tasks"][0])
else:
self.assertNotIn("params", serialized_dag["dag"]["tasks"][0])

deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_simple_task = deserialized_dag.task_dict["simple_task"]
self.assertEqual(expected_val, deserialized_simple_task.params)

def test_extra_serialized_field_and_operator_links(self):
"""
Assert extra field exists & OperatorLinks defined in Plugins and inbuilt Operator Links.
Expand Down