Skip to content

Commit

Permalink
isolate some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed May 9, 2024
1 parent c86ab2b commit cdd6368
Showing 1 changed file with 88 additions and 88 deletions.
176 changes: 88 additions & 88 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,104 +219,104 @@ class TestStatefulDataLoaderIterable(TestCase):
# shuffle=True,
# )

# # class TestStatefulDataLoaderMap(TestCase):
# def _run_and_checkpoint3(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
# if num_workers == 0:
# return
# dataset = DummyMapDataset(100, shuffle=shuffle)
# generator = torch.Generator()
# generator.manual_seed(13)
# sampler = torch.utils.data.RandomSampler(dataset, generator=generator)
# dl = StatefulDataLoader(
# dataset=dataset,
# num_workers=num_workers,
# collate_fn=identity,
# snapshot_every_n_steps=every_n_steps,
# persistent_workers=pw,
# batch_size=batch_size,
# sampler=sampler,
# multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
# )
# class TestStatefulDataLoaderMap(TestCase):
def _run_and_checkpoint3(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
if num_workers == 0:
return
dataset = DummyMapDataset(100, shuffle=shuffle)
generator = torch.Generator()
generator.manual_seed(13)
sampler = torch.utils.data.RandomSampler(dataset, generator=generator)
dl = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
snapshot_every_n_steps=every_n_steps,
persistent_workers=pw,
batch_size=batch_size,
sampler=sampler,
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
)

# if interrupt is None:
# interrupt = len(dl)
if interrupt is None:
interrupt = len(dl)

# it = iter(dl)
# for _ in range(interrupt):
# next(it)
it = iter(dl)
for _ in range(interrupt):
next(it)

# state_dict = dl.state_dict()
# exp = []
# for batch in it:
# exp.append(batch)
state_dict = dl.state_dict()
exp = []
for batch in it:
exp.append(batch)

# # Restore new instance from state
# generator = torch.Generator()
# generator.manual_seed(13)
# sampler = torch.utils.data.RandomSampler(dataset, generator=generator)
# dl = StatefulDataLoader(
# dataset=dataset,
# num_workers=num_workers,
# collate_fn=identity,
# snapshot_every_n_steps=every_n_steps,
# persistent_workers=pw,
# batch_size=batch_size,
# sampler=sampler,
# multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
# )
# dl.load_state_dict(state_dict)
# batches = []
# for batch in dl:
# batches.append(batch)
# Restore new instance from state
generator = torch.Generator()
generator.manual_seed(13)
sampler = torch.utils.data.RandomSampler(dataset, generator=generator)
dl = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
snapshot_every_n_steps=every_n_steps,
persistent_workers=pw,
batch_size=batch_size,
sampler=sampler,
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
)
dl.load_state_dict(state_dict)
batches = []
for batch in dl:
batches.append(batch)

# self.assertEqual(batches, exp)
self.assertEqual(batches, exp)

# def test_no_mp3(self):
# for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
# self._run_and_checkpoint3(
# num_workers=0,
# batch_size=batch_size,
# pw=False,
# interrupt=interrupt,
# )
def test_no_mp3(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
self._run_and_checkpoint3(
num_workers=0,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

# def test_mp_x3(self):
# for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
# self._run_and_checkpoint3(
# num_workers=3,
# batch_size=batch_size,
# pw=False,
# interrupt=interrupt,
# )
def test_mp_x3(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
self._run_and_checkpoint3(
num_workers=3,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

# def test_mp_pw3(self):
# for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
# self._run_and_checkpoint3(
# num_workers=3,
# batch_size=batch_size,
# pw=True,
# interrupt=interrupt,
# )
def test_mp_pw3(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
self._run_and_checkpoint3(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

# def test_mp_every_n_steps3(self):
# batch_size = 7
# for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]):
# self._run_and_checkpoint3(
# num_workers=3,
# batch_size=batch_size,
# pw=True,
# interrupt=interrupt,
# )
def test_mp_every_n_steps3(self):
batch_size = 7
for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]):
self._run_and_checkpoint3(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

# def test_random_state3(self):
# for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]):
# self._run_and_checkpoint3(
# num_workers=num_workers,
# batch_size=7,
# pw=False,
# interrupt=interrupt,
# shuffle=True,
# )
def test_random_state3(self):
for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]):
self._run_and_checkpoint3(
num_workers=num_workers,
batch_size=7,
pw=False,
interrupt=interrupt,
shuffle=True,
)

# class TestStatefulSampler(TestCase):
def _run_and_checkpoint2(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
Expand Down

0 comments on commit cdd6368

Please sign in to comment.