Skip to content

Commit

Permalink
Additional Datasets for Finetuning (#1803)
Browse files Browse the repository at this point in the history
* wip support for additional datasets

* support for splits and load_dataset args

* clean up

* c4 and op working with splits

* load less data, run faster
  • Loading branch information
Sara Adkins authored Nov 1, 2023
1 parent 1f350a8 commit ab69c11
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 49 deletions.
2 changes: 2 additions & 0 deletions src/sparseml/transformers/finetune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@
# flake8: noqa

from .base import TextGenerationDataset
from .c4 import C4Dataset
from .open_platypus import OpenPlatypusDataset
from .wikitext import WikiTextDataset
36 changes: 12 additions & 24 deletions src/sparseml/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,24 @@


class TextGenerationDataset(RegistryMixin):
def __init__(self, text_column, data_args, tokenizer):
def __init__(self, text_column, data_args, split, tokenizer):

self.text_column = text_column
self.tokenizer = tokenizer
self.data_args = data_args
self.raw_kwargs = data_args.raw_kwargs or {}
self.split = split

if data_args.concatenate_data:
self.padding = False
elif data_args.pad_to_max_length:
self.padding = "max_length"
else:
self.padding = False

if self.padding:
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token

max_seq_length = data_args.max_seq_length
if max_seq_length > tokenizer.model_max_length:
Expand All @@ -43,7 +51,9 @@ def __init__(self, text_column, data_args, tokenizer):
self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

def get_raw_dataset(self, cache_dir):
return get_raw_dataset(self.data_args, cache_dir)
return get_raw_dataset(
self.data_args, cache_dir, split=self.split, **self.raw_kwargs
)

def tokenize_and_process(self, raw_dataset):
def tokenize_fn(data):
Expand Down Expand Up @@ -99,25 +109,3 @@ def label_fn(data):
)

return dataset

def make_dataset_splits(self, tokenized_dataset, do_train, do_eval, do_predict):
train_split = eval_split = predict_split = None
if do_train:
if "train" not in tokenized_dataset:
raise ValueError("--do_train requires a train dataset")
train_split = tokenized_dataset["train"]
if do_eval:
if "validation" not in tokenized_dataset:
raise ValueError("--do_eval requires a validation dataset")
eval_split = tokenized_dataset["validation"]
if do_predict:
if "validation" not in tokenized_dataset:
raise ValueError("--do_predict requires a test dataset")
predict_split = tokenized_dataset["test"]

split_datasets = {
"train": train_split,
"validation": eval_split,
"test": predict_split,
}
return split_datasets
27 changes: 27 additions & 0 deletions src/sparseml/transformers/finetune/data/c4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy

from sparseml.transformers.finetune.data import TextGenerationDataset


@TextGenerationDataset.register(name="c4")
class C4Dataset(TextGenerationDataset):
def __init__(self, data_args, split, tokenizer):
data_args = deepcopy(data_args)
data_args.dataset_name = "allenai/c4"
super().__init__(
text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
)
18 changes: 9 additions & 9 deletions src/sparseml/transformers/finetune/data/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass, field
from typing import Optional
from typing import Dict, Optional


@dataclass
Expand Down Expand Up @@ -50,6 +50,14 @@ class DataTrainingArguments:
"help": "Whether or not to concatenate datapoints to fill max_seq_length"
},
)
raw_kwargs: Optional[Dict] = field(
default=None,
metadata={"help": "Additional keyboard args to pass to datasets load_data"},
)
splits: Optional[Dict] = field(
default=None,
metadata={"help": "Optional percentages of each split to download"},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached preprocessed datasets or not."},
Expand Down Expand Up @@ -89,11 +97,3 @@ class DataTrainingArguments:
),
},
)
eval_on_test: bool = field(
default=False,
metadata={"help": "Evaluate the test dataset."},
)
num_export_samples: int = field(
default=0,
metadata={"help": "Number of samples (inputs/outputs) to export during eval."},
)
29 changes: 29 additions & 0 deletions src/sparseml/transformers/finetune/data/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from datasets import Dataset, load_dataset


__all__ = ["get_raw_dataset", "make_dataset_splits"]


def get_raw_dataset(data_args, cache_dir: str, **kwargs) -> Dataset:
raw_datasets = load_dataset(
data_args.dataset_name,
Expand All @@ -24,3 +27,29 @@ def get_raw_dataset(data_args, cache_dir: str, **kwargs) -> Dataset:
)

return raw_datasets


def make_dataset_splits(tokenized_datasets, do_train, do_eval, do_predict):
if "all" in tokenized_datasets and len(tokenized_datasets) == 1:
tokenized_datasets = tokenized_datasets.get("all")

train_split = eval_split = predict_split = None
if do_train:
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
train_split = tokenized_datasets["train"]
if do_eval:
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_split = tokenized_datasets["validation"]
if do_predict:
if "validation" not in tokenized_datasets:
raise ValueError("--do_predict requires a test dataset")
predict_split = tokenized_datasets["test"]

split_datasets = {
"train": train_split,
"validation": eval_split,
"test": predict_split,
}
return split_datasets
66 changes: 66 additions & 0 deletions src/sparseml/transformers/finetune/data/open_platypus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy

from sparseml.transformers.finetune.data import TextGenerationDataset
from sparseml.transformers.finetune.data.helpers import get_raw_dataset


@TextGenerationDataset.register(name="open_platypus")
class OpenPlatypusDataset(TextGenerationDataset):
ALPACA_TEMPLATE = {
"prompt_input": "Below is an instruction that describes a task, paired with an "
"input that provides further context. Write a response that appropriately "
"completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n"
"{input}\n\n### Response:\n",
"prompt_no_input": "Below is an instruction that describes a task. Write a "
"response that appropriately completes the request.\n\n### Instruction:\n{"
"instruction}\n\n### Response:\n",
}

def __init__(self, data_args, split, tokenizer):
data_args = deepcopy(data_args)
data_args.dataset_name = "garage-bAInd/Open-Platypus"
super().__init__(
text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
)

def get_raw_dataset(self, cache_dir):
raw_dataset = get_raw_dataset(
self.data_args, cache_dir, split=self.split, **self.raw_kwargs
)

def restructure_fn(sample):
if "input" in sample:
sample["text"] = self.ALPACA_TEMPLATE["prompt_input"].format(
instruction=sample["instruction"], input=sample["input"]
)
else:
sample["text"] = self.ALPACA_TEMPLATE["prompt_no_input"].format(
instruction=sample["instruction"]
)

if "output" in sample:
sample["text"] += sample["output"]
return sample

raw_dataset = raw_dataset.map(
restructure_fn,
batched=False,
remove_columns=["input", "output", "instruction", "data_source"],
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
desc="Restructuring Platypus Dataset",
)
return raw_dataset
6 changes: 4 additions & 2 deletions src/sparseml/transformers/finetune/data/wikitext.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@

@TextGenerationDataset.register(name="wikitext")
class WikiTextDataset(TextGenerationDataset):
def __init__(self, data_args, tokenizer):
super().__init__(text_column="text", data_args=data_args, tokenizer=tokenizer)
def __init__(self, data_args, split, tokenizer):
super().__init__(
text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
)
32 changes: 22 additions & 10 deletions src/sparseml/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from sparseml.transformers.finetune import Trainer, TrainingArguments
from sparseml.transformers.finetune.data import TextGenerationDataset
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
from sparseml.transformers.finetune.data.helpers import make_dataset_splits
from sparseml.transformers.finetune.helpers import apply_recipe_structure_to_model
from sparseml.transformers.finetune.model_args import ModelArguments
from sparseml.transformers.utils import SparseAutoModel, get_shared_tokenizer_src
Expand Down Expand Up @@ -152,15 +153,26 @@ def main(**kwargs):

# Load datasets
# TODO: will any of this cause problems with FSDP?
do_eval = training_args.do_eval or data_args.num_export_samples > 0
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset_name, data_args=data_args, tokenizer=tokenizer
)
raw_dataset = dataset_manager.get_raw_dataset(model_args.cache_dir)

tokenized_datasets = dataset_manager.tokenize_and_process(raw_dataset)
tokenized_datasets = dataset_manager.make_dataset_splits(
tokenized_datasets, training_args.do_train, do_eval, training_args.do_predict
splits = data_args.splits
tokenized_datasets = {}
if data_args.splits is None:
splits = {"all": None}
for split_name, split_str in splits.items():
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset_name,
data_args=data_args,
split=split_str,
tokenizer=tokenizer,
)
raw_dataset = dataset_manager.get_raw_dataset(model_args.cache_dir)
tokenized_dataset = dataset_manager.tokenize_and_process(raw_dataset)
tokenized_datasets[split_name] = tokenized_dataset

tokenized_datasets = make_dataset_splits(
tokenized_datasets,
training_args.do_train,
training_args.do_eval,
training_args.do_predict,
)
train_dataset = tokenized_datasets.get("train")
eval_dataset = tokenized_datasets.get("validation")
Expand Down Expand Up @@ -190,7 +202,7 @@ def main(**kwargs):
args=training_args,
data_args=data_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if do_eval else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
)
Expand Down
14 changes: 10 additions & 4 deletions test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@ def run():
from sparseml.transformers.finetune.text_generation import main

model = "./obcq_deployment"
dataset_name = "wikitext"
dataset_config_name = "wikitext-2-raw-v1"
concatenate_data = True
dataset_name = "c4"
dataset_config_name = "allenai--c4"
concatenate_data = None
do_train = True
do_eval = False
output_dir = "./output_finetune"
recipe = "test_trainer_recipe.yaml"
num_train_epochs=1
overwrite_output_dir = True
raw_kwargs = {"data_files": {"train": "en/c4-train.00000-of-01024.json.gz"}}
splits = {
"train": "train[:5%]",
}

main(
model_name_or_path=model,
Expand All @@ -22,7 +26,9 @@ def run():
recipe=recipe,
num_train_epochs=num_train_epochs,
overwrite_output_dir=overwrite_output_dir,
concatenate_data = concatenate_data
concatenate_data = concatenate_data,
splits = splits,
raw_kwargs=raw_kwargs
)

if __name__ == "__main__":
Expand Down

0 comments on commit ab69c11

Please sign in to comment.