Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rank and files #126

Merged
merged 4 commits into from
Jun 6, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from torch.nn.utils.rnn import pad_sequence
import argparse
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
set_seed,
AutoTokenizer,
AutoModelForCausalLM,
set_seed,
Seq2SeqTrainer,
BitsAndBytesConfig,
LlamaTokenizer
Expand Down Expand Up @@ -202,17 +202,17 @@ class GenerationArguments:
num_beams: Optional[int] = field(default=1)
num_beam_groups: Optional[int] = field(default=1)
penalty_alpha: Optional[float] = field(default=None)
use_cache: Optional[bool] = field(default=True)
use_cache: Optional[bool] = field(default=True)

# Hyperparameters for logit manipulation
temperature: Optional[float] = field(default=1.0)
top_k: Optional[int] = field(default=50)
top_p: Optional[float] = field(default=1.0)
typical_p: Optional[float] = field(default=1.0)
diversity_penalty: Optional[float] = field(default=0.0)
repetition_penalty: Optional[float] = field(default=1.0)
diversity_penalty: Optional[float] = field(default=0.0)
repetition_penalty: Optional[float] = field(default=1.0)
length_penalty: Optional[float] = field(default=1.0)
no_repeat_ngram_size: Optional[int] = field(default=0)
no_repeat_ngram_size: Optional[int] = field(default=0)

def find_all_linear_names(args, model):
cls = bnb.nn.Linear4bit if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
Expand Down Expand Up @@ -260,6 +260,14 @@ def get_accelerate_model(args, checkpoint_dir):
n_gpus = torch.cuda.device_count()
max_memory = f'{args.max_memory_MB}MB'
max_memory = {i: max_memory for i in range(n_gpus)}
device_map = "auto"

# if we are in a distributed setting, we need to set the device map and max memory per device
if os.environ.get('LOCAL_RANK') is not None:
local_rank = int(os.environ.get('LOCAL_RANK', '0'))
device_map = {'': local_rank}
max_memory = {'': max_memory[local_rank]}


if args.full_finetune: assert args.bits in [16, 32]

Expand All @@ -270,7 +278,7 @@ def get_accelerate_model(args, checkpoint_dir):
cache_dir=args.cache_dir,
load_in_4bit=args.bits == 4,
load_in_8bit=args.bits == 8,
device_map='auto',
device_map=device_map,
max_memory=max_memory,
quantization_config=BitsAndBytesConfig(
load_in_4bit=args.bits == 4,
Expand Down Expand Up @@ -396,9 +404,9 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
)
# Build the input and labels for causal LM
input_ids = []
labels = []
labels = []
for tokenized_source, tokenized_target in zip(
tokenized_sources_with_prompt['input_ids'],
tokenized_sources_with_prompt['input_ids'],
tokenized_targets['input_ids']
):
if not self.predict_with_generate:
Expand Down Expand Up @@ -463,14 +471,14 @@ def local_dataset(dataset_name):
if dataset_name.endswith('.json'):
full_dataset = Dataset.from_json(path_or_paths=dataset_name)
elif dataset_name.endswith('.jsonl'):
full_dataset = Dataset.from_json(filename=dataset_name, format='jsonlines')
full_dataset = Dataset.from_pandas(pd.read_json(dataset_name, lines=True))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change would load the full dataset in memory, is it intentional ?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically, NLP datasets are small enough to fit in memory, so this should be fine in most cases. However, I am unaware of the benefits of using Pandas vs HF Datasets for loading and have not benchmarked the two libraries. Could you provide some more details? Otherwise, I lean towards using the HF Datasets method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using datasets library directly would be better, the previous code didn't work though so I fixed it the way I knew how to. You are right that it would be better to just correct the syntax.

Copy link
Contributor

@lhoestq lhoestq Jun 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HF Datasets converts the data to an Arrow file and memory maps the data from disk. This gives high speed while keeping the RAM usage to minimum. It's also useful in distributed setups because the memory mapped file can be seen as shared memory across processes - no need to copy the data to the different processes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd switch it but I'm traveling rest of week.

artidoro marked this conversation as resolved.
Show resolved Hide resolved
elif dataset_name.endswith('.csv'):
full_dataset = Dataset.from_pandas(pd.read_csv(dataset_name))
elif dataset_name.endswith('.tsv'):
full_dataset = Dataset.from_pandas(pd.read_csv(dataset_name, delimiter='\t'))
else:
raise ValueError(f"Unsupported dataset format: {dataset_name}")

split_dataset = full_dataset.train_test_split(test_size=0.1)
return split_dataset

Expand All @@ -481,7 +489,7 @@ def make_data_module(tokenizer: transformers.PreTrainedTokenizer, args) -> Dict:

Available datasets to be selected with `dataset` argument:
- alpaca, 52002 examples
- alpaca cleaned, 51942 examples
- alpaca cleaned, 51942 examples
- chip2 (OIG), 210289 examples
- self-instruct, 82612 examples
- hh-rlhf (Anthropic), 160800 examples
Expand Down Expand Up @@ -518,7 +526,7 @@ def load_data(dataset_name):
else:
if os.path.exists(dataset_name):
try:
args.dataset_format = args.dataset_format if args.dataset_format else "alpaca"
args.dataset_format = args.dataset_format if args.dataset_format else "input-output"
full_dataset = local_dataset(dataset_name)
return full_dataset
except:
Expand All @@ -528,7 +536,7 @@ def load_data(dataset_name):

def format_dataset(dataset, dataset_format):
if (
dataset_format == 'alpaca' or dataset_format == 'alpaca-clean' or
dataset_format == 'alpaca' or dataset_format == 'alpaca-clean' or
(dataset_format is None and args.dataset in ['alpaca', 'alpaca-clean'])
):
dataset = dataset.map(extract_alpaca_dataset, remove_columns=['instruction'])
Expand All @@ -550,12 +558,15 @@ def format_dataset(dataset, dataset_format):
'input': '',
'output': x['text'],
})
elif dataset_format == 'input-output':
# leave as is
pass
# Remove unused columns.
dataset = dataset.remove_columns(
[col for col in dataset.column_names['train'] if col not in ['input', 'output']]
)
return dataset

# Load dataset.
dataset = load_data(args.dataset)
dataset = format_dataset(dataset, args.dataset_format)
Expand All @@ -582,14 +593,14 @@ def format_dataset(dataset, dataset_format):
train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})

data_collator = DataCollatorForCausalLM(
tokenizer=tokenizer,
tokenizer=tokenizer,
source_max_len=args.source_max_len,
target_max_len=args.target_max_len,
train_on_source=args.train_on_source,
predict_with_generate=args.predict_with_generate,
)
return dict(
train_dataset=train_dataset if args.do_train else None,
train_dataset=train_dataset if args.do_train else None,
eval_dataset=eval_dataset if args.do_eval else None,
predict_dataset=eval_dataset if args.do_predict else None,
data_collator=data_collator
Expand Down Expand Up @@ -648,19 +659,19 @@ def train():
if 'llama' in args.model_name_or_path or isinstance(tokenizer, LlamaTokenizer):
# LLaMA tokenizer may not have correct special tokens set.
# Check and add them if missing to prevent them from being parsed into different tokens.
# Note that these are present in the vocabulary.
# Note that these are present in the vocabulary.
# Note also that `model.config.pad_token_id` is 0 which corresponds to `<unk>` token.
print('Adding special tokens.')
tokenizer.add_special_tokens({
"eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id),
"bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id),
"unk_token": tokenizer.convert_ids_to_tokens(
"unk_token": tokenizer.convert_ids_to_tokens(
model.config.pad_token_id if model.config.pad_token_id != -1 else tokenizer.pad_token_id
),
})
data_module = make_data_module(tokenizer=tokenizer, args=args)
trainer = Seq2SeqTrainer(
model=model,
model=model,
tokenizer=tokenizer,
args=training_args,
**{k:v for k,v in data_module.items() if k != 'predict_dataset'},
Expand Down Expand Up @@ -748,7 +759,7 @@ def on_evaluate(self, args, state, control, model, **kwargs):
if args.do_train:
logger.info("*** Train ***")
# Note: `resume_from_checkpoint` not supported for adapter checkpoints by HF.
# Currently adapter checkpoint is reloaded as expected but optimizer/scheduler states are not.
# Currently adapter checkpoint is reloaded as expected but optimizer/scheduler states are not.
train_result = trainer.train()
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
Expand Down