Replies: 1 comment 21 replies
-
I think this is working as intended. The device does not grow its memory pool until it is needed, but it will never shrink its memory pool once it's grown. After the initial allocations, the memory pool never expands again, because new allocations fit in the existing pool. Does that make sense? Side note, your def move_np_to_jnp(state):
return jax.device_put(state) ( |
Beta Was this translation helpful? Give feedback.
21 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I've been playing around on Google Colab with GPU instances.
When I call jax.device_put() on a converted numpy array, there seems to be memory allocated that cannot be recovered even after running jax.clear_backends() and gc.collect()
not sure if this is intended behavior, I wasn't able to find anything about this in the docs.
Jax version is one prepackaged with Google Colab
Beta Was this translation helpful? Give feedback.
All reactions