forked from scverse/scanpy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a module for computing useful metrics. Started off with Geary's C since I'm using it and finding it useful. I've also got a fairly fast way to calculate it worked out. Unfortunatly my implementation runs into some issues with some global configs set by umap (see lmcinnes/umap#306), so I'm going to see if that can be resolved before changing it.
- Loading branch information
Showing
5 changed files
with
249 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,3 +18,4 @@ numba>=0.41.0 | |
umap-learn>=0.3.0 | ||
legacy-api-wrap | ||
setuptools_scm | ||
multipledispatch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from ._gearys_c import gearys_c |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
# TODO: Calling this after UMAP has imported will hang, pending https://github.com/lmcinnes/umap/issues/306 | ||
from typing import Optional, Union | ||
|
||
from anndata import AnnData | ||
from multipledispatch import dispatch | ||
import numba | ||
import numpy as np | ||
import pandas as pd | ||
from scipy import sparse | ||
|
||
|
||
def _choose_obs_rep(adata, *, use_raw=False, layer=None, obsm=None, obsp=None): | ||
""" | ||
Choose array aligned with obs annotation. | ||
""" | ||
is_layer = layer is not None | ||
is_raw = use_raw is not False | ||
is_obsm = obsm is not None | ||
is_obsp = obsp is not None | ||
choices_made = sum((is_layer, is_raw, is_obsm, is_obsp)) | ||
assert choices_made <= 1 | ||
if choices_made == 0: | ||
return adata.X | ||
elif is_layer: | ||
return adata.layers[layer] | ||
elif use_raw: | ||
return adata.raw.X | ||
elif is_obsm: | ||
return adata.obsm[obsm] | ||
elif is_obsp: | ||
return adata.obsp[obsp] | ||
else: | ||
raise RuntimeError("You broke it. But how? Please report this.") | ||
|
||
|
||
############################################################################### | ||
# Calculation | ||
############################################################################### | ||
|
||
|
||
@numba.njit(cache=True, parallel=True) | ||
def _gearys_c_vec(data, indices, indptr, x): | ||
W = data.sum() | ||
return _gearys_c_vec_W(data, indices, indptr, x, W) | ||
|
||
|
||
@numba.njit(cache=True, parallel=True) | ||
def _gearys_c_vec_W(data, indices, indptr, x, W): | ||
N = len(indptr) - 1 | ||
x_bar = x.mean() | ||
|
||
total = 0.0 | ||
for i in numba.prange(N): | ||
s = slice(indptr[i], indptr[i + 1]) | ||
i_indices = indices[s] | ||
i_data = data[s] | ||
total += np.sum(i_data * ((x[i] - x[i_indices]) ** 2)) | ||
|
||
numer = (N - 1) * total | ||
denom = 2 * W * ((x - x_bar) ** 2).sum() | ||
C = numer / denom | ||
|
||
return C | ||
|
||
|
||
@numba.njit(cache=True, parallel=True) | ||
def _gearys_c_mtx_csr( | ||
g_data, g_indices, g_indptr, x_data, x_indices, x_indptr, x_shape | ||
): | ||
M, N = x_shape | ||
W = g_data.sum() | ||
out = np.zeros(M, dtype=np.float64) | ||
for k in numba.prange(M): | ||
x_arr = np.zeros(N, dtype=x_data.dtype) | ||
sk = slice(x_indptr[k], x_indptr[k + 1]) | ||
x_arr[x_indices[sk]] = x_data[sk] | ||
outval = _gearys_c_vec_W(g_data, g_indices, g_indptr, x_arr, W) | ||
out[k] = outval | ||
return out | ||
|
||
|
||
@numba.njit(cache=True, parallel=True) | ||
def _gearys_c_mtx(g_data, g_indices, g_indptr, X): | ||
M, N = X.shape | ||
W = g_data.sum() | ||
out = np.zeros(M, dtype=np.float64) | ||
for k in numba.prange(M): | ||
outval = _gearys_c_vec_W(g_data, g_indices, g_indptr, X[k, :], W) | ||
out[k] = outval | ||
return out | ||
|
||
|
||
############################################################################### | ||
# Interface | ||
############################################################################### | ||
|
||
|
||
@dispatch(sparse.csr_matrix, sparse.csr_matrix) | ||
def gearys_c(g, vals) -> np.ndarray: | ||
assert g.shape[0] == g.shape[1] | ||
assert g.shape[0] == vals.shape[1] | ||
return _gearys_c_mtx_csr( | ||
g.data, | ||
g.indices, | ||
g.indptr, | ||
vals.data, | ||
vals.indices, | ||
vals.indptr, | ||
vals.shape, | ||
) | ||
|
||
|
||
@dispatch(sparse.spmatrix, np.ndarray) | ||
def gearys_c(g, vals): | ||
assert g.shape[0] == g.shape[1], "`g` should be a square matrix." | ||
if not isinstance(g, sparse.csr_matrix): | ||
g = g.tocsr() | ||
if vals.ndim == 1: | ||
assert g.shape[0] == vals.shape[0] | ||
return _gearys_c_vec(g.data, g.indices, g.indptr, vals) | ||
elif vals.ndim == 2: | ||
assert g.shape[0] == vals.shape[1] | ||
return _gearys_c_mtx(g.data, g.indices, g.indptr, vals) | ||
else: | ||
raise ValueError() | ||
|
||
|
||
@dispatch(sparse.spmatrix, (pd.DataFrame, pd.Series)) | ||
def gearys_c(g, vals): | ||
return gearys_c(g, vals.values) | ||
|
||
@dispatch(sparse.spmatrix, sparse.spmatrix) | ||
def gearys_c(g, vals) -> np.ndarray: | ||
if not isinstance(g, sparse.csr_matrix): | ||
g = g.tocsr() | ||
if not isinstance(vals, sparse.csc_matrix): | ||
vals = vals.tocsr() | ||
return gearys_c(g, vals) | ||
|
||
|
||
# TODO: Document better | ||
# TODO: Have scanpydoc work with multipledispatch | ||
@dispatch(AnnData) | ||
def gearys_c( | ||
adata: AnnData, | ||
*, | ||
vals: Optional[Union[np.ndarray, sparse.spmatrix]] = None, | ||
use_graph: Optional[str] = None, | ||
layer: Optional[str] = None, | ||
obsm: Optional[str] = None, | ||
obsp: Optional[str] = None, | ||
use_raw: bool = False, | ||
) -> Union[np.ndarray, float]: | ||
""" | ||
Calculate `Geary's C` <https://en.wikipedia.org/wiki/Geary's_C>`_, as used | ||
by `VISION <https://doi.org/10.1038/s41467-019-12235-0>`_. | ||
Geary's C is a measure of autocorrelation for some measure on a graph. This | ||
can be to whether measures are correlated between neighboring cells. Lower | ||
values indicate greater correlation. | ||
..math | ||
C = | ||
\frac{ | ||
(N - 1)\sum_{i,j} w_{i,j} (x_i - x_j)^2 | ||
}{ | ||
2W \sum_i (x_i - \bar{x})^2 | ||
} | ||
Params | ||
------ | ||
adata | ||
vals | ||
Values to calculate Geary's C for. If this is two dimensional, should | ||
be of shape `(n_features, n_cells)`. Otherwise should be of shape | ||
`(n_cells,)`. This matrix can be selected from elements of the anndata | ||
object by using key word arguments: `layer`, `obsm`, `obsp`, or | ||
`use_raw`. | ||
use_graph | ||
Key to use for graph in anndata object. If not provided, default | ||
neighbors connectivities will be used instead. | ||
layer | ||
Key for `adata.layers` to choose `vals`. | ||
obsm | ||
Key for `adata.obsm` to choose `vals`. | ||
obsp | ||
Key for `adata.obsp` to choose `vals`. | ||
use_raw | ||
Whether to use `adata.raw.X` for `vals`. | ||
Returns | ||
------- | ||
If vals is two dimensional, returns a 1 dimensional ndarray array. Returns a scalar if `vals` is 1d. | ||
""" | ||
if use_graph is None: | ||
if "connectivities" in adata.obsp: | ||
g = adata.obsp["connectivities"] | ||
elif "neighbors" in adata.uns: | ||
g = adata.uns["neighbors"]["connectivities"] | ||
else: | ||
raise ValueError("Must run neighbors first.") | ||
else: | ||
raise NotImplementedError() | ||
if vals is None: | ||
vals = _choose_obs_rep(adata, use_raw=use_raw, layer=layer, obsm=obsm, obsp=obsp).T | ||
return gearys_c(g, vals) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from operator import eq | ||
import numpy as np | ||
import scanpy as sc | ||
from scipy import sparse | ||
|
||
|
||
def test_gearys_c(): | ||
pbmc = sc.datasets.pbmc68k_reduced() | ||
pbmc.layers["raw"] = pbmc.raw.X.copy() | ||
g = pbmc.uns["neighbors"]["connectivities"] | ||
|
||
assert eq( | ||
sc.metrics.gearys_c(g, pbmc.obs["percent_mito"]), | ||
sc.metrics.gearys_c(pbmc, vals=pbmc.obs["percent_mito"]) | ||
) | ||
|
||
assert eq( # Test that series and vectors return same value | ||
sc.metrics.gearys_c(g, pbmc.obs["percent_mito"]), | ||
sc.metrics.gearys_c(g, pbmc.obs["percent_mito"].values), | ||
) | ||
|
||
assert np.array_equal( | ||
sc.metrics.gearys_c(pbmc, obsm="X_pca"), | ||
sc.metrics.gearys_c(g, pbmc.obsm["X_pca"].T) | ||
) | ||
|
||
# Test case with perfectly seperated groups | ||
connected = np.zeros(100) | ||
connected[np.random.choice(100, size=30, replace=False)] = 1 | ||
graph = np.zeros((100, 100)) | ||
graph[np.ix_(connected.astype(bool), connected.astype(bool))] = 1 | ||
graph[np.ix_(~connected.astype(bool), ~connected.astype(bool))] = 1 | ||
graph = sparse.csr_matrix(graph) | ||
|
||
assert sc.metrics.gearys_c(graph, connected) == 0. | ||
adata = sc.AnnData( | ||
sparse.csr_matrix((100, 100)), obsp={"connectivities": graph} | ||
) | ||
assert sc.metrics.gearys_c(adata, vals=connected) == 0. |