Skip to content

Commit

Permalink
[RLlib] Issue 31323: BC/MARWIL/CQL do work with multi-GPU (but config…
Browse files Browse the repository at this point in the history
… validation prevents them from running in this mode). (#31393)
  • Loading branch information
sven1977 authored and AmeerHajAli committed Jan 12, 2023
1 parent 422e636 commit 687b1f0
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 7 deletions.
3 changes: 0 additions & 3 deletions rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,6 @@ def validate(self) -> None:
# Call super's validation method.
super().validate()

if self.num_gpus > 1:
raise ValueError("`num_gpus` > 1 not yet supported for CQL!")

# CQL-torch performs the optimizer steps inside the loss function.
# Using the multi-GPU optimizer will therefore not work (see multi-GPU
# check above) and we must use the simple optimizer for now.
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/cql/tests/test_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_cql_compilation(self):
evaluation_num_workers=2,
)
.rollouts(num_rollout_workers=0)
.reporting(min_time_s_per_iteration=0.0)
.reporting(min_time_s_per_iteration=0)
)
num_iterations = 4

Expand Down
3 changes: 0 additions & 3 deletions rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,6 @@ def validate(self) -> None:
if self.beta < 0.0 or self.beta > 1.0:
raise ValueError("`beta` must be within 0.0 and 1.0!")

if self.num_gpus > 1:
raise ValueError("`num_gpus` > 1 not yet supported for MARWIL!")

if self.postprocess_inputs is False and self.beta > 0.0:
raise ValueError(
"`postprocess_inputs` must be True for MARWIL (to "
Expand Down

0 comments on commit 687b1f0

Please sign in to comment.