You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to measure the FLOPs required for Dataset Distillation using the FRePo repository (https://github.com/yongchao97/FRePo) implemented with JAX. I have tried multiple approaches with the profiler in the JAX documentation, which uses trace statements. I used a similar approach in Pytorch. However, I am having trouble accessing the FLOP values needed for the calculations. I am looking for guidance on the best way to measure and track FLOPs in JAX for a larger task such as one epoch of a training loop of a neural network.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I am trying to measure the FLOPs required for Dataset Distillation using the FRePo repository (https://github.com/yongchao97/FRePo) implemented with JAX. I have tried multiple approaches with the profiler in the JAX documentation, which uses trace statements. I used a similar approach in Pytorch. However, I am having trouble accessing the FLOP values needed for the calculations. I am looking for guidance on the best way to measure and track FLOPs in JAX for a larger task such as one epoch of a training loop of a neural network.
Thanks in advance! :)
Beta Was this translation helpful? Give feedback.
All reactions