Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Add way to initialize SrlBert without pretrained BERT weights #257

Merged
merged 3 commits into from
May 2, 2021

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Apr 30, 2021

Closes allenai/allennlp#5170.

You can avoid caching/loading pretrained BERT weights by setting the bert_model parameter of SrlBert to a dictionary that corresponds to the BertConfig from HuggingFace. You'll also need a local copy of the config and vocab to avoid downloads from the dataset reader, so the easiest complete work-around would look something like this:

from transformers import AutoConfig
from allennlp.predictors import Predictor

transformer_model_name = "bert-base-uncased"
archive_path = "https://storage.googleapis.com/allennlp-public-models/structured-prediction-srl-bert.2020.12.15.tar.gz"

# Need copies of the transformer config and vocab in a local directory.
local_config_path = "./" + transformer_model_name + "-local"

config = AutoConfig.from_pretrained(local_config_path)

predictor = Predictor.from_path(
    archive_path,
    overrides={
        "model.bert_model": config.to_dict(),
        "dataset_reader.bert_model_name": local_config_path,
    },
)

You can set up the local files you need by running this:

from transformers import AutoConfig, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(transformer_model_name)
config = AutoConfig.from_pretrained(transformer_model_name)
tokenizer.save_pretrained(tokenizer_path)
config.to_json_file(local_config_path + "/config.json")

This is related to allenai/allennlp#5172, but required it's own solution since the SrlBert model is a bit of an oddball in that it uses the BERT model class from transformers directly, instead of through AllenNLP's PretrainedTransformerEmbedder.

Copy link
Contributor

@ArjunSubramonian ArjunSubramonian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Is there a test for this?

@@ -66,7 +66,7 @@ jobs:
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }}
key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }}-v2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache was corrupted for some reason.

@epwalsh
Copy link
Member Author

epwalsh commented Apr 30, 2021

Looks great! Is there a test for this?

No, but there should be. I'll add one.

Copy link
Contributor

@ArjunSubramonian ArjunSubramonian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@epwalsh epwalsh merged commit 845fe4c into main May 2, 2021
@epwalsh epwalsh deleted the srl-no-load-weights branch May 2, 2021 21:50
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cannot load the pre-trained models
2 participants