-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
176 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
from .evaluator import Evaluator | ||
from .oracles import Oracle | ||
from .benchmark_deprecated import BenchmarkGroup | ||
from .tdc_hf import tdc_hf_interface | ||
from .model_server.tdc_hf import tdc_hf_interface | ||
from tdc.utils.knowledge_graph import KnowledgeGraph |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
File renamed without changes.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import numpy as np | ||
import scipy.sparse as sp | ||
|
||
from geneformer import TranscriptomeTokenizer | ||
from ...utils.load import pd_load, download_wrapper | ||
|
||
class GeneformerTokenizer(TranscriptomeTokenizer): | ||
""" | ||
Uses Geneformer Utils to parse zero-shot model server requests for tokenizing single-cell gene expression data. | ||
Geneformer tokenizer source code: https://github.com/jkobject/geneformer/blob/main/geneformer/tokenizer.py | ||
""" | ||
|
||
def __init__(self, path=None, custom_attr_name_dict=None, nproc=1,): | ||
path = path or "./data" | ||
download_wrapper("geneformer_gene_median_dictionary", path, ["geneformer_gene_median_dictionary"]) | ||
download_wrapper("geneformer_gene_name_id_dict", path, ["geneformer_gene_name_id_dict"]) | ||
download_wrapper("geneformer_token_dictionary", path, ["geneformer_token_dictionary"]) | ||
self.gene_median_dict = pd_load("geneformer_gene_median_dictionary", path=path) | ||
self.gene_name_id_dict = pd_load("geneformer_gene_name_id_dict", path=path) | ||
self.gene_token_dict = pd_load("geneformer_token_dictionary", path=path) | ||
self.custom_attr_name_dict = custom_attr_name_dict | ||
self.nproc = nproc | ||
|
||
# gene keys for full vocabulary | ||
self.gene_keys = list(self.gene_median_dict.keys()) | ||
|
||
# protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization | ||
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys))) | ||
|
||
@classmethod | ||
def rank_genes(gene_vector, gene_tokens): | ||
""" | ||
Rank gene expression vector. | ||
""" | ||
# sort by median-scaled gene values | ||
sorted_indices = np.argsort(-gene_vector) | ||
return gene_tokens[sorted_indices] | ||
|
||
def tokenize_cell_vectors(self, cell_vector_adata, target_sum=10_000, chunk_size=512, ensembl_id="ensembl_id"): | ||
""" | ||
Tokenizing single-cell gene expression vectors formatted as anndata types. | ||
""" | ||
adata = cell_vector_adata | ||
if self.custom_attr_name_dict is not None: | ||
file_cell_metadata = { | ||
attr_key: [] for attr_key in self.custom_attr_name_dict.keys() | ||
} | ||
|
||
coding_miRNA_loc = np.where( | ||
[self.genelist_dict.get(i, False) for i in adata.var[ensembl_id]] | ||
)[0] | ||
norm_factor_vector = np.array( | ||
[ | ||
self.gene_median_dict[i] | ||
for i in adata.var[ensembl_id][coding_miRNA_loc] | ||
] | ||
) | ||
coding_miRNA_ids = adata.var[ensembl_id][coding_miRNA_loc] | ||
coding_miRNA_tokens = np.array( | ||
[self.gene_token_dict[i] for i in coding_miRNA_ids] | ||
) | ||
|
||
try: | ||
_ = adata.obs["filter_pass"] | ||
except KeyError: | ||
var_exists = False | ||
else: | ||
var_exists = True | ||
|
||
if var_exists: | ||
filter_pass_loc = np.where( | ||
[i == 1 for i in adata.obs["filter_pass"]] | ||
)[0] | ||
elif not var_exists: | ||
print( | ||
f"The anndata object has no column attribute 'filter_pass'; tokenizing all cells." | ||
) | ||
filter_pass_loc = np.array([i for i in range(adata.shape[0])]) | ||
|
||
tokenized_cells = [] | ||
|
||
for i in range(0, len(filter_pass_loc), chunk_size): | ||
idx = filter_pass_loc[i:i+chunk_size] | ||
|
||
print(adata[idx].obs.columns) | ||
|
||
n_counts = adata[idx].obs['ncounts'].values[:, None] | ||
X_view = adata[idx, coding_miRNA_loc].X | ||
X_norm = (X_view / n_counts * target_sum / norm_factor_vector) | ||
# print(type(adata[idx].X)) | ||
# X_norm = adata[idx].X["normalized"] | ||
# X_norm = sp.csr_matrix(X_norm) | ||
|
||
tokenized_cells += [ | ||
self.rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices]) | ||
for i in range(X_norm.shape[0]) | ||
] | ||
|
||
# add custom attributes for subview to dict | ||
if self.custom_attr_name_dict is not None: | ||
for k in file_cell_metadata.keys(): | ||
file_cell_metadata[k] += adata[idx].obs[k].tolist() | ||
else: | ||
file_cell_metadata = None | ||
|
||
return tokenized_cells, file_cell_metadata | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import os | ||
import sys | ||
|
||
import unittest | ||
import shutil | ||
import pytest | ||
|
||
# temporary solution for relative imports in case TDC is not installed | ||
# if TDC is installed, no need to use the following line | ||
sys.path.append( | ||
os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) | ||
# TODO: add verification for the generation other than simple integration | ||
|
||
from tdc.resource import cellxgene_census | ||
from tdc.model_server.tokenizers.geneformer import GeneformerTokenizer | ||
|
||
class TestModelServer(unittest.TestCase): | ||
|
||
def setUp(self): | ||
print(os.getcwd()) | ||
self.resource = cellxgene_census.CensusResource() | ||
|
||
def testGeneformerTokenizer(self): | ||
# genes = ['ENSG00000161798', 'ENSG00000188229'] | ||
# cell_types = ['mucus secreting cell', 'neuroendocrine cell'] | ||
# obs_cols = ["dataset_id", "assay", "suspension_type", "sex", "tissue_general", "tissue", "cell_type", "ncounts"] | ||
# adata = self.resource.gget_czi_cellxgene( | ||
# ensembl=True, | ||
# gene=genes, | ||
# cell_type=cell_types, | ||
# column_names=obs_cols, | ||
# ) | ||
# TODO: scperturb is using chembl, NOT ensembl. geneformer assumes ensembl. can fix by going back to cellxgene and not normalizing | ||
from tdc.multi_pred.perturboutcome import PerturbOutcome | ||
test_loader = PerturbOutcome( | ||
name="scperturb_drug_AissaBenevolenskaya2021") | ||
adata = test_loader.adata | ||
print(type(adata.var)) | ||
print(adata.var.columns) | ||
print(type(adata.obs)) | ||
print(adata.obs.columns) | ||
print("initializing tokenizer") | ||
tokenizer = GeneformerTokenizer() | ||
print("testing tokenizer") | ||
x = tokenizer.tokenize_cell_vectors(adata) | ||
assert x | ||
|
||
def tearDown(self): | ||
try: | ||
print(os.getcwd()) | ||
shutil.rmtree(os.path.join(os.getcwd(), "data")) | ||
except: | ||
pass | ||
|
||
|