diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index 0f15d27a09bb5a..d20dd3ddabdd64 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -98,7 +98,6 @@ "is_paused_upon_creation": { "type": "boolean" } }, "required": [ - "params", "_dag_id", "fileloc", "tasks" diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 959a48e9e3ecfb..d39bd506083623 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -252,8 +252,13 @@ def _value_is_hardcoded_default(cls, attrname: str, value: Any) -> bool: user explicitly specifies an attribute with the same "value" as the default. (This is because ``"default" is "default"`` will be False as they are different strings with the same characters.) + + Also returns True if the value is an empty list or empty dict. This is done + to account for the case where the default value of the field is None but has the + ``field = field or {}`` set. """ - if attrname in cls._CONSTRUCTOR_PARAMS and cls._CONSTRUCTOR_PARAMS[attrname].default is value: + if attrname in cls._CONSTRUCTOR_PARAMS and \ + (cls._CONSTRUCTOR_PARAMS[attrname].default is value or (value in [{}, []])): return True return False @@ -265,11 +270,11 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): Class specific attributes used by UI are move to object attributes. """ - _decorated_fields = {'executor_config', } + _decorated_fields = {'executor_config'} _CONSTRUCTOR_PARAMS = { k: v for k, v in signature(BaseOperator).parameters.items() - if v.default is not v.empty and v.default is not None + if v.default is not v.empty } def __init__(self, *args, **kwargs): @@ -366,9 +371,6 @@ def _is_excluded(cls, var: Any, attrname: str, op: BaseOperator): dag_date = getattr(op.dag, attrname, None) if var is dag_date or var == dag_date: return True - if attrname in {"executor_config", "params"} and not var: - # Don't store empty executor config or params dicts. - return True return super()._is_excluded(var, attrname, op) @classmethod @@ -470,7 +472,7 @@ def __get_constructor_defaults(): # pylint: disable=no-method-argument } return { param_to_attr.get(k, k): v for k, v in signature(DAG).parameters.items() - if v.default is not v.empty and v.default is not None + if v.default is not v.empty } _CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__() # type: ignore del __get_constructor_defaults diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py index 17f1546f041054..ca91f02bc45f5d 100644 --- a/airflow/utils/decorators.py +++ b/airflow/utils/decorators.py @@ -64,7 +64,7 @@ def wrapper(*args, **kwargs): dag_args = copy(dag.default_args) or {} dag_params = copy(dag.params) or {} - params = kwargs.get('params', {}) + params = kwargs.get('params', {}) or {} dag_params.update(params) default_args = {} diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 8e8ed2037caeb6..ae2cbea8894d13 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -55,7 +55,6 @@ }, "start_date": 1564617600.0, "is_paused_upon_creation": False, - "params": {}, "_dag_id": "simple_dag", "fileloc": None, "tasks": [ @@ -159,6 +158,9 @@ def collect_dags(): dags.update(make_example_dags(example_dags)) dags.update(make_example_dags(contrib_example_dags)) dags.update(make_example_dags(gcp_example_dags)) + + # Filter subdags as they are stored in same row in Serialized Dag table + dags = {dag_id: dag for dag_id, dag in dags.items() if not dag.is_subdag} return dags @@ -241,6 +243,9 @@ def test_deserialization(self): self.assertTrue(set(stringified_dags.keys()) == set(dags.keys())) # Verify deserialized DAGs. + for dag_id in stringified_dags: + self.validate_deserialized_dag(stringified_dags[dag_id], dags[dag_id]) + example_skip_dag = stringified_dags['example_skip_dag'] skip_operator_1_task = example_skip_dag.task_dict['skip_operator_1'] self.validate_deserialized_task( @@ -260,6 +265,22 @@ def test_deserialization(self): SubDagOperator.ui_fgcolor ) + def validate_deserialized_dag(self, serialized_dag, dag): + """ + Verify that all example DAGs work with DAG Serialization by + checking fields between Serialized Dags & non-Serialized Dags + """ + fields_to_check = [ + "task_ids", "params", "fileloc", "max_active_runs", "concurrency", + "is_paused_upon_creation", "doc_md", "safe_dag_id", "is_subdag", + "catchup", "description", "start_date", "end_date", "parent_dag", + "template_searchpath" + ] + + # fields_to_check = dag.get_serialized_fields() + for field in fields_to_check: + self.assertEqual(getattr(serialized_dag, field), getattr(dag, field)) + def validate_deserialized_task(self, task, task_type, ui_color, ui_fgcolor): """Verify non-airflow operators are casted to BaseOperator.""" self.assertTrue(isinstance(task, SerializedBaseOperator)) @@ -275,6 +296,8 @@ def validate_deserialized_task(self, task, task_type, ui_color, ui_fgcolor): self.assertTrue(isinstance(task.subdag, DAG)) else: self.assertIsNone(task.subdag) + self.assertEqual({}, task.params) + self.assertEqual({}, task.executor_config) @parameterized.expand([ (datetime(2019, 8, 1), None, datetime(2019, 8, 1)), @@ -335,7 +358,6 @@ def test_deserialization_schedule_interval(self, serialized_schedule_interval, e "__version": 1, "dag": { "default_args": {"__type": "dict", "__var": {}}, - "params": {}, "_dag_id": "simple_dag", "fileloc": __file__, "tasks": [], @@ -365,6 +387,50 @@ def test_roundtrip_relativedelta(self, val, expected): round_tripped = SerializedDAG._deserialize(serialized) self.assertEqual(val, round_tripped) + @parameterized.expand([ + (None, {}), + ({"param_1": "value_1"}, {"param_1": "value_1"}), + ]) + def test_dag_params_roundtrip(self, val, expected_val): + """ + Test that params work both on Serialized DAGs & Tasks + """ + dag = DAG(dag_id='simple_dag', params=val) + BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1)) + + serialized_dag = SerializedDAG.to_dict(dag) + if val: + self.assertIn("params", serialized_dag["dag"]) + else: + self.assertNotIn("params", serialized_dag["dag"]) + + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_simple_task = deserialized_dag.task_dict["simple_task"] + self.assertEqual(expected_val, deserialized_dag.params) + self.assertEqual(expected_val, deserialized_simple_task.params) + + @parameterized.expand([ + (None, {}), + ({"param_1": "value_1"}, {"param_1": "value_1"}), + ]) + def test_task_params_roundtrip(self, val, expected_val): + """ + Test that params work both on Serialized DAGs & Tasks + """ + dag = DAG(dag_id='simple_dag') + BaseOperator(task_id='simple_task', dag=dag, params=val, + start_date=datetime(2019, 8, 1)) + + serialized_dag = SerializedDAG.to_dict(dag) + if val: + self.assertIn("params", serialized_dag["dag"]["tasks"][0]) + else: + self.assertNotIn("params", serialized_dag["dag"]["tasks"][0]) + + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_simple_task = deserialized_dag.task_dict["simple_task"] + self.assertEqual(expected_val, deserialized_simple_task.params) + def test_extra_serialized_field_and_operator_links(self): """ Assert extra field exists & OperatorLinks defined in Plugins and inbuilt Operator Links.