Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generic to_funsor conversion methods for funsor.distributions #321

Merged
merged 64 commits into from
Mar 26, 2020

Conversation

eb8680
Copy link
Member

@eb8680 eb8680 commented Mar 4, 2020

This PR adds implementations of to_funsor and to_data for the revamped funsor.distributions.Distribution from #320. These implementations make use of a generic Distribution._infer_param_domain method that uses the arg_constraints TorchDistribution metadata to infer event shapes.

The git history is kind of gross, but there's nothing too complicated that's new here.

Tested:

@eb8680 eb8680 changed the base branch from distributions-refactor to master March 25, 2020 22:12
@eb8680 eb8680 removed the Blocked Blocked by other issues label Mar 25, 2020
@eb8680 eb8680 requested a review from fritzo March 25, 2020 22:17
funsor_dist_class = getattr(funsor.distributions, type(pyro_dist).__name__.split("__")[-1])
params = [to_funsor(
getattr(pyro_dist, param_name),
output=funsor_dist_class._infer_param_domain(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Distribution._infer_param_domain is used here to enable a generic torchdistribution_to_funsor conversion that works by converting TorchDistribution parameters to funsors and constructing Distribution funsors from the results.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly that this assumes all parameters are batched scalars? So e.g. MultivariateNormal and LowRankMultivariateNormal and Dirichlet will need special casing? If so do you think we could add an assert .event_dim == 0?

Copy link
Member Author

@eb8680 eb8680 Mar 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_infer_param_domain looks at the arg_constraint for each parameter to infer their outputs. For example, in the case of MultivariateNormal, arg_constraints["loc"] is a constraints.RealVector object, so _infer_param_domain knows that its output shape is loc.shape[-1:].

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Is there any logic you'd like reviewed with extra scrutiny?

Can you confirm these changes are pure refactoring and thus covered by existing tests?

@@ -104,7 +110,7 @@ def __getattribute__(self, attr):
return super().__getattribute__(attr)

@classmethod
@functools.lru_cache(maxsize=None)
@functools.lru_cache(maxsize=5000)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, were you seeing unbounded growth?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in any existing tests, I just figured it would be good to add a ceiling.

@eb8680
Copy link
Member Author

eb8680 commented Mar 26, 2020

Is there any logic you'd like reviewed with extra scrutiny?

Nope, this is mostly boring refactoring

Can you confirm these changes are pure refactoring and thus covered by existing tests?

Everything here should be covered by existing tests in test/pyro/test_convert.py.

@fritzo fritzo merged commit be19281 into master Mar 26, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants