Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[src] Multichannel models #427

Merged
merged 10 commits into from
Feb 6, 2021
15 changes: 10 additions & 5 deletions asteroid/models/base_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import warnings
from typing import Optional

from .. import separate
from ..masknn import activations
Expand All @@ -22,20 +23,24 @@ def _unsqueeze_to_3d(x):
class BaseModel(torch.nn.Module):
"""Base class for serializable models.

Defines saving/loading procedures as well as separation methods from
file, torch tensors and numpy arrays.
Need to overwrite the `forward` method, the `sample_rate` property and
the `get_model_args` method.
Defines saving/loading procedures, and separation interface to `separate`.
Need to overwrite the `forward` and `get_model_args` methods.

Models inheriting from `BaseModel` can be used by :mod:`asteroid.separate`
and by the `asteroid-infer` CLI. For models whose `forward` doesn't go from
waveform to waveform tensors, overwrite `forward_wav` to return
waveform tensors.

Args:
sample_rate (float): Operating sample rate of the model.
n_channels: Supported number of channels of the model.
If None, no checks will be performed.
"""

def __init__(self, sample_rate: float = 8000.0):
def __init__(self, sample_rate: float = 8000.0, n_channels: Optional[int] = 1):
super().__init__()
self.__sample_rate = sample_rate
self.n_channels = n_channels
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this needs to be a property as the sample_rate but we could do it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not ? some models are tied to the number of channels and actually to the array topology

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sample_rate, we made it like that because the model holds reference to the sample_rate, but the filterbank as well, so we wanted the raise the warning when setting it.

But for the number of channels, for now nothing holds reference to it.
If we see it's a limitation in the future we can always write a setter/getter.

Asteroid, as Python, is for consenting adults 😉


def forward(self, *args, **kwargs):
raise NotImplementedError
Expand Down
13 changes: 10 additions & 3 deletions asteroid/separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ class Protocol:
class Separatable(Protocol):
"""Things that are separatable."""

def forward_wav(self, wav, **kwargs):
n_channels: int

def forward_wav(self, wav: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
wav (torch.Tensor): waveform tensor.
Expand All @@ -34,7 +36,7 @@ def forward_wav(self, wav, **kwargs):
...

@property
def sample_rate(self):
def sample_rate(self) -> float:
"""Operating sample rate of the model (float)."""
...

Expand Down Expand Up @@ -88,6 +90,11 @@ def separate(
@torch.no_grad()
def torch_separate(model: Separatable, wav: torch.Tensor, **kwargs) -> torch.Tensor:
"""Core logic of `separate`."""
if model.n_channels is not None and wav.shape[-2] != model.n_channels:
raise RuntimeError(
f"Model supports {model.n_channels}-channel inputs but found audio with {wav.shape[-2]} channels."
f"Please match the number of channels."
)
Comment on lines +94 to +98
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a flag to ignore that, if passed, make us ignore that and take the first channels or something?
Something that would be passed from the CLI to here (--ignore-channels-check).
I'm not sure it's useful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@popcornell what's your opinion on that?
Maybe we can start by not having it. And if we find it useful later, or there is a user demand, we can change that.

# Handle device placement
input_device = get_device(wav, default="cpu")
model_device = get_device(model, default="cpu")
Expand Down Expand Up @@ -159,7 +166,7 @@ def file_separate(
f"of {model.sample_rate}Hz. You can pass `resample=True` to resample automatically."
)
# Pass wav as [batch, n_chan, time]; here: [1, 1, time]
mpariente marked this conversation as resolved.
Show resolved Hide resolved
wav = wav[:, 0][None, None]
wav = wav.T[None]
(est_srcs,) = numpy_separate(model, wav, **kwargs)
# Resample to original sr
est_srcs = [
Expand Down
18 changes: 18 additions & 0 deletions tests/models/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ def test_set_sample_rate_raises_warning():
model.sample_rate = 16000.0


def test_multichannel_model_loading():
class MCModel(BaseModel):
def __init__(self, sample_rate=8000.0, n_channels=2):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As said in the PR. We must have the sample_rate as argument.

If we want fixed number of channels, we don't need it in the __init__, we just super(n_channels=2) and that works.

super().__init__(sample_rate=sample_rate, n_channels=n_channels)

def forward(self, x, **kwargs):
return x

def get_model_args(self):
return {"sample_rate": self.sample_rate, "n_channels": self.n_channels}

model = MCModel()
model_conf = model.serialize()

new_model = MCModel.from_pretrained(model_conf)
assert model.n_channels == new_model.n_channels


def test_convtasnet_sep():
nnet = ConvTasNet(
n_src=2,
Expand Down