Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve code coverage for CgroupTaskRunner #39896

Merged
merged 10 commits into from
May 29, 2024
57 changes: 50 additions & 7 deletions tests/task/task_runner/test_cgroup_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,21 @@

from unittest import mock

from cgroupspy.nodes import Node

from airflow.task.task_runner.cgroup_task_runner import CgroupTaskRunner


class TestCgroupTaskRunner:
def setup_method(self):
job = mock.Mock()
job.job_type = None
job.task_instance = mock.MagicMock()
job.task_instance.run_as_user = None
job.task_instance.command_as_list.return_value = ["sleep", "1000"]
job.task_instance.task.resources = None
self.job = job

@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__")
@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.on_finish")
def test_cgroup_task_runner_super_calls(self, mock_super_on_finish, mock_super_init):
Expand All @@ -32,14 +43,46 @@ def test_cgroup_task_runner_super_calls(self, mock_super_on_finish, mock_super_i
and when task finishes, CgroupTaskRunner.on_finish() calls
super().on_finish() to delete the temp cfg file.
"""
Job = mock.Mock()
Job.job_type = None
Job.task_instance = mock.MagicMock()
Job.task_instance.run_as_user = None
Job.task_instance.command_as_list.return_value = ["sleep", "1000"]

runner = CgroupTaskRunner(Job)
runner = CgroupTaskRunner(self.job)
assert mock_super_init.called

runner.on_finish()
assert mock_super_on_finish.called

@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__")
@mock.patch("cgroupspy.nodes.Node.create_cgroup")
def test_create_cgroup_not_exist(self, mock_create_cgroup, mock_super_init):
mock_create_cgroup.return_value = Node("test_node")
node = CgroupTaskRunner(self.job)._create_cgroup("./test_cgroup")
assert node.name.decode() == "test_node"

@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__")
@mock.patch("cgroupspy.nodes.Node.delete_cgroup")
def test_delete_cgroup_exist(self, mock_delete_cgroup, mock_super_init):
CgroupTaskRunner(self.job)._delete_cgroup("./test_cgroup")
assert not mock_delete_cgroup.called

@mock.patch("airflow.utils.log.logging_mixin.LoggingMixin.__init__")
@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.run_command")
@mock.patch("cgroupspy.nodes.Node.create_cgroup")
def test_start_task(self, mock_create_cgroup, mock_run_command, logging_init):
CgroupTaskRunner(self.job).start()
assert mock_create_cgroup.called
assert mock_run_command.called
assert logging_init.called

@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__")
def test_return_code_none(self, mock_super_init):
return_code = CgroupTaskRunner(self.job).return_code()
assert not return_code

@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__")
@mock.patch("builtins.open", new_callable=mock.mock_open)
def test_log_memory_usage(self, mock_open_file, mock_super_init):
mock_open_file.return_value.read.return_value = "12345789"
mem_cgroup_node = mock.Mock()
mem_cgroup_node.full_path = "/test/cgroup"
mem_cgroup_node.controller = mock.MagicMock()
mem_cgroup_node.controller.limit_in_bytes = 123456
CgroupTaskRunner(self.job)._log_memory_usage(mem_cgroup_node)
assert mock_open_file.called