Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistencies between data collator output and masked permute in original paper #18

Open
ghost opened this issue Jul 20, 2022 · 0 comments

Comments

@ghost
Copy link

ghost commented Jul 20, 2022

Hi all on the MPNet research team,

I am in the process of converting the fairseq training code for MPNet into a training loop that is compatible with Huggingface. Although many of the convenience classes already exist in Huggingface (like MPNetForMaskedLM), one thing that has become clear to us is that we will need to port over the collator function in MaskedDataset (under tasks/masked_permutation_lm).

In exploring how this collator works, I understand the logic as:

  1. Permute input IDs (based on whole word spans or tokens via arg) and positions
  2. Create masked/corrupted tokens based on the final n indices of the permuted sequence, where n is the prediction size (i.e. seq_len x 0.15 at default values)
  3. Concat these together using concat(seq, mask, mask) and concat(positions, predict_positions, predict_positions)

Using this logic, we might expect the collator function to perform the below operation on some dummy input IDs:

src_tokens = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]

# Once the collator permutes everything and we append the mask portions, we expect something like
new_ids = [ 20,  23,  30,  14,  15,  27,  28,  11,  12,  17,  18,  26,  29,  13, 10,  19,  21,  22,  16,  24,  25, <mask>,  <corrupted>, <mask>, <mask>,  <corrupted>, <mask>]
new_positions = [10, 13, 20,  4,  5, 17, 18,  1,  2,  7,  8, 16, 19,  3,  0,  9, 11, 12, 6, 14, 15,  6, 14, 15,  6, 14, 15]

However, after rereading the MPNet paper, especially section 2.2 and 2.3 with attention on Figure 2, it would SEEM that the output of the collator is incongruous with what is described in these sections.

Figure 2 points out that the content and query masks are built using a permuted sequence that looks like:

src_tokens = [x_1, x_2, x_3, x_4, x_5, x_6]

# Once permuted we get:
new_ids = [x_1, x_3, x_5, <mask>, <mask>, <mask>,  x_4, x_6, x_2]
new_positions = [1, 3, 5, 4, 6, 2, 4, 6, 2]

In this example within the paper, we are masking the pred_len tokens and then appending the content to the end for the content stream. However, the collator output KEEPS the token content in the main sequence, and then adds TWO batches of mask tokens to the end, which to me seems necessarily different than what's described in the paper. Referring back to our dummy example above, I can outline the discrepancies I'm seeing:

collator_ids = [ 20,  23,  30,  14,  15,  27,  28,  11,  12,  17,  18,  26,  29,  13, 10,  19,  21,  22,  16,  24,  25, <mask>,  <corrupted>, <mask>, <mask>,  <corrupted>, <mask>]
collator_positions = [10, 13, 20,  4,  5, 17, 18,  1,  2,  7,  8, 16, 19,  3,  0,  9, 11, 12, 6, 14, 15,  6, 14, 15,  6, 14, 15]

paper_ids = [ 20,  23,  30,  14,  15,  27,  28,  11,  12,  17,  18,  26,  29,  13, 10,  19,  21,  22, <mask>,  <corrupted>, <mask>, 16, 24, 25]
paper_positions = [10, 13, 20,  4,  5, 17, 18,  1,  2,  7,  8, 16, 19,  3,  0,  9, 11, 12, 6, 14, 15,  6, 14, 15]

My question, then, is this: am I correct in understanding that the collator implementation is different than what's described in the paper? If so, why?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

0 participants