Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Fix]Fix merge parameters in pp #8239

Merged
merged 3 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llm/docs/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./

```
python merge_tp_and_pp_params.py \
--model_name_or_path ./checkpoints/llama_sft_ckpts/checkpoint-100
--model_name_or_path ./checkpoints/llama_sft_ckpts/checkpoint-100 \
--pp 2 --tp 4
```

<summary>&emsp; 脚本参数介绍</summary><div>
Expand Down
48 changes: 40 additions & 8 deletions llm/merge_tp_and_pp_params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,7 @@
# limitations under the License.
import importlib
import os
import re

import paddle

Expand All @@ -27,9 +28,33 @@ def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", default=None, required=True, help="The directory of model.")
parser.add_argument("--device", type=str, default="gpu", help="Device")
parser.add_argument("--pipeline_parallel_degree", "--pp", type=int, required=True, help="pp degree")
parser.add_argument("--tensor_parallel_degree", "--tp", type=int, required=True, help="tp degree")
return parser.parse_args()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改一下readme



def validate_model_file(path: str, tp_degree: int, pp_degree: int) -> None:
files = os.listdir(path)
pattern = r"model_state\.tp0\d*_pp0\d*\.pdparams|model_state\.tp0\d*\.pdparams|model_state\.pp0\d*\.pdparams"
if pp_degree == 0:
target_files = [f"model_state.tp{tp:0>2d}.pdparams" for tp in range(tp_degree)]
elif tp_degree == 0:
target_files = [f"model_state.pp{pp:0>2d}.pdparams" for pp in range(pp_degree)]
else:
target_files = [
f"model_state.tp{tp:0>2d}_pp{pp:0>2d}.pdparams" for tp in range(tp_degree) for pp in range(pp_degree)
]

exist_required_files = []
for file in files:
if re.match(pattern, file):
exist_required_files.append(file)

missing_files = set(target_files) - set(exist_required_files)
if len(missing_files) > 0:
raise FileNotFoundError(f"Please check your pp/tp degree, missing files {list(missing_files)}")


def load_tp_params(tp_degree, path):
tp_state_dict_list = []
for tp in range(tp_degree):
Expand Down Expand Up @@ -102,23 +127,30 @@ def main():
paddle.set_device(args.device)
config = AutoConfig.from_pretrained(args.model_name_or_path)
init_class = config["architectures"][0]
import_class = importlib.import_module(f"paddlenlp.transformers.{MAPPING_NAMES[init_class[:-11]]}.modeling")
if args.pipeline_parallel_degree > 1:
# using pp
import_class = importlib.import_module(f"paddlenlp.transformers.{MAPPING_NAMES[init_class[:-15]]}.modeling_pp")
else:
# tp only
import_class = importlib.import_module(f"paddlenlp.transformers.{MAPPING_NAMES[init_class[:-11]]}.modeling")
model_class = getattr(import_class, init_class)

if config.tensor_parallel_degree > 1:
if config.pipeline_parallel_degree > 1:
validate_model_file(args.model_name_or_path, args.tensor_parallel_degree, args.pipeline_parallel_degree)

if args.tensor_parallel_degree > 1:
if args.pipeline_parallel_degree > 1:
tp_state_dict_list = load_tp_and_pp_params(
config.tensor_parallel_degree, config.pipeline_parallel_degree, args.model_name_or_path
args.tensor_parallel_degree, args.pipeline_parallel_degree, args.model_name_or_path
)
else:
tp_state_dict_list = load_tp_params(config.tensor_parallel_degree, args.model_name_or_path)
tp_state_dict_list = load_tp_params(args.tensor_parallel_degree, args.model_name_or_path)
state_dict_to_save = merge_tensor_parallel(
model_class=model_class, state_dict_list=tp_state_dict_list, config=config
)
logger.info("Saving")
paddle.save(state_dict_to_save, os.path.join(args.model_name_or_path, "model_state.pdparams"))
elif config.pipeline_parallel_degree > 1:
state_dict_to_save = load_pp_params(config.pipeline_parallel_degree, args.model_name_or_path)
elif args.pipeline_parallel_degree > 1:
state_dict_to_save = load_pp_params(args.pipeline_parallel_degree, args.model_name_or_path)
logger.info("Saving")
paddle.save(state_dict_to_save, os.path.join(args.model_name_or_path, "model_state.pdparams"))
else:
Expand Down
Loading