Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize gaussian performance #331

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,31 @@ def as_tensor(self):
prototype = next(iter(self.parts.values()))
for i in _find_intervals(self.parts.keys(), self.shape[-1]):
if i not in self.parts:
self.parts[i] = ops.new_zeros(prototype, self.shape[:-1] + (i[1] - i[0],))
self.parts[i] = ops.new_zeros(prototype, (i[1] - i[0],))

# Concatenate parts.
parts = [v for k, v in sorted(self.parts.items())]
parts = [ops.expand(v, self.shape[:-1] + v.shape[-1:]) for k, v in sorted(self.parts.items())]
result = ops.cat(-1, *parts)
if not get_tracing_state():
assert result.shape == self.shape
return result

def __add__(self, other):
assert isinstance(other, BlockVector)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll also want either (1) an assertion that the two block forms are compatible, or (2) a brach to convert both .as_tensor() and add the results. Two block forms are incompatible if any of their slices overlap, e.g. [0:3],[3:10] versus [0:4],[4:10].

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember that I was worried about this issue too but later found that there is no test for it. So I assumed that each block corresponds to a variable and we only add two BlockVectors when their common variables have consistent dimensions. If that is true, I'll add an assertation. Otherwise, you are right that we need a branch.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way, we need a check for that consistency. So adding a branch has no harm after all. I'll do it.

shape = broadcast_shape(self.shape, other.shape)
vector = BlockVector(shape)
a = self.parts
b = other.parts
a_keys = set(a.keys())
b_keys = set(b.keys())
for j in a_keys - b_keys:
vector.parts[j] = a[j]
for j in a_keys & b_keys:
vector.parts[j] = a[j] + b[j]
for j in b_keys - a_keys:
vector.parts[j] = b[j]
return vector


class BlockMatrix(object):
"""
Expand Down Expand Up @@ -155,21 +171,39 @@ def as_tensor(self):
for i in rows:
for j in cols:
if j not in self.parts[i]:
shape = self.shape[:-2] + (i[1] - i[0], j[1] - j[0])
shape = (i[1] - i[0], j[1] - j[0])
self.parts[i][j] = ops.new_zeros(prototype, shape)

# Concatenate parts.
# TODO This could be optimized into a single .reshape().cat().reshape() if
# all inputs are contiguous, thereby saving a memcopy.
columns = {i: ops.cat(-1, *[v for j, v in sorted(part.items())])
columns = {i: ops.cat(-1, *[ops.expand(v, self.shape[:-2] + v.shape[-2:]) for j, v in sorted(part.items())])
for i, part in self.parts.items()}
result = ops.cat(-2, *[v for i, v in sorted(columns.items())])

if not get_tracing_state():
assert result.shape == self.shape
return result


def align_gaussian(new_inputs, old):
def __add__(self, other):
assert isinstance(other, BlockMatrix)
shape = broadcast_shape(self.shape, other.shape)
matrix = BlockMatrix(shape)
for part in set(self.parts.keys()) | set(other.parts.keys()):
a = self.parts[part]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this error if part in other.parts but part not in self.parts? And conversely for the following line?

b = other.parts[part]
a_keys = set(a.keys())
b_keys = set(b.keys())
for j in a_keys - b_keys:
matrix.parts[part][j] = a[j]
for j in a_keys & b_keys:
matrix.parts[part][j] = a[j] + b[j]
for j in b_keys - a_keys:
matrix.parts[part][j] = b[j]
return matrix


def align_gaussian(new_inputs, old, try_keeping_block_form=False):
"""
Align data of a Gaussian distribution to a new ``inputs`` shape.
"""
Expand Down Expand Up @@ -212,8 +246,9 @@ def align_gaussian(new_inputs, old):
old_slice2 = slice(offset2, offset2 + num_elements2)
new_slice2 = slice(new_offset2, new_offset2 + num_elements2)
precision[..., new_slice1, new_slice2] = old_precision[..., old_slice1, old_slice2]
info_vec = info_vec.as_tensor()
precision = precision.as_tensor()
if not try_keeping_block_form:
info_vec = info_vec.as_tensor()
precision = precision.as_tensor()

return info_vec, precision

Expand Down Expand Up @@ -635,12 +670,21 @@ def eager_add_gaussian_gaussian(op, lhs, rhs):
# Align data.
inputs = lhs.inputs.copy()
inputs.update(rhs.inputs)
lhs_info_vec, lhs_precision = align_gaussian(inputs, lhs)
rhs_info_vec, rhs_precision = align_gaussian(inputs, rhs)
lhs_info_vec, lhs_precision = align_gaussian(inputs, lhs, try_keeping_block_form=True)
keep_block = isinstance(lhs_info_vec, BlockVector)
rhs_info_vec, rhs_precision = align_gaussian(inputs, rhs, try_keeping_block_form=keep_block)

if keep_block and not isinstance(rhs_info_vec, BlockVector):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you'll also need to test whether the two block forms are compatible. E.g. the following two block forms are not compatible:

  • [0:1], [1:3]
  • [0:2], [2:3]

lhs_info_vec = lhs_info_vec.as_tensor()
lhs_precision = lhs_precision.as_tensor()

# Fuse aligned Gaussians.
info_vec = lhs_info_vec + rhs_info_vec
precision = lhs_precision + rhs_precision

if isinstance(info_vec, BlockVector):
info_vec = info_vec.as_tensor()
precision = precision.as_tensor()
return Gaussian(info_vec, precision, inputs)


Expand Down
7 changes: 7 additions & 0 deletions funsor/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _cholesky(x):
"""
if x.shape[-1] == 1:
return np.sqrt(x)

return np.linalg.cholesky(x)


Expand All @@ -82,6 +83,9 @@ def _cholesky_inverse(x):

@ops.cholesky_solve.register(array, array)
def _cholesky_solve(x, y):
if y.shape[-1] == 1:
return x / (y * y)

return cho_solve((y, True), x)


Expand Down Expand Up @@ -222,6 +226,9 @@ def _sum(x, dim):

@ops.triangular_solve.register(array, array)
def _triangular_solve(x, y, upper=False, transpose=False):
if y.shape[-1] == 1:
return x / y

assert np.ndim(x) >= 2 and np.ndim(y) >= 2
n, m = x.shape[-2:]
assert y.shape[-2:] == (n, n)
Expand Down
17 changes: 15 additions & 2 deletions funsor/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
################################################################################

ops.abs.register(torch.Tensor)(torch.abs)
ops.cholesky_solve.register(torch.Tensor, torch.Tensor)(torch.cholesky_solve)
ops.clamp.register(torch.Tensor, object, object)(torch.clamp)
ops.exp.register(torch.Tensor)(torch.exp)
ops.full_like.register(torch.Tensor, object)(torch.full_like)
Expand Down Expand Up @@ -69,11 +68,22 @@ def _cholesky_inverse(x):
"""
Like :func:`torch.cholesky_inverse` but supports batching and gradients.
"""
if x.dim() == 2:
if x.size(-1) == 1:
return (x * x).reciprocal()
elif x.dim() == 2:
return x.cholesky_inverse()

return torch.eye(x.size(-1)).cholesky_solve(x)


@ops.cholesky_solve.register(torch.Tensor, torch.Tensor)
def _cholesky_solve(x, y):
if y.shape[-1] == 1:
return x / (y * y)

return x.cholesky_solve(y)


@ops.detach.register(torch.Tensor)
def _detach(x):
return x.detach()
Expand Down Expand Up @@ -224,4 +234,7 @@ def _sum(x, dim):

@ops.triangular_solve.register(torch.Tensor, torch.Tensor)
def _triangular_solve(x, y, upper=False, transpose=False):
if y.size(-1) == 1:
return x / y

return x.triangular_solve(y, upper, transpose).solution