-
Notifications
You must be signed in to change notification settings - Fork 176
Add label smoothing to CopyNet #287
Add label smoothing to CopyNet #287
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just a small comment.
one_hot_targets = torch.zeros_like(log_probs).scatter_( | ||
-1, target_tokens.unsqueeze(1), 1.0 - self._label_smoothing | ||
) | ||
smoothed_targets = one_hot_targets + smoothing_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like it would be faster to start with torch.full_like(log_probs, smoothing_value)
and go from there? I did not measure it though.
one_hot_targets
isn't used anywhere else, is it? It's not even properly "one hot" like this. Maybe it should be called one_warm_targets
🤣 .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like it would be faster to start with
torch.full_like(log_probs, smoothing_value)
and go from there? I did not measure it though.
Something like this?
one_hot_targets = torch.full_like(log_probs, smoothing_value).scatter_(
1, target_tokens.unsqueeze(1), 1.0 - self._label_smoothing + smoothing_value
)
It appears to be slightly faster based on a quick test:
one_hot_targets
isn't used anywhere else, is it? It's not even properly "one hot" like this. Maybe it should be calledone_warm_targets
🤣 .
Good point! I am stealing the variable one_hot_targets
right from here. If we switch to the approach above, we could just call it smoothed_targets
smoothed_targets = torch.full_like(log_probs, smoothing_value).scatter_(
1, target_tokens.unsqueeze(1), 1.0 - self._label_smoothing + smoothing_value
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth updating https://github.com/allenai/allennlp/blob/cf113d705b9054d329c67cf9bb29cbc3f191015d/allennlp/nn/util.py#L825-L828 to use this micro-optimization
Oh, I see it's also missing a changelog entry. This definitely has enough magnitude to warrant one. |
Done! |
Thanks! I'll make the switch in util.py. |
This PR adds label smoothing to
CopyNetSeq2Seq
. As discussed in allenai/allennlp#5276, label smoothing is added to the generation scores only. It is mostly a re-working of the existing label smoothing code insequence_cross_entropy_with_logits
.As a sanity check, I ran the code with my own model. A model with a small
label_smoothing
value reaches similar performance as a model withlabel_smoothing == 0.0
. As for additional unit tests, I think a modification oftest_get_ll_contrib
might make the most sense.