diff --git a/airflow/providers/amazon/aws/operators/eks.py b/airflow/providers/amazon/aws/operators/eks.py index 600c4d3aa1439..8e97015d86182 100644 --- a/airflow/providers/amazon/aws/operators/eks.py +++ b/airflow/providers/amazon/aws/operators/eks.py @@ -139,20 +139,6 @@ def __init__( region: Optional[str] = None, **kwargs, ) -> None: - if compute: - if compute not in SUPPORTED_COMPUTE_VALUES: - raise ValueError("Provided compute type is not supported.") - elif (compute == 'nodegroup') and not nodegroup_role_arn: - raise ValueError( - MISSING_ARN_MSG.format(compute=NODEGROUP_FULL_NAME, requirement='nodegroup_role_arn') - ) - elif (compute == 'fargate') and not fargate_pod_execution_role_arn: - raise ValueError( - MISSING_ARN_MSG.format( - compute=FARGATE_FULL_NAME, requirement='fargate_pod_execution_role_arn' - ) - ) - self.compute = compute self.cluster_name = cluster_name self.cluster_role_arn = cluster_role_arn @@ -170,6 +156,20 @@ def __init__( super().__init__(**kwargs) def execute(self, context: 'Context'): + if self.compute: + if self.compute not in SUPPORTED_COMPUTE_VALUES: + raise ValueError("Provided compute type is not supported.") + elif (self.compute == 'nodegroup') and not self.nodegroup_role_arn: + raise ValueError( + MISSING_ARN_MSG.format(compute=NODEGROUP_FULL_NAME, requirement='nodegroup_role_arn') + ) + elif (self.compute == 'fargate') and not self.fargate_pod_execution_role_arn: + raise ValueError( + MISSING_ARN_MSG.format( + compute=FARGATE_FULL_NAME, requirement='fargate_pod_execution_role_arn' + ) + ) + eks_hook = EksHook( aws_conn_id=self.aws_conn_id, region_name=self.region, diff --git a/tests/providers/amazon/aws/operators/test_eks.py b/tests/providers/amazon/aws/operators/test_eks.py index 0950708c0d64b..d1610b4ffd5ef 100644 --- a/tests/providers/amazon/aws/operators/test_eks.py +++ b/tests/providers/amazon/aws/operators/test_eks.py @@ -19,6 +19,8 @@ from typing import Any, Dict, List from unittest import mock +import pytest + from airflow.providers.amazon.aws.hooks.eks import ClusterStates, EksHook from airflow.providers.amazon.aws.operators.eks import ( EksCreateClusterOperator, @@ -202,6 +204,42 @@ def test_execute_when_called_with_fargate_creates_both( **convert_keys(self.create_fargate_profile_params) ) + def test_invalid_compute_value(self): + invalid_compute = EksCreateClusterOperator( + task_id=TASK_ID, + **self.create_cluster_params, + compute='infinite', + ) + + with pytest.raises(ValueError, match="Provided compute type is not supported."): + invalid_compute.execute({}) + + def test_nodegroup_compute_missing_nodegroup_role_arn(self): + missing_nodegroup_role_arn = EksCreateClusterOperator( + task_id=TASK_ID, + **self.create_cluster_params, + compute='nodegroup', + ) + + with pytest.raises( + ValueError, + match="Creating an Amazon EKS managed node groups requires nodegroup_role_arn to be passed in.", + ): + missing_nodegroup_role_arn.execute({}) + + def test_fargate_compute_missing_fargate_pod_execution_role_arn(self): + missing_fargate_pod_execution_role_arn = EksCreateClusterOperator( + task_id=TASK_ID, + **self.create_cluster_params, + compute='fargate', + ) + + with pytest.raises( + ValueError, + match="Creating an AWS Fargate profiles requires fargate_pod_execution_role_arn to be passed in.", + ): + missing_fargate_pod_execution_role_arn.execute({}) + class TestEksCreateFargateProfileOperator(unittest.TestCase): def setUp(self) -> None: