Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional input arguments to to_funsor and to_data #316

Merged
merged 10 commits into from
Feb 18, 2020
37 changes: 10 additions & 27 deletions funsor/pyro/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
from funsor.cnf import Contraction
from funsor.delta import Delta
from funsor.distributions import BernoulliLogits, MultivariateNormal, Normal
from funsor.domains import bint, reals
from funsor.domains import Domain, bint, reals
from funsor.gaussian import Gaussian
from funsor.interpreter import gensym
from funsor.ops import cholesky
from funsor.tensor import Tensor, align_tensors
from funsor.terms import Funsor, Independent, Variable, eager
from funsor.terms import Funsor, Independent, Variable, eager, to_data, to_funsor

# Conversion functions use fixed names for Pyro batch dims, but
# accept an event_inputs tuple for custom event dim names.
Expand Down Expand Up @@ -63,21 +63,12 @@ def tensor_to_funsor(tensor, event_inputs=(), event_output=0, dtype="real"):
assert isinstance(event_output, int) and event_output >= 0
inputs_shape = tensor.shape[:tensor.dim() - event_output]
output_shape = tensor.shape[tensor.dim() - event_output:]
dim_to_name = DIM_TO_NAME + event_inputs if event_inputs else DIM_TO_NAME

# Squeeze shape of inputs.
inputs = OrderedDict()
squeezed_shape = []
for dim, size in enumerate(inputs_shape):
if size > 1:
name = dim_to_name[dim - len(inputs_shape)]
inputs[name] = bint(size)
squeezed_shape.append(size)
squeezed_shape = torch.Size(squeezed_shape)
if squeezed_shape != inputs_shape:
tensor = tensor.reshape(squeezed_shape + output_shape)

return Tensor(tensor, inputs, dtype)
dim_to_name_list = DIM_TO_NAME + event_inputs if event_inputs else DIM_TO_NAME
dim_to_name = OrderedDict(zip(
range(-len(inputs_shape), 0),
zip(dim_to_name_list[len(dim_to_name_list) - len(inputs_shape):],
map(bint, inputs_shape))))
return to_funsor(tensor, Domain(dtype=dtype, shape=output_shape), dim_to_name)
eb8680 marked this conversation as resolved.
Show resolved Hide resolved


def funsor_to_tensor(funsor_, ndims, event_inputs=()):
Expand All @@ -99,16 +90,8 @@ def funsor_to_tensor(funsor_, ndims, event_inputs=()):
if event_inputs:
dim_to_name = DIM_TO_NAME + event_inputs
name_to_dim = dict(zip(dim_to_name, range(-len(dim_to_name), 0)))
names = tuple(sorted(funsor_.inputs, key=name_to_dim.__getitem__))
tensor = funsor_.align(names).data
if names:
# Unsqueeze shape of inputs.
dims = list(map(name_to_dim.__getitem__, names))
inputs_shape = [1] * (-dims[0])
for dim, size in zip(dims, tensor.shape):
inputs_shape[dim] = size
inputs_shape = torch.Size(inputs_shape)
tensor = tensor.reshape(inputs_shape + funsor_.output.shape)
tensor = to_data(funsor_, name_to_dim)

if ndims != tensor.dim():
tensor = tensor.reshape((1,) * (ndims - tensor.dim()) + tensor.shape)
assert tensor.dim() == ndims
Expand Down
70 changes: 54 additions & 16 deletions funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,18 +405,40 @@ def backend(self):
return "numpy"


@dispatch(numeric_array)
def to_funsor(x):
return Tensor(x)


@dispatch(numeric_array, Domain)
def to_funsor(x, output):
result = Tensor(x, dtype=output.dtype)
if result.output != output:
raise ValueError("Invalid shape: expected {}, actual {}"
.format(output.shape, result.output.shape))
return result
# TODO move these registrations to backend-specific files
@to_funsor.register(torch.Tensor)
@to_funsor.register(np.ndarray)
@to_funsor.register(np.generic)
def tensor_to_funsor(x, output=None, dim_to_name=None):
if not dim_to_name:
output = output if output is not None else reals(*x.shape)
result = Tensor(x, dtype=output.dtype)
if result.output != output:
raise ValueError("Invalid shape: expected {}, actual {}"
.format(output.shape, result.output.shape))
return result
else:
assert output is not None # TODO attempt to infer output
assert all(isinstance(k, int) and isinstance(v[0], str) and isinstance(v[1], Domain)
eb8680 marked this conversation as resolved.
Show resolved Hide resolved
for k, v in dim_to_name.items())
# logic very similar to pyro.ops.packed.pack
# this should not touch memory, only reshape
# pack the tensor according to the dim => (name, domain) mapping in inputs
packed_inputs = OrderedDict()
for dim, size in zip(range(len(x.shape) - len(output.shape)), x.shape):
if size == 1:
continue # TODO broadcast domain and shape here
name, domain = dim_to_name[dim + len(output.shape) - len(x.shape)]
packed_inputs[name] = domain if domain.dtype > 1 else bint(size)
eb8680 marked this conversation as resolved.
Show resolved Hide resolved
if any(size > 1 for size in output.shape):
# pack outputs into a single dimension
x = x.reshape(x.shape[:-len(output.shape)] + (-1,))
x = x.squeeze()
if all(size == 1 for size in output.shape):
# handle special case: all output dims are 1
x = x.unsqueeze(-1)
x = x.reshape(x.shape[:-1] + output.shape)
eb8680 marked this conversation as resolved.
Show resolved Hide resolved
return Tensor(x, packed_inputs, dtype=output.dtype)


def align_tensor(new_inputs, x, expand=False):
Expand Down Expand Up @@ -484,10 +506,26 @@ def align_tensors(*args, **kwargs):


@to_data.register(Tensor)
def _to_data_tensor(x):
if x.inputs:
raise ValueError(f"cannot convert Tensor to data due to lazy inputs: {set(x.inputs)}")
return x.data
def _to_data_tensor(x, name_to_dim=None):
if not name_to_dim or not x.inputs:
if x.inputs:
raise ValueError(f"cannot convert Tensor to data due to lazy inputs: {set(x.inputs)}")
return x.data
else:
eb8680 marked this conversation as resolved.
Show resolved Hide resolved
# logic very similar to pyro.ops.packed.unpack
# first collapse input domains into single dimensions
data = x.data.reshape(tuple(d.dtype for d in x.inputs.values()) + x.output.shape)
# permute packed dimensions to correct order
unsorted_dims = [name_to_dim[name] for name in x.inputs]
dims = sorted(unsorted_dims)
permutation = [unsorted_dims.index(dim) for dim in dims] + \
list(range(len(dims), len(dims) + len(x.output.shape)))
data = data.permute(*permutation)
# expand
batch_shape = [1] * -min(dims)
for dim, size in zip(dims, data.shape):
batch_shape[dim] = size
return data.reshape(tuple(batch_shape) + x.output.shape)


@eager.register(Binary, Op, Tensor, Number)
Expand Down
52 changes: 20 additions & 32 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,51 +756,40 @@ def _(arg, indent, out):
interpreter.children.register(Funsor)(interpreter.children_funsor)


@dispatch(object)
def to_funsor(x):
@singledispatch
def to_funsor(x, output=None, dim_to_name=None):
"""
Convert to a :class:`Funsor` .
Only :class:`Funsor` s and scalars are accepted.

:param x: An object.
:param funsor.domains.Domain output: An optional output hint.
:param OrderedDict dim_to_name: An optional inputs hint.
eb8680 marked this conversation as resolved.
Show resolved Hide resolved
:return: A Funsor equivalent to ``x``.
:rtype: Funsor
:raises: ValueError
"""
raise ValueError("Cannot convert to Funsor: {}".format(repr(x)))


@dispatch(object, Domain)
def to_funsor(x, output):
raise ValueError("Cannot convert to Funsor: {}".format(repr(x)))


@dispatch(object, object)
def to_funsor(x, output):
raise TypeError("Invalid Domain: {}".format(repr(output)))


@dispatch(Funsor)
def to_funsor(x):
return x


@dispatch(Funsor, Domain)
def to_funsor(x, output):
if x.output != output:
@to_funsor.register(Funsor)
def funsor_to_funsor(x, output=None, dim_to_name=None):
if output is not None and x.output != output:
raise ValueError("Output mismatch: {} vs {}".format(x.output, output))
if dim_to_name is not None and list(x.inputs.keys()) != [v[0] for v in dim_to_name.values()]:
raise ValueError("Inputs mismatch: {} vs {}".format(x.inputs, dim_to_name))
return x


@singledispatch
def to_data(x):
def to_data(x, name_to_dim=None):
"""
Extract a python object from a :class:`Funsor`.

Raises a ``ValueError`` if free variables remain or if the funsor is lazy.

:param x: An object, possibly a :class:`Funsor`.
:param OrderedDict name_to_dim: An optional inputs hint.
:return: A non-funsor equivalent to ``x``.
:raises: ValueError if any free variables remain.
:raises: PatternMissingError if funsor is not fully evaluated.
Expand All @@ -809,8 +798,8 @@ def to_data(x):


@to_data.register(Funsor)
def _to_data_funsor(x):
if x.inputs:
def _to_data_funsor(x, name_to_dim=None):
if name_to_dim is None and x.inputs:
raise ValueError(f"cannot convert {type(x)} to data due to lazy inputs: {set(x.inputs)}")
raise PatternMissingError(r"cannot convert to a non-Funsor: {repr(x)}")

Expand Down Expand Up @@ -839,8 +828,10 @@ def eager_subs(self, subs):
return subs[0][1]


@dispatch(str, Domain)
def to_funsor(name, output):
@to_funsor.register(str)
def name_to_funsor(name, output=None):
if output is None:
raise ValueError(f"Missing output: {name}")
return Variable(name, output)


Expand Down Expand Up @@ -1099,13 +1090,10 @@ def eager_unary(self, op):
return Number(op(self.data), dtype)


@dispatch(numbers.Number)
def to_funsor(x):
return Number(x)


@dispatch(numbers.Number, Domain)
def to_funsor(x, output):
@to_funsor.register(numbers.Number)
def number_to_funsor(x, output=None):
if output is None:
return Number(x)
if output.shape:
raise ValueError("Cannot create Number with shape {}".format(output.shape))
return Number(x, output.dtype)
Expand Down