Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize gaussian performance #331

Closed
wants to merge 20 commits into from
Closed

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented May 16, 2020

This small PR adds some optimizations explored in #315:

  • Lazily add two BlockVector, BlockMatrix. Currently, we convert them to numeric arrays, then add the results.
  • Some optimization for Cholesky/triangular_solve for scalar matrices.

I explored 2D-cat to convert BlockMatrix to a tensor, but we don't gain anything for the performance. It is true that the number of cat operators is reduced (for gaussian hmm with time_dim=6000, the number of cat op is reduced from 160 calls to 91 calls), but the .reshape operators after them become expensive.

Test on GaussianHMM with batch_dim, time_dim, obs_dim, hidden_dim = 5, 6000, 3, 2, time reduced from 192ms to 160ms to evaluate log_prob.

profiling code
import pyro.distributions as dist
import pytest
import torch
from pyro.distributions.util import broadcast_shape

from funsor.pyro.hmm import GaussianHMM
from funsor.testing import assert_close, random_mvn
import funsor; funsor.set_backend("torch")

batch_dim, time_dim, obs_dim, hidden_dim = 5, 6000, 3, 2

init_shape = (batch_dim,)
trans_mat_shape = trans_mvn_shape = obs_mat_shape = obs_mvn_shape = (batch_dim, time_dim)
init_dist = random_mvn(init_shape, hidden_dim)
trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
trans_dist = random_mvn(trans_mvn_shape, hidden_dim)
obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
obs_dist = random_mvn(obs_mvn_shape, obs_dim)

actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)
shape = broadcast_shape(init_shape + (1,),
                        trans_mat_shape, trans_mvn_shape,
                        obs_mat_shape, obs_mvn_shape)
data = obs_dist.expand(shape).sample()
assert data.shape == actual_dist.shape()

%time actual_log_prob = actual_dist.log_prob(data)

@fehiepsi fehiepsi added the Blocked Blocked by other issues label May 16, 2020
@eb8680 eb8680 added Blocked Blocked by other issues and removed Blocked Blocked by other issues labels May 17, 2020
@fehiepsi fehiepsi removed the Blocked Blocked by other issues label Jul 11, 2020
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice speedup. I believe there is a subtler compatibility condition though.

shape = broadcast_shape(self.shape, other.shape)
matrix = BlockMatrix(shape)
for part in set(self.parts.keys()) | set(other.parts.keys()):
a = self.parts[part]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this error if part in other.parts but part not in self.parts? And conversely for the following line?

keep_block = isinstance(lhs_info_vec, BlockVector)
rhs_info_vec, rhs_precision = align_gaussian(inputs, rhs, try_keeping_block_form=keep_block)

if keep_block and not isinstance(rhs_info_vec, BlockVector):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you'll also need to test whether the two block forms are compatible. E.g. the following two block forms are not compatible:

  • [0:1], [1:3]
  • [0:2], [2:3]

result = ops.cat(-1, *parts)
if not get_tracing_state():
assert result.shape == self.shape
return result

def __add__(self, other):
assert isinstance(other, BlockVector)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll also want either (1) an assertion that the two block forms are compatible, or (2) a brach to convert both .as_tensor() and add the results. Two block forms are incompatible if any of their slices overlap, e.g. [0:3],[3:10] versus [0:4],[4:10].

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember that I was worried about this issue too but later found that there is no test for it. So I assumed that each block corresponds to a variable and we only add two BlockVectors when their common variables have consistent dimensions. If that is true, I'll add an assertation. Otherwise, you are right that we need a branch.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way, we need a check for that consistency. So adding a branch has no harm after all. I'll do it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants