Skip to content

Commit

Permalink
[RLlib] New API stack: (Multi)RLModule overhaul vol 05 (deprecate Spe…
Browse files Browse the repository at this point in the history
…cs, SpecDict, TensorSpec). (ray-project#47915)

Signed-off-by: ujjawal-khare <[email protected]>
  • Loading branch information
sven1977 authored and ujjawal-khare committed Oct 15, 2024
1 parent 8ebc194 commit ce265dd
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 155 deletions.
25 changes: 0 additions & 25 deletions rllib/algorithms/ppo/ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import List

from ray.rllib.core.models.configs import RecurrentEncoderConfig
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import (
Expand Down Expand Up @@ -49,30 +48,6 @@ def get_initial_state(self) -> dict:
else:
return {}

@override(RLModule)
def input_specs_inference(self) -> SpecDict:
return [Columns.OBS]

@override(RLModule)
def output_specs_inference(self) -> SpecDict:
return [Columns.ACTION_DIST_INPUTS]

@override(RLModule)
def input_specs_exploration(self):
return self.input_specs_inference()

@override(RLModule)
def output_specs_exploration(self) -> SpecDict:
return self.output_specs_inference()

@override(RLModule)
def input_specs_train(self) -> SpecDict:
return self.input_specs_exploration()

@override(RLModule)
def output_specs_train(self) -> SpecDict:
return [Columns.ACTION_DIST_INPUTS]

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(InferenceOnlyAPI)
def get_non_inference_attributes(self) -> List[str]:
Expand Down
112 changes: 1 addition & 111 deletions rllib/core/rl_module/rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,11 @@
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.specs.typing import SpecType
from ray.rllib.core.models.specs.checker import (
check_input_specs,
check_output_specs,
convert_to_canonical_format,
)
from ray.rllib.models.distributions import Distribution
from ray.rllib.utils.annotations import (
ExperimentalAPI,
override,
OverrideToImplementCustomLogic,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.checkpoints import Checkpointable
from ray.rllib.utils.deprecation import Deprecated
Expand Down Expand Up @@ -476,9 +470,7 @@ def __init__(
framework=self.framework
)

# Make sure, `setup()` is only called once, no matter what. In some cases
# of multiple inheritance (and with our __post_init__ functionality in place,
# this might get called twice.
# Make sure, `setup()` is only called once, no matter what.
if hasattr(self, "_is_setup") and self._is_setup:
raise RuntimeError(
"`RLModule.setup()` called twice within your RLModule implementation "
Expand Down Expand Up @@ -568,8 +560,6 @@ def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""
return {}

@check_input_specs("_input_specs_inference")
@check_output_specs("_output_specs_inference")
def forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""DO NOT OVERRIDE! Forward-pass during evaluation, called from the sampler.
Expand Down Expand Up @@ -599,8 +589,6 @@ def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""
return self._forward(batch, **kwargs)

@check_input_specs("_input_specs_exploration")
@check_output_specs("_output_specs_exploration")
def forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""DO NOT OVERRIDE! Forward-pass during exploration, called from the sampler.
Expand Down Expand Up @@ -630,8 +618,6 @@ def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any
"""
return self._forward(batch, **kwargs)

@check_input_specs("_input_specs_train")
@check_output_specs("_output_specs_train")
def forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""DO NOT OVERRIDE! Forward-pass during training called from the learner.
Expand Down Expand Up @@ -743,48 +729,6 @@ def get_ctor_args_and_kwargs(self):
}, # **kwargs
)

@OverrideToImplementCustomLogic_CallToSuperRecommended
def output_specs_inference(self) -> SpecType:
"""Returns the output specs of the `forward_inference()` method.
Override this method to customize the output specs of the inference call.
The default implementation requires the `forward_inference()` method to return
a dict that has `action_dist` key and its value is an instance of
`Distribution`.
"""
return [Columns.ACTION_DIST_INPUTS]

@OverrideToImplementCustomLogic_CallToSuperRecommended
def output_specs_exploration(self) -> SpecType:
"""Returns the output specs of the `forward_exploration()` method.
Override this method to customize the output specs of the exploration call.
The default implementation requires the `forward_exploration()` method to return
a dict that has `action_dist` key and its value is an instance of
`Distribution`.
"""
return [Columns.ACTION_DIST_INPUTS]

def output_specs_train(self) -> SpecType:
"""Returns the output specs of the forward_train method."""
return {}

def input_specs_inference(self) -> SpecType:
"""Returns the input specs of the forward_inference method."""
return self._default_input_specs()

def input_specs_exploration(self) -> SpecType:
"""Returns the input specs of the forward_exploration method."""
return self._default_input_specs()

def input_specs_train(self) -> SpecType:
"""Returns the input specs of the forward_train method."""
return self._default_input_specs()

def _default_input_specs(self) -> SpecType:
"""Returns the default input specs."""
return [Columns.OBS]

def as_multi_rl_module(self) -> "MultiRLModule":
"""Returns a multi-agent wrapper around this module."""
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule
Expand Down Expand Up @@ -846,57 +790,3 @@ def input_specs_train(self) -> SpecType:
def _default_input_specs(self) -> SpecType:
"""Returns the default input specs."""
return [Columns.OBS]


@Deprecated(
old="RLModule(config=[RLModuleConfig object])",
new="RLModule(observation_space=.., action_space=.., inference_only=.., "
"model_config=.., catalog_class=..)",
error=False,
)
@dataclass
class RLModuleConfig:
observation_space: gym.Space = None
action_space: gym.Space = None
inference_only: bool = False
learner_only: bool = False
model_config_dict: Dict[str, Any] = field(default_factory=dict)
catalog_class: Type["Catalog"] = None

def get_catalog(self) -> Optional["Catalog"]:
if self.catalog_class is not None:
return self.catalog_class(
observation_space=self.observation_space,
action_space=self.action_space,
model_config_dict=self.model_config_dict,
)
return None

def to_dict(self):
catalog_class_path = (
serialize_type(self.catalog_class) if self.catalog_class else ""
)
return {
"observation_space": gym_space_to_dict(self.observation_space),
"action_space": gym_space_to_dict(self.action_space),
"inference_only": self.inference_only,
"learner_only": self.learner_only,
"model_config_dict": self.model_config_dict,
"catalog_class_path": catalog_class_path,
}

@classmethod
def from_dict(cls, d: Dict[str, Any]):
catalog_class = (
None
if d["catalog_class_path"] == ""
else deserialize_type(d["catalog_class_path"])
)
return cls(
observation_space=gym_space_from_dict(d["observation_space"]),
action_space=gym_space_from_dict(d["action_space"]),
inference_only=d["inference_only"],
learner_only=d["learner_only"],
model_config_dict=d["model_config_dict"],
catalog_class=catalog_class,
)
18 changes: 0 additions & 18 deletions rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ray.rllib.core import Columns
from ray.rllib.core.models.base import ENCODER_OUT
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
Expand Down Expand Up @@ -91,23 +90,6 @@ def setup(self):
)
self.vf = vf_config.build(framework=self.framework)

@override(RLModule)
def output_specs_inference(self) -> SpecDict:
return [Columns.ACTIONS]

@override(RLModule)
def output_specs_exploration(self) -> SpecDict:
return [Columns.ACTION_DIST_INPUTS, Columns.ACTIONS, Columns.ACTION_LOGP]

@override(RLModule)
def output_specs_train(self) -> SpecDict:
return [
Columns.ACTION_DIST_INPUTS,
Columns.ACTIONS,
Columns.ACTION_LOGP,
Columns.VF_PREDS,
]

@abstractmethod
def pi(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:
"""Computes the policy outputs given a batch of data.
Expand Down
2 changes: 1 addition & 1 deletion rllib/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ray.air.constants import TRAINING_ITERATION
from ray.air.integrations.wandb import WandbLoggerCallback, WANDB_ENV_VAR
from ray.rllib.common import SupportedFileType
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core import DEFAULT_MODULE_ID, Columns
from ray.rllib.env.wrappers.atari_wrappers import is_atari, wrap_deepmind
from ray.rllib.train import load_experiments_from_file
from ray.rllib.utils.annotations import OldAPIStack
Expand Down

0 comments on commit ce265dd

Please sign in to comment.