Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Geneformer updates for July 2024 LTS #961

Merged
merged 52 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
ab747d5
first cut
mlin Jan 10, 2024
af87e86
polish
mlin Jan 10, 2024
6214f8d
Merge remote-tracking branch 'origin/main' into mlin/geneformer-strea…
mlin Jan 10, 2024
00c1f4d
smaller docker image
mlin Jan 28, 2024
82d8246
wip
mlin Jan 31, 2024
92402db
readme
mlin Jan 31, 2024
73f9b91
owls
mlin Jan 31, 2024
e176ea7
gz
mlin Jan 31, 2024
1992cf3
Merge remote-tracking branch 'origin/main' into mlin/geneformer-healt…
mlin Feb 19, 2024
45deef9
update label blocklist
mlin Feb 19, 2024
f8587ca
GeneformerTokenizer special_tokens option
mlin Jun 21, 2024
9b2bcd3
add gene_mapping
mlin Jun 21, 2024
e502210
Merge remote-tracking branch 'origin/main' into mlin/geneformer-healt…
mlin Jun 23, 2024
9cdf861
lint
mlin Jun 23, 2024
9bee474
lint
mlin Jun 23, 2024
f77163c
bump docker versions
mlin Jun 23, 2024
ccdc054
Merge branch 'mlin/geneformer-updates-jun24' into mlin/geneformer-hea…
mlin Jun 23, 2024
bd855ed
tokenizer_kwargs passthrough
mlin Jun 23, 2024
602c005
lint
mlin Jun 23, 2024
9933309
lint
mlin Jun 23, 2024
682f52a
lint
mlin Jun 23, 2024
7065636
tokenizer_kwargs passthrough WDL
mlin Jun 23, 2024
8fe66f9
workaround
mlin Jun 24, 2024
a23f562
tolerance
mlin Jun 24, 2024
4b996f8
sum mapped columns
mlin Jun 24, 2024
51aa427
buildspec.yml
mlin Jun 25, 2024
0490341
buildspec.yml
mlin Jun 25, 2024
9ad6e24
buildspec.yml
mlin Jun 25, 2024
02497c4
gf-95m from S3
mlin Jun 25, 2024
fa1fbe5
fix
mlin Jun 25, 2024
8e3167c
fix
mlin Jun 25, 2024
28dbbd1
fix
mlin Jun 25, 2024
2171ab2
fix
mlin Jun 25, 2024
c4fedb2
fix
mlin Jun 25, 2024
cfda1b5
fix
mlin Jun 25, 2024
4388b4a
ontology updates
mlin Jun 25, 2024
305d124
fix
mlin Jun 25, 2024
52d9df2
handle unknown
mlin Jun 25, 2024
105e716
passthrough model_type
mlin Jun 26, 2024
26ed332
posterity comments
mlin Jun 26, 2024
833db5d
Merge remote-tracking branch 'origin/mlin/geneformer-updates-jun24' i…
mlin Jun 27, 2024
ae04e38
use sparse binary matrix to consolidate gene counts
mlin Jun 27, 2024
79e0cec
reskip
mlin Jun 27, 2024
4ea5fcd
fix
mlin Jun 27, 2024
f076e83
lint
mlin Jun 28, 2024
fe856e3
Merge remote-tracking branch 'origin/main' into mlin/geneformer-healt…
mlin Jun 28, 2024
1ae2d91
update readme
mlin Jun 28, 2024
e044a53
update Geneformer upstream version
mlin Jul 3, 2024
de6eea6
Merge remote-tracking branch 'origin/main' into mlin/geneformer-healt…
mlin Jul 3, 2024
ea1137a
strike geneformer from requirements-dev.txt due to python>=3.10 requi…
mlin Jul 5, 2024
abbcd4b
install geneformer in CI (python >=3.10 only)
mlin Jul 5, 2024
e12a102
lint
mlin Jul 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/py-unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ jobs:
pip install --use-pep517 accumulation-tree # Geneformer dependency needs --use-pep517 for Cython
GIT_CLONE_PROTECTION_ACTIVE=false pip install -r ./api/python/cellxgene_census/scripts/requirements-dev.txt
pip install -e './api/python/cellxgene_census/[experimental]'
- name: Install Geneformer (python >=3.10 only)
run: pip install git+https://huggingface.co/ctheodoris/Geneformer@471eefc
if: matrix.python-version != '3.8' && matrix.python-version != '3.9'
- name: Report Dependency Versions
run: pip list
- name: Test with pytest (API, main tests)
Expand Down
1 change: 0 additions & 1 deletion api/python/cellxgene_census/scripts/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@ twine
coverage
nbqa
transformers[torch]
git+https://huggingface.co/ctheodoris/Geneformer@8df5dc1
owlready2
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def gen() -> Generator[Dict[str, Any], None, None]:
self.X(self.layer_name).blockwise(axis=0, reindex_disable_on_axis=[1], size=self.block_size).scipy()
):
assert isinstance(Xblock, scipy.sparse.csr_matrix)
assert Xblock.shape[0] == len(block_cell_joinids)
for i, cell_joinid in enumerate(block_cell_joinids):
yield self.cell_item(cell_joinid, Xblock.getrow(i))

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pickle
from typing import Any, Dict, Optional, Sequence, Set
from typing import Any, Dict, List, Optional, Sequence, Set

import numpy as np
import numpy.typing as npt
Expand All @@ -14,7 +14,7 @@ class GeneformerTokenizer(CellDatasetBuilder):
cell in CELLxGENE Census ExperimentAxisQuery results (human).

This class requires the Geneformer package to be installed separately with:
`pip install git+https://huggingface.co/ctheodoris/Geneformer@8df5dc1`
`pip install git+https://huggingface.co/ctheodoris/Geneformer@471eefc`

Example usage:

Expand Down Expand Up @@ -44,11 +44,18 @@ class GeneformerTokenizer(CellDatasetBuilder):

obs_column_names: Set[str]
max_input_tokens: int

# set of gene soma_joinids corresponding to genes modeled by Geneformer:
model_gene_ids: npt.NDArray[np.int64]
model_gene_tokens: npt.NDArray[np.int64] # token for each model_gene_id
model_gene_medians: npt.NDArray[np.float64] # float for each model_gene_id
special_token: bool

# Newer versions of Geneformer has a consolidated gene list (gene_mapping_file), meaning the
# counts for one or more Census genes are to be summed to get the count for one Geneformer
# gene. model_gene_map is a sparse binary matrix to map a cell vector (or multi-cell matrix) of
# Census gene counts onto Geneformer gene counts. model_gene_map[i,j] is 1 iff the i'th Census
# gene count contributes to the j'th Geneformer gene count.
model_gene_map: scipy.sparse.coo_matrix
model_gene_tokens: npt.NDArray[np.int64] # Geneformer token for each column of model_gene_map
model_gene_medians: npt.NDArray[np.float64] # float for each column of model_gene_map
model_cls_token: Optional[np.int64] = None
model_sep_token: Optional[np.int64] = None

def __init__(
self,
Expand All @@ -57,25 +64,33 @@ def __init__(
obs_column_names: Optional[Sequence[str]] = None,
obs_attributes: Optional[Sequence[str]] = None,
max_input_tokens: int = 2048,
special_token: bool = False,
token_dictionary_file: str = "",
gene_median_file: str = "",
gene_mapping_file: str = "",
**kwargs: Any,
) -> None:
"""- `experiment`: Census Experiment to query
"""Initialize GeneformerTokenizer.

Args:
- `experiment`: Census Experiment to query
- `obs_query`: obs AxisQuery defining the set of Census cells to process (default all)
- `obs_column_names`: obs dataframe columns (cell metadata) to propagate into attributes
of each Dataset item
- `max_input_tokens`: maximum length of Geneformer input token sequence (default 2048)
- `special_token`: whether to affix separator tokens to the sequence (default False)
- `token_dictionary_file`, `gene_median_file`: pickle files supplying the mapping of
Ensembl human gene IDs onto Geneformer token numbers and median expression values.
By default, these will be loaded from the Geneformer package.
- `gene_mapping_file`: optional pickle file with mapping for Census gene IDs to model's
"""
if obs_attributes: # old name of obs_column_names
obs_column_names = obs_attributes

self.max_input_tokens = max_input_tokens
self.special_token = special_token
self.obs_column_names = set(obs_column_names) if obs_column_names else set()
self._load_geneformer_data(experiment, token_dictionary_file, gene_median_file)
self._load_geneformer_data(experiment, token_dictionary_file, gene_median_file, gene_mapping_file)
super().__init__(
experiment,
measurement_name="RNA",
Expand All @@ -88,14 +103,21 @@ def _load_geneformer_data(
experiment: tiledbsoma.Experiment,
token_dictionary_file: str,
gene_median_file: str,
gene_mapping_file: str,
) -> None:
"""Load (1) the experiment's genes dataframe and (2) Geneformer's static data
files for gene tokens and median expression; then, intersect them to compute
self.model_gene_{ids,tokens,medians}.
"""
# TODO: this work could be reused for all queries on this experiment

genes_df = experiment.ms["RNA"].var.read(column_names=["soma_joinid", "feature_id"]).concat().to_pandas()
genes_df = (
experiment.ms["RNA"]
.var.read(column_names=["soma_joinid", "feature_id"])
.concat()
.to_pandas()
.set_index("soma_joinid")
)

if not (token_dictionary_file and gene_median_file):
try:
Expand All @@ -104,7 +126,7 @@ def _load_geneformer_data(
# pyproject.toml can't express Geneformer git+https dependency
raise ImportError(
"Please install Geneformer with: "
"pip install git+https://huggingface.co/ctheodoris/Geneformer@8df5dc1"
"pip install git+https://huggingface.co/ctheodoris/Geneformer@471eefc"
) from None
if not token_dictionary_file:
token_dictionary_file = geneformer.tokenizer.TOKEN_DICTIONARY_FILE
Expand All @@ -115,34 +137,58 @@ def _load_geneformer_data(
with open(gene_median_file, "rb") as f:
gene_median_dict = pickle.load(f)

gene_mapping = None
if gene_mapping_file:
with open(gene_mapping_file, "rb") as f:
gene_mapping = pickle.load(f)

# compute model_gene_{ids,tokens,medians} by joining genes_df with Geneformer's
# dicts
model_gene_ids = []
model_gene_tokens = []
model_gene_medians = []
map_data = []
map_i = []
map_j = []
model_gene_id_by_ensg: Dict[str, int] = {}
model_gene_count = 0
model_gene_tokens: List[np.int64] = []
model_gene_medians: List[np.float64] = []
for gene_id, row in genes_df.iterrows():
ensg = row["feature_id"] # ENSG... gene id, which keys Geneformer's dicts
if gene_mapping is not None:
ensg = gene_mapping.get(ensg, ensg)
if ensg in gene_token_dict:
model_gene_ids.append(gene_id)
model_gene_tokens.append(gene_token_dict[ensg])
model_gene_medians.append(gene_median_dict[ensg])
self.model_gene_ids = np.array(model_gene_ids, dtype=np.int64)
if ensg not in model_gene_id_by_ensg:
model_gene_id_by_ensg[ensg] = model_gene_count
model_gene_count += 1
model_gene_tokens.append(gene_token_dict[ensg])
model_gene_medians.append(gene_median_dict[ensg])
map_data.append(1)
map_i.append(gene_id)
map_j.append(model_gene_id_by_ensg[ensg])

self.model_gene_map = scipy.sparse.coo_matrix(
(map_data, (map_i, map_j)), shape=(genes_df.index.max() + 1, model_gene_count), dtype=bool
)
self.model_gene_tokens = np.array(model_gene_tokens, dtype=np.int64)
self.model_gene_medians = np.array(model_gene_medians, dtype=np.float64)

assert len(np.unique(self.model_gene_ids)) == len(self.model_gene_ids)
assert len(np.unique(self.model_gene_tokens)) == len(self.model_gene_tokens)
assert np.all(self.model_gene_medians > 0)
# Geneformer models protein-coding and miRNA genes, so the intersection should
# be somewhere a little north of 20K.
assert len(self.model_gene_ids) > 20_000
# be north of 18K.
assert (
model_gene_count > 18_000
), f"Mismatch between Census gene IDs and Geneformer token dicts (only {model_gene_count} common genes)"

# Precompute a vector by which we'll multiply each cell's expression vector.
# The denominator normalizes by Geneformer's median expression values.
# The numerator 10K factor follows Geneformer's tokenizer; theoretically it doesn't affect
# affect the rank order, but is probably intended to help with numerical precision.
self.model_gene_medians_factor = 10_000.0 / self.model_gene_medians

if self.special_token:
self.model_cls_token = gene_token_dict["<cls>"]
self.model_sep_token = gene_token_dict["<sep>"]

def __enter__(self) -> "GeneformerTokenizer":
super().__enter__()
# On context entry, load the necessary cell metadata (obs_df)
Expand All @@ -156,21 +202,29 @@ def cell_item(self, cell_joinid: int, cell_Xrow: scipy.sparse.csr_matrix) -> Dic
"""Given the expression vector for one cell, compute the Dataset item providing
the Geneformer inputs (token sequence and metadata).
"""
# project cell_Xrow onto model_gene_ids and normalize by row sum.
# notice we divide by the total count of the complete row (not only of the projected
# Apply model_gene_map to cell_Xrow and normalize with row sum & gene medians.
# Notice we divide by the total count of the complete row (not only of the projected
# values); this follows Geneformer's internal tokenizer.
model_counts = cell_Xrow[:, self.model_gene_ids].multiply(1.0 / cell_Xrow.sum())
assert isinstance(model_counts, scipy.sparse.csr_matrix), type(model_counts)
# assert len(model_counts.data) == np.count_nonzero(model_counts.data)
model_expr = model_counts.multiply(self.model_gene_medians_factor)
model_expr = (cell_Xrow * self.model_gene_map).multiply(self.model_gene_medians_factor / cell_Xrow.sum())
assert isinstance(model_expr, scipy.sparse.coo_matrix), type(model_expr)
# assert len(model_expr.data) == np.count_nonzero(model_expr.data)
assert model_expr.shape == (1, self.model_gene_map.shape[1])

# figure the resulting tokens in descending order of model_expr
# (use sparse model_expr.{col,data} to naturally exclude undetected genes)
token_order = model_expr.col[np.argsort(-model_expr.data)[: self.max_input_tokens]]
input_ids = self.model_gene_tokens[token_order]

if self.special_token:
# affix special tokens, dropping the last two gene tokens if necessary
if len(input_ids) == self.max_input_tokens:
input_ids = input_ids[:-1]
assert self.model_cls_token is not None
input_ids = np.insert(input_ids, 0, self.model_cls_token)
if len(input_ids) == self.max_input_tokens:
input_ids = input_ids[:-1]
assert self.model_sep_token is not None
input_ids = np.append(input_ids, self.model_sep_token)

ans = {"input_ids": input_ids, "length": len(input_ids)}
# add the requested obs attributes
for attr in self.obs_column_names:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import datasets
import pytest
import tiledbsoma
Expand Down Expand Up @@ -67,7 +69,7 @@ def test_GeneformerTokenizer_correctness(tmpdir: Path) -> None:
ad.write_h5ad(h5ad_dir.join("tokenizeme.h5ad"))
# run geneformer.TranscriptomeTokenizer to get "true" tokenizations
# see: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/tokenizer.py
TranscriptomeTokenizer({}).tokenize_data(h5ad_dir, tmpdir, "tk", file_format="h5ad")
TranscriptomeTokenizer({}).tokenize_data(h5ad_dir, str(tmpdir), "tk", file_format="h5ad")
true_tokens = [it["input_ids"] for it in datasets.load_from_disk(tmpdir.join("tk.dataset"))]

# check GeneformerTokenizer sequences against geneformer.TranscriptomeTokenizer's
Expand All @@ -88,6 +90,7 @@ def test_GeneformerTokenizer_correctness(tmpdir: Path) -> None:
assert identical / len(cell_ids) >= EXACT_THRESHOLD


@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
@pytest.mark.experimental
@pytest.mark.live_corpus
def test_GeneformerTokenizer_docstring_example() -> None:
Expand Down
46 changes: 27 additions & 19 deletions tools/models/geneformer/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,41 +1,49 @@
# Builds a docker image with:
# - PyTorch+CUDA
# - CUDA+PyTorch
# - Geneformer
# - cellxgene_census
# - our Census-Geneformer training scripts
FROM nvcr.io/nvidia/pytorch:23.10-py3
FROM nvcr.io/nvidia/cuda:11.8.0-runtime-ubuntu22.04

# Set the tiledbsoma version used to write the embeddings SparseNDArray, to ensure
# compatibility with the Census embeddings curator
ARG EMBEDDINGS_TILEDBSOMA_VERSION=1.4.4
ARG GENEFORMER_VERSION=8df5dc1

RUN apt update && apt install -y python3-venv git-lfs pigz
RUN apt update && apt install -y build-essential python3-pip python3-venv git-lfs pigz libcurl4-openssl-dev
RUN git lfs install
ENV GIT_SSL_NO_VERIFY=true
RUN pip install \
transformers[torch] \
"cellxgene_census[experimental] @ git+https://github.com/chanzuckerberg/cellxgene-census.git#subdirectory=api/python/cellxgene_census" \
git+https://huggingface.co/ctheodoris/Geneformer@${GENEFORMER_VERSION}
RUN pip install owlready2 boto3

ENV GIT_SSL_NO_VERIFY=true
RUN pip install --upgrade pip setuptools setuptools_scm
RUN pip install torch torchdata --index-url https://download.pytorch.org/whl/cu118
# ^^^ match the base image CUDA version!
RUN pip install owlready2 boto3 transformers[torch]
# workaround for unknown problem blocking `import geneformer`:
# https://github.com/microsoft/TaskMatrix/issues/116#issuecomment-1565431850
RUN pip uninstall -y transformer-engine
# smoke test
RUN python3 -c 'import geneformer; import cellxgene_census; cellxgene_census.open_soma()'

# Set the tiledbsoma version used to write the embeddings SparseNDArray, to ensure
# compatibility with the Census embeddings curator
ARG EMBEDDINGS_TILEDBSOMA_VERSION=1.9.5
ARG CELLXGENE_CENSUS_VERSION=main
ARG GENEFORMER_VERSION=471eefc

RUN mkdir /census-geneformer
WORKDIR /census-geneformer
# clone Geneformer separately to get LFS files
RUN git clone https://github.com/chanzuckerberg/cellxgene-census.git \
&& git -C cellxgene-census checkout ${CELLXGENE_CENSUS_VERSION}
RUN pip install cellxgene-census/api/python/cellxgene_census
RUN git clone --recursive https://huggingface.co/ctheodoris/Geneformer \
&& git -C Geneformer checkout ${GENEFORMER_VERSION}
RUN pip install -e Geneformer

# prepare a venv with tiledbsoma ${EMBEDDINGS_TILEDBSOMA_VERSION}
# smoke test
RUN python3 -c 'import geneformer; import cellxgene_census; from cellxgene_census.experimental.ml.huggingface import GeneformerTokenizer; cellxgene_census.open_soma()'

# prepare a venv with pinned tiledbsoma ${EMBEDDINGS_TILEDBSOMA_VERSION}, which our embeddings
# generation step will use to output a TileDB array compatible with the Census embeddings curator.
RUN python3 -m venv --system-site-packages embeddings_tiledbsoma_venv && \
. embeddings_tiledbsoma_venv/bin/activate && \
pip install tiledbsoma==${EMBEDDINGS_TILEDBSOMA_VERSION}

COPY *.py .
COPY helpers ./helpers
COPY *.py ./
COPY finetune-geneformer.config.yml .

# FIXME: eliminate once model is published in Geneformer repo
COPY gf-95m/ ./gf-95m/
Loading
Loading