-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize gaussian performance #331
Changes from all commits
8e3a7ef
b50e63b
ce09b11
f9b05b6
c588e4b
95e8683
32ae224
b3ec65b
a3e7848
7d24882
6fd166d
e1648c8
635f9c2
408d594
5f5b2c9
8665e24
50b6fcc
eeb32ea
90dafdf
8f3e83d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -114,15 +114,31 @@ def as_tensor(self): | |
prototype = next(iter(self.parts.values())) | ||
for i in _find_intervals(self.parts.keys(), self.shape[-1]): | ||
if i not in self.parts: | ||
self.parts[i] = ops.new_zeros(prototype, self.shape[:-1] + (i[1] - i[0],)) | ||
self.parts[i] = ops.new_zeros(prototype, (i[1] - i[0],)) | ||
|
||
# Concatenate parts. | ||
parts = [v for k, v in sorted(self.parts.items())] | ||
parts = [ops.expand(v, self.shape[:-1] + v.shape[-1:]) for k, v in sorted(self.parts.items())] | ||
result = ops.cat(-1, *parts) | ||
if not get_tracing_state(): | ||
assert result.shape == self.shape | ||
return result | ||
|
||
def __add__(self, other): | ||
assert isinstance(other, BlockVector) | ||
shape = broadcast_shape(self.shape, other.shape) | ||
vector = BlockVector(shape) | ||
a = self.parts | ||
b = other.parts | ||
a_keys = set(a.keys()) | ||
b_keys = set(b.keys()) | ||
for j in a_keys - b_keys: | ||
vector.parts[j] = a[j] | ||
for j in a_keys & b_keys: | ||
vector.parts[j] = a[j] + b[j] | ||
for j in b_keys - a_keys: | ||
vector.parts[j] = b[j] | ||
return vector | ||
|
||
|
||
class BlockMatrix(object): | ||
""" | ||
|
@@ -155,21 +171,39 @@ def as_tensor(self): | |
for i in rows: | ||
for j in cols: | ||
if j not in self.parts[i]: | ||
shape = self.shape[:-2] + (i[1] - i[0], j[1] - j[0]) | ||
shape = (i[1] - i[0], j[1] - j[0]) | ||
self.parts[i][j] = ops.new_zeros(prototype, shape) | ||
|
||
# Concatenate parts. | ||
# TODO This could be optimized into a single .reshape().cat().reshape() if | ||
# all inputs are contiguous, thereby saving a memcopy. | ||
columns = {i: ops.cat(-1, *[v for j, v in sorted(part.items())]) | ||
columns = {i: ops.cat(-1, *[ops.expand(v, self.shape[:-2] + v.shape[-2:]) for j, v in sorted(part.items())]) | ||
for i, part in self.parts.items()} | ||
result = ops.cat(-2, *[v for i, v in sorted(columns.items())]) | ||
|
||
if not get_tracing_state(): | ||
assert result.shape == self.shape | ||
return result | ||
|
||
|
||
def align_gaussian(new_inputs, old): | ||
def __add__(self, other): | ||
assert isinstance(other, BlockMatrix) | ||
shape = broadcast_shape(self.shape, other.shape) | ||
matrix = BlockMatrix(shape) | ||
for part in set(self.parts.keys()) | set(other.parts.keys()): | ||
a = self.parts[part] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't this error if |
||
b = other.parts[part] | ||
a_keys = set(a.keys()) | ||
b_keys = set(b.keys()) | ||
for j in a_keys - b_keys: | ||
matrix.parts[part][j] = a[j] | ||
for j in a_keys & b_keys: | ||
matrix.parts[part][j] = a[j] + b[j] | ||
for j in b_keys - a_keys: | ||
matrix.parts[part][j] = b[j] | ||
return matrix | ||
|
||
|
||
def align_gaussian(new_inputs, old, try_keeping_block_form=False): | ||
""" | ||
Align data of a Gaussian distribution to a new ``inputs`` shape. | ||
""" | ||
|
@@ -212,8 +246,9 @@ def align_gaussian(new_inputs, old): | |
old_slice2 = slice(offset2, offset2 + num_elements2) | ||
new_slice2 = slice(new_offset2, new_offset2 + num_elements2) | ||
precision[..., new_slice1, new_slice2] = old_precision[..., old_slice1, old_slice2] | ||
info_vec = info_vec.as_tensor() | ||
precision = precision.as_tensor() | ||
if not try_keeping_block_form: | ||
info_vec = info_vec.as_tensor() | ||
precision = precision.as_tensor() | ||
|
||
return info_vec, precision | ||
|
||
|
@@ -635,12 +670,21 @@ def eager_add_gaussian_gaussian(op, lhs, rhs): | |
# Align data. | ||
inputs = lhs.inputs.copy() | ||
inputs.update(rhs.inputs) | ||
lhs_info_vec, lhs_precision = align_gaussian(inputs, lhs) | ||
rhs_info_vec, rhs_precision = align_gaussian(inputs, rhs) | ||
lhs_info_vec, lhs_precision = align_gaussian(inputs, lhs, try_keeping_block_form=True) | ||
keep_block = isinstance(lhs_info_vec, BlockVector) | ||
rhs_info_vec, rhs_precision = align_gaussian(inputs, rhs, try_keeping_block_form=keep_block) | ||
|
||
if keep_block and not isinstance(rhs_info_vec, BlockVector): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you'll also need to test whether the two block forms are compatible. E.g. the following two block forms are not compatible:
|
||
lhs_info_vec = lhs_info_vec.as_tensor() | ||
lhs_precision = lhs_precision.as_tensor() | ||
|
||
# Fuse aligned Gaussians. | ||
info_vec = lhs_info_vec + rhs_info_vec | ||
precision = lhs_precision + rhs_precision | ||
|
||
if isinstance(info_vec, BlockVector): | ||
info_vec = info_vec.as_tensor() | ||
precision = precision.as_tensor() | ||
return Gaussian(info_vec, precision, inputs) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'll also want either (1) an assertion that the two block forms are compatible, or (2) a brach to convert both
.as_tensor()
and add the results. Two block forms are incompatible if any of their slices overlap, e.g.[0:3],[3:10]
versus[0:4],[4:10]
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember that I was worried about this issue too but later found that there is no test for it. So I assumed that each block corresponds to a variable and we only add two BlockVectors when their common variables have consistent dimensions. If that is true, I'll add an assertation. Otherwise, you are right that we need a branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either way, we need a check for that consistency. So adding a branch has no harm after all. I'll do it.