Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add biomedical entity normalization #3180

Closed
wants to merge 30 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
641a3c0
Initial version (already adapted to recent Flair API changes)
Mar 14, 2023
9779abf
Revise mention text pre-processing: define general interface and adap…
Mar 14, 2023
8da7d75
Refactor entity linking model structure
Mar 15, 2023
e34c831
Update documentation
Mar 22, 2023
f54925c
Introduce separate methods for pre-processing (1) entity mentions fro…
Mar 23, 2023
90a0acb
Merge branch 'master' into bio-entity-normalization
alanakbik Apr 21, 2023
f1f51fd
Fix formatting
alanakbik Apr 21, 2023
f2f21d3
feat(test): biomedical entity linking
Apr 26, 2023
82c1b8b
fix(requirements): add faiss
Apr 26, 2023
2e3cda3
fix(test): hold on w/ automatic tests for now
Apr 26, 2023
adb231e
fix(bionel): start major refactoring
Apr 26, 2023
c80f1be
fix(bionel): major refactor
Apr 27, 2023
d10d297
fix(bionel): assign entity type
May 2, 2023
25ba2dd
fix(biencoder): set sparse encoder and weight
May 2, 2023
4525d3b
fix(bionel): address comments
May 11, 2023
3a5913d
fix(candidate_generator): container for search result
May 12, 2023
734d895
fix(predict): default annotation layer iff not provided by use
May 19, 2023
d79f871
fix(label): scores can be >= or <=
May 19, 2023
118fb95
fix(candidate): parametrize database name
May 19, 2023
1fcfddf
feat(candidate_generator): cache sparse encoder
May 22, 2023
9322c1b
fix(candidate_generator): minor improvements
May 23, 2023
071f51e
feat(linking_candidate): pretty print
May 24, 2023
a23f360
fix(candidate_generator): check sparse encoder for sparse search
May 24, 2023
ce29290
chore: crystal clear dictionary name
Jun 1, 2023
0d65336
feat(candidate_generator): add sparse index
Jun 1, 2023
02812f0
fix(candidate_generator): KISS: sparse search w/ scipy sparse matrices
Jun 2, 2023
ca6eee8
Minor update to comments and documentation
Jul 12, 2023
6c8f219
Fix tests and type annotations
Jul 12, 2023
2fa43cc
Merge branch 'master' into bio-entity-normalization
Jul 12, 2023
d90d92d
Merge
Jul 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 102 additions & 61 deletions flair/models/biomedical_entity_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,23 @@
logger = logging.getLogger("flair")


PRETRAINED_MODELS = [
PRETRAINED_DENSE_MODELS = [
"cambridgeltl/SapBERT-from-PubMedBERT-fulltext",
]

# Dense + sparse retrieval
PRETRAINED_HYBRID_MODELS = [
"dmis-lab/biosyn-sapbert-bc5cdr-disease",
"dmis-lab/biosyn-sapbert-ncbi-disease",
"dmis-lab/biosyn-sapbert-bc5cdr-chemical",
"dmis-lab/biosyn-biobert-bc5cdr-disease",
"dmis-lab/biosyn-biobert-ncbi-disease",
"dmis-lab/biosyn-biobert-bc5cdr-chemical",
"dmis-lab/biosyn-biobert-bc2gn",
"dmis-lab/biosyn-sapbert-bc2gn",
]
PRETRAINED_HYBRID_MODELS = {
"dmis-lab/biosyn-sapbert-bc5cdr-disease": "disease",
"dmis-lab/biosyn-sapbert-ncbi-disease": "disease",
"dmis-lab/biosyn-sapbert-bc5cdr-chemical": "chemical",
"dmis-lab/biosyn-biobert-bc5cdr-disease": "disease",
"dmis-lab/biosyn-biobert-ncbi-disease": "disease",
"dmis-lab/biosyn-biobert-bc5cdr-chemical": "chemical",
"dmis-lab/biosyn-biobert-bc2gn": "gene",
"dmis-lab/biosyn-sapbert-bc2gn": "gene",
}

PRETRAINED_MODELS = PRETRAINED_HYBRID_MODELS + PRETRAINED_MODELS
PRETRAINED_MODELS = list(PRETRAINED_HYBRID_MODELS) + PRETRAINED_DENSE_MODELS

# just in case we add: fuzzy search, Levenstein, ...
STRING_MATCHING_MODELS = ["exact-string-match"]
Expand All @@ -60,6 +60,13 @@

ENTITY_TYPES = ["disease", "chemical", "gene", "species"]

ENTITY_TYPE_TO_LABELS = {
"disease": "diseases",
"gene": "genes",
"species": "species",
"chemical": "chemical",
}

ENTITY_TYPE_TO_HYBRID_MODEL = {
"disease": "dmis-lab/biosyn-sapbert-bc5cdr-disease",
"chemical": "dmis-lab/biosyn-sapbert-bc5cdr-chemical",
Expand All @@ -80,6 +87,13 @@
"chemical": "ctd-chemical",
}

ENTITY_TYPE_TO_ANNOTATION_LAYER = {
"disease": "diseases",
"gene": "genes",
"chemical": "chemicals",
"species": "species",
}

BIOMEDICAL_DICTIONARIES = {
"ctd-disease": CTD_DISEASE_DICTIONARY,
"ctd-chemical": CTD_CHEMICAL_DICTIONARY,
Expand Down Expand Up @@ -438,7 +452,7 @@ def __call__(self, mentions: List[str]) -> torch.Tensor:
def save(self, path: Path) -> None:
with path.open("wb") as fout:
pickle.dump(self.encoder, fout)
logger.info("Sparse encoder saved in %s", path)
# logger.info("Sparse encoder saved in %s", path)

@classmethod
def load(cls, path: Path) -> "BigramTfIDFVectorizer":
Expand All @@ -448,7 +462,7 @@ def load(cls, path: Path) -> "BigramTfIDFVectorizer":
newVectorizer = cls()
with open(path, "rb") as fin:
newVectorizer.encoder = pickle.load(fin)
logger.info("Sparse encoder loaded from %s", path)
# logger.info("Sparse encoder loaded from %s", path)

return newVectorizer

Expand Down Expand Up @@ -785,7 +799,7 @@ def _load_emebddings(self, model_name_or_path: str, dictionary_name_or_path: str
if embeddings_cache_file.exists():

with embeddings_cache_file.open("rb") as fp:
logger.info("Load cached emebddings from %s", embeddings_cache_file)
logger.info("Load cached emebddings from: %s", embeddings_cache_file)
embeddings = pickle.load(fp)

else:
Expand Down Expand Up @@ -946,9 +960,16 @@ class BiomedicalEntityLinker:
entity / concept to these mentions according to a knowledge base / dictionary.
"""

def __init__(self, candidate_generator: AbstractCandidateGenerator, preprocessor: AbstractEntityPreprocessor):
def __init__(
self,
candidate_generator: AbstractCandidateGenerator,
preprocessor: AbstractEntityPreprocessor,
entity_type: str,
):
self.preprocessor = preprocessor
self.candidate_generator = candidate_generator
self.entity_type = entity_type
self.annotation_layer = ENTITY_TYPE_TO_ANNOTATION_LAYER[self.entity_type]
sg-wbi marked this conversation as resolved.
Show resolved Hide resolved

def build_entity_linking_label(self, data_point: Span, prediction: Tuple[str, str, int]) -> EntityLinkingLabel:
"""
Expand Down Expand Up @@ -983,7 +1004,8 @@ def build_entity_linking_label(self, data_point: Span, prediction: Tuple[str, st
)

def extract_mentions(
self, sentences: List[Sentence], input_entity_annotation_layer: Optional[str] = None
self,
sentences: List[Sentence],
) -> Tuple[List[int], List[Span], List[str]]:
"""
Unpack all mentions in sentences for batch search.
Expand All @@ -994,7 +1016,7 @@ def extract_mentions(
data_points = []
mentions = []
for i, sentence in enumerate(sentences):
for entity in sentence.get_labels(input_entity_annotation_layer):
for entity in sentence.get_labels(self.annotation_layer):
source.append(i)
data_points.append(entity.data_point)
mentions.append(
Expand All @@ -1003,20 +1025,21 @@ def extract_mentions(
else entity.data_point.text,
)

assert len(mentions) > 0, f"There are no entity mentions of type `{self.entity_type}`"
sg-wbi marked this conversation as resolved.
Show resolved Hide resolved

return source, data_points, mentions

def predict(
self,
sentences: Union[List[Sentence], Sentence],
input_entity_annotation_layer: str = None,
# input_entity_annotation_layer: str = None,
top_k: int = 1,
) -> None:
"""
Predicts the best matching top-k entity / concept identifiers of all named entites annotated
with tag input_entity_annotation_layer.

:param sentences: One or more sentences to run the prediction on
:param input_entity_annotation_layer: Entity type to run the prediction on
:param top_k: Number of best-matching entity / concept identifiers which should be predicted
per entity mention
"""
Expand All @@ -1028,11 +1051,9 @@ def predict(
self.preprocessor.initialize(sentences)

# Build label name
label_name = input_entity_annotation_layer + "_nen" if (input_entity_annotation_layer is not None) else "nen"
# label_name = input_entity_annotation_layer + "_nen" if (input_entity_annotation_layer is not None) else "nen"

source, data_points, mentions = self.extract_mentions(
sentences=sentences, input_entity_annotation_layer=input_entity_annotation_layer
)
source, data_points, mentions = self.extract_mentions(sentences=sentences)

# Retrieve top-k concept / entity candidates
predictions = self.candidate_generator.search(entity_mentions=mentions, top_k=top_k)
Expand All @@ -1041,7 +1062,7 @@ def predict(
for i, data_point, prediction in zip(source, data_points, predictions):

sentences[i].add_label(
typename=label_name,
typename=self.annotation_layer,
value_or_label=self.build_entity_linking_label(prediction=prediction, data_point=data_point),
)

Expand All @@ -1057,6 +1078,7 @@ def load(
preprocessor: AbstractEntityPreprocessor = Ab3PEntityPreprocessor.load(preprocessor=EntityPreprocessor()),
force_hybrid_search: bool = False,
sparse_weight: float = DEFAULT_SPARSE_WEIGHT,
entity_type: Optional[str] = None,
):
"""
Loads a model for biomedical named entity normalization.
Expand All @@ -1069,11 +1091,15 @@ def load(
)

if isinstance(model_name_or_path, str):
model_name_or_path = cls.__get_model_path(
model_name_or_path, entity_type = cls.__get_model_path_and_entity_type(
model_name_or_path=model_name_or_path,
entity_type=entity_type,
hybrid_search=hybrid_search,
force_hybrid_search=force_hybrid_search,
)
else:
assert entity_type is not None, "When using a custom model you must specify `entity_type`"
assert entity_type in ENTITY_TYPES, f"Invalid entity type `{entity_type}! Must be one of: {ENTITY_TYPES}"

if model_name_or_path == "exact-string-match":
candidate_generator = ExactMatchCandidateGenerator.load(dictionary_name_or_path)
Expand All @@ -1089,62 +1115,77 @@ def load(
preprocessor=preprocessor,
)

logger.info("Load model `%s` with dictionary `%s`", model_name_or_path, dictionary_name_or_path)
logger.info(
"BiomedicalEntityLinker predicts: Entity type: %s with Dictionary `%s`",
entity_type,
dictionary_name_or_path,
)

return cls(candidate_generator=candidate_generator, preprocessor=preprocessor)
return cls(candidate_generator=candidate_generator, preprocessor=preprocessor, entity_type=entity_type)

@staticmethod
def __get_model_path(
model_name_or_path: Union[str, Path], hybrid_search: bool = False, force_hybrid_search: bool = False
) -> str:
def __get_model_path_and_entity_type(
model_name_or_path: Union[str, Path],
entity_type: Optional[str] = None,
hybrid_search: bool = False,
force_hybrid_search: bool = False,
) -> Tuple[str, str]:
"""
Try to figure out what model the user wants
"""

if isinstance(model_name_or_path, str):

if model_name_or_path not in MODELS and model_name_or_path not in ENTITY_TYPES:
raise ValueError(
f"""Unknown model `{model_name_or_path}`! \n
Available entity types are: {ENTITY_TYPES} \n
If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)`"""
)
if model_name_or_path not in MODELS and model_name_or_path not in ENTITY_TYPES:
raise ValueError(
f"""Unknown model `{model_name_or_path}`! \n
Available entity types are: {ENTITY_TYPES} \n
If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)`"""
)

if hybrid_search:
# load model by entity_type
if model_name_or_path in ENTITY_TYPES:
# check if we have a hybrid pre-trained model
if model_name_or_path in ENTITY_TYPE_TO_HYBRID_MODEL:
model_name_or_path = ENTITY_TYPE_TO_HYBRID_MODEL[model_name_or_path]
else:
# check if user really wants to use hybrid search anyway
if not force_hybrid_search:
raise ValueError(
f"""
Model for entity type `{model_name_or_path}` was not trained for hybrid search!
If you want to proceed anyway please pass `force_hybrid_search=True`:
we will fit a sparse encoder for you. The default value of `sparse_weight` is `{DEFAULT_SPARSE_WEIGHT}`.
"""
)
model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path]
if model_name_or_path == "cambridgeltl/SapBERT-from-PubMedBERT-fulltext":
assert entity_type is not None, f"For model {model_name_or_path} you must specify `entity_type`"

entity_type = None
if hybrid_search:
# load model by entity_type
if model_name_or_path in ENTITY_TYPES:
# check if we have a hybrid pre-trained model
if model_name_or_path in ENTITY_TYPE_TO_HYBRID_MODEL:
entity_type = model_name_or_path
model_name_or_path = ENTITY_TYPE_TO_HYBRID_MODEL[model_name_or_path]
else:
if model_name_or_path not in PRETRAINED_HYBRID_MODELS and not force_hybrid_search:
# check if user really wants to use hybrid search anyway
if not force_hybrid_search:
raise ValueError(
f"""
Model `{model_name_or_path}` was not trained for hybrid search!
Model for entity type `{model_name_or_path}` was not trained for hybrid search!
If you want to proceed anyway please pass `force_hybrid_search=True`:
we will fit a sparse encoder for you. The default value of `sparse_weight` is `{DEFAULT_SPARSE_WEIGHT}`.
"""
)

model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path]
else:
if model_name_or_path not in PRETRAINED_HYBRID_MODELS and not force_hybrid_search:
raise ValueError(
f"""
Model `{model_name_or_path}` was not trained for hybrid search!
If you want to proceed anyway please pass `force_hybrid_search=True`:
we will fit a sparse encoder for you. The default value of `sparse_weight` is `{DEFAULT_SPARSE_WEIGHT}`.
"""
)
entity_type = PRETRAINED_HYBRID_MODELS[model_name_or_path]

else:
if model_name_or_path in ENTITY_TYPES:
model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path]

return model_name_or_path
assert entity_type is not None, f"Impossible to determine entity type for model `{model_name_or_path}`"

return model_name_or_path, entity_type

@staticmethod
def __get_dictionary_path(
model_name_or_path: str, dictionary_name_or_path: Optional[Union[str, Path]] = None
model_name_or_path: str,
dictionary_name_or_path: Optional[Union[str, Path]] = None,
) -> str:
"""
Try to figure out what dictionary (depending on the model) the user wants
Expand Down