-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
646 Implement LIME for tabular data #661
Changes from all commits
ce26759
6206665
afa3f89
e611668
191a862
f4bbb77
ccced97
d440b42
2cfdf20
25bea49
b3a7a9e
85f152a
7614265
09291f8
0dcbd0c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
"""LIME tabular explainer.""" | ||
from typing import Iterable | ||
from typing import List | ||
from typing import Union | ||
import numpy as np | ||
from lime.lime_tabular import LimeTabularExplainer | ||
from dianna import utils | ||
|
||
|
||
class LIMETabular: | ||
"""Wrapper around the LIME explainer for tabular data.""" | ||
|
||
def __init__( | ||
self, | ||
training_data: np.array, | ||
mode: str = "classification", | ||
feature_names: List[int] = None, | ||
categorical_features: List[int] = None, | ||
kernel_width: int = 25, | ||
kernel: callable = None, | ||
verbose: bool = False, | ||
class_names: List[str] = None, | ||
feature_selection: str = "auto", | ||
random_state: int = None, | ||
**kwargs, | ||
) -> None: | ||
"""Initializes Lime explainer. | ||
|
||
For numerical features, perturb them by sampling from a Normal(0,1) and | ||
doing the inverse operation of mean-centering and scaling, according to the | ||
means and stds in the training data. | ||
|
||
For categorical features, perturb by sampling according to the training | ||
distribution, and making a binary feature that is 1 when the value is the | ||
same as the instance being explained. | ||
|
||
More information can be found in the API guide: | ||
https://lime-ml.readthedocs.io/en/latest/lime.html#module-lime.lime_tabular | ||
|
||
Args: | ||
training_data (np.array): numpy 2d array | ||
mode (str, optional): "classification" or "regression" | ||
feature_names (list(str), optional): list of names corresponding to the columns | ||
in the training data. | ||
categorical_features (list(int), optional): list of indices corresponding to the | ||
categorical columns. Values in these | ||
columns MUST be integers. | ||
kernel_width (int, optional): kernel width | ||
kernel (callable, optional): kernel | ||
verbose (bool, optional): verbose | ||
class_names (str, optional): list of class names, ordered according to whatever | ||
the classifier is using. If not present, class names | ||
will be '0', '1', ... | ||
feature_selection (str, optional): feature selection | ||
random_state (int or np.RandomState, optional): seed or random state | ||
kwargs: These parameters are passed on | ||
|
||
""" | ||
self.mode = mode | ||
init_instance_kwargs = utils.get_kwargs_applicable_to_function( | ||
LimeTabularExplainer, kwargs | ||
) | ||
|
||
# temporary solution for setting num_features and top_labels | ||
self.num_features = len(feature_names) | ||
self.top_labels = len(class_names) | ||
|
||
self.explainer = LimeTabularExplainer( | ||
training_data, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be possible to call LIME without training_data? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my comments above. |
||
mode=self.mode, | ||
feature_names=feature_names, | ||
categorical_features=categorical_features, | ||
kernel_width=kernel_width, | ||
kernel=kernel, | ||
verbose=verbose, | ||
class_names=class_names, | ||
feature_selection=feature_selection, | ||
random_state=random_state, | ||
**init_instance_kwargs, | ||
) | ||
|
||
def explain( | ||
self, | ||
model_or_function: Union[str, callable], | ||
input_tabular: np.array, | ||
labels: Iterable[int] = (1,), | ||
num_samples: int = 5000, | ||
**kwargs, | ||
) -> np.array: | ||
"""Run the LIME explainer. | ||
|
||
Args: | ||
model_or_function (callable or str): The function that runs the model to be explained | ||
or the path to a ONNX model on disk. | ||
input_tabular (np.ndarray): Data to be explained. | ||
labels (Iterable(int), optional): Indices of classes to be explained. | ||
num_samples (int, optional): Number of samples | ||
kwargs: These parameters are passed on | ||
|
||
Other keyword arguments: see the documentation for LimeTabularExplainer.explain_instance: | ||
https://lime-ml.readthedocs.io/en/latest/lime.html#lime.lime_tabular.LimeTabularExplainer.explain_instance | ||
|
||
Returns: | ||
explanation: An Explanation object containing the LIME explanations for each class. | ||
""" | ||
# run the explanation. | ||
explain_instance_kwargs = utils.get_kwargs_applicable_to_function( | ||
self.explainer.explain_instance, kwargs | ||
) | ||
runner = utils.get_function(model_or_function) | ||
|
||
explanation = self.explainer.explain_instance( | ||
input_tabular, | ||
runner, | ||
labels=labels, | ||
top_labels=self.top_labels, | ||
num_features=self.num_features, | ||
num_samples=num_samples, | ||
**explain_instance_kwargs, | ||
) | ||
|
||
if self.mode == "regression": | ||
local_exp = sorted(explanation.local_exp[1]) | ||
saliency = [i[1] for i in local_exp] | ||
|
||
elif self.mode == "classification": | ||
# extract scores from lime explainer | ||
saliency = [] | ||
for i in range(self.top_labels): | ||
local_exp = sorted(explanation.local_exp[i]) | ||
# shape of local_exp [(index, saliency)] | ||
selected_saliency = [x[1] for x in local_exp] | ||
saliency.append(selected_saliency[:]) | ||
|
||
return np.array(saliency) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,6 +104,7 @@ notebooks = | |
scipy | ||
skl2onnx | ||
spacy | ||
seaborn | ||
tf2onnx | ||
torch | ||
torchtext | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
"""Test LIME tabular method.""" | ||
from unittest import TestCase | ||
import numpy as np | ||
import dianna | ||
from dianna.methods.lime_tabular import LIMETabular | ||
from tests.utils import run_model | ||
|
||
|
||
class LIMEOnTabular(TestCase): | ||
"""Suite of LIME tests for the tabular case.""" | ||
|
||
def test_lime_tabular_classification_correct_output_shape(self): | ||
"""Test the output of explainer.""" | ||
training_data = np.random.random((10, 2)) | ||
input_data = np.random.random(2) | ||
feature_names = ["feature_1", "feature_2"] | ||
explainer = LIMETabular(training_data, | ||
mode ='classification', | ||
feature_names=feature_names, | ||
class_names = ["class_1", "class_2"]) | ||
exp = explainer.explain( | ||
run_model, | ||
input_data, | ||
) | ||
assert len(exp[0]) == len(feature_names) | ||
|
||
def test_lime_tabular_regression_correct_output_shape(self): | ||
"""Test the output of explainer.""" | ||
training_data = np.random.random((10, 2)) | ||
input_data = np.random.random(2) | ||
feature_names = ["feature_1", "feature_2"] | ||
exp = dianna.explain_tabular(run_model, input_tabular=input_data, method='lime', | ||
mode ='regression', training_data = training_data, | ||
feature_names=feature_names, class_names=['class_1']) | ||
|
||
assert len(exp) == len(feature_names) | ||
Comment on lines
+9
to
+36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now you're only testing for shapes. It would be nice to add a test with synthetic data+model with some stupidly simple pattern to check if the method can detect it. Just like we have for timeseries RISE. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the idea of adding a "naive test" for it. I thought about it. But then I was thinking that we added a test for timeseries mainly because we implement those methods ourselves (referring to some papers and methods like lime segmentation, for instance, but purely implementing those methods mostly from scratch). This LIME tabular is entirely from LIME. We simply put a wrapper around it. It is somehow tested in the original implementation of LIME: https://github.com/marcotcr/lime/blob/master/lime/tests/test_lime_tabular.py But I do agree that it is nice to add a test for it. Maybe let's create an issue and design a simple test for it?? This PR already gets too big...I would prefer to do it in a separate PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue #669 created! 😄 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha input_tabular! I was already guessing what this new variable's name would be after the changes you made above