Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Distributed VQA metric (#250)
Browse files Browse the repository at this point in the history
* Fixes the VQA metric in the distributed case

* Adds a test for the VQA metric

* Changelog

* Set up devices properly

* Use the new number_of_runs parameter

* Productivity through formatting

* Make sure data types align

* Fix the test

There are _multiple_ labels per instance. That's the whole point of this metric.
  • Loading branch information
dirkgr authored Apr 15, 2021
1 parent 54b332e commit 419bc90
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 26 deletions.
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## Unreleased

## [v2.3.0](https://github.com/allenai/allennlp-models/releases/tag/v2.3.0) - 2021-04-14
### Fixed

- VQA metric now calculates correctly in the distributed case

### Added

- Tests for the VQA metric


## [v2.3.0](https://github.com/allenai/allennlp-models/releases/tag/v2.3.0) - 2021-04-14

### Fixed

Expand Down
46 changes: 21 additions & 25 deletions allennlp_models/vision/metrics/vqa.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Union

import torch
from overrides import overrides

Expand All @@ -21,8 +19,8 @@ class VqaMeasure(Metric):
"""

def __init__(self) -> None:
self._sum_of_scores: Union[None, torch.Tensor] = None
self._score_count: Union[None, torch.Tensor] = None
self._sum_of_scores = 0.0
self._score_count = 0

@overrides
def __call__(self, logits: torch.Tensor, labels: torch.Tensor, label_weights: torch.Tensor):
Expand All @@ -38,39 +36,37 @@ def __call__(self, logits: torch.Tensor, labels: torch.Tensor, label_weights: to
every one of the labels.
"""

device = logits.device

if self._sum_of_scores is None:
self._sum_of_scores = torch.zeros([], device=device, dtype=label_weights.dtype)
if self._score_count is None:
self._score_count = torch.zeros([], device=device, dtype=torch.int32)

logits, labels, label_weights = self.detach_tensors(logits, labels, label_weights)
predictions = logits.argmax(dim=1)

# Sum over dimension 1 gives the score per question. We care about the overall sum though,
# so we sum over all dimensions.
self._sum_of_scores += (label_weights * (labels == predictions.unsqueeze(-1))).sum()
self._score_count += labels.size(0)
local_sum_of_scores = (
(label_weights * (labels == predictions.unsqueeze(-1))).sum().to(torch.float32)
)
local_score_count = torch.tensor(labels.size(0), dtype=torch.int32, device=labels.device)

from allennlp.common.util import is_distributed

if is_distributed():
dist.all_reduce(self._sum_of_scores, op=dist.ReduceOp.SUM)
dist.all_reduce(self._score_count, op=dist.ReduceOp.SUM)
dist.all_reduce(local_sum_of_scores, op=dist.ReduceOp.SUM)
dist.all_reduce(local_score_count, op=dist.ReduceOp.SUM)

self._sum_of_scores += local_sum_of_scores.item()
self._score_count += local_score_count.item()

@overrides
def get_metric(self, reset: bool = False):
"""
# Returns
score : `float`
"""
from allennlp.common.util import nan_safe_tensor_divide

return {"score": nan_safe_tensor_divide(self._sum_of_scores, self._score_count).item()}
if self._score_count > 0:
result = self._sum_of_scores / self._score_count
else:
result = 0.0
result_dict = {"score": result}
if reset:
self.reset()
return result_dict

@overrides
def reset(self) -> None:
self._sum_of_scores = None
self._score_count = None
self._sum_of_scores = 0.0
self._score_count = 0
108 changes: 108 additions & 0 deletions tests/vision/metrics/vqa_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import Any, Dict, List, Tuple, Union

import pytest
import torch
from allennlp.common.testing import (
AllenNlpTestCase,
multi_device,
global_distributed_metric,
run_distributed_test,
)

from allennlp_models.vision import VqaMeasure


class VqaMeasureTest(AllenNlpTestCase):
@multi_device
def test_vqa(self, device: str):
vqa = VqaMeasure()
logits = torch.tensor(
[[0.35, 0.25, 0.1, 0.1, 0.2], [0.1, 0.6, 0.1, 0.2, 0.0]], device=device
)
labels = torch.tensor([[0], [3]], device=device)
label_weights = torch.tensor([[1 / 3], [2 / 3]], device=device)
vqa(logits, labels, label_weights)
vqa_score = vqa.get_metric()["score"]
assert vqa_score == pytest.approx((1 / 3) / 2)

@multi_device
def test_vqa_accumulates_and_resets_correctly(self, device: str):
vqa = VqaMeasure()
logits = torch.tensor(
[[0.35, 0.25, 0.1, 0.1, 0.2], [0.1, 0.6, 0.1, 0.2, 0.0]], device=device
)
labels = torch.tensor([[0], [3]], device=device)
labels2 = torch.tensor([[4], [4]], device=device)
label_weights = torch.tensor([[1 / 3], [2 / 3]], device=device)

vqa(logits, labels, label_weights)
vqa(logits, labels, label_weights)
vqa(logits, labels2, label_weights)
vqa(logits, labels2, label_weights)

vqa_score = vqa.get_metric(reset=True)["score"]
assert vqa_score == pytest.approx((1 / 3 + 1 / 3 + 0 + 0) / 8)

vqa(logits, labels, label_weights)
vqa_score = vqa.get_metric(reset=True)["score"]
assert vqa_score == pytest.approx((1 / 3) / 2)

@multi_device
def test_does_not_divide_by_zero_with_no_count(self, device: str):
vqa = VqaMeasure()
assert vqa.get_metric()["score"] == pytest.approx(0.0)

def test_distributed_accuracy(self):
logits = [
torch.tensor([[0.35, 0.25, 0.1, 0.1, 0.2]]),
torch.tensor([[0.1, 0.6, 0.1, 0.2, 0.0]]),
]
labels = [torch.tensor([[0]]), torch.tensor([[3]])]
label_weights = [torch.tensor([[1 / 3]]), torch.tensor([[2 / 3]])]
metric_kwargs = {"logits": logits, "labels": labels, "label_weights": label_weights}
desired_accuracy = {"score": (1 / 3) / 2}
run_distributed_test(
[-1, -1],
global_distributed_metric,
VqaMeasure(),
metric_kwargs,
desired_accuracy,
exact=False,
)

def test_distributed_accuracy_unequal_batches(self):
logits = [
torch.tensor([[0.35, 0.25, 0.1, 0.1, 0.2], [0.35, 0.25, 0.1, 0.1, 0.2]]),
torch.tensor([[0.1, 0.6, 0.1, 0.2, 0.0]]),
]
labels = [torch.tensor([[0], [0]]), torch.tensor([[3]])]
label_weights = [torch.tensor([[1], [1]]), torch.tensor([[1 / 3]])]
metric_kwargs = {"logits": logits, "labels": labels, "label_weights": label_weights}
desired_accuracy = {"score": (1 + 1 + 0) / 3}
run_distributed_test(
[-1, -1],
global_distributed_metric,
VqaMeasure(),
metric_kwargs,
desired_accuracy,
exact=False,
)

def test_multiple_distributed_runs(self):
logits = [
torch.tensor([[0.35, 0.25, 0.1, 0.1, 0.2]]),
torch.tensor([[0.1, 0.6, 0.1, 0.2, 0.0]]),
]
labels = [torch.tensor([[0]]), torch.tensor([[3]])]
label_weights = [torch.tensor([[1 / 3]]), torch.tensor([[2 / 3]])]
metric_kwargs = {"logits": logits, "labels": labels, "label_weights": label_weights}
desired_accuracy = {"score": (1 / 3) / 2}
run_distributed_test(
[-1, -1],
global_distributed_metric,
VqaMeasure(),
metric_kwargs,
desired_accuracy,
exact=True,
number_of_runs=200,
)

0 comments on commit 419bc90

Please sign in to comment.