Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deferrable mode to MLEngineStartTrainingJobOperator #27405

Merged
merged 1 commit into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 110 additions & 5 deletions airflow/providers/google/cloud/hooks/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@
import logging
import random
import time
from typing import Callable
from typing import Callable, cast

from aiohttp import ClientSession
from gcloud.aio.auth import AioSession, Token
from googleapiclient.discovery import Resource, build
from googleapiclient.errors import HttpError
from httplib2 import Response
from requests import Session

from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.version import version as airflow_version

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -132,7 +136,7 @@ def create_job(self, job: dict, project_id: str, use_existing_job_fn: Callable |
# 409 means there is an existing job with the same job ID.
if e.resp.status == 409:
if use_existing_job_fn is not None:
existing_job = self._get_job(project_id, job_id)
existing_job = self.get_job(project_id, job_id)
if not use_existing_job_fn(existing_job):
self.log.error(
"Job with job_id %s already exist, but it does not match our expectation: %s",
Expand All @@ -147,6 +151,40 @@ def create_job(self, job: dict, project_id: str, use_existing_job_fn: Callable |

return self._wait_for_job_done(project_id, job_id)

@GoogleBaseHook.fallback_to_default_project_id
def create_job_without_waiting_result(
self,
body: dict,
project_id: str,
):
"""
Launches a MLEngine job and wait for it to reach a terminal state.

:param project_id: The Google Cloud project id within which MLEngine
job will be launched. If set to None or missing, the default project_id from the Google Cloud
connection is used.
:param body: MLEngine Job object that should be provided to the MLEngine
API, such as: ::

{
'jobId': 'my_job_id',
'trainingInput': {
'scaleTier': 'STANDARD_1',
...
}
}
:return: The MLEngine job_id of the object if the job successfully reach a
terminal state (which might be FAILED or CANCELLED state).
"""
hook = self.get_conn()

self._append_label(body)

request = hook.projects().jobs().create(parent=f"projects/{project_id}", body=body)
job_id = body["jobId"]
request.execute(num_retries=self.num_retries)
return job_id

@GoogleBaseHook.fallback_to_default_project_id
def cancel_job(
self,
Expand Down Expand Up @@ -181,7 +219,7 @@ def cancel_job(
self.log.error("Failed to cancel MLEngine job: %s", e)
raise

def _get_job(self, project_id: str, job_id: str) -> dict:
def get_job(self, project_id: str, job_id: str) -> dict:
"""
Gets a MLEngine job based on the job id.

Expand Down Expand Up @@ -223,7 +261,7 @@ def _wait_for_job_done(self, project_id: str, job_id: str, interval: int = 30):
if interval <= 0:
raise ValueError("Interval must be > 0")
while True:
job = self._get_job(project_id, job_id)
job = self.get_job(project_id, job_id)
if job["state"] in ["SUCCEEDED", "FAILED", "CANCELLED"]:
return job
time.sleep(interval)
Expand Down Expand Up @@ -489,3 +527,70 @@ def _delete_all_versions(self, model_name: str, project_id: str):
def _append_label(self, model: dict) -> None:
model["labels"] = model.get("labels", {})
model["labels"]["airflow-version"] = _AIRFLOW_VERSION


class MLEngineAsyncHook(GoogleBaseAsyncHook):
"""Uses gcloud-aio library to retrieve Job details"""

sync_hook_class = MLEngineHook

def _check_fileds(
self,
job_id: str,
project_id: str | None = None,
):
if not project_id:
raise AirflowException("Google Cloud project id is required.")
if not job_id:
raise AirflowException("An unique job id is required for Google MLEngine training job.")

async def _get_link(self, url: str, session: Session):
s = AioSession(session)
t = Token(scopes=["https://www.googleapis.com/auth/cloud-platform"])
headers = {
"Authorization": f"Bearer {t.get()}",
"accept": "application/json",
"accept-encoding": "gzip, deflate",
"user-agent": "(gzip)",
"x-goog-api-client": "gdcl/1.12.11 gl-python/3.8.15",
}
return await s.get(url=url, headers=headers)

async def get_job(self, job_id: str, session: Session, project_id: str | None = None):
"""Get the specified job resource by job ID and project ID."""
self._check_fileds(project_id=project_id, job_id=job_id)

url = f"https://ml.googleapis.com/v1/projects/{project_id}/jobs/{job_id}"
return await self._get_link(url=url, session=session)

async def get_job_status(
self,
job_id: str,
project_id: str | None = None,
) -> str | None:
"""
Polls for job status asynchronously using gcloud-aio.

Note that an OSError is raised when Job results are still pending.
Exception means that Job finished with errors
"""
self._check_fileds(project_id=project_id, job_id=job_id)

async with ClientSession() as s:
try:
job_response = await self.get_job(
project_id=project_id, job_id=job_id, session=cast(Session, s)
)
json_response = await job_response.json()
self.log.info("Retrieving json_response: %s", json_response)

if json_response["state"] in ["SUCCEEDED", "FAILED", "CANCELLED"]:
job_status = "success"
elif json_response["state"] in ["PREPARING", "RUNNING"]:
job_status = "pending"
except OSError:
job_status = "pending"
except Exception as e:
self.log.info("Query execution finished with errors...")
job_status = str(e)
return job_status
147 changes: 122 additions & 25 deletions airflow/providers/google/cloud/operators/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@

import logging
import re
import time
import warnings
from typing import TYPE_CHECKING, Any, Sequence

from googleapiclient.errors import HttpError

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.mlengine import MLEngineHook
Expand All @@ -34,6 +37,7 @@
MLEngineModelsListLink,
MLEngineModelVersionDetailsLink,
)
from airflow.providers.google.cloud.triggers.mlengine import MLEngineStartTrainingJobTrigger

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -1094,8 +1098,6 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
for the specified service account.
If set to None or missing, the Google-managed Cloud ML Engine service account will be used.
:param project_id: The Google Cloud project name within which MLEngine training job should run.
If set to None or missing, the default project_id from the Google Cloud connection is used.
(templated)
:param gcp_conn_id: The connection ID to use when fetching connection info.
:param delegate_to: The account to impersonate using domain-wide delegation of authority,
if any. For this to work, the service account making the request must have
Expand All @@ -1116,6 +1118,8 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called
:param deferrable: Run operator in the deferrable mode
"""

template_fields: Sequence[str] = (
Expand All @@ -1142,6 +1146,7 @@ def __init__(
*,
job_id: str,
region: str,
project_id: str,
package_uris: list[str] | None = None,
training_python_module: str | None = None,
training_args: list[str] | None = None,
Expand All @@ -1152,13 +1157,14 @@ def __init__(
python_version: str | None = None,
job_dir: str | None = None,
service_account: str | None = None,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
mode: str = "PRODUCTION",
labels: dict[str, str] | None = None,
impersonation_chain: str | Sequence[str] | None = None,
hyperparameters: dict | None = None,
deferrable: bool = False,
cancel_on_kill: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -1181,6 +1187,8 @@ def __init__(
self._labels = labels
self._hyperparameters = hyperparameters
self._impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.cancel_on_kill = cancel_on_kill

custom = self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM"
custom_image = (
Expand Down Expand Up @@ -1209,10 +1217,16 @@ def __init__(
"a custom Docker image should be provided but not both."
)

def _handle_job_error(self, finished_training_job) -> None:
if finished_training_job["state"] != "SUCCEEDED":
self.log.error("MLEngine training job failed: %s", str(finished_training_job))
raise RuntimeError(finished_training_job["errorMessage"])

def execute(self, context: Context):
job_id = _normalize_mlengine_job_id(self._job_id)
self.job_id = job_id
training_request: dict[str, Any] = {
"jobId": job_id,
"jobId": self.job_id,
"trainingInput": {
"scaleTier": self._scale_tier,
"region": self._region,
Expand Down Expand Up @@ -1261,29 +1275,62 @@ def execute(self, context: Context):
delegate_to=self._delegate_to,
impersonation_chain=self._impersonation_chain,
)
self.hook = hook

# Helper method to check if the existing job's training input is the
# same as the request we get here.
def check_existing_job(existing_job):
existing_training_input = existing_job.get("trainingInput")
requested_training_input = training_request["trainingInput"]
if "scaleTier" not in existing_training_input:
existing_training_input["scaleTier"] = None

existing_training_input["args"] = existing_training_input.get("args")
requested_training_input["args"] = (
requested_training_input["args"] if requested_training_input["args"] else None
try:
self.log.info("Executing: %s'", training_request)
self.job_id = self.hook.create_job_without_waiting_result(
project_id=self._project_id,
body=training_request,
)

return existing_training_input == requested_training_input

finished_training_job = hook.create_job(
project_id=self._project_id, job=training_request, use_existing_job_fn=check_existing_job
)

if finished_training_job["state"] != "SUCCEEDED":
self.log.error("MLEngine training job failed: %s", str(finished_training_job))
raise RuntimeError(finished_training_job["errorMessage"])
except HttpError as e:
if e.resp.status == 409:
# If the job already exists retrieve it
self.hook.get_job(project_id=self._project_id, job_id=self.job_id)
if self._project_id:
MLEngineJobDetailsLink.persist(
context=context,
task_instance=self,
project_id=self._project_id,
job_id=self.job_id,
)
self.log.error(
"Failed to create new job with given name since it already exists. "
"The existing one will be used."
)
else:
raise e

context["ti"].xcom_push(key="job_id", value=self.job_id)
if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=MLEngineStartTrainingJobTrigger(
conn_id=self._gcp_conn_id,
job_id=self.job_id,
project_id=self._project_id,
region=self._region,
runtime_version=self._runtime_version,
python_version=self._python_version,
job_dir=self._job_dir,
package_uris=self._package_uris,
training_python_module=self._training_python_module,
training_args=self._training_args,
labels=self._labels,
gcp_conn_id=self._gcp_conn_id,
impersonation_chain=self._impersonation_chain,
delegate_to=self._delegate_to,
),
method_name="execute_complete",
)
else:
finished_training_job = self._wait_for_job_done(self._project_id, self.job_id)
self._handle_job_error(finished_training_job)
gcp_metadata = {
"job_id": self.job_id,
"project_id": self._project_id,
}
context["task_instance"].xcom_push("gcp_metadata", gcp_metadata)

project_id = self._project_id or hook.project_id
if project_id:
Expand All @@ -1294,6 +1341,56 @@ def check_existing_job(existing_job):
job_id=job_id,
)

def _wait_for_job_done(self, project_id: str, job_id: str, interval: int = 30):
"""
Waits for the Job to reach a terminal state.

This method will periodically check the job state until the job reach
a terminal state.

:param project_id: The project in which the Job is located. If set to None or missing, the default
project_id from the Google Cloud connection is used. (templated)
:param job_id: A unique id for the Google MLEngine job. (templated)
:param interval: Time expressed in seconds after which the job status is checked again. (templated)
:raises: googleapiclient.errors.HttpError
"""
self.log.info("Waiting for job. job_id=%s", job_id)

if interval <= 0:
raise ValueError("Interval must be > 0")
while True:
job = self.hook.get_job(project_id, job_id)
if job["state"] in ["SUCCEEDED", "FAILED", "CANCELLED"]:
return job
time.sleep(interval)

def execute_complete(self, context: Context, event: dict[str, Any]):
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(
"%s completed with response %s ",
self.task_id,
event["message"],
)
if self._project_id:
MLEngineJobDetailsLink.persist(
context=context,
task_instance=self,
project_id=self._project_id,
job_id=self._job_id,
)

def on_kill(self) -> None:
if self.job_id and self.cancel_on_kill:
self.hook.cancel_job(job_id=self.job_id, project_id=self._project_id) # type: ignore[union-attr]
else:
self.log.info("Skipping to cancel job: %s:%s.%s", self._project_id, self.job_id)


class MLEngineTrainingCancelJobOperator(BaseOperator):
"""
Expand Down
Loading