Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

When loading archived fine-tuned models for prediction, prevent non-fine-tuned pretrained transformer models from being downloaded #4599

Closed
bwriordan opened this issue Aug 25, 2020 · 1 comment · Fixed by #5172

Comments

@bwriordan
Copy link

Is your feature request related to a problem? Please describe.
When loading an archived model with a pretrained transformer embedder for prediction, where the model has been fine-tuned on a dataset, a huggingface pretrained transformer model is always downloaded to ~/.cache/torch. Then the archived model is loaded and replaces the downloaded model. When doing prediction with an existing archived model, the huggingface download is not necessary.

Describe the solution you'd like
For prediction, prevent pretrained transformer models from being downloaded.

Describe alternatives you've considered
There doesn't seem to be a way to pass an argument to from_path() to prevent the model from being downloaded.

Additional context

The downloading happens here:
model.py: model = Model.from_params(vocab=vocab, params=model_params)
pretrained_transformer_embedder.py:

self.transformer_model = cached_transformers.get(
            model_name, True, override_weights_file, override_weights_strip_prefix
        )

cached_transformers.py: transformer = AutoModel.from_pretrained(model_name)

There is logic in model.py to remove pretrained embedding parameters: remove_pretrained_embedding_params(model_params)
However, this seems to only target pretrained embeddings like Glove via the pretrained_file config file parameter.

@matt-gardner
Copy link
Contributor

Yes, the current behavior here is not ideal, but it doesn't actually have an easy solution, I don't think. We get not just the weights from that call to AutoModel.from_pretrained, we get specifics of the architecture, also. We don't know weight sizes, e.g., without that call. It's not trivial to bypass it like how we do with a simple embedding layer.

This is pretty low priority for us, especially as you probably downloaded the model when you trained it, anyway, so a cache miss here is a rare case, not the typical case. But, if anyone figures out a good solution for this problem, we'd be happy to review a PR.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants