Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synthetic copy task data question #127

Open
cmmcirvin opened this issue Sep 22, 2024 · 7 comments
Open

Synthetic copy task data question #127

cmmcirvin opened this issue Sep 22, 2024 · 7 comments

Comments

@cmmcirvin
Copy link

I'm slightly confused by the line below data generator for the simple copy task.

It seems to be manually setting part of the data to be 1. I'm not sure what the reason for this is, and it seems to be causing the first value in the output array to mess up sometimes, and removing it doesn't seem to hurt performance or break anything. What is this line intended to do?

Thanks for the help!

@hhxxttxs-tang
Copy link

It might suggest that SOS token is defined as "1".

@cmmcirvin
Copy link
Author

I think that the SOS token is 0 though? If we look at the final output example, the start_token parameter passed into the greedy_decode function is 0.

print(greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=0))

Looking at that greedy_decode function, it seems like that start_symbol token is being used as the first output label, where I'd expect the SOS token to be.

@hhxxttxs-tang
Copy link

hhxxttxs-tang commented Sep 23, 2024

IMO, "0" as start_symbol is wrong. It needs to match with what is used for training, otherwise you won't get correct output for inference

"0" is used for padding instead - see the last line of code in fun: data_gen()

    yield Batch(src.to(device), tgt.to(device), 0) 

@cmmcirvin
Copy link
Author

I think just changing the above line from data[:, 0] = 1 to data[:, 0] = 0 should be fine, correct? It seems to be more stable when I do that, at least - I think if we use 0 as a start symbol everywhere by modifying the above line, this should be fine.

I don't think 1 is a good SOS token, as our data for the copy task ranges from 1 to V (line below), so I don't think it makes sense for the SOS token to be the same as a valid token we want to copy.

data = torch.randint(1, V, size=(batch_size, 10))

@hhxxttxs-tang
Copy link

hhxxttxs-tang commented Sep 23, 2024

"0" is good if no padding is involved, otherwise, i think they need to be defined separately.
not sure what's the best practice for a good SOS token.

@cmmcirvin
Copy link
Author

Ah, I see how 0 is being used for padding now, thanks. I was mis-understanding that part before.

I still don't think the current implementation is correct in how it's handling the SOS token because 1 is a valid element of the dataset, but I see why 0 would also be a bad idea now. Thanks!

@kuraga
Copy link

kuraga commented Sep 24, 2024

#116 ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants