Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Add way of skipping pretrained weights download #5172

Merged
merged 3 commits into from
May 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.sanity_checks.task_checklists` module.
- Added a way to avoid downloading and loading pretrained weights in modules that wrap transformers
such as the `PretrainedTransformerEmbedder` and `PretrainedTransformerMismatchedEmbedder`.
You can do this by setting the parameter `load_weights` to `False`.
See [PR #5172](https://github.com/allenai/allennlp/pull/5172) for more details.


## Unreleased
Expand Down
36 changes: 33 additions & 3 deletions allennlp/common/cached_transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import warnings
from typing import NamedTuple, Optional, Dict, Tuple

import transformers
from transformers import AutoModel, AutoConfig

Expand All @@ -21,6 +23,7 @@ def get(
make_copy: bool,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
load_weights: bool = True,
**kwargs,
) -> transformers.PreTrainedModel:
"""
Expand All @@ -34,18 +37,35 @@ def get(
If this is `True`, return a copy of the model instead of the cached model itself. If you want to modify the
parameters of the model, set this to `True`. If you want only part of the model, set this to `False`, but
make sure to `copy.deepcopy()` the bits you are keeping.
override_weights_file : `str`, optional
override_weights_file : `str`, optional (default = `None`)
If set, this specifies a file from which to load alternate weights that override the
weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created
with `torch.save()`.
override_weights_strip_prefix : `str`, optional
override_weights_strip_prefix : `str`, optional (default = `None`)
If set, strip the given prefix from the state dict when loading it.
load_weights : `bool`, optional (default = `True`)
If set to `False`, no weights will be loaded. This is helpful when you only
want to initialize the architecture, like when you've already fine-tuned a model
and are going to load the weights from a state dict elsewhere.
"""
global _model_cache
spec = TransformerSpec(model_name, override_weights_file, override_weights_strip_prefix)
transformer = _model_cache.get(spec, None)
if transformer is None:
if override_weights_file is not None:
if not load_weights:
if override_weights_file is not None:
warnings.warn(
"You specified an 'override_weights_file' in allennlp.common.cached_transformers.get(), "
"but 'load_weights' is set to False, so 'override_weights_file' will be ignored.",
UserWarning,
)
transformer = AutoModel.from_config(
AutoConfig.from_pretrained(
model_name,
**kwargs,
)
)
elif override_weights_file is not None:
from allennlp.common.file_utils import cached_path
import torch

Expand Down Expand Up @@ -121,3 +141,13 @@ def get_tokenizer(model_name: str, **kwargs) -> transformers.PreTrainedTokenizer
)
_tokenizer_cache[cache_key] = tokenizer
return tokenizer


def _clear_caches():
"""
Clears in-memory transformer and tokenizer caches.
"""
global _model_cache
global _tokenizer_cache
_model_cache.clear()
_tokenizer_cache.clear()
16 changes: 13 additions & 3 deletions allennlp/modules/seq2vec_encoders/bert_pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@ class BertPooler(Seq2VecEncoder):
The pretrained BERT model to use. If this is a string,
we will call `transformers.AutoModel.from_pretrained(pretrained_model)`
and use that.
requires_grad : `bool`, optional, (default = `True`)
override_weights_file: `Optional[str]`, optional (default = `None`)
If set, this specifies a file from which to load alternate weights that override the
weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created
with `torch.save()`.
override_weights_strip_prefix: `Optional[str]`, optional (default = `None`)
If set, strip the given prefix from the state dict when loading it.
load_weights: `bool`, optional (default = `True`)
Whether to load the pretraiend weights.
requires_grad : `bool`, optional (default = `True`)
If True, the weights of the pooler will be updated during training.
Otherwise they will not.
dropout : `float`, optional, (default = `0.0`)
Expand All @@ -43,6 +51,7 @@ def __init__(
*,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
load_weights: bool = True,
requires_grad: bool = True,
dropout: float = 0.0,
transformer_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -54,8 +63,9 @@ def __init__(
model = cached_transformers.get(
pretrained_model,
False,
override_weights_file,
override_weights_strip_prefix,
override_weights_file=override_weights_file,
override_weights_strip_prefix=override_weights_strip_prefix,
load_weights=load_weights,
**(transformer_kwargs or {}),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ class PretrainedTransformerEmbedder(TokenEmbedder):
When `True` (the default), only the final layer of the pretrained transformer is taken
for the embeddings. But if set to `False`, a scalar mix of all of the layers
is used.
override_weights_file: `Optional[str]`, optional (default = `None`)
If set, this specifies a file from which to load alternate weights that override the
weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created
with `torch.save()`.
override_weights_strip_prefix: `Optional[str]`, optional (default = `None`)
If set, strip the given prefix from the state dict when loading it.
load_weights: `bool`, optional (default = `True`)
Whether to load the pretrained weights. If you're loading your model/predictor from an AllenNLP archive
it usually makes sense to set this to `False` (via the `overrides` parameter)
to avoid unnecessarily caching and loading the original pretrained weights,
since the archive will already contain all of the weights needed.
gradient_checkpointing: `bool`, optional (default = `None`)
Enable or disable gradient checkpointing.
tokenizer_kwargs: `Dict[str, Any]`, optional (default = `None`)
Expand All @@ -74,6 +85,7 @@ def __init__(
last_layer_only: bool = True,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
load_weights: bool = True,
gradient_checkpointing: Optional[bool] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
transformer_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -86,6 +98,7 @@ def __init__(
True,
override_weights_file=override_weights_file,
override_weights_strip_prefix=override_weights_strip_prefix,
load_weights=load_weights,
**(transformer_kwargs or {}),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ class PretrainedTransformerMismatchedEmbedder(TokenEmbedder):
When `True` (the default), only the final layer of the pretrained transformer is taken
for the embeddings. But if set to `False`, a scalar mix of all of the layers
is used.
override_weights_file: `Optional[str]`, optional (default = `None`)
If set, this specifies a file from which to load alternate weights that override the
weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created
with `torch.save()`.
override_weights_strip_prefix: `Optional[str]`, optional (default = `None`)
If set, strip the given prefix from the state dict when loading it.
load_weights: `bool`, optional (default = `True`)
Whether to load the pretrained weights. If you're loading your model/predictor from an AllenNLP archive
it usually makes sense to set this to `False` (via the `overrides` parameter)
to avoid unnecessarily caching and loading the original pretrained weights,
since the archive will already contain all of the weights needed.
gradient_checkpointing: `bool`, optional (default = `None`)
Enable or disable gradient checkpointing.
tokenizer_kwargs: `Dict[str, Any]`, optional (default = `None`)
Expand All @@ -56,6 +67,9 @@ def __init__(
max_length: int = None,
train_parameters: bool = True,
last_layer_only: bool = True,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
load_weights: bool = True,
gradient_checkpointing: Optional[bool] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
transformer_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -68,6 +82,9 @@ def __init__(
max_length=max_length,
train_parameters=train_parameters,
last_layer_only=last_layer_only,
override_weights_file=override_weights_file,
override_weights_strip_prefix=override_weights_strip_prefix,
load_weights=load_weights,
gradient_checkpointing=gradient_checkpointing,
tokenizer_kwargs=tokenizer_kwargs,
transformer_kwargs=transformer_kwargs,
Expand Down
32 changes: 22 additions & 10 deletions allennlp/modules/transformer/transformer_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,21 +147,32 @@ def get_relevant_module(
relevant_module: Optional[Union[str, List[str]]] = None,
source: str = "huggingface",
mapping: Optional[Dict[str, str]] = None,
load_weights: bool = True,
):
"""
Returns the relevant underlying module given a model name/object.

# Parameters:

pretrained_module: Name of the transformer model containing the layer,
or the actual layer (not the model object).
relevant_module: Name of the desired module. Defaults to cls._relevant_module.
source: Where the model came from. Default - huggingface.
mapping: Optional mapping that determines any differences in the module names
between the class modules and the input model's modules. Default - cls._huggingface_mapping
# Parameters

pretrained_module : `Union[str, torch.nn.Module]`
Name of the transformer model containing the layer,
or the actual layer (not the model object).
relevant_module : `Optional[Union[str, List[str]]]`, optional
Name of the desired module. Defaults to cls._relevant_module.
source : `str`, optional
Where the model came from. Default - huggingface.
mapping : `Dict[str, str]`, optional
Optional mapping that determines any differences in the module names
between the class modules and the input model's modules.
Default - cls._huggingface_mapping
load_weights : `bool`, optional
Whether or not to load the pretrained weights.
Default is `True`.
"""
if isinstance(pretrained_module, str):
pretrained_module = cached_transformers.get(pretrained_module, False)
pretrained_module = cached_transformers.get(
pretrained_module, False, load_weights=load_weights
)

relevant_module = relevant_module or cls._relevant_module

Expand Down Expand Up @@ -192,6 +203,7 @@ def from_pretrained_module(
pretrained_module: Union[str, torch.nn.Module],
source: str = "huggingface",
mapping: Optional[Dict[str, str]] = None,
load_weights: bool = True,
**kwargs,
):
"""
Expand All @@ -208,7 +220,7 @@ def from_pretrained_module(
)

pretrained_module = cls.get_relevant_module(
pretrained_module, source=source, mapping=mapping
pretrained_module, source=source, mapping=mapping, load_weights=load_weights
)
final_kwargs = cls._get_input_arguments(pretrained_module, source, mapping)
final_kwargs.update(kwargs)
Expand Down
9 changes: 8 additions & 1 deletion allennlp/modules/transformer/transformer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def from_pretrained_module( # type: ignore
num_hidden_layers: Optional[Union[int, range]] = None,
source="huggingface",
mapping: Optional[Dict[str, str]] = None,
load_weights: bool = True,
**kwargs,
):
final_kwargs = {}
Expand All @@ -185,4 +186,10 @@ def from_pretrained_module( # type: ignore
else:
final_kwargs["num_hidden_layers"] = num_hidden_layers

return super().from_pretrained_module(pretrained_module, source, mapping, **final_kwargs)
return super().from_pretrained_module(
pretrained_module,
source=source,
mapping=mapping,
load_weights=load_weights,
**final_kwargs,
)
59 changes: 54 additions & 5 deletions tests/common/cached_transformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@


class TestCachedTransformers(AllenNlpTestCase):
def setup_method(self):
super().setup_method()
cached_transformers._clear_caches()

def teardown_method(self):
super().teardown_method()
cached_transformers._clear_caches()

def test_get_missing_from_cache_local_files_only(self):
with pytest.raises((OSError, ValueError)):
cached_transformers.get(
Expand All @@ -19,17 +27,23 @@ def test_get_missing_from_cache_local_files_only(self):
local_files_only=True,
)

def clear_test_dir(self):
for f in os.listdir(str(self.TEST_DIR)):
os.remove(str(self.TEST_DIR) + "/" + f)
assert len(os.listdir(str(self.TEST_DIR))) == 0

def test_from_pretrained_avoids_weights_download_if_override_weights(self):
config = AutoConfig.from_pretrained("epwalsh/bert-xsmall-dummy", cache_dir=self.TEST_DIR)
# only download config because downloading pretrained weights in addition takes too long
transformer = AutoModel.from_config(
AutoConfig.from_pretrained("epwalsh/bert-xsmall-dummy", cache_dir=self.TEST_DIR)
)
transformer = AutoModel.from_config(config)

# clear cache directory
for f in os.listdir(str(self.TEST_DIR)):
os.remove(str(self.TEST_DIR) + "/" + f)
assert len(os.listdir(str(self.TEST_DIR))) == 0
self.clear_test_dir()

save_weights_path = str(self.TEST_DIR) + "/bert_weights.pth"
save_weights_path = str(self.TEST_DIR / "bert_weights.pth")
torch.save(transformer.state_dict(), save_weights_path)

override_transformer = cached_transformers.get(
Expand All @@ -44,7 +58,7 @@ def test_from_pretrained_avoids_weights_download_if_override_weights(self):
# so this assertion could fail in the future
json_fnames = [fname for fname in os.listdir(str(self.TEST_DIR)) if fname.endswith(".json")]
assert len(json_fnames) == 1
json_data = json.load(open(str(self.TEST_DIR) + "/" + json_fnames[0]))
json_data = json.load(open(str(self.TEST_DIR / json_fnames[0])))
assert (
json_data["url"]
== "https://huggingface.co/epwalsh/bert-xsmall-dummy/resolve/main/config.json"
Expand All @@ -58,6 +72,41 @@ def test_from_pretrained_avoids_weights_download_if_override_weights(self):
for p1, p2 in zip(transformer.parameters(), override_transformer.parameters()):
assert p1.data.ne(p2.data).sum() == 0

def test_from_pretrained_no_load_weights(self):
_ = cached_transformers.get(
"epwalsh/bert-xsmall-dummy", False, load_weights=False, cache_dir=self.TEST_DIR
)
# check that only three files were downloaded (filename.json, filename, filename.lock), for config.json
# if more than three files were downloaded, then model weights were also (incorrectly) downloaded
# NOTE: downloaded files are not explicitly detailed in Huggingface's public API,
# so this assertion could fail in the future
json_fnames = [fname for fname in os.listdir(str(self.TEST_DIR)) if fname.endswith(".json")]
assert len(json_fnames) == 1
json_data = json.load(open(str(self.TEST_DIR / json_fnames[0])))
assert (
json_data["url"]
== "https://huggingface.co/epwalsh/bert-xsmall-dummy/resolve/main/config.json"
)
resource_id = os.path.splitext(json_fnames[0])[0]
assert set(os.listdir(str(self.TEST_DIR))) == set(
[json_fnames[0], resource_id, resource_id + ".lock"]
)

def test_from_pretrained_no_load_weights_local_config(self):
config = AutoConfig.from_pretrained("epwalsh/bert-xsmall-dummy", cache_dir=self.TEST_DIR)
self.clear_test_dir()

# Save config to file.
local_config_path = str(self.TEST_DIR / "local_config.json")
config.to_json_file(local_config_path, use_diff=False)

# Now load the model from the local config.
_ = cached_transformers.get(
local_config_path, False, load_weights=False, cache_dir=self.TEST_DIR
)
# Make sure no other files were downloaded.
assert os.listdir(str(self.TEST_DIR)) == ["local_config.json"]

def test_get_tokenizer_missing_from_cache_local_files_only(self):
with pytest.raises((OSError, ValueError)):
cached_transformers.get_tokenizer(
Expand Down