Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Trace_ELBO that generalizes Trace_ELBO, TraceEnum_ELBO, and TraceGraph_ELBO #2893

Draft
wants to merge 62 commits into
base: dev
Choose a base branch
from

Conversation

ordabayevy
Copy link
Member

@ordabayevy ordabayevy commented Jul 3, 2021

Design Doc

New version of Trace_ELBO that extends TraceEnum_ELBO:

I get wrong values for elbo (much larger absolute value compared to pyro.infer.trace_elbo.Trace_ELBO), presumably because sum(log_factors, to_funsor(0.0)) in line 41 broadcasts terms in log_factors and then that leads to large absolute values after elbo.reduce(funsor.ops.add, plate_vars) summation.

Here I try to fix it by reducing each cost term individually similar to TraceEnum_ELBO. I'm also not sure if integration is needed here.

Yerdos Ordabayev added 2 commits July 3, 2021 00:57
@ordabayevy ordabayevy added the bug label Jul 3, 2021
@fritzo fritzo added this to the 1.7 release milestone Jul 4, 2021

elbo = to_funsor(0.0)
for cost in costs:
elbo += cost.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this missing Dice factors included in log_measures? IIRC that was the reason for using Integrate.

Copy link
Member Author

@ordabayevy ordabayevy Jul 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied the test from #2894 which has a simple model/guide pair. When running that model (Elbo=Trace_ELBO, backend=contrib.funsor, reparam-False) both guide_terms["log_measures"] and model_terms["log_measures"] are empty. I can't find Dice factors anywhere in model_terms or guide_terms.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess they're not included because Funsor.sample isn't used in the evaluation of Trace_ELBO. I don't think contrib.funsor.infer.Trace_ELBO is tested extensively outside the pyro-api tests in tests/contrib/funsor/test_pyroapi_funsor.py, which is why this wasn't noticed before.

A more general Funsor-based implementation of Trace_ELBO is certainly possible and would look very similar to the guide-side enumeration handling logic in TraceEnum_ELBO. We might even be able to write a custom "enumeration" strategy that just called Funsor.sample and reuse TraceEnum_ELBO as the Trace_ELBO implementation.

I believe a completely general version might require variable elimination logic beyond what's currently in funsor.sum_product handling cases where the guide had plate structure incompatible with the restrictions there, although I can't immediately think of existing tests or examples where that would be the case.

@fritzo fritzo removed this from the 1.7 release milestone Jul 6, 2021
@ordabayevy ordabayevy added the WIP label Jul 16, 2021
@eb8680 eb8680 mentioned this pull request Aug 3, 2021
2 tasks
- df_a * logqa
- df_a * (qb * logqb).sum()
- df_a * (qb * df_c * logqc).sum()
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eb8680 can you check the math here? This is an example with b enumerated in the guide. Trace_ELBO works correctly here.

# +-----------+
# a -|-> b --> c |
# | \--> d |
# +-----------+
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e is observed and b is enumerated.

# guide (c is enumerated)
# +-----------+
# a -|-> b --> c |
# +-----------+
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d is observed and c is enumerated.

# guide (b is enumerated)
# +-----------+
# a -|-> b --> c |
# +-----------+
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d is observed and b is enumerated.

@ordabayevy ordabayevy marked this pull request as draft October 4, 2023 12:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants