-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support global backend #318
Conversation
@fehiepsi FYI to help motivate this PR: So it would be nice if, either in this PR or a follow-up PR, you could make |
Understood! I totally agree with (and love) that decision choice. Let me address it in this PR. Thanks for your explaining, Fritz! |
@fritzo I tried to make pyro/torch as "optional" as possible but there are still several places that still requires torch/pyro (see my topic comment). Most of them require the port of |
@fehiepsi I think it's fine to move/copy backend-agnostic logic from @eb8680 how close do you think |
@fehiepsi since this is a big refactoring and your are blocked, we could split the refactoring into multiple PRs if you want. This PR does most of the work but keeps "torch" as a hard dependency; then a follow up PR cleans up once the blockers are removed. Splitting up should reduce merge conflicts with any of @eb8680's refactoring of the distributions code. |
There's still a lot of work for me to do on refactoring @fehiepsi within this PR, I think it's fine to not finish breaking up the backends completely, and to not bother with anything in If you're still blocked after this is merged, I suggest focusing on performance improvements for |
@fritzo @eb8680 I already mark pytest skip the blocked tests. I'll port @fritzo Let me port einsum stuff first, then I am happy to chat for further refactoring. @eb8680 About performance, there is a "constant" cost (which is large) in funsor (due to function dispatch I guess). I don't expect that |
We can discuss further in #315 but this seems like something the PyTorch JIT might actually help with? |
Thanks for your suggestion |
@eb8680 Yeah, I think it will help. Let me try it! |
) | ||
|
||
set_backend(get_backend()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems a little safer to execute this in funsor.util
where set_backend()
is defined. That way anything depending on get_backend()
will execute after this statement. Do you have a reason for locating it here in __init__.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no reason besides let the backend executed by default. I'll move it there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work @fehiepsi I'm looking forward to a fully agnostic library!
I have a couple of comments, feel free to address them in a follow-up PR.
Also generally I am a little confused about the three different ways to set backend:
- environment variable
set_backend()
- inferred from context (e.g. use "numpy" even when using "torch" backend)
I think it's fine for these to proliferate, but it would be good to have a big docstring somewhere explaining all the nuances and giving a complete picture. Maybe set_backend.__doc__
would be a good place, or maybe get_backend.__doc__
?
def pytest_runtest_setup(item): | ||
pyro.set_rng_seed(0) | ||
pyro.enable_validation(True) | ||
np.random.seed(0) | ||
backend = get_backend() | ||
if backend == "torch": | ||
import pyro | ||
|
||
pyro.set_rng_seed(0) | ||
pyro.enable_validation(True) | ||
elif backend == "jax": | ||
from jax.config import config | ||
|
||
config.update('jax_platform_name', 'cpu') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume set_backend()
should never be called by a test, rather backend should be set by FUNSOR_BACKEND
environment variable. It might be worth documenting this or better adding a check like
def _disallow_set_backend(*args):
raise ValueError("set_backend() cannot be called during tests")
def pytest_runtest_setup(item):
...
import funsor.util
funsor.util.set_backend = _disallow_set_backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I'll incorporate it in the next PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I'll incorporate it in the next PR.
@fritzo I'm going to merge this. I'll add more doc to |
Thanks a lot for your review and suggestions, @fritzo ! |
Addresses #207
Follow up discussions at #317, this PR supports global backend in funsor.
To make it easier for reviewing, I list here important changes in this PR:
jax
andnumpy
implementation. I also take this chance to reorder those ops according to their names (to make it easier to compare different backends or to find some specific implementation).numpy
for default implementation of numeric ops.funsor.torch
orfunsor.jax
are lazily imported, depending on the environment variable FUNSOR_BACKEND or the usage offunsor.set_backend(...)
.set_backend
. There is a problem: if we have change backend to JAX, we can't revert back tonumpy
backend.x.dim()
, which is not available for ndarrayamax
,amin
) for those ops (like the current master branch). I guess Pyro devs would like PyTorch names than NumPy names for numeric ops.Some small changes:
torch._C._get_tracing_state
byfunsor.util.get_tracing_state
funsor.util.is_nn_module
ops.is_numeric_array
to check if a variable is a numeric array. We can rename it toops.is_tensor
to matchtorch.is_tensor
if prefered.ops.detach
to detach gradient. This is no op for numpy backend.ops.log
which gives warnings at 0. valuetuple
or iterator for testing utilitiesrandn
,zeros
,... (because I found so many places in test files that use that patterntorch.zeros(1, 2, 3)
instead oftorch.zeros((1, 2, 3))
).recursion_reinterpret.register(torch.Tensor)
is implemented intorch.py
file, instead ofinterpreter.py
file. (similar forchildren
,to_funsor
,allclose
)Tests pass with all backends
Tests only work with torch/pyro
test/pyro/
test/test_minipyro.py
test/test_adjoint.py
: requireseinsum.adjoint.require_backward
test/test_distributions.py
: refactoringtest/test_einsum.py
test/test_memoize.py
: requires some einsum stuffstest/test_optimizer.py
: require distributions and einsumTODO:
numpy
andjax
backend.