-
Notifications
You must be signed in to change notification settings - Fork 422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[src] Multichannel models #427
Conversation
super().__init__() | ||
self.__sample_rate = sample_rate | ||
self.n_channels = n_channels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this needs to be a property as the sample_rate
but we could do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not ? some models are tied to the number of channels and actually to the array topology
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For sample_rate, we made it like that because the model holds reference to the sample_rate, but the filterbank as well, so we wanted the raise the warning when setting it.
But for the number of channels, for now nothing holds reference to it.
If we see it's a limitation in the future we can always write a setter/getter.
Asteroid, as Python, is for consenting adults 😉
if model.n_channels is not None and wav.shape[-2] != model.n_channels: | ||
raise RuntimeError( | ||
f"Model supports {model.n_channels}-channel inputs but found audio with {wav.shape[-2]} channels." | ||
f"Please match the number of channels." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a flag to ignore that, if passed, make us ignore that and take the first channels or something?
Something that would be passed from the CLI to here (--ignore-channels-check
).
I'm not sure it's useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@popcornell what's your opinion on that?
Maybe we can start by not having it. And if we find it useful later, or there is a user demand, we can change that.
@@ -32,6 +32,24 @@ def test_set_sample_rate_raises_warning(): | |||
model.sample_rate = 16000.0 | |||
|
|||
|
|||
def test_multichannel_model_loading(): | |||
class MCModel(BaseModel): | |||
def __init__(self, sample_rate=8000.0, n_channels=2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As said in the PR. We must have the sample_rate
as argument.
If we want fixed number of channels, we don't need it in the __init__
, we just super(n_channels=2)
and that works.
Another note: there will be several scenarios for
We should add some tests and if the second case doesn't work, raise a useful error. |
How about moving the check for missing sample rate to after the model object has been constructed, and then checking using |
I thought about that. |
|
How about we drop it? It's backwards incompatible, but loading models will still work, and it's easy to fix for people. Maybe it's time we don't default to 8 kHz anymore now that people are using Asteroid for other things than traditional 8 kHz speech separation |
Well, I'd be ok with that! |
This is enough for this PR. |
Started implementing what I suggested in #420.
I'll highlight the places where I doubt in the PR.
Funny things I found along the way:
This can be serialized, but not loaded because
sample_rate
is needed in the model conf to load the model.OTOH:
This fails at loading, because it tries to pass
sample_rate
as kwargs, but cannot.Should we do something about this?