-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] - Fix APPO RLModule inference-only problems. #45111
Merged
sven1977
merged 5 commits into
ray-project:master
from
simonsays1980:fix-appo-inference-only-modules
May 3, 2024
Merged
Changes from 3 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
a76863e
Made APPO RLModule always a learner module.
simonsays1980 c2b5db9
Added a base module for APPO to enable straight inheritance of '_infe…
simonsays1980 52e0a27
Removed setup from 'APPTfRLModule' and inherited from 'APPORLModule' …
simonsays1980 95d1934
Removed weight updates from 'APPORLModule' to framework-specific modu…
simonsays1980 a650368
Removed checking for 'inference-only' from 'forward_train' as the pa…
simonsays1980 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
""" | ||
This file holds framework-agnostic components for APPO's RLModules. | ||
""" | ||
|
||
import abc | ||
|
||
from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule | ||
from ray.rllib.utils.annotations import ExperimentalAPI | ||
|
||
# TODO (simon): Write a light-weight version of this class for the `TFRLModule` | ||
|
||
|
||
@ExperimentalAPI | ||
class APPORLModule(PPORLModule, abc.ABC): | ||
def setup(self): | ||
super().setup() | ||
|
||
# If the module is not for inference only, set up the target networks. | ||
if not self.inference_only: | ||
catalog = self.config.get_catalog() | ||
# Old pi and old encoder are the "target networks" that are used for | ||
# the stabilization of the updates of the current pi and encoder. | ||
self.old_pi = catalog.build_pi_head(framework=self.framework) | ||
self.old_encoder = catalog.build_actor_critic_encoder( | ||
framework=self.framework | ||
) | ||
self.old_pi.load_state_dict(self.pi.state_dict()) | ||
self.old_encoder.load_state_dict(self.encoder.state_dict()) | ||
# We do not train the targets. | ||
self.old_pi.requires_grad_(False) | ||
self.old_encoder.requires_grad_(False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
from ray.rllib.algorithms.appo.appo import ( | ||
OLD_ACTION_DIST_LOGITS_KEY, | ||
) | ||
from ray.rllib.algorithms.appo.appo_rl_module import APPORLModule | ||
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule | ||
from ray.rllib.core.columns import Columns | ||
from ray.rllib.core.models.base import ACTOR | ||
|
@@ -14,20 +15,9 @@ | |
from ray.rllib.utils.nested_dict import NestedDict | ||
|
||
|
||
class APPOTorchRLModule(PPOTorchRLModule, RLModuleWithTargetNetworksInterface): | ||
@override(PPOTorchRLModule) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||
def setup(self): | ||
super().setup() | ||
catalog = self.config.get_catalog() | ||
# Old pi and old encoder are the "target networks" that are used for | ||
# the stabilization of the updates of the current pi and encoder. | ||
self.old_pi = catalog.build_pi_head(framework=self.framework) | ||
self.old_encoder = catalog.build_actor_critic_encoder(framework=self.framework) | ||
self.old_pi.load_state_dict(self.pi.state_dict()) | ||
self.old_encoder.load_state_dict(self.encoder.state_dict()) | ||
self.old_pi.trainable = False | ||
self.old_encoder.trainable = False | ||
|
||
class APPOTorchRLModule( | ||
PPOTorchRLModule, RLModuleWithTargetNetworksInterface, APPORLModule | ||
): | ||
@override(RLModuleWithTargetNetworksInterface) | ||
def get_target_network_pairs(self): | ||
return [(self.old_pi, self.pi), (self.old_encoder, self.encoder)] | ||
|
@@ -42,8 +32,26 @@ def output_specs_train(self) -> List[str]: | |
|
||
@override(PPOTorchRLModule) | ||
def _forward_train(self, batch: NestedDict): | ||
if self.inference_only: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Maybe we should move this error into the parent RLModule class' main method? |
||
raise RuntimeError( | ||
"Trying to train a module that is not a learner module. Set the " | ||
"flag `inference_only=False` when building the module." | ||
) | ||
|
||
outs = super()._forward_train(batch) | ||
old_pi_inputs_encoded = self.old_encoder(batch)[ENCODER_OUT][ACTOR] | ||
old_action_dist_logits = self.old_pi(old_pi_inputs_encoded) | ||
outs[OLD_ACTION_DIST_LOGITS_KEY] = old_action_dist_logits | ||
return outs | ||
|
||
@override(PPOTorchRLModule) | ||
def _set_inference_only_state_dict_keys(self) -> None: | ||
# Get the model_parameters from the `PPOTorchRLModule`. | ||
super()._set_inference_only_state_dict_keys() | ||
# Get the model_parameters. | ||
state_dict = self.state_dict() | ||
# Note, these keys are only known to the learner module. Furthermore, | ||
# we want this to be run once during setup and not for each worker. | ||
self._inference_only_state_dict_keys["unexpected_keys"].extend( | ||
[name for name in state_dict if "old" in name] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!