Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Cleanup examples folder 17: Add example for custom RLModule with an LSTM. #46276

Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3104,15 +3104,6 @@ py_test(

# subdirectory: rl_modules/
# ....................................
py_test(
name = "examples/rl_modules/custom_rl_module",
main = "examples/rl_modules/custom_rl_module.py",
tags = ["team:rllib", "examples"],
size = "medium",
srcs = ["examples/rl_modules/custom_rl_module.py"],
args = ["--enable-new-api-stack", "--stop-iters=3"],
)

py_test(
name = "examples/rl_modules/action_masking_rlm",
main = "examples/rl_modules/action_masking_rlm.py",
Expand All @@ -3130,6 +3121,22 @@ py_test(
srcs = ["examples/rl_modules/autoregressive_actions_rlm.py"],
args = ["--enable-new-api-stack"],
)
py_test(
name = "examples/rl_modules/custom_cnn_rl_module",
main = "examples/rl_modules/custom_cnn_rl_module.py",
tags = ["team:rllib", "examples"],
size = "medium",
srcs = ["examples/rl_modules/custom_cnn_rl_module.py"],
args = ["--enable-new-api-stack", "--stop-iters=3"],
)
py_test(
name = "examples/rl_modules/custom_lstm_rl_module",
main = "examples/rl_modules/custom_lstm_rl_module.py",
tags = ["team:rllib", "examples"],
size = "medium",
srcs = ["examples/rl_modules/custom_lstm_rl_module.py"],
args = ["--as-test", "--enable-new-api-stack"],
)

#@OldAPIStack @HybridAPIStack
py_test(
Expand Down
192 changes: 192 additions & 0 deletions rllib/examples/rl_modules/classes/lstm_containing_rlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from typing import Any

import numpy as np

from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.models.torch.torch_distributions import TorchCategorical
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()


class LSTMContainingRLModule(TorchRLModule):
"""An example TorchRLModule that contains an LSTM layer.

.. testcode::

import numpy as np
import gymnasium as gym
from ray.rllib.core.rl_module.rl_module import RLModuleConfig

B = 10 # batch size
T = 5 # seq len
f = 25 # feature dim
CELL = 32 # LSTM cell size

# Construct the RLModule.
rl_module_config = RLModuleConfig(
observation_space=gym.spaces.Box(-1.0, 1.0, (f,), np.float32),
action_space=gym.spaces.Discrete(4),
model_config_dict={"lstm_cell_size": CELL}
)
my_net = LSTMContainingRLModule(rl_module_config)

# Create some dummy input.
obs = torch.from_numpy(
np.random.random_sample(size=(B, T, f)
).astype(np.float32))
state_in = my_net.get_initial_state()
# Repeat state_in across batch.
state_in = tree.map_structure(
lambda s: torch.from_numpy(s).unsqueeze(0).repeat(B, 1), state_in
)
input_dict = {
Columns.OBS: obs,
Columns.STATE_IN: state_in,
}

# Run through all 3 forward passes.
print(my_net.forward_inference(input_dict))
print(my_net.forward_exploration(input_dict))
print(my_net.forward_train(input_dict))

# Print out the number of parameters.
num_all_params = sum(int(np.prod(p.size())) for p in my_net.parameters())
print(f"num params = {num_all_params}")
"""

@override(TorchRLModule)
def setup(self):
"""Use this method to create all the model components that you require.

Feel free to access the following useful properties in this class:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should add this super neat docstrings also to the classes I have written. Really good idea!

- `self.config.model_config_dict`: The config dict for this RLModule class,
which should contain flxeible settings, for example: {"hiddens": [256, 256]}.
- `self.config.observation|action_space`: The observation and action space that
this RLModule is subject to. Note that the observation space might not be the
exact space from your env, but that it might have already gone through
preprocessing through a connector pipeline (for example, flattening,
frame-stacking, mean/std-filtering, etc..).
"""
# Assume a simple Box(1D) tensor as input shape.
in_size = self.config.observation_space.shape[0]

# Get the LSTM cell size from our RLModuleConfig's (self.config)
# `model_config_dict` property:
self._lstm_cell_size = self.config.model_config_dict.get("lstm_cell_size", 256)
self._lstm = nn.LSTM(in_size, self._lstm_cell_size, batch_first=False)
in_size = self._lstm_cell_size

# Build a sequential stack.
layers = []
# Get the dense layer pre-stack configuration from the same config dict.
dense_layers = self.config.model_config_dict.get("dense_layers", [128, 128])
for out_size in dense_layers:
# Dense layer.
layers.append(nn.Linear(in_size, out_size))
# ReLU activation.
layers.append(nn.ReLU())
in_size = out_size

self._fc_net = nn.Sequential(*layers)

# Logits layer (no bias, no activation).
self._logits = nn.Linear(in_size, self.config.action_space.n)
# Single-node value layer.
self._values = nn.Linear(in_size, 1)

@override(TorchRLModule)
def get_initial_state(self) -> Any:
return {
"h": np.zeros(shape=(self._lstm_cell_size,), dtype=np.float32),
"c": np.zeros(shape=(self._lstm_cell_size,), dtype=np.float32),
}

@override(TorchRLModule)
def _forward_inference(self, batch, **kwargs):
# Compute the basic 1D feature tensor (inputs to policy- and value-heads).
_, state_out, logits = self._compute_features_state_out_and_logits(batch)

# Return logits as ACTION_DIST_INPUTS (categorical distribution).
# Note that the default `GetActions` connector piece (in the EnvRunner) will
# take care of argmax-"sampling" from the logits to yield the inference (greedy)
# action.
return {
Columns.STATE_OUT: state_out,
Columns.ACTION_DIST_INPUTS: logits,
}

@override(TorchRLModule)
def _forward_exploration(self, batch, **kwargs):
# Exact same as `_forward_inference`.
# Note that the default `GetActions` connector piece (in the EnvRunner) will
# take care of stochastic sampling from the Categorical defined by the logits
# to yield the exploration action.
return self._forward_inference(batch, **kwargs)

@override(TorchRLModule)
def _forward_train(self, batch, **kwargs):
# Compute the basic 1D feature tensor (inputs to policy- and value-heads).
features, state_out, logits = self._compute_features_state_out_and_logits(batch)
# Besides the action logits, we also have to return value predictions here
# (to be used inside the loss function).
values = self._values(features).squeeze(-1)
return {
Columns.STATE_OUT: state_out,
Columns.ACTION_DIST_INPUTS: logits,
Columns.VF_PREDS: values,
}

# TODO (sven): We still need to define the distibution to use here, even though,
# we have a pretty standard action space (Discrete), which should simply always map
# to a categorical dist. by default.
@override(TorchRLModule)
def get_inference_action_dist_cls(self):
return TorchCategorical

@override(TorchRLModule)
def get_exploration_action_dist_cls(self):
return TorchCategorical

@override(TorchRLModule)
def get_train_action_dist_cls(self):
return TorchCategorical

# TODO (sven): In order for this RLModule to work with PPO, we must define
# our own `_compute_values()` method. This would become more obvious, if we simply
# subclassed the `PPOTorchRLModule` directly here (which we didn't do for
# simplicity and to keep some generality). We might change even get rid of algo-
# specific RLModule subclasses altogether in the future and replace them
# by mere algo-specific APIs (w/o any actual implementations).
def _compute_values(self, batch):
obs = batch[Columns.OBS]
state_in = batch[Columns.STATE_IN]
h, c = state_in["h"], state_in["c"]
# Unsqueeze the layer dim (we only have 1 LSTM layer.
features, _ = self._lstm(
obs.permute(1, 0, 2), # we have to permute, b/c our LSTM is time-major
(h.unsqueeze(0), c.unsqueeze(0)),
)
# Make batch-major again.
features = features.permute(1, 0, 2)
# Push through our FC net.
features = self._fc_net(features)
return self._values(features).squeeze(-1)

def _compute_features_state_out_and_logits(self, batch):
obs = batch[Columns.OBS]
state_in = batch[Columns.STATE_IN]
h, c = state_in["h"], state_in["c"]
# Unsqueeze the layer dim (we only have 1 LSTM layer.
features, (h, c) = self._lstm(
obs.permute(1, 0, 2), # we have to permute, b/c our LSTM is time-major
(h.unsqueeze(0), c.unsqueeze(0)),
)
# Make batch-major again.
features = features.permute(1, 0, 2)
# Push through our FC net.
features = self._fc_net(features)
logits = self._logits(features)
return features, {"h": h.squeeze(0), "c": c.squeeze(0)}, logits
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ray.rllib.models.torch.torch_distributions import TorchCategorical
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import convert_to_torch_tensor

torch, nn = try_import_torch()

Expand All @@ -24,6 +23,31 @@ class TinyAtariCNN(TorchRLModule):
and n 1x1 filters, where n is the number of actions in the (discrete) action space.
Simple reshaping (no flattening or extra linear layers necessary) lead to the
action logits, which can directly be used inside a distribution or loss.

import numpy as np
import gymnasium as gym
from ray.rllib.core.rl_module.rl_module import RLModuleConfig

rl_module_config = RLModuleConfig(
observation_space=gym.spaces.Box(-1.0, 1.0, (42, 42, 4), np.float32),
action_space=gym.spaces.Discrete(4),
)
my_net = TinyAtariCNN(rl_module_config)

B = 10
w = 42
h = 42
c = 4
data = torch.from_numpy(
np.random.random_sample(size=(B, w, h, c)).astype(np.float32)
)
print(my_net.forward_inference({"obs": data}))
print(my_net.forward_exploration({"obs": data}))
print(my_net.forward_train({"obs": data}))

num_all_params = sum(int(np.prod(p.size())) for p in my_net.parameters())
print(f"num params = {num_all_params}")

"""

@override(TorchRLModule)
Expand Down Expand Up @@ -124,8 +148,8 @@ def _forward_train(self, batch, **kwargs):
# simplicity and to keep some generality). We might even get rid of algo-
# specific RLModule subclasses altogether in the future and replace them
# by mere algo-specific APIs (w/o any actual implementations).
def _compute_values(self, batch, device):
obs = convert_to_torch_tensor(batch[Columns.OBS], device=device)
def _compute_values(self, batch):
obs = batch[Columns.OBS]
features = self._base_cnn_stack(obs.permute(0, 3, 1, 2))
features = torch.squeeze(features, dim=[-1, -2])
return self._values(features).squeeze(-1)
Expand Down Expand Up @@ -156,29 +180,3 @@ def get_exploration_action_dist_cls(self):
@override(RLModule)
def get_inference_action_dist_cls(self):
return TorchCategorical


if __name__ == "__main__":
import numpy as np
import gymnasium as gym
from ray.rllib.core.rl_module.rl_module import RLModuleConfig

rl_module_config = RLModuleConfig(
observation_space=gym.spaces.Box(-1.0, 1.0, (42, 42, 4), np.float32),
action_space=gym.spaces.Discrete(4),
)
my_net = TinyAtariCNN(rl_module_config)

B = 10
w = 42
h = 42
c = 4
data = torch.from_numpy(
np.random.random_sample(size=(B, w, h, c)).astype(np.float32)
)
print(my_net.forward_inference({"obs": data}))
print(my_net.forward_exploration({"obs": data}))
print(my_net.forward_train({"obs": data}))

num_all_params = sum(int(np.prod(p.size())) for p in my_net.parameters())
print(f"num params = {num_all_params}")
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Example of implementing and configuring a custom (torch) RLModule.
"""Example of implementing and configuring a custom (torch) CNN containing RLModule.

This example:
- demonstrates how you can subclass the TorchRLModule base class and setup your
own neural network architecture by overriding `setup()`.
- how to override the 3 forward methods: `_forward_inference()`,
- demonstrates how you can subclass the TorchRLModule base class and set up your
own CNN-stack architecture by overriding the `setup()` method.
- shows how to override the 3 forward methods: `_forward_inference()`,
`_forward_exploration()`, and `forward_train()` to implement your own custom forward
logic(s). You will also learn, when each of these 3 methods is called by RLlib or
the users of your RLModule.
Expand Down Expand Up @@ -56,7 +56,7 @@

from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack
from ray.rllib.examples.rl_modules.classes.tiny_atari_cnn import TinyAtariCNN
from ray.rllib.examples.rl_modules.classes.tiny_atari_cnn_rlm import TinyAtariCNN
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
run_rllib_example_script_experiment,
Expand Down
Loading
Loading