Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs] Add PyTorch loaders article release #1214

Merged
merged 13 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/articles/2024/20240709-pytorch-fig-scvi.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
135 changes: 135 additions & 0 deletions docs/articles/2024/20240709-pytorch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# First stable iteration of Census (SOMA) PyTorch loaders

*Published:* *July 9th, 2024*

*By:* *[Emanuele Bezzi](mailto:[email protected]), [Pablo Garcia-Nieto](mailto:[email protected]), [Prathap Sridharan](mailto:[email protected]), [Ryan Williams](mailto:[email protected])*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ryan's email is wrong. Worth checking if they're ok with adding the email here though?

Copy link
Contributor

@ryan-williams ryan-williams Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I saw that #1228 addresses this 🙏 (and yes, I'm ok / appreciate being listed!)


The Census team is excited to share the release of Census PyTorch loaders that work out-of-the-box for memory-efficient training across any slice of the >70M cells in Census.

In 2023, we released a beta version of the loaders and we have observed interest from users to utilize them with Census or their own data. For example [Wolf et al.](https://lamin.ai/blog/arrayloader-benchmarks) performed comparisons across different training approaches and found our loaders to be ideal for *uncached* training of Census data, albeit with some caveats.

We have continued the development of the loaders in collaboration with our partners at TileDB, and we are happy to announce this release as the first stable iteration. We hope the loaders can accelerate the development of large-scale models of single-cell data by leveraging the following main features:

- **Out-of-the-box training on all or any slice of Census data.**
- **Efficient memory usage with out-of-core training.**
- **Calibrated shuffling of observations (cells).**
- **Cloud-based or local data access.**
- **Increased training speed.**
- **Custom data encoders.**

Keep on reading for usage and more details on the main loader features.

## Census PyTorch loaders usage

The loaders are ready to use for PyTorch modeling via the specialized Data Pipe [`ExperimentDataPipe`](https://chanzuckerberg.github.io/cellxgene-census/_autosummary/cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe.html#cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe), which takes advantage of the out-of-core data access TileDB-SOMA offers.

Please follow the [Training a PyTorch Model](https://chanzuckerberg.github.io/cellxgene-census/notebooks/experimental/pytorch.html) tutorial for a full reproducible example to train a logistic regression on cell type labels.

In short, the following shows you how to initialize the loader to train a model on a small subset of cells. First, you can initialize a `ExperimentDataPipe` to train a model on tongue cells as follows:

```python
import cellxgene_census.experimental.ml as census_ml
import cellxgene_census
import tiledbsoma as soma

experiment = census["census_data"]["homo_sapiens"]

experiment_datapipe = census_ml.ExperimentDataPipe(
experiment,
measurement_name="RNA",
X_name="raw",
obs_query=soma.AxisQuery(value_filter="tissue_general == 'tongue' and is_primary_data == True"),
obs_column_names=["cell_type"],
batch_size=128,
shuffle=True,
)
```

Then you can perform any PyTorch operations and training.

```python
# Splitting training and test sets
train_datapipe, test_datapipe = experiment_datapipe.random_split(weights={"train": 0.8, "test": 0.2}, seed=1)

# Creating data loader
experiment_dataloader = census_ml.experiment_dataloader(train_datapipe)

# Training a PyTorch model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = MODEL().to(device)
model.train()
```

## Census PyTorch loaders main features

### Out-of-the-box training on all or any slice of Census data

Since the `ExperimentDataPipe` inherits from the [PyTorch Iterable-style DataPipe](https://pytorch.org/data/main/torchdata.datapipes.iter.html) it can be readily used with PyTorch models.

The single-cell expression data is encoded in numerical tensors, and for supervised training the cell metadata can be automatically transformed with a default encoder, or with custom user-defined encoders (see below).

### Efficient memory usage with out-of-core training

Thanks to the underlying backend of Census — TileDB-SOMA — the PyTorch loaders take advantage of incremental data materialization of fixed and small size to keep memory usage constant throughout training.

In addition, data is eagerly fetched while batches go through training so that compute is never idle or waiting for data to be loaded. This feature is particularly useful when fetching Census data directly from the cloud.

Memory usage is defined by the parameters `soma_chunk_size` and `shuffle_chunk_count` - see below for a full description on how these should be tuned.

### Calibrated shuffling of observations (cells)

Shuffling along efficient out-of-core data fetching is a challenge. In general, increasing randomness of shuffling leads to slower data fetching.

In the first iteration of the loaders, shuffling was done through large blocks of data of user-defined size. This shuffling strategy led to non-random distribution of observations per training batch, becasue Census has a non-random data structure (observations from the same datasets are adjacent to one another) thus training loss was unstable (Figure 1).

**Now we have implemented a scatter-gather approach**, whereby multiple chunks of data are fetched randomly from Census, then a number of chunks are concatenated into a block and all observations within the block are randomly shuffled. Adjusting the size and number of chunks per block leads to well-calibrated shuffling with stable training loss (Figure 2) while maintaining efficient data fetching (Figure 3).

The balance between memory usage, efficiency, and level of randomness can be adjusted with the parameters `soma_chunk_size` and `shuffle_chunk_count`. Increasing `shuffle_chunk_count` will improve randomness, as more scattered chunks will be collected before the pool is randomized. Increasing `soma_chunk_size` will improve I/O efficiency while decreasing it will improve memory usage. We recommend a default of `soma_chunk_size=64, shuffle_chunk_count=2000` as we determined this configuration yields a good balance.

```{figure} ./20240709-pytorch-fig-loss-before.png
:alt: Census PyTorch loaders shuffling
:align: center
:figwidth: 80%

**Figure 1. Training loss was unstable with the previous shuffling strategy**. Based on a trial scVI run on 64K Census cells.
```

```{figure} ./20240709-pytorch-fig-loss-after.png
:alt: Census PyTorch loaders callibrated shuffling
:align: center
:figwidth: 80%

**Figure 2. Training loss is well-calibrated with the current scatter-gather shuffling strategy.** Based on a trial scVI run on 250K Census cells.
```

### Increased training speed

We have made improvements to the loaders to reduce the amount of data transformations required from data fetching to model training. One such important change is to encode the expression data as a dense matrix immediately after the data is retrieved from disk/cloud.

In our benchmarks, we found that densifying data increases training speed ~3X while maintaining relatively constant memory usage (Figure 3). For this reason, we have disable the intermediate data processing in sparse format unless Torch Sparse Tensors are requested via the `ExperimentDataPipe` parameter `return_sparse_X`.

```{figure} ./20240709-pytorch-fig-benchmark.png
:alt: Census PyTorch loaders benchmark
:align: center
:figwidth: 80%

**Figure 3. Benchmark of memory usage and speed of data processing during modeling, default parameters lead to 3K+ samples/sec with 27GB of memory.** The benchmark was done processing 4M cells out of a 10M-cell Census, data was fetched from the cloud (S3). "Method" indicates the expression matrix encoding, circles are dense (np.array) and squares are sparse (scipy.csr). Size indicates the total number of cells per processing block (max cells materialized at any given time) and color is the number of individual randomly grabbed chunks composing a processing block, higher chunks per block lead to better shuffling. Data was fetched until modeling step, but no model was trained.
```

We repeated the benchmark in Figure 3 in different conditions encompassing varying number of total cells and multiple epochs, please [follow this link for the full benchmark report and code.](https://github.com/ryan-williams/arrayloader-benchmarks).

When comparing dense vs sparse processing in an end-to-end training exercise with scVI, we also observed slight increased speed with the dense approach and comparable memory usage to sparse processing (Figure 4). However in this full training example the differences were less substantial, highlighting that other model-specific factors during the training phase will contribute to memory and speed performance.

```{figure} ./20240709-pytorch-fig-scvi.png
:alt: Census scVI PyTorch run
:align: center
:figwidth: 80%

**Figure 4. Trial scVI training run with default parameters of the Census Pytorch loaders, highlighting increased speed of dense vs sparse data processing.** Training was done on 5684805 mouse cells for 1 epoch on a g4dn.16xlarge EC2 machine.
```

### Custom data encoders

For maximum flexibility, users can provide custom encoders for the cell metadata enabling custom transformations or interactions between different metadata variables.

To use custom encoders you need to instantiate the desired encoder via the [Encoder](https://chanzuckerberg.github.io/cellxgene-census/_autosummary/cellxgene_census.experimental.ml.encoders.Encoder.html#cellxgene_census.experimental.ml.encoders.Encoder) class and pass it to the `encoders` parameter of the `ExperimentDataPipe`.
Loading