Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do I load a pre-trained model? #438

Closed
btaba opened this issue Jan 2, 2024 Discussed in #403 · 9 comments
Closed

How do I load a pre-trained model? #438

btaba opened this issue Jan 2, 2024 Discussed in #403 · 9 comments
Assignees

Comments

@btaba
Copy link
Collaborator

btaba commented Jan 2, 2024

Discussed in #403

Originally posted by eleninisioti October 11, 2023
There is a notebook that explains how to save and load models (https://github.com/google/brax/blob/main/notebooks/training.ipynb) but there testing happens right after training, calling function make_inference_fn(params), which requires first running make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress), the whole training process.

My question is how I can test without running the training process, simply by loading the params. How can I have inference_fn withour training?

@btaba btaba self-assigned this Feb 11, 2024
@Rian-Jo
Copy link

Rian-Jo commented Mar 6, 2024

@btaba Thank you for your works.

Would you let us know when the updates will be for this issue?

In addtion, is there any plan to work for deploying inference function or params/networks to c++ such as torch stript?

thank you.

@willthibault
Copy link

Hello,

Thank you for creating the issue @btaba from the original discussion.

Will this sort of functionality or @jihan1218 's solution become part of the main branch at some point?

Thanks!

@btaba
Copy link
Collaborator Author

btaba commented May 29, 2024

Should be fixed in b164655

Here's an example:

from orbax import checkpoint as ocp
from flax.training import orbax_utils

ckpt_path = epath.Path('/tmp/some-env/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)

def policy_params_fn(current_step, make_policy, params):
  # save checkpoints
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f'{current_step}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)


train_fn = functools.partial(
      ppo.train, num_timesteps=100_000_000,
      policy_params_fn=policy_params_fn,
      restore_checkpoint_path=ckpt_path / '11141120'  # to restart from a previous checkpoint
)

make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

You can recover the inference fn without training like so:

make_inference_fn, params, _= ppo.train(environment=env, num_timesteps=0)

And the params can be loaded from the checkpoint using orbax.

@btaba btaba closed this as completed May 29, 2024
@Rian-Jo
Copy link

Rian-Jo commented Jun 10, 2024

@btaba Hi.

I ran my code with orbax_checkpointer like the example above and got '_CHECKPOINT_METADATA', '_METADATA', and 'checkpoint' files. Then, when i tried to restore the checkpoint, i got the error such as

    136 print("*********************************")
    138 evalEnv = envs.GetEnvironment(self.envName)
--> 140 make_inference_fn, self.params, self.metrics = train_fn(environment=self.env, progress_fn=ProgressCallback, eval_env=evalEnv, policy_params_fn=PolicyCallback)
    142 # create data frame with train_rewards and steps
    143 df = pd.DataFrame({'Steps': trainSteps, 'Rewards': trainRewards, 'RewardsError': trainRewardsErr})

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/brax/training/agents/ppo/train.py:399, in train(environment, num_timesteps, episode_length, action_repeat, num_envs, max_devices_per_host, num_eval_envs, learning_rate, entropy_cost, discounting, seed, unroll_length, batch_size, num_minibatches, num_updates_per_batch, num_evals, num_resets_per_eval, normalize_observations, reward_scaling, clipping_epsilon, gae_lambda, deterministic_eval, network_factory, progress_fn, normalize_advantage, eval_env, policy_params_fn, randomization_fn, restore_checkpoint_path)
    397   orbax_checkpointer = ocp.PyTreeCheckpointer()
    398   target = training_state.normalizer_params, init_params
--> 399   (normalizer_params, init_params) = orbax_checkpointer.restore(
    400       restore_checkpoint_path, item=target
    401   )
    402   training_state = training_state.replace(
    403       normalizer_params=normalizer_params, params=init_params
    404   )
    406 training_state = jax.device_put_replicated(
    407     training_state,
    408     jax.local_devices()[:local_devices_to_use])

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py:211, in Checkpointer.restore(self, directory, *args, **kwargs)
    209 logging.info('Restoring item from %s.', directory)
    210 ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs)
--> 211 restored = self._handler.restore(directory, args=ckpt_args)
    212 logging.info('Finished restoring checkpoint from %s.', directory)
    213 multihost.sync_global_processes(
    214     multihost.unique_barrier_key(
    215         'Checkpointer:restore',
   (...)
    219     processes=self._active_processes,
    220 )

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:637, in PyTreeCheckpointHandler.restore(self, directory, item, restore_args, transforms, transforms_default_to_original, legacy_transform_fn, args)
    628 if (
    629     (directory / _METADATA_FILE).exists()
    630     and transforms is None
    631     and legacy_transform_fn is None
    632 ):
    633   args = BasePyTreeRestoreArgs(
    634       item,
    635       restore_args=restore_args,
    636   )
--> 637   return self._handler_impl.restore(directory, args=args)
    639 logging.debug('directory=%s, restore_args=%s', directory, restore_args)
    640 if not directory.exists():

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py:800, in BasePyTreeCheckpointHandler.restore(self, directory, args)
    794   logging.debug(
    795       'ts_metrics: %s',
    796       json.dumps(ts.experimental_collect_matching_metrics('/tensorstore/')),
    797   )
    799 if item is not None:
--> 800   return utils.deserialize_tree(restored_item, item)
    801 return restored_item

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/utils.py:269, in deserialize_tree(serialized, target, keep_empty_nodes)
    266     result = result[key_name]
    267   return result
--> 269 return jax.tree_util.tree_map_with_path(
    270     _reconstruct_from_keypath,
    271     target,
    272     is_leaf=is_empty_or_leaf if keep_empty_nodes else None,
    273 )

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/jax/_src/tree_util.py:1001, in tree_map_with_path(f, tree, is_leaf, *rest)
    999 keypath_leaves = list(zip(*keypath_leaves))
   1000 all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
-> 1001 return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/jax/_src/tree_util.py:1001, in <genexpr>(.0)
    999 keypath_leaves = list(zip(*keypath_leaves))
   1000 all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
-> 1001 return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))

File ~/.pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/utils.py:266, in deserialize_tree.<locals>._reconstruct_from_keypath(keypath, _)
    264   if not isinstance(result, list) and key_name not in result:
    265     key_name = str(key_name)
--> 266   result = result[key_name]
    267 return result

KeyError: 'policy'
'''

can you give me any hints to solve the problem?

Thank you!

@btaba
Copy link
Collaborator Author

btaba commented Jun 11, 2024

Hi @Rian-Jo , are you saving the full params in your checkpoint? It looks like the "policy" key is missing. Take a look at the colab example here to see what the diff is:

https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb

@Rian-Jo
Copy link

Rian-Jo commented Jun 11, 2024

@btaba I got it. Thank you so much!!

One more question.

Is there any difference between the params resulted from the train function and its saveed/loaded params? simulation playback with first and last seem different. And also, is there also any difference between the make_inference_fn from train function and the one from jax.xla_computation?

Thank you.

@btaba
Copy link
Collaborator Author

btaba commented Jun 11, 2024

Can you be a bit more specific about what the "saveed/loaded params" are?
What is the make_inference_fn from jax.xla_computation ? Not understanding what the question is

@Rian-Jo
Copy link

Rian-Jo commented Jun 12, 2024

Hi @btaba,

Here's an example:

from orbax import checkpoint as ocp
from flax.training import orbax_utils

ckpt_path = epath.Path('/tmp/some-env/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)

def policy_params_fn(current_step, make_policy, params):
  # save checkpoints
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f'{current_step}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)


train_fn = functools.partial(
      ppo.train, num_timesteps=100_000_000,
      policy_params_fn=policy_params_fn,
      restore_checkpoint_path=ckpt_path / '11141120'  # to restart from a previous checkpoint
)

make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

You can recover the inference fn without training like so:

make_inference_fn, params, _= ppo.train(environment=env, num_timesteps=0)

I have three jit_inference_fn.

  1. trained make_inference_fn
make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
jit_inference_fn = jax.jit(make_inference_fn(params))
model.save_params(model_path, params)
  1. restored make_inference_fn with saved_params
make_inference_fn, params, _= train_fn(environment=env, num_timesteps=0)
params = model.load_params(model_path)

jit_inference_fn = jax.jit(make_inference_fn(params))
  1. a wrapped version of jit_inference_fn by using jax.xla_computation
c = jax.xla_computation(jit_inference_fn)(...)

i use c from case 3 in c++.

should these three be functioning same? i got the results from each are different..
i wander that there are some chnages during the process,

(case 2) when the params is saved and loaded or call train_fn with 'num_timesteps=0'
(case 3) when the function is wrapped with xla_computation.


https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb

Meanwhile, in quadruped example, what is the meaning of editing state value (qvel)? i understand the qvel is the sensored value. how can random noise kick to floating base state (qvel[:2]) ensure the legs kick?

i imagined if random noise is added to the floating base state and pipeline_step runs, the floating base jumps to the randomly noised direction continuosly not legs.

...
  def step(self, state: State, action: jax.Array) -> State:  # pytype: disable=signature-mismatch
    rng, cmd_rng, kick_noise_2 = jax.random.split(state.info['rng'], 3)

    # kick
    push_interval = 10
    kick_theta = jax.random.uniform(kick_noise_2, maxval=2 * jp.pi)
    kick = jp.array([jp.cos(kick_theta), jp.sin(kick_theta)])
    kick *= jp.mod(state.info['step'], push_interval) == 0
    qvel = state.pipeline_state.qvel  # pytype: disable=attribute-error
    qvel = qvel.at[:2].set(kick * self._kick_vel + qvel[:2])
    state = state.tree_replace({'pipeline_state.qvel': qvel})

    # physics step
    motor_targets = self._default_pose + action * self._action_scale
    motor_targets = jp.clip(motor_targets, self.lowers, self.uppers)
    pipeline_state = self.pipeline_step(state.pipeline_state, motor_targets)
    x, xd = pipeline_state.x, pipeline_state.xd
...

Thank you.

@scott-yj-yang
Copy link

scott-yj-yang commented Jul 18, 2024

Hi @Rian-Jo , I have the similar issue when loading the policy:

Traceback (most recent call last):
  File "/root/vast/scott-yang/VNL-Brax-Imitation/train.py", line 365, in main
    make_inference_fn, params, _ = train_fn(
                                   ^^^^^^^^^
  File "/root/vast/scott-yang/VNL-Brax-Imitation/ppo_imitation/train.py", line 431, in train
    (normalizer_params, init_params) = orbax_checkpointer.restore(
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/anaconda3/envs/vnl/lib/python3.11/site-packages/orbax/checkpoint/checkpointer.py", line 211, in restore
    restored = self._handler.restore(directory, args=ckpt_args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/anaconda3/envs/vnl/lib/python3.11/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 642, in restore
    return self._handler_impl.restore(directory, args=args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/anaconda3/envs/vnl/lib/python3.11/site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 812, in restore
    return tree_utils.deserialize_tree(restored_item, item)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/anaconda3/envs/vnl/lib/python3.11/site-packages/orbax/checkpoint/tree/utils.py", line 176, in deserialize_tree
    return jax.tree_util.tree_map_with_path(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/anaconda3/envs/vnl/lib/python3.11/site-packages/jax/_src/tree_util.py", line 1170, in tree_map_with_path
    return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/anaconda3/envs/vnl/lib/python3.11/site-packages/jax/_src/tree_util.py", line 1170, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
                             ^^^^^^
  File "/root/anaconda3/envs/vnl/lib/python3.11/site-packages/orbax/checkpoint/tree/utils.py", line 173, in _reconstruct_from_keypath
    result = result[key_name]
             ~~~~~~^^^^^^^^^^
KeyError: 'policy'

May I ask what is the fix for saving the checkpoint? I am saving the checkpoint according to the example provided:

from orbax import checkpoint as ocp
from flax.training import orbax_utils

ckpt_path = epath.Path('/tmp/some-env/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)

def policy_params_fn(current_step, make_policy, params):
  # save checkpoints
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f'{current_step}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants