Skip to content

Commit

Permalink
[RLlib] Update autoregressive actions example. (ray-project#47829)
Browse files Browse the repository at this point in the history
Signed-off-by: ujjawal-khare <[email protected]>
  • Loading branch information
simonsays1980 authored and ujjawal-khare committed Oct 15, 2024
1 parent 554195d commit 041874d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3260,7 +3260,7 @@ py_test(
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/autoregressive_action_dist.py"],
args = ["--as-test", "--framework=torch", "--stop-reward=150", "--num-cpus=4"]
args = ["--as-test", "--framework=torch", "--stop-reward=-0.012", "--num-cpus=4"]
)

#@OldAPIStack
Expand Down
4 changes: 3 additions & 1 deletion rllib/examples/autoregressive_action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def get_cli_args():
config = (
get_trainable_cls(args.run)
.get_default_config()
.environment(CorrelatedActionsEnv)
# Batch-norm models have not been migrated to the RL Module API yet.
.api_stack(enable_rl_module_and_learner=False)
.environment(AutoRegressiveActionEnv)
.framework(args.framework)
.training(gamma=0.5)
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
Expand Down
12 changes: 7 additions & 5 deletions rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import abc
from abc import abstractmethod
from typing import Any, Dict
from typing import Dict

from ray.rllib.core import Columns
from ray.rllib.core.models.base import ENCODER_OUT
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
Expand Down Expand Up @@ -233,10 +234,11 @@ def _forward_train(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:
return outs

@override(ValueFunctionAPI)
def compute_values(self, batch: Dict[str, TensorType], embeddings=None):
# Encoder forward pass to get `embeddings`, if necessary.
if embeddings is None:
embeddings = self.encoder(batch)[ENCODER_OUT]
def compute_values(self, batch: Dict[str, TensorType]):

# Encoder forward pass.
encoder_outs = self.encoder(batch)[ENCODER_OUT]

# Value head forward pass.
vf_out = self.vf(embeddings)
# Squeeze out last dimension (single node value head).
Expand Down

0 comments on commit 041874d

Please sign in to comment.