diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 9b6911d..a8b3773 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -1,10 +1,72 @@ -name: python-tiledbsoma-ml +name: python-tiledbsoma-ml CI on: + pull_request: + branches: ["*"] + paths-ignore: + - 'scripts/**' + push: + branches: [main] + paths-ignore: + - 'scripts/**' + workflow_dispatch: jobs: - job: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Restore pre-commit cache + uses: actions/cache@v4 + with: + path: ~/.cache/pre-commit + key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }} + + - name: Install pre-commit + run: pip -v install pre-commit + + - name: Run pre-commit hooks on all files + run: pre-commit run -v -a + + tests: runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - # Empty job; placeholder GHA + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install prereqs + run: | + pip install --upgrade pip wheel pytest pytest-cov setuptools + pip install . + + - name: Run tests + run: pytest -v --cov=src --cov-report=xml tests + + build: + # for now, just do a test build to ensure that it works + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Do build + run: | + pip install --upgrade build pip wheel setuptools setuptools-scm + python -m build . diff --git a/.github/workflows/python-tilledbsoma-ml-compat.yml b/.github/workflows/python-tilledbsoma-ml-compat.yml new file mode 100644 index 0000000..61a5001 --- /dev/null +++ b/.github/workflows/python-tilledbsoma-ml-compat.yml @@ -0,0 +1,46 @@ +name: python-tiledbsoma-ml past tiledbsoma compat # Latest tiledbsoma version covered by another workflow + +on: + pull_request: + branches: ["*"] + paths-ignore: + - "scripts/**" + - "notebooks/**" + push: + branches: [main] + paths-ignore: + - "scripts/**" + - "notebooks/**" + +jobs: + unit_tests: + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest"] # could add 'macos-latest', but the matrix is already huge... + python-version: ["3.9", "3.10", "3.11"] # TODO: add 3.12 when tiledbsoma releases wheels for it. + pkg-version: + - "tiledbsoma~=1.9.0 'numpy<2.0.0'" + - "tiledbsoma~=1.10.0 'numpy<2.0.0'" + - "tiledbsoma~=1.11.0" + - "tiledbsoma~=1.12.0" + - "tiledbsoma~=1.13.0" + - "tiledbsoma~=1.14.0" + + runs-on: ${{ matrix.os }} + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install prereqs + run: | + pip install --upgrade pip pytest setuptools + pip install ${{ matrix.pkg-version }} . + + - name: Run tests + run: pytest -v tests diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..efa407c --- /dev/null +++ b/.gitignore @@ -0,0 +1,162 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..6385ca9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +repos: + - repo: https://github.com/psf/black + rev: "24.8.0" + hooks: + - id: black + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.5 + hooks: + - id: ruff + name: "ruff for tiledbsoma_ml" + args: ["--config=pyproject.toml"] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.11.2 + hooks: + - id: mypy + pass_filenames: false + args: ["--config-file=pyproject.toml", "src"] + additional_dependencies: + - attrs + - numpy + - pandas-stubs>=2 diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..957c631 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,24 @@ + +# Change Log + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](http://keepachangelog.com/) +and this project adheres to [Semantic Versioning](http://semver.org/). + +## [Unreleased] - yyyy-mm-dd + +Port and enhance contribution from the Chan Zuckerberg Initiative Foundation +[CELLxGENE](https://cellxgene.cziscience.com/) project. + +This is not a one-for-one migration of the contributed code. Substantial changes have +been made to the package utility (e.g., multi-GPU support), improve API usability, etc. + +### Added + +- Initial commits via [PR #1](https://github.com/single-cell-data/TileDB-SOMA-ML/pull/1) +- Refine package dependency pins and compatibility tests via [PR #2](https://github.com/single-cell-data/TileDB-SOMA-ML/pull/2) + +### Changed + +### Fixed diff --git a/README.md b/README.md new file mode 100644 index 0000000..576b1b5 --- /dev/null +++ b/README.md @@ -0,0 +1,45 @@ + +# tiledbsoma_ml + +A Python package containing ML tools for use with `tiledbsoma`. + +## Description + +The package currently contains a prototype PyTorch `IterableDataset` for use with the +[`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) +API. + +## Getting Started + +### Installing + +Install using your favorite package installer. For exapmle, with pip: + +> pip install tiledbsoma-ml + +Developers may install editable, from source, in the usual manner: + +> pip install -e . + +### Documentation + +TBD + +## Builds + +This is a pure Python package. To build a wheel, ensure you have the `build` package installed, and then: + +> python -m build . + +## Version History + +See the [CHANGELOG.md](CHANGELOG.md) file. + +## License + +This project is licensed under the MIT License. + +## Acknowledgements + +The SOMA team is grateful to the Chan Zuckerberg Initiative Foundation [CELLxGENE Census](https://cellxgene.cziscience.com) +team for their initial contribution. diff --git a/apis/python/src/tiledbsoma/ml/__init__.py b/apis/python/src/tiledbsoma/ml/__init__.py deleted file mode 100644 index 8450cdc..0000000 --- a/apis/python/src/tiledbsoma/ml/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""An API to facilitate use of PyTorch ML training with data from the CZI Science CELLxGENE Census.""" - -from .encoders import BatchEncoder, Encoder, LabelEncoder -from .pytorch import ExperimentDataPipe, Stats, experiment_dataloader - -__all__ = [ - "Stats", - "ExperimentDataPipe", - "experiment_dataloader", - "Encoder", - "LabelEncoder", - "BatchEncoder", -] diff --git a/apis/python/src/tiledbsoma/ml/encoders.py b/apis/python/src/tiledbsoma/ml/encoders.py deleted file mode 100644 index 3d4fc4d..0000000 --- a/apis/python/src/tiledbsoma/ml/encoders.py +++ /dev/null @@ -1,130 +0,0 @@ -import abc -import functools -from typing import List - -import numpy.typing as npt -import pandas as pd -from sklearn.preprocessing import LabelEncoder as SklearnLabelEncoder - - -class Encoder(abc.ABC): - """Base class for ``obs`` encoders. - - To define a custom encoder, five methods must be implemented: - - - ``fit``: defines how the encoder will be fitted to the data. - - ``transform``: defines how the encoder will be applied to the data - in order to create an ``obs`` tensor. - - ``inverse_transform``: defines how to decode the encoded values back - to the original values. - - ``name``: The name of the encoder. This will be used as the key in the - dictionary of encoders. Each encoder passed to a :class:`.pytorch.ExperimentDataPipe` must have a unique name. - - ``columns``: List of columns in ``obs`` that the encoder will be applied to. - - See the implementation of :class:`LabelEncoder` for an example. - """ - - @abc.abstractmethod - def fit(self, obs: pd.DataFrame) -> None: - """Fit the encoder with obs.""" - pass - - @abc.abstractmethod - def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """Transform the obs :class:`pandas.DataFrame` into a :class:`pandas.DataFrame` of encoded values.""" - pass - - @abc.abstractmethod - def inverse_transform(self, encoded_values: npt.ArrayLike) -> npt.ArrayLike: - """Inverse transform the encoded values back to the original values.""" - pass - - @property - @abc.abstractmethod - def name(self) -> str: - """Name of the encoder.""" - pass - - @property - @abc.abstractmethod - def columns(self) -> List[str]: - """Columns in ``obs`` that the encoder will be applied to.""" - pass - - -class LabelEncoder(Encoder): - """Default encoder based on :class:`sklearn.preprocessing.LabelEncoder`.""" - - def __init__(self, col: str) -> None: - self._encoder = SklearnLabelEncoder() - self.col = col - - def fit(self, obs: pd.DataFrame) -> None: - """Fit the encoder with ``obs``.""" - self._encoder.fit(obs[self.col].unique()) - - def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """Transform the obs :class:`pandas.DataFrame` into a :class:`pandas.DataFrame` of encoded values.""" - return self._encoder.transform(df[self.col]) # type: ignore - - def inverse_transform(self, encoded_values: npt.ArrayLike) -> npt.ArrayLike: - """Inverse transform the encoded values back to the original values.""" - return self._encoder.inverse_transform(encoded_values) # type: ignore - - @property - def name(self) -> str: - """Name of the encoder.""" - return self.col - - @property - def columns(self) -> List[str]: - """Columns in ``obs`` that the encoder will be applied to.""" - return [self.col] - - @property - def classes_(self): # type: ignore - """Classes of the encoder.""" - return self._encoder.classes_ - - -class BatchEncoder(Encoder): - """An encoder that concatenates and encodes several ``obs`` columns.""" - - def __init__(self, cols: List[str], name: str = "batch"): - self.cols = cols - from sklearn.preprocessing import LabelEncoder - - self._name = name - self._encoder = LabelEncoder() - - def _join_cols(self, df: pd.DataFrame): # type: ignore - return functools.reduce(lambda a, b: a + b, [df[c].astype(str) for c in self.cols]) - - def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """Transform the obs :class:`pandas.DataFrame` into a :class:`pandas.DataFrame` of encoded values.""" - arr = self._join_cols(df) - return self._encoder.transform(arr) # type: ignore - - def inverse_transform(self, encoded_values: npt.ArrayLike) -> npt.ArrayLike: - """Inverse transform the encoded values back to the original values.""" - return self._encoder.inverse_transform(encoded_values) # type: ignore - - def fit(self, obs: pd.DataFrame) -> None: - """Fit the encoder with ``obs``.""" - arr = self._join_cols(obs) - self._encoder.fit(arr.unique()) - - @property - def columns(self) -> List[str]: - """Columns in ``obs`` that the encoder will be applied to.""" - return self.cols - - @property - def name(self) -> str: - """Name of the encoder.""" - return self._name - - @property - def classes_(self): # type: ignore - """Classes of the encoder.""" - return self._encoder.classes_ diff --git a/apis/python/src/tiledbsoma/ml/pytorch.py b/apis/python/src/tiledbsoma/ml/pytorch.py deleted file mode 100644 index 5bef673..0000000 --- a/apis/python/src/tiledbsoma/ml/pytorch.py +++ /dev/null @@ -1,869 +0,0 @@ -import gc -import itertools -import logging -import os -import typing -from contextlib import contextmanager -from datetime import timedelta -from math import ceil -from time import time -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union - -import numpy as np -import numpy.typing as npt -import pandas as pd -import psutil -import tiledbsoma as soma -import torch -import torchdata.datapipes.iter as pipes -from attr import define -from numpy.random import Generator -from pyarrow import Table -from scipy import sparse -from torch import Tensor -from torch import distributed as dist -from torch.utils.data import DataLoader -from torch.utils.data.dataset import Dataset - -from ... import get_default_soma_context -from ..util._eager_iter import _EagerIterator -from .encoders import Encoder, LabelEncoder - -pytorch_logger = logging.getLogger("cellxgene_census.experimental.pytorch") - -# TODO: Rename to reflect the correct order of the Tensors within the tuple: (X, obs) -ObsAndXDatum = Tuple[Tensor, Tensor] -"""Return type of ``ExperimentDataPipe`` that pairs a Tensor of ``obs`` row(s) with a Tensor of ``X`` matrix row(s). -The Tensors are rank 1 if ``batch_size`` is 1, otherwise the Tensors are rank 2.""" - - -# "Chunk" of X data, returned by each `Method` above -ChunkX = Union[npt.NDArray[Any], sparse.csr_matrix] - - -@define -class _SOMAChunk: - """Return type of ``_ObsAndXSOMAIterator`` that pairs a chunk of ``obs`` rows with the respective rows from the ``X`` - matrix. - - Lifecycle: - experimental - """ - - obs: pd.DataFrame - X: ChunkX - stats: "Stats" - - def __len__(self) -> int: - return len(self.obs) - - -Encoders = Dict[str, Encoder] -"""A dictionary of ``Encoder``s keyed by the ``obs`` column name.""" - - -@define -class Stats: - """Statistics about the data retrieved by ``ExperimentDataPipe`` via SOMA API. This is useful for assessing the read - throughput of SOMA data. - - Lifecycle: - experimental - """ - - n_obs: int = 0 - """The total number of obs rows retrieved""" - - nnz: int = 0 - """The total number of values retrieved""" - - elapsed: float = 0 - """The total elapsed time in seconds for retrieving all batches""" - - n_soma_chunks: int = 0 - """The number of chunks retrieved""" - - def __str__(self) -> str: - return f"{self.n_soma_chunks=}, {self.n_obs=}, {self.nnz=}, " f"elapsed={timedelta(seconds=self.elapsed)}" - - def __add__(self, other: "Stats") -> "Stats": - self.n_obs += other.n_obs - self.nnz += other.nnz - self.elapsed += other.elapsed - self.n_soma_chunks += other.n_soma_chunks - return self - - -@contextmanager -def _open_experiment( - uri: str, - aws_region: Optional[str] = None, -) -> soma.Experiment: - """Internal method for opening a SOMA ``Experiment`` as a context manager.""" - context = get_default_soma_context().replace(tiledb_config={"vfs.s3.region": aws_region} if aws_region else {}) - - with soma.Experiment.open(uri, context=context) as exp: - yield exp - - -def _tables_to_np( - tables: Iterator[Tuple[Table, Any]], shape: Tuple[int, int] -) -> typing.Generator[Tuple[npt.NDArray[Any], Any, int], None, None]: - for tbl, indices in tables: - row_indices, col_indices, data = (x.to_numpy() for x in tbl.columns) - nnz = len(data) - dense_matrix = np.zeros(shape, dtype=data.dtype) - dense_matrix[row_indices, col_indices] = data - yield dense_matrix, indices, nnz - - -class _ObsAndXSOMAIterator(Iterator[_SOMAChunk]): - """Iterates the SOMA chunks of corresponding ``obs`` and ``X`` data. This is an internal class, - not intended for public use. - """ - - X: soma.SparseNDArray - """A handle to the full X data of the SOMA ``Experiment``""" - - obs_joinids_chunks_iter: Iterator[npt.NDArray[np.int64]] - - var_joinids: npt.NDArray[np.int64] - """The ``var`` joinids to be retrieved from the SOMA ``Experiment``""" - - def __init__( - self, - obs: soma.DataFrame, - X: soma.SparseNDArray, - obs_column_names: Sequence[str], - obs_joinids_chunked: List[npt.NDArray[np.int64]], - var_joinids: npt.NDArray[np.int64], - shuffle_chunk_count: Optional[int] = None, - shuffle_rng: Optional[Generator] = None, - return_sparse_X: bool = False, - ): - self.obs = obs - self.X = X - self.obs_column_names = obs_column_names - if shuffle_chunk_count is not None: - assert shuffle_rng is not None - - # At the start of this step, `obs_joinids_chunked` is a list of one dimensional - # numpy arrays. Each numpy array corresponds to a chunk of contiguous rows in `obs`. - # Critically, `obs_joinids_chunked` is randomly ordered where each chunk is - # from a random section of `obs`. - # We then take `shuffle_chunk_count` of these in order, concatenate them into - # a larger numpy array and shuffle this larger numpy array. - # The result is again a list of numpy arrays. - self.obs_joinids_chunks_iter = ( - shuffle_rng.permutation(np.concatenate(grouped_chunks)) - for grouped_chunks in list_split(obs_joinids_chunked, shuffle_chunk_count) - ) - else: - self.obs_joinids_chunks_iter = iter(obs_joinids_chunked) - self.var_joinids = var_joinids - self.shuffle_chunk_count = shuffle_chunk_count - self.return_sparse_X = return_sparse_X - - def __next__(self) -> _SOMAChunk: - pytorch_logger.debug("Retrieving next SOMA chunk...") - start_time = time() - - # If no more chunks to iterate through, raise StopIteration, as all iterators do when at end - obs_joinids_chunk = next(self.obs_joinids_chunks_iter) - - if "soma_joinid" not in self.obs_column_names: - cols = ["soma_joinid", *self.obs_column_names] - else: - cols = list(self.obs_column_names) - - obs_batch = ( - self.obs.read( - coords=(obs_joinids_chunk,), - column_names=cols, - ) - .concat() - .to_pandas() - .set_index("soma_joinid") - ) - assert obs_batch.shape[0] == obs_joinids_chunk.shape[0] - - # handle case of empty result (first batch has 0 rows) - if len(obs_batch) == 0: - raise StopIteration - - # reorder obs rows to match obs_joinids_chunk ordering, which may be shuffled - obs_batch = obs_batch.reindex(obs_joinids_chunk, copy=False) - - # note: the `blockwise` call is employed for its ability to reindex the axes of the sparse matrix, - # but the blockwise iteration feature is not used (block_size is set to retrieve the chunk as a single block) - blockwise_iter = self.X.read(coords=(obs_joinids_chunk, self.var_joinids)).blockwise( - axis=0, size=len(obs_joinids_chunk), eager=False - ) - - X_batch: ChunkX - if not self.return_sparse_X: - res = next(_tables_to_np(blockwise_iter.tables(), shape=(obs_batch.shape[0], len(self.var_joinids)))) - X_batch, nnz = res[0], res[2] - else: - X_batch = next(blockwise_iter.scipy(compress=True))[0] - nnz = X_batch.nnz - - assert obs_batch.shape[0] == X_batch.shape[0] - - end_time = time() - stats = Stats() - stats.n_obs += X_batch.shape[0] - stats.nnz += nnz - stats.elapsed += end_time - start_time - stats.n_soma_chunks += 1 - - pytorch_logger.debug(f"Retrieved SOMA chunk: {stats}") - return _SOMAChunk(obs=obs_batch, X=X_batch, stats=stats) - - -def list_split(arr_list: List[Any], sublist_len: int) -> List[List[Any]]: - """Splits a python list into a list of sublists where each sublist is of size `sublist_len`. - TODO: Replace with `itertools.batched` when Python 3.12 becomes the minimum supported version. - """ - i = 0 - result = [] - while i < len(arr_list): - if (i + sublist_len) >= len(arr_list): - result.append(arr_list[i:]) - else: - result.append(arr_list[i : i + sublist_len]) - - i += sublist_len - - return result - - -def run_gc() -> Tuple[Tuple[Any, Any, Any], Tuple[Any, Any, Any], float]: # noqa: D103 - proc = psutil.Process(os.getpid()) - - pre_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() - start = time() - gc.collect() - gc_elapsed = time() - start - post_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() - - pytorch_logger.debug(f"gc: pre={pre_gc}") - pytorch_logger.debug(f"gc: post={post_gc}") - - return pre_gc, post_gc, gc_elapsed - - -class _ObsAndXIterator(Iterator[ObsAndXDatum]): - """Iterates through a set of ``obs`` and corresponding ``X`` rows, where the rows to be returned are specified by - the ``obs_tables_iter`` argument. For the specified ``obs` rows, the corresponding ``X`` data is loaded and - joined together. It is returned from this iterator as 2-tuples of ``X`` and obs Tensors. - - Internally manages the retrieval of data in SOMA-sized chunks, fetching the next chunk of SOMA data as needed. - Supports fetching the data in an eager manner, where the next SOMA chunk is fetched while the current chunk is - being read. This is an internal class, not intended for public use. - """ - - soma_chunk_iter: Iterator[_SOMAChunk] - """The iterator for SOMA chunks of paired obs and X data""" - - soma_chunk: Optional[_SOMAChunk] - """The current SOMA chunk of obs and X data""" - - i: int = -1 - """Index into current obs ``SOMA`` chunk""" - - def __init__( - self, - obs: soma.DataFrame, - X: soma.SparseNDArray, - obs_column_names: Sequence[str], - obs_joinids_chunked: List[npt.NDArray[np.int64]], - var_joinids: npt.NDArray[np.int64], - batch_size: int, - encoders: List[Encoder], - stats: Stats, - return_sparse_X: bool, - use_eager_fetch: bool, - shuffle_chunk_count: Optional[int] = None, - shuffle_rng: Optional[Generator] = None, - ) -> None: - self.soma_chunk_iter = _ObsAndXSOMAIterator( - obs, - X, - obs_column_names, - obs_joinids_chunked, - var_joinids, - shuffle_chunk_count, - shuffle_rng, - return_sparse_X=return_sparse_X, - ) - if use_eager_fetch: - self.soma_chunk_iter = _EagerIterator(self.soma_chunk_iter) - self.soma_chunk = None - self.var_joinids = var_joinids - self.batch_size = batch_size - self.return_sparse_X = return_sparse_X - self.encoders = encoders - self.stats = stats - self.gc_elapsed = 0.0 - self.max_process_mem_usage_bytes = 0 - self.X_dtype = X.schema[2].type.to_pandas_dtype() - - def __next__(self) -> ObsAndXDatum: - """Read the next torch batch, possibly across multiple soma chunks.""" - obss: list[pd.DataFrame] = [] - Xs: list[ChunkX] = [] - n_obs = 0 - - while n_obs < self.batch_size: - try: - obs_partial, X_partial = self._read_partial_torch_batch(self.batch_size - n_obs) - n_obs += len(obs_partial) - obss.append(obs_partial) - Xs.append(X_partial) - except StopIteration: - break - - if len(Xs) == 0: # If we ran out of data - raise StopIteration - else: - if self.return_sparse_X: - X = sparse.vstack(Xs) - else: - X = np.concatenate(Xs, axis=0) - obs = pd.concat(obss, axis=0) - - obs_encoded = pd.DataFrame() - - # Add the soma_joinid to the original obs, in case that is requested by the encoders. - obs["soma_joinid"] = obs.index - - for enc in self.encoders: - obs_encoded[enc.name] = enc.transform(obs) - - # `to_numpy()` avoids copying the numpy array data - obs_tensor = torch.from_numpy(obs_encoded.to_numpy()) - - if not self.return_sparse_X: - X_tensor = torch.from_numpy(X) - else: - coo = X.tocoo() - - X_tensor = torch.sparse_coo_tensor( - # Note: The `np.array` seems unnecessary, but PyTorch warns bare array is "extremely slow" - indices=torch.from_numpy(np.array([coo.row, coo.col])), - values=coo.data, - size=coo.shape, - ) - - if self.batch_size == 1: - X_tensor = X_tensor[0] - obs_tensor = obs_tensor[0] - - return X_tensor, obs_tensor - - def _read_partial_torch_batch(self, batch_size: int) -> Tuple[pd.DataFrame, ChunkX]: - """Reads a torch-size batch of data from the current SOMA chunk, returning a torch-size batch whose size may - contain fewer rows than the requested ``batch_size``. This can happen when the remaining rows in the current - SOMA chunk are fewer than the requested ``batch_size``. - """ - if self.soma_chunk is None or not (0 <= self.i < len(self.soma_chunk)): - # GC memory from previous soma_chunk - self.soma_chunk = None - pre_gc, _, gc_elapsed = run_gc() - self.max_process_mem_usage_bytes = max(self.max_process_mem_usage_bytes, pre_gc[0].uss) - - self.soma_chunk: _SOMAChunk = next(self.soma_chunk_iter) - self.stats += self.soma_chunk.stats - self.gc_elapsed += gc_elapsed - self.i = 0 - - pytorch_logger.debug( - f"Retrieved SOMA chunk totals: {self.stats}, gc_elapsed={timedelta(seconds=self.gc_elapsed)}" - ) - - obs_batch = self.soma_chunk.obs - X_chunk = self.soma_chunk.X - - safe_batch_size = min(batch_size, len(obs_batch) - self.i) - slice_ = slice(self.i, self.i + safe_batch_size) - assert slice_.stop <= obs_batch.shape[0] - - obs_rows = obs_batch.iloc[slice_] - assert obs_rows.index.is_unique - assert safe_batch_size == obs_rows.shape[0] - - X_batch = X_chunk[slice_] - - assert obs_rows.shape[0] == X_batch.shape[0] - - self.i += safe_batch_size - - return obs_rows, X_batch - - -class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsAndXDatum]]): # type: ignore - r"""An :class:`torchdata.datapipes.iter.IterDataPipe` that reads ``obs`` and ``X`` data from a - :class:`tiledbsoma.Experiment`, based upon the specified queries along the ``obs`` and ``var`` axes. Provides an - iterator over these data when the object is passed to Python's built-in ``iter`` function. - - >>> for batch in iter(ExperimentDataPipe(...)): - X_batch, y_batch = batch - - The ``batch_size`` parameter controls the number of rows of ``obs`` and ``X`` data that are returned in each - iteration. If the ``batch_size`` is 1, then each Tensor will have rank 1: - - >>> (tensor([0., 0., 0., 0., 0., 1., 0., 0., 0.]), # X data - tensor([2415, 0, 0], dtype=torch.int64)) # obs data, encoded - - For larger ``batch_size`` values, the returned Tensors will have rank 2: - - >>> DataLoader(..., batch_size=3, ...): - (tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0.], # X batch - [0., 0., 0., 0., 0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0., 0., 0., 0., 0.]]), - tensor([[2415, 0, 0], # obs batch - [2416, 0, 4], - [2417, 0, 3]], dtype=torch.int64)) - - The ``return_sparse_X`` parameter controls whether the ``X`` data is returned as a dense or sparse - :class:`torch.Tensor`. If the model supports use of sparse :class:`torch.Tensor`\ s, this will reduce memory usage. - - The ``obs_column_names`` parameter determines the data columns that are returned in the ``obs`` Tensor. The first - element is always the ``soma_joinid`` of the ``obs`` :class:`pandas.DataFrame` (or, equivalently, the - ``soma_dim_0`` of the ``X`` matrix). The remaining elements are the ``obs`` columns specified by - ``obs_column_names``, and string-typed columns are encoded as integer values. If needed, these values can be decoded - by obtaining the encoder for a given ``obs`` column name and calling its ``inverse_transform`` method: - - >>> exp_data_pipe.obs_encoders[""].inverse_transform(encoded_values) - - Lifecycle: - experimental - """ - - _initialized: bool - - _obs_joinids: Optional[npt.NDArray[np.int64]] - - _var_joinids: Optional[npt.NDArray[np.int64]] - - _encoders: List[Encoder] - - _stats: Stats - - _shuffle_rng: Optional[Generator] - - # TODO: Consider adding another convenience method wrapper to construct this object whose signature is more closely - # aligned with get_anndata() params (i.e. "exploded" AxisQuery params). - def __init__( - self, - experiment: soma.Experiment, - measurement_name: str = "RNA", - X_name: str = "raw", - obs_query: Optional[soma.AxisQuery] = None, - var_query: Optional[soma.AxisQuery] = None, - obs_column_names: Sequence[str] = (), - batch_size: int = 1, - shuffle: bool = True, - seed: Optional[int] = None, - return_sparse_X: bool = False, - soma_chunk_size: Optional[int] = 64, - use_eager_fetch: bool = True, - encoders: Optional[List[Encoder]] = None, - shuffle_chunk_count: Optional[int] = 2000, - ) -> None: - r"""Construct a new ``ExperimentDataPipe``. - - Args: - experiment: - The :class:`tiledbsoma.Experiment` from which to read data. - measurement_name: - The name of the :class:`tiledbsoma.Measurement` to read. Defaults to ``"RNA"``. - X_name: - The name of the X layer to read. Defaults to ``"raw"``. - obs_query: - The query used to filter along the ``obs`` axis. If not specified, all ``obs`` and ``X`` data will - be returned, which can be very large. - var_query: - The query used to filter along the ``var`` axis. If not specified, all ``var`` columns (genes/features) - will be returned. - obs_column_names: - The names of the ``obs`` columns to return. The ``soma_joinid`` index "column" does not need to be - specified and will always be returned. If not specified, only the ``soma_joinid`` will be returned. - If custom encoders are passed, this parameter must not be used, since the columns will be inferred - automatically from the encoders. - batch_size: - The number of rows of ``obs`` and ``X`` data to return in each iteration. Defaults to ``1``. A value of - ``1`` will result in :class:`torch.Tensor` of rank 1 being returns (a single row); larger values will - result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). - shuffle: - Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``. - For performance reasons, shuffling is not performed globally across all rows, but rather in chunks. - More specifically, we select ``shuffle_chunk_count`` non-contiguous chunks across all the observations - in the query, concatenate the chunks and shuffle the associated observations. - The randomness of the shuffling is therefore determined by the - (``soma_chunk_size``, ``shuffle_chunk_count``) selection. The default values have been determined - to yield a good trade-off between randomness and performance. Further tuning may be required for - different type of models. Note that memory usage is correlated to the product - ``soma_chunk_size * shuffle_chunk_count``. - seed: - The random seed used for shuffling. Defaults to ``None`` (no seed). This *must* be specified when using - :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker - processes. - return_sparse_X: - Controls whether the ``X`` data is returned as a dense or sparse :class:`torch.Tensor`. As ``X`` data is - very sparse, setting this to ``True`` will reduce memory usage, if the model supports use of sparse - :class:`torch.Tensor`\ s. Defaults to ``False``, since sparse :class:`torch.Tensor`\ s are still - experimental in PyTorch. - soma_chunk_size: - The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of - this class's behavior: 1) The maximum memory utilization, with larger values providing - better read performance, but also requiring more memory; 2) The granularity of the global shuffling - step (see ``shuffle`` parameter for details). The default value of 64 works well in conjunction - with the default ``shuffle_chunk_count`` value. - use_eager_fetch: - Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made - available for processing via the iterator. This allows network (or filesystem) requests to be made in - parallel with client-side processing of the SOMA data, potentially improving overall performance at the - cost of doubling memory utilization. Defaults to ``True``. - shuffle_chunk_count: - The number of contiguous blocks (chunks) of rows sampled to then concatenate and shuffle. - Larger numbers correspond to more randomness per training batch. - If ``shuffle == False``, this parameter is ignored. Defaults to ``2000``. - encoders: - Specify custom encoders to be used. If not specified, a LabelEncoder will be created and - used for each column in ``obs_column_names``. If specified, only columns for which an encoder - has been registered will be returned in the ``obs`` tensor. Each encoder needs to have a unique name. - If this parameter is specified, the ``obs_column_names`` parameter must not be used, - since the columns will be inferred automatically from the encoders. - - Lifecycle: - experimental - """ - self.exp_uri = experiment.uri - self.aws_region = experiment.context.tiledb_config.get("vfs.s3.region") - self.measurement_name = measurement_name - self.layer_name = X_name - self.obs_query = obs_query - self.var_query = var_query - self.obs_column_names = obs_column_names - self.batch_size = batch_size - self.return_sparse_X = return_sparse_X - self.soma_chunk_size = soma_chunk_size - self.use_eager_fetch = use_eager_fetch - self._stats = Stats() - self._encoders = encoders or [] - self._obs_joinids = None - self._var_joinids = None - self._shuffle_chunk_count = shuffle_chunk_count if shuffle else None - self._shuffle_rng = np.random.default_rng(seed) if shuffle else None - self._initialized = False - self.max_process_mem_usage_bytes = 0 - - if obs_column_names and encoders: - raise ValueError( - "Cannot specify both `obs_column_names` and `encoders`. If `encoders` are specified, columns will be inferred automatically." - ) - - if encoders: - # Check if names are unique - if len(encoders) != len({enc.name for enc in encoders}): - raise ValueError("Encoders must have unique names") - - self.obs_column_names = list(dict.fromkeys(itertools.chain(*[enc.columns for enc in encoders]))) - - def _init(self) -> None: - if self._initialized: - return - - pytorch_logger.debug("Initializing ExperimentDataPipe") - - with _open_experiment(self.exp_uri, self.aws_region) as exp: - query = exp.axis_query( - measurement_name=self.measurement_name, - obs_query=self.obs_query, - var_query=self.var_query, - ) - - # The to_numpy() call is a workaround for a possible bug in TileDB-SOMA: - # https://github.com/single-cell-data/TileDB-SOMA/issues/1456 - self._obs_joinids = query.obs_joinids().to_numpy() - self._var_joinids = query.var_joinids().to_numpy() - - self._encoders = self._build_obs_encoders(query) - - self._initialized = True - - @staticmethod - def _subset_ids_to_partition( - ids_chunked: List[npt.NDArray[np.int64]], - partition_index: int, - num_partitions: int, - ) -> List[npt.NDArray[np.int64]]: - """Returns a single partition of the obs_joinids_chunked (a 2D ndarray), based upon the current process's distributed rank and world - size. - """ - # subset to a single partition - # typing does not reflect that is actually a List of 2D NDArrays - partition_indices = np.array_split(range(len(ids_chunked)), num_partitions) - partition = [ids_chunked[i] for i in partition_indices[partition_index]] - - if pytorch_logger.isEnabledFor(logging.DEBUG) and len(partition) > 0: - pytorch_logger.debug( - f"Process {os.getpid()} handling partition {partition_index + 1} of {num_partitions}, " - f"partition_size={sum([len(chunk) for chunk in partition])}" - ) - - return partition - - @staticmethod - def _compute_partitions( - loader_partition: int, - loader_partitions: int, - dist_partition: int, - num_dist_partitions: int, - ) -> Tuple[int, int]: - # NOTE: Can alternately use a `worker_init_fn` to split among workers split workload - total_partitions = num_dist_partitions * loader_partitions - partition = dist_partition * loader_partitions + loader_partition - return partition, total_partitions - - def __iter__(self) -> Iterator[ObsAndXDatum]: - self._init() - assert self._obs_joinids is not None - assert self._var_joinids is not None - - if self.soma_chunk_size is None: - # set soma_chunk_size to utilize ~1 GiB of RAM per SOMA chunk; assumes 95% X data sparsity, 8 bytes for the - # X value and 8 bytes for the sparse matrix indices, and a 100% working memory overhead (2x). - X_row_memory_size = 0.05 * len(self._var_joinids) * 8 * 3 * 2 - self.soma_chunk_size = int((1 * 1024**3) / X_row_memory_size) - pytorch_logger.debug(f"Using {self.soma_chunk_size=}") - - if ( - self.return_sparse_X - and torch.utils.data.get_worker_info() - and torch.utils.data.get_worker_info().num_workers > 0 - ): - raise NotImplementedError( - "torch does not work with sparse tensors in multi-processing mode " - "(see https://github.com/pytorch/pytorch/issues/20248)" - ) - - # chunk the obs joinids into batches of size soma_chunk_size - obs_joinids_chunked = self._chunk_ids(self._obs_joinids, self.soma_chunk_size) - - # globally shuffle the chunks, if requested - if self._shuffle_rng: - self._shuffle_rng.shuffle(obs_joinids_chunked) - - # subset to a single partition, as needed for distributed training and multi-processing datat loading - worker_info = torch.utils.data.get_worker_info() - partition, partitions = self._compute_partitions( - loader_partition=worker_info.id if worker_info else 0, - loader_partitions=worker_info.num_workers if worker_info else 1, - dist_partition=dist.get_rank() if dist.is_initialized() else 0, - num_dist_partitions=dist.get_world_size() if dist.is_initialized() else 1, - ) - obs_joinids_chunked_partition: List[npt.NDArray[np.int64]] = self._subset_ids_to_partition( - obs_joinids_chunked, partition, partitions - ) - - with _open_experiment(self.exp_uri, self.aws_region) as exp: - obs_and_x_iter = _ObsAndXIterator( - obs=exp.obs, - X=exp.ms[self.measurement_name].X[self.layer_name], - obs_column_names=self.obs_column_names, - obs_joinids_chunked=obs_joinids_chunked_partition, - var_joinids=self._var_joinids, - batch_size=self.batch_size, - encoders=self._encoders, - stats=self._stats, - return_sparse_X=self.return_sparse_X, - use_eager_fetch=self.use_eager_fetch, - shuffle_rng=self._shuffle_rng, - shuffle_chunk_count=self._shuffle_chunk_count, - ) - - yield from obs_and_x_iter - - self.max_process_mem_usage_bytes = obs_and_x_iter.max_process_mem_usage_bytes - pytorch_logger.debug( - "max process memory usage=" f"{self.max_process_mem_usage_bytes / (1024 ** 3):.3f} GiB" - ) - - @staticmethod - def _chunk_ids(ids: npt.NDArray[np.int64], chunk_size: int) -> List[npt.NDArray[np.int64]]: - num_chunks = max(1, ceil(len(ids) / chunk_size)) - pytorch_logger.debug(f"Shuffling {len(ids)} obs joinids into {num_chunks} chunks of {chunk_size}") - return np.array_split(ids, num_chunks) - - def __len__(self) -> int: - self._init() - assert self._obs_joinids is not None - - div, rem = divmod(len(self._obs_joinids), self.batch_size) - return div + bool(rem) - - def __getitem__(self, index: int) -> ObsAndXDatum: - raise NotImplementedError("IterDataPipe can only be iterated") - - def _build_obs_encoders(self, query: soma.ExperimentAxisQuery) -> List[Encoder]: - pytorch_logger.debug("Initializing encoders") - - encoders = [] - - if "soma_joinid" not in self.obs_column_names: - cols = ["soma_joinid", *self.obs_column_names] - else: - cols = list(self.obs_column_names) - - obs = query.obs(column_names=cols).concat().to_pandas() - - if self._encoders: - # Fit all the custom encoders with obs - for enc in self._encoders: - enc.fit(obs) - encoders.append(enc) - else: - # Create one LabelEncoder for each column, and fit it with obs - for col in self.obs_column_names: - enc = LabelEncoder(col) - enc.fit(obs) - encoders.append(enc) - - return encoders - - # TODO: This does not work in multiprocessing mode, as child process's stats are not collected - def stats(self) -> Stats: - """Get data loading stats for this :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. - - Returns: - The :class:`cellxgene_census.experimental.ml.pytorch.Stats` object for this - :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. - - Lifecycle: - experimental - """ - return self._stats - - @property - def shape(self) -> Tuple[int, int]: - """Get the shape of the data that will be returned by this :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. - This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode - (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect - the size of the partition of the data assigned to the active process. - - Returns: - A 2-tuple of ``int``s, for obs and var counts, respectively. - - Lifecycle: - experimental - """ - self._init() - assert self._obs_joinids is not None - assert self._var_joinids is not None - - return len(self._obs_joinids), len(self._var_joinids) - - @property - def obs_encoders(self) -> Encoders: - """Returns a dictionary of :class:`sklearn.preprocessing.LabelEncoder` objects, keyed on ``obs`` column names, - which were used to encode the ``obs`` column values. - - These encoders can be used to decode the encoded values as follows: - - >>> exp_data_pipe.obs_encoders[""].inverse_transform(encoded_values) - - Returns: - A ``Dict[str, LabelEncoder]``, mapping column names to :class:`sklearn.preprocessing.LabelEncoder` objects. - """ - self._init() - assert self._encoders is not None - - return {enc.name: enc for enc in self._encoders} - - -# Note: must be a top-level function (and not a lambda), to play nice with multiprocessing pickling -def _collate_noop(x: Any) -> Any: - return x - - -# TODO: Move into somacore.ExperimentAxisQuery -def experiment_dataloader( - datapipe: pipes.IterDataPipe, - num_workers: int = 0, - **dataloader_kwargs: Any, -) -> DataLoader: - """Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a - :class:`torch.utils.data.DataLoader` that works with :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`, - since some of the :class:`torch.utils.data.DataLoader` constructor parameters are not applicable when using a - :class:`torchdata.datapipes.iter.IterDataPipe` (``shuffle``, ``batch_size``, ``sampler``, ``batch_sampler``, - ``collate_fn``). - - Args: - datapipe: - An :class:`torchdata.datapipes.iter.IterDataPipe`, which can be an - :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe` or any other - :class:`torchdata.datapipes.iter.IterDataPipe` that has been chained to the - :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. - num_workers: - Number of worker processes to use for data loading. If ``0``, data will be loaded in the main process. - **dataloader_kwargs: - Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` constructor, - except for ``shuffle``, ``batch_size``, ``sampler``, ``batch_sampler``, and ``collate_fn``, which are not - supported when using :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. - - Returns: - A :class:`torch.utils.data.DataLoader`. - - Raises: - ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, ``batch_sampler``, or ``collate_fn`` params - are passed as keyword arguments. - - Lifecycle: - experimental - """ - unsupported_dataloader_args = [ - "shuffle", - "batch_size", - "sampler", - "batch_sampler", - "collate_fn", - ] - if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()): - raise ValueError(f"The {','.join(unsupported_dataloader_args)} DataLoader params are not supported") - - if num_workers > 0: - _init_multiprocessing() - - return DataLoader( - datapipe, - batch_size=None, # batching is handled by our ExperimentDataPipe - num_workers=num_workers, - # avoid use of default collator, which adds an extra (3rd) dimension to the tensor batches - collate_fn=_collate_noop, - # shuffling is handled by our ExperimentDataPipe - shuffle=False, - **dataloader_kwargs, - ) - - -def _init_multiprocessing() -> None: - """Ensures use of "spawn" for starting child processes with multiprocessing. - - Forked processes are known to be problematic: - https://pytorch.org/docs/stable/notes/multiprocessing.html#avoiding-and-fighting-deadlocks - Also, CUDA does not support forked child processes: - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing - - """ - torch.multiprocessing.set_start_method("fork", force=True) - orig_start_method = torch.multiprocessing.get_start_method() - if orig_start_method != "spawn": - if orig_start_method: - pytorch_logger.warning( - "switching torch multiprocessing start method from " - f'"{torch.multiprocessing.get_start_method()}" to "spawn"' - ) - torch.multiprocessing.set_start_method("spawn", force=True) diff --git a/notebooks/tutorial_lightning.ipynb b/notebooks/tutorial_lightning.ipynb new file mode 100644 index 0000000..daeedf5 --- /dev/null +++ b/notebooks/tutorial_lightning.ipynb @@ -0,0 +1,242 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training a model with PyTorch Lightning\n", + "\n", + "This tutorial provides a quick overview of training a toy model with Lightning, using the `tiledbsoma_ml.ExperimentAxisQueryIterableDataset` class, on data from the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/) dataset. This is intended only to demonstrate the use of the `ExperimentAxisQueryIterableDataset`, and not as an example of how to train a biologically useful model.\n", + "\n", + "For more information on these API, please refer to the [`tutorial_pytorch` notebook](tutorial_pytorch.ipynb).\n", + "\n", + "**Prerequesites**\n", + "\n", + "Install `tiledbsoma_ml` and `scikit-learn`, for example:\n", + "\n", + "> pip install tiledbsoma_ml scikit-learn\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize SOMA Experiment query as training data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n" + ] + } + ], + "source": [ + "import pytorch_lightning as pl\n", + "import tiledbsoma as soma\n", + "import torch\n", + "from sklearn.preprocessing import LabelEncoder\n", + "\n", + "import tiledbsoma_ml as soma_ml\n", + "\n", + "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", + "\n", + "experiment = soma.open(\n", + " CZI_Census_Homo_Sapiens_URL,\n", + " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"}),\n", + ")\n", + "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", + "\n", + "with experiment.axis_query(\n", + " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + ") as query:\n", + " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", + " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", + "\n", + " experiment_dataset = soma_ml.ExperimentAxisQueryIterableDataset(\n", + " query,\n", + " X_name=\"raw\",\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=128,\n", + " shuffle=True,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the Lightning module" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class LogisticRegressionLightning(pl.LightningModule):\n", + " def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=1e-5):\n", + " super(LogisticRegressionLightning, self).__init__()\n", + " self.linear = torch.nn.Linear(input_dim, output_dim)\n", + " self.cell_type_encoder = cell_type_encoder\n", + " self.learning_rate = learning_rate\n", + " self.loss_fn = torch.nn.CrossEntropyLoss()\n", + "\n", + " def forward(self, x):\n", + " outputs = torch.sigmoid(self.linear(x))\n", + " return outputs\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " X_batch, y_batch = batch\n", + " # X_batch = X_batch.float()\n", + " X_batch = torch.from_numpy(X_batch).float().to(self.device)\n", + "\n", + " # Perform prediction\n", + " outputs = self(X_batch)\n", + "\n", + " # Determine the predicted label\n", + " probabilities = torch.nn.functional.softmax(outputs, 1)\n", + " predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + " # Compute loss\n", + " y_batch = torch.from_numpy(\n", + " self.cell_type_encoder.transform(y_batch[\"cell_type\"])\n", + " ).to(self.device)\n", + " loss = self.loss_fn(outputs, y_batch.long())\n", + "\n", + " # Compute accuracy\n", + " train_correct = (predictions == y_batch).sum().item()\n", + " train_accuracy = train_correct / len(predictions)\n", + "\n", + " # Log loss and accuracy\n", + " self.log(\"train_loss\", loss, prog_bar=True)\n", + " self.log(\"train_accuracy\", train_accuracy, prog_bar=True)\n", + "\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " return optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train the model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params | Mode \n", + "-----------------------------------------------------\n", + "0 | linear | Linear | 726 K | train\n", + "1 | loss_fn | CrossEntropyLoss | 0 | train\n", + "-----------------------------------------------------\n", + "726 K Trainable params\n", + "0 Non-trainable params\n", + "726 K Total params\n", + "2.905 Total estimated model params size (MB)\n", + "2 Modules in train mode\n", + "0 Modules in eval mode\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:122: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.31it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=20` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.28it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]\n" + ] + } + ], + "source": [ + "dataloader = soma_ml.experiment_dataloader(experiment_dataset)\n", + "\n", + "# The size of the input dimension is the number of genes\n", + "input_dim = experiment_dataset.shape[1]\n", + "\n", + "# The size of the output dimension is the number of distinct cell_type values\n", + "output_dim = len(cell_type_encoder.classes_)\n", + "\n", + "# Initialize the PyTorch Lightning model\n", + "model = LogisticRegressionLightning(\n", + " input_dim, output_dim, cell_type_encoder=cell_type_encoder\n", + ")\n", + "\n", + "# Define the PyTorch Lightning Trainer\n", + "trainer = pl.Trainer(max_epochs=20)\n", + "\n", + "# set precision\n", + "torch.set_float32_matmul_precision(\"high\")\n", + "\n", + "# Train the model\n", + "trainer.fit(model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "toymodel", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/tutorial_multiworker.ipynb b/notebooks/tutorial_multiworker.ipynb new file mode 100644 index 0000000..37e17e8 --- /dev/null +++ b/notebooks/tutorial_multiworker.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multi-process training\n", + "\n", + "Multi-process usage of `tiledbsoma_ml.ExperimentAxisQueryIterDataset` includes both:\n", + "* using the `torch.utils.data.DataLoader` with 1 or more worker (ie., with an argument of `n_workers=1` or greater)\n", + "* using a multi-process training configuration, such as `DistributedDataParallel`\n", + "\n", + "In these configurations, `ExperimentAxisQueryIterDataset` will automatically partition data across workers. However, when using `shuffle=True`, there are several things to keep in mind:\n", + "\n", + "1. All worker processes must share the same random number generator `seed`, ensuring that all workers shuffle and partition the data in the same way.\n", + "2. To ensure that each epoch returns a _different_ shuffle, the caller must set the epoch, using the `set_epoch` API. This is identical to the behavior of `torch.utils.data.distributed.DistributedSampler`.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n" + ] + } + ], + "source": [ + "import tiledbsoma as soma\n", + "import torch\n", + "from sklearn.preprocessing import LabelEncoder\n", + "\n", + "import tiledbsoma_ml as soma_ml\n", + "\n", + "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", + "\n", + "experiment = soma.open(\n", + " CZI_Census_Homo_Sapiens_URL,\n", + " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"}),\n", + ")\n", + "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", + "\n", + "with experiment.axis_query(\n", + " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + ") as query:\n", + " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", + " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", + "\n", + " experiment_dataset = soma_ml.ExperimentAxisQueryIterableDataset(\n", + " query,\n", + " X_name=\"raw\",\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=128,\n", + " shuffle=True,\n", + " )\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "class LogisticRegression(torch.nn.Module):\n", + " def __init__(self, input_dim, output_dim):\n", + " super(LogisticRegression, self).__init__() # noqa: UP008\n", + " self.linear = torch.nn.Linear(input_dim, output_dim)\n", + "\n", + " def forward(self, x):\n", + " outputs = torch.sigmoid(self.linear(x))\n", + " return outputs\n", + " \n", + "\n", + "def train_epoch(model, train_dataloader, loss_fn, optimizer, device):\n", + " model.train()\n", + " train_loss = 0\n", + " train_correct = 0\n", + " train_total = 0\n", + "\n", + " for X_batch, y_batch in train_dataloader:\n", + " optimizer.zero_grad()\n", + "\n", + " X_batch = torch.from_numpy(X_batch).float().to(device)\n", + "\n", + " # Perform prediction\n", + " outputs = model(X_batch)\n", + "\n", + " # Determine the predicted label\n", + " probabilities = torch.nn.functional.softmax(outputs, 1)\n", + " predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + " # Compute the loss and perform back propagation\n", + " y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)\n", + " train_correct += (predictions == y_batch).sum().item()\n", + " train_total += len(predictions)\n", + "\n", + " loss = loss_fn(outputs, y_batch.long())\n", + " train_loss += loss.item()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " train_loss /= train_total\n", + " train_accuracy = train_correct / train_total\n", + " return train_loss, train_accuracy" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multi-worker DataLoader\n", + "\n", + "If you use a multi-worker data loader (i.e., `num_workers` with a value other than `0`), and `shuffle=True`, remember to call `set_epoch` at the start of each epoch, _before_ the iterator is created.\n", + "\n", + "The same approach should be taken for parallel training, e.g., when using DDP or DP.\n", + "\n", + "*Tip*: when running with `num_workers=0`, i.e., using the data loader in-process, the `ExperimentAxisQueryIterDataset` will automatically increment the epoch count each time the iterator completes." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "switching torch multiprocessing start method from \"fork\" to \"spawn\"\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1: Train Loss: 0.0169229 Accuracy 0.3124\n", + "Epoch 2: Train Loss: 0.0148674 Accuracy 0.4272\n", + "Epoch 3: Train Loss: 0.0144468 Accuracy 0.4509\n", + "Epoch 4: Train Loss: 0.0141778 Accuracy 0.4999\n", + "Epoch 5: Train Loss: 0.0139660 Accuracy 0.5619\n", + "Epoch 6: Train Loss: 0.0137670 Accuracy 0.6971\n", + "Epoch 7: Train Loss: 0.0136089 Accuracy 0.8670\n", + "Epoch 8: Train Loss: 0.0135203 Accuracy 0.9099\n", + "Epoch 9: Train Loss: 0.0134427 Accuracy 0.9262\n", + "Epoch 10: Train Loss: 0.0133607 Accuracy 0.9300\n", + "Epoch 11: Train Loss: 0.0133110 Accuracy 0.9348\n", + "Epoch 12: Train Loss: 0.0132749 Accuracy 0.9378\n", + "Epoch 13: Train Loss: 0.0132431 Accuracy 0.9413\n", + "Epoch 14: Train Loss: 0.0132194 Accuracy 0.9444\n", + "Epoch 15: Train Loss: 0.0131942 Accuracy 0.9465\n", + "Epoch 16: Train Loss: 0.0131739 Accuracy 0.9499\n", + "Epoch 17: Train Loss: 0.0131527 Accuracy 0.9526\n", + "Epoch 18: Train Loss: 0.0131369 Accuracy 0.9551\n", + "Epoch 19: Train Loss: 0.0131214 Accuracy 0.9563\n", + "Epoch 20: Train Loss: 0.0131061 Accuracy 0.9578\n" + ] + } + ], + "source": [ + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + "# The size of the input dimension is the number of genes\n", + "input_dim = experiment_dataset.shape[1]\n", + "\n", + "# The size of the output dimension is the number of distinct cell_type values\n", + "output_dim = len(cell_type_encoder.classes_)\n", + "\n", + "model = LogisticRegression(input_dim, output_dim).to(device)\n", + "loss_fn = torch.nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)\n", + "\n", + "\n", + "# define a two-worker data loader. The dataset is shuffled, so call `set_epoch` to ensure\n", + "# that a different shuffle is applied on each epoch.\n", + "experiment_dataloader = soma_ml.experiment_dataloader(\n", + " experiment_dataset, num_workers=2, persistent_workers=True\n", + ")\n", + "\n", + "for epoch in range(20):\n", + " experiment_dataset.set_epoch(epoch)\n", + " train_loss, train_accuracy = train_epoch(\n", + " model, experiment_dataloader, loss_fn, optimizer, device\n", + " )\n", + " print(\n", + " f\"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}\"\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "toymodel", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/tutorial_pytorch.ipynb b/notebooks/tutorial_pytorch.ipynb new file mode 100644 index 0000000..70c62e3 --- /dev/null +++ b/notebooks/tutorial_pytorch.ipynb @@ -0,0 +1,609 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training a PyTorch Model\n", + "\n", + "This tutorial shows how to train a Logistic Regression model in PyTorch using the `tiledbsoma.ml.ExperimentAxisQueryIterDataPipe` class, and the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/) dataset. This is intended only to demonstrate the use of the `ExperimentAxisQueryIterDataPipe`, and not as an example of how to train a biologically useful model.\n", + "\n", + "This tutorial assumes a basic familiarity with PyTorch and the Census API.\n", + "\n", + "**Prerequisites**\n", + "\n", + "Install `tiledbsoma` with the optional `ml` dependencies, for example:\n", + "\n", + "> pip install tiledbsoma[ml]\n", + "\n", + "\n", + "**Contents**\n", + "\n", + "* [Create a DataLoader](#Create-a-DataLoader)\n", + "* [Define the model](#Define-the-model)\n", + "* [Train the model](#Train-the-model)\n", + "* [Make predictions with the model](#Make-predictions-with-the-model)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create an ExperimentAxisQueryIterDataPipe\n", + "\n", + "To train a model in PyTorch using this `census` data object, first instantiate open a SOMA Experiment, and create a `ExperimentAxisQueryIterDataPipe`. This example utilizes a recent CZI Census release, access directly from S3.\n", + "\n", + "We are also going to create an encoder for the `obs` labels at the same time, and train it on the `cell_type` labels. In this example we use the LabelEncoder from `scikit-learn`." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import tiledbsoma as soma\n", + "from sklearn.preprocessing import LabelEncoder\n", + "\n", + "import tiledbsoma_ml as soma_ml\n", + "\n", + "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", + "\n", + "experiment = soma.open(\n", + " CZI_Census_Homo_Sapiens_URL,\n", + " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"}),\n", + ")\n", + "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", + "\n", + "with experiment.axis_query(\n", + " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + ") as query:\n", + " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", + " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", + "\n", + " experiment_dataset = soma_ml.ExperimentAxisQueryIterDataPipe(\n", + " query,\n", + " X_name=\"raw\",\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=128,\n", + " shuffle=True,\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `ExperimentAxisQueryIterDataPipe` class explained\n", + "\n", + "This class provides an implementation of PyTorch's `torchdata` [IterDataPipe interface](https://pytorch.org/data/main/torchdata.datapipes.iter.html), which defines a common mechanism for wrapping and accessing training data from any underlying source. The `ExperimentAxisQueryIterDataPipe` class encapsulates the details of querying and retrieving Census data from a single SOMA `Experiment` and returning it to the caller a NumPy `ndarray` and a Pandas `DataFrame`. Most importantly, it retrieves the data lazily from the Census in batches, avoiding having to load the entire training dataset into memory at once.\n", + "\n", + "### `ExperimentAxisQueryIterDataPipe` parameters explained\n", + "\n", + "The constructor only requires a single parameter, `experiment`, which is a `soma.Experiment` containing the data of the organism to be used for training.\n", + "\n", + "To retrieve a subset of the Experiment's data, along either the `obs` or `var` axes, you may specify query filters via the `obs_query` and `var_query` parameters, which are both `soma.AxisQuery` objects.\n", + "\n", + "The values for the prediction label(s) that you intend to use for training are specified via the `obs_column_names` array.\n", + "\n", + "The `batch_size` allows you to specify the number of obs rows (cells) to be returned by each return PyTorch tensor. You may exclude this parameter if you want single rows (`batch_size=1`).\n", + "\n", + "The `shuffle` flag allows you to randomize the ordering of the training data for each training epoch. Note:\n", + "* You should use this flag instead of the `DataLoader` `shuffle` flag, primarily for performance reasons.\n", + "* PyTorch's TorchData library provides a [Shuffler](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.Shuffler.html) `DataPipe`, which is alternate mechanism one can use to perform shuffling of an `IterableDataset`. However, the `Shuffler` will not \"globally\" randomize the training data, as it only \"locally\" randomizes the ordering of the training data within fixed-size \"windows\". Due to the layout of Census data, a given \"window\" of Census data may be highly homogeneous in terms of its `obs` axis attribute values, and so this shuffling strategy may not provide sufficient randomization for certain types of models." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can inspect the shape of the full dataset, without causing the full dataset to be loaded. The `shape` property returns the number of batches on the first dimension:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(118, 60530)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiment_dataset.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Split the dataset\n", + "\n", + "You may split the overall dataset into the typical training, validation, and test sets by using the PyTorch [RandomSplitter](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.RandomSplitter.html#torchdata.datapipes.iter.RandomSplitter) `DataPipe`. Using PyTorch's functional form for chaining `DataPipe`s, this is done as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset, test_dataset = experiment_dataset.random_split(weights={\"train\": 0.8, \"test\": 0.2}, seed=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create the DataLoader\n", + "\n", + "With the full set of DataPipe operations chained together, we can now instantiate a PyTorch [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) on the training data. " + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "experiment_dataloader = soma_ml.experiment_dataloader(train_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternately, you can instantiate a `DataLoader` object directly via its constructor. However, many of the parameters are not usable with iterable-style Datasets, which is the case for `ExperimentAxisQueryIterDataPipe`. In particular, the `shuffle`, `batch_size`, `sampler`, `batch_sampler`, `collate_fn` parameters should not be specified. Using `experiment_dataloader` helps enforce correct usage." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the model\n", + "\n", + "With the training data retrieval code now in place, we can move on to defining a simple logistic regression model, using PyTorch's `torch.nn.Linear` class:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "\n", + "class LogisticRegression(torch.nn.Module):\n", + " def __init__(self, input_dim, output_dim):\n", + " super(LogisticRegression, self).__init__() # noqa: UP008\n", + " self.linear = torch.nn.Linear(input_dim, output_dim)\n", + "\n", + " def forward(self, x):\n", + " outputs = torch.sigmoid(self.linear(x))\n", + " return outputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we define a function to train the model for a single epoch:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def train_epoch(model, train_dataloader, loss_fn, optimizer, device):\n", + " model.train()\n", + " train_loss = 0\n", + " train_correct = 0\n", + " train_total = 0\n", + "\n", + " for X_batch, y_batch in train_dataloader:\n", + " optimizer.zero_grad()\n", + "\n", + " X_batch = torch.from_numpy(X_batch).float().to(device)\n", + "\n", + " # Perform prediction\n", + " outputs = model(X_batch)\n", + "\n", + " # Determine the predicted label\n", + " probabilities = torch.nn.functional.softmax(outputs, 1)\n", + " predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + " # Compute the loss and perform back propagation\n", + " y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)\n", + " train_correct += (predictions == y_batch).sum().item()\n", + " train_total += len(predictions)\n", + "\n", + " loss = loss_fn(outputs, y_batch.long())\n", + " train_loss += loss.item()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " train_loss /= train_total\n", + " train_accuracy = train_correct / train_total\n", + " return train_loss, train_accuracy" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note the line, `X_batch, y_batch = batch`. Since the `train_dataloader` was configured with `batch_size=16`, these variables will hold tensors of rank 2. The `X_batch` tensor will appear, for example, as:\n", + "\n", + "```\n", + "tensor([[0., 0., 0., ..., 1., 0., 0.],\n", + " [0., 0., 2., ..., 0., 3., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " ...,\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 1., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 8.]])\n", + " \n", + "```\n", + "\n", + "For `batch_size=1`, the tensors will be of rank 1. The `X_batch` tensor will appear, for example, as:\n", + "\n", + "```\n", + "tensor([0., 0., 0., ..., 1., 0., 0.])\n", + "```\n", + " \n", + "For `y_batch`, this will contain the user-specified `obs` `cell_type` training labels. By default, these are encoded using a LabelEncoder and it will be a matrix where each column represents the encoded values of each column specified in `obs_column_names` when creating the datapipe (in this case, only the cell type). It will look like this:\n", + "\n", + "```\n", + "tensor([1, 1, 3, ..., 2, 1, 4])\n", + "\n", + "```\n", + "Note that cell type values are integer-encoded values, which can be decoded using `experiment_dataset.encoders` (more on this below).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train the model\n", + "\n", + "Finally, we are ready to train the model. Here we instantiate the model, a loss function, and an optimization method and then iterate through the desired number of training epochs. Note how the `train_dataloader` is passed into `train_epoch`, where for each epoch it will provide a new iterator through the training dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1: Train Loss: 0.0171090 Accuracy 0.1798\n", + "Epoch 2: Train Loss: 0.0151506 Accuracy 0.3480\n", + "Epoch 3: Train Loss: 0.0146299 Accuracy 0.4174\n", + "Epoch 4: Train Loss: 0.0142093 Accuracy 0.4765\n", + "Epoch 5: Train Loss: 0.0140261 Accuracy 0.5111\n", + "Epoch 6: Train Loss: 0.0138939 Accuracy 0.5634\n", + "Epoch 7: Train Loss: 0.0137783 Accuracy 0.6182\n", + "Epoch 8: Train Loss: 0.0136766 Accuracy 0.7050\n", + "Epoch 9: Train Loss: 0.0135647 Accuracy 0.8293\n", + "Epoch 10: Train Loss: 0.0134729 Accuracy 0.8793\n", + "Epoch 11: Train Loss: 0.0133968 Accuracy 0.8938\n", + "Epoch 12: Train Loss: 0.0133453 Accuracy 0.9013\n", + "Epoch 13: Train Loss: 0.0133143 Accuracy 0.9047\n", + "Epoch 14: Train Loss: 0.0132873 Accuracy 0.9102\n", + "Epoch 15: Train Loss: 0.0132666 Accuracy 0.9176\n", + "Epoch 16: Train Loss: 0.0132246 Accuracy 0.9219\n", + "Epoch 17: Train Loss: 0.0132161 Accuracy 0.9230\n", + "Epoch 18: Train Loss: 0.0131877 Accuracy 0.9295\n", + "Epoch 19: Train Loss: 0.0131658 Accuracy 0.9344\n", + "Epoch 20: Train Loss: 0.0131338 Accuracy 0.9382\n" + ] + } + ], + "source": [ + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + "# The size of the input dimension is the number of genes\n", + "input_dim = experiment_dataset.shape[1]\n", + "\n", + "# The size of the output dimension is the number of distinct cell_type values\n", + "output_dim = len(cell_type_encoder.classes_)\n", + "\n", + "model = LogisticRegression(input_dim, output_dim).to(device)\n", + "loss_fn = torch.nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)\n", + "\n", + "for epoch in range(20):\n", + " train_loss, train_accuracy = train_epoch(model, experiment_dataloader, loss_fn, optimizer, device)\n", + " print(f\"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Make predictions with the model\n", + "\n", + "To make predictions with the model, we first create a new `DataLoader` using the `test_dataset`, which provides the \"test\" split of the original dataset. For this example, we will only make predictions on a single batch of data from the test split." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "experiment_dataloader = soma_ml.experiment_dataloader(test_dataset)\n", + "X_batch, y_batch = next(iter(experiment_dataloader))\n", + "X_batch = torch.from_numpy(X_batch)\n", + "y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type']))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we invoke the model on the `X_batch` input data and extract the predictions:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 8, 1, 1, 1, 1, 1, 1, 8, 8, 5, 1, 7, 8, 1, 1, 1, 1, 7,\n", + " 7, 8, 1, 1, 5, 5, 1, 8, 1, 1, 1, 7, 8, 7, 7, 7, 8, 7,\n", + " 5, 1, 1, 8, 1, 5, 8, 5, 1, 11, 1, 7, 1, 1, 5, 5, 1, 11,\n", + " 1, 6, 8, 5, 1, 8, 11, 8, 1, 8, 1, 8, 1, 5, 1, 1, 1, 8,\n", + " 8, 7, 5, 1, 1, 8, 1, 7, 2, 1, 7, 1, 5, 1, 1, 7, 1, 8,\n", + " 1, 1, 1, 7, 7, 1, 1, 1, 7, 1, 1, 7, 7, 5, 7, 8, 5, 1,\n", + " 5, 1, 5, 5, 5, 1, 1, 1, 8, 5, 1, 1, 7, 8, 1, 1, 1, 1,\n", + " 8, 1], device='cuda:0')" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.eval()\n", + "\n", + "model.to(device)\n", + "outputs = model(X_batch.to(device))\n", + "\n", + "probabilities = torch.nn.functional.softmax(outputs, 1)\n", + "predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + "display(predictions)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The predictions are returned as the encoded values of `cell_type` label. To recover the original cell type labels as strings, we decode using the same `LabelEncoder` used for training.\n", + "\n", + "At inference time, if the model inputs are not obtained via an `ExperimentAxisQueryIterDataPipe`, one could pickle the encoder at training time and save it along with the model. Then, at inference time it can be unpickled and used as shown below." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'leukocyte', 'leukocyte',\n", + " 'epithelial cell', 'basal cell', 'keratinocyte', 'leukocyte',\n", + " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'keratinocyte', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'epithelial cell', 'epithelial cell', 'basal cell',\n", + " 'leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'leukocyte', 'keratinocyte', 'keratinocyte',\n", + " 'keratinocyte', 'leukocyte', 'keratinocyte', 'epithelial cell',\n", + " 'basal cell', 'basal cell', 'leukocyte', 'basal cell',\n", + " 'epithelial cell', 'leukocyte', 'epithelial cell', 'basal cell',\n", + " 'vein endothelial cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'epithelial cell', 'epithelial cell',\n", + " 'basal cell', 'vein endothelial cell', 'basal cell', 'fibroblast',\n", + " 'leukocyte', 'epithelial cell', 'basal cell', 'leukocyte',\n", + " 'vein endothelial cell', 'leukocyte', 'basal cell', 'leukocyte',\n", + " 'basal cell', 'leukocyte', 'basal cell', 'epithelial cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'leukocyte', 'leukocyte',\n", + " 'keratinocyte', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'leukocyte', 'basal cell', 'keratinocyte',\n", + " 'capillary endothelial cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'basal cell', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'basal cell', 'keratinocyte', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'keratinocyte', 'keratinocyte',\n", + " 'epithelial cell', 'keratinocyte', 'leukocyte', 'epithelial cell',\n", + " 'basal cell', 'epithelial cell', 'basal cell', 'epithelial cell',\n", + " 'epithelial cell', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'leukocyte', 'epithelial cell', 'basal cell',\n", + " 'basal cell', 'keratinocyte', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'leukocyte',\n", + " 'basal cell'], dtype=object)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "predicted_cell_types = cell_type_encoder.inverse_transform(predictions.cpu())\n", + "\n", + "display(predicted_cell_types)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we create a Pandas DataFrame to examine the predictions:" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
actual cell typepredicted cell type
0leukocyteleukocyte
1basal cellbasal cell
2basal cellbasal cell
3basal cellbasal cell
4basal cellbasal cell
.........
123fibroblastbasal cell
124basal cellbasal cell
125keratinocytebasal cell
126leukocyteleukocyte
127basal cellbasal cell
\n", + "

128 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " actual cell type predicted cell type\n", + "0 leukocyte leukocyte\n", + "1 basal cell basal cell\n", + "2 basal cell basal cell\n", + "3 basal cell basal cell\n", + "4 basal cell basal cell\n", + ".. ... ...\n", + "123 fibroblast basal cell\n", + "124 basal cell basal cell\n", + "125 keratinocyte basal cell\n", + "126 leukocyte leukocyte\n", + "127 basal cell basal cell\n", + "\n", + "[128 rows x 2 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "display(\n", + " pd.DataFrame(\n", + " {\n", + " \"actual cell type\": cell_type_encoder.inverse_transform(y_batch.ravel().numpy()),\n", + " \"predicted cell type\": predicted_cell_types,\n", + " }\n", + " )\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tiledbsoma-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7901253 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,76 @@ +[build-system] +requires = ["setuptools >= 61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "tiledbsoma-ml" +dynamic = ["version"] +dependencies = [ + "attrs>=22.2", + "tiledbsoma>=1.9.0", + "torch>=2.0", + "torchdata<=0.9", + "numpy", + "numba", + "pandas", + "pyarrow", + "scipy" +] +requires-python = ">= 3.9" +description = "Machine learning tools for use with tiledbsoma" +readme = "README.md" +authors = [ + {name = "TileDB, Inc.", email = "help@tiledb.io"}, + {name = "The Chan Zuckerberg Initiative Foundation", email = "soma@chanzuckerberg.com" }, +] +maintainers = [ + {name = "TileDB, Inc.", email="help@tiledb.io"}, +] + +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Information Technology", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Topic :: Scientific/Engineering :: Bio-Informatics", + "Operating System :: Unix", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS :: MacOS X", + "Operating System :: Microsoft :: Windows", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] + +[project.urls] +Repository = "https://github.com/TileDB-Inc/TileDB-SOMA-ML.git" +Issues = "https://github.com/TileDB-Inc/TileDB-SOMA-ML/issues" +Changelog = "https://github.com/TileDB-Inc/TileDB-SOMA-ML/blob/main/CHANGELOG.md" + +[tool.setuptools.dynamic] +version = {attr = "tiledbsoma_ml.__version__"} + +[tool.setuptools.package-data] +"tiledbsoma_ml" = ["py.typed"] + +[tool.setuptools_scm] +root = "../../.." + +[tool.mypy] +show_error_codes = true +ignore_missing_imports = true +warn_unreachable = true +strict = true +python_version = 3.9 +plugins = "numpy.typing.mypy_plugin" + +[tool.ruff] +lint.select = ["E", "F", "B", "I"] +lint.ignore = ["E501"] # line too long +lint.extend-select = ["I001"] # unsorted-imports +fix = true +target-version = "py39" +line-length = 120 diff --git a/scripts/rehome-census.sh b/scripts/rehome-census.sh new file mode 100755 index 0000000..b742b2b --- /dev/null +++ b/scripts/rehome-census.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash +# +# Create a fresh TileDB-SOMA-ML clone, with some "re-homed" history: +# 1. Reproduce the Git history of `api/python/cellxgene_census/src/cellxgene_census/experimental/ml` in the CELLxGENE Census repo: +# - This was developed between May 2023 and July 2024 +# - A few files are omitted, that are not relevant to "PyTorch loaders" work (namely the "huggingface" subdirectory) +# 2. Insert one commit moving these files to `apis/python/src/tiledbsoma/ml` (which they were copied to, in the TileDB-SOMA repo, in July 2024) +# 3. Replay the `bkmartinjr/experimentdatapipe` branch of TileDB-SOMA (developed between July 2024 and September 2024) on top of this +# - Just the commits that touch `apis/python/src/tiledbsoma/ml` directory, or a few other relevant paths (e.g. `other_packages, where they were moved, later in that branch's development) + +set -ex + +pip install git-filter-repo + +# Create a Census clone, filter to files/commits relevant to PyTorch loaders: +git clone -o origin https://github.com/chanzuckerberg/cellxgene-census census-ml && cd census-ml +ml=api/python/cellxgene_census/src/cellxgene_census/experimental/ml +git filter-repo \ + --path $ml/__init__.py \ + --path $ml/pytorch.py \ + --path $ml/encoders.py \ + --path $ml/util +cd .. + +# Create a TileDB-SOMA clone, filter to files/commits relevant to PyTorch loaders: +git clone -o origin -b bkmartinjr/experimentdatapipe git@github.com:single-cell-data/TileDB-SOMA.git soma-pytorch && cd soma-pytorch +git branch -m main +renames=() +for p in CHANGELOG.md README.md notebooks pyproject.toml src tests; do + renames+=(--path-rename "other_packages/python/tiledbsoma_ml/$p:$p") +done +git filter-repo --force \ + --path other_packages \ + --path apis/python/src/tiledbsoma/ml \ + --path .github/workflows/python-tiledbsoma-ml.yml \ + "${renames[@]}" +cd .. + +# Initialize TileDB-SOMA-ML clone, fetch filtered Census and TileDB-SOMA branches from the adjacent directories above: +git clone https://github.com/ryan-williams/TileDB-SOMA-ML soma-ml && cd soma-ml +git remote add c ../census-ml && git fetch c +git remote add t ../soma-pytorch && git fetch t +git reset --hard c/main + +# From the filtered Census HEAD, `git mv` the files to where the TileDB-SOMA branch ported them +tdbs=apis/python/src/tiledbsoma +mkdir -p $tdbs +git mv $ml $tdbs/ + +# Cherry-pick the root commit of the TileDB-SOMA port +root="$(git rev-list --max-parents=0 t/main)" +git cherry-pick $root +# Ensure all files match the TileDB-SOMA root commit +git status --porcelain | grep '^UU' | cut -c4- | xargs git checkout --theirs -- +# Verify there are no diffs vs TileDB-SOMA root commit +git diff --exit-code $root + +# Rebase `$root..t/main` (the rest of the filtered TileDB-SOMA commits) onto cherry-picked HEAD +git reset --hard t/main +git rebase --onto "HEAD@{1}" $root diff --git a/src/tiledbsoma_ml/__init__.py b/src/tiledbsoma_ml/__init__.py new file mode 100644 index 0000000..263608f --- /dev/null +++ b/src/tiledbsoma_ml/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +"""An API to support machine learning applications built on SOMA.""" + +from .pytorch import ( + ExperimentAxisQueryIterableDataset, + ExperimentAxisQueryIterDataPipe, + experiment_dataloader, +) + +__version__ = "0.1.0-dev" + +__all__ = [ + "ExperimentAxisQueryIterDataPipe", + "ExperimentAxisQueryIterableDataset", + "experiment_dataloader", +] diff --git a/src/tiledbsoma_ml/py.typed b/src/tiledbsoma_ml/py.typed new file mode 100644 index 0000000..288a150 --- /dev/null +++ b/src/tiledbsoma_ml/py.typed @@ -0,0 +1,2 @@ +# Marker file to indicate that this package contains Python typing information, +# and that mypy can use it to typecheck client code. diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py new file mode 100644 index 0000000..0e3bb33 --- /dev/null +++ b/src/tiledbsoma_ml/pytorch.py @@ -0,0 +1,1331 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from __future__ import annotations + +import contextlib +import gc +import itertools +import logging +import math +import os +import sys +import time +from contextlib import contextmanager +from itertools import islice +from math import ceil +from typing import ( + TYPE_CHECKING, + Any, + ContextManager, + Dict, + Iterable, + Iterator, + Sequence, + Tuple, + TypeVar, + Union, + cast, +) + +import attrs +import numba +import numpy as np +import numpy.typing as npt +import pandas as pd +import pyarrow as pa +import scipy.sparse as sparse +import tiledbsoma as soma +import torch +import torchdata +from somacore.query._eager_iter import EagerIterator as _EagerIterator +from typing_extensions import Self, TypeAlias + +logger = logging.getLogger("tiledbsoma_ml.pytorch") + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + +if TYPE_CHECKING: + # Python 3.8 does not support subscripting types, so work-around by + # restricting this to when we are running a type checker. TODO: remove + # the conditional when Python 3.8 support is dropped. + NDArrayNumber: TypeAlias = npt.NDArray[np.number[Any]] + XDatum: TypeAlias = Union[NDArrayNumber, sparse.csr_matrix] +else: + NDArrayNumber: TypeAlias = np.ndarray + XDatum: TypeAlias = Union[np.ndarray, sparse.csr_matrix] + +XObsDatum: TypeAlias = Tuple[XDatum, pd.DataFrame] +"""Return type of ``ExperimentAxisQueryIterableDataset`` and ``ExperimentAxisQueryIterDataPipe``, +which pairs a slice of ``X`` rows with a cooresponding slice of ``obs``. In the default case, +the datum is a tuple of :class:`numpy.ndarray` and :class:`pandas.DataFrame` (for ``X`` and ``obs`` +respectively). If the object is created with ``return_sparse_X`` as True, the ``X`` slice is +returned as a :class:`scipy.sparse.csr_matrix`. If the ``batch_size`` is 1, the :class:`numpy.ndarray` +will be returned with rank 1; in all other cases, objects are returned with rank 2.""" + + +@attrs.define(frozen=True, kw_only=True) +class _ExperimentLocator: + """State required to open the Experiment. + + Necessary as we will likely be invoked across multiple processes. + + Private implementation class. + """ + + uri: str + tiledb_timestamp_ms: int + tiledb_config: Dict[str, Union[str, float]] + + @classmethod + def create(cls, experiment: soma.Experiment) -> "_ExperimentLocator": + return _ExperimentLocator( + uri=experiment.uri, + tiledb_timestamp_ms=experiment.tiledb_timestamp_ms, + tiledb_config=experiment.context.tiledb_config, + ) + + @contextmanager + def open_experiment(self) -> Iterator[soma.Experiment]: + context = soma.SOMATileDBContext(tiledb_config=self.tiledb_config) + with soma.Experiment.open( + self.uri, tiledb_timestamp=self.tiledb_timestamp_ms, context=context + ) as exp: + yield exp + + +class ExperimentAxisQueryIterable(Iterable[XObsDatum]): + """An :class:`Iterator` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as + selected by a user-specified :class:`tiledbsoma.ExperimentAxisQuery`. Each step of the iterator + produces equal sized ``X`` and ``obs`` data, in the form of a :class:`numpy.ndarray` and + :class:`pandas.DataFrame`, respectively. + + Private base class for subclasses of :class:`torch.utils.data.IterableDataset` and + :class:`torchdata.datapipes.iter.IterDataPipe`. Refer to :class:`ExperimentAxisQueryIterableDataset` + and `ExperimentAxisQueryIterDataPipe` for more details on usage. + + Lifecycle: + experimental + """ + + def __init__( + self, + query: soma.ExperimentAxisQuery, + X_name: str, + obs_column_names: Sequence[str] = ("soma_joinid",), + batch_size: int = 1, + shuffle: bool = True, + io_batch_size: int = 2**16, + shuffle_chunk_size: int = 64, + return_sparse_X: bool = False, + seed: int | None = None, + use_eager_fetch: bool = True, + ): + """ + Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`. + + The resulting iterator will produce a tuple containing associated slices of ``X`` and ``obs`` data, as + a NumPy :class:`numpy.ndarray` (or optionally, :class:`scipy.sparse.csr_matrix`) and a + Pandas :class:`pandas.DataFrame`, respectively. + + Args: + query: + A :class:`tiledbsoma.ExperimentAxisQuery`, defining the data to iterate over. + X_name: + The name of the X layer to read. + obs_column_names: + The names of the ``obs`` columns to return. At least one column name must be specified. + Default is ``('soma_joinid',)``. + batch_size: + The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of + ``1`` will result in :class:`torch.Tensor` of rank 1 being returned (a single row); larger values will + result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). Note that a ``batch_size`` of 1 allows + this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` + batching, but higher performance can be achieved by performing batching in this class, and setting the ``DataLoader``'s + ``batch_size`` parameter to ``None``. + shuffle: + Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``. + io_batch_size: + The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts: + 1. Maximum memory utilization, larger values provide better read performance, but require more memory. + 2. The number of rows read prior to shuffling (see the ``shuffle`` parameter for details). + The default value of 65,536 provides high performance but may need to be reduced in memory-limited hosts + or when using a large number of :class:`DataLoader` workers. + shuffle_chunk_size: + The number of contiguous rows sampled prior to concatenation and shuffling. + Larger numbers correspond to less randomness, but greater read performance. + If ``shuffle == False``, this parameter is ignored. + return_sparse_X: + If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will + return ``X`` data as a :class:`numpy.ndarray`. + seed: + The random seed used for shuffling. Defaults to ``None`` (no seed). This argument *MUST* be specified when using + :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker + processes. + use_eager_fetch: + Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made + available for processing via the iterator. This allows network (or filesystem) requests to be made in + parallel with client-side processing of the SOMA data, potentially improving overall performance at the + cost of doubling memory utilization. Defaults to ``True``. + + Returns: + An ``iterable``, which can be iterated over using the Python ``iter()`` statement, or passed directly to + a :class:`torch.utils.data.DataLoader` instance. + + Raises: + ``ValueError`` on various unsupported or malformed parameter values. + + Lifecycle: + experimental + + .. warning:: + When using this class in any distributed mode, calling the :meth:`set_epoch` method at + the beginning of each epoch **before** creating the :class:`DataLoader` iterator + is necessary to make shuffling work properly across multiple epochs. Otherwise, + the same ordering will be always used. + + In addition, when using shuffling in a distributed configuration (e.g., ``DDP``), you + must provide a seed, ensuring that the same shuffle is used across all replicas. + """ + + super().__init__() + + # Anything set in the instance needs to be picklable for multi-process DataLoaders + self.experiment_locator = _ExperimentLocator.create(query.experiment) + self.layer_name = X_name + self.measurement_name = query.measurement_name + self.obs_query = query._matrix_axis_query.obs + self.var_query = query._matrix_axis_query.var + self.obs_column_names = list(obs_column_names) + self.batch_size = batch_size + self.io_batch_size = io_batch_size + self.shuffle = shuffle + self.return_sparse_X = return_sparse_X + self.use_eager_fetch = use_eager_fetch + self._obs_joinids: npt.NDArray[np.int64] | None = None + self._var_joinids: npt.NDArray[np.int64] | None = None + self.seed = ( + seed if seed is not None else np.random.default_rng().integers(0, 2**32 - 1) + ) + self._user_specified_seed = seed is not None + self.shuffle_chunk_size = shuffle_chunk_size + self._initialized = False + self.epoch = 0 + + if self.shuffle: + # round io_batch_size up to a unit of shuffle_chunk_size to simplify code. + self.io_batch_size = ( + ceil(io_batch_size / shuffle_chunk_size) * shuffle_chunk_size + ) + + if not self.obs_column_names: + raise ValueError("Must specify at least one value in `obs_column_names`") + + def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]: + """Create iterator over obs id chunks with split size of (roughly) io_batch_size. + + As appropriate, will chunk, shuffle and apply partitioning per worker. + + IMPORTANT: in any scenario using torch.distributed, where WORLD_SIZE > 1, this will + always partition such that each process has the same number of samples. Where + the number of obs_joinids is not evenly divisible by the number of processes, + the number of joinids will be dropped (dropped ids can never exceed WORLD_SIZE-1). + + Abstractly, the steps taken: + 1. Split the joinids into WORLD_SIZE sections (aka number of GPUS in DDP) + 2. Trim the splits to be of equal length + 3. Chunk and optionally shuffle the chunks + 4. Partition by number of data loader workers (to not generate redundant batches + in cases where the DataLoader is running with `n_workers>1`). + + Private method. + """ + assert self._obs_joinids is not None + obs_joinids: npt.NDArray[np.int64] = self._obs_joinids + + # 1. Get the split for the model replica/GPU + world_size, rank = _get_distributed_world_rank() + _gpu_splits = _splits(len(obs_joinids), world_size) + _gpu_split = obs_joinids[_gpu_splits[rank] : _gpu_splits[rank + 1]] + + # 2. Trim to be all of equal length - equivalent to a "drop_last" + # TODO: may need to add an option to do padding as well. + min_len = np.diff(_gpu_splits).min() + assert 0 <= (np.diff(_gpu_splits).min() - min_len) <= 1 + _gpu_split = _gpu_split[:min_len] + + # 3. Chunk and optionally shuffle chunks + if self.shuffle: + assert self.io_batch_size % self.shuffle_chunk_size == 0 + shuffle_split = np.array_split( + _gpu_split, max(1, ceil(len(_gpu_split) / self.shuffle_chunk_size)) + ) + + # Deterministically create RNG - state must be same across all processes, ensuring + # that the joinid partitions are identical across all processes. + rng = np.random.default_rng(self.seed + self.epoch + 99) + rng.shuffle(shuffle_split) + obs_joinids_chunked = list( + np.concatenate(b) + for b in _batched( + shuffle_split, self.io_batch_size // self.shuffle_chunk_size + ) + ) + else: + obs_joinids_chunked = np.array_split( + _gpu_split, max(1, ceil(len(_gpu_split) / self.io_batch_size)) + ) + + # 4. Partition by DataLoader worker + n_workers, worker_id = _get_worker_world_rank() + obs_splits = _splits(len(obs_joinids_chunked), n_workers) + obs_partition_joinids = obs_joinids_chunked[ + obs_splits[worker_id] : obs_splits[worker_id + 1] + ].copy() + + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"Process {os.getpid()} rank={rank}, world_size={world_size}, worker_id={worker_id}, " + f"n_workers={n_workers}, epoch={self.epoch}, " + f"partition_size={sum([len(chunk) for chunk in obs_partition_joinids])}" + ) + + return iter(obs_partition_joinids) + + def _init_once(self, exp: soma.Experiment | None = None) -> None: + """One-time per worker initialization. + + All operations be idempotent in order to support pipe reset(). + + Private method. + """ + if self._initialized: + return + + logger.debug( + f"Initializing ExperimentAxisQueryIterable (shuffle={self.shuffle})" + ) + + if exp is None: + # If no user-provided Experiment, open/close it ourselves + exp_cm: ContextManager[soma.Experiment] = ( + self.experiment_locator.open_experiment() + ) + else: + # else, it is caller responsibility to open/close the experiment + exp_cm = contextlib.nullcontext(exp) + + with exp_cm as exp: + with exp.axis_query( + measurement_name=self.measurement_name, + obs_query=self.obs_query, + var_query=self.var_query, + ) as query: + self._obs_joinids = query.obs_joinids().to_numpy() + self._var_joinids = query.var_joinids().to_numpy() + + self._initialized = True + + def __iter__(self) -> Iterator[XObsDatum]: + """Create iterator over query. + + Returns: + ``iterator`` + + Lifecycle: + experimental + """ + + if ( + self.return_sparse_X + and torch.utils.data.get_worker_info() + and torch.utils.data.get_worker_info().num_workers > 0 + ): + raise NotImplementedError( + "torch does not work with sparse tensors in multi-processing mode " + "(see https://github.com/pytorch/pytorch/issues/20248)" + ) + + world_size, rank = _get_distributed_world_rank() + n_workers, worker_id = _get_worker_world_rank() + logger.debug( + f"Iterator created rank={rank}, world_size={world_size}, worker_id={worker_id}, n_workers={n_workers}, seed={self.seed}, epoch={self.epoch}" + ) + if world_size > 1 and self.shuffle and self._user_specified_seed is None: + raise ValueError( + "ExperimentAxisQueryIterable requires an explicit `seed` when shuffle is used in a multi-process configuration." + ) + + with self.experiment_locator.open_experiment() as exp: + self._init_once(exp) + X = exp.ms[self.measurement_name].X[self.layer_name] + if not isinstance(X, soma.SparseNDArray): + raise NotImplementedError( + "ExperimentAxisQueryIterable only supports X layers which are of type SparseNDArray" + ) + + obs_joinid_iter = self._create_obs_joinids_partition() + _mini_batch_iter = self._mini_batch_iter(exp.obs, X, obs_joinid_iter) + if self.use_eager_fetch: + _mini_batch_iter = _EagerIterator( + _mini_batch_iter, pool=exp.context.threadpool + ) + + yield from _mini_batch_iter + + self.epoch += 1 + + def __len__(self) -> int: + """Return the approximate number of batches this iterable will produce. If run in the context of :class:`torch.distributed` or + as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) + count will reflect the size of the data partition assigned to the active process. + + See important caveats in the PyTorch + [:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) + documentation regarding ``len(dataloader)``, which also apply to this class. + + Returns: + An ``int``. + + Lifecycle: + experimental + """ + return self.shape[0] + + @property + def shape(self) -> Tuple[int, int]: + """Get the approximate shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`. + This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode + (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect + the size of the data partition assigned to the active process. + + Returns: + A tuple of two ``int`` values: number of obs, number of vars. + + Lifecycle: + experimental + """ + self._init_once() + assert self._obs_joinids is not None + assert self._var_joinids is not None + world_size, _ = _get_distributed_world_rank() + n_workers, _ = _get_worker_world_rank() + partition_len = len(self._obs_joinids) // world_size // n_workers + div, rem = divmod(partition_len, self.batch_size) + return div + bool(rem), len(self._var_joinids) + + def set_epoch(self, epoch: int) -> None: + """ + Set the epoch for this Data iterator. + + When :attr:`shuffle=True`, this will ensure that all replicas use a different + random ordering for each epoch. Failure to call this method before each epoch + will result in the same data ordering. + + This call must be made before the per-epoch iterator is created. + """ + self.epoch = epoch + + def __getitem__(self, index: int) -> XObsDatum: + raise NotImplementedError( + "``ExperimentAxisQueryIterable can only be iterated - does not support mapping" + ) + + def _io_batch_iter( + self, + obs: soma.DataFrame, + X: soma.SparseNDArray, + obs_joinid_iter: Iterator[npt.NDArray[np.int64]], + ) -> Iterator[Tuple[_CSR_IO_Buffer, pd.DataFrame]]: + """Iterate over IO batches, i.e., SOMA query/read, producing a tuple of + (X: csr_array, obs: DataFrame). + + obs joinids read are controlled by the obs_joinid_iter. Iterator results will + be reindexed and shuffled (if shuffling enabled). + + Private method. + """ + assert self._var_joinids is not None + + # Create RNG - does not need to be identical across processes, but use the seed anyway + # for reproducibility. + shuffle_rng = np.random.default_rng(self.seed + self.epoch) + + obs_column_names = ( + list(self.obs_column_names) + if "soma_joinid" in self.obs_column_names + else ["soma_joinid", *self.obs_column_names] + ) + var_indexer = soma.IntIndexer(self._var_joinids, context=X.context) + + for obs_coords in obs_joinid_iter: + st_time = time.perf_counter() + obs_shuffled_coords = ( + obs_coords if not self.shuffle else shuffle_rng.permuted(obs_coords) + ) + obs_indexer = soma.IntIndexer(obs_shuffled_coords, context=X.context) + logger.debug( + f"Retrieving next SOMA IO batch of length {len(obs_coords)}..." + ) + + # to maximize optty's for concurrency, when in eager_fetch mode, + # create the X read iterator first, as the eager iterator will begin + # the read-ahead immediately. Then proceed to fetch obs DataFrame. + # This matters most on latent backing stores, e.g., S3. + # + X_tbl_iter: Iterator[pa.Table] = X.read( + coords=(obs_coords, self._var_joinids) + ).tables() + + def make_io_buffer( + X_tbl: pa.Table, + obs_coords: npt.NDArray[np.int64], + var_coords: npt.NDArray[np.int64], + obs_indexer: soma.IntIndexer, + ) -> _CSR_IO_Buffer: + """This function provides a GC after we throw off (large) garbage.""" + m = _CSR_IO_Buffer.from_ijd( + obs_indexer.get_indexer(X_tbl["soma_dim_0"]), + var_indexer.get_indexer(X_tbl["soma_dim_1"]), + X_tbl["soma_data"].to_numpy(), + shape=(len(obs_coords), len(var_coords)), + ) + gc.collect(generation=0) + return m + + _io_buf_iter = ( + make_io_buffer(X_tbl, obs_coords, self._var_joinids, obs_indexer) + for X_tbl in X_tbl_iter + ) + if self.use_eager_fetch: + _io_buf_iter = _EagerIterator(_io_buf_iter, pool=X.context.threadpool) + + # Now that X read is potentially in progress (in eager mode), go fetch obs data + # + obs_io_batch = cast( + pd.DataFrame, + obs.read(coords=(obs_coords,), column_names=obs_column_names) + .concat() + .to_pandas() + .set_index("soma_joinid") + .reindex(obs_shuffled_coords, copy=False) + .reset_index(), + ) + obs_io_batch = obs_io_batch[self.obs_column_names] + + X_io_batch = _CSR_IO_Buffer.merge(tuple(_io_buf_iter)) + + del obs_indexer, obs_coords, obs_shuffled_coords, _io_buf_iter + gc.collect() + + tm = time.perf_counter() - st_time + logger.debug( + f"Retrieved SOMA IO batch, took {tm:.2f}sec, {X_io_batch.shape[0]/tm:0.1f} samples/sec" + ) + yield X_io_batch, obs_io_batch + + def _mini_batch_iter( + self, + obs: soma.DataFrame, + X: soma.SparseNDArray, + obs_joinid_iter: Iterator[npt.NDArray[np.int64]], + ) -> Iterator[XObsDatum]: + """Break IO batches into shuffled mini-batch-sized chunks. + + Private method. + """ + assert self._obs_joinids is not None + assert self._var_joinids is not None + + io_batch_iter = self._io_batch_iter(obs, X, obs_joinid_iter) + if self.use_eager_fetch: + io_batch_iter = _EagerIterator(io_batch_iter, pool=X.context.threadpool) + + mini_batch_size = self.batch_size + result: Tuple[NDArrayNumber, pd.DataFrame] | None = None + for X_io_batch, obs_io_batch in io_batch_iter: + assert X_io_batch.shape[0] == obs_io_batch.shape[0] + assert X_io_batch.shape[1] == len(self._var_joinids) + iob_idx = 0 # current offset into io batch + iob_len = X_io_batch.shape[0] + + while iob_idx < iob_len: + if result is None: + # perform zero copy slice where possible + X_datum = ( + X_io_batch.slice_toscipy( + slice(iob_idx, iob_idx + mini_batch_size) + ) + if self.return_sparse_X + else X_io_batch.slice_tonumpy( + slice(iob_idx, iob_idx + mini_batch_size) + ) + ) + result = ( + X_datum, + obs_io_batch.iloc[iob_idx : iob_idx + mini_batch_size], + ) + iob_idx += len(result[1]) + else: + # use any remnant from previous IO batch + to_take = min(mini_batch_size - len(result[1]), iob_len - iob_idx) + X_datum = ( + sparse.vstack( + [result[0], X_io_batch.slice_toscipy(slice(0, to_take))] + ) + if self.return_sparse_X + else np.concatenate( + [result[0], X_io_batch.slice_tonumpy(slice(0, to_take))] + ) + ) + result = ( + X_datum, + pd.concat([result[1], obs_io_batch.iloc[0:to_take]]), + ) + iob_idx += to_take + + assert result[0].shape[0] == result[1].shape[0] + if result[0].shape[0] == mini_batch_size: + yield result + result = None + + else: + # yield the remnant, if any + if result is not None: + yield result + + +class ExperimentAxisQueryIterDataPipe( + torchdata.datapipes.iter.IterDataPipe[ # type:ignore[misc] + torch.utils.data.dataset.Dataset[XObsDatum] + ], +): + """A :class:`torch.utils.data.IterableDataset` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. + + This class is based upon the now-deprecated :class:`torchdata.datapipes` API, and should only be used for + legacy code. See [GitHub issue #1196](https://github.com/pytorch/data/issues/1196) and the + TorchData [README](https://github.com/pytorch/data/blob/v0.8.0/README.md) for more information. + + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + + def __init__( + self, + query: soma.ExperimentAxisQuery, + X_name: str = "raw", + obs_column_names: Sequence[str] = ("soma_joinid",), + batch_size: int = 1, + shuffle: bool = True, + seed: int | None = None, + io_batch_size: int = 2**16, + shuffle_chunk_size: int = 64, + return_sparse_X: bool = False, + use_eager_fetch: bool = True, + ): + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + super().__init__() + self._exp_iter = ExperimentAxisQueryIterable( + query=query, + X_name=X_name, + obs_column_names=obs_column_names, + batch_size=batch_size, + shuffle=shuffle, + seed=seed, + io_batch_size=io_batch_size, + return_sparse_X=return_sparse_X, + use_eager_fetch=use_eager_fetch, + shuffle_chunk_size=shuffle_chunk_size, + ) + + def __iter__(self) -> Iterator[XObsDatum]: + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + batch_size = self._exp_iter.batch_size + for X, obs in self._exp_iter: + if batch_size == 1: + X = X[0] + yield X, obs + + def __len__(self) -> int: + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + return self._exp_iter.__len__() + + @property + def shape(self) -> Tuple[int, int]: + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + return self._exp_iter.shape + + def set_epoch(self, epoch: int) -> None: + """ + Set the epoch for this Data iterator. + + When :attr:`shuffle=True`, this will ensure that all replicas use a different + random ordering for each epoch. Failure to call this method before each epoch + will result in the same data ordering. + + This call must be made before the per-epoch iterator is created. + + Lifecycle: + experimental + """ + self._exp_iter.set_epoch(epoch) + + @property + def epoch(self) -> int: + return self._exp_iter.epoch + + +class ExperimentAxisQueryIterableDataset( + torch.utils.data.IterableDataset[XObsDatum] # type:ignore[misc] +): + """A :class:`torch.utils.data.IterableDataset` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. + + This class works seamlessly with :class:`torch.utils.data.DataLoader` to load ``obs`` and ``X`` data as + specified by a SOMA :class:`tiledbsoma.ExperimentAxisQuery`, providing an iterator over batches of + ``obs`` and ``X`` data. Each iteration will yield a tuple containing an :class:`numpy.ndarray` + and a :class:`pandas.DataFrame`. + + For example: + + >>> import torch + >>> import tiledbsoma + >>> import tiledbsoma_ml + >>> with tiledbsoma.Experiment.open("my_experiment_path") as exp: + with exp.axis_query(measurement_name="RNA", obs_query=tiledbsoma.AxisQuery(value_filter="tissue_type=='lung'")) as query: + ds = tiledbsoma_ml.ExperimentAxisQueryIterableDataset(query) + dataloader = torch.utils.data.DataLoader(ds) + >>> data = next(iter(dataloader)) + >>> data + (array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), + soma_joinid + 0 57905025) + >>> data[0] + array([0., 0., 0., ..., 0., 0., 0.], dtype=float32) + >>> data[1] + soma_joinid + 0 57905025 + + The ``batch_size`` parameter controls the number of rows of ``obs`` and ``X`` data that are returned in each + iteration. If the ``batch_size`` is 1, then each result will have rank 1, else it will have rank 2. A ``batch_size`` + of 1 is compatible with :class:`torch.utils.data.DataLoader`-implemented batching, but it will usually be more + performant to create mini-batches using this class, and set the ``DataLoader`` batch size to `None`. + + The ``obs_column_names`` parameter determines the data columns that are returned in the ``obs`` DataFrame (the + default is a single column, containing the ``soma_joinid`` for the ``obs`` dimension). + + The ``io_batch_size`` parameter determines the number of rows read, from which mini-batches are yielded. A + larger value will increase total memory usage and may reduce average read time per row. + + Shuffling support is enabled with the ``shuffle`` parameter, and will normally be more performant than using + :class:`DataLoader` shuffling. The shuffling algorithm works as follows: + + 1. Rows selected by the query are subdivided into groups of size ``shuffle_chunk_size``, aka a "shuffle chunk". + 2. A random selection of shuffle chunks is drawn and read as a single I/O buffer (of size ``io_buffer_size``). + 3. The entire I/O buffer is shuffled. + + Put another way, we read randomly selected groups of observations from across all query results, concatenate + those into an I/O buffer, and shuffle the buffer before returning mini-batches. The randomness of the shuffle + is therefore determined by the ``io_buffer_size`` (number of rows read), and the ``shuffle_chunk_size`` + (number of rows in each draw). Decreasing ``shuffle_chunk_size`` will increase shuffling randomness, and decrease I/O + performance. + + This class will detect when run in a multiprocessing mode, including multi-worker :class:`torch.utils.data.DataLoader` + and multi-process training such as :class:`torch.nn.parallel.DistributedDataParallel`, and will automatically partition + data appropriately. In the case of distributed training, sample partitions across all processes must be equal. Any + data tail will be dropped. + + Lifecycle: + experimental + """ + + def __init__( + self, + query: soma.ExperimentAxisQuery, + X_name: str = "raw", + obs_column_names: Sequence[str] = ("soma_joinid",), + batch_size: int = 1, + shuffle: bool = True, + seed: int | None = None, + io_batch_size: int = 2**16, + shuffle_chunk_size: int = 64, + return_sparse_X: bool = False, + use_eager_fetch: bool = True, + ): + """ + Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`. + + The resulting iterator will produce a tuple containing associated slices of ``X`` and ``obs`` data, as + a NumPy ``ndarray`` (or optionally, :class:`scipy.sparse.csr_matrix`) and a Pandas ``DataFrame`` respectively. + + Args: + query: + A :class:`tiledbsoma.ExperimentAxisQuery`, defining the data which will be iterated over. + X_name: + The name of the ``X`` layer to read. + obs_column_names: + The names of the ``obs`` columns to return. At least one column name must be specified. + Default is ``('soma_joinid',)``. + batch_size: + The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of + ``1`` will result in :class:`torch.Tensor` of rank 1 being returned (a single row); larger values will + result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). + + Note that a ``batch_size`` of 1 allows this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` + batching, but you will achieve higher performance by performing batching in this class, and setting the ``DataLoader`` + batch_size parameter to ``None``. + shuffle: + Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``. + io_batch_size: + The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of + this class's behavior: 1) The maximum memory utilization, with larger values providing + better read performance, but also requiring more memory; 2) The number of rows read prior to shuffling + (see ``shuffle`` parameter for details). The default value of 131,072 provides high performance, but + may need to be reduced in memory limited hosts (or where a large number of :class:`DataLoader` workers + are employed). + shuffle_chunk_size: + The number of contiguous rows sampled, prior to concatenation and shuffling. + Larger numbers correspond to less randomness, but greater read performance. + If ``shuffle == False``, this parameter is ignored. + return_sparse_X: + If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will + return ``X`` data as a :class:`numpy.ndarray`. + seed: + The random seed used for shuffling. Defaults to ``None`` (no seed). This arguiment *must* be specified when using + :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker + processes. + use_eager_fetch: + Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made + available for processing via the iterator. This allows network (or filesystem) requests to be made in + parallel with client-side processing of the SOMA data, potentially improving overall performance at the + cost of doubling memory utilization. Defaults to ``True``. + + Returns: + An ``iterable``, which can be iterated over using the Python ``iter()`` statement, or passed directly to + a :class:`torch.data.utils.DataLoader` instance. + + Raises: + ``ValueError`` on various unsupported or malformed parameter values. + + Lifecycle: + experimental + + """ + super().__init__() + self._exp_iter = ExperimentAxisQueryIterable( + query=query, + X_name=X_name, + obs_column_names=obs_column_names, + batch_size=batch_size, + shuffle=shuffle, + seed=seed, + io_batch_size=io_batch_size, + return_sparse_X=return_sparse_X, + use_eager_fetch=use_eager_fetch, + shuffle_chunk_size=shuffle_chunk_size, + ) + + def __iter__(self) -> Iterator[XObsDatum]: + """Create Iterator yielding tuples of :class:`numpy.ndarray` and :class:`pandas.DataFrame`. + + Returns: + ``iterator`` + + Lifecycle: + experimental + """ + batch_size = self._exp_iter.batch_size + for X, obs in self._exp_iter: + if batch_size == 1: + X = X[0] + yield X, obs + + def __len__(self) -> int: + """Return approximate number of batches this iterable will produce. + + See important caveats in the PyTorch + [:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) + documentation regarding ``len(dataloader)``, which also apply to this class. + + Returns: + An ``int``. + + Lifecycle: + experimental + """ + return self._exp_iter.__len__() + + @property + def shape(self) -> Tuple[int, int]: + """Get the shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset`. + + This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode + (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect + the size of the partition of the data assigned to the active process. + + Returns: + A tuple of ``int``s, for obs and var counts, respectively. + + Lifecycle: + experimental + """ + return self._exp_iter.shape + + def set_epoch(self, epoch: int) -> None: + """ + Set the epoch for this Data iterator. + + When :attr:`shuffle=True`, this will ensure that all replicas use a different + random ordering for each epoch. Failure to call this method before each epoch + will result in the same data ordering. + + This call must be made before the per-epoch iterator is created. + + Lifecycle: + experimental + """ + self._exp_iter.set_epoch(epoch) + + @property + def epoch(self) -> int: + return self._exp_iter.epoch + + +def experiment_dataloader( + ds: torchdata.datapipes.iter.IterDataPipe | torch.utils.data.IterableDataset, + **dataloader_kwargs: Any, +) -> torch.utils.data.DataLoader: + """Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a + :class:`torch.utils.data.DataLoader` that works with :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` + or :class:`tiledbsoma_ml.ExperimentAxisQueryIterDataPipe`. + + Several :class:`torch.utils.data.DataLoader` constructor parameters are not applicable, or are non-performant, + when using loaders from this module, including ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``. + Specifying any of these parameters will result in an error. + + Refer to ``https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader`` for more information on + :class:`torch.utils.data.DataLoader` parameters. + + Args: + ds: + A :class:`torch.utils.data.IterableDataset` or a :class:`torchdata.datapipes.iter.IterDataPipe`. May + include chained data pipes. + **dataloader_kwargs: + Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` constructor, + except for ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``, which are not + supported when using data loaders in this module. + + Returns: + A :class:`torch.utils.data.DataLoader`. + + Raises: + ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, or ``batch_sampler`` params + are passed as keyword arguments. + + Lifecycle: + experimental + """ + unsupported_dataloader_args = [ + "shuffle", + "batch_size", + "sampler", + "batch_sampler", + ] + if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()): + raise ValueError( + f"The {','.join(unsupported_dataloader_args)} DataLoader parameters are not supported" + ) + + if dataloader_kwargs.get("num_workers", 0) > 0: + _init_multiprocessing() + + if "collate_fn" not in dataloader_kwargs: + dataloader_kwargs["collate_fn"] = _collate_noop + + return torch.utils.data.DataLoader( + ds, + batch_size=None, # batching is handled by upstream iterator + shuffle=False, # shuffling is handled by upstream iterator + **dataloader_kwargs, + ) + + +def _collate_noop(datum: _T) -> _T: + """Noop collation for use with a dataloader instance. + + Private. + """ + return datum + + +def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: + """For `total_length` points, compute start/stop offsets that split the length into roughly equal sizes. + + A total_length of L, split into N sections, will return L%N sections of size L//N+1, + and the remainder as size L//N. This results in the same split as numpy.array_split, + for an array of length L and sections N. + + Private. + + Examples + -------- + >>> _splits(10, 3) + array([0, 4, 7, 10]) + >>> _splits(4, 2) + array([0, 2, 4]) + """ + if sections <= 0: + raise ValueError("number of sections must greater than 0.") from None + each_section, extras = divmod(total_length, sections) + per_section_sizes = ( + [0] + extras * [each_section + 1] + (sections - extras) * [each_section] + ) + splits = np.array(per_section_sizes, dtype=np.intp).cumsum() + return splits + + +if sys.version_info >= (3, 12): + _batched = itertools.batched + +else: + + def _batched(iterable: Iterable[_T_co], n: int) -> Iterator[Tuple[_T_co, ...]]: + """Same as the Python 3.12+ itertools.batched -- polyfill for old Python versions.""" + if n < 1: + raise ValueError("n must be at least one") + it = iter(iterable) + while batch := tuple(islice(it, n)): + yield batch + + +def _get_distributed_world_rank() -> Tuple[int, int]: + """Return tuple containing equivalent of torch.distributed world size and rank.""" + world_size, rank = 1, 0 + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + elif "LOCAL_RANK" in os.environ and "WORLD_SIZE" in os.environ: + # Lightning doesn't use RANK! LOCAL_RANK is only for the local node. There + # is a NODE_RANK for the node's rank, but no way to tell the local node's + # world. So computing a global rank is impossible(?). Using LOCAL_RANK as a + # proxy, which works fine on a single-CPU box. TODO: could throw/error + # if NODE_RANK != 0. + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["LOCAL_RANK"]) + elif torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + return world_size, rank + + +def _get_worker_world_rank() -> Tuple[int, int]: + """Return number of DataLoader workers and our worker rank/id""" + num_workers, worker = 1, 0 + if "WORKER" in os.environ and "NUM_WORKERS" in os.environ: + num_workers = int(os.environ["NUM_WORKERS"]) + worker = int(os.environ["WORKER"]) + else: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + num_workers = worker_info.num_workers + worker = worker_info.id + return num_workers, worker + + +def _init_multiprocessing() -> None: + """Ensures use of "spawn" for starting child processes with multiprocessing. + + Forked processes are known to be problematic: + https://pytorch.org/docs/stable/notes/multiprocessing.html#avoiding-and-fighting-deadlocks + Also, CUDA does not support forked child processes: + https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing + + Private. + """ + orig_start_method = torch.multiprocessing.get_start_method() + if orig_start_method != "spawn": + if orig_start_method: + logger.warning( + "switching torch multiprocessing start method from " + f'"{torch.multiprocessing.get_start_method()}" to "spawn"' + ) + torch.multiprocessing.set_start_method("spawn", force=True) + + +class _CSR_IO_Buffer: + """Implement a minimal CSR matrix with specific optimizations for use in this package. + + Operations supported are: + * Incrementally build a CSR from COO, allowing overlapped I/O and CSR conversion for I/O batches, + and a final "merge" step which combines the result. + * Zero intermediate copy conversion of an arbitrary row slice to dense (ie., mini-batch extraction). + * Parallel ops where it makes sense (construction, merge, etc) + * Minimize memory use for index arrays + + Overall is significantly faster, and uses less memory, than the equivalent scipy.sparse operations. + """ + + __slots__ = ("indptr", "indices", "data", "shape") + + def __init__( + self, + indptr: NDArrayNumber, + indices: NDArrayNumber, + data: NDArrayNumber, + shape: Tuple[int, int], + ) -> None: + """Construct from PJV format.""" + assert len(data) == len(indices) + assert len(data) <= np.iinfo(indptr.dtype).max + assert shape[1] <= np.iinfo(indices.dtype).max + assert indptr[-1] == len(data) and indptr[0] == 0 + + self.shape = shape + self.indptr = indptr + self.indices = indices + self.data = data + + @staticmethod + def from_ijd( + i: NDArrayNumber, j: NDArrayNumber, d: NDArrayNumber, shape: Tuple[int, int] + ) -> _CSR_IO_Buffer: + """Factory from COO""" + nnz = len(d) + indptr = np.zeros((shape[0] + 1), dtype=smallest_uint_dtype(nnz)) + indices = np.empty((nnz,), dtype=smallest_uint_dtype(shape[1])) + data = np.empty((nnz,), dtype=d.dtype) + _coo_to_csr_inner(shape[0], i, j, d, indptr, indices, data) + return _CSR_IO_Buffer(indptr, indices, data, shape) + + @staticmethod + def from_pjd( + p: NDArrayNumber, j: NDArrayNumber, d: NDArrayNumber, shape: Tuple[int, int] + ) -> _CSR_IO_Buffer: + """Factory from CSR""" + return _CSR_IO_Buffer(p, j, d, shape) + + @property + def nnz(self) -> int: + return len(self.indices) + + @property + def nbytes(self) -> int: + return int(self.indptr.nbytes + self.indices.nbytes + self.data.nbytes) + + @property + def dtype(self) -> npt.DTypeLike: + return self.data.dtype + + def slice_tonumpy(self, row_index: slice) -> NDArrayNumber: + """Extract slice as a dense ndarray. Does not assume any particular ordering of minor axis.""" + assert isinstance(row_index, slice) + assert row_index.step in (1, None) + row_idx_start, row_idx_end, _ = row_index.indices(self.indptr.shape[0] - 1) + n_rows = max(row_idx_end - row_idx_start, 0) + out = np.zeros((n_rows, self.shape[1]), dtype=self.data.dtype) + if n_rows >= 0: + _csr_to_dense_inner( + row_idx_start, n_rows, self.indptr, self.indices, self.data, out + ) + return out + + def slice_toscipy(self, row_index: slice) -> sparse.csr_matrix: + """Extract slice as a sparse.csr_matrix. Does not assume any paritcular ordering of minor axis, but + will return a canonically ordered scipy sparse object.""" + assert isinstance(row_index, slice) + assert row_index.step in (1, None) + row_idx_start, row_idx_end, _ = row_index.indices(self.indptr.shape[0] - 1) + n_rows = max(row_idx_end - row_idx_start, 0) + if n_rows == 0: + return sparse.csr_matrix((0, self.shape[1]), dtype=self.dtype) + + indptr = self.indptr[row_idx_start : row_idx_end + 1].copy() + indices = self.indices[indptr[0] : indptr[-1]].copy() + data = self.data[indptr[0] : indptr[-1]].copy() + indptr -= indptr[0] + return sparse.csr_matrix((data, indices, indptr), shape=(n_rows, self.shape[1])) + + @staticmethod + def merge(mtxs: Sequence[_CSR_IO_Buffer]) -> _CSR_IO_Buffer: + assert len(mtxs) > 0 + nnz = sum(m.nnz for m in mtxs) + shape = mtxs[0].shape + for m in mtxs[1:]: + assert m.shape == mtxs[0].shape + assert m.indices.dtype == mtxs[0].indices.dtype + assert all(m.shape == shape for m in mtxs) + + indptr = np.sum( + [m.indptr for m in mtxs], axis=0, dtype=smallest_uint_dtype(nnz) + ) + indices = np.empty((nnz,), dtype=mtxs[0].indices.dtype) + data = np.empty((nnz,), mtxs[0].data.dtype) + + _csr_merge_inner( + tuple((m.indptr.astype(indptr.dtype), m.indices, m.data) for m in mtxs), + indptr, + indices, + data, + ) + return _CSR_IO_Buffer.from_pjd(indptr, indices, data, shape) + + def sort_indices(self) -> Self: + """Sort indices, IN PLACE.""" + _csr_sort_indices(self.indptr, self.indices, self.data) + return self + + +def smallest_uint_dtype(max_val: int) -> npt.DTypeLike: + for dt in [np.uint16, np.uint32]: + if max_val <= np.iinfo(dt).max: + return dt + else: + return np.uint64 + + +@numba.njit(nogil=True, parallel=True) # type:ignore[misc] +def _csr_merge_inner( + As: Tuple[Tuple[NDArrayNumber, NDArrayNumber, NDArrayNumber], ...], # P,J,D + Bp: NDArrayNumber, + Bj: NDArrayNumber, + Bd: NDArrayNumber, +) -> None: + n_rows = len(Bp) - 1 + offsets = Bp.copy() + for Ap, Aj, Ad in As: + n_elmts = Ap[1:] - Ap[:-1] + for n in numba.prange(n_rows): + Bj[offsets[n] : offsets[n] + n_elmts[n]] = Aj[Ap[n] : Ap[n] + n_elmts[n]] + Bd[offsets[n] : offsets[n] + n_elmts[n]] = Ad[Ap[n] : Ap[n] + n_elmts[n]] + offsets[:-1] += n_elmts + + +@numba.njit(nogil=True, parallel=True) # type:ignore[misc] +def _csr_to_dense_inner( + row_idx_start: int, + n_rows: int, + indptr: NDArrayNumber, + indices: NDArrayNumber, + data: NDArrayNumber, + out: NDArrayNumber, +) -> None: + for i in numba.prange(row_idx_start, row_idx_start + n_rows): + for j in range(indptr[i], indptr[i + 1]): + out[i - row_idx_start, indices[j]] = data[j] + + +@numba.njit(nogil=True, parallel=True, inline="always") # type:ignore[misc] +def _count_rows(n_rows: int, Ai: NDArrayNumber, Bp: NDArrayNumber) -> NDArrayNumber: + """Private: parallel row count.""" + nnz = len(Ai) + + partition_size = 32 * 1024**2 + n_partitions = math.ceil(nnz / partition_size) + if n_partitions > 1: + counts = np.zeros((n_partitions, n_rows), dtype=Bp.dtype) + for p in numba.prange(n_partitions): + for n in range(p * partition_size, min(nnz, (p + 1) * partition_size)): + row = Ai[n] + counts[p, row] += 1 + + Bp[:-1] = counts.sum(axis=0) + else: + for n in range(nnz): + row = Ai[n] + Bp[row] += 1 + + return Bp + + +@numba.njit(nogil=True, parallel=True) # type:ignore[misc] +def _coo_to_csr_inner( + n_rows: int, + Ai: NDArrayNumber, + Aj: NDArrayNumber, + Ad: NDArrayNumber, + Bp: NDArrayNumber, + Bj: NDArrayNumber, + Bd: NDArrayNumber, +) -> None: + nnz = len(Ai) + + _count_rows(n_rows, Ai, Bp) + + # cum sum to get the row index pointers (NOTE: starting with zero) + cumsum = 0 + for n in range(n_rows): + tmp = Bp[n] + Bp[n] = cumsum + cumsum += tmp + Bp[n_rows] = nnz + + # Reorganize all of the data. Side-effect: pointers shifted (reversed in the + # subsequent section). + # + # Method is concurrent (partioned by rows) if number of rows is greater + # than 2**partition_bits. This partitioning scheme leverages the fact + # that reads are much cheaper than writes. + # + # The code is equivalent to: + # for n in range(nnz): + # row = Ai[n] + # dst_row = Bp[row] + # Bj[dst_row] = Aj[n] + # Bd[dst_row] = Ad[n] + # Bp[row] += 1 + + partition_bits = 13 + n_partitions = (n_rows + 2**partition_bits - 1) >> partition_bits + for p in numba.prange(n_partitions): + for n in range(nnz): + row = Ai[n] + if (row >> partition_bits) != p: + continue + dst_row = Bp[row] + Bj[dst_row] = Aj[n] + Bd[dst_row] = Ad[n] + Bp[row] += 1 + + # Shift the pointers by one slot (ie., start at zero) + prev_ptr = 0 + for n in range(n_rows + 1): + tmp = Bp[n] + Bp[n] = prev_ptr + prev_ptr = tmp + + +@numba.njit(nogil=True, parallel=True) # type:ignore[misc] +def _csr_sort_indices(Bp: NDArrayNumber, Bj: NDArrayNumber, Bd: NDArrayNumber) -> None: + """In-place sort of minor axis indices""" + n_rows = len(Bp) - 1 + for r in numba.prange(n_rows): + row_start = Bp[r] + row_end = Bp[r + 1] + order = np.argsort(Bj[row_start:row_end]) + Bj[row_start:row_end] = Bj[row_start:row_end][order] + Bd[row_start:row_end] = Bd[row_start:row_end][order] diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py new file mode 100644 index 0000000..ed6440e --- /dev/null +++ b/tests/test_pytorch.py @@ -0,0 +1,1033 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from __future__ import annotations + +import pathlib +from functools import partial +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from unittest.mock import patch + +import numpy as np +import numpy.typing as npt +import pandas as pd +import pyarrow as pa +import pytest +import tiledbsoma as soma +from scipy import sparse +from scipy.sparse import coo_matrix, spmatrix +from tiledbsoma import Experiment, _factory +from tiledbsoma._collection import CollectionBase + +# conditionally import torch, as it will not be available in all test environments. +# This supports the pytest `ml` mark, which can be used to disable all PyTorch-dependent +# tests. +try: + from torch.utils.data._utils.worker import WorkerInfo + + from tiledbsoma_ml.pytorch import ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterableDataset, + ExperimentAxisQueryIterDataPipe, + experiment_dataloader, + ) +except ImportError: + # this should only occur when not running `ml`-marked tests + pass + + +def pytorch_x_value_gen(obs_range: range, var_range: range) -> spmatrix: + occupied_shape = ( + obs_range.stop - obs_range.start, + var_range.stop - var_range.start, + ) + checkerboard_of_ones = coo_matrix(np.indices(occupied_shape).sum(axis=0) % 2) + checkerboard_of_ones.row += obs_range.start + checkerboard_of_ones.col += var_range.start + return checkerboard_of_ones + + +def pytorch_seq_x_value_gen(obs_range: range, var_range: range) -> spmatrix: + """A sparse matrix where the values of each col are the obs_range values. Useful for checking the + X values are being returned in the correct order.""" + data = np.vstack([list(obs_range)] * len(var_range)).flatten() + rows = np.vstack([list(obs_range)] * len(var_range)).flatten() + cols = np.column_stack([list(var_range)] * len(obs_range)).flatten() + return coo_matrix((data, (rows, cols))) + + +@pytest.fixture +def X_layer_names() -> List[str]: + return ["raw"] + + +@pytest.fixture +def obsp_layer_names() -> Optional[List[str]]: + return None + + +@pytest.fixture +def varp_layer_names() -> Optional[List[str]]: + return None + + +def add_dataframe(coll: CollectionBase, key: str, value_range: range) -> None: + df = coll.add_new_dataframe( + key, + schema=pa.schema( + [ + ("soma_joinid", pa.int64()), + ("label", pa.large_string()), + ("label2", pa.large_string()), + ] + ), + index_column_names=["soma_joinid"], + ) + df.write( + pa.Table.from_pydict( + { + "soma_joinid": list(value_range), + "label": [str(i) for i in value_range], + "label2": ["c" for i in value_range], + } + ) + ) + + +def add_sparse_array( + coll: CollectionBase, + key: str, + obs_range: range, + var_range: range, + value_gen: Callable[[range, range], spmatrix], +) -> None: + a = coll.add_new_sparse_ndarray( + key, type=pa.float32(), shape=(obs_range.stop, var_range.stop) + ) + tensor = pa.SparseCOOTensor.from_scipy(value_gen(obs_range, var_range)) + a.write(tensor) + + +@pytest.fixture(scope="function") +def soma_experiment( + tmp_path: pathlib.Path, + obs_range: Union[int, range], + var_range: Union[int, range], + X_value_gen: Callable[[range, range], sparse.spmatrix], + obsp_layer_names: Sequence[str], + varp_layer_names: Sequence[str], +) -> soma.Experiment: + with soma.Experiment.create((tmp_path / "exp").as_posix()) as exp: + if isinstance(obs_range, int): + obs_range = range(obs_range) + if isinstance(var_range, int): + var_range = range(var_range) + + add_dataframe(exp, "obs", obs_range) + ms = exp.add_new_collection("ms") + rna = ms.add_new_collection("RNA", soma.Measurement) + add_dataframe(rna, "var", var_range) + rna_x = rna.add_new_collection("X", soma.Collection) + add_sparse_array(rna_x, "raw", obs_range, var_range, X_value_gen) + + if obsp_layer_names: + obsp = rna.add_new_collection("obsp") + for obsp_layer_name in obsp_layer_names: + add_sparse_array( + obsp, obsp_layer_name, obs_range, var_range, X_value_gen + ) + + if varp_layer_names: + varp = rna.add_new_collection("varp") + for varp_layer_name in varp_layer_names: + add_sparse_array( + varp, varp_layer_name, obs_range, var_range, X_value_gen + ) + return _factory.open((tmp_path / "exp").as_posix()) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize("return_sparse_X", [True, False]) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_non_batched( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, + return_sparse_X: bool, +) -> None: + # batch_size should default to 1 + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + shuffle=False, + use_eager_fetch=use_eager_fetch, + return_sparse_X=return_sparse_X, + ) + assert type(exp_data_pipe.shape) is tuple + assert len(exp_data_pipe.shape) == 2 + assert exp_data_pipe.shape == (6, 3) + + row_iter = iter(exp_data_pipe) + + row = next(row_iter) + + if return_sparse_X: + # sparse slices remain 2D, always + assert isinstance(row[0], sparse.csr_matrix) + assert row[0].shape == (1, 3) + assert row[0].todense().tolist() == [[0, 1, 0]] + + else: + assert isinstance(row[0], np.ndarray) + assert row[0].shape == (3,) + assert row[0].tolist() == [0, 1, 0] + + assert isinstance(row[1], pd.DataFrame) + assert row[1].shape == (1, 1) + assert row[1].keys() == ["label"] + assert row[1]["label"].tolist() == ["0"] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize("return_sparse_X", [True, False]) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_uneven_soma_and_result_batches( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, + return_sparse_X: bool, +) -> None: + """This is checking that batches are correctly created when they require fetching multiple chunks.""" + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + shuffle=False, + batch_size=3, + io_batch_size=2, + use_eager_fetch=use_eager_fetch, + return_sparse_X=return_sparse_X, + ) + row_iter = iter(exp_data_pipe) + + X_batch, obs_batch = next(row_iter) + + if return_sparse_X: + assert isinstance(X_batch, sparse.csr_matrix) + assert X_batch.todense()[0].tolist() == [[0, 1, 0]] + else: + assert isinstance(X_batch, np.ndarray) + assert X_batch[0].tolist() == [0, 1, 0] + + assert isinstance(obs_batch, pd.DataFrame) + assert X_batch.shape[0] == obs_batch.shape[0] + assert X_batch.shape == (3, 3) + assert obs_batch.shape == (3, 1) + assert ["label"] == obs_batch.keys() + assert obs_batch["label"].tolist() == ["0", "1", "2"] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_batching__all_batches_full_size( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1].keys() == ["label"] + assert batch[1]["label"].tolist() == ["0", "1", "2"] + + batch = next(batch_iter) + assert batch[0].tolist() == [[1, 0, 1], [0, 1, 0], [1, 0, 1]] + assert batch[1].keys() == ["label"] + assert batch[1]["label"].tolist() == ["3", "4", "5"] + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [ + (range(100_000_000, 100_000_003), 3, pytorch_x_value_gen, use_eager_fetch) + for use_eager_fetch in (True, False) + ], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_unique_soma_joinids( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["soma_joinid", "label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + + soma_joinids = np.concatenate( + [batch[1]["soma_joinid"].to_numpy() for batch in exp_data_pipe] + ) + assert len(np.unique(soma_joinids)) == len(soma_joinids) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(5, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_batching__partial_final_batch_size( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + next(batch_iter) + batch = next(batch_iter) + assert batch[0].tolist() == [[1, 0, 1], [0, 1, 0]] + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_batching__exactly_one_batch( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1]["label"].tolist() == ["0", "1", "2"] + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_batching__empty_query_result( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query( + measurement_name="RNA", obs_query=soma.AxisQuery(coords=([],)) + ) as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_sparse_output__non_batched( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + return_sparse_X=True, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert isinstance(batch[0], sparse.csr_matrix) + assert batch[0].todense().A.squeeze().tolist() == [0, 1, 0] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_sparse_output__batched( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + return_sparse_X=True, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert isinstance(batch[0], sparse.csr_matrix) + assert batch[0].todense().tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [ + (10, 1, pytorch_x_value_gen, use_eager_fetch) + for use_eager_fetch in (True, False) + ], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_batching__partial_soma_batches_are_concatenated( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + # set SOMA batch read size such that PyTorch batches will span the tail and head of two SOMA batches + io_batch_size=4, + use_eager_fetch=use_eager_fetch, + ) + + full_result = list(exp_data_pipe) + + assert [len(batch[0]) for batch in full_result] == [3, 3, 3, 1] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)] +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_multiprocessing__returns_full_result( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, +) -> None: + """Tests the ExperimentAxisQueryIterDataPipe provides all data, as collected from multiple processes that are managed by a + PyTorch DataLoader with multiple workers configured.""" + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["soma_joinid", "label"], + io_batch_size=3, # two chunks, one per worker + ) + # Note we're testing the ExperimentAxisQueryIterDataPipe via a DataLoader, since this is what sets up the multiprocessing + dl = experiment_dataloader(dp, num_workers=2) + + full_result = list(iter(dl)) + + soma_joinids = np.concatenate( + [t[1]["soma_joinid"].to_numpy() for t in full_result] + ) + assert sorted(soma_joinids) == list(range(6)) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", + [(6, 3, pytorch_x_value_gen), (7, 3, pytorch_x_value_gen)], +) +@pytest.mark.parametrize( + "world_size,rank", + [(3, 0), (3, 1), (3, 2), (2, 0), (2, 1)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_distributed__returns_data_partition_for_rank( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + obs_range: int, + world_size: int, + rank: int, +) -> None: + """Tests pytorch._partition_obs_joinids() behavior in a simulated PyTorch distributed processing mode, + using mocks to avoid having to do real PyTorch distributed setup.""" + + with ( + patch("torch.distributed.is_initialized") as mock_dist_is_initialized, + patch("torch.distributed.get_rank") as mock_dist_get_rank, + patch("torch.distributed.get_world_size") as mock_dist_get_world_size, + ): + mock_dist_is_initialized.return_value = True + mock_dist_get_rank.return_value = rank + mock_dist_get_world_size.return_value = world_size + + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["soma_joinid"], + io_batch_size=2, + shuffle=False, + ) + full_result = list(iter(dp)) + soma_joinids = np.concatenate( + [t[1]["soma_joinid"].to_numpy() for t in full_result] + ) + + expected_joinids = np.array_split(np.arange(obs_range), world_size)[rank][ + 0 : obs_range // world_size + ].tolist() + assert sorted(soma_joinids) == expected_joinids + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", + [(12, 3, pytorch_x_value_gen), (13, 3, pytorch_x_value_gen)], +) +@pytest.mark.parametrize( + "world_size,rank,num_workers,worker_id", + [ + (3, 1, 2, 0), + (3, 1, 2, 1), + ], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_distributed_and_multiprocessing__returns_data_partition_for_rank( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + obs_range: int, + world_size: int, + rank: int, + num_workers: int, + worker_id: int, +) -> None: + """Tests pytorch._partition_obs_joinids() behavior in a simulated PyTorch distributed processing mode and + DataLoader multiprocessing mode, using mocks to avoid having to do distributed pytorch + setup or real DataLoader multiprocessing.""" + + with ( + patch("torch.utils.data.get_worker_info") as mock_get_worker_info, + patch("torch.distributed.is_initialized") as mock_dist_is_initialized, + patch("torch.distributed.get_rank") as mock_dist_get_rank, + patch("torch.distributed.get_world_size") as mock_dist_get_world_size, + ): + mock_get_worker_info.return_value = WorkerInfo( + id=worker_id, num_workers=num_workers, seed=1234 + ) + mock_dist_is_initialized.return_value = True + mock_dist_get_rank.return_value = rank + mock_dist_get_world_size.return_value = world_size + + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["soma_joinid"], + io_batch_size=2, + shuffle=False, + ) + + full_result = list(iter(dp)) + + soma_joinids = np.concatenate( + [t[1]["soma_joinid"].to_numpy() for t in full_result] + ) + + expected_joinids = np.array_split(np.arange(obs_range), world_size)[rank][ + 0 : obs_range // world_size + ] + expected_joinids = np.array_split(expected_joinids, num_workers)[worker_id] + assert sorted(soma_joinids) == expected_joinids.tolist() + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_experiment_dataloader__non_batched( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + dl = experiment_dataloader(dp) + data = [row for row in dl] + assert all(d[0].shape == (3,) for d in data) + assert all(d[1].shape == (1, 1) for d in data) + + row = data[0] + assert row[0].tolist() == [0, 1, 0] + assert row[1]["label"].tolist() == ["0"] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_experiment_dataloader__batched( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + dl = experiment_dataloader(dp) + data = [row for row in dl] + + batch = data[0] + assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1].to_numpy().tolist() == [[0], [1], [2]] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [ + (10, 3, pytorch_x_value_gen, use_eager_fetch) + for use_eager_fetch in (True, False) + ], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_experiment_dataloader__batched_length( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + dl = experiment_dataloader(dp) + assert len(dl) == len(list(dl)) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,batch_size", + [(10, 3, pytorch_x_value_gen, batch_size) for batch_size in (1, 3, 10)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_experiment_dataloader__collate_fn( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + batch_size: int, +) -> None: + def collate_fn( + batch_size: int, data: Tuple[npt.NDArray[np.number[Any]], pd.DataFrame] + ) -> Tuple[npt.NDArray[np.number[Any]], pd.DataFrame]: + assert isinstance(data, tuple) + assert len(data) == 2 + assert isinstance(data[0], np.ndarray) and isinstance(data[1], pd.DataFrame) + if batch_size > 1: + assert data[0].shape[0] == data[1].shape[0] + assert data[0].shape[0] <= batch_size + else: + assert data[0].ndim == 1 + assert data[1].shape[1] <= batch_size + return data + + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=batch_size, + shuffle=False, + ) + dl = experiment_dataloader(dp, collate_fn=partial(collate_fn, batch_size)) + assert len(list(dl)) > 0 + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test__X_tensor_dtype_matches_X_matrix( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + data = next(iter(dp)) + + assert data[0].dtype == np.float32 + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(10, 1, pytorch_x_value_gen)] +) +def test__pytorch_splitting( + soma_experiment: Experiment, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = ExperimentAxisQueryIterDataPipe( + query, + X_name="raw", + obs_column_names=["label"], + ) + # function not available for IterableDataset, yet.... + dp_train, dp_test = dp.random_split( + weights={"train": 0.7, "test": 0.3}, seed=1234 + ) + dl = experiment_dataloader(dp_train) + + all_rows = list(iter(dl)) + assert len(all_rows) == 7 + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(16, 1, pytorch_seq_x_value_gen)] +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test__shuffle( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + shuffle=True, + ) + + all_rows = list(iter(dp)) + assert all(r[0].shape == (1,) for r in all_rows) + soma_joinids = [row[1]["soma_joinid"].iloc[0] for row in all_rows] + X_values = [row[0][0].item() for row in all_rows] + + # same elements + assert set(soma_joinids) == set(range(16)) + # not ordered! (...with a `1/16!` probability of being ordered) + assert soma_joinids != list(range(16)) + # randomizes X in same order as obs + # note: X values were explicitly set to match obs_joinids to allow for this simple assertion + assert X_values == soma_joinids + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)] +) +def test_experiment_axis_query_iterable_error_checks( + soma_experiment: Experiment, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = ExperimentAxisQueryIterable( + query, + X_name="raw", + shuffle=True, + ) + with pytest.raises(NotImplementedError): + dp[0] + + with pytest.raises(ValueError): + dp = ExperimentAxisQueryIterable( + query, + obs_column_names=(), + X_name="raw", + shuffle=True, + ) + + +def test_experiment_dataloader__unsupported_params__fails() -> None: + with patch( + "tiledbsoma_ml.pytorch.ExperimentAxisQueryIterDataPipe" + ) as dummy_exp_data_pipe: + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, shuffle=True) + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, batch_size=3) + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, batch_sampler=[]) + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, sampler=[]) + + +def test_batched() -> None: + from tiledbsoma_ml.pytorch import _batched + + assert list(_batched(range(6), 1)) == list((i,) for i in range(6)) + assert list(_batched(range(6), 2)) == [(0, 1), (2, 3), (4, 5)] + assert list(_batched(range(6), 3)) == [(0, 1, 2), (3, 4, 5)] + assert list(_batched(range(6), 4)) == [(0, 1, 2, 3), (4, 5)] + assert list(_batched(range(6), 5)) == [(0, 1, 2, 3, 4), (5,)] + assert list(_batched(range(6), 6)) == [(0, 1, 2, 3, 4, 5)] + assert list(_batched(range(6), 7)) == [(0, 1, 2, 3, 4, 5)] + + # bogus batch value + with pytest.raises(ValueError): + list(_batched([0, 1], 0)) + with pytest.raises(ValueError): + list(_batched([2, 3], -1)) + + +def test_splits() -> None: + from tiledbsoma_ml.pytorch import _splits + + assert _splits(10, 1).tolist() == [0, 10] + assert _splits(10, 3).tolist() == [0, 4, 7, 10] + assert _splits(10, 4).tolist() == [0, 3, 6, 8, 10] + assert _splits(10, 10).tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + assert _splits(10, 11).tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10] + + # bad number of sections + with pytest.raises(ValueError): + _splits(10, 0) + with pytest.raises(ValueError): + _splits(10, -1) + + +@pytest.mark.parametrize( # keep these small as we materialize as a dense ndarray + "shape", + [(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)], +) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) +def test_csr__construct_from_ijd(shape: Tuple[int, int], dtype: npt.DTypeLike) -> None: + from tiledbsoma_ml.pytorch import _CSR_IO_Buffer + + sp_coo = sparse.random(shape[0], shape[1], dtype=dtype, format="coo", density=0.05) + sp_csr = sp_coo.tocsr() + + _ncsr = _CSR_IO_Buffer.from_ijd( + sp_coo.row, sp_coo.col, sp_coo.data, shape=sp_coo.shape + ) + assert _ncsr.nnz == sp_coo.nnz == sp_csr.nnz + assert _ncsr.dtype == sp_coo.dtype == sp_csr.dtype + assert _ncsr.nbytes == ( + _ncsr.data.nbytes + _ncsr.indices.nbytes + _ncsr.indptr.nbytes + ) + + # _CSR_IO_Buffer makes no guarantees about minor axis ordering (ie, "canonical" form) until + # sort_indices is called, so use the SciPy sparse csr package to validate by round-tripping. + assert ( + sparse.csr_matrix((_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape) + != sp_csr + ).nnz == 0 + + # Check dense slicing + assert np.array_equal(_ncsr.slice_tonumpy(slice(0, shape[0])), sp_coo.toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(0, shape[0])), sp_csr.toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(1, -1)), sp_csr[1:-1].toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(None, -2)), sp_csr[:-2].toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(None)), sp_csr[:].toarray()) + + # Check sparse slicing + assert (_ncsr.slice_toscipy(slice(0, shape[0])) != sp_csr).nnz == 0 + assert (_ncsr.slice_toscipy(slice(1, -1)) != sp_csr[1:-1]).nnz == 0 + assert (_ncsr.slice_toscipy(slice(None, -2)) != sp_csr[:-2]).nnz == 0 + assert (_ncsr.slice_toscipy(slice(None)) != sp_csr[:]).nnz == 0 + + +@pytest.mark.parametrize( + "shape", + [(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)], +) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) +def test_csr__construct_from_pjd(shape: Tuple[int, int], dtype: npt.DTypeLike) -> None: + from tiledbsoma_ml.pytorch import _CSR_IO_Buffer + + sp_csr = sparse.random(shape[0], shape[1], dtype=dtype, format="csr", density=0.05) + + _ncsr = _CSR_IO_Buffer.from_pjd( + sp_csr.indptr.copy(), + sp_csr.indices.copy(), + sp_csr.data.copy(), + shape=sp_csr.shape, + ) + + # _CSR_IO_Buffer makes no guarantees about minor axis ordering (ie, "canonical" form) until + # sort_indices is called, so use the SciPy sparse csr package to validate by round-tripping. + assert ( + sparse.csr_matrix((_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape) + != sp_csr + ).nnz == 0 + + # Check dense slicing + assert np.array_equal(_ncsr.slice_tonumpy(slice(0, shape[0])), sp_csr.toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(1, -1)), sp_csr[1:-1].toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(None, -2)), sp_csr[:-2].toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(None)), sp_csr[:].toarray()) + + # Check sparse slicing + assert (_ncsr.slice_toscipy(slice(0, shape[0])) != sp_csr).nnz == 0 + assert (_ncsr.slice_toscipy(slice(1, -1)) != sp_csr[1:-1]).nnz == 0 + assert (_ncsr.slice_toscipy(slice(None, -2)) != sp_csr[:-2]).nnz == 0 + assert (_ncsr.slice_toscipy(slice(None)) != sp_csr[:]).nnz == 0 + + +@pytest.mark.parametrize( + "shape", + [(100, 10), (10, 100)], +) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) +@pytest.mark.parametrize("n_splits", [2, 3, 4]) +def test_csr__merge( + shape: Tuple[int, int], dtype: npt.DTypeLike, n_splits: int +) -> None: + from tiledbsoma_ml.pytorch import _CSR_IO_Buffer + + sp_coo = sparse.random(shape[0], shape[1], dtype=dtype, format="coo", density=0.5) + splits = [ + t + for t in zip( + np.array_split(sp_coo.row, n_splits), + np.array_split(sp_coo.col, n_splits), + np.array_split(sp_coo.data, n_splits), + ) + ] + _ncsr = _CSR_IO_Buffer.merge( + [_CSR_IO_Buffer.from_ijd(i, j, d, shape=sp_coo.shape) for i, j, d in splits] + ) + + assert ( + sp_coo.tocsr() + != sparse.csr_matrix( + (_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape + ) + ).nnz == 0 + + +@pytest.mark.parametrize( + "shape", + [(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)], +) +def test_csr__sort_indices(shape: Tuple[int, int]) -> None: + from tiledbsoma_ml.pytorch import _CSR_IO_Buffer + + sp_coo = sparse.random( + shape[0], shape[1], dtype=np.float32, format="coo", density=0.05 + ) + sp_csr = sp_coo.tocsr() + + _ncsr = _CSR_IO_Buffer.from_ijd( + sp_coo.row, sp_coo.col, sp_coo.data, shape=sp_coo.shape + ).sort_indices() + + assert np.array_equal(sp_csr.indptr, _ncsr.indptr) + assert np.array_equal(sp_csr.indices, _ncsr.indices) + assert np.array_equal(sp_csr.data, _ncsr.data)