Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLM]add ktotrainer #9393

Merged
merged 2 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 98 additions & 20 deletions paddlenlp/transformers/tensor_parallel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,20 @@
labels_chunk = labels[token_start_idx:token_end_idx]

# logits calculations
logits_chunk_cast = paddle.matmul(hidden_states_chunk, lm_head_weight_cast, transpose_y=transpose_y)
logits_chunk_cast = paddle.matmul(

Check warning on line 257 in paddlenlp/transformers/tensor_parallel_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tensor_parallel_utils.py#L257

Added line #L257 was not covered by tests
hidden_states_chunk,
lm_head_weight_cast,
transpose_y=transpose_y,
)
if lm_head_bias is not None:
logits_chunk_cast += lm_head_bias_cast
if tensor_parallel_degree > 1 and not tensor_parallel_output:
logits_chunk_cast_lst = []
dist.all_gather(logits_chunk_cast_lst, logits_chunk_cast, group=model_parallel_group)
dist.all_gather(

Check warning on line 266 in paddlenlp/transformers/tensor_parallel_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tensor_parallel_utils.py#L266

Added line #L266 was not covered by tests
logits_chunk_cast_lst,
logits_chunk_cast,
group=model_parallel_group,
)
logits_chunk_cast = paddle.concat(logits_chunk_cast_lst, axis=-1)
logits_chunk = logits_chunk_cast.astype("float32")

Expand All @@ -271,18 +279,30 @@
exp_logits = paddle.exp(normalized_logits)
sum_exp_logits = paddle.sum(exp_logits, axis=-1, keepdim=True)
if tensor_parallel_degree > 1 and tensor_parallel_output:
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=model_parallel_group)
dist.all_reduce(

Check warning on line 282 in paddlenlp/transformers/tensor_parallel_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tensor_parallel_utils.py#L282

Added line #L282 was not covered by tests
sum_exp_logits,
op=dist.ReduceOp.SUM,
group=model_parallel_group,
)
log_sum_exp_logits = paddle.log(sum_exp_logits)

# cross entropy
labels_one_hot = labels_chunk.unsqueeze(1) == indices
label_logits = paddle.sum(
paddle.where(labels_one_hot, normalized_logits, paddle.zeros_like(normalized_logits)),
paddle.where(
labels_one_hot,
normalized_logits,
paddle.zeros_like(normalized_logits),
),
axis=-1,
keepdim=True,
)
if tensor_parallel_degree > 1 and tensor_parallel_output:
dist.all_reduce(label_logits, op=dist.ReduceOp.SUM, group=model_parallel_group)
dist.all_reduce(

Check warning on line 301 in paddlenlp/transformers/tensor_parallel_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tensor_parallel_utils.py#L301

Added line #L301 was not covered by tests
label_logits,
op=dist.ReduceOp.SUM,
group=model_parallel_group,
)
token_loss_chunk = (log_sum_exp_logits - label_logits).squeeze(1) / divisor
cond = loss_mask[token_start_idx:token_end_idx].astype("bool")
token_loss_chunk = paddle.where(cond, token_loss_chunk, paddle.zeros_like(token_loss_chunk))
Expand All @@ -298,18 +318,30 @@
grad_logits_chunk = (exp_logits / sum_exp_logits - labels_one_hot.astype("float32")) / divisor
grad_logits_chunk = grad_logits_chunk.astype(dtype)
grad_logits_chunk = paddle.where(
cond.unsqueeze(1), grad_logits_chunk, paddle.zeros_like(grad_logits_chunk)
cond.unsqueeze(1),
grad_logits_chunk,
paddle.zeros_like(grad_logits_chunk),
)

if grad_hidden_states is not None:
grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul(
grad_logits_chunk, lm_head_weight_cast, transpose_y=not transpose_y
grad_logits_chunk,
lm_head_weight_cast,
transpose_y=not transpose_y,
)
if grad_lm_head_weight is not None:
if transpose_y:
grad_lm_head_weight += paddle.matmul(grad_logits_chunk, hidden_states_chunk, transpose_x=True)
grad_lm_head_weight += paddle.matmul(

Check warning on line 334 in paddlenlp/transformers/tensor_parallel_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tensor_parallel_utils.py#L334

Added line #L334 was not covered by tests
grad_logits_chunk,
hidden_states_chunk,
transpose_x=True,
)
else:
grad_lm_head_weight += paddle.matmul(hidden_states_chunk, grad_logits_chunk, transpose_x=True)
grad_lm_head_weight += paddle.matmul(

Check warning on line 340 in paddlenlp/transformers/tensor_parallel_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tensor_parallel_utils.py#L340

Added line #L340 was not covered by tests
hidden_states_chunk,
grad_logits_chunk,
transpose_x=True,
)
if grad_lm_head_bias is not None:
grad_lm_head_bias += grad_logits_chunk.astype("float32").sum(axis=0).astype(dtype)

Expand Down Expand Up @@ -340,7 +372,11 @@
grad_args = []
if ctx.hidden_states_has_grad:
if tensor_parallel_degree > 1:
dist.all_reduce(grad_hidden_states, op=dist.ReduceOp.SUM, group=model_parallel_group)
dist.all_reduce(

Check warning on line 375 in paddlenlp/transformers/tensor_parallel_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tensor_parallel_utils.py#L375

Added line #L375 was not covered by tests
grad_hidden_states,
op=dist.ReduceOp.SUM,
group=model_parallel_group,
)
grad_args.append(grad_hidden_states.reshape(original_shape))
if ctx.lm_head_weight_has_grad:
grad_args.append(grad_lm_head_weight)
Expand Down Expand Up @@ -376,9 +412,20 @@
grad_lm_head_bias = None

if ctx.aux_num == 1:
return grad_hidden_states, grad_lm_head_weight, grad_lm_head_bias, None
return (

Check warning on line 415 in paddlenlp/transformers/tensor_parallel_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tensor_parallel_utils.py#L415

Added line #L415 was not covered by tests
grad_hidden_states,
grad_lm_head_weight,
grad_lm_head_bias,
None,
)
else:
return grad_hidden_states, grad_lm_head_weight, grad_lm_head_bias, None, None
return (

Check warning on line 422 in paddlenlp/transformers/tensor_parallel_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tensor_parallel_utils.py#L422

Added line #L422 was not covered by tests
grad_hidden_states,
grad_lm_head_weight,
grad_lm_head_bias,
None,
None,
)

# return_token_loss = True
grad_token_loss = grad_output.reshape([-1])
Expand Down Expand Up @@ -444,12 +491,20 @@
labels_chunk = labels[token_start_idx:token_end_idx]

# logits calculations
logits_chunk_cast = paddle.matmul(hidden_states_chunk, lm_head_weight_cast, transpose_y=transpose_y)
logits_chunk_cast = paddle.matmul(

Check warning on line 494 in paddlenlp/transformers/tensor_parallel_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tensor_parallel_utils.py#L494

Added line #L494 was not covered by tests
hidden_states_chunk,
lm_head_weight_cast,
transpose_y=transpose_y,
)
if lm_head_bias is not None:
logits_chunk_cast += lm_head_bias_cast
if tensor_parallel_degree > 1 and not tensor_parallel_output:
logits_chunk_cast_lst = []
dist.all_gather(logits_chunk_cast_lst, logits_chunk_cast, group=model_parallel_group)
dist.all_gather(

Check warning on line 503 in paddlenlp/transformers/tensor_parallel_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tensor_parallel_utils.py#L503

Added line #L503 was not covered by tests
logits_chunk_cast_lst,
logits_chunk_cast,
group=model_parallel_group,
)
logits_chunk_cast = paddle.concat(logits_chunk_cast_lst, axis=-1)
logits_chunk = logits_chunk_cast.astype("float32")

Expand All @@ -461,7 +516,11 @@
exp_logits = paddle.exp(normalized_logits)
sum_exp_logits = paddle.sum(exp_logits, axis=-1, keepdim=True)
if tensor_parallel_degree > 1 and tensor_parallel_output:
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=model_parallel_group)
dist.all_reduce(

Check warning on line 519 in paddlenlp/transformers/tensor_parallel_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tensor_parallel_utils.py#L519

Added line #L519 was not covered by tests
sum_exp_logits,
op=dist.ReduceOp.SUM,
group=model_parallel_group,
)

labels_one_hot = labels_chunk.unsqueeze(1) == indices
if tensor_parallel_degree > 1 and not tensor_parallel_output:
Expand All @@ -473,12 +532,16 @@
grad_logits_chunk = grad_logits_chunk.astype(dtype)
cond = loss_mask[token_start_idx:token_end_idx].astype("bool")
grad_logits_chunk = paddle.where(
cond.unsqueeze(1), grad_logits_chunk, paddle.zeros_like(grad_logits_chunk)
cond.unsqueeze(1),
grad_logits_chunk,
paddle.zeros_like(grad_logits_chunk),
)

if grad_hidden_states is not None:
grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul(
grad_logits_chunk, lm_head_weight_cast, transpose_y=not transpose_y
grad_logits_chunk,
lm_head_weight_cast,
transpose_y=not transpose_y,
)
if grad_lm_head_weight is not None:
if transpose_y:
Expand All @@ -490,10 +553,25 @@

if grad_hidden_states is not None:
if tensor_parallel_degree > 1:
dist.all_reduce(grad_hidden_states, op=dist.ReduceOp.SUM, group=model_parallel_group)
dist.all_reduce(

Check warning on line 556 in paddlenlp/transformers/tensor_parallel_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tensor_parallel_utils.py#L556

Added line #L556 was not covered by tests
grad_hidden_states,
op=dist.ReduceOp.SUM,
group=model_parallel_group,
)
grad_hidden_states = grad_hidden_states.reshape(ctx.original_shape)

if ctx.aux_num == 1:
return grad_hidden_states, grad_lm_head_weight, grad_lm_head_bias, None
return (

Check warning on line 564 in paddlenlp/transformers/tensor_parallel_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tensor_parallel_utils.py#L564

Added line #L564 was not covered by tests
grad_hidden_states,
grad_lm_head_weight,
grad_lm_head_bias,
None,
)
else:
return grad_hidden_states, grad_lm_head_weight, grad_lm_head_bias, None, None
return (

Check warning on line 571 in paddlenlp/transformers/tensor_parallel_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tensor_parallel_utils.py#L571

Added line #L571 was not covered by tests
grad_hidden_states,
grad_lm_head_weight,
grad_lm_head_bias,
None,
None,
)
2 changes: 2 additions & 0 deletions paddlenlp/trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from .dpo_criterion import DPOCriterion
from .dpo_trainer import DPOTrainer
from .kto_criterion import KTOCriterion
from .kto_trainer import KTOTrainer
from .sft_trainer import *
from .trl_data import *
from .trl_utils import *
8 changes: 4 additions & 4 deletions paddlenlp/trl/dpo_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,10 @@
)
loss = dpo_loss + sft_loss
if self.use_infohub:
infohub.policy_chosen_logps.append(policy_chosen_logps)
infohub.policy_rejected_logps.append(policy_rejected_logps)
infohub.sft_loss.append(sft_loss)
infohub.dpo_loss.append(dpo_loss)
infohub.policy_chosen_logps.append(policy_chosen_logps.detach())
infohub.policy_rejected_logps.append(policy_rejected_logps.detach())
infohub.sft_loss.append(sft_loss.detach())
infohub.dpo_loss.append(dpo_loss.detach())

Check warning on line 293 in paddlenlp/trl/dpo_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/dpo_criterion.py#L290-L293

Added lines #L290 - L293 were not covered by tests
return loss
else:
return policy_chosen_logps, policy_rejected_logps, sft_loss, dpo_loss, loss
Loading
Loading