diff --git a/.github/workflows/docsite-build-deploy.yml b/.github/workflows/docsite-build-deploy.yml index 2e1efb8c2..811f2e27c 100644 --- a/.github/workflows/docsite-build-deploy.yml +++ b/.github/workflows/docsite-build-deploy.yml @@ -23,7 +23,7 @@ jobs: python -m pip install -U pip setuptools wheel pip install -r ./api/python/cellxgene_census/scripts/requirements-dev.txt pip install -r ./docs/requirements.txt - pip install -e ./api/python/cellxgene_census/ + pip install -e './api/python/cellxgene_census/[experimental]' mkdir -p docsite touch docsite/.nojekyll diff --git a/.github/workflows/py-unittests.yml b/.github/workflows/py-unittests.yml index 13a03bc0d..8431d5bff 100644 --- a/.github/workflows/py-unittests.yml +++ b/.github/workflows/py-unittests.yml @@ -23,13 +23,25 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install dependencies + if: matrix.python-version != '3.10' run: | python -m pip install -U pip setuptools wheel pip install -r ./api/python/cellxgene_census/scripts/requirements-dev.txt - pip install -e ./api/python/cellxgene_census/ + pip install -e './api/python/cellxgene_census/' + - name: Install dependencies (with experimental) + if: matrix.python-version == '3.10' + run: | + python -m pip install -U pip setuptools wheel + pip install -r ./api/python/cellxgene_census/scripts/requirements-dev.txt + pip install -e './api/python/cellxgene_census/[experimental]' - name: Test with pytest (API) + if: matrix.python-version != '3.10' run: | PYTHONPATH=. coverage run --parallel-mode -m pytest -v -rP --durations=20 ./api/python/cellxgene_census/tests/ + - name: Test with pytest (API, with experimental) + if: matrix.python-version == '3.10' + run: | + PYTHONPATH=. coverage run --parallel-mode -m pytest -v -rP --durations=20 --experimental ./api/python/cellxgene_census/tests/ - uses: actions/upload-artifact@v3 if: matrix.os == 'ubuntu-latest' with: diff --git a/api/python/cellxgene_census/pyproject.toml b/api/python/cellxgene_census/pyproject.toml index d105bce2b..d9e982cfc 100644 --- a/api/python/cellxgene_census/pyproject.toml +++ b/api/python/cellxgene_census/pyproject.toml @@ -42,6 +42,13 @@ dependencies= [ "certifi", ] +[project.optional-dependencies] +experimental = [ + "torch==2.0.1", + "torchdata==0.6.1", + "scikit-learn==1.2.2", +] + [project.urls] homepage = "https://github.com/chanzuckerberg/cellxgene-census" repository = "https://github.com/chanzuckerberg/cellxgene-census" @@ -69,6 +76,7 @@ plugins = "numpy.typing.mypy_plugin" markers = [ "live_corpus: runs on the live CELLxGENE Census data corpus and small enough to run in CI", "expensive: too expensive to run regularly or in CI", + "experimental: tests for the `experimental` package", ] [tool.ruff] diff --git a/api/python/cellxgene_census/src/cellxgene_census/_open.py b/api/python/cellxgene_census/src/cellxgene_census/_open.py index 96902227c..2b126c3ea 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/_open.py +++ b/api/python/cellxgene_census/src/cellxgene_census/_open.py @@ -33,7 +33,17 @@ def _open_soma(locator: CensusLocator, context: Optional[soma.options.SOMATileDBContext] = None) -> soma.Collection: """Private. Merge config defaults and return open census as a soma Collection/context.""" - s3_region = locator.get("s3_region") + context = _build_soma_tiledb_context(locator.get("s3_region"), context) + return soma.open(locator["uri"], mode="r", soma_type=soma.Collection, context=context) + + +def _build_soma_tiledb_context( + s3_region: Optional[str] = None, context: Optional[soma.options.SOMATileDBContext] = None +) -> soma.options.SOMATileDBContext: + """ + Private. Build a SOMATileDBContext with sensible defaults. If user-defined context is provided, only update the + `vfs.s3.region` only. + """ if not context: # if no user-defined context, cellxgene_census defaults take precedence over SOMA defaults @@ -50,8 +60,7 @@ def _open_soma(locator: CensusLocator, context: Optional[soma.options.SOMATileDB tiledb_config = context.tiledb_ctx.config() tiledb_config["vfs.s3.region"] = s3_region context = context.replace(tiledb_config=tiledb_config) - - return soma.open(locator["uri"], mode="r", soma_type=soma.Collection, context=context) + return context def open_soma( diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/__init__.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/__init__.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py new file mode 100644 index 000000000..8c89536ae --- /dev/null +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -0,0 +1,457 @@ +import logging +import os +import sys +from datetime import timedelta +from time import time +from typing import Any, Dict, Iterator, Optional, Sequence, Tuple + +import numpy as np +import pandas as pd +import pyarrow as pa +import scipy +import somacore +import tiledbsoma as soma +import torch +import torchdata.datapipes.iter as pipes +from attr import attrs +from scipy import sparse +from sklearn.preprocessing import LabelEncoder +from somacore.query import _fast_csr +from torch import Tensor +from torch.utils.data import DataLoader +from torch.utils.data.dataset import Dataset + +import cellxgene_census +from cellxgene_census._open import _build_soma_tiledb_context + +ObsDatum = Tuple[Tensor, torch.sparse_coo_tensor] + +Encoders = Dict[str, LabelEncoder] + +pytorch_logger = logging.getLogger("cellxgene_census.experimental.pytorch") +pytorch_logger.setLevel(logging.INFO) + + +@attrs +class Stats: + n_obs: int = 0 + """The total number of obs rows retrieved""" + + nnz: int = 0 + """The total number of values retrieved""" + + elapsed: int = 0 + """The total elapsed time in seconds for retrieving all batches""" + + n_soma_batches: int = 0 + """The number of batches retrieved""" + + def __str__(self) -> str: + return ( + f"n_soma_batches={self.n_soma_batches}, n_obs={self.n_obs}, nnz={self.nnz}, " + f"elapsed={timedelta(seconds=self.elapsed)}" + ) + + +def _open_experiment( + uri: str, aws_region: Optional[str] = None, soma_buffer_bytes: Optional[int] = None +) -> soma.Experiment: + context = _build_soma_tiledb_context(aws_region) + + if soma_buffer_bytes is not None: + context = context.replace( + tiledb_config={ + "py.init_buffer_bytes": soma_buffer_bytes, + "soma.init_buffer_bytes": soma_buffer_bytes, + } + ) + + return soma.Experiment.open(uri, context=context) + + +class _ObsAndXIterator(Iterator[ObsDatum]): + """ + Iterates through a set of obs and related X rows, specified as `soma_joinid`s. Encapsulates the batch-based data + fetching from TileDB-SOMA objects, providing row-based iteration. + """ + + obs_tables_iter: somacore.ReadIter[pa.Table] + """Iterates the TileDB-SOMA batches (tables) of obs data""" + + obs_batch_: pd.DataFrame = pd.DataFrame() + """The current TileDB-SOMA batch of obs data""" + + X_batch: scipy.matrix = None + """All X data for the soma_joinids of the current obs - batch""" + + i: int = -1 + """Index into current obs TileDB-SOMA batch""" + + def __init__( + self, + obs_joinids: pa.Array, + var_joinids: pa.Array, + exp_uri: str, + aws_region: Optional[str], + measurement_name: str, + X_layer_name: str, + batch_size: int, + encoders: Dict[str, LabelEncoder], + stats: Stats, + obs_column_names: Sequence[str], + sparse_X: bool, + soma_buffer_bytes: Optional[int] = None, + ) -> None: + self.obs_joinids = obs_joinids + self.var_joinids = var_joinids + self.batch_size = batch_size + self.sparse_X = sparse_X + + # holding reference to SOMA object prevents it from being closed + self.exp = _open_experiment(exp_uri, aws_region, soma_buffer_bytes=soma_buffer_bytes) + self.X: soma.SparseNDArray = self.exp.ms[measurement_name].X[X_layer_name] + self.obs_tables_iter = self.exp.obs.read( + coords=(obs_joinids,), batch_size=somacore.BatchSize(), column_names=obs_column_names + ) + self.encoders = encoders + self.stats = stats + + def __next__(self) -> ObsDatum: + # read the next torch batch, possibly across multiple soma batches + obs: pd.DataFrame = pd.DataFrame() + X: sparse.csr_matrix = sparse.csr_matrix((0, len(self.var_joinids))) + + while len(obs) < self.batch_size: + try: + obs_partial, X_partial = self._read_partial_torch_batch(self.batch_size) + if X is None: + obs = obs_partial + X = X_partial + else: + obs = pd.concat([obs, obs_partial], axis=0) + X = sparse.vstack([X, X_partial]) + + except StopIteration: + break + + if len(obs) == 0: + raise StopIteration + + obs_encoded = pd.DataFrame( + data={"soma_joinid": obs.soma_joinid}, columns=obs.columns, dtype=np.int32, index=obs.index + ) + for col, enc in self.encoders.items(): + obs_encoded[col] = enc.transform(obs[col]) + + obs_tensor = torch.Tensor(obs_encoded.to_numpy()).int() + if not self.sparse_X: + X_tensor = torch.Tensor(X.todense()) + else: + coo = X.tocoo() + + X_tensor = torch.sparse_coo_tensor( + # Note: The `np.array` seems unnecessary, but PyTorch warns bare array is "extremely slow" + indices=torch.Tensor(np.array([coo.row, coo.col])), + values=coo.data, + size=coo.shape, + ) + + if self.batch_size == 1: + X_tensor = X_tensor[0] + obs_tensor = obs_tensor[0] + + return X_tensor, obs_tensor + + def _read_partial_torch_batch(self, batch_size: int) -> Tuple[pd.DataFrame, sparse.csr_matrix]: + safe_batch_size = min(batch_size, len(self.obs_batch) - self.i) + slice_ = slice(self.i, self.i + safe_batch_size) + obs_rows = self.obs_batch.iloc[slice_] + X_csr_scipy = self.X_batch[slice_] + self.i += safe_batch_size + return obs_rows, X_csr_scipy + + @property + # TODO: Retrieve next batch asynchronously, so it's available before the current batch's iteration ends + def obs_batch(self) -> pd.DataFrame: + """ + Returns the current SOMA batch of obs rows. + If the current SOMA batch has been fully iterated, loads the next SOMA batch of both obs and X data and returns + the new obs batch (only). + Raises StopIteration if there are no more SOMA batches to retrieve. + """ + if 0 <= self.i < len(self.obs_batch_): + return self.obs_batch_ + + pytorch_logger.debug("Retrieving next TileDB-SOMA batch...") + start_time = time() + # If no more batch to iterate through, raise StopIteration, as all iterators do when at end + obs_table = next(self.obs_tables_iter) + self.obs_batch_ = obs_table.to_pandas() + # handle case of empty result (first batch has 0 rows) + if len(self.obs_batch_) == 0: + raise StopIteration + self.X_batch = _fast_csr.read_scipy_csr(self.X, obs_table["soma_joinid"].combine_chunks(), self.var_joinids) + self.i = 0 + self.stats.n_obs += self.X_batch.shape[0] + self.stats.nnz += self.X_batch.nnz + self.stats.elapsed += int(time() - start_time) + self.stats.n_soma_batches += 1 + # TODO: also show per-batch stats (non-cumulative) + pytorch_logger.debug(f"Retrieved batch: shape={self.X_batch.shape}, cum_stats: {self.stats}") + return self.obs_batch_ + + +class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsDatum]]): # type: ignore + """ + An iterable-style PyTorch data pipe that reads obs and X data from a SOMA Experiment, and returns an iterator of + tuples of torch tensors: + + (tensor([0., 0., 0., 0., 0., 1., 0., 0., 0.]), # X data + tensor([2415, 0, 0], dtype=torch.int32)) # obs data, encoded + + Supports batching via `batch_size` param: + DataLoader(..., batch_size=3, ...): + + (tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0.], # X batch + [0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0.]]), + tensor([[2415, 0, 0], # obs batch + [2416, 0, 4], + [2417, 0, 3]], dtype=torch.int32)) + + Obs attribute values are encoded as categoricals. Values can be decoded by obtaining the encoder for a given + attribute: + + exp_data_pipe.obs_encoders()[""].inverse_transform(encoded_values) + """ + + _query: Optional[soma.ExperimentAxisQuery] + """In multi-processing mode (i.e. num_workers > 0), this ExperimentAxisQuery object will *not* be pickled; + each worker will instantiate a new query""" + + _obs_joinids_partitioned: Optional[pa.Array] + + _obs_and_x_iter: Optional[_ObsAndXIterator] + + _encoders: Optional[Encoders] + + _stats: Stats + + # TODO: Consider adding another convenience method wrapper to construct this object whose signature is more closely + # aligned with get_anndata() params (i.e. "exploded" AxisQuery params). + def __init__( + self, + experiment: soma.Experiment, + measurement_name: str, + X_name: str, + obs_query: Optional[soma.AxisQuery] = None, + var_query: Optional[soma.AxisQuery] = None, + obs_column_names: Sequence[str] = (), + batch_size: int = 1, + sparse_X: bool = False, + num_workers: int = 0, + soma_buffer_bytes: Optional[int] = None, + ) -> None: + self.exp_uri = experiment.uri + self.aws_region = experiment.context.tiledb_ctx.config().get("vfs.s3.region") + self.measurement_name = measurement_name + self.layer_name = X_name + self.obs_query = obs_query + self.var_query = var_query + self.obs_column_names = obs_column_names + self.batch_size = batch_size + self.sparse_X = sparse_X + # TODO: This will control the SOMA batch sizes, and could be replaced with a row count once TileDB-SOMA + # supports `batch_size` param. It affects both the obs and X read operations. + self.soma_buffer_bytes = soma_buffer_bytes + self._query = None + self._stats = Stats() + self._encoders = None + + if "soma_joinid" not in self.obs_column_names: + self.obs_column_names = ["soma_joinid", *self.obs_column_names] + + def _init(self) -> None: + if self._query is not None: + return + + # TODO: support multiple layers, per somacore.query.query.ExperimentAxisQuery.to_anndata() + # TODO: handle closing of `_query` (and transitively, `exp`) when iterator is no longer in use; may need be used as a ContextManager, + # but not clear how we can do that when used by DataLoader + exp = _open_experiment(self.exp_uri, self.aws_region, soma_buffer_bytes=self.soma_buffer_bytes) + + self._query = exp.axis_query( + measurement_name=self.measurement_name, + obs_query=self.obs_query, + var_query=self.var_query, + ) + + obs_joinids = self._query.obs_joinids() + + worker_info = torch.utils.data.get_worker_info() + if worker_info: + # multi-processing mode + + if worker_info.num_workers > 0 and self.sparse_X: + raise NotImplementedError( + "torch does not work with sparse tensors in multi-processing mode " + "(see https://github.com/pytorch/pytorch/issues/20248)" + ) + + partition, num_partitions = worker_info.id, worker_info.num_workers + + else: + partition, num_partitions = 0, 1 + + if num_partitions is not None: + # partitioned data loading + # NOTE: Can alternately use a `worker_init_fn` to split among workers split workload + partition_size = len(obs_joinids) // num_partitions + partition_start = partition_size * partition + partition_end_excl = min(len(obs_joinids), partition_start + partition_size) + self._obs_joinids_partitioned = obs_joinids[partition_start:partition_end_excl] + + pytorch_logger.debug( + f"Process {os.getpid()} handling partition {partition + 1} of {num_partitions}, " + f"range={partition_start}:{partition_end_excl}, " + f"{partition_size:}" + ) + + def __iter__(self) -> Iterator[ObsDatum]: + self._init() + assert self._query is not None + + return _ObsAndXIterator( + obs_joinids=self._obs_joinids_partitioned or self._query.obs_joinids(), + var_joinids=self._query.var_joinids(), + exp_uri=self.exp_uri, + aws_region=self.aws_region, + measurement_name=self.measurement_name, + X_layer_name=self.layer_name, + batch_size=self.batch_size, + encoders=self.obs_encoders(), + stats=self._stats, + obs_column_names=self.obs_column_names, + sparse_X=self.sparse_X, + soma_buffer_bytes=self.soma_buffer_bytes, + ) + + def __len__(self) -> int: + self._init() + assert self._query is not None + + return int(self._query.n_obs) + + def __getitem__(self, index: int) -> ObsDatum: + raise NotImplementedError("IterDataPipe can only be iterated") + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + # Don't pickle `_query` + del state["_query"] + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) + self._query = None + + def obs_encoders(self) -> Encoders: + if self._encoders is not None: + return self._encoders + + self._init() + assert self._query is not None + + obs = self._query.obs(column_names=self.obs_column_names).concat() + self._encoders = {} + for col in self.obs_column_names: + if obs[col].type in (pa.string(), pa.large_string()): + enc = LabelEncoder() + enc.fit(obs[col].combine_chunks().unique()) + self._encoders[col] = enc + + return self._encoders + + def stats(self) -> Stats: + return self._stats + + @property + def shape(self) -> Tuple[int, int]: + self._init() + assert self._query is not None + + return self._query.n_obs, self._query.n_vars + + +# Note: must be a top-level function (and not a lambda), to play nice with multiprocessing pickling +def collate_noop(x: Any) -> Any: + return x + + +# TODO: Move into somacore.ExperimentAxisQuery +def experiment_dataloader( + datapipe: pipes.IterDataPipe, + num_workers: int = 0, + **dataloader_kwargs: Any, +) -> DataLoader: + """ + Factory method for PyTorch DataLoader. Provides a safer, more convenient interface for instantiating a DataLoader + that works with the ExperimentDataPipe, since not all of DataLoader's params can be used (batch_size, sampler, + batch_sampler, collate_fn). + """ + + unsupported_dataloader_args = ["batch_size", "sampler", "batch_sampler", "collate_fn"] + if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()): + raise ValueError(f"The {','.join(unsupported_dataloader_args)} DataLoader params are not supported") + + return DataLoader( + datapipe, + batch_size=None, # batching is handled by our ExperimentDataPipe + num_workers=num_workers, + # avoid use of default collator, which adds an extra (3rd) dimension to the tensor batches + collate_fn=collate_noop, + **dataloader_kwargs, + ) + + +# For testing only +if __name__ == "__main__": + import tiledbsoma as soma + + ( + census_uri_arg, + organism_arg, + measurement_name_arg, + layer_name_arg, + obs_value_filter_arg, + column_names_arg, + sparse_X_arg, + torch_batch_size_arg, + num_workers_arg, + ) = sys.argv[1:] + + census = cellxgene_census.open_soma(uri=census_uri_arg) + + exp_datapipe = ExperimentDataPipe( + experiment=census["census_data"]["homo_sapiens"], + measurement_name=measurement_name_arg, + X_name=layer_name_arg, + obs_query=soma.AxisQuery(value_filter=(obs_value_filter_arg or None)), + var_query=soma.AxisQuery(coords=(slice(1, 9),)), + obs_column_names=column_names_arg.split(","), + batch_size=int(torch_batch_size_arg), + soma_buffer_bytes=2**12, + sparse_X=sparse_X_arg.lower() == "sparse", + ) + + dp_shuffle = exp_datapipe.shuffle(buffer_size=len(exp_datapipe)) + dp_train, dp_test = dp_shuffle.random_split(weights={"train": 0.7, "test": 0.3}, seed=1234) + + dl = experiment_dataloader(dp_train, num_workers=int(num_workers_arg)) + + i = 0 + datum = None + for i, datum in enumerate(dl): + if (i + 1) % 1000 == 0: + print(f"Received {i} torch batches, {exp_datapipe.stats()}:\n{datum}") + print(f"Received {i} torch batches, {exp_datapipe.stats()}:\n{datum}") diff --git a/api/python/cellxgene_census/tests/README.md b/api/python/cellxgene_census/tests/README.md index 7a67bda98..612cd83f0 100644 --- a/api/python/cellxgene_census/tests/README.md +++ b/api/python/cellxgene_census/tests/README.md @@ -14,22 +14,27 @@ Then run the tests: ## Pytest Marks -There are two Pytest marks you can use from the command line: +There are various Pytest marks you can use from the command line: - live_corpus: tests that directly access the `latest` version of the Census. Enabled by default. - expensive: tests that are expensive (ie., cpu, memory, time). Disabled by default - enable with `--expensive`. Some of these tests are _very_ expensive, ie., require a very large memory host to succeed. +- experimental: tests that are for code in the `experimental` package. Disabled by default - enable with `--experimental`. These tests require installation the optional Python packages installed via pip `pip install -e ./api/python/cellxgene_census/[experimental]` By default, only relatively cheap & fast tests are run. To enable `expensive` tests: > pytest --expensive ... +To enable `experimental` tests: + +> pytest --experimental ... + To disable `live_corpus` tests: > pytest -m 'not live_corpus' You can also combine them, e.g., -> pytest -m 'not live_corpus' --expensive +> pytest -m 'not live_corpus' --expensive --experimental # Acceptance (expensive) tests diff --git a/api/python/cellxgene_census/tests/__init__.py b/api/python/cellxgene_census/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/python/cellxgene_census/tests/conftest.py b/api/python/cellxgene_census/tests/conftest.py index 40f6ec6f2..cc50c43e7 100644 --- a/api/python/cellxgene_census/tests/conftest.py +++ b/api/python/cellxgene_census/tests/conftest.py @@ -1,12 +1,30 @@ import pytest +TEST_MARKERS_SKIPPED_BY_DEFAULT = ["expensive", "experimental"] + def pytest_addoption(parser: pytest.Parser) -> None: - parser.addoption( - "--expensive", action="store_true", dest="expensive", default=False, help="enable 'expensive' decorated tests" - ) + for test_option in TEST_MARKERS_SKIPPED_BY_DEFAULT: + parser.addoption( + f"--{test_option}", + action="store_true", + dest=test_option, + default=False, + help=f"enable '{test_option}' decorated tests", + ) def pytest_configure(config: pytest.Config) -> None: - if not config.option.expensive: - config.option.markexpr = "not expensive" + """ + Exclude tests marked with any of the TEST_MARKERS_SKIPPED_BY_DEFAULT values, unless the corresponding explicit + flag is specified by the user. + """ + excluded_markexprs = [] + + for test_option in TEST_MARKERS_SKIPPED_BY_DEFAULT: + if not vars(config.option).get(test_option, False): + excluded_markexprs.append(test_option) + + if config.option.markexpr and excluded_markexprs: + config.option.markexpr += " and " + config.option.markexpr += " and ".join([f"not {m}" for m in excluded_markexprs]) diff --git a/api/python/cellxgene_census/tests/experimental/__init__.py b/api/python/cellxgene_census/tests/experimental/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/python/cellxgene_census/tests/experimental/ml/__init__.py b/api/python/cellxgene_census/tests/experimental/ml/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py b/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py new file mode 100644 index 000000000..dc6735ff7 --- /dev/null +++ b/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py @@ -0,0 +1,360 @@ +import pathlib +from typing import Callable, List, Optional, Sequence, Tuple + +import numpy as np +import pyarrow as pa +import pytest +import tiledbsoma as soma +from scipy import sparse +from scipy.sparse import coo_matrix, spmatrix +from somacore import AxisQuery +from tiledbsoma import Experiment, _factory +from tiledbsoma._collection import CollectionBase + +# conditionally import torch, as it will not be available in all test environments +try: + from torch import Tensor + + from cellxgene_census.experimental.ml.pytorch import ExperimentDataPipe, experiment_dataloader +except ImportError: + # this should only occur when not running `experimental`-marked tests + pass + + +def pytorch_x_value_gen(shape: Tuple[int, int]) -> spmatrix: + checkerboard_of_ones = coo_matrix(np.indices(shape).sum(axis=0) % 2) + return checkerboard_of_ones + + +@pytest.fixture +def X_layer_names() -> List[str]: + return ["raw"] + + +@pytest.fixture +def obsp_layer_names() -> Optional[List[str]]: + return None + + +@pytest.fixture +def varp_layer_names() -> Optional[List[str]]: + return None + + +@pytest.fixture +def X_value_gen() -> Callable[[Tuple[int, int]], sparse.spmatrix]: + def _x_value_gen(shape: Tuple[int, int]) -> sparse.coo_matrix: + return sparse.random( + shape[0], + shape[1], + density=0.1, + format="coo", + dtype=np.float32, + random_state=np.random.default_rng(), + ) + + return _x_value_gen + + +def add_dataframe(coll: CollectionBase, key: str, sz: int) -> None: + df = coll.add_new_dataframe( + key, + schema=pa.schema( + [ + ("soma_joinid", pa.int64()), + ("label", pa.large_string()), + ] + ), + index_column_names=["soma_joinid"], + ) + df.write( + pa.Table.from_pydict( + { + "soma_joinid": [i for i in range(sz)], + "label": [str(i) for i in range(sz)], + } + ) + ) + + +def add_sparse_array( + coll: CollectionBase, key: str, shape: Tuple[int, int], value_gen: Callable[[Tuple[int, int]], sparse.spmatrix] +) -> None: + a = coll.add_new_sparse_ndarray(key, type=pa.float32(), shape=shape) + tensor = pa.SparseCOOTensor.from_scipy(value_gen(shape)) + a.write(tensor) + + +@pytest.fixture(scope="function") +def soma_experiment( + tmp_path: pathlib.Path, + n_obs: int, + n_vars: int, + X_layer_names: Sequence[str], + X_value_gen: Callable[[Tuple[int, int]], sparse.spmatrix], + obsp_layer_names: Sequence[str], + varp_layer_names: Sequence[str], +) -> soma.Experiment: + with soma.Experiment.create((tmp_path / "exp").as_posix()) as exp: + add_dataframe(exp, "obs", n_obs) + ms = exp.add_new_collection("ms") + rna = ms.add_new_collection("RNA", soma.Measurement) + add_dataframe(rna, "var", n_vars) + rna_x = rna.add_new_collection("X", soma.Collection) + for X_layer_name in X_layer_names: + add_sparse_array(rna_x, X_layer_name, (n_obs, n_vars), X_value_gen) + + if obsp_layer_names: + obsp = rna.add_new_collection("obsp") + for obsp_layer_name in obsp_layer_names: + add_sparse_array(obsp, obsp_layer_name, (n_obs, n_obs), X_value_gen) + + if varp_layer_names: + varp = rna.add_new_collection("varp") + for varp_layer_name in varp_layer_names: + add_sparse_array(varp, varp_layer_name, (n_vars, n_vars), X_value_gen) + return _factory.open((tmp_path / "exp").as_posix()) + + +@pytest.mark.experimental +# noinspection PyTestParametrized +@pytest.mark.parametrize("n_obs,n_vars,X_layer_names,X_value_gen", [(6, 3, ("raw",), pytorch_x_value_gen)]) +def test_non_batched(soma_experiment: Experiment) -> None: + exp_data_pipe = ExperimentDataPipe( + soma_experiment, measurement_name="RNA", X_name="raw", obs_column_names=["label"] + ) + row_iter = iter(exp_data_pipe) + + row = next(row_iter) + assert row[0].int().tolist() == [0, 1, 0] + assert row[1].tolist() == [0, 0] + + +@pytest.mark.experimental +# noinspection PyTestParametrized,DuplicatedCode +@pytest.mark.parametrize("n_obs,n_vars,X_layer_names,X_value_gen", [(6, 3, ("raw",), pytorch_x_value_gen)]) +def test_batching__all_batches_full_size(soma_experiment: Experiment) -> None: + exp_data_pipe = ExperimentDataPipe( + soma_experiment, measurement_name="RNA", X_name="raw", obs_column_names=["label"], batch_size=3 + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert batch[0].int().tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1].tolist() == [[0, 0], [1, 1], [2, 2]] + + batch = next(batch_iter) + assert batch[0].int().tolist() == [[1, 0, 1], [0, 1, 0], [1, 0, 1]] + assert batch[1].tolist() == [[3, 3], [4, 4], [5, 5]] + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.experimental +# noinspection PyTestParametrized +@pytest.mark.parametrize("n_obs,n_vars,X_layer_names,X_value_gen", [(5, 3, ("raw",), pytorch_x_value_gen)]) +def test_batching__partial_final_batch_size(soma_experiment: Experiment) -> None: + exp_data_pipe = ExperimentDataPipe( + soma_experiment, measurement_name="RNA", X_name="raw", obs_column_names=["label"], batch_size=3 + ) + batch_iter = iter(exp_data_pipe) + + next(batch_iter) + batch = next(batch_iter) + assert batch[0].int().tolist() == [[1, 0, 1], [0, 1, 0]] + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.experimental +# noinspection PyTestParametrized,DuplicatedCode +@pytest.mark.parametrize("n_obs,n_vars,X_layer_names,X_value_gen", [(3, 3, ("raw",), pytorch_x_value_gen)]) +def test_batching__exactly_one_batch(soma_experiment: Experiment) -> None: + exp_data_pipe = ExperimentDataPipe( + soma_experiment, measurement_name="RNA", X_name="raw", obs_column_names=["label"], batch_size=3 + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert batch[0].int().tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1].tolist() == [[0, 0], [1, 1], [2, 2]] + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.experimental +# noinspection PyTestParametrized +@pytest.mark.parametrize("n_obs,n_vars,X_layer_names,X_value_gen", [(6, 3, ("raw",), pytorch_x_value_gen)]) +def test_batching__empty_query_result(soma_experiment: Experiment) -> None: + exp_data_pipe = ExperimentDataPipe( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_query=AxisQuery(coords=([],)), + obs_column_names=["label"], + batch_size=3, + ) + batch_iter = iter(exp_data_pipe) + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.experimental +# noinspection PyTestParametrized +@pytest.mark.parametrize("n_obs,n_vars,X_layer_names,X_value_gen", [(6, 3, ("raw",), pytorch_x_value_gen)]) +def test_sparse_output__non_batched(soma_experiment: Experiment) -> None: + exp_data_pipe = ExperimentDataPipe( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + sparse_X=True, + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert isinstance(batch[1], Tensor) + assert batch[0].to_dense().tolist() == [0, 1, 0] + + +@pytest.mark.experimental +# noinspection PyTestParametrized +@pytest.mark.parametrize("n_obs,n_vars,X_layer_names,X_value_gen", [(6, 3, ("raw",), pytorch_x_value_gen)]) +def test_sparse_output__batched(soma_experiment: Experiment) -> None: + exp_data_pipe = ExperimentDataPipe( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + batch_size=3, + sparse_X=True, + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert isinstance(batch[1], Tensor) + assert batch[0].to_dense().tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + + +@pytest.mark.experimental +# noinspection PyTestParametrized +@pytest.mark.parametrize("n_obs,n_vars,X_layer_names,X_value_gen", [(3, 3, ("raw",), pytorch_x_value_gen)]) +def test_encoders(soma_experiment: Experiment) -> None: + exp_data_pipe = ExperimentDataPipe( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + batch_size=3, + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert isinstance(batch[1], Tensor) + + labels_encoded = batch[1][:, 1] + labels_decoded = exp_data_pipe.obs_encoders()["label"].inverse_transform(labels_encoded) + assert labels_decoded.tolist() == ["0", "1", "2"] + + +@pytest.mark.experimental +# noinspection PyTestParametrized,DuplicatedCode +@pytest.mark.parametrize("n_obs,n_vars,X_layer_names,X_value_gen", [(3, 3, ("raw",), pytorch_x_value_gen)]) +def test_experiment_dataloader__non_batched(soma_experiment: Experiment) -> None: + dp = ExperimentDataPipe(soma_experiment, measurement_name="RNA", X_name="raw", obs_column_names=["label"]) + dl = experiment_dataloader(dp) + torch_data = [row for row in dl] + + row = torch_data[0] + assert row[0].to_dense().tolist() == [0, 1, 0] + assert row[1].tolist() == [0, 0] + + +@pytest.mark.experimental +# noinspection PyTestParametrized,DuplicatedCode +@pytest.mark.parametrize("n_obs,n_vars,X_layer_names,X_value_gen", [(6, 3, ("raw",), pytorch_x_value_gen)]) +def test_experiment_dataloader__batched(soma_experiment: Experiment) -> None: + dp = ExperimentDataPipe( + soma_experiment, measurement_name="RNA", X_name="raw", obs_column_names=["label"], batch_size=3 + ) + dl = experiment_dataloader(dp) + torch_data = [row for row in dl] + + batch = torch_data[0] + assert batch[0].to_dense().tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1].tolist() == [[0, 0], [1, 1], [2, 2]] + + +@pytest.mark.experimental +@pytest.mark.skip(reason="Not implemented") +def test_experiment_dataloader__multiprocess_sparse_matrix__fails() -> None: + pass + + +@pytest.mark.experimental +@pytest.mark.skip(reason="Not implemented") +def test_experiment_dataloader__multiprocess_dense_matrix__ok() -> None: + pass + + +@pytest.mark.experimental +# noinspection PyTestParametrized,DuplicatedCode +@pytest.mark.parametrize("n_obs,n_vars,X_layer_names,X_value_gen", [(10, 1, ("raw",), pytorch_x_value_gen)]) +def test_experiment_dataloader__splitting(soma_experiment: Experiment) -> None: + dp = ExperimentDataPipe(soma_experiment, measurement_name="RNA", X_name="raw", obs_column_names=["label"]) + dp_train, dp_test = dp.random_split(weights={"train": 0.7, "test": 0.3}, seed=1234) + dl = experiment_dataloader(dp_train) + + all_rows = list(iter(dl)) + assert len(all_rows) == 7 + + +@pytest.mark.experimental +# noinspection PyTestParametrized,DuplicatedCode +@pytest.mark.parametrize("n_obs,n_vars,X_layer_names,X_value_gen", [(10, 1, ("raw",), pytorch_x_value_gen)]) +def test_experiment_dataloader__shuffling(soma_experiment: Experiment) -> None: + dp = ExperimentDataPipe(soma_experiment, measurement_name="RNA", X_name="raw", obs_column_names=["label"]) + dp = dp.shuffle() + dl = experiment_dataloader(dp) + + data1 = list(iter(dl)) + data2 = list(iter(dl)) + + data1_soma_joinids = [row[1][0].item() for row in data1] + data2_soma_joinids = [row[1][0].item() for row in data2] + assert data1_soma_joinids != data2_soma_joinids + + +@pytest.mark.experimental +# noinspection PyTestParametrized,DuplicatedCode +@pytest.mark.parametrize("n_obs,n_vars,X_layer_names,X_value_gen", [(3, 3, ("raw",), pytorch_x_value_gen)]) +def test_experiment_dataloader__multiprocess_pickling(soma_experiment: Experiment) -> None: + """ + If the DataPipe is accessed prior to multiprocessing (num_workers > 0), its internal _query will be + initialized. But since it cannot be pickled, we must ensure it is ignored during pickling in multiprocessing mode. + This test verifies the correct pickling behavior is in place. + """ + + dp = ExperimentDataPipe( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + num_workers=1, + sparse_X=True, + ) + dl = experiment_dataloader(dp) + dp.obs_encoders() # trigger query building + row = next(iter(dl)) # trigger multiprocessing + + assert row is not None + + +@pytest.mark.experimental +@pytest.mark.skip(reason="TODO") +def test_experiment_data_loader__unsupported_params__fails() -> None: + pass diff --git a/api/python/cellxgene_census/tests/test_open.py b/api/python/cellxgene_census/tests/test_open.py index 87cf1626b..7819d14a7 100644 --- a/api/python/cellxgene_census/tests/test_open.py +++ b/api/python/cellxgene_census/tests/test_open.py @@ -200,6 +200,7 @@ def test_opening_census_without_anon_access_fails_with_bogus_creds() -> None: cellxgene_census.open_soma(census_version="latest", context=soma.SOMATileDBContext()) +@pytest.mark.live_corpus def test_can_open_with_anonymous_access() -> None: """ With anonymous access, `open_soma` must be able to access the census even with bogus credentials diff --git a/api/python/notebooks/ml_demo/pytorch_lr_classifier.py b/api/python/notebooks/ml_demo/pytorch_lr_classifier.py new file mode 100644 index 000000000..711108117 --- /dev/null +++ b/api/python/notebooks/ml_demo/pytorch_lr_classifier.py @@ -0,0 +1,127 @@ +import numpy as np +import tiledbsoma as soma +import torch + +import cellxgene_census +from cellxgene_census.experimental.ml.pytorch import experiment_dataloader, ExperimentDataPipe + +# TODO: Convert this to a notebook + + +class LogisticRegression(torch.nn.Module): + def __init__(self, input_dim, output_dim, dropout_prob=0.2): + super(LogisticRegression, self).__init__() + self.linear = torch.nn.Linear(input_dim, output_dim) + self.dropout = torch.nn.Dropout(p=dropout_prob) + + def forward(self, x): + outputs = self.dropout(self.linear(x)) + return outputs + + +def train_epoch(model, train_dataloader, loss_fn, optimizer, device): + model.train() + train_loss = 0 + train_correct = 0 + train_total = 0 + y_true = [] + y_pred = [] + for batch in train_dataloader: + optimizer.zero_grad() + X_batch, y_batch = batch + + # exclude the soma_joinid of the y_batch, which is in the first column + # TODO: allow correct types to be specified to experiment_dataloader() + y_batch = y_batch[:, 1:].flatten().long() + + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + + outputs = model(X_batch) + probs = torch.nn.functional.softmax(outputs, 1) + preds = torch.argmax(probs, axis=1).cpu() + + train_correct += (preds == y_batch.cpu()).sum().item() + train_total += len(preds) + + loss = loss_fn(outputs, y_batch) + train_loss += loss.item() + loss.backward() + optimizer.step() + + y_true.extend(y_batch.cpu().numpy()) + y_pred.extend(preds.numpy()) + + train_loss /= train_total + train_accuracy = train_correct / train_total + return train_loss, train_accuracy + + +def train(model, model_filename, train_dataloader, device, learning_rate, weight_decay, num_epochs, precision=7): + loss_fn = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + losses = [] + + for epoch in range(num_epochs): + train_loss, train_accuracy = train_epoch(model, train_dataloader, loss_fn, optimizer, device) + losses.append(round(train_loss, precision)) + print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}") + + torch.save(model.state_dict(), model_filename) + + print("Saved model to file", model_filename) + return model + + +def run(): + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + census = cellxgene_census.open_soma() + + predicted_label = "cell_type" + obs_value_filter = "tissue_general == 'tongue' and is_primary_data == True" + var_value_filter = "" + + exp_dp = ExperimentDataPipe( + census["census_data"]["homo_sapiens"], + measurement_name="RNA", + X_name="raw", + obs_query=soma.AxisQuery(value_filter=(obs_value_filter or None)), + var_query=soma.AxisQuery(value_filter=(var_value_filter or None)), + obs_column_names=[predicted_label], + batch_size=16, + ) + + dp = exp_dp.shuffle(buffer_size=len(exp_dp)) + dp_train, dp_test = dp.random_split(weights={"train": 0.7, "test": 0.3}, seed=RANDOM_SEED) + + dl_train = experiment_dataloader( + dp_train, + # >= 1 uses multiprocessing to load data + num_workers=0, + ) + + pred_field_encoder = exp_dp.obs_encoders()[predicted_label] + output_dim = len(pred_field_encoder.classes_) + input_dim = exp_dp.shape[1] + + model = LogisticRegression(input_dim, output_dim).to(device) + model = train( + model, + "model.pt", + dl_train, + device, + learning_rate=1e-4, + weight_decay=0.0, + num_epochs=3, + ) + return model + + +RANDOM_SEED = 123 +np.random.seed(RANDOM_SEED) +torch.manual_seed(RANDOM_SEED) + + +if __name__ == "__main__": + run()