diff --git a/flair/data.py b/flair/data.py index 69d85baf9..38585fdaa 100644 --- a/flair/data.py +++ b/flair/data.py @@ -629,25 +629,9 @@ def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]: class Span(_PartOfSentence): """This class represents one textual span consisting of Tokens.""" - def __new__(self, tokens: List[Token]): - # check if the span already exists. If so, return it - unlabeled_identifier = self._make_unlabeled_identifier(tokens) - if unlabeled_identifier in tokens[0].sentence._known_spans: - span = tokens[0].sentence._known_spans[unlabeled_identifier] - return span - - # else make a new span - else: - span = super().__new__(self) - span.initialized = False - tokens[0].sentence._known_spans[unlabeled_identifier] = span - return span - def __init__(self, tokens: List[Token]) -> None: - if not self.initialized: - super().__init__(tokens[0].sentence) - self.tokens = tokens - self.initialized: bool = True + super().__init__(tokens[0].sentence) + self.tokens = tokens @property def start_position(self) -> int: @@ -696,26 +680,10 @@ def to_dict(self, tag_type: Optional[str] = None): class Relation(_PartOfSentence): - def __new__(self, first: Span, second: Span): - # check if the relation already exists. If so, return it - unlabeled_identifier = self._make_unlabeled_identifier(first, second) - if unlabeled_identifier in first.sentence._known_spans: - span = first.sentence._known_spans[unlabeled_identifier] - return span - - # else make a new relation - else: - span = super().__new__(self) - span.initialized = False - first.sentence._known_spans[unlabeled_identifier] = span - return span - def __init__(self, first: Span, second: Span) -> None: - if not self.initialized: - super().__init__(sentence=first.sentence) - self.first: Span = first - self.second: Span = second - self.initialized: bool = True + super().__init__(sentence=first.sentence) + self.first: Span = first + self.second: Span = second def __repr__(self) -> str: return str(self) @@ -793,7 +761,7 @@ def __init__( self.tokens: List[Token] = [] # private field for all known spans - self._known_spans: Dict[str, _PartOfSentence] = {} + self._known_parts: Dict[str, _PartOfSentence] = {} self.language_code: Optional[str] = language_code @@ -870,7 +838,7 @@ def get_relations(self, label_type: Optional[str] = None) -> List[Relation]: def get_spans(self, label_type: Optional[str] = None) -> List[Span]: spans: List[Span] = [] - for potential_span in self._known_spans.values(): + for potential_span in self._known_parts.values(): if isinstance(potential_span, Span) and (label_type is None or potential_span.has_label(label_type)): spans.append(potential_span) return sorted(spans) @@ -1047,8 +1015,7 @@ def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]: } def get_span(self, start: int, stop: int) -> Span: - span_slice = slice(start, stop) - return self[span_slice] + return self[start:stop] @typing.overload def __getitem__(self, idx: int) -> Token: ... @@ -1056,9 +1023,26 @@ def __getitem__(self, idx: int) -> Token: ... @typing.overload def __getitem__(self, s: slice) -> Span: ... + @typing.overload + def __getitem__(self, s: typing.Tuple[Span, Span]) -> Relation: ... + def __getitem__(self, subscript): - if isinstance(subscript, slice): - return Span(self.tokens[subscript]) + if isinstance(subscript, tuple): + first, second = subscript + identifier = "" + if isinstance(first, Span) and isinstance(second, Span): + identifier = Relation._make_unlabeled_identifier(first, second) + if identifier not in self._known_parts: + self._known_parts[identifier] = Relation(first, second) + + return self._known_parts[identifier] + elif isinstance(subscript, slice): + identifier = Span._make_unlabeled_identifier(self.tokens[subscript]) + + if identifier not in self._known_parts: + self._known_parts[identifier] = Span(self.tokens[subscript]) + + return self._known_parts[identifier] else: return self.tokens[subscript] @@ -1210,11 +1194,11 @@ def remove_labels(self, typename: str): token.remove_labels(typename) # labels also need to be deleted at all known spans - for span in self._known_spans.values(): + for span in self._known_parts.values(): span.remove_labels(typename) # remove spans without labels - self._known_spans = {k: v for k, v in self._known_spans.items() if len(v.labels) > 0} + self._known_parts = {k: v for k, v in self._known_parts.items() if len(v.labels) > 0} # delete labels at object itself super().remove_labels(typename) diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 669954849..dfc4a07eb 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -26,7 +26,6 @@ Corpus, FlairDataset, MultiCorpus, - Relation, Sentence, Token, get_spans_from_bio, @@ -731,9 +730,7 @@ def _convert_lines_to_sentence( tail_end = int(indices[3]) label = indices[4] # head and tail span indices are 1-indexed and end index is inclusive - relation = Relation( - first=sentence[head_start - 1 : head_end], second=sentence[tail_start - 1 : tail_end] - ) + relation = sentence[sentence[head_start - 1 : head_end], sentence[tail_start - 1 : tail_end]] remapped = self._remap_label(label) if remapped != "O": relation.add_label(typename="relation", value=remapped) diff --git a/flair/models/regexp_tagger.py b/flair/models/regexp_tagger.py index 35c244d96..499b3017e 100644 --- a/flair/models/regexp_tagger.py +++ b/flair/models/regexp_tagger.py @@ -41,7 +41,7 @@ def get_token_span(self, span: Tuple[int, int]) -> Span: """ span_start: int = self.__tokens_start_pos.index(span[0]) span_end: int = self.__tokens_end_pos.index(span[1]) - return Span(self.tokens[span_start : span_end + 1]) + return self.sentence[span_start : span_end + 1] class RegexpTagger: diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 53ccabac3..9e9f438e8 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -372,11 +372,9 @@ def _entity_pair_permutations( """ valid_entities: List[_Entity] = list(self._valid_entities(sentence)) - # Use a dictionary to find gold relation annotations for a given entity pair - relation_to_gold_label: Dict[str, str] = { - relation.unlabeled_identifier: relation.get_label(self.label_type, zero_tag_value=self.zero_tag_value).value - for relation in sentence.get_relations(self.label_type) - } + # ensure that all existing relations without label have the label set to zero_tag_value. + for relation in sentence.get_relations(self.label_type): + relation.set_label(self.label_type, relation.get_label(self.label_type, self.zero_tag_value).value) # Yield head and tail entity pairs from the cross product of all entities for head, tail in itertools.product(valid_entities, repeat=2): @@ -393,9 +391,10 @@ def _entity_pair_permutations( continue # Obtain gold label, if existing - original_relation: Relation = Relation(first=head.span, second=tail.span) - gold_label: Optional[str] = relation_to_gold_label.get(original_relation.unlabeled_identifier) - + gold_relation = sentence[head.span, tail.span] + gold_label: Optional[str] = gold_relation.get_label(self.label_type, zero_tag_value="O").value + if gold_label == "O": + gold_label = None yield head, tail, gold_label def _encode_sentence( @@ -479,7 +478,7 @@ def _encode_sentence_for_inference( tail=tail, gold_label=gold_label if gold_label is not None else self.zero_tag_value, ) - original_relation: Relation = Relation(first=head.span, second=tail.span) + original_relation: Relation = sentence[head.span, tail.span] yield masked_sentence, original_relation def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedSentence]: diff --git a/flair/models/relation_extractor_model.py b/flair/models/relation_extractor_model.py index 795e8a517..bb4f43154 100644 --- a/flair/models/relation_extractor_model.py +++ b/flair/models/relation_extractor_model.py @@ -82,7 +82,7 @@ def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Relation]: ): continue - relation = Relation(span_1, span_2) + relation = sentence[span_1, span_2] if self.training and self.train_on_gold_pairs_only and relation.get_label(self.label_type).value == "O": continue entity_pairs.append(relation) diff --git a/flair/models/tars_model.py b/flair/models/tars_model.py index a7a41bdb5..fd1896bd7 100644 --- a/flair/models/tars_model.py +++ b/flair/models/tars_model.py @@ -412,10 +412,9 @@ def _get_tars_formatted_sentence(self, label, sentence): for entity_label in sentence.get_labels(self.label_type): if entity_label.value == label: - new_span = Span( - [tars_sentence.get_token(token.idx + label_length) for token in entity_label.data_point] - ) - new_span.add_label(self.static_label_type, value="entity") + start_pos = entity_label.data_point[0].idx + label_length - 1 + end_pos = entity_label.data_point[-1].idx + label_length + tars_sentence[start_pos:end_pos].add_label(self.static_label_type, value="entity") tars_sentence.copy_context_from_sentence(sentence) return tars_sentence @@ -572,19 +571,16 @@ def predict( already_set_indices: List[int] = [] - sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1)) - sorted_x.reverse() - for tuple in sorted_x: - # get the span and its label - label = tuple[0] - + sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1), reverse=True) + for label, _ in sorted_x: + span = typing.cast(Span, label.data_point) label_length = ( 0 if not self.prefix else len(label.value.split(" ")) + len(self.separator.split(" ")) ) # determine whether tokens in this span already have a label tag_this = True - for token in label.data_point: + for token in span: corresponding_token = sentence.get_token(token.idx - label_length) if corresponding_token is None: tag_this = False @@ -596,9 +592,10 @@ def predict( # only add if all tokens have no label if tag_this: # make and add a corresponding predicted span - predicted_span = Span( - [sentence.get_token(token.idx - label_length) for token in label.data_point] - ) + start_pos = span.tokens[0].idx - label_length - 1 + end_pos = span.tokens[-1].idx - label_length + + predicted_span = sentence[start_pos:end_pos] predicted_span.add_label(label_name, value=label.value, score=label.score) # set indices so that no token can be tagged twice diff --git a/tests/test_labels.py b/tests/test_labels.py index 210a21588..0357725b7 100644 --- a/tests/test_labels.py +++ b/tests/test_labels.py @@ -189,9 +189,9 @@ def test_relation_tags(): sentence = Sentence("Humboldt Universität zu Berlin is located in Berlin .") # create two relation label - Relation(sentence[0:4], sentence[7:8]).add_label("rel", "located in") - Relation(sentence[0:2], sentence[3:4]).add_label("rel", "university of") - Relation(sentence[0:2], sentence[3:4]).add_label("syntactic", "apposition") + sentence[sentence[0:4], sentence[7:8]].add_label("rel", "located in") + sentence[sentence[0:2], sentence[3:4]].add_label("rel", "university of") + sentence[sentence[0:2], sentence[3:4]].add_label("syntactic", "apposition") # there should be two relation labels labels: List[Label] = sentence.get_labels("rel") diff --git a/tests/test_sentence.py b/tests/test_sentence.py index 3e3142264..ad5e85470 100644 --- a/tests/test_sentence.py +++ b/tests/test_sentence.py @@ -1,3 +1,6 @@ +import copy +import pickle + from flair.data import Sentence @@ -73,3 +76,37 @@ def test_start_end_position_pretokenized() -> None: (10, 18), (19, 20), ] + + +def test_spans_support_deepcopy() -> None: + sentence = Sentence(["I", "live", "in", "Vienna", "."]) + sentence[3:4].add_label("ner", "LOC") + + _ = copy.deepcopy(sentence) + + +def test_spans_support_pickle() -> None: + sentence = Sentence(["I", "live", "in", "Vienna", "."]) + sentence[3:4].add_label("ner", "LOC") + + pickle_data = pickle.dumps(sentence) + _ = pickle.loads(pickle_data) + + +def test_relations_support_deepcopy() -> None: + sentence = Sentence(["Vienna", "is", "the", "capital", "of", "Austria"]) + sentence[0:1].add_label("ner", "LOC") + sentence[5:6].add_label("ner", "LOC") + sentence[sentence[0:1], sentence[5:6]].add_label("rel", "capital") + + _ = copy.deepcopy(sentence) + + +def test_relations_support_pickle() -> None: + sentence = Sentence(["Vienna", "is", "the", "capital", "of", "Austria"]) + sentence[0:1].add_label("ner", "LOC") + sentence[5:6].add_label("ner", "LOC") + sentence[sentence[0:1], sentence[5:6]].add_label("rel", "capital") + + pickle_data = pickle.dumps(sentence) + _ = pickle.loads(pickle_data)