Skip to content

Commit

Permalink
followup changes to allow unsupported datasets
Browse files Browse the repository at this point in the history
ghstack-source-id: ce288a19c67fccd0751c6fd92ae14a161da8bfa3
Pull Request resolved: #261
  • Loading branch information
tianyu-l committed Apr 24, 2024
1 parent be432e1 commit f65ffdd
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
5 changes: 0 additions & 5 deletions torchtitan/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,3 @@
"build_hf_data_loader",
"create_tokenizer",
]

dataloader_fn = {
"c4_mini": build_hf_data_loader,
"c4": build_hf_data_loader,
}
23 changes: 15 additions & 8 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class HuggingFaceDataset(IterableDataset):
rank (int): rank of the current data parallel process
infinite (bool): whether to loop infinitely over the dataset
We currently support the c4 dataset:
We currently support the c4 dataset and a subset of it:
c4_mini (45K training entries)
c4 (177M training entries - this dataset is streamed due to the size)
>> c4 (EN) <<:
Expand Down Expand Up @@ -65,6 +66,19 @@ def __init__(
rank: int = 0,
infinite: bool = False,
) -> None:
# allow user to pass in a local path to use unsupported datasets
if dataset_name not in _supported_datasets:
if dataset_path:
logger.warning(
f"Dataset {dataset_name} is not tested or verfied. "
f"Recommended datasets are: {list(_supported_datasets.keys())}."
)
else:
raise ValueError(
f"Dataset {dataset_name} is not supported. "
f"Supported datasets are: {list(_supported_datasets.keys())}."
)

# special case to auto-load c4_mini (and any future datasets) from local dir
if dataset_name == "c4_mini":
dataset_path = f"torchtitan/datasets/{dataset_name}"
Expand All @@ -79,13 +93,6 @@ def __init__(
logger.info(f"Preparing {dataset_name} dataset from HuggingFace")
# Setting `streaming=True` works for large dataset, but is slightly
# slower and unstable.
if dataset_name not in _supported_datasets:
import warnings

warnings.warn(
f"Dataset {dataset_name} is not tested/verfied. "
f"Recommended datasets are: {_supported_datasets.keys()}."
)
if dataset_name == "c4":
# c4 is huge, and requires both streaming and language selection
# (we default to en).
Expand Down
5 changes: 2 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from torchtitan.checkpoint import CheckpointManager
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import create_tokenizer, dataloader_fn
from torchtitan.datasets import build_hf_data_loader, create_tokenizer
from torchtitan.float8_linear import build_fp8_linear
from torchtitan.logging_utils import init_logger, logger
from torchtitan.lr_scheduling import get_lr_scheduler
Expand Down Expand Up @@ -137,14 +137,13 @@ def main(job_config: JobConfig):
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

# build dataloader
build_dataloader_fn = dataloader_fn[job_config.training.dataset]
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]
dp_degree = dp_mesh.size()
dp_rank = dp_mesh.get_local_rank()
else:
dp_degree, dp_rank = 1, 0
data_loader = build_dataloader_fn(
data_loader = build_hf_data_loader(
job_config.training.dataset,
job_config.training.dataset_path,
tokenizer,
Expand Down

0 comments on commit f65ffdd

Please sign in to comment.