Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyuqin1998 committed May 31, 2024
1 parent 16eaedd commit e7c4b1e
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 46 deletions.
10 changes: 9 additions & 1 deletion docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,15 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
following config is support:
enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in data_parallel, which improve the performance. ONLY supported for auto mode now.
gradient_sync_after_accumulate, move gradient sync operations from backward into optimizer step when gradient accumulate enabling, which reduce the sync times to improve performance, but will increase the memory usage. ONLY supported for auto mode now.
--context_parallel_degree
上下文并行是将训练数据在序列维度进行切分的并行方法。
该方法使用Ring FlashAttention来保障切分后Attention结果的正确性。通过环状通信和迭代更新来得到完整的注意力分数。
默认值-1, 表示不启用上下文并行,
(`int`, 可选, 默认为 `-1`)
(注: 该方法需要修改模型结构, 目前支持LLAMA)
(注: 该方法对通信开销较大, 建议只有在序列长度超长时, 如1024k, 时才使用)
Context parallelism is a parallel method that segments training data in the sequence dimension.
This method uses Ring FlashAttention to ensure the correctness of the Attention result after segmentation. The complete attention score is obtained through ring communication and iterative updates.
--recompute
是否使用重计算训练。可以节省显存。
重新计算前向过程以获取梯度,减少中间变量显存.
Expand Down
8 changes: 2 additions & 6 deletions llm/llama/run_trainer_tp2cp2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,13 @@ unset PADDLE_ELASTIC_TIMEOUT

max_seq_length=1024

master=127.0.0.1
port=36677

max_steps=10000
max_steps=1000
log_dir=seq_${max_seq_length}_log
echo "log_dir:${log_dir}"
rm -rf $log_dir

export PYTHONPATH=../../:$PYTHONPATH
python -u -m paddle.distributed.launch \
--master $master:$port \
--gpus "3,4,5,7" \
--log_dir "./$log_dir" \
run_pretrain.py \
Expand Down Expand Up @@ -78,7 +74,7 @@ python -u -m paddle.distributed.launch \
--recompute_use_reentrant true \
--data_cache "./data_cache" \
--pipeline_parallel_degree 1 \
--cp_parallel_degree 2 \
--context_parallel_degree 2 \
--tensor_parallel_degree 2 \
--sequence_parallel false \
--skip_profile_timer true \
Expand Down
6 changes: 3 additions & 3 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,15 +485,15 @@ def main():
config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob

config.sep_parallel_degree = training_args.sep_parallel_degree
config.cp_parallel_degree = training_args.cp_parallel_degree
config.context_parallel_degree = training_args.context_parallel_degree
if config.sequence_parallel:
assert config.tensor_parallel_degree > 1, "tensor_parallel_degree must be larger than 1 for sequence parallel."
assert (
config.num_attention_heads % config.sep_parallel_degree == 0
), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}"
assert (
config.seq_length % config.cp_parallel_degree == 0
), f"seq_length:{config.seq_length} must be divisible by cp_parallel_degree {config.cp_parallel_degree}"
config.seq_length % config.context_parallel_degree == 0
), f"seq_length:{config.seq_length} must be divisible by context_parallel_degree {config.context_parallel_degree}"

if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
try:
Expand Down
8 changes: 4 additions & 4 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,8 +764,8 @@ def train(
trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size
if self.args.sep_parallel_degree > 0:
trainable_numel = trainable_numel // self.args.sep_parallel_degree
if self.args.cp_parallel_degree > 0:
trainable_numel = trainable_numel // self.args.cp_parallel_degree
if self.args.context_parallel_degree > 0:
trainable_numel = trainable_numel // self.args.context_parallel_degree

Check warning on line 768 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L767-L768

Added lines #L767 - L768 were not covered by tests
# the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited
# so, the trainable numel is a little bigger than real.
logger.debug(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)")
Expand Down Expand Up @@ -900,7 +900,7 @@ def _inner_training_loop(
for step, inputs in enumerate(epoch_iterator):
if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1:
inputs = split_inputs_sequence_dim(inputs)
if self.args.use_hybrid_parallel and self.args.cp_parallel_degree > 1:
if self.args.use_hybrid_parallel and self.args.context_parallel_degree > 1:
inputs = split_inputs_sequence_dim_load_balance(inputs)

Check warning on line 904 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L903-L904

Added lines #L903 - L904 were not covered by tests
self.timers and self.timers("read-data").stop()
os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step)
Expand Down Expand Up @@ -1765,7 +1765,7 @@ def _wrap_model(self, model, training=True):
in_sharding_parallel_mode = self.sharding is not None
in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1
in_sep_parallel_mode = self.args.sep_parallel_degree > 1
in_cp_parallel_mode = self.args.cp_parallel_degree > 1
in_cp_parallel_mode = self.args.context_parallel_degree > 1

# Multi-gpu training
if (
Expand Down
31 changes: 17 additions & 14 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ class TrainingArguments:
The paddle sequence parallel strategy. It can reduce the GPU memory of activation to 1/sep, and it is orthogonal to
data parallel, sharding stage1, tensor parallel and pipeline parallel strategy.
)
cp_parallel_degree (`int`, *optional*, defaults to `-1`)(
The paddle context parallel strategy. It can reduce the GPU memory of activation to 1/cp, and it is orthogonal to
data parallel, sharding stage1, tensor parallel and pipeline parallel strategy.
context_parallel_degree (`int`, *optional*, defaults to `-1`)(
Context parallelism is a parallel method that segments training data in the sequence dimension.
This method uses Ring FlashAttention to ensure the correctness of the Attention result after segmentation. The complete attention score is obtained through ring communication and iterative updates.
)
data_parallel_config (`str`, *optional*)(
Some additional configs which affect data parallel performance, we provide some option to config it.
Expand Down Expand Up @@ -587,7 +587,7 @@ class TrainingArguments:
)
},
)
cp_parallel_degree: int = field(
context_parallel_degree: int = field(
default=-1,
metadata={
"help": (
Expand Down Expand Up @@ -931,7 +931,7 @@ def __post_init__(self):
if world_size > 1:
tensor_parallel_degree = max(self.tensor_parallel_degree, 1)
sep_parallel_degree = max(self.sep_parallel_degree, 1)
cp_parallel_degree = max(self.cp_parallel_degree, 1)
context_parallel_degree = max(self.context_parallel_degree, 1)

Check warning on line 934 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L934

Added line #L934 was not covered by tests
pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1)

assert (
Expand All @@ -941,7 +941,10 @@ def __post_init__(self):
if self.sharding_parallel_degree == -1:
if len(self.sharding) > 0:
self.sharding_parallel_degree = world_size // (
tensor_parallel_degree * sep_parallel_degree * cp_parallel_degree * pipeline_parallel_degree
tensor_parallel_degree
* sep_parallel_degree

Check warning on line 945 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L945

Added line #L945 was not covered by tests
* context_parallel_degree
* pipeline_parallel_degree
)

sharding_parallel_degree = max(self.sharding_parallel_degree, 1)
Expand All @@ -953,7 +956,7 @@ def __post_init__(self):
sharding_parallel_degree
* tensor_parallel_degree
* sep_parallel_degree
* cp_parallel_degree
* context_parallel_degree
* pipeline_parallel_degree

Check warning on line 960 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L958-L960

Added lines #L958 - L960 were not covered by tests
)

Expand All @@ -962,22 +965,22 @@ def __post_init__(self):
or tensor_parallel_degree > 1
or pipeline_parallel_degree > 1
or self.sep_parallel_degree > 1
or self.cp_parallel_degree > 1
or self.context_parallel_degree > 1
):
self.use_hybrid_parallel = True
self.sharding_parallel_degree = sharding_parallel_degree
self.tensor_parallel_degree = tensor_parallel_degree
self.pipeline_parallel_degree = pipeline_parallel_degree
self.sep_parallel_degree = sep_parallel_degree
self.cp_parallel_degree = cp_parallel_degree
self.context_parallel_degree = context_parallel_degree

if not self.use_hybrid_parallel:
self.sharding = []
self.sharding_parallel_degree = -1
self.tensor_parallel_degree = -1
self.pipeline_parallel_degree = -1
self.sep_parallel_degree = -1
self.cp_parallel_degree = -1
self.context_parallel_degree = -1

Check warning on line 983 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L983

Added line #L983 was not covered by tests

if self.hybrid_parallel_topo_order is None:
self.hybrid_parallel_topo_order = "pp_first"
Expand Down Expand Up @@ -1180,7 +1183,7 @@ def is_segment_parallel_supported():
"sharding_degree": self.sharding_parallel_degree,
"sep_degree": self.sep_parallel_degree
if self.sep_parallel_degree > 1
else self.cp_parallel_degree,
else self.context_parallel_degree,

Check warning on line 1186 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1184-L1186

Added lines #L1184 - L1186 were not covered by tests
"order": order,
}
else:
Expand Down Expand Up @@ -1264,7 +1267,7 @@ def is_segment_parallel_supported():
elif self.enable_auto_parallel:
self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1)
self.sep_parallel_degree = max(self.sep_parallel_degree, 1)
self.cp_parallel_degree = max(self.cp_parallel_degree, 1)
self.context_parallel_degree = max(self.context_parallel_degree, 1)
self.pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1)

assert (
Expand All @@ -1276,7 +1279,7 @@ def is_segment_parallel_supported():
self.sharding_parallel_degree = world_size // (
self.tensor_parallel_degree
* self.sep_parallel_degree

Check warning on line 1281 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1281

Added line #L1281 was not covered by tests
* self.cp_parallel_degree
* self.context_parallel_degree
* self.pipeline_parallel_degree
)

Expand All @@ -1289,7 +1292,7 @@ def is_segment_parallel_supported():
self.sharding_parallel_degree
* self.tensor_parallel_degree
* self.sep_parallel_degree
* self.cp_parallel_degree
* self.context_parallel_degree

Check warning on line 1295 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1295

Added line #L1295 was not covered by tests
* self.pipeline_parallel_degree
)

Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def __init__(self, **kwargs):
self.tensor_parallel_rank = kwargs.pop("tensor_parallel_rank", 0)
# Parameters for sep and cp
self.sep_parallel_degree = kwargs.pop("sep_parallel_degree", -1)
self.cp_parallel_degree = kwargs.pop("cp_parallel_degree", -1)
self.context_parallel_degree = kwargs.pop("context_parallel_degree", -1)
# If set to True, this option is used with fleet.meta_parallel.ParallelCrossEntropy
# to calculate cross-entropy loss for parallel model.
self.tensor_parallel_output = kwargs.pop("tensor_parallel_output", False)
Expand Down
14 changes: 7 additions & 7 deletions paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ def fusion_rope(
position_ids,
past_key_value,
rotary_emb,
cp_parallel_degree=-1,
context_parallel_degree=-1,
):
if get_env_device() != "gcu":
assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
if cp_parallel_degree > 1:
if context_parallel_degree > 1:
assert get_env_device() == "gpu", "context parallel only support cuda device for now"
kv_seq_len *= cp_parallel_degree
kv_seq_len *= context_parallel_degree

Check warning on line 73 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L71-L73

Added lines #L71 - L73 were not covered by tests
if get_env_device() != "gcu":
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
if get_env_device() == "npu":
Expand Down Expand Up @@ -156,7 +156,7 @@ def fusion_flash_attention(
if version != "0.0.0" and version <= "2.5.2":
if alibi is not None:
raise ValueError("Flash Attention doesn't support alibi")
if config.cp_parallel_degree > 1:
if config.context_parallel_degree > 1:
raise ValueError(f"Context parallel is not implemented in version {version}")

Check warning on line 160 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L159-L160

Added lines #L159 - L160 were not covered by tests
attn_output, attn_weights = flash_attention(
query_states,
Expand All @@ -170,7 +170,7 @@ def fusion_flash_attention(
alibi = alibi.reshape([bsz, num_heads, 1, -1])
attention_mask = attention_mask.cast(alibi.dtype) + alibi
if get_env_device() == "npu":
if config.cp_parallel_degree > 1:
if config.context_parallel_degree > 1:
raise ValueError("Context parallel is not implemented for npu")

Check warning on line 174 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L173-L174

Added lines #L173 - L174 were not covered by tests
attn_output = core.eager._run_custom_op(
"flash_attention_npu",
Expand All @@ -186,7 +186,7 @@ def fusion_flash_attention(
npu_is_casual,
)[0]
elif get_env_device() == "gcu":
if config.cp_parallel_degree > 1:
if config.context_parallel_degree > 1:
raise ValueError("Context parallel is not implemented for gcu")

Check warning on line 190 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L189-L190

Added lines #L189 - L190 were not covered by tests
attn_output = core.eager._run_custom_op(
"fused_sdp_flash_attention_gcu",
Expand All @@ -199,7 +199,7 @@ def fusion_flash_attention(
True,
)[0]
else:
if config.cp_parallel_degree > 1:
if config.context_parallel_degree > 1:
attn_output = RingFlashAttention.apply(

Check warning on line 203 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L202-L203

Added lines #L202 - L203 were not covered by tests
query_states,
key_states,
Expand Down
20 changes: 10 additions & 10 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def scaled_dot_product_attention(
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]

else:
if config.cp_parallel_degree > 1:
if config.context_parallel_degree > 1:
raise ValueError("Context parallel requires `use_flash_attention=True`")

# [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
Expand Down Expand Up @@ -935,7 +935,7 @@ def forward(
if self.reshard_layer is not None:
batch_size, seq_length, _, _ = query_states.shape
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
if self.config.cp_parallel_degree > 1:
if self.config.context_parallel_degree > 1:
batch_size, seq_length, _, _ = query_states.shape
group = fleet.get_hybrid_communicate_group().get_sep_parallel_group()
chunk_size = seq_length // 2

Check warning on line 941 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L940-L941

Added lines #L940 - L941 were not covered by tests
Expand All @@ -955,12 +955,12 @@ def forward(
position_ids,
past_key_value,
self.rotary_emb,
self.config.cp_parallel_degree,
self.config.context_parallel_degree,
)

else:
if self.config.cp_parallel_degree > 1:
kv_seq_len *= self.config.cp_parallel_degree
if self.config.context_parallel_degree > 1:
kv_seq_len *= self.config.context_parallel_degree
if self.config.use_long_sequence_strategies:
cos, sin = self.rotary_emb(seq_len=kv_seq_len)
cos = cos[None, :, None, :]
Expand Down Expand Up @@ -1529,7 +1529,7 @@ def forward(
# [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
inputs_embeds = ScatterOp.apply(inputs_embeds)

if self.config.cp_parallel_degree > 1 and (attention_mask is not None or self.config.alibi):
if self.config.context_parallel_degree > 1 and (attention_mask is not None or self.config.alibi):

Check warning on line 1532 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1532

Added line #L1532 was not covered by tests
raise NotImplementedError("Ring FlashAttention dosen't support attention_mask or alibi")
# embed positions
if attention_mask is None:
Expand Down Expand Up @@ -1674,7 +1674,7 @@ def forward(self, prediction_scores, masked_lm_labels):
with paddle.amp.auto_cast(False):
masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2))

if self.config.sep_parallel_degree > 1 or self.config.cp_parallel_degree > 1:
if self.config.sep_parallel_degree > 1 or self.config.context_parallel_degree > 1:
_hcg = fleet.get_hybrid_communicate_group()
masked_lm_loss = ConcatMaskedLoss.apply(masked_lm_loss, axis=1, group=_hcg.get_sep_parallel_group())
# skip ignore_index which loss == 0
Expand Down Expand Up @@ -1747,9 +1747,9 @@ def forward(self, hidden_states, tensor_parallel_output=None):
if self.config.sep_parallel_degree > 1:
assert seq_length % self.config.sep_parallel_degree == 0
seq_length = seq_length // self.config.sep_parallel_degree
if self.config.cp_parallel_degree > 1:
assert seq_length % self.config.cp_parallel_degree == 0
seq_length = seq_length // self.config.cp_parallel_degree
if self.config.context_parallel_degree > 1:
assert seq_length % self.config.context_parallel_degree == 0
seq_length = seq_length // self.config.context_parallel_degree
hidden_states = paddle.reshape_(hidden_states, [-1, seq_length, self.config.hidden_size])

if tensor_parallel_output is None:
Expand Down

0 comments on commit e7c4b1e

Please sign in to comment.