Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Add way to initialize SrlBert without pretrained BERT weights (#257)
Browse files Browse the repository at this point in the history
* add way to initialize SrlBert without pretrained BERT weights

* tick cache version

* add test
  • Loading branch information
epwalsh authored May 2, 2021
1 parent ab1e86a commit 845fe4c
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 29 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }}
key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }}-v2

- name: Install requirements
run: |
Expand Down Expand Up @@ -192,7 +192,7 @@ jobs:
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }}
key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }}-v2

- name: Install requirements
run: |
Expand Down Expand Up @@ -336,7 +336,7 @@ jobs:
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }}
key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }}-v2

- name: Install requirements
run: |
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added tests for checklist suites for SQuAD-style reading comprehension models (`bidaf`), and textual entailment models (`decomposable_attention` and `esim`).
- Added a way to initialize the `SrlBert` model without caching/loading pretrained transformer weights.
You need to set the `bert_model` parameter to the dictionary form of the corresponding `BertConfig` from HuggingFace.
See [PR #257](https://github.com/allenai/allennlp-models/pull/257) for more details.


## [v2.4.0](https://github.com/allenai/allennlp-models/releases/tag/v2.4.0) - 2021-04-22
Expand Down
28 changes: 25 additions & 3 deletions allennlp_models/structured_prediction/models/srl_bert.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import warnings
from typing import Dict, List, Any, Union

from overrides import overrides
import torch
from torch.nn.modules import Linear, Dropout
import torch.nn.functional as F
from transformers.models.bert.configuration_bert import BertConfig
from transformers.models.bert.modeling_bert import BertModel

from allennlp.data import TextFieldTensors, Vocabulary
Expand Down Expand Up @@ -31,14 +33,26 @@ class SrlBert(Model):
vocab : `Vocabulary`, required
A Vocabulary, required in order to compute sizes for input/output projections.
model : `Union[str, BertModel]`, required.
A string describing the BERT model to load or an already constructed BertModel.
bert_model : `Union[str, Dict[str, Any], BertModel]`, required.
A string describing the BERT model to load, a BERT config in the form of a dictionary,
or an already constructed BertModel.
!!! Note
If you pass a config `bert_model` (a dictionary), pretrained weights will
not be cached and loaded! This is ideal if you're loading this model from an
AllenNLP archive since the weights you need will already be included in the
archive, but not what you want if you're training.
initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
Used to initialize the model parameters.
label_smoothing : `float`, optional (default = `0.0`)
Whether or not to use label smoothing on the labels when computing cross entropy loss.
ignore_span_metric : `bool`, optional (default = `False`)
Whether to calculate span loss, which is irrelevant when predicting BIO for Open Information Extraction.
srl_eval_path : `str`, optional (default=`DEFAULT_SRL_EVAL_PATH`)
The path to the srl-eval.pl script. By default, will use the srl-eval.pl included with allennlp,
which is located at allennlp/tools/srl-eval.pl . If `None`, srl-eval.pl is not used.
Expand All @@ -47,7 +61,7 @@ class SrlBert(Model):
def __init__(
self,
vocab: Vocabulary,
bert_model: Union[str, BertModel],
bert_model: Union[str, Dict[str, Any], BertModel],
embedding_dropout: float = 0.0,
initializer: InitializerApplicator = InitializerApplicator(),
label_smoothing: float = None,
Expand All @@ -59,6 +73,14 @@ def __init__(

if isinstance(bert_model, str):
self.bert_model = BertModel.from_pretrained(bert_model)
elif isinstance(bert_model, dict):
warnings.warn(
"Initializing BertModel without pretrained weights. This is fine if you're loading "
"from an AllenNLP archive, but not if you're training.",
UserWarning,
)
bert_config = BertConfig.from_dict(bert_model)
self.bert_model = BertModel(bert_config)
else:
self.bert_model = bert_model

Expand Down
13 changes: 7 additions & 6 deletions test_fixtures/structured_prediction/srl/bert_srl.jsonnet
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
local bert_model = "allennlp/tests/fixtures/bert/vocab.txt";
local bert_model = "epwalsh/bert-xsmall-dummy";

{
"dataset_reader":{
"type":"srl",
"bert_model_name": "bert-base-uncased"
"dataset_reader":{
"type":"srl",
"bert_model_name": bert_model,
},
"train_data_path": "test_fixtures/structured_prediction/srl",
"validation_data_path": "test_fixtures/structured_prediction/srl",
"train_data_path": "test_fixtures/structured_prediction/srl",
"validation_data_path": "test_fixtures/structured_prediction/srl",
"model": {
"type": "srl_bert",
"bert_model": bert_model,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
local bert_model = "test_fixtures/bert-xsmall-dummy";

# Take from test_fixtures/bert-xsmall-dummy/config.json
local bert_config = {
"architectures": [
"BertModel"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 20,
"initializer_range": 0.02,
"intermediate_size": 40,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 1,
"num_hidden_layers": 1,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 250
};

{
"dataset_reader":{
"type":"srl",
"bert_model_name": bert_model,
},
"train_data_path": "test_fixtures/structured_prediction/srl",
"validation_data_path": "test_fixtures/structured_prediction/srl",
"model": {
"type": "srl_bert",
"bert_model": bert_config,
"embedding_dropout": 0.0
},
"data_loader": {
"batch_sampler": {
"type": "bucket",
"batch_size": 5,
"padding_noise": 0.0
}
},
"trainer": {
"optimizer": {
"type": "adam",
"lr": 0.001
},
"checkpointer": {
"num_serialized_models_to_keep": 1
},
"num_epochs": 3,
"grad_norm": 10.0,
"patience": 5,
"cuda_device": -1
}
}
30 changes: 13 additions & 17 deletions tests/structured_prediction/models/bert_srl_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

import numpy
import pytest
from _pytest.monkeypatch import MonkeyPatch
from transformers.models.bert.modeling_bert import BertConfig, BertModel
from transformers.models.bert.tokenization_bert import BertTokenizer

Expand All @@ -14,28 +15,12 @@

class BertSrlTest(ModelTestCase):
def setup_method(self):

self.monkeypatch = MonkeyPatch()
# monkeypatch the PretrainedBertModel to return the tiny test fixture model
config_path = FIXTURES_ROOT / "structured_prediction" / "srl" / "bert" / "config.json"
vocab_path = FIXTURES_ROOT / "structured_prediction" / "srl" / "bert" / "vocab.txt"
config = BertConfig.from_json_file(config_path)
self.monkeypatch.setattr(BertModel, "from_pretrained", lambda _: BertModel(config))
self.monkeypatch.setattr(
BertTokenizer, "from_pretrained", lambda _: BertTokenizer(vocab_path)
)

super().setup_method()
self.set_up_model(
FIXTURES_ROOT / "structured_prediction" / "srl" / "bert_srl.jsonnet",
FIXTURES_ROOT / "structured_prediction" / "srl" / "conll_2012",
)

def teardown_method(self):
self.monkeypatch.undo()
self.monkeypatch.undo()
super().teardown_method()

def test_bert_srl_model_can_train_save_and_load(self):
ignore_grads = {"bert_model.pooler.dense.weight", "bert_model.pooler.dense.bias"}
self.ensure_model_can_train_save_and_load(self.param_file, gradients_to_ignore=ignore_grads)
Expand Down Expand Up @@ -67,3 +52,14 @@ def test_decode_runs_correctly(self):
# to_bioul throws an exception if the tag sequence is not well formed,
# so here we can easily check that the sequence we produce is good.
to_bioul(prediction, encoding="BIO")


class BertSrlFromLocalFilesTest(ModelTestCase):
def test_init_from_local_files(self):
with pytest.warns(
UserWarning, match=r"Initializing BertModel without pretrained weights.*"
):
self.set_up_model(
FIXTURES_ROOT / "structured_prediction" / "srl" / "bert_srl_local_files.jsonnet",
FIXTURES_ROOT / "structured_prediction" / "srl" / "conll_2012",
)

0 comments on commit 845fe4c

Please sign in to comment.