diff --git a/tests/task/task_runner/test_cgroup_task_runner.py b/tests/task/task_runner/test_cgroup_task_runner.py index c3999e19b09a30..b8b78f60647d37 100644 --- a/tests/task/task_runner/test_cgroup_task_runner.py +++ b/tests/task/task_runner/test_cgroup_task_runner.py @@ -19,10 +19,21 @@ from unittest import mock +from cgroupspy.nodes import Node + from airflow.task.task_runner.cgroup_task_runner import CgroupTaskRunner class TestCgroupTaskRunner: + def setup_method(self): + job = mock.Mock() + job.job_type = None + job.task_instance = mock.MagicMock() + job.task_instance.run_as_user = None + job.task_instance.command_as_list.return_value = ["sleep", "1000"] + job.task_instance.task.resources = None + self.job = job + @mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__") @mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.on_finish") def test_cgroup_task_runner_super_calls(self, mock_super_on_finish, mock_super_init): @@ -32,14 +43,46 @@ def test_cgroup_task_runner_super_calls(self, mock_super_on_finish, mock_super_i and when task finishes, CgroupTaskRunner.on_finish() calls super().on_finish() to delete the temp cfg file. """ - Job = mock.Mock() - Job.job_type = None - Job.task_instance = mock.MagicMock() - Job.task_instance.run_as_user = None - Job.task_instance.command_as_list.return_value = ["sleep", "1000"] - - runner = CgroupTaskRunner(Job) + runner = CgroupTaskRunner(self.job) assert mock_super_init.called runner.on_finish() assert mock_super_on_finish.called + + @mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__") + @mock.patch("cgroupspy.nodes.Node.create_cgroup") + def test_create_cgroup_not_exist(self, mock_create_cgroup, mock_super_init): + mock_create_cgroup.return_value = Node("test_node") + node = CgroupTaskRunner(self.job)._create_cgroup("./test_cgroup") + assert node.name.decode() == "test_node" + + @mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__") + @mock.patch("cgroupspy.nodes.Node.delete_cgroup") + def test_delete_cgroup_exist(self, mock_delete_cgroup, mock_super_init): + CgroupTaskRunner(self.job)._delete_cgroup("./test_cgroup") + assert not mock_delete_cgroup.called + + @mock.patch("airflow.utils.log.logging_mixin.LoggingMixin.__init__") + @mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.run_command") + @mock.patch("cgroupspy.nodes.Node.create_cgroup") + def test_start_task(self, mock_create_cgroup, mock_run_command, logging_init): + CgroupTaskRunner(self.job).start() + assert mock_create_cgroup.called + assert mock_run_command.called + assert logging_init.called + + @mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__") + def test_return_code_none(self, mock_super_init): + return_code = CgroupTaskRunner(self.job).return_code() + assert not return_code + + @mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__") + @mock.patch("builtins.open", new_callable=mock.mock_open) + def test_log_memory_usage(self, mock_open_file, mock_super_init): + mock_open_file.return_value.read.return_value = "12345789" + mem_cgroup_node = mock.Mock() + mem_cgroup_node.full_path = "/test/cgroup" + mem_cgroup_node.controller = mock.MagicMock() + mem_cgroup_node.controller.limit_in_bytes = 123456 + CgroupTaskRunner(self.job)._log_memory_usage(mem_cgroup_node) + assert mock_open_file.called