Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement wait_until_job_complete parameter for KubernetesJobOperator #37998

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion airflow/providers/cncf/kubernetes/hooks/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import json
import tempfile
from functools import cached_property
from time import sleep
from typing import TYPE_CHECKING, Any, Generator

import aiofiles
Expand All @@ -42,6 +43,13 @@

LOADING_KUBE_CONFIG_FILE_RESOURCE = "Loading Kubernetes configuration file kube_config from {}..."

JOB_FINAL_STATUS_CONDITION_TYPES = {
"Complete",
"Failed",
}

JOB_STATUS_CONDITION_TYPES = JOB_FINAL_STATUS_CONDITION_TYPES | {"Suspended"}


def _load_body_to_dict(body: str) -> dict:
try:
Expand Down Expand Up @@ -504,14 +512,41 @@ def create_job(
return resp

def get_job(self, job_name: str, namespace: str) -> V1Job:
"""Get Job of specified name from Google Cloud.
"""Get Job of specified name and namespace.

:param job_name: Name of Job to fetch.
:param namespace: Namespace of the Job.
:return: Job object
"""
return self.batch_v1_client.read_namespaced_job(name=job_name, namespace=namespace, pretty=True)

def get_job_status(self, job_name: str, namespace: str) -> V1Job:
"""Get job with status of specified name and namespace.

:param job_name: Name of Job to fetch.
:param namespace: Namespace of the Job.
:return: Job object
"""
return self.batch_v1_client.read_namespaced_job_status(
name=job_name, namespace=namespace, pretty=True
)

def wait_until_job_complete(self, job_name: str, namespace: str, job_poll_interval: float = 10) -> V1Job:
"""Block job of specified name and namespace until it is complete or failed.

:param job_name: Name of Job to fetch.
:param namespace: Namespace of the Job.
:param job_poll_interval: Interval in seconds between polling the job status
:return: Job object
"""
while True:
self.log.info("Requesting status for the job '%s' ", job_name)
job: V1Job = self.get_job_status(job_name=job_name, namespace=namespace)
if self.is_job_complete(job=job):
return job
self.log.info("The job '%s' is incomplete. Sleeping for %i sec.", job_name, job_poll_interval)
sleep(job_poll_interval)

def list_jobs_all_namespaces(self) -> V1JobList:
"""Get list of Jobs from all namespaces.

Expand All @@ -527,6 +562,34 @@ def list_jobs_from_namespace(self, namespace: str) -> V1JobList:
"""
return self.batch_v1_client.list_namespaced_job(namespace=namespace, pretty=True)

def is_job_complete(self, job: V1Job) -> bool:
"""Check whether the given job is complete (with success or fail).

:return: Boolean indicating that the given job is complete.
"""
if conditions := job.status.conditions:
if final_condition_types := list(
c for c in conditions if c.type in JOB_FINAL_STATUS_CONDITION_TYPES and c.status
):
s = "s" if len(final_condition_types) > 1 else ""
self.log.info(
"The job '%s' state%s: %s",
job.metadata.name,
s,
", ".join(f"{c.type} at {c.last_transition_time}" for c in final_condition_types),
)
return True
return False

@staticmethod
def is_job_failed(job: V1Job) -> bool:
"""Check whether the given job is failed.

:return: Boolean indicating that the given job is failed.
"""
conditions = job.status.conditions or []
return bool(next((c for c in conditions if c.type == "Failed" and c.status), None))


def _get_bool(val) -> bool | None:
"""Convert val to bool if can be done with certainty; if we cannot infer intention we return None."""
Expand Down
21 changes: 21 additions & 0 deletions airflow/providers/cncf/kubernetes/operators/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from kubernetes.client import BatchV1Api, models as k8s
from kubernetes.client.api_client import ApiClient

from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
add_unique_suffix,
Expand Down Expand Up @@ -65,6 +66,10 @@ class KubernetesJobOperator(KubernetesPodOperator):
:param selector: The selector of this V1JobSpec.
:param suspend: Suspend specifies whether the Job controller should create Pods or not.
:param ttl_seconds_after_finished: ttlSecondsAfterFinished limits the lifetime of a Job that has finished execution (either Complete or Failed).
:param wait_until_job_complete: Whether to wait until started job finished execution (either Complete or
Failed). Default is False.
:param job_poll_interval: Interval in seconds between polling the job status. Default is 10.
Used if the parameter `wait_until_job_complete` set True.
"""

template_fields: Sequence[str] = tuple({"job_template_file"} | set(KubernetesPodOperator.template_fields))
Expand All @@ -82,6 +87,8 @@ def __init__(
selector: k8s.V1LabelSelector | None = None,
suspend: bool | None = None,
ttl_seconds_after_finished: int | None = None,
wait_until_job_complete: bool = False,
job_poll_interval: float = 10,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -97,6 +104,8 @@ def __init__(
self.selector = selector
self.suspend = suspend
self.ttl_seconds_after_finished = ttl_seconds_after_finished
self.wait_until_job_complete = wait_until_job_complete
self.job_poll_interval = job_poll_interval

@cached_property
def _incluster_namespace(self):
Expand Down Expand Up @@ -135,6 +144,18 @@ def execute(self, context: Context):
ti.xcom_push(key="job_name", value=self.job.metadata.name)
ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace)

if self.wait_until_job_complete:
self.job = self.hook.wait_until_job_complete(
job_name=self.job.metadata.name,
namespace=self.job.metadata.namespace,
job_poll_interval=self.job_poll_interval,
)
ti.xcom_push(
key="job", value=self.hook.batch_v1_client.api_client.sanitize_for_serialization(self.job)
)
if self.hook.is_job_failed(job=self.job):
raise AirflowException(f"Kubernetes job '{self.job.metadata.name}' is failed")

@staticmethod
def deserialize_job_template_file(path: str) -> k8s.V1Job:
"""
Expand Down
116 changes: 116 additions & 0 deletions tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
ASYNC_CONFIG_PATH = "/files/path/to/config/file"
POD_NAME = "test-pod"
NAMESPACE = "test-namespace"
JOB_NAME = "test-job"
POLL_INTERVAL = 100


class DeprecationRemovalRequired(AirflowException):
Expand Down Expand Up @@ -446,6 +448,120 @@ def test_delete_custom_object(
_preload_content="_preload_content",
)

@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
@patch(f"{HOOK_MODULE}.KubernetesHook.batch_v1_client")
def test_get_job_status(self, mock_client, mock_kube_config_merger, mock_kube_config_loader):
job_expected = mock_client.read_namespaced_job_status.return_value

hook = KubernetesHook()
job_actual = hook.get_job_status(job_name=JOB_NAME, namespace=NAMESPACE)

mock_client.read_namespaced_job_status.assert_called_once_with(
name=JOB_NAME, namespace=NAMESPACE, pretty=True
)
assert job_actual == job_expected

@pytest.mark.parametrize(
"conditions, expected_result",
[
(None, False),
([], False),
([mock.MagicMock(type="Complete", status=True)], False),
([mock.MagicMock(type="Complete", status=False)], False),
([mock.MagicMock(type="Failed", status=False)], False),
([mock.MagicMock(type="Failed", status=True)], True),
(
[mock.MagicMock(type="Complete", status=False), mock.MagicMock(type="Failed", status=True)],
True,
),
(
[mock.MagicMock(type="Complete", status=True), mock.MagicMock(type="Failed", status=True)],
True,
),
],
)
@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
def test_is_job_failed(self, mock_merger, mock_loader, conditions, expected_result):
mock_job = mock.MagicMock()
mock_job.status.conditions = conditions

hook = KubernetesHook()
actual_result = hook.is_job_failed(mock_job)

assert actual_result == expected_result

@pytest.mark.parametrize(
"condition_type, status, expected_result",
[
("Complete", False, False),
("Complete", True, True),
("Failed", False, False),
("Failed", True, True),
("Suspended", False, False),
("Suspended", True, False),
("Unknown", False, False),
("Unknown", True, False),
],
)
@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
def test_is_job_complete(self, mock_merger, mock_loader, condition_type, status, expected_result):
mock_job = mock.MagicMock()
mock_job.status.conditions = [mock.MagicMock(type=condition_type, status=status)]

hook = KubernetesHook()
actual_result = hook.is_job_complete(mock_job)

assert actual_result == expected_result

@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
@patch(f"{HOOK_MODULE}.KubernetesHook.get_job_status")
def test_wait_until_job_complete(self, mock_job_status, mock_kube_config_merger, mock_kube_config_loader):
job_expected = mock.MagicMock(
status=mock.MagicMock(
conditions=[
mock.MagicMock(type="TestType1"),
mock.MagicMock(type="TestType2"),
mock.MagicMock(type="Complete", status=True),
]
)
)
mock_job_status.side_effect = [
mock.MagicMock(status=mock.MagicMock(conditions=None)),
mock.MagicMock(status=mock.MagicMock(conditions=[mock.MagicMock(type="TestType")])),
mock.MagicMock(
status=mock.MagicMock(
conditions=[
mock.MagicMock(type="TestType1"),
mock.MagicMock(type="TestType2"),
]
)
),
mock.MagicMock(
status=mock.MagicMock(
conditions=[
mock.MagicMock(type="TestType1"),
mock.MagicMock(type="TestType2"),
mock.MagicMock(type="Complete", status=False),
]
)
),
job_expected,
]

hook = KubernetesHook()
with patch(f"{HOOK_MODULE}.sleep", return_value=None) as mock_sleep:
job_actual = hook.wait_until_job_complete(
job_name=JOB_NAME, namespace=NAMESPACE, job_poll_interval=POLL_INTERVAL
)

mock_job_status.assert_has_calls([mock.call(job_name=JOB_NAME, namespace=NAMESPACE)] * 5)
mock_sleep.assert_has_calls([mock.call(POLL_INTERVAL)] * 4)
assert job_actual == job_expected


class TestKubernetesHookIncorrectConfiguration:
@pytest.mark.parametrize(
Expand Down
71 changes: 70 additions & 1 deletion tests/providers/cncf/kubernetes/operators/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,24 @@
from __future__ import annotations

import re
from unittest import mock
from unittest.mock import patch

import pendulum
import pytest
from kubernetes.client import ApiClient, models as k8s

from airflow.exceptions import AirflowException
from airflow.models import DAG, DagModel, DagRun, TaskInstance
from airflow.providers.cncf.kubernetes.operators.job import KubernetesJobOperator
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.types import DagRunType

DEFAULT_DATE = timezone.datetime(2016, 1, 1, 1, 0, 0)
HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.job.KubernetesHook"
JOB_OPERATORS_PATH = "airflow.providers.cncf.kubernetes.operators.job.{}"
HOOK_CLASS = JOB_OPERATORS_PATH.format("KubernetesHook")
POLL_INTERVAL = 100


def create_context(task, persist_to_db=False, map_index=None):
Expand Down Expand Up @@ -450,3 +454,68 @@ def test_task_id_as_name_dag_id_is_ignored(self):
)
job = k.build_job_request_obj({})
assert re.match(r"job-a-very-reasonable-task-name-[a-z0-9-]+", job.metadata.name) is not None

@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job"))
@patch(HOOK_CLASS)
def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj):
mock_hook.return_value.is_job_failed.return_value = False
mock_job_request_obj = mock_build_job_request_obj.return_value
mock_job_expected = mock_create_job.return_value
mock_ti = mock.MagicMock()
context = dict(ti=mock_ti)

op = KubernetesJobOperator(
task_id="test_task_id",
)
execute_result = op.execute(context=context)

mock_build_job_request_obj.assert_called_once_with(context)
mock_create_job.assert_called_once_with(job_request_obj=mock_job_request_obj)
mock_ti.xcom_push.assert_has_calls(
[
mock.call(key="job_name", value=mock_job_expected.metadata.name),
mock.call(key="job_namespace", value=mock_job_expected.metadata.namespace),
]
)

assert op.job_request_obj == mock_job_request_obj
assert op.job == mock_job_expected
assert not op.wait_until_job_complete
assert execute_result is None
assert not mock_hook.wait_until_job_complete.called

@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job"))
@patch(HOOK_CLASS)
def test_execute_fail(self, mock_hook, mock_create_job, mock_build_job_request_obj):
mock_hook.return_value.is_job_failed.return_value = True

op = KubernetesJobOperator(
task_id="test_task_id",
)

with pytest.raises(AirflowException):
op.execute(context=dict(ti=mock.MagicMock()))

@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job"))
@patch(f"{HOOK_CLASS}.wait_until_job_complete")
def test_wait_until_job_complete(
self, mock_wait_until_job_complete, mock_create_job, mock_build_job_request_obj
):
mock_job_expected = mock_create_job.return_value
mock_ti = mock.MagicMock()

op = KubernetesJobOperator(
task_id="test_task_id", wait_until_job_complete=True, job_poll_interval=POLL_INTERVAL
)
op.execute(context=dict(ti=mock_ti))

assert op.wait_until_job_complete
assert op.job_poll_interval == POLL_INTERVAL
mock_wait_until_job_complete.assert_called_once_with(
job_name=mock_job_expected.metadata.name,
namespace=mock_job_expected.metadata.namespace,
job_poll_interval=POLL_INTERVAL,
)