Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow dask arrays to propagate through _normalize_data() #1663

Merged
merged 7 commits into from
Mar 21, 2022

Conversation

GenevieveBuckley
Copy link
Contributor

While I was looking through the scanpy source code, I found a note that says # dask doesn't do medians.

https://github.com/theislab/scanpy/blob/0c4ca5b21524c2972d514ddbd85834002ed623de/scanpy/preprocessing/_normalization.py#L17

Dask does in fact do medians, provided it's applied along an axis: dask/dask#5575
But this feature was only merged in November 2019 (the same month the comment above was added), so I think it was too new at the time to be widely known & available.

This PR attempts to remove the coercion to numpy, and allow dask arrays to propagate through the _normalize_data function.

@GenevieveBuckley GenevieveBuckley changed the title Dask median Allow dask arrays to propagate through _normalize_data() Feb 19, 2021
counts = np.asarray(counts) # dask doesn't do medians
after = np.median(counts[counts>0], axis=0) if after is None else after
try: # dask array
counts_greater_than_zero = counts[counts>0].compute_chunk_sizes()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not perfectly efficient, we have to compute the chunk sizes here because they are unknown due to counts[counts>0]. There's probably still advantages here, compared with coercing the whole thing to a numpy array at the beginning.

See here for more details: https://docs.dask.org/en/latest/array-chunks.html

compute_chunk_sizes() immediately performs computation and modifies the array in-place.

Copy link
Member

@ivirshup ivirshup left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, great to have this working! Bad timing on the last time we looked at it I guess 😄

Apart from the minor comments I've asked inline, could this get a test that checks to make sure the dask array is propagated through?

scanpy/tests/test_normalization.py Outdated Show resolved Hide resolved
scanpy/preprocessing/_normalization.py Outdated Show resolved Hide resolved
@GenevieveBuckley
Copy link
Contributor Author

Huh, I don't think the tests are checking what I thought they were.

  1. It doesn't look like you can have an AnnData object with a Dask array, what am I doing wrong here?
from anndata import AnnData
import numpy as np
import dask.array as da
from scipy.sparse import csr_matrix

X_total = [[1, 0], [3, 0], [5, 6]]

adata = AnnData(np.array(X_total), dtype='float32')
print(type(adata.X)  # is a numpy matrix, as expected

adata = AnnData(csr_matrix(X_total), dtype='float32')
print(type(adata.X)  # is a sparse matrix, as expected

adata = AnnData(da.from_array(X_total), dtype='float32')
print(type(adata.X)  # is a numpy array NOT a dask array, not what I expected
  1. The change I made to _normalize_data() changed coercion of counts, not X. When I stepped through this private function, it seemed like things were working the way I'd expected, but there's a lot of other stuff happening before & afterwards in normalize_total() which I haven't looked at much.

What combinations of inputs to _normalize_data() need to be supported?

  • numpjy X, numpy counts
  • dask X, dask counts
  • csr_matrix X, csr_matrix counts

Combinations?

  • numpjy X, dask counts
  • dask X, numpy counts
  • numpjy X, csr_matrix counts
  • csr_matrix X, numpy counts
  • dask X, csr_matrix counts
  • csr_matrix X, dask counts

@ivirshup
Copy link
Member

So, I'm not too surprised to see this, since I don't think much of the distributed stuff has good testing, and I'm not too familiar with it.

I believe the AnnData constructor is converting the array. You can get around this by assigning X to be a dask array, e.g.:

a = ad.AnnData(np.ones((1000, 100)))
a.X = da.from_array(a.X)
type(a.X)
# dask.array.core.Array

Better support for dask arrays would be a great feature request and series of additions to anndata. I think this is the endemic numeric python problem of "these things are all like arrays, so can kinda use the same API, but in practice every type needs to be special cased".

but there's a lot of other stuff happening before & afterwards in normalize_total() which I haven't looked at much

Yeah, I think this function has built up some cruft. I've opened a PR to streamline this #1667, but will need to check with people more familiar with the code. The private method should handle all of the computation, while the outer wrapper will do more argument handling/ getting data out of the AnnData/ assigning it back.

What combinations of inputs to _normalize_data() need to be supported

I believe counts should always be generated from X, so we don't need to worry about the combinations of types.

@GenevieveBuckley
Copy link
Contributor Author

Better support for dask arrays would be a great feature request and series of additions to anndata.

I'd be great to have a bigger picture conversation about this with you. I've just sent you a twitter DM with my contact details.

@ryan-williams
Copy link

Apologies for jumping in, but wanted to mention that @sakoht and I have been working on Dask+Anndata a bit over in celsiustx/anndata, and would love to discuss some of these issues or listen in.

To date, a lot of the work has been upstream in Dask (e.g. dask/dask#6661; also not quite a PR yet but celsiustx/dask@sum2 improves support for Dask Arrays with scipy.sparse.spmatrix blocks), but focus should be moving more into Anndata soon.

I'll try to put some more cogent thoughts into an issue (likely on Anndata) tmrw, but just wanted to mention here since I've been following this thread with interest! Thanks.

@ivirshup
Copy link
Member

Great to hear from both of you. I'd really love to have better Dask integration with AnnData and am excited to see these progress!

@ryan-williams, it'd great if you could open an issue over on anndata about this! I think that'd be a good place to discuss design considerations.

@GenevieveBuckley
Copy link
Contributor Author

So it seems that in every case, no matter what array type is given to andata.X, the counts_per_cell variable generated in normalize_total() is always being created as a numpy array.

So I'm not sure why there was a note next to the line in _normalize_data() about not being able to use dask, because the input counts here are always numpy (because they've been created already in normalize_total()).

Presumably this is not intended?

@ivirshup
Copy link
Member

ivirshup commented Feb 24, 2021

@GenevieveBuckley, could you provide an example? I think I'm not sure what you mean, as I see dask arrays being passed to _normalize_total

I just checked out this PR, added:

print(type(X), type(counts))

as the first line of _normalize_data and ran this code:

import scanpy as sc
import dask.array as da

adata = sc.datasets.blobs(n_observations=100, n_variables=50)
adata.X = da.from_array(adata.X)
# WARNING: Some cells have total count of genes equal to zero
# <class 'dask.array.core.Array'>
# <class 'dask.array.core.Array'>
display(adata.X)
# dask.array<true_divide, shape=(100, 50), dtype=float32, chunksize=(100, 50), chunktype=numpy.ndarray>

Oh, and sorry about the conflicts! I've just removed all the ignored patterns from the pyproject.toml.

@Koncopd
Copy link
Member

Koncopd commented Feb 24, 2021

Yes, if i remember correctly np.ravel converts dask to numpy.

@ivirshup
Copy link
Member

@Koncopd, I thought that too, but this isn't the case currently:

from dask.array import da; import numpy as np

np.ravel(da.ones(10))
# dask.array<ones, shape=(10,), dtype=float64, chunksize=(10,), chunktype=numpy.ndarray>

@GenevieveBuckley
Copy link
Contributor Author

I'd been adding a breakpoint just before the call to _normalize_data() and was printing out type(adata.X) and type(counts_per_cell).

I'll check it again the same way you did, and also check if I have uncommited code in my local repo.

@ivirshup
Copy link
Member

ivirshup commented Mar 3, 2021

Mind if I merge #1667? It would cause more conflicts, but should clean up the flow control in normalize_total. I'd also be happy to add a commit fixing up the conflicts here.

Edit: I added that commit. Feel free to remove it's troublesome.

@codecov
Copy link

codecov bot commented Mar 5, 2021

Codecov Report

Merging #1663 (a074541) into master (1be0a68) will increase coverage by 0.00%.
The diff coverage is 80.00%.

@@           Coverage Diff           @@
##           master    #1663   +/-   ##
=======================================
  Coverage   71.38%   71.38%           
=======================================
  Files          92       92           
  Lines       11283    11291    +8     
=======================================
+ Hits         8054     8060    +6     
- Misses       3229     3231    +2     
Impacted Files Coverage Δ
scanpy/preprocessing/_normalization.py 87.05% <80.00%> (-1.26%) ⬇️

@ivirshup ivirshup added this to the 1.8.1 milestone Jun 23, 2021
@ivirshup ivirshup modified the milestones: 1.8.1, 1.8.2 Jul 7, 2021
@ivirshup ivirshup modified the milestones: 1.8.2, 1.8.3 Nov 3, 2021
@ivirshup ivirshup modified the milestones: 1.8.3, 1.9.0 Mar 21, 2022
@ivirshup ivirshup merged commit 9cb915b into scverse:master Mar 21, 2022
@flying-sheep flying-sheep mentioned this pull request Aug 4, 2023
18 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants