Skip to content

Commit

Permalink
Add support for associating custom tags to job runs submitted via Emr…
Browse files Browse the repository at this point in the history
…ContainerOperator (#23769)

Co-authored-by: Sandeep Kadyan <[email protected]>
  • Loading branch information
skadyan and Sandeep Kadyan authored May 22, 2022
1 parent 65f3b18 commit e54ca47
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 1 deletion.
3 changes: 3 additions & 0 deletions airflow/providers/amazon/aws/hooks/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def submit_job(
job_driver: dict,
configuration_overrides: Optional[dict] = None,
client_request_token: Optional[str] = None,
tags: Optional[dict] = None,
) -> str:
"""
Submit a job to the EMR Containers API and return the job ID.
Expand All @@ -148,6 +149,7 @@ def submit_job(
specifically either application configuration or monitoring configuration.
:param client_request_token: The client idempotency token of the job run request.
Use this if you want to specify a unique ID to prevent two jobs from getting started.
:param tags: The tags assigned to job runs.
:return: Job ID
"""
params = {
Expand All @@ -157,6 +159,7 @@ def submit_job(
"releaseLabel": release_label,
"jobDriver": job_driver,
"configurationOverrides": configuration_overrides or {},
"tags": tags or {},
}
if client_request_token:
params["clientToken"] = client_request_token
Expand Down
5 changes: 5 additions & 0 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ class EmrContainerOperator(BaseOperator):
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check query status on EMR
:param max_tries: Maximum number of times to wait for the job run to finish.
Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state.
:param tags: The tags assigned to job runs.
Defaults to None
"""

template_fields: Sequence[str] = (
Expand All @@ -160,6 +162,7 @@ def __init__(
aws_conn_id: str = "aws_default",
poll_interval: int = 30,
max_tries: Optional[int] = None,
tags: Optional[dict] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -173,6 +176,7 @@ def __init__(
self.client_request_token = client_request_token or str(uuid4())
self.poll_interval = poll_interval
self.max_tries = max_tries
self.tags = tags
self.job_id: Optional[str] = None

@cached_property
Expand All @@ -192,6 +196,7 @@ def execute(self, context: 'Context') -> Optional[str]:
self.job_driver,
self.configuration_overrides,
self.client_request_token,
self.tags,
)
query_status = self.hook.poll_query_status(self.job_id, self.max_tries, self.poll_interval)

Expand Down
3 changes: 2 additions & 1 deletion tests/providers/amazon/aws/operators/test_emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def setUp(self, emr_hook_mock):
configuration_overrides={},
poll_interval=0,
client_request_token=GENERATED_UUID,
tags={},
)

@mock.patch.object(EmrContainerHook, 'submit_job')
Expand All @@ -66,7 +67,7 @@ def test_execute_without_failure(
self.emr_container.execute(None)

mock_submit_job.assert_called_once_with(
'test_emr_job', 'arn:aws:somerole', '6.3.0-latest', {}, {}, GENERATED_UUID
'test_emr_job', 'arn:aws:somerole', '6.3.0-latest', {}, {}, GENERATED_UUID, {}
)
mock_check_query_status.assert_called_once_with('jobid_123456')
assert self.emr_container.release_label == '6.3.0-latest'
Expand Down

0 comments on commit e54ca47

Please sign in to comment.