diff --git a/docs/analysis.rst b/docs/analysis.rst index 53fb42b00..fb2ac690a 100644 --- a/docs/analysis.rst +++ b/docs/analysis.rst @@ -150,7 +150,7 @@ The `~specutils.analysis.moment` function computes moments of any order: >>> from specutils.analysis import moment >>> moment(noisy_gaussian, SpectralRegion(7*u.GHz, 3*u.GHz)) # doctest:+FLOAT_CMP - + >>> moment(noisy_gaussian, SpectralRegion(7*u.GHz, 3*u.GHz), order=1) # doctest:+FLOAT_CMP >>> moment(noisy_gaussian, SpectralRegion(7*u.GHz, 3*u.GHz), order=2) # doctest:+FLOAT_CMP diff --git a/docs/spectral_cube.rst b/docs/spectral_cube.rst index 8a8d4bd8b..e5911ccc5 100644 --- a/docs/spectral_cube.rst +++ b/docs/spectral_cube.rst @@ -115,9 +115,9 @@ along the spectral axis (remember that the spectral axis is always last in a >>> m.shape # doctest: +REMOTE_DATA (74, 74) >>> m[30:33,30:33] # doctest: +REMOTE_DATA +FLOAT_CMP - + Use Case ======== diff --git a/specutils/analysis/moment.py b/specutils/analysis/moment.py index 4b50cbb5d..4e0c968f2 100644 --- a/specutils/analysis/moment.py +++ b/specutils/analysis/moment.py @@ -5,6 +5,7 @@ import numpy as np from ..manipulation import extract_region +from ..spectra import SpectrumCollection from .utils import computation_wrapper @@ -39,6 +40,9 @@ def moment(spectrum, regions=None, order=0, axis=-1): Moment of the spectrum. Returns None if (order < 0 or None) """ + if isinstance(spectrum, SpectrumCollection): + return [computation_wrapper(_compute_moment, spec, regions,order=order, axis=axis) + for spec in spectrum] return computation_wrapper(_compute_moment, spectrum, regions, order=order, axis=axis) @@ -60,8 +64,10 @@ def _compute_moment(spectrum, regions=None, order=0, axis=-1): if order is None or order < 0: return None + dx = np.abs(np.diff(spectral_axis.bin_edges)) + m0 = np.sum(flux * dx, axis=axis) if order == 0: - return np.sum(flux, axis=axis) + return m0 dispersion = spectral_axis if len(flux.shape) > len(spectral_axis.shape): @@ -69,15 +75,14 @@ def _compute_moment(spectrum, regions=None, order=0, axis=-1): dispersion = np.tile(spectral_axis, _shape) if order == 1: - return np.sum(flux * dispersion, axis=axis) / np.sum(flux, axis=axis) + return np.sum(flux * dispersion * dx, axis=axis) / m0 if order > 1: - m0 = np.sum(flux, axis=axis) # By setting keepdims to True, the axes which are reduced are # left in the result as dimensions with size one. This means # that we can broadcast m1 correctly against dispersion. - m1 = (np.sum(flux * spectral_axis, axis=axis, keepdims=True) - / np.sum(flux, axis=axis, keepdims=True)) + m1 = (np.sum(flux * dispersion * dx, axis=axis, keepdims=True) + / np.sum(flux * dx, axis=axis, keepdims=True)) - return np.sum(flux * (dispersion - m1) ** order, axis=axis) / m0 + return np.sum(flux * dx * (dispersion - m1) ** order, axis=axis) / m0 diff --git a/specutils/tests/test_analysis.py b/specutils/tests/test_analysis.py index 170b58cd3..6efd1958e 100644 --- a/specutils/tests/test_analysis.py +++ b/specutils/tests/test_analysis.py @@ -1065,8 +1065,8 @@ def test_moment(): spectrum = Spectrum1D(spectral_axis=frequencies, flux=flux) moment_0 = moment(spectrum, order=0) - assert moment_0.unit.is_equivalent(u.Jy) - assert quantity_allclose(moment_0, 252.96*u.Jy, atol=0.01*u.Jy) + assert moment_0.unit.is_equivalent(u.Jy * u.GHz) + assert quantity_allclose(moment_0, 2.5045*u.Jy*u.GHz, atol=0.001*u.Jy*u.GHz) moment_1 = moment(spectrum, order=1) assert moment_1.unit.is_equivalent(u.GHz) @@ -1097,6 +1097,11 @@ def test_moment_cube(): spectrum = Spectrum1D(spectral_axis=frequencies, flux=flux_multid) + moment_0 = moment(spectrum, order=0) + + assert moment_0.shape == (9, 10) + assert moment_0.unit.is_equivalent(u.Jy*u.GHz) + moment_1 = moment(spectrum, order=1) assert moment_1.shape == (9, 10) @@ -1159,8 +1164,8 @@ def test_moment_cube_order_2(): assert moment_2.shape == (10, 10000) assert moment_2.unit.is_equivalent(u.GHz**2) # check assorted values. - assert quantity_allclose(moment_2[0][0], 2.019e-28*u.GHz**2, rtol=0.01) - assert quantity_allclose(moment_2[1][0], 2.019e-28*u.GHz**2, rtol=0.01) + assert quantity_allclose(moment_2[0][0], 8.078e-28*u.GHz**2, rtol=0.01) + assert quantity_allclose(moment_2[1][0], 8.078e-28*u.GHz**2, rtol=0.01) assert quantity_allclose(moment_2[0][3], 2.019e-28*u.GHz**2, rtol=0.01)