Skip to content

Commit

Permalink
Workaround umap<0.4 and increase numerical stability of gearys_c
Browse files Browse the repository at this point in the history
* Work around lmcinnes/umap#306 by not
  calling out to kernel function. That code has been kept, but commented
  out.
* Increase numerical stability by casting data to system width. Tests
  were failing due to instability.
  • Loading branch information
ivirshup committed Nov 13, 2019
1 parent 7ca25a6 commit 4d344dc
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 23 deletions.
124 changes: 101 additions & 23 deletions scanpy/metrics/_gearys_c.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# TODO: Calling this after UMAP has imported will hang, pending https://github.com/lmcinnes/umap/issues/306
from typing import Optional, Union

from anndata import AnnData
Expand Down Expand Up @@ -36,6 +35,16 @@ def _choose_obs_rep(adata, *, use_raw=False, layer=None, obsm=None, obsp=None):
###############################################################################
# Calculation
###############################################################################
# Some notes on the implementation:
# * This could be phrased as tensor multiplication. However that does not get
# parallelized, which boosts performance almost linearly with cores.
# * Due to the umap setting the default threading backend, a parallel numba
# function that calls another parallel numba function can get stuck. This
# ends up meaning code re-use will be limited until umap 0.4.
# See: https://github.com/lmcinnes/umap/issues/306
# * There can be a fair amount of numerical instability here (big reductions),
# so data is cast to float64. Removing these casts/ conversion will cause the
# tests to fail.


@numba.njit(cache=True, parallel=True)
Expand All @@ -48,6 +57,7 @@ def _gearys_c_vec(data, indices, indptr, x):
def _gearys_c_vec_W(data, indices, indptr, x, W):
N = len(indptr) - 1
x_bar = x.mean()
x = x.astype(np.float_)

total = 0.0
for i in numba.prange(N):
Expand All @@ -69,78 +79,143 @@ def _gearys_c_mtx_csr(
):
M, N = x_shape
W = g_data.sum()
out = np.zeros(M, dtype=np.float64)
out = np.zeros(M, dtype=np.float_)
for k in numba.prange(M):
x_arr = np.zeros(N, dtype=x_data.dtype)
x_k = np.zeros(N, dtype=np.float_)
sk = slice(x_indptr[k], x_indptr[k + 1])
x_arr[x_indices[sk]] = x_data[sk]
outval = _gearys_c_vec_W(g_data, g_indices, g_indptr, x_arr, W)
out[k] = outval
x_k_data = x_data[sk]
x_k[x_indices[sk]] = x_k_data
x_k_bar = np.sum(x_k_data) / N
total = 0.0
for i in numba.prange(N):
s = slice(g_indptr[i], g_indptr[i + 1])
i_indices = g_indices[s]
i_data = g_data[s]
total += np.sum(i_data * ((x_k[i] - x_k[i_indices]) ** 2))
numer = (N - 1) * total
# Expanded from 2 * W * ((x_k - x_k_bar) ** 2).sum(), but uses sparsity
# to skip some calculations
denom = (
2
* W
* (
np.sum(x_k_data ** 2)
- np.sum(x_k_data * x_k_bar * 2)
+ (x_k_bar ** 2) * N
)
)
C = numer / denom
out[k] = C
return out


# Simplified implementation, hits race condition after umap import due to numba
# parallel backend
# @numba.njit(cache=True, parallel=True)
# def _gearys_c_mtx_csr(
# g_data, g_indices, g_indptr, x_data, x_indices, x_indptr, x_shape
# ):
# M, N = x_shape
# W = g_data.sum()
# out = np.zeros(M, dtype=np.float64)
# for k in numba.prange(M):
# x_arr = np.zeros(N, dtype=x_data.dtype)
# sk = slice(x_indptr[k], x_indptr[k + 1])
# x_arr[x_indices[sk]] = x_data[sk]
# outval = _gearys_c_vec_W(g_data, g_indices, g_indptr, x_arr, W)
# out[k] = outval
# return out


@numba.njit(cache=True, parallel=True)
def _gearys_c_mtx(g_data, g_indices, g_indptr, X):
M, N = X.shape
W = g_data.sum()
out = np.zeros(M, dtype=np.float64)
out = np.zeros(M, dtype=np.float_)
for k in numba.prange(M):
outval = _gearys_c_vec_W(g_data, g_indices, g_indptr, X[k, :], W)
out[k] = outval
x = X[k, :].astype(np.float_)
x_bar = x.mean()

total = 0.0
for i in numba.prange(N):
s = slice(g_indptr[i], g_indptr[i + 1])
i_indices = g_indices[s]
i_data = g_data[s]
total += np.sum(i_data * ((x[i] - x[i_indices]) ** 2))

numer = (N - 1) * total
denom = 2 * W * ((x - x_bar) ** 2).sum()
C = numer / denom

out[k] = C
return out


# Similar to above, simplified version umaps choice of parallel backend breaks:
# @numba.njit(cache=True, parallel=True)
# def _gearys_c_mtx(g_data, g_indices, g_indptr, X):
# M, N = X.shape
# W = g_data.sum()
# out = np.zeros(M, dtype=np.float64)
# for k in numba.prange(M):
# outval = _gearys_c_vec_W(g_data, g_indices, g_indptr, X[k, :], W)
# out[k] = outval
# return out


###############################################################################
# Interface
###############################################################################


@dispatch(sparse.csr_matrix, sparse.csr_matrix)
def gearys_c(g, vals) -> np.ndarray:
assert g.shape[0] == g.shape[1]
assert g.shape[0] == g.shape[1], "`g` should be a square adjacency matrix"
assert g.shape[0] == vals.shape[1]
return _gearys_c_mtx_csr(
g.data,
g.data.astype(np.float_, copy=False),
g.indices,
g.indptr,
vals.data,
vals.data.astype(np.float_, copy=False),
vals.indices,
vals.indptr,
vals.shape,
)


@dispatch(sparse.spmatrix, np.ndarray)
@dispatch(sparse.spmatrix, np.ndarray) # noqa
def gearys_c(g, vals):
assert g.shape[0] == g.shape[1], "`g` should be a square matrix."
if not isinstance(g, sparse.csr_matrix):
g = g.tocsr()
g_data = g.data.astype(np.float_, copy=False)
if vals.ndim == 1:
assert g.shape[0] == vals.shape[0]
return _gearys_c_vec(g.data, g.indices, g.indptr, vals)
return _gearys_c_vec(g_data, g.indices, g.indptr, vals)
elif vals.ndim == 2:
assert g.shape[0] == vals.shape[1]
return _gearys_c_mtx(g.data, g.indices, g.indptr, vals)
return _gearys_c_mtx(g_data, g.indices, g.indptr, vals)
else:
raise ValueError()


@dispatch(sparse.spmatrix, (pd.DataFrame, pd.Series))
@dispatch(sparse.spmatrix, (pd.DataFrame, pd.Series)) # noqa
def gearys_c(g, vals):
return gearys_c(g, vals.values)

@dispatch(sparse.spmatrix, sparse.spmatrix)

@dispatch(sparse.spmatrix, sparse.spmatrix) # noqa
def gearys_c(g, vals) -> np.ndarray:
if not isinstance(g, sparse.csr_matrix):
g = g.tocsr()
if not isinstance(vals, sparse.csc_matrix):
if not isinstance(vals, sparse.csr_matrix):
vals = vals.tocsr()
return gearys_c(g, vals)


# TODO: Document better
# TODO: Have scanpydoc work with multipledispatch
@dispatch(AnnData)
@dispatch(AnnData) # noqa
def gearys_c(
adata: AnnData,
*,
Expand All @@ -156,12 +231,12 @@ def gearys_c(
by `VISION <https://doi.org/10.1038/s41467-019-12235-0>`_.
Geary's C is a measure of autocorrelation for some measure on a graph. This
can be to whether measures are correlated between neighboring cells. Lower
can be to whether measures are correlated between neighboring cells. Lower
values indicate greater correlation.
..math
C =
C =
\frac{
(N - 1)\sum_{i,j} w_{i,j} (x_i - x_j)^2
}{
Expand Down Expand Up @@ -191,7 +266,8 @@ def gearys_c(
Returns
-------
If vals is two dimensional, returns a 1 dimensional ndarray array. Returns a scalar if `vals` is 1d.
If vals is two dimensional, returns a 1 dimensional ndarray array. Returns
a scalar if `vals` is 1d.
"""
if use_graph is None:
if "connectivities" in adata.obsp:
Expand All @@ -203,5 +279,7 @@ def gearys_c(
else:
raise NotImplementedError()
if vals is None:
vals = _choose_obs_rep(adata, use_raw=use_raw, layer=layer, obsm=obsm, obsp=obsp).T
vals = _choose_obs_rep(
adata, use_raw=use_raw, layer=layer, obsm=obsm, obsp=obsp
).T
return gearys_c(g, vals)
15 changes: 15 additions & 0 deletions scanpy/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ def test_gearys_c():
sc.metrics.gearys_c(g, pbmc.obsm["X_pca"].T),
)

all_genes = sc.metrics.gearys_c(pbmc, layer="raw")
first_gene = sc.metrics.gearys_c(
pbmc,
vals=pbmc.obs_vector(pbmc.var_names[0], layer="raw")
)

assert np.allclose(all_genes[0], first_gene)

assert np.allclose(
sc.metrics.gearys_c(pbmc, layer="raw"),
sc.metrics.gearys_c(pbmc, vals=pbmc.layers["raw"].T.toarray())
)

# Test case with perfectly seperated groups
connected = np.zeros(100)
connected[np.random.choice(100, size=30, replace=False)] = 1
Expand All @@ -36,6 +49,8 @@ def test_gearys_c():
graph = sparse.csr_matrix(graph)

assert sc.metrics.gearys_c(graph, connected) == 0.0
assert sc.metrics.gearys_c(graph, connected) \
== sc.metrics.gearys_c(graph, sparse.csr_matrix(connected))
adata = sc.AnnData(
sparse.csr_matrix((100, 100)), obsp={"connectivities": graph}
)
Expand Down

0 comments on commit 4d344dc

Please sign in to comment.