-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #95 Implemented BLEU class metric. Reviewed By: ninginthecloud Differential Revision: D40813045 fbshipit-source-id: 6c9b1873f9bed9227cc2b67d7bc4d22a0805f331
- Loading branch information
1 parent
d2401f1
commit 911d8d0
Showing
4 changed files
with
249 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#!/usr/bin/env fbpython | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from torcheval.metrics.text import BLEUScore | ||
from torcheval.utils.test_utils.metric_class_tester import MetricClassTester | ||
|
||
|
||
class TestBleu(MetricClassTester): | ||
def test_bleu_invalid_update(self) -> None: | ||
candidates = ["the squirrel is eating the nut"] | ||
references = [ | ||
["a squirrel is eating a nut", "the squirrel is eating a tasty nut"], | ||
["there is a cat on the mat", "a cat is on the mat"], | ||
] | ||
metric = BLEUScore(n_gram=4) | ||
with self.assertRaisesRegex( | ||
ValueError, | ||
"Input and target corpus should have same sizes", | ||
): | ||
metric.update(candidates, references) | ||
|
||
def test_bleu_invalid_w(self) -> None: | ||
with self.assertRaisesRegex( | ||
ValueError, | ||
"the length of weights should equal n_gram", | ||
): | ||
BLEUScore(n_gram=4, weights=torch.tensor([0.3, 0.3, 0.4])) | ||
|
||
def test_bleu_invalid_n(self) -> None: | ||
with self.assertRaisesRegex( | ||
ValueError, | ||
"n_gram should be 1, 2, 3, or 4", | ||
): | ||
BLEUScore(n_gram=5) | ||
|
||
def test_bleu_single_example(self) -> None: | ||
candidate = ["the squirrel is eating the nut"] | ||
reference = [ | ||
["a squirrel is eating a nut", "the squirrel is eating a tasty nut"] | ||
] | ||
metric = BLEUScore(n_gram=4) | ||
metric.update(candidate, reference) | ||
val = metric.compute() | ||
self.assertAlmostEqual(val.item(), 0.53728497) | ||
|
||
def test_bleu_multiple_updates(self) -> None: | ||
candidates = [["the squirrel is eating the nut"], ["the cat is on the mat"]] | ||
references = [ | ||
[["a squirrel is eating a nut", "the squirrel is eating a tasty nut"]], | ||
[["there is a cat on the mat", "a cat is on the mat"]], | ||
] | ||
self.run_class_implementation_tests( | ||
metric=BLEUScore(n_gram=4), | ||
state_names={ | ||
"input_len", | ||
"target_len", | ||
"matches_by_order", | ||
"possible_matches_by_order", | ||
}, | ||
update_kwargs={ | ||
"input": candidates, | ||
"target": references, | ||
}, | ||
compute_result=torch.tensor(0.65341892, dtype=torch.float64), | ||
num_total_updates=2, | ||
num_processes=2, | ||
) | ||
|
||
def test_bleu_multiple_examples_per_update(self) -> None: | ||
candidates = [ | ||
["the squirrel is eating the nut", "the cat is on the mat"], | ||
["i like ice cream and apple pie"], | ||
] | ||
references = [ | ||
[ | ||
["a squirrel is eating a nut", "the squirrel is eating a tasty nut"], | ||
["there is a cat on the mat", "a cat is on the mat"], | ||
], | ||
[ | ||
[ | ||
"i like apple pie with ice cream on top", | ||
"i like ice cream with my apple pie", | ||
"i enjoy my apple pie with ice cream", | ||
] | ||
], | ||
] | ||
self.run_class_implementation_tests( | ||
metric=BLEUScore(n_gram=4), | ||
state_names={ | ||
"input_len", | ||
"target_len", | ||
"matches_by_order", | ||
"possible_matches_by_order", | ||
}, | ||
update_kwargs={ | ||
"input": candidates, | ||
"target": references, | ||
}, | ||
compute_result=torch.tensor(0.56377503, dtype=torch.float64), | ||
num_total_updates=2, | ||
num_processes=2, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-ignore-all-errors[16]: Undefined attribute of metric states. | ||
|
||
from typing import Iterable, Optional, Sequence, TypeVar, Union | ||
|
||
import torch | ||
from torcheval.metrics.functional.text.bleu import ( | ||
_bleu_score_compute, | ||
_bleu_score_update, | ||
) | ||
|
||
from torcheval.metrics.metric import Metric | ||
|
||
TBLEUScore = TypeVar("TBLEUScore") | ||
|
||
|
||
class BLEUScore(Metric[torch.Tensor]): | ||
""" | ||
Compute BLEU score (https://en.wikipedia.org/wiki/BLEU) given translations and references. | ||
Its functional version is ``torcheval.metrics.functional.text.bleu``. | ||
Args: | ||
n_gram: Maximum n-gram to use when computing BLEU score. Can be 1, 2, 3, or 4. | ||
weights: Optional weight distribution of n-grams. Requires len(weights) = n_gram. If unspecified, will use uniform weights. | ||
Examples: | ||
>>> import torch | ||
>>> from torcheval.metrics import BLEUScore | ||
>>> metric = BLEUScore(n_gram=4) | ||
>>> candidates = ["the squirrel is eating the nut", "the cat is on the mat"] | ||
>>> references = [["a squirrel is eating a nut", "the squirrel is eating a tasty nut"], ["there is a cat on the mat", "a cat is on the mat"]] | ||
>>> metric.update(candidates, references) | ||
>>> metric.compute() | ||
tensor(0.65341892) | ||
>>> candidates = ["i like ice cream and apple pie"] | ||
>>> references = [["i like apple pie with ice cream on top", "i like ice cream with my apple pie", "i enjoy my apple pie with ice cream"]] | ||
>>> metric.update(candidates, references) | ||
>>> metric.compute() | ||
tensor([0.56377503]) | ||
""" | ||
|
||
def __init__( | ||
self: TBLEUScore, | ||
*, | ||
n_gram: int, | ||
weights: Optional[torch.Tensor] = None, | ||
device: Optional[torch.device] = None, | ||
) -> None: | ||
super().__init__(device=device) | ||
|
||
if n_gram not in [1, 2, 3, 4]: | ||
raise ValueError(f"n_gram should be 1, 2, 3, or 4, got {n_gram}.") | ||
if weights is not None and n_gram != len(weights): | ||
raise ValueError( | ||
f"the length of weights should equal n_gram, got len(weights)={len(weights)}, n_gram={n_gram}" | ||
) | ||
|
||
self.weights = weights | ||
self.n_gram = n_gram | ||
self._add_state( | ||
"input_len", torch.tensor(0.0, dtype=torch.float64, device=device) | ||
) | ||
self._add_state( | ||
"target_len", torch.tensor(0.0, dtype=torch.float64, device=device) | ||
) | ||
self._add_state( | ||
"matches_by_order", | ||
torch.zeros(n_gram, dtype=torch.float64, device=device), | ||
) | ||
self._add_state( | ||
"possible_matches_by_order", | ||
torch.zeros(n_gram, dtype=torch.float64, device=device), | ||
) | ||
|
||
@torch.inference_mode() | ||
# pyre-ignore[14]: `update` overrides method defined in `Metric` inconsistently. | ||
def update( | ||
self: TBLEUScore, | ||
input: Union[str, Sequence[str]], | ||
target: Sequence[Union[str, Sequence[str]]], | ||
) -> TBLEUScore: | ||
""" | ||
Update the metric state with new inputs. | ||
Args: | ||
input: Translations to score. | ||
target: List of references for each translation. | ||
""" | ||
( | ||
input_len, | ||
target_len, | ||
matches_by_order, | ||
possible_matches_by_order, | ||
) = _bleu_score_update(input, target, self.n_gram, self.device) | ||
self.input_len += input_len | ||
self.target_len += target_len | ||
self.matches_by_order += matches_by_order | ||
self.possible_matches_by_order += possible_matches_by_order | ||
return self | ||
|
||
@torch.inference_mode() | ||
def compute(self: TBLEUScore) -> torch.Tensor: | ||
""" | ||
Returns the running BLEUScore. If no ``update()`` calls are made before | ||
``compute()`` is called, return tensor(0.0). | ||
""" | ||
if torch.sum(self.matches_by_order) == 0: | ||
return torch.tensor(0.0, dtype=torch.float64, device=self.device) | ||
return _bleu_score_compute( | ||
self.input_len, | ||
self.target_len, | ||
self.matches_by_order, | ||
self.possible_matches_by_order, | ||
self.n_gram, | ||
self.weights, | ||
) | ||
|
||
@torch.inference_mode() | ||
def merge_state(self: TBLEUScore, metrics: Iterable[TBLEUScore]) -> TBLEUScore: | ||
""" | ||
Merge the metric state with its counterparts from other metric instances. | ||
Args: | ||
metrics (Iterable[Metric]): metric instances whose states are to be merged. | ||
""" | ||
for metric in metrics: | ||
self.input_len += metric.input_len.to(self.device) | ||
self.target_len += metric.target_len.to(self.device) | ||
self.matches_by_order += metric.matches_by_order.to(self.device) | ||
self.possible_matches_by_order += metric.possible_matches_by_order.to( | ||
self.device | ||
) | ||
return self |