diff --git a/rllib/algorithms/bc/torch/bc_torch_rl_module.py b/rllib/algorithms/bc/torch/bc_torch_rl_module.py index a547047d7f41..d06c323b124e 100644 --- a/rllib/algorithms/bc/torch/bc_torch_rl_module.py +++ b/rllib/algorithms/bc/torch/bc_torch_rl_module.py @@ -11,7 +11,7 @@ class BCTorchRLModule(TorchRLModule): @override(RLModule) def setup(self): # __sphinx_doc_begin__ - # Build models from catalog + # Build models from catalog. self.encoder = self.catalog.build_encoder(framework=self.framework) self.pi = self.catalog.build_pi_head(framework=self.framework) diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py index a447d084533b..43eddb909dea 100644 --- a/rllib/core/rl_module/multi_rl_module.py +++ b/rllib/core/rl_module/multi_rl_module.py @@ -1,5 +1,5 @@ import copy -from dataclasses import dataclass, field +import dataclasses import logging import pprint from typing import ( @@ -666,7 +666,11 @@ def build(self, module_id: Optional[ModuleID] = None) -> RLModule: observation_space=self.observation_space, action_space=self.action_space, inference_only=self.inference_only, - model_config=self.model_config, + model_config=( + dataclasses.asdict(self.model_config) + if dataclasses.is_dataclass(self.model_config) + else self.model_config + ), rl_module_specs=self.rl_module_specs, ) # Older custom model might still require the old `MultiRLModuleConfig` under @@ -859,7 +863,7 @@ def get_rl_module_config(self): "module2: [RLModuleSpec], ..}, inference_only=..)", error=False, ) -@dataclass +@dataclasses.dataclass class MultiRLModuleConfig: inference_only: bool = False modules: Dict[ModuleID, RLModuleSpec] = dataclasses.field(default_factory=dict) diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index f1fb5b337cc5..42aa0a780ed4 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -98,7 +98,7 @@ def build(self) -> "RLModule": observation_space=self.observation_space, action_space=self.action_space, inference_only=self.inference_only, - model_config=self.model_config, + model_config=self._get_model_config(), catalog_class=self.catalog_class, ) # Older custom model might still require the old `RLModuleConfig` under