Skip to content

Commit

Permalink
Adding state parser utility that can be used for modifying worker states
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
gokulavasan committed Jun 18, 2024
1 parent 958eeb0 commit cb10f71
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 1 deletion.
78 changes: 78 additions & 0 deletions test/stateful_dataloader/test_state_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from torch.testing._internal.common_utils import TestCase

from torch.utils.data import Dataset, IterableDataset
from torchdata.stateful_dataloader import Stateful, StatefulDataLoader, StateParserUtil


class StatefulIterableDataset(IterableDataset, Stateful):
def __init__(self):
self.num_calls = 0

def __iter__(self):
return self

def __next__(self):
self.num_calls += 1
return self.num_calls

def load_state_dict(self, state_dict):
self.num_calls = state_dict["num_calls"]

def state_dict(self):
return {"num_calls": self.num_calls}


def identity(x):
return x


class TestIteratorDataset(TestCase):
def test_increasing_worker(self):
ds = StatefulIterableDataset()
dl = StatefulDataLoader(ds, num_workers=2, collate_fn=identity)
it = iter(dl)
next(it)
sd = dl.state_dict()
print(sd)
del dl

parser = StateParserUtil(sd)
worker_states = parser.fetch_dataset_state()
worker_states[2] = {"num_calls": 2}
worker_states[3] = {"num_calls": 3}
parser.set_dataset_state(worker_states)

# worker state doesn't equal num workers setting
with self.assertRaises(AssertionError):
parser.get_state_dict()
parser.set_num_workers(4)

# last worker yielded id is greater than num workers
parser.set_last_worker_yielded_id(10)
with self.assertRaises(AssertionError):
parser.get_state_dict()
parser.set_last_worker_yielded_id(0)

# load the modified state
new_sd = parser.get_state_dict()
print(new_sd)
dl = StatefulDataLoader(ds, num_workers=4, collate_fn=identity)
dl.load_state_dict(new_sd)
it = iter(dl)
values = []
for _ in range(4):
values.extend(next(it))
print(values)
self.assertEqual(values, [1, 3, 4, 2])


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion torchdata/stateful_dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .state_parser import StateParserUtil
from .stateful import Stateful
from .stateful_dataloader import StatefulDataLoader

__all__ = ["Stateful", "StatefulDataLoader"]
__all__ = ["Stateful", "StatefulDataLoader", "StateParserUtil"]
74 changes: 74 additions & 0 deletions torchdata/stateful_dataloader/state_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, Union

logger = logging.getLogger(__name__)


class StateParserUtil:
"""
Utility class that can be used to modify state returned by the dataloader
"""

def __init__(self, state_dict: Dict[str, Any]):
self._state_dict = state_dict
self._is_multiprocess_state = "_snapshot" in self._state_dict

def fetch_dataset_state(self) -> Dict[int, Any]:
# Handle both cases of single process and multiprocess
if not self._is_multiprocess_state:
return self._state_dict["dataset_state"]
return {
state["worker_id"]: state["dataset_state"]
for _, state in self._state_dict["_snapshot"]["_worker_snapshots"].items()
}

def set_last_worker_yielded_id(self, last_worker_yielded: int) -> None:
# Ensure that this number is within the number of workers
if not self._is_multiprocess_state:
logger.warning("Cannot set last worker yielded id on a single process state dict")
return
self._state_dict["_snapshot"]["_last_yielded_worker_id"] = last_worker_yielded

def set_num_workers(self, num_workers: int) -> None:
if not self._is_multiprocess_state:
logger.warning("Cannot set num_workers on a single process state dict")
return
self._state_dict["_snapshot"]["_main_snapshot"]["_num_workers"] = num_workers

def set_dataset_state(self, dataset_state: Union[Dict[int, Any], Any]) -> None:
if not self._is_multiprocess_state:
self._state_dict["dataset_state"] = dataset_state
return

for id, state in dataset_state.items():
worker_states = self._state_dict["_snapshot"]["_worker_snapshots"]
worker_key = f"worker_{id}"
if worker_key in worker_states:
worker_states[worker_key]["dataset_state"] = state
else:
worker_states[worker_key] = {"worker_id": id, "dataset_state": state, "fetcher_state": None}

def get_state_dict(self) -> Dict[str, Any]:
# Perform validations
# a) num_workers should match worker_snapshots
# b) last yielded worker id should be within num_workers
if not self._is_multiprocess_state:
return self._state_dict

last_yielded_id = self._state_dict["_snapshot"]["_last_yielded_worker_id"]
num_workers = self._state_dict["_snapshot"]["_main_snapshot"]["_num_workers"]
worker_ids = self._state_dict["_snapshot"]["_worker_snapshots"].keys()

assert (
len(worker_ids) == num_workers
), f"Number of worker states {len(worker_ids)} should be equal to num_workers setting {num_workers}"
assert (
len(set(worker_ids)) == num_workers
), f"Worker state for all from [0, {num_workers}) should be present. Instead found state for only {worker_ids} workers"
assert last_yielded_id < num_workers, "Last yielded id should be strictly within the number of workers"
return self._state_dict

0 comments on commit cb10f71

Please sign in to comment.