diff --git a/qlora.py b/qlora.py index c9aae684..7a5ec29e 100644 --- a/qlora.py +++ b/qlora.py @@ -20,9 +20,9 @@ from torch.nn.utils.rnn import pad_sequence import argparse from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, - set_seed, + AutoTokenizer, + AutoModelForCausalLM, + set_seed, Seq2SeqTrainer, BitsAndBytesConfig, LlamaTokenizer @@ -202,17 +202,17 @@ class GenerationArguments: num_beams: Optional[int] = field(default=1) num_beam_groups: Optional[int] = field(default=1) penalty_alpha: Optional[float] = field(default=None) - use_cache: Optional[bool] = field(default=True) + use_cache: Optional[bool] = field(default=True) # Hyperparameters for logit manipulation temperature: Optional[float] = field(default=1.0) top_k: Optional[int] = field(default=50) top_p: Optional[float] = field(default=1.0) typical_p: Optional[float] = field(default=1.0) - diversity_penalty: Optional[float] = field(default=0.0) - repetition_penalty: Optional[float] = field(default=1.0) + diversity_penalty: Optional[float] = field(default=0.0) + repetition_penalty: Optional[float] = field(default=1.0) length_penalty: Optional[float] = field(default=1.0) - no_repeat_ngram_size: Optional[int] = field(default=0) + no_repeat_ngram_size: Optional[int] = field(default=0) def find_all_linear_names(args, model): cls = bnb.nn.Linear4bit if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear) @@ -260,6 +260,14 @@ def get_accelerate_model(args, checkpoint_dir): n_gpus = torch.cuda.device_count() max_memory = f'{args.max_memory_MB}MB' max_memory = {i: max_memory for i in range(n_gpus)} + device_map = "auto" + + # if we are in a distributed setting, we need to set the device map and max memory per device + if os.environ.get('LOCAL_RANK') is not None: + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + device_map = {'': local_rank} + max_memory = {'': max_memory[local_rank]} + if args.full_finetune: assert args.bits in [16, 32] @@ -270,7 +278,7 @@ def get_accelerate_model(args, checkpoint_dir): cache_dir=args.cache_dir, load_in_4bit=args.bits == 4, load_in_8bit=args.bits == 8, - device_map='auto', + device_map=device_map, max_memory=max_memory, quantization_config=BitsAndBytesConfig( load_in_4bit=args.bits == 4, @@ -396,9 +404,9 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: ) # Build the input and labels for causal LM input_ids = [] - labels = [] + labels = [] for tokenized_source, tokenized_target in zip( - tokenized_sources_with_prompt['input_ids'], + tokenized_sources_with_prompt['input_ids'], tokenized_targets['input_ids'] ): if not self.predict_with_generate: @@ -470,7 +478,7 @@ def local_dataset(dataset_name): full_dataset = Dataset.from_pandas(pd.read_csv(dataset_name, delimiter='\t')) else: raise ValueError(f"Unsupported dataset format: {dataset_name}") - + split_dataset = full_dataset.train_test_split(test_size=0.1) return split_dataset @@ -481,7 +489,7 @@ def make_data_module(tokenizer: transformers.PreTrainedTokenizer, args) -> Dict: Available datasets to be selected with `dataset` argument: - alpaca, 52002 examples - - alpaca cleaned, 51942 examples + - alpaca cleaned, 51942 examples - chip2 (OIG), 210289 examples - self-instruct, 82612 examples - hh-rlhf (Anthropic), 160800 examples @@ -518,7 +526,7 @@ def load_data(dataset_name): else: if os.path.exists(dataset_name): try: - args.dataset_format = args.dataset_format if args.dataset_format else "alpaca" + args.dataset_format = args.dataset_format if args.dataset_format else "input-output" full_dataset = local_dataset(dataset_name) return full_dataset except: @@ -528,7 +536,7 @@ def load_data(dataset_name): def format_dataset(dataset, dataset_format): if ( - dataset_format == 'alpaca' or dataset_format == 'alpaca-clean' or + dataset_format == 'alpaca' or dataset_format == 'alpaca-clean' or (dataset_format is None and args.dataset in ['alpaca', 'alpaca-clean']) ): dataset = dataset.map(extract_alpaca_dataset, remove_columns=['instruction']) @@ -550,12 +558,15 @@ def format_dataset(dataset, dataset_format): 'input': '', 'output': x['text'], }) + elif dataset_format == 'input-output': + # leave as is + pass # Remove unused columns. dataset = dataset.remove_columns( [col for col in dataset.column_names['train'] if col not in ['input', 'output']] ) return dataset - + # Load dataset. dataset = load_data(args.dataset) dataset = format_dataset(dataset, args.dataset_format) @@ -582,14 +593,14 @@ def format_dataset(dataset, dataset_format): train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) data_collator = DataCollatorForCausalLM( - tokenizer=tokenizer, + tokenizer=tokenizer, source_max_len=args.source_max_len, target_max_len=args.target_max_len, train_on_source=args.train_on_source, predict_with_generate=args.predict_with_generate, ) return dict( - train_dataset=train_dataset if args.do_train else None, + train_dataset=train_dataset if args.do_train else None, eval_dataset=eval_dataset if args.do_eval else None, predict_dataset=eval_dataset if args.do_predict else None, data_collator=data_collator @@ -648,19 +659,19 @@ def train(): if 'llama' in args.model_name_or_path or isinstance(tokenizer, LlamaTokenizer): # LLaMA tokenizer may not have correct special tokens set. # Check and add them if missing to prevent them from being parsed into different tokens. - # Note that these are present in the vocabulary. + # Note that these are present in the vocabulary. # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token. print('Adding special tokens.') tokenizer.add_special_tokens({ "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id), "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id), - "unk_token": tokenizer.convert_ids_to_tokens( + "unk_token": tokenizer.convert_ids_to_tokens( model.config.pad_token_id if model.config.pad_token_id != -1 else tokenizer.pad_token_id ), }) data_module = make_data_module(tokenizer=tokenizer, args=args) trainer = Seq2SeqTrainer( - model=model, + model=model, tokenizer=tokenizer, args=training_args, **{k:v for k,v in data_module.items() if k != 'predict_dataset'}, @@ -748,7 +759,7 @@ def on_evaluate(self, args, state, control, model, **kwargs): if args.do_train: logger.info("*** Train ***") # Note: `resume_from_checkpoint` not supported for adapter checkpoints by HF. - # Currently adapter checkpoint is reloaded as expected but optimizer/scheduler states are not. + # Currently adapter checkpoint is reloaded as expected but optimizer/scheduler states are not. train_result = trainer.train() metrics = train_result.metrics trainer.log_metrics("train", metrics)