Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Actually save the optimizer state for tf learners #34252

Merged
merged 3 commits into from
Apr 12, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions rllib/core/learner/tf/tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,15 @@ def configure_optimizer_per_module(
) -> Union[ParamOptimizerPair, NamedParamOptimizerPairs]:
module = self._module[module_id]
lr = self._optimizer_config["lr"]
optim = tf.keras.optimizers.Adam(learning_rate=lr)
pair: ParamOptimizerPair = (
self.get_parameters(module),
tf.keras.optimizers.Adam(learning_rate=lr),
optim,
)
# this isn't strictly necessary, but makes it so that if a checkpoint is
# computed before training actually starts, then it will be the same in
# shape / size as a checkpoint after training starts.
optim.build(module.trainable_variables)
return pair

@override(Learner)
Expand Down Expand Up @@ -144,16 +149,22 @@ def _save_optimizers(self, path: Union[str, pathlib.Path]) -> None:
path = pathlib.Path(path)
path.mkdir(parents=True, exist_ok=True)
for name, optim in self._named_optimizers.items():
state = tf.keras.optimizers.serialize(optim)
state = tf.nest.map_structure(convert_numpy_to_python_primitives, state)
with open(path / f"{name}.json", "w") as f:
json.dump(state, f)
hparams = tf.keras.optimizers.serialize(optim)
hparams = tf.nest.map_structure(convert_numpy_to_python_primitives, hparams)
with open(path / f"{name}_hparams.json", "w") as f:
json.dump(hparams, f)
state = optim.variables()
serialized_tensors = [tf.io.serialize_tensor(tensor) for tensor in state]
contents = tf.strings.join(serialized_tensors, separator="tensor: ")
tf.io.write_file(str(path / f"{name}_state.txt"), contents)

@override(Learner)
def _load_optimizers(self, path: Union[str, pathlib.Path]) -> None:
path = pathlib.Path(path)
for name in self._named_optimizers.keys():
with open(path / f"{name}.json", "r") as f:
# This loads in a new optimizer initialized with the same hparams
# as the one that was saved.
with open(path / f"{name}_hparams.json", "r") as f:
state = json.load(f)
new_optim = tf.keras.optimizers.deserialize(state)
old_optim = self._named_optimizers[name]
Expand All @@ -163,6 +174,29 @@ def _load_optimizers(self, path: Union[str, pathlib.Path]) -> None:
for param_ref in param_seq:
self._optimizer_parameters[new_optim].append(param_ref)

# This loads in the actual state of the optimizer.
contents = tf.io.read_file(str(path / f"{name}_state.txt"))
serialized_tensors = tf.strings.split(contents, sep="tensor: ")
unserialized_optim_state = []
for serialized_tensor, old_optim_tensor in zip(
serialized_tensors, old_optim.variables()
):
unserialized_optim_state.append(
tf.io.parse_tensor(serialized_tensor, old_optim_tensor.dtype)
)
# delete the old optimizer / free its memory
del old_optim
# these are the variables that the optimizer is supposed to optimize over
variable_list = [
self._params[param_ref]
for param_ref in self._optimizer_parameters[new_optim]
]
# initialize the optimizer with the variables that it is supposed to
# optimize over
new_optim.build(variable_list)
# set the state of the optimizer to the state that was saved
new_optim.set_weights(unserialized_optim_state)

@override(Learner)
def set_weights(self, weights: Mapping[str, Any]) -> None:
self._module.set_state(weights)
Expand Down