Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: EmrServerlessStartJobOperator not serializing DAGs correctly whe… #38022

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 56 additions & 6 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,27 +1253,77 @@ def operator_extra_links(self):
op_extra_links = []

if isinstance(self, MappedOperator):
operator_class = self.operator_class
enable_application_ui_links = self.partial_kwargs.get(
"enable_application_ui_links"
) or self.expand_input.value.get("enable_application_ui_links")
job_driver = self.partial_kwargs.get("job_driver") or self.expand_input.value.get("job_driver")
job_driver = self.partial_kwargs.get("job_driver", {}) or self.expand_input.value.get(
"job_driver", {}
)
configuration_overrides = self.partial_kwargs.get(
"configuration_overrides"
) or self.expand_input.value.get("configuration_overrides")

# Configuration overrides can either be a list or a dictionary, depending on whether it's passed in as partial or expand.
if isinstance(configuration_overrides, list):
if any(
[
operator_class.is_monitoring_in_job_override(
self=operator_class,
config_key="s3MonitoringConfiguration",
job_override=job_override,
)
for job_override in configuration_overrides
]
):
op_extra_links.extend([EmrServerlessS3LogsLink()])
if any(
[
operator_class.is_monitoring_in_job_override(
self=operator_class,
config_key="cloudWatchLoggingConfiguration",
job_override=job_override,
)
for job_override in configuration_overrides
]
):
op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
else:
if operator_class.is_monitoring_in_job_override(
self=operator_class,
config_key="s3MonitoringConfiguration",
job_override=configuration_overrides,
):
op_extra_links.extend([EmrServerlessS3LogsLink()])
if operator_class.is_monitoring_in_job_override(
self=operator_class,
config_key="cloudWatchLoggingConfiguration",
job_override=configuration_overrides,
):
op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])

else:
operator_class = self
enable_application_ui_links = self.enable_application_ui_links
configuration_overrides = self.configuration_overrides
job_driver = self.job_driver

if operator_class.is_monitoring_in_job_override(
"s3MonitoringConfiguration", configuration_overrides
):
op_extra_links.extend([EmrServerlessS3LogsLink()])
if operator_class.is_monitoring_in_job_override(
"cloudWatchLoggingConfiguration", configuration_overrides
):
op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])

if enable_application_ui_links:
op_extra_links.extend([EmrServerlessDashboardLink()])
if "sparkSubmit" in job_driver:
if isinstance(job_driver, list):
if any("sparkSubmit" in ind_job_driver for ind_job_driver in job_driver):
op_extra_links.extend([EmrServerlessLogsLink()])
elif "sparkSubmit" in job_driver:
op_extra_links.extend([EmrServerlessLogsLink()])
if self.is_monitoring_in_job_override("s3MonitoringConfiguration", configuration_overrides):
op_extra_links.extend([EmrServerlessS3LogsLink()])
if self.is_monitoring_in_job_override("cloudWatchLoggingConfiguration", configuration_overrides):
op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])

return tuple(op_extra_links)

Expand Down
55 changes: 55 additions & 0 deletions tests/providers/amazon/aws/operators/test_emr_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,21 @@

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import (
EmrServerlessCloudWatchLogsLink,
EmrServerlessDashboardLink,
EmrServerlessLogsLink,
EmrServerlessS3LogsLink,
)
from airflow.providers.amazon.aws.operators.emr import (
EmrServerlessCreateApplicationOperator,
EmrServerlessDeleteApplicationOperator,
EmrServerlessStartJobOperator,
EmrServerlessStopApplicationOperator,
)
from airflow.serialization.serialized_objects import (
SerializedBaseOperator,
)
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
Expand Down Expand Up @@ -1096,6 +1105,52 @@ def test_links_spark_without_applicationui_enabled(
job_run_id=job_run_id,
)

def test_operator_extra_links_mapped_without_applicationui_enabled(
self,
):
operator = EmrServerlessStartJobOperator.partial(
task_id=task_id,
application_id=application_id,
execution_role_arn=execution_role_arn,
job_driver=spark_job_driver,
enable_application_ui_links=False,
).expand(
configuration_overrides=[s3_configuration_overrides, cloudwatch_configuration_overrides],
)

serialize = SerializedBaseOperator.serialize
deserialize = SerializedBaseOperator.deserialize_operator
deserialized_operator = deserialize(serialize(operator))

assert deserialized_operator.operator_extra_links == [
EmrServerlessS3LogsLink(),
EmrServerlessCloudWatchLogsLink(),
]

def test_operator_extra_links_mapped_with_applicationui_enabled_at_partial(
self,
):
operator = EmrServerlessStartJobOperator.partial(
task_id=task_id,
application_id=application_id,
execution_role_arn=execution_role_arn,
job_driver=spark_job_driver,
enable_application_ui_links=True,
).expand(
configuration_overrides=[s3_configuration_overrides, cloudwatch_configuration_overrides],
)

serialize = SerializedBaseOperator.serialize
deserialize = SerializedBaseOperator.deserialize_operator
deserialized_operator = deserialize(serialize(operator))

assert deserialized_operator.operator_extra_links == [
EmrServerlessS3LogsLink(),
EmrServerlessCloudWatchLogsLink(),
EmrServerlessDashboardLink(),
EmrServerlessLogsLink(),
]


class TestEmrServerlessDeleteOperator:
@mock.patch.object(EmrServerlessHook, "get_waiter")
Expand Down
Loading