diff --git a/asteroid/dsp/overlap_add.py b/asteroid/dsp/overlap_add.py index 0c8b5e4cc..c25c97432 100644 --- a/asteroid/dsp/overlap_add.py +++ b/asteroid/dsp/overlap_add.py @@ -66,6 +66,7 @@ def __init__( self.window_size = window_size self.hop_size = hop_size if hop_size is not None else window_size // 2 self.n_src = n_src + self.n_channels = getattr(nnet, "n_channels", None) if window: from scipy.signal import get_window # for torch.hub diff --git a/asteroid/models/base_models.py b/asteroid/models/base_models.py index 35f4b11e9..a4bf1ddad 100644 --- a/asteroid/models/base_models.py +++ b/asteroid/models/base_models.py @@ -1,5 +1,6 @@ import torch import warnings +from typing import Optional from .. import separate from ..masknn import activations @@ -22,20 +23,24 @@ def _unsqueeze_to_3d(x): class BaseModel(torch.nn.Module): """Base class for serializable models. - Defines saving/loading procedures as well as separation methods from - file, torch tensors and numpy arrays. - Need to overwrite the `forward` method, the `sample_rate` property and - the `get_model_args` method. + Defines saving/loading procedures, and separation interface to `separate`. + Need to overwrite the `forward` and `get_model_args` methods. Models inheriting from `BaseModel` can be used by :mod:`asteroid.separate` and by the `asteroid-infer` CLI. For models whose `forward` doesn't go from waveform to waveform tensors, overwrite `forward_wav` to return waveform tensors. + + Args: + sample_rate (float): Operating sample rate of the model. + n_channels: Supported number of channels of the model. + If None, no checks will be performed. """ - def __init__(self, sample_rate: float = 8000.0): + def __init__(self, sample_rate: float = 8000.0, n_channels: Optional[int] = 1): super().__init__() self.__sample_rate = sample_rate + self.n_channels = n_channels def forward(self, *args, **kwargs): raise NotImplementedError diff --git a/asteroid/separate.py b/asteroid/separate.py index 04be7d771..d6b73de5a 100644 --- a/asteroid/separate.py +++ b/asteroid/separate.py @@ -3,6 +3,7 @@ import torch import numpy as np import soundfile as sf +from typing import Optional try: from typing import Protocol @@ -19,7 +20,9 @@ class Protocol: class Separatable(Protocol): """Things that are separatable.""" - def forward_wav(self, wav, **kwargs): + n_channels: Optional[int] + + def forward_wav(self, wav: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: wav (torch.Tensor): waveform tensor. @@ -34,7 +37,7 @@ def forward_wav(self, wav, **kwargs): ... @property - def sample_rate(self): + def sample_rate(self) -> float: """Operating sample rate of the model (float).""" ... @@ -88,6 +91,11 @@ def separate( @torch.no_grad() def torch_separate(model: Separatable, wav: torch.Tensor, **kwargs) -> torch.Tensor: """Core logic of `separate`.""" + if model.n_channels is not None and wav.shape[-2] != model.n_channels: + raise RuntimeError( + f"Model supports {model.n_channels}-channel inputs but found audio with {wav.shape[-2]} channels." + f"Please match the number of channels." + ) # Handle device placement input_device = get_device(wav, default="cpu") model_device = get_device(model, default="cpu") @@ -158,8 +166,8 @@ def file_separate( f"Received a signal with a sampling rate of {fs}Hz for a model " f"of {model.sample_rate}Hz. You can pass `resample=True` to resample automatically." ) - # Pass wav as [batch, n_chan, time]; here: [1, 1, time] - wav = wav[:, 0][None, None] + # Pass wav as [batch, n_chan, time]; here: [1, chan, time] + wav = wav.T[None] (est_srcs,) = numpy_separate(model, wav, **kwargs) # Resample to original sr est_srcs = [ diff --git a/tests/models/models_test.py b/tests/models/models_test.py index 0d5c316d5..a5d4160bf 100644 --- a/tests/models/models_test.py +++ b/tests/models/models_test.py @@ -32,6 +32,24 @@ def test_set_sample_rate_raises_warning(): model.sample_rate = 16000.0 +def test_multichannel_model_loading(): + class MCModel(BaseModel): + def __init__(self, sample_rate=8000.0, n_channels=2): + super().__init__(sample_rate=sample_rate, n_channels=n_channels) + + def forward(self, x, **kwargs): + return x + + def get_model_args(self): + return {"sample_rate": self.sample_rate, "n_channels": self.n_channels} + + model = MCModel() + model_conf = model.serialize() + + new_model = MCModel.from_pretrained(model_conf) + assert model.n_channels == new_model.n_channels + + def test_convtasnet_sep(): nnet = ConvTasNet( n_src=2,