Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

646 Implement LIME for tabular data #661

Merged
merged 15 commits into from
Dec 7, 2023
62 changes: 46 additions & 16 deletions dianna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@
__version__ = '1.2.0'


def explain_timeseries(model_or_function, timeseries_data, method, labels,
**kwargs):
def explain_timeseries(model_or_function, input_timeseries, method, labels, **kwargs):
"""Explain timeseries data given a model and a chosen method.

Args:
model_or_function (callable or str): The function that runs the model to be explained _or_
the path to a ONNX model on disk.
timeseries_data (np.ndarray): Timeseries data to be explained
input_timeseries (np.ndarray): Timeseries data to be explained
method (string): One of the supported methods: RISE, LIME or KernelSHAP
labels (Iterable(int)): Labels to be explained
kwargs: key word arguments
Expand All @@ -49,18 +48,20 @@ def explain_timeseries(model_or_function, timeseries_data, method, labels,
"""
explainer = _get_explainer(method, kwargs, modality='Timeseries')
explain_timeseries_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs)
return explainer.explain(model_or_function, timeseries_data, labels,
**explain_timeseries_kwargs)
explainer.explain, kwargs
)
return explainer.explain(
model_or_function, input_timeseries, labels, **explain_timeseries_kwargs
)


def explain_image(model_or_function, input_data, method, labels, **kwargs):
def explain_image(model_or_function, input_image, method, labels, **kwargs):
"""Explain an image (input_data) given a model and a chosen method.

Args:
model_or_function (callable or str): The function that runs the model to be explained _or_
the path to a ONNX model on disk.
input_data (np.ndarray): Image data to be explained
input_image (np.ndarray): Image data to be explained
method (string): One of the supported methods: RISE, LIME or KernelSHAP
labels (Iterable(int)): Labels to be explained
kwargs: These keyword parameters are passed on
Expand All @@ -74,13 +75,14 @@ def explain_image(model_or_function, input_data, method, labels, **kwargs):
from onnx_tf.backend import prepare # noqa: F401
explainer = _get_explainer(method, kwargs, modality='Image')
explain_image_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs)
return explainer.explain(model_or_function, input_data, labels,
**explain_image_kwargs)
explainer.explain, kwargs
)
return explainer.explain(
model_or_function, input_image, labels, **explain_image_kwargs
)


def explain_text(model_or_function, input_text, tokenizer, method, labels,
**kwargs):
def explain_text(model_or_function, input_text, tokenizer, method, labels, **kwargs):
"""Explain text (input_text) given a model and a chosen method.

Args:
Expand All @@ -98,7 +100,8 @@ def explain_text(model_or_function, input_text, tokenizer, method, labels,
"""
explainer = _get_explainer(method, kwargs, modality='Text')
explain_text_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs)
explainer.explain, kwargs
)
return explainer.explain(
model_or_function=model_or_function,
input_text=input_text,
Expand All @@ -108,10 +111,36 @@ def explain_text(model_or_function, input_text, tokenizer, method, labels,
)


def explain_tabular(model_or_function, input_tabular, method, labels=(1, ), **kwargs):
"""Explain tabular (input_text) given a model and a chosen method.

Args:
model_or_function (callable or str): The function that runs the model to be explained _or_
the path to a ONNX model on disk.
input_tabular (np.ndarray): Tabular data to be explained
method (string): One of the supported methods: RISE, LIME or KernelSHAP
labels (Iterable(int), optional): Labels to be explained
kwargs: These keyword parameters are passed on

Returns:
One heatmap (2D array) per class.
"""
explainer = _get_explainer(method, kwargs, modality='Tabular')
explain_tabular_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs
)
return explainer.explain(
model_or_function=model_or_function,
input_tabular=input_tabular,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha input_tabular! I was already guessing what this new variable's name would be after the changes you made above

labels=labels,
**explain_tabular_kwargs,
)

def _get_explainer(method, kwargs, modality):
try:
method_submodule = importlib.import_module(
f'dianna.methods.{method.lower()}_{modality.lower()}')
f'dianna.methods.{method.lower()}_{modality.lower()}'
)
except ImportError as err:
raise ValueError(
f'Method {method.lower()}_{modality.lower()} does not exist'
Expand All @@ -123,5 +152,6 @@ def _get_explainer(method, kwargs, modality):
f'Data modality {modality} is not available for method {method.upper()}'
) from err
method_kwargs = utils.get_kwargs_applicable_to_function(
method_class.__init__, kwargs)
method_class.__init__, kwargs
)
return method_class(**method_kwargs)
2 changes: 1 addition & 1 deletion dianna/dashboard/_models_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def run_model(ts_data):

explanation = dianna.explain_timeseries(
run_model,
timeseries_data=ts_data[0],
input_timeseries=ts_data[0],
method='RISE',
**kwargs,
)
Expand Down
135 changes: 135 additions & 0 deletions dianna/methods/lime_tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""LIME tabular explainer."""
from typing import Iterable
from typing import List
from typing import Union
import numpy as np
from lime.lime_tabular import LimeTabularExplainer
from dianna import utils


class LIMETabular:
"""Wrapper around the LIME explainer for tabular data."""

def __init__(
self,
training_data: np.array,
mode: str = "classification",
feature_names: List[int] = None,
categorical_features: List[int] = None,
kernel_width: int = 25,
kernel: callable = None,
verbose: bool = False,
class_names: List[str] = None,
feature_selection: str = "auto",
random_state: int = None,
**kwargs,
) -> None:
"""Initializes Lime explainer.

For numerical features, perturb them by sampling from a Normal(0,1) and
doing the inverse operation of mean-centering and scaling, according to the
means and stds in the training data.

For categorical features, perturb by sampling according to the training
distribution, and making a binary feature that is 1 when the value is the
same as the instance being explained.

More information can be found in the API guide:
https://lime-ml.readthedocs.io/en/latest/lime.html#module-lime.lime_tabular

Args:
training_data (np.array): numpy 2d array
mode (str, optional): "classification" or "regression"
feature_names (list(str), optional): list of names corresponding to the columns
in the training data.
categorical_features (list(int), optional): list of indices corresponding to the
categorical columns. Values in these
columns MUST be integers.
kernel_width (int, optional): kernel width
kernel (callable, optional): kernel
verbose (bool, optional): verbose
class_names (str, optional): list of class names, ordered according to whatever
the classifier is using. If not present, class names
will be '0', '1', ...
feature_selection (str, optional): feature selection
random_state (int or np.RandomState, optional): seed or random state
kwargs: These parameters are passed on

"""
self.mode = mode
init_instance_kwargs = utils.get_kwargs_applicable_to_function(
LimeTabularExplainer, kwargs
)

# temporary solution for setting num_features and top_labels
self.num_features = len(feature_names)
self.top_labels = len(class_names)

self.explainer = LimeTabularExplainer(
training_data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to call LIME without training_data?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comments above.

mode=self.mode,
feature_names=feature_names,
categorical_features=categorical_features,
kernel_width=kernel_width,
kernel=kernel,
verbose=verbose,
class_names=class_names,
feature_selection=feature_selection,
random_state=random_state,
**init_instance_kwargs,
)

def explain(
self,
model_or_function: Union[str, callable],
input_tabular: np.array,
labels: Iterable[int] = (1,),
num_samples: int = 5000,
**kwargs,
) -> np.array:
"""Run the LIME explainer.

Args:
model_or_function (callable or str): The function that runs the model to be explained
or the path to a ONNX model on disk.
input_tabular (np.ndarray): Data to be explained.
labels (Iterable(int), optional): Indices of classes to be explained.
num_samples (int, optional): Number of samples
kwargs: These parameters are passed on

Other keyword arguments: see the documentation for LimeTabularExplainer.explain_instance:
https://lime-ml.readthedocs.io/en/latest/lime.html#lime.lime_tabular.LimeTabularExplainer.explain_instance

Returns:
explanation: An Explanation object containing the LIME explanations for each class.
"""
# run the explanation.
explain_instance_kwargs = utils.get_kwargs_applicable_to_function(
self.explainer.explain_instance, kwargs
)
runner = utils.get_function(model_or_function)

explanation = self.explainer.explain_instance(
input_tabular,
runner,
labels=labels,
top_labels=self.top_labels,
num_features=self.num_features,
num_samples=num_samples,
**explain_instance_kwargs,
)

if self.mode == "regression":
local_exp = sorted(explanation.local_exp[1])
saliency = [i[1] for i in local_exp]

elif self.mode == "classification":
# extract scores from lime explainer
saliency = []
for i in range(self.top_labels):
local_exp = sorted(explanation.local_exp[i])
# shape of local_exp [(index, saliency)]
selected_saliency = [x[1] for x in local_exp]
saliency.append(selected_saliency[:])

return np.array(saliency)
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ notebooks =
scipy
skl2onnx
spacy
seaborn
tf2onnx
torch
torchtext
Expand Down
36 changes: 36 additions & 0 deletions tests/methods/test_lime_tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Test LIME tabular method."""
from unittest import TestCase
import numpy as np
import dianna
from dianna.methods.lime_tabular import LIMETabular
from tests.utils import run_model


class LIMEOnTabular(TestCase):
"""Suite of LIME tests for the tabular case."""

def test_lime_tabular_classification_correct_output_shape(self):
"""Test the output of explainer."""
training_data = np.random.random((10, 2))
input_data = np.random.random(2)
feature_names = ["feature_1", "feature_2"]
explainer = LIMETabular(training_data,
mode ='classification',
feature_names=feature_names,
class_names = ["class_1", "class_2"])
exp = explainer.explain(
run_model,
input_data,
)
assert len(exp[0]) == len(feature_names)

def test_lime_tabular_regression_correct_output_shape(self):
"""Test the output of explainer."""
training_data = np.random.random((10, 2))
input_data = np.random.random(2)
feature_names = ["feature_1", "feature_2"]
exp = dianna.explain_tabular(run_model, input_tabular=input_data, method='lime',
mode ='regression', training_data = training_data,
feature_names=feature_names, class_names=['class_1'])

assert len(exp) == len(feature_names)
Comment on lines +9 to +36
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now you're only testing for shapes. It would be nice to add a test with synthetic data+model with some stupidly simple pattern to check if the method can detect it. Just like we have for timeseries RISE.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of adding a "naive test" for it. I thought about it. But then I was thinking that we added a test for timeseries mainly because we implement those methods ourselves (referring to some papers and methods like lime segmentation, for instance, but purely implementing those methods mostly from scratch).

This LIME tabular is entirely from LIME. We simply put a wrapper around it. It is somehow tested in the original implementation of LIME:

https://github.com/marcotcr/lime/blob/master/lime/tests/test_lime_tabular.py

But I do agree that it is nice to add a test for it. Maybe let's create an issue and design a simple test for it?? This PR already gets too big...I would prefer to do it in a separate PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue #669 created! 😄

2 changes: 1 addition & 1 deletion tests/methods/test_rise_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_rise_timeseries_with_expert_model_for_correct_max_and_min():
temperature_timeseries = average_temperature_timeseries_with_1_cold_and_1_hot_day(cold_day_index, hot_day_index)

summer_explanation, winter_explanation = dianna.explain_timeseries(run_expert_model,
timeseries_data=temperature_timeseries,
input_timeseries=temperature_timeseries,
method='rise',
labels=[0, 1],
p_keep=0.1, n_masks=10000,
Expand Down
Loading
Loading