Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cube writer to properly write out cube generated by model fitting #2012

Merged
merged 7 commits into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ New Features
Cubeviz
^^^^^^^

- Custom Spectrum1D writer for spectral cube generated by Cubeviz. [#2012]

Imviz
^^^^^

Expand Down
9 changes: 9 additions & 0 deletions docs/cubeviz/export_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ Alternatively, you can wrap this all into a single command:

mydata = cubeviz.app.get_data_from_viewer("uncert-viewer", "data_name")

To write out a `specutils.Spectrum1D` cube from Cubeviz
(e.g., a fitted cube from :ref:`model-fitting`),
where the mask (if available) is as defined in
`Spectrum1D masks <https://specutils.readthedocs.io/en/latest/spectrum1d.html#including-masks>`_:

.. code-block:: python

mydata.write("mydata.fits", format="jdaviz-cube")

Data can also be accessed directly from ``data_collection`` using the following code:

.. code-block:: python
Expand Down
64 changes: 64 additions & 0 deletions jdaviz/configs/cubeviz/helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import numpy as np
from astropy.io import fits
from astropy.io import registry as io_registry
from astropy.utils.decorators import deprecated
from glue.core import BaseData
from specutils import Spectrum1D
from specutils.io.registers import _astropy_has_priorities

from jdaviz.core.helpers import ImageConfigHelper
from jdaviz.configs.default.plugins.line_lists.line_list_mixin import LineListMixin
Expand Down Expand Up @@ -126,3 +130,63 @@ class CubeViz(Cubeviz):

def layer_is_cube_image_data(layer):
return isinstance(layer, BaseData) and layer.ndim in (2, 3)


# TODO: We can remove this when specutils supports it, i.e.,
# https://github.com/astropy/specutils/issues/592 and
# https://github.com/astropy/specutils/pull/1009
# NOTE: Cannot use custom_write decorator from specutils because
# that involves asking user to manually put something in
# their ~/.specutils directory.

def jdaviz_cube_fitswriter(spectrum, file_name, **kwargs):
"""This is a custom writer for Spectrum1D data cube.
This writer is specifically targetting data cube
generated from Cubeviz plugins (e.g., cube fitting)
with FITS WCS. It writes out data in the following format
(with MASK only exist when applicable)::

No. Name Ver Type
0 PRIMARY 1 PrimaryHDU
1 SCI 1 ImageHDU (float32)
2 MASK 1 ImageHDU (uint16)
Comment on lines +150 to +152
Copy link
Collaborator

@dhomeier dhomeier Feb 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the file structure you need, I don't think wcs1d-fits can do it in its present form – with or without the new PR. You can write to the first extension instead of the primary with write(..., hdu=1, format='wcs1d-fits'), but currently there is no provision to write the mask. It could certainly be added, but last I recall specutils had rather poor support to write to a new HDU in an existing HDUList.
If this loader is working as it should, might be best to use it as is and perhaps then contribute to the default_loaders upstream. Would this make sense to provide as a writer for jwst_s3d_loader, or is that yet another format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has nothing to do with JWST, it is a "generic" format for a cube output from model fitting plugin that could or could not have mask populated. So putting this in jwst_s3d_loader would be very misleading.

Copy link
Collaborator

@dhomeier dhomeier Feb 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case it would probably belong into wcs1d-fits. I have not found much on mask treatment apart from SDSS/MaNGA (which is storing it in hdulist['MASK']), so otherwise don't know what would be a standard for storing it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dunno either. If this is not generic enough, personally, I am okay with this writer just living in Jdaviz forever.


The FITS file generated by this writer does not need a
custom reader to be read back into Spectrum1D.

Examples
--------
To write out a Spectrum1D cube using this writer:

>>> spec.write("my_output.fits", format="jdaviz-cube", overwrite=True) # doctest: +SKIP

"""
pri_hdu = fits.PrimaryHDU()

flux = spectrum.flux
sci_hdu = fits.ImageHDU(flux.value.astype(np.float32))
sci_hdu.name = "SCI"
sci_hdu.header.update(spectrum.meta)
sci_hdu.header.update(spectrum.wcs.to_header())
sci_hdu.header['BUNIT'] = flux.unit.to_string(format='fits')

hlist = [pri_hdu, sci_hdu]

# https://specutils.readthedocs.io/en/latest/spectrum1d.html#including-masks
# Good: False or 0
# Bad: True or non-zero
if spectrum.mask is not None:
mask_hdu = fits.ImageHDU(spectrum.mask.astype(np.uint16))
mask_hdu.name = "MASK"
hlist.append(mask_hdu)

hdulist = fits.HDUList(hlist)
hdulist.writeto(file_name, **kwargs)


if _astropy_has_priorities():
kwargs = {"priority": 0}
else: # pragma: no cover
kwargs = {}
io_registry.register_writer(
"jdaviz-cube", Spectrum1D, jdaviz_cube_fitswriter, force=False, **kwargs)
17 changes: 13 additions & 4 deletions jdaviz/configs/cubeviz/plugins/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,22 @@ def _return_spectrum_with_correct_units(flux, wcs, metadata, data_type, target_w
category=UserWarning)
sc = Spectrum1D(flux=flux, wcs=wcs)

# TODO: Can the unit be defined in a different keyword, e.g., CUNIT1?
if target_wave_unit is None and hdulist is not None:
cunit_key = 'CUNIT3'
found_target = False
for ext in ('SCI', 'FLUX', 'PRIMARY'): # In priority order
if ext in hdulist and cunit_key in hdulist[ext].header:
target_wave_unit = u.Unit(hdulist[ext].header[cunit_key])
if found_target:
break
if ext not in hdulist:
continue
hdr = hdulist[ext].header
# The WCS could be swapped or unswapped.
for cunit_num in (3, 1):
cunit_key = f"CUNIT{cunit_num}"
ctype_key = f"CTYPE{cunit_num}"
if cunit_key in hdr and 'WAVE' in hdr[ctype_key]:
target_wave_unit = u.Unit(hdr[cunit_key])
found_target = True
break

if (data_type == 'flux' and target_wave_unit is not None
and target_wave_unit != sc.spectral_axis.unit):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,10 @@ def _fit_model_to_cube(self, add_data):
return

# Get the primary data component
spec = data.get_object(cls=Spectrum1D, statistic=None)
if "_orig_spec" in data.meta:
spec = data.meta["_orig_spec"]
else:
spec = data.get_object(cls=Spectrum1D, statistic=None)

snackbar_message = SnackbarMessage(
"Fitting model to cube...",
Expand Down Expand Up @@ -825,8 +828,8 @@ def _fit_model_to_cube(self, add_data):

# Create new glue data object
output_cube = Data(label=label,
coords=data.coords)
output_cube['flux'] = fitted_spectrum.flux.value
coords=fitted_spectrum.wcs,
flux=fitted_spectrum.flux.value)
output_cube.get_component('flux').units = fitted_spectrum.flux.unit.to_string()

if add_data:
Expand Down
101 changes: 75 additions & 26 deletions jdaviz/configs/default/plugins/model_fitting/tests/test_fitting.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import astropy.modeling.models as models
import astropy.modeling.parameters as params
from astropy.nddata import StdDevUncertainty
import astropy.units as u
import warnings

import numpy as np
import pytest

from astropy import units as u
from astropy.io import fits
from astropy.io.registry.base import IORegistryError
from astropy.modeling import models, parameters as params
from astropy.nddata import StdDevUncertainty
from astropy.wcs import WCS
from numpy.testing import assert_allclose, assert_array_equal
from specutils.spectra import Spectrum1D

from jdaviz.configs.default.plugins.model_fitting import fitting_backend as fb
Expand Down Expand Up @@ -100,7 +104,7 @@ def test_fitting_backend(unc):

parameters_expected = np.array([0.7, 4.65, 0.3, 2., 5.55, 0.3, -2.,
8.15, 0.2, 1.])
assert np.allclose(fm.parameters, parameters_expected, atol=1e-5)
assert_allclose(fm.parameters, parameters_expected, atol=1e-5)

# Returns the fitted model
fm, fitted_spectrum = fb.fit_model_to_spectrum(spectrum, model_list, expression,
Expand All @@ -109,14 +113,11 @@ def test_fitting_backend(unc):
parameters_expected = np.array([1.0104705, 4.58956282, 0.19590464, 2.39892026,
5.49867754, 0.10834472, -1.66902953, 8.19714439,
0.09535613, 3.99125545])
assert np.allclose(fm.parameters, parameters_expected, atol=1e-5)
assert_allclose(fm.parameters, parameters_expected, atol=1e-5)


# When pytest turns warnings into errors, this silently fails with
# len(fitted_parameters) == 0
@pytest.mark.filterwarnings('ignore')
@pytest.mark.parametrize('unc', ('zeros', None))
def test_cube_fitting_backend(unc):
def test_cube_fitting_backend(cubeviz_helper, unc, tmp_path):
np.random.seed(42)

SIGMA = 0.1 # noise in data
Expand All @@ -140,7 +141,14 @@ def test_cube_fitting_backend(unc):
flux_cube[:, spx[0], spx[1]] = build_spectrum(sigma=SIGMA)[1]

# Transpose so it can be packed in a Spectrum1D instance.
flux_cube = flux_cube.transpose(1, 2, 0)
flux_cube = flux_cube.transpose(1, 2, 0) # (15, 14, 200)
cube_wcs = WCS({
'WCSAXES': 3, 'RADESYS': 'ICRS', 'EQUINOX': 2000.0,
'CRPIX3': 38.0, 'CRPIX2': 38.0, 'CRPIX1': 1.0,
'CRVAL3': 205.4384, 'CRVAL2': 27.004754, 'CRVAL1': 0.0,
'CDELT3': 0.01, 'CDELT2': 0.01, 'CDELT1': 0.05,
'CUNIT3': 'deg', 'CUNIT2': 'deg', 'CUNIT1': 'um',
'CTYPE3': 'RA---TAN', 'CTYPE2': 'DEC--TAN', 'CTYPE1': 'WAVE'})

# Mask part of the spectral axis to later ensure that it gets propagated through:
mask = np.zeros_like(flux_cube).astype(bool)
Expand All @@ -152,7 +160,7 @@ def test_cube_fitting_backend(unc):
elif unc is None:
uncertainties = None

spectrum = Spectrum1D(flux=flux_cube*u.Jy, spectral_axis=x*u.um,
spectrum = Spectrum1D(flux=flux_cube*u.Jy, wcs=cube_wcs,
uncertainty=uncertainties, mask=mask)

# Initial model for fit.
Expand All @@ -168,8 +176,10 @@ def test_cube_fitting_backend(unc):
# n_cpu = 1 # NOTE: UNCOMMENT TO DEBUG LOCALLY, AS NEEDED

# Fit to all spaxels.
fitted_parameters, fitted_spectrum = fb.fit_model_to_spectrum(
spectrum, model_list, expression, n_cpu=n_cpu)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=r"The fit may be unsuccessful.*")
fitted_parameters, fitted_spectrum = fb.fit_model_to_spectrum(
spectrum, model_list, expression, n_cpu=n_cpu)

# Check that parameter results are formatted as expected.
assert type(fitted_parameters) == list
Expand Down Expand Up @@ -198,19 +208,58 @@ def test_cube_fitting_backend(unc):
# interested here in checking the correctness of the data
# packaging into the output products.

assert np.allclose(fitted_model[0].amplitude.value, 1.09, atol=TOL)
assert np.allclose(fitted_model[1].amplitude.value, 2.4, atol=TOL)
assert np.allclose(fitted_model[2].amplitude.value, -1.7, atol=TOL)
assert_allclose(fitted_model[0].amplitude.value, 1.09, atol=TOL)
assert_allclose(fitted_model[1].amplitude.value, 2.4, atol=TOL)
assert_allclose(fitted_model[2].amplitude.value, -1.7, atol=TOL)

assert np.allclose(fitted_model[0].mean.value, 4.6, atol=TOL)
assert np.allclose(fitted_model[1].mean.value, 5.5, atol=TOL)
assert np.allclose(fitted_model[2].mean.value, 8.2, atol=TOL)
assert_allclose(fitted_model[0].mean.value, 4.6, atol=TOL)
assert_allclose(fitted_model[1].mean.value, 5.5, atol=TOL)
assert_allclose(fitted_model[2].mean.value, 8.2, atol=TOL)

assert np.allclose(fitted_model[0].stddev.value, 0.2, atol=TOL)
assert np.allclose(fitted_model[1].stddev.value, 0.1, atol=TOL)
assert np.allclose(fitted_model[2].stddev.value, 0.1, atol=TOL)
assert_allclose(fitted_model[0].stddev.value, 0.2, atol=TOL)
assert_allclose(fitted_model[1].stddev.value, 0.1, atol=TOL)
assert_allclose(fitted_model[2].stddev.value, 0.1, atol=TOL)

assert np.allclose(fitted_model[3].amplitude.value, 4.0, atol=TOL)
assert_allclose(fitted_model[3].amplitude.value, 4.0, atol=TOL)

# Check that the fitted spectrum is masked correctly:
assert np.all(fitted_spectrum.mask == mask)
assert_array_equal(fitted_spectrum.mask, mask)

# Check I/O roundtrip.
out_fn = tmp_path / "fitted_cube.fits"
fitted_spectrum.write(out_fn, format="jdaviz-cube", overwrite=True)
flux_unit_str = fitted_spectrum.flux.unit.to_string(format="fits")
coo_expected = fitted_spectrum.wcs.pixel_to_world(1, 0, 2)
with fits.open(out_fn) as pf:
assert len(pf) == 3
assert pf[0].name == "PRIMARY"
assert pf[1].name == "SCI"
assert pf[1].header["BUNIT"] == flux_unit_str
assert_allclose(pf[1].data, fitted_spectrum.flux.value)
assert pf[2].name == "MASK"
assert_array_equal(pf[2].data, mask)
w = WCS(pf[1].header)
coo = w.pixel_to_world(1, 0, 2)
assert_allclose(coo[0].value, coo_expected[0].value) # SpectralCoord
assert_allclose([coo[1].ra.deg, coo[1].dec.deg],
[coo_expected[1].ra.deg, coo_expected[1].dec.deg])

# Our custom format is not registered to readers, just writers.
# You can read it back in without custom read. See "Cubeviz roundtrip" below.
with pytest.raises(IORegistryError, match="No reader defined"):
Spectrum1D.read(out_fn, format="jdaviz-cube")

# Check Cubeviz roundtrip.
cubeviz_helper.load_data(out_fn)
assert len(cubeviz_helper.app.data_collection) == 2
data_sci = cubeviz_helper.app.data_collection["fitted_cube.fits[SCI]"]
flux_sci = data_sci.get_component("flux")
assert_allclose(flux_sci.data, fitted_spectrum.flux.value)
assert flux_sci.units == flux_unit_str
coo = data_sci.coords.pixel_to_world(1, 0, 2)
assert_allclose(coo[0].value, coo_expected[0].value) # SpectralCoord
assert_allclose([coo[1].ra.deg, coo[1].dec.deg],
[coo_expected[1].ra.deg, coo_expected[1].dec.deg])
data_mask = cubeviz_helper.app.data_collection["fitted_cube.fits[MASK]"]
flux_mask = data_mask.get_component("flux")
assert_array_equal(flux_mask.data, mask)