Skip to content

Commit

Permalink
Add generic to_funsor conversion methods for funsor.distributions (#321)
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 authored Mar 26, 2020
1 parent 69ea45a commit be19281
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 152 deletions.
131 changes: 127 additions & 4 deletions funsor/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,25 @@
import functools
import inspect
import math
import typing
from collections import OrderedDict

import makefun
import pyro.distributions as dist
from pyro.distributions.torch_distribution import MaskedDistribution
import torch
import torch.distributions.constraints as constraints

import funsor.delta
import funsor.ops as ops
from funsor.affine import is_affine
from funsor.cnf import GaussianMixture
from funsor.domains import Domain, reals
from funsor.gaussian import Gaussian
from funsor.interpreter import gensym
from funsor.ops import cholesky
from funsor.tensor import Tensor, align_tensors, ignore_jit_warnings, stack
from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, to_funsor
from funsor.terms import Funsor, FunsorMeta, Independent, Number, Variable, eager, to_data, to_funsor
from funsor.util import broadcast_shape


Expand Down Expand Up @@ -52,7 +56,9 @@ class DistributionMeta(FunsorMeta):
def __call__(cls, *args, **kwargs):
kwargs.update(zip(cls._ast_fields, args))
value = kwargs.pop('value', 'value')
kwargs = OrderedDict((k, to_funsor(kwargs[k])) for k in cls._ast_fields if k != 'value')
kwargs = OrderedDict(
(k, to_funsor(kwargs[k], output=cls._infer_param_domain(k, getattr(kwargs[k], "shape", ()))))
for k in cls._ast_fields if k != 'value')
value = to_funsor(value, output=cls._infer_value_domain(**{k: v.output for k, v in kwargs.items()}))
args = numbers_to_tensors(*(tuple(kwargs.values()) + (value,)))
return super(DistributionMeta, cls).__call__(*args)
Expand Down Expand Up @@ -104,7 +110,7 @@ def __getattribute__(self, attr):
return super().__getattribute__(attr)

@classmethod
@functools.lru_cache(maxsize=None)
@functools.lru_cache(maxsize=5000)
def _infer_value_domain(cls, **kwargs):
# rely on the underlying distribution's logic to infer the event_shape given param domains
instance = cls.dist_class(**{k: _dummy_tensor(domain) for k, domain in kwargs.items()}, validate_args=False)
Expand All @@ -115,6 +121,23 @@ def _infer_value_domain(cls, **kwargs):
out_dtype = 'real'
return Domain(dtype=out_dtype, shape=out_shape)

@classmethod
@functools.lru_cache(maxsize=5000)
def _infer_param_domain(cls, name, raw_shape):
support = cls.dist_class.arg_constraints.get(name, None)
if isinstance(support, constraints._Simplex):
output = reals(raw_shape[-1])
elif isinstance(support, constraints._RealVector):
output = reals(raw_shape[-1])
elif isinstance(support, (constraints._LowerCholesky, constraints._PositiveDefinite)):
output = reals(*raw_shape[-2:])
elif isinstance(support, constraints._Real) and name == "logits" and \
isinstance(cls.dist_class.arg_constraints["probs"], constraints._Simplex):
output = reals(raw_shape[-1])
else:
output = None
return output


################################################################################
# Distribution Wrappers
Expand Down Expand Up @@ -183,7 +206,7 @@ def __init__(self, logits, validate_args=None):

# Multinomial and related dists have dependent bint dtypes, so we just make them 'real'
# See issue: https://github.com/pyro-ppl/funsor/issues/322
@functools.lru_cache(maxsize=None)
@functools.lru_cache(maxsize=5000)
def _multinomial_infer_value_domain(cls, **kwargs):
instance = cls.dist_class(**{k: _dummy_tensor(domain) for k, domain in kwargs.items()}, validate_args=False)
return reals(*instance.event_shape)
Expand All @@ -194,6 +217,106 @@ def _multinomial_infer_value_domain(cls, **kwargs):
DirichletMultinomial._infer_value_domain = classmethod(_multinomial_infer_value_domain)


###############################################
# Converting PyTorch Distributions to funsors
###############################################

@to_funsor.register(torch.distributions.Distribution)
def torchdistribution_to_funsor(pyro_dist, output=None, dim_to_name=None):
import funsor.distributions # TODO find a better way to do this lookup
funsor_dist_class = getattr(funsor.distributions, type(pyro_dist).__name__.split("_PyroWrapper_")[-1])
params = [to_funsor(
getattr(pyro_dist, param_name),
output=funsor_dist_class._infer_param_domain(
param_name, getattr(getattr(pyro_dist, param_name), "shape", ())),
dim_to_name=dim_to_name)
for param_name in funsor_dist_class._ast_fields if param_name != 'value']
return funsor_dist_class(*params)


@to_funsor.register(torch.distributions.Independent)
def indepdist_to_funsor(pyro_dist, output=None, dim_to_name=None):
dim_to_name = OrderedDict((dim - pyro_dist.reinterpreted_batch_ndims, name)
for dim, name in dim_to_name.items())
dim_to_name.update(OrderedDict((i, f"_pyro_event_dim_{i}") for i in range(-pyro_dist.reinterpreted_batch_ndims, 0)))
result = to_funsor(pyro_dist.base_dist, dim_to_name=dim_to_name)
for i in reversed(range(-pyro_dist.reinterpreted_batch_ndims, 0)):
name = f"_pyro_event_dim_{i}"
result = funsor.terms.Independent(result, "value", name, "value")
return result


@to_funsor.register(MaskedDistribution)
def maskeddist_to_funsor(pyro_dist, output=None, dim_to_name=None):
mask = to_funsor(pyro_dist._mask.float(), output=output, dim_to_name=dim_to_name)
funsor_base_dist = to_funsor(pyro_dist.base_dist, output=output, dim_to_name=dim_to_name)
return mask * funsor_base_dist


@to_funsor.register(torch.distributions.Bernoulli)
def bernoulli_to_funsor(pyro_dist, output=None, dim_to_name=None):
new_pyro_dist = _PyroWrapper_BernoulliLogits(logits=pyro_dist.logits)
return torchdistribution_to_funsor(new_pyro_dist, output, dim_to_name)


@to_funsor.register(torch.distributions.TransformedDistribution)
def transformeddist_to_funsor(pyro_dist, output=None, dim_to_name=None):
raise NotImplementedError("TODO implement conversion of TransformedDistribution")


@to_funsor.register(torch.distributions.MultivariateNormal)
def torchmvn_to_funsor(pyro_dist, output=None, dim_to_name=None, real_inputs=OrderedDict()):
funsor_dist = torchdistribution_to_funsor(pyro_dist, output=output, dim_to_name=dim_to_name)
if len(real_inputs) == 0:
return funsor_dist
discrete, gaussian = funsor_dist(value="value").terms
inputs = OrderedDict((k, v) for k, v in gaussian.inputs.items() if v.dtype != 'real')
inputs.update(real_inputs)
return discrete + Gaussian(gaussian.info_vec, gaussian.precision, inputs)


###########################################################
# Converting distribution funsors to PyTorch distributions
###########################################################

@to_data.register(Distribution)
def distribution_to_data(funsor_dist, name_to_dim=None):
pyro_dist_class = funsor_dist.dist_class
params = [to_data(getattr(funsor_dist, param_name), name_to_dim=name_to_dim)
for param_name in funsor_dist._ast_fields if param_name != 'value']
pyro_dist = pyro_dist_class(*params)
funsor_event_shape = funsor_dist.value.output.shape
pyro_dist = pyro_dist.to_event(max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0))
if pyro_dist.event_shape != funsor_event_shape:
raise ValueError("Event shapes don't match, something went wrong")
return pyro_dist


@to_data.register(Independent[typing.Union[Independent, Distribution], str, str, str])
def indep_to_data(funsor_dist, name_to_dim=None):
raise NotImplementedError("TODO implement conversion of Independent")


@to_data.register(Gaussian)
def gaussian_to_data(funsor_dist, name_to_dim=None, normalized=False):
if normalized:
return to_data(funsor_dist.log_normalizer + funsor_dist, name_to_dim=name_to_dim)
loc = funsor_dist.info_vec.unsqueeze(-1).cholesky_solve(cholesky(funsor_dist.precision)).squeeze(-1)
int_inputs = OrderedDict((k, d) for k, d in funsor_dist.inputs.items() if d.dtype != "real")
loc = to_data(Tensor(loc, int_inputs), name_to_dim)
precision = to_data(Tensor(funsor_dist.precision, int_inputs), name_to_dim)
return dist.MultivariateNormal(loc, precision_matrix=precision)


@to_data.register(GaussianMixture)
def gaussianmixture_to_data(funsor_dist, name_to_dim=None):
discrete, gaussian = funsor_dist.terms
cat = dist.Categorical(logits=to_data(
discrete + gaussian.log_normalizer, name_to_dim=name_to_dim))
mvn = to_data(gaussian, name_to_dim=name_to_dim)
return cat, mvn


################################################
# Backend-agnostic distribution patterns
################################################
Expand Down
Loading

0 comments on commit be19281

Please sign in to comment.