Skip to content

Commit

Permalink
[RLlib] Quick-fix for default RLModules in combination with a user-pr…
Browse files Browse the repository at this point in the history
…ovided config-sub-dict (instead of a full `DefaultModelConfig`). (ray-project#47965)

Signed-off-by: ujjawal-khare <[email protected]>
  • Loading branch information
sven1977 authored and ujjawal-khare committed Oct 15, 2024
1 parent 189477f commit 7d2dbce
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion rllib/algorithms/bc/torch/bc_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class BCTorchRLModule(TorchRLModule):
@override(RLModule)
def setup(self):
# __sphinx_doc_begin__
# Build models from catalog
# Build models from catalog.
self.encoder = self.catalog.build_encoder(framework=self.framework)
self.pi = self.catalog.build_pi_head(framework=self.framework)

Expand Down
10 changes: 7 additions & 3 deletions rllib/core/rl_module/multi_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from dataclasses import dataclass, field
import dataclasses
import logging
import pprint
from typing import (
Expand Down Expand Up @@ -666,7 +666,11 @@ def build(self, module_id: Optional[ModuleID] = None) -> RLModule:
observation_space=self.observation_space,
action_space=self.action_space,
inference_only=self.inference_only,
model_config=self.model_config,
model_config=(
dataclasses.asdict(self.model_config)
if dataclasses.is_dataclass(self.model_config)
else self.model_config
),
rl_module_specs=self.rl_module_specs,
)
# Older custom model might still require the old `MultiRLModuleConfig` under
Expand Down Expand Up @@ -859,7 +863,7 @@ def get_rl_module_config(self):
"module2: [RLModuleSpec], ..}, inference_only=..)",
error=False,
)
@dataclass
@dataclasses.dataclass
class MultiRLModuleConfig:
inference_only: bool = False
modules: Dict[ModuleID, RLModuleSpec] = dataclasses.field(default_factory=dict)
Expand Down
2 changes: 1 addition & 1 deletion rllib/core/rl_module/rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def build(self) -> "RLModule":
observation_space=self.observation_space,
action_space=self.action_space,
inference_only=self.inference_only,
model_config=self.model_config,
model_config=self._get_model_config(),
catalog_class=self.catalog_class,
)
# Older custom model might still require the old `RLModuleConfig` under
Expand Down

0 comments on commit 7d2dbce

Please sign in to comment.