Skip to content

Commit

Permalink
Save iterator state once instead of storing it again in dataset state
Browse files Browse the repository at this point in the history
Differential Revision: D57765902

Pull Request resolved: #1258
  • Loading branch information
gokulavasan authored May 24, 2024
1 parent a0412de commit 927b78b
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 4 deletions.
76 changes: 75 additions & 1 deletion test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,40 @@ def __len__(self):
return self.size


class DummyIteratorIterableDataset(torch.utils.data.IterableDataset, Iterator, Stateful):
def __init__(self, samples, shuffle, include_generator):
self.samples = samples
self.shuffle = shuffle
self.include_generator = include_generator
self.size = len(self.samples)
self.i = 0

def __iter__(self):
return self

def __next__(self):
if self.i >= len(self.samples):
raise StopIteration
if self.shuffle:
i = torch.randint(self.size, (1,)).item()
else:
i = self.i
sample = self.samples[i]
self.i += 1
return sample

def state_dict(self):
sd = {"i": self.i}
if self.include_generator:
sd["g"] = torch.get_rng_state()
return sd

def load_state_dict(self, state_dict):
self.i = state_dict["i"]
if self.include_generator:
torch.set_rng_state(state_dict["g"])


class DummyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, sizes_for_all_workers, shuffle=False, include_generator=True):
self.sizes_for_all_workers = sizes_for_all_workers
Expand Down Expand Up @@ -131,8 +165,11 @@ def identity(x):


class TestStatefulDataLoaderIterable_shard0(TestCase):
def _get_dataset(self, shuffle):
return DummyIterableDataset([0, 100, 37], shuffle=shuffle)

def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
dataset = DummyIterableDataset([0, 100, 37], shuffle=shuffle)
dataset = self._get_dataset(shuffle)
dl = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
Expand Down Expand Up @@ -987,5 +1024,42 @@ def test_json_serde_multi_process_map(self):
self._run_test_map(3)


class TestStatefulDataLoaderIterable2_shard3(TestStatefulDataLoaderIterable_shard0):
# Perform sanity test checks with the iterable dataset that is also an iterator
def _get_dataset(self, shuffle):
return DummyIteratorIterableDataset(list(range(100)), shuffle=shuffle, include_generator=True)


class TestDatasetIteratorStateDuplication_shard3(TestCase):
def test(self):
dataset = DummyIteratorIterableDataset(list(range(100)), shuffle=True, include_generator=True)
for num_workers in (0, 2):
dl = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
)
it = iter(dl)
# Fetch at least one batch from each worker
for _ in range(num_workers + 1):
next(it)
state_dict = dl.state_dict()
print(state_dict)

if num_workers > 0:
for i in range(num_workers):
# Ensure worker state is stored only once if the dataset is also the iterator
self.assertEqual(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], None)
self.assertTrue(
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][
"dataset_iter_state"
]
)
else:
self.assertEqual(state_dict["dataset_state"], None)
self.assertTrue(state_dict["fetcher_state"]["dataset_iter_state"])


if __name__ == "__main__":
unittest.main()
8 changes: 7 additions & 1 deletion torchdata/stateful_dataloader/stateful_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,14 @@ def state_dict(self):
_DATASET_ITER_STATE: try_to_serialize(self._dataset_fetcher.dataset_iter),
_FETCHER_ENDED: self._dataset_fetcher.ended,
}
dataset_state = (
try_to_serialize(self._dataset_fetcher.dataset)
if self._dataset_fetcher.dataset_iter is not self._dataset_fetcher.dataset
else None
)
else:
fetcher_state = None
dataset_state = try_to_serialize(self._dataset_fetcher.dataset)

state_dict = {
_INDEX_SAMPLER_STATE: try_to_serialize(self._index_sampler),
Expand All @@ -337,7 +343,7 @@ def state_dict(self):
_ITERABLEDATASET_LEN_CALLED: self._IterableDataset_len_called,
_SHARED_SEED: self._shared_seed,
_FETCHER_STATE: fetcher_state,
_DATASET_STATE: try_to_serialize(self._dataset_fetcher.dataset),
_DATASET_STATE: dataset_state,
_ITERATOR_FINISHED: self._finished,
}
return state_dict
Expand Down
7 changes: 5 additions & 2 deletions torchdata/stateful_dataloader/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,13 @@ def _worker_loop(
_DATASET_ITER_STATE: try_to_serialize(fetcher.dataset_iter), # type: ignore[union-attr]
_FETCHER_ENDED: fetcher.ended, # type: ignore[union-attr]
}
# Pick up any user-defined dataset state if it is not the iterator as it is already captured in fetcher_state's dataset_iter_state
dataset_state = try_to_serialize(dataset) if fetcher.dataset_iter is not dataset else None # type: ignore[union-attr]
else:
fetcher_state = None
# Pick up any user-defined dataset state, for both map/iterable style datasets
dataset_state = try_to_serialize(dataset)
# Pick up any user-defined dataset state
dataset_state = try_to_serialize(dataset)

state_dict = {
_WORKER_ID: worker_id,
_FETCHER_STATE: fetcher_state,
Expand Down

0 comments on commit 927b78b

Please sign in to comment.