diff --git a/paddlenlp/transformers/tensor_parallel_utils.py b/paddlenlp/transformers/tensor_parallel_utils.py index ea68e41207c5..679345ff78c1 100644 --- a/paddlenlp/transformers/tensor_parallel_utils.py +++ b/paddlenlp/transformers/tensor_parallel_utils.py @@ -254,12 +254,20 @@ def forward( labels_chunk = labels[token_start_idx:token_end_idx] # logits calculations - logits_chunk_cast = paddle.matmul(hidden_states_chunk, lm_head_weight_cast, transpose_y=transpose_y) + logits_chunk_cast = paddle.matmul( + hidden_states_chunk, + lm_head_weight_cast, + transpose_y=transpose_y, + ) if lm_head_bias is not None: logits_chunk_cast += lm_head_bias_cast if tensor_parallel_degree > 1 and not tensor_parallel_output: logits_chunk_cast_lst = [] - dist.all_gather(logits_chunk_cast_lst, logits_chunk_cast, group=model_parallel_group) + dist.all_gather( + logits_chunk_cast_lst, + logits_chunk_cast, + group=model_parallel_group, + ) logits_chunk_cast = paddle.concat(logits_chunk_cast_lst, axis=-1) logits_chunk = logits_chunk_cast.astype("float32") @@ -271,18 +279,30 @@ def forward( exp_logits = paddle.exp(normalized_logits) sum_exp_logits = paddle.sum(exp_logits, axis=-1, keepdim=True) if tensor_parallel_degree > 1 and tensor_parallel_output: - dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=model_parallel_group) + dist.all_reduce( + sum_exp_logits, + op=dist.ReduceOp.SUM, + group=model_parallel_group, + ) log_sum_exp_logits = paddle.log(sum_exp_logits) # cross entropy labels_one_hot = labels_chunk.unsqueeze(1) == indices label_logits = paddle.sum( - paddle.where(labels_one_hot, normalized_logits, paddle.zeros_like(normalized_logits)), + paddle.where( + labels_one_hot, + normalized_logits, + paddle.zeros_like(normalized_logits), + ), axis=-1, keepdim=True, ) if tensor_parallel_degree > 1 and tensor_parallel_output: - dist.all_reduce(label_logits, op=dist.ReduceOp.SUM, group=model_parallel_group) + dist.all_reduce( + label_logits, + op=dist.ReduceOp.SUM, + group=model_parallel_group, + ) token_loss_chunk = (log_sum_exp_logits - label_logits).squeeze(1) / divisor cond = loss_mask[token_start_idx:token_end_idx].astype("bool") token_loss_chunk = paddle.where(cond, token_loss_chunk, paddle.zeros_like(token_loss_chunk)) @@ -298,18 +318,30 @@ def forward( grad_logits_chunk = (exp_logits / sum_exp_logits - labels_one_hot.astype("float32")) / divisor grad_logits_chunk = grad_logits_chunk.astype(dtype) grad_logits_chunk = paddle.where( - cond.unsqueeze(1), grad_logits_chunk, paddle.zeros_like(grad_logits_chunk) + cond.unsqueeze(1), + grad_logits_chunk, + paddle.zeros_like(grad_logits_chunk), ) if grad_hidden_states is not None: grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul( - grad_logits_chunk, lm_head_weight_cast, transpose_y=not transpose_y + grad_logits_chunk, + lm_head_weight_cast, + transpose_y=not transpose_y, ) if grad_lm_head_weight is not None: if transpose_y: - grad_lm_head_weight += paddle.matmul(grad_logits_chunk, hidden_states_chunk, transpose_x=True) + grad_lm_head_weight += paddle.matmul( + grad_logits_chunk, + hidden_states_chunk, + transpose_x=True, + ) else: - grad_lm_head_weight += paddle.matmul(hidden_states_chunk, grad_logits_chunk, transpose_x=True) + grad_lm_head_weight += paddle.matmul( + hidden_states_chunk, + grad_logits_chunk, + transpose_x=True, + ) if grad_lm_head_bias is not None: grad_lm_head_bias += grad_logits_chunk.astype("float32").sum(axis=0).astype(dtype) @@ -340,7 +372,11 @@ def forward( grad_args = [] if ctx.hidden_states_has_grad: if tensor_parallel_degree > 1: - dist.all_reduce(grad_hidden_states, op=dist.ReduceOp.SUM, group=model_parallel_group) + dist.all_reduce( + grad_hidden_states, + op=dist.ReduceOp.SUM, + group=model_parallel_group, + ) grad_args.append(grad_hidden_states.reshape(original_shape)) if ctx.lm_head_weight_has_grad: grad_args.append(grad_lm_head_weight) @@ -376,9 +412,20 @@ def backward(ctx, grad_output): grad_lm_head_bias = None if ctx.aux_num == 1: - return grad_hidden_states, grad_lm_head_weight, grad_lm_head_bias, None + return ( + grad_hidden_states, + grad_lm_head_weight, + grad_lm_head_bias, + None, + ) else: - return grad_hidden_states, grad_lm_head_weight, grad_lm_head_bias, None, None + return ( + grad_hidden_states, + grad_lm_head_weight, + grad_lm_head_bias, + None, + None, + ) # return_token_loss = True grad_token_loss = grad_output.reshape([-1]) @@ -444,12 +491,20 @@ def backward(ctx, grad_output): labels_chunk = labels[token_start_idx:token_end_idx] # logits calculations - logits_chunk_cast = paddle.matmul(hidden_states_chunk, lm_head_weight_cast, transpose_y=transpose_y) + logits_chunk_cast = paddle.matmul( + hidden_states_chunk, + lm_head_weight_cast, + transpose_y=transpose_y, + ) if lm_head_bias is not None: logits_chunk_cast += lm_head_bias_cast if tensor_parallel_degree > 1 and not tensor_parallel_output: logits_chunk_cast_lst = [] - dist.all_gather(logits_chunk_cast_lst, logits_chunk_cast, group=model_parallel_group) + dist.all_gather( + logits_chunk_cast_lst, + logits_chunk_cast, + group=model_parallel_group, + ) logits_chunk_cast = paddle.concat(logits_chunk_cast_lst, axis=-1) logits_chunk = logits_chunk_cast.astype("float32") @@ -461,7 +516,11 @@ def backward(ctx, grad_output): exp_logits = paddle.exp(normalized_logits) sum_exp_logits = paddle.sum(exp_logits, axis=-1, keepdim=True) if tensor_parallel_degree > 1 and tensor_parallel_output: - dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=model_parallel_group) + dist.all_reduce( + sum_exp_logits, + op=dist.ReduceOp.SUM, + group=model_parallel_group, + ) labels_one_hot = labels_chunk.unsqueeze(1) == indices if tensor_parallel_degree > 1 and not tensor_parallel_output: @@ -473,12 +532,16 @@ def backward(ctx, grad_output): grad_logits_chunk = grad_logits_chunk.astype(dtype) cond = loss_mask[token_start_idx:token_end_idx].astype("bool") grad_logits_chunk = paddle.where( - cond.unsqueeze(1), grad_logits_chunk, paddle.zeros_like(grad_logits_chunk) + cond.unsqueeze(1), + grad_logits_chunk, + paddle.zeros_like(grad_logits_chunk), ) if grad_hidden_states is not None: grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul( - grad_logits_chunk, lm_head_weight_cast, transpose_y=not transpose_y + grad_logits_chunk, + lm_head_weight_cast, + transpose_y=not transpose_y, ) if grad_lm_head_weight is not None: if transpose_y: @@ -490,10 +553,25 @@ def backward(ctx, grad_output): if grad_hidden_states is not None: if tensor_parallel_degree > 1: - dist.all_reduce(grad_hidden_states, op=dist.ReduceOp.SUM, group=model_parallel_group) + dist.all_reduce( + grad_hidden_states, + op=dist.ReduceOp.SUM, + group=model_parallel_group, + ) grad_hidden_states = grad_hidden_states.reshape(ctx.original_shape) if ctx.aux_num == 1: - return grad_hidden_states, grad_lm_head_weight, grad_lm_head_bias, None + return ( + grad_hidden_states, + grad_lm_head_weight, + grad_lm_head_bias, + None, + ) else: - return grad_hidden_states, grad_lm_head_weight, grad_lm_head_bias, None, None + return ( + grad_hidden_states, + grad_lm_head_weight, + grad_lm_head_bias, + None, + None, + ) diff --git a/paddlenlp/trl/__init__.py b/paddlenlp/trl/__init__.py index 41e0d6556e2a..a67fd0e69f6d 100644 --- a/paddlenlp/trl/__init__.py +++ b/paddlenlp/trl/__init__.py @@ -14,6 +14,8 @@ from .dpo_criterion import DPOCriterion from .dpo_trainer import DPOTrainer +from .kto_criterion import KTOCriterion +from .kto_trainer import KTOTrainer from .sft_trainer import * from .trl_data import * from .trl_utils import * diff --git a/paddlenlp/trl/dpo_criterion.py b/paddlenlp/trl/dpo_criterion.py index 2af2a8ef2096..be454e2ce4d1 100644 --- a/paddlenlp/trl/dpo_criterion.py +++ b/paddlenlp/trl/dpo_criterion.py @@ -287,10 +287,10 @@ def forward( ) loss = dpo_loss + sft_loss if self.use_infohub: - infohub.policy_chosen_logps.append(policy_chosen_logps) - infohub.policy_rejected_logps.append(policy_rejected_logps) - infohub.sft_loss.append(sft_loss) - infohub.dpo_loss.append(dpo_loss) + infohub.policy_chosen_logps.append(policy_chosen_logps.detach()) + infohub.policy_rejected_logps.append(policy_rejected_logps.detach()) + infohub.sft_loss.append(sft_loss.detach()) + infohub.dpo_loss.append(dpo_loss.detach()) return loss else: return policy_chosen_logps, policy_rejected_logps, sft_loss, dpo_loss, loss diff --git a/paddlenlp/trl/kto_criterion.py b/paddlenlp/trl/kto_criterion.py new file mode 100644 index 000000000000..a6ca6c4c837a --- /dev/null +++ b/paddlenlp/trl/kto_criterion.py @@ -0,0 +1,262 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import os + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy + +from paddlenlp.transformers import ( + AllGatherVarlenOp, + fused_head_and_loss_fn, + parallel_linear, + parallel_matmul, + sequence_parallel_sparse_mask_labels, +) +from paddlenlp.utils import infohub + + +class KTOCriterion(nn.Layer): + """KTO Criterion""" + + def __init__(self, config, kto_config=None, ignore_label=0, use_infohub=False): + super(KTOCriterion, self).__init__() + self.config = config + if kto_config is None: + if getattr(self.config, "kto_config", None) is None: + raise ValueError("KTO Criterion requires model_config.kto_config.") + self.kto_config = copy.deepcopy(config.kto_config) + else: + self.kto_config = kto_config + if self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1: + self.logprobs = ParallelCrossEntropy() + else: + self.logprobs = nn.CrossEntropyLoss(reduction="none") + self.use_infohub = use_infohub + self.ignore_label = ignore_label + # allgather kl in criterion + topo = fleet.get_hybrid_communicate_group()._topo + parallel_groups = topo.get_comm_list("pipe") + ranks = [] + for group in parallel_groups: + ranks.append(group[-1]) + self.comm_group = paddle.distributed.new_group(ranks=ranks) + + def _nested_gather(self, tensors): + """ + Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before + concatenating them to `gathered` + """ + local_rank = -1 + env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1)) + if env_local_rank != -1 and env_local_rank != local_rank and paddle.distributed.get_world_size() > 1: + local_rank = env_local_rank + if tensors is None: + return + if local_rank != -1: + output_tensors = [] + paddle.distributed.all_gather( + output_tensors, paddle.tile(tensors, repeat_times=[1, 1]), group=self.comm_group + ) + tensors = paddle.concat(output_tensors, axis=0) + return tensors + + def kto_logps(self, logits, response_labels, response_kl_labels, response_indexs): + """KTO logprobs""" + labels = response_labels + response_kl_labels + if self.config.use_fused_head_and_loss_fn: + hidden_states, weight, bias, transpose_y = logits + elif self.config.use_sparse_head_and_loss_fn: + hidden_states, weight, bias = logits + if self.config.use_sparse_head_and_loss_fn: + if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel: + labels, sparse_tgt_idx = sequence_parallel_sparse_mask_labels(labels, self.ignore_label) + + hidden_states = paddle.take_along_axis(hidden_states, sparse_tgt_idx, axis=0) + hidden_states = AllGatherVarlenOp.apply(hidden_states) + else: + labels = labels.flatten() + sparse_tgt_idx = paddle.nonzero(labels != self.ignore_label).flatten() + labels = paddle.take_along_axis(labels, sparse_tgt_idx, axis=0) + + hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]]) + hidden_states = paddle.take_along_axis(hidden_states, sparse_tgt_idx.unsqueeze(-1), axis=0) + if self.config.use_fused_head_and_loss_fn: + per_token_logps = -fused_head_and_loss_fn( + hidden_states, + weight, + bias, + labels, + None, + transpose_y, + self.config.vocab_size, + self.config.tensor_parallel_degree, + self.config.tensor_parallel_output, + self.config.fused_linear, + getattr(self.config, "chunk_size", 1024), + return_token_loss=True, + ignore_index=self.ignore_label, + ) + elif self.config.use_sparse_head_and_loss_fn: + if bias is None: + logits = parallel_matmul(hidden_states, weight, self.config.tensor_parallel_output) + else: + logits = parallel_linear( + hidden_states, + weight, + bias, + self.config.tensor_parallel_output, + ) + logits = logits.astype("float32") + per_token_logps = -self.logprobs(logits, labels) + else: + logits = logits.astype("float32") + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + # bs, seq + per_token_logps = -self.logprobs(logits, labels.unsqueeze(2)).squeeze(2) + + if len(response_indexs.shape) == 3: + response_indexs = response_indexs[0] + if self.config.use_sparse_head_and_loss_fn: + chosen_logps_list = [ + (per_token_logps[response_index[1] : response_index[2]]).sum() + for response_index in response_indexs + if response_index[4] == 1 + ] + rejected_logps_list = [ + (per_token_logps[response_index[1] : response_index[2]]).sum() + for response_index in response_indexs + if response_index[4] == 0 + ] + kl_logps_list = [ + (per_token_logps[response_index[2] : response_index[3]]).sum() for response_index in response_indexs + ] + else: + chosen_logps_list = [ + (per_token_logps[response_index[0]][response_index[1] : response_index[2]]).sum() + for response_index in response_indexs + if response_index[4] == 1 + ] + rejected_logps_list = [ + (per_token_logps[response_index[0]][response_index[1] : response_index[2]]).sum() + for response_index in response_indexs + if response_index[4] == 0 + ] + kl_logps_list = [ + (per_token_logps[response_index[0]][response_index[2] : response_index[3]]).sum() + for response_index in response_indexs + ] + if len(chosen_logps_list) == 0: + chosen_logps = paddle.zeros([0], dtype="float32") + else: + chosen_logps = paddle.stack(chosen_logps_list, axis=0) + if len(rejected_logps_list) == 0: + rejected_logps = paddle.zeros([0], dtype="float32") + else: + rejected_logps = paddle.stack(rejected_logps_list, axis=0) + kl_logps = paddle.stack(kl_logps_list, axis=0) + return chosen_logps, rejected_logps, kl_logps + + def kto_loss( + self, + policy_chosen_logps, + policy_rejected_logps, + policy_kl_logps, + reference_chosen_logps, + reference_rejected_logps, + reference_kl_logps, + ): + """KTO Loss""" + kl = (policy_kl_logps - reference_kl_logps).mean().detach() + kl = self._nested_gather(paddle.tile(kl, repeat_times=[1, 1])).mean().clip(min=0) + if policy_chosen_logps.shape[0] == 0 or reference_chosen_logps.shape[0] == 0: + chosen_losses = paddle.zeros([0]) + else: + chosen_logratios = policy_chosen_logps - reference_chosen_logps + chosen_losses = 1 - F.sigmoid(self.kto_config.beta * (chosen_logratios - kl)) + if policy_rejected_logps.shape[0] == 0 or reference_rejected_logps.shape[0] == 0: + rejected_losses = paddle.zeros([0]) + else: + rejected_logratios = policy_rejected_logps - reference_rejected_logps + rejected_losses = 1 - F.sigmoid(self.kto_config.beta * (kl - rejected_logratios)) + losses = paddle.concat( + ( + self.kto_config.desirable_weight * chosen_losses, + self.kto_config.undesirable_weight * rejected_losses, + ), + 0, + ) + return losses.mean(), kl + + def forward( + self, + logits, + labels, + ): + """Forward""" + ( + response_labels, + response_kl_labels, + response_indexs, + reference_chosen_logps, + reference_rejected_logps, + reference_kl_logps, + ) = labels + if reference_chosen_logps is None or reference_rejected_logps is None or reference_kl_logps is None: + ( + reference_chosen_logps, + reference_rejected_logps, + reference_kl_logps, + ) = self.kto_logps(logits, response_labels, response_kl_labels, response_indexs) + if self.use_infohub: + infohub.reference_chosen_logps.append(reference_chosen_logps) + infohub.reference_rejected_logps.append(reference_rejected_logps) + infohub.reference_kl_logps.append(reference_kl_logps) + # pipeline mode requires return loss when self._compute_loss is True + return paddle.zeros([1]) + else: + return ( + reference_chosen_logps, + reference_rejected_logps, + reference_kl_logps, + ) + policy_chosen_logps, policy_rejected_logps, policy_kl_logps = self.kto_logps( + logits, response_labels, response_kl_labels, response_indexs + ) + loss, kl = self.kto_loss( + policy_chosen_logps, + policy_rejected_logps, + policy_kl_logps, + reference_chosen_logps, + reference_rejected_logps, + reference_kl_logps, + ) + if self.use_infohub: + infohub.policy_chosen_logps.append(policy_chosen_logps.detach()) + infohub.policy_rejected_logps.append(policy_rejected_logps.detach()) + infohub.policy_kl_logps.append(policy_kl_logps.detach()) + infohub.kl.append(kl.detach()) + return loss + else: + return ( + policy_chosen_logps, + policy_rejected_logps, + policy_kl_logps, + loss, + kl, + ) diff --git a/paddlenlp/trl/kto_trainer.py b/paddlenlp/trl/kto_trainer.py new file mode 100644 index 000000000000..861326fae59c --- /dev/null +++ b/paddlenlp/trl/kto_trainer.py @@ -0,0 +1,555 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" KTO Trainer """ +from collections import OrderedDict, defaultdict + +import paddle +from paddle.distributed import fleet + +from paddlenlp.trainer import Trainer +from paddlenlp.transformers.model_utils import unwrap_model +from paddlenlp.trl import KTOCriterion +from paddlenlp.utils import infohub + + +def disable_dropout_in_model(model: paddle.nn.Layer) -> None: + """ "disable dropout""" + for module in model.children(): + if isinstance(module, paddle.nn.Dropout): + module.p = 0 + + +try: + from paddlenlp.peft.lora.lora_model import AVAILABLE_LAYERS +except: + from paddlenlp.peft.lora.lora_model import AVALIABLE_LAYERS + + AVAILABLE_LAYERS = AVALIABLE_LAYERS + +KTO_INFO_KEYS = [ + "reference_chosen_logps", + "reference_rejected_logps", + "reference_kl_logps", + "policy_chosen_logps", + "policy_rejected_logps", + "policy_kl_logps", + "kl", +] + + +class KTOTrainer(Trainer): + """ + Initialize KTOTrainer. + """ + + def __init__( + self, + model, + data_collator, + ref_model=None, + kto_config=None, + disable_dropout: bool = True, + padding_value: int = 0, + kto_criterion=None, + ignore_label: int = 0, + **kwargs, + ): + super().__init__(model, data_collator=data_collator, **kwargs) + if kto_config is None: + raise ValueError("kto_config is None") + else: + self.kto_config = kto_config + if ref_model: + self.ref_model = ref_model + self.ref_model_wrapped = self._wrap_ref_model(self.ref_model) + self.ref_model_wrapped.eval() + elif self.kto_config.lora: + self.ref_model = None + self.ref_model_wrapped = None + else: + raise ValueError("ref_model is None! KTO requires a reference model") + if not self.args.pipeline_parallel_degree > 1: + if kto_criterion is None: + self.kto_criterion = KTOCriterion(self.model.config, kto_config=kto_config, ignore_label=ignore_label) + elif isinstance(kto_criterion, KTOCriterion): + self.kto_criterion = kto_criterion + else: + raise ValueError(f"kto_criterion should be None or KTOCriterion. Got {type(kto_criterion)}") + if disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + self.padding_value = padding_value + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + if self.model.config.tensor_parallel_output and self.model.config.tensor_parallel_degree > 1: + self.logprobs = paddle.distributed.fleet.meta_parallel.ParallelCrossEntropy() + else: + self.logprobs = paddle.nn.CrossEntropyLoss(reduction="none") + self.reset_dpo_infohub() + + def get_batch_metrics(self, ref_model, model, batch, train_eval="train"): + """Compute the KTO loss and other metrics for the given batch of inputs for train or test.""" + inputs = { + "input_ids": batch["input_ids"], + "position_ids": batch["position_ids"], + } + if "attention_mask" in batch: + inputs["attention_mask"] = batch["attention_mask"] + elif "attn_mask_start_row_indices" in batch: + inputs["attn_mask_start_row_indices"] = batch["attn_mask_start_row_indices"] + else: + raise ValueError("No attention mask found in batch.") + labels = ( + batch["response_labels"], + batch["response_kl_labels"], + batch["response_indexs"], + None, + None, + None, + ) + with paddle.no_grad(): + if self.kto_config.lora: + self.disable_lora(model) + model.eval() + logits = model(**inputs) + self.enable_lora(model) + model.train() + else: + logits = ref_model(**inputs) + ( + reference_chosen_logps, + reference_rejected_logps, + reference_kl_logps, + ) = self.kto_criterion(logits, labels) + labels = labels[:3] + ( + reference_chosen_logps, + reference_rejected_logps, + reference_kl_logps, + ) + logits = model(**inputs) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_kl_logps, + loss, + kl, + ) = self.kto_criterion(logits, labels) + + # metrics + metric_inputs = dict( + policy_chosen_logps=policy_chosen_logps, + policy_rejected_logps=policy_rejected_logps, + reference_chosen_logps=reference_chosen_logps, + reference_rejected_logps=reference_rejected_logps, + kl=kl, + train_eval=train_eval, + ) + self.log_metric(**metric_inputs) + return loss + + def log_metric( + self, + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + kl, + train_eval, + ): + metrics = {} + chosen_rewards = self.kto_config.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.kto_config.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}count/chosen"] = paddle.to_tensor(chosen_rewards.shape[0]) + metrics[f"{prefix}count/rejected"] = paddle.to_tensor(rejected_rewards.shape[0]) + + if policy_chosen_logps.shape[0] == 0 or len(reference_chosen_logps.shape) == 0: + metrics[f"{prefix}rewards/chosen"] = paddle.zeros([]) + metrics[f"{prefix}logps/chosen"] = paddle.zeros([]) + else: + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.mean() + if policy_rejected_logps.shape[0] == 0 or reference_rejected_logps.shape[0] == 0: + metrics[f"{prefix}rewards/rejected"] = paddle.zeros([]) + metrics[f"{prefix}logps/rejected"] = paddle.zeros([]) + else: + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean() + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.mean() + + for key in metrics: + if "count" in key: + metrics[key] = self._nested_gather(paddle.tile(metrics[key], repeat_times=[1, 1])).sum().cpu() + metrics[key] /= max(self.args.tensor_parallel_degree, 1) + else: + metrics[key] = self._nested_gather(paddle.tile(metrics[key], repeat_times=[1, 1])).mean().cpu() + metrics[f"{prefix}kl"] = kl + metrics[f"{prefix}rewards/margins"] = metrics[f"{prefix}rewards/chosen"] - metrics[f"{prefix}rewards/rejected"] + if self.args.should_save: + self.store_metrics(metrics, train_eval=train_eval) + + def compute_loss(self, model, inputs): + """Compute the KTO loss for the given batch of inputs.""" + loss = self.get_batch_metrics(self.ref_model_wrapped, model, inputs, train_eval="train") + return loss + + def _wrap_ref_model(self, model): + """Wrap reference model.""" + if unwrap_model(model) is not model: + return model + self.amp_dtype = "float16" if self.args.fp16 else "bfloat16" + model = paddle.amp.decorate( + models=model, + level=self.args.fp16_opt_level, + dtype=self.amp_dtype, + ) + model = fleet.distributed_model(model) + if self.args.pipeline_parallel_degree > 1: + model._prepare_pipeline_inputs_func = prepare_pipeline_dpo_inputs_func + return model + + def _wrap_model(self, model, training=True): + """Wrap model.""" + model = super()._wrap_model(model, training) + if self.args.pipeline_parallel_degree > 1: + model._prepare_pipeline_inputs_func = prepare_pipeline_dpo_inputs_func + return model + + def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"): + """evaluate""" + self.model_wrapped = self._wrap_ref_model(self.model_wrapped) + return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix) + + def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None): + """prediction_step""" + if self.args.pipeline_parallel_degree > 1: + # hack for pipeline mode + inputs = self._prepare_inputs(inputs) + return self.prediction_pipeline_step(self.ref_model_wrapped, model, inputs) + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + with paddle.no_grad(): + with self.autocast_smart_context_manager(): + loss = self.get_batch_metrics(self.ref_model_wrapped, model, inputs, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + else: + raise NotImplementedError("KTOTrainer only supports prediction_loss_only=True for now.") + + def store_metrics(self, metrics, train_eval="train"): + """store_metrics""" + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def log(self, logs, **kwargs): + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + if "count" in key: + logs[key] = paddle.to_tensor(metrics).sum().item() + else: + logs[key] = paddle.to_tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + if self.state.epoch is not None and train_eval == "train": + self.state.epoch *= self.args.num_train_epochs + return super().log(logs, **kwargs) + + def disable_lora(self, model): + """Disable LORA layers.""" + for _, layer in model.named_sublayers(): + if any(isinstance(layer, lora_layer) for lora_layer in AVAILABLE_LAYERS): + layer.disable_lora = True + + def enable_lora(self, model): + """Enable LORA layers.""" + for _, layer in model.named_sublayers(): + if any(isinstance(layer, lora_layer) for lora_layer in AVAILABLE_LAYERS): + layer.disable_lora = False + + def training_pipeline_step(self, model, inputs): + """ + Perform a training step on a batch of inputs. + """ + # accumulation data + if not hasattr(self, "_pp_data_buffer"): + self._pp_data_buffer = [] + self._pp_data_buffer.append(inputs) + if len(self._pp_data_buffer) != self.args.gradient_accumulation_steps: + return paddle.zeros([]) + + concatenated_inputs = {} + for key in self._pp_data_buffer[0].keys(): + concatenated_inputs[key] = [ + self._pp_data_buffer[i][key] for i in range(self.args.gradient_accumulation_steps) + ] + concatenated_inputs["reference_chosen_logps"] = None + concatenated_inputs["reference_rejected_logps"] = None + concatenated_inputs["reference_kl_logps"] = None + self._pp_data_buffer = [] + inputs, labels = model._prepare_pipeline_inputs_func(concatenated_inputs) + model_config_backup = model.micro_batch_size, model.accumulate_steps + model.micro_batch_size = self.args.per_device_train_batch_size + model.accumulate_steps = self.args.gradient_accumulation_steps + + if self.kto_config.lora: + self.disable_lora(model) + model.eval() + with paddle.no_grad(): + with self.autocast_smart_context_manager(): + model.eval_batch(data=[inputs, labels], compute_loss=True) + self.enable_lora(model) + model._p2p_helper.clear_meta_cache() + model.train() + else: + ref_model = self.ref_model_wrapped + ref_model_config_backup = ( + ref_model.micro_batch_size, + ref_model.accumulate_steps, + ) + ref_model.accumulate_steps = model.accumulate_steps + ref_model.micro_batch_size = model.micro_batch_size + with paddle.no_grad(): + with self.autocast_smart_context_manager(): + ref_model.eval_batch(data=[inputs, labels], compute_loss=True) + ref_model.micro_batch_size, ref_model.accumulate_steps = ref_model_config_backup + reference_chosen_logps = infohub.reference_chosen_logps + reference_rejected_logps = infohub.reference_rejected_logps + reference_kl_logps = infohub.reference_kl_logps + + if model.is_pipeline_last_stage(ignore_virtual=model._layers._num_virtual_pipeline_stages > 1): + labels = labels[:3] + ( + reference_chosen_logps, + reference_rejected_logps, + reference_kl_logps, + ) + train_inputs = [inputs, labels] + train_inputs = model._prepare_training(train_inputs, self.optimizer, self.lr_scheduler) + model.optimizer = None # we do not use `PipelineParallel` to handler optimizer step + model.lr_scheduler = None + with self.autocast_smart_context_manager(): + loss = model.forward_backward_pipeline(train_inputs, self.scaler if self.do_grad_scaling else None) + model.micro_batch_size, model.accumulate_steps = model_config_backup + + # broadcast KTO_INFO_KEYS + self.broadcast_last_stage_infohub_tensor() + + # metrics + metric_inputs = dict( + policy_chosen_logps=infohub.policy_chosen_logps, + policy_rejected_logps=infohub.policy_rejected_logps, + reference_chosen_logps=infohub.reference_chosen_logps, + reference_rejected_logps=infohub.reference_rejected_logps, + kl=infohub.kl, + train_eval="train", + ) + self.log_metric(**metric_inputs) + self.reset_dpo_infohub() + return loss.detach() + + def prediction_pipeline_step( + self, + ref_model, + model, + batch, + ): + """ + prediction_step function for pipeline parallel mode. + """ + concatenated_inputs = {} + # consider no drop last + per_device_train_batch_size = self.args.per_device_train_batch_size + gradient_accumulation_steps = self.args.gradient_accumulation_steps + # preprocess inputs: tuple(List[Tensor]) + for key in batch.keys(): + if key not in "response_indexs": + concatenated_inputs[key] = [ + batch[key][i * per_device_train_batch_size : (i + 1) * per_device_train_batch_size] + for i in range(gradient_accumulation_steps) + ] + else: + concatenated_inputs["response_indexs"] = [[] for _ in range(gradient_accumulation_steps)] + for i in range(gradient_accumulation_steps): + for response_index in batch[key]: + if response_index[0] in list( + range( + i * per_device_train_batch_size, + (i + 1) * per_device_train_batch_size, + ) + ): + response_index[0] -= i * per_device_train_batch_size + concatenated_inputs["response_indexs"][i].append(response_index) + concatenated_inputs["response_indexs"][i] = paddle.stack(concatenated_inputs["response_indexs"][i]) + if model._layers.config.use_sparse_head_and_loss_fn: + last_batch_response_length = concatenated_inputs["response_indexs"][i][0, 1] + concatenated_inputs["response_indexs"][i][:, 1:] -= last_batch_response_length + + concatenated_inputs["reference_chosen_logps"] = None + concatenated_inputs["reference_rejected_logps"] = None + concatenated_inputs["reference_kl_logps"] = None + + self._pp_data_buffer = [] + inputs, labels = model._prepare_pipeline_inputs_func(concatenated_inputs) + + if self.kto_config.lora: + self.disable_lora(model) + model.eval() + with paddle.no_grad(): + with self.autocast_smart_context_manager(): + model.eval_batch(data=[inputs, labels], compute_loss=True) + self.enable_lora(model) + model._p2p_helper.clear_meta_cache() + model.train() + else: + ref_model = self.ref_model_wrapped + with paddle.no_grad(): + with self.autocast_smart_context_manager(): + ref_model.eval_batch(data=[inputs, labels], compute_loss=True) + reference_chosen_logps = infohub.reference_chosen_logps + reference_rejected_logps = infohub.reference_rejected_logps + reference_kl_logps = infohub.reference_kl_logps + + if model.is_pipeline_last_stage(ignore_virtual=model._layers._num_virtual_pipeline_stages > 1): + labels = labels[:3] + ( + reference_chosen_logps, + reference_rejected_logps, + reference_kl_logps, + ) + with paddle.no_grad(): + with self.autocast_smart_context_manager(): + loss = model.eval_batch(data=[inputs, labels], compute_loss=True) + + # broadcast KTO_INFO_KEYS + self.broadcast_last_stage_infohub_tensor() + # metrics + metric_inputs = dict( + policy_chosen_logps=infohub.policy_chosen_logps, + policy_rejected_logps=infohub.policy_rejected_logps, + reference_chosen_logps=infohub.reference_chosen_logps, + reference_rejected_logps=infohub.reference_rejected_logps, + kl=infohub.kl, + train_eval="eval", + ) + self.log_metric(**metric_inputs) + self.reset_dpo_infohub() + return (loss, None, None) + + def reset_dpo_infohub(self): + """Initialize infohub""" + for key in KTO_INFO_KEYS: + setattr(infohub, key, []) + + def broadcast_last_stage_infohub_tensor(self): + for key in KTO_INFO_KEYS: + if self.model_wrapped.is_pipeline_last_stage( + ignore_virtual=self.model_wrapped._layers._num_virtual_pipeline_stages > 1 + ): + if key == "kl": + tensor = paddle.stack(getattr(infohub, key)).mean().detach() + elif "logps" in key: + logps_list = getattr(infohub, key) + if all(logps.shape == [0] for logps in logps_list): + tensor = paddle.zeros([1]) + else: + tensor = paddle.concat(getattr(infohub, key), axis=0).detach() + tensor_shape = paddle.to_tensor(tensor.shape, dtype="int64") + paddle.distributed.broadcast( + tensor_shape, + src=self.model_wrapped.global_rank, + group=self.model_wrapped.pp_group, + ) + else: + raise ValueError(f"Invalid key: {key}") + paddle.distributed.broadcast( + tensor, + src=self.model_wrapped.global_rank, + group=self.model_wrapped.pp_group, + ) + else: + if key == "kl": + tensor = paddle.zeros([], "float32") + elif "logps" in key: + tensor_shape = paddle.empty([1], dtype="int64") + paddle.distributed.broadcast( + tensor_shape, + src=self.model_wrapped._hcg.get_rank_from_stage(self.model_wrapped.num_stages - 1), + group=self.model_wrapped.pp_group, + ) + tensor = paddle.zeros(tensor_shape, "float32") + else: + raise ValueError(f"Invalid key: {key}") + paddle.distributed.broadcast( + tensor, + src=self.model_wrapped._hcg.get_rank_from_stage(self.model_wrapped.num_stages - 1), + group=self.model_wrapped.pp_group, + ) + setattr(infohub, key, tensor) + + +def prepare_pipeline_dpo_inputs_func(inputs): + """Prepare pipeline inputs""" + if "attention_mask" in inputs: + first_stage_keys = [ + "input_ids", + "attention_mask", + "position_ids", + ] + else: + first_stage_keys = [ + "input_ids", + "attn_mask_start_row_indices", + "position_ids", + ] + + last_stage_keys = [ + "response_labels", + "response_kl_labels", + "response_indexs", + "reference_chosen_logps", + "reference_rejected_logps", + "reference_kl_logps", + ] + + def get_expected_keys(inputs, keys): + ret = tuple([inputs.pop(k) for k in keys if k in inputs]) + if len(ret) == 1: + ret = ret[0] + return ret + + if type(inputs) is dict or type(inputs) is OrderedDict: + return [ + get_expected_keys(inputs, first_stage_keys), + get_expected_keys(inputs, last_stage_keys), + ] + + keys = list(inputs[0].keys()) + inputs_batch = {key: [data.pop(key) for data in inputs] for key in keys} + return [ + get_expected_keys(inputs_batch, first_stage_keys), + get_expected_keys(inputs_batch, last_stage_keys), + ]