Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ECS Executor - add support to adopt orphaned tasks. #37786

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ def success(self, key: TaskInstanceKey, info=None) -> None:
"""
self.change_state(key, TaskInstanceState.SUCCESS, info)

def queued(self, key: TaskInstanceKey, info=None) -> None:
"""
Set queued state for the event.
:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
self.change_state(key, TaskInstanceState.QUEUED, info)

def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]:
"""
Return and flush the event buffer.
Expand Down
43 changes: 41 additions & 2 deletions airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import time
from collections import defaultdict, deque
from copy import deepcopy
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence

from botocore.exceptions import ClientError, NoCredentialsError

Expand All @@ -47,12 +47,13 @@
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.helpers import merge_dicts
from airflow.utils.state import State

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.providers.amazon.aws.executors.ecs.utils import (
CommandType,
ExecutorConfigType,
Expand Down Expand Up @@ -240,6 +241,7 @@ def __update_running_task(self, task):
# Get state of current task.
task_state = task.get_task_state()
task_key = self.active_workers.arn_to_key[task.task_arn]

# Mark finished tasks as either a success/failure.
if task_state == State.FAILED:
self.fail(task_key)
Expand Down Expand Up @@ -394,6 +396,7 @@ def attempt_task_runs(self):
else:
task = run_task_response["tasks"][0]
self.active_workers.add_task(task, task_key, queue, cmd, exec_config, attempt_number)
self.queued(task_key, task.task_arn)
if failure_reasons:
self.log.error(
"Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.",
Expand Down Expand Up @@ -494,3 +497,39 @@ def get_container(self, container_list):
'container "name" must be provided in "containerOverrides" configuration'
)
raise KeyError(f"No such container found by container name: {self.container_name}")

def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
ferruzzi marked this conversation as resolved.
Show resolved Hide resolved
"""
Adopt task instances which have an external_executor_id (the ECS task ARN).
ferruzzi marked this conversation as resolved.
Show resolved Hide resolved
Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling.
"""
with Stats.timer("ecs_executor.adopt_task_instances.duration"):
ferruzzi marked this conversation as resolved.
Show resolved Hide resolved
adopted_tis: list[TaskInstance] = []

if task_arns := [ti.external_executor_id for ti in tis if ti.external_executor_id]:
task_descriptions = self.__describe_tasks(task_arns).get("tasks", [])

for task in task_descriptions:
ti = [ti for ti in tis if ti.external_executor_id == task.task_arn][0]
self.active_workers.add_task(
task,
ti.key,
ti.queue,
ti.command_as_list(),
ti.executor_config,
ti.prev_attempted_tries,
)
adopted_tis.append(ti)

if adopted_tis:
tasks = [f"{task} in state {task.state}" for task in adopted_tis]
task_instance_str = "\n\t".join(tasks)
self.log.info(
"Adopted the following %d tasks from a dead executor:\n\t%s",
len(adopted_tis),
task_instance_str,
)

not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
return not_adopted_tis
12 changes: 7 additions & 5 deletions airflow/providers/amazon/aws/executors/ecs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,30 +78,30 @@ class RunTaskKwargsConfigKeys(BaseConfigKeys):
ASSIGN_PUBLIC_IP = "assign_public_ip"
CAPACITY_PROVIDER_STRATEGY = "capacity_provider_strategy"
CLUSTER = "cluster"
ferruzzi marked this conversation as resolved.
Show resolved Hide resolved
CONTAINER_NAME = "container_name"
LAUNCH_TYPE = "launch_type"
PLATFORM_VERSION = "platform_version"
SECURITY_GROUPS = "security_groups"
SUBNETS = "subnets"
TASK_DEFINITION = "task_definition"
CONTAINER_NAME = "container_name"


class AllEcsConfigKeys(RunTaskKwargsConfigKeys):
"""All keys loaded into the config which are related to the ECS Executor."""

MAX_RUN_TASK_ATTEMPTS = "max_run_task_attempts"
AWS_CONN_ID = "conn_id"
RUN_TASK_KWARGS = "run_task_kwargs"
REGION_NAME = "region_name"
CHECK_HEALTH_ON_STARTUP = "check_health_on_startup"
ferruzzi marked this conversation as resolved.
Show resolved Hide resolved
MAX_RUN_TASK_ATTEMPTS = "max_run_task_attempts"
REGION_NAME = "region_name"
RUN_TASK_KWARGS = "run_task_kwargs"


class EcsExecutorException(Exception):
"""Thrown when something unexpected has occurred within the ECS ecosystem."""


class EcsExecutorTask:
"""Data Transfer Object for an ECS Fargate Task."""
"""Data Transfer Object for an ECS Task."""
ferruzzi marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
Expand All @@ -111,13 +111,15 @@ def __init__(
containers: list[dict[str, Any]],
started_at: Any | None = None,
stopped_reason: str | None = None,
external_executor_id: str | None = None,
):
self.task_arn = task_arn
self.last_status = last_status
self.desired_status = desired_status
self.containers = containers
self.started_at = started_at
self.stopped_reason = stopped_reason
self.external_executor_id = external_executor_id

def get_task_state(self) -> str:
"""
Expand Down
40 changes: 40 additions & 0 deletions tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow.models import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.amazon.aws.executors.ecs import ecs_executor, ecs_executor_config
from airflow.providers.amazon.aws.executors.ecs.boto_schema import BotoTaskSchema
Expand Down Expand Up @@ -829,6 +830,45 @@ def test_update_running_tasks_failed(self, mock_executor, caplog):
"test failure" in caplog.messages[0]
)

def test_try_adopt_task_instances(self, mock_executor):
"""Test that executor can adopt orphaned task instances from a SchedulerJob shutdown event."""
mock_executor.ecs.describe_tasks.return_value = {
"tasks": [
{
"taskArn": "001",
ferruzzi marked this conversation as resolved.
Show resolved Hide resolved
"lastStatus": "RUNNING",
"desiredStatus": "RUNNING",
"containers": [{"name": "some-ecs-container"}],
},
{
"taskArn": "002",
"lastStatus": "RUNNING",
"desiredStatus": "RUNNING",
"containers": [{"name": "another-ecs-container"}],
},
],
"failures": [],
}

orphaned_tasks = [
mock.Mock(spec=TaskInstance),
mock.Mock(spec=TaskInstance),
mock.Mock(spec=TaskInstance),
]
orphaned_tasks[0].external_executor_id = "001" # Matches a running task_arn
orphaned_tasks[1].external_executor_id = "002" # Matches a running task_arn
orphaned_tasks[2].external_executor_id = None # One orphaned task has no external_executor_id
for task in orphaned_tasks:
task.prev_attempted_tries = 1

not_adopted_tasks = mock_executor.try_adopt_task_instances(orphaned_tasks)

mock_executor.ecs.describe_tasks.assert_called_once()
# Two of the three tasks should be adopted.
assert len(orphaned_tasks) - 1 == len(mock_executor.active_workers)
# The remaining one task is unable to be adopted.
assert 1 == len(not_adopted_tasks)


class TestEcsExecutorConfig:
@pytest.fixture
Expand Down