Skip to content

Commit

Permalink
add repair_run support to DatabricksRunNowOperator in deferrable mode (
Browse files Browse the repository at this point in the history
…#38619)

* feat(databricks): add repair_run supporrt to DatabricksRunNowOperator in deferrable mode

* test(databricks): add test case test_execute_complete_failure_and_repair_run
  • Loading branch information
Lee-W authored Apr 1, 2024
1 parent 9da08a5 commit 39b684d
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 53 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def repair_run(self, json: dict) -> int:

def get_latest_repair_id(self, run_id: int) -> int | None:
"""Get latest repair id if any exist for run_id else None."""
json = {"run_id": run_id, "include_history": True}
json = {"run_id": run_id, "include_history": "true"}
response = self._do_api_call(GET_RUN_ENDPOINT, json)
repair_history = response["repair_history"]
if len(repair_history) == 1:
Expand Down
119 changes: 67 additions & 52 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None:
"""
if operator.do_xcom_push and context is not None:
context["ti"].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id)

log.info("Run submitted with run_id: %s", operator.run_id)
run_page_url = hook.get_run_page_url(operator.run_id)
if operator.do_xcom_push and context is not None:
Expand All @@ -66,52 +67,52 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None:
log.info("%s completed successfully.", operator.task_id)
log.info("View run status, Spark UI, and logs at %s", run_page_url)
return
else:
if run_state.result_state == "FAILED":
task_run_id = None
if "tasks" in run_info:
for task in run_info["tasks"]:
if task.get("state", {}).get("result_state", "") == "FAILED":
task_run_id = task["run_id"]
if task_run_id is not None:
run_output = hook.get_run_output(task_run_id)
if "error" in run_output:
notebook_error = run_output["error"]
else:
notebook_error = run_state.state_message

if run_state.result_state == "FAILED":
task_run_id = None
if "tasks" in run_info:
for task in run_info["tasks"]:
if task.get("state", {}).get("result_state", "") == "FAILED":
task_run_id = task["run_id"]
if task_run_id is not None:
run_output = hook.get_run_output(task_run_id)
if "error" in run_output:
notebook_error = run_output["error"]
else:
notebook_error = run_state.state_message
error_message = (
f"{operator.task_id} failed with terminal state: {run_state} "
f"and with the error {notebook_error}"
)
else:
error_message = (
f"{operator.task_id} failed with terminal state: {run_state} "
f"and with the error {run_state.state_message}"
)
if isinstance(operator, DatabricksRunNowOperator) and operator.repair_run:
operator.repair_run = False
log.warning(
"%s but since repair run is set, repairing the run with all failed tasks",
error_message,
)

latest_repair_id = hook.get_latest_repair_id(operator.run_id)
repair_json = {"run_id": operator.run_id, "rerun_all_failed_tasks": True}
if latest_repair_id is not None:
repair_json["latest_repair_id"] = latest_repair_id
operator.json["latest_repair_id"] = hook.repair_run(operator, repair_json)
_handle_databricks_operator_execution(operator, hook, log, context)
raise AirflowException(error_message)

else:
log.info("%s in run state: %s", operator.task_id, run_state)
log.info("View run status, Spark UI, and logs at %s", run_page_url)
log.info("Sleeping for %s seconds.", operator.polling_period_seconds)
time.sleep(operator.polling_period_seconds)
else:
log.info("View run status, Spark UI, and logs at %s", run_page_url)
notebook_error = run_state.state_message
error_message = (
f"{operator.task_id} failed with terminal state: {run_state} "
f"and with the error {notebook_error}"
)
else:
error_message = (
f"{operator.task_id} failed with terminal state: {run_state} "
f"and with the error {run_state.state_message}"
)

if isinstance(operator, DatabricksRunNowOperator) and operator.repair_run:
operator.repair_run = False
log.warning(
"%s but since repair run is set, repairing the run with all failed tasks",
error_message,
)

latest_repair_id = hook.get_latest_repair_id(operator.run_id)
repair_json = {"run_id": operator.run_id, "rerun_all_failed_tasks": True}
if latest_repair_id is not None:
repair_json["latest_repair_id"] = latest_repair_id
operator.json["latest_repair_id"] = hook.repair_run(operator, repair_json)
_handle_databricks_operator_execution(operator, hook, log, context)
raise AirflowException(error_message)

log.info("%s in run state: %s", operator.task_id, run_state)
log.info("View run status, Spark UI, and logs at %s", run_page_url)
log.info("Sleeping for %s seconds.", operator.polling_period_seconds)
time.sleep(operator.polling_period_seconds)

log.info("View run status, Spark UI, and logs at %s", run_page_url)


def _handle_deferrable_databricks_operator_execution(operator, hook, log, context) -> None:
Expand Down Expand Up @@ -146,6 +147,7 @@ def _handle_deferrable_databricks_operator_execution(operator, hook, log, contex
retry_delay=operator.databricks_retry_delay,
retry_args=operator.databricks_retry_args,
run_page_url=run_page_url,
repair_run=getattr(operator, "repair_run", False),
),
method_name=DEFER_METHOD_NAME,
)
Expand All @@ -163,9 +165,15 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger)
if run_state.is_successful:
log.info("Job run completed successfully.")
return
else:
error_message = f"Job run failed with terminal state: {run_state}"
raise AirflowException(error_message)

error_message = f"Job run failed with terminal state: {run_state}"
if event["repair_run"]:
log.warning(
"%s but since repair run is set, repairing the run with all failed tasks",
error_message,
)
return
raise AirflowException(error_message)


class DatabricksJobRunLink(BaseOperatorLink):
Expand Down Expand Up @@ -584,9 +592,6 @@ def execute(self, context):
self.run_id = hook.submit_run(json_normalised)
_handle_deferrable_databricks_operator_execution(self, hook, self.log, context)

def execute_complete(self, context: dict | None, event: dict):
_handle_deferrable_databricks_operator_completion(event, self.log)


class DatabricksRunNowOperator(BaseOperator):
"""
Expand Down Expand Up @@ -734,7 +739,7 @@ class DatabricksRunNowOperator(BaseOperator):
:param do_xcom_push: Whether we should push run_id and run_page_url to xcom.
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
:param deferrable: Run operator in the deferrable mode.
:param repair_run: Repair the databricks run in case of failure, doesn't work in deferrable mode
:param repair_run: Repair the databricks run in case of failure.
"""

# Used in airflow.models.BaseOperator
Expand Down Expand Up @@ -825,6 +830,7 @@ def execute(self, context: Context):
raise AirflowException(f"Job ID for job name {self.json['job_name']} can not be found")
self.json["job_id"] = job_id
del self.json["job_name"]

self.run_id = hook.run_now(self.json)
if self.deferrable:
_handle_deferrable_databricks_operator_execution(self, hook, self.log, context)
Expand All @@ -834,8 +840,17 @@ def execute(self, context: Context):
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event:
_handle_deferrable_databricks_operator_completion(event, self.log)

def on_kill(self):
if event["repair_run"]:
self.repair_run = False
self.run_id = event["run_id"]
latest_repair_id = self._hook.get_latest_repair_id(self.run_id)
repair_json = {"run_id": self.run_id, "rerun_all_failed_tasks": True}
if latest_repair_id is not None:
repair_json["latest_repair_id"] = latest_repair_id
self.json["latest_srepair_id"] = self._hook.repair_run(repair_json)
_handle_deferrable_databricks_operator_execution(self, self._hook, self.log, context)

def on_kill(self) -> None:
if self.run_id:
self._hook.cancel_run(self.run_id)
self.log.info(
Expand Down
4 changes: 4 additions & 0 deletions airflow/providers/databricks/triggers/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
retry_delay: int = 10,
retry_args: dict[Any, Any] | None = None,
run_page_url: str | None = None,
repair_run: bool = False,
) -> None:
super().__init__()
self.run_id = run_id
Expand All @@ -56,6 +57,7 @@ def __init__(
self.retry_delay = retry_delay
self.retry_args = retry_args
self.run_page_url = run_page_url
self.repair_run = repair_run
self.hook = DatabricksHook(
databricks_conn_id,
retry_limit=self.retry_limit,
Expand All @@ -74,6 +76,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"retry_delay": self.retry_delay,
"retry_args": self.retry_args,
"run_page_url": self.run_page_url,
"repair_run": self.repair_run,
},
)

Expand All @@ -87,6 +90,7 @@ async def run(self):
"run_id": self.run_id,
"run_page_url": self.run_page_url,
"run_state": run_state.to_json(),
"repair_run": self.repair_run,
}
)
return
Expand Down
34 changes: 34 additions & 0 deletions tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,7 @@ def test_execute_complete_failure(self, db_mock_class):
"run_id": RUN_ID,
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": False,
}

op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
Expand Down Expand Up @@ -1063,6 +1064,7 @@ def test_init_with_json(self):
"python_params": PYTHON_PARAMS,
"spark_submit_params": SPARK_SUBMIT_PARAMS,
"job_id": JOB_ID,
"repair_run": False,
}
op = DatabricksRunNowOperator(task_id=TASK_ID, json=json)

Expand All @@ -1073,6 +1075,7 @@ def test_init_with_json(self):
"python_params": PYTHON_PARAMS,
"spark_submit_params": SPARK_SUBMIT_PARAMS,
"job_id": JOB_ID,
"repair_run": False,
}
)

Expand Down Expand Up @@ -1442,6 +1445,7 @@ def test_execute_complete_success(self):
"run_id": RUN_ID,
"run_page_url": RUN_PAGE_URL,
"run_state": RunState("TERMINATED", "SUCCESS", "").to_json(),
"repair_run": False,
}

op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
Expand All @@ -1458,6 +1462,7 @@ def test_execute_complete_failure(self, db_mock_class):
"run_id": RUN_ID,
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": False,
}

op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
Expand All @@ -1471,6 +1476,35 @@ def test_execute_complete_failure(self, db_mock_class):
with pytest.raises(AirflowException, match=f"Job run failed with terminal state: {run_state_failed}"):
op.execute_complete(context=None, event=event)

@mock.patch(
"airflow.providers.databricks.operators.databricks._handle_deferrable_databricks_operator_execution"
)
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_execute_complete_failure_and_repair_run(
self, db_mock_class, mock_handle_deferrable_databricks_operator_execution
):
"""
Test `execute_complete` function in case the Trigger has returned a failure event with repair_run=True.
"""
run_state_failed = RunState("TERMINATED", "FAILED", "")
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
event = {
"run_id": RUN_ID,
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": True,
}

op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
op.execute_complete(context=None, event=event)

db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
db_mock.get_latest_repair_id.assert_called_once()
db_mock.repair_run.assert_called_once()
mock_handle_deferrable_databricks_operator_execution.assert_called_once()

def test_execute_complete_incorrect_event_validation_failure(self):
event = {"event_id": "no such column"}
op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID)
Expand Down
3 changes: 3 additions & 0 deletions tests/providers/databricks/triggers/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def test_serialize(self):
"retry_limit": 3,
"retry_args": None,
"run_page_url": RUN_PAGE_URL,
"repair_run": False,
},
)

Expand All @@ -119,6 +120,7 @@ async def test_run_return_success(self, mock_get_run_state, mock_get_run_page_ur
life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="SUCCESS"
).to_json(),
"run_page_url": RUN_PAGE_URL,
"repair_run": False,
}
)

Expand Down Expand Up @@ -148,6 +150,7 @@ async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep):
life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="SUCCESS"
).to_json(),
"run_page_url": RUN_PAGE_URL,
"repair_run": False,
}
)
mock_sleep.assert_called_once()
Expand Down

0 comments on commit 39b684d

Please sign in to comment.