Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix t5 dataset #459

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions libai/data/datasets/t5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ class T5Dataset(flow.utils.data.Dataset):
All values are padded to this length. Defaults to 512.
max_seq_length_dec (int, optional): Maximum length of the sequence passing into decoder.
All values are padded to this length. Defaults to 128.
mask_lm_prob (float, optional): Probability to mask tokens. Defaults to 0.15.
max_preds_per_seq (int, optional): Maximum number of masked tokens in each sentence.
Defaults to None.
masked_lm_prob (float, optional): Probability to mask tokens. Defaults to 0.15.
short_seq_prob (float, optional):
Probability of producing a short sequence. Defaults to 0.0.
seed (int, optional):
Expand Down
5 changes: 3 additions & 2 deletions projects/MT5/layers/attention_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def forward(
)
else:
position_bias = self.compute_bias(
real_seq_length, key_length, placement=attention_mask.placement
real_seq_length, key_length, placement=attention_scores.placement
)

if past_key_value is not None:
Expand All @@ -228,13 +228,14 @@ def forward(
if use_cache:
attention_mask = attention_mask.expand_as(attention_scores)

attention_dropout_prob = self.attention_dropout_prob if self.training else 0.0
attention_weights = flow._C.fused_bias_add_scale_mask_softmax_dropout(
attention_scores,
position_bias,
attention_mask,
fill_value=-10000.0,
scale=1,
p=self.attention_dropout_prob,
p=attention_dropout_prob,
)[0]
else:
attention_scores = attention_scores + position_bias
Expand Down
11 changes: 6 additions & 5 deletions projects/MT5/mt5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def forward(
position_bias = None
encoder_decoder_position_bias = None
self.set_cache(encoder_states=None, past_key_values=None)
encoder_attn_mask = self.extended_attn_mask(encoder_attn_mask)
encoder_attn_mask = self.extended_attn_mask(encoder_attn_mask) if encoder_attn_mask is not None else encoder_attn_mask
enc_embedding_output = self.embedding(encoder_input_ids)
# transpose [batch_size, seq_len, embed_size] to [seq_len, batch_size, embed_size]
enc_hidden_states = enc_embedding_output.transpose(0, 1)
Expand All @@ -219,10 +219,11 @@ def forward(
if only_encoder:
return encoder_states

decoder_attn_mask = self.extended_attn_mask(
decoder_attn_mask, decoder_input_ids, is_decoder=True
)
encoder_decoder_attn_mask = self.extended_attn_mask(encoder_decoder_attn_mask)
if decoder_attn_mask is not None:
decoder_attn_mask = self.extended_attn_mask(
decoder_attn_mask, decoder_input_ids, is_decoder=True
)
encoder_decoder_attn_mask = self.extended_attn_mask(encoder_decoder_attn_mask) if encoder_decoder_attn_mask is not None else encoder_decoder_attn_mask

dec_embedding_output = self.embedding(decoder_input_ids)
# transpose [batch_size, seq_len, embed_size] to [seq_len, batch_size, embed_size]
Expand Down