Replies: 1 comment
-
May be related to felipeangelimvieira/prophetverse#94 |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi all - I'm experiencing my RAM memory filling up over time when using
jax.device_get()
to move data to the host.Here's what I'm doing: I have a
while True
loop that 1) loads a model (in this case a UNet), 2) trains it, 3) copies the model to host & saves a checkpoint, and 4) reset the model weights to a base model. Here's a simplified version of this code:With each iteration of the While True my RAM memory fills up by around 4GB (the size of the UNet). If I comment out
get_params_to_save
I don't encounter the issue. However, it seems to be not 100% consistent - so I see my memory go up at times, and sometimes it stays flat 🤔 .My suspicion is that
jax.device_get()
keeps a reference to unet_state.params, which means unet_state is not properly garbage collected insidetrain()
(or at least not consistently). I have tried to explicitly garbage collect (usinggc.collect()
) without luck - I assumegc.collect()
only operates on Python objects, but has no influence on C/Jax structures?I can replace the line
with
which seems to solve the issue. So I was wondering what the difference between these two lines? Could I trigger jax's internal garbage collection manually to ensure all objects created inside
train()
are garbage collected? I'd be super happy about any insight on this! 🙏Side note: You may ask why I'm not just simply loading the model inside
train()
and then discarding it. The reason is other parts of the codebase as well as performance considerations.Beta Was this translation helpful? Give feedback.
All reactions