forked from NVIDIA/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
distributed checkpoint average script (NVIDIA#7721)
* dist checkpoint average script Signed-off-by: Yi Dong <[email protected]> * change name Signed-off-by: Yi Dong <[email protected]> * change log message Signed-off-by: Yi Dong <[email protected]> * address comments Signed-off-by: Yi Dong <[email protected]> --------- Signed-off-by: Yi Dong <[email protected]> Signed-off-by: maxime burchi <[email protected]>
- Loading branch information
Showing
1 changed file
with
160 additions
and
0 deletions.
There are no files selected for viewing
160 changes: 160 additions & 0 deletions
160
scripts/checkpoint_averaging/distributed_checkpoint_averaging.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Copyright 2017 Johns Hopkins University (Shinji Watanabe) | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
Example: python scripts/checkpoint_averaging/distributed_checkpoint_averaging.py \ | ||
--name_prefix=<checkpoint name> \ | ||
--checkpoint_dir=<folder with mp_rank_X subfolders containing checkpoints> | ||
--steps <optinally a list of checkpoint steps to average, if not provided, it will average all the checkpoints> | ||
will generate a new directory in each of the distributed checkpoint subfolders named <checkpoint name>-averaged | ||
""" | ||
|
||
import argparse | ||
import logging | ||
import os | ||
import shutil | ||
|
||
import zarr | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
'--name_prefix', help='Name of the final checkpoint. Will append -averaged automatically.', | ||
) | ||
parser.add_argument( | ||
'--checkpoint_dir', help='Folder containing all the distributed checkpoints.', | ||
) | ||
# list of checkpoint steps to average | ||
parser.add_argument( | ||
'--steps', | ||
nargs='+', | ||
type=int, | ||
help='List of checkpoint steps to average. If not specified, will average all.', | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
if args.steps is not None: | ||
logging.info(f"Will average only steps {args.steps}") | ||
|
||
# repeating for all ranks | ||
|
||
checkpoint_paths = [] | ||
for ckpt_dir in os.listdir(args.checkpoint_dir): | ||
logging.info("Processing %s", ckpt_dir) | ||
if ckpt_dir.endswith('0-last'): | ||
continue | ||
if args.steps is None: | ||
checkpoint_paths.append(ckpt_dir) | ||
else: | ||
for step in args.steps: | ||
key = f"-step={step}-" | ||
if key in ckpt_dir: | ||
checkpoint_paths.append(ckpt_dir) | ||
|
||
n = len(checkpoint_paths) | ||
# initialize dict, will be used to store the weights that need to be averaged | ||
avg_weights = {} | ||
|
||
logging.info(f"Averaging {n} checkpoints ... {'at steps:' + str(args.steps) if args.steps is not None else ''}") | ||
|
||
# item that needs to be copied to the new checkpoint folder | ||
copy_items = [] | ||
for ix, path in enumerate(checkpoint_paths): | ||
full_path = os.path.join(args.checkpoint_dir, path) | ||
|
||
for item in os.listdir(full_path): | ||
|
||
# if item is not a directory, skip it | ||
if not os.path.isdir(os.path.join(full_path, item)): | ||
if ix == 0: | ||
copy_items.append(os.path.join(full_path, item)) | ||
continue | ||
|
||
# transformer engine states, leave them out | ||
if item.endswith('._extra_state'): | ||
if ix == 0: | ||
copy_items.append(os.path.join(full_path, item)) | ||
continue | ||
|
||
# optimizer states, no point of averaing them | ||
if item.startswith('optimizer.'): | ||
if ix == 0: | ||
copy_items.append(os.path.join(full_path, item)) | ||
continue | ||
|
||
if item not in avg_weights: | ||
logging.info(f"Initialized average weights dict with: {item}") | ||
avg_weights[item] = zarr.open(os.path.join(full_path, item), mode='r') | ||
else: | ||
logging.info(f"Updated average weights dict with weight: {item}") | ||
array_z = zarr.open(os.path.join(full_path, item), mode='r') | ||
sum_array = avg_weights[item][:] + array_z[:] | ||
avg_weights[item] = zarr.array(sum_array, chunks=array_z.chunks, dtype=array_z.dtype) | ||
|
||
for k in avg_weights: | ||
logging.info(f"Average weights dict key : {k}, dtype : {avg_weights[k].dtype}, shape : {avg_weights[k].shape}") | ||
if str(avg_weights[k].dtype).startswith("int"): | ||
raise ValueError("Int type not supported") | ||
else: | ||
array_z = avg_weights[k][:] | ||
array_z = array_z / n | ||
avg_weights[k] = zarr.array(array_z, chunks=avg_weights[k].chunks, dtype=avg_weights[k].dtype) | ||
|
||
# Save model | ||
if args.steps is None: | ||
ckpt_name = os.path.join(args.checkpoint_dir, args.name_prefix + '-averaged') | ||
else: | ||
steps_combined = '_'.join([str(x) for x in args.steps]) | ||
ckpt_name = os.path.join(args.checkpoint_dir, args.name_prefix + '-' + steps_combined + '-averaged') | ||
|
||
# save avg_weights | ||
for k in avg_weights: | ||
logging.info(f"Saving {k} to {ckpt_name}") | ||
zarr.save(os.path.join(ckpt_name, k), avg_weights[k]) | ||
|
||
# copy other files | ||
for item in copy_items: | ||
is_file = os.path.isfile(item) | ||
logging.info(f"Copying {'directory' if is_file else 'file'} {item} to {ckpt_name}") | ||
if os.path.isfile(item): | ||
# copy single file | ||
shutil.copy(item, ckpt_name) | ||
else: | ||
# copy directory | ||
shutil.copytree(item, os.path.join(ckpt_name, os.path.basename(item)), dirs_exist_ok=True) | ||
|
||
logging.info(f"Averaged distributed checkpoint saved as : {ckpt_name}") | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |