diff --git a/CHANGELOG.md b/CHANGELOG.md index b62874f70..ee1aa1d22 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Added support for NLVR2 visual entailment, including a data loader, two models, and training configs. + ## [v2.5.0](https://github.com/allenai/allennlp-models/releases/tag/v2.5.0) - 2021-06-03 diff --git a/README.md b/README.md index 2b80ad3db..b1cf2e0b7 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,8 @@ Here is a list of pre-trained models currently available. - [`mc-roberta-commonsenseqa`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/mc-roberta-commonsenseqa.json) - RoBERTa-based multiple choice model for CommonSenseQA. - [`mc-roberta-piqa`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/mc-roberta-piqa.json) - RoBERTa-based multiple choice model for PIQA. - [`mc-roberta-swag`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/mc-roberta-swag.json) - RoBERTa-based multiple choice model for SWAG. +- [`nlvr2-vilbert`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/nlvr2-vilbert-head.json) - ViLBERT-based model for Visual Entailment. +- [`nlvr2-vilbert`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/nlvr2-vilbert.json) - ViLBERT-based model for Visual Entailment. - [`pair-classification-binary-gender-bias-mitigated-roberta-snli`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/pair-classification-binary-gender-bias-mitigated-roberta-snli.json) - RoBERTa finetuned on SNLI with binary gender bias mitigation. - [`pair-classification-decomposable-attention-elmo`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/pair-classification-decomposable-attention-elmo.json) - The decomposable attention model (Parikh et al, 2017) combined with ELMo embeddings trained on SNLI. - [`pair-classification-esim`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/pair-classification-esim.json) - Enhanced LSTM trained on SNLI. diff --git a/allennlp_models/modelcards/nlvr2-vilbert-head.json b/allennlp_models/modelcards/nlvr2-vilbert-head.json new file mode 100644 index 000000000..37a2d81a2 --- /dev/null +++ b/allennlp_models/modelcards/nlvr2-vilbert-head.json @@ -0,0 +1,66 @@ +{ + "id": "nlvr2-vilbert", + "registered_model_name": "nlvr2", + "registered_predictor_name": null, + "display_name": "Visual Entailment - NLVR2", + "task_id": "nlvr2", + "model_details": { + "description": "This model uses a VilBERT-based backbone with an NLVR2-specific model head. The image features are obtained using the ResNet backbone and Faster RCNN (region detection).", + "short_description": "ViLBERT-based model for Visual Entailment.", + "developed_by": "Lu et al", + "contributed_by": "Jacob Morrison", + "date": "2021-05-27", + "version": "2", + "model_type": "ViLBERT based on BERT large", + "paper": { + "citation": "\n@inproceedings{Lu2019ViLBERTPT,\ntitle={ViLBERT: Pretraining Task-Agnostic Visiolinguistic Representations for Vision-and-Language Tasks},\nauthor={Jiasen Lu and Dhruv Batra and D. Parikh and Stefan Lee},\nbooktitle={NeurIPS},\nyear={2019}", + "title": "ViLBERT: Pretraining Task-Agnostic Visiolinguistic Representations for Vision-and-Language Tasks", + "url": "https://api.semanticscholar.org/CorpusID:199453025" + }, + "license": null, + "contact": "allennlp-contact@allenai.org" + }, + "intended_use": { + "primary_uses": "This model is developed for the AllenNLP demo.", + "primary_users": null, + "out_of_scope_use_cases": null + }, + "factors": { + "relevant_factors": null, + "evaluation_factors": null + }, + "metrics": { + "model_performance_measures": "Accuracy and F1-score", + "decision_thresholds": null, + "variation_approaches": null + }, + "evaluation_data": { + "dataset": { + "name": "Natural Language for Visual Reasoning For Real dev set", + "url": "https://github.com/lil-lab/nlvr/tree/master/nlvr2", + "notes": "Evaluation requires a large amount of images to be accessible locally, so we cannot provide a command you can easily copy and paste." + }, + "motivation": null, + "preprocessing": null + }, + "training_data": { + "dataset": { + "name": "Natural Language for Visual Reasoning For Real train set", + "url": "https://github.com/lil-lab/nlvr/tree/master/nlvr2" + }, + "motivation": null, + "preprocessing": null + }, + "quantitative_analyses": { + "unitary_results": "On the validation set:\nF1: 33.7%\nAccuracy: 50.8%.\nThese scores do not match the performance in the 12-in-1 paper because this was trained as a standalone task, not as part of a multitask setup. Please contact us if you want to match those scores!", + "intersectional_results": null + }, + "model_ethical_considerations": { + "ethical_considerations": null + }, + "model_usage": { + "archive_file": "vilbert-nlvr2-head-2021.06.01.tar.gz", + "training_config": "vilbert_nlvr2_pretrained.jsonnet", + "install_instructions": "pip install allennlp>=2.5.1 allennlp-models>=2.5.1" + } +} diff --git a/allennlp_models/modelcards/nlvr2-vilbert.json b/allennlp_models/modelcards/nlvr2-vilbert.json new file mode 100644 index 000000000..966ee3adb --- /dev/null +++ b/allennlp_models/modelcards/nlvr2-vilbert.json @@ -0,0 +1,66 @@ +{ + "id": "nlvr2-vilbert", + "registered_model_name": "nlvr2", + "registered_predictor_name": null, + "display_name": "Visual Entailment - NLVR2", + "task_id": "nlvr2", + "model_details": { + "description": "This model is based on the ViLBERT multitask architecture. The image features are obtained using the ResNet backbone and Faster RCNN (region detection).", + "short_description": "ViLBERT-based model for Visual Entailment.", + "developed_by": "Lu et al", + "contributed_by": "Jacob Morrison", + "date": "2021-05-27", + "version": "2", + "model_type": "ViLBERT based on BERT large", + "paper": { + "citation": "\n@inproceedings{Lu2019ViLBERTPT,\ntitle={ViLBERT: Pretraining Task-Agnostic Visiolinguistic Representations for Vision-and-Language Tasks},\nauthor={Jiasen Lu and Dhruv Batra and D. Parikh and Stefan Lee},\nbooktitle={NeurIPS},\nyear={2019}", + "title": "ViLBERT: Pretraining Task-Agnostic Visiolinguistic Representations for Vision-and-Language Tasks", + "url": "https://api.semanticscholar.org/CorpusID:199453025" + }, + "license": null, + "contact": "allennlp-contact@allenai.org" + }, + "intended_use": { + "primary_uses": "This model is developed for the AllenNLP demo.", + "primary_users": null, + "out_of_scope_use_cases": null + }, + "factors": { + "relevant_factors": null, + "evaluation_factors": null + }, + "metrics": { + "model_performance_measures": "Accuracy and F1-score", + "decision_thresholds": null, + "variation_approaches": null + }, + "evaluation_data": { + "dataset": { + "name": "Natural Language for Visual Reasoning For Real dev set", + "url": "https://github.com/lil-lab/nlvr/tree/master/nlvr2", + "notes": "Evaluation requires a large amount of images to be accessible locally, so we cannot provide a command you can easily copy and paste." + }, + "motivation": null, + "preprocessing": null + }, + "training_data": { + "dataset": { + "name": "Natural Language for Visual Reasoning For Real train set", + "url": "https://github.com/lil-lab/nlvr/tree/master/nlvr2" + }, + "motivation": null, + "preprocessing": null + }, + "quantitative_analyses": { + "unitary_results": "On the validation set:\nF1: 33.7%\nAccuracy: 50.8%.\nThese scores do not match the performance in the 12-in-1 paper because this was trained as a standalone task, not as part of a multitask setup. Please contact us if you want to match those scores!", + "intersectional_results": null + }, + "model_ethical_considerations": { + "ethical_considerations": null + }, + "model_usage": { + "archive_file": "vilbert-nlvr2-2021.06.01.tar.gz", + "training_config": "vilbert_nlvr2_pretrained.jsonnet", + "install_instructions": "pip install allennlp>=2.5.1 allennlp-models>=2.5.1" + } +} diff --git a/allennlp_models/vision/dataset_readers/__init__.py b/allennlp_models/vision/dataset_readers/__init__.py index 436a95914..124cfe0af 100644 --- a/allennlp_models/vision/dataset_readers/__init__.py +++ b/allennlp_models/vision/dataset_readers/__init__.py @@ -1,5 +1,6 @@ from allennlp_models.vision.dataset_readers.vision_reader import VisionReader from allennlp_models.vision.dataset_readers.gqa import GQAReader +from allennlp_models.vision.dataset_readers.nlvr2 import Nlvr2Reader from allennlp_models.vision.dataset_readers.vgqa import VGQAReader from allennlp_models.vision.dataset_readers.vqav2 import VQAv2Reader from allennlp_models.vision.dataset_readers.visual_entailment import VisualEntailmentReader diff --git a/allennlp_models/vision/dataset_readers/nlvr2.py b/allennlp_models/vision/dataset_readers/nlvr2.py new file mode 100644 index 000000000..659142fc4 --- /dev/null +++ b/allennlp_models/vision/dataset_readers/nlvr2.py @@ -0,0 +1,220 @@ +import logging +from os import PathLike +from typing import Any, Dict, Iterable, Tuple, Union, Optional + +from overrides import overrides +import torch +from torch import Tensor + +from allennlp.common.file_utils import cached_path, json_lines_from_file +from allennlp.common.lazy import Lazy +from allennlp.data.dataset_readers.dataset_reader import DatasetReader +from allennlp.data.fields import ArrayField, LabelField, ListField, MetadataField, TextField +from allennlp.data.image_loader import ImageLoader +from allennlp.data.instance import Instance +from allennlp.data.token_indexers import TokenIndexer +from allennlp.data.tokenizers import Tokenizer +from allennlp.modules.vision.grid_embedder import GridEmbedder +from allennlp.modules.vision.region_detector import RegionDetector + +from allennlp_models.vision.dataset_readers.vision_reader import VisionReader + +logger = logging.getLogger(__name__) + + +@DatasetReader.register("nlvr2") +class Nlvr2Reader(VisionReader): + """ + Reads the NLVR2 dataset from [http://lil.nlp.cornell.edu/nlvr/](http://lil.nlp.cornell.edu/nlvr/). + In this task, the model is presented with two images and a hypothesis referring to those images. + The task for the model is to identify whether the hypothesis is true or false. + Accordingly, the instances produced by this reader contain two images, featurized into the + fields "box_features" and "box_coordinates". In addition to that, it produces a `TextField` + called "hypothesis", and a `MetadataField` called "identifier". The latter contains the question + id from the question set. + + Parameters + ---------- + image_dir: `str` + Path to directory containing `png` image files. + image_loader: `ImageLoader` + An image loader to read the images with + image_featurizer: `GridEmbedder` + The backbone image processor (like a ResNet), whose output will be passed to the region + detector for finding object boxes in the image. + region_detector: `RegionDetector` + For pulling out regions of the image (both coordinates and features) that will be used by + downstream models. + feature_cache_dir: `str`, optional + If given, the reader will attempt to use the featurized image cache in this directory. + Caching the featurized images can result in big performance improvements, so it is + recommended to set this. + tokenizer: `Tokenizer`, optional, defaults to `PretrainedTransformerTokenizer("bert-base-uncased")` + token_indexers: `Dict[str, TokenIndexer]`, optional, + defaults to`{"tokens": PretrainedTransformerIndexer("bert-base-uncased")}` + cuda_device: `int`, optional + Set this to run image featurization on the given GPU. By default, image featurization runs on CPU. + max_instances: `int`, optional + If set, the reader only returns the first `max_instances` instances, and then stops. + This is useful for testing. + image_processing_batch_size: `int` + The number of images to process at one time while featurizing. Default is 8. + """ + + def __init__( + self, + image_dir: Optional[Union[str, PathLike]] = None, + *, + image_loader: Optional[ImageLoader] = None, + image_featurizer: Optional[Lazy[GridEmbedder]] = None, + region_detector: Optional[Lazy[RegionDetector]] = None, + feature_cache_dir: Optional[Union[str, PathLike]] = None, + tokenizer: Optional[Tokenizer] = None, + token_indexers: Optional[Dict[str, TokenIndexer]] = None, + cuda_device: Optional[Union[int, torch.device]] = None, + max_instances: Optional[int] = None, + image_processing_batch_size: int = 8, + write_to_cache: bool = True, + ) -> None: + run_featurization = image_loader and image_featurizer and region_detector + if image_dir is None and run_featurization: + raise ValueError( + "Because of the size of the image datasets, we don't download them automatically. " + "Please go to https://github.com/lil-lab/nlvr/tree/master/nlvr2, download the datasets you need, " + "and set the image_dir parameter to point to your download location. This dataset " + "reader does not care about the exact directory structure. It finds the images " + "wherever they are." + ) + + super().__init__( + image_dir, + image_loader=image_loader, + image_featurizer=image_featurizer, + region_detector=region_detector, + feature_cache_dir=feature_cache_dir, + tokenizer=tokenizer, + token_indexers=token_indexers, + cuda_device=cuda_device, + max_instances=max_instances, + image_processing_batch_size=image_processing_batch_size, + write_to_cache=write_to_cache, + ) + + github_url = "https://raw.githubusercontent.com/lil-lab/nlvr/" + nlvr_commit = "68a11a766624a5b665ec7594982b8ecbedc728c7" + data_dir = f"{github_url}{nlvr_commit}/nlvr2/data" + self.splits = { + "dev": f"{data_dir}/dev.json", + "test": f"{data_dir}/test1.json", + "train": f"{data_dir}/train.json", + "balanced_dev": f"{data_dir}/balanced/balanced_dev.json", + "balanced_test": f"{data_dir}/balanced/balanced_test1.json", + "unbalanced_dev": f"{data_dir}/balanced/unbalanced_dev.json", + "unbalanced_test": f"{data_dir}/balanced/unbalanced_test1.json", + } + + @overrides + def _read(self, split_or_filename: str): + filename = self.splits.get(split_or_filename, split_or_filename) + + json_file_path = cached_path(filename) + + blobs = [] + json_blob: Dict[str, Any] + for json_blob in json_lines_from_file(json_file_path): + blobs.append(json_blob) + + blob_dicts = list(self.shard_iterable(blobs)) + processed_images1: Iterable[Optional[Tuple[Tensor, Tensor]]] + processed_images2: Iterable[Optional[Tuple[Tensor, Tensor]]] + if self.produce_featurized_images: + # It would be much easier to just process one image at a time, but it's faster to process + # them in batches. So this code gathers up instances until it has enough to fill up a batch + # that needs processing, and then processes them all. + + try: + image_paths1 = [] + image_paths2 = [] + for blob in blob_dicts: + identifier = blob["identifier"] + image_name_base = identifier[: identifier.rindex("-")] + image_paths1.append(self.images[f"{image_name_base}-img0.png"]) + image_paths2.append(self.images[f"{image_name_base}-img1.png"]) + except KeyError as e: + missing_id = e.args[0] + raise KeyError( + missing_id, + f"We could not find an image with the id {missing_id}. " + "Because of the size of the image datasets, we don't download them automatically. " + "Please go to https://github.com/lil-lab/nlvr/tree/master/nlvr2, download the " + "datasets you need, and set the image_dir parameter to point to your download " + "location. This dataset reader does not care about the exact directory " + "structure. It finds the images wherever they are.", + ) + + processed_images1 = self._process_image_paths(image_paths1) + processed_images2 = self._process_image_paths(image_paths2) + else: + processed_images1 = [None for _ in range(len(blob_dicts))] + processed_images2 = [None for _ in range(len(blob_dicts))] + + attempted_instances = 0 + for json_blob, image1, image2 in zip(blob_dicts, processed_images1, processed_images2): + identifier = json_blob["identifier"] + hypothesis = json_blob["sentence"] + label = json_blob["label"] == "True" + instance = self.text_to_instance(identifier, hypothesis, image1, image2, label) + if instance is not None: + attempted_instances += 1 + yield instance + logger.info(f"Successfully yielded {attempted_instances} instances") + + def extract_image_features(self, image: Union[str, Tuple[Tensor, Tensor]], use_cache: bool): + if isinstance(image, str): + features, coords = next(self._process_image_paths([image], use_cache=use_cache)) + else: + features, coords = image + + return ( + ArrayField(features), + ArrayField(coords), + ArrayField( + features.new_ones((features.shape[0],), dtype=torch.bool), + padding_value=False, + dtype=torch.bool, + ), + ) + + @overrides + def text_to_instance( + self, # type: ignore + identifier: Optional[str], + hypothesis: str, + image1: Union[str, Tuple[Tensor, Tensor]], + image2: Union[str, Tuple[Tensor, Tensor]], + label: bool, + use_cache: bool = True, + ) -> Instance: + hypothesis_field = TextField(self._tokenizer.tokenize(hypothesis), None) + box_features1, box_coordinates1, box_mask1 = self.extract_image_features(image1, use_cache) + box_features2, box_coordinates2, box_mask2 = self.extract_image_features(image2, use_cache) + + fields = { + "hypothesis": ListField([hypothesis_field, hypothesis_field]), + "box_features": ListField([box_features1, box_features2]), + "box_coordinates": ListField([box_coordinates1, box_coordinates2]), + "box_mask": ListField([box_mask1, box_mask2]), + } + + if identifier is not None: + fields["identifier"] = MetadataField(identifier) + + if label is not None: + fields["label"] = LabelField(int(label), skip_indexing=True) + + return Instance(fields) + + @overrides + def apply_token_indexers(self, instance: Instance) -> None: + instance["hypothesis"][0].token_indexers = self._token_indexers # type: ignore + instance["hypothesis"][1].token_indexers = self._token_indexers # type: ignore diff --git a/allennlp_models/vision/models/__init__.py b/allennlp_models/vision/models/__init__.py index 8d96d285f..3c8627eed 100644 --- a/allennlp_models/vision/models/__init__.py +++ b/allennlp_models/vision/models/__init__.py @@ -1,3 +1,4 @@ +from allennlp_models.vision.models.nlvr2 import Nlvr2Model from allennlp_models.vision.models.vision_text_model import VisionTextModel from allennlp_models.vision.models.visual_entailment import VisualEntailmentModel from allennlp_models.vision.models.vilbert_vqa import VqaVilbert diff --git a/allennlp_models/vision/models/heads/__init__.py b/allennlp_models/vision/models/heads/__init__.py index d55c627fa..581e912db 100644 --- a/allennlp_models/vision/models/heads/__init__.py +++ b/allennlp_models/vision/models/heads/__init__.py @@ -1,2 +1,3 @@ +from allennlp_models.vision.models.heads.nlvr2_head import Nlvr2Head from allennlp_models.vision.models.heads.vqa_head import VqaHead from allennlp_models.vision.models.heads.visual_entailment_head import VisualEntailmentHead diff --git a/allennlp_models/vision/models/heads/nlvr2_head.py b/allennlp_models/vision/models/heads/nlvr2_head.py new file mode 100644 index 000000000..64fa4402c --- /dev/null +++ b/allennlp_models/vision/models/heads/nlvr2_head.py @@ -0,0 +1,77 @@ +from typing import Dict, Optional + +import torch +from overrides import overrides + +from allennlp.data.vocabulary import Vocabulary +from allennlp.models.heads.head import Head + + +@Head.register("nlvr2") +class Nlvr2Head(Head): + def __init__(self, vocab: Vocabulary, embedding_dim: int, label_namespace: str = "labels"): + super().__init__(vocab) + + self.label_namespace = label_namespace + + self.layer1 = torch.nn.Linear(embedding_dim * 2, embedding_dim) + self.layer2 = torch.nn.Linear(embedding_dim, 2) + + self.activation = torch.nn.ReLU() + + from allennlp.training.metrics import CategoricalAccuracy + from allennlp.training.metrics import FBetaMeasure + + self.accuracy = CategoricalAccuracy() + self.fbeta = FBetaMeasure(beta=1.0, average="macro") + + @overrides + def forward( + self, # type: ignore + encoded_boxes: torch.Tensor, + encoded_boxes_mask: torch.Tensor, + encoded_boxes_pooled: torch.Tensor, + encoded_text: torch.Tensor, + encoded_text_mask: torch.Tensor, + encoded_text_pooled: torch.Tensor, + pooled_boxes_and_text: torch.Tensor, + label: Optional[torch.Tensor] = None, + label_weights: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + pooled_boxes_and_text = pooled_boxes_and_text.transpose(0, 1) + hidden = self.layer1( + torch.cat((pooled_boxes_and_text[0], pooled_boxes_and_text[1]), dim=-1) + ) + logits = self.layer2(self.activation(hidden)) + probs = torch.softmax(logits, dim=-1) + + output = {"logits": logits, "probs": probs} + + assert label_weights is None + if label is not None: + output["loss"] = torch.nn.functional.cross_entropy(logits, label) / logits.size(0) + self.accuracy(logits, label) + self.fbeta(probs, label) + + return output + + @overrides + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + result = self.fbeta.get_metric(reset) + result["accuracy"] = self.accuracy.get_metric(reset) + return result + + def make_output_human_readable( + self, output_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + if len(output_dict) <= 0: + return output_dict + logits = output_dict["logits"] + entailment_answer_index = logits.argmax(-1) + entailment_answer = [ + self.vocab.get_token_from_index(int(i), "labels") for i in entailment_answer_index + ] + output_dict["entailment_answer"] = entailment_answer + return output_dict + + default_predictor = "nlvr2" diff --git a/allennlp_models/vision/models/nlvr2.py b/allennlp_models/vision/models/nlvr2.py new file mode 100644 index 000000000..e322d9f05 --- /dev/null +++ b/allennlp_models/vision/models/nlvr2.py @@ -0,0 +1,142 @@ +import logging +from typing import Dict, Optional, List, Any + +from overrides import overrides +import numpy as np +import torch + +from allennlp.data import TextFieldTensors, Vocabulary +from allennlp.models.model import Model +from allennlp.modules.transformer import ( + TransformerEmbeddings, + ImageFeatureEmbeddings, + BiModalEncoder, +) +from allennlp.training.metrics import CategoricalAccuracy +from allennlp.training.metrics import FBetaMeasure + +from allennlp_models.vision.models.vision_text_model import VisionTextModel + +logger = logging.getLogger(__name__) + + +@Model.register("nlvr2") +@Model.register("nlvr2_from_huggingface", constructor="from_huggingface_model_name") +class Nlvr2Model(VisionTextModel): + """ + Model for visual entailment task based on the paper + [A Corpus for Reasoning About Natural Language Grounded in Photographs] + (https://api.semanticscholar.org/CorpusID:53178856). + + # Parameters + + vocab : `Vocabulary` + text_embeddings : `TransformerEmbeddings` + image_embeddings : `ImageFeatureEmbeddings` + encoder : `BiModalEncoder` + pooled_output_dim : `int` + fusion_method : `str`, optional (default = `"mul"`) + dropout : `float`, optional (default = `0.1`) + label_namespace : `str`, optional (default = `labels`) + """ + + def __init__( + self, + vocab: Vocabulary, + text_embeddings: TransformerEmbeddings, + image_embeddings: ImageFeatureEmbeddings, + encoder: BiModalEncoder, + pooled_output_dim: int, + fusion_method: str = "mul", + dropout: float = 0.1, + label_namespace: str = "labels", + *, + ignore_text: bool = False, + ignore_image: bool = False, + ) -> None: + + super().__init__( + vocab, + text_embeddings, + image_embeddings, + encoder, + pooled_output_dim, + fusion_method, + dropout, + label_namespace, + is_multilabel=False, + ) + + self.pooled_output_dim = pooled_output_dim + + self.layer1 = torch.nn.Linear(pooled_output_dim * 2, pooled_output_dim) + self.layer2 = torch.nn.Linear(pooled_output_dim, 2) + + self.activation = torch.nn.ReLU() + + self.accuracy = CategoricalAccuracy() + self.fbeta = FBetaMeasure(beta=1.0, average="macro") + + @overrides + def forward( + self, # type: ignore + box_features: torch.Tensor, + box_coordinates: torch.Tensor, + box_mask: torch.Tensor, + hypothesis: TextFieldTensors, + label: Optional[torch.Tensor] = None, + identifier: List[Dict[str, Any]] = None, + ) -> Dict[str, torch.Tensor]: + batch_size = box_features.shape[0] + + pooled_outputs = self.backbone(box_features, box_coordinates, box_mask, hypothesis)[ + "pooled_boxes_and_text" + ].transpose(0, 1) + + hidden = self.layer1(torch.cat((pooled_outputs[0], pooled_outputs[1]), dim=-1)) + + # Shape: (batch_size, num_labels) + logits = self.layer2(self.activation(hidden)) + + # Shape: (batch_size, num_labels) + probs = torch.softmax(logits, dim=-1) + + outputs = {"logits": logits, "probs": probs} + outputs = self._compute_loss_and_metrics(batch_size, outputs, label) + + return outputs + + @overrides + def _compute_loss_and_metrics( + self, + batch_size: int, + outputs: torch.Tensor, + label: torch.Tensor, + ): + if label is not None: + outputs["loss"] = ( + torch.nn.functional.cross_entropy(outputs["logits"], label) / batch_size + ) + self.accuracy(outputs["logits"], label) + self.fbeta(outputs["probs"], label) + return outputs + + @overrides + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + metrics = self.fbeta.get_metric(reset) + accuracy = self.accuracy.get_metric(reset) + metrics.update({"accuracy": accuracy}) + return metrics + + @overrides + def make_output_human_readable( + self, output_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + batch_labels = [] + for batch_index, batch in enumerate(output_dict["probs"]): + labels = np.argmax(batch, axis=-1) + batch_labels.append(labels) + output_dict["labels"] = batch_labels + return output_dict + + default_predictor = "nlvr2" diff --git a/allennlp_models/vision/predictors/nlvr2.py b/allennlp_models/vision/predictors/nlvr2.py new file mode 100644 index 000000000..d863352c5 --- /dev/null +++ b/allennlp_models/vision/predictors/nlvr2.py @@ -0,0 +1,44 @@ +from typing import List, Dict + +from overrides import overrides +import numpy + +from allennlp.common.file_utils import cached_path +from allennlp.common.util import JsonDict +from allennlp.data import Instance +from allennlp.data.fields import LabelField +from allennlp.predictors.predictor import Predictor + + +@Predictor.register("nlvr2") +class Nlvr2Predictor(Predictor): + def predict(self, image1: str, image2: str, hypothesis: str) -> JsonDict: + image1 = cached_path(image1) + image2 = cached_path(image2) + return self.predict_json({"image1": image1, "image2": image2, "hypothesis": hypothesis}) + + @overrides + def _json_to_instance(self, json_dict: JsonDict) -> Instance: + from allennlp_models.vision.dataset_readers.nlvr2 import Nlvr2Reader + + image1 = cached_path(json_dict["image1"]) + image2 = cached_path(json_dict["image2"]) + hypothesis = json_dict["hypothesis"] + if isinstance(self._dataset_reader, Nlvr2Reader): + return self._dataset_reader.text_to_instance( + hypothesis, image1, image2, use_cache=False + ) + else: + raise ValueError( + f"Dataset reader is of type f{self._dataset_reader.__class__.__name__}. " + f"Expected {Nlvr2Reader.__name__}." + ) + + @overrides + def predictions_to_labeled_instances( + self, instance: Instance, outputs: Dict[str, numpy.ndarray] + ) -> List[Instance]: + new_instance = instance.duplicate() + label = numpy.argmax(outputs["probs"]) + new_instance.add_field("label", LabelField(int(label), skip_indexing=True)) + return [new_instance] diff --git a/test_fixtures/vision/images/nlvr2/dev-850-0-img0.png b/test_fixtures/vision/images/nlvr2/dev-850-0-img0.png new file mode 100644 index 000000000..494e87578 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/dev-850-0-img0.png differ diff --git a/test_fixtures/vision/images/nlvr2/dev-850-0-img1.png b/test_fixtures/vision/images/nlvr2/dev-850-0-img1.png new file mode 100644 index 000000000..4c12be3a9 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/dev-850-0-img1.png differ diff --git a/test_fixtures/vision/images/nlvr2/dev-850-1-img0.png b/test_fixtures/vision/images/nlvr2/dev-850-1-img0.png new file mode 100644 index 000000000..3fc1c616b Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/dev-850-1-img0.png differ diff --git a/test_fixtures/vision/images/nlvr2/dev-850-1-img1.png b/test_fixtures/vision/images/nlvr2/dev-850-1-img1.png new file mode 100644 index 000000000..225a83895 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/dev-850-1-img1.png differ diff --git a/test_fixtures/vision/images/nlvr2/dev-850-2-img0.png b/test_fixtures/vision/images/nlvr2/dev-850-2-img0.png new file mode 100644 index 000000000..3dd18d8f1 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/dev-850-2-img0.png differ diff --git a/test_fixtures/vision/images/nlvr2/dev-850-2-img1.png b/test_fixtures/vision/images/nlvr2/dev-850-2-img1.png new file mode 100644 index 000000000..bf6768b41 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/dev-850-2-img1.png differ diff --git a/test_fixtures/vision/images/nlvr2/dev-850-3-img0.png b/test_fixtures/vision/images/nlvr2/dev-850-3-img0.png new file mode 100644 index 000000000..1b9d9b32c Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/dev-850-3-img0.png differ diff --git a/test_fixtures/vision/images/nlvr2/dev-850-3-img1.png b/test_fixtures/vision/images/nlvr2/dev-850-3-img1.png new file mode 100644 index 000000000..85401c4b6 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/dev-850-3-img1.png differ diff --git a/test_fixtures/vision/images/nlvr2/test1-0-0-img0.png b/test_fixtures/vision/images/nlvr2/test1-0-0-img0.png new file mode 100644 index 000000000..3deac95d0 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/test1-0-0-img0.png differ diff --git a/test_fixtures/vision/images/nlvr2/test1-0-0-img1.png b/test_fixtures/vision/images/nlvr2/test1-0-0-img1.png new file mode 100644 index 000000000..189d6b55a Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/test1-0-0-img1.png differ diff --git a/test_fixtures/vision/images/nlvr2/test1-0-1-img0.png b/test_fixtures/vision/images/nlvr2/test1-0-1-img0.png new file mode 100644 index 000000000..8ebb45431 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/test1-0-1-img0.png differ diff --git a/test_fixtures/vision/images/nlvr2/test1-0-1-img1.png b/test_fixtures/vision/images/nlvr2/test1-0-1-img1.png new file mode 100644 index 000000000..6db5e5259 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/test1-0-1-img1.png differ diff --git a/test_fixtures/vision/images/nlvr2/test1-0-2-img0.png b/test_fixtures/vision/images/nlvr2/test1-0-2-img0.png new file mode 100644 index 000000000..8439d6746 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/test1-0-2-img0.png differ diff --git a/test_fixtures/vision/images/nlvr2/test1-0-2-img1.png b/test_fixtures/vision/images/nlvr2/test1-0-2-img1.png new file mode 100644 index 000000000..da51cdb9a Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/test1-0-2-img1.png differ diff --git a/test_fixtures/vision/images/nlvr2/test1-0-3-img0.png b/test_fixtures/vision/images/nlvr2/test1-0-3-img0.png new file mode 100644 index 000000000..d95f28bbd Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/test1-0-3-img0.png differ diff --git a/test_fixtures/vision/images/nlvr2/test1-0-3-img1.png b/test_fixtures/vision/images/nlvr2/test1-0-3-img1.png new file mode 100644 index 000000000..d6a62ccce Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/test1-0-3-img1.png differ diff --git a/test_fixtures/vision/images/nlvr2/train-10171-0-img0.png b/test_fixtures/vision/images/nlvr2/train-10171-0-img0.png new file mode 100644 index 000000000..4d59b1863 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/train-10171-0-img0.png differ diff --git a/test_fixtures/vision/images/nlvr2/train-10171-0-img1.png b/test_fixtures/vision/images/nlvr2/train-10171-0-img1.png new file mode 100644 index 000000000..9af2753d6 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/train-10171-0-img1.png differ diff --git a/test_fixtures/vision/images/nlvr2/train-4100-0-img0.png b/test_fixtures/vision/images/nlvr2/train-4100-0-img0.png new file mode 100644 index 000000000..ae8ffc18e Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/train-4100-0-img0.png differ diff --git a/test_fixtures/vision/images/nlvr2/train-4100-0-img1.png b/test_fixtures/vision/images/nlvr2/train-4100-0-img1.png new file mode 100644 index 000000000..bc8b8dcc4 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/train-4100-0-img1.png differ diff --git a/test_fixtures/vision/images/nlvr2/train-4933-2-img0.png b/test_fixtures/vision/images/nlvr2/train-4933-2-img0.png new file mode 100644 index 000000000..54c08fd94 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/train-4933-2-img0.png differ diff --git a/test_fixtures/vision/images/nlvr2/train-4933-2-img1.png b/test_fixtures/vision/images/nlvr2/train-4933-2-img1.png new file mode 100644 index 000000000..48f6b2c51 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/train-4933-2-img1.png differ diff --git a/test_fixtures/vision/images/nlvr2/train-6623-1-img0.png b/test_fixtures/vision/images/nlvr2/train-6623-1-img0.png new file mode 100644 index 000000000..b186f45ea Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/train-6623-1-img0.png differ diff --git a/test_fixtures/vision/images/nlvr2/train-6623-1-img1.png b/test_fixtures/vision/images/nlvr2/train-6623-1-img1.png new file mode 100644 index 000000000..197080e06 Binary files /dev/null and b/test_fixtures/vision/images/nlvr2/train-6623-1-img1.png differ diff --git a/test_fixtures/vision/nlvr2/experiment.jsonnet b/test_fixtures/vision/nlvr2/experiment.jsonnet new file mode 100644 index 000000000..49b149874 --- /dev/null +++ b/test_fixtures/vision/nlvr2/experiment.jsonnet @@ -0,0 +1,78 @@ +local model_name = "epwalsh/bert-xsmall-dummy"; + +{ + "dataset_reader": { + "type": "nlvr2", + "image_dir": "test_fixtures/vision/images/nlvr2", + "image_loader": "torch", + "image_featurizer": "null", + "region_detector": { + "type": "random", + "seed": 322 + }, + "tokenizer": { + "type": "pretrained_transformer", + "model_name": model_name + }, + "token_indexers": { + "tokens": { + "type": "pretrained_transformer", + "model_name": model_name + } + } + }, + "train_data_path": "test_fixtures/vision/nlvr2/tiny-dev.json", + "validation_data_path": "test_fixtures/vision/nlvr2/tiny-dev.json", + "model": { + "type": "nlvr2", + "text_embeddings": { + "vocab_size": 250, + "embedding_size": 20, + "pad_token_id": 0, + "max_position_embeddings": 512, + "type_vocab_size": 2, + "dropout": 0.0 + }, + "image_embeddings": { + "feature_size": 10, + "embedding_size": 200 + }, + "encoder": { + # text + "hidden_size1": 20, + "num_hidden_layers1": 1, + "intermediate_size1": 40, + "num_attention_heads1": 1, + "attention_dropout1": 0.1, + "hidden_dropout1": 0.1, + "biattention_id1": [0, 1], + "fixed_layer1": 0, + + # vision + "hidden_size2": 200, + "num_hidden_layers2": 1, + "intermediate_size2": 50, + "num_attention_heads2": 1, + "attention_dropout2": 0.0, + "hidden_dropout2": 0.0, + "biattention_id2": [0, 1], + "fixed_layer2": 0, + + "combined_num_attention_heads": 2, + "combined_hidden_size": 200, + "activation": "gelu", + }, + "pooled_output_dim": 100, + "fusion_method": "sum", + }, + "data_loader": { + "batch_size": 4 + }, + "trainer": { + "optimizer": { + "type": "huggingface_adamw", + "lr": 0.00005 + }, + "num_epochs": 1, + }, +} diff --git a/test_fixtures/vision/nlvr2/experiment_from_huggingface.jsonnet b/test_fixtures/vision/nlvr2/experiment_from_huggingface.jsonnet new file mode 100644 index 000000000..10d91979c --- /dev/null +++ b/test_fixtures/vision/nlvr2/experiment_from_huggingface.jsonnet @@ -0,0 +1,58 @@ +local model_name = "epwalsh/bert-xsmall-dummy"; +{ + "dataset_reader": { + "type": "nlvr2", + "image_dir": "test_fixtures/vision/images/nlvr2", + "image_loader": "torch", + "image_featurizer": "null", + "region_detector": { + "type": "random", + "seed": 322 + }, + "tokenizer": { + "type": "pretrained_transformer", + "model_name": model_name + }, + "token_indexers": { + "tokens": { + "type": "pretrained_transformer", + "model_name": model_name + } + } + }, + "train_data_path": "test_fixtures/vision/nlvr2/tiny-dev.json", + "validation_data_path": "test_fixtures/vision/nlvr2/tiny-dev.json", + "model": { + "type": "nlvr2_from_huggingface", + "model_name": model_name, + "image_feature_dim": 10, + "image_num_hidden_layers": 1, + "image_hidden_size": 200, + "image_num_attention_heads": 1, + "image_intermediate_size": 50, + "image_attention_dropout": 0.0, + "image_hidden_dropout": 0.0, + "image_biattention_id": [0, 1], + "image_fixed_layer": 0, + + "text_biattention_id": [0, 1], + "text_fixed_layer": 0, + + "combined_hidden_size": 200, + "combined_num_attention_heads": 4, + + "pooled_output_dim": 100, + "fusion_method": "sum", + "pooled_dropout": 0.0, + }, + "data_loader": { + "batch_size": 32 + }, + "trainer": { + "optimizer": { + "type": "huggingface_adamw", + "lr": 0.00005 + }, + "num_epochs": 1, + }, +} diff --git a/test_fixtures/vision/nlvr2/tiny-dev.json b/test_fixtures/vision/nlvr2/tiny-dev.json new file mode 100644 index 000000000..0fb5a62f7 --- /dev/null +++ b/test_fixtures/vision/nlvr2/tiny-dev.json @@ -0,0 +1,8 @@ +{"validation": {"61": "False"}, "sentence": "The right image shows a curving walkway of dark glass circles embedded in dirt and flanked by foliage.", "left_url": "https://i.kinja-img.com/gawker-media/image/upload/s--UyhVSznS--/18iy0hwo5wdrpjpg.jpg", "writer": "61", "label": "False", "right_url": "https://cdn.pixabay.com/photo/2015/09/21/12/57/beer-bottles-949793_960_720.jpg", "synset": "beer bottle", "query": "some beer bottles38", "identifier": "dev-850-0-0", "extra_validations": {"154": "False", "139": "False", "149": "False", "62": "False"}} +{"validation": {"160": "True"}, "sentence": "The right image shows a curving walkway of dark glass circles embedded in dirt and flanked by foliage.", "left_url": "https://i.pinimg.com/originals/56/9d/99/569d99a9ae49f55cafc676b851d5b48e.jpg", "writer": "61", "label": "True", "right_url": "https://i.pinimg.com/originals/80/11/05/801105102b8e65810f3c276895f80bf7.jpg", "synset": "beer bottle", "query": "some beer bottles38", "identifier": "dev-850-2-0", "extra_validations": {"136": "True", "154": "True", "149": "True", "62": "True"}} +{"validation": {"48": "False"}, "sentence": "The right image shows a curving walkway of dark glass circles embedded in dirt and flanked by foliage.", "left_url": "https://i.pinimg.com/736x/fa/52/07/fa52079fe0db452cd47946fb8c7d553f--soda-bottles-sodas.jpg", "writer": "61", "label": "False", "right_url": "http://www.historicbottles.com/queenolive.jpg", "synset": "beer bottle", "query": "some beer bottles38", "identifier": "dev-850-3-0", "extra_validations": {"16": "False", "136": "False", "130": "False", "138": "False"}} +{"validation": {"83": "True"}, "sentence": "The right image shows a curving walkway of dark glass circles embedded in dirt and flanked by foliage.", "left_url": "http://www.thebruery.com/wp-content/uploads/2014/07/the-bruery-DIY-craft-beer-bottle-planter-garden-craft-end.jpg", "writer": "61", "label": "True", "right_url": "https://s-media-cache-ak0.pinimg.com/originals/72/bf/56/72bf569dae62342a629adc71da1bf407.jpg", "synset": "beer bottle", "query": "some beer bottles38", "identifier": "dev-850-1-0", "extra_validations": {"58": "True", "108": "True", "157": "True", "134": "True"}} +{"validation": {"66": "True"}, "sentence": "IN at least one image there are at least four bottle rows that together make a walking path.", "left_url": "http://www.thebruery.com/wp-content/uploads/2014/07/the-bruery-DIY-craft-beer-bottle-planter-garden-craft-end.jpg", "writer": "31", "label": "True", "right_url": "https://s-media-cache-ak0.pinimg.com/originals/72/bf/56/72bf569dae62342a629adc71da1bf407.jpg", "synset": "beer bottle", "query": "some beer bottles38", "identifier": "dev-850-1-1", "extra_validations": {"136": "True", "154": "True", "77": "False", "71": "True"}} +{"validation": {"66": "False"}, "sentence": "IN at least one image there are at least four bottle rows that together make a walking path.", "left_url": "https://i.kinja-img.com/gawker-media/image/upload/s--UyhVSznS--/18iy0hwo5wdrpjpg.jpg", "writer": "31", "label": "False", "right_url": "https://cdn.pixabay.com/photo/2015/09/21/12/57/beer-bottles-949793_960_720.jpg", "synset": "beer bottle", "query": "some beer bottles38", "identifier": "dev-850-0-1", "extra_validations": {"160": "False", "154": "False", "43": "False", "127": "False"}} +{"validation": {"157": "False"}, "sentence": "IN at least one image there are at least four bottle rows that together make a walking path.", "left_url": "https://i.pinimg.com/736x/fa/52/07/fa52079fe0db452cd47946fb8c7d553f--soda-bottles-sodas.jpg", "writer": "31", "label": "False", "right_url": "http://www.historicbottles.com/queenolive.jpg", "synset": "beer bottle", "query": "some beer bottles38", "identifier": "dev-850-3-1", "extra_validations": {"56": "False", "160": "False", "61": "False", "62": "False"}} +{"validation": {"160": "True"}, "sentence": "IN at least one image there are at least four bottle rows that together make a walking path.", "left_url": "https://i.pinimg.com/originals/56/9d/99/569d99a9ae49f55cafc676b851d5b48e.jpg", "writer": "31", "label": "True", "right_url": "https://i.pinimg.com/originals/80/11/05/801105102b8e65810f3c276895f80bf7.jpg", "synset": "beer bottle", "query": "some beer bottles38", "identifier": "dev-850-2-1", "extra_validations": {"154": "True", "156": "False", "61": "True", "31": "True"}} \ No newline at end of file diff --git a/tests/vision/dataset_readers/nlvr2_test.py b/tests/vision/dataset_readers/nlvr2_test.py new file mode 100644 index 000000000..0eb6a4f44 --- /dev/null +++ b/tests/vision/dataset_readers/nlvr2_test.py @@ -0,0 +1,49 @@ +from allennlp.common.testing import AllenNlpTestCase +from allennlp.common.lazy import Lazy +from allennlp.data import Batch, Vocabulary +from allennlp.data.image_loader import TorchImageLoader +from allennlp.data.tokenizers import WhitespaceTokenizer +from allennlp.data.token_indexers import SingleIdTokenIndexer +from allennlp.modules.vision.grid_embedder import NullGridEmbedder +from allennlp.modules.vision.region_detector import RandomRegionDetector + +from tests import FIXTURES_ROOT + + +class TestNlvr2Reader(AllenNlpTestCase): + def test_read(self): + from allennlp_models.vision.dataset_readers.nlvr2 import Nlvr2Reader + + reader = Nlvr2Reader( + image_dir=FIXTURES_ROOT / "vision" / "images" / "nlvr2", + image_loader=TorchImageLoader(), + image_featurizer=Lazy(NullGridEmbedder), + region_detector=Lazy(RandomRegionDetector), + tokenizer=WhitespaceTokenizer(), + token_indexers={"tokens": SingleIdTokenIndexer()}, + ) + instances = list(reader.read("test_fixtures/vision/nlvr2/tiny-dev.json")) + assert len(instances) == 8 + + instance = instances[0] + assert len(instance.fields) == 6 + assert instance["hypothesis"][0] == instance["hypothesis"][1] + assert len(instance["hypothesis"][0]) == 18 + hypothesis_tokens = [t.text for t in instance["hypothesis"][0]] + assert hypothesis_tokens[:6] == ["The", "right", "image", "shows", "a", "curving"] + assert instance["label"].label == 0 + assert instances[1]["label"].label == 1 + assert instance["identifier"].metadata == "dev-850-0-0" + + batch = Batch(instances) + batch.index_instances(Vocabulary()) + tensors = batch.as_tensor_dict() + + # (batch size, 2 images per instance, num boxes (fake), num features (fake)) + assert tensors["box_features"].size() == (8, 2, 2, 10) + + # (batch size, 2 images per instance, num boxes (fake), 4 coords) + assert tensors["box_coordinates"].size() == (8, 2, 2, 4) + + # (batch size, 2 images per instance, num boxes (fake)) + assert tensors["box_mask"].size() == (8, 2, 2) diff --git a/tests/vision/models/vilbert_nlvr2_test.py b/tests/vision/models/vilbert_nlvr2_test.py new file mode 100644 index 000000000..c9d96e8bf --- /dev/null +++ b/tests/vision/models/vilbert_nlvr2_test.py @@ -0,0 +1,77 @@ +from torch.testing import assert_allclose +from transformers import AutoModel + +from allennlp.common.testing import ModelTestCase +from allennlp.data import Vocabulary +from allennlp.common.testing import assert_equal_parameters + +from allennlp_models import vision # noqa: F401 + +from tests import FIXTURES_ROOT + + +class TestNlvr2Vilbert(ModelTestCase): + def test_model_can_train_save_and_load_small_model(self): + param_file = FIXTURES_ROOT / "vision" / "nlvr2" / "experiment.jsonnet" + self.ensure_model_can_train_save_and_load( + param_file, gradients_to_ignore={"classifier.weight", "classifier.bias"} + ) + + def test_model_can_train_save_and_load_with_cache(self): + import tempfile + + with tempfile.TemporaryDirectory(prefix=self.__class__.__name__) as d: + overrides = {"dataset_reader": {"feature_cache_dir": str(d)}} + import json + + overrides = json.dumps(overrides) + param_file = FIXTURES_ROOT / "vision" / "nlvr2" / "experiment.jsonnet" + self.ensure_model_can_train_save_and_load( + param_file, + overrides=overrides, + gradients_to_ignore={"classifier.weight", "classifier.bias"}, + ) + + def test_model_can_train_save_and_load_from_huggingface(self): + param_file = FIXTURES_ROOT / "vision" / "nlvr2" / "experiment_from_huggingface.jsonnet" + self.ensure_model_can_train_save_and_load( + param_file, gradients_to_ignore={"classifier.weight", "classifier.bias"} + ) + + def test_model_loads_weights_correctly(self): + from allennlp_models.vision.models.nlvr2 import Nlvr2Model + + vocab = Vocabulary() + model_name = "epwalsh/bert-xsmall-dummy" + model = Nlvr2Model.from_huggingface_model_name( + vocab=vocab, + model_name=model_name, + image_feature_dim=2048, + image_num_hidden_layers=1, + image_hidden_size=3, + image_num_attention_heads=1, + combined_num_attention_heads=1, + combined_hidden_size=5, + pooled_output_dim=7, + image_intermediate_size=11, + image_attention_dropout=0.0, + image_hidden_dropout=0.0, + image_biattention_id=[0, 1], + text_biattention_id=[0, 1], + text_fixed_layer=0, + image_fixed_layer=0, + ) + + transformer = AutoModel.from_pretrained(model_name) + + # compare embedding parameters + assert_allclose( + transformer.embeddings.word_embeddings.weight.data, + model.backbone.text_embeddings.embeddings.word_embeddings.weight.data, + ) + + # compare encoder parameters + assert_allclose( + transformer.encoder.layer[0].intermediate.dense.weight.data, + model.backbone.encoder.layers1[0].intermediate.dense.weight.data, + ) diff --git a/training_config/vision/vilbert_nlvr2_head_pretrained.jsonnet b/training_config/vision/vilbert_nlvr2_head_pretrained.jsonnet new file mode 100644 index 000000000..4c4cde7a5 --- /dev/null +++ b/training_config/vision/vilbert_nlvr2_head_pretrained.jsonnet @@ -0,0 +1,111 @@ +local model_name = "bert-large-uncased"; +local num_gpus = 1; +local effective_batch_size = 64; +local gpu_batch_size = effective_batch_size / num_gpus; +local num_epochs = 10; +local patience = 5; +local num_gradient_accumulation_steps = effective_batch_size / gpu_batch_size / std.max(1, num_gpus); +local num_instances = 86373; + + +local reader_common = { + "image_loader": "torch", + "region_detector": "faster_rcnn", + "image_featurizer": "resnet_backbone", + "tokenizer": { + "type": "pretrained_transformer", + "model_name": model_name + }, + "token_indexers": { + "tokens": { + "type": "pretrained_transformer", + "model_name": model_name + } + }, + // "max_instances": 1000, # DEBUG + "image_processing_batch_size": 16, +}; + +{ + "dataset_reader": { + "type": "multitask", + "readers": { + "nlvr2": reader_common { + "type": "nlvr2", + "image_dir": "/net/nfs2.allennlp/data/vision/nlvr2/images", + "feature_cache_dir": "/net/nfs2.allennlp/data/vision/nlvr2/feature_cache", + } + } + }, + "train_data_path": { + "nlvr2": "train", + }, + "validation_data_path": { + "nlvr2": "dev", + }, + "test_data_path": { + "nlvr2": "test", + }, + "model": { + "type": "multitask", + "arg_name_mapping": { + "backbone": {"question": "text", "hypothesis": "text"} + }, + "backbone": { + "type": "vilbert_from_huggingface", + "model_name": model_name, + "image_feature_dim": 1024, + "image_num_hidden_layers": 6, + "image_hidden_size": 1024, + "image_num_attention_heads": 8, + "image_intermediate_size": 1024, + "image_attention_dropout": 0.1, + "image_hidden_dropout": 0.1, + "image_biattention_id": [0, 1, 2, 3, 4, 5], + "text_biattention_id": [6, 7, 8, 9, 10, 11], + "text_fixed_layer": 0, + "image_fixed_layer": 0, + "combined_hidden_size": 1024, + "combined_num_attention_heads": 8, + "pooled_output_dim": 1024, + "fusion_method": "mul" + }, + "heads": { + "nlvr2": { + "type": "nlvr2", + "embedding_dim": 1024 + }, + } + }, + "data_loader": { + "type": "multitask", + "scheduler": { + "batch_size": gpu_batch_size, + }, + "shuffle": true, + }, + [if num_gpus > 1 then "distributed"]: { + "cuda_devices": std.range(0, num_gpus - 1) + //"cuda_devices": std.repeat([-1], num_gpus) # Use this for debugging on CPU + }, + "trainer": { + "optimizer": { + "type": "huggingface_adamw", + "lr": 4e-5, + "correct_bias": true, + "weight_decay": 0.01, + "parameter_groups": [[["bias", "LayerNorm\\.weight", "layer_norm\\.weight"], {"weight_decay": 0}]], + }, + "learning_rate_scheduler": { + "type": "linear_with_warmup", + "warmup_steps" : std.ceil(0.1 * num_instances * num_epochs * num_gradient_accumulation_steps / effective_batch_size) + }, + "validation_metric": ["+nlvr2_accuracy"], + "patience": patience, + "num_epochs": num_epochs, + "num_gradient_accumulation_steps": num_gradient_accumulation_steps, + }, + "random_seed": 876170670, + "numpy_seed": 876170670, + "pytorch_seed": 876170670, +} diff --git a/training_config/vision/vilbert_nlvr2_pretrained.jsonnet b/training_config/vision/vilbert_nlvr2_pretrained.jsonnet new file mode 100644 index 000000000..48c0867d0 --- /dev/null +++ b/training_config/vision/vilbert_nlvr2_pretrained.jsonnet @@ -0,0 +1,78 @@ +local model_name = "bert-large-uncased"; +local num_gpus = 1; +local effective_batch_size = 64; +local gpu_batch_size = effective_batch_size / num_gpus; +local num_epochs = 10; +local patience = 5; +local num_gradient_accumulation_steps = effective_batch_size / gpu_batch_size / std.max(1, num_gpus); +local num_instances = 86373; + +{ + "dataset_reader": { + "type": "nlvr2", + "image_dir": "/net/nfs2.allennlp/data/vision/nlvr2/images", + "feature_cache_dir": "/net/nfs2.allennlp/data/vision/nlvr2/feature_cache", + "image_loader": "torch", + "image_featurizer": "resnet_backbone", + "region_detector": "faster_rcnn", + "tokenizer": { + "type": "pretrained_transformer", + "model_name": model_name + }, + "token_indexers": { + "tokens": { + "type": "pretrained_transformer", + "model_name": model_name + } + }, + "image_processing_batch_size": 16, + // "max_instances": 1000 + }, + "train_data_path": "train", + "validation_data_path": "dev", + "test_data_path": "test", + "evaluate_on_test": true, + "model": { + "type": "nlvr2_from_huggingface", + "model_name": model_name, + "image_feature_dim": 1024, + "image_hidden_size": 1024, + "image_num_attention_heads": 8, + "image_num_hidden_layers": 6, + "combined_hidden_size": 1024, + "combined_num_attention_heads": 8, + "pooled_output_dim": 1024, + "image_intermediate_size": 1024, + "image_attention_dropout": 0.1, + "image_hidden_dropout": 0.1, + "image_biattention_id": [0, 1, 2, 3, 4, 5], + "text_biattention_id": [6, 7, 8, 9, 10, 11], + "text_fixed_layer": 0, + "image_fixed_layer": 0, + "fusion_method": "mul" + }, + "data_loader": { + "batch_size": gpu_batch_size, + "shuffle": true, + }, + [if num_gpus > 1 then "distributed"]: { + "cuda_devices": std.range(0, num_gpus - 1) + #"cuda_devices": std.repeat([-1], num_gpus) # Use this for debugging on CPU + }, + "trainer": { + "optimizer": { + "type": "huggingface_adamw", + "lr": 2e-5, + "weight_decay": 0.01, + }, + "learning_rate_scheduler": { + "type": "linear_with_warmup", + "warmup_steps" : std.ceil(0.1 * num_instances * num_epochs * num_gradient_accumulation_steps / effective_batch_size) + // "warmup_steps": 5000 + }, + "num_gradient_accumulation_steps": num_gradient_accumulation_steps, + "validation_metric": "+accuracy", + "num_epochs": num_epochs, + "patience": patience + }, +}