diff --git a/CHANGES.rst b/CHANGES.rst index 56a6f2719..f8c1107bf 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,8 @@ New Features Bug Fixes ^^^^^^^^^ +- Spectrum1D math operators no longer perform illogical math. [#998] + 1.9.1 (2022-11-22) ------------------ diff --git a/specutils/spectra/spectrum1d.py b/specutils/spectra/spectrum1d.py index 639e094e0..37e0a4c76 100644 --- a/specutils/spectra/spectrum1d.py +++ b/specutils/spectra/spectrum1d.py @@ -671,39 +671,55 @@ def redshift(self, val): def radial_velocity(self, val): self.shift_spectrum_to(radial_velocity=val) + def _validate_op_other(self, other): + """Throw error if math with other is impossible.""" + if isinstance(other, u.Quantity): + if other.size != 1: + raise ValueError('Quantity must be scalar.') + + elif isinstance(other, Spectrum1D): + if not np.allclose(self.spectral_axis, other.spectral_axis): + raise ValueError('Mismatched spectral_axis, please resample spectrum.') + + elif not isinstance(other, NDCube): + raise NotImplementedError(f'Cannot operate on {other.__class__.__name__} class.') + def __add__(self, other): - if not isinstance(other, (NDCube, u.Quantity)): - try: - other = u.Quantity(other, unit=self.unit) - except TypeError: - return NotImplemented + if isinstance(other, (int, float)): + other = u.Quantity(other, unit=self.unit) # I hope you know what you doing. + else: + self._validate_op_other(other) return self.add(other) def __sub__(self, other): - if not isinstance(other, NDCube): - try: - other = u.Quantity(other, unit=self.unit) - except TypeError: - return NotImplemented + # Enables specreduce background subtraction via __rsub__ in + # https://github.com/astropy/specreduce/blob/main/specreduce/background.py + if hasattr(other, 'image') and isinstance(other.image, Spectrum1D): + if not np.allclose(self.spectral_axis, other.image.spectral_axis): + raise ValueError('Mismatched spectral_axis, please resample spectrum.') + return NotImplemented + + if isinstance(other, (int, float)): + other = u.Quantity(other, unit=self.unit) # I hope you know what you doing. + else: + self._validate_op_other(other) return self.subtract(other) def __mul__(self, other): - if not isinstance(other, NDCube): + if isinstance(other, (int, float)): other = u.Quantity(other) + else: + self._validate_op_other(other) return self.multiply(other) - def __div__(self, other): - if not isinstance(other, NDCube): - other = u.Quantity(other) - - return self.divide(other) - def __truediv__(self, other): - if not isinstance(other, NDCube): + if isinstance(other, (int, float)): other = u.Quantity(other) + else: + self._validate_op_other(other) return self.divide(other) diff --git a/specutils/tests/test_arithmetic.py b/specutils/tests/test_arithmetic.py index e8b87ffb8..c8c416fbd 100644 --- a/specutils/tests/test_arithmetic.py +++ b/specutils/tests/test_arithmetic.py @@ -8,29 +8,129 @@ from specutils.spectra.spectrum1d import Spectrum1D +# Mimic specreduce background object a little. +class _MockBackground: + def __init__(self, spec): + self.image = spec + + def __rsub__(self, other): + return 42 # Does not matter what, just want to make sure this is called. + + class TestMathWithAllOnes: def setup_class(self): flux = np.ones(10) * u.nJy wave = (np.arange(flux.size) + 1) * u.um self.spec = Spectrum1D(spectral_axis=wave, flux=flux) - def test_add_sub_spectral_axes_same(self): + def test_math_with_spectral_axes_same(self): spec_added = self.spec + self.spec assert_quantity_allclose(spec_added.flux, 2 * u.nJy) spec_subbed = self.spec - self.spec assert_quantity_allclose(spec_subbed.flux, 0 * u.nJy) - def test_add_sub_spectral_axes_different(self): + # TODO: mul, div + + def test_math_with_diff_spectral_axis(self): + """Same spectral axis but in different units.""" + new_spec = Spectrum1D(spectral_axis=self.spec.spectral_axis.to(u.AA), + flux=self.spec.flux) + + spec_added = self.spec + new_spec + assert_quantity_allclose(spec_added.flux, 2 * u.nJy) + + spec_subbed = self.spec - new_spec + assert_quantity_allclose(spec_subbed.flux, 0 * u.nJy) + + # TODO: mul, div + + def test_math_with_spectral_axes_different(self): new_wave = self.spec.spectral_axis + (1 * u.um) new_spec = Spectrum1D(spectral_axis=new_wave, flux=self.spec.flux) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match='Mismatched spectral_axis'): self.spec + new_spec - with pytest.raises(ValueError): + with pytest.raises(ValueError, match='Mismatched spectral_axis'): self.spec - new_spec + # TODO: mul, div + + def test_math_with_plain_number(self): + spec_added_fwd = self.spec + 1 + assert_quantity_allclose(spec_added_fwd.flux, 2 * u.nJy) + + spec_added_bak = 1 + self.spec + assert_quantity_allclose(spec_added_bak.flux, 2 * u.nJy) + + spec_subbed_fwd = self.spec - 1 + assert_quantity_allclose(spec_subbed_fwd.flux, 0 * u.nJy) + + spec_subbed_bak = 2 - self.spec + assert_quantity_allclose(spec_subbed_bak.flux, 1 * u.nJy) + + spec_mul_fwd = self.spec * 2 + assert_quantity_allclose(spec_mul_fwd.flux, 2 * u.nJy) + + spec_mul_bak = 2 * self.spec + assert_quantity_allclose(spec_mul_bak.flux, 2 * u.nJy) + + spec_div = self.spec / 2 + assert_quantity_allclose(spec_div.flux, 0.5 * u.nJy) + + with pytest.raises(TypeError, match='unsupported operand'): + 1 / self.spec + + with pytest.raises(NotImplementedError, match='Cannot operate on ndarray class'): + self.spec * np.ones(self.spec.flux.shape) + + def test_math_with_quantity(self): + # TODO: Like plain number, but also see Jy * Jy becomes Jy**2, Jy + nJy + # TODO: Array Quantity now errors + pass + + def test_math_with_ndcube(self): + from astropy.wcs import WCS + from ndcube import NDCube + + data = np.ones((4, 4, 10)) * u.nJy + wcs = WCS(naxis=3) + wcs.wcs.ctype = 'WAVE', 'RA--TAN', 'DEC-TAN' + wcs.wcs.cunit = 'Angstrom', 'deg', 'deg' + wcs.wcs.cdelt = 0.2, 0.5, 0.4 + wcs.wcs.crpix = 0, 2, 2 + wcs.wcs.crval = 10, 0.5, 1 + wcs.wcs.cname = 'wavelength', 'lon', 'lat' + + ndc = NDCube(data, wcs=wcs) + spec3d = Spectrum1D(flux=data, wcs=wcs) + + spec_added = spec3d + ndc + assert_quantity_allclose(spec_added.flux, 2 * u.nJy) + + spec_subbed = spec3d - ndc + assert_quantity_allclose(spec_subbed.flux, 0 * u.nJy) + + spec_mul = spec3d * ndc + assert_quantity_allclose(spec_mul.flux, 1 * (u.nJy * u.nJy)) + + spec_div = spec3d / ndc + assert_quantity_allclose(spec_div.flux, 1) + + # Also test 3D vs 1D Spectrum1D + with pytest.raises(ValueError, match='Mismatched spectral_axis'): + spec3d + self.spec + + def test_nd_vs_1d(self): + spec2d = Spectrum1D(flux=np.ones((10, 10)) * u.nJy, spectral_axis=self.spec.spectral_axis) + spec_added_2d = spec2d + self.spec + assert_quantity_allclose(spec_added_2d.flux, 2 * u.nJy) + + spec3d = Spectrum1D(flux=np.ones((10, 10, 10)) * u.nJy, spectral_axis=self.spec.spectral_axis) + spec_added_3d = spec3d + self.spec + assert_quantity_allclose(spec_added_3d.flux, 2 * u.nJy) + def test_mask_nans(self): new_flux = deepcopy(self.spec.flux) nan_idx = [1, 3, 5] @@ -39,6 +139,10 @@ def test_mask_nans(self): spec_added = self.spec + new_spec assert spec_added.mask[nan_idx].all() + def test_specreduce_bg_rsub(self): + bg = _MockBackground(self.spec) + assert (self.spec - bg) == 42 + def test_add_basic_spectra(simulated_spectra): @@ -50,7 +154,7 @@ def test_add_basic_spectra(simulated_spectra): # Calculate using the spectrum1d/nddata code spec3 = simulated_spectra.s1_um_mJy_e1 + simulated_spectra.s1_um_mJy_e2 - assert np.allclose(spec3.flux.value, flux3) + np.testing.assert_allclose(spec3.flux.value, flux3) def test_add_diff_flux_prefix(simulated_spectra): @@ -64,7 +168,7 @@ def test_add_diff_flux_prefix(simulated_spectra): # Calculate using the spectrum1d/nddata code spec3 = simulated_spectra.s1_AA_mJy_e3 + simulated_spectra.s1_AA_nJy_e4 - assert np.allclose(spec3.flux.value, flux3) + np.testing.assert_allclose(spec3.flux.value, flux3) def test_subtract_basic_spectra(simulated_spectra): @@ -77,7 +181,7 @@ def test_subtract_basic_spectra(simulated_spectra): # Calculate using the spectrum1d/nddata code spec3 = simulated_spectra.s1_um_mJy_e2 - simulated_spectra.s1_um_mJy_e1 - assert np.allclose(spec3.flux.value, flux3) + np.testing.assert_allclose(spec3.flux.value, flux3) def test_divide_basic_spectra(simulated_spectra): @@ -90,7 +194,7 @@ def test_divide_basic_spectra(simulated_spectra): # Calculate using the spectrum1d/nddata code spec3 = simulated_spectra.s1_um_mJy_e1 / simulated_spectra.s1_um_mJy_e2 - assert np.allclose(spec3.flux.value, flux3) + np.testing.assert_allclose(spec3.flux.value, flux3) def test_multiplication_basic_spectra(simulated_spectra): @@ -103,13 +207,7 @@ def test_multiplication_basic_spectra(simulated_spectra): # Calculate using the spectrum1d/nddata code spec3 = simulated_spectra.s1_um_mJy_e1 * simulated_spectra.s1_um_mJy_e2 - assert np.allclose(spec3.flux.value, flux3) - - -def test_add_diff_spectral_axis(simulated_spectra): - - # Calculate using the spectrum1d/nddata code - spec3 = simulated_spectra.s1_um_mJy_e1 + simulated_spectra.s1_AA_mJy_e3 # noqa + np.testing.assert_allclose(spec3.flux.value, flux3) def test_masks(simulated_spectra): @@ -120,7 +218,7 @@ def test_masks(simulated_spectra): masked_sum.mask[:50] = True masked_diff = masked_sum - masked_spec - assert u.allclose(masked_diff.flux, masked_spec.flux) + assert_quantity_allclose(masked_diff.flux, masked_spec.flux) assert np.all(masked_diff.mask == masked_sum.mask | masked_spec.mask)