Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update distribution shape inference to handle independent dims #402

Merged
merged 18 commits into from
Dec 17, 2020
Merged
101 changes: 81 additions & 20 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
import funsor.ops as ops
from funsor.affine import is_affine
from funsor.cnf import Contraction, GaussianMixture
from funsor.domains import Array, Real, Reals
from funsor.domains import Array, Real, Reals, RealsType
from funsor.gaussian import Gaussian
from funsor.interpreter import gensym
from funsor.tensor import (Tensor, align_tensors, dummy_numeric_array, get_default_prototype,
ignore_jit_warnings, numeric_array, stack)
from funsor.terms import Funsor, FunsorMeta, Independent, Number, Variable, \
eager, to_data, to_funsor
eager, reflect, to_data, to_funsor
from funsor.util import broadcast_shape, get_backend, getargspec, lazy_property


Expand Down Expand Up @@ -57,12 +57,36 @@ class DistributionMeta(FunsorMeta):
"""
def __call__(cls, *args, **kwargs):
kwargs.update(zip(cls._ast_fields, args))
value = kwargs.pop('value', 'value')
kwargs = OrderedDict(
(k, to_funsor(kwargs[k], output=cls._infer_param_domain(k, getattr(kwargs[k], "shape", ()))))
for k in cls._ast_fields if k != 'value')
value = to_funsor(value, output=cls._infer_value_domain(**{k: v.output for k, v in kwargs.items()}))
args = numbers_to_tensors(*(tuple(kwargs.values()) + (value,)))
kwargs["value"] = kwargs.get("value", "value")
kwargs = OrderedDict((k, kwargs[k]) for k in cls._ast_fields) # make sure args are sorted

domains = OrderedDict()
for k, v in kwargs.items():
if k == "value":
continue

# compute unbroadcasted param domains
domain = cls._infer_param_domain(k, getattr(kwargs[k], "shape", ()))
# use to_funsor to infer output dimensions of e.g. tensors
domains[k] = domain if domain is not None else to_funsor(v).output

# broadcast individual param domains with Funsor inputs
# this avoids .expand-ing underlying parameter tensors
Copy link
Member

@fehiepsi fehiepsi Dec 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expected domain of scale for Normal(Reals[2], 1.) and Normal(Reals[2], torch.ones(2))? Currently, domains["scale"] will be Real in both case. The second case will trigger an error at to_funsor(v, output=domains[k]) below.

In either case, I guess we need to rewrite eager_normal or eager_mvn to address Reals[2] loc. Maybe there is some trick to avoid doing so. cc @fritzo

Copy link
Member Author

@eb8680 eb8680 Dec 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expected domain of scale for Normal(Reals[2], 1.) and Normal(Reals[2], torch.ones(2))?

In the first case, it's Real, and in the second, it's Reals[2]. I guess I should add a second broadcasting condition below to handle the case where the parameter is a raw tensor:

if ops.is_numeric_array(v):  # at this point we know all of v's dims are output dims
    domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)]

if isinstance(v, Funsor) and isinstance(v.output, RealsType):
domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)]
elif ops.is_numeric_array(v):
domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)]

# now use the broadcasted parameter shapes to infer the event_shape
domains["value"] = cls._infer_value_domain(**domains)
if isinstance(kwargs["value"], Funsor) and isinstance(kwargs["value"].output, RealsType):
# try to broadcast the event shape with the value, in case they disagree
domains["value"] = Reals[broadcast_shape(domains["value"].shape, kwargs["value"].output.shape)]

# finally, perform conversions to funsors
kwargs = OrderedDict((k, to_funsor(v, output=domains[k])) for k, v in kwargs.items())
args = numbers_to_tensors(*kwargs.values())

return super(DistributionMeta, cls).__call__(*args)


Expand Down Expand Up @@ -98,14 +122,6 @@ def eager_reduce(self, op, reduced_vars):
return Number(0.) # distributions are normalized
return super(Distribution, self).eager_reduce(op, reduced_vars)

@classmethod
def eager_log_prob(cls, *params):
inputs, tensors = align_tensors(*params)
params = dict(zip(cls._ast_fields, tensors))
value = params.pop('value')
data = cls.dist_class(**params).log_prob(value)
return Tensor(data, inputs)

def _get_raw_dist(self):
"""
Internal method for working with underlying distribution attributes
Expand All @@ -129,6 +145,23 @@ def has_rsample(self):
def has_enumerate_support(self):
return getattr(self.dist_class, "has_enumerate_support", False)

@classmethod
def eager_log_prob(cls, *params):
params, value = params[:-1], params[-1]
params = params + (Variable("value", value.output),)
instance = reflect(cls, *params)
raw_dist, value_name, value_output, dim_to_name = instance._get_raw_dist()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to refactor eager_log_prob to use Distribution._get_raw_dist() to get the new tests to pass.

assert value.output == value_output
name_to_dim = {v: k for k, v in dim_to_name.items()}
dim_to_name.update({-1 - d - len(raw_dist.batch_shape): name
for d, name in enumerate(value.inputs) if name not in name_to_dim})
name_to_dim.update({v: k for k, v in dim_to_name.items() if v not in name_to_dim})
raw_log_prob = raw_dist.log_prob(to_data(value, name_to_dim=name_to_dim))
log_prob = to_funsor(raw_log_prob, Real, dim_to_name=dim_to_name)
# final align() ensures that the inputs have the canonical order
# implied by align_tensors, which is assumed pervasively in tests
return log_prob.align(tuple(align_tensors(*(params[:-1] + (value,)))[0]))
eb8680 marked this conversation as resolved.
Show resolved Hide resolved

def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):

# note this should handle transforms correctly via distribution_to_data
Expand Down Expand Up @@ -191,7 +224,13 @@ def _infer_value_domain(cls, **kwargs):
# rely on the underlying distribution's logic to infer the event_shape given param domains
instance = cls.dist_class(**{k: dummy_numeric_array(domain) for k, domain in kwargs.items()},
validate_args=False)
out_shape = instance.event_shape

# Note inclusion of batch_shape here to handle independent event dimensions.
# The arguments to _infer_value_domain are the .output shapes of parameters,
# so any extra batch dimensions that aren't part of the instance event_shape
# must be broadcasted output dimensions by construction.
out_shape = instance.batch_shape + instance.event_shape
Copy link
Member Author

@eb8680 eb8680 Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change to _infer_value_domain is the conceptual meat of the PR.


if type(instance.support).__name__ == "_IntegerInterval":
out_dtype = int(instance.support.upper_bound + 1)
else:
Expand Down Expand Up @@ -400,10 +439,32 @@ def __call__(self, cls, args, kwargs):

@to_data.register(Distribution)
def distribution_to_data(funsor_dist, name_to_dim=None):
params = [to_data(getattr(funsor_dist, param_name), name_to_dim=name_to_dim)
for param_name in funsor_dist._ast_fields if param_name != 'value']
pyro_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], params)))
funsor_event_shape = funsor_dist.value.output.shape

# attempt to generically infer the independent output dimensions
instance = funsor_dist.dist_class(**{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beyond the scope of this PR, I'm concerned with the increasing overhead of shape computations that need to do tensor ops. I like @fehiepsi's recent suggestion of implementing .forward_event_shape() for transforms. I think it would be worthwhile to discuss and think about extensions to the Distribution interface that could replace all this need to create an throw away dummy distributions.

(Indeed in theory an optimizing compiler could remove all this overhead, but in practice our tensor backends either incur super-linear compile time cost, or fail to cover the wide range of probabilistic models we would like to handle. And while these dummy tensor ops are cheap, they add noise to debugging efforts.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree the repeated creation of distribution instances here is not ideal. Perhaps we could add counterparts of some of the shape inference methods from TFP (e.g. event_shape_tensor, param_shapes) upstream in torch.distributions.

k: dummy_numeric_array(domain)
for k, domain in zip(funsor_dist._ast_fields, (v.output for v in funsor_dist._ast_values))
if k != "value"
}, validate_args=False)
eb8680 marked this conversation as resolved.
Show resolved Hide resolved
event_shape = broadcast_shape(instance.event_shape, funsor_dist.value.output.shape)
# XXX is this final broadcast_shape necessary? should just be event_shape[...]?
indep_shape = broadcast_shape(instance.batch_shape, event_shape[:len(event_shape) - len(instance.event_shape)])
eb8680 marked this conversation as resolved.
Show resolved Hide resolved

params = []
for param_name, funsor_param in zip(funsor_dist._ast_fields, funsor_dist._ast_values[:-1]):
param = to_data(funsor_param, name_to_dim=name_to_dim)

# infer the independent dimensions of each parameter separately, since we chose to keep them unbroadcasted
param_event_shape = getattr(funsor_dist._infer_param_domain(param_name, funsor_param.output.shape), "shape", ())
param_indep_shape = funsor_param.output.shape[:len(funsor_param.output.shape) - len(param_event_shape)]
for i in range(max(0, len(indep_shape) - len(param_indep_shape))):
# add singleton event dimensions, leave broadcasting/expanding to backend
param = ops.unsqueeze(param, -1 - len(funsor_param.output.shape))

params.append(param)

pyro_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], params)))
pyro_dist = pyro_dist.to_event(max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0))

# TODO get this working for all backends
Expand Down
7 changes: 7 additions & 0 deletions funsor/jax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ def deltadist_to_data(funsor_dist, name_to_dim=None):
# Converting PyTorch Distributions to funsors
###############################################

dist.Independent.has_rsample = property(lambda self: self.base_dist.has_rsample)
dist.Independent.rsample = dist.Independent.sample
dist.MaskedDistribution.has_rsample = property(lambda self: self.base_dist.has_rsample)
dist.MaskedDistribution.rsample = dist.MaskedDistribution.sample
dist.TransformedDistribution.has_rsample = property(lambda self: self.base_dist.has_rsample)
dist.TransformedDistribution.rsample = dist.TransformedDistribution.sample
Comment on lines +219 to +224
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a TODO pointing to a NumPyro issue to fix this bug, so we can delete this workaround once the bug is fixed? cc @fehiepsi

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad Should we add those new attributes to NumPyro distributions? We can make default behaviors for them so that there would be only a few changes in the code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. In numpyro, for distributions that have reparametrized samplers available both will be the same so we can just add a default Distribution.rsample method which delegates to sample and throws a NotImplemented error when not available.


to_funsor.register(dist.Independent)(indepdist_to_funsor)
if hasattr(dist, "MaskedDistribution"):
to_funsor.register(dist.MaskedDistribution)(maskeddist_to_funsor)
Expand Down
2 changes: 1 addition & 1 deletion funsor/torch/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _infer_value_domain(**kwargs):
@functools.lru_cache(maxsize=5000)
def _infer_value_domain(cls, **kwargs):
instance = cls.dist_class(**{k: dummy_numeric_array(domain) for k, domain in kwargs.items()}, validate_args=False)
return Reals[instance.event_shape]
return Reals[instance.batch_shape + instance.event_shape]
fritzo marked this conversation as resolved.
Show resolved Hide resolved


# TODO fix Delta.arg_constraints["v"] to be a
Expand Down
65 changes: 64 additions & 1 deletion test/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def dirichlet(concentration: Reals[event_shape],
check_funsor(expected, inputs, Real)
actual = dist.Dirichlet(concentration, value)
check_funsor(actual, inputs, Real)
assert_close(actual, expected)
assert_close(actual, expected, atol=1e-4)


@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
Expand Down Expand Up @@ -1123,3 +1123,66 @@ def test_gamma_poisson_conjugate(batch_shape):

obs = Tensor(ops.astype(ops.astype(ops.exp(randn(batch_shape)), 'int32'), 'float32'), inputs)
_assert_conjugate_density_ok(latent, conditional, obs)


@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
@pytest.mark.parametrize('event_shape', [(4,), (4, 7), (1, 4), (4, 1), (4, 1, 7)], ids=str)
@pytest.mark.parametrize('use_raw_scale', [False, True])
def test_normal_event_dim_conversion(batch_shape, event_shape, use_raw_scale):

batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))

value = Variable("value", Reals[event_shape])
loc = Tensor(randn(batch_shape + event_shape), inputs)
scale = Tensor(ops.exp(randn(batch_shape)), inputs)
if use_raw_scale:
if batch_shape:
pytest.xfail(reason="raw scale is underspecified for nonempty batch_shape")
scale = scale.data

with interpretation(lazy):
actual = dist.Normal(loc=loc, scale=scale, value=value)

expected_inputs = inputs.copy()
expected_inputs.update({"value": Reals[event_shape]})
check_funsor(actual, expected_inputs, Real)

name_to_dim = {batch_dim: -1-i for i, batch_dim in enumerate(batch_dims)}
rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32)
data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0][1][0]

actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim)
expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob(
funsor.to_data(data, name_to_dim=name_to_dim))
assert actual_log_prob.shape == expected_log_prob.shape
assert_close(actual_log_prob, expected_log_prob)


@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
@pytest.mark.parametrize('event_shape', [(4,), (4, 7), (1, 4), (4, 1), (4, 1, 7)], ids=str)
def test_mvnormal_event_dim_conversion(batch_shape, event_shape):

batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))

value = Variable("value", Reals[event_shape])
loc = Tensor(randn(batch_shape + event_shape), inputs)
scale_tril = Tensor(random_scale_tril(batch_shape + event_shape + event_shape[-1:]), inputs)

with interpretation(lazy):
actual = dist.MultivariateNormal(loc=loc, scale_tril=scale_tril, value=value)

expected_inputs = inputs.copy()
expected_inputs.update({"value": Reals[event_shape]})
check_funsor(actual, expected_inputs, Real)

name_to_dim = {batch_dim: -1-i for i, batch_dim in enumerate(batch_dims)}
rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32)
data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0][1][0]

actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim)
expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob(
funsor.to_data(data, name_to_dim=name_to_dim))
assert actual_log_prob.shape == expected_log_prob.shape
assert_close(actual_log_prob, expected_log_prob)