Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Adds tests for multitask predictor (#248)
Browse files Browse the repository at this point in the history
* Makes the multitask tests work

* Adds some tests for the new multitask predict story

* Changelog
  • Loading branch information
dirkgr authored Apr 14, 2021
1 parent f4fb932 commit acc3424
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 24 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added script that produces the coref training data.
- Added tests for using `allennlp predict` on multitask models.


## [v2.2.0](https://github.com/allenai/allennlp-models/releases/tag/v2.2.0) - 2021-03-26
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
{"task": "vqa", "image": "https://i.imgur.com/UOt9Q4J.jpeg", "question": "What's the color of the pyramid in the foreground?"}
{"task": "vqa", "image": "https://i.imgur.com/9JNTNQd.jpeg", "question": "How many human skulls are there?"}
{"task": "gqa", "image": "https://i.imgur.com/FB4749j.jpeg", "question": "What color are the statues?"}
{"task": "ve", "image": "https://i.imgur.com/FB4749j.jpeg", "hypothesis": "The statues are hugging each other."}
111 changes: 111 additions & 0 deletions test_fixtures/vision/vilbert_multitask/experiment.jsonnet
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
local model_name = "epwalsh/bert-xsmall-dummy";

local vocabulary = {
"type": "from_files",
"directory": "https://storage.googleapis.com/allennlp-public-data/vilbert/vilbert_multitask.bert-base-uncased.vocab.tar.gz"
};

local reader_common = {
"image_loader": "torch",
"image_featurizer": "null",
"region_detector": {
"type": "random",
"seed": 322
},
"tokenizer": {
"type": "pretrained_transformer",
"model_name": model_name
},
"token_indexers": {
"tokens": {
"type": "pretrained_transformer",
"model_name": model_name
}
}
};

{
"dataset_reader": {
"type": "multitask",
"readers": {
"vqa": reader_common {
"type": "vqav2",
"image_dir": "test_fixtures/vision/images/vqav2",
"answer_vocab": vocabulary,
"multiple_answers_per_question": true
},
"ve": reader_common {
"type": "visual-entailment",
"image_dir": "test_fixtures/vision/images/visual_entailment",
}
}
},
"validation_dataset_reader": self.dataset_reader {
"readers": super.readers {
"vqa": super.vqa {
"answer_vocab": null // make sure we don't skip unanswerable questions during validation
}
}
},
"vocabulary": vocabulary,
"train_data_path": {
"vqa": "unittest",
"ve": "test_fixtures/vision/visual_entailment/sample_pairs.jsonl",
},
"validation_data_path": {
"vqa": "unittest",
"ve": "test_fixtures/vision/visual_entailment/sample_pairs.jsonl",
},
"model": {
"type": "multitask",
"arg_name_mapping": {
"backbone": {"question": "text", "hypothesis": "text"}
},
"backbone": {
"type": "vilbert_from_huggingface",
"model_name": model_name,

"image_feature_dim": 10,
"image_num_hidden_layers": 1,
"image_hidden_size": 20,
"image_num_attention_heads": 1,
"image_intermediate_size": 5,
"image_attention_dropout": 0.0,
"image_hidden_dropout": 0.0,
"image_biattention_id": [0, 1],
"image_fixed_layer": 0,

"text_biattention_id": [0, 1],
"text_fixed_layer": 0,

"combined_hidden_size": 20,
"combined_num_attention_heads": 2,

"pooled_output_dim": 20,
"fusion_method": "sum",
},
"heads": {
"vqa": {
"type": "vqa",
"embedding_dim": 20
},
"ve": {
"type": "visual_entailment",
"embedding_dim": 20
}
}
},
"data_loader": {
"type": "multitask",
"scheduler": { "batch_size": 2 }
},
"trainer": {
"optimizer": {
"type": "huggingface_adamw",
"lr": 4e-5,
},
"validation_metric": ["+vqa_vqa", "+ve_acc"],
"patience": 1,
"num_epochs": 3,
}
}
Binary file not shown.
92 changes: 92 additions & 0 deletions tests/commands/multitask_predict_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# These test should really be in the core repo, but they are here because the multitask model is
# here.
import json
import os
import pathlib
import shutil
import sys
import tempfile

import pytest
from allennlp.commands import main
from allennlp.common.checks import ConfigurationError

from allennlp.common.testing import AllenNlpTestCase

from tests import FIXTURES_ROOT


class TestMultitaskPredict(AllenNlpTestCase):
def setup_method(self):
super().setup_method()
self.classifier_model_path = FIXTURES_ROOT / "vision" / "vilbert_multitask" / "model.tar.gz"
self.classifier_data_path = FIXTURES_ROOT / "vision" / "vilbert_multitask" / "dataset.json"
self.tempdir = pathlib.Path(tempfile.mkdtemp())
self.infile = self.tempdir / "inputs.txt"
self.outfile = self.tempdir / "outputs.txt"

def test_works_with_multitask_model(self):
sys.argv = [
"__main__.py", # executable
"predict", # command
str(self.classifier_model_path),
str(self.classifier_data_path),
"--output-file",
str(self.outfile),
"--silent",
]

main()

assert os.path.exists(self.outfile)

with open(self.outfile, "r") as f:
results = [json.loads(line) for line in f]

assert len(results) == 3
for result in results:
assert "vqa_best_answer" in result.keys() or "ve_entailment_answer" in result.keys()

shutil.rmtree(self.tempdir)

def test_using_dataset_reader_works_with_specified_multitask_head(self):
sys.argv = [
"__main__.py", # executable
"predict", # command
str(self.classifier_model_path),
"unittest", # "path" of the input data, but it's not really a path for VQA
"--output-file",
str(self.outfile),
"--silent",
"--use-dataset-reader",
"--multitask-head",
"vqa",
]

main()

assert os.path.exists(self.outfile)

with open(self.outfile, "r") as f:
results = [json.loads(line) for line in f]

assert len(results) == 3
for result in results:
assert "vqa_best_answer" in result.keys()

shutil.rmtree(self.tempdir)

def test_using_dataset_reader_fails_with_missing_parameter(self):
sys.argv = [
"__main__.py", # executable
"predict", # command
str(self.classifier_model_path),
"unittest", # "path" of the input data, but it's not really a path for VQA
"--output-file",
str(self.outfile),
"--silent",
"--use-dataset-reader",
]

with pytest.raises(ConfigurationError):
main()
33 changes: 33 additions & 0 deletions tests/vision/models/vilbert_multitask_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from allennlp.common.testing import ModelTestCase

from tests import FIXTURES_ROOT


class TestVilbertMultitask(ModelTestCase):
def test_predict(self):
from allennlp.models import load_archive
from allennlp.predictors import Predictor
import allennlp_models.vision

archive = load_archive(FIXTURES_ROOT / "vision" / "vilbert_multitask" / "model.tar.gz")
predictor = Predictor.from_archive(archive)

with open(
FIXTURES_ROOT / "vision" / "vilbert_multitask" / "dataset.json", "r"
) as file_input:
json_input = [predictor.load_line(line) for line in file_input if not line.isspace()]
predictions = predictor.predict_batch_json(json_input)
assert all(
"gqa_best_answer" in p or "vqa_best_answer" in p or "ve_entailment_answer" in p
for p in predictions
)

def test_model_can_train_save_and_load_small_model(self):
param_file = FIXTURES_ROOT / "vision" / "vilbert_multitask" / "experiment.jsonnet"

# The VQA weights are going to be zero because the last batch is Visual Entailment only,
# and so the gradients for VQA don't get set.
self.ensure_model_can_train_save_and_load(
param_file,
gradients_to_ignore={"_heads.vqa.classifier.bias", "_heads.vqa.classifier.weight"},
)
23 changes: 0 additions & 23 deletions tests/vision/vilbert_multitask.py

This file was deleted.

0 comments on commit acc3424

Please sign in to comment.