diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 0eedc444fb97b..247d810a6bf3d 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -651,6 +651,7 @@ class DatabricksRunNowOperator(BaseOperator): - ``spark_submit_params`` - ``idempotency_token`` - ``repair_run`` + - ``cancel_previous_runs`` :param job_id: the job_id of the existing Databricks job. This field will be templated. @@ -740,6 +741,7 @@ class DatabricksRunNowOperator(BaseOperator): :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. :param deferrable: Run operator in the deferrable mode. :param repair_run: Repair the databricks run in case of failure. + :param cancel_previous_runs: Cancel all existing running jobs before submitting new one. """ # Used in airflow.models.BaseOperator @@ -771,6 +773,7 @@ def __init__( wait_for_termination: bool = True, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), repair_run: bool = False, + cancel_previous_runs: bool = False, **kwargs, ) -> None: """Create a new ``DatabricksRunNowOperator``.""" @@ -784,6 +787,7 @@ def __init__( self.wait_for_termination = wait_for_termination self.deferrable = deferrable self.repair_run = repair_run + self.cancel_previous_runs = cancel_previous_runs if job_id is not None: self.json["job_id"] = job_id @@ -831,6 +835,9 @@ def execute(self, context: Context): self.json["job_id"] = job_id del self.json["job_name"] + if self.cancel_previous_runs and self.json["job_id"] is not None: + hook.cancel_all_runs(self.json["job_id"]) + self.run_id = hook.run_now(self.json) if self.deferrable: _handle_deferrable_databricks_operator_execution(self, hook, self.log, context) diff --git a/docs/apache-airflow-providers-databricks/operators/run_now.rst b/docs/apache-airflow-providers-databricks/operators/run_now.rst index a4b00d9005c81..facf47e7d6c56 100644 --- a/docs/apache-airflow-providers-databricks/operators/run_now.rst +++ b/docs/apache-airflow-providers-databricks/operators/run_now.rst @@ -45,6 +45,9 @@ All other parameters are optional and described in documentation for ``Databrick * ``python_named_parameters`` * ``jar_params`` * ``spark_submit_params`` +* ``idempotency_token`` +* ``repair_run`` +* ``cancel_previous_runs`` DatabricksRunNowDeferrableOperator ================================== diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 6797377161962..f2a3441f435cf 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -737,7 +737,7 @@ def test_exec_success(self, db_mock_class): } op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) @@ -767,7 +767,7 @@ def test_exec_pipeline_name(self, db_mock_class): op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value db_mock.find_pipeline_id_by_name.return_value = PIPELINE_ID_TASK["pipeline_id"] - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) @@ -798,7 +798,7 @@ def test_exec_failure(self, db_mock_class): } op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") with pytest.raises(AirflowException): @@ -845,7 +845,7 @@ def test_wait_for_termination(self, db_mock_class): } op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") assert op.wait_for_termination @@ -875,7 +875,7 @@ def test_no_wait_for_termination(self, db_mock_class): } op = DatabricksSubmitRunOperator(task_id=TASK_ID, wait_for_termination=False, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID assert not op.wait_for_termination @@ -909,7 +909,7 @@ def test_execute_task_deferred(self, db_mock_class): } op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("RUNNING", "RUNNING") with pytest.raises(TaskDeferred) as exc: @@ -971,7 +971,7 @@ def test_execute_complete_failure(self, db_mock_class): op.execute_complete(context=None, event=event) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") with pytest.raises(AirflowException, match=f"Job run failed with terminal state: {run_state_failed}"): @@ -993,7 +993,7 @@ def test_databricks_submit_run_deferrable_operator_failed_before_defer(self, moc } op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") op.execute(None) @@ -1023,7 +1023,7 @@ def test_databricks_submit_run_deferrable_operator_success_before_defer(self, mo } op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) @@ -1147,7 +1147,7 @@ def test_exec_success(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) @@ -1181,7 +1181,7 @@ def test_exec_failure(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") with pytest.raises(AirflowException): @@ -1215,7 +1215,7 @@ def test_exec_failure_with_message(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = mock_dict( { "job_id": JOB_ID, @@ -1279,7 +1279,7 @@ def test_wait_for_termination(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") assert op.wait_for_termination @@ -1311,7 +1311,7 @@ def test_no_wait_for_termination(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, wait_for_termination=False, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID assert not op.wait_for_termination @@ -1357,7 +1357,7 @@ def test_exec_with_job_name(self, db_mock_class): op = DatabricksRunNowOperator(task_id=TASK_ID, job_name=JOB_NAME, json=run) db_mock = db_mock_class.return_value db_mock.find_job_id_by_name.return_value = JOB_ID - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) @@ -1397,6 +1397,74 @@ def test_exec_failure_if_job_id_not_found(self, db_mock_class): db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME) + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_cancel_previous_runs(self, db_mock_class): + run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} + op = DatabricksRunNowOperator( + task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=True, wait_for_termination=False, json=run + ) + db_mock = db_mock_class.return_value + db_mock.run_now.return_value = RUN_ID + + assert op.cancel_previous_runs + + op.execute(None) + + expected = utils.normalise_json_content( + { + "notebook_params": NOTEBOOK_PARAMS, + "notebook_task": NOTEBOOK_TASK, + "jar_params": JAR_PARAMS, + "job_id": JOB_ID, + } + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksRunNowOperator", + ) + + db_mock.cancel_all_runs.assert_called_once_with(JOB_ID) + db_mock.run_now.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_not_called() + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_no_cancel_previous_runs(self, db_mock_class): + run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} + op = DatabricksRunNowOperator( + task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=False, wait_for_termination=False, json=run + ) + db_mock = db_mock_class.return_value + db_mock.run_now.return_value = RUN_ID + + assert not op.cancel_previous_runs + + op.execute(None) + + expected = utils.normalise_json_content( + { + "notebook_params": NOTEBOOK_PARAMS, + "notebook_task": NOTEBOOK_TASK, + "jar_params": JAR_PARAMS, + "job_id": JOB_ID, + } + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksRunNowOperator", + ) + + db_mock.cancel_all_runs.assert_not_called() + db_mock.run_now.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_not_called() + class TestDatabricksRunNowDeferrableOperator: @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") @@ -1407,7 +1475,7 @@ def test_execute_task_deferred(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("RUNNING", "RUNNING") with pytest.raises(TaskDeferred) as exc: @@ -1470,7 +1538,7 @@ def test_execute_complete_failure(self, db_mock_class): op.execute_complete(context=None, event=event) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") with pytest.raises(AirflowException, match=f"Job run failed with terminal state: {run_state_failed}"): @@ -1499,7 +1567,7 @@ def test_execute_complete_failure_and_repair_run( op.execute_complete(context=None, event=event) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") db_mock.get_latest_repair_id.assert_called_once() db_mock.repair_run.assert_called_once() @@ -1521,7 +1589,7 @@ def test_operator_failed_before_defer(self, mock_defer, db_mock_class): } op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") op.execute(None) @@ -1548,7 +1616,7 @@ def test_databricks_run_now_deferrable_operator_failed_before_defer(self, mock_d run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") op.execute(None) @@ -1581,7 +1649,7 @@ def test_databricks_run_now_deferrable_operator_success_before_defer(self, mock_ run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None)