From 927b78bc94e0f8b8621d5acea02e813200d24bbb Mon Sep 17 00:00:00 2001 From: Gokul Date: Fri, 24 May 2024 11:32:08 -0700 Subject: [PATCH] Save iterator state once instead of storing it again in dataset state Differential Revision: D57765902 Pull Request resolved: https://github.com/pytorch/data/pull/1258 --- test/stateful_dataloader/test_state_dict.py | 76 ++++++++++++++++++- .../stateful_dataloader.py | 8 +- torchdata/stateful_dataloader/worker.py | 7 +- 3 files changed, 87 insertions(+), 4 deletions(-) diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 2d50c5b8e..72d9cb1c3 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -79,6 +79,40 @@ def __len__(self): return self.size +class DummyIteratorIterableDataset(torch.utils.data.IterableDataset, Iterator, Stateful): + def __init__(self, samples, shuffle, include_generator): + self.samples = samples + self.shuffle = shuffle + self.include_generator = include_generator + self.size = len(self.samples) + self.i = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.i >= len(self.samples): + raise StopIteration + if self.shuffle: + i = torch.randint(self.size, (1,)).item() + else: + i = self.i + sample = self.samples[i] + self.i += 1 + return sample + + def state_dict(self): + sd = {"i": self.i} + if self.include_generator: + sd["g"] = torch.get_rng_state() + return sd + + def load_state_dict(self, state_dict): + self.i = state_dict["i"] + if self.include_generator: + torch.set_rng_state(state_dict["g"]) + + class DummyIterableDataset(torch.utils.data.IterableDataset): def __init__(self, sizes_for_all_workers, shuffle=False, include_generator=True): self.sizes_for_all_workers = sizes_for_all_workers @@ -131,8 +165,11 @@ def identity(x): class TestStatefulDataLoaderIterable_shard0(TestCase): + def _get_dataset(self, shuffle): + return DummyIterableDataset([0, 100, 37], shuffle=shuffle) + def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): - dataset = DummyIterableDataset([0, 100, 37], shuffle=shuffle) + dataset = self._get_dataset(shuffle) dl = StatefulDataLoader( dataset=dataset, num_workers=num_workers, @@ -987,5 +1024,42 @@ def test_json_serde_multi_process_map(self): self._run_test_map(3) +class TestStatefulDataLoaderIterable2_shard3(TestStatefulDataLoaderIterable_shard0): + # Perform sanity test checks with the iterable dataset that is also an iterator + def _get_dataset(self, shuffle): + return DummyIteratorIterableDataset(list(range(100)), shuffle=shuffle, include_generator=True) + + +class TestDatasetIteratorStateDuplication_shard3(TestCase): + def test(self): + dataset = DummyIteratorIterableDataset(list(range(100)), shuffle=True, include_generator=True) + for num_workers in (0, 2): + dl = StatefulDataLoader( + dataset=dataset, + num_workers=num_workers, + collate_fn=identity, + multiprocessing_context="forkserver" if IS_MACOS and num_workers else None, + ) + it = iter(dl) + # Fetch at least one batch from each worker + for _ in range(num_workers + 1): + next(it) + state_dict = dl.state_dict() + print(state_dict) + + if num_workers > 0: + for i in range(num_workers): + # Ensure worker state is stored only once if the dataset is also the iterator + self.assertEqual(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], None) + self.assertTrue( + state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][ + "dataset_iter_state" + ] + ) + else: + self.assertEqual(state_dict["dataset_state"], None) + self.assertTrue(state_dict["fetcher_state"]["dataset_iter_state"]) + + if __name__ == "__main__": unittest.main() diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index ed6a32994..9c521a279 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -326,8 +326,14 @@ def state_dict(self): _DATASET_ITER_STATE: try_to_serialize(self._dataset_fetcher.dataset_iter), _FETCHER_ENDED: self._dataset_fetcher.ended, } + dataset_state = ( + try_to_serialize(self._dataset_fetcher.dataset) + if self._dataset_fetcher.dataset_iter is not self._dataset_fetcher.dataset + else None + ) else: fetcher_state = None + dataset_state = try_to_serialize(self._dataset_fetcher.dataset) state_dict = { _INDEX_SAMPLER_STATE: try_to_serialize(self._index_sampler), @@ -337,7 +343,7 @@ def state_dict(self): _ITERABLEDATASET_LEN_CALLED: self._IterableDataset_len_called, _SHARED_SEED: self._shared_seed, _FETCHER_STATE: fetcher_state, - _DATASET_STATE: try_to_serialize(self._dataset_fetcher.dataset), + _DATASET_STATE: dataset_state, _ITERATOR_FINISHED: self._finished, } return state_dict diff --git a/torchdata/stateful_dataloader/worker.py b/torchdata/stateful_dataloader/worker.py index 2a2227835..2152e677e 100644 --- a/torchdata/stateful_dataloader/worker.py +++ b/torchdata/stateful_dataloader/worker.py @@ -213,10 +213,13 @@ def _worker_loop( _DATASET_ITER_STATE: try_to_serialize(fetcher.dataset_iter), # type: ignore[union-attr] _FETCHER_ENDED: fetcher.ended, # type: ignore[union-attr] } + # Pick up any user-defined dataset state if it is not the iterator as it is already captured in fetcher_state's dataset_iter_state + dataset_state = try_to_serialize(dataset) if fetcher.dataset_iter is not dataset else None # type: ignore[union-attr] else: fetcher_state = None - # Pick up any user-defined dataset state, for both map/iterable style datasets - dataset_state = try_to_serialize(dataset) + # Pick up any user-defined dataset state + dataset_state = try_to_serialize(dataset) + state_dict = { _WORKER_ID: worker_id, _FETCHER_STATE: fetcher_state,