Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adding embedding cache for gdc case #76

Merged
merged 4 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bdikit/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"bdi-cl-v0.2": "https://nyu.box.com/shared/static/1vdc28kzbpoj6ey95bksaww541p9gj31.pt",
}

BDIKIT_EMBEDDINGS_CACHE_DIR = os.path.join(BDIKIT_CACHE_DIR, "embeddings")


def download_file_url(url: str, destination: str):
# start the download stream
Expand Down
18 changes: 15 additions & 3 deletions bdikit/models/contrastive_learning/cl_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from bdikit.download import get_cached_model_or_download
from bdikit.utils import check_gdc_cache, write_embeddings_to_cache


dir_path = os.path.dirname(os.path.realpath(__file__))
GDC_TABLE_PATH = os.path.join(dir_path, "../../../../resource/gdc_table.csv")
DEFAULT_CL_MODEL = "bdi-cl-v0.2"


Expand Down Expand Up @@ -106,14 +106,26 @@ def _sample_to_15_rows(self, table: pd.DataFrame):
return table

def _load_table_tokens(self, table: pd.DataFrame) -> List[np.ndarray]:

embedding_file, embeddings = check_gdc_cache(table, self.model_path)

if embeddings != None:
print(f"Table features loaded for {len(table.columns)} columns")
return embeddings

tables = []
for _, column in enumerate(table.columns):
curr_table = pd.DataFrame(table[column])
curr_table = self._sample_to_15_rows(curr_table)
tables.append(curr_table)
vectors = self._inference_on_tables(tables)
print(f"Table features extracted from {len(table.columns)} columns")
return [vec[-1] for vec in vectors]
embeddings = [vec[-1] for vec in vectors]

if embedding_file != None:
write_embeddings_to_cache(embedding_file, embeddings)

return embeddings

def _inference_on_tables(self, tables: List[pd.DataFrame]) -> List[List]:
total = len(tables)
Expand Down
79 changes: 79 additions & 0 deletions bdikit/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import json
from os.path import join, dirname
import pandas as pd
import hashlib
import os
from bdikit.download import BDIKIT_EMBEDDINGS_CACHE_DIR

GDC_SCHEMA_PATH = join(dirname(__file__), "./resource/gdc_schema.json")
GDC_TABLE_PATH = join(dirname(__file__), "./resource/gdc_table.csv")

__gdc_df = None
__gdc_hash = None


def read_gdc_schema():
Expand Down Expand Up @@ -70,3 +78,74 @@ def get_gdc_layered_metadata():
metadata[key] = (subschema, data)

return metadata


def hash_dataframe(df: pd.DataFrame) -> str:
EduardoPena marked this conversation as resolved.
Show resolved Hide resolved

hash_object = hashlib.sha256()

columns_string = ",".join(df.columns) + "\n"
hash_object.update(columns_string.encode())

for row in df.itertuples(index=False, name=None):
row_string = ",".join(map(str, row)) + "\n"
hash_object.update(row_string.encode())

return hash_object.hexdigest()


def write_embeddings_to_cache(embedding_file: str, embeddings: list):

os.makedirs(os.path.dirname(embedding_file), exist_ok=True)

with open(embedding_file, "w") as file:
for vec in embeddings:
file.write(",".join([str(val) for val in vec]) + "\n")


def load_gdc_data():
global __gdc_df, __gdc_hash
if __gdc_df is None or __gdc_hash is None:
__gdc_df = pd.read_csv(GDC_TABLE_PATH)
__gdc_hash = hash_dataframe(__gdc_df)


def get_gdc_dataframe():
global __gdc_df
load_gdc_data()

return __gdc_df


def check_gdc_cache(table: pd.DataFrame, model_path: str):
global __gdc_df, __gdc_hash
load_gdc_data()

table_hash = hash_dataframe(table)

df_hash_file = None
features = None

# check if table for computing embedding is the same as the GDC table we have in resources
if table_hash == __gdc_hash:
model_name = model_path.split("/")[-1]
cache_model_path = os.path.join(BDIKIT_EMBEDDINGS_CACHE_DIR, model_name)
df_hash_file = os.path.join(cache_model_path, __gdc_hash)

# Found file in cache
if os.path.isfile(df_hash_file):
try:
# Load embeddings from disk
with open(df_hash_file, "r") as file:
features = [
[float(val) for val in vec.split(",")]
for vec in file.read().split("\n")
if vec.strip()
]
if len(features) != len(__gdc_df.columns):
features = None
raise ValueError("Mismatch in the number of features")
except Exception as e:
print(f"Error loading features from cache: {e}")
features = None
return df_hash_file, features
Loading