diff --git a/tests/conftest.py b/tests/conftest.py index 3c3ae4373b..b8ed0a32f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,7 @@ import subprocess import time from pathlib import Path +from unittest.mock import patch import dask import pandas as pd @@ -371,3 +372,19 @@ def devices(request): @pytest.fixture def report(request): return request.config.getoption("--report") + + +@pytest.fixture(scope="function", autouse=True) +def cleanup_dataloader(): + """After each test runs. Call .stop() on any dataloaders created during the test. + The avoids issues with background threads hanging around and interfering with subsequent tests. + This happens when a dataloader is partially consumed (not all batches are iterated through). + """ + from merlin.dataloader.loader_base import LoaderBase + + with patch.object( + LoaderBase, "__iter__", side_effect=LoaderBase.__iter__, autospec=True + ) as patched: + yield + for call in patched.call_args_list: + call.args[0].stop()