Skip to content

Commit

Permalink
Change: align_to and shift now return a copy
Browse files Browse the repository at this point in the history
  • Loading branch information
astrofle committed Oct 21, 2024
1 parent bf66119 commit c5a692e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
28 changes: 14 additions & 14 deletions src/dysh/spectra/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,23 +542,26 @@ def shift(self, s, remove_wrap=True, fill_value=np.nan, method="fft"):
"fft" uses a phase shift.
"""

new_data = core.data_shift(self.data, s, remove_wrap=remove_wrap, fill_value=fill_value, method=method)
new_spec = self._copy()
new_data = core.data_shift(new_spec.data, s, remove_wrap=remove_wrap, fill_value=fill_value, method=method)

# Update data values.
self._data = new_data
new_spec._data = new_data

# Update metadata.
self.meta["CRPIX1"] += s
new_spec.meta["CRPIX1"] += s

# Update WCS.
self.wcs.wcs.crpix[0] += s
new_spec.wcs.wcs.crpix[0] += s

# Update `SpectralAxis` values.
# Radial velocity needs to be copied by hand.
radial_velocity = deepcopy(self._spectral_axis._radial_velocity)
new_spectral_axis_values = self.wcs.spectral.pixel_to_world(np.arange(self.flux.shape[-1]))
self._spectral_axis = self.spectral_axis.replicate(value=new_spectral_axis_values)
self._spectral_axis._radial_velocity = radial_velocity
radial_velocity = deepcopy(new_spec._spectral_axis._radial_velocity)
new_spectral_axis_values = new_spec.wcs.spectral.pixel_to_world(np.arange(new_spec.flux.shape[-1]))
new_spec._spectral_axis = new_spec.spectral_axis.replicate(value=new_spectral_axis_values)
new_spec._spectral_axis._radial_velocity = radial_velocity

return new_spec

def find_shift(self, other, units=None, frame=None):
"""
Expand Down Expand Up @@ -633,7 +636,7 @@ def align_to(self, other, units=None, frame=None, remove_wrap=True, fill_value=n
"""

s = self.find_shift(other, units=units, frame=frame)
self.shift(s, remove_wrap=remove_wrap, fill_value=fill_value, method=method)
return self.shift(s, remove_wrap=remove_wrap, fill_value=fill_value, method=method)

@property
def equivalencies(self):
Expand Down Expand Up @@ -1554,12 +1557,9 @@ def average_spectra(spectra, equal_weights=False, align=False):
f"Element {i} of `spectra` has units {s.flux.unit}, but the first element has units {units}."
)
if align:
s_ = s._copy()
if i > 0:
s_.align_to(spectra[0])
else:
s_ = s
data_array[i] = s_.data
s = s.align_to(spectra[0])
data_array[i] = s.data
if not equal_weights:
weights[i] = core.tsys_weight(s.meta["EXPOSURE"], s.meta["CDELT1"], s.meta["TSYS"])
else:
Expand Down
8 changes: 4 additions & 4 deletions src/dysh/spectra/tests/test_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def test_shift(self):

# Apply method to be tested.
shift = 5.5
spec.shift(shift)
spec = spec.shift(shift)

# Internal tests.
assert np.all(np.isnan(spec[: int(np.round(shift))].data))
Expand Down Expand Up @@ -480,20 +480,20 @@ def test_align_to(self):
org_spec = spec._copy()

# Align to itself.
spec.align_to(spec)
spec = spec.align_to(spec)
compare_spectrum(spec, org_spec, ignore_history=True)
assert np.all((spec - org_spec).data == 0)

# Align to a shifted version.
shift = 5
spec.shift(shift)
spec = spec.shift(shift)
assert np.all((spec.data[shift:] - org_spec.data[:-shift]) == 0.0)

# Align to a shifted version with signal.
fshift = 0.5
spec = self.ss._copy()
org_spec = spec._copy()
spec.shift(shift + fshift)
spec = spec.shift(shift + fshift)
# The amplitude of the signal will decrease because of the sampling.
tol = np.sqrt(
(1 - np.exp(-0.5 * (fshift) ** 2 / spec.meta["STDD"] ** 2)) ** 2.0
Expand Down

0 comments on commit c5a692e

Please sign in to comment.