From 3b13e273019fae6290a69126e48274d0a7d8893c Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 11 Feb 2020 11:55:28 -0800 Subject: [PATCH 01/24] add an inputs argument to to_funsor --- funsor/tensor.py | 27 ++++++++++++++----------- funsor/terms.py | 52 +++++++++++++++++++----------------------------- 2 files changed, 35 insertions(+), 44 deletions(-) diff --git a/funsor/tensor.py b/funsor/tensor.py index cd06d5fbf..ec440fd7d 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -405,18 +405,21 @@ def backend(self): return "numpy" -@dispatch(numeric_array) -def to_funsor(x): - return Tensor(x) - - -@dispatch(numeric_array, Domain) -def to_funsor(x, output): - result = Tensor(x, dtype=output.dtype) - if result.output != output: - raise ValueError("Invalid shape: expected {}, actual {}" - .format(output.shape, result.output.shape)) - return result +# TODO move these registrations to backend-specific files +@to_funsor.register(torch.Tensor) +@to_funsor.register(np.ndarray) +@to_funsor.register(np.generic) +def tensor_to_funsor(x, output=None, inputs=None): + if output is None and inputs is None: + return Tensor(x) + if output is not None and inputs is None: + result = Tensor(x, dtype=output.dtype) + if result.output != output: + raise ValueError("Invalid shape: expected {}, actual {}" + .format(output.shape, result.output.shape)) + return result + if inputs is not None: + raise NotImplementedError("TODO") def align_tensor(new_inputs, x, expand=False): diff --git a/funsor/terms.py b/funsor/terms.py index 4f095be74..f96d1cb8b 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -756,14 +756,15 @@ def _(arg, indent, out): interpreter.children.register(Funsor)(interpreter.children_funsor) -@dispatch(object) -def to_funsor(x): +@singledispatch +def to_funsor(x, output=None, inputs=None): """ Convert to a :class:`Funsor` . Only :class:`Funsor` s and scalars are accepted. :param x: An object. :param funsor.domains.Domain output: An optional output hint. + :param OrderedDict inputs: An optional inputs hint. :return: A Funsor equivalent to ``x``. :rtype: Funsor :raises: ValueError @@ -771,36 +772,24 @@ def to_funsor(x): raise ValueError("Cannot convert to Funsor: {}".format(repr(x))) -@dispatch(object, Domain) -def to_funsor(x, output): - raise ValueError("Cannot convert to Funsor: {}".format(repr(x))) - - -@dispatch(object, object) -def to_funsor(x, output): - raise TypeError("Invalid Domain: {}".format(repr(output))) - - -@dispatch(Funsor) -def to_funsor(x): - return x - - -@dispatch(Funsor, Domain) -def to_funsor(x, output): - if x.output != output: +@to_funsor.register(Funsor) +def funsor_to_funsor(x, output=None, inputs=None): + if output is not None and x.output != output: raise ValueError("Output mismatch: {} vs {}".format(x.output, output)) + if inputs is not None and x.inputs != inputs: + raise ValueError("Inputs mismatch: {} vs {}".format(x.inputs, inputs)) return x @singledispatch -def to_data(x): +def to_data(x, inputs=None): """ Extract a python object from a :class:`Funsor`. Raises a ``ValueError`` if free variables remain or if the funsor is lazy. :param x: An object, possibly a :class:`Funsor`. + :param OrderedDict inputs: An optional inputs hint. :return: A non-funsor equivalent to ``x``. :raises: ValueError if any free variables remain. :raises: PatternMissingError if funsor is not fully evaluated. @@ -809,8 +798,8 @@ def to_data(x): @to_data.register(Funsor) -def _to_data_funsor(x): - if x.inputs: +def _to_data_funsor(x, inputs=None): + if inputs is None and x.inputs: raise ValueError(f"cannot convert {type(x)} to data due to lazy inputs: {set(x.inputs)}") raise PatternMissingError(r"cannot convert to a non-Funsor: {repr(x)}") @@ -839,8 +828,10 @@ def eager_subs(self, subs): return subs[0][1] -@dispatch(str, Domain) -def to_funsor(name, output): +@to_funsor.register(str) +def name_to_funsor(name, output=None): + if output is None: + raise ValueError(f"Missing output: {name}") return Variable(name, output) @@ -1099,13 +1090,10 @@ def eager_unary(self, op): return Number(op(self.data), dtype) -@dispatch(numbers.Number) -def to_funsor(x): - return Number(x) - - -@dispatch(numbers.Number, Domain) -def to_funsor(x, output): +@to_funsor.register(numbers.Number) +def number_to_funsor(x, output=None): + if output is None: + return Number(x) if output.shape: raise ValueError("Cannot create Number with shape {}".format(output.shape)) return Number(x, output.dtype) From bc56c2e780b0a877bf972c6f085463bcc8d22170 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 11 Feb 2020 15:08:48 -0800 Subject: [PATCH 02/24] implement funsor_to_tensor and tensor_to_funsor --- funsor/tensor.py | 53 +++++++++++++++++++++++++++++++++++++++--------- funsor/terms.py | 18 ++++++++-------- 2 files changed, 52 insertions(+), 19 deletions(-) diff --git a/funsor/tensor.py b/funsor/tensor.py index ec440fd7d..9d82fe3f0 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -409,17 +409,34 @@ def backend(self): @to_funsor.register(torch.Tensor) @to_funsor.register(np.ndarray) @to_funsor.register(np.generic) -def tensor_to_funsor(x, output=None, inputs=None): - if output is None and inputs is None: - return Tensor(x) - if output is not None and inputs is None: +def tensor_to_funsor(x, output=None, dim_to_name=None): + if not dim_to_name: + output = output if output is not None else reals(*x.shape) result = Tensor(x, dtype=output.dtype) if result.output != output: raise ValueError("Invalid shape: expected {}, actual {}" .format(output.shape, result.output.shape)) return result - if inputs is not None: - raise NotImplementedError("TODO") + else: + assert output is not None # TODO attempt to infer output + # logic very similar to pyro.ops.packed.pack + # this should not touch memory, only reshape + # pack the tensor according to the dim => (name, domain) mapping in inputs + packed_inputs = OrderedDict( + [dim_to_name[dim - len(x.shape)] for dim, size in enumerate(x.shape) + if size > 1 and dim < len(x.shape) - len(output.shape)] + ) + if any(size > 1 for size in output.shape): + # pack outputs into a single dimension + x = x.reshape(x.shape[:-len(output.shape)] + (-1,)) + x = x.squeeze() + if output.shape and all(size == 1 for size in output.shape): + # handle special case: all output dims are 1 + x = x.unsqueeze(-1) + x = x.reshape(x.shape[:-1] + output.shape) + # unpack dims into final domain shapes + x = x.reshape(sum([d.shape for d in packed_inputs.values()], ()) + output.shape) + return Tensor(x, packed_inputs, dtype=output.dtype) def align_tensor(new_inputs, x, expand=False): @@ -487,10 +504,26 @@ def align_tensors(*args, **kwargs): @to_data.register(Tensor) -def _to_data_tensor(x): - if x.inputs: - raise ValueError(f"cannot convert Tensor to data due to lazy inputs: {set(x.inputs)}") - return x.data +def _to_data_tensor(x, name_to_dim=None): + if not name_to_dim: + if x.inputs: + raise ValueError(f"cannot convert Tensor to data due to lazy inputs: {set(x.inputs)}") + return x.data + else: + # logic very similar to pyro.ops.packed.unpack + # first collapse input domains into single dimensions + data = x.data.reshape(tuple(d.num_elements for d in x.inputs.values()) + x.output.shape) + # permute packed dimensions to correct order + unsorted_dims = [name_to_dim[name] for name in x.inputs] + dims = sorted(unsorted_dims) + permutation = [unsorted_dims.index(dim) for dim in dims] + \ + list(range(len(dims), len(dims) + len(x.output.shape))) + data = data.permute(*permutation) + # expand + batch_shape = [1] * -min(dims) + for dim, size in zip(dims, data.shape): + batch_shape[dim] = size + return data.reshape(batch_shape + x.output.shape) @eager.register(Binary, Op, Tensor, Number) diff --git a/funsor/terms.py b/funsor/terms.py index f96d1cb8b..e5fc5ec5e 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -757,14 +757,14 @@ def _(arg, indent, out): @singledispatch -def to_funsor(x, output=None, inputs=None): +def to_funsor(x, output=None, dim_to_name=None): """ Convert to a :class:`Funsor` . Only :class:`Funsor` s and scalars are accepted. :param x: An object. :param funsor.domains.Domain output: An optional output hint. - :param OrderedDict inputs: An optional inputs hint. + :param OrderedDict dim_to_name: An optional inputs hint. :return: A Funsor equivalent to ``x``. :rtype: Funsor :raises: ValueError @@ -773,23 +773,23 @@ def to_funsor(x, output=None, inputs=None): @to_funsor.register(Funsor) -def funsor_to_funsor(x, output=None, inputs=None): +def funsor_to_funsor(x, output=None, dim_to_name=None): if output is not None and x.output != output: raise ValueError("Output mismatch: {} vs {}".format(x.output, output)) - if inputs is not None and x.inputs != inputs: - raise ValueError("Inputs mismatch: {} vs {}".format(x.inputs, inputs)) + if dim_to_name is not None and list(x.inputs.keys()) != [v[0] for v in dim_to_name.values()]: + raise ValueError("Inputs mismatch: {} vs {}".format(x.inputs, dim_to_name)) return x @singledispatch -def to_data(x, inputs=None): +def to_data(x, name_to_dim=None): """ Extract a python object from a :class:`Funsor`. Raises a ``ValueError`` if free variables remain or if the funsor is lazy. :param x: An object, possibly a :class:`Funsor`. - :param OrderedDict inputs: An optional inputs hint. + :param OrderedDict name_to_dim: An optional inputs hint. :return: A non-funsor equivalent to ``x``. :raises: ValueError if any free variables remain. :raises: PatternMissingError if funsor is not fully evaluated. @@ -798,8 +798,8 @@ def to_data(x, inputs=None): @to_data.register(Funsor) -def _to_data_funsor(x, inputs=None): - if inputs is None and x.inputs: +def _to_data_funsor(x, name_to_dim=None): + if name_to_dim is None and x.inputs: raise ValueError(f"cannot convert {type(x)} to data due to lazy inputs: {set(x.inputs)}") raise PatternMissingError(r"cannot convert to a non-Funsor: {repr(x)}") From fa04123549bf4e5243458c40a842a29d26bcf98f Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 11 Feb 2020 16:00:48 -0800 Subject: [PATCH 03/24] use to_funsor and to_data in funsor.pyro.convert --- funsor/pyro/convert.py | 37 ++++++++++--------------------------- funsor/tensor.py | 14 +++++++------- 2 files changed, 17 insertions(+), 34 deletions(-) diff --git a/funsor/pyro/convert.py b/funsor/pyro/convert.py index b04a0a6ff..4ed4992d6 100644 --- a/funsor/pyro/convert.py +++ b/funsor/pyro/convert.py @@ -28,12 +28,12 @@ from funsor.cnf import Contraction from funsor.delta import Delta from funsor.distributions import BernoulliLogits, MultivariateNormal, Normal -from funsor.domains import bint, reals +from funsor.domains import Domain, bint, reals from funsor.gaussian import Gaussian from funsor.interpreter import gensym from funsor.ops import cholesky from funsor.tensor import Tensor, align_tensors -from funsor.terms import Funsor, Independent, Variable, eager +from funsor.terms import Funsor, Independent, Variable, eager, to_data, to_funsor # Conversion functions use fixed names for Pyro batch dims, but # accept an event_inputs tuple for custom event dim names. @@ -63,21 +63,12 @@ def tensor_to_funsor(tensor, event_inputs=(), event_output=0, dtype="real"): assert isinstance(event_output, int) and event_output >= 0 inputs_shape = tensor.shape[:tensor.dim() - event_output] output_shape = tensor.shape[tensor.dim() - event_output:] - dim_to_name = DIM_TO_NAME + event_inputs if event_inputs else DIM_TO_NAME - - # Squeeze shape of inputs. - inputs = OrderedDict() - squeezed_shape = [] - for dim, size in enumerate(inputs_shape): - if size > 1: - name = dim_to_name[dim - len(inputs_shape)] - inputs[name] = bint(size) - squeezed_shape.append(size) - squeezed_shape = torch.Size(squeezed_shape) - if squeezed_shape != inputs_shape: - tensor = tensor.reshape(squeezed_shape + output_shape) - - return Tensor(tensor, inputs, dtype) + dim_to_name_list = DIM_TO_NAME + event_inputs if event_inputs else DIM_TO_NAME + dim_to_name = OrderedDict(zip( + range(-len(inputs_shape), 0), + zip(dim_to_name_list[len(dim_to_name_list) - len(inputs_shape):], + map(bint, inputs_shape)))) + return to_funsor(tensor, Domain(dtype=dtype, shape=output_shape), dim_to_name) def funsor_to_tensor(funsor_, ndims, event_inputs=()): @@ -99,16 +90,8 @@ def funsor_to_tensor(funsor_, ndims, event_inputs=()): if event_inputs: dim_to_name = DIM_TO_NAME + event_inputs name_to_dim = dict(zip(dim_to_name, range(-len(dim_to_name), 0))) - names = tuple(sorted(funsor_.inputs, key=name_to_dim.__getitem__)) - tensor = funsor_.align(names).data - if names: - # Unsqueeze shape of inputs. - dims = list(map(name_to_dim.__getitem__, names)) - inputs_shape = [1] * (-dims[0]) - for dim, size in zip(dims, tensor.shape): - inputs_shape[dim] = size - inputs_shape = torch.Size(inputs_shape) - tensor = tensor.reshape(inputs_shape + funsor_.output.shape) + tensor = to_data(funsor_, name_to_dim) + if ndims != tensor.dim(): tensor = tensor.reshape((1,) * (ndims - tensor.dim()) + tensor.shape) assert tensor.dim() == ndims diff --git a/funsor/tensor.py b/funsor/tensor.py index 9d82fe3f0..be9dff33d 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -419,23 +419,23 @@ def tensor_to_funsor(x, output=None, dim_to_name=None): return result else: assert output is not None # TODO attempt to infer output + assert all(isinstance(k, int) and isinstance(v[0], str) and isinstance(v[1], Domain) + for k, v in dim_to_name.items()) # logic very similar to pyro.ops.packed.pack # this should not touch memory, only reshape # pack the tensor according to the dim => (name, domain) mapping in inputs packed_inputs = OrderedDict( - [dim_to_name[dim - len(x.shape)] for dim, size in enumerate(x.shape) + [dim_to_name[dim + len(output.shape) - len(x.shape)] for dim, size in enumerate(x.shape) if size > 1 and dim < len(x.shape) - len(output.shape)] ) if any(size > 1 for size in output.shape): # pack outputs into a single dimension x = x.reshape(x.shape[:-len(output.shape)] + (-1,)) x = x.squeeze() - if output.shape and all(size == 1 for size in output.shape): + if all(size == 1 for size in output.shape): # handle special case: all output dims are 1 x = x.unsqueeze(-1) x = x.reshape(x.shape[:-1] + output.shape) - # unpack dims into final domain shapes - x = x.reshape(sum([d.shape for d in packed_inputs.values()], ()) + output.shape) return Tensor(x, packed_inputs, dtype=output.dtype) @@ -505,14 +505,14 @@ def align_tensors(*args, **kwargs): @to_data.register(Tensor) def _to_data_tensor(x, name_to_dim=None): - if not name_to_dim: + if not name_to_dim or not x.inputs: if x.inputs: raise ValueError(f"cannot convert Tensor to data due to lazy inputs: {set(x.inputs)}") return x.data else: # logic very similar to pyro.ops.packed.unpack # first collapse input domains into single dimensions - data = x.data.reshape(tuple(d.num_elements for d in x.inputs.values()) + x.output.shape) + data = x.data.reshape(tuple(d.dtype for d in x.inputs.values()) + x.output.shape) # permute packed dimensions to correct order unsorted_dims = [name_to_dim[name] for name in x.inputs] dims = sorted(unsorted_dims) @@ -523,7 +523,7 @@ def _to_data_tensor(x, name_to_dim=None): batch_shape = [1] * -min(dims) for dim, size in zip(dims, data.shape): batch_shape[dim] = size - return data.reshape(batch_shape + x.output.shape) + return data.reshape(tuple(batch_shape) + x.output.shape) @eager.register(Binary, Op, Tensor, Number) From 095156951d0356eb4ac7f603cf74788b20d928c0 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 11 Feb 2020 16:38:39 -0800 Subject: [PATCH 04/24] nit --- funsor/tensor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/funsor/tensor.py b/funsor/tensor.py index be9dff33d..adff1cf6f 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -424,10 +424,11 @@ def tensor_to_funsor(x, output=None, dim_to_name=None): # logic very similar to pyro.ops.packed.pack # this should not touch memory, only reshape # pack the tensor according to the dim => (name, domain) mapping in inputs - packed_inputs = OrderedDict( - [dim_to_name[dim + len(output.shape) - len(x.shape)] for dim, size in enumerate(x.shape) - if size > 1 and dim < len(x.shape) - len(output.shape)] - ) + packed_inputs = OrderedDict() + for dim, size in enumerate(x.shape): + if size > 1 and dim < len(x.shape) - len(output.shape): + name, domain = dim_to_name[dim + len(output.shape) - len(x.shape)] + packed_inputs[name] = domain if domain.dtype > 1 else bint(size) if any(size > 1 for size in output.shape): # pack outputs into a single dimension x = x.reshape(x.shape[:-len(output.shape)] + (-1,)) From 84d9fc06fc95233fba14a1ea922927ff0d1d46b2 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 11 Feb 2020 16:42:07 -0800 Subject: [PATCH 05/24] tweak --- funsor/tensor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/funsor/tensor.py b/funsor/tensor.py index adff1cf6f..dec52a0e5 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -425,10 +425,11 @@ def tensor_to_funsor(x, output=None, dim_to_name=None): # this should not touch memory, only reshape # pack the tensor according to the dim => (name, domain) mapping in inputs packed_inputs = OrderedDict() - for dim, size in enumerate(x.shape): - if size > 1 and dim < len(x.shape) - len(output.shape): - name, domain = dim_to_name[dim + len(output.shape) - len(x.shape)] - packed_inputs[name] = domain if domain.dtype > 1 else bint(size) + for dim, size in zip(range(len(x.shape) - len(output.shape)), x.shape): + if size == 1: + continue # TODO broadcast domain and shape here + name, domain = dim_to_name[dim + len(output.shape) - len(x.shape)] + packed_inputs[name] = domain if domain.dtype > 1 else bint(size) if any(size > 1 for size in output.shape): # pack outputs into a single dimension x = x.reshape(x.shape[:-len(output.shape)] + (-1,)) From a32566f03c812417a0efcad667568593ff09ef26 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 11 Feb 2020 17:12:06 -0800 Subject: [PATCH 06/24] attempt at generic distribution conversion --- funsor/distributions.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/funsor/distributions.py b/funsor/distributions.py index 3c60190ed..9507e4e2d 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -15,7 +15,7 @@ from funsor.gaussian import Gaussian from funsor.interpreter import gensym, interpretation from funsor.tensor import Tensor, align_tensors, ignore_jit_warnings, stack -from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, lazy, to_funsor +from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, lazy, to_data, to_funsor def numbers_to_tensors(*args): @@ -602,6 +602,25 @@ def eager_vonmises(loc, concentration, value): return VonMises.eager_log_prob(loc=loc, concentration=concentration, value=value) +@to_funsor.register(dist.TorchDistribution) +def torchdistribution_to_funsor(pyro_dist, output=None, dim_to_name=None): + import funsor.distributions # TODO find a better way to do this lookup + funsor_dist = getattr(funsor.distributions, type(pyro_dist).__name__) + params = [to_funsor(getattr(pyro_dist, param_name), dim_to_name=dim_to_name) + for param_name in funsor_dist._ast_fields if param_name != 'value'] + return funsor_dist(*params) + + +@to_data.register(Distribution) +def distribution_to_data(funsor_dist, name_to_dim=None): + pyro_dist = funsor_dist.dist_class + assert 'value' not in name_to_dim + assert funsor_dist.inputs['value'].shape == () # TODO convert properly + params = [to_data(getattr(pyro_dist, param_name), name_to_dim=name_to_dim) + for param_name in funsor_dist._ast_fields if param_name != 'value'] + return pyro_dist(*params) + + __all__ = [ 'Bernoulli', 'BernoulliLogits', From 6be8a3447f4ec693de830e78fabfa6966c18ecf3 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 13 Feb 2020 11:38:16 -0800 Subject: [PATCH 07/24] address comment --- funsor/tensor.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/funsor/tensor.py b/funsor/tensor.py index dec52a0e5..857864c11 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -430,14 +430,9 @@ def tensor_to_funsor(x, output=None, dim_to_name=None): continue # TODO broadcast domain and shape here name, domain = dim_to_name[dim + len(output.shape) - len(x.shape)] packed_inputs[name] = domain if domain.dtype > 1 else bint(size) - if any(size > 1 for size in output.shape): - # pack outputs into a single dimension - x = x.reshape(x.shape[:-len(output.shape)] + (-1,)) - x = x.squeeze() - if all(size == 1 for size in output.shape): - # handle special case: all output dims are 1 - x = x.unsqueeze(-1) - x = x.reshape(x.shape[:-1] + output.shape) + shape = tuple(d.size for d in packed_inputs.values()) + output.shape + if x.shape != shape: + x = x.reshape(shape) return Tensor(x, packed_inputs, dtype=output.dtype) @@ -506,12 +501,14 @@ def align_tensors(*args, **kwargs): @to_data.register(Tensor) -def _to_data_tensor(x, name_to_dim=None): +def tensor_to_data(x, name_to_dim=None): if not name_to_dim or not x.inputs: if x.inputs: raise ValueError(f"cannot convert Tensor to data due to lazy inputs: {set(x.inputs)}") return x.data else: + assert all(isinstance(k, str) and isinstance(v, int) and v <= 0 + for k, v in name_to_dim.items()) # logic very similar to pyro.ops.packed.unpack # first collapse input domains into single dimensions data = x.data.reshape(tuple(d.dtype for d in x.inputs.values()) + x.output.shape) From 630abf5236d59c62dac0ca5ef840c5225d732a53 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 13 Feb 2020 11:48:09 -0800 Subject: [PATCH 08/24] remove domains from dim_to_name --- funsor/pyro/convert.py | 3 +-- funsor/tensor.py | 11 +++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/funsor/pyro/convert.py b/funsor/pyro/convert.py index 4ed4992d6..d6492a40f 100644 --- a/funsor/pyro/convert.py +++ b/funsor/pyro/convert.py @@ -66,8 +66,7 @@ def tensor_to_funsor(tensor, event_inputs=(), event_output=0, dtype="real"): dim_to_name_list = DIM_TO_NAME + event_inputs if event_inputs else DIM_TO_NAME dim_to_name = OrderedDict(zip( range(-len(inputs_shape), 0), - zip(dim_to_name_list[len(dim_to_name_list) - len(inputs_shape):], - map(bint, inputs_shape)))) + dim_to_name_list[len(dim_to_name_list) - len(inputs_shape):])) return to_funsor(tensor, Domain(dtype=dtype, shape=output_shape), dim_to_name) diff --git a/funsor/tensor.py b/funsor/tensor.py index 857864c11..f3150c6a3 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -419,17 +419,16 @@ def tensor_to_funsor(x, output=None, dim_to_name=None): return result else: assert output is not None # TODO attempt to infer output - assert all(isinstance(k, int) and isinstance(v[0], str) and isinstance(v[1], Domain) + assert all(isinstance(k, int) and isinstance(v, str) for k, v in dim_to_name.items()) # logic very similar to pyro.ops.packed.pack # this should not touch memory, only reshape - # pack the tensor according to the dim => (name, domain) mapping in inputs + # pack the tensor according to the dim => name mapping in inputs packed_inputs = OrderedDict() for dim, size in zip(range(len(x.shape) - len(output.shape)), x.shape): - if size == 1: - continue # TODO broadcast domain and shape here - name, domain = dim_to_name[dim + len(output.shape) - len(x.shape)] - packed_inputs[name] = domain if domain.dtype > 1 else bint(size) + name = dim_to_name.get(dim + len(output.shape) - len(x.shape), None) + if name is not None: + packed_inputs[name] = bint(size) shape = tuple(d.size for d in packed_inputs.values()) + output.shape if x.shape != shape: x = x.reshape(shape) From bbbdafaaedc1a6aa401fa2e6805c65f194587edc Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 13 Feb 2020 11:51:35 -0800 Subject: [PATCH 09/24] add dim_to_name docstring comment --- funsor/terms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index e5fc5ec5e..4aa3071e7 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -764,7 +764,7 @@ def to_funsor(x, output=None, dim_to_name=None): :param x: An object. :param funsor.domains.Domain output: An optional output hint. - :param OrderedDict dim_to_name: An optional inputs hint. + :param OrderedDict dim_to_name: An optional mapping from negative batch dimensions to name strings. :return: A Funsor equivalent to ``x``. :rtype: Funsor :raises: ValueError @@ -776,7 +776,7 @@ def to_funsor(x, output=None, dim_to_name=None): def funsor_to_funsor(x, output=None, dim_to_name=None): if output is not None and x.output != output: raise ValueError("Output mismatch: {} vs {}".format(x.output, output)) - if dim_to_name is not None and list(x.inputs.keys()) != [v[0] for v in dim_to_name.values()]: + if dim_to_name is not None and list(x.inputs.keys()) != list(dim_to_name.values()): raise ValueError("Inputs mismatch: {} vs {}".format(x.inputs, dim_to_name)) return x From 01cde8f9e6178b97c0c2f1d824807281a0e3be07 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 13 Feb 2020 11:58:24 -0800 Subject: [PATCH 10/24] assert batch dim negativity --- funsor/tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/funsor/tensor.py b/funsor/tensor.py index f3150c6a3..139645368 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -419,7 +419,7 @@ def tensor_to_funsor(x, output=None, dim_to_name=None): return result else: assert output is not None # TODO attempt to infer output - assert all(isinstance(k, int) and isinstance(v, str) + assert all(isinstance(k, int) and k < 0 and isinstance(v, str) for k, v in dim_to_name.items()) # logic very similar to pyro.ops.packed.pack # this should not touch memory, only reshape @@ -506,7 +506,7 @@ def tensor_to_data(x, name_to_dim=None): raise ValueError(f"cannot convert Tensor to data due to lazy inputs: {set(x.inputs)}") return x.data else: - assert all(isinstance(k, str) and isinstance(v, int) and v <= 0 + assert all(isinstance(k, str) and isinstance(v, int) and v < 0 for k, v in name_to_dim.items()) # logic very similar to pyro.ops.packed.unpack # first collapse input domains into single dimensions From 1a3a88f09f1bec9ff74c4e81b64949f804353962 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 13 Feb 2020 14:07:44 -0800 Subject: [PATCH 11/24] consider even named dims of size 1 empty in tensor_to_funsor --- funsor/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funsor/tensor.py b/funsor/tensor.py index 139645368..3b4cd3074 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -427,7 +427,7 @@ def tensor_to_funsor(x, output=None, dim_to_name=None): packed_inputs = OrderedDict() for dim, size in zip(range(len(x.shape) - len(output.shape)), x.shape): name = dim_to_name.get(dim + len(output.shape) - len(x.shape), None) - if name is not None: + if name is not None and size > 1: packed_inputs[name] = bint(size) shape = tuple(d.size for d in packed_inputs.values()) + output.shape if x.shape != shape: From 4956775b21f6497e5222a100a56180ccbd3e2b00 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 14 Feb 2020 17:25:31 -0800 Subject: [PATCH 12/24] sketch new distribution wrapper --- funsor/distributions.py | 69 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/funsor/distributions.py b/funsor/distributions.py index 9507e4e2d..df9406810 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -96,10 +96,79 @@ def eager_log_prob(cls, **params): return Tensor(data, inputs) +def _dummy_tensor(domain): + return torch.tensor(0.1 if domain.dtype == 'real' else 1).expand(domain.shape) + + +class Distribution2(Funsor): + """ + Different design for the Distribution Funsor wrapper, + closer to Gaussian or Delta in which the value is a fresh input. + """ + dist_class = dist.Distribution # defined by derived classes + + def __init__(self, *args, name='value'): + params = tuple(zip(self._ast_fields, args)) + inputs = OrderedDict() + for param_name, value in params: + assert isinstance(param_name, str) + assert isinstance(value, Funsor) + inputs.update(value.inputs) + assert isinstance(name, str) and name not in inputs + inputs[name] = self._infer_value_shape(cls, **params) + output = reals() + fresh = frozenset({name}) + bound = frozenset() + super().__init__(inputs, output, fresh, bound) + self.params = params + self.name = name + + @classmethod + def _infer_value_shape(cls, **kwargs): + # rely on the underlying distribution's logic to infer the event_shape + instance = cls.dist_class(**{k: _dummy_tensor(v.output) for k, v in kwargs}) + out_shape = instance.event_shape + if isinstance(instance.support, torch.distributions.constraints._IntegerInterval): + out_dtype = instance.support.upper_bound + else: + out_dtype = 'real' + return Domain(dtype=out_dtype, shape=out_shape) + + def eager_subs(self, subs): + name, sub = subs[0] + if isinstance(sub, Tensor): + inputs, tensors = align_tensors(*self.params.values()) + data = self.dist_class(**params).log_prob(value) + return Tensor(data, inputs) + elif isinstance(sub, (Variable, str)): # TODO change name param + return + else: + raise NotImplementedError("not implemented") + ################################################################################ # Distribution Wrappers ################################################################################ +def make_dist(pyro_dist_class, param_names=()): + + import makefun + + if not param_names: + param_names = tuple(pyro_dist_class.arg_constraints.keys()) + assert all(name in pyro_dist_class.arg_constraints for name in param_names) + + @makefun.with_signature(f"__init__(self, {', '.join(param_names)}, value='value')") + def dist_init(*args, **kwargs): + return super().__init__(*args, **kwargs) + + dist_class = DistributionMeta(pyro_dist_class.__name__, (Distribution,), { + 'dist_class': pyro_dist_class, + '__init__': dist_init, + }) + + return dist_class + + class BernoulliProbs(Distribution): """ Wraps :class:`pyro.distributions.Bernoulli` . From 8dc858560e574e7434a67e960750acf388f374a5 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 2 Mar 2020 15:06:45 -0800 Subject: [PATCH 13/24] split new version into second file --- funsor/distributions.py | 92 +------------------------------ funsor/distributions2.py | 114 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 90 deletions(-) create mode 100644 funsor/distributions2.py diff --git a/funsor/distributions.py b/funsor/distributions.py index df9406810..15b50a811 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -6,7 +6,6 @@ import pyro.distributions as dist import torch -from pyro.distributions.util import broadcast_shape import funsor.delta import funsor.ops as ops @@ -15,7 +14,8 @@ from funsor.gaussian import Gaussian from funsor.interpreter import gensym, interpretation from funsor.tensor import Tensor, align_tensors, ignore_jit_warnings, stack -from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, lazy, to_data, to_funsor +from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, lazy, to_funsor +from funsor.util import broadcast_shape def numbers_to_tensors(*args): @@ -96,79 +96,10 @@ def eager_log_prob(cls, **params): return Tensor(data, inputs) -def _dummy_tensor(domain): - return torch.tensor(0.1 if domain.dtype == 'real' else 1).expand(domain.shape) - - -class Distribution2(Funsor): - """ - Different design for the Distribution Funsor wrapper, - closer to Gaussian or Delta in which the value is a fresh input. - """ - dist_class = dist.Distribution # defined by derived classes - - def __init__(self, *args, name='value'): - params = tuple(zip(self._ast_fields, args)) - inputs = OrderedDict() - for param_name, value in params: - assert isinstance(param_name, str) - assert isinstance(value, Funsor) - inputs.update(value.inputs) - assert isinstance(name, str) and name not in inputs - inputs[name] = self._infer_value_shape(cls, **params) - output = reals() - fresh = frozenset({name}) - bound = frozenset() - super().__init__(inputs, output, fresh, bound) - self.params = params - self.name = name - - @classmethod - def _infer_value_shape(cls, **kwargs): - # rely on the underlying distribution's logic to infer the event_shape - instance = cls.dist_class(**{k: _dummy_tensor(v.output) for k, v in kwargs}) - out_shape = instance.event_shape - if isinstance(instance.support, torch.distributions.constraints._IntegerInterval): - out_dtype = instance.support.upper_bound - else: - out_dtype = 'real' - return Domain(dtype=out_dtype, shape=out_shape) - - def eager_subs(self, subs): - name, sub = subs[0] - if isinstance(sub, Tensor): - inputs, tensors = align_tensors(*self.params.values()) - data = self.dist_class(**params).log_prob(value) - return Tensor(data, inputs) - elif isinstance(sub, (Variable, str)): # TODO change name param - return - else: - raise NotImplementedError("not implemented") - ################################################################################ # Distribution Wrappers ################################################################################ -def make_dist(pyro_dist_class, param_names=()): - - import makefun - - if not param_names: - param_names = tuple(pyro_dist_class.arg_constraints.keys()) - assert all(name in pyro_dist_class.arg_constraints for name in param_names) - - @makefun.with_signature(f"__init__(self, {', '.join(param_names)}, value='value')") - def dist_init(*args, **kwargs): - return super().__init__(*args, **kwargs) - - dist_class = DistributionMeta(pyro_dist_class.__name__, (Distribution,), { - 'dist_class': pyro_dist_class, - '__init__': dist_init, - }) - - return dist_class - - class BernoulliProbs(Distribution): """ Wraps :class:`pyro.distributions.Bernoulli` . @@ -671,25 +602,6 @@ def eager_vonmises(loc, concentration, value): return VonMises.eager_log_prob(loc=loc, concentration=concentration, value=value) -@to_funsor.register(dist.TorchDistribution) -def torchdistribution_to_funsor(pyro_dist, output=None, dim_to_name=None): - import funsor.distributions # TODO find a better way to do this lookup - funsor_dist = getattr(funsor.distributions, type(pyro_dist).__name__) - params = [to_funsor(getattr(pyro_dist, param_name), dim_to_name=dim_to_name) - for param_name in funsor_dist._ast_fields if param_name != 'value'] - return funsor_dist(*params) - - -@to_data.register(Distribution) -def distribution_to_data(funsor_dist, name_to_dim=None): - pyro_dist = funsor_dist.dist_class - assert 'value' not in name_to_dim - assert funsor_dist.inputs['value'].shape == () # TODO convert properly - params = [to_data(getattr(pyro_dist, param_name), name_to_dim=name_to_dim) - for param_name in funsor_dist._ast_fields if param_name != 'value'] - return pyro_dist(*params) - - __all__ = [ 'Bernoulli', 'BernoulliLogits', diff --git a/funsor/distributions2.py b/funsor/distributions2.py new file mode 100644 index 000000000..234b9a046 --- /dev/null +++ b/funsor/distributions2.py @@ -0,0 +1,114 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import math +from collections import OrderedDict + +import pyro.distributions as dist +import torch +from pyro.distributions.util import broadcast_shape + +import funsor.delta +import funsor.ops as ops +from funsor.affine import is_affine +from funsor.domains import bint, reals +from funsor.gaussian import Gaussian +from funsor.interpreter import gensym, interpretation +from funsor.tensor import Tensor, align_tensors, ignore_jit_warnings, stack +from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, lazy, to_data, to_funsor + + +def _dummy_tensor(domain): + return torch.tensor(0.1 if domain.dtype == 'real' else 1).expand(domain.shape) + + +class Distribution2(Funsor): + """ + Different design for the Distribution Funsor wrapper, + closer to Gaussian or Delta in which the value is a fresh input. + """ + dist_class = dist.Distribution # defined by derived classes + + def __init__(self, *args, name='value'): + params = tuple(zip(self._ast_fields, args)) + inputs = OrderedDict() + for param_name, value in params: + assert isinstance(param_name, str) + assert isinstance(value, Funsor) + inputs.update(value.inputs) + assert isinstance(name, str) and name not in inputs + inputs[name] = self._infer_value_shape(cls, **params) + output = reals() + fresh = frozenset({name}) + bound = frozenset() + super().__init__(inputs, output, fresh, bound) + self.params = params + self.name = name + + @classmethod + def _infer_value_shape(cls, **kwargs): + # rely on the underlying distribution's logic to infer the event_shape + instance = cls.dist_class(**{k: _dummy_tensor(v.output) for k, v in kwargs}) + out_shape = instance.event_shape + if isinstance(instance.support, torch.distributions.constraints._IntegerInterval): + out_dtype = instance.support.upper_bound + else: + out_dtype = 'real' + return Domain(dtype=out_dtype, shape=out_shape) + + def eager_subs(self, subs): + name, sub = subs[0] + if isinstance(sub, (Number, Tensor)): + inputs, tensors = align_tensors(*self.params.values()) + data = self.dist_class(**tensors).log_prob(sub.data) + return Tensor(data, inputs) + elif isinstance(sub, (Variable, str)): + return type(self)(*self._ast_values, name=sub.name if isinstance(sub, Variable) else sub) + else: + raise NotImplementedError("not implemented") + +################################################################################ +# Distribution Wrappers +################################################################################ + +def make_dist(pyro_dist_class, param_names=()): + + import makefun + + if not param_names: + param_names = tuple(pyro_dist_class.arg_constraints.keys()) + assert all(name in pyro_dist_class.arg_constraints for name in param_names) + + @makefun.with_signature(f"__init__(self, {', '.join(param_names)}, name='value')") + def dist_init(*args, **kwargs): + return super().__init__(*args, **kwargs) + + dist_class = FunsorMeta(pyro_dist_class.__name__, (Distribution2,), { + 'dist_class': pyro_dist_class, + '__init__': dist_init, + }) + + return dist_class + + +for pyro_dist in (dist.Categorical, dist.Bernoulli, dist.Normal): + locals()[pyro_dist_class.__name__.split(".")[-1]] = make_dist(pyro_dist_class) + + +@to_funsor.register(dist.TorchDistribution) +def torchdistribution_to_funsor(pyro_dist, output=None, dim_to_name=None): + import funsor.distributions # TODO find a better way to do this lookup + funsor_dist = getattr(funsor.distributions, type(pyro_dist).__name__) + params = [to_funsor(getattr(pyro_dist, param_name), dim_to_name=dim_to_name) + for param_name in funsor_dist._ast_fields if param_name != 'value'] + return funsor_dist(*params) + + +@to_data.register(Distribution2) +def distribution_to_data(funsor_dist, name_to_dim=None): + pyro_dist = funsor_dist.dist_class + assert 'value' not in name_to_dim + assert funsor_dist.inputs['value'].shape == () # TODO convert properly + params = [to_data(getattr(pyro_dist, param_name), name_to_dim=name_to_dim) + for param_name in funsor_dist._ast_fields if param_name != 'value'] + return pyro_dist(*params) From f84a03104695b798dec70951eba72e5ec463ca99 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 2 Mar 2020 15:10:38 -0800 Subject: [PATCH 14/24] lint --- funsor/distributions2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/funsor/distributions2.py b/funsor/distributions2.py index 234b9a046..69109fc5b 100644 --- a/funsor/distributions2.py +++ b/funsor/distributions2.py @@ -11,7 +11,7 @@ import funsor.delta import funsor.ops as ops from funsor.affine import is_affine -from funsor.domains import bint, reals +from funsor.domains import Domain, bint, reals from funsor.gaussian import Gaussian from funsor.interpreter import gensym, interpretation from funsor.tensor import Tensor, align_tensors, ignore_jit_warnings, stack @@ -37,7 +37,7 @@ def __init__(self, *args, name='value'): assert isinstance(value, Funsor) inputs.update(value.inputs) assert isinstance(name, str) and name not in inputs - inputs[name] = self._infer_value_shape(cls, **params) + inputs[name] = self._infer_value_shape(**params) output = reals() fresh = frozenset({name}) bound = frozenset() @@ -67,6 +67,7 @@ def eager_subs(self, subs): else: raise NotImplementedError("not implemented") + ################################################################################ # Distribution Wrappers ################################################################################ @@ -91,7 +92,7 @@ def dist_init(*args, **kwargs): return dist_class -for pyro_dist in (dist.Categorical, dist.Bernoulli, dist.Normal): +for pyro_dist_class in (dist.Categorical, dist.Bernoulli, dist.Normal): locals()[pyro_dist_class.__name__.split(".")[-1]] = make_dist(pyro_dist_class) From 809e44cef91c216b0550f78535d89b54b3a82983 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 2 Mar 2020 15:34:27 -0800 Subject: [PATCH 15/24] most basic beta density test passes --- funsor/distributions2.py | 60 ++++---- test/test_distributions2.py | 277 ++++++++++++++++++++++++++++++++++++ 2 files changed, 311 insertions(+), 26 deletions(-) create mode 100644 test/test_distributions2.py diff --git a/funsor/distributions2.py b/funsor/distributions2.py index 69109fc5b..40d935f81 100644 --- a/funsor/distributions2.py +++ b/funsor/distributions2.py @@ -30,9 +30,9 @@ class Distribution2(Funsor): dist_class = dist.Distribution # defined by derived classes def __init__(self, *args, name='value'): - params = tuple(zip(self._ast_fields, args)) + params = OrderedDict(zip(self._ast_fields, args)) inputs = OrderedDict() - for param_name, value in params: + for param_name, value in params.items(): assert isinstance(param_name, str) assert isinstance(value, Funsor) inputs.update(value.inputs) @@ -48,7 +48,7 @@ def __init__(self, *args, name='value'): @classmethod def _infer_value_shape(cls, **kwargs): # rely on the underlying distribution's logic to infer the event_shape - instance = cls.dist_class(**{k: _dummy_tensor(v.output) for k, v in kwargs}) + instance = cls.dist_class(**{k: _dummy_tensor(v.output) for k, v in kwargs.items()}) out_shape = instance.event_shape if isinstance(instance.support, torch.distributions.constraints._IntegerInterval): out_dtype = instance.support.upper_bound @@ -60,7 +60,7 @@ def eager_subs(self, subs): name, sub = subs[0] if isinstance(sub, (Number, Tensor)): inputs, tensors = align_tensors(*self.params.values()) - data = self.dist_class(**tensors).log_prob(sub.data) + data = self.dist_class(*tensors).log_prob(sub.data) return Tensor(data, inputs) elif isinstance(sub, (Variable, str)): return type(self)(*self._ast_values, name=sub.name if isinstance(sub, Variable) else sub) @@ -81,8 +81,8 @@ def make_dist(pyro_dist_class, param_names=()): assert all(name in pyro_dist_class.arg_constraints for name in param_names) @makefun.with_signature(f"__init__(self, {', '.join(param_names)}, name='value')") - def dist_init(*args, **kwargs): - return super().__init__(*args, **kwargs) + def dist_init(self, *args, **kwargs): + return Distribution2.__init__(self, *map(to_funsor, list(kwargs.values())[:-1]), name='value') dist_class = FunsorMeta(pyro_dist_class.__name__, (Distribution2,), { 'dist_class': pyro_dist_class, @@ -92,24 +92,32 @@ def dist_init(*args, **kwargs): return dist_class -for pyro_dist_class in (dist.Categorical, dist.Bernoulli, dist.Normal): +# @to_funsor.register(dist.TorchDistribution) +# def torchdistribution_to_funsor(pyro_dist, output=None, dim_to_name=None): +# import funsor.distributions # TODO find a better way to do this lookup +# funsor_dist = getattr(funsor.distributions, type(pyro_dist).__name__) +# params = [to_funsor(getattr(pyro_dist, param_name), dim_to_name=dim_to_name) +# for param_name in funsor_dist._ast_fields if param_name != 'value'] +# return funsor_dist(*params) +# +# +# @to_data.register(Distribution2) +# def distribution_to_data(funsor_dist, name_to_dim=None): +# pyro_dist = funsor_dist.dist_class +# assert 'value' not in name_to_dim +# assert funsor_dist.inputs['value'].shape == () # TODO convert properly +# params = [to_data(getattr(pyro_dist, param_name), name_to_dim=name_to_dim) +# for param_name in funsor_dist._ast_fields if param_name != 'value'] +# return pyro_dist(*params) + + +_wrapped_pyro_dists = [ + dist.Beta, + # dist.Bernoulli, + # dist.Categorical, + # dist.Poisson, + # dist.Normal, +] + +for pyro_dist_class in _wrapped_pyro_dists: locals()[pyro_dist_class.__name__.split(".")[-1]] = make_dist(pyro_dist_class) - - -@to_funsor.register(dist.TorchDistribution) -def torchdistribution_to_funsor(pyro_dist, output=None, dim_to_name=None): - import funsor.distributions # TODO find a better way to do this lookup - funsor_dist = getattr(funsor.distributions, type(pyro_dist).__name__) - params = [to_funsor(getattr(pyro_dist, param_name), dim_to_name=dim_to_name) - for param_name in funsor_dist._ast_fields if param_name != 'value'] - return funsor_dist(*params) - - -@to_data.register(Distribution2) -def distribution_to_data(funsor_dist, name_to_dim=None): - pyro_dist = funsor_dist.dist_class - assert 'value' not in name_to_dim - assert funsor_dist.inputs['value'].shape == () # TODO convert properly - params = [to_data(getattr(pyro_dist, param_name), name_to_dim=name_to_dim) - for param_name in funsor_dist._ast_fields if param_name != 'value'] - return pyro_dist(*params) diff --git a/test/test_distributions2.py b/test/test_distributions2.py new file mode 100644 index 000000000..e6b7801fa --- /dev/null +++ b/test/test_distributions2.py @@ -0,0 +1,277 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import math +from collections import OrderedDict + +import pyro +import pytest +import torch + +import funsor +import funsor.distributions2 as dist +from funsor.cnf import Contraction, GaussianMixture +from funsor.delta import Delta +from funsor.domains import bint, reals +from funsor.interpreter import interpretation, reinterpret +from funsor.pyro.convert import dist_to_funsor +from funsor.tensor import Einsum, Tensor +from funsor.terms import Independent, Variable, lazy +from funsor.testing import assert_close, check_funsor, random_mvn, random_tensor +from funsor.util import get_backend + +funsor.set_backend("torch") + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_beta_density(batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(), reals(), reals(), reals()) + def beta(concentration1, concentration0, value): + return torch.distributions.Beta(concentration1, concentration0).log_prob(value) + + check_funsor(beta, {'concentration1': reals(), 'concentration0': reals(), 'value': reals()}, reals()) + + concentration1 = Tensor(torch.randn(batch_shape).exp(), inputs) + concentration0 = Tensor(torch.randn(batch_shape).exp(), inputs) + value = Tensor(torch.rand(batch_shape), inputs) + expected = beta(concentration1, concentration0, value) + check_funsor(expected, inputs, reals()) + + actual = dist.Beta(concentration1, concentration0, 'value')(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('syntax', ['eager', 'lazy', 'generic']) +def test_bernoulli_probs_density(batch_shape, syntax): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(), reals(), reals()) + def bernoulli(probs, value): + return torch.distributions.Bernoulli(probs).log_prob(value) + + check_funsor(bernoulli, {'probs': reals(), 'value': reals()}, reals()) + + probs = Tensor(torch.rand(batch_shape), inputs) + value = Tensor(torch.rand(batch_shape).round(), inputs) + expected = bernoulli(probs, value) + check_funsor(expected, inputs, reals()) + + d = Variable('value', reals()) + if syntax == 'eager': + actual = dist.BernoulliProbs(probs, value) + elif syntax == 'lazy': + actual = dist.BernoulliProbs(probs, d)(value=value) + elif syntax == 'generic': + actual = dist.Bernoulli(probs=probs)(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('syntax', ['eager', 'lazy', 'generic']) +def test_bernoulli_logits_density(batch_shape, syntax): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(), reals(), reals()) + def bernoulli(logits, value): + return torch.distributions.Bernoulli(logits=logits).log_prob(value) + + check_funsor(bernoulli, {'logits': reals(), 'value': reals()}, reals()) + + logits = Tensor(torch.rand(batch_shape), inputs) + value = Tensor(torch.rand(batch_shape).round(), inputs) + expected = bernoulli(logits, value) + check_funsor(expected, inputs, reals()) + + d = Variable('value', reals()) + if syntax == 'eager': + actual = dist.BernoulliLogits(logits, value) + elif syntax == 'lazy': + actual = dist.BernoulliLogits(logits, d)(value=value) + elif syntax == 'generic': + actual = dist.Bernoulli(logits=logits)(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('eager', [False, True]) +def test_binomial_density(batch_shape, eager): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + max_count = 10 + + @funsor.function(reals(), reals(), reals(), reals()) + def binomial(total_count, probs, value): + return torch.distributions.Binomial(total_count, probs).log_prob(value) + + check_funsor(binomial, {'total_count': reals(), 'probs': reals(), 'value': reals()}, reals()) + + value_data = random_tensor(inputs, bint(max_count)).data.float() + total_count_data = value_data + random_tensor(inputs, bint(max_count)).data.float() + value = Tensor(value_data, inputs) + total_count = Tensor(total_count_data, inputs) + probs = Tensor(torch.rand(batch_shape), inputs) + expected = binomial(total_count, probs, value) + check_funsor(expected, inputs, reals()) + + m = Variable('value', reals()) + actual = dist.Binomial(total_count, probs, value) if eager else \ + dist.Binomial(total_count, probs, m)(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('size', [4]) +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_categorical_density(size, batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.of_shape(reals(size), bint(size)) + def categorical(probs, value): + return probs[value].log() + + check_funsor(categorical, {'probs': reals(size), 'value': bint(size)}, reals()) + + probs_data = torch.randn(batch_shape + (size,)).exp() + probs_data /= probs_data.sum(-1, keepdim=True) + probs = Tensor(probs_data, inputs) + value = random_tensor(inputs, bint(size)) + expected = categorical(probs, value) + check_funsor(expected, inputs, reals()) + + actual = dist.Categorical(probs, value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('event_shape', [(1,), (4,), (5,)], ids=str) +def test_dirichlet_density(batch_shape, event_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(*event_shape), reals(*event_shape), reals()) + def dirichlet(concentration, value): + return torch.distributions.Dirichlet(concentration).log_prob(value) + + check_funsor(dirichlet, {'concentration': reals(*event_shape), 'value': reals(*event_shape)}, reals()) + + concentration = Tensor(torch.randn(batch_shape + event_shape).exp(), inputs) + value_data = torch.rand(batch_shape + event_shape) + value_data = value_data / value_data.sum(-1, keepdim=True) + value = Tensor(value_data, inputs) + expected = dirichlet(concentration, value) + check_funsor(expected, inputs, reals()) + actual = dist.Dirichlet(concentration, value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_normal_density(batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.of_shape(reals(), reals(), reals()) + def normal(loc, scale, value): + return -((value - loc) ** 2) / (2 * scale ** 2) - scale.log() - math.log(math.sqrt(2 * math.pi)) + + check_funsor(normal, {'loc': reals(), 'scale': reals(), 'value': reals()}, reals()) + + loc = Tensor(torch.randn(batch_shape), inputs) + scale = Tensor(torch.randn(batch_shape).exp(), inputs) + value = Tensor(torch.randn(batch_shape), inputs) + expected = normal(loc, scale, value) + check_funsor(expected, inputs, reals()) + + actual = dist.Normal(loc, scale, value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('syntax', ['eager', 'lazy']) +def test_poisson_probs_density(batch_shape, syntax): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(), reals(), reals()) + def poisson(rate, value): + return torch.distributions.Poisson(rate).log_prob(value) + + check_funsor(poisson, {'rate': reals(), 'value': reals()}, reals()) + + rate = Tensor(torch.rand(batch_shape), inputs) + value = Tensor(torch.randn(batch_shape).exp().round(), inputs) + expected = poisson(rate, value) + check_funsor(expected, inputs, reals()) + + d = Variable('value', reals()) + if syntax == 'eager': + actual = dist.Poisson(rate, value) + elif syntax == 'lazy': + actual = dist.Poisson(rate, d)(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('syntax', ['eager', 'lazy']) +def test_gamma_probs_density(batch_shape, syntax): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(), reals(), reals(), reals()) + def gamma(concentration, rate, value): + return torch.distributions.Gamma(concentration, rate).log_prob(value) + + check_funsor(gamma, {'concentration': reals(), 'rate': reals(), 'value': reals()}, reals()) + + concentration = Tensor(torch.rand(batch_shape), inputs) + rate = Tensor(torch.rand(batch_shape), inputs) + value = Tensor(torch.randn(batch_shape).exp(), inputs) + expected = gamma(concentration, rate, value) + check_funsor(expected, inputs, reals()) + + d = Variable('value', reals()) + if syntax == 'eager': + actual = dist.Gamma(concentration, rate, value) + elif syntax == 'lazy': + actual = dist.Gamma(concentration, rate, d)(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('syntax', ['eager', 'lazy']) +def test_von_mises_probs_density(batch_shape, syntax): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(), reals(), reals(), reals()) + def von_mises(loc, concentration, value): + return pyro.distributions.VonMises(loc, concentration).log_prob(value) + + check_funsor(von_mises, {'concentration': reals(), 'loc': reals(), 'value': reals()}, reals()) + + concentration = Tensor(torch.rand(batch_shape), inputs) + loc = Tensor(torch.rand(batch_shape), inputs) + value = Tensor(torch.randn(batch_shape).abs(), inputs) + expected = von_mises(loc, concentration, value) + check_funsor(expected, inputs, reals()) + + d = Variable('value', reals()) + if syntax == 'eager': + actual = dist.VonMises(loc, concentration, value) + elif syntax == 'lazy': + actual = dist.VonMises(loc, concentration, d)(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) From ed3543cddca6f25845f0bdbda89bab7fa7b5fed4 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 2 Mar 2020 16:14:09 -0800 Subject: [PATCH 16/24] basic density tests pass --- funsor/distributions2.py | 41 +++++++++++--- test/test_distributions2.py | 103 ++++++++++++------------------------ 2 files changed, 68 insertions(+), 76 deletions(-) diff --git a/funsor/distributions2.py b/funsor/distributions2.py index 40d935f81..be6a0f4b4 100644 --- a/funsor/distributions2.py +++ b/funsor/distributions2.py @@ -51,7 +51,7 @@ def _infer_value_shape(cls, **kwargs): instance = cls.dist_class(**{k: _dummy_tensor(v.output) for k, v in kwargs.items()}) out_shape = instance.event_shape if isinstance(instance.support, torch.distributions.constraints._IntegerInterval): - out_dtype = instance.support.upper_bound + out_dtype = instance.support.upper_bound + 1 else: out_dtype = 'real' return Domain(dtype=out_dtype, shape=out_shape) @@ -111,13 +111,38 @@ def dist_init(self, *args, **kwargs): # return pyro_dist(*params) +class BernoulliProbs(dist.Bernoulli): + def __init__(self, probs, validate_args=None): + return super().__init__(probs=probs, validate_args=validate_args) + + +class BernoulliLogits(dist.Bernoulli): + def __init__(self, logits, validate_args=None): + return super().__init__(logits=logits, validate_args=validate_args) + + +class CategoricalProbs(dist.Categorical): + def __init__(self, probs, validate_args=None): + return super().__init__(probs=probs, validate_args=validate_args) + + +class CategoricalLogits(dist.Categorical): + def __init__(self, logits, validate_args=None): + return super().__init__(logits=logits, validate_args=validate_args) + + _wrapped_pyro_dists = [ - dist.Beta, - # dist.Bernoulli, - # dist.Categorical, - # dist.Poisson, - # dist.Normal, + (dist.Beta, ()), + (BernoulliProbs, ('probs',)), + (BernoulliLogits, ('logits',)), + (CategoricalProbs, ('probs',)), + (CategoricalLogits, ('logits',)), + (dist.Poisson, ()), + (dist.Gamma, ()), + (dist.VonMises, ()), + (dist.Dirichlet, ()), + (dist.Normal, ()), ] -for pyro_dist_class in _wrapped_pyro_dists: - locals()[pyro_dist_class.__name__.split(".")[-1]] = make_dist(pyro_dist_class) +for pyro_dist_class, param_names in _wrapped_pyro_dists: + locals()[pyro_dist_class.__name__.split(".")[-1]] = make_dist(pyro_dist_class, param_names) diff --git a/test/test_distributions2.py b/test/test_distributions2.py index e6b7801fa..5087476ed 100644 --- a/test/test_distributions2.py +++ b/test/test_distributions2.py @@ -45,14 +45,13 @@ def beta(concentration1, concentration0, value): @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) -@pytest.mark.parametrize('syntax', ['eager', 'lazy', 'generic']) -def test_bernoulli_probs_density(batch_shape, syntax): +def test_bernoulli_probs_density(batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @funsor.function(reals(), reals(), reals()) def bernoulli(probs, value): - return torch.distributions.Bernoulli(probs).log_prob(value) + return torch.distributions.Bernoulli(probs=probs).log_prob(value) check_funsor(bernoulli, {'probs': reals(), 'value': reals()}, reals()) @@ -61,20 +60,13 @@ def bernoulli(probs, value): expected = bernoulli(probs, value) check_funsor(expected, inputs, reals()) - d = Variable('value', reals()) - if syntax == 'eager': - actual = dist.BernoulliProbs(probs, value) - elif syntax == 'lazy': - actual = dist.BernoulliProbs(probs, d)(value=value) - elif syntax == 'generic': - actual = dist.Bernoulli(probs=probs)(value=value) + actual = dist.BernoulliProbs(probs, 'value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) -@pytest.mark.parametrize('syntax', ['eager', 'lazy', 'generic']) -def test_bernoulli_logits_density(batch_shape, syntax): +def test_bernoulli_logits_density(batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @@ -89,65 +81,55 @@ def bernoulli(logits, value): expected = bernoulli(logits, value) check_funsor(expected, inputs, reals()) - d = Variable('value', reals()) - if syntax == 'eager': - actual = dist.BernoulliLogits(logits, value) - elif syntax == 'lazy': - actual = dist.BernoulliLogits(logits, d)(value=value) - elif syntax == 'generic': - actual = dist.Bernoulli(logits=logits)(value=value) + actual = dist.BernoulliLogits(logits, 'value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) +@pytest.mark.parametrize('size', [4]) @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) -@pytest.mark.parametrize('eager', [False, True]) -def test_binomial_density(batch_shape, eager): +def test_categorical_probs_density(size, batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) - max_count = 10 - @funsor.function(reals(), reals(), reals(), reals()) - def binomial(total_count, probs, value): - return torch.distributions.Binomial(total_count, probs).log_prob(value) + @funsor.function(reals(size), bint(size), reals()) + def categorical_probs(probs, value): + return torch.distributions.Categorical(probs=probs).log_prob(value) - check_funsor(binomial, {'total_count': reals(), 'probs': reals(), 'value': reals()}, reals()) + check_funsor(categorical_probs, {'probs': reals(size), 'value': bint(size)}, reals()) - value_data = random_tensor(inputs, bint(max_count)).data.float() - total_count_data = value_data + random_tensor(inputs, bint(max_count)).data.float() - value = Tensor(value_data, inputs) - total_count = Tensor(total_count_data, inputs) - probs = Tensor(torch.rand(batch_shape), inputs) - expected = binomial(total_count, probs, value) + probs_data = torch.randn(batch_shape + (size,)).exp() + probs_data /= probs_data.sum(-1, keepdim=True) + probs = Tensor(probs_data, inputs) + value = random_tensor(inputs, bint(size)) + expected = categorical_probs(probs, value) check_funsor(expected, inputs, reals()) - m = Variable('value', reals()) - actual = dist.Binomial(total_count, probs, value) if eager else \ - dist.Binomial(total_count, probs, m)(value=value) + actual = dist.CategoricalProbs(probs, 'value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @pytest.mark.parametrize('size', [4]) @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) -def test_categorical_density(size, batch_shape): +def test_categorical_logits_density(size, batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) - @funsor.of_shape(reals(size), bint(size)) - def categorical(probs, value): - return probs[value].log() + @funsor.function(reals(size), bint(size), reals()) + def categorical_logits(logits, value): + return torch.distributions.Categorical(logits=logits).log_prob(value) - check_funsor(categorical, {'probs': reals(size), 'value': bint(size)}, reals()) + check_funsor(categorical_logits, {'logits': reals(size), 'value': bint(size)}, reals()) - probs_data = torch.randn(batch_shape + (size,)).exp() - probs_data /= probs_data.sum(-1, keepdim=True) - probs = Tensor(probs_data, inputs) + logits_data = torch.randn(batch_shape + (size,)) + logits_data /= logits_data.sum(-1, keepdim=True) + logits = Tensor(logits_data, inputs) value = random_tensor(inputs, bint(size)) - expected = categorical(probs, value) + expected = categorical_logits(logits, value) check_funsor(expected, inputs, reals()) - actual = dist.Categorical(probs, value) + actual = dist.CategoricalLogits(logits, 'value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @@ -170,7 +152,7 @@ def dirichlet(concentration, value): value = Tensor(value_data, inputs) expected = dirichlet(concentration, value) check_funsor(expected, inputs, reals()) - actual = dist.Dirichlet(concentration, value) + actual = dist.Dirichlet(concentration, 'value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @@ -192,14 +174,13 @@ def normal(loc, scale, value): expected = normal(loc, scale, value) check_funsor(expected, inputs, reals()) - actual = dist.Normal(loc, scale, value) + actual = dist.Normal(loc, scale, 'value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) -@pytest.mark.parametrize('syntax', ['eager', 'lazy']) -def test_poisson_probs_density(batch_shape, syntax): +def test_poisson_density(batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @@ -214,18 +195,13 @@ def poisson(rate, value): expected = poisson(rate, value) check_funsor(expected, inputs, reals()) - d = Variable('value', reals()) - if syntax == 'eager': - actual = dist.Poisson(rate, value) - elif syntax == 'lazy': - actual = dist.Poisson(rate, d)(value=value) + actual = dist.Poisson(rate, 'value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) -@pytest.mark.parametrize('syntax', ['eager', 'lazy']) -def test_gamma_probs_density(batch_shape, syntax): +def test_gamma_density(batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @@ -241,18 +217,13 @@ def gamma(concentration, rate, value): expected = gamma(concentration, rate, value) check_funsor(expected, inputs, reals()) - d = Variable('value', reals()) - if syntax == 'eager': - actual = dist.Gamma(concentration, rate, value) - elif syntax == 'lazy': - actual = dist.Gamma(concentration, rate, d)(value=value) + actual = dist.Gamma(concentration, rate, 'value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) -@pytest.mark.parametrize('syntax', ['eager', 'lazy']) -def test_von_mises_probs_density(batch_shape, syntax): +def test_von_mises_density(batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @@ -268,10 +239,6 @@ def von_mises(loc, concentration, value): expected = von_mises(loc, concentration, value) check_funsor(expected, inputs, reals()) - d = Variable('value', reals()) - if syntax == 'eager': - actual = dist.VonMises(loc, concentration, value) - elif syntax == 'lazy': - actual = dist.VonMises(loc, concentration, d)(value=value) + actual = dist.VonMises(loc, concentration, 'value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) From 0d84d09ee64c67c0f1d887744bd79a13eba80493 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 2 Mar 2020 16:34:17 -0800 Subject: [PATCH 17/24] tweak generic to_funsor/to_data implementations --- funsor/distributions2.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/funsor/distributions2.py b/funsor/distributions2.py index be6a0f4b4..c18250b07 100644 --- a/funsor/distributions2.py +++ b/funsor/distributions2.py @@ -92,23 +92,21 @@ def dist_init(self, *args, **kwargs): return dist_class -# @to_funsor.register(dist.TorchDistribution) -# def torchdistribution_to_funsor(pyro_dist, output=None, dim_to_name=None): -# import funsor.distributions # TODO find a better way to do this lookup -# funsor_dist = getattr(funsor.distributions, type(pyro_dist).__name__) -# params = [to_funsor(getattr(pyro_dist, param_name), dim_to_name=dim_to_name) -# for param_name in funsor_dist._ast_fields if param_name != 'value'] -# return funsor_dist(*params) -# -# -# @to_data.register(Distribution2) -# def distribution_to_data(funsor_dist, name_to_dim=None): -# pyro_dist = funsor_dist.dist_class -# assert 'value' not in name_to_dim -# assert funsor_dist.inputs['value'].shape == () # TODO convert properly -# params = [to_data(getattr(pyro_dist, param_name), name_to_dim=name_to_dim) -# for param_name in funsor_dist._ast_fields if param_name != 'value'] -# return pyro_dist(*params) +@to_funsor.register(dist.TorchDistribution) +def torchdistribution_to_funsor(pyro_dist, output=None, dim_to_name=None): + import funsor.distributions2 # TODO find a better way to do this lookup + funsor_dist = getattr(funsor.distributions2, type(pyro_dist).__name__) + params = [to_funsor(getattr(pyro_dist, param_name), dim_to_name=dim_to_name) + for param_name in funsor_dist_class._ast_fields if param_name != 'name'] + return funsor_dist_class(*params) + + +@to_data.register(Distribution2) +def distribution_to_data(funsor_dist, name_to_dim=None): + pyro_dist_class = funsor_dist.dist_class + params = [to_data(getattr(funsor_dist, param_name), name_to_dim=name_to_dim) + for param_name in funsor_dist._ast_fields if param_name != 'name'] + return pyro_dist_class(*params) class BernoulliProbs(dist.Bernoulli): From 06e51d4b139e8466963867668821c51d615ecb1b Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 2 Mar 2020 16:36:32 -0800 Subject: [PATCH 18/24] standardize test --- test/test_distributions2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_distributions2.py b/test/test_distributions2.py index 5087476ed..781d51890 100644 --- a/test/test_distributions2.py +++ b/test/test_distributions2.py @@ -162,9 +162,9 @@ def test_normal_density(batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) - @funsor.of_shape(reals(), reals(), reals()) + @funsor.function(reals(), reals(), reals(), reals()) def normal(loc, scale, value): - return -((value - loc) ** 2) / (2 * scale ** 2) - scale.log() - math.log(math.sqrt(2 * math.pi)) + return torch.distributions.Normal(loc, scale).log_prob(value) check_funsor(normal, {'loc': reals(), 'scale': reals(), 'value': reals()}, reals()) From 55805b90655bd4f5506dd1d7fc1148d1b546bf79 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 2 Mar 2020 16:46:34 -0800 Subject: [PATCH 19/24] check event shape in to_data --- funsor/distributions2.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/funsor/distributions2.py b/funsor/distributions2.py index c18250b07..5c7515eaf 100644 --- a/funsor/distributions2.py +++ b/funsor/distributions2.py @@ -95,7 +95,7 @@ def dist_init(self, *args, **kwargs): @to_funsor.register(dist.TorchDistribution) def torchdistribution_to_funsor(pyro_dist, output=None, dim_to_name=None): import funsor.distributions2 # TODO find a better way to do this lookup - funsor_dist = getattr(funsor.distributions2, type(pyro_dist).__name__) + funsor_dist_class = getattr(funsor.distributions2, type(pyro_dist).__name__) params = [to_funsor(getattr(pyro_dist, param_name), dim_to_name=dim_to_name) for param_name in funsor_dist_class._ast_fields if param_name != 'name'] return funsor_dist_class(*params) @@ -106,7 +106,12 @@ def distribution_to_data(funsor_dist, name_to_dim=None): pyro_dist_class = funsor_dist.dist_class params = [to_data(getattr(funsor_dist, param_name), name_to_dim=name_to_dim) for param_name in funsor_dist._ast_fields if param_name != 'name'] - return pyro_dist_class(*params) + pyro_dist = pyro_dist_class(*params) + funsor_event_shape = funsor_dist.inputs[funsor_dist.name].shape + pyro_dist = pyro_dist.to_event(max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0)) + if pyro_dist.event_shape != funsor_event_shape: + raise ValueError("Event shapes don't match, something went wrong") + return pyro_dist class BernoulliProbs(dist.Bernoulli): From f1f9e0e7e63f6222bf639e6421ae19b6666e88fc Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 2 Mar 2020 17:02:44 -0800 Subject: [PATCH 20/24] add metaclass to handle default name --- funsor/distributions2.py | 11 +++++++++-- test/test_distributions2.py | 20 ++++++++++---------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/funsor/distributions2.py b/funsor/distributions2.py index 5c7515eaf..959cda764 100644 --- a/funsor/distributions2.py +++ b/funsor/distributions2.py @@ -22,7 +22,14 @@ def _dummy_tensor(domain): return torch.tensor(0.1 if domain.dtype == 'real' else 1).expand(domain.shape) -class Distribution2(Funsor): +class DistributionMeta2(FunsorMeta): + def __call__(cls, *args, name=None): + if len(args) < len(cls._ast_fields): + args = args + (name if name is not None else 'value',) + return super(DistributionMeta2, cls).__call__(*args) + + +class Distribution2(Funsor, metaclass=DistributionMeta2): """ Different design for the Distribution Funsor wrapper, closer to Gaussian or Delta in which the value is a fresh input. @@ -82,7 +89,7 @@ def make_dist(pyro_dist_class, param_names=()): @makefun.with_signature(f"__init__(self, {', '.join(param_names)}, name='value')") def dist_init(self, *args, **kwargs): - return Distribution2.__init__(self, *map(to_funsor, list(kwargs.values())[:-1]), name='value') + return Distribution2.__init__(self, *map(to_funsor, list(kwargs.values())[:-1]), name=kwargs['name']) dist_class = FunsorMeta(pyro_dist_class.__name__, (Distribution2,), { 'dist_class': pyro_dist_class, diff --git a/test/test_distributions2.py b/test/test_distributions2.py index 781d51890..08e67f515 100644 --- a/test/test_distributions2.py +++ b/test/test_distributions2.py @@ -39,7 +39,7 @@ def beta(concentration1, concentration0, value): expected = beta(concentration1, concentration0, value) check_funsor(expected, inputs, reals()) - actual = dist.Beta(concentration1, concentration0, 'value')(value=value) + actual = dist.Beta(concentration1, concentration0, name='value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @@ -60,7 +60,7 @@ def bernoulli(probs, value): expected = bernoulli(probs, value) check_funsor(expected, inputs, reals()) - actual = dist.BernoulliProbs(probs, 'value')(value=value) + actual = dist.BernoulliProbs(probs, name='value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @@ -81,7 +81,7 @@ def bernoulli(logits, value): expected = bernoulli(logits, value) check_funsor(expected, inputs, reals()) - actual = dist.BernoulliLogits(logits, 'value')(value=value) + actual = dist.BernoulliLogits(logits, name='value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @@ -105,7 +105,7 @@ def categorical_probs(probs, value): expected = categorical_probs(probs, value) check_funsor(expected, inputs, reals()) - actual = dist.CategoricalProbs(probs, 'value')(value=value) + actual = dist.CategoricalProbs(probs, name='value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @@ -129,7 +129,7 @@ def categorical_logits(logits, value): expected = categorical_logits(logits, value) check_funsor(expected, inputs, reals()) - actual = dist.CategoricalLogits(logits, 'value')(value=value) + actual = dist.CategoricalLogits(logits, name='value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @@ -152,7 +152,7 @@ def dirichlet(concentration, value): value = Tensor(value_data, inputs) expected = dirichlet(concentration, value) check_funsor(expected, inputs, reals()) - actual = dist.Dirichlet(concentration, 'value')(value=value) + actual = dist.Dirichlet(concentration, name='value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @@ -174,7 +174,7 @@ def normal(loc, scale, value): expected = normal(loc, scale, value) check_funsor(expected, inputs, reals()) - actual = dist.Normal(loc, scale, 'value')(value=value) + actual = dist.Normal(loc, scale, name='value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @@ -195,7 +195,7 @@ def poisson(rate, value): expected = poisson(rate, value) check_funsor(expected, inputs, reals()) - actual = dist.Poisson(rate, 'value')(value=value) + actual = dist.Poisson(rate, name='value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @@ -217,7 +217,7 @@ def gamma(concentration, rate, value): expected = gamma(concentration, rate, value) check_funsor(expected, inputs, reals()) - actual = dist.Gamma(concentration, rate, 'value')(value=value) + actual = dist.Gamma(concentration, rate, name='value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) @@ -239,6 +239,6 @@ def von_mises(loc, concentration, value): expected = von_mises(loc, concentration, value) check_funsor(expected, inputs, reals()) - actual = dist.VonMises(loc, concentration, 'value')(value=value) + actual = dist.VonMises(loc, concentration, name='value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) From b5ac6209651a7707e8ec36bd07ef66d6c9a2af29 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 2 Mar 2020 17:55:26 -0800 Subject: [PATCH 21/24] add a to_funsor test for normal --- funsor/distributions2.py | 8 +++++++- test/test_distributions2.py | 27 +++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/funsor/distributions2.py b/funsor/distributions2.py index 959cda764..d45fc7651 100644 --- a/funsor/distributions2.py +++ b/funsor/distributions2.py @@ -52,6 +52,11 @@ def __init__(self, *args, name='value'): self.params = params self.name = name + def __getattribute__(self, attr): + if attr in type(self)._ast_fields and attr != 'name': + return self.params[attr] + return super().__getattribute__(attr) + @classmethod def _infer_value_shape(cls, **kwargs): # rely on the underlying distribution's logic to infer the event_shape @@ -99,7 +104,7 @@ def dist_init(self, *args, **kwargs): return dist_class -@to_funsor.register(dist.TorchDistribution) +@to_funsor.register(torch.distributions.Distribution) def torchdistribution_to_funsor(pyro_dist, output=None, dim_to_name=None): import funsor.distributions2 # TODO find a better way to do this lookup funsor_dist_class = getattr(funsor.distributions2, type(pyro_dist).__name__) @@ -152,6 +157,7 @@ def __init__(self, logits, validate_args=None): (dist.VonMises, ()), (dist.Dirichlet, ()), (dist.Normal, ()), + (dist.MultivariateNormal, ('loc', 'scale_tril')), ] for pyro_dist_class, param_names in _wrapped_pyro_dists: diff --git a/test/test_distributions2.py b/test/test_distributions2.py index 08e67f515..c748331c1 100644 --- a/test/test_distributions2.py +++ b/test/test_distributions2.py @@ -242,3 +242,30 @@ def von_mises(loc, concentration, value): actual = dist.VonMises(loc, concentration, name='value')(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected) + + +@pytest.mark.parametrize("event_shape", [ + (), # (5,), (4, 3), +], ids=str) +@pytest.mark.parametrize("batch_shape", [ + (), (2,), (2, 3), +], ids=str) +def test_normal_funsor_normal(batch_shape, event_shape): + loc = torch.randn(batch_shape + event_shape) + scale = torch.randn(batch_shape + event_shape).exp() + d = pyro.distributions.Normal(loc, scale).to_event(len(event_shape)) + value = d.sample() + name_to_dim = OrderedDict( + (f'{v}', v) for v in range(-len(batch_shape), 0) if batch_shape[v] > 1) + dim_to_name = OrderedDict((v, k) for k, v in name_to_dim.items()) + f = funsor.to_funsor(d, reals(), dim_to_name=dim_to_name) + d2 = funsor.to_data(f, name_to_dim=name_to_dim) + assert type(d) == type(d2) + assert d.batch_shape == d2.batch_shape + assert d.event_shape == d2.event_shape + expected_log_prob = d.log_prob(value) + actual_log_prob = d2.log_prob(value) + assert_close(actual_log_prob, expected_log_prob) + expected_funsor_log_prob = funsor.to_funsor(actual_log_prob, reals(), dim_to_name) + actual_funsor_log_prob = f(value=funsor.to_funsor(value, reals(*event_shape), dim_to_name)) + assert_close(actual_funsor_log_prob, expected_funsor_log_prob) From 2030536b8eee38f12147488b38a1d8c610295108 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 2 Mar 2020 18:10:30 -0800 Subject: [PATCH 22/24] add makefun to dependencies --- funsor/distributions2.py | 3 +-- setup.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/funsor/distributions2.py b/funsor/distributions2.py index d45fc7651..377ff106d 100644 --- a/funsor/distributions2.py +++ b/funsor/distributions2.py @@ -4,6 +4,7 @@ import math from collections import OrderedDict +import makefun import pyro.distributions as dist import torch from pyro.distributions.util import broadcast_shape @@ -86,8 +87,6 @@ def eager_subs(self, subs): def make_dist(pyro_dist_class, param_names=()): - import makefun - if not param_names: param_names = tuple(pyro_dist_class.arg_constraints.keys()) assert all(name in pyro_dist_class.arg_constraints for name in param_names) diff --git a/setup.py b/setup.py index 23a39f65a..78b25e1c9 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ author_email='fritzo@uber.com', python_requires=">=3.6", install_requires=[ + 'makefun', 'multipledispatch', 'numpy>=1.7', 'opt_einsum>=2.3.2', From 3c1f24963008a1c93475d0b82a730cbdc9795d4d Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 2 Mar 2020 18:39:27 -0800 Subject: [PATCH 23/24] add more to_funsor sketches --- funsor/distributions2.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/funsor/distributions2.py b/funsor/distributions2.py index 377ff106d..c93faad44 100644 --- a/funsor/distributions2.py +++ b/funsor/distributions2.py @@ -5,18 +5,14 @@ from collections import OrderedDict import makefun -import pyro.distributions as dist import torch +import pyro.distributions as dist from pyro.distributions.util import broadcast_shape -import funsor.delta import funsor.ops as ops -from funsor.affine import is_affine from funsor.domains import Domain, bint, reals -from funsor.gaussian import Gaussian -from funsor.interpreter import gensym, interpretation -from funsor.tensor import Tensor, align_tensors, ignore_jit_warnings, stack -from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, lazy, to_data, to_funsor +from funsor.tensor import Tensor, align_tensors +from funsor.terms import Funsor, FunsorMeta, Independent, Number, Variable, eager, to_data, to_funsor def _dummy_tensor(domain): @@ -112,6 +108,27 @@ def torchdistribution_to_funsor(pyro_dist, output=None, dim_to_name=None): return funsor_dist_class(*params) +@to_funsor.register(torch.distributions.Independent) +def indepdist_to_funsor(pyro_dist, output=None, dim_to_name=None): + result = to_funsor(pyro_dist.base_dist, dim_to_name=dim_to_name) + for i in range(pyro_dist.reinterpreted_batch_ndims): + name = ... # XXX what is this? read off from result? + result = funsor.terms.Independent(result, "value", name, "value") + return result + + +@to_funsor.register(pyro.distributions.MaskedDistribution) +def maskeddist_to_funsor(pyro_dist, output=None, dim_to_name=None): + mask = to_funsor(pyro_dist._mask.float(), output=output, dim_to_name=dim_to_name) + funsor_base_dist = to_funsor(pyro_dist.base_dist, output=output, dim_to_name=dim_to_name) + return mask * funsor_base_dist + + +@to_funsor.register(torch.distributions.TransformedDistribution) +def transformeddist_to_funsor(pyro_dist, output=None, dim_to_name=None): + raise NotImplementedError("TODO") + + @to_data.register(Distribution2) def distribution_to_data(funsor_dist, name_to_dim=None): pyro_dist_class = funsor_dist.dist_class From c913eebbf9757a4e0a57a53adfbd486fa41f5203 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 2 Mar 2020 18:46:27 -0800 Subject: [PATCH 24/24] shuffle code around --- funsor/distributions2.py | 58 +++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/funsor/distributions2.py b/funsor/distributions2.py index c93faad44..b8a67c228 100644 --- a/funsor/distributions2.py +++ b/funsor/distributions2.py @@ -7,6 +7,7 @@ import makefun import torch import pyro.distributions as dist +from pyro.distributions.torch_distribution import MaskedDistribution from pyro.distributions.util import broadcast_shape import funsor.ops as ops @@ -77,27 +78,9 @@ def eager_subs(self, subs): raise NotImplementedError("not implemented") -################################################################################ -# Distribution Wrappers -################################################################################ - -def make_dist(pyro_dist_class, param_names=()): - - if not param_names: - param_names = tuple(pyro_dist_class.arg_constraints.keys()) - assert all(name in pyro_dist_class.arg_constraints for name in param_names) - - @makefun.with_signature(f"__init__(self, {', '.join(param_names)}, name='value')") - def dist_init(self, *args, **kwargs): - return Distribution2.__init__(self, *map(to_funsor, list(kwargs.values())[:-1]), name=kwargs['name']) - - dist_class = FunsorMeta(pyro_dist_class.__name__, (Distribution2,), { - 'dist_class': pyro_dist_class, - '__init__': dist_init, - }) - - return dist_class - +###################################### +# Converting distributions to funsors +###################################### @to_funsor.register(torch.distributions.Distribution) def torchdistribution_to_funsor(pyro_dist, output=None, dim_to_name=None): @@ -117,7 +100,7 @@ def indepdist_to_funsor(pyro_dist, output=None, dim_to_name=None): return result -@to_funsor.register(pyro.distributions.MaskedDistribution) +@to_funsor.register(MaskedDistribution) def maskeddist_to_funsor(pyro_dist, output=None, dim_to_name=None): mask = to_funsor(pyro_dist._mask.float(), output=output, dim_to_name=dim_to_name) funsor_base_dist = to_funsor(pyro_dist.base_dist, output=output, dim_to_name=dim_to_name) @@ -129,6 +112,10 @@ def transformeddist_to_funsor(pyro_dist, output=None, dim_to_name=None): raise NotImplementedError("TODO") +########################################################### +# Converting distribution funsors to PyTorch distributions +########################################################### + @to_data.register(Distribution2) def distribution_to_data(funsor_dist, name_to_dim=None): pyro_dist_class = funsor_dist.dist_class @@ -142,6 +129,33 @@ def distribution_to_data(funsor_dist, name_to_dim=None): return pyro_dist +@to_data.register(Independent) +def indep_to_data(funsor_dist, name_to_dim=None): + raise NotImplementedError("TODO") + + +################################################################################ +# Distribution Wrappers +################################################################################ + +def make_dist(pyro_dist_class, param_names=()): + + if not param_names: + param_names = tuple(pyro_dist_class.arg_constraints.keys()) + assert all(name in pyro_dist_class.arg_constraints for name in param_names) + + @makefun.with_signature(f"__init__(self, {', '.join(param_names)}, name='value')") + def dist_init(self, *args, **kwargs): + return Distribution2.__init__(self, *map(to_funsor, list(kwargs.values())[:-1]), name=kwargs['name']) + + dist_class = DistributionMeta2(pyro_dist_class.__name__, (Distribution2,), { + 'dist_class': pyro_dist_class, + '__init__': dist_init, + }) + + return dist_class + + class BernoulliProbs(dist.Bernoulli): def __init__(self, probs, validate_args=None): return super().__init__(probs=probs, validate_args=validate_args)