Skip to content

Commit

Permalink
Update audio docs (#1360)
Browse files Browse the repository at this point in the history
* exclude update and compute
* update modular documentation
* matching function docs
* Apply suggestions from code review

Co-authored-by: Jirka <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
5 people committed Nov 30, 2022
1 parent bc057d4 commit 8bcfb73
Show file tree
Hide file tree
Showing 18 changed files with 204 additions and 243 deletions.
5 changes: 5 additions & 0 deletions docs/source/audio/perceptual_evaluation_speech_quality.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/audio_classification.svg
:tags: Audio

.. include:: ../links.rst

##############################################
Perceptual Evaluation of Speech Quality (PESQ)
##############################################
Expand All @@ -11,8 +13,11 @@ Module Interface
________________

.. autoclass:: torchmetrics.audio.pesq.PerceptualEvaluationSpeechQuality
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.audio.pesq.perceptual_evaluation_speech_quality
:noindex:
1 change: 1 addition & 0 deletions docs/source/audio/permutation_invariant_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.PermutationInvariantTraining
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
3 changes: 3 additions & 0 deletions docs/source/audio/scale_invariant_signal_distortion_ratio.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/audio_classification.svg
:tags: Audio

.. include:: ../links.rst

###################################################
Scale-Invariant Signal-to-Distortion Ratio (SI-SDR)
###################################################
Expand All @@ -12,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.ScaleInvariantSignalDistortionRatio
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
3 changes: 3 additions & 0 deletions docs/source/audio/scale_invariant_signal_noise_ratio.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/audio_classification.svg
:tags: Audio

.. include:: ../links.rst

##############################################
Scale-Invariant Signal-to-Noise Ratio (SI-SNR)
##############################################
Expand All @@ -12,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.ScaleInvariantSignalNoiseRatio
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
3 changes: 3 additions & 0 deletions docs/source/audio/short_time_objective_intelligibility.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/audio_classification.svg
:tags: Audio

.. include:: ../links.rst

###########################################
Short-Time Objective Intelligibility (STOI)
###########################################
Expand All @@ -12,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.audio.stoi.ShortTimeObjectiveIntelligibility
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
3 changes: 3 additions & 0 deletions docs/source/audio/signal_distortion_ratio.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/audio_classification.svg
:tags: Audio

.. include:: ../links.rst

################################
Signal to Distortion Ratio (SDR)
################################
Expand All @@ -12,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.SignalDistortionRatio
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/audio/signal_noise_ratio.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.SignalNoiseRatio
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
13 changes: 13 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,16 @@
.. _CLIP score: https://arxiv.org/pdf/2104.08718.pdf
.. _Huggingface OpenAI: https://huggingface.co/openai
.. _Theil's U: https://en.wikipedia.org/wiki/Uncertainty_coefficient
.. _Perceptual Evaluation of Speech Quality: https://en.wikipedia.org/wiki/Perceptual_Evaluation_of_Speech_Quality
.. _pesq package: https://github.com/ludlows/python-pesq
.. _Cees Taal's website: http://www.ceestaal.nl/code/
.. _pystoi package: https://github.com/mpariente/pystoi
.. _stoi ref1: https://ieeexplore.ieee.org/document/5495701
.. _stoi ref2: https://ieeexplore.ieee.org/document/5713237
.. _stoi ref3: https://ieeexplore.ieee.org/document/7539284
.. _sdr ref1: https://ieeexplore.ieee.org/document/1643671
.. _sdr ref2: https://arxiv.org/abs/2110.06440
.. _Scale-invariant signal-to-distortion ratio: https://arxiv.org/abs/1811.02508
.. _Scale-invariant signal-to-noise ratio: https://arxiv.org/abs/1711.00541
.. _Signal-to-noise ratio: https://arxiv.org/abs/1811.02508
.. _Permutation invariant training: https://arxiv.org/abs/1607.00325
37 changes: 18 additions & 19 deletions src/torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,28 @@


class PerceptualEvaluationSpeechQuality(Metric):
"""Perceptual Evaluation of Speech Quality (PESQ)
"""Calculates `Perceptual Evaluation of Speech Quality`_ (PESQ). It's a recognized industry standard for audio
quality that takes into considerations characteristics such as: audio sharpness, call volume, background noise,
clipping, audio interference ect. PESQ returns a score between -0.5 and 4.5 with the higher scores indicating a
better quality.
This is a wrapper for the pesq package [1]. Note that input will be moved to `cpu`
to perform the metric calculation.
This metric is a wrapper for the `pesq package`_. Note that input will be moved to ``cpu`` to perform the metric
calculation.
As input to ``forward`` and ``update`` the metric accepts the following input
- ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
- ``target`` (: :class:`~torch.Tensor`): float tensor with shape ``(...,time)``
As output of `forward` and `compute` the metric returns the following output
- ``pesq`` (: :class:`~torch.Tensor`): float tensor with shape ``(...,)`` of PESQ value per sample
.. note:: using this metrics requires you to have ``pesq`` install. Either install as ``pip install
torchmetrics[audio]`` or ``pip install pesq``. Note that ``pesq`` will compile with your currently
torchmetrics[audio]`` or ``pip install pesq``. ``pesq`` will compile with your currently
installed version of numpy, meaning that if you upgrade numpy at some point in the future you will
most likely have to reinstall ``pesq``.
Forward accepts
- ``preds``: ``shape [...,time]``
- ``target``: ``shape [...,time]``
Args:
fs: sampling frequency, should be 16000 or 8000 (Hz)
mode: ``'wb'`` (wide-band) or ``'nb'`` (narrow-band)
Expand Down Expand Up @@ -66,9 +73,6 @@ class PerceptualEvaluationSpeechQuality(Metric):
>>> wb_pesq = PerceptualEvaluationSpeechQuality(16000, 'wb')
>>> wb_pesq(preds, target)
tensor(1.7359)
References:
[1] https://github.com/ludlows/python-pesq
"""

sum_pesq: Tensor
Expand Down Expand Up @@ -104,12 +108,7 @@ def __init__(
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
"""Update state with predictions and targets."""
pesq_batch = perceptual_evaluation_speech_quality(
preds, target, self.fs, self.mode, False, self.n_processes
).to(self.sum_pesq.device)
Expand All @@ -118,5 +117,5 @@ def update(self, preds: Tensor, target: Tensor) -> None:
self.total += pesq_batch.numel()

def compute(self) -> Tensor:
"""Computes average PESQ."""
"""Computes metric."""
return self.sum_pesq / self.total
38 changes: 14 additions & 24 deletions src/torchmetrics/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,36 @@
from typing import Any, Callable, Dict

from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics.functional.audio.pit import permutation_invariant_training
from torchmetrics.metric import Metric


class PermutationInvariantTraining(Metric):
"""Permutation invariant training (PermutationInvariantTraining). The PermutationInvariantTraining implements
the famous Permutation Invariant Training method.
"""Calculates `Permutation invariant training`_ (PIT) that can evaluate models for speaker independent multi-
talker speech separation in a permutation invariant way.
[1] in speech separation field in order to calculate audio metrics in a permutation invariant way.
As input to ``forward`` and ``update`` the metric accepts the following input
Forward accepts
- ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(batch_size,num_speakers,...)``
- ``target`` (: :class:`~torch.Tensor`): float tensor with shape ``(batch_size,num_speakers,...)``
- ``preds``: ``shape [batch, spk, ...]``
- ``target``: ``shape [batch, spk, ...]``
As output of `forward` and `compute` the metric returns the following output
- ``pesq`` (: :class:`~torch.Tensor`): float scalar tensor with average PESQ value over samples
Args:
metric_func:
a metric function accept a batch of target and estimate,
i.e. ``metric_func(preds[:, i, ...], target[:, j, ...])``, and returns a batch of metric tensors ``[batch]``
i.e. ``metric_func(preds[:, i, ...], target[:, j, ...])``, and returns a batch of metric
tensors ``(batch,)``
eval_func:
the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better
or the larger the better.
kwargs: Additional keyword arguments for either the ``metric_func`` or distributed communication,
see :ref:`Metric kwargs` for more info.
Returns:
average PermutationInvariantTraining metric
Example:
>>> import torch
>>> from torchmetrics import PermutationInvariantTraining
Expand All @@ -54,11 +54,6 @@ class PermutationInvariantTraining(Metric):
>>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max')
>>> pit(preds, target)
tensor(-2.1065)
Reference:
[1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for
speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech
Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154.
"""

full_state_update: bool = False
Expand All @@ -69,7 +64,7 @@ class PermutationInvariantTraining(Metric):
def __init__(
self,
metric_func: Callable,
eval_func: str = "max",
eval_func: Literal["max", "min"] = "max",
**kwargs: Any,
) -> None:
base_kwargs: Dict[str, Any] = {
Expand All @@ -86,17 +81,12 @@ def __init__(
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
"""Update state with predictions and targets."""
pit_metric = permutation_invariant_training(preds, target, self.metric_func, self.eval_func, **self.kwargs)[0]

self.sum_pit_metric += pit_metric.sum()
self.total += pit_metric.numel()

def compute(self) -> Tensor:
"""Computes average PermutationInvariantTraining metric."""
"""Computes metric."""
return self.sum_pit_metric / self.total
Loading

0 comments on commit 8bcfb73

Please sign in to comment.