You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've been working on a library for a bit -- and recently I've been considering how to inform users about usage of closures within the library.
One common idiom in my library is to wrap a function into a Pytree:
@lib.lift
def some_fn():
x = ...
The lifting here raises the function into a Pytree datatype (let's call it Lifted) which supports a set of interfaces which I use during transformations. These interfaces often do some computation, and then return out a stored version of the Lifted instance in a datatype.
Sometimes, users try and close over values inside of these lifted functions -- when the values are JAX tracers, they get stored in the Python closure for the object -- and eventually I get tracer leaks. I think the reason for this is the returning of the Lifted instance -- note that Lifted is defined as a Pytree and normally works fine (I've been careful with flatten and unflatten so tracers are always treated as dynamic data).
However, I'd really like to allow users to close over arrays -- so I tried to alleviate this problem by using jax.closure_convert -- e.g. when someone wants to use the interfaces on Lifted -- I closure convert under the hood, use the converted function -- and pass in the captured tracer arrays as arguments.
However, I'm still getting tracer leaks.
Was closure_convert meant to be used in this situation? If yes, any guesses as to why I'm still encountering leaks?
Sorry I'm being quite vague wrt the actual implementation (source code is closed, for now). If necessary, I could hop on a call with someone and explain / show code.
Edit: re -- can you use closure_convert to define Pytree-compat environments for closures? E.g. where the "static" information is the Python callable -- and the dynamic info is the closed over Pytree environment?
If yes -- has someone done this somewhere? I'd love to inspect it.
Edit 2: I inspected the .__closure__ dunder to understand what sort of data is being held by my original closure, then the transformed version, compared to the auxiliary arguments that come out of closure_convert:
where Pytree is a metaclass which registers the dataclass as a Pytree -- surprisingly, this works.
This seems like an awful kludge -- but honestly, I'd be surprised if one of the maintainers hasn't tried this before -- what are the sharp edges here?
E.g. I was considering defining __call__ for the closure by mutating the callable.__closure__ cells before invoking it. So then calling the closure would basically "put the arrays back in" before running the code. Then perhaps you'd also need to reset the environment (because Python doesn't support a native closure conversion transform on its closures, as far as I can tell).
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi all!
I've been working on a library for a bit -- and recently I've been considering how to inform users about usage of closures within the library.
One common idiom in my library is to wrap a function into a Pytree:
The lifting here raises the function into a Pytree datatype (let's call it
Lifted
) which supports a set of interfaces which I use during transformations. These interfaces often do some computation, and then return out a stored version of theLifted
instance in a datatype.Sometimes, users try and close over values inside of these lifted functions -- when the values are JAX tracers, they get stored in the Python closure for the object -- and eventually I get tracer leaks. I think the reason for this is the returning of the
Lifted
instance -- note thatLifted
is defined as aPytree
and normally works fine (I've been careful withflatten
andunflatten
so tracers are always treated as dynamic data).However, I'd really like to allow users to close over arrays -- so I tried to alleviate this problem by using
jax.closure_convert
-- e.g. when someone wants to use the interfaces onLifted
-- I closure convert under the hood, use the converted function -- and pass in the captured tracer arrays as arguments.However, I'm still getting tracer leaks.
Was
closure_convert
meant to be used in this situation? If yes, any guesses as to why I'm still encountering leaks?Sorry I'm being quite vague wrt the actual implementation (source code is closed, for now). If necessary, I could hop on a call with someone and explain / show code.
Edit: re -- can you use
closure_convert
to definePytree
-compat environments for closures? E.g. where the "static" information is the Python callable -- and the dynamic info is the closed overPytree
environment?If yes -- has someone done this somewhere? I'd love to inspect it.
Edit 2: I inspected the
.__closure__
dunder to understand what sort of data is being held by my original closure, then the transformed version, compared to the auxiliary arguments that come out ofclosure_convert
:So e.g. -- it actually seems like the transformed variant doesn't hold any
DynamicJaxprTracer
objects -- but I can't truly be sure I think.Edit 3: I'm fascinated by the prospect of lifting closures to a
Pytree
compat representation. I tried the following thing:where
Pytree
is a metaclass which registers the dataclass as aPytree
-- surprisingly, this works.This seems like an awful kludge -- but honestly, I'd be surprised if one of the maintainers hasn't tried this before -- what are the sharp edges here?
E.g. I was considering defining
__call__
for the closure by mutating thecallable.__closure__
cells before invoking it. So then calling the closure would basically "put the arrays back in" before running the code. Then perhaps you'd also need to reset the environment (because Python doesn't support a native closure conversion transform on its closures, as far as I can tell).Beta Was this translation helpful? Give feedback.
All reactions