diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index 9a2c13ca9d6845..143bdcdcc89137 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -133,6 +133,7 @@ def submit_job( job_driver: dict, configuration_overrides: Optional[dict] = None, client_request_token: Optional[str] = None, + tags: Optional[dict] = None, ) -> str: """ Submit a job to the EMR Containers API and return the job ID. @@ -148,6 +149,7 @@ def submit_job( specifically either application configuration or monitoring configuration. :param client_request_token: The client idempotency token of the job run request. Use this if you want to specify a unique ID to prevent two jobs from getting started. + :param tags: The tags assigned to job runs. :return: Job ID """ params = { @@ -157,6 +159,7 @@ def submit_job( "releaseLabel": release_label, "jobDriver": job_driver, "configurationOverrides": configuration_overrides or {}, + "tags": tags or {}, } if client_request_token: params["clientToken"] = client_request_token diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 6f942e297fb9c1..a1f3fa753d8172 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -136,6 +136,8 @@ class EmrContainerOperator(BaseOperator): :param poll_interval: Time (in seconds) to wait between two consecutive calls to check query status on EMR :param max_tries: Maximum number of times to wait for the job run to finish. Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state. + :param tags: The tags assigned to job runs. + Defaults to None """ template_fields: Sequence[str] = ( @@ -160,6 +162,7 @@ def __init__( aws_conn_id: str = "aws_default", poll_interval: int = 30, max_tries: Optional[int] = None, + tags: Optional[dict] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -173,6 +176,7 @@ def __init__( self.client_request_token = client_request_token or str(uuid4()) self.poll_interval = poll_interval self.max_tries = max_tries + self.tags = tags self.job_id: Optional[str] = None @cached_property @@ -192,6 +196,7 @@ def execute(self, context: 'Context') -> Optional[str]: self.job_driver, self.configuration_overrides, self.client_request_token, + self.tags, ) query_status = self.hook.poll_query_status(self.job_id, self.max_tries, self.poll_interval) diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py b/tests/providers/amazon/aws/operators/test_emr_containers.py index 24d7eb6f828ec4..8bd58e2b6aa015 100644 --- a/tests/providers/amazon/aws/operators/test_emr_containers.py +++ b/tests/providers/amazon/aws/operators/test_emr_containers.py @@ -51,6 +51,7 @@ def setUp(self, emr_hook_mock): configuration_overrides={}, poll_interval=0, client_request_token=GENERATED_UUID, + tags={}, ) @mock.patch.object(EmrContainerHook, 'submit_job') @@ -66,7 +67,7 @@ def test_execute_without_failure( self.emr_container.execute(None) mock_submit_job.assert_called_once_with( - 'test_emr_job', 'arn:aws:somerole', '6.3.0-latest', {}, {}, GENERATED_UUID + 'test_emr_job', 'arn:aws:somerole', '6.3.0-latest', {}, {}, GENERATED_UUID, {} ) mock_check_query_status.assert_called_once_with('jobid_123456') assert self.emr_container.release_label == '6.3.0-latest'