Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] fix test_sample_generate bug #8157

Merged
merged 1 commit into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ unit-test:

.PHONY: install
install:
pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html
pip install -r requirements-dev.txt
pip install -r requirements.txt
pip install -r paddlenlp/experimental/autonlp/requirements.txt
Expand Down
12 changes: 0 additions & 12 deletions paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,18 +1209,6 @@ def sample(

# multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
next_tokens = paddle.multinomial(probs)
# # multinomial not support fp16 and bf16 currently, issue: https://github.com/PaddlePaddle/Paddle/issues/51852
# if probs.dtype == paddle.bfloat16 and top_k == 1:
# probs = probs.astype("float32")
# next_tokens = paddle.unsqueeze(paddle.argmax(probs, axis=-1), -1)
# else:
# # next_tokens = paddle.multinomial(probs)
# probs = probs.cpu()
# from paddlenlp.transformers.utils import device_guard

# with device_guard("cpu"):
# next_tokens = paddle.multinomial(probs)
# next_tokens = next_tokens.cuda()

if self.config.tensor_parallel_degree > 1:
# Maybe no need to broadcast if seed is set correclty.
Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/transformers/blenderbot/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ def __init__(
normalize_before=True,
weight_attr=None,
bias_attr=None,
*args,
**kwargs,
):
super(BlenderbotDecoderLayer, self).__init__(
d_model=d_model,
Expand All @@ -205,6 +207,8 @@ def __init__(
normalize_before=normalize_before,
weight_attr=weight_attr,
bias_attr=bias_attr,
*args,
**kwargs,
)

def forward(self, tgt, memory=None, tgt_mask=None, memory_mask=None, cache=None):
Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/transformers/blenderbot_small/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def __init__(
normalize_before=True,
weight_attr=None,
bias_attr=None,
*args,
**kwargs,
):
super(BlenderbotSmallDecoderLayer, self).__init__(
d_model=d_model,
Expand All @@ -138,6 +140,8 @@ def __init__(
normalize_before=normalize_before,
weight_attr=weight_attr,
bias_attr=bias_attr,
*args,
**kwargs,
)

def forward(self, tgt, memory=None, tgt_mask=None, memory_mask=None, cache=None):
Expand Down
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
paddlepaddle==2.5.1
paddleocr<2.7
pre-commit
pytest
Expand Down
7 changes: 6 additions & 1 deletion tests/transformers/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def _get_input_ids_and_config(self):
max_batch_size = 2
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:max_batch_size, :sequence_length]
# For test_sample_generate such as: NVIDIA_TF32_OVERRIDE=0 FLAGS_cudnn_deterministic=1 python3.10 -m pytest -svv tests/transformers/bloom/test_modeling.py::BloomModelTest_0::test_sample_generate
# There are serious memory bug for this tensor slice. which use the original tensor mem ptr for cold start
# Here we just clone the tensor to avoid this problem.
input_ids = input_ids.clone()
attention_mask = attention_mask[:max_batch_size, :sequence_length].unsqueeze([1, 2])

attention_mask = attention_mask * attention_mask.transpose([0, 1, 3, 2])
Expand Down Expand Up @@ -270,6 +274,7 @@ def _sample_generate(
logits_warper,
process_kwargs,
):

with paddle.no_grad():
output_generate = model.generate(
input_ids,
Expand Down Expand Up @@ -440,9 +445,9 @@ def test_greedy_generate(self):
self.assertListEqual(output_greedy[0].tolist(), output_generate[0].tolist())

def test_sample_generate(self):

for model_class in self.all_generative_model_classes.keys():
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
input_ids = input_ids.clone()
paddle.seed(124)
model = self._make_model_instance(config, model_class)
model.eval()
Expand Down
Loading