Skip to content

Commit

Permalink
Cleanup normalize_total (#1667)
Browse files Browse the repository at this point in the history
* Cleanup normalize_total

* Add modification tests and copy kwarg for normalize_total

* Test that 'layers' argument is deprecated

* Added more mutation checks for normalize_total

* release note

* Error message
  • Loading branch information
ivirshup authored and Zethson committed Mar 15, 2021
1 parent c11c486 commit f637c08
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 83 deletions.
3 changes: 3 additions & 0 deletions docs/release-notes/1.8.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

- Switched to flit_ for building and deploying the package,
a simple tool with an easy to understand command line interface and metadata.
- Added `layer` and `copy` kwargs to :func:`~scanpy.pp.normalize_total` :pr:`1667` :smaller:`I Virshup`

.. _flit: https://flit.readthedocs.io/en/latest/

Expand All @@ -22,3 +23,5 @@
.. rubric:: Bug fixes

.. rubric:: Deprecations

- Deprecated `layers` and `layers_norm` kwargs to :func:`~scanpy.pp.normalize_total` :pr:`1667` :smaller:`I Virshup`
104 changes: 62 additions & 42 deletions scanpy/preprocessing/_normalization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Union, Iterable, Dict
from warnings import warn

import numpy as np
from anndata import AnnData
Expand All @@ -8,6 +9,7 @@
from .. import logging as logg
from .._compat import Literal
from .._utils import view_to_actual
from scanpy.get import _get_obs_rep, _set_obs_rep


def _normalize_data(X, counts, after=None, copy=False):
Expand All @@ -31,9 +33,11 @@ def normalize_total(
exclude_highly_expressed: bool = False,
max_fraction: float = 0.05,
key_added: Optional[str] = None,
layer: Optional[str] = None,
layers: Union[Literal['all'], Iterable[str]] = None,
layer_norm: Optional[str] = None,
inplace: bool = True,
copy: bool = False,
) -> Optional[Dict[str, np.ndarray]]:
"""\
Normalize counts per cell.
Expand Down Expand Up @@ -72,23 +76,13 @@ def normalize_total(
key_added
Name of the field in `adata.obs` where the normalization factor is
stored.
layers
List of layers to normalize. Set to `'all'` to normalize all layers.
layer_norm
Specifies how to normalize layers:
* If `None`, after normalization, for each layer in *layers* each cell
has a total count equal to the median of the *counts_per_cell* before
normalization of the layer.
* If `'after'`, for each layer in *layers* each cell has
a total count equal to `target_sum`.
* If `'X'`, for each layer in *layers* each cell has a total count
equal to the median of total counts for observations (cells) of
`adata.X` before normalization.
layer
Layer to normalize instead of `X`. If `None`, `X` is normalized.
inplace
Whether to update `adata` or return dictionary with normalized copies of
`adata.X` and `adata.layers`.
copy
Whether to modify copied input object. Not compatible with inplace=False.
Returns
-------
Expand Down Expand Up @@ -127,9 +121,30 @@ def normalize_total(
[ 0.5, 0.5, 0.5, 1. , 1. ],
[ 0.5, 11. , 0.5, 1. , 1. ]], dtype=float32)
"""
if copy:
if not inplace:
raise ValueError("`copy=True` cannot be used with `inplace=False`.")
adata = adata.copy()

if max_fraction < 0 or max_fraction > 1:
raise ValueError('Choose max_fraction between 0 and 1.')

# Deprecated features
if layers is not None:
warn(
FutureWarning(
"The `layers` argument is deprecated. Instead, specify individual "
"layers to normalize with `layer`."
)
)
if layer_norm is not None:
warn(
FutureWarning(
"The `layer_norm` argument is deprecated. Specify the target size "
"factor directly with `target_sum`."
)
)

if layers == 'all':
layers = adata.layers.keys()
elif isinstance(layers, str):
Expand All @@ -139,32 +154,47 @@ def normalize_total(

view_to_actual(adata)

X = _get_obs_rep(adata, layer=layer)

gene_subset = None
msg = 'normalizing counts per cell'
if exclude_highly_expressed:
counts_per_cell = adata.X.sum(1) # original counts per cell
counts_per_cell = X.sum(1) # original counts per cell
counts_per_cell = np.ravel(counts_per_cell)

# at least one cell as more than max_fraction of counts per cell
gene_subset = (adata.X > counts_per_cell[:, None] * max_fraction).sum(0)

gene_subset = (X > counts_per_cell[:, None] * max_fraction).sum(0)
gene_subset = np.ravel(gene_subset) == 0

msg += (
' The following highly-expressed genes are not considered during '
f'normalization factor computation:\n{adata.var_names[~gene_subset].tolist()}'
)
counts_per_cell = X[:, gene_subset].sum(1)
else:
counts_per_cell = X.sum(1)
start = logg.info(msg)

# counts per cell for subset, if max_fraction!=1
X = adata.X if gene_subset is None else adata[:, gene_subset].X
counts_per_cell = X.sum(1)
# get rid of adata view
counts_per_cell = np.ravel(counts_per_cell).copy()
counts_per_cell = np.ravel(counts_per_cell)

cell_subset = counts_per_cell > 0
if not np.all(cell_subset):
logg.warning('Some cells have total count of genes equal to zero')
warn(UserWarning('Some cells have zero counts'))

if inplace:
if key_added is not None:
adata.obs[key_added] = counts_per_cell
_set_obs_rep(
adata, _normalize_data(X, counts_per_cell, target_sum), layer=layer
)
else:
# not recarray because need to support sparse
dat = dict(
X=_normalize_data(X, counts_per_cell, target_sum, copy=True),
norm_factor=counts_per_cell,
)

# Deprecated features
if layer_norm == 'after':
after = target_sum
elif layer_norm == 'X':
Expand All @@ -173,26 +203,13 @@ def normalize_total(
after = None
else:
raise ValueError('layer_norm should be "after", "X" or None')
del cell_subset

if inplace:
if key_added is not None:
adata.obs[key_added] = counts_per_cell
adata.X = _normalize_data(adata.X, counts_per_cell, target_sum)
else:
# not recarray because need to support sparse
dat = dict(
X=_normalize_data(adata.X, counts_per_cell, target_sum, copy=True),
norm_factor=counts_per_cell,
for layer_to_norm in layers if layers is not None else ():
res = normalize_total(
adata, layer=layer_to_norm, target_sum=after, inplace=inplace
)

for layer_name in layers or ():
layer = adata.layers[layer_name]
counts = np.ravel(layer.sum(1))
if inplace:
adata.layers[layer_name] = _normalize_data(layer, counts, after)
else:
dat[layer_name] = _normalize_data(layer, counts, after, copy=True)
if not inplace:
dat[layer_to_norm] = res["X"]

logg.info(
' finished ({time_passed})',
Expand All @@ -203,4 +220,7 @@ def normalize_total(
f'and added {key_added!r}, counts per cell before normalization (adata.obs)'
)

return dat if not inplace else None
if copy:
return adata
elif not inplace:
return dat
94 changes: 55 additions & 39 deletions scanpy/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,68 +2,84 @@
This file contains helper functions for the scanpy test suite.
"""

from itertools import permutations

import scanpy as sc
import numpy as np

from anndata.tests.helpers import asarray, assert_equal

# TODO: Report more context on the fields being compared on error
# TODO: Allow specifying paths to ignore on comparison

###########################
# Representation choice
###########################
# These functions can be used to check that functions are correctly using arugments like `layers`, `obsm`, etc.


def check_rep_mutation(func, X, **kwargs):
def check_rep_mutation(func, X, *, fields=["layer", "obsm"], **kwargs):
"""Check that only the array meant to be modified is modified."""
adata = sc.AnnData(
X=X.copy(),
layers={"layer": X.copy()},
obsm={"obsm": X.copy()},
dtype=X.dtype,
)
adata = sc.AnnData(X=X.copy(), dtype=X.dtype)
for field in fields:
sc.get._set_obs_rep(adata, X, **{field: field})
X_array = asarray(X)

adata_X = func(adata, copy=True, **kwargs)
adata_layer = func(adata, layer="layer", copy=True, **kwargs)
adata_obsm = func(adata, obsm="obsm", copy=True, **kwargs)
adatas_proc = {
field: func(adata, copy=True, **{field: field}, **kwargs) for field in fields
}

assert np.array_equal(asarray(adata_X.X), asarray(adata_layer.layers["layer"]))
assert np.array_equal(asarray(adata_X.X), asarray(adata_obsm.obsm["obsm"]))
# Modified fields
for field in fields:
result_array = asarray(
sc.get._get_obs_rep(adatas_proc[field], **{field: field})
)
np.testing.assert_array_equal(asarray(adata_X.X), result_array)

assert np.array_equal(asarray(adata_layer.X), asarray(adata_layer.obsm["obsm"]))
assert np.array_equal(asarray(adata_obsm.X), asarray(adata_obsm.layers["layer"]))
assert np.array_equal(
asarray(adata_X.layers["layer"]), asarray(adata_X.obsm["obsm"])
)
# Unmodified fields
for field in fields:
np.testing.assert_array_equal(X_array, asarray(adatas_proc[field].X))
np.testing.assert_array_equal(
X_array, asarray(sc.get._get_obs_rep(adata_X, **{field: field}))
)
for field_a, field_b in permutations(fields, 2):
result_array = asarray(
sc.get._get_obs_rep(adatas_proc[field_a], **{field_b: field_b})
)
np.testing.assert_array_equal(X_array, result_array)


def check_rep_results(func, X, **kwargs):
def check_rep_results(func, X, *, fields=["layer", "obsm"], **kwargs):
"""Checks that the results of a computation add values/ mutate the anndata object in a consistent way."""
# Gen data
adata_X = sc.AnnData(
X=X.copy(),
layers={"layer": np.zeros(shape=X.shape, dtype=X.dtype)},
obsm={"obsm": np.zeros(shape=X.shape, dtype=X.dtype)},
)
adata_layer = sc.AnnData(
X=np.zeros(shape=X.shape, dtype=X.dtype),
layers={"layer": X.copy()},
obsm={"obsm": np.zeros(shape=X.shape, dtype=X.dtype)},
)
adata_obsm = sc.AnnData(
X=np.zeros(shape=X.shape, dtype=X.dtype),
layers={"layer": np.zeros(shape=X.shape, dtype=X.dtype)},
obsm={"obsm": X.copy()},
empty_X = np.zeros(shape=X.shape, dtype=X.dtype)
adata = sc.AnnData(
X=empty_X.copy(),
layers={"layer": empty_X.copy()},
obsm={"obsm": empty_X.copy()},
)

adata_X = adata.copy()
adata_X.X = X.copy()

adatas_proc = {}
for field in fields:
cur = adata.copy()
sc.get._set_obs_rep(cur, X.copy(), **{field: field})
adatas_proc[field] = cur

# Apply function
func(adata_X, **kwargs)
func(adata_layer, layer="layer", **kwargs)
func(adata_obsm, obsm="obsm", **kwargs)
for field in fields:
func(adatas_proc[field], **{field: field}, **kwargs)

# Reset X
adata_X.X = np.zeros(shape=X.shape, dtype=X.dtype)
adata_layer.layers["layer"] = np.zeros(shape=X.shape, dtype=X.dtype)
adata_obsm.obsm["obsm"] = np.zeros(shape=X.shape, dtype=X.dtype)
adata_X.X = empty_X.copy()
for field in fields:
sc.get._set_obs_rep(adatas_proc[field], empty_X.copy(), **{field: field})

# Check equality
assert_equal(adata_X, adata_layer)
assert_equal(adata_X, adata_obsm)
for field_a, field_b in permutations(fields, 2):
assert_equal(adatas_proc[field_a], adatas_proc[field_b])
for field in fields:
assert_equal(adata_X, adatas_proc[field])
16 changes: 14 additions & 2 deletions scanpy/tests/test_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import numpy as np
from anndata import AnnData
from scipy.sparse import csr_matrix
from scipy import sparse

import scanpy as sc
from anndata.tests.helpers import assert_equal
from scanpy.tests.helpers import check_rep_mutation, check_rep_results
from anndata.tests.helpers import assert_equal, asarray

X_total = [[1, 0], [3, 0], [5, 6]]
X_frac = [[1, 0, 1], [3, 0, 1], [5, 6, 1]]
Expand All @@ -24,12 +26,22 @@ def test_normalize_total(typ, dtype):
assert np.allclose(np.ravel(adata.X[:, 1:3].sum(axis=1)), [1.0, 1.0, 1.0])


@pytest.mark.parametrize('typ', [asarray, csr_matrix], ids=lambda x: x.__name__)
@pytest.mark.parametrize('dtype', ['float32', 'int64'])
def test_normalize_total_rep(typ, dtype):
# Test that layer kwarg works
X = typ(sparse.random(100, 50, format="csr", density=0.2, dtype=dtype))
check_rep_mutation(sc.pp.normalize_total, X, fields=["layer"])
check_rep_results(sc.pp.normalize_total, X, fields=["layer"])


@pytest.mark.parametrize('typ', [np.array, csr_matrix], ids=lambda x: x.__name__)
@pytest.mark.parametrize('dtype', ['float32', 'int64'])
def test_normalize_total_layers(typ, dtype):
adata = AnnData(typ(X_total), dtype=dtype)
adata.layers["layer"] = adata.X.copy()
sc.pp.normalize_total(adata, layers=["layer"])
with pytest.warns(FutureWarning, match=r".*layers.*deprecated"):
sc.pp.normalize_total(adata, layers=["layer"])
assert np.allclose(adata.layers["layer"].sum(axis=1), [3.0, 3.0, 3.0])


Expand Down

0 comments on commit f637c08

Please sign in to comment.