-
Notifications
You must be signed in to change notification settings - Fork 422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[src] Multichannel models #427
Changes from 7 commits
9cd5fe5
6c43a9a
2c2c356
7d6c923
c1aae1b
49179cd
a9659fc
f3f4d46
46cb293
035c234
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,9 @@ class Protocol: | |
class Separatable(Protocol): | ||
"""Things that are separatable.""" | ||
|
||
def forward_wav(self, wav, **kwargs): | ||
n_channels: int | ||
|
||
def forward_wav(self, wav: torch.Tensor, **kwargs) -> torch.Tensor: | ||
""" | ||
Args: | ||
wav (torch.Tensor): waveform tensor. | ||
|
@@ -34,7 +36,7 @@ def forward_wav(self, wav, **kwargs): | |
... | ||
|
||
@property | ||
def sample_rate(self): | ||
def sample_rate(self) -> float: | ||
"""Operating sample rate of the model (float).""" | ||
... | ||
|
||
|
@@ -88,6 +90,11 @@ def separate( | |
@torch.no_grad() | ||
def torch_separate(model: Separatable, wav: torch.Tensor, **kwargs) -> torch.Tensor: | ||
"""Core logic of `separate`.""" | ||
if model.n_channels is not None and wav.shape[-2] != model.n_channels: | ||
raise RuntimeError( | ||
f"Model supports {model.n_channels}-channel inputs but found audio with {wav.shape[-2]} channels." | ||
f"Please match the number of channels." | ||
) | ||
Comment on lines
+94
to
+98
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a flag to ignore that, if passed, make us ignore that and take the first channels or something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @popcornell what's your opinion on that? |
||
# Handle device placement | ||
input_device = get_device(wav, default="cpu") | ||
model_device = get_device(model, default="cpu") | ||
|
@@ -159,7 +166,7 @@ def file_separate( | |
f"of {model.sample_rate}Hz. You can pass `resample=True` to resample automatically." | ||
) | ||
# Pass wav as [batch, n_chan, time]; here: [1, 1, time] | ||
mpariente marked this conversation as resolved.
Show resolved
Hide resolved
|
||
wav = wav[:, 0][None, None] | ||
wav = wav.T[None] | ||
(est_srcs,) = numpy_separate(model, wav, **kwargs) | ||
# Resample to original sr | ||
est_srcs = [ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,24 @@ def test_set_sample_rate_raises_warning(): | |
model.sample_rate = 16000.0 | ||
|
||
|
||
def test_multichannel_model_loading(): | ||
class MCModel(BaseModel): | ||
def __init__(self, sample_rate=8000.0, n_channels=2): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As said in the PR. We must have the If we want fixed number of channels, we don't need it in the |
||
super().__init__(sample_rate=sample_rate, n_channels=n_channels) | ||
|
||
def forward(self, x, **kwargs): | ||
return x | ||
|
||
def get_model_args(self): | ||
return {"sample_rate": self.sample_rate, "n_channels": self.n_channels} | ||
|
||
model = MCModel() | ||
model_conf = model.serialize() | ||
|
||
new_model = MCModel.from_pretrained(model_conf) | ||
assert model.n_channels == new_model.n_channels | ||
|
||
|
||
def test_convtasnet_sep(): | ||
nnet = ConvTasNet( | ||
n_src=2, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this needs to be a property as the
sample_rate
but we could do it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not ? some models are tied to the number of channels and actually to the array topology
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For sample_rate, we made it like that because the model holds reference to the sample_rate, but the filterbank as well, so we wanted the raise the warning when setting it.
But for the number of channels, for now nothing holds reference to it.
If we see it's a limitation in the future we can always write a setter/getter.
Asteroid, as Python, is for consenting adults 😉