From 3fdbe64304990909edc05ee27c3d8d8ec2cca044 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 6 Mar 2024 19:38:25 -0800 Subject: [PATCH] add data loading option to load from local file system ghstack-source-id: 65ce0eb749f205d98beffad37244c1b87cb1e0ae Pull Request resolved: https://github.com/pytorch/torchtrain/pull/117 --- torchtrain/config_manager.py | 8 +++++ torchtrain/datasets/hf_datasets.py | 50 +++++++++++++++++------------- train.py | 1 + 3 files changed, 37 insertions(+), 22 deletions(-) diff --git a/torchtrain/config_manager.py b/torchtrain/config_manager.py index 1c44a678..57ae25f3 100644 --- a/torchtrain/config_manager.py +++ b/torchtrain/config_manager.py @@ -151,6 +151,14 @@ def init_args_from_command_line( parser.add_argument( "--training.dataset", type=str, default="alpaca", help="dataset to use" ) + parser.add_argument( + "--training.dataset_path", + type=str, + help=( + "Path to the dataset in the file system. If provided, data will be" + "loaded from this path instead of downloaded.", + ) + ) parser.add_argument( "--training.batch_size", type=int, default=8, help="batch size" ) diff --git a/torchtrain/datasets/hf_datasets.py b/torchtrain/datasets/hf_datasets.py index 1ce96f49..ceeb1a9c 100644 --- a/torchtrain/datasets/hf_datasets.py +++ b/torchtrain/datasets/hf_datasets.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -from typing import List +from typing import List, Optional import torch from torch.utils.data import DataLoader, IterableDataset @@ -10,7 +10,7 @@ from torchtrain.logging_utils import rank0_log from torchtrain.utils import Color -from datasets import load_dataset +from datasets import load_dataset, load_from_disk from datasets.distributed import split_dataset_by_node _supported_datasets = { @@ -20,15 +20,7 @@ class HuggingFaceDataset(IterableDataset): - """PyTorch Representation of a Dataset from Hugging Face. - - We currently support two datasets: - minipile (1M training entries) - alpaca (52K training entries) - - >> MiniPile <<: - MiniPile dataset is detailed in the following paper: - https://arxiv.org/abs/2304.08442 + """PyTorch Representation of the HuggingFace Dataset. Args: dataset_name (str): name of the dataset to load @@ -38,12 +30,9 @@ class HuggingFaceDataset(IterableDataset): rank (int): rank of the current data parallel process infinite (bool): whether to loop infinitely over the dataset - Data input format (minipile): - { - "text": "Open-end spinning devices with such rotor bearing arrangements are known in - various different embodiments, and have been extensively described, - for example in German Patent Publications" - } + We currently support two datasets: + minipile (1M training entries) + alpaca (52K training entries) >> Alpaca <<: Data input format (alpaca): @@ -55,6 +44,15 @@ class HuggingFaceDataset(IterableDataset): Oranges\nClass 2: Bananas, Strawberries\nClass 3: Pineapples", # noqa: B950 } + >> MiniPile <<: + MiniPile dataset is detailed in the following paper: https://arxiv.org/abs/2304.08442 + Data input format (minipile): + { + "text": "Open-end spinning devices with such rotor bearing arrangements are known in + various different embodiments, and have been extensively described, + for example in German Patent Publications" + } + Example: >>> alpaca_ds = HuggingFaceDataset(tokenizer=tokenizer) >>> for batch in Dataloader(alpaca_ds, batch_size=8): @@ -65,21 +63,28 @@ class HuggingFaceDataset(IterableDataset): def __init__( self, dataset_name: str, + dataset_path: Optional[str], tokenizer: TokenizerIf, seq_len: int = 2048, world_size: int = 1, rank: int = 0, infinite: bool = False, ) -> None: - # TODO: This is a temporary solution for small datasets like Alpaca. - # For larger datasets we need to use a more scalable approach. - # Setting `streaming=True` works for large dataset, but the speed is slow. if dataset_name not in _supported_datasets: raise ValueError( f"Dataset {dataset_name} is not supported. Supported datasets are: {_supported_datasets.keys()}" ) - ds = load_dataset(_supported_datasets[dataset_name], split="train") + # TODO: This is a temporary solution for small datasets like Alpaca. + # For larger datasets we need to use a more scalable approach. + if dataset_path: + rank0_log(f"{Color.green}Loading '{dataset_name}' dataset locally from {dataset_path}...{Color.reset}") + ds = load_from_disk(dataset_path) + else: + rank0_log(f"{Color.green}Downloading '{dataset_name}' dataset from HuggingFace...{Color.reset}") + # Setting `streaming=True` works for large dataset, but the speed is slow. + ds = load_dataset(_supported_datasets[dataset_name], split="train") + self.dataset_name = dataset_name self._data = split_dataset_by_node(ds, rank, world_size) self._tokenizer = tokenizer @@ -115,6 +120,7 @@ def __iter__(self): def build_hf_data_loader( dataset_name: str, + dataset_path: Optional[str], tokenizer: TokenizerIf, batch_size: int, seq_len: int, @@ -123,7 +129,7 @@ def build_hf_data_loader( infinite: bool = True, ): hf_ds = HuggingFaceDataset( - dataset_name, tokenizer, seq_len, world_size, rank, infinite + dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite ) return DataLoader(hf_ds, batch_size=batch_size) diff --git a/train.py b/train.py index 5de57b4e..15f0e493 100644 --- a/train.py +++ b/train.py @@ -115,6 +115,7 @@ def main(job_config: JobConfig): dp_degree, dp_rank = 1, 0 data_loader = build_dataloader_fn( job_config.training.dataset, + job_config.training.dataset_path, tokenizer, job_config.training.batch_size, job_config.training.seq_len,