Skip to content

Commit

Permalink
[RLlib] Add examples and docs for Catalog. (#33898) (#34267)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArturNiederfahrenhorst authored Apr 18, 2023
1 parent f40c4c8 commit 174abe6
Show file tree
Hide file tree
Showing 15 changed files with 522 additions and 14 deletions.
1 change: 1 addition & 0 deletions doc/source/_includes/rllib/rlmodules_rollout.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.. note:: This doc is related to the `RLModule API <rllib-rlmodule.html>`__ and therefore experimental.
3 changes: 2 additions & 1 deletion doc/source/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ parts:
title: Ray RLlib
sections:
- file: rllib/rllib-training
- file: rllib/core-concepts
- file: rllib/key-concepts
- file: rllib/rllib-env
- file: rllib/rllib-algorithms
- file: rllib/user-guides
Expand All @@ -301,6 +301,7 @@ parts:
- file: rllib/rllib-sample-collection
- file: rllib/rllib-replay-buffers
- file: rllib/rllib-offline
- file: rllib/rllib-catalogs
- file: rllib/rllib-connector
- file: rllib/rllib-rlmodule
- file: rllib/rllib-fault-tolerance
Expand Down
137 changes: 137 additions & 0 deletions doc/source/rllib/doc_code/catalog_guide.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# flake8: noqa
"""
This file holds several examples for the Catalogs API that are used in the catalog
guide.
"""


# 1) Basic interaction with Catalogs in RLlib.
# __sphinx_doc_basic_interaction_begin__
import gymnasium as gym

from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog

env = gym.make("CartPole-v1")

catalog = PPOCatalog(env.observation_space, env.action_space, model_config_dict={})
# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_actor_critic_encoder(framework="torch")
policy_head = catalog.build_pi_head(framework="torch")
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")
# __sphinx_doc_basic_interaction_end__


# 2) Basic workflow that includes the Catalog base class and
# RLlib's ModelConfigs to build models and an action distribution to step through an
# environment.

# __sphinx_doc_modelsworkflow_begin__
import gymnasium as gym
import torch

from ray.rllib.core.models.base import STATE_IN, ENCODER_OUT
from ray.rllib.core.models.catalog import Catalog
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.policy.sample_batch import SampleBatch

env = gym.make("CartPole-v1")

catalog = Catalog(env.observation_space, env.action_space, model_config_dict={})
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")
# Therefore, we need `env.action_space.n` action distribution inputs.
expected_action_dist_input_dims = (env.action_space.n,)
# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_encoder(framework="torch")
# Build a suitable head model for the action distribution.
head_config = MLPHeadConfig(
input_dims=catalog.latent_dims, output_dims=expected_action_dist_input_dims
)
head = head_config.build(framework="torch")
# Now we are ready to interact with the environment
obs, info = env.reset()
# Encoders check for state and sequence lengths for recurrent models.
# We don't need either in this case because default encoders are not recurrent.
input_batch = {
SampleBatch.OBS: torch.Tensor([obs]),
STATE_IN: None,
SampleBatch.SEQ_LENS: None,
}
# Pass the batch through our models and the action distribution.
encoding = encoder(input_batch)[ENCODER_OUT]
action_dist_inputs = head(encoding)
action_dist = action_dist_class.from_logits(action_dist_inputs)
actions = action_dist.sample().numpy()
env.step(actions[0])
# __sphinx_doc_modelsworkflow_end__


# 3) Demonstrates a basic workflow that includes the PPOCatalog to build models
# and an action distribution to step through an environment.

# __sphinx_doc_ppo_models_begin__
import gymnasium as gym
import torch

from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.core.models.base import STATE_IN, ENCODER_OUT, ACTOR
from ray.rllib.policy.sample_batch import SampleBatch

env = gym.make("CartPole-v1")

catalog = PPOCatalog(env.observation_space, env.action_space, model_config_dict={})
# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_actor_critic_encoder(framework="torch")
policy_head = catalog.build_pi_head(framework="torch")
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")

# Now we are ready to interact with the environment
obs, info = env.reset()
# Encoders check for state and sequence lengths for recurrent models.
# We don't need either in this case because default encoders are not recurrent.
input_batch = {
SampleBatch.OBS: torch.Tensor([obs]),
STATE_IN: None,
SampleBatch.SEQ_LENS: None,
}
# Pass the batch through our models and the action distribution.
encoding = encoder(input_batch)[ENCODER_OUT][ACTOR]
action_dist_inputs = policy_head(encoding)
action_dist = action_dist_class.from_logits(action_dist_inputs)
actions = action_dist.sample().numpy()
env.step(actions[0])
# __sphinx_doc_ppo_models_end__


# 4) Demonstrates how to specify a Catalog for an RLModule to use through
# AlgorithmConfig.

# __sphinx_doc_algo_configs_begin__
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec


class MyPPOCatalog(PPOCatalog):
def __init__(self, *args, **kwargs):
print("Hi from within PPORLModule!")
super().__init__(*args, **kwargs)


config = (
PPOConfig()
.environment("CartPole-v1")
.framework("torch")
.rl_module(_enable_rl_module_api=True)
)

# Specify the catalog to use for the PPORLModule.
config = config.rl_module(
rl_module_spec=SingleAgentRLModuleSpec(catalog_class=MyPPOCatalog)
)
# This is how RLlib constructs a PPORLModule
# It will say "Hi from within PPORLModule!".
ppo = config.build()
# __sphinx_doc_algo_configs_end__
16 changes: 16 additions & 0 deletions doc/source/rllib/images/catalog/catalog_and_rlm_diagram.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 174abe6

Please sign in to comment.