diff --git a/api/python/cellxgene_census/src/cellxgene_census/_experiment.py b/api/python/cellxgene_census/src/cellxgene_census/_experiment.py index 8f5b3f5fd..f8edb0ce2 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/_experiment.py +++ b/api/python/cellxgene_census/src/cellxgene_census/_experiment.py @@ -31,7 +31,7 @@ def _get_experiment(census: soma.Collection, organism: str) -> soma.Experiment: ValueError: if unable to find the specified organism. Lifecycle: - Experimental. + maturing Examples: diff --git a/api/python/cellxgene_census/src/cellxgene_census/_get_anndata.py b/api/python/cellxgene_census/src/cellxgene_census/_get_anndata.py index 271a9498c..9d5c6f068 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/_get_anndata.py +++ b/api/python/cellxgene_census/src/cellxgene_census/_get_anndata.py @@ -63,7 +63,7 @@ def get_anndata( An :class:`anndata.AnnData` object containing the census slice. Lifecycle: - Experimental. + maturing Examples: >>> get_anndata(census, "Mus musculus", obs_value_filter="tissue_general in ['brain', 'lung']") diff --git a/api/python/cellxgene_census/src/cellxgene_census/_open.py b/api/python/cellxgene_census/src/cellxgene_census/_open.py index 2b126c3ea..11d977c4a 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/_open.py +++ b/api/python/cellxgene_census/src/cellxgene_census/_open.py @@ -89,7 +89,7 @@ def open_soma( or a version are specified. Lifecycle: - Experimental. + maturing Examples: Open the default Census version, using a context manager which will automatically @@ -169,7 +169,7 @@ def get_source_h5ad_uri(dataset_id: str, *, census_version: str = "latest") -> C KeyError: if either `dataset_id` or `census_version` do not exist. Lifecycle: - Experimental. + maturing Examples: >>> cellxgene_census.get_source_h5ad_uri("cb5efdb0-f91c-4cbd-9ad4-9d4fa41c572d") @@ -206,7 +206,7 @@ def download_source_h5ad(dataset_id: str, to_path: str, *, census_version: str = an existing file), or is not a file. Lifecycle: - Experimental. + maturing See Also: :func:`get_source_h5ad_uri`: Look up the location of the source H5AD. diff --git a/api/python/cellxgene_census/src/cellxgene_census/_presence_matrix.py b/api/python/cellxgene_census/src/cellxgene_census/_presence_matrix.py index 2ab5c0d5c..bd6f57e9f 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/_presence_matrix.py +++ b/api/python/cellxgene_census/src/cellxgene_census/_presence_matrix.py @@ -38,7 +38,7 @@ def get_presence_matrix( ValueError: if the organism cannot be found. Lifecycle: - Experimental. + maturing Examples: >>> get_presence_matrix(census, "Homo sapiens", "RNA") diff --git a/api/python/cellxgene_census/src/cellxgene_census/_release_directory.py b/api/python/cellxgene_census/src/cellxgene_census/_release_directory.py index 3d61ac18c..3f8f43b32 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/_release_directory.py +++ b/api/python/cellxgene_census/src/cellxgene_census/_release_directory.py @@ -54,7 +54,7 @@ def get_census_version_description(census_version: str) -> CensusVersionDescript KeyError: if unknown census_version value. Lifecycle: - Experimental. + maturing See Also: :func:`get_census_version_directory`: returns the entire directory as a dict. @@ -83,7 +83,7 @@ def get_census_version_directory() -> Dict[CensusVersionName, CensusVersionDescr A dictionary that contains release names and their corresponding release description. Lifecycle: - Experimental. + maturing See Also: :func:`get_census_version_description`: get description by census_version. diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/__init__.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/__init__.py index e69de29bb..ab62b4e63 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/__init__.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/__init__.py @@ -0,0 +1,7 @@ +""" +An API to facilitate use of PyTorch ML training with data from the CZI Science CELLxGENE Census. +""" + +from .pytorch import ExperimentDataPipe, Stats, experiment_dataloader + +__all__ = ["Stats", "ExperimentDataPipe", "experiment_dataloader"] diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index 8c89536ae..153584bd9 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -34,6 +34,13 @@ @attrs class Stats: + """ + Statistics about the data retrieved by ``ExperimentDataPipe`` via SOMA API. + + Lifecycle: + experimental + """ + n_obs: int = 0 """The total number of obs rows retrieved""" @@ -71,21 +78,21 @@ def _open_experiment( class _ObsAndXIterator(Iterator[ObsDatum]): """ - Iterates through a set of obs and related X rows, specified as `soma_joinid`s. Encapsulates the batch-based data - fetching from TileDB-SOMA objects, providing row-based iteration. + Iterates through a set of obs and related X rows, specified as ``soma_joinid``s. Encapsulates the batch-based data + fetching from SOMA objects, providing row-based iteration. """ obs_tables_iter: somacore.ReadIter[pa.Table] - """Iterates the TileDB-SOMA batches (tables) of obs data""" + """Iterates the SOMA batches (tables) of obs data""" obs_batch_: pd.DataFrame = pd.DataFrame() - """The current TileDB-SOMA batch of obs data""" + """The current SOMA batch of obs data""" X_batch: scipy.matrix = None - """All X data for the soma_joinids of the current obs - batch""" + """All X data for the ``soma_joinid``s of the current obs - batch""" i: int = -1 - """Index into current obs TileDB-SOMA batch""" + """Index into current obs ``SOMA`` batch""" def __init__( self, @@ -177,7 +184,7 @@ def obs_batch(self) -> pd.DataFrame: Returns the current SOMA batch of obs rows. If the current SOMA batch has been fully iterated, loads the next SOMA batch of both obs and X data and returns the new obs batch (only). - Raises StopIteration if there are no more SOMA batches to retrieve. + Raises ``StopIteration`` if there are no more SOMA batches to retrieve. """ if 0 <= self.i < len(self.obs_batch_): return self.obs_batch_ @@ -203,30 +210,33 @@ def obs_batch(self) -> pd.DataFrame: class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsDatum]]): # type: ignore """ - An iterable-style PyTorch data pipe that reads obs and X data from a SOMA Experiment, and returns an iterator of - tuples of torch tensors: + An iterable-style PyTorch ``DataPipe`` that reads obs and X data from a SOMA Experiment, and returns an iterator of + tuples of PyTorch ``Tensor``s: - (tensor([0., 0., 0., 0., 0., 1., 0., 0., 0.]), # X data - tensor([2415, 0, 0], dtype=torch.int32)) # obs data, encoded + >>> (tensor([0., 0., 0., 0., 0., 1., 0., 0., 0.]), # X data + tensor([2415, 0, 0], dtype=torch.int32)) # obs data, encoded Supports batching via `batch_size` param: - DataLoader(..., batch_size=3, ...): - (tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0.], # X batch - [0., 0., 0., 0., 0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0., 0., 0., 0., 0.]]), - tensor([[2415, 0, 0], # obs batch - [2416, 0, 4], - [2417, 0, 3]], dtype=torch.int32)) + >>> DataLoader(..., batch_size=3, ...): + (tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0.], # X batch + [0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0.]]), + tensor([[2415, 0, 0], # obs batch + [2416, 0, 4], + [2417, 0, 3]], dtype=torch.int32)) Obs attribute values are encoded as categoricals. Values can be decoded by obtaining the encoder for a given attribute: - exp_data_pipe.obs_encoders()[""].inverse_transform(encoded_values) + >>> exp_data_pipe.obs_encoders()[""].inverse_transform(encoded_values) + + Lifecycle: + experimental """ _query: Optional[soma.ExperimentAxisQuery] - """In multi-processing mode (i.e. num_workers > 0), this ExperimentAxisQuery object will *not* be pickled; + """In multi-processing mode (i.e. num_workers > 0), this ``ExperimentAxisQuery`` object will *not* be pickled; each worker will instantiate a new query""" _obs_joinids_partitioned: Optional[pa.Array] @@ -252,6 +262,16 @@ def __init__( num_workers: int = 0, soma_buffer_bytes: Optional[int] = None, ) -> None: + """ + Construct a new ``ExperimentDataPipe``. + + Returns: + ``ExperimentDataPipe``. + + Lifecycle: + experimental + """ + self.exp_uri = experiment.uri self.aws_region = experiment.context.tiledb_ctx.config().get("vfs.s3.region") self.measurement_name = measurement_name @@ -356,6 +376,15 @@ def __setstate__(self, state: Dict[str, Any]) -> None: self._query = None def obs_encoders(self) -> Encoders: + """ + Returns the encoders that were used to encode obs column values and that are needed to decode them. + + Returns: + ``Dict[str, LabelEncoder]`` mapping column names to ``LabelEncoder``s. + + Lifecycle: + experimental + """ if self._encoders is not None: return self._encoders @@ -373,10 +402,33 @@ def obs_encoders(self) -> Encoders: return self._encoders def stats(self) -> Stats: + """ + Get data loading stats for this ``ExperimentDataPipe``. + + Args: None. + + Returns: + ``Stats`` object. + + Lifecycle: + experimental + """ return self._stats @property def shape(self) -> Tuple[int, int]: + """ + Get the shape of the data that will be returned by this ExperimentDataPipe. This is the number of + obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode + (i.e. DataLoader instantiated with num_workers > 0), the obs (cell) count will reflect the size of the + partition of the data assigned to the active process. + + Returns: + 2-tuple of ``int``s, for obs and var counts, respectively. + + Lifecycle: + experimental + """ self._init() assert self._query is not None @@ -384,7 +436,7 @@ def shape(self) -> Tuple[int, int]: # Note: must be a top-level function (and not a lambda), to play nice with multiprocessing pickling -def collate_noop(x: Any) -> Any: +def _collate_noop(x: Any) -> Any: return x @@ -395,9 +447,15 @@ def experiment_dataloader( **dataloader_kwargs: Any, ) -> DataLoader: """ - Factory method for PyTorch DataLoader. Provides a safer, more convenient interface for instantiating a DataLoader - that works with the ExperimentDataPipe, since not all of DataLoader's params can be used (batch_size, sampler, - batch_sampler, collate_fn). + Factory method for PyTorch ``DataLoader``. Provides a safer, more convenient interface for instantiating a + ``DataLoader`` that works with the ``ExperimentDataPipe``, since not all of ``DataLoader``'s params can be + used (``batch_size``, ``sampler``, ``batch_sampler``, ``collate_fn``). + + Returns: + PyTorch ``DataLoader``. + + Lifecycle: + experimental """ unsupported_dataloader_args = ["batch_size", "sampler", "batch_sampler", "collate_fn"] @@ -409,7 +467,7 @@ def experiment_dataloader( batch_size=None, # batching is handled by our ExperimentDataPipe num_workers=num_workers, # avoid use of default collator, which adds an extra (3rd) dimension to the tensor batches - collate_fn=collate_noop, + collate_fn=_collate_noop, **dataloader_kwargs, ) diff --git a/api/python/notebooks/ml_demo/pytorch_lr_classifier.py b/api/python/notebooks/ml_demo/pytorch_lr_classifier.py index 711108117..a8ef9c68e 100644 --- a/api/python/notebooks/ml_demo/pytorch_lr_classifier.py +++ b/api/python/notebooks/ml_demo/pytorch_lr_classifier.py @@ -3,7 +3,7 @@ import torch import cellxgene_census -from cellxgene_census.experimental.ml.pytorch import experiment_dataloader, ExperimentDataPipe +import cellxgene_census.experimental.ml as census_ml # TODO: Convert this to a notebook @@ -82,7 +82,7 @@ def run(): obs_value_filter = "tissue_general == 'tongue' and is_primary_data == True" var_value_filter = "" - exp_dp = ExperimentDataPipe( + exp_dp = census_ml.ExperimentDataPipe( census["census_data"]["homo_sapiens"], measurement_name="RNA", X_name="raw", @@ -95,7 +95,7 @@ def run(): dp = exp_dp.shuffle(buffer_size=len(exp_dp)) dp_train, dp_test = dp.random_split(weights={"train": 0.7, "test": 0.3}, seed=RANDOM_SEED) - dl_train = experiment_dataloader( + dl_train = census_ml.experiment_dataloader( dp_train, # >= 1 uses multiprocessing to load data num_workers=0,