Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

followup changes to allow unsupported datasets #261

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions torchtitan/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,3 @@
"build_hf_data_loader",
"create_tokenizer",
]

dataloader_fn = {
"c4_mini": build_hf_data_loader,
"c4": build_hf_data_loader,
}
23 changes: 15 additions & 8 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class HuggingFaceDataset(IterableDataset):
rank (int): rank of the current data parallel process
infinite (bool): whether to loop infinitely over the dataset

We currently support the c4 dataset:
We currently support the c4 dataset and a subset of it:
c4_mini (45K training entries)
c4 (177M training entries - this dataset is streamed due to the size)

>> c4 (EN) <<:
Expand Down Expand Up @@ -65,6 +66,19 @@ def __init__(
rank: int = 0,
infinite: bool = False,
) -> None:
# allow user to pass in a local path to use unsupported datasets
if dataset_name not in _supported_datasets:
if dataset_path:
logger.warning(
f"Dataset {dataset_name} is not tested or verfied. "
f"Recommended datasets are: {list(_supported_datasets.keys())}."
)
else:
raise ValueError(
f"Dataset {dataset_name} is not supported. "
f"Supported datasets are: {list(_supported_datasets.keys())}."
)
Copy link
Contributor

@fegin fegin Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also error out if users pass both a supported dataset_name and dataset_path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think this was a decision discussed long ago. The point is for any dataset, we'd like to offer two ways to use data, one is to download from HF hub, the other is to use local files if a path is provided. In the latter case, user still needs to be clear about the correspondence between dataset name and dataset path. Let's still keep this to make it less error-prone.


# special case to auto-load c4_mini (and any future datasets) from local dir
if dataset_name == "c4_mini":
dataset_path = f"torchtitan/datasets/{dataset_name}"
Expand All @@ -79,13 +93,6 @@ def __init__(
logger.info(f"Preparing {dataset_name} dataset from HuggingFace")
# Setting `streaming=True` works for large dataset, but is slightly
# slower and unstable.
if dataset_name not in _supported_datasets:
import warnings

warnings.warn(
f"Dataset {dataset_name} is not tested/verfied. "
f"Recommended datasets are: {_supported_datasets.keys()}."
)
if dataset_name == "c4":
# c4 is huge, and requires both streaming and language selection
# (we default to en).
Expand Down
5 changes: 2 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from torchtitan.checkpoint import CheckpointManager
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import create_tokenizer, dataloader_fn
from torchtitan.datasets import build_hf_data_loader, create_tokenizer
from torchtitan.float8_linear import build_fp8_linear
from torchtitan.logging_utils import init_logger, logger
from torchtitan.lr_scheduling import get_lr_scheduler
Expand Down Expand Up @@ -137,14 +137,13 @@ def main(job_config: JobConfig):
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

# build dataloader
build_dataloader_fn = dataloader_fn[job_config.training.dataset]
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]
dp_degree = dp_mesh.size()
dp_rank = dp_mesh.get_local_rank()
else:
dp_degree, dp_rank = 1, 0
data_loader = build_dataloader_fn(
data_loader = build_hf_data_loader(
job_config.training.dataset,
job_config.training.dataset_path,
tokenizer,
Expand Down
Loading