Skip to content

Commit

Permalink
Merge pull request #619 from rosteen/jdat-184
Browse files Browse the repository at this point in the history
Implement SpectralCoord in SpectrumCollection
  • Loading branch information
eteq authored Apr 1, 2020
2 parents 60f2c6c + ca66348 commit 8339d8a
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 8 deletions.
16 changes: 15 additions & 1 deletion specutils/spectra/spectral_coordinate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ def __new__(cls, value, unit=None, observer=None, target=None,
if isinstance(value, u.Quantity) and unit is None:
obj._unit = value.unit

# If we're initializing from an existing SpectralCoord, keep any
# parameters that aren't being overridden
if isinstance(value, SpectralCoord):
if observer is None:
observer = value.observer
if target is None:
target = value.target
if radial_velocity is None and redshift is None:
radial_velocity = value.radial_velocity
if doppler_rest is None:
doppler_rest = value.doppler_rest
if doppler_convention is None:
doppler_convention = value.doppler_convention

# Store state about whether the observer and target were defined
# explicitly (True), or implicity from rv/redshift (False)
obj._frames_state = dict(observer=observer is not None,
Expand Down Expand Up @@ -243,7 +257,7 @@ def _copy(self, **kwargs):
@property
def quantity(self):
"""
Convert the ``SpectralCoord`` to a simple ``~astropy.units.Quantity``.
Convert the ``SpectralCoord`` to a `~astropy.units.Quantity`.
Returns
-------
Expand Down
24 changes: 22 additions & 2 deletions specutils/spectra/spectrum_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from astropy.nddata import NDUncertainty, StdDevUncertainty

from .spectrum1d import Spectrum1D
from .spectral_coordinate import SpectralCoord

__all__ = ['SpectrumCollection']

Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(self, flux, spectral_axis=None, wcs=None, uncertainty=None,
if spectral_axis is not None:
if not isinstance(spectral_axis, u.Quantity):
raise u.UnitsError("Spectral axis must be a `Quantity`.")
spectral_axis = SpectralCoord(spectral_axis)

# Ensure that the input values are the same shape
if not (flux.shape == spectral_axis.shape):
Expand Down Expand Up @@ -116,15 +118,33 @@ def from_spectra(cls, spectra):
----------
spectra : list, ndarray
A list of :class:`~specutils.Spectrum1D` objects to be held in the
collection.
collection. Currently the spectral_axis parameters (e.g. observer,
radial_velocity) must be the same for each spectrum.
"""
# Enforce that the shape of each item must be the same
if not all((x.shape == spectra[0].shape for x in spectra)):
raise ValueError("Shape of all elements must be the same.")

# Compose multi-dimensional ndarrays for each property
flux = u.Quantity([spec.flux for spec in spectra])
spectral_axis = u.Quantity([spec.spectral_axis for spec in spectra])

# Check that the spectral parameters are the same for each input
# spectral_axis and create the multi-dimensional SpectralCoord
sa = [x.spectral_axis for x in spectra]
if not all(x.radial_velocity == sa[0].radial_velocity for x in sa) or \
not all(x.target == sa[0].target for x in sa) or \
not all(x.observer == sa[0].observer for x in sa) or \
not all(x.doppler_convention == sa[0].doppler_convention for
x in sa) or \
not all(x.doppler_rest == sa[0].doppler_rest for x in sa):
raise ValueError("All input spectral_axis SpectralCoord "
"objects must have the same parameters.")
spectral_axis = SpectralCoord(sa,
radial_velocity=sa[0].radial_velocity,
doppler_rest=sa[0].doppler_rest,
doppler_convention=sa[0].doppler_convention,
observer=sa[0].observer,
target=sa[0].target)

# Check that either all spectra have associated uncertainties, or that
# none of them do. If only some do, log an error and ignore the
Expand Down
13 changes: 13 additions & 0 deletions specutils/tests/test_spectral_coord.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,19 @@ def test_create_spectral_coord_observer_target(observer, target):
else:
raise NotImplementedError()

def test_create_from_spectral_coord(observer, target):
"""
Checks that parameters are correctly copied to the new SpectralCoord object
"""
spec_coord1 = SpectralCoord([100, 200, 300] * u.nm, observer=observer,
target=target, radial_velocity=u.Quantity(1000, 'km/s'),
doppler_convention = 'optical', doppler_rest = 6000*u.AA)
spec_coord2 = SpectralCoord(spec_coord1)
assert spec_coord1.observer == spec_coord2.observer
assert spec_coord2.target == spec_coord2.target
assert spec_coord2.radial_velocity == spec_coord2.radial_velocity
assert spec_coord2.doppler_convention == spec_coord2.doppler_convention
assert spec_coord2.doppler_rest == spec_coord2.doppler_rest

# SCIENCE USE CASE TESTS

Expand Down
22 changes: 17 additions & 5 deletions specutils/tests/test_spectrum_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from gwcs.wcs import WCS as GWCS

from ..spectra.spectrum1d import Spectrum1D
from ..spectra.spectral_coordinate import SpectralCoord
from ..spectra.spectrum_collection import SpectrumCollection
from ..utils.wcs_utils import gwcs_from_array

Expand Down Expand Up @@ -76,11 +77,11 @@ def test_collection_without_optional_arguments():


def test_create_collection_from_spectrum1D():
spec = Spectrum1D(spectral_axis=np.linspace(0, 50, 50) * u.AA,
flux=np.random.randn(50) * u.Jy,
spec = Spectrum1D(spectral_axis=SpectralCoord(np.linspace(0, 50, 50) * u.AA,
redshift=0.1), flux=np.random.randn(50) * u.Jy,
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))
spec1 = Spectrum1D(spectral_axis=np.linspace(20, 60, 50) * u.AA,
flux=np.random.randn(50) * u.Jy,
spec1 = Spectrum1D(spectral_axis=SpectralCoord(np.linspace(20, 60, 50) * u.AA,
redshift=0.1), flux=np.random.randn(50) * u.Jy,
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))

spec_coll = SpectrumCollection.from_spectra([spec, spec1])
Expand All @@ -89,9 +90,10 @@ def test_create_collection_from_spectrum1D():
assert spec_coll.shape == (2, )
assert spec_coll.nspectral == 50
assert isinstance(spec_coll.flux, u.Quantity)
assert isinstance(spec_coll.spectral_axis, u.Quantity)
assert isinstance(spec_coll.spectral_axis, SpectralCoord)
assert spec.spectral_axis.unit == spec_coll.spectral_axis.unit
assert spec.flux.unit == spec_coll.flux.unit
assert spec_coll.spectral_axis.redshift == 0.1


def test_create_collection_from_collections():
Expand Down Expand Up @@ -129,6 +131,16 @@ def test_create_collection_from_spectra_without_uncertainties():

SpectrumCollection.from_spectra([spec, spec1])

def test_mismatched_spectral_axes_parameters():
spec = Spectrum1D(spectral_axis=SpectralCoord(np.linspace(0, 50, 50) * u.AA,
radial_velocity=u.Quantity(100.0, "km/s")),
flux=np.random.randn(50) * u.Jy)
spec1 = Spectrum1D(spectral_axis=SpectralCoord(np.linspace(20, 60, 50) * u.AA,
radial_velocity=u.Quantity(200.0, "km/s")),
flux=np.random.randn(50) * u.Jy)

with pytest.raises(ValueError) as e_info:
SpectrumCollection.from_spectra([spec, spec1])

@pytest.mark.parametrize('scshape,expected_len', [((5, 10), 5), ((4, 5, 10), 4)])
def test_len(scshape, expected_len):
Expand Down

0 comments on commit 8339d8a

Please sign in to comment.