Skip to content

Commit

Permalink
Run formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
robjmcgibbon committed Feb 6, 2024
1 parent 1b49b50 commit 89c67f1
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 16 deletions.
22 changes: 14 additions & 8 deletions swiftemulator/emulators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,38 +124,43 @@ def predict_values(

raise NotImplementedError

def interactive_plot(self, x: np.array, xlabel: str = '', ylabel: str = ''):
def interactive_plot(self, x: np.array, xlabel: str = "", ylabel: str = ""):
"""
Generates an interactive plot over which shows the emulator predictions
for the input data passed to this method
"""
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider

fig, ax = plt.subplots()
model_specification = self.model_specification
param_means = {}
sliders = []
n_param = model_specification.number_of_parameters
fig.subplots_adjust(bottom=0.12+n_param*0.1)
fig.subplots_adjust(bottom=0.12 + n_param * 0.1)
for i in range(n_param):
# Extracting information needed for slider
name = model_specification.parameter_names[i]
lo_lim = sorted(model_specification.parameter_limits[i])[0]
hi_lim = sorted(model_specification.parameter_limits[i])[1]
param_means[name] = (lo_lim + hi_lim)/2
param_means[name] = (lo_lim + hi_lim) / 2

# Adding slider
if model_specification.parameter_printable_names:
name = model_specification.parameter_printable_names[i]
slider_ax = fig.add_axes([0.35, i*0.1, 0.3, 0.1])
slider = Slider(ax=slider_ax, label=name,
valmin=lo_lim, valmax=hi_lim,
valinit=(lo_lim + hi_lim)/2)
slider_ax = fig.add_axes([0.35, i * 0.1, 0.3, 0.1])
slider = Slider(
ax=slider_ax,
label=name,
valmin=lo_lim,
valmax=hi_lim,
valinit=(lo_lim + hi_lim) / 2,
)
sliders.append(slider)

# Setting up initial value
pred, pred_var = self.predict_values(x, param_means)
line, = ax.plot(x, pred)
(line,) = ax.plot(x, pred)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)

Expand All @@ -167,6 +172,7 @@ def update(val):
}
pred, pred_var = self.predict_values(x, params)
line.set_ydata(pred)

for slider in sliders:
slider.on_changed(update)

Expand Down
16 changes: 10 additions & 6 deletions swiftemulator/emulators/multi_gaussian_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,16 @@ def predict_values(

for index, (low, high) in enumerate(self.independent_regions):
mask = np.logical_and(
independent > low
if low is not None
else np.ones_like(independent).astype(bool),
independent < high
if high is not None
else np.ones_like(independent).astype(bool),
(
independent > low
if low is not None
else np.ones_like(independent).astype(bool)
),
(
independent < high
if high is not None
else np.ones_like(independent).astype(bool)
),
)

predicted, errors = self.emulators[index].predict_values(
Expand Down
4 changes: 2 additions & 2 deletions swiftemulator/io/swift.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def load_pipeline_outputs(
"adaptive_mass_function",
"histogram",
]
recursive_search = (
lambda d, k: d.get(k[0], recursive_search(d, k[1:])) if len(k) > 0 else None
recursive_search = lambda d, k: (
d.get(k[0], recursive_search(d, k[1:])) if len(k) > 0 else None
)
line_search = lambda d: recursive_search(d, line_types)

Expand Down

0 comments on commit 89c67f1

Please sign in to comment.