diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index 31f7f12cc31e..9c51e565315f 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -12,6 +12,7 @@ from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule +from ray.tune.trial import Resources logger = logging.getLogger(__name__) @@ -141,6 +142,21 @@ class DQNAgent(Agent): _policy_graph = DQNPolicyGraph _optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS + @classmethod + @override(Agent) + def default_resource_request(cls, config): + cf = dict(cls._default_config, **config) + Agent._validate_config(cf) + if cf["optimizer_class"] == "AsyncReplayOptimizer": + extra = cf["optimizer"]["num_replay_buffer_shards"] + else: + extra = 0 + return Resources( + cpu=cf["num_cpus_for_driver"], + gpu=cf["num_gpus"], + extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"] + extra, + extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) + @override(Agent) def _init(self): self._validate_config() diff --git a/python/ray/rllib/optimizers/async_replay_optimizer.py b/python/ray/rllib/optimizers/async_replay_optimizer.py index a2cfee61a0d7..f1ae1bf71639 100644 --- a/python/ray/rllib/optimizers/async_replay_optimizer.py +++ b/python/ray/rllib/optimizers/async_replay_optimizer.py @@ -225,7 +225,8 @@ def _step(self): return sample_timesteps, train_timesteps -@ray.remote(num_cpus=0) +# reserve 1 CPU so that our method calls don't get stalled +@ray.remote(num_cpus=1) class ReplayActor(object): """A replay buffer shard. diff --git a/python/ray/rllib/tests/test_supported_spaces.py b/python/ray/rllib/tests/test_supported_spaces.py index 7d59a04fb223..93a8366bf56c 100644 --- a/python/ray/rllib/tests/test_supported_spaces.py +++ b/python/ray/rllib/tests/test_supported_spaces.py @@ -105,7 +105,7 @@ def check_support_multiagent(alg, config): class ModelSupportedSpaces(unittest.TestCase): def setUp(self): - ray.init(num_cpus=4) + ray.init(num_cpus=10) def tearDown(self): ray.shutdown()