Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] - Fix APPO RLModule inference-only problems. #45111

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions rllib/algorithms/appo/appo_rl_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
This file holds framework-agnostic components for APPO's RLModules.
"""

import abc

from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule
from ray.rllib.utils.annotations import ExperimentalAPI

# TODO (simon): Write a light-weight version of this class for the `TFRLModule`


@ExperimentalAPI
class APPORLModule(PPORLModule, abc.ABC):
def setup(self):
super().setup()

# If the module is not for inference only, set up the target networks.
if not self.inference_only:
catalog = self.config.get_catalog()
# Old pi and old encoder are the "target networks" that are used for
# the stabilization of the updates of the current pi and encoder.
self.old_pi = catalog.build_pi_head(framework=self.framework)
self.old_encoder = catalog.build_actor_critic_encoder(
framework=self.framework
)
self.old_pi.load_state_dict(self.pi.state_dict())
self.old_encoder.load_state_dict(self.encoder.state_dict())
# We do not train the targets.
self.old_pi.requires_grad_(False)
self.old_encoder.requires_grad_(False)
15 changes: 2 additions & 13 deletions rllib/algorithms/appo/tf/appo_tf_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List

from ray.rllib.algorithms.appo.appo import OLD_ACTION_DIST_LOGITS_KEY
from ray.rllib.algorithms.appo.appo_rl_module import APPORLModule
from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import ACTOR
Expand All @@ -15,19 +16,7 @@
_, tf, _ = try_import_tf()


class APPOTfRLModule(PPOTfRLModule, RLModuleWithTargetNetworksInterface):
def setup(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

super().setup()
catalog = self.config.get_catalog()
# old pi and old encoder are the "target networks" that are used for
# the stabilization of the updates of the current pi and encoder.
self.old_pi = catalog.build_pi_head(framework=self.framework)
self.old_encoder = catalog.build_actor_critic_encoder(framework=self.framework)
self.old_pi.set_weights(self.pi.get_weights())
self.old_encoder.set_weights(self.encoder.get_weights())
self.old_pi.trainable = False
self.old_encoder.trainable = False

class APPOTfRLModule(PPOTfRLModule, RLModuleWithTargetNetworksInterface, APPORLModule):
@override(RLModuleWithTargetNetworksInterface)
def get_target_network_pairs(self):
return [(self.old_pi, self.pi), (self.old_encoder, self.encoder)]
Expand Down
36 changes: 22 additions & 14 deletions rllib/algorithms/appo/torch/appo_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ray.rllib.algorithms.appo.appo import (
OLD_ACTION_DIST_LOGITS_KEY,
)
from ray.rllib.algorithms.appo.appo_rl_module import APPORLModule
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import ACTOR
Expand All @@ -14,20 +15,9 @@
from ray.rllib.utils.nested_dict import NestedDict


class APPOTorchRLModule(PPOTorchRLModule, RLModuleWithTargetNetworksInterface):
@override(PPOTorchRLModule)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

def setup(self):
super().setup()
catalog = self.config.get_catalog()
# Old pi and old encoder are the "target networks" that are used for
# the stabilization of the updates of the current pi and encoder.
self.old_pi = catalog.build_pi_head(framework=self.framework)
self.old_encoder = catalog.build_actor_critic_encoder(framework=self.framework)
self.old_pi.load_state_dict(self.pi.state_dict())
self.old_encoder.load_state_dict(self.encoder.state_dict())
self.old_pi.trainable = False
self.old_encoder.trainable = False

class APPOTorchRLModule(
PPOTorchRLModule, RLModuleWithTargetNetworksInterface, APPORLModule
):
@override(RLModuleWithTargetNetworksInterface)
def get_target_network_pairs(self):
return [(self.old_pi, self.pi), (self.old_encoder, self.encoder)]
Expand All @@ -42,8 +32,26 @@ def output_specs_train(self) -> List[str]:

@override(PPOTorchRLModule)
def _forward_train(self, batch: NestedDict):
if self.inference_only:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe we should move this error into the parent RLModule class' main method?
def forward_train(self, ...):?

raise RuntimeError(
"Trying to train a module that is not a learner module. Set the "
"flag `inference_only=False` when building the module."
)

outs = super()._forward_train(batch)
old_pi_inputs_encoded = self.old_encoder(batch)[ENCODER_OUT][ACTOR]
old_action_dist_logits = self.old_pi(old_pi_inputs_encoded)
outs[OLD_ACTION_DIST_LOGITS_KEY] = old_action_dist_logits
return outs

@override(PPOTorchRLModule)
def _set_inference_only_state_dict_keys(self) -> None:
# Get the model_parameters from the `PPOTorchRLModule`.
super()._set_inference_only_state_dict_keys()
# Get the model_parameters.
state_dict = self.state_dict()
# Note, these keys are only known to the learner module. Furthermore,
# we want this to be run once during setup and not for each worker.
self._inference_only_state_dict_keys["unexpected_keys"].extend(
[name for name in state_dict if "old" in name]
)
6 changes: 4 additions & 2 deletions rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def _set_inference_only_state_dict_keys(self) -> None:
# Note, these keys are only known to the learner module. Furthermore,
# we want this to be run once during setup and not for each worker.
self._inference_only_state_dict_keys["unexpected_keys"] = [
name for name in state_dict if "vf" in name or "critic_encoder" in name
name
for name in state_dict
if "vf" in name or name.startswith("encoder.critic_encoder")
]
# Do we use a separate encoder for the actor and critic?
# if not self.config.model_config_dict.get("vf_share_layers", True):
Expand All @@ -153,7 +155,7 @@ def _set_inference_only_state_dict_keys(self) -> None:
self._inference_only_state_dict_keys["expected_keys"] = {
name: name.replace("actor_encoder", "encoder")
for name in state_dict
if "actor_encoder" in name
if name.startswith("encoder.actor_encoder")
}

@override(TorchRLModule)
Expand Down
Loading