Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-node multi-GPU training. #63

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def decode_one_batch(
ans[lm_scale_str] = hyps
else:
for lm_scale in lm_scale_list:
ans[lm_scale_str] = [[] * lattice.shape[0]]
ans[f"{lm_scale}"] = [[] * lattice.shape[0]]
return ans


Expand Down
125 changes: 125 additions & 0 deletions egs/librispeech/ASR/conformer_ctc/run-multi-node-multi-gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#!/usr/bin/env bash
#
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# This script is the entry point to start model training
# with multi-node multi-GPU.
#
# Read the usage instructions below for how to run this script.

set -e

cur_dir=$(cd $(dirname $BASH_SOURCE) && pwd)

# DDP related parameters
master_addr=
node_rank=
num_nodes=
master_port=1234

# Training script parameters
# You can add more if you like
#
# Use ./conformer_ctc/train.py --help to see more
#
# If you add more parameters here, remember to append them to the
# end of this file.
#
max_duration=200
bucketing_sampler=1
full_libri=1
start_epoch=0
num_epochs=2
exp_dir=conformer_ctc/exp3
lang_dir=data/lang_bpe_500

. $cur_dir/../shared/parse_options.sh

function usage() {
echo "Usage: "
echo ""
echo " $0 \\"
echo " --master-addr <IP of master> \\"
echo " --master-port <Port of master> \\"
echo " --node-rank <rank of this node> \\"
echo " --num-nodes <Number of node>"
echo ""
echo " --master-addr The ip address of the master node."
echo " --master-port The port of the master node."
echo " --node-rank Rank of this node."
echo " --num-nodes Number of nodes in DDP training."
echo ""
echo "Usage example:"
echo "Suppose you want to use DDP with two machines:"
echo " (1) Machine 1 has 4 GPUs. You want to use"
echo " GPU 0, 1, and 3 for training"
echo " IP of machine 1 is: 10.177.41.71"
echo " (2) Machine 2 has 4 GPUs. You want to use"
echo " GPU 0, 2, and 3 for training"
echo " IP of machine 2 is: 10.177.41.72"
echo "You want to select machine 1 as the master node and"
echo "assume that the port 1234 is free on machine 1."
echo ""
echo "On machine 1, you run:"
echo ""
echo " export CUDA_VISIBLE_DEVICES=\"0,1,3\""
echo " ./conformer_ctc/run-multi-node-multi-gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 0 --num-nodes 2"
echo ""
echo "On machine 2, you run:"
echo ""
echo " export CUDA_VISIBLE_DEVICES=\"0,2,3\""
echo " ./conformer_ctc/run-multi-node-multi-gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 1 --num-nodes 2"
echo ""
echo "Note 1:"
echo " You use CUDA_VISIBLE_DEVICES to decide which GPUs are used for training."
echo ""
echo "Note 2:"
echo " If you use torch < 1.9.0, then every node has to use the same number of GPUs for training."
echo " If you use torch >= 1.9.0, different nodes can have a different number of GPUs for training."
exit 1
}

default='\033[0m'
bold='\033[1m'
red='\033[31m'

function error() {
printf "${bold}${red}[ERROR]${default} $1\n"
}

[ ! -z $CUDA_VISIBLE_DEVICES ] || ( echo; error "Please set CUDA_VISIBLE_DEVICES"; echo; usage )
[ ! -z $master_addr ] || ( echo; error "Please set --master-addr"; echo; usage )
[ ! -z $master_port ] || ( echo; error "Please set --master-port"; echo; usage )
[ ! -z $node_rank ] || ( echo; error "Please set --node-rank"; echo; usage )
[ ! -z $num_nodes ] || ( echo; error "Please set --num-nodes"; echo; usage )

# Number of GPUs this node has
num_gpus=$(python3 -c "s=\"$CUDA_VISIBLE_DEVICES\"; print(len(s.split(',')))")

echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
echo "num_gpus: $num_gpus"
echo "master_addr: $master_addr"

export MASTER_ADDR=$master_addr
export MASTER_PORT=$master_port

set -x

python -m torch.distributed.launch \
--use_env \
--nproc_per_node $num_gpus \
--nnodes $num_nodes \
--node_rank $node_rank \
--master_addr $master_addr \
--master_port $master_port \
\
$cur_dir/train.py \
--use-multi-node true \
--master-port $master_port \
--max-duration $max_duration \
--bucketing-sampler $bucketing_sampler \
--full-libri $full_libri \
--start-epoch $start_epoch \
--num-epochs $num_epochs \
--exp-dir $exp_dir \
--lang-dir $lang_dir
69 changes: 59 additions & 10 deletions egs/librispeech/ASR/conformer_ctc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.dist import (
cleanup_dist,
get_local_rank,
get_rank,
get_world_size,
setup_dist,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
Expand All @@ -54,6 +60,17 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--use-multi-node",
type=str2bool,
default=False,
help="""True if using multi-node multi-GPU.
You are not supposed to set it directly.
See ./conformer_ctc/run-multi-node-multi-gpu.sh
for details.
""",
)

parser.add_argument(
"--world-size",
type=int,
Expand Down Expand Up @@ -92,6 +109,23 @@ def get_parser():
""",
)

parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved.
""",
)

parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe",
help="""It contains language related input files such as lexicon.txt
""",
)

return parser


Expand All @@ -106,12 +140,6 @@ def get_params() -> AttributeDict:

Explanation of options saved in `params`:

- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved

- lang_dir: It contains language related input files such as
"lexicon.txt"

- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
Expand Down Expand Up @@ -621,9 +649,17 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))

if args.use_multi_node:
local_rank = get_local_rank()
else:
local_rank = rank
logging.info(
f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}"
)

fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_dist(rank, world_size, params.master_port, args.use_multi_node)

setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
Expand All @@ -640,7 +676,8 @@ def run(rank, world_size, args):

device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
device = torch.device("cuda", local_rank)
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")

graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
Expand All @@ -665,7 +702,7 @@ def run(rank, world_size, args):

model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
model = DDP(model, device_ids=[local_rank])

optimizer = Noam(
model.parameters(),
Expand Down Expand Up @@ -726,9 +763,21 @@ def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)

if args.use_multi_node:
# for multi-node multi-GPU training
rank = get_rank()
world_size = get_world_size()
args.world_size = world_size
print(f"rank: {rank}, world_size: {world_size}")
run(rank=rank, world_size=world_size, args=args)
return

world_size = args.world_size
assert world_size >= 1

if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
Expand Down
3 changes: 1 addition & 2 deletions egs/librispeech/ASR/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dl_dir=$PWD/download
# data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
5000
500
)

# All files generated by this script are saved in "data".
Expand Down Expand Up @@ -190,5 +191,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
./local/compile_hlg.py --lang-dir $lang_dir
done
fi

cd data && ln -sfv lang_bpe_5000 lang_bpe
46 changes: 39 additions & 7 deletions icefall/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,46 @@
from torch import distributed as dist


def setup_dist(rank, world_size, master_port=None):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = (
"12354" if master_port is None else str(master_port)
)
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def setup_dist(rank, world_size, master_port=None, is_multi_node=False):
"""
rank and world_size are used only if is_multi_node is False.
"""
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = "localhost"

if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = (
"12354" if master_port is None else str(master_port)
)

if is_multi_node is False:
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
else:
dist.init_process_group("nccl")


def cleanup_dist():
dist.destroy_process_group()


def get_world_size():
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"])
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1


def get_rank():
if "RANK" in os.environ:
return int(os.environ["RANK"])
elif dist.is_available() and dist.is_initialized():
return dist.rank()
else:
return 1


def get_local_rank():
return int(os.environ.get("LOCAL_RANK", 0))