-
Notifications
You must be signed in to change notification settings - Fork 0
/
bayesn_model.py
3200 lines (2782 loc) · 153 KB
/
bayesn_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
BayeSN SED Model. Defines a class which allows you to fit or simulate from the
BayeSN Optical+NIR SED model.
"""
import os
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import interp1d
from scipy.integrate import simpson
import numpyro
from numpyro.infer import MCMC, NUTS, init_to_median, init_to_sample, init_to_value, Predictive, log_likelihood
import numpyro.distributions as dist
from numpyro.optim import Adam
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta, AutoMultivariateNormal, AutoDiagonalNormal, AutoLaplaceApproximation
from numpyro.infer.util import log_density
from numpyro.distributions.transforms import LowerCholeskyAffine
from numpyro.contrib.control_flow import scan
from numpyro import handlers
import h5py
import sncosmo
import spline_utils
import pickle
import pandas as pd
import jax
from jax import device_put
import jax.numpy as jnp
from jax.random import PRNGKey, split, normal
from astropy.cosmology import FlatLambdaCDM
import astropy.table as at
import astropy.constants as const
import matplotlib as mpl
from matplotlib import rc
import extinction
import timeit
from astropy.io import fits
import ruamel.yaml as yaml
import time
from tqdm import tqdm
import arviz.stats.stats as astats
from zltn_utils import *
import timeit
import arviz
# Make plots look pretty
rc('font', **{'family': 'serif', 'serif': ['cmr10']})
mpl.rcParams['axes.unicode_minus'] = False
mpl.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams.update({'font.size': 22})
jax.config.update('jax_platform_name', 'cpu')
print(f'Currently working in {os.getcwd()}')
class SEDmodel(object):
"""
BayeSN-SED Model
Class which imports a BayeSN model, and allows one to fit or simulate
Type Ia supernovae based on this model.
Parameters
----------
num_devices: int, optional
If running on a CPU, numpyro will by default see it as a single device - this argument will set the number
of available cores for numpyro to use e.g. set to 4, you can train 4 chains on 4 cores in parallel. Defaults
to 4
enable_x64: Bool, optional
Determines whether 64-bit precision is used. Often required when training on GPU, typically better left
enabled although worth a try disabled to improve performance depending on your model/initialisation. Defaults to
True
load_model : str, optional
Can be either a pre-defined BayeSN model name (see table below), or
a path to directory containing a set of .txt files from which a
valid model can be constructed. Currently implemented default models
are listed below - default is M20. See README in `BayeSNmodel/model_files`
for more info.
``'M20'`` | Mandel+20 BayeSN model (arXiv:2008.07538). Covers
| rest wavelength range of 3000-18500A (BVRIYJH). No
| treatment of host mass effects. Global RV assumed.
| Trained on low-z Avelino+19 (ApJ, 887, 106)
| compilation of CfA, CSP and others.
``'T21'`` | Thorp+21 No-Split BayeSN model (arXiv:2102:05678). Covers
| rest wavelength range of 3500-9500A (griz). No
| treatment of host mass effects. Global RV assumed.
| Trained on Foundation DR1 (Foley+18, Jones+19).
fiducial_cosmology : dict, optional
Dictionary containg kwargs ``{H0, Om0}`` for initialising a
:py:class:`astropy.cosmology.FlatLambdaCDM` instance. Defaults to
Riess+16 (ApJ, 826, 56) cosmology ``{H0:73.24, "Om0":0.28}``.
obsmodel_file: str, optional
Path to file containing details on all bands loaded into model. Defaults to data/SNmodel_pb_obsmode_map.txt
Attributes
----------
cosmo: :py:class:`astropy.cosmology.FlatLambdaCDM`
:py:class:`astropy.cosmology.FlatLambdaCDM` instance defining the
fiducial cosmology which the model was trained using.
Rv_MW: float
Rv value for calculating Milky Way extinction
scale: float
Scaling factor used when training/fitting in flux space to ensure that flux values are of order unity
sigma_pec: float
Peculiar velocity to be used in calculating redshift uncertainties, set to 150 km/s
l_knots: array-like
Array of wavelength knots which the model is defined at
t_knots: array-like
Array of time knots which the model is defined at
W0: array-like
W0 matrix for loaded model
W1: array-like
W1 matrix for loaded model
L_Sigma: array-like
Covariance matrix describing epsilon distribution for loaded model
M0: float
Reference absolute magnitude for scaling Hsiao template
sigma0: float
Standard deviation of grey offset parameter for loaded model
Rv: float
Global host extinction value for loaded model
tauA: float
Global tauA value for exponential AV prior for loaded model
min_wave: float
Minimum wavelength covered by model, used when preparing band responses
max_wave: float
Maximum wavelength covered by model, used when preparing band responses
spectrum_bins: int
Number of wavelength bins used for modelling spectra and calculating photometry. Based on ParSNiP as presented
in Boone+21
hsiao_flux: array-like
Grid of flux value for Hsiao template
hsiao_t: array-like
Time values corresponding to Hsiao template grid
hsiao_l: array-like
Wavelength values corresponding to Hsiao template grid
Returns
-------
out : :py:class:`bayesn_model.SEDmodel` instance
"""
def __init__(self, num_devices=4, enable_x64=True, load_model='T21_model',
fiducial_cosmology={"H0": 73.24, "Om0": 0.28}, obsmodel_file='data/SNmodel_pb_obsmode_map.txt'):
# Settings for jax/numpyro
numpyro.set_host_device_count(num_devices)
jax.config.update('jax_enable_x64', enable_x64)
print('Current devices:', jax.devices())
print(jax.local_device_count())
self.cosmo = FlatLambdaCDM(**fiducial_cosmology)
self.data = None
self.hsiao_interp = None
self.RV_MW = device_put(jnp.array(3.1))
self.scale = 1e18
self.device_scale = device_put(jnp.array(self.scale))
self.sigma_pec = device_put(jnp.array(150 / 3e5))
# try:
if os.path.exists(f'model_files/{load_model}/BAYESN.YAML'):
with open(f'model_files/{load_model}/BAYESN.YAML', 'r') as file:
params = yaml.load(file, Loader=yaml.Loader)
self.l_knots = jnp.array(params['L_KNOTS'])
self.tau_knots = jnp.array(params['TAU_KNOTS'])
self.W0 = jnp.array(params['W0'])
self.W1 = jnp.array(params['W1'])
self.L_Sigma = jnp.array(params['L_SIGMA_EPSILON'])
self.M0 = jnp.array(params['M0'])
self.sigma0 = jnp.array(params['SIGMA0'])
if 'MU_R' in params.keys():
self.Rv = None
self.mu_R = jnp.array(params['MU_R'])
self.sigma_R = jnp.array(params['SIGMA_R'])
else:
self.Rv = jnp.array(params['RV'])
self.mu_R = None
self.sigma_R = None
self.tauA = jnp.array(params['TAUA'])
else:
self.l_knots = np.genfromtxt(f'model_files/{load_model}/l_knots.txt')
self.tau_knots = np.genfromtxt(f'model_files/{load_model}/tau_knots.txt')
self.W0 = np.genfromtxt(f'model_files/{load_model}/W0.txt')
self.W1 = np.genfromtxt(f'model_files/{load_model}/W1.txt')
self.L_Sigma = np.genfromtxt(f'model_files/{load_model}/L_Sigma_epsilon.txt')
model_params = np.genfromtxt(f'model_files/{load_model}/M0_sigma0_RV_tauA.txt')
self.M0 = device_put(model_params[0])
self.sigma0 = device_put(model_params[1])
self.Rv = device_put(model_params[2])
self.tauA = device_put(model_params[3])
self.l_knots = device_put(self.l_knots)
self.tau_knots = device_put(self.tau_knots)
self.W0 = device_put(self.W0)
self.W1 = device_put(self.W1)
self.L_Sigma = device_put(self.L_Sigma)
# except:
# raise ValueError('Must select one of M20_model, T21_model, T21_partial-split_model and W22_model')
# Initialise arrays and values for band responses - these are based on ParSNiP as presented in Boone+22
self.obsmode_file = obsmodel_file
self._setup_band_weights()
KD_l = spline_utils.invKD_irr(self.l_knots)
self.J_l_T = device_put(spline_utils.spline_coeffs_irr(self.model_wave, self.l_knots, KD_l))
self.KD_t = device_put(spline_utils.invKD_irr(self.tau_knots))
self.load_hsiao_template()
self.ZPT = 27.5
self.J_l_T = device_put(self.J_l_T)
self.hsiao_flux = device_put(self.hsiao_flux)
self.J_l_T_hsiao = device_put(self.J_l_T_hsiao)
self.xk = jnp.array(
[0.0, 1e4 / 26500., 1e4 / 12200., 1e4 / 6000., 1e4 / 5470., 1e4 / 4670., 1e4 / 4110., 1e4 / 2700.,
1e4 / 2600.])
KD_x = spline_utils.invKD_irr(self.xk)
self.M_fitz_block = device_put(spline_utils.spline_coeffs_irr(1e4 / self.model_wave, self.xk, KD_x))
self.J_t_map = jax.jit(jax.vmap(self.spline_coeffs_irr_step, in_axes=(0, None, None)))
def load_hsiao_template(self):
"""
Loads the Hsiao template from the internal HDF5 file.
Stores the template as an attribute of `SEDmodel`.
"""
with h5py.File(os.path.join('data', 'hsiao.h5'), 'r') as file:
data = file['default']
hsiao_phase = data['phase'][()].astype('float64')
hsiao_wave = data['wave'][()].astype('float64')
hsiao_flux = data['flux'][()].astype('float64')
KD_l_hsiao = spline_utils.invKD_irr(hsiao_wave)
self.KD_t_hsiao = device_put(spline_utils.invKD_irr(hsiao_phase))
self.J_l_T_hsiao = device_put(spline_utils.spline_coeffs_irr(self.model_wave,
hsiao_wave, KD_l_hsiao))
self.hsiao_t = device_put(hsiao_phase)
self.hsiao_l = device_put(hsiao_wave)
self.hsiao_flux = device_put(hsiao_flux.T)
self.hsiao_flux = jnp.matmul(self.J_l_T_hsiao, self.hsiao_flux)
def _setup_band_weights(self):
"""
Sets up the interpolation for the band weights used for photometry as well as calculating the zero points for
each band. This code is partly based off ParSNiP from
Boone+21
"""
# Build the model in log wavelength
self.min_wave = self.l_knots[0]
self.max_wave = self.l_knots[-1]
self.spectrum_bins = 300
self.band_oversampling = 51
self.max_redshift = 4
model_log_wave = np.linspace(np.log10(self.min_wave),
np.log10(self.max_wave),
self.spectrum_bins)
model_spacing = model_log_wave[1] - model_log_wave[0]
band_spacing = model_spacing / self.band_oversampling
band_max_log_wave = (
np.log10(self.max_wave * (1 + self.max_redshift))
+ band_spacing
)
# Oversampling must be odd.
assert self.band_oversampling % 2 == 1
pad = (self.band_oversampling - 1) // 2
band_log_wave = np.arange(np.log10(self.min_wave),
band_max_log_wave, band_spacing)
band_wave = 10 ** (band_log_wave)
band_pad_log_wave = np.arange(
np.log10(self.min_wave) - band_spacing * pad,
band_max_log_wave + band_spacing * pad,
band_spacing
)
# Load reference source spectra
with fits.open('data/zp/alpha_lyr_stis_010.fits') as hdu:
vega_df = pd.DataFrame.from_records(hdu[1].data)
vega_lam, vega_f = vega_df.WAVELENGTH, vega_df.FLUX
def f_lam(l):
f = (const.c.to('AA/s').value / 1e23) * ((l) ** -2) * 10 ** (-48.6 / 2.5) * 1e23
return f
band_weights, zps = [], []
self.band_dict, self.zp_dict, self.band_lim_dict = {}, {}, {}
obsmode = pd.read_csv(self.obsmode_file, delim_whitespace=True)
band_ind = 0
for i, row in obsmode.iterrows():
band, magsys = row.pb, row.magsys
try:
R = np.loadtxt(os.path.join('data', row.obsmode))
except:
continue
band_transmission = np.interp(10 ** band_pad_log_wave, R[:, 0], R[:, 1])
band_low_lim = R[np.where(R[:, 1] > 0.01 * R[:, 1].max())[0][0], 0]
# Convolve the bands to match the sampling of the spectrum.
band_conv_transmission = jnp.interp(band_wave, 10 ** band_pad_log_wave, band_transmission)
dlamba = jnp.diff(band_wave)
dlamba = jnp.r_[dlamba, dlamba[-1]]
num = band_wave * band_conv_transmission * dlamba
denom = jnp.sum(num)
band_weight = num / denom
band_weights.append(band_weight)
# Get zero points
lam = R[:, 0]
if row.magsys == 'abmag':
zp = f_lam(lam)
elif row.magsys == 'vegamag':
zp = interp1d(vega_lam, vega_f, kind='cubic')(lam)
else:
continue
int1 = simpson(lam * zp * R[:, 1], lam)
int2 = simpson(lam * R[:, 1], lam)
zp = 2.5 * np.log10(int1 / int2)
self.band_dict[band] = band_ind
self.band_lim_dict[band] = band_low_lim
self.zp_dict[band] = zp
zps.append(zp)
band_ind += 1
self.zps = jnp.array(zps)
self.inv_band_dict = {val: key for key, val in self.band_dict.items()}
# Get the locations that should be sampled at redshift 0. We can scale these to
# get the locations at any redshift.
band_interpolate_locations = jnp.arange(
0,
self.spectrum_bins * self.band_oversampling,
self.band_oversampling
)
# Save the variables that we need to do interpolation.
self.band_interpolate_locations = device_put(band_interpolate_locations)
self.band_interpolate_spacing = band_spacing
self.band_interpolate_weights = jnp.array(band_weights)
self.model_wave = 10 ** (model_log_wave)
def _calculate_band_weights(self, redshifts, ebv):
"""
Calculates the observer-frame band weights, including the effect of Milky Way extinction, for each SN
Parameters
----------
redshifts: array-like
Array of redshifts for each SN
ebv: array-like
Array of Milky Way E(B-V) values for each SN
Returns
-------
weights: array-like
Array containing observer-frame band weights
"""
# Figure out the locations to sample at for each redshift.
locs = (
self.band_interpolate_locations
+ jnp.log10(1 + redshifts)[:, None] / self.band_interpolate_spacing
)
flat_locs = locs.flatten()
# Linear interpolation
int_locs = flat_locs.astype(jnp.int32)
remainders = flat_locs - int_locs
start = self.band_interpolate_weights[..., int_locs]
end = self.band_interpolate_weights[..., int_locs + 1]
flat_result = remainders * end + (1 - remainders) * start
weights = flat_result.reshape((-1,) + locs.shape).transpose(1, 2, 0)
# Normalise so max transmission = 1
sum = jnp.sum(weights, axis=1)
weights /= sum[:, None, :]
# Apply MW extinction
abv = self.RV_MW * ebv
mw_array = jnp.zeros((weights.shape[0], weights.shape[1]))
for i, val in enumerate(abv):
mw = jnp.power(10, -0.4 * extinction.fitzpatrick99(self.model_wave * (1 + np.array(redshifts[i])), val,
self.RV_MW))
mw_array = mw_array.at[i, :].set(mw)
weights = weights * mw_array[..., None]
# We need an extra term of 1 + z from the filter contraction.
weights /= (1 + redshifts)[:, None, None]
return weights
def get_spectra(self, theta, Av, W0, W1, eps, Rv, J_t, hsiao_interp):
"""
Calculates rest-frame spectra for given parameter values
Parameters
----------
theta: array-like
Set of theta values for each SN
Av: array-like
Set of host extinction values for each SN
W0: array-like
Global W0 matrix
W1: array-like
Global W1 matrix
eps: array-like
Set of epsilon values for each SN, describing residual colour variation
Rv: float
Global R_V value for host extinction (need to allow this to be variable in future)
J_t: array-like
Matrix for cubic spline interpolation in time axis for each SN
hsiao_interp: array-like
Array containing Hsiao template spectra for each t value, comprising model for previous day, next day and
t % 1 to allow for linear interpolation
Returns
-------
model_spectra: array-like
Matrix containing model spectra for all SNe at all time-steps
"""
num_batch = theta.shape[0]
W0 = jnp.repeat(W0[None, ...], num_batch, axis=0)
W1 = jnp.repeat(W1[None, ...], num_batch, axis=0)
W = W0 + theta[..., None, None] * W1 + eps
WJt = jnp.matmul(W, J_t)
W_grid = jnp.matmul(self.J_l_T, WJt)
low_hsiao = self.hsiao_flux[:, hsiao_interp[0, ...].astype(int)]
up_hsiao = self.hsiao_flux[:, hsiao_interp[1, ...].astype(int)]
H_grid = ((1 - hsiao_interp[2, :]) * low_hsiao + hsiao_interp[2, :] * up_hsiao).transpose(2, 0, 1)
model_spectra = H_grid * 10 ** (-0.4 * W_grid)
# Extinction----------------------------------------------------------
f99_x0 = 4.596
f99_gamma = 0.99
f99_c2 = -0.824 + 4.717 / Rv
f99_c1 = 2.030 - 3.007 * f99_c2
f99_c3 = 3.23
f99_c4 = 0.41
f99_c5 = 5.9
f99_d1 = self.xk[7] ** 2 / ((self.xk[7] ** 2 - f99_x0 ** 2) ** 2 + (f99_gamma * self.xk[7]) ** 2)
f99_d2 = self.xk[8] ** 2 / ((self.xk[8] ** 2 - f99_x0 ** 2) ** 2 + (f99_gamma * self.xk[8]) ** 2)
yk = jnp.zeros((num_batch, 9))
yk = yk.at[:, 0].set(-Rv)
yk = yk.at[:, 1].set(0.26469 * Rv / 3.1 - Rv)
yk = yk.at[:, 2].set(0.82925 * Rv / 3.1 - Rv)
yk = yk.at[:, 3].set(-0.422809 + 1.00270 * Rv + 2.13572e-4 * Rv ** 2 - Rv)
yk = yk.at[:, 4].set(-5.13540e-2 + 1.00216 * Rv - 7.35778e-5 * Rv ** 2 - Rv)
yk = yk.at[:, 5].set(0.700127 + 1.00184 * Rv - 3.32598e-5 * Rv ** 2 - Rv)
yk = yk.at[:, 6].set(
1.19456 + 1.01707 * Rv - 5.46959e-3 * Rv ** 2 + 7.97809e-4 * Rv ** 3 - 4.45636e-5 * Rv ** 4 - Rv)
yk = yk.at[:, 7].set(f99_c1 + f99_c2 * self.xk[7] + f99_c3 * f99_d1)
yk = yk.at[:, 8].set(f99_c1 + f99_c2 * self.xk[8] + f99_c3 * f99_d2)
A = Av[..., None] * (1 + (self.M_fitz_block @ yk.T).T / Rv[..., None]) # Rv[..., None]
f_A = 10 ** (-0.4 * A)
model_spectra = model_spectra * f_A[..., None]
return model_spectra
def get_flux_batch(self, theta, Av, W0, W1, eps, Ds, Rv, band_indices, mask, J_t, hsiao_interp, weights):
"""
Calculates observer-frame fluxes for given parameter values
Parameters
----------
theta: array-like
Set of theta values for each SN
Av: array-like
Set of host extinction values for each SN
W0: array-like
Global W0 matrix
W1: array-like
Global W1 matrix
eps: array-like
Set of epsilon values for each SN, describing residual colour variation
Ds: array-like
Set of distance moduli for each SN
Rv: float
Global R_V value for host extinction (need to allow this to be variable in future)
band_indices: array-like
Array containing indices describing which filter each observation is in
mask: array-like
Array containing mask describing whether observations should contribute to the posterior
J_t: array-like
Matrix for cubic spline interpolation in time axis for each SN
hsiao_interp: array-like
Array containing Hsiao template spectra for each t value, comprising model for previous day, next day and
t % 1 to allow for linear interpolation
weights: array_like
Array containing band weights to use for photometry
Returns
-------
model_flux: array-like
Matrix containing model fluxes for all SNe at all time-steps
"""
num_batch = theta.shape[0]
num_observations = band_indices.shape[0]
model_spectra = self.get_spectra(theta, Av, W0, W1, eps, Rv, J_t, hsiao_interp)
batch_indices = (
jnp.arange(num_batch)
.repeat(num_observations)
).astype(int)
obs_band_weights = (
weights[batch_indices, :, band_indices.T.flatten()]
.reshape((num_batch, num_observations, -1))
.transpose(0, 2, 1)
)
model_flux = jnp.sum(model_spectra * obs_band_weights, axis=1).T
model_flux = model_flux * 10 ** (-0.4 * (self.M0 + Ds))
zps = self.zps[band_indices]
zp_flux = 10 ** (zps / 2.5)
#model_flux = model_flux * self.device_scale
model_flux = (model_flux / zp_flux) * 10 ** (0.4 * 27.5) # Convert to FLUXCAL
model_flux *= mask
return model_flux
def get_mag_batch(self, theta, Av, W0, W1, eps, Ds, Rv, band_indices, mask, J_t, hsiao_interp, weights):
"""
Calculates observer-frame magnitudes for given parameter values
Parameters
----------
theta: array-like
Set of theta values for each SN
Av: array-like
Set of host extinction values for each SN
W0: array-like
Global W0 matrix
W1: array-like
Global W1 matrix
eps: array-like
Set of epsilon values for each SN, describing residual colour variation
Ds: array-like
Set of distance moduli for each SN
Rv: float
Global R_V value for host extinction (need to allow this to be variable in future)
band_indices: array-like
Array containing indices describing which filter each observation is in
mask: array-like
Array containing mask describing whether observations should contribute to the posterior
J_t: array-like
Matrix for cubic spline interpolation in time axis for each SN
hsiao_interp: array-like
Array containing Hsiao template spectra for each t value, comprising model for previous day, next day and
t % 1 to allow for linear interpolation
weights: array_like
Array containing band weights to use for photometry
Returns
-------
model_mag: array-like
Matrix containing model magnitudes for all SNe at all time-steps
"""
model_flux = self.get_flux_batch(theta, Av, W0, W1, eps, Ds, Rv, band_indices, mask, J_t, hsiao_interp, weights)
#model_flux = model_flux / self.device_scale
model_flux = model_flux + (1 - mask) * 0.01
#zps = self.zps[band_indices]
#model_mag = - 2.5 * jnp.log10(model_flux) + zps # self.M0 + Ds
model_mag = - 2.5 * jnp.log10(model_flux) + 27.5 # self.M0 + Ds
model_mag *= mask
return model_mag
@staticmethod
def spline_coeffs_irr_step(x_now, x, invkd):
"""
Vectorized version of cubic spline coefficient calculator found in spline_utils
Parameters
----------
x_now: array-like
Current x location to calculate spline knots for
x: array-like
Numpy array containing the locations of the spline knots.
invkd: array-like
Precomputed matrix for generating second derivatives. Can be obtained
from the output of ``spline_utils.invKD_irr``.
Returns
-------
X: Set of spline coefficients for each x knot
"""
X = jnp.zeros_like(x)
up_extrap = x_now > x[-1]
down_extrap = x_now < x[0]
interp = 1 - up_extrap - down_extrap
h = x[-1] - x[-2]
a = (x[-1] - x_now) / h
b = 1 - a
f = (x_now - x[-1]) * h / 6.0
X = X.at[-2].set(X[-2] + a * up_extrap)
X = X.at[-1].set(X[-1] + b * up_extrap)
X = X.at[:].set(X[:] + f * invkd[-2, :] * up_extrap)
h = x[1] - x[0]
b = (x_now - x[0]) / h
a = 1 - b
f = (x_now - x[0]) * h / 6.0
X = X.at[0].set(X[0] + a * down_extrap)
X = X.at[1].set(X[1] + b * down_extrap)
X = X.at[:].set(X[:] - f * invkd[1, :] * down_extrap)
q = jnp.argmax(x_now < x) - 1
h = x[q + 1] - x[q]
a = (x[q + 1] - x_now) / h
b = 1 - a
c = ((a ** 3 - a) / 6) * h ** 2
d = ((b ** 3 - b) / 6) * h ** 2
X = X.at[q].set(X[q] + a * interp)
X = X.at[q + 1].set(X[q + 1] + b * interp)
X = X.at[:].set(X[:] + c * invkd[q, :] * interp + d * invkd[q + 1, :] * interp)
return X
def fit_model_vi(self, obs, weights):
"""
Numpyro model used for fitting SN properties assuming fixed global properties from a trained model. Will fit for tmax
as well as theta, epsilon, Av and distance modulus
Parameters
----------
obs: array-like
Data to fit, from output of process_dataset
weights: array-like
Band-weights to calculate photometry
"""
sample_size = obs.shape[-1]
N_knots_sig = (self.l_knots.shape[0] - 2) * self.tau_knots.shape[0]
with numpyro.plate('SNe', sample_size) as sn_index:
Av = numpyro.sample(f'AV', My_Exponential(1 / self.tauA))
# Av = numpyro.sample(f'AV', dist.Uniform(-10, 10))
# print("Model AV:", Av)
theta = numpyro.sample(f'theta', dist.Normal(0, 1.0))
# Rv = numpyro.sample('Rv', dist.Normal(self.mu_R, self.sigma_R))
# tmax = numpyro.sample('tmax', dist.Uniform(-10, 10))
tmax = numpyro.sample('tmax', dist.Normal(0, 5))
# tmax = numpyro.sample('tmax', dist.Normal(0, 0.003))
# tmax = jnp.asarray([6])
t = obs[0, ...] - tmax[None, sn_index]
hsiao_interp = jnp.array([19 + jnp.floor(t), 19 + jnp.ceil(t), jnp.remainder(t, 1)])
keep_shape = t.shape
t = t.flatten(order='F')
# J_t = jax.vmap(self.spline_coeffs_irr_step, in_axes=(0, None, None))(t, self.tau_knots, self.KD_t).reshape((*keep_shape, self.tau_knots.shape[0]),
# order='F').transpose(1, 2, 0)
J_t = self.J_t_map(t, self.tau_knots, self.KD_t).reshape((*keep_shape, self.tau_knots.shape[0]),
order='F').transpose(1, 2, 0)
eps_mu = jnp.zeros(N_knots_sig)
# eps = numpyro.sample('eps', dist.MultivariateNormal(eps_mu, scale_tril=self.L_Sigma))
eps_tform = numpyro.sample('eps_tform', dist.MultivariateNormal(eps_mu, jnp.eye(N_knots_sig)))
eps_tform = eps_tform.T
eps = numpyro.deterministic('eps', jnp.matmul(self.L_Sigma, eps_tform))
eps = eps.T
eps = jnp.reshape(eps, (sample_size, self.l_knots.shape[0] - 2, self.tau_knots.shape[0]), order='F')
eps_full = jnp.zeros((sample_size, self.l_knots.shape[0], self.tau_knots.shape[0]))
eps = eps_full.at[:, 1:-1, :].set(eps)
# eps = jnp.zeros((sample_size, self.l_knots.shape[0], self.tau_knots.shape[0]))
band_indices = obs[-6, :, sn_index].astype(int).T
muhat = obs[-3, 0, sn_index]
mask = obs[-1, :, sn_index].T.astype(bool)
muhat_err = 5
Ds_err = jnp.sqrt(muhat_err * muhat_err + self.sigma0 * self.sigma0)
# Ds = numpyro.sample('Ds', dist.ImproperUniform(dist.constraints.greater_than(0), (), event_shape=()))
Ds = numpyro.sample('Ds', dist.Normal(muhat, Ds_err)) # Ds_err
flux = self.get_flux_batch(theta, Av, self.W0, self.W1, eps, Ds, self.Rv, band_indices, mask,
J_t, hsiao_interp, weights)
with numpyro.handlers.mask(mask=mask):
numpyro.sample(f'obs', dist.Normal(flux, obs[2, :, sn_index].T),
obs=obs[1, :, sn_index].T) # _{sn_index}
def fit_model_vi_no_eps(self, obs, weights):
"""
Numpyro model used for fitting SN properties assuming fixed global properties from a trained model. Will fit for tmax
as well as theta, epsilon, Av and distance modulus
Parameters
----------
obs: array-like
Data to fit, from output of process_dataset
weights: array-like
Band-weights to calculate photometry
"""
sample_size = obs.shape[-1]
N_knots_sig = (self.l_knots.shape[0] - 2) * self.tau_knots.shape[0]
with numpyro.plate('SNe', sample_size) as sn_index:
Av = numpyro.sample(f'AV', My_Exponential(1 / self.tauA))
# print("Model AV:", Av)
theta = numpyro.sample(f'theta', dist.Normal(0, 1.0))
# Rv = numpyro.sample('Rv', dist.Normal(self.mu_R, self.sigma_R))
# tmax = numpyro.sample('tmax', dist.Uniform(-10, 10))
# tmax = numpyro.sample('tmax', dist.Normal(0, 0.003))
tmax = numpyro.sample('tmax', dist.Normal(0, 5))
# tmax = jnp.asarray([6])
t = obs[0, ...] - tmax[None, sn_index]
hsiao_interp = jnp.array([19 + jnp.floor(t), 19 + jnp.ceil(t), jnp.remainder(t, 1)])
keep_shape = t.shape
t = t.flatten(order='F')
# J_t = jax.vmap(self.spline_coeffs_irr_step, in_axes=(0, None, None))(t, self.tau_knots, self.KD_t).reshape((*keep_shape, self.tau_knots.shape[0]),
# order='F').transpose(1, 2, 0)
J_t = self.J_t_map(t, self.tau_knots, self.KD_t).reshape((*keep_shape, self.tau_knots.shape[0]),
order='F').transpose(1, 2, 0)
eps_mu = jnp.zeros(N_knots_sig)
# eps = numpyro.sample('eps', dist.MultivariateNormal(eps_mu, scale_tril=self.L_Sigma))
eps_tform = eps_mu
eps_tform = eps_tform.T
eps = numpyro.deterministic('eps', jnp.matmul(self.L_Sigma, eps_tform))
eps = eps.T
eps = jnp.reshape(eps, (sample_size, self.l_knots.shape[0] - 2, self.tau_knots.shape[0]), order='F')
eps_full = jnp.zeros((sample_size, self.l_knots.shape[0], self.tau_knots.shape[0]))
eps = eps_full.at[:, 1:-1, :].set(eps)
# eps = jnp.zeros((sample_size, self.l_knots.shape[0], self.tau_knots.shape[0]))
band_indices = obs[-6, :, sn_index].astype(int).T
muhat = obs[-3, 0, sn_index]
mask = obs[-1, :, sn_index].T.astype(bool)
muhat_err = 5
Ds_err = jnp.sqrt(muhat_err * muhat_err + self.sigma0 * self.sigma0)
# Ds = numpyro.sample('Ds', dist.ImproperUniform(dist.constraints.greater_than(0), (), event_shape=()))
Ds = numpyro.sample('Ds', dist.Normal(muhat, Ds_err)) # Ds_err
flux = self.get_flux_batch(theta, Av, self.W0, self.W1, eps, Ds, self.Rv, band_indices, mask,
J_t, hsiao_interp, weights)
with numpyro.handlers.mask(mask=mask):
numpyro.sample(f'obs', dist.Normal(flux, obs[2, :, sn_index].T),
obs=obs[1, :, sn_index].T) # _{sn_index}
def fit_model_mcmc(self, obs, weights, epsilons_on = True):
"""
Numpyro model used for fitting SN properties assuming fixed global properties from a trained model. Will fit for tmax
as well as theta, epsilon, Av and distance modulus
Parameters
----------
obs: array-like
Data to fit, from output of process_dataset
weights: array-like
Band-weights to calculate photometry
"""
sample_size = obs.shape[-1]
N_knots_sig = (self.l_knots.shape[0] - 2) * self.tau_knots.shape[0]
with numpyro.plate('SNe', sample_size) as sn_index:
Av = numpyro.sample(f'AV', dist.Exponential(1 / self.tauA)) # TODO CHANGE THIS BACK
# Av = numpyro.sample(f'AV', dist.Uniform(-10, 10))
# print("Model AV:", Av)
theta = numpyro.sample(f'theta', dist.Normal(0, 1.0))
# Rv = numpyro.sample('Rv', dist.Normal(self.mu_R, self.sigma_R))
# tmax = numpyro.sample('tmax', dist.Uniform(-10, 10))
# tmax = numpyro.sample('tmax', dist.Normal(0, 0.003))
tmax = numpyro.sample('tmax', dist.Normal(0, 5))
# tmax = jnp.asarray([6])
t = obs[0, ...] - tmax[None, sn_index]
hsiao_interp = jnp.array([19 + jnp.floor(t), 19 + jnp.ceil(t), jnp.remainder(t, 1)])
keep_shape = t.shape
t = t.flatten(order='F')
# J_t = jax.vmap(self.spline_coeffs_irr_step, in_axes=(0, None, None))(t, self.tau_knots, self.KD_t).reshape((*keep_shape, self.tau_knots.shape[0]),
# order='F').transpose(1, 2, 0)
J_t = self.J_t_map(t, self.tau_knots, self.KD_t).reshape((*keep_shape, self.tau_knots.shape[0]),
order='F').transpose(1, 2, 0)
eps_mu = jnp.zeros(N_knots_sig)
# eps = numpyro.sample('eps', dist.MultivariateNormal(eps_mu, scale_tril=self.L_Sigma))
eps_tform = numpyro.sample('eps_tform', dist.MultivariateNormal(eps_mu, jnp.eye(N_knots_sig)))
eps_tform = eps_tform.T
eps = numpyro.deterministic('eps', jnp.matmul(self.L_Sigma, eps_tform))
eps = eps.T
eps = jnp.reshape(eps, (sample_size, self.l_knots.shape[0] - 2, self.tau_knots.shape[0]), order='F')
eps_full = jnp.zeros((sample_size, self.l_knots.shape[0], self.tau_knots.shape[0]))
eps = eps_full.at[:, 1:-1, :].set(eps)
# eps = jnp.zeros((sample_size, self.l_knots.shape[0], self.tau_knots.shape[0]))
band_indices = obs[-6, :, sn_index].astype(int).T
muhat = obs[-3, 0, sn_index]
mask = obs[-1, :, sn_index].T.astype(bool)
muhat_err = 5
Ds_err = jnp.sqrt(muhat_err * muhat_err + self.sigma0 * self.sigma0)
# Ds = numpyro.sample('Ds', dist.ImproperUniform(dist.constraints.greater_than(0), (), event_shape=()))
Ds = numpyro.sample('Ds', dist.Normal(muhat, Ds_err)) # Ds_err
flux = self.get_flux_batch(theta, Av, self.W0, self.W1, eps, Ds, self.Rv, band_indices, mask,
J_t, hsiao_interp, weights)
with numpyro.handlers.mask(mask=mask):
numpyro.sample(f'obs', dist.Normal(flux, obs[2, :, sn_index].T),
obs=obs[1, :, sn_index].T) # _{sn_index}
def fit_model_mcmc_no_eps(self, obs, weights):
"""
Numpyro model used for fitting SN properties assuming fixed global properties from a trained model. Will fit for tmax
as well as theta, epsilon, Av and distance modulus
Parameters
----------
obs: array-like
Data to fit, from output of process_dataset
weights: array-like
Band-weights to calculate photometry
"""
sample_size = obs.shape[-1]
N_knots_sig = (self.l_knots.shape[0] - 2) * self.tau_knots.shape[0]
with numpyro.plate('SNe', sample_size) as sn_index:
Av = numpyro.sample(f'AV', dist.Exponential(1 / self.tauA))
# print("Model AV:", Av)
theta = numpyro.sample(f'theta', dist.Normal(0, 1.0))
# Rv = numpyro.sample('Rv', dist.Normal(self.mu_R, self.sigma_R))
# tmax = numpyro.sample('tmax', dist.Uniform(-10, 10))
# tmax = numpyro.sample('tmax', dist.Normal(0, 0.003))
tmax = numpyro.sample('tmax', dist.Normal(0, 5))
# tmax = jnp.asarray([6])
t = obs[0, ...] - tmax[None, sn_index]
hsiao_interp = jnp.array([19 + jnp.floor(t), 19 + jnp.ceil(t), jnp.remainder(t, 1)])
keep_shape = t.shape
t = t.flatten(order='F')
# J_t = jax.vmap(self.spline_coeffs_irr_step, in_axes=(0, None, None))(t, self.tau_knots, self.KD_t).reshape((*keep_shape, self.tau_knots.shape[0]),
# order='F').transpose(1, 2, 0)
J_t = self.J_t_map(t, self.tau_knots, self.KD_t).reshape((*keep_shape, self.tau_knots.shape[0]),
order='F').transpose(1, 2, 0)
eps_mu = jnp.zeros(N_knots_sig)
# eps = numpyro.sample('eps', dist.MultivariateNormal(eps_mu, scale_tril=self.L_Sigma))
eps_tform = eps_mu
eps_tform = eps_tform.T
eps = numpyro.deterministic('eps', jnp.matmul(self.L_Sigma, eps_tform))
eps = eps.T
eps = jnp.reshape(eps, (sample_size, self.l_knots.shape[0] - 2, self.tau_knots.shape[0]), order='F')
eps_full = jnp.zeros((sample_size, self.l_knots.shape[0], self.tau_knots.shape[0]))
eps = eps_full.at[:, 1:-1, :].set(eps)
# eps = jnp.zeros((sample_size, self.l_knots.shape[0], self.tau_knots.shape[0]))
band_indices = obs[-6, :, sn_index].astype(int).T
muhat = obs[-3, 0, sn_index]
mask = obs[-1, :, sn_index].T.astype(bool)
muhat_err = 5
Ds_err = jnp.sqrt(muhat_err * muhat_err + self.sigma0 * self.sigma0)
# Ds = numpyro.sample('Ds', dist.ImproperUniform(dist.constraints.greater_than(0), (), event_shape=()))
Ds = numpyro.sample('Ds', dist.Normal(muhat, Ds_err)) # Ds_err
flux = self.get_flux_batch(theta, Av, self.W0, self.W1, eps, Ds, self.Rv, band_indices, mask,
J_t, hsiao_interp, weights)
with numpyro.handlers.mask(mask=mask):
numpyro.sample(f'obs', dist.Normal(flux, obs[2, :, sn_index].T),
obs=obs[1, :, sn_index].T) # _{sn_index}
def fit(self, num_samples, num_warmup, num_chains, output, epsilons_on, model_path=None, chain_method='parallel',
init_strategy='median'):
"""
Function to run fitting process and save chains and fit statistics. I'm still experimenting with the best way to
do this - you can either run lots of separate HMC processes or you can do one big process which fits all SNe
(with different SN parameters treated as independent). The latter has advantages as all flux integrals across
all objects are calculated in one tensor operation, but the downside is that it can make it more difficult to
converge as the parameter space grows.
Parameters
----------
num_samples: int
Number of posterior samples
num_warmup: int
Number of warmup steps before sampling
num_chains: int
Number of chains
output: str
Name of output directory which will store results
model_path: str, optional
Name of directory containing model parameters to use for fitting. I'm using this for now to keep my
numpyro trained models separate from T21/M20/W22 etc. until we're confident with them. Defaults to None,
which means that the model loaded when initialising the SEDmodel object is used.
chain_method: str, optional
Method used to distribute different chains, defaults to parallel. Options are:
``'sequential'`` | Chains are run one after the other.
``'parallel'`` | Chains are spread in parallel across all available devices and run simultaneously. If you
|try to run more chains than there are devices available, numpyro will automatically revert
|from parallel to sequential
``'vectorized'`` | Chains are run simultaneously on a single device. Only really advisable on a GPU, and
| will probably lead to a memory error on CPU
init_strategy: str, optional
Strategy to use for initialisation, default to median. Options are:
``'median'`` | Chains are initialised to prior media
``'sample'`` | Chains are initialised to a random sample from the priors
"""
if init_strategy == 'median':
init_strategy = init_to_median()
elif init_strategy == 'value':
init_strategy = init_to_value(values=self.fit_initial_guess())
elif init_strategy == 'map':
init_strategy = init_to_value(self.map_initial_guess(mode='fit'))
elif init_strategy == 'sample':
init_strategy = init_to_sample()
else:
raise ValueError('Invalid init strategy, must be one of median or sample')
if model_path is not None:
with open(os.path.join('results', model_path, 'chains.pkl'), 'rb') as file:
result = pickle.load(file)
self.W0 = device_put(
np.reshape(np.mean(result['W0'], axis=(0, 1)), (self.l_knots.shape[0], self.tau_knots.shape[0]),
order='F'))
self.W1 = device_put(
np.reshape(np.mean(result['W1'], axis=(0, 1)), (self.l_knots.shape[0], self.tau_knots.shape[0]),
order='F'))
# sigmaepsilon = np.mean(result['sigmaepsilon'], axis=(0, 1))
# L_Omega = np.mean(result['L_Omega'], axis=(0, 1))
# self.L_Sigma = device_put(jnp.matmul(jnp.diag(sigmaepsilon), L_Omega))
self.Rv = device_put(np.mean(result['Rv'], axis=(0, 1)))
self.sigma0 = device_put(np.mean(result['sigma0'], axis=(0, 1)))
self.tauA = device_put(np.mean(result['tauA'], axis=(0, 1)))
rng = PRNGKey(321)
rng, rng_ = split(rng)
if epsilons_on:
model = self.fit_model_mcmc
else:
model = self.fit_model_mcmc_no_eps
nuts_kernel = NUTS(model, adapt_step_size=True, init_strategy=init_strategy, max_tree_depth=10)
start = timeit.default_timer()
mcmc = MCMC(nuts_kernel, num_samples=num_samples, num_warmup=num_warmup, num_chains=num_chains,
chain_method=chain_method)
mcmc.run(rng, self.data, self.band_weights)
mcmc.print_summary()
samples = mcmc.get_samples(group_by_chain=True)
end = timeit.default_timer()
print('original: ', end - start)
self.fit_postprocess_samples(samples, output)
def fit_mcmc_vmap(self, data, band_weights, num_samples=250, num_warmup=250, num_chains=4, model_path=None, chain_method='parallel',
init_strategy='median'):
"""
Function to run fitting process and save chains and fit statistics. I'm still experimenting with the best way to
do this - you can either run lots of separate HMC processes or you can do one big process which fits all SNe
(with different SN parameters treated as independent). The latter has advantages as all flux integrals across
all objects are calculated in one tensor operation, but the downside is that it can make it more difficult to
converge as the parameter space grows.
Parameters
----------
num_samples: int
Number of posterior samples
num_warmup: int
Number of warmup steps before sampling
num_chains: int
Number of chains
output: str
Name of output directory which will store results
model_path: str, optional
Name of directory containing model parameters to use for fitting. I'm using this for now to keep my
numpyro trained models separate from T21/M20/W22 etc. until we're confident with them. Defaults to None,
which means that the model loaded when initialising the SEDmodel object is used.