From 0e95347c3bde6247a272182198e0b70fb3969d61 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 1 Dec 2020 14:32:06 -0500 Subject: [PATCH 01/17] Update distribution shape inference to handle independent dims --- funsor/distribution.py | 44 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index 4537e282f..52255f23b 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -16,7 +16,7 @@ import funsor.ops as ops from funsor.affine import is_affine from funsor.cnf import Contraction, GaussianMixture -from funsor.domains import Array, Real, Reals +from funsor.domains import Array, Real, Reals, RealsType from funsor.gaussian import Gaussian from funsor.interpreter import gensym from funsor.tensor import (Tensor, align_tensors, dummy_numeric_array, get_default_prototype, @@ -57,12 +57,34 @@ class DistributionMeta(FunsorMeta): """ def __call__(cls, *args, **kwargs): kwargs.update(zip(cls._ast_fields, args)) - value = kwargs.pop('value', 'value') - kwargs = OrderedDict( - (k, to_funsor(kwargs[k], output=cls._infer_param_domain(k, getattr(kwargs[k], "shape", ())))) - for k in cls._ast_fields if k != 'value') - value = to_funsor(value, output=cls._infer_value_domain(**{k: v.output for k, v in kwargs.items()})) - args = numbers_to_tensors(*(tuple(kwargs.values()) + (value,))) + kwargs["value"] = kwargs.get("value", "value") + kwargs = OrderedDict((k, kwargs[k]) for k in cls._ast_fields) # make sure args are sorted + + domains = OrderedDict() + for k, v in kwargs.items(): + if k == "value": + continue + + # compute unbroadcasted param domains + domain = cls._infer_param_domain(k, getattr(kwargs[k], "shape", ())) + # use to_funsor to infer output dimensions of e.g. tensors + domains[k] = domain if domain is not None else to_funsor(v).output + + # broadcast individual param domains with Funsor inputs + # this avoids .expand-ing underlying parameter tensors + if isinstance(v, Funsor) and isinstance(v.output, RealsType): + domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)] + + # now use the broadcasted parameter shapes to infer the event_shape + domains["value"] = cls._infer_value_domain(**domains) + if isinstance(kwargs["value"], Funsor) and isinstance(kwargs["value"].output, RealsType): + # try to broadcast the event shape with the value, in case they disagree + domains["value"] = Reals[broadcast_shape(domains["value"].shape, kwargs["value"].output.shape)] + + # finally, perform conversions to funsors + kwargs = OrderedDict((k, to_funsor(v, output=domains[k])) for k, v in kwargs.items()) + args = numbers_to_tensors(*kwargs.values()) + return super(DistributionMeta, cls).__call__(*args) @@ -191,7 +213,13 @@ def _infer_value_domain(cls, **kwargs): # rely on the underlying distribution's logic to infer the event_shape given param domains instance = cls.dist_class(**{k: dummy_numeric_array(domain) for k, domain in kwargs.items()}, validate_args=False) - out_shape = instance.event_shape + + # Note inclusion of batch_shape here to handle independent event dimensions. + # The arguments to _infer_value_domain are the .output shapes of parameters, + # so any extra batch dimensions that aren't part of the instance event_shape + # must be broadcasted output dimensions by construction. + out_shape = instance.batch_shape + instance.event_shape + if type(instance.support).__name__ == "_IntegerInterval": out_dtype = int(instance.support.upper_bound + 1) else: From 5f8357380db0c9ea0c033db1936dc3e493be238f Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 1 Dec 2020 23:28:36 -0500 Subject: [PATCH 02/17] update eager_log_prob and add a test --- funsor/distribution.py | 35 ++++++++++++++++++++----------- test/test_distribution.py | 28 +++++++++++++++++++++++++ test/test_distribution_generic.py | 4 +++- 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index 52255f23b..bdb8cb0e2 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -22,7 +22,7 @@ from funsor.tensor import (Tensor, align_tensors, dummy_numeric_array, get_default_prototype, ignore_jit_warnings, numeric_array, stack) from funsor.terms import Funsor, FunsorMeta, Independent, Number, Variable, \ - eager, to_data, to_funsor + eager, reflect, to_data, to_funsor from funsor.util import broadcast_shape, get_backend, getargspec, lazy_property @@ -120,14 +120,6 @@ def eager_reduce(self, op, reduced_vars): return Number(0.) # distributions are normalized return super(Distribution, self).eager_reduce(op, reduced_vars) - @classmethod - def eager_log_prob(cls, *params): - inputs, tensors = align_tensors(*params) - params = dict(zip(cls._ast_fields, tensors)) - value = params.pop('value') - data = cls.dist_class(**params).log_prob(value) - return Tensor(data, inputs) - def _get_raw_dist(self): """ Internal method for working with underlying distribution attributes @@ -151,6 +143,19 @@ def has_rsample(self): def has_enumerate_support(self): return getattr(self.dist_class, "has_enumerate_support", False) + @classmethod + def eager_log_prob(cls, *params): + params, value = params[:-1], params[-1] + params = params + (Variable("value", value.output),) + raw_dist, value_name, value_output, dim_to_name = reflect(cls, *params)._get_raw_dist() + assert value.output == value_output + name_to_dim = {v: k for k, v in dim_to_name.items()} + dim_to_name.update({-1 - d -len(raw_dist.batch_shape): name + for d, name in enumerate(value.inputs) if name not in name_to_dim}) + name_to_dim.update({v: k for k, v in dim_to_name.items() if v not in name_to_dim}) + raw_log_prob = raw_dist.log_prob(to_data(value, name_to_dim=name_to_dim)) + return to_funsor(raw_log_prob, Real, dim_to_name=dim_to_name) + def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): # note this should handle transforms correctly via distribution_to_data @@ -428,11 +433,17 @@ def __call__(self, cls, args, kwargs): @to_data.register(Distribution) def distribution_to_data(funsor_dist, name_to_dim=None): - params = [to_data(getattr(funsor_dist, param_name), name_to_dim=name_to_dim) - for param_name in funsor_dist._ast_fields if param_name != 'value'] - pyro_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], params))) funsor_event_shape = funsor_dist.value.output.shape + params = [] + for param_name, funsor_param in zip(funsor_dist._ast_fields, funsor_dist._ast_values[:-1]): + import pdb; pdb.set_trace() + param = to_data(funsor_param, name_to_dim=name_to_dim) + for i in range(max(0, len(funsor_event_shape) - len(funsor_param.output.shape))): + param = param.unsqueeze(-1 - len(funsor_param.output.shape)) + params.append(param) + pyro_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], params))) pyro_dist = pyro_dist.to_event(max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0)) + import pdb; pdb.set_trace() # TODO get this working for all backends if not isinstance(funsor_dist.value, Variable): diff --git a/test/test_distribution.py b/test/test_distribution.py index 2f621779d..e66e420b7 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -1123,3 +1123,31 @@ def test_gamma_poisson_conjugate(batch_shape): obs = Tensor(ops.astype(ops.astype(ops.exp(randn(batch_shape)), 'int32'), 'float32'), inputs) _assert_conjugate_density_ok(latent, conditional, obs) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('event_shape', [(4,), (4, 7), (1, 4), (4, 1), (4, 1, 7)], ids=str) +def test_normal_event_dim_conversion(batch_shape, event_shape): + + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) + + value = Variable("value", Reals[event_shape]) + loc = Tensor(randn(batch_shape + event_shape), inputs) + scale = Tensor(ops.exp(randn(batch_shape)), inputs) + + with interpretation(lazy): + actual = dist.Normal(loc=loc, scale=scale, value=value) + + expected_inputs = inputs.copy() + expected_inputs.update({"value": Reals[event_shape]}) + check_funsor(actual, expected_inputs, Real) + + name_to_dim = {batch_dim: -1-i for i, batch_dim in enumerate(batch_dims)} + data = actual.sample(frozenset(["value"])).terms[0][1][0] + + actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim) + expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob( + funsor.to_data(data, name_to_dim=name_to_dim)) + assert actual_log_prob.shape == expected_log_prob.shape + assert_close(actual_log_prob, expected_log_prob) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 50a00e237..d72f67653 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -459,6 +459,7 @@ def test_generic_distribution_to_funsor(case): funsor_dist = [term for term in funsor_dist.terms if isinstance(term, (funsor.distribution.Distribution, funsor.terms.Independent))][0] + import pdb; pdb.set_trace() actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim) assert isinstance(actual_dist, backend_dist.Distribution) @@ -496,7 +497,8 @@ def test_generic_log_prob(case, use_lazy): raw_value = raw_dist.sample() expected_logprob = to_funsor(raw_dist.log_prob(raw_value), output=funsor.Real, dim_to_name=dim_to_name) funsor_value = to_funsor(raw_value, output=expected_value_domain, dim_to_name=dim_to_name) - assert_close(funsor_dist(value=funsor_value), expected_logprob, rtol=1e-4 if use_lazy else 1e-3) + actual_logprob = funsor_dist(value=funsor_value).align(tuple(expected_logprob.inputs)) + assert_close(actual_logprob, expected_logprob, rtol=1e-4 if use_lazy else 1e-3) @pytest.mark.parametrize("case", TEST_CASES, ids=str) From 13d94f4fdd5137bea334ad0202b4957d4497f530 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Dec 2020 11:22:52 -0500 Subject: [PATCH 03/17] remove pdb, add mvnormal test --- funsor/distribution.py | 2 -- test/test_distribution.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index bdb8cb0e2..5177653c1 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -436,14 +436,12 @@ def distribution_to_data(funsor_dist, name_to_dim=None): funsor_event_shape = funsor_dist.value.output.shape params = [] for param_name, funsor_param in zip(funsor_dist._ast_fields, funsor_dist._ast_values[:-1]): - import pdb; pdb.set_trace() param = to_data(funsor_param, name_to_dim=name_to_dim) for i in range(max(0, len(funsor_event_shape) - len(funsor_param.output.shape))): param = param.unsqueeze(-1 - len(funsor_param.output.shape)) params.append(param) pyro_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], params))) pyro_dist = pyro_dist.to_event(max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0)) - import pdb; pdb.set_trace() # TODO get this working for all backends if not isinstance(funsor_dist.value, Variable): diff --git a/test/test_distribution.py b/test/test_distribution.py index e66e420b7..244be0406 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -1151,3 +1151,31 @@ def test_normal_event_dim_conversion(batch_shape, event_shape): funsor.to_data(data, name_to_dim=name_to_dim)) assert actual_log_prob.shape == expected_log_prob.shape assert_close(actual_log_prob, expected_log_prob) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('event_shape', [(4,), (4, 7), (1, 4), (4, 1), (4, 1, 7)], ids=str) +def test_mvnormal_event_dim_conversion(batch_shape, event_shape): + + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) + + value = Variable("value", Reals[event_shape]) + loc = Tensor(randn(batch_shape + event_shape), inputs) + scale_tril = Tensor(random_scale_tril(batch_shape + event_shape + event_shape[-1:]), inputs) + + with interpretation(lazy): + actual = dist.MultivariateNormal(loc=loc, scale_tril=scale_tril, value=value) + + expected_inputs = inputs.copy() + expected_inputs.update({"value": Reals[event_shape]}) + check_funsor(actual, expected_inputs, Real) + + name_to_dim = {batch_dim: -1-i for i, batch_dim in enumerate(batch_dims)} + data = actual.sample(frozenset(["value"])).terms[0][1][0] + + actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim) + expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob( + funsor.to_data(data, name_to_dim=name_to_dim)) + assert actual_log_prob.shape == expected_log_prob.shape + assert_close(actual_log_prob, expected_log_prob) From 7b8c2eefff5ec2cd8ebbd64eb3fc29a858658554 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Dec 2020 11:46:29 -0500 Subject: [PATCH 04/17] fix alignment in eager_log_prob --- funsor/distribution.py | 12 ++++++++---- test/test_distribution_generic.py | 3 +-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index 5177653c1..40b180528 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -16,7 +16,7 @@ import funsor.ops as ops from funsor.affine import is_affine from funsor.cnf import Contraction, GaussianMixture -from funsor.domains import Array, Real, Reals, RealsType +from funsor.domains import Array, BintType, Real, Reals, RealsType from funsor.gaussian import Gaussian from funsor.interpreter import gensym from funsor.tensor import (Tensor, align_tensors, dummy_numeric_array, get_default_prototype, @@ -147,14 +147,18 @@ def has_enumerate_support(self): def eager_log_prob(cls, *params): params, value = params[:-1], params[-1] params = params + (Variable("value", value.output),) - raw_dist, value_name, value_output, dim_to_name = reflect(cls, *params)._get_raw_dist() + instance = reflect(cls, *params) + raw_dist, value_name, value_output, dim_to_name = instance._get_raw_dist() assert value.output == value_output name_to_dim = {v: k for k, v in dim_to_name.items()} - dim_to_name.update({-1 - d -len(raw_dist.batch_shape): name + dim_to_name.update({-1 - d - len(raw_dist.batch_shape): name for d, name in enumerate(value.inputs) if name not in name_to_dim}) name_to_dim.update({v: k for k, v in dim_to_name.items() if v not in name_to_dim}) raw_log_prob = raw_dist.log_prob(to_data(value, name_to_dim=name_to_dim)) - return to_funsor(raw_log_prob, Real, dim_to_name=dim_to_name) + log_prob = to_funsor(raw_log_prob, Real, dim_to_name=dim_to_name) + inputs = value.inputs.copy() + inputs.update(instance.inputs) + return log_prob.align(tuple(k for k, v in inputs.items() if k in log_prob.inputs and isinstance(v, BintType))) def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index d72f67653..291843a90 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -459,7 +459,6 @@ def test_generic_distribution_to_funsor(case): funsor_dist = [term for term in funsor_dist.terms if isinstance(term, (funsor.distribution.Distribution, funsor.terms.Independent))][0] - import pdb; pdb.set_trace() actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim) assert isinstance(actual_dist, backend_dist.Distribution) @@ -497,7 +496,7 @@ def test_generic_log_prob(case, use_lazy): raw_value = raw_dist.sample() expected_logprob = to_funsor(raw_dist.log_prob(raw_value), output=funsor.Real, dim_to_name=dim_to_name) funsor_value = to_funsor(raw_value, output=expected_value_domain, dim_to_name=dim_to_name) - actual_logprob = funsor_dist(value=funsor_value).align(tuple(expected_logprob.inputs)) + actual_logprob = funsor_dist(value=funsor_value) assert_close(actual_logprob, expected_logprob, rtol=1e-4 if use_lazy else 1e-3) From 88cb6036a1c9e9ec68e2db7dbc747ba10d3d0045 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Dec 2020 16:28:45 -0500 Subject: [PATCH 05/17] attempting a fix, and ops.unsqueeze for jax --- funsor/distribution.py | 20 +++++++++++++++++++- funsor/torch/distributions.py | 2 +- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index 40b180528..73097a92d 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -438,12 +438,30 @@ def __call__(self, cls, args, kwargs): @to_data.register(Distribution) def distribution_to_data(funsor_dist, name_to_dim=None): funsor_event_shape = funsor_dist.value.output.shape + + # # attempt to generically infer the independent output dimensions + # # TODO we still need to combine this somehow with param shape information... + # instance = funsor_dist.dist_class(**{k: dummy_numeric_array(domain) + # for k, domain in zip(funsor_dist._ast_fields, + # (v.output for v in funsor_dist._ast_values)) + # if k != "value"}, + # validate_args=False) + # event_shape = broadcast_shape(instance.event_shape, funsor_dist.value.output.shape) + # # XXX is this final broadcast_shape necessary? should just be event_shape[...]? + # indep_shape = broadcast_shape(instance.batch_shape, event_shape[:len(event_shape) - len(instance.event_shape)]) + params = [] for param_name, funsor_param in zip(funsor_dist._ast_fields, funsor_dist._ast_values[:-1]): param = to_data(funsor_param, name_to_dim=name_to_dim) + + # FIXME: this loop is invalid for distributions like DirichletMultinomial + # which have some parameters (e.g. total_count) whose output shapes + # are smaller than the distribution's event_shape. for i in range(max(0, len(funsor_event_shape) - len(funsor_param.output.shape))): - param = param.unsqueeze(-1 - len(funsor_param.output.shape)) + param = ops.unsqueeze(param, -1 - len(funsor_param.output.shape)) + params.append(param) + pyro_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], params))) pyro_dist = pyro_dist.to_event(max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0)) diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index b93431ad1..1bcd0185a 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -137,7 +137,7 @@ def _infer_value_domain(**kwargs): @functools.lru_cache(maxsize=5000) def _infer_value_domain(cls, **kwargs): instance = cls.dist_class(**{k: dummy_numeric_array(domain) for k, domain in kwargs.items()}, validate_args=False) - return Reals[instance.event_shape] + return Reals[instance.batch_shape + instance.event_shape] # TODO fix Delta.arg_constraints["v"] to be a From 2de25bec01575a55ded399fc5ccaa444cea018c0 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Dec 2020 17:56:15 -0500 Subject: [PATCH 06/17] fix dirichletmultinomial --- funsor/distribution.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index 73097a92d..54a807516 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -439,25 +439,23 @@ def __call__(self, cls, args, kwargs): def distribution_to_data(funsor_dist, name_to_dim=None): funsor_event_shape = funsor_dist.value.output.shape - # # attempt to generically infer the independent output dimensions - # # TODO we still need to combine this somehow with param shape information... - # instance = funsor_dist.dist_class(**{k: dummy_numeric_array(domain) - # for k, domain in zip(funsor_dist._ast_fields, - # (v.output for v in funsor_dist._ast_values)) - # if k != "value"}, - # validate_args=False) - # event_shape = broadcast_shape(instance.event_shape, funsor_dist.value.output.shape) - # # XXX is this final broadcast_shape necessary? should just be event_shape[...]? - # indep_shape = broadcast_shape(instance.batch_shape, event_shape[:len(event_shape) - len(instance.event_shape)]) + # attempt to generically infer the independent output dimensions + instance = funsor_dist.dist_class(**{ + k: dummy_numeric_array(domain) + for k, domain in zip(funsor_dist._ast_fields, (v.output for v in funsor_dist._ast_values)) + if k != "value" + }, validate_args=False) + event_shape = broadcast_shape(instance.event_shape, funsor_dist.value.output.shape) + # XXX is this final broadcast_shape necessary? should just be event_shape[...]? + indep_shape = broadcast_shape(instance.batch_shape, event_shape[:len(event_shape) - len(instance.event_shape)]) params = [] for param_name, funsor_param in zip(funsor_dist._ast_fields, funsor_dist._ast_values[:-1]): param = to_data(funsor_param, name_to_dim=name_to_dim) - # FIXME: this loop is invalid for distributions like DirichletMultinomial - # which have some parameters (e.g. total_count) whose output shapes - # are smaller than the distribution's event_shape. - for i in range(max(0, len(funsor_event_shape) - len(funsor_param.output.shape))): + param_event_shape = getattr(funsor_dist._infer_param_domain(param_name, funsor_param.output.shape), "shape", ()) + param_indep_shape = funsor_param.output.shape[:len(funsor_param.output.shape) - len(param_event_shape)] + for i in range(max(0, len(indep_shape) - len(param_indep_shape))): param = ops.unsqueeze(param, -1 - len(funsor_param.output.shape)) params.append(param) From 9a0ccf3c26c4fe0b9a061d2075372c740eee8d98 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Dec 2020 18:01:10 -0500 Subject: [PATCH 07/17] comments --- funsor/distribution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/funsor/distribution.py b/funsor/distribution.py index 54a807516..ca38662c9 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -453,9 +453,11 @@ def distribution_to_data(funsor_dist, name_to_dim=None): for param_name, funsor_param in zip(funsor_dist._ast_fields, funsor_dist._ast_values[:-1]): param = to_data(funsor_param, name_to_dim=name_to_dim) + # infer the independent dimensions of each parameter separately, since we chose to keep them unbroadcasted param_event_shape = getattr(funsor_dist._infer_param_domain(param_name, funsor_param.output.shape), "shape", ()) param_indep_shape = funsor_param.output.shape[:len(funsor_param.output.shape) - len(param_event_shape)] for i in range(max(0, len(indep_shape) - len(param_indep_shape))): + # add singleton event dimensions, leave broadcasting/expanding to backend param = ops.unsqueeze(param, -1 - len(funsor_param.output.shape)) params.append(param) From 010221044492db8ea006e9611897a1ae0802a7ab Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Dec 2020 18:11:37 -0500 Subject: [PATCH 08/17] fix tests for jax --- funsor/distribution.py | 3 ++- test/test_distribution.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index ca38662c9..5374a38d6 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -173,7 +173,8 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): sample_shape = tuple(v.size for v in sample_inputs.values()) sample_args = (sample_shape,) if get_backend() == "torch" else (rng_key, sample_shape) if self.has_rsample: - raw_value = raw_dist.rsample(*sample_args) + # TODO fix this hack by adding rsample and has_rsample to Independent upstream in NumPyro + raw_value = getattr(raw_dist, "rsample", raw_dist.sample)(*sample_args) else: raw_value = ops.detach(raw_dist.sample(*sample_args)) diff --git a/test/test_distribution.py b/test/test_distribution.py index 244be0406..5f939748e 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -1144,7 +1144,8 @@ def test_normal_event_dim_conversion(batch_shape, event_shape): check_funsor(actual, expected_inputs, Real) name_to_dim = {batch_dim: -1-i for i, batch_dim in enumerate(batch_dims)} - data = actual.sample(frozenset(["value"])).terms[0][1][0] + rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) + data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0][1][0] actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim) expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob( @@ -1172,7 +1173,8 @@ def test_mvnormal_event_dim_conversion(batch_shape, event_shape): check_funsor(actual, expected_inputs, Real) name_to_dim = {batch_dim: -1-i for i, batch_dim in enumerate(batch_dims)} - data = actual.sample(frozenset(["value"])).terms[0][1][0] + rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) + data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0][1][0] actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim) expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob( From 9f85c50c7605c8861295495ec2a5c8d9bbf0721e Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Dec 2020 12:47:46 -0500 Subject: [PATCH 09/17] add extra broadcasting condition --- funsor/distribution.py | 2 ++ test/test_distribution.py | 7 ++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index 5374a38d6..80d6f7759 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -74,6 +74,8 @@ def __call__(cls, *args, **kwargs): # this avoids .expand-ing underlying parameter tensors if isinstance(v, Funsor) and isinstance(v.output, RealsType): domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)] + elif ops.is_numeric_array(v): + domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)] # now use the broadcasted parameter shapes to infer the event_shape domains["value"] = cls._infer_value_domain(**domains) diff --git a/test/test_distribution.py b/test/test_distribution.py index 5f939748e..8f730dc2c 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -1127,7 +1127,8 @@ def test_gamma_poisson_conjugate(batch_shape): @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) @pytest.mark.parametrize('event_shape', [(4,), (4, 7), (1, 4), (4, 1), (4, 1, 7)], ids=str) -def test_normal_event_dim_conversion(batch_shape, event_shape): +@pytest.mark.parametrize('use_raw_scale', [False, True]) +def test_normal_event_dim_conversion(batch_shape, event_shape, use_raw_scale): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) @@ -1135,6 +1136,10 @@ def test_normal_event_dim_conversion(batch_shape, event_shape): value = Variable("value", Reals[event_shape]) loc = Tensor(randn(batch_shape + event_shape), inputs) scale = Tensor(ops.exp(randn(batch_shape)), inputs) + if use_raw_scale: + if batch_shape: + pytest.xfail(reason="raw scale is underspecified for nonempty batch_shape") + scale = scale.data with interpretation(lazy): actual = dist.Normal(loc=loc, scale=scale, value=value) From 9f5f391792ae41e35b2639d3d7091fe4843f3be4 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Dec 2020 12:57:12 -0500 Subject: [PATCH 10/17] patch jax compound dists --- funsor/distribution.py | 3 +-- funsor/jax/distributions.py | 7 +++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index 80d6f7759..b91f0523b 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -175,8 +175,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): sample_shape = tuple(v.size for v in sample_inputs.values()) sample_args = (sample_shape,) if get_backend() == "torch" else (rng_key, sample_shape) if self.has_rsample: - # TODO fix this hack by adding rsample and has_rsample to Independent upstream in NumPyro - raw_value = getattr(raw_dist, "rsample", raw_dist.sample)(*sample_args) + raw_value = raw_dist.rsample(*sample_args) else: raw_value = ops.detach(raw_dist.sample(*sample_args)) diff --git a/funsor/jax/distributions.py b/funsor/jax/distributions.py index 23a06be69..9d625301a 100644 --- a/funsor/jax/distributions.py +++ b/funsor/jax/distributions.py @@ -215,6 +215,13 @@ def deltadist_to_data(funsor_dist, name_to_dim=None): # Converting PyTorch Distributions to funsors ############################################### +dist.Independent.has_rsample = property(lambda self: self.base_dist.has_rsample) +dist.Independent.rsample = dist.Independent.sample +dist.MaskedDistribution.has_rsample = property(lambda self: self.base_dist.has_rsample) +dist.MaskedDistribution.rsample = dist.MaskedDistribution.sample +dist.TransformedDistribution.has_rsample = property(lambda self: self.base_dist.has_rsample) +dist.TransformedDistribution.rsample = dist.TransformedDistribution.sample + to_funsor.register(dist.Independent)(indepdist_to_funsor) if hasattr(dist, "MaskedDistribution"): to_funsor.register(dist.MaskedDistribution)(maskeddist_to_funsor) From a3a4c4b132ceade0fa698bd93d91790394424235 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Dec 2020 14:45:50 -0500 Subject: [PATCH 11/17] align to align_tensors --- funsor/distribution.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index b91f0523b..a3af11bbd 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -158,9 +158,9 @@ def eager_log_prob(cls, *params): name_to_dim.update({v: k for k, v in dim_to_name.items() if v not in name_to_dim}) raw_log_prob = raw_dist.log_prob(to_data(value, name_to_dim=name_to_dim)) log_prob = to_funsor(raw_log_prob, Real, dim_to_name=dim_to_name) - inputs = value.inputs.copy() - inputs.update(instance.inputs) - return log_prob.align(tuple(k for k, v in inputs.items() if k in log_prob.inputs and isinstance(v, BintType))) + # final align() ensures that the inputs have the canonical order + # implied by align_tensors, which is assumed pervasively in tests + return log_prob.align(tuple(align_tensors(*(params[:-1] + (value,)))[0])) def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): From 260d2a23f9103110cf7fd6762cd927c474b25577 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Dec 2020 14:57:54 -0500 Subject: [PATCH 12/17] lint --- funsor/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index a3af11bbd..fc61f5db8 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -16,7 +16,7 @@ import funsor.ops as ops from funsor.affine import is_affine from funsor.cnf import Contraction, GaussianMixture -from funsor.domains import Array, BintType, Real, Reals, RealsType +from funsor.domains import Array, Real, Reals, RealsType from funsor.gaussian import Gaussian from funsor.interpreter import gensym from funsor.tensor import (Tensor, align_tensors, dummy_numeric_array, get_default_prototype, From 204b357b8a66496c93ee5189d9898d973da668f8 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Dec 2020 13:36:23 -0500 Subject: [PATCH 13/17] tweak tolerance --- test/test_distribution.py | 2 +- test/test_distribution_generic.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_distribution.py b/test/test_distribution.py index 8f730dc2c..a21b463bd 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -246,7 +246,7 @@ def dirichlet(concentration: Reals[event_shape], check_funsor(expected, inputs, Real) actual = dist.Dirichlet(concentration, value) check_funsor(actual, inputs, Real) - assert_close(actual, expected) + assert_close(actual, expected, atol=1e-4) @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 291843a90..50a00e237 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -496,8 +496,7 @@ def test_generic_log_prob(case, use_lazy): raw_value = raw_dist.sample() expected_logprob = to_funsor(raw_dist.log_prob(raw_value), output=funsor.Real, dim_to_name=dim_to_name) funsor_value = to_funsor(raw_value, output=expected_value_domain, dim_to_name=dim_to_name) - actual_logprob = funsor_dist(value=funsor_value) - assert_close(actual_logprob, expected_logprob, rtol=1e-4 if use_lazy else 1e-3) + assert_close(funsor_dist(value=funsor_value), expected_logprob, rtol=1e-4 if use_lazy else 1e-3) @pytest.mark.parametrize("case", TEST_CASES, ids=str) From aa56a2de4562a01912184627c088b57e9e6e5e5d Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 10 Dec 2020 09:25:57 -0500 Subject: [PATCH 14/17] tolerance --- test/test_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_distribution.py b/test/test_distribution.py index a21b463bd..ceb13a913 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -246,7 +246,7 @@ def dirichlet(concentration: Reals[event_shape], check_funsor(expected, inputs, Real) actual = dist.Dirichlet(concentration, value) check_funsor(actual, inputs, Real) - assert_close(actual, expected, atol=1e-4) + assert_close(actual, expected, atol=1e-3) @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) From 51d7bef06fa88cd0b89201960a168d577a826a0c Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 10 Dec 2020 10:14:09 -0500 Subject: [PATCH 15/17] rtol --- test/test_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_distribution.py b/test/test_distribution.py index ceb13a913..d72e83f4b 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -246,7 +246,7 @@ def dirichlet(concentration: Reals[event_shape], check_funsor(expected, inputs, Real) actual = dist.Dirichlet(concentration, value) check_funsor(actual, inputs, Real) - assert_close(actual, expected, atol=1e-3) + assert_close(actual, expected, atol=1e-3, rtol=1e-3) @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) From e7f4441a4c76573da49c37b0e7893b8e90c18809 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 14 Dec 2020 20:41:30 -0500 Subject: [PATCH 16/17] address comments --- funsor/distribution.py | 17 ++++++++++------- funsor/jax/distributions.py | 5 +++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index fc61f5db8..fa22f7d32 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -158,9 +158,12 @@ def eager_log_prob(cls, *params): name_to_dim.update({v: k for k, v in dim_to_name.items() if v not in name_to_dim}) raw_log_prob = raw_dist.log_prob(to_data(value, name_to_dim=name_to_dim)) log_prob = to_funsor(raw_log_prob, Real, dim_to_name=dim_to_name) - # final align() ensures that the inputs have the canonical order + # this logic ensures that the inputs have the canonical order # implied by align_tensors, which is assumed pervasively in tests - return log_prob.align(tuple(align_tensors(*(params[:-1] + (value,)))[0])) + inputs = OrderedDict() + for x in params[:-1] + (value,): + inputs.update(x.inputs) + return log_prob.align(inputs) def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): @@ -443,13 +446,13 @@ def distribution_to_data(funsor_dist, name_to_dim=None): # attempt to generically infer the independent output dimensions instance = funsor_dist.dist_class(**{ - k: dummy_numeric_array(domain) - for k, domain in zip(funsor_dist._ast_fields, (v.output for v in funsor_dist._ast_values)) - if k != "value" + k: dummy_numeric_array(v.output) + for k, v in zip(funsor_dist._ast_fields, funsor_dist._ast_values[:-1]) }, validate_args=False) event_shape = broadcast_shape(instance.event_shape, funsor_dist.value.output.shape) - # XXX is this final broadcast_shape necessary? should just be event_shape[...]? - indep_shape = broadcast_shape(instance.batch_shape, event_shape[:len(event_shape) - len(instance.event_shape)]) + reinterpreted_batch_ndims = len(event_shape) - len(instance.event_shape) + assert reinterpreted_batch_ndims > 0 # XXX is this ever nonzero? + indep_shape = broadcast_shape(instance.batch_shape, event_shape[:reinterpreted_batch_ndims]) params = [] for param_name, funsor_param in zip(funsor_dist._ast_fields, funsor_dist._ast_values[:-1]): diff --git a/funsor/jax/distributions.py b/funsor/jax/distributions.py index 9d625301a..ff1b099de 100644 --- a/funsor/jax/distributions.py +++ b/funsor/jax/distributions.py @@ -200,7 +200,7 @@ def _infer_param_domain(cls, name, raw_shape): ########################################################### -# Converting distribution funsors to PyTorch distributions +# Converting distribution funsors to NumPyro distributions ########################################################### # Convert Delta **distribution** to raw data @@ -212,9 +212,10 @@ def deltadist_to_data(funsor_dist, name_to_dim=None): ############################################### -# Converting PyTorch Distributions to funsors +# Converting NumPyro Distributions to funsors ############################################### +# TODO move these properties upstream to numpyro.distributions dist.Independent.has_rsample = property(lambda self: self.base_dist.has_rsample) dist.Independent.rsample = dist.Independent.sample dist.MaskedDistribution.has_rsample = property(lambda self: self.base_dist.has_rsample) From 925cdfb75948db34ae7a73fb3511fcfca4447ebd Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 15 Dec 2020 00:10:39 -0500 Subject: [PATCH 17/17] fix test --- funsor/distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index fa22f7d32..73235c09d 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -163,7 +163,7 @@ def eager_log_prob(cls, *params): inputs = OrderedDict() for x in params[:-1] + (value,): inputs.update(x.inputs) - return log_prob.align(inputs) + return log_prob.align(tuple(inputs)) def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): @@ -451,7 +451,7 @@ def distribution_to_data(funsor_dist, name_to_dim=None): }, validate_args=False) event_shape = broadcast_shape(instance.event_shape, funsor_dist.value.output.shape) reinterpreted_batch_ndims = len(event_shape) - len(instance.event_shape) - assert reinterpreted_batch_ndims > 0 # XXX is this ever nonzero? + assert reinterpreted_batch_ndims >= 0 # XXX is this ever nonzero? indep_shape = broadcast_shape(instance.batch_shape, event_shape[:reinterpreted_batch_ndims]) params = []