Skip to content

Commit

Permalink
Merge pull request #323 from mims-harvard/geneformer_server
Browse files Browse the repository at this point in the history
Model server with Geneformer pilot: tokenizer
  • Loading branch information
amva13 authored Oct 20, 2024
2 parents cf88543 + e03614b commit 7ab1fc1
Show file tree
Hide file tree
Showing 11 changed files with 253 additions and 5 deletions.
10 changes: 8 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ workflows:
jobs:
test-3.9:
docker:
- image: circleci/python:3.9
- image: circleci/python:3.10

working_directory: ~/repo

Expand All @@ -26,6 +26,12 @@ jobs:
# fallback to using the latest cache if no exact match is found
- v1-py3-dependencies-

- run:
name: Install git-lfs
command: |
sudo apt-get install git-lfs
git lfs install
- run:
name: install dependencies
command: |
Expand All @@ -49,7 +55,7 @@ jobs:
no_output_timeout: 30m
command: |
. venv/bin/activate
pytest --ignore=tdc/test/dev_tests/ --ignore=tdc/test/test_resources.py --ignore=tdc/test/test_dataloaders.py
pytest --ignore=tdc/test/dev_tests/ --ignore=tdc/test/test_resources.py --ignore=tdc/test/test_dataloaders.py --ignore=tdc/test/test_model_server.py
- store_artifacts:
path: test-reports
Expand Down
7 changes: 6 additions & 1 deletion .github/workflows/conda-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ jobs:
- name: Set up Python version
uses: actions/setup-python@v1
with:
python-version: '3.9'
python-version: '3.10'

- name: Install git-lfs
run: |
sudo apt-get install git-lfs
git lfs install
- name: Setup Miniconda
uses: conda-incubator/setup-miniconda@v2
Expand Down
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- mygene=3.2.2
- numpy=1.26.4
- openpyxl=3.0.10
- python=3.9.13
- python=3.10
- pip=23.3.1
- pandas=2.1.4
- requests=2.31.0
Expand Down Expand Up @@ -40,6 +40,7 @@ dependencies:
- torchvision==0.16.1
- transformers==4.43.4
- yapf==0.40.2
- git+https://github.com/amva13/geneformer.git@main#egg=geneformer

variables:
KMP_DUPLICATE_LIB_OK: "TRUE"
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ pydantic>=2.6.3,<3.0.0
rdkit>=2023.9.5,<2024.3.1
tiledbsoma>=1.7.2,<2.0.0
yapf>=0.40.2,<1.0.0

# github packages
git+https://github.com/amva13/geneformer.git@main#egg=geneformer
2 changes: 1 addition & 1 deletion tdc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .evaluator import Evaluator
from .oracles import Oracle
from .benchmark_deprecated import BenchmarkGroup
from .tdc_hf import tdc_hf_interface
from .model_server.tdc_hf import tdc_hf_interface
from tdc.utils.knowledge_graph import KnowledgeGraph
6 changes: 6 additions & 0 deletions tdc/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,9 @@ def get_task2category():
"pinnacle_output8": "zip",
"pinnacle_output9": "zip",
"pinnacle_output10": "zip",
"geneformer_gene_median_dictionary": "pkl",
"geneformer_gene_name_id_dict": "pkl",
"geneformer_token_dictionary": "pkl",
}

name2id = {
Expand Down Expand Up @@ -1124,6 +1127,9 @@ def get_task2category():
"pinnacle_output8": 10431074,
"pinnacle_output9": 10431075,
"pinnacle_output10": 10431081,
"geneformer_gene_median_dictionary": 10626278,
"geneformer_gene_name_id_dict": 10626276,
"geneformer_token_dictionary": 10626277,
}

oracle2type = {
Expand Down
Empty file added tdc/model_server/__init__.py
Empty file.
File renamed without changes.
Empty file.
117 changes: 117 additions & 0 deletions tdc/model_server/tokenizers/geneformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import numpy as np
import scipy.sparse as sp

from geneformer import TranscriptomeTokenizer
from ...utils.load import pd_load, download_wrapper


class GeneformerTokenizer(TranscriptomeTokenizer):
"""
Uses Geneformer Utils to parse zero-shot model server requests for tokenizing single-cell gene expression data.
Tokenizer source code: https://github.com/amva13/geneformer/blob/main/geneformer/tokenizer.py
"""

def __init__(
self,
path=None,
custom_attr_name_dict=None,
nproc=1,
):
path = path or "./data"
download_wrapper("geneformer_gene_median_dictionary", path,
["geneformer_gene_median_dictionary"])
download_wrapper("geneformer_gene_name_id_dict", path,
["geneformer_gene_name_id_dict"])
download_wrapper("geneformer_token_dictionary", path,
["geneformer_token_dictionary"])
self.gene_median_dict = pd_load("geneformer_gene_median_dictionary",
path=path)
self.gene_name_id_dict = pd_load("geneformer_gene_name_id_dict",
path=path)
self.gene_token_dict = pd_load("geneformer_token_dictionary", path=path)
self.custom_attr_name_dict = custom_attr_name_dict
self.nproc = nproc

# gene keys for full vocabulary
self.gene_keys = list(self.gene_median_dict.keys())

# protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
self.genelist_dict = dict(
zip(self.gene_keys, [True] * len(self.gene_keys)))

@classmethod
def rank_genes(cls, gene_vector, gene_tokens):
"""
Rank gene expression vector.
"""
# sort by median-scaled gene values
sorted_indices = np.argsort(-gene_vector)
return gene_tokens[sorted_indices]

def tokenize_cell_vectors(self,
cell_vector_adata,
target_sum=10_000,
chunk_size=512,
ensembl_id="ensembl_id"):
"""
Tokenizing single-cell gene expression vectors formatted as anndata types.
"""
adata = cell_vector_adata
if self.custom_attr_name_dict is not None:
file_cell_metadata = {
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
}

coding_miRNA_loc = np.where([
self.genelist_dict.get(i, False) for i in adata.var[ensembl_id]
])[0]
norm_factor_vector = np.array([
self.gene_median_dict[i]
for i in adata.var[ensembl_id][coding_miRNA_loc]
])
coding_miRNA_ids = adata.var[ensembl_id][coding_miRNA_loc]
coding_miRNA_tokens = np.array(
[self.gene_token_dict[i] for i in coding_miRNA_ids])

try:
_ = adata.obs["filter_pass"]
except KeyError:
var_exists = False
else:
var_exists = True

if var_exists:
filter_pass_loc = np.where(
[i == 1 for i in adata.obs["filter_pass"]])[0]
elif not var_exists:
print(
f"The anndata object has no column attribute 'filter_pass'; tokenizing all cells."
)
filter_pass_loc = np.array([i for i in range(adata.shape[0])])

tokenized_cells = []

for i in range(0, len(filter_pass_loc), chunk_size):
idx = filter_pass_loc[i:i + chunk_size]

n_counts = adata[idx].obs['ncounts'].values[:, None]
X_view = adata[idx, coding_miRNA_loc].X
X_norm = (X_view / n_counts * target_sum / norm_factor_vector)
X_norm = sp.csr_matrix(X_norm)

tokenized_cells += [
self.rank_genes(X_norm[i].data,
coding_miRNA_tokens[X_norm[i].indices])
for i in range(X_norm.shape[0])
]

# add custom attributes for subview to dict
if self.custom_attr_name_dict is not None:
for k in file_cell_metadata.keys():
file_cell_metadata[k] += adata[idx].obs[k].tolist()
else:
file_cell_metadata = None

return tokenized_cells, file_cell_metadata
110 changes: 110 additions & 0 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-

import os
import sys

import unittest
import shutil
import pytest

# temporary solution for relative imports in case TDC is not installed
# if TDC is installed, no need to use the following line
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
# TODO: add verification for the generation other than simple integration

from tdc.resource import cellxgene_census
from tdc.model_server.tokenizers.geneformer import GeneformerTokenizer

import requests


def get_target_from_chembl(chembl_id):
# Query ChEMBL API for target information
chembl_url = f"https://www.ebi.ac.uk/chembl/api/data/target/{chembl_id}.json"
response = requests.get(chembl_url)

if response.status_code == 200:
data = response.json()
# Extract UniProt ID from the ChEMBL target info
for component in data.get('target_components', []):
for xref in component.get('target_component_xrefs', []):
if xref['xref_src_db'] == 'UniProt':
return xref['xref_id']
else:
raise ValueError(f"ChEMBL ID {chembl_id} not found or invalid.")
return None


def get_ensembl_from_uniprot(uniprot_id):
# Query UniProt API to get Ensembl ID from UniProt ID
uniprot_url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.json"
response = requests.get(uniprot_url)

if response.status_code == 200:
data = response.json()
# Extract Ensembl Gene ID from the cross-references
for xref in data.get('dbReferences', []):
if xref['type'] == 'Ensembl':
return xref['id']
else:
raise ValueError(f"UniProt ID {uniprot_id} not found or invalid.")
return None


def get_ensembl_id_from_chembl_id(chembl_id):
try:
# Step 1: Get UniProt ID from ChEMBL
uniprot_id = get_target_from_chembl(chembl_id)
if not uniprot_id:
return f"No UniProt ID found for ChEMBL ID {chembl_id}"

# Step 2: Get Ensembl ID from UniProt
ensembl_id = get_ensembl_from_uniprot(uniprot_id)
if not ensembl_id:
return f"No Ensembl ID found for UniProt ID {uniprot_id}"

return f"Ensembl ID for ChEMBL ID {chembl_id}: {ensembl_id}"
except Exception as e:
return str(e)


class TestModelServer(unittest.TestCase):

def setUp(self):
print(os.getcwd())
self.resource = cellxgene_census.CensusResource()

def testGeneformerTokenizer(self):
import anndata
from tdc.multi_pred.perturboutcome import PerturbOutcome
test_loader = PerturbOutcome(
name="scperturb_drug_AissaBenevolenskaya2021")
adata = test_loader.adata
print("swapping obs and var because scperturb violated convention...")
adata_flipped = anndata.AnnData(adata.X.T)
adata_flipped.obs = adata.var
adata_flipped.var = adata.obs
adata = adata_flipped
print("swap complete")
print("adding ensembl ids...")
adata.var["ensembl_id"] = adata.var["chembl-ID"].apply(
get_ensembl_id_from_chembl_id)
print("added ensembl_id column")

print(type(adata.var))
print(adata.var.columns)
print(type(adata.obs))
print(adata.obs.columns)
print("initializing tokenizer")
tokenizer = GeneformerTokenizer()
print("testing tokenizer")
x = tokenizer.tokenize_cell_vectors(adata)
assert x

def tearDown(self):
try:
print(os.getcwd())
shutil.rmtree(os.path.join(os.getcwd(), "data"))
except:
pass

0 comments on commit 7ab1fc1

Please sign in to comment.