From 7502cad2844139d57e4276d971c0706a361d9dbe Mon Sep 17 00:00:00 2001 From: Darren Weber Date: Tue, 17 Dec 2019 08:15:47 -0800 Subject: [PATCH] [AIRFLOW-6206] Move and rename AWS batch operator [AIP-21] (#6764) - conform to AIP-21 - see https://issues.apache.org/jira/browse/AIRFLOW-4733 - see https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-21%3A+Changes+in+import+paths - use airflow.providers.amazon.aws.operators.batch.AwsBatchOperator - deprecate airflow.contrib.operators.awsbatch_operator.AWSBatchOperator - fix pylint for airflow/providers/amazon/aws/operators/batch.py --- UPDATING.md | 1 + .../contrib/operators/awsbatch_operator.py | 279 +-------------- .../providers/amazon/aws/operators/batch.py | 322 ++++++++++++++++++ docs/operators-and-hooks-ref.rst | 2 +- scripts/ci/pylint_todo.txt | 1 - .../amazon/aws/operators/test_batch.py} | 18 +- tests/test_core_to_contrib.py | 4 + 7 files changed, 354 insertions(+), 273 deletions(-) create mode 100644 airflow/providers/amazon/aws/operators/batch.py rename tests/{contrib/operators/test_awsbatch_operator.py => providers/amazon/aws/operators/test_batch.py} (94%) diff --git a/UPDATING.md b/UPDATING.md index c2e43adef79759..a43ef1d4b4d880 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -148,6 +148,7 @@ Migrated are: | airflow.contrib.hooks.aws_sqs_hook.SQSHook | airflow.providers.amazon.aws.hooks.sqs.SQSHook | | airflow.contrib.hooks.aws_sns_hook.AwsSnsHook | airflow.providers.amazon.aws.hooks.sns.AwsSnsHook | | airflow.contrib.operators.aws_athena_operator.AWSAthenaOperator | airflow.providers.amazon.aws.operators.athena.AWSAthenaOperator | +| airflow.contrib.operators.awsbatch.AWSBatchOperator | airflow.providers.amazon.aws.operators.batch.AwsBatchOperator | | airflow.contrib.operators.aws_sqs_publish_operator.SQSPublishOperator | airflow.providers.amazon.aws.operators.sqs.SQSPublishOperator | | airflow.contrib.operators.aws_sns_publish_operator.SnsPublishOperator | airflow.providers.amazon.aws.operators.sns.SnsPublishOperator | | airflow.contrib.sensors.aws_athena_sensor.AthenaSensor | airflow.providers.amazon.aws.sensors.athena.AthenaSensor | diff --git a/airflow/contrib/operators/awsbatch_operator.py b/airflow/contrib/operators/awsbatch_operator.py index 3c25d319acf463..c77d62001b9ace 100644 --- a/airflow/contrib/operators/awsbatch_operator.py +++ b/airflow/contrib/operators/awsbatch_operator.py @@ -17,275 +17,30 @@ # specific language governing permissions and limitations # under the License. # -import sys -from math import pow -from random import randint -from time import sleep -from typing import Optional -import botocore.exceptions -import botocore.waiter +"""This module is deprecated. Please use `airflow.providers.amazon.aws.operators.batch`.""" -from airflow.contrib.hooks.aws_hook import AwsHook -from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.typing_compat import Protocol -from airflow.utils.decorators import apply_defaults +import warnings +from airflow.providers.amazon.aws.operators.batch import AwsBatchOperator -class BatchProtocol(Protocol): - """ - .. seealso:: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html - """ - - def describe_jobs(self, jobs) -> dict: - ... - - def get_waiter(self, x: str) -> botocore.waiter.Waiter: - ... - - def submit_job( - self, jobName, jobQueue, jobDefinition, arrayProperties, parameters, containerOverrides - ) -> dict: - ... +warnings.warn( + "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.batch`.", + DeprecationWarning, + stacklevel=2, +) - def terminate_job(self, jobId: str, reason: str) -> dict: - ... - -class AWSBatchOperator(BaseOperator): +class AWSBatchOperator(AwsBatchOperator): """ - Execute a job on AWS Batch Service - - .. warning: the queue parameter was renamed to job_queue to segregate the - internal CeleryExecutor queue from the AWS Batch internal queue. - - :param job_name: the name for the job that will run on AWS Batch (templated) - :type job_name: str - :param job_definition: the job definition name on AWS Batch - :type job_definition: str - :param job_queue: the queue name on AWS Batch - :type job_queue: str - :param overrides: the same parameter that boto3 will receive on - containerOverrides (templated) - http://boto3.readthedocs.io/en/latest/reference/services/batch.html#Batch.Client.submit_job - :type overrides: dict - :param array_properties: the same parameter that boto3 will receive on - arrayProperties - http://boto3.readthedocs.io/en/latest/reference/services/batch.html#Batch.Client.submit_job - :type array_properties: dict - :param parameters: the same parameter that boto3 will receive on - parameters (templated) - http://boto3.readthedocs.io/en/latest/reference/services/batch.html#Batch.Client.submit_job - :type parameters: dict - :param max_retries: exponential backoff retries while waiter is not - merged, 4200 = 48 hours - :type max_retries: int - :param status_retries: number of retries to get job description (status), 10 - :type status_retries: int - :param aws_conn_id: connection id of AWS credentials / region name. If None, - credential boto3 strategy will be used - (http://boto3.readthedocs.io/en/latest/guide/configuration.html). - :type aws_conn_id: str - :param region_name: region name to use in AWS Hook. - Override the region_name in connection (if provided) - :type region_name: str + This class is deprecated. Please use `airflow.providers.amazon.aws.operators.batch.AwsBatchOperator`. """ - MAX_RETRIES = 4200 - STATUS_RETRIES = 10 - - ui_color = "#c3dae0" - client = None # type: BatchProtocol - arn = None # type: str - template_fields = ( - "job_name", - "overrides", - "parameters", - ) - - @apply_defaults - def __init__( - self, - job_name, - job_definition, - job_queue, - overrides, - array_properties=None, - parameters=None, - max_retries=MAX_RETRIES, - status_retries=STATUS_RETRIES, - aws_conn_id=None, - region_name=None, - **kwargs, - ): - super().__init__(**kwargs) - - self.job_name = job_name - self.aws_conn_id = aws_conn_id - self.region_name = region_name - self.job_definition = job_definition - self.job_queue = job_queue - self.overrides = overrides - self.array_properties = array_properties or {} - self.parameters = parameters - self.max_retries = max_retries - self.status_retries = status_retries - - self.jobId = None # pylint: disable=invalid-name - self.jobName = None # pylint: disable=invalid-name - - self.hook = self.get_hook() - - def execute(self, context): - self.log.info( - "Running AWS Batch Job - Job definition: %s - on queue %s", - self.job_definition, - self.job_queue, + def __init__(self, *args, **kwargs): + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.amazon.aws.operators.batch.AwsBatchOperator`.""", + DeprecationWarning, + stacklevel=2, ) - self.log.info("AWSBatchOperator overrides: %s", self.overrides) - - self.client = self.hook.get_client_type("batch", region_name=self.region_name) - - try: - - response = self.client.submit_job( - jobName=self.job_name, - jobQueue=self.job_queue, - jobDefinition=self.job_definition, - arrayProperties=self.array_properties, - parameters=self.parameters, - containerOverrides=self.overrides, - ) - - self.log.info("AWS Batch Job started: %s", response) - - self.jobId = response["jobId"] - self.jobName = response["jobName"] - - self._wait_for_task_ended() - self._check_success_task() - - self.log.info("AWS Batch Job has been successfully executed: %s", response) - except Exception as e: - self.log.info("AWS Batch Job has failed executed") - raise AirflowException(e) - - def _wait_for_task_ended(self): - """ - Try to use a waiter from the below pull request - - * https://github.com/boto/botocore/pull/1307 - - If the waiter is not available apply a exponential backoff - - * docs.aws.amazon.com/general/latest/gr/api-retries.html - """ - try: - waiter = self.client.get_waiter("job_execution_complete") - waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow - waiter.wait(jobs=[self.jobId]) - except ValueError: - self._poll_for_task_ended() - - def _poll_for_task_ended(self): - """ - Poll for job status - - * docs.aws.amazon.com/general/latest/gr/api-retries.html - """ - # Allow a batch job some time to spin up. A random interval - # decreases the chances of exceeding an AWS API throttle - # limit when there are many concurrent tasks. - pause = randint(5, 30) - - tries = 0 - while tries < self.max_retries: - tries += 1 - self.log.info( - "AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds", - self.jobId, - tries, - self.max_retries, - pause, - ) - sleep(pause) - - response = self._get_job_description() - jobs = response.get("jobs") - status = jobs[-1]["status"] # check last job status - self.log.info("AWS Batch job (%s) status: %s", self.jobId, status) - - # status options: 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED' - if status in ["SUCCEEDED", "FAILED"]: - break - - pause = 1 + pow(tries * 0.3, 2) - - def _get_job_description(self) -> Optional[dict]: - """ - Get job description - - * https://docs.aws.amazon.com/batch/latest/APIReference/API_DescribeJobs.html - * https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html - """ - tries = 0 - while tries < self.status_retries: - tries += 1 - try: - response = self.client.describe_jobs(jobs=[self.jobId]) - if response and response.get("jobs"): - return response - else: - self.log.error("Job description has no jobs (%s): %s", self.jobId, response) - except botocore.exceptions.ClientError as err: - response = err.response - self.log.error("Job description error (%s): %s", self.jobId, response) - if tries < self.status_retries: - error = response.get("Error", {}) - if error.get("Code") == "TooManyRequestsException": - pause = randint(1, 10) # avoid excess requests with a random pause - self.log.info( - "AWS Batch job (%s) status retry (%d of %d) in the next %.2f seconds", - self.jobId, - tries, - self.status_retries, - pause, - ) - sleep(pause) - continue - - msg = "Failed to get job description ({})".format(self.jobId) - raise AirflowException(msg) - - def _check_success_task(self): - """ - Check the final status of the batch job; the job status options are: - 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED' - """ - response = self._get_job_description() - jobs = response.get("jobs") - - matching_jobs = [job for job in jobs if job["jobId"] == self.jobId] - if not matching_jobs: - raise AirflowException( - "Job ({}) has no job description {}".format(self.jobId, response) - ) - - job = matching_jobs[0] - self.log.info("AWS Batch stopped, check status: %s", job) - job_status = job["status"] - if job_status == "FAILED": - reason = job["statusReason"] - raise AirflowException("Job ({}) failed with status {}".format(self.jobId, reason)) - elif job_status in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"]: - raise AirflowException( - "Job ({}) is still pending {}".format(self.jobId, job_status) - ) - - def get_hook(self): - return AwsHook(aws_conn_id=self.aws_conn_id) - - def on_kill(self): - response = self.client.terminate_job(jobId=self.jobId, reason="Task killed by the user") - self.log.info(response) + super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py new file mode 100644 index 00000000000000..49468f33f12d0e --- /dev/null +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +""" +Airflow operator for AWS batch service + +.. seealso:: http://boto3.readthedocs.io/en/latest/reference/services/batch.html +""" + +import sys +from random import randint +from time import sleep +from typing import Optional + +import botocore.exceptions +import botocore.waiter + +from airflow.contrib.hooks.aws_hook import AwsHook +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.typing_compat import Protocol +from airflow.utils.decorators import apply_defaults + +# pylint: disable=invalid-name, unused-argument + + +class BatchProtocol(Protocol): + """ + .. seealso:: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html + """ + + def describe_jobs(self, jobs) -> dict: + """Get job descriptions from AWS batch""" + ... + + def get_waiter(self, x: str) -> botocore.waiter.Waiter: + """Get an AWS service waiter + + Note that AWS batch might not have any waiters (until botocore PR1307 is merged and released). + + .. code-block:: python + + import boto3 + boto3.client('batch').waiter_names == [] + + .. seealso:: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/clients.html#waiters + .. seealso:: https://github.com/boto/botocore/pull/1307 + """ + ... + + def submit_job( + self, jobName, jobQueue, jobDefinition, arrayProperties, parameters, containerOverrides + ) -> dict: + """Submit a batch job + :type jobName: str + :type jobQueue: str + :type jobDefinition: str + :type arrayProperties: dict + :type parameters: dict + :type containerOverrides: dict + """ + ... + + def terminate_job(self, jobId: str, reason: str) -> dict: + """Terminate a batch job""" + ... + + +class AwsBatchOperator(BaseOperator): + """ + Execute a job on AWS Batch Service + + .. warning: the queue parameter was renamed to job_queue to segregate the + internal CeleryExecutor queue from the AWS Batch internal queue. + + :param job_name: the name for the job that will run on AWS Batch (templated) + :type job_name: str + :param job_definition: the job definition name on AWS Batch + :type job_definition: str + :param job_queue: the queue name on AWS Batch + :type job_queue: str + :param overrides: the same parameter that boto3 will receive on + containerOverrides (templated) + http://boto3.readthedocs.io/en/latest/reference/services/batch.html#Batch.Client.submit_job + :type overrides: dict + :param array_properties: the same parameter that boto3 will receive on + arrayProperties + http://boto3.readthedocs.io/en/latest/reference/services/batch.html#Batch.Client.submit_job + :type array_properties: dict + :param parameters: the same parameter that boto3 will receive on + parameters (templated) + http://boto3.readthedocs.io/en/latest/reference/services/batch.html#Batch.Client.submit_job + :type parameters: dict + :param max_retries: exponential backoff retries while waiter is not + merged, 4200 = 48 hours + :type max_retries: int + :param status_retries: number of retries to get job description (status), 10 + :type status_retries: int + :param aws_conn_id: connection id of AWS credentials / region name. If None, + credential boto3 strategy will be used + (http://boto3.readthedocs.io/en/latest/guide/configuration.html). + :type aws_conn_id: str + :param region_name: region name to use in AWS Hook. + Override the region_name in connection (if provided) + :type region_name: str + """ + + MAX_RETRIES = 4200 + STATUS_RETRIES = 10 + + ui_color = "#c3dae0" + client = None # type: BatchProtocol + arn = None # type: str + template_fields = ( + "job_name", + "overrides", + "parameters", + ) + + @apply_defaults + def __init__( + self, + job_name, + job_definition, + job_queue, + overrides, + array_properties=None, + parameters=None, + max_retries=None, + status_retries=None, + aws_conn_id=None, + region_name=None, + **kwargs, + ): # pylint: disable=too-many-arguments + super().__init__(**kwargs) + + self.job_name = job_name + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.job_definition = job_definition + self.job_queue = job_queue + self.overrides = overrides + self.array_properties = array_properties or {} + self.parameters = parameters + self.max_retries = max_retries or self.MAX_RETRIES + self.status_retries = status_retries or self.STATUS_RETRIES + + self.jobId = None # pylint: disable=invalid-name + self.jobName = None # pylint: disable=invalid-name + + self.hook = self.get_hook() + + def execute(self, context): + self.log.info( + "Running AWS Batch Job - job definition: %s - on queue %s", + self.job_definition, + self.job_queue, + ) + self.log.info("AWS Batch Job - container overrides: %s", self.overrides) + + self.client = self.hook.get_client_type("batch", region_name=self.region_name) + + try: + response = self.client.submit_job( + jobName=self.job_name, + jobQueue=self.job_queue, + jobDefinition=self.job_definition, + arrayProperties=self.array_properties, + parameters=self.parameters, + containerOverrides=self.overrides, + ) + + self.log.info("AWS Batch Job started: %s", response) + + self.jobId = response["jobId"] + self.jobName = response["jobName"] + + self._wait_for_task_ended() + self._check_success_task() + + self.log.info("AWS Batch Job has been successfully executed: %s", response) + + except Exception as e: + self.log.info("AWS Batch Job has failed") + raise AirflowException(e) + + def _wait_for_task_ended(self): + """ + Try to use a waiter from the below pull request + + * https://github.com/boto/botocore/pull/1307 + + If the waiter is not available apply a exponential backoff + + * docs.aws.amazon.com/general/latest/gr/api-retries.html + """ + try: + waiter = self.client.get_waiter("job_execution_complete") + waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow + waiter.wait(jobs=[self.jobId]) + except ValueError: + self._poll_for_task_ended() + + def _poll_for_task_ended(self): + """ + Poll for job status + + * docs.aws.amazon.com/general/latest/gr/api-retries.html + """ + # Allow a batch job some time to spin up. A random interval + # decreases the chances of exceeding an AWS API throttle + # limit when there are many concurrent tasks. + pause = randint(5, 30) + + tries = 0 + while tries < self.max_retries: + tries += 1 + self.log.info( + "AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds", + self.jobId, + tries, + self.max_retries, + pause, + ) + sleep(pause) + + response = self._get_job_description() + jobs = response.get("jobs") + status = jobs[-1]["status"] # check last job status + self.log.info("AWS Batch job (%s) status: %s", self.jobId, status) + + # status options: 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED' + if status in ["SUCCEEDED", "FAILED"]: + break + + pause = 1 + pow(tries * 0.3, 2) + + def _get_job_description(self) -> Optional[dict]: + """ + Get job description + + * https://docs.aws.amazon.com/batch/latest/APIReference/API_DescribeJobs.html + * https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html + """ + tries = 0 + while tries < self.status_retries: + tries += 1 + try: + response = self.client.describe_jobs(jobs=[self.jobId]) + if response and response.get("jobs"): + return response + else: + self.log.error("Job description has no jobs (%s): %s", self.jobId, response) + except botocore.exceptions.ClientError as err: + response = err.response + self.log.error("Job description error (%s): %s", self.jobId, response) + if tries < self.status_retries: + error = response.get("Error", {}) + if error.get("Code") == "TooManyRequestsException": + pause = randint(1, 10) # avoid excess requests with a random pause + self.log.info( + "AWS Batch job (%s) status retry (%d of %d) in the next %.2f seconds", + self.jobId, + tries, + self.status_retries, + pause, + ) + sleep(pause) + continue + + msg = "Failed to get job description ({})".format(self.jobId) + raise AirflowException(msg) + + def _check_success_task(self): + """ + Check the final status of the batch job; the job status options are: + 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED' + """ + response = self._get_job_description() + jobs = response.get("jobs") + + matching_jobs = [job for job in jobs if job["jobId"] == self.jobId] + if not matching_jobs: + raise AirflowException( + "Job ({}) has no job description {}".format(self.jobId, response) + ) + + job = matching_jobs[0] + self.log.info("AWS Batch stopped, check status: %s", job) + job_status = job["status"] + if job_status == "FAILED": + reason = job["statusReason"] + raise AirflowException("Job ({}) failed with status {}".format(self.jobId, reason)) + elif job_status in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"]: + raise AirflowException( + "Job ({}) is still pending {}".format(self.jobId, job_status) + ) + + def get_hook(self): + """Get an AWS API client (boto3)""" + return AwsHook(aws_conn_id=self.aws_conn_id) + + def on_kill(self): + response = self.client.terminate_job(jobId=self.jobId, reason="Task killed by the user") + self.log.info(response) diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst index 50e82c232dad10..2d2832c4fa8329 100644 --- a/docs/operators-and-hooks-ref.rst +++ b/docs/operators-and-hooks-ref.rst @@ -302,7 +302,7 @@ These integrations allow you to perform various operations within the Amazon Web * - `AWS Batch `__ - - - :mod:`airflow.contrib.operators.awsbatch_operator` + - :mod:`airflow.providers.amazon.aws.operators.batch` - * - `AWS DataSync `__ diff --git a/scripts/ci/pylint_todo.txt b/scripts/ci/pylint_todo.txt index cb4b05ac7efaf8..4efa91d278b2dc 100644 --- a/scripts/ci/pylint_todo.txt +++ b/scripts/ci/pylint_todo.txt @@ -31,7 +31,6 @@ ./airflow/contrib/hooks/vertica_hook.py ./airflow/contrib/hooks/wasb_hook.py ./airflow/contrib/operators/adls_list_operator.py -./airflow/contrib/operators/awsbatch_operator.py ./airflow/contrib/operators/azure_container_instances_operator.py ./airflow/contrib/operators/azure_cosmos_operator.py ./airflow/contrib/operators/dingding_operator.py diff --git a/tests/contrib/operators/test_awsbatch_operator.py b/tests/providers/amazon/aws/operators/test_batch.py similarity index 94% rename from tests/contrib/operators/test_awsbatch_operator.py rename to tests/providers/amazon/aws/operators/test_batch.py index 3cccfa7beadf68..0eec05d0d95cda 100644 --- a/tests/contrib/operators/test_awsbatch_operator.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -23,8 +23,8 @@ import botocore.exceptions -from airflow.contrib.operators.awsbatch_operator import AWSBatchOperator from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.operators.batch import AwsBatchOperator from tests.compat import mock JOB_NAME = "51455483-c62c-48ac-9b88-53a6a725baa3" @@ -36,15 +36,15 @@ } -class TestAWSBatchOperator(unittest.TestCase): +class TestAwsBatchOperator(unittest.TestCase): MAX_RETRIES = 2 STATUS_RETRIES = 3 - @mock.patch("airflow.contrib.operators.awsbatch_operator.AwsHook") + @mock.patch("airflow.providers.amazon.aws.operators.batch.AwsHook") def setUp(self, aws_hook_mock): self.aws_hook_mock = aws_hook_mock - self.batch = AWSBatchOperator( + self.batch = AwsBatchOperator( task_id="task", job_name=JOB_NAME, job_queue="queue", @@ -76,8 +76,8 @@ def test_init(self): def test_template_fields_overrides(self): self.assertEqual(self.batch.template_fields, ("job_name", "overrides", "parameters",)) - @mock.patch.object(AWSBatchOperator, "_wait_for_task_ended") - @mock.patch.object(AWSBatchOperator, "_check_success_task") + @mock.patch.object(AwsBatchOperator, "_wait_for_task_ended") + @mock.patch.object(AwsBatchOperator, "_check_success_task") def test_execute_without_failures(self, check_mock, wait_mock): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value client_mock.submit_job.return_value = RESPONSE_WITHOUT_FAILURES @@ -130,7 +130,7 @@ def test_wait_end_tasks(self): client_mock.get_waiter.return_value.wait.assert_called_once_with(jobs=[JOB_ID]) self.assertEqual(sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts) - @mock.patch("airflow.contrib.operators.awsbatch_operator.randint") + @mock.patch("airflow.providers.amazon.aws.operators.batch.randint") def test_poll_job_status_success(self, mock_randint): client_mock = mock.Mock() self.batch.jobId = JOB_ID @@ -146,7 +146,7 @@ def test_poll_job_status_success(self, mock_randint): client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID]) - @mock.patch("airflow.contrib.operators.awsbatch_operator.randint") + @mock.patch("airflow.providers.amazon.aws.operators.batch.randint") def test_poll_job_status_running(self, mock_randint): client_mock = mock.Mock() self.batch.jobId = JOB_ID @@ -164,7 +164,7 @@ def test_poll_job_status_running(self, mock_randint): client_mock.describe_jobs.assert_called_with(jobs=[JOB_ID]) self.assertEqual(client_mock.describe_jobs.call_count, self.MAX_RETRIES) - @mock.patch("airflow.contrib.operators.awsbatch_operator.randint") + @mock.patch("airflow.providers.amazon.aws.operators.batch.randint") def test_poll_job_status_hit_api_throttle(self, mock_randint): client_mock = mock.Mock() self.batch.jobId = JOB_ID diff --git a/tests/test_core_to_contrib.py b/tests/test_core_to_contrib.py index 6b2aed72b49ca2..6027707038d04d 100644 --- a/tests/test_core_to_contrib.py +++ b/tests/test_core_to_contrib.py @@ -770,6 +770,10 @@ "airflow.providers.amazon.aws.operators.athena.AWSAthenaOperator", "airflow.contrib.operators.aws_athena_operator.AWSAthenaOperator", ), + ( + "airflow.providers.amazon.aws.operators.batch.AwsBatchOperator", + "airflow.contrib.operators.awsbatch_operator.AWSBatchOperator", + ), ( "airflow.providers.amazon.aws.operators.sqs.SQSPublishOperator", "airflow.contrib.operators.aws_sqs_publish_operator.SQSPublishOperator",