This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 176
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds tests for multitask predictor (#248)
* Makes the multitask tests work * Adds some tests for the new multitask predict story * Changelog
- Loading branch information
Showing
7 changed files
with
237 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 change: 0 additions & 1 deletion
1
test_fixtures/vision/vilbert_multitask.json → ...res/vision/vilbert_multitask/dataset.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
{"task": "vqa", "image": "https://i.imgur.com/UOt9Q4J.jpeg", "question": "What's the color of the pyramid in the foreground?"} | ||
{"task": "vqa", "image": "https://i.imgur.com/9JNTNQd.jpeg", "question": "How many human skulls are there?"} | ||
{"task": "gqa", "image": "https://i.imgur.com/FB4749j.jpeg", "question": "What color are the statues?"} | ||
{"task": "ve", "image": "https://i.imgur.com/FB4749j.jpeg", "hypothesis": "The statues are hugging each other."} |
111 changes: 111 additions & 0 deletions
111
test_fixtures/vision/vilbert_multitask/experiment.jsonnet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
local model_name = "epwalsh/bert-xsmall-dummy"; | ||
|
||
local vocabulary = { | ||
"type": "from_files", | ||
"directory": "https://storage.googleapis.com/allennlp-public-data/vilbert/vilbert_multitask.bert-base-uncased.vocab.tar.gz" | ||
}; | ||
|
||
local reader_common = { | ||
"image_loader": "torch", | ||
"image_featurizer": "null", | ||
"region_detector": { | ||
"type": "random", | ||
"seed": 322 | ||
}, | ||
"tokenizer": { | ||
"type": "pretrained_transformer", | ||
"model_name": model_name | ||
}, | ||
"token_indexers": { | ||
"tokens": { | ||
"type": "pretrained_transformer", | ||
"model_name": model_name | ||
} | ||
} | ||
}; | ||
|
||
{ | ||
"dataset_reader": { | ||
"type": "multitask", | ||
"readers": { | ||
"vqa": reader_common { | ||
"type": "vqav2", | ||
"image_dir": "test_fixtures/vision/images/vqav2", | ||
"answer_vocab": vocabulary, | ||
"multiple_answers_per_question": true | ||
}, | ||
"ve": reader_common { | ||
"type": "visual-entailment", | ||
"image_dir": "test_fixtures/vision/images/visual_entailment", | ||
} | ||
} | ||
}, | ||
"validation_dataset_reader": self.dataset_reader { | ||
"readers": super.readers { | ||
"vqa": super.vqa { | ||
"answer_vocab": null // make sure we don't skip unanswerable questions during validation | ||
} | ||
} | ||
}, | ||
"vocabulary": vocabulary, | ||
"train_data_path": { | ||
"vqa": "unittest", | ||
"ve": "test_fixtures/vision/visual_entailment/sample_pairs.jsonl", | ||
}, | ||
"validation_data_path": { | ||
"vqa": "unittest", | ||
"ve": "test_fixtures/vision/visual_entailment/sample_pairs.jsonl", | ||
}, | ||
"model": { | ||
"type": "multitask", | ||
"arg_name_mapping": { | ||
"backbone": {"question": "text", "hypothesis": "text"} | ||
}, | ||
"backbone": { | ||
"type": "vilbert_from_huggingface", | ||
"model_name": model_name, | ||
|
||
"image_feature_dim": 10, | ||
"image_num_hidden_layers": 1, | ||
"image_hidden_size": 20, | ||
"image_num_attention_heads": 1, | ||
"image_intermediate_size": 5, | ||
"image_attention_dropout": 0.0, | ||
"image_hidden_dropout": 0.0, | ||
"image_biattention_id": [0, 1], | ||
"image_fixed_layer": 0, | ||
|
||
"text_biattention_id": [0, 1], | ||
"text_fixed_layer": 0, | ||
|
||
"combined_hidden_size": 20, | ||
"combined_num_attention_heads": 2, | ||
|
||
"pooled_output_dim": 20, | ||
"fusion_method": "sum", | ||
}, | ||
"heads": { | ||
"vqa": { | ||
"type": "vqa", | ||
"embedding_dim": 20 | ||
}, | ||
"ve": { | ||
"type": "visual_entailment", | ||
"embedding_dim": 20 | ||
} | ||
} | ||
}, | ||
"data_loader": { | ||
"type": "multitask", | ||
"scheduler": { "batch_size": 2 } | ||
}, | ||
"trainer": { | ||
"optimizer": { | ||
"type": "huggingface_adamw", | ||
"lr": 4e-5, | ||
}, | ||
"validation_metric": ["+vqa_vqa", "+ve_acc"], | ||
"patience": 1, | ||
"num_epochs": 3, | ||
} | ||
} |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# These test should really be in the core repo, but they are here because the multitask model is | ||
# here. | ||
import json | ||
import os | ||
import pathlib | ||
import shutil | ||
import sys | ||
import tempfile | ||
|
||
import pytest | ||
from allennlp.commands import main | ||
from allennlp.common.checks import ConfigurationError | ||
|
||
from allennlp.common.testing import AllenNlpTestCase | ||
|
||
from tests import FIXTURES_ROOT | ||
|
||
|
||
class TestMultitaskPredict(AllenNlpTestCase): | ||
def setup_method(self): | ||
super().setup_method() | ||
self.classifier_model_path = FIXTURES_ROOT / "vision" / "vilbert_multitask" / "model.tar.gz" | ||
self.classifier_data_path = FIXTURES_ROOT / "vision" / "vilbert_multitask" / "dataset.json" | ||
self.tempdir = pathlib.Path(tempfile.mkdtemp()) | ||
self.infile = self.tempdir / "inputs.txt" | ||
self.outfile = self.tempdir / "outputs.txt" | ||
|
||
def test_works_with_multitask_model(self): | ||
sys.argv = [ | ||
"__main__.py", # executable | ||
"predict", # command | ||
str(self.classifier_model_path), | ||
str(self.classifier_data_path), | ||
"--output-file", | ||
str(self.outfile), | ||
"--silent", | ||
] | ||
|
||
main() | ||
|
||
assert os.path.exists(self.outfile) | ||
|
||
with open(self.outfile, "r") as f: | ||
results = [json.loads(line) for line in f] | ||
|
||
assert len(results) == 3 | ||
for result in results: | ||
assert "vqa_best_answer" in result.keys() or "ve_entailment_answer" in result.keys() | ||
|
||
shutil.rmtree(self.tempdir) | ||
|
||
def test_using_dataset_reader_works_with_specified_multitask_head(self): | ||
sys.argv = [ | ||
"__main__.py", # executable | ||
"predict", # command | ||
str(self.classifier_model_path), | ||
"unittest", # "path" of the input data, but it's not really a path for VQA | ||
"--output-file", | ||
str(self.outfile), | ||
"--silent", | ||
"--use-dataset-reader", | ||
"--multitask-head", | ||
"vqa", | ||
] | ||
|
||
main() | ||
|
||
assert os.path.exists(self.outfile) | ||
|
||
with open(self.outfile, "r") as f: | ||
results = [json.loads(line) for line in f] | ||
|
||
assert len(results) == 3 | ||
for result in results: | ||
assert "vqa_best_answer" in result.keys() | ||
|
||
shutil.rmtree(self.tempdir) | ||
|
||
def test_using_dataset_reader_fails_with_missing_parameter(self): | ||
sys.argv = [ | ||
"__main__.py", # executable | ||
"predict", # command | ||
str(self.classifier_model_path), | ||
"unittest", # "path" of the input data, but it's not really a path for VQA | ||
"--output-file", | ||
str(self.outfile), | ||
"--silent", | ||
"--use-dataset-reader", | ||
] | ||
|
||
with pytest.raises(ConfigurationError): | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from allennlp.common.testing import ModelTestCase | ||
|
||
from tests import FIXTURES_ROOT | ||
|
||
|
||
class TestVilbertMultitask(ModelTestCase): | ||
def test_predict(self): | ||
from allennlp.models import load_archive | ||
from allennlp.predictors import Predictor | ||
import allennlp_models.vision | ||
|
||
archive = load_archive(FIXTURES_ROOT / "vision" / "vilbert_multitask" / "model.tar.gz") | ||
predictor = Predictor.from_archive(archive) | ||
|
||
with open( | ||
FIXTURES_ROOT / "vision" / "vilbert_multitask" / "dataset.json", "r" | ||
) as file_input: | ||
json_input = [predictor.load_line(line) for line in file_input if not line.isspace()] | ||
predictions = predictor.predict_batch_json(json_input) | ||
assert all( | ||
"gqa_best_answer" in p or "vqa_best_answer" in p or "ve_entailment_answer" in p | ||
for p in predictions | ||
) | ||
|
||
def test_model_can_train_save_and_load_small_model(self): | ||
param_file = FIXTURES_ROOT / "vision" / "vilbert_multitask" / "experiment.jsonnet" | ||
|
||
# The VQA weights are going to be zero because the last batch is Visual Entailment only, | ||
# and so the gradients for VQA don't get set. | ||
self.ensure_model_can_train_save_and_load( | ||
param_file, | ||
gradients_to_ignore={"_heads.vqa.classifier.bias", "_heads.vqa.classifier.weight"}, | ||
) |
This file was deleted.
Oops, something went wrong.