Skip to content

Commit

Permalink
Add option to wait for completion on the EmrCreateJobFlowOperator (#2…
Browse files Browse the repository at this point in the history
…8827)

* Fix countdown logic to handle None (=infinite)

* Convert none to inf float for simpler code logic

---------

Co-authored-by: Niko Oliveira <[email protected]>
  • Loading branch information
BasPH and o-nikolas authored Jan 30, 2023
1 parent 246d778 commit 5490102
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 14 deletions.
9 changes: 9 additions & 0 deletions airflow/providers/amazon/aws/hooks/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ def add_job_flow_steps(
)
return response["StepIds"]

def terminate_job_flow(self, job_flow_id: str) -> None:
"""
Terminate a given EMR cluster (job flow) by id. If TerminationProtected=True on the cluster,
termination will be unsuccessful.
:param job_flow_id: id of the job flow to terminate
"""
self.get_conn().terminate_job_flows(JobFlowIds=[job_flow_id])

def test_connection(self):
"""
Return failed state for test Amazon Elastic MapReduce Connection (untestable).
Expand Down
62 changes: 51 additions & 11 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,12 @@ class EmrCreateJobFlowOperator(BaseOperator):
:param job_flow_overrides: boto3 style arguments or reference to an arguments file
(must be '.json') to override specific ``emr_conn_id`` extra parameters. (templated)
:param region_name: Region named passed to EmrHook
:param wait_for_completion: Whether to finish task immediately after creation (False) or wait for jobflow
completion (True)
:param waiter_countdown: Max. seconds to wait for jobflow completion (only in combination with
wait_for_completion=True, None = no limit)
:param waiter_check_interval_seconds: Number of seconds between polling the jobflow state. Defaults to 60
seconds.
"""

template_fields: Sequence[str] = ("job_flow_overrides",)
Expand All @@ -539,42 +545,76 @@ def __init__(
emr_conn_id: str | None = "emr_default",
job_flow_overrides: str | dict[str, Any] | None = None,
region_name: str | None = None,
wait_for_completion: bool = False,
waiter_countdown: int | None = None,
waiter_check_interval_seconds: int = 60,
**kwargs,
):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.emr_conn_id = emr_conn_id
self.job_flow_overrides = job_flow_overrides or {}
self.region_name = region_name
self.wait_for_completion = wait_for_completion
self.waiter_countdown = waiter_countdown
self.waiter_check_interval_seconds = waiter_check_interval_seconds

self._job_flow_id: str | None = None

def execute(self, context: Context) -> str:
emr = EmrHook(
@cached_property
def _emr_hook(self) -> EmrHook:
"""Create and return an EmrHook."""
return EmrHook(
aws_conn_id=self.aws_conn_id, emr_conn_id=self.emr_conn_id, region_name=self.region_name
)

def execute(self, context: Context) -> str | None:
self.log.info(
"Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s", self.aws_conn_id, self.emr_conn_id
"Creating job flow using aws_conn_id: %s, emr_conn_id: %s", self.aws_conn_id, self.emr_conn_id
)
if isinstance(self.job_flow_overrides, str):
job_flow_overrides: dict[str, Any] = ast.literal_eval(self.job_flow_overrides)
self.job_flow_overrides = job_flow_overrides
else:
job_flow_overrides = self.job_flow_overrides
response = emr.create_job_flow(job_flow_overrides)
response = self._emr_hook.create_job_flow(job_flow_overrides)

if not response["ResponseMetadata"]["HTTPStatusCode"] == 200:
raise AirflowException(f"JobFlow creation failed: {response}")
raise AirflowException(f"Job flow creation failed: {response}")
else:
job_flow_id = response["JobFlowId"]
self.log.info("JobFlow with id %s created", job_flow_id)
self._job_flow_id = response["JobFlowId"]
self.log.info("Job flow with id %s created", self._job_flow_id)
EmrClusterLink.persist(
context=context,
operator=self,
region_name=emr.conn_region_name,
aws_partition=emr.conn_partition,
job_flow_id=job_flow_id,
region_name=self._emr_hook.conn_region_name,
aws_partition=self._emr_hook.conn_partition,
job_flow_id=self._job_flow_id,
)
return job_flow_id

if self.wait_for_completion:
# Didn't use a boto-supplied waiter because those don't support waiting for WAITING state.
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#waiters
waiter(
get_state_callable=self._emr_hook.get_conn().describe_cluster,
get_state_args={"ClusterId": self._job_flow_id},
parse_response=["Cluster", "Status", "State"],
# Cluster will be in WAITING after finishing if KeepJobFlowAliveWhenNoSteps is True
desired_state={"WAITING", "TERMINATED"},
failure_states={"TERMINATED_WITH_ERRORS"},
object_type="job flow",
action="finished",
countdown=self.waiter_countdown,
check_interval_seconds=self.waiter_check_interval_seconds,
)

return self._job_flow_id

def on_kill(self) -> None:
"""Terminate job flow."""
if self._job_flow_id:
self.log.info("Terminating job flow %s", self._job_flow_id)
self._emr_hook.terminate_job_flow(self._job_flow_id)


class EmrModifyClusterOperator(BaseOperator):
Expand Down
10 changes: 7 additions & 3 deletions airflow/providers/amazon/aws/utils/waiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def waiter(
failure_states: set,
object_type: str,
action: str,
countdown: int = 25 * 60,
countdown: int | float | None = 25 * 60,
check_interval_seconds: int = 60,
) -> None:
"""
Expand All @@ -49,8 +49,8 @@ def waiter(
exception if any are reached before the desired_state
:param object_type: Used for the reporting string. What are you waiting for? (application, job, etc)
:param action: Used for the reporting string. What action are you waiting for? (created, deleted, etc)
:param countdown: Total amount of time the waiter should wait for the desired state
before timing out (in seconds). Defaults to 25 * 60 seconds.
:param countdown: Number of seconds the waiter should wait for the desired state before timing out.
Defaults to 25 * 60 seconds. None = infinite.
:param check_interval_seconds: Number of seconds waiter should wait before attempting
to retry get_state_callable. Defaults to 60 seconds.
"""
Expand All @@ -60,6 +60,10 @@ def waiter(
break
if state in failure_states:
raise AirflowException(f"{object_type.title()} reached failure state {state}.")

if countdown is None:
countdown = float("inf")

if countdown > check_interval_seconds:
countdown -= check_interval_seconds
log.info("Waiting for %s to be %s.", object_type.lower(), action.lower())
Expand Down

0 comments on commit 5490102

Please sign in to comment.