Skip to content

Commit

Permalink
[AIRFLOW-6206] Move and rename AWS batch operator [AIP-21] (#6764)
Browse files Browse the repository at this point in the history
- conform to AIP-21
  - see https://issues.apache.org/jira/browse/AIRFLOW-4733
  - see https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-21%3A+Changes+in+import+paths
  - use airflow.providers.amazon.aws.operators.batch.AwsBatchOperator
  - deprecate airflow.contrib.operators.awsbatch_operator.AWSBatchOperator

- fix pylint for airflow/providers/amazon/aws/operators/batch.py
  • Loading branch information
Darren Weber authored and ashb committed Dec 17, 2019
1 parent c4c635d commit 7502cad
Show file tree
Hide file tree
Showing 7 changed files with 354 additions and 273 deletions.
1 change: 1 addition & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ Migrated are:
| airflow.contrib.hooks.aws_sqs_hook.SQSHook | airflow.providers.amazon.aws.hooks.sqs.SQSHook |
| airflow.contrib.hooks.aws_sns_hook.AwsSnsHook | airflow.providers.amazon.aws.hooks.sns.AwsSnsHook |
| airflow.contrib.operators.aws_athena_operator.AWSAthenaOperator | airflow.providers.amazon.aws.operators.athena.AWSAthenaOperator |
| airflow.contrib.operators.awsbatch.AWSBatchOperator | airflow.providers.amazon.aws.operators.batch.AwsBatchOperator |
| airflow.contrib.operators.aws_sqs_publish_operator.SQSPublishOperator | airflow.providers.amazon.aws.operators.sqs.SQSPublishOperator |
| airflow.contrib.operators.aws_sns_publish_operator.SnsPublishOperator | airflow.providers.amazon.aws.operators.sns.SnsPublishOperator |
| airflow.contrib.sensors.aws_athena_sensor.AthenaSensor | airflow.providers.amazon.aws.sensors.athena.AthenaSensor |
Expand Down
279 changes: 17 additions & 262 deletions airflow/contrib/operators/awsbatch_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,275 +17,30 @@
# specific language governing permissions and limitations
# under the License.
#
import sys
from math import pow
from random import randint
from time import sleep
from typing import Optional

import botocore.exceptions
import botocore.waiter
"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.batch`."""

from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.typing_compat import Protocol
from airflow.utils.decorators import apply_defaults
import warnings

from airflow.providers.amazon.aws.operators.batch import AwsBatchOperator

class BatchProtocol(Protocol):
"""
.. seealso:: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html
"""

def describe_jobs(self, jobs) -> dict:
...

def get_waiter(self, x: str) -> botocore.waiter.Waiter:
...

def submit_job(
self, jobName, jobQueue, jobDefinition, arrayProperties, parameters, containerOverrides
) -> dict:
...
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.batch`.",
DeprecationWarning,
stacklevel=2,
)

def terminate_job(self, jobId: str, reason: str) -> dict:
...


class AWSBatchOperator(BaseOperator):
class AWSBatchOperator(AwsBatchOperator):
"""
Execute a job on AWS Batch Service
.. warning: the queue parameter was renamed to job_queue to segregate the
internal CeleryExecutor queue from the AWS Batch internal queue.
:param job_name: the name for the job that will run on AWS Batch (templated)
:type job_name: str
:param job_definition: the job definition name on AWS Batch
:type job_definition: str
:param job_queue: the queue name on AWS Batch
:type job_queue: str
:param overrides: the same parameter that boto3 will receive on
containerOverrides (templated)
http://boto3.readthedocs.io/en/latest/reference/services/batch.html#Batch.Client.submit_job
:type overrides: dict
:param array_properties: the same parameter that boto3 will receive on
arrayProperties
http://boto3.readthedocs.io/en/latest/reference/services/batch.html#Batch.Client.submit_job
:type array_properties: dict
:param parameters: the same parameter that boto3 will receive on
parameters (templated)
http://boto3.readthedocs.io/en/latest/reference/services/batch.html#Batch.Client.submit_job
:type parameters: dict
:param max_retries: exponential backoff retries while waiter is not
merged, 4200 = 48 hours
:type max_retries: int
:param status_retries: number of retries to get job description (status), 10
:type status_retries: int
:param aws_conn_id: connection id of AWS credentials / region name. If None,
credential boto3 strategy will be used
(http://boto3.readthedocs.io/en/latest/guide/configuration.html).
:type aws_conn_id: str
:param region_name: region name to use in AWS Hook.
Override the region_name in connection (if provided)
:type region_name: str
This class is deprecated. Please use `airflow.providers.amazon.aws.operators.batch.AwsBatchOperator`.
"""

MAX_RETRIES = 4200
STATUS_RETRIES = 10

ui_color = "#c3dae0"
client = None # type: BatchProtocol
arn = None # type: str
template_fields = (
"job_name",
"overrides",
"parameters",
)

@apply_defaults
def __init__(
self,
job_name,
job_definition,
job_queue,
overrides,
array_properties=None,
parameters=None,
max_retries=MAX_RETRIES,
status_retries=STATUS_RETRIES,
aws_conn_id=None,
region_name=None,
**kwargs,
):
super().__init__(**kwargs)

self.job_name = job_name
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.job_definition = job_definition
self.job_queue = job_queue
self.overrides = overrides
self.array_properties = array_properties or {}
self.parameters = parameters
self.max_retries = max_retries
self.status_retries = status_retries

self.jobId = None # pylint: disable=invalid-name
self.jobName = None # pylint: disable=invalid-name

self.hook = self.get_hook()

def execute(self, context):
self.log.info(
"Running AWS Batch Job - Job definition: %s - on queue %s",
self.job_definition,
self.job_queue,
def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.amazon.aws.operators.batch.AwsBatchOperator`.""",
DeprecationWarning,
stacklevel=2,
)
self.log.info("AWSBatchOperator overrides: %s", self.overrides)

self.client = self.hook.get_client_type("batch", region_name=self.region_name)

try:

response = self.client.submit_job(
jobName=self.job_name,
jobQueue=self.job_queue,
jobDefinition=self.job_definition,
arrayProperties=self.array_properties,
parameters=self.parameters,
containerOverrides=self.overrides,
)

self.log.info("AWS Batch Job started: %s", response)

self.jobId = response["jobId"]
self.jobName = response["jobName"]

self._wait_for_task_ended()
self._check_success_task()

self.log.info("AWS Batch Job has been successfully executed: %s", response)
except Exception as e:
self.log.info("AWS Batch Job has failed executed")
raise AirflowException(e)

def _wait_for_task_ended(self):
"""
Try to use a waiter from the below pull request
* https://github.com/boto/botocore/pull/1307
If the waiter is not available apply a exponential backoff
* docs.aws.amazon.com/general/latest/gr/api-retries.html
"""
try:
waiter = self.client.get_waiter("job_execution_complete")
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
waiter.wait(jobs=[self.jobId])
except ValueError:
self._poll_for_task_ended()

def _poll_for_task_ended(self):
"""
Poll for job status
* docs.aws.amazon.com/general/latest/gr/api-retries.html
"""
# Allow a batch job some time to spin up. A random interval
# decreases the chances of exceeding an AWS API throttle
# limit when there are many concurrent tasks.
pause = randint(5, 30)

tries = 0
while tries < self.max_retries:
tries += 1
self.log.info(
"AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds",
self.jobId,
tries,
self.max_retries,
pause,
)
sleep(pause)

response = self._get_job_description()
jobs = response.get("jobs")
status = jobs[-1]["status"] # check last job status
self.log.info("AWS Batch job (%s) status: %s", self.jobId, status)

# status options: 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED'
if status in ["SUCCEEDED", "FAILED"]:
break

pause = 1 + pow(tries * 0.3, 2)

def _get_job_description(self) -> Optional[dict]:
"""
Get job description
* https://docs.aws.amazon.com/batch/latest/APIReference/API_DescribeJobs.html
* https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html
"""
tries = 0
while tries < self.status_retries:
tries += 1
try:
response = self.client.describe_jobs(jobs=[self.jobId])
if response and response.get("jobs"):
return response
else:
self.log.error("Job description has no jobs (%s): %s", self.jobId, response)
except botocore.exceptions.ClientError as err:
response = err.response
self.log.error("Job description error (%s): %s", self.jobId, response)
if tries < self.status_retries:
error = response.get("Error", {})
if error.get("Code") == "TooManyRequestsException":
pause = randint(1, 10) # avoid excess requests with a random pause
self.log.info(
"AWS Batch job (%s) status retry (%d of %d) in the next %.2f seconds",
self.jobId,
tries,
self.status_retries,
pause,
)
sleep(pause)
continue

msg = "Failed to get job description ({})".format(self.jobId)
raise AirflowException(msg)

def _check_success_task(self):
"""
Check the final status of the batch job; the job status options are:
'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED'
"""
response = self._get_job_description()
jobs = response.get("jobs")

matching_jobs = [job for job in jobs if job["jobId"] == self.jobId]
if not matching_jobs:
raise AirflowException(
"Job ({}) has no job description {}".format(self.jobId, response)
)

job = matching_jobs[0]
self.log.info("AWS Batch stopped, check status: %s", job)
job_status = job["status"]
if job_status == "FAILED":
reason = job["statusReason"]
raise AirflowException("Job ({}) failed with status {}".format(self.jobId, reason))
elif job_status in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"]:
raise AirflowException(
"Job ({}) is still pending {}".format(self.jobId, job_status)
)

def get_hook(self):
return AwsHook(aws_conn_id=self.aws_conn_id)

def on_kill(self):
response = self.client.terminate_job(jobId=self.jobId, reason="Task killed by the user")
self.log.info(response)
super().__init__(*args, **kwargs)
Loading

0 comments on commit 7502cad

Please sign in to comment.