Skip to content

Commit

Permalink
Merge pull request #593 from allenai/shanea/inspect-train-data-no-ind…
Browse files Browse the repository at this point in the history
…ices

Inspect training data without data indices
  • Loading branch information
2015aroras authored May 24, 2024
2 parents eafd154 + d137a1f commit 5789cfe
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 14 deletions.
3 changes: 2 additions & 1 deletion olmo/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def build_eval_dataloader(
)


def build_train_dataloader(train_config: TrainConfig) -> DataLoader:
def build_train_dataloader(train_config: TrainConfig, world_size: Optional[int] = None) -> DataLoader:
assert train_config.device_train_batch_size is not None
collator = DataCollator(
pad_direction=train_config.data.pad_direction, pad_token_id=train_config.model.pad_token_id
Expand All @@ -103,6 +103,7 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader:
seed=seed + (train_config.epoch or 0),
shuffle=True,
drop_last=train_config.data.drop_last,
world_size=world_size,
work_dir=work_dir,
),
batch_size=train_config.device_train_batch_size,
Expand Down
148 changes: 135 additions & 13 deletions scripts/inspect_train_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,122 @@
Use this script to inspect the data in given batches from a training run.
"""

import argparse
import gzip
import sys
import os
import tempfile
from pathlib import Path
from typing import Optional
from typing import List, Optional

from cached_path import cached_path

from olmo.checkpoint import load_state_dict
from olmo.config import TrainConfig
from olmo.data import build_memmap_dataset
from olmo.exceptions import OLMoCliError
from olmo.data import build_memmap_dataset, build_train_dataloader
from olmo.data.iterable_dataset import IterableDataset
from olmo.tokenizer import Tokenizer
from olmo.util import clean_opt, prepare_cli_environment
from olmo.util import add_cached_path_clients, clean_opt, prepare_cli_environment


def get_global_train_examples_seen_before_step(step: int, trainer_state: dict, cfg: TrainConfig):
global_step = trainer_state["global_step"]

if global_step >= step:
raise ValueError(f"Step {step} must be after training first step {global_step}")

global_train_examples_seen_this_epoch = trainer_state.get(
"global_train_examples_seen_this_epoch",
trainer_state.get( # for backwards compatibility
"global_train_examples_seen",
trainer_state.get("global_data_step", global_step) * cfg.global_train_batch_size,
),
)

# Subtract 1 from step because we want to be just before that step
global_train_examples_seen_this_epoch += (step - 1 - global_step) * cfg.global_train_batch_size
return global_train_examples_seen_this_epoch


def main(save_folder: Path, *steps: int, rank: Optional[int] = None):
def inspect_data_without_device_data_indices(
run_path: str, *steps: int, world_size: int, ranks: List[int], reference_step: int
):
cfg = TrainConfig.load(
cached_path(os.path.join(run_path, f"step{reference_step}/config.yaml")),
overrides=[clean_opt("--evaluators=[]"), clean_opt("--save_overwrite")],
)
cfg.data.num_workers = 1

if cfg.global_train_batch_size % world_size != 0:
raise ValueError(f"World size must divide global_train_batch_size {cfg.global_train_batch_size}")
cfg.device_train_batch_size = cfg.global_train_batch_size // world_size

try:
trainer_state = load_state_dict(run_path, f"step{reference_step}/train/rank0.pt", map_location="cpu")
except FileNotFoundError:
try:
# Unsharded checkpointing
trainer_state = load_state_dict(run_path, f"step{reference_step}/train.pt", map_location="cpu")
except FileNotFoundError:
# Legacy checkpointing
trainer_state = load_state_dict(run_path, f"step{reference_step}/rank0.pt", map_location="cpu")

tokenizer = Tokenizer.from_train_config(cfg)

with tempfile.TemporaryDirectory() as tmpdir:
cfg.save_folder = tmpdir

# Build dataloader in rank 0 to generate the indices file
os.environ["RANK"] = "0"
dataloader = build_train_dataloader(cfg, world_size=world_size)

for rank in ranks:
os.environ["RANK"] = str(rank)
# Set FS_LOCAL_RANK to a non-zero number so that global data indices are not rewritten
os.environ["FS_LOCAL_RANK"] = "1"

for step in steps:
# With the current implementation, this does not rebuild the global data indices if the FS local rank is non-zero
dataloader = build_train_dataloader(cfg, world_size=world_size)
assert isinstance(dataloader.dataset, IterableDataset)
dataloader.dataset.start_index = get_global_train_examples_seen_before_step(
step, trainer_state, cfg
)
batch = next(iter(dataloader))
for i, batch_entry in enumerate(batch["input_ids"].tolist()):
example = tokenizer.decode(batch_entry)
print(f'[step={step}, rank={rank}, example={i}] "{example}"\n')


def main(
run_path: str,
*steps: int,
world_size: Optional[int] = None,
rank: Optional[int] = None,
reference_step: Optional[int] = None,
use_data_indices: bool = True,
):
save_folder = Path(run_path)
if not use_data_indices or not (save_folder / "data-indices").is_dir():
assert world_size is not None
assert reference_step is not None
ranks = [rank] if rank is not None else list(range(world_size))
inspect_data_without_device_data_indices(
run_path, *steps, world_size=world_size, ranks=ranks, reference_step=reference_step
)
return

cfg = TrainConfig.load(save_folder / "config.yaml", overrides=[clean_opt("--evaluators=[]")])
dataset = build_memmap_dataset(cfg, cfg.data)
tokenizer = Tokenizer.from_train_config(cfg)

if rank is None:
world_size = len(list((save_folder / "data-indices").glob("*.tsv.gz")))
num_indices_files = len(list((save_folder / "data-indices").glob("*.tsv.gz")))
if world_size is not None and world_size != num_indices_files:
raise ValueError(f"World size {world_size} does not match number of indices files {num_indices_files}")

indices_files = {
rank: gzip.open(save_folder / "data-indices" / f"rank{rank}.tsv.gz", "rt")
for rank in range(world_size)
for rank in range(num_indices_files)
}
else:
indices_files = {rank: gzip.open(save_folder / "data-indices" / f"rank{rank}.tsv.gz", "rt")}
Expand All @@ -48,9 +142,37 @@ def main(save_folder: Path, *steps: int, rank: Optional[int] = None):
if __name__ == "__main__":
prepare_cli_environment()

try:
save_folder, rank, steps = sys.argv[1], int(sys.argv[2]), [int(i) for i in sys.argv[3:]]
except (IndexError, ValueError):
raise OLMoCliError(f"Usage: {sys.argv[0]} [SAVE_FOLDER] [RANK] [STEP_NUMBER...]")
add_cached_path_clients()

parser = argparse.ArgumentParser()

parser.add_argument("run_path", help="Path to run of which you want to inspect training data")
parser.add_argument(
"rank",
type=int,
help="Device rank for which you want to see training data. Set to `-1` to get all ranks.",
)
parser.add_argument("steps", nargs="+", type=int, help="Steps of run for which you want to see training data")
parser.add_argument(
"--no_data_indices",
action="store_false",
dest="use_data_indices",
help="If set, this script acts as if data indices are not present.",
)
parser.add_argument(
"--checkpoint_num",
type=int,
help="Step number of checkpoint from which training state is to be obtained. Required when data indices are not present.",
)
parser.add_argument("--world_size", type=int, help="World size. Required when data indices are not present.")

args = parser.parse_args()

main(Path(save_folder), *steps, rank=rank if rank >= 0 else None)
main(
args.run_path,
*args.steps,
world_size=args.world_size,
rank=args.rank if args.rank >= 0 else None,
reference_step=args.checkpoint_num if args.checkpoint_num >= 0 else None,
use_data_indices=args.use_data_indices,
)

0 comments on commit 5789cfe

Please sign in to comment.