-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Profiling GaussianHMM with NumPyro backend #315
Conversation
I'd also check out the case of long time dimension say 1000, so you can exercise parallel scan. |
This sounds like a good next step to me. |
@fritzo With |
@fehiepsi very interesting! This suggests your |
@fehiepsi another thing I've been meaning to do is implement a 2d cat operation in PyTorch, rather than cat'ting a list of cats. I think we could do that using a bunch of reshapes and save time. |
@fritzo I added 2D cat and lazy evaluation of info_vec/precision in eager_add(gaussian, gaussian) and found that they made funsor 1.5x faster than before. :D |
…future refactoring
@fritzo @neerajprad I just did profiling GaussianHMM with numpyro backend. Here are some observations so far:
This is a pretty good motivation to support numpy backend. :D I'll summarize things that need to change in detail below. stuffs can be backend-agnostic
and depending on
is much simpler than adding
diverging stuffs
What do you think about those diverging stuff? Do you have an idea to block caution
|
@fehiepsi indeed those are encouraging numbers!
Sure, let's name it BACKEND = os.environ.get("FUNSOR_BACKEND", "torch") Would this be used in tests or in all of Funsor?
Good to know. Two pervasive options are:
Have you tried running under FUNSOR_DEBUG=2 pytest -vs test/test_gaussian.py |
@fehiepsi also when profiling, could you compare times when we're computing gradients? E.g. init_dist = dist.Normal(pyro.param("init_loc", ...),
pyro.param("init_scale", ...)).to_event(1)
trans_matrix = pyro.param("trans_matrix", torch.eye(dim))
...
%%timit
hmm = GaussianHMM(init_dist, trans_matrix, ...)
hmm.log_prob(data).sum().backward() |
@eb8680 Thanks! I'll follow-up with your refactoring.
First time I know about them. I will use them next time. Thanks!
Sure, let me do it. Thanks!
I only intended to use in future tests, but having a global |
Regarding differences like:
We can easily implement the generic |
Hi @neerajprad, Eli already made the PR pyro-ppl/pyro#2307. And I think much of that will be moved to funsor. We'll need to learn more from that PR, then we can discuss on how to use funsor in NumPyro (after seeing that jit/grad works in GaussianHMM, I don't expect that this job will be complicated :D). |
I'll make a separate PR for those optimizations. |
I just do profiling
funsor.pyro.hmm.GaussianHMM
andpyro.distributions.GaussianHMM
and have some observations:funsor
is slow comparing topyro
, especially for small hidden/obs dims. Withbatch_dim, time_dim, obs_dim, hidden_dim = 5, 6, 3, 2
,pyro
takes 5ms whilefunsor
takes 35ms to evaluate log_prob.funsor
seems to be constant. I increasedobs_dim, hidden_dim
to30, 20
and verified that .torch.cat
takes a large amount of time infunsor
(e.g. withbatch_dim, time_dim = 50, 60
, this op takes 7ms per the total 70ms). Similarly,torch.pad
takes a portion of time inpyro
(but the time is less thantorch.cat
infunsor
). I think the reason is we replacepad
bynew_zeros + cat
infunsor
.cat
,pad
) seems to be similar between two versions. It seems to me that at some points,funsor
switch dimensions. I guess this is due to the fact that in funsor, we store dimensions in a set. For example, I looked at args of triangular solve and got (this is not important I believe)Profiling code
With the above observations, I think that using
jax.jit
will resolve this "constant" overhead. So I would like to go ahead and implement GaussianHMM in NumPyro. This requires modifications of currentfunsor.pyro
to make it work with NumPyro. All changes in this PR are just for profiling, not to merge. Hopefully,jax.jit
will work well with lazily evaluation mechanisms infunsor
.Tasks