diff --git a/bdikit/download.py b/bdikit/download.py index 27d85c90..27f86570 100644 --- a/bdikit/download.py +++ b/bdikit/download.py @@ -17,6 +17,8 @@ "bdi-cl-v0.2": "https://nyu.box.com/shared/static/1vdc28kzbpoj6ey95bksaww541p9gj31.pt", } +BDIKIT_EMBEDDINGS_CACHE_DIR = os.path.join(BDIKIT_CACHE_DIR, "embeddings") + def download_file_url(url: str, destination: str): # start the download stream diff --git a/bdikit/models/contrastive_learning/cl_api.py b/bdikit/models/contrastive_learning/cl_api.py index 15ddddb8..d34795e4 100644 --- a/bdikit/models/contrastive_learning/cl_api.py +++ b/bdikit/models/contrastive_learning/cl_api.py @@ -13,9 +13,9 @@ from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm from bdikit.download import get_cached_model_or_download +from bdikit.utils import check_gdc_cache, write_embeddings_to_cache + -dir_path = os.path.dirname(os.path.realpath(__file__)) -GDC_TABLE_PATH = os.path.join(dir_path, "../../../../resource/gdc_table.csv") DEFAULT_CL_MODEL = "bdi-cl-v0.2" @@ -106,6 +106,13 @@ def _sample_to_15_rows(self, table: pd.DataFrame): return table def _load_table_tokens(self, table: pd.DataFrame) -> List[np.ndarray]: + + embedding_file, embeddings = check_gdc_cache(table, self.model_path) + + if embeddings != None: + print(f"Table features loaded for {len(table.columns)} columns") + return embeddings + tables = [] for _, column in enumerate(table.columns): curr_table = pd.DataFrame(table[column]) @@ -113,7 +120,12 @@ def _load_table_tokens(self, table: pd.DataFrame) -> List[np.ndarray]: tables.append(curr_table) vectors = self._inference_on_tables(tables) print(f"Table features extracted from {len(table.columns)} columns") - return [vec[-1] for vec in vectors] + embeddings = [vec[-1] for vec in vectors] + + if embedding_file != None: + write_embeddings_to_cache(embedding_file, embeddings) + + return embeddings def _inference_on_tables(self, tables: List[pd.DataFrame]) -> List[List]: total = len(tables) diff --git a/bdikit/utils.py b/bdikit/utils.py index 8827987d..8cbde80d 100644 --- a/bdikit/utils.py +++ b/bdikit/utils.py @@ -1,7 +1,15 @@ import json from os.path import join, dirname +import pandas as pd +import hashlib +import os +from bdikit.download import BDIKIT_EMBEDDINGS_CACHE_DIR GDC_SCHEMA_PATH = join(dirname(__file__), "./resource/gdc_schema.json") +GDC_TABLE_PATH = join(dirname(__file__), "./resource/gdc_table.csv") + +__gdc_df = None +__gdc_hash = None def read_gdc_schema(): @@ -70,3 +78,74 @@ def get_gdc_layered_metadata(): metadata[key] = (subschema, data) return metadata + + +def hash_dataframe(df: pd.DataFrame) -> str: + + hash_object = hashlib.sha256() + + columns_string = ",".join(df.columns) + "\n" + hash_object.update(columns_string.encode()) + + for row in df.itertuples(index=False, name=None): + row_string = ",".join(map(str, row)) + "\n" + hash_object.update(row_string.encode()) + + return hash_object.hexdigest() + + +def write_embeddings_to_cache(embedding_file: str, embeddings: list): + + os.makedirs(os.path.dirname(embedding_file), exist_ok=True) + + with open(embedding_file, "w") as file: + for vec in embeddings: + file.write(",".join([str(val) for val in vec]) + "\n") + + +def load_gdc_data(): + global __gdc_df, __gdc_hash + if __gdc_df is None or __gdc_hash is None: + __gdc_df = pd.read_csv(GDC_TABLE_PATH) + __gdc_hash = hash_dataframe(__gdc_df) + + +def get_gdc_dataframe(): + global __gdc_df + load_gdc_data() + + return __gdc_df + + +def check_gdc_cache(table: pd.DataFrame, model_path: str): + global __gdc_df, __gdc_hash + load_gdc_data() + + table_hash = hash_dataframe(table) + + df_hash_file = None + features = None + + # check if table for computing embedding is the same as the GDC table we have in resources + if table_hash == __gdc_hash: + model_name = model_path.split("/")[-1] + cache_model_path = os.path.join(BDIKIT_EMBEDDINGS_CACHE_DIR, model_name) + df_hash_file = os.path.join(cache_model_path, __gdc_hash) + + # Found file in cache + if os.path.isfile(df_hash_file): + try: + # Load embeddings from disk + with open(df_hash_file, "r") as file: + features = [ + [float(val) for val in vec.split(",")] + for vec in file.read().split("\n") + if vec.strip() + ] + if len(features) != len(__gdc_df.columns): + features = None + raise ValueError("Mismatch in the number of features") + except Exception as e: + print(f"Error loading features from cache: {e}") + features = None + return df_hash_file, features