diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 5a83dd39c9..1b5da7cf3b 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -383,7 +383,7 @@ def decode_one_batch( ans[lm_scale_str] = hyps else: for lm_scale in lm_scale_list: - ans[lm_scale_str] = [[] * lattice.shape[0]] + ans[f"{lm_scale}"] = [[] * lattice.shape[0]] return ans diff --git a/egs/librispeech/ASR/conformer_ctc/run-multi-node-multi-gpu.sh b/egs/librispeech/ASR/conformer_ctc/run-multi-node-multi-gpu.sh new file mode 100755 index 0000000000..e1fb6f428a --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/run-multi-node-multi-gpu.sh @@ -0,0 +1,125 @@ +#!/usr/bin/env bash +# +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# This script is the entry point to start model training +# with multi-node multi-GPU. +# +# Read the usage instructions below for how to run this script. + +set -e + +cur_dir=$(cd $(dirname $BASH_SOURCE) && pwd) + +# DDP related parameters +master_addr= +node_rank= +num_nodes= +master_port=1234 + +# Training script parameters +# You can add more if you like +# +# Use ./conformer_ctc/train.py --help to see more +# +# If you add more parameters here, remember to append them to the +# end of this file. +# +max_duration=200 +bucketing_sampler=1 +full_libri=1 +start_epoch=0 +num_epochs=2 +exp_dir=conformer_ctc/exp3 +lang_dir=data/lang_bpe_500 + +. $cur_dir/../shared/parse_options.sh + +function usage() { + echo "Usage: " + echo "" + echo " $0 \\" + echo " --master-addr \\" + echo " --master-port \\" + echo " --node-rank \\" + echo " --num-nodes " + echo "" + echo " --master-addr The ip address of the master node." + echo " --master-port The port of the master node." + echo " --node-rank Rank of this node." + echo " --num-nodes Number of nodes in DDP training." + echo "" + echo "Usage example:" + echo "Suppose you want to use DDP with two machines:" + echo " (1) Machine 1 has 4 GPUs. You want to use" + echo " GPU 0, 1, and 3 for training" + echo " IP of machine 1 is: 10.177.41.71" + echo " (2) Machine 2 has 4 GPUs. You want to use" + echo " GPU 0, 2, and 3 for training" + echo " IP of machine 2 is: 10.177.41.72" + echo "You want to select machine 1 as the master node and" + echo "assume that the port 1234 is free on machine 1." + echo "" + echo "On machine 1, you run:" + echo "" + echo " export CUDA_VISIBLE_DEVICES=\"0,1,3\"" + echo " ./conformer_ctc/run-multi-node-multi-gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 0 --num-nodes 2" + echo "" + echo "On machine 2, you run:" + echo "" + echo " export CUDA_VISIBLE_DEVICES=\"0,2,3\"" + echo " ./conformer_ctc/run-multi-node-multi-gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 1 --num-nodes 2" + echo "" + echo "Note 1:" + echo " You use CUDA_VISIBLE_DEVICES to decide which GPUs are used for training." + echo "" + echo "Note 2:" + echo " If you use torch < 1.9.0, then every node has to use the same number of GPUs for training." + echo " If you use torch >= 1.9.0, different nodes can have a different number of GPUs for training." + exit 1 +} + +default='\033[0m' +bold='\033[1m' +red='\033[31m' + +function error() { + printf "${bold}${red}[ERROR]${default} $1\n" +} + +[ ! -z $CUDA_VISIBLE_DEVICES ] || ( echo; error "Please set CUDA_VISIBLE_DEVICES"; echo; usage ) +[ ! -z $master_addr ] || ( echo; error "Please set --master-addr"; echo; usage ) +[ ! -z $master_port ] || ( echo; error "Please set --master-port"; echo; usage ) +[ ! -z $node_rank ] || ( echo; error "Please set --node-rank"; echo; usage ) +[ ! -z $num_nodes ] || ( echo; error "Please set --num-nodes"; echo; usage ) + +# Number of GPUs this node has +num_gpus=$(python3 -c "s=\"$CUDA_VISIBLE_DEVICES\"; print(len(s.split(',')))") + +echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES" +echo "num_gpus: $num_gpus" +echo "master_addr: $master_addr" + +export MASTER_ADDR=$master_addr +export MASTER_PORT=$master_port + +set -x + +python -m torch.distributed.launch \ + --use_env \ + --nproc_per_node $num_gpus \ + --nnodes $num_nodes \ + --node_rank $node_rank \ + --master_addr $master_addr \ + --master_port $master_port \ + \ + $cur_dir/train.py \ + --use-multi-node true \ + --master-port $master_port \ + --max-duration $max_duration \ + --bucketing-sampler $bucketing_sampler \ + --full-libri $full_libri \ + --start-epoch $start_epoch \ + --num-epochs $num_epochs \ + --exp-dir $exp_dir \ + --lang-dir $lang_dir diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 80b2d924a7..81d6d4a10e 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -39,7 +39,13 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist +from icefall.dist import ( + cleanup_dist, + get_local_rank, + get_rank, + get_world_size, + setup_dist, +) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -54,6 +60,17 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=False, + help="""True if using multi-node multi-GPU. + You are not supposed to set it directly. + See ./conformer_ctc/run-multi-node-multi-gpu.sh + for details. + """, + ) + parser.add_argument( "--world-size", type=int, @@ -92,6 +109,23 @@ def get_parser(): """, ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved. + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe", + help="""It contains language related input files such as lexicon.txt + """, + ) + return parser @@ -106,12 +140,6 @@ def get_params() -> AttributeDict: Explanation of options saved in `params`: - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" - - best_train_loss: Best training loss so far. It is used to select the model that has the lowest training loss. It is updated during the training. @@ -621,9 +649,17 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) + if args.use_multi_node: + local_rank = get_local_rank() + else: + local_rank = rank + logging.info( + f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}" + ) + fix_random_seed(42) if world_size > 1: - setup_dist(rank, world_size, params.master_port) + setup_dist(rank, world_size, params.master_port, args.use_multi_node) setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") @@ -640,7 +676,8 @@ def run(rank, world_size, args): device = torch.device("cpu") if torch.cuda.is_available(): - device = torch.device("cuda", rank) + device = torch.device("cuda", local_rank) + logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}") graph_compiler = BpeCtcTrainingGraphCompiler( params.lang_dir, @@ -665,7 +702,7 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: - model = DDP(model, device_ids=[rank]) + model = DDP(model, device_ids=[local_rank]) optimizer = Noam( model.parameters(), @@ -726,9 +763,21 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return world_size = args.world_size assert world_size >= 1 + if world_size > 1: mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) else: diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index f06e013f60..dd3f1085a5 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -41,6 +41,7 @@ dl_dir=$PWD/download # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( 5000 + 500 ) # All files generated by this script are saved in "data". @@ -190,5 +191,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then ./local/compile_hlg.py --lang-dir $lang_dir done fi - -cd data && ln -sfv lang_bpe_5000 lang_bpe diff --git a/icefall/dist.py b/icefall/dist.py index 203c7c563d..9dc9031778 100644 --- a/icefall/dist.py +++ b/icefall/dist.py @@ -21,14 +21,46 @@ from torch import distributed as dist -def setup_dist(rank, world_size, master_port=None): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = ( - "12354" if master_port is None else str(master_port) - ) - dist.init_process_group("nccl", rank=rank, world_size=world_size) - torch.cuda.set_device(rank) +def setup_dist(rank, world_size, master_port=None, is_multi_node=False): + """ + rank and world_size are used only if is_multi_node is False. + """ + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = ( + "12354" if master_port is None else str(master_port) + ) + + if is_multi_node is False: + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + else: + dist.init_process_group("nccl") def cleanup_dist(): dist.destroy_process_group() + + +def get_world_size(): + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + else: + return 1 + + +def get_rank(): + if "RANK" in os.environ: + return int(os.environ["RANK"]) + elif dist.is_available() and dist.is_initialized(): + return dist.rank() + else: + return 1 + + +def get_local_rank(): + return int(os.environ.get("LOCAL_RANK", 0))