Skip to content

Commit

Permalink
Implement WIL metric
Browse files Browse the repository at this point in the history
Summary: Implemented functional and class based Word Information Loss metric and added tests.

Reviewed By: ninginthecloud

Differential Revision: D41942828

fbshipit-source-id: 9b1fefce0693344450af80d139ceb3fe5c56279a
  • Loading branch information
andreasfloros authored and facebook-github-bot committed Dec 16, 2022
1 parent 02f7ede commit 94de3e3
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 3 deletions.
42 changes: 42 additions & 0 deletions tests/metrics/functional/text/test_word_information_lost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torcheval.metrics.functional import word_information_lost


class TestWordInformationLost(unittest.TestCase):
def test_word_information_lost(self) -> None:

input = ["hello world", "welcome to the facebook"]
target = ["hello metaverse", "welcome to meta"]
torch.testing.assert_close(
word_information_lost(input, target),
torch.tensor(0.7, dtype=torch.float64),
)

input = ["this is the prediction", "there is an other sample"]
target = ["this is the reference", "there is another one"]
torch.testing.assert_close(
word_information_lost(input, target),
torch.tensor(0.6527777, dtype=torch.float64),
)

def test_word_information_lost_with_invalid_input(self) -> None:
with self.assertRaisesRegex(
AssertionError,
"Arguments must contain the same number of strings.",
):
word_information_lost(
["hello metaverse", "welcome to meta"],
[
"welcome to meta",
"this is the prediction",
"there is an other sample",
],
)
49 changes: 49 additions & 0 deletions tests/metrics/text/test_word_information_lost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torcheval.metrics.text import WordInformationLost
from torcheval.utils.test_utils.metric_class_tester import MetricClassTester


class TestWordInformationLost(MetricClassTester):
def test_word_information_lost(self) -> None:
self.run_class_implementation_tests(
metric=WordInformationLost(),
state_names={"correct_total", "target_total", "preds_total"},
update_kwargs={
"input": [
["hello world", "welcome to the facebook"],
["hello world", "welcome to the facebook"],
["hello world", "welcome to the facebook"],
["hello world", "welcome to the facebook"],
],
"target": [
["hello metaverse", "welcome to meta"],
["hello metaverse", "welcome to meta"],
["hello metaverse", "welcome to meta"],
["hello metaverse", "welcome to meta"],
],
},
compute_result=torch.tensor(0.7, dtype=torch.float64),
num_total_updates=4,
)

def test_word_information_lost_with_invalid_input(self) -> None:
metric = WordInformationLost()

with self.assertRaisesRegex(
AssertionError,
"Arguments must contain the same number of strings.",
):
metric.update(
["hello metaverse", "welcome to meta"],
[
"welcome to meta",
"this is the prediction",
"there is an other sample",
],
)
3 changes: 2 additions & 1 deletion torcheval/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
from torcheval.metrics.functional.text import (
perplexity,
word_error_rate,
word_information_lost,
word_information_preserved,
)


__all__ = [
"auc",
"binary_auprc",
Expand Down Expand Up @@ -89,4 +89,5 @@
"weighted_calibration",
"word_error_rate",
"word_information_preserved",
"word_information_lost",
]
10 changes: 9 additions & 1 deletion torcheval/metrics/functional/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@

from torcheval.metrics.functional.text.perplexity import perplexity
from torcheval.metrics.functional.text.word_error_rate import word_error_rate
from torcheval.metrics.functional.text.word_information_lost import (
word_information_lost,
)
from torcheval.metrics.functional.text.word_information_preserved import (
word_information_preserved,
)

__all__ = ["perplexity", "word_error_rate", "word_information_preserved"]
__all__ = [
"perplexity",
"word_error_rate",
"word_information_preserved",
"word_information_lost",
]
76 changes: 76 additions & 0 deletions torcheval/metrics/functional/text/word_information_lost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Tuple, Union

import torch

from torcheval.metrics.functional.text.helper import _get_errors_and_totals


def _wil_update(
input: Union[str, List[str]],
target: Union[str, List[str]],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Update the wil score with the current set of references and predictions.
Args:
input: Transcription(s) to score as a string or list of strings
target: Reference(s) for each speech input as a string or list of strings
Returns:
Number of correct words
Number of words overall references
Number of words overall predictions
"""
if isinstance(input, str):
input = [input]
if isinstance(target, str):
target = [target]
assert len(input) == len(
target
), f"Arguments must contain the same number of strings, but got len(input)={len(input)} and len(target)={len(target)}"
errors, max_total, target_total, input_total = _get_errors_and_totals(input, target)
return errors - max_total, target_total, input_total


def _wil_compute(
correct_total: torch.Tensor, target_total: torch.Tensor, preds_total: torch.Tensor
) -> torch.Tensor:
"""Compute the Word Information Lost.
Args:
correct_total: Number of correct words
target_total: Number of words overall references
preds_total: Number of words overall prediction
Returns:
Word Information Lost score
"""
return 1 - ((correct_total / target_total) * (correct_total / preds_total))


@torch.inference_mode()
def word_information_lost(
input: Union[str, List[str]],
target: Union[str, List[str]],
) -> torch.Tensor:
"""Word Information Lost rate is a metric of the performance of an automatic speech recognition system. This
value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better
the performance of the ASR system with a Word Information Lost rate of 0 being a perfect score.
Its class version is ``torcheval.metrics.WordInformationLost``.
Args:
input: Transcription(s) to score as a string or list of strings
target: Reference(s) for each speech input as a string or list of strings
Returns:
Word Information Lost rate
Examples:
>>> from torcheval.metrics.functional import word_information_lost
>>> input = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> word_information_lost(input, target)
tensor(0.6528)
"""
correct_total, target_total, preds_total = _wil_update(input, target)
return _wil_compute(correct_total, target_total, preds_total)
8 changes: 7 additions & 1 deletion torcheval/metrics/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@

from torcheval.metrics.text.perplexity import Perplexity
from torcheval.metrics.text.word_error_rate import WordErrorRate
from torcheval.metrics.text.word_information_lost import WordInformationLost
from torcheval.metrics.text.word_information_preserved import WordInformationPreserved

__all__ = ["Perplexity", "WordErrorRate", "WordInformationPreserved"]
__all__ = [
"Perplexity",
"WordErrorRate",
"WordInformationLost",
"WordInformationPreserved",
]
96 changes: 96 additions & 0 deletions torcheval/metrics/text/word_information_lost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors[16]: Undefined attribute of metric states.

from typing import Iterable, List, Optional, TypeVar, Union

import torch

from torcheval.metrics.functional.text.word_information_lost import (
_wil_compute,
_wil_update,
)

from torcheval.metrics.metric import Metric

TWordInformationLost = TypeVar("TWordInformationLost")


class WordInformationLost(Metric[torch.Tensor]):
"""Word Information Lost (WIL_) is a metric of the performance of an automatic speech recognition system. This
value indicates the percentage of words that were incorrectly predicted between a set of ground-truth sentences
and a set of hypothesis sentences. The lower the value, the better the performance of the ASR system with a
WordInformationLost of 0 being a perfect score. Word Information Lost rate can then be computed as:
.. math::
wil = 1 - \frac{C}{N} * \frac{C}{P}
where:
- :math:`C` is the number of correct words,
- :math:`N` is the number of words in the reference
- :math:`P` is the number of words in the prediction
Its functional version is :func:`torcheval.metrics.functional.word_information_lost`.
Examples:
>>> from torcheval.metrics.text import WordInformationLost
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> metric = WordInformationLost()
>>> metric(preds, target)
tensor(0.6528)
"""

def __init__(
self: TWordInformationLost,
device: Optional[torch.device] = None,
) -> None:
super().__init__(device=device)
self._add_state("correct_total", torch.tensor(0.0, dtype=torch.float64))
self._add_state("target_total", torch.tensor(0.0, dtype=torch.float64))
self._add_state("preds_total", torch.tensor(0.0, dtype=torch.float64))

@torch.inference_mode()
# pyre-ignore[14]: `update` overrides method defined in `Metric` inconsistently.
def update(
self: TWordInformationLost,
input: Union[str, List[str]],
target: Union[str, List[str]],
) -> TWordInformationLost:
"""Store predictions/references for computing Word Information Lost scores.
Args:
input: Transcription(s) to score as a string or list of strings
target: Reference(s) for each speech input as a string or list of strings
"""
correct_total, target_total, preds_total = _wil_update(input, target)
self.correct_total += correct_total
self.target_total += target_total
self.preds_total += preds_total
return self

@torch.inference_mode()
def compute(self: TWordInformationLost) -> torch.Tensor:
"""Calculate the Word Information Lost.
Returns:
Word Information Lost score
"""
return _wil_compute(self.correct_total, self.target_total, self.preds_total)

@torch.inference_mode()
def merge_state(
self: TWordInformationLost,
metrics: Iterable[TWordInformationLost],
) -> TWordInformationLost:
"""
Merge the metric state with its counterparts from other metric instances.
Args:
metrics (Iterable[Metric]): metric instances whose states are to be merged.
"""
for metric in metrics:
self.correct_total += metric.correct_total.to(self.device)
self.target_total += metric.target_total.to(self.device)
self.preds_total += metric.preds_total.to(self.device)
return self

0 comments on commit 94de3e3

Please sign in to comment.