Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cancel_previous_run to DatabricksRunNowOperator #38702

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ class DatabricksRunNowOperator(BaseOperator):
- ``spark_submit_params``
- ``idempotency_token``
- ``repair_run``
- ``cancel_previous_runs``

:param job_id: the job_id of the existing Databricks job.
This field will be templated.
Expand Down Expand Up @@ -740,6 +741,7 @@ class DatabricksRunNowOperator(BaseOperator):
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
:param deferrable: Run operator in the deferrable mode.
:param repair_run: Repair the databricks run in case of failure.
:param cancel_previous_runs: Cancel all existing running jobs before submitting new one.
"""

# Used in airflow.models.BaseOperator
Expand Down Expand Up @@ -771,6 +773,7 @@ def __init__(
wait_for_termination: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
repair_run: bool = False,
cancel_previous_runs: bool = False,
**kwargs,
) -> None:
"""Create a new ``DatabricksRunNowOperator``."""
Expand All @@ -784,6 +787,7 @@ def __init__(
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
self.repair_run = repair_run
self.cancel_previous_runs = cancel_previous_runs

if job_id is not None:
self.json["job_id"] = job_id
Expand Down Expand Up @@ -830,6 +834,8 @@ def execute(self, context: Context):
raise AirflowException(f"Job ID for job name {self.json['job_name']} can not be found")
self.json["job_id"] = job_id
del self.json["job_name"]
if self.cancel_previous_runs:
hook.cancel_all_runs(job_id)

self.run_id = hook.run_now(self.json)
if self.deferrable:
Expand Down
64 changes: 64 additions & 0 deletions tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,70 @@ def test_exec_failure_if_job_id_not_found(self, db_mock_class):

db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME)

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_cancel_previous_runs(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=True, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
Lee-W marked this conversation as resolved.
Show resolved Hide resolved

assert op.cancel_previous_runs

op.execute(None)

expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
"jar_params": JAR_PARAMS,
"job_id": JOB_ID,
}
)
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay,
retry_args=None,
caller="DatabricksRunNowOperator",
)

db_mock.cancel_all_runs.assert_called_once_with(JOB_ID)
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_not_called()

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_no_cancel_previous_runs(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=False, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1

assert not op.cancel_previous_runs

op.execute(None)

expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
"jar_params": JAR_PARAMS,
"job_id": JOB_ID,
}
)
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay,
retry_args=None,
caller="DatabricksRunNowOperator",
)

db_mock.cancel_all_runs.assert_not_called()
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_not_called()


class TestDatabricksRunNowDeferrableOperator:
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
Expand Down
Loading