From 9b9018071b74bf022671f042c81899878451676b Mon Sep 17 00:00:00 2001 From: Namhyeok Date: Wed, 2 Jun 2021 13:35:22 +0900 Subject: [PATCH 1/3] predict span using top10 start&end position --- pororo/models/brainbert/BrainRoBERTa.py | 36 ++++++++++++++----------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/pororo/models/brainbert/BrainRoBERTa.py b/pororo/models/brainbert/BrainRoBERTa.py index 7154983..80e9a9e 100644 --- a/pororo/models/brainbert/BrainRoBERTa.py +++ b/pororo/models/brainbert/BrainRoBERTa.py @@ -241,22 +241,26 @@ def predict_span( tokens, return_logits=True, ).squeeze() # T x 2 - # first predict start position, - # then predict end position among the remaining logits - start = logits[:, 0].argmax().item() - mask = (torch.arange( - logits.size(0), dtype=torch.long, device=self.device) >= start) - end = (mask * logits[:, 1]).argmax().item() - # end position is shifted during training, so we add 1 back - answer_tokens = tokens[start:end + 1] - - answer = "" - if len(answer_tokens) >= 1: - decoded = self.decode(answer_tokens) - if isinstance(decoded, str): - answer = decoded - - return (answer, (start, end + 1)) + # predict top 10 start positions + # then predict top 10 end position among the remaining logits + results = [] + starts = logits[:, 0].argsort(descending=True).tolist() + for start in starts: + mask = (torch.arange( + logits.size(0), dtype=torch.long, device=self.device) >= start) + ends = (mask * logits[:, 1]).argsort(descending=True).tolist() + # end position is shifted during training, so we add 1 back + for end in ends: + answer_tokens = tokens[start:end + 1] + answer = "" + if len(answer_tokens) >= 1: + decoded = self.decode(answer_tokens) + if isinstance(decoded, str): + answer = decoded + score = (logits[:,0][start] + logits[:,1][end]).item() + results.append((answer, (start, end + 1), score)) + + return sorted(results,key=lambda x:x[2],reverse=True)[0] @torch.no_grad() def predict_tags( From 87e5951ee0c52ab3636bcddfecb84c30ed6a40d9 Mon Sep 17 00:00:00 2001 From: Namhyeok Date: Wed, 2 Jun 2021 13:36:06 +0900 Subject: [PATCH 2/3] score output --- pororo/tasks/machine_reading_comprehension.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pororo/tasks/machine_reading_comprehension.py b/pororo/tasks/machine_reading_comprehension.py index 4c00b14..acfe9cd 100644 --- a/pororo/tasks/machine_reading_comprehension.py +++ b/pororo/tasks/machine_reading_comprehension.py @@ -113,4 +113,5 @@ def predict( return ( span, pair_result[1], + pair_result[2], ) From 091e69e2514c3b2fcca958557cccbb7745f38edc Mon Sep 17 00:00:00 2001 From: Namhyeok Date: Wed, 2 Jun 2021 13:51:19 +0900 Subject: [PATCH 3/3] logit output --- pororo/models/brainbert/BrainRoBERTa.py | 4 ++-- pororo/tasks/machine_reading_comprehension.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pororo/models/brainbert/BrainRoBERTa.py b/pororo/models/brainbert/BrainRoBERTa.py index 80e9a9e..584d079 100644 --- a/pororo/models/brainbert/BrainRoBERTa.py +++ b/pororo/models/brainbert/BrainRoBERTa.py @@ -258,9 +258,9 @@ def predict_span( if isinstance(decoded, str): answer = decoded score = (logits[:,0][start] + logits[:,1][end]).item() - results.append((answer, (start, end + 1), score)) + results.append((answer, (start, end + 1), (logits[:,0][start].item(),logits[:,1][end].item()), score)) - return sorted(results,key=lambda x:x[2],reverse=True)[0] + return sorted(results,key=lambda x:x[3],reverse=True)[0] @torch.no_grad() def predict_tags( diff --git a/pororo/tasks/machine_reading_comprehension.py b/pororo/tasks/machine_reading_comprehension.py index acfe9cd..e2969b3 100644 --- a/pororo/tasks/machine_reading_comprehension.py +++ b/pororo/tasks/machine_reading_comprehension.py @@ -114,4 +114,5 @@ def predict( span, pair_result[1], pair_result[2], + pair_result[3], )