Skip to content

Commit

Permalink
Release/2.8 (#8437)
Browse files Browse the repository at this point in the history
* [XPU] llama add xpu support (#8282)

* [XPU] llama add xpu support

* fix

* use try import

* fix

* refine

* refine

* refine

* refine

* update (#8399)

* [LLM] Support fuse attention q, k, v weights  (#8202)

1. add use-interface & fuse action

1.1. modify 1., code order

2. switch to name_mapping

3. solve tp branch

3.2 follow hui, handel qkv separately

3.3 handle pdparams

3.4 from torch

3.5 abandon low_cpu_mem_usage

3.6 solve shard branch

* 3.6.1 solve shard branch after rebase develop

* code clean

* remove debug comment

* Redefine fuse and split functions

* Redefine fuse and split functions

* comment and fix

* update method

* update QKV fuse and split

* support fuse weights in multi-files

* add precision compare

* simplify function call

* support use_fast_ffn

* clean modeling and configuration

* add test for gpt and opt

* fix tp_actions get

* add fast_ffn test

* add Qwen2Moe

* Revert "add Qwen2Moe"

This reverts commit 113b883.

* add test for split

* update doc

* update filter_dict_keys

---------

Co-authored-by: Zii <[email protected]>

* [LLM] Fix fuse or split with same key (#8378)

* fix fuse or split with same key

* fix

* fix eps

* update format

* [LLM] add decay steps option for finetuning (#8251)

* [LLM] add memory stats to logger of trainer (#8269)

* [Distributed] fix lora (#8325)

* [LLM] fix lora target modules on llama (#8372)

* [Distributed] metric calculation supports tp logits (#8370)

* Update model_utils.py

* Update model_utils.py

* Update model_utils.py

---------

Co-authored-by: Jianbang Yang <[email protected]>
Co-authored-by: DrownFish19 <[email protected]>
Co-authored-by: Zii <[email protected]>
Co-authored-by: Tian <[email protected]>
  • Loading branch information
5 people authored May 24, 2024
1 parent 8879f79 commit bbf945b
Show file tree
Hide file tree
Showing 16 changed files with 886 additions and 47 deletions.
6 changes: 3 additions & 3 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def main():
if not training_args.autotuner_benchmark:
model = AutoModelForCausalLMPipe.from_pretrained(
model_args.model_name_or_path,
tensor_parallel_output=False,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
use_flash_attention=model_args.use_flash_attention,
Expand All @@ -152,7 +152,7 @@ def main():
# NOTE(gongenlei): new add autotuner_benchmark
model_config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
tensor_parallel_output=False,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
dtype=dtype,
Expand All @@ -163,7 +163,7 @@ def main():
else:
model_config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
tensor_parallel_output=False,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
dtype=dtype,
Expand Down
11 changes: 11 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device


def add_start_docstrings(*docstr):
Expand Down Expand Up @@ -483,6 +484,16 @@ def main():
config.num_attention_heads % config.sep_parallel_degree == 0
), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}"

if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
try:
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401

LinearConfig.enable_accumulate_steps_opt()
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
except ImportError:
# It's OK, not use accumulate_steps optimization
pass

print("Final pre-training config:", config)

# Set the dtype for loading model
Expand Down
9 changes: 9 additions & 0 deletions llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,11 @@ def get_lora_target_modules(model):
".*v_proj.*",
".*k_proj.*",
".*o_proj.*",
".*qkv_proj.*",
".*gate_proj.*",
".*down_proj.*",
".*up_proj.*",
".*gate_up_fused_proj.*",
]
elif model.base_model_prefix == "opt":
target_modules = [
Expand Down Expand Up @@ -209,6 +211,13 @@ def prediction_step(
# keepdim in order to maintain the same shape as logits
if isinstance(logits, (list, tuple)):
logits = logits[0]
# all gather logits when enabling tensor_parallel_output
if self.args.tensor_parallel_degree > 1 and self.args.tensor_parallel_output:
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
gathered_logits = []
dist.all_gather(gathered_logits, logits, group=model_parallel_group)
logits = paddle.concat(gathered_logits, axis=-1)
return (loss, logits.argmax(axis=-1, keepdim=True), labels)

loss = None
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def forward(self, input: paddle.Tensor):
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)
else:
res_mp = MC2ColumnParallelCoreLinear.apply(input, self.weight, self.model_parallel_group)
result_mp = res_mp + self.bias
result_mp = (res_mp + self.bias) if self.bias is not None else res_mp

if not self.merged:
input_a = self.lora_dropout(input) @ self.lora_A
Expand Down
18 changes: 16 additions & 2 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import paddle.distributed as dist
import paddle.nn as nn
from packaging import version
from paddle import framework
from paddle.base import core
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
HybridParallelOptimizer,
Expand Down Expand Up @@ -1257,6 +1259,20 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate()))
logs["global_step"] = int(self.state.global_step)

divisor = 2**30
# TODO(@gexiao): replace these codes with unified APIs in Paddle
current_device = framework._current_expected_place_()
if str(current_device) != "Place(cpu)":
device_id = current_device.get_device_id()
current_memory_allocated = core.device_memory_stat_current_value("Allocated", device_id)
current_memory_reserved = core.device_memory_stat_current_value("Reserved", device_id)
max_memory_allocated = core.device_memory_stat_peak_value("Allocated", device_id)
max_memory_reserved = core.device_memory_stat_peak_value("Reserved", device_id)
logs["current_memory_allocated"] = current_memory_allocated / divisor
logs["current_memory_reserved"] = current_memory_reserved / divisor
logs["max_memory_allocated"] = max_memory_allocated / divisor
logs["max_memory_reserved"] = max_memory_reserved / divisor

total_train_batch_size = (
self.args.train_batch_size * self.args.gradient_accumulation_steps * self.args.dataset_world_size
)
Expand Down Expand Up @@ -1614,8 +1630,6 @@ def _load_rng_state(self, checkpoint):
random.setstate(checkpoint_rng_state["python"])
np.random.set_state(checkpoint_rng_state["numpy"])

core = paddle.framework.core

core.default_cpu_generator().set_state(checkpoint_rng_state["cpu"])
if core.is_compiled_with_cuda():
if not len(checkpoint_rng_state["cuda"]) == core.get_cuda_device_count():
Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,10 @@ class TrainingArguments:
default=False,
metadata={"help": "whether to run distributed training in auto parallel mode"},
)
tensor_parallel_output: Optional[bool] = field(
default=False,
metadata={"help": "whether to output logits in distributed status"},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down
Loading

0 comments on commit bbf945b

Please sign in to comment.