Skip to content

Commit

Permalink
More implementation.
Browse files Browse the repository at this point in the history
Just need a few more tests.
  • Loading branch information
pllim committed Dec 1, 2022
1 parent 79bd22f commit e718cf7
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 34 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ New Features
Bug Fixes
^^^^^^^^^

- Spectrum1D math operators no longer perform illogical math. [#998]

1.9.1 (2022-11-22)
------------------

Expand Down
52 changes: 34 additions & 18 deletions specutils/spectra/spectrum1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,39 +671,55 @@ def redshift(self, val):
def radial_velocity(self, val):
self.shift_spectrum_to(radial_velocity=val)

def _validate_op_other(self, other):
"""Throw error if math with other is impossible."""
if isinstance(other, u.Quantity):
if other.size != 1:
raise ValueError('Quantity must be scalar.')

elif isinstance(other, Spectrum1D):
if not np.allclose(self.spectral_axis, other.spectral_axis):
raise ValueError('Mismatched spectral_axis, please resample spectrum.')

elif not isinstance(other, NDCube):
raise NotImplementedError(f'Cannot operate on {other.__class__.__name__} class.')

def __add__(self, other):
if not isinstance(other, (NDCube, u.Quantity)):
try:
other = u.Quantity(other, unit=self.unit)
except TypeError:
return NotImplemented
if isinstance(other, (int, float)):
other = u.Quantity(other, unit=self.unit) # I hope you know what you doing.
else:
self._validate_op_other(other)

return self.add(other)

def __sub__(self, other):
if not isinstance(other, NDCube):
try:
other = u.Quantity(other, unit=self.unit)
except TypeError:
return NotImplemented
# Enables specreduce background subtraction via __rsub__ in
# https://github.com/astropy/specreduce/blob/main/specreduce/background.py
if hasattr(other, 'image') and isinstance(other.image, Spectrum1D):
if not np.allclose(self.spectral_axis, other.image.spectral_axis):
raise ValueError('Mismatched spectral_axis, please resample spectrum.')
return NotImplemented

if isinstance(other, (int, float)):
other = u.Quantity(other, unit=self.unit) # I hope you know what you doing.
else:
self._validate_op_other(other)

return self.subtract(other)

def __mul__(self, other):
if not isinstance(other, NDCube):
if isinstance(other, (int, float)):
other = u.Quantity(other)
else:
self._validate_op_other(other)

return self.multiply(other)

def __div__(self, other):
if not isinstance(other, NDCube):
other = u.Quantity(other)

return self.divide(other)

def __truediv__(self, other):
if not isinstance(other, NDCube):
if isinstance(other, (int, float)):
other = u.Quantity(other)
else:
self._validate_op_other(other)

return self.divide(other)

Expand Down
130 changes: 114 additions & 16 deletions specutils/tests/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,129 @@
from specutils.spectra.spectrum1d import Spectrum1D


# Mimic specreduce background object a little.
class _MockBackground:
def __init__(self, spec):
self.image = spec

def __rsub__(self, other):
return 42 # Does not matter what, just want to make sure this is called.


class TestMathWithAllOnes:
def setup_class(self):
flux = np.ones(10) * u.nJy
wave = (np.arange(flux.size) + 1) * u.um
self.spec = Spectrum1D(spectral_axis=wave, flux=flux)

def test_add_sub_spectral_axes_same(self):
def test_math_with_spectral_axes_same(self):
spec_added = self.spec + self.spec
assert_quantity_allclose(spec_added.flux, 2 * u.nJy)

spec_subbed = self.spec - self.spec
assert_quantity_allclose(spec_subbed.flux, 0 * u.nJy)

def test_add_sub_spectral_axes_different(self):
# TODO: mul, div

def test_math_with_diff_spectral_axis(self):
"""Same spectral axis but in different units."""
new_spec = Spectrum1D(spectral_axis=self.spec.spectral_axis.to(u.AA),
flux=self.spec.flux)

spec_added = self.spec + new_spec
assert_quantity_allclose(spec_added.flux, 2 * u.nJy)

spec_subbed = self.spec - new_spec
assert_quantity_allclose(spec_subbed.flux, 0 * u.nJy)

# TODO: mul, div

def test_math_with_spectral_axes_different(self):
new_wave = self.spec.spectral_axis + (1 * u.um)
new_spec = Spectrum1D(spectral_axis=new_wave, flux=self.spec.flux)

with pytest.raises(ValueError):
with pytest.raises(ValueError, match='Mismatched spectral_axis'):
self.spec + new_spec

with pytest.raises(ValueError):
with pytest.raises(ValueError, match='Mismatched spectral_axis'):
self.spec - new_spec

# TODO: mul, div

def test_math_with_plain_number(self):
spec_added_fwd = self.spec + 1
assert_quantity_allclose(spec_added_fwd.flux, 2 * u.nJy)

spec_added_bak = 1 + self.spec
assert_quantity_allclose(spec_added_bak.flux, 2 * u.nJy)

spec_subbed_fwd = self.spec - 1
assert_quantity_allclose(spec_subbed_fwd.flux, 0 * u.nJy)

spec_subbed_bak = 2 - self.spec
assert_quantity_allclose(spec_subbed_bak.flux, 1 * u.nJy)

spec_mul_fwd = self.spec * 2
assert_quantity_allclose(spec_mul_fwd.flux, 2 * u.nJy)

spec_mul_bak = 2 * self.spec
assert_quantity_allclose(spec_mul_bak.flux, 2 * u.nJy)

spec_div = self.spec / 2
assert_quantity_allclose(spec_div.flux, 0.5 * u.nJy)

with pytest.raises(TypeError, match='unsupported operand'):
1 / self.spec

with pytest.raises(NotImplementedError, match='Cannot operate on ndarray class'):
self.spec * np.ones(self.spec.flux.shape)

def test_math_with_quantity(self):
# TODO: Like plain number, but also see Jy * Jy becomes Jy**2, Jy + nJy
# TODO: Array Quantity now errors
pass

def test_math_with_ndcube(self):
from astropy.wcs import WCS
from ndcube import NDCube

data = np.ones((4, 4, 10)) * u.nJy
wcs = WCS(naxis=3)
wcs.wcs.ctype = 'WAVE', 'RA--TAN', 'DEC-TAN'
wcs.wcs.cunit = 'Angstrom', 'deg', 'deg'
wcs.wcs.cdelt = 0.2, 0.5, 0.4
wcs.wcs.crpix = 0, 2, 2
wcs.wcs.crval = 10, 0.5, 1
wcs.wcs.cname = 'wavelength', 'lon', 'lat'

ndc = NDCube(data, wcs=wcs)
spec3d = Spectrum1D(flux=data, wcs=wcs)

spec_added = spec3d + ndc
assert_quantity_allclose(spec_added.flux, 2 * u.nJy)

spec_subbed = spec3d - ndc
assert_quantity_allclose(spec_subbed.flux, 0 * u.nJy)

spec_mul = spec3d * ndc
assert_quantity_allclose(spec_mul.flux, 1 * (u.nJy * u.nJy))

spec_div = spec3d / ndc
assert_quantity_allclose(spec_div.flux, 1)

# Also test 3D vs 1D Spectrum1D
with pytest.raises(ValueError, match='Mismatched spectral_axis'):
spec3d + self.spec

def test_nd_vs_1d(self):
spec2d = Spectrum1D(flux=np.ones((10, 10)) * u.nJy, spectral_axis=self.spec.spectral_axis)
spec_added_2d = spec2d + self.spec
assert_quantity_allclose(spec_added_2d.flux, 2 * u.nJy)

spec3d = Spectrum1D(flux=np.ones((10, 10, 10)) * u.nJy, spectral_axis=self.spec.spectral_axis)
spec_added_3d = spec3d + self.spec
assert_quantity_allclose(spec_added_3d.flux, 2 * u.nJy)

def test_mask_nans(self):
new_flux = deepcopy(self.spec.flux)
nan_idx = [1, 3, 5]
Expand All @@ -39,6 +139,10 @@ def test_mask_nans(self):
spec_added = self.spec + new_spec
assert spec_added.mask[nan_idx].all()

def test_specreduce_bg_rsub(self):
bg = _MockBackground(self.spec)
assert (self.spec - bg) == 42


def test_add_basic_spectra(simulated_spectra):

Expand All @@ -50,7 +154,7 @@ def test_add_basic_spectra(simulated_spectra):
# Calculate using the spectrum1d/nddata code
spec3 = simulated_spectra.s1_um_mJy_e1 + simulated_spectra.s1_um_mJy_e2

assert np.allclose(spec3.flux.value, flux3)
np.testing.assert_allclose(spec3.flux.value, flux3)


def test_add_diff_flux_prefix(simulated_spectra):
Expand All @@ -64,7 +168,7 @@ def test_add_diff_flux_prefix(simulated_spectra):
# Calculate using the spectrum1d/nddata code
spec3 = simulated_spectra.s1_AA_mJy_e3 + simulated_spectra.s1_AA_nJy_e4

assert np.allclose(spec3.flux.value, flux3)
np.testing.assert_allclose(spec3.flux.value, flux3)


def test_subtract_basic_spectra(simulated_spectra):
Expand All @@ -77,7 +181,7 @@ def test_subtract_basic_spectra(simulated_spectra):
# Calculate using the spectrum1d/nddata code
spec3 = simulated_spectra.s1_um_mJy_e2 - simulated_spectra.s1_um_mJy_e1

assert np.allclose(spec3.flux.value, flux3)
np.testing.assert_allclose(spec3.flux.value, flux3)


def test_divide_basic_spectra(simulated_spectra):
Expand All @@ -90,7 +194,7 @@ def test_divide_basic_spectra(simulated_spectra):
# Calculate using the spectrum1d/nddata code
spec3 = simulated_spectra.s1_um_mJy_e1 / simulated_spectra.s1_um_mJy_e2

assert np.allclose(spec3.flux.value, flux3)
np.testing.assert_allclose(spec3.flux.value, flux3)


def test_multiplication_basic_spectra(simulated_spectra):
Expand All @@ -103,13 +207,7 @@ def test_multiplication_basic_spectra(simulated_spectra):
# Calculate using the spectrum1d/nddata code
spec3 = simulated_spectra.s1_um_mJy_e1 * simulated_spectra.s1_um_mJy_e2

assert np.allclose(spec3.flux.value, flux3)


def test_add_diff_spectral_axis(simulated_spectra):

# Calculate using the spectrum1d/nddata code
spec3 = simulated_spectra.s1_um_mJy_e1 + simulated_spectra.s1_AA_mJy_e3 # noqa
np.testing.assert_allclose(spec3.flux.value, flux3)


def test_masks(simulated_spectra):
Expand All @@ -120,7 +218,7 @@ def test_masks(simulated_spectra):

masked_sum.mask[:50] = True
masked_diff = masked_sum - masked_spec
assert u.allclose(masked_diff.flux, masked_spec.flux)
assert_quantity_allclose(masked_diff.flux, masked_spec.flux)
assert np.all(masked_diff.mask == masked_sum.mask | masked_spec.mask)


Expand Down

0 comments on commit e718cf7

Please sign in to comment.