Skip to content

Commit

Permalink
[Bug Fix]Fix merge parameters in pp (#8239)
Browse files Browse the repository at this point in the history
* update merge pp

* update des

* renew
  • Loading branch information
Southpika authored Apr 16, 2024
1 parent 0790824 commit c3ec984
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 9 deletions.
3 changes: 2 additions & 1 deletion llm/docs/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./

```
python merge_tp_and_pp_params.py \
--model_name_or_path ./checkpoints/llama_sft_ckpts/checkpoint-100
--model_name_or_path ./checkpoints/llama_sft_ckpts/checkpoint-100 \
--pp 2 --tp 4
```

<summary>&emsp; 脚本参数介绍</summary><div>
Expand Down
48 changes: 40 additions & 8 deletions llm/merge_tp_and_pp_params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,7 @@
# limitations under the License.
import importlib
import os
import re

import paddle

Expand All @@ -27,9 +28,33 @@ def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", default=None, required=True, help="The directory of model.")
parser.add_argument("--device", type=str, default="gpu", help="Device")
parser.add_argument("--pipeline_parallel_degree", "--pp", type=int, required=True, help="pp degree")
parser.add_argument("--tensor_parallel_degree", "--tp", type=int, required=True, help="tp degree")
return parser.parse_args()


def validate_model_file(path: str, tp_degree: int, pp_degree: int) -> None:
files = os.listdir(path)
pattern = r"model_state\.tp0\d*_pp0\d*\.pdparams|model_state\.tp0\d*\.pdparams|model_state\.pp0\d*\.pdparams"
if pp_degree == 0:
target_files = [f"model_state.tp{tp:0>2d}.pdparams" for tp in range(tp_degree)]
elif tp_degree == 0:
target_files = [f"model_state.pp{pp:0>2d}.pdparams" for pp in range(pp_degree)]
else:
target_files = [
f"model_state.tp{tp:0>2d}_pp{pp:0>2d}.pdparams" for tp in range(tp_degree) for pp in range(pp_degree)
]

exist_required_files = []
for file in files:
if re.match(pattern, file):
exist_required_files.append(file)

missing_files = set(target_files) - set(exist_required_files)
if len(missing_files) > 0:
raise FileNotFoundError(f"Please check your pp/tp degree, missing files {list(missing_files)}")


def load_tp_params(tp_degree, path):
tp_state_dict_list = []
for tp in range(tp_degree):
Expand Down Expand Up @@ -102,23 +127,30 @@ def main():
paddle.set_device(args.device)
config = AutoConfig.from_pretrained(args.model_name_or_path)
init_class = config["architectures"][0]
import_class = importlib.import_module(f"paddlenlp.transformers.{MAPPING_NAMES[init_class[:-11]]}.modeling")
if args.pipeline_parallel_degree > 1:
# using pp
import_class = importlib.import_module(f"paddlenlp.transformers.{MAPPING_NAMES[init_class[:-15]]}.modeling_pp")
else:
# tp only
import_class = importlib.import_module(f"paddlenlp.transformers.{MAPPING_NAMES[init_class[:-11]]}.modeling")
model_class = getattr(import_class, init_class)

if config.tensor_parallel_degree > 1:
if config.pipeline_parallel_degree > 1:
validate_model_file(args.model_name_or_path, args.tensor_parallel_degree, args.pipeline_parallel_degree)

if args.tensor_parallel_degree > 1:
if args.pipeline_parallel_degree > 1:
tp_state_dict_list = load_tp_and_pp_params(
config.tensor_parallel_degree, config.pipeline_parallel_degree, args.model_name_or_path
args.tensor_parallel_degree, args.pipeline_parallel_degree, args.model_name_or_path
)
else:
tp_state_dict_list = load_tp_params(config.tensor_parallel_degree, args.model_name_or_path)
tp_state_dict_list = load_tp_params(args.tensor_parallel_degree, args.model_name_or_path)
state_dict_to_save = merge_tensor_parallel(
model_class=model_class, state_dict_list=tp_state_dict_list, config=config
)
logger.info("Saving")
paddle.save(state_dict_to_save, os.path.join(args.model_name_or_path, "model_state.pdparams"))
elif config.pipeline_parallel_degree > 1:
state_dict_to_save = load_pp_params(config.pipeline_parallel_degree, args.model_name_or_path)
elif args.pipeline_parallel_degree > 1:
state_dict_to_save = load_pp_params(args.pipeline_parallel_degree, args.model_name_or_path)
logger.info("Saving")
paddle.save(state_dict_to_save, os.path.join(args.model_name_or_path, "model_state.pdparams"))
else:
Expand Down

0 comments on commit c3ec984

Please sign in to comment.