-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #111 Implemented functional and class based Word Information Loss metric and added tests. Reviewed By: ninginthecloud Differential Revision: D41942828 fbshipit-source-id: 18befe96d893f3074972be6fc5356f2f000e76e3
- Loading branch information
1 parent
7f63371
commit e80aa1f
Showing
9 changed files
with
297 additions
and
7 deletions.
There are no files selected for viewing
42 changes: 42 additions & 0 deletions
42
tests/metrics/functional/text/test_word_information_lost.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import torch | ||
from torcheval.metrics.functional import word_information_lost | ||
|
||
|
||
class TestWordInformationLost(unittest.TestCase): | ||
def test_word_information_lost(self) -> None: | ||
|
||
input = ["hello world", "welcome to the facebook"] | ||
target = ["hello metaverse", "welcome to meta"] | ||
torch.testing.assert_close( | ||
word_information_lost(input, target), | ||
torch.tensor(0.7, dtype=torch.float64), | ||
) | ||
|
||
input = ["this is the prediction", "there is an other sample"] | ||
target = ["this is the reference", "there is another one"] | ||
torch.testing.assert_close( | ||
word_information_lost(input, target), | ||
torch.tensor(0.6527777, dtype=torch.float64), | ||
) | ||
|
||
def test_word_information_lost_with_invalid_input(self) -> None: | ||
with self.assertRaisesRegex( | ||
AssertionError, | ||
"Arguments must contain the same number of strings.", | ||
): | ||
word_information_lost( | ||
["hello metaverse", "welcome to meta"], | ||
[ | ||
"welcome to meta", | ||
"this is the prediction", | ||
"there is an other sample", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from torcheval.metrics.text import WordInformationLost | ||
from torcheval.utils.test_utils.metric_class_tester import MetricClassTester | ||
|
||
|
||
class TestWordInformationLost(MetricClassTester): | ||
def test_word_information_lost(self) -> None: | ||
self.run_class_implementation_tests( | ||
metric=WordInformationLost(), | ||
state_names={"correct_total", "target_total", "preds_total"}, | ||
update_kwargs={ | ||
"input": [ | ||
["hello world", "welcome to the facebook"], | ||
["hello world", "welcome to the facebook"], | ||
["hello world", "welcome to the facebook"], | ||
["hello world", "welcome to the facebook"], | ||
], | ||
"target": [ | ||
["hello metaverse", "welcome to meta"], | ||
["hello metaverse", "welcome to meta"], | ||
["hello metaverse", "welcome to meta"], | ||
["hello metaverse", "welcome to meta"], | ||
], | ||
}, | ||
compute_result=torch.tensor(0.7, dtype=torch.float64), | ||
num_total_updates=4, | ||
) | ||
|
||
def test_word_information_lost_with_invalid_input(self) -> None: | ||
metric = WordInformationLost() | ||
|
||
with self.assertRaisesRegex( | ||
AssertionError, | ||
"Arguments must contain the same number of strings.", | ||
): | ||
metric.update( | ||
["hello metaverse", "welcome to meta"], | ||
[ | ||
"welcome to meta", | ||
"this is the prediction", | ||
"there is an other sample", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
torcheval/metrics/functional/text/word_information_lost.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import List, Tuple, Union | ||
|
||
import torch | ||
|
||
from torcheval.metrics.functional.text.helper import _get_errors_and_totals | ||
|
||
|
||
def _wil_update( | ||
input: Union[str, List[str]], | ||
target: Union[str, List[str]], | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
"""Update the wil score with the current set of references and predictions. | ||
Args: | ||
input: Transcription(s) to score as a string or list of strings | ||
target: Reference(s) for each speech input as a string or list of strings | ||
Returns: | ||
Number of correct words | ||
Number of words overall references | ||
Number of words overall predictions | ||
""" | ||
if isinstance(input, str): | ||
input = [input] | ||
if isinstance(target, str): | ||
target = [target] | ||
assert len(input) == len( | ||
target | ||
), f"Arguments must contain the same number of strings, but got len(input)={len(input)} and len(target)={len(target)}" | ||
errors, max_total, target_total, input_total = _get_errors_and_totals(input, target) | ||
return errors - max_total, target_total, input_total | ||
|
||
|
||
def _wil_compute( | ||
correct_total: torch.Tensor, target_total: torch.Tensor, preds_total: torch.Tensor | ||
) -> torch.Tensor: | ||
"""Compute the Word Information Lost. | ||
Args: | ||
correct_total: Number of correct words | ||
target_total: Number of words overall references | ||
preds_total: Number of words overall prediction | ||
Returns: | ||
Word Information Lost score | ||
""" | ||
return 1 - ((correct_total / target_total) * (correct_total / preds_total)) | ||
|
||
|
||
@torch.inference_mode() | ||
def word_information_lost( | ||
input: Union[str, List[str]], | ||
target: Union[str, List[str]], | ||
) -> torch.Tensor: | ||
"""Word Information Lost rate is a metric of the performance of an automatic speech recognition system. This | ||
value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better | ||
the performance of the ASR system with a Word Information Lost rate of 0 being a perfect score. | ||
Its class version is ``torcheval.metrics.WordInformationLost``. | ||
Args: | ||
input: Transcription(s) to score as a string or list of strings | ||
target: Reference(s) for each speech input as a string or list of strings | ||
Returns: | ||
Word Information Lost rate | ||
Examples: | ||
>>> from torcheval.metrics.functional import word_information_lost | ||
>>> input = ["this is the prediction", "there is an other sample"] | ||
>>> target = ["this is the reference", "there is another one"] | ||
>>> word_information_lost(input, target) | ||
tensor(0.6528) | ||
""" | ||
correct_total, target_total, preds_total = _wil_update(input, target) | ||
return _wil_compute(correct_total, target_total, preds_total) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-ignore-all-errors[16]: Undefined attribute of metric states. | ||
|
||
from typing import Iterable, List, Optional, TypeVar, Union | ||
|
||
import torch | ||
|
||
from torcheval.metrics.functional.text.word_information_lost import ( | ||
_wil_compute, | ||
_wil_update, | ||
) | ||
|
||
from torcheval.metrics.metric import Metric | ||
|
||
TWordInformationLost = TypeVar("TWordInformationLost") | ||
|
||
|
||
class WordInformationLost(Metric[torch.Tensor]): | ||
"""Word Information Lost (WIL_) is a metric of the performance of an automatic speech recognition system. This | ||
value indicates the percentage of words that were incorrectly predicted between a set of ground-truth sentences | ||
and a set of hypothesis sentences. The lower the value, the better the performance of the ASR system with a | ||
WordInformationLost of 0 being a perfect score. Word Information Lost rate can then be computed as: | ||
.. math:: | ||
wil = 1 - \frac{C}{N} * \frac{C}{P} | ||
where: | ||
- :math:`C` is the number of correct words, | ||
- :math:`N` is the number of words in the reference | ||
- :math:`P` is the number of words in the prediction | ||
Its functional version is :func:`torcheval.metrics.functional.word_information_lost`. | ||
Examples: | ||
>>> from torcheval.metrics.text import WordInformationLost | ||
>>> preds = ["this is the prediction", "there is an other sample"] | ||
>>> target = ["this is the reference", "there is another one"] | ||
>>> metric = WordInformationLost() | ||
>>> metric(preds, target) | ||
tensor(0.6528) | ||
""" | ||
|
||
def __init__( | ||
self: TWordInformationLost, | ||
device: Optional[torch.device] = None, | ||
) -> None: | ||
super().__init__(device=device) | ||
self._add_state( | ||
"correct_total", torch.tensor(0.0, dtype=torch.float64, device=self.device) | ||
) | ||
self._add_state( | ||
"target_total", torch.tensor(0.0, dtype=torch.float64, device=self.device) | ||
) | ||
self._add_state( | ||
"preds_total", torch.tensor(0.0, dtype=torch.float64, device=self.device) | ||
) | ||
|
||
@torch.inference_mode() | ||
# pyre-ignore[14]: `update` overrides method defined in `Metric` inconsistently. | ||
def update( | ||
self: TWordInformationLost, | ||
input: Union[str, List[str]], | ||
target: Union[str, List[str]], | ||
) -> TWordInformationLost: | ||
"""Store predictions/references for computing Word Information Lost scores. | ||
Args: | ||
input: Transcription(s) to score as a string or list of strings | ||
target: Reference(s) for each speech input as a string or list of strings | ||
""" | ||
correct_total, target_total, preds_total = _wil_update(input, target) | ||
self.correct_total += correct_total.to(self.device) | ||
self.target_total += target_total.to(self.device) | ||
self.preds_total += preds_total.to(self.device) | ||
return self | ||
|
||
@torch.inference_mode() | ||
def compute(self: TWordInformationLost) -> torch.Tensor: | ||
"""Calculate the Word Information Lost. | ||
Returns: | ||
Word Information Lost score | ||
""" | ||
return _wil_compute(self.correct_total, self.target_total, self.preds_total) | ||
|
||
@torch.inference_mode() | ||
def merge_state( | ||
self: TWordInformationLost, | ||
metrics: Iterable[TWordInformationLost], | ||
) -> TWordInformationLost: | ||
""" | ||
Merge the metric state with its counterparts from other metric instances. | ||
Args: | ||
metrics (Iterable[Metric]): metric instances whose states are to be merged. | ||
""" | ||
for metric in metrics: | ||
self.correct_total += metric.correct_total.to(self.device) | ||
self.target_total += metric.target_total.to(self.device) | ||
self.preds_total += metric.preds_total.to(self.device) | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters