Skip to content

Commit

Permalink
Implement BLEU metric (class) (#95)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #95

Implemented BLEU class metric.

Reviewed By: ninginthecloud

Differential Revision: D40813045

fbshipit-source-id: 6c9b1873f9bed9227cc2b67d7bc4d22a0805f331
  • Loading branch information
Erika Lal authored and facebook-github-bot committed Dec 27, 2022
1 parent d2401f1 commit 911d8d0
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 0 deletions.
107 changes: 107 additions & 0 deletions tests/metrics/text/test_bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#!/usr/bin/env fbpython
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torcheval.metrics.text import BLEUScore
from torcheval.utils.test_utils.metric_class_tester import MetricClassTester


class TestBleu(MetricClassTester):
def test_bleu_invalid_update(self) -> None:
candidates = ["the squirrel is eating the nut"]
references = [
["a squirrel is eating a nut", "the squirrel is eating a tasty nut"],
["there is a cat on the mat", "a cat is on the mat"],
]
metric = BLEUScore(n_gram=4)
with self.assertRaisesRegex(
ValueError,
"Input and target corpus should have same sizes",
):
metric.update(candidates, references)

def test_bleu_invalid_w(self) -> None:
with self.assertRaisesRegex(
ValueError,
"the length of weights should equal n_gram",
):
BLEUScore(n_gram=4, weights=torch.tensor([0.3, 0.3, 0.4]))

def test_bleu_invalid_n(self) -> None:
with self.assertRaisesRegex(
ValueError,
"n_gram should be 1, 2, 3, or 4",
):
BLEUScore(n_gram=5)

def test_bleu_single_example(self) -> None:
candidate = ["the squirrel is eating the nut"]
reference = [
["a squirrel is eating a nut", "the squirrel is eating a tasty nut"]
]
metric = BLEUScore(n_gram=4)
metric.update(candidate, reference)
val = metric.compute()
self.assertAlmostEqual(val.item(), 0.53728497)

def test_bleu_multiple_updates(self) -> None:
candidates = [["the squirrel is eating the nut"], ["the cat is on the mat"]]
references = [
[["a squirrel is eating a nut", "the squirrel is eating a tasty nut"]],
[["there is a cat on the mat", "a cat is on the mat"]],
]
self.run_class_implementation_tests(
metric=BLEUScore(n_gram=4),
state_names={
"input_len",
"target_len",
"matches_by_order",
"possible_matches_by_order",
},
update_kwargs={
"input": candidates,
"target": references,
},
compute_result=torch.tensor(0.65341892, dtype=torch.float64),
num_total_updates=2,
num_processes=2,
)

def test_bleu_multiple_examples_per_update(self) -> None:
candidates = [
["the squirrel is eating the nut", "the cat is on the mat"],
["i like ice cream and apple pie"],
]
references = [
[
["a squirrel is eating a nut", "the squirrel is eating a tasty nut"],
["there is a cat on the mat", "a cat is on the mat"],
],
[
[
"i like apple pie with ice cream on top",
"i like ice cream with my apple pie",
"i enjoy my apple pie with ice cream",
]
],
]
self.run_class_implementation_tests(
metric=BLEUScore(n_gram=4),
state_names={
"input_len",
"target_len",
"matches_by_order",
"possible_matches_by_order",
},
update_kwargs={
"input": candidates,
"target": references,
},
compute_result=torch.tensor(0.56377503, dtype=torch.float64),
num_total_updates=2,
num_processes=2,
)
2 changes: 2 additions & 0 deletions torcheval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from torcheval.metrics.regression import MeanSquaredError, R2Score

from torcheval.metrics.text import (
BLEUScore,
Perplexity,
WordErrorRate,
WordInformationLost,
Expand Down Expand Up @@ -78,6 +79,7 @@
"BinaryPrecisionRecallCurve",
"BinaryRecall",
"BinaryRecallAtFixedPrecision",
"BLEUScore",
"Cat",
"ClickThroughRate",
"HitRate",
Expand Down
2 changes: 2 additions & 0 deletions torcheval/metrics/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torcheval.metrics.text.bleu import BLEUScore
from torcheval.metrics.text.perplexity import Perplexity
from torcheval.metrics.text.word_error_rate import WordErrorRate
from torcheval.metrics.text.word_information_lost import WordInformationLost
from torcheval.metrics.text.word_information_preserved import WordInformationPreserved

__all__ = [
"BLEUScore",
"Perplexity",
"WordErrorRate",
"WordInformationLost",
Expand Down
138 changes: 138 additions & 0 deletions torcheval/metrics/text/bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors[16]: Undefined attribute of metric states.

from typing import Iterable, Optional, Sequence, TypeVar, Union

import torch
from torcheval.metrics.functional.text.bleu import (
_bleu_score_compute,
_bleu_score_update,
)

from torcheval.metrics.metric import Metric

TBLEUScore = TypeVar("TBLEUScore")


class BLEUScore(Metric[torch.Tensor]):
"""
Compute BLEU score (https://en.wikipedia.org/wiki/BLEU) given translations and references.
Its functional version is ``torcheval.metrics.functional.text.bleu``.
Args:
n_gram: Maximum n-gram to use when computing BLEU score. Can be 1, 2, 3, or 4.
weights: Optional weight distribution of n-grams. Requires len(weights) = n_gram. If unspecified, will use uniform weights.
Examples:
>>> import torch
>>> from torcheval.metrics import BLEUScore
>>> metric = BLEUScore(n_gram=4)
>>> candidates = ["the squirrel is eating the nut", "the cat is on the mat"]
>>> references = [["a squirrel is eating a nut", "the squirrel is eating a tasty nut"], ["there is a cat on the mat", "a cat is on the mat"]]
>>> metric.update(candidates, references)
>>> metric.compute()
tensor(0.65341892)
>>> candidates = ["i like ice cream and apple pie"]
>>> references = [["i like apple pie with ice cream on top", "i like ice cream with my apple pie", "i enjoy my apple pie with ice cream"]]
>>> metric.update(candidates, references)
>>> metric.compute()
tensor([0.56377503])
"""

def __init__(
self: TBLEUScore,
*,
n_gram: int,
weights: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
) -> None:
super().__init__(device=device)

if n_gram not in [1, 2, 3, 4]:
raise ValueError(f"n_gram should be 1, 2, 3, or 4, got {n_gram}.")
if weights is not None and n_gram != len(weights):
raise ValueError(
f"the length of weights should equal n_gram, got len(weights)={len(weights)}, n_gram={n_gram}"
)

self.weights = weights
self.n_gram = n_gram
self._add_state(
"input_len", torch.tensor(0.0, dtype=torch.float64, device=device)
)
self._add_state(
"target_len", torch.tensor(0.0, dtype=torch.float64, device=device)
)
self._add_state(
"matches_by_order",
torch.zeros(n_gram, dtype=torch.float64, device=device),
)
self._add_state(
"possible_matches_by_order",
torch.zeros(n_gram, dtype=torch.float64, device=device),
)

@torch.inference_mode()
# pyre-ignore[14]: `update` overrides method defined in `Metric` inconsistently.
def update(
self: TBLEUScore,
input: Union[str, Sequence[str]],
target: Sequence[Union[str, Sequence[str]]],
) -> TBLEUScore:
"""
Update the metric state with new inputs.
Args:
input: Translations to score.
target: List of references for each translation.
"""
(
input_len,
target_len,
matches_by_order,
possible_matches_by_order,
) = _bleu_score_update(input, target, self.n_gram, self.device)
self.input_len += input_len
self.target_len += target_len
self.matches_by_order += matches_by_order
self.possible_matches_by_order += possible_matches_by_order
return self

@torch.inference_mode()
def compute(self: TBLEUScore) -> torch.Tensor:
"""
Returns the running BLEUScore. If no ``update()`` calls are made before
``compute()`` is called, return tensor(0.0).
"""
if torch.sum(self.matches_by_order) == 0:
return torch.tensor(0.0, dtype=torch.float64, device=self.device)
return _bleu_score_compute(
self.input_len,
self.target_len,
self.matches_by_order,
self.possible_matches_by_order,
self.n_gram,
self.weights,
)

@torch.inference_mode()
def merge_state(self: TBLEUScore, metrics: Iterable[TBLEUScore]) -> TBLEUScore:
"""
Merge the metric state with its counterparts from other metric instances.
Args:
metrics (Iterable[Metric]): metric instances whose states are to be merged.
"""
for metric in metrics:
self.input_len += metric.input_len.to(self.device)
self.target_len += metric.target_len.to(self.device)
self.matches_by_order += metric.matches_by_order.to(self.device)
self.possible_matches_by_order += metric.possible_matches_by_order.to(
self.device
)
return self

0 comments on commit 911d8d0

Please sign in to comment.