diff --git a/docs/source/emulator_analysis/index.rst b/docs/source/emulator_analysis/index.rst index 0c36324..4ebc0c2 100644 --- a/docs/source/emulator_analysis/index.rst +++ b/docs/source/emulator_analysis/index.rst @@ -153,6 +153,24 @@ This is implemented into the SWIFT-Emulator with sweep as `ModelValues` and `ModelParameters` containers, that are easy to parse. +Interactive plots +----------------- + +Another way to explore the effect of varying the parameters is +to try an interactive plot. This provides a slider for each +parameter, and the plot will update to show the emulator +predictions when you adjust the sliders. + +.. code-block:: python + + schecter_emulator.interactive_plot(independent) + +.. image:: interactive_plot.png + +It is possible to pass reference data to be plotted when calling +:meth:`swiftemulator.emulators.base.BaseEmulator.interactive\_plot`. + + Model Parameters Features ------------------------- @@ -234,4 +252,4 @@ This method is a lot slower than the default hyperparameter optimisation, and may take some time to compute. The main take away from plots like this is to see whether the hyperparameters are converged, and whether they are -consistent with the faster optimisation method. \ No newline at end of file +consistent with the faster optimisation method. diff --git a/docs/source/emulator_analysis/interactive_plot.png b/docs/source/emulator_analysis/interactive_plot.png new file mode 100644 index 0000000..0fc8ea5 Binary files /dev/null and b/docs/source/emulator_analysis/interactive_plot.png differ diff --git a/swiftemulator/emulators/base.py b/swiftemulator/emulators/base.py index 1fcadaf..298d9b2 100644 --- a/swiftemulator/emulators/base.py +++ b/swiftemulator/emulators/base.py @@ -85,7 +85,7 @@ def predict_values( Parameters ---------- - independent, np.array + independent: np.array Independent continuous variables to evaluate the emulator at. If the emulator is discrete, these are only allowed to be the discrete independent variables that the emulator was trained at @@ -98,12 +98,12 @@ def predict_values( Returns ------- - dependent_predictions, np.array + dependent_predictions: np.array Array of predictions, if the emulator is a function f, these are the predicted values of f(independent) evaluted at the position of the input ``model_parameters``. - dependent_prediction_errors, np.array + dependent_prediction_errors: np.array Errors on the model predictions. For models where the errors are unconstrained, this is an array of zeroes. @@ -124,10 +124,38 @@ def predict_values( raise NotImplementedError - def interactive_plot(self, x: np.array, xlabel: str = "", ylabel: str = "", x_data: np.array = None, y_data: np.array = None): + def interactive_plot( + self, + x: np.array, + xlabel: str = "", + ylabel: str = "", + x_data: np.array = None, + y_data: np.array = None, + ): """ - Generates an interactive plot over which shows the emulator predictions - for the input data passed to this method + Generates an interactive plot which displays the emulator predictions. + If no reference data is passed to be overplotted then the plot will + display a line which corresponds to the predictions for the mean + of the parameter values. + + Parameters + ---------- + + x: np.array + Array of data for which the emulator should make predictions. + + xlabel: str, optional + Label for horizontal axis on the resultant figure. + + ylabel: str, optional + Label for vertical axis on the resultant figure. + + x_data: np.array, optional + Array containing x-values of reference data to plot. + + y_data: np.array, optional + Array containing y-values of reference data to plot. + Must be the same shape as x_data """ import matplotlib.pyplot as plt from matplotlib.widgets import Slider @@ -161,9 +189,9 @@ def interactive_plot(self, x: np.array, xlabel: str = "", ylabel: str = "", x_da # Setting up initial value pred, pred_var = self.predict_values(x, param_means) if (x_data is None) or (y_data is None): - ax.plot(x, pred, 'k--') + ax.plot(x, pred, "k--") else: - ax.plot(x_data, y_data, 'k.') + ax.plot(x_data, y_data, "k.") (line,) = ax.plot(x, pred) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel)