-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #323 from mims-harvard/geneformer_server
Model server with Geneformer pilot: tokenizer
- Loading branch information
Showing
11 changed files
with
253 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
from .evaluator import Evaluator | ||
from .oracles import Oracle | ||
from .benchmark_deprecated import BenchmarkGroup | ||
from .tdc_hf import tdc_hf_interface | ||
from .model_server.tdc_hf import tdc_hf_interface | ||
from tdc.utils.knowledge_graph import KnowledgeGraph |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
File renamed without changes.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import numpy as np | ||
import scipy.sparse as sp | ||
|
||
from geneformer import TranscriptomeTokenizer | ||
from ...utils.load import pd_load, download_wrapper | ||
|
||
|
||
class GeneformerTokenizer(TranscriptomeTokenizer): | ||
""" | ||
Uses Geneformer Utils to parse zero-shot model server requests for tokenizing single-cell gene expression data. | ||
Tokenizer source code: https://github.com/amva13/geneformer/blob/main/geneformer/tokenizer.py | ||
""" | ||
|
||
def __init__( | ||
self, | ||
path=None, | ||
custom_attr_name_dict=None, | ||
nproc=1, | ||
): | ||
path = path or "./data" | ||
download_wrapper("geneformer_gene_median_dictionary", path, | ||
["geneformer_gene_median_dictionary"]) | ||
download_wrapper("geneformer_gene_name_id_dict", path, | ||
["geneformer_gene_name_id_dict"]) | ||
download_wrapper("geneformer_token_dictionary", path, | ||
["geneformer_token_dictionary"]) | ||
self.gene_median_dict = pd_load("geneformer_gene_median_dictionary", | ||
path=path) | ||
self.gene_name_id_dict = pd_load("geneformer_gene_name_id_dict", | ||
path=path) | ||
self.gene_token_dict = pd_load("geneformer_token_dictionary", path=path) | ||
self.custom_attr_name_dict = custom_attr_name_dict | ||
self.nproc = nproc | ||
|
||
# gene keys for full vocabulary | ||
self.gene_keys = list(self.gene_median_dict.keys()) | ||
|
||
# protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization | ||
self.genelist_dict = dict( | ||
zip(self.gene_keys, [True] * len(self.gene_keys))) | ||
|
||
@classmethod | ||
def rank_genes(cls, gene_vector, gene_tokens): | ||
""" | ||
Rank gene expression vector. | ||
""" | ||
# sort by median-scaled gene values | ||
sorted_indices = np.argsort(-gene_vector) | ||
return gene_tokens[sorted_indices] | ||
|
||
def tokenize_cell_vectors(self, | ||
cell_vector_adata, | ||
target_sum=10_000, | ||
chunk_size=512, | ||
ensembl_id="ensembl_id"): | ||
""" | ||
Tokenizing single-cell gene expression vectors formatted as anndata types. | ||
""" | ||
adata = cell_vector_adata | ||
if self.custom_attr_name_dict is not None: | ||
file_cell_metadata = { | ||
attr_key: [] for attr_key in self.custom_attr_name_dict.keys() | ||
} | ||
|
||
coding_miRNA_loc = np.where([ | ||
self.genelist_dict.get(i, False) for i in adata.var[ensembl_id] | ||
])[0] | ||
norm_factor_vector = np.array([ | ||
self.gene_median_dict[i] | ||
for i in adata.var[ensembl_id][coding_miRNA_loc] | ||
]) | ||
coding_miRNA_ids = adata.var[ensembl_id][coding_miRNA_loc] | ||
coding_miRNA_tokens = np.array( | ||
[self.gene_token_dict[i] for i in coding_miRNA_ids]) | ||
|
||
try: | ||
_ = adata.obs["filter_pass"] | ||
except KeyError: | ||
var_exists = False | ||
else: | ||
var_exists = True | ||
|
||
if var_exists: | ||
filter_pass_loc = np.where( | ||
[i == 1 for i in adata.obs["filter_pass"]])[0] | ||
elif not var_exists: | ||
print( | ||
f"The anndata object has no column attribute 'filter_pass'; tokenizing all cells." | ||
) | ||
filter_pass_loc = np.array([i for i in range(adata.shape[0])]) | ||
|
||
tokenized_cells = [] | ||
|
||
for i in range(0, len(filter_pass_loc), chunk_size): | ||
idx = filter_pass_loc[i:i + chunk_size] | ||
|
||
n_counts = adata[idx].obs['ncounts'].values[:, None] | ||
X_view = adata[idx, coding_miRNA_loc].X | ||
X_norm = (X_view / n_counts * target_sum / norm_factor_vector) | ||
X_norm = sp.csr_matrix(X_norm) | ||
|
||
tokenized_cells += [ | ||
self.rank_genes(X_norm[i].data, | ||
coding_miRNA_tokens[X_norm[i].indices]) | ||
for i in range(X_norm.shape[0]) | ||
] | ||
|
||
# add custom attributes for subview to dict | ||
if self.custom_attr_name_dict is not None: | ||
for k in file_cell_metadata.keys(): | ||
file_cell_metadata[k] += adata[idx].obs[k].tolist() | ||
else: | ||
file_cell_metadata = None | ||
|
||
return tokenized_cells, file_cell_metadata |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import os | ||
import sys | ||
|
||
import unittest | ||
import shutil | ||
import pytest | ||
|
||
# temporary solution for relative imports in case TDC is not installed | ||
# if TDC is installed, no need to use the following line | ||
sys.path.append( | ||
os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) | ||
# TODO: add verification for the generation other than simple integration | ||
|
||
from tdc.resource import cellxgene_census | ||
from tdc.model_server.tokenizers.geneformer import GeneformerTokenizer | ||
|
||
import requests | ||
|
||
|
||
def get_target_from_chembl(chembl_id): | ||
# Query ChEMBL API for target information | ||
chembl_url = f"https://www.ebi.ac.uk/chembl/api/data/target/{chembl_id}.json" | ||
response = requests.get(chembl_url) | ||
|
||
if response.status_code == 200: | ||
data = response.json() | ||
# Extract UniProt ID from the ChEMBL target info | ||
for component in data.get('target_components', []): | ||
for xref in component.get('target_component_xrefs', []): | ||
if xref['xref_src_db'] == 'UniProt': | ||
return xref['xref_id'] | ||
else: | ||
raise ValueError(f"ChEMBL ID {chembl_id} not found or invalid.") | ||
return None | ||
|
||
|
||
def get_ensembl_from_uniprot(uniprot_id): | ||
# Query UniProt API to get Ensembl ID from UniProt ID | ||
uniprot_url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.json" | ||
response = requests.get(uniprot_url) | ||
|
||
if response.status_code == 200: | ||
data = response.json() | ||
# Extract Ensembl Gene ID from the cross-references | ||
for xref in data.get('dbReferences', []): | ||
if xref['type'] == 'Ensembl': | ||
return xref['id'] | ||
else: | ||
raise ValueError(f"UniProt ID {uniprot_id} not found or invalid.") | ||
return None | ||
|
||
|
||
def get_ensembl_id_from_chembl_id(chembl_id): | ||
try: | ||
# Step 1: Get UniProt ID from ChEMBL | ||
uniprot_id = get_target_from_chembl(chembl_id) | ||
if not uniprot_id: | ||
return f"No UniProt ID found for ChEMBL ID {chembl_id}" | ||
|
||
# Step 2: Get Ensembl ID from UniProt | ||
ensembl_id = get_ensembl_from_uniprot(uniprot_id) | ||
if not ensembl_id: | ||
return f"No Ensembl ID found for UniProt ID {uniprot_id}" | ||
|
||
return f"Ensembl ID for ChEMBL ID {chembl_id}: {ensembl_id}" | ||
except Exception as e: | ||
return str(e) | ||
|
||
|
||
class TestModelServer(unittest.TestCase): | ||
|
||
def setUp(self): | ||
print(os.getcwd()) | ||
self.resource = cellxgene_census.CensusResource() | ||
|
||
def testGeneformerTokenizer(self): | ||
import anndata | ||
from tdc.multi_pred.perturboutcome import PerturbOutcome | ||
test_loader = PerturbOutcome( | ||
name="scperturb_drug_AissaBenevolenskaya2021") | ||
adata = test_loader.adata | ||
print("swapping obs and var because scperturb violated convention...") | ||
adata_flipped = anndata.AnnData(adata.X.T) | ||
adata_flipped.obs = adata.var | ||
adata_flipped.var = adata.obs | ||
adata = adata_flipped | ||
print("swap complete") | ||
print("adding ensembl ids...") | ||
adata.var["ensembl_id"] = adata.var["chembl-ID"].apply( | ||
get_ensembl_id_from_chembl_id) | ||
print("added ensembl_id column") | ||
|
||
print(type(adata.var)) | ||
print(adata.var.columns) | ||
print(type(adata.obs)) | ||
print(adata.obs.columns) | ||
print("initializing tokenizer") | ||
tokenizer = GeneformerTokenizer() | ||
print("testing tokenizer") | ||
x = tokenizer.tokenize_cell_vectors(adata) | ||
assert x | ||
|
||
def tearDown(self): | ||
try: | ||
print(os.getcwd()) | ||
shutil.rmtree(os.path.join(os.getcwd(), "data")) | ||
except: | ||
pass |