Skip to content

Commit

Permalink
Improve code coverage for CgroupTaskRunner (apache#39896)
Browse files Browse the repository at this point in the history
  • Loading branch information
yangyulely authored and fdemiane committed Jun 6, 2024
1 parent 5862653 commit 804d4e8
Showing 1 changed file with 50 additions and 7 deletions.
57 changes: 50 additions & 7 deletions tests/task/task_runner/test_cgroup_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,21 @@

from unittest import mock

from cgroupspy.nodes import Node

from airflow.task.task_runner.cgroup_task_runner import CgroupTaskRunner


class TestCgroupTaskRunner:
def setup_method(self):
job = mock.Mock()
job.job_type = None
job.task_instance = mock.MagicMock()
job.task_instance.run_as_user = None
job.task_instance.command_as_list.return_value = ["sleep", "1000"]
job.task_instance.task.resources = None
self.job = job

@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__")
@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.on_finish")
def test_cgroup_task_runner_super_calls(self, mock_super_on_finish, mock_super_init):
Expand All @@ -32,14 +43,46 @@ def test_cgroup_task_runner_super_calls(self, mock_super_on_finish, mock_super_i
and when task finishes, CgroupTaskRunner.on_finish() calls
super().on_finish() to delete the temp cfg file.
"""
Job = mock.Mock()
Job.job_type = None
Job.task_instance = mock.MagicMock()
Job.task_instance.run_as_user = None
Job.task_instance.command_as_list.return_value = ["sleep", "1000"]

runner = CgroupTaskRunner(Job)
runner = CgroupTaskRunner(self.job)
assert mock_super_init.called

runner.on_finish()
assert mock_super_on_finish.called

@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__")
@mock.patch("cgroupspy.nodes.Node.create_cgroup")
def test_create_cgroup_not_exist(self, mock_create_cgroup, mock_super_init):
mock_create_cgroup.return_value = Node("test_node")
node = CgroupTaskRunner(self.job)._create_cgroup("./test_cgroup")
assert node.name.decode() == "test_node"

@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__")
@mock.patch("cgroupspy.nodes.Node.delete_cgroup")
def test_delete_cgroup_exist(self, mock_delete_cgroup, mock_super_init):
CgroupTaskRunner(self.job)._delete_cgroup("./test_cgroup")
assert not mock_delete_cgroup.called

@mock.patch("airflow.utils.log.logging_mixin.LoggingMixin.__init__")
@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.run_command")
@mock.patch("cgroupspy.nodes.Node.create_cgroup")
def test_start_task(self, mock_create_cgroup, mock_run_command, logging_init):
CgroupTaskRunner(self.job).start()
assert mock_create_cgroup.called
assert mock_run_command.called
assert logging_init.called

@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__")
def test_return_code_none(self, mock_super_init):
return_code = CgroupTaskRunner(self.job).return_code()
assert not return_code

@mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__")
@mock.patch("builtins.open", new_callable=mock.mock_open)
def test_log_memory_usage(self, mock_open_file, mock_super_init):
mock_open_file.return_value.read.return_value = "12345789"
mem_cgroup_node = mock.Mock()
mem_cgroup_node.full_path = "/test/cgroup"
mem_cgroup_node.controller = mock.MagicMock()
mem_cgroup_node.controller.limit_in_bytes = 123456
CgroupTaskRunner(self.job)._log_memory_usage(mem_cgroup_node)
assert mock_open_file.called

0 comments on commit 804d4e8

Please sign in to comment.