Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise an error if there is no common scale when model enumerated #1536

Merged
merged 7 commits into from
Feb 6, 2023

Conversation

ordabayevy
Copy link
Member

It turns out that current subsample scaling logic is not handled correctly in TraceEnum_ELBO (sorry for that). I added some tests to demonstrate that. I think the best place to handle subsampling scaling is in funsor.sum_product.sum_product. For now I added a check for a common scale which raises an error if there is no common scale.

else None
scales_set = set(
[
model_trace[name]["scale"]
Copy link
Member

@fehiepsi fehiepsi Feb 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to check the uniqueness of jax array scales because jax arrays is not hashable. Could you replace this by something like

scale = None
for name in (group_names | group_sum_vars):
    site_scale = model_trace[name]["scale"]
    if isinstance(scale, (int, float)) and isinstance(site_scale, (int, float, type(None))) and (site_scale != scale):
        raise ValueError(...)
    scale = site_scale

Btw, does this mean that we don't support enumeration for models with both global and local variables?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied the implementation from Pyro which first checks that array is scalar, converts it to scalar, and then compares them. Does it look alright to you?

https://github.com/pyro-ppl/pyro/blob/dev/pyro/infer/traceenum_elbo.py#L37-L41

Btw, does this mean that we don't support enumeration for models with both global and local variables?

Yeah, if you enumerate a global variable than you cannot subsample a local variable that depends on it.
I'm working on it but that will require first changing funsor.sum_product.sum_product to handle plate-wise scaling there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we can't convert a tracer to float, float(site_scale) will fail when scale is a tracer. Maybe we need to raise error if it is a jnp.array?

Copy link
Member Author

@ordabayevy ordabayevy Feb 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Changed it so that it raises an error if scale is jnp.array.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks, Yerdos.

@fehiepsi fehiepsi merged commit 643412d into master Feb 6, 2023
@fehiepsi fehiepsi deleted the subsample-scale branch February 6, 2023 22:06
@ordabayevy
Copy link
Member Author

Thanks for reviewing @fehiepsi .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants