Skip to content

Commit

Permalink
use parametrize decorator in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kahlstrm committed May 11, 2024
1 parent c43ebd0 commit 33c830e
Showing 1 changed file with 47 additions and 45 deletions.
92 changes: 47 additions & 45 deletions tests/utils/test_log_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,56 +313,58 @@ def test__read_for_k8s_executor(self, mock_k8s_get_task_log, create_task_instanc
else:
mock_k8s_get_task_log.assert_not_called()

def test__read_for_celery_executor_fallbacks_to_worker(self, create_task_instance):
@pytest.mark.parametrize(
"state", [TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED, TaskInstanceState.UP_FOR_RETRY]
)
def test__read_for_celery_executor_fallbacks_to_worker(self, state, create_task_instance):
"""Test for executors which do not have `get_task_log` method, it fallbacks to reading
log from worker if and only if remote logs aren't found"""
executor_name = "CeleryExecutor"
# Reading logs from worker should occur when the task is either running, deferred, or up for retry.
for state in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED, TaskInstanceState.UP_FOR_RETRY):
ti = create_task_instance(
dag_id=f"dag_for_testing_celery_executor_log_read_{state}",
task_id="task_for_testing_celery_executor_log_read",
run_type=DagRunType.SCHEDULED,
execution_date=DEFAULT_DATE,
ti = create_task_instance(
dag_id=f"dag_for_testing_celery_executor_log_read_{state}",
task_id="task_for_testing_celery_executor_log_read",
run_type=DagRunType.SCHEDULED,
execution_date=DEFAULT_DATE,
)
ti.try_number = 2
ti.state = state
with conf_vars({("core", "executor"): executor_name}):
reload(executor_loader)
fth = FileTaskHandler("")
fth._read_from_logs_server = mock.Mock()
fth._read_from_logs_server.return_value = ["this message"], ["this\nlog\ncontent"]
actual = fth._read(ti=ti, try_number=2)
fth._read_from_logs_server.assert_called_once()
# If we are in the up for retry or deferred state, log has ended and a new task try has started.
# When the state is running, the log has not ended.
expected_end_of_log = state in (TaskInstanceState.DEFERRED, TaskInstanceState.UP_FOR_RETRY)
assert actual == (
"*** this message\nthis\nlog\ncontent",
{"end_of_log": expected_end_of_log, "log_pos": 16},
)

# Previous try_number should return served logs when remote logs aren't implemented
fth._read_from_logs_server = mock.Mock()
fth._read_from_logs_server.return_value = ["served logs try_number=1"], ["this\nlog\ncontent"]
actual = fth._read(ti=ti, try_number=1)
fth._read_from_logs_server.assert_called_once()
assert actual == (
"*** served logs try_number=1\nthis\nlog\ncontent",
{"end_of_log": True, "log_pos": 16},
)

# When remote_logs is implemented, previous try_number is from remote logs without reaching worker server
fth._read_from_logs_server.reset_mock()
fth._read_remote_logs = mock.Mock()
fth._read_remote_logs.return_value = ["remote logs"], ["remote\nlog\ncontent"]
actual = fth._read(ti=ti, try_number=1)
fth._read_remote_logs.assert_called_once()
fth._read_from_logs_server.assert_not_called()
assert actual == (
"*** remote logs\nremote\nlog\ncontent",
{"end_of_log": True, "log_pos": 18},
)
ti.try_number = 2
ti.state = state
with conf_vars({("core", "executor"): executor_name}):
reload(executor_loader)
fth = FileTaskHandler("")
fth._read_from_logs_server = mock.Mock()
fth._read_from_logs_server.return_value = ["this message"], ["this\nlog\ncontent"]
actual = fth._read(ti=ti, try_number=2)
fth._read_from_logs_server.assert_called_once()
# If we are in the up for retry or deferred state, log has ended and a new task try has started.
# When the state is running, the log has not ended.
expected_end_of_log = state in (TaskInstanceState.DEFERRED, TaskInstanceState.UP_FOR_RETRY)
assert actual == (
"*** this message\nthis\nlog\ncontent",
{"end_of_log": expected_end_of_log, "log_pos": 16},
)

# Previous try_number should return served logs when remote logs aren't implemented
fth._read_from_logs_server = mock.Mock()
fth._read_from_logs_server.return_value = ["served logs try_number=1"], ["this\nlog\ncontent"]
actual = fth._read(ti=ti, try_number=1)
fth._read_from_logs_server.assert_called_once()
assert actual == (
"*** served logs try_number=1\nthis\nlog\ncontent",
{"end_of_log": True, "log_pos": 16},
)

# When remote_logs is implemented, previous try_number is from remote logs without reaching worker server
fth._read_from_logs_server.reset_mock()
fth._read_remote_logs = mock.Mock()
fth._read_remote_logs.return_value = ["remote logs"], ["remote\nlog\ncontent"]
actual = fth._read(ti=ti, try_number=1)
fth._read_remote_logs.assert_called_once()
fth._read_from_logs_server.assert_not_called()
assert actual == (
"*** remote logs\nremote\nlog\ncontent",
{"end_of_log": True, "log_pos": 18},
)

@pytest.mark.parametrize(
"remote_logs, local_logs, served_logs_checked",
Expand Down

0 comments on commit 33c830e

Please sign in to comment.