Skip to content

Commit

Permalink
Remove dependency on pyro ParamStoreDict (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored Apr 12, 2019
1 parent 79a8621 commit 5053f55
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
3 changes: 2 additions & 1 deletion examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def guide(data):
# Report the final values of the variational parameters
# in the guide after training.
if args.verbose:
for name, value in pyro.get_param_store().items():
for name in pyro.get_param_store():
value = pyro.param(name).data
print("{} = {}".format(name, value.detach().cpu().numpy()))

# For this simple (conjugate) model we know the exact posterior. In
Expand Down
30 changes: 19 additions & 11 deletions funsor/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
"""
from __future__ import absolute_import, division, print_function

import weakref
from collections import OrderedDict, namedtuple

import torch
from pyro.params.param_store import ParamStoreDict

import funsor

Expand Down Expand Up @@ -61,7 +61,7 @@ def expand_inputs(self, name, size):
# See http://docs.pyro.ai/en/0.3.1/parameters.html

PYRO_STACK = []
PARAM_STORE = ParamStoreDict() # A dict-like object that also supports constraints.
PARAM_STORE = {} # maps name -> (unconstrained_value, constraint)


def get_param_store():
Expand Down Expand Up @@ -239,25 +239,33 @@ def sample(name, fn, obs=None):
return msg["value"]


# param is an effectful version of PARAM_STORE.setdefault
# param is an effectful version of PARAM_STORE.setdefault that also handles constraints.
# When any effect handlers are active, it constructs an initial message and calls apply_stack.
def param(name, init_value=None, constraint=None, event_dim=None):
def param(name, init_value=None, constraint=torch.distributions.constraints.real, event_dim=None):
cond_indep_stack = {}
output = None
if init_value is not None:
if event_dim is None:
event_dim = init_value.dim()
output = funsor.reals(*init_value.shape[init_value.dim() - event_dim:])
if constraint is None:
constraint = torch.distributions.constraints.real

def fn(init_value, constraint):
if init_value is None:
value = PARAM_STORE[name]
if name in PARAM_STORE:
unconstrained_value, constraint = PARAM_STORE[name]
else:
value = PARAM_STORE.setdefault(name, init_value, constraint)
value.unconstrained()._funsor_metadata = (cond_indep_stack, output)
return tensor_to_funsor(value, *value.unconstrained()._funsor_metadata)
# Initialize with a constrained value.
assert init_value is not None
with torch.no_grad():
constrained_value = init_value.detach()
unconstrained_value = torch.distributions.transform_to(constraint).inv(constrained_value)
unconstrained_value.requires_grad_()
unconstrained_value._funsor_metadata = (cond_indep_stack, output)
PARAM_STORE[name] = unconstrained_value, constraint

# Transform from unconstrained space to constrained space.
constrained_value = torch.distributions.transform_to(constraint)(unconstrained_value)
constrained_value.unconstrained = weakref.ref(unconstrained_value)
return tensor_to_funsor(constrained_value, *unconstrained_value._funsor_metadata)

# if there are no active Messengers, we just draw a sample and return it as expected:
if not PYRO_STACK:
Expand Down
2 changes: 1 addition & 1 deletion test/test_minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def guide():
assert_close(actual, expected)


@pytest.mark.parametrize("backend", ["pyro", "funsor"])
@pytest.mark.parametrize("backend", ["pyro", "minipyro", "funsor"])
def test_constraints(backend):
data = torch.tensor(0.5)

Expand Down

0 comments on commit 5053f55

Please sign in to comment.