Skip to content

Commit

Permalink
Add initial_params argument
Browse files Browse the repository at this point in the history
  • Loading branch information
robjmcgibbon committed Feb 7, 2024
1 parent f42aba2 commit 3dce39e
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
16 changes: 11 additions & 5 deletions docs/source/emulator_analysis/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,23 @@ Another way to explore the effect of varying the parameters is
to try an interactive plot. Every emulator object contains an
`interactive_plot` method. This generates a plot with a slider
for each parameter. The plot will update to show the emulator
predictions when sliders are adjusted.
predictions when sliders are adjusted. The emulator will make
its initial prediction using the parameter values passed to it.
If no parameters are passed if will default to the midpoint of
each parameter range. It is also possible to pass reference data
to overplot on the emulator predictions. If no reference data is
passed the plot will display a fixed dashed line corresponding to
the prediction using the initial parameter values.

.. code-block:: python
schecter_emulator.interactive_plot(predict_x, xlabel="Stellar mass", ylabel="dn/dlogM")
schecter_emulator.interactive_plot(predict_x, initial_params=center,
xlabel="Stellar mass", ylabel="dn/dlogM",
x_data=[10.5, 11, 11.5],
y_data=[-10, -11, -12])
.. image:: interactive_plot.png

It is possible to pass reference data to be plotted when calling
:meth:`swiftemulator.emulators.base.BaseEmulator.interactive\_plot`.


Model Parameters Features
-------------------------
Expand Down
Binary file modified docs/source/emulator_analysis/interactive_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 18 additions & 10 deletions swiftemulator/emulators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,23 +127,30 @@ def predict_values(
def interactive_plot(
self,
x: np.array,
initial_params: Dict[str, float] = {},
xlabel: str = "",
ylabel: str = "",
x_data: np.array = None,
y_data: np.array = None,
):
"""
Generates an interactive plot which displays the emulator predictions.
If no reference data is passed to be overplotted then the plot will
display a line which corresponds to the predictions for the mean
of the parameter values.
If initial_params should contain the initial parameter values to make a
prediction for. If initial_params is not passed the midpoint of each of
the parameter values will be used instead. If no reference data is
passed to be overplotted then the plot will display a line which
corresponds to the predictions for the initial parameter values.
Parameters
----------
x: np.array
Array of data for which the emulator should make predictions.
initial_params: Dict[str, float], optional
What parameters values to plot the predicition for initally.
If missing the midpoint of each parameter range will be used.
xlabel: str, optional
Label for horizontal axis on the resultant figure.
Expand All @@ -162,7 +169,6 @@ def interactive_plot(

fig, ax = plt.subplots()
model_specification = self.model_specification
param_means = {}
sliders = []
n_param = model_specification.number_of_parameters
fig.subplots_adjust(bottom=0.12 + n_param * 0.1)
Expand All @@ -171,23 +177,25 @@ def interactive_plot(
name = model_specification.parameter_names[i]
lo_lim = sorted(model_specification.parameter_limits[i])[0]
hi_lim = sorted(model_specification.parameter_limits[i])[1]
param_means[name] = (lo_lim + hi_lim) / 2
if not name in initial_params:
initial_params[name] = (lo_lim + hi_lim) / 2

# Adding slider
printable_name = name
if model_specification.parameter_printable_names:
name = model_specification.parameter_printable_names[i]
printable_name = model_specification.parameter_printable_names[i]
slider_ax = fig.add_axes([0.35, i * 0.1, 0.3, 0.1])
slider = Slider(
ax=slider_ax,
label=name,
label=printable_name,
valmin=lo_lim,
valmax=hi_lim,
valinit=(lo_lim + hi_lim) / 2,
valinit=initial_params[name],
)
sliders.append(slider)

# Setting up initial value
pred, pred_var = self.predict_values(x, param_means)
# Plotting lines and reference data
pred, pred_var = self.predict_values(x, initial_params)
if (x_data is None) or (y_data is None):
ax.plot(x, pred, "k--")
else:
Expand Down

0 comments on commit 3dce39e

Please sign in to comment.