Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: CEM and Anchor docs #40

Merged
merged 4 commits into from
Apr 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
Welcome to alibi's documentation!
=================================

.. toctree::
:maxdepth: 1
:caption: Methods

methods/Anchors.ipynb
methods/CEM.ipynb

.. toctree::
:maxdepth: 1
:caption: Examples
Expand Down
460 changes: 460 additions & 0 deletions doc/source/methods/Anchors.ipynb

Large diffs are not rendered by default.

336 changes: 336 additions & 0 deletions doc/source/methods/CEM.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,336 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[source](../api/alibi.explainers.cem.rst)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Contrastive Explanation Method"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Overview"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The *Contrastive Explanation Method* (CEM) is based on the paper [Explanations based on the Missing: Towards Constrastive Explanations with Pertinent Negatives](https://arxiv.org/abs/1802.07623) and extends the [code](https://github.com/IBM/Contrastive-Explanation-Method) open sourced by the authors. CEM generates instance based local black box explanations for classification models in terms of Pertinent Positives (PP) and Pertinent Negatives (PN). For a PP, the method finds the features that should be minimally and sufficiently present (e.g. important pixels in an image) to predict the same class as on the original instance. PN's on the other hand identify what features should be minimally and necessarily absent from the instance to be explained in order to maintain the original prediction class. The aim of PN's is not to provide a full set of characteristics that should be absent in the explained instance, but to provide a minimal set that differentiates it from the closest different class. Intuitively, the Pertinent Positives could be compared to Anchors while Pertinent Negatives are similar to Counterfactuals. As the authors of the paper state, CEM can generate clear explanations of the form: \"An input x is classified in class y because features $f_{i}$, ..., $f_{k}$ are present and because features $f_{m}$, ..., $f_{p}$ are absent.\" The current implementation is most suitable for images and tabular data without categorical features.\n",
"\n",
"In order to create interpretable PP's and PN's, feature-wise perturbation needs to be done in a meaningful way. To keep the perturbations sparse and close to the original instance, the objective function contains an elastic net ($\\beta$$L_{1}$ + $L_{2}$) regularizer. Optionally, an auto-encoder can be trained to reconstruct instances of the training set. We can then introduce the $L_{2}$ reconstruction error of the perturbed instance as an additional loss term in our objective function. As a result, the perturbed instance lies close to the training data manifold.\n",
"\n",
"The ability to add or remove features to arrive at respectively PN's or PP's implies that there are feature values that contain no information with regards to the model's predictions. Consider for instance the MNIST image below where the pixels are scaled between 0 and 1. The pixels with values close to 1 define the number in the image while the background pixels have value 0. We assume that perturbations towards the background value 0 are equivalent to removing features, while perturbations towards 1 imply adding features."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![mnist4](mnist_orig.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It is intuitive to understand that adding features to get a PN means changing 0's into 1's until a different number is formed, in this case changing a 4 into a 9."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![mnist4pn](mnist_pn.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To find the PP, we do the opposite and change 1's from the original instance into 0's, the background value, and only keep a vague outline of the original 4."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![mnist4pp](mnist_pp.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It is however often not trivial to find these non-informative feature values and domain knowledge becomes very important. \n",
"\n",
"For more details, we refer the reader to the original [paper](https://arxiv.org/abs/1802.07623)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Usage"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Because the optimizer is defined in TensorFlow (TF), we need to run the CEM explainer within a TensorFlow session:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"# initialize TensorFlow session before model definition\n",
"sess = tf.Session()\n",
"K.set_session(sess) # using a Keras model in the same session\n",
"sess.run(tf.global_variables_initializer())\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can then load our MNIST classifier and the (optional) auto-encoder. The example below uses Keras or TF models. This allows optimization of the objective function to run entirely with automatic differentiation because the TF graph has access to the underlying model architecture. For models built in different frameworks (e.g. scikit-learn), the gradients of part of the loss function with respect to the input features need to be evaluated numerically. We'll handle this case later."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"# define models\n",
"cnn = load_model('mnist_cnn.h5')\n",
"ae = load_model('mnist_ae.h5')\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now initialize the CEM explainer:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"# initialize CEM explainer\n",
"shape = (1,) + x_train.shape[1:]\n",
"mode = 'PN'\n",
"cem = CEM(sess, cnn, mode, shape, kappa=0., beta=.1, \n",
" feature_range=(x_train.min(), x_train.max()), \n",
" gamma=100, ae_model=ae, max_iterations=1000, \n",
" c_init=1., c_steps=10, learning_rate_init=1e-2, \n",
" clip=(-1000.,1000.), no_info_val=-1.)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Besides passing the previously defined session as well as the predictive and auto-encoder models, we set a number of **hyperparameters** ...\n",
"\n",
"... **general**:\n",
"\n",
"* `mode`: 'PN' or 'PP'.\n",
"\n",
"* `shape`: shape of the instance to be explained, starting with batch dimension. Currently only single explanations are supported, so the batch dimension should be equal to 1.\n",
"\n",
"* `feature_range`: global or feature-wise min and max values for the perturbed instance.\n",
"\n",
"... related to the **optimizer**:\n",
"\n",
"* `max_iterations`: number of loss optimization steps for each value of *c*; the multiplier of the first loss term.\n",
"\n",
"* `learning_rate_init`: initial learning rate, follows polynomial decay.\n",
"\n",
"* `clip`: min and max gradient values.\n",
"\n",
"... related to the **non-informative value**:\n",
"\n",
"* `no_info_val`: as explained in the previous section, it is important to define which feature values are considered background and not crucial for the class predictions. For MNIST images scaled between 0 and 1 or -0.5 and 0.5 as in the notebooks, pixel perturbations in the direction of the (low) background pixel value can be seen as removing features, moving towards the non-informative value. As a result, the `no_info_val` parameter is set at a low value like -1. `no_info_val` can be defined globally or feature-wise. For most applications, domain knowledge becomes very important here. If a representative sample of the training set is available, we can always (naively) infer a `no_info_val` by taking the feature-wise median or mean:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"cem.fit(x_train, no_info_type='median')\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"... related to the **objective function**:\n",
"\n",
"* `c_init` and `c_steps`: the multiplier $c$ of the first loss term is updated for `c_steps` iterations, starting at `c_init`. The first loss term encourages the perturbed instance to be predicted as a different class for a PN and the same class for a PP. If we find a candidate PN or PP for the current value of $c$, we reduce the value of $c$ for the next optimization cycle to put more emphasis on the regularization terms and improve the solution. If we cannot find a solution, $c$ is increased to put more weight on the prediction class restrictions of the PN and PP before focusing on the regularization.\n",
"\n",
"* `kappa`: the first term in the loss function is defined by a difference between the predicted probabilities for the perturbed instance of the original class and the max of the other classes. $\\kappa \\geq 0$ defines a cap for this difference, limiting its impact on the overall loss to be optimized. Similar to the original paper, we set $\\kappa$ to 0. in the examples.\n",
"\n",
"* `beta`: $\\beta$ is the $L_{1}$ loss term multiplier. A higher value for $\\beta$ means more weight on the sparsity restrictions of the perturbations. Similar to the paper, we set $\\beta$ to 0.1 for the MNIST and Iris datasets.\n",
"\n",
"* `gamma`: multiplier for the optional $L_{2}$ reconstruction error. A higher value for $\\gamma$ means more emphasis on the reconstruction error penalty defined by the auto-encoder. Similar to the paper, we set $\\gamma$ to 100 when we have an auto-encoder available.\n",
"\n",
"While the paper's default values for the loss term coefficients worked well for the simple examples provided in the notebooks, it is recommended to test their robustness for your own applications."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Explanation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can finally explain the instance and close the TensorFlow session when we are done:\n",
"\n",
"```python\n",
"explanation = cem.explain(X)\n",
"sess.close()\n",
"K.clear_session()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The ```explain``` method returns a dictionary with the following *key: value* pairs:\n",
"\n",
"* *X*: original instance\n",
"\n",
"* *X_pred*: predicted class of original instance\n",
"\n",
"* *PN* or *PP*: Pertinent Negative or Pertinant Positive\n",
"\n",
"* *PN_pred* or *PP_pred*: predicted class of PN or PP\n",
"\n",
"* *grads_graph*: gradient values computed from the TF graph with respect to the input features at the PN or PP\n",
"\n",
"* *grads_num*: numerical gradient values with respect to the input features at the PN or PP"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Numerical Gradients"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So far, the whole optimization problem could be defined within the TF graph, making autodiff possible. It is however possible that we do not have access to the model architecture and weights, and are only provided with a ```predict``` function returning probabilities for each class. The CEM can be initialized in the TF session as follows:\n",
"\n",
"```python\n",
"# define model\n",
"lr = load_model('iris_lr.h5')\n",
"predict_fn = lambda x: lr.predict(x)\n",
" \n",
"# initialize CEM explainer\n",
"shape = (1,) + x_train.shape[1:]\n",
"mode = 'PP'\n",
"cem = CEM(sess, predict_fn, mode, shape, kappa=0., beta=.1, \n",
" feature_range=(x_train.min(), x_train.max()), \n",
" eps=(1e-2, 1e-2), update_num_grad=100)\n",
"```\n",
"\n",
"In this case, we need to evaluate the gradients of the loss function with respect to the input features numerically:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\\begin{equation*} \\frac{\\partial L}{\\partial x} = \\frac{\\partial L}{\\partial p} \\frac{\\partial p}{\\partial x} \\end{equation*}\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"where $L$ is the loss function, $p$ the predict function and $x$ the input features to optimize. There are now 2 additional hyperparameters to consider:\n",
"\n",
"* `eps`: a tuple to define the perturbation size used to compute the numerical gradients. `eps[0]` and `eps[1]` are used respectively for $^{\\delta L}/_{\\delta p}$ and $^{\\delta p}/_{\\delta x}$. `eps[0]` and `eps[1]` can be a combination of float values or numpy arrays. For `eps[0]`, the array dimension should be *(1 x nb of prediction categories)* and for `eps[1]` it should be *(1 x nb of features)*. For the Iris dataset, `eps` could look as follows:\n",
"\n",
"```python\n",
"eps0 = np.array([[1e-2, 1e-2, 1e-2]]) # 3 prediction categories, equivalent to 1e-2\n",
"eps1 = np.array([[1e-2, 1e-2, 1e-2, 1e-2]]) # 4 features, also equivalent to 1e-2\n",
"eps = (eps0, eps1)\n",
"```\n",
"\n",
"- `update_num_grad`: for complex models with a high number of parameters and a high dimensional feature space (e.g. Inception on ImageNet), evaluating numerical gradients can be expensive as they involve prediction calls for each perturbed instance. The `update_num_grad` parameter allows you to set a batch size on which to evaluate the numerical gradients, reducing the number of prediction calls required. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Examples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[Contrastive Explanations Method (CEM) applied to MNIST](../examples/cem_mnist.nblink)\n",
"\n",
"[Contrastive Explanations Method (CEM) applied to Iris dataset](../examples/cem_iris.nblink)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Binary file added doc/source/methods/anchor_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/source/methods/lime_sentiment.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/source/methods/mnist_orig.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/source/methods/mnist_pn.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/source/methods/mnist_pp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/source/methods/persiancat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/source/methods/persiancatanchor.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/source/methods/persiancatsegm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.