Skip to content

Commit

Permalink
Merge pull request #3388 from flairNLP/bf/bio-entity-normalization
Browse files Browse the repository at this point in the history
Entity Mention Linker
  • Loading branch information
alanakbik authored Feb 8, 2024
2 parents ba948fd + 44d73a2 commit 17e2895
Show file tree
Hide file tree
Showing 15 changed files with 2,417 additions and 100 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,4 @@ venv.bak/

resources/taggers/
regression_train/
/doc_build/
19 changes: 19 additions & 0 deletions flair/class_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import inspect
from typing import Iterable, Optional, Type, TypeVar

T = TypeVar("T")


def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]:
for subclass in cls.__subclasses__():
yield from get_non_abstract_subclasses(subclass)
if inspect.isabstract(subclass):
continue
yield subclass


def get_state_subclass_by_name(cls: Type[T], cls_name: Optional[str]) -> Type[T]:
for sub_cls in get_non_abstract_subclasses(cls):
if sub_cls.__name__ == cls_name:
return sub_cls
raise ValueError(f"Could not find any class with name '{cls_name}'")
142 changes: 102 additions & 40 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,11 @@ class Label:
Default value for the score is 1.0.
"""

def __init__(self, data_point: "DataPoint", value: str, score: float = 1.0) -> None:
def __init__(self, data_point: "DataPoint", value: str, score: float = 1.0, **metadata) -> None:
self._value = value
self._score = score
self.data_point: DataPoint = data_point
self.metadata = metadata
super().__init__()

def set_value(self, value: str, score: float = 1.0):
Expand All @@ -235,14 +236,14 @@ def to_dict(self):
return {"value": self.value, "confidence": self.score}

def __str__(self) -> str:
return f"{self.data_point.unlabeled_identifier}{flair._arrow}{self._value} ({round(self._score, 4)})"
return f"{self.data_point.unlabeled_identifier}{flair._arrow}{self._value}{self.metadata_str} ({round(self._score, 4)})"

@property
def shortstring(self):
return f'"{self.data_point.text}"/{self._value}'

def __repr__(self) -> str:
return f"'{self.data_point.unlabeled_identifier}'/'{self._value}' ({round(self._score, 4)})"
return f"'{self.data_point.unlabeled_identifier}'/'{self._value}'{self.metadata_str} ({round(self._score, 4)})"

def __eq__(self, other):
return self.value == other.value and self.score == other.score and self.data_point == other.data_point
Expand All @@ -253,6 +254,13 @@ def __hash__(self):
def __lt__(self, other):
return self.data_point < other.data_point

@property
def metadata_str(self) -> str:
if not self.metadata:
return ""
rep = "/".join(f"{k}={v}" for k, v in self.metadata.items())
return f"/{rep}"

@property
def labeled_identifier(self):
return f"{self.data_point.unlabeled_identifier}/{self.value}"
Expand Down Expand Up @@ -336,16 +344,18 @@ def get_metadata(self, key: str) -> typing.Any:
def has_metadata(self, key: str) -> bool:
return key in self._metadata

def add_label(self, typename: str, value: str, score: float = 1.0):
def add_label(self, typename: str, value: str, score: float = 1.0, **metadata):
label = Label(self, value, score, **metadata)

if typename not in self.annotation_layers:
self.annotation_layers[typename] = [Label(self, value, score)]
self.annotation_layers[typename] = [label]
else:
self.annotation_layers[typename].append(Label(self, value, score))
self.annotation_layers[typename].append(label)

return self

def set_label(self, typename: str, value: str, score: float = 1.0):
self.annotation_layers[typename] = [Label(self, value, score)]
def set_label(self, typename: str, value: str, score: float = 1.0, **metadata):
self.annotation_layers[typename] = [Label(self, value, score, **metadata)]
return self

def remove_labels(self, typename: str):
Expand Down Expand Up @@ -375,28 +385,25 @@ def labels(self) -> List[Label]:
def unlabeled_identifier(self):
raise NotImplementedError

def _printout_labels(self, main_label=None, add_score: bool = True):
def _printout_labels(self, main_label=None, add_score: bool = True, add_metadata: bool = True):
all_labels = []
keys = [main_label] if main_label is not None else self.annotation_layers.keys()
if add_score:
for key in keys:
all_labels.extend(
[
f"{label.value} ({round(label.score, 4)})"
for label in self.get_labels(key)
if label.data_point == self
]
)
labels = "; ".join(all_labels)
if labels != "":
labels = flair._arrow + labels
else:
for key in keys:
all_labels.extend([f"{label.value}" for label in self.get_labels(key) if label.data_point == self])
labels = "/".join(all_labels)
if labels != "":
labels = "/" + labels
return labels

sep = "; " if add_score else "/"
sent_sep = flair._arrow if add_score else "/"
for key in keys:
for label in self.get_labels(key):
if label.data_point is not self:
continue
value = label.value
if add_metadata:
value = f"{value}{label.metadata_str}"
if add_score:
value = f"{value} ({label.score:.04f})"
all_labels.append(value)
if not all_labels:
return ""
return sent_sep + sep.join(all_labels)

def __str__(self) -> str:
return self.unlabeled_identifier + self._printout_labels()
Expand Down Expand Up @@ -431,6 +438,61 @@ def __len__(self) -> int:
raise NotImplementedError


class EntityCandidate:
"""A Concept as part of a knowledgebase or ontology."""

def __init__(
self,
concept_id: str,
concept_name: str,
database_name: str,
additional_ids: Optional[List[str]] = None,
synonyms: Optional[List[str]] = None,
description: Optional[str] = None,
):
"""A Concept as part of a knowledgebase or ontology.
Args:
concept_id: Identifier of the concept from the knowledgebase / ontology
concept_name: (Canonical) name of the concept from the knowledgebase / ontology
additional_ids: List of additional identifiers for the concept / entity in the KB / ontology
database_name: Name of the knowledgebase / ontology
synonyms: A list of synonyms for this entry
description: A description about the Concept to describe
"""
self.concept_id = concept_id
self.concept_name = concept_name
self.database_name = database_name
self.description = description
if additional_ids is None:
self.additional_ids = []
else:
self.additional_ids = additional_ids
if synonyms is None:
self.synonyms = []
else:
self.synonyms = synonyms

def __str__(self) -> str:
string = f"EntityLinkingCandidate: {self.database_name}:{self.concept_id} - {self.concept_name}"
if self.additional_ids:
string += f" - {'|'.join(self.additional_ids)}"
return string

def __repr__(self) -> str:
return str(self)

def to_dict(self) -> Dict[str, typing.Any]:
return {
"concept_id": self.concept_id,
"concept_name": self.concept_name,
"database_name": self.database_name,
"additional_ids": self.additional_ids,
"synonyms": self.synonyms,
"description": self.description,
}


DT = typing.TypeVar("DT", bound=DataPoint)
DT2 = typing.TypeVar("DT2", bound=DataPoint)

Expand All @@ -440,18 +502,18 @@ def __init__(self, sentence) -> None:
super().__init__()
self.sentence: Sentence = sentence

def add_label(self, typename: str, value: str, score: float = 1.0):
super().add_label(typename, value, score)
self.sentence.annotation_layers.setdefault(typename, []).append(Label(self, value, score))
def add_label(self, typename: str, value: str, score: float = 1.0, **metadata):
super().add_label(typename, value, score, **metadata)
self.sentence.annotation_layers.setdefault(typename, []).append(Label(self, value, score, **metadata))

def set_label(self, typename: str, value: str, score: float = 1.0):
def set_label(self, typename: str, value: str, score: float = 1.0, **metadata):
if len(self.annotation_layers.get(typename, [])) > 0:
# First we remove any existing labels for this PartOfSentence in self.sentence
self.sentence.annotation_layers[typename] = [
label for label in self.sentence.annotation_layers.get(typename, []) if label.data_point != self
]
self.sentence.annotation_layers.setdefault(typename, []).append(Label(self, value, score))
super().set_label(typename, value, score)
self.sentence.annotation_layers.setdefault(typename, []).append(Label(self, value, score, **metadata))
super().set_label(typename, value, score, **metadata)
return self

def remove_labels(self, typename: str):
Expand Down Expand Up @@ -537,21 +599,21 @@ def __len__(self) -> int:
def __repr__(self) -> str:
return self.__str__()

def add_label(self, typename: str, value: str, score: float = 1.0):
def add_label(self, typename: str, value: str, score: float = 1.0, **metadata):
# The Token is a special _PartOfSentence in that it may be initialized without a Sentence.
# therefore, labels get added only to the Sentence if it exists
if self.sentence:
super().add_label(typename=typename, value=value, score=score)
super().add_label(typename=typename, value=value, score=score, **metadata)
else:
DataPoint.add_label(self, typename=typename, value=value, score=score)
DataPoint.add_label(self, typename=typename, value=value, score=score, **metadata)

def set_label(self, typename: str, value: str, score: float = 1.0):
def set_label(self, typename: str, value: str, score: float = 1.0, **metadata):
# The Token is a special _PartOfSentence in that it may be initialized without a Sentence.
# Therefore, labels get set only to the Sentence if it exists
if self.sentence:
super().set_label(typename=typename, value=value, score=score)
super().set_label(typename=typename, value=value, score=score, **metadata)
else:
DataPoint.set_label(self, typename=typename, value=value, score=score)
DataPoint.set_label(self, typename=typename, value=value, score=score, **metadata)

def to_dict(self, tag_type: Optional[str] = None):
return {
Expand Down
12 changes: 12 additions & 0 deletions flair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@
# word sense disambiguation
# Expose all entity linking datasets
from .entity_linking import (
CTD_CHEMICALS_DICTIONARY,
CTD_DISEASES_DICTIONARY,
NCBI_GENE_HUMAN_DICTIONARY,
NCBI_TAXONOMY_DICTIONARY,
NEL_ENGLISH_AIDA,
NEL_ENGLISH_AQUAINT,
NEL_ENGLISH_IITB,
Expand All @@ -147,6 +151,8 @@
WSD_UFSAC,
WSD_WORDNET_GLOSS_TAGGED,
ZELDA,
EntityLinkingDictionary,
HunerEntityLinkingDictionary,
)

# Expose all relation extraction datasets
Expand Down Expand Up @@ -315,6 +321,7 @@
"SentenceDataset",
"MongoDataset",
"StringDataset",
"EntityLinkingDictionary",
"AGNEWS",
"ANAT_EM",
"AZDZ",
Expand Down Expand Up @@ -342,6 +349,7 @@
"FSU",
"GELLUS",
"GPRO",
"HunerEntityLinkingDictionary",
"HUNER_CELL_LINE",
"HUNER_CELL_LINE_CELL_FINDER",
"HUNER_CELL_LINE_CLL",
Expand Down Expand Up @@ -390,6 +398,10 @@
"LINNEAUS",
"LOCTEXT",
"MIRNA",
"NCBI_GENE_HUMAN_DICTIONARY",
"NCBI_TAXONOMY_DICTIONARY",
"CTD_DISEASES_DICTIONARY",
"CTD_CHEMICALS_DICTIONARY",
"NCBI_DISEASE",
"ONTONOTES",
"OSIRIS",
Expand Down
Loading

0 comments on commit 17e2895

Please sign in to comment.