diff --git a/python/ray/tune/integration/wandb.py b/python/ray/tune/integration/wandb.py index 280516685016..f21c79163a6e 100644 --- a/python/ray/tune/integration/wandb.py +++ b/python/ray/tune/integration/wandb.py @@ -527,8 +527,6 @@ def __init__(self, config: Dict, *args, **kwargs): "`class YourTrainable(WandbTrainableMixin)`." ) - super().__init__(config, *args, **kwargs) - _config = config.copy() try: @@ -540,6 +538,8 @@ def __init__(self, config: Dict, *args, **kwargs): "containing at least a `project` specification." ) + super().__init__(_config, *args, **kwargs) + api_key_file = wandb_config.pop("api_key_file", None) if api_key_file: api_key_file = os.path.expanduser(api_key_file) diff --git a/python/ray/tune/tests/test_integration_wandb.py b/python/ray/tune/tests/test_integration_wandb.py index b6ad19003e4e..296ba48d2033 100644 --- a/python/ray/tune/tests/test_integration_wandb.py +++ b/python/ray/tune/tests/test_integration_wandb.py @@ -437,6 +437,29 @@ def train_fn(config): self.assertEqual(wrapped.wandb.kwargs["id"], trial.trial_id) self.assertEqual(wrapped.wandb.kwargs["name"], trial.trial_name) + def testWandbMixinRllib(self): + """Test compatibility with RLLib configuration dicts""" + # Local import to avoid tune dependency on rllib + try: + from ray.rllib.agents.ppo import PPOTrainer + except ImportError: + self.skipTest("ray[rllib] not available") + return + + class WandbPPOTrainer(_MockWandbTrainableMixin, PPOTrainer): + pass + + config = { + "env": "CartPole-v0", + "wandb": { + "project": "test_project", + "api_key": "1234", + }, + } + + # Test that trainer object can be initialized + WandbPPOTrainer(config) + if __name__ == "__main__": import pytest