From a01e06090fcd770b4f83b31b05b91ce2c5e46a26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20S=C3=A4nger?= Date: Tue, 14 Mar 2023 11:04:41 +0100 Subject: [PATCH 01/58] Initial version (already adapted to recent Flair API changes) --- flair/data.py | 65 +- flair/datasets/__init__.py | 8 + flair/models/biomedical_entity_linking.py | 1004 +++++++++++++++++ .../HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md | 34 + 4 files changed, 1107 insertions(+), 4 deletions(-) create mode 100644 flair/models/biomedical_entity_linking.py create mode 100644 resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md diff --git a/flair/data.py b/flair/data.py index 10e74e782..ce4d8dabf 100644 --- a/flair/data.py +++ b/flair/data.py @@ -261,7 +261,6 @@ def labeled_identifier(self): def unlabeled_identifier(self): return f"{self.data_point.unlabeled_identifier}" - class DataPoint: """This is the parent class of all data points in Flair. @@ -336,11 +335,13 @@ def get_metadata(self, key: str) -> typing.Any: def has_metadata(self, key: str) -> bool: return key in self._metadata - def add_label(self, typename: str, value: str, score: float = 1.0): + def add_label(self, typename: str, value_or_label: Union[str, Label], score: float = 1.0): + label = value_or_label if isinstance(value_or_label, Label) else Label(self, value_or_label, score) + if typename not in self.annotation_layers: - self.annotation_layers[typename] = [Label(self, value, score)] + self.annotation_layers[typename] = [label] else: - self.annotation_layers[typename].append(Label(self, value, score)) + self.annotation_layers[typename].append(label) return self @@ -431,6 +432,62 @@ def __len__(self) -> int: raise NotImplementedError +class EntityLinkingLabel(Label): + def __init__( + self, + data_point: DataPoint, + id: str, + concept_name: str, + score: float = 1.0, + additional_ids: Optional[Union[List[str], str]] = None, + database: Optional[str] = None, + ): + super().__init__(data_point, id, score) + self.concept_name = concept_name + if isinstance(additional_ids, str): + additional_ids = [additional_ids] + self.additional_ids = additional_ids + self.database = database + + def spawn(self, value: str, score: float = 1.0): + return EntityLinkingLabel( + data_point=self.data_point, + id=value, + score=score, + concept_name=self.concept_name, + additional_ids=self.additional_ids, + database=self.database + ) + + def __str__(self): + if self.additional_ids is None: + return f"{self.database}:{self._value} {self.concept_name} ({round(self._score, 2)})" + else: + return f"{self.database}:{self._value} ({', '.join(self.additional_ids)}) {self.concept_name} ({round(self._score, 2)})" + + def __repr__(self): + if self.additional_ids is None: + return f"{self.database}:{self._value} {self.concept_name} [{self.data_point.text}] ({round(self._score, 2)})" + else: + return f"{self.database}:{self._value} ({', '.join(self.additional_ids)}) {self.concept_name} [{self.data_point.text}] ({round(self._score, 2)})" + + def __len__(self): + return len(self.data_point) + + def __eq__(self, other): + return ( + self.value == other.value + and self.data_point == other.data_point + and self.concept_name == other.concept_name + and self.database == other.database + and self.score == other.score + ) + + @property + def identifier(self): + return f"{self.value}" + + DT = typing.TypeVar("DT", bound=DataPoint) DT2 = typing.TypeVar("DT2", bound=DataPoint) diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index d8edd4445..ed0037d70 100644 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -90,6 +90,10 @@ LOCTEXT, MIRNA, NCBI_DISEASE, + NEL_NCBI_HUMAN_GENE_DICT, + NEL_NCBI_TAXONOMY_DICT, + NEL_CTD_CHEMICAL_DICT, + NEL_CTD_DISEASE_DICT, OSIRIS, PDR, S800, @@ -390,6 +394,10 @@ "LINNEAUS", "LOCTEXT", "MIRNA", + "NEL_NCBI_HUMAN_GENE_DICT", + "NEL_NCBI_TAXONOMY_DICT", + "NEL_CTD_CHEMICAL_DICT", + "NEL_CTD_DISEASE_DICT", "NCBI_DISEASE", "ONTONOTES", "OSIRIS", diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py new file mode 100644 index 000000000..e79b63372 --- /dev/null +++ b/flair/models/biomedical_entity_linking.py @@ -0,0 +1,1004 @@ +import os +import stat +import re +import pickle +import logging +import numpy as np +import torch +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity +from tqdm import tqdm +from transformers import ( + PreTrainedTokenizer, + PreTrainedModel, +) +from huggingface_hub import hf_hub_url, cached_download +from string import punctuation +import flair +from flair.data import Sentence, EntityLinkingLabel +from flair.datasets import ( + NEL_CTD_CHEMICAL_DICT, + NEL_CTD_DISEASE_DICT, + NEL_NCBI_HUMAN_GENE_DICT, + NEL_NCBI_TAXONOMY_DICT, +) +from flair.file_utils import cached_path +from typing import List, Tuple, Union +from pathlib import Path +import subprocess +import tempfile +import faiss +from flair.embeddings import TransformerDocumentEmbeddings +from string import punctuation + +log = logging.getLogger("flair") + + +class BigramTfIDFVectorizer: + """ + Class to encode a list of mentions into a sparse tensor. + + Slightly modified from Sung et al. 2020 + Biomedical Entity Representations with Synonym Marginalization + https://github.com/dmis-lab/BioSyn/tree/master/src/biosyn/sparse_encoder.py#L8 + """ + + def __init__(self) -> None: + self.encoder = TfidfVectorizer(analyzer="char", ngram_range=(1, 2)) + + def transform(self, mentions: list) -> torch.Tensor: + vec = self.encoder.transform(mentions).toarray() + vec = torch.FloatTensor(vec) + return vec + + def __call__(self, mentions: list): + return self.transform(mentions) + + def save_encoder(self, path: Path): + with open(path, "wb") as fout: + pickle.dump(self.encoder, fout) + log.info("Sparse encoder saved in {}".format(path)) + + @classmethod + def load(cls, path: Path): + newVectorizer = cls() + with open(path, "rb") as fin: + newVectorizer.encoder = pickle.load(fin) + log.info("Sparse encoder loaded from {}".format(path)) + + return newVectorizer + + +class TextPreprocess: + """ + Text Preprocess module + Support lowercase, removing punctuation, typo correction + + Slightly modifed from Sung et al. 2020 + Biomedical Entity Representations with Synonym Marginalization + https://github.com/dmis-lab/BioSyn/blob/master/src/biosyn/preprocesser.py#L5 + """ + + def __init__( + self, + lowercase: bool = True, + remove_punctuation: bool = True, + ) -> None: + """ + :param typo_path str: path of known typo dictionary + """ + self.lowercase = lowercase + self.rmv_puncts = remove_punctuation + self.punctuation = punctuation + self.rmv_puncts_regex = re.compile( + r"[\s{}]+".format(re.escape(self.punctuation)) + ) + + def remove_punctuation(self, phrase: str) -> str: + phrase = self.rmv_puncts_regex.split(phrase) + phrase = " ".join(phrase).strip() + + return phrase + + def run(self, text: str) -> str: + if self.lowercase: + text = text.lower() + + if self.rmv_puncts: + text = self.remove_punctuation(text) + + text = text.strip() + + return text + + +class Ab3P: + """ + Module for the Abbreviation Resolver Ab3P + https://github.com/ncbi-nlp/Ab3P + """ + + def __init__(self, ab3p_path: Path, word_data_dir: Path) -> None: + self.ab3p_path = ab3p_path + self.word_data_dir = word_data_dir + + @classmethod + def load(cls, ab3p_path: Path = None): + data_dir = os.path.join(flair.cache_root, "ab3p") + if not os.path.exists(data_dir): + os.mkdir(os.path.join(data_dir)) + word_data_dir = os.path.join(data_dir, "word_data/") + if not os.path.exists(word_data_dir): + os.mkdir(word_data_dir) + if ab3p_path is None: + ab3p_path = cls.download_ab3p(data_dir, word_data_dir) + return cls(ab3p_path, word_data_dir) + + @classmethod + def download_ab3p(cls, data_dir: Path, word_data_dir: Path) -> Path: + # download word data for Ab3P if not already downloaded + ab3p_url = ( + "https://raw.githubusercontent.com/dmis-lab/BioSyn/master/Ab3P/WordData/" + ) + ab3p_files = [ + "Ab3P_prec.dat", + "Lf1chSf", + "SingTermFreq.dat", + "cshset_wrdset3.ad", + "cshset_wrdset3.ct", + "cshset_wrdset3.ha", + "cshset_wrdset3.nm", + "cshset_wrdset3.str", + "hshset_Lf1chSf.ad", + "hshset_Lf1chSf.ha", + "hshset_Lf1chSf.nm", + "hshset_Lf1chSf.str", + "hshset_stop.ad", + "hshset_stop.ha", + "hshset_stop.nm", + "hshset_stop.str", + "stop", + ] + for file in ab3p_files: + data_path = cached_path(ab3p_url + file, word_data_dir) + # download ab3p executable + ab3p_path = cached_path( + "https://github.com/dmis-lab/BioSyn/raw/master/Ab3P/identify_abbr", data_dir + ) + os.chmod(ab3p_path, stat.S_IEXEC) + return ab3p_path + + def build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> dict: + abbreviation_dict = {} + # tempfile to store the data to pass to the ab3p executable + with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as temp_file: + for sentence in sentences: + temp_file.write(sentence.to_tokenized_string() + "\n") + temp_file.flush() + # temporarily create path file in the current working directory for Ab3P + with open(os.path.join(os.getcwd(), "path_Ab3P"), "w") as path_file: + path_file.write(self.word_data_dir + "\n") + # run ab3p with the temp file containing the dataset + try: + result = subprocess.run( + [self.ab3p_path, temp_file.name], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + except: + log.error( + """The abbreviation resolver Ab3P could not be run on your system. To ensure maximum accuracy, please + install Ab3P yourself. See https://github.com/ncbi-nlp/Ab3P""" + ) + else: + line = result.stdout.decode("utf-8") + if "Path file for type cshset does not exist!" in line: + log.error( + "Error when using Ab3P for abbreviation resolution. A file named path_Ab3p needs to exist in your current directory containing the path to the WordData directory for Ab3P to work!" + ) + elif "Cannot open" in line: + log.error( + "Error when using Ab3P for abbreviation resolution. Could not open the WordData directory for Ab3P!" + ) + elif "failed to open" in line: + log.error( + "Error when using Ab3P for abbreviation resolution. Could not open the WordData directory for Ab3P!" + ) + + lines = line.split("\n") + for line in lines: + if len(line.split("|")) == 3: + sf, lf, _ = line.split("|") + sf = sf.strip().lower() + lf = lf.strip().lower() + abbreviation_dict[sf] = lf + finally: + # remove the path file + os.remove(os.path.join(os.getcwd(), "path_Ab3P")) + + return abbreviation_dict + + +class DictionaryDataset: + """ + A class used to load dictionary data from a custom dictionary file. + Every line in the file must be formatted as follows: + concept_unique_id||concept_name + with one line per concept name. Multiple synonyms for the same concept should + be in seperate lines with the same concept_unique_id. + + Slightly modifed from Sung et al. 2020 + Biomedical Entity Representations with Synonym Marginalization + https://github.com/dmis-lab/BioSyn/blob/master/src/biosyn/data_loader.py#L89 + """ + + def __init__( + self, dictionary_path: Union[Path, str], load_into_memory: True + ) -> None: + """ + :param dictionary_path str: The path of the dictionary + """ + log.info("Loading Dictionary from {}".format(dictionary_path)) + if load_into_memory: + self.data = self.load_data(dictionary_path) + else: + self.data = self.get_data(dictionary_path) + + def load_data(self, dictionary_path) -> np.ndarray: + data = [] + with open(dictionary_path, mode="r", encoding="utf-8") as f: + lines = f.readlines() + for line in tqdm(lines, desc="Loading dictionary"): + line = line.strip() + if line == "": + continue + cui, name = line.split("||") + name = name.lower() + data.append((name, cui)) + + data = np.array(data) + return data + + # generator version + def get_data(self, dictionary_path): + data = [] + with open(dictionary_path, mode="r", encoding="utf-8") as f: + lines = f.readlines() + for line in tqdm(lines, desc="Loading dictionary"): + line = line.strip() + if line == "": + continue + cui, name = line.split("||") + name = name.lower() + yield (name, cui) + + +class BiEncoderEntityLiker: + """ + A class to load a model and use it to encode a dictionary and entities + """ + + def __init__( + self, use_sparse_embeds: bool, max_length: int, index_use_cuda: bool + ) -> None: + self.use_sparse_embeds = use_sparse_embeds + self.max_length = max_length + + self.tokenizer = None + self.encoder = None + + self.sparse_encoder = None + self.sparse_weight = None + + self.index_use_cuda = index_use_cuda and flair.device.type == "cuda" + + self.dense_dictionary_index = None + self.dict_sparse_embeds = None + + self.dictionary = None + + def load_model( + self, + model_name_or_path: Union[str, Path], + dictionary_name_or_path: str, + batch_size: int = 1024, + use_cosine: bool = True, + ): + """ + Load the model and embed the dictionary + :param model_name_or_path: The path of the model + :param dictionary_name_or_path: The path of the dictionary + :param batch_size: The batch size for embedding the dictionary + """ + self.load_dense_encoder(model_name_or_path) + self.use_cosine = True + + if self.use_sparse_embeds: + self.load_sparse_encoder(model_name_or_path) + self.load_sparse_weight(model_name_or_path) + + self.embed_dictionary( + model_name_or_path=model_name_or_path, + dictionary_name_or_path=dictionary_name_or_path, + batch_size=batch_size, + ) + + return self + + def load_dense_encoder( + self, model_name_or_path: Union[str, Path] + ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: + + self.encoder = TransformerDocumentEmbeddings( + model_name_or_path, is_token_embedding=False + ) + + return self.encoder + + def load_sparse_encoder( + self, model_name_or_path: Union[str, Path] + ) -> BigramTfIDFVectorizer: + sparse_encoder_path = os.path.join(model_name_or_path, "sparse_encoder.pk") + # check file exists + if not os.path.isfile(sparse_encoder_path): + # download from huggingface hub and cache it + sparse_encoder_url = hf_hub_url( + model_name_or_path, filename="sparse_encoder.pk" + ) + sparse_encoder_path = cached_download( + url=sparse_encoder_url, + cache_dir=flair.cache_root / "models" / model_name_or_path, + ) + + self.sparse_encoder = BigramTfIDFVectorizer.load(path=sparse_encoder_path) + + return self.sparse_encoder + + def load_sparse_weight(self, model_name_or_path: Union[str, Path]) -> torch.Tensor: + sparse_weight_path = os.path.join(model_name_or_path, "sparse_weight.pt") + # check file exists + if not os.path.isfile(sparse_weight_path): + # download from huggingface hub and cache it + sparse_weight_url = hf_hub_url( + model_name_or_path, filename="sparse_weight.pt" + ) + sparse_weight_path = cached_download( + url=sparse_weight_url, + cache_dir=flair.cache_root / "models" / model_name_or_path, + ) + + self.sparse_weight = torch.load(sparse_weight_path, map_location="cpu") + + return self.sparse_weight + + def get_sparse_weight(self) -> torch.Tensor: + assert self.sparse_weight is not None + + return self.sparse_weight + + def embed_sparse(self, names: list) -> np.ndarray: + """ + Embedding data into sparse representations + :param names np.array: An array of names + :returns sparse_embeds np.array: A list of sparse embeddings + """ + sparse_embeds = self.sparse_encoder(names) + sparse_embeds = sparse_embeds.numpy() + + return sparse_embeds + + def embed_dense( + self, + names: np.ndarray, + show_progress: bool = False, + batch_size: int = 2048, + ) -> np.ndarray: + """ + Embedding data into dense representations for SapBert + :param names: np.array of names + :param show_progress: bool to toggle progress bar + :param batch_size: batch size + :return dense_embeds: list of dense embeddings of the names + """ + self.encoder.eval() # prevent dropout + + dense_embeds = [] + + with torch.no_grad(): + if show_progress: + iterations = tqdm( + range(0, len(names), batch_size), + desc="Calculating dense embeddings for dictionary", + ) + else: + iterations = range(0, len(names), batch_size) + + for start in iterations: + + # make batch + end = min(start + batch_size, len(names)) + batch = [Sentence(name) for name in names[start:end]] + + # embed batch + self.encoder.embed(batch) + + dense_embeds += [ + name.embedding.cpu().detach().numpy() for name in batch + ] + + if flair.device.type == "cuda": + torch.cuda.empty_cache() + + return np.array(dense_embeds) + + def get_sparse_similarity_scores( + self, + query_embeds: np.ndarray, + dict_embeds: np.ndarray, + cosine: bool = False, + normalise: bool = False, + ) -> np.ndarray: + """ + Return score matrix + :param query_embeds: 2d numpy array of query embeddings + :param dict_embeds: 2d numpy array of query embeddings + :param score_matrix: 2d numpy array of scores + """ + if cosine: + score_matrix = cosine_similarity(query_embeds, dict_embeds) + else: + score_matrix = np.matmul(query_embeds, dict_embeds.T) + + if normalise: + score_matrix = (score_matrix - score_matrix.min()) / ( + score_matrix.max() - score_matrix.min() + ) + + return score_matrix + + def retrieve_sparse_topk_candidates( + self, + score_matrix: np.ndarray, + topk: int, + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Return sorted topk idxes (descending order) + :param score_matrix: 2d numpy array of scores + :param topk: number of candidates to retrieve + :return res: d numpy array of ids [# of query , # of dict] + :return scores: numpy array of top scores + """ + + def indexing_2d(arr, cols): + rows = np.repeat( + np.arange(0, cols.shape[0])[:, np.newaxis], cols.shape[1], axis=1 + ) + return arr[rows, cols] + + # get topk indexes without sorting + topk_idxs = np.argpartition(score_matrix, -topk)[:, -topk:] + + # get topk indexes with sorting + topk_score_matrix = indexing_2d(score_matrix, topk_idxs) + topk_argidxs = np.argsort(-topk_score_matrix) + topk_idxs = indexing_2d(topk_idxs, topk_argidxs) + topk_scores = indexing_2d(score_matrix, topk_idxs) + + return topk_idxs, topk_scores + + def get_predictions( + self, + mention: str, + topk: int, + batch_size: int = 1024, + ) -> np.ndarray: + """ + Return the topk predictions for a mention and their scores + :param mention: string of the mention to find candidates for + :param topk: number of candidates + :return res: d numpy array of ids [# of query , # of dict] + :return scores: numpy array of top predictions and their scores + """ + # get dense embeds for mention + mention_dense_embeds = self.embed_dense(names=[mention]) + + if self.use_cosine: + # normalize mention vector + faiss.normalize_L2(mention_dense_embeds) + + assert ( + self.dense_dictionary_index is not None + ), "Index not built yet, please run load_model to embed your dictionary before calling get_predictions" + + # if using sparse embeds: calculate hybrid scores with dense and sparse embeds + if self.use_sparse_embeds: + assert ( + self.dict_sparse_embeds is not None + ), "Index not built yet, please run load_model to embed your dictionary before calling get_predictions" + # search for more than topk candidates to use them when combining with sparse scores + # get candidates from dense embeddings + dense_scores, dense_ids = self.dense_dictionary_index.search( + x=mention_dense_embeds, k=topk + 10 + ) + # get sparse embeds for mention + mention_sparse_embeds = self.embed_sparse(names=[mention]) + if self.use_cosine: + # normalize mention vector + faiss.normalize_L2(mention_sparse_embeds) + + # get candidates from sprase embeddings + sparse_weight = self.get_sparse_weight().item() + sparse_score_matrix = self.get_sparse_similarity_scores( + query_embeds=mention_sparse_embeds, dict_embeds=self.dict_sparse_embeds + ) + sparse_ids, sparse_distances = self.retrieve_sparse_topk_candidates( + score_matrix=sparse_score_matrix, topk=topk + 10 + ) + + # combine dense and sparse scores + hybrid_ids = [] + hybrid_scores = [] + # for every embedded mention + for ( + top_dense_ids, + top_dense_scores, + top_sparse_ids, + top_sparse_distances, + ) in zip(dense_ids, dense_scores, sparse_ids, sparse_distances): + ids = top_dense_ids + distances = top_dense_scores + for sparse_id, sparse_distance in zip( + top_sparse_ids, top_sparse_distances + ): + if sparse_id not in ids: + ids = np.append(ids, sparse_id) + distances = np.append( + distances, sparse_weight * sparse_distance + ) + else: + index = np.where(ids == sparse_id)[0][0] + distances[index] = ( + sparse_weight * sparse_distance + ) + distances[index] + + sorted_indizes = np.argsort(-distances) + ids = ids[sorted_indizes][:topk] + distances = distances[sorted_indizes][:topk] + hybrid_ids.append(ids.tolist()) + hybrid_scores.append(distances.tolist()) + + return [ + np.append(self.dictionary[ind], score) + for ind, score in zip(hybrid_ids, hybrid_scores) + ] + # use only dense embeds + else: + dense_distances, dense_ids = self.dense_dictionary_index.search( + x=mention_dense_embeds, k=topk + ) + return [ + np.append(self.dictionary[ind], score) + for ind, score in zip(dense_ids, dense_distances) + ] + + def embed_dictionary( + self, model_name_or_path: str, dictionary_name_or_path: str, batch_size: int + ): + # check for embedded dictionary in cache + dictionary_name = os.path.splitext(os.path.basename(dictionary_name_or_path))[0] + cache_folder = os.path.join(flair.cache_root, "datasets") + file_name = f"bio_nen_{model_name_or_path.split('/')[-1]}_{dictionary_name}" + cached_dictionary_path = os.path.join( + cache_folder, + f"{file_name}.pk", + ) + self.load_dictionary(dictionary_name_or_path) + + # If exists, load the cached dictionary indices + if os.path.exists(cached_dictionary_path): + self.load_cached_dictionary(cached_dictionary_path) + + # else, load and embed + else: + + # get names from dictionary and remove punctuation + punctuation_regex = re.compile(r"[\s{}]+".format(re.escape(punctuation))) + dictionary_names = [] + for row in self.dictionary: + name = punctuation_regex.split(row[0]) + name = " ".join(name).strip().lower() + dictionary_names.append(name) + + # create dense and sparse embeddings + if self.use_sparse_embeds: + self.dict_dense_embeds = self.embed_dense( + names=dictionary_names, + batch_size=batch_size, + show_progress=True + ) + self.dict_sparse_embeds = self.embed_sparse(names=dictionary_names) + + # create only dense embeddings + else: + self.dict_dense_embeds = self.embed_dense( + names=dictionary_names, show_progress=True + ) + self.dict_sparse_embeds = None + + # build dense index + dimension = self.dict_dense_embeds.shape[1] + if self.use_cosine: + # to use cosine similarity, we normalize the vectors and then use inner product + faiss.normalize_L2(self.dict_dense_embeds) + + self.dense_dictionary_index = faiss.IndexFlatIP(dimension) + if self.index_use_cuda: + self.dense_dictionary_index = faiss.index_cpu_to_gpu( + faiss.StandardGpuResources(), 0, self.dense_dictionary_index + ) + self.dense_dictionary_index.add(self.dict_dense_embeds) + + if self.index_use_cuda: + # create a cpu version of the index to cache (IO does not work with gpu index) + cache_index_dense = faiss.index_gpu_to_cpu(self.dense_dictionary_index) + # cache dictionary + self.cache_dictionary( + cache_index_dense, cache_folder, cached_dictionary_path + ) + else: + self.cache_dictionary( + self.dense_dictionary_index, cache_folder, cached_dictionary_path + ) + + def load_cached_dictionary(self, cached_dictionary_path: str): + with open(cached_dictionary_path, "rb") as cached_file: + cached_dictionary = pickle.load(cached_file) + log.info( + "Loaded dictionary from cached file {}".format(cached_dictionary_path) + ) + + (self.dictionary, self.dict_sparse_embeds, self.dense_dictionary_index,) = ( + cached_dictionary["dictionary"], + cached_dictionary["sparse_dictionary_embeds"], + cached_dictionary["dense_dictionary_index"], + ) + if self.index_use_cuda: + self.dense_dictionary_index = faiss.index_cpu_to_gpu( + faiss.StandardGpuResources(), 0, self.dense_dictionary_index + ) + + def load_dictionary(self, dictionary_name_or_path: str): + # use provided dictionary + if dictionary_name_or_path == "ctd-disease": + self.dictionary = NEL_CTD_DISEASE_DICT().data + elif dictionary_name_or_path == "ctd-chemical": + self.dictionary = NEL_CTD_CHEMICAL_DICT().data + elif dictionary_name_or_path == "ncbi-gene": + self.dictionary = NEL_NCBI_HUMAN_GENE_DICT().data + elif dictionary_name_or_path == "ncbi-taxonomy": + self.dictionary = NEL_NCBI_TAXONOMY_DICT().data + # use custom dictionary file + else: + self.dictionary = DictionaryDataset( + dictionary_path=dictionary_name_or_path + ).data + + def cache_dictionary(self, cache_index_dense, cache_folder, cached_dictionary_path): + cached_dictionary = { + "dictionary": self.dictionary, + "sparse_dictionary_embeds": self.dict_sparse_embeds, + "dense_dictionary_index": cache_index_dense, + } + if not os.path.exists(cache_folder): + os.mkdir(cache_folder) + with open(cached_dictionary_path, "wb") as cache_file: + pickle.dump(cached_dictionary, cache_file) + print("Saving dictionary into cached file {}".format(cache_folder)) + + +class ExactStringMatchEntityLinker: + def __init__(self): + self.dictionary = None + + def load_model( + self, + dictionary_name_or_path: str, + ): + # use provided dictionary + if dictionary_name_or_path == "ctd-disease": + dictionary_data = NEL_CTD_DISEASE_DICT().data + elif dictionary_name_or_path == "ctd-chemical": + dictionary_data = NEL_CTD_CHEMICAL_DICT().data + elif dictionary_name_or_path == "ncbi-gene": + dictionary_data = NEL_NCBI_HUMAN_GENE_DICT().data + elif dictionary_name_or_path == "ncbi-taxonomy": + dictionary_data = NEL_NCBI_TAXONOMY_DICT().data + # use custom dictionary file + else: + dictionary_data = DictionaryDataset( + dictionary_path=dictionary_name_or_path + ).data + + # make dictionary from array of tuples (entity, entity_id) + self.dictionary = {name: cui for name, cui in dictionary_data} + + def get_predictions(self, mention: str, topk) -> np.ndarray: + if mention in self.dictionary: + return np.array([(mention, self.dictionary[mention], 1.0)]) + else: + return [] + + +class MultiBiEncoderEntityLinker: + """ + Biomedical Entity Linker for HunFlair + Can predict top k entities on sentences annotated with biomedical entity mentions + """ + + def __init__( + self, + models: List[BiEncoderEntityLiker], + ab3p=None, + ) -> None: + """ + Initalize class, called by classmethod load + :param models: list of objects containing the dense and sparse encoders + :param ab3p_path: path to ab3p model + """ + self.models = models + self.text_preprocessor = TextPreprocess() + self.ab3p = ab3p + + @classmethod + def load( + cls, + model_names: Union[List[str], str], + dictionary_names_or_paths: Union[str, Path, List[str], List[Path]] = None, + use_sparse_and_dense_embeds: bool = True, + max_length=25, + batch_size=1024, + index_use_cuda=False, + use_cosine: bool = True, + use_ab3p: bool = True, + ab3p_path: Path = None, + ): + """ + Load a model for biomedical named entity normalization on sentences annotated with + biomedical entity mentions + :param model_names: List of names of pretrained models to use. Possible values for pretrained models are: + chemical, disease, gene, sapbert-bc5cdr-dissaease, sapbert-ncbi-disease, sapbert-bc5cdr-chemical, biobert-bc5cdr-disease, + biobert-ncbi-disease, biobert-bc5cdr-chemical, biosyn-biobert-bc2gn, biosyn-sapbert-bc2gn, sapbert, exact-string-match + :param dictionary_path: Name of one of the provided dictionaries listing all possible ids and their synonyms + or a path to a dictionary file with each line in the format: id||name, with one line for each name of a concept. + Possible values for dictionaries are: chemical, ctd-chemical, disease, bc5cdr-disease, gene, cnbci-gene, + taxonomy and ncbi-taxonomy + :param use_sparse_and_dense_embeds: If True, uses a combinations of sparse and dense embeddings for the dictionary and the mentions + If False, uses only dense embeddings + :param batch_size: Batch size for the dense encoder + :param index_use_cuda: If True, uses GPU for the dense encoder + :param use_cosine: If True, uses cosine similarity for the dense encoder. If False, uses inner product + :param use_ab3p: If True, uses ab3p to resolve abbreviations + :param ab3p_path: Optional: oath to ab3p on your machine + """ + # validate input: check that amount of models and dictionaries match + if isinstance(model_names, str): + model_names = [model_names] + + # case one dictionary for all models + if isinstance(dictionary_names_or_paths, str) or isinstance( + dictionary_names_or_paths, Path + ): + dictionary_names_or_paths = [dictionary_names_or_paths] * len(model_names) + # case no dictionary provided + elif dictionary_names_or_paths is None: + dictionary_names_or_paths = [None] * len(model_names) + # case one model, multiple dictionaries + elif len(model_names) == 1: + model_names = model_names * len(dictionary_names_or_paths) + # case mismatching amount of models and dictionaries + elif len(model_names) != len(dictionary_names_or_paths): + raise ValueError( + "Amount of models and dictionaries must match. Got {} models and {} dictionaries".format( + len(model_names), len(dictionary_names_or_paths) + ) + ) + assert len(model_names) == len(dictionary_names_or_paths) + + models = [] + + for model_name, dictionary_name_or_path in zip( + model_names, dictionary_names_or_paths + ): + # get the paths for the model and dictionary + model_path = cls.__get_model_path(model_name, use_sparse_and_dense_embeds) + dictionary_path = cls.__get_dictionary_path( + dictionary_name_or_path, model_name=model_name + ) + + if model_path == "exact-string-match": + model = ExactStringMatchEntityLinker() + model.load_model(dictionary_path) + + else: + model = BiEncoderEntityLiker( + use_sparse_embeds=use_sparse_and_dense_embeds, + max_length=max_length, + index_use_cuda=index_use_cuda, + ) + + model.load_model( + model_name_or_path=model_path, + dictionary_name_or_path=dictionary_path, + batch_size=batch_size, + use_cosine=use_cosine, + ) + + models.append(model) + + # load ab3p model + ab3p = Ab3P.load(ab3p_path) if use_ab3p else None + + return cls(models, ab3p) + + @staticmethod + def __get_model_path( + model_name: str, use_sparse_and_dense_embeds + ) -> BiEncoderEntityLiker: + model_name = model_name.lower() + model_path = model_name + + # if a provided model is used, + # modify model name to huggingface path + + if model_name in [ + "sapbert-bc5cdr-disease", + "sapbert-ncbi-disease", + "sapbert-bc5cdr-chemical", + "biobert-bc5cdr-disease", + "biobert-ncbi-disease", + "biobert-bc5cdr-chemical", + "biosyn-biobert-bc2gn", + "biosyn-sapbert-bc2gn", + ]: + model_path = "dmis-lab/biosyn-" + model_name + elif model_name == "sapbert": + model_path = "cambridgeltl/SapBERT-from-PubMedBERT-fulltext" + elif model_name == "exact-string-match": + model_path = "exact-string-match" + elif use_sparse_and_dense_embeds: + if model_name == "disease": + model_path = "dmis-lab/biosyn-sapbert-bc5cdr-disease" + elif model_name == "chemical": + model_path = "dmis-lab/biosyn-sapbert-bc5cdr-chemical" + elif model_name == "gene": + model_path = "dmis-lab/biosyn-sapbert-bc2gn" + else: + if model_name == "disease": + model_path = "cambridgeltl/SapBERT-from-PubMedBERT-fulltext" + elif model_name == "chemical": + model_path = "cambridgeltl/SapBERT-from-PubMedBERT-fulltext" + elif model_name == "gene": + raise ValueError( + "No trained model for gene entity linking using only dense embeddings." + ) + return model_path + + @staticmethod + def __get_dictionary_path(dictionary_path: str, model_name: str): + # determine dictionary to use + if dictionary_path == "disease": + dictionary_path = "ctd-disease" + if dictionary_path == "chemical": + dictionary_path = "ctd-chemical" + if dictionary_path == "gene": + dictionary_path = "ncbi-gene" + if dictionary_path == "taxonomy": + dictionary_path = "ncbi-taxonomy" + if dictionary_path is None: + # disease + if model_name in [ + "sapbert-bc5cdr-disease", + "sapbert-ncbi-disease", + "biobert-bc5cdr-disease", + "biobert-ncbi-disease", + "disease", + ]: + dictionary_path = "ctd-disease" + # chemical + elif model_name in [ + "sapbert-bc5cdr-chemical", + "biobert-bc5cdr-chemical", + "chemical", + ]: + dictionary_path = "ctd-chemical" + # gene + elif model_name in ["gene", "biosyn-biobert-bc2gn", "biosyn-sapbert-bc2gn"]: + dictionary_path = "ncbi-gene" + # error + else: + log.error( + """When using a custom model you need to specify a dictionary. + Available options are: 'disease', 'chemical', 'gene' and 'taxonomy'. + Or provide a path to a dictionary file.""" + ) + raise ValueError("Invalid dictionary") + + return dictionary_path + + def predict( + self, + sentences: Union[List[Sentence], Sentence], + input_entity_annotation_layer: str = None, + topk: int = 1, + ) -> None: + """ + On one or more sentences, predict the cui on all named entites annotated with a tag of type input_entity_annotation_layer. + Annotates the top k predictions. + :param sentences: one or more sentences to run the predictions on + :param input_entity_annotation_layer: only entities with in this annotation layer will be annotated + :param topk: number of predicted cui candidates to add to annotation + :param abbreviation_dict: dictionary with abbreviations and their expanded form or a boolean value indicating whether + abbreviations should be expanded using Ab3P + """ + # make sure sentences is a list of sentences + if not isinstance(sentences, list): + sentences = [sentences] + + # use Ab3P to build abbreviation dictionary + if self.ab3p is not None: + abbreviation_dict = self.ab3p.build_abbreviation_dict(sentences) + + for model in self.models: + + for sentence in sentences: + for entity in sentence.get_labels(input_entity_annotation_layer): + # preprocess mention + if abbreviation_dict is not None: + parsed_tokens = [] + for token in entity.data_point.tokens: + token = self.text_preprocessor.run(token.text) + if token in abbreviation_dict: + parsed_tokens.append(abbreviation_dict[token.lower()]) + elif len(token) != 0: + parsed_tokens.append(token) + mention = " ".join(parsed_tokens) + else: + mention = self.text_preprocessor.run(entity.span.text) + + # get predictions from dictionary + predictions = model.get_predictions(mention, topk) + + # add predictions to entity + label_name = ( + (input_entity_annotation_layer + "_nen") + if (input_entity_annotation_layer is not None) + else "nen" + ) + for prediction in predictions: + # if concept unique id is made up of mulitple ids, seperated by '|' + # seperate it into cui and additional_labels + cui = prediction[1] + if "|" in cui: + labels = cui.split("|") + cui = labels[0] + additional_labels = labels[1:] + else: + additional_labels = None + # determine database: + if ":" in cui: + cui_parts = cui.split(":") + database = ":".join(cui_parts[0:-1]) + cui = cui_parts[-1] + else: + database = None + sentence.add_label( + typename=label_name, + value_or_label=EntityLinkingLabel( + data_point=entity.data_point, + id=cui, + concept_name=prediction[0], + additional_ids=additional_labels, + database=database, + score=prediction[2].astype(float), + ), + ) diff --git a/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md new file mode 100644 index 000000000..9af93fb6e --- /dev/null +++ b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md @@ -0,0 +1,34 @@ +# HunFlair Tutorial 3: Entity Linking + +After adding Named Entity Recognition tags to your sentence, you can run Named Entity Linking on these annotations. +```python +from flair.models import MultiTagger +from flair.tokenization import SciSpacyTokenizer +from flair.data import Sentence + +sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome", use_tokenizer=SciSpacyTokenizer()) + +ner_tagger = MultiTagger.load("hunflair-disease") +ner_tagger.predict(sentence) + +nen_tagger = MultiBiEncoderEntityLinker.load("disease") +nen_tagger.predict(sentence) + +for tag in sentence.get_labels(): + print(tag) +``` +This should print: +~~~ +Disease [Behavioral abnormalities (1,2)] (0.6736) +Disease [Fragile X Syndrome (10,11,12)] (0.99) +MESH:D001523 (DO:DOID:150) behavior disorders (0.98) +MESH:D005600 (DO:DOID:14261, OMIM:300624, OMIM:309548) fragile x syndrome (1.1) +~~~ +This output contains the NER disease annotations and it's entity linking annotations with ids from (often more than one) database. +We have preconfigured combinations of models and dictionaries for "disease", "chemical" and "gene". You can also provide your own model and dictionary: + +```python +nen_tagger = MultiBiEncoderEntityLinker.load("name_or_path_to_your_model", dictionary_names_or_paths="name_or_path_to_your_dictionary") +nen_tagger = MultiBiEncoderEntityLinker.load("path_to_custom_disease_model", dictionary_names_or_paths="disease") +```` +You can use any combination of provided models, provided dictionaries and your own. From 22bc93aa62de630f2953683adbdfd90c6e5f5069 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20S=C3=A4nger?= Date: Tue, 14 Mar 2023 15:50:28 +0100 Subject: [PATCH 02/58] Revise mention text pre-processing: define general interface and adapt basic text and Ab3P pre-processing to the new structure; fix bug in Ab3P abbreviation detection --- flair/data.py | 5 +- flair/models/biomedical_entity_linking.py | 367 ++++++++++++++-------- 2 files changed, 240 insertions(+), 132 deletions(-) diff --git a/flair/data.py b/flair/data.py index ce4d8dabf..4928b0f57 100644 --- a/flair/data.py +++ b/flair/data.py @@ -461,9 +461,10 @@ def spawn(self, value: str, score: float = 1.0): def __str__(self): if self.additional_ids is None: - return f"{self.database}:{self._value} {self.concept_name} ({round(self._score, 2)})" + return f"{self.data_point.unlabeled_identifier}{flair._arrow} " \ + f"{self.concept_name} - {self.database}:{self._value} ({round(self._score, 4)})" else: - return f"{self.database}:{self._value} ({', '.join(self.additional_ids)}) {self.concept_name} ({round(self._score, 2)})" + return f" ({', '.join(self.additional_ids)}) {self.concept_name} ({round(self._score, 2)})" def __repr__(self): if self.additional_ids is None: diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index e79b63372..2cae48059 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -1,145 +1,198 @@ -import os -import stat -import re -import pickle +import flair +import faiss import logging import numpy as np +import os +import pickle +import re +import subprocess +import stat +import string +import tempfile import torch -from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.metrics.pairwise import cosine_similarity -from tqdm import tqdm -from transformers import ( - PreTrainedTokenizer, - PreTrainedModel, -) -from huggingface_hub import hf_hub_url, cached_download -from string import punctuation -import flair -from flair.data import Sentence, EntityLinkingLabel + +from collections import defaultdict +from flair.data import Sentence, EntityLinkingLabel, DataPoint, Label from flair.datasets import ( NEL_CTD_CHEMICAL_DICT, NEL_CTD_DISEASE_DICT, NEL_NCBI_HUMAN_GENE_DICT, NEL_NCBI_TAXONOMY_DICT, ) +from flair.embeddings import TransformerDocumentEmbeddings from flair.file_utils import cached_path -from typing import List, Tuple, Union +from huggingface_hub import hf_hub_url, cached_download from pathlib import Path -import subprocess -import tempfile -import faiss -from flair.embeddings import TransformerDocumentEmbeddings -from string import punctuation +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity +from tqdm import tqdm +from transformers import ( + PreTrainedTokenizer, + PreTrainedModel +) +from typing import List, Tuple, Union, Optional, Dict log = logging.getLogger("flair") -class BigramTfIDFVectorizer: +class MentionPreprocessor: """ - Class to encode a list of mentions into a sparse tensor. + A mention preprocessor is used to transform / clean an entity mention (recognized by + an entity recognition model in the original text). This can include removing certain characters + (e.g. punctuation) or converting characters (e.g. HTML-encoded characters) as well as + (more sophisticated) domain-specific procedures. - Slightly modified from Sung et al. 2020 - Biomedical Entity Representations with Synonym Marginalization - https://github.com/dmis-lab/BioSyn/tree/master/src/biosyn/sparse_encoder.py#L8 + This class provides the basic interface for such transformations and should be extended by + subclasses that implement the concrete transformations. """ + def process(self, entity_mention: Union[DataPoint, str], sentence: Sentence) -> str: + """ + Process the given entity mention and applies the transformation procedure to it. - def __init__(self) -> None: - self.encoder = TfidfVectorizer(analyzer="char", ngram_range=(1, 2)) - - def transform(self, mentions: list) -> torch.Tensor: - vec = self.encoder.transform(mentions).toarray() - vec = torch.FloatTensor(vec) - return vec - - def __call__(self, mentions: list): - return self.transform(mentions) - - def save_encoder(self, path: Path): - with open(path, "wb") as fout: - pickle.dump(self.encoder, fout) - log.info("Sparse encoder saved in {}".format(path)) - - @classmethod - def load(cls, path: Path): - newVectorizer = cls() - with open(path, "rb") as fin: - newVectorizer.encoder = pickle.load(fin) - log.info("Sparse encoder loaded from {}".format(path)) + :param entity_mention: entity mention either given as DataPoint or str + :param sentence: sentence in which the entity mentioned occurred + """ + raise NotImplementedError() - return newVectorizer + def initialize(self, sentences: List[Sentence]) -> None: + """ + Initializes the pre-processor for a batch of sentences, which is may be necessary for + more sophisticated transformations. + """ + # Do nothing by default + pass -class TextPreprocess: +class BasicMentionPreprocessor(MentionPreprocessor): """ - Text Preprocess module - Support lowercase, removing punctuation, typo correction + Basic implementation of MentionPreprocessor, which supports lowercasing, typo correction + and removing of punctuation characters. - Slightly modifed from Sung et al. 2020 - Biomedical Entity Representations with Synonym Marginalization - https://github.com/dmis-lab/BioSyn/blob/master/src/biosyn/preprocesser.py#L5 + Implementation is adapted from: + Slightly modifed from Sung et al. 2020 + Biomedical Entity Representations with Synonym Marginalization + https://github.com/dmis-lab/BioSyn/blob/master/src/biosyn/preprocesser.py#L5 """ def __init__( self, lowercase: bool = True, remove_punctuation: bool = True, + punctuation_symbols: str = string.punctuation ) -> None: """ - :param typo_path str: path of known typo dictionary + :param lowercase: Indicates whether to perform lowercasing or not (True by default) + :param remove_punctuation: Indicates whether to perform removal punctuations symbols (True by default) + :param punctuation_symbols: String containing all punctuation symbols that should be removed + (default is given by string.punctuation) """ self.lowercase = lowercase - self.rmv_puncts = remove_punctuation - self.punctuation = punctuation + self.remove_punctuation = remove_punctuation self.rmv_puncts_regex = re.compile( - r"[\s{}]+".format(re.escape(self.punctuation)) + r"[\s{}]+".format(re.escape(punctuation_symbols)) ) - def remove_punctuation(self, phrase: str) -> str: - phrase = self.rmv_puncts_regex.split(phrase) - phrase = " ".join(phrase).strip() - - return phrase + def process(self, entity_mention: Union[DataPoint, str], sentence: Sentence) -> str: + mention_text = entity_mention if isinstance(entity_mention, str) else entity_mention.text - def run(self, text: str) -> str: if self.lowercase: - text = text.lower() + mention_text = mention_text.lower() - if self.rmv_puncts: - text = self.remove_punctuation(text) + if self.remove_punctuation: + mention_text = self.rmv_puncts_regex.split(mention_text) + mention_text = " ".join(mention_text).strip() - text = text.strip() + mention_text = mention_text.strip() - return text + return mention_text -class Ab3P: +class Ab3PMentionPreprocessor(MentionPreprocessor): """ - Module for the Abbreviation Resolver Ab3P - https://github.com/ncbi-nlp/Ab3P + Implementation of MentionPreprocessor which utilizes Ab3P, an (biomedical)abbreviation definition detector, + given in: + https://github.com/ncbi-nlp/Ab3P + + Ab3P applies a set of rules reflecting simple patterns such as Alpha Beta (AB) as well as more involved cases. + The algorithm is described in detail in the following paper: + + Abbreviation definition identification based on automatic precision estimates. + Sohn S, Comeau DC, Kim W, Wilbur WJ. BMC Bioinformatics. 2008 Sep 25;9:402. + PubMed ID: 18817555 """ - def __init__(self, ab3p_path: Path, word_data_dir: Path) -> None: + def __init__( + self, + ab3p_path: Path, + word_data_dir: Path, + mention_preprocessor: Optional[MentionPreprocessor] = None + ) -> None: + """ + :param ab3p_path: Path to the folder containing the Ab3P implementation + :param word_data_dir: Path to the word data directory + :param mention_preprocessor: Mention text preprocessor that is used before trying to link + the mention text to an abbreviation. + + """ self.ab3p_path = ab3p_path self.word_data_dir = word_data_dir + self.mention_preprocessor = mention_preprocessor + + def initialize(self, sentences: List[Sentence]) -> None: + self.abbreviation_dict = self._build_abbreviation_dict(sentences) + + def process(self, entity_mention: Union[Label, str], sentence: Sentence) -> str: + sentence_text = sentence.to_tokenized_string().strip() + + tokens = ( + [token.text for token in entity_mention.data_point.tokens] + if isinstance(entity_mention, Label) + else [entity_mention] # FIXME: Maybe split mention on whitespaces here?? + ) + + parsed_tokens = [] + for token in tokens: + if self.mention_preprocessor is not None: + token = self.mention_preprocessor.process(token, sentence) + + if sentence_text in self.abbreviation_dict: + if token.lower() in self.abbreviation_dict[sentence_text]: + parsed_tokens.append(self.abbreviation_dict[sentence_text][token.lower()]) + continue + + if len(token) != 0: + parsed_tokens.append(token) + + return " ".join(parsed_tokens) @classmethod - def load(cls, ab3p_path: Path = None): - data_dir = os.path.join(flair.cache_root, "ab3p") - if not os.path.exists(data_dir): - os.mkdir(os.path.join(data_dir)) - word_data_dir = os.path.join(data_dir, "word_data/") - if not os.path.exists(word_data_dir): - os.mkdir(word_data_dir) + def load( + cls, + ab3p_path: Path = None, + mention_preprocessor: Optional[MentionPreprocessor] = None + ): + data_dir = flair.cache_root / "ab3p" + if not data_dir.exists(): + data_dir.mkdir(parents=True) + + word_data_dir = data_dir / "word_data" + if not word_data_dir.exists(): + word_data_dir.mkdir() + if ab3p_path is None: ab3p_path = cls.download_ab3p(data_dir, word_data_dir) - return cls(ab3p_path, word_data_dir) + + return cls(ab3p_path, word_data_dir, mention_preprocessor) @classmethod def download_ab3p(cls, data_dir: Path, word_data_dir: Path) -> Path: - # download word data for Ab3P if not already downloaded - ab3p_url = ( - "https://raw.githubusercontent.com/dmis-lab/BioSyn/master/Ab3P/WordData/" - ) + """ + Downloads the Ab3P tool and all necessary data files. + """ + + # Download word data for Ab3P if not already downloaded + ab3p_url = "https://raw.githubusercontent.com/dmis-lab/BioSyn/master/Ab3P/WordData/" + ab3p_files = [ "Ab3P_prec.dat", "Lf1chSf", @@ -160,25 +213,42 @@ def download_ab3p(cls, data_dir: Path, word_data_dir: Path) -> Path: "stop", ] for file in ab3p_files: - data_path = cached_path(ab3p_url + file, word_data_dir) - # download ab3p executable + cached_path(ab3p_url + file, word_data_dir) + + # Download Ab3P executable ab3p_path = cached_path( "https://github.com/dmis-lab/BioSyn/raw/master/Ab3P/identify_abbr", data_dir ) - os.chmod(ab3p_path, stat.S_IEXEC) + + ab3p_path.chmod(ab3p_path.stat().st_mode | stat.S_IXUSR) return ab3p_path - def build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> dict: - abbreviation_dict = {} - # tempfile to store the data to pass to the ab3p executable + def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict[str, Dict[str, str]]: + """ + Processes the given sentences with the Ab3P tool. The function returns a dictionary + containing the abbreviations found for each sentence, e.g.: + + { + "Respiratory syncytial viruses ( RSV ) are a subgroup of the paramyxoviruses.": + {"RSV": "Respiratory syncytial viruses"}, + "Rous sarcoma virus ( RSV ) is a retrovirus.": + {"RSV": "Rous sarcoma virus"} + } + + """ + abbreviation_dict = defaultdict(dict) + + # Create a temp file which holds the sentences we want to process with ab3p with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as temp_file: for sentence in sentences: temp_file.write(sentence.to_tokenized_string() + "\n") temp_file.flush() - # temporarily create path file in the current working directory for Ab3P + + # Temporarily create path file in the current working directory for Ab3P with open(os.path.join(os.getcwd(), "path_Ab3P"), "w") as path_file: - path_file.write(self.word_data_dir + "\n") - # run ab3p with the temp file containing the dataset + path_file.write(str(self.word_data_dir) + "/\n") + + # Run ab3p with the temp file containing the dataset try: result = subprocess.run( [self.ab3p_path, temp_file.name], @@ -206,12 +276,22 @@ def build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> dict: ) lines = line.split("\n") + cur_sentence = None for line in lines: if len(line.split("|")) == 3: + if cur_sentence is None: + continue + sf, lf, _ = line.split("|") sf = sf.strip().lower() lf = lf.strip().lower() - abbreviation_dict[sf] = lf + abbreviation_dict[cur_sentence][sf] = lf + + elif len(line.strip()) > 0: + cur_sentence = line + else: + cur_sentence = None + finally: # remove the path file os.remove(os.path.join(os.getcwd(), "path_Ab3P")) @@ -219,6 +299,41 @@ def build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> dict: return abbreviation_dict +class BigramTfIDFVectorizer: + """ + Class to encode a list of mentions into a sparse tensor. + + Slightly modified from Sung et al. 2020 + Biomedical Entity Representations with Synonym Marginalization + https://github.com/dmis-lab/BioSyn/tree/master/src/biosyn/sparse_encoder.py#L8 + """ + + def __init__(self) -> None: + self.encoder = TfidfVectorizer(analyzer="char", ngram_range=(1, 2)) + + def transform(self, mentions: list) -> torch.Tensor: + vec = self.encoder.transform(mentions).toarray() + vec = torch.FloatTensor(vec) + return vec + + def __call__(self, mentions: list): + return self.transform(mentions) + + def save_encoder(self, path: Path): + with open(path, "wb") as fout: + pickle.dump(self.encoder, fout) + log.info("Sparse encoder saved in {}".format(path)) + + @classmethod + def load(cls, path: Path): + newVectorizer = cls() + with open(path, "rb") as fin: + newVectorizer.encoder = pickle.load(fin) + log.info("Sparse encoder loaded from {}".format(path)) + + return newVectorizer + + class DictionaryDataset: """ A class used to load dictionary data from a custom dictionary file. @@ -273,15 +388,15 @@ def get_data(self, dictionary_path): yield (name, cui) -class BiEncoderEntityLiker: - """ - A class to load a model and use it to encode a dictionary and entities - """ +class BiEncoderEntityLinker: def __init__( - self, use_sparse_embeds: bool, max_length: int, index_use_cuda: bool + self, + use_sparse_embeddings: bool, + max_length: int, + index_use_cuda: bool ) -> None: - self.use_sparse_embeds = use_sparse_embeds + self.use_sparse_embeds = use_sparse_embeddings self.max_length = max_length self.tokenizer = None @@ -311,7 +426,7 @@ def load_model( :param batch_size: The batch size for embedding the dictionary """ self.load_dense_encoder(model_name_or_path) - self.use_cosine = True + self.use_cosine = use_cosine if self.use_sparse_embeds: self.load_sparse_encoder(model_name_or_path) @@ -602,7 +717,7 @@ def embed_dictionary( else: # get names from dictionary and remove punctuation - punctuation_regex = re.compile(r"[\s{}]+".format(re.escape(punctuation))) + punctuation_regex = re.compile(r"[\s{}]+".format(re.escape(string.punctuation))) dictionary_names = [] for row in self.dictionary: name = punctuation_regex.split(row[0]) @@ -737,8 +852,8 @@ class MultiBiEncoderEntityLinker: def __init__( self, - models: List[BiEncoderEntityLiker], - ab3p=None, + models: List[BiEncoderEntityLinker], + preprocessor: Optional[MentionPreprocessor] = None, ) -> None: """ Initalize class, called by classmethod load @@ -746,8 +861,7 @@ def __init__( :param ab3p_path: path to ab3p model """ self.models = models - self.text_preprocessor = TextPreprocess() - self.ab3p = ab3p + self.preprocessor = preprocessor @classmethod def load( @@ -820,8 +934,8 @@ def load( model.load_model(dictionary_path) else: - model = BiEncoderEntityLiker( - use_sparse_embeds=use_sparse_and_dense_embeds, + model = BiEncoderEntityLinker( + use_sparse_embeddings=use_sparse_and_dense_embeds, max_length=max_length, index_use_cuda=index_use_cuda, ) @@ -836,14 +950,16 @@ def load( models.append(model) # load ab3p model - ab3p = Ab3P.load(ab3p_path) if use_ab3p else None + preprocessor = Ab3PMentionPreprocessor.load( + mention_preprocessor=BasicMentionPreprocessor() + ) #Ab3P.load(ab3p_path) if use_ab3p else None - return cls(models, ab3p) + return cls(models, preprocessor) @staticmethod def __get_model_path( model_name: str, use_sparse_and_dense_embeds - ) -> BiEncoderEntityLiker: + ) -> BiEncoderEntityLinker: model_name = model_name.lower() model_path = model_name @@ -944,29 +1060,20 @@ def predict( if not isinstance(sentences, list): sentences = [sentences] - # use Ab3P to build abbreviation dictionary - if self.ab3p is not None: - abbreviation_dict = self.ab3p.build_abbreviation_dict(sentences) + if self.preprocessor is not None: + self.preprocessor.initialize(sentences) for model in self.models: - for sentence in sentences: for entity in sentence.get_labels(input_entity_annotation_layer): - # preprocess mention - if abbreviation_dict is not None: - parsed_tokens = [] - for token in entity.data_point.tokens: - token = self.text_preprocessor.run(token.text) - if token in abbreviation_dict: - parsed_tokens.append(abbreviation_dict[token.lower()]) - elif len(token) != 0: - parsed_tokens.append(token) - mention = " ".join(parsed_tokens) - else: - mention = self.text_preprocessor.run(entity.span.text) + mention_text = ( + self.preprocessor.process(entity, sentence) + if self.preprocessor is not None + else entity.data_point.text + ) # get predictions from dictionary - predictions = model.get_predictions(mention, topk) + predictions = model.get_predictions(mention_text, topk) # add predictions to entity label_name = ( From acf5fb6318772340a0b0a0fbbaea7e0728a35176 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20S=C3=A4nger?= Date: Wed, 15 Mar 2023 15:23:24 +0100 Subject: [PATCH 03/58] Refactor entity linking model structure --- flair/models/biomedical_entity_linking.py | 1002 ++++++++++----------- 1 file changed, 489 insertions(+), 513 deletions(-) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index 2cae48059..ab8f6f122 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -26,10 +26,6 @@ from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm -from transformers import ( - PreTrainedTokenizer, - PreTrainedModel -) from typing import List, Tuple, Union, Optional, Dict log = logging.getLogger("flair") @@ -37,27 +33,30 @@ class MentionPreprocessor: """ - A mention preprocessor is used to transform / clean an entity mention (recognized by - an entity recognition model in the original text). This can include removing certain characters - (e.g. punctuation) or converting characters (e.g. HTML-encoded characters) as well as - (more sophisticated) domain-specific procedures. + A mention preprocessor is used to transform / clean an entity mention (recognized by + an entity recognition model in the original text). This can include removing certain characters + (e.g. punctuation) or converting characters (e.g. HTML-encoded characters) as well as + (more sophisticated) domain-specific procedures. - This class provides the basic interface for such transformations and should be extended by - subclasses that implement the concrete transformations. + This class provides the basic interface for such transformations and should be extended by + subclasses that implement the concrete transformations. """ def process(self, entity_mention: Union[DataPoint, str], sentence: Sentence) -> str: """ - Process the given entity mention and applies the transformation procedure to it. + Processes the given entity mention and applies the transformation procedure to it. - :param entity_mention: entity mention either given as DataPoint or str - :param sentence: sentence in which the entity mentioned occurred + :param entity_mention: entity mention either given as DataPoint or str + :param sentence: sentence in which the entity mentioned occurred + :result: Cleaned / transformed string representation of the given entity mention """ raise NotImplementedError() def initialize(self, sentences: List[Sentence]) -> None: """ - Initializes the pre-processor for a batch of sentences, which is may be necessary for - more sophisticated transformations. + Initializes the pre-processor for a batch of sentences, which is may be necessary for + more sophisticated transformations. + + :param sentences: List of sentences that will be processed. """ # Do nothing by default pass @@ -65,13 +64,12 @@ def initialize(self, sentences: List[Sentence]) -> None: class BasicMentionPreprocessor(MentionPreprocessor): """ - Basic implementation of MentionPreprocessor, which supports lowercasing, typo correction - and removing of punctuation characters. + Basic implementation of MentionPreprocessor, which supports lowercasing, typo correction + and removing of punctuation characters. - Implementation is adapted from: - Slightly modifed from Sung et al. 2020 - Biomedical Entity Representations with Synonym Marginalization - https://github.com/dmis-lab/BioSyn/blob/master/src/biosyn/preprocesser.py#L5 + Implementation is adapted from: + Sung et al. 2020, Biomedical Entity Representations with Synonym Marginalization + https://github.com/dmis-lab/BioSyn/blob/master/src/biosyn/preprocesser.py#L5 """ def __init__( @@ -81,10 +79,12 @@ def __init__( punctuation_symbols: str = string.punctuation ) -> None: """ - :param lowercase: Indicates whether to perform lowercasing or not (True by default) - :param remove_punctuation: Indicates whether to perform removal punctuations symbols (True by default) - :param punctuation_symbols: String containing all punctuation symbols that should be removed - (default is given by string.punctuation) + Initializes the mention preprocessor. + + :param lowercase: Indicates whether to perform lowercasing or not (True by default) + :param remove_punctuation: Indicates whether to perform removal punctuations symbols (True by default) + :param punctuation_symbols: String containing all punctuation symbols that should be removed + (default is given by string.punctuation) """ self.lowercase = lowercase self.remove_punctuation = remove_punctuation @@ -109,16 +109,16 @@ def process(self, entity_mention: Union[DataPoint, str], sentence: Sentence) -> class Ab3PMentionPreprocessor(MentionPreprocessor): """ - Implementation of MentionPreprocessor which utilizes Ab3P, an (biomedical)abbreviation definition detector, - given in: - https://github.com/ncbi-nlp/Ab3P + Implementation of MentionPreprocessor which utilizes Ab3P, an (biomedical)abbreviation definition detector, + given in: + https://github.com/ncbi-nlp/Ab3P - Ab3P applies a set of rules reflecting simple patterns such as Alpha Beta (AB) as well as more involved cases. - The algorithm is described in detail in the following paper: + Ab3P applies a set of rules reflecting simple patterns such as Alpha Beta (AB) as well as more involved cases. + The algorithm is described in detail in the following paper: - Abbreviation definition identification based on automatic precision estimates. - Sohn S, Comeau DC, Kim W, Wilbur WJ. BMC Bioinformatics. 2008 Sep 25;9:402. - PubMed ID: 18817555 + Abbreviation definition identification based on automatic precision estimates. + Sohn S, Comeau DC, Kim W, Wilbur WJ. BMC Bioinformatics. 2008 Sep 25;9:402. + PubMed ID: 18817555 """ def __init__( @@ -128,10 +128,12 @@ def __init__( mention_preprocessor: Optional[MentionPreprocessor] = None ) -> None: """ - :param ab3p_path: Path to the folder containing the Ab3P implementation - :param word_data_dir: Path to the word data directory - :param mention_preprocessor: Mention text preprocessor that is used before trying to link - the mention text to an abbreviation. + Creates the mention pre-processor + + :param ab3p_path: Path to the folder containing the Ab3P implementation + :param word_data_dir: Path to the word data directory + :param mention_preprocessor: Mention text preprocessor that is used before trying to link + the mention text to an abbreviation. """ self.ab3p_path = ab3p_path @@ -169,7 +171,7 @@ def process(self, entity_mention: Union[Label, str], sentence: Sentence) -> str: def load( cls, ab3p_path: Path = None, - mention_preprocessor: Optional[MentionPreprocessor] = None + preprocessor: Optional[MentionPreprocessor] = None ): data_dir = flair.cache_root / "ab3p" if not data_dir.exists(): @@ -182,7 +184,7 @@ def load( if ab3p_path is None: ab3p_path = cls.download_ab3p(data_dir, word_data_dir) - return cls(ab3p_path, word_data_dir, mention_preprocessor) + return cls(ab3p_path, word_data_dir, preprocessor) @classmethod def download_ab3p(cls, data_dir: Path, word_data_dir: Path) -> Path: @@ -225,16 +227,15 @@ def download_ab3p(cls, data_dir: Path, word_data_dir: Path) -> Path: def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict[str, Dict[str, str]]: """ - Processes the given sentences with the Ab3P tool. The function returns a dictionary - containing the abbreviations found for each sentence, e.g.: - - { - "Respiratory syncytial viruses ( RSV ) are a subgroup of the paramyxoviruses.": - {"RSV": "Respiratory syncytial viruses"}, - "Rous sarcoma virus ( RSV ) is a retrovirus.": - {"RSV": "Rous sarcoma virus"} - } - + Processes the given sentences with the Ab3P tool. The function returns a (nested) dictionary + containing the abbreviations found for each sentence, e.g.: + + { + "Respiratory syncytial viruses ( RSV ) are a subgroup of the paramyxoviruses.": + {"RSV": "Respiratory syncytial viruses"}, + "Rous sarcoma virus ( RSV ) is a retrovirus.": + {"RSV": "Rous sarcoma virus"} + } """ abbreviation_dict = defaultdict(dict) @@ -301,31 +302,31 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict class BigramTfIDFVectorizer: """ - Class to encode a list of mentions into a sparse tensor. + Helper class to encode a list of entity mentions or dictionary entries into a sparse tensor. - Slightly modified from Sung et al. 2020 - Biomedical Entity Representations with Synonym Marginalization - https://github.com/dmis-lab/BioSyn/tree/master/src/biosyn/sparse_encoder.py#L8 + Implementation adapted from: + Sung et al.: Biomedical Entity Representations with Synonym Marginalization, 2020 + https://github.com/dmis-lab/BioSyn/tree/master/src/biosyn/sparse_encoder.py#L8 """ def __init__(self) -> None: self.encoder = TfidfVectorizer(analyzer="char", ngram_range=(1, 2)) - def transform(self, mentions: list) -> torch.Tensor: + def transform(self, mentions: List[str]) -> torch.Tensor: vec = self.encoder.transform(mentions).toarray() vec = torch.FloatTensor(vec) return vec - def __call__(self, mentions: list): + def __call__(self, mentions: List[str]) -> torch.Tensor: return self.transform(mentions) - def save_encoder(self, path: Path): - with open(path, "wb") as fout: + def save_encoder(self, path: Path) -> None: + with path.open("wb") as fout: pickle.dump(self.encoder, fout) log.info("Sparse encoder saved in {}".format(path)) @classmethod - def load(cls, path: Path): + def load(cls, path: Path) -> "BigramTfIDFVectorizer": newVectorizer = cls() with open(path, "rb") as fin: newVectorizer.encoder = pickle.load(fin) @@ -348,7 +349,9 @@ class DictionaryDataset: """ def __init__( - self, dictionary_path: Union[Path, str], load_into_memory: True + self, + dictionary_path: Union[Path, str], + load_into_memory: bool = True ) -> None: """ :param dictionary_path str: The path of the dictionary @@ -359,7 +362,7 @@ def __init__( else: self.data = self.get_data(dictionary_path) - def load_data(self, dictionary_path) -> np.ndarray: + def load_data(self, dictionary_path: str) -> np.ndarray: data = [] with open(dictionary_path, mode="r", encoding="utf-8") as f: lines = f.readlines() @@ -387,70 +390,146 @@ def get_data(self, dictionary_path): name = name.lower() yield (name, cui) + @classmethod + def load(cls, dictionary_name_or_path: str): + # use provided dictionary + if dictionary_name_or_path == "ctd-disease": + return NEL_CTD_DISEASE_DICT() + elif dictionary_name_or_path == "ctd-chemical": + return NEL_CTD_CHEMICAL_DICT() + elif dictionary_name_or_path == "ncbi-gene": + return NEL_NCBI_HUMAN_GENE_DICT() + elif dictionary_name_or_path == "ncbi-taxonomy": + return NEL_NCBI_TAXONOMY_DICT() + # use custom dictionary file + else: + return DictionaryDataset(dictionary_path=dictionary_name_or_path) -class BiEncoderEntityLinker: - def __init__( - self, - use_sparse_embeddings: bool, - max_length: int, - index_use_cuda: bool - ) -> None: - self.use_sparse_embeds = use_sparse_embeddings - self.max_length = max_length - - self.tokenizer = None - self.encoder = None +class EntityRetrieverModel: + """ + An entity retriever model is used to find the top-k entities / concepts of a knowledge base / + dictionary for a given entity mention in text. + """ - self.sparse_encoder = None - self.sparse_weight = None + def get_top_k( + self, + entity_mention: str, + top_k: int + ) -> List[Tuple[str, str, float]]: + """ + Returns the top-k entity / concept identifiers for the given entity mention. - self.index_use_cuda = index_use_cuda and flair.device.type == "cuda" + :param entity_mention: Entity mention text under investigation + :param top_k: Number of (best-matching) entities from the knowledge base to return + :result: List of tuples highlighting the top-k entities. Each tuple has the following + structure (entity / concept name, concept ids, score). + """ + raise NotImplementedError() - self.dense_dictionary_index = None - self.dict_sparse_embeds = None - self.dictionary = None +class ExactStringMatchingRetrieverModel(EntityRetrieverModel): + """ + Implementation of an entity retriever model which uses exact string matching to + find the entity / concept identifier for a given entity mention. + """ + def __init__(self, dictionary: DictionaryDataset): + # Build index which maps concept / entity names to concept / entity ids + self.name_to_id_index = {name: cui for name, cui in dictionary.data} + @classmethod def load_model( + cls, + dictionary_name_or_path: str, + ): + # Load dictionary + return cls(DictionaryDataset.load(dictionary_name_or_path)) + + def get_top_k( + self, + entity_mention: str, + top_k: int + ) -> List[Tuple[str, str, float]]: + """ + Returns the top-k entity / concept identifiers for the given entity mention. Note that + the model either return the entity with an identical name in the knowledge base / dictionary + or none. + + :param entity_mention: Entity mention under investigation + :param top_k: Number of (best-matching) entities from the knowledge base to return + :result: List of tuples highlighting the top-k entities. Each tuple has the following + structure (entity / concept name, concept ids, score). + """ + if entity_mention in self.name_to_id_index: + return [(entity_mention, self.name_to_id_index[entity_mention], 1.0)] + else: + return [] + + +class BiEncoderEntityRetrieverModel(EntityRetrieverModel): + """ + Implementation of EntityRetrieverModel which uses dense (transformer-based) embeddings and (optionally) + sparse character-based representations, for normalizing an entity mention to specific identifiers + in a knowledge base / dictionary. + + To this end, the model embeds the entity mention text and all concept names from the knowledge + base and outputs the k best-matching concepts based on embedding similarity. + """ + + def __init__( self, model_name_or_path: Union[str, Path], dictionary_name_or_path: str, - batch_size: int = 1024, - use_cosine: bool = True, - ): + use_sparse_embeddings: bool, + use_cosine: bool, + max_length: int, + batch_size: int, + index_use_cuda: bool, + top_k_extra_dense: int = 10, + top_k_extra_sparse: int = 10 + ) -> None: """ - Load the model and embed the dictionary - :param model_name_or_path: The path of the model - :param dictionary_name_or_path: The path of the dictionary - :param batch_size: The batch size for embedding the dictionary + Initializes the BiEncoderEntityRetrieverModel. + + :param model_name_or_path: Name of or path to the transformer model to be used. + :param dictionary_name_or_path: Name of or path to the transformer model to be used. + :param use_sparse_embeddings: Indicates whether to use sparse embeddings or not + :param use_cosine: Indicates whether to use cosine similarity (instead of inner product) + :param max_length: Maximal number of tokens used for embedding an entity mention / concept name + :param batch_size: Batch size used during embedding of the dictionary and top-k prediction + :param index_use_cuda: Indicates whether to use CUDA while indexing the dictionary / knowledge base + :param top_k_extra_sparse: Number of extra entities (resp. their sparse embeddings) which should be + retrieved while combining sparse and dense scores + :param top_k_extra_dense: Number of extra entities (resp. their dense embeddings) which should be + retrieved while combining sparse and dense scores """ - self.load_dense_encoder(model_name_or_path) + self.use_sparse_embeds = use_sparse_embeddings self.use_cosine = use_cosine + self.max_length = max_length + self.batch_size = batch_size + self.top_k_extra_dense = top_k_extra_dense + self.top_k_extra_sparse = top_k_extra_sparse + self.index_use_cuda = index_use_cuda and flair.device.type == "cuda" + + # Load dense encoder + self.dense_encoder = TransformerDocumentEmbeddings( + model=model_name_or_path, + is_token_embedding=False + ) + # Load sparse encoder if self.use_sparse_embeds: - self.load_sparse_encoder(model_name_or_path) - self.load_sparse_weight(model_name_or_path) + #FIXME: What happens if sparse encoder isn't pre-trained??? + self._load_sparse_encoder(model_name_or_path) + self._load_sparse_weight(model_name_or_path) - self.embed_dictionary( + self._embed_dictionary( model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path, batch_size=batch_size, ) - return self - - def load_dense_encoder( - self, model_name_or_path: Union[str, Path] - ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: - - self.encoder = TransformerDocumentEmbeddings( - model_name_or_path, is_token_embedding=False - ) - - return self.encoder - - def load_sparse_encoder( + def _load_sparse_encoder( self, model_name_or_path: Union[str, Path] ) -> BigramTfIDFVectorizer: sparse_encoder_path = os.path.join(model_name_or_path, "sparse_encoder.pk") @@ -469,7 +548,7 @@ def load_sparse_encoder( return self.sparse_encoder - def load_sparse_weight(self, model_name_or_path: Union[str, Path]) -> torch.Tensor: + def _load_sparse_weight(self, model_name_or_path: Union[str, Path]) -> float: sparse_weight_path = os.path.join(model_name_or_path, "sparse_weight.pt") # check file exists if not os.path.isfile(sparse_weight_path): @@ -482,40 +561,43 @@ def load_sparse_weight(self, model_name_or_path: Union[str, Path]) -> torch.Tens cache_dir=flair.cache_root / "models" / model_name_or_path, ) - self.sparse_weight = torch.load(sparse_weight_path, map_location="cpu") + self.sparse_weight = torch.load(sparse_weight_path, map_location="cpu").item() return self.sparse_weight - def get_sparse_weight(self) -> torch.Tensor: - assert self.sparse_weight is not None - - return self.sparse_weight - - def embed_sparse(self, names: list) -> np.ndarray: + def _embed_sparse(self, entity_names: np.ndarray) -> np.ndarray: """ - Embedding data into sparse representations - :param names np.array: An array of names - :returns sparse_embeds np.array: A list of sparse embeddings + Embeds the given numpy array of entity names, either originating from the knowledge base + or recognized in a text, into sparse representations. + + :param entity_names: An array of entity / concept names + :returns sparse_embeds np.array: Numpy array containing the sparse embeddings """ - sparse_embeds = self.sparse_encoder(names) + sparse_embeds = self.sparse_encoder(entity_names) sparse_embeds = sparse_embeds.numpy() + if self.use_cosine: + faiss.normalize_L2(sparse_embeds) + return sparse_embeds - def embed_dense( + def _embed_dense( self, names: np.ndarray, - show_progress: bool = False, batch_size: int = 2048, + show_progress: bool = False ) -> np.ndarray: """ - Embedding data into dense representations for SapBert - :param names: np.array of names + Embeds the given numpy array of entity / concept names, either originating from the + knowledge base or recognized in a text, into dense representations using a + TransformerDocumentEmbedding model. + + :param names: Numpy array of entity / concept names + :param batch_size: Batch size used while embedding the name :param show_progress: bool to toggle progress bar - :param batch_size: batch size - :return dense_embeds: list of dense embeddings of the names + :return: Numpy array containing the dense embeddings of the names """ - self.encoder.eval() # prevent dropout + self.dense_encoder.eval() # prevent dropout dense_embeds = [] @@ -529,13 +611,12 @@ def embed_dense( iterations = range(0, len(names), batch_size) for start in iterations: - - # make batch + # Create batch end = min(start + batch_size, len(names)) batch = [Sentence(name) for name in names[start:end]] # embed batch - self.encoder.embed(batch) + self.dense_encoder.embed(batch) dense_embeds += [ name.embedding.cpu().detach().numpy() for name in batch @@ -544,45 +625,131 @@ def embed_dense( if flair.device.type == "cuda": torch.cuda.empty_cache() - return np.array(dense_embeds) + dense_embeds = np.array(dense_embeds) + if self.use_cosine: + faiss.normalize_L2(dense_embeds) + + return dense_embeds - def get_sparse_similarity_scores( - self, - query_embeds: np.ndarray, - dict_embeds: np.ndarray, - cosine: bool = False, - normalise: bool = False, - ) -> np.ndarray: + def _embed_dictionary( + self, + model_name_or_path: str, + dictionary_name_or_path: str, + batch_size: int + ): """ - Return score matrix - :param query_embeds: 2d numpy array of query embeddings - :param dict_embeds: 2d numpy array of query embeddings - :param score_matrix: 2d numpy array of scores + Computes the embeddings for the given knowledge base / dictionary. """ - if cosine: - score_matrix = cosine_similarity(query_embeds, dict_embeds) + # Load dictionary + self.dictionary = DictionaryDataset.load(dictionary_name_or_path).data + + # Check for embedded dictionary in cache + dictionary_name = os.path.splitext(os.path.basename(dictionary_name_or_path))[0] + file_name = f"bio_nen_{model_name_or_path.split('/')[-1]}_{dictionary_name}" + + cache_folder = flair.cache_root / "datasets" + emb_dictionary_cache_file = cache_folder / f"{file_name}.pk" + + # If exists, load the cached dictionary indices + if emb_dictionary_cache_file.exists(): + self._load_cached_dense_emb_dictionary(emb_dictionary_cache_file) + else: - score_matrix = np.matmul(query_embeds, dict_embeds.T) + # get names from dictionary and remove punctuation + # FIXME: Why doing this here???? + punctuation_regex = re.compile(r"[\s{}]+".format(re.escape(string.punctuation))) - if normalise: - score_matrix = (score_matrix - score_matrix.min()) / ( - score_matrix.max() - score_matrix.min() + dictionary_names = [] + for row in self.dictionary: + name = punctuation_regex.split(row[0]) + name = " ".join(name).strip().lower() + dictionary_names.append(name) + dictionary_names = np.array(dictionary_names) + + # Compute dense embeddings (if necessary) + self.dict_dense_embeddings = self._embed_dense( + names=dictionary_names, + batch_size=batch_size, + show_progress=True ) - return score_matrix + # To use cosine similarity, we normalize the vectors and then use inner product + if self.use_cosine: + faiss.normalize_L2(self.dict_dense_embeddings) + + # Compute sparse embeddings (if necessary) + if self.use_sparse_embeds: + self.dict_sparse_embeddings = self._embed_sparse(entity_names=dictionary_names) + else: + self.dict_sparse_embeddings = None + + # Build dense embedding index using faiss + dimension = self.dict_dense_embeddings.shape[1] + self.dense_dictionary_index = faiss.IndexFlatIP(dimension) + self.dense_dictionary_index.add(self.dict_dense_embeddings) + + # Store the pre-computed index on disk for later re-use + cached_dictionary = { + "dictionary": self.dictionary, + "sparse_dictionary_embeds": self.dict_sparse_embeddings, + "dense_dictionary_index": self.dense_dictionary_index, + } + + if not cache_folder.exists(): + cache_folder.mkdir(parents=True) + + log.info(f"Saving dictionary into cached file {cache_folder}") + with emb_dictionary_cache_file.open("wb") as cache_file: + pickle.dump(cached_dictionary, cache_file) + + # If we use CUDA - move index to GPU + if self.index_use_cuda: + self.dense_dictionary_index = faiss.index_cpu_to_gpu( + faiss.StandardGpuResources(), 0, self.dense_dictionary_index + ) + + def _load_cached_dense_emb_dictionary(self, cached_dictionary_path: Path): + """ + Loads pre-computed dense dictionary embedding from disk. + """ + with cached_dictionary_path.open("rb") as cached_file: + log.info("Loaded dictionary from cached file {}".format(cached_dictionary_path)) + cached_dictionary = pickle.load(cached_file) + + self.dictionary, self.dict_sparse_embeddings, self.dense_dictionary_index = ( + cached_dictionary["dictionary"], + cached_dictionary["sparse_dictionary_embeds"], + cached_dictionary["dense_dictionary_index"] + ) + + if self.index_use_cuda: + self.dense_dictionary_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, self.dense_dictionary_index) def retrieve_sparse_topk_candidates( - self, - score_matrix: np.ndarray, - topk: int, - ) -> Tuple[np.ndarray, np.ndarray]: + self, + mention_embeddings: np.ndarray, + dict_concept_embeddings: np.ndarray, + top_k: int, + normalise: bool = False, + ) -> Tuple[np.ndarray, np.ndarray]: """ - Return sorted topk idxes (descending order) + Returns top-k indexes (in descending order) for the given entity mentions resp. mention + embeddings. + :param score_matrix: 2d numpy array of scores - :param topk: number of candidates to retrieve + :param top_k: number of candidates to retrieve :return res: d numpy array of ids [# of query , # of dict] :return scores: numpy array of top scores """ + if self.use_cosine: + score_matrix = cosine_similarity(mention_embeddings, dict_concept_embeddings) + else: + score_matrix = np.matmul(mention_embeddings, dict_concept_embeddings.T) + + if normalise: + score_matrix = (score_matrix - score_matrix.min()) / ( + score_matrix.max() - score_matrix.min() + ) def indexing_2d(arr, cols): rows = np.repeat( @@ -590,70 +757,64 @@ def indexing_2d(arr, cols): ) return arr[rows, cols] - # get topk indexes without sorting - topk_idxs = np.argpartition(score_matrix, -topk)[:, -topk:] + # Get topk indexes without sorting + topk_idxs = np.argpartition(score_matrix, -top_k)[:, -top_k:] - # get topk indexes with sorting + # Get topk indexes with sorting topk_score_matrix = indexing_2d(score_matrix, topk_idxs) topk_argidxs = np.argsort(-topk_score_matrix) topk_idxs = indexing_2d(topk_idxs, topk_argidxs) topk_scores = indexing_2d(score_matrix, topk_idxs) - return topk_idxs, topk_scores + return (topk_idxs, topk_scores) - def get_predictions( - self, - mention: str, - topk: int, - batch_size: int = 1024, - ) -> np.ndarray: + def get_top_k( + self, + entity_mention: str, + top_k: int + ) -> List[Tuple[str, str, float]]: """ - Return the topk predictions for a mention and their scores - :param mention: string of the mention to find candidates for - :param topk: number of candidates - :return res: d numpy array of ids [# of query , # of dict] - :return scores: numpy array of top predictions and their scores + Returns the top-k entities for a given entity mention. + + :param entity_mention: Entity mention text under investigation + :param top_k: Number of (best-matching) entities from the knowledge base to return + :result: List of tuples highlighting the top-k entities. Each tuple has the following + structure (entity / concept name, concept ids, score). """ - # get dense embeds for mention - mention_dense_embeds = self.embed_dense(names=[mention]) - if self.use_cosine: - # normalize mention vector - faiss.normalize_L2(mention_dense_embeds) + # Compute dense embedding for the given entity mention + mention_dense_embeds = self._embed_dense( + names=np.array([entity_mention]), + batch_size=self.batch_size + ) - assert ( - self.dense_dictionary_index is not None - ), "Index not built yet, please run load_model to embed your dictionary before calling get_predictions" + # Search for more than top-k candidates if combining them with sparse scores + top_k_dense = top_k if not self.use_sparse_embeds else top_k + self.top_k_extra_dense - # if using sparse embeds: calculate hybrid scores with dense and sparse embeds + # Get candidates from dense embeddings + dense_scores, dense_ids = self.dense_dictionary_index.search( + x=mention_dense_embeds, + k=top_k_dense + ) + + # If using sparse embeds: calculate hybrid scores with dense and sparse embeds if self.use_sparse_embeds: - assert ( - self.dict_sparse_embeds is not None - ), "Index not built yet, please run load_model to embed your dictionary before calling get_predictions" - # search for more than topk candidates to use them when combining with sparse scores - # get candidates from dense embeddings - dense_scores, dense_ids = self.dense_dictionary_index.search( - x=mention_dense_embeds, k=topk + 10 - ) - # get sparse embeds for mention - mention_sparse_embeds = self.embed_sparse(names=[mention]) - if self.use_cosine: - # normalize mention vector - faiss.normalize_L2(mention_sparse_embeds) + # Get sparse embeddings for the entity mention + mention_sparse_embeds = self._embed_sparse(entity_names=np.array([entity_mention])) - # get candidates from sprase embeddings - sparse_weight = self.get_sparse_weight().item() - sparse_score_matrix = self.get_sparse_similarity_scores( - query_embeds=mention_sparse_embeds, dict_embeds=self.dict_sparse_embeds - ) + # Get candidates from sparse embeddings sparse_ids, sparse_distances = self.retrieve_sparse_topk_candidates( - score_matrix=sparse_score_matrix, topk=topk + 10 + mention_embeddings=mention_sparse_embeds, + dict_concept_embeddings=self.dict_sparse_embeddings, + top_k=top_k + self.top_k_extra_sparse ) - # combine dense and sparse scores + # Combine dense and sparse scores + sparse_weight = self.sparse_weight hybrid_ids = [] hybrid_scores = [] - # for every embedded mention + + # For every embedded mention for ( top_dense_ids, top_dense_scores, @@ -662,6 +823,7 @@ def get_predictions( ) in zip(dense_ids, dense_scores, sparse_ids, sparse_distances): ids = top_dense_ids distances = top_dense_scores + for sparse_id, sparse_distance in zip( top_sparse_ids, top_sparse_distances ): @@ -677,289 +839,169 @@ def get_predictions( ) + distances[index] sorted_indizes = np.argsort(-distances) - ids = ids[sorted_indizes][:topk] - distances = distances[sorted_indizes][:topk] + ids = ids[sorted_indizes][:top_k] + distances = distances[sorted_indizes][:top_k] hybrid_ids.append(ids.tolist()) hybrid_scores.append(distances.tolist()) - return [ - np.append(self.dictionary[ind], score) - for ind, score in zip(hybrid_ids, hybrid_scores) - ] - # use only dense embeds else: - dense_distances, dense_ids = self.dense_dictionary_index.search( - x=mention_dense_embeds, k=topk - ) - return [ - np.append(self.dictionary[ind], score) - for ind, score in zip(dense_ids, dense_distances) - ] + # Use only dense embedding results + hybrid_ids = dense_ids + hybrid_scores = dense_scores - def embed_dictionary( - self, model_name_or_path: str, dictionary_name_or_path: str, batch_size: int - ): - # check for embedded dictionary in cache - dictionary_name = os.path.splitext(os.path.basename(dictionary_name_or_path))[0] - cache_folder = os.path.join(flair.cache_root, "datasets") - file_name = f"bio_nen_{model_name_or_path.split('/')[-1]}_{dictionary_name}" - cached_dictionary_path = os.path.join( - cache_folder, - f"{file_name}.pk", - ) - self.load_dictionary(dictionary_name_or_path) - - # If exists, load the cached dictionary indices - if os.path.exists(cached_dictionary_path): - self.load_cached_dictionary(cached_dictionary_path) - - # else, load and embed - else: - - # get names from dictionary and remove punctuation - punctuation_regex = re.compile(r"[\s{}]+".format(re.escape(string.punctuation))) - dictionary_names = [] - for row in self.dictionary: - name = punctuation_regex.split(row[0]) - name = " ".join(name).strip().lower() - dictionary_names.append(name) - - # create dense and sparse embeddings - if self.use_sparse_embeds: - self.dict_dense_embeds = self.embed_dense( - names=dictionary_names, - batch_size=batch_size, - show_progress=True - ) - self.dict_sparse_embeds = self.embed_sparse(names=dictionary_names) - - # create only dense embeddings - else: - self.dict_dense_embeds = self.embed_dense( - names=dictionary_names, show_progress=True - ) - self.dict_sparse_embeds = None - - # build dense index - dimension = self.dict_dense_embeds.shape[1] - if self.use_cosine: - # to use cosine similarity, we normalize the vectors and then use inner product - faiss.normalize_L2(self.dict_dense_embeds) - - self.dense_dictionary_index = faiss.IndexFlatIP(dimension) - if self.index_use_cuda: - self.dense_dictionary_index = faiss.index_cpu_to_gpu( - faiss.StandardGpuResources(), 0, self.dense_dictionary_index - ) - self.dense_dictionary_index.add(self.dict_dense_embeds) - - if self.index_use_cuda: - # create a cpu version of the index to cache (IO does not work with gpu index) - cache_index_dense = faiss.index_gpu_to_cpu(self.dense_dictionary_index) - # cache dictionary - self.cache_dictionary( - cache_index_dense, cache_folder, cached_dictionary_path - ) - else: - self.cache_dictionary( - self.dense_dictionary_index, cache_folder, cached_dictionary_path - ) - - def load_cached_dictionary(self, cached_dictionary_path: str): - with open(cached_dictionary_path, "rb") as cached_file: - cached_dictionary = pickle.load(cached_file) - log.info( - "Loaded dictionary from cached file {}".format(cached_dictionary_path) - ) - - (self.dictionary, self.dict_sparse_embeds, self.dense_dictionary_index,) = ( - cached_dictionary["dictionary"], - cached_dictionary["sparse_dictionary_embeds"], - cached_dictionary["dense_dictionary_index"], - ) - if self.index_use_cuda: - self.dense_dictionary_index = faiss.index_cpu_to_gpu( - faiss.StandardGpuResources(), 0, self.dense_dictionary_index - ) - - def load_dictionary(self, dictionary_name_or_path: str): - # use provided dictionary - if dictionary_name_or_path == "ctd-disease": - self.dictionary = NEL_CTD_DISEASE_DICT().data - elif dictionary_name_or_path == "ctd-chemical": - self.dictionary = NEL_CTD_CHEMICAL_DICT().data - elif dictionary_name_or_path == "ncbi-gene": - self.dictionary = NEL_NCBI_HUMAN_GENE_DICT().data - elif dictionary_name_or_path == "ncbi-taxonomy": - self.dictionary = NEL_NCBI_TAXONOMY_DICT().data - # use custom dictionary file - else: - self.dictionary = DictionaryDataset( - dictionary_path=dictionary_name_or_path - ).data - - def cache_dictionary(self, cache_index_dense, cache_folder, cached_dictionary_path): - cached_dictionary = { - "dictionary": self.dictionary, - "sparse_dictionary_embeds": self.dict_sparse_embeds, - "dense_dictionary_index": cache_index_dense, - } - if not os.path.exists(cache_folder): - os.mkdir(cache_folder) - with open(cached_dictionary_path, "wb") as cache_file: - pickle.dump(cached_dictionary, cache_file) - print("Saving dictionary into cached file {}".format(cache_folder)) - - -class ExactStringMatchEntityLinker: - def __init__(self): - self.dictionary = None - - def load_model( - self, - dictionary_name_or_path: str, - ): - # use provided dictionary - if dictionary_name_or_path == "ctd-disease": - dictionary_data = NEL_CTD_DISEASE_DICT().data - elif dictionary_name_or_path == "ctd-chemical": - dictionary_data = NEL_CTD_CHEMICAL_DICT().data - elif dictionary_name_or_path == "ncbi-gene": - dictionary_data = NEL_NCBI_HUMAN_GENE_DICT().data - elif dictionary_name_or_path == "ncbi-taxonomy": - dictionary_data = NEL_NCBI_TAXONOMY_DICT().data - # use custom dictionary file - else: - dictionary_data = DictionaryDataset( - dictionary_path=dictionary_name_or_path - ).data - - # make dictionary from array of tuples (entity, entity_id) - self.dictionary = {name: cui for name, cui in dictionary_data} - - def get_predictions(self, mention: str, topk) -> np.ndarray: - if mention in self.dictionary: - return np.array([(mention, self.dictionary[mention], 1.0)]) - else: - return [] + return [ + tuple(self.dictionary[entity_index].reshape(1, -1)[0]) + (score[0],) + for entity_index, score in zip(hybrid_ids, hybrid_scores) + ] -class MultiBiEncoderEntityLinker: +class BiomedicalEntityLinker: """ - Biomedical Entity Linker for HunFlair - Can predict top k entities on sentences annotated with biomedical entity mentions + Entity linking model which expects text/sentences with annotated entity mentions and predicts + entity / concept to these mentions according to a knowledge base / dictionary. """ - def __init__( + self, + retriever_model: EntityRetrieverModel, + mention_preprocessor: MentionPreprocessor + ): + self.preprocessor = mention_preprocessor + self.retriever_model = retriever_model + + def predict( self, - models: List[BiEncoderEntityLinker], - preprocessor: Optional[MentionPreprocessor] = None, + sentences: Union[List[Sentence], Sentence], + input_entity_annotation_layer: str = None, + top_k: int = 1, ) -> None: """ - Initalize class, called by classmethod load - :param models: list of objects containing the dense and sparse encoders - :param ab3p_path: path to ab3p model + Predicts the best matching top-k entity / concept identifiers of all named entites annotated + with tag input_entity_annotation_layer. + + :param sentences: One or more sentences to run the prediction on + :param input_entity_annotation_layer: Entity type to run the prediction on + :param top_k: Number of best-matching entity / concept identifiers per entity mention """ - self.models = models - self.preprocessor = preprocessor + # make sure sentences is a list of sentences + if not isinstance(sentences, list): + sentences = [sentences] + + if self.preprocessor is not None: + self.preprocessor.initialize(sentences) + + # Build label name + label_name = ( + input_entity_annotation_layer + "_nen" + if (input_entity_annotation_layer is not None) + else "nen" + ) + + # For every sentence .. + for sentence in sentences: + # ... process every mentioned entity + for entity in sentence.get_labels(input_entity_annotation_layer): + # Pre-process entity mention (if necessary) + mention_text = ( + self.preprocessor.process(entity, sentence) + if self.preprocessor is not None + else entity.data_point.text + ) + # Retrieve top-k concept / entity candidates + predictions = self.retriever_model.get_top_k(mention_text, top_k) + + # Add a label annotation for each candidate + for prediction in predictions: + # if concept unique id is made up of mulitple ids, seperated by '|' + # seperate it into cui and additional_labels + cui = prediction[1] + if "|" in cui: + labels = cui.split("|") + cui = labels[0] + additional_labels = labels[1:] + else: + additional_labels = None + # determine database: + if ":" in cui: + cui_parts = cui.split(":") + database = ":".join(cui_parts[0:-1]) + cui = cui_parts[-1] + else: + database = None + sentence.add_label( + typename=label_name, + value_or_label=EntityLinkingLabel( + data_point=entity.data_point, + id=cui, + concept_name=prediction[0], + additional_ids=additional_labels, + database=database, + score=prediction[2], + ), + ) @classmethod def load( cls, - model_names: Union[List[str], str], - dictionary_names_or_paths: Union[str, Path, List[str], List[Path]] = None, - use_sparse_and_dense_embeds: bool = True, - max_length=25, - batch_size=1024, - index_use_cuda=False, + model_name_or_path: Union[str, Path], + dictionary_name_or_path: Union[str, Path] = None, + use_sparse_embeddings: bool = True, + max_length: int = 25, + batch_size: int = 1024, + index_use_cuda: bool = False, use_cosine: bool = True, - use_ab3p: bool = True, - ab3p_path: Path = None, + preprocessor: MentionPreprocessor = Ab3PMentionPreprocessor.load(preprocessor=BasicMentionPreprocessor()) ): """ - Load a model for biomedical named entity normalization on sentences annotated with - biomedical entity mentions - :param model_names: List of names of pretrained models to use. Possible values for pretrained models are: - chemical, disease, gene, sapbert-bc5cdr-dissaease, sapbert-ncbi-disease, sapbert-bc5cdr-chemical, biobert-bc5cdr-disease, - biobert-ncbi-disease, biobert-bc5cdr-chemical, biosyn-biobert-bc2gn, biosyn-sapbert-bc2gn, sapbert, exact-string-match - :param dictionary_path: Name of one of the provided dictionaries listing all possible ids and their synonyms - or a path to a dictionary file with each line in the format: id||name, with one line for each name of a concept. - Possible values for dictionaries are: chemical, ctd-chemical, disease, bc5cdr-disease, gene, cnbci-gene, - taxonomy and ncbi-taxonomy - :param use_sparse_and_dense_embeds: If True, uses a combinations of sparse and dense embeddings for the dictionary and the mentions - If False, uses only dense embeddings + Loads a model for biomedical named entity normalization. + + :param model_name_or_path: Name of or path to a pretrained model to use. Possible values for pretrained + models are: + chemical, disease, gene, sapbert-bc5cdr-dissaease, sapbert-ncbi-disease, sapbert-bc5cdr-chemical, + biobert-bc5cdr-disease,biobert-ncbi-disease, biobert-bc5cdr-chemical, biosyn-biobert-bc2gn, + biosyn-sapbert-bc2gn, sapbert, exact-string-match + :param dictionary_name_or_path: Name of or path to a dictionary listing all possible entity / concept + identifiers and their concept names / synonyms. Pre-defined dictionaries are: + chemical, ctd-chemical, disease, bc5cdr-disease, gene, cnbci-gene, taxonomy and ncbi-taxonomy + :param use_sparse_embeddings: Indicates whether to use sparse embeddings for inference. If True, + uses a combinations of sparse and dense embeddings. If False, uses only dense embeddings + :param: max_length: Maximal number of tokens for an entity mention or concept name :param batch_size: Batch size for the dense encoder - :param index_use_cuda: If True, uses GPU for the dense encoder - :param use_cosine: If True, uses cosine similarity for the dense encoder. If False, uses inner product + :param index_use_cuda: If True, uses GPU for the dense encoding + :param use_cosine: If True, uses cosine similarity for the dense encoder. If False, inner product is used. :param use_ab3p: If True, uses ab3p to resolve abbreviations :param ab3p_path: Optional: oath to ab3p on your machine """ - # validate input: check that amount of models and dictionaries match - if isinstance(model_names, str): - model_names = [model_names] - - # case one dictionary for all models - if isinstance(dictionary_names_or_paths, str) or isinstance( - dictionary_names_or_paths, Path - ): - dictionary_names_or_paths = [dictionary_names_or_paths] * len(model_names) - # case no dictionary provided - elif dictionary_names_or_paths is None: - dictionary_names_or_paths = [None] * len(model_names) - # case one model, multiple dictionaries - elif len(model_names) == 1: - model_names = model_names * len(dictionary_names_or_paths) - # case mismatching amount of models and dictionaries - elif len(model_names) != len(dictionary_names_or_paths): - raise ValueError( - "Amount of models and dictionaries must match. Got {} models and {} dictionaries".format( - len(model_names), len(dictionary_names_or_paths) - ) - ) - assert len(model_names) == len(dictionary_names_or_paths) - - models = [] - - for model_name, dictionary_name_or_path in zip( - model_names, dictionary_names_or_paths - ): - # get the paths for the model and dictionary - model_path = cls.__get_model_path(model_name, use_sparse_and_dense_embeds) - dictionary_path = cls.__get_dictionary_path( - dictionary_name_or_path, model_name=model_name - ) - - if model_path == "exact-string-match": - model = ExactStringMatchEntityLinker() - model.load_model(dictionary_path) - + dictionary_path = dictionary_name_or_path + if dictionary_name_or_path is None or isinstance(dictionary_name_or_path, str): + dictionary_path = cls.__get_dictionary_path(dictionary_name_or_path, model_name_or_path) + + retriever_model = None + if isinstance(model_name_or_path, str): + if model_name_or_path == "exact-string-match": + retriever_model = ExactStringMatchingRetrieverModel.load_model(dictionary_path) else: - model = BiEncoderEntityLinker( - use_sparse_embeddings=use_sparse_and_dense_embeds, - max_length=max_length, - index_use_cuda=index_use_cuda, - ) - - model.load_model( + model_path = cls.__get_model_path(model_name_or_path, use_sparse_embeddings) + retriever_model = BiEncoderEntityRetrieverModel( model_name_or_path=model_path, dictionary_name_or_path=dictionary_path, - batch_size=batch_size, + use_sparse_embeddings=use_sparse_embeddings, use_cosine=use_cosine, + max_length=max_length, + batch_size=batch_size, + index_use_cuda=index_use_cuda, ) - models.append(model) - - # load ab3p model - preprocessor = Ab3PMentionPreprocessor.load( - mention_preprocessor=BasicMentionPreprocessor() - ) #Ab3P.load(ab3p_path) if use_ab3p else None - - return cls(models, preprocessor) + return cls( + retriever_model=retriever_model, + mention_preprocessor=preprocessor + ) @staticmethod def __get_model_path( - model_name: str, use_sparse_and_dense_embeds - ) -> BiEncoderEntityLinker: + model_name: str, + use_sparse_and_dense_embeds: bool + ) -> str: + model_name = model_name.lower() model_path = model_name @@ -1000,7 +1042,10 @@ def __get_model_path( return model_path @staticmethod - def __get_dictionary_path(dictionary_path: str, model_name: str): + def __get_dictionary_path( + dictionary_path: str, + model_name: str + ) -> str: # determine dictionary to use if dictionary_path == "disease": dictionary_path = "ctd-disease" @@ -1040,72 +1085,3 @@ def __get_dictionary_path(dictionary_path: str, model_name: str): raise ValueError("Invalid dictionary") return dictionary_path - - def predict( - self, - sentences: Union[List[Sentence], Sentence], - input_entity_annotation_layer: str = None, - topk: int = 1, - ) -> None: - """ - On one or more sentences, predict the cui on all named entites annotated with a tag of type input_entity_annotation_layer. - Annotates the top k predictions. - :param sentences: one or more sentences to run the predictions on - :param input_entity_annotation_layer: only entities with in this annotation layer will be annotated - :param topk: number of predicted cui candidates to add to annotation - :param abbreviation_dict: dictionary with abbreviations and their expanded form or a boolean value indicating whether - abbreviations should be expanded using Ab3P - """ - # make sure sentences is a list of sentences - if not isinstance(sentences, list): - sentences = [sentences] - - if self.preprocessor is not None: - self.preprocessor.initialize(sentences) - - for model in self.models: - for sentence in sentences: - for entity in sentence.get_labels(input_entity_annotation_layer): - mention_text = ( - self.preprocessor.process(entity, sentence) - if self.preprocessor is not None - else entity.data_point.text - ) - - # get predictions from dictionary - predictions = model.get_predictions(mention_text, topk) - - # add predictions to entity - label_name = ( - (input_entity_annotation_layer + "_nen") - if (input_entity_annotation_layer is not None) - else "nen" - ) - for prediction in predictions: - # if concept unique id is made up of mulitple ids, seperated by '|' - # seperate it into cui and additional_labels - cui = prediction[1] - if "|" in cui: - labels = cui.split("|") - cui = labels[0] - additional_labels = labels[1:] - else: - additional_labels = None - # determine database: - if ":" in cui: - cui_parts = cui.split(":") - database = ":".join(cui_parts[0:-1]) - cui = cui_parts[-1] - else: - database = None - sentence.add_label( - typename=label_name, - value_or_label=EntityLinkingLabel( - data_point=entity.data_point, - id=cui, - concept_name=prediction[0], - additional_ids=additional_labels, - database=database, - score=prediction[2].astype(float), - ), - ) From 56c89ba89778be10f25074aef52c2f47d9815f1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20S=C3=A4nger?= Date: Wed, 22 Mar 2023 17:41:02 +0100 Subject: [PATCH 04/58] Update documentation --- flair/data.py | 38 ++++++++---- flair/models/biomedical_entity_linking.py | 61 +++++++++++-------- .../HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md | 27 ++++---- 3 files changed, 77 insertions(+), 49 deletions(-) diff --git a/flair/data.py b/flair/data.py index 4928b0f57..d9067c4e5 100644 --- a/flair/data.py +++ b/flair/data.py @@ -433,26 +433,44 @@ def __len__(self) -> int: class EntityLinkingLabel(Label): + """ + Label class models entity linking annotations. Each entity linking label has a data point it refers + to as well as the identifier and name of the concept / entity from a knowledge base or ontology. + + Optionally, additional concepts identifier and the database name can be provided. + """ + def __init__( self, data_point: DataPoint, - id: str, + concept_id: str, concept_name: str, score: float = 1.0, additional_ids: Optional[Union[List[str], str]] = None, database: Optional[str] = None, ): - super().__init__(data_point, id, score) + """ + Initializes the label instance. + + :param data_point: Data point / span the label refers to + :param concept_id: Identifier of the entity / concept from the knowledge base / ontology + :param concept_name: (Canonical) name of the entity / concept from the knowledge base / ontology + :param score: Matching score of the entity / concept according to the entity mention + :param additional_ids: List of additional identifiers for the concept / entity in the KB / ontology + :param database: Name of the knowlege base / ontology + """ + super().__init__(data_point, concept_id, score) self.concept_name = concept_name + self.database = database + if isinstance(additional_ids, str): additional_ids = [additional_ids] self.additional_ids = additional_ids - self.database = database def spawn(self, value: str, score: float = 1.0): return EntityLinkingLabel( data_point=self.data_point, - id=value, + concept_id=value, score=score, concept_name=self.concept_name, additional_ids=self.additional_ids, @@ -460,17 +478,12 @@ def spawn(self, value: str, score: float = 1.0): ) def __str__(self): - if self.additional_ids is None: - return f"{self.data_point.unlabeled_identifier}{flair._arrow} " \ + return f"{self.data_point.unlabeled_identifier}{flair._arrow} " \ f"{self.concept_name} - {self.database}:{self._value} ({round(self._score, 4)})" - else: - return f" ({', '.join(self.additional_ids)}) {self.concept_name} ({round(self._score, 2)})" def __repr__(self): - if self.additional_ids is None: - return f"{self.database}:{self._value} {self.concept_name} [{self.data_point.text}] ({round(self._score, 2)})" - else: - return f"{self.database}:{self._value} ({', '.join(self.additional_ids)}) {self.concept_name} [{self.data_point.text}] ({round(self._score, 2)})" + return f"{self.data_point.unlabeled_identifier}{flair._arrow} " \ + f"{self.concept_name} - {self.database}:{self._value} ({round(self._score, 4)})" def __len__(self): return len(self.data_point) @@ -480,6 +493,7 @@ def __eq__(self, other): self.value == other.value and self.data_point == other.data_point and self.concept_name == other.concept_name + and self.identifier == other.identifier and self.database == other.database and self.score == other.score ) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index ab8f6f122..c08cb0fde 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -11,6 +11,7 @@ import tempfile import torch +from abc import ABC, abstractmethod from collections import defaultdict from flair.data import Sentence, EntityLinkingLabel, DataPoint, Label from flair.datasets import ( @@ -26,7 +27,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm -from typing import List, Tuple, Union, Optional, Dict +from typing import List, Tuple, Union, Optional, Dict, Iterable log = logging.getLogger("flair") @@ -341,7 +342,7 @@ class DictionaryDataset: Every line in the file must be formatted as follows: concept_unique_id||concept_name with one line per concept name. Multiple synonyms for the same concept should - be in seperate lines with the same concept_unique_id. + be in separate lines with the same concept_unique_id. Slightly modifed from Sung et al. 2020 Biomedical Entity Representations with Synonym Marginalization @@ -354,18 +355,20 @@ def __init__( load_into_memory: bool = True ) -> None: """ - :param dictionary_path str: The path of the dictionary + :param dictionary_path str: Path to the dictionary file + :param load_into_memory bool: Indicates whether the dictionary entries should be loaded in + memory or not (Default True) """ - log.info("Loading Dictionary from {}".format(dictionary_path)) + log.info("Loading dictionary from {}".format(dictionary_path)) if load_into_memory: self.data = self.load_data(dictionary_path) else: self.data = self.get_data(dictionary_path) - def load_data(self, dictionary_path: str) -> np.ndarray: + def load_data(self, dictionary_path: Union[Path, str]) -> np.ndarray: data = [] - with open(dictionary_path, mode="r", encoding="utf-8") as f: - lines = f.readlines() + with open(dictionary_path, mode="r", encoding="utf-8") as file: + lines = file.readlines() for line in tqdm(lines, desc="Loading dictionary"): line = line.strip() if line == "": @@ -378,7 +381,7 @@ def load_data(self, dictionary_path: str) -> np.ndarray: return data # generator version - def get_data(self, dictionary_path): + def get_data(self, dictionary_path: Union[Path, str]) -> Iterable[Tuple]: data = [] with open(dictionary_path, mode="r", encoding="utf-8") as f: lines = f.readlines() @@ -391,27 +394,29 @@ def get_data(self, dictionary_path): yield (name, cui) @classmethod - def load(cls, dictionary_name_or_path: str): - # use provided dictionary - if dictionary_name_or_path == "ctd-disease": - return NEL_CTD_DISEASE_DICT() - elif dictionary_name_or_path == "ctd-chemical": - return NEL_CTD_CHEMICAL_DICT() - elif dictionary_name_or_path == "ncbi-gene": - return NEL_NCBI_HUMAN_GENE_DICT() - elif dictionary_name_or_path == "ncbi-taxonomy": - return NEL_NCBI_TAXONOMY_DICT() - # use custom dictionary file + def load(cls, dictionary_name_or_path: Union[Path, str]): + if isinstance(dictionary_name_or_path, str): + # use provided dictionary + if dictionary_name_or_path == "ctd-disease": + return NEL_CTD_DISEASE_DICT() + elif dictionary_name_or_path == "ctd-chemical": + return NEL_CTD_CHEMICAL_DICT() + elif dictionary_name_or_path == "ncbi-gene": + return NEL_NCBI_HUMAN_GENE_DICT() + elif dictionary_name_or_path == "ncbi-taxonomy": + return NEL_NCBI_TAXONOMY_DICT() else: + # use custom dictionary file return DictionaryDataset(dictionary_path=dictionary_name_or_path) -class EntityRetrieverModel: +class EntityRetrieverModel(ABC): """ An entity retriever model is used to find the top-k entities / concepts of a knowledge base / dictionary for a given entity mention in text. """ + @abstractmethod def get_top_k( self, entity_mention: str, @@ -880,7 +885,8 @@ def predict( :param sentences: One or more sentences to run the prediction on :param input_entity_annotation_layer: Entity type to run the prediction on - :param top_k: Number of best-matching entity / concept identifiers per entity mention + :param top_k: Number of best-matching entity / concept identifiers which should be predicted + per entity mention """ # make sure sentences is a list of sentences if not isinstance(sentences, list): @@ -912,8 +918,8 @@ def predict( # Add a label annotation for each candidate for prediction in predictions: - # if concept unique id is made up of mulitple ids, seperated by '|' - # seperate it into cui and additional_labels + # if concept identifier is made up of multiple ids, separated by '|' + # separate it into cui and additional_labels cui = prediction[1] if "|" in cui: labels = cui.split("|") @@ -921,6 +927,7 @@ def predict( additional_labels = labels[1:] else: additional_labels = None + # determine database: if ":" in cui: cui_parts = cui.split(":") @@ -928,17 +935,19 @@ def predict( cui = cui_parts[-1] else: database = None + sentence.add_label( typename=label_name, value_or_label=EntityLinkingLabel( data_point=entity.data_point, - id=cui, + concept_id=cui, concept_name=prediction[0], additional_ids=additional_labels, database=database, score=prediction[2], ), ) + @classmethod def load( cls, @@ -968,8 +977,8 @@ def load( :param batch_size: Batch size for the dense encoder :param index_use_cuda: If True, uses GPU for the dense encoding :param use_cosine: If True, uses cosine similarity for the dense encoder. If False, inner product is used. - :param use_ab3p: If True, uses ab3p to resolve abbreviations - :param ab3p_path: Optional: oath to ab3p on your machine + :param preprocessor: Implementation of MentionPreprocessor to use for pre-processing the entity + mention text and dictionary entries """ dictionary_path = dictionary_name_or_path if dictionary_name_or_path is None or isinstance(dictionary_name_or_path, str): diff --git a/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md index 9af93fb6e..2925054f0 100644 --- a/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md +++ b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md @@ -2,16 +2,17 @@ After adding Named Entity Recognition tags to your sentence, you can run Named Entity Linking on these annotations. ```python -from flair.models import MultiTagger +from flair.models.biomedical_entity_linking import BiomedicalEntityLinker +from flair.nn import Classifier from flair.tokenization import SciSpacyTokenizer from flair.data import Sentence sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome", use_tokenizer=SciSpacyTokenizer()) -ner_tagger = MultiTagger.load("hunflair-disease") +ner_tagger = Classifier.load("hunflair-disease") ner_tagger.predict(sentence) -nen_tagger = MultiBiEncoderEntityLinker.load("disease") +nen_tagger = BiomedicalEntityLinker.load("disease") nen_tagger.predict(sentence) for tag in sentence.get_labels(): @@ -19,16 +20,20 @@ for tag in sentence.get_labels(): ``` This should print: ~~~ -Disease [Behavioral abnormalities (1,2)] (0.6736) -Disease [Fragile X Syndrome (10,11,12)] (0.99) -MESH:D001523 (DO:DOID:150) behavior disorders (0.98) -MESH:D005600 (DO:DOID:14261, OMIM:300624, OMIM:309548) fragile x syndrome (1.1) +Span[0:2]: "Behavioral abnormalities" → Disease (0.6736) +Span[0:2]: "Behavioral abnormalities" → behavior disorders - MESH:D001523 (0.9772) +Span[9:12]: "Fragile X Syndrome" → Disease (0.99) +Span[9:12]: "Fragile X Syndrome" → fragile x syndrome - MESH:D005600 (1.0976) ~~~ -This output contains the NER disease annotations and it's entity linking annotations with ids from (often more than one) database. -We have preconfigured combinations of models and dictionaries for "disease", "chemical" and "gene". You can also provide your own model and dictionary: +The output contains both the NER disease annotations and their entity / concept identifiers according to +a knowledge base or ontology. We have pre-configured combinations of models and dictionaries for +"disease", "chemical" and "gene". +You can also provide your own model and dictionary: ```python -nen_tagger = MultiBiEncoderEntityLinker.load("name_or_path_to_your_model", dictionary_names_or_paths="name_or_path_to_your_dictionary") -nen_tagger = MultiBiEncoderEntityLinker.load("path_to_custom_disease_model", dictionary_names_or_paths="disease") +from flair.models.biomedical_entity_linking import BiomedicalEntityLinker + +nen_tagger = BiomedicalEntityLinker.load("name_or_path_to_your_model", dictionary_names_or_paths="name_or_path_to_your_dictionary") +nen_tagger = BiomedicalEntityLinker.load("path_to_custom_disease_model", dictionary_names_or_paths="disease") ```` You can use any combination of provided models, provided dictionaries and your own. From 51fe95135dd7fffbbf0e2b1f7fc79be32c51fa1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20S=C3=A4nger?= Date: Thu, 23 Mar 2023 15:45:46 +0100 Subject: [PATCH 05/58] Introduce separate methods for pre-processing (1) entity mentions from text and (2) entity / concept names from an knowledge base or ontology --- flair/models/biomedical_entity_linking.py | 108 ++++++++++++---------- 1 file changed, 61 insertions(+), 47 deletions(-) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index c08cb0fde..41bd8eb3c 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from collections import defaultdict -from flair.data import Sentence, EntityLinkingLabel, DataPoint, Label +from flair.data import Sentence, EntityLinkingLabel, Label from flair.datasets import ( NEL_CTD_CHEMICAL_DICT, NEL_CTD_DISEASE_DICT, @@ -32,26 +32,36 @@ log = logging.getLogger("flair") -class MentionPreprocessor: +class EntityPreprocessor: """ - A mention preprocessor is used to transform / clean an entity mention (recognized by - an entity recognition model in the original text). This can include removing certain characters + A entity pre-processor is used to transform / clean an entity mention (recognized by + an entity recognition model in the original text). This may include removing certain characters (e.g. punctuation) or converting characters (e.g. HTML-encoded characters) as well as (more sophisticated) domain-specific procedures. This class provides the basic interface for such transformations and should be extended by - subclasses that implement the concrete transformations. + subclasses that implement concrete transformations. """ - def process(self, entity_mention: Union[DataPoint, str], sentence: Sentence) -> str: + def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: """ Processes the given entity mention and applies the transformation procedure to it. - :param entity_mention: entity mention either given as DataPoint or str + :param entity_mention: entity mention under investigation :param sentence: sentence in which the entity mentioned occurred :result: Cleaned / transformed string representation of the given entity mention """ raise NotImplementedError() + def process_entry(self, entity_name: str) -> str: + """ + Processes the given entity name (originating from a knowledge base / ontology) and + applies the transformation procedure to it. + + :param entity_name: entity mention given as DataPoint + :result: Cleaned / transformed string representation of the given entity mention + """ + raise NotImplementedError() + def initialize(self, sentences: List[Sentence]) -> None: """ Initializes the pre-processor for a batch of sentences, which is may be necessary for @@ -63,7 +73,7 @@ def initialize(self, sentences: List[Sentence]) -> None: pass -class BasicMentionPreprocessor(MentionPreprocessor): +class BasicEntityPreprocessor(EntityPreprocessor): """ Basic implementation of MentionPreprocessor, which supports lowercasing, typo correction and removing of punctuation characters. @@ -93,22 +103,21 @@ def __init__( r"[\s{}]+".format(re.escape(punctuation_symbols)) ) - def process(self, entity_mention: Union[DataPoint, str], sentence: Sentence) -> str: - mention_text = entity_mention if isinstance(entity_mention, str) else entity_mention.text - + def process_entry(self, entity_name: str) -> str: if self.lowercase: - mention_text = mention_text.lower() + entity_name = entity_name.lower() if self.remove_punctuation: - mention_text = self.rmv_puncts_regex.split(mention_text) - mention_text = " ".join(mention_text).strip() + entity_name = self.rmv_puncts_regex.split(entity_name) + entity_name = " ".join(entity_name).strip() - mention_text = mention_text.strip() + return entity_name.strip() - return mention_text + def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: + return self.process_entry(entity_mention.data_point.text) -class Ab3PMentionPreprocessor(MentionPreprocessor): +class Ab3PEntityPreprocessor(EntityPreprocessor): """ Implementation of MentionPreprocessor which utilizes Ab3P, an (biomedical)abbreviation definition detector, given in: @@ -126,37 +135,31 @@ def __init__( self, ab3p_path: Path, word_data_dir: Path, - mention_preprocessor: Optional[MentionPreprocessor] = None + preprocessor: Optional[EntityPreprocessor] = None ) -> None: """ Creates the mention pre-processor :param ab3p_path: Path to the folder containing the Ab3P implementation :param word_data_dir: Path to the word data directory - :param mention_preprocessor: Mention text preprocessor that is used before trying to link + :param preprocessor: Entity mention text preprocessor that is used before trying to link the mention text to an abbreviation. - """ self.ab3p_path = ab3p_path self.word_data_dir = word_data_dir - self.mention_preprocessor = mention_preprocessor + self.preprocessor = preprocessor def initialize(self, sentences: List[Sentence]) -> None: self.abbreviation_dict = self._build_abbreviation_dict(sentences) - def process(self, entity_mention: Union[Label, str], sentence: Sentence) -> str: + def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: sentence_text = sentence.to_tokenized_string().strip() - - tokens = ( - [token.text for token in entity_mention.data_point.tokens] - if isinstance(entity_mention, Label) - else [entity_mention] # FIXME: Maybe split mention on whitespaces here?? - ) + tokens = [token.text for token in entity_mention.data_point.tokens] parsed_tokens = [] for token in tokens: - if self.mention_preprocessor is not None: - token = self.mention_preprocessor.process(token, sentence) + if self.preprocessor is not None: + token = self.preprocessor.process_entry(token) if sentence_text in self.abbreviation_dict: if token.lower() in self.abbreviation_dict[sentence_text]: @@ -168,11 +171,19 @@ def process(self, entity_mention: Union[Label, str], sentence: Sentence) -> str: return " ".join(parsed_tokens) + def process_entry(self, entity_name: str) -> str: + # Ab3P works on sentence-level and not on a single entity mention / name + # - so we just apply the wrapped text pre-processing here (if configured) + if self.preprocessor is not None: + return self.preprocessor.process_entry(entity_name) + + return entity_name + @classmethod def load( cls, ab3p_path: Path = None, - preprocessor: Optional[MentionPreprocessor] = None + preprocessor: Optional[EntityPreprocessor] = None ): data_dir = flair.cache_root / "ab3p" if not data_dir.exists(): @@ -457,7 +468,7 @@ def get_top_k( ) -> List[Tuple[str, str, float]]: """ Returns the top-k entity / concept identifiers for the given entity mention. Note that - the model either return the entity with an identical name in the knowledge base / dictionary + the model either returns the entity with an identical name in the knowledge base / dictionary or none. :param entity_mention: Entity mention under investigation @@ -491,7 +502,8 @@ def __init__( batch_size: int, index_use_cuda: bool, top_k_extra_dense: int = 10, - top_k_extra_sparse: int = 10 + top_k_extra_sparse: int = 10, + preprocessor: Optional[EntityPreprocessor] = BasicEntityPreprocessor() ) -> None: """ Initializes the BiEncoderEntityRetrieverModel. @@ -507,6 +519,8 @@ def __init__( retrieved while combining sparse and dense scores :param top_k_extra_dense: Number of extra entities (resp. their dense embeddings) which should be retrieved while combining sparse and dense scores + :param preprocessor: Preprocessing strategy to clean and transform entity / concept names from + the knowledge base """ self.use_sparse_embeds = use_sparse_embeddings self.use_cosine = use_cosine @@ -515,6 +529,7 @@ def __init__( self.top_k_extra_dense = top_k_extra_dense self.top_k_extra_sparse = top_k_extra_sparse self.index_use_cuda = index_use_cuda and flair.device.type == "cuda" + self.preprocessor = preprocessor # Load dense encoder self.dense_encoder = TransformerDocumentEmbeddings( @@ -660,20 +675,19 @@ def _embed_dictionary( self._load_cached_dense_emb_dictionary(emb_dictionary_cache_file) else: - # get names from dictionary and remove punctuation - # FIXME: Why doing this here???? - punctuation_regex = re.compile(r"[\s{}]+".format(re.escape(string.punctuation))) - - dictionary_names = [] + # get all concept names from the dictionary + concept_names = [] for row in self.dictionary: - name = punctuation_regex.split(row[0]) - name = " ".join(name).strip().lower() - dictionary_names.append(name) - dictionary_names = np.array(dictionary_names) + concept_name = row[0] + if self.preprocessor is not None: + concept_name = self.preprocessor.process_entry(concept_name) + concept_names.append(concept_name) + + concept_names = np.array(concept_names) # Compute dense embeddings (if necessary) self.dict_dense_embeddings = self._embed_dense( - names=dictionary_names, + names=concept_names, batch_size=batch_size, show_progress=True ) @@ -684,7 +698,7 @@ def _embed_dictionary( # Compute sparse embeddings (if necessary) if self.use_sparse_embeds: - self.dict_sparse_embeddings = self._embed_sparse(entity_names=dictionary_names) + self.dict_sparse_embeddings = self._embed_sparse(entity_names=concept_names) else: self.dict_sparse_embeddings = None @@ -868,7 +882,7 @@ class BiomedicalEntityLinker: def __init__( self, retriever_model: EntityRetrieverModel, - mention_preprocessor: MentionPreprocessor + mention_preprocessor: EntityPreprocessor ): self.preprocessor = mention_preprocessor self.retriever_model = retriever_model @@ -908,7 +922,7 @@ def predict( for entity in sentence.get_labels(input_entity_annotation_layer): # Pre-process entity mention (if necessary) mention_text = ( - self.preprocessor.process(entity, sentence) + self.preprocessor.process_mention(entity, sentence) if self.preprocessor is not None else entity.data_point.text ) @@ -958,7 +972,7 @@ def load( batch_size: int = 1024, index_use_cuda: bool = False, use_cosine: bool = True, - preprocessor: MentionPreprocessor = Ab3PMentionPreprocessor.load(preprocessor=BasicMentionPreprocessor()) + preprocessor: EntityPreprocessor = Ab3PEntityPreprocessor.load(preprocessor=BasicEntityPreprocessor()) ): """ Loads a model for biomedical named entity normalization. From 19f74fb9eafe9d058a0be8e8037b3ed41b217784 Mon Sep 17 00:00:00 2001 From: Alan Akbik Date: Fri, 21 Apr 2023 21:04:16 +0200 Subject: [PATCH 06/58] Fix formatting --- flair/data.py | 15 +- flair/datasets/__init__.py | 4 +- flair/models/biomedical_entity_linking.py | 243 +++++++--------------- 3 files changed, 91 insertions(+), 171 deletions(-) diff --git a/flair/data.py b/flair/data.py index d9067c4e5..1b4176423 100644 --- a/flair/data.py +++ b/flair/data.py @@ -261,6 +261,7 @@ def labeled_identifier(self): def unlabeled_identifier(self): return f"{self.data_point.unlabeled_identifier}" + class DataPoint: """This is the parent class of all data points in Flair. @@ -474,16 +475,20 @@ def spawn(self, value: str, score: float = 1.0): score=score, concept_name=self.concept_name, additional_ids=self.additional_ids, - database=self.database + database=self.database, ) def __str__(self): - return f"{self.data_point.unlabeled_identifier}{flair._arrow} " \ - f"{self.concept_name} - {self.database}:{self._value} ({round(self._score, 4)})" + return ( + f"{self.data_point.unlabeled_identifier}{flair._arrow} " + f"{self.concept_name} - {self.database}:{self._value} ({round(self._score, 4)})" + ) def __repr__(self): - return f"{self.data_point.unlabeled_identifier}{flair._arrow} " \ - f"{self.concept_name} - {self.database}:{self._value} ({round(self._score, 4)})" + return ( + f"{self.data_point.unlabeled_identifier}{flair._arrow} " + f"{self.concept_name} - {self.database}:{self._value} ({round(self._score, 4)})" + ) def __len__(self): return len(self.data_point) diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index ed0037d70..c203b802b 100644 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -90,10 +90,10 @@ LOCTEXT, MIRNA, NCBI_DISEASE, - NEL_NCBI_HUMAN_GENE_DICT, - NEL_NCBI_TAXONOMY_DICT, NEL_CTD_CHEMICAL_DICT, NEL_CTD_DISEASE_DICT, + NEL_NCBI_HUMAN_GENE_DICT, + NEL_NCBI_TAXONOMY_DICT, OSIRIS, PDR, S800, diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index 41bd8eb3c..dbdbd0858 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -1,19 +1,26 @@ -import flair -import faiss import logging -import numpy as np import os import pickle import re -import subprocess import stat import string +import subprocess import tempfile -import torch - from abc import ABC, abstractmethod from collections import defaultdict -from flair.data import Sentence, EntityLinkingLabel, Label +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import faiss +import numpy as np +import torch +from huggingface_hub import cached_download, hf_hub_url +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity +from tqdm import tqdm + +import flair +from flair.data import EntityLinkingLabel, Label, Sentence from flair.datasets import ( NEL_CTD_CHEMICAL_DICT, NEL_CTD_DISEASE_DICT, @@ -22,12 +29,6 @@ ) from flair.embeddings import TransformerDocumentEmbeddings from flair.file_utils import cached_path -from huggingface_hub import hf_hub_url, cached_download -from pathlib import Path -from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.metrics.pairwise import cosine_similarity -from tqdm import tqdm -from typing import List, Tuple, Union, Optional, Dict, Iterable log = logging.getLogger("flair") @@ -42,6 +43,7 @@ class EntityPreprocessor: This class provides the basic interface for such transformations and should be extended by subclasses that implement concrete transformations. """ + def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: """ Processes the given entity mention and applies the transformation procedure to it. @@ -84,10 +86,7 @@ class BasicEntityPreprocessor(EntityPreprocessor): """ def __init__( - self, - lowercase: bool = True, - remove_punctuation: bool = True, - punctuation_symbols: str = string.punctuation + self, lowercase: bool = True, remove_punctuation: bool = True, punctuation_symbols: str = string.punctuation ) -> None: """ Initializes the mention preprocessor. @@ -99,9 +98,7 @@ def __init__( """ self.lowercase = lowercase self.remove_punctuation = remove_punctuation - self.rmv_puncts_regex = re.compile( - r"[\s{}]+".format(re.escape(punctuation_symbols)) - ) + self.rmv_puncts_regex = re.compile(r"[\s{}]+".format(re.escape(punctuation_symbols))) def process_entry(self, entity_name: str) -> str: if self.lowercase: @@ -131,12 +128,7 @@ class Ab3PEntityPreprocessor(EntityPreprocessor): PubMed ID: 18817555 """ - def __init__( - self, - ab3p_path: Path, - word_data_dir: Path, - preprocessor: Optional[EntityPreprocessor] = None - ) -> None: + def __init__(self, ab3p_path: Path, word_data_dir: Path, preprocessor: Optional[EntityPreprocessor] = None) -> None: """ Creates the mention pre-processor @@ -180,11 +172,7 @@ def process_entry(self, entity_name: str) -> str: return entity_name @classmethod - def load( - cls, - ab3p_path: Path = None, - preprocessor: Optional[EntityPreprocessor] = None - ): + def load(cls, ab3p_path: Path = None, preprocessor: Optional[EntityPreprocessor] = None): data_dir = flair.cache_root / "ab3p" if not data_dir.exists(): data_dir.mkdir(parents=True) @@ -201,7 +189,7 @@ def load( @classmethod def download_ab3p(cls, data_dir: Path, word_data_dir: Path) -> Path: """ - Downloads the Ab3P tool and all necessary data files. + Downloads the Ab3P tool and all necessary data files. """ # Download word data for Ab3P if not already downloaded @@ -230,9 +218,7 @@ def download_ab3p(cls, data_dir: Path, word_data_dir: Path) -> Path: cached_path(ab3p_url + file, word_data_dir) # Download Ab3P executable - ab3p_path = cached_path( - "https://github.com/dmis-lab/BioSyn/raw/master/Ab3P/identify_abbr", data_dir - ) + ab3p_path = cached_path("https://github.com/dmis-lab/BioSyn/raw/master/Ab3P/identify_abbr", data_dir) ab3p_path.chmod(ab3p_path.stat().st_mode | stat.S_IXUSR) return ab3p_path @@ -360,11 +346,7 @@ class DictionaryDataset: https://github.com/dmis-lab/BioSyn/blob/master/src/biosyn/data_loader.py#L89 """ - def __init__( - self, - dictionary_path: Union[Path, str], - load_into_memory: bool = True - ) -> None: + def __init__(self, dictionary_path: Union[Path, str], load_into_memory: bool = True) -> None: """ :param dictionary_path str: Path to the dictionary file :param load_into_memory bool: Indicates whether the dictionary entries should be loaded in @@ -428,11 +410,7 @@ class EntityRetrieverModel(ABC): """ @abstractmethod - def get_top_k( - self, - entity_mention: str, - top_k: int - ) -> List[Tuple[str, str, float]]: + def get_top_k(self, entity_mention: str, top_k: int) -> List[Tuple[str, str, float]]: """ Returns the top-k entity / concept identifiers for the given entity mention. @@ -449,6 +427,7 @@ class ExactStringMatchingRetrieverModel(EntityRetrieverModel): Implementation of an entity retriever model which uses exact string matching to find the entity / concept identifier for a given entity mention. """ + def __init__(self, dictionary: DictionaryDataset): # Build index which maps concept / entity names to concept / entity ids self.name_to_id_index = {name: cui for name, cui in dictionary.data} @@ -461,11 +440,7 @@ def load_model( # Load dictionary return cls(DictionaryDataset.load(dictionary_name_or_path)) - def get_top_k( - self, - entity_mention: str, - top_k: int - ) -> List[Tuple[str, str, float]]: + def get_top_k(self, entity_mention: str, top_k: int) -> List[Tuple[str, str, float]]: """ Returns the top-k entity / concept identifiers for the given entity mention. Note that the model either returns the entity with an identical name in the knowledge base / dictionary @@ -503,24 +478,24 @@ def __init__( index_use_cuda: bool, top_k_extra_dense: int = 10, top_k_extra_sparse: int = 10, - preprocessor: Optional[EntityPreprocessor] = BasicEntityPreprocessor() + preprocessor: Optional[EntityPreprocessor] = BasicEntityPreprocessor(), ) -> None: """ - Initializes the BiEncoderEntityRetrieverModel. - - :param model_name_or_path: Name of or path to the transformer model to be used. - :param dictionary_name_or_path: Name of or path to the transformer model to be used. - :param use_sparse_embeddings: Indicates whether to use sparse embeddings or not - :param use_cosine: Indicates whether to use cosine similarity (instead of inner product) - :param max_length: Maximal number of tokens used for embedding an entity mention / concept name - :param batch_size: Batch size used during embedding of the dictionary and top-k prediction - :param index_use_cuda: Indicates whether to use CUDA while indexing the dictionary / knowledge base - :param top_k_extra_sparse: Number of extra entities (resp. their sparse embeddings) which should be - retrieved while combining sparse and dense scores - :param top_k_extra_dense: Number of extra entities (resp. their dense embeddings) which should be - retrieved while combining sparse and dense scores - :param preprocessor: Preprocessing strategy to clean and transform entity / concept names from - the knowledge base + Initializes the BiEncoderEntityRetrieverModel. + + :param model_name_or_path: Name of or path to the transformer model to be used. + :param dictionary_name_or_path: Name of or path to the transformer model to be used. + :param use_sparse_embeddings: Indicates whether to use sparse embeddings or not + :param use_cosine: Indicates whether to use cosine similarity (instead of inner product) + :param max_length: Maximal number of tokens used for embedding an entity mention / concept name + :param batch_size: Batch size used during embedding of the dictionary and top-k prediction + :param index_use_cuda: Indicates whether to use CUDA while indexing the dictionary / knowledge base + :param top_k_extra_sparse: Number of extra entities (resp. their sparse embeddings) which should be + retrieved while combining sparse and dense scores + :param top_k_extra_dense: Number of extra entities (resp. their dense embeddings) which should be + retrieved while combining sparse and dense scores + :param preprocessor: Preprocessing strategy to clean and transform entity / concept names from + the knowledge base """ self.use_sparse_embeds = use_sparse_embeddings self.use_cosine = use_cosine @@ -532,14 +507,11 @@ def __init__( self.preprocessor = preprocessor # Load dense encoder - self.dense_encoder = TransformerDocumentEmbeddings( - model=model_name_or_path, - is_token_embedding=False - ) + self.dense_encoder = TransformerDocumentEmbeddings(model=model_name_or_path, is_token_embedding=False) # Load sparse encoder if self.use_sparse_embeds: - #FIXME: What happens if sparse encoder isn't pre-trained??? + # FIXME: What happens if sparse encoder isn't pre-trained??? self._load_sparse_encoder(model_name_or_path) self._load_sparse_weight(model_name_or_path) @@ -549,16 +521,12 @@ def __init__( batch_size=batch_size, ) - def _load_sparse_encoder( - self, model_name_or_path: Union[str, Path] - ) -> BigramTfIDFVectorizer: + def _load_sparse_encoder(self, model_name_or_path: Union[str, Path]) -> BigramTfIDFVectorizer: sparse_encoder_path = os.path.join(model_name_or_path, "sparse_encoder.pk") # check file exists if not os.path.isfile(sparse_encoder_path): # download from huggingface hub and cache it - sparse_encoder_url = hf_hub_url( - model_name_or_path, filename="sparse_encoder.pk" - ) + sparse_encoder_url = hf_hub_url(model_name_or_path, filename="sparse_encoder.pk") sparse_encoder_path = cached_download( url=sparse_encoder_url, cache_dir=flair.cache_root / "models" / model_name_or_path, @@ -573,9 +541,7 @@ def _load_sparse_weight(self, model_name_or_path: Union[str, Path]) -> float: # check file exists if not os.path.isfile(sparse_weight_path): # download from huggingface hub and cache it - sparse_weight_url = hf_hub_url( - model_name_or_path, filename="sparse_weight.pt" - ) + sparse_weight_url = hf_hub_url(model_name_or_path, filename="sparse_weight.pt") sparse_weight_path = cached_download( url=sparse_weight_url, cache_dir=flair.cache_root / "models" / model_name_or_path, @@ -601,12 +567,7 @@ def _embed_sparse(self, entity_names: np.ndarray) -> np.ndarray: return sparse_embeds - def _embed_dense( - self, - names: np.ndarray, - batch_size: int = 2048, - show_progress: bool = False - ) -> np.ndarray: + def _embed_dense(self, names: np.ndarray, batch_size: int = 2048, show_progress: bool = False) -> np.ndarray: """ Embeds the given numpy array of entity / concept names, either originating from the knowledge base or recognized in a text, into dense representations using a @@ -638,9 +599,7 @@ def _embed_dense( # embed batch self.dense_encoder.embed(batch) - dense_embeds += [ - name.embedding.cpu().detach().numpy() for name in batch - ] + dense_embeds += [name.embedding.cpu().detach().numpy() for name in batch] if flair.device.type == "cuda": torch.cuda.empty_cache() @@ -651,14 +610,9 @@ def _embed_dense( return dense_embeds - def _embed_dictionary( - self, - model_name_or_path: str, - dictionary_name_or_path: str, - batch_size: int - ): + def _embed_dictionary(self, model_name_or_path: str, dictionary_name_or_path: str, batch_size: int): """ - Computes the embeddings for the given knowledge base / dictionary. + Computes the embeddings for the given knowledge base / dictionary. """ # Load dictionary self.dictionary = DictionaryDataset.load(dictionary_name_or_path).data @@ -687,9 +641,7 @@ def _embed_dictionary( # Compute dense embeddings (if necessary) self.dict_dense_embeddings = self._embed_dense( - names=concept_names, - batch_size=batch_size, - show_progress=True + names=concept_names, batch_size=batch_size, show_progress=True ) # To use cosine similarity, we normalize the vectors and then use inner product @@ -729,7 +681,7 @@ def _embed_dictionary( def _load_cached_dense_emb_dictionary(self, cached_dictionary_path: Path): """ - Loads pre-computed dense dictionary embedding from disk. + Loads pre-computed dense dictionary embedding from disk. """ with cached_dictionary_path.open("rb") as cached_file: log.info("Loaded dictionary from cached file {}".format(cached_dictionary_path)) @@ -738,19 +690,21 @@ def _load_cached_dense_emb_dictionary(self, cached_dictionary_path: Path): self.dictionary, self.dict_sparse_embeddings, self.dense_dictionary_index = ( cached_dictionary["dictionary"], cached_dictionary["sparse_dictionary_embeds"], - cached_dictionary["dense_dictionary_index"] + cached_dictionary["dense_dictionary_index"], ) if self.index_use_cuda: - self.dense_dictionary_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, self.dense_dictionary_index) + self.dense_dictionary_index = faiss.index_cpu_to_gpu( + faiss.StandardGpuResources(), 0, self.dense_dictionary_index + ) def retrieve_sparse_topk_candidates( - self, - mention_embeddings: np.ndarray, - dict_concept_embeddings: np.ndarray, - top_k: int, - normalise: bool = False, - ) -> Tuple[np.ndarray, np.ndarray]: + self, + mention_embeddings: np.ndarray, + dict_concept_embeddings: np.ndarray, + top_k: int, + normalise: bool = False, + ) -> Tuple[np.ndarray, np.ndarray]: """ Returns top-k indexes (in descending order) for the given entity mentions resp. mention embeddings. @@ -766,14 +720,10 @@ def retrieve_sparse_topk_candidates( score_matrix = np.matmul(mention_embeddings, dict_concept_embeddings.T) if normalise: - score_matrix = (score_matrix - score_matrix.min()) / ( - score_matrix.max() - score_matrix.min() - ) + score_matrix = (score_matrix - score_matrix.min()) / (score_matrix.max() - score_matrix.min()) def indexing_2d(arr, cols): - rows = np.repeat( - np.arange(0, cols.shape[0])[:, np.newaxis], cols.shape[1], axis=1 - ) + rows = np.repeat(np.arange(0, cols.shape[0])[:, np.newaxis], cols.shape[1], axis=1) return arr[rows, cols] # Get topk indexes without sorting @@ -787,11 +737,7 @@ def indexing_2d(arr, cols): return (topk_idxs, topk_scores) - def get_top_k( - self, - entity_mention: str, - top_k: int - ) -> List[Tuple[str, str, float]]: + def get_top_k(self, entity_mention: str, top_k: int) -> List[Tuple[str, str, float]]: """ Returns the top-k entities for a given entity mention. @@ -802,19 +748,13 @@ def get_top_k( """ # Compute dense embedding for the given entity mention - mention_dense_embeds = self._embed_dense( - names=np.array([entity_mention]), - batch_size=self.batch_size - ) + mention_dense_embeds = self._embed_dense(names=np.array([entity_mention]), batch_size=self.batch_size) # Search for more than top-k candidates if combining them with sparse scores top_k_dense = top_k if not self.use_sparse_embeds else top_k + self.top_k_extra_dense # Get candidates from dense embeddings - dense_scores, dense_ids = self.dense_dictionary_index.search( - x=mention_dense_embeds, - k=top_k_dense - ) + dense_scores, dense_ids = self.dense_dictionary_index.search(x=mention_dense_embeds, k=top_k_dense) # If using sparse embeds: calculate hybrid scores with dense and sparse embeds if self.use_sparse_embeds: @@ -825,7 +765,7 @@ def get_top_k( sparse_ids, sparse_distances = self.retrieve_sparse_topk_candidates( mention_embeddings=mention_sparse_embeds, dict_concept_embeddings=self.dict_sparse_embeddings, - top_k=top_k + self.top_k_extra_sparse + top_k=top_k + self.top_k_extra_sparse, ) # Combine dense and sparse scores @@ -843,19 +783,13 @@ def get_top_k( ids = top_dense_ids distances = top_dense_scores - for sparse_id, sparse_distance in zip( - top_sparse_ids, top_sparse_distances - ): + for sparse_id, sparse_distance in zip(top_sparse_ids, top_sparse_distances): if sparse_id not in ids: ids = np.append(ids, sparse_id) - distances = np.append( - distances, sparse_weight * sparse_distance - ) + distances = np.append(distances, sparse_weight * sparse_distance) else: index = np.where(ids == sparse_id)[0][0] - distances[index] = ( - sparse_weight * sparse_distance - ) + distances[index] + distances[index] = (sparse_weight * sparse_distance) + distances[index] sorted_indizes = np.argsort(-distances) ids = ids[sorted_indizes][:top_k] @@ -879,11 +813,8 @@ class BiomedicalEntityLinker: Entity linking model which expects text/sentences with annotated entity mentions and predicts entity / concept to these mentions according to a knowledge base / dictionary. """ - def __init__( - self, - retriever_model: EntityRetrieverModel, - mention_preprocessor: EntityPreprocessor - ): + + def __init__(self, retriever_model: EntityRetrieverModel, mention_preprocessor: EntityPreprocessor): self.preprocessor = mention_preprocessor self.retriever_model = retriever_model @@ -910,11 +841,7 @@ def predict( self.preprocessor.initialize(sentences) # Build label name - label_name = ( - input_entity_annotation_layer + "_nen" - if (input_entity_annotation_layer is not None) - else "nen" - ) + label_name = input_entity_annotation_layer + "_nen" if (input_entity_annotation_layer is not None) else "nen" # For every sentence .. for sentence in sentences: @@ -972,7 +899,7 @@ def load( batch_size: int = 1024, index_use_cuda: bool = False, use_cosine: bool = True, - preprocessor: EntityPreprocessor = Ab3PEntityPreprocessor.load(preprocessor=BasicEntityPreprocessor()) + preprocessor: EntityPreprocessor = Ab3PEntityPreprocessor.load(preprocessor=BasicEntityPreprocessor()), ): """ Loads a model for biomedical named entity normalization. @@ -1014,17 +941,10 @@ def load( index_use_cuda=index_use_cuda, ) - return cls( - retriever_model=retriever_model, - mention_preprocessor=preprocessor - ) + return cls(retriever_model=retriever_model, mention_preprocessor=preprocessor) @staticmethod - def __get_model_path( - model_name: str, - use_sparse_and_dense_embeds: bool - ) -> str: - + def __get_model_path(model_name: str, use_sparse_and_dense_embeds: bool) -> str: model_name = model_name.lower() model_path = model_name @@ -1059,16 +979,11 @@ def __get_model_path( elif model_name == "chemical": model_path = "cambridgeltl/SapBERT-from-PubMedBERT-fulltext" elif model_name == "gene": - raise ValueError( - "No trained model for gene entity linking using only dense embeddings." - ) + raise ValueError("No trained model for gene entity linking using only dense embeddings.") return model_path @staticmethod - def __get_dictionary_path( - dictionary_path: str, - model_name: str - ) -> str: + def __get_dictionary_path(dictionary_path: str, model_name: str) -> str: # determine dictionary to use if dictionary_path == "disease": dictionary_path = "ctd-disease" From e5342d554e8d8859018ea1caa01de4fd4523f1e1 Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Wed, 26 Apr 2023 12:49:40 +0200 Subject: [PATCH 07/58] feat(test): biomedical entity linking --- tests/test_biomedical_entity_linking.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/test_biomedical_entity_linking.py diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py new file mode 100644 index 000000000..60d25b187 --- /dev/null +++ b/tests/test_biomedical_entity_linking.py @@ -0,0 +1,20 @@ +from flair.data import Sentence +from flair.models.biomedical_entity_linking import BiomedicalEntityLinker +from flair.nn import Classifier +from flair.tokenization import SciSpacyTokenizer + + +def test_biomedical_entity_linking(): + sentence = Sentence( + "Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome", use_tokenizer=SciSpacyTokenizer() + ) + ner_tagger = Classifier.load("hunflair-disease") + ner_tagger.predict(sentence) + nen_tagger = BiomedicalEntityLinker.load("disease") + nen_tagger.predict(sentence) + for tag in sentence.get_labels(): + print(tag) + + +if __name__ == "__main__": + test_biomedical_entity_linking() From 4ef59242f2da9f422e545a9af6d0e53aad44dfd5 Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Wed, 26 Apr 2023 18:54:30 +0200 Subject: [PATCH 08/58] fix(test): hold on w/ automatic tests for now --- tests/test_biomedical_entity_linking.py | 80 ++++++++++++++++++------- 1 file changed, 60 insertions(+), 20 deletions(-) diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index 60d25b187..d25f76ac6 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -1,20 +1,60 @@ -from flair.data import Sentence -from flair.models.biomedical_entity_linking import BiomedicalEntityLinker -from flair.nn import Classifier -from flair.tokenization import SciSpacyTokenizer - - -def test_biomedical_entity_linking(): - sentence = Sentence( - "Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome", use_tokenizer=SciSpacyTokenizer() - ) - ner_tagger = Classifier.load("hunflair-disease") - ner_tagger.predict(sentence) - nen_tagger = BiomedicalEntityLinker.load("disease") - nen_tagger.predict(sentence) - for tag in sentence.get_labels(): - print(tag) - - -if __name__ == "__main__": - test_biomedical_entity_linking() +# from flair.data import Sentence +# from flair.models.biomedical_entity_linking import BioNelDictionary +# from flair.nn import Classifier +# from flair.tokenization import SciSpacyTokenizer + +# def test_bionel_dictionary(): +# """ +# Check data in dictionary is what we expect. +# Hard to define a good test as dictionaries are DYNAMIC, +# i.e. they can change over time +# """ + +# dictionary = BioNelDictionary.load("disease") +# _, identifier = next(dictionary.stream()) +# assert identifier.startswith(("MESH:", "OMIM:", "DO:DOID")) + +# dictionary = BioNelDictionary.load("ctd-disease") +# _, identifier = next(dictionary.stream()) +# assert identifier.startswith("MESH:") + +# dictionary = BioNelDictionary.load("ctd-chemical") +# _, identifier = next(dictionary.stream()) +# assert identifier.startswith("MESH:") + +# dictionary = BioNelDictionary.load("chemical") +# _, identifier = next(dictionary.stream()) +# assert identifier.startswith("MESH:") + +# dictionary = BioNelDictionary.load("ncbi-taxonomy") +# _, identifier = next(dictionary.stream()) +# assert identifier.isdigit() + +# dictionary = BioNelDictionary.load("species") +# _, identifier = next(dictionary.stream()) +# assert identifier.isdigit() + +# dictionary = BioNelDictionary.load("ncbi-gene") +# _, identifier = next(dictionary.stream()) +# assert identifier.isdigit() + +# dictionary = BioNelDictionary.load("gene") +# _, identifier = next(dictionary.stream()) +# assert identifier.isdigit() + + +# def test_biomedical_entity_linking(): +# sentence = Sentence( +# "Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome", use_tokenizer=SciSpacyTokenizer() +# ) +# ner_tagger = Classifier.load("hunflair-disease") +# ner_tagger.predict(sentence) +# nen_tagger = BiomedicalEntityLinker.load("disease") +# nen_tagger.predict(sentence) +# for tag in sentence.get_labels(): +# print(tag) + + +# if __name__ == "__main__": +# test_bionel_dictionary() +# test_biomedical_entity_linking() From abc42b5e887b1ff7caade3111bf89e1609bdfb7c Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Wed, 26 Apr 2023 19:04:17 +0200 Subject: [PATCH 09/58] fix(bionel): start major refactoring - improve name consistency - make code more pythonic - dictionaries always do lazy loading - consistency in dictionary parsing: always yield (cui,name) - clean up loading w/ CONSTANTS (easily swap models) - allow access to sparse and dense search --- flair/models/biomedical_entity_linking.py | 777 +++++++++++----------- 1 file changed, 405 insertions(+), 372 deletions(-) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index dbdbd0858..9726ecb14 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -8,8 +8,9 @@ import tempfile from abc import ABC, abstractmethod from collections import defaultdict +from enum import Enum, auto from pathlib import Path -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterator, List, Optional, Tuple, Union import faiss import numpy as np @@ -27,13 +28,88 @@ NEL_NCBI_HUMAN_GENE_DICT, NEL_NCBI_TAXONOMY_DICT, ) +from flair.datasets.biomedical import PreprocessedBioNelDictionary from flair.embeddings import TransformerDocumentEmbeddings from flair.file_utils import cached_path -log = logging.getLogger("flair") +logger = logging.getLogger("flair") + +BIOMEDICAL_NEL_DICTIONARIES = { + "ctd-disease": NEL_CTD_DISEASE_DICT, + "ctd-chemical": NEL_CTD_CHEMICAL_DICT, + "ncbi-gene": NEL_NCBI_HUMAN_GENE_DICT, + "ncbi-taxonomy": NEL_NCBI_TAXONOMY_DICT, +} + +PRETRAINED_MODELS = [ + "cambridgeltl/SapBERT-from-PubMedBERT-fulltext", +] + +# Dense + sparse retrieval +PRETRAINED_HYBRID_MODELS = [ + "biosyn-sapbert-bc5cdr-disease", + "biosyn-sapbert-ncbi-disease", + "biosyn-sapbert-bc5cdr-chemical", + "biosyn-biobert-bc5cdr-disease", + "biosyn-biobert-ncbi-disease", + "biosyn-biobert-bc5cdr-chemical", + "biosyn-biobert-bc2gn", + "biosyn-sapbert-bc2gn", +] + +PRETRAINED_MODELS = PRETRAINED_HYBRID_MODELS + PRETRAINED_MODELS + +# just in case we add: fuzzy search, Levenstein, ... +STRING_MATCHING_MODELS = ["exact-string-match"] + +MODELS = PRETRAINED_MODELS + STRING_MATCHING_MODELS + +ENTITY_TYPES = ["disease", "chemical", "gene", "species"] + +ENTITY_TYPE_TO_HYBRID_MODEL = { + "disease": "dmis-lab/biosyn-sapbert-bc5cdr-disease", + "chemical": "dmis-lab/biosyn-sapbert-bc5cdr-chemical", + "gene": "dmis-lab/biosyn-sapbert-bc2gn", +} + +# for now we always fall back to SapBERT, +# but we should train our own models at some point +ENTITY_TYPE_TO_DENSE_MODEL = { + entity_type: "cambridgeltl/SapBERT-from-PubMedBERT-fulltext" for entity_type in ENTITY_TYPES +} + +DEFAULT_SPARSE_WEIGHT = 0.5 + +ENTITY_TYPE_TO_NEL_DICTIONARY = { + "gene": "ncbi-gene", + "species": "ncbi-taxonomy", + "disease": "ctd-disease", + "chemical": "ctd-chemical", +} + +MODEL_NAME_TO_NEL_DICTIONARY = { + "biosyn-sapbert-bc5cdr-disease": "ctd-disease", + "biosyn-sapbert-ncbi-disease": "ctd-disease", + "biosyn-sapbert-bc5cdr-chemical": "ctd-chemical", + "biosyn-biobert-bc5cdr-disease": "ctd-chemical", + "biosyn-biobert-ncbi-disease": "ctd-disease", + "biosyn-biobert-bc5cdr-chemical": "ctd-chemical", + "biosyn-biobert-bc2gn": "ncbi-gene", + "biosyn-sapbert-bc2gn": "ncbi-gene", +} + + +class SimilarityMetric(Enum): + """ + Available similarity metrics + """ + INNER_PRODUCT = faiss.METRIC_INNER_PRODUCT + # L2 = faiss.METRIC_L2 + COSINE = auto() -class EntityPreprocessor: + +class BioNelPreprocessor(ABC): """ A entity pre-processor is used to transform / clean an entity mention (recognized by an entity recognition model in the original text). This may include removing certain characters @@ -44,6 +120,7 @@ class EntityPreprocessor: subclasses that implement concrete transformations. """ + @abstractmethod def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: """ Processes the given entity mention and applies the transformation procedure to it. @@ -52,8 +129,8 @@ def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: :param sentence: sentence in which the entity mentioned occurred :result: Cleaned / transformed string representation of the given entity mention """ - raise NotImplementedError() + @abstractmethod def process_entry(self, entity_name: str) -> str: """ Processes the given entity name (originating from a knowledge base / ontology) and @@ -64,18 +141,17 @@ def process_entry(self, entity_name: str) -> str: """ raise NotImplementedError() - def initialize(self, sentences: List[Sentence]) -> None: + @abstractmethod + def initialize(self, sentences: List[Sentence]): """ Initializes the pre-processor for a batch of sentences, which is may be necessary for more sophisticated transformations. :param sentences: List of sentences that will be processed. """ - # Do nothing by default - pass -class BasicEntityPreprocessor(EntityPreprocessor): +class BasicBioNelPreprocessor(BioNelPreprocessor): """ Basic implementation of MentionPreprocessor, which supports lowercasing, typo correction and removing of punctuation characters. @@ -100,6 +176,9 @@ def __init__( self.remove_punctuation = remove_punctuation self.rmv_puncts_regex = re.compile(r"[\s{}]+".format(re.escape(punctuation_symbols))) + def initialize(self, sentences): + pass + def process_entry(self, entity_name: str) -> str: if self.lowercase: entity_name = entity_name.lower() @@ -114,7 +193,7 @@ def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: return self.process_entry(entity_mention.data_point.text) -class Ab3PEntityPreprocessor(EntityPreprocessor): +class Ab3PPreprocessor(BioNelPreprocessor): """ Implementation of MentionPreprocessor which utilizes Ab3P, an (biomedical)abbreviation definition detector, given in: @@ -128,7 +207,7 @@ class Ab3PEntityPreprocessor(EntityPreprocessor): PubMed ID: 18817555 """ - def __init__(self, ab3p_path: Path, word_data_dir: Path, preprocessor: Optional[EntityPreprocessor] = None) -> None: + def __init__(self, ab3p_path: Path, word_data_dir: Path, preprocessor: Optional[BioNelPreprocessor] = None) -> None: """ Creates the mention pre-processor @@ -140,6 +219,7 @@ def __init__(self, ab3p_path: Path, word_data_dir: Path, preprocessor: Optional[ self.ab3p_path = ab3p_path self.word_data_dir = word_data_dir self.preprocessor = preprocessor + self.abbreviation_dict = {} def initialize(self, sentences: List[Sentence]) -> None: self.abbreviation_dict = self._build_abbreviation_dict(sentences) @@ -172,7 +252,7 @@ def process_entry(self, entity_name: str) -> str: return entity_name @classmethod - def load(cls, ab3p_path: Path = None, preprocessor: Optional[EntityPreprocessor] = None): + def load(cls, ab3p_path: Path = None, preprocessor: Optional[BioNelPreprocessor] = None): data_dir = flair.cache_root / "ab3p" if not data_dir.exists(): data_dir.mkdir(parents=True) @@ -248,29 +328,31 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict path_file.write(str(self.word_data_dir) + "/\n") # Run ab3p with the temp file containing the dataset + # https://pylint.pycqa.org/en/latest/user_guide/messages/warning/subprocess-run-check.html try: result = subprocess.run( [self.ab3p_path, temp_file.name], stdout=subprocess.PIPE, stderr=subprocess.PIPE, + check=True, ) - except: - log.error( + except subprocess.CalledProcessError: + logger.error( """The abbreviation resolver Ab3P could not be run on your system. To ensure maximum accuracy, please install Ab3P yourself. See https://github.com/ncbi-nlp/Ab3P""" ) else: line = result.stdout.decode("utf-8") if "Path file for type cshset does not exist!" in line: - log.error( + logger.error( "Error when using Ab3P for abbreviation resolution. A file named path_Ab3p needs to exist in your current directory containing the path to the WordData directory for Ab3P to work!" ) elif "Cannot open" in line: - log.error( + logger.error( "Error when using Ab3P for abbreviation resolution. Could not open the WordData directory for Ab3P!" ) elif "failed to open" in line: - log.error( + logger.error( "Error when using Ab3P for abbreviation resolution. Could not open the WordData directory for Ab3P!" ) @@ -310,30 +392,46 @@ class BigramTfIDFVectorizer: def __init__(self) -> None: self.encoder = TfidfVectorizer(analyzer="char", ngram_range=(1, 2)) - def transform(self, mentions: List[str]) -> torch.Tensor: - vec = self.encoder.transform(mentions).toarray() + def fit(self, names: List[str]): + """ + Fit vectorizer + """ + self.encoder.fit(names) + return self + + def transform(self, names: List[str]) -> torch.Tensor: + """ + Convert string names to sparse vectors + """ + vec = self.encoder.transform(names).toarray() vec = torch.FloatTensor(vec) return vec def __call__(self, mentions: List[str]) -> torch.Tensor: + """ + Short for `transform` + """ return self.transform(mentions) def save_encoder(self, path: Path) -> None: with path.open("wb") as fout: pickle.dump(self.encoder, fout) - log.info("Sparse encoder saved in {}".format(path)) + logger.info("Sparse encoder saved in %s", path) @classmethod def load(cls, path: Path) -> "BigramTfIDFVectorizer": + """ + Instantiate from path + """ newVectorizer = cls() with open(path, "rb") as fin: newVectorizer.encoder = pickle.load(fin) - log.info("Sparse encoder loaded from {}".format(path)) + logger.info("Sparse encoder loaded from %s", path) return newVectorizer -class DictionaryDataset: +class BioNelDictionary: """ A class used to load dictionary data from a custom dictionary file. Every line in the file must be formatted as follows: @@ -346,101 +444,91 @@ class DictionaryDataset: https://github.com/dmis-lab/BioSyn/blob/master/src/biosyn/data_loader.py#L89 """ - def __init__(self, dictionary_path: Union[Path, str], load_into_memory: bool = True) -> None: + def __init__(self, reader): + self.reader = reader + + @classmethod + def load(cls, dictionary_name_or_path: Union[Path, str]) -> "BioNelDictionary": """ - :param dictionary_path str: Path to the dictionary file - :param load_into_memory bool: Indicates whether the dictionary entries should be loaded in - memory or not (Default True) + Load dictionary: either pre-definded or from path """ - log.info("Loading dictionary from {}".format(dictionary_path)) - if load_into_memory: - self.data = self.load_data(dictionary_path) - else: - self.data = self.get_data(dictionary_path) - - def load_data(self, dictionary_path: Union[Path, str]) -> np.ndarray: - data = [] - with open(dictionary_path, mode="r", encoding="utf-8") as file: - lines = file.readlines() - for line in tqdm(lines, desc="Loading dictionary"): - line = line.strip() - if line == "": - continue - cui, name = line.split("||") - name = name.lower() - data.append((name, cui)) - - data = np.array(data) - return data - - # generator version - def get_data(self, dictionary_path: Union[Path, str]) -> Iterable[Tuple]: - data = [] - with open(dictionary_path, mode="r", encoding="utf-8") as f: - lines = f.readlines() - for line in tqdm(lines, desc="Loading dictionary"): - line = line.strip() - if line == "": - continue - cui, name = line.split("||") - name = name.lower() - yield (name, cui) - @classmethod - def load(cls, dictionary_name_or_path: Union[Path, str]): if isinstance(dictionary_name_or_path, str): - # use provided dictionary - if dictionary_name_or_path == "ctd-disease": - return NEL_CTD_DISEASE_DICT() - elif dictionary_name_or_path == "ctd-chemical": - return NEL_CTD_CHEMICAL_DICT() - elif dictionary_name_or_path == "ncbi-gene": - return NEL_NCBI_HUMAN_GENE_DICT() - elif dictionary_name_or_path == "ncbi-taxonomy": - return NEL_NCBI_TAXONOMY_DICT() + if ( + dictionary_name_or_path not in ENTITY_TYPE_TO_NEL_DICTIONARY + and dictionary_name_or_path not in BIOMEDICAL_NEL_DICTIONARIES + ): + raise ValueError( + f"""Unkwnon dictionary `{dictionary_name_or_path}`, + Available dictionaries are: {tuple(BIOMEDICAL_NEL_DICTIONARIES)} \n + If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)`""" + ) + + dictionary_name_or_path = ENTITY_TYPE_TO_NEL_DICTIONARY.get( + dictionary_name_or_path, dictionary_name_or_path + ) + + reader = BIOMEDICAL_NEL_DICTIONARIES[dictionary_name_or_path]() + else: # use custom dictionary file - return DictionaryDataset(dictionary_path=dictionary_name_or_path) + reader = PreprocessedBioNelDictionary(path=dictionary_name_or_path) + + return cls(reader=reader) + + def get_database_names(self) -> List[str]: + """ + List all database names covered by dictionary, e.g. MESH, OMIM + """ + return self.reader.get_database_names() -class EntityRetrieverModel(ABC): + def stream(self) -> Iterator[Tuple[str, str]]: + """ + Stream preprocessed dictionary + """ + + for entry in self.reader.stream(): + yield entry + + +class EntityRetriever(ABC): """ An entity retriever model is used to find the top-k entities / concepts of a knowledge base / dictionary for a given entity mention in text. """ @abstractmethod - def get_top_k(self, entity_mention: str, top_k: int) -> List[Tuple[str, str, float]]: + def search(self, entity_mentions: List[str], top_k: int) -> List[Tuple[str, str, float]]: """ Returns the top-k entity / concept identifiers for the given entity mention. - :param entity_mention: Entity mention text under investigation + :param entity_mentions: Entity mention text under investigation :param top_k: Number of (best-matching) entities from the knowledge base to return :result: List of tuples highlighting the top-k entities. Each tuple has the following structure (entity / concept name, concept ids, score). """ - raise NotImplementedError() -class ExactStringMatchingRetrieverModel(EntityRetrieverModel): +class ExactStringMatchingRetriever(EntityRetriever): """ Implementation of an entity retriever model which uses exact string matching to find the entity / concept identifier for a given entity mention. """ - def __init__(self, dictionary: DictionaryDataset): + def __init__(self, dictionary: BioNelDictionary): # Build index which maps concept / entity names to concept / entity ids - self.name_to_id_index = {name: cui for name, cui in dictionary.data} + self.name_to_id_index = dict(dictionary.data) @classmethod - def load_model( - cls, - dictionary_name_or_path: str, - ): + def load(cls, dictionary_name_or_path: str) -> "ExactStringMatchingRetrieverModel": + """ + Compatibility function + """ # Load dictionary - return cls(DictionaryDataset.load(dictionary_name_or_path)) + return cls(BioNelDictionary.load(dictionary_name_or_path)) - def get_top_k(self, entity_mention: str, top_k: int) -> List[Tuple[str, str, float]]: + def search(self, entity_mentions: List[str], top_k: int) -> List[Tuple[str, str, float]]: """ Returns the top-k entity / concept identifiers for the given entity mention. Note that the model either returns the entity with an identical name in the knowledge base / dictionary @@ -451,13 +539,11 @@ def get_top_k(self, entity_mention: str, top_k: int) -> List[Tuple[str, str, flo :result: List of tuples highlighting the top-k entities. Each tuple has the following structure (entity / concept name, concept ids, score). """ - if entity_mention in self.name_to_id_index: - return [(entity_mention, self.name_to_id_index[entity_mention], 1.0)] - else: - return [] + + return [(em, self.name_to_id_index.get(em), 1.0) for em in entity_mentions] -class BiEncoderEntityRetrieverModel(EntityRetrieverModel): +class BiEncoderEntityRetriever(EntityRetriever): """ Implementation of EntityRetrieverModel which uses dense (transformer-based) embeddings and (optionally) sparse character-based representations, for normalizing an entity mention to specific identifiers @@ -471,58 +557,59 @@ def __init__( self, model_name_or_path: Union[str, Path], dictionary_name_or_path: str, - use_sparse_embeddings: bool, - use_cosine: bool, - max_length: int, - batch_size: int, - index_use_cuda: bool, - top_k_extra_dense: int = 10, - top_k_extra_sparse: int = 10, - preprocessor: Optional[EntityPreprocessor] = BasicEntityPreprocessor(), - ) -> None: + hybrid_search: bool = False, + max_length: int = 25, + index_batch_size: int = 1024, + preprocessor: BioNelPreprocessor = Ab3PPreprocessor.load(preprocessor=BasicBioNelPreprocessor()), + similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, + sparse_weight: Optional[float] = None, + ): """ Initializes the BiEncoderEntityRetrieverModel. :param model_name_or_path: Name of or path to the transformer model to be used. :param dictionary_name_or_path: Name of or path to the transformer model to be used. - :param use_sparse_embeddings: Indicates whether to use sparse embeddings or not + :param hybrid_search: Indicates whether to use sparse embeddings or not :param use_cosine: Indicates whether to use cosine similarity (instead of inner product) :param max_length: Maximal number of tokens used for embedding an entity mention / concept name - :param batch_size: Batch size used during embedding of the dictionary and top-k prediction - :param index_use_cuda: Indicates whether to use CUDA while indexing the dictionary / knowledge base - :param top_k_extra_sparse: Number of extra entities (resp. their sparse embeddings) which should be - retrieved while combining sparse and dense scores - :param top_k_extra_dense: Number of extra entities (resp. their dense embeddings) which should be - retrieved while combining sparse and dense scores - :param preprocessor: Preprocessing strategy to clean and transform entity / concept names from - the knowledge base + :param index_batch_size: Batch size used during embedding of the dictionary and top-k prediction + :param similarity_metric: which metric to use to compute similarity + :param sparse_weight: default sparse weight + :param preprocessor: Preprocessing strategy to clean and transform entity / concept names from the knowledge base """ - self.use_sparse_embeds = use_sparse_embeddings - self.use_cosine = use_cosine - self.max_length = max_length - self.batch_size = batch_size - self.top_k_extra_dense = top_k_extra_dense - self.top_k_extra_sparse = top_k_extra_sparse - self.index_use_cuda = index_use_cuda and flair.device.type == "cuda" self.preprocessor = preprocessor + self.similarity_metric = similarity_metric + self.max_length = max_length + self.index_batch_size = index_batch_size + self.hybrid_search = hybrid_search + self.sparse_weight = sparse_weight # Load dense encoder self.dense_encoder = TransformerDocumentEmbeddings(model=model_name_or_path, is_token_embedding=False) - # Load sparse encoder - if self.use_sparse_embeds: - # FIXME: What happens if sparse encoder isn't pre-trained??? - self._load_sparse_encoder(model_name_or_path) - self._load_sparse_weight(model_name_or_path) + # Load dictionary + self.dictionary = BioNelDictionary.load(dictionary_name_or_path) - self._embed_dictionary( + self.embeddings = self._load_emebddings( model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path, - batch_size=batch_size, + batch_size=self.index_batch_size, ) - def _load_sparse_encoder(self, model_name_or_path: Union[str, Path]) -> BigramTfIDFVectorizer: + # Build dense embedding index using faiss + dimension = self.embeddings["dense"].shape[1] + self.dense_index = faiss.IndexFlatIP(dimension) + self.dense_index.add(self.embeddings["dense"]) + + self.sparse_encoder: Optional[BigramTfIDFVectorizer] = None + if self.hybrid_search: + self._set_sparse_encoder(model_name_or_path=model_name_or_path) + + def _set_sparse_encoder(self, model_name_or_path: Union[str, Path]) -> BigramTfIDFVectorizer: + sparse_encoder_path = os.path.join(model_name_or_path, "sparse_encoder.pk") + sparse_weight_path = os.path.join(model_name_or_path, "sparse_weight.pt") + # check file exists if not os.path.isfile(sparse_encoder_path): # download from huggingface hub and cache it @@ -534,10 +621,6 @@ def _load_sparse_encoder(self, model_name_or_path: Union[str, Path]) -> BigramTf self.sparse_encoder = BigramTfIDFVectorizer.load(path=sparse_encoder_path) - return self.sparse_encoder - - def _load_sparse_weight(self, model_name_or_path: Union[str, Path]) -> float: - sparse_weight_path = os.path.join(model_name_or_path, "sparse_weight.pt") # check file exists if not os.path.isfile(sparse_weight_path): # download from huggingface hub and cache it @@ -551,7 +634,7 @@ def _load_sparse_weight(self, model_name_or_path: Union[str, Path]) -> float: return self.sparse_weight - def _embed_sparse(self, entity_names: np.ndarray) -> np.ndarray: + def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: """ Embeds the given numpy array of entity names, either originating from the knowledge base or recognized in a text, into sparse representations. @@ -559,15 +642,15 @@ def _embed_sparse(self, entity_names: np.ndarray) -> np.ndarray: :param entity_names: An array of entity / concept names :returns sparse_embeds np.array: Numpy array containing the sparse embeddings """ - sparse_embeds = self.sparse_encoder(entity_names) + sparse_embeds = self.sparse_encoder(inputs) sparse_embeds = sparse_embeds.numpy() - if self.use_cosine: + if self.similarity_metric == SimilarityMetric.COS: faiss.normalize_L2(sparse_embeds) return sparse_embeds - def _embed_dense(self, names: np.ndarray, batch_size: int = 2048, show_progress: bool = False) -> np.ndarray: + def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: bool = False) -> np.ndarray: """ Embeds the given numpy array of entity / concept names, either originating from the knowledge base or recognized in a text, into dense representations using a @@ -585,16 +668,16 @@ def _embed_dense(self, names: np.ndarray, batch_size: int = 2048, show_progress: with torch.no_grad(): if show_progress: iterations = tqdm( - range(0, len(names), batch_size), + range(0, len(inputs), batch_size), desc="Calculating dense embeddings for dictionary", ) else: - iterations = range(0, len(names), batch_size) + iterations = range(0, len(inputs), batch_size) for start in iterations: # Create batch - end = min(start + batch_size, len(names)) - batch = [Sentence(name) for name in names[start:end]] + end = min(start + batch_size, len(inputs)) + batch = [Sentence(name) for name in inputs[start:end]] # embed batch self.dense_encoder.embed(batch) @@ -605,104 +688,58 @@ def _embed_dense(self, names: np.ndarray, batch_size: int = 2048, show_progress: torch.cuda.empty_cache() dense_embeds = np.array(dense_embeds) - if self.use_cosine: - faiss.normalize_L2(dense_embeds) return dense_embeds - def _embed_dictionary(self, model_name_or_path: str, dictionary_name_or_path: str, batch_size: int): + def _load_emebddings(self, model_name_or_path: str, dictionary_name_or_path: str, batch_size: int): """ Computes the embeddings for the given knowledge base / dictionary. """ - # Load dictionary - self.dictionary = DictionaryDataset.load(dictionary_name_or_path).data # Check for embedded dictionary in cache dictionary_name = os.path.splitext(os.path.basename(dictionary_name_or_path))[0] file_name = f"bio_nen_{model_name_or_path.split('/')[-1]}_{dictionary_name}" cache_folder = flair.cache_root / "datasets" - emb_dictionary_cache_file = cache_folder / f"{file_name}.pk" - # If exists, load the cached dictionary indices - if emb_dictionary_cache_file.exists(): - self._load_cached_dense_emb_dictionary(emb_dictionary_cache_file) + embeddings_cache_file = cache_folder / f"{file_name}.pk" - else: - # get all concept names from the dictionary - concept_names = [] - for row in self.dictionary: - concept_name = row[0] - if self.preprocessor is not None: - concept_name = self.preprocessor.process_entry(concept_name) - concept_names.append(concept_name) + # If exists, load the cached dictionary indices + if embeddings_cache_file.exists(): - concept_names = np.array(concept_names) + with embeddings_cache_file.open("rb") as fp: + logger.info("Load cached emebddings from %s", embeddings_cache_file) + embeddings = pickle.load(fp) - # Compute dense embeddings (if necessary) - self.dict_dense_embeddings = self._embed_dense( - names=concept_names, batch_size=batch_size, show_progress=True - ) + else: - # To use cosine similarity, we normalize the vectors and then use inner product - if self.use_cosine: - faiss.normalize_L2(self.dict_dense_embeddings) + cache_folder.mkdir(parents=True, exist_ok=True) - # Compute sparse embeddings (if necessary) - if self.use_sparse_embeds: - self.dict_sparse_embeddings = self._embed_sparse(entity_names=concept_names) - else: - self.dict_sparse_embeddings = None + names = self.dictionary.to_names(preprocessor=self.preprocessor) - # Build dense embedding index using faiss - dimension = self.dict_dense_embeddings.shape[1] - self.dense_dictionary_index = faiss.IndexFlatIP(dimension) - self.dense_dictionary_index.add(self.dict_dense_embeddings) + # Compute dense embeddings (if necessary) + dense_embeddings = self.embed_dense(inputs=names, batch_size=batch_size, show_progress=True) + sparse_embeddings = self.embed_sparse(inputs=names) if self.hybrid_search else None # Store the pre-computed index on disk for later re-use - cached_dictionary = { - "dictionary": self.dictionary, - "sparse_dictionary_embeds": self.dict_sparse_embeddings, - "dense_dictionary_index": self.dense_dictionary_index, + embeddings = { + "dense": dense_embeddings, + "sparse": sparse_embeddings, } - if not cache_folder.exists(): - cache_folder.mkdir(parents=True) + logger.info("Caching preprocessed dictionary into %s", cache_folder) + with embeddings_cache_file.open("wb") as fp: + pickle.dump(embeddings, fp) - log.info(f"Saving dictionary into cached file {cache_folder}") - with emb_dictionary_cache_file.open("wb") as cache_file: - pickle.dump(cached_dictionary, cache_file) + if self.similarity_metric == SimilarityMetric.COS: + faiss.normalize_L2(embeddings["dense"]) - # If we use CUDA - move index to GPU - if self.index_use_cuda: - self.dense_dictionary_index = faiss.index_cpu_to_gpu( - faiss.StandardGpuResources(), 0, self.dense_dictionary_index - ) + return embeddings - def _load_cached_dense_emb_dictionary(self, cached_dictionary_path: Path): - """ - Loads pre-computed dense dictionary embedding from disk. - """ - with cached_dictionary_path.open("rb") as cached_file: - log.info("Loaded dictionary from cached file {}".format(cached_dictionary_path)) - cached_dictionary = pickle.load(cached_file) - - self.dictionary, self.dict_sparse_embeddings, self.dense_dictionary_index = ( - cached_dictionary["dictionary"], - cached_dictionary["sparse_dictionary_embeds"], - cached_dictionary["dense_dictionary_index"], - ) - - if self.index_use_cuda: - self.dense_dictionary_index = faiss.index_cpu_to_gpu( - faiss.StandardGpuResources(), 0, self.dense_dictionary_index - ) - - def retrieve_sparse_topk_candidates( + def search_sparse( self, - mention_embeddings: np.ndarray, - dict_concept_embeddings: np.ndarray, - top_k: int, + entity_mentions: List[str], + top_k: int = 1, normalise: bool = False, ) -> Tuple[np.ndarray, np.ndarray]: """ @@ -714,10 +751,13 @@ def retrieve_sparse_topk_candidates( :return res: d numpy array of ids [# of query , # of dict] :return scores: numpy array of top scores """ - if self.use_cosine: - score_matrix = cosine_similarity(mention_embeddings, dict_concept_embeddings) + + mention_embeddings = self.sparse_encoder(entity_mentions) + + if self.similarity_metric == SimilarityMetric.COSINE: + score_matrix = cosine_similarity(mention_embeddings, self.embeddings["sparse"]) else: - score_matrix = np.matmul(mention_embeddings, dict_concept_embeddings.T) + score_matrix = np.matmul(mention_embeddings, self.embeddings["sparse"].T) if normalise: score_matrix = (score_matrix - score_matrix.min()) / (score_matrix.max() - score_matrix.min()) @@ -737,75 +777,80 @@ def indexing_2d(arr, cols): return (topk_idxs, topk_scores) - def get_top_k(self, entity_mention: str, top_k: int) -> List[Tuple[str, str, float]]: + def search_dense(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np.ndarray, np.ndarray]: """ - Returns the top-k entities for a given entity mention. - - :param entity_mention: Entity mention text under investigation - :param top_k: Number of (best-matching) entities from the knowledge base to return - :result: List of tuples highlighting the top-k entities. Each tuple has the following - structure (entity / concept name, concept ids, score). + Dense search via FAISS index """ # Compute dense embedding for the given entity mention - mention_dense_embeds = self._embed_dense(names=np.array([entity_mention]), batch_size=self.batch_size) - - # Search for more than top-k candidates if combining them with sparse scores - top_k_dense = top_k if not self.use_sparse_embeds else top_k + self.top_k_extra_dense + mention_dense_embeds = self.embed_dense(inputs=np.array(entity_mentions), batch_size=self.index_batch_size) # Get candidates from dense embeddings - dense_scores, dense_ids = self.dense_dictionary_index.search(x=mention_dense_embeds, k=top_k_dense) - - # If using sparse embeds: calculate hybrid scores with dense and sparse embeds - if self.use_sparse_embeds: - # Get sparse embeddings for the entity mention - mention_sparse_embeds = self._embed_sparse(entity_names=np.array([entity_mention])) - - # Get candidates from sparse embeddings - sparse_ids, sparse_distances = self.retrieve_sparse_topk_candidates( - mention_embeddings=mention_sparse_embeds, - dict_concept_embeddings=self.dict_sparse_embeddings, - top_k=top_k + self.top_k_extra_sparse, - ) + dists, ids = self.dense_index.search(x=mention_dense_embeds, top_k=top_k) - # Combine dense and sparse scores - sparse_weight = self.sparse_weight - hybrid_ids = [] - hybrid_scores = [] - - # For every embedded mention - for ( - top_dense_ids, - top_dense_scores, - top_sparse_ids, - top_sparse_distances, - ) in zip(dense_ids, dense_scores, sparse_ids, sparse_distances): - ids = top_dense_ids - distances = top_dense_scores - - for sparse_id, sparse_distance in zip(top_sparse_ids, top_sparse_distances): - if sparse_id not in ids: - ids = np.append(ids, sparse_id) - distances = np.append(distances, sparse_weight * sparse_distance) - else: - index = np.where(ids == sparse_id)[0][0] - distances[index] = (sparse_weight * sparse_distance) + distances[index] + return dists, ids - sorted_indizes = np.argsort(-distances) - ids = ids[sorted_indizes][:top_k] - distances = distances[sorted_indizes][:top_k] - hybrid_ids.append(ids.tolist()) - hybrid_scores.append(distances.tolist()) + def search(self, entity_mentions: List[str], top_k: int) -> List[Tuple[str, str, float]]: + """ + Returns the top-k entities for a given entity mention. - else: - # Use only dense embedding results - hybrid_ids = dense_ids - hybrid_scores = dense_scores + :param entity_mentions: Entity mentions (search queries) + :param top_k: Number of (best-matching) entities from the knowledge base to return + :result: List of tuples w/ the top-k entities: (concept name, concept ids, score). + """ - return [ - tuple(self.dictionary[entity_index].reshape(1, -1)[0]) + (score[0],) - for entity_index, score in zip(hybrid_ids, hybrid_scores) - ] + # dense + + # # If using sparse embeds: calculate hybrid scores with dense and sparse embeds + # if self.use_sparse_embeds: + # # Get sparse embeddings for the entity mention + # mention_sparse_embeds = self.embed_sparse(entity_names=np.array([entity_mention])) + + # # Get candidates from sparse embeddings + # sparse_ids, sparse_distances = self.search_sparse( + # mention_embeddings=mention_sparse_embeds, + # dict_concept_embeddings=self.dict_sparse_embeddings, + # top_k=top_k + self.top_k_extra_sparse, + # ) + + # # Combine dense and sparse scores + # sparse_weight = self.sparse_weight + # hybrid_ids = [] + # hybrid_scores = [] + + # # For every embedded mention + # for ( + # top_dense_ids, + # top_dense_scores, + # top_sparse_ids, + # top_sparse_distances, + # ) in zip(dense_ids, dense_scores, sparse_ids, sparse_distances): + # ids = top_dense_ids + # distances = top_dense_scores + + # for sparse_id, sparse_distance in zip(top_sparse_ids, top_sparse_distances): + # if sparse_id not in ids: + # ids = np.append(ids, sparse_id) + # distances = np.append(distances, sparse_weight * sparse_distance) + # else: + # index = np.where(ids == sparse_id)[0][0] + # distances[index] = (sparse_weight * sparse_distance) + distances[index] + + # sorted_indizes = np.argsort(-distances) + # ids = ids[sorted_indizes][:top_k] + # distances = distances[sorted_indizes][:top_k] + # hybrid_ids.append(ids.tolist()) + # hybrid_scores.append(distances.tolist()) + + # else: + # # Use only dense embedding results + # hybrid_ids = dense_ids + # hybrid_scores = dense_scores + + # return [ + # tuple(self.dictionary[entity_index].reshape(1, -1)[0]) + (score[0],) + # for entity_index, score in zip(hybrid_ids, hybrid_scores) + # ] class BiomedicalEntityLinker: @@ -814,7 +859,7 @@ class BiomedicalEntityLinker: entity / concept to these mentions according to a knowledge base / dictionary. """ - def __init__(self, retriever_model: EntityRetrieverModel, mention_preprocessor: EntityPreprocessor): + def __init__(self, retriever_model: EntityRetriever, mention_preprocessor: BioNelPreprocessor): self.preprocessor = mention_preprocessor self.retriever_model = retriever_model @@ -855,7 +900,7 @@ def predict( ) # Retrieve top-k concept / entity candidates - predictions = self.retriever_model.get_top_k(mention_text, top_k) + predictions = self.retriever_model.search(mention_text, top_k) # Add a label annotation for each candidate for prediction in predictions: @@ -894,32 +939,17 @@ def load( cls, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path] = None, - use_sparse_embeddings: bool = True, + hybrid_search: bool = True, max_length: int = 25, - batch_size: int = 1024, - index_use_cuda: bool = False, - use_cosine: bool = True, - preprocessor: EntityPreprocessor = Ab3PEntityPreprocessor.load(preprocessor=BasicEntityPreprocessor()), + index_batch_size: int = 1024, + similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, + preprocessor: BioNelPreprocessor = Ab3PPreprocessor.load(preprocessor=BasicBioNelPreprocessor()), + default_sparse_encoder: bool = False, + sparse_weight: float = DEFAULT_SPARSE_WEIGHT, ): """ Loads a model for biomedical named entity normalization. - - :param model_name_or_path: Name of or path to a pretrained model to use. Possible values for pretrained - models are: - chemical, disease, gene, sapbert-bc5cdr-dissaease, sapbert-ncbi-disease, sapbert-bc5cdr-chemical, - biobert-bc5cdr-disease,biobert-ncbi-disease, biobert-bc5cdr-chemical, biosyn-biobert-bc2gn, - biosyn-sapbert-bc2gn, sapbert, exact-string-match - :param dictionary_name_or_path: Name of or path to a dictionary listing all possible entity / concept - identifiers and their concept names / synonyms. Pre-defined dictionaries are: - chemical, ctd-chemical, disease, bc5cdr-disease, gene, cnbci-gene, taxonomy and ncbi-taxonomy - :param use_sparse_embeddings: Indicates whether to use sparse embeddings for inference. If True, - uses a combinations of sparse and dense embeddings. If False, uses only dense embeddings - :param: max_length: Maximal number of tokens for an entity mention or concept name - :param batch_size: Batch size for the dense encoder - :param index_use_cuda: If True, uses GPU for the dense encoding - :param use_cosine: If True, uses cosine similarity for the dense encoder. If False, inner product is used. - :param preprocessor: Implementation of MentionPreprocessor to use for pre-processing the entity - mention text and dictionary entries + See __init__ method for detailed docstring on arguments """ dictionary_path = dictionary_name_or_path if dictionary_name_or_path is None or isinstance(dictionary_name_or_path, str): @@ -928,98 +958,101 @@ def load( retriever_model = None if isinstance(model_name_or_path, str): if model_name_or_path == "exact-string-match": - retriever_model = ExactStringMatchingRetrieverModel.load_model(dictionary_path) + retriever_model = ExactStringMatchingRetriever.load(dictionary_path) else: - model_path = cls.__get_model_path(model_name_or_path, use_sparse_embeddings) - retriever_model = BiEncoderEntityRetrieverModel( + model_path = cls.__get_model_path( + model_name_or_path=model_name_or_path, + hybrid_search=hybrid_search, + default_sparse_encoder=default_sparse_encoder, + ) + retriever_model = BiEncoderEntityRetriever( model_name_or_path=model_path, dictionary_name_or_path=dictionary_path, - use_sparse_embeddings=use_sparse_embeddings, - use_cosine=use_cosine, + hybrid_search=hybrid_search, + similarity_metric=similarity_metric, max_length=max_length, - batch_size=batch_size, - index_use_cuda=index_use_cuda, + index_batch_size=index_batch_size, + sparse_weight=sparse_weight, + preprocessor=preprocessor, ) return cls(retriever_model=retriever_model, mention_preprocessor=preprocessor) @staticmethod - def __get_model_path(model_name: str, use_sparse_and_dense_embeds: bool) -> str: - model_name = model_name.lower() - model_path = model_name - - # if a provided model is used, - # modify model name to huggingface path - - if model_name in [ - "sapbert-bc5cdr-disease", - "sapbert-ncbi-disease", - "sapbert-bc5cdr-chemical", - "biobert-bc5cdr-disease", - "biobert-ncbi-disease", - "biobert-bc5cdr-chemical", - "biosyn-biobert-bc2gn", - "biosyn-sapbert-bc2gn", - ]: - model_path = "dmis-lab/biosyn-" + model_name - elif model_name == "sapbert": - model_path = "cambridgeltl/SapBERT-from-PubMedBERT-fulltext" - elif model_name == "exact-string-match": - model_path = "exact-string-match" - elif use_sparse_and_dense_embeds: - if model_name == "disease": - model_path = "dmis-lab/biosyn-sapbert-bc5cdr-disease" - elif model_name == "chemical": - model_path = "dmis-lab/biosyn-sapbert-bc5cdr-chemical" - elif model_name == "gene": - model_path = "dmis-lab/biosyn-sapbert-bc2gn" - else: - if model_name == "disease": - model_path = "cambridgeltl/SapBERT-from-PubMedBERT-fulltext" - elif model_name == "chemical": - model_path = "cambridgeltl/SapBERT-from-PubMedBERT-fulltext" - elif model_name == "gene": - raise ValueError("No trained model for gene entity linking using only dense embeddings.") - return model_path + def __get_model_path( + model_name_or_path: Union[str, Path], hybrid_search: bool = False, default_sparse_encoder: bool = False + ) -> str: + """ + Try to figure out what model the user wants + """ + + if isinstance(model_name_or_path, str): + + model_name_or_path = model_name_or_path.lower() + + if model_name_or_path not in MODELS and model_name_or_path not in ENTITY_TYPES: + raise ValueError( + f"""Unknown model `{model_name_or_path}`! \n + Available entity types are: {ENTITY_TYPES} \n + If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)`""" + ) + + if hybrid_search: + + # load model by entity_type + if model_name_or_path in ENTITY_TYPES: + # check if we have a hybrid pre-trained model + if model_name_or_path in ENTITY_TYPE_TO_HYBRID_MODEL: + model_name_or_path = ENTITY_TYPE_TO_HYBRID_MODEL[model_name_or_path] + else: + # check if user really wants to use hybrid search anyway + if not default_sparse_encoder: + raise ValueError( + f"""Model for entity type `{model_name_or_path}` was not trained for hybrid search! \n + If you want to proceed anyway please pass `default_sparse_encoder=True`: + we will fit a sparse encoder for you. The default value of `sparse_weight` is `{DEFAULT_SPARSE_WEIGHT}`. + """ + ) + model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] + else: + if model_name_or_path not in PRETRAINED_HYBRID_MODELS and not default_sparse_encoder: + raise ValueError( + f"""Model `{model_name_or_path}` was not trained for hybrid search! \n + If you want to proceed anyway please pass `default_sparse_encoder=True`: + we will fit a sparse encoder for you. The default value of `sparse_weight` is `{DEFAULT_SPARSE_WEIGHT}`. + """ + ) + + return model_name_or_path @staticmethod - def __get_dictionary_path(dictionary_path: str, model_name: str) -> str: - # determine dictionary to use - if dictionary_path == "disease": - dictionary_path = "ctd-disease" - if dictionary_path == "chemical": - dictionary_path = "ctd-chemical" - if dictionary_path == "gene": - dictionary_path = "ncbi-gene" - if dictionary_path == "taxonomy": - dictionary_path = "ncbi-taxonomy" - if dictionary_path is None: - # disease - if model_name in [ - "sapbert-bc5cdr-disease", - "sapbert-ncbi-disease", - "biobert-bc5cdr-disease", - "biobert-ncbi-disease", - "disease", - ]: - dictionary_path = "ctd-disease" - # chemical - elif model_name in [ - "sapbert-bc5cdr-chemical", - "biobert-bc5cdr-chemical", - "chemical", - ]: - dictionary_path = "ctd-chemical" - # gene - elif model_name in ["gene", "biosyn-biobert-bc2gn", "biosyn-sapbert-bc2gn"]: - dictionary_path = "ncbi-gene" - # error + def __get_dictionary_path(model_name: str, dictionary_name_or_path: Optional[Union[str, Path]] = None) -> str: + """ + Try to figure out what dictionary (depending on the model) the user wants + """ + + if model_name in STRING_MATCHING_MODELS and dictionary_name_or_path is None: + raise ValueError("When using a string-matchin retriever you must specify `dictionary_name_or_path`!") + + if dictionary_name_or_path not in MODELS and dictionary_name_or_path not in ENTITY_TYPES: + raise ValueError( + f"""Unknown dictionary `{dictionary_name_or_path}`! \n + Available entity types are: {ENTITY_TYPES} \n + If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)` + """ + ) + + if dictionary_name_or_path is not None: + if dictionary_name_or_path in ENTITY_TYPES: + dictionary_name_or_path = ENTITY_TYPE_TO_NEL_DICTIONARY[dictionary_name_or_path] else: - log.error( - """When using a custom model you need to specify a dictionary. - Available options are: 'disease', 'chemical', 'gene' and 'taxonomy'. - Or provide a path to a dictionary file.""" - ) - raise ValueError("Invalid dictionary") + if model_name in MODEL_NAME_TO_NEL_DICTIONARY: + dictionary_name_or_path = MODEL_NAME_TO_NEL_DICTIONARY[dictionary_name_or_path] + else: + raise ValueError( + """When using a custom model you need to specify a dictionary. + Available options are: 'disease', 'chemical', 'gene' and 'species'. + Or provide a path to a dictionary file.""" + ) - return dictionary_path + return dictionary_name_or_path From bdc3e8aed04a745dbd418754f90d9fc43c252e02 Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Thu, 27 Apr 2023 18:28:31 +0200 Subject: [PATCH 10/58] fix(bionel): major refactor - yet better naming - add batched search - fix dicionary loading --- flair/datasets/__init__.py | 8 +- flair/models/biomedical_entity_linking.py | 633 +++++++++++++--------- tests/test_biomedical_entity_linking.py | 50 +- 3 files changed, 404 insertions(+), 287 deletions(-) diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index c203b802b..094169c6b 100644 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -37,6 +37,8 @@ CLL, CRAFT, CRAFT_V4, + CTD_CHEMICAL_DICTIONARY, + CTD_DISEASE_DICTIONARY, DECA, FSU, GELLUS, @@ -90,10 +92,8 @@ LOCTEXT, MIRNA, NCBI_DISEASE, - NEL_CTD_CHEMICAL_DICT, - NEL_CTD_DISEASE_DICT, - NEL_NCBI_HUMAN_GENE_DICT, - NEL_NCBI_TAXONOMY_DICT, + NCBI_GENE_HUMAN_DICTIONARY, + NCBI_TAXONOMY_DICTIONARY, OSIRIS, PDR, S800, diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index 9726ecb14..344ff187c 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -15,31 +15,25 @@ import faiss import numpy as np import torch -from huggingface_hub import cached_download, hf_hub_url +from huggingface_hub import hf_hub_download from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm import flair -from flair.data import EntityLinkingLabel, Label, Sentence +from flair.data import EntityLinkingLabel, Label, Sentence, Span from flair.datasets import ( - NEL_CTD_CHEMICAL_DICT, - NEL_CTD_DISEASE_DICT, - NEL_NCBI_HUMAN_GENE_DICT, - NEL_NCBI_TAXONOMY_DICT, + CTD_CHEMICAL_DICTIONARY, + CTD_DISEASE_DICTIONARY, + NCBI_GENE_HUMAN_DICTIONARY, + NCBI_TAXONOMY_DICTIONARY, ) -from flair.datasets.biomedical import PreprocessedBioNelDictionary +from flair.datasets.biomedical import ParsedBiomedicalEntityLinkingDictionary from flair.embeddings import TransformerDocumentEmbeddings from flair.file_utils import cached_path logger = logging.getLogger("flair") -BIOMEDICAL_NEL_DICTIONARIES = { - "ctd-disease": NEL_CTD_DISEASE_DICT, - "ctd-chemical": NEL_CTD_CHEMICAL_DICT, - "ncbi-gene": NEL_NCBI_HUMAN_GENE_DICT, - "ncbi-taxonomy": NEL_NCBI_TAXONOMY_DICT, -} PRETRAINED_MODELS = [ "cambridgeltl/SapBERT-from-PubMedBERT-fulltext", @@ -47,14 +41,14 @@ # Dense + sparse retrieval PRETRAINED_HYBRID_MODELS = [ - "biosyn-sapbert-bc5cdr-disease", - "biosyn-sapbert-ncbi-disease", - "biosyn-sapbert-bc5cdr-chemical", - "biosyn-biobert-bc5cdr-disease", - "biosyn-biobert-ncbi-disease", - "biosyn-biobert-bc5cdr-chemical", - "biosyn-biobert-bc2gn", - "biosyn-sapbert-bc2gn", + "dmis-lab/biosyn-sapbert-bc5cdr-disease", + "dmis-lab/biosyn-sapbert-ncbi-disease", + "dmis-lab/biosyn-sapbert-bc5cdr-chemical", + "dmis-lab/biosyn-biobert-bc5cdr-disease", + "dmis-lab/biosyn-biobert-ncbi-disease", + "dmis-lab/biosyn-biobert-bc5cdr-chemical", + "dmis-lab/biosyn-biobert-bc2gn", + "dmis-lab/biosyn-sapbert-bc2gn", ] PRETRAINED_MODELS = PRETRAINED_HYBRID_MODELS + PRETRAINED_MODELS @@ -78,38 +72,45 @@ entity_type: "cambridgeltl/SapBERT-from-PubMedBERT-fulltext" for entity_type in ENTITY_TYPES } -DEFAULT_SPARSE_WEIGHT = 0.5 -ENTITY_TYPE_TO_NEL_DICTIONARY = { +ENTITY_TYPE_TO_DICTIONARY = { "gene": "ncbi-gene", "species": "ncbi-taxonomy", "disease": "ctd-disease", "chemical": "ctd-chemical", } -MODEL_NAME_TO_NEL_DICTIONARY = { - "biosyn-sapbert-bc5cdr-disease": "ctd-disease", - "biosyn-sapbert-ncbi-disease": "ctd-disease", - "biosyn-sapbert-bc5cdr-chemical": "ctd-chemical", - "biosyn-biobert-bc5cdr-disease": "ctd-chemical", - "biosyn-biobert-ncbi-disease": "ctd-disease", - "biosyn-biobert-bc5cdr-chemical": "ctd-chemical", - "biosyn-biobert-bc2gn": "ncbi-gene", - "biosyn-sapbert-bc2gn": "ncbi-gene", +BIOMEDICAL_DICTIONARIES = { + "ctd-disease": CTD_DISEASE_DICTIONARY, + "ctd-chemical": CTD_CHEMICAL_DICTIONARY, + "ncbi-gene": NCBI_GENE_HUMAN_DICTIONARY, + "ncbi-taxonomy": NCBI_TAXONOMY_DICTIONARY, } +MODEL_NAME_TO_DICTIONARY = { + "dmis-lab/biosyn-sapbert-bc5cdr-disease": "ctd-disease", + "dmis-lab/biosyn-sapbert-ncbi-disease": "ctd-disease", + "dmis-lab/biosyn-sapbert-bc5cdr-chemical": "ctd-chemical", + "dmis-lab/biosyn-biobert-bc5cdr-disease": "ctd-chemical", + "dmis-lab/biosyn-biobert-ncbi-disease": "ctd-disease", + "dmis-lab/biosyn-biobert-bc5cdr-chemical": "ctd-chemical", + "dmis-lab/biosyn-biobert-bc2gn": "ncbi-gene", + "dmis-lab/biosyn-sapbert-bc2gn": "ncbi-gene", +} + + +DEFAULT_SPARSE_WEIGHT = 0.5 + class SimilarityMetric(Enum): - """ - Available similarity metrics - """ + """Similarity metrics""" INNER_PRODUCT = faiss.METRIC_INNER_PRODUCT # L2 = faiss.METRIC_L2 COSINE = auto() -class BioNelPreprocessor(ABC): +class AbstractEntityPreprocessor(ABC): """ A entity pre-processor is used to transform / clean an entity mention (recognized by an entity recognition model in the original text). This may include removing certain characters @@ -120,6 +121,14 @@ class BioNelPreprocessor(ABC): subclasses that implement concrete transformations. """ + @property + @abstractmethod + def name(self) -> str: + """ + Define preprocessor name. + This is needed to correctly cache different multiple version of the dictionary + """ + @abstractmethod def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: """ @@ -131,7 +140,7 @@ def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: """ @abstractmethod - def process_entry(self, entity_name: str) -> str: + def process_entity_name(self, entity_name: str) -> str: """ Processes the given entity name (originating from a knowledge base / ontology) and applies the transformation procedure to it. @@ -139,7 +148,6 @@ def process_entry(self, entity_name: str) -> str: :param entity_name: entity mention given as DataPoint :result: Cleaned / transformed string representation of the given entity mention """ - raise NotImplementedError() @abstractmethod def initialize(self, sentences: List[Sentence]): @@ -151,7 +159,7 @@ def initialize(self, sentences: List[Sentence]): """ -class BasicBioNelPreprocessor(BioNelPreprocessor): +class EntityPreprocessor(AbstractEntityPreprocessor): """ Basic implementation of MentionPreprocessor, which supports lowercasing, typo correction and removing of punctuation characters. @@ -162,7 +170,9 @@ class BasicBioNelPreprocessor(BioNelPreprocessor): """ def __init__( - self, lowercase: bool = True, remove_punctuation: bool = True, punctuation_symbols: str = string.punctuation + self, + lowercase: bool = True, + remove_punctuation: bool = True, ) -> None: """ Initializes the mention preprocessor. @@ -174,12 +184,17 @@ def __init__( """ self.lowercase = lowercase self.remove_punctuation = remove_punctuation - self.rmv_puncts_regex = re.compile(r"[\s{}]+".format(re.escape(punctuation_symbols))) + self.rmv_puncts_regex = re.compile(r"[\s{}]+".format(re.escape(string.punctuation))) + + @property + def name(self): + + return "biosyn" def initialize(self, sentences): pass - def process_entry(self, entity_name: str) -> str: + def process_entity_name(self, entity_name: str) -> str: if self.lowercase: entity_name = entity_name.lower() @@ -190,10 +205,10 @@ def process_entry(self, entity_name: str) -> str: return entity_name.strip() def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: - return self.process_entry(entity_mention.data_point.text) + return self.process_entity_name(entity_mention.data_point.text) -class Ab3PPreprocessor(BioNelPreprocessor): +class Ab3PEntityPreprocessor(AbstractEntityPreprocessor): """ Implementation of MentionPreprocessor which utilizes Ab3P, an (biomedical)abbreviation definition detector, given in: @@ -207,7 +222,9 @@ class Ab3PPreprocessor(BioNelPreprocessor): PubMed ID: 18817555 """ - def __init__(self, ab3p_path: Path, word_data_dir: Path, preprocessor: Optional[BioNelPreprocessor] = None) -> None: + def __init__( + self, ab3p_path: Path, word_data_dir: Path, preprocessor: Optional[AbstractEntityPreprocessor] = None + ) -> None: """ Creates the mention pre-processor @@ -221,6 +238,11 @@ def __init__(self, ab3p_path: Path, word_data_dir: Path, preprocessor: Optional[ self.preprocessor = preprocessor self.abbreviation_dict = {} + @property + def name(self): + + return f"ab3p_{self.preprocessor.name}" + def initialize(self, sentences: List[Sentence]) -> None: self.abbreviation_dict = self._build_abbreviation_dict(sentences) @@ -231,7 +253,7 @@ def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: parsed_tokens = [] for token in tokens: if self.preprocessor is not None: - token = self.preprocessor.process_entry(token) + token = self.preprocessor.process_entity_name(token) if sentence_text in self.abbreviation_dict: if token.lower() in self.abbreviation_dict[sentence_text]: @@ -243,16 +265,16 @@ def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: return " ".join(parsed_tokens) - def process_entry(self, entity_name: str) -> str: + def process_entity_name(self, entity_name: str) -> str: # Ab3P works on sentence-level and not on a single entity mention / name # - so we just apply the wrapped text pre-processing here (if configured) if self.preprocessor is not None: - return self.preprocessor.process_entry(entity_name) + return self.preprocessor.process_entity_name(entity_name) return entity_name @classmethod - def load(cls, ab3p_path: Path = None, preprocessor: Optional[BioNelPreprocessor] = None): + def load(cls, ab3p_path: Path = None, preprocessor: Optional[AbstractEntityPreprocessor] = None): data_dir = flair.cache_root / "ab3p" if not data_dir.exists(): data_dir.mkdir(parents=True) @@ -389,7 +411,7 @@ class BigramTfIDFVectorizer: https://github.com/dmis-lab/BioSyn/tree/master/src/biosyn/sparse_encoder.py#L8 """ - def __init__(self) -> None: + def __init__(self): self.encoder = TfidfVectorizer(analyzer="char", ngram_range=(1, 2)) def fit(self, names: List[str]): @@ -413,7 +435,7 @@ def __call__(self, mentions: List[str]) -> torch.Tensor: """ return self.transform(mentions) - def save_encoder(self, path: Path) -> None: + def save(self, path: Path) -> None: with path.open("wb") as fout: pickle.dump(self.encoder, fout) logger.info("Sparse encoder saved in %s", path) @@ -431,7 +453,7 @@ def load(cls, path: Path) -> "BigramTfIDFVectorizer": return newVectorizer -class BioNelDictionary: +class BiomedicalEntityLinkingDictionary: """ A class used to load dictionary data from a custom dictionary file. Every line in the file must be formatted as follows: @@ -448,31 +470,29 @@ def __init__(self, reader): self.reader = reader @classmethod - def load(cls, dictionary_name_or_path: Union[Path, str]) -> "BioNelDictionary": + def load(cls, dictionary_name_or_path: Union[Path, str]) -> "EntityLinkingDictionary": """ Load dictionary: either pre-definded or from path """ if isinstance(dictionary_name_or_path, str): if ( - dictionary_name_or_path not in ENTITY_TYPE_TO_NEL_DICTIONARY - and dictionary_name_or_path not in BIOMEDICAL_NEL_DICTIONARIES + dictionary_name_or_path not in ENTITY_TYPE_TO_DICTIONARY + and dictionary_name_or_path not in BIOMEDICAL_DICTIONARIES ): raise ValueError( f"""Unkwnon dictionary `{dictionary_name_or_path}`, - Available dictionaries are: {tuple(BIOMEDICAL_NEL_DICTIONARIES)} \n + Available dictionaries are: {tuple(BIOMEDICAL_DICTIONARIES)} \n If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)`""" ) - dictionary_name_or_path = ENTITY_TYPE_TO_NEL_DICTIONARY.get( - dictionary_name_or_path, dictionary_name_or_path - ) + dictionary_name_or_path = ENTITY_TYPE_TO_DICTIONARY.get(dictionary_name_or_path, dictionary_name_or_path) - reader = BIOMEDICAL_NEL_DICTIONARIES[dictionary_name_or_path]() + reader = BIOMEDICAL_DICTIONARIES[dictionary_name_or_path]() else: # use custom dictionary file - reader = PreprocessedBioNelDictionary(path=dictionary_name_or_path) + reader = ParsedBiomedicalEntityLinkingDictionary(path=dictionary_name_or_path) return cls(reader=reader) @@ -492,7 +512,7 @@ def stream(self) -> Iterator[Tuple[str, str]]: yield entry -class EntityRetriever(ABC): +class AbstractCandidateGenerator(ABC): """ An entity retriever model is used to find the top-k entities / concepts of a knowledge base / dictionary for a given entity mention in text. @@ -510,13 +530,13 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[Tuple[str, str, """ -class ExactStringMatchingRetriever(EntityRetriever): +class ExactMatchCandidateGenerator(AbstractCandidateGenerator): """ Implementation of an entity retriever model which uses exact string matching to find the entity / concept identifier for a given entity mention. """ - def __init__(self, dictionary: BioNelDictionary): + def __init__(self, dictionary: BiomedicalEntityLinkingDictionary): # Build index which maps concept / entity names to concept / entity ids self.name_to_id_index = dict(dictionary.data) @@ -526,7 +546,7 @@ def load(cls, dictionary_name_or_path: str) -> "ExactStringMatchingRetrieverMode Compatibility function """ # Load dictionary - return cls(BioNelDictionary.load(dictionary_name_or_path)) + return cls(BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path)) def search(self, entity_mentions: List[str], top_k: int) -> List[Tuple[str, str, float]]: """ @@ -543,7 +563,7 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[Tuple[str, str, return [(em, self.name_to_id_index.get(em), 1.0) for em in entity_mentions] -class BiEncoderEntityRetriever(EntityRetriever): +class BiEncoderCandidateGenerator(AbstractCandidateGenerator): """ Implementation of EntityRetrieverModel which uses dense (transformer-based) embeddings and (optionally) sparse character-based representations, for normalizing an entity mention to specific identifiers @@ -557,12 +577,13 @@ def __init__( self, model_name_or_path: Union[str, Path], dictionary_name_or_path: str, - hybrid_search: bool = False, + similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, + preprocessor: AbstractEntityPreprocessor = Ab3PEntityPreprocessor.load(preprocessor=EntityPreprocessor()), max_length: int = 25, index_batch_size: int = 1024, - preprocessor: BioNelPreprocessor = Ab3PPreprocessor.load(preprocessor=BasicBioNelPreprocessor()), - similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, + hybrid_search: bool = False, sparse_weight: Optional[float] = None, + force_hybrid_search: bool = False, ): """ Initializes the BiEncoderEntityRetrieverModel. @@ -577,18 +598,27 @@ def __init__( :param sparse_weight: default sparse weight :param preprocessor: Preprocessing strategy to clean and transform entity / concept names from the knowledge base """ + self.model_name_or_path = model_name_or_path + self.dictionary_name_or_path = dictionary_name_or_path self.preprocessor = preprocessor self.similarity_metric = similarity_metric self.max_length = max_length self.index_batch_size = index_batch_size self.hybrid_search = hybrid_search self.sparse_weight = sparse_weight - - # Load dense encoder - self.dense_encoder = TransformerDocumentEmbeddings(model=model_name_or_path, is_token_embedding=False) + self.force_hybrid_search = force_hybrid_search # Load dictionary - self.dictionary = BioNelDictionary.load(dictionary_name_or_path) + self.dictionary = list(BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path).stream()) + + # Load encoders + self.dense_encoder = TransformerDocumentEmbeddings(model=model_name_or_path, is_token_embedding=False) + self.sparse_encoder: Optional[BigramTfIDFVectorizer] = None + self.sparse_weight: Optional[float] = None + if self.hybrid_search: + self._set_sparse_weigth_and_encoder( + model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path + ) self.embeddings = self._load_emebddings( model_name_or_path=model_name_or_path, @@ -596,43 +626,88 @@ def __init__( batch_size=self.index_batch_size, ) - # Build dense embedding index using faiss - dimension = self.embeddings["dense"].shape[1] - self.dense_index = faiss.IndexFlatIP(dimension) - self.dense_index.add(self.embeddings["dense"]) + self.dense_index = self.build_dense_index(self.embeddings["dense"]) - self.sparse_encoder: Optional[BigramTfIDFVectorizer] = None - if self.hybrid_search: - self._set_sparse_encoder(model_name_or_path=model_name_or_path) + @property + def higher_is_better(self): + """ + Determine if similarity is proportional to score. + E.g. for L2 lower is better + """ + + return self.similarity_metric in [SimilarityMetric.COSINE, SimilarityMetric.INNER_PRODUCT] + + # separate method to allow more sophisticated logic in the future, + # e.g. ANN with IndexIP, HNSW... + def build_dense_index(self, embeddings: np.ndarray) -> faiss.Index: + """Initialize FAISS index""" + + dense_index = faiss.IndexFlatIP(embeddings.shape[1]) + dense_index.add(embeddings) + + return dense_index - def _set_sparse_encoder(self, model_name_or_path: Union[str, Path]) -> BigramTfIDFVectorizer: + def _fit_and_cache_sparse_encoder(self, sparse_encoder_path: str, sparse_weight_path: str): + """Fit sparse encoder to current dictionary""" + + sparse_weight = self.sparse_weight if self.sparse_weight is not None else DEFAULT_SPARSE_WEIGHT + logger.info( + "Hybrid model has no pretrained sparse encoder. Fit to dictionary `%s` (sparse_weight=%s)", + self.dictionary_name_or_path, + sparse_weight, + ) + sparse_encoder = BigramTfIDFVectorizer().fit([name for name, cui in self.dictionary]) + sparse_encoder.save(sparse_encoder_path) + torch.save(torch.FloatTensor(self.sparse_weight), sparse_weight_path) + + def _set_sparse_weigth_and_encoder( + self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path] + ): sparse_encoder_path = os.path.join(model_name_or_path, "sparse_encoder.pk") sparse_weight_path = os.path.join(model_name_or_path, "sparse_weight.pt") - # check file exists - if not os.path.isfile(sparse_encoder_path): - # download from huggingface hub and cache it - sparse_encoder_url = hf_hub_url(model_name_or_path, filename="sparse_encoder.pk") - sparse_encoder_path = cached_download( - url=sparse_encoder_url, - cache_dir=flair.cache_root / "models" / model_name_or_path, - ) + if isinstance(model_name_or_path, str): - self.sparse_encoder = BigramTfIDFVectorizer.load(path=sparse_encoder_path) + # check file exists + if model_name_or_path in PRETRAINED_HYBRID_MODELS: - # check file exists - if not os.path.isfile(sparse_weight_path): - # download from huggingface hub and cache it - sparse_weight_url = hf_hub_url(model_name_or_path, filename="sparse_weight.pt") - sparse_weight_path = cached_download( - url=sparse_weight_url, - cache_dir=flair.cache_root / "models" / model_name_or_path, - ) + if not os.path.exists(sparse_encoder_path): + sparse_encoder_path = hf_hub_download( + repo_id=model_name_or_path, + filename="sparse_encoder.pk", + cache_dir=flair.cache_root / "models" / model_name_or_path, + ) - self.sparse_weight = torch.load(sparse_weight_path, map_location="cpu").item() + if not os.path.exists(sparse_weight_path): + sparse_weight_path = hf_hub_download( + repo_id=model_name_or_path, + filename="sparse_weight.pt", + cache_dir=flair.cache_root / "models" / model_name_or_path, + ) + else: + if self.force_hybrid_search: + if not os.path.exists(sparse_encoder_path) and not os.path.exists(sparse_weight_path): + self._fit_and_cache_sparse_encoder( + sparse_encoder_path=sparse_encoder_path, sparse_weight_path=sparse_weight_path + ) + else: + raise ValueError( + f"Hybrid model has no pretrained sparse encoder. Please pass `force_hybrid_search=True` if you want to fit a sparse model to dictionary `{dictionary_name_or_path}`" + ) + else: + if not os.path.exists(sparse_encoder_path) and not os.path.exists(sparse_weight_path): + if self.force_hybrid_search: + self._fit_and_cache_sparse_encoder( + sparse_encoder_path=sparse_encoder_path, sparse_weight_path=sparse_weight_path + ) + else: + raise ValueError( + f"Hybrid model has no pretrained sparse encoder. Please pass `force_hybrid_search=True` if you want to fit a sparse model to dictionary `{dictionary_name_or_path}`" + ) - return self.sparse_weight + self.sparse_encoder = BigramTfIDFVectorizer.load(path=sparse_encoder_path) + self.sparse_weight = torch.load(sparse_weight_path, map_location="cpu").item() def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: """ @@ -645,7 +720,7 @@ def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: sparse_embeds = self.sparse_encoder(inputs) sparse_embeds = sparse_embeds.numpy() - if self.similarity_metric == SimilarityMetric.COS: + if self.similarity_metric == SimilarityMetric.COSINE: faiss.normalize_L2(sparse_embeds) return sparse_embeds @@ -669,7 +744,7 @@ def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: if show_progress: iterations = tqdm( range(0, len(inputs), batch_size), - desc="Calculating dense embeddings for dictionary", + desc="Embed inputs", ) else: iterations = range(0, len(inputs), batch_size) @@ -702,7 +777,9 @@ def _load_emebddings(self, model_name_or_path: str, dictionary_name_or_path: str cache_folder = flair.cache_root / "datasets" - embeddings_cache_file = cache_folder / f"{file_name}.pk" + pp_name = self.preprocessor.name if self.preprocessor is not None else "null" + + embeddings_cache_file = cache_folder / f"{file_name}_pp={pp_name}.pk" # If exists, load the cached dictionary indices if embeddings_cache_file.exists(): @@ -715,10 +792,11 @@ def _load_emebddings(self, model_name_or_path: str, dictionary_name_or_path: str cache_folder.mkdir(parents=True, exist_ok=True) - names = self.dictionary.to_names(preprocessor=self.preprocessor) + names = [self.preprocessor.process_entity_name(name) for name, cui in self.dictionary] # Compute dense embeddings (if necessary) dense_embeddings = self.embed_dense(inputs=names, batch_size=batch_size, show_progress=True) + sparse_embeddings = self.embed_sparse(inputs=names) if self.hybrid_search else None # Store the pre-computed index on disk for later re-use @@ -727,11 +805,11 @@ def _load_emebddings(self, model_name_or_path: str, dictionary_name_or_path: str "sparse": sparse_embeddings, } - logger.info("Caching preprocessed dictionary into %s", cache_folder) + logger.info("Caching dictionary emebddings into %s", embeddings_cache_file) with embeddings_cache_file.open("wb") as fp: pickle.dump(embeddings, fp) - if self.similarity_metric == SimilarityMetric.COS: + if self.similarity_metric == SimilarityMetric.COSINE: faiss.normalize_L2(embeddings["dense"]) return embeddings @@ -775,7 +853,7 @@ def indexing_2d(arr, cols): topk_idxs = indexing_2d(topk_idxs, topk_argidxs) topk_scores = indexing_2d(score_matrix, topk_idxs) - return (topk_idxs, topk_scores) + return topk_scores, topk_idxs def search_dense(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np.ndarray, np.ndarray]: """ @@ -785,11 +863,53 @@ def search_dense(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np.n # Compute dense embedding for the given entity mention mention_dense_embeds = self.embed_dense(inputs=np.array(entity_mentions), batch_size=self.index_batch_size) + if self.similarity_metric == SimilarityMetric.COSINE: + faiss.normalize_L2(mention_dense_embeds) + # Get candidates from dense embeddings - dists, ids = self.dense_index.search(x=mention_dense_embeds, top_k=top_k) + dists, ids = self.dense_index.search(mention_dense_embeds, top_k) return dists, ids + def combine_dense_and_sparse_results( + self, + dense_ids: np.ndarray, + dense_scores: np.ndarray, + sparse_ids: np.ndarray, + sparse_scores: np.ndarray, + top_k: int = 1, + ): + """ + Expand dense resutls with sparse ones (that are not already in the dense) + and re-weight the score as: dense_score + sparse_weight * sparse_scores + """ + + hybrid_ids = [] + hybrid_scores = [] + for i in range(dense_ids.shape[0]): + + mention_ids = dense_ids[i] + mention_scores = dense_scores[i] + + mention_spare_ids = sparse_ids[i] + mention_sparse_scores = sparse_scores[i] + + for sparse_id, sparse_score in zip(mention_spare_ids, mention_sparse_scores): + if sparse_id not in mention_ids: + mention_ids = np.append(mention_ids, sparse_id) + mention_scores = np.append(mention_scores, self.sparse_weight * sparse_score) + else: + index = np.where(mention_ids == sparse_id)[0][0] + mention_scores[index] += self.sparse_weight * sparse_score + + rerank_indices = np.argsort(-mention_scores if self.higher_is_better else mention_scores) + mention_ids = mention_ids[rerank_indices][:top_k] + mention_scores = mention_scores[rerank_indices][:top_k] + hybrid_ids.append(mention_ids.tolist()) + hybrid_scores.append(mention_scores.tolist()) + + return hybrid_scores, hybrid_ids + def search(self, entity_mentions: List[str], top_k: int) -> List[Tuple[str, str, float]]: """ Returns the top-k entities for a given entity mention. @@ -799,58 +919,25 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[Tuple[str, str, :result: List of tuples w/ the top-k entities: (concept name, concept ids, score). """ - # dense - - # # If using sparse embeds: calculate hybrid scores with dense and sparse embeds - # if self.use_sparse_embeds: - # # Get sparse embeddings for the entity mention - # mention_sparse_embeds = self.embed_sparse(entity_names=np.array([entity_mention])) - - # # Get candidates from sparse embeddings - # sparse_ids, sparse_distances = self.search_sparse( - # mention_embeddings=mention_sparse_embeds, - # dict_concept_embeddings=self.dict_sparse_embeddings, - # top_k=top_k + self.top_k_extra_sparse, - # ) - - # # Combine dense and sparse scores - # sparse_weight = self.sparse_weight - # hybrid_ids = [] - # hybrid_scores = [] - - # # For every embedded mention - # for ( - # top_dense_ids, - # top_dense_scores, - # top_sparse_ids, - # top_sparse_distances, - # ) in zip(dense_ids, dense_scores, sparse_ids, sparse_distances): - # ids = top_dense_ids - # distances = top_dense_scores - - # for sparse_id, sparse_distance in zip(top_sparse_ids, top_sparse_distances): - # if sparse_id not in ids: - # ids = np.append(ids, sparse_id) - # distances = np.append(distances, sparse_weight * sparse_distance) - # else: - # index = np.where(ids == sparse_id)[0][0] - # distances[index] = (sparse_weight * sparse_distance) + distances[index] - - # sorted_indizes = np.argsort(-distances) - # ids = ids[sorted_indizes][:top_k] - # distances = distances[sorted_indizes][:top_k] - # hybrid_ids.append(ids.tolist()) - # hybrid_scores.append(distances.tolist()) - - # else: - # # Use only dense embedding results - # hybrid_ids = dense_ids - # hybrid_scores = dense_scores - - # return [ - # tuple(self.dictionary[entity_index].reshape(1, -1)[0]) + (score[0],) - # for entity_index, score in zip(hybrid_ids, hybrid_scores) - # ] + scores, ids = self.search_dense(entity_mentions=entity_mentions, top_k=top_k) + + if self.hybrid_search: + + sparse_scores, sparse_ids = self.search_sparse(entity_mentions=entity_mentions, top_k=top_k) + + scores, ids = self.combine_dense_and_sparse_results( + dense_ids=ids, + dense_scores=scores, + sparse_scores=sparse_scores, + sparse_ids=sparse_ids, + top_k=top_k, + ) + + return [ + tuple(self.dictionary[i]) + (score,) + for mention_ids, mention_scores in zip(ids, scores) + for i, score in zip(mention_ids, mention_scores) + ] class BiomedicalEntityLinker: @@ -859,9 +946,64 @@ class BiomedicalEntityLinker: entity / concept to these mentions according to a knowledge base / dictionary. """ - def __init__(self, retriever_model: EntityRetriever, mention_preprocessor: BioNelPreprocessor): - self.preprocessor = mention_preprocessor - self.retriever_model = retriever_model + def __init__(self, candidate_generator: AbstractCandidateGenerator, preprocessor: AbstractEntityPreprocessor): + self.preprocessor = preprocessor + self.candidate_generator = candidate_generator + + def build_entity_linking_label(self, data_point: Span, prediction: Tuple[str, str, int]) -> EntityLinkingLabel: + """ + Create entity linking label from retriever model result + """ + + # if concept identifier is made up of multiple ids, separated by '|' + # separate it into cui and additional_labels + cui = prediction[1] + if "|" in cui: + labels = cui.split("|") + cui = labels[0] + additional_labels = labels[1:] + else: + additional_labels = None + + # determine database: + if ":" in cui: + cui_parts = cui.split(":") + database = ":".join(cui_parts[0:-1]) + cui = cui_parts[-1] + else: + database = None + + return EntityLinkingLabel( + data_point=data_point, + concept_id=cui, + concept_name=prediction[0], + additional_ids=additional_labels, + database=database, + score=prediction[2], + ) + + def extract_mentions( + self, sentences: List[Sentence], input_entity_annotation_layer: Optional[str] = None + ) -> Tuple[List[int], List[Span], List[str]]: + """ + Unpack all mentions in sentences for batch search. + Output is list of (sentence index, mention text). + """ + + source = [] + data_points = [] + mentions = [] + for i, sentence in enumerate(sentences): + for entity in sentence.get_labels(input_entity_annotation_layer): + source.append(i) + data_points.append(entity.data_point) + mentions.append( + self.preprocessor.process_mention(entity, sentence) + if self.preprocessor is not None + else entity.data_point.text, + ) + + return source, data_points, mentions def predict( self, @@ -888,99 +1030,72 @@ def predict( # Build label name label_name = input_entity_annotation_layer + "_nen" if (input_entity_annotation_layer is not None) else "nen" - # For every sentence .. - for sentence in sentences: - # ... process every mentioned entity - for entity in sentence.get_labels(input_entity_annotation_layer): - # Pre-process entity mention (if necessary) - mention_text = ( - self.preprocessor.process_mention(entity, sentence) - if self.preprocessor is not None - else entity.data_point.text - ) + source, data_points, mentions = self.extract_mentions( + sentences=sentences, input_entity_annotation_layer=input_entity_annotation_layer + ) - # Retrieve top-k concept / entity candidates - predictions = self.retriever_model.search(mention_text, top_k) - - # Add a label annotation for each candidate - for prediction in predictions: - # if concept identifier is made up of multiple ids, separated by '|' - # separate it into cui and additional_labels - cui = prediction[1] - if "|" in cui: - labels = cui.split("|") - cui = labels[0] - additional_labels = labels[1:] - else: - additional_labels = None + # Retrieve top-k concept / entity candidates + predictions = self.candidate_generator.search(entity_mentions=mentions, top_k=top_k) - # determine database: - if ":" in cui: - cui_parts = cui.split(":") - database = ":".join(cui_parts[0:-1]) - cui = cui_parts[-1] - else: - database = None - - sentence.add_label( - typename=label_name, - value_or_label=EntityLinkingLabel( - data_point=entity.data_point, - concept_id=cui, - concept_name=prediction[0], - additional_ids=additional_labels, - database=database, - score=prediction[2], - ), - ) + # Add a label annotation for each candidate + for i, data_point, prediction in zip(source, data_points, predictions): + + sentences[i].add_label( + typename=label_name, + value_or_label=self.build_entity_linking_label(prediction=prediction, data_point=data_point), + ) @classmethod def load( cls, model_name_or_path: Union[str, Path], - dictionary_name_or_path: Union[str, Path] = None, + dictionary_name_or_path: Optional[Union[str, Path]] = None, hybrid_search: bool = True, max_length: int = 25, index_batch_size: int = 1024, similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, - preprocessor: BioNelPreprocessor = Ab3PPreprocessor.load(preprocessor=BasicBioNelPreprocessor()), - default_sparse_encoder: bool = False, + preprocessor: AbstractEntityPreprocessor = Ab3PEntityPreprocessor.load(preprocessor=EntityPreprocessor()), + force_hybrid_search: bool = False, sparse_weight: float = DEFAULT_SPARSE_WEIGHT, ): """ Loads a model for biomedical named entity normalization. See __init__ method for detailed docstring on arguments """ - dictionary_path = dictionary_name_or_path + if dictionary_name_or_path is None or isinstance(dictionary_name_or_path, str): - dictionary_path = cls.__get_dictionary_path(dictionary_name_or_path, model_name_or_path) + dictionary_name_or_path = cls.__get_dictionary_path( + model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path + ) - retriever_model = None if isinstance(model_name_or_path, str): - if model_name_or_path == "exact-string-match": - retriever_model = ExactStringMatchingRetriever.load(dictionary_path) - else: - model_path = cls.__get_model_path( - model_name_or_path=model_name_or_path, - hybrid_search=hybrid_search, - default_sparse_encoder=default_sparse_encoder, - ) - retriever_model = BiEncoderEntityRetriever( - model_name_or_path=model_path, - dictionary_name_or_path=dictionary_path, - hybrid_search=hybrid_search, - similarity_metric=similarity_metric, - max_length=max_length, - index_batch_size=index_batch_size, - sparse_weight=sparse_weight, - preprocessor=preprocessor, - ) + model_name_or_path = cls.__get_model_path( + model_name_or_path=model_name_or_path, + hybrid_search=hybrid_search, + force_hybrid_search=force_hybrid_search, + ) - return cls(retriever_model=retriever_model, mention_preprocessor=preprocessor) + if model_name_or_path == "exact-string-match": + candidate_generator = ExactMatchCandidateGenerator.load(dictionary_name_or_path) + else: + candidate_generator = BiEncoderCandidateGenerator( + model_name_or_path=model_name_or_path, + dictionary_name_or_path=dictionary_name_or_path, + hybrid_search=hybrid_search, + similarity_metric=similarity_metric, + max_length=max_length, + index_batch_size=index_batch_size, + sparse_weight=sparse_weight, + preprocessor=preprocessor, + ) + + logger.info("Load model `%s` with dictionary `%s`", model_name_or_path, dictionary_name_or_path) + + return cls(candidate_generator=candidate_generator, preprocessor=preprocessor) @staticmethod def __get_model_path( - model_name_or_path: Union[str, Path], hybrid_search: bool = False, default_sparse_encoder: bool = False + model_name_or_path: Union[str, Path], hybrid_search: bool = False, force_hybrid_search: bool = False ) -> str: """ Try to figure out what model the user wants @@ -988,8 +1103,6 @@ def __get_model_path( if isinstance(model_name_or_path, str): - model_name_or_path = model_name_or_path.lower() - if model_name_or_path not in MODELS and model_name_or_path not in ENTITY_TYPES: raise ValueError( f"""Unknown model `{model_name_or_path}`! \n @@ -998,7 +1111,6 @@ def __get_model_path( ) if hybrid_search: - # load model by entity_type if model_name_or_path in ENTITY_TYPES: # check if we have a hybrid pre-trained model @@ -1006,53 +1118,52 @@ def __get_model_path( model_name_or_path = ENTITY_TYPE_TO_HYBRID_MODEL[model_name_or_path] else: # check if user really wants to use hybrid search anyway - if not default_sparse_encoder: + if not force_hybrid_search: raise ValueError( - f"""Model for entity type `{model_name_or_path}` was not trained for hybrid search! \n - If you want to proceed anyway please pass `default_sparse_encoder=True`: + f""" + Model for entity type `{model_name_or_path}` was not trained for hybrid search! + If you want to proceed anyway please pass `force_hybrid_search=True`: we will fit a sparse encoder for you. The default value of `sparse_weight` is `{DEFAULT_SPARSE_WEIGHT}`. """ ) model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] else: - if model_name_or_path not in PRETRAINED_HYBRID_MODELS and not default_sparse_encoder: + if model_name_or_path not in PRETRAINED_HYBRID_MODELS and not force_hybrid_search: raise ValueError( - f"""Model `{model_name_or_path}` was not trained for hybrid search! \n - If you want to proceed anyway please pass `default_sparse_encoder=True`: + f""" + Model `{model_name_or_path}` was not trained for hybrid search! + If you want to proceed anyway please pass `force_hybrid_search=True`: we will fit a sparse encoder for you. The default value of `sparse_weight` is `{DEFAULT_SPARSE_WEIGHT}`. """ ) - return model_name_or_path + else: + model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] + + return model_name_or_path @staticmethod - def __get_dictionary_path(model_name: str, dictionary_name_or_path: Optional[Union[str, Path]] = None) -> str: + def __get_dictionary_path( + model_name_or_path: str, dictionary_name_or_path: Optional[Union[str, Path]] = None + ) -> str: """ Try to figure out what dictionary (depending on the model) the user wants """ - if model_name in STRING_MATCHING_MODELS and dictionary_name_or_path is None: + if model_name_or_path in STRING_MATCHING_MODELS and dictionary_name_or_path is None: raise ValueError("When using a string-matchin retriever you must specify `dictionary_name_or_path`!") - if dictionary_name_or_path not in MODELS and dictionary_name_or_path not in ENTITY_TYPES: - raise ValueError( - f"""Unknown dictionary `{dictionary_name_or_path}`! \n - Available entity types are: {ENTITY_TYPES} \n - If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)` - """ - ) - if dictionary_name_or_path is not None: if dictionary_name_or_path in ENTITY_TYPES: - dictionary_name_or_path = ENTITY_TYPE_TO_NEL_DICTIONARY[dictionary_name_or_path] + dictionary_name_or_path = ENTITY_TYPE_TO_DICTIONARY[dictionary_name_or_path] + else: + if model_name_or_path in MODEL_NAME_TO_DICTIONARY: + dictionary_name_or_path = MODEL_NAME_TO_DICTIONARY[model_name_or_path] + elif model_name_or_path in ENTITY_TYPE_TO_DICTIONARY: + dictionary_name_or_path = ENTITY_TYPE_TO_DICTIONARY[model_name_or_path] else: - if model_name in MODEL_NAME_TO_NEL_DICTIONARY: - dictionary_name_or_path = MODEL_NAME_TO_NEL_DICTIONARY[dictionary_name_or_path] - else: - raise ValueError( - """When using a custom model you need to specify a dictionary. - Available options are: 'disease', 'chemical', 'gene' and 'species'. - Or provide a path to a dictionary file.""" - ) + raise ValueError( + f"When using a custom model you need to specify a dictionary. Available options are: {ENTITY_TYPES}. Or provide a path to a dictionary file." + ) return dictionary_name_or_path diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index d25f76ac6..a285af5c7 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -1,60 +1,66 @@ # from flair.data import Sentence -# from flair.models.biomedical_entity_linking import BioNelDictionary +# from flair.models.biomedical_entity_linking import ( +# BiomedicalEntityLinker, +# BiomedicalEntityLinkingDictionary, +# ) # from flair.nn import Classifier -# from flair.tokenization import SciSpacyTokenizer -# def test_bionel_dictionary(): + +# def test_bel_dictionary(): # """ # Check data in dictionary is what we expect. # Hard to define a good test as dictionaries are DYNAMIC, # i.e. they can change over time # """ -# dictionary = BioNelDictionary.load("disease") +# dictionary = BiomedicalEntityLinkingDictionary.load("disease") # _, identifier = next(dictionary.stream()) # assert identifier.startswith(("MESH:", "OMIM:", "DO:DOID")) -# dictionary = BioNelDictionary.load("ctd-disease") +# dictionary = BiomedicalEntityLinkingDictionary.load("ctd-disease") # _, identifier = next(dictionary.stream()) # assert identifier.startswith("MESH:") -# dictionary = BioNelDictionary.load("ctd-chemical") +# dictionary = BiomedicalEntityLinkingDictionary.load("ctd-chemical") # _, identifier = next(dictionary.stream()) # assert identifier.startswith("MESH:") -# dictionary = BioNelDictionary.load("chemical") +# dictionary = BiomedicalEntityLinkingDictionary.load("chemical") # _, identifier = next(dictionary.stream()) # assert identifier.startswith("MESH:") -# dictionary = BioNelDictionary.load("ncbi-taxonomy") +# dictionary = BiomedicalEntityLinkingDictionary.load("ncbi-taxonomy") # _, identifier = next(dictionary.stream()) # assert identifier.isdigit() -# dictionary = BioNelDictionary.load("species") +# dictionary = BiomedicalEntityLinkingDictionary.load("species") # _, identifier = next(dictionary.stream()) # assert identifier.isdigit() -# dictionary = BioNelDictionary.load("ncbi-gene") +# dictionary = BiomedicalEntityLinkingDictionary.load("ncbi-gene") # _, identifier = next(dictionary.stream()) # assert identifier.isdigit() -# dictionary = BioNelDictionary.load("gene") +# dictionary = BiomedicalEntityLinkingDictionary.load("gene") # _, identifier = next(dictionary.stream()) # assert identifier.isdigit() # def test_biomedical_entity_linking(): -# sentence = Sentence( -# "Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome", use_tokenizer=SciSpacyTokenizer() -# ) -# ner_tagger = Classifier.load("hunflair-disease") -# ner_tagger.predict(sentence) -# nen_tagger = BiomedicalEntityLinker.load("disease") -# nen_tagger.predict(sentence) -# for tag in sentence.get_labels(): -# print(tag) + +# sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome") + +# tagger = Classifier.load("hunflair") +# tagger.predict(sentence) + +# disease_linker = BiomedicalEntityLinker.load("disease", hybrid_search=True) +# disease_linker.predict(sentence) + +# gene_linker = BiomedicalEntityLinker.load("gene", hybrid_search=False) + +# breakpoint() # if __name__ == "__main__": -# test_bionel_dictionary() -# test_biomedical_entity_linking() +# # test_bel_dictionary() +# test_biomedical_entity_linking() From 48e8ae7dd6b6be30e699e760ee8bec750e6931d6 Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Tue, 2 May 2023 19:39:15 +0200 Subject: [PATCH 11/58] fix(bionel): assign entity type - predict only on mentions of give entity type --- flair/models/biomedical_entity_linking.py | 163 ++++++++++++++-------- 1 file changed, 102 insertions(+), 61 deletions(-) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index 344ff187c..1298eef6b 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -35,23 +35,23 @@ logger = logging.getLogger("flair") -PRETRAINED_MODELS = [ +PRETRAINED_DENSE_MODELS = [ "cambridgeltl/SapBERT-from-PubMedBERT-fulltext", ] # Dense + sparse retrieval -PRETRAINED_HYBRID_MODELS = [ - "dmis-lab/biosyn-sapbert-bc5cdr-disease", - "dmis-lab/biosyn-sapbert-ncbi-disease", - "dmis-lab/biosyn-sapbert-bc5cdr-chemical", - "dmis-lab/biosyn-biobert-bc5cdr-disease", - "dmis-lab/biosyn-biobert-ncbi-disease", - "dmis-lab/biosyn-biobert-bc5cdr-chemical", - "dmis-lab/biosyn-biobert-bc2gn", - "dmis-lab/biosyn-sapbert-bc2gn", -] +PRETRAINED_HYBRID_MODELS = { + "dmis-lab/biosyn-sapbert-bc5cdr-disease": "disease", + "dmis-lab/biosyn-sapbert-ncbi-disease": "disease", + "dmis-lab/biosyn-sapbert-bc5cdr-chemical": "chemical", + "dmis-lab/biosyn-biobert-bc5cdr-disease": "disease", + "dmis-lab/biosyn-biobert-ncbi-disease": "disease", + "dmis-lab/biosyn-biobert-bc5cdr-chemical": "chemical", + "dmis-lab/biosyn-biobert-bc2gn": "gene", + "dmis-lab/biosyn-sapbert-bc2gn": "gene", +} -PRETRAINED_MODELS = PRETRAINED_HYBRID_MODELS + PRETRAINED_MODELS +PRETRAINED_MODELS = list(PRETRAINED_HYBRID_MODELS) + PRETRAINED_DENSE_MODELS # just in case we add: fuzzy search, Levenstein, ... STRING_MATCHING_MODELS = ["exact-string-match"] @@ -60,6 +60,13 @@ ENTITY_TYPES = ["disease", "chemical", "gene", "species"] +ENTITY_TYPE_TO_LABELS = { + "disease": "diseases", + "gene": "genes", + "species": "species", + "chemical": "chemical", +} + ENTITY_TYPE_TO_HYBRID_MODEL = { "disease": "dmis-lab/biosyn-sapbert-bc5cdr-disease", "chemical": "dmis-lab/biosyn-sapbert-bc5cdr-chemical", @@ -80,6 +87,13 @@ "chemical": "ctd-chemical", } +ENTITY_TYPE_TO_ANNOTATION_LAYER = { + "disease": "diseases", + "gene": "genes", + "chemical": "chemicals", + "species": "species", +} + BIOMEDICAL_DICTIONARIES = { "ctd-disease": CTD_DISEASE_DICTIONARY, "ctd-chemical": CTD_CHEMICAL_DICTIONARY, @@ -438,7 +452,7 @@ def __call__(self, mentions: List[str]) -> torch.Tensor: def save(self, path: Path) -> None: with path.open("wb") as fout: pickle.dump(self.encoder, fout) - logger.info("Sparse encoder saved in %s", path) + # logger.info("Sparse encoder saved in %s", path) @classmethod def load(cls, path: Path) -> "BigramTfIDFVectorizer": @@ -448,7 +462,7 @@ def load(cls, path: Path) -> "BigramTfIDFVectorizer": newVectorizer = cls() with open(path, "rb") as fin: newVectorizer.encoder = pickle.load(fin) - logger.info("Sparse encoder loaded from %s", path) + # logger.info("Sparse encoder loaded from %s", path) return newVectorizer @@ -785,7 +799,7 @@ def _load_emebddings(self, model_name_or_path: str, dictionary_name_or_path: str if embeddings_cache_file.exists(): with embeddings_cache_file.open("rb") as fp: - logger.info("Load cached emebddings from %s", embeddings_cache_file) + logger.info("Load cached emebddings from: %s", embeddings_cache_file) embeddings = pickle.load(fp) else: @@ -946,9 +960,16 @@ class BiomedicalEntityLinker: entity / concept to these mentions according to a knowledge base / dictionary. """ - def __init__(self, candidate_generator: AbstractCandidateGenerator, preprocessor: AbstractEntityPreprocessor): + def __init__( + self, + candidate_generator: AbstractCandidateGenerator, + preprocessor: AbstractEntityPreprocessor, + entity_type: str, + ): self.preprocessor = preprocessor self.candidate_generator = candidate_generator + self.entity_type = entity_type + self.annotation_layer = ENTITY_TYPE_TO_ANNOTATION_LAYER[self.entity_type] def build_entity_linking_label(self, data_point: Span, prediction: Tuple[str, str, int]) -> EntityLinkingLabel: """ @@ -983,7 +1004,8 @@ def build_entity_linking_label(self, data_point: Span, prediction: Tuple[str, st ) def extract_mentions( - self, sentences: List[Sentence], input_entity_annotation_layer: Optional[str] = None + self, + sentences: List[Sentence], ) -> Tuple[List[int], List[Span], List[str]]: """ Unpack all mentions in sentences for batch search. @@ -994,7 +1016,7 @@ def extract_mentions( data_points = [] mentions = [] for i, sentence in enumerate(sentences): - for entity in sentence.get_labels(input_entity_annotation_layer): + for entity in sentence.get_labels(self.annotation_layer): source.append(i) data_points.append(entity.data_point) mentions.append( @@ -1003,12 +1025,14 @@ def extract_mentions( else entity.data_point.text, ) + assert len(mentions) > 0, f"There are no entity mentions of type `{self.entity_type}`" + return source, data_points, mentions def predict( self, sentences: Union[List[Sentence], Sentence], - input_entity_annotation_layer: str = None, + # input_entity_annotation_layer: str = None, top_k: int = 1, ) -> None: """ @@ -1016,7 +1040,6 @@ def predict( with tag input_entity_annotation_layer. :param sentences: One or more sentences to run the prediction on - :param input_entity_annotation_layer: Entity type to run the prediction on :param top_k: Number of best-matching entity / concept identifiers which should be predicted per entity mention """ @@ -1028,11 +1051,9 @@ def predict( self.preprocessor.initialize(sentences) # Build label name - label_name = input_entity_annotation_layer + "_nen" if (input_entity_annotation_layer is not None) else "nen" + # label_name = input_entity_annotation_layer + "_nen" if (input_entity_annotation_layer is not None) else "nen" - source, data_points, mentions = self.extract_mentions( - sentences=sentences, input_entity_annotation_layer=input_entity_annotation_layer - ) + source, data_points, mentions = self.extract_mentions(sentences=sentences) # Retrieve top-k concept / entity candidates predictions = self.candidate_generator.search(entity_mentions=mentions, top_k=top_k) @@ -1041,7 +1062,7 @@ def predict( for i, data_point, prediction in zip(source, data_points, predictions): sentences[i].add_label( - typename=label_name, + typename=self.annotation_layer, value_or_label=self.build_entity_linking_label(prediction=prediction, data_point=data_point), ) @@ -1057,6 +1078,7 @@ def load( preprocessor: AbstractEntityPreprocessor = Ab3PEntityPreprocessor.load(preprocessor=EntityPreprocessor()), force_hybrid_search: bool = False, sparse_weight: float = DEFAULT_SPARSE_WEIGHT, + entity_type: Optional[str] = None, ): """ Loads a model for biomedical named entity normalization. @@ -1069,11 +1091,15 @@ def load( ) if isinstance(model_name_or_path, str): - model_name_or_path = cls.__get_model_path( + model_name_or_path, entity_type = cls.__get_model_path_and_entity_type( model_name_or_path=model_name_or_path, + entity_type=entity_type, hybrid_search=hybrid_search, force_hybrid_search=force_hybrid_search, ) + else: + assert entity_type is not None, "When using a custom model you must specify `entity_type`" + assert entity_type in ENTITY_TYPES, f"Invalid entity type `{entity_type}! Must be one of: {ENTITY_TYPES}" if model_name_or_path == "exact-string-match": candidate_generator = ExactMatchCandidateGenerator.load(dictionary_name_or_path) @@ -1089,62 +1115,77 @@ def load( preprocessor=preprocessor, ) - logger.info("Load model `%s` with dictionary `%s`", model_name_or_path, dictionary_name_or_path) + logger.info( + "BiomedicalEntityLinker predicts: Entity type: %s with Dictionary `%s`", + entity_type, + dictionary_name_or_path, + ) - return cls(candidate_generator=candidate_generator, preprocessor=preprocessor) + return cls(candidate_generator=candidate_generator, preprocessor=preprocessor, entity_type=entity_type) @staticmethod - def __get_model_path( - model_name_or_path: Union[str, Path], hybrid_search: bool = False, force_hybrid_search: bool = False - ) -> str: + def __get_model_path_and_entity_type( + model_name_or_path: Union[str, Path], + entity_type: Optional[str] = None, + hybrid_search: bool = False, + force_hybrid_search: bool = False, + ) -> Tuple[str, str]: """ Try to figure out what model the user wants """ - if isinstance(model_name_or_path, str): - - if model_name_or_path not in MODELS and model_name_or_path not in ENTITY_TYPES: - raise ValueError( - f"""Unknown model `{model_name_or_path}`! \n - Available entity types are: {ENTITY_TYPES} \n - If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)`""" - ) + if model_name_or_path not in MODELS and model_name_or_path not in ENTITY_TYPES: + raise ValueError( + f"""Unknown model `{model_name_or_path}`! \n + Available entity types are: {ENTITY_TYPES} \n + If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)`""" + ) - if hybrid_search: - # load model by entity_type - if model_name_or_path in ENTITY_TYPES: - # check if we have a hybrid pre-trained model - if model_name_or_path in ENTITY_TYPE_TO_HYBRID_MODEL: - model_name_or_path = ENTITY_TYPE_TO_HYBRID_MODEL[model_name_or_path] - else: - # check if user really wants to use hybrid search anyway - if not force_hybrid_search: - raise ValueError( - f""" - Model for entity type `{model_name_or_path}` was not trained for hybrid search! - If you want to proceed anyway please pass `force_hybrid_search=True`: - we will fit a sparse encoder for you. The default value of `sparse_weight` is `{DEFAULT_SPARSE_WEIGHT}`. - """ - ) - model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] + if model_name_or_path == "cambridgeltl/SapBERT-from-PubMedBERT-fulltext": + assert entity_type is not None, f"For model {model_name_or_path} you must specify `entity_type`" + + entity_type = None + if hybrid_search: + # load model by entity_type + if model_name_or_path in ENTITY_TYPES: + # check if we have a hybrid pre-trained model + if model_name_or_path in ENTITY_TYPE_TO_HYBRID_MODEL: + entity_type = model_name_or_path + model_name_or_path = ENTITY_TYPE_TO_HYBRID_MODEL[model_name_or_path] else: - if model_name_or_path not in PRETRAINED_HYBRID_MODELS and not force_hybrid_search: + # check if user really wants to use hybrid search anyway + if not force_hybrid_search: raise ValueError( f""" - Model `{model_name_or_path}` was not trained for hybrid search! + Model for entity type `{model_name_or_path}` was not trained for hybrid search! If you want to proceed anyway please pass `force_hybrid_search=True`: we will fit a sparse encoder for you. The default value of `sparse_weight` is `{DEFAULT_SPARSE_WEIGHT}`. """ ) - + model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] else: + if model_name_or_path not in PRETRAINED_HYBRID_MODELS and not force_hybrid_search: + raise ValueError( + f""" + Model `{model_name_or_path}` was not trained for hybrid search! + If you want to proceed anyway please pass `force_hybrid_search=True`: + we will fit a sparse encoder for you. The default value of `sparse_weight` is `{DEFAULT_SPARSE_WEIGHT}`. + """ + ) + entity_type = PRETRAINED_HYBRID_MODELS[model_name_or_path] + + else: + if model_name_or_path in ENTITY_TYPES: model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] - return model_name_or_path + assert entity_type is not None, f"Impossible to determine entity type for model `{model_name_or_path}`" + + return model_name_or_path, entity_type @staticmethod def __get_dictionary_path( - model_name_or_path: str, dictionary_name_or_path: Optional[Union[str, Path]] = None + model_name_or_path: str, + dictionary_name_or_path: Optional[Union[str, Path]] = None, ) -> str: """ Try to figure out what dictionary (depending on the model) the user wants From e6b57eb1c9239374de2c7e61fc79b0c1c2ddf72d Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Tue, 2 May 2023 20:14:00 +0200 Subject: [PATCH 12/58] fix(biencoder): set sparse encoder and weight --- flair/models/biomedical_entity_linking.py | 44 ++++++++++------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index 1298eef6b..887b9929d 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -661,18 +661,17 @@ def build_dense_index(self, embeddings: np.ndarray) -> faiss.Index: return dense_index - def _fit_and_cache_sparse_encoder(self, sparse_encoder_path: str, sparse_weight_path: str): + def _fit_sparse_encoder(self): """Fit sparse encoder to current dictionary""" - sparse_weight = self.sparse_weight if self.sparse_weight is not None else DEFAULT_SPARSE_WEIGHT logger.info( "Hybrid model has no pretrained sparse encoder. Fit to dictionary `%s` (sparse_weight=%s)", self.dictionary_name_or_path, - sparse_weight, + self.sparse_weight, ) - sparse_encoder = BigramTfIDFVectorizer().fit([name for name, cui in self.dictionary]) - sparse_encoder.save(sparse_encoder_path) - torch.save(torch.FloatTensor(self.sparse_weight), sparse_weight_path) + self.sparse_encoder = BigramTfIDFVectorizer().fit([name for name, cui in self.dictionary]) + # sparse_encoder.save(Path(sparse_encoder_path)) + # torch.save(torch.FloatTensor(self.sparse_weight), sparse_weight_path) def _set_sparse_weigth_and_encoder( self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path] @@ -692,6 +691,7 @@ def _set_sparse_weigth_and_encoder( filename="sparse_encoder.pk", cache_dir=flair.cache_root / "models" / model_name_or_path, ) + self.sparse_encoder = BigramTfIDFVectorizer.load(path=sparse_encoder_path) if not os.path.exists(sparse_weight_path): sparse_weight_path = hf_hub_download( @@ -699,29 +699,23 @@ def _set_sparse_weigth_and_encoder( filename="sparse_weight.pt", cache_dir=flair.cache_root / "models" / model_name_or_path, ) + self.sparse_weight = torch.load(sparse_weight_path, map_location="cpu").item() else: if self.force_hybrid_search: - if not os.path.exists(sparse_encoder_path) and not os.path.exists(sparse_weight_path): - self._fit_and_cache_sparse_encoder( - sparse_encoder_path=sparse_encoder_path, sparse_weight_path=sparse_weight_path - ) + self.sparse_weight = self.sparse_weight if self.sparse_weight is not None else DEFAULT_SPARSE_WEIGHT + self._fit_sparse_encoder() else: raise ValueError( - f"Hybrid model has no pretrained sparse encoder. Please pass `force_hybrid_search=True` if you want to fit a sparse model to dictionary `{dictionary_name_or_path}`" + f"A: Hybrid model has no pretrained sparse encoder. Please pass `force_hybrid_search=True` if you want to fit a sparse model to dictionary `{dictionary_name_or_path}`" ) else: - if not os.path.exists(sparse_encoder_path) and not os.path.exists(sparse_weight_path): - if self.force_hybrid_search: - self._fit_and_cache_sparse_encoder( - sparse_encoder_path=sparse_encoder_path, sparse_weight_path=sparse_weight_path - ) - else: - raise ValueError( - f"Hybrid model has no pretrained sparse encoder. Please pass `force_hybrid_search=True` if you want to fit a sparse model to dictionary `{dictionary_name_or_path}`" - ) - - self.sparse_encoder = BigramTfIDFVectorizer.load(path=sparse_encoder_path) - self.sparse_weight = torch.load(sparse_weight_path, map_location="cpu").item() + if self.force_hybrid_search: + self.sparse_weight = self.sparse_weight if self.sparse_weight is not None else DEFAULT_SPARSE_WEIGHT + self._fit_sparse_encoder() + else: + raise ValueError( + f"Local hybrid model has no pretrained sparse encoder. Please pass `force_hybrid_search=True` if you want to fit a sparse model to dictionary `{dictionary_name_or_path}`" + ) def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: """ @@ -793,7 +787,7 @@ def _load_emebddings(self, model_name_or_path: str, dictionary_name_or_path: str pp_name = self.preprocessor.name if self.preprocessor is not None else "null" - embeddings_cache_file = cache_folder / f"{file_name}_pp={pp_name}.pk" + embeddings_cache_file = cache_folder / f"{file_name}-{pp_name}.pk" # If exists, load the cached dictionary indices if embeddings_cache_file.exists(): @@ -1113,6 +1107,7 @@ def load( index_batch_size=index_batch_size, sparse_weight=sparse_weight, preprocessor=preprocessor, + force_hybrid_search=force_hybrid_search, ) logger.info( @@ -1144,7 +1139,6 @@ def __get_model_path_and_entity_type( if model_name_or_path == "cambridgeltl/SapBERT-from-PubMedBERT-fulltext": assert entity_type is not None, f"For model {model_name_or_path} you must specify `entity_type`" - entity_type = None if hybrid_search: # load model by entity_type if model_name_or_path in ENTITY_TYPES: From 0d3cec2ecfcfc782794ffb7da6f507881190f585 Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Thu, 11 May 2023 18:47:16 +0200 Subject: [PATCH 13/58] fix(bionel): address comments - fix mypy typing - fix typos - update docstrings - rm faiss from requirements - better naming - allow user to specify annotation layer in predict - allow no mentions --- flair/data.py | 72 ++-- flair/datasets/__init__.py | 12 +- flair/models/biomedical_entity_linking.py | 467 +++++++++++----------- 3 files changed, 272 insertions(+), 279 deletions(-) diff --git a/flair/data.py b/flair/data.py index 1b4176423..8069ecf53 100644 --- a/flair/data.py +++ b/flair/data.py @@ -6,7 +6,7 @@ from collections import Counter, defaultdict from operator import itemgetter from pathlib import Path -from typing import Dict, Iterable, List, NamedTuple, Optional, Union, cast +from typing import Dict, Iterable, List, NamedTuple, Optional, Tuple, Union, cast import torch from deprecated.sphinx import deprecated @@ -433,61 +433,71 @@ def __len__(self) -> int: raise NotImplementedError -class EntityLinkingLabel(Label): - """ - Label class models entity linking annotations. Each entity linking label has a data point it refers - to as well as the identifier and name of the concept / entity from a knowledge base or ontology. - - Optionally, additional concepts identifier and the database name can be provided. - """ +class EntityLinkingCandidate: + """Represent a single candidate returned by a CandidateGenerator""" def __init__( self, - data_point: DataPoint, concept_id: str, concept_name: str, + database_name: str, score: float = 1.0, additional_ids: Optional[Union[List[str], str]] = None, - database: Optional[str] = None, ): """ - Initializes the label instance. - - :param data_point: Data point / span the label refers to :param concept_id: Identifier of the entity / concept from the knowledge base / ontology :param concept_name: (Canonical) name of the entity / concept from the knowledge base / ontology :param score: Matching score of the entity / concept according to the entity mention :param additional_ids: List of additional identifiers for the concept / entity in the KB / ontology - :param database: Name of the knowlege base / ontology + :param database_name: Name of the knowlege base / ontology """ - super().__init__(data_point, concept_id, score) + self.concept_id = concept_id self.concept_name = concept_name - self.database = database - - if isinstance(additional_ids, str): - additional_ids = [additional_ids] + self.database_name = database_name + self.score = score self.additional_ids = additional_ids - def spawn(self, value: str, score: float = 1.0): - return EntityLinkingLabel( - data_point=self.data_point, - concept_id=value, - score=score, - concept_name=self.concept_name, - additional_ids=self.additional_ids, - database=self.database, - ) + +class EntityLinkingLabel(Label): + """ + Label class models entity linking annotations. Each entity linking label has a data point it refers + to as well as the identifier and name of the concept / entity from a knowledge base or ontology. + Optionally, additional concepts identifier and the database name can be provided. + """ + + def __init__(self, data_point: DataPoint, candidates: List[EntityLinkingCandidate]): + """ + Initializes the label instance. + :param data_point: Data point / span the label refers to + :param candidates: **sorted** list of candidates from candidate generator + """ + + def is_sorted(lst, key=lambda x: x, comparison=lambda x, y: x > y): + for i, el in enumerate(lst[1:]): + if comparison(key(el), key(lst[i])): + return False + return True + + # candidates must be sorted, regardless if higher is better or not + assert is_sorted(candidates, key=lambda x: x.score) or is_sorted( + candidates, key=lambda x: x.score, comparison=lambda x, y: x < y + ), "List of candidates must be sorted!" + + super().__init__(data_point, candidates[0].concept_id, candidates[0].score) + self.candidates = candidates + self.concept_name = self.candidates[0].concept_name + self.database_name = self.candidates[0].database_name def __str__(self): return ( f"{self.data_point.unlabeled_identifier}{flair._arrow} " - f"{self.concept_name} - {self.database}:{self._value} ({round(self._score, 4)})" + f"{self.concept_name} - {self.database_name}:{self._value} ({round(self._score, 4)})" ) def __repr__(self): return ( f"{self.data_point.unlabeled_identifier}{flair._arrow} " - f"{self.concept_name} - {self.database}:{self._value} ({round(self._score, 4)})" + f"{self.concept_name} - {self.database_name}:{self._value} ({round(self._score, 4)})" ) def __len__(self): @@ -499,7 +509,7 @@ def __eq__(self, other): and self.data_point == other.data_point and self.concept_name == other.concept_name and self.identifier == other.identifier - and self.database == other.database + and self.database_name == other.database_name and self.score == other.score ) diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index 094169c6b..c008cb0dc 100644 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -37,8 +37,8 @@ CLL, CRAFT, CRAFT_V4, - CTD_CHEMICAL_DICTIONARY, - CTD_DISEASE_DICTIONARY, + CTD_CHEMICALS_DICTIONARY, + CTD_DISEASES_DICTIONARY, DECA, FSU, GELLUS, @@ -394,10 +394,10 @@ "LINNEAUS", "LOCTEXT", "MIRNA", - "NEL_NCBI_HUMAN_GENE_DICT", - "NEL_NCBI_TAXONOMY_DICT", - "NEL_CTD_CHEMICAL_DICT", - "NEL_CTD_DISEASE_DICT", + "NCBI_GENE_HUMAN_DICTIONARY", + "NCBI_TAXONOMY_DICTIONARY", + "CTD_DISEASES_DICTIONARY", + "CTD_CHEMICALS_DICTIONARY", "NCBI_DISEASE", "ONTONOTES", "OSIRIS", diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index 887b9929d..c181d5455 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -6,13 +6,13 @@ import string import subprocess import tempfile +import warnings from abc import ABC, abstractmethod from collections import defaultdict from enum import Enum, auto from pathlib import Path from typing import Dict, Iterator, List, Optional, Tuple, Union -import faiss import numpy as np import torch from huggingface_hub import hf_hub_download @@ -21,17 +21,29 @@ from tqdm import tqdm import flair -from flair.data import EntityLinkingLabel, Label, Sentence, Span +from flair.data import EntityLinkingCandidate, EntityLinkingLabel, Label, Sentence, Span from flair.datasets import ( - CTD_CHEMICAL_DICTIONARY, - CTD_DISEASE_DICTIONARY, + CTD_CHEMICALS_DICTIONARY, + CTD_DISEASES_DICTIONARY, NCBI_GENE_HUMAN_DICTIONARY, NCBI_TAXONOMY_DICTIONARY, ) -from flair.datasets.biomedical import ParsedBiomedicalEntityLinkingDictionary +from flair.datasets.biomedical import ( + AbstractBiomedicalEntityLinkingDictionary, + ParsedBiomedicalEntityLinkingDictionary, +) from flair.embeddings import TransformerDocumentEmbeddings from flair.file_utils import cached_path +FAISS_VERSION = "1.7.4" + +try: + import faiss +except ImportError as error: + raise ImportError( + f"You need to install to run the biomedical entity linking: `pip faiss faiss-cpu=={FAISS_VERSION}`" + ) from error + logger = logging.getLogger("flair") @@ -83,8 +95,8 @@ ENTITY_TYPE_TO_DICTIONARY = { "gene": "ncbi-gene", "species": "ncbi-taxonomy", - "disease": "ctd-disease", - "chemical": "ctd-chemical", + "disease": "ctd-diseases", + "chemical": "ctd-chemicals", } ENTITY_TYPE_TO_ANNOTATION_LAYER = { @@ -95,8 +107,8 @@ } BIOMEDICAL_DICTIONARIES = { - "ctd-disease": CTD_DISEASE_DICTIONARY, - "ctd-chemical": CTD_CHEMICAL_DICTIONARY, + "ctd-diseases": CTD_DISEASES_DICTIONARY, + "ctd-chemicals": CTD_CHEMICALS_DICTIONARY, "ncbi-gene": NCBI_GENE_HUMAN_DICTIONARY, "ncbi-taxonomy": NCBI_TAXONOMY_DICTIONARY, } @@ -126,20 +138,15 @@ class SimilarityMetric(Enum): class AbstractEntityPreprocessor(ABC): """ - A entity pre-processor is used to transform / clean an entity mention (recognized by - an entity recognition model in the original text). This may include removing certain characters - (e.g. punctuation) or converting characters (e.g. HTML-encoded characters) as well as - (more sophisticated) domain-specific procedures. - - This class provides the basic interface for such transformations and should be extended by - subclasses that implement concrete transformations. + A pre-processor used to transform / clean both entity mentions and entity names + This class provides the basic interface for such transformations + and must provide a `name` attribute to uniquely identify the type of preprocessing applied. """ @property @abstractmethod def name(self) -> str: """ - Define preprocessor name. This is needed to correctly cache different multiple version of the dictionary """ @@ -175,26 +182,17 @@ def initialize(self, sentences: List[Sentence]): class EntityPreprocessor(AbstractEntityPreprocessor): """ - Basic implementation of MentionPreprocessor, which supports lowercasing, typo correction - and removing of punctuation characters. - - Implementation is adapted from: + Entity preprocessor adapted from: Sung et al. 2020, Biomedical Entity Representations with Synonym Marginalization https://github.com/dmis-lab/BioSyn/blob/master/src/biosyn/preprocesser.py#L5 """ - def __init__( - self, - lowercase: bool = True, - remove_punctuation: bool = True, - ) -> None: + def __init__(self, lowercase: bool = True, remove_punctuation: bool = True): """ Initializes the mention preprocessor. :param lowercase: Indicates whether to perform lowercasing or not (True by default) :param remove_punctuation: Indicates whether to perform removal punctuations symbols (True by default) - :param punctuation_symbols: String containing all punctuation symbols that should be removed - (default is given by string.punctuation) """ self.lowercase = lowercase self.remove_punctuation = remove_punctuation @@ -224,16 +222,11 @@ def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: class Ab3PEntityPreprocessor(AbstractEntityPreprocessor): """ - Implementation of MentionPreprocessor which utilizes Ab3P, an (biomedical)abbreviation definition detector, - given in: - https://github.com/ncbi-nlp/Ab3P - - Ab3P applies a set of rules reflecting simple patterns such as Alpha Beta (AB) as well as more involved cases. - The algorithm is described in detail in the following paper: - + Entity preprocessor which uses Ab3P, an (biomedical)abbreviation definition detector: Abbreviation definition identification based on automatic precision estimates. Sohn S, Comeau DC, Kim W, Wilbur WJ. BMC Bioinformatics. 2008 Sep 25;9:402. PubMed ID: 18817555 + https://github.com/ncbi-nlp/Ab3P """ def __init__( @@ -244,8 +237,7 @@ def __init__( :param ab3p_path: Path to the folder containing the Ab3P implementation :param word_data_dir: Path to the word data directory - :param preprocessor: Entity mention text preprocessor that is used before trying to link - the mention text to an abbreviation. + :param preprocessor: Basic entity preprocessor """ self.ab3p_path = ab3p_path self.word_data_dir = word_data_dir @@ -304,9 +296,7 @@ def load(cls, ab3p_path: Path = None, preprocessor: Optional[AbstractEntityPrepr @classmethod def download_ab3p(cls, data_dir: Path, word_data_dir: Path) -> Path: - """ - Downloads the Ab3P tool and all necessary data files. - """ + """Downloads the Ab3P tool and all necessary data files.""" # Download word data for Ab3P if not already downloaded ab3p_url = "https://raw.githubusercontent.com/dmis-lab/BioSyn/master/Ab3P/WordData/" @@ -350,6 +340,9 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict "Rous sarcoma virus ( RSV ) is a retrovirus.": {"RSV": "Rous sarcoma virus"} } + + :param sentences: list of senternces + :result abbreviation_dict: abbreviations and their resolution detected in each input sentence """ abbreviation_dict = defaultdict(dict) @@ -416,78 +409,24 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict return abbreviation_dict -class BigramTfIDFVectorizer: - """ - Helper class to encode a list of entity mentions or dictionary entries into a sparse tensor. - - Implementation adapted from: - Sung et al.: Biomedical Entity Representations with Synonym Marginalization, 2020 - https://github.com/dmis-lab/BioSyn/tree/master/src/biosyn/sparse_encoder.py#L8 - """ - - def __init__(self): - self.encoder = TfidfVectorizer(analyzer="char", ngram_range=(1, 2)) - - def fit(self, names: List[str]): - """ - Fit vectorizer - """ - self.encoder.fit(names) - return self - - def transform(self, names: List[str]) -> torch.Tensor: - """ - Convert string names to sparse vectors - """ - vec = self.encoder.transform(names).toarray() - vec = torch.FloatTensor(vec) - return vec - - def __call__(self, mentions: List[str]) -> torch.Tensor: - """ - Short for `transform` - """ - return self.transform(mentions) - - def save(self, path: Path) -> None: - with path.open("wb") as fout: - pickle.dump(self.encoder, fout) - # logger.info("Sparse encoder saved in %s", path) - - @classmethod - def load(cls, path: Path) -> "BigramTfIDFVectorizer": - """ - Instantiate from path - """ - newVectorizer = cls() - with open(path, "rb") as fin: - newVectorizer.encoder = pickle.load(fin) - # logger.info("Sparse encoder loaded from %s", path) - - return newVectorizer - - class BiomedicalEntityLinkingDictionary: """ - A class used to load dictionary data from a custom dictionary file. - Every line in the file must be formatted as follows: - concept_unique_id||concept_name - with one line per concept name. Multiple synonyms for the same concept should - be in separate lines with the same concept_unique_id. - - Slightly modifed from Sung et al. 2020 - Biomedical Entity Representations with Synonym Marginalization - https://github.com/dmis-lab/BioSyn/blob/master/src/biosyn/data_loader.py#L89 + Load dictionary: either pre-definded or from path + Every line in the file must be formatted as follows: concept_id||concept_name + If multiple concept ids are associated to a given name + they must be separated by a `|`. """ - def __init__(self, reader): + def __init__( + self, reader: Union[AbstractBiomedicalEntityLinkingDictionary, ParsedBiomedicalEntityLinkingDictionary] + ): self.reader = reader @classmethod - def load(cls, dictionary_name_or_path: Union[Path, str]) -> "EntityLinkingDictionary": - """ - Load dictionary: either pre-definded or from path - """ + def load( + cls, dictionary_name_or_path: Union[Path, str], database_name: Optional[str] = None + ) -> "EntityLinkingDictionary": + """Load dictionary: either pre-definded or from path""" if isinstance(dictionary_name_or_path, str): if ( @@ -495,9 +434,9 @@ def load(cls, dictionary_name_or_path: Union[Path, str]) -> "EntityLinkingDictio and dictionary_name_or_path not in BIOMEDICAL_DICTIONARIES ): raise ValueError( - f"""Unkwnon dictionary `{dictionary_name_or_path}`, - Available dictionaries are: {tuple(BIOMEDICAL_DICTIONARIES)} \n - If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)`""" + f"Unkwnon dictionary `{dictionary_name_or_path}`!" + f" Available dictionaries are: {tuple(BIOMEDICAL_DICTIONARIES)}" + " If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)`" ) dictionary_name_or_path = ENTITY_TYPE_TO_DICTIONARY.get(dictionary_name_or_path, dictionary_name_or_path) @@ -506,20 +445,22 @@ def load(cls, dictionary_name_or_path: Union[Path, str]) -> "EntityLinkingDictio else: # use custom dictionary file - reader = ParsedBiomedicalEntityLinkingDictionary(path=dictionary_name_or_path) + assert ( + database_name is not None + ), "When providing a path to a custom dictionary you must specify the `database_name`!" + reader = ParsedBiomedicalEntityLinkingDictionary(path=dictionary_name_or_path, database_name=database_name) return cls(reader=reader) - def get_database_names(self) -> List[str]: - """ - List all database names covered by dictionary, e.g. MESH, OMIM - """ + @property + def database_name(self) -> str: + """Database name of the dictionary""" - return self.reader.get_database_names() + return self.reader.database_name def stream(self) -> Iterator[Tuple[str, str]]: """ - Stream preprocessed dictionary + Stream entries from preprocessed dictionary """ for entry in self.reader.stream(): @@ -528,26 +469,23 @@ def stream(self) -> Iterator[Tuple[str, str]]: class AbstractCandidateGenerator(ABC): """ - An entity retriever model is used to find the top-k entities / concepts of a knowledge base / - dictionary for a given entity mention in text. + Base class for a candidate genertor """ @abstractmethod - def search(self, entity_mentions: List[str], top_k: int) -> List[Tuple[str, str, float]]: + def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, str, float]]]: """ - Returns the top-k entity / concept identifiers for the given entity mention. + Returns the top-k entity / concept identifiers for the each entity mention. - :param entity_mentions: Entity mention text under investigation - :param top_k: Number of (best-matching) entities from the knowledge base to return - :result: List of tuples highlighting the top-k entities. Each tuple has the following - structure (entity / concept name, concept ids, score). + :param entity_mentions: Entity mentions + :param top_k: Number of best-matching entities from the knowledge base to return + :result: list of tuples in the form: (entity / concept name, concept ids, similarity score). """ class ExactMatchCandidateGenerator(AbstractCandidateGenerator): """ - Implementation of an entity retriever model which uses exact string matching to - find the entity / concept identifier for a given entity mention. + Candidate generator using exact string matching as search criterion """ def __init__(self, dictionary: BiomedicalEntityLinkingDictionary): @@ -556,35 +494,66 @@ def __init__(self, dictionary: BiomedicalEntityLinkingDictionary): @classmethod def load(cls, dictionary_name_or_path: str) -> "ExactStringMatchingRetrieverModel": - """ - Compatibility function - """ - # Load dictionary + """Compatibility function""" return cls(BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path)) - def search(self, entity_mentions: List[str], top_k: int) -> List[Tuple[str, str, float]]: + def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, str, float]]]: """ - Returns the top-k entity / concept identifiers for the given entity mention. Note that - the model either returns the entity with an identical name in the knowledge base / dictionary - or none. - - :param entity_mention: Entity mention under investigation - :param top_k: Number of (best-matching) entities from the knowledge base to return - :result: List of tuples highlighting the top-k entities. Each tuple has the following - structure (entity / concept name, concept ids, score). + Returns the top-k entity / concept identifiers for the each entity mention. + + :param entity_mentions: Entity mentions + :param top_k: Number of best-matching entities from the knowledge base to return + :result: list of tuples in the form: (entity / concept name, concept ids, similarity score). """ - return [(em, self.name_to_id_index.get(em), 1.0) for em in entity_mentions] + return [[(em, self.name_to_id_index.get(em), 1.0)] for em in entity_mentions] -class BiEncoderCandidateGenerator(AbstractCandidateGenerator): +class BigramTfIDFVectorizer: + """ + Wrapper for sklearn TfIDFVectorizer w/ fixed ngram range at the character level + Implementation adapted from: + Sung et al.: Biomedical Entity Representations with Synonym Marginalization, 2020 + https://github.com/dmis-lab/BioSyn/tree/master/src/biosyn/sparse_encoder.py#L8 """ - Implementation of EntityRetrieverModel which uses dense (transformer-based) embeddings and (optionally) - sparse character-based representations, for normalizing an entity mention to specific identifiers - in a knowledge base / dictionary. - To this end, the model embeds the entity mention text and all concept names from the knowledge - base and outputs the k best-matching concepts based on embedding similarity. + def __init__(self): + self.encoder = TfidfVectorizer(analyzer="char", ngram_range=(1, 2)) + + def fit(self, names: List[str]): + """Learn vocabulary""" + self.encoder.fit(names) + return self + + def transform(self, names: List[str]) -> torch.Tensor: + """Convert strings to sparse vectors""" + vec = self.encoder.transform(names).toarray() + vec = torch.FloatTensor(vec) + return vec + + def __call__(self, mentions: List[str]) -> torch.Tensor: + """Short for `transform`""" + return self.transform(mentions) + + @classmethod + def load(cls, path: Path) -> "BigramTfIDFVectorizer": + """Instantiate from path""" + newVectorizer = cls() + + with open(path, "rb") as fin: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + newVectorizer.encoder = pickle.load(fin) + # logger.info("Sparse encoder loaded from %s", path) + + return newVectorizer + + +class BiEncoderCandidateGenerator(AbstractCandidateGenerator): + """ + Candidate generator using both dense (transformer-based) + and (optionally) sparse vector representations, + to search candidates in a knowledge base / dictionary. """ def __init__( @@ -594,36 +563,43 @@ def __init__( similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, preprocessor: AbstractEntityPreprocessor = Ab3PEntityPreprocessor.load(preprocessor=EntityPreprocessor()), max_length: int = 25, - index_batch_size: int = 1024, + batch_size: int = 1024, hybrid_search: bool = False, sparse_weight: Optional[float] = None, force_hybrid_search: bool = False, + dictionary: Optional[BiomedicalEntityLinkingDictionary] = None, ): """ Initializes the BiEncoderEntityRetrieverModel. :param model_name_or_path: Name of or path to the transformer model to be used. :param dictionary_name_or_path: Name of or path to the transformer model to be used. - :param hybrid_search: Indicates whether to use sparse embeddings or not - :param use_cosine: Indicates whether to use cosine similarity (instead of inner product) - :param max_length: Maximal number of tokens used for embedding an entity mention / concept name - :param index_batch_size: Batch size used during embedding of the dictionary and top-k prediction :param similarity_metric: which metric to use to compute similarity + :param preprocessor: Preprocessing for entity mentions and names + :param max_length: Maximum number of input tokens to transformer model + :param batch_size: how many entity mentions/names to embed in one forward pass + :param hybrid_search: Indicates whether to use sparse embeddings or not :param sparse_weight: default sparse weight - :param preprocessor: Preprocessing strategy to clean and transform entity / concept names from the knowledge base + :param force_hybrid_search: if pre-trained model is not hybrid (dense+sparse) fit a sparse encoder + :param dictionary: optionally pass a dictionary """ self.model_name_or_path = model_name_or_path self.dictionary_name_or_path = dictionary_name_or_path self.preprocessor = preprocessor self.similarity_metric = similarity_metric self.max_length = max_length - self.index_batch_size = index_batch_size + self.batch_size = batch_size self.hybrid_search = hybrid_search self.sparse_weight = sparse_weight self.force_hybrid_search = force_hybrid_search - # Load dictionary - self.dictionary = list(BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path).stream()) + # allow to pass custom dictionary + if dictionary is not None: + self.dictionary = dictionary + else: + self.dictionary = BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path) + + self.dictionary_data = list(self.dictionary.stream()) # Load encoders self.dense_encoder = TransformerDocumentEmbeddings(model=model_name_or_path, is_token_embedding=False) @@ -634,10 +610,10 @@ def __init__( model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path ) - self.embeddings = self._load_emebddings( + self.embeddings = self._load_embeddings( model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path, - batch_size=self.index_batch_size, + batch_size=self.batch_size, ) self.dense_index = self.build_dense_index(self.embeddings["dense"]) @@ -646,7 +622,7 @@ def __init__( def higher_is_better(self): """ Determine if similarity is proportional to score. - E.g. for L2 lower is better + E.g. for L2 lower is better, while INNER_PRODUCT higher is better """ return self.similarity_metric in [SimilarityMetric.COSINE, SimilarityMetric.INNER_PRODUCT] @@ -686,14 +662,17 @@ def _set_sparse_weigth_and_encoder( if model_name_or_path in PRETRAINED_HYBRID_MODELS: if not os.path.exists(sparse_encoder_path): + sparse_encoder_path = hf_hub_download( repo_id=model_name_or_path, filename="sparse_encoder.pk", cache_dir=flair.cache_root / "models" / model_name_or_path, ) + self.sparse_encoder = BigramTfIDFVectorizer.load(path=sparse_encoder_path) if not os.path.exists(sparse_weight_path): + sparse_weight_path = hf_hub_download( repo_id=model_name_or_path, filename="sparse_weight.pt", @@ -719,8 +698,7 @@ def _set_sparse_weigth_and_encoder( def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: """ - Embeds the given numpy array of entity names, either originating from the knowledge base - or recognized in a text, into sparse representations. + Create sparse embeddings from array of entity mentions/names. :param entity_names: An array of entity / concept names :returns sparse_embeds np.array: Numpy array containing the sparse embeddings @@ -735,9 +713,8 @@ def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: bool = False) -> np.ndarray: """ - Embeds the given numpy array of entity / concept names, either originating from the - knowledge base or recognized in a text, into dense representations using a - TransformerDocumentEmbedding model. + + Create dense embeddings from array of entity mentions/names. :param names: Numpy array of entity / concept names :param batch_size: Batch size used while embedding the name @@ -752,7 +729,7 @@ def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: if show_progress: iterations = tqdm( range(0, len(inputs), batch_size), - desc="Embed inputs", + desc=f"Embedding `{self.dictionary.database_name}` dictionary:", ) else: iterations = range(0, len(inputs), batch_size) @@ -774,10 +751,8 @@ def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: return dense_embeds - def _load_emebddings(self, model_name_or_path: str, dictionary_name_or_path: str, batch_size: int): - """ - Computes the embeddings for the given knowledge base / dictionary. - """ + def _load_embeddings(self, model_name_or_path: str, dictionary_name_or_path: str, batch_size: int): + """Compute and cache the embeddings for the given knowledge base / dictionary.""" # Check for embedded dictionary in cache dictionary_name = os.path.splitext(os.path.basename(dictionary_name_or_path))[0] @@ -800,7 +775,7 @@ def _load_emebddings(self, model_name_or_path: str, dictionary_name_or_path: str cache_folder.mkdir(parents=True, exist_ok=True) - names = [self.preprocessor.process_entity_name(name) for name, cui in self.dictionary] + names = [self.preprocessor.process_entity_name(name) for name, cui in self.dictionary_data] # Compute dense embeddings (if necessary) dense_embeddings = self.embed_dense(inputs=names, batch_size=batch_size, show_progress=True) @@ -829,13 +804,11 @@ def search_sparse( normalise: bool = False, ) -> Tuple[np.ndarray, np.ndarray]: """ - Returns top-k indexes (in descending order) for the given entity mentions resp. mention - embeddings. + Find candidates with sparse representations - :param score_matrix: 2d numpy array of scores + :param entity_mentions: list of entity mentions (queries) :param top_k: number of candidates to retrieve - :return res: d numpy array of ids [# of query , # of dict] - :return scores: numpy array of top scores + :param normalise: normalise scores """ mention_embeddings = self.sparse_encoder(entity_mentions) @@ -865,11 +838,14 @@ def indexing_2d(arr, cols): def search_dense(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np.ndarray, np.ndarray]: """ - Dense search via FAISS index + Find candidates with dense representations (FAISS) + + :param entity_mentions: list of entity mentions (queries) + :param top_k: number of candidates to retrieve """ # Compute dense embedding for the given entity mention - mention_dense_embeds = self.embed_dense(inputs=np.array(entity_mentions), batch_size=self.index_batch_size) + mention_dense_embeds = self.embed_dense(inputs=np.array(entity_mentions), batch_size=self.batch_size) if self.similarity_metric == SimilarityMetric.COSINE: faiss.normalize_L2(mention_dense_embeds) @@ -918,13 +894,13 @@ def combine_dense_and_sparse_results( return hybrid_scores, hybrid_ids - def search(self, entity_mentions: List[str], top_k: int) -> List[Tuple[str, str, float]]: + def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, str, float]]]: """ - Returns the top-k entities for a given entity mention. + Returns the top-k entity / concept identifiers for the each entity mention. - :param entity_mentions: Entity mentions (search queries) - :param top_k: Number of (best-matching) entities from the knowledge base to return - :result: List of tuples w/ the top-k entities: (concept name, concept ids, score). + :param entity_mentions: Entity mentions + :param top_k: Number of best-matching entities from the knowledge base to return + :result: list of tuples in the form: (entity / concept name, concept ids, similarity score). """ scores, ids = self.search_dense(entity_mentions=entity_mentions, top_k=top_k) @@ -942,17 +918,13 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[Tuple[str, str, ) return [ - tuple(self.dictionary[i]) + (score,) + [tuple(self.dictionary_data[i]) + (score,) for i, score in zip(mention_ids, mention_scores)] for mention_ids, mention_scores in zip(ids, scores) - for i, score in zip(mention_ids, mention_scores) ] class BiomedicalEntityLinker: - """ - Entity linking model which expects text/sentences with annotated entity mentions and predicts - entity / concept to these mentions according to a knowledge base / dictionary. - """ + """Entity linking model for the biomedical domain""" def __init__( self, @@ -963,70 +935,68 @@ def __init__( self.preprocessor = preprocessor self.candidate_generator = candidate_generator self.entity_type = entity_type - self.annotation_layer = ENTITY_TYPE_TO_ANNOTATION_LAYER[self.entity_type] + self.annotation_layers = [ENTITY_TYPE_TO_ANNOTATION_LAYER.get(self.entity_type, "ner")] - def build_entity_linking_label(self, data_point: Span, prediction: Tuple[str, str, int]) -> EntityLinkingLabel: - """ - Create entity linking label from retriever model result - """ + def build_candidate(self, candidate: Tuple[str, str, float]): + """Get nice container with all info about entity linking candidate""" + + concept_name = candidate[0] + concept_id = candidate[1] + score = candidate[2] + database_name = self.candidate_generator.dictionary.database_name - # if concept identifier is made up of multiple ids, separated by '|' - # separate it into cui and additional_labels - cui = prediction[1] - if "|" in cui: - labels = cui.split("|") - cui = labels[0] + if "|" in concept_id: + labels = concept_id.split("|") + concept_id = labels[0] additional_labels = labels[1:] else: additional_labels = None - # determine database: - if ":" in cui: - cui_parts = cui.split(":") - database = ":".join(cui_parts[0:-1]) - cui = cui_parts[-1] - else: - database = None - - return EntityLinkingLabel( - data_point=data_point, - concept_id=cui, - concept_name=prediction[0], + return EntityLinkingCandidate( + concept_id=concept_id, + concept_name=concept_name, + score=score, additional_ids=additional_labels, - database=database, - score=prediction[2], + database_name=database_name, ) def extract_mentions( self, sentences: List[Sentence], + annotation_layers: Optional[List[str]] = None, ) -> Tuple[List[int], List[Span], List[str]]: - """ - Unpack all mentions in sentences for batch search. - Output is list of (sentence index, mention text). - """ + """Unpack all mentions in sentences for batch search.""" source = [] data_points = [] mentions = [] + mention_annotationa_layers = [] + + # get all valid annotation layers: pre-determined and user input + annotation_layers = ( + self.annotation_layers + annotation_layers if annotation_layers is not None else self.annotation_layers + ) + for i, sentence in enumerate(sentences): - for entity in sentence.get_labels(self.annotation_layer): - source.append(i) - data_points.append(entity.data_point) - mentions.append( - self.preprocessor.process_mention(entity, sentence) - if self.preprocessor is not None - else entity.data_point.text, - ) + for annotation_layer in annotation_layers: + for entity in sentence.get_labels(annotation_layer): + source.append(i) + data_points.append(entity.data_point) + mentions.append( + self.preprocessor.process_mention(entity, sentence) + if self.preprocessor is not None + else entity.data_point.text, + ) + mention_annotationa_layers.append(annotation_layer) - assert len(mentions) > 0, f"There are no entity mentions of type `{self.entity_type}`" + # assert len(mentions) > 0, f"There are no entity mentions of type `{self.entity_type}`" - return source, data_points, mentions + return source, data_points, mentions, mention_annotationa_layers def predict( self, sentences: Union[List[Sentence], Sentence], - # input_entity_annotation_layer: str = None, + annotation_layers: Optional[List[str]] = None, top_k: int = 1, ) -> None: """ @@ -1034,8 +1004,8 @@ def predict( with tag input_entity_annotation_layer. :param sentences: One or more sentences to run the prediction on - :param top_k: Number of best-matching entity / concept identifiers which should be predicted - per entity mention + :param annotation_layers: list of annotation layers to extract entity mentions + :param top_k: Number of best-matching entity / concept identifiers """ # make sure sentences is a list of sentences if not isinstance(sentences, list): @@ -1047,17 +1017,23 @@ def predict( # Build label name # label_name = input_entity_annotation_layer + "_nen" if (input_entity_annotation_layer is not None) else "nen" - source, data_points, mentions = self.extract_mentions(sentences=sentences) + source, data_points, mentions, mentions_annotation_layers = self.extract_mentions( + sentences=sentences, annotation_layers=annotation_layers + ) # Retrieve top-k concept / entity candidates - predictions = self.candidate_generator.search(entity_mentions=mentions, top_k=top_k) + candidates = self.candidate_generator.search(entity_mentions=mentions, top_k=top_k) # Add a label annotation for each candidate - for i, data_point, prediction in zip(source, data_points, predictions): + for i, data_point, mention_candidates, mentions_annotation_layer in zip( + source, data_points, candidates, mentions_annotation_layers + ): sentences[i].add_label( - typename=self.annotation_layer, - value_or_label=self.build_entity_linking_label(prediction=prediction, data_point=data_point), + typename=mentions_annotation_layer, + value_or_label=EntityLinkingLabel( + data_point=data_point, candidates=[self.build_candidate(c) for c in mention_candidates] + ), ) @classmethod @@ -1067,12 +1043,13 @@ def load( dictionary_name_or_path: Optional[Union[str, Path]] = None, hybrid_search: bool = True, max_length: int = 25, - index_batch_size: int = 1024, + batch_size: int = 1024, similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, preprocessor: AbstractEntityPreprocessor = Ab3PEntityPreprocessor.load(preprocessor=EntityPreprocessor()), force_hybrid_search: bool = False, sparse_weight: float = DEFAULT_SPARSE_WEIGHT, entity_type: Optional[str] = None, + dictionary: Optional[List[Tuple[str, str]]] = None, ): """ Loads a model for biomedical named entity normalization. @@ -1104,16 +1081,18 @@ def load( hybrid_search=hybrid_search, similarity_metric=similarity_metric, max_length=max_length, - index_batch_size=index_batch_size, + batch_size=batch_size, sparse_weight=sparse_weight, preprocessor=preprocessor, force_hybrid_search=force_hybrid_search, + dictionary=dictionary, ) logger.info( - "BiomedicalEntityLinker predicts: Entity type: %s with Dictionary `%s`", - entity_type, + "BiomedicalEntityLinker predicts: Dictionary `%s` (entity type: %s) with %s classes", dictionary_name_or_path, + entity_type, + len(candidate_generator.dictionary_data), ) return cls(candidate_generator=candidate_generator, preprocessor=preprocessor, entity_type=entity_type) @@ -1131,9 +1110,9 @@ def __get_model_path_and_entity_type( if model_name_or_path not in MODELS and model_name_or_path not in ENTITY_TYPES: raise ValueError( - f"""Unknown model `{model_name_or_path}`! \n - Available entity types are: {ENTITY_TYPES} \n - If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)`""" + f"Unknown model `{model_name_or_path}`!" + f" Available entity types are: {ENTITY_TYPES}" + " If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)`" ) if model_name_or_path == "cambridgeltl/SapBERT-from-PubMedBERT-fulltext": @@ -1172,7 +1151,9 @@ def __get_model_path_and_entity_type( if model_name_or_path in ENTITY_TYPES: model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] - assert entity_type is not None, f"Impossible to determine entity type for model `{model_name_or_path}`" + assert ( + entity_type is not None + ), f"Impossible to determine entity type for model `{model_name_or_path}`: please specify via `entity_type`" return model_name_or_path, entity_type @@ -1186,7 +1167,9 @@ def __get_dictionary_path( """ if model_name_or_path in STRING_MATCHING_MODELS and dictionary_name_or_path is None: - raise ValueError("When using a string-matchin retriever you must specify `dictionary_name_or_path`!") + raise ValueError( + "When using a string-matching candidate generator you must specify `dictionary_name_or_path`!" + ) if dictionary_name_or_path is not None: if dictionary_name_or_path in ENTITY_TYPES: From 99a109f153fa64ba68f50f75877bad5840b42868 Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Fri, 12 May 2023 18:07:53 +0200 Subject: [PATCH 14/58] fix(candidate_generator): container for search result --- flair/models/biomedical_entity_linking.py | 65 ++++++++++++----------- 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index c181d5455..1c9fcf6cb 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -473,7 +473,7 @@ class AbstractCandidateGenerator(ABC): """ @abstractmethod - def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, str, float]]]: + def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: """ Returns the top-k entity / concept identifiers for the each entity mention. @@ -482,6 +482,29 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, :result: list of tuples in the form: (entity / concept name, concept ids, similarity score). """ + def build_candidate(self, candidate: Tuple[str, str, float]) -> EntityLinkingCandidate: + """Get nice container with all info about entity linking candidate""" + + concept_name = candidate[0] + concept_id = candidate[1] + score = candidate[2] + database_name = self.dictionary.database_name + + if "|" in concept_id: + labels = concept_id.split("|") + concept_id = labels[0] + additional_labels = labels[1:] + else: + additional_labels = None + + return EntityLinkingCandidate( + concept_id=concept_id, + concept_name=concept_name, + score=score, + additional_ids=additional_labels, + database_name=database_name, + ) + class ExactMatchCandidateGenerator(AbstractCandidateGenerator): """ @@ -490,14 +513,14 @@ class ExactMatchCandidateGenerator(AbstractCandidateGenerator): def __init__(self, dictionary: BiomedicalEntityLinkingDictionary): # Build index which maps concept / entity names to concept / entity ids - self.name_to_id_index = dict(dictionary.data) + self.name_to_id_index = dict(list(dictionary.stream())) @classmethod def load(cls, dictionary_name_or_path: str) -> "ExactStringMatchingRetrieverModel": """Compatibility function""" return cls(BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path)) - def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, str, float]]]: + def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: """ Returns the top-k entity / concept identifiers for the each entity mention. @@ -506,7 +529,7 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, :result: list of tuples in the form: (entity / concept name, concept ids, similarity score). """ - return [[(em, self.name_to_id_index.get(em), 1.0)] for em in entity_mentions] + return [[self.build_candidate((em, self.name_to_id_index.get(em), 1.0))] for em in entity_mentions] class BigramTfIDFVectorizer: @@ -894,7 +917,7 @@ def combine_dense_and_sparse_results( return hybrid_scores, hybrid_ids - def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, str, float]]]: + def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: """ Returns the top-k entity / concept identifiers for the each entity mention. @@ -918,7 +941,10 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, ) return [ - [tuple(self.dictionary_data[i]) + (score,) for i, score in zip(mention_ids, mention_scores)] + [ + self.build_candidate(tuple(self.dictionary_data[i]) + (score,)) + for i, score in zip(mention_ids, mention_scores) + ] for mention_ids, mention_scores in zip(ids, scores) ] @@ -937,29 +963,6 @@ def __init__( self.entity_type = entity_type self.annotation_layers = [ENTITY_TYPE_TO_ANNOTATION_LAYER.get(self.entity_type, "ner")] - def build_candidate(self, candidate: Tuple[str, str, float]): - """Get nice container with all info about entity linking candidate""" - - concept_name = candidate[0] - concept_id = candidate[1] - score = candidate[2] - database_name = self.candidate_generator.dictionary.database_name - - if "|" in concept_id: - labels = concept_id.split("|") - concept_id = labels[0] - additional_labels = labels[1:] - else: - additional_labels = None - - return EntityLinkingCandidate( - concept_id=concept_id, - concept_name=concept_name, - score=score, - additional_ids=additional_labels, - database_name=database_name, - ) - def extract_mentions( self, sentences: List[Sentence], @@ -1031,9 +1034,7 @@ def predict( sentences[i].add_label( typename=mentions_annotation_layer, - value_or_label=EntityLinkingLabel( - data_point=data_point, candidates=[self.build_candidate(c) for c in mention_candidates] - ), + value_or_label=EntityLinkingLabel(data_point=data_point, candidates=mention_candidates), ) @classmethod From 301988e877e7d037bbac02f587ac5342c2a5601a Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Fri, 19 May 2023 13:20:58 +0200 Subject: [PATCH 15/58] fix(predict): default annotation layer iff not provided by use - fix typo --- flair/models/biomedical_entity_linking.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index 1c9fcf6cb..f07be8dfd 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -973,12 +973,10 @@ def extract_mentions( source = [] data_points = [] mentions = [] - mention_annotationa_layers = [] + mention_annotation_layers = [] - # get all valid annotation layers: pre-determined and user input - annotation_layers = ( - self.annotation_layers + annotation_layers if annotation_layers is not None else self.annotation_layers - ) + # use default annotation layers only if are not provided + annotation_layers = annotation_layers if annotation_layers is not None else self.annotation_layers for i, sentence in enumerate(sentences): for annotation_layer in annotation_layers: @@ -990,11 +988,11 @@ def extract_mentions( if self.preprocessor is not None else entity.data_point.text, ) - mention_annotationa_layers.append(annotation_layer) + mention_annotation_layers.append(annotation_layer) # assert len(mentions) > 0, f"There are no entity mentions of type `{self.entity_type}`" - return source, data_points, mentions, mention_annotationa_layers + return source, data_points, mentions, mention_annotation_layers def predict( self, From c14d6ce9b8787eb6beecc44898deeceefca3fe35 Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Fri, 19 May 2023 13:21:30 +0200 Subject: [PATCH 16/58] fix(label): scores can be >= or <= --- flair/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flair/data.py b/flair/data.py index 8069ecf53..ada25b0ea 100644 --- a/flair/data.py +++ b/flair/data.py @@ -472,7 +472,7 @@ def __init__(self, data_point: DataPoint, candidates: List[EntityLinkingCandidat :param candidates: **sorted** list of candidates from candidate generator """ - def is_sorted(lst, key=lambda x: x, comparison=lambda x, y: x > y): + def is_sorted(lst, key=lambda x: x, comparison=lambda x, y: x >= y): for i, el in enumerate(lst[1:]): if comparison(key(el), key(lst[i])): return False @@ -480,7 +480,7 @@ def is_sorted(lst, key=lambda x: x, comparison=lambda x, y: x > y): # candidates must be sorted, regardless if higher is better or not assert is_sorted(candidates, key=lambda x: x.score) or is_sorted( - candidates, key=lambda x: x.score, comparison=lambda x, y: x < y + candidates, key=lambda x: x.score, comparison=lambda x, y: x <= y ), "List of candidates must be sorted!" super().__init__(data_point, candidates[0].concept_id, candidates[0].score) From c66789a48469ce0b4062ff233578cccd3d2d459c Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Fri, 19 May 2023 13:24:05 +0200 Subject: [PATCH 17/58] fix(candidate): parametrize database name --- flair/models/biomedical_entity_linking.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index f07be8dfd..b688a59d6 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -482,13 +482,12 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLink :result: list of tuples in the form: (entity / concept name, concept ids, similarity score). """ - def build_candidate(self, candidate: Tuple[str, str, float]) -> EntityLinkingCandidate: + def build_candidate(self, candidate: Tuple[str, str, float], database_name: str) -> EntityLinkingCandidate: """Get nice container with all info about entity linking candidate""" concept_name = candidate[0] concept_id = candidate[1] score = candidate[2] - database_name = self.dictionary.database_name if "|" in concept_id: labels = concept_id.split("|") @@ -513,6 +512,7 @@ class ExactMatchCandidateGenerator(AbstractCandidateGenerator): def __init__(self, dictionary: BiomedicalEntityLinkingDictionary): # Build index which maps concept / entity names to concept / entity ids + self.dictionary = dictionary self.name_to_id_index = dict(list(dictionary.stream())) @classmethod @@ -529,7 +529,14 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLink :result: list of tuples in the form: (entity / concept name, concept ids, similarity score). """ - return [[self.build_candidate((em, self.name_to_id_index.get(em), 1.0))] for em in entity_mentions] + return [ + [ + self.build_candidate( + candidate=(em, self.name_to_id_index.get(em), 1.0), database_name=self.dictionary.database_name + ) + ] + for em in entity_mentions + ] class BigramTfIDFVectorizer: @@ -942,7 +949,9 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLink return [ [ - self.build_candidate(tuple(self.dictionary_data[i]) + (score,)) + self.build_candidate( + candidate=tuple(self.dictionary_data[i]) + (score,), database_name=self.dictionary.database_name + ) for i, score in zip(mention_ids, mention_scores) ] for mention_ids, mention_scores in zip(ids, scores) From 70c0c7d6759eaa23668384d5141c12dee94acef0 Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Mon, 22 May 2023 17:46:47 +0200 Subject: [PATCH 18/58] feat(candidate_generator): cache sparse encoder - better naming - unique cache name --- flair/models/biomedical_entity_linking.py | 121 +++++++++++++--------- 1 file changed, 73 insertions(+), 48 deletions(-) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index b688a59d6..ee51cd39d 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -13,6 +13,7 @@ from pathlib import Path from typing import Dict, Iterator, List, Optional, Tuple, Union +import joblib import numpy as np import torch from huggingface_hub import hf_hub_download @@ -565,15 +566,19 @@ def __call__(self, mentions: List[str]) -> torch.Tensor: """Short for `transform`""" return self.transform(mentions) + def save(self, path: Path): + """Save vectorizer to disk""" + joblib.dump(self.encoder, str(path)) + @classmethod def load(cls, path: Path) -> "BigramTfIDFVectorizer": """Instantiate from path""" newVectorizer = cls() - with open(path, "rb") as fin: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - newVectorizer.encoder = pickle.load(fin) + # with open(path, "rb") as fin: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + newVectorizer.encoder = joblib.load(str(path)) # logger.info("Sparse encoder loaded from %s", path) return newVectorizer @@ -657,6 +662,16 @@ def higher_is_better(self): return self.similarity_metric in [SimilarityMetric.COSINE, SimilarityMetric.INNER_PRODUCT] + def _get_cache_name(self, model_name_or_path: str, dictionary_name_or_path: str) -> str: + """Fixed name for caching""" + + # Check for embedded dictionary in cache + dictionary_name = os.path.splitext(os.path.basename(dictionary_name_or_path))[0] + file_name = f"{model_name_or_path.split('/')[-1]}_{dictionary_name}" + pp_name = self.preprocessor.name if self.preprocessor is not None else "null" + + return f"{file_name}-{pp_name}" + # separate method to allow more sophisticated logic in the future, # e.g. ANN with IndexIP, HNSW... def build_dense_index(self, embeddings: np.ndarray) -> faiss.Index: @@ -667,7 +682,7 @@ def build_dense_index(self, embeddings: np.ndarray) -> faiss.Index: return dense_index - def _fit_sparse_encoder(self): + def _fit_sparse_encoder(self) -> BigramTfIDFVectorizer: """Fit sparse encoder to current dictionary""" logger.info( @@ -675,10 +690,39 @@ def _fit_sparse_encoder(self): self.dictionary_name_or_path, self.sparse_weight, ) - self.sparse_encoder = BigramTfIDFVectorizer().fit([name for name, cui in self.dictionary]) + sparse_encoder = BigramTfIDFVectorizer().fit([name for name, cui in self.dictionary_data]) # sparse_encoder.save(Path(sparse_encoder_path)) # torch.save(torch.FloatTensor(self.sparse_weight), sparse_weight_path) + return sparse_encoder + + def _handle_sparse_encoder(self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path]): + """If necessary fit and cache sparse encoder""" + + if self.force_hybrid_search: + + self.sparse_weight = self.sparse_weight if self.sparse_weight is not None else DEFAULT_SPARSE_WEIGHT + + if isinstance(model_name_or_path, str): + cache_name = self._get_cache_name( + model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path + ) + path = flair.cache_root / "models" / f"{cache_name}-sparse-encoder.pk" + else: + path = model_name_or_path / "sparse_encoder.pk" + + if path.exists(): + self.sparse_encoder = BigramTfIDFVectorizer.load(path) + else: + self.sparse_encoder = self._fit_sparse_encoder() + logger.info("Save fitted sparse encoder to %s", path) + self.sparse_encoder.save(path) + else: + model_type = "Hybrid model" if isinstance(model_name_or_path, str) else "Local hybrid model" + raise ValueError( + f"{model_type} has no pretrained sparse encoder. Please pass `force_hybrid_search=True` if you want to fit a sparse model to dictionary `{dictionary_name_or_path}`" + ) + def _set_sparse_weigth_and_encoder( self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path] ): @@ -686,45 +730,29 @@ def _set_sparse_weigth_and_encoder( sparse_encoder_path = os.path.join(model_name_or_path, "sparse_encoder.pk") sparse_weight_path = os.path.join(model_name_or_path, "sparse_weight.pt") - if isinstance(model_name_or_path, str): - - # check file exists - if model_name_or_path in PRETRAINED_HYBRID_MODELS: + if isinstance(model_name_or_path, str) and model_name_or_path in PRETRAINED_HYBRID_MODELS: + if not os.path.exists(sparse_encoder_path): - if not os.path.exists(sparse_encoder_path): - - sparse_encoder_path = hf_hub_download( - repo_id=model_name_or_path, - filename="sparse_encoder.pk", - cache_dir=flair.cache_root / "models" / model_name_or_path, - ) + sparse_encoder_path = hf_hub_download( + repo_id=model_name_or_path, + filename="sparse_encoder.pk", + cache_dir=flair.cache_root / "models" / model_name_or_path, + ) - self.sparse_encoder = BigramTfIDFVectorizer.load(path=sparse_encoder_path) + self.sparse_encoder = BigramTfIDFVectorizer.load(path=sparse_encoder_path) - if not os.path.exists(sparse_weight_path): + if not os.path.exists(sparse_weight_path): - sparse_weight_path = hf_hub_download( - repo_id=model_name_or_path, - filename="sparse_weight.pt", - cache_dir=flair.cache_root / "models" / model_name_or_path, - ) - self.sparse_weight = torch.load(sparse_weight_path, map_location="cpu").item() - else: - if self.force_hybrid_search: - self.sparse_weight = self.sparse_weight if self.sparse_weight is not None else DEFAULT_SPARSE_WEIGHT - self._fit_sparse_encoder() - else: - raise ValueError( - f"A: Hybrid model has no pretrained sparse encoder. Please pass `force_hybrid_search=True` if you want to fit a sparse model to dictionary `{dictionary_name_or_path}`" - ) - else: - if self.force_hybrid_search: - self.sparse_weight = self.sparse_weight if self.sparse_weight is not None else DEFAULT_SPARSE_WEIGHT - self._fit_sparse_encoder() - else: - raise ValueError( - f"Local hybrid model has no pretrained sparse encoder. Please pass `force_hybrid_search=True` if you want to fit a sparse model to dictionary `{dictionary_name_or_path}`" + sparse_weight_path = hf_hub_download( + repo_id=model_name_or_path, + filename="sparse_weight.pt", + cache_dir=flair.cache_root / "models" / model_name_or_path, ) + self.sparse_weight = torch.load(sparse_weight_path, map_location="cpu").item() + else: + self._handle_sparse_encoder( + model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path + ) def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: """ @@ -784,21 +812,18 @@ def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: def _load_embeddings(self, model_name_or_path: str, dictionary_name_or_path: str, batch_size: int): """Compute and cache the embeddings for the given knowledge base / dictionary.""" - # Check for embedded dictionary in cache - dictionary_name = os.path.splitext(os.path.basename(dictionary_name_or_path))[0] - file_name = f"bio_nen_{model_name_or_path.split('/')[-1]}_{dictionary_name}" - cache_folder = flair.cache_root / "datasets" + cache_name = self._get_cache_name( + model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path + ) - pp_name = self.preprocessor.name if self.preprocessor is not None else "null" - - embeddings_cache_file = cache_folder / f"{file_name}-{pp_name}.pk" + embeddings_cache_file = cache_folder / f"{cache_name}-embeddings.pk" # If exists, load the cached dictionary indices if embeddings_cache_file.exists(): with embeddings_cache_file.open("rb") as fp: - logger.info("Load cached emebddings from: %s", embeddings_cache_file) + logger.info("BiEncoderCandidateGenerator: load cached emebddings `%s`", f"{cache_name}-embeddings.pk") embeddings = pickle.load(fp) else: From 37a24580c5d2a34a100e13a83f32861bcf61dadc Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Tue, 23 May 2023 12:12:11 +0200 Subject: [PATCH 19/58] fix(candidate_generator): minor improvements - add option to time search - change error to warning if pre-trained model is not hybrid - check if there are mentions to predict --- flair/models/biomedical_entity_linking.py | 87 +++++++++++++---------- 1 file changed, 48 insertions(+), 39 deletions(-) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index ee51cd39d..f3e445d8b 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -6,6 +6,7 @@ import string import subprocess import tempfile +import time import warnings from abc import ABC, abstractmethod from collections import defaultdict @@ -42,7 +43,7 @@ import faiss except ImportError as error: raise ImportError( - f"You need to install to run the biomedical entity linking: `pip faiss faiss-cpu=={FAISS_VERSION}`" + f"You need to install faiss to run the biomedical entity linking: `pip faiss faiss-cpu=={FAISS_VERSION}`" ) from error logger = logging.getLogger("flair") @@ -474,7 +475,7 @@ class AbstractCandidateGenerator(ABC): """ @abstractmethod - def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: + def search(self, entity_mentions: List[str], top_k: int, timeit: bool = False) -> List[List[EntityLinkingCandidate]]: """ Returns the top-k entity / concept identifiers for the each entity mention. @@ -521,7 +522,7 @@ def load(cls, dictionary_name_or_path: str) -> "ExactStringMatchingRetrieverMode """Compatibility function""" return cls(BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path)) - def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: + def search(self, entity_mentions: List[str], top_k: int, timeit: bool = False) -> List[List[EntityLinkingCandidate]]: """ Returns the top-k entity / concept identifiers for the each entity mention. @@ -717,11 +718,11 @@ def _handle_sparse_encoder(self, model_name_or_path: Union[str, Path], dictionar self.sparse_encoder = self._fit_sparse_encoder() logger.info("Save fitted sparse encoder to %s", path) self.sparse_encoder.save(path) - else: - model_type = "Hybrid model" if isinstance(model_name_or_path, str) else "Local hybrid model" - raise ValueError( - f"{model_type} has no pretrained sparse encoder. Please pass `force_hybrid_search=True` if you want to fit a sparse model to dictionary `{dictionary_name_or_path}`" - ) + # else: + # model_type = "Hybrid model" if isinstance(model_name_or_path, str) else "Local hybrid model" + # raise ValueError( + # f"{model_type} has no pretrained sparse encoder. Please pass `force_hybrid_search=True` if you want to fit a sparse model to dictionary `{dictionary_name_or_path}`" + # ) def _set_sparse_weigth_and_encoder( self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path] @@ -787,7 +788,7 @@ def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: if show_progress: iterations = tqdm( range(0, len(inputs), batch_size), - desc=f"Embedding `{self.dictionary.database_name}` dictionary:", + desc=f"Embedding `{self.dictionary.database_name}` dictionary", ) else: iterations = range(0, len(inputs), batch_size) @@ -857,6 +858,7 @@ def search_sparse( entity_mentions: List[str], top_k: int = 1, normalise: bool = False, + timeit: bool = False ) -> Tuple[np.ndarray, np.ndarray]: """ Find candidates with sparse representations @@ -868,10 +870,14 @@ def search_sparse( mention_embeddings = self.sparse_encoder(entity_mentions) + start = time.time() if self.similarity_metric == SimilarityMetric.COSINE: score_matrix = cosine_similarity(mention_embeddings, self.embeddings["sparse"]) else: score_matrix = np.matmul(mention_embeddings, self.embeddings["sparse"].T) + elapsed = round(time.time() - start, 2) + if timeit: + logger.info(f"BiEncoderCandidateGenerator: sparse search with {len(entity_mentions)} query took ~{elapsed}") if normalise: score_matrix = (score_matrix - score_matrix.min()) / (score_matrix.max() - score_matrix.min()) @@ -889,9 +895,10 @@ def indexing_2d(arr, cols): topk_idxs = indexing_2d(topk_idxs, topk_argidxs) topk_scores = indexing_2d(score_matrix, topk_idxs) + return topk_scores, topk_idxs - def search_dense(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np.ndarray, np.ndarray]: + def search_dense(self, entity_mentions: List[str], top_k: int = 1, timeit : bool = False) -> Tuple[np.ndarray, np.ndarray]: """ Find candidates with dense representations (FAISS) @@ -905,8 +912,12 @@ def search_dense(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np.n if self.similarity_metric == SimilarityMetric.COSINE: faiss.normalize_L2(mention_dense_embeds) + start = time.time() # Get candidates from dense embeddings dists, ids = self.dense_index.search(mention_dense_embeds, top_k) + elapsed = round(time.time() - start, 2) + if timeit: + logger.info(f"BiEncoderCandidateGenerator: dense search with {len(entity_mentions)} query took ~{elapsed}") return dists, ids @@ -949,7 +960,7 @@ def combine_dense_and_sparse_results( return hybrid_scores, hybrid_ids - def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: + def search(self, entity_mentions: List[str], top_k: int, timeit: bool = False) -> List[List[EntityLinkingCandidate]]: """ Returns the top-k entity / concept identifiers for the each entity mention. @@ -958,11 +969,11 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLink :result: list of tuples in the form: (entity / concept name, concept ids, similarity score). """ - scores, ids = self.search_dense(entity_mentions=entity_mentions, top_k=top_k) + scores, ids = self.search_dense(entity_mentions=entity_mentions, top_k=top_k, timeit=timeit) - if self.hybrid_search: + if self.hybrid_search and self.sparse_encoder is not None: - sparse_scores, sparse_ids = self.search_sparse(entity_mentions=entity_mentions, top_k=top_k) + sparse_scores, sparse_ids = self.search_sparse(entity_mentions=entity_mentions, top_k=top_k, timeit=timeit) scores, ids = self.combine_dense_and_sparse_results( dense_ids=ids, @@ -1033,6 +1044,7 @@ def predict( sentences: Union[List[Sentence], Sentence], annotation_layers: Optional[List[str]] = None, top_k: int = 1, + timeit:bool = False ) -> None: """ Predicts the best matching top-k entity / concept identifiers of all named entites annotated @@ -1049,25 +1061,24 @@ def predict( if self.preprocessor is not None: self.preprocessor.initialize(sentences) - # Build label name - # label_name = input_entity_annotation_layer + "_nen" if (input_entity_annotation_layer is not None) else "nen" - source, data_points, mentions, mentions_annotation_layers = self.extract_mentions( sentences=sentences, annotation_layers=annotation_layers ) - # Retrieve top-k concept / entity candidates - candidates = self.candidate_generator.search(entity_mentions=mentions, top_k=top_k) + # no mentions: nothing to do here + if len(mentions) > 0: + # Retrieve top-k concept / entity candidates + candidates = self.candidate_generator.search(entity_mentions=mentions, top_k=top_k, timeit=timeit) - # Add a label annotation for each candidate - for i, data_point, mention_candidates, mentions_annotation_layer in zip( - source, data_points, candidates, mentions_annotation_layers - ): + # Add a label annotation for each candidate + for i, data_point, mention_candidates, mentions_annotation_layer in zip( + source, data_points, candidates, mentions_annotation_layers + ): - sentences[i].add_label( - typename=mentions_annotation_layer, - value_or_label=EntityLinkingLabel(data_point=data_point, candidates=mention_candidates), - ) + sentences[i].add_label( + typename=mentions_annotation_layer, + value_or_label=EntityLinkingLabel(data_point=data_point, candidates=mention_candidates), + ) @classmethod def load( @@ -1161,22 +1172,20 @@ def __get_model_path_and_entity_type( else: # check if user really wants to use hybrid search anyway if not force_hybrid_search: - raise ValueError( - f""" - Model for entity type `{model_name_or_path}` was not trained for hybrid search! - If you want to proceed anyway please pass `force_hybrid_search=True`: - we will fit a sparse encoder for you. The default value of `sparse_weight` is `{DEFAULT_SPARSE_WEIGHT}`. - """ + logger.warning( + "Model for entity type `%s` was not trained for hybrid search: no sparse search will be performed." + " If you want to use sparse search please pass `force_hybrid_search=True`:" + " we will fit a sparse encoder for you. The default value of `sparse_weight` is `%s`.", + model_name_or_path, DEFAULT_SPARSE_WEIGHT ) model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] else: if model_name_or_path not in PRETRAINED_HYBRID_MODELS and not force_hybrid_search: - raise ValueError( - f""" - Model `{model_name_or_path}` was not trained for hybrid search! - If you want to proceed anyway please pass `force_hybrid_search=True`: - we will fit a sparse encoder for you. The default value of `sparse_weight` is `{DEFAULT_SPARSE_WEIGHT}`. - """ + logger.warning( + "Model `%s` was not trained for hybrid search: no sparse search will be performed." + " If you want to use sparse search please pass `force_hybrid_search=True`:" + " we will fit a sparse encoder for you. The default value of `sparse_weight` is `%s`.", + model_name_or_path, DEFAULT_SPARSE_WEIGHT ) entity_type = PRETRAINED_HYBRID_MODELS[model_name_or_path] From bc801fc7cc302aedd6a4acaef43ef4fbc0eae818 Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Wed, 24 May 2023 18:02:55 +0200 Subject: [PATCH 20/58] feat(linking_candidate): pretty print --- flair/data.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/flair/data.py b/flair/data.py index ada25b0ea..c799c9dd9 100644 --- a/flair/data.py +++ b/flair/data.py @@ -457,6 +457,15 @@ def __init__( self.score = score self.additional_ids = additional_ids + def __str__(self) -> str: + string = f"EntityLinkingCandidate: {self.database_name}:{self.concept_id} - {self.concept_name} - {self.score}" + if self.additional_ids is not None: + string += f" - {self.additional_ids}" + return string + + def __repr__(self) -> str: + return str(self) + class EntityLinkingLabel(Label): """ From 3677658a5c31669cad929a48b745c17ddc2d7801 Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Wed, 24 May 2023 18:03:50 +0200 Subject: [PATCH 21/58] fix(candidate_generator): check sparse encoder for sparse search --- flair/models/biomedical_entity_linking.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index f3e445d8b..77eae6ac3 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -868,6 +868,8 @@ def search_sparse( :param normalise: normalise scores """ + assert self.sparse_encoder is not None, "BiEncoderCandidateGenerator has no `sparse_encoder`! Pass `force_hybrid_search=True` at initialization" + mention_embeddings = self.sparse_encoder(entity_mentions) start = time.time() @@ -877,7 +879,7 @@ def search_sparse( score_matrix = np.matmul(mention_embeddings, self.embeddings["sparse"].T) elapsed = round(time.time() - start, 2) if timeit: - logger.info(f"BiEncoderCandidateGenerator: sparse search with {len(entity_mentions)} query took ~{elapsed}") + logger.info("BiEncoderCandidateGenerator: sparse search with %s query took ~%s", len(entity_mentions), elapsed) if normalise: score_matrix = (score_matrix - score_matrix.min()) / (score_matrix.max() - score_matrix.min()) @@ -917,7 +919,7 @@ def search_dense(self, entity_mentions: List[str], top_k: int = 1, timeit : bool dists, ids = self.dense_index.search(mention_dense_embeds, top_k) elapsed = round(time.time() - start, 2) if timeit: - logger.info(f"BiEncoderCandidateGenerator: dense search with {len(entity_mentions)} query took ~{elapsed}") + logger.info("BiEncoderCandidateGenerator: dense search with %s query took ~%s", len(entity_mentions), elapsed) return dists, ids @@ -1173,7 +1175,7 @@ def __get_model_path_and_entity_type( # check if user really wants to use hybrid search anyway if not force_hybrid_search: logger.warning( - "Model for entity type `%s` was not trained for hybrid search: no sparse search will be performed." + "BiEncoderCandidateGenerator: model for entity type `%s` was not trained for hybrid search: no sparse search will be performed." " If you want to use sparse search please pass `force_hybrid_search=True`:" " we will fit a sparse encoder for you. The default value of `sparse_weight` is `%s`.", model_name_or_path, DEFAULT_SPARSE_WEIGHT @@ -1182,7 +1184,7 @@ def __get_model_path_and_entity_type( else: if model_name_or_path not in PRETRAINED_HYBRID_MODELS and not force_hybrid_search: logger.warning( - "Model `%s` was not trained for hybrid search: no sparse search will be performed." + "BiEncoderCandidateGenerator: model `%s` was not trained for hybrid search: no sparse search will be performed." " If you want to use sparse search please pass `force_hybrid_search=True`:" " we will fit a sparse encoder for you. The default value of `sparse_weight` is `%s`.", model_name_or_path, DEFAULT_SPARSE_WEIGHT From 414b5a84a7a2ff12c1179242ea0e2403700862c5 Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Thu, 1 Jun 2023 19:46:37 +0200 Subject: [PATCH 22/58] feat(candidate_generator): add sparse index --- flair/models/biomedical_entity_linking.py | 317 +++++++++++++--------- 1 file changed, 196 insertions(+), 121 deletions(-) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index 77eae6ac3..5bf3763f0 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -1,6 +1,5 @@ import logging import os -import pickle import re import stat import string @@ -19,7 +18,6 @@ import torch from huggingface_hub import hf_hub_download from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm import flair @@ -38,14 +36,24 @@ from flair.file_utils import cached_path FAISS_VERSION = "1.7.4" +HNSWLIB_VERSION = "1.17.2" try: import faiss except ImportError as error: raise ImportError( - f"You need to install faiss to run the biomedical entity linking: `pip faiss faiss-cpu=={FAISS_VERSION}`" + f"You need to install faiss to run the biomedical entity linking: `pip install faiss-cpu=={FAISS_VERSION}`" ) from error + +try: + import hnswlib +except ImportError as error: + raise ImportError( + f"You need to install faiss to run the biomedical entity linking: `pip install hnswlib=={HNSWLIB_VERSION}`" + ) from error + + logger = logging.getLogger("flair") @@ -138,6 +146,26 @@ class SimilarityMetric(Enum): COSINE = auto() +HNSWLIB_METRIC = {SimilarityMetric.INNER_PRODUCT: "ip", SimilarityMetric.COSINE: "cosine"} + + +def timeit(func): + """ + # This function shows the execution time of + # the function object passed + """ + + def wrap_func(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + elapsed = round(time.time() - start, 4) + class_name, func_name = func.__qualname__.split(".") + logger.info("%s: %s took ~%s", class_name, func_name , elapsed) + return result + + return wrap_func + + class AbstractEntityPreprocessor(ABC): """ A pre-processor used to transform / clean both entity mentions and entity names @@ -475,7 +503,7 @@ class AbstractCandidateGenerator(ABC): """ @abstractmethod - def search(self, entity_mentions: List[str], top_k: int, timeit: bool = False) -> List[List[EntityLinkingCandidate]]: + def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: """ Returns the top-k entity / concept identifiers for the each entity mention. @@ -522,7 +550,7 @@ def load(cls, dictionary_name_or_path: str) -> "ExactStringMatchingRetrieverMode """Compatibility function""" return cls(BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path)) - def search(self, entity_mentions: List[str], top_k: int, timeit: bool = False) -> List[List[EntityLinkingCandidate]]: + def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: """ Returns the top-k entity / concept identifiers for the each entity mention. @@ -628,6 +656,8 @@ def __init__( self.hybrid_search = hybrid_search self.sparse_weight = sparse_weight self.force_hybrid_search = force_hybrid_search + if self.force_hybrid_search: + self.hybrid_search = True # allow to pass custom dictionary if dictionary is not None: @@ -642,18 +672,16 @@ def __init__( self.sparse_encoder: Optional[BigramTfIDFVectorizer] = None self.sparse_weight: Optional[float] = None if self.hybrid_search: - self._set_sparse_weigth_and_encoder( + self.sparse_encoder, self.sparse_weight = self._get_sparse_encoder_and_weight( model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path ) - self.embeddings = self._load_embeddings( + self.indices = self._load_indices( model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path, batch_size=self.batch_size, ) - self.dense_index = self.build_dense_index(self.embeddings["dense"]) - @property def higher_is_better(self): """ @@ -673,23 +701,13 @@ def _get_cache_name(self, model_name_or_path: str, dictionary_name_or_path: str) return f"{file_name}-{pp_name}" - # separate method to allow more sophisticated logic in the future, - # e.g. ANN with IndexIP, HNSW... - def build_dense_index(self, embeddings: np.ndarray) -> faiss.Index: - """Initialize FAISS index""" - - dense_index = faiss.IndexFlatIP(embeddings.shape[1]) - dense_index.add(embeddings) - - return dense_index - + @timeit def _fit_sparse_encoder(self) -> BigramTfIDFVectorizer: """Fit sparse encoder to current dictionary""" logger.info( - "Hybrid model has no pretrained sparse encoder. Fit to dictionary `%s` (sparse_weight=%s)", + "BiEncoderCandidateGenerator: hybrid model has no pretrained sparse encoder. Fit to dictionary `%s`", self.dictionary_name_or_path, - self.sparse_weight, ) sparse_encoder = BigramTfIDFVectorizer().fit([name for name, cui in self.dictionary_data]) # sparse_encoder.save(Path(sparse_encoder_path)) @@ -700,33 +718,26 @@ def _fit_sparse_encoder(self) -> BigramTfIDFVectorizer: def _handle_sparse_encoder(self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path]): """If necessary fit and cache sparse encoder""" - if self.force_hybrid_search: + if isinstance(model_name_or_path, str): + cache_name = self._get_cache_name( + model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path + ) + path = flair.cache_root / "models" / f"{cache_name}-sparse-encoder.pk" + else: + path = model_name_or_path / "sparse_encoder.pk" - self.sparse_weight = self.sparse_weight if self.sparse_weight is not None else DEFAULT_SPARSE_WEIGHT + if path.exists(): + sparse_encoder = BigramTfIDFVectorizer.load(path) + else: + sparse_encoder = self._fit_sparse_encoder() + logger.info("Save fitted sparse encoder to %s", path) + sparse_encoder.save(path) - if isinstance(model_name_or_path, str): - cache_name = self._get_cache_name( - model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path - ) - path = flair.cache_root / "models" / f"{cache_name}-sparse-encoder.pk" - else: - path = model_name_or_path / "sparse_encoder.pk" + return sparse_encoder - if path.exists(): - self.sparse_encoder = BigramTfIDFVectorizer.load(path) - else: - self.sparse_encoder = self._fit_sparse_encoder() - logger.info("Save fitted sparse encoder to %s", path) - self.sparse_encoder.save(path) - # else: - # model_type = "Hybrid model" if isinstance(model_name_or_path, str) else "Local hybrid model" - # raise ValueError( - # f"{model_type} has no pretrained sparse encoder. Please pass `force_hybrid_search=True` if you want to fit a sparse model to dictionary `{dictionary_name_or_path}`" - # ) - - def _set_sparse_weigth_and_encoder( + def _get_sparse_encoder_and_weight( self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path] - ): + ) -> Tuple[BigramTfIDFVectorizer, float]: sparse_encoder_path = os.path.join(model_name_or_path, "sparse_encoder.pk") sparse_weight_path = os.path.join(model_name_or_path, "sparse_weight.pt") @@ -740,7 +751,7 @@ def _set_sparse_weigth_and_encoder( cache_dir=flair.cache_root / "models" / model_name_or_path, ) - self.sparse_encoder = BigramTfIDFVectorizer.load(path=sparse_encoder_path) + sparse_encoder = BigramTfIDFVectorizer.load(path=sparse_encoder_path) if not os.path.exists(sparse_weight_path): @@ -749,12 +760,16 @@ def _set_sparse_weigth_and_encoder( filename="sparse_weight.pt", cache_dir=flair.cache_root / "models" / model_name_or_path, ) - self.sparse_weight = torch.load(sparse_weight_path, map_location="cpu").item() + sparse_weight = torch.load(sparse_weight_path, map_location="cpu").item() else: - self._handle_sparse_encoder( + sparse_weight = self.sparse_weight if self.sparse_weight is not None else DEFAULT_SPARSE_WEIGHT + sparse_encoder = self._handle_sparse_encoder( model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path ) + return sparse_encoder, sparse_weight + + @timeit def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: """ Create sparse embeddings from array of entity mentions/names. @@ -770,6 +785,7 @@ def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: return sparse_embeds + @timeit def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: bool = False) -> np.ndarray: """ @@ -788,7 +804,7 @@ def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: if show_progress: iterations = tqdm( range(0, len(inputs), batch_size), - desc=f"Embedding `{self.dictionary.database_name}` dictionary", + desc=f"Embedding `{self.dictionary.database_name}`", ) else: iterations = range(0, len(inputs), batch_size) @@ -810,56 +826,131 @@ def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: return dense_embeds - def _load_embeddings(self, model_name_or_path: str, dictionary_name_or_path: str, batch_size: int): - """Compute and cache the embeddings for the given knowledge base / dictionary.""" + # separate method to allow more sophisticated logic in the future, e.g.: ANN with HNSW, PQ... + @timeit + def build_dense_index(self, embeddings: np.ndarray) -> faiss.Index: + """Initialize FAISS index""" + + index = faiss.IndexFlatIP(embeddings.shape[1]) + index.add(embeddings) + + return index + + # separate method to allow more sophisticated logic in the future... + @timeit + def build_sparse_index(self, embeddings: np.ndarray) -> hnswlib.Index: + """Initialize Annoy index""" + + metric = HNSWLIB_METRIC[self.similarity_metric] + + ###################################### + # ANNOY + ###################################### + # index = annoy.AnnoyIndex(embeddings.shape[1], metric) + # # See https://github.com/spotify/annoy#tradeoffs + # n_trees = int(embeddings.shape[0] / 100) + # for i, v in enumerate(embeddings.tolist()): + # index.add_item(i, v) + # index.build(n_trees, n_jobs=self.cores) + + ###################################### + # HNSWLIB + ###################################### + index = hnswlib.Index(space=metric, dim=embeddings.shape[1]) + index.init_index(max_elements=embeddings.shape[0], ef_construction=200, M=16) + index.add_items(embeddings, np.arange(embeddings.shape[0])) + index.set_ef(50) # ef should always be > k + + return index + + def _load_indices(self, model_name_or_path: str, dictionary_name_or_path: str, batch_size: int) -> Dict: + """Load cached indices if available, otherwise compute embeddings, build index and cache""" - cache_folder = flair.cache_root / "datasets" cache_name = self._get_cache_name( model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path ) - embeddings_cache_file = cache_folder / f"{cache_name}-embeddings.pk" + cache_folder = flair.cache_root / "datasets" / cache_name - # If exists, load the cached dictionary indices - if embeddings_cache_file.exists(): + cache_folder.mkdir(parents=True, exist_ok=True) - with embeddings_cache_file.open("rb") as fp: - logger.info("BiEncoderCandidateGenerator: load cached emebddings `%s`", f"{cache_name}-embeddings.pk") - embeddings = pickle.load(fp) + indices = {} + preprocessed_names = None - else: + for index_type in ["sparse", "dense"]: - cache_folder.mkdir(parents=True, exist_ok=True) + if index_type == "sparse" and not self.hybrid_search: + continue - names = [self.preprocessor.process_entity_name(name) for name, cui in self.dictionary_data] + file_name = f"index-{index_type}.bin" - # Compute dense embeddings (if necessary) - dense_embeddings = self.embed_dense(inputs=names, batch_size=batch_size, show_progress=True) + index_cache_file = cache_folder / file_name - sparse_embeddings = self.embed_sparse(inputs=names) if self.hybrid_search else None + # If exists, load the cached dictionary indices + if index_cache_file.exists(): - # Store the pre-computed index on disk for later re-use - embeddings = { - "dense": dense_embeddings, - "sparse": sparse_embeddings, - } + logger.info( + "BiEncoderCandidateGenerator: %s: load cached %s index from `%s`", + self.dictionary.database_name, + index_type, + cache_name, + ) - logger.info("Caching dictionary emebddings into %s", embeddings_cache_file) - with embeddings_cache_file.open("wb") as fp: - pickle.dump(embeddings, fp) + if index_type == "dense": + indices[index_type] = faiss.read_index(str(index_cache_file)) + else: - if self.similarity_metric == SimilarityMetric.COSINE: - faiss.normalize_L2(embeddings["dense"]) + dimension = len(self.sparse_encoder.encoder.vocabulary_) + index = hnswlib.Index(space=HNSWLIB_VERSION[dimension], dim=dimension) + index.load_index(str(index_cache_file)) + # index = annoy.AnnoyIndex(dimension, HNSWLIB_METRIC[self.similarity_metric]) + # index.load(str(index_cache_file)) + # indices[index_type] = index - return embeddings + else: - def search_sparse( - self, - entity_mentions: List[str], - top_k: int = 1, - normalise: bool = False, - timeit: bool = False - ) -> Tuple[np.ndarray, np.ndarray]: + logger.info( + "BiEncoderCandidateGenerator: %s: build %s index. This will take some time...", + self.dictionary.database_name, + index_type, + ) + + if preprocessed_names is None: + preprocessed_names = [ + self.preprocessor.process_entity_name(name) for name, cui in self.dictionary_data + ] + + embeddings = ( + self.embed_dense(inputs=preprocessed_names, batch_size=batch_size, show_progress=True) + if index_type == "dense" + else self.embed_sparse(inputs=preprocessed_names) + ) + + if self.similarity_metric == SimilarityMetric.COSINE: + faiss.normalize_L2(embeddings) + + index = ( + self.build_dense_index(embeddings) if index_type == "dense" else self.build_sparse_index(embeddings) + ) + + if index_type == "dense": + faiss.write_index(index, str(index_cache_file)) + else: + index.save_index(str(index_cache_file)) + + indices[index_type] = index + + logger.info( + "BiEncoderCandidateGenerator: %s: cached %s index in %s", + self.dictionary.database_name, + index_type, + cache_name, + ) + + return indices + + @timeit + def search_sparse(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np.ndarray, np.ndarray]: """ Find candidates with sparse representations @@ -868,39 +959,26 @@ def search_sparse( :param normalise: normalise scores """ - assert self.sparse_encoder is not None, "BiEncoderCandidateGenerator has no `sparse_encoder`! Pass `force_hybrid_search=True` at initialization" + assert ( + self.sparse_encoder is not None + ), "BiEncoderCandidateGenerator has no `sparse_encoder`! Pass `force_hybrid_search=True` at initialization" mention_embeddings = self.sparse_encoder(entity_mentions) - start = time.time() - if self.similarity_metric == SimilarityMetric.COSINE: - score_matrix = cosine_similarity(mention_embeddings, self.embeddings["sparse"]) - else: - score_matrix = np.matmul(mention_embeddings, self.embeddings["sparse"].T) - elapsed = round(time.time() - start, 2) - if timeit: - logger.info("BiEncoderCandidateGenerator: sparse search with %s query took ~%s", len(entity_mentions), elapsed) - - if normalise: - score_matrix = (score_matrix - score_matrix.min()) / (score_matrix.max() - score_matrix.min()) + # idxs = [] + # dists = [] + # for v in mention_embeddings.tolist(): + # vidxs, vdists = self.indices["sparse"].get_nns_by_vector(v, top_k, include_distances=True) + # idxs.append(vidxs) + # dists.append(vdists) + # np.vstack(dists), np.vstack(idxs) - def indexing_2d(arr, cols): - rows = np.repeat(np.arange(0, cols.shape[0])[:, np.newaxis], cols.shape[1], axis=1) - return arr[rows, cols] + idxs, dists = self.indices["sparse"].knn_query(mention_embeddings, k=top_k) - # Get topk indexes without sorting - topk_idxs = np.argpartition(score_matrix, -top_k)[:, -top_k:] + return idxs, dists - # Get topk indexes with sorting - topk_score_matrix = indexing_2d(score_matrix, topk_idxs) - topk_argidxs = np.argsort(-topk_score_matrix) - topk_idxs = indexing_2d(topk_idxs, topk_argidxs) - topk_scores = indexing_2d(score_matrix, topk_idxs) - - - return topk_scores, topk_idxs - - def search_dense(self, entity_mentions: List[str], top_k: int = 1, timeit : bool = False) -> Tuple[np.ndarray, np.ndarray]: + @timeit + def search_dense(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np.ndarray, np.ndarray]: """ Find candidates with dense representations (FAISS) @@ -914,14 +992,10 @@ def search_dense(self, entity_mentions: List[str], top_k: int = 1, timeit : bool if self.similarity_metric == SimilarityMetric.COSINE: faiss.normalize_L2(mention_dense_embeds) - start = time.time() # Get candidates from dense embeddings - dists, ids = self.dense_index.search(mention_dense_embeds, top_k) - elapsed = round(time.time() - start, 2) - if timeit: - logger.info("BiEncoderCandidateGenerator: dense search with %s query took ~%s", len(entity_mentions), elapsed) + dists, ids = self.indices["dense"].search(mention_dense_embeds, top_k) - return dists, ids + return ids, dists def combine_dense_and_sparse_results( self, @@ -962,7 +1036,7 @@ def combine_dense_and_sparse_results( return hybrid_scores, hybrid_ids - def search(self, entity_mentions: List[str], top_k: int, timeit: bool = False) -> List[List[EntityLinkingCandidate]]: + def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: """ Returns the top-k entity / concept identifiers for the each entity mention. @@ -971,11 +1045,11 @@ def search(self, entity_mentions: List[str], top_k: int, timeit: bool = False) - :result: list of tuples in the form: (entity / concept name, concept ids, similarity score). """ - scores, ids = self.search_dense(entity_mentions=entity_mentions, top_k=top_k, timeit=timeit) + ids, scores = self.search_dense(entity_mentions=entity_mentions, top_k=top_k) if self.hybrid_search and self.sparse_encoder is not None: - sparse_scores, sparse_ids = self.search_sparse(entity_mentions=entity_mentions, top_k=top_k, timeit=timeit) + sparse_ids, sparse_scores = self.search_sparse(entity_mentions=entity_mentions, top_k=top_k) scores, ids = self.combine_dense_and_sparse_results( dense_ids=ids, @@ -1046,7 +1120,6 @@ def predict( sentences: Union[List[Sentence], Sentence], annotation_layers: Optional[List[str]] = None, top_k: int = 1, - timeit:bool = False ) -> None: """ Predicts the best matching top-k entity / concept identifiers of all named entites annotated @@ -1070,7 +1143,7 @@ def predict( # no mentions: nothing to do here if len(mentions) > 0: # Retrieve top-k concept / entity candidates - candidates = self.candidate_generator.search(entity_mentions=mentions, top_k=top_k, timeit=timeit) + candidates = self.candidate_generator.search(entity_mentions=mentions, top_k=top_k) # Add a label annotation for each candidate for i, data_point, mention_candidates, mentions_annotation_layer in zip( @@ -1178,7 +1251,8 @@ def __get_model_path_and_entity_type( "BiEncoderCandidateGenerator: model for entity type `%s` was not trained for hybrid search: no sparse search will be performed." " If you want to use sparse search please pass `force_hybrid_search=True`:" " we will fit a sparse encoder for you. The default value of `sparse_weight` is `%s`.", - model_name_or_path, DEFAULT_SPARSE_WEIGHT + model_name_or_path, + DEFAULT_SPARSE_WEIGHT, ) model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] else: @@ -1187,7 +1261,8 @@ def __get_model_path_and_entity_type( "BiEncoderCandidateGenerator: model `%s` was not trained for hybrid search: no sparse search will be performed." " If you want to use sparse search please pass `force_hybrid_search=True`:" " we will fit a sparse encoder for you. The default value of `sparse_weight` is `%s`.", - model_name_or_path, DEFAULT_SPARSE_WEIGHT + model_name_or_path, + DEFAULT_SPARSE_WEIGHT, ) entity_type = PRETRAINED_HYBRID_MODELS[model_name_or_path] From 1783eabeb5ca758ff2ed3f33a8c565a8a13256b6 Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Fri, 2 Jun 2023 16:16:30 +0200 Subject: [PATCH 23/58] fix(candidate_generator): KISS: sparse search w/ scipy sparse matrices --- flair/models/biomedical_entity_linking.py | 231 ++++++++++------------ 1 file changed, 103 insertions(+), 128 deletions(-) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index 5bf3763f0..c348640c9 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -15,9 +15,12 @@ import joblib import numpy as np +import scipy import torch from huggingface_hub import hf_hub_download +from scipy.sparse._csr import csr_matrix from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm import flair @@ -36,7 +39,6 @@ from flair.file_utils import cached_path FAISS_VERSION = "1.7.4" -HNSWLIB_VERSION = "1.17.2" try: import faiss @@ -46,14 +48,6 @@ ) from error -try: - import hnswlib -except ImportError as error: - raise ImportError( - f"You need to install faiss to run the biomedical entity linking: `pip install hnswlib=={HNSWLIB_VERSION}`" - ) from error - - logger = logging.getLogger("flair") @@ -146,9 +140,6 @@ class SimilarityMetric(Enum): COSINE = auto() -HNSWLIB_METRIC = {SimilarityMetric.INNER_PRODUCT: "ip", SimilarityMetric.COSINE: "cosine"} - - def timeit(func): """ # This function shows the execution time of @@ -160,7 +151,7 @@ def wrap_func(*args, **kwargs): result = func(*args, **kwargs) elapsed = round(time.time() - start, 4) class_name, func_name = func.__qualname__.split(".") - logger.info("%s: %s took ~%s", class_name, func_name , elapsed) + logger.info("%s: %s took ~%s", class_name, func_name, elapsed) return result return wrap_func @@ -585,13 +576,12 @@ def fit(self, names: List[str]): self.encoder.fit(names) return self - def transform(self, names: List[str]) -> torch.Tensor: + def transform(self, names: List[str]) -> csr_matrix: """Convert strings to sparse vectors""" - vec = self.encoder.transform(names).toarray() - vec = torch.FloatTensor(vec) - return vec + embeddings = self.encoder.transform(names) + return embeddings - def __call__(self, mentions: List[str]) -> torch.Tensor: + def __call__(self, mentions: List[str]) -> np.ndarray: """Short for `transform`""" return self.transform(mentions) @@ -665,7 +655,9 @@ def __init__( else: self.dictionary = BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path) - self.dictionary_data = list(self.dictionary.stream()) + self.dictionary_data = [ + (self.preprocessor.process_entity_name(name), cui) for name, cui in self.dictionary.stream() + ] # Load encoders self.dense_encoder = TransformerDocumentEmbeddings(model=model_name_or_path, is_token_embedding=False) @@ -679,7 +671,6 @@ def __init__( self.indices = self._load_indices( model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path, - batch_size=self.batch_size, ) @property @@ -702,7 +693,7 @@ def _get_cache_name(self, model_name_or_path: str, dictionary_name_or_path: str) return f"{file_name}-{pp_name}" @timeit - def _fit_sparse_encoder(self) -> BigramTfIDFVectorizer: + def fit_sparse_encoder(self) -> BigramTfIDFVectorizer: """Fit sparse encoder to current dictionary""" logger.info( @@ -729,8 +720,8 @@ def _handle_sparse_encoder(self, model_name_or_path: Union[str, Path], dictionar if path.exists(): sparse_encoder = BigramTfIDFVectorizer.load(path) else: - sparse_encoder = self._fit_sparse_encoder() - logger.info("Save fitted sparse encoder to %s", path) + sparse_encoder = self.fit_sparse_encoder() + # logger.info("Save fitted sparse encoder to %s", path) sparse_encoder.save(path) return sparse_encoder @@ -769,23 +760,16 @@ def _get_sparse_encoder_and_weight( return sparse_encoder, sparse_weight - @timeit - def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: + def embed_sparse(self, inputs: np.ndarray) -> csr_matrix: """ Create sparse embeddings from array of entity mentions/names. :param entity_names: An array of entity / concept names - :returns sparse_embeds np.array: Numpy array containing the sparse embeddings + :returns sparse_embeds csr_matrix: Scipy sparse CSR matrix """ - sparse_embeds = self.sparse_encoder(inputs) - sparse_embeds = sparse_embeds.numpy() - if self.similarity_metric == SimilarityMetric.COSINE: - faiss.normalize_L2(sparse_embeds) - - return sparse_embeds + return self.sparse_encoder(inputs) - @timeit def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: bool = False) -> np.ndarray: """ @@ -827,43 +811,42 @@ def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: return dense_embeds # separate method to allow more sophisticated logic in the future, e.g.: ANN with HNSW, PQ... - @timeit - def build_dense_index(self, embeddings: np.ndarray) -> faiss.Index: - """Initialize FAISS index""" + def get_dense_index(self, names: List[str], path: Path) -> faiss.Index: + """Load or create dense index and save it to disk""" + + if path.exists(): + + index = faiss.read_index(str(path)) + + else: + + embeddings = self.embed_dense(inputs=names, batch_size=self.batch_size, show_progress=True) - index = faiss.IndexFlatIP(embeddings.shape[1]) - index.add(embeddings) + index = faiss.IndexFlatIP(embeddings.shape[1]) + index.add(embeddings) + + if self.similarity_metric == SimilarityMetric.COSINE: + faiss.normalize_L2(embeddings) + + faiss.write_index(index, str(path)) return index - # separate method to allow more sophisticated logic in the future... - @timeit - def build_sparse_index(self, embeddings: np.ndarray) -> hnswlib.Index: - """Initialize Annoy index""" - - metric = HNSWLIB_METRIC[self.similarity_metric] - - ###################################### - # ANNOY - ###################################### - # index = annoy.AnnoyIndex(embeddings.shape[1], metric) - # # See https://github.com/spotify/annoy#tradeoffs - # n_trees = int(embeddings.shape[0] / 100) - # for i, v in enumerate(embeddings.tolist()): - # index.add_item(i, v) - # index.build(n_trees, n_jobs=self.cores) - - ###################################### - # HNSWLIB - ###################################### - index = hnswlib.Index(space=metric, dim=embeddings.shape[1]) - index.init_index(max_elements=embeddings.shape[0], ef_construction=200, M=16) - index.add_items(embeddings, np.arange(embeddings.shape[0])) - index.set_ef(50) # ef should always be > k + def get_sparse_index(self, names: List[str], path: Path) -> csr_matrix: + """Load or create sparse index and save it to disk""" + + if path.exists(): + index = scipy.sparse.load_npz(str(path)) + else: + index = self.embed_sparse(inputs=names) + + scipy.sparse.save_npz(str(path), index) + # index.save_index # HNSWLIB + # index.save # ANNOY return index - def _load_indices(self, model_name_or_path: str, dictionary_name_or_path: str, batch_size: int) -> Dict: + def _load_indices(self, model_name_or_path: str, dictionary_name_or_path: str) -> Dict: """Load cached indices if available, otherwise compute embeddings, build index and cache""" cache_name = self._get_cache_name( @@ -875,76 +858,34 @@ def _load_indices(self, model_name_or_path: str, dictionary_name_or_path: str, b cache_folder.mkdir(parents=True, exist_ok=True) indices = {} - preprocessed_names = None + + logger.info( + "BiEncoderCandidateGenerator: initialize %s %s", + self.dictionary.database_name, + "indices" if self.hybrid_search else "index", + ) for index_type in ["sparse", "dense"]: if index_type == "sparse" and not self.hybrid_search: continue - file_name = f"index-{index_type}.bin" + extension = "bin" if index_type == "dense" else "npz" + + file_name = f"index-{index_type}.{extension}" index_cache_file = cache_folder / file_name - # If exists, load the cached dictionary indices - if index_cache_file.exists(): + if index_type == "dense": - logger.info( - "BiEncoderCandidateGenerator: %s: load cached %s index from `%s`", - self.dictionary.database_name, - index_type, - cache_name, + indices[index_type] = self.get_dense_index( + names=[n for n, _ in self.dictionary_data], path=index_cache_file ) - if index_type == "dense": - indices[index_type] = faiss.read_index(str(index_cache_file)) - else: - - dimension = len(self.sparse_encoder.encoder.vocabulary_) - index = hnswlib.Index(space=HNSWLIB_VERSION[dimension], dim=dimension) - index.load_index(str(index_cache_file)) - # index = annoy.AnnoyIndex(dimension, HNSWLIB_METRIC[self.similarity_metric]) - # index.load(str(index_cache_file)) - # indices[index_type] = index - else: - logger.info( - "BiEncoderCandidateGenerator: %s: build %s index. This will take some time...", - self.dictionary.database_name, - index_type, - ) - - if preprocessed_names is None: - preprocessed_names = [ - self.preprocessor.process_entity_name(name) for name, cui in self.dictionary_data - ] - - embeddings = ( - self.embed_dense(inputs=preprocessed_names, batch_size=batch_size, show_progress=True) - if index_type == "dense" - else self.embed_sparse(inputs=preprocessed_names) - ) - - if self.similarity_metric == SimilarityMetric.COSINE: - faiss.normalize_L2(embeddings) - - index = ( - self.build_dense_index(embeddings) if index_type == "dense" else self.build_sparse_index(embeddings) - ) - - if index_type == "dense": - faiss.write_index(index, str(index_cache_file)) - else: - index.save_index(str(index_cache_file)) - - indices[index_type] = index - - logger.info( - "BiEncoderCandidateGenerator: %s: cached %s index in %s", - self.dictionary.database_name, - index_type, - cache_name, + indices[index_type] = self.get_sparse_index( + names=[n for n, _ in self.dictionary_data], path=index_cache_file ) return indices @@ -958,22 +899,28 @@ def search_sparse(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np. :param top_k: number of candidates to retrieve :param normalise: normalise scores """ - assert ( self.sparse_encoder is not None ), "BiEncoderCandidateGenerator has no `sparse_encoder`! Pass `force_hybrid_search=True` at initialization" mention_embeddings = self.sparse_encoder(entity_mentions) - # idxs = [] - # dists = [] - # for v in mention_embeddings.tolist(): - # vidxs, vdists = self.indices["sparse"].get_nns_by_vector(v, top_k, include_distances=True) - # idxs.append(vidxs) - # dists.append(vdists) - # np.vstack(dists), np.vstack(idxs) + if self.similarity_metric == SimilarityMetric.COSINE: + score_matrix = cosine_similarity(mention_embeddings, self.indices["sparse"], dense_output=False) + elif self.similarity_metric == SimilarityMetric.INNER_PRODUCT: + score_matrix = mention_embeddings.dot(self.indices["sparse"].T) + + score_matrix = score_matrix.toarray() + + num_mentions = score_matrix.shape[0] - idxs, dists = self.indices["sparse"].knn_query(mention_embeddings, k=top_k) + unsorted_indices = np.argpartition(score_matrix, -top_k)[:, -top_k:] + unsorted_scores = score_matrix[np.arange(num_mentions)[:, None], unsorted_indices] + + sorted_score_matrix_indices = np.argsort(-unsorted_scores) + + idxs = unsorted_indices[np.arange(num_mentions)[:, None], sorted_score_matrix_indices] + dists = unsorted_scores[np.arange(num_mentions)[:, None], sorted_score_matrix_indices] return idxs, dists @@ -1304,3 +1251,31 @@ def __get_dictionary_path( ) return dictionary_name_or_path + + # @timeit + # def build_sparse_index(self, embeddings: csr_matrix) -> csr_matrix: + # """Initialize sparse index""" + + # index = embeddings + + # ###################################### + # # ANNOY + # ###################################### + # # metric = ANNOY_METRIC[self.similarity_metric] + # # index = annoy.AnnoyIndex(embeddings.shape[1], metric) + # # # See https://github.com/spotify/annoy#tradeoffs + # # n_trees = int(embeddings.shape[0] / 100) + # # for i, v in enumerate(embeddings.tolist()): + # # index.add_item(i, v) + # # index.build(n_trees, n_jobs=min(mp.cpu_count(), 8)) + + # ###################################### + # # HNSWLIB + # ###################################### + # # metric = HNSWLIB_METRIC[self.similarity_metric] + # # index = hnswlib.Index(space=metric, dim=embeddings.shape[1]) + # # index.init_index(max_elements=embeddings.shape[0], ef_construction=200, M=16) + # # index.add_items(embeddings, np.arange(embeddings.shape[0])) + # # index.set_ef(50) # ef should always be > k + + # return index From 8c908baa9e232e3fc58d21486766e4a871e711a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20S=C3=A4nger?= Date: Wed, 12 Jul 2023 14:54:16 +0200 Subject: [PATCH 24/58] Minor update to comments and documentation --- flair/data.py | 2 +- flair/models/biomedical_entity_linking.py | 128 ++++++------------ .../HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md | 38 ++++-- 3 files changed, 73 insertions(+), 95 deletions(-) diff --git a/flair/data.py b/flair/data.py index c799c9dd9..6e197ff65 100644 --- a/flair/data.py +++ b/flair/data.py @@ -458,7 +458,7 @@ def __init__( self.additional_ids = additional_ids def __str__(self) -> str: - string = f"EntityLinkingCandidate: {self.database_name}:{self.concept_id} - {self.concept_name} - {self.score}" + string = f"EntityLinkingCandidate: {self.database_name}:{self.concept_id} - {self.concept_name} - {self.score}" if self.additional_ids is not None: string += f" - {self.additional_ids}" return string diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index c348640c9..7ce894b16 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -142,8 +142,7 @@ class SimilarityMetric(Enum): def timeit(func): """ - # This function shows the execution time of - # the function object passed + This function shows the execution time of the function object passed """ def wrap_func(*args, **kwargs): @@ -206,6 +205,9 @@ class EntityPreprocessor(AbstractEntityPreprocessor): Entity preprocessor adapted from: Sung et al. 2020, Biomedical Entity Representations with Synonym Marginalization https://github.com/dmis-lab/BioSyn/blob/master/src/biosyn/preprocesser.py#L5 + + The preprocessor provides basic string transformation options including lower-casing, + removal of punctuations symbols, etc. """ def __init__(self, lowercase: bool = True, remove_punctuation: bool = True): @@ -221,7 +223,6 @@ def __init__(self, lowercase: bool = True, remove_punctuation: bool = True): @property def name(self): - return "biosyn" def initialize(self, sentences): @@ -243,7 +244,7 @@ def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: class Ab3PEntityPreprocessor(AbstractEntityPreprocessor): """ - Entity preprocessor which uses Ab3P, an (biomedical)abbreviation definition detector: + Entity preprocessor which uses Ab3P, an (biomedical) abbreviation definition detector: Abbreviation definition identification based on automatic precision estimates. Sohn S, Comeau DC, Kim W, Wilbur WJ. BMC Bioinformatics. 2008 Sep 25;9:402. PubMed ID: 18817555 @@ -267,7 +268,6 @@ def __init__( @property def name(self): - return f"ab3p_{self.preprocessor.name}" def initialize(self, sentences: List[Sentence]) -> None: @@ -347,6 +347,7 @@ def download_ab3p(cls, data_dir: Path, word_data_dir: Path) -> Path: # Download Ab3P executable ab3p_path = cached_path("https://github.com/dmis-lab/BioSyn/raw/master/Ab3P/identify_abbr", data_dir) + # Make Ab3P executable ab3p_path.chmod(ab3p_path.stat().st_mode | stat.S_IXUSR) return ab3p_path @@ -362,12 +363,12 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict {"RSV": "Rous sarcoma virus"} } - :param sentences: list of senternces + :param sentences: list of sentences :result abbreviation_dict: abbreviations and their resolution detected in each input sentence """ abbreviation_dict = defaultdict(dict) - # Create a temp file which holds the sentences we want to process with ab3p + # Create a temp file which holds the sentences we want to process with Ab3P with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as temp_file: for sentence in sentences: temp_file.write(sentence.to_tokenized_string() + "\n") @@ -377,7 +378,7 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict with open(os.path.join(os.getcwd(), "path_Ab3P"), "w") as path_file: path_file.write(str(self.word_data_dir) + "/\n") - # Run ab3p with the temp file containing the dataset + # Run Ab3P with the temp file containing the dataset # https://pylint.pycqa.org/en/latest/user_guide/messages/warning/subprocess-run-check.html try: result = subprocess.run( @@ -432,10 +433,14 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict class BiomedicalEntityLinkingDictionary: """ - Load dictionary: either pre-definded or from path - Every line in the file must be formatted as follows: concept_id||concept_name - If multiple concept ids are associated to a given name - they must be separated by a `|`. + Class to load named entity dictionaries: either pre-defined or from a path on disk. + For the latter, every line in the file must be formatted as follows: + + concept_id||concept_name + + If multiple concept ids are associated to a given name they must be separated by a `|`, e.g. + + 7157||TP53|tumor protein p53 """ def __init__( @@ -490,17 +495,18 @@ def stream(self) -> Iterator[Tuple[str, str]]: class AbstractCandidateGenerator(ABC): """ - Base class for a candidate genertor + Base class for a candidate generator, i.e. given a mention of an entity, find matching + entries from the dictionary. """ @abstractmethod def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: """ - Returns the top-k entity / concept identifiers for the each entity mention. + Returns the top-k entity / concept identifiers for each entity mention. :param entity_mentions: Entity mentions :param top_k: Number of best-matching entities from the knowledge base to return - :result: list of tuples in the form: (entity / concept name, concept ids, similarity score). + :result: List containing a list of entity linking candidates per entity mention from the input """ def build_candidate(self, candidate: Tuple[str, str, float], database_name: str) -> EntityLinkingCandidate: @@ -542,14 +548,6 @@ def load(cls, dictionary_name_or_path: str) -> "ExactStringMatchingRetrieverMode return cls(BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path)) def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: - """ - Returns the top-k entity / concept identifiers for the each entity mention. - - :param entity_mentions: Entity mentions - :param top_k: Number of best-matching entities from the knowledge base to return - :result: list of tuples in the form: (entity / concept name, concept ids, similarity score). - """ - return [ [ self.build_candidate( @@ -564,6 +562,7 @@ class BigramTfIDFVectorizer: """ Wrapper for sklearn TfIDFVectorizer w/ fixed ngram range at the character level Implementation adapted from: + Sung et al.: Biomedical Entity Representations with Synonym Marginalization, 2020 https://github.com/dmis-lab/BioSyn/tree/master/src/biosyn/sparse_encoder.py#L8 """ @@ -605,8 +604,7 @@ def load(cls, path: Path) -> "BigramTfIDFVectorizer": class BiEncoderCandidateGenerator(AbstractCandidateGenerator): """ - Candidate generator using both dense (transformer-based) - and (optionally) sparse vector representations, + Candidate generator using both dense (transformer-based) and (optionally) sparse vector representations, to search candidates in a knowledge base / dictionary. """ @@ -629,13 +627,13 @@ def __init__( :param model_name_or_path: Name of or path to the transformer model to be used. :param dictionary_name_or_path: Name of or path to the transformer model to be used. :param similarity_metric: which metric to use to compute similarity - :param preprocessor: Preprocessing for entity mentions and names + :param preprocessor: Preprocessing strategy for entity mentions and names :param max_length: Maximum number of input tokens to transformer model - :param batch_size: how many entity mentions/names to embed in one forward pass + :param batch_size: Number of entity mentions/names to embed in one forward pass :param hybrid_search: Indicates whether to use sparse embeddings or not - :param sparse_weight: default sparse weight + :param sparse_weight: Weight to balance sparse and dense similarity scores (default sparse weight) :param force_hybrid_search: if pre-trained model is not hybrid (dense+sparse) fit a sparse encoder - :param dictionary: optionally pass a dictionary + :param dictionary: optionally pass a custom dictionary """ self.model_name_or_path = model_name_or_path self.dictionary_name_or_path = dictionary_name_or_path @@ -706,7 +704,9 @@ def fit_sparse_encoder(self) -> BigramTfIDFVectorizer: return sparse_encoder - def _handle_sparse_encoder(self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path]): + def _handle_sparse_encoder( + self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path] + ) -> BigramTfIDFVectorizer: """If necessary fit and cache sparse encoder""" if isinstance(model_name_or_path, str): @@ -729,13 +729,11 @@ def _handle_sparse_encoder(self, model_name_or_path: Union[str, Path], dictionar def _get_sparse_encoder_and_weight( self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path] ) -> Tuple[BigramTfIDFVectorizer, float]: - sparse_encoder_path = os.path.join(model_name_or_path, "sparse_encoder.pk") sparse_weight_path = os.path.join(model_name_or_path, "sparse_weight.pt") if isinstance(model_name_or_path, str) and model_name_or_path in PRETRAINED_HYBRID_MODELS: if not os.path.exists(sparse_encoder_path): - sparse_encoder_path = hf_hub_download( repo_id=model_name_or_path, filename="sparse_encoder.pk", @@ -745,7 +743,6 @@ def _get_sparse_encoder_and_weight( sparse_encoder = BigramTfIDFVectorizer.load(path=sparse_encoder_path) if not os.path.exists(sparse_weight_path): - sparse_weight_path = hf_hub_download( repo_id=model_name_or_path, filename="sparse_weight.pt", @@ -760,19 +757,18 @@ def _get_sparse_encoder_and_weight( return sparse_encoder, sparse_weight - def embed_sparse(self, inputs: np.ndarray) -> csr_matrix: + def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: """ Create sparse embeddings from array of entity mentions/names. - :param entity_names: An array of entity / concept names - :returns sparse_embeds csr_matrix: Scipy sparse CSR matrix + :param inputs: Numpy array of entity / concept names + :returns Numpy array containing the sparse embeddings of the names """ return self.sparse_encoder(inputs) def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: bool = False) -> np.ndarray: """ - Create dense embeddings from array of entity mentions/names. :param names: Numpy array of entity / concept names @@ -815,11 +811,9 @@ def get_dense_index(self, names: List[str], path: Path) -> faiss.Index: """Load or create dense index and save it to disk""" if path.exists(): - index = faiss.read_index(str(path)) else: - embeddings = self.embed_dense(inputs=names, batch_size=self.batch_size, show_progress=True) index = faiss.IndexFlatIP(embeddings.shape[1]) @@ -854,7 +848,6 @@ def _load_indices(self, model_name_or_path: str, dictionary_name_or_path: str) - ) cache_folder = flair.cache_root / "datasets" / cache_name - cache_folder.mkdir(parents=True, exist_ok=True) indices = {} @@ -866,7 +859,6 @@ def _load_indices(self, model_name_or_path: str, dictionary_name_or_path: str) - ) for index_type in ["sparse", "dense"]: - if index_type == "sparse" and not self.hybrid_search: continue @@ -877,13 +869,11 @@ def _load_indices(self, model_name_or_path: str, dictionary_name_or_path: str) - index_cache_file = cache_folder / file_name if index_type == "dense": - indices[index_type] = self.get_dense_index( names=[n for n, _ in self.dictionary_data], path=index_cache_file ) else: - indices[index_type] = self.get_sparse_index( names=[n for n, _ in self.dictionary_data], path=index_cache_file ) @@ -895,9 +885,8 @@ def search_sparse(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np. """ Find candidates with sparse representations - :param entity_mentions: list of entity mentions (queries) - :param top_k: number of candidates to retrieve - :param normalise: normalise scores + :param entity_mentions: list of entity mentions (~ queries) + :param top_k: number of candidates to retrieve per mention """ assert ( self.sparse_encoder is not None @@ -929,7 +918,7 @@ def search_dense(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np.n """ Find candidates with dense representations (FAISS) - :param entity_mentions: list of entity mentions (queries) + :param entity_mentions: list of entity mentions (~ queries) :param top_k: number of candidates to retrieve """ @@ -953,14 +942,13 @@ def combine_dense_and_sparse_results( top_k: int = 1, ): """ - Expand dense resutls with sparse ones (that are not already in the dense) - and re-weight the score as: dense_score + sparse_weight * sparse_scores + Expand dense results with sparse ones (that are not already in the dense) and re-weight the + score as: dense_score + sparse_weight * sparse_scores """ hybrid_ids = [] hybrid_scores = [] for i in range(dense_ids.shape[0]): - mention_ids = dense_ids[i] mention_scores = dense_scores[i] @@ -985,17 +973,16 @@ def combine_dense_and_sparse_results( def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: """ - Returns the top-k entity / concept identifiers for the each entity mention. + Returns the top-k entity / concept identifiers for each entity mention. :param entity_mentions: Entity mentions :param top_k: Number of best-matching entities from the knowledge base to return - :result: list of tuples in the form: (entity / concept name, concept ids, similarity score). + :result: List containing a list of entity linking candidates per entity mention from the input """ ids, scores = self.search_dense(entity_mentions=entity_mentions, top_k=top_k) if self.hybrid_search and self.sparse_encoder is not None: - sparse_ids, sparse_scores = self.search_sparse(entity_mentions=entity_mentions, top_k=top_k) scores, ids = self.combine_dense_and_sparse_results( @@ -1035,7 +1022,7 @@ def extract_mentions( self, sentences: List[Sentence], annotation_layers: Optional[List[str]] = None, - ) -> Tuple[List[int], List[Span], List[str]]: + ) -> Tuple[List[int], List[Span], List[str], List[str]]: """Unpack all mentions in sentences for batch search.""" source = [] @@ -1069,11 +1056,11 @@ def predict( top_k: int = 1, ) -> None: """ - Predicts the best matching top-k entity / concept identifiers of all named entites annotated + Predicts the best matching top-k entity / concept identifiers of all named entities annotated with tag input_entity_annotation_layer. :param sentences: One or more sentences to run the prediction on - :param annotation_layers: list of annotation layers to extract entity mentions + :param annotation_layers: List of annotation layers to extract entity mentions :param top_k: Number of best-matching entity / concept identifiers """ # make sure sentences is a list of sentences @@ -1096,7 +1083,6 @@ def predict( for i, data_point, mention_candidates, mentions_annotation_layer in zip( source, data_points, candidates, mentions_annotation_layers ): - sentences[i].add_label( typename=mentions_annotation_layer, value_or_label=EntityLinkingLabel(data_point=data_point, candidates=mention_candidates), @@ -1251,31 +1237,3 @@ def __get_dictionary_path( ) return dictionary_name_or_path - - # @timeit - # def build_sparse_index(self, embeddings: csr_matrix) -> csr_matrix: - # """Initialize sparse index""" - - # index = embeddings - - # ###################################### - # # ANNOY - # ###################################### - # # metric = ANNOY_METRIC[self.similarity_metric] - # # index = annoy.AnnoyIndex(embeddings.shape[1], metric) - # # # See https://github.com/spotify/annoy#tradeoffs - # # n_trees = int(embeddings.shape[0] / 100) - # # for i, v in enumerate(embeddings.tolist()): - # # index.add_item(i, v) - # # index.build(n_trees, n_jobs=min(mp.cpu_count(), 8)) - - # ###################################### - # # HNSWLIB - # ###################################### - # # metric = HNSWLIB_METRIC[self.similarity_metric] - # # index = hnswlib.Index(space=metric, dim=embeddings.shape[1]) - # # index.init_index(max_elements=embeddings.shape[0], ef_construction=200, M=16) - # # index.add_items(embeddings, np.arange(embeddings.shape[0])) - # # index.set_ef(50) # ef should always be > k - - # return index diff --git a/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md index 2925054f0..b7c814570 100644 --- a/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md +++ b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md @@ -1,29 +1,49 @@ # HunFlair Tutorial 3: Entity Linking -After adding Named Entity Recognition tags to your sentence, you can run Named Entity Linking on these annotations. +After adding named entity recognition tags to your sentence, you can run named entity linking on these annotations. ```python from flair.models.biomedical_entity_linking import BiomedicalEntityLinker from flair.nn import Classifier from flair.tokenization import SciSpacyTokenizer from flair.data import Sentence -sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome", use_tokenizer=SciSpacyTokenizer()) +sentence = Sentence( + "The mutation in the ABCD1 gene causes X-linked adrenoleukodystrophy, " + "a neurodegenerative disease, which is exacerbated by exposure to high " + "levels of mercury in dolphin populations.", + use_tokenizer=SciSpacyTokenizer() +) -ner_tagger = Classifier.load("hunflair-disease") +ner_tagger = Classifier.load("hunflair") ner_tagger.predict(sentence) nen_tagger = BiomedicalEntityLinker.load("disease") nen_tagger.predict(sentence) +nen_tagger = BiomedicalEntityLinker.load("gene") +nen_tagger.predict(sentence) + +nen_tagger = BiomedicalEntityLinker.load("chemical") +nen_tagger.predict(sentence) + +nen_tagger = BiomedicalEntityLinker.load("species", entity_type="species") +nen_tagger.predict(sentence) + for tag in sentence.get_labels(): print(tag) ``` This should print: ~~~ -Span[0:2]: "Behavioral abnormalities" → Disease (0.6736) -Span[0:2]: "Behavioral abnormalities" → behavior disorders - MESH:D001523 (0.9772) -Span[9:12]: "Fragile X Syndrome" → Disease (0.99) -Span[9:12]: "Fragile X Syndrome" → fragile x syndrome - MESH:D005600 (1.0976) +Span[4:5]: "ABCD1" → Gene (0.9575) +Span[4:5]: "ABCD1" → abcd1 - NCBI-GENE-HUMAN:215 (14.5503) +Span[7:11]: "X-linked adrenoleukodystrophy" → Disease (0.9867) +Span[7:11]: "X-linked adrenoleukodystrophy" → x linked adrenoleukodystrophy - CTD-DISEASES:MESH:D000326 (13.9717) +Span[13:15]: "neurodegenerative disease" → Disease (0.8865) +Span[13:15]: "neurodegenerative disease" → neurodegenerative disease - CTD-DISEASES:MESH:D019636 (14.2779) +Span[25:26]: "mercury" → Chemical (0.9456) +Span[25:26]: "mercury" → mercury - CTD-CHEMICALS:MESH:D008628 (14.9185) +Span[27:28]: "dolphin" → Species (0.8082) +Span[27:28]: "dolphin" → marine dolphins - NCBI-TAXONOMY:9726 (14.473) ~~~ The output contains both the NER disease annotations and their entity / concept identifiers according to a knowledge base or ontology. We have pre-configured combinations of models and dictionaries for @@ -33,7 +53,7 @@ You can also provide your own model and dictionary: ```python from flair.models.biomedical_entity_linking import BiomedicalEntityLinker -nen_tagger = BiomedicalEntityLinker.load("name_or_path_to_your_model", dictionary_names_or_paths="name_or_path_to_your_dictionary") -nen_tagger = BiomedicalEntityLinker.load("path_to_custom_disease_model", dictionary_names_or_paths="disease") +nen_tagger = BiomedicalEntityLinker.load("name_or_path_to_your_model", dictionary_names_or_path="name_or_path_to_your_dictionary") +nen_tagger = BiomedicalEntityLinker.load("path_to_custom_disease_model", dictionary_names_or_path="disease") ```` You can use any combination of provided models, provided dictionaries and your own. From c6273e66d5d925007beaefeba894a68cfb83727e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20S=C3=A4nger?= Date: Wed, 12 Jul 2023 17:40:20 +0200 Subject: [PATCH 25/58] Fix tests and type annotations --- flair/embeddings/document.py | 3 +- flair/models/biomedical_entity_linking.py | 125 +++++++++++++--------- tests/test_datasets_biomedical.py | 8 +- 3 files changed, 82 insertions(+), 54 deletions(-) diff --git a/flair/embeddings/document.py b/flair/embeddings/document.py index c1e73442e..bf611521b 100644 --- a/flair/embeddings/document.py +++ b/flair/embeddings/document.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path from typing import Any, Dict, List, Optional, Union, cast import torch @@ -29,7 +30,7 @@ class TransformerDocumentEmbeddings(DocumentEmbeddings, TransformerEmbeddings): def __init__( self, - model: str = "bert-base-uncased", # set parameters with different default values + model: Union[str, Path] = "bert-base-uncased", # set parameters with different default values layers: str = "-1", layer_mean: bool = False, is_token_embedding: bool = False, diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index 7ce894b16..78434c784 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -11,14 +11,14 @@ from collections import defaultdict from enum import Enum, auto from pathlib import Path -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import Dict, Iterator, List, Optional, Tuple, Type, Union, cast import joblib import numpy as np import scipy import torch from huggingface_hub import hf_hub_download -from scipy.sparse._csr import csr_matrix +from scipy.sparse import csr_matrix from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm @@ -110,7 +110,7 @@ "species": "species", } -BIOMEDICAL_DICTIONARIES = { +BIOMEDICAL_DICTIONARIES: Dict[str, Type] = { "ctd-diseases": CTD_DISEASES_DICTIONARY, "ctd-chemicals": CTD_CHEMICALS_DICTIONARY, "ncbi-gene": NCBI_GENE_HUMAN_DICTIONARY, @@ -233,8 +233,8 @@ def process_entity_name(self, entity_name: str) -> str: entity_name = entity_name.lower() if self.remove_punctuation: - entity_name = self.rmv_puncts_regex.split(entity_name) - entity_name = " ".join(entity_name).strip() + name_parts = self.rmv_puncts_regex.split(entity_name) + entity_name = " ".join(name_parts).strip() return entity_name.strip() @@ -264,7 +264,7 @@ def __init__( self.ab3p_path = ab3p_path self.word_data_dir = word_data_dir self.preprocessor = preprocessor - self.abbreviation_dict = {} + self.abbreviation_dict: Dict[str, Dict[str, str]] = {} @property def name(self): @@ -275,7 +275,7 @@ def initialize(self, sentences: List[Sentence]) -> None: def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: sentence_text = sentence.to_tokenized_string().strip() - tokens = [token.text for token in entity_mention.data_point.tokens] + tokens = [token.text for token in cast(Span, entity_mention.data_point).tokens] parsed_tokens = [] for token in tokens: @@ -366,7 +366,7 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict :param sentences: list of sentences :result abbreviation_dict: abbreviations and their resolution detected in each input sentence """ - abbreviation_dict = defaultdict(dict) + abbreviation_dict: Dict = defaultdict(dict) # Create a temp file which holds the sentences we want to process with Ab3P with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as temp_file: @@ -451,10 +451,12 @@ def __init__( @classmethod def load( cls, dictionary_name_or_path: Union[Path, str], database_name: Optional[str] = None - ) -> "EntityLinkingDictionary": + ) -> "BiomedicalEntityLinkingDictionary": """Load dictionary: either pre-definded or from path""" if isinstance(dictionary_name_or_path, str): + dictionary_name_or_path = cast(str, dictionary_name_or_path) + if ( dictionary_name_or_path not in ENTITY_TYPE_TO_DICTIONARY and dictionary_name_or_path not in BIOMEDICAL_DICTIONARIES @@ -462,12 +464,13 @@ def load( raise ValueError( f"Unkwnon dictionary `{dictionary_name_or_path}`!" f" Available dictionaries are: {tuple(BIOMEDICAL_DICTIONARIES)}" - " If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)`" + " If you want to pass a local path please use the `Path` class, " + "i.e. `model_name_or_path=Path(my_path)`" ) dictionary_name_or_path = ENTITY_TYPE_TO_DICTIONARY.get(dictionary_name_or_path, dictionary_name_or_path) - reader = BIOMEDICAL_DICTIONARIES[dictionary_name_or_path]() + reader = BIOMEDICAL_DICTIONARIES[str(dictionary_name_or_path)]() else: # use custom dictionary file @@ -543,19 +546,26 @@ def __init__(self, dictionary: BiomedicalEntityLinkingDictionary): self.name_to_id_index = dict(list(dictionary.stream())) @classmethod - def load(cls, dictionary_name_or_path: str) -> "ExactStringMatchingRetrieverModel": + def load(cls, dictionary_name_or_path: Union[str, Path]) -> "ExactMatchCandidateGenerator": """Compatibility function""" return cls(BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path)) def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: - return [ - [ + candidates: List[List[EntityLinkingCandidate]] = [] + for mention in entity_mentions: + dict_entry = self.name_to_id_index.get(mention) + if not dict_entry: + candidates.append([]) + continue + + candidates.append([ self.build_candidate( - candidate=(em, self.name_to_id_index.get(em), 1.0), database_name=self.dictionary.database_name + candidate=(mention, dict_entry, 1.0), + database_name=self.dictionary.database_name ) - ] - for em in entity_mentions - ] + ]) + + return candidates class BigramTfIDFVectorizer: @@ -575,12 +585,12 @@ def fit(self, names: List[str]): self.encoder.fit(names) return self - def transform(self, names: List[str]) -> csr_matrix: + def transform(self, names: Union[List[str], np.ndarray]) -> csr_matrix: """Convert strings to sparse vectors""" embeddings = self.encoder.transform(names) return embeddings - def __call__(self, mentions: List[str]) -> np.ndarray: + def __call__(self, mentions: Union[List[str], np.ndarray]) -> np.ndarray: """Short for `transform`""" return self.transform(mentions) @@ -589,7 +599,7 @@ def save(self, path: Path): joblib.dump(self.encoder, str(path)) @classmethod - def load(cls, path: Path) -> "BigramTfIDFVectorizer": + def load(cls, path: Union[Path, str]) -> "BigramTfIDFVectorizer": """Instantiate from path""" newVectorizer = cls() @@ -611,7 +621,7 @@ class BiEncoderCandidateGenerator(AbstractCandidateGenerator): def __init__( self, model_name_or_path: Union[str, Path], - dictionary_name_or_path: str, + dictionary_name_or_path: Union[str, Path], similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, preprocessor: AbstractEntityPreprocessor = Ab3PEntityPreprocessor.load(preprocessor=EntityPreprocessor()), max_length: int = 25, @@ -653,14 +663,14 @@ def __init__( else: self.dictionary = BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path) - self.dictionary_data = [ + self.dictionary_data: List[Tuple[str, str]] = [ (self.preprocessor.process_entity_name(name), cui) for name, cui in self.dictionary.stream() ] # Load encoders self.dense_encoder = TransformerDocumentEmbeddings(model=model_name_or_path, is_token_embedding=False) self.sparse_encoder: Optional[BigramTfIDFVectorizer] = None - self.sparse_weight: Optional[float] = None + if self.hybrid_search: self.sparse_encoder, self.sparse_weight = self._get_sparse_encoder_and_weight( model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path @@ -680,12 +690,12 @@ def higher_is_better(self): return self.similarity_metric in [SimilarityMetric.COSINE, SimilarityMetric.INNER_PRODUCT] - def _get_cache_name(self, model_name_or_path: str, dictionary_name_or_path: str) -> str: + def _get_cache_name(self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path]) -> str: """Fixed name for caching""" # Check for embedded dictionary in cache dictionary_name = os.path.splitext(os.path.basename(dictionary_name_or_path))[0] - file_name = f"{model_name_or_path.split('/')[-1]}_{dictionary_name}" + file_name = f"{str(model_name_or_path).split('/')[-1]}_{dictionary_name}" pp_name = self.preprocessor.name if self.preprocessor is not None else "null" return f"{file_name}-{pp_name}" @@ -733,6 +743,8 @@ def _get_sparse_encoder_and_weight( sparse_weight_path = os.path.join(model_name_or_path, "sparse_weight.pt") if isinstance(model_name_or_path, str) and model_name_or_path in PRETRAINED_HYBRID_MODELS: + model_name_or_path = cast(str, model_name_or_path) + if not os.path.exists(sparse_encoder_path): sparse_encoder_path = hf_hub_download( repo_id=model_name_or_path, @@ -764,6 +776,8 @@ def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: :param inputs: Numpy array of entity / concept names :returns Numpy array containing the sparse embeddings of the names """ + if self.sparse_encoder is None: + raise AssertionError("Error while using the model") return self.sparse_encoder(inputs) @@ -802,19 +816,17 @@ def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: if flair.device.type == "cuda": torch.cuda.empty_cache() - dense_embeds = np.array(dense_embeds) - - return dense_embeds + return np.array(dense_embeds) # separate method to allow more sophisticated logic in the future, e.g.: ANN with HNSW, PQ... - def get_dense_index(self, names: List[str], path: Path) -> faiss.Index: + def get_dense_index(self, names: np.ndarray, path: Path) -> faiss.Index: """Load or create dense index and save it to disk""" if path.exists(): index = faiss.read_index(str(path)) else: - embeddings = self.embed_dense(inputs=names, batch_size=self.batch_size, show_progress=True) + embeddings = self.embed_dense(inputs=np.array(names), batch_size=self.batch_size, show_progress=True) index = faiss.IndexFlatIP(embeddings.shape[1]) index.add(embeddings) @@ -826,7 +838,7 @@ def get_dense_index(self, names: List[str], path: Path) -> faiss.Index: return index - def get_sparse_index(self, names: List[str], path: Path) -> csr_matrix: + def get_sparse_index(self, names: np.ndarray, path: Path) -> csr_matrix: """Load or create sparse index and save it to disk""" if path.exists(): @@ -840,7 +852,7 @@ def get_sparse_index(self, names: List[str], path: Path) -> csr_matrix: return index - def _load_indices(self, model_name_or_path: str, dictionary_name_or_path: str) -> Dict: + def _load_indices(self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path]) -> Dict: """Load cached indices if available, otherwise compute embeddings, build index and cache""" cache_name = self._get_cache_name( @@ -863,19 +875,20 @@ def _load_indices(self, model_name_or_path: str, dictionary_name_or_path: str) - continue extension = "bin" if index_type == "dense" else "npz" - file_name = f"index-{index_type}.{extension}" index_cache_file = cache_folder / file_name + names = np.array([n for n, _ in self.dictionary_data]) + if index_type == "dense": indices[index_type] = self.get_dense_index( - names=[n for n, _ in self.dictionary_data], path=index_cache_file + names=names, path=index_cache_file ) else: indices[index_type] = self.get_sparse_index( - names=[n for n, _ in self.dictionary_data], path=index_cache_file + names=names, path=index_cache_file ) return indices @@ -996,7 +1009,7 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLink return [ [ self.build_candidate( - candidate=tuple(self.dictionary_data[i]) + (score,), database_name=self.dictionary.database_name + candidate=self.dictionary_data[i] + (score,), database_name=self.dictionary.database_name ) for i, score in zip(mention_ids, mention_scores) ] @@ -1101,12 +1114,16 @@ def load( force_hybrid_search: bool = False, sparse_weight: float = DEFAULT_SPARSE_WEIGHT, entity_type: Optional[str] = None, - dictionary: Optional[List[Tuple[str, str]]] = None, - ): + dictionary: Optional[BiomedicalEntityLinkingDictionary] = None, + ) -> "BiomedicalEntityLinker": """ Loads a model for biomedical named entity normalization. See __init__ method for detailed docstring on arguments """ + if not isinstance(model_name_or_path, str): + raise AssertionError(f"String matching model name has to be an " + f"string (and not {type(model_name_or_path)}") + model_name_or_path = cast(str, model_name_or_path) if dictionary_name_or_path is None or isinstance(dictionary_name_or_path, str): dictionary_name_or_path = cls.__get_dictionary_path( @@ -1125,7 +1142,7 @@ def load( assert entity_type in ENTITY_TYPES, f"Invalid entity type `{entity_type}! Must be one of: {ENTITY_TYPES}" if model_name_or_path == "exact-string-match": - candidate_generator = ExactMatchCandidateGenerator.load(dictionary_name_or_path) + candidate_generator: AbstractCandidateGenerator = ExactMatchCandidateGenerator.load(dictionary_name_or_path) else: candidate_generator = BiEncoderCandidateGenerator( model_name_or_path=model_name_or_path, @@ -1141,10 +1158,9 @@ def load( ) logger.info( - "BiomedicalEntityLinker predicts: Dictionary `%s` (entity type: %s) with %s classes", + "BiomedicalEntityLinker predicts: Dictionary `%s` (entity type: %s)", dictionary_name_or_path, - entity_type, - len(candidate_generator.dictionary_data), + entity_type ) return cls(candidate_generator=candidate_generator, preprocessor=preprocessor, entity_type=entity_type) @@ -1155,7 +1171,7 @@ def __get_model_path_and_entity_type( entity_type: Optional[str] = None, hybrid_search: bool = False, force_hybrid_search: bool = False, - ) -> Tuple[str, str]: + ) -> Tuple[Union[str, Path], str]: """ Try to figure out what model the user wants """ @@ -1172,7 +1188,9 @@ def __get_model_path_and_entity_type( if hybrid_search: # load model by entity_type - if model_name_or_path in ENTITY_TYPES: + if isinstance(model_name_or_path, str) and model_name_or_path in ENTITY_TYPES: + model_name_or_path = cast(str, model_name_or_path) + # check if we have a hybrid pre-trained model if model_name_or_path in ENTITY_TYPE_TO_HYBRID_MODEL: entity_type = model_name_or_path @@ -1181,7 +1199,8 @@ def __get_model_path_and_entity_type( # check if user really wants to use hybrid search anyway if not force_hybrid_search: logger.warning( - "BiEncoderCandidateGenerator: model for entity type `%s` was not trained for hybrid search: no sparse search will be performed." + "BiEncoderCandidateGenerator: model for entity type `%s` was not trained for" + " hybrid search: no sparse search will be performed." " If you want to use sparse search please pass `force_hybrid_search=True`:" " we will fit a sparse encoder for you. The default value of `sparse_weight` is `%s`.", model_name_or_path, @@ -1191,16 +1210,20 @@ def __get_model_path_and_entity_type( else: if model_name_or_path not in PRETRAINED_HYBRID_MODELS and not force_hybrid_search: logger.warning( - "BiEncoderCandidateGenerator: model `%s` was not trained for hybrid search: no sparse search will be performed." + "BiEncoderCandidateGenerator: model `%s` was not trained for hybrid search: no sparse" + " search will be performed." " If you want to use sparse search please pass `force_hybrid_search=True`:" " we will fit a sparse encoder for you. The default value of `sparse_weight` is `%s`.", model_name_or_path, DEFAULT_SPARSE_WEIGHT, ) + + model_name_or_path = cast(str, model_name_or_path) entity_type = PRETRAINED_HYBRID_MODELS[model_name_or_path] else: - if model_name_or_path in ENTITY_TYPES: + if isinstance(model_name_or_path, str) and model_name_or_path in ENTITY_TYPES: + model_name_or_path = cast(str, model_name_or_path) model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] assert ( @@ -1213,7 +1236,7 @@ def __get_model_path_and_entity_type( def __get_dictionary_path( model_name_or_path: str, dictionary_name_or_path: Optional[Union[str, Path]] = None, - ) -> str: + ) -> Union[str, Path]: """ Try to figure out what dictionary (depending on the model) the user wants """ @@ -1223,7 +1246,9 @@ def __get_dictionary_path( "When using a string-matching candidate generator you must specify `dictionary_name_or_path`!" ) - if dictionary_name_or_path is not None: + if dictionary_name_or_path is not None and isinstance(dictionary_name_or_path, str): + dictionary_name_or_path = cast(str, dictionary_name_or_path) + if dictionary_name_or_path in ENTITY_TYPES: dictionary_name_or_path = ENTITY_TYPE_TO_DICTIONARY[dictionary_name_or_path] else: diff --git a/tests/test_datasets_biomedical.py b/tests/test_datasets_biomedical.py index c515e068c..fbff952ae 100644 --- a/tests/test_datasets_biomedical.py +++ b/tests/test_datasets_biomedical.py @@ -182,7 +182,7 @@ def assert_conll_writer_output( assert contents == expected_output -def test_filter_nested_entities(caplog): +def test_filter_nested_entities(recwarn): entities_per_document = { "d0": [Entity((0, 1), "t0"), Entity((2, 3), "t1")], "d1": [Entity((0, 6), "t0"), Entity((2, 3), "t1"), Entity((4, 5), "t2")], @@ -204,9 +204,11 @@ def test_filter_nested_entities(caplog): } dataset = InternalBioNerDataset(documents={}, entities_per_document=entities_per_document) - caplog.set_level(logging.WARNING) filter_nested_entities(dataset) - assert "WARNING: Corpus modified by filtering nested entities." in caplog.text + + assert len(recwarn.list) == 1 + assert isinstance(recwarn.list[0].message, UserWarning) + assert "Corpus modified by filtering nested entities." in recwarn.list[0].message.args[0] for key, entities in dataset.entities_per_document.items(): assert key in target From 9a3fff63d4381ce4dbad7dfd10b55da98f22d283 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 28 Aug 2023 14:50:25 +0200 Subject: [PATCH 26/58] code format and fix ruff & most mypy errors. --- flair/data.py | 37 ++-- flair/embeddings/document.py | 3 +- flair/models/biomedical_entity_linking.py | 251 +++++++++------------- 3 files changed, 118 insertions(+), 173 deletions(-) diff --git a/flair/data.py b/flair/data.py index 6e197ff65..329d8550d 100644 --- a/flair/data.py +++ b/flair/data.py @@ -6,7 +6,7 @@ from collections import Counter, defaultdict from operator import itemgetter from pathlib import Path -from typing import Dict, Iterable, List, NamedTuple, Optional, Tuple, Union, cast +from typing import Dict, Iterable, List, NamedTuple, Optional, Union, cast import torch from deprecated.sphinx import deprecated @@ -434,7 +434,7 @@ def __len__(self) -> int: class EntityLinkingCandidate: - """Represent a single candidate returned by a CandidateGenerator""" + """Represent a single candidate returned by a CandidateGenerator.""" def __init__( self, @@ -444,12 +444,14 @@ def __init__( score: float = 1.0, additional_ids: Optional[Union[List[str], str]] = None, ): - """ - :param concept_id: Identifier of the entity / concept from the knowledge base / ontology - :param concept_name: (Canonical) name of the entity / concept from the knowledge base / ontology - :param score: Matching score of the entity / concept according to the entity mention - :param additional_ids: List of additional identifiers for the concept / entity in the KB / ontology - :param database_name: Name of the knowlege base / ontology + """Represent a single candidate returned by a CandidateGenerator. + + Args: + concept_id: Identifier of the entity / concept from the knowledge base / ontology + concept_name: (Canonical) name of the entity / concept from the knowledge base / ontology + score: Matching score of the entity / concept according to the entity mention + additional_ids: List of additional identifiers for the concept / entity in the KB / ontology + database_name: Name of the knowlege base / ontology """ self.concept_id = concept_id self.concept_name = concept_name @@ -468,24 +470,23 @@ def __repr__(self) -> str: class EntityLinkingLabel(Label): - """ - Label class models entity linking annotations. Each entity linking label has a data point it refers + """Label class models entity linking annotations. + + Each entity linking label has a data point it refers to as well as the identifier and name of the concept / entity from a knowledge base or ontology. Optionally, additional concepts identifier and the database name can be provided. """ def __init__(self, data_point: DataPoint, candidates: List[EntityLinkingCandidate]): - """ - Initializes the label instance. - :param data_point: Data point / span the label refers to - :param candidates: **sorted** list of candidates from candidate generator + """Initializes the label instance. + + Args: + data_point: Data point / span the label refers to + candidates: **sorted** list of candidates from candidate generator. """ def is_sorted(lst, key=lambda x: x, comparison=lambda x, y: x >= y): - for i, el in enumerate(lst[1:]): - if comparison(key(el), key(lst[i])): - return False - return True + return all(not comparison(key(el), key(lst[i])) for i, el in enumerate(lst[1:])) # candidates must be sorted, regardless if higher is better or not assert is_sorted(candidates, key=lambda x: x.score) or is_sorted( diff --git a/flair/embeddings/document.py b/flair/embeddings/document.py index bf611521b..c1e73442e 100644 --- a/flair/embeddings/document.py +++ b/flair/embeddings/document.py @@ -1,5 +1,4 @@ import logging -from pathlib import Path from typing import Any, Dict, List, Optional, Union, cast import torch @@ -30,7 +29,7 @@ class TransformerDocumentEmbeddings(DocumentEmbeddings, TransformerEmbeddings): def __init__( self, - model: Union[str, Path] = "bert-base-uncased", # set parameters with different default values + model: str = "bert-base-uncased", # set parameters with different default values layers: str = "-1", layer_mean: bool = False, is_token_embedding: bool = False, diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index 78434c784..616513083 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -133,7 +133,7 @@ class SimilarityMetric(Enum): - """Similarity metrics""" + """Similarity metrics.""" INNER_PRODUCT = faiss.METRIC_INNER_PRODUCT # L2 = faiss.METRIC_L2 @@ -141,9 +141,7 @@ class SimilarityMetric(Enum): def timeit(func): - """ - This function shows the execution time of the function object passed - """ + """This function shows the execution time of the function object passed.""" def wrap_func(*args, **kwargs): start = time.time() @@ -157,8 +155,8 @@ def wrap_func(*args, **kwargs): class AbstractEntityPreprocessor(ABC): - """ - A pre-processor used to transform / clean both entity mentions and entity names + """A pre-processor used to transform / clean both entity mentions and entity names. + This class provides the basic interface for such transformations and must provide a `name` attribute to uniquely identify the type of preprocessing applied. """ @@ -166,14 +164,11 @@ class AbstractEntityPreprocessor(ABC): @property @abstractmethod def name(self) -> str: - """ - This is needed to correctly cache different multiple version of the dictionary - """ + """This is needed to correctly cache different multiple version of the dictionary.""" @abstractmethod def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: - """ - Processes the given entity mention and applies the transformation procedure to it. + """Processes the given entity mention and applies the transformation procedure to it. :param entity_mention: entity mention under investigation :param sentence: sentence in which the entity mentioned occurred @@ -182,44 +177,45 @@ def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: @abstractmethod def process_entity_name(self, entity_name: str) -> str: - """ - Processes the given entity name (originating from a knowledge base / ontology) and - applies the transformation procedure to it. + """Processes the given entity name and applies the transformation procedure to it. - :param entity_name: entity mention given as DataPoint - :result: Cleaned / transformed string representation of the given entity mention + Args: + entity_name: entity mention given as DataPoint + Returns: + Cleaned / transformed string representation of the given entity mention """ @abstractmethod def initialize(self, sentences: List[Sentence]): - """ - Initializes the pre-processor for a batch of sentences, which is may be necessary for - more sophisticated transformations. + """Initializes the pre-processor for a batch of sentences. - :param sentences: List of sentences that will be processed. + This may be necessary for more sophisticated transformations. + + Args: + sentences: List of sentences that will be processed. """ class EntityPreprocessor(AbstractEntityPreprocessor): - """ - Entity preprocessor adapted from: + """Entity preprocessor using Synonym Marginalization. + + Adapted from: Sung et al. 2020, Biomedical Entity Representations with Synonym Marginalization - https://github.com/dmis-lab/BioSyn/blob/master/src/biosyn/preprocesser.py#L5 + https://github.com/dmis-lab/BioSyn/blob/master/src/biosyn/preprocesser.py#L5. The preprocessor provides basic string transformation options including lower-casing, removal of punctuations symbols, etc. """ def __init__(self, lowercase: bool = True, remove_punctuation: bool = True): - """ - Initializes the mention preprocessor. + """Initializes the mention preprocessor. :param lowercase: Indicates whether to perform lowercasing or not (True by default) :param remove_punctuation: Indicates whether to perform removal punctuations symbols (True by default) """ self.lowercase = lowercase self.remove_punctuation = remove_punctuation - self.rmv_puncts_regex = re.compile(r"[\s{}]+".format(re.escape(string.punctuation))) + self.rmv_puncts_regex = re.compile(rf"[\s{re.escape(string.punctuation)}]+") @property def name(self): @@ -243,19 +239,18 @@ def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: class Ab3PEntityPreprocessor(AbstractEntityPreprocessor): - """ - Entity preprocessor which uses Ab3P, an (biomedical) abbreviation definition detector: - Abbreviation definition identification based on automatic precision estimates. - Sohn S, Comeau DC, Kim W, Wilbur WJ. BMC Bioinformatics. 2008 Sep 25;9:402. - PubMed ID: 18817555 - https://github.com/ncbi-nlp/Ab3P + """Entity preprocessor which uses Ab3P, an (biomedical) abbreviation definition detector. + + Abbreviation definition identification based on automatic precision estimates. + Sohn S, Comeau DC, Kim W, Wilbur WJ. BMC Bioinformatics. 2008 Sep 25;9:402. + PubMed ID: 18817555 + https://github.com/ncbi-nlp/Ab3P. """ def __init__( self, ab3p_path: Path, word_data_dir: Path, preprocessor: Optional[AbstractEntityPreprocessor] = None ) -> None: - """ - Creates the mention pre-processor + """Creates the mention pre-processor. :param ab3p_path: Path to the folder containing the Ab3P implementation :param word_data_dir: Path to the word data directory @@ -282,10 +277,9 @@ def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: if self.preprocessor is not None: token = self.preprocessor.process_entity_name(token) - if sentence_text in self.abbreviation_dict: - if token.lower() in self.abbreviation_dict[sentence_text]: - parsed_tokens.append(self.abbreviation_dict[sentence_text][token.lower()]) - continue + if sentence_text in self.abbreviation_dict and token.lower() in self.abbreviation_dict[sentence_text]: + parsed_tokens.append(self.abbreviation_dict[sentence_text][token.lower()]) + continue if len(token) != 0: parsed_tokens.append(token) @@ -301,7 +295,7 @@ def process_entity_name(self, entity_name: str) -> str: return entity_name @classmethod - def load(cls, ab3p_path: Path = None, preprocessor: Optional[AbstractEntityPreprocessor] = None): + def load(cls, ab3p_path: Optional[Path] = None, preprocessor: Optional[AbstractEntityPreprocessor] = None): data_dir = flair.cache_root / "ab3p" if not data_dir.exists(): data_dir.mkdir(parents=True) @@ -318,7 +312,6 @@ def load(cls, ab3p_path: Path = None, preprocessor: Optional[AbstractEntityPrepr @classmethod def download_ab3p(cls, data_dir: Path, word_data_dir: Path) -> Path: """Downloads the Ab3P tool and all necessary data files.""" - # Download word data for Ab3P if not already downloaded ab3p_url = "https://raw.githubusercontent.com/dmis-lab/BioSyn/master/Ab3P/WordData/" @@ -352,9 +345,9 @@ def download_ab3p(cls, data_dir: Path, word_data_dir: Path) -> Path: return ab3p_path def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict[str, Dict[str, str]]: - """ - Processes the given sentences with the Ab3P tool. The function returns a (nested) dictionary - containing the abbreviations found for each sentence, e.g.: + """Processes the given sentences with the Ab3P tool. + + The function returns a (nested) dictionary containing the abbreviations found for each sentence, e.g.: { "Respiratory syncytial viruses ( RSV ) are a subgroup of the paramyxoviruses.": @@ -383,8 +376,7 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict try: result = subprocess.run( [self.ab3p_path, temp_file.name], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + capture_output=True, check=True, ) except subprocess.CalledProcessError: @@ -398,11 +390,7 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict logger.error( "Error when using Ab3P for abbreviation resolution. A file named path_Ab3p needs to exist in your current directory containing the path to the WordData directory for Ab3P to work!" ) - elif "Cannot open" in line: - logger.error( - "Error when using Ab3P for abbreviation resolution. Could not open the WordData directory for Ab3P!" - ) - elif "failed to open" in line: + elif "Cannot open" in line or "failed to open" in line: logger.error( "Error when using Ab3P for abbreviation resolution. Could not open the WordData directory for Ab3P!" ) @@ -432,8 +420,9 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict class BiomedicalEntityLinkingDictionary: - """ - Class to load named entity dictionaries: either pre-defined or from a path on disk. + """Class to load named entity dictionaries. + + Loading either pre-defined or from a path on disk. For the latter, every line in the file must be formatted as follows: concept_id||concept_name @@ -452,8 +441,7 @@ def __init__( def load( cls, dictionary_name_or_path: Union[Path, str], database_name: Optional[str] = None ) -> "BiomedicalEntityLinkingDictionary": - """Load dictionary: either pre-definded or from path""" - + """Load dictionary: either pre-definded or from path.""" if isinstance(dictionary_name_or_path, str): dictionary_name_or_path = cast(str, dictionary_name_or_path) @@ -483,29 +471,23 @@ def load( @property def database_name(self) -> str: - """Database name of the dictionary""" - + """Database name of the dictionary.""" return self.reader.database_name def stream(self) -> Iterator[Tuple[str, str]]: - """ - Stream entries from preprocessed dictionary - """ - - for entry in self.reader.stream(): - yield entry + """Stream entries from preprocessed dictionary.""" + yield from self.reader.stream() class AbstractCandidateGenerator(ABC): - """ - Base class for a candidate generator, i.e. given a mention of an entity, find matching - entries from the dictionary. + """Base class for a candidate generator. + + Given a mention of an entity, find matching entries from the dictionary. """ @abstractmethod def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: - """ - Returns the top-k entity / concept identifiers for each entity mention. + """Returns the top-k entity / concept identifiers for each entity mention. :param entity_mentions: Entity mentions :param top_k: Number of best-matching entities from the knowledge base to return @@ -513,8 +495,7 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLink """ def build_candidate(self, candidate: Tuple[str, str, float], database_name: str) -> EntityLinkingCandidate: - """Get nice container with all info about entity linking candidate""" - + """Get nice container with all info about entity linking candidate.""" concept_name = candidate[0] concept_id = candidate[1] score = candidate[2] @@ -536,9 +517,7 @@ def build_candidate(self, candidate: Tuple[str, str, float], database_name: str) class ExactMatchCandidateGenerator(AbstractCandidateGenerator): - """ - Candidate generator using exact string matching as search criterion - """ + """Candidate generator using exact string matching as search criterion.""" def __init__(self, dictionary: BiomedicalEntityLinkingDictionary): # Build index which maps concept / entity names to concept / entity ids @@ -547,30 +526,31 @@ def __init__(self, dictionary: BiomedicalEntityLinkingDictionary): @classmethod def load(cls, dictionary_name_or_path: Union[str, Path]) -> "ExactMatchCandidateGenerator": - """Compatibility function""" + """Compatibility function.""" return cls(BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path)) def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: candidates: List[List[EntityLinkingCandidate]] = [] - for mention in entity_mentions: + for mention in entity_mentions: dict_entry = self.name_to_id_index.get(mention) if not dict_entry: candidates.append([]) continue - candidates.append([ - self.build_candidate( - candidate=(mention, dict_entry, 1.0), - database_name=self.dictionary.database_name - ) - ]) + candidates.append( + [ + self.build_candidate( + candidate=(mention, dict_entry, 1.0), database_name=self.dictionary.database_name + ) + ] + ) return candidates class BigramTfIDFVectorizer: - """ - Wrapper for sklearn TfIDFVectorizer w/ fixed ngram range at the character level + """Wrapper for sklearn TfIDFVectorizer w/ fixed ngram range at the character level. + Implementation adapted from: Sung et al.: Biomedical Entity Representations with Synonym Marginalization, 2020 @@ -581,26 +561,26 @@ def __init__(self): self.encoder = TfidfVectorizer(analyzer="char", ngram_range=(1, 2)) def fit(self, names: List[str]): - """Learn vocabulary""" + """Learn vocabulary.""" self.encoder.fit(names) return self def transform(self, names: Union[List[str], np.ndarray]) -> csr_matrix: - """Convert strings to sparse vectors""" + """Convert strings to sparse vectors.""" embeddings = self.encoder.transform(names) return embeddings def __call__(self, mentions: Union[List[str], np.ndarray]) -> np.ndarray: - """Short for `transform`""" + """Short for `transform`.""" return self.transform(mentions) def save(self, path: Path): - """Save vectorizer to disk""" + """Save vectorizer to disk.""" joblib.dump(self.encoder, str(path)) @classmethod def load(cls, path: Union[Path, str]) -> "BigramTfIDFVectorizer": - """Instantiate from path""" + """Instantiate from path.""" newVectorizer = cls() # with open(path, "rb") as fin: @@ -613,10 +593,7 @@ def load(cls, path: Union[Path, str]) -> "BigramTfIDFVectorizer": class BiEncoderCandidateGenerator(AbstractCandidateGenerator): - """ - Candidate generator using both dense (transformer-based) and (optionally) sparse vector representations, - to search candidates in a knowledge base / dictionary. - """ + """Candidate generator using both dense and (optionally) sparse vector representations, to search candidates.""" def __init__( self, @@ -631,8 +608,7 @@ def __init__( force_hybrid_search: bool = False, dictionary: Optional[BiomedicalEntityLinkingDictionary] = None, ): - """ - Initializes the BiEncoderEntityRetrieverModel. + """Initializes the BiEncoderEntityRetrieverModel. :param model_name_or_path: Name of or path to the transformer model to be used. :param dictionary_name_or_path: Name of or path to the transformer model to be used. @@ -668,7 +644,7 @@ def __init__( ] # Load encoders - self.dense_encoder = TransformerDocumentEmbeddings(model=model_name_or_path, is_token_embedding=False) + self.dense_encoder = TransformerDocumentEmbeddings(model=str(model_name_or_path), is_token_embedding=False) self.sparse_encoder: Optional[BigramTfIDFVectorizer] = None if self.hybrid_search: @@ -683,16 +659,14 @@ def __init__( @property def higher_is_better(self): - """ - Determine if similarity is proportional to score. - E.g. for L2 lower is better, while INNER_PRODUCT higher is better - """ + """Determine if similarity is proportional to score. + E.g. for L2 lower is better, while INNER_PRODUCT higher is better. + """ return self.similarity_metric in [SimilarityMetric.COSINE, SimilarityMetric.INNER_PRODUCT] def _get_cache_name(self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path]) -> str: - """Fixed name for caching""" - + """Fixed name for caching.""" # Check for embedded dictionary in cache dictionary_name = os.path.splitext(os.path.basename(dictionary_name_or_path))[0] file_name = f"{str(model_name_or_path).split('/')[-1]}_{dictionary_name}" @@ -702,8 +676,7 @@ def _get_cache_name(self, model_name_or_path: Union[str, Path], dictionary_name_ @timeit def fit_sparse_encoder(self) -> BigramTfIDFVectorizer: - """Fit sparse encoder to current dictionary""" - + """Fit sparse encoder to current dictionary.""" logger.info( "BiEncoderCandidateGenerator: hybrid model has no pretrained sparse encoder. Fit to dictionary `%s`", self.dictionary_name_or_path, @@ -717,8 +690,7 @@ def fit_sparse_encoder(self) -> BigramTfIDFVectorizer: def _handle_sparse_encoder( self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path] ) -> BigramTfIDFVectorizer: - """If necessary fit and cache sparse encoder""" - + """If necessary fit and cache sparse encoder.""" if isinstance(model_name_or_path, str): cache_name = self._get_cache_name( model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path @@ -770,8 +742,7 @@ def _get_sparse_encoder_and_weight( return sparse_encoder, sparse_weight def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: - """ - Create sparse embeddings from array of entity mentions/names. + """Create sparse embeddings from array of entity mentions/names. :param inputs: Numpy array of entity / concept names :returns Numpy array containing the sparse embeddings of the names @@ -782,8 +753,7 @@ def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: return self.sparse_encoder(inputs) def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: bool = False) -> np.ndarray: - """ - Create dense embeddings from array of entity mentions/names. + """Create dense embeddings from array of entity mentions/names. :param names: Numpy array of entity / concept names :param batch_size: Batch size used while embedding the name @@ -820,8 +790,7 @@ def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: # separate method to allow more sophisticated logic in the future, e.g.: ANN with HNSW, PQ... def get_dense_index(self, names: np.ndarray, path: Path) -> faiss.Index: - """Load or create dense index and save it to disk""" - + """Load or create dense index and save it to disk.""" if path.exists(): index = faiss.read_index(str(path)) @@ -839,8 +808,7 @@ def get_dense_index(self, names: np.ndarray, path: Path) -> faiss.Index: return index def get_sparse_index(self, names: np.ndarray, path: Path) -> csr_matrix: - """Load or create sparse index and save it to disk""" - + """Load or create sparse index and save it to disk.""" if path.exists(): index = scipy.sparse.load_npz(str(path)) else: @@ -853,8 +821,7 @@ def get_sparse_index(self, names: np.ndarray, path: Path) -> csr_matrix: return index def _load_indices(self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path]) -> Dict: - """Load cached indices if available, otherwise compute embeddings, build index and cache""" - + """Load cached indices if available, otherwise compute embeddings, build index and cache.""" cache_name = self._get_cache_name( model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path ) @@ -882,21 +849,16 @@ def _load_indices(self, model_name_or_path: Union[str, Path], dictionary_name_or names = np.array([n for n, _ in self.dictionary_data]) if index_type == "dense": - indices[index_type] = self.get_dense_index( - names=names, path=index_cache_file - ) + indices[index_type] = self.get_dense_index(names=names, path=index_cache_file) else: - indices[index_type] = self.get_sparse_index( - names=names, path=index_cache_file - ) + indices[index_type] = self.get_sparse_index(names=names, path=index_cache_file) return indices @timeit def search_sparse(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np.ndarray, np.ndarray]: - """ - Find candidates with sparse representations + """Find candidates with sparse representations. :param entity_mentions: list of entity mentions (~ queries) :param top_k: number of candidates to retrieve per mention @@ -928,13 +890,11 @@ def search_sparse(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np. @timeit def search_dense(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np.ndarray, np.ndarray]: - """ - Find candidates with dense representations (FAISS) + """Find candidates with dense representations (FAISS). :param entity_mentions: list of entity mentions (~ queries) :param top_k: number of candidates to retrieve """ - # Compute dense embedding for the given entity mention mention_dense_embeds = self.embed_dense(inputs=np.array(entity_mentions), batch_size=self.batch_size) @@ -954,11 +914,10 @@ def combine_dense_and_sparse_results( sparse_scores: np.ndarray, top_k: int = 1, ): - """ - Expand dense results with sparse ones (that are not already in the dense) and re-weight the - score as: dense_score + sparse_weight * sparse_scores - """ + """Expand dense results with sparse ones ans re-weight them. + Re-weight the score as: dense_score + sparse_weight * sparse_scores. + """ hybrid_ids = [] hybrid_scores = [] for i in range(dense_ids.shape[0]): @@ -985,14 +944,12 @@ def combine_dense_and_sparse_results( return hybrid_scores, hybrid_ids def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: - """ - Returns the top-k entity / concept identifiers for each entity mention. + """Returns the top-k entity / concept identifiers for each entity mention. :param entity_mentions: Entity mentions :param top_k: Number of best-matching entities from the knowledge base to return :result: List containing a list of entity linking candidates per entity mention from the input """ - ids, scores = self.search_dense(entity_mentions=entity_mentions, top_k=top_k) if self.hybrid_search and self.sparse_encoder is not None: @@ -1018,7 +975,7 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLink class BiomedicalEntityLinker: - """Entity linking model for the biomedical domain""" + """Entity linking model for the biomedical domain.""" def __init__( self, @@ -1037,7 +994,6 @@ def extract_mentions( annotation_layers: Optional[List[str]] = None, ) -> Tuple[List[int], List[Span], List[str], List[str]]: """Unpack all mentions in sentences for batch search.""" - source = [] data_points = [] mentions = [] @@ -1068,9 +1024,7 @@ def predict( annotation_layers: Optional[List[str]] = None, top_k: int = 1, ) -> None: - """ - Predicts the best matching top-k entity / concept identifiers of all named entities annotated - with tag input_entity_annotation_layer. + """Predicts the best matching top-k entity / concept identifiers of all named entities annotated with tag input_entity_annotation_layer. :param sentences: One or more sentences to run the prediction on :param annotation_layers: List of annotation layers to extract entity mentions @@ -1116,13 +1070,12 @@ def load( entity_type: Optional[str] = None, dictionary: Optional[BiomedicalEntityLinkingDictionary] = None, ) -> "BiomedicalEntityLinker": - """ - Loads a model for biomedical named entity normalization. - See __init__ method for detailed docstring on arguments + """Loads a model for biomedical named entity normalization. + + See __init__ method for detailed docstring on arguments. """ if not isinstance(model_name_or_path, str): - raise AssertionError(f"String matching model name has to be an " - f"string (and not {type(model_name_or_path)}") + raise AssertionError(f"String matching model name has to be an string (and not {type(model_name_or_path)}") model_name_or_path = cast(str, model_name_or_path) if dictionary_name_or_path is None or isinstance(dictionary_name_or_path, str): @@ -1158,9 +1111,7 @@ def load( ) logger.info( - "BiomedicalEntityLinker predicts: Dictionary `%s` (entity type: %s)", - dictionary_name_or_path, - entity_type + "BiomedicalEntityLinker predicts: Dictionary `%s` (entity type: %s)", dictionary_name_or_path, entity_type ) return cls(candidate_generator=candidate_generator, preprocessor=preprocessor, entity_type=entity_type) @@ -1172,10 +1123,7 @@ def __get_model_path_and_entity_type( hybrid_search: bool = False, force_hybrid_search: bool = False, ) -> Tuple[Union[str, Path], str]: - """ - Try to figure out what model the user wants - """ - + """Try to figure out what model the user wants.""" if model_name_or_path not in MODELS and model_name_or_path not in ENTITY_TYPES: raise ValueError( f"Unknown model `{model_name_or_path}`!" @@ -1237,10 +1185,7 @@ def __get_dictionary_path( model_name_or_path: str, dictionary_name_or_path: Optional[Union[str, Path]] = None, ) -> Union[str, Path]: - """ - Try to figure out what dictionary (depending on the model) the user wants - """ - + """Try to figure out what dictionary (depending on the model) the user wants.""" if model_name_or_path in STRING_MATCHING_MODELS and dictionary_name_or_path is None: raise ValueError( "When using a string-matching candidate generator you must specify `dictionary_name_or_path`!" From fd9507f6cde9f90bea078c0f6f0ad8c628f724b7 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 4 Sep 2023 16:28:16 +0200 Subject: [PATCH 27/58] refine interface of BiomedicalEntityLinker --- flair/data.py | 67 +------------- flair/models/biomedical_entity_linking.py | 103 +++++++++------------ pyproject.toml | 1 + tests/test_biomedical_entity_linking.py | 105 ++++++++++++---------- 4 files changed, 104 insertions(+), 172 deletions(-) diff --git a/flair/data.py b/flair/data.py index 329d8550d..3721e6076 100644 --- a/flair/data.py +++ b/flair/data.py @@ -336,8 +336,8 @@ def get_metadata(self, key: str) -> typing.Any: def has_metadata(self, key: str) -> bool: return key in self._metadata - def add_label(self, typename: str, value_or_label: Union[str, Label], score: float = 1.0): - label = value_or_label if isinstance(value_or_label, Label) else Label(self, value_or_label, score) + def add_label(self, typename: str, value: str, score: float = 1.0): + label = Label(self, value, score) if typename not in self.annotation_layers: self.annotation_layers[typename] = [label] @@ -441,7 +441,6 @@ def __init__( concept_id: str, concept_name: str, database_name: str, - score: float = 1.0, additional_ids: Optional[Union[List[str], str]] = None, ): """Represent a single candidate returned by a CandidateGenerator. @@ -456,11 +455,10 @@ def __init__( self.concept_id = concept_id self.concept_name = concept_name self.database_name = database_name - self.score = score self.additional_ids = additional_ids def __str__(self) -> str: - string = f"EntityLinkingCandidate: {self.database_name}:{self.concept_id} - {self.concept_name} - {self.score}" + string = f"EntityLinkingCandidate: {self.database_name}:{self.concept_id} - {self.concept_name}" if self.additional_ids is not None: string += f" - {self.additional_ids}" return string @@ -469,65 +467,6 @@ def __repr__(self) -> str: return str(self) -class EntityLinkingLabel(Label): - """Label class models entity linking annotations. - - Each entity linking label has a data point it refers - to as well as the identifier and name of the concept / entity from a knowledge base or ontology. - Optionally, additional concepts identifier and the database name can be provided. - """ - - def __init__(self, data_point: DataPoint, candidates: List[EntityLinkingCandidate]): - """Initializes the label instance. - - Args: - data_point: Data point / span the label refers to - candidates: **sorted** list of candidates from candidate generator. - """ - - def is_sorted(lst, key=lambda x: x, comparison=lambda x, y: x >= y): - return all(not comparison(key(el), key(lst[i])) for i, el in enumerate(lst[1:])) - - # candidates must be sorted, regardless if higher is better or not - assert is_sorted(candidates, key=lambda x: x.score) or is_sorted( - candidates, key=lambda x: x.score, comparison=lambda x, y: x <= y - ), "List of candidates must be sorted!" - - super().__init__(data_point, candidates[0].concept_id, candidates[0].score) - self.candidates = candidates - self.concept_name = self.candidates[0].concept_name - self.database_name = self.candidates[0].database_name - - def __str__(self): - return ( - f"{self.data_point.unlabeled_identifier}{flair._arrow} " - f"{self.concept_name} - {self.database_name}:{self._value} ({round(self._score, 4)})" - ) - - def __repr__(self): - return ( - f"{self.data_point.unlabeled_identifier}{flair._arrow} " - f"{self.concept_name} - {self.database_name}:{self._value} ({round(self._score, 4)})" - ) - - def __len__(self): - return len(self.data_point) - - def __eq__(self, other): - return ( - self.value == other.value - and self.data_point == other.data_point - and self.concept_name == other.concept_name - and self.identifier == other.identifier - and self.database_name == other.database_name - and self.score == other.score - ) - - @property - def identifier(self): - return f"{self.value}" - - DT = typing.TypeVar("DT", bound=DataPoint) DT2 = typing.TypeVar("DT2", bound=DataPoint) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index 616513083..be93a734d 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -24,7 +24,7 @@ from tqdm import tqdm import flair -from flair.data import EntityLinkingCandidate, EntityLinkingLabel, Label, Sentence, Span +from flair.data import EntityLinkingCandidate, Label, Sentence, Span from flair.datasets import ( CTD_CHEMICALS_DICTIONARY, CTD_DISEASES_DICTIONARY, @@ -47,10 +47,8 @@ f"You need to install faiss to run the biomedical entity linking: `pip install faiss-cpu=={FAISS_VERSION}`" ) from error - logger = logging.getLogger("flair") - PRETRAINED_DENSE_MODELS = [ "cambridgeltl/SapBERT-from-PubMedBERT-fulltext", ] @@ -74,19 +72,12 @@ MODELS = PRETRAINED_MODELS + STRING_MATCHING_MODELS -ENTITY_TYPES = ["disease", "chemical", "gene", "species"] - -ENTITY_TYPE_TO_LABELS = { - "disease": "diseases", - "gene": "genes", - "species": "species", - "chemical": "chemical", -} +ENTITY_TYPES = ["diseases", "chemical", "genes", "species"] ENTITY_TYPE_TO_HYBRID_MODEL = { - "disease": "dmis-lab/biosyn-sapbert-bc5cdr-disease", + "diseases": "dmis-lab/biosyn-sapbert-bc5cdr-disease", "chemical": "dmis-lab/biosyn-sapbert-bc5cdr-chemical", - "gene": "dmis-lab/biosyn-sapbert-bc2gn", + "genes": "dmis-lab/biosyn-sapbert-bc2gn", } # for now we always fall back to SapBERT, @@ -95,21 +86,13 @@ entity_type: "cambridgeltl/SapBERT-from-PubMedBERT-fulltext" for entity_type in ENTITY_TYPES } - ENTITY_TYPE_TO_DICTIONARY = { - "gene": "ncbi-gene", + "genes": "ncbi-gene", "species": "ncbi-taxonomy", - "disease": "ctd-diseases", + "diseases": "ctd-diseases", "chemical": "ctd-chemicals", } -ENTITY_TYPE_TO_ANNOTATION_LAYER = { - "disease": "diseases", - "gene": "genes", - "chemical": "chemicals", - "species": "species", -} - BIOMEDICAL_DICTIONARIES: Dict[str, Type] = { "ctd-diseases": CTD_DISEASES_DICTIONARY, "ctd-chemicals": CTD_CHEMICALS_DICTIONARY, @@ -128,7 +111,6 @@ "dmis-lab/biosyn-sapbert-bc2gn": "ncbi-gene", } - DEFAULT_SPARSE_WEIGHT = 0.5 @@ -433,7 +415,7 @@ class BiomedicalEntityLinkingDictionary: """ def __init__( - self, reader: Union[AbstractBiomedicalEntityLinkingDictionary, ParsedBiomedicalEntityLinkingDictionary] + self, reader: AbstractBiomedicalEntityLinkingDictionary ): self.reader = reader @@ -478,6 +460,9 @@ def stream(self) -> Iterator[Tuple[str, str]]: """Stream entries from preprocessed dictionary.""" yield from self.reader.stream() + def __getitem__(self, item: str) -> EntityLinkingCandidate: + return self.reader[item] + class AbstractCandidateGenerator(ABC): """Base class for a candidate generator. @@ -486,7 +471,7 @@ class AbstractCandidateGenerator(ABC): """ @abstractmethod - def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: + def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]: """Returns the top-k entity / concept identifiers for each entity mention. :param entity_mentions: Entity mentions @@ -498,7 +483,6 @@ def build_candidate(self, candidate: Tuple[str, str, float], database_name: str) """Get nice container with all info about entity linking candidate.""" concept_name = candidate[0] concept_id = candidate[1] - score = candidate[2] if "|" in concept_id: labels = concept_id.split("|") @@ -510,7 +494,6 @@ def build_candidate(self, candidate: Tuple[str, str, float], database_name: str) return EntityLinkingCandidate( concept_id=concept_id, concept_name=concept_name, - score=score, additional_ids=additional_labels, database_name=database_name, ) @@ -529,23 +512,16 @@ def load(cls, dictionary_name_or_path: Union[str, Path]) -> "ExactMatchCandidate """Compatibility function.""" return cls(BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path)) - def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: - candidates: List[List[EntityLinkingCandidate]] = [] + def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]: + results: List[List[Tuple[str, float]]] = [] for mention in entity_mentions: dict_entry = self.name_to_id_index.get(mention) - if not dict_entry: - candidates.append([]) + if dict_entry is None: + results.append([]) continue + results.append([(dict_entry, 1.0)]) - candidates.append( - [ - self.build_candidate( - candidate=(mention, dict_entry, 1.0), database_name=self.dictionary.database_name - ) - ] - ) - - return candidates + return results class BigramTfIDFVectorizer: @@ -943,7 +919,7 @@ def combine_dense_and_sparse_results( return hybrid_scores, hybrid_ids - def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLinkingCandidate]]: + def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]: """Returns the top-k entity / concept identifiers for each entity mention. :param entity_mentions: Entity mentions @@ -962,12 +938,9 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[EntityLink sparse_ids=sparse_ids, top_k=top_k, ) - return [ [ - self.build_candidate( - candidate=self.dictionary_data[i] + (score,), database_name=self.dictionary.database_name - ) + (self.dictionary_data[i][1].split("|")[0], score) for i, score in zip(mention_ids, mention_scores) ] for mention_ids, mention_scores in zip(ids, scores) @@ -982,19 +955,24 @@ def __init__( candidate_generator: AbstractCandidateGenerator, preprocessor: AbstractEntityPreprocessor, entity_type: str, + label_type: str, ): self.preprocessor = preprocessor self.candidate_generator = candidate_generator self.entity_type = entity_type - self.annotation_layers = [ENTITY_TYPE_TO_ANNOTATION_LAYER.get(self.entity_type, "ner")] + self.annotation_layers = [self.entity_type] + self._label_type = label_type + + @property + def label_type(self): + return self._label_type def extract_mentions( self, sentences: List[Sentence], annotation_layers: Optional[List[str]] = None, - ) -> Tuple[List[int], List[Span], List[str], List[str]]: + ) -> Tuple[List[Span], List[str], List[str]]: """Unpack all mentions in sentences for batch search.""" - source = [] data_points = [] mentions = [] mention_annotation_layers = [] @@ -1002,10 +980,9 @@ def extract_mentions( # use default annotation layers only if are not provided annotation_layers = annotation_layers if annotation_layers is not None else self.annotation_layers - for i, sentence in enumerate(sentences): + for sentence in sentences: for annotation_layer in annotation_layers: for entity in sentence.get_labels(annotation_layer): - source.append(i) data_points.append(entity.data_point) mentions.append( self.preprocessor.process_mention(entity, sentence) @@ -1016,7 +993,7 @@ def extract_mentions( # assert len(mentions) > 0, f"There are no entity mentions of type `{self.entity_type}`" - return source, data_points, mentions, mention_annotation_layers + return data_points, mentions, mention_annotation_layers def predict( self, @@ -1037,7 +1014,7 @@ def predict( if self.preprocessor is not None: self.preprocessor.initialize(sentences) - source, data_points, mentions, mentions_annotation_layers = self.extract_mentions( + data_points, mentions, mentions_annotation_layers = self.extract_mentions( sentences=sentences, annotation_layers=annotation_layers ) @@ -1047,24 +1024,23 @@ def predict( candidates = self.candidate_generator.search(entity_mentions=mentions, top_k=top_k) # Add a label annotation for each candidate - for i, data_point, mention_candidates, mentions_annotation_layer in zip( - source, data_points, candidates, mentions_annotation_layers + for data_point, mention_candidates, mentions_annotation_layer in zip( + data_points, candidates, mentions_annotation_layers ): - sentences[i].add_label( - typename=mentions_annotation_layer, - value_or_label=EntityLinkingLabel(data_point=data_point, candidates=mention_candidates), - ) + for (candidate_id, confidence) in mention_candidates: + data_point.add_label(self.label_type, candidate_id, confidence) @classmethod def load( cls, model_name_or_path: Union[str, Path], + label_type: str, dictionary_name_or_path: Optional[Union[str, Path]] = None, hybrid_search: bool = True, max_length: int = 25, batch_size: int = 1024, similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, - preprocessor: AbstractEntityPreprocessor = Ab3PEntityPreprocessor.load(preprocessor=EntityPreprocessor()), + preprocessor: AbstractEntityPreprocessor = EntityPreprocessor(), force_hybrid_search: bool = False, sparse_weight: float = DEFAULT_SPARSE_WEIGHT, entity_type: Optional[str] = None, @@ -1114,7 +1090,12 @@ def load( "BiomedicalEntityLinker predicts: Dictionary `%s` (entity type: %s)", dictionary_name_or_path, entity_type ) - return cls(candidate_generator=candidate_generator, preprocessor=preprocessor, entity_type=entity_type) + return cls( + candidate_generator=candidate_generator, + preprocessor=preprocessor, + entity_type=entity_type, + label_type=label_type, + ) @staticmethod def __get_model_path_and_entity_type( diff --git a/pyproject.toml b/pyproject.toml index e4986413c..62da09bfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ filterwarnings = [ "ignore:The class LayoutLMv3FeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use LayoutLMv3ImageProcessor instead.", # huggingface layoutlmv3 has deprecated calls. "ignore:pkg_resources", # huggingface has deprecated calls. 'ignore:Deprecated call to `pkg_resources', # huggingface has deprecated calls. + 'ignore:distutils Version classes are deprecated.', # faiss uses deprecated distutils. ] markers = [ "integration", diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index a285af5c7..ebd225dad 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -1,66 +1,77 @@ -# from flair.data import Sentence -# from flair.models.biomedical_entity_linking import ( -# BiomedicalEntityLinker, -# BiomedicalEntityLinkingDictionary, -# ) -# from flair.nn import Classifier +from flair.data import Sentence +from flair.models.biomedical_entity_linking import ( + BiomedicalEntityLinker, + BiomedicalEntityLinkingDictionary, +) +from flair.nn import Classifier -# def test_bel_dictionary(): -# """ -# Check data in dictionary is what we expect. -# Hard to define a good test as dictionaries are DYNAMIC, -# i.e. they can change over time -# """ +def test_bel_dictionary(): + """Check data in dictionary is what we expect. -# dictionary = BiomedicalEntityLinkingDictionary.load("disease") -# _, identifier = next(dictionary.stream()) -# assert identifier.startswith(("MESH:", "OMIM:", "DO:DOID")) + Hard to define a good test as dictionaries are DYNAMIC, + i.e. they can change over time. + """ + dictionary = BiomedicalEntityLinkingDictionary.load("diseases") + _, identifier = next(dictionary.stream()) + assert identifier.startswith(("MESH:", "OMIM:", "DO:DOID")) -# dictionary = BiomedicalEntityLinkingDictionary.load("ctd-disease") -# _, identifier = next(dictionary.stream()) -# assert identifier.startswith("MESH:") + dictionary = BiomedicalEntityLinkingDictionary.load("ctd-diseases") + _, identifier = next(dictionary.stream()) + assert identifier.startswith("MESH:") -# dictionary = BiomedicalEntityLinkingDictionary.load("ctd-chemical") -# _, identifier = next(dictionary.stream()) -# assert identifier.startswith("MESH:") + dictionary = BiomedicalEntityLinkingDictionary.load("ctd-chemicals") + _, identifier = next(dictionary.stream()) + assert identifier.startswith("MESH:") -# dictionary = BiomedicalEntityLinkingDictionary.load("chemical") -# _, identifier = next(dictionary.stream()) -# assert identifier.startswith("MESH:") + dictionary = BiomedicalEntityLinkingDictionary.load("chemical") + _, identifier = next(dictionary.stream()) + assert identifier.startswith("MESH:") -# dictionary = BiomedicalEntityLinkingDictionary.load("ncbi-taxonomy") -# _, identifier = next(dictionary.stream()) -# assert identifier.isdigit() + dictionary = BiomedicalEntityLinkingDictionary.load("ncbi-taxonomy") + _, identifier = next(dictionary.stream()) + assert identifier.isdigit() -# dictionary = BiomedicalEntityLinkingDictionary.load("species") -# _, identifier = next(dictionary.stream()) -# assert identifier.isdigit() + dictionary = BiomedicalEntityLinkingDictionary.load("species") + _, identifier = next(dictionary.stream()) + assert identifier.isdigit() -# dictionary = BiomedicalEntityLinkingDictionary.load("ncbi-gene") -# _, identifier = next(dictionary.stream()) -# assert identifier.isdigit() + dictionary = BiomedicalEntityLinkingDictionary.load("ncbi-gene") + _, identifier = next(dictionary.stream()) + assert identifier.isdigit() -# dictionary = BiomedicalEntityLinkingDictionary.load("gene") -# _, identifier = next(dictionary.stream()) -# assert identifier.isdigit() + dictionary = BiomedicalEntityLinkingDictionary.load("genes") + _, identifier = next(dictionary.stream()) + assert identifier.isdigit() -# def test_biomedical_entity_linking(): +def test_biomedical_entity_linking(): + sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome") -# sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome") + tagger = Classifier.load("hunflair") + tagger.predict(sentence) -# tagger = Classifier.load("hunflair") -# tagger.predict(sentence) + disease_linker = BiomedicalEntityLinker.load("diseases", "diseases-nel", hybrid_search=True) + disease_dictionary = disease_linker.candidate_generator.dictionary + disease_linker.predict(sentence) -# disease_linker = BiomedicalEntityLinker.load("disease", hybrid_search=True) -# disease_linker.predict(sentence) + gene_linker = BiomedicalEntityLinker.load("genes", "genes-nel", hybrid_search=False, entity_type="genes") + gene_dictionary = gene_linker.candidate_generator.dictionary -# gene_linker = BiomedicalEntityLinker.load("gene", hybrid_search=False) + gene_linker.predict(sentence) -# breakpoint() + print("Diseases") + for span in sentence.get_spans(disease_linker.entity_type): + print(f"Span: {span.text}") + for candidate_label in span.get_labels(disease_linker.label_type): + candidate = disease_dictionary[candidate_label.value] + print(f"Candidate: {candidate.concept_name}") + print("Genes") + for span in sentence.get_spans(gene_linker.entity_type): + print(f"Span: {span.text}") + for candidate_label in span.get_labels(gene_linker.label_type): + candidate = gene_dictionary[candidate_label.value] + print(f"Candidate: {candidate.concept_name}") -# if __name__ == "__main__": -# # test_bel_dictionary() -# test_biomedical_entity_linking() + breakpoint() # noqa: T100 From a374040f35c5a0052af8124f6b7c3cdedea7ffcf Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 4 Sep 2023 19:07:34 +0200 Subject: [PATCH 28/58] refactor knowledgebase datasets --- flair/data.py | 33 +- flair/datasets/__init__.py | 12 +- flair/datasets/knowledgebase.py | 425 ++++++++++++++++++ flair/models/biomedical_entity_linking.py | 29 +- .../HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md | 21 +- tests/test_biomedical_entity_linking.py | 6 +- 6 files changed, 484 insertions(+), 42 deletions(-) create mode 100644 flair/datasets/knowledgebase.py diff --git a/flair/data.py b/flair/data.py index 3721e6076..b52395df4 100644 --- a/flair/data.py +++ b/flair/data.py @@ -433,34 +433,45 @@ def __len__(self) -> int: raise NotImplementedError -class EntityLinkingCandidate: - """Represent a single candidate returned by a CandidateGenerator.""" +class Concept: + """A Concept as part of a knowledgebase or ontology.""" def __init__( self, concept_id: str, concept_name: str, database_name: str, - additional_ids: Optional[Union[List[str], str]] = None, + additional_ids: Optional[List[str]] = None, + synonyms: Optional[List[str]] = None, + description: Optional[str] = None, ): - """Represent a single candidate returned by a CandidateGenerator. + """A Concept as part of a knowledgebase or ontology. Args: - concept_id: Identifier of the entity / concept from the knowledge base / ontology - concept_name: (Canonical) name of the entity / concept from the knowledge base / ontology - score: Matching score of the entity / concept according to the entity mention + concept_id: Identifier of the concept from the knowledgebase / ontology + concept_name: (Canonical) name of the concept from the knowledgebase / ontology additional_ids: List of additional identifiers for the concept / entity in the KB / ontology - database_name: Name of the knowlege base / ontology + database_name: Name of the knowledgebase / ontology + synonyms: A list of synonyms for this entry + description: A description about the Concept to describe """ self.concept_id = concept_id self.concept_name = concept_name self.database_name = database_name - self.additional_ids = additional_ids + self.description = description + if additional_ids is None: + self.additional_ids = [] + else: + self.additional_ids = additional_ids + if synonyms is None: + self.synonyms = [] + else: + self.synonyms = synonyms def __str__(self) -> str: string = f"EntityLinkingCandidate: {self.database_name}:{self.concept_id} - {self.concept_name}" - if self.additional_ids is not None: - string += f" - {self.additional_ids}" + if self.additional_ids: + string += f" - {'|'.join(self.additional_ids)}" return string def __repr__(self) -> str: diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index c008cb0dc..21493365d 100644 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -7,6 +7,14 @@ StringDataset, ) +from .knowledgebase import ( + CTD_CHEMICALS_DICTIONARY, + CTD_DISEASES_DICTIONARY, + NCBI_GENE_HUMAN_DICTIONARY, + NCBI_TAXONOMY_DICTIONARY, +) + + # Expose all biomedical data sets used for the evaluation of BioBERT # - # - @@ -37,8 +45,6 @@ CLL, CRAFT, CRAFT_V4, - CTD_CHEMICALS_DICTIONARY, - CTD_DISEASES_DICTIONARY, DECA, FSU, GELLUS, @@ -92,8 +98,6 @@ LOCTEXT, MIRNA, NCBI_DISEASE, - NCBI_GENE_HUMAN_DICTIONARY, - NCBI_TAXONOMY_DICTIONARY, OSIRIS, PDR, S800, diff --git a/flair/datasets/knowledgebase.py b/flair/datasets/knowledgebase.py new file mode 100644 index 000000000..22c78c0ad --- /dev/null +++ b/flair/datasets/knowledgebase.py @@ -0,0 +1,425 @@ +import csv +from pathlib import Path +from typing import Dict, Iterable, Iterator, Optional, Tuple, Union + +import flair +from flair.data import Concept +from flair.file_utils import cached_path, unpack_file + + +class KnowledgebaseLinkingDictionary: + """Base class for downloading and reading of dictionaries for knowledgebase entity linking. + + A dictionary represents all entities of a knowledge base and their associated ids. + """ + + def __init__( + self, + candidates: Iterable[Concept], + dataset_name: Optional[str] = None, + ): + """Initialize the Knowledgebase linking dictionary. + + Args: + candidates: A iterable sequence of all Candidates contained in the knowledge base. + """ + # this dataset name + if dataset_name is None: + dataset_name = self.__class__.__name__.lower() + self._dataset_name = dataset_name + + candidates = list(candidates) + + self._idx_to_candidates = {candidate.concept_id: candidate for candidate in candidates} + self._text_to_index = { + text: candidate.concept_id + for candidate in candidates + for text in [candidate.concept_name, *candidate.synonyms] + } + + @property + def database_name(self) -> str: + """Name of the database represented by the dictionary.""" + return self._dataset_name + + @property + def text_to_index(self) -> Dict[str, str]: + return self._text_to_index + + def __getitem__(self, item: str) -> Concept: + return self._idx_to_candidates[item] + + +class HunerEntityLinkingDictionary(KnowledgebaseLinkingDictionary): + """Base dictionary with data already in huner format. + + Every line in the file must be formatted as follows: + + concept_id||concept_name + + If multiple concept ids are associated to a given name they have to be separated by a `|`, e.g. + + 7157||TP53|tumor protein p53 + """ + + def __init__(self, path: Path, dataset_name: str): + self.dataset_file = path + self._dataset_name = dataset_name + super().__init__(self._load_candidates(), dataset_name=dataset_name) + + def _load_candidates(self): + with open(self.dataset_file) as fp: + for line in fp: + line = line.strip() + if line == "": + continue + assert "||" in line, "Preprocessed EntityLinkingDictionary must have lines in the format: `cui||name`" + cui, name = line.split("||", 1) + name = name.lower() + cui, *additional_ids = cui.split("|") + yield Concept( + concept_id=cui, + concept_name=name, + database_name=self._dataset_name, + additional_ids=additional_ids, + ) + + +class CTD_DISEASES_DICTIONARY(KnowledgebaseLinkingDictionary): + """Dictionary for named entity linking on diseases using the Comparative Toxicogenomics Database (CTD). + + Fur further information can be found at https://ctdbase.org/ + """ + + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + ): + if base_path is None: + base_path = flair.cache_root / "datasets" + + dataset_name = self.__class__.__name__.lower() + + data_folder = base_path / dataset_name + + data_file = self.download_dictionary(data_folder) + + super().__init__(self.parse_file(data_file), dataset_name="CTD-DISEASES") + + def download_dictionary(self, data_dir: Path) -> Path: + result_file = data_dir / "CTD_diseases.tsv" + data_url = "https://ctdbase.org/reports/CTD_diseases.tsv.gz" + + if not result_file.exists(): + data_path = cached_path(data_url, data_dir) + unpack_file(data_path, unpack_to=result_file, keep=False) + + return result_file + + def parse_file(self, original_file: Path) -> Iterator[Concept]: + columns = [ + "symbol", + "identifier", + "alternative_identifiers", + "definition", + "parent_identifiers", + "tree_numbers", + "parent_tree_numbers", + "synonyms", + "slim_mappings", + ] + + with open(original_file, encoding="utf-8") as f: + reader = csv.DictReader(filter(lambda r: r[0] != "#", f), fieldnames=columns, delimiter="\t") + + for row in reader: + identifier = row["identifier"] + additional_identifiers = [i for i in row.get("alternative_identifiers", "").split("|") if i != ""] + + if identifier == "MESH:C" and not additional_identifiers: + return None + + symbol = row["symbol"] + + synonyms = [s for s in row.get("synonyms", "").split("|") if s != ""] + definition = row["definition"] + + yield Concept( + concept_id=identifier, + concept_name=symbol, + database_name="CTD-DISEASES", + additional_ids=additional_identifiers, + synonyms=synonyms, + description=definition, + ) + + +class CTD_CHEMICALS_DICTIONARY(KnowledgebaseLinkingDictionary): + """Dictionary for named entity linking on chemicals using the Comparative Toxicogenomics Database (CTD). + + Fur further information can be found at https://ctdbase.org/ + """ + + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + ): + if base_path is None: + base_path = flair.cache_root / "datasets" + + dataset_name = self.__class__.__name__.lower() + + data_folder = base_path / dataset_name + + data_file = self.download_dictionary(data_folder) + + super().__init__(self.parse_file(data_file), dataset_name="CTD-CHEMICALS") + + def download_dictionary(self, data_dir: Path) -> Path: + result_file = data_dir / "CTD_chemicals.tsv" + data_url = "https://ctdbase.org/reports/CTD_chemicals.tsv.gz" + + if not result_file.exists(): + data_path = cached_path(data_url, data_dir) + unpack_file(data_path, unpack_to=result_file) + + return result_file + + def parse_file(self, original_file: Path) -> Iterator[Concept]: + columns = [ + "symbol", + "identifier", + "casrn", + "definition", + "parent_identifiers", + "tree_numbers", + "parent_tree_numbers", + "synonyms", + ] + + with open(original_file, encoding="utf-8") as f: + reader = csv.DictReader(filter(lambda r: r[0] != "#", f), fieldnames=columns, delimiter="\t") + + for row in reader: + identifier = row["identifier"] + additional_identifiers = [i for i in row.get("alternative_identifiers", "").split("|") if i != ""] + + if identifier == "MESH:D013749": + # This MeSH ID was used by MeSH when this chemical was part of the MeSH controlled vocabulary. + continue + + symbol = row["symbol"] + + synonyms = [s for s in row.get("synonyms", "").split("|") if s != "" and s != symbol] + definition = row["definition"] + + yield Concept( + concept_id=identifier, + concept_name=symbol, + database_name="CTD-CHEMICALS", + additional_ids=additional_identifiers, + synonyms=synonyms, + description=definition, + ) + + +class NCBI_GENE_HUMAN_DICTIONARY(KnowledgebaseLinkingDictionary): + """Dictionary for named entity linking on diseases using the NCBI Gene ontology. + + Note that this dictionary only represents human genes - gene from different species + aren't included! + + Fur further information can be found at https://www.ncbi.nlm.nih.gov/gene/ + """ + + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + ): + if base_path is None: + base_path = flair.cache_root / "datasets" + + dataset_name = self.__class__.__name__.lower() + + data_folder = base_path / dataset_name + + data_file = self.download_dictionary(data_folder) + + super().__init__(self.parse_dictionary(data_file), dataset_name="NCBI-GENE-HUMAN") + + def _is_invalid_name(self, name: Optional[str]) -> bool: + """Determine if a name should be skipped.""" + if name is None: + return False + EMPTY_ENTRY_TEXT = [ + "when different from all specified ones in Gene.", + "Record to support submission of GeneRIFs for a gene not in Gene", + ] + + newentry = name == "NEWENTRY" + empty = name == "" + text_comment = any(e in name for e in EMPTY_ENTRY_TEXT) + + return any([newentry, empty, text_comment]) + + def download_dictionary(self, data_dir: Path) -> Path: + result_file = data_dir / "Homo_sapiens.gene_info" + data_url = "https://ftp.ncbi.nih.gov/gene/DATA/GENE_INFO/Mammalia/Homo_sapiens.gene_info.gz" + + if not result_file.exists(): + data_path = cached_path(data_url, data_dir) + unpack_file(data_path, unpack_to=result_file) + + return result_file + + def parse_dictionary(self, original_file: Path) -> Iterator[Tuple[str, str]]: + synonym_fields = ( + "Symbol_from_nomenclature_authority", + "Full_name_from_nomenclature_authority", + "description", + "Synonyms", + "Other_designations", + ) + field_names = [ + "tax_id", + "GeneID", + "Symbol", + "LocusTag", + "Synonyms", + "dbXrefs", + "chromosome", + "map_location", + "description", + "type_of_gene", + "Symbol_from_nomenclature_authority", + "Full_name_from_nomenclature_authority", + "Nomenclature_status", + "Other_designations", + "Modification_date", + "Feature_type", + ] + + with open(original_file, encoding="utf-8") as f: + reader = csv.DictReader(filter(lambda r: r[0] != "#", f), fieldnames=field_names, delimiter="\t") + + for row in reader: + identifier = row["GeneID"] + symbol = row["Symbol"] + + if self._is_invalid_name(symbol): + continue + additional_identifiers = [i for i in row.get("alternative_identifiers", "").split("|") if i != ""] + + if identifier == "MESH:D013749": + # This MeSH ID was used by MeSH when this chemical was part of the MeSH controlled vocabulary. + continue + + synonyms = [] + for synonym_field in synonym_fields: + synonyms.extend([name.replace("'", "") for name in row.get(synonym_field, "").split("|")]) + synonyms = sorted([sym for sym in set(synonyms) if not self._is_invalid_name(sym)]) + + yield Concept( + concept_id=identifier, + concept_name=symbol, + database_name="NCBI-GENE-HUMAN", + additional_ids=additional_identifiers, + synonyms=synonyms, + ) + + +class NCBI_TAXONOMY_DICTIONARY(KnowledgebaseLinkingDictionary): + """Dictionary for named entity linking on organisms / species using the NCBI taxonomy ontology. + + Further information about the ontology can be found at https://www.ncbi.nlm.nih.gov/taxonomy + """ + + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + ): + if base_path is None: + base_path = flair.cache_root / "datasets" + + dataset_name = self.__class__.__name__.lower() + + data_folder = base_path / dataset_name + + data_file = self.download_dictionary(data_folder) + + super().__init__(self.parse_dictionary(data_file), dataset_name="NCBI-TAXONOMY") + + def download_dictionary(self, data_dir: Path) -> Path: + result_file = data_dir / "names.dmp" + data_url = "https://ftp.ncbi.nih.gov/pub/taxonomy/new_taxdump/new_taxdump.tar.gz" + + if not result_file.exists(): + data_path = cached_path(data_url, data_dir) + unpack_file(data_path, unpack_to=result_file) + + return result_file + + def parse_dictionary(self, original_file: Path) -> Iterator[Tuple[str, str]]: + ncbi_taxonomy_synset = [ + "genbank common name", + "common name", + "scientific name", + "equivalent name", + "synonym", + "acronym", + "blast name", + "genbank", + "genbank synonym", + "genbank acronym", + "includes", + "type material", + ] + main_field = "scientific name" + + with open(original_file, encoding="utf-8") as f: + curr_identifier = None + curr_synonyms = [] + curr_name = None + + for line in f: + # parse line + parsed_line = {} + elements = [e.strip() for e in line.strip().split("|")] + parsed_line["identifier"] = elements[0] + parsed_line["name"] = elements[1] if elements[2] == "" else elements[2] + parsed_line["field"] = elements[3] + + if parsed_line["name"] in ["all", "root"]: + continue + + if parsed_line["field"] in ["authority", "in-part", "type material"]: + continue + + if parsed_line["field"] not in ncbi_taxonomy_synset: + raise ValueError(f"Field {parsed_line['field']} unknown!") + + if curr_identifier is None: + curr_identifier = parsed_line["identifier"] + + if curr_identifier == parsed_line["identifier"]: + synonym = parsed_line["name"] + if parsed_line["field"] == main_field: + curr_name = synonym + else: + curr_synonyms.append(synonym) + elif curr_identifier != parsed_line["identifier"]: + assert curr_name is not None + yield Concept( + concept_id=curr_identifier, + concept_name=curr_name, + database_name="NCBI-TAXONOMY", + ) + + curr_identifier = parsed_line["identifier"] + curr_synonyms = [] + curr_name = None + synonym = parsed_line["name"] + if parsed_line["field"] == main_field: + curr_name = synonym + else: + curr_synonyms.append(synonym) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index be93a734d..19e766559 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -24,7 +24,7 @@ from tqdm import tqdm import flair -from flair.data import EntityLinkingCandidate, Label, Sentence, Span +from flair.data import Concept, Label, Sentence, Span from flair.datasets import ( CTD_CHEMICALS_DICTIONARY, CTD_DISEASES_DICTIONARY, @@ -32,7 +32,7 @@ NCBI_TAXONOMY_DICTIONARY, ) from flair.datasets.biomedical import ( - AbstractBiomedicalEntityLinkingDictionary, + KnowledgebaseLinkingDictionary, ParsedBiomedicalEntityLinkingDictionary, ) from flair.embeddings import TransformerDocumentEmbeddings @@ -414,9 +414,7 @@ class BiomedicalEntityLinkingDictionary: 7157||TP53|tumor protein p53 """ - def __init__( - self, reader: AbstractBiomedicalEntityLinkingDictionary - ): + def __init__(self, reader: KnowledgebaseLinkingDictionary): self.reader = reader @classmethod @@ -460,7 +458,7 @@ def stream(self) -> Iterator[Tuple[str, str]]: """Stream entries from preprocessed dictionary.""" yield from self.reader.stream() - def __getitem__(self, item: str) -> EntityLinkingCandidate: + def __getitem__(self, item: str) -> Concept: return self.reader[item] @@ -479,7 +477,7 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, :result: List containing a list of entity linking candidates per entity mention from the input """ - def build_candidate(self, candidate: Tuple[str, str, float], database_name: str) -> EntityLinkingCandidate: + def build_candidate(self, candidate: Tuple[str, str, float], database_name: str) -> Concept: """Get nice container with all info about entity linking candidate.""" concept_name = candidate[0] concept_id = candidate[1] @@ -491,7 +489,7 @@ def build_candidate(self, candidate: Tuple[str, str, float], database_name: str) else: additional_labels = None - return EntityLinkingCandidate( + return Concept( concept_id=concept_id, concept_name=concept_name, additional_ids=additional_labels, @@ -939,15 +937,12 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, top_k=top_k, ) return [ - [ - (self.dictionary_data[i][1].split("|")[0], score) - for i, score in zip(mention_ids, mention_scores) - ] + [(self.dictionary_data[i][1].split("|")[0], score) for i, score in zip(mention_ids, mention_scores)] for mention_ids, mention_scores in zip(ids, scores) ] -class BiomedicalEntityLinker: +class EntityMentionLinker: """Entity linking model for the biomedical domain.""" def __init__( @@ -967,6 +962,10 @@ def __init__( def label_type(self): return self._label_type + @property + def dictionary(self) -> KnowledgebaseLinkingDictionary: + return self.candidate_generator.dictionary + def extract_mentions( self, sentences: List[Sentence], @@ -1027,7 +1026,7 @@ def predict( for data_point, mention_candidates, mentions_annotation_layer in zip( data_points, candidates, mentions_annotation_layers ): - for (candidate_id, confidence) in mention_candidates: + for candidate_id, confidence in mention_candidates: data_point.add_label(self.label_type, candidate_id, confidence) @classmethod @@ -1045,7 +1044,7 @@ def load( sparse_weight: float = DEFAULT_SPARSE_WEIGHT, entity_type: Optional[str] = None, dictionary: Optional[BiomedicalEntityLinkingDictionary] = None, - ) -> "BiomedicalEntityLinker": + ) -> "EntityMentionLinker": """Loads a model for biomedical named entity normalization. See __init__ method for detailed docstring on arguments. diff --git a/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md index b7c814570..973e6c328 100644 --- a/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md +++ b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md @@ -1,8 +1,9 @@ # HunFlair Tutorial 3: Entity Linking -After adding named entity recognition tags to your sentence, you can run named entity linking on these annotations. +After adding named entity recognition tags to your sentence, you can run named entity linking on these annotations. + ```python -from flair.models.biomedical_entity_linking import BiomedicalEntityLinker +from flair.models.biomedical_entity_linking import EntityMentionLinker from flair.nn import Classifier from flair.tokenization import SciSpacyTokenizer from flair.data import Sentence @@ -17,16 +18,16 @@ sentence = Sentence( ner_tagger = Classifier.load("hunflair") ner_tagger.predict(sentence) -nen_tagger = BiomedicalEntityLinker.load("disease") +nen_tagger = EntityMentionLinker.load("disease") nen_tagger.predict(sentence) -nen_tagger = BiomedicalEntityLinker.load("gene") +nen_tagger = EntityMentionLinker.load("gene") nen_tagger.predict(sentence) -nen_tagger = BiomedicalEntityLinker.load("chemical") +nen_tagger = EntityMentionLinker.load("chemical") nen_tagger.predict(sentence) -nen_tagger = BiomedicalEntityLinker.load("species", entity_type="species") +nen_tagger = EntityMentionLinker.load("species", entity_type="species") nen_tagger.predict(sentence) for tag in sentence.get_labels(): @@ -50,10 +51,12 @@ a knowledge base or ontology. We have pre-configured combinations of models and "disease", "chemical" and "gene". You can also provide your own model and dictionary: + ```python -from flair.models.biomedical_entity_linking import BiomedicalEntityLinker +from flair.models.biomedical_entity_linking import EntityMentionLinker -nen_tagger = BiomedicalEntityLinker.load("name_or_path_to_your_model", dictionary_names_or_path="name_or_path_to_your_dictionary") -nen_tagger = BiomedicalEntityLinker.load("path_to_custom_disease_model", dictionary_names_or_path="disease") +nen_tagger = EntityMentionLinker.load("name_or_path_to_your_model", + dictionary_names_or_path="name_or_path_to_your_dictionary") +nen_tagger = EntityMentionLinker.load("path_to_custom_disease_model", dictionary_names_or_path="disease") ```` You can use any combination of provided models, provided dictionaries and your own. diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index ebd225dad..6c9e9b8a8 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -1,7 +1,7 @@ from flair.data import Sentence from flair.models.biomedical_entity_linking import ( - BiomedicalEntityLinker, BiomedicalEntityLinkingDictionary, + EntityMentionLinker, ) from flair.nn import Classifier @@ -51,11 +51,11 @@ def test_biomedical_entity_linking(): tagger = Classifier.load("hunflair") tagger.predict(sentence) - disease_linker = BiomedicalEntityLinker.load("diseases", "diseases-nel", hybrid_search=True) + disease_linker = EntityMentionLinker.load("diseases", "diseases-nel", hybrid_search=True) disease_dictionary = disease_linker.candidate_generator.dictionary disease_linker.predict(sentence) - gene_linker = BiomedicalEntityLinker.load("genes", "genes-nel", hybrid_search=False, entity_type="genes") + gene_linker = EntityMentionLinker.load("genes", "genes-nel", hybrid_search=False, entity_type="genes") gene_dictionary = gene_linker.candidate_generator.dictionary gene_linker.predict(sentence) From 2c12f149299b458a168973140889157ab82b8266 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 11 Sep 2023 10:24:00 +0200 Subject: [PATCH 29/58] refactor CandidateSearchIndex --- flair/data.py | 1 + flair/datasets/__init__.py | 18 +- flair/datasets/knowledgebase.py | 21 +- flair/models/biomedical_entity_linking.py | 650 +++++++--------------- tests/test_biomedical_entity_linking.py | 36 +- 5 files changed, 245 insertions(+), 481 deletions(-) diff --git a/flair/data.py b/flair/data.py index b52395df4..e2be6162f 100644 --- a/flair/data.py +++ b/flair/data.py @@ -839,6 +839,7 @@ def __init__( # log a warning if the dataset is empty if text == "": log.warning("Warning: An empty Sentence was created! Are there empty strings in your dataset?") + breakpoint() @property def unlabeled_identifier(self): diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index 21493365d..ba066545b 100644 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -7,14 +7,6 @@ StringDataset, ) -from .knowledgebase import ( - CTD_CHEMICALS_DICTIONARY, - CTD_DISEASES_DICTIONARY, - NCBI_GENE_HUMAN_DICTIONARY, - NCBI_TAXONOMY_DICTIONARY, -) - - # Expose all biomedical data sets used for the evaluation of BioBERT # - # - @@ -156,6 +148,14 @@ WSD_WORDNET_GLOSS_TAGGED, ZELDA, ) +from .knowledgebase import ( + CTD_CHEMICALS_DICTIONARY, + CTD_DISEASES_DICTIONARY, + NCBI_GENE_HUMAN_DICTIONARY, + NCBI_TAXONOMY_DICTIONARY, + HunerEntityLinkingDictionary, + KnowledgebaseLinkingDictionary, +) # Expose all relation extraction datasets from .ocr import SROIE, OcrJsonDataset @@ -323,6 +323,7 @@ "SentenceDataset", "MongoDataset", "StringDataset", + "KnowledgebaseLinkingDictionary", "AGNEWS", "ANAT_EM", "AZDZ", @@ -350,6 +351,7 @@ "FSU", "GELLUS", "GPRO", + "HunerEntityLinkingDictionary", "HUNER_CELL_LINE", "HUNER_CELL_LINE_CELL_FINDER", "HUNER_CELL_LINE_CLL", diff --git a/flair/datasets/knowledgebase.py b/flair/datasets/knowledgebase.py index 22c78c0ad..2870e31c0 100644 --- a/flair/datasets/knowledgebase.py +++ b/flair/datasets/knowledgebase.py @@ -1,6 +1,6 @@ import csv from pathlib import Path -from typing import Dict, Iterable, Iterator, Optional, Tuple, Union +from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union import flair from flair.data import Concept @@ -46,6 +46,10 @@ def database_name(self) -> str: def text_to_index(self) -> Dict[str, str]: return self._text_to_index + @property + def candidates(self) -> List[Concept]: + return list(self._idx_to_candidates.values()) + def __getitem__(self, item: str) -> Concept: return self._idx_to_candidates[item] @@ -97,6 +101,7 @@ def __init__( ): if base_path is None: base_path = flair.cache_root / "datasets" + base_path = Path(base_path) dataset_name = self.__class__.__name__.lower() @@ -166,6 +171,7 @@ def __init__( ): if base_path is None: base_path = flair.cache_root / "datasets" + base_path = Path(base_path) dataset_name = self.__class__.__name__.lower() @@ -238,6 +244,7 @@ def __init__( ): if base_path is None: base_path = flair.cache_root / "datasets" + base_path = Path(base_path) dataset_name = self.__class__.__name__.lower() @@ -251,6 +258,7 @@ def _is_invalid_name(self, name: Optional[str]) -> bool: """Determine if a name should be skipped.""" if name is None: return False + name = name.strip() EMPTY_ENTRY_TEXT = [ "when different from all specified ones in Gene.", "Record to support submission of GeneRIFs for a gene not in Gene", @@ -258,9 +266,10 @@ def _is_invalid_name(self, name: Optional[str]) -> bool: newentry = name == "NEWENTRY" empty = name == "" + minus = name == "-" text_comment = any(e in name for e in EMPTY_ENTRY_TEXT) - return any([newentry, empty, text_comment]) + return any([newentry, empty, minus, text_comment]) def download_dictionary(self, data_dir: Path) -> Path: result_file = data_dir / "Homo_sapiens.gene_info" @@ -272,7 +281,7 @@ def download_dictionary(self, data_dir: Path) -> Path: return result_file - def parse_dictionary(self, original_file: Path) -> Iterator[Tuple[str, str]]: + def parse_dictionary(self, original_file: Path) -> Iterator[Concept]: synonym_fields = ( "Symbol_from_nomenclature_authority", "Full_name_from_nomenclature_authority", @@ -318,6 +327,8 @@ def parse_dictionary(self, original_file: Path) -> Iterator[Tuple[str, str]]: for synonym_field in synonym_fields: synonyms.extend([name.replace("'", "") for name in row.get(synonym_field, "").split("|")]) synonyms = sorted([sym for sym in set(synonyms) if not self._is_invalid_name(sym)]) + if symbol in synonyms: + synonyms.remove(symbol) yield Concept( concept_id=identifier, @@ -340,7 +351,7 @@ def __init__( ): if base_path is None: base_path = flair.cache_root / "datasets" - + base_path = Path(base_path) dataset_name = self.__class__.__name__.lower() data_folder = base_path / dataset_name @@ -359,7 +370,7 @@ def download_dictionary(self, data_dir: Path) -> Path: return result_file - def parse_dictionary(self, original_file: Path) -> Iterator[Tuple[str, str]]: + def parse_dictionary(self, original_file: Path) -> Iterator[Concept]: ncbi_taxonomy_synset = [ "genbank common name", "common name", diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index 19e766559..6d47932b4 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -28,14 +28,12 @@ from flair.datasets import ( CTD_CHEMICALS_DICTIONARY, CTD_DISEASES_DICTIONARY, + HunerEntityLinkingDictionary, NCBI_GENE_HUMAN_DICTIONARY, NCBI_TAXONOMY_DICTIONARY, -) -from flair.datasets.biomedical import ( KnowledgebaseLinkingDictionary, - ParsedBiomedicalEntityLinkingDictionary, ) -from flair.embeddings import TransformerDocumentEmbeddings +from flair.embeddings import TransformerDocumentEmbeddings, DocumentTFIDFEmbeddings, DocumentEmbeddings from flair.file_utils import cached_path FAISS_VERSION = "1.7.4" @@ -117,8 +115,7 @@ class SimilarityMetric(Enum): """Similarity metrics.""" - INNER_PRODUCT = faiss.METRIC_INNER_PRODUCT - # L2 = faiss.METRIC_L2 + INNER_PRODUCT = auto() COSINE = auto() @@ -136,7 +133,7 @@ def wrap_func(*args, **kwargs): return wrap_func -class AbstractEntityPreprocessor(ABC): +class EntityPreprocessor(ABC): """A pre-processor used to transform / clean both entity mentions and entity names. This class provides the basic interface for such transformations @@ -178,7 +175,7 @@ def initialize(self, sentences: List[Sentence]): """ -class EntityPreprocessor(AbstractEntityPreprocessor): +class BioSynEntityPreprocessor(EntityPreprocessor): """Entity preprocessor using Synonym Marginalization. Adapted from: @@ -220,7 +217,7 @@ def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: return self.process_entity_name(entity_mention.data_point.text) -class Ab3PEntityPreprocessor(AbstractEntityPreprocessor): +class Ab3PEntityPreprocessor(EntityPreprocessor): """Entity preprocessor which uses Ab3P, an (biomedical) abbreviation definition detector. Abbreviation definition identification based on automatic precision estimates. @@ -229,9 +226,7 @@ class Ab3PEntityPreprocessor(AbstractEntityPreprocessor): https://github.com/ncbi-nlp/Ab3P. """ - def __init__( - self, ab3p_path: Path, word_data_dir: Path, preprocessor: Optional[AbstractEntityPreprocessor] = None - ) -> None: + def __init__(self, ab3p_path: Path, word_data_dir: Path, preprocessor: Optional[EntityPreprocessor] = None) -> None: """Creates the mention pre-processor. :param ab3p_path: Path to the folder containing the Ab3P implementation @@ -277,7 +272,7 @@ def process_entity_name(self, entity_name: str) -> str: return entity_name @classmethod - def load(cls, ab3p_path: Optional[Path] = None, preprocessor: Optional[AbstractEntityPreprocessor] = None): + def load(cls, ab3p_path: Optional[Path] = None, preprocessor: Optional[EntityPreprocessor] = None): data_dir = flair.cache_root / "ab3p" if not data_dir.exists(): data_dir.mkdir(parents=True) @@ -419,9 +414,9 @@ def __init__(self, reader: KnowledgebaseLinkingDictionary): @classmethod def load( - cls, dictionary_name_or_path: Union[Path, str], database_name: Optional[str] = None - ) -> "BiomedicalEntityLinkingDictionary": - """Load dictionary: either pre-definded or from path.""" + cls, dictionary_name_or_path: Union[Path, str], dataset_name: Optional[str] = None + ) -> "KnowledgebaseLinkingDictionary": + """Load dictionary: either pre-defined or from path.""" if isinstance(dictionary_name_or_path, str): dictionary_name_or_path = cast(str, dictionary_name_or_path) @@ -430,7 +425,7 @@ def load( and dictionary_name_or_path not in BIOMEDICAL_DICTIONARIES ): raise ValueError( - f"Unkwnon dictionary `{dictionary_name_or_path}`!" + f"Unknown dictionary `{dictionary_name_or_path}`!" f" Available dictionaries are: {tuple(BIOMEDICAL_DICTIONARIES)}" " If you want to pass a local path please use the `Path` class, " "i.e. `model_name_or_path=Path(my_path)`" @@ -443,72 +438,67 @@ def load( else: # use custom dictionary file assert ( - database_name is not None - ), "When providing a path to a custom dictionary you must specify the `database_name`!" - reader = ParsedBiomedicalEntityLinkingDictionary(path=dictionary_name_or_path, database_name=database_name) + dataset_name is not None + ), "When providing a path to a custom dictionary you must specify the `dataset_name`!" + reader = HunerEntityLinkingDictionary(path=dictionary_name_or_path, dataset_name=dataset_name) - return cls(reader=reader) + return reader @property def database_name(self) -> str: """Database name of the dictionary.""" return self.reader.database_name - def stream(self) -> Iterator[Tuple[str, str]]: - """Stream entries from preprocessed dictionary.""" - yield from self.reader.stream() - def __getitem__(self, item: str) -> Concept: return self.reader[item] -class AbstractCandidateGenerator(ABC): +class CandidateSearchIndex(ABC): """Base class for a candidate generator. Given a mention of an entity, find matching entries from the dictionary. """ @abstractmethod - def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]: - """Returns the top-k entity / concept identifiers for each entity mention. + def index( + self, dictionary: KnowledgebaseLinkingDictionary, preprocessor: Optional[EntityPreprocessor] = None + ) -> None: + """Index a dictionary to prepare for search. - :param entity_mentions: Entity mentions - :param top_k: Number of best-matching entities from the knowledge base to return - :result: List containing a list of entity linking candidates per entity mention from the input + Args: + dictionary: The data to index. + preprocessor: If given, preprocess the concept name and synonyms before indexing. """ - def build_candidate(self, candidate: Tuple[str, str, float], database_name: str) -> Concept: - """Get nice container with all info about entity linking candidate.""" - concept_name = candidate[0] - concept_id = candidate[1] + @abstractmethod + def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]: + """Returns the top-k entity / concept identifiers for each entity mention. - if "|" in concept_id: - labels = concept_id.split("|") - concept_id = labels[0] - additional_labels = labels[1:] - else: - additional_labels = None + Args: + entity_mentions: Entity mentions + top_k: Number of best-matching entities from the knowledge base to return - return Concept( - concept_id=concept_id, - concept_name=concept_name, - additional_ids=additional_labels, - database_name=database_name, - ) + Returns: + List containing a list of entity linking candidates per entity mention from the input + """ -class ExactMatchCandidateGenerator(AbstractCandidateGenerator): +class ExactMatchCandidateSearchIndex(CandidateSearchIndex): """Candidate generator using exact string matching as search criterion.""" - def __init__(self, dictionary: BiomedicalEntityLinkingDictionary): - # Build index which maps concept / entity names to concept / entity ids - self.dictionary = dictionary - self.name_to_id_index = dict(list(dictionary.stream())) + def __init__(self): + self.name_to_id_index: Dict[str, str] = {} - @classmethod - def load(cls, dictionary_name_or_path: Union[str, Path]) -> "ExactMatchCandidateGenerator": - """Compatibility function.""" - return cls(BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path)) + def index( + self, dictionary: KnowledgebaseLinkingDictionary, preprocessor: Optional[EntityPreprocessor] = None + ) -> None: + def p(text: str) -> str: + return preprocessor.process_entity_name(text) if preprocessor is not None else text + + for candidate in dictionary.candidates: + self.name_to_id_index[p(candidate.concept_name)] = candidate.concept_id + for synonym in candidate.synonyms: + self.name_to_id_index[p(synonym)] = candidate.concept_id def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]: results: List[List[Tuple[str, float]]] = [] @@ -522,424 +512,180 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, return results -class BigramTfIDFVectorizer: - """Wrapper for sklearn TfIDFVectorizer w/ fixed ngram range at the character level. - - Implementation adapted from: - - Sung et al.: Biomedical Entity Representations with Synonym Marginalization, 2020 - https://github.com/dmis-lab/BioSyn/tree/master/src/biosyn/sparse_encoder.py#L8 - """ - - def __init__(self): - self.encoder = TfidfVectorizer(analyzer="char", ngram_range=(1, 2)) - - def fit(self, names: List[str]): - """Learn vocabulary.""" - self.encoder.fit(names) - return self - - def transform(self, names: Union[List[str], np.ndarray]) -> csr_matrix: - """Convert strings to sparse vectors.""" - embeddings = self.encoder.transform(names) - return embeddings - - def __call__(self, mentions: Union[List[str], np.ndarray]) -> np.ndarray: - """Short for `transform`.""" - return self.transform(mentions) - - def save(self, path: Path): - """Save vectorizer to disk.""" - joblib.dump(self.encoder, str(path)) - - @classmethod - def load(cls, path: Union[Path, str]) -> "BigramTfIDFVectorizer": - """Instantiate from path.""" - newVectorizer = cls() - - # with open(path, "rb") as fin: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - newVectorizer.encoder = joblib.load(str(path)) - # logger.info("Sparse encoder loaded from %s", path) - - return newVectorizer - - -class BiEncoderCandidateGenerator(AbstractCandidateGenerator): +class SemanticCandidateSearchIndex(CandidateSearchIndex): """Candidate generator using both dense and (optionally) sparse vector representations, to search candidates.""" def __init__( self, - model_name_or_path: Union[str, Path], - dictionary_name_or_path: Union[str, Path], + embeddings: List[DocumentEmbeddings], similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, - preprocessor: AbstractEntityPreprocessor = Ab3PEntityPreprocessor.load(preprocessor=EntityPreprocessor()), - max_length: int = 25, - batch_size: int = 1024, - hybrid_search: bool = False, - sparse_weight: Optional[float] = None, - force_hybrid_search: bool = False, - dictionary: Optional[BiomedicalEntityLinkingDictionary] = None, + weights: Optional[List[float]] = None, + batch_size: int = 128, + show_progress: bool = True, ): - """Initializes the BiEncoderEntityRetrieverModel. - - :param model_name_or_path: Name of or path to the transformer model to be used. - :param dictionary_name_or_path: Name of or path to the transformer model to be used. - :param similarity_metric: which metric to use to compute similarity - :param preprocessor: Preprocessing strategy for entity mentions and names - :param max_length: Maximum number of input tokens to transformer model - :param batch_size: Number of entity mentions/names to embed in one forward pass - :param hybrid_search: Indicates whether to use sparse embeddings or not - :param sparse_weight: Weight to balance sparse and dense similarity scores (default sparse weight) - :param force_hybrid_search: if pre-trained model is not hybrid (dense+sparse) fit a sparse encoder - :param dictionary: optionally pass a custom dictionary - """ - self.model_name_or_path = model_name_or_path - self.dictionary_name_or_path = dictionary_name_or_path - self.preprocessor = preprocessor - self.similarity_metric = similarity_metric - self.max_length = max_length - self.batch_size = batch_size - self.hybrid_search = hybrid_search - self.sparse_weight = sparse_weight - self.force_hybrid_search = force_hybrid_search - if self.force_hybrid_search: - self.hybrid_search = True - - # allow to pass custom dictionary - if dictionary is not None: - self.dictionary = dictionary - else: - self.dictionary = BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path) - - self.dictionary_data: List[Tuple[str, str]] = [ - (self.preprocessor.process_entity_name(name), cui) for name, cui in self.dictionary.stream() - ] - - # Load encoders - self.dense_encoder = TransformerDocumentEmbeddings(model=str(model_name_or_path), is_token_embedding=False) - self.sparse_encoder: Optional[BigramTfIDFVectorizer] = None + """Initializes the EncoderCandidateSearchIndex. - if self.hybrid_search: - self.sparse_encoder, self.sparse_weight = self._get_sparse_encoder_and_weight( - model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path - ) - - self.indices = self._load_indices( - model_name_or_path=model_name_or_path, - dictionary_name_or_path=dictionary_name_or_path, - ) - - @property - def higher_is_better(self): - """Determine if similarity is proportional to score. - - E.g. for L2 lower is better, while INNER_PRODUCT higher is better. + Args: + embeddings: A list of embeddings used for search. + weights: Weight the embedding's importance. + similarity_metric: The metric used to define similarity. + batch_size: The batch size used for indexing embeddings. + show_progress: show the progress while indexing. """ - return self.similarity_metric in [SimilarityMetric.COSINE, SimilarityMetric.INNER_PRODUCT] + if weights is None: + weights = [1.0 for _ in embeddings] + if len(weights) != len(embeddings): + raise ValueError("Weights have to be of the same length as embeddings") - def _get_cache_name(self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path]) -> str: - """Fixed name for caching.""" - # Check for embedded dictionary in cache - dictionary_name = os.path.splitext(os.path.basename(dictionary_name_or_path))[0] - file_name = f"{str(model_name_or_path).split('/')[-1]}_{dictionary_name}" - pp_name = self.preprocessor.name if self.preprocessor is not None else "null" - - return f"{file_name}-{pp_name}" + self.embeddings = embeddings + self.weights = weights + self.similarity_metric = similarity_metric + self.show_progress = show_progress + self.batch_size = batch_size - @timeit - def fit_sparse_encoder(self) -> BigramTfIDFVectorizer: - """Fit sparse encoder to current dictionary.""" - logger.info( - "BiEncoderCandidateGenerator: hybrid model has no pretrained sparse encoder. Fit to dictionary `%s`", - self.dictionary_name_or_path, - ) - sparse_encoder = BigramTfIDFVectorizer().fit([name for name, cui in self.dictionary_data]) - # sparse_encoder.save(Path(sparse_encoder_path)) - # torch.save(torch.FloatTensor(self.sparse_weight), sparse_weight_path) + self.ids: List[str] = [] + self._precomputed_embeddings: np.ndarray = np.array([]) - return sparse_encoder + @classmethod + def bi_encoder( + cls, + model_name_or_path: str, + hybrid_search: bool, + similarity_metric: SimilarityMetric, + batch_size: int = 128, + show_progress: bool = True, + sparse_weight: float = 0.5, + preprocessor: Optional[EntityPreprocessor] = None, + dictionary: Optional[KnowledgebaseLinkingDictionary] = None, + ) -> "SemanticCandidateSearchIndex": + embeddings: List[DocumentEmbeddings] = [TransformerDocumentEmbeddings(model_name_or_path)] + weights = [1.0] + if hybrid_search: + if dictionary is None: + raise ValueError("Require dictionary to be set on hybrid search.") - def _handle_sparse_encoder( - self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path] - ) -> BigramTfIDFVectorizer: - """If necessary fit and cache sparse encoder.""" - if isinstance(model_name_or_path, str): - cache_name = self._get_cache_name( - model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path - ) - path = flair.cache_root / "models" / f"{cache_name}-sparse-encoder.pk" - else: - path = model_name_or_path / "sparse_encoder.pk" + texts = [] - if path.exists(): - sparse_encoder = BigramTfIDFVectorizer.load(path) - else: - sparse_encoder = self.fit_sparse_encoder() - # logger.info("Save fitted sparse encoder to %s", path) - sparse_encoder.save(path) - - return sparse_encoder - - def _get_sparse_encoder_and_weight( - self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path] - ) -> Tuple[BigramTfIDFVectorizer, float]: - sparse_encoder_path = os.path.join(model_name_or_path, "sparse_encoder.pk") - sparse_weight_path = os.path.join(model_name_or_path, "sparse_weight.pt") - - if isinstance(model_name_or_path, str) and model_name_or_path in PRETRAINED_HYBRID_MODELS: - model_name_or_path = cast(str, model_name_or_path) - - if not os.path.exists(sparse_encoder_path): - sparse_encoder_path = hf_hub_download( - repo_id=model_name_or_path, - filename="sparse_encoder.pk", - cache_dir=flair.cache_root / "models" / model_name_or_path, - ) + for candidate in dictionary.candidates: + texts.append(candidate.concept_name) + texts.extend(candidate.synonyms) - sparse_encoder = BigramTfIDFVectorizer.load(path=sparse_encoder_path) + if preprocessor is not None: + texts = [preprocessor.process_entity_name(t) for t in texts] - if not os.path.exists(sparse_weight_path): - sparse_weight_path = hf_hub_download( - repo_id=model_name_or_path, - filename="sparse_weight.pt", - cache_dir=flair.cache_root / "models" / model_name_or_path, + embeddings.append( + DocumentTFIDFEmbeddings( + [Sentence(t) for t in texts], + analyzer="char", + ngram_range=(1, 2), ) - sparse_weight = torch.load(sparse_weight_path, map_location="cpu").item() - else: - sparse_weight = self.sparse_weight if self.sparse_weight is not None else DEFAULT_SPARSE_WEIGHT - sparse_encoder = self._handle_sparse_encoder( - model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path ) + weights = [1.0, sparse_weight] + return cls( + embeddings, + similarity_metric=similarity_metric, + weights=weights, + batch_size=batch_size, + show_progress=show_progress, + ) - return sparse_encoder, sparse_weight - - def embed_sparse(self, inputs: np.ndarray) -> np.ndarray: - """Create sparse embeddings from array of entity mentions/names. - - :param inputs: Numpy array of entity / concept names - :returns Numpy array containing the sparse embeddings of the names - """ - if self.sparse_encoder is None: - raise AssertionError("Error while using the model") - - return self.sparse_encoder(inputs) - - def embed_dense(self, inputs: np.ndarray, batch_size: int = 1024, show_progress: bool = False) -> np.ndarray: - """Create dense embeddings from array of entity mentions/names. + def index( + self, dictionary: KnowledgebaseLinkingDictionary, preprocessor: Optional[EntityPreprocessor] = None + ) -> None: + def p(text: str) -> str: + return preprocessor.process_entity_name(text) if preprocessor is not None else text - :param names: Numpy array of entity / concept names - :param batch_size: Batch size used while embedding the name - :param show_progress: bool to toggle progress bar - :return: Numpy array containing the dense embeddings of the names - """ - self.dense_encoder.eval() # prevent dropout + texts: List[str] = [] + self.ids = [] + for candidate in dictionary.candidates: + texts.append(p(candidate.concept_name)) + self.ids.append(candidate.concept_id) + for synonym in candidate.synonyms: + texts.append(p(synonym)) + self.ids.append(candidate.concept_id) - dense_embeds = [] + precomputed_embeddings = [] with torch.no_grad(): - if show_progress: + if self.show_progress: iterations = tqdm( - range(0, len(inputs), batch_size), - desc=f"Embedding `{self.dictionary.database_name}`", + range(0, len(texts), self.batch_size), + desc=f"Embedding `{dictionary.database_name}`", ) else: - iterations = range(0, len(inputs), batch_size) + iterations = range(0, len(texts), self.batch_size) for start in iterations: - # Create batch - end = min(start + batch_size, len(inputs)) - batch = [Sentence(name) for name in inputs[start:end]] - - # embed batch - self.dense_encoder.embed(batch) - - dense_embeds += [name.embedding.cpu().detach().numpy() for name in batch] - + end = min(start + self.batch_size, len(texts)) + batch = [Sentence(name) for name in texts[start:end]] + + for embedding in self.embeddings: + embedding.embed(batch) + + for sent in batch: + embs = [] + for embedding, weight in zip(self.embeddings, self.weights): + emb = sent.get_embedding(embedding.get_names()) + if self.similarity_metric == SimilarityMetric.COSINE: + emb = emb / torch.norm(emb) + embs.append(emb * weight) + + precomputed_embeddings.append(torch.cat(embs, dim=0).cpu().numpy()) + sent.clear_embeddings() if flair.device.type == "cuda": torch.cuda.empty_cache() - return np.array(dense_embeds) - - # separate method to allow more sophisticated logic in the future, e.g.: ANN with HNSW, PQ... - def get_dense_index(self, names: np.ndarray, path: Path) -> faiss.Index: - """Load or create dense index and save it to disk.""" - if path.exists(): - index = faiss.read_index(str(path)) - - else: - embeddings = self.embed_dense(inputs=np.array(names), batch_size=self.batch_size, show_progress=True) - - index = faiss.IndexFlatIP(embeddings.shape[1]) - index.add(embeddings) - - if self.similarity_metric == SimilarityMetric.COSINE: - faiss.normalize_L2(embeddings) - - faiss.write_index(index, str(path)) - - return index - - def get_sparse_index(self, names: np.ndarray, path: Path) -> csr_matrix: - """Load or create sparse index and save it to disk.""" - if path.exists(): - index = scipy.sparse.load_npz(str(path)) - else: - index = self.embed_sparse(inputs=names) - - scipy.sparse.save_npz(str(path), index) - # index.save_index # HNSWLIB - # index.save # ANNOY - - return index - - def _load_indices(self, model_name_or_path: Union[str, Path], dictionary_name_or_path: Union[str, Path]) -> Dict: - """Load cached indices if available, otherwise compute embeddings, build index and cache.""" - cache_name = self._get_cache_name( - model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path - ) - - cache_folder = flair.cache_root / "datasets" / cache_name - cache_folder.mkdir(parents=True, exist_ok=True) - - indices = {} - - logger.info( - "BiEncoderCandidateGenerator: initialize %s %s", - self.dictionary.database_name, - "indices" if self.hybrid_search else "index", - ) - - for index_type in ["sparse", "dense"]: - if index_type == "sparse" and not self.hybrid_search: - continue - - extension = "bin" if index_type == "dense" else "npz" - file_name = f"index-{index_type}.{extension}" - - index_cache_file = cache_folder / file_name - - names = np.array([n for n, _ in self.dictionary_data]) - - if index_type == "dense": - indices[index_type] = self.get_dense_index(names=names, path=index_cache_file) - - else: - indices[index_type] = self.get_sparse_index(names=names, path=index_cache_file) - - return indices - - @timeit - def search_sparse(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np.ndarray, np.ndarray]: - """Find candidates with sparse representations. - - :param entity_mentions: list of entity mentions (~ queries) - :param top_k: number of candidates to retrieve per mention - """ - assert ( - self.sparse_encoder is not None - ), "BiEncoderCandidateGenerator has no `sparse_encoder`! Pass `force_hybrid_search=True` at initialization" - - mention_embeddings = self.sparse_encoder(entity_mentions) - - if self.similarity_metric == SimilarityMetric.COSINE: - score_matrix = cosine_similarity(mention_embeddings, self.indices["sparse"], dense_output=False) - elif self.similarity_metric == SimilarityMetric.INNER_PRODUCT: - score_matrix = mention_embeddings.dot(self.indices["sparse"].T) - - score_matrix = score_matrix.toarray() - - num_mentions = score_matrix.shape[0] - - unsorted_indices = np.argpartition(score_matrix, -top_k)[:, -top_k:] - unsorted_scores = score_matrix[np.arange(num_mentions)[:, None], unsorted_indices] - - sorted_score_matrix_indices = np.argsort(-unsorted_scores) - - idxs = unsorted_indices[np.arange(num_mentions)[:, None], sorted_score_matrix_indices] - dists = unsorted_scores[np.arange(num_mentions)[:, None], sorted_score_matrix_indices] - - return idxs, dists - - @timeit - def search_dense(self, entity_mentions: List[str], top_k: int = 1) -> Tuple[np.ndarray, np.ndarray]: - """Find candidates with dense representations (FAISS). - - :param entity_mentions: list of entity mentions (~ queries) - :param top_k: number of candidates to retrieve - """ - # Compute dense embedding for the given entity mention - mention_dense_embeds = self.embed_dense(inputs=np.array(entity_mentions), batch_size=self.batch_size) - - if self.similarity_metric == SimilarityMetric.COSINE: - faiss.normalize_L2(mention_dense_embeds) - - # Get candidates from dense embeddings - dists, ids = self.indices["dense"].search(mention_dense_embeds, top_k) - - return ids, dists - - def combine_dense_and_sparse_results( - self, - dense_ids: np.ndarray, - dense_scores: np.ndarray, - sparse_ids: np.ndarray, - sparse_scores: np.ndarray, - top_k: int = 1, - ): - """Expand dense results with sparse ones ans re-weight them. + self._precomputed_embeddings = np.stack(precomputed_embeddings, axis=0) - Re-weight the score as: dense_score + sparse_weight * sparse_scores. - """ - hybrid_ids = [] - hybrid_scores = [] - for i in range(dense_ids.shape[0]): - mention_ids = dense_ids[i] - mention_scores = dense_scores[i] - - mention_spare_ids = sparse_ids[i] - mention_sparse_scores = sparse_scores[i] - - for sparse_id, sparse_score in zip(mention_spare_ids, mention_sparse_scores): - if sparse_id not in mention_ids: - mention_ids = np.append(mention_ids, sparse_id) - mention_scores = np.append(mention_scores, self.sparse_weight * sparse_score) - else: - index = np.where(mention_ids == sparse_id)[0][0] - mention_scores[index] += self.sparse_weight * sparse_score + def emb_search(self, entity_mentions: List[str]) -> np.ndarray: + embeddings = [] - rerank_indices = np.argsort(-mention_scores if self.higher_is_better else mention_scores) - mention_ids = mention_ids[rerank_indices][:top_k] - mention_scores = mention_scores[rerank_indices][:top_k] - hybrid_ids.append(mention_ids.tolist()) - hybrid_scores.append(mention_scores.tolist()) + with torch.no_grad(): + for start in range(0, len(entity_mentions), self.batch_size): + end = min(start + self.batch_size, len(entity_mentions)) + batch = [Sentence(name) for name in entity_mentions[start:end]] + + for embedding in self.embeddings: + embedding.embed(batch) + + for sent in batch: + embs = [] + for embedding in self.embeddings: + emb = sent.get_embedding(embedding.get_names()) + if self.similarity_metric == SimilarityMetric.COSINE: + emb = emb / torch.norm(emb) + embs.append(emb) + + embeddings.append(torch.cat(embs, dim=0).cpu().numpy()) + sent.clear_embeddings() + if flair.device.type == "cuda": + torch.cuda.empty_cache() - return hybrid_scores, hybrid_ids + return np.stack(embeddings, axis=0) def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]: """Returns the top-k entity / concept identifiers for each entity mention. - :param entity_mentions: Entity mentions - :param top_k: Number of best-matching entities from the knowledge base to return - :result: List containing a list of entity linking candidates per entity mention from the input + Args: + entity_mentions: Entity mentions + top_k: Number of best-matching entities from the knowledge base to return + + Returns: + List containing a list of entity linking candidates per entity mention from the input """ - ids, scores = self.search_dense(entity_mentions=entity_mentions, top_k=top_k) - if self.hybrid_search and self.sparse_encoder is not None: - sparse_ids, sparse_scores = self.search_sparse(entity_mentions=entity_mentions, top_k=top_k) + mention_embs = self.emb_search(entity_mentions) + all_scores = mention_embs @ self._precomputed_embeddings.T + selected_indices = np.argsort(all_scores, axis=1)[:, :top_k] + scores = np.take_along_axis(all_scores, selected_indices, axis=1) - scores, ids = self.combine_dense_and_sparse_results( - dense_ids=ids, - dense_scores=scores, - sparse_scores=sparse_scores, - sparse_ids=sparse_ids, - top_k=top_k, + results = [] + for i in range(selected_indices.shape[0]): + results.append( + [(self.ids[selected_indices[i, j]], float(scores[i, j])) for j in range(selected_indices.shape[1])] ) - return [ - [(self.dictionary_data[i][1].split("|")[0], score) for i, score in zip(mention_ids, mention_scores)] - for mention_ids, mention_scores in zip(ids, scores) - ] + + return results class EntityMentionLinker: @@ -947,16 +693,18 @@ class EntityMentionLinker: def __init__( self, - candidate_generator: AbstractCandidateGenerator, - preprocessor: AbstractEntityPreprocessor, + candidate_generator: CandidateSearchIndex, + preprocessor: EntityPreprocessor, entity_type: str, label_type: str, + dictionary: KnowledgebaseLinkingDictionary, ): self.preprocessor = preprocessor self.candidate_generator = candidate_generator self.entity_type = entity_type self.annotation_layers = [self.entity_type] self._label_type = label_type + self._dictionary = dictionary @property def label_type(self): @@ -964,7 +712,7 @@ def label_type(self): @property def dictionary(self) -> KnowledgebaseLinkingDictionary: - return self.candidate_generator.dictionary + return self._dictionary def extract_mentions( self, @@ -1036,14 +784,14 @@ def load( label_type: str, dictionary_name_or_path: Optional[Union[str, Path]] = None, hybrid_search: bool = True, - max_length: int = 25, - batch_size: int = 1024, + batch_size: int = 128, similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, - preprocessor: AbstractEntityPreprocessor = EntityPreprocessor(), + preprocessor: EntityPreprocessor = BioSynEntityPreprocessor(), force_hybrid_search: bool = False, sparse_weight: float = DEFAULT_SPARSE_WEIGHT, entity_type: Optional[str] = None, - dictionary: Optional[BiomedicalEntityLinkingDictionary] = None, + dictionary: Optional[KnowledgebaseLinkingDictionary] = None, + dataset_name: Optional[str] = None, ) -> "EntityMentionLinker": """Loads a model for biomedical named entity normalization. @@ -1053,10 +801,12 @@ def load( raise AssertionError(f"String matching model name has to be an string (and not {type(model_name_or_path)}") model_name_or_path = cast(str, model_name_or_path) - if dictionary_name_or_path is None or isinstance(dictionary_name_or_path, str): - dictionary_name_or_path = cls.__get_dictionary_path( - model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path - ) + if dictionary is None: + if dictionary_name_or_path is None or isinstance(dictionary_name_or_path, str): + dictionary_name_or_path = cls.__get_dictionary_path( + model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path + ) + dictionary = BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path, dataset_name=dataset_name) if isinstance(model_name_or_path, str): model_name_or_path, entity_type = cls.__get_model_path_and_entity_type( @@ -1070,21 +820,20 @@ def load( assert entity_type in ENTITY_TYPES, f"Invalid entity type `{entity_type}! Must be one of: {ENTITY_TYPES}" if model_name_or_path == "exact-string-match": - candidate_generator: AbstractCandidateGenerator = ExactMatchCandidateGenerator.load(dictionary_name_or_path) + candidate_generator: CandidateSearchIndex = ExactMatchCandidateSearchIndex() else: - candidate_generator = BiEncoderCandidateGenerator( - model_name_or_path=model_name_or_path, - dictionary_name_or_path=dictionary_name_or_path, + candidate_generator = SemanticCandidateSearchIndex.bi_encoder( + model_name_or_path=str(model_name_or_path), hybrid_search=hybrid_search, similarity_metric=similarity_metric, - max_length=max_length, batch_size=batch_size, sparse_weight=sparse_weight, preprocessor=preprocessor, - force_hybrid_search=force_hybrid_search, dictionary=dictionary, ) + candidate_generator.index(dictionary, preprocessor) + logger.info( "BiomedicalEntityLinker predicts: Dictionary `%s` (entity type: %s)", dictionary_name_or_path, entity_type ) @@ -1094,6 +843,7 @@ def load( preprocessor=preprocessor, entity_type=entity_type, label_type=label_type, + dictionary=dictionary, ) @staticmethod diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index 6c9e9b8a8..a87f2703c 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -13,36 +13,36 @@ def test_bel_dictionary(): i.e. they can change over time. """ dictionary = BiomedicalEntityLinkingDictionary.load("diseases") - _, identifier = next(dictionary.stream()) - assert identifier.startswith(("MESH:", "OMIM:", "DO:DOID")) + candidate = dictionary.candidates[0] + assert candidate.concept_id.startswith(("MESH:", "OMIM:", "DO:DOID")) dictionary = BiomedicalEntityLinkingDictionary.load("ctd-diseases") - _, identifier = next(dictionary.stream()) - assert identifier.startswith("MESH:") + candidate = dictionary.candidates[0] + assert candidate.concept_id.startswith("MESH:") dictionary = BiomedicalEntityLinkingDictionary.load("ctd-chemicals") - _, identifier = next(dictionary.stream()) - assert identifier.startswith("MESH:") + candidate = dictionary.candidates[0] + assert candidate.concept_id.startswith("MESH:") dictionary = BiomedicalEntityLinkingDictionary.load("chemical") - _, identifier = next(dictionary.stream()) - assert identifier.startswith("MESH:") + candidate = dictionary.candidates[0] + assert candidate.concept_id.startswith("MESH:") dictionary = BiomedicalEntityLinkingDictionary.load("ncbi-taxonomy") - _, identifier = next(dictionary.stream()) - assert identifier.isdigit() + candidate = dictionary.candidates[0] + assert candidate.concept_id.isdigit() dictionary = BiomedicalEntityLinkingDictionary.load("species") - _, identifier = next(dictionary.stream()) - assert identifier.isdigit() + candidate = dictionary.candidates[0] + assert candidate.concept_id.isdigit() dictionary = BiomedicalEntityLinkingDictionary.load("ncbi-gene") - _, identifier = next(dictionary.stream()) - assert identifier.isdigit() + candidate = dictionary.candidates[0] + assert candidate.concept_id.isdigit() dictionary = BiomedicalEntityLinkingDictionary.load("genes") - _, identifier = next(dictionary.stream()) - assert identifier.isdigit() + candidate = dictionary.candidates[0] + assert candidate.concept_id.isdigit() def test_biomedical_entity_linking(): @@ -52,11 +52,11 @@ def test_biomedical_entity_linking(): tagger.predict(sentence) disease_linker = EntityMentionLinker.load("diseases", "diseases-nel", hybrid_search=True) - disease_dictionary = disease_linker.candidate_generator.dictionary + disease_dictionary = disease_linker.dictionary disease_linker.predict(sentence) gene_linker = EntityMentionLinker.load("genes", "genes-nel", hybrid_search=False, entity_type="genes") - gene_dictionary = gene_linker.candidate_generator.dictionary + gene_dictionary = gene_linker.dictionary gene_linker.predict(sentence) From e398213079c53f502c6be783163d32f7aca9cebe Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 18 Sep 2023 15:11:02 +0200 Subject: [PATCH 30/58] fix rebase errors --- flair/data.py | 1 - flair/datasets/knowledgebase.py | 2 +- flair/models/biomedical_entity_linking.py | 23 +++-------------------- tests/test_datasets_biomedical.py | 7 +++---- 4 files changed, 7 insertions(+), 26 deletions(-) diff --git a/flair/data.py b/flair/data.py index e2be6162f..b52395df4 100644 --- a/flair/data.py +++ b/flair/data.py @@ -839,7 +839,6 @@ def __init__( # log a warning if the dataset is empty if text == "": log.warning("Warning: An empty Sentence was created! Are there empty strings in your dataset?") - breakpoint() @property def unlabeled_identifier(self): diff --git a/flair/datasets/knowledgebase.py b/flair/datasets/knowledgebase.py index 2870e31c0..493107fcf 100644 --- a/flair/datasets/knowledgebase.py +++ b/flair/datasets/knowledgebase.py @@ -1,6 +1,6 @@ import csv from pathlib import Path -from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Dict, Iterable, Iterator, List, Optional, Union import flair from flair.data import Concept diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index 6d47932b4..a5e9fa159 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -6,21 +6,14 @@ import subprocess import tempfile import time -import warnings from abc import ABC, abstractmethod from collections import defaultdict from enum import Enum, auto from pathlib import Path -from typing import Dict, Iterator, List, Optional, Tuple, Type, Union, cast +from typing import Dict, List, Optional, Tuple, Type, Union, cast -import joblib import numpy as np -import scipy import torch -from huggingface_hub import hf_hub_download -from scipy.sparse import csr_matrix -from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm import flair @@ -28,23 +21,14 @@ from flair.datasets import ( CTD_CHEMICALS_DICTIONARY, CTD_DISEASES_DICTIONARY, - HunerEntityLinkingDictionary, NCBI_GENE_HUMAN_DICTIONARY, NCBI_TAXONOMY_DICTIONARY, + HunerEntityLinkingDictionary, KnowledgebaseLinkingDictionary, ) -from flair.embeddings import TransformerDocumentEmbeddings, DocumentTFIDFEmbeddings, DocumentEmbeddings +from flair.embeddings import DocumentEmbeddings, DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings from flair.file_utils import cached_path -FAISS_VERSION = "1.7.4" - -try: - import faiss -except ImportError as error: - raise ImportError( - f"You need to install faiss to run the biomedical entity linking: `pip install faiss-cpu=={FAISS_VERSION}`" - ) from error - logger = logging.getLogger("flair") PRETRAINED_DENSE_MODELS = [ @@ -673,7 +657,6 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, Returns: List containing a list of entity linking candidates per entity mention from the input """ - mention_embs = self.emb_search(entity_mentions) all_scores = mention_embs @ self._precomputed_embeddings.T selected_indices = np.argsort(all_scores, axis=1)[:, :top_k] diff --git a/tests/test_datasets_biomedical.py b/tests/test_datasets_biomedical.py index fbff952ae..2ba4cf3d6 100644 --- a/tests/test_datasets_biomedical.py +++ b/tests/test_datasets_biomedical.py @@ -182,7 +182,7 @@ def assert_conll_writer_output( assert contents == expected_output -def test_filter_nested_entities(recwarn): +def test_filter_nested_entities(caplog): entities_per_document = { "d0": [Entity((0, 1), "t0"), Entity((2, 3), "t1")], "d1": [Entity((0, 6), "t0"), Entity((2, 3), "t1"), Entity((4, 5), "t2")], @@ -204,11 +204,10 @@ def test_filter_nested_entities(recwarn): } dataset = InternalBioNerDataset(documents={}, entities_per_document=entities_per_document) + caplog.set_level(logging.WARNING) filter_nested_entities(dataset) - assert len(recwarn.list) == 1 - assert isinstance(recwarn.list[0].message, UserWarning) - assert "Corpus modified by filtering nested entities." in recwarn.list[0].message.args[0] + assert "WARNING: Corpus modified by filtering nested entities." in caplog.text for key, entities in dataset.entities_per_document.items(): assert key in target From 51a1a57fb163b60b39df4c33ff7f275b2c8b4601 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 18 Sep 2023 17:54:47 +0200 Subject: [PATCH 31/58] WIP: add save functionality --- flair/class_utils.py | 19 ++ flair/datasets/knowledgebase.py | 4 +- flair/models/biomedical_entity_linking.py | 309 +++++++++--------- flair/nn/model.py | 14 +- .../HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md | 14 +- tests/test_biomedical_entity_linking.py | 26 +- 6 files changed, 201 insertions(+), 185 deletions(-) create mode 100644 flair/class_utils.py diff --git a/flair/class_utils.py b/flair/class_utils.py new file mode 100644 index 000000000..842a53387 --- /dev/null +++ b/flair/class_utils.py @@ -0,0 +1,19 @@ +import inspect +from typing import Iterable, Optional, Type, TypeVar + +T = TypeVar("T") + + +def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]: + for subclass in cls.__subclasses__(): + yield from get_non_abstract_subclasses(subclass) + if inspect.isabstract(subclass): + continue + yield subclass + + +def get_state_subclass_by_name(cls: Type[T], cls_name: Optional[str]) -> Type[T]: + for sub_cls in get_non_abstract_subclasses(cls): + if sub_cls.__name__ == cls_name: + return sub_cls + raise ValueError(f"Could not find any class with name '{cls_name}'") diff --git a/flair/datasets/knowledgebase.py b/flair/datasets/knowledgebase.py index 493107fcf..142893c30 100644 --- a/flair/datasets/knowledgebase.py +++ b/flair/datasets/knowledgebase.py @@ -66,8 +66,8 @@ class HunerEntityLinkingDictionary(KnowledgebaseLinkingDictionary): 7157||TP53|tumor protein p53 """ - def __init__(self, path: Path, dataset_name: str): - self.dataset_file = path + def __init__(self, path: Union[str, Path], dataset_name: str): + self.dataset_file = Path(path) self._dataset_name = dataset_name super().__init__(self._load_candidates(), dataset_name=dataset_name) diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index a5e9fa159..6a7c8220f 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -1,3 +1,4 @@ +import inspect import logging import os import re @@ -5,19 +6,19 @@ import string import subprocess import tempfile -import time from abc import ABC, abstractmethod from collections import defaultdict from enum import Enum, auto from pathlib import Path -from typing import Dict, List, Optional, Tuple, Type, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast import numpy as np import torch from tqdm import tqdm import flair -from flair.data import Concept, Label, Sentence, Span +from flair.class_utils import get_state_subclass_by_name +from flair.data import Label, Sentence, Span from flair.datasets import ( CTD_CHEMICALS_DICTIONARY, CTD_DISEASES_DICTIONARY, @@ -96,67 +97,77 @@ DEFAULT_SPARSE_WEIGHT = 0.5 -class SimilarityMetric(Enum): - """Similarity metrics.""" +def load_dictionary( + dictionary_name_or_path: Union[Path, str], dataset_name: Optional[str] = None +) -> KnowledgebaseLinkingDictionary: + """Load dictionary: either pre-defined or from path.""" + if isinstance(dictionary_name_or_path, str) and ( + dictionary_name_or_path in ENTITY_TYPE_TO_DICTIONARY or dictionary_name_or_path in BIOMEDICAL_DICTIONARIES + ): + dictionary_name_or_path = ENTITY_TYPE_TO_DICTIONARY.get(dictionary_name_or_path, dictionary_name_or_path) - INNER_PRODUCT = auto() - COSINE = auto() + return BIOMEDICAL_DICTIONARIES[str(dictionary_name_or_path)]() + if dataset_name is None: + raise ValueError("When loading a custom dictionary, you need to specify a dataset_name!") + return HunerEntityLinkingDictionary(path=dictionary_name_or_path, dataset_name=dataset_name) -def timeit(func): - """This function shows the execution time of the function object passed.""" - def wrap_func(*args, **kwargs): - start = time.time() - result = func(*args, **kwargs) - elapsed = round(time.time() - start, 4) - class_name, func_name = func.__qualname__.split(".") - logger.info("%s: %s took ~%s", class_name, func_name, elapsed) - return result +class SimilarityMetric(Enum): + """Similarity metrics.""" - return wrap_func + INNER_PRODUCT = auto() + COSINE = auto() class EntityPreprocessor(ABC): - """A pre-processor used to transform / clean both entity mentions and entity names. + """A pre-processor used to transform / clean both entity mentions and entity names.""" - This class provides the basic interface for such transformations - and must provide a `name` attribute to uniquely identify the type of preprocessing applied. - """ + def initialize(self, sentences: List[Sentence]) -> None: + """Initializes the pre-processor for a batch of sentences. - @property - @abstractmethod - def name(self) -> str: - """This is needed to correctly cache different multiple version of the dictionary.""" + This may be necessary for more sophisticated transformations. + + Args: + sentences: List of sentences that will be processed. + """ - @abstractmethod def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: """Processes the given entity mention and applies the transformation procedure to it. - :param entity_mention: entity mention under investigation - :param sentence: sentence in which the entity mentioned occurred - :result: Cleaned / transformed string representation of the given entity mention + Usually just forwards the entity_mention to :meth:`EntityPreprocessor.process_entity_name`, but can be implemented + to preprocess mentions on a sentence level instead. + + Args: + entity_mention: entity mention under investigation + sentence: sentence in which the entity mentioned occurred + + Returns: + Cleaned / transformed string representation of the given entity mention """ + return self.process_entity_name(entity_mention.data_point.text) @abstractmethod def process_entity_name(self, entity_name: str) -> str: """Processes the given entity name and applies the transformation procedure to it. Args: - entity_name: entity mention given as DataPoint + entity_name: the text of the entity mention + Returns: Cleaned / transformed string representation of the given entity mention """ - @abstractmethod - def initialize(self, sentences: List[Sentence]): - """Initializes the pre-processor for a batch of sentences. + @classmethod + def _from_state(cls, state_dict: Dict[str, Any]) -> "EntityPreprocessor": + if inspect.isabstract(cls): + cls_name = state_dict.pop("__cls__", None) + return get_state_subclass_by_name(cls, cls_name)._from_state(state_dict) + else: + return cls(**state_dict) - This may be necessary for more sophisticated transformations. - - Args: - sentences: List of sentences that will be processed. - """ + def _get_state(self) -> Dict[str, Any]: + return {"__cls__": self.__class__.__name__} class BioSynEntityPreprocessor(EntityPreprocessor): @@ -173,20 +184,14 @@ class BioSynEntityPreprocessor(EntityPreprocessor): def __init__(self, lowercase: bool = True, remove_punctuation: bool = True): """Initializes the mention preprocessor. - :param lowercase: Indicates whether to perform lowercasing or not (True by default) - :param remove_punctuation: Indicates whether to perform removal punctuations symbols (True by default) + Args: + lowercase: Indicates whether to perform lowercasing or not + remove_punctuation: Indicates whether to perform removal punctuations symbols """ self.lowercase = lowercase self.remove_punctuation = remove_punctuation self.rmv_puncts_regex = re.compile(rf"[\s{re.escape(string.punctuation)}]+") - @property - def name(self): - return "biosyn" - - def initialize(self, sentences): - pass - def process_entity_name(self, entity_name: str) -> str: if self.lowercase: entity_name = entity_name.lower() @@ -197,8 +202,12 @@ def process_entity_name(self, entity_name: str) -> str: return entity_name.strip() - def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: - return self.process_entity_name(entity_mention.data_point.text) + def _get_state(self) -> Dict[str, Any]: + return { + **super()._get_state(), + "lowercase": self.lowercase, + "remove_punctuation": self.remove_punctuation, + } class Ab3PEntityPreprocessor(EntityPreprocessor): @@ -213,19 +222,16 @@ class Ab3PEntityPreprocessor(EntityPreprocessor): def __init__(self, ab3p_path: Path, word_data_dir: Path, preprocessor: Optional[EntityPreprocessor] = None) -> None: """Creates the mention pre-processor. - :param ab3p_path: Path to the folder containing the Ab3P implementation - :param word_data_dir: Path to the word data directory - :param preprocessor: Basic entity preprocessor + Args: + ab3p_path: Path to the folder containing the Ab3P implementation + word_data_dir: Path to the word data directory + preprocessor: Basic entity preprocessor """ self.ab3p_path = ab3p_path self.word_data_dir = word_data_dir self.preprocessor = preprocessor self.abbreviation_dict: Dict[str, Dict[str, str]] = {} - @property - def name(self): - return f"ab3p_{self.preprocessor.name}" - def initialize(self, sentences: List[Sentence]) -> None: self.abbreviation_dict = self._build_abbreviation_dict(sentences) @@ -256,8 +262,8 @@ def process_entity_name(self, entity_name: str) -> str: return entity_name @classmethod - def load(cls, ab3p_path: Optional[Path] = None, preprocessor: Optional[EntityPreprocessor] = None): - data_dir = flair.cache_root / "ab3p" + def load_biosyn(cls, preprocessor: Optional[EntityPreprocessor] = None): + data_dir = flair.cache_root / "ab3p_biosyn" if not data_dir.exists(): data_dir.mkdir(parents=True) @@ -265,13 +271,12 @@ def load(cls, ab3p_path: Optional[Path] = None, preprocessor: Optional[EntityPre if not word_data_dir.exists(): word_data_dir.mkdir() - if ab3p_path is None: - ab3p_path = cls.download_ab3p(data_dir, word_data_dir) + ab3p_path = cls._download_biosyn_ab3p(data_dir, word_data_dir) return cls(ab3p_path, word_data_dir, preprocessor) @classmethod - def download_ab3p(cls, data_dir: Path, word_data_dir: Path) -> Path: + def _download_biosyn_ab3p(cls, data_dir: Path, word_data_dir: Path) -> Path: """Downloads the Ab3P tool and all necessary data files.""" # Download word data for Ab3P if not already downloaded ab3p_url = "https://raw.githubusercontent.com/dmis-lab/BioSyn/master/Ab3P/WordData/" @@ -317,8 +322,11 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict {"RSV": "Rous sarcoma virus"} } - :param sentences: list of sentences - :result abbreviation_dict: abbreviations and their resolution detected in each input sentence + Args: + sentences: list of sentences + + Returns: + abbreviation_dict: abbreviations and their resolution detected in each input sentence """ abbreviation_dict: Dict = defaultdict(dict) @@ -379,62 +387,23 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict return abbreviation_dict - -class BiomedicalEntityLinkingDictionary: - """Class to load named entity dictionaries. - - Loading either pre-defined or from a path on disk. - For the latter, every line in the file must be formatted as follows: - - concept_id||concept_name - - If multiple concept ids are associated to a given name they must be separated by a `|`, e.g. - - 7157||TP53|tumor protein p53 - """ - - def __init__(self, reader: KnowledgebaseLinkingDictionary): - self.reader = reader + def _get_state(self) -> Dict[str, Any]: + return { + **super()._get_state(), + "ab3p_path": str(self.ab3p_path), + "word_data_dir": str(self.word_data_dir), + "preprocessor": None if self.preprocessor is None else self.preprocessor._get_state(), + } @classmethod - def load( - cls, dictionary_name_or_path: Union[Path, str], dataset_name: Optional[str] = None - ) -> "KnowledgebaseLinkingDictionary": - """Load dictionary: either pre-defined or from path.""" - if isinstance(dictionary_name_or_path, str): - dictionary_name_or_path = cast(str, dictionary_name_or_path) - - if ( - dictionary_name_or_path not in ENTITY_TYPE_TO_DICTIONARY - and dictionary_name_or_path not in BIOMEDICAL_DICTIONARIES - ): - raise ValueError( - f"Unknown dictionary `{dictionary_name_or_path}`!" - f" Available dictionaries are: {tuple(BIOMEDICAL_DICTIONARIES)}" - " If you want to pass a local path please use the `Path` class, " - "i.e. `model_name_or_path=Path(my_path)`" - ) - - dictionary_name_or_path = ENTITY_TYPE_TO_DICTIONARY.get(dictionary_name_or_path, dictionary_name_or_path) - - reader = BIOMEDICAL_DICTIONARIES[str(dictionary_name_or_path)]() - - else: - # use custom dictionary file - assert ( - dataset_name is not None - ), "When providing a path to a custom dictionary you must specify the `dataset_name`!" - reader = HunerEntityLinkingDictionary(path=dictionary_name_or_path, dataset_name=dataset_name) - - return reader - - @property - def database_name(self) -> str: - """Database name of the dictionary.""" - return self.reader.database_name - - def __getitem__(self, item: str) -> Concept: - return self.reader[item] + def _from_state(cls, state_dict: Dict[str, Any]) -> "EntityPreprocessor": + return cls( + ab3p_path=Path(state_dict["ad3p_path"]), + word_data_dir=Path(state_dict["word_data_dir"]), + preprocessor=None + if state_dict["preprocessor"] is None + else EntityPreprocessor._from_state(state_dict["preprocessor"]), + ) class CandidateSearchIndex(ABC): @@ -466,11 +435,27 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, List containing a list of entity linking candidates per entity mention from the input """ + @classmethod + def _from_state(cls, state_dict: Dict[str, Any]) -> "CandidateSearchIndex": + if inspect.isabstract(cls): + cls_name = state_dict.pop("__cls__", None) + return get_state_subclass_by_name(cls, cls_name)._from_state(state_dict) + else: + return cls(**state_dict) + + def _get_state(self) -> Dict[str, Any]: + return {"__cls__": self.__class__.__name__} + class ExactMatchCandidateSearchIndex(CandidateSearchIndex): """Candidate generator using exact string matching as search criterion.""" def __init__(self): + """Candidate generator using exact string matching as search criterion. + + Args: + name_to_id_index: internal state, should only be set when loading an initialized index. + """ self.name_to_id_index: Dict[str, str] = {} def index( @@ -495,6 +480,18 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, return results + @classmethod + def _from_state(cls, state_dict: Dict[str, Any]) -> "CandidateSearchIndex": + index = cls() + index.name_to_id_index = state_dict["name_to_id_index"] + return index + + def _get_state(self) -> Dict[str, Any]: + return { + **super()._get_state(), + "name_to_id_index": self.name_to_id_index, + } + class SemanticCandidateSearchIndex(CandidateSearchIndex): """Candidate generator using both dense and (optionally) sparse vector representations, to search candidates.""" @@ -670,6 +667,31 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, return results + @classmethod + def _from_state(cls, state_dict: Dict[str, Any]) -> "CandidateSearchIndex": + index = cls( + embeddings=[DocumentEmbeddings.load_embedding(emb) for emb in state_dict["embeddings"]], + similarity_metric=SimilarityMetric(state_dict["similarity_metric"]), + weights=state_dict["weights"], + batch_size=state_dict["batch_size"], + show_progress=state_dict["show_progress"], + ) + index.ids = state_dict["ids"] + index._precomputed_embeddings = state_dict["precomputed_embeddings"] + return index + + def _get_state(self) -> Dict[str, Any]: + return { + **super()._get_state(), + "embeddings": [emb.save_embeddings() for emb in self.embeddings], + "similarity_metric": self.similarity_metric.value, + "weights": self.weights, + "batch_size": self.batch_size, + "show_progress": self.show_progress, + "ids": self.ids, + "precomputed_embeddings": self._precomputed_embeddings, + } + class EntityMentionLinker: """Entity linking model for the biomedical domain.""" @@ -678,14 +700,13 @@ def __init__( self, candidate_generator: CandidateSearchIndex, preprocessor: EntityPreprocessor, - entity_type: str, + entity_label_type: str, label_type: str, dictionary: KnowledgebaseLinkingDictionary, ): self.preprocessor = preprocessor self.candidate_generator = candidate_generator - self.entity_type = entity_type - self.annotation_layers = [self.entity_type] + self.entity_label_type = entity_label_type self._label_type = label_type self._dictionary = dictionary @@ -700,42 +721,32 @@ def dictionary(self) -> KnowledgebaseLinkingDictionary: def extract_mentions( self, sentences: List[Sentence], - annotation_layers: Optional[List[str]] = None, - ) -> Tuple[List[Span], List[str], List[str]]: + ) -> Tuple[List[Span], List[str]]: """Unpack all mentions in sentences for batch search.""" data_points = [] mentions = [] - mention_annotation_layers = [] - - # use default annotation layers only if are not provided - annotation_layers = annotation_layers if annotation_layers is not None else self.annotation_layers for sentence in sentences: - for annotation_layer in annotation_layers: - for entity in sentence.get_labels(annotation_layer): - data_points.append(entity.data_point) - mentions.append( - self.preprocessor.process_mention(entity, sentence) - if self.preprocessor is not None - else entity.data_point.text, - ) - mention_annotation_layers.append(annotation_layer) - - # assert len(mentions) > 0, f"There are no entity mentions of type `{self.entity_type}`" + for entity in sentence.get_labels(self.entity_label_type): + data_points.append(entity.data_point) + mentions.append( + self.preprocessor.process_mention(entity, sentence) + if self.preprocessor is not None + else entity.data_point.text, + ) - return data_points, mentions, mention_annotation_layers + return data_points, mentions def predict( self, sentences: Union[List[Sentence], Sentence], - annotation_layers: Optional[List[str]] = None, top_k: int = 1, ) -> None: """Predicts the best matching top-k entity / concept identifiers of all named entities annotated with tag input_entity_annotation_layer. - :param sentences: One or more sentences to run the prediction on - :param annotation_layers: List of annotation layers to extract entity mentions - :param top_k: Number of best-matching entity / concept identifiers + Args: + sentences: One or more sentences to run the prediction on + top_k: Number of best-matching entity / concept identifiers """ # make sure sentences is a list of sentences if not isinstance(sentences, list): @@ -744,9 +755,7 @@ def predict( if self.preprocessor is not None: self.preprocessor.initialize(sentences) - data_points, mentions, mentions_annotation_layers = self.extract_mentions( - sentences=sentences, annotation_layers=annotation_layers - ) + data_points, mentions = self.extract_mentions(sentences=sentences) # no mentions: nothing to do here if len(mentions) > 0: @@ -754,14 +763,12 @@ def predict( candidates = self.candidate_generator.search(entity_mentions=mentions, top_k=top_k) # Add a label annotation for each candidate - for data_point, mention_candidates, mentions_annotation_layer in zip( - data_points, candidates, mentions_annotation_layers - ): + for data_point, mention_candidates in zip(data_points, candidates): for candidate_id, confidence in mention_candidates: data_point.add_label(self.label_type, candidate_id, confidence) @classmethod - def load( + def build( cls, model_name_or_path: Union[str, Path], label_type: str, @@ -789,7 +796,7 @@ def load( dictionary_name_or_path = cls.__get_dictionary_path( model_name_or_path=model_name_or_path, dictionary_name_or_path=dictionary_name_or_path ) - dictionary = BiomedicalEntityLinkingDictionary.load(dictionary_name_or_path, dataset_name=dataset_name) + dictionary = load_dictionary(dictionary_name_or_path, dataset_name=dataset_name) if isinstance(model_name_or_path, str): model_name_or_path, entity_type = cls.__get_model_path_and_entity_type( @@ -824,7 +831,7 @@ def load( return cls( candidate_generator=candidate_generator, preprocessor=preprocessor, - entity_type=entity_type, + entity_label_type=entity_type, label_type=label_type, dictionary=dictionary, ) diff --git a/flair/nn/model.py b/flair/nn/model.py index b339f1994..96b2c2d92 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -13,6 +13,7 @@ from tqdm import tqdm import flair +from flair.class_utils import get_non_abstract_subclasses from flair.data import DT, DT2, Corpus, Dictionary, Sentence, _iter_dataset from flair.datasets import DataLoader, FlairDatapointDataset from flair.embeddings import Embeddings @@ -137,7 +138,7 @@ def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "Model": # if this class is abstract, go through all inheriting classes and try to fetch and load the model if inspect.isabstract(cls): # get all non-abstract subclasses - subclasses = get_non_abstract_subclasses(cls) + subclasses = list(get_non_abstract_subclasses(cls)) # try to fetch the model for each subclass. if fetching is possible, load model and return it for model_cls in subclasses: @@ -976,14 +977,3 @@ def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "DefaultClassifie from typing import cast return cast("DefaultClassifier", super().load(model_path=model_path)) - - -def get_non_abstract_subclasses(cls): - all_subclasses = [] - for subclass in cls.__subclasses__(): - all_subclasses.extend(get_non_abstract_subclasses(subclass)) - if inspect.isabstract(subclass): - continue - all_subclasses.append(subclass) - - return all_subclasses diff --git a/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md index 973e6c328..6e7d9790c 100644 --- a/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md +++ b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md @@ -18,16 +18,16 @@ sentence = Sentence( ner_tagger = Classifier.load("hunflair") ner_tagger.predict(sentence) -nen_tagger = EntityMentionLinker.load("disease") +nen_tagger = EntityMentionLinker.build("disease") nen_tagger.predict(sentence) -nen_tagger = EntityMentionLinker.load("gene") +nen_tagger = EntityMentionLinker.build("gene") nen_tagger.predict(sentence) -nen_tagger = EntityMentionLinker.load("chemical") +nen_tagger = EntityMentionLinker.build("chemical") nen_tagger.predict(sentence) -nen_tagger = EntityMentionLinker.load("species", entity_type="species") +nen_tagger = EntityMentionLinker.build("species", entity_type="species") nen_tagger.predict(sentence) for tag in sentence.get_labels(): @@ -55,8 +55,8 @@ You can also provide your own model and dictionary: ```python from flair.models.biomedical_entity_linking import EntityMentionLinker -nen_tagger = EntityMentionLinker.load("name_or_path_to_your_model", - dictionary_names_or_path="name_or_path_to_your_dictionary") -nen_tagger = EntityMentionLinker.load("path_to_custom_disease_model", dictionary_names_or_path="disease") +nen_tagger = EntityMentionLinker.build("name_or_path_to_your_model", + dictionary_names_or_path="name_or_path_to_your_dictionary") +nen_tagger = EntityMentionLinker.build("path_to_custom_disease_model", dictionary_names_or_path="disease") ```` You can use any combination of provided models, provided dictionaries and your own. diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index a87f2703c..b4b4475ec 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -1,7 +1,7 @@ from flair.data import Sentence from flair.models.biomedical_entity_linking import ( - BiomedicalEntityLinkingDictionary, EntityMentionLinker, + load_dictionary, ) from flair.nn import Classifier @@ -12,35 +12,35 @@ def test_bel_dictionary(): Hard to define a good test as dictionaries are DYNAMIC, i.e. they can change over time. """ - dictionary = BiomedicalEntityLinkingDictionary.load("diseases") + dictionary = load_dictionary("diseases") candidate = dictionary.candidates[0] assert candidate.concept_id.startswith(("MESH:", "OMIM:", "DO:DOID")) - dictionary = BiomedicalEntityLinkingDictionary.load("ctd-diseases") + dictionary = load_dictionary("ctd-diseases") candidate = dictionary.candidates[0] assert candidate.concept_id.startswith("MESH:") - dictionary = BiomedicalEntityLinkingDictionary.load("ctd-chemicals") + dictionary = load_dictionary("ctd-chemicals") candidate = dictionary.candidates[0] assert candidate.concept_id.startswith("MESH:") - dictionary = BiomedicalEntityLinkingDictionary.load("chemical") + dictionary = load_dictionary("chemical") candidate = dictionary.candidates[0] assert candidate.concept_id.startswith("MESH:") - dictionary = BiomedicalEntityLinkingDictionary.load("ncbi-taxonomy") + dictionary = load_dictionary("ncbi-taxonomy") candidate = dictionary.candidates[0] assert candidate.concept_id.isdigit() - dictionary = BiomedicalEntityLinkingDictionary.load("species") + dictionary = load_dictionary("species") candidate = dictionary.candidates[0] assert candidate.concept_id.isdigit() - dictionary = BiomedicalEntityLinkingDictionary.load("ncbi-gene") + dictionary = load_dictionary("ncbi-gene") candidate = dictionary.candidates[0] assert candidate.concept_id.isdigit() - dictionary = BiomedicalEntityLinkingDictionary.load("genes") + dictionary = load_dictionary("genes") candidate = dictionary.candidates[0] assert candidate.concept_id.isdigit() @@ -51,24 +51,24 @@ def test_biomedical_entity_linking(): tagger = Classifier.load("hunflair") tagger.predict(sentence) - disease_linker = EntityMentionLinker.load("diseases", "diseases-nel", hybrid_search=True) + disease_linker = EntityMentionLinker.build("diseases", "diseases-nel", hybrid_search=True) disease_dictionary = disease_linker.dictionary disease_linker.predict(sentence) - gene_linker = EntityMentionLinker.load("genes", "genes-nel", hybrid_search=False, entity_type="genes") + gene_linker = EntityMentionLinker.build("genes", "genes-nel", hybrid_search=False, entity_type="genes") gene_dictionary = gene_linker.dictionary gene_linker.predict(sentence) print("Diseases") - for span in sentence.get_spans(disease_linker.entity_type): + for span in sentence.get_spans(disease_linker.entity_label_type): print(f"Span: {span.text}") for candidate_label in span.get_labels(disease_linker.label_type): candidate = disease_dictionary[candidate_label.value] print(f"Candidate: {candidate.concept_name}") print("Genes") - for span in sentence.get_spans(gene_linker.entity_type): + for span in sentence.get_spans(gene_linker.entity_label_type): print(f"Span: {span.text}") for candidate_label in span.get_labels(gene_linker.label_type): candidate = gene_dictionary[candidate_label.value] From ce85e3d8dbd9be204c6cebab9d3ef46daf2bff1b Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 2 Oct 2023 15:25:55 +0200 Subject: [PATCH 32/58] add load & save functionality --- flair/data.py | 10 ++++++++ flair/datasets/knowledgebase.py | 23 ++++++++++++++++- flair/models/biomedical_entity_linking.py | 31 ++++++++++++++++++++++- 3 files changed, 62 insertions(+), 2 deletions(-) diff --git a/flair/data.py b/flair/data.py index b52395df4..0f6fb69a2 100644 --- a/flair/data.py +++ b/flair/data.py @@ -477,6 +477,16 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) + def to_dict(self) -> Dict[str, Any]: + return { + "concept_id": self.concept_id, + "concept_name": self.concept_name, + "database_name": self.database_name, + "additional_ids": self.additional_ids, + "synonyms": self.synonyms, + "description": self.description, + } + DT = typing.TypeVar("DT", bound=DataPoint) DT2 = typing.TypeVar("DT2", bound=DataPoint) diff --git a/flair/datasets/knowledgebase.py b/flair/datasets/knowledgebase.py index 142893c30..2863dae39 100644 --- a/flair/datasets/knowledgebase.py +++ b/flair/datasets/knowledgebase.py @@ -1,6 +1,6 @@ import csv from pathlib import Path -from typing import Dict, Iterable, Iterator, List, Optional, Union +from typing import Dict, Iterable, Iterator, List, Optional, Union, Any import flair from flair.data import Concept @@ -53,6 +53,27 @@ def candidates(self) -> List[Concept]: def __getitem__(self, item: str) -> Concept: return self._idx_to_candidates[item] + def to_in_memory_dictionary(self) -> "InMemoryEntityLinkingDictionary": + return InMemoryEntityLinkingDictionary(list(self._idx_to_candidates.values()), self._dataset_name) + + +class InMemoryEntityLinkingDictionary(KnowledgebaseLinkingDictionary): + def __init__(self, candidates: List[Concept], dataset_name: str): + self._dataset_name = dataset_name + super().__init__(candidates, dataset_name=dataset_name) + + def to_state(self) -> Dict[str, Any]: + return { + "dataset_name": self._dataset_name, + "candidates": [candidate.to_dict() for candidate in self._idx_to_candidates.values()], + } + + @classmethod + def from_state(cls, state: Dict[str, Any]) -> "InMemoryEntityLinkingDictionary": + return cls( + dataset_name=state["dataset_name"], candidates=[Concept(**candidate) for candidate in state["candidates"]] + ) + class HunerEntityLinkingDictionary(KnowledgebaseLinkingDictionary): """Base dictionary with data already in huner format. diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/biomedical_entity_linking.py index 6a7c8220f..3cc37d52e 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/biomedical_entity_linking.py @@ -27,8 +27,9 @@ HunerEntityLinkingDictionary, KnowledgebaseLinkingDictionary, ) +from flair.datasets.knowledgebase import InMemoryEntityLinkingDictionary from flair.embeddings import DocumentEmbeddings, DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings -from flair.file_utils import cached_path +from flair.file_utils import cached_path, load_torch_state logger = logging.getLogger("flair") @@ -767,6 +768,34 @@ def predict( for candidate_id, confidence in mention_candidates: data_point.add_label(self.label_type, candidate_id, confidence) + @staticmethod + def _fetch_model(model_name: str) -> str: + if Path(model_name).exists(): + return model_name + + raise NotImplementedError() + + def save(self, model_path: Union[str, Path]) -> None: + pass + + @classmethod + def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "EntityMentionLinker": + if isinstance(model_path, str): + model_path = cls._fetch_model(model_path) + + if isinstance(model_path, dict): + state = model_path + else: + state = load_torch_state(str(model_path)) + + candidate_generator = CandidateSearchIndex._from_state(state["candidate_search_index"]) + preprocessor = EntityPreprocessor._from_state("entity_preprocessor") + entity_label_type = state["entity_label_type"] + label_type = state["label_type"] + dictionary = InMemoryEntityLinkingDictionary.from_state(state["dictionary"]) + + return cls(candidate_generator, preprocessor, entity_label_type, label_type, dictionary) + @classmethod def build( cls, From 750ba7d27d029502662078d3faebd6c8f1c37216 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Sat, 23 Dec 2023 23:14:02 +0100 Subject: [PATCH 33/58] fix naming --- .gitignore | 1 + flair/data.py | 4 +- flair/datasets/__init__.py | 14 +- flair/datasets/entity_linking.py | 454 ++++++++++++++++- flair/datasets/knowledgebase.py | 457 ------------------ flair/models/__init__.py | 2 + ...y_linking.py => entity_mention_linking.py} | 90 ++-- 7 files changed, 516 insertions(+), 506 deletions(-) delete mode 100644 flair/datasets/knowledgebase.py rename flair/models/{biomedical_entity_linking.py => entity_mention_linking.py} (93%) diff --git a/.gitignore b/.gitignore index 86261f980..53e248b6d 100644 --- a/.gitignore +++ b/.gitignore @@ -108,3 +108,4 @@ venv.bak/ resources/taggers/ regression_train/ +/doc_build/ diff --git a/flair/data.py b/flair/data.py index 0f6fb69a2..1089230d5 100644 --- a/flair/data.py +++ b/flair/data.py @@ -433,7 +433,7 @@ def __len__(self) -> int: raise NotImplementedError -class Concept: +class EntityCandidate: """A Concept as part of a knowledgebase or ontology.""" def __init__( @@ -477,7 +477,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> Dict[str, typing.Any]: return { "concept_id": self.concept_id, "concept_name": self.concept_name, diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index ba066545b..56c7e4dd0 100644 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -133,6 +133,10 @@ # word sense disambiguation # Expose all entity linking datasets from .entity_linking import ( + CTD_CHEMICALS_DICTIONARY, + CTD_DISEASES_DICTIONARY, + NCBI_GENE_HUMAN_DICTIONARY, + NCBI_TAXONOMY_DICTIONARY, NEL_ENGLISH_AIDA, NEL_ENGLISH_AQUAINT, NEL_ENGLISH_IITB, @@ -147,14 +151,8 @@ WSD_UFSAC, WSD_WORDNET_GLOSS_TAGGED, ZELDA, -) -from .knowledgebase import ( - CTD_CHEMICALS_DICTIONARY, - CTD_DISEASES_DICTIONARY, - NCBI_GENE_HUMAN_DICTIONARY, - NCBI_TAXONOMY_DICTIONARY, + EntityLinkingDictionary, HunerEntityLinkingDictionary, - KnowledgebaseLinkingDictionary, ) # Expose all relation extraction datasets @@ -323,7 +321,7 @@ "SentenceDataset", "MongoDataset", "StringDataset", - "KnowledgebaseLinkingDictionary", + "EntityLinkingDictionary", "AGNEWS", "ANAT_EM", "AZDZ", diff --git a/flair/datasets/entity_linking.py b/flair/datasets/entity_linking.py index a515e0f3a..a73ded8f6 100644 --- a/flair/datasets/entity_linking.py +++ b/flair/datasets/entity_linking.py @@ -2,12 +2,12 @@ import logging import os from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Any, Dict, Iterable, Iterator, List, Optional, Union import requests import flair -from flair.data import Corpus, MultiCorpus, Sentence +from flair.data import Corpus, EntityCandidate, MultiCorpus, Sentence from flair.datasets.sequence_labeling import ColumnCorpus, MultiFileColumnCorpus from flair.file_utils import cached_path, unpack_file from flair.splitter import SegtokSentenceSplitter, SentenceSplitter @@ -15,6 +15,456 @@ log = logging.getLogger("flair") +class EntityLinkingDictionary: + """Base class for downloading and reading of dictionaries for entity entity linking. + + A dictionary represents all entities of a knowledge base and their associated ids. + """ + + def __init__( + self, + candidates: Iterable[EntityCandidate], + dataset_name: Optional[str] = None, + ): + """Initialize the entity linking dictionary. + + Args: + candidates: A iterable sequence of all Candidates contained in the knowledge base. + """ + # this dataset name + if dataset_name is None: + dataset_name = self.__class__.__name__.lower() + self._dataset_name = dataset_name + + candidates = list(candidates) + + self._idx_to_candidates = {candidate.concept_id: candidate for candidate in candidates} + self._text_to_index = { + text: candidate.concept_id + for candidate in candidates + for text in [candidate.concept_name, *candidate.synonyms] + } + + @property + def database_name(self) -> str: + """Name of the database represented by the dictionary.""" + return self._dataset_name + + @property + def text_to_index(self) -> Dict[str, str]: + return self._text_to_index + + @property + def candidates(self) -> List[EntityCandidate]: + return list(self._idx_to_candidates.values()) + + def __getitem__(self, item: str) -> EntityCandidate: + return self._idx_to_candidates[item] + + def to_in_memory_dictionary(self) -> "InMemoryEntityLinkingDictionary": + return InMemoryEntityLinkingDictionary(list(self._idx_to_candidates.values()), self._dataset_name) + + +class InMemoryEntityLinkingDictionary(EntityLinkingDictionary): + def __init__(self, candidates: List[EntityCandidate], dataset_name: str): + self._dataset_name = dataset_name + super().__init__(candidates, dataset_name=dataset_name) + + def to_state(self) -> Dict[str, Any]: + return { + "dataset_name": self._dataset_name, + "candidates": [candidate.to_dict() for candidate in self._idx_to_candidates.values()], + } + + @classmethod + def from_state(cls, state: Dict[str, Any]) -> "InMemoryEntityLinkingDictionary": + return cls( + dataset_name=state["dataset_name"], + candidates=[EntityCandidate(**candidate) for candidate in state["candidates"]], + ) + + +class HunerEntityLinkingDictionary(EntityLinkingDictionary): + """Base dictionary with data already in huner format. + + Every line in the file must be formatted as follows: + + concept_id||concept_name + + If multiple concept ids are associated to a given name they have to be separated by a `|`, e.g. + + 7157||TP53|tumor protein p53 + """ + + def __init__(self, path: Union[str, Path], dataset_name: str): + self.dataset_file = Path(path) + self._dataset_name = dataset_name + super().__init__(self._load_candidates(), dataset_name=dataset_name) + + def _load_candidates(self): + with open(self.dataset_file) as fp: + for line in fp: + line = line.strip() + if line == "": + continue + assert "||" in line, "Preprocessed EntityLinkingDictionary must have lines in the format: `cui||name`" + cui, name = line.split("||", 1) + name = name.lower() + cui, *additional_ids = cui.split("|") + yield EntityCandidate( + concept_id=cui, + concept_name=name, + database_name=self._dataset_name, + additional_ids=additional_ids, + ) + + +class CTD_DISEASES_DICTIONARY(EntityLinkingDictionary): + """Dictionary for named entity linking on diseases using the Comparative Toxicogenomics Database (CTD). + + Fur further information can be found at https://ctdbase.org/ + """ + + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + ): + if base_path is None: + base_path = flair.cache_root / "datasets" + base_path = Path(base_path) + + dataset_name = self.__class__.__name__.lower() + + data_folder = base_path / dataset_name + + data_file = self.download_dictionary(data_folder) + + super().__init__(self.parse_file(data_file), dataset_name="CTD-DISEASES") + + def download_dictionary(self, data_dir: Path) -> Path: + result_file = data_dir / "CTD_diseases.tsv" + data_url = "https://ctdbase.org/reports/CTD_diseases.tsv.gz" + + if not result_file.exists(): + data_path = cached_path(data_url, data_dir) + unpack_file(data_path, unpack_to=result_file, keep=False) + + return result_file + + def parse_file(self, original_file: Path) -> Iterator[EntityCandidate]: + columns = [ + "symbol", + "identifier", + "alternative_identifiers", + "definition", + "parent_identifiers", + "tree_numbers", + "parent_tree_numbers", + "synonyms", + "slim_mappings", + ] + + with open(original_file, encoding="utf-8") as f: + reader = csv.DictReader(filter(lambda r: r[0] != "#", f), fieldnames=columns, delimiter="\t") + + for row in reader: + identifier = row["identifier"] + additional_identifiers = [i for i in row.get("alternative_identifiers", "").split("|") if i != ""] + + if identifier == "MESH:C" and not additional_identifiers: + return None + + symbol = row["symbol"] + + synonyms = [s for s in row.get("synonyms", "").split("|") if s != ""] + definition = row["definition"] + + yield EntityCandidate( + concept_id=identifier, + concept_name=symbol, + database_name="CTD-DISEASES", + additional_ids=additional_identifiers, + synonyms=synonyms, + description=definition, + ) + + +class CTD_CHEMICALS_DICTIONARY(EntityLinkingDictionary): + """Dictionary for named entity linking on chemicals using the Comparative Toxicogenomics Database (CTD). + + Fur further information can be found at https://ctdbase.org/ + """ + + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + ): + if base_path is None: + base_path = flair.cache_root / "datasets" + base_path = Path(base_path) + + dataset_name = self.__class__.__name__.lower() + + data_folder = base_path / dataset_name + + data_file = self.download_dictionary(data_folder) + + super().__init__(self.parse_file(data_file), dataset_name="CTD-CHEMICALS") + + def download_dictionary(self, data_dir: Path) -> Path: + result_file = data_dir / "CTD_chemicals.tsv" + data_url = "https://ctdbase.org/reports/CTD_chemicals.tsv.gz" + + if not result_file.exists(): + data_path = cached_path(data_url, data_dir) + unpack_file(data_path, unpack_to=result_file) + + return result_file + + def parse_file(self, original_file: Path) -> Iterator[EntityCandidate]: + columns = [ + "symbol", + "identifier", + "casrn", + "definition", + "parent_identifiers", + "tree_numbers", + "parent_tree_numbers", + "synonyms", + ] + + with open(original_file, encoding="utf-8") as f: + reader = csv.DictReader(filter(lambda r: r[0] != "#", f), fieldnames=columns, delimiter="\t") + + for row in reader: + identifier = row["identifier"] + additional_identifiers = [i for i in row.get("alternative_identifiers", "").split("|") if i != ""] + + if identifier == "MESH:D013749": + # This MeSH ID was used by MeSH when this chemical was part of the MeSH controlled vocabulary. + continue + + symbol = row["symbol"] + + synonyms = [s for s in row.get("synonyms", "").split("|") if s != "" and s != symbol] + definition = row["definition"] + + yield EntityCandidate( + concept_id=identifier, + concept_name=symbol, + database_name="CTD-CHEMICALS", + additional_ids=additional_identifiers, + synonyms=synonyms, + description=definition, + ) + + +class NCBI_GENE_HUMAN_DICTIONARY(EntityLinkingDictionary): + """Dictionary for named entity linking on diseases using the NCBI Gene ontology. + + Note that this dictionary only represents human genes - gene from different species + aren't included! + + Fur further information can be found at https://www.ncbi.nlm.nih.gov/gene/ + """ + + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + ): + if base_path is None: + base_path = flair.cache_root / "datasets" + base_path = Path(base_path) + + dataset_name = self.__class__.__name__.lower() + + data_folder = base_path / dataset_name + + data_file = self.download_dictionary(data_folder) + + super().__init__(self.parse_dictionary(data_file), dataset_name="NCBI-GENE-HUMAN") + + def _is_invalid_name(self, name: Optional[str]) -> bool: + """Determine if a name should be skipped.""" + if name is None: + return False + name = name.strip() + EMPTY_ENTRY_TEXT = [ + "when different from all specified ones in Gene.", + "Record to support submission of GeneRIFs for a gene not in Gene", + ] + + newentry = name == "NEWENTRY" + empty = name == "" + minus = name == "-" + text_comment = any(e in name for e in EMPTY_ENTRY_TEXT) + + return any([newentry, empty, minus, text_comment]) + + def download_dictionary(self, data_dir: Path) -> Path: + result_file = data_dir / "Homo_sapiens.gene_info" + data_url = "https://ftp.ncbi.nih.gov/gene/DATA/GENE_INFO/Mammalia/Homo_sapiens.gene_info.gz" + + if not result_file.exists(): + data_path = cached_path(data_url, data_dir) + unpack_file(data_path, unpack_to=result_file) + + return result_file + + def parse_dictionary(self, original_file: Path) -> Iterator[EntityCandidate]: + synonym_fields = ( + "Symbol_from_nomenclature_authority", + "Full_name_from_nomenclature_authority", + "description", + "Synonyms", + "Other_designations", + ) + field_names = [ + "tax_id", + "GeneID", + "Symbol", + "LocusTag", + "Synonyms", + "dbXrefs", + "chromosome", + "map_location", + "description", + "type_of_gene", + "Symbol_from_nomenclature_authority", + "Full_name_from_nomenclature_authority", + "Nomenclature_status", + "Other_designations", + "Modification_date", + "Feature_type", + ] + + with open(original_file, encoding="utf-8") as f: + reader = csv.DictReader(filter(lambda r: r[0] != "#", f), fieldnames=field_names, delimiter="\t") + + for row in reader: + identifier = row["GeneID"] + symbol = row["Symbol"] + + if self._is_invalid_name(symbol): + continue + additional_identifiers = [i for i in row.get("alternative_identifiers", "").split("|") if i != ""] + + if identifier == "MESH:D013749": + # This MeSH ID was used by MeSH when this chemical was part of the MeSH controlled vocabulary. + continue + + synonyms = [] + for synonym_field in synonym_fields: + synonyms.extend([name.replace("'", "") for name in row.get(synonym_field, "").split("|")]) + synonyms = sorted([sym for sym in set(synonyms) if not self._is_invalid_name(sym)]) + if symbol in synonyms: + synonyms.remove(symbol) + + yield EntityCandidate( + concept_id=identifier, + concept_name=symbol, + database_name="NCBI-GENE-HUMAN", + additional_ids=additional_identifiers, + synonyms=synonyms, + ) + + +class NCBI_TAXONOMY_DICTIONARY(EntityLinkingDictionary): + """Dictionary for named entity linking on organisms / species using the NCBI taxonomy ontology. + + Further information about the ontology can be found at https://www.ncbi.nlm.nih.gov/taxonomy + """ + + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + ): + if base_path is None: + base_path = flair.cache_root / "datasets" + base_path = Path(base_path) + dataset_name = self.__class__.__name__.lower() + + data_folder = base_path / dataset_name + + data_file = self.download_dictionary(data_folder) + + super().__init__(self.parse_dictionary(data_file), dataset_name="NCBI-TAXONOMY") + + def download_dictionary(self, data_dir: Path) -> Path: + result_file = data_dir / "names.dmp" + data_url = "https://ftp.ncbi.nih.gov/pub/taxonomy/new_taxdump/new_taxdump.tar.gz" + + if not result_file.exists(): + data_path = cached_path(data_url, data_dir) + unpack_file(data_path, unpack_to=result_file.parent) + + return result_file + + def parse_dictionary(self, original_file: Path) -> Iterator[EntityCandidate]: + ncbi_taxonomy_synset = [ + "genbank common name", + "common name", + "scientific name", + "equivalent name", + "synonym", + "acronym", + "blast name", + "genbank", + "genbank synonym", + "genbank acronym", + "includes", + "type material", + ] + main_field = "scientific name" + with open(original_file, encoding="utf-8") as f: + curr_identifier = None + curr_synonyms = [] + curr_name = None + + for line in f: + # parse line + parsed_line = {} + elements = [e.strip() for e in line.strip().split("|")] + parsed_line["identifier"] = elements[0] + parsed_line["name"] = elements[1] if elements[2] == "" else elements[2] + parsed_line["field"] = elements[3] + + if parsed_line["name"] in ["all", "root"]: + continue + + if parsed_line["field"] in ["authority", "in-part", "type material"]: + continue + + if parsed_line["field"] not in ncbi_taxonomy_synset: + raise ValueError(f"Field {parsed_line['field']} unknown!") + + if curr_identifier is None: + curr_identifier = parsed_line["identifier"] + + if curr_identifier == parsed_line["identifier"]: + synonym = parsed_line["name"] + if parsed_line["field"] == main_field: + curr_name = synonym + else: + curr_synonyms.append(synonym) + elif curr_identifier != parsed_line["identifier"]: + assert curr_name is not None + yield EntityCandidate( + concept_id=curr_identifier, + concept_name=curr_name, + database_name="NCBI-TAXONOMY", + ) + + curr_identifier = parsed_line["identifier"] + curr_synonyms = [] + curr_name = None + synonym = parsed_line["name"] + if parsed_line["field"] == main_field: + curr_name = synonym + else: + curr_synonyms.append(synonym) + + class ZELDA(MultiFileColumnCorpus): def __init__( self, diff --git a/flair/datasets/knowledgebase.py b/flair/datasets/knowledgebase.py deleted file mode 100644 index 2863dae39..000000000 --- a/flair/datasets/knowledgebase.py +++ /dev/null @@ -1,457 +0,0 @@ -import csv -from pathlib import Path -from typing import Dict, Iterable, Iterator, List, Optional, Union, Any - -import flair -from flair.data import Concept -from flair.file_utils import cached_path, unpack_file - - -class KnowledgebaseLinkingDictionary: - """Base class for downloading and reading of dictionaries for knowledgebase entity linking. - - A dictionary represents all entities of a knowledge base and their associated ids. - """ - - def __init__( - self, - candidates: Iterable[Concept], - dataset_name: Optional[str] = None, - ): - """Initialize the Knowledgebase linking dictionary. - - Args: - candidates: A iterable sequence of all Candidates contained in the knowledge base. - """ - # this dataset name - if dataset_name is None: - dataset_name = self.__class__.__name__.lower() - self._dataset_name = dataset_name - - candidates = list(candidates) - - self._idx_to_candidates = {candidate.concept_id: candidate for candidate in candidates} - self._text_to_index = { - text: candidate.concept_id - for candidate in candidates - for text in [candidate.concept_name, *candidate.synonyms] - } - - @property - def database_name(self) -> str: - """Name of the database represented by the dictionary.""" - return self._dataset_name - - @property - def text_to_index(self) -> Dict[str, str]: - return self._text_to_index - - @property - def candidates(self) -> List[Concept]: - return list(self._idx_to_candidates.values()) - - def __getitem__(self, item: str) -> Concept: - return self._idx_to_candidates[item] - - def to_in_memory_dictionary(self) -> "InMemoryEntityLinkingDictionary": - return InMemoryEntityLinkingDictionary(list(self._idx_to_candidates.values()), self._dataset_name) - - -class InMemoryEntityLinkingDictionary(KnowledgebaseLinkingDictionary): - def __init__(self, candidates: List[Concept], dataset_name: str): - self._dataset_name = dataset_name - super().__init__(candidates, dataset_name=dataset_name) - - def to_state(self) -> Dict[str, Any]: - return { - "dataset_name": self._dataset_name, - "candidates": [candidate.to_dict() for candidate in self._idx_to_candidates.values()], - } - - @classmethod - def from_state(cls, state: Dict[str, Any]) -> "InMemoryEntityLinkingDictionary": - return cls( - dataset_name=state["dataset_name"], candidates=[Concept(**candidate) for candidate in state["candidates"]] - ) - - -class HunerEntityLinkingDictionary(KnowledgebaseLinkingDictionary): - """Base dictionary with data already in huner format. - - Every line in the file must be formatted as follows: - - concept_id||concept_name - - If multiple concept ids are associated to a given name they have to be separated by a `|`, e.g. - - 7157||TP53|tumor protein p53 - """ - - def __init__(self, path: Union[str, Path], dataset_name: str): - self.dataset_file = Path(path) - self._dataset_name = dataset_name - super().__init__(self._load_candidates(), dataset_name=dataset_name) - - def _load_candidates(self): - with open(self.dataset_file) as fp: - for line in fp: - line = line.strip() - if line == "": - continue - assert "||" in line, "Preprocessed EntityLinkingDictionary must have lines in the format: `cui||name`" - cui, name = line.split("||", 1) - name = name.lower() - cui, *additional_ids = cui.split("|") - yield Concept( - concept_id=cui, - concept_name=name, - database_name=self._dataset_name, - additional_ids=additional_ids, - ) - - -class CTD_DISEASES_DICTIONARY(KnowledgebaseLinkingDictionary): - """Dictionary for named entity linking on diseases using the Comparative Toxicogenomics Database (CTD). - - Fur further information can be found at https://ctdbase.org/ - """ - - def __init__( - self, - base_path: Optional[Union[str, Path]] = None, - ): - if base_path is None: - base_path = flair.cache_root / "datasets" - base_path = Path(base_path) - - dataset_name = self.__class__.__name__.lower() - - data_folder = base_path / dataset_name - - data_file = self.download_dictionary(data_folder) - - super().__init__(self.parse_file(data_file), dataset_name="CTD-DISEASES") - - def download_dictionary(self, data_dir: Path) -> Path: - result_file = data_dir / "CTD_diseases.tsv" - data_url = "https://ctdbase.org/reports/CTD_diseases.tsv.gz" - - if not result_file.exists(): - data_path = cached_path(data_url, data_dir) - unpack_file(data_path, unpack_to=result_file, keep=False) - - return result_file - - def parse_file(self, original_file: Path) -> Iterator[Concept]: - columns = [ - "symbol", - "identifier", - "alternative_identifiers", - "definition", - "parent_identifiers", - "tree_numbers", - "parent_tree_numbers", - "synonyms", - "slim_mappings", - ] - - with open(original_file, encoding="utf-8") as f: - reader = csv.DictReader(filter(lambda r: r[0] != "#", f), fieldnames=columns, delimiter="\t") - - for row in reader: - identifier = row["identifier"] - additional_identifiers = [i for i in row.get("alternative_identifiers", "").split("|") if i != ""] - - if identifier == "MESH:C" and not additional_identifiers: - return None - - symbol = row["symbol"] - - synonyms = [s for s in row.get("synonyms", "").split("|") if s != ""] - definition = row["definition"] - - yield Concept( - concept_id=identifier, - concept_name=symbol, - database_name="CTD-DISEASES", - additional_ids=additional_identifiers, - synonyms=synonyms, - description=definition, - ) - - -class CTD_CHEMICALS_DICTIONARY(KnowledgebaseLinkingDictionary): - """Dictionary for named entity linking on chemicals using the Comparative Toxicogenomics Database (CTD). - - Fur further information can be found at https://ctdbase.org/ - """ - - def __init__( - self, - base_path: Optional[Union[str, Path]] = None, - ): - if base_path is None: - base_path = flair.cache_root / "datasets" - base_path = Path(base_path) - - dataset_name = self.__class__.__name__.lower() - - data_folder = base_path / dataset_name - - data_file = self.download_dictionary(data_folder) - - super().__init__(self.parse_file(data_file), dataset_name="CTD-CHEMICALS") - - def download_dictionary(self, data_dir: Path) -> Path: - result_file = data_dir / "CTD_chemicals.tsv" - data_url = "https://ctdbase.org/reports/CTD_chemicals.tsv.gz" - - if not result_file.exists(): - data_path = cached_path(data_url, data_dir) - unpack_file(data_path, unpack_to=result_file) - - return result_file - - def parse_file(self, original_file: Path) -> Iterator[Concept]: - columns = [ - "symbol", - "identifier", - "casrn", - "definition", - "parent_identifiers", - "tree_numbers", - "parent_tree_numbers", - "synonyms", - ] - - with open(original_file, encoding="utf-8") as f: - reader = csv.DictReader(filter(lambda r: r[0] != "#", f), fieldnames=columns, delimiter="\t") - - for row in reader: - identifier = row["identifier"] - additional_identifiers = [i for i in row.get("alternative_identifiers", "").split("|") if i != ""] - - if identifier == "MESH:D013749": - # This MeSH ID was used by MeSH when this chemical was part of the MeSH controlled vocabulary. - continue - - symbol = row["symbol"] - - synonyms = [s for s in row.get("synonyms", "").split("|") if s != "" and s != symbol] - definition = row["definition"] - - yield Concept( - concept_id=identifier, - concept_name=symbol, - database_name="CTD-CHEMICALS", - additional_ids=additional_identifiers, - synonyms=synonyms, - description=definition, - ) - - -class NCBI_GENE_HUMAN_DICTIONARY(KnowledgebaseLinkingDictionary): - """Dictionary for named entity linking on diseases using the NCBI Gene ontology. - - Note that this dictionary only represents human genes - gene from different species - aren't included! - - Fur further information can be found at https://www.ncbi.nlm.nih.gov/gene/ - """ - - def __init__( - self, - base_path: Optional[Union[str, Path]] = None, - ): - if base_path is None: - base_path = flair.cache_root / "datasets" - base_path = Path(base_path) - - dataset_name = self.__class__.__name__.lower() - - data_folder = base_path / dataset_name - - data_file = self.download_dictionary(data_folder) - - super().__init__(self.parse_dictionary(data_file), dataset_name="NCBI-GENE-HUMAN") - - def _is_invalid_name(self, name: Optional[str]) -> bool: - """Determine if a name should be skipped.""" - if name is None: - return False - name = name.strip() - EMPTY_ENTRY_TEXT = [ - "when different from all specified ones in Gene.", - "Record to support submission of GeneRIFs for a gene not in Gene", - ] - - newentry = name == "NEWENTRY" - empty = name == "" - minus = name == "-" - text_comment = any(e in name for e in EMPTY_ENTRY_TEXT) - - return any([newentry, empty, minus, text_comment]) - - def download_dictionary(self, data_dir: Path) -> Path: - result_file = data_dir / "Homo_sapiens.gene_info" - data_url = "https://ftp.ncbi.nih.gov/gene/DATA/GENE_INFO/Mammalia/Homo_sapiens.gene_info.gz" - - if not result_file.exists(): - data_path = cached_path(data_url, data_dir) - unpack_file(data_path, unpack_to=result_file) - - return result_file - - def parse_dictionary(self, original_file: Path) -> Iterator[Concept]: - synonym_fields = ( - "Symbol_from_nomenclature_authority", - "Full_name_from_nomenclature_authority", - "description", - "Synonyms", - "Other_designations", - ) - field_names = [ - "tax_id", - "GeneID", - "Symbol", - "LocusTag", - "Synonyms", - "dbXrefs", - "chromosome", - "map_location", - "description", - "type_of_gene", - "Symbol_from_nomenclature_authority", - "Full_name_from_nomenclature_authority", - "Nomenclature_status", - "Other_designations", - "Modification_date", - "Feature_type", - ] - - with open(original_file, encoding="utf-8") as f: - reader = csv.DictReader(filter(lambda r: r[0] != "#", f), fieldnames=field_names, delimiter="\t") - - for row in reader: - identifier = row["GeneID"] - symbol = row["Symbol"] - - if self._is_invalid_name(symbol): - continue - additional_identifiers = [i for i in row.get("alternative_identifiers", "").split("|") if i != ""] - - if identifier == "MESH:D013749": - # This MeSH ID was used by MeSH when this chemical was part of the MeSH controlled vocabulary. - continue - - synonyms = [] - for synonym_field in synonym_fields: - synonyms.extend([name.replace("'", "") for name in row.get(synonym_field, "").split("|")]) - synonyms = sorted([sym for sym in set(synonyms) if not self._is_invalid_name(sym)]) - if symbol in synonyms: - synonyms.remove(symbol) - - yield Concept( - concept_id=identifier, - concept_name=symbol, - database_name="NCBI-GENE-HUMAN", - additional_ids=additional_identifiers, - synonyms=synonyms, - ) - - -class NCBI_TAXONOMY_DICTIONARY(KnowledgebaseLinkingDictionary): - """Dictionary for named entity linking on organisms / species using the NCBI taxonomy ontology. - - Further information about the ontology can be found at https://www.ncbi.nlm.nih.gov/taxonomy - """ - - def __init__( - self, - base_path: Optional[Union[str, Path]] = None, - ): - if base_path is None: - base_path = flair.cache_root / "datasets" - base_path = Path(base_path) - dataset_name = self.__class__.__name__.lower() - - data_folder = base_path / dataset_name - - data_file = self.download_dictionary(data_folder) - - super().__init__(self.parse_dictionary(data_file), dataset_name="NCBI-TAXONOMY") - - def download_dictionary(self, data_dir: Path) -> Path: - result_file = data_dir / "names.dmp" - data_url = "https://ftp.ncbi.nih.gov/pub/taxonomy/new_taxdump/new_taxdump.tar.gz" - - if not result_file.exists(): - data_path = cached_path(data_url, data_dir) - unpack_file(data_path, unpack_to=result_file) - - return result_file - - def parse_dictionary(self, original_file: Path) -> Iterator[Concept]: - ncbi_taxonomy_synset = [ - "genbank common name", - "common name", - "scientific name", - "equivalent name", - "synonym", - "acronym", - "blast name", - "genbank", - "genbank synonym", - "genbank acronym", - "includes", - "type material", - ] - main_field = "scientific name" - - with open(original_file, encoding="utf-8") as f: - curr_identifier = None - curr_synonyms = [] - curr_name = None - - for line in f: - # parse line - parsed_line = {} - elements = [e.strip() for e in line.strip().split("|")] - parsed_line["identifier"] = elements[0] - parsed_line["name"] = elements[1] if elements[2] == "" else elements[2] - parsed_line["field"] = elements[3] - - if parsed_line["name"] in ["all", "root"]: - continue - - if parsed_line["field"] in ["authority", "in-part", "type material"]: - continue - - if parsed_line["field"] not in ncbi_taxonomy_synset: - raise ValueError(f"Field {parsed_line['field']} unknown!") - - if curr_identifier is None: - curr_identifier = parsed_line["identifier"] - - if curr_identifier == parsed_line["identifier"]: - synonym = parsed_line["name"] - if parsed_line["field"] == main_field: - curr_name = synonym - else: - curr_synonyms.append(synonym) - elif curr_identifier != parsed_line["identifier"]: - assert curr_name is not None - yield Concept( - concept_id=curr_identifier, - concept_name=curr_name, - database_name="NCBI-TAXONOMY", - ) - - curr_identifier = parsed_line["identifier"] - curr_synonyms = [] - curr_name = None - synonym = parsed_line["name"] - if parsed_line["field"] == main_field: - curr_name = synonym - else: - curr_synonyms.append(synonym) diff --git a/flair/models/__init__.py b/flair/models/__init__.py index 118ff0794..93a262150 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -1,4 +1,5 @@ from .clustering import ClusteringModel +from .entity_mention_linking import EntityMentionLinker from .entity_linker_model import SpanClassifier from .language_model import LanguageModel from .lemmatizer_model import Lemmatizer @@ -15,6 +16,7 @@ from .word_tagger_model import TokenClassifier, WordTagger __all__ = [ + "EntityMentionLinker", "SpanClassifier", "LanguageModel", "Lemmatizer", diff --git a/flair/models/biomedical_entity_linking.py b/flair/models/entity_mention_linking.py similarity index 93% rename from flair/models/biomedical_entity_linking.py rename to flair/models/entity_mention_linking.py index 3cc37d52e..18907bb45 100644 --- a/flair/models/biomedical_entity_linking.py +++ b/flair/models/entity_mention_linking.py @@ -14,22 +14,25 @@ import numpy as np import torch +from torch.utils.data import Dataset from tqdm import tqdm import flair from flair.class_utils import get_state_subclass_by_name -from flair.data import Label, Sentence, Span +from flair.data import DT, Dictionary, Label, Sentence, Span from flair.datasets import ( CTD_CHEMICALS_DICTIONARY, CTD_DISEASES_DICTIONARY, NCBI_GENE_HUMAN_DICTIONARY, NCBI_TAXONOMY_DICTIONARY, + EntityLinkingDictionary, HunerEntityLinkingDictionary, - KnowledgebaseLinkingDictionary, ) -from flair.datasets.knowledgebase import InMemoryEntityLinkingDictionary +from flair.datasets.entity_linking import InMemoryEntityLinkingDictionary from flair.embeddings import DocumentEmbeddings, DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings -from flair.file_utils import cached_path, load_torch_state +from flair.embeddings.base import load_embeddings +from flair.file_utils import cached_path +from flair.training_utils import Result logger = logging.getLogger("flair") @@ -88,11 +91,12 @@ "dmis-lab/biosyn-sapbert-bc5cdr-disease": "ctd-disease", "dmis-lab/biosyn-sapbert-ncbi-disease": "ctd-disease", "dmis-lab/biosyn-sapbert-bc5cdr-chemical": "ctd-chemical", + "dmis-lab/biosyn-sapbert-bc2gn": "ncbi-gene", + "dmis-lab/biosyn-biobert-bc5cdr-disease": "ctd-chemical", "dmis-lab/biosyn-biobert-ncbi-disease": "ctd-disease", "dmis-lab/biosyn-biobert-bc5cdr-chemical": "ctd-chemical", "dmis-lab/biosyn-biobert-bc2gn": "ncbi-gene", - "dmis-lab/biosyn-sapbert-bc2gn": "ncbi-gene", } DEFAULT_SPARSE_WEIGHT = 0.5 @@ -100,7 +104,7 @@ def load_dictionary( dictionary_name_or_path: Union[Path, str], dataset_name: Optional[str] = None -) -> KnowledgebaseLinkingDictionary: +) -> EntityLinkingDictionary: """Load dictionary: either pre-defined or from path.""" if isinstance(dictionary_name_or_path, str) and ( dictionary_name_or_path in ENTITY_TYPE_TO_DICTIONARY or dictionary_name_or_path in BIOMEDICAL_DICTIONARIES @@ -414,9 +418,7 @@ class CandidateSearchIndex(ABC): """ @abstractmethod - def index( - self, dictionary: KnowledgebaseLinkingDictionary, preprocessor: Optional[EntityPreprocessor] = None - ) -> None: + def index(self, dictionary: EntityLinkingDictionary, preprocessor: Optional[EntityPreprocessor] = None) -> None: """Index a dictionary to prepare for search. Args: @@ -459,9 +461,7 @@ def __init__(self): """ self.name_to_id_index: Dict[str, str] = {} - def index( - self, dictionary: KnowledgebaseLinkingDictionary, preprocessor: Optional[EntityPreprocessor] = None - ) -> None: + def index(self, dictionary: EntityLinkingDictionary, preprocessor: Optional[EntityPreprocessor] = None) -> None: def p(text: str) -> str: return preprocessor.process_entity_name(text) if preprocessor is not None else text @@ -538,7 +538,7 @@ def bi_encoder( show_progress: bool = True, sparse_weight: float = 0.5, preprocessor: Optional[EntityPreprocessor] = None, - dictionary: Optional[KnowledgebaseLinkingDictionary] = None, + dictionary: Optional[EntityLinkingDictionary] = None, ) -> "SemanticCandidateSearchIndex": embeddings: List[DocumentEmbeddings] = [TransformerDocumentEmbeddings(model_name_or_path)] weights = [1.0] @@ -571,9 +571,7 @@ def bi_encoder( show_progress=show_progress, ) - def index( - self, dictionary: KnowledgebaseLinkingDictionary, preprocessor: Optional[EntityPreprocessor] = None - ) -> None: + def index(self, dictionary: EntityLinkingDictionary, preprocessor: Optional[EntityPreprocessor] = None) -> None: def p(text: str) -> str: return preprocessor.process_entity_name(text) if preprocessor is not None else text @@ -671,7 +669,7 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, @classmethod def _from_state(cls, state_dict: Dict[str, Any]) -> "CandidateSearchIndex": index = cls( - embeddings=[DocumentEmbeddings.load_embedding(emb) for emb in state_dict["embeddings"]], + embeddings=[load_embeddings(emb) for emb in state_dict["embeddings"]], similarity_metric=SimilarityMetric(state_dict["similarity_metric"]), weights=state_dict["weights"], batch_size=state_dict["batch_size"], @@ -694,7 +692,7 @@ def _get_state(self) -> Dict[str, Any]: } -class EntityMentionLinker: +class EntityMentionLinker(flair.nn.Model): """Entity linking model for the biomedical domain.""" def __init__( @@ -703,20 +701,21 @@ def __init__( preprocessor: EntityPreprocessor, entity_label_type: str, label_type: str, - dictionary: KnowledgebaseLinkingDictionary, + dictionary: EntityLinkingDictionary, ): self.preprocessor = preprocessor self.candidate_generator = candidate_generator self.entity_label_type = entity_label_type self._label_type = label_type self._dictionary = dictionary + super().__init__() @property def label_type(self): return self._label_type @property - def dictionary(self) -> KnowledgebaseLinkingDictionary: + def dictionary(self) -> EntityLinkingDictionary: return self._dictionary def extract_mentions( @@ -773,29 +772,29 @@ def _fetch_model(model_name: str) -> str: if Path(model_name).exists(): return model_name - raise NotImplementedError() - - def save(self, model_path: Union[str, Path]) -> None: - pass + raise NotImplementedError @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "EntityMentionLinker": - if isinstance(model_path, str): - model_path = cls._fetch_model(model_path) - - if isinstance(model_path, dict): - state = model_path - else: - state = load_torch_state(str(model_path)) - + def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs) -> "EntityMentionLinker": candidate_generator = CandidateSearchIndex._from_state(state["candidate_search_index"]) - preprocessor = EntityPreprocessor._from_state("entity_preprocessor") + preprocessor = EntityPreprocessor._from_state(state["entity_preprocessor"]) entity_label_type = state["entity_label_type"] label_type = state["label_type"] dictionary = InMemoryEntityLinkingDictionary.from_state(state["dictionary"]) return cls(candidate_generator, preprocessor, entity_label_type, label_type, dictionary) + def _get_state_dict(self): + """Returns the state dictionary for this model.""" + return { + **super()._get_state_dict(), + "label_type": self.label_type, + "entity_label_type": self.entity_label_type, + "entity_preprocessor": self.preprocessor._get_state(), + "candidate_search_index": self.candidate_generator._get_state(), + "dictionary": self.dictionary.to_in_memory_dictionary().to_state(), + } + @classmethod def build( cls, @@ -809,15 +808,14 @@ def build( force_hybrid_search: bool = False, sparse_weight: float = DEFAULT_SPARSE_WEIGHT, entity_type: Optional[str] = None, - dictionary: Optional[KnowledgebaseLinkingDictionary] = None, + dictionary: Optional[EntityLinkingDictionary] = None, dataset_name: Optional[str] = None, ) -> "EntityMentionLinker": """Loads a model for biomedical named entity normalization. - See __init__ method for detailed docstring on arguments. """ if not isinstance(model_name_or_path, str): - raise AssertionError(f"String matching model name has to be an string (and not {type(model_name_or_path)}") + raise ValueError(f"String matching model name has to be an string (and not {type(model_name_or_path)}") model_name_or_path = cast(str, model_name_or_path) if dictionary is None: @@ -956,3 +954,21 @@ def __get_dictionary_path( ) return dictionary_name_or_path + + def forward_loss(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]: + raise NotImplementedError("The EntityLinker cannot be trained") + + def evaluate( + self, + data_points: Union[List[DT], Dataset], + gold_label_type: str, + out_path: Optional[Union[str, Path]] = None, + embedding_storage_mode: str = "none", + mini_batch_size: int = 32, + main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), + exclude_labels: List[str] = [], + gold_label_dictionary: Optional[Dictionary] = None, + return_loss: bool = True, + **kwargs, + ) -> Result: + raise NotImplementedError("Evaluation is currently not implemented for EntityLinking") From c5fa9c9478d9d1392c5794b05cd74f0e9e059034 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Sun, 24 Dec 2023 14:30:15 +0100 Subject: [PATCH 34/58] add hf model download --- flair/file_utils.py | 38 +++++++++++++++++++++ flair/models/entity_mention_linking.py | 21 ++++++++++-- flair/models/sequence_tagger_model.py | 45 ++----------------------- tests/test_biomedical_entity_linking.py | 8 ++--- 4 files changed, 63 insertions(+), 49 deletions(-) diff --git a/flair/file_utils.py b/flair/file_utils.py index a3db6458d..dfb0049b7 100644 --- a/flair/file_utils.py +++ b/flair/file_utils.py @@ -21,6 +21,7 @@ import torch from botocore import UNSIGNED from botocore.config import Config +from requests import HTTPError from tqdm import tqdm as _tqdm import flair @@ -144,6 +145,43 @@ def unzip_file(file: Union[str, Path], unzip_to: Union[str, Path]): zipObj.extractall(Path(unzip_to)) +def hf_download(model_name: str) -> str: + hf_model_name = "pytorch_model.bin" + revision = "main" + + if "@" in model_name: + model_name_split = model_name.split("@") + revision = model_name_split[-1] + model_name = model_name_split[0] + + # use model name as subfolder + model_folder = model_name.split("/", maxsplit=1)[1] if "/" in model_name else model_name + + # Lazy import + from huggingface_hub.file_download import hf_hub_download + + try: + return hf_hub_download( + repo_id=model_name, + filename=hf_model_name, + revision=revision, + library_name="flair", + library_version=flair.__version__, + cache_dir=flair.cache_root / "models" / model_folder, + ) + except HTTPError: + # output information + logger.error("-" * 80) + logger.error( + f"ERROR: The key '{model_name}' was neither found on the ModelHub nor is this a valid path to a file on your system!" + ) + logger.error(" -> Please check https://huggingface.co/models?filter=flair for all available models.") + logger.error(" -> Alternatively, point to a model file on your local drive.") + logger.error("-" * 80) + Path(flair.cache_root / "models" / model_folder).rmdir() # remove folder again if not valid + raise + + def unpack_file(file: Path, unpack_to: Path, mode: Optional[str] = None, keep: bool = True): """Unpacks an archive file to the given location. diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index 18907bb45..74099ae50 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -31,7 +31,7 @@ from flair.datasets.entity_linking import InMemoryEntityLinkingDictionary from flair.embeddings import DocumentEmbeddings, DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings from flair.embeddings.base import load_embeddings -from flair.file_utils import cached_path +from flair.file_utils import cached_path, hf_download from flair.training_utils import Result logger = logging.getLogger("flair") @@ -772,7 +772,24 @@ def _fetch_model(model_name: str) -> str: if Path(model_name).exists(): return model_name - raise NotImplementedError + bio_base_repo = "helpmefindaname" + + hf_model_map = { + "bio-gene": f"{bio_base_repo}/flair-eml-sapbert-bc2gn-gene", + "bio-disease": f"{bio_base_repo}/flair-eml-sapbert-bc5cdr-disease", + "bio-chemical": f"{bio_base_repo}/flair-eml-sapbert-bc5cdr-chemical", + "bio-species": f"{bio_base_repo}/flair-eml-species-exact-match", + "bio-gene-exact-match": f"{bio_base_repo}/flair-eml-gene-exact-match", + "bio-disease-exact-match": f"{bio_base_repo}/flair-eml-disease-exact-match", + "bio-chemical-exact-match": f"{bio_base_repo}/flair-eml-chemical-exact-match", + "bio-species-exact-match": f"{bio_base_repo}/flair-eml-species-exact-match", + } + + if model_name in hf_model_map: + model_name = hf_model_map[model_name] + + return hf_download(model_name) + @classmethod def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs) -> "EntityMentionLinker": diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index c6defd24a..2f1bab67e 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -14,7 +14,7 @@ from flair.data import Dictionary, Label, Sentence, Span, get_spans_from_bio from flair.datasets import DataLoader, FlairDatapointDataset from flair.embeddings import TokenEmbeddings -from flair.file_utils import cached_path, unzip_file +from flair.file_utils import cached_path, unzip_file, hf_download from flair.models.sequence_tagger_utils.crf import CRF from flair.models.sequence_tagger_utils.viterbi import ViterbiDecoder, ViterbiLoss from flair.training_utils import store_embeddings @@ -775,9 +775,7 @@ def _fetch_model(model_name) -> str: # get mapped name hf_model_name = huggingface_model_map[model_name] - # use mapped name instead - model_name = hf_model_name - get_from_model_hub = True + model_path = hf_download(hf_model_name) # if not, check if model key is remapped to direct download location. If so, download model elif model_name in hu_model_map: @@ -838,44 +836,7 @@ def _fetch_model(model_name) -> str: # for all other cases (not local file or special download location), use HF model hub else: - get_from_model_hub = True - - # if not a local file, get from model hub - if get_from_model_hub: - hf_model_name = "pytorch_model.bin" - revision = "main" - - if "@" in model_name: - model_name_split = model_name.split("@") - revision = model_name_split[-1] - model_name = model_name_split[0] - - # use model name as subfolder - model_folder = model_name.split("/", maxsplit=1)[1] if "/" in model_name else model_name - - # Lazy import - from huggingface_hub.file_download import hf_hub_download - - try: - model_path = hf_hub_download( - repo_id=model_name, - filename=hf_model_name, - revision=revision, - library_name="flair", - library_version=flair.__version__, - cache_dir=flair.cache_root / "models" / model_folder, - ) - except HTTPError: - # output information - log.error("-" * 80) - log.error( - f"ERROR: The key '{model_name}' was neither found on the ModelHub nor is this a valid path to a file on your system!" - ) - log.error(" -> Please check https://huggingface.co/models?filter=flair for all available models.") - log.error(" -> Alternatively, point to a model file on your local drive.") - log.error("-" * 80) - Path(flair.cache_root / "models" / model_folder).rmdir() # remove folder again if not valid - raise + model_path = hf_download(model_name) return model_path diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index b4b4475ec..7a165041d 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -1,5 +1,5 @@ from flair.data import Sentence -from flair.models.biomedical_entity_linking import ( +from flair.models.entity_mention_linking import ( EntityMentionLinker, load_dictionary, ) @@ -51,11 +51,11 @@ def test_biomedical_entity_linking(): tagger = Classifier.load("hunflair") tagger.predict(sentence) - disease_linker = EntityMentionLinker.build("diseases", "diseases-nel", hybrid_search=True) + disease_linker = EntityMentionLinker.load("bio-disease") disease_dictionary = disease_linker.dictionary disease_linker.predict(sentence) - gene_linker = EntityMentionLinker.build("genes", "genes-nel", hybrid_search=False, entity_type="genes") + gene_linker = EntityMentionLinker.load("bio-genes") gene_dictionary = gene_linker.dictionary gene_linker.predict(sentence) @@ -73,5 +73,3 @@ def test_biomedical_entity_linking(): for candidate_label in span.get_labels(gene_linker.label_type): candidate = gene_dictionary[candidate_label.value] print(f"Candidate: {candidate.concept_name}") - - breakpoint() # noqa: T100 From 94db5b32d006b09a6c9f7c6a35bac7cd82dffba3 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Sun, 24 Dec 2023 14:59:14 +0100 Subject: [PATCH 35/58] fix ruff & mypy errors --- flair/models/__init__.py | 2 +- flair/models/entity_mention_linking.py | 8 ++------ flair/models/sequence_tagger_model.py | 5 +---- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/flair/models/__init__.py b/flair/models/__init__.py index 93a262150..ac69e19aa 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -1,6 +1,6 @@ from .clustering import ClusteringModel -from .entity_mention_linking import EntityMentionLinker from .entity_linker_model import SpanClassifier +from .entity_mention_linking import EntityMentionLinker from .language_model import LanguageModel from .lemmatizer_model import Lemmatizer from .multitask_model import MultitaskModel diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index 74099ae50..12d4f06ab 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -92,7 +92,6 @@ "dmis-lab/biosyn-sapbert-ncbi-disease": "ctd-disease", "dmis-lab/biosyn-sapbert-bc5cdr-chemical": "ctd-chemical", "dmis-lab/biosyn-sapbert-bc2gn": "ncbi-gene", - "dmis-lab/biosyn-biobert-bc5cdr-disease": "ctd-chemical", "dmis-lab/biosyn-biobert-ncbi-disease": "ctd-disease", "dmis-lab/biosyn-biobert-bc5cdr-chemical": "ctd-chemical", @@ -669,7 +668,7 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, @classmethod def _from_state(cls, state_dict: Dict[str, Any]) -> "CandidateSearchIndex": index = cls( - embeddings=[load_embeddings(emb) for emb in state_dict["embeddings"]], + embeddings=cast(List[DocumentEmbeddings], [load_embeddings(emb) for emb in state_dict["embeddings"]]), similarity_metric=SimilarityMetric(state_dict["similarity_metric"]), weights=state_dict["weights"], batch_size=state_dict["batch_size"], @@ -790,7 +789,6 @@ def _fetch_model(model_name: str) -> str: return hf_download(model_name) - @classmethod def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs) -> "EntityMentionLinker": candidate_generator = CandidateSearchIndex._from_state(state["candidate_search_index"]) @@ -828,9 +826,7 @@ def build( dictionary: Optional[EntityLinkingDictionary] = None, dataset_name: Optional[str] = None, ) -> "EntityMentionLinker": - """Loads a model for biomedical named entity normalization. - - """ + """Loads a model for biomedical named entity normalization.""" if not isinstance(model_name_or_path, str): raise ValueError(f"String matching model name has to be an string (and not {type(model_name_or_path)}") model_name_or_path = cast(str, model_name_or_path) diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 2f1bab67e..2bcc2c5cc 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -2,7 +2,6 @@ import tempfile from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union, cast -from urllib.error import HTTPError import torch import torch.nn @@ -14,7 +13,7 @@ from flair.data import Dictionary, Label, Sentence, Span, get_spans_from_bio from flair.datasets import DataLoader, FlairDatapointDataset from flair.embeddings import TokenEmbeddings -from flair.file_utils import cached_path, unzip_file, hf_download +from flair.file_utils import cached_path, hf_download, unzip_file from flair.models.sequence_tagger_utils.crf import CRF from flair.models.sequence_tagger_utils.viterbi import ViterbiDecoder, ViterbiLoss from flair.training_utils import store_embeddings @@ -764,8 +763,6 @@ def _fetch_model(model_name) -> str: cache_dir = Path("models") - get_from_model_hub = False - # check if model name is a valid local file if Path(model_name).exists(): model_path = model_name From 1d55fec589d1801ae6227094b1545b370f03e2a5 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Sun, 24 Dec 2023 15:02:23 +0100 Subject: [PATCH 36/58] fix entity linking test --- tests/test_biomedical_entity_linking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index 7a165041d..fa71b129a 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -55,7 +55,7 @@ def test_biomedical_entity_linking(): disease_dictionary = disease_linker.dictionary disease_linker.predict(sentence) - gene_linker = EntityMentionLinker.load("bio-genes") + gene_linker = EntityMentionLinker.load("bio-gene") gene_dictionary = gene_linker.dictionary gene_linker.predict(sentence) From a4a27dcae36be9d1c2496965245e388f128f5be9 Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Tue, 2 Jan 2024 12:46:38 +0100 Subject: [PATCH 37/58] fixed selection of knowledge base identifiers for entity_mention_linking --- flair/models/entity_mention_linking.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index 12d4f06ab..be367718b 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -654,14 +654,14 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, """ mention_embs = self.emb_search(entity_mentions) all_scores = mention_embs @ self._precomputed_embeddings.T - selected_indices = np.argsort(all_scores, axis=1)[:, :top_k] - scores = np.take_along_axis(all_scores, selected_indices, axis=1) + indices_top_k = np.argpartition(all_scores, kth=-top_k, axis=1)[:, -top_k:] + mention_numbers = np.tile(np.arange(len(entity_mentions)), (top_k, 1)).T + positions_top_k = np.argsort(all_scores[mention_numbers, indices_top_k], axis=1) + sorted_indices_top_k = indices_top_k[mention_numbers, positions_top_k] results = [] - for i in range(selected_indices.shape[0]): - results.append( - [(self.ids[selected_indices[i, j]], float(scores[i, j])) for j in range(selected_indices.shape[1])] - ) + for i in range(sorted_indices_top_k.shape[0]): + results.append([(self.ids[j], float(all_scores[i, j])) for j in sorted_indices_top_k[i, :]]) return results From 6689297c4a930bb953a6fe102a368b6f61749de6 Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Fri, 5 Jan 2024 15:48:35 +0100 Subject: [PATCH 38/58] fix(dictionary): corrections in file parsing --- flair/datasets/entity_linking.py | 43 +++++++++++++------------------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/flair/datasets/entity_linking.py b/flair/datasets/entity_linking.py index a73ded8f6..9c1196797 100644 --- a/flair/datasets/entity_linking.py +++ b/flair/datasets/entity_linking.py @@ -8,7 +8,8 @@ import flair from flair.data import Corpus, EntityCandidate, MultiCorpus, Sentence -from flair.datasets.sequence_labeling import ColumnCorpus, MultiFileColumnCorpus +from flair.datasets.sequence_labeling import (ColumnCorpus, + MultiFileColumnCorpus) from flair.file_utils import cached_path, unpack_file from flair.splitter import SegtokSentenceSplitter, SentenceSplitter @@ -39,11 +40,14 @@ def __init__( candidates = list(candidates) self._idx_to_candidates = {candidate.concept_id: candidate for candidate in candidates} - self._text_to_index = { - text: candidate.concept_id - for candidate in candidates - for text in [candidate.concept_name, *candidate.synonyms] - } + + # one name can map to multiple concepts + self._text_to_index: Dict[str, List] = {} + for candidate in candidates: + for text in [candidate.concept_name, *candidate.synonyms]: + if text not in self._text_to_index: + self._text_to_index[text] = [] + self._text_to_index[text].append(candidate.concept_id) @property def database_name(self) -> str: @@ -51,7 +55,7 @@ def database_name(self) -> str: return self._dataset_name @property - def text_to_index(self) -> Dict[str, str]: + def text_to_index(self) -> Dict[str, List]: return self._text_to_index @property @@ -109,7 +113,6 @@ def _load_candidates(self): continue assert "||" in line, "Preprocessed EntityLinkingDictionary must have lines in the format: `cui||name`" cui, name = line.split("||", 1) - name = name.lower() cui, *additional_ids = cui.split("|") yield EntityCandidate( concept_id=cui, @@ -171,13 +174,11 @@ def parse_file(self, original_file: Path) -> Iterator[EntityCandidate]: identifier = row["identifier"] additional_identifiers = [i for i in row.get("alternative_identifiers", "").split("|") if i != ""] - if identifier == "MESH:C" and not additional_identifiers: - return None + if any(i == "MESH:C" for i in [identifier, *additional_identifiers]): + continue symbol = row["symbol"] - synonyms = [s for s in row.get("synonyms", "").split("|") if s != ""] - definition = row["definition"] yield EntityCandidate( concept_id=identifier, @@ -185,7 +186,6 @@ def parse_file(self, original_file: Path) -> Iterator[EntityCandidate]: database_name="CTD-DISEASES", additional_ids=additional_identifiers, synonyms=synonyms, - description=definition, ) @@ -240,14 +240,12 @@ def parse_file(self, original_file: Path) -> Iterator[EntityCandidate]: identifier = row["identifier"] additional_identifiers = [i for i in row.get("alternative_identifiers", "").split("|") if i != ""] - if identifier == "MESH:D013749": - # This MeSH ID was used by MeSH when this chemical was part of the MeSH controlled vocabulary. - continue + # if identifier == "MESH:D013749": + # # This MeSH ID was used by MeSH when this chemical was part of the MeSH controlled vocabulary. + # continue symbol = row["symbol"] - synonyms = [s for s in row.get("synonyms", "").split("|") if s != "" and s != symbol] - definition = row["definition"] yield EntityCandidate( concept_id=identifier, @@ -255,7 +253,6 @@ def parse_file(self, original_file: Path) -> Iterator[EntityCandidate]: database_name="CTD-CHEMICALS", additional_ids=additional_identifiers, synonyms=synonyms, - description=definition, ) @@ -347,11 +344,6 @@ def parse_dictionary(self, original_file: Path) -> Iterator[EntityCandidate]: if self._is_invalid_name(symbol): continue - additional_identifiers = [i for i in row.get("alternative_identifiers", "").split("|") if i != ""] - - if identifier == "MESH:D013749": - # This MeSH ID was used by MeSH when this chemical was part of the MeSH controlled vocabulary. - continue synonyms = [] for synonym_field in synonym_fields: @@ -364,7 +356,6 @@ def parse_dictionary(self, original_file: Path) -> Iterator[EntityCandidate]: concept_id=identifier, concept_name=symbol, database_name="NCBI-GENE-HUMAN", - additional_ids=additional_identifiers, synonyms=synonyms, ) @@ -447,12 +438,14 @@ def parse_dictionary(self, original_file: Path) -> Iterator[EntityCandidate]: curr_name = synonym else: curr_synonyms.append(synonym) + elif curr_identifier != parsed_line["identifier"]: assert curr_name is not None yield EntityCandidate( concept_id=curr_identifier, concept_name=curr_name, database_name="NCBI-TAXONOMY", + synonyms=curr_synonyms, ) curr_identifier = parsed_line["identifier"] From c2c53a9e91311aa5693fbbafadf7e6a11cfe1342 Mon Sep 17 00:00:00 2001 From: Samuele Garda Date: Thu, 11 Jan 2024 14:12:27 +0100 Subject: [PATCH 39/58] fix: preprocessing, candidate generator, linker - preprocessing: ensure no empty strings after processing - preprocessing: ensure Ab3P works - generator: separate sparse and dense search - generator: constant with sparse weight for pre-trained models --- flair/models/entity_mention_linking.py | 352 ++++++++++++++----------- 1 file changed, 205 insertions(+), 147 deletions(-) diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index be367718b..dbb8143a8 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -1,3 +1,4 @@ +import copy import inspect import logging import os @@ -14,12 +15,13 @@ import numpy as np import torch +from scipy import sparse from torch.utils.data import Dataset from tqdm import tqdm import flair from flair.class_utils import get_state_subclass_by_name -from flair.data import DT, Dictionary, Label, Sentence, Span +from flair.data import DT, Dictionary, Sentence from flair.datasets import ( CTD_CHEMICALS_DICTIONARY, CTD_DISEASES_DICTIONARY, @@ -52,8 +54,20 @@ "dmis-lab/biosyn-sapbert-bc2gn": "gene", } +# fetched from original repo to avoid download +HYBRID_MODELS_SPARSE_WEIGHT = { + "dmis-lab/biosyn-sapbert-bc5cdr-disease": 0.09762775897979736, + "dmis-lab/biosyn-sapbert-ncbi-disease": 0.40971508622169495, + "dmis-lab/biosyn-sapbert-bc5cdr-chemical": 0.07534809410572052, + "dmis-lab/biosyn-biobert-bc5cdr-disease": 1.5729279518127441, + "dmis-lab/biosyn-biobert-ncbi-disease": 1.7646825313568115, + "dmis-lab/biosyn-biobert-bc2gn": 1.5786927938461304, + "dmis-lab/biosyn-sapbert-bc2gn": 0.0288906991481781, +} + PRETRAINED_MODELS = list(PRETRAINED_HYBRID_MODELS) + PRETRAINED_DENSE_MODELS + # just in case we add: fuzzy search, Levenstein, ... STRING_MATCHING_MODELS = ["exact-string-match"] @@ -101,6 +115,16 @@ DEFAULT_SPARSE_WEIGHT = 0.5 +class SimilarityMetric(Enum): + """Similarity metrics.""" + + INNER_PRODUCT = auto() + COSINE = auto() + + +PRETRAINED_MODEL_TO_SIMILARITY_METRIC = {m: SimilarityMetric.INNER_PRODUCT for m in PRETRAINED_MODELS} + + def load_dictionary( dictionary_name_or_path: Union[Path, str], dataset_name: Optional[str] = None ) -> EntityLinkingDictionary: @@ -117,13 +141,6 @@ def load_dictionary( return HunerEntityLinkingDictionary(path=dictionary_name_or_path, dataset_name=dataset_name) -class SimilarityMetric(Enum): - """Similarity metrics.""" - - INNER_PRODUCT = auto() - COSINE = auto() - - class EntityPreprocessor(ABC): """A pre-processor used to transform / clean both entity mentions and entity names.""" @@ -136,7 +153,7 @@ def initialize(self, sentences: List[Sentence]) -> None: sentences: List of sentences that will be processed. """ - def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: + def process_mention(self, entity_mention: str, sentence: Optional[Sentence] = None) -> str: """Processes the given entity mention and applies the transformation procedure to it. Usually just forwards the entity_mention to :meth:`EntityPreprocessor.process_entity_name`, but can be implemented @@ -149,7 +166,7 @@ def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: Returns: Cleaned / transformed string representation of the given entity mention """ - return self.process_entity_name(entity_mention.data_point.text) + return self.process_entity_name(entity_mention) @abstractmethod def process_entity_name(self, entity_name: str) -> str: @@ -197,6 +214,8 @@ def __init__(self, lowercase: bool = True, remove_punctuation: bool = True): self.rmv_puncts_regex = re.compile(rf"[\s{re.escape(string.punctuation)}]+") def process_entity_name(self, entity_name: str) -> str: + original = copy.deepcopy(entity_name) + if self.lowercase: entity_name = entity_name.lower() @@ -204,7 +223,12 @@ def process_entity_name(self, entity_name: str) -> str: name_parts = self.rmv_puncts_regex.split(entity_name) entity_name = " ".join(name_parts).strip() - return entity_name.strip() + entity_name = entity_name.strip() + + # NOTE: Avoid emtpy string if mentions are just punctutations (e.g. `-` or `(`) + entity_name = original if len(entity_name) == 0 else entity_name + + return entity_name def _get_state(self) -> Dict[str, Any]: return { @@ -223,7 +247,12 @@ class Ab3PEntityPreprocessor(EntityPreprocessor): https://github.com/ncbi-nlp/Ab3P. """ - def __init__(self, ab3p_path: Path, word_data_dir: Path, preprocessor: Optional[EntityPreprocessor] = None) -> None: + def __init__( + self, + ab3p_path: Optional[Path] = None, + word_data_dir: Optional[Path] = None, + preprocessor: Optional[EntityPreprocessor] = None, + ) -> None: """Creates the mention pre-processor. Args: @@ -231,31 +260,36 @@ def __init__(self, ab3p_path: Path, word_data_dir: Path, preprocessor: Optional[ word_data_dir: Path to the word data directory preprocessor: Basic entity preprocessor """ - self.ab3p_path = ab3p_path - self.word_data_dir = word_data_dir + if ab3p_path is not None and word_data_dir is not None: + self.ab3p_path = ab3p_path + self.word_data_dir = word_data_dir + else: + self.ab3p_path, self.word_data_dir = self._get_biosyn_ab3p_paths() self.preprocessor = preprocessor self.abbreviation_dict: Dict[str, Dict[str, str]] = {} def initialize(self, sentences: List[Sentence]) -> None: self.abbreviation_dict = self._build_abbreviation_dict(sentences) - def process_mention(self, entity_mention: Label, sentence: Sentence) -> str: - sentence_text = sentence.to_tokenized_string().strip() - tokens = [token.text for token in cast(Span, entity_mention.data_point).tokens] + def process_mention(self, entity_mention: str, sentence: Optional[Sentence] = None) -> str: + assert ( + sentence is not None + ), "Ab3P requires the sentence where `entity_mention` was found for abbreviation resolution" + + original = copy.deepcopy(entity_mention) - parsed_tokens = [] - for token in tokens: - if self.preprocessor is not None: - token = self.preprocessor.process_entity_name(token) + sentence_text = sentence.to_original_text() - if sentence_text in self.abbreviation_dict and token.lower() in self.abbreviation_dict[sentence_text]: - parsed_tokens.append(self.abbreviation_dict[sentence_text][token.lower()]) - continue + if entity_mention in self.abbreviation_dict.get(sentence_text, {}): + entity_mention = self.abbreviation_dict[sentence_text][entity_mention] + + if self.preprocessor is not None: + entity_mention = self.preprocessor.process_entity_name(entity_mention) - if len(token) != 0: - parsed_tokens.append(token) + # NOTE: Avoid emtpy string if mentions are just punctutations (e.g. `-` or `(`) + entity_mention = original if len(entity_mention) == 0 else entity_mention - return " ".join(parsed_tokens) + return entity_mention def process_entity_name(self, entity_name: str) -> str: # Ab3P works on sentence-level and not on a single entity mention / name @@ -265,8 +299,7 @@ def process_entity_name(self, entity_name: str) -> str: return entity_name - @classmethod - def load_biosyn(cls, preprocessor: Optional[EntityPreprocessor] = None): + def _get_biosyn_ab3p_paths(self) -> Tuple[Path, Path]: data_dir = flair.cache_root / "ab3p_biosyn" if not data_dir.exists(): data_dir.mkdir(parents=True) @@ -275,12 +308,11 @@ def load_biosyn(cls, preprocessor: Optional[EntityPreprocessor] = None): if not word_data_dir.exists(): word_data_dir.mkdir() - ab3p_path = cls._download_biosyn_ab3p(data_dir, word_data_dir) + ab3p_path = self._download_biosyn_ab3p(data_dir, word_data_dir) - return cls(ab3p_path, word_data_dir, preprocessor) + return ab3p_path, word_data_dir - @classmethod - def _download_biosyn_ab3p(cls, data_dir: Path, word_data_dir: Path) -> Path: + def _download_biosyn_ab3p(self, data_dir: Path, word_data_dir: Path) -> Path: """Downloads the Ab3P tool and all necessary data files.""" # Download word data for Ab3P if not already downloaded ab3p_url = "https://raw.githubusercontent.com/dmis-lab/BioSyn/master/Ab3P/WordData/" @@ -337,7 +369,7 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict # Create a temp file which holds the sentences we want to process with Ab3P with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as temp_file: for sentence in sentences: - temp_file.write(sentence.to_tokenized_string() + "\n") + temp_file.write(sentence.to_original_text() + "\n") temp_file.flush() # Temporarily create path file in the current working directory for Ab3P @@ -376,8 +408,8 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict continue sf, lf, _ = line.split("|") - sf = sf.strip().lower() - lf = lf.strip().lower() + sf = sf.strip() + lf = lf.strip() abbreviation_dict[cur_sentence][sf] = lf elif len(line.strip()) > 0: @@ -498,34 +530,32 @@ class SemanticCandidateSearchIndex(CandidateSearchIndex): def __init__( self, - embeddings: List[DocumentEmbeddings], - similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, - weights: Optional[List[float]] = None, + embeddings: Dict[str, DocumentEmbeddings], + hybrid_search: bool, + similarity_metric: SimilarityMetric = SimilarityMetric.INNER_PRODUCT, + sparse_weight: float = DEFAULT_SPARSE_WEIGHT, batch_size: int = 128, show_progress: bool = True, ): - """Initializes the EncoderCandidateSearchIndex. + """Initializes the SemanticCandidateSearchIndex. Args: embeddings: A list of embeddings used for search. - weights: Weight the embedding's importance. + hybrid_search: combine sparse and dense embeddings + sparse_weight: Weight for sparse embeddings. similarity_metric: The metric used to define similarity. batch_size: The batch size used for indexing embeddings. show_progress: show the progress while indexing. """ - if weights is None: - weights = [1.0 for _ in embeddings] - if len(weights) != len(embeddings): - raise ValueError("Weights have to be of the same length as embeddings") - self.embeddings = embeddings - self.weights = weights + self.hybrid_search = hybrid_search + self.sparse_weight = sparse_weight self.similarity_metric = similarity_metric self.show_progress = show_progress self.batch_size = batch_size self.ids: List[str] = [] - self._precomputed_embeddings: np.ndarray = np.array([]) + self._precomputed_embeddings: Dict[str, np.ndarray] = {"sparse": np.array([]), "dense": np.array([])} @classmethod def bi_encoder( @@ -539,8 +569,12 @@ def bi_encoder( preprocessor: Optional[EntityPreprocessor] = None, dictionary: Optional[EntityLinkingDictionary] = None, ) -> "SemanticCandidateSearchIndex": - embeddings: List[DocumentEmbeddings] = [TransformerDocumentEmbeddings(model_name_or_path)] - weights = [1.0] + # NOTE: ensure correct similarity metric for pretrained model + if model_name_or_path in PRETRAINED_MODELS: + similarity_metric = PRETRAINED_MODEL_TO_SIMILARITY_METRIC[model_name_or_path] + + embeddings: Dict[str, DocumentEmbeddings] = {"dense": TransformerDocumentEmbeddings(model_name_or_path)} + if hybrid_search: if dictionary is None: raise ValueError("Require dictionary to be set on hybrid search.") @@ -554,20 +588,25 @@ def bi_encoder( if preprocessor is not None: texts = [preprocessor.process_entity_name(t) for t in texts] - embeddings.append( - DocumentTFIDFEmbeddings( - [Sentence(t) for t in texts], - analyzer="char", - ngram_range=(1, 2), - ) + embeddings["sparse"] = DocumentTFIDFEmbeddings( + [Sentence(t) for t in texts], + analyzer="char", + ngram_range=(1, 2), ) - weights = [1.0, sparse_weight] + + sparse_weight = ( + sparse_weight + if model_name_or_path not in HYBRID_MODELS_SPARSE_WEIGHT + else HYBRID_MODELS_SPARSE_WEIGHT[model_name_or_path] + ) + return cls( embeddings, similarity_metric=similarity_metric, - weights=weights, + sparse_weight=sparse_weight, batch_size=batch_size, show_progress=show_progress, + hybrid_search=hybrid_search, ) def index(self, dictionary: EntityLinkingDictionary, preprocessor: Optional[EntityPreprocessor] = None) -> None: @@ -583,8 +622,7 @@ def p(text: str) -> str: texts.append(p(synonym)) self.ids.append(candidate.concept_id) - precomputed_embeddings = [] - + dense_embeddings = [] with torch.no_grad(): if self.show_progress: iterations = tqdm( @@ -598,49 +636,60 @@ def p(text: str) -> str: end = min(start + self.batch_size, len(texts)) batch = [Sentence(name) for name in texts[start:end]] - for embedding in self.embeddings: - embedding.embed(batch) - + self.embeddings["dense"].embed(batch) for sent in batch: - embs = [] - for embedding, weight in zip(self.embeddings, self.weights): - emb = sent.get_embedding(embedding.get_names()) - if self.similarity_metric == SimilarityMetric.COSINE: - emb = emb / torch.norm(emb) - embs.append(emb * weight) - - precomputed_embeddings.append(torch.cat(embs, dim=0).cpu().numpy()) + emb = sent.get_embedding() + if self.similarity_metric == SimilarityMetric.COSINE: + emb = emb / torch.norm(emb) + dense_embeddings.append(emb.cpu().numpy()) sent.clear_embeddings() if flair.device.type == "cuda": torch.cuda.empty_cache() - self._precomputed_embeddings = np.stack(precomputed_embeddings, axis=0) + self._precomputed_embeddings["dense"] = np.stack(dense_embeddings, axis=0) - def emb_search(self, entity_mentions: List[str]) -> np.ndarray: - embeddings = [] + if self.hybrid_search: + sparse_embs = [] + batch = [Sentence(name) for name in texts] + self.embeddings["sparse"].embed(batch) + for sent in batch: + sparse_emb = sent.get_embedding() + if self.similarity_metric == SimilarityMetric.COSINE: + sparse_emb = sparse_emb / torch.norm(sparse_emb) + sparse_embs.append(sparse_emb.cpu().numpy()) + sent.clear_embeddings() + self._precomputed_embeddings["sparse"] = np.stack(sparse_embs, axis=0) + + def embed(self, entity_mentions: List[str]) -> Dict[str, np.ndarray]: + query_embeddings: Dict[str, List] = {"dense": []} + + inputs = [Sentence(name) for name in entity_mentions] with torch.no_grad(): for start in range(0, len(entity_mentions), self.batch_size): end = min(start + self.batch_size, len(entity_mentions)) - batch = [Sentence(name) for name in entity_mentions[start:end]] - - for embedding in self.embeddings: - embedding.embed(batch) - + batch = inputs[start:end] + self.embeddings["dense"].embed(batch) for sent in batch: - embs = [] - for embedding in self.embeddings: - emb = sent.get_embedding(embedding.get_names()) - if self.similarity_metric == SimilarityMetric.COSINE: - emb = emb / torch.norm(emb) - embs.append(emb) - - embeddings.append(torch.cat(embs, dim=0).cpu().numpy()) + emb = sent.get_embedding() + if self.similarity_metric == SimilarityMetric.COSINE: + emb = emb / torch.norm(emb) + query_embeddings["dense"].append(emb.cpu().numpy()) sent.clear_embeddings() if flair.device.type == "cuda": torch.cuda.empty_cache() - return np.stack(embeddings, axis=0) + if self.hybrid_search: + query_embeddings["sparse"] = [] + self.embeddings["sparse"].embed(inputs) + for sent in inputs: + sparse_emb = sent.get_embedding() + if self.similarity_metric == SimilarityMetric.COSINE: + sparse_emb = sparse_emb / torch.norm(sparse_emb) + query_embeddings["sparse"].append(sparse_emb.cpu().numpy()) + sent.clear_embeddings() + + return {k: np.stack(v, axis=0) for k, v in query_embeddings.items()} def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, float]]]: """Returns the top-k entity / concept identifiers for each entity mention. @@ -652,26 +701,39 @@ def search(self, entity_mentions: List[str], top_k: int) -> List[List[Tuple[str, Returns: List containing a list of entity linking candidates per entity mention from the input """ - mention_embs = self.emb_search(entity_mentions) - all_scores = mention_embs @ self._precomputed_embeddings.T - indices_top_k = np.argpartition(all_scores, kth=-top_k, axis=1)[:, -top_k:] - mention_numbers = np.tile(np.arange(len(entity_mentions)), (top_k, 1)).T - positions_top_k = np.argsort(all_scores[mention_numbers, indices_top_k], axis=1) - sorted_indices_top_k = indices_top_k[mention_numbers, positions_top_k] + mention_embs = self.embed(entity_mentions) + + scores = mention_embs["dense"] @ self._precomputed_embeddings["dense"].T + + if self.hybrid_search: + query = sparse.csr_matrix(mention_embs["sparse"]) + index = sparse.csr_matrix(self._precomputed_embeddings["sparse"]) + sparse_scores = query.dot(index.T).toarray() + scores += self.sparse_weight * sparse_scores + + num_mentions = scores.shape[0] + unsorted_indices = np.argpartition(scores, -top_k)[:, -top_k:] + unsorted_scores = scores[np.arange(num_mentions)[:, None], unsorted_indices] + sorted_score_matrix_indices = np.argsort(-unsorted_scores) + topk_idxs = unsorted_indices[np.arange(num_mentions)[:, None], sorted_score_matrix_indices] + topk_scores = unsorted_scores[np.arange(num_mentions)[:, None], sorted_score_matrix_indices] results = [] - for i in range(sorted_indices_top_k.shape[0]): - results.append([(self.ids[j], float(all_scores[i, j])) for j in sorted_indices_top_k[i, :]]) + for i in range(num_mentions): + results.append([(self.ids[j], s) for j, s in zip(topk_idxs[i, :], topk_scores[i, :])]) return results @classmethod - def _from_state(cls, state_dict: Dict[str, Any]) -> "CandidateSearchIndex": + def _from_state(cls, state_dict: Dict[str, Any]) -> "SemanticCandidateSearchIndex": index = cls( - embeddings=cast(List[DocumentEmbeddings], [load_embeddings(emb) for emb in state_dict["embeddings"]]), + embeddings=cast( + Dict[str, DocumentEmbeddings], {k: load_embeddings(emb) for k, emb in state_dict["embeddings"].items()} + ), similarity_metric=SimilarityMetric(state_dict["similarity_metric"]), - weights=state_dict["weights"], + sparse_weight=state_dict["sparse_weight"], batch_size=state_dict["batch_size"], + hybrid_search=state_dict["hybrid_search"], show_progress=state_dict["show_progress"], ) index.ids = state_dict["ids"] @@ -681,11 +743,12 @@ def _from_state(cls, state_dict: Dict[str, Any]) -> "CandidateSearchIndex": def _get_state(self) -> Dict[str, Any]: return { **super()._get_state(), - "embeddings": [emb.save_embeddings() for emb in self.embeddings], + "embeddings": {k: emb.save_embeddings() for k, emb in self.embeddings.items()}, "similarity_metric": self.similarity_metric.value, - "weights": self.weights, + "sparse_weight": self.sparse_weight, "batch_size": self.batch_size, "show_progress": self.show_progress, + "hybrid_search": self.hybrid_search, "ids": self.ids, "precomputed_embeddings": self._precomputed_embeddings, } @@ -717,25 +780,6 @@ def label_type(self): def dictionary(self) -> EntityLinkingDictionary: return self._dictionary - def extract_mentions( - self, - sentences: List[Sentence], - ) -> Tuple[List[Span], List[str]]: - """Unpack all mentions in sentences for batch search.""" - data_points = [] - mentions = [] - - for sentence in sentences: - for entity in sentence.get_labels(self.entity_label_type): - data_points.append(entity.data_point) - mentions.append( - self.preprocessor.process_mention(entity, sentence) - if self.preprocessor is not None - else entity.data_point.text, - ) - - return data_points, mentions - def predict( self, sentences: Union[List[Sentence], Sentence], @@ -754,7 +798,17 @@ def predict( if self.preprocessor is not None: self.preprocessor.initialize(sentences) - data_points, mentions = self.extract_mentions(sentences=sentences) + data_points = [] + mentions = [] + + for sentence in sentences: + for entity in sentence.get_labels(self.entity_label_type): + data_points.append(entity.data_point) + mentions.append( + self.preprocessor.process_mention(entity, sentence) + if self.preprocessor is not None + else entity.data_point.text, + ) # no mentions: nothing to do here if len(mentions) > 0: @@ -818,9 +872,8 @@ def build( dictionary_name_or_path: Optional[Union[str, Path]] = None, hybrid_search: bool = True, batch_size: int = 128, - similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, - preprocessor: EntityPreprocessor = BioSynEntityPreprocessor(), - force_hybrid_search: bool = False, + similarity_metric: SimilarityMetric = SimilarityMetric.INNER_PRODUCT, + preprocessor: Optional[EntityPreprocessor] = None, sparse_weight: float = DEFAULT_SPARSE_WEIGHT, entity_type: Optional[str] = None, dictionary: Optional[EntityLinkingDictionary] = None, @@ -843,12 +896,17 @@ def build( model_name_or_path=model_name_or_path, entity_type=entity_type, hybrid_search=hybrid_search, - force_hybrid_search=force_hybrid_search, ) else: assert entity_type is not None, "When using a custom model you must specify `entity_type`" assert entity_type in ENTITY_TYPES, f"Invalid entity type `{entity_type}! Must be one of: {ENTITY_TYPES}" + preprocessor = ( + preprocessor + if preprocessor is not None + else Ab3PEntityPreprocessor(preprocessor=BioSynEntityPreprocessor()) + ) + if model_name_or_path == "exact-string-match": candidate_generator: CandidateSearchIndex = ExactMatchCandidateSearchIndex() else: @@ -865,7 +923,7 @@ def build( candidate_generator.index(dictionary, preprocessor) logger.info( - "BiomedicalEntityLinker predicts: Dictionary `%s` (entity type: %s)", dictionary_name_or_path, entity_type + "EntityMentionLinker predicts: Dictionary `%s` (entity type: %s)", dictionary_name_or_path, entity_type ) return cls( @@ -881,7 +939,6 @@ def __get_model_path_and_entity_type( model_name_or_path: Union[str, Path], entity_type: Optional[str] = None, hybrid_search: bool = False, - force_hybrid_search: bool = False, ) -> Tuple[Union[str, Path], str]: """Try to figure out what model the user wants.""" if model_name_or_path not in MODELS and model_name_or_path not in ENTITY_TYPES: @@ -898,36 +955,31 @@ def __get_model_path_and_entity_type( # load model by entity_type if isinstance(model_name_or_path, str) and model_name_or_path in ENTITY_TYPES: model_name_or_path = cast(str, model_name_or_path) + entity_type = model_name_or_path # check if we have a hybrid pre-trained model if model_name_or_path in ENTITY_TYPE_TO_HYBRID_MODEL: - entity_type = model_name_or_path model_name_or_path = ENTITY_TYPE_TO_HYBRID_MODEL[model_name_or_path] else: - # check if user really wants to use hybrid search anyway - if not force_hybrid_search: - logger.warning( - "BiEncoderCandidateGenerator: model for entity type `%s` was not trained for" - " hybrid search: no sparse search will be performed." - " If you want to use sparse search please pass `force_hybrid_search=True`:" - " we will fit a sparse encoder for you. The default value of `sparse_weight` is `%s`.", - model_name_or_path, - DEFAULT_SPARSE_WEIGHT, - ) + logger.warning( + "EntityMentionLinker: `hybrid_search=True` but model for entity type `%s` was not trained for hybrid search." + " Results may be poor.", + model_name_or_path, + ) model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] else: - if model_name_or_path not in PRETRAINED_HYBRID_MODELS and not force_hybrid_search: + if model_name_or_path not in PRETRAINED_HYBRID_MODELS: logger.warning( - "BiEncoderCandidateGenerator: model `%s` was not trained for hybrid search: no sparse" - " search will be performed." - " If you want to use sparse search please pass `force_hybrid_search=True`:" - " we will fit a sparse encoder for you. The default value of `sparse_weight` is `%s`.", + "EntityMentionLinker: `hybrid_search=True` but model `%s` was not trained for hybrid search." + " Results may be poor.", model_name_or_path, - DEFAULT_SPARSE_WEIGHT, ) - - model_name_or_path = cast(str, model_name_or_path) - entity_type = PRETRAINED_HYBRID_MODELS[model_name_or_path] + assert ( + entity_type is not None + ), f"For non-hybrid model `{model_name_or_path}` with `hybrid_search=True` you must specify `entity_type`" + else: + model_name_or_path = cast(str, model_name_or_path) + entity_type = PRETRAINED_HYBRID_MODELS[model_name_or_path] else: if isinstance(model_name_or_path, str) and model_name_or_path in ENTITY_TYPES: @@ -971,6 +1023,12 @@ def __get_dictionary_path( def forward_loss(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]: raise NotImplementedError("The EntityLinker cannot be trained") + @classmethod + def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "EntityMentionLinker": + from typing import cast + + return cast("EntityMentionLinker", super().load(model_path=model_path)) + def evaluate( self, data_points: Union[List[DT], Dataset], From 1efe956c03586eded24808c6d2339218af888fed Mon Sep 17 00:00:00 2001 From: Samuele Garda Date: Thu, 11 Jan 2024 14:12:54 +0100 Subject: [PATCH 40/58] feat(tests): test preprocessing --- tests/test_biomedical_entity_linking.py | 32 +++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index fa71b129a..5d2b5e0e3 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -1,5 +1,7 @@ from flair.data import Sentence from flair.models.entity_mention_linking import ( + Ab3PEntityPreprocessor, + BioSynEntityPreprocessor, EntityMentionLinker, load_dictionary, ) @@ -45,6 +47,36 @@ def test_bel_dictionary(): assert candidate.concept_id.isdigit() +def test_biosyn_preprocessing(): + """Check preprocessing does not produce empty strings.""" + preprocessor = BioSynEntityPreprocessor() + + # NOTE: Avoid emtpy string if mentions are just punctutations (e.g. `-` or `(`) + for s in ["-", "(", ")", "9"]: + assert len(preprocessor.process_mention(s)) > 0 + assert len(preprocessor.process_entity_name(s)) > 0 + + +def test_abbrevitation_resolution(): + """Test abbreviation resolution works correctly.""" + preprocessor = Ab3PEntityPreprocessor(preprocessor=BioSynEntityPreprocessor()) + + sentences = [ + Sentence("Features of ARCL type II overlap with those of Wrinkly skin syndrome (WSS)."), + Sentence("Weaver-Smith syndrome (WSS) is a Mendelian disorder of the epigenetic machinery."), + ] + + preprocessor.initialize(sentences) + + mentions = ["WSS", "WSS"] + for idx, (mention, sentence) in enumerate(zip(mentions, sentences)): + mention = preprocessor.process_mention(mention, sentence) + if idx == 0: + assert mention == "wrinkly skin syndrome" + elif idx == 1: + assert mention == "weaver-smith syndrome" + + def test_biomedical_entity_linking(): sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome") From 5da7494b362087a3e45a3276a16d4f38a3d097d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20S=C3=A4nger?= Date: Thu, 11 Jan 2024 19:31:43 +0100 Subject: [PATCH 41/58] Minor fix in pre-processing pipeline --- flair/models/entity_mention_linking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index dbb8143a8..7c8270775 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -805,7 +805,7 @@ def predict( for entity in sentence.get_labels(self.entity_label_type): data_points.append(entity.data_point) mentions.append( - self.preprocessor.process_mention(entity, sentence) + self.preprocessor.process_mention(entity.data_point.text, sentence) if self.preprocessor is not None else entity.data_point.text, ) From 4cd86b02f7651cb9178eefff35c3e84e58501289 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20S=C3=A4nger?= Date: Fri, 12 Jan 2024 12:11:57 +0100 Subject: [PATCH 42/58] Add support for label and entity type definition + fix tests --- flair/models/entity_mention_linking.py | 52 ++++++++++++++++++++----- tests/test_biomedical_entity_linking.py | 27 ++++++------- 2 files changed, 54 insertions(+), 25 deletions(-) diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index 7c8270775..157a7794d 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -11,7 +11,8 @@ from collections import defaultdict from enum import Enum, auto from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast, Sequence, Set +from collections.abc import Iterable import numpy as np import torch @@ -761,13 +762,23 @@ def __init__( self, candidate_generator: CandidateSearchIndex, preprocessor: EntityPreprocessor, - entity_label_type: str, + entity_label_types: Union[str, Sequence[str], Dict[str, Optional[Set[str]]]], label_type: str, dictionary: EntityLinkingDictionary, ): + """ + Initializes an entity mention linker + + Args: + candidate_generator: Strategy to find matching entities for a given mention + preprocessor: Pre-processing strategy to transform / clean entity mentions + entity_label_types: A label type or sequence of label types of the required relation entities. You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. E.g. to use only 'disease' and 'chemical' labels from a NER-tagger: `{'ner': {'disease', 'chemical'}}`. To use all labels from 'ner', pass 'ner' + label_type: The label under which the predictions of the linker should be stored + dictionary: The dictionary listing all entities + """ self.preprocessor = preprocessor self.candidate_generator = candidate_generator - self.entity_label_type = entity_label_type + self.entity_label_types = entity_label_types self._label_type = label_type self._dictionary = dictionary super().__init__() @@ -784,17 +795,30 @@ def predict( self, sentences: Union[List[Sentence], Sentence], top_k: int = 1, + pred_label_type: Optional[str] = None, + entity_label_types: Optional[Union[str, Sequence[str], Dict[str, Optional[Set[str]]]]] = None ) -> None: """Predicts the best matching top-k entity / concept identifiers of all named entities annotated with tag input_entity_annotation_layer. Args: sentences: One or more sentences to run the prediction on top_k: Number of best-matching entity / concept identifiers + entity_label_types: A label type or sequence of label types of the required relation entities. You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. E.g. to use only 'disease' and 'chemical' labels from a NER-tagger: `{'ner': {'disease', 'chemical'}}`. To use all labels from 'ner', pass 'ner' + pred_label_type: The label under which the predictions of the linker should be stored """ # make sure sentences is a list of sentences if not isinstance(sentences, list): sentences = [sentences] + # Make sure entity label types are represented as dict + entity_label_types = entity_label_types if entity_label_types is not None else self.entity_label_types + if isinstance(entity_label_types, str): + entity_label_types = {entity_label_types: []} + elif isinstance(entity_label_types, Iterable): + entity_label_types = {label: [] for label in entity_label_types} + + pred_label_type = pred_label_type if pred_label_type is not None else self.label_type + if self.preprocessor is not None: self.preprocessor.initialize(sentences) @@ -802,7 +826,17 @@ def predict( mentions = [] for sentence in sentences: - for entity in sentence.get_labels(self.entity_label_type): + # Collect all entities based on entity type labels configuration + entities = [] + for label_type, entity_types in entity_label_types.items(): + labels = sentence.get_labels(label_type) + if len(entity_types) > 0: + labels = [label for label in labels if label.value in entity_types] + + entities.extend(labels) + + # Preprocess entity mentions + for entity in entities: data_points.append(entity.data_point) mentions.append( self.preprocessor.process_mention(entity.data_point.text, sentence) @@ -818,7 +852,7 @@ def predict( # Add a label annotation for each candidate for data_point, mention_candidates in zip(data_points, candidates): for candidate_id, confidence in mention_candidates: - data_point.add_label(self.label_type, candidate_id, confidence) + data_point.add_label(pred_label_type, candidate_id, confidence) @staticmethod def _fetch_model(model_name: str) -> str: @@ -847,18 +881,18 @@ def _fetch_model(model_name: str) -> str: def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs) -> "EntityMentionLinker": candidate_generator = CandidateSearchIndex._from_state(state["candidate_search_index"]) preprocessor = EntityPreprocessor._from_state(state["entity_preprocessor"]) - entity_label_type = state["entity_label_type"] + entity_label_types = state["entity_label_types"] label_type = state["label_type"] dictionary = InMemoryEntityLinkingDictionary.from_state(state["dictionary"]) - return cls(candidate_generator, preprocessor, entity_label_type, label_type, dictionary) + return cls(candidate_generator, preprocessor, entity_label_types, label_type, dictionary) def _get_state_dict(self): """Returns the state dictionary for this model.""" return { **super()._get_state_dict(), "label_type": self.label_type, - "entity_label_type": self.entity_label_type, + "entity_label_types": self.entity_label_types, "entity_preprocessor": self.preprocessor._get_state(), "candidate_search_index": self.candidate_generator._get_state(), "dictionary": self.dictionary.to_in_memory_dictionary().to_state(), @@ -929,7 +963,7 @@ def build( return cls( candidate_generator=candidate_generator, preprocessor=preprocessor, - entity_label_type=entity_type, + entity_label_types=entity_type, label_type=label_type, dictionary=dictionary, ) diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index 5d2b5e0e3..1d02ffc6b 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -74,7 +74,7 @@ def test_abbrevitation_resolution(): if idx == 0: assert mention == "wrinkly skin syndrome" elif idx == 1: - assert mention == "weaver-smith syndrome" + assert mention == "weaver smith syndrome" def test_biomedical_entity_linking(): @@ -83,25 +83,20 @@ def test_biomedical_entity_linking(): tagger = Classifier.load("hunflair") tagger.predict(sentence) - disease_linker = EntityMentionLinker.load("bio-disease") + disease_linker = EntityMentionLinker.load("masaenger/bio-nen-disease") disease_dictionary = disease_linker.dictionary - disease_linker.predict(sentence) + disease_linker.predict(sentence, pred_label_type="disease-nen", entity_label_types="diseases") - gene_linker = EntityMentionLinker.load("bio-gene") + gene_linker = EntityMentionLinker.load("masaenger/bio-nen-gene") gene_dictionary = gene_linker.dictionary - - gene_linker.predict(sentence) + gene_linker.predict(sentence, pred_label_type="gene-nen", entity_label_types="genes") print("Diseases") - for span in sentence.get_spans(disease_linker.entity_label_type): - print(f"Span: {span.text}") - for candidate_label in span.get_labels(disease_linker.label_type): - candidate = disease_dictionary[candidate_label.value] - print(f"Candidate: {candidate.concept_name}") + for label in sentence.get_labels("disease-nen"): + candidate = disease_dictionary[label.value] + print(f"Candidate: {candidate.concept_name}") print("Genes") - for span in sentence.get_spans(gene_linker.entity_label_type): - print(f"Span: {span.text}") - for candidate_label in span.get_labels(gene_linker.label_type): - candidate = gene_dictionary[candidate_label.value] - print(f"Candidate: {candidate.concept_name}") + for label in sentence.get_labels("gene-nen"): + candidate = gene_dictionary[label.value] + print(f"Candidate: {candidate.concept_name}") From aaa5f393a172190e1072a41cf55122d8c7e09833 Mon Sep 17 00:00:00 2001 From: Samuele Garda Date: Fri, 12 Jan 2024 16:42:11 +0100 Subject: [PATCH 43/58] fix: formatting and type checking --- flair/datasets/entity_linking.py | 3 +- flair/models/entity_mention_linking.py | 46 ++++++++++++------------- tests/test_biomedical_entity_linking.py | 7 +++- 3 files changed, 30 insertions(+), 26 deletions(-) diff --git a/flair/datasets/entity_linking.py b/flair/datasets/entity_linking.py index 9c1196797..5e6f55bdf 100644 --- a/flair/datasets/entity_linking.py +++ b/flair/datasets/entity_linking.py @@ -8,8 +8,7 @@ import flair from flair.data import Corpus, EntityCandidate, MultiCorpus, Sentence -from flair.datasets.sequence_labeling import (ColumnCorpus, - MultiFileColumnCorpus) +from flair.datasets.sequence_labeling import ColumnCorpus, MultiFileColumnCorpus from flair.file_utils import cached_path, unpack_file from flair.splitter import SegtokSentenceSplitter, SentenceSplitter diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index 157a7794d..fe65cee7b 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -9,10 +9,10 @@ import tempfile from abc import ABC, abstractmethod from collections import defaultdict +from collections.abc import Iterable from enum import Enum, auto from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast, Sequence, Set -from collections.abc import Iterable +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast import numpy as np import torch @@ -45,14 +45,14 @@ # Dense + sparse retrieval PRETRAINED_HYBRID_MODELS = { - "dmis-lab/biosyn-sapbert-bc5cdr-disease": "disease", - "dmis-lab/biosyn-sapbert-ncbi-disease": "disease", + "dmis-lab/biosyn-sapbert-bc5cdr-disease": "diseases", + "dmis-lab/biosyn-sapbert-ncbi-disease": "diseases", "dmis-lab/biosyn-sapbert-bc5cdr-chemical": "chemical", - "dmis-lab/biosyn-biobert-bc5cdr-disease": "disease", - "dmis-lab/biosyn-biobert-ncbi-disease": "disease", + "dmis-lab/biosyn-biobert-bc5cdr-disease": "diseases", + "dmis-lab/biosyn-biobert-ncbi-disease": "diseases", "dmis-lab/biosyn-biobert-bc5cdr-chemical": "chemical", - "dmis-lab/biosyn-biobert-bc2gn": "gene", - "dmis-lab/biosyn-sapbert-bc2gn": "gene", + "dmis-lab/biosyn-biobert-bc2gn": "genes", + "dmis-lab/biosyn-sapbert-bc2gn": "genes", } # fetched from original repo to avoid download @@ -762,19 +762,18 @@ def __init__( self, candidate_generator: CandidateSearchIndex, preprocessor: EntityPreprocessor, - entity_label_types: Union[str, Sequence[str], Dict[str, Optional[Set[str]]]], + entity_label_types: Union[str, Sequence[str], Dict[str, Set[str]]], label_type: str, dictionary: EntityLinkingDictionary, ): - """ - Initializes an entity mention linker - - Args: - candidate_generator: Strategy to find matching entities for a given mention - preprocessor: Pre-processing strategy to transform / clean entity mentions - entity_label_types: A label type or sequence of label types of the required relation entities. You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. E.g. to use only 'disease' and 'chemical' labels from a NER-tagger: `{'ner': {'disease', 'chemical'}}`. To use all labels from 'ner', pass 'ner' - label_type: The label under which the predictions of the linker should be stored - dictionary: The dictionary listing all entities + """Initializes an entity mention linker. + + Args: + candidate_generator: Strategy to find matching entities for a given mention + preprocessor: Pre-processing strategy to transform / clean entity mentions + entity_label_types: A label type or sequence of label types of the required relation entities. You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. E.g. to use only 'disease' and 'chemical' labels from a NER-tagger: `{'ner': {'disease', 'chemical'}}`. To use all labels from 'ner', pass 'ner' + label_type: The label under which the predictions of the linker should be stored + dictionary: The dictionary listing all entities """ self.preprocessor = preprocessor self.candidate_generator = candidate_generator @@ -796,7 +795,7 @@ def predict( sentences: Union[List[Sentence], Sentence], top_k: int = 1, pred_label_type: Optional[str] = None, - entity_label_types: Optional[Union[str, Sequence[str], Dict[str, Optional[Set[str]]]]] = None + entity_label_types: Optional[Union[str, Sequence[str], Dict[str, Set[str]]]] = None, ) -> None: """Predicts the best matching top-k entity / concept identifiers of all named entities annotated with tag input_entity_annotation_layer. @@ -813,9 +812,9 @@ def predict( # Make sure entity label types are represented as dict entity_label_types = entity_label_types if entity_label_types is not None else self.entity_label_types if isinstance(entity_label_types, str): - entity_label_types = {entity_label_types: []} + entity_label_types = cast(Dict[str, Set[str]], {entity_label_types: {}}) elif isinstance(entity_label_types, Iterable): - entity_label_types = {label: [] for label in entity_label_types} + entity_label_types = cast(Dict[str, Set[str]], {label: {} for label in entity_label_types}) pred_label_type = pred_label_type if pred_label_type is not None else self.label_type @@ -1016,9 +1015,10 @@ def __get_model_path_and_entity_type( entity_type = PRETRAINED_HYBRID_MODELS[model_name_or_path] else: - if isinstance(model_name_or_path, str) and model_name_or_path in ENTITY_TYPES: + if isinstance(model_name_or_path, str): model_name_or_path = cast(str, model_name_or_path) - model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] + if model_name_or_path in ENTITY_TYPES: + model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] assert ( entity_type is not None diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index 1d02ffc6b..52e2412e5 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -89,7 +89,7 @@ def test_biomedical_entity_linking(): gene_linker = EntityMentionLinker.load("masaenger/bio-nen-gene") gene_dictionary = gene_linker.dictionary - gene_linker.predict(sentence, pred_label_type="gene-nen", entity_label_types="genes") + gene_linker.predict(sentence, pred_label_type="gene-nen", entity_label_types="genes") print("Diseases") for label in sentence.get_labels("disease-nen"): @@ -100,3 +100,8 @@ def test_biomedical_entity_linking(): for label in sentence.get_labels("gene-nen"): candidate = gene_dictionary[label.value] print(f"Candidate: {candidate.concept_name}") + + +if __name__ == "__main__": + test_abbrevitation_resolution() + test_biomedical_entity_linking() From def9905d53170b89b99de47f55f52980604f2353 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Sun, 21 Jan 2024 20:48:56 +0100 Subject: [PATCH 44/58] add batchsize to prediction instead of only embedding to reduce memory usage --- flair/datasets/entity_linking.py | 3 +++ flair/models/entity_mention_linking.py | 27 ++++++++++++++------------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/flair/datasets/entity_linking.py b/flair/datasets/entity_linking.py index 5e6f55bdf..bc9744461 100644 --- a/flair/datasets/entity_linking.py +++ b/flair/datasets/entity_linking.py @@ -64,6 +64,9 @@ def candidates(self) -> List[EntityCandidate]: def __getitem__(self, item: str) -> EntityCandidate: return self._idx_to_candidates[item] + def __contains__(self, item: str) -> bool: + return item in self._idx_to_candidates + def to_in_memory_dictionary(self) -> "InMemoryEntityLinkingDictionary": return InMemoryEntityLinkingDictionary(list(self._idx_to_candidates.values()), self._dataset_name) diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index fe65cee7b..4eb1eb35b 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -215,7 +215,7 @@ def __init__(self, lowercase: bool = True, remove_punctuation: bool = True): self.rmv_puncts_regex = re.compile(rf"[\s{re.escape(string.punctuation)}]+") def process_entity_name(self, entity_name: str) -> str: - original = copy.deepcopy(entity_name) + original = entity_name if self.lowercase: entity_name = entity_name.lower() @@ -277,7 +277,7 @@ def process_mention(self, entity_mention: str, sentence: Optional[Sentence] = No sentence is not None ), "Ab3P requires the sentence where `entity_mention` was found for abbreviation resolution" - original = copy.deepcopy(entity_mention) + original = entity_mention sentence_text = sentence.to_original_text() @@ -672,11 +672,11 @@ def embed(self, entity_mentions: List[str]) -> Dict[str, np.ndarray]: batch = inputs[start:end] self.embeddings["dense"].embed(batch) for sent in batch: - emb = sent.get_embedding() + emb = sent.get_embedding(self.embeddings["dense"].get_names()) if self.similarity_metric == SimilarityMetric.COSINE: emb = emb / torch.norm(emb) query_embeddings["dense"].append(emb.cpu().numpy()) - sent.clear_embeddings() + sent.clear_embeddings(self.embeddings["dense"].get_names()) if flair.device.type == "cuda": torch.cuda.empty_cache() @@ -684,11 +684,11 @@ def embed(self, entity_mentions: List[str]) -> Dict[str, np.ndarray]: query_embeddings["sparse"] = [] self.embeddings["sparse"].embed(inputs) for sent in inputs: - sparse_emb = sent.get_embedding() + sparse_emb = sent.get_embedding(self.embeddings["sparse"].get_names()) if self.similarity_metric == SimilarityMetric.COSINE: sparse_emb = sparse_emb / torch.norm(sparse_emb) query_embeddings["sparse"].append(sparse_emb.cpu().numpy()) - sent.clear_embeddings() + sent.clear_embeddings(self.embeddings["sparse"].get_names()) return {k: np.stack(v, axis=0) for k, v in query_embeddings.items()} @@ -765,6 +765,7 @@ def __init__( entity_label_types: Union[str, Sequence[str], Dict[str, Set[str]]], label_type: str, dictionary: EntityLinkingDictionary, + batch_size: int = 1024, ): """Initializes an entity mention linker. @@ -780,6 +781,7 @@ def __init__( self.entity_label_types = entity_label_types self._label_type = label_type self._dictionary = dictionary + self.batch_size = batch_size super().__init__() @property @@ -843,13 +845,12 @@ def predict( else entity.data_point.text, ) - # no mentions: nothing to do here - if len(mentions) > 0: - # Retrieve top-k concept / entity candidates - candidates = self.candidate_generator.search(entity_mentions=mentions, top_k=top_k) + # Retrieve top-k concept / entity candidates + for i in range(0, len(mentions), self.batch_size): + candidates = self.candidate_generator.search(entity_mentions=mentions[i : i + self.batch_size], top_k=top_k) # Add a label annotation for each candidate - for data_point, mention_candidates in zip(data_points, candidates): + for data_point, mention_candidates in zip(data_points[i : i + self.batch_size], candidates): for candidate_id, confidence in mention_candidates: data_point.add_label(pred_label_type, candidate_id, confidence) @@ -883,8 +884,9 @@ def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs) -> "Entity entity_label_types = state["entity_label_types"] label_type = state["label_type"] dictionary = InMemoryEntityLinkingDictionary.from_state(state["dictionary"]) + batch_size = state.get("batch_size", 128) - return cls(candidate_generator, preprocessor, entity_label_types, label_type, dictionary) + return cls(candidate_generator, preprocessor, entity_label_types, label_type, dictionary, batch_size=batch_size) def _get_state_dict(self): """Returns the state dictionary for this model.""" @@ -895,6 +897,7 @@ def _get_state_dict(self): "entity_preprocessor": self.preprocessor._get_state(), "candidate_search_index": self.candidate_generator._get_state(), "dictionary": self.dictionary.to_in_memory_dictionary().to_state(), + "batch_size": self.batch_size, } @classmethod From 03092c298afe3032f62c2d574d38394cc057bf64 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Sun, 21 Jan 2024 22:38:02 +0100 Subject: [PATCH 45/58] add evaluation function --- flair/models/entity_mention_linking.py | 49 +++++++++++++++++++++----- 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index 4eb1eb35b..196c8284d 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -22,7 +22,7 @@ import flair from flair.class_utils import get_state_subclass_by_name -from flair.data import DT, Dictionary, Sentence +from flair.data import DT, Dictionary, Sentence, _iter_dataset from flair.datasets import ( CTD_CHEMICALS_DICTIONARY, CTD_DISEASES_DICTIONARY, @@ -755,7 +755,7 @@ def _get_state(self) -> Dict[str, Any]: } -class EntityMentionLinker(flair.nn.Model): +class EntityMentionLinker(flair.nn.Model[Sentence]): """Entity linking model for the biomedical domain.""" def __init__( @@ -798,6 +798,7 @@ def predict( top_k: int = 1, pred_label_type: Optional[str] = None, entity_label_types: Optional[Union[str, Sequence[str], Dict[str, Set[str]]]] = None, + batch_size: Optional[int] = None, ) -> None: """Predicts the best matching top-k entity / concept identifiers of all named entities annotated with tag input_entity_annotation_layer. @@ -810,6 +811,8 @@ def predict( # make sure sentences is a list of sentences if not isinstance(sentences, list): sentences = [sentences] + if batch_size is None: + batch_size = self.batch_size # Make sure entity label types are represented as dict entity_label_types = entity_label_types if entity_label_types is not None else self.entity_label_types @@ -846,11 +849,11 @@ def predict( ) # Retrieve top-k concept / entity candidates - for i in range(0, len(mentions), self.batch_size): - candidates = self.candidate_generator.search(entity_mentions=mentions[i : i + self.batch_size], top_k=top_k) + for i in range(0, len(mentions), batch_size): + candidates = self.candidate_generator.search(entity_mentions=mentions[i : i + batch_size], top_k=top_k) # Add a label annotation for each candidate - for data_point, mention_candidates in zip(data_points[i : i + self.batch_size], candidates): + for data_point, mention_candidates in zip(data_points[i : i + batch_size], candidates): for candidate_id, confidence in mention_candidates: data_point.add_label(pred_label_type, candidate_id, confidence) @@ -1068,15 +1071,45 @@ def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "EntityMentionLin def evaluate( self, - data_points: Union[List[DT], Dataset], + data_points: Union[List[Sentence], Dataset], gold_label_type: str, out_path: Optional[Union[str, Path]] = None, embedding_storage_mode: str = "none", mini_batch_size: int = 32, - main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), + main_evaluation_metric: Tuple[str, str] = ("accuracy", "f1-score"), exclude_labels: List[str] = [], gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, + k: int = 1, **kwargs, ) -> Result: - raise NotImplementedError("Evaluation is currently not implemented for EntityLinking") + if gold_label_dictionary is not None: + raise NotImplementedError("evaluating an EntityMentionLinker with a gold_label_dictionary is not supported") + + if isinstance(data_points, Dataset): + data_points = list(_iter_dataset(data_points)) + + self.predict(data_points, top_k=k, pred_label_type="predicted", entity_label_types=gold_label_type, batch_size=mini_batch_size) + + hits = 0 + total = 0 + for sentence in data_points: + spans = sentence.get_spans(gold_label_type) + for span in spans: + exps = set(exp.value for exp in span.get_labels(gold_label_type) if exp.value not in exclude_labels) + + predictions = set(pred.value for pred in span.get_labels("predicted")) + total += 1 + if exps & predictions: + hits += 1 + sentence.remove_labels("predicted") + accuracy = hits / total + + detailed_results = f"Accuracy@{k}: {accuracy:0.2%}" + scores = {"accuracy": accuracy, f"accuracy@{k}": accuracy, "loss": 0.0} + + return Result( + main_score=accuracy, + detailed_results=detailed_results, + scores=scores + ) From bf88d0a862b1abcd48980e978876a32f4d170eda Mon Sep 17 00:00:00 2001 From: Samuele Garda Date: Fri, 26 Jan 2024 18:36:26 +0100 Subject: [PATCH 46/58] fix(linker): extraction of entity mentions in predict - normalize entity types: diseases->disease, genes-gene - predict: compatibility with Classifier.load('hunflair').label_type --- flair/models/entity_mention_linking.py | 150 ++++++++++++------ .../HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md | 24 +-- 2 files changed, 119 insertions(+), 55 deletions(-) diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index 196c8284d..79ee011a3 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -1,4 +1,3 @@ -import copy import inspect import logging import os @@ -9,7 +8,6 @@ import tempfile from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import Iterable from enum import Enum, auto from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast @@ -22,7 +20,7 @@ import flair from flair.class_utils import get_state_subclass_by_name -from flair.data import DT, Dictionary, Sentence, _iter_dataset +from flair.data import DT, Dictionary, Label, Sentence, _iter_dataset from flair.datasets import ( CTD_CHEMICALS_DICTIONARY, CTD_DISEASES_DICTIONARY, @@ -45,14 +43,14 @@ # Dense + sparse retrieval PRETRAINED_HYBRID_MODELS = { - "dmis-lab/biosyn-sapbert-bc5cdr-disease": "diseases", - "dmis-lab/biosyn-sapbert-ncbi-disease": "diseases", + "dmis-lab/biosyn-sapbert-bc5cdr-disease": "disease", + "dmis-lab/biosyn-sapbert-ncbi-disease": "disease", "dmis-lab/biosyn-sapbert-bc5cdr-chemical": "chemical", - "dmis-lab/biosyn-biobert-bc5cdr-disease": "diseases", - "dmis-lab/biosyn-biobert-ncbi-disease": "diseases", + "dmis-lab/biosyn-biobert-bc5cdr-disease": "disease", + "dmis-lab/biosyn-biobert-ncbi-disease": "disease", "dmis-lab/biosyn-biobert-bc5cdr-chemical": "chemical", - "dmis-lab/biosyn-biobert-bc2gn": "genes", - "dmis-lab/biosyn-sapbert-bc2gn": "genes", + "dmis-lab/biosyn-biobert-bc2gn": "gene", + "dmis-lab/biosyn-sapbert-bc2gn": "gene", } # fetched from original repo to avoid download @@ -74,12 +72,12 @@ MODELS = PRETRAINED_MODELS + STRING_MATCHING_MODELS -ENTITY_TYPES = ["diseases", "chemical", "genes", "species"] +ENTITY_TYPES = ["disease", "chemical", "gene", "species"] ENTITY_TYPE_TO_HYBRID_MODEL = { - "diseases": "dmis-lab/biosyn-sapbert-bc5cdr-disease", + "disease": "dmis-lab/biosyn-sapbert-bc5cdr-disease", "chemical": "dmis-lab/biosyn-sapbert-bc5cdr-chemical", - "genes": "dmis-lab/biosyn-sapbert-bc2gn", + "gene": "dmis-lab/biosyn-sapbert-bc2gn", } # for now we always fall back to SapBERT, @@ -89,9 +87,9 @@ } ENTITY_TYPE_TO_DICTIONARY = { - "genes": "ncbi-gene", + "gene": "ncbi-gene", "species": "ncbi-taxonomy", - "diseases": "ctd-diseases", + "disease": "ctd-diseases", "chemical": "ctd-chemicals", } @@ -103,13 +101,13 @@ } MODEL_NAME_TO_DICTIONARY = { - "dmis-lab/biosyn-sapbert-bc5cdr-disease": "ctd-disease", - "dmis-lab/biosyn-sapbert-ncbi-disease": "ctd-disease", - "dmis-lab/biosyn-sapbert-bc5cdr-chemical": "ctd-chemical", + "dmis-lab/biosyn-sapbert-bc5cdr-disease": "ctd-diseases", + "dmis-lab/biosyn-sapbert-ncbi-disease": "ctd-diseases", + "dmis-lab/biosyn-sapbert-bc5cdr-chemical": "ctd-chemicals", "dmis-lab/biosyn-sapbert-bc2gn": "ncbi-gene", - "dmis-lab/biosyn-biobert-bc5cdr-disease": "ctd-chemical", - "dmis-lab/biosyn-biobert-ncbi-disease": "ctd-disease", - "dmis-lab/biosyn-biobert-bc5cdr-chemical": "ctd-chemical", + "dmis-lab/biosyn-biobert-bc5cdr-disease": "ctd-chemicals", + "dmis-lab/biosyn-biobert-ncbi-disease": "ctd-diseases", + "dmis-lab/biosyn-biobert-bc5cdr-chemical": "ctd-chemicals", "dmis-lab/biosyn-biobert-bc2gn": "ncbi-gene", } @@ -126,6 +124,18 @@ class SimilarityMetric(Enum): PRETRAINED_MODEL_TO_SIMILARITY_METRIC = {m: SimilarityMetric.INNER_PRODUCT for m in PRETRAINED_MODELS} +def normalize_entity_type(entity_type: str) -> str: + """Normalize entity type to ease interoperability.""" + entity_type = entity_type.lower() + + if entity_type == "diseases": + entity_type = "disease" + elif entity_type == "genes": + entity_type = "gene" + + return entity_type + + def load_dictionary( dictionary_name_or_path: Union[Path, str], dataset_name: Optional[str] = None ) -> EntityLinkingDictionary: @@ -435,7 +445,7 @@ def _get_state(self) -> Dict[str, Any]: @classmethod def _from_state(cls, state_dict: Dict[str, Any]) -> "EntityPreprocessor": return cls( - ab3p_path=Path(state_dict["ad3p_path"]), + ab3p_path=Path(state_dict["ab3p_path"]), word_data_dir=Path(state_dict["word_data_dir"]), preprocessor=None if state_dict["preprocessor"] is None @@ -772,18 +782,45 @@ def __init__( Args: candidate_generator: Strategy to find matching entities for a given mention preprocessor: Pre-processing strategy to transform / clean entity mentions - entity_label_types: A label type or sequence of label types of the required relation entities. You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. E.g. to use only 'disease' and 'chemical' labels from a NER-tagger: `{'ner': {'disease', 'chemical'}}`. To use all labels from 'ner', pass 'ner' + entity_label_types: A label type or sequence of label types of the required entities. + You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. + E.g. to use only 'disease' and 'chemical' labels from a NER-tagger: `{'ner': {'disease', 'chemical'}}`. + To use all labels from 'ner', pass 'ner' label_type: The label under which the predictions of the linker should be stored dictionary: The dictionary listing all entities + batch_size: Batch size to encode mentions/dictionary names """ self.preprocessor = preprocessor self.candidate_generator = candidate_generator - self.entity_label_types = entity_label_types + self.entity_label_types = self.get_entity_label_types(entity_label_types) self._label_type = label_type self._dictionary = dictionary self.batch_size = batch_size super().__init__() + def get_entity_label_types( + self, entity_label_types: Union[str, Sequence[str], Dict[str, Set[str]]] + ) -> Dict[str, Set[str]]: + """Find out what NER labels to extract from sentence. + + Args: + entity_label_types: A label type or sequence of label types of the required entities. + You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. + E.g. to use only 'disease' and 'chemical' labels from a NER-tagger: `{'ner': {'disease', 'chemical'}}`. + To use all labels from 'ner', pass 'ner' + """ + if isinstance(entity_label_types, str): + entity_label_types = cast(Dict[str, Set[str]], {entity_label_types: {}}) + elif isinstance(entity_label_types, Sequence): + entity_label_types = cast(Dict[str, Set[str]], {label: {} for label in entity_label_types}) + + entity_label_types = { + label: {normalize_entity_type(e) for e in entity_types} + for label, entity_types in entity_label_types.items() + } + + return entity_label_types + @property def label_type(self): return self._label_type @@ -792,6 +829,28 @@ def label_type(self): def dictionary(self) -> EntityLinkingDictionary: return self._dictionary + def extract_entities_mentions(self, sentence: Sentence, entity_label_types: Dict[str, Set[str]]) -> List[Label]: + """Extract tagged mentions from sentences.""" + entities_mentions: List[Label] = [] + + if all(len(sentence.get_labels(lt)) == 0 for lt in entity_label_types): + # TODO: This is a hacky workaround for the fact that + # `Classifier.load('hunflair)` `label_type='diseases'`, + # Remove once the new unified (i.e multi-entity-type) NER is merged + # See: https://github.com/flairNLP/flair/pull/3387 + entity_types = {e for sublist in entity_label_types.values() for e in sublist} + entities_mentions = [ + label for label in sentence.get_labels() if normalize_entity_type(label.value) in entity_types + ] + else: + for label_type, entity_types in entity_label_types.items(): + labels = sentence.get_labels(label_type) + if len(entity_types) > 0: + labels = [label for label in labels if normalize_entity_type(label.value) in entity_types] + entities_mentions.extend(labels) + + return entities_mentions + def predict( self, sentences: Union[List[Sentence], Sentence], @@ -805,8 +864,12 @@ def predict( Args: sentences: One or more sentences to run the prediction on top_k: Number of best-matching entity / concept identifiers - entity_label_types: A label type or sequence of label types of the required relation entities. You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. E.g. to use only 'disease' and 'chemical' labels from a NER-tagger: `{'ner': {'disease', 'chemical'}}`. To use all labels from 'ner', pass 'ner' + entity_label_types: A label type or sequence of label types of the required entities. + You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. + E.g. to use only 'disease' and 'chemical' labels from a NER-tagger: `{'ner': {'disease', 'chemical'}}`. + To use all labels from 'ner', pass 'ner' pred_label_type: The label under which the predictions of the linker should be stored + batch_size: Batch size to encode mentions/dictionary names """ # make sure sentences is a list of sentences if not isinstance(sentences, list): @@ -815,11 +878,11 @@ def predict( batch_size = self.batch_size # Make sure entity label types are represented as dict - entity_label_types = entity_label_types if entity_label_types is not None else self.entity_label_types - if isinstance(entity_label_types, str): - entity_label_types = cast(Dict[str, Set[str]], {entity_label_types: {}}) - elif isinstance(entity_label_types, Iterable): - entity_label_types = cast(Dict[str, Set[str]], {label: {} for label in entity_label_types}) + entity_label_types = ( + self.get_entity_label_types(entity_label_types) + if entity_label_types is not None + else self.entity_label_types + ) pred_label_type = pred_label_type if pred_label_type is not None else self.label_type @@ -831,16 +894,10 @@ def predict( for sentence in sentences: # Collect all entities based on entity type labels configuration - entities = [] - for label_type, entity_types in entity_label_types.items(): - labels = sentence.get_labels(label_type) - if len(entity_types) > 0: - labels = [label for label in labels if label.value in entity_types] - - entities.extend(labels) + entities_mentions = self.extract_entities_mentions(sentence, entity_label_types) # Preprocess entity mentions - for entity in entities: + for entity in entities_mentions: data_points.append(entity.data_point) mentions.append( self.preprocessor.process_mention(entity.data_point.text, sentence) @@ -907,7 +964,7 @@ def _get_state_dict(self): def build( cls, model_name_or_path: Union[str, Path], - label_type: str, + label_type: str = "link", dictionary_name_or_path: Optional[Union[str, Path]] = None, hybrid_search: bool = True, batch_size: int = 128, @@ -968,7 +1025,7 @@ def build( return cls( candidate_generator=candidate_generator, preprocessor=preprocessor, - entity_label_types=entity_type, + entity_label_types={"ner": {entity_type}}, label_type=label_type, dictionary=dictionary, ) @@ -989,6 +1046,7 @@ def __get_model_path_and_entity_type( if model_name_or_path == "cambridgeltl/SapBERT-from-PubMedBERT-fulltext": assert entity_type is not None, f"For model {model_name_or_path} you must specify `entity_type`" + entity_type = normalize_entity_type(entity_type) if hybrid_search: # load model by entity_type @@ -1089,7 +1147,13 @@ def evaluate( if isinstance(data_points, Dataset): data_points = list(_iter_dataset(data_points)) - self.predict(data_points, top_k=k, pred_label_type="predicted", entity_label_types=gold_label_type, batch_size=mini_batch_size) + self.predict( + data_points, + top_k=k, + pred_label_type="predicted", + entity_label_types=gold_label_type, + batch_size=mini_batch_size, + ) hits = 0 total = 0 @@ -1108,8 +1172,4 @@ def evaluate( detailed_results = f"Accuracy@{k}: {accuracy:0.2%}" scores = {"accuracy": accuracy, f"accuracy@{k}": accuracy, "loss": 0.0} - return Result( - main_score=accuracy, - detailed_results=detailed_results, - scores=scores - ) + return Result(main_score=accuracy, detailed_results=detailed_results, scores=scores) diff --git a/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md index 6e7d9790c..e7cf79499 100644 --- a/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md +++ b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md @@ -18,23 +18,25 @@ sentence = Sentence( ner_tagger = Classifier.load("hunflair") ner_tagger.predict(sentence) -nen_tagger = EntityMentionLinker.build("disease") +nen_tagger = EntityMentionLinker.load("disease-linker") nen_tagger.predict(sentence) -nen_tagger = EntityMentionLinker.build("gene") +nen_tagger = EntityMentionLinker.load("gene-linker") nen_tagger.predict(sentence) -nen_tagger = EntityMentionLinker.build("chemical") +nen_tagger = EntityMentionLinker.load("chemical-linker") nen_tagger.predict(sentence) -nen_tagger = EntityMentionLinker.build("species", entity_type="species") +nen_tagger = EntityMentionLinker.load("species-linker") nen_tagger.predict(sentence) for tag in sentence.get_labels(): print(tag) ``` + This should print: -~~~ + +``` Span[4:5]: "ABCD1" → Gene (0.9575) Span[4:5]: "ABCD1" → abcd1 - NCBI-GENE-HUMAN:215 (14.5503) Span[7:11]: "X-linked adrenoleukodystrophy" → Disease (0.9867) @@ -45,10 +47,11 @@ Span[25:26]: "mercury" → Chemical (0.9456) Span[25:26]: "mercury" → mercury - CTD-CHEMICALS:MESH:D008628 (14.9185) Span[27:28]: "dolphin" → Species (0.8082) Span[27:28]: "dolphin" → marine dolphins - NCBI-TAXONOMY:9726 (14.473) -~~~ -The output contains both the NER disease annotations and their entity / concept identifiers according to -a knowledge base or ontology. We have pre-configured combinations of models and dictionaries for -"disease", "chemical" and "gene". +``` + +The output contains both the NER disease annotations and their entity / concept identifiers according to +a knowledge base or ontology. We have pre-configured combinations of models and dictionaries for +"disease", "chemical" and "gene". You can also provide your own model and dictionary: @@ -58,5 +61,6 @@ from flair.models.biomedical_entity_linking import EntityMentionLinker nen_tagger = EntityMentionLinker.build("name_or_path_to_your_model", dictionary_names_or_path="name_or_path_to_your_dictionary") nen_tagger = EntityMentionLinker.build("path_to_custom_disease_model", dictionary_names_or_path="disease") -```` +``` + You can use any combination of provided models, provided dictionaries and your own. From 7b824a1a3d03ac02a2d374a4c03cb67e8c74a62a Mon Sep 17 00:00:00 2001 From: Samuele Garda Date: Mon, 29 Jan 2024 13:06:37 +0100 Subject: [PATCH 47/58] fix(predict): ensure mentions extraction works with legacy classifier - fix(preprocessing): rm path from a3bp-preprocessor state --- flair/models/entity_mention_linking.py | 40 ++++++++++++-------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index 79ee011a3..0de63152a 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -437,16 +437,12 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict def _get_state(self) -> Dict[str, Any]: return { **super()._get_state(), - "ab3p_path": str(self.ab3p_path), - "word_data_dir": str(self.word_data_dir), "preprocessor": None if self.preprocessor is None else self.preprocessor._get_state(), } @classmethod def _from_state(cls, state_dict: Dict[str, Any]) -> "EntityPreprocessor": return cls( - ab3p_path=Path(state_dict["ab3p_path"]), - word_data_dir=Path(state_dict["word_data_dir"]), preprocessor=None if state_dict["preprocessor"] is None else EntityPreprocessor._from_state(state_dict["preprocessor"]), @@ -796,6 +792,7 @@ def __init__( self._label_type = label_type self._dictionary = dictionary self.batch_size = batch_size + self._warned_legacy_sequence_tagger = False super().__init__() def get_entity_label_types( @@ -833,11 +830,17 @@ def extract_entities_mentions(self, sentence: Sentence, entity_label_types: Dict """Extract tagged mentions from sentences.""" entities_mentions: List[Label] = [] - if all(len(sentence.get_labels(lt)) == 0 for lt in entity_label_types): - # TODO: This is a hacky workaround for the fact that - # `Classifier.load('hunflair)` `label_type='diseases'`, - # Remove once the new unified (i.e multi-entity-type) NER is merged - # See: https://github.com/flairNLP/flair/pull/3387 + # NOTE: This is a hacky workaround for the fact that + # the `label_type`s in `Classifier.load('hunflair)` are + # 'diseases', 'genes', 'species', 'chemical' instead of 'ner'. + # We warn users once they need to update SequenceTagger model + # See: https://github.com/flairNLP/flair/pull/3387 + if any(label in ["diseases", "genes", "species", "chemical"] for label in sentence.annotation_layers): + if not self._warned_legacy_sequence_tagger: + logger.warn( + "The tagger `Classifier.load('hunflair') is deprecated. Please update to: `Classifier.load('hunflair2')`." + ) + self._warned_legacy_sequence_tagger = True entity_types = {e for sublist in entity_label_types.values() for e in sublist} entities_mentions = [ label for label in sentence.get_labels() if normalize_entity_type(label.value) in entity_types @@ -919,17 +922,12 @@ def _fetch_model(model_name: str) -> str: if Path(model_name).exists(): return model_name - bio_base_repo = "helpmefindaname" - + bio_base_repo = "hunflair" hf_model_map = { - "bio-gene": f"{bio_base_repo}/flair-eml-sapbert-bc2gn-gene", - "bio-disease": f"{bio_base_repo}/flair-eml-sapbert-bc5cdr-disease", - "bio-chemical": f"{bio_base_repo}/flair-eml-sapbert-bc5cdr-chemical", - "bio-species": f"{bio_base_repo}/flair-eml-species-exact-match", - "bio-gene-exact-match": f"{bio_base_repo}/flair-eml-gene-exact-match", - "bio-disease-exact-match": f"{bio_base_repo}/flair-eml-disease-exact-match", - "bio-chemical-exact-match": f"{bio_base_repo}/flair-eml-chemical-exact-match", - "bio-species-exact-match": f"{bio_base_repo}/flair-eml-species-exact-match", + "gene-linker": f"{bio_base_repo}/biosyn-sapbert-bc2gn", + "disease-linker": f"{bio_base_repo}/biosyn-sapbert-bc5cdr-disease", + "chemical-linker": f"{bio_base_repo}/biosyn-sapbert-bc5cdr-chemical", + "species-linker": f"{bio_base_repo}/sapbert-ncbi-taxonomy", } if model_name in hf_model_map: @@ -1040,7 +1038,6 @@ def __get_model_path_and_entity_type( if model_name_or_path not in MODELS and model_name_or_path not in ENTITY_TYPES: raise ValueError( f"Unknown model `{model_name_or_path}`!" - f" Available entity types are: {ENTITY_TYPES}" " If you want to pass a local path please use the `Path` class, i.e. `model_name_or_path=Path(my_path)`" ) @@ -1113,7 +1110,8 @@ def __get_dictionary_path( dictionary_name_or_path = ENTITY_TYPE_TO_DICTIONARY[model_name_or_path] else: raise ValueError( - f"When using a custom model you need to specify a dictionary. Available options are: {ENTITY_TYPES}. Or provide a path to a dictionary file." + f"When using a custom model you need to specify a dictionary. Available options are: {list(ENTITY_TYPE_TO_DICTIONARY.values())}. " + "Or provide a path to a dictionary file." ) return dictionary_name_or_path From c8cc120a9e3f854b69da8dc500a2bec2e347332d Mon Sep 17 00:00:00 2001 From: Samuele Garda Date: Mon, 29 Jan 2024 13:06:43 +0100 Subject: [PATCH 48/58] chore: update tests --- tests/test_biomedical_entity_linking.py | 36 ++++++++++++------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index 52e2412e5..5fe115fd5 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -78,30 +78,30 @@ def test_abbrevitation_resolution(): def test_biomedical_entity_linking(): - sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome") + sentence = Sentence( + "The mutation in the ABCD1 gene causes X-linked adrenoleukodystrophy, " + "a neurodegenerative disease, which is exacerbated by exposure to high " + "levels of mercury in dolphin populations.", + ) tagger = Classifier.load("hunflair") tagger.predict(sentence) - disease_linker = EntityMentionLinker.load("masaenger/bio-nen-disease") - disease_dictionary = disease_linker.dictionary - disease_linker.predict(sentence, pred_label_type="disease-nen", entity_label_types="diseases") + for entity_type in ["disease", "chemical", "gene", "species"]: + linker = EntityMentionLinker.load(f"{entity_type}-linker") + linker.predict(sentence) + + for span in sentence.get_spans(): + print(span) - gene_linker = EntityMentionLinker.load("masaenger/bio-nen-gene") - gene_dictionary = gene_linker.dictionary - gene_linker.predict(sentence, pred_label_type="gene-nen", entity_label_types="genes") - print("Diseases") - for label in sentence.get_labels("disease-nen"): - candidate = disease_dictionary[label.value] - print(f"Candidate: {candidate.concept_name}") +def test_legacy_sequence_tagger(): + sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome") - print("Genes") - for label in sentence.get_labels("gene-nen"): - candidate = gene_dictionary[label.value] - print(f"Candidate: {candidate.concept_name}") + legacy_tagger = Classifier.load("hunflair") + legacy_tagger.predict(sentence) + disease_linker = EntityMentionLinker.load("hunflair/biosyn-sapbert-ncbi-disease") + disease_linker.predict(sentence, pred_label_type="disease-nen") -if __name__ == "__main__": - test_abbrevitation_resolution() - test_biomedical_entity_linking() + assert disease_linker._warned_legacy_sequence_tagger From 94aaca1066c2a20c1a2eec61fb6cd5e0eb2684d5 Mon Sep 17 00:00:00 2001 From: Samuele Garda Date: Mon, 29 Jan 2024 15:05:45 +0100 Subject: [PATCH 49/58] fix(tests): normalized entity type name --- tests/test_biomedical_entity_linking.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index 5fe115fd5..bcccead99 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -14,7 +14,7 @@ def test_bel_dictionary(): Hard to define a good test as dictionaries are DYNAMIC, i.e. they can change over time. """ - dictionary = load_dictionary("diseases") + dictionary = load_dictionary("disease") candidate = dictionary.candidates[0] assert candidate.concept_id.startswith(("MESH:", "OMIM:", "DO:DOID")) @@ -42,7 +42,7 @@ def test_bel_dictionary(): candidate = dictionary.candidates[0] assert candidate.concept_id.isdigit() - dictionary = load_dictionary("genes") + dictionary = load_dictionary("gene") candidate = dictionary.candidates[0] assert candidate.concept_id.isdigit() @@ -105,3 +105,7 @@ def test_legacy_sequence_tagger(): disease_linker.predict(sentence, pred_label_type="disease-nen") assert disease_linker._warned_legacy_sequence_tagger + + +if __name__ == "__main__": + test_bel_dictionary() From 75ea402bbd47ab28f4581ba3ade9e3d2f2b094a9 Mon Sep 17 00:00:00 2001 From: Samuele Garda Date: Mon, 29 Jan 2024 15:06:12 +0100 Subject: [PATCH 50/58] fix(logging): deprecated logger.warn --- flair/models/entity_mention_linking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index 0de63152a..a739b6cc2 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -837,7 +837,7 @@ def extract_entities_mentions(self, sentence: Sentence, entity_label_types: Dict # See: https://github.com/flairNLP/flair/pull/3387 if any(label in ["diseases", "genes", "species", "chemical"] for label in sentence.annotation_layers): if not self._warned_legacy_sequence_tagger: - logger.warn( + logger.warning( "The tagger `Classifier.load('hunflair') is deprecated. Please update to: `Classifier.load('hunflair2')`." ) self._warned_legacy_sequence_tagger = True From d9805befda2f925120fcddc20e5769d6ae098885 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 12 Jan 2024 13:58:33 +0100 Subject: [PATCH 51/58] improve typing and run black --- flair/datasets/entity_linking.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flair/datasets/entity_linking.py b/flair/datasets/entity_linking.py index bc9744461..330785d34 100644 --- a/flair/datasets/entity_linking.py +++ b/flair/datasets/entity_linking.py @@ -41,7 +41,7 @@ def __init__( self._idx_to_candidates = {candidate.concept_id: candidate for candidate in candidates} # one name can map to multiple concepts - self._text_to_index: Dict[str, List] = {} + self._text_to_index: Dict[str, List[str]] = {} for candidate in candidates: for text in [candidate.concept_name, *candidate.synonyms]: if text not in self._text_to_index: @@ -54,7 +54,7 @@ def database_name(self) -> str: return self._dataset_name @property - def text_to_index(self) -> Dict[str, List]: + def text_to_index(self) -> Dict[str, List[str]]: return self._text_to_index @property From 8ab67d2bbbcdd275368ab5d362d49c662a4a8c3e Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 2 Feb 2024 16:05:26 +0100 Subject: [PATCH 52/58] add datasets for nel-bioner evaluation --- flair/datasets/entity_linking.py | 402 ++++++++++++++++++++++++- flair/models/entity_mention_linking.py | 5 +- requirements.txt | 3 +- 3 files changed, 405 insertions(+), 5 deletions(-) diff --git a/flair/datasets/entity_linking.py b/flair/datasets/entity_linking.py index 330785d34..91b4c8315 100644 --- a/flair/datasets/entity_linking.py +++ b/flair/datasets/entity_linking.py @@ -1,13 +1,18 @@ +import abc +import bisect import csv import logging import os +import re from pathlib import Path from typing import Any, Dict, Iterable, Iterator, List, Optional, Union import requests +from bioc import biocxml, pubtator import flair -from flair.data import Corpus, EntityCandidate, MultiCorpus, Sentence +from flair.data import Corpus, EntityCandidate, MultiCorpus, Sentence, Token +from flair.datasets import FlairDatapointDataset from flair.datasets.sequence_labeling import ColumnCorpus, MultiFileColumnCorpus from flair.file_utils import cached_path, unpack_file from flair.splitter import SegtokSentenceSplitter, SentenceSplitter @@ -2205,3 +2210,398 @@ def __init__( banned_sentences=banned_sentences, sample_missing_splits=sample_missing_splits, ) + + +class BigbioCorpus(Corpus, abc.ABC): + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + label_type: str = "el", + norm_keys: List[str] = ["db_name", "db_id"], + **kwargs, + ) -> None: + self.label_type = label_type + self.norm_keys = norm_keys + base_path = flair.cache_root / "datasets" if not base_path else Path(base_path) + + dataset_name = self.__class__.__name__.lower() + + data_folder = base_path / dataset_name + paths = self._download_dataset(data_folder) + + super().__init__( + train=self._files_to_dataset(paths["train"]) if "train" in paths else None, + dev=self._files_to_dataset(paths["dev"]) if "dev" in paths else None, + test=self._files_to_dataset(paths["test"]) if "test" in paths else None, + **kwargs, + ) + + @abc.abstractmethod + def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]: + pass + + @abc.abstractmethod + def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: + pass + + def _dict_to_sentences(self, entry: Dict[str, Any]) -> List[Sentence]: + entities = [entity for entity in entry["entities"] if entity["normalized"]] + + tokenized_passages = [ + Sentence(passage["text"][0], start_position=passage["offsets"][0][0]) for passage in entry["passages"] + ] + start_ids = [ + sentence.start_position + token.start_position for sentence in tokenized_passages for token in sentence + ] + end_ids = [ + sentence.start_position + token.end_position for sentence in tokenized_passages for token in sentence + ] + + for entity in entities: + for start, end in entity["offsets"]: + if start not in start_ids: + assert start not in end_ids + start_ids.append(start) + end_ids.append(start) + if end not in end_ids: + assert end not in start_ids + end_ids.append(end) + start_ids.append(end) + start_ids.sort() + end_ids.sort() + passage_sentences = [] + n_tokens = len(start_ids) + for passage in entry["passages"]: + token_offset = passage["offsets"][0][0] + start_idx = bisect.bisect_left(start_ids, token_offset) + end_idx = bisect.bisect_right(end_ids, passage["offsets"][0][1]) + offsets = zip(start_ids[start_idx:end_idx], end_ids[start_idx:end_idx]) + passage_tokens = [ + Token(passage["text"][0][start - token_offset : end - token_offset]) for start, end in offsets + ] + for i, idx in enumerate(range(start_idx, end_idx)): + if idx + 1 < n_tokens: + passage_tokens[i].whitespace_after = start_ids[idx + 1] - end_ids[idx] + passage_sentences.append(Sentence(passage_tokens, start_position=token_offset)) + for token, start, end in zip( + passage_sentences[-1], start_ids[start_idx:end_idx], end_ids[start_idx:end_idx] + ): + assert token.start_position + passage_sentences[-1].start_position == start + assert token.end_position + passage_sentences[-1].start_position == end + + start_id_to_token = { + token.start_position + sentence.start_position: (sentence, i) + for sentence in passage_sentences + for i, token in enumerate(sentence) + } + end_id_to_token = { + token.end_position + sentence.start_position: (sentence, i) + for sentence in passage_sentences + for i, token in enumerate(sentence) + } + for entity in entities: + mention_ids = [":".join([n[key] for key in self.norm_keys]) for n in entity["normalized"]] + assert len(entity["offsets"]) == len(entity["text"]) + for (start, end), text in zip(entity["offsets"], entity["text"]): + assert start in start_id_to_token + assert end in end_id_to_token + sent_s, start_token_idx = start_id_to_token[start] + sent_e, end_token_idx = end_id_to_token[end] + assert sent_s is sent_e + + for mention_id in mention_ids: + sent_s[start_token_idx : end_token_idx + 1].add_label(self.label_type, mention_id) + return passage_sentences + + def _files_to_dataset(self, paths: Union[Path, List[Path]]) -> FlairDatapointDataset: + if isinstance(paths, Path): + paths = [paths] + all_sentences = [] + for path in paths: + for entry in self._file_to_dicts(path): + all_sentences.extend(self._dict_to_sentences(entry)) + return FlairDatapointDataset(all_sentences) + + +class BIGBIO_NCBI_DISEASE(BigbioCorpus): + def __init__(self, base_path: Optional[Union[str, Path]] = None, label_type: str = "el-diseases", **kwargs) -> None: + super().__init__(base_path, label_type, **kwargs) + + def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]: + download_urls = { + "train": ( + "NCBItrainset_corpus.txt", + "https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/NCBItrainset_corpus.zip", + ), + "dev": ( + "NCBIdevelopset_corpus.txt", + "https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/NCBIdevelopset_corpus.zip", + ), + "test": ( + "NCBItestset_corpus.txt", + "https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/NCBItestset_corpus.zip", + ), + } + results_files: Dict[str, Union[Path, List[Path]]] = {} + + for split, (filename, url) in download_urls.items(): + result_path = data_folder / filename + results_files[split] = result_path + + if result_path.exists(): + continue + + path = cached_path(url, data_folder) + unpack_file(path, data_folder) + + return results_files + + def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: + with open(filepath) as f: + for doc in pubtator.iterparse(f): + unified_example = { + "id": doc.pmid, + "document_id": doc.pmid, + "passages": [ + { + "text": [doc.title], + "offsets": [[0, len(doc.title)]], + }, + { + "text": [doc.abstract], + "offsets": [ + [ + # +1 assumes the title and abstract will be joined by a space. + len(doc.title) + 1, + len(doc.title) + 1 + len(doc.abstract), + ] + ], + }, + ], + } + + unified_entities = [] + for i, entity in enumerate(doc.annotations): + # We need a unique identifier for this entity, so build it from the document id and entity id + unified_entity_id = "_".join([doc.pmid, entity.id, str(i)]) + # The user can provide a callable that returns the database name. + normalized = [] + + for x in entity.id.split("|"): + if x.startswith(("OMIM", "omim")): + normalized.append({"db_name": "OMIM", "db_id": x.strip().split(":")[-1]}) + elif "+" in x: + normalized.extend( + [ + { + "db_name": "MESH", + "db_id": y.split(":")[-1].strip(), + } + for y in x.split("+") + ] + ) + else: + normalized.append({"db_name": "MESH", "db_id": x.split(":")[-1].strip()}) + + unified_entities.append( + { + "id": unified_entity_id, + "type": entity.type, + "text": [entity.text], + "offsets": [[entity.start, entity.end]], + "normalized": normalized, + } + ) + + unified_example["entities"] = unified_entities + + yield unified_example + + +class BIGBIO_BC5CDR_CHEMICAL(BigbioCorpus): + def __init__(self, base_path: Optional[Union[str, Path]] = None, label_type: str = "el-chemical", **kwargs) -> None: + super().__init__(base_path, label_type, **kwargs) + + def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]: + url = "https://huggingface.co/datasets/bigbio/bc5cdr/resolve/main/CDR_Data.zip" + + path = cached_path(url, data_folder) + data_path = data_folder / "CDR_Data" / "CDR.Corpus.v010516" + if not data_path.exists(): + unpack_file(path, data_folder) + assert data_folder.exists() + + results_files: Dict[str, Union[Path, List[Path]]] = { + "train": data_path / "CDR_TrainingSet.BioC.xml", + "dev": data_path / "CDR_DevelopmentSet.BioC.xml", + "test": data_path / "CDR_TestSet.BioC.xml", + } + return results_files + + def _get_bioc_entity(self, span, db_id_key="MESH"): + offsets = [(loc.offset, loc.offset + loc.length) for loc in span.locations] + + text = span.text + + if len(offsets) > 1: + i = 0 + texts = [] + for start, end in offsets: + chunk_len = end - start + texts.append(text[i : chunk_len + i]) + i += chunk_len + while i < len(text) and text[i] == " ": + i += 1 + else: + texts = [text] + db_ids = span.infons[db_id_key] if db_id_key else "-1" + + # some entities are not linked and + # some entities are linked to multiple normalized ids + db_ids_list = [] if db_ids == "-1" else db_ids.split("|") + + normalized = [{"db_name": db_id_key, "db_id": db_id} for db_id in db_ids_list] + + return { + "id": span.id, + "offsets": offsets, + "text": texts, + "type": span.infons["type"], + "normalized": normalized, + } + + def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: + reader = biocxml.BioCXMLDocumentReader(str(filepath)) + + for i, xdoc in enumerate(reader): + data = { + "document_id": xdoc.id, + "entities": [], + "passages": [], + } + + char_start = 0 + # passages must not overlap and spans must cover the entire document + for passage in xdoc.passages: + offsets = [[char_start, char_start + len(passage.text)]] + char_start = char_start + len(passage.text) + 1 + data["passages"].append( + { + "type": passage.infons["type"], + "text": [passage.text], + "offsets": offsets, + } + ) + + # entities + for passage in xdoc.passages: + for span in passage.annotations: + ent = self._get_bioc_entity(span, db_id_key="MESH") + if ent["type"].lower() == "chemical": + data["entities"].append(ent) + + yield data + + +class BIGBIO_GNORMPLUS(BigbioCorpus): + def __init__(self, base_path: Optional[Union[str, Path]] = None, label_type: str = "el-genes", **kwargs) -> None: + self._re_tax_id = re.compile(r"(?P\d+)\([tT]ax:(?P\d+)\)") + super().__init__(base_path, label_type, norm_keys=["db_id"], **kwargs) + + def _download_dataset(self, data_folder: Path) -> Dict[str, Union[Path, List[Path]]]: + url = "https://www.ncbi.nlm.nih.gov/CBBresearch/Lu/Demo/tmTools/download/GNormPlus/GNormPlusCorpus.zip" + + path = cached_path(url, data_folder) + data_path = data_folder / "GNormPlusCorpus" + if not data_path.exists(): + unpack_file(path, data_folder) + assert data_folder.exists() + + results_files: Dict[str, Union[Path, List[Path]]] = { + "train": [data_path / "BC2GNtrain.BioC.xml", data_path / "NLMIAT.BioC.xml"], + "test": data_path / "BC2GNtest.BioC.xml", + } + return results_files + + def _parse_bioc_entity(self, span, db_id_key="NCBIGene", insert_tax_id=False): + offsets = [(loc.offset, loc.offset + loc.length) for loc in span.locations] + + text = span.text + + if len(offsets) > 1: + i = 0 + texts = [] + for start, end in offsets: + chunk_len = end - start + texts.append(text[i : chunk_len + i]) + i += chunk_len + while i < len(text) and text[i] == " ": + i += 1 + else: + texts = [text] + _type = span.infons["type"] + + # parse db ids + normalized = [] + if _type in span.infons: + for _id in span.infons[_type].split(","): + match = self._re_tax_id.match(_id) + if match: + _id = match.group("db_id") + + n = {"db_name": db_id_key, "db_id": _id} + if insert_tax_id: + n["tax_id"] = match.group("tax_id") if match else None + + normalized.append(n) + return { + "offsets": offsets, + "text": texts, + "type": _type, + "normalized": normalized, + } + + def _adjust_entity_offsets(self, text: str, entities: List[Dict]): + for entity in entities: + start, end = entity["offsets"][0] + entity_mention = entity["text"][0] + if text[start:end] != entity_mention: + if text[start - 1 : end - 1] == entity_mention: + entity["offsets"] = [(start - 1, end - 1)] + elif text[start : end - 1] == entity_mention: + entity["offsets"] = [(start, end - 1)] + + def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: + with filepath.open("r") as f: + collection = biocxml.load(f) + + for document in collection.documents: + text = " ".join([passage.text for passage in document.passages]) + entities = [ + self._parse_bioc_entity(entity) for passage in document.passages for entity in passage.annotations + ] + + # Some of the entities have a off-by-one error. Correct these annotations! + self._adjust_entity_offsets(text, entities) + + # passage offsets/lengths do not connect, recalculate them for this schema. + passage_spans = [] + start = 0 + for passage in document.passages: + end = start + len(passage.text) + passage_spans.append((start, end)) + start = end + 1 + + features = { + "passages": [ + { + "type": passage.infons["type"], + "text": [passage.text], + "offsets": [span], + } + for passage, span in zip(document.passages, passage_spans) + ], + "entities": entities, + } + + yield features diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index a739b6cc2..6037b6abb 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -1158,9 +1158,9 @@ def evaluate( for sentence in data_points: spans = sentence.get_spans(gold_label_type) for span in spans: - exps = set(exp.value for exp in span.get_labels(gold_label_type) if exp.value not in exclude_labels) + exps = {exp.value for exp in span.get_labels(gold_label_type) if exp.value not in exclude_labels} - predictions = set(pred.value for pred in span.get_labels("predicted")) + predictions = {pred.value for pred in span.get_labels("predicted")} total += 1 if exps & predictions: hits += 1 @@ -1169,5 +1169,4 @@ def evaluate( detailed_results = f"Accuracy@{k}: {accuracy:0.2%}" scores = {"accuracy": accuracy, f"accuracy@{k}": accuracy, "loss": 0.0} - return Result(main_score=accuracy, detailed_results=detailed_results, scores=scores) diff --git a/requirements.txt b/requirements.txt index 83f09e2ec..87c8face9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,4 +26,5 @@ transformer-smaller-training-vocab>=0.2.3 transformers[sentencepiece]>=4.18.0,<5.0.0 urllib3<2.0.0,>=1.0.0 # pin below 2 to make dependency resolution faster. wikipedia-api>=0.5.7 -semver<4.0.0,>=3.0.0 \ No newline at end of file +semver<4.0.0,>=3.0.0 +bioc<3.0.0,>=2.0.0 \ No newline at end of file From 654ed8d9bb528f12b07c9c52d49e75f42a876a1f Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 2 Feb 2024 16:09:33 +0100 Subject: [PATCH 53/58] mark heavy test as integration test --- tests/test_biomedical_entity_linking.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index bcccead99..00e2b6cf7 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -1,3 +1,5 @@ +import pytest + from flair.data import Sentence from flair.models.entity_mention_linking import ( Ab3PEntityPreprocessor, @@ -77,6 +79,7 @@ def test_abbrevitation_resolution(): assert mention == "weaver smith syndrome" +@pytest.mark.integration() def test_biomedical_entity_linking(): sentence = Sentence( "The mutation in the ABCD1 gene causes X-linked adrenoleukodystrophy, " From efb1018f45d2914315d63adee3a2b6380d3b9b53 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 2 Feb 2024 23:08:27 +0100 Subject: [PATCH 54/58] add metadata to labels for cnadidate names --- flair/data.py | 71 ++++++++++++++------------ flair/models/entity_mention_linking.py | 2 +- 2 files changed, 39 insertions(+), 34 deletions(-) diff --git a/flair/data.py b/flair/data.py index 1089230d5..30317dc23 100644 --- a/flair/data.py +++ b/flair/data.py @@ -213,10 +213,11 @@ class Label: Default value for the score is 1.0. """ - def __init__(self, data_point: "DataPoint", value: str, score: float = 1.0) -> None: + def __init__(self, data_point: "DataPoint", value: str, score: float = 1.0, **metadata) -> None: self._value = value self._score = score self.data_point: DataPoint = data_point + self.metadata = metadata super().__init__() def set_value(self, value: str, score: float = 1.0): @@ -235,14 +236,14 @@ def to_dict(self): return {"value": self.value, "confidence": self.score} def __str__(self) -> str: - return f"{self.data_point.unlabeled_identifier}{flair._arrow}{self._value} ({round(self._score, 4)})" + return f"{self.data_point.unlabeled_identifier}{flair._arrow}{self._value}{self.metadata_str} ({round(self._score, 4)})" @property def shortstring(self): return f'"{self.data_point.text}"/{self._value}' def __repr__(self) -> str: - return f"'{self.data_point.unlabeled_identifier}'/'{self._value}' ({round(self._score, 4)})" + return f"'{self.data_point.unlabeled_identifier}'/'{self._value}'{self.metadata_str} ({round(self._score, 4)})" def __eq__(self, other): return self.value == other.value and self.score == other.score and self.data_point == other.data_point @@ -253,6 +254,13 @@ def __hash__(self): def __lt__(self, other): return self.data_point < other.data_point + @property + def metadata_str(self) -> str: + if not self.metadata: + return "" + rep = "/".join(f"{k}={v}" for k, v in self.metadata.items()) + return f"/{rep}" + @property def labeled_identifier(self): return f"{self.data_point.unlabeled_identifier}/{self.value}" @@ -336,8 +344,8 @@ def get_metadata(self, key: str) -> typing.Any: def has_metadata(self, key: str) -> bool: return key in self._metadata - def add_label(self, typename: str, value: str, score: float = 1.0): - label = Label(self, value, score) + def add_label(self, typename: str, value: str, score: float = 1.0, **metadata): + label = Label(self, value, score, **metadata) if typename not in self.annotation_layers: self.annotation_layers[typename] = [label] @@ -346,8 +354,8 @@ def add_label(self, typename: str, value: str, score: float = 1.0): return self - def set_label(self, typename: str, value: str, score: float = 1.0): - self.annotation_layers[typename] = [Label(self, value, score)] + def set_label(self, typename: str, value: str, score: float = 1.0, **metadata): + self.annotation_layers[typename] = [Label(self, value, score, **metadata)] return self def remove_labels(self, typename: str): @@ -377,28 +385,25 @@ def labels(self) -> List[Label]: def unlabeled_identifier(self): raise NotImplementedError - def _printout_labels(self, main_label=None, add_score: bool = True): + def _printout_labels(self, main_label=None, add_score: bool = True, add_metadata: bool = True): all_labels = [] keys = [main_label] if main_label is not None else self.annotation_layers.keys() - if add_score: - for key in keys: - all_labels.extend( - [ - f"{label.value} ({round(label.score, 4)})" - for label in self.get_labels(key) - if label.data_point == self - ] - ) - labels = "; ".join(all_labels) - if labels != "": - labels = flair._arrow + labels - else: - for key in keys: - all_labels.extend([f"{label.value}" for label in self.get_labels(key) if label.data_point == self]) - labels = "/".join(all_labels) - if labels != "": - labels = "/" + labels - return labels + + sep = "; " if add_score else "/" + sent_sep = flair._arrow if add_score else "/" + for key in keys: + for label in self.get_labels(key): + if label.data_point is not self: + continue + value = label.value + if add_metadata: + value = f"{value}{label.metadata_str}" + if add_score: + value = f"{value} ({label.score:.04f})" + all_labels.append(value) + if not all_labels: + return "" + return sent_sep + sep.join(all_labels) def __str__(self) -> str: return self.unlabeled_identifier + self._printout_labels() @@ -497,18 +502,18 @@ def __init__(self, sentence) -> None: super().__init__() self.sentence: Sentence = sentence - def add_label(self, typename: str, value: str, score: float = 1.0): - super().add_label(typename, value, score) - self.sentence.annotation_layers.setdefault(typename, []).append(Label(self, value, score)) + def add_label(self, typename: str, value: str, score: float = 1.0, **metadata): + super().add_label(typename, value, score, **metadata) + self.sentence.annotation_layers.setdefault(typename, []).append(Label(self, value, score, **metadata)) - def set_label(self, typename: str, value: str, score: float = 1.0): + def set_label(self, typename: str, value: str, score: float = 1.0, **metadata): if len(self.annotation_layers.get(typename, [])) > 0: # First we remove any existing labels for this PartOfSentence in self.sentence self.sentence.annotation_layers[typename] = [ label for label in self.sentence.annotation_layers.get(typename, []) if label.data_point != self ] - self.sentence.annotation_layers.setdefault(typename, []).append(Label(self, value, score)) - super().set_label(typename, value, score) + self.sentence.annotation_layers.setdefault(typename, []).append(Label(self, value, score, **metadata)) + super().set_label(typename, value, score, **metadata) return self def remove_labels(self, typename: str): diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index 6037b6abb..3d1d85b6c 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -915,7 +915,7 @@ def predict( # Add a label annotation for each candidate for data_point, mention_candidates in zip(data_points[i : i + batch_size], candidates): for candidate_id, confidence in mention_candidates: - data_point.add_label(pred_label_type, candidate_id, confidence) + data_point.add_label(pred_label_type, candidate_id, confidence, name=self.dictionary[candidate_id].concept_name) @staticmethod def _fetch_model(model_name: str) -> str: From 44c1413d4775466f51caa984bc49248af1bc0d8c Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Sun, 4 Feb 2024 13:21:32 +0100 Subject: [PATCH 55/58] fix black ruff and mypy --- flair/data.py | 12 ++++++------ flair/models/entity_mention_linking.py | 26 ++++++++++++++------------ 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/flair/data.py b/flair/data.py index 30317dc23..dd8304405 100644 --- a/flair/data.py +++ b/flair/data.py @@ -599,21 +599,21 @@ def __len__(self) -> int: def __repr__(self) -> str: return self.__str__() - def add_label(self, typename: str, value: str, score: float = 1.0): + def add_label(self, typename: str, value: str, score: float = 1.0, **metadata): # The Token is a special _PartOfSentence in that it may be initialized without a Sentence. # therefore, labels get added only to the Sentence if it exists if self.sentence: - super().add_label(typename=typename, value=value, score=score) + super().add_label(typename=typename, value=value, score=score, **metadata) else: - DataPoint.add_label(self, typename=typename, value=value, score=score) + DataPoint.add_label(self, typename=typename, value=value, score=score, **metadata) - def set_label(self, typename: str, value: str, score: float = 1.0): + def set_label(self, typename: str, value: str, score: float = 1.0, **metadata): # The Token is a special _PartOfSentence in that it may be initialized without a Sentence. # Therefore, labels get set only to the Sentence if it exists if self.sentence: - super().set_label(typename=typename, value=value, score=score) + super().set_label(typename=typename, value=value, score=score, **metadata) else: - DataPoint.set_label(self, typename=typename, value=value, score=score) + DataPoint.set_label(self, typename=typename, value=value, score=score, **metadata) def to_dict(self, tag_type: Optional[str] = None): return { diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index 3d1d85b6c..1bb58e435 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -443,9 +443,11 @@ def _get_state(self) -> Dict[str, Any]: @classmethod def _from_state(cls, state_dict: Dict[str, Any]) -> "EntityPreprocessor": return cls( - preprocessor=None - if state_dict["preprocessor"] is None - else EntityPreprocessor._from_state(state_dict["preprocessor"]), + preprocessor=( + None + if state_dict["preprocessor"] is None + else EntityPreprocessor._from_state(state_dict["preprocessor"]) + ), ) @@ -601,11 +603,7 @@ def bi_encoder( ngram_range=(1, 2), ) - sparse_weight = ( - sparse_weight - if model_name_or_path not in HYBRID_MODELS_SPARSE_WEIGHT - else HYBRID_MODELS_SPARSE_WEIGHT[model_name_or_path] - ) + sparse_weight = HYBRID_MODELS_SPARSE_WEIGHT.get(model_name_or_path, sparse_weight) return cls( embeddings, @@ -903,9 +901,11 @@ def predict( for entity in entities_mentions: data_points.append(entity.data_point) mentions.append( - self.preprocessor.process_mention(entity.data_point.text, sentence) - if self.preprocessor is not None - else entity.data_point.text, + ( + self.preprocessor.process_mention(entity.data_point.text, sentence) + if self.preprocessor is not None + else entity.data_point.text + ), ) # Retrieve top-k concept / entity candidates @@ -915,7 +915,9 @@ def predict( # Add a label annotation for each candidate for data_point, mention_candidates in zip(data_points[i : i + batch_size], candidates): for candidate_id, confidence in mention_candidates: - data_point.add_label(pred_label_type, candidate_id, confidence, name=self.dictionary[candidate_id].concept_name) + data_point.add_label( + pred_label_type, candidate_id, confidence, name=self.dictionary[candidate_id].concept_name + ) @staticmethod def _fetch_model(model_name: str) -> str: From d7fa7ddef64c2c6bc05c717cb8bf2cee97572704 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Wed, 7 Feb 2024 00:44:50 +0100 Subject: [PATCH 56/58] make test more memory efficient by only loading the smallest model --- tests/test_biomedical_entity_linking.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_biomedical_entity_linking.py b/tests/test_biomedical_entity_linking.py index 00e2b6cf7..40fd11391 100644 --- a/tests/test_biomedical_entity_linking.py +++ b/tests/test_biomedical_entity_linking.py @@ -90,9 +90,8 @@ def test_biomedical_entity_linking(): tagger = Classifier.load("hunflair") tagger.predict(sentence) - for entity_type in ["disease", "chemical", "gene", "species"]: - linker = EntityMentionLinker.load(f"{entity_type}-linker") - linker.predict(sentence) + linker = EntityMentionLinker.load("disease-linker") + linker.predict(sentence) for span in sentence.get_spans(): print(span) From ba833f0bc88e884c9acbefeb93f19fb5e9ae7f59 Mon Sep 17 00:00:00 2001 From: Samuele Garda Date: Wed, 7 Feb 2024 18:24:24 +0100 Subject: [PATCH 57/58] chore(docs): add docstrings fro datasets --- flair/datasets/entity_linking.py | 41 ++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/flair/datasets/entity_linking.py b/flair/datasets/entity_linking.py index 91b4c8315..3351b11cb 100644 --- a/flair/datasets/entity_linking.py +++ b/flair/datasets/entity_linking.py @@ -29,12 +29,13 @@ class EntityLinkingDictionary: def __init__( self, candidates: Iterable[EntityCandidate], - dataset_name: Optional[str] = None, + dataset_name: Optional[str] = None, # used as prefix to `EntityCandidate.concept_id`, e.g. NCBI Gene:2 ): """Initialize the entity linking dictionary. Args: candidates: A iterable sequence of all Candidates contained in the knowledge base. + dataset_name: string to prefix concept IDs. To be used for custom dictionaries. """ # this dataset name if dataset_name is None: @@ -76,6 +77,8 @@ def to_in_memory_dictionary(self) -> "InMemoryEntityLinkingDictionary": return InMemoryEntityLinkingDictionary(list(self._idx_to_candidates.values()), self._dataset_name) +# NOTE: EntityLinkingDictionary are lazy-loaded from a preprocessed file. +# Use this class to load into memory all candidates class InMemoryEntityLinkingDictionary(EntityLinkingDictionary): def __init__(self, candidates: List[EntityCandidate], dataset_name: str): self._dataset_name = dataset_name @@ -2212,7 +2215,17 @@ def __init__( ) -class BigbioCorpus(Corpus, abc.ABC): +# TODO: Adapt this following: https://github.com/flairNLP/flair/pull/3146 +class BigBioEntityLinkingCorpus(Corpus, abc.ABC): + """This class implements an adapter to data sets implemented in the BigBio framework: + + https://github.com/bigscience-workshop/biomedical + + The BigBio framework harmonizes over 120 biomedical data sets and provides a uniform + programming api to access them. This adapter allows to use all named entity recognition + data sets by using the bigbio_kb schema. + """ + def __init__( self, base_path: Optional[Union[str, Path]] = None, @@ -2323,7 +2336,13 @@ def _files_to_dataset(self, paths: Union[Path, List[Path]]) -> FlairDatapointDat return FlairDatapointDataset(all_sentences) -class BIGBIO_NCBI_DISEASE(BigbioCorpus): +class BIGBIO_NCBI_DISEASE(BigBioEntityLinkingCorpus): + """This class implents the adapter for the NCBI Disease corpus: + + - Reference: https://www.sciencedirect.com/science/article/pii/S1532046413001974 + - Link: https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/ + """ + def __init__(self, base_path: Optional[Union[str, Path]] = None, label_type: str = "el-diseases", **kwargs) -> None: super().__init__(base_path, label_type, **kwargs) @@ -2418,7 +2437,13 @@ def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: yield unified_example -class BIGBIO_BC5CDR_CHEMICAL(BigbioCorpus): +class BIGBIO_BC5CDR_CHEMICAL(BigBioEntityLinkingCorpus): + """This class implents the adapter for the BC5CDR corpus (only chemical annotations): + + - Reference: https://academic.oup.com/database/article/doi/10.1093/database/baw068/2630414 + - Link: https://biocreative.bioinformatics.udel.edu/tasks/biocreative-v/track-3-cdr/ + """ + def __init__(self, base_path: Optional[Union[str, Path]] = None, label_type: str = "el-chemical", **kwargs) -> None: super().__init__(base_path, label_type, **kwargs) @@ -2503,7 +2528,13 @@ def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: yield data -class BIGBIO_GNORMPLUS(BigbioCorpus): +class BIGBIO_GNORMPLUS(BigBioEntityLinkingCorpus): + """This class implents the adapter for the GNormPlus corpus: + + - Reference: https://www.hindawi.com/journals/bmri/2015/918710/ + - Link: https://www.ncbi.nlm.nih.gov/research/bionlp/Tools/gnormplus/ + """ + def __init__(self, base_path: Optional[Union[str, Path]] = None, label_type: str = "el-genes", **kwargs) -> None: self._re_tax_id = re.compile(r"(?P\d+)\([tT]ax:(?P\d+)\)") super().__init__(base_path, label_type, norm_keys=["db_id"], **kwargs) From 44d73a25fb181a7a7e850c4feae35ce2500194e2 Mon Sep 17 00:00:00 2001 From: Samuele Garda Date: Thu, 8 Feb 2024 09:47:29 +0100 Subject: [PATCH 58/58] fix(bigbio): better naming & fix ruff --- flair/datasets/entity_linking.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/flair/datasets/entity_linking.py b/flair/datasets/entity_linking.py index 3351b11cb..20f2caefd 100644 --- a/flair/datasets/entity_linking.py +++ b/flair/datasets/entity_linking.py @@ -2217,9 +2217,9 @@ def __init__( # TODO: Adapt this following: https://github.com/flairNLP/flair/pull/3146 class BigBioEntityLinkingCorpus(Corpus, abc.ABC): - """This class implements an adapter to data sets implemented in the BigBio framework: + """This class implements an adapter to data sets implemented in the BigBio framework. - https://github.com/bigscience-workshop/biomedical + See: https://github.com/bigscience-workshop/biomedical The BigBio framework harmonizes over 120 biomedical data sets and provides a uniform programming api to access them. This adapter allows to use all named entity recognition @@ -2336,9 +2336,10 @@ def _files_to_dataset(self, paths: Union[Path, List[Path]]) -> FlairDatapointDat return FlairDatapointDataset(all_sentences) -class BIGBIO_NCBI_DISEASE(BigBioEntityLinkingCorpus): - """This class implents the adapter for the NCBI Disease corpus: +class BIGBIO_EL_NCBI_DISEASE(BigBioEntityLinkingCorpus): + """This class implents the adapter for the NCBI Disease corpus. + See: - Reference: https://www.sciencedirect.com/science/article/pii/S1532046413001974 - Link: https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/ """ @@ -2437,9 +2438,10 @@ def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: yield unified_example -class BIGBIO_BC5CDR_CHEMICAL(BigBioEntityLinkingCorpus): - """This class implents the adapter for the BC5CDR corpus (only chemical annotations): +class BIGBIO_EL_BC5CDR_CHEMICAL(BigBioEntityLinkingCorpus): + """This class implents the adapter for the BC5CDR corpus (only chemical annotations). + See: - Reference: https://academic.oup.com/database/article/doi/10.1093/database/baw068/2630414 - Link: https://biocreative.bioinformatics.udel.edu/tasks/biocreative-v/track-3-cdr/ """ @@ -2528,9 +2530,10 @@ def _file_to_dicts(self, filepath: Path) -> Iterator[Dict[str, Any]]: yield data -class BIGBIO_GNORMPLUS(BigBioEntityLinkingCorpus): - """This class implents the adapter for the GNormPlus corpus: +class BIGBIO_EL_GNORMPLUS(BigBioEntityLinkingCorpus): + """This class implents the adapter for the GNormPlus corpus. + See: - Reference: https://www.hindawi.com/journals/bmri/2015/918710/ - Link: https://www.ncbi.nlm.nih.gov/research/bionlp/Tools/gnormplus/ """