From f871c11e82bfb007fc31379d02d3535a6a45ddda Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 2 Jun 2022 11:20:25 -0700 Subject: [PATCH 1/4] Prototype `torch.distributed` --- rl_games/common/a2c_common.py | 88 +++++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 35 deletions(-) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 9610ca29..3323ca6c 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -20,6 +20,7 @@ from tensorboardX import SummaryWriter import torch from torch import nn +import torch.distributed as dist from time import sleep @@ -68,11 +69,15 @@ def __init__(self, base_name, params): self.rank_size = 1 self.curr_frames = 0 if self.multi_gpu: - from rl_games.distributed.hvd_wrapper import HorovodWrapper - self.hvd = HorovodWrapper() - self.config = self.hvd.update_algo_config(config) - self.rank = self.hvd.rank - self.rank_size = self.hvd.rank_size + self.rank = int(os.getenv("LOCAL_RANK", "0")) + self.rank_size = int(os.getenv("WORLD_SIZE", "1")) + dist.init_process_group("nccl", rank=self.rank, world_size=self.rank_size) + + self.device_name = 'cuda:' + str(self.rank) + config['device'] = self.device_name + if self.rank != 0: + config['print_stats'] = False + config['lr_schedule'] = None self.use_diagnostics = config.get('use_diagnostics', False) @@ -251,19 +256,27 @@ def __init__(self, base_name, params): def trancate_gradients_and_step(self): if self.multi_gpu: - self.optimizer.synchronize() + # batch allreduce ops: see https://github.com/entity-neural-network/incubator/pull/220 + all_grads_list = [] + for param in self.model.parameters(): + if param.grad is not None: + all_grads_list.append(param.grad.view(-1)) + all_grads = torch.cat(all_grads_list) + dist.all_reduce(all_grads, op=dist.ReduceOp.SUM) + offset = 0 + for param in self.model.parameters(): + if param.grad is not None: + param.grad.data.copy_( + all_grads[offset : offset + param.numel()].view_as(param.grad.data) / self.rank_size + ) + offset += param.numel() if self.truncate_grads: self.scaler.unscale_(self.optimizer) nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm) - if self.multi_gpu: - with self.optimizer.skip_synchronize(): - self.scaler.step(self.optimizer) - self.scaler.update() - else: - self.scaler.step(self.optimizer) - self.scaler.update() + self.scaler.step(self.optimizer) + self.scaler.update() def load_networks(self, params): builder = model_builder.ModelBuilder() @@ -308,8 +321,8 @@ def set_train(self): def update_lr(self, lr): if self.multi_gpu: - lr_tensor = torch.tensor([lr]) - self.hvd.broadcast_value(lr_tensor, 'learning_rate') + lr_tensor = torch.tensor([lr], device=self.device) + dist.broadcast(lr_tensor, 0) lr = lr_tensor.item() for param_group in self.optimizer.param_groups: @@ -802,7 +815,8 @@ def train_epoch(self): av_kls = torch_ext.mean_list(ep_kls) if self.multi_gpu: - av_kls = self.hvd.average_value(av_kls, 'ep_kls') + dist.all_reduce(av_kls, op=dist.ReduceOp.SUM) + av_kls /= self.rank_size self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item()) self.update_lr(self.last_lr) @@ -887,19 +901,20 @@ def train(self): self.obs = self.env_reset() if self.multi_gpu: - self.hvd.setup_algo(self) + # + print("====================broadcasting parameters") + model_params = [self.model.state_dict()] + dist.broadcast_object_list(model_params, 0) + self.model.load_state_dict(model_params[0]) while True: epoch_num = self.update_epoch() step_time, play_time, update_time, sum_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul = self.train_epoch() - if self.multi_gpu: - self.hvd.sync_stats(self) - # cleaning memory to optimize space self.dataset.update_values_dict(None) total_time += sum_time - curr_frames = self.curr_frames + curr_frames = self.curr_frames * self.rank_size if self.multi_gpu else self.curr_frames self.frame += curr_frames should_exit = False if self.rank == 0: @@ -958,10 +973,10 @@ def train(self): print('MAX EPOCHS NUM!') should_exit = True update_time = 0 - if self.multi_gpu: - should_exit_t = torch.tensor(should_exit).float() - self.hvd.broadcast_value(should_exit_t, 'should_exit') - should_exit = should_exit_t.bool().item() + # if self.multi_gpu: + # should_exit_t = torch.tensor(should_exit).float() + # dist.broadcast(should_exit_t, 0) + # should_exit = should_exit_t.bool().item() if should_exit: return self.last_mean_rewards, epoch_num @@ -1040,9 +1055,10 @@ def train_epoch(self): self.dataset.update_mu_sigma(cmu, csigma) av_kls = torch_ext.mean_list(ep_kls) - if self.multi_gpu: - av_kls = self.hvd.average_value(av_kls, 'ep_kls') + dist.all_reduce(av_kls, op=dist.ReduceOp.SUM) + av_kls /= self.rank_size + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item()) self.update_lr(self.last_lr) @@ -1128,7 +1144,11 @@ def train(self): self.curr_frames = self.batch_size_envs if self.multi_gpu: - self.hvd.setup_algo(self) + # + print("====================broadcasting parameters") + model_params = [self.model.state_dict()] + dist.broadcast_object_list(model_params, 0) + self.model.load_state_dict(model_params[0]) while True: epoch_num = self.update_epoch() @@ -1136,8 +1156,6 @@ def train(self): total_time += sum_time frame = self.frame // self.num_agents - if self.multi_gpu: - self.hvd.sync_stats(self) # cleaning memory to optimize space self.dataset.update_values_dict(None) should_exit = False @@ -1147,7 +1165,7 @@ def train(self): # do we need scaled_time? scaled_time = self.num_agents * sum_time scaled_play_time = self.num_agents * play_time - curr_frames = self.curr_frames + curr_frames = self.curr_frames * self.rank_size if self.multi_gpu else self.curr_frames self.frame += curr_frames if self.print_stats: fps_step = curr_frames / step_time @@ -1203,10 +1221,10 @@ def train(self): should_exit = True update_time = 0 - if self.multi_gpu: - should_exit_t = torch.tensor(should_exit).float() - self.hvd.broadcast_value(should_exit_t, 'should_exit') - should_exit = should_exit_t.float().item() + # if self.multi_gpu: + # should_exit_t = torch.tensor(should_exit).float() + # dist.broadcast(should_exit_t, 0) + # should_exit = should_exit_t.float().item() if should_exit: return self.last_mean_rewards, epoch_num From fbe50c29a5640ecccdfba2067fc50c02edd5cbc5 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 3 Jun 2022 09:12:54 -0700 Subject: [PATCH 2/4] Fix --- rl_games/algos_torch/a2c_continuous.py | 1 - rl_games/algos_torch/a2c_discrete.py | 1 - rl_games/algos_torch/central_value.py | 8 +-- rl_games/common/a2c_common.py | 2 +- rl_games/distributed/hvd_wrapper.py | 75 -------------------------- rl_games/torch_runner.py | 8 ++- runner.py | 7 ++- 7 files changed, 13 insertions(+), 89 deletions(-) delete mode 100644 rl_games/distributed/hvd_wrapper.py diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 489e8e0e..7a7c762d 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -48,7 +48,6 @@ def __init__(self, base_name, params): 'writter' : self.writer, 'max_epochs' : self.max_epochs, 'multi_gpu' : self.multi_gpu, - 'hvd': self.hvd if self.multi_gpu else None } self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device) diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index 6f912788..3f50ae9e 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -49,7 +49,6 @@ def __init__(self, base_name, params): 'writter' : self.writer, 'max_epochs' : self.max_epochs, 'multi_gpu' : self.multi_gpu, - 'hvd': self.hvd if self.multi_gpu else None } self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device) diff --git a/rl_games/algos_torch/central_value.py b/rl_games/algos_torch/central_value.py index e1074ce0..d07c31a1 100644 --- a/rl_games/algos_torch/central_value.py +++ b/rl_games/algos_torch/central_value.py @@ -1,5 +1,6 @@ import torch from torch import nn +import torch.distributed as dist import gym import numpy as np from rl_games.algos_torch import torch_ext @@ -9,7 +10,7 @@ from rl_games.common import schedulers class CentralValueTrain(nn.Module): - def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_length, num_actors, num_actions, seq_len, normalize_value,network, config, writter, max_epochs, multi_gpu, hvd): + def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_length, num_actors, num_actions, seq_len, normalize_value,network, config, writter, max_epochs, multi_gpu): nn.Module.__init__(self) self.ppo_device = ppo_device self.num_agents, self.horizon_length, self.num_actors, self.seq_len = num_agents, horizon_length, num_actors, seq_len @@ -19,7 +20,6 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng self.value_size = value_size self.max_epochs = max_epochs self.multi_gpu = multi_gpu - self.hvd = hvd self.truncate_grads = config.get('truncate_grads', False) self.config = config self.normalize_input = config['normalize_input'] @@ -77,8 +77,8 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng def update_lr(self, lr): if self.multi_gpu: - lr_tensor = torch.tensor([lr]) - self.hvd.broadcast_value(lr_tensor, 'cv_learning_rate') + lr_tensor = torch.tensor([lr], device=self.device) + dist.broadcast(lr_tensor, 0) lr = lr_tensor.item() for param_group in self.optimizer.param_groups: diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 3323ca6c..2ea94264 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -901,7 +901,7 @@ def train(self): self.obs = self.env_reset() if self.multi_gpu: - # + torch.cuda.set_device(self.rank) print("====================broadcasting parameters") model_params = [self.model.state_dict()] dist.broadcast_object_list(model_params, 0) diff --git a/rl_games/distributed/hvd_wrapper.py b/rl_games/distributed/hvd_wrapper.py deleted file mode 100644 index 07d78233..00000000 --- a/rl_games/distributed/hvd_wrapper.py +++ /dev/null @@ -1,75 +0,0 @@ -import torch -import horovod.torch as hvd -import os - - -class HorovodWrapper: - def __init__(self): - hvd.init() - self.rank = hvd.rank() - self.rank_size = hvd.size() - print('Starting horovod with rank: {0}, size: {1}'.format(self.rank, self.rank_size)) - #self.device_name = 'cpu' - self.device_name = 'cuda:' + str(self.rank) - - def update_algo_config(self, config): - config['device'] = self.device_name - if self.rank != 0: - config['print_stats'] = False - config['lr_schedule'] = None - return config - - def setup_algo(self, algo): - hvd.broadcast_parameters(algo.model.state_dict(), root_rank=0) - hvd.broadcast_optimizer_state(algo.optimizer, root_rank=0) - algo.optimizer = hvd.DistributedOptimizer(algo.optimizer, named_parameters=algo.model.named_parameters()) - - self.sync_stats(algo) - - if algo.has_central_value: - hvd.broadcast_parameters(algo.central_value_net.model.state_dict(), root_rank=0) - hvd.broadcast_optimizer_state(algo.central_value_net.optimizer, root_rank=0) - - algo.central_value_net.optimizer = hvd.DistributedOptimizer(algo.central_value_net.optimizer, named_parameters=algo.central_value_net.model.named_parameters()) - - # allreduce doesn't work in expected way. need to fix it in the future - def sync_recursive(self, values, name): - if isinstance(values, torch.Tensor): - values.data = hvd.allreduce(values, name=name) - else: - for k, v in values.items(): - self.sync_recursive(v, name+'/'+k) - - def sync_stats(self, algo): - stats_dict = algo.get_stats_weights(model_stats=False) - #self.sync_recursive(stats_dict, 'stats') - if algo.normalize_input: - algo.model.running_mean_std.running_mean = hvd.allreduce(algo.model.running_mean_std.running_mean, 'normalize_input/running_mean') - algo.model.running_mean_std.running_var = hvd.allreduce(algo.model.running_mean_std.running_var, 'normalize_input/running_var') - if algo.normalize_value: - algo.model.value_mean_std.running_mean = hvd.allreduce(algo.model.value_mean_std.running_mean, 'normalize_value/running_mean') - algo.model.value_mean_std.running_var = hvd.allreduce(algo.model.value_mean_std.running_var, 'normalize_value/running_var') - if algo.has_central_value: - cv_net = algo.central_value_net - if cv_net.normalize_input: - cv_net.model.running_mean_std.running_mean = hvd.allreduce(cv_net.model.running_mean_std.running_mean, 'cval/normalize_input/running_mean') - cv_net.model.running_mean_std.running_var = hvd.allreduce(cv_net.model.running_mean_std.running_var, 'cval/normalize_input/running_var') - if cv_net.normalize_value: - cv_net.model.value_mean_std.running_mean = hvd.allreduce(cv_net.model.value_mean_std.running_mean, 'cval/normalize_value/running_mean') - cv_net.model.value_mean_std.running_var = hvd.allreduce(cv_net.model.value_mean_std.running_var, 'cval/normalize_value/running_var') - algo.curr_frames = hvd.allreduce(torch.tensor(algo.curr_frames), average=False).item() - - def broadcast_value(self, val, name): - hvd.broadcast_parameters({name: val}, root_rank=0) - - def is_root(self): - return self.rank == 0 - - def average_stats(self, stats_dict): - res_dict = {} - for k,v in stats_dict.items(): - res_dict[k] = self.metric_average(v, k) - - def average_value(self, val, name): - avg_tensor = hvd.allreduce(val, name=name) - return avg_tensor diff --git a/rl_games/torch_runner.py b/rl_games/torch_runner.py index 331fd510..a2c38ba9 100644 --- a/rl_games/torch_runner.py +++ b/rl_games/torch_runner.py @@ -1,3 +1,4 @@ +import os import time import numpy as np import random @@ -60,10 +61,7 @@ def load_config(self, params): self.seed = int(time.time()) if params["config"].get('multi_gpu', False): - import horovod.torch as hvd - - hvd.init() - self.seed += hvd.rank() + self.seed += int(os.getenv("LOCAL_RANK", "0")) print(f"self.seed = {self.seed}") self.algo_params = params['algo'] @@ -82,7 +80,7 @@ def load_config(self, params): params['config']['env_config']['seed'] = self.seed else: if params["config"].get('multi_gpu', False): - params['config']['env_config']['seed'] += hvd.rank() + params['config']['env_config']['seed'] += int(os.getenv("LOCAL_RANK", "0")) config = params['config'] config['reward_shaper'] = tr_helpers.DefaultRewardsShaper(**config['reward_shaper']) diff --git a/runner.py b/runner.py index b8de1395..b1613f00 100644 --- a/runner.py +++ b/runner.py @@ -52,7 +52,8 @@ except yaml.YAMLError as exc: print(exc) - if args["track"]: + rank = int(os.getenv("LOCAL_RANK", "0")) + if args["track"] and rank == 0: import wandb wandb.init( @@ -67,4 +68,6 @@ runner.run(args) ray.shutdown() - + + if args["track"] and rank == 0: + wandb.finish() From c1702323efe64620df27cb832710b0febbbed653 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 3 Jun 2022 09:18:38 -0700 Subject: [PATCH 3/4] add docs --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 5fc9c168..76520a22 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,14 @@ poetry run python runner.py --train --file rl_games/configs/atari/ppo_breakout_t ``` +## Multi GPU + +We use `torchrun` to orchestrate any multi-gpu runs. + +```bash +torchrun --standalone --nnodes=1 --nproc_per_node=2 runner.py --train --file rl_games/configs/ppo_cartpole.yaml +``` + ## Config Parameters | Field | Example Value | Default | Description | From 0cacb8e6e51ee4556e6abcf06b1492a07d75c804 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 3 Jun 2022 10:58:55 -0700 Subject: [PATCH 4/4] Address comments --- rl_games/common/a2c_common.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 2ea94264..6c437ed5 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -973,10 +973,10 @@ def train(self): print('MAX EPOCHS NUM!') should_exit = True update_time = 0 - # if self.multi_gpu: - # should_exit_t = torch.tensor(should_exit).float() - # dist.broadcast(should_exit_t, 0) - # should_exit = should_exit_t.bool().item() + if self.multi_gpu: + should_exit_t = torch.tensor(should_exit).float() + dist.broadcast(should_exit_t, 0) + should_exit = should_exit_t.bool().item() if should_exit: return self.last_mean_rewards, epoch_num @@ -1221,10 +1221,10 @@ def train(self): should_exit = True update_time = 0 - # if self.multi_gpu: - # should_exit_t = torch.tensor(should_exit).float() - # dist.broadcast(should_exit_t, 0) - # should_exit = should_exit_t.float().item() + if self.multi_gpu: + should_exit_t = torch.tensor(should_exit).float() + dist.broadcast(should_exit_t, 0) + should_exit = should_exit_t.float().item() if should_exit: return self.last_mean_rewards, epoch_num