Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support conversion of Transforms and TransformedDistributions to and from Funsors #365

Merged
merged 46 commits into from
Nov 20, 2020
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
9b6a3bc
transformed distribution conversion and sampling
eb8680 Sep 9, 2020
6b3371a
most distribution tests passing?
eb8680 Sep 10, 2020
fcb70ea
nit
eb8680 Sep 10, 2020
78747a5
dont add lebesgue by default
eb8680 Sep 13, 2020
a469cd7
add lebesgue in conversion
eb8680 Sep 13, 2020
be80ef0
transform?
eb8680 Sep 13, 2020
2016ba1
add transform conversions
eb8680 Sep 17, 2020
7880e40
Merge branch 'master' into lebesgue-2
eb8680 Sep 17, 2020
2a4d548
Merge branch 'master' into lebesgue-2
eb8680 Sep 17, 2020
2f0b0ea
add some tests
eb8680 Sep 17, 2020
08949cf
add composition to transform conversion
eb8680 Sep 20, 2020
8d5a5ad
nit
eb8680 Sep 20, 2020
05cd1a2
remove sigmoid transform conversion patterns
eb8680 Sep 20, 2020
a3f8050
fix transform composition order
eb8680 Sep 20, 2020
fbc91b7
nit
eb8680 Sep 20, 2020
7400b41
Merge branch 'master' into lebesgue-2
eb8680 Sep 26, 2020
77aca44
Merge branch 'master' into lebesgue-2
eb8680 Oct 25, 2020
c14f3ac
fix merge
eb8680 Oct 25, 2020
67777fa
increase tolerance in flaky binomial test?
eb8680 Oct 25, 2020
cd13b30
fix errors in transformed sampling
eb8680 Oct 25, 2020
6765bd7
support sampling from transformed distribution in Contraction
eb8680 Oct 25, 2020
c89fa88
Add atanh and tanh ops for TanhTransform
eb8680 Oct 29, 2020
e10b962
fix error in lognormal test
eb8680 Oct 29, 2020
d6efe96
fix tests
eb8680 Oct 29, 2020
18af534
sigmoid
eb8680 Oct 29, 2020
1c42e8e
Merge branch 'tanh-op' into lebesgue-2
eb8680 Oct 29, 2020
22b2115
register tanh and sigmoid for conversion
eb8680 Oct 29, 2020
f78b0ab
Merge branch 'master' into lebesgue-2
eb8680 Oct 31, 2020
e18f5ce
Merge branch 'master' into lebesgue-2
eb8680 Nov 11, 2020
f24a599
fix jax tests
eb8680 Nov 11, 2020
974b21b
Merge branch 'master' into lebesgue-2
eb8680 Nov 11, 2020
2d93270
simplify unscaled_sample
eb8680 Nov 11, 2020
092141c
Re-xfail lognormal sampler test
eb8680 Nov 11, 2020
ddb6d86
Merge branch 'master' into lebesgue-2
eb8680 Nov 12, 2020
bcf52cc
organization nit
eb8680 Nov 12, 2020
0d46e5d
add some transformeddist test cases
eb8680 Nov 12, 2020
6e33c93
Merge branch 'master' into lebesgue-2
eb8680 Nov 12, 2020
0966eb0
fix conversion by avoiding double inversion
eb8680 Nov 15, 2020
6c2cc27
use normalize
eb8680 Nov 15, 2020
b59a0d8
add xfail option to DistTestCase and add an xfailing Independent(Tran…
eb8680 Nov 15, 2020
8d9d965
add jax xfails
eb8680 Nov 16, 2020
bdde5b4
move Lebesgue to a separate branch
eb8680 Nov 16, 2020
8aab02d
lint
eb8680 Nov 16, 2020
0df073f
address comments
eb8680 Nov 20, 2020
b8f553f
isdisjoint
eb8680 Nov 20, 2020
18c87f3
Merge branch 'master' into lebesgue-2
eb8680 Nov 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import opt_einsum
from multipledispatch.variadic import Variadic

import funsor
import funsor.ops as ops
from funsor.affine import affine_inputs
from funsor.delta import Delta
Expand Down Expand Up @@ -127,6 +128,14 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
terms.append(-gaussian.log_normalizer)
terms.append(term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0]))
result = Contraction(self.red_op, self.bin_op, self.reduced_vars, *terms)
elif any(isinstance(term, funsor.distribution.Distribution)
and greedy_vars.intersection(term.value.inputs) for term in greedy_terms):
eb8680 marked this conversation as resolved.
Show resolved Hide resolved
sampled_terms = [
term.unscaled_sample(greedy_vars.intersection(term.value.inputs), sample_inputs)
for term in greedy_terms if isinstance(term, funsor.distribution.Distribution)
and greedy_vars.intersection(term.value.inputs)
]
result = Contraction(self.red_op, self.bin_op, self.reduced_vars, *(terms + sampled_terms))
else:
raise NotImplementedError('Unhandled case: {}'.format(
', '.join(str(type(t)) for t in greedy_terms)))
Expand Down
58 changes: 39 additions & 19 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from funsor.interpreter import gensym
from funsor.tensor import (Tensor, align_tensors, dummy_numeric_array, get_default_prototype,
ignore_jit_warnings, numeric_array, stack)
from funsor.terms import Funsor, FunsorMeta, Independent, Number, Variable, eager, to_data, to_funsor
from funsor.terms import Funsor, FunsorMeta, Independent, Number, Variable, \
eager, to_data, to_funsor
from funsor.util import broadcast_shape, get_backend, getargspec, lazy_property


Expand Down Expand Up @@ -109,17 +110,16 @@ def _get_raw_dist(self):
"""
Internal method for working with underlying distribution attributes
"""
if isinstance(self.value, Variable):
value_name = self.value.name
else:
raise NotImplementedError("cannot get raw dist for {}".format(self))
value_name = [name for name, domain in self.value.inputs.items() # TODO is this right?
if domain == self.value.output][0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks weird. Can you explain what's going on? When is value_name != "value"? The assertion in self.__init__() suggests value_name == "value" IIUC.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value_name can potentially be anything, since self.value is a lazy expression. This logic is meant to solve the problem of identifying the value name when self.value is not a Variable or even has more than one input (which happens when constructing the Funsor version of a TransformedDistribution).

The simplest nontrivial example of the latter case would be an affine or power transform where the parameters are funsor.Tensors with nontrivial .inputs, although these are not handled in this PR since funsor.delta.solve does not yet support inverting such expressions.

The main use for value_name in this PR is in Distribution.unscaled_sample, which needs to know value_name to construct a sample Delta with the correct .inputs.

# arbitrary name-dim mapping, since we're converting back to a funsor anyway
name_to_dim = {name: -dim-1 for dim, (name, domain) in enumerate(self.inputs.items())
if isinstance(domain.dtype, int) and name != value_name}
raw_dist = to_data(self, name_to_dim=name_to_dim)
dim_to_name = {dim: name for name, dim in name_to_dim.items()}
# also return value output, dim_to_name for converting results back to funsor
return raw_dist, self.value.output, dim_to_name
value_output = self.inputs[value_name]
return raw_dist, value_name, value_output, dim_to_name

@property
def has_rsample(self):
Expand All @@ -130,16 +130,15 @@ def has_enumerate_support(self):
return getattr(self.dist_class, "has_enumerate_support", False)

def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
params = OrderedDict(self.params)
value = params.pop("value")
assert all(isinstance(v, (Number, Tensor)) for v in params.values())
assert isinstance(value, Variable) and value.name in sampled_vars

value_name = value.name
raw_dist, value_output, dim_to_name = self._get_raw_dist()
# note this should handle transforms correctly via distribution_to_data
raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist()
for d, name in zip(range(len(sample_inputs), 0, -1), sample_inputs.keys()):
dim_to_name[-d - len(raw_dist.batch_shape)] = name

if value_name not in sampled_vars:
return self

sample_shape = tuple(v.size for v in sample_inputs.values())
sample_args = (sample_shape,) if get_backend() == "torch" else (rng_key, sample_shape)
if self.has_rsample:
Expand All @@ -161,23 +160,23 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):

def enumerate_support(self, expand=False):
assert self.has_enumerate_support and isinstance(self.value, Variable)
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.enumerate_support(expand=expand)
dim_to_name[min(dim_to_name.keys(), default=0)-1] = self.value.name
dim_to_name[min(dim_to_name.keys(), default=0)-1] = value_name
return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)

def entropy(self):
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.entropy()
return to_funsor(raw_value, output=self.output, dim_to_name=dim_to_name)

def mean(self):
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.mean
return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)

def variance(self):
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.variance
return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)

Expand Down Expand Up @@ -234,7 +233,6 @@ def _infer_param_domain(cls, name, raw_shape):
# Distribution Wrappers
################################################################################


def make_dist(backend_dist_class, param_names=(), generate_eager=True, generate_to_funsor=True):
if not param_names:
param_names = tuple(name for name in inspect.getfullargspec(backend_dist_class.__init__)[0][1:]
Expand Down Expand Up @@ -312,8 +310,19 @@ def maskeddist_to_funsor(backend_dist, output=None, dim_to_name=None):
return mask * funsor_base_dist


# converts TransformedDistributions
def transformeddist_to_funsor(backend_dist, output=None, dim_to_name=None):
raise NotImplementedError("TODO implement conversion of TransformedDistribution")
dist_module = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist
base_dist, transforms = backend_dist, []
while isinstance(base_dist, dist_module.TransformedDistribution):
transforms = base_dist.transforms + transforms
base_dist = base_dist.base_dist
funsor_base_dist = to_funsor(base_dist, output=output, dim_to_name=dim_to_name)
# TODO make this work with transforms that change the output type
transform = to_funsor(dist_module.transforms.ComposeTransform(transforms),
funsor_base_dist.inputs["value"], dim_to_name)
_, inv_transform, ldj = funsor.delta.solve(transform, to_funsor("value", funsor_base_dist.inputs["value"]))
return -ldj + funsor_base_dist(value=inv_transform)


class CoerceDistributionToFunsor:
Expand Down Expand Up @@ -396,6 +405,17 @@ def distribution_to_data(funsor_dist, name_to_dim=None):
pyro_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], params)))
funsor_event_shape = funsor_dist.value.output.shape
pyro_dist = pyro_dist.to_event(max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0))

# TODO get this working for all backends
if not isinstance(funsor_dist.value, Variable):
if get_backend() != "torch":
raise NotImplementedError("transformed distributions not yet supported under this backend,"
"try set_backend('torch')")
inv_value = funsor.delta.solve(funsor_dist.value, Variable("value", funsor_dist.value.output))[1]
transforms = to_data(inv_value, name_to_dim=name_to_dim)
backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist
pyro_dist = backend_dist.TransformedDistribution(pyro_dist, transforms)

if pyro_dist.event_shape != funsor_event_shape:
raise ValueError("Event shapes don't match, something went wrong")
return pyro_dist
Expand Down
86 changes: 85 additions & 1 deletion funsor/torch/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from funsor.domains import Real, Reals
import funsor.ops as ops
from funsor.tensor import Tensor, dummy_numeric_array
from funsor.terms import Binary, Funsor, Variable, eager, to_data, to_funsor
from funsor.terms import Binary, Funsor, Unary, Variable, eager, to_data, to_funsor
from funsor.util import methodof


Expand Down Expand Up @@ -221,10 +221,94 @@ def deltadist_to_data(funsor_dist, name_to_dim=None):
return dist.Delta(v, log_density, event_dim=len(funsor_dist.v.output.shape))


@functools.singledispatch
def op_to_torch_transform(op, name_to_dim=None):
raise NotImplementedError("cannot convert {} to a Transform".format(op))


@op_to_torch_transform.register(ops.TransformOp)
def transform_to_torch_transform(op, name_to_dim=None):
raise NotImplementedError("{} is not a currently supported transform".format(op))


@op_to_torch_transform.register(ops.ExpOp)
def exp_to_torch_transform(op, name_to_dim=None):
return torch.distributions.transforms.ExpTransform()


@op_to_torch_transform.register(ops.LogOp)
def log_to_torch_transform(op, name_to_dim=None):
return torch.distributions.transforms.ExpTransform().inv


@op_to_torch_transform.register(ops.SigmoidOp)
def sigmoid_to_torch_transform(op, name_to_dim=None):
return torch.distributions.transforms.SigmoidTransform()


@op_to_torch_transform.register(ops.TanhOp)
def tanh_to_torch_transform(op, name_to_dim=None):
return torch.distributions.transforms.TanhTransform()


@op_to_torch_transform.register(ops.AtanhOp)
def atanh_to_torch_transform(op, name_to_dim=None):
return torch.distributions.transforms.TanhTransform().inv


@to_data.register(Unary[ops.TransformOp, Union[Unary, Variable]])
def transform_to_data(expr, name_to_dim=None):
if isinstance(expr.op, ops.TransformOp):
tfm = op_to_torch_transform(expr.op, name_to_dim=name_to_dim)
if isinstance(expr.arg, Unary):
tfm = torch.distributions.transforms.ComposeTransform([to_data(expr.arg, name_to_dim=name_to_dim), tfm])
return tfm
raise NotImplementedError("cannot convert to data: {}".format(expr))


###############################################
# Converting PyTorch Distributions to funsors
###############################################

@to_funsor.register(torch.distributions.Transform)
def transform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
raise NotImplementedError("{} is not a currently supported transform".format(tfm))


@to_funsor.register(torch.distributions.transforms.ExpTransform)
def exptransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
name = next(real_inputs.keys()) if real_inputs else "value"
return ops.exp(Variable(name, output))


@to_funsor.register(torch.distributions.transforms.TanhTransform)
def exptransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
name = next(real_inputs.keys()) if real_inputs else "value"
return ops.tanh(Variable(name, output))


@to_funsor.register(torch.distributions.transforms.SigmoidTransform)
def exptransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
name = next(real_inputs.keys()) if real_inputs else "value"
return ops.sigmoid(Variable(name, output))


@to_funsor.register(torch.distributions.transforms._InverseTransform)
def inversetransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
expr = to_funsor(tfm._inv, output=output, dim_to_name=dim_to_name, real_inputs=real_inputs)
assert isinstance(expr, Unary)
return expr.op.inv(expr.arg)


@to_funsor.register(torch.distributions.transforms.ComposeTransform)
def composetransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
name = next(real_inputs.keys()) if real_inputs else "value"
expr = Variable(name, output)
for part in tfm.parts:
expr = to_funsor(part, output=output, dim_to_name=dim_to_name, real_inputs=real_inputs)(**{name: expr})
return expr


to_funsor.register(torch.distributions.Independent)(indepdist_to_funsor)
to_funsor.register(MaskedDistribution)(maskeddist_to_funsor)
to_funsor.register(torch.distributions.TransformedDistribution)(transformeddist_to_funsor)
Expand Down
Loading