Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly handle contraction of guide plates in TraceEnum_ELBO #1537

Merged
merged 3 commits into from
Feb 9, 2023

Conversation

ordabayevy
Copy link
Member

@ordabayevy ordabayevy commented Feb 8, 2023

There was a bug in TraceEnum_ELBO demonstrated in a new test_enum_elbo::test_guide_plate_contraction test. When a cost term depends on a non-reparameterizable guide site with extra plate dims they need to be product-contracted. Instead these plate dims were passed to _eager_contract_tensors as reduced_vars and sum-contracted. So I replaced _eager_contract_tensors with funsor.sum_product.sum_product which can do plated sum-products and eliminate plates.

@ordabayevy ordabayevy added the bug Something isn't working label Feb 8, 2023
def model(params):
with pyro.plate("a_axis", size=2):
a = pyro.sample("a", dist.Categorical(jnp.array([0.2, 0.8])))
pyro.sample("b", dist.Normal(jnp.sum(a), 1.0), obs=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that this invalidates this restriction: https://pyro.ai/examples/enumeration.html#Restriction-2:-no-downstream-coupling? Could you elaborate why we can enumerate here?

Copy link
Member Author

@ordabayevy ordabayevy Feb 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not being enumerated here but being used as a non-reparameterizable site. I should have used other distribution like Poisson to make it less confusing :)

But the point is that b depends on a non-reparametrizable site a which has a_axis plate. The dice_factor needs to be product-contracted to eliminate the extra a_axis plate before multiplying the cost term for site b. Instead a_axis is being passed to _eager_contract_tensors as reduced_vars and sum-contracted with logsumexp. Hope this clarifies it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. So we are using TraceEnum_ELBO but enumeration is disabled for those cases. Could you add a warning for this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean in general or for this test?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean in general. I think using TraceEnum_ELBO without enumeration is confusing. Maybe raise error if we can't enumerate sites with infer={"enumerate": "parallel"}?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the distribution to Poisson in the test since it can be any non-reparameterizable distributions. I can open another issue/PR for enumeration configuration since it is a separate issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds reasonable to me, thanks!

@fehiepsi fehiepsi merged commit 562e1be into master Feb 9, 2023
@fehiepsi fehiepsi deleted the enum-guide-plate branch February 9, 2023 21:34
@ordabayevy
Copy link
Member Author

Thanks for reviewing @fehiepsi !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants