From 8cc3c2d4d15ccfeb5bde69365ac48d26ae342fb8 Mon Sep 17 00:00:00 2001 From: EduardoPena Date: Tue, 16 Jul 2024 16:08:27 -0400 Subject: [PATCH 1/3] Adding embedding cache for gdc case --- .../contrastive_learning/cl_api.py | 54 ++++++++++++++++++- bdikit/utils.py | 26 +++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/bdikit/mapping_algorithms/scope_reducing/_algorithms/contrastive_learning/cl_api.py b/bdikit/mapping_algorithms/scope_reducing/_algorithms/contrastive_learning/cl_api.py index 0c53a218..1a7c6a8c 100644 --- a/bdikit/mapping_algorithms/scope_reducing/_algorithms/contrastive_learning/cl_api.py +++ b/bdikit/mapping_algorithms/scope_reducing/_algorithms/contrastive_learning/cl_api.py @@ -13,11 +13,20 @@ from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm from bdikit.download import get_cached_model_or_download +from bdikit.utils import hash_dataframe, write_embeddings_to_cache dir_path = os.path.dirname(os.path.realpath(__file__)) GDC_TABLE_PATH = os.path.join(dir_path, "../../../../resource/gdc_table.csv") DEFAULT_CL_MODEL = "bdi-cl-v0.2" +default_os_cache_dir = os.getenv( + "XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache") +) +BDIKIT_CACHE_DIR = os.getenv( + "BDIKIT_CACHE", os.path.join(default_os_cache_dir, "bdikit") +) +BDIKIT_EMBEDDINGS_CACHE_DIR = os.path.join(BDIKIT_CACHE_DIR, "embeddings") + class ContrastiveLearningAPI: def __init__( @@ -105,7 +114,45 @@ def _sample_to_15_rows(self, table: pd.DataFrame): table = unique_rows.sample(n=15, random_state=1) return table + def _check_gdc_cache(self, table: pd.DataFrame): + + gdc_df = pd.read_csv(GDC_TABLE_PATH) + gdc_hash = hash_dataframe(gdc_df) + + table_hash = hash_dataframe(table) + + df_hash_file = None + features = None + + # check if table for computing embedding is the same as the GDC table we have in resources + if table_hash == gdc_hash: + df_hash_file = os.path.join(BDIKIT_EMBEDDINGS_CACHE_DIR, gdc_hash) + # Found file in cache + if os.path.isfile(df_hash_file): + try: + # Load embeddings from disk + with open(df_hash_file, "r") as file: + features = [ + [float(val) for val in vec.split(",")] + for vec in file.read().split("\n") + if vec.strip() + ] + if len(features) != len(gdc_df.columns): + features = None + raise ValueError("Mismatch in the number of features") + except Exception as e: + print(f"Error loading features from cache: {e}") + features = None + return df_hash_file, features + def _load_table_tokens(self, table: pd.DataFrame) -> List[np.ndarray]: + + embedding_file, embeddings = self._check_gdc_cache(table) + + if embeddings != None: + print(f"Table features loaded for {len(table.columns)} columns") + return embeddings + tables = [] for _, column in enumerate(table.columns): curr_table = pd.DataFrame(table[column]) @@ -113,7 +160,12 @@ def _load_table_tokens(self, table: pd.DataFrame) -> List[np.ndarray]: tables.append(curr_table) vectors = self._inference_on_tables(tables) print(f"Table features extracted from {len(table.columns)} columns") - return [vec[-1] for vec in vectors] + embeddings = [vec[-1] for vec in vectors] + + if embedding_file != None: + write_embeddings_to_cache(embedding_file, embeddings) + + return embeddings def _inference_on_tables(self, tables: List[pd.DataFrame]) -> List[List]: total = len(tables) diff --git a/bdikit/utils.py b/bdikit/utils.py index 8827987d..f65f7421 100644 --- a/bdikit/utils.py +++ b/bdikit/utils.py @@ -1,5 +1,8 @@ import json from os.path import join, dirname +import pandas as pd +import hashlib +import os GDC_SCHEMA_PATH = join(dirname(__file__), "./resource/gdc_schema.json") @@ -70,3 +73,26 @@ def get_gdc_layered_metadata(): metadata[key] = (subschema, data) return metadata + + +def hash_dataframe(df: pd.DataFrame) -> str: + + hash_object = hashlib.sha256() + + columns_string = ",".join(df.columns) + "\n" + hash_object.update(columns_string.encode()) + + for row in df.itertuples(index=False, name=None): + row_string = ",".join(map(str, row)) + "\n" + hash_object.update(row_string.encode()) + + return hash_object.hexdigest() + + +def write_embeddings_to_cache(embedding_file: str, embeddings: list): + + os.makedirs(os.path.dirname(embedding_file), exist_ok=True) + + with open(embedding_file, "w") as file: + for vec in embeddings: + file.write(",".join([str(val) for val in vec]) + "\n") From 6e9427cbce3cbeae1b549349ca8f8c7e4f6e22a6 Mon Sep 17 00:00:00 2001 From: EduardoPena Date: Mon, 22 Jul 2024 10:46:14 -0400 Subject: [PATCH 2/3] Providing global access to GDC data and using that for cache lookup --- bdikit/download.py | 2 + .../contrastive_learning/cl_api.py | 46 ++--------------- bdikit/utils.py | 50 +++++++++++++++++++ 3 files changed, 55 insertions(+), 43 deletions(-) diff --git a/bdikit/download.py b/bdikit/download.py index 27d85c90..27f86570 100644 --- a/bdikit/download.py +++ b/bdikit/download.py @@ -17,6 +17,8 @@ "bdi-cl-v0.2": "https://nyu.box.com/shared/static/1vdc28kzbpoj6ey95bksaww541p9gj31.pt", } +BDIKIT_EMBEDDINGS_CACHE_DIR = os.path.join(BDIKIT_CACHE_DIR, "embeddings") + def download_file_url(url: str, destination: str): # start the download stream diff --git a/bdikit/mapping_algorithms/scope_reducing/_algorithms/contrastive_learning/cl_api.py b/bdikit/mapping_algorithms/scope_reducing/_algorithms/contrastive_learning/cl_api.py index 1a7c6a8c..ec0c9f70 100644 --- a/bdikit/mapping_algorithms/scope_reducing/_algorithms/contrastive_learning/cl_api.py +++ b/bdikit/mapping_algorithms/scope_reducing/_algorithms/contrastive_learning/cl_api.py @@ -13,19 +13,10 @@ from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm from bdikit.download import get_cached_model_or_download -from bdikit.utils import hash_dataframe, write_embeddings_to_cache +from bdikit.utils import check_gdc_cache, write_embeddings_to_cache -dir_path = os.path.dirname(os.path.realpath(__file__)) -GDC_TABLE_PATH = os.path.join(dir_path, "../../../../resource/gdc_table.csv") -DEFAULT_CL_MODEL = "bdi-cl-v0.2" -default_os_cache_dir = os.getenv( - "XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache") -) -BDIKIT_CACHE_DIR = os.getenv( - "BDIKIT_CACHE", os.path.join(default_os_cache_dir, "bdikit") -) -BDIKIT_EMBEDDINGS_CACHE_DIR = os.path.join(BDIKIT_CACHE_DIR, "embeddings") +DEFAULT_CL_MODEL = "bdi-cl-v0.2" class ContrastiveLearningAPI: @@ -114,40 +105,9 @@ def _sample_to_15_rows(self, table: pd.DataFrame): table = unique_rows.sample(n=15, random_state=1) return table - def _check_gdc_cache(self, table: pd.DataFrame): - - gdc_df = pd.read_csv(GDC_TABLE_PATH) - gdc_hash = hash_dataframe(gdc_df) - - table_hash = hash_dataframe(table) - - df_hash_file = None - features = None - - # check if table for computing embedding is the same as the GDC table we have in resources - if table_hash == gdc_hash: - df_hash_file = os.path.join(BDIKIT_EMBEDDINGS_CACHE_DIR, gdc_hash) - # Found file in cache - if os.path.isfile(df_hash_file): - try: - # Load embeddings from disk - with open(df_hash_file, "r") as file: - features = [ - [float(val) for val in vec.split(",")] - for vec in file.read().split("\n") - if vec.strip() - ] - if len(features) != len(gdc_df.columns): - features = None - raise ValueError("Mismatch in the number of features") - except Exception as e: - print(f"Error loading features from cache: {e}") - features = None - return df_hash_file, features - def _load_table_tokens(self, table: pd.DataFrame) -> List[np.ndarray]: - embedding_file, embeddings = self._check_gdc_cache(table) + embedding_file, embeddings = check_gdc_cache(table) if embeddings != None: print(f"Table features loaded for {len(table.columns)} columns") diff --git a/bdikit/utils.py b/bdikit/utils.py index f65f7421..5d5ed8be 100644 --- a/bdikit/utils.py +++ b/bdikit/utils.py @@ -3,8 +3,13 @@ import pandas as pd import hashlib import os +from bdikit.download import BDIKIT_EMBEDDINGS_CACHE_DIR GDC_SCHEMA_PATH = join(dirname(__file__), "./resource/gdc_schema.json") +GDC_TABLE_PATH = join(dirname(__file__), "./resource/gdc_table.csv") + +__gdc_df = None +__gdc_hash = None def read_gdc_schema(): @@ -96,3 +101,48 @@ def write_embeddings_to_cache(embedding_file: str, embeddings: list): with open(embedding_file, "w") as file: for vec in embeddings: file.write(",".join([str(val) for val in vec]) + "\n") + + +def load_gdc_data(): + global __gdc_df, __gdc_hash + if __gdc_df is None or __gdc_hash is None: + __gdc_df = pd.read_csv(GDC_TABLE_PATH) + __gdc_hash = hash_dataframe(__gdc_df) + + +def get_gdc_dataframe(): + global __gdc_df + load_gdc_data() + + return __gdc_df + + +def check_gdc_cache(table: pd.DataFrame): + global __gdc_df, __gdc_hash + load_gdc_data() + + table_hash = hash_dataframe(table) + + df_hash_file = None + features = None + + # check if table for computing embedding is the same as the GDC table we have in resources + if table_hash == __gdc_hash: + df_hash_file = os.path.join(BDIKIT_EMBEDDINGS_CACHE_DIR, __gdc_hash) + # Found file in cache + if os.path.isfile(df_hash_file): + try: + # Load embeddings from disk + with open(df_hash_file, "r") as file: + features = [ + [float(val) for val in vec.split(",")] + for vec in file.read().split("\n") + if vec.strip() + ] + if len(features) != len(__gdc_df.columns): + features = None + raise ValueError("Mismatch in the number of features") + except Exception as e: + print(f"Error loading features from cache: {e}") + features = None + return df_hash_file, features From f3810ca3c3b0176fba0f3ea92429f77b0e4562e1 Mon Sep 17 00:00:00 2001 From: EduardoPena Date: Thu, 25 Jul 2024 11:32:03 -0400 Subject: [PATCH 3/3] Rebase --- bdikit/models/contrastive_learning/cl_api.py | 2 +- bdikit/utils.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/bdikit/models/contrastive_learning/cl_api.py b/bdikit/models/contrastive_learning/cl_api.py index 976b7623..d34795e4 100644 --- a/bdikit/models/contrastive_learning/cl_api.py +++ b/bdikit/models/contrastive_learning/cl_api.py @@ -107,7 +107,7 @@ def _sample_to_15_rows(self, table: pd.DataFrame): def _load_table_tokens(self, table: pd.DataFrame) -> List[np.ndarray]: - embedding_file, embeddings = check_gdc_cache(table) + embedding_file, embeddings = check_gdc_cache(table, self.model_path) if embeddings != None: print(f"Table features loaded for {len(table.columns)} columns") diff --git a/bdikit/utils.py b/bdikit/utils.py index 5d5ed8be..8cbde80d 100644 --- a/bdikit/utils.py +++ b/bdikit/utils.py @@ -117,7 +117,7 @@ def get_gdc_dataframe(): return __gdc_df -def check_gdc_cache(table: pd.DataFrame): +def check_gdc_cache(table: pd.DataFrame, model_path: str): global __gdc_df, __gdc_hash load_gdc_data() @@ -128,7 +128,10 @@ def check_gdc_cache(table: pd.DataFrame): # check if table for computing embedding is the same as the GDC table we have in resources if table_hash == __gdc_hash: - df_hash_file = os.path.join(BDIKIT_EMBEDDINGS_CACHE_DIR, __gdc_hash) + model_name = model_path.split("/")[-1] + cache_model_path = os.path.join(BDIKIT_EMBEDDINGS_CACHE_DIR, model_name) + df_hash_file = os.path.join(cache_model_path, __gdc_hash) + # Found file in cache if os.path.isfile(df_hash_file): try: