diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py index 9eb2eba87bf2..4f890ffc9aa8 100755 --- a/applications/ColossalChat/coati/dataset/tokenization_utils.py +++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py @@ -49,6 +49,10 @@ def tokenize_sft( messages = data_point["messages"] template = deepcopy(conversation_template) + + if messages[0]["from"] == "system": + template.system_message = str(messages[0]["content"]) + messages.pop(0) template.messages = [] for idx, mess in enumerate(messages): if mess["from"] != template.roles[idx % 2]: @@ -148,11 +152,14 @@ def tokenize_prompt( template = deepcopy(conversation_template) template.messages = [] + if messages[0]["from"] == "system": + template.system_message = str(messages[0]["content"]) + messages.pop(0) + for idx, mess in enumerate(messages): if mess["from"] != template.roles[idx % 2]: raise ValueError( - f"Message should iterate between user and assistant and starts with a \ - line from the user. Got the following data:\n{messages}" + f"Message should iterate between user and assistant and starts with a line from the user. Got the following data:\n{messages}" ) template.append_message(mess["from"], mess["content"]) @@ -225,6 +232,10 @@ def tokenize_rlhf( template = deepcopy(conversation_template) template.clear() + if context[0]["from"] == "system": + template.system_message = str(context[0]["content"]) + context.pop(0) + for idx, mess in enumerate(context): if mess["from"] != template.roles[idx % 2]: raise ValueError( @@ -345,6 +356,10 @@ def tokenize_kto( template = deepcopy(conversation_template) template.clear() + if prompt[0]["from"] == "system": + template.system_message = str(prompt[0]["content"]) + prompt.pop(0) + if prompt[0].get("from", None) != "user": raise ValueError("conversation should start with user") if completion.get("from", None) != "assistant": diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py index 840cca074c39..bd0bbd36b9bc 100755 --- a/applications/ColossalChat/coati/models/loss.py +++ b/applications/ColossalChat/coati/models/loss.py @@ -46,7 +46,10 @@ def forward( action_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: skip = False - ratio_ = ((log_probs - old_log_probs) * action_mask).exp() + if action_mask is None: + ratio_ = (log_probs - old_log_probs).exp() + else: + ratio_ = ((log_probs - old_log_probs) * action_mask).exp() # note that if dropout is disabled (recommanded), ratio will always be 1. if ratio_.mean() > self.skip_threshold: @@ -56,7 +59,10 @@ def forward( surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages loss = -torch.min(surr1, surr2) - loss = masked_mean(loss, action_mask) + if action_mask is not None: + loss = masked_mean(loss, action_mask) + else: + loss = loss.mean(dim=1) loss = loss.mean() return loss, skip, ratio_.max() @@ -81,8 +87,10 @@ def forward( values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) surr1 = (values_clipped - returns) ** 2 surr2 = (values - returns) ** 2 - loss = torch.max(surr1, surr2) / torch.sum(action_mask) - loss = torch.sum(loss * action_mask) + if action_mask is not None: + loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask) + else: + loss = torch.mean(torch.max(surr1, surr2)) return 0.5 * loss diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py index c7ef2be8f6c4..24ddca6545c8 100755 --- a/applications/ColossalChat/coati/trainer/dpo.py +++ b/applications/ColossalChat/coati/trainer/dpo.py @@ -56,6 +56,7 @@ def __init__( beta: float = 0.1, gamma: float = 0.0, length_normalization: bool = False, + apply_loss_mask: bool = True, accumulation_steps: int = 1, start_epoch: int = 0, save_interval: int = 0, @@ -67,6 +68,7 @@ def __init__( self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer self.actor_loss_fn = DpoLoss(beta, gamma) + self.apply_loss_mask = apply_loss_mask self.save_interval = save_interval self.coordinator = coordinator self.save_dir = save_dir @@ -135,6 +137,10 @@ def _train(self, epoch: int): batch["reject_attention_mask"], batch["reject_loss_mask"], ) + if not self.apply_loss_mask: + chosen_loss_mask = chosen_loss_mask.fill_(1.0) + reject_loss_mask = reject_loss_mask.fill_(1.0) + batch_size = chosen_input_ids.size()[0] actor_all_logits = self.model( @@ -284,6 +290,9 @@ def _eval(self, epoch: int): batch["reject_attention_mask"], batch["reject_loss_mask"], ) + if not self.apply_loss_mask: + chosen_loss_mask = chosen_loss_mask.fill_(1.0) + reject_loss_mask = reject_loss_mask.fill_(1.0) batch_size = chosen_input_ids.size()[0] diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py index 8ab0bc66bcf9..6462ba816686 100755 --- a/applications/ColossalChat/coati/trainer/kto.py +++ b/applications/ColossalChat/coati/trainer/kto.py @@ -6,7 +6,7 @@ from typing import Any, Optional import torch -import torch.distributed +import torch.distributed as dist from coati.models.loss import KTOLoss from coati.models.utils import calc_masked_log_probs from coati.trainer.utils import all_reduce_mean @@ -59,6 +59,7 @@ def __init__( beta: float = 0.1, desirable_weight: float = 1.0, undesirable_weight: float = 1.0, + apply_loss_mask: bool = True, accumulation_steps: int = 1, start_epoch: int = 0, save_interval: int = 0, @@ -70,6 +71,7 @@ def __init__( self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight) + self.apply_loss_mask = apply_loss_mask self.save_interval = save_interval self.coordinator = coordinator self.save_dir = save_dir @@ -134,6 +136,10 @@ def _train(self, epoch: int): batch["kl_attention_mask"], batch["kl_loss_mask"], ) + if not self.apply_loss_mask: + loss_mask = loss_mask.fill_(1.0) + kl_loss_mask = kl_loss_mask.fill_(1.0) + batch_size = input_ids.size()[0] # actor logits @@ -182,8 +188,28 @@ def _train(self, epoch: int): # sync loss_mean = all_reduce_mean(tensor=loss) - chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean()) - rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean()) + chosen_reward_mean = chosen_rewards.mean() + chosen_rewards_list = [ + torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size()) + ] + dist.all_gather(chosen_rewards_list, chosen_reward_mean) + rejected_reward_mean = rejected_rewards.mean() + rejected_rewards_list = [ + torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size()) + ] + dist.all_gather(rejected_rewards_list, rejected_reward_mean) + chosen_rewards_list = [i for i in chosen_rewards_list if not i.isnan()] + rejected_rewards_list = [i for i in rejected_rewards_list if not i.isnan()] + chosen_rewards_mean = ( + torch.stack(chosen_rewards_list).mean() + if len(chosen_rewards_list) > 0 + else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device) + ) + rejected_rewards_mean = ( + torch.stack(rejected_rewards_list).mean() + if len(rejected_rewards_list) > 0 + else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device) + ) self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item()) @@ -256,6 +282,11 @@ def _eval(self, epoch: int): batch["kl_attention_mask"], batch["kl_loss_mask"], ) + + if not self.apply_loss_mask: + loss_mask = loss_mask.fill_(1.0) + kl_loss_mask = kl_loss_mask.fill_(1.0) + batch_size = input_ids.size()[0] # actor logits diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py index b039da4afc30..c2f75771cdff 100644 --- a/applications/ColossalChat/coati/trainer/orpo.py +++ b/applications/ColossalChat/coati/trainer/orpo.py @@ -52,6 +52,7 @@ def __init__( tokenizer: PreTrainedTokenizerBase, max_epochs: int = 1, lam: float = 0.1, + apply_loss_mask: bool = True, accumulation_steps: int = 1, start_epoch: int = 0, save_interval: int = 0, @@ -67,6 +68,7 @@ def __init__( self.save_dir = save_dir self.num_train_step = 0 self.lam = lam + self.apply_loss_mask = apply_loss_mask self.accumulation_steps = accumulation_steps self.device = get_current_device() self.accumulative_meter = AccumulativeMeanMeter() @@ -130,6 +132,11 @@ def _train(self, epoch: int): batch["reject_attention_mask"], batch["reject_loss_mask"], ) + + if not self.apply_loss_mask: + chosen_loss_mask = chosen_loss_mask.fill_(1.0) + reject_loss_mask = reject_loss_mask.fill_(1.0) + batch_size = chosen_input_ids.size()[0] actor_out = self.model( input_ids=torch.cat([chosen_input_ids, reject_input_ids]), @@ -263,6 +270,11 @@ def _eval(self, epoch: int): batch["reject_attention_mask"], batch["reject_loss_mask"], ) + + if not self.apply_loss_mask: + chosen_loss_mask = chosen_loss_mask.fill_(1.0) + reject_loss_mask = reject_loss_mask.fill_(1.0) + batch_size = chosen_input_ids.size()[0] actor_out = self.model( input_ids=torch.cat([chosen_input_ids, reject_input_ids]), diff --git a/applications/ColossalChat/coati/trainer/ppo.py b/applications/ColossalChat/coati/trainer/ppo.py index 287767669516..63c813b39ef9 100755 --- a/applications/ColossalChat/coati/trainer/ppo.py +++ b/applications/ColossalChat/coati/trainer/ppo.py @@ -102,6 +102,7 @@ def __init__( sample_buffer: bool = False, dataloader_pin_memory: bool = True, offload_inference_models: bool = True, + apply_loss_mask: bool = True, accumulation_steps: int = 1, save_interval: int = 0, save_dir: str = None, @@ -140,6 +141,7 @@ def __init__( self.actor_optim = actor_optim self.critic_optim = critic_optim self.save_interval = save_interval + self.apply_loss_mask = apply_loss_mask self.coordinator = coordinator self.actor_save_dir = os.path.join(save_dir, "actor") self.critic_save_dir = os.path.join(save_dir, "critic") @@ -229,7 +231,10 @@ def _training_step(self, experience: Experience): action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions) actor_loss, to_skip, max_ratio = self.actor_loss_fn( - action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask + action_log_probs, + experience.action_log_probs, + experience.advantages, + action_mask=experience.action_mask if self.apply_loss_mask else None, ) actor_loss = (1 - self.ptx_coef) * actor_loss if not to_skip: @@ -249,7 +254,10 @@ def _training_step(self, experience: Experience): input_ids=experience.sequences, attention_mask=experience.attention_mask ) # [batch size, prompt_length + response_length] critic_loss = self.critic_loss_fn( - values[:, -num_actions:], experience.values, experience.advantages, action_mask=experience.action_mask + values[:, -num_actions:], + experience.values, + experience.advantages, + action_mask=experience.action_mask if self.apply_loss_mask else None, ) critic_loss = critic_loss * self.vf_coef self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim) diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py index c09d61034984..d37676ada3e0 100755 --- a/applications/ColossalChat/coati/trainer/sft.py +++ b/applications/ColossalChat/coati/trainer/sft.py @@ -41,6 +41,7 @@ def __init__( lr_scheduler: _LRScheduler, max_epochs: int = 2, accumulation_steps: int = 8, + apply_loss_mask: bool = True, start_epoch=0, save_interval: int = None, save_dir: str = None, @@ -55,6 +56,7 @@ def __init__( self.coordinator = coordinator self.num_train_step = 0 self.num_eval_step = 0 + self.apply_loss_mask = apply_loss_mask self.accumulative_meter = AccumulativeMeanMeter() def _before_fit( @@ -100,7 +102,11 @@ def _train(self, epoch: int): for i, batch in enumerate(self.train_dataloader): batch = to_device(batch, torch.cuda.current_device()) batch_size = batch["input_ids"].size(0) - outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) + outputs = self.model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + ) loss = outputs.loss self.booster.backward(loss=loss, optimizer=self.optimizer) @@ -158,7 +164,11 @@ def _eval(self, epoch: int): ) for batch in self.eval_dataloader: batch = to_device(batch, torch.cuda.current_device()) - outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) + outputs = self.model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + ) loss_mean = all_reduce_mean(tensor=outputs.loss) self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0)) step_bar.update() diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md index b749f197ed21..904d69cfcc4e 100755 --- a/applications/ColossalChat/examples/README.md +++ b/applications/ColossalChat/examples/README.md @@ -387,6 +387,7 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai - save_dir: path to store the model checkpoints. - max_length: input will be padded/truncated to max_length before feeding to the model. - max_epochs: number of epochs to train. +- disable_loss_mask: whether to use the loss mask to mask the loss or not. For example, in SFT, if the loss mask is disabled, the model will compute the loss across all tokens in the sequence, if the loss mask is applied, only tokens correspond to the assistant responses will contribute to the final loss. - batch_size: training batch size. - mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some devices may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility. - save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes. diff --git a/applications/ColossalChat/examples/inference/inference.py b/applications/ColossalChat/examples/inference/inference.py index 103bd8d95016..5f59ba4528fc 100755 --- a/applications/ColossalChat/examples/inference/inference.py +++ b/applications/ColossalChat/examples/inference/inference.py @@ -53,8 +53,8 @@ def load_model_and_tokenizer(model_path, tokenizer_path, device="cuda", **kwargs tuple: A tuple containing the loaded model and tokenizer. """ - model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs, trust_remote_code=True).to(torch.bfloat16) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token model.to(device) diff --git a/applications/ColossalChat/examples/inference/round.txt b/applications/ColossalChat/examples/inference/round.txt deleted file mode 100644 index ba02074c1a03..000000000000 --- a/applications/ColossalChat/examples/inference/round.txt +++ /dev/null @@ -1,104 +0,0 @@ - - -========== -round 1: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -tell me a story [/INST] Great, let’s hear a story. - -========== - - -========== -round 2: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -tell me a story [/INST] Great, let’s hear a story. [INST] calculate 1+1 [/INST] 1+1=2 - -========== - - -========== -round 3: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -tell me a story [/INST] Great, let’s hear a story. [INST] calculate 1+1 [/INST] 1+1=2 [INST] who is the first president of the USA [/INST] The first president of the United States was George Washington. - -========== - - -========== -round 1: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -who is the first president of the USA? [/INST] The first president of the United States was George Washington. - -========== - - -========== -round 2: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -who is the first president of the USA? [/INST] The first president of the United States was George Washington. [INST] tell me a story [/INST] One story that might be interesting is the story of how the United States was founded. In 1776, the Thirteen Colonies united together to form the new nation of America. The first president of the new nation was George Washington. The first state was Pennsylvania. The first capital was Philadelphia. The first presidential election was held in 1787, and was between George Read, a Federalist, and John Adams, a Republican. The Federalists wanted to keep the power of the federal government limited, while the Republicans wanted the power to be spread around to the states. The Federalists won, and George Washington was elected president. - -========== - - -========== -round 1: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? - -========== - - -========== -round 2: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? [INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? - -========== - - -========== -round 3: -[INST] <> -A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. - - -<> - -tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? [INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? [INST] Is Donald Trump the president of the United States [/INST] Yes, Donald Trump became the 45th president of the United States in January of 2016. - -========== diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py index 44131f572445..d88750aebc8f 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py @@ -278,6 +278,7 @@ def train(args): beta=args.beta, gamma=args.gamma, length_normalization=args.length_normalization, + apply_loss_mask=not args.disable_loss_mask, ) trainer.fit( @@ -346,6 +347,7 @@ def train(args): default=False, help="Disable the reference model (enabled by default)", ) + parser.add_argument("--disable_loss_mask", default=False, action="store_true") parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path") parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.py b/applications/ColossalChat/examples/training_scripts/train_kto.py index d063b82bb214..598fd8062fcf 100755 --- a/applications/ColossalChat/examples/training_scripts/train_kto.py +++ b/applications/ColossalChat/examples/training_scripts/train_kto.py @@ -297,6 +297,7 @@ def train(args): beta=args.beta, desirable_weight=args.desirable_weight, undesirable_weight=args.undesirable_weight, + apply_loss_mask=not args.disable_loss_mask, ) trainer.fit( @@ -341,6 +342,7 @@ def train(args): parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss") parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss") parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss") + parser.add_argument("--disable_loss_mask", default=False, action="store_true") parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) parser.add_argument("--zero_cpu_offload", default=False, action="store_true") diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.py b/applications/ColossalChat/examples/training_scripts/train_orpo.py index f06524507d5f..87860f7ea023 100755 --- a/applications/ColossalChat/examples/training_scripts/train_orpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_orpo.py @@ -259,6 +259,7 @@ def train(args): save_dir=args.save_dir, coordinator=coordinator, lam=args.lam, + apply_loss_mask=not args.disable_loss_mask, ) trainer.fit( @@ -301,6 +302,7 @@ def train(args): parser.add_argument("--pp", type=int, default=1) parser.add_argument("--sp", type=int, default=1) parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss") + parser.add_argument("--disable_loss_mask", default=False, action="store_true") parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) parser.add_argument("--zero_cpu_offload", default=False, action="store_true") diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.py b/applications/ColossalChat/examples/training_scripts/train_ppo.py index 333be9963c06..c104183942c3 100755 --- a/applications/ColossalChat/examples/training_scripts/train_ppo.py +++ b/applications/ColossalChat/examples/training_scripts/train_ppo.py @@ -411,6 +411,7 @@ def train(args): use_cache=True, do_sample=True, temperature=0.7, + apply_loss_mask=not args.disable_loss_mask, accumulation_steps=args.accumulation_steps, save_dir=args.save_path, save_interval=args.save_interval, @@ -498,6 +499,7 @@ def train(args): parser.add_argument("--critic_lr", type=float, default=9e-6) parser.add_argument("--kl_coef", type=float, default=0.1) parser.add_argument("--ptx_coef", type=float, default=0.0) + parser.add_argument("--disable_loss_mask", default=False, action="store_true") parser.add_argument("--max_length", type=int, default=2048) parser.add_argument("--max_seq_len", type=int, default=256) parser.add_argument("--log_dir", default="logs", type=str) diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py index 6007a8599277..c4ef3b783d4d 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.py +++ b/applications/ColossalChat/examples/training_scripts/train_sft.py @@ -272,6 +272,7 @@ def train(args): lr_scheduler=lr_scheduler, max_epochs=args.max_epochs, accumulation_steps=args.accumulation_steps, + apply_loss_mask=not args.disable_loss_mask, start_epoch=start_epoch, save_interval=args.save_interval, save_dir=args.save_path, @@ -317,6 +318,7 @@ def train(args): parser.add_argument("--tp", type=int, default=1) parser.add_argument("--pp", type=int, default=1) parser.add_argument("--sp", type=int, default=1) + parser.add_argument("--disable_loss_mask", default=False, action="store_true") parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) parser.add_argument("--zero_cpu_offload", default=False, action="store_true")