Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support flashmask #8670

Merged
merged 3 commits into from
Jun 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,23 @@
alibi = position_ids
position_ids = attn_mask_startend_row_indices
attn_mask_startend_row_indices = None
elif not self.config.alibi and position_ids is None and attn_mask_startend_row_indices is not None:
# hidden_states, attention_mask, position_ids
position_ids = attn_mask_startend_row_indices
attn_mask_startend_row_indices = None
alibi = None
elif not self.config.alibi:
if get_env_device() in ["gpu"]:
if attention_mask is not None and attention_mask.dtype == paddle.int32:
attention_mask, attn_mask_startend_row_indices, position_ids = (

Check warning on line 214 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L211-L214

Added lines #L211 - L214 were not covered by tests
None,
attention_mask,
attn_mask_startend_row_indices,
)
elif attention_mask is not None and attention_mask.dtype == paddle.int64:
attention_mask, attn_mask_startend_row_indices, position_ids = None, None, attention_mask
elif (

Check warning on line 221 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L219-L221

Added lines #L219 - L221 were not covered by tests
attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64
):
attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices
elif position_ids is None and attn_mask_startend_row_indices is not None:
position_ids = attn_mask_startend_row_indices
attn_mask_startend_row_indices = None

Check warning on line 227 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L224-L227

Added lines #L224 - L227 were not covered by tests

has_gradient = not hidden_states.stop_gradient
if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:
Expand Down
Loading