Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce base class for EKS sensors #29053

Merged
merged 5 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 84 additions & 93 deletions airflow/providers/amazon/aws/sensors/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Tracking the state of Amazon EKS Clusters, Amazon EKS managed node groups, and AWS Fargate profiles."""
from __future__ import annotations

from abc import abstractmethod
from typing import TYPE_CHECKING, Sequence

from airflow.compat.functools import cached_property
Expand Down Expand Up @@ -57,7 +58,71 @@
)


class EksClusterStateSensor(BaseSensorOperator):
class EksBaseSensor(BaseSensorOperator):
"""
Base class to check various EKS states.
Subclasses need to implement get_state and get_terminal_states methods.

:param cluster_name: The name of the Cluster
:param target_state: Will return successfully when that state is reached.
:param target_state_type: The enum containing the states,
will be used to convert the target state if it has to be converted from a string
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then the default boto3 configuration would be used (and must be
maintained on each worker node).
:param region: Which AWS region the connection should use.
If this is None or empty then the default boto3 behaviour is used.
"""

def __init__(
self,
*,
cluster_name: str,
target_state: ClusterStates | NodegroupStates | FargateProfileStates,
target_state_type: type,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.cluster_name = cluster_name
self.aws_conn_id = aws_conn_id
self.region = region
self.target_state = (
target_state
if isinstance(target_state, target_state_type)
else target_state_type(str(target_state).upper())
)

@cached_property
def hook(self) -> EksHook:
return EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)

def poke(self, context: Context) -> bool:
state = self.get_state()
self.log.info("Current state: %s", state)
if state in (self.get_terminal_states() - {self.target_state}):
# If we reach a terminal state which is not the target state:
raise AirflowException(
UNEXPECTED_TERMINAL_STATE_MSG.format(current_state=state, target_state=self.target_state)
)
return state == self.target_state

@abstractmethod
def get_state(self) -> ClusterStates | NodegroupStates | FargateProfileStates:
...

@abstractmethod
def get_terminal_states(self) -> frozenset:
...


class EksClusterStateSensor(EksBaseSensor):
"""
Check the state of an Amazon EKS Cluster until it reaches the target state or another terminal state.

Expand All @@ -83,43 +148,19 @@ class EksClusterStateSensor(BaseSensorOperator):
def __init__(
self,
*,
cluster_name: str,
target_state: ClusterStates = ClusterStates.ACTIVE,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
**kwargs,
):
self.cluster_name = cluster_name
self.target_state = (
target_state
if isinstance(target_state, ClusterStates)
else ClusterStates(str(target_state).upper())
)
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(**kwargs)
super().__init__(target_state=target_state, target_state_type=ClusterStates, **kwargs)

@cached_property
def hook(self):
return EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)
def get_state(self) -> ClusterStates:
return self.hook.get_cluster_state(clusterName=self.cluster_name)

def poke(self, context: Context):
cluster_state = self.hook.get_cluster_state(clusterName=self.cluster_name)
self.log.info("Cluster state: %s", cluster_state)
if cluster_state in (CLUSTER_TERMINAL_STATES - {self.target_state}):
# If we reach a terminal state which is not the target state:
raise AirflowException(
UNEXPECTED_TERMINAL_STATE_MSG.format(
current_state=cluster_state, target_state=self.target_state
)
)
return cluster_state == self.target_state
def get_terminal_states(self) -> frozenset:
return CLUSTER_TERMINAL_STATES


class EksFargateProfileStateSensor(BaseSensorOperator):
class EksFargateProfileStateSensor(EksBaseSensor):
"""
Check the state of an AWS Fargate profile until it reaches the target state or another terminal state.

Expand Down Expand Up @@ -152,47 +193,23 @@ class EksFargateProfileStateSensor(BaseSensorOperator):
def __init__(
self,
*,
cluster_name: str,
fargate_profile_name: str,
target_state: FargateProfileStates = FargateProfileStates.ACTIVE,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
**kwargs,
):
self.cluster_name = cluster_name
super().__init__(target_state=target_state, target_state_type=FargateProfileStates, **kwargs)
self.fargate_profile_name = fargate_profile_name
self.target_state = (
target_state
if isinstance(target_state, FargateProfileStates)
else FargateProfileStates(str(target_state).upper())
)
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(**kwargs)

@cached_property
def hook(self):
return EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)

def poke(self, context: Context):
fargate_profile_state = self.hook.get_fargate_profile_state(
def get_state(self) -> FargateProfileStates:
return self.hook.get_fargate_profile_state(
clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name
)
self.log.info("Fargate profile state: %s", fargate_profile_state)
if fargate_profile_state in (FARGATE_TERMINAL_STATES - {self.target_state}):
# If we reach a terminal state which is not the target state:
raise AirflowException(
UNEXPECTED_TERMINAL_STATE_MSG.format(
current_state=fargate_profile_state, target_state=self.target_state
)
)
return fargate_profile_state == self.target_state

def get_terminal_states(self) -> frozenset:
return FARGATE_TERMINAL_STATES


class EksNodegroupStateSensor(BaseSensorOperator):
class EksNodegroupStateSensor(EksBaseSensor):
"""
Check the state of an EKS managed node group until it reaches the target state or another terminal state.

Expand Down Expand Up @@ -225,41 +242,15 @@ class EksNodegroupStateSensor(BaseSensorOperator):
def __init__(
self,
*,
cluster_name: str,
nodegroup_name: str,
target_state: NodegroupStates = NodegroupStates.ACTIVE,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
**kwargs,
):
self.cluster_name = cluster_name
super().__init__(target_state=target_state, target_state_type=NodegroupStates, **kwargs)
self.nodegroup_name = nodegroup_name
self.target_state = (
target_state
if isinstance(target_state, NodegroupStates)
else NodegroupStates(str(target_state).upper())
)
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(**kwargs)

@cached_property
def hook(self):
return EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)
def get_state(self) -> NodegroupStates:
return self.hook.get_nodegroup_state(clusterName=self.cluster_name, nodegroupName=self.nodegroup_name)

def poke(self, context: Context):
nodegroup_state = self.hook.get_nodegroup_state(
clusterName=self.cluster_name, nodegroupName=self.nodegroup_name
)
self.log.info("Nodegroup state: %s", nodegroup_state)
if nodegroup_state in (NODEGROUP_TERMINAL_STATES - {self.target_state}):
# If we reach a terminal state which is not the target state:
raise AirflowException(
UNEXPECTED_TERMINAL_STATE_MSG.format(
current_state=nodegroup_state, target_state=self.target_state
)
)
return nodegroup_state == self.target_state
def get_terminal_states(self) -> frozenset:
return NODEGROUP_TERMINAL_STATES
1 change: 1 addition & 0 deletions tests/always/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ class TestAmazonProviderProjectStructure(ExampleCoverageTest):
"airflow.providers.amazon.aws.operators.appflow.AppflowBaseOperator",
"airflow.providers.amazon.aws.operators.ecs.EcsBaseOperator",
"airflow.providers.amazon.aws.sensors.ecs.EcsBaseSensor",
"airflow.providers.amazon.aws.sensors.eks.EksBaseSensor",
}

MISSING_EXAMPLES_FOR_CLASSES = {
Expand Down