Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[src] Multichannel models #427

Merged
merged 10 commits into from
Feb 6, 2021
Merged

[src] Multichannel models #427

merged 10 commits into from
Feb 6, 2021

Conversation

mpariente
Copy link
Collaborator

Started implementing what I suggested in #420.

I'll highlight the places where I doubt in the PR.

Funny things I found along the way:

  • It's hard to make a model that doesn't take the sample rate as argument
class Model(BaseModel):
    def __init__(self):
        super().__init__(sample_rate=44100)
        
    def forward(self, x):
        return x

    def get_model_args(self):
        return {}

This can be serialized, but not loaded because sample_rate is needed in the model conf to load the model.

OTOH:

class Model(BaseModel):
    def __init__(self):
        super().__init__(sample_rate=44100)

    def forward(self, x):
        return x

    def get_model_args(self):
        return {"sample_rate": self.sample_rate}

This fails at loading, because it tries to pass sample_rate as kwargs, but cannot.
Should we do something about this?

super().__init__()
self.__sample_rate = sample_rate
self.n_channels = n_channels
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this needs to be a property as the sample_rate but we could do it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not ? some models are tied to the number of channels and actually to the array topology

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sample_rate, we made it like that because the model holds reference to the sample_rate, but the filterbank as well, so we wanted the raise the warning when setting it.

But for the number of channels, for now nothing holds reference to it.
If we see it's a limitation in the future we can always write a setter/getter.

Asteroid, as Python, is for consenting adults 😉

Comment on lines +93 to +97
if model.n_channels is not None and wav.shape[-2] != model.n_channels:
raise RuntimeError(
f"Model supports {model.n_channels}-channel inputs but found audio with {wav.shape[-2]} channels."
f"Please match the number of channels."
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a flag to ignore that, if passed, make us ignore that and take the first channels or something?
Something that would be passed from the CLI to here (--ignore-channels-check).
I'm not sure it's useful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@popcornell what's your opinion on that?
Maybe we can start by not having it. And if we find it useful later, or there is a user demand, we can change that.

asteroid/separate.py Outdated Show resolved Hide resolved
@@ -32,6 +32,24 @@ def test_set_sample_rate_raises_warning():
model.sample_rate = 16000.0


def test_multichannel_model_loading():
class MCModel(BaseModel):
def __init__(self, sample_rate=8000.0, n_channels=2):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As said in the PR. We must have the sample_rate as argument.

If we want fixed number of channels, we don't need it in the __init__, we just super(n_channels=2) and that works.

@mpariente
Copy link
Collaborator Author

Another note: there will be several scenarios for LambdaOverlapAdd:

  • Multichannel input, single output (should work fine)
  • Multichannel input, multichannel output (not sure it works).

We should add some tests and if the second case doesn't work, raise a useful error.

@jonashaag
Copy link
Collaborator

jonashaag commented Feb 3, 2021

This can be serialized, but not loaded because sample_rate is needed in the model conf to load the model.

How about moving the check for missing sample rate to after the model object has been constructed, and then checking using hasattr(model, "sample_rate")? That way you are free to set the sample_rate property however you like as long as it's present.

@mpariente
Copy link
Collaborator Author

This can be serialized, but not loaded because sample_rate is needed in the model conf to load the model.

How about moving the check for missing sample rate to after the model object has been constructed, and then checking using hasattr(model, "sample_rate")? That way you are free to set the sample_rate property however you like as long as it's present.

I thought about that.
That seems fine to me, let's do that.

@mpariente
Copy link
Collaborator Author

How about moving the check for missing sample rate to after the model object has been constructed, and then checking using hasattr(model, "sample_rate")? That way you are free to set the sample_rate property however you like as long as it's present.

sample_rate has a default value, so actually this won't work because the sample_rate property will be there.

@jonashaag
Copy link
Collaborator

jonashaag commented Feb 5, 2021

How about we drop it? It's backwards incompatible, but loading models will still work, and it's easy to fix for people. Maybe it's time we don't default to 8 kHz anymore now that people are using Asteroid for other things than traditional 8 kHz speech separation

@mpariente
Copy link
Collaborator Author

Well, I'd be ok with that!

@mpariente
Copy link
Collaborator Author

This is enough for this PR.
After merging this, I'll create a PR to remove the default on the sample rate.

@mpariente mpariente merged commit 40bba0d into master Feb 6, 2021
@mpariente mpariente deleted the multichannel branch February 6, 2021 21:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants