Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: update lifecycle tags #509

Merged
merged 7 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _get_experiment(census: soma.Collection, organism: str) -> soma.Experiment:
ValueError: if unable to find the specified organism.

Lifecycle:
Experimental.
maturing

Examples:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_anndata(
An :class:`anndata.AnnData` object containing the census slice.

Lifecycle:
Experimental.
maturing

Examples:
>>> get_anndata(census, "Mus musculus", obs_value_filter="tissue_general in ['brain', 'lung']")
Expand Down
6 changes: 3 additions & 3 deletions api/python/cellxgene_census/src/cellxgene_census/_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def open_soma(
or a version are specified.

Lifecycle:
Experimental.
maturing

Examples:
Open the default Census version, using a context manager which will automatically
Expand Down Expand Up @@ -169,7 +169,7 @@ def get_source_h5ad_uri(dataset_id: str, *, census_version: str = "latest") -> C
KeyError: if either `dataset_id` or `census_version` do not exist.

Lifecycle:
Experimental.
maturing

Examples:
>>> cellxgene_census.get_source_h5ad_uri("cb5efdb0-f91c-4cbd-9ad4-9d4fa41c572d")
Expand Down Expand Up @@ -206,7 +206,7 @@ def download_source_h5ad(dataset_id: str, to_path: str, *, census_version: str =
an existing file), or is not a file.

Lifecycle:
Experimental.
maturing

See Also:
:func:`get_source_h5ad_uri`: Look up the location of the source H5AD.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_presence_matrix(
ValueError: if the organism cannot be found.

Lifecycle:
Experimental.
maturing

Examples:
>>> get_presence_matrix(census, "Homo sapiens", "RNA")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_census_version_description(census_version: str) -> CensusVersionDescript
KeyError: if unknown census_version value.

Lifecycle:
Experimental.
maturing

See Also:
:func:`get_census_version_directory`: returns the entire directory as a dict.
Expand Down Expand Up @@ -83,7 +83,7 @@ def get_census_version_directory() -> Dict[CensusVersionName, CensusVersionDescr
A dictionary that contains release names and their corresponding release description.

Lifecycle:
Experimental.
maturing

See Also:
:func:`get_census_version_description`: get description by census_version.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
An API to facilitate use of PyTorch ML training with data from the CZI Science CELLxGENE Census.
"""

from .pytorch import ExperimentDataPipe, Stats, experiment_dataloader

__all__ = ["Stats", "ExperimentDataPipe", "experiment_dataloader"]
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@

@attrs
class Stats:
"""
Statistics about the data retrieved by ``ExperimentDataPipe`` via SOMA API.

Lifecycle:
experimental
"""

n_obs: int = 0
"""The total number of obs rows retrieved"""

Expand Down Expand Up @@ -71,21 +78,21 @@ def _open_experiment(

class _ObsAndXIterator(Iterator[ObsDatum]):
"""
Iterates through a set of obs and related X rows, specified as `soma_joinid`s. Encapsulates the batch-based data
fetching from TileDB-SOMA objects, providing row-based iteration.
Iterates through a set of obs and related X rows, specified as ``soma_joinid``s. Encapsulates the batch-based data
fetching from SOMA objects, providing row-based iteration.
"""

obs_tables_iter: somacore.ReadIter[pa.Table]
"""Iterates the TileDB-SOMA batches (tables) of obs data"""
"""Iterates the SOMA batches (tables) of obs data"""

obs_batch_: pd.DataFrame = pd.DataFrame()
"""The current TileDB-SOMA batch of obs data"""
"""The current SOMA batch of obs data"""

X_batch: scipy.matrix = None
"""All X data for the soma_joinids of the current obs - batch"""
"""All X data for the ``soma_joinid``s of the current obs - batch"""

i: int = -1
"""Index into current obs TileDB-SOMA batch"""
"""Index into current obs ``SOMA`` batch"""

def __init__(
self,
Expand Down Expand Up @@ -177,7 +184,7 @@ def obs_batch(self) -> pd.DataFrame:
Returns the current SOMA batch of obs rows.
If the current SOMA batch has been fully iterated, loads the next SOMA batch of both obs and X data and returns
the new obs batch (only).
Raises StopIteration if there are no more SOMA batches to retrieve.
Raises ``StopIteration`` if there are no more SOMA batches to retrieve.
"""
if 0 <= self.i < len(self.obs_batch_):
return self.obs_batch_
Expand All @@ -203,30 +210,33 @@ def obs_batch(self) -> pd.DataFrame:

class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsDatum]]): # type: ignore
"""
An iterable-style PyTorch data pipe that reads obs and X data from a SOMA Experiment, and returns an iterator of
tuples of torch tensors:
An iterable-style PyTorch ``DataPipe`` that reads obs and X data from a SOMA Experiment, and returns an iterator of
tuples of PyTorch ``Tensor``s:

(tensor([0., 0., 0., 0., 0., 1., 0., 0., 0.]), # X data
tensor([2415, 0, 0], dtype=torch.int32)) # obs data, encoded
>>> (tensor([0., 0., 0., 0., 0., 1., 0., 0., 0.]), # X data
tensor([2415, 0, 0], dtype=torch.int32)) # obs data, encoded

Supports batching via `batch_size` param:
DataLoader(..., batch_size=3, ...):

(tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0.], # X batch
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
tensor([[2415, 0, 0], # obs batch
[2416, 0, 4],
[2417, 0, 3]], dtype=torch.int32))
>>> DataLoader(..., batch_size=3, ...):
(tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0.], # X batch
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
tensor([[2415, 0, 0], # obs batch
[2416, 0, 4],
[2417, 0, 3]], dtype=torch.int32))

Obs attribute values are encoded as categoricals. Values can be decoded by obtaining the encoder for a given
attribute:

exp_data_pipe.obs_encoders()["<obs_attr_name>"].inverse_transform(encoded_values)
>>> exp_data_pipe.obs_encoders()["<obs_attr_name>"].inverse_transform(encoded_values)

Lifecycle:
experimental
"""

_query: Optional[soma.ExperimentAxisQuery]
"""In multi-processing mode (i.e. num_workers > 0), this ExperimentAxisQuery object will *not* be pickled;
"""In multi-processing mode (i.e. num_workers > 0), this ``ExperimentAxisQuery`` object will *not* be pickled;
each worker will instantiate a new query"""

_obs_joinids_partitioned: Optional[pa.Array]
Expand All @@ -252,6 +262,16 @@ def __init__(
num_workers: int = 0,
soma_buffer_bytes: Optional[int] = None,
) -> None:
"""
Construct a new ``ExperimentDataPipe``.

Returns:
``ExperimentDataPipe``.

Lifecycle:
experimental
"""

self.exp_uri = experiment.uri
self.aws_region = experiment.context.tiledb_ctx.config().get("vfs.s3.region")
self.measurement_name = measurement_name
Expand Down Expand Up @@ -356,6 +376,15 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
self._query = None

def obs_encoders(self) -> Encoders:
"""
Returns the encoders that were used to encode obs column values and that are needed to decode them.

Returns:
``Dict[str, LabelEncoder]`` mapping column names to ``LabelEncoder``s.

Lifecycle:
experimental
"""
if self._encoders is not None:
return self._encoders

Expand All @@ -373,18 +402,41 @@ def obs_encoders(self) -> Encoders:
return self._encoders

def stats(self) -> Stats:
"""
Get data loading stats for this ``ExperimentDataPipe``.

Args: None.

Returns:
``Stats`` object.

Lifecycle:
experimental
"""
return self._stats

@property
def shape(self) -> Tuple[int, int]:
"""
Get the shape of the data that will be returned by this ExperimentDataPipe. This is the number of
obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode
(i.e. DataLoader instantiated with num_workers > 0), the obs (cell) count will reflect the size of the
partition of the data assigned to the active process.

Returns:
2-tuple of ``int``s, for obs and var counts, respectively.

Lifecycle:
experimental
"""
self._init()
assert self._query is not None

return self._query.n_obs, self._query.n_vars


# Note: must be a top-level function (and not a lambda), to play nice with multiprocessing pickling
def collate_noop(x: Any) -> Any:
def _collate_noop(x: Any) -> Any:
return x


Expand All @@ -395,9 +447,15 @@ def experiment_dataloader(
**dataloader_kwargs: Any,
) -> DataLoader:
"""
Factory method for PyTorch DataLoader. Provides a safer, more convenient interface for instantiating a DataLoader
that works with the ExperimentDataPipe, since not all of DataLoader's params can be used (batch_size, sampler,
batch_sampler, collate_fn).
Factory method for PyTorch ``DataLoader``. Provides a safer, more convenient interface for instantiating a
``DataLoader`` that works with the ``ExperimentDataPipe``, since not all of ``DataLoader``'s params can be
used (``batch_size``, ``sampler``, ``batch_sampler``, ``collate_fn``).

Returns:
PyTorch ``DataLoader``.

Lifecycle:
experimental
"""

unsupported_dataloader_args = ["batch_size", "sampler", "batch_sampler", "collate_fn"]
Expand All @@ -409,7 +467,7 @@ def experiment_dataloader(
batch_size=None, # batching is handled by our ExperimentDataPipe
num_workers=num_workers,
# avoid use of default collator, which adds an extra (3rd) dimension to the tensor batches
collate_fn=collate_noop,
collate_fn=_collate_noop,
**dataloader_kwargs,
)

Expand Down
6 changes: 3 additions & 3 deletions api/python/notebooks/ml_demo/pytorch_lr_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

import cellxgene_census
from cellxgene_census.experimental.ml.pytorch import experiment_dataloader, ExperimentDataPipe
import cellxgene_census.experimental.ml as census_ml

# TODO: Convert this to a notebook

Expand Down Expand Up @@ -82,7 +82,7 @@ def run():
obs_value_filter = "tissue_general == 'tongue' and is_primary_data == True"
var_value_filter = ""

exp_dp = ExperimentDataPipe(
exp_dp = census_ml.ExperimentDataPipe(
census["census_data"]["homo_sapiens"],
measurement_name="RNA",
X_name="raw",
Expand All @@ -95,7 +95,7 @@ def run():
dp = exp_dp.shuffle(buffer_size=len(exp_dp))
dp_train, dp_test = dp.random_split(weights={"train": 0.7, "test": 0.3}, seed=RANDOM_SEED)

dl_train = experiment_dataloader(
dl_train = census_ml.experiment_dataloader(
dp_train,
# >= 1 uses multiprocessing to load data
num_workers=0,
Expand Down