-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add generic to_funsor conversion methods for funsor.distributions #321
Conversation
funsor_dist_class = getattr(funsor.distributions, type(pyro_dist).__name__.split("__")[-1]) | ||
params = [to_funsor( | ||
getattr(pyro_dist, param_name), | ||
output=funsor_dist_class._infer_param_domain( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Distribution._infer_param_domain
is used here to enable a generic torchdistribution_to_funsor
conversion that works by converting TorchDistribution
parameters to funsors and constructing Distribution
funsors from the results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I understand correctly that this assumes all parameters are batched scalars? So e.g. MultivariateNormal
and LowRankMultivariateNormal
and Dirichlet
will need special casing? If so do you think we could add an assert .event_dim == 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_infer_param_domain
looks at the arg_constraint
for each parameter to infer their output
s. For example, in the case of MultivariateNormal
, arg_constraints["loc"]
is a constraints.RealVector
object, so _infer_param_domain
knows that its output shape is loc.shape[-1:]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Is there any logic you'd like reviewed with extra scrutiny?
Can you confirm these changes are pure refactoring and thus covered by existing tests?
@@ -104,7 +110,7 @@ def __getattribute__(self, attr): | |||
return super().__getattribute__(attr) | |||
|
|||
@classmethod | |||
@functools.lru_cache(maxsize=None) | |||
@functools.lru_cache(maxsize=5000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, were you seeing unbounded growth?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not in any existing tests, I just figured it would be good to add a ceiling.
Nope, this is mostly boring refactoring
Everything here should be covered by existing tests in |
This PR adds implementations of
to_funsor
andto_data
for the revampedfunsor.distributions.Distribution
from #320. These implementations make use of a genericDistribution._infer_param_domain
method that uses thearg_constraints
TorchDistribution
metadata to infer event shapes.The git history is kind of gross, but there's nothing too complicated that's new here.
Tested:
to_funsor
/to_data
implementations are exercised by existing tests intest/pyro/test_convert.py
via the changes in this PR tofunsor.pyro.convert