From d5840267c4d22a78355cc9631614f8c2657c25d8 Mon Sep 17 00:00:00 2001 From: Bas Harenslak Date: Mon, 28 Oct 2019 21:34:49 +0100 Subject: [PATCH 1/2] [AIRFLOW-5775] Migrate AWS Batch components to /providers/aws [AIP-21] --- UPDATING.md | 1 + .../contrib/operators/awsbatch_operator.py | 196 +--------------- airflow/providers/aws/operators/batch.py | 209 ++++++++++++++++++ docs/operators-and-hooks-ref.rst | 4 +- scripts/ci/pylint_todo.txt | 2 +- .../aws/operators/test_batch.py} | 2 +- tests/test_core_to_contrib.py | 4 + 7 files changed, 227 insertions(+), 191 deletions(-) create mode 100644 airflow/providers/aws/operators/batch.py rename tests/{contrib/operators/test_awsbatch_operator.py => providers/aws/operators/test_batch.py} (99%) diff --git a/UPDATING.md b/UPDATING.md index a09271f909f97f..6538a71745f1bd 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -52,6 +52,7 @@ Migrated are: |-----------------------------------------------------------------|----------------------------------------------------------| | airflow.contrib.hooks.aws_athena_hook.AWSAthenaHook | airflow.providers.aws.hooks.athena.AWSAthenaHook | | airflow.contrib.operators.aws_athena_operator.AWSAthenaOperator | airflow.providers.aws.operators.athena.AWSAthenaOperator | +| airflow.contrib.operators.awsbatch_operator.AWSBatchOperator | airflow.providers.aws.operators.batch.AWSBatchOperator | | airflow.contrib.sensors.aws_athena_sensor.AthenaSensor | airflow.providers.aws.sensors.athena.AthenaSensor | ### Additional arguments passed to BaseOperator cause an exception diff --git a/airflow/contrib/operators/awsbatch_operator.py b/airflow/contrib/operators/awsbatch_operator.py index 22508d17c56a2a..aa6af273cec17c 100644 --- a/airflow/contrib/operators/awsbatch_operator.py +++ b/airflow/contrib/operators/awsbatch_operator.py @@ -16,194 +16,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -import sys -from math import pow -from random import randint -from time import sleep -from typing import Optional - -from airflow.contrib.hooks.aws_hook import AwsHook -from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.typing import Protocol -from airflow.utils.decorators import apply_defaults - - -class BatchProtocol(Protocol): - def submit_job(self, jobName, jobQueue, jobDefinition, containerOverrides): - ... - - def get_waiter(self, x: str): - ... - - def describe_jobs(self, jobs): - ... - - def terminate_job(self, jobId: str, reason: str): - ... - - -class AWSBatchOperator(BaseOperator): - """ - Execute a job on AWS Batch Service - - .. warning: the queue parameter was renamed to job_queue to segregate the - internal CeleryExecutor queue from the AWS Batch internal queue. - - :param job_name: the name for the job that will run on AWS Batch (templated) - :type job_name: str - :param job_definition: the job definition name on AWS Batch - :type job_definition: str - :param job_queue: the queue name on AWS Batch - :type job_queue: str - :param overrides: the same parameter that boto3 will receive on - containerOverrides (templated): - http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job - :type overrides: dict - :param array_properties: the same parameter that boto3 will receive on - arrayProperties: - http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job - :type array_properties: dict - :param max_retries: exponential backoff retries while waiter is not - merged, 4200 = 48 hours - :type max_retries: int - :param aws_conn_id: connection id of AWS credentials / region name. If None, - credential boto3 strategy will be used - (http://boto3.readthedocs.io/en/latest/guide/configuration.html). - :type aws_conn_id: str - :param region_name: region name to use in AWS Hook. - Override the region_name in connection (if provided) - :type region_name: str - """ - - ui_color = '#c3dae0' - client = None # type: Optional[BatchProtocol] - arn = None # type: Optional[str] - template_fields = ('job_name', 'overrides',) - - @apply_defaults - def __init__(self, job_name, job_definition, job_queue, overrides, array_properties=None, - max_retries=4200, aws_conn_id=None, region_name=None, **kwargs): - super().__init__(**kwargs) - - self.job_name = job_name - self.aws_conn_id = aws_conn_id - self.region_name = region_name - self.job_definition = job_definition - self.job_queue = job_queue - self.overrides = overrides - self.array_properties = array_properties - self.max_retries = max_retries - - self.jobId = None # pylint: disable=invalid-name - self.jobName = None # pylint: disable=invalid-name - - self.hook = self.get_hook() - - def execute(self, context): - self.log.info( - 'Running AWS Batch Job - Job definition: %s - on queue %s', - self.job_definition, self.job_queue - ) - self.log.info('AWSBatchOperator overrides: %s', self.overrides) - - self.client = self.hook.get_client_type( - 'batch', - region_name=self.region_name - ) - - try: - response = self.client.submit_job( - jobName=self.job_name, - jobQueue=self.job_queue, - jobDefinition=self.job_definition, - arrayProperties=self.array_properties, - containerOverrides=self.overrides) - - self.log.info('AWS Batch Job started: %s', response) - - self.jobId = response['jobId'] - self.jobName = response['jobName'] - - self._wait_for_task_ended() - - self._check_success_task() - - self.log.info('AWS Batch Job has been successfully executed: %s', response) - except Exception as e: - self.log.info('AWS Batch Job has failed executed') - raise AirflowException(e) - - def _wait_for_task_ended(self): - """ - Try to use a waiter from the below pull request - - * https://github.com/boto/botocore/pull/1307 - - If the waiter is not available apply a exponential backoff - - * docs.aws.amazon.com/general/latest/gr/api-retries.html - """ - try: - waiter = self.client.get_waiter('job_execution_complete') - waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow - waiter.wait(jobs=[self.jobId]) - except ValueError: - # If waiter not available use expo - - # Allow a batch job some time to spin up. A random interval - # decreases the chances of exceeding an AWS API throttle - # limit when there are many concurrent tasks. - pause = randint(5, 30) - - retries = 1 - while retries <= self.max_retries: - self.log.info('AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds', - self.jobId, retries, self.max_retries, pause) - sleep(pause) - - response = self.client.describe_jobs(jobs=[self.jobId]) - status = response['jobs'][-1]['status'] - self.log.info('AWS Batch job (%s) status: %s', self.jobId, status) - if status in ['SUCCEEDED', 'FAILED']: - break - - retries += 1 - pause = 1 + pow(retries * 0.3, 2) - - def _check_success_task(self): - response = self.client.describe_jobs( - jobs=[self.jobId], - ) - - self.log.info('AWS Batch stopped, check status: %s', response) - if len(response.get('jobs')) < 1: - raise AirflowException('No job found for {}'.format(response)) - for job in response['jobs']: - job_status = job['status'] - if job_status == 'FAILED': - reason = job['statusReason'] - raise AirflowException('Job failed with status {}'.format(reason)) - elif job_status in [ - 'SUBMITTED', - 'PENDING', - 'RUNNABLE', - 'STARTING', - 'RUNNING' - ]: - raise AirflowException( - 'This task is still pending {}'.format(job_status)) +"""This module is deprecated. Please use `airflow.providers.aws.operators.batch`.""" - def get_hook(self): - return AwsHook( - aws_conn_id=self.aws_conn_id - ) +import warnings - def on_kill(self): - response = self.client.terminate_job( - jobId=self.jobId, - reason='Task killed by the user') +# pylint: disable=unused-import +from airflow.providers.aws.operators.batch import BatchProtocol, AWSBatchOperator # noqa - self.log.info(response) +warnings.warn( + "This module is deprecated. Please use `airflow.providers.aws.operators.batch`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/aws/operators/batch.py b/airflow/providers/aws/operators/batch.py new file mode 100644 index 00000000000000..22508d17c56a2a --- /dev/null +++ b/airflow/providers/aws/operators/batch.py @@ -0,0 +1,209 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import sys +from math import pow +from random import randint +from time import sleep +from typing import Optional + +from airflow.contrib.hooks.aws_hook import AwsHook +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.typing import Protocol +from airflow.utils.decorators import apply_defaults + + +class BatchProtocol(Protocol): + def submit_job(self, jobName, jobQueue, jobDefinition, containerOverrides): + ... + + def get_waiter(self, x: str): + ... + + def describe_jobs(self, jobs): + ... + + def terminate_job(self, jobId: str, reason: str): + ... + + +class AWSBatchOperator(BaseOperator): + """ + Execute a job on AWS Batch Service + + .. warning: the queue parameter was renamed to job_queue to segregate the + internal CeleryExecutor queue from the AWS Batch internal queue. + + :param job_name: the name for the job that will run on AWS Batch (templated) + :type job_name: str + :param job_definition: the job definition name on AWS Batch + :type job_definition: str + :param job_queue: the queue name on AWS Batch + :type job_queue: str + :param overrides: the same parameter that boto3 will receive on + containerOverrides (templated): + http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job + :type overrides: dict + :param array_properties: the same parameter that boto3 will receive on + arrayProperties: + http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job + :type array_properties: dict + :param max_retries: exponential backoff retries while waiter is not + merged, 4200 = 48 hours + :type max_retries: int + :param aws_conn_id: connection id of AWS credentials / region name. If None, + credential boto3 strategy will be used + (http://boto3.readthedocs.io/en/latest/guide/configuration.html). + :type aws_conn_id: str + :param region_name: region name to use in AWS Hook. + Override the region_name in connection (if provided) + :type region_name: str + """ + + ui_color = '#c3dae0' + client = None # type: Optional[BatchProtocol] + arn = None # type: Optional[str] + template_fields = ('job_name', 'overrides',) + + @apply_defaults + def __init__(self, job_name, job_definition, job_queue, overrides, array_properties=None, + max_retries=4200, aws_conn_id=None, region_name=None, **kwargs): + super().__init__(**kwargs) + + self.job_name = job_name + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.job_definition = job_definition + self.job_queue = job_queue + self.overrides = overrides + self.array_properties = array_properties + self.max_retries = max_retries + + self.jobId = None # pylint: disable=invalid-name + self.jobName = None # pylint: disable=invalid-name + + self.hook = self.get_hook() + + def execute(self, context): + self.log.info( + 'Running AWS Batch Job - Job definition: %s - on queue %s', + self.job_definition, self.job_queue + ) + self.log.info('AWSBatchOperator overrides: %s', self.overrides) + + self.client = self.hook.get_client_type( + 'batch', + region_name=self.region_name + ) + + try: + response = self.client.submit_job( + jobName=self.job_name, + jobQueue=self.job_queue, + jobDefinition=self.job_definition, + arrayProperties=self.array_properties, + containerOverrides=self.overrides) + + self.log.info('AWS Batch Job started: %s', response) + + self.jobId = response['jobId'] + self.jobName = response['jobName'] + + self._wait_for_task_ended() + + self._check_success_task() + + self.log.info('AWS Batch Job has been successfully executed: %s', response) + except Exception as e: + self.log.info('AWS Batch Job has failed executed') + raise AirflowException(e) + + def _wait_for_task_ended(self): + """ + Try to use a waiter from the below pull request + + * https://github.com/boto/botocore/pull/1307 + + If the waiter is not available apply a exponential backoff + + * docs.aws.amazon.com/general/latest/gr/api-retries.html + """ + try: + waiter = self.client.get_waiter('job_execution_complete') + waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow + waiter.wait(jobs=[self.jobId]) + except ValueError: + # If waiter not available use expo + + # Allow a batch job some time to spin up. A random interval + # decreases the chances of exceeding an AWS API throttle + # limit when there are many concurrent tasks. + pause = randint(5, 30) + + retries = 1 + while retries <= self.max_retries: + self.log.info('AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds', + self.jobId, retries, self.max_retries, pause) + sleep(pause) + + response = self.client.describe_jobs(jobs=[self.jobId]) + status = response['jobs'][-1]['status'] + self.log.info('AWS Batch job (%s) status: %s', self.jobId, status) + if status in ['SUCCEEDED', 'FAILED']: + break + + retries += 1 + pause = 1 + pow(retries * 0.3, 2) + + def _check_success_task(self): + response = self.client.describe_jobs( + jobs=[self.jobId], + ) + + self.log.info('AWS Batch stopped, check status: %s', response) + if len(response.get('jobs')) < 1: + raise AirflowException('No job found for {}'.format(response)) + + for job in response['jobs']: + job_status = job['status'] + if job_status == 'FAILED': + reason = job['statusReason'] + raise AirflowException('Job failed with status {}'.format(reason)) + elif job_status in [ + 'SUBMITTED', + 'PENDING', + 'RUNNABLE', + 'STARTING', + 'RUNNING' + ]: + raise AirflowException( + 'This task is still pending {}'.format(job_status)) + + def get_hook(self): + return AwsHook( + aws_conn_id=self.aws_conn_id + ) + + def on_kill(self): + response = self.client.terminate_job( + jobId=self.jobId, + reason='Task killed by the user') + + self.log.info(response) diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst index fc516f4115895b..44c0259fa9c5a9 100644 --- a/docs/operators-and-hooks-ref.rst +++ b/docs/operators-and-hooks-ref.rst @@ -309,9 +309,9 @@ These integrations allow you to perform various operations within the Amazon Web - :mod:`airflow.providers.aws.operators.athena` - :mod:`airflow.providers.aws.sensors.athena` - * - `AWS Batch `__ + * - `AWS Batch `__ - - - :mod:`airflow.contrib.operators.awsbatch_operator` + - :mod:`airflow.providers.aws.operators.batch` - * - `Amazon CloudWatch Logs `__ diff --git a/scripts/ci/pylint_todo.txt b/scripts/ci/pylint_todo.txt index cdbe04939ef47c..6252c0c068b943 100644 --- a/scripts/ci/pylint_todo.txt +++ b/scripts/ci/pylint_todo.txt @@ -35,7 +35,6 @@ ./airflow/contrib/hooks/vertica_hook.py ./airflow/contrib/hooks/wasb_hook.py ./airflow/contrib/operators/adls_list_operator.py -./airflow/contrib/operators/awsbatch_operator.py ./airflow/contrib/operators/azure_container_instances_operator.py ./airflow/contrib/operators/azure_cosmos_operator.py ./airflow/operators/cassandra_to_gcs.py @@ -201,6 +200,7 @@ ./airflow/operators/subdag_operator.py ./airflow/plugins_manager.py ./airflow/providers/aws/operators/athena.py +./airflow/providers/aws/operators/batch.py ./airflow/providers/aws/sensors/athena.py ./airflow/sensors/__init__.py ./airflow/sensors/base_sensor_operator.py diff --git a/tests/contrib/operators/test_awsbatch_operator.py b/tests/providers/aws/operators/test_batch.py similarity index 99% rename from tests/contrib/operators/test_awsbatch_operator.py rename to tests/providers/aws/operators/test_batch.py index 8814857c147369..5978fa3516f251 100644 --- a/tests/contrib/operators/test_awsbatch_operator.py +++ b/tests/providers/aws/operators/test_batch.py @@ -21,8 +21,8 @@ import sys import unittest -from airflow.contrib.operators.awsbatch_operator import AWSBatchOperator from airflow.exceptions import AirflowException +from airflow.providers.aws.operators.batch import AWSBatchOperator from tests.compat import mock RESPONSE_WITHOUT_FAILURES = { diff --git a/tests/test_core_to_contrib.py b/tests/test_core_to_contrib.py index 1ddf0dc48569b4..127743c379d1e8 100644 --- a/tests/test_core_to_contrib.py +++ b/tests/test_core_to_contrib.py @@ -741,6 +741,10 @@ "airflow.providers.aws.operators.athena.AWSAthenaOperator", "airflow.contrib.operators.aws_athena_operator.AWSAthenaOperator", ), + ( + "airflow.providers.aws.operators.batch.AWSBatchOperator", + "airflow.contrib.operators.awsbatch_operator.AWSBatchOperator", + ), ] SENSOR = [ ( From 1e28da15b6bcaa6d241080a53304b3ac3c516012 Mon Sep 17 00:00:00 2001 From: Bas Harenslak Date: Tue, 29 Oct 2019 09:24:33 +0100 Subject: [PATCH 2/2] isort awsbatch_operator.py --- airflow/contrib/operators/awsbatch_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/contrib/operators/awsbatch_operator.py b/airflow/contrib/operators/awsbatch_operator.py index aa6af273cec17c..b0955d704d784f 100644 --- a/airflow/contrib/operators/awsbatch_operator.py +++ b/airflow/contrib/operators/awsbatch_operator.py @@ -22,7 +22,7 @@ import warnings # pylint: disable=unused-import -from airflow.providers.aws.operators.batch import BatchProtocol, AWSBatchOperator # noqa +from airflow.providers.aws.operators.batch import AWSBatchOperator, BatchProtocol # noqa warnings.warn( "This module is deprecated. Please use `airflow.providers.aws.operators.batch`.",