diff --git a/docs/release-notes/1.8.0.rst b/docs/release-notes/1.8.0.rst index b98676558a..d7d709ecd0 100644 --- a/docs/release-notes/1.8.0.rst +++ b/docs/release-notes/1.8.0.rst @@ -5,6 +5,7 @@ - Switched to flit_ for building and deploying the package, a simple tool with an easy to understand command line interface and metadata. +- Added `layer` and `copy` kwargs to :func:`~scanpy.pp.normalize_total` :pr:`1667` :smaller:`I Virshup` .. _flit: https://flit.readthedocs.io/en/latest/ @@ -15,3 +16,5 @@ .. rubric:: Bug fixes .. rubric:: Deprecations + +- Deprecated `layers` and `layers_norm` kwargs to :func:`~scanpy.pp.normalize_total` :pr:`1667` :smaller:`I Virshup` diff --git a/scanpy/preprocessing/_normalization.py b/scanpy/preprocessing/_normalization.py index 2d745634f6..542afd38ca 100644 --- a/scanpy/preprocessing/_normalization.py +++ b/scanpy/preprocessing/_normalization.py @@ -1,4 +1,5 @@ from typing import Optional, Union, Iterable, Dict +from warnings import warn import numpy as np from anndata import AnnData @@ -8,6 +9,7 @@ from .. import logging as logg from .._compat import Literal from .._utils import view_to_actual +from scanpy.get import _get_obs_rep, _set_obs_rep def _normalize_data(X, counts, after=None, copy=False): @@ -31,9 +33,11 @@ def normalize_total( exclude_highly_expressed: bool = False, max_fraction: float = 0.05, key_added: Optional[str] = None, + layer: Optional[str] = None, layers: Union[Literal['all'], Iterable[str]] = None, layer_norm: Optional[str] = None, inplace: bool = True, + copy: bool = False, ) -> Optional[Dict[str, np.ndarray]]: """\ Normalize counts per cell. @@ -72,23 +76,13 @@ def normalize_total( key_added Name of the field in `adata.obs` where the normalization factor is stored. - layers - List of layers to normalize. Set to `'all'` to normalize all layers. - layer_norm - Specifies how to normalize layers: - - * If `None`, after normalization, for each layer in *layers* each cell - has a total count equal to the median of the *counts_per_cell* before - normalization of the layer. - * If `'after'`, for each layer in *layers* each cell has - a total count equal to `target_sum`. - * If `'X'`, for each layer in *layers* each cell has a total count - equal to the median of total counts for observations (cells) of - `adata.X` before normalization. - + layer + Layer to normalize instead of `X`. If `None`, `X` is normalized. inplace Whether to update `adata` or return dictionary with normalized copies of `adata.X` and `adata.layers`. + copy + Whether to modify copied input object. Not compatible with inplace=False. Returns ------- @@ -127,9 +121,30 @@ def normalize_total( [ 0.5, 0.5, 0.5, 1. , 1. ], [ 0.5, 11. , 0.5, 1. , 1. ]], dtype=float32) """ + if copy: + if not inplace: + raise ValueError("`copy=True` cannot be used with `inplace=False`.") + adata = adata.copy() + if max_fraction < 0 or max_fraction > 1: raise ValueError('Choose max_fraction between 0 and 1.') + # Deprecated features + if layers is not None: + warn( + FutureWarning( + "The `layers` argument is deprecated. Instead, specify individual " + "layers to normalize with `layer`." + ) + ) + if layer_norm is not None: + warn( + FutureWarning( + "The `layer_norm` argument is deprecated. Specify the target size " + "factor directly with `target_sum`." + ) + ) + if layers == 'all': layers = adata.layers.keys() elif isinstance(layers, str): @@ -139,32 +154,47 @@ def normalize_total( view_to_actual(adata) + X = _get_obs_rep(adata, layer=layer) + gene_subset = None msg = 'normalizing counts per cell' if exclude_highly_expressed: - counts_per_cell = adata.X.sum(1) # original counts per cell + counts_per_cell = X.sum(1) # original counts per cell counts_per_cell = np.ravel(counts_per_cell) # at least one cell as more than max_fraction of counts per cell - gene_subset = (adata.X > counts_per_cell[:, None] * max_fraction).sum(0) + + gene_subset = (X > counts_per_cell[:, None] * max_fraction).sum(0) gene_subset = np.ravel(gene_subset) == 0 msg += ( ' The following highly-expressed genes are not considered during ' f'normalization factor computation:\n{adata.var_names[~gene_subset].tolist()}' ) + counts_per_cell = X[:, gene_subset].sum(1) + else: + counts_per_cell = X.sum(1) start = logg.info(msg) - - # counts per cell for subset, if max_fraction!=1 - X = adata.X if gene_subset is None else adata[:, gene_subset].X - counts_per_cell = X.sum(1) - # get rid of adata view - counts_per_cell = np.ravel(counts_per_cell).copy() + counts_per_cell = np.ravel(counts_per_cell) cell_subset = counts_per_cell > 0 if not np.all(cell_subset): - logg.warning('Some cells have total count of genes equal to zero') + warn(UserWarning('Some cells have zero counts')) + if inplace: + if key_added is not None: + adata.obs[key_added] = counts_per_cell + _set_obs_rep( + adata, _normalize_data(X, counts_per_cell, target_sum), layer=layer + ) + else: + # not recarray because need to support sparse + dat = dict( + X=_normalize_data(X, counts_per_cell, target_sum, copy=True), + norm_factor=counts_per_cell, + ) + + # Deprecated features if layer_norm == 'after': after = target_sum elif layer_norm == 'X': @@ -173,26 +203,13 @@ def normalize_total( after = None else: raise ValueError('layer_norm should be "after", "X" or None') - del cell_subset - if inplace: - if key_added is not None: - adata.obs[key_added] = counts_per_cell - adata.X = _normalize_data(adata.X, counts_per_cell, target_sum) - else: - # not recarray because need to support sparse - dat = dict( - X=_normalize_data(adata.X, counts_per_cell, target_sum, copy=True), - norm_factor=counts_per_cell, + for layer_to_norm in layers if layers is not None else (): + res = normalize_total( + adata, layer=layer_to_norm, target_sum=after, inplace=inplace ) - - for layer_name in layers or (): - layer = adata.layers[layer_name] - counts = np.ravel(layer.sum(1)) - if inplace: - adata.layers[layer_name] = _normalize_data(layer, counts, after) - else: - dat[layer_name] = _normalize_data(layer, counts, after, copy=True) + if not inplace: + dat[layer_to_norm] = res["X"] logg.info( ' finished ({time_passed})', @@ -203,4 +220,7 @@ def normalize_total( f'and added {key_added!r}, counts per cell before normalization (adata.obs)' ) - return dat if not inplace else None + if copy: + return adata + elif not inplace: + return dat diff --git a/scanpy/tests/helpers.py b/scanpy/tests/helpers.py index 35253bad69..61fc35e23e 100644 --- a/scanpy/tests/helpers.py +++ b/scanpy/tests/helpers.py @@ -2,68 +2,84 @@ This file contains helper functions for the scanpy test suite. """ +from itertools import permutations + import scanpy as sc import numpy as np from anndata.tests.helpers import asarray, assert_equal +# TODO: Report more context on the fields being compared on error +# TODO: Allow specifying paths to ignore on comparison + ########################### # Representation choice ########################### # These functions can be used to check that functions are correctly using arugments like `layers`, `obsm`, etc. -def check_rep_mutation(func, X, **kwargs): +def check_rep_mutation(func, X, *, fields=["layer", "obsm"], **kwargs): """Check that only the array meant to be modified is modified.""" - adata = sc.AnnData( - X=X.copy(), - layers={"layer": X.copy()}, - obsm={"obsm": X.copy()}, - dtype=X.dtype, - ) + adata = sc.AnnData(X=X.copy(), dtype=X.dtype) + for field in fields: + sc.get._set_obs_rep(adata, X, **{field: field}) + X_array = asarray(X) + adata_X = func(adata, copy=True, **kwargs) - adata_layer = func(adata, layer="layer", copy=True, **kwargs) - adata_obsm = func(adata, obsm="obsm", copy=True, **kwargs) + adatas_proc = { + field: func(adata, copy=True, **{field: field}, **kwargs) for field in fields + } - assert np.array_equal(asarray(adata_X.X), asarray(adata_layer.layers["layer"])) - assert np.array_equal(asarray(adata_X.X), asarray(adata_obsm.obsm["obsm"])) + # Modified fields + for field in fields: + result_array = asarray( + sc.get._get_obs_rep(adatas_proc[field], **{field: field}) + ) + np.testing.assert_array_equal(asarray(adata_X.X), result_array) - assert np.array_equal(asarray(adata_layer.X), asarray(adata_layer.obsm["obsm"])) - assert np.array_equal(asarray(adata_obsm.X), asarray(adata_obsm.layers["layer"])) - assert np.array_equal( - asarray(adata_X.layers["layer"]), asarray(adata_X.obsm["obsm"]) - ) + # Unmodified fields + for field in fields: + np.testing.assert_array_equal(X_array, asarray(adatas_proc[field].X)) + np.testing.assert_array_equal( + X_array, asarray(sc.get._get_obs_rep(adata_X, **{field: field})) + ) + for field_a, field_b in permutations(fields, 2): + result_array = asarray( + sc.get._get_obs_rep(adatas_proc[field_a], **{field_b: field_b}) + ) + np.testing.assert_array_equal(X_array, result_array) -def check_rep_results(func, X, **kwargs): +def check_rep_results(func, X, *, fields=["layer", "obsm"], **kwargs): """Checks that the results of a computation add values/ mutate the anndata object in a consistent way.""" # Gen data - adata_X = sc.AnnData( - X=X.copy(), - layers={"layer": np.zeros(shape=X.shape, dtype=X.dtype)}, - obsm={"obsm": np.zeros(shape=X.shape, dtype=X.dtype)}, - ) - adata_layer = sc.AnnData( - X=np.zeros(shape=X.shape, dtype=X.dtype), - layers={"layer": X.copy()}, - obsm={"obsm": np.zeros(shape=X.shape, dtype=X.dtype)}, - ) - adata_obsm = sc.AnnData( - X=np.zeros(shape=X.shape, dtype=X.dtype), - layers={"layer": np.zeros(shape=X.shape, dtype=X.dtype)}, - obsm={"obsm": X.copy()}, + empty_X = np.zeros(shape=X.shape, dtype=X.dtype) + adata = sc.AnnData( + X=empty_X.copy(), + layers={"layer": empty_X.copy()}, + obsm={"obsm": empty_X.copy()}, ) + adata_X = adata.copy() + adata_X.X = X.copy() + + adatas_proc = {} + for field in fields: + cur = adata.copy() + sc.get._set_obs_rep(cur, X.copy(), **{field: field}) + adatas_proc[field] = cur + # Apply function func(adata_X, **kwargs) - func(adata_layer, layer="layer", **kwargs) - func(adata_obsm, obsm="obsm", **kwargs) + for field in fields: + func(adatas_proc[field], **{field: field}, **kwargs) # Reset X - adata_X.X = np.zeros(shape=X.shape, dtype=X.dtype) - adata_layer.layers["layer"] = np.zeros(shape=X.shape, dtype=X.dtype) - adata_obsm.obsm["obsm"] = np.zeros(shape=X.shape, dtype=X.dtype) + adata_X.X = empty_X.copy() + for field in fields: + sc.get._set_obs_rep(adatas_proc[field], empty_X.copy(), **{field: field}) - # Check equality - assert_equal(adata_X, adata_layer) - assert_equal(adata_X, adata_obsm) + for field_a, field_b in permutations(fields, 2): + assert_equal(adatas_proc[field_a], adatas_proc[field_b]) + for field in fields: + assert_equal(adata_X, adatas_proc[field]) diff --git a/scanpy/tests/test_normalization.py b/scanpy/tests/test_normalization.py index 9b84699d9a..0f5dbb102d 100644 --- a/scanpy/tests/test_normalization.py +++ b/scanpy/tests/test_normalization.py @@ -2,9 +2,11 @@ import numpy as np from anndata import AnnData from scipy.sparse import csr_matrix +from scipy import sparse import scanpy as sc -from anndata.tests.helpers import assert_equal +from scanpy.tests.helpers import check_rep_mutation, check_rep_results +from anndata.tests.helpers import assert_equal, asarray X_total = [[1, 0], [3, 0], [5, 6]] X_frac = [[1, 0, 1], [3, 0, 1], [5, 6, 1]] @@ -24,12 +26,22 @@ def test_normalize_total(typ, dtype): assert np.allclose(np.ravel(adata.X[:, 1:3].sum(axis=1)), [1.0, 1.0, 1.0]) +@pytest.mark.parametrize('typ', [asarray, csr_matrix], ids=lambda x: x.__name__) +@pytest.mark.parametrize('dtype', ['float32', 'int64']) +def test_normalize_total_rep(typ, dtype): + # Test that layer kwarg works + X = typ(sparse.random(100, 50, format="csr", density=0.2, dtype=dtype)) + check_rep_mutation(sc.pp.normalize_total, X, fields=["layer"]) + check_rep_results(sc.pp.normalize_total, X, fields=["layer"]) + + @pytest.mark.parametrize('typ', [np.array, csr_matrix], ids=lambda x: x.__name__) @pytest.mark.parametrize('dtype', ['float32', 'int64']) def test_normalize_total_layers(typ, dtype): adata = AnnData(typ(X_total), dtype=dtype) adata.layers["layer"] = adata.X.copy() - sc.pp.normalize_total(adata, layers=["layer"]) + with pytest.warns(FutureWarning, match=r".*layers.*deprecated"): + sc.pp.normalize_total(adata, layers=["layer"]) assert np.allclose(adata.layers["layer"].sum(axis=1), [3.0, 3.0, 3.0])