Skip to content

Commit

Permalink
修复示例中类型错误(#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tongjilibo committed Mar 5, 2024
1 parent a40a48a commit 3e44d65
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def predict(self, inputs, output_ids, states):
def generate(self, passage, topk=1, topp=0.95):
token_ids, segment_ids = tokenizer.encode(passage, maxlen=max_p_len)
a_ids = self.random_sample([token_ids, segment_ids], n=1, topp=topp)[0] # 基于随机采样
token_ids += list(a_ids)
token_ids += list(a_ids.cpu().numpy())
segment_ids += [1] * len(a_ids)
q_ids = self.beam_search([token_ids, segment_ids], topk=topk)[0] # 基于beam search
return (tokenizer.decode(q_ids.cpu().numpy()), tokenizer.decode(a_ids.cpu().numpy()))
Expand Down

1 comment on commit 3e44d65

@ocyisheng
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

token_ids += list(a_ids) 老哥这行代码多余了吧?
删除后我测试通过

Please sign in to comment.