Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup normalize_total #1667

Merged
merged 6 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/release-notes/1.8.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

- Switched to flit_ for building and deploying the package,
a simple tool with an easy to understand command line interface and metadata.
- Added `layer` and `copy` kwargs to :func:`~scanpy.pp.normalize_total` :pr:`1667` :smaller:`I Virshup`

.. _flit: https://flit.readthedocs.io/en/latest/

Expand All @@ -15,3 +16,5 @@
.. rubric:: Bug fixes

.. rubric:: Deprecations

- Deprecated `layers` and `layers_norm` kwargs to :func:`~scanpy.pp.normalize_total` :pr:`1667` :smaller:`I Virshup`
104 changes: 62 additions & 42 deletions scanpy/preprocessing/_normalization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Union, Iterable, Dict
from warnings import warn

import numpy as np
from anndata import AnnData
Expand All @@ -8,6 +9,7 @@
from .. import logging as logg
from .._compat import Literal
from .._utils import view_to_actual
from scanpy.get import _get_obs_rep, _set_obs_rep


def _normalize_data(X, counts, after=None, copy=False):
Expand All @@ -31,9 +33,11 @@ def normalize_total(
exclude_highly_expressed: bool = False,
max_fraction: float = 0.05,
key_added: Optional[str] = None,
layer: Optional[str] = None,
layers: Union[Literal['all'], Iterable[str]] = None,
layer_norm: Optional[str] = None,
inplace: bool = True,
copy: bool = False,
) -> Optional[Dict[str, np.ndarray]]:
"""\
Normalize counts per cell.
Expand Down Expand Up @@ -72,23 +76,13 @@ def normalize_total(
key_added
Name of the field in `adata.obs` where the normalization factor is
stored.
layers
List of layers to normalize. Set to `'all'` to normalize all layers.
layer_norm
Specifies how to normalize layers:

* If `None`, after normalization, for each layer in *layers* each cell
has a total count equal to the median of the *counts_per_cell* before
normalization of the layer.
* If `'after'`, for each layer in *layers* each cell has
a total count equal to `target_sum`.
* If `'X'`, for each layer in *layers* each cell has a total count
equal to the median of total counts for observations (cells) of
`adata.X` before normalization.

layer
Layer to normalize instead of `X`. If `None`, `X` is normalized.
inplace
Whether to update `adata` or return dictionary with normalized copies of
`adata.X` and `adata.layers`.
copy
Whether to modify copied input object. Not compatible with inplace=False.

Returns
-------
Expand Down Expand Up @@ -127,9 +121,30 @@ def normalize_total(
[ 0.5, 0.5, 0.5, 1. , 1. ],
[ 0.5, 11. , 0.5, 1. , 1. ]], dtype=float32)
"""
if copy:
if not inplace:
raise ValueError("`copy=True` cannot be used with `inplace=False`.")
adata = adata.copy()

if max_fraction < 0 or max_fraction > 1:
raise ValueError('Choose max_fraction between 0 and 1.')

# Deprecated features
if layers is not None:
warn(
FutureWarning(
"The `layers` argument is deprecated. Instead, specify individual "
"layers to normalize with `layer`."
)
)
if layer_norm is not None:
warn(
FutureWarning(
"The `layer_norm` argument is deprecated. Specify the target size "
"factor directly with `target_sum`."
)
)

if layers == 'all':
layers = adata.layers.keys()
elif isinstance(layers, str):
Expand All @@ -139,32 +154,47 @@ def normalize_total(

view_to_actual(adata)

X = _get_obs_rep(adata, layer=layer)

gene_subset = None
msg = 'normalizing counts per cell'
if exclude_highly_expressed:
counts_per_cell = adata.X.sum(1) # original counts per cell
counts_per_cell = X.sum(1) # original counts per cell
counts_per_cell = np.ravel(counts_per_cell)

# at least one cell as more than max_fraction of counts per cell
gene_subset = (adata.X > counts_per_cell[:, None] * max_fraction).sum(0)

gene_subset = (X > counts_per_cell[:, None] * max_fraction).sum(0)
gene_subset = np.ravel(gene_subset) == 0

msg += (
' The following highly-expressed genes are not considered during '
f'normalization factor computation:\n{adata.var_names[~gene_subset].tolist()}'
)
counts_per_cell = X[:, gene_subset].sum(1)
else:
counts_per_cell = X.sum(1)
start = logg.info(msg)

# counts per cell for subset, if max_fraction!=1
X = adata.X if gene_subset is None else adata[:, gene_subset].X
counts_per_cell = X.sum(1)
# get rid of adata view
counts_per_cell = np.ravel(counts_per_cell).copy()
counts_per_cell = np.ravel(counts_per_cell)

cell_subset = counts_per_cell > 0
if not np.all(cell_subset):
logg.warning('Some cells have total count of genes equal to zero')
warn(UserWarning('Some cells have zero counts'))

if inplace:
if key_added is not None:
adata.obs[key_added] = counts_per_cell
_set_obs_rep(
adata, _normalize_data(X, counts_per_cell, target_sum), layer=layer
)
else:
# not recarray because need to support sparse
dat = dict(
X=_normalize_data(X, counts_per_cell, target_sum, copy=True),
norm_factor=counts_per_cell,
)

# Deprecated features
if layer_norm == 'after':
after = target_sum
elif layer_norm == 'X':
Expand All @@ -173,26 +203,13 @@ def normalize_total(
after = None
else:
raise ValueError('layer_norm should be "after", "X" or None')
del cell_subset

if inplace:
if key_added is not None:
adata.obs[key_added] = counts_per_cell
adata.X = _normalize_data(adata.X, counts_per_cell, target_sum)
else:
# not recarray because need to support sparse
dat = dict(
X=_normalize_data(adata.X, counts_per_cell, target_sum, copy=True),
norm_factor=counts_per_cell,
for layer_to_norm in layers if layers is not None else ():
res = normalize_total(
adata, layer=layer_to_norm, target_sum=after, inplace=inplace
)

for layer_name in layers or ():
layer = adata.layers[layer_name]
counts = np.ravel(layer.sum(1))
if inplace:
adata.layers[layer_name] = _normalize_data(layer, counts, after)
else:
dat[layer_name] = _normalize_data(layer, counts, after, copy=True)
if not inplace:
dat[layer_to_norm] = res["X"]

logg.info(
' finished ({time_passed})',
Expand All @@ -203,4 +220,7 @@ def normalize_total(
f'and added {key_added!r}, counts per cell before normalization (adata.obs)'
)

return dat if not inplace else None
if copy:
return adata
elif not inplace:
return dat
94 changes: 55 additions & 39 deletions scanpy/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,68 +2,84 @@
This file contains helper functions for the scanpy test suite.
"""

from itertools import permutations

import scanpy as sc
import numpy as np

from anndata.tests.helpers import asarray, assert_equal

# TODO: Report more context on the fields being compared on error
# TODO: Allow specifying paths to ignore on comparison

###########################
# Representation choice
###########################
# These functions can be used to check that functions are correctly using arugments like `layers`, `obsm`, etc.


def check_rep_mutation(func, X, **kwargs):
def check_rep_mutation(func, X, *, fields=["layer", "obsm"], **kwargs):
"""Check that only the array meant to be modified is modified."""
adata = sc.AnnData(
X=X.copy(),
layers={"layer": X.copy()},
obsm={"obsm": X.copy()},
dtype=X.dtype,
)
adata = sc.AnnData(X=X.copy(), dtype=X.dtype)
for field in fields:
sc.get._set_obs_rep(adata, X, **{field: field})
X_array = asarray(X)

adata_X = func(adata, copy=True, **kwargs)
adata_layer = func(adata, layer="layer", copy=True, **kwargs)
adata_obsm = func(adata, obsm="obsm", copy=True, **kwargs)
adatas_proc = {
field: func(adata, copy=True, **{field: field}, **kwargs) for field in fields
}

assert np.array_equal(asarray(adata_X.X), asarray(adata_layer.layers["layer"]))
assert np.array_equal(asarray(adata_X.X), asarray(adata_obsm.obsm["obsm"]))
# Modified fields
for field in fields:
result_array = asarray(
sc.get._get_obs_rep(adatas_proc[field], **{field: field})
)
np.testing.assert_array_equal(asarray(adata_X.X), result_array)

assert np.array_equal(asarray(adata_layer.X), asarray(adata_layer.obsm["obsm"]))
assert np.array_equal(asarray(adata_obsm.X), asarray(adata_obsm.layers["layer"]))
assert np.array_equal(
asarray(adata_X.layers["layer"]), asarray(adata_X.obsm["obsm"])
)
# Unmodified fields
for field in fields:
np.testing.assert_array_equal(X_array, asarray(adatas_proc[field].X))
np.testing.assert_array_equal(
X_array, asarray(sc.get._get_obs_rep(adata_X, **{field: field}))
)
for field_a, field_b in permutations(fields, 2):
result_array = asarray(
sc.get._get_obs_rep(adatas_proc[field_a], **{field_b: field_b})
)
np.testing.assert_array_equal(X_array, result_array)


def check_rep_results(func, X, **kwargs):
def check_rep_results(func, X, *, fields=["layer", "obsm"], **kwargs):
"""Checks that the results of a computation add values/ mutate the anndata object in a consistent way."""
# Gen data
adata_X = sc.AnnData(
X=X.copy(),
layers={"layer": np.zeros(shape=X.shape, dtype=X.dtype)},
obsm={"obsm": np.zeros(shape=X.shape, dtype=X.dtype)},
)
adata_layer = sc.AnnData(
X=np.zeros(shape=X.shape, dtype=X.dtype),
layers={"layer": X.copy()},
obsm={"obsm": np.zeros(shape=X.shape, dtype=X.dtype)},
)
adata_obsm = sc.AnnData(
X=np.zeros(shape=X.shape, dtype=X.dtype),
layers={"layer": np.zeros(shape=X.shape, dtype=X.dtype)},
obsm={"obsm": X.copy()},
empty_X = np.zeros(shape=X.shape, dtype=X.dtype)
adata = sc.AnnData(
X=empty_X.copy(),
layers={"layer": empty_X.copy()},
obsm={"obsm": empty_X.copy()},
)

adata_X = adata.copy()
adata_X.X = X.copy()

adatas_proc = {}
for field in fields:
cur = adata.copy()
sc.get._set_obs_rep(cur, X.copy(), **{field: field})
adatas_proc[field] = cur

# Apply function
func(adata_X, **kwargs)
func(adata_layer, layer="layer", **kwargs)
func(adata_obsm, obsm="obsm", **kwargs)
for field in fields:
func(adatas_proc[field], **{field: field}, **kwargs)

# Reset X
adata_X.X = np.zeros(shape=X.shape, dtype=X.dtype)
adata_layer.layers["layer"] = np.zeros(shape=X.shape, dtype=X.dtype)
adata_obsm.obsm["obsm"] = np.zeros(shape=X.shape, dtype=X.dtype)
adata_X.X = empty_X.copy()
for field in fields:
sc.get._set_obs_rep(adatas_proc[field], empty_X.copy(), **{field: field})

# Check equality
assert_equal(adata_X, adata_layer)
assert_equal(adata_X, adata_obsm)
for field_a, field_b in permutations(fields, 2):
assert_equal(adatas_proc[field_a], adatas_proc[field_b])
for field in fields:
assert_equal(adata_X, adatas_proc[field])
16 changes: 14 additions & 2 deletions scanpy/tests/test_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import numpy as np
from anndata import AnnData
from scipy.sparse import csr_matrix
from scipy import sparse

import scanpy as sc
from anndata.tests.helpers import assert_equal
from scanpy.tests.helpers import check_rep_mutation, check_rep_results
from anndata.tests.helpers import assert_equal, asarray

X_total = [[1, 0], [3, 0], [5, 6]]
X_frac = [[1, 0, 1], [3, 0, 1], [5, 6, 1]]
Expand All @@ -24,12 +26,22 @@ def test_normalize_total(typ, dtype):
assert np.allclose(np.ravel(adata.X[:, 1:3].sum(axis=1)), [1.0, 1.0, 1.0])


@pytest.mark.parametrize('typ', [asarray, csr_matrix], ids=lambda x: x.__name__)
@pytest.mark.parametrize('dtype', ['float32', 'int64'])
def test_normalize_total_rep(typ, dtype):
# Test that layer kwarg works
X = typ(sparse.random(100, 50, format="csr", density=0.2, dtype=dtype))
check_rep_mutation(sc.pp.normalize_total, X, fields=["layer"])
check_rep_results(sc.pp.normalize_total, X, fields=["layer"])


@pytest.mark.parametrize('typ', [np.array, csr_matrix], ids=lambda x: x.__name__)
@pytest.mark.parametrize('dtype', ['float32', 'int64'])
def test_normalize_total_layers(typ, dtype):
adata = AnnData(typ(X_total), dtype=dtype)
adata.layers["layer"] = adata.X.copy()
sc.pp.normalize_total(adata, layers=["layer"])
with pytest.warns(FutureWarning, match=r".*layers.*deprecated"):
sc.pp.normalize_total(adata, layers=["layer"])
assert np.allclose(adata.layers["layer"].sum(axis=1), [3.0, 3.0, 3.0])


Expand Down