Skip to content

Commit

Permalink
add data loading option to load from local file system
Browse files Browse the repository at this point in the history
ghstack-source-id: 3c930054d3b04faf3866048740a2ef887d066dd6
Pull Request resolved: #117
  • Loading branch information
tianyu-l committed Mar 7, 2024
1 parent 6927e45 commit d902a47
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 23 deletions.
8 changes: 8 additions & 0 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ def init_args_from_command_line(
parser.add_argument(
"--training.dataset", type=str, default="alpaca", help="dataset to use"
)
parser.add_argument(
"--training.dataset_path",
type=str,
help=(
"Path to the dataset in the file system. If provided, data will be"
"loaded from this path instead of downloaded.",
),
)
parser.add_argument(
"--training.batch_size", type=int, default=8, help="batch size"
)
Expand Down
57 changes: 34 additions & 23 deletions torchtrain/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from typing import List
from typing import List, Optional

import torch
from torch.utils.data import DataLoader, IterableDataset
Expand All @@ -10,7 +10,7 @@
from torchtrain.logging_utils import rank0_log
from torchtrain.utils import Color

from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from datasets.distributed import split_dataset_by_node

_supported_datasets = {
Expand All @@ -20,30 +20,20 @@


class HuggingFaceDataset(IterableDataset):
"""PyTorch Representation of a Dataset from Hugging Face.
We currently support two datasets:
minipile (1M training entries)
alpaca (52K training entries)
>> MiniPile <<:
MiniPile dataset is detailed in the following paper:
https://arxiv.org/abs/2304.08442
"""PyTorch Representation of the HuggingFace Dataset.
Args:
dataset_name (str): name of the dataset to load
dataset_path (Optional[str]): Path to the dataset in the file system. If provided, data will be loaded from this path instead of downloaded.
tokenizer (TokenizerIf): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
seq_len (int): max sequence length
world_size (int): number of data parallel processes participating in training
rank (int): rank of the current data parallel process
infinite (bool): whether to loop infinitely over the dataset
Data input format (minipile):
{
"text": "Open-end spinning devices with such rotor bearing arrangements are known in
various different embodiments, and have been extensively described,
for example in German Patent Publications"
}
We currently support two datasets:
alpaca (52K training entries)
minipile (1M training entries)
>> Alpaca <<:
Data input format (alpaca):
Expand All @@ -55,8 +45,17 @@ class HuggingFaceDataset(IterableDataset):
Oranges\nClass 2: Bananas, Strawberries\nClass 3: Pineapples", # noqa: B950
}
>> MiniPile <<:
MiniPile dataset is detailed in the paper: https://arxiv.org/abs/2304.08442
Data input format (minipile):
{
"text": "Open-end spinning devices with such rotor bearing arrangements are known in
various different embodiments, and have been extensively described,
for example in German Patent Publications"
}
Example:
>>> alpaca_ds = HuggingFaceDataset(tokenizer=tokenizer)
>>> alpaca_ds = HuggingFaceDataset(dataset_name="alpaca", dataset_path=None, tokenizer=tokenizer)
>>> for batch in Dataloader(alpaca_ds, batch_size=8):
print(f"Batch size: {len(batch)}")
Batch size: 8
Expand All @@ -65,21 +64,32 @@ class HuggingFaceDataset(IterableDataset):
def __init__(
self,
dataset_name: str,
dataset_path: Optional[str],
tokenizer: TokenizerIf,
seq_len: int = 2048,
world_size: int = 1,
rank: int = 0,
infinite: bool = False,
) -> None:
# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
# Setting `streaming=True` works for large dataset, but the speed is slow.
if dataset_name not in _supported_datasets:
raise ValueError(
f"Dataset {dataset_name} is not supported. Supported datasets are: {_supported_datasets.keys()}"
)

ds = load_dataset(_supported_datasets[dataset_name], split="train")
# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
if dataset_path:
rank0_log(
f"{Color.green}Loading '{dataset_name}' dataset locally from {dataset_path}...{Color.reset}"
)
ds = load_from_disk(dataset_path)
else:
rank0_log(
f"{Color.green}Downloading '{dataset_name}' dataset from HuggingFace...{Color.reset}"
)
# Setting `streaming=True` works for large dataset, but the speed is slow.
ds = load_dataset(_supported_datasets[dataset_name], split="train")

self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._tokenizer = tokenizer
Expand Down Expand Up @@ -115,6 +125,7 @@ def __iter__(self):

def build_hf_data_loader(
dataset_name: str,
dataset_path: Optional[str],
tokenizer: TokenizerIf,
batch_size: int,
seq_len: int,
Expand All @@ -123,7 +134,7 @@ def build_hf_data_loader(
infinite: bool = True,
):
hf_ds = HuggingFaceDataset(
dataset_name, tokenizer, seq_len, world_size, rank, infinite
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
)

return DataLoader(hf_ds, batch_size=batch_size)
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def main(job_config: JobConfig):
dp_degree, dp_rank = 1, 0
data_loader = build_dataloader_fn(
job_config.training.dataset,
job_config.training.dataset_path,
tokenizer,
job_config.training.batch_size,
job_config.training.seq_len,
Expand Down

0 comments on commit d902a47

Please sign in to comment.