Skip to content

Commit

Permalink
init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
amva13 committed Oct 20, 2024
1 parent cf88543 commit 1353929
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 2 deletions.
4 changes: 3 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- mygene=3.2.2
- numpy=1.26.4
- openpyxl=3.0.10
- python=3.9.13
- python=3.10
- pip=23.3.1
- pandas=2.1.4
- requests=2.31.0
Expand Down Expand Up @@ -43,3 +43,5 @@ dependencies:

variables:
KMP_DUPLICATE_LIB_OK: "TRUE"

# install geneformer via script https://github.com/jkobject/geneformer/tree/main
2 changes: 1 addition & 1 deletion tdc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .evaluator import Evaluator
from .oracles import Oracle
from .benchmark_deprecated import BenchmarkGroup
from .tdc_hf import tdc_hf_interface
from .model_server.tdc_hf import tdc_hf_interface
from tdc.utils.knowledge_graph import KnowledgeGraph
6 changes: 6 additions & 0 deletions tdc/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,9 @@ def get_task2category():
"pinnacle_output8": "zip",
"pinnacle_output9": "zip",
"pinnacle_output10": "zip",
"geneformer_gene_median_dictionary": "pkl",
"geneformer_gene_name_id_dict": "pkl",
"geneformer_token_dictionary": "pkl",
}

name2id = {
Expand Down Expand Up @@ -1124,6 +1127,9 @@ def get_task2category():
"pinnacle_output8": 10431074,
"pinnacle_output9": 10431075,
"pinnacle_output10": 10431081,
"geneformer_gene_median_dictionary": 10626278,
"geneformer_gene_name_id_dict": 10626276,
"geneformer_token_dictionary": 10626277,
}

oracle2type = {
Expand Down
Empty file added tdc/model_server/__init__.py
Empty file.
File renamed without changes.
Empty file.
109 changes: 109 additions & 0 deletions tdc/model_server/tokenizers/geneformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import numpy as np
import scipy.sparse as sp

from geneformer import TranscriptomeTokenizer
from ...utils.load import pd_load, download_wrapper

class GeneformerTokenizer(TranscriptomeTokenizer):
"""
Uses Geneformer Utils to parse zero-shot model server requests for tokenizing single-cell gene expression data.
Geneformer tokenizer source code: https://github.com/jkobject/geneformer/blob/main/geneformer/tokenizer.py
"""

def __init__(self, path=None, custom_attr_name_dict=None, nproc=1,):
path = path or "./data"
download_wrapper("geneformer_gene_median_dictionary", path, ["geneformer_gene_median_dictionary"])
download_wrapper("geneformer_gene_name_id_dict", path, ["geneformer_gene_name_id_dict"])
download_wrapper("geneformer_token_dictionary", path, ["geneformer_token_dictionary"])
self.gene_median_dict = pd_load("geneformer_gene_median_dictionary", path=path)
self.gene_name_id_dict = pd_load("geneformer_gene_name_id_dict", path=path)
self.gene_token_dict = pd_load("geneformer_token_dictionary", path=path)
self.custom_attr_name_dict = custom_attr_name_dict
self.nproc = nproc

# gene keys for full vocabulary
self.gene_keys = list(self.gene_median_dict.keys())

# protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))

@classmethod
def rank_genes(gene_vector, gene_tokens):
"""
Rank gene expression vector.
"""
# sort by median-scaled gene values
sorted_indices = np.argsort(-gene_vector)
return gene_tokens[sorted_indices]

def tokenize_cell_vectors(self, cell_vector_adata, target_sum=10_000, chunk_size=512, ensembl_id="ensembl_id"):
"""
Tokenizing single-cell gene expression vectors formatted as anndata types.
"""
adata = cell_vector_adata
if self.custom_attr_name_dict is not None:
file_cell_metadata = {
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
}

coding_miRNA_loc = np.where(
[self.genelist_dict.get(i, False) for i in adata.var[ensembl_id]]
)[0]
norm_factor_vector = np.array(
[
self.gene_median_dict[i]
for i in adata.var[ensembl_id][coding_miRNA_loc]
]
)
coding_miRNA_ids = adata.var[ensembl_id][coding_miRNA_loc]
coding_miRNA_tokens = np.array(
[self.gene_token_dict[i] for i in coding_miRNA_ids]
)

try:
_ = adata.obs["filter_pass"]
except KeyError:
var_exists = False
else:
var_exists = True

if var_exists:
filter_pass_loc = np.where(
[i == 1 for i in adata.obs["filter_pass"]]
)[0]
elif not var_exists:
print(
f"The anndata object has no column attribute 'filter_pass'; tokenizing all cells."
)
filter_pass_loc = np.array([i for i in range(adata.shape[0])])

tokenized_cells = []

for i in range(0, len(filter_pass_loc), chunk_size):
idx = filter_pass_loc[i:i+chunk_size]

print(adata[idx].obs.columns)

n_counts = adata[idx].obs['ncounts'].values[:, None]
X_view = adata[idx, coding_miRNA_loc].X
X_norm = (X_view / n_counts * target_sum / norm_factor_vector)
# print(type(adata[idx].X))
# X_norm = adata[idx].X["normalized"]
# X_norm = sp.csr_matrix(X_norm)

tokenized_cells += [
self.rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
for i in range(X_norm.shape[0])
]

# add custom attributes for subview to dict
if self.custom_attr_name_dict is not None:
for k in file_cell_metadata.keys():
file_cell_metadata[k] += adata[idx].obs[k].tolist()
else:
file_cell_metadata = None

return tokenized_cells, file_cell_metadata

57 changes: 57 additions & 0 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-

import os
import sys

import unittest
import shutil
import pytest

# temporary solution for relative imports in case TDC is not installed
# if TDC is installed, no need to use the following line
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
# TODO: add verification for the generation other than simple integration

from tdc.resource import cellxgene_census
from tdc.model_server.tokenizers.geneformer import GeneformerTokenizer

class TestModelServer(unittest.TestCase):

def setUp(self):
print(os.getcwd())
self.resource = cellxgene_census.CensusResource()

def testGeneformerTokenizer(self):
# genes = ['ENSG00000161798', 'ENSG00000188229']
# cell_types = ['mucus secreting cell', 'neuroendocrine cell']
# obs_cols = ["dataset_id", "assay", "suspension_type", "sex", "tissue_general", "tissue", "cell_type", "ncounts"]
# adata = self.resource.gget_czi_cellxgene(
# ensembl=True,
# gene=genes,
# cell_type=cell_types,
# column_names=obs_cols,
# )
# TODO: scperturb is using chembl, NOT ensembl. geneformer assumes ensembl. can fix by going back to cellxgene and not normalizing
from tdc.multi_pred.perturboutcome import PerturbOutcome
test_loader = PerturbOutcome(
name="scperturb_drug_AissaBenevolenskaya2021")
adata = test_loader.adata
print(type(adata.var))
print(adata.var.columns)
print(type(adata.obs))
print(adata.obs.columns)
print("initializing tokenizer")
tokenizer = GeneformerTokenizer()
print("testing tokenizer")
x = tokenizer.tokenize_cell_vectors(adata)
assert x

def tearDown(self):
try:
print(os.getcwd())
shutil.rmtree(os.path.join(os.getcwd(), "data"))
except:
pass


0 comments on commit 1353929

Please sign in to comment.