Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sft flash mask #8664

Merged
merged 11 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llm/alignment/dpo/dpo_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ class DPOModelArgument:
"help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`."
},
)
use_attn_mask_start_row_indices: bool = field(
default=False, metadata={"help": "Whether to use attn_mask_start_row_indices in flash attention."}
flash_mask: bool = field(
default=False, metadata={"help": "Whether to use flash mask in flash attention."}
)
virtual_pp_degree: int = field(
default=1,
Expand Down
16 changes: 16 additions & 0 deletions llm/alignment/dpo/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import sys
import time
import inspect
from functools import partial

import paddle
Expand All @@ -36,8 +37,14 @@
preference_collate_fn,
preprocess_preference_data,
)
from paddlenlp.transformers import (
LlamaForCausalLM,
LlamaForCausalLMPipe,
)
from paddlenlp.utils.log import logger

flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe]


def main():
"""main"""
Expand Down Expand Up @@ -124,6 +131,15 @@ def main():
ref_model = AutoModelForCausalLM.from_config(ref_config)
model.set_state_dict(ref_model.state_dict())

if model_args.flash_mask and not model.config.use_flash_attention:
logger.warning(
"`flash_mask` must use with zero padding and flash attention."
)
model.config.use_flash_attention = True

if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
raise NotImplementedError(f"{model.__class__} not support flash mask.")

if model_args.tokenizer_name_or_path is not None:
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
else:
Expand Down
2 changes: 1 addition & 1 deletion llm/config/llama/dpo_argument.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"sharding_parallel_degree": 1,
"sharding": "stage1",
"use_flash_attention": true,
"use_attn_mask_start_row_indices":false,
"flash_mask":true,
"recompute": false,
"recompute_granularity": "full",
"dpo_beta": 0.1,
Expand Down
2 changes: 1 addition & 1 deletion llm/config/qwen/dpo_argument.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"sharding_parallel_degree": 1,
"sharding": "stage1",
"use_flash_attention": true,
"use_attn_mask_start_row_indices":false,
"flash_mask":false,
"recompute": false,
"recompute_granularity": "full",
"dpo_beta": 0.1,
Expand Down
23 changes: 20 additions & 3 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json
import os
import sys
import inspect
from functools import partial

import paddle
Expand Down Expand Up @@ -51,13 +52,17 @@
AutoTokenizer,
Llama3Tokenizer,
LlamaTokenizer,
LlamaForCausalLM,
LlamaForCausalLMPipe,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
from paddlenlp.utils.log import logger

# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
os.environ["USE_CASUAL_MASK"] = "False"

flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe]


def main():
# Arguments
Expand All @@ -77,6 +82,7 @@ def main():
raise ValueError(
"--do_train, --do_ptq, --do_gptq and --do_qat cannot work at the same time. Please choose only one at a time"
)


# Setup GPU & distributed training
paddle.set_device(training_args.device)
Expand Down Expand Up @@ -160,6 +166,16 @@ def main():
# NOTE(gongenlei): new add autotuner_benchmark
model = model_class.from_config(model_config, dtype=dtype)

if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention):
logger.warning(
"`flash_mask` must use with zero padding and flash attention."
)
data_args.zero_padding = True
model.config.use_flash_attention = True

if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
raise NotImplementedError(f"{model.__class__} not support flash mask.")

if training_args.do_train and model_args.neftune:
# Inspired by https://github.com/neelsjain/NEFTune
if hasattr(model, "get_input_embeddings"):
Expand Down Expand Up @@ -329,12 +345,12 @@ def neft_post_hook(module, input, output):
"Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM and QWen so far."
)
train_ds = (
train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding))
train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask))
if train_ds is not None
else None
)
ptq_ds = (
ptq_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding))
ptq_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask))
if ptq_ds is not None
else None
)
Expand All @@ -345,7 +361,7 @@ def neft_post_hook(module, input, output):
)
eval_zero_padding = False
dev_ds = (
dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, zero_padding=eval_zero_padding))
dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, zero_padding=eval_zero_padding, flash_mask=model_args.flash_mask))
if dev_ds is not None
else None
)
Expand Down Expand Up @@ -498,6 +514,7 @@ def compute_metrics_do_generation(eval_preds):
padding=padding,
max_label_length=max_length,
return_tensors="np",
return_attention_mask=not model_args.flash_mask,
pad_to_multiple_of=data_args.pad_to_multiple_of,
),
do_generation=data_args.eval_with_do_generation,
Expand Down
3 changes: 3 additions & 0 deletions llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ class ModelArgument:
aistudio_token: str = field(default=None, metadata={"help": "The token of aistudio"})
neftune: bool = field(default=False, metadata={"help": "Whether to apply NEFT"})
neftune_noise_alpha: float = field(default=5.0, metadata={"help": "NEFT noise alpha"})
flash_mask: bool = field(
default=False, metadata={"help": "Whether to use flash_mask in flash attention."}
)


@dataclass
Expand Down
22 changes: 17 additions & 5 deletions llm/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,12 @@ def tokenize_rounds_example(tokenizer, example, data_args, **kwargs):
return tokenized_source, labels


def convert_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False):
def convert_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False, flash_mask=False):
if tokenizer.chat_template is not None:
return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding)
return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding, flash_mask)

tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args)

if is_test:
return {
**tokenized_source,
Expand All @@ -194,12 +195,17 @@ def convert_example_common(example, tokenizer, data_args, is_test=True, zero_pad
if "position_ids" in tokenized_source:
features["position_ids"] = list(range(seq_length))
if zero_padding:
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)
if flash_mask:
features["attn_mask_startend_row_indices"] = (
[seq_length] * seq_length
)
else:
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)

return features


def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False):
def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False, flash_mask=False):
"""convert multi-rounds conversation example

Args:
Expand Down Expand Up @@ -227,7 +233,13 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, z
seq_length = len(input_ids)
features = {"input_ids": input_ids, "labels": labels}
if zero_padding:
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)
if flash_mask:
features["attn_mask_startend_row_indices"] = (
[seq_length] * seq_length
)
else:
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)


if "position_ids" in rounds_inputs:
rounds_inputs["position_ids"] = rounds_inputs["position_ids"][:-1]
Expand Down
28 changes: 28 additions & 0 deletions paddlenlp/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,11 @@
if return_tensors is None:
return_tensors = self.return_tensors
labels = [feature["labels"] for feature in batch] if "labels" in batch[0].keys() else None
use_attn_mask_startend_row_indices = (

Check warning on line 373 in paddlenlp/data/data_collator.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/data_collator.py#L373

Added line #L373 was not covered by tests
[feature["attn_mask_startend_row_indices"] for feature in batch]
if "attn_mask_startend_row_indices" in batch[0].keys()
else None
)
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
# same length to return tensors.
if labels is not None:
Expand All @@ -396,6 +401,29 @@
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
else:
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
if use_attn_mask_startend_row_indices is not None:
if self.max_length is not None:
max_length = self.max_length

Check warning on line 406 in paddlenlp/data/data_collator.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/data_collator.py#L404-L406

Added lines #L404 - L406 were not covered by tests
else:
max_length = max(len(l) for l in use_attn_mask_startend_row_indices)
if self.pad_to_multiple_of is not None:
max_length = (

Check warning on line 410 in paddlenlp/data/data_collator.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/data_collator.py#L408-L410

Added lines #L408 - L410 were not covered by tests
(max_length + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of * self.pad_to_multiple_of
)

for feature in batch:
pad_len = max_length - len(feature["attn_mask_startend_row_indices"])
remainder = np.zeros([1, pad_len], dtype=np.int32)
feature["attn_mask_startend_row_indices"] = (

Check warning on line 417 in paddlenlp/data/data_collator.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/data_collator.py#L414-L417

Added lines #L414 - L417 were not covered by tests
np.concatenate(
[remainder, np.array([feature["attn_mask_startend_row_indices"]], dtype=np.int32) + pad_len],
axis=-1,
)
if padding_side == "left"
wtmlon marked this conversation as resolved.
Show resolved Hide resolved
else np.concatenate(
[np.array([feature["attn_mask_startend_row_indices"]], dtype=np.int32), remainder], axis=-1
)
)

batch = self.tokenizer.pad(
batch,
Expand Down
10 changes: 5 additions & 5 deletions paddlenlp/datasets/zero_padding_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@
"chosen_labels",
"rejected_labels",
"response_indexs",
"attn_mask_start_row_indices",
"attn_mask_startend_row_indices",
]

@classmethod
def _pad_batch_records(cls, batch_records):
# Only consider supported input keys
input_keys = [key for key in batch_records[0].keys() if key in cls.supported_input_keys]
if "attn_mask_start_row_indices" not in input_keys and "attention_mask" not in input_keys:
if "attn_mask_startend_row_indices" not in input_keys and "attention_mask" not in input_keys:

Check warning on line 38 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L38

Added line #L38 was not covered by tests
input_keys.append("attention_mask")
batched_features = {key: [] for key in input_keys}
sequence_sum = 0
Expand All @@ -57,9 +57,9 @@

seq_length = len(record["input_ids"])
# If attention_mask is not given, assume it's causal mask
if "attn_mask_start_row_indices" in record:
attn_mask_start_row_indices = [i + sequence_sum for i in record["attn_mask_start_row_indices"]]
batched_features["attn_mask_start_row_indices"].extend(attn_mask_start_row_indices)
if "attn_mask_startend_row_indices" in record:
attn_mask_startend_row_indices = [i + sequence_sum for i in record["attn_mask_startend_row_indices"]]
batched_features["attn_mask_startend_row_indices"].extend(attn_mask_startend_row_indices)

Check warning on line 62 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L60-L62

Added lines #L60 - L62 were not covered by tests
else:
attention_mask = record.get("attention_mask", np.tril(np.ones([seq_length, seq_length], dtype=bool)))
batched_features["attention_mask"].append(attention_mask)
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@
else:
if attn_mask_startend_row_indices is not None:
assert alibi is None, "flash_attention_with_sparse_mask not support alibi"
if len(attn_mask_startend_row_indices.shape) == 2:
attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1)

Check warning on line 215 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L214-L215

Added lines #L214 - L215 were not covered by tests
attn_output = F.flash_attention_with_sparse_mask(
query_states,
key_states,
Expand Down
8 changes: 8 additions & 0 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1906,6 +1906,14 @@
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if attn_mask_startend_row_indices is not None and attention_mask is not None:
logger.warning(

Check warning on line 1911 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1911

Added line #L1911 was not covered by tests
"You have provided both attn_mask_startend_row_indices and attention_mask. "
"The attn_mask_startend_row_indices will be used."
)
attention_mask = None

Check warning on line 1915 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1915

Added line #L1915 was not covered by tests

outputs = self.llama(
input_ids, # [bs, seq_len]
position_ids=position_ids,
Expand Down
10 changes: 5 additions & 5 deletions paddlenlp/trl/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@
}
if "attention_mask" in batch:
dpo_inputs["attention_mask"] = batch["attention_mask"]
if "attn_mask_start_row_indices" in batch:
dpo_inputs["attn_mask_start_row_indices"] = batch["attn_mask_start_row_indices"]
if "attn_mask_startend_row_indices" in batch:
dpo_inputs["attn_mask_startend_row_indices"] = batch["attn_mask_startend_row_indices"]

Check warning on line 181 in paddlenlp/trl/dpo_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/dpo_trainer.py#L180-L181

Added lines #L180 - L181 were not covered by tests
if self.reference_free:
reference_chosen_logps, reference_rejected_logps = None, None
else:
Expand All @@ -194,8 +194,8 @@
}
if "attention_mask" in batch:
dpo_inputs["attention_mask"] = batch["attention_mask"]
if "attn_mask_start_row_indices" in batch:
dpo_inputs["attn_mask_start_row_indices"] = batch["attn_mask_start_row_indices"]
if "attn_mask_startend_row_indices" in batch:
dpo_inputs["attn_mask_startend_row_indices"] = batch["attn_mask_startend_row_indices"]

Check warning on line 198 in paddlenlp/trl/dpo_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/dpo_trainer.py#L197-L198

Added lines #L197 - L198 were not covered by tests
if self.reference_free:
reference_chosen_logps, reference_rejected_logps = None, None
else:
Expand Down Expand Up @@ -522,7 +522,7 @@
else:
first_stage_keys = [
"input_ids",
"attn_mask_start_row_indices",
"attn_mask_startend_row_indices",
"position_ids",
]

Expand Down
25 changes: 14 additions & 11 deletions paddlenlp/trl/trl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@
}

# attention mask
if model_args.use_attn_mask_start_row_indices:
output_dict["attn_mask_start_row_indices"] = (
if model_args.flash_mask:
output_dict["attn_mask_startend_row_indices"] = (

Check warning on line 163 in paddlenlp/trl/trl_data.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/trl_data.py#L162-L163

Added lines #L162 - L163 were not covered by tests
[seq_len] * prompt_len + [prompt_len + chosen_len] * chosen_len + [seq_len] * rejected_len
)
else:
Expand All @@ -183,14 +183,14 @@
"response_indexs": [],
}
sequence = batch[0]
if "attn_mask_start_row_indices" in sequence:
input_dict["attn_mask_start_row_indices"] = []
use_attn_mask_start_row_indices = True
if "attn_mask_startend_row_indices" in sequence:
input_dict["attn_mask_startend_row_indices"] = []
use_attn_mask_startend_row_indices = True

Check warning on line 188 in paddlenlp/trl/trl_data.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/trl_data.py#L186-L188

Added lines #L186 - L188 were not covered by tests
elif "attention_mask" in sequence:
input_dict["attention_mask"] = []
use_attn_mask_start_row_indices = False
use_attn_mask_startend_row_indices = False

Check warning on line 191 in paddlenlp/trl/trl_data.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/trl_data.py#L191

Added line #L191 was not covered by tests
else:
raise ValueError("attention_mask and attn_mask_start_row_indices are both None.")
raise ValueError("attention_mask and attn_mask_startend_row_indices are both None.")

Check warning on line 193 in paddlenlp/trl/trl_data.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/trl_data.py#L193

Added line #L193 was not covered by tests

for i, sequence in enumerate(batch):
difference = max_seq_len - len(sequence["input_ids"])
Expand All @@ -199,9 +199,12 @@
input_dict["position_ids"].append(sequence["position_ids"] + [0] * difference)
input_dict["chosen_labels"].append(sequence["chosen_labels"] + [0] * difference)
input_dict["rejected_labels"].append(sequence["rejected_labels"] + [0] * difference)
if use_attn_mask_start_row_indices:
input_dict["attn_mask_start_row_indices"].append(
[sequence["attn_mask_start_row_indices"] + [sequence["attn_mask_start_row_indices"][-1]] * difference]
if use_attn_mask_startend_row_indices:
input_dict["attn_mask_startend_row_indices"].append(

Check warning on line 203 in paddlenlp/trl/trl_data.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/trl_data.py#L202-L203

Added lines #L202 - L203 were not covered by tests
[
sequence["attn_mask_startend_row_indices"]
+ [sequence["attn_mask_startend_row_indices"][-1]] * difference
]
)
else:
input_dict["attention_mask"].append(
Expand All @@ -225,7 +228,7 @@
for key in input_dict:
if key == "attention_mask":
input_dict[key] = np.array(input_dict[key], dtype=bool)
elif key == "attn_mask_start_row_indices":
elif key == "attn_mask_startend_row_indices":

Check warning on line 231 in paddlenlp/trl/trl_data.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/trl_data.py#L231

Added line #L231 was not covered by tests
input_dict[key] = np.array(input_dict[key], dtype=np.int32)
else:
input_dict[key] = np.array(input_dict[key])
Expand Down
Loading