-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib] Examples folder cleanup: ModelV2 -> RLModule wrapper for migr…
…ating to new API stack. (#47425)
- Loading branch information
Showing
15 changed files
with
255 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI | ||
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI | ||
|
||
|
||
__all__ = [ | ||
"TargetNetworkAPI", | ||
"ValueFunctionAPI", | ||
] |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from typing import Any, Dict | ||
|
||
from ray.rllib.core.columns import Columns | ||
from ray.rllib.core.rl_module.apis import ValueFunctionAPI | ||
from ray.rllib.core.rl_module.torch import TorchRLModule | ||
from ray.rllib.models.torch.torch_distributions import ( | ||
TorchCategorical, | ||
TorchDiagGaussian, | ||
TorchMultiCategorical, | ||
TorchMultiDistribution, | ||
TorchSquashedGaussian, | ||
) | ||
from ray.rllib.models.torch.torch_action_dist import ( | ||
TorchCategorical as OldTorchCategorical, | ||
TorchDiagGaussian as OldTorchDiagGaussian, | ||
TorchMultiActionDistribution as OldTorchMultiActionDistribution, | ||
TorchMultiCategorical as OldTorchMultiCategorical, | ||
TorchSquashedGaussian as OldTorchSquashedGaussian, | ||
) | ||
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 | ||
from ray.rllib.utils.annotations import override | ||
|
||
|
||
class ModelV2ToRLModule(TorchRLModule, ValueFunctionAPI): | ||
"""An RLModule containing a (old stack) ModelV2, provided by a policy checkpoint.""" | ||
|
||
@override(TorchRLModule) | ||
def setup(self): | ||
super().setup() | ||
|
||
# Get the policy checkpoint from the `model_config_dict`. | ||
policy_checkpoint_dir = self.config.model_config_dict.get( | ||
"policy_checkpoint_dir" | ||
) | ||
if policy_checkpoint_dir is None: | ||
raise ValueError( | ||
"The `model_config_dict` of your RLModule must contain a " | ||
"`policy_checkpoint_dir` key pointing to the policy checkpoint " | ||
"directory! You can find this dir under the Algorithm's checkpoint dir " | ||
"in subdirectory: [algo checkpoint dir]/policies/[policy ID, e.g. " | ||
"`default_policy`]." | ||
) | ||
|
||
# Create a temporary policy object. | ||
policy = TorchPolicyV2.from_checkpoint(policy_checkpoint_dir) | ||
self._model_v2 = policy.model | ||
|
||
# Translate the action dist classes from the old API stack to the new. | ||
self._action_dist_class = self._translate_dist_class(policy.dist_class) | ||
|
||
# Erase the torch policy from memory, so it can be garbage collected. | ||
del policy | ||
|
||
def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: | ||
nn_output, state_out = self._model_v2(batch) | ||
# Interpret the NN output as action logits. | ||
output = {Columns.ACTION_DIST_INPUTS: nn_output} | ||
# Add the `state_out` to the `output`, new API stack style. | ||
if state_out: | ||
output[Columns.STATE_OUT] = {} | ||
for i, o in enumerate(state_out): | ||
output[Columns.STATE_OUT][i] = o | ||
|
||
return output | ||
|
||
def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: | ||
return self._forward_inference(batch, **kwargs) | ||
|
||
def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: | ||
out = self._forward_inference(batch, **kwargs) | ||
out[Columns.ACTION_LOGP] = self._action_dist_class( | ||
out[Columns.ACTION_DIST_INPUTS] | ||
).logp(batch[Columns.ACTIONS]) | ||
out[Columns.VF_PREDS] = self._model_v2.value_function() | ||
return out | ||
|
||
def compute_values(self, batch: Dict[str, Any]): | ||
self._model_v2(batch) | ||
return self._model_v2.value_function() | ||
|
||
def get_inference_action_dist_cls(self): | ||
return self._action_dist_class | ||
|
||
def get_exploration_action_dist_cls(self): | ||
return self._action_dist_class | ||
|
||
def get_train_action_dist_cls(self): | ||
return self._action_dist_class | ||
|
||
def _translate_dist_class(self, old_dist_class): | ||
map_ = { | ||
OldTorchCategorical: TorchCategorical, | ||
OldTorchDiagGaussian: TorchDiagGaussian, | ||
OldTorchMultiActionDistribution: TorchMultiDistribution, | ||
OldTorchMultiCategorical: TorchMultiCategorical, | ||
OldTorchSquashedGaussian: TorchSquashedGaussian, | ||
} | ||
if old_dist_class not in map_: | ||
raise ValueError( | ||
f"ModelV2ToRLModule does NOT support {old_dist_class} action " | ||
f"distributions yet!" | ||
) | ||
|
||
return map_[old_dist_class] |
120 changes: 120 additions & 0 deletions
120
rllib/examples/rl_modules/migrate_modelv2_to_new_api_stack_by_policy_checkpoint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import pathlib | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
import torch | ||
|
||
from ray.rllib.algorithms.ppo import PPOConfig | ||
from ray.rllib.core.rl_module.rl_module import RLModuleConfig, RLModuleSpec | ||
from ray.rllib.examples.rl_modules.classes.modelv2_to_rlm import ModelV2ToRLModule | ||
from ray.rllib.utils.metrics import ( | ||
ENV_RUNNER_RESULTS, | ||
EPISODE_RETURN_MEAN, | ||
) | ||
from ray.rllib.utils.spaces.space_utils import batch | ||
|
||
|
||
if __name__ == "__main__": | ||
# Configure and train an old stack default ModelV2. | ||
config = ( | ||
PPOConfig() | ||
# Old API stack. | ||
.api_stack( | ||
enable_env_runner_and_connector_v2=False, | ||
enable_rl_module_and_learner=False, | ||
) | ||
.environment("CartPole-v1") | ||
.training( | ||
lr=0.0003, | ||
num_sgd_iter=6, | ||
vf_loss_coeff=0.01, | ||
) | ||
) | ||
algo_old_stack = config.build() | ||
|
||
min_return_old_stack = 100.0 | ||
while True: | ||
results = algo_old_stack.train() | ||
print(results) | ||
if results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= min_return_old_stack: | ||
print( | ||
f"Reached episode return of {min_return_old_stack} -> stopping " | ||
"old API stack training." | ||
) | ||
break | ||
|
||
checkpoint = algo_old_stack.save() | ||
policy_path = ( | ||
pathlib.Path(checkpoint.checkpoint.path) / "policies" / "default_policy" | ||
) | ||
assert policy_path.is_dir() | ||
algo_old_stack.stop() | ||
|
||
print("done") | ||
|
||
# Move the old API stack (trained) ModelV2 into the new API stack's RLModule. | ||
# Run a simple CartPole inference experiment. | ||
env = gym.make("CartPole-v1", render_mode="human") | ||
rl_module = ModelV2ToRLModule( | ||
config=RLModuleConfig( | ||
observation_space=env.observation_space, | ||
action_space=env.action_space, | ||
model_config_dict={"policy_checkpoint_dir": policy_path}, | ||
), | ||
) | ||
|
||
obs, _ = env.reset() | ||
env.render() | ||
done = False | ||
episode_return = 0.0 | ||
while not done: | ||
output = rl_module.forward_inference({"obs": torch.from_numpy(batch([obs]))}) | ||
action_logits = output["action_dist_inputs"].detach().numpy()[0] | ||
action = np.argmax(action_logits) | ||
obs, reward, terminated, truncated, _ = env.step(action) | ||
done = terminated or truncated | ||
episode_return += reward | ||
env.render() | ||
|
||
print(f"Ran episode with trained ModelV2: return={episode_return}") | ||
|
||
# Continue training with the (checkpointed) ModelV2. | ||
|
||
# We change the original (old API stack) `config` into a new API stack one: | ||
config = config.api_stack( | ||
enable_rl_module_and_learner=True, | ||
enable_env_runner_and_connector_v2=True, | ||
).rl_module( | ||
rl_module_spec=RLModuleSpec( | ||
module_class=ModelV2ToRLModule, | ||
model_config_dict={"policy_checkpoint_dir": policy_path}, | ||
), | ||
) | ||
|
||
# Build the new stack algo. | ||
algo_new_stack = config.build() | ||
|
||
# Train until a higher return. | ||
min_return_new_stack = 450.0 | ||
passed = False | ||
for i in range(50): | ||
results = algo_new_stack.train() | ||
print(results) | ||
# Make sure that the model's weights from the old API stack training | ||
# were properly transferred to the new API RLModule wrapper. Thus, even | ||
# after only one iteration of new stack training, we already expect the | ||
# return to be higher than it was at the end of the old stack training. | ||
assert results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= min_return_old_stack | ||
if results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= min_return_new_stack: | ||
print( | ||
f"Reached episode return of {min_return_new_stack} -> stopping " | ||
"new API stack training." | ||
) | ||
passed = True | ||
break | ||
|
||
if not passed: | ||
raise ValueError( | ||
"Continuing training on the new stack did not succeed! Last return: " | ||
f"{results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.