Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Update pretrained.py: Quick fix to be able to load pertained_models d…
Browse files Browse the repository at this point in the history
…irectly to GPU. (#254)

* Update pretrained.py

quick fix to be able to load pertained_models directly to GPU.

* Update CHANGELOG.md
  • Loading branch information
davidberenstein1957 committed Apr 23, 2021
1 parent 2ba85b5 commit 659c71f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- `pretrained.load_predictor()` now allows for loading model onto GPU.
- `VqaMeasure` now calculates correctly in the distributed case.
- `ConllCorefScores` now calculates correctly in the distributed case.
- `SrlEvalScorer` raises an appropriate error if run in the distributed setting.
Expand Down
2 changes: 2 additions & 0 deletions allennlp_models/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def get_pretrained_models() -> Dict[str, ModelCard]:
def load_predictor(
model_id: str,
pretrained_models: Dict[str, ModelCard] = None,
cuda_device: int = -1,
overrides: Union[str, Dict[str, Any]] = None,
) -> Predictor:
"""
Expand All @@ -76,5 +77,6 @@ def load_predictor(
return Predictor.from_path(
model_card.model_usage.archive_file,
predictor_name=model_card.registered_predictor_name,
cuda_device=cuda_device,
overrides=overrides,
)

0 comments on commit 659c71f

Please sign in to comment.