From 3e44d65d5ab3ecfa7bb0b1830398abd53fb12673 Mon Sep 17 00:00:00 2001 From: tongjilibo Date: Tue, 5 Mar 2024 09:11:52 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E7=A4=BA=E4=BE=8B=E4=B8=AD?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E9=94=99=E8=AF=AF(#165)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/seq2seq/task_question_answer_generation_by_seq2seq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/seq2seq/task_question_answer_generation_by_seq2seq.py b/examples/seq2seq/task_question_answer_generation_by_seq2seq.py index bb43c7d8..3e4a4718 100644 --- a/examples/seq2seq/task_question_answer_generation_by_seq2seq.py +++ b/examples/seq2seq/task_question_answer_generation_by_seq2seq.py @@ -145,7 +145,7 @@ def predict(self, inputs, output_ids, states): def generate(self, passage, topk=1, topp=0.95): token_ids, segment_ids = tokenizer.encode(passage, maxlen=max_p_len) a_ids = self.random_sample([token_ids, segment_ids], n=1, topp=topp)[0] # 基于随机采样 - token_ids += list(a_ids) + token_ids += list(a_ids.cpu().numpy()) segment_ids += [1] * len(a_ids) q_ids = self.beam_search([token_ids, segment_ids], topk=topk)[0] # 基于beam search return (tokenizer.decode(q_ids.cpu().numpy()), tokenizer.decode(a_ids.cpu().numpy()))