Skip to content

Commit

Permalink
[RLlib] Examples folder cleanup: ModelV2 -> RLModule wrapper for migr…
Browse files Browse the repository at this point in the history
…ating to new API stack. (#47425)
  • Loading branch information
sven1977 authored Aug 30, 2024
1 parent f0a81a6 commit 3c950a1
Show file tree
Hide file tree
Showing 15 changed files with 255 additions and 43 deletions.
51 changes: 23 additions & 28 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3459,20 +3459,20 @@ py_test(
# subdirectory: rl_modules/
# ....................................
py_test(
name = "examples/rl_modules/action_masking_rlm",
main = "examples/rl_modules/action_masking_rlm.py",
name = "examples/rl_modules/action_masking_rl_module",
main = "examples/rl_modules/action_masking_rl_module.py",
tags = ["team:rllib", "examples"],
size = "medium",
srcs = ["examples/rl_modules/action_masking_rlm.py"],
srcs = ["examples/rl_modules/action_masking_rl_module.py"],
args = ["--enable-new-api-stack", "--stop-iters=5"],
)

py_test(
name = "examples/rl_modules/autoregressive_actions_rlm",
main = "examples/rl_modules/autoregressive_actions_rlm.py",
name = "examples/rl_modules/autoregressive_actions_rl_module",
main = "examples/rl_modules/autoregressive_actions_rl_module.py",
tags = ["team:rllib", "examples"],
size = "medium",
srcs = ["examples/rl_modules/autoregressive_actions_rlm.py"],
srcs = ["examples/rl_modules/autoregressive_actions_rl_module.py"],
args = ["--enable-new-api-stack"],
)
py_test(
Expand Down Expand Up @@ -3501,6 +3501,13 @@ py_test(
srcs = ["examples/rl_modules/classes/mobilenet_rlm.py"],
)

py_test(
name = "examples/rl_modules/migrate_modelv2_to_new_api_stack_by_policy_checkpoint",
main = "examples/rl_modules/migrate_modelv2_to_new_api_stack_by_policy_checkpoint.py",
tags = ["team:rllib", "examples"],
size = "large",
srcs = ["examples/rl_modules/migrate_modelv2_to_new_api_stack_by_policy_checkpoint.py"],
)
py_test(
name = "examples/rl_modules/pretraining_single_agent_training_multi_agent",
main = "examples/rl_modules/pretraining_single_agent_training_multi_agent.py",
Expand All @@ -3510,6 +3517,7 @@ py_test(
args = ["--enable-new-api-stack", "--num-agents=2", "--stop-iters-pretraining=5", "--stop-iters=20", "--stop-reward=150.0"],
)

#@OldAPIStack
py_test(
name = "examples/autoregressive_action_dist_tf",
main = "examples/autoregressive_action_dist.py",
Expand All @@ -3519,6 +3527,7 @@ py_test(
args = ["--as-test", "--framework=tf", "--stop-reward=150", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/autoregressive_action_dist_torch",
main = "examples/autoregressive_action_dist.py",
Expand All @@ -3528,6 +3537,7 @@ py_test(
args = ["--as-test", "--framework=torch", "--stop-reward=150", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/cartpole_lstm_impala_tf2",
main = "examples/cartpole_lstm.py",
Expand All @@ -3537,6 +3547,7 @@ py_test(
args = ["--run=IMPALA", "--as-test", "--framework=tf2", "--stop-reward=28", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/cartpole_lstm_impala_torch",
main = "examples/cartpole_lstm.py",
Expand All @@ -3546,6 +3557,7 @@ py_test(
args = ["--run=IMPALA", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/cartpole_lstm_ppo_tf2",
main = "examples/cartpole_lstm.py",
Expand All @@ -3555,6 +3567,7 @@ py_test(
args = ["--run=PPO", "--as-test", "--framework=tf2", "--stop-reward=28", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/cartpole_lstm_ppo_torch",
main = "examples/cartpole_lstm.py",
Expand All @@ -3564,6 +3577,7 @@ py_test(
args = ["--run=PPO", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/cartpole_lstm_ppo_torch_with_prev_a_and_r",
main = "examples/cartpole_lstm.py",
Expand Down Expand Up @@ -3613,6 +3627,7 @@ py_test(
args = ["--as-test", "--framework=torch", "--stop-reward=6.0"]
)

#@OldAPIStack
py_test(
name = "examples/metrics/custom_metrics_and_callbacks",
main = "examples/metrics/custom_metrics_and_callbacks.py",
Expand All @@ -3622,6 +3637,7 @@ py_test(
args = ["--stop-iters=2"]
)

#@OldAPIStack
py_test(
name = "examples/custom_model_loss_and_metrics_ppo_tf",
main = "examples/custom_model_loss_and_metrics.py",
Expand All @@ -3633,6 +3649,7 @@ py_test(
args = ["--run=PPO", "--stop-iters=1", "--framework=tf","--input-files=tests/data/cartpole"]
)

#@OldAPIStack
py_test(
name = "examples/custom_model_loss_and_metrics_ppo_torch",
main = "examples/custom_model_loss_and_metrics.py",
Expand All @@ -3644,28 +3661,6 @@ py_test(
args = ["--run=PPO", "--framework=torch", "--stop-iters=1", "--input-files=tests/data/cartpole"]
)

py_test(
name = "examples/custom_model_loss_and_metrics_pg_tf",
main = "examples/custom_model_loss_and_metrics.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "small",
# Include the json data file.
data = ["tests/data/cartpole/small.json"],
srcs = ["examples/custom_model_loss_and_metrics.py"],
args = ["--stop-iters=1", "--framework=tf", "--input-files=tests/data/cartpole"]
)

py_test(
name = "examples/custom_model_loss_and_metrics_pg_torch",
main = "examples/custom_model_loss_and_metrics.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "small",
# Include the json data file.
data = ["tests/data/cartpole/small.json"],
srcs = ["examples/custom_model_loss_and_metrics.py"],
args = ["--framework=torch", "--stop-iters=1", "--input-files=tests/data/cartpole"]
)

py_test(
name = "examples/custom_recurrent_rnn_tokenizer_repeat_after_me_tf2",
main = "examples/custom_recurrent_rnn_tokenizer.py",
Expand Down
8 changes: 8 additions & 0 deletions rllib/core/rl_module/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI


__all__ = [
"TargetNetworkAPI",
"ValueFunctionAPI",
]
104 changes: 104 additions & 0 deletions rllib/examples/rl_modules/classes/modelv2_to_rlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from typing import Any, Dict

from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.apis import ValueFunctionAPI
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.models.torch.torch_distributions import (
TorchCategorical,
TorchDiagGaussian,
TorchMultiCategorical,
TorchMultiDistribution,
TorchSquashedGaussian,
)
from ray.rllib.models.torch.torch_action_dist import (
TorchCategorical as OldTorchCategorical,
TorchDiagGaussian as OldTorchDiagGaussian,
TorchMultiActionDistribution as OldTorchMultiActionDistribution,
TorchMultiCategorical as OldTorchMultiCategorical,
TorchSquashedGaussian as OldTorchSquashedGaussian,
)
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
from ray.rllib.utils.annotations import override


class ModelV2ToRLModule(TorchRLModule, ValueFunctionAPI):
"""An RLModule containing a (old stack) ModelV2, provided by a policy checkpoint."""

@override(TorchRLModule)
def setup(self):
super().setup()

# Get the policy checkpoint from the `model_config_dict`.
policy_checkpoint_dir = self.config.model_config_dict.get(
"policy_checkpoint_dir"
)
if policy_checkpoint_dir is None:
raise ValueError(
"The `model_config_dict` of your RLModule must contain a "
"`policy_checkpoint_dir` key pointing to the policy checkpoint "
"directory! You can find this dir under the Algorithm's checkpoint dir "
"in subdirectory: [algo checkpoint dir]/policies/[policy ID, e.g. "
"`default_policy`]."
)

# Create a temporary policy object.
policy = TorchPolicyV2.from_checkpoint(policy_checkpoint_dir)
self._model_v2 = policy.model

# Translate the action dist classes from the old API stack to the new.
self._action_dist_class = self._translate_dist_class(policy.dist_class)

# Erase the torch policy from memory, so it can be garbage collected.
del policy

def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
nn_output, state_out = self._model_v2(batch)
# Interpret the NN output as action logits.
output = {Columns.ACTION_DIST_INPUTS: nn_output}
# Add the `state_out` to the `output`, new API stack style.
if state_out:
output[Columns.STATE_OUT] = {}
for i, o in enumerate(state_out):
output[Columns.STATE_OUT][i] = o

return output

def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return self._forward_inference(batch, **kwargs)

def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
out = self._forward_inference(batch, **kwargs)
out[Columns.ACTION_LOGP] = self._action_dist_class(
out[Columns.ACTION_DIST_INPUTS]
).logp(batch[Columns.ACTIONS])
out[Columns.VF_PREDS] = self._model_v2.value_function()
return out

def compute_values(self, batch: Dict[str, Any]):
self._model_v2(batch)
return self._model_v2.value_function()

def get_inference_action_dist_cls(self):
return self._action_dist_class

def get_exploration_action_dist_cls(self):
return self._action_dist_class

def get_train_action_dist_cls(self):
return self._action_dist_class

def _translate_dist_class(self, old_dist_class):
map_ = {
OldTorchCategorical: TorchCategorical,
OldTorchDiagGaussian: TorchDiagGaussian,
OldTorchMultiActionDistribution: TorchMultiDistribution,
OldTorchMultiCategorical: TorchMultiCategorical,
OldTorchSquashedGaussian: TorchSquashedGaussian,
}
if old_dist_class not in map_:
raise ValueError(
f"ModelV2ToRLModule does NOT support {old_dist_class} action "
f"distributions yet!"
)

return map_[old_dist_class]
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import pathlib

import gymnasium as gym
import numpy as np
import torch

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleConfig, RLModuleSpec
from ray.rllib.examples.rl_modules.classes.modelv2_to_rlm import ModelV2ToRLModule
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
)
from ray.rllib.utils.spaces.space_utils import batch


if __name__ == "__main__":
# Configure and train an old stack default ModelV2.
config = (
PPOConfig()
# Old API stack.
.api_stack(
enable_env_runner_and_connector_v2=False,
enable_rl_module_and_learner=False,
)
.environment("CartPole-v1")
.training(
lr=0.0003,
num_sgd_iter=6,
vf_loss_coeff=0.01,
)
)
algo_old_stack = config.build()

min_return_old_stack = 100.0
while True:
results = algo_old_stack.train()
print(results)
if results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= min_return_old_stack:
print(
f"Reached episode return of {min_return_old_stack} -> stopping "
"old API stack training."
)
break

checkpoint = algo_old_stack.save()
policy_path = (
pathlib.Path(checkpoint.checkpoint.path) / "policies" / "default_policy"
)
assert policy_path.is_dir()
algo_old_stack.stop()

print("done")

# Move the old API stack (trained) ModelV2 into the new API stack's RLModule.
# Run a simple CartPole inference experiment.
env = gym.make("CartPole-v1", render_mode="human")
rl_module = ModelV2ToRLModule(
config=RLModuleConfig(
observation_space=env.observation_space,
action_space=env.action_space,
model_config_dict={"policy_checkpoint_dir": policy_path},
),
)

obs, _ = env.reset()
env.render()
done = False
episode_return = 0.0
while not done:
output = rl_module.forward_inference({"obs": torch.from_numpy(batch([obs]))})
action_logits = output["action_dist_inputs"].detach().numpy()[0]
action = np.argmax(action_logits)
obs, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
episode_return += reward
env.render()

print(f"Ran episode with trained ModelV2: return={episode_return}")

# Continue training with the (checkpointed) ModelV2.

# We change the original (old API stack) `config` into a new API stack one:
config = config.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
).rl_module(
rl_module_spec=RLModuleSpec(
module_class=ModelV2ToRLModule,
model_config_dict={"policy_checkpoint_dir": policy_path},
),
)

# Build the new stack algo.
algo_new_stack = config.build()

# Train until a higher return.
min_return_new_stack = 450.0
passed = False
for i in range(50):
results = algo_new_stack.train()
print(results)
# Make sure that the model's weights from the old API stack training
# were properly transferred to the new API RLModule wrapper. Thus, even
# after only one iteration of new stack training, we already expect the
# return to be higher than it was at the end of the old stack training.
assert results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= min_return_old_stack
if results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= min_return_new_stack:
print(
f"Reached episode return of {min_return_new_stack} -> stopping "
"new API stack training."
)
passed = True
break

if not passed:
raise ValueError(
"Continuing training on the new stack did not succeed! Last return: "
f"{results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]}"
)
2 changes: 0 additions & 2 deletions rllib/tuned_examples/bc/cartpole_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,9 @@
}
)
.training(
gamma=0.99,
lr=0.0003,
num_sgd_iter=6,
vf_loss_coeff=0.01,
use_kl_loss=True,
)
.evaluation(
evaluation_num_env_runners=1,
Expand Down
Loading

0 comments on commit 3c950a1

Please sign in to comment.