Skip to content

Commit

Permalink
[RLlib] Introduce experimental larger than GPU train batch size featu…
Browse files Browse the repository at this point in the history
…re for torch (ray-project#34189)

Signed-off-by: Artur Niederfahrenhorst <[email protected]>
  • Loading branch information
ArturNiederfahrenhorst authored Apr 25, 2023
1 parent 74ddaaa commit 72268e8
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 16 deletions.
11 changes: 10 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2525,7 +2525,16 @@ py_test(
name = "tests/test_gpus",
tags = ["team:rllib", "tests_dir"],
size = "large",
srcs = ["tests/test_gpus.py"]
srcs = ["tests/test_gpus.py"],
args = ["TestGPUs"]
)

py_test(
name = "tests/test_gpus_large_batch",
tags = ["team:rllib", "tests_dir"],
size = "large",
srcs = ["tests/test_gpus.py"],
args = ["TestGPUsLargeBatch"]
)

py_test(
Expand Down
16 changes: 16 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def __init__(self, algo_class=None):
self._disable_action_flattening = False
self._disable_execution_plan_api = True
self._disable_initialize_loss_from_dummy_batch = False
self._load_only_minibatch_onto_device = False

# Has this config object been frozen (cannot alter its attributes anymore).
self._is_frozen = False
Expand Down Expand Up @@ -964,6 +965,14 @@ def validate(self) -> None:
f"config.framework({self.framework_str})!"
)

if (
self.simple_optimizer or self.framework_str != "torch"
) and self._load_only_minibatch_onto_device:
raise ValueError(
"`load_only_minibatch_onto_device` is only supported for "
f"config.framework({self.framework_str}) and without simple_optimizer."
)

# Detect if specified env is an Atari env.
if self.is_atari is None:
self.is_atari = self._detect_atari_env()
Expand Down Expand Up @@ -1625,6 +1634,10 @@ def training(
_enable_learner_api: Whether to enable the LearnerGroup and Learner
for training. This API uses ray.train to run the training loop which
allows for a more flexible distributed training.
_load_only_minibatch_onto_device: Whether to load only the minibatch onto
the given device. This is useful for larger training batches that
don't fit on the given device while the mini-batches and their
gradients do. This experimental setting is only supported for torch
Returns:
This updated AlgorithmConfig object.
Expand Down Expand Up @@ -2460,6 +2473,7 @@ def experimental(
_disable_action_flattening: Optional[bool] = NotProvided,
_disable_execution_plan_api: Optional[bool] = NotProvided,
_disable_initialize_loss_from_dummy_batch: Optional[bool] = NotProvided,
_load_only_minibatch_onto_device: Optional[bool] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the config's experimental settings.
Expand Down Expand Up @@ -2503,6 +2517,8 @@ def experimental(
self._disable_initialize_loss_from_dummy_batch = (
_disable_initialize_loss_from_dummy_batch
)
if _load_only_minibatch_onto_device is not NotProvided:
self._load_only_minibatch_onto_device = _load_only_minibatch_onto_device

return self

Expand Down
23 changes: 20 additions & 3 deletions rllib/policy/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,15 +790,32 @@ def _zero_pad_in_place(path, value):
return self

@ExperimentalAPI
def to_device(self, device, framework="torch"):
def to_device(self, device, framework="torch", copy=False):
"""Moves tensors inside this batch to a given device.
Depending on the copy flag, this will either return a new batch
or modify this batch in-place.
Args:
device: The device to move the tensors to.
framework: The framework to use for the device (e.g. "torch").
copy: If False, modify batch in place. If True, return a new batch.
Returns:
A batch with all tensors moved to the given device.
"""
"""TODO: transfer batch to given device as framework tensor."""
if framework == "torch":
assert torch is not None
if copy:
target = SampleBatch()
else:
target = self
for k, v in self.items():
self[k] = convert_to_torch_tensor(v, device)
target[k] = convert_to_torch_tensor(v, device)
else:
raise NotImplementedError
return self
return target

@PublicAPI
def size_bytes(self) -> int:
Expand Down
36 changes: 33 additions & 3 deletions rllib/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,14 @@ def load_batch_into_buffer(
)

# 3) Load splits into the given buffer (consisting of n GPUs).
slices = [slice.to_device(self.devices[i]) for i, slice in enumerate(slices)]
if not self.config.get("_load_only_minibatch_onto_device", False):
# We usually want to load the full batch onto the device here, which is
# much faster than loading the batch slice-by-slice.
# However, if the batch is too large, it may be favorable to load the
# batch slice-by-slice.
slices = [
slice.to_device(self.devices[i]) for i, slice in enumerate(slices)
]
self._loaded_batches[buffer_index] = slices

# Return loaded samples per-device.
Expand Down Expand Up @@ -604,8 +611,20 @@ def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
)
batch_fetches[f"tower_{i}"] = {"custom_metrics": custom_metrics}

# If `_load_only_minibatch_onto_device` is True, then the main batch always
# remains on the CPU (it's probably too big to be fit on the GPU). Thus, in
# this case, for each individual update step, we need to copy the freshly
# determined sub-slice to the GPU. These sub-slices need to be small enough
# then to fit on the GPU.
if self.config.get("_load_only_minibatch_onto_device", False):
copy_batch_to_device = True
else:
copy_batch_to_device = False

# Do the (maybe parallelized) gradient calculation step.
tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches)
tower_outputs = self._multi_gpu_parallel_grad_calc(
device_batches, copy_batch_to_device=copy_batch_to_device
)

# Mean-reduce gradients over GPU-towers (do this on CPU: self.device).
all_grads = []
Expand Down Expand Up @@ -1061,7 +1080,9 @@ def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None):
return postprocessed_batch

def _multi_gpu_parallel_grad_calc(
self, sample_batches: List[SampleBatch]
self,
sample_batches: List[SampleBatch],
copy_batch_to_device: bool = False,
) -> List[Tuple[List[TensorType], GradInfoDict]]:
"""Performs a parallelized loss and gradient calculation over the batch.
Expand All @@ -1073,6 +1094,10 @@ def _multi_gpu_parallel_grad_calc(
Args:
sample_batches: A list of SampleBatch shards to
calculate loss and gradients for.
copy_batch_to_device: Whether to create a copy of the batch that is then
moved to GPU. This is useful if we don't want to move the original
batch to the device. In case of a large batch, we can thereby only move
mini-batches to GPU one by one and free them after each step.
Returns:
A list (one item per device) of 2-tuples, each with 1) gradient
Expand All @@ -1083,6 +1108,11 @@ def _multi_gpu_parallel_grad_calc(
results = {}
grad_enabled = torch.is_grad_enabled()

if copy_batch_to_device:
sample_batches = [
batch.to_device(i, copy=True) for i, batch in enumerate(sample_batches)
]

def _worker(shard_idx, model, sample_batch, device):
torch.set_grad_enabled(grad_enabled)
try:
Expand Down
42 changes: 38 additions & 4 deletions rllib/policy/torch_policy_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,14 @@ def load_batch_into_buffer(
)

# 3) Load splits into the given buffer (consisting of n GPUs).
slices = [slice.to_device(self.devices[i]) for i, slice in enumerate(slices)]
if not self.config.get("_load_only_minibatch_onto_device", False):
# We usually want to load the full batch onto the device here, which is
# much faster than loading the batch slice-by-slice.
# However, if the batch is too large, it may be favorable to load the
# batch slice-by-slice.
slices = [
slice.to_device(self.devices[i]) for i, slice in enumerate(slices)
]
self._loaded_batches[buffer_index] = slices

# Return loaded samples per-device.
Expand All @@ -749,7 +756,11 @@ def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:

@override(Policy)
@DeveloperAPI
def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
def learn_on_loaded_batch(
self,
offset: int = 0,
buffer_index: int = 0,
):
if not self._loaded_batches[buffer_index]:
raise ValueError(
"Must call Policy.load_batch_into_buffer() before "
Expand Down Expand Up @@ -803,8 +814,20 @@ def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
)
batch_fetches[f"tower_{i}"] = {"custom_metrics": custom_metrics}

# If `_load_only_minibatch_onto_device` is True, then the main batch always
# remains on the CPU (it's probably too big to be fit on the GPU). Thus, in
# this case, for each individual update step, we need to copy the freshly
# determined sub-slice to the GPU. These sub-slices need to be small enough
# then to fit on the GPU.
if self.config.get("_load_only_minibatch_onto_device", False):
copy_batch_to_device = True
else:
copy_batch_to_device = False

# Do the (maybe parallelized) gradient calculation step.
tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches)
tower_outputs = self._multi_gpu_parallel_grad_calc(
device_batches, copy_batch_to_device=copy_batch_to_device
)

# Mean-reduce gradients over GPU-towers (do this on CPU: self.device).
all_grads = []
Expand Down Expand Up @@ -1188,7 +1211,9 @@ def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None):
return postprocessed_batch

def _multi_gpu_parallel_grad_calc(
self, sample_batches: List[SampleBatch]
self,
sample_batches: List[SampleBatch],
copy_batch_to_device: bool = False,
) -> List[Tuple[List[TensorType], GradInfoDict]]:
"""Performs a parallelized loss and gradient calculation over the batch.
Expand All @@ -1200,6 +1225,10 @@ def _multi_gpu_parallel_grad_calc(
Args:
sample_batches: A list of SampleBatch shards to
calculate loss and gradients for.
copy_batch_to_device: Whether to create a copy of the batch that is then
moved to GPU. This is useful if we don't want to move the original
batch to the device. In case of a large batch, we can thereby only move
mini-batches to GPU one by one and free them after each step.
Returns:
A list (one item per device) of 2-tuples, each with 1) gradient
Expand All @@ -1210,6 +1239,11 @@ def _multi_gpu_parallel_grad_calc(
results = {}
grad_enabled = torch.is_grad_enabled()

if copy_batch_to_device:
sample_batches = [
batch.to_device(i, copy=True) for i, batch in enumerate(sample_batches)
]

def _worker(shard_idx, model, sample_batch, device):
torch.set_grad_enabled(grad_enabled)
try:
Expand Down
102 changes: 97 additions & 5 deletions rllib/tests/test_gpus.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import copy
import unittest

import numpy as np
import torch

import ray
from ray import air
from ray import tune
from ray.rllib.algorithms.a2c.a2c import A2CConfig
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.qmix import QMixConfig
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.test_utils import framework_iterator
from ray import tune

torch, _ = try_import_torch()


class TestGPUs(unittest.TestCase):
Expand Down Expand Up @@ -111,8 +117,94 @@ def test_gpus_in_local_mode(self):
ray.shutdown()


class TestGPUsLargeBatch(unittest.TestCase):
def test_larger_train_batch_size_multi_gpu_train_one_step(self):
# Tests that we can use a `train_batch_size` larger than GPU memory with our
# experimental setting `_load_only_minibatch_onto_device` with
# multi_gpu_train_one_step.

# These values make it so that one large minibatch and the optimizer
# variables can fit onto the device, but the whole sample_batch is already too
# large for the GPU itself.
sgd_minibatch_size = int(1e4)
train_batch_size = int(sgd_minibatch_size * 1e5)

# Fake CartPole episode of n time steps.
CARTPOLE_FAKE_BATCH = SampleBatch(
{
SampleBatch.OBS: np.zeros((train_batch_size, 4), dtype=np.float32),
SampleBatch.ACTIONS: np.zeros((train_batch_size,), dtype=np.float32),
SampleBatch.PREV_ACTIONS: np.zeros(
(train_batch_size,), dtype=np.float32
),
SampleBatch.REWARDS: np.zeros((train_batch_size,), dtype=np.float32),
SampleBatch.PREV_REWARDS: np.zeros(
(train_batch_size,), dtype=np.float32
),
"value_targets": np.zeros((train_batch_size,), dtype=np.float32),
SampleBatch.TERMINATEDS: np.array([False] * train_batch_size),
SampleBatch.TRUNCATEDS: np.array([False] * train_batch_size),
"advantages": np.zeros((train_batch_size,), dtype=np.float32),
SampleBatch.VF_PREDS: np.zeros((train_batch_size,), dtype=np.float32),
SampleBatch.ACTION_DIST_INPUTS: np.zeros(
(train_batch_size, 2), dtype=np.float32
),
SampleBatch.ACTION_LOGP: np.zeros(
(train_batch_size,), dtype=np.float32
),
SampleBatch.EPS_ID: np.zeros((train_batch_size,), dtype=np.int64),
SampleBatch.AGENT_INDEX: np.zeros((train_batch_size,), dtype=np.int64),
}
)

# Test if we can even fail this test due too a GPU OOM
try:
batch_copy = copy.deepcopy(CARTPOLE_FAKE_BATCH)
batch_copy.to_device(0)
raise ValueError(
"We should not be able to move this batch to the device. "
"If this error occurs, this means that this test cannot fail "
"inside multi_gpu_train_one_step."
)
except torch.cuda.OutOfMemoryError:
pass

for config_class in (PPOConfig, QMixConfig):
config = (
config_class()
.environment(env="CartPole-v1")
.framework("torch")
.resources(num_gpus=1)
.rollouts(num_rollout_workers=0)
.training(
train_batch_size=train_batch_size,
num_sgd_iter=1,
sgd_minibatch_size=self.sgd_minibatch_size,
# This setting makes it so that we don't load a batch of
# size `train_batch_size` onto the device, but only
# minibatches.
_load_only_minibatch_onto_device=True,
)
)

algorithm = config.build()
policy = algorithm.get_policy()

# Sanity check if we are covering both, TorchPolicy and TorchPolicyV2
if config_class is QMixConfig:
assert isinstance(policy, TorchPolicy)
elif config_class is PPOConfig:
assert isinstance(policy, TorchPolicyV2)

policy.load_batch_into_buffer(CARTPOLE_FAKE_BATCH)
policy.learn_on_loaded_batch()


if __name__ == "__main__":
import pytest
import sys

sys.exit(pytest.main(["-v", __file__]))
# One can specify the specific TestCase class to run.
# None for all unittest.TestCase classes in this file.
class_ = sys.argv[1] if len(sys.argv) > 1 else None
sys.exit(pytest.main(["-v", __file__ + ("" if class_ is None else "::" + class_)]))

0 comments on commit 72268e8

Please sign in to comment.