Skip to content

Commit

Permalink
fixed _all_scores_for_token method to correctly calculate token proba…
Browse files Browse the repository at this point in the history
…bility
  • Loading branch information
MdMotahar authored and helpmefindaname committed Oct 11, 2024
1 parent c674212 commit a0b3ea6
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
2 changes: 1 addition & 1 deletion flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def _all_scores_for_token(self, sentences: List[Sentence], score_tensor: torch.T
previous = 0
for length in lengths:
prob_tags_per_sentence.append(prob_all_tags[previous : previous + length])
previous = length
previous += length
return prob_tags_per_sentence

def _get_state_dict(self):
Expand Down
35 changes: 35 additions & 0 deletions tests/models/test_sequence_tagger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest
import torch
import torch.nn.functional as F

import flair
from flair.embeddings import FlairEmbeddings, WordEmbeddings
Expand Down Expand Up @@ -121,3 +123,36 @@ def test_train_load_use_tagger_disjunct_tags(
loaded_model.predict([example_sentence, self.empty_sentence])
loaded_model.predict([self.empty_sentence])
del loaded_model

@pytest.mark.integration()
def test_all_token_prob_distribution(self, embeddings, corpus):
tag_dictionary = corpus.make_label_dictionary("ner", add_unk=False)
model = self.build_model(embeddings, tag_dictionary)

# get features from forward propagation
sentences = [corpus.train[i] for i in range(len(corpus.train))]

# reverse sort all sequences by their length
sentences = sorted(sentences, key=len, reverse=True)

with torch.no_grad():
sentence_tensor, lengths = model._prepare_tensors(sentences)
features = model.forward(sentence_tensor, lengths)

# remove previously predicted labels of this type
for sentence in sentences:
sentence.remove_labels(model.label_type)

softmax_batch = F.softmax(features, dim=1).cpu()
lengths = [len(sentence) for sentence in sentences]
all_tokens_prob_distrib = model._all_scores_for_token(sentences, softmax_batch, lengths)

for i, sen_tokens_prob_distribution in enumerate(all_tokens_prob_distrib):
assert len(sen_tokens_prob_distribution) == lengths[i]
for token_prob_distrib, token in zip(sen_tokens_prob_distribution, sentences[i]):
assert len(token_prob_distrib) == len(model.label_dictionary)
score_sum = 0.0
for token_label in token_prob_distrib:
assert token_label.data_point == token
score_sum += token_label.score
assert abs(score_sum - 1.0) < 1.0e-5

0 comments on commit a0b3ea6

Please sign in to comment.