diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index 90182de3c6eb3..f2b05d535a63e 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -82,6 +82,7 @@ "max_active_runs": { "type" : "number"}, "default_args": { "$ref": "#/definitions/dict" }, "start_date": { "$ref": "#/definitions/datetime" }, + "end_date": { "$ref": "#/definitions/datetime" }, "dagrun_timeout": { "$ref": "#/definitions/timedelta" }, "doc_md": { "type" : "string"} }, diff --git a/airflow/serialization/serialized_dag.py b/airflow/serialization/serialized_dag.py index 8749a504093d5..1f6d3b3316f00 100644 --- a/airflow/serialization/serialized_dag.py +++ b/airflow/serialization/serialized_dag.py @@ -126,6 +126,11 @@ def deserialize_dag(cls, encoded_dag): for task in dag.task_dict.values(): task.dag = dag task = cast(SerializedBaseOperator, task) + + for date_attr in ["start_date", "end_date"]: + if getattr(task, date_attr) is None: + setattr(task, date_attr, getattr(dag, date_attr)) + if task.subdag is not None: setattr(task.subdag, 'parent_dag', dag) task.subdag.is_subdag = True diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 7d212ec2e5f94..b46dc4a92c33c 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -292,6 +292,55 @@ def validate_operator_extra_links(self, task): "https://www.google.com" ) + @parameterized.expand([ + (datetime(2019, 8, 1), None, datetime(2019, 8, 1)), + (datetime(2019, 8, 1), datetime(2019, 8, 2), datetime(2019, 8, 2)), + (datetime(2019, 8, 1), datetime(2019, 7, 30), datetime(2019, 8, 1)), + ]) + def test_deserialization_start_date(self, + dag_start_date, + task_start_date, + expected_task_start_date): + dag = DAG(dag_id='simple_dag', start_date=dag_start_date) + BaseOperator(task_id='simple_task', dag=dag, start_date=task_start_date) + + serialized_dag = SerializedDAG.to_dict(dag) + if not task_start_date or dag_start_date >= task_start_date: + # If dag.start_date > task.start_date -> task.start_date=dag.start_date + # because of the logic in dag.add_task() + self.assertNotIn("start_date", serialized_dag["dag"]["tasks"][0]) + else: + self.assertIn("start_date", serialized_dag["dag"]["tasks"][0]) + + dag = SerializedDAG.from_dict(serialized_dag) + simple_task = dag.task_dict["simple_task"] + self.assertEqual(simple_task.start_date, expected_task_start_date) + + @parameterized.expand([ + (datetime(2019, 8, 1), None, datetime(2019, 8, 1)), + (datetime(2019, 8, 1), datetime(2019, 8, 2), datetime(2019, 8, 1)), + (datetime(2019, 8, 1), datetime(2019, 7, 30), datetime(2019, 7, 30)), + ]) + def test_deserialization_end_date(self, + dag_end_date, + task_end_date, + expected_task_end_date): + dag = DAG(dag_id='simple_dag', start_date=datetime(2019, 8, 1), + end_date=dag_end_date) + BaseOperator(task_id='simple_task', dag=dag, end_date=task_end_date) + + serialized_dag = SerializedDAG.to_dict(dag) + if not task_end_date or dag_end_date <= task_end_date: + # If dag.end_date < task.end_date -> task.end_date=dag.end_date + # because of the logic in dag.add_task() + self.assertNotIn("end_date", serialized_dag["dag"]["tasks"][0]) + else: + self.assertIn("end_date", serialized_dag["dag"]["tasks"][0]) + + dag = SerializedDAG.from_dict(serialized_dag) + simple_task = dag.task_dict["simple_task"] + self.assertEqual(simple_task.end_date, expected_task_end_date) + @parameterized.expand([ (None, None), ("@weekly", "@weekly"),