-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce resources to streamline the combination of jax transformation with OOP pattern #44
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diegoferigo
force-pushed
the
fix/trace_leak
branch
4 times, most recently
from
August 3, 2023 22:13
12b4bbb
to
5ee6189
Compare
diegoferigo
force-pushed
the
fix/trace_leak
branch
from
August 4, 2023 11:09
5ee6189
to
9a0c1f2
Compare
diegoferigo
force-pushed
the
fix/trace_leak
branch
from
August 4, 2023 11:54
9a0c1f2
to
3a59866
Compare
diegoferigo
changed the title
Fix trace leak by introducing a new pattern for jit-compiling class methods
Introduce resources to streamline the combination of jax transformation with OOP pattern
Aug 4, 2023
flferretti
approved these changes
Aug 4, 2023
traversaro
approved these changes
Aug 4, 2023
diegoferigo
force-pushed
the
fix/trace_leak
branch
from
August 4, 2023 15:26
3a59866
to
48c54f0
Compare
diegoferigo
force-pushed
the
fix/trace_leak
branch
from
October 9, 2023 07:42
48c54f0
to
c1f5b63
Compare
Given that the OOP decorators alter considerably downstream code and we're not yet sure they are 100% compatible with our long-term goals, I'll proceed with caution by merging this and next PRs into the If everything keeps looking good, the new APIs will become the default ones. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
JAX endorses a functional programming pattern in which algorithms are executed stateless over some data structure that is passed as input and returned (possibly updated) as output. This approach has many benefits, but often a OOP might result more user friendly.
We already adopted such OOP pattern by using
jax_dataclasses
to create pytrees. These pytrees can have dataclass attributes (fields) containing static/dynamic data, and algorithms can be implemented as dataclass methods. If these methods werestaticmethods
, the functional pattern would be preserved, however it'd be users' responsibility to correctly propagating the state. If the methods, instead, are not static, they might update the data that, if not done correctly, could trigger jit recompilations and other problems like trace leaks (#43).We already apply such OOP pattern to our classes and so far it has worked decently well, but recent jax versions became more picky about trace leaks when this pattern is used in conjunction with jit compilation.
The main problem seems that jax thinks we have trace leaks because instead of having dataclass methods returning a tuple
(output, state)
, they just returnoutput
and updatestate
directly from the method. This breaks the functional pattern that jax expects.Furthermore, applying algorithms of parallel objects was not straightforward since it required creating lambdas/closures and pass them to
jax.vmap
. This is quite confusing, particularly for new users.This PR tries to solve these limitations. The idea is to keep using OOP, but handle transparently some internal jax details to simplify developers' and users' life. In particular, the following is the desiderata:
jax.jit
(addressing also leaked traces).jax_dataclasses
to create custom PyTrees.self
is a pytree so it's ok), therefore different objects can re-use the first jit-compiled method.jax.vmap
if they have been parallelized (they have all fields with the batch dimension as first axis).There are few caveat to consider when such approach is implemented. In this PR, I only introduce the tooling and a test to achieve these goals. I'll update all the jaxsim modules in a new PR.