Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Actually save the optimizer state for tf learners #34252

Merged
merged 3 commits into from
Apr 12, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 107 additions & 8 deletions rllib/core/learner/tf/tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,15 @@ def configure_optimizer_per_module(
) -> Union[ParamOptimizerPair, NamedParamOptimizerPairs]:
module = self._module[module_id]
lr = self._optimizer_config["lr"]
optim = tf.keras.optimizers.Adam(learning_rate=lr)
pair: ParamOptimizerPair = (
self.get_parameters(module),
tf.keras.optimizers.Adam(learning_rate=lr),
optim,
)
# this isn't strictly necessary, but makes it so that if a checkpoint is
# computed before training actually starts, then it will be the same in
# shape / size as a checkpoint after training starts.
optim.build(module.trainable_variables)
return pair

@override(Learner)
Expand Down Expand Up @@ -139,30 +144,124 @@ def load_state(
with self._strategy.scope():
super().load_state(path)

def _save_optimizer_hparams(
self,
path: pathlib.Path,
optim: "tf.keras.optimizers.Optimizer",
optim_name: str,
) -> None:
"""Save the hyperparameters of optim to path/optim_name_hparams.json.

Args:
path: The path to the directory to save the hyperparameters to.
optim: The optimizer to save the hyperparameters of.
optim_name: The name of the optimizer.

"""
hparams = tf.keras.optimizers.serialize(optim)
hparams = tf.nest.map_structure(convert_numpy_to_python_primitives, hparams)
with open(path / f"{optim_name}_hparams.json", "w") as f:
json.dump(hparams, f)

def _save_optimizer_state(
self,
path: pathlib.Path,
optim: "tf.keras.optimizers.Optimizer",
optim_name: str,
) -> None:
"""Save the state variables of optim to path/optim_name_state.txt.

Args:
path: The path to the directory to save the state to.
optim: The optimizer to save the state of.
optim_name: The name of the optimizer.

"""
state = optim.variables()
serialized_tensors = [tf.io.serialize_tensor(tensor) for tensor in state]
contents = tf.strings.join(serialized_tensors, separator="tensor: ")
tf.io.write_file(str(path / f"{optim_name}_state.txt"), contents)

@override(Learner)
def _save_optimizers(self, path: Union[str, pathlib.Path]) -> None:
path = pathlib.Path(path)
path.mkdir(parents=True, exist_ok=True)
for name, optim in self._named_optimizers.items():
state = tf.keras.optimizers.serialize(optim)
state = tf.nest.map_structure(convert_numpy_to_python_primitives, state)
with open(path / f"{name}.json", "w") as f:
json.dump(state, f)
self._save_optimizer_hparams(path, optim, name)
self._save_optimizer_state(path, optim, name)

def _load_optimizer_from_hparams(
self, path: pathlib.Path, optim_name: str
) -> "tf.keras.optimizers.Optimizer":
"""Load an optimizer from the hyperparameters saved at path/optim_name_hparams.json.

Args:
path: The path to the directory to load the hyperparameters from.
optim_name: The name of the optimizer.

Returns:
The optimizer loaded from the hyperparameters.

"""
with open(path / f"{optim_name}_hparams.json", "r") as f:
state = json.load(f)
return tf.keras.optimizers.deserialize(state)

def _load_optimizer_state(
self,
path: pathlib.Path,
optim: "tf.keras.optimizers.Optimizer",
optim_name: str,
) -> None:
"""Load the state of optim from the state saved at path/optim_name_state.txt.

Args:
path: The path to the directory to load the state from.
optim: The optimizer to load the state into.
optim_name: The name of the optimizer.

"""
contents = tf.io.read_file(str(path / f"{optim_name}_state.txt"))
serialized_tensors = tf.strings.split(contents, sep="tensor: ")
unserialized_optim_state = []
for serialized_tensor, optim_tensor in zip(
serialized_tensors, optim.variables()
):
unserialized_optim_state.append(
tf.io.parse_tensor(serialized_tensor, optim_tensor.dtype)
)

# set the state of the optimizer to the state that was saved
optim.set_weights(unserialized_optim_state)

@override(Learner)
def _load_optimizers(self, path: Union[str, pathlib.Path]) -> None:
path = pathlib.Path(path)
for name in self._named_optimizers.keys():
with open(path / f"{name}.json", "r") as f:
state = json.load(f)
new_optim = tf.keras.optimizers.deserialize(state)
new_optim = self._load_optimizer_from_hparams(path, name)
old_optim = self._named_optimizers[name]

# assign replace the old optim with the new optim in the learner's state
self._named_optimizers[name] = new_optim
param_seq = self._optimizer_parameters.pop(old_optim)
self._optimizer_parameters[new_optim] = []
for param_ref in param_seq:
self._optimizer_parameters[new_optim].append(param_ref)

# delete the old optimizer / free its memory
del old_optim
# these are the variables that the optimizer is supposed to optimize over
variable_list = [
self._params[param_ref]
for param_ref in self._optimizer_parameters[new_optim]
]
# initialize the optimizer with the variables that it is supposed to
# optimize over
new_optim.build(variable_list)

# This loads in the actual state of the optimizer.
self._load_optimizer_state(path, new_optim, name)

@override(Learner)
def set_weights(self, weights: Mapping[str, Any]) -> None:
self._module.set_state(weights)
Expand Down