Skip to content

Commit

Permalink
simplified code to compute var stats fits
Browse files Browse the repository at this point in the history
  • Loading branch information
iprafols committed May 19, 2023
1 parent 9aa79cd commit 64b8849
Showing 1 changed file with 12 additions and 21 deletions.
33 changes: 12 additions & 21 deletions py/picca/delta_extraction/expected_fluxes/dr16_expected_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
"use ivar as weight": False,
})

FUDGE_FIT_START = FUDGE_REF
ETA_FIT_START = 1.
VAR_LSS_FIT_START = 0.1
FUDGE_DEFAULT = 0
ETA_DEFAULT = 1.
VAR_LSS_DEFAULT = 0.1


class Dr16ExpectedFlux(ExpectedFlux):
Expand Down Expand Up @@ -204,7 +204,7 @@ def _initialize_get_eta(self):
eta = np.zeros(self.num_bins_variance)
# normal initialization, starting values eta=1, var_lss=0.2 , and fudge=0
else:
eta = np.ones(self.num_bins_variance)
eta = np.zeros(self.num_bins_variance) + ETA_DEFAULT
# this bit is what is actually freeing eta for the fit
self.fit_variance_functions.append("eta")

Expand All @@ -221,7 +221,7 @@ def _initialize_get_fudge(self):
if not self.use_ivar_as_weight and not self.use_constant_weight:
# this bit is what is actually freeing fudge for the fit
self.fit_variance_functions.append("fudge")
fudge = np.zeros(self.num_bins_variance)
fudge = np.zeros(self.num_bins_variance) + FUDGE_DEFAULT
self.get_fudge = interp1d(self.log_lambda_var_func_grid,
fudge,
fill_value='extrapolate',
Expand All @@ -237,7 +237,7 @@ def _initialize_get_var_lss(self):
var_lss = np.ones(self.num_bins_variance)
# normal initialization, starting values eta=1, var_lss=0.2 , and fudge=0
else:
var_lss = np.zeros(self.num_bins_variance) + 0.2
var_lss = np.zeros(self.num_bins_variance) + VAR_LSS_DEFAULT
# this bit is what is actually freeing var_lss for the fit
self.fit_variance_functions.append("var_lss")
self.get_var_lss = interp1d(self.log_lambda_var_func_grid,
Expand Down Expand Up @@ -529,18 +529,9 @@ def compute_var_stats(self, forests):
ExpectedFluxError if wavelength solution is not valid
"""
# initialize arrays
if "eta" in self.fit_variance_functions:
eta = np.zeros(self.num_bins_variance) + ETA_FIT_START
else:
eta = self.get_eta(self.log_lambda_var_func_grid)
if "var_lss" in self.fit_variance_functions:
var_lss = np.zeros(self.num_bins_variance) + VAR_LSS_FIT_START
else:
var_lss = self.get_var_lss(self.log_lambda_var_func_grid)
if "fudge" in self.fit_variance_functions:
fudge = np.zeros(self.num_bins_variance) + FUDGE_FIT_START
else:
fudge = self.get_fudge(self.log_lambda_var_func_grid)
eta = self.get_eta(self.log_lambda_var_func_grid)
var_lss = self.get_var_lss(self.log_lambda_var_func_grid)
fudge = self.get_fudge(self.log_lambda_var_func_grid)
num_pixels = np.zeros(self.num_bins_variance)
valid_fit = np.zeros(self.num_bins_variance)

Expand Down Expand Up @@ -587,9 +578,9 @@ def compute_var_stats(self, forests):
fudge[index] = minimizer.values["fudge"] * FUDGE_REF
valid_fit[index] = True
else:
eta[index] = 1.
var_lss[index] = 0.1
fudge[index] = 1. * FUDGE_REF
eta[index] = ETA_DEFAULT
var_lss[index] = VAR_LSS_DEFAULT
fudge[index] = FUDGE_DEFAULT
valid_fit[index] = False
num_pixels[index] = leasts_squares.get_num_pixels()
chi2_in_bin[index] = minimizer.fval
Expand Down

0 comments on commit 64b8849

Please sign in to comment.