Skip to content

Commit

Permalink
Merge pull request #46 from robjmcgibbon/interactive_plot
Browse files Browse the repository at this point in the history
Interactive plot
  • Loading branch information
JBorrow authored Aug 12, 2024
2 parents 6bc0f31 + 3dce39e commit cebb4fd
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 12 deletions.
27 changes: 26 additions & 1 deletion docs/source/emulator_analysis/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,31 @@ This is implemented into the SWIFT-Emulator with
sweep as `ModelValues` and `ModelParameters`
containers, that are easy to parse.

Interactive plots
-----------------

Another way to explore the effect of varying the parameters is
to try an interactive plot. Every emulator object contains an
`interactive_plot` method. This generates a plot with a slider
for each parameter. The plot will update to show the emulator
predictions when sliders are adjusted. The emulator will make
its initial prediction using the parameter values passed to it.
If no parameters are passed if will default to the midpoint of
each parameter range. It is also possible to pass reference data
to overplot on the emulator predictions. If no reference data is
passed the plot will display a fixed dashed line corresponding to
the prediction using the initial parameter values.

.. code-block:: python
schecter_emulator.interactive_plot(predict_x, initial_params=center,
xlabel="Stellar mass", ylabel="dn/dlogM",
x_data=[10.5, 11, 11.5],
y_data=[-10, -11, -12])
.. image:: interactive_plot.png


Model Parameters Features
-------------------------

Expand Down Expand Up @@ -234,4 +259,4 @@ This method is a lot slower than the default hyperparameter
optimisation, and may take some time to compute. The main
take away from plots like this is to see whether the
hyperparameters are converged, and whether they are
consistent with the faster optimisation method.
consistent with the faster optimisation method.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
101 changes: 98 additions & 3 deletions swiftemulator/emulators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def predict_values(
Parameters
----------
independent, np.array
independent: np.array
Independent continuous variables to evaluate the emulator
at. If the emulator is discrete, these are only allowed to be
the discrete independent variables that the emulator was trained at
Expand All @@ -98,12 +98,12 @@ def predict_values(
Returns
-------
dependent_predictions, np.array
dependent_predictions: np.array
Array of predictions, if the emulator is a function f, these
are the predicted values of f(independent) evaluted at the position
of the input ``model_parameters``.
dependent_prediction_errors, np.array
dependent_prediction_errors: np.array
Errors on the model predictions. For models where the errors are
unconstrained, this is an array of zeroes.
Expand All @@ -123,3 +123,98 @@ def predict_values(
)

raise NotImplementedError

def interactive_plot(
self,
x: np.array,
initial_params: Dict[str, float] = {},
xlabel: str = "",
ylabel: str = "",
x_data: np.array = None,
y_data: np.array = None,
):
"""
Generates an interactive plot which displays the emulator predictions.
If initial_params should contain the initial parameter values to make a
prediction for. If initial_params is not passed the midpoint of each of
the parameter values will be used instead. If no reference data is
passed to be overplotted then the plot will display a line which
corresponds to the predictions for the initial parameter values.
Parameters
----------
x: np.array
Array of data for which the emulator should make predictions.
initial_params: Dict[str, float], optional
What parameters values to plot the predicition for initally.
If missing the midpoint of each parameter range will be used.
xlabel: str, optional
Label for horizontal axis on the resultant figure.
ylabel: str, optional
Label for vertical axis on the resultant figure.
x_data: np.array, optional
Array containing x-values of reference data to plot.
y_data: np.array, optional
Array containing y-values of reference data to plot.
Must be the same shape as x_data
"""
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider

fig, ax = plt.subplots()
model_specification = self.model_specification
sliders = []
n_param = model_specification.number_of_parameters
fig.subplots_adjust(bottom=0.12 + n_param * 0.1)
for i in range(n_param):
# Extracting information needed for slider
name = model_specification.parameter_names[i]
lo_lim = sorted(model_specification.parameter_limits[i])[0]
hi_lim = sorted(model_specification.parameter_limits[i])[1]
if not name in initial_params:
initial_params[name] = (lo_lim + hi_lim) / 2

# Adding slider
printable_name = name
if model_specification.parameter_printable_names:
printable_name = model_specification.parameter_printable_names[i]
slider_ax = fig.add_axes([0.35, i * 0.1, 0.3, 0.1])
slider = Slider(
ax=slider_ax,
label=printable_name,
valmin=lo_lim,
valmax=hi_lim,
valinit=initial_params[name],
)
sliders.append(slider)

# Plotting lines and reference data
pred, pred_var = self.predict_values(x, initial_params)
if (x_data is None) or (y_data is None):
ax.plot(x, pred, "k--")
else:
ax.plot(x_data, y_data, "k.")
(line,) = ax.plot(x, pred)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)

# Define and enable update function
def update(val):
params = {
model_specification.parameter_names[i]: sliders[i].val
for i in range(n_param)
}
pred, pred_var = self.predict_values(x, params)
line.set_ydata(pred)

for slider in sliders:
slider.on_changed(update)

plt.show()
plt.close()
16 changes: 10 additions & 6 deletions swiftemulator/emulators/multi_gaussian_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,16 @@ def predict_values(

for index, (low, high) in enumerate(self.independent_regions):
mask = np.logical_and(
independent > low
if low is not None
else np.ones_like(independent).astype(bool),
independent < high
if high is not None
else np.ones_like(independent).astype(bool),
(
independent > low
if low is not None
else np.ones_like(independent).astype(bool)
),
(
independent < high
if high is not None
else np.ones_like(independent).astype(bool)
),
)

predicted, errors = self.emulators[index].predict_values(
Expand Down
4 changes: 2 additions & 2 deletions swiftemulator/io/swift.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def load_pipeline_outputs(
"adaptive_mass_function",
"histogram",
]
recursive_search = (
lambda d, k: d.get(k[0], recursive_search(d, k[1:])) if len(k) > 0 else None
recursive_search = lambda d, k: (
d.get(k[0], recursive_search(d, k[1:])) if len(k) > 0 else None
)
line_search = lambda d: recursive_search(d, line_types)

Expand Down

0 comments on commit cebb4fd

Please sign in to comment.