Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Changes and improvements to how we initialize transformer modules fro…
Browse files Browse the repository at this point in the history
…m pretrained models (#5200)

* updates

* rename 'load_state_dict' -> 'read_state_dict'

* fix TransformerStack

* more fixes

* fix embeddings

* fix toolkit tests

* fix self attention

* fix bimodal encoder tests

* fix more tests

* fix T5!

* fixes

* fix backbone

* fix

* fixes

* fix

* doc fixes

* name changes

* patch models branch temporarily

* update CHANGELOG

* change default dist loading strategy to 'MEM_EFFICIENT' for T5

* fix distilbert test

* always use memory efficient distributed loading strategy

* Update .github/workflows/ci.yml

Co-authored-by: Pete <[email protected]>

Co-authored-by: Akshita Bhagia <[email protected]>
  • Loading branch information
epwalsh and AkshitaB authored May 17, 2021
1 parent cccb35d commit cf113d7
Show file tree
Hide file tree
Showing 31 changed files with 1,708 additions and 1,658 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Use `dist_reduce_sum` in distributed metrics.
- Allow Google Cloud Storage paths in `cached_path` ("gs://...").
- Renamed `nn.util.load_state_dict()` to `read_state_dict` to avoid confusion with `torch.nn.Module.load_state_dict()`.
- `TransformerModule.from_pretrained_module` now only accepts a pretrained model ID (e.g. "bert-base-case") instead of
an actual `torch.nn.Module`. Other parameters to this method have changed as well.
- Print the first batch to the console by default.
- Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0).

### Added

- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.confidence_checks.task_checklists` module.
- Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files.
- Added `nn.util.distributed_device()` helper function.
- Added `allennlp.nn.util.load_state_dict` helper function.
- Added a way to avoid downloading and loading pretrained weights in modules that wrap transformers
such as the `PretrainedTransformerEmbedder` and `PretrainedTransformerMismatchedEmbedder`.
Expand Down
6 changes: 3 additions & 3 deletions allennlp/commands/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from allennlp.commands.subcommand import Subcommand
from allennlp.common.file_utils import cached_path
from allennlp.nn.util import load_state_dict
from allennlp.nn.util import read_state_dict


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -249,10 +249,10 @@ def _get_checkpoint_path(checkpoint: str) -> str:
def _diff(args: argparse.Namespace):
checkpoint_1_path = _get_checkpoint_path(args.checkpoint1)
checkpoint_2_path = _get_checkpoint_path(args.checkpoint2)
checkpoint_1 = load_state_dict(
checkpoint_1 = read_state_dict(
checkpoint_1_path, strip_prefix=args.strip_prefix_1, strict=False
)
checkpoint_2 = load_state_dict(
checkpoint_2 = read_state_dict(
checkpoint_2_path, strip_prefix=args.strip_prefix_2, strict=False
)
for step in checkpoint_diff(checkpoint_1, checkpoint_2, args.scale, args.threshold):
Expand Down
9 changes: 8 additions & 1 deletion allennlp/common/testing/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,19 @@ def run_distributed_test(
func: `Callable`
`func` needs to be global for spawning the processes, so that it can be pickled.
start_method: `Optional[str]`, optional (default = `None`)
The start method to use for starting the workers. Defaults to "spawn" for GPU
processes and fork otherwise.
"""
device_ids = device_ids or [-1, -1]
check_for_gpu(device_ids)
# "fork" start method is the default and should be preferred, except when we're
# running the tests on GPU, in which case we need to use "spawn".
start_method = "spawn" if any(x >= 0 for x in device_ids) else "fork"
if "start_method" in kwargs:
start_method = kwargs.pop("start_method")
else:
start_method = "spawn" if any(x >= 0 for x in device_ids) else "fork"
nprocs = world_size = len(device_ids)
mp.start_processes(
init_process,
Expand Down
12 changes: 12 additions & 0 deletions allennlp/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,18 @@ def is_distributed() -> bool:
return dist.is_available() and dist.is_initialized()


def is_global_primary() -> bool:
"""
Checks if the distributed process group is the global primary (rank = 0).
If the distributed process group is not available or has not been initialized,
this trivially returns `True`.
"""
if not is_distributed():
return True
else:
return dist.get_rank() == 0


def sanitize_wordpiece(wordpiece: str) -> str:
"""
Sanitizes wordpieces from BERT, RoBERTa or ALBERT tokenizers.
Expand Down
2 changes: 1 addition & 1 deletion allennlp/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def _load(

# Load state dict. We pass `strict=False` so PyTorch doesn't raise a RuntimeError
# if the state dict is missing keys because we handle this case below.
model_state = util.load_state_dict(weights_file, cuda_device=cuda_device)
model_state = util.read_state_dict(weights_file, cuda_device=cuda_device)
missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False)

# Modules might define a class variable called `authorized_missing_keys`,
Expand Down
52 changes: 10 additions & 42 deletions allennlp/modules/backbones/vilbert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from allennlp.data.fields.text_field import TextFieldTensors
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.backbones.backbone import Backbone
from allennlp.modules.transformer import BiModalEncoder, ImageFeatureEmbeddings, Embeddings
from allennlp.modules.transformer import (
BiModalEncoder,
ImageFeatureEmbeddings,
TransformerEmbeddings,
TransformerPooler,
)

logger = logging.getLogger(__name__)

Expand All @@ -23,7 +28,7 @@ class VilbertBackbone(Backbone):
def __init__(
self,
vocab: Vocabulary,
text_embeddings: Embeddings,
text_embeddings: TransformerEmbeddings,
image_embeddings: ImageFeatureEmbeddings,
encoder: BiModalEncoder,
pooled_output_dim: int,
Expand All @@ -36,7 +41,6 @@ def __init__(
self.text_embeddings = text_embeddings
self.image_embeddings = image_embeddings
self.encoder = encoder
from allennlp.modules.transformer import TransformerPooler

self.t_pooler = TransformerPooler(encoder.hidden_size1, pooled_output_dim)
self.v_pooler = TransformerPooler(encoder.hidden_size2, pooled_output_dim)
Expand Down Expand Up @@ -66,44 +70,7 @@ def from_huggingface_model_name(
image_fixed_layer: int,
fusion_method: str = "sum",
):
from transformers import AutoModel

transformer = AutoModel.from_pretrained(model_name)

from copy import deepcopy

text_embeddings = deepcopy(transformer.embeddings)

# Albert (and maybe others?) has this "embedding_size", that's different from "hidden_size".
# To get them to the same dimensionality, it uses a linear transform after the embedding
# layer, which we need to pull out and copy here.
if hasattr(transformer.config, "embedding_size"):
config = transformer.config

from transformers.models.albert.modeling_albert import AlbertModel

if isinstance(transformer, AlbertModel):
linear_transform = deepcopy(transformer.encoder.embedding_hidden_mapping_in)
else:
logger.warning(
"Unknown model that uses separate embedding size; weights of the linear "
f"transform will not be initialized. Model type is: {transformer.__class__}"
)
linear_transform = torch.nn.Linear(config.embedding_dim, config.hidden_dim)

# We can't just use torch.nn.Sequential here, even though that's basically all this is,
# because Sequential doesn't accept *inputs, only a single argument.

class EmbeddingsShim(torch.nn.Module):
def __init__(self, embeddings: torch.nn.Module, linear_transform: torch.nn.Module):
super().__init__()
self.linear_transform = linear_transform
self.embeddings = embeddings

def forward(self, *inputs, **kwargs):
return self.linear_transform(self.embeddings(*inputs, **kwargs))

text_embeddings = EmbeddingsShim(text_embeddings, linear_transform)
text_embeddings = TransformerEmbeddings.from_pretrained_module(model_name)

image_embeddings = ImageFeatureEmbeddings(
feature_size=image_feature_dim,
Expand All @@ -112,7 +79,7 @@ def forward(self, *inputs, **kwargs):
)

encoder = BiModalEncoder.from_pretrained_module(
pretrained_module=transformer,
model_name,
num_hidden_layers2=image_num_hidden_layers,
hidden_size2=image_hidden_size,
num_attention_heads2=image_num_attention_heads,
Expand All @@ -126,6 +93,7 @@ def forward(self, *inputs, **kwargs):
fixed_layer1=text_fixed_layer,
fixed_layer2=image_fixed_layer,
)

return cls(
vocab=vocab,
text_embeddings=text_embeddings,
Expand Down
2 changes: 1 addition & 1 deletion allennlp/modules/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def forward(self, token_ids: torch.LongTensor, mask: torch.BoolTensor):
```
"""

from allennlp.modules.transformer.layer_norm import LayerNorm
from allennlp.modules.transformer.positional_encoding import SinusoidalPositionalEncoding

from allennlp.modules.transformer.transformer_module import TransformerModule
from allennlp.modules.transformer.transformer_embeddings import (
Embeddings,
Expand Down
5 changes: 3 additions & 2 deletions allennlp/modules/transformer/bimodal_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,12 @@ def forward(
input_tensor2,
attention_mask1=None,
attention_mask2=None,
co_attention_mask=None,
co_attention_mask=None, # TODO: is this flag necessary?
use_co_attention_mask=False,
):
"""
# Parameters
input_tensor1 : `torch.Tensor`
Shape `batch_size x seq_len1 x hidden_dim1`
where `seq_len1` can be the sequence length
Expand All @@ -143,7 +145,6 @@ def forward(
if you know which words correspond to which regions in the image,
this mask can be applied to limit the attention given the bias.
use_co_attention_mask : `bool`
# TODO: is this flag necessary?
Whether to use co_attention_mask or not, default = `False`.
"""

Expand Down
2 changes: 1 addition & 1 deletion allennlp/modules/transformer/bimodal_connection_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def forward(self, hidden_states1, input_tensor1, hidden_states2, input_tensor2):

class BiModalConnectionLayer(TransformerModule, FromParams):

_huggingface_mapping = {"biAttention": "bimodal_attention", "biOutput": "bimodal_output"}
_pretrained_mapping = {"biAttention": "bimodal_attention", "biOutput": "bimodal_output"}

def __init__(
self,
Expand Down
110 changes: 17 additions & 93 deletions allennlp/modules/transformer/bimodal_encoder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Optional, Dict, List, Union
from typing import Optional, List, TYPE_CHECKING

import torch

from allennlp.common import FromParams

from allennlp.modules.util import replicate_layers

from allennlp.modules.transformer.transformer_layer import TransformerLayer
from allennlp.modules.transformer.bimodal_connection_layer import BiModalConnectionLayer
from allennlp.modules.transformer.transformer_module import TransformerModule

if TYPE_CHECKING:
from transformers.configuration_utils import PretrainedConfig


class BiModalEncoder(TransformerModule, FromParams):
"""
Expand Down Expand Up @@ -46,8 +48,9 @@ class BiModalEncoder(TransformerModule, FromParams):
in_batch_pairs: `bool` (default = `False`)
"""

_huggingface_mapping = {"layer": "layers1"}
_relevant_module = "encoder"
_pretrained_mapping = {"layer": "layers1"}
_pretrained_relevant_module = ["encoder", "bert.encoder"]
_pretrained_allow_missing = [r"^layers2\..*", r"^c_layer\..*"]

def __init__(
self,
Expand Down Expand Up @@ -243,93 +246,14 @@ def forward(
)

@classmethod
def _get_input_arguments(
cls,
pretrained_module: torch.nn.Module,
source="huggingface",
mapping: Optional[Dict[str, str]] = None,
**kwargs,
):
"""
The `pretrained_module` only supplies one of the modalities.
"""
submodules = cls._get_mapped_submodules(pretrained_module, source, mapping)

def _from_config(cls, config: "PretrainedConfig", **kwargs):
final_kwargs = {}

final_kwargs["num_hidden_layers1"] = len(submodules["layers1"])

final_kwargs["hidden_size1"] = submodules["layers1.0.attention.self.query"].in_features
final_kwargs["num_attention_heads1"] = submodules[
"layers1.0.attention.self"
].num_attention_heads
final_kwargs["attention_dropout1"] = submodules["layers1.0.attention.self.dropout"].p
final_kwargs["hidden_dropout1"] = submodules["layers1.0.attention.output.dropout"].p
final_kwargs["intermediate_size1"] = submodules["layers1.0.intermediate.dense"].out_features
final_kwargs["activation"] = submodules["layers1.0.intermediate"].intermediate_act_fn

final_kwargs["num_hidden_layers1"] = config.num_hidden_layers
final_kwargs["hidden_size1"] = config.hidden_size
final_kwargs["num_attention_heads1"] = config.num_attention_heads
final_kwargs["attention_dropout1"] = config.attention_probs_dropout_prob
final_kwargs["hidden_dropout1"] = config.hidden_dropout_prob
final_kwargs["intermediate_size1"] = config.intermediate_size
final_kwargs["activation"] = config.hidden_act
final_kwargs.update(**kwargs)

return final_kwargs

def _load_from_pretrained_module(
self,
pretrained_module: torch.nn.Module,
source="huggingface",
mapping: Optional[Dict[str, str]] = None,
ignore_absent_parameters: Optional[List] = None,
):
if source == "huggingface":
ignore_absent_parameters = ["layers2", "c_layer"]
super()._load_from_pretrained_module(
pretrained_module, source, mapping, ignore_absent_parameters
)

@classmethod
def from_pretrained_module( # type: ignore
cls,
pretrained_module: Union[str, torch.nn.Module],
num_hidden_layers2: int,
hidden_size2: int,
combined_hidden_size: int,
intermediate_size2: int,
num_attention_heads2: int,
combined_num_attention_heads: int,
attention_dropout2: float,
hidden_dropout2: float,
biattention_id1: List[int],
biattention_id2: List[int],
fixed_layer1: int,
fixed_layer2: int,
fast_mode: bool = False,
with_coattention: bool = True,
in_batch_pairs: bool = False,
source="huggingface",
mapping: Optional[Dict[str, str]] = None,
# **kwargs,
):
"""
The `pretrained_module` only supplies one of the modalities.
"""
pretrained_module = cls.get_relevant_module(
pretrained_module, source=source, mapping=mapping
)
final_kwargs = {}
final_kwargs.update(cls._get_input_arguments(pretrained_module, source, mapping))
final_kwargs["num_hidden_layers2"] = num_hidden_layers2
final_kwargs["hidden_size2"] = hidden_size2
final_kwargs["combined_hidden_size"] = combined_hidden_size
final_kwargs["intermediate_size2"] = intermediate_size2
final_kwargs["num_attention_heads2"] = num_attention_heads2
final_kwargs["combined_num_attention_heads"] = combined_num_attention_heads
final_kwargs["attention_dropout2"] = attention_dropout2
final_kwargs["hidden_dropout2"] = hidden_dropout2
final_kwargs["biattention_id1"] = biattention_id1
final_kwargs["biattention_id2"] = biattention_id2
final_kwargs["fixed_layer1"] = fixed_layer1
final_kwargs["fixed_layer2"] = fixed_layer2
final_kwargs["fast_mode"] = fast_mode
final_kwargs["with_coattention"] = with_coattention
final_kwargs["in_batch_pairs"] = in_batch_pairs

return super().from_pretrained_module(pretrained_module, source, mapping, **final_kwargs)
return cls(**final_kwargs)
7 changes: 7 additions & 0 deletions allennlp/modules/transformer/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch

from allennlp.modules.transformer.transformer_module import TransformerModule


class LayerNorm(torch.nn.LayerNorm, TransformerModule):
_pretrained_mapping = {"gamma": "weight", "beta": "bias"}
5 changes: 3 additions & 2 deletions allennlp/modules/transformer/output_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
from allennlp.common import FromParams

from allennlp.modules.transformer.transformer_module import TransformerModule
from allennlp.modules.transformer.layer_norm import LayerNorm


class OutputLayer(TransformerModule, FromParams):

_huggingface_mapping = {"LayerNorm": "layer_norm"}
_pretrained_mapping = {"LayerNorm": "layer_norm"}

def __init__(self, input_size: int, hidden_size: int, dropout: float):
super().__init__()
self.dense = torch.nn.Linear(input_size, hidden_size)
self.layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12)
self.layer_norm = LayerNorm(hidden_size, eps=1e-12)
self.dropout = torch.nn.Dropout(dropout)

def forward(self, hidden_states, input_tensor):
Expand Down
3 changes: 3 additions & 0 deletions allennlp/modules/transformer/positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def __init__(self, min_timescale: float = 1.0, max_timescale: float = 1.0e4):
self.max_timescale = max_timescale

def forward(self, input_tensor: torch.Tensor):
"""
Adds a positional encoding to `input_tensor`.
"""
# TODO: Another option is to specify the expected size in init, so that we can construct
# the positional encoding beforehand, and simply add it to the input tensor in forward.
_, timesteps, hidden_dim = input_tensor.size()
Expand Down
Loading

0 comments on commit cf113d7

Please sign in to comment.