From 9cd5fe5285e6271e9e8fc0d17dd594bbf16df8c2 Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Tue, 2 Feb 2021 21:41:59 +0100 Subject: [PATCH 01/10] Add n_channels to BaseModel --- asteroid/models/base_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/asteroid/models/base_models.py b/asteroid/models/base_models.py index 35f4b11e9..63caeea0f 100644 --- a/asteroid/models/base_models.py +++ b/asteroid/models/base_models.py @@ -33,9 +33,10 @@ class BaseModel(torch.nn.Module): waveform tensors. """ - def __init__(self, sample_rate: float = 8000.0): + def __init__(self, sample_rate: float = 8000.0, n_channels: int = 1): super().__init__() self.__sample_rate = sample_rate + self.n_channels = n_channels def forward(self, *args, **kwargs): raise NotImplementedError From 6c43a9a1716e049e3551c23b3e09f7e9cfc7aa04 Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Tue, 2 Feb 2021 21:44:07 +0100 Subject: [PATCH 02/10] Add commented problematic places --- asteroid/models/base_models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/asteroid/models/base_models.py b/asteroid/models/base_models.py index 63caeea0f..ababda1c6 100644 --- a/asteroid/models/base_models.py +++ b/asteroid/models/base_models.py @@ -179,6 +179,11 @@ def serialize(self): state_dict=self.get_state_dict(), model_args=self.get_model_args(), ) + # FIXME + # if hasattr(self, "sample_rate") and "sample_rate" not in model_conf["model_args"]: + # model_conf["model_args"]["sample_rate"] = self.sample_rate + # if hasattr(self, "n_channels") and "n_channels" not in model_conf["model_args"]: + # model_conf["model_args"]["n_channels"] = self.n_channels # Additional infos infos = dict() infos["software_versions"] = dict( From 2c2c356b8b22eb1a1a84a243115fc1d69bec0b27 Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Tue, 2 Feb 2021 21:44:49 +0100 Subject: [PATCH 03/10] Load/save test for multichannel models --- tests/models/models_test.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/models/models_test.py b/tests/models/models_test.py index 0d5c316d5..a5d4160bf 100644 --- a/tests/models/models_test.py +++ b/tests/models/models_test.py @@ -32,6 +32,24 @@ def test_set_sample_rate_raises_warning(): model.sample_rate = 16000.0 +def test_multichannel_model_loading(): + class MCModel(BaseModel): + def __init__(self, sample_rate=8000.0, n_channels=2): + super().__init__(sample_rate=sample_rate, n_channels=n_channels) + + def forward(self, x, **kwargs): + return x + + def get_model_args(self): + return {"sample_rate": self.sample_rate, "n_channels": self.n_channels} + + model = MCModel() + model_conf = model.serialize() + + new_model = MCModel.from_pretrained(model_conf) + assert model.n_channels == new_model.n_channels + + def test_convtasnet_sep(): nnet = ConvTasNet( n_src=2, From 7d6c9234ea32b08909f339af720030990c180b68 Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Tue, 2 Feb 2021 21:45:15 +0100 Subject: [PATCH 04/10] Add n_channels to Separatable --- asteroid/separate.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/asteroid/separate.py b/asteroid/separate.py index 04be7d771..70e7ea3a4 100644 --- a/asteroid/separate.py +++ b/asteroid/separate.py @@ -19,7 +19,9 @@ class Protocol: class Separatable(Protocol): """Things that are separatable.""" - def forward_wav(self, wav, **kwargs): + n_channels: int + + def forward_wav(self, wav: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: wav (torch.Tensor): waveform tensor. @@ -34,7 +36,7 @@ def forward_wav(self, wav, **kwargs): ... @property - def sample_rate(self): + def sample_rate(self) -> float: """Operating sample rate of the model (float).""" ... From c1aae1b0edeee1525b6cc6ae7f22d6fa6682ca33 Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Tue, 2 Feb 2021 21:51:35 +0100 Subject: [PATCH 05/10] Remove the comments on loading --- asteroid/models/base_models.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/asteroid/models/base_models.py b/asteroid/models/base_models.py index ababda1c6..63caeea0f 100644 --- a/asteroid/models/base_models.py +++ b/asteroid/models/base_models.py @@ -179,11 +179,6 @@ def serialize(self): state_dict=self.get_state_dict(), model_args=self.get_model_args(), ) - # FIXME - # if hasattr(self, "sample_rate") and "sample_rate" not in model_conf["model_args"]: - # model_conf["model_args"]["sample_rate"] = self.sample_rate - # if hasattr(self, "n_channels") and "n_channels" not in model_conf["model_args"]: - # model_conf["model_args"]["n_channels"] = self.n_channels # Additional infos infos = dict() infos["software_versions"] = dict( From 49179cd46c9f79ad26de0e09d7b1d1e20dec9823 Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Tue, 2 Feb 2021 22:04:34 +0100 Subject: [PATCH 06/10] Multichannel in separate --- asteroid/separate.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/asteroid/separate.py b/asteroid/separate.py index 70e7ea3a4..8706dc096 100644 --- a/asteroid/separate.py +++ b/asteroid/separate.py @@ -90,6 +90,11 @@ def separate( @torch.no_grad() def torch_separate(model: Separatable, wav: torch.Tensor, **kwargs) -> torch.Tensor: """Core logic of `separate`.""" + if model.n_channels is not None and wav.shape[-2] != model.n_channels: + raise RuntimeError( + f"Model supports {model.n_channels}-channel inputs but found audio with {wav.shape[-2]} channels." + f"Please match the number of channels." + ) # Handle device placement input_device = get_device(wav, default="cpu") model_device = get_device(model, default="cpu") @@ -161,7 +166,7 @@ def file_separate( f"of {model.sample_rate}Hz. You can pass `resample=True` to resample automatically." ) # Pass wav as [batch, n_chan, time]; here: [1, 1, time] - wav = wav[:, 0][None, None] + wav = wav.T[None] (est_srcs,) = numpy_separate(model, wav, **kwargs) # Resample to original sr est_srcs = [ From a9659fcc8c8b783565092265e7b99210df838e4c Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Tue, 2 Feb 2021 22:05:11 +0100 Subject: [PATCH 07/10] Edit BaseModel docstrings + add Optional on n_channels --- asteroid/models/base_models.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/asteroid/models/base_models.py b/asteroid/models/base_models.py index 63caeea0f..a4bf1ddad 100644 --- a/asteroid/models/base_models.py +++ b/asteroid/models/base_models.py @@ -1,5 +1,6 @@ import torch import warnings +from typing import Optional from .. import separate from ..masknn import activations @@ -22,18 +23,21 @@ def _unsqueeze_to_3d(x): class BaseModel(torch.nn.Module): """Base class for serializable models. - Defines saving/loading procedures as well as separation methods from - file, torch tensors and numpy arrays. - Need to overwrite the `forward` method, the `sample_rate` property and - the `get_model_args` method. + Defines saving/loading procedures, and separation interface to `separate`. + Need to overwrite the `forward` and `get_model_args` methods. Models inheriting from `BaseModel` can be used by :mod:`asteroid.separate` and by the `asteroid-infer` CLI. For models whose `forward` doesn't go from waveform to waveform tensors, overwrite `forward_wav` to return waveform tensors. + + Args: + sample_rate (float): Operating sample rate of the model. + n_channels: Supported number of channels of the model. + If None, no checks will be performed. """ - def __init__(self, sample_rate: float = 8000.0, n_channels: int = 1): + def __init__(self, sample_rate: float = 8000.0, n_channels: Optional[int] = 1): super().__init__() self.__sample_rate = sample_rate self.n_channels = n_channels From f3f4d46f0f7c8297311cb6f32db8e2c4982ff8a1 Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Tue, 2 Feb 2021 22:06:46 +0100 Subject: [PATCH 08/10] Optional on n_channels in separate.py --- asteroid/separate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/asteroid/separate.py b/asteroid/separate.py index 8706dc096..0629949cc 100644 --- a/asteroid/separate.py +++ b/asteroid/separate.py @@ -3,6 +3,7 @@ import torch import numpy as np import soundfile as sf +from typing import Optional try: from typing import Protocol @@ -19,7 +20,7 @@ class Protocol: class Separatable(Protocol): """Things that are separatable.""" - n_channels: int + n_channels: Optional[int] def forward_wav(self, wav: torch.Tensor, **kwargs) -> torch.Tensor: """ From 46cb2932eec0e8a561de5f59d738b2bdb50f8db4 Mon Sep 17 00:00:00 2001 From: Pariente Manuel Date: Tue, 2 Feb 2021 22:10:51 +0100 Subject: [PATCH 09/10] Update asteroid/separate.py --- asteroid/separate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asteroid/separate.py b/asteroid/separate.py index 0629949cc..d6b73de5a 100644 --- a/asteroid/separate.py +++ b/asteroid/separate.py @@ -166,7 +166,7 @@ def file_separate( f"Received a signal with a sampling rate of {fs}Hz for a model " f"of {model.sample_rate}Hz. You can pass `resample=True` to resample automatically." ) - # Pass wav as [batch, n_chan, time]; here: [1, 1, time] + # Pass wav as [batch, n_chan, time]; here: [1, chan, time] wav = wav.T[None] (est_srcs,) = numpy_separate(model, wav, **kwargs) # Resample to original sr From 035c2349c36963cccfa2b9fd8e9048334c2263fb Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Tue, 2 Feb 2021 22:13:10 +0100 Subject: [PATCH 10/10] Add n_channels to LambdaOverlapAdd --- asteroid/dsp/overlap_add.py | 1 + 1 file changed, 1 insertion(+) diff --git a/asteroid/dsp/overlap_add.py b/asteroid/dsp/overlap_add.py index 0c8b5e4cc..c25c97432 100644 --- a/asteroid/dsp/overlap_add.py +++ b/asteroid/dsp/overlap_add.py @@ -66,6 +66,7 @@ def __init__( self.window_size = window_size self.hop_size = hop_size if hop_size is not None else window_size // 2 self.n_src = n_src + self.n_channels = getattr(nnet, "n_channels", None) if window: from scipy.signal import get_window # for torch.hub