Skip to content

Commit

Permalink
Add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
robjmcgibbon committed Feb 6, 2024
1 parent fc05765 commit 318f165
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 9 deletions.
20 changes: 19 additions & 1 deletion docs/source/emulator_analysis/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,24 @@ This is implemented into the SWIFT-Emulator with
sweep as `ModelValues` and `ModelParameters`
containers, that are easy to parse.

Interactive plots
-----------------

Another way to explore the effect of varying the parameters is
to try an interactive plot. This provides a slider for each
parameter, and the plot will update to show the emulator
predictions when you adjust the sliders.

.. code-block:: python
schecter_emulator.interactive_plot(independent)
.. image:: interactive_plot.png

It is possible to pass reference data to be plotted when calling
:meth:`swiftemulator.emulators.base.BaseEmulator.interactive\_plot`.


Model Parameters Features
-------------------------

Expand Down Expand Up @@ -234,4 +252,4 @@ This method is a lot slower than the default hyperparameter
optimisation, and may take some time to compute. The main
take away from plots like this is to see whether the
hyperparameters are converged, and whether they are
consistent with the faster optimisation method.
consistent with the faster optimisation method.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
44 changes: 36 additions & 8 deletions swiftemulator/emulators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def predict_values(
Parameters
----------
independent, np.array
independent: np.array
Independent continuous variables to evaluate the emulator
at. If the emulator is discrete, these are only allowed to be
the discrete independent variables that the emulator was trained at
Expand All @@ -98,12 +98,12 @@ def predict_values(
Returns
-------
dependent_predictions, np.array
dependent_predictions: np.array
Array of predictions, if the emulator is a function f, these
are the predicted values of f(independent) evaluted at the position
of the input ``model_parameters``.
dependent_prediction_errors, np.array
dependent_prediction_errors: np.array
Errors on the model predictions. For models where the errors are
unconstrained, this is an array of zeroes.
Expand All @@ -124,10 +124,38 @@ def predict_values(

raise NotImplementedError

def interactive_plot(self, x: np.array, xlabel: str = "", ylabel: str = "", x_data: np.array = None, y_data: np.array = None):
def interactive_plot(
self,
x: np.array,
xlabel: str = "",
ylabel: str = "",
x_data: np.array = None,
y_data: np.array = None,
):
"""
Generates an interactive plot over which shows the emulator predictions
for the input data passed to this method
Generates an interactive plot which displays the emulator predictions.
If no reference data is passed to be overplotted then the plot will
display a line which corresponds to the predictions for the mean
of the parameter values.
Parameters
----------
x: np.array
Array of data for which the emulator should make predictions.
xlabel: str, optional
Label for horizontal axis on the resultant figure.
ylabel: str, optional
Label for vertical axis on the resultant figure.
x_data: np.array, optional
Array containing x-values of reference data to plot.
y_data: np.array, optional
Array containing y-values of reference data to plot.
Must be the same shape as x_data
"""
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
Expand Down Expand Up @@ -161,9 +189,9 @@ def interactive_plot(self, x: np.array, xlabel: str = "", ylabel: str = "", x_da
# Setting up initial value
pred, pred_var = self.predict_values(x, param_means)
if (x_data is None) or (y_data is None):
ax.plot(x, pred, 'k--')
ax.plot(x, pred, "k--")
else:
ax.plot(x_data, y_data, 'k.')
ax.plot(x_data, y_data, "k.")
(line,) = ax.plot(x, pred)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
Expand Down

0 comments on commit 318f165

Please sign in to comment.