-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[docs] Add PyTorch loaders article release #1214
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
148cb86
Add pytorch article
pablo-gar af077a1
update article
pablo-gar 6325439
add article
pablo-gar a87b65b
add author
pablo-gar cadef7a
Editorial
pablo-gar e7d398c
lint
pablo-gar cbacf9e
editorial
pablo-gar 5c7e5bf
editorial
pablo-gar a33479a
lint
pablo-gar db019c4
editorial
pablo-gar bdbdff8
final edits
pablo-gar cce2337
update date
pablo-gar 273ffc3
update date
pablo-gar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# First stable iteration of Census (SOMA) PyTorch loaders | ||
|
||
*Published:* *July 9th, 2024* | ||
|
||
*By:* *[Emanuele Bezzi](mailto:[email protected]), [Pablo Garcia-Nieto](mailto:[email protected]), [Prathap Sridharan](mailto:[email protected]), [Ryan Williams](mailto:[email protected])* | ||
|
||
The Census team is excited to share the release of Census PyTorch loaders that work out-of-the-box for memory-efficient training across any slice of the >70M cells in Census. | ||
|
||
In 2023, we released a beta version of the loaders and we have observed interest from users to utilize them with Census or their own data. For example [Wolf et al.](https://lamin.ai/blog/arrayloader-benchmarks) performed comparisons across different training approaches and found our loaders to be ideal for *uncached* training of Census data, albeit with some caveats. | ||
|
||
We have continued the development of the loaders in collaboration with our partners at TileDB, and we are happy to announce this release as the first stable iteration. We hope the loaders can accelerate the development of large-scale models of single-cell data by leveraging the following main features: | ||
|
||
- **Out-of-the-box training on all or any slice of Census data.** | ||
- **Efficient memory usage with out-of-core training.** | ||
- **Calibrated shuffling of observations (cells).** | ||
- **Cloud-based or local data access.** | ||
- **Increased training speed.** | ||
- **Custom data encoders.** | ||
|
||
Keep on reading for usage and more details on the main loader features. | ||
|
||
## Census PyTorch loaders usage | ||
|
||
The loaders are ready to use for PyTorch modeling via the specialized Data Pipe [`ExperimentDataPipe`](https://chanzuckerberg.github.io/cellxgene-census/_autosummary/cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe.html#cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe), which takes advantage of the out-of-core data access TileDB-SOMA offers. | ||
|
||
Please follow the [Training a PyTorch Model](https://chanzuckerberg.github.io/cellxgene-census/notebooks/experimental/pytorch.html) tutorial for a full reproducible example to train a logistic regression on cell type labels. | ||
|
||
In short, the following shows you how to initialize the loader to train a model on a small subset of cells. First, you can initialize a `ExperimentDataPipe` to train a model on tongue cells as follows: | ||
|
||
```python | ||
import cellxgene_census.experimental.ml as census_ml | ||
import cellxgene_census | ||
import tiledbsoma as soma | ||
|
||
experiment = census["census_data"]["homo_sapiens"] | ||
|
||
experiment_datapipe = census_ml.ExperimentDataPipe( | ||
experiment, | ||
measurement_name="RNA", | ||
X_name="raw", | ||
obs_query=soma.AxisQuery(value_filter="tissue_general == 'tongue' and is_primary_data == True"), | ||
obs_column_names=["cell_type"], | ||
batch_size=128, | ||
shuffle=True, | ||
) | ||
``` | ||
|
||
Then you can perform any PyTorch operations and training. | ||
|
||
```python | ||
# Splitting training and test sets | ||
train_datapipe, test_datapipe = experiment_datapipe.random_split(weights={"train": 0.8, "test": 0.2}, seed=1) | ||
|
||
# Creating data loader | ||
experiment_dataloader = census_ml.experiment_dataloader(train_datapipe) | ||
|
||
# Training a PyTorch model | ||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | ||
model = MODEL().to(device) | ||
model.train() | ||
``` | ||
|
||
## Census PyTorch loaders main features | ||
|
||
### Out-of-the-box training on all or any slice of Census data | ||
|
||
Since the `ExperimentDataPipe` inherits from the [PyTorch Iterable-style DataPipe](https://pytorch.org/data/main/torchdata.datapipes.iter.html) it can be readily used with PyTorch models. | ||
|
||
The single-cell expression data is encoded in numerical tensors, and for supervised training the cell metadata can be automatically transformed with a default encoder, or with custom user-defined encoders (see below). | ||
|
||
### Efficient memory usage with out-of-core training | ||
|
||
Thanks to the underlying backend of Census — TileDB-SOMA — the PyTorch loaders take advantage of incremental data materialization of fixed and small size to keep memory usage constant throughout training. | ||
|
||
In addition, data is eagerly fetched while batches go through training so that compute is never idle or waiting for data to be loaded. This feature is particularly useful when fetching Census data directly from the cloud. | ||
|
||
Memory usage is defined by the parameters `soma_chunk_size` and `shuffle_chunk_count` - see below for a full description on how these should be tuned. | ||
|
||
### Calibrated shuffling of observations (cells) | ||
|
||
Shuffling along efficient out-of-core data fetching is a challenge. In general, increasing randomness of shuffling leads to slower data fetching. | ||
|
||
In the first iteration of the loaders, shuffling was done through large blocks of data of user-defined size. This shuffling strategy led to non-random distribution of observations per training batch, becasue Census has a non-random data structure (observations from the same datasets are adjacent to one another) thus training loss was unstable (Figure 1). | ||
|
||
**Now we have implemented a scatter-gather approach**, whereby multiple chunks of data are fetched randomly from Census, then a number of chunks are concatenated into a block and all observations within the block are randomly shuffled. Adjusting the size and number of chunks per block leads to well-calibrated shuffling with stable training loss (Figure 2) while maintaining efficient data fetching (Figure 3). | ||
|
||
The balance between memory usage, efficiency, and level of randomness can be adjusted with the parameters `soma_chunk_size` and `shuffle_chunk_count`. Increasing `shuffle_chunk_count` will improve randomness, as more scattered chunks will be collected before the pool is randomized. Increasing `soma_chunk_size` will improve I/O efficiency while decreasing it will improve memory usage. We recommend a default of `soma_chunk_size=64, shuffle_chunk_count=2000` as we determined this configuration yields a good balance. | ||
|
||
```{figure} ./20240709-pytorch-fig-loss-before.png | ||
:alt: Census PyTorch loaders shuffling | ||
:align: center | ||
:figwidth: 80% | ||
|
||
**Figure 1. Training loss was unstable with the previous shuffling strategy**. Based on a trial scVI run on 64K Census cells. | ||
``` | ||
|
||
```{figure} ./20240709-pytorch-fig-loss-after.png | ||
:alt: Census PyTorch loaders callibrated shuffling | ||
:align: center | ||
:figwidth: 80% | ||
|
||
**Figure 2. Training loss is well-calibrated with the current scatter-gather shuffling strategy.** Based on a trial scVI run on 250K Census cells. | ||
``` | ||
|
||
### Increased training speed | ||
|
||
We have made improvements to the loaders to reduce the amount of data transformations required from data fetching to model training. One such important change is to encode the expression data as a dense matrix immediately after the data is retrieved from disk/cloud. | ||
|
||
In our benchmarks, we found that densifying data increases training speed ~3X while maintaining relatively constant memory usage (Figure 3). For this reason, we have disable the intermediate data processing in sparse format unless Torch Sparse Tensors are requested via the `ExperimentDataPipe` parameter `return_sparse_X`. | ||
|
||
```{figure} ./20240709-pytorch-fig-benchmark.png | ||
:alt: Census PyTorch loaders benchmark | ||
:align: center | ||
:figwidth: 80% | ||
|
||
**Figure 3. Benchmark of memory usage and speed of data processing during modeling, default parameters lead to 3K+ samples/sec with 27GB of memory.** The benchmark was done processing 4M cells out of a 10M-cell Census, data was fetched from the cloud (S3). "Method" indicates the expression matrix encoding, circles are dense (np.array) and squares are sparse (scipy.csr). Size indicates the total number of cells per processing block (max cells materialized at any given time) and color is the number of individual randomly grabbed chunks composing a processing block, higher chunks per block lead to better shuffling. Data was fetched until modeling step, but no model was trained. | ||
``` | ||
|
||
We repeated the benchmark in Figure 3 in different conditions encompassing varying number of total cells and multiple epochs, please [follow this link for the full benchmark report and code.](https://github.com/ryan-williams/arrayloader-benchmarks). | ||
|
||
When comparing dense vs sparse processing in an end-to-end training exercise with scVI, we also observed slight increased speed with the dense approach and comparable memory usage to sparse processing (Figure 4). However in this full training example the differences were less substantial, highlighting that other model-specific factors during the training phase will contribute to memory and speed performance. | ||
|
||
```{figure} ./20240709-pytorch-fig-scvi.png | ||
:alt: Census scVI PyTorch run | ||
:align: center | ||
:figwidth: 80% | ||
|
||
**Figure 4. Trial scVI training run with default parameters of the Census Pytorch loaders, highlighting increased speed of dense vs sparse data processing.** Training was done on 5684805 mouse cells for 1 epoch on a g4dn.16xlarge EC2 machine. | ||
``` | ||
|
||
### Custom data encoders | ||
|
||
For maximum flexibility, users can provide custom encoders for the cell metadata enabling custom transformations or interactions between different metadata variables. | ||
|
||
To use custom encoders you need to instantiate the desired encoder via the [Encoder](https://chanzuckerberg.github.io/cellxgene-census/_autosummary/cellxgene_census.experimental.ml.encoders.Encoder.html#cellxgene_census.experimental.ml.encoders.Encoder) class and pass it to the `encoders` parameter of the `ExperimentDataPipe`. |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ryan's email is wrong. Worth checking if they're ok with adding the email here though?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I saw that #1228 addresses this 🙏 (and yes, I'm ok / appreciate being listed!)