Skip to content

Commit

Permalink
[RLlib] Implemented the new spec checker. (ray-project#31226)
Browse files Browse the repository at this point in the history
Signed-off-by: tmynn <[email protected]>
  • Loading branch information
kouroshHakha authored and tamohannes committed Jan 25, 2023
1 parent 96bc324 commit 80ebf5e
Show file tree
Hide file tree
Showing 19 changed files with 760 additions and 395 deletions.
4 changes: 2 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1916,10 +1916,10 @@ py_test(

# test ModelSpec
py_test(
name = "test_model_spec",
name = "test_spec_dict",
tags = ["team:rllib", "models"],
size = "small",
srcs = ["models/specs/tests/test_model_spec.py"]
srcs = ["models/specs/tests/test_spec_dict.py"]
)

# test TorchVectorEncoder
Expand Down
20 changes: 10 additions & 10 deletions rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.models.specs.specs_dict import ModelSpec
from ray.rllib.models.specs.specs_dict import SpecDict
from ray.rllib.models.specs.specs_torch import TorchTensorSpec
from ray.rllib.models.torch.torch_distributions import (
TorchCategorical,
Expand Down Expand Up @@ -227,12 +227,12 @@ def get_initial_state(self) -> NestedDict:
return {}

@override(RLModule)
def input_specs_inference(self) -> ModelSpec:
def input_specs_inference(self) -> SpecDict:
return self.input_specs_exploration()

@override(RLModule)
def output_specs_inference(self) -> ModelSpec:
return ModelSpec({SampleBatch.ACTION_DIST: TorchDeterministic})
def output_specs_inference(self) -> SpecDict:
return SpecDict({SampleBatch.ACTION_DIST: TorchDeterministic})

@override(RLModule)
def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]:
Expand All @@ -255,7 +255,7 @@ def input_specs_exploration(self):
return self.shared_encoder.input_spec()

@override(RLModule)
def output_specs_exploration(self) -> ModelSpec:
def output_specs_exploration(self) -> SpecDict:
specs = {SampleBatch.ACTION_DIST: self.__get_action_dist_type()}
if self._is_discrete:
specs[SampleBatch.ACTION_DIST_INPUTS] = {
Expand All @@ -267,7 +267,7 @@ def output_specs_exploration(self) -> ModelSpec:
"scale": TorchTensorSpec("b, h", h=self.config.action_space.shape[0]),
}

return ModelSpec(specs)
return SpecDict(specs)

@override(RLModule)
def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]:
Expand Down Expand Up @@ -299,7 +299,7 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]:
return output

@override(RLModule)
def input_specs_train(self) -> ModelSpec:
def input_specs_train(self) -> SpecDict:
if self._is_discrete:
action_spec = TorchTensorSpec("b")
else:
Expand All @@ -310,12 +310,12 @@ def input_specs_train(self) -> ModelSpec:
spec_dict.update({SampleBatch.ACTIONS: action_spec})
if SampleBatch.OBS in spec_dict:
spec_dict[SampleBatch.NEXT_OBS] = spec_dict[SampleBatch.OBS]
spec = ModelSpec(spec_dict)
spec = SpecDict(spec_dict)
return spec

@override(RLModule)
def output_specs_train(self) -> ModelSpec:
spec = ModelSpec(
def output_specs_train(self) -> SpecDict:
spec = SpecDict(
{
SampleBatch.ACTION_DIST: self.__get_action_dist_type(),
SampleBatch.ACTION_LOGP: TorchTensorSpec("b", dtype=torch.float32),
Expand Down
22 changes: 12 additions & 10 deletions rllib/core/rl_module/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.models.specs.specs_dict import ModelSpec, check_specs
from ray.rllib.models.specs.specs_dict import SpecDict
from ray.rllib.models.specs.checker import check_input_specs, check_output_specs
from ray.rllib.models.specs.specs_torch import TorchTensorSpec
from ray.rllib.models.torch.primitives import FCNet

Expand Down Expand Up @@ -66,12 +67,13 @@ def get_inital_state(self):
raise []

def input_spec(self):
return ModelSpec()
return SpecDict()

def output_spec(self):
return ModelSpec()
return SpecDict()

@check_specs(input_spec="_input_spec", output_spec="_output_spec")
@check_input_specs("_input_spec")
@check_output_specs("_output_spec")
def forward(self, input_dict):
return self._forward(input_dict)

Expand All @@ -91,12 +93,12 @@ def __init__(self, config: FCConfig) -> None:
)

def input_spec(self):
return ModelSpec(
return SpecDict(
{SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.input_dim)}
)

def output_spec(self):
return ModelSpec(
return SpecDict(
{"embedding": TorchTensorSpec("b, h", h=self.config.output_dim)}
)

Expand Down Expand Up @@ -125,7 +127,7 @@ def get_inital_state(self):

def input_spec(self):
config = self.config
return ModelSpec(
return SpecDict(
{
# bxt is just a name for better readability to indicated padded batch
SampleBatch.OBS: TorchTensorSpec("bxt, h", h=config.input_dim),
Expand All @@ -142,7 +144,7 @@ def input_spec(self):

def output_spec(self):
config = self.config
return ModelSpec(
return SpecDict(
{
"embedding": TorchTensorSpec("bxt, h", h=config.output_dim),
"state_out": {
Expand Down Expand Up @@ -181,10 +183,10 @@ def __init__(self, config: EncoderConfig) -> None:
super().__init__(config)

def input_spec(self):
return ModelSpec()
return SpecDict()

def output_spec(self):
return ModelSpec()
return SpecDict()

def _forward(self, input_dict):
return input_dict
18 changes: 9 additions & 9 deletions rllib/core/rl_module/marl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ray.util.annotations import PublicAPI
from ray.rllib.utils.annotations import override

from ray.rllib.models.specs.specs_dict import ModelSpec
from ray.rllib.models.specs.specs_dict import SpecDict
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.core.rl_module import RLModule

Expand Down Expand Up @@ -243,32 +243,32 @@ def __getitem__(self, module_id: ModuleID) -> RLModule:
return self._rl_modules[module_id]

@override(RLModule)
def output_specs_train(self) -> ModelSpec:
def output_specs_train(self) -> SpecDict:
return self._get_specs_for_modules("output_specs_train")

@override(RLModule)
def output_specs_inference(self) -> ModelSpec:
def output_specs_inference(self) -> SpecDict:
return self._get_specs_for_modules("output_specs_inference")

@override(RLModule)
def output_specs_exploration(self) -> ModelSpec:
def output_specs_exploration(self) -> SpecDict:
return self._get_specs_for_modules("output_specs_exploration")

@override(RLModule)
def input_specs_train(self) -> ModelSpec:
def input_specs_train(self) -> SpecDict:
return self._get_specs_for_modules("input_specs_train")

@override(RLModule)
def input_specs_inference(self) -> ModelSpec:
def input_specs_inference(self) -> SpecDict:
return self._get_specs_for_modules("input_specs_inference")

@override(RLModule)
def input_specs_exploration(self) -> ModelSpec:
def input_specs_exploration(self) -> SpecDict:
return self._get_specs_for_modules("input_specs_exploration")

def _get_specs_for_modules(self, method_name: str) -> ModelSpec:
def _get_specs_for_modules(self, method_name: str) -> SpecDict:
"""Returns a ModelSpec from the given method_name for all modules."""
return ModelSpec(
return SpecDict(
{
module_id: getattr(module, method_name)()
for module_id, module in self._rl_modules.items()
Expand Down
40 changes: 20 additions & 20 deletions rllib/core/rl_module/rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
OverrideToImplementCustomLogic_CallToSuperRecommended,
)

from ray.rllib.models.specs.specs_dict import ModelSpec, check_specs
from ray.rllib.models.specs.typing import SpecType
from ray.rllib.models.specs.checker import check_input_specs, check_output_specs
from ray.rllib.models.distributions import Distribution
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.nested_dict import NestedDict
Expand Down Expand Up @@ -188,46 +189,45 @@ def get_initial_state(self) -> NestedDict:
return {}

@OverrideToImplementCustomLogic_CallToSuperRecommended
def output_specs_inference(self) -> ModelSpec:
def output_specs_inference(self) -> SpecType:
"""Returns the output specs of the forward_inference method.
Override this method to customize the output specs of the inference call.
The default implementation requires the forward_inference to reutn a dict that
has `action_dist` key and its value is an instance of `Distribution`.
This assumption must always hold.
"""
return ModelSpec({"action_dist": Distribution})
return {"action_dist": Distribution}

@OverrideToImplementCustomLogic_CallToSuperRecommended
def output_specs_exploration(self) -> ModelSpec:
def output_specs_exploration(self) -> SpecType:
"""Returns the output specs of the forward_exploration method.
Override this method to customize the output specs of the inference call.
The default implementation requires the forward_exploration to reutn a dict
that has `action_dist` key and its value is an instance of
`Distribution`. This assumption must always hold.
"""
return ModelSpec({"action_dist": Distribution})
return {"action_dist": Distribution}

def output_specs_train(self) -> ModelSpec:
def output_specs_train(self) -> SpecType:
"""Returns the output specs of the forward_train method."""
return ModelSpec()
return {}

def input_specs_inference(self) -> ModelSpec:
def input_specs_inference(self) -> SpecType:
"""Returns the input specs of the forward_inference method."""
return ModelSpec()
return {}

def input_specs_exploration(self) -> ModelSpec:
def input_specs_exploration(self) -> SpecType:
"""Returns the input specs of the forward_exploration method."""
return ModelSpec()
return {}

def input_specs_train(self) -> ModelSpec:
def input_specs_train(self) -> SpecType:
"""Returns the input specs of the forward_train method."""
return ModelSpec()
return {}

@check_specs(
input_spec="_input_specs_inference", output_spec="_output_specs_inference"
)
@check_input_specs("_input_specs_inference")
@check_output_specs("_output_specs_inference")
def forward_inference(self, batch: SampleBatchType, **kwargs) -> Mapping[str, Any]:
"""Forward-pass during evaluation, called from the sampler. This method should
not be overriden. Instead, override the _forward_inference method.
Expand All @@ -247,9 +247,8 @@ def forward_inference(self, batch: SampleBatchType, **kwargs) -> Mapping[str, An
def _forward_inference(self, batch: NestedDict, **kwargs) -> Mapping[str, Any]:
"""Forward-pass during evaluation. See forward_inference for details."""

@check_specs(
input_spec="_input_specs_exploration", output_spec="_output_specs_exploration"
)
@check_input_specs("_input_specs_exploration")
@check_output_specs("_output_specs_exploration")
def forward_exploration(
self, batch: SampleBatchType, **kwargs
) -> Mapping[str, Any]:
Expand All @@ -271,7 +270,8 @@ def forward_exploration(
def _forward_exploration(self, batch: NestedDict, **kwargs) -> Mapping[str, Any]:
"""Forward-pass during exploration. See forward_exploration for details."""

@check_specs(input_spec="_input_specs_train", output_spec="_output_specs_train")
@check_input_specs("_input_specs_train")
@check_output_specs("_output_specs_train")
def forward_train(
self,
batch: SampleBatchType,
Expand Down
31 changes: 16 additions & 15 deletions rllib/core/testing/tf/bc_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule
from ray.rllib.models.specs.specs_dict import ModelSpec
from ray.rllib.models.specs.specs_dict import SpecDict
from ray.rllib.models.specs.typing import SpecType
from ray.rllib.models.specs.specs_tf import TFTensorSpecs
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
Expand Down Expand Up @@ -33,28 +34,28 @@ def __init__(
self._output_dim = output_dim

@override(RLModule)
def input_specs_exploration(self) -> ModelSpec:
return ModelSpec(self._default_inputs())
def input_specs_exploration(self) -> SpecType:
return self._default_inputs()

@override(RLModule)
def input_specs_inference(self) -> ModelSpec:
return ModelSpec(self._default_inputs())
def input_specs_inference(self) -> SpecType:
return self._default_inputs()

@override(RLModule)
def input_specs_train(self) -> ModelSpec:
return ModelSpec(self._default_inputs())
def input_specs_train(self) -> SpecType:
return self._default_inputs()

@override(RLModule)
def output_specs_exploration(self) -> ModelSpec:
return ModelSpec(self._default_outputs())
def output_specs_exploration(self) -> SpecType:
return self._default_outputs()

@override(RLModule)
def output_specs_inference(self) -> ModelSpec:
return ModelSpec(self._default_outputs())
def output_specs_inference(self) -> SpecType:
return self._default_outputs()

@override(RLModule)
def output_specs_train(self) -> ModelSpec:
return ModelSpec(self._default_outputs())
def output_specs_train(self) -> SpecType:
return self._default_outputs()

@override(RLModule)
def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]:
Expand Down Expand Up @@ -104,9 +105,9 @@ def from_model_config(

return cls(**config)

def _default_inputs(self) -> ModelSpec:
def _default_inputs(self) -> SpecDict:
obs_dim = self._input_dim
return {"obs": TFTensorSpecs("b, do", do=obs_dim)}

def _default_outputs(self) -> ModelSpec:
def _default_outputs(self) -> SpecDict:
return {"action_dist": tfp.distributions.Distribution}
Loading

0 comments on commit 80ebf5e

Please sign in to comment.