Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support global backend #318

Merged
merged 23 commits into from
Feb 19, 2020
Merged

Support global backend #318

merged 23 commits into from
Feb 19, 2020

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Feb 14, 2020

Addresses #207

Follow up discussions at #317, this PR supports global backend in funsor.

To make it easier for reviewing, I list here important changes in this PR:

  • Separate out jax and numpy implementation. I also take this chance to reorder those ops according to their names (to make it easier to compare different backends or to find some specific implementation).
  • Use numpy for default implementation of numeric ops. funsor.torch or funsor.jax are lazily imported, depending on the environment variable FUNSOR_BACKEND or the usage of funsor.set_backend(...).
  • We can change backend using the utility set_backend. There is a problem: if we have change backend to JAX, we can't revert back to numpy backend.
  • Port Pyro Tensordot implementation to funsor. Pyro implementation has some statements x.dim(), which is not available for ndarray
  • Despite that we will use numpy backend by default, I still use PyTorch names (except for amax, amin) for those ops (like the current master branch). I guess Pyro devs would like PyTorch names than NumPy names for numeric ops.

Some small changes:

  • replace torch._C._get_tracing_state by funsor.util.get_tracing_state
  • add funsor.util.is_nn_module
  • add ops.is_numeric_array to check if a variable is a numeric array. We can rename it to ops.is_tensor to match torch.is_tensor if prefered.
  • add ops.detach to detach gradient. This is no op for numpy backend.
  • fix issues of ops.log which gives warnings at 0. value
  • support both tuple or iterator for testing utilities randn, zeros,... (because I found so many places in test files that use that pattern torch.zeros(1, 2, 3) instead of torch.zeros((1, 2, 3))).
  • dispatch for recursion_reinterpret.register(torch.Tensor) is implemented in torch.py file, instead of interpreter.py file. (similar for children, to_funsor, allclose)
  • others are cleanup and reordering changes

Tests pass with all backends

  • test_affine
  • test_alpha_conversion
  • test_cnf
  • test_delta
  • test_gaussian
  • test_import
  • test_integrate
  • test_joint
  • test_sum_product
  • test_tensor
  • test_terms

Tests only work with torch/pyro

  • test/pyro/
  • test/test_minipyro.py
  • test/test_adjoint.py: requires einsum.adjoint.require_backward
  • test/test_distributions.py: refactoring
  • test/test_einsum.py
  • test/test_memoize.py: requires some einsum stuffs
  • test/test_optimizer.py: require distributions and einsum

TODO:

  • separate numpy and jax backend.
  • revise DICE logic

@fehiepsi fehiepsi added the WIP label Feb 14, 2020
@fritzo
Copy link
Member

fritzo commented Feb 15, 2020

@fehiepsi FYI to help motivate this PR:
As part of Eli's #316 and pyro-ppl/pyro#2307 we discussed the possibility of moving some of the poutines (EnumMessenger, MarkovMessenger, ...) from Pyro to Funsor, so they could be immediately used also in NumPyro. Then we would make Pyro depend on Funsor, which would be ok (non-cyclic) if both torch and jax were optional/extras dependencies of Funsor, whence also Pyro could be an optional dependency. I.e. Pyro would always depend on Funsor, but Funsor would only optionally/extras depend on Pyro.

So it would be nice if, either in this PR or a follow-up PR, you could make pyro-ppl an optional dependency.

@fehiepsi
Copy link
Member Author

Understood! I totally agree with (and love) that decision choice. Let me address it in this PR. Thanks for your explaining, Fritz!

@fehiepsi fehiepsi removed the WIP label Feb 18, 2020
@fehiepsi
Copy link
Member Author

@fritzo I tried to make pyro/torch as "optional" as possible but there are still several places that still requires torch/pyro (see my topic comment). Most of them require the port of https://github.com/pyro-ppl/pyro/tree/dev/pyro/ops/einsum to numpy and jax backend, which will be addressed in #314 (I can go ahead to make the port or I could wait for refactoring pyro.ops.packed). Others require funsor.distributions and funsor.pyro, which are on the process of refactoring so I should wait for a while.

@fritzo
Copy link
Member

fritzo commented Feb 18, 2020

@fehiepsi I think it's fine to move/copy backend-agnostic logic from pyro.ops.einsum into Funsor, IIUC. Then once Pyro depends on Funsor we can delete the original code?

@eb8680 how close do you think funsor.distributions and funsor.pyro are to being backend-agnostic? Should you and @fehiepsi and I meet this week to plan refactoring?

@fritzo
Copy link
Member

fritzo commented Feb 18, 2020

@fehiepsi since this is a big refactoring and your are blocked, we could split the refactoring into multiple PRs if you want. This PR does most of the work but keeps "torch" as a hard dependency; then a follow up PR cleans up once the blockers are removed. Splitting up should reduce merge conflicts with any of @eb8680's refactoring of the distributions code.

@eb8680
Copy link
Member

eb8680 commented Feb 18, 2020

how close do you think funsor.distributions and funsor.pyro are to being backend-agnostic? Should you and @fehiepsi and I meet this week to plan refactoring?

There's still a lot of work for me to do on refactoring funsor.distributions, implementing to_funsor/to_data for distributions and finishing pyro-ppl/pyro#2307 so it'll be a few days.

@fehiepsi within this PR, I think it's fine to not finish breaking up the backends completely, and to not bother with anything in funsor.distributions or funsor.pyro yet so we don't do unnecessary work. I agree with @fritzo's suggestion of breaking up your refactoring into multiple PRs. I think it's fine to push ahead with #314 in a simplified form - I'll leave a comment there.

If you're still blocked after this is merged, I suggest focusing on performance improvements for Gaussian and Tensor in #315 so that we can push ahead on replacing pyro.ops.gaussian and the Pyro HMM distributions.

@fehiepsi
Copy link
Member Author

@fritzo @eb8680 I already mark pytest skip the blocked tests. I'll port pyro.ops.einsum to here in #314 and will remove some of the marks. For funsor.distributions and funsor.pyro, let's address it in a future PR (one of the approaches is to use string to let funsor choose the correct distribution backend, as in #315). The only blocker for this PR now is a bug somewhere which makes test_adjoint fail for torch backend. I am debugging for it. :)

@fritzo Let me port einsum stuff first, then I am happy to chat for further refactoring.

@eb8680 About performance, there is a "constant" cost (which is large) in funsor (due to function dispatch I guess). I don't expect that funsor can achieve the performance of pyro.ops.gaussian with torch backend, but I'll try to do profiling more and incorporate some optimizations in a PR.

@eb8680
Copy link
Member

eb8680 commented Feb 18, 2020

there is a "constant" cost (which is large) in funsor (due to function dispatch I guess)

We can discuss further in #315 but this seems like something the PyTorch JIT might actually help with?

@fehiepsi
Copy link
Member Author

Thanks for your suggestion FUNSOR_DEBUG=1 @fritzo ! It helps me a lot on isolating the bug. 😆

@fehiepsi
Copy link
Member Author

the PyTorch JIT might actually help with

@eb8680 Yeah, I think it will help. Let me try it!

Makefile Show resolved Hide resolved
)

set_backend(get_backend())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a little safer to execute this in funsor.util where set_backend() is defined. That way anything depending on get_backend() will execute after this statement. Do you have a reason for locating it here in __init__.py?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no reason besides let the backend executed by default. I'll move it there.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work @fehiepsi I'm looking forward to a fully agnostic library!

I have a couple of comments, feel free to address them in a follow-up PR.

Also generally I am a little confused about the three different ways to set backend:

  • environment variable
  • set_backend()
  • inferred from context (e.g. use "numpy" even when using "torch" backend)

I think it's fine for these to proliferate, but it would be good to have a big docstring somewhere explaining all the nuances and giving a complete picture. Maybe set_backend.__doc__ would be a good place, or maybe get_backend.__doc__?

Comment on lines 9 to +20
def pytest_runtest_setup(item):
pyro.set_rng_seed(0)
pyro.enable_validation(True)
np.random.seed(0)
backend = get_backend()
if backend == "torch":
import pyro

pyro.set_rng_seed(0)
pyro.enable_validation(True)
elif backend == "jax":
from jax.config import config

config.update('jax_platform_name', 'cpu')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume set_backend() should never be called by a test, rather backend should be set by FUNSOR_BACKEND environment variable. It might be worth documenting this or better adding a check like

def _disallow_set_backend(*args):
    raise ValueError("set_backend() cannot be called during tests")

def pytest_runtest_setup(item):
    ...
    import funsor.util
    funsor.util.set_backend = _disallow_set_backend

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I'll incorporate it in the next PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I'll incorporate it in the next PR.

@fehiepsi
Copy link
Member Author

@fritzo I'm going to merge this. I'll add more doc to set_backend in the next PR. I think users will most likely use set_backend instead of an environment variable (which is only useful for testing unless users want to play with import os; os.environ... stuff). About "inferred from context", doing so will cause errors I believe (e.g. einsum backend depends on FUNSOR_BACKEND).

@fehiepsi fehiepsi merged commit 812923a into pyro-ppl:master Feb 19, 2020
@fehiepsi
Copy link
Member Author

Thanks a lot for your review and suggestions, @fritzo !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants