-
Notifications
You must be signed in to change notification settings - Fork 255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How do I load a pre-trained model? #438
Comments
@btaba Thank you for your works. Would you let us know when the updates will be for this issue? In addtion, is there any plan to work for deploying inference function or params/networks to c++ such as torch stript? thank you. |
Hello, Thank you for creating the issue @btaba from the original discussion. Will this sort of functionality or @jihan1218 's solution become part of the main branch at some point? Thanks! |
Should be fixed in b164655 Here's an example:
You can recover the inference fn without training like so:
And the params can be loaded from the checkpoint using orbax. |
@btaba Hi. I ran my code with orbax_checkpointer like the example above and got '_CHECKPOINT_METADATA', '_METADATA', and 'checkpoint' files. Then, when i tried to restore the checkpoint, i got the error such as
|
Hi @Rian-Jo , are you saving the full params in your checkpoint? It looks like the "policy" key is missing. Take a look at the colab example here to see what the diff is: https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb |
@btaba I got it. Thank you so much!! One more question. Is there any difference between the params resulted from the train function and its saveed/loaded params? simulation playback with first and last seem different. And also, is there also any difference between the make_inference_fn from train function and the one from jax.xla_computation? Thank you. |
Can you be a bit more specific about what the "saveed/loaded params" are? |
Hi @btaba,
I have three jit_inference_fn.
i use c from case 3 in c++. should these three be functioning same? i got the results from each are different.. (case 2) when the params is saved and loaded or call train_fn with 'num_timesteps=0'
Meanwhile, in quadruped example, what is the meaning of editing state value (qvel)? i understand the qvel is the sensored value. how can random noise kick to floating base state (qvel[:2]) ensure the legs kick? i imagined if random noise is added to the floating base state and pipeline_step runs, the floating base jumps to the randomly noised direction continuosly not legs.
Thank you. |
Hi @Rian-Jo , I have the similar issue when loading the policy: Traceback (most recent call last):
File "/root/vast/scott-yang/VNL-Brax-Imitation/train.py", line 365, in main
make_inference_fn, params, _ = train_fn(
^^^^^^^^^
File "/root/vast/scott-yang/VNL-Brax-Imitation/ppo_imitation/train.py", line 431, in train
(normalizer_params, init_params) = orbax_checkpointer.restore(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vnl/lib/python3.11/site-packages/orbax/checkpoint/checkpointer.py", line 211, in restore
restored = self._handler.restore(directory, args=ckpt_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vnl/lib/python3.11/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 642, in restore
return self._handler_impl.restore(directory, args=args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vnl/lib/python3.11/site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 812, in restore
return tree_utils.deserialize_tree(restored_item, item)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vnl/lib/python3.11/site-packages/orbax/checkpoint/tree/utils.py", line 176, in deserialize_tree
return jax.tree_util.tree_map_with_path(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vnl/lib/python3.11/site-packages/jax/_src/tree_util.py", line 1170, in tree_map_with_path
return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/vnl/lib/python3.11/site-packages/jax/_src/tree_util.py", line 1170, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
^^^^^^
File "/root/anaconda3/envs/vnl/lib/python3.11/site-packages/orbax/checkpoint/tree/utils.py", line 173, in _reconstruct_from_keypath
result = result[key_name]
~~~~~~^^^^^^^^^^
KeyError: 'policy' May I ask what is the fix for saving the checkpoint? I am saving the checkpoint according to the example provided: from orbax import checkpoint as ocp
from flax.training import orbax_utils
ckpt_path = epath.Path('/tmp/some-env/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)
def policy_params_fn(current_step, make_policy, params):
# save checkpoints
orbax_checkpointer = ocp.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(params)
path = ckpt_path / f'{current_step}'
orbax_checkpointer.save(path, params, force=True, save_args=save_args) Thank you! |
Discussed in #403
Originally posted by eleninisioti October 11, 2023
There is a notebook that explains how to save and load models (https://github.com/google/brax/blob/main/notebooks/training.ipynb) but there testing happens right after training, calling function
make_inference_fn(params)
, which requires first runningmake_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)
, the whole training process.My question is how I can test without running the training process, simply by loading the params. How can I have
inference_fn
withour training?The text was updated successfully, but these errors were encountered: