Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Profiling GaussianHMM with NumPyro backend #315

Closed
wants to merge 23 commits into from

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Feb 10, 2020

I just do profiling funsor.pyro.hmm.GaussianHMM and pyro.distributions.GaussianHMM and have some observations:

  • funsor is slow comparing to pyro, especially for small hidden/obs dims. With batch_dim, time_dim, obs_dim, hidden_dim = 5, 6, 3, 2, pyro takes 5ms while funsor takes 35ms to evaluate log_prob.
  • The overhead in funsor seems to be constant. I increased obs_dim, hidden_dim to 30, 20 and verified that .
  • torch.cat takes a large amount of time in funsor (e.g. with batch_dim, time_dim = 50, 60, this op takes 7ms per the total 70ms). Similarly, torch.pad takes a portion of time in pyro (but the time is less than torch.cat in funsor). I think the reason is we replace pad by new_zeros + cat in funsor.
  • Time for numerical calculation (except cat, pad) seems to be similar between two versions. It seems to me that at some points, funsor switch dimensions. I guess this is due to the fact that in funsor, we store dimensions in a set. For example, I looked at args of triangular solve and got (this is not important I believe)
funsor torch.Size([5, 3, 2, 4]) torch.Size([5, 3, 2, 2])
funsor torch.Size([5, 3, 2, 1]) torch.Size([5, 3, 2, 2])
funsor torch.Size([5, 1, 2, 4]) torch.Size([5, 1, 2, 2])
funsor torch.Size([5, 1, 2, 1]) torch.Size([5, 1, 2, 2])
funsor torch.Size([1, 5, 2, 4]) torch.Size([1, 5, 2, 2])
funsor torch.Size([1, 5, 2, 1]) torch.Size([1, 5, 2, 2])
funsor torch.Size([5, 2, 2]) torch.Size([5, 2, 2])
funsor torch.Size([5, 2, 1]) torch.Size([5, 2, 2])
funsor torch.Size([5, 2, 1]) torch.Size([5, 2, 2])
CPU times: user 287 ms, sys: 30.9 ms, total: 318 ms
Wall time: 28.3 ms
pyro torch.Size([5, 3, 2, 4]) torch.Size([5, 3, 2, 2])
pyro torch.Size([5, 3, 2, 1]) torch.Size([5, 3, 2, 2])
pyro torch.Size([5, 1, 2, 4]) torch.Size([5, 1, 2, 2])
pyro torch.Size([5, 1, 2, 1]) torch.Size([5, 1, 2, 2])
pyro torch.Size([5, 1, 2, 4]) torch.Size([5, 1, 2, 2])
pyro torch.Size([5, 1, 2, 1]) torch.Size([5, 1, 2, 2])
pyro torch.Size([5, 2, 2]) torch.Size([5, 2, 2])
pyro torch.Size([5, 2, 1]) torch.Size([5, 2, 2])
pyro torch.Size([5, 2, 1]) torch.Size([5, 2, 2])
CPU times: user 70.4 ms, sys: 1.17 ms, total: 71.6 ms
Wall time: 5.84 ms
Profiling code
import pyro.distributions as dist
import pytest
import torch
from pyro.distributions.util import broadcast_shape

from funsor.pyro.hmm import DiscreteHMM, GaussianHMM, GaussianMRF, SwitchingLinearHMM
from funsor.testing import assert_close, random_mvn

batch_dim, time_dim, obs_dim, hidden_dim = 5, 6, 3, 2

init_shape = (batch_dim,)
trans_mat_shape = trans_mvn_shape = obs_mat_shape = obs_mvn_shape = (batch_dim, time_dim)
init_dist = random_mvn(init_shape, hidden_dim)
trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
trans_dist = random_mvn(trans_mvn_shape, hidden_dim)
obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
obs_dist = random_mvn(obs_mvn_shape, obs_dim)

actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)
expected_dist = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)
assert actual_dist.batch_shape == expected_dist.batch_shape
assert actual_dist.event_shape == expected_dist.event_shape

shape = broadcast_shape(init_shape + (1,),
                        trans_mat_shape, trans_mvn_shape,
                        obs_mat_shape, obs_mvn_shape)
data = obs_dist.expand(shape).sample()
assert data.shape == actual_dist.shape()

%time actual_log_prob = actual_dist.log_prob(data)
%time expected_log_prob = expected_dist.log_prob(data)

With the above observations, I think that using jax.jit will resolve this "constant" overhead. So I would like to go ahead and implement GaussianHMM in NumPyro. This requires modifications of current funsor.pyro to make it work with NumPyro. All changes in this PR are just for profiling, not to merge. Hopefully, jax.jit will work well with lazily evaluation mechanisms in funsor.

Tasks

  • todo

@fritzo
Copy link
Member

fritzo commented Feb 10, 2020

I'd also check out the case of long time dimension say 1000, so you can exercise parallel scan.

@fritzo
Copy link
Member

fritzo commented Feb 10, 2020

implement GaussianHMM in NumPyro ... requires modifications of current funsor.pyro to make it work with NumPyro

This sounds like a good next step to me.

@fehiepsi
Copy link
Member Author

fehiepsi commented Feb 10, 2020

I'd also check out the case of long time dimension say 1000, so you can exercise parallel scan.

@fritzo With time_dim = 6000, torch.cat now takes 60ms among the total 180ms, which is the slowest among all ops (triangular_solve only took 15ms). Other profiling results seem as expected to me (funsor took 180ms while pyro took 50ms).

@fritzo
Copy link
Member

fritzo commented Feb 10, 2020

@fehiepsi very interesting! This suggests your cat -> pad refactoring will indeed help!

@fritzo
Copy link
Member

fritzo commented Feb 10, 2020

@fehiepsi another thing I've been meaning to do is implement a 2d cat operation in PyTorch, rather than cat'ting a list of cats. I think we could do that using a bunch of reshapes and save time.

@fehiepsi
Copy link
Member Author

implement a 2d cat operation in PyTorch

@fritzo I added 2D cat and lazy evaluation of info_vec/precision in eager_add(gaussian, gaussian) and found that they made funsor 1.5x faster than before. :D

@fehiepsi
Copy link
Member Author

fehiepsi commented Feb 13, 2020

@fritzo @neerajprad I just did profiling GaussianHMM with numpyro backend. Here are some observations so far:

  • numpyro backend is pretty fast; with time_dim=6000, it is 20x faster than pyro and 80x faster than funsor in cpu (see this gist). WIth smaller time_dim, it is even faster.
  • with numpyro backend and the above time_dim, GPU is 1.5x faster than CPU. However, compiling time in GPU is pretty slow (a minute or so :( ).

This is a pretty good motivation to support numpy backend. :D I'll summarize things that need to change in detail below.

stuffs can be backend-agnostic

  • funsor.distributions: following "opt-einsum" approach, I created the dict
BACKEND_TO_DISTRIBUTION_BACKEND = {
    "torch": "pyro.distributions",
    "numpy": "numpyro.distributions"}

and depending on Tensor.backend, we will use either Pyro distributions' log_prob or NumPyro distributions' log_prob

  • test/pyro/*.py: I found that using the pattern
randn = partial(randn, backend="numpy")
random_tensor = partial(random_tensor, backend="numpy")
random_mvn = partial(random_mvn, backend="numpy")
dist = import_module(BACKEND_TO_DISTRIBUTION_BACKEND["numpy"])

is much simpler than adding backend pytest parametrize for each test. Can we use an environment flag BACKEND=numpy to distinguish two tests?

  • funsor.pyro.convert: except isinstance(pyro_dist, dist.Multivariate) statements, remaining implementation can be backend-agnostic.

diverging stuffs

  • We need both pyro distribution and numpyro distribution in the dispatching of dist_to_funsor @dist_to_funsor.register(dist.Bernoulli). I don't have a solution for this yet.
  • NumPyro distributions do not have .expand(...) method. <- we can support this in numpyro if needed.
  • In NumPyro, values are validated using the decorator validate_sample. Out-of-support log_prob will return NaN (instead of throwing an error as in Pyro - the reason is under jit, we can't check for out-of-support values). I am not sure how to make the behavior consistent for both Pyro and NumPyro here.
  • In NumPyro, distributions need a random key to sample. Currently, we don't have a mechanism to provide key in funsor. Should I add a keyword key=None here?
  • FunsorDistribution is a subclass of either pyro.distributions.Distribution or numpyro.distributions.Distribution. It seems that we need to maintain two FunsorDistribution here (though most of the code is the same).

What do you think about those diverging stuff? Do you have an idea to block numpyro import? I guess we can check if numpyro is available, then we add dispatches to it.

caution

  • Unlike torch.expand, jax.numpy.broadcast_to doesn't accept -1 dims and causes a fatal error while running (without traceback). I have spent a lot of time to fix that issue (it is hard to debug in funsor - I got lost in the forest of INTERPRETER stuffs... =.=).

@eb8680
Copy link
Member

eb8680 commented Feb 13, 2020

@fehiepsi you should be able to focus on speeding up Tensor and Gaussian and not worry too much about the distribution-related backend API divergence - I'll be refactoring funsor.distributions and the funsor<->data conversion functionality in funsor as part of #316 and related work.

@fritzo
Copy link
Member

fritzo commented Feb 13, 2020

@fehiepsi indeed those are encouraging numbers!

Can we use an environment flag BACKEND=numpy to distinguish two tests?

Sure, let's name it FUNSOR_BACKEND as in

BACKEND = os.environ.get("FUNSOR_BACKEND", "torch")

Would this be used in tests or in all of Funsor?

Unlike torch.expand, jax.numpy.broadcast_to doesn't accept -1

Good to know. Two pervasive options are:

  1. avoid -1 in ops.expand, and rewrite lots of funsor code
  2. wrap the jax implementation of ops.expand to replace -1 by 1, e.g.
    @ops.expand.register(jax array or whatever)
    def jax_expand(x, shape):
        shape = tuple(1 if size == -1 else size for size in shape)
        return jax.numpy.broadcast_to(x, shape)

it is hard to debug in funsor - I got lost in the forest of INTERPRETER

Have you tried running under FUNSOR_DEBUG=1 or FUNSOR_DEBUG=2? E.g.

FUNSOR_DEBUG=2 pytest -vs test/test_gaussian.py

@fritzo
Copy link
Member

fritzo commented Feb 13, 2020

@fehiepsi also when profiling, could you compare times when we're computing gradients? E.g.

init_dist = dist.Normal(pyro.param("init_loc", ...),
                        pyro.param("init_scale", ...)).to_event(1)
trans_matrix = pyro.param("trans_matrix", torch.eye(dim))
...

%%timit
hmm = GaussianHMM(init_dist, trans_matrix, ...)
hmm.log_prob(data).sum().backward()

@fehiepsi
Copy link
Member Author

@eb8680 Thanks! I'll follow-up with your refactoring.

Have you tried running under FUNSOR_DEBUG=1 or FUNSOR_DEBUG=2?

First time I know about them. I will use them next time. Thanks!

wrap the jax implementation of ops.expand to replace -1 by 1

Sure, let me do it. Thanks!

Would this be used in tests or in all of Funsor?

I only intended to use in future tests, but having a global BACKEND would be convenient. I think it is a good time to do it now. I'll sketch a PR to incorporate some fixes/enhancements that I catch in this PR.

@neerajprad
Copy link
Member

Regarding differences like:

NumPyro distributions do not have .expand(...) method. <- we can support this in numpyro if needed.

We can easily implement the generic Distribution.expand from Pyro, and are actually using a simple form of that under plate. The other differences seem very minor, and I am unsure how much of the broadcasting and enumeration machinery we will need to implement in numpyro, but please feel free to create a separate issue there to discuss what changes, if any, are needed to integrate with funsor. I think @eb8680's will probably have a better idea about this with #316.

@fehiepsi
Copy link
Member Author

Hi @neerajprad, Eli already made the PR pyro-ppl/pyro#2307. And I think much of that will be moved to funsor. We'll need to learn more from that PR, then we can discuss on how to use funsor in NumPyro (after seeing that jit/grad works in GaussianHMM, I don't expect that this job will be complicated :D).

@eb8680 eb8680 mentioned this pull request Feb 18, 2020
13 tasks
@fehiepsi
Copy link
Member Author

I'll make a separate PR for those optimizations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants