Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Find the right embedding layer for mismatched cases (#4179)
Browse files Browse the repository at this point in the history
* Find the right embedding layer for mismatched cases

* black
  • Loading branch information
matt-gardner authored Apr 30, 2020
1 parent 74c8404 commit 2544e59
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,11 +1658,30 @@ def find_embedding_layer(model: torch.nn.Module) -> torch.nn.Module:
)
from allennlp.modules.token_embedders.embedding import Embedding

# The special case only works if we don't have mismatched embedding. What we essentially want
# to do is grab a transformer's wordpiece embedding layer, because using that is a lot easier
# for something like hotflip than running a network to embed all tokens in a vocabulary. If
# you've used a mismatched embedder, though, your input tokens are actually *words*, so that's
# what we'll be visualizing and attacking, even though you're modeling things at the wordpiece
# level. In this case, we need to return gradients and things at the word level, so we can't
# use our shortcut of just returning the wordpiece embedding.
mismatched = False
for module in model.modules():
if isinstance(module, BertEmbeddings):
return module.word_embeddings
if isinstance(module, GPT2Model):
return module.wte
if "Mismatched" in module.__class__.__name__:
# We don't currently have a good way to check whether an embedder is mismatched, and it
# doesn't seem like it's worth it to try to add an API call for this somewhere,
# especially as we can't really call it here in a type-safe way, anyway, as we're
# iterating over plain pytorch Modules. This check should work for now (v1.0), but it's
# possible that some class gets added later that will require this check to change.
mismatched = True

if not mismatched:
for module in model.modules():
if isinstance(module, BertEmbeddings):
return module.word_embeddings
if isinstance(module, GPT2Model):
return module.wte

for module in model.modules():
if isinstance(module, TextFieldEmbedder):

Expand Down

0 comments on commit 2544e59

Please sign in to comment.