From c8df34e9ebfb1c5a7862e3f1e55f1fbbd122b4c1 Mon Sep 17 00:00:00 2001
From: Adelin CONSTANS
Date: Mon, 29 May 2023 09:37:09 +0200
Subject: [PATCH 001/138] example-based: add cole method
---
xplique/example_based/cole.py | 411 ++++++++++++++++++++++++++++++++++
1 file changed, 411 insertions(+)
create mode 100644 xplique/example_based/cole.py
diff --git a/xplique/example_based/cole.py b/xplique/example_based/cole.py
new file mode 100644
index 00000000..63fa69c6
--- /dev/null
+++ b/xplique/example_based/cole.py
@@ -0,0 +1,411 @@
+"""
+Module related to Case Base Explainer
+"""
+
+
+import matplotlib.pyplot as plt
+import numpy as np
+from sklearn.metrics import DistanceMetric
+from sklearn.neighbors import KDTree
+import tensorflow as tf
+
+from ..plots.image import _standardize_image
+from ..types import Callable, Union, Optional
+
+
+class Cole:
+ """
+ Used to compute the Case Based Explainer sytem, a twins sytem that use ANN and knn with
+ the same dataset.
+
+ Ref. Eoin M. Kenny and Mark T. Keane.
+ Twin-Systems to Explain Artificial Neural Networks using Case-Based Reasoning:
+ Comparative Tests of Feature-Weighting Methods in ANN-CBR Twins for XAI. (2019)
+ https://www.ijcai.org/proceedings/2019/376
+ """
+
+ def __init__(
+ self,
+ model: Callable,
+ case_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_train: np.ndarray,
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ distance_function: DistanceMetric = None,
+ weights_extraction_function: Callable = None,
+ k: Optional[int] = 3,
+ ):
+ """
+ Parameters
+ ----------
+ model
+ The model from wich we want to obtain explanations.
+ case_dataset
+ The dataset used to train the model,
+ also use by the function to calcul the closest examples.
+ labels_train
+ labels define by the dataset.
+ targets
+ labels predict by the model from the dataset.
+ distance_function
+ The function to calcul the distance between the inputs and all the dataset.
+ (Can use : euclidean, manhattan, minkowski etc...)
+ weights_extraction_function
+ The function to calcul the weight of every features,
+ many type of methode can be use but it will depend of
+ what type of dataset you've got.
+ examples:
+ def my_function(inputs, targets):
+ # outputs.shape == inputs.shape
+ return outputs
+ k
+ Represante how many nearest neighbours you want to be returns.
+ """
+ # set attributes
+ self.model = model
+ self.case_dataset = case_dataset
+ self.weights_extraction_function = weights_extraction_function
+ self.k_neighbors = k
+ self.labels_train = labels_train
+
+ # verify targets parametre
+ if targets is None:
+ targets = model(case_dataset)
+ nb_classes = targets.shape[1]
+ targets = tf.argmax(targets, axis=1)
+ targets = tf.one_hot(
+ targets, nb_classes
+ ) # nb_classes normalement en second argument mais la du coup 10.
+
+ # verify distance_function parametre
+ if distance_function is None:
+ distance_function = DistanceMetric.get_metric("euclidean")
+
+ # verify weight_extraction_function parametre
+ if weights_extraction_function is None:
+ self.weights_extraction_function = self._get_default_weights_extraction_function()
+
+ # compute case dataset weights (used in distance)
+ # the weight extraction function may need the predictions to extract the weights
+ case_dataset_weight = self.weights_extraction_function(case_dataset, targets)
+ # for images, channels may disappear
+ if len(case_dataset_weight.shape) != len(case_dataset.shape):
+ case_dataset_weight = tf.expand_dims(case_dataset_weight, -1)
+ self.case_dataset_weight = case_dataset_weight
+
+ # apply weights to the case dataset (weighted distance)
+ weighted_case_dataset = tf.math.multiply(case_dataset_weight, case_dataset)
+ # flatten features for kdtree
+ weighted_case_dataset = tf.reshape(
+ weighted_case_dataset, [weighted_case_dataset.shape[0], -1]
+ )
+
+ # create kdtree instance with weighted case dataset
+ # will be called to estimate closest examples
+ self.knn = KDTree(weighted_case_dataset, metric=distance_function)
+
+ def extract_element_from_indices(
+ self,
+ labels_train: np.ndarray,
+ examples_indice: np.ndarray,
+ ):
+ """
+ This function has to extract every example and weights from the dataset
+ by the indice calculate with de knn query in the explain function
+
+ Parameters
+ ----------
+ labels_train
+ labels define by the dataset.
+ examples_indice
+ Represente the indice of the K nearust neighbours of the input.
+
+ Returns
+ -------
+ examples
+ Represente the K nearust neighbours of the input.
+ examples_weights
+ features weight of the examples.
+ labels_examples
+ labels of the examples.
+ """
+ all_examples = []
+ all_weight_examples = []
+ all_labels_examples = []
+ for sample_examples_indice in examples_indice:
+ sample_examples = []
+ weight_ex = []
+ label_ex = []
+ for indice in sample_examples_indice:
+ sample_examples.append(self.case_dataset[indice])
+ weight_ex.append(self.case_dataset_weight[indice])
+ label_ex.append(labels_train[indice])
+ # (k, h, w, 1)
+ all_examples.append(tf.stack(sample_examples, axis=0))
+ all_weight_examples.append(tf.stack(weight_ex, axis=0))
+ all_labels_examples.append(tf.stack(label_ex, axis=0))
+ # (n, k, h, w, 1)
+ examples = tf.stack(all_examples, axis=0)
+ examples_weights = tf.stack(all_weight_examples, axis=0)
+ labels_examples = tf.stack(all_labels_examples, axis=0)
+
+ return examples, examples_weights, labels_examples
+
+ def explain(
+ self,
+ inputs: Union[tf.Tensor, np.ndarray],
+ targets: Union[tf.Tensor, np.ndarray] = None,
+ ):
+ """
+ This function calculates the indice of the k closest example of the different inputs.
+ Then calls extract_element_from_indice to extract the examples from those indices.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Expected shape among (N,W), (N,T,W), (N,W,H,C).
+ targets
+ Tensor or Array. Corresponding to the prediction of the samples by the model.
+ shape: (n, nb_classes)
+ Used by the `weights_extraction_function` if it is an Xplique attribution function,
+ For more details, please refer to the explain methods documentation.
+
+ Returns
+ -------
+ examples
+ Represente the K nearust neighbours of the input.
+ examples_distance
+ distance between the input and the examples.
+ examples_weight
+ features weight of the examples.
+ inputs_weights
+ features weight of the inputs.
+ examples_labels
+ labels of the examples.
+ """
+
+ # verify targets parametre
+ if targets is None:
+ targets = self.model(inputs)
+ nb_classes = targets.shape[1]
+ targets = tf.argmax(targets, axis=1)
+ targets = tf.one_hot(targets, nb_classes)
+
+ # compute weight (used in distance)
+ # the weight extraction function may need the prediction to extract the weights
+ inputs_weights = self.weights_extraction_function(inputs, targets)
+
+ # for images, channels may disappear
+ if len(inputs_weights.shape) != len(inputs.shape):
+ inputs_weights = tf.expand_dims(inputs_weights, -1)
+
+ # apply weights to the inputs
+ weighted_inputs = tf.math.multiply(inputs_weights, inputs)
+ # flatten features for knn query
+ weighted_inputs = tf.reshape(weighted_inputs, [weighted_inputs.shape[0], -1])
+
+ # kdtree instance call with knn.query,
+ # call with the weighted inputs and the number of closest examples (k)
+ examples_distance, examples_indice = self.knn.query(
+ weighted_inputs, k=self.k_neighbors
+ )
+
+ # call the extract_element_from_indices function
+ examples, examples_weights, examples_labels = self.extract_element_from_indices(
+ self.labels_train, examples_indice
+ )
+
+ return (
+ examples,
+ examples_distance,
+ examples_weights,
+ inputs_weights,
+ examples_labels,
+ )
+
+ @staticmethod
+ def _get_default_weights_extraction_function():
+ """
+ This function allows you to get the default weight extraction function.
+ """
+ return lambda inputs, targets: tf.ones(inputs.shape)
+
+ def show_result_images(
+ self,
+ inputs: Union[tf.Tensor, np.ndarray],
+ examples: Union[tf.Tensor, np.ndarray],
+ examples_distance: float,
+ inputs_weights: np.ndarray,
+ examples_weights: np.ndarray,
+ indice_original: int,
+ examples_labels: np.ndarray,
+ labels_test: np.ndarray,
+ clip_percentile: Optional[float] = 0.2,
+ cmapimages: Optional[str] = "gray",
+ cmapexplanation: Optional[str] = "coolwarm",
+ alpha: Optional[float] = 0.5,
+ ):
+ """
+ This function is for image data, it show the returns of the explain function.
+
+ Parameters
+ ---------
+ inputs
+ Tensor or Array. Input samples to be show next to examples.
+ Expected shape among (N,W), (N,T,W), (N,W,H,C).
+ examples
+ Represente the K nearust neighbours of the input.
+ examples_distance
+ Distance between input data and examples.
+ inputs_weights
+ features weight of the inputs.
+ examples_weight
+ features weight of the examples.
+ indice_original
+ Represente the indice of the inputs to show the true labels.
+ examples_labels
+ labels of the examples.
+ labels_test
+ Corresponding to labels of the dataset test.
+ clip_percentile
+ Percentile value to use if clipping is needed, e.g a value of 1 will perform a clipping
+ between percentile 1 and 99.
+ This parameter allows to avoid outliers in case of too extreme values.
+ cmapimages
+ For images.
+ The Colormap instance or registered colormap name used to map scalar data to colors.
+ This parameter is ignored for RGB(A) data.
+ cmapexplanation
+ For explanation.
+ The Colormap instance or registered colormap name used to map scalar data to colors.
+ This parameter is ignored for RGB(A) data.
+ alpha
+ The alpha blending value, between 0 (transparent) and 1 (opaque).
+ If alpha is an array, the alpha blending values are applied pixel by pixel,
+ and alpha must have the same shape as X.
+ """
+ # pylint: disable=too-many-arguments
+
+ # Initialize 'input_and_examples' and 'corresponding_weights' that they
+ # will be use to show every closest examples and the explanation
+ inputs = tf.expand_dims(inputs, 1)
+ inputs_weights = tf.expand_dims(inputs_weights, 1)
+ input_and_examples = tf.concat([inputs, examples], axis=1)
+ corresponding_weights = tf.concat([inputs_weights, examples_weights], axis=1)
+
+ # calcul the prediction of input and examples
+ # that they will be used at title of the image
+ # nevessary loop becaue we have n * k elements
+ predicted_labels = []
+ for samples in input_and_examples:
+ predicted = self.model(samples)
+ predicted = tf.argmax(predicted, axis=1)
+ predicted_labels.append(predicted)
+
+ # configure the grid to show all results
+ plt.rcParams["figure.autolayout"] = True
+ plt.rcParams["figure.figsize"] = [20, 10]
+
+ # loop to organize and show all results
+ for j in range(np.asarray(input_and_examples).shape[0]):
+ fig = plt.figure()
+ gridspec = fig.add_gridspec(2, input_and_examples.shape[1])
+ for k in range(len(input_and_examples[j])):
+ fig.add_subplot(gridspec[0, k])
+ if k == 0:
+ plt.title(
+ f"Original image\nGround Truth: {labels_test[indice_original[j]]}"\
+ + f"\nPrediction: {predicted_labels[j][k]}"
+ )
+ else:
+ plt.title(
+ f"Examples\nGround Truth: {examples_labels[j][k-1]}"\
+ + f"\nPrediction: {predicted_labels[j][k]}"\
+ + f"\nDistance: {round(examples_distance[j][k-1], 2)}"
+ )
+ plt.imshow(input_and_examples[j][k], cmap=cmapimages)
+ plt.axis("off")
+ fig.add_subplot(gridspec[1, k])
+ plt.imshow(input_and_examples[j][k], cmap=cmapimages)
+ plt.imshow(
+ _standardize_image(corresponding_weights[j][k], clip_percentile),
+ cmap=cmapexplanation,
+ alpha=alpha,
+ )
+ plt.axis("off")
+ plt.show()
+
+ def show_result_tabular(
+ self,
+ inputs: Union[tf.Tensor, np.ndarray],
+ examples: Union[tf.Tensor, np.ndarray],
+ examples_distance: float,
+ indice_original: int,
+ examples_labels: np.ndarray,
+ labels_test: np.ndarray,
+ show_values: bool = False,
+ ):
+ """
+ This function is for image data, it show the returns of the explain function.
+
+ Parameters
+ ---------
+ inputs
+ Tensor or Array. Input samples to be show next to examples.
+ Expected shape among (N,W), (N,T,W), (N,W,H,C).
+ examples
+ Represente the K nearust neighbours of the input.
+ examples_weight
+ features weight of the examples.
+ indice_original
+ Represente the indice of the inputs to show the true labels.
+ examples_labels
+ labels of the examples.
+ labels_test
+ Corresponding to labels of the dataset test.
+ show_values
+ boolean default at False, to show the values of examples.
+ """
+
+ # Initialize 'input_and_examples' and 'corresponding_weights' that they
+ # will be use to show every closest examples and the explanation
+ inputs = tf.expand_dims(inputs, 1)
+ input_and_examples = tf.concat([inputs, examples], axis=1)
+
+ # calcul the prediction of input and examples
+ # that they will be used at title of the image
+ # nevessary loop becaue we have n * k elements
+ predicted_labels = []
+ for samples in input_and_examples:
+ predicted = self.model(samples)
+ predicted = tf.argmax(predicted, axis=1)
+ predicted_labels.append(predicted)
+
+ # apply argmax function to labels
+ labels_test = tf.argmax(labels_test, axis=1)
+ examples_labels = tf.argmax(examples_labels, axis=1)
+
+ # define values_string if show_values is at None
+ values_string = ""
+
+ # loop to organize and show all results
+ for i in range(input_and_examples.shape[0]):
+ for j in range(input_and_examples.shape[1]):
+ if show_values is True:
+ values_string = f"\t\tValues: {input_and_examples[i][j]}"
+ if j == 0:
+ print(
+ f"Originale_data, indice: {indice_original[i]}"\
+ + f"\tDistance: \t\tGround Truth: {labels_test[i]}"\
+ + f"\t\tPrediction: {predicted_labels[i][j]}"
+ + values_string
+ )
+ else:
+ print(
+ f"\tExamples: {j}"\
+ + f"\t\tDistance: {round(examples_distance[i][j-1], 2)}"\
+ + f"\t\tGround Truth: {examples_labels[i][j-1]}"\
+ + f"\t\tPrediction: {predicted_labels[i][j]}"
+ + values_string
+ )
+ print("\n")
From c331801391a10fe64490018b12df3226bdb7cb79 Mon Sep 17 00:00:00 2001
From: Adelin CONSTANS
Date: Mon, 29 May 2023 09:39:36 +0200
Subject: [PATCH 002/138] tests: unit testing cole
---
tests/example_based/test_cole.py | 199 +++++++++++++++++++++++++++++++
1 file changed, 199 insertions(+)
create mode 100644 tests/example_based/test_cole.py
diff --git a/tests/example_based/test_cole.py b/tests/example_based/test_cole.py
new file mode 100644
index 00000000..941df68b
--- /dev/null
+++ b/tests/example_based/test_cole.py
@@ -0,0 +1,199 @@
+"""
+Test Cole
+"""
+from math import prod, sqrt
+
+import numpy as np
+from sklearn.metrics import DistanceMetric
+import tensorflow as tf
+
+from xplique.example_based import Cole
+from xplique.types import Union
+
+from ..utils import generate_data, generate_model, almost_equal, generate_agnostic_model
+
+
+def test_neighbors_distance():
+ """
+ The function test every output of the explanation method
+ """
+ # Method parameters initialisation
+ input_shape = (3, 3, 1)
+ nb_labels = 10
+ nb_samples = 10
+ nb_samples_test = 8
+ k = 3
+
+ # Data generation
+ matrix_train = tf.stack([i * tf.ones(input_shape) for i in range(nb_samples)])
+ matrix_test = matrix_train[1:-1]
+ labels_train = tf.range(nb_samples)
+ labels_test = labels_train[1:-1]
+
+ # Model generation
+ model = generate_model(input_shape, nb_labels)
+
+ # Initialisation of weights_extraction_function and distance_function
+ # They will be used in CaseBasedExplainer initialisation
+ distance_function = DistanceMetric.get_metric("euclidean")
+
+ # CaseBasedExplainer initialisation
+ method = Cole(
+ model,
+ matrix_train,
+ labels_train,
+ targets=None,
+ distance_function=distance_function,
+ weights_extraction_function=lambda inputs, targets: tf.ones(inputs.shape),
+ )
+
+ # Method explanation
+ (
+ examples,
+ examples_distance,
+ examples_weights,
+ inputs_weights,
+ examples_labels,
+ ) = method.explain(matrix_test, labels_test)
+
+ # test every outputs shape
+ assert examples.shape == (nb_samples_test, k) + input_shape
+ assert examples_distance.shape == (nb_samples_test, k)
+ assert examples_weights.shape == (nb_samples_test, k) + input_shape
+ assert inputs_weights.shape == (nb_samples_test,) + input_shape
+ assert examples_labels.shape == (nb_samples_test, k)
+
+ for i in range(len(labels_test)):
+ # test examples:
+ assert almost_equal(examples[i][0], matrix_train[i + 1])
+ assert almost_equal(examples[i][1], matrix_train[i + 2]) or almost_equal(
+ examples[i][1], matrix_train[i]
+ )
+ assert almost_equal(examples[i][2], matrix_train[i]) or almost_equal(
+ examples[i][2], matrix_train[i + 2]
+ )
+
+ # test examples_distance
+ assert almost_equal(examples_distance[i][0], 0)
+ assert almost_equal(examples_distance[i][1], sqrt(prod(input_shape)))
+ assert almost_equal(examples_distance[i][2], sqrt(prod(input_shape)))
+
+ # test examples_labels
+ assert almost_equal(examples_labels[i][0], labels_train[i + 1])
+ assert almost_equal(examples_labels[i][1], labels_train[i + 2]) or almost_equal(
+ examples_labels[i][1], labels_train[i]
+ )
+ assert almost_equal(examples_labels[i][2], labels_train[i]) or almost_equal(
+ examples_labels[i][2], labels_train[i + 2]
+ )
+
+
+def weights_attribution(
+ inputs: Union[tf.Tensor, np.ndarray], targets: Union[tf.Tensor, np.ndarray]
+):
+ """
+ Custom weights extraction function
+ Zeros everywhere and target at 0, 0, 0
+ """
+ weights = tf.Variable(tf.zeros(inputs.shape, dtype=tf.float32))
+ weights[:, 0, 0, 0].assign(targets)
+ return weights
+
+
+def test_weights_attribution():
+ """
+ Function to test the weights attribution
+ """
+ # Method parameters initialisation
+ input_shape = (3, 3, 1)
+ nb_labels = 10
+ nb_samples = 10
+
+ # Data generation
+ matrix_train = tf.stack(
+ [i * tf.ones(input_shape, dtype=tf.float32) for i in range(nb_samples)]
+ )
+ matrix_test = matrix_train[1:-1]
+ labels_train = tf.range(nb_samples, dtype=tf.float32)
+ labels_test = labels_train[1:-1]
+
+ # Model generation
+ model = generate_model(input_shape, nb_labels)
+
+ # Initialisation of distance_function
+ # It will be used in CaseBasedExplainer initialisation
+ distance_function = DistanceMetric.get_metric("euclidean")
+
+ # CaseBasedExplainer initialisation
+ method = Cole(
+ model,
+ matrix_train,
+ labels_train,
+ targets=labels_train,
+ distance_function=distance_function,
+ weights_extraction_function=weights_attribution,
+ )
+
+ # test case dataset weigth
+ assert almost_equal(method.case_dataset_weight[:, 0, 0, 0], method.labels_train)
+ assert almost_equal(
+ tf.reduce_sum(method.case_dataset_weight, axis=[1, 2, 3]), method.labels_train
+ )
+
+ # Method explanation
+ _, _, examples_weights, inputs_weights, examples_labels =\
+ method.explain(matrix_test, labels_test)
+
+ # test examples weights
+ assert almost_equal(examples_weights[:, :, 0, 0, 0], examples_labels)
+ assert almost_equal(
+ tf.reduce_sum(examples_weights, axis=[2, 3, 4]), examples_labels
+ )
+
+ # test inputs weights
+ assert almost_equal(inputs_weights[:, 0, 0, 0], labels_test)
+ assert almost_equal(tf.reduce_sum(inputs_weights, axis=[1, 2, 3]), labels_test)
+
+
+def test_tabular_inputs():
+ """
+ Function to test the acceptation of tabular data input in the method
+ """
+ # Method parameters initialisation
+ data_shape = (3,)
+ input_shape = data_shape
+ nb_labels = 3
+ nb_samples = 20
+ nb_inputs = 5
+ k = 3
+
+ # Data generation
+ dataset, targets = generate_data(data_shape, nb_labels, nb_samples)
+ dataset_train = dataset[:-nb_inputs]
+ dataset_test = dataset[-nb_inputs:]
+ targets_train = targets[:-nb_inputs]
+ targets_test = targets[-nb_inputs:]
+
+ # Model generation
+ model = generate_agnostic_model(input_shape, nb_labels)
+
+ # Initialisation of weights_extraction_function and distance_function
+ # They will be used in CaseBasedExplainer initialisation
+ distance_function = DistanceMetric.get_metric("euclidean")
+
+ # CaseBasedExplainer initialisation
+ method = Cole(
+ model,
+ dataset_train,
+ targets_train,
+ targets=targets_train,
+ distance_function=distance_function,
+ weights_extraction_function=lambda inputs, targets: tf.ones(inputs.shape),
+ k=k,
+ )
+
+ # Method explanation
+ examples, _, _, _, _ = method.explain(dataset_test, targets_test)
+
+ # test examples shape
+ assert examples.shape == (nb_inputs, k) + input_shape
From e1e4d82307464be8e462b99e9540f44e298fc6e4 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Mon, 29 May 2023 09:50:25 +0200
Subject: [PATCH 003/138] example based: introduce projections
---
xplique/example_based/projections/__init__.py | 8 +
.../example_based/projections/attributions.py | 156 ++++++++++++++++++
xplique/example_based/projections/base.py | 151 +++++++++++++++++
xplique/example_based/projections/custom.py | 90 ++++++++++
.../example_based/projections/latent_space.py | 48 ++++++
5 files changed, 453 insertions(+)
create mode 100644 xplique/example_based/projections/__init__.py
create mode 100644 xplique/example_based/projections/attributions.py
create mode 100644 xplique/example_based/projections/base.py
create mode 100644 xplique/example_based/projections/custom.py
create mode 100644 xplique/example_based/projections/latent_space.py
diff --git a/xplique/example_based/projections/__init__.py b/xplique/example_based/projections/__init__.py
new file mode 100644
index 00000000..d5d4cf90
--- /dev/null
+++ b/xplique/example_based/projections/__init__.py
@@ -0,0 +1,8 @@
+"""
+Projections
+"""
+
+from .attributions import AttributionProjection
+from .base import Projection
+from .custom import CustomProjection
+from .latent_space import LatentSpaceProjection
diff --git a/xplique/example_based/projections/attributions.py b/xplique/example_based/projections/attributions.py
new file mode 100644
index 00000000..7f9f624f
--- /dev/null
+++ b/xplique/example_based/projections/attributions.py
@@ -0,0 +1,156 @@
+"""
+Attribution, a projection from example based module
+"""
+
+
+import tensorflow as tf
+import numpy as np
+
+from ...attributions.base import BlackBoxExplainer
+from ...attributions import Saliency
+from ...commons import find_layer
+from ...types import Callable, Union, Optional
+
+from .base import Projection
+
+
+class AttributionProjection(Projection):
+ """
+ Projection build on an attribution function to provide local projections.
+ This class is used as the projection of the `Cole` similar examples method.
+
+ Depending on the `latent_layer`, the model will be splited between
+ the feature extractor and the predictor.
+ The feature extractor will become the `space_projection()` method, then
+ the predictor will be used to build the attribution method explain, and
+ its `explain()` method will become the `get_weights()` method.
+
+ If no `latent_layer` is provided, the model is not splited,
+ the `space_projection()` is the identity function, and
+ the attributions (`get_weights()`) are compute on the whole model.
+
+ Parameters
+ ----------
+ model
+ The model from which we want to obtain explanations.
+ latent_layer
+ Layer used to split the model, the first part will be used for projection and
+ the second to compute the attributions. By default, the model is not split.
+ For such split, the `model` should be a `tf.keras.Model`.
+
+ Layer to target for the outputs (e.g logits or after softmax).
+ If an `int` is provided it will be interpreted as a layer index.
+ If a `string` is provided it will look for the layer name.
+
+ The method as described in the paper apply the separation on the last convolutionnal layer.
+ To do so, the `"last_conv"` parameter will extract it.
+ Otherwise, `-1` could be used for the last layer before softmax.
+ attribution_method
+ Class of the attribution method to use for projection.
+ It should inherit from `xplique.attributions.base.BlackBoxExplainer`.
+ Ignored if a projection is given.
+ attribution_kwargs
+ Parameters to be passed at the construction of the `attribution_method`.
+ """
+
+ def __init__(
+ self,
+ model: Callable,
+ method: BlackBoxExplainer = Saliency,
+ latent_layer: Optional[Union[str, int]] = None,
+ **attribution_kwargs
+ ):
+ self.model = model
+
+ if latent_layer is None:
+ # no split
+ self.latent_layer = None
+ space_projection = lambda inputs: inputs
+ get_weights = method(model, **attribution_kwargs)
+ else:
+ # split the model if a latent_layer is provided
+ if latent_layer == "last_conv":
+ self.latent_layer = next(
+ layer for layer in model.layers[::-1] if hasattr(layer, "filters")
+ )
+ else:
+ self.latent_layer = find_layer(model, latent_layer)
+
+ space_projection = tf.keras.Model(
+ model.input, self.latent_layer.output, name="features_extractor"
+ )
+ self.predictor = tf.keras.Model(
+ self.latent_layer.output, model.output, name="predictor"
+ )
+ get_weights = method(self.predictor, **attribution_kwargs)
+
+ # set methods
+ super().__init__(get_weights, space_projection)
+
+ # attribution methods output do not have channel
+ # we wrap get_weights to expend dimensions if needed
+ self.__wrap_get_weights_to_extend_channels(self.get_weights)
+
+ def __wrap_get_weights_to_extend_channels(self, get_weights: Callable):
+ """
+ Extend channel if miss match between inputs and weights
+ """
+
+ def wrapped_get_weights(inputs, targets):
+ weights = get_weights(inputs, targets)
+ weights = tf.cond(
+ pred=weights.shape == inputs.shape,
+ true_fn=lambda: weights,
+ false_fn=lambda: tf.expand_dims(weights, axis=-1),
+ )
+ return weights
+
+ self.get_weights = wrapped_get_weights
+
+ def get_input_weights(
+ self,
+ inputs: Union[tf.Tensor, np.ndarray],
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ ):
+ """
+ For visualization purpose (and only), we may be interested to project weights
+ from the projected space to the input space.
+ This is applied only if their is a difference in dimension.
+ We assume here that we are treating images and an upsampling is applied.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ More information in the documentation.
+ targets
+ Additional parameter for `self.get_weights` function.
+
+ Returns
+ -------
+ input_weights
+ Tensor with the same dimension as `inputs` modulo the channels.
+ They are an upsampled version of the actual weights used in the projection.
+ """
+ projected_inputs = self.space_projection(inputs)
+ weights = self.get_weights(projected_inputs, targets)
+
+ # take mean over channels for images
+ channel_mean_fn = lambda: tf.reduce_mean(weights, axis=-1, keepdims=True)
+ weights = tf.cond(
+ pred=tf.shape(weights).shape[0] < 4,
+ true_fn=lambda: weights,
+ false_fn=channel_mean_fn,
+ )
+
+ # resizing
+ resize_fn = lambda: tf.image.resize(
+ weights, inputs.shape[1:-1], method="bicubic"
+ )
+ input_weights = tf.cond(
+ pred=projected_inputs.shape == inputs.shape,
+ true_fn=lambda: weights,
+ false_fn=resize_fn,
+ )
+ return input_weights
diff --git a/xplique/example_based/projections/base.py b/xplique/example_based/projections/base.py
new file mode 100644
index 00000000..debe261a
--- /dev/null
+++ b/xplique/example_based/projections/base.py
@@ -0,0 +1,151 @@
+"""
+Base projection for similar examples in example based module
+"""
+
+from abc import ABC
+
+import tensorflow as tf
+import numpy as np
+
+from ...commons import sanitize_inputs_targets
+from ...types import Callable, Union, Optional
+
+
+class Projection(ABC):
+ """
+ Base class used by `NaturalExampleBasedExplainer` to projet samples to a meaningfull space
+ for the model to explain.
+
+ Projection have two parts a `space_projection` and `weights`, to apply a projection,
+ the samples are first projected to a new space and then weighted.
+ Either the `space_projection` or the `weights` could be `None` but,
+ if both are, the projection is an identity function.
+
+ At least one of the two part should include the model in the computation
+ for distance between projected elements to make sense for the model.
+
+ Note that the cost of this projection should be limited
+ as it will be applied to all samples of the train dataset.
+
+ Parameters
+ ----------
+ get_weights
+ Callable, a function that return the weights (Tensor) for a given input (Tensor).
+ Weights should have the same shape as the input (possible difference on channels).
+
+ Example of `get_weights()` function:
+ ```
+ def get_weights_example(projected_inputs: Union(tf.Tensor, np.ndarray),
+ targets: Union(tf.Tensor, np.ndarray) = None):
+ '''
+ Example of function to get weights,
+ projected_inputs are the elements for which weights are comlputed.
+ targets are optionnal additionnal parameters for weights computation.
+ '''
+ weights = ... # do some magic with inputs and targets, it should use the model.
+ return weights
+ ```
+ space_projection
+ Callable that take samples and return a Tensor in the projected sapce.
+ An example of projected space is the latent space of a model. See `LatentSpaceProjection`
+ """
+
+ def __init__(self, get_weights: Callable = None, space_projection: Callable = None):
+ assert get_weights is not None or space_projection is not None, (
+ "At least one of `get_weights` and `space_projection`"
+ + "should not be `None`."
+ )
+
+ # set get weights
+ if get_weights is None:
+ # no weights
+ get_weights = lambda inputs, _: tf.ones(tf.shape(inputs))
+ if not hasattr(get_weights, "__call__"):
+ raise TypeError(
+ f"`get_weights` should be `Callable`, not a {type(get_weights)}"
+ )
+ self.get_weights = get_weights
+
+ # set space_projection
+ if space_projection is None:
+ space_projection = lambda inputs: inputs
+ if not hasattr(space_projection, "__call__"):
+ raise TypeError(
+ f"`space_projection` should be a `Callable`, not a {type(space_projection)}"
+ )
+ self.space_projection = space_projection
+
+ def get_input_weights(
+ self,
+ inputs: Union[tf.Tensor, np.ndarray],
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ ):
+ """
+ Depending on the projection, we may not be able to visualize weights
+ as they are after the space projection. In this case, this method should be overwritten,
+ as in `AttributionProjection` that applies an upsampling.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ More information in the documentation.
+ targets
+ Additional parameter for `self.get_weights` function.
+
+ Returns
+ -------
+ input_weights
+ Tensor with the same dimension as `inputs` modulo the channels.
+ They are an upsampled version of the actual weights used in the projection.
+ """
+ projected_inputs = self.space_projection(inputs)
+ assert tf.reduce_all(tf.equal(projected_inputs, inputs)), (
+ "Weights cannot be interpreted in the input space"
+ + "if `space_projection()` is not an identity."
+ + "Either remove 'weights' from the returns or"
+ + "make your own projection and overwrite `get_input_weights`."
+ )
+
+ weights = self.get_weights(projected_inputs, targets)
+
+ return weights
+
+ @sanitize_inputs_targets
+ def project(
+ self,
+ inputs: Union[tf.Tensor, np.ndarray],
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ ):
+ """
+ Project samples in a space meaningful for the model,
+ either by weights the inputs, projecting in a latent space or both.
+ This function should be called at the init and for each explanation.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ More information in the documentation.
+ targets
+ Additional parameter for `self.get_weights` function.
+
+ Returns
+ -------
+ projected_samples
+ The samples projected in the new space.
+ """
+ projected_inputs = self.space_projection(inputs)
+ weights = self.get_weights(projected_inputs, targets)
+
+ return tf.multiply(weights, projected_inputs)
+
+ def __call__(
+ self,
+ inputs: Union[tf.Tensor, np.ndarray],
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ ):
+ """project alias"""
+ return self.project(inputs, targets)
diff --git a/xplique/example_based/projections/custom.py b/xplique/example_based/projections/custom.py
new file mode 100644
index 00000000..966c6ada
--- /dev/null
+++ b/xplique/example_based/projections/custom.py
@@ -0,0 +1,90 @@
+"""
+Custom, a projection from example based module
+"""
+
+import tensorflow as tf
+import numpy as np
+
+from ...types import Callable, Union
+
+from .base import Projection
+
+
+class CustomProjection(Projection):
+ """
+ Base class used by `NaturalExampleBasedExplainer` to projet samples to a meaningfull space
+ for the model to explain.
+
+ Projection have two parts a `space_projection` and `weights`, to apply a projection,
+ the samples are first projected to a new space and then weighted.
+ Either the `space_projection` or the `weights` could be `None` but,
+ if both are, the projection is an identity function.
+
+ At least one of the two part should include the model in the computation
+ for distance between projected elements to make sense for the model.
+
+ Note that the cost of this projection should be limited
+ as it will be applied to all samples of the train dataset.
+
+ Parameters
+ ----------
+ weights
+ Either a Tensor or a Callable.
+ - In the case of a Tensor, weights are applied in the projected space
+ (after `space_projection`).
+ Hence weights should have the same shape as a `projected_input`.
+ - In the case of a Callable, the function should return the weights when called,
+ as a way to get the weights (a Tensor)
+ It is pertinent in the case on weights dependent on the inputs, i.e. local weighting.
+
+ Example of `get_weights()` function:
+ ```
+ def get_weights_example(projected_inputs: Union(tf.Tensor, np.ndarray),
+ targets: Union(tf.Tensor, np.ndarray) = None):
+ '''
+ Example of function to get weights,
+ projected_inputs are the elements for which weights are comlputed.
+ targets are optionnal additionnal parameters for weights computation.
+ '''
+ weights = ... # do some magic with inputs and targets, it should use the model.
+ return weights
+ ```
+ space_projection
+ Callable that take samples and return a Tensor in the projected sapce.
+ An example of projected space is the latent space of a model.
+ In this case, the model should be splitted and the
+ """
+
+ def __init__(
+ self,
+ weights: Union[Callable, tf.Tensor, np.ndarray] = None,
+ space_projection: Callable = None,
+ ):
+ # Set weights or
+ if weights is None or hasattr(weights, "__call__"):
+ # weights is already a function or there is no weights
+ get_weights = weights
+ elif isinstance(weights, (tf.Tensor, np.ndarray)):
+ # weights is a tensor
+ if isinstance(weights, np.ndarray):
+ weights = tf.convert_to_tensor(weights, dtype=tf.float32)
+
+ # define a function that returns the weights
+ def get_weights(inputs, _ = None):
+ nweights = tf.expand_dims(weights, axis=0)
+ return tf.repeat(nweights, tf.shape(inputs)[0], axis=0)
+
+ else:
+ raise TypeError(
+ "`weights` should be a tensor or a `Callable`,"
+ + f"not a {type(weights)}"
+ )
+
+ # Set space_projection
+ if space_projection is not None and not hasattr(space_projection, "__call__"):
+ raise TypeError(
+ "`space_projection` should be a `Callable`,"
+ + f"not a {type(space_projection)}"
+ )
+
+ super().__init__(get_weights, space_projection)
diff --git a/xplique/example_based/projections/latent_space.py b/xplique/example_based/projections/latent_space.py
new file mode 100644
index 00000000..04ce0304
--- /dev/null
+++ b/xplique/example_based/projections/latent_space.py
@@ -0,0 +1,48 @@
+"""
+Custom, a projection from example based module
+"""
+
+import tensorflow as tf
+
+from ...commons import find_layer
+from ...types import Callable, Union
+
+from .base import Projection
+
+
+class LatentSpaceProjection(Projection):
+ """
+ Projection that project inputs in the model latent space.
+ It does not have weighting.
+
+ Parameters
+ ----------
+ model
+ The model from which we want to obtain explanations.
+ latent_layer
+ Layer used to split the `model`.
+
+ Layer to target for the outputs (e.g logits or after softmax).
+ If an `int` is provided it will be interpreted as a layer index.
+ If a `string` is provided it will look for the layer name.
+
+ To separate after the last convolution, `"last_conv"` can be used.
+ Otherwise, `-1` could be used for the last layer before softmax.
+ """
+
+ def __init__(self, model: Callable, latent_layer: Union[str, int] = -1):
+ self.model = model
+
+ # split the model if a latent_layer is provided
+ if latent_layer == "last_conv":
+ self.latent_layer = next(
+ layer for layer in model.layers[::-1] if hasattr(layer, "filters")
+ )
+ else:
+ self.latent_layer = find_layer(model, latent_layer)
+
+ latent_space_projection = tf.keras.Model(
+ model.input, self.latent_layer.output, name="features_extractor"
+ )
+
+ super().__init__(space_projection=latent_space_projection)
From eda33170c54a914a05da54a9673818d0ad36c683 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Mon, 29 May 2023 09:50:46 +0200
Subject: [PATCH 004/138] example based: introduce search methods
---
.../example_based/search_methods/__init__.py | 8 +
xplique/example_based/search_methods/base.py | 180 +++++++++++++++
xplique/example_based/search_methods/knn.py | 207 ++++++++++++++++++
3 files changed, 395 insertions(+)
create mode 100644 xplique/example_based/search_methods/__init__.py
create mode 100644 xplique/example_based/search_methods/base.py
create mode 100644 xplique/example_based/search_methods/knn.py
diff --git a/xplique/example_based/search_methods/__init__.py b/xplique/example_based/search_methods/__init__.py
new file mode 100644
index 00000000..228e1acd
--- /dev/null
+++ b/xplique/example_based/search_methods/__init__.py
@@ -0,0 +1,8 @@
+"""
+Search methods
+"""
+
+from .base import BaseSearchMethod
+
+# from .sklearn_knn import SklearnKNN
+from .knn import KNN
diff --git a/xplique/example_based/search_methods/base.py b/xplique/example_based/search_methods/base.py
new file mode 100644
index 00000000..13a05f6a
--- /dev/null
+++ b/xplique/example_based/search_methods/base.py
@@ -0,0 +1,180 @@
+"""
+Base search method for example-based module
+"""
+
+from abc import ABC, abstractmethod
+
+import tensorflow as tf
+import numpy as np
+
+from ...types import Callable, Union, Optional, List
+
+from ...commons import sanitize_dataset
+
+from ..projections.base import Projection
+
+
+def _sanitize_returns(returns: Optional[Union[List[str], str]] = None,
+ possibilities: List[str] = None,
+ default: Union[List[str], str] = None):
+ """
+ Factorization of `set_returns` for `BaseSearchMethod` and `SimilarExamples`.
+ It cleans the `returns` parameter.
+ Results is either a sublist of possibilities or a value among possibilities.
+
+ Parameters
+ ----------
+ returns
+ The value to verify and put to the `instance.returns` attribute.
+ possibilities
+ List of possible unit values for `instance.returns`.
+ default
+ Value in case `returns` is None.
+
+ Returns
+ -------
+ returns
+ The cleaned `returns` value.
+ """
+ if possibilities is None:
+ possibilities = ["examples"]
+ if default is None:
+ default = ["examples"]
+
+ if returns is None:
+ returns = default
+ elif isinstance(returns, str):
+ if returns == "all":
+ returns = possibilities
+ elif returns in possibilities:
+ returns = [returns]
+ else:
+ raise ValueError(f"{returns} should belong to {possibilities}")
+ elif isinstance(returns, list):
+ pass # already in the right format.
+ else:
+ raise ValueError(f"{returns} should either be `str` or `List[str]`")
+
+ return returns
+
+
+class BaseSearchMethod(ABC):
+ """
+ Base class used by `NaturalExampleBasedExplainer` search examples in
+ a meaningful space for the model. It can also be used alone but will not provided
+ model explanations.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ For natural example-based methods it is the train dataset.
+ targets_dataset
+ Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
+ k
+ The number of examples to retrieve.
+ projection
+ Projection or Callable that project samples from the input space to the search space.
+ The search space sould be a space where distance make sense for the model.
+ It should not be `None`, otherwise,
+ all examples could be computed only with the `search_method`.
+
+ Example of Callable:
+ ```
+ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
+ '''
+ Example of projection,
+ inputs are the elements to project.
+ targets are optionnal parameters to orientated the projection.
+ '''
+ projected_inputs = # do some magic on inputs, it should use the model.
+ return projected_inputs
+ ```
+ search_returns
+ String or list of string with the elements to return in `self.find_examples()`.
+ See `self.set_returns()` for detail.
+ batch_size
+ Number of sample treated simultaneously.
+ It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
+ """
+
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ k: int = 1,
+ projection: Union[Projection, Callable] = None,
+ search_returns: Optional[Union[List[str], str]] = None,
+ batch_size: Optional[int] = 32,
+ ): # pylint: disable=R0801
+ # set batch size
+ if hasattr(cases_dataset, "_batch_size"):
+ self.batch_size = cases_dataset._batch_size
+ else:
+ self.batch_size = batch_size
+
+ self.cases_dataset = sanitize_dataset(cases_dataset, self.batch_size)
+ self.targets_dataset = sanitize_dataset(targets_dataset, self.batch_size)
+ if self.targets_dataset is None:
+ # The `find_examples()` method need to be able to iterate on `self.targets_dataset`
+ self.targets_dataset = [None] * self.cases_dataset.cardinality().numpy()
+
+ self.set_k(k)
+ self.set_returns(search_returns)
+ self.projection = projection
+
+ def set_k(self, k: int):
+ """
+ Change value of k with constructing a new `BaseSearchMethod`.
+ It is useful because the constructor can be computionnaly expensive.
+
+ Parameters
+ ----------
+ k
+ The number of examples to retrieve.
+ """
+ assert isinstance(k, int) and k >= 1, f"k should be an int >= 1 and not {k}"
+ self.k = k
+
+ def set_returns(self, returns: Optional[Union[List[str], str]] = None):
+ """
+ Set `self.returns` used to define returned elements in `self.find_examples()`.
+
+ Parameters
+ ----------
+ returns
+ Most elements are useful in `xplique.plots.plot_examples()`.
+ `returns` can be set to 'all' for all possible elements to be returned.
+ - 'examples' correspond to the expected examples,
+ the inputs may be included in first position. (n, k(+1), ...)
+ - 'indices' the indices of the examples in the `search_set`.
+ Used to retrieve the original example and labels. (n, k, ...)
+ - 'distances' the distances between the inputs and the corresponding examples.
+ They are associated to the examples. (n, k, ...)
+ - 'include_inputs' specify if inputs should be included in the returned elements.
+ Note that it changes the number of returned elements from k to k+1.
+ """
+ possibilities = ["examples", "indices", "distances", "include_inputs"]
+ default = "examples"
+ self.returns = _sanitize_returns(returns, possibilities, default)
+
+
+ @abstractmethod
+ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray]):
+ """
+ Search the samples to return as examples. Called by the explain methods.
+ It may also return the indices corresponding to the samples,
+ based on `self.returns` value.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Assumed to have been already projected.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ """
+ raise NotImplementedError()
+
+ def __call__(self, inputs: Union[tf.Tensor, np.ndarray]):
+ """find_samples alias"""
+ return self.find_examples(inputs)
diff --git a/xplique/example_based/search_methods/knn.py b/xplique/example_based/search_methods/knn.py
new file mode 100644
index 00000000..ed8d721b
--- /dev/null
+++ b/xplique/example_based/search_methods/knn.py
@@ -0,0 +1,207 @@
+"""
+KNN online search method in example-based module
+"""
+
+import numpy as np
+import tensorflow as tf
+
+from ...commons import dataset_gather
+from ...types import Callable, List, Union, Optional, Tuple
+
+from .base import BaseSearchMethod
+from ..projections import Projection
+
+
+class KNN(BaseSearchMethod):
+ """
+ KNN method to search examples. Based on `sklearn.neighbors.NearestNeighbors`.
+ Basically a wrapper of `NearestNeighbors` to match the `BaseSearchMethod` API.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ For natural example-based methods it is the train dataset.
+ targets_dataset
+ Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
+ k
+ The number of examples to retrieve.
+ projection
+ Projection or Callable that project samples from the input space to the search space.
+ The search space sould be a space where distance make sense for the model.
+ It should not be `None`, otherwise,
+ all examples could be computed only with the `search_method`.
+
+ Example of Callable:
+ ```
+ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
+ '''
+ Example of projection,
+ inputs are the elements to project.
+ targets are optionnal parameters to orientated the projection.
+ '''
+ projected_inputs = # do some magic on inputs, it should use the model.
+ return projected_inputs
+ ```
+ search_returns
+ String or list of string with the elements to return in `self.find_examples()`.
+ See `self.set_returns()` for detail.
+ batch_size
+ Number of sample treated simultaneously.
+ It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
+ distance
+ Either a Callable, or a value supported by `tf.norm` `ord` parameter.
+ Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
+ "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
+ yielding the corresponding p-norm." We also added 'cosine'.
+ """
+
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ k: int = 1,
+ projection: Union[Projection, Callable] = None,
+ search_returns: Optional[Union[List[str], str]] = None,
+ batch_size: Optional[int] = 32,
+ distance: Union[int, str, Callable] = "euclidean",
+ ): # pylint: disable=R0801
+ super().__init__(
+ cases_dataset, targets_dataset, k, projection, search_returns, batch_size
+ )
+
+ if hasattr(distance, "__call__"):
+ self.distance_fn = distance
+ elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(
+ distance, int
+ ):
+ self.distance_fn = lambda x1, x2: tf.norm(x1 - x2, ord=distance)
+ else:
+ raise AttributeError(
+ "The distance parameter is expected to be either a Callable or in"
+ + " ['fro', 'euclidean', 'cosine', 1, 2, np.inf] ",
+ +f"but {distance} was received.",
+ )
+
+ self.distance_fn_over_all_x2 = lambda x1, x2: tf.map_fn(
+ fn=lambda x2: self.distance_fn(x1, x2),
+ elems=x2,
+ )
+
+ # Computes crossed distances between two tensors x1(shape=(n1, ...)) and x2(shape=(n2, ...))
+ # The result is a distance matrix of size (n1, n2)
+ self.crossed_distances_fn = lambda x1, x2: tf.vectorized_map(
+ fn=lambda a1: self.distance_fn_over_all_x2(a1, x2),
+ elems=x1
+ )
+
+ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Tensor, tf.Tensor]:
+ """
+ Compute the k-neareast neighbors to each tensor of `inputs` in `self.cases_dataset`.
+ Here `self.cases_dataset` is a `tf.data.Dataset`, hence, computations are done by batches.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples on which knn are computed.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ More information in the documentation.
+
+ Returns
+ -------
+ best_distances
+ Tensor of distances between the knn and the inputs with dimension (n, k).
+ The n inputs times their k-nearest neighbors.
+ best_indices
+ Tensor of indices of the knn in `self.cases_dataset` with dimension (n, k, 2).
+ Where, n represent the number of inputs and k the number of corresponding examples.
+ The index of each element is encoded by two values,
+ the batch index and the index of the element in the batch.
+ Those indices can be used through `xplique.commons.tf_dataset_operation.dataset_gather`.
+ """
+ nb_inputs = tf.shape(inputs)[0]
+
+ # initialiaze
+ # (n, k, 2)
+ best_indices = tf.Variable(tf.fill((nb_inputs, self.k, 2), -1))
+ # (n, k)
+ best_distances = tf.Variable(tf.fill((nb_inputs, self.k), np.inf))
+ # (n, bs)
+ batch_indices = tf.expand_dims(tf.range(self.batch_size, dtype=tf.int32), axis=0)
+ batch_indices = tf.tile(batch_indices, multiples=(nb_inputs, 1))
+
+ # iterate on batches
+ for batch_index, (cases, targets) in enumerate(
+ zip(self.cases_dataset, self.targets_dataset)
+ ):
+ # project batch of dataset cases
+ if self.projection is not None:
+ projected_cases = self.projection.project(cases, targets)
+ else:
+ projected_cases = cases
+
+ # add new elements
+ # (n, current_bs, 2)
+ indices = batch_indices[:, : tf.shape(projected_cases)[0]]
+ new_indices = tf.stack(
+ [tf.fill(indices.shape, tf.cast(batch_index, tf.int32)), indices], axis=-1
+ )
+
+ # compute distances
+ # (n, current_bs)
+ distances = self.crossed_distances_fn(inputs, projected_cases)
+
+ # (n, k+curent_bs, 2)
+ concatenated_indices = tf.concat([best_indices, new_indices], axis=1)
+ # (n, k+curent_bs)
+ concatenated_distances = tf.concat([best_distances, distances], axis=1)
+
+ # sort all
+ # (n, k)
+ sort_order = tf.argsort(
+ concatenated_distances, axis=1, direction="ASCENDING"
+ )[:, : self.k]
+
+ best_indices.assign(
+ tf.gather(concatenated_indices, sort_order, axis=1, batch_dims=1)
+ )
+ best_distances.assign(
+ tf.gather(concatenated_distances, sort_order, axis=1, batch_dims=1)
+ )
+
+ return best_distances, best_indices
+
+ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray]):
+ """
+ Search the samples to return as examples. Called by the explain methods.
+ It may also return the indices corresponding to the samples,
+ based on `return_indices` value.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Assumed to have been already projected.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ """
+ # compute neighbors
+ examples_distances, examples_indices = self.kneighbors(inputs)
+
+ # Set values in return dict
+ return_dict = {}
+ if "examples" in self.returns:
+ return_dict["examples"] = dataset_gather(self.cases_dataset, examples_indices)
+ if "include_inputs" in self.returns:
+ inputs = tf.expand_dims(inputs, axis=1)
+ return_dict["examples"] = tf.concat(
+ [inputs, return_dict["examples"]], axis=1
+ )
+ if "indices" in self.returns:
+ return_dict["indices"] = examples_indices
+ if "distances" in self.returns:
+ return_dict["distances"] = examples_distances
+
+ # Return a dict only different variables are returned
+ if len(return_dict) == 1:
+ return list(return_dict.values())[0]
+ return return_dict
From 4b0a3406ee66b5fec54532d0f196913cff2cebc3 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Mon, 29 May 2023 09:51:58 +0200
Subject: [PATCH 005/138] example based: add base class
---
xplique/example_based/similar_examples.py | 380 ++++++++++++++++++++++
1 file changed, 380 insertions(+)
create mode 100644 xplique/example_based/similar_examples.py
diff --git a/xplique/example_based/similar_examples.py b/xplique/example_based/similar_examples.py
new file mode 100644
index 00000000..2961b1e0
--- /dev/null
+++ b/xplique/example_based/similar_examples.py
@@ -0,0 +1,380 @@
+"""
+Base model for example-based
+"""
+
+import math
+
+import tensorflow as tf
+import numpy as np
+
+from ..types import Callable, Dict, List, Optional, Type, Union
+
+from ..commons import sanitize_inputs_targets
+from ..commons import sanitize_dataset, dataset_gather
+from .search_methods import KNN, BaseSearchMethod
+from .projections import Projection
+
+from .search_methods.base import _sanitize_returns
+
+
+class SimilarExamples:
+ """
+ Base class for natural example-base methods explaining models,
+ they project the cases_dataset into a pertinent space for the with a `Projection`,
+ then they call the `BaseSearchMethod` on it.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Becareful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Becareful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ targets_dataset
+ Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Becareful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ search_method
+ An algorithm to search the examples in the projected space.
+ k
+ The number of examples to retrieve.
+ projection
+ Projection or Callable that project samples from the input space to the search space.
+ The search space sould be a space where distance make sense for the model.
+ It should not be `None`, otherwise,
+ all examples could be computed only with the `search_method`.
+
+ Example of Callable:
+ ```
+ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
+ '''
+ Example of projection,
+ inputs are the elements to project.
+ targets are optionnal parameters to orientated the projection.
+ '''
+ projected_inputs = # do some magic on inputs, it should use the model.
+ return projected_inputs
+ ```
+ case_returns
+ String or list of string with the elements to return in `self.explain()`.
+ See `self.set_returns()` for detail.
+ batch_size
+ Number of sample treated simultaneously for projection and search.
+ Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+ search_method_kwargs
+ Parameters to be passed at the construction of the `search_method`.
+ """
+
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ search_method: Type[BaseSearchMethod] = KNN,
+ k: int = 1,
+ projection: Union[Projection, Callable] = None,
+ case_returns: Union[List[str], str] = "examples",
+ batch_size: Optional[int] = 32,
+ **search_method_kwargs,
+ ):
+ assert (
+ projection is not None
+ ), "`SimilarExamples` without `projection` is a `BaseSearchMethod`."
+
+ # set attributes
+ batch_size = self.__initialize_cases_dataset(
+ cases_dataset, labels_dataset, targets_dataset, batch_size
+ )
+ self.k = k
+ self.set_returns(case_returns)
+ self.projection = projection
+
+ # set `search_returns` if not provided and overwrite it otherwise
+ search_method_kwargs["search_returns"] = ["indices", "distances"]
+
+ # initiate search_method
+ self.search_method = search_method(
+ cases_dataset=cases_dataset,
+ targets_dataset=targets_dataset,
+ k=k,
+ projection=projection,
+ batch_size=batch_size,
+ **search_method_kwargs,
+ )
+
+ def __initialize_cases_dataset(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]],
+ targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]],
+ batch_size: Optional[int],
+ ) -> int:
+ """
+ Factorization of `__init__()` method for dataset related attributes.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ labels_dataset
+ Labels associated to the examples in the dataset.
+ Indices should match with cases_dataset.
+ targets_dataset
+ Targets associated to the cases_dataset for dataset projection.
+ See `projection` for detail.
+ batch_size
+ Number of sample treated simultaneously when using the datasets.
+ Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+
+ Returns
+ -------
+ batch_size
+ Number of sample treated simultaneously when using the datasets.
+ Extracted from the datasets in case they are `tf.data.Dataset`.
+ Otherwise, the input value.
+ """
+ # at least one dataset provided
+ if isinstance(cases_dataset, tf.data.Dataset):
+ # set batch size (ignore provided argument) and cardinality
+ if isinstance(cases_dataset.element_spec, tuple):
+ batch_size = tf.shape(next(iter(cases_dataset))[0])[0].numpy()
+ else:
+ batch_size = tf.shape(next(iter(cases_dataset)))[0].numpy()
+
+ cardinality = cases_dataset.cardinality().numpy()
+ else:
+ # if case_dataset is not a `tf.data.Dataset`, then neither should the other.
+ assert not isinstance(labels_dataset, tf.data.Dataset)
+ assert not isinstance(targets_dataset, tf.data.Dataset)
+ # set batch size and cardinality
+ batch_size = min(batch_size, len(cases_dataset))
+ cardinality = math.ceil(len(cases_dataset) / batch_size)
+
+ # verify cardinality and create datasets from the tensors
+ self.cases_dataset = sanitize_dataset(
+ cases_dataset, batch_size, cardinality
+ )
+ self.labels_dataset = sanitize_dataset(
+ labels_dataset, batch_size, cardinality
+ )
+ self.targets_dataset = sanitize_dataset(
+ targets_dataset, batch_size, cardinality
+ )
+
+ # if the provided `cases_dataset` has several columns
+ if isinstance(self.cases_dataset.element_spec, tuple):
+ # switch case on the number of columns of `cases_dataset`
+ if len(self.cases_dataset.element_spec) == 2:
+ assert self.labels_dataset is None, (
+ "The second column of `cases_dataset` is assumed to be the labels."
+ + "Hence, `labels_dataset` should be empty."
+ )
+ self.labels_dataset = self.cases_dataset.map(lambda x, y: y)
+ self.cases_dataset = self.cases_dataset.map(lambda x, y: x)
+
+ elif len(self.cases_dataset.element_spec) == 3:
+ assert self.labels_dataset is None, (
+ "The second column of `cases_dataset` is assumed to be the labels."
+ + "Hence, `labels_dataset` should be empty."
+ )
+ assert self.targets_dataset is None, (
+ "The second column of `cases_dataset` is assumed to be the labels."
+ + "Hence, `labels_dataset` should be empty."
+ )
+ self.targets_dataset = self.cases_dataset.map(lambda x, y, t: t)
+ self.labels_dataset = self.cases_dataset.map(lambda x, y, t: y)
+ self.cases_dataset = self.cases_dataset.map(lambda x, y, t: x)
+ else:
+ raise AttributeError(
+ "`cases_dataset` cannot possess more than 3 columns,"
+ + f"{len(self.cases_dataset.element_spec)} were detected."
+ )
+
+ self.cases_dataset = self.cases_dataset.prefetch(tf.data.AUTOTUNE)
+ if self.labels_dataset is not None:
+ self.labels_dataset = self.labels_dataset.prefetch(tf.data.AUTOTUNE)
+ if self.targets_dataset is not None:
+ self.targets_dataset = self.targets_dataset.prefetch(tf.data.AUTOTUNE)
+
+ return batch_size
+
+ def set_k(self, k: int):
+ """
+ Setter for the k parameter.
+
+ Parameters
+ ----------
+ k
+ Number of examples to return, it should be a positive integer.
+ """
+ assert isinstance(k, int) and k >= 1, f"k should be an int >= 1 and not {k}"
+ self.k = k
+ self.search_method.set_k(k)
+
+ def set_returns(self, returns: Union[List[str], str]):
+ """
+ Set `self.returns` used to define returned elements in `self.explain()`.
+
+ Parameters
+ ----------
+ returns
+ Most elements are useful in `xplique.plots.plot_examples()`.
+ `returns` can be set to 'all' for all possible elements to be returned.
+ - 'examples' correspond to the expected examples,
+ the inputs may be included in first position. (n, k(+1), ...)
+ - 'weights' the weights in the input space used in the projection.
+ They are associated to the input and the examples. (n, k(+1), ...)
+ - 'distances' the distances between the inputs and the corresponding examples.
+ They are associated to the examples. (n, k, ...)
+ - 'labels' if provided through `dataset_labels`,
+ they are the labels associated with the examples. (n, k, ...)
+ - 'include_inputs' specify if inputs should be included in the returned elements.
+ Note that it changes the number of returned elements from k to k+1.
+ """
+ possibilities = ["examples", "weights", "distances", "labels", "include_inputs"]
+ default = "examples"
+ self.returns = _sanitize_returns(returns, possibilities, default)
+
+ @sanitize_inputs_targets
+ def explain(
+ self,
+ inputs: Union[tf.Tensor, np.ndarray],
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ ):
+ """
+ Compute examples to explain the inputs.
+ It project inputs with `self.projection` in the search space
+ and find examples with `self.search_method`.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ More information in the documentation.
+ targets
+ Tensor or Array passed to the projection function.
+
+ Returns
+ -------
+ return_dict
+ Dictionnary with listed elements in `self.returns`.
+ If only one element is present it returns the element.
+ The elements that can be returned are:
+ examples, weights, distances, indices, and labels.
+ """
+ # project inputs
+ projected_inputs = self.projection(inputs, targets)
+
+ # look for closest elements to projected inputs
+ search_output = self.search_method(projected_inputs)
+
+ # manage returned elements
+ return self.format_search_output(search_output, inputs, targets)
+
+ def __call__(
+ self,
+ inputs: Union[tf.Tensor, np.ndarray],
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ ):
+ """explain alias"""
+ return self.explain(inputs, targets)
+
+ def format_search_output(
+ self,
+ search_output: Dict[str, tf.Tensor],
+ inputs: Union[tf.Tensor, np.ndarray],
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ ):
+ """
+ Format the output of the `search_method` to match the expected returns in `self.returns`.
+
+ Parameters
+ ----------
+ search_output
+ Dictionnary with the required outputs from the `search_method`.
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ More information in the documentation.
+ targets
+ Tensor or Array passed to the projection function.
+ Here it is used by the explain function of attribution methods.
+ Refer to the corresponding method documentation for more detail.
+ Note that the default method is `Saliency`.
+
+ Returns
+ -------
+ return_dict
+ Dictionnary with listed elements in `self.returns`.
+ If only one element is present it returns the element.
+ The elements that can be returned are:
+ examples, weights, distances, indices, and labels.
+ """
+ return_dict = {}
+
+ examples = dataset_gather(self.cases_dataset, search_output["indices"])
+ examples_labels = dataset_gather(self.labels_dataset, search_output["indices"])
+ examples_targets = dataset_gather(
+ self.targets_dataset, search_output["indices"]
+ )
+
+ # add examples and weights
+ if "examples" in self.returns or "weights" in self.returns:
+ if "include_inputs" in self.returns:
+ # include inputs
+ inputs = tf.expand_dims(inputs, axis=1)
+ examples = tf.concat([inputs, examples], axis=1)
+ if targets is not None:
+ targets = tf.expand_dims(targets, axis=1)
+ examples_targets = tf.concat([targets, examples_targets], axis=1)
+ else:
+ examples_targets = [None] * len(examples)
+ if "examples" in self.returns:
+ return_dict["examples"] = examples
+ if "weights" in self.returns:
+ # get weights of examples (n, k, ...)
+ # we iterate on the inputs dimension through maps
+ # and ask weights for batch of examples
+ weights = []
+ for ex, ex_targ in zip(examples, examples_targets):
+ if isinstance(self.projection, Projection):
+ # get weights in the input space
+ weights.append(self.projection.get_input_weights(ex, ex_targ))
+ else:
+ raise AttributeError(
+ "Cannot extract weights from the provided projection function"
+ + "Either remove 'weights' from the `case_returns` or"
+ + "inherit from `Projection` and overwrite `get_input_weights`."
+ )
+
+ return_dict["weights"] = tf.stack(weights, axis=0)
+
+ # optimization test TODO
+ # return_dict["weights"] = tf.vectorized_map(
+ # fn=lambda x: self.projection.get_input_weights(x[0], x[1]),
+ # elems=(examples, examples_targets),
+ # # fn_output_signature=tf.float32,
+ # )
+
+ # add indices, distances, and labels
+ if "distances" in self.returns:
+ return_dict["distances"] = search_output["distances"]
+ if "labels" in self.returns:
+ assert (
+ examples_labels is not None
+ ), "The method cannot return labels without a label dataset."
+ return_dict["labels"] = examples_labels
+
+ # return a dict only different variables are returned
+ if len(return_dict) == 1:
+ return list(return_dict.values())[0]
+ return return_dict
From 0daf2817abc8ce8d6d31ff1bd7d1f76954e1c95d Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:36:52 +0100
Subject: [PATCH 006/138] cole: update and improve
---
xplique/__init__.py | 3 +-
xplique/example_based/__init__.py | 6 +
xplique/example_based/cole.py | 492 ++++++------------------------
xplique/types/__init__.py | 2 +-
4 files changed, 109 insertions(+), 394 deletions(-)
create mode 100644 xplique/example_based/__init__.py
diff --git a/xplique/__init__.py b/xplique/__init__.py
index 32ee5166..8ab3377a 100644
--- a/xplique/__init__.py
+++ b/xplique/__init__.py
@@ -9,9 +9,10 @@
__version__ = '1.3.3'
from . import attributions
+from . import commons
from . import concepts
+from . import example_based
from . import features_visualizations
-from . import commons
from . import plots
from .commons import Tasks
diff --git a/xplique/example_based/__init__.py b/xplique/example_based/__init__.py
new file mode 100644
index 00000000..a958a62b
--- /dev/null
+++ b/xplique/example_based/__init__.py
@@ -0,0 +1,6 @@
+"""
+Example-based methods available
+"""
+
+from .cole import Cole
+from .similar_examples import SimilarExamples
diff --git a/xplique/example_based/cole.py b/xplique/example_based/cole.py
index 63fa69c6..85c4c2d6 100644
--- a/xplique/example_based/cole.py
+++ b/xplique/example_based/cole.py
@@ -1,411 +1,119 @@
"""
-Module related to Case Base Explainer
+Implementation of Cole method a simlilar examples method from example based module
"""
-
-import matplotlib.pyplot as plt
import numpy as np
-from sklearn.metrics import DistanceMetric
-from sklearn.neighbors import KDTree
import tensorflow as tf
-from ..plots.image import _standardize_image
-from ..types import Callable, Union, Optional
+from ..attributions.base import BlackBoxExplainer
+from ..attributions import Saliency
+from ..types import Callable, List, Optional, Union, Type
+
+from .similar_examples import SimilarExamples
+from .projections import AttributionProjection
+from .search_methods import KNN
+from .search_methods import BaseSearchMethod
-class Cole:
+class Cole(SimilarExamples):
"""
- Used to compute the Case Based Explainer sytem, a twins sytem that use ANN and knn with
- the same dataset.
+ Cole is a similar examples methods that gives the most similar examples to a query.
+ Cole use the model to build a search space so that distances are meaningful for the model.
+ It uses attribution methods to weights inputs.
+ Those attributions may be computed in the latent space for complex data types like images.
- Ref. Eoin M. Kenny and Mark T. Keane.
+ It is an implementation of a method proposed by Kenny et Keane in 2019,
Twin-Systems to Explain Artificial Neural Networks using Case-Based Reasoning:
- Comparative Tests of Feature-Weighting Methods in ANN-CBR Twins for XAI. (2019)
- https://www.ijcai.org/proceedings/2019/376
+ https://researchrepository.ucd.ie/handle/10197/11064
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Becareful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Becareful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ targets_dataset
+ Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Becareful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ search_method
+ An algorithm to search the examples in the projected space.
+ k
+ The number of examples to retrieve. Default value is `1`.
+ distance
+ Either a Callable, or a value supported by `tf.norm` `ord` parameter.
+ Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
+ "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
+ yielding the corresponding p-norm."
+ case_returns
+ String or list of string with the elements to return in `self.explain()`.
+ See `self.set_returns()` from parent class `SimilarExamples` for detail.
+ By default, the `explain()` method will only return the examples.
+ batch_size
+ Number of sample treated simultaneously for projection and search.
+ Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+ latent_layer
+ Layer used to split the model, the first part will be used for projection and
+ the second to compute the attributions. By default, the model is not split.
+ For such split, the `model` should be a `tf.keras.Model`.
+
+ Layer to target for the outputs (e.g logits or after softmax).
+ If an `int` is provided it will be interpreted as a layer index.
+ If a `string` is provided it will look for the layer name.
+
+ The method as described in the paper apply the separation on the last convolutionnal layer.
+ To do so, the `"last_conv"` parameter will extract it.
+ Otherwise, `-1` could be used for the last layer before softmax.
+ attribution_method
+ Class of the attribution method to use for projection.
+ It should inherit from `xplique.attributions.base.BlackBoxExplainer`.
+ Ignored if a projection is given.
+ attribution_kwargs
+ Parameters to be passed at the construction of the `attribution_method`.
"""
def __init__(
self,
- model: Callable,
- case_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- labels_train: np.ndarray,
- targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
- distance_function: DistanceMetric = None,
- weights_extraction_function: Callable = None,
- k: Optional[int] = 3,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ model: tf.keras.Model,
+ targets_dataset: Union[tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ search_method: Type[BaseSearchMethod] = KNN,
+ k: int = 1,
+ distance: Union[str, Callable] = "euclidean",
+ case_returns: Optional[Union[List[str], str]] = "examples",
+ batch_size: Optional[int] = 32,
+ latent_layer: Optional[Union[str, int]] = None,
+ attribution_method: Type[BlackBoxExplainer] = Saliency,
+ **attribution_kwargs,
):
- """
- Parameters
- ----------
- model
- The model from wich we want to obtain explanations.
- case_dataset
- The dataset used to train the model,
- also use by the function to calcul the closest examples.
- labels_train
- labels define by the dataset.
- targets
- labels predict by the model from the dataset.
- distance_function
- The function to calcul the distance between the inputs and all the dataset.
- (Can use : euclidean, manhattan, minkowski etc...)
- weights_extraction_function
- The function to calcul the weight of every features,
- many type of methode can be use but it will depend of
- what type of dataset you've got.
- examples:
- def my_function(inputs, targets):
- # outputs.shape == inputs.shape
- return outputs
- k
- Represante how many nearest neighbours you want to be returns.
- """
- # set attributes
- self.model = model
- self.case_dataset = case_dataset
- self.weights_extraction_function = weights_extraction_function
- self.k_neighbors = k
- self.labels_train = labels_train
-
- # verify targets parametre
- if targets is None:
- targets = model(case_dataset)
- nb_classes = targets.shape[1]
- targets = tf.argmax(targets, axis=1)
- targets = tf.one_hot(
- targets, nb_classes
- ) # nb_classes normalement en second argument mais la du coup 10.
-
- # verify distance_function parametre
- if distance_function is None:
- distance_function = DistanceMetric.get_metric("euclidean")
-
- # verify weight_extraction_function parametre
- if weights_extraction_function is None:
- self.weights_extraction_function = self._get_default_weights_extraction_function()
-
- # compute case dataset weights (used in distance)
- # the weight extraction function may need the predictions to extract the weights
- case_dataset_weight = self.weights_extraction_function(case_dataset, targets)
- # for images, channels may disappear
- if len(case_dataset_weight.shape) != len(case_dataset.shape):
- case_dataset_weight = tf.expand_dims(case_dataset_weight, -1)
- self.case_dataset_weight = case_dataset_weight
-
- # apply weights to the case dataset (weighted distance)
- weighted_case_dataset = tf.math.multiply(case_dataset_weight, case_dataset)
- # flatten features for kdtree
- weighted_case_dataset = tf.reshape(
- weighted_case_dataset, [weighted_case_dataset.shape[0], -1]
+ # buil attribution projection
+ projection = AttributionProjection(
+ model=model,
+ method=attribution_method,
+ latent_layer=latent_layer,
+ **attribution_kwargs,
)
- # create kdtree instance with weighted case dataset
- # will be called to estimate closest examples
- self.knn = KDTree(weighted_case_dataset, metric=distance_function)
-
- def extract_element_from_indices(
- self,
- labels_train: np.ndarray,
- examples_indice: np.ndarray,
- ):
- """
- This function has to extract every example and weights from the dataset
- by the indice calculate with de knn query in the explain function
-
- Parameters
- ----------
- labels_train
- labels define by the dataset.
- examples_indice
- Represente the indice of the K nearust neighbours of the input.
-
- Returns
- -------
- examples
- Represente the K nearust neighbours of the input.
- examples_weights
- features weight of the examples.
- labels_examples
- labels of the examples.
- """
- all_examples = []
- all_weight_examples = []
- all_labels_examples = []
- for sample_examples_indice in examples_indice:
- sample_examples = []
- weight_ex = []
- label_ex = []
- for indice in sample_examples_indice:
- sample_examples.append(self.case_dataset[indice])
- weight_ex.append(self.case_dataset_weight[indice])
- label_ex.append(labels_train[indice])
- # (k, h, w, 1)
- all_examples.append(tf.stack(sample_examples, axis=0))
- all_weight_examples.append(tf.stack(weight_ex, axis=0))
- all_labels_examples.append(tf.stack(label_ex, axis=0))
- # (n, k, h, w, 1)
- examples = tf.stack(all_examples, axis=0)
- examples_weights = tf.stack(all_weight_examples, axis=0)
- labels_examples = tf.stack(all_labels_examples, axis=0)
-
- return examples, examples_weights, labels_examples
-
- def explain(
- self,
- inputs: Union[tf.Tensor, np.ndarray],
- targets: Union[tf.Tensor, np.ndarray] = None,
- ):
- """
- This function calculates the indice of the k closest example of the different inputs.
- Then calls extract_element_from_indice to extract the examples from those indices.
-
- Parameters
- ----------
- inputs
- Tensor or Array. Input samples to be explained.
- Expected shape among (N,W), (N,T,W), (N,W,H,C).
- targets
- Tensor or Array. Corresponding to the prediction of the samples by the model.
- shape: (n, nb_classes)
- Used by the `weights_extraction_function` if it is an Xplique attribution function,
- For more details, please refer to the explain methods documentation.
-
- Returns
- -------
- examples
- Represente the K nearust neighbours of the input.
- examples_distance
- distance between the input and the examples.
- examples_weight
- features weight of the examples.
- inputs_weights
- features weight of the inputs.
- examples_labels
- labels of the examples.
- """
-
- # verify targets parametre
- if targets is None:
- targets = self.model(inputs)
- nb_classes = targets.shape[1]
- targets = tf.argmax(targets, axis=1)
- targets = tf.one_hot(targets, nb_classes)
-
- # compute weight (used in distance)
- # the weight extraction function may need the prediction to extract the weights
- inputs_weights = self.weights_extraction_function(inputs, targets)
-
- # for images, channels may disappear
- if len(inputs_weights.shape) != len(inputs.shape):
- inputs_weights = tf.expand_dims(inputs_weights, -1)
-
- # apply weights to the inputs
- weighted_inputs = tf.math.multiply(inputs_weights, inputs)
- # flatten features for knn query
- weighted_inputs = tf.reshape(weighted_inputs, [weighted_inputs.shape[0], -1])
-
- # kdtree instance call with knn.query,
- # call with the weighted inputs and the number of closest examples (k)
- examples_distance, examples_indice = self.knn.query(
- weighted_inputs, k=self.k_neighbors
+ assert targets_dataset is not None
+
+ super().__init__(
+ cases_dataset,
+ labels_dataset,
+ targets_dataset,
+ search_method,
+ k,
+ projection,
+ case_returns,
+ batch_size,
+ distance=distance,
)
-
- # call the extract_element_from_indices function
- examples, examples_weights, examples_labels = self.extract_element_from_indices(
- self.labels_train, examples_indice
- )
-
- return (
- examples,
- examples_distance,
- examples_weights,
- inputs_weights,
- examples_labels,
- )
-
- @staticmethod
- def _get_default_weights_extraction_function():
- """
- This function allows you to get the default weight extraction function.
- """
- return lambda inputs, targets: tf.ones(inputs.shape)
-
- def show_result_images(
- self,
- inputs: Union[tf.Tensor, np.ndarray],
- examples: Union[tf.Tensor, np.ndarray],
- examples_distance: float,
- inputs_weights: np.ndarray,
- examples_weights: np.ndarray,
- indice_original: int,
- examples_labels: np.ndarray,
- labels_test: np.ndarray,
- clip_percentile: Optional[float] = 0.2,
- cmapimages: Optional[str] = "gray",
- cmapexplanation: Optional[str] = "coolwarm",
- alpha: Optional[float] = 0.5,
- ):
- """
- This function is for image data, it show the returns of the explain function.
-
- Parameters
- ---------
- inputs
- Tensor or Array. Input samples to be show next to examples.
- Expected shape among (N,W), (N,T,W), (N,W,H,C).
- examples
- Represente the K nearust neighbours of the input.
- examples_distance
- Distance between input data and examples.
- inputs_weights
- features weight of the inputs.
- examples_weight
- features weight of the examples.
- indice_original
- Represente the indice of the inputs to show the true labels.
- examples_labels
- labels of the examples.
- labels_test
- Corresponding to labels of the dataset test.
- clip_percentile
- Percentile value to use if clipping is needed, e.g a value of 1 will perform a clipping
- between percentile 1 and 99.
- This parameter allows to avoid outliers in case of too extreme values.
- cmapimages
- For images.
- The Colormap instance or registered colormap name used to map scalar data to colors.
- This parameter is ignored for RGB(A) data.
- cmapexplanation
- For explanation.
- The Colormap instance or registered colormap name used to map scalar data to colors.
- This parameter is ignored for RGB(A) data.
- alpha
- The alpha blending value, between 0 (transparent) and 1 (opaque).
- If alpha is an array, the alpha blending values are applied pixel by pixel,
- and alpha must have the same shape as X.
- """
- # pylint: disable=too-many-arguments
-
- # Initialize 'input_and_examples' and 'corresponding_weights' that they
- # will be use to show every closest examples and the explanation
- inputs = tf.expand_dims(inputs, 1)
- inputs_weights = tf.expand_dims(inputs_weights, 1)
- input_and_examples = tf.concat([inputs, examples], axis=1)
- corresponding_weights = tf.concat([inputs_weights, examples_weights], axis=1)
-
- # calcul the prediction of input and examples
- # that they will be used at title of the image
- # nevessary loop becaue we have n * k elements
- predicted_labels = []
- for samples in input_and_examples:
- predicted = self.model(samples)
- predicted = tf.argmax(predicted, axis=1)
- predicted_labels.append(predicted)
-
- # configure the grid to show all results
- plt.rcParams["figure.autolayout"] = True
- plt.rcParams["figure.figsize"] = [20, 10]
-
- # loop to organize and show all results
- for j in range(np.asarray(input_and_examples).shape[0]):
- fig = plt.figure()
- gridspec = fig.add_gridspec(2, input_and_examples.shape[1])
- for k in range(len(input_and_examples[j])):
- fig.add_subplot(gridspec[0, k])
- if k == 0:
- plt.title(
- f"Original image\nGround Truth: {labels_test[indice_original[j]]}"\
- + f"\nPrediction: {predicted_labels[j][k]}"
- )
- else:
- plt.title(
- f"Examples\nGround Truth: {examples_labels[j][k-1]}"\
- + f"\nPrediction: {predicted_labels[j][k]}"\
- + f"\nDistance: {round(examples_distance[j][k-1], 2)}"
- )
- plt.imshow(input_and_examples[j][k], cmap=cmapimages)
- plt.axis("off")
- fig.add_subplot(gridspec[1, k])
- plt.imshow(input_and_examples[j][k], cmap=cmapimages)
- plt.imshow(
- _standardize_image(corresponding_weights[j][k], clip_percentile),
- cmap=cmapexplanation,
- alpha=alpha,
- )
- plt.axis("off")
- plt.show()
-
- def show_result_tabular(
- self,
- inputs: Union[tf.Tensor, np.ndarray],
- examples: Union[tf.Tensor, np.ndarray],
- examples_distance: float,
- indice_original: int,
- examples_labels: np.ndarray,
- labels_test: np.ndarray,
- show_values: bool = False,
- ):
- """
- This function is for image data, it show the returns of the explain function.
-
- Parameters
- ---------
- inputs
- Tensor or Array. Input samples to be show next to examples.
- Expected shape among (N,W), (N,T,W), (N,W,H,C).
- examples
- Represente the K nearust neighbours of the input.
- examples_weight
- features weight of the examples.
- indice_original
- Represente the indice of the inputs to show the true labels.
- examples_labels
- labels of the examples.
- labels_test
- Corresponding to labels of the dataset test.
- show_values
- boolean default at False, to show the values of examples.
- """
-
- # Initialize 'input_and_examples' and 'corresponding_weights' that they
- # will be use to show every closest examples and the explanation
- inputs = tf.expand_dims(inputs, 1)
- input_and_examples = tf.concat([inputs, examples], axis=1)
-
- # calcul the prediction of input and examples
- # that they will be used at title of the image
- # nevessary loop becaue we have n * k elements
- predicted_labels = []
- for samples in input_and_examples:
- predicted = self.model(samples)
- predicted = tf.argmax(predicted, axis=1)
- predicted_labels.append(predicted)
-
- # apply argmax function to labels
- labels_test = tf.argmax(labels_test, axis=1)
- examples_labels = tf.argmax(examples_labels, axis=1)
-
- # define values_string if show_values is at None
- values_string = ""
-
- # loop to organize and show all results
- for i in range(input_and_examples.shape[0]):
- for j in range(input_and_examples.shape[1]):
- if show_values is True:
- values_string = f"\t\tValues: {input_and_examples[i][j]}"
- if j == 0:
- print(
- f"Originale_data, indice: {indice_original[i]}"\
- + f"\tDistance: \t\tGround Truth: {labels_test[i]}"\
- + f"\t\tPrediction: {predicted_labels[i][j]}"
- + values_string
- )
- else:
- print(
- f"\tExamples: {j}"\
- + f"\t\tDistance: {round(examples_distance[i][j-1], 2)}"\
- + f"\t\tGround Truth: {examples_labels[i][j-1]}"\
- + f"\t\tPrediction: {predicted_labels[i][j]}"
- + values_string
- )
- print("\n")
diff --git a/xplique/types/__init__.py b/xplique/types/__init__.py
index 52cca202..ba01d0c2 100644
--- a/xplique/types/__init__.py
+++ b/xplique/types/__init__.py
@@ -2,5 +2,5 @@
Typing module
"""
-from typing import Union, Tuple, List, Callable, Dict, Optional, Any
+from typing import Union, Tuple, List, Callable, Dict, Optional, Any, Type
from .custom_type import OperatorSignature
From 62a621d4e6a431d338f3b39f66816a25c8a9ca0d Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Mon, 29 May 2023 09:59:04 +0200
Subject: [PATCH 007/138] example based tests: update and complete
---
tests/commons/test_tf_dataset_operation.py | 144 +++++++++
tests/example_based/__init__.py | 0
tests/example_based/test_cole.py | 285 ++++++++---------
tests/example_based/test_image_plot.py | 101 ++++++
tests/example_based/test_similar_examples.py | 305 +++++++++++++++++++
tests/example_based/test_split_projection.py | 85 ++++++
tests/utils.py | 8 +
xplique/commons/__init__.py | 2 +-
xplique/commons/data_conversion.py | 33 +-
9 files changed, 806 insertions(+), 157 deletions(-)
create mode 100644 tests/commons/test_tf_dataset_operation.py
create mode 100644 tests/example_based/__init__.py
create mode 100644 tests/example_based/test_image_plot.py
create mode 100644 tests/example_based/test_similar_examples.py
create mode 100644 tests/example_based/test_split_projection.py
diff --git a/tests/commons/test_tf_dataset_operation.py b/tests/commons/test_tf_dataset_operation.py
new file mode 100644
index 00000000..1f9a5f42
--- /dev/null
+++ b/tests/commons/test_tf_dataset_operation.py
@@ -0,0 +1,144 @@
+"""
+Test operations on tf datasets
+"""
+import os
+import sys
+
+sys.path.append(os.getcwd())
+
+import unittest
+
+import numpy as np
+import tensorflow as tf
+
+
+from xplique.commons.tf_dataset_operations import *
+from xplique.commons.tf_dataset_operations import _almost_equal
+
+
+def test_are_dataset_first_elems_equal():
+ """
+ Verify that the function is able to compare the first element of datasets
+ """
+ tf_dataset_up = tf.data.Dataset.from_tensor_slices(
+ tf.reshape(tf.range(90), (10, 3, 3))
+ )
+ tf_dataset_up_small = tf.data.Dataset.from_tensor_slices(
+ tf.reshape(tf.range(45), (5, 3, 3))
+ )
+ tf_dataset_down = tf.data.Dataset.from_tensor_slices(
+ tf.reshape(tf.range(90, 0, -1), (10, 3, 3))
+ )
+
+ zipped = tf.data.Dataset.zip((tf_dataset_up, tf_dataset_up))
+ zipped_batched_in = tf.data.Dataset.zip(
+ (tf_dataset_up.batch(3), tf_dataset_up.batch(3))
+ )
+
+ assert are_dataset_first_elems_equal(tf_dataset_up, tf_dataset_up)
+ assert are_dataset_first_elems_equal(tf_dataset_up.batch(3), tf_dataset_up.batch(3))
+ assert are_dataset_first_elems_equal(tf_dataset_up, tf_dataset_up_small)
+ assert are_dataset_first_elems_equal(
+ tf_dataset_up.batch(3), tf_dataset_up_small.batch(3)
+ )
+ assert are_dataset_first_elems_equal(zipped, zipped)
+ assert are_dataset_first_elems_equal(zipped.batch(3), zipped.batch(3))
+ assert are_dataset_first_elems_equal(zipped_batched_in, zipped_batched_in)
+ assert not are_dataset_first_elems_equal(tf_dataset_up, zipped)
+ assert not are_dataset_first_elems_equal(tf_dataset_up.batch(3), zipped.batch(3))
+ assert not are_dataset_first_elems_equal(tf_dataset_up.batch(3), zipped_batched_in)
+ assert not are_dataset_first_elems_equal(tf_dataset_up, tf_dataset_down)
+ assert not are_dataset_first_elems_equal(
+ tf_dataset_up.batch(3), tf_dataset_down.batch(3)
+ )
+
+
+def test_is_not_shuffled():
+ """
+ Verify the function is able to detect dataset that do not provide stable order of elements
+ """
+ tf_dataset = tf.data.Dataset.from_tensor_slices(
+ tf.reshape(tf.range(90), (10, 3, 3))
+ )
+ tf_shuffled_once = tf_dataset.shuffle(3, reshuffle_each_iteration=False)
+ zipped = tf.data.Dataset.zip((tf_dataset, tf_dataset))
+
+ assert is_not_shuffled(tf_dataset)
+ assert is_not_shuffled(tf_dataset.batch(3))
+ assert is_not_shuffled(tf_shuffled_once)
+ assert is_not_shuffled(tf_shuffled_once.batch(3))
+ assert is_not_shuffled(zipped)
+ assert is_not_shuffled(zipped.batch(3))
+
+
+def test_batch_size_matches():
+ """
+ Test that the function is able to detect incoherence between dataset and batch_size
+ """
+ tf_dataset = tf.data.Dataset.from_tensor_slices(
+ tf.reshape(tf.range(90), (10, 3, 3))
+ )
+ tf_dataset_b2 = tf_dataset.batch(2)
+ tf_dataset_b5 = tf_dataset.batch(5)
+ tf_dataset_b25 = tf_dataset_b5.batch(2)
+ tf_dataset_b52 = tf_dataset_b2.batch(5)
+ tf_dataset_b32 = tf_dataset.batch(32)
+
+ assert batch_size_matches(tf_dataset, 3)
+ assert batch_size_matches(tf_dataset_b2, 2)
+ assert batch_size_matches(tf_dataset_b5, 5)
+ assert batch_size_matches(tf_dataset_b25, 2)
+ assert batch_size_matches(tf_dataset_b52, 5)
+ assert batch_size_matches(tf_dataset_b32, 10)
+
+
+def test_sanitize_dataset():
+ """
+ Test that verifies that the function harmonize inputs into datasets
+ """
+ tf_tensor = tf.reshape(tf.range(90), (10, 3, 3))
+ np_array = np.array(tf_tensor)
+ tf_dataset = tf.data.Dataset.from_tensor_slices(tf_tensor)
+ tf_dataset_b4 = tf_dataset.batch(4)
+
+ # test convertion
+ assert sanitize_dataset(None, 1) is None
+ assert are_dataset_first_elems_equal(tf_dataset, tf_dataset)
+ assert are_dataset_first_elems_equal(tf_dataset_b4, tf_dataset_b4)
+ assert are_dataset_first_elems_equal(
+ sanitize_dataset(tf_tensor, 4, 3), tf_dataset_b4
+ )
+ assert are_dataset_first_elems_equal(
+ sanitize_dataset(np_array, 4, 3), tf_dataset_b4
+ )
+
+ # test catch assertion errors
+ test_raise_assertion_error = unittest.TestCase().assertRaises
+ test_raise_assertion_error(
+ AssertionError, sanitize_dataset, tf_dataset.shuffle(2).batch(4), 4
+ )
+ test_raise_assertion_error(AssertionError, sanitize_dataset, tf_dataset_b4, 3)
+ test_raise_assertion_error(AssertionError, sanitize_dataset, tf_dataset_b4, 4, 4)
+ test_raise_assertion_error(AssertionError, sanitize_dataset, np_array[:6], 4, 4)
+
+
+def test_dataset_gather():
+ """
+ Test dataset gather function
+ """
+ # (5, 2, 3, 3)
+ tf_dataset = tf.data.Dataset.from_tensor_slices(
+ tf.reshape(tf.range(90), (10, 3, 3))
+ ).batch(2)
+
+ indices_1 = np.array([[[0, 0], [1, 1]], [[2, 1], [0, 0]]])
+ # (2, 2, 3, 3)
+ results_1 = dataset_gather(tf_dataset, indices_1)
+ assert np.all(tf.shape(results_1).numpy() == np.array([2, 2, 3, 3]))
+ assert _almost_equal(results_1[0, 0], results_1[1, 1])
+
+ indices_2 = tf.constant([[[1, 1]]])
+ # (1, 1, 3, 3)
+ results_2 = dataset_gather(tf_dataset, indices_2)
+ assert np.all(tf.shape(results_2).numpy() == np.array([1, 1, 3, 3]))
+ assert _almost_equal(results_1[0, 1], results_2[0, 0])
diff --git a/tests/example_based/__init__.py b/tests/example_based/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/example_based/test_cole.py b/tests/example_based/test_cole.py
index 941df68b..9fb1b73b 100644
--- a/tests/example_based/test_cole.py
+++ b/tests/example_based/test_cole.py
@@ -1,199 +1,182 @@
"""
Test Cole
"""
+import os
+
+import sys
+
+sys.path.append(os.getcwd())
+
from math import prod, sqrt
import numpy as np
-from sklearn.metrics import DistanceMetric
+import scipy
import tensorflow as tf
-from xplique.example_based import Cole
+from xplique.attributions import Occlusion, Saliency
+
+from xplique.example_based import Cole, SimilarExamples
+from xplique.example_based.projections import CustomProjection
+from xplique.example_based.search_methods import KNN
from xplique.types import Union
-from ..utils import generate_data, generate_model, almost_equal, generate_agnostic_model
+from tests.utils import (
+ generate_data,
+ generate_model,
+ almost_equal,
+ generate_timeseries_model,
+)
-def test_neighbors_distance():
+def get_setup(input_shape, nb_samples=10, nb_labels=10):
"""
- The function test every output of the explanation method
+ Generate data and model for Cole
"""
- # Method parameters initialisation
- input_shape = (3, 3, 1)
- nb_labels = 10
- nb_samples = 10
- nb_samples_test = 8
- k = 3
-
# Data generation
- matrix_train = tf.stack([i * tf.ones(input_shape) for i in range(nb_samples)])
- matrix_test = matrix_train[1:-1]
- labels_train = tf.range(nb_samples)
- labels_test = labels_train[1:-1]
+ x_train = tf.stack(
+ [i * tf.ones(input_shape, tf.float32) for i in range(nb_samples)]
+ )
+ x_test = x_train[1:-1]
+ y_train = tf.one_hot(tf.range(len(x_train)) % nb_labels, depth=nb_labels)
# Model generation
model = generate_model(input_shape, nb_labels)
- # Initialisation of weights_extraction_function and distance_function
- # They will be used in CaseBasedExplainer initialisation
- distance_function = DistanceMetric.get_metric("euclidean")
-
- # CaseBasedExplainer initialisation
- method = Cole(
- model,
- matrix_train,
- labels_train,
- targets=None,
- distance_function=distance_function,
- weights_extraction_function=lambda inputs, targets: tf.ones(inputs.shape),
- )
-
- # Method explanation
- (
- examples,
- examples_distance,
- examples_weights,
- inputs_weights,
- examples_labels,
- ) = method.explain(matrix_test, labels_test)
-
- # test every outputs shape
- assert examples.shape == (nb_samples_test, k) + input_shape
- assert examples_distance.shape == (nb_samples_test, k)
- assert examples_weights.shape == (nb_samples_test, k) + input_shape
- assert inputs_weights.shape == (nb_samples_test,) + input_shape
- assert examples_labels.shape == (nb_samples_test, k)
-
- for i in range(len(labels_test)):
- # test examples:
- assert almost_equal(examples[i][0], matrix_train[i + 1])
- assert almost_equal(examples[i][1], matrix_train[i + 2]) or almost_equal(
- examples[i][1], matrix_train[i]
- )
- assert almost_equal(examples[i][2], matrix_train[i]) or almost_equal(
- examples[i][2], matrix_train[i + 2]
- )
-
- # test examples_distance
- assert almost_equal(examples_distance[i][0], 0)
- assert almost_equal(examples_distance[i][1], sqrt(prod(input_shape)))
- assert almost_equal(examples_distance[i][2], sqrt(prod(input_shape)))
-
- # test examples_labels
- assert almost_equal(examples_labels[i][0], labels_train[i + 1])
- assert almost_equal(examples_labels[i][1], labels_train[i + 2]) or almost_equal(
- examples_labels[i][1], labels_train[i]
- )
- assert almost_equal(examples_labels[i][2], labels_train[i]) or almost_equal(
- examples_labels[i][2], labels_train[i + 2]
- )
-
-
-def weights_attribution(
- inputs: Union[tf.Tensor, np.ndarray], targets: Union[tf.Tensor, np.ndarray]
-):
- """
- Custom weights extraction function
- Zeros everywhere and target at 0, 0, 0
- """
- weights = tf.Variable(tf.zeros(inputs.shape, dtype=tf.float32))
- weights[:, 0, 0, 0].assign(targets)
- return weights
+ return model, x_train, x_test, y_train
-def test_weights_attribution():
+def test_cole_attribution():
"""
- Function to test the weights attribution
+ Test Cole attribution projection.
+ It should be the same as a manual projection.
+ Test that the distance has an impact.
"""
- # Method parameters initialisation
- input_shape = (3, 3, 1)
+ # Setup
+ nb_samples = 20
+ input_shape = (5, 5)
nb_labels = 10
- nb_samples = 10
-
- # Data generation
- matrix_train = tf.stack(
- [i * tf.ones(input_shape, dtype=tf.float32) for i in range(nb_samples)]
+ k = 3
+ x_train = tf.random.uniform(
+ (nb_samples,) + input_shape, minval=-1, maxval=1, seed=0
)
- matrix_test = matrix_train[1:-1]
- labels_train = tf.range(nb_samples, dtype=tf.float32)
- labels_test = labels_train[1:-1]
+ x_test = tf.random.uniform((nb_samples,) + input_shape, minval=-1, maxval=1, seed=2)
+ labels = tf.one_hot(
+ indices=tf.repeat(input=tf.range(nb_labels), repeats=[nb_samples // nb_labels]),
+ depth=nb_labels,
+ )
+ y_train = labels
+ y_test = tf.random.shuffle(labels, seed=1)
# Model generation
- model = generate_model(input_shape, nb_labels)
+ model = generate_timeseries_model(input_shape, nb_labels)
- # Initialisation of distance_function
- # It will be used in CaseBasedExplainer initialisation
- distance_function = DistanceMetric.get_metric("euclidean")
+ # Cole with attribution method constructor
+ method_constructor = Cole(
+ cases_dataset=x_train,
+ targets_dataset=y_train,
+ search_method=KNN,
+ k=k,
+ batch_size=7,
+ distance="euclidean",
+ model=model,
+ attribution_method=Saliency,
+ )
- # CaseBasedExplainer initialisation
- method = Cole(
- model,
- matrix_train,
- labels_train,
- targets=labels_train,
- distance_function=distance_function,
- weights_extraction_function=weights_attribution,
+ # Cole with attribution explain
+ projection = CustomProjection(weights=Saliency(model))
+
+ euclidean_dist = lambda x, z: tf.sqrt(tf.reduce_sum(tf.square(x - z)))
+ method_call = SimilarExamples(
+ cases_dataset=x_train,
+ targets_dataset=y_train,
+ search_method=KNN,
+ k=k,
+ distance=euclidean_dist,
+ projection=projection,
)
- # test case dataset weigth
- assert almost_equal(method.case_dataset_weight[:, 0, 0, 0], method.labels_train)
- assert almost_equal(
- tf.reduce_sum(method.case_dataset_weight, axis=[1, 2, 3]), method.labels_train
+ method_different_distance = Cole(
+ cases_dataset=x_train,
+ targets_dataset=y_train,
+ search_method=KNN,
+ k=k,
+ batch_size=2,
+ distance=np.inf, # infinity norm based distance
+ model=model,
+ attribution_method=Saliency,
)
- # Method explanation
- _, _, examples_weights, inputs_weights, examples_labels =\
- method.explain(matrix_test, labels_test)
+ # Generate explanation
+ examples_constructor = method_constructor.explain(x_test, y_test)
+ examples_call = method_call.explain(x_test, y_test)
+ examples_different_distance = method_different_distance(x_test, y_test)
+
+ # Verifications
+ # Shape should be (n, k, h, w, c)
+ assert examples_constructor.shape == (len(x_test), k) + input_shape
+ assert examples_call.shape == (len(x_test), k) + input_shape
+ assert examples_different_distance.shape == (len(x_test), k) + input_shape
+
+ # both methods should be the same
+ assert almost_equal(examples_constructor, examples_call)
- # test examples weights
- assert almost_equal(examples_weights[:, :, 0, 0, 0], examples_labels)
+ # a different distance should give different results
+ assert not almost_equal(examples_constructor, examples_different_distance)
+
+ # check weights are equal to the attribution directly on the input
+ method_constructor.set_returns(["weights", "include_inputs"])
assert almost_equal(
- tf.reduce_sum(examples_weights, axis=[2, 3, 4]), examples_labels
+ method_constructor.explain(x_test, y_test)[:, 0],
+ Saliency(model)(x_test, y_test),
)
- # test inputs weights
- assert almost_equal(inputs_weights[:, 0, 0, 0], labels_test)
- assert almost_equal(tf.reduce_sum(inputs_weights, axis=[1, 2, 3]), labels_test)
-
-def test_tabular_inputs():
+def test_cole_spliting():
"""
- Function to test the acceptation of tabular data input in the method
+ Test Cole with a `latent_layer` provided.
+ It should split the model.
"""
- # Method parameters initialisation
- data_shape = (3,)
- input_shape = data_shape
- nb_labels = 3
- nb_samples = 20
- nb_inputs = 5
- k = 3
-
- # Data generation
- dataset, targets = generate_data(data_shape, nb_labels, nb_samples)
- dataset_train = dataset[:-nb_inputs]
- dataset_test = dataset[-nb_inputs:]
- targets_train = targets[:-nb_inputs]
- targets_test = targets[-nb_inputs:]
+ # Setup
+ nb_samples = 10
+ input_shape = (6, 6, 3)
+ nb_labels = 5
+ k = 1
+ x_train = tf.random.uniform((nb_samples,) + input_shape, minval=0, maxval=1)
+ x_test = tf.random.uniform((nb_samples,) + input_shape, minval=0, maxval=1)
+ labels = tf.one_hot(
+ indices=tf.repeat(input=tf.range(nb_labels), repeats=[nb_samples // nb_labels]),
+ depth=nb_labels,
+ )
+ y_train = labels
+ y_test = tf.random.shuffle(labels)
# Model generation
- model = generate_agnostic_model(input_shape, nb_labels)
-
- # Initialisation of weights_extraction_function and distance_function
- # They will be used in CaseBasedExplainer initialisation
- distance_function = DistanceMetric.get_metric("euclidean")
+ model = generate_model(input_shape, nb_labels)
- # CaseBasedExplainer initialisation
+ # Cole with attribution method constructor
method = Cole(
- model,
- dataset_train,
- targets_train,
- targets=targets_train,
- distance_function=distance_function,
- weights_extraction_function=lambda inputs, targets: tf.ones(inputs.shape),
+ cases_dataset=x_train,
+ targets_dataset=y_train,
+ search_method=KNN,
k=k,
+ case_returns=["examples", "weights", "include_inputs"],
+ model=model,
+ latent_layer="last_conv",
+ attribution_method=Occlusion,
+ patch_size=2,
+ patch_stride=1,
)
- # Method explanation
- examples, _, _, _, _ = method.explain(dataset_test, targets_test)
+ # Generate explanation
+ outputs = method.explain(x_test, y_test)
+ examples, weights = outputs["examples"], outputs["weights"]
+
+ # Verifications
+ # Shape should be (n, k, h, w, c)
+ nb_samples_test = x_test.shape[0]
+ assert examples.shape == (nb_samples_test, k + 1) + input_shape
+ assert weights.shape[:-1] == (nb_samples_test, k + 1) + input_shape[:-1]
+
- # test examples shape
- assert examples.shape == (nb_inputs, k) + input_shape
+# test_cole_attribution()
+# test_cole_spliting()
diff --git a/tests/example_based/test_image_plot.py b/tests/example_based/test_image_plot.py
new file mode 100644
index 00000000..f8254d17
--- /dev/null
+++ b/tests/example_based/test_image_plot.py
@@ -0,0 +1,101 @@
+"""
+Test Cole
+"""
+import os
+import sys
+
+sys.path.append(os.getcwd())
+
+from math import prod, sqrt
+
+import numpy as np
+import scipy
+import tensorflow as tf
+
+from xplique.attributions import Occlusion, Saliency
+
+from xplique.example_based import Cole, SimilarExamples
+from xplique.example_based.projections import CustomProjection
+from xplique.example_based.search_methods import KNN
+from xplique.plots.image import plot_examples
+
+from tests.utils import (
+ generate_data,
+ generate_model,
+ almost_equal,
+ generate_timeseries_model,
+)
+
+
+def get_setup(input_shape, nb_samples=10, nb_labels=10):
+ """
+ Generate data and model for Cole
+ """
+ # Data generation
+ x_train = tf.stack(
+ [i * tf.ones(input_shape, tf.float32) for i in range(nb_samples)]
+ )
+ x_test = x_train[1:-1]
+ y_train = tf.one_hot(tf.range(len(x_train)) % nb_labels, depth=nb_labels)
+
+ # Model generation
+ model = generate_model(input_shape, nb_labels)
+
+ return model, x_train, x_test, y_train
+
+
+def test_plot_cole_spliting():
+ """
+ Test examples plot function.
+ """
+ # Setup
+ nb_samples = 10
+ input_shape = (6, 6, 3)
+ nb_labels = 5
+ k = 1
+ x_train = tf.random.uniform((nb_samples,) + input_shape, minval=0, maxval=1)
+ x_test = tf.random.uniform((nb_samples,) + input_shape, minval=0, maxval=1)
+ labels = tf.one_hot(
+ indices=tf.repeat(input=tf.range(nb_labels), repeats=[nb_samples // nb_labels]),
+ depth=nb_labels,
+ )
+ y_train = labels
+ y_test = tf.random.shuffle(labels)
+
+ # Model generation
+ model = generate_model(input_shape, nb_labels)
+
+ # Cole with attribution method constructor
+ method = Cole(
+ cases_dataset=x_train,
+ labels_dataset=tf.argmax(y_train, axis=1),
+ targets_dataset=y_train,
+ search_method=KNN,
+ k=k,
+ case_returns="all",
+ model=model,
+ latent_layer="last_conv",
+ attribution_method=Occlusion,
+ patch_size=2,
+ patch_stride=1,
+ )
+
+ # Generate explanation
+ outputs = method.explain(x_test, y_test)
+
+ # get predictions on examples
+ predicted_labels = tf.map_fn(
+ fn=lambda x: tf.cast(tf.argmax(model(x), axis=1), tf.int32),
+ elems=outputs["examples"],
+ fn_output_signature=tf.int32,
+ )
+
+ # test plot
+ plot_examples(
+ test_labels=tf.argmax(y_test, axis=1),
+ predicted_labels=predicted_labels,
+ **outputs
+ )
+
+
+# test_plot_cole_spliting()
diff --git a/tests/example_based/test_similar_examples.py b/tests/example_based/test_similar_examples.py
new file mode 100644
index 00000000..fbc5824e
--- /dev/null
+++ b/tests/example_based/test_similar_examples.py
@@ -0,0 +1,305 @@
+"""
+Test Cole
+"""
+import os
+import sys
+
+sys.path.append(os.getcwd())
+
+from math import prod, sqrt
+import unittest
+
+import numpy as np
+import tensorflow as tf
+
+from xplique.commons import sanitize_dataset, are_dataset_first_elems_equal
+from xplique.types import Union
+
+from xplique.example_based import SimilarExamples
+from xplique.example_based.projections import CustomProjection
+from xplique.example_based.search_methods import KNN
+
+from tests.utils import almost_equal
+
+
+def get_setup(input_shape, nb_samples=10, nb_labels=10):
+ """
+ Generate data and model for SimilarExamples
+ """
+ # Data generation
+ x_train = tf.stack(
+ [i * tf.ones(input_shape, tf.float32) for i in range(nb_samples)]
+ )
+ x_test = x_train[1:-1]
+ y_train = tf.range(len(x_train), dtype=tf.float32) % nb_labels
+
+ return x_train, x_test, y_train
+
+
+def test_similar_examples_input_datasets_management():
+ """
+ Test management of dataset init inputs
+ """
+ proj = CustomProjection(space_projection=lambda inputs, targets=None: inputs)
+
+ tf_tensor = tf.reshape(tf.range(90), (10, 3, 3))
+ np_array = np.array(tf_tensor)
+ tf_dataset = tf.data.Dataset.from_tensor_slices(tf_tensor)
+ too_short_np_array = np_array[:3]
+ too_long_tf_dataset = tf_dataset.concatenate(tf_dataset)
+
+ tf_dataset_b3 = tf_dataset.batch(3)
+ tf_dataset_b5 = tf_dataset.batch(5)
+ too_long_tf_dataset_b5 = too_long_tf_dataset.batch(5)
+ too_long_tf_dataset_b10 = too_long_tf_dataset.batch(10)
+
+ tf_shuffled = tf_dataset.shuffle(32, 0).batch(4)
+ tf_one_shuffle = tf_dataset.shuffle(32, 0, reshuffle_each_iteration=False).batch(4)
+
+ # Method initialization that should work
+ method = SimilarExamples(tf_dataset_b3, None, np_array, projection=proj)
+ assert are_dataset_first_elems_equal(method.cases_dataset, tf_dataset_b3)
+ assert are_dataset_first_elems_equal(method.labels_dataset, None)
+ assert are_dataset_first_elems_equal(method.targets_dataset, tf_dataset_b3)
+
+ method = SimilarExamples(np_array, tf_tensor, None, batch_size=5, projection=proj)
+ assert are_dataset_first_elems_equal(method.cases_dataset, tf_dataset_b5)
+ assert are_dataset_first_elems_equal(method.labels_dataset, tf_dataset_b5)
+ assert are_dataset_first_elems_equal(method.targets_dataset, None)
+
+ method = SimilarExamples(
+ tf.data.Dataset.zip((tf_dataset_b5, tf_dataset_b5)),
+ None,
+ np_array,
+ projection=proj,
+ )
+ assert are_dataset_first_elems_equal(method.cases_dataset, tf_dataset_b5)
+ assert are_dataset_first_elems_equal(method.labels_dataset, tf_dataset_b5)
+ assert are_dataset_first_elems_equal(method.targets_dataset, tf_dataset_b5)
+
+ method = SimilarExamples(
+ tf.data.Dataset.zip((tf_one_shuffle, tf_one_shuffle)), projection=proj
+ )
+ assert are_dataset_first_elems_equal(method.cases_dataset, tf_one_shuffle)
+ assert are_dataset_first_elems_equal(method.labels_dataset, tf_one_shuffle)
+ assert are_dataset_first_elems_equal(method.targets_dataset, None)
+
+ method = SimilarExamples(tf_one_shuffle, projection=proj)
+ assert are_dataset_first_elems_equal(method.cases_dataset, tf_one_shuffle)
+ assert are_dataset_first_elems_equal(method.labels_dataset, None)
+ assert are_dataset_first_elems_equal(method.targets_dataset, None)
+
+ # Method initialization that should not work
+ test_raise_assertion_error = unittest.TestCase().assertRaises
+ test_raise_assertion_error(TypeError, SimilarExamples)
+ test_raise_assertion_error(AssertionError, SimilarExamples, tf_tensor)
+ test_raise_assertion_error(
+ AssertionError, SimilarExamples, tf_shuffled, projection=proj
+ )
+ test_raise_assertion_error(
+ AssertionError, SimilarExamples, tf_dataset, tf_tensor, projection=proj
+ )
+ test_raise_assertion_error(
+ AssertionError, SimilarExamples, tf_dataset_b3, tf_dataset_b5, projection=proj
+ )
+ test_raise_assertion_error(
+ AssertionError,
+ SimilarExamples,
+ tf.data.Dataset.zip((tf_dataset_b5, tf_dataset_b5)),
+ np_array,
+ projection=proj,
+ )
+ test_raise_assertion_error(
+ AssertionError, SimilarExamples, tf_dataset_b3, too_short_np_array
+ )
+ test_raise_assertion_error(
+ AssertionError, SimilarExamples, tf_dataset, None, too_long_tf_dataset
+ )
+ test_raise_assertion_error(
+ AssertionError,
+ SimilarExamples,
+ tf_dataset_b5,
+ too_long_tf_dataset_b5,
+ projection=proj,
+ )
+ test_raise_assertion_error(
+ AssertionError,
+ SimilarExamples,
+ too_long_tf_dataset_b10,
+ tf_dataset_b5,
+ projection=proj,
+ )
+
+
+def test_similar_examples_basic():
+ """
+ Test the SimilarExamples with an identity projection.
+ """
+ # Setup
+ input_shape = (4, 4, 1)
+ k = 3
+ x_train, x_test, _ = get_setup(input_shape)
+
+ identity_projection = CustomProjection(
+ space_projection=lambda inputs, targets=None: inputs
+ )
+
+ # Method initialization
+ method = SimilarExamples(
+ cases_dataset=x_train,
+ projection=identity_projection,
+ search_method=KNN,
+ k=k,
+ batch_size=3,
+ distance="euclidean",
+ )
+
+ # Generate explanation
+ examples = method.explain(x_test)
+
+ # Verifications
+ # Shape should be (n, k, h, w, c)
+ assert examples.shape == (len(x_test), k) + input_shape
+
+ for i in range(len(x_test)):
+ # test examples:
+ assert almost_equal(examples[i, 0], x_train[i + 1])
+ assert almost_equal(examples[i, 1], x_train[i + 2]) or almost_equal(
+ examples[i, 1], x_train[i]
+ )
+ assert almost_equal(examples[i, 2], x_train[i]) or almost_equal(
+ examples[i, 2], x_train[i + 2]
+ )
+
+
+def test_similar_examples_return_multiple_elements():
+ """
+ Test the returns attribute.
+ Test modifying k.
+ """
+ # Setup
+ input_shape = (5, 5, 1)
+ k = 3
+ x_train, x_test, y_train = get_setup(input_shape)
+
+ nb_samples_test = len(x_test)
+ assert nb_samples_test + 2 == len(y_train)
+
+ identity_projection = CustomProjection(
+ space_projection=lambda inputs, targets=None: inputs
+ )
+
+ # Method initialization
+ method = SimilarExamples(
+ cases_dataset=x_train,
+ labels_dataset=y_train,
+ projection=identity_projection,
+ search_method=KNN,
+ k=1,
+ batch_size=3,
+ distance="euclidean",
+ )
+
+ method.set_returns("all")
+
+ method.set_k(k)
+
+ # Generate explanation
+ method_output = method.explain(x_test)
+
+ assert isinstance(method_output, dict)
+
+ examples = method_output["examples"]
+ weights = method_output["weights"]
+ distances = method_output["distances"]
+ labels = method_output["labels"]
+
+ # test every outputs shape (with the include inputs)
+ assert examples.shape == (nb_samples_test, k + 1) + input_shape
+ assert weights.shape == (nb_samples_test, k + 1) + input_shape
+ # the inputs distance ae zero and indices do not exist
+ assert distances.shape == (nb_samples_test, k)
+ assert labels.shape == (nb_samples_test, k)
+
+ for i in range(nb_samples_test):
+ # test examples:
+ assert almost_equal(examples[i, 0], x_test[i])
+ assert almost_equal(examples[i, 1], x_train[i + 1])
+ assert almost_equal(examples[i, 2], x_train[i + 2]) or almost_equal(
+ examples[i, 2], x_train[i]
+ )
+ assert almost_equal(examples[i, 3], x_train[i]) or almost_equal(
+ examples[i, 3], x_train[i + 2]
+ )
+
+ # test weights
+ assert almost_equal(weights[i], tf.ones(weights[i].shape, dtype=tf.float32))
+
+ # test distances
+ assert almost_equal(distances[i, 0], 0)
+ assert almost_equal(distances[i, 1], sqrt(prod(input_shape)))
+ assert almost_equal(distances[i, 2], sqrt(prod(input_shape)))
+
+ # test labels
+ assert almost_equal(labels[i, 0], y_train[i + 1])
+ assert almost_equal(labels[i, 1], y_train[i]) or almost_equal(
+ labels[i, 1], y_train[i + 2]
+ )
+ assert almost_equal(labels[i, 2], y_train[i]) or almost_equal(
+ labels[i, 2], y_train[i + 2]
+ )
+
+
+def test_similar_examples_weighting():
+ """
+ Test the application of the projection weighting.
+ """
+ # Setup
+ input_shape = (4, 4, 1)
+ nb_samples = 10
+ k = 3
+ x_train, x_test, y_train = get_setup(input_shape, nb_samples)
+
+ # Define the weighing function
+ weights = np.zeros(x_train[0].shape)
+ weights[1] = np.ones(weights[1].shape)
+
+ # create huge noise on non interesting features
+ noise = np.random.uniform(size=x_train.shape, low=-100, high=100)
+ x_train = np.float32(weights * np.array(x_train) + (1 - weights) * noise)
+
+ weighting_function = CustomProjection(weights=weights)
+
+ method = SimilarExamples(
+ cases_dataset=x_train,
+ labels_dataset=y_train,
+ projection=weighting_function,
+ search_method=KNN,
+ k=k,
+ batch_size=5,
+ distance="euclidean",
+ )
+
+ # Generate explanation
+ examples = method.explain(x_test)
+
+ # Verifications
+ # Shape should be (n, k, h, w, c)
+ nb_samples_test = x_test.shape[0]
+ assert examples.shape == (nb_samples_test, k) + input_shape
+
+ for i in range(nb_samples_test):
+ # test examples:
+ assert almost_equal(examples[i, 0], x_train[i + 1])
+ assert almost_equal(examples[i, 1], x_train[i + 2]) or almost_equal(
+ examples[i, 1], x_train[i]
+ )
+ assert almost_equal(examples[i, 2], x_train[i]) or almost_equal(
+ examples[i, 2], x_train[i + 2]
+ )
+
+
+# test_similar_examples_input_dataset_management()
+# test_similar_examples_basic()
+# test_similar_examples_return_multiple_elements()
+# test_similar_examples_weighting()
diff --git a/tests/example_based/test_split_projection.py b/tests/example_based/test_split_projection.py
new file mode 100644
index 00000000..bc560b48
--- /dev/null
+++ b/tests/example_based/test_split_projection.py
@@ -0,0 +1,85 @@
+import numpy as np
+import tensorflow as tf
+from tensorflow.keras.layers import (
+ Dense,
+ Conv2D,
+ Activation,
+ Dropout,
+ Flatten,
+ MaxPooling2D,
+ Input,
+)
+
+from xplique.example_based.projections import AttributionProjection
+from xplique.example_based.projections import LatentSpaceProjection
+from ..utils import generate_data, almost_equal
+
+
+def _generate_model(input_shape=(32, 32, 3), output_shape=10):
+ model = tf.keras.Sequential()
+ model.add(Input(shape=input_shape))
+ model.add(Conv2D(4, kernel_size=(2, 2), activation="relu", name="conv2d_1"))
+ model.add(Conv2D(4, kernel_size=(2, 2), activation="relu", name="conv2d_2"))
+ model.add(MaxPooling2D(pool_size=(2, 2)))
+ model.add(Dropout(0.25))
+ model.add(Flatten())
+ model.add(Dense(output_shape, name="dense"))
+ model.add(Activation("softmax", name="softmax"))
+ model.compile(loss="categorical_crossentropy", optimizer="sgd")
+
+ return model
+
+
+def test_attribution_latent_layer():
+ """We should target the right layer using either int, string or default procedure"""
+ tf.keras.backend.clear_session()
+
+ model = _generate_model()
+
+ first_conv_layer = model.get_layer("conv2d_1")
+ last_conv_layer = model.get_layer("conv2d_2")
+ flatten_layer = model.get_layer("flatten")
+
+ # default should not include model spliting
+ projection_default = AttributionProjection(model)
+ assert projection_default.latent_layer is None
+
+ # last_conv should be recognized
+ projection_default = AttributionProjection(model, latent_layer="last_conv")
+ assert projection_default.latent_layer == last_conv_layer
+
+ # target the first conv layer
+ projection_default = AttributionProjection(model, latent_layer=0)
+ assert projection_default.latent_layer == first_conv_layer
+
+ # target a random flatten layer
+ projection_default = AttributionProjection(model, latent_layer="flatten")
+ assert projection_default.latent_layer == flatten_layer
+
+
+def test_latent_space_latent_layer():
+ """We should target the right layer using either int, string or default procedure"""
+ tf.keras.backend.clear_session()
+
+ model = _generate_model()
+
+ first_conv_layer = model.get_layer("conv2d_1")
+ last_conv_layer = model.get_layer("conv2d_2")
+ flatten_layer = model.get_layer("flatten")
+ last_layer = model.get_layer("softmax")
+
+ # default should not include model spliting
+ projection_default = LatentSpaceProjection(model)
+ assert projection_default.latent_layer == last_layer
+
+ # last_conv should be recognized
+ projection_default = LatentSpaceProjection(model, latent_layer="last_conv")
+ assert projection_default.latent_layer == last_conv_layer
+
+ # target the first conv layer
+ projection_default = LatentSpaceProjection(model, latent_layer=0)
+ assert projection_default.latent_layer == first_conv_layer
+
+ # target a random flatten layer
+ projection_default = LatentSpaceProjection(model, latent_layer="flatten")
+ assert projection_default.latent_layer == flatten_layer
diff --git a/tests/utils.py b/tests/utils.py
index 67cf0e36..92d348e2 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -31,6 +31,14 @@ def generate_model(input_shape=(32, 32, 3), output_shape=10):
return model
+def generate_agnostic_model(input_shape=(3,), nb_labels=3):
+ model = Sequential()
+ model.add(Input(input_shape))
+ model.add(Flatten())
+ model.add(Dense(nb_labels))
+
+ return model
+
def generate_timeseries_model(input_shape=(20, 10), output_shape=10):
model = Sequential()
model.add(Input(shape=input_shape))
diff --git a/xplique/commons/__init__.py b/xplique/commons/__init__.py
index 94237f90..db9fcf3a 100644
--- a/xplique/commons/__init__.py
+++ b/xplique/commons/__init__.py
@@ -2,7 +2,7 @@
Utility classes and functions
"""
-from .data_conversion import tensor_sanitize, numpy_sanitize
+from .data_conversion import tensor_sanitize, numpy_sanitize, sanitize_inputs_targets
from .model_override import guided_relu_policy, deconv_relu_policy, override_relu_gradient, \
find_layer, open_relu_policy
from .tf_operations import repeat_labels, batch_tensor
diff --git a/xplique/commons/data_conversion.py b/xplique/commons/data_conversion.py
index 517f86ad..ae5d7eeb 100644
--- a/xplique/commons/data_conversion.py
+++ b/xplique/commons/data_conversion.py
@@ -5,11 +5,12 @@
import tensorflow as tf
import numpy as np
-from ..types import Union, Optional, Tuple
+from ..types import Union, Optional, Tuple, Callable
def tensor_sanitize(inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- targets: Optional[Union[tf.Tensor, np.ndarray]]) -> Tuple[tf.Tensor, tf.Tensor]:
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Ensure the output as tf.Tensor, accept various inputs format including:
tf.Tensor, List, numpy array, tf.data.Dataset (when label = None).
@@ -35,17 +36,20 @@ def tensor_sanitize(inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
if hasattr(inputs, '_batch_size'):
inputs = inputs.unbatch()
# unpack the dataset, assume we have tuple of (input, target)
- targets = [target for _, target in inputs]
inputs = [inp for inp, _ in inputs]
+ if targets is not None:
+ targets = [target for _, target in inputs]
inputs = tf.cast(inputs, tf.float32)
- targets = tf.cast(targets, tf.float32)
+ if targets is not None:
+ targets = tf.cast(targets, tf.float32)
return inputs, targets
def numpy_sanitize(inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- targets: Optional[Union[tf.Tensor, np.ndarray]]) -> Tuple[tf.Tensor, tf.Tensor]:
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Ensure the output as np.ndarray, accept various inputs format including:
tf.Tensor, List, numpy array, tf.data.Dataset (when label = None).
@@ -66,3 +70,22 @@ def numpy_sanitize(inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
"""
inputs, targets = tensor_sanitize(inputs, targets)
return inputs.numpy(), targets.numpy()
+
+
+def sanitize_inputs_targets(explanation_method: Callable):
+ """
+ Wrap a method explanation function to ensure tf.Tensor as inputs and targets.
+ But targets may be None.
+
+ explanation_method
+ Function to wrap, should return an tf.tensor.
+ """
+ def sanitize(self, inputs: Union[tf.data.Dataset, tf.Tensor, np.array],
+ targets: Optional[Union[tf.Tensor, np.array]] = None,
+ *args):
+ # ensure we have tf.tensor
+ inputs, targets = tensor_sanitize(inputs, targets)
+ # then enter the explanation function
+ return explanation_method(self, inputs, targets, *args)
+
+ return sanitize
From 2aa133dbf824fb0ffd7614a54fb62ab2cf8779c0 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:44:32 +0100
Subject: [PATCH 008/138] plots: add image visualization for example based
---
xplique/plots/__init__.py | 2 +-
xplique/plots/image.py | 118 +++++++++++++++++++++++++++++++++++++-
2 files changed, 118 insertions(+), 2 deletions(-)
diff --git a/xplique/plots/__init__.py b/xplique/plots/__init__.py
index 12e25eae..c7037f6a 100644
--- a/xplique/plots/__init__.py
+++ b/xplique/plots/__init__.py
@@ -1,6 +1,6 @@
"""
Utility functions to visualize explanations
"""
-from .image import plot_attributions, plot_attribution, plot_maco
+from .image import plot_attributions, plot_attribution, plot_maco, plot_examples
from .tabular import plot_feature_impact, plot_mean_feature_impact, summary_plot_tabular
from .timeseries import plot_timeseries_attributions
diff --git a/xplique/plots/image.py b/xplique/plots/image.py
index ca69b87d..c90d956b 100644
--- a/xplique/plots/image.py
+++ b/xplique/plots/image.py
@@ -171,7 +171,7 @@ def plot_attributions(
cols
Number of columns.
img_size
- Size of each subplots (in inch), considering we keep aspect ratio
+ Size of each subplots (in inch), considering we keep aspect ratio.
plot_kwargs
Additional parameters passed to `plt.imshow()`.
"""
@@ -230,3 +230,119 @@ def plot_maco(image, alpha, percentile_image=1.0, percentile_alpha=80):
plt.imshow(np.concatenate([image, alpha], -1))
plt.axis('off')
+
+
+def plot_examples(
+ examples: np.ndarray,
+ weights: np.ndarray = None,
+ distances: float = None,
+ labels: np.ndarray = None,
+ test_labels: np.ndarray = None,
+ predicted_labels: np.ndarray = None,
+ img_size: float = 2.,
+ **attribution_kwargs,
+):
+ """
+ This function is for image data, it show the returns of the explain function.
+
+ Parameters
+ ---------
+ examples
+ Represente the k nearest neighbours of the input. (n, k+1, h, w, c)
+ weights
+ Features weight of the examples.
+ distances
+ Distance between input data and examples.
+ labels
+ Labels of the examples.
+ labels_test
+ Corresponding to labels of the dataset test.
+ attribution_kwargs
+ Additionnal parameters passed to `xplique.plots.plot_attribution()`.
+ img_size:
+ Size of each subplots (in inch), considering we keep aspect ratio
+ """
+ # pylint: disable=too-many-arguments
+ if weights is not None:
+ assert examples.shape[:2] == weights.shape[:2],\
+ "Number of weights must correspond to the number of examples."
+ if distances is not None:
+ assert examples.shape[0] == distances.shape[0],\
+ "Number of samples treated should match between examples and distances."
+ assert examples.shape[1] == distances.shape[1] + 1,\
+ "Number of distances for each input must correspond to the number of examples -1."
+ if labels is not None:
+ assert examples.shape[0] == labels.shape[0],\
+ "Number of samples treated should match between examples and labels."
+ assert examples.shape[1] == labels.shape[1] + 1,\
+ "Number of labels for each input must correspond to the number of examples -1."
+
+ # number of rows depends if weights are provided
+ rows_by_input = 1 + (weights is not None)
+ rows = rows_by_input * examples.shape[0]
+ cols = examples.shape[1]
+ # get width and height of our images
+ l_width, l_height = examples.shape[2:4]
+
+ # define the figure margin, width, height in inch
+ margin = 0.3
+ spacing = 0.3
+ figwidth = cols * img_size + (cols-1) * spacing + 2 * margin
+ figheight = rows * img_size * l_height/l_width + (rows-1) * spacing + 2 * margin
+
+ left = margin/figwidth
+ bottom = margin/figheight
+
+ space_with_line = spacing / (3 * img_size)
+
+ fig = plt.figure()
+ fig.set_size_inches(figwidth, figheight)
+
+ fig.subplots_adjust(
+ left = left,
+ bottom = bottom,
+ right = 1.-left,
+ top = 1.-bottom,
+ wspace = spacing/img_size,
+ hspace= spacing/img_size * l_width/l_height
+ )
+
+ # configure the grid to show all results
+ plt.rcParams["figure.autolayout"] = True
+ plt.rcParams["figure.figsize"] = [3 * examples.shape[1], 4 * (1 + (weights is not None))]
+
+ # loop to organize and show all results
+ for i in range(examples.shape[0]):
+ for k in range(examples.shape[1]):
+ plt.subplot(rows, cols, rows_by_input * i * cols + k + 1)
+
+ # set title
+ if k == 0:
+ title = "Original image"
+ title += f"\nGround Truth: {test_labels[i]}" if test_labels is not None else ""
+ title += f"\nPrediction: {predicted_labels[i, k]}"\
+ if predicted_labels is not None else ""
+ else:
+ title = f"Example {k}"
+ title += f"\nGround Truth: {labels[i, k-1]}" if labels is not None else ""
+ title += f"\nPrediction: {predicted_labels[i, k]}"\
+ if predicted_labels is not None else ""
+ title += f"\nDistance: {distances[i, k-1]:.4f}" if distances is not None else ""
+ plt.title(title)
+
+ # plot image
+ img = _normalize(examples[i, k])
+ if img.shape[-1] == 1:
+ plt.imshow(img[:,:,0], cmap="gray")
+ else:
+ plt.imshow(img)
+ plt.axis("off")
+
+ # plot weights
+ if weights is not None:
+ plt.subplot(rows, cols, (rows_by_input * i + 1) * cols + k + 1)
+ plot_attribution(weights[i, k], examples[i, k], **attribution_kwargs)
+ plt.axis("off")
+ plt.plot([-1, 1.5], [-space_with_line, -space_with_line],
+ color='black', lw=1, transform=plt.gca().transAxes, clip_on=False)
+ fig.tight_layout()
From bed4747aa34035826886eb9ed168104a9c7cd4ed Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:50:46 +0100
Subject: [PATCH 009/138] commons: add operations for tf dataset
---
xplique/commons/__init__.py | 2 +
xplique/commons/data_conversion.py | 22 +--
xplique/commons/tf_dataset_operations.py | 235 +++++++++++++++++++++++
3 files changed, 247 insertions(+), 12 deletions(-)
create mode 100644 xplique/commons/tf_dataset_operations.py
diff --git a/xplique/commons/__init__.py b/xplique/commons/__init__.py
index db9fcf3a..c5312a2e 100644
--- a/xplique/commons/__init__.py
+++ b/xplique/commons/__init__.py
@@ -11,3 +11,5 @@
get_inference_function, get_gradient_functions)
from .exceptions import no_gradients_available, raise_invalid_operator
from .forgrad import forgrad
+from .tf_dataset_operations import are_dataset_first_elems_equal, dataset_gather, sanitize_dataset,\
+ is_not_shuffled, batch_size_matches
diff --git a/xplique/commons/data_conversion.py b/xplique/commons/data_conversion.py
index ae5d7eeb..9bcf3309 100644
--- a/xplique/commons/data_conversion.py
+++ b/xplique/commons/data_conversion.py
@@ -9,8 +9,7 @@
def tensor_sanitize(inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- targets: Optional[Union[tf.Tensor, np.ndarray]] = None
- ) -> Tuple[tf.Tensor, tf.Tensor]:
+ targets: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Ensure the output as tf.Tensor, accept various inputs format including:
tf.Tensor, List, numpy array, tf.data.Dataset (when label = None).
@@ -36,20 +35,17 @@ def tensor_sanitize(inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
if hasattr(inputs, '_batch_size'):
inputs = inputs.unbatch()
# unpack the dataset, assume we have tuple of (input, target)
+ targets = [target for _, target in inputs]
inputs = [inp for inp, _ in inputs]
- if targets is not None:
- targets = [target for _, target in inputs]
inputs = tf.cast(inputs, tf.float32)
- if targets is not None:
- targets = tf.cast(targets, tf.float32)
+ targets = tf.cast(targets, tf.float32)
return inputs, targets
def numpy_sanitize(inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- targets: Optional[Union[tf.Tensor, np.ndarray]] = None
- ) -> Tuple[tf.Tensor, tf.Tensor]:
+ targets: Optional[Union[tf.Tensor, np.ndarray]]) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Ensure the output as np.ndarray, accept various inputs format including:
tf.Tensor, List, numpy array, tf.data.Dataset (when label = None).
@@ -80,12 +76,14 @@ def sanitize_inputs_targets(explanation_method: Callable):
explanation_method
Function to wrap, should return an tf.tensor.
"""
- def sanitize(self, inputs: Union[tf.data.Dataset, tf.Tensor, np.array],
+ def sanitize(self, inputs: Union[tf.Tensor, np.array],
targets: Optional[Union[tf.Tensor, np.array]] = None,
- *args):
+ ):
# ensure we have tf.tensor
- inputs, targets = tensor_sanitize(inputs, targets)
+ inputs = tf.cast(inputs, tf.float32)
+ if targets is not None:
+ targets = tf.cast(targets, tf.float32)
# then enter the explanation function
- return explanation_method(self, inputs, targets, *args)
+ return explanation_method(self, inputs, targets)
return sanitize
diff --git a/xplique/commons/tf_dataset_operations.py b/xplique/commons/tf_dataset_operations.py
new file mode 100644
index 00000000..69e750eb
--- /dev/null
+++ b/xplique/commons/tf_dataset_operations.py
@@ -0,0 +1,235 @@
+"""
+Set of functions to manipulated `tf.data.Dataset`
+"""
+from itertools import product
+
+import numpy as np
+import tensorflow as tf
+
+from ..types import Optional, Union
+
+
+def _almost_equal(arr1, arr2, epsilon=1e-6):
+ """Ensure two array are almost equal at an epsilon"""
+ return np.shape(arr1) == np.shape(arr2) and np.sum(np.abs(arr1 - arr2)) < epsilon
+
+
+def are_dataset_first_elems_equal(
+ dataset1: Optional[tf.data.Dataset], dataset2: Optional[tf.data.Dataset]
+) -> bool:
+ """
+ Test if the first batch of elements of two datasets are the same.
+ It is used to verify equality between datasets in a lazy way.
+
+ Parameters
+ ----------
+ dataset1
+ First `tf.data.Dataset` to compare.
+ dataset2
+ Second `tf.data.Dataset` to compare.
+
+ Returns
+ -------
+ test_result
+ Boolean value of the equality.
+ """
+ if dataset1 is None:
+ return dataset2 is None
+
+ if dataset2 is None:
+ return False
+
+ next1 = next(iter(dataset1))
+ next2 = next(iter(dataset2))
+ if isinstance(next1, tuple):
+ next1 = next1[0]
+ if isinstance(next2, tuple):
+ next2 = next2[0]
+ else:
+ return False
+
+ return _almost_equal(next1, next2)
+
+
+def is_not_shuffled(dataset: Optional[tf.data.Dataset]) -> bool:
+ """
+ Test if the provided dataset reshuffle at each iteration.
+ Tensorflow do not provide clean way to verify it,
+ hence we draw two times the first element and compare it.
+ It may not always detect shuffled datasets, but this is enough of a safety net.
+
+ Parameters
+ ----------
+ dataset
+ Tensorflow dataset to test.
+
+ Returns
+ -------
+ test_result
+ Boolean value of the test.
+ """
+ return are_dataset_first_elems_equal(dataset, dataset)
+
+
+def batch_size_matches(dataset: Optional[tf.data.Dataset], batch_size: int) -> bool:
+ """
+ Test if batch size of a tensorflow dataset matches the expected one.
+ Tensorflow do not provide clean way to verify it,
+ hence we draw a batch and check its first dimension.
+ It may fail in some really precise cases, but this is enough of a safety net.
+
+ Parameters
+ ----------
+ dataset
+ Tensorflow dataset to test.
+ batch_size
+ The expected batch size of the dataset.
+
+ Returns
+ -------
+ test_result
+ Boolean value of the test.
+ """
+ if dataset is None:
+ # ignored
+ return True
+
+ first_item = next(iter(dataset))
+ if isinstance(first_item, tuple):
+ return tf.reduce_all(
+ [tf.shape(item)[0].numpy() == batch_size for item in first_item]
+ )
+ return tf.shape(first_item)[0].numpy() == batch_size
+
+
+def sanitize_dataset(
+ dataset: Union[tf.data.Dataset, tf.Tensor, np.array],
+ batch_size: int,
+ cardinality: Optional[int] = None,
+) -> Optional[tf.data.Dataset]:
+ """
+ Function to ensure input dataset match expected format.
+ It also transforms tensors in `tf.data.Dataset` and also verify the properties.
+ This function verify that datasets do not reshuffle at each iteration and
+ that their batch isze and cardinality match the expected ones.
+ Note that, that Tensorflow do not provide easy way to make those tests, hence,
+ for cost constraints, our tests are not perfect.
+
+ Parameters
+ ----------
+ dataset
+ Tensorflow dataset to verify or tensor to transform in `tf.data.Dataset` and verify.
+ batch_size
+ The expected batch size used either to verify the input dataset
+ or batch the transformed tensor.
+ cardinality
+ Expected number of batch in the dataset or batched transformed tensor.
+
+ Returns
+ -------
+ dataset
+ Verified dataset or transformed tensor. In both case a `tf.data.Dataset`,
+ that does not reshuffle at each iteration and
+ with batch size and cardinality matching the expected ones.
+ """
+ if dataset is not None:
+ if isinstance(dataset, tf.data.Dataset):
+ assert is_not_shuffled(dataset), (
+ "Datasets should not be shuffled, "
+ + "the order of the element should stay the same at each iteration."
+ )
+ assert batch_size_matches(
+ dataset, batch_size
+ ), "The batch size should match between datasets."
+ else:
+ dataset = tf.data.Dataset.from_tensor_slices(dataset).batch(batch_size)
+
+ if cardinality is not None and cardinality > 0:
+ dataset_cardinality = dataset.cardinality().numpy()
+ if dataset_cardinality > 0:
+ assert dataset_cardinality == cardinality, (
+ "The number of batch should match between datasets. "
+ + f"Received {dataset.cardinality().numpy()} vs {cardinality}. "
+ + "You may have provided non-batched datasets or datasets with different length."
+ )
+
+ return dataset
+
+
+def dataset_gather(dataset: tf.data.Dataset, indices: tf.Tensor) -> tf.Tensor:
+ """
+ Imitation of `tf.gather` for `tf.data.Dataset`,
+ it extract elements from `dataset` at the given indices.
+ We could see it as returning the `indices` tensor
+ where each index was replaced by the corresponding element in `dataset`.
+ The aim is to use it in the `example_based` module to extract examples form the cases dataset.
+ Hence, `indices` expect dimensions of (n, k, 2),
+ where n represent the number of inputs and k the number of corresponding examples.
+ Here indices for each element are encoded by two values,
+ the batch index and the index of the element in the batch.
+
+ Example of application
+ ```
+ >>> dataset = tf.data.Dataset.from_tensor_slices(
+ ... tf.reshape(tf.range(20), (-1, 2, 2))
+ ... ).batch(3) # shape=(None, 2, 2)
+ >>> indices = tf.constant([[[0, 0]], [[1, 0]]]) # shape=(2, 1, 2)
+ >>> dataset_gather(dataset, indices)
+
+ ```
+
+ Parameters
+ ----------
+ dataset
+ Tensorflow dataset to verify or tensor to transform in `tf.data.Dataset` and verify.
+ indices
+ Tensor of indices of elements to extract from the `dataset`.
+ `indices` should be of dimensions (n, k, 2),
+ this is to match the format of indices in the `example_based` module.
+ Indeed, n represent the number of inputs and k the number of corresponding examples.
+ The index of each element is encoded by two values,
+ the batch index and the index of the element in the batch.
+
+ Returns
+ -------
+ results
+
+ indices should be (n, k, 2)
+ """
+ if dataset is None:
+ return None
+
+ example = next(iter(dataset))
+ # (n, bs, ...)
+ results = tf.Variable(
+ tf.zeros(
+ indices.shape[:-1] + example[0].shape, dtype=dataset.element_spec.dtype
+ )
+ )
+
+ nb_results = product(indices.shape[:-1])
+ current_nb_results = 0
+
+ for i, batch in enumerate(dataset):
+ # check if the batch is interesting
+ if not tf.reduce_any(indices[..., 0] == i):
+ continue
+
+ # extract pertinent elements
+ pertinent_indices_location = tf.where(indices[..., 0] == i)
+ samples_index = tf.gather_nd(indices[..., 1], pertinent_indices_location)
+ samples = tf.gather(batch, samples_index)
+
+ # put them at the right place in results
+ for location, sample in zip(pertinent_indices_location, samples):
+ results[location[0], location[1]].assign(sample)
+ current_nb_results += 1
+
+ # test if results are filled to break the loop
+ if current_nb_results == nb_results:
+ break
+ return results
From 9f9acfb27452262631bceca95722e8102e45af8d Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:53:13 +0100
Subject: [PATCH 010/138] pylint: disable similarities for signatures
---
.pylintrc | 19 +++++++++++++++++++
setup.cfg | 1 +
2 files changed, 20 insertions(+)
create mode 100644 .pylintrc
diff --git a/.pylintrc b/.pylintrc
new file mode 100644
index 00000000..91513741
--- /dev/null
+++ b/.pylintrc
@@ -0,0 +1,19 @@
+[MASTER]
+disable=
+ R0903, # allows to expose only one public method
+ R0914, # allow multiples local variables
+ E0401, # pending issue with pylint see pylint#2603
+ E1123, # issues between pylint and tensorflow since 2.2.0
+ E1120, # see pylint#3613
+ C3001, # lambda function as variable
+
+[FORMAT]
+max-line-length=100
+max-args=12
+
+[SIMILARITIES]
+min-similarity-lines=6
+ignore-comments=yes
+ignore-docstrings=yes
+ignore-imports=no
+ignore-signatures=yes
diff --git a/setup.cfg b/setup.cfg
index 3fde7c5a..3a85bd3f 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -26,6 +26,7 @@ min-similarity-lines = 6
ignore-comments = yes
ignore-docstrings = yes
ignore-imports = no
+ignore-signatures = yes
[tox:tox]
envlist = py{37,38,39,310}-lint, py{37,38,39,310}-tf{22,25,28,211}, py{38,39,310}-tf{25,28,211}-torch{111,113,200}
From d29f92f38ac58f67d9979ecd8ea6a336157300e3 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Mon, 12 Feb 2024 16:48:57 +0100
Subject: [PATCH 011/138] example based: introduce base example method
abstraction
---
xplique/example_based/base_example_method.py | 380 +++++++++++++++++++
1 file changed, 380 insertions(+)
create mode 100644 xplique/example_based/base_example_method.py
diff --git a/xplique/example_based/base_example_method.py b/xplique/example_based/base_example_method.py
new file mode 100644
index 00000000..eca11a9e
--- /dev/null
+++ b/xplique/example_based/base_example_method.py
@@ -0,0 +1,380 @@
+"""
+Base model for example-based
+"""
+
+import math
+
+import tensorflow as tf
+import numpy as np
+
+from ..types import Callable, Dict, List, Optional, Type, Union
+
+from ..commons import sanitize_inputs_targets
+from ..commons import sanitize_dataset, dataset_gather
+from .search_methods import KNN, BaseSearchMethod
+from .projections import Projection
+
+from .search_methods.base import _sanitize_returns
+
+
+class BaseExampleMethod:
+ """
+ Base class for natural example-based methods explaining models,
+ they project the cases_dataset into a pertinent space for the with a `Projection`,
+ then they call the `BaseSearchMethod` on it.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ targets_dataset
+ Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ search_method
+ An algorithm to search the examples in the projected space.
+ k
+ The number of examples to retrieve.
+ projection
+ Projection or Callable that project samples from the input space to the search space.
+ The search space should be a space where distance make sense for the model.
+ It should not be `None`, otherwise,
+ all examples could be computed only with the `search_method`.
+
+ Example of Callable:
+ ```
+ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
+ '''
+ Example of projection,
+ inputs are the elements to project.
+ targets are optional parameters to orientated the projection.
+ '''
+ projected_inputs = # do some magic on inputs, it should use the model.
+ return projected_inputs
+ ```
+ case_returns
+ String or list of string with the elements to return in `self.explain()`.
+ See `self.set_returns()` for detail.
+ batch_size
+ Number of sample treated simultaneously for projection and search.
+ Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+ search_method_kwargs
+ Parameters to be passed at the construction of the `search_method`.
+ """
+
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ search_method: Type[BaseSearchMethod] = KNN,
+ k: int = 1,
+ projection: Union[Projection, Callable] = None,
+ case_returns: Union[List[str], str] = "examples",
+ batch_size: Optional[int] = 32,
+ **search_method_kwargs,
+ ):
+ assert (
+ projection is not None
+ ), "`BaseExampleMethod` without `projection` is a `BaseSearchMethod`."
+
+ # set attributes
+ batch_size = self.__initialize_cases_dataset(
+ cases_dataset, labels_dataset, targets_dataset, batch_size
+ )
+ self.k = k
+ self.set_returns(case_returns)
+ self.projection = projection
+
+ # set `search_returns` if not provided and overwrite it otherwise
+ search_method_kwargs["search_returns"] = ["indices", "distances"]
+
+ # initiate search_method
+ self.search_method = search_method(
+ cases_dataset=cases_dataset,
+ targets_dataset=targets_dataset,
+ k=k,
+ projection=projection,
+ batch_size=batch_size,
+ **search_method_kwargs,
+ )
+
+ def __initialize_cases_dataset(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]],
+ targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]],
+ batch_size: Optional[int],
+ ) -> int:
+ """
+ Factorization of `__init__()` method for dataset related attributes.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ labels_dataset
+ Labels associated to the examples in the dataset.
+ Indices should match with cases_dataset.
+ targets_dataset
+ Targets associated to the cases_dataset for dataset projection.
+ See `projection` for detail.
+ batch_size
+ Number of sample treated simultaneously when using the datasets.
+ Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+
+ Returns
+ -------
+ batch_size
+ Number of sample treated simultaneously when using the datasets.
+ Extracted from the datasets in case they are `tf.data.Dataset`.
+ Otherwise, the input value.
+ """
+ # at least one dataset provided
+ if isinstance(cases_dataset, tf.data.Dataset):
+ # set batch size (ignore provided argument) and cardinality
+ if isinstance(cases_dataset.element_spec, tuple):
+ batch_size = tf.shape(next(iter(cases_dataset))[0])[0].numpy()
+ else:
+ batch_size = tf.shape(next(iter(cases_dataset)))[0].numpy()
+
+ cardinality = cases_dataset.cardinality().numpy()
+ else:
+ # if case_dataset is not a `tf.data.Dataset`, then neither should the other.
+ assert not isinstance(labels_dataset, tf.data.Dataset)
+ assert not isinstance(targets_dataset, tf.data.Dataset)
+ # set batch size and cardinality
+ batch_size = min(batch_size, len(cases_dataset))
+ cardinality = math.ceil(len(cases_dataset) / batch_size)
+
+ # verify cardinality and create datasets from the tensors
+ self.cases_dataset = sanitize_dataset(
+ cases_dataset, batch_size, cardinality
+ )
+ self.labels_dataset = sanitize_dataset(
+ labels_dataset, batch_size, cardinality
+ )
+ self.targets_dataset = sanitize_dataset(
+ targets_dataset, batch_size, cardinality
+ )
+
+ # if the provided `cases_dataset` has several columns
+ if isinstance(self.cases_dataset.element_spec, tuple):
+ # switch case on the number of columns of `cases_dataset`
+ if len(self.cases_dataset.element_spec) == 2:
+ assert self.labels_dataset is None, (
+ "The second column of `cases_dataset` is assumed to be the labels."
+ + "Hence, `labels_dataset` should be empty."
+ )
+ self.labels_dataset = self.cases_dataset.map(lambda x, y: y)
+ self.cases_dataset = self.cases_dataset.map(lambda x, y: x)
+
+ elif len(self.cases_dataset.element_spec) == 3:
+ assert self.labels_dataset is None, (
+ "The second column of `cases_dataset` is assumed to be the labels."
+ + "Hence, `labels_dataset` should be empty."
+ )
+ assert self.targets_dataset is None, (
+ "The second column of `cases_dataset` is assumed to be the labels."
+ + "Hence, `labels_dataset` should be empty."
+ )
+ self.targets_dataset = self.cases_dataset.map(lambda x, y, t: t)
+ self.labels_dataset = self.cases_dataset.map(lambda x, y, t: y)
+ self.cases_dataset = self.cases_dataset.map(lambda x, y, t: x)
+ else:
+ raise AttributeError(
+ "`cases_dataset` cannot possess more than 3 columns,"
+ + f"{len(self.cases_dataset.element_spec)} were detected."
+ )
+
+ self.cases_dataset = self.cases_dataset.prefetch(tf.data.AUTOTUNE)
+ if self.labels_dataset is not None:
+ self.labels_dataset = self.labels_dataset.prefetch(tf.data.AUTOTUNE)
+ if self.targets_dataset is not None:
+ self.targets_dataset = self.targets_dataset.prefetch(tf.data.AUTOTUNE)
+
+ return batch_size
+
+ def set_k(self, k: int):
+ """
+ Setter for the k parameter.
+
+ Parameters
+ ----------
+ k
+ Number of examples to return, it should be a positive integer.
+ """
+ assert isinstance(k, int) and k >= 1, f"k should be an int >= 1 and not {k}"
+ self.k = k
+ self.search_method.set_k(k)
+
+ def set_returns(self, returns: Union[List[str], str]):
+ """
+ Set `self.returns` used to define returned elements in `self.explain()`.
+
+ Parameters
+ ----------
+ returns
+ Most elements are useful in `xplique.plots.plot_examples()`.
+ `returns` can be set to 'all' for all possible elements to be returned.
+ - 'examples' correspond to the expected examples,
+ the inputs may be included in first position. (n, k(+1), ...)
+ - 'weights' the weights in the input space used in the projection.
+ They are associated to the input and the examples. (n, k(+1), ...)
+ - 'distances' the distances between the inputs and the corresponding examples.
+ They are associated to the examples. (n, k, ...)
+ - 'labels' if provided through `dataset_labels`,
+ they are the labels associated with the examples. (n, k, ...)
+ - 'include_inputs' specify if inputs should be included in the returned elements.
+ Note that it changes the number of returned elements from k to k+1.
+ """
+ possibilities = ["examples", "weights", "distances", "labels", "include_inputs"]
+ default = "examples"
+ self.returns = _sanitize_returns(returns, possibilities, default)
+
+ @sanitize_inputs_targets
+ def explain(
+ self,
+ inputs: Union[tf.Tensor, np.ndarray],
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ ):
+ """
+ Compute examples to explain the inputs.
+ It project inputs with `self.projection` in the search space
+ and find examples with `self.search_method`.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ More information in the documentation.
+ targets
+ Tensor or Array passed to the projection function.
+
+ Returns
+ -------
+ return_dict
+ Dictionnary with listed elements in `self.returns`.
+ If only one element is present it returns the element.
+ The elements that can be returned are:
+ examples, weights, distances, indices, and labels.
+ """
+ # project inputs
+ projected_inputs = self.projection(inputs, targets)
+
+ # look for closest elements to projected inputs
+ search_output = self.search_method(projected_inputs)
+
+ # manage returned elements
+ return self.format_search_output(search_output, inputs, targets)
+
+ def __call__(
+ self,
+ inputs: Union[tf.Tensor, np.ndarray],
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ ):
+ """explain alias"""
+ return self.explain(inputs, targets)
+
+ def format_search_output(
+ self,
+ search_output: Dict[str, tf.Tensor],
+ inputs: Union[tf.Tensor, np.ndarray],
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ ):
+ """
+ Format the output of the `search_method` to match the expected returns in `self.returns`.
+
+ Parameters
+ ----------
+ search_output
+ Dictionnary with the required outputs from the `search_method`.
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ More information in the documentation.
+ targets
+ Tensor or Array passed to the projection function.
+ Here it is used by the explain function of attribution methods.
+ Refer to the corresponding method documentation for more detail.
+ Note that the default method is `Saliency`.
+
+ Returns
+ -------
+ return_dict
+ Dictionnary with listed elements in `self.returns`.
+ If only one element is present it returns the element.
+ The elements that can be returned are:
+ examples, weights, distances, indices, and labels.
+ """
+ return_dict = {}
+
+ examples = dataset_gather(self.cases_dataset, search_output["indices"])
+ examples_labels = dataset_gather(self.labels_dataset, search_output["indices"])
+ examples_targets = dataset_gather(
+ self.targets_dataset, search_output["indices"]
+ )
+
+ # add examples and weights
+ if "examples" in self.returns or "weights" in self.returns:
+ if "include_inputs" in self.returns:
+ # include inputs
+ inputs = tf.expand_dims(inputs, axis=1)
+ examples = tf.concat([inputs, examples], axis=1)
+ if targets is not None:
+ targets = tf.expand_dims(targets, axis=1)
+ examples_targets = tf.concat([targets, examples_targets], axis=1)
+ else:
+ examples_targets = [None] * len(examples)
+ if "examples" in self.returns:
+ return_dict["examples"] = examples
+ if "weights" in self.returns:
+ # get weights of examples (n, k, ...)
+ # we iterate on the inputs dimension through maps
+ # and ask weights for batch of examples
+ weights = []
+ for ex, ex_targ in zip(examples, examples_targets):
+ if isinstance(self.projection, Projection):
+ # get weights in the input space
+ weights.append(self.projection.get_input_weights(ex, ex_targ))
+ else:
+ raise AttributeError(
+ "Cannot extract weights from the provided projection function"
+ + "Either remove 'weights' from the `case_returns` or"
+ + "inherit from `Projection` and overwrite `get_input_weights`."
+ )
+
+ return_dict["weights"] = tf.stack(weights, axis=0)
+
+ # optimization test TODO
+ # return_dict["weights"] = tf.vectorized_map(
+ # fn=lambda x: self.projection.get_input_weights(x[0], x[1]),
+ # elems=(examples, examples_targets),
+ # # fn_output_signature=tf.float32,
+ # )
+
+ # add indices, distances, and labels
+ if "distances" in self.returns:
+ return_dict["distances"] = search_output["distances"]
+ if "labels" in self.returns:
+ assert (
+ examples_labels is not None
+ ), "The method cannot return labels without a label dataset."
+ return_dict["labels"] = examples_labels
+
+ # return a dict only different variables are returned
+ if len(return_dict) == 1:
+ return list(return_dict.values())[0]
+ return return_dict
From 69fef126c26117813ea23fdd053d2409f3a6fec2 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Mon, 12 Feb 2024 16:49:36 +0100
Subject: [PATCH 012/138] example based: adapt similar examples
---
xplique/example_based/similar_examples.py | 325 ++--------------------
1 file changed, 23 insertions(+), 302 deletions(-)
diff --git a/xplique/example_based/similar_examples.py b/xplique/example_based/similar_examples.py
index 2961b1e0..2a9634d3 100644
--- a/xplique/example_based/similar_examples.py
+++ b/xplique/example_based/similar_examples.py
@@ -13,42 +13,39 @@
from ..commons import sanitize_dataset, dataset_gather
from .search_methods import KNN, BaseSearchMethod
from .projections import Projection
+from .base_example_method import BaseExampleMethod
from .search_methods.base import _sanitize_returns
-class SimilarExamples:
+class SimilarExamples(BaseExampleMethod):
"""
- Base class for natural example-base methods explaining models,
- they project the cases_dataset into a pertinent space for the with a `Projection`,
- then they call the `BaseSearchMethod` on it.
+ Base class for similar examples.
Parameters
----------
cases_dataset
The dataset used to train the model, examples are extracted from the dataset.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Becareful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
labels_dataset
Labels associated to the examples in the dataset. Indices should match with cases_dataset.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
Batch size and cardinality of other dataset should match `cases_dataset`.
- Becareful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
targets_dataset
Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
Batch size and cardinality of other dataset should match `cases_dataset`.
- Becareful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
- search_method
- An algorithm to search the examples in the projected space.
k
The number of examples to retrieve.
projection
Projection or Callable that project samples from the input space to the search space.
- The search space sould be a space where distance make sense for the model.
+ The search space should be a space where distance make sense for the model.
It should not be `None`, otherwise,
all examples could be computed only with the `search_method`.
@@ -58,7 +55,7 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
'''
Example of projection,
inputs are the elements to project.
- targets are optionnal parameters to orientated the projection.
+ targets are optional parameters to orientated the projection.
'''
projected_inputs = # do some magic on inputs, it should use the model.
return projected_inputs
@@ -69,8 +66,12 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
batch_size
Number of sample treated simultaneously for projection and search.
Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
- search_method_kwargs
- Parameters to be passed at the construction of the `search_method`.
+ distance
+ Distance for the knn search method.
+ Either a Callable, or a value supported by `tf.norm` `ord` parameter.
+ Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
+ "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
+ yielding the corresponding p-norm." We also added 'cosine'.
"""
def __init__(
@@ -78,303 +79,23 @@ def __init__(
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- search_method: Type[BaseSearchMethod] = KNN,
k: int = 1,
projection: Union[Projection, Callable] = None,
case_returns: Union[List[str], str] = "examples",
batch_size: Optional[int] = 32,
- **search_method_kwargs,
+ distance: Union[int, str, Callable] = "euclidean",
):
- assert (
- projection is not None
- ), "`SimilarExamples` without `projection` is a `BaseSearchMethod`."
-
- # set attributes
- batch_size = self.__initialize_cases_dataset(
- cases_dataset, labels_dataset, targets_dataset, batch_size
- )
- self.k = k
- self.set_returns(case_returns)
- self.projection = projection
-
- # set `search_returns` if not provided and overwrite it otherwise
- search_method_kwargs["search_returns"] = ["indices", "distances"]
-
- # initiate search_method
- self.search_method = search_method(
+ # the only difference with parent is that the search method is always KNN
+ search_method = KNN
+
+ super().__init__(
cases_dataset=cases_dataset,
+ labels_dataset=labels_dataset,
targets_dataset=targets_dataset,
+ search_method=search_method,
k=k,
projection=projection,
+ case_returns=case_returns,
batch_size=batch_size,
- **search_method_kwargs,
- )
-
- def __initialize_cases_dataset(
- self,
- cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]],
- targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]],
- batch_size: Optional[int],
- ) -> int:
- """
- Factorization of `__init__()` method for dataset related attributes.
-
- Parameters
- ----------
- cases_dataset
- The dataset used to train the model, examples are extracted from the dataset.
- labels_dataset
- Labels associated to the examples in the dataset.
- Indices should match with cases_dataset.
- targets_dataset
- Targets associated to the cases_dataset for dataset projection.
- See `projection` for detail.
- batch_size
- Number of sample treated simultaneously when using the datasets.
- Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
-
- Returns
- -------
- batch_size
- Number of sample treated simultaneously when using the datasets.
- Extracted from the datasets in case they are `tf.data.Dataset`.
- Otherwise, the input value.
- """
- # at least one dataset provided
- if isinstance(cases_dataset, tf.data.Dataset):
- # set batch size (ignore provided argument) and cardinality
- if isinstance(cases_dataset.element_spec, tuple):
- batch_size = tf.shape(next(iter(cases_dataset))[0])[0].numpy()
- else:
- batch_size = tf.shape(next(iter(cases_dataset)))[0].numpy()
-
- cardinality = cases_dataset.cardinality().numpy()
- else:
- # if case_dataset is not a `tf.data.Dataset`, then neither should the other.
- assert not isinstance(labels_dataset, tf.data.Dataset)
- assert not isinstance(targets_dataset, tf.data.Dataset)
- # set batch size and cardinality
- batch_size = min(batch_size, len(cases_dataset))
- cardinality = math.ceil(len(cases_dataset) / batch_size)
-
- # verify cardinality and create datasets from the tensors
- self.cases_dataset = sanitize_dataset(
- cases_dataset, batch_size, cardinality
- )
- self.labels_dataset = sanitize_dataset(
- labels_dataset, batch_size, cardinality
- )
- self.targets_dataset = sanitize_dataset(
- targets_dataset, batch_size, cardinality
- )
-
- # if the provided `cases_dataset` has several columns
- if isinstance(self.cases_dataset.element_spec, tuple):
- # switch case on the number of columns of `cases_dataset`
- if len(self.cases_dataset.element_spec) == 2:
- assert self.labels_dataset is None, (
- "The second column of `cases_dataset` is assumed to be the labels."
- + "Hence, `labels_dataset` should be empty."
- )
- self.labels_dataset = self.cases_dataset.map(lambda x, y: y)
- self.cases_dataset = self.cases_dataset.map(lambda x, y: x)
-
- elif len(self.cases_dataset.element_spec) == 3:
- assert self.labels_dataset is None, (
- "The second column of `cases_dataset` is assumed to be the labels."
- + "Hence, `labels_dataset` should be empty."
- )
- assert self.targets_dataset is None, (
- "The second column of `cases_dataset` is assumed to be the labels."
- + "Hence, `labels_dataset` should be empty."
- )
- self.targets_dataset = self.cases_dataset.map(lambda x, y, t: t)
- self.labels_dataset = self.cases_dataset.map(lambda x, y, t: y)
- self.cases_dataset = self.cases_dataset.map(lambda x, y, t: x)
- else:
- raise AttributeError(
- "`cases_dataset` cannot possess more than 3 columns,"
- + f"{len(self.cases_dataset.element_spec)} were detected."
- )
-
- self.cases_dataset = self.cases_dataset.prefetch(tf.data.AUTOTUNE)
- if self.labels_dataset is not None:
- self.labels_dataset = self.labels_dataset.prefetch(tf.data.AUTOTUNE)
- if self.targets_dataset is not None:
- self.targets_dataset = self.targets_dataset.prefetch(tf.data.AUTOTUNE)
-
- return batch_size
-
- def set_k(self, k: int):
- """
- Setter for the k parameter.
-
- Parameters
- ----------
- k
- Number of examples to return, it should be a positive integer.
- """
- assert isinstance(k, int) and k >= 1, f"k should be an int >= 1 and not {k}"
- self.k = k
- self.search_method.set_k(k)
-
- def set_returns(self, returns: Union[List[str], str]):
- """
- Set `self.returns` used to define returned elements in `self.explain()`.
-
- Parameters
- ----------
- returns
- Most elements are useful in `xplique.plots.plot_examples()`.
- `returns` can be set to 'all' for all possible elements to be returned.
- - 'examples' correspond to the expected examples,
- the inputs may be included in first position. (n, k(+1), ...)
- - 'weights' the weights in the input space used in the projection.
- They are associated to the input and the examples. (n, k(+1), ...)
- - 'distances' the distances between the inputs and the corresponding examples.
- They are associated to the examples. (n, k, ...)
- - 'labels' if provided through `dataset_labels`,
- they are the labels associated with the examples. (n, k, ...)
- - 'include_inputs' specify if inputs should be included in the returned elements.
- Note that it changes the number of returned elements from k to k+1.
- """
- possibilities = ["examples", "weights", "distances", "labels", "include_inputs"]
- default = "examples"
- self.returns = _sanitize_returns(returns, possibilities, default)
-
- @sanitize_inputs_targets
- def explain(
- self,
- inputs: Union[tf.Tensor, np.ndarray],
- targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
- ):
- """
- Compute examples to explain the inputs.
- It project inputs with `self.projection` in the search space
- and find examples with `self.search_method`.
-
- Parameters
- ----------
- inputs
- Tensor or Array. Input samples to be explained.
- Expected shape among (N, W), (N, T, W), (N, W, H, C).
- More information in the documentation.
- targets
- Tensor or Array passed to the projection function.
-
- Returns
- -------
- return_dict
- Dictionnary with listed elements in `self.returns`.
- If only one element is present it returns the element.
- The elements that can be returned are:
- examples, weights, distances, indices, and labels.
- """
- # project inputs
- projected_inputs = self.projection(inputs, targets)
-
- # look for closest elements to projected inputs
- search_output = self.search_method(projected_inputs)
-
- # manage returned elements
- return self.format_search_output(search_output, inputs, targets)
-
- def __call__(
- self,
- inputs: Union[tf.Tensor, np.ndarray],
- targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
- ):
- """explain alias"""
- return self.explain(inputs, targets)
-
- def format_search_output(
- self,
- search_output: Dict[str, tf.Tensor],
- inputs: Union[tf.Tensor, np.ndarray],
- targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
- ):
- """
- Format the output of the `search_method` to match the expected returns in `self.returns`.
-
- Parameters
- ----------
- search_output
- Dictionnary with the required outputs from the `search_method`.
- inputs
- Tensor or Array. Input samples to be explained.
- Expected shape among (N, W), (N, T, W), (N, W, H, C).
- More information in the documentation.
- targets
- Tensor or Array passed to the projection function.
- Here it is used by the explain function of attribution methods.
- Refer to the corresponding method documentation for more detail.
- Note that the default method is `Saliency`.
-
- Returns
- -------
- return_dict
- Dictionnary with listed elements in `self.returns`.
- If only one element is present it returns the element.
- The elements that can be returned are:
- examples, weights, distances, indices, and labels.
- """
- return_dict = {}
-
- examples = dataset_gather(self.cases_dataset, search_output["indices"])
- examples_labels = dataset_gather(self.labels_dataset, search_output["indices"])
- examples_targets = dataset_gather(
- self.targets_dataset, search_output["indices"]
+ distance=distance
)
-
- # add examples and weights
- if "examples" in self.returns or "weights" in self.returns:
- if "include_inputs" in self.returns:
- # include inputs
- inputs = tf.expand_dims(inputs, axis=1)
- examples = tf.concat([inputs, examples], axis=1)
- if targets is not None:
- targets = tf.expand_dims(targets, axis=1)
- examples_targets = tf.concat([targets, examples_targets], axis=1)
- else:
- examples_targets = [None] * len(examples)
- if "examples" in self.returns:
- return_dict["examples"] = examples
- if "weights" in self.returns:
- # get weights of examples (n, k, ...)
- # we iterate on the inputs dimension through maps
- # and ask weights for batch of examples
- weights = []
- for ex, ex_targ in zip(examples, examples_targets):
- if isinstance(self.projection, Projection):
- # get weights in the input space
- weights.append(self.projection.get_input_weights(ex, ex_targ))
- else:
- raise AttributeError(
- "Cannot extract weights from the provided projection function"
- + "Either remove 'weights' from the `case_returns` or"
- + "inherit from `Projection` and overwrite `get_input_weights`."
- )
-
- return_dict["weights"] = tf.stack(weights, axis=0)
-
- # optimization test TODO
- # return_dict["weights"] = tf.vectorized_map(
- # fn=lambda x: self.projection.get_input_weights(x[0], x[1]),
- # elems=(examples, examples_targets),
- # # fn_output_signature=tf.float32,
- # )
-
- # add indices, distances, and labels
- if "distances" in self.returns:
- return_dict["distances"] = search_output["distances"]
- if "labels" in self.returns:
- assert (
- examples_labels is not None
- ), "The method cannot return labels without a label dataset."
- return_dict["labels"] = examples_labels
-
- # return a dict only different variables are returned
- if len(return_dict) == 1:
- return list(return_dict.values())[0]
- return return_dict
From 539c387a382a166526c32b5eeea9484e6d374ad0 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Mon, 12 Feb 2024 16:50:22 +0100
Subject: [PATCH 013/138] example based: adapt cole
---
xplique/example_based/cole.py | 4 ----
1 file changed, 4 deletions(-)
diff --git a/xplique/example_based/cole.py b/xplique/example_based/cole.py
index 85c4c2d6..ded8fbfd 100644
--- a/xplique/example_based/cole.py
+++ b/xplique/example_based/cole.py
@@ -45,8 +45,6 @@ class Cole(SimilarExamples):
Batch size and cardinality of other dataset should match `cases_dataset`.
Becareful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
- search_method
- An algorithm to search the examples in the projected space.
k
The number of examples to retrieve. Default value is `1`.
distance
@@ -87,7 +85,6 @@ def __init__(
model: tf.keras.Model,
targets_dataset: Union[tf.Tensor, np.ndarray],
labels_dataset: Optional[Union[tf.Tensor, np.ndarray]] = None,
- search_method: Type[BaseSearchMethod] = KNN,
k: int = 1,
distance: Union[str, Callable] = "euclidean",
case_returns: Optional[Union[List[str], str]] = "examples",
@@ -110,7 +107,6 @@ def __init__(
cases_dataset,
labels_dataset,
targets_dataset,
- search_method,
k,
projection,
case_returns,
From 7a58c9fd2b7453dd1b6a245d69167b5c3930bfab Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Mon, 12 Feb 2024 16:50:59 +0100
Subject: [PATCH 014/138] example based: adapt tests
---
tests/example_based/test_cole.py | 4 ----
tests/example_based/test_image_plot.py | 3 +--
tests/example_based/test_similar_examples.py | 3 ---
tests/example_based/test_split_projection.py | 2 +-
4 files changed, 2 insertions(+), 10 deletions(-)
diff --git a/tests/example_based/test_cole.py b/tests/example_based/test_cole.py
index 9fb1b73b..9d8c63a0 100644
--- a/tests/example_based/test_cole.py
+++ b/tests/example_based/test_cole.py
@@ -74,7 +74,6 @@ def test_cole_attribution():
method_constructor = Cole(
cases_dataset=x_train,
targets_dataset=y_train,
- search_method=KNN,
k=k,
batch_size=7,
distance="euclidean",
@@ -89,7 +88,6 @@ def test_cole_attribution():
method_call = SimilarExamples(
cases_dataset=x_train,
targets_dataset=y_train,
- search_method=KNN,
k=k,
distance=euclidean_dist,
projection=projection,
@@ -98,7 +96,6 @@ def test_cole_attribution():
method_different_distance = Cole(
cases_dataset=x_train,
targets_dataset=y_train,
- search_method=KNN,
k=k,
batch_size=2,
distance=np.inf, # infinity norm based distance
@@ -157,7 +154,6 @@ def test_cole_spliting():
method = Cole(
cases_dataset=x_train,
targets_dataset=y_train,
- search_method=KNN,
k=k,
case_returns=["examples", "weights", "include_inputs"],
model=model,
diff --git a/tests/example_based/test_image_plot.py b/tests/example_based/test_image_plot.py
index f8254d17..25908f44 100644
--- a/tests/example_based/test_image_plot.py
+++ b/tests/example_based/test_image_plot.py
@@ -15,7 +15,7 @@
from xplique.attributions import Occlusion, Saliency
from xplique.example_based import Cole, SimilarExamples
-from xplique.example_based.projections import CustomProjection
+from xplique.example_based.projections import Projection
from xplique.example_based.search_methods import KNN
from xplique.plots.image import plot_examples
@@ -70,7 +70,6 @@ def test_plot_cole_spliting():
cases_dataset=x_train,
labels_dataset=tf.argmax(y_train, axis=1),
targets_dataset=y_train,
- search_method=KNN,
k=k,
case_returns="all",
model=model,
diff --git a/tests/example_based/test_similar_examples.py b/tests/example_based/test_similar_examples.py
index fbc5824e..3e6c0401 100644
--- a/tests/example_based/test_similar_examples.py
+++ b/tests/example_based/test_similar_examples.py
@@ -148,7 +148,6 @@ def test_similar_examples_basic():
method = SimilarExamples(
cases_dataset=x_train,
projection=identity_projection,
- search_method=KNN,
k=k,
batch_size=3,
distance="euclidean",
@@ -194,7 +193,6 @@ def test_similar_examples_return_multiple_elements():
cases_dataset=x_train,
labels_dataset=y_train,
projection=identity_projection,
- search_method=KNN,
k=1,
batch_size=3,
distance="euclidean",
@@ -274,7 +272,6 @@ def test_similar_examples_weighting():
cases_dataset=x_train,
labels_dataset=y_train,
projection=weighting_function,
- search_method=KNN,
k=k,
batch_size=5,
distance="euclidean",
diff --git a/tests/example_based/test_split_projection.py b/tests/example_based/test_split_projection.py
index bc560b48..db3105d1 100644
--- a/tests/example_based/test_split_projection.py
+++ b/tests/example_based/test_split_projection.py
@@ -40,7 +40,7 @@ def test_attribution_latent_layer():
last_conv_layer = model.get_layer("conv2d_2")
flatten_layer = model.get_layer("flatten")
- # default should not include model spliting
+ # default should not include model splitting
projection_default = AttributionProjection(model)
assert projection_default.latent_layer is None
From 2255896be846f63e7976851f93f00696f5f262d1 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:07:30 +0100
Subject: [PATCH 015/138] base example method: dataset projections in
projections
---
xplique/example_based/base_example_method.py | 29 +++++++++++++++-----
1 file changed, 22 insertions(+), 7 deletions(-)
diff --git a/xplique/example_based/base_example_method.py b/xplique/example_based/base_example_method.py
index eca11a9e..2c4b99df 100644
--- a/xplique/example_based/base_example_method.py
+++ b/xplique/example_based/base_example_method.py
@@ -93,19 +93,34 @@ def __init__(
batch_size = self.__initialize_cases_dataset(
cases_dataset, labels_dataset, targets_dataset, batch_size
)
+
self.k = k
self.set_returns(case_returns)
- self.projection = projection
+
+ assert hasattr(projection, "__call__"), "projection should be a callable."
+
+ # check projection type
+ if isinstance(projection, Projection):
+ self.projection = projection
+ elif hasattr(projection, "__call__"):
+ self.projection = Projection(get_weights=None, space_projection=projection)
+ else:
+ raise AttributeError(
+ "projection should be a `Projection` or a `Callable`, not a"
+ + f"{type(projection)}"
+ )
+
+ # project dataset
+ projected_cases_dataset = self.projection.project_dataset(self.cases_dataset,
+ self.targets_dataset)
# set `search_returns` if not provided and overwrite it otherwise
search_method_kwargs["search_returns"] = ["indices", "distances"]
# initiate search_method
self.search_method = search_method(
- cases_dataset=cases_dataset,
- targets_dataset=targets_dataset,
+ cases_dataset=projected_cases_dataset,
k=k,
- projection=projection,
batch_size=batch_size,
**search_method_kwargs,
)
@@ -266,7 +281,7 @@ def explain(
Returns
-------
return_dict
- Dictionnary with listed elements in `self.returns`.
+ Dictionary with listed elements in `self.returns`.
If only one element is present it returns the element.
The elements that can be returned are:
examples, weights, distances, indices, and labels.
@@ -300,7 +315,7 @@ def format_search_output(
Parameters
----------
search_output
- Dictionnary with the required outputs from the `search_method`.
+ Dictionary with the required outputs from the `search_method`.
inputs
Tensor or Array. Input samples to be explained.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
@@ -314,7 +329,7 @@ def format_search_output(
Returns
-------
return_dict
- Dictionnary with listed elements in `self.returns`.
+ Dictionary with listed elements in `self.returns`.
If only one element is present it returns the element.
The elements that can be returned are:
examples, weights, distances, indices, and labels.
From ead19ea762c3be35b71bb8b327b318b5807c6919 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:08:17 +0100
Subject: [PATCH 016/138] base projection: dataset projections in projections
---
xplique/example_based/projections/base.py | 60 ++++++++++++++++++-----
1 file changed, 48 insertions(+), 12 deletions(-)
diff --git a/xplique/example_based/projections/base.py b/xplique/example_based/projections/base.py
index debe261a..9581cc22 100644
--- a/xplique/example_based/projections/base.py
+++ b/xplique/example_based/projections/base.py
@@ -7,13 +7,13 @@
import tensorflow as tf
import numpy as np
-from ...commons import sanitize_inputs_targets
+from ...commons import sanitize_inputs_targets, get_device
from ...types import Callable, Union, Optional
-class Projection(ABC):
+class Projection(ABC): # TODO See if this should stay as abstract class or if we should remove CustomProjection
"""
- Base class used by `NaturalExampleBasedExplainer` to projet samples to a meaningfull space
+ Base class used by `NaturalExampleBasedExplainer` to project samples to a meaningful space
for the model to explain.
Projection have two parts a `space_projection` and `weights`, to apply a projection,
@@ -39,14 +39,14 @@ def get_weights_example(projected_inputs: Union(tf.Tensor, np.ndarray),
targets: Union(tf.Tensor, np.ndarray) = None):
'''
Example of function to get weights,
- projected_inputs are the elements for which weights are comlputed.
- targets are optionnal additionnal parameters for weights computation.
+ projected_inputs are the elements for which weights are computed.
+ targets are optional additional parameters for weights computation.
'''
weights = ... # do some magic with inputs and targets, it should use the model.
return weights
```
space_projection
- Callable that take samples and return a Tensor in the projected sapce.
+ Callable that take samples and return a Tensor in the projected space.
An example of projected space is the latent space of a model. See `LatentSpaceProjection`
"""
@@ -75,6 +75,9 @@ def __init__(self, get_weights: Callable = None, space_projection: Callable = No
)
self.space_projection = space_projection
+ # set device
+ self.device = get_device()
+
def get_input_weights(
self,
inputs: Union[tf.Tensor, np.ndarray],
@@ -83,7 +86,7 @@ def get_input_weights(
"""
Depending on the projection, we may not be able to visualize weights
as they are after the space projection. In this case, this method should be overwritten,
- as in `AttributionProjection` that applies an upsampling.
+ as in `AttributionProjection` that applies an up-sampling.
Parameters
----------
@@ -98,7 +101,7 @@ def get_input_weights(
-------
input_weights
Tensor with the same dimension as `inputs` modulo the channels.
- They are an upsampled version of the actual weights used in the projection.
+ They are an up-sampled version of the actual weights used in the projection.
"""
projected_inputs = self.space_projection(inputs)
assert tf.reduce_all(tf.equal(projected_inputs, inputs)), (
@@ -137,10 +140,10 @@ def project(
projected_samples
The samples projected in the new space.
"""
- projected_inputs = self.space_projection(inputs)
- weights = self.get_weights(projected_inputs, targets)
-
- return tf.multiply(weights, projected_inputs)
+ with tf.device(self.device):
+ projected_inputs = self.space_projection(inputs)
+ weights = self.get_weights(projected_inputs, targets)
+ return tf.multiply(weights, projected_inputs)
def __call__(
self,
@@ -149,3 +152,36 @@ def __call__(
):
"""project alias"""
return self.project(inputs, targets)
+
+ def project_dataset(
+ self,
+ cases_dataset: tf.data.Dataset,
+ targets_dataset: Optional[tf.data.Dataset] = None,
+ ) -> Optional[tf.data.Dataset]:
+ """
+ Apply the projection to a dataset through `Dataset.map`
+
+ Parameters
+ ----------
+ cases_dataset
+ Dataset of samples to be projected.
+ targets_dataset
+ Dataset of targets for the samples.
+
+ Returns
+ -------
+ projected_dataset
+ The projected dataset.
+ """
+ # project dataset, note that projection is done at iteration time
+ if targets_dataset is None:
+ projected_cases_dataset = cases_dataset.map(self.project)
+ else:
+ # in case targets are provided, we zip the datasets and project them together
+ projected_cases_dataset = tf.data.Dataset.zip(
+ (cases_dataset, targets_dataset)
+ ).map(
+ lambda x, y: self.project(x, y)
+ )
+
+ return projected_cases_dataset
From 48bc853ad0edb641bce6c18c888ac7ced74ff01c Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:09:42 +0100
Subject: [PATCH 017/138] latent space projection: dataset projections in
projections
---
.../example_based/projections/latent_space.py | 19 ++++---------------
1 file changed, 4 insertions(+), 15 deletions(-)
diff --git a/xplique/example_based/projections/latent_space.py b/xplique/example_based/projections/latent_space.py
index 04ce0304..3bfc1d9f 100644
--- a/xplique/example_based/projections/latent_space.py
+++ b/xplique/example_based/projections/latent_space.py
@@ -8,6 +8,7 @@
from ...types import Callable, Union
from .base import Projection
+from .commons import model_splitting
class LatentSpaceProjection(Projection):
@@ -31,18 +32,6 @@ class LatentSpaceProjection(Projection):
"""
def __init__(self, model: Callable, latent_layer: Union[str, int] = -1):
- self.model = model
-
- # split the model if a latent_layer is provided
- if latent_layer == "last_conv":
- self.latent_layer = next(
- layer for layer in model.layers[::-1] if hasattr(layer, "filters")
- )
- else:
- self.latent_layer = find_layer(model, latent_layer)
-
- latent_space_projection = tf.keras.Model(
- model.input, self.latent_layer.output, name="features_extractor"
- )
-
- super().__init__(space_projection=latent_space_projection)
+ features_extractor, _ = model_splitting(model, latent_layer)
+ super().__init__(space_projection=features_extractor)
+ # TODO test if gpu is used for the projection
From fcacfd1f263b654ac168470dca42eaecb53fb865 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:10:01 +0100
Subject: [PATCH 018/138] attribution projection: dataset projections in
projections
---
.../example_based/projections/attributions.py | 94 ++++++++++---------
1 file changed, 52 insertions(+), 42 deletions(-)
diff --git a/xplique/example_based/projections/attributions.py b/xplique/example_based/projections/attributions.py
index 7f9f624f..2ebf37c8 100644
--- a/xplique/example_based/projections/attributions.py
+++ b/xplique/example_based/projections/attributions.py
@@ -1,17 +1,18 @@
"""
Attribution, a projection from example based module
"""
-
+import warnings
import tensorflow as tf
import numpy as np
+from xplique.types import Optional
from ...attributions.base import BlackBoxExplainer
from ...attributions import Saliency
-from ...commons import find_layer
from ...types import Callable, Union, Optional
from .base import Projection
+from .commons import model_splitting
class AttributionProjection(Projection):
@@ -19,13 +20,13 @@ class AttributionProjection(Projection):
Projection build on an attribution function to provide local projections.
This class is used as the projection of the `Cole` similar examples method.
- Depending on the `latent_layer`, the model will be splited between
+ Depending on the `latent_layer`, the model will be splitted between
the feature extractor and the predictor.
The feature extractor will become the `space_projection()` method, then
the predictor will be used to build the attribution method explain, and
its `explain()` method will become the `get_weights()` method.
- If no `latent_layer` is provided, the model is not splited,
+ If no `latent_layer` is provided, the model is not splitted,
the `space_projection()` is the identity function, and
the attributions (`get_weights()`) are compute on the whole model.
@@ -42,7 +43,7 @@ class AttributionProjection(Projection):
If an `int` is provided it will be interpreted as a layer index.
If a `string` is provided it will look for the layer name.
- The method as described in the paper apply the separation on the last convolutionnal layer.
+ The method as described in the paper apply the separation on the last convolutional layer.
To do so, the `"last_conv"` parameter will extract it.
Otherwise, `-1` could be used for the last layer before softmax.
attribution_method
@@ -60,53 +61,23 @@ def __init__(
latent_layer: Optional[Union[str, int]] = None,
**attribution_kwargs
):
- self.model = model
+ self.method = method
if latent_layer is None:
# no split
self.latent_layer = None
- space_projection = lambda inputs: inputs
- get_weights = method(model, **attribution_kwargs)
+ space_projection = None
+ self.predictor = model
else:
# split the model if a latent_layer is provided
- if latent_layer == "last_conv":
- self.latent_layer = next(
- layer for layer in model.layers[::-1] if hasattr(layer, "filters")
- )
- else:
- self.latent_layer = find_layer(model, latent_layer)
-
- space_projection = tf.keras.Model(
- model.input, self.latent_layer.output, name="features_extractor"
- )
- self.predictor = tf.keras.Model(
- self.latent_layer.output, model.output, name="predictor"
- )
- get_weights = method(self.predictor, **attribution_kwargs)
+ space_projection, self.predictor = model_splitting(model, latent_layer)
+
+ # compute attributions
+ get_weights = self.method(self.predictor, **attribution_kwargs)
# set methods
super().__init__(get_weights, space_projection)
- # attribution methods output do not have channel
- # we wrap get_weights to expend dimensions if needed
- self.__wrap_get_weights_to_extend_channels(self.get_weights)
-
- def __wrap_get_weights_to_extend_channels(self, get_weights: Callable):
- """
- Extend channel if miss match between inputs and weights
- """
-
- def wrapped_get_weights(inputs, targets):
- weights = get_weights(inputs, targets)
- weights = tf.cond(
- pred=weights.shape == inputs.shape,
- true_fn=lambda: weights,
- false_fn=lambda: tf.expand_dims(weights, axis=-1),
- )
- return weights
-
- self.get_weights = wrapped_get_weights
-
def get_input_weights(
self,
inputs: Union[tf.Tensor, np.ndarray],
@@ -154,3 +125,42 @@ def get_input_weights(
false_fn=resize_fn,
)
return input_weights
+
+ def project_dataset(
+ self,
+ cases_dataset: tf.data.Dataset,
+ targets_dataset: tf.data.Dataset,
+ ) -> tf.data.Dataset:
+ """
+ Apply the projection to a dataset without `Dataset.map`.
+ Because attribution methods create a `tf.data.Dataset` for batching,
+ however doing so inside a `Dataset.map` is not recommended.
+
+ Parameters
+ ----------
+ cases_dataset
+ Dataset of samples to be projected.
+ targets_dataset
+ Dataset of targets for the samples.
+
+ Returns
+ -------
+ projected_dataset
+ The projected dataset.
+ """
+ # TODO see if a warning is needed
+
+ projected_cases_dataset = []
+ batch_size = None
+
+ # iteratively project the dataset
+ for inputs, targets in tf.data.Dataset.zip((cases_dataset, targets_dataset)):
+ if batch_size is None:
+ batch_size = inputs.shape[0] # TODO check if there is a smarter way to do this
+ projected_cases_dataset.append(self.project(inputs, targets))
+
+ projected_cases_dataset = tf.concat(projected_cases_dataset, axis=0)
+ projected_cases_dataset = tf.data.Dataset.from_tensor_slices(projected_cases_dataset)
+ projected_cases_dataset = projected_cases_dataset.batch(batch_size)
+
+ return projected_cases_dataset
\ No newline at end of file
From 4dbb6b6141cd903080266cc491038436dfc55138 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:10:39 +0100
Subject: [PATCH 019/138] example based: introduce hadamard projection
---
xplique/example_based/projections/__init__.py | 2 +-
xplique/example_based/projections/hadamard.py | 120 ++++++++++++++++++
2 files changed, 121 insertions(+), 1 deletion(-)
create mode 100644 xplique/example_based/projections/hadamard.py
diff --git a/xplique/example_based/projections/__init__.py b/xplique/example_based/projections/__init__.py
index d5d4cf90..4b33a895 100644
--- a/xplique/example_based/projections/__init__.py
+++ b/xplique/example_based/projections/__init__.py
@@ -4,5 +4,5 @@
from .attributions import AttributionProjection
from .base import Projection
-from .custom import CustomProjection
+from .hadamard import HadamardProjection
from .latent_space import LatentSpaceProjection
diff --git a/xplique/example_based/projections/hadamard.py b/xplique/example_based/projections/hadamard.py
new file mode 100644
index 00000000..87234883
--- /dev/null
+++ b/xplique/example_based/projections/hadamard.py
@@ -0,0 +1,120 @@
+"""
+Attribution, a projection from example based module
+"""
+import warnings
+
+import tensorflow as tf
+import numpy as np
+from xplique.types import Optional
+
+from ...commons import get_gradient_functions
+from ...types import Callable, Union, Optional, OperatorSignature
+
+from .base import Projection
+from .commons import model_splitting
+
+
+class HadamardProjection(Projection):
+ """
+ Projection build on an the latent space and the gradient.
+ This class is used as the projection of the `Cole` similar examples method.
+
+ Depending on the `latent_layer`, the model will be splitted between
+ the feature extractor and the predictor.
+ The feature extractor will become the `space_projection()` method, then
+ the predictor will be used to build the attribution method explain, and
+ its `explain()` method will become the `get_weights()` method.
+
+ If no `latent_layer` is provided, the model is not splitted,
+ the `space_projection()` is the identity function, and
+ the attributions (`get_weights()`) are compute on the whole model.
+
+ Parameters
+ ----------
+ model
+ The model from which we want to obtain explanations.
+ latent_layer
+ Layer used to split the model, the first part will be used for projection and
+ the second to compute the attributions. By default, the model is not split.
+ For such split, the `model` should be a `tf.keras.Model`.
+
+ Layer to target for the outputs (e.g logits or after softmax).
+ If an `int` is provided it will be interpreted as a layer index.
+ If a `string` is provided it will look for the layer name.
+
+ The method as described in the paper apply the separation on the last convolutional layer.
+ To do so, the `"last_conv"` parameter will extract it.
+ Otherwise, `-1` could be used for the last layer before softmax.
+ operator
+ Operator to use to compute the explanation, if None use standard predictions.
+ """
+
+ def __init__(
+ self,
+ model: Callable,
+ latent_layer: Optional[Union[str, int]] = None,
+ operator: Optional[OperatorSignature] = None,
+ ):
+ if latent_layer is None:
+ # no split
+ self.latent_layer = None
+ space_projection = None
+ self.predictor = model
+ else:
+ # split the model if a latent_layer is provided
+ space_projection, self.predictor = model_splitting(model, latent_layer)
+
+ # the weights are given be the gradient of the operator
+ gradients, _ = get_gradient_functions(self.predictor, operator)
+ get_weights = lambda inputs, targets: gradients(self.predictor, inputs, targets) # TODO check usage of gpu
+
+ # set methods
+ super().__init__(get_weights, space_projection)
+
+ def get_input_weights(
+ self,
+ inputs: Union[tf.Tensor, np.ndarray],
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ ):
+ """
+ For visualization purpose (and only), we may be interested to project weights
+ from the projected space to the input space.
+ This is applied only if their is a difference in dimension.
+ We assume here that we are treating images and an upsampling is applied.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ More information in the documentation.
+ targets
+ Additional parameter for `self.get_weights` function.
+
+ Returns
+ -------
+ input_weights
+ Tensor with the same dimension as `inputs` modulo the channels.
+ They are an upsampled version of the actual weights used in the projection.
+ """
+ projected_inputs = self.space_projection(inputs)
+ weights = self.get_weights(projected_inputs, targets)
+
+ # take mean over channels for images
+ channel_mean_fn = lambda: tf.reduce_mean(weights, axis=-1, keepdims=True)
+ weights = tf.cond(
+ pred=tf.shape(weights).shape[0] < 4,
+ true_fn=lambda: weights,
+ false_fn=channel_mean_fn,
+ )
+
+ # resizing
+ resize_fn = lambda: tf.image.resize(
+ weights, inputs.shape[1:-1], method="bicubic"
+ )
+ input_weights = tf.cond(
+ pred=projected_inputs.shape == inputs.shape,
+ true_fn=lambda: weights,
+ false_fn=resize_fn,
+ )
+ return input_weights
From 6e8ddb1d9fbc693cd7669ac6119f54a236a1ac5f Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:13:32 +0100
Subject: [PATCH 020/138] base search method: remove projection from search
---
xplique/example_based/search_methods/base.py | 32 ++------------------
1 file changed, 2 insertions(+), 30 deletions(-)
diff --git a/xplique/example_based/search_methods/base.py b/xplique/example_based/search_methods/base.py
index 13a05f6a..1c7c0f1b 100644
--- a/xplique/example_based/search_methods/base.py
+++ b/xplique/example_based/search_methods/base.py
@@ -11,8 +11,6 @@
from ...commons import sanitize_dataset
-from ..projections.base import Projection
-
def _sanitize_returns(returns: Optional[Union[List[str], str]] = None,
possibilities: List[str] = None,
@@ -69,27 +67,8 @@ class BaseSearchMethod(ABC):
cases_dataset
The dataset used to train the model, examples are extracted from the dataset.
For natural example-based methods it is the train dataset.
- targets_dataset
- Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
k
The number of examples to retrieve.
- projection
- Projection or Callable that project samples from the input space to the search space.
- The search space sould be a space where distance make sense for the model.
- It should not be `None`, otherwise,
- all examples could be computed only with the `search_method`.
-
- Example of Callable:
- ```
- def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
- '''
- Example of projection,
- inputs are the elements to project.
- targets are optionnal parameters to orientated the projection.
- '''
- projected_inputs = # do some magic on inputs, it should use the model.
- return projected_inputs
- ```
search_returns
String or list of string with the elements to return in `self.find_examples()`.
See `self.set_returns()` for detail.
@@ -101,12 +80,11 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
def __init__(
self,
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
k: int = 1,
- projection: Union[Projection, Callable] = None,
search_returns: Optional[Union[List[str], str]] = None,
batch_size: Optional[int] = 32,
): # pylint: disable=R0801
+
# set batch size
if hasattr(cases_dataset, "_batch_size"):
self.batch_size = cases_dataset._batch_size
@@ -114,19 +92,14 @@ def __init__(
self.batch_size = batch_size
self.cases_dataset = sanitize_dataset(cases_dataset, self.batch_size)
- self.targets_dataset = sanitize_dataset(targets_dataset, self.batch_size)
- if self.targets_dataset is None:
- # The `find_examples()` method need to be able to iterate on `self.targets_dataset`
- self.targets_dataset = [None] * self.cases_dataset.cardinality().numpy()
self.set_k(k)
self.set_returns(search_returns)
- self.projection = projection
def set_k(self, k: int):
"""
Change value of k with constructing a new `BaseSearchMethod`.
- It is useful because the constructor can be computionnaly expensive.
+ It is useful because the constructor can be computationally expensive.
Parameters
----------
@@ -170,7 +143,6 @@ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray]):
----------
inputs
Tensor or Array. Input samples to be explained.
- Assumed to have been already projected.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
"""
raise NotImplementedError()
From 5abb298d500e0485fab7fa0373cab7d809a0459b Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:13:45 +0100
Subject: [PATCH 021/138] knn search method: remove projection from search
---
xplique/example_based/search_methods/knn.py | 37 +++------------------
1 file changed, 4 insertions(+), 33 deletions(-)
diff --git a/xplique/example_based/search_methods/knn.py b/xplique/example_based/search_methods/knn.py
index ed8d721b..8530f4a6 100644
--- a/xplique/example_based/search_methods/knn.py
+++ b/xplique/example_based/search_methods/knn.py
@@ -22,27 +22,8 @@ class KNN(BaseSearchMethod):
cases_dataset
The dataset used to train the model, examples are extracted from the dataset.
For natural example-based methods it is the train dataset.
- targets_dataset
- Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
k
The number of examples to retrieve.
- projection
- Projection or Callable that project samples from the input space to the search space.
- The search space sould be a space where distance make sense for the model.
- It should not be `None`, otherwise,
- all examples could be computed only with the `search_method`.
-
- Example of Callable:
- ```
- def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
- '''
- Example of projection,
- inputs are the elements to project.
- targets are optionnal parameters to orientated the projection.
- '''
- projected_inputs = # do some magic on inputs, it should use the model.
- return projected_inputs
- ```
search_returns
String or list of string with the elements to return in `self.find_examples()`.
See `self.set_returns()` for detail.
@@ -59,15 +40,13 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
def __init__(
self,
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
k: int = 1,
- projection: Union[Projection, Callable] = None,
search_returns: Optional[Union[List[str], str]] = None,
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
): # pylint: disable=R0801
super().__init__(
- cases_dataset, targets_dataset, k, projection, search_returns, batch_size
+ cases_dataset, k, search_returns, batch_size
)
if hasattr(distance, "__call__"):
@@ -131,25 +110,17 @@ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Tensor, t
batch_indices = tf.tile(batch_indices, multiples=(nb_inputs, 1))
# iterate on batches
- for batch_index, (cases, targets) in enumerate(
- zip(self.cases_dataset, self.targets_dataset)
- ):
- # project batch of dataset cases
- if self.projection is not None:
- projected_cases = self.projection.project(cases, targets)
- else:
- projected_cases = cases
-
+ for batch_index, cases in enumerate(self.cases_dataset):
# add new elements
# (n, current_bs, 2)
- indices = batch_indices[:, : tf.shape(projected_cases)[0]]
+ indices = batch_indices[:, : tf.shape(cases)[0]]
new_indices = tf.stack(
[tf.fill(indices.shape, tf.cast(batch_index, tf.int32)), indices], axis=-1
)
# compute distances
# (n, current_bs)
- distances = self.crossed_distances_fn(inputs, projected_cases)
+ distances = self.crossed_distances_fn(inputs, cases)
# (n, k+curent_bs, 2)
concatenated_indices = tf.concat([best_indices, new_indices], axis=1)
From 92056f161c21c0d706fdb0d632e0da5684264e7a Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:14:58 +0100
Subject: [PATCH 022/138] cole: add hadamard product possibility
---
xplique/example_based/cole.py | 54 +++++++++++++++++++++++------------
1 file changed, 36 insertions(+), 18 deletions(-)
diff --git a/xplique/example_based/cole.py b/xplique/example_based/cole.py
index ded8fbfd..3fdfc82f 100644
--- a/xplique/example_based/cole.py
+++ b/xplique/example_based/cole.py
@@ -6,13 +6,10 @@
import tensorflow as tf
from ..attributions.base import BlackBoxExplainer
-from ..attributions import Saliency
from ..types import Callable, List, Optional, Union, Type
from .similar_examples import SimilarExamples
-from .projections import AttributionProjection
-from .search_methods import KNN
-from .search_methods import BaseSearchMethod
+from .projections import AttributionProjection, HadamardProjection
class Cole(SimilarExamples):
@@ -31,19 +28,19 @@ class Cole(SimilarExamples):
cases_dataset
The dataset used to train the model, examples are extracted from the dataset.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Becareful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
labels_dataset
Labels associated to the examples in the dataset. Indices should match with cases_dataset.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
Batch size and cardinality of other dataset should match `cases_dataset`.
- Becareful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
targets_dataset
Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
Batch size and cardinality of other dataset should match `cases_dataset`.
- Becareful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
k
The number of examples to retrieve. Default value is `1`.
@@ -59,6 +56,8 @@ class Cole(SimilarExamples):
batch_size
Number of sample treated simultaneously for projection and search.
Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+ device
+ Device to use for the projection, if None, use the default device.
latent_layer
Layer used to split the model, the first part will be used for projection and
the second to compute the attributions. By default, the model is not split.
@@ -68,13 +67,13 @@ class Cole(SimilarExamples):
If an `int` is provided it will be interpreted as a layer index.
If a `string` is provided it will look for the layer name.
- The method as described in the paper apply the separation on the last convolutionnal layer.
+ The method as described in the paper apply the separation on the last convolutional layer.
To do so, the `"last_conv"` parameter will extract it.
Otherwise, `-1` could be used for the last layer before softmax.
attribution_method
Class of the attribution method to use for projection.
It should inherit from `xplique.attributions.base.BlackBoxExplainer`.
- Ignored if a projection is given.
+ By default, it computes the gradient to make the Hadamard product in the latent space.
attribution_kwargs
Parameters to be passed at the construction of the `attribution_method`.
"""
@@ -89,20 +88,39 @@ def __init__(
distance: Union[str, Callable] = "euclidean",
case_returns: Optional[Union[List[str], str]] = "examples",
batch_size: Optional[int] = 32,
+ device: Optional[str] = None,
latent_layer: Optional[Union[str, int]] = None,
- attribution_method: Type[BlackBoxExplainer] = Saliency,
+ attribution_method: Union[str, Type[BlackBoxExplainer]] = "gradient",
**attribution_kwargs,
):
- # buil attribution projection
- projection = AttributionProjection(
- model=model,
- method=attribution_method,
- latent_layer=latent_layer,
- **attribution_kwargs,
- )
-
assert targets_dataset is not None
+ # build the corresponding projection
+ if isinstance(attribution_method, str) and attribution_method.lower() == "gradient":
+
+ operator = attribution_kwargs.get("operator", None)
+
+ projection = HadamardProjection(
+ model=model,
+ latent_layer=latent_layer,
+ operator=operator,
+ device=device,
+ )
+ elif issubclass(attribution_method, BlackBoxExplainer):
+ # build attribution projection
+ projection = AttributionProjection(
+ model=model,
+ method=attribution_method,
+ latent_layer=latent_layer,
+ device=device,
+ **attribution_kwargs,
+ )
+ else:
+ raise ValueError(
+ f"attribution_method should be 'gradient' or a subclass of BlackBoxExplainer," +\
+ "not {attribution_method}"
+ )
+
super().__init__(
cases_dataset,
labels_dataset,
From 3f69e7b034add4f9b5dab98404aa734ac15cede3 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:16:55 +0100
Subject: [PATCH 023/138] projections: factorize model splitting
---
xplique/example_based/projections/commons.py | 60 ++++++++++++++++++++
1 file changed, 60 insertions(+)
create mode 100644 xplique/example_based/projections/commons.py
diff --git a/xplique/example_based/projections/commons.py b/xplique/example_based/projections/commons.py
new file mode 100644
index 00000000..59dc7ee8
--- /dev/null
+++ b/xplique/example_based/projections/commons.py
@@ -0,0 +1,60 @@
+"""
+Commons for projections
+"""
+
+import tensorflow as tf
+
+from ...commons import find_layer
+from ...types import Callable, Union, Optional, Tuple
+
+
+def model_splitting(model: tf.keras.Model,
+ latent_layer: Union[str, int],
+ return_layer: bool = False,
+ ) -> Tuple[Callable, Callable, Optional[tf.keras.layers.Layer]]:
+ """
+ Split the model into two parts, before and after the `latent_layer`.
+ The parts will respectively be called `features_extractor` and `predictor`.
+
+ Parameters
+ ----------
+ model
+ Model to be split.
+ latent_layer
+ Layer used to split the `model`.
+
+ Layer to target for the outputs (e.g logits or after softmax).
+ If an `int` is provided it will be interpreted as a layer index.
+ If a `string` is provided it will look for the layer name.
+
+ To separate after the last convolution, `"last_conv"` can be used.
+ Otherwise, `-1` could be used for the last layer before softmax.
+ return_layer
+ If True, return the latent layer found.
+
+ Returns
+ -------
+ features_extractor
+ Model used to project the inputs.
+ predictor
+ Model used to compute the attributions.
+ latent_layer
+ Layer used to split the `model`.
+ """
+ if latent_layer == "last_conv":
+ latent_layer = next(
+ layer for layer in model.layers[::-1] if hasattr(layer, "filters")
+ )
+ else:
+ latent_layer = find_layer(model, latent_layer)
+
+ features_extractor = tf.keras.Model(
+ model.input, latent_layer.output, name="features_extractor"
+ )
+ predictor = tf.keras.Model(
+ latent_layer.output, model.output, name="predictor"
+ )
+
+ if return_layer:
+ return features_extractor, predictor, latent_layer
+ return features_extractor, predictor
\ No newline at end of file
From 5cdf1c132da0d29a88e35900ac5470ce76a80700 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 28 Feb 2024 11:54:01 +0100
Subject: [PATCH 024/138] example based projections: fuse custom projection
with base class
---
xplique/example_based/projections/base.py | 53 ++++++++----
xplique/example_based/projections/custom.py | 90 ---------------------
2 files changed, 38 insertions(+), 105 deletions(-)
delete mode 100644 xplique/example_based/projections/custom.py
diff --git a/xplique/example_based/projections/base.py b/xplique/example_based/projections/base.py
index 9581cc22..54192ed5 100644
--- a/xplique/example_based/projections/base.py
+++ b/xplique/example_based/projections/base.py
@@ -11,7 +11,7 @@
from ...types import Callable, Union, Optional
-class Projection(ABC): # TODO See if this should stay as abstract class or if we should remove CustomProjection
+class Projection(ABC):
"""
Base class used by `NaturalExampleBasedExplainer` to project samples to a meaningful space
for the model to explain.
@@ -30,13 +30,17 @@ class Projection(ABC): # TODO See if this should stay as abstract class or if w
Parameters
----------
get_weights
- Callable, a function that return the weights (Tensor) for a given input (Tensor).
+ Either a Tensor or a Callable.
+ - In the case of a Tensor, weights are applied in the projected space.
+ - In the case of a callable, a function is expected.
+ It should take inputs and targets as parameters and return the weights (Tensor).
Weights should have the same shape as the input (possible difference on channels).
+ The inputs of `get_weights()` correspond to the projected inputs.
Example of `get_weights()` function:
```
def get_weights_example(projected_inputs: Union(tf.Tensor, np.ndarray),
- targets: Union(tf.Tensor, np.ndarray) = None):
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None):
'''
Example of function to get weights,
projected_inputs are the elements for which weights are computed.
@@ -48,35 +52,53 @@ def get_weights_example(projected_inputs: Union(tf.Tensor, np.ndarray),
space_projection
Callable that take samples and return a Tensor in the projected space.
An example of projected space is the latent space of a model. See `LatentSpaceProjection`
+ device
+ Device to use for the projection, if None, use the default device.
"""
- def __init__(self, get_weights: Callable = None, space_projection: Callable = None):
+ def __init__(self,
+ get_weights: Optional[Union[Callable, tf.Tensor, np.ndarray]] = None,
+ space_projection: Optional[Callable] = None,
+ device: Optional[str] = None):
assert get_weights is not None or space_projection is not None, (
"At least one of `get_weights` and `space_projection`"
+ "should not be `None`."
)
- # set get weights
+ # set get_weights
if get_weights is None:
# no weights
- get_weights = lambda inputs, _: tf.ones(tf.shape(inputs))
- if not hasattr(get_weights, "__call__"):
+ self.get_weights = lambda inputs, _: tf.ones(tf.shape(inputs))
+ elif isinstance(get_weights, (tf.Tensor, np.ndarray)):
+ # weights is a tensor
+ if isinstance(get_weights, np.ndarray):
+ weights = tf.convert_to_tensor(get_weights, dtype=tf.float32)
+
+ # define a function that returns the weights
+ def get_weights(inputs, _ = None):
+ nweights = tf.expand_dims(weights, axis=0)
+ return tf.repeat(nweights, tf.shape(inputs)[0], axis=0)
+ self.get_weights = get_weights
+ elif hasattr(get_weights, "__call__"):
+ # weights is a function
+ self.get_weights = get_weights
+ else:
raise TypeError(
- f"`get_weights` should be `Callable`, not a {type(get_weights)}"
+ f"`get_weights` should be `Callable` or a Tensor, not a {type(get_weights)}"
)
- self.get_weights = get_weights
-
+
# set space_projection
if space_projection is None:
- space_projection = lambda inputs: inputs
- if not hasattr(space_projection, "__call__"):
+ self.space_projection = lambda inputs: inputs
+ elif hasattr(space_projection, "__call__"):
+ self.space_projection = space_projection
+ else:
raise TypeError(
f"`space_projection` should be a `Callable`, not a {type(space_projection)}"
)
- self.space_projection = space_projection
# set device
- self.device = get_device()
+ self.device = get_device(device)
def get_input_weights(
self,
@@ -143,7 +165,8 @@ def project(
with tf.device(self.device):
projected_inputs = self.space_projection(inputs)
weights = self.get_weights(projected_inputs, targets)
- return tf.multiply(weights, projected_inputs)
+ weighted_projected_inputs = tf.multiply(weights, projected_inputs)
+ return weighted_projected_inputs
def __call__(
self,
diff --git a/xplique/example_based/projections/custom.py b/xplique/example_based/projections/custom.py
deleted file mode 100644
index 966c6ada..00000000
--- a/xplique/example_based/projections/custom.py
+++ /dev/null
@@ -1,90 +0,0 @@
-"""
-Custom, a projection from example based module
-"""
-
-import tensorflow as tf
-import numpy as np
-
-from ...types import Callable, Union
-
-from .base import Projection
-
-
-class CustomProjection(Projection):
- """
- Base class used by `NaturalExampleBasedExplainer` to projet samples to a meaningfull space
- for the model to explain.
-
- Projection have two parts a `space_projection` and `weights`, to apply a projection,
- the samples are first projected to a new space and then weighted.
- Either the `space_projection` or the `weights` could be `None` but,
- if both are, the projection is an identity function.
-
- At least one of the two part should include the model in the computation
- for distance between projected elements to make sense for the model.
-
- Note that the cost of this projection should be limited
- as it will be applied to all samples of the train dataset.
-
- Parameters
- ----------
- weights
- Either a Tensor or a Callable.
- - In the case of a Tensor, weights are applied in the projected space
- (after `space_projection`).
- Hence weights should have the same shape as a `projected_input`.
- - In the case of a Callable, the function should return the weights when called,
- as a way to get the weights (a Tensor)
- It is pertinent in the case on weights dependent on the inputs, i.e. local weighting.
-
- Example of `get_weights()` function:
- ```
- def get_weights_example(projected_inputs: Union(tf.Tensor, np.ndarray),
- targets: Union(tf.Tensor, np.ndarray) = None):
- '''
- Example of function to get weights,
- projected_inputs are the elements for which weights are comlputed.
- targets are optionnal additionnal parameters for weights computation.
- '''
- weights = ... # do some magic with inputs and targets, it should use the model.
- return weights
- ```
- space_projection
- Callable that take samples and return a Tensor in the projected sapce.
- An example of projected space is the latent space of a model.
- In this case, the model should be splitted and the
- """
-
- def __init__(
- self,
- weights: Union[Callable, tf.Tensor, np.ndarray] = None,
- space_projection: Callable = None,
- ):
- # Set weights or
- if weights is None or hasattr(weights, "__call__"):
- # weights is already a function or there is no weights
- get_weights = weights
- elif isinstance(weights, (tf.Tensor, np.ndarray)):
- # weights is a tensor
- if isinstance(weights, np.ndarray):
- weights = tf.convert_to_tensor(weights, dtype=tf.float32)
-
- # define a function that returns the weights
- def get_weights(inputs, _ = None):
- nweights = tf.expand_dims(weights, axis=0)
- return tf.repeat(nweights, tf.shape(inputs)[0], axis=0)
-
- else:
- raise TypeError(
- "`weights` should be a tensor or a `Callable`,"
- + f"not a {type(weights)}"
- )
-
- # Set space_projection
- if space_projection is not None and not hasattr(space_projection, "__call__"):
- raise TypeError(
- "`space_projection` should be a `Callable`,"
- + f"not a {type(space_projection)}"
- )
-
- super().__init__(get_weights, space_projection)
From 0c141bb6e64d751939ca7871afb5e633addaee4c Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:18:22 +0100
Subject: [PATCH 025/138] projections tests: adapt to changes and complete
---
tests/example_based/test_projections.py | 133 +++++++++++++++++++
tests/example_based/test_split_projection.py | 85 ------------
2 files changed, 133 insertions(+), 85 deletions(-)
create mode 100644 tests/example_based/test_projections.py
delete mode 100644 tests/example_based/test_split_projection.py
diff --git a/tests/example_based/test_projections.py b/tests/example_based/test_projections.py
new file mode 100644
index 00000000..8fe8b28f
--- /dev/null
+++ b/tests/example_based/test_projections.py
@@ -0,0 +1,133 @@
+import numpy as np
+import tensorflow as tf
+from tensorflow.keras.layers import (
+ Dense,
+ Conv2D,
+ Activation,
+ Dropout,
+ Flatten,
+ MaxPooling2D,
+ Input,
+)
+
+from xplique.attributions import Saliency
+from xplique.example_based.projections import Projection, AttributionProjection, LatentSpaceProjection
+from xplique.example_based.projections.commons import model_splitting
+from ..utils import generate_data, almost_equal
+
+def get_setup(input_shape, nb_samples=10, nb_labels=2):
+ """
+ Generate data and model for SimilarExamples
+ """
+ # Data generation
+ x_train = tf.stack(
+ [i * tf.ones(input_shape, tf.float32) for i in range(nb_samples)]
+ )
+ x_test = x_train[1:-1]
+ y_train = tf.one_hot(tf.range(len(x_train)) % nb_labels, nb_labels)
+
+ return x_train, x_test, y_train
+
+
+def _generate_model(input_shape=(32, 32, 3), output_shape=2):
+ model = tf.keras.Sequential()
+ model.add(Input(shape=input_shape))
+ model.add(Conv2D(4, kernel_size=(2, 2), activation="relu", name="conv2d_1"))
+ model.add(Conv2D(4, kernel_size=(2, 2), activation="relu", name="conv2d_2"))
+ model.add(MaxPooling2D(pool_size=(2, 2)))
+ model.add(Dropout(0.25))
+ model.add(Flatten())
+ model.add(Dense(output_shape, name="dense"))
+ model.add(Activation("softmax", name="softmax"))
+ model.compile(loss="categorical_crossentropy", optimizer="sgd")
+
+ return model
+
+
+def test_model_splitting_latent_layer():
+ """We should target the right layer using either int, string or default procedure"""
+ tf.keras.backend.clear_session()
+
+ model = _generate_model()
+
+ first_conv_layer = model.get_layer("conv2d_1")
+ last_conv_layer = model.get_layer("conv2d_2")
+ flatten_layer = model.get_layer("flatten")
+
+ # last_conv should be recognized
+ _, _, latent_layer = model_splitting(model, latent_layer="last_conv", return_layer=True)
+ assert latent_layer == last_conv_layer
+
+ # target the first conv layer
+ _, _, latent_layer = model_splitting(model, latent_layer=0, return_layer=True)
+ assert latent_layer == first_conv_layer
+
+ # target a random flatten layer
+ _, _, latent_layer = model_splitting(model, latent_layer="flatten", return_layer=True)
+ assert latent_layer == flatten_layer
+
+
+def test_simple_projection_mapping():
+ """
+ Test if a simple projection can be mapped.
+ """
+ # Setup
+ input_shape = (7, 7, 3)
+ nb_samples = 10
+ nb_labels = 2
+ x_train, _, y_train = get_setup(input_shape, nb_samples=nb_samples, nb_labels=nb_labels)
+
+ weights = tf.random.uniform((input_shape[0], input_shape[1], 1), minval=0, maxval=1)
+
+ space_projection = lambda x, y=None: tf.nn.max_pool2d(x, ksize=3, strides=1, padding="SAME")
+
+ projection = Projection(get_weights=weights, space_projection=space_projection)
+
+ # Generate tf.data.Dataset from numpy
+ train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(3)
+
+ # Apply the projection by mapping the dataset
+ projected_train_dataset = projection.project_dataset(train_dataset)
+
+
+def test_latent_space_projection_mapping():
+ """
+ Test if the latent space projection can be mapped.
+ """
+ # Setup
+ input_shape = (7, 7, 3)
+ nb_samples = 10
+ nb_labels = 2
+ x_train, _, y_train = get_setup(input_shape, nb_samples=nb_samples, nb_labels=nb_labels)
+
+ model = _generate_model(input_shape=input_shape, output_shape=nb_labels)
+
+ projection = LatentSpaceProjection(model, "last_conv")
+
+ # Generate tf.data.Dataset from numpy
+ train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(3)
+
+ # Apply the projection by mapping the dataset
+ projected_train_dataset = projection.project_dataset(train_dataset)
+
+
+def test_attribution_projection_mapping():
+ """
+ Test if the attribution projection can be mapped.
+ """
+ # Setup
+ input_shape = (7, 7, 3)
+ nb_samples = 10
+ nb_labels = 2
+ x_train, _, y_train = get_setup(input_shape, nb_samples=nb_samples, nb_labels=nb_labels)
+
+ model = _generate_model(input_shape=input_shape, output_shape=nb_labels)
+
+ projection = AttributionProjection(model, method=Saliency, latent_layer="last_conv")
+
+ # Generate tf.data.Dataset from numpy
+ train_dataset = tf.data.Dataset.from_tensor_slices(x_train).batch(3)
+ targets_dataset = tf.data.Dataset.from_tensor_slices(y_train).batch(3)
+
+ # Apply the projection by mapping the dataset
+ projected_train_dataset = projection.project_dataset(train_dataset, targets_dataset)
\ No newline at end of file
diff --git a/tests/example_based/test_split_projection.py b/tests/example_based/test_split_projection.py
deleted file mode 100644
index db3105d1..00000000
--- a/tests/example_based/test_split_projection.py
+++ /dev/null
@@ -1,85 +0,0 @@
-import numpy as np
-import tensorflow as tf
-from tensorflow.keras.layers import (
- Dense,
- Conv2D,
- Activation,
- Dropout,
- Flatten,
- MaxPooling2D,
- Input,
-)
-
-from xplique.example_based.projections import AttributionProjection
-from xplique.example_based.projections import LatentSpaceProjection
-from ..utils import generate_data, almost_equal
-
-
-def _generate_model(input_shape=(32, 32, 3), output_shape=10):
- model = tf.keras.Sequential()
- model.add(Input(shape=input_shape))
- model.add(Conv2D(4, kernel_size=(2, 2), activation="relu", name="conv2d_1"))
- model.add(Conv2D(4, kernel_size=(2, 2), activation="relu", name="conv2d_2"))
- model.add(MaxPooling2D(pool_size=(2, 2)))
- model.add(Dropout(0.25))
- model.add(Flatten())
- model.add(Dense(output_shape, name="dense"))
- model.add(Activation("softmax", name="softmax"))
- model.compile(loss="categorical_crossentropy", optimizer="sgd")
-
- return model
-
-
-def test_attribution_latent_layer():
- """We should target the right layer using either int, string or default procedure"""
- tf.keras.backend.clear_session()
-
- model = _generate_model()
-
- first_conv_layer = model.get_layer("conv2d_1")
- last_conv_layer = model.get_layer("conv2d_2")
- flatten_layer = model.get_layer("flatten")
-
- # default should not include model splitting
- projection_default = AttributionProjection(model)
- assert projection_default.latent_layer is None
-
- # last_conv should be recognized
- projection_default = AttributionProjection(model, latent_layer="last_conv")
- assert projection_default.latent_layer == last_conv_layer
-
- # target the first conv layer
- projection_default = AttributionProjection(model, latent_layer=0)
- assert projection_default.latent_layer == first_conv_layer
-
- # target a random flatten layer
- projection_default = AttributionProjection(model, latent_layer="flatten")
- assert projection_default.latent_layer == flatten_layer
-
-
-def test_latent_space_latent_layer():
- """We should target the right layer using either int, string or default procedure"""
- tf.keras.backend.clear_session()
-
- model = _generate_model()
-
- first_conv_layer = model.get_layer("conv2d_1")
- last_conv_layer = model.get_layer("conv2d_2")
- flatten_layer = model.get_layer("flatten")
- last_layer = model.get_layer("softmax")
-
- # default should not include model spliting
- projection_default = LatentSpaceProjection(model)
- assert projection_default.latent_layer == last_layer
-
- # last_conv should be recognized
- projection_default = LatentSpaceProjection(model, latent_layer="last_conv")
- assert projection_default.latent_layer == last_conv_layer
-
- # target the first conv layer
- projection_default = LatentSpaceProjection(model, latent_layer=0)
- assert projection_default.latent_layer == first_conv_layer
-
- # target a random flatten layer
- projection_default = LatentSpaceProjection(model, latent_layer="flatten")
- assert projection_default.latent_layer == flatten_layer
From 56ddd00cc35cf6a6ccc9c5872dae69d9572a60b4 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:18:47 +0100
Subject: [PATCH 026/138] similar examples tests: adapt to changes
---
tests/example_based/test_similar_examples.py | 15 +++++++++------
1 file changed, 9 insertions(+), 6 deletions(-)
diff --git a/tests/example_based/test_similar_examples.py b/tests/example_based/test_similar_examples.py
index 3e6c0401..2ec371d3 100644
--- a/tests/example_based/test_similar_examples.py
+++ b/tests/example_based/test_similar_examples.py
@@ -16,7 +16,7 @@
from xplique.types import Union
from xplique.example_based import SimilarExamples
-from xplique.example_based.projections import CustomProjection
+from xplique.example_based.projections import Projection, LatentSpaceProjection
from xplique.example_based.search_methods import KNN
from tests.utils import almost_equal
@@ -40,9 +40,9 @@ def test_similar_examples_input_datasets_management():
"""
Test management of dataset init inputs
"""
- proj = CustomProjection(space_projection=lambda inputs, targets=None: inputs)
+ proj = Projection(space_projection=lambda inputs, targets=None: inputs)
- tf_tensor = tf.reshape(tf.range(90), (10, 3, 3))
+ tf_tensor = tf.reshape(tf.range(90, dtype=tf.float32), (10, 3, 3))
np_array = np.array(tf_tensor)
tf_dataset = tf.data.Dataset.from_tensor_slices(tf_tensor)
too_short_np_array = np_array[:3]
@@ -140,7 +140,7 @@ def test_similar_examples_basic():
k = 3
x_train, x_test, _ = get_setup(input_shape)
- identity_projection = CustomProjection(
+ identity_projection = Projection(
space_projection=lambda inputs, targets=None: inputs
)
@@ -184,7 +184,7 @@ def test_similar_examples_return_multiple_elements():
nb_samples_test = len(x_test)
assert nb_samples_test + 2 == len(y_train)
- identity_projection = CustomProjection(
+ identity_projection = Projection(
space_projection=lambda inputs, targets=None: inputs
)
@@ -266,7 +266,7 @@ def test_similar_examples_weighting():
noise = np.random.uniform(size=x_train.shape, low=-100, high=100)
x_train = np.float32(weights * np.array(x_train) + (1 - weights) * noise)
- weighting_function = CustomProjection(weights=weights)
+ weighting_function = Projection(get_weights=weights)
method = SimilarExamples(
cases_dataset=x_train,
@@ -286,6 +286,9 @@ def test_similar_examples_weighting():
assert examples.shape == (nb_samples_test, k) + input_shape
for i in range(nb_samples_test):
+ print(i)
+ print(examples[i, 0])
+ print(x_train[i + 1])
# test examples:
assert almost_equal(examples[i, 0], x_train[i + 1])
assert almost_equal(examples[i, 1], x_train[i + 2]) or almost_equal(
From c44d256d5486664db178f21bcdc7b0793ab39dad Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Feb 2024 16:19:14 +0100
Subject: [PATCH 027/138] cole tests: adapt to changes and add hadamard
---
tests/example_based/test_cole.py | 74 ++++++++++++++++++++++++++------
1 file changed, 62 insertions(+), 12 deletions(-)
diff --git a/tests/example_based/test_cole.py b/tests/example_based/test_cole.py
index 9d8c63a0..a9dc1afe 100644
--- a/tests/example_based/test_cole.py
+++ b/tests/example_based/test_cole.py
@@ -7,18 +7,13 @@
sys.path.append(os.getcwd())
-from math import prod, sqrt
-
import numpy as np
-import scipy
import tensorflow as tf
+from xplique.commons.operators_operations import gradients_predictions
from xplique.attributions import Occlusion, Saliency
-
from xplique.example_based import Cole, SimilarExamples
-from xplique.example_based.projections import CustomProjection
-from xplique.example_based.search_methods import KNN
-from xplique.types import Union
+from xplique.example_based.projections import Projection
from tests.utils import (
generate_data,
@@ -38,11 +33,12 @@ def get_setup(input_shape, nb_samples=10, nb_labels=10):
)
x_test = x_train[1:-1]
y_train = tf.one_hot(tf.range(len(x_train)) % nb_labels, depth=nb_labels)
+ y_test = y_train[1:-1]
# Model generation
model = generate_model(input_shape, nb_labels)
- return model, x_train, x_test, y_train
+ return model, x_train, x_test, y_train, y_test
def test_cole_attribution():
@@ -81,8 +77,12 @@ def test_cole_attribution():
attribution_method=Saliency,
)
- # Cole with attribution explain
- projection = CustomProjection(weights=Saliency(model))
+ # Cole with attribution explain batch gradient is overwritten for test purpose, do not copy!
+ explainer = Saliency(model)
+ explainer.batch_gradient = \
+ lambda model, inputs, targets, batch_size:\
+ explainer.gradient(model, inputs, targets)
+ projection = Projection(get_weights=explainer)
euclidean_dist = lambda x, z: tf.sqrt(tf.reduce_sum(tf.square(x - z)))
method_call = SimilarExamples(
@@ -128,7 +128,57 @@ def test_cole_attribution():
)
-def test_cole_spliting():
+def test_cole_hadamard():
+ """
+ Test Cole with Hadamard projection.
+ It should be the same as a manual projection.
+ """
+ # Setup
+ input_shape = (7, 7, 3)
+ nb_samples = 10
+ nb_labels = 2
+ k = 3
+ model, x_train, x_test, y_train, y_test =\
+ get_setup(input_shape, nb_samples=nb_samples, nb_labels=nb_labels)
+
+ # Cole with Hadamard projection constructor
+ method_constructor = Cole(
+ cases_dataset=x_train,
+ targets_dataset=y_train,
+ k=k,
+ batch_size=7,
+ distance="euclidean",
+ model=model,
+ projection_method="gradient",
+ )
+
+ # Cole with Hadamard projection explain batch gradient is overwritten for test purpose, do not copy!
+ weights_extraction = lambda inputs, targets: gradients_predictions(model, inputs, targets)
+ projection = Projection(get_weights=weights_extraction)
+
+ euclidean_dist = lambda x, z: tf.sqrt(tf.reduce_sum(tf.square(x - z)))
+ method_call = SimilarExamples(
+ cases_dataset=x_train,
+ targets_dataset=y_train,
+ k=k,
+ distance=euclidean_dist,
+ projection=projection,
+ )
+
+ # Generate explanation
+ examples_constructor = method_constructor.explain(x_test, y_test)
+ examples_call = method_call.explain(x_test, y_test)
+
+ # Verifications
+ # Shape should be (n, k, h, w, c)
+ assert examples_constructor.shape == (len(x_test), k) + input_shape
+ assert examples_call.shape == (len(x_test), k) + input_shape
+
+ # both methods should be the same
+ assert almost_equal(examples_constructor, examples_call)
+
+
+def test_cole_splitting():
"""
Test Cole with a `latent_layer` provided.
It should split the model.
@@ -175,4 +225,4 @@ def test_cole_spliting():
# test_cole_attribution()
-# test_cole_spliting()
+# test_cole_splitting()
From 445fce18cbb20f80ebffa2c680595928fd179605 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Thu, 15 Feb 2024 15:03:50 +0100
Subject: [PATCH 028/138] tf operations: add get device for dataset mapping
---
xplique/commons/__init__.py | 2 +-
xplique/commons/tf_operations.py | 25 +++++++++++++++++++++++++
2 files changed, 26 insertions(+), 1 deletion(-)
diff --git a/xplique/commons/__init__.py b/xplique/commons/__init__.py
index c5312a2e..6153c01a 100644
--- a/xplique/commons/__init__.py
+++ b/xplique/commons/__init__.py
@@ -5,7 +5,7 @@
from .data_conversion import tensor_sanitize, numpy_sanitize, sanitize_inputs_targets
from .model_override import guided_relu_policy, deconv_relu_policy, override_relu_gradient, \
find_layer, open_relu_policy
-from .tf_operations import repeat_labels, batch_tensor
+from .tf_operations import repeat_labels, batch_tensor, get_device
from .callable_operations import predictions_one_hot_callable
from .operators_operations import (Tasks, get_operator, check_operator, operator_batching,
get_inference_function, get_gradient_functions)
diff --git a/xplique/commons/tf_operations.py b/xplique/commons/tf_operations.py
index 1d6e5fae..3831b41f 100644
--- a/xplique/commons/tf_operations.py
+++ b/xplique/commons/tf_operations.py
@@ -54,3 +54,28 @@ def batch_tensor(tensors: Union[Tuple, tf.Tensor],
dataset = dataset.batch(batch_size)
return dataset
+
+
+def get_device(device: Optional[str] = None) -> str:
+ """
+ Gets the name of the device to use. If there are any available GPUs, it will use the first one
+ in the system, otherwise, it will use the CPU.
+
+ Parameters
+ ----------
+ device
+ A string specifying the device on which to run the computations. If None, it will search
+ for available GPUs, and if none are found, it will return the first CPU.
+
+ Returns
+ -------
+ device
+ A string with the name of the device on which to run the computations.
+ """
+ if device is not None:
+ return device
+
+ physical_devices = tf.config.list_physical_devices('GPU')
+ if physical_devices is None or len(physical_devices) == 0:
+ return 'cpu:0'
+ return 'GPU:0'
From 7cfa2f391ed0a5a78f885cc5862b470ae28816be Mon Sep 17 00:00:00 2001
From: Lucas Hervier
Date: Wed, 6 Mar 2024 17:56:49 +0100
Subject: [PATCH 029/138] feat: add a new KNN object and improve distance
computation efficiency
---
xplique/example_based/base_example_method.py | 2 +-
.../example_based/search_methods/__init__.py | 4 +-
xplique/example_based/search_methods/base.py | 16 +-
xplique/example_based/search_methods/knn.py | 313 +++++++++++++++---
4 files changed, 279 insertions(+), 56 deletions(-)
diff --git a/xplique/example_based/base_example_method.py b/xplique/example_based/base_example_method.py
index 2c4b99df..9e3facf2 100644
--- a/xplique/example_based/base_example_method.py
+++ b/xplique/example_based/base_example_method.py
@@ -290,7 +290,7 @@ def explain(
projected_inputs = self.projection(inputs, targets)
# look for closest elements to projected inputs
- search_output = self.search_method(projected_inputs)
+ search_output = self.search_method(projected_inputs, targets)
# manage returned elements
return self.format_search_output(search_output, inputs, targets)
diff --git a/xplique/example_based/search_methods/__init__.py b/xplique/example_based/search_methods/__init__.py
index 228e1acd..010b7cb3 100644
--- a/xplique/example_based/search_methods/__init__.py
+++ b/xplique/example_based/search_methods/__init__.py
@@ -2,7 +2,7 @@
Search methods
"""
-from .base import BaseSearchMethod
+from .base import BaseSearchMethod, ORDER
# from .sklearn_knn import SklearnKNN
-from .knn import KNN
+from .knn import BaseKNN, KNN, FilterKNN
diff --git a/xplique/example_based/search_methods/base.py b/xplique/example_based/search_methods/base.py
index 1c7c0f1b..303575c3 100644
--- a/xplique/example_based/search_methods/base.py
+++ b/xplique/example_based/search_methods/base.py
@@ -1,7 +1,7 @@
"""
Base search method for example-based module
"""
-
+from enum import Enum
from abc import ABC, abstractmethod
import tensorflow as tf
@@ -11,6 +11,14 @@
from ...commons import sanitize_dataset
+class ORDER(Enum):
+ """
+ Enumeration for the two types of ordering for the sorting function.
+ ASCENDING puts the elements with the smallest value first.
+ DESCENDING puts the elements with the largest value first.
+ """
+ ASCENDING = 1
+ DESCENDING = 2
def _sanitize_returns(returns: Optional[Union[List[str], str]] = None,
possibilities: List[str] = None,
@@ -133,7 +141,7 @@ def set_returns(self, returns: Optional[Union[List[str], str]] = None):
@abstractmethod
- def find_examples(self, inputs: Union[tf.Tensor, np.ndarray]):
+ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None):
"""
Search the samples to return as examples. Called by the explain methods.
It may also return the indices corresponding to the samples,
@@ -147,6 +155,6 @@ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray]):
"""
raise NotImplementedError()
- def __call__(self, inputs: Union[tf.Tensor, np.ndarray]):
+ def __call__(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None):
"""find_samples alias"""
- return self.find_examples(inputs)
+ return self.find_examples(inputs, targets)
diff --git a/xplique/example_based/search_methods/knn.py b/xplique/example_based/search_methods/knn.py
index 8530f4a6..9b0f228b 100644
--- a/xplique/example_based/search_methods/knn.py
+++ b/xplique/example_based/search_methods/knn.py
@@ -1,18 +1,102 @@
"""
KNN online search method in example-based module
"""
+import math
+from abc import abstractmethod
import numpy as np
import tensorflow as tf
-from ...commons import dataset_gather
+from ...commons import dataset_gather, sanitize_dataset
from ...types import Callable, List, Union, Optional, Tuple
-from .base import BaseSearchMethod
-from ..projections import Projection
+from .base import BaseSearchMethod, ORDER
+class BaseKNN(BaseSearchMethod):
+ """
+ """
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ k: int = 1,
+ search_returns: Optional[Union[List[str], str]] = None,
+ batch_size: Optional[int] = 32,
+ order: ORDER = ORDER.ASCENDING
+ ):
+ super().__init__(
+ cases_dataset, k, search_returns, batch_size
+ )
+
+ assert isinstance(order, ORDER), f"order should be an instance of ORDER and not {type(order)}"
+ self.order = order
+ # fill value
+ self.fill_value = np.inf if self.order == ORDER.ASCENDING else -np.inf
+
+ @abstractmethod
+ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None) -> Tuple[tf.Tensor, tf.Tensor]:
+ """
+ Compute the k-neareast neighbors to each tensor of `inputs` in `self.cases_dataset`.
+ Here `self.cases_dataset` is a `tf.data.Dataset`, hence, computations are done by batches.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples on which knn are computed.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ More information in the documentation.
+ targets
+ Tensor or Array. Target samples to be explained.
+
+ Returns
+ -------
+ best_distances
+ Tensor of distances between the knn and the inputs with dimension (n, k).
+ The n inputs times their k-nearest neighbors.
+ best_indices
+ Tensor of indices of the knn in `self.cases_dataset` with dimension (n, k, 2).
+ Where, n represent the number of inputs and k the number of corresponding examples.
+ The index of each element is encoded by two values,
+ the batch index and the index of the element in the batch.
+ Those indices can be used through `xplique.commons.tf_dataset_operation.dataset_gather`.
+ """
+ raise NotImplementedError
-class KNN(BaseSearchMethod):
+ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None):
+ """
+ Search the samples to return as examples. Called by the explain methods.
+ It may also return the indices corresponding to the samples,
+ based on `return_indices` value.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Assumed to have been already projected.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ """
+ # compute neighbors
+ examples_distances, examples_indices = self.kneighbors(inputs, targets)
+
+ # Set values in return dict
+ return_dict = {}
+ if "examples" in self.returns:
+ return_dict["examples"] = dataset_gather(self.cases_dataset, examples_indices)
+ if "include_inputs" in self.returns:
+ inputs = tf.expand_dims(inputs, axis=1)
+ return_dict["examples"] = tf.concat(
+ [inputs, return_dict["examples"]], axis=1
+ )
+ if "indices" in self.returns:
+ return_dict["indices"] = examples_indices
+ if "distances" in self.returns:
+ return_dict["distances"] = examples_distances
+
+ # Return a dict only different variables are returned
+ if len(return_dict) == 1:
+ return list(return_dict.values())[0]
+ return return_dict
+
+class KNN(BaseKNN):
"""
KNN method to search examples. Based on `sklearn.neighbors.NearestNeighbors`.
Basically a wrapper of `NearestNeighbors` to match the `BaseSearchMethod` API.
@@ -36,7 +120,6 @@ class KNN(BaseSearchMethod):
"Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
yielding the corresponding p-norm." We also added 'cosine'.
"""
-
def __init__(
self,
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
@@ -44,9 +127,10 @@ def __init__(
search_returns: Optional[Union[List[str], str]] = None,
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
+ order: ORDER = ORDER.ASCENDING
): # pylint: disable=R0801
super().__init__(
- cases_dataset, k, search_returns, batch_size
+ cases_dataset, k, search_returns, batch_size, order
)
if hasattr(distance, "__call__"):
@@ -54,27 +138,32 @@ def __init__(
elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(
distance, int
):
- self.distance_fn = lambda x1, x2: tf.norm(x1 - x2, ord=distance)
+ self.distance_fn = lambda x1, x2: tf.norm(x1 - x2, ord=distance, axis=-1)
else:
raise AttributeError(
"The distance parameter is expected to be either a Callable or in"
- + " ['fro', 'euclidean', 'cosine', 1, 2, np.inf] ",
- +f"but {distance} was received.",
+ + " ['fro', 'euclidean', 1, 2, np.inf] "
+ +f"but {type(distance)} was received."
)
- self.distance_fn_over_all_x2 = lambda x1, x2: tf.map_fn(
- fn=lambda x2: self.distance_fn(x1, x2),
- elems=x2,
- )
-
- # Computes crossed distances between two tensors x1(shape=(n1, ...)) and x2(shape=(n2, ...))
- # The result is a distance matrix of size (n1, n2)
- self.crossed_distances_fn = lambda x1, x2: tf.vectorized_map(
- fn=lambda a1: self.distance_fn_over_all_x2(a1, x2),
- elems=x1
- )
+ @tf.function
+ def _crossed_distances_fn(self, x1, x2):
+ n = x1.shape[0]
+ m = x2.shape[0]
+ x2 = tf.expand_dims(x2, axis=0)
+ x2 = tf.repeat(x2, n, axis=0)
+ # reshape for broadcasting
+ x1 = tf.reshape(x1, (n, 1, -1))
+ x2 = tf.reshape(x2, (n, m, -1))
+ def compute_distance(args):
+ a, b = args
+ return self.distance_fn(a, b)
+ args = (x1, x2)
+ # Use vectorized_map to apply compute_distance element-wise
+ distances = tf.vectorized_map(compute_distance, args)
+ return distances
- def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Tensor, tf.Tensor]:
+ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], _ = None) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Compute the k-neareast neighbors to each tensor of `inputs` in `self.cases_dataset`.
Here `self.cases_dataset` is a `tf.data.Dataset`, hence, computations are done by batches.
@@ -104,12 +193,13 @@ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Tensor, t
# (n, k, 2)
best_indices = tf.Variable(tf.fill((nb_inputs, self.k, 2), -1))
# (n, k)
- best_distances = tf.Variable(tf.fill((nb_inputs, self.k), np.inf))
+ best_distances = tf.Variable(tf.fill((nb_inputs, self.k), self.fill_value))
# (n, bs)
batch_indices = tf.expand_dims(tf.range(self.batch_size, dtype=tf.int32), axis=0)
batch_indices = tf.tile(batch_indices, multiples=(nb_inputs, 1))
# iterate on batches
+ # for batch_index, (cases, cases_targets) in enumerate(zip(self.cases_dataset, self.targets_dataset)):
for batch_index, cases in enumerate(self.cases_dataset):
# add new elements
# (n, current_bs, 2)
@@ -120,7 +210,7 @@ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Tensor, t
# compute distances
# (n, current_bs)
- distances = self.crossed_distances_fn(inputs, cases)
+ distances = self._crossed_distances_fn(inputs, cases)
# (n, k+curent_bs, 2)
concatenated_indices = tf.concat([best_indices, new_indices], axis=1)
@@ -130,7 +220,7 @@ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Tensor, t
# sort all
# (n, k)
sort_order = tf.argsort(
- concatenated_distances, axis=1, direction="ASCENDING"
+ concatenated_distances, axis=1, direction=self.order.name.upper()
)[:, : self.k]
best_indices.assign(
@@ -142,37 +232,162 @@ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Tensor, t
return best_distances, best_indices
- def find_examples(self, inputs: Union[tf.Tensor, np.ndarray]):
+class FilterKNN(BaseKNN):
+ """
+ KNN method to search examples. Based on `sklearn.neighbors.NearestNeighbors`.
+ Basically a wrapper of `NearestNeighbors` to match the `BaseSearchMethod` API.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ For natural example-based methods it is the train dataset.
+ k
+ The number of examples to retrieve.
+ search_returns
+ String or list of string with the elements to return in `self.find_examples()`.
+ See `self.set_returns()` for detail.
+ batch_size
+ Number of sample treated simultaneously.
+ It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
+ distance
+ Either a Callable, or a value supported by `tf.norm` `ord` parameter.
+ Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
+ "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
+ yielding the corresponding p-norm." We also added 'cosine'.
+ filter_fn
+ A Callable that takes as inputs the inputs, their targets, the cases and their targets and
+ returns a boolean mask of shape (n, m) where n is the number of inputs and m the number of cases.
+ This boolean mask is used to choose between which inputs and cases to compute the distances.
+ """
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ filter_fn: Optional[Callable] = None,
+ search_returns: Optional[Union[List[str], str]] = None,
+ batch_size: Optional[int] = 32,
+ distance: Union[int, str, Callable] = "euclidean",
+ order: ORDER = ORDER.ASCENDING
+ ): # pylint: disable=R0801
+ super().__init__(
+ cases_dataset, k, search_returns, batch_size, order
+ )
+
+ if hasattr(distance, "__call__"):
+ self.distance_fn = distance
+ elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(
+ distance, int
+ ):
+ self.distance_fn = lambda x1, x2, m: tf.where(m, tf.norm(x1 - x2, ord=distance, axis=-1), self.fill_value)
+ else:
+ raise AttributeError(
+ "The distance parameter is expected to be either a Callable or in"
+ + " ['fro', 'euclidean', 1, 2, np.inf] "
+ +f"but {type(distance)} was received."
+ )
+
+ # set targets_dataset
+ if targets_dataset is not None:
+ batch_size = min(batch_size, len(cases_dataset))
+ cardinality = math.ceil(len(cases_dataset) / batch_size)
+ self.targets_dataset = sanitize_dataset(
+ targets_dataset, batch_size, cardinality
+ )
+ else:
+ self.targets_dataset = [None]*len(cases_dataset)
+
+ # TODO: Assertion on the function signature
+ if filter_fn is None:
+ filter_fn = lambda x, z, y, t: tf.ones((tf.shape(x)[0], tf.shape(z)[0]), dtype=tf.bool)
+ self.filter_fn = filter_fn
+
+ @tf.function
+ def _crossed_distances_fn(self, x1, x2, mask):
+ n = x1.shape[0]
+ m = x2.shape[0]
+ x2 = tf.expand_dims(x2, axis=0)
+ x2 = tf.repeat(x2, n, axis=0)
+ # reshape for broadcasting
+ x1 = tf.reshape(x1, (n, 1, -1))
+ x2 = tf.reshape(x2, (n, m, -1))
+ def compute_distance(args):
+ a, b, mask = args
+ return self.distance_fn(a, b, mask)
+ args = (x1, x2, mask)
+ # Use vectorized_map to apply compute_distance element-wise
+ distances = tf.vectorized_map(compute_distance, args)
+ return distances
+
+ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None) -> Tuple[tf.Tensor, tf.Tensor]:
"""
- Search the samples to return as examples. Called by the explain methods.
- It may also return the indices corresponding to the samples,
- based on `return_indices` value.
+ Compute the k-neareast neighbors to each tensor of `inputs` in `self.cases_dataset`.
+ Here `self.cases_dataset` is a `tf.data.Dataset`, hence, computations are done by batches.
Parameters
----------
inputs
- Tensor or Array. Input samples to be explained.
- Assumed to have been already projected.
+ Tensor or Array. Input samples on which knn are computed.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ More information in the documentation.
+
+ Returns
+ -------
+ best_distances
+ Tensor of distances between the knn and the inputs with dimension (n, k).
+ The n inputs times their k-nearest neighbors.
+ best_indices
+ Tensor of indices of the knn in `self.cases_dataset` with dimension (n, k, 2).
+ Where, n represent the number of inputs and k the number of corresponding examples.
+ The index of each element is encoded by two values,
+ the batch index and the index of the element in the batch.
+ Those indices can be used through `xplique.commons.tf_dataset_operation.dataset_gather`.
"""
- # compute neighbors
- examples_distances, examples_indices = self.kneighbors(inputs)
+ nb_inputs = tf.shape(inputs)[0]
- # Set values in return dict
- return_dict = {}
- if "examples" in self.returns:
- return_dict["examples"] = dataset_gather(self.cases_dataset, examples_indices)
- if "include_inputs" in self.returns:
- inputs = tf.expand_dims(inputs, axis=1)
- return_dict["examples"] = tf.concat(
- [inputs, return_dict["examples"]], axis=1
- )
- if "indices" in self.returns:
- return_dict["indices"] = examples_indices
- if "distances" in self.returns:
- return_dict["distances"] = examples_distances
+ # initialiaze
+ # (n, k, 2)
+ best_indices = tf.Variable(tf.fill((nb_inputs, self.k, 2), -1))
+ # (n, k)
+ best_distances = tf.Variable(tf.fill((nb_inputs, self.k), self.fill_value))
+ # (n, bs)
+ batch_indices = tf.expand_dims(tf.range(self.batch_size, dtype=tf.int32), axis=0)
+ batch_indices = tf.tile(batch_indices, multiples=(nb_inputs, 1))
- # Return a dict only different variables are returned
- if len(return_dict) == 1:
- return list(return_dict.values())[0]
- return return_dict
+ # iterate on batches
+ for batch_index, (cases, cases_targets) in enumerate(zip(self.cases_dataset, self.targets_dataset)):
+ # add new elements
+ # (n, current_bs, 2)
+ indices = batch_indices[:, : tf.shape(cases)[0]]
+ new_indices = tf.stack(
+ [tf.fill(indices.shape, tf.cast(batch_index, tf.int32)), indices], axis=-1
+ )
+
+ # get filter masks
+ # (n, current_bs)
+ filter_mask = self.filter_fn(inputs, cases, targets, cases_targets)
+
+ # compute distances
+ # (n, current_bs)
+ distances = self._crossed_distances_fn(inputs, cases, mask=filter_mask)
+
+ # (n, k+curent_bs, 2)
+ concatenated_indices = tf.concat([best_indices, new_indices], axis=1)
+ # (n, k+curent_bs)
+ concatenated_distances = tf.concat([best_distances, distances], axis=1)
+
+ # sort all
+ # (n, k)
+ sort_order = tf.argsort(
+ concatenated_distances, axis=1, direction=self.order.name.upper()
+ )[:, : self.k]
+
+ best_indices.assign(
+ tf.gather(concatenated_indices, sort_order, axis=1, batch_dims=1)
+ )
+ best_distances.assign(
+ tf.gather(concatenated_distances, sort_order, axis=1, batch_dims=1)
+ )
+
+ return best_distances, best_indices
\ No newline at end of file
From 52d1c6a7641379f636569b6bca0d2bce8be4da6a Mon Sep 17 00:00:00 2001
From: Lucas Hervier
Date: Wed, 6 Mar 2024 17:57:32 +0100
Subject: [PATCH 030/138] tests: add tests for KNNs
---
tests/example_based/test_knn.py | 505 ++++++++++++++++++++++++++++++++
1 file changed, 505 insertions(+)
create mode 100644 tests/example_based/test_knn.py
diff --git a/tests/example_based/test_knn.py b/tests/example_based/test_knn.py
new file mode 100644
index 00000000..e1d43b08
--- /dev/null
+++ b/tests/example_based/test_knn.py
@@ -0,0 +1,505 @@
+"""
+Test the different search methods.
+"""
+import pytest
+import numpy as np
+import tensorflow as tf
+
+from xplique.example_based.search_methods import BaseKNN, KNN, FilterKNN, ORDER
+
+def get_setup(input_shape, nb_samples=10, nb_labels=10):
+ """
+ Generate data and model for SimilarExamples
+ """
+ # Data generation
+ x_train = tf.stack(
+ [i * tf.ones(input_shape, tf.float32) for i in range(nb_samples)]
+ )
+ x_test = x_train[1:-1]
+ y_train = tf.range(len(x_train), dtype=tf.float32) % nb_labels
+
+ return x_train, x_test, y_train
+
+class MockKNN(BaseKNN):
+ """
+ Mock KNN class for testing the find_examples method
+ """
+ def kneighbors(self, inputs, targets):
+ """
+ Define a mock kneighbors method for testing the find_examples method of
+ the base class.
+ """
+ best_distances = tf.random.normal((inputs.shape[0], self.k), dtype=tf.float32)
+ best_indices= tf.random.uniform((inputs.shape[0], self.k, 2), maxval=self.k, dtype=tf.int32)
+ return best_distances, best_indices
+
+def same_target_filter(inputs, cases, targets, cases_targets):
+ """
+ Filter function that returns a boolean mask with true when point-wise inputs and cases
+ have the same target.
+ """
+ # get the labels predicted by the model
+ # (n, )
+ predicted_labels = tf.argmax(targets, axis=-1)
+
+ # for each input, if the target label is the same as the predicted label
+ # the mask as a True value and False otherwise
+ label_targets = tf.argmax(cases_targets, axis=-1) # (bs,)
+ mask = tf.equal(tf.expand_dims(predicted_labels, axis=1), label_targets) #(n, bs)
+ return mask
+
+def test_base_init():
+ """
+ Test the initialization of the base KNN class (not the super).
+ Check if it raises the relevant errors when the input is invalid.
+ """
+ base_knn = MockKNN(
+ np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
+ k=2,
+ search_returns='distances',
+ )
+ assert base_knn.order == ORDER.ASCENDING
+ assert base_knn.fill_value == np.inf
+
+ # Test with reverse order
+ order = ORDER.DESCENDING
+ base_knn = MockKNN(
+ np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
+ k=2,
+ search_returns='distances',
+ order=order
+ )
+ assert base_knn.order == order
+ assert base_knn.fill_value == -np.inf
+
+ # Test with invalid order
+ with pytest.raises(AssertionError):
+ base_knn = MockKNN(
+ np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
+ k=2,
+ search_returns='distances',
+ order='invalid'
+ )
+
+def test_base_find_examples():
+ """
+ Test the find_examples method of the base KNN class.
+ """
+ returns = ["examples", "indices", "distances"]
+ mock_knn = MockKNN(
+ tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.], [10., 11., 12.]], dtype=tf.float32),
+ k = 2,
+ search_returns = returns,
+ )
+
+ inputs = tf.random.normal((5, 3), dtype=tf.float32)
+ return_dict = mock_knn.find_examples(inputs)
+ assert set(return_dict.keys()) == set(returns)
+ assert return_dict["examples"].shape == (5, 2, 3)
+ assert return_dict["indices"].shape == (5, 2, 2)
+ assert return_dict["distances"].shape == (5, 2)
+
+ returns = ["examples", "include_inputs"]
+ mock_knn = MockKNN(
+ tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.], [10., 11., 12.]], dtype=tf.float32),
+ k = 2,
+ search_returns = returns,
+ )
+ return_dict = mock_knn.find_examples(inputs)
+ assert return_dict.shape == (5, 3, 3)
+
+ mock_knn = MockKNN(
+ tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.], [10., 11., 12.]], dtype=tf.float32),
+ k = 2,
+ )
+ return_dict = mock_knn.find_examples(inputs)
+ assert return_dict.shape == (5, 2, 3)
+
+def test_knn_init():
+ """
+ Test the initialization of the KNN class which are not linked to the super class.
+ """
+ cases_dataset = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.], [10., 11., 12.]], dtype=tf.float32)
+ x1 = tf.random.normal((1, 3), dtype=tf.float32)
+ x2 = tf.random.normal((3, 3), dtype=tf.float32)
+
+ # Test with distances that are compatible with tf.norm
+ distances = ["euclidean", 1, 2, np.inf, 5]
+ for distance in distances:
+ knn = KNN(
+ cases_dataset,
+ k=2,
+ search_returns='distances',
+ distance=distance,
+ )
+ assert tf.reduce_all(tf.equal(knn.distance_fn(x1, x2), tf.norm(x1 - x2, ord=distance, axis=-1)))
+
+ # Test with a custom distance function
+ def custom_distance(x1, x2):
+ return tf.reduce_sum(tf.abs(x1 - x2), axis=-1)
+ knn = KNN(
+ cases_dataset,
+ k=2,
+ search_returns='distances',
+ distance=custom_distance,
+ )
+ assert tf.reduce_all(tf.equal(knn.distance_fn(x1, x2), custom_distance(x1, x2)))
+
+ # Test with invalid distance
+ invalid_distances = [None, "invalid", 0.5]
+ for distance in invalid_distances:
+ with pytest.raises(AttributeError):
+ knn = KNN(
+ cases_dataset,
+ k=2,
+ search_returns='distances',
+ distance=distance,
+ )
+
+def test_knn_compute_distances():
+ """
+ Test the private method _compute_distances_fn of the KNN class.
+ """
+ # Test with input and cases being 1D
+ knn = KNN(
+ np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
+ k=2,
+ distance='euclidean',
+ order=ORDER.ASCENDING
+ )
+ x1 = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=tf.float32)
+ x2 = tf.constant([[7.0, 8.0], [9.0, 10.0]], dtype=tf.float32)
+
+ expected_distance = tf.constant(
+ [
+ [np.sqrt(72), np.sqrt(128)],
+ [np.sqrt(32), np.sqrt(72)],
+ [np.sqrt(8), np.sqrt(32)]
+ ], dtype=tf.float32
+ )
+
+ distances = knn._crossed_distances_fn(x1, x2)
+ assert distances.shape == (x1.shape[0], x2.shape[0])
+ assert tf.reduce_all(tf.equal(distances, expected_distance))
+
+ # Test with higher dimensions
+ data = np.array([
+ [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
+ [[10, 11, 12], [13, 14, 15], [16, 17, 18]]
+ ])
+
+ knn = KNN(
+ data,
+ k=2,
+ distance="euclidean",
+ order=ORDER.ASCENDING
+ )
+
+ x1 = tf.constant(
+ [
+ [[1, 2, 3],[4, 5, 6],[7, 8, 9]],
+ [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
+ [[19, 20, 21], [22, 23, 24], [25, 26, 27]]
+ ], dtype=tf.float32
+ )
+
+ x2 = tf.constant(
+ [
+ [[28, 29, 30], [31, 32, 33], [34, 35, 36]],
+ [[37, 38, 39], [40, 41, 42], [43, 44, 45]],
+ ], dtype=tf.float32
+ )
+
+ expected_distance = tf.constant(
+ [[np.sqrt(9)*27, np.sqrt(9)*36],
+ [np.sqrt(9)*18, np.sqrt(9)*27],
+ [np.sqrt(9)*9, np.sqrt(9)*18]], dtype=tf.float32)
+
+ distances = knn._crossed_distances_fn(x1, x2)
+ assert distances.shape == (x1.shape[0], x2.shape[0])
+ assert tf.reduce_all(tf.equal(distances, expected_distance))
+
+
+def test_knn_kneighbors():
+ """
+ Test the kneighbors method of the KNN class.
+ """
+ # Test with input and cases being 1D
+ cases = tf.constant([[1.], [2.], [3.], [4.], [5.]], dtype=tf.float32)
+ inputs = tf.constant([[1.5], [2.5], [4.5]], dtype=tf.float32)
+ knn = KNN(
+ cases,
+ k=2,
+ batch_size=2,
+ distance="euclidean",
+ )
+
+ distances, indices = knn.kneighbors(inputs)
+ assert distances.shape == (3, 2)
+ assert indices.shape == (3, 2, 2)
+ assert tf.reduce_all(tf.equal(distances, tf.constant([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]], dtype=tf.float32)))
+ assert tf.reduce_all(tf.equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32)))
+
+ # Test with reverse order
+ knn = KNN(
+ cases,
+ k=2,
+ batch_size=2,
+ distance="euclidean",
+ order=ORDER.DESCENDING
+ )
+
+ distances, indices = knn.kneighbors(inputs)
+ assert distances.shape == (3, 2)
+ assert indices.shape == (3, 2, 2)
+ assert tf.reduce_all(tf.equal(distances, tf.constant([[3.5, 2.5], [2.5, 1.5], [3.5, 2.5]], dtype=tf.float32)))
+ assert tf.reduce_all(tf.equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32)))
+
+ # Test with input and cases being 2D
+ cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32)
+ inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
+ knn = KNN(
+ cases,
+ k=2,
+ batch_size=2,
+ distance="euclidean",
+ )
+
+ distances, indices = knn.kneighbors(inputs)
+ assert distances.shape == (3, 2)
+ assert indices.shape == (3, 2, 2)
+ assert tf.reduce_all(tf.equal(distances, tf.constant([[np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)]], dtype=tf.float32)))
+ assert tf.reduce_all(tf.equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32)))
+
+ # Test with reverse order
+ knn = KNN(
+ cases,
+ k=2,
+ batch_size=2,
+ distance="euclidean",
+ order=ORDER.DESCENDING
+ )
+
+ distances, indices = knn.kneighbors(inputs)
+ assert distances.shape == (3, 2)
+ assert indices.shape == (3, 2, 2)
+ expected_distances = tf.constant([[np.sqrt(2*3.5**2), np.sqrt(2*2.5**2)], [np.sqrt(2*2.5**2), np.sqrt(2*1.5**2)], [np.sqrt(2*3.5**2), np.sqrt(2*2.5**2)]], dtype=tf.float32)
+ assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
+ assert tf.reduce_all(tf.equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32)))
+
+def test_filter_knn_compute_distances():
+ """
+ Test the private method _compute_distances_fn of the FilterKNN class.
+ """
+ # Test in Low dimension
+ knn = FilterKNN(
+ np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
+ k=2,
+ distance='euclidean',
+ order=ORDER.ASCENDING
+ )
+ x1 = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=tf.float32)
+ x2 = tf.constant([[7.0, 8.0], [9.0, 10.0]], dtype=tf.float32)
+ expected_distance = tf.constant(
+ [
+ [np.sqrt(72), np.sqrt(128)],
+ [np.sqrt(32), np.sqrt(72)],
+ [np.sqrt(8), np.sqrt(32)]
+ ], dtype=tf.float32
+ )
+ mask = tf.ones((x1.shape[0], x2.shape[0]), dtype=tf.bool)
+ distances = knn._crossed_distances_fn(x1, x2, mask)
+ assert distances.shape == (x1.shape[0], x2.shape[0])
+ assert tf.reduce_all(tf.equal(distances, expected_distance))
+
+ mask = tf.constant([[True, False], [False, True], [True, True]], dtype=tf.bool)
+ expected_distance = tf.constant([[np.sqrt(72), np.inf], [np.inf, np.sqrt(72)], [np.sqrt(8), np.sqrt(32)]], dtype=tf.float32)
+ distances = knn._crossed_distances_fn(x1, x2, mask)
+ assert tf.reduce_all(tf.equal(distances, expected_distance))
+
+ # Test with higher dimensions
+ data = np.array([
+ [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
+ [[10, 11, 12], [13, 14, 15], [16, 17, 18]]
+ ])
+
+ knn = FilterKNN(
+ data,
+ k=2,
+ distance="euclidean",
+ order=ORDER.ASCENDING
+ )
+
+ x1 = tf.constant(
+ [
+ [[1, 2, 3],[4, 5, 6],[7, 8, 9]],
+ [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
+ [[19, 20, 21], [22, 23, 24], [25, 26, 27]]
+ ], dtype=tf.float32
+ )
+
+ x2 = tf.constant(
+ [
+ [[28, 29, 30], [31, 32, 33], [34, 35, 36]],
+ [[37, 38, 39], [40, 41, 42], [43, 44, 45]],
+ ], dtype=tf.float32
+ )
+
+ expected_distance = tf.constant(
+ [[np.sqrt(9)*27, np.sqrt(9)*36],
+ [np.sqrt(9)*18, np.sqrt(9)*27],
+ [np.sqrt(9)*9, np.sqrt(9)*18]], dtype=tf.float32)
+
+ mask = tf.ones((x1.shape[0], x2.shape[0]), dtype=tf.bool)
+ distances = knn._crossed_distances_fn(x1, x2, mask)
+ assert distances.shape == (x1.shape[0], x2.shape[0])
+ assert tf.reduce_all(tf.equal(distances, expected_distance))
+
+ mask = tf.constant([[True, False], [False, True], [True, True]], dtype=tf.bool)
+ expected_distance = tf.constant([[np.sqrt(9)*27, np.inf], [np.inf, np.sqrt(9)*27], [np.sqrt(9)*9, np.sqrt(9)*18]], dtype=tf.float32)
+ distances = knn._crossed_distances_fn(x1, x2, mask)
+ assert distances.shape == (x1.shape[0], x2.shape[0])
+ assert tf.reduce_all(tf.equal(distances, expected_distance))
+
+def test_filter_knn_kneighbors():
+ """
+ """
+ # Test with input and cases being 1D
+ cases = tf.constant([[1.], [2.], [3.], [4.], [5.]], dtype=tf.float32)
+ inputs = tf.constant([[1.5], [2.5], [4.5]], dtype=tf.float32)
+ ## default filter and default order
+ knn = KNN(
+ cases,
+ k=2,
+ batch_size=2,
+ distance="euclidean",
+ )
+
+ distances, indices = knn.kneighbors(inputs)
+ assert distances.shape == (3, 2)
+ assert indices.shape == (3, 2, 2)
+ assert tf.reduce_all(tf.equal(distances, tf.constant([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]], dtype=tf.float32)))
+ assert tf.reduce_all(tf.equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32)))
+
+ cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
+ targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
+
+ ## add a filter that is not the default
+ knn = FilterKNN(
+ cases,
+ targets_dataset=cases_targets,
+ k=2,
+ batch_size=2,
+ distance="euclidean",
+ filter_fn=same_target_filter
+ )
+ mask = same_target_filter(inputs, cases, targets, cases_targets)
+ print(mask)
+ distances, indices = knn.kneighbors(inputs, targets)
+ assert distances.shape == (3, 2)
+ assert indices.shape == (3, 2, 2)
+ assert tf.reduce_all(tf.equal(distances, tf.constant([[0.5, 2.5], [0.5, 0.5], [0.5, 1.5]], dtype=tf.float32)))
+ assert tf.reduce_all(tf.equal(indices, tf.constant([[[0, 0], [1, 1]],[[0, 1], [1, 0]],[[2, 0], [1, 0]]], dtype=tf.int32)))
+
+ ## test with reverse order
+ knn = FilterKNN(
+ cases,
+ k=2,
+ batch_size=2,
+ distance="euclidean",
+ order=ORDER.DESCENDING
+ )
+
+ distances, indices = knn.kneighbors(inputs, targets)
+ assert distances.shape == (3, 2)
+ assert indices.shape == (3, 2, 2)
+ expected_distances = tf.constant([[3.5, 2.5], [2.5, 1.5], [3.5, 2.5]], dtype=tf.float32)
+ assert tf.reduce_all(tf.equal(distances, expected_distances))
+ assert tf.reduce_all(tf.equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32)))
+
+ ## add a filter that is not the default one and reverse order
+ knn = FilterKNN(
+ cases,
+ targets_dataset=cases_targets,
+ k=2,
+ batch_size=2,
+ distance="euclidean",
+ order=ORDER.DESCENDING,
+ filter_fn=same_target_filter
+ )
+
+ distances, indices = knn.kneighbors(inputs, targets)
+ assert distances.shape == (3, 2)
+ assert indices.shape == (3, 2, 2)
+ assert tf.reduce_all(tf.equal(distances, tf.constant([[2.5, 0.5], [2.5, 0.5], [2.5, 1.5]], dtype=tf.float32)))
+ assert tf.reduce_all(tf.equal(indices, tf.constant([[[1, 1], [0, 0]],[[2, 0], [0, 1]],[[0, 1], [1, 0]]], dtype=tf.int32)))
+
+ # Test with input and cases being 2D
+ cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32)
+ inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
+ ## default filter and default order
+ knn = FilterKNN(
+ cases,
+ k=2,
+ batch_size=2,
+ distance="euclidean",
+ )
+
+ distances, indices = knn.kneighbors(inputs)
+ assert distances.shape == (3, 2)
+ assert indices.shape == (3, 2, 2)
+ assert tf.reduce_all(tf.equal(distances, tf.constant([[np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)]], dtype=tf.float32)))
+ assert tf.reduce_all(tf.equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32)))
+
+ cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
+ targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
+ ## add a filter that is not the default
+ knn = FilterKNN(
+ cases,
+ targets_dataset=cases_targets,
+ k=2,
+ batch_size=2,
+ distance="euclidean",
+ filter_fn=same_target_filter
+ )
+
+ distances, indices = knn.kneighbors(inputs, targets)
+ assert distances.shape == (3, 2)
+ assert indices.shape == (3, 2, 2)
+ expected_distances = tf.constant([[np.sqrt(0.5), np.sqrt(2*2.5**2)], [np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(2*1.5**2)],], dtype=tf.float32)
+ assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
+ assert tf.reduce_all(tf.equal(indices, tf.constant([[[0, 0], [1, 1]],[[0, 1], [1, 0]],[[2, 0], [1, 0]]], dtype=tf.int32)))
+
+ ## test with reverse order and default filter
+ knn = FilterKNN(
+ cases,
+ k=2,
+ batch_size=2,
+ distance="euclidean",
+ order=ORDER.DESCENDING
+ )
+
+ distances, indices = knn.kneighbors(inputs)
+ assert distances.shape == (3, 2)
+ assert indices.shape == (3, 2, 2)
+ expected_distances = tf.constant([[np.sqrt(2*3.5**2), np.sqrt(2*2.5**2)], [np.sqrt(2*2.5**2), np.sqrt(2*1.5**2)], [np.sqrt(2*3.5**2), np.sqrt(2*2.5**2)]], dtype=tf.float32)
+ assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
+ assert tf.reduce_all(tf.equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32)))
+
+ ## add a filter that is not the default one and reverse order
+ knn = FilterKNN(
+ cases,
+ targets_dataset=cases_targets,
+ k=2,
+ batch_size=2,
+ distance="euclidean",
+ order=ORDER.DESCENDING,
+ filter_fn=same_target_filter
+ )
+
+ distances, indices = knn.kneighbors(inputs, targets)
+ assert distances.shape == (3, 2)
+ assert indices.shape == (3, 2, 2)
+ expected_distances = tf.constant([[np.sqrt(2*2.5**2), np.sqrt(0.5)], [np.sqrt(2*2.5**2), np.sqrt(0.5)], [np.sqrt(2*2.5**2), np.sqrt(2*1.5**2)]], dtype=tf.float32)
+ assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
+ assert tf.reduce_all(tf.equal(indices, tf.constant([[[1, 1], [0, 0]],[[2, 0], [0, 1]],[[0, 1], [1, 0]]], dtype=tf.int32)))
From 939bbc8bf02a4aea5d9c1bed38bc9fc983bb2f40 Mon Sep 17 00:00:00 2001
From: Lucas Hervier
Date: Thu, 7 Mar 2024 15:23:49 +0100
Subject: [PATCH 031/138] feat: develop the naive semi factual, update search
methods for integration
---
tests/example_based/test_contrastive.py | 53 ++++++++++++++
tests/example_based/test_knn.py | 4 +-
xplique/example_based/__init__.py | 1 +
xplique/example_based/base_example_method.py | 3 +
xplique/example_based/contrastive_examples.py | 70 +++++++++++++++++++
xplique/example_based/search_methods/base.py | 8 +++
xplique/example_based/search_methods/knn.py | 25 +++----
7 files changed, 145 insertions(+), 19 deletions(-)
create mode 100644 tests/example_based/test_contrastive.py
create mode 100644 xplique/example_based/contrastive_examples.py
diff --git a/tests/example_based/test_contrastive.py b/tests/example_based/test_contrastive.py
new file mode 100644
index 00000000..40204492
--- /dev/null
+++ b/tests/example_based/test_contrastive.py
@@ -0,0 +1,53 @@
+"""
+Tests for the contrastive methods.
+"""
+import tensorflow as tf
+import numpy as np
+
+from xplique.example_based import NaiveSemiFactuals
+from xplique.example_based.projections import Projection
+
+def test_naive_semi_factuals():
+ """
+ """
+ cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32)
+ cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
+
+ cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2)
+ cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2)
+ semi_factuals = NaiveSemiFactuals(cases_dataset, cases_targets_dataset, k=2, case_returns=["examples", "indices", "distances", "include_inputs"], batch_size=2)
+
+ inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
+ targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
+
+ mask = semi_factuals.filter_fn(inputs, cases, targets, cases_targets)
+ assert mask.shape == (inputs.shape[0], cases.shape[0])
+
+ expected_mask = tf.constant([
+ [True, False, False, True, False],
+ [False, True, True, False, True],
+ [False, True, True, False, True]], dtype=tf.bool)
+ assert tf.reduce_all(tf.equal(mask, expected_mask))
+
+ return_dict = semi_factuals(inputs, targets)
+ assert set(return_dict.keys()) == set(["examples", "indices", "distances"])
+
+ examples = return_dict["examples"]
+ distances = return_dict["distances"]
+ indices = return_dict["indices"]
+
+ assert examples.shape == (3, 3, 2) # (n, k+1, W)
+ assert distances.shape == (3, 2) # (n, k)
+ assert indices.shape == (3, 2, 2) # (n, k, 2)
+
+ expected_examples = tf.constant([
+ [[1.5, 2.5], [4., 5.], [1., 2.]],
+ [[2.5, 3.5], [5., 6.], [2., 3.]],
+ [[4.5, 5.5], [2., 3.], [3., 4.]]], dtype=tf.float32)
+ assert tf.reduce_all(tf.equal(examples, expected_examples))
+
+ expected_distances = tf.constant([[np.sqrt(2*2.5**2), np.sqrt(0.5)], [np.sqrt(2*2.5**2), np.sqrt(0.5)], [np.sqrt(2*2.5**2), np.sqrt(2*1.5**2)]], dtype=tf.float32)
+ assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
+
+ expected_indices = tf.constant([[[1, 1], [0, 0]],[[2, 0], [0, 1]],[[0, 1], [1, 0]]], dtype=tf.int32)
+ assert tf.reduce_all(tf.equal(indices, expected_indices))
diff --git a/tests/example_based/test_knn.py b/tests/example_based/test_knn.py
index e1d43b08..63d4d504 100644
--- a/tests/example_based/test_knn.py
+++ b/tests/example_based/test_knn.py
@@ -368,7 +368,7 @@ def test_filter_knn_kneighbors():
cases = tf.constant([[1.], [2.], [3.], [4.], [5.]], dtype=tf.float32)
inputs = tf.constant([[1.5], [2.5], [4.5]], dtype=tf.float32)
## default filter and default order
- knn = KNN(
+ knn = FilterKNN(
cases,
k=2,
batch_size=2,
@@ -393,8 +393,6 @@ def test_filter_knn_kneighbors():
distance="euclidean",
filter_fn=same_target_filter
)
- mask = same_target_filter(inputs, cases, targets, cases_targets)
- print(mask)
distances, indices = knn.kneighbors(inputs, targets)
assert distances.shape == (3, 2)
assert indices.shape == (3, 2, 2)
diff --git a/xplique/example_based/__init__.py b/xplique/example_based/__init__.py
index a958a62b..8f0ea443 100644
--- a/xplique/example_based/__init__.py
+++ b/xplique/example_based/__init__.py
@@ -4,3 +4,4 @@
from .cole import Cole
from .similar_examples import SimilarExamples
+from .contrastive_examples import NaiveSemiFactuals
diff --git a/xplique/example_based/base_example_method.py b/xplique/example_based/base_example_method.py
index 9e3facf2..df8ac306 100644
--- a/xplique/example_based/base_example_method.py
+++ b/xplique/example_based/base_example_method.py
@@ -122,6 +122,7 @@ def __init__(
cases_dataset=projected_cases_dataset,
k=k,
batch_size=batch_size,
+ targets_dataset=self.targets_dataset,
**search_method_kwargs,
)
@@ -381,6 +382,8 @@ def format_search_output(
# )
# add indices, distances, and labels
+ if "indices" in self.returns:
+ return_dict["indices"] = search_output["indices"]
if "distances" in self.returns:
return_dict["distances"] = search_output["distances"]
if "labels" in self.returns:
diff --git a/xplique/example_based/contrastive_examples.py b/xplique/example_based/contrastive_examples.py
new file mode 100644
index 00000000..365e8355
--- /dev/null
+++ b/xplique/example_based/contrastive_examples.py
@@ -0,0 +1,70 @@
+"""
+Implementation of both counterfactuals and semi factuals methods for classification tasks.
+"""
+import numpy as np
+import tensorflow as tf
+
+from ..types import Callable, List, Optional, Union
+
+from .base_example_method import BaseExampleMethod
+from .search_methods import BaseSearchMethod, KNN, ORDER, FilterKNN
+from .projections import Projection
+
+class NaiveSemiFactuals(BaseExampleMethod):
+ """
+ Define a naive version of semi factuals search. That for a given sample
+ it will return the farthest sample which have the same label.
+ """
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ projection: Union[Projection, Callable] = None,
+ case_returns: Union[List[str], str] = "examples",
+ batch_size: Optional[int] = 32,
+ distance: Union[int, str, Callable] = "euclidean",
+ ):
+ search_method = FilterKNN
+
+ if projection is None:
+ projection = Projection(space_projection=lambda inputs: inputs)
+
+ super().__init__(
+ cases_dataset=cases_dataset,
+ labels_dataset=labels_dataset,
+ targets_dataset=targets_dataset,
+ search_method=search_method,
+ k=k,
+ projection=projection,
+ case_returns=case_returns,
+ batch_size=batch_size,
+ distance=distance,
+ filter_fn=self.filter_fn,
+ order = ORDER.DESCENDING
+ )
+
+
+ def filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor:
+ """
+ Filter function to mask the cases for which the label is different from the predicted
+ label on the inputs.
+ """
+ # get the labels predicted by the model
+ # (n, )
+ predicted_labels = tf.argmax(targets, axis=-1)
+
+ # for each input, if the target label is the same as the predicted label
+ # the mask as a True value and False otherwise
+ label_targets = tf.argmax(cases_targets, axis=-1) # (bs,)
+ mask = tf.equal(tf.expand_dims(predicted_labels, axis=1), label_targets) #(n, bs)
+ return mask
+
+class PredictedLabelAwareSemiFactuals():
+ def __init__(self) -> None:
+ raise NotImplementedError
+
+class NaiveCounterFactuals(BaseExampleMethod):
+ def __init__():
+ raise NotImplementedError
diff --git a/xplique/example_based/search_methods/base.py b/xplique/example_based/search_methods/base.py
index 303575c3..a165688d 100644
--- a/xplique/example_based/search_methods/base.py
+++ b/xplique/example_based/search_methods/base.py
@@ -91,6 +91,7 @@ def __init__(
k: int = 1,
search_returns: Optional[Union[List[str], str]] = None,
batch_size: Optional[int] = 32,
+ targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
): # pylint: disable=R0801
# set batch size
@@ -104,6 +105,13 @@ def __init__(
self.set_k(k)
self.set_returns(search_returns)
+ # set targets_dataset
+ if targets_dataset is not None:
+ self.targets_dataset = sanitize_dataset(targets_dataset, self.batch_size)
+ else:
+ # make an iterable of None
+ self.targets_dataset = [None]*len(cases_dataset)
+
def set_k(self, k: int):
"""
Change value of k with constructing a new `BaseSearchMethod`.
diff --git a/xplique/example_based/search_methods/knn.py b/xplique/example_based/search_methods/knn.py
index 9b0f228b..5291999a 100644
--- a/xplique/example_based/search_methods/knn.py
+++ b/xplique/example_based/search_methods/knn.py
@@ -21,10 +21,11 @@ def __init__(
k: int = 1,
search_returns: Optional[Union[List[str], str]] = None,
batch_size: Optional[int] = 32,
- order: ORDER = ORDER.ASCENDING
+ order: ORDER = ORDER.ASCENDING,
+ targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
):
super().__init__(
- cases_dataset, k, search_returns, batch_size
+ cases_dataset, k, search_returns, batch_size, targets_dataset
)
assert isinstance(order, ORDER), f"order should be an instance of ORDER and not {type(order)}"
@@ -127,10 +128,11 @@ def __init__(
search_returns: Optional[Union[List[str], str]] = None,
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
- order: ORDER = ORDER.ASCENDING
+ order: ORDER = ORDER.ASCENDING,
+ targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
): # pylint: disable=R0801
super().__init__(
- cases_dataset, k, search_returns, batch_size, order
+ cases_dataset, k, search_returns, batch_size, order, targets_dataset
)
if hasattr(distance, "__call__"):
@@ -272,7 +274,7 @@ def __init__(
order: ORDER = ORDER.ASCENDING
): # pylint: disable=R0801
super().__init__(
- cases_dataset, k, search_returns, batch_size, order
+ cases_dataset, k, search_returns, batch_size, order, targets_dataset
)
if hasattr(distance, "__call__"):
@@ -288,16 +290,6 @@ def __init__(
+f"but {type(distance)} was received."
)
- # set targets_dataset
- if targets_dataset is not None:
- batch_size = min(batch_size, len(cases_dataset))
- cardinality = math.ceil(len(cases_dataset) / batch_size)
- self.targets_dataset = sanitize_dataset(
- targets_dataset, batch_size, cardinality
- )
- else:
- self.targets_dataset = [None]*len(cases_dataset)
-
# TODO: Assertion on the function signature
if filter_fn is None:
filter_fn = lambda x, z, y, t: tf.ones((tf.shape(x)[0], tf.shape(z)[0]), dtype=tf.bool)
@@ -390,4 +382,5 @@ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Uni
tf.gather(concatenated_distances, sort_order, axis=1, batch_dims=1)
)
- return best_distances, best_indices
\ No newline at end of file
+ return best_distances, best_indices
+
\ No newline at end of file
From 8d9f234f590dbd20fc917dcff77ba07173a3d66e Mon Sep 17 00:00:00 2001
From: Lucas Hervier
Date: Thu, 7 Mar 2024 17:23:13 +0100
Subject: [PATCH 032/138] feat: add a semi factuals method that is dedicated to
one specific label
---
tests/example_based/test_contrastive.py | 118 +++++++++++++++++-
xplique/example_based/__init__.py | 2 +-
xplique/example_based/contrastive_examples.py | 62 ++++++++-
3 files changed, 176 insertions(+), 6 deletions(-)
diff --git a/tests/example_based/test_contrastive.py b/tests/example_based/test_contrastive.py
index 40204492..c87fd67b 100644
--- a/tests/example_based/test_contrastive.py
+++ b/tests/example_based/test_contrastive.py
@@ -1,11 +1,12 @@
"""
Tests for the contrastive methods.
"""
+import pytest
+
import tensorflow as tf
import numpy as np
-from xplique.example_based import NaiveSemiFactuals
-from xplique.example_based.projections import Projection
+from xplique.example_based import NaiveSemiFactuals, PredictedLabelAwareSemiFactuals
def test_naive_semi_factuals():
"""
@@ -51,3 +52,116 @@ def test_naive_semi_factuals():
expected_indices = tf.constant([[[1, 1], [0, 0]],[[2, 0], [0, 1]],[[0, 1], [1, 0]]], dtype=tf.int32)
assert tf.reduce_all(tf.equal(indices, expected_indices))
+
+def test_labelaware_semifactuals():
+ """
+ """
+ cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32)
+ cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
+
+ cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2)
+ cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2)
+
+ inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
+ targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
+
+ semi_factuals = PredictedLabelAwareSemiFactuals(cases_dataset, cases_targets_dataset, target_label=0, k=2, batch_size=2, case_returns=["examples", "distances", "include_inputs"])
+ # assert the filtering on the right label went right
+
+ combined_dataset = tf.data.Dataset.zip((cases_dataset.unbatch(), cases_targets_dataset.unbatch()))
+ for elem, label in combined_dataset:
+ print(f"elem: {elem}, label: {label}")
+ print(f"lambda_fn: {tf.equal(tf.argmax(label, axis=-1),0)}")
+ combined_dataset = combined_dataset.filter(lambda x, y: tf.equal(tf.argmax(y, axis=-1),0))
+
+ for elem, label in combined_dataset:
+ print(f"elem: {elem}, label: {label}")
+
+ filter_cases = semi_factuals.cases_dataset
+ filter_targets = semi_factuals.targets_dataset
+
+ # for elem in filter_cases:
+ # print(elem)
+ # for elem in filter_targets:
+ # print(elem)
+
+ expected_filter_cases = tf.constant([[2., 3.], [3., 4.], [5., 6.]], dtype=tf.float32)
+ expected_filter_targets = tf.constant([[1, 0], [1, 0], [1, 0]], dtype=tf.float32)
+
+ tensor_filter_cases = []
+ for elem in filter_cases.unbatch():
+ tensor_filter_cases.append(elem)
+ tensor_filter_cases = tf.stack(tensor_filter_cases)
+ assert tf.reduce_all(tf.equal(tensor_filter_cases, expected_filter_cases))
+
+ tensor_filter_targets = []
+ for elem in filter_targets.unbatch():
+ tensor_filter_targets.append(elem)
+ tensor_filter_targets = tf.stack(tensor_filter_targets)
+ assert tf.reduce_all(tf.equal(tensor_filter_targets, expected_filter_targets))
+
+ # check the call method
+ filter_inputs = tf.constant([[2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
+ filter_targets = tf.constant([[1, 0], [1, 0]], dtype=tf.float32)
+
+ return_dict = semi_factuals(filter_inputs, filter_targets)
+ assert set(return_dict.keys()) == set(["examples", "distances"])
+
+ examples = return_dict["examples"]
+ distances = return_dict["distances"]
+
+ assert examples.shape == (2, 3, 2) # (n_label0, k+1, W)
+ assert distances.shape == (2, 2) # (n_label0, k)
+
+ expected_examples = tf.constant([
+ [[2.5, 3.5], [5., 6.], [2., 3.]],
+ [[4.5, 5.5], [2., 3.], [3., 4.]]], dtype=tf.float32)
+ assert tf.reduce_all(tf.equal(examples, expected_examples))
+
+ expected_distances = tf.constant([[np.sqrt(2*2.5**2), np.sqrt(0.5)], [np.sqrt(2*2.5**2), np.sqrt(2*1.5**2)]], dtype=tf.float32)
+ assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
+
+
+ # check an error is raised when a target does not match the target label
+ with pytest.raises(AssertionError):
+ semi_factuals(inputs, targets)
+
+ # same but with the other label
+ semi_factuals = PredictedLabelAwareSemiFactuals(cases_dataset, cases_targets_dataset, target_label=1, k=2, batch_size=2, case_returns=["examples", "distances", "include_inputs"])
+ filter_cases = semi_factuals.cases_dataset
+ filter_targets = semi_factuals.targets_dataset
+
+ expected_filter_cases = tf.constant([[1., 2.], [4., 5.]], dtype=tf.float32)
+ expected_filter_targets = tf.constant([[0, 1], [0, 1]], dtype=tf.float32)
+
+ tensor_filter_cases = []
+ for elem in filter_cases.unbatch():
+ tensor_filter_cases.append(elem)
+ tensor_filter_cases = tf.stack(tensor_filter_cases)
+ assert tf.reduce_all(tf.equal(tensor_filter_cases, expected_filter_cases))
+
+ tensor_filter_targets = []
+ for elem in filter_targets.unbatch():
+ tensor_filter_targets.append(elem)
+ tensor_filter_targets = tf.stack(tensor_filter_targets)
+ assert tf.reduce_all(tf.equal(tensor_filter_targets, expected_filter_targets))
+
+ # check the call method
+ filter_inputs = tf.constant([[1.5, 2.5]], dtype=tf.float32)
+ filter_targets = tf.constant([[0, 1]], dtype=tf.float32)
+
+ return_dict = semi_factuals(filter_inputs, filter_targets)
+ assert set(return_dict.keys()) == set(["examples", "distances"])
+
+ examples = return_dict["examples"]
+ distances = return_dict["distances"]
+
+ assert examples.shape == (1, 3, 2) # (n_label1, k+1, W)
+ assert distances.shape == (1, 2) # (n_label1, k)
+
+ expected_examples = tf.constant([
+ [[1.5, 2.5], [4., 5.], [1., 2.]]], dtype=tf.float32)
+ assert tf.reduce_all(tf.equal(examples, expected_examples))
+
+ expected_distances = tf.constant([[np.sqrt(2*2.5**2), np.sqrt(0.5)]], dtype=tf.float32)
+ assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
diff --git a/xplique/example_based/__init__.py b/xplique/example_based/__init__.py
index 8f0ea443..75238093 100644
--- a/xplique/example_based/__init__.py
+++ b/xplique/example_based/__init__.py
@@ -4,4 +4,4 @@
from .cole import Cole
from .similar_examples import SimilarExamples
-from .contrastive_examples import NaiveSemiFactuals
+from .contrastive_examples import NaiveSemiFactuals, PredictedLabelAwareSemiFactuals
diff --git a/xplique/example_based/contrastive_examples.py b/xplique/example_based/contrastive_examples.py
index 365e8355..c3383e92 100644
--- a/xplique/example_based/contrastive_examples.py
+++ b/xplique/example_based/contrastive_examples.py
@@ -61,9 +61,65 @@ def filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor:
mask = tf.equal(tf.expand_dims(predicted_labels, axis=1), label_targets) #(n, bs)
return mask
-class PredictedLabelAwareSemiFactuals():
- def __init__(self) -> None:
- raise NotImplementedError
+class PredictedLabelAwareSemiFactuals(BaseExampleMethod):
+ """
+ As we know semi-factuals should belong to the same class as the input,
+ we propose here a method that is dedicated to a specific label.
+ """
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ target_label: int,
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ projection: Union[Projection, Callable] = None,
+ case_returns: Union[List[str], str] = "examples",
+ batch_size: Optional[int] = 32,
+ distance: Union[int, str, Callable] = "euclidean",
+ ):
+ # filter the cases dataset and targets dataset to keep only the ones
+ # that have the target label
+ # TODO: improve this unbatch and batch
+ combined_dataset = tf.data.Dataset.zip((cases_dataset.unbatch(), targets_dataset.unbatch()))
+ combined_dataset = combined_dataset.filter(lambda x, y: tf.equal(tf.argmax(y, axis=-1),target_label))
+
+ # separate the cases and targets
+ cases_dataset = combined_dataset.map(lambda x, y: x).batch(batch_size)
+ targets_dataset = combined_dataset.map(lambda x, y: y).batch(batch_size)
+
+ # delete the combined dataset
+ del combined_dataset
+
+ if projection is None:
+ projection = Projection(space_projection=lambda inputs: inputs)
+
+ search_method = KNN
+
+ super().__init__(
+ cases_dataset=cases_dataset,
+ labels_dataset=labels_dataset,
+ targets_dataset=targets_dataset,
+ search_method=search_method,
+ k=k,
+ projection=projection,
+ case_returns=case_returns,
+ batch_size=batch_size,
+ distance=distance,
+ order = ORDER.DESCENDING
+ )
+
+ self.target_label = target_label
+
+ def __call__(
+ self,
+ inputs: Union[tf.Tensor, np.ndarray],
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ ):
+ # assert targets are all the same as the target label
+ if targets is not None:
+ assert tf.reduce_all(tf.argmax(targets, axis=-1) == self.target_label), "All targets should be the same as the target label."
+ return super().__call__(inputs, targets)
class NaiveCounterFactuals(BaseExampleMethod):
def __init__():
From f5f05a3df6b92d73e2b810e129f7d9d5667c0090 Mon Sep 17 00:00:00 2001
From: Lucas Hervier
Date: Wed, 13 Mar 2024 11:28:29 +0100
Subject: [PATCH 033/138] feat: add a naive counter factuals methods and its
test
---
tests/example_based/test_contrastive.py | 65 ++++++++++++++-----
xplique/example_based/__init__.py | 2 +-
xplique/example_based/contrastive_examples.py | 50 +++++++++++++-
3 files changed, 98 insertions(+), 19 deletions(-)
diff --git a/tests/example_based/test_contrastive.py b/tests/example_based/test_contrastive.py
index c87fd67b..bac1aaa2 100644
--- a/tests/example_based/test_contrastive.py
+++ b/tests/example_based/test_contrastive.py
@@ -6,7 +6,7 @@
import tensorflow as tf
import numpy as np
-from xplique.example_based import NaiveSemiFactuals, PredictedLabelAwareSemiFactuals
+from xplique.example_based import NaiveSemiFactuals, PredictedLabelAwareSemiFactuals, NaiveCounterFactuals
def test_naive_semi_factuals():
"""
@@ -69,22 +69,11 @@ def test_labelaware_semifactuals():
# assert the filtering on the right label went right
combined_dataset = tf.data.Dataset.zip((cases_dataset.unbatch(), cases_targets_dataset.unbatch()))
- for elem, label in combined_dataset:
- print(f"elem: {elem}, label: {label}")
- print(f"lambda_fn: {tf.equal(tf.argmax(label, axis=-1),0)}")
combined_dataset = combined_dataset.filter(lambda x, y: tf.equal(tf.argmax(y, axis=-1),0))
- for elem, label in combined_dataset:
- print(f"elem: {elem}, label: {label}")
-
filter_cases = semi_factuals.cases_dataset
filter_targets = semi_factuals.targets_dataset
- # for elem in filter_cases:
- # print(elem)
- # for elem in filter_targets:
- # print(elem)
-
expected_filter_cases = tf.constant([[2., 3.], [3., 4.], [5., 6.]], dtype=tf.float32)
expected_filter_targets = tf.constant([[1, 0], [1, 0], [1, 0]], dtype=tf.float32)
@@ -99,7 +88,7 @@ def test_labelaware_semifactuals():
tensor_filter_targets.append(elem)
tensor_filter_targets = tf.stack(tensor_filter_targets)
assert tf.reduce_all(tf.equal(tensor_filter_targets, expected_filter_targets))
-
+
# check the call method
filter_inputs = tf.constant([[2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
filter_targets = tf.constant([[1, 0], [1, 0]], dtype=tf.float32)
@@ -121,11 +110,10 @@ def test_labelaware_semifactuals():
expected_distances = tf.constant([[np.sqrt(2*2.5**2), np.sqrt(0.5)], [np.sqrt(2*2.5**2), np.sqrt(2*1.5**2)]], dtype=tf.float32)
assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
-
# check an error is raised when a target does not match the target label
with pytest.raises(AssertionError):
semi_factuals(inputs, targets)
-
+
# same but with the other label
semi_factuals = PredictedLabelAwareSemiFactuals(cases_dataset, cases_targets_dataset, target_label=1, k=2, batch_size=2, case_returns=["examples", "distances", "include_inputs"])
filter_cases = semi_factuals.cases_dataset
@@ -145,7 +133,7 @@ def test_labelaware_semifactuals():
tensor_filter_targets.append(elem)
tensor_filter_targets = tf.stack(tensor_filter_targets)
assert tf.reduce_all(tf.equal(tensor_filter_targets, expected_filter_targets))
-
+
# check the call method
filter_inputs = tf.constant([[1.5, 2.5]], dtype=tf.float32)
filter_targets = tf.constant([[0, 1]], dtype=tf.float32)
@@ -165,3 +153,48 @@ def test_labelaware_semifactuals():
expected_distances = tf.constant([[np.sqrt(2*2.5**2), np.sqrt(0.5)]], dtype=tf.float32)
assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
+
+def test_naive_counter_factuals():
+ """
+ """
+ cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32)
+ cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
+
+ cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2)
+ cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2)
+ counter_factuals = NaiveCounterFactuals(cases_dataset, cases_targets_dataset, k=2, case_returns=["examples", "indices", "distances", "include_inputs"], batch_size=2)
+
+ inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
+ targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
+
+ mask = counter_factuals.filter_fn(inputs, cases, targets, cases_targets)
+ assert mask.shape == (inputs.shape[0], cases.shape[0])
+
+ expected_mask = tf.constant([
+ [False, True, True, False, True],
+ [True, False, False, True, False],
+ [True, False, False, True, False]], dtype=tf.bool)
+ assert tf.reduce_all(tf.equal(mask, expected_mask))
+
+ return_dict = counter_factuals(inputs, targets)
+ assert set(return_dict.keys()) == set(["examples", "indices", "distances"])
+
+ examples = return_dict["examples"]
+ distances = return_dict["distances"]
+ indices = return_dict["indices"]
+
+ assert examples.shape == (3, 3, 2) # (n, k+1, W)
+ assert distances.shape == (3, 2) # (n, k)
+ assert indices.shape == (3, 2, 2) # (n, k, 2)
+
+ expected_examples = tf.constant([
+ [[1.5, 2.5], [2., 3.], [3., 4.]],
+ [[2.5, 3.5], [1., 2.], [4., 5.]],
+ [[4.5, 5.5], [4., 5.], [1., 2.]]], dtype=tf.float32)
+ assert tf.reduce_all(tf.equal(examples, expected_examples))
+
+ expected_distances = tf.constant([[np.sqrt(2*0.5**2), np.sqrt(2*1.5**2)], [np.sqrt(2*1.5**2), np.sqrt(2*1.5**2)], [np.sqrt(2*0.5**2), np.sqrt(2*3.5**2)]], dtype=tf.float32)
+ assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
+
+ expected_indices = tf.constant([[[0, 1], [1, 0]],[[0, 0], [1, 1]],[[1, 1], [0, 0]]], dtype=tf.int32)
+ assert tf.reduce_all(tf.equal(indices, expected_indices))
diff --git a/xplique/example_based/__init__.py b/xplique/example_based/__init__.py
index 75238093..89e08d1b 100644
--- a/xplique/example_based/__init__.py
+++ b/xplique/example_based/__init__.py
@@ -4,4 +4,4 @@
from .cole import Cole
from .similar_examples import SimilarExamples
-from .contrastive_examples import NaiveSemiFactuals, PredictedLabelAwareSemiFactuals
+from .contrastive_examples import NaiveSemiFactuals, PredictedLabelAwareSemiFactuals, NaiveCounterFactuals
diff --git a/xplique/example_based/contrastive_examples.py b/xplique/example_based/contrastive_examples.py
index c3383e92..83b03b11 100644
--- a/xplique/example_based/contrastive_examples.py
+++ b/xplique/example_based/contrastive_examples.py
@@ -122,5 +122,51 @@ def __call__(
return super().__call__(inputs, targets)
class NaiveCounterFactuals(BaseExampleMethod):
- def __init__():
- raise NotImplementedError
+ """
+
+ """
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ projection: Union[Projection, Callable] = None,
+ case_returns: Union[List[str], str] = "examples",
+ batch_size: Optional[int] = 32,
+ distance: Union[int, str, Callable] = "euclidean",
+ ):
+ search_method = FilterKNN
+
+ if projection is None:
+ projection = Projection(space_projection=lambda inputs: inputs)
+
+ super().__init__(
+ cases_dataset=cases_dataset,
+ labels_dataset=labels_dataset,
+ targets_dataset=targets_dataset,
+ search_method=search_method,
+ k=k,
+ projection=projection,
+ case_returns=case_returns,
+ batch_size=batch_size,
+ distance=distance,
+ filter_fn=self.filter_fn,
+ order = ORDER.ASCENDING
+ )
+
+
+ def filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor:
+ """
+ Filter function to mask the cases for which the label is different from the predicted
+ label on the inputs.
+ """
+ # get the labels predicted by the model
+ # (n, )
+ predicted_labels = tf.argmax(targets, axis=-1)
+
+ # for each input, if the target label is the same as the predicted label
+ # the mask as a True value and False otherwise
+ label_targets = tf.argmax(cases_targets, axis=-1) # (bs,)
+ mask = tf.not_equal(tf.expand_dims(predicted_labels, axis=1), label_targets) #(n, bs)
+ return mask
From 7b32df9830e99d6390f0e40f72f04f1944faa98d Mon Sep 17 00:00:00 2001
From: Mohamed Chafik Bakey
Date: Thu, 21 Mar 2024 12:58:09 +0100
Subject: [PATCH 034/138] add Prototypes
---
tests/example_based/test_prototypes.py | 248 +++++++++++
tests/utils.py | 49 +++
xplique/example_based/__init__.py | 1 +
xplique/example_based/base_example_method.py | 1 +
xplique/example_based/prototypes.py | 114 ++++++
.../example_based/search_methods/__init__.py | 4 +
xplique/example_based/search_methods/base.py | 2 +
.../search_methods/mmd_critic_search.py | 98 +++++
.../search_methods/proto_dash_search.py | 244 +++++++++++
.../search_methods/proto_greedy_search.py | 385 ++++++++++++++++++
.../search_methods/prototypes_search.py | 137 +++++++
11 files changed, 1283 insertions(+)
create mode 100644 tests/example_based/test_prototypes.py
create mode 100644 xplique/example_based/prototypes.py
create mode 100644 xplique/example_based/search_methods/mmd_critic_search.py
create mode 100644 xplique/example_based/search_methods/proto_dash_search.py
create mode 100644 xplique/example_based/search_methods/proto_greedy_search.py
create mode 100644 xplique/example_based/search_methods/prototypes_search.py
diff --git a/tests/example_based/test_prototypes.py b/tests/example_based/test_prototypes.py
new file mode 100644
index 00000000..39ba2425
--- /dev/null
+++ b/tests/example_based/test_prototypes.py
@@ -0,0 +1,248 @@
+"""
+Test Prototypes
+"""
+import os
+import sys
+
+sys.path.append(os.getcwd())
+
+from math import prod, sqrt
+import unittest
+import time
+
+import numpy as np
+from sklearn.metrics.pairwise import rbf_kernel
+import tensorflow as tf
+
+from xplique.commons import sanitize_dataset, are_dataset_first_elems_equal
+from xplique.types import Union
+
+from xplique.example_based import Prototypes
+from xplique.example_based.projections import Projection, LatentSpaceProjection
+from xplique.example_based.search_methods import ProtoGreedySearch, ProtoDashSearch, MMDCriticSearch
+
+from tests.utils import almost_equal, get_Gaussian_Data, load_data, plot
+
+
+def test_proto_greedy_basic():
+ """
+ Test the SimilarExamples with an identity projection.
+ """
+ # Setup
+ k = 3
+ nb_prototypes = 3
+ gamma = 0.026
+ x_train, y_train = get_Gaussian_Data(nb_samples_class=20)
+ # x_train, y_train = load_data('usps')
+ # x_test, y_test = load_data('usps.t')
+ # x_test = tf.random.shuffle(x_test)
+ # x_test = x_test[0:8]
+
+ identity_projection = Projection(
+ space_projection=lambda inputs, targets=None: inputs
+ )
+
+ def custom_kernel_wrapper(gamma):
+ def custom_kernel(x,y=None):
+ return rbf_kernel(x,y,gamma)
+ return custom_kernel
+
+ kernel_fn = custom_kernel_wrapper(gamma)
+
+ kernel_type = "global"
+
+ # Method initialization
+ method = Prototypes(
+ cases_dataset=x_train,
+ labels_dataset=y_train,
+ search_method=ProtoGreedySearch,
+ k=k,
+ projection=identity_projection,
+ batch_size=32,
+ distance="euclidean",
+ nb_prototypes=nb_prototypes,
+ kernel_type=kernel_type,
+ kernel_fn=kernel_fn,
+ )
+
+ # Generate explanation
+ prototype_indices, prototype_weights = method.get_prototypes()
+
+ prototypes = tf.gather(x_train, prototype_indices)
+ prototype_labels = tf.gather(y_train, prototype_indices)
+
+ # sort by label
+ prototype_labels_sorted = prototype_labels.numpy().argsort()
+
+ prototypes = tf.gather(prototypes, prototype_labels_sorted)
+ prototype_indices = tf.gather(prototype_indices, prototype_labels_sorted)
+ prototype_labels = tf.gather(prototype_labels, prototype_labels_sorted)
+ prototype_weights = tf.gather(prototype_weights, prototype_labels_sorted)
+
+ # Verifications
+ # Shape
+ assert prototype_indices.shape == (nb_prototypes,)
+ assert prototypes.shape == (nb_prototypes, x_train.shape[1])
+ assert prototype_weights.shape == (nb_prototypes,)
+
+ # at least 1 prototype per class is selected
+ assert tf.unique(prototype_labels)[0].shape == tf.unique(y_train)[0].shape
+
+ # uniqueness test of prototypes
+ assert prototype_indices.shape == tf.unique(prototype_indices)[0].shape
+
+ # Check if all indices are between 0 and x_train.shape[0]-1
+ assert tf.reduce_all(tf.math.logical_and(prototype_indices >= 0, prototype_indices <= x_train.shape[0]-1))
+
+ # # Visualize all prototypes
+ # plot(prototypes, prototype_weights, 'proto_greedy')
+
+def test_proto_dash_basic():
+ """
+ Test the SimilarExamples with an identity projection.
+ """
+ # Setup
+ k = 3
+ nb_prototypes = 3
+ gamma = 0.026
+ x_train, y_train = get_Gaussian_Data(nb_samples_class=20)
+ # x_train, y_train = load_data('usps')
+ # x_test, y_test = load_data('usps.t')
+ # x_test = tf.random.shuffle(x_test)
+ # x_test = x_test[0:8]
+
+ identity_projection = Projection(
+ space_projection=lambda inputs, targets=None: inputs
+ )
+
+ def custom_kernel_wrapper(gamma):
+ def custom_kernel(x,y=None):
+ return rbf_kernel(x,y,gamma)
+ return custom_kernel
+
+ kernel_fn = custom_kernel_wrapper(gamma)
+
+ kernel_type = "global"
+
+ # Method initialization
+ method = Prototypes(
+ cases_dataset=x_train,
+ labels_dataset=y_train,
+ search_method=ProtoDashSearch,
+ k=k,
+ projection=identity_projection,
+ batch_size=32,
+ distance="euclidean",
+ nb_prototypes=nb_prototypes,
+ kernel_type=kernel_type,
+ kernel_fn=kernel_fn,
+ )
+
+ # Generate explanation
+ prototype_indices, prototype_weights = method.get_prototypes()
+
+ prototypes = tf.gather(x_train, prototype_indices)
+ prototype_labels = tf.gather(y_train, prototype_indices)
+
+ # sort by label
+ prototype_labels_sorted = prototype_labels.numpy().argsort()
+
+ prototypes = tf.gather(prototypes, prototype_labels_sorted)
+ prototype_indices = tf.gather(prototype_indices, prototype_labels_sorted)
+ prototype_labels = tf.gather(prototype_labels, prototype_labels_sorted)
+ prototype_weights = tf.gather(prototype_weights, prototype_labels_sorted)
+
+ # Verifications
+ # Shape
+ assert prototype_indices.shape == (nb_prototypes,)
+ assert prototypes.shape == (nb_prototypes, x_train.shape[1])
+ assert prototype_weights.shape == (nb_prototypes,)
+
+ # at least 1 prototype per class is selected
+ assert tf.unique(prototype_labels)[0].shape == tf.unique(y_train)[0].shape
+
+ # uniqueness test of prototypes
+ assert prototype_indices.shape == tf.unique(prototype_indices)[0].shape
+
+ # Check if all indices are between 0 and x_train.shape[0]-1
+ assert tf.reduce_all(tf.math.logical_and(prototype_indices >= 0, prototype_indices <= x_train.shape[0]-1))
+
+ # # Visualize all prototypes
+ # plot(prototypes, prototype_weights, 'proto_dash')
+
+def test_mmd_critic_basic():
+ """
+ Test the SimilarExamples with an identity projection.
+ """
+ # Setup
+ k = 3
+ nb_prototypes = 3
+ gamma = 0.026
+ x_train, y_train = get_Gaussian_Data(nb_samples_class=20)
+ # x_train, y_train = load_data('usps')
+ # x_test, y_test = load_data('usps.t')
+ # x_test = tf.random.shuffle(x_test)
+ # x_test = x_test[0:8]
+
+ identity_projection = Projection(
+ space_projection=lambda inputs, targets=None: inputs
+ )
+
+ def custom_kernel_wrapper(gamma):
+ def custom_kernel(x,y=None):
+ return rbf_kernel(x,y,gamma)
+ return custom_kernel
+
+ kernel_fn = custom_kernel_wrapper(gamma)
+
+ kernel_type = "global"
+
+ # Method initialization
+ method = Prototypes(
+ cases_dataset=x_train,
+ labels_dataset=y_train,
+ search_method=MMDCriticSearch,
+ k=k,
+ projection=identity_projection,
+ batch_size=32,
+ distance="euclidean",
+ nb_prototypes=nb_prototypes,
+ kernel_type=kernel_type,
+ kernel_fn=kernel_fn,
+ )
+
+ # Generate explanation
+ prototype_indices, prototype_weights = method.get_prototypes()
+
+ prototypes = tf.gather(x_train, prototype_indices)
+ prototype_labels = tf.gather(y_train, prototype_indices)
+
+ # sort by label
+ prototype_labels_sorted = prototype_labels.numpy().argsort()
+
+ prototypes = tf.gather(prototypes, prototype_labels_sorted)
+ prototype_indices = tf.gather(prototype_indices, prototype_labels_sorted)
+ prototype_labels = tf.gather(prototype_labels, prototype_labels_sorted)
+ prototype_weights = tf.gather(prototype_weights, prototype_labels_sorted)
+
+ # Verifications
+ # Shape
+ assert prototype_indices.shape == (nb_prototypes,)
+ assert prototypes.shape == (nb_prototypes, x_train.shape[1])
+ assert prototype_weights.shape == (nb_prototypes,)
+
+ # at least 1 prototype per class is selected
+ assert tf.unique(prototype_labels)[0].shape == tf.unique(y_train)[0].shape
+
+ # uniqueness test of prototypes
+ assert prototype_indices.shape == tf.unique(prototype_indices)[0].shape
+
+ # Check if all indices are between 0 and x_train.shape[0]-1
+ assert tf.reduce_all(tf.math.logical_and(prototype_indices >= 0, prototype_indices <= x_train.shape[0]-1))
+
+ # # Visualize all prototypes
+ # plot(prototypes, prototype_weights, 'mmd_critic')
+
+test_proto_greedy_basic()
+# test_proto_dash_basic()
+# test_mmd_critic_basic()
diff --git a/tests/utils.py b/tests/utils.py
index 92d348e2..4219cd1a 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -1,7 +1,11 @@
import signal, time
import numpy as np
+import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
+from sklearn.datasets import load_svmlight_file
+from pathlib import Path
+from math import ceil
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import (Dense, Conv1D, Conv2D, Activation, GlobalAveragePooling1D,
@@ -250,3 +254,48 @@ def download_file(identifier: str,
for chunk in response.iter_content(chunk_size):
if chunk:
file.write(chunk)
+
+def get_Gaussian_Data(nb_samples_class=20):
+ tf.random.set_seed(42)
+
+ sigma = 0.05
+ mu = [10, 20, 30]
+
+ X = tf.concat([tf.random.normal(shape=(nb_samples_class,1), mean=mu[i], stddev=sigma, dtype=tf.float32) for i in range(3)], axis=0)
+ y = tf.concat([tf.ones(shape=(nb_samples_class), dtype=tf.int32) * i for i in range(3)], axis=0)
+
+ return(X, y)
+
+def load_data(fname):
+ data_dir = Path('/home/mohamed-chafik.bakey/MMD-critic/data')
+ X, y = load_svmlight_file(str(data_dir / fname))
+ X = tf.constant(X.todense(), dtype=tf.float32)
+ y = tf.constant(np.array(y), dtype=tf.int64)
+ sort_indices = y.numpy().argsort()
+ X = tf.gather(X, sort_indices, axis=0)
+ y = tf.gather(y, sort_indices)
+ y -= 1
+ return X, y
+
+def plot(prototypes_sorted, prototype_weights_sorted, extension):
+
+ output_dir = Path('tests/example_based/tmp')
+ k = prototypes_sorted.shape[0]
+
+ # Visualize all prototypes
+ num_cols = 8
+ num_rows = ceil(k / num_cols)
+ fig, axes = plt.subplots(num_rows, num_cols, figsize=(6, num_rows * 1.25))
+ if prototype_weights_sorted is not None:
+ # Adjust the spacing between lines
+ plt.subplots_adjust(hspace=1)
+ for i, axis in enumerate(axes.ravel()):
+ if i >= k:
+ axis.axis('off')
+ continue
+ axis.imshow(prototypes_sorted[i].numpy().reshape(16, 16), cmap='gray')
+ if prototype_weights_sorted is not None:
+ axis.set_title("{:.2f}".format(prototype_weights_sorted[i].numpy()))
+ axis.axis('off')
+ # fig.suptitle(f'{k} Prototypes')
+ plt.savefig(output_dir / f'{k}_prototypes_{extension}.png')
\ No newline at end of file
diff --git a/xplique/example_based/__init__.py b/xplique/example_based/__init__.py
index a958a62b..0cdb3d2f 100644
--- a/xplique/example_based/__init__.py
+++ b/xplique/example_based/__init__.py
@@ -4,3 +4,4 @@
from .cole import Cole
from .similar_examples import SimilarExamples
+from .prototypes import Prototypes
diff --git a/xplique/example_based/base_example_method.py b/xplique/example_based/base_example_method.py
index 2c4b99df..31ea4f89 100644
--- a/xplique/example_based/base_example_method.py
+++ b/xplique/example_based/base_example_method.py
@@ -120,6 +120,7 @@ def __init__(
# initiate search_method
self.search_method = search_method(
cases_dataset=projected_cases_dataset,
+ labels_dataset=labels_dataset,
k=k,
batch_size=batch_size,
**search_method_kwargs,
diff --git a/xplique/example_based/prototypes.py b/xplique/example_based/prototypes.py
new file mode 100644
index 00000000..9df3081f
--- /dev/null
+++ b/xplique/example_based/prototypes.py
@@ -0,0 +1,114 @@
+"""
+Base model for prototypes
+"""
+
+import math
+
+import time
+
+import tensorflow as tf
+import numpy as np
+
+from ..types import Callable, Dict, List, Optional, Type, Union
+
+from ..commons import sanitize_inputs_targets
+from ..commons import sanitize_dataset, dataset_gather
+from .search_methods import ProtoGreedySearch, PrototypesSearch
+from .projections import Projection
+from .base_example_method import BaseExampleMethod
+
+from .search_methods.base import _sanitize_returns
+
+
+class Prototypes(BaseExampleMethod):
+ """
+ Base class for prototypes.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ targets_dataset
+ Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ k
+ The number of examples to retrieve.
+ projection
+ Projection or Callable that project samples from the input space to the search space.
+ The search space should be a space where distance make sense for the model.
+ It should not be `None`, otherwise,
+ all examples could be computed only with the `search_method`.
+
+ Example of Callable:
+ ```
+ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
+ '''
+ Example of projection,
+ inputs are the elements to project.
+ targets are optional parameters to orientated the projection.
+ '''
+ projected_inputs = # do some magic on inputs, it should use the model.
+ return projected_inputs
+ ```
+ case_returns
+ String or list of string with the elements to return in `self.explain()`.
+ See `self.set_returns()` for detail.
+ batch_size
+ Number of sample treated simultaneously for projection and search.
+ Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+ distance
+ Distance for the knn search method.
+ Either a Callable, or a value supported by `tf.norm` `ord` parameter.
+ Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
+ "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
+ yielding the corresponding p-norm." We also added 'cosine'.
+ """
+
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ search_method: Type[PrototypesSearch] = ProtoGreedySearch,
+ k: int = 1,
+ projection: Union[Projection, Callable] = None,
+ case_returns: Union[List[str], str] = "examples",
+ batch_size: Optional[int] = 32,
+ **search_method_kwargs,
+ ):
+ super().__init__(
+ cases_dataset=cases_dataset,
+ labels_dataset=labels_dataset,
+ targets_dataset=targets_dataset,
+ search_method=search_method,
+ k=k,
+ projection=projection,
+ case_returns=case_returns,
+ batch_size=batch_size,
+ **search_method_kwargs,
+ )
+
+ def get_prototypes(self):
+ """
+ Return the prototypes computed by the search method.
+
+ Returns:
+ prototype_indices : Tensor
+ prototype indices.
+ prototype_indices : Tensor
+ prototype weights.
+ """
+ return self.search_method.prototype_indices, self.search_method.prototype_weights
+
diff --git a/xplique/example_based/search_methods/__init__.py b/xplique/example_based/search_methods/__init__.py
index 228e1acd..d54e85a4 100644
--- a/xplique/example_based/search_methods/__init__.py
+++ b/xplique/example_based/search_methods/__init__.py
@@ -6,3 +6,7 @@
# from .sklearn_knn import SklearnKNN
from .knn import KNN
+from .prototypes_search import PrototypesSearch
+from .proto_greedy_search import ProtoGreedySearch
+from .proto_dash_search import ProtoDashSearch
+from .mmd_critic_search import MMDCriticSearch
diff --git a/xplique/example_based/search_methods/base.py b/xplique/example_based/search_methods/base.py
index 1c7c0f1b..5dde1a9c 100644
--- a/xplique/example_based/search_methods/base.py
+++ b/xplique/example_based/search_methods/base.py
@@ -80,6 +80,7 @@ class BaseSearchMethod(ABC):
def __init__(
self,
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
k: int = 1,
search_returns: Optional[Union[List[str], str]] = None,
batch_size: Optional[int] = 32,
@@ -92,6 +93,7 @@ def __init__(
self.batch_size = batch_size
self.cases_dataset = sanitize_dataset(cases_dataset, self.batch_size)
+ self.labels_dataset = sanitize_dataset(labels_dataset, self.batch_size)
self.set_k(k)
self.set_returns(search_returns)
diff --git a/xplique/example_based/search_methods/mmd_critic_search.py b/xplique/example_based/search_methods/mmd_critic_search.py
new file mode 100644
index 00000000..cfe70941
--- /dev/null
+++ b/xplique/example_based/search_methods/mmd_critic_search.py
@@ -0,0 +1,98 @@
+"""
+MMDCritic search method in example-based module
+"""
+
+import numpy as np
+import tensorflow as tf
+
+from ...commons import dataset_gather
+from ...types import Callable, List, Union, Optional, Tuple
+
+from .proto_greedy_search import ProtoGreedySearch
+from ..projections import Projection
+
+
+class MMDCriticSearch(ProtoGreedySearch):
+ """
+ KNN method to search examples. Based on `sklearn.neighbors.NearestNeighbors`.
+ Basically a wrapper of `NearestNeighbors` to match the `BaseSearchMethod` API.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ For natural example-based methods it is the train dataset.
+ k
+ The number of examples to retrieve.
+ search_returns
+ String or list of string with the elements to return in `self.find_examples()`.
+ See `self.set_returns()` for detail.
+ batch_size
+ Number of sample treated simultaneously.
+ It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
+ distance
+ Either a Callable, or a value supported by `tf.norm` `ord` parameter.
+ Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
+ "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
+ yielding the corresponding p-norm." We also added 'cosine'.
+ """
+
+ def compute_objectives(self, selection_indices, selection_cases, selection_labels, selection_weights, selection_selection_kernel, candidates_indices, candidates_cases, candidates_labels, candidates_selection_kernel):
+ """
+ Compute the objective function and corresponding weights for a given set of selected prototypes and a candidate.
+
+ Here, we have a special case of protogreedy where we give equal weights to all prototypes,
+ the objective here is simplified to speed up processing
+
+ Find argmax_{c} F(S ∪ c) - F(S)
+ ≡
+ Find argmax_{c} F(S ∪ c)
+ ≡
+ Find argmax_{c} (sum1 - sum2) where: sum1 = (2 / n) * ∑[i=1 to n] κ(x_i, c)
+ sum2 = 1/(|S|+1) [2 * ∑[j=1 to |S|] * κ(x_j, c) + κ(c, c)]
+
+ Parameters
+ ----------
+ selection_indices : Tensor
+ Indices corresponding to the selected prototypes.
+ selection_cases : Tensor
+ Cases corresponding to the selected prototypes.
+ selection_labels : Tensor
+ Labels corresponding to the selected prototypes.
+ selection_weights : Tensor
+ Weights corresponding to the selected prototypes.
+ selection_selection_kernel : Tensor
+ Kernel matrix computed from the selected prototypes.
+ candidates_indices : Tensor
+ Indices corresponding to the candidate prototypes.
+ candidates_cases : Tensor
+ Cases corresponding to the candidate prototypes.
+ candidates_labels : Tensor
+ Labels corresponding to the candidate prototypes.
+ candidates_selection_kernel : Tensor
+ Kernel matrix between the candidates and the selected prototypes.
+
+ Returns
+ -------
+ objectives
+ Tensor that contains the computed objective values for each candidate.
+ objectives_weights
+ Tensor that contains the computed objective weights for each candidate.
+ """
+
+ nb_candidates = candidates_indices.shape[0]
+ nb_selection = selection_indices.shape[0]
+
+ sum1 = 2 * tf.gather(self.col_means, candidates_indices)
+
+ if nb_selection == 0:
+ sum2 = tf.abs(tf.gather(self.diag, candidates_indices))
+ else:
+ temp = tf.transpose(candidates_selection_kernel, perm=[1, 0])
+ sum2 = tf.reduce_sum(temp, axis=0) * 2 + tf.gather(self.diag, candidates_indices)
+ sum2 /= (nb_selection + 1)
+
+ objectives = sum1 - sum2
+ objectives_weights = tf.ones(shape=(nb_candidates, nb_selection+1), dtype=tf.float32) / tf.cast(nb_selection+1, dtype=tf.float32)
+
+ return objectives, objectives_weights
diff --git a/xplique/example_based/search_methods/proto_dash_search.py b/xplique/example_based/search_methods/proto_dash_search.py
new file mode 100644
index 00000000..e29d12b8
--- /dev/null
+++ b/xplique/example_based/search_methods/proto_dash_search.py
@@ -0,0 +1,244 @@
+"""
+ProtoDash search method in example-based module
+"""
+
+import numpy as np
+from sklearn.metrics.pairwise import rbf_kernel
+from scipy.optimize import minimize
+import tensorflow as tf
+
+from ...commons import dataset_gather
+from ...types import Callable, List, Union, Optional, Tuple
+
+from .proto_greedy_search import ProtoGreedySearch
+from ..projections import Projection
+
+class Optimizer():
+ """
+ Class to solve the quadratic problem:
+ F(S) ≡ max_{w:supp(w)∈ S, w ≥ 0} l(w),
+ where l(w) = w^T * μ_p - 1/2 * w^T * K * w
+
+ Parameters
+ ----------
+ initial_weights : Tensor
+ Initial weight vector.
+ min_weight : float, optional
+ Lower bound on weight. Default is 0.
+ max_weight : float, optional
+ Upper bound on weight. Default is 10000.
+ """
+
+ def __init__(
+ self,
+ initial_weights: Union[tf.Tensor, np.ndarray],
+ min_weight: float = 0,
+ max_weight: float = 10000
+ ):
+ self.initial_weights = initial_weights
+ self.min_weight = min_weight
+ self.max_weight = max_weight
+ self.bounds = [(min_weight, max_weight)] * initial_weights.shape[0]
+ self.objective_fn = lambda w, u, K: - (w @ u - 0.5 * w @ K @ w)
+
+ def optimize(self, u, K):
+ """
+ Perform optimization to find the optimal values of the weight vector (w)
+ and the corresponding objective function value.
+
+ Parameters
+ ----------
+ u : Tensor
+ Mean similarity of each prototype.
+ K : Tensor
+ The kernel matrix.
+
+ Returns
+ -------
+ best_weights : Tensor
+ The optimal value of the weight vector (w).
+ best_objective : Tensor
+ The value of the objective function corresponding to the best_weights.
+ """
+
+ u = u.numpy()
+ K = K.numpy()
+
+ result = minimize(self.objective_fn, self.initial_weights, args=(u, K), method='SLSQP', bounds=self.bounds, options={'disp': False})
+
+ # Get the best weights
+ best_weights = result.x
+ best_weights = tf.expand_dims(tf.convert_to_tensor(best_weights, dtype=tf.float32), axis=0)
+
+ # Get the best objective
+ best_objective = -result.fun
+ best_objective = tf.expand_dims(tf.convert_to_tensor(best_objective, dtype=tf.float32), axis=0)
+
+ assert tf.reduce_all(best_weights >= 0)
+
+ return best_weights, best_objective
+
+
+class ProtoDashSearch(ProtoGreedySearch):
+ """
+ Protodash method for searching prototypes.
+
+ References:
+ .. [#] `Karthik S. Gurumoorthy, Amit Dhurandhar, Guillermo Cecchi,
+ "ProtoDash: Fast Interpretable Prototype Selection"
+ `_
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ For natural example-based methods it is the train dataset.
+ k
+ The number of examples to retrieve.
+ search_returns
+ String or list of string with the elements to return in `self.find_examples()`.
+ See `self.set_returns()` for detail.
+ batch_size
+ Number of sample treated simultaneously.
+ It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
+ distance
+ Either a Callable, or a value supported by `tf.norm` `ord` parameter.
+ Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
+ "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
+ yielding the corresponding p-norm." We also added 'cosine'.
+ nb_prototypes : int
+ Number of prototypes to find.
+ find_prototypes_kwargs
+ Additional parameters passed to `find_prototypes` function.
+ """
+
+ def find_prototypes(self, nb_prototypes, kernel_type: str = 'local', kernel_fn: callable = rbf_kernel, use_optimizer: bool = False):
+ """
+ Search for prototypes and their corresponding weights.
+
+ Parameters
+ ----------
+ nb_prototypes : int
+ Number of prototypes to find.
+ nb_prototypes : int
+ Number of prototypes to find.
+ kernel_type : str, optional
+ The kernel type. It can be 'local' or 'global', by default 'local'.
+ When it is local, the distances are calculated only within the classes.
+ kernel_fn : Callable, optional
+ Kernel function or kernel matrix, by default rbf_kernel.
+ use_optimizer : bool, optional
+ Flag indicating whether to use an optimizer for prototype selection, by default False.
+
+ Returns
+ -------
+ prototype_indices : Tensor
+ The indices of the selected prototypes.
+ prototype_weights :
+ The normalized weights of the selected prototypes.
+ """
+
+ self.use_optimizer = use_optimizer
+
+ return super().find_prototypes(nb_prototypes, kernel_type, kernel_fn)
+
+ def update_selection_weights(self, selection_indices, selection_weights, selection_selection_kernel, best_indice, best_weights, best_objective):
+ """
+ Update the selection weights based on the given parameters.
+ Pursuant to Lemma IV.4:
+ If best_gradient ≤ 0, then
+ ζ(S∪{best_sample_index}) = ζ(S) and specifically, w_{best_sample_index} = 0.
+ Otherwise, the stationarity and complementary slackness KKT conditions
+ entails that w_{best_sample_index} = best_gradient / κ(best_sample_index, best_sample_index)
+
+ Parameters
+ ----------
+ selected_indices : Tensor
+ Indices corresponding to the selected prototypes.
+ selected_weights : Tensor
+ Weights corresponding to the selected prototypes.
+ selection_selection_kernel : Tensor
+ Kernel matrix computed from the selected prototypes.
+ best_indice : int
+ The index of the selected prototype with the highest objective function value.
+ best_weights : Tensor
+ The weights corresponding to the optimal solution of the objective function for each candidate.
+ best_objective : float
+ The computed objective function value.
+
+ Returns
+ -------
+ selection_weights : Tensor
+ Updated weights corresponding to the selected prototypes.
+ """
+
+ if best_objective <= 0:
+ selection_weights = tf.concat([selection_weights, [0]], axis=0)
+ else:
+ u = tf.expand_dims(tf.gather(self.col_means, selection_indices), axis=1)
+ K = selection_selection_kernel
+
+ if self.use_optimizer:
+ initial_weights = tf.concat([selection_weights, [best_objective / tf.gather(self.diag, best_indice)]], axis=0)
+ opt = Optimizer(initial_weights)
+ selection_weights, _ = opt.optimize(u, K)
+ selection_weights = tf.squeeze(selection_weights, axis=0)
+ else:
+ # We added epsilon to the diagonal of K to ensure that K is invertible
+ K_inv = tf.linalg.inv(K + ProtoDashSearch.EPSILON * tf.eye(K.shape[-1]))
+ selection_weights = tf.linalg.matmul(K_inv, u)
+ selection_weights = tf.maximum(selection_weights, 0)
+ selection_weights = tf.squeeze(selection_weights, axis=1)
+
+ return selection_weights
+
+ def compute_objectives(self, selection_indices, selection_cases, selection_labels, selection_weights, selection_selection_kernel, candidates_indices, candidates_cases, candidates_labels, candidates_selection_kernel):
+ """
+ Compute the objective function and corresponding weights for a given set of selected prototypes and a candidate.
+ Calculate the gradient of l(w) = w^T * μ_p - 1/2 * w^T * K * w
+ w.r.t w, on the optimal weight point ζ^(S)
+ g = ∇l(ζ^(S)) = μ_p - K * ζ^(S)
+ g is computed for each candidate c
+
+ Parameters
+ ----------
+ selection_indices : Tensor
+ Indices corresponding to the selected prototypes.
+ selection_cases : Tensor
+ Cases corresponding to the selected prototypes.
+ selection_labels : Tensor
+ Labels corresponding to the selected prototypes.
+ selection_weights : Tensor
+ Weights corresponding to the selected prototypes.
+ selection_selection_kernel : Tensor
+ Kernel matrix computed from the selected prototypes.
+ candidates_indices : Tensor
+ Indices corresponding to the candidate prototypes.
+ candidates_cases : Tensor
+ Cases corresponding to the candidate prototypes.
+ candidates_labels : Tensor
+ Labels corresponding to the candidate prototypes.
+ candidates_selection_kernel : Tensor
+ Kernel matrix between the candidates and the selected prototypes.
+
+ Returns
+ -------
+ objectives
+ Tensor that contains the computed objective values for each candidate.
+ objectives_weights
+ Tensor that contains the computed objective weights for each candidate.
+ """
+
+ u = tf.gather(self.col_means, candidates_indices)
+
+ if selection_indices.shape[0] == 0:
+ # S = ∅ and ζ^(∅) = 0, g = ∇l(ζ^(∅)) = μ_p
+ objectives = u
+ else:
+ u = tf.expand_dims(u, axis=1)
+ K = candidates_selection_kernel
+
+ objectives = u - tf.matmul(K, tf.expand_dims(selection_weights, axis=1))
+ objectives = tf.squeeze(objectives, axis=1)
+
+ return objectives, None
diff --git a/xplique/example_based/search_methods/proto_greedy_search.py b/xplique/example_based/search_methods/proto_greedy_search.py
new file mode 100644
index 00000000..00128b5c
--- /dev/null
+++ b/xplique/example_based/search_methods/proto_greedy_search.py
@@ -0,0 +1,385 @@
+"""
+ProtoGreedy search method in example-based module
+"""
+
+import numpy as np
+from sklearn.metrics.pairwise import rbf_kernel
+import tensorflow as tf
+
+from ...commons import dataset_gather, sanitize_dataset
+from ...types import Callable, List, Union, Optional, Tuple
+
+from .prototypes_search import PrototypesSearch
+from ..projections import Projection
+
+
+class ProtoGreedySearch(PrototypesSearch):
+ """
+ ProtoGreedy method for searching prototypes.
+
+ References:
+ .. [#] `Karthik S. Gurumoorthy, Amit Dhurandhar, Guillermo Cecchi,
+ "ProtoDash: Fast Interpretable Prototype Selection"
+ `_
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ For natural example-based methods it is the train dataset.
+ k
+ The number of examples to retrieve.
+ search_returns
+ String or list of string with the elements to return in `self.find_examples()`.
+ See `self.set_returns()` for detail.
+ batch_size
+ Number of sample treated simultaneously.
+ It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
+ distance
+ Either a Callable, or a value supported by `tf.norm` `ord` parameter.
+ Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
+ "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
+ yielding the corresponding p-norm." We also added 'cosine'.
+ nb_prototypes : int
+ Number of prototypes to find.
+ find_prototypes_kwargs
+ Additional parameters passed to `find_prototypes` function.
+ """
+
+ # Avoid zero division during procedure. (the value is not important, as if the denominator is
+ # zero, then the nominator will also be zero).
+ EPSILON = tf.constant(1e-6)
+
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ search_returns: Optional[Union[List[str], str]] = None,
+ batch_size: Optional[int] = 32,
+ distance: Union[int, str, Callable] = "euclidean",
+ nb_prototypes: int = 1,
+ **find_prototypes_kwargs
+ ): # pylint: disable=R0801
+ super().__init__(
+ cases_dataset, labels_dataset, k, search_returns, batch_size, distance, nb_prototypes, **find_prototypes_kwargs
+ )
+
+ def compute_objectives(self, selection_indices, selection_cases, selection_labels, selection_weights, selection_selection_kernel, candidates_indices, candidates_cases, candidates_labels, candidates_selection_kernel):
+ """
+ Compute the objective and its weights for each candidate.
+
+ Parameters
+ ----------
+ selection_indices : Tensor
+ Indices corresponding to the selected prototypes.
+ selection_cases : Tensor
+ Cases corresponding to the selected prototypes.
+ selection_labels : Tensor
+ Labels corresponding to the selected prototypes.
+ selection_weights : Tensor
+ Weights corresponding to the selected prototypes.
+ selection_selection_kernel : Tensor
+ Kernel matrix computed from the selected prototypes.
+ candidates_indices : Tensor
+ Indices corresponding to the candidate prototypes.
+ candidates_cases : Tensor
+ Cases corresponding to the candidate prototypes.
+ candidates_labels : Tensor
+ Labels corresponding to the candidate prototypes.
+ candidates_selection_kernel : Tensor
+ Kernel matrix between the candidates and the selected prototypes.
+
+ Returns
+ -------
+ objectives
+ Tensor that contains the computed objective values for each candidate.
+ objectives_weights
+ Tensor that contains the computed objective weights for each candidate.
+ """
+
+ nb_candidates = candidates_indices.shape[0]
+ nb_selection = selection_cases.shape[0]
+
+ repeated_selection_indices = tf.tile(tf.expand_dims(selection_indices, 0), [nb_candidates, 1])
+ repeated_selection_candidates_indices = tf.concat([repeated_selection_indices, tf.expand_dims(candidates_indices, 1)], axis=1)
+ u = tf.expand_dims(tf.gather(self.col_means, repeated_selection_candidates_indices), axis=2)
+
+ if nb_selection == 0:
+ K = tf.expand_dims(tf.expand_dims(tf.gather(self.diag, candidates_indices), axis=-1), axis=-1)
+ else:
+ repeated_selection_selection_kernel = tf.tile(tf.expand_dims(selection_selection_kernel, 0), [nb_candidates, 1, 1])
+ repeated_selection_selection_kernel = tf.pad(repeated_selection_selection_kernel, [[0, 0], [0, 1], [0, 1]])
+
+ candidates_diag = tf.expand_dims(tf.expand_dims(tf.gather(self.diag, candidates_indices), axis=-1), axis=-1)
+ candidates_diag = tf.pad(candidates_diag, [[0, 0], [nb_selection, 0], [nb_selection, 0]])
+
+ candidates_selection_kernel = tf.expand_dims(candidates_selection_kernel, axis=-1)
+ candidates_selection_kernel = tf.pad(candidates_selection_kernel, [[0, 0], [0, 1], [nb_selection, 0]])
+
+ K = repeated_selection_selection_kernel + candidates_diag + candidates_selection_kernel + tf.transpose(candidates_selection_kernel, [0, 2, 1])
+
+ # Compute the objective weights for each candidate in the batch
+ K_inv = tf.linalg.inv(K + ProtoGreedySearch.EPSILON * tf.eye(K.shape[-1]))
+ objectives_weights = tf.matmul(K_inv, u)
+ objectives_weights = tf.maximum(objectives_weights, 0)
+
+ # Compute the objective for each candidate in the batch
+ objectives = tf.matmul(tf.transpose(objectives_weights, [0, 2, 1]), u) - 0.5 * tf.matmul(tf.matmul(tf.transpose(objectives_weights, [0, 2, 1]), K), objectives_weights)
+ objectives = tf.squeeze(objectives, axis=[1,2])
+
+ return objectives, objectives_weights
+
+ def update_selection_weights(self, selection_indices, selection_weights, selection_selection_kernel, best_indice, best_weights, best_objective):
+ """
+ Update the selection weights based on the optimization results.
+
+ Parameters
+ ----------
+ selected_indices : Tensor
+ Indices corresponding to the selected prototypes.
+ selected_weights : Tensor
+ Weights corresponding to the selected prototypes.
+ selection_selection_kernel : Tensor
+ Kernel matrix computed from the selected prototypes.
+ best_indice : int
+ The index of the selected prototype with the highest objective function value.
+ best_weights : Tensor
+ The weights corresponding to the optimal solution of the objective function for each candidate.
+ best_objective : float
+ The computed objective function value.
+
+ Returns
+ -------
+ selection_weights : Tensor
+ Updated weights corresponding to the selected prototypes.
+ """
+
+ selection_weights = best_weights
+
+ return selection_weights
+
+ def compute_kernel_attributes(self, kernel_type: str = 'local', kernel_fn: callable = rbf_kernel):
+ """
+ Compute the attributes of the class that are related to the kernel.
+
+ Parameters
+ ----------
+ kernel_type : str, optional
+ The kernel type. It can be 'local' or 'global', by default 'local'.
+ When it is local, the distances are calculated only within the classes.
+ kernel_fn : Callable, optional
+ Kernel function or kernel matrix, by default rbf_kernel.
+
+ Returns
+ -------
+ selection_weights : Tensor
+ Updated weights corresponding to the selected prototypes.
+ """
+ if kernel_type in ['local', 'global']:
+ self.kernel_type = kernel_type
+ else:
+ raise AttributeError(
+ "The kernel_type parameter is expected to be in"
+ + " ['local', 'global'] ",
+ +f"but {kernel_type} was received.",
+ )
+
+ if hasattr(kernel_fn, "__call__"):
+ def custom_kernel_fn(x1, x2, y1, y2):
+ if self.kernel_type == 'global':
+ kernel_matrix = kernel_fn(x1,x2)
+ if isinstance(kernel_matrix, np.ndarray):
+ kernel_matrix = tf.convert_to_tensor(kernel_matrix)
+ else:
+ # In the case of a local kernel, calculations are limited to within the class.
+ # Across different classes, the kernel values are set to 0.
+ kernel_matrix = np.zeros((x1.shape[0], x2.shape[0]), dtype=np.float32)
+ y_intersect = np.intersect1d(y1, y2)
+ for i in range(y_intersect.shape[0]):
+ y1_indices = tf.where(tf.equal(y1, y_intersect[i]))[:, 0]
+ y2_indices = tf.where(tf.equal(y2, y_intersect[i]))[:, 0]
+ sub_matrix = kernel_fn(tf.gather(x1, y1_indices), tf.gather(x2, y2_indices))
+ kernel_matrix[tf.reshape(y1_indices, (-1, 1)), tf.reshape(y2_indices, (1, -1))] = sub_matrix
+ kernel_matrix = tf.convert_to_tensor(kernel_matrix)
+ return kernel_matrix
+
+ self.kernel_fn = custom_kernel_fn
+ else:
+ raise AttributeError(
+ "The kernel parameter is expected to be a Callable",
+ +f"but {kernel_fn} was received.",
+ )
+
+ # TODO: for local explanation add the ability to compute distance_fn based on the kernel
+
+ # Compute the sum of the columns and the diagonal values of the kernel matrix of the dataset.
+ # We take advantage of the symmetry of this matrix to traverse only its lower triangle.
+ col_sums = []
+ diag = []
+ row_sums = []
+
+ for batch_col_index, (batch_col_cases, batch_col_labels) in enumerate(
+ zip(self.cases_dataset, self.labels_dataset)
+ ):
+ batch_col_sums = tf.zeros((batch_col_cases.shape[0]))
+
+ for batch_row_index, (batch_row_cases, batch_row_labels) in enumerate(
+ zip(self.cases_dataset, self.labels_dataset)
+ ):
+ if batch_row_index < batch_col_index:
+ continue
+
+ batch_kernel = self.kernel_fn(batch_row_cases, batch_col_cases, batch_row_labels, batch_col_labels)
+
+ batch_col_sums = batch_col_sums + tf.reduce_sum(batch_kernel, axis=0)
+
+ if batch_col_index == batch_row_index:
+ if batch_col_index != 0:
+ batch_col_sums = batch_col_sums + row_sums[batch_row_index]
+
+ diag.append(tf.linalg.diag_part(batch_kernel))
+
+ if batch_col_index == 0:
+ if batch_row_index == 0:
+ row_sums.append(None)
+ else:
+ row_sums.append(tf.reduce_sum(batch_kernel, axis=1))
+ else:
+ row_sums[batch_row_index] += tf.reduce_sum(batch_kernel, axis=1)
+
+ col_sums.append(batch_col_sums)
+
+ self.col_sums = tf.concat(col_sums, axis=0)
+ self.n = self.col_sums.shape[0]
+ self.col_means = self.col_sums / self.n
+ self.diag = tf.concat(diag, axis=0)
+ self.nb_features = batch_col_cases.shape[1]
+
+ def find_prototypes(self, nb_prototypes, kernel_type: str = 'local', kernel_fn: callable = rbf_kernel):
+ """
+ Search for prototypes and their corresponding weights.
+
+ Parameters
+ ----------
+ nb_prototypes : int
+ Number of prototypes to find.
+ kernel_type : str, optional
+ The kernel type. It can be 'local' or 'global', by default 'local'.
+ When it is local, the distances are calculated only within the classes.
+ kernel_fn : Callable, optional
+ Kernel function or kernel matrix, by default rbf_kernel.
+
+ Returns
+ -------
+ prototype_indices : Tensor
+ The indices of the selected prototypes.
+ prototype_weights :
+ The normalized weights of the selected prototypes.
+ """
+
+ self.compute_kernel_attributes(kernel_type, kernel_fn)
+
+ # Tensors to store selected indices and their corresponding cases, labels and weights.
+ selection_indices = tf.constant([], dtype=tf.int32)
+ selection_cases = tf.zeros((0, self.nb_features), dtype=tf.float32)
+ selection_labels = tf.constant([], dtype=tf.int32)
+ selection_weights = tf.constant([], dtype=tf.float32)
+ # Tensor to store the all_candidates-selection kernel of the previous iteration.
+ all_candidates_selection_kernel = tf.zeros((self.n, 0), dtype=tf.float32)
+ # Tensor to store the selection-selection kernel.
+ selection_selection_kernel = None
+
+ k = 0
+ while k < nb_prototypes:
+
+ nb_selection = selection_cases.shape[0]
+
+ # Tensor to store the all_candidates-last_selected kernel
+ if nb_selection !=0:
+ all_candidates_last_selected_kernel = tf.zeros((self.n), dtype=tf.float32)
+
+ best_objective = None
+ best_indice = None
+ best_case = None
+ best_label = None
+ best_weights = None
+
+ for batch_index, (cases, labels) in enumerate(
+ zip(self.cases_dataset, self.labels_dataset)
+ ):
+ batch_inside_indices = tf.range(cases.shape[0], dtype=tf.int32)
+ batch_indices = batch_index * self.batch_size + batch_inside_indices
+
+ # Filter the batch to keep only candidate indices.
+ if nb_selection == 0:
+ candidates_indices = batch_indices
+ else:
+ candidates_indices = tf.convert_to_tensor(np.setdiff1d(batch_indices, selection_indices))
+
+ nb_candidates = candidates_indices.shape[0]
+
+ if nb_candidates == 0:
+ continue
+
+ candidates_inside_indices = candidates_indices % self.batch_size
+ candidates_cases = tf.gather(cases, candidates_inside_indices)
+ candidates_labels = tf.gather(labels, candidates_inside_indices)
+
+ # Compute the candidates-selection kernel for the batch
+ if nb_selection == 0:
+ candidates_selection_kernel = None
+ else:
+ candidates_last_selected_kernel = self.kernel_fn(candidates_cases, selection_cases[-1:, :], candidates_labels, selection_labels[-1:])
+ candidates_selection_kernel = tf.concat([tf.gather(all_candidates_selection_kernel, candidates_indices, axis=0), candidates_last_selected_kernel], axis=1)
+ all_candidates_last_selected_kernel = tf.tensor_scatter_nd_update(all_candidates_last_selected_kernel, tf.expand_dims(candidates_indices, axis=1), tf.squeeze(candidates_last_selected_kernel, axis=1))
+
+ # Compute the objectives for the batch
+ objectives, objectives_weights = self.compute_objectives(selection_indices, selection_cases, selection_labels, selection_weights, selection_selection_kernel, candidates_indices, candidates_cases, candidates_labels, candidates_selection_kernel)
+
+ # Select the best objective in the batch
+ objectives_argmax = tf.argmax(objectives)
+
+ if (best_objective is None) or (tf.gather(objectives, objectives_argmax) > best_objective):
+ best_objective = tf.gather(objectives, objectives_argmax)
+ best_indice = tf.squeeze(tf.gather(candidates_indices, objectives_argmax))
+ best_case = tf.gather(candidates_cases, objectives_argmax)
+ best_label = tf.gather(candidates_labels, objectives_argmax)
+ if objectives_weights is not None:
+ best_weights = tf.squeeze(tf.gather(objectives_weights, objectives_argmax))
+
+ # Update the all_candidates-selection kernel
+ if nb_selection != 0:
+ all_candidates_selection_kernel = tf.concat([all_candidates_selection_kernel, tf.expand_dims(all_candidates_last_selected_kernel, axis=1)], axis=1)
+
+ # Update the selection-selection kernel
+ if nb_selection == 0:
+ selection_selection_kernel = tf.gather(self.diag, [[best_indice]])
+ else:
+ selection_selection_kernel = tf.pad(selection_selection_kernel, [[0, 1], [0, 1]])
+
+ best_candidate_selection_kernel = tf.gather(all_candidates_selection_kernel, [best_indice], axis=0)
+ best_candidate_selection_kernel = tf.pad(best_candidate_selection_kernel, [[nb_selection, 0], [0, 1]])
+
+ best_candidate_diag = tf.expand_dims(tf.gather(self.diag, [best_indice]), axis=-1)
+ best_candidate_diag = tf.pad(best_candidate_diag, [[nb_selection, 0], [nb_selection, 0]])
+
+ selection_selection_kernel = selection_selection_kernel + best_candidate_diag + best_candidate_selection_kernel + tf.transpose(best_candidate_selection_kernel)
+
+ # Update selection indices, cases and labels
+ selection_indices = tf.concat([selection_indices, [best_indice]], axis=0)
+ selection_cases = tf.concat([selection_cases, [best_case]], axis=0)
+ selection_labels = tf.concat([selection_labels, [best_label]], axis=0)
+
+ # Update selection weights
+ selection_weights = self.update_selection_weights(selection_indices, selection_weights, selection_selection_kernel, best_indice, best_weights, best_objective)
+
+ k += 1
+
+ prototype_indices = selection_indices
+ prototype_weights = selection_weights
+
+ # Normalize the weights
+ prototype_weights = prototype_weights / tf.reduce_sum(prototype_weights)
+
+ return prototype_indices, prototype_weights
\ No newline at end of file
diff --git a/xplique/example_based/search_methods/prototypes_search.py b/xplique/example_based/search_methods/prototypes_search.py
new file mode 100644
index 00000000..6db01b3c
--- /dev/null
+++ b/xplique/example_based/search_methods/prototypes_search.py
@@ -0,0 +1,137 @@
+"""
+Prototypes search method in example-based module
+"""
+
+from abc import ABC, abstractmethod
+import numpy as np
+from sklearn.metrics.pairwise import rbf_kernel
+import tensorflow as tf
+
+from ...commons import dataset_gather
+from ...types import Callable, List, Union, Optional, Tuple
+
+from .base import BaseSearchMethod
+from ..projections import Projection
+
+
+class PrototypesSearch(BaseSearchMethod):
+ """
+ Prototypes search method to find prototypes and the examples closest to these prototypes.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ For natural example-based methods it is the train dataset.
+ k
+ The number of examples to retrieve.
+ search_returns
+ String or list of string with the elements to return in `self.find_examples()`.
+ See `self.set_returns()` for detail.
+ batch_size
+ Number of sample treated simultaneously.
+ It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
+ distance
+ Either a Callable, or a value supported by `tf.norm` `ord` parameter.
+ Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
+ "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
+ yielding the corresponding p-norm." We also added 'cosine'.
+ nb_prototypes : int
+ Number of prototypes to find.
+ find_prototypes_kwargs
+ Additional parameters passed to `find_prototypes` function.
+ """
+
+ # Avoid zero division during procedure. (the value is not important, as if the denominator is
+ # zero, then the nominator will also be zero).
+ EPSILON = tf.constant(1e-6)
+
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ search_returns: Optional[Union[List[str], str]] = None,
+ batch_size: Optional[int] = 32,
+ distance: Union[int, str, Callable] = "euclidean",
+ nb_prototypes: int = 1,
+ **find_prototypes_kwargs,
+ ): # pylint: disable=R0801
+ super().__init__(
+ cases_dataset, labels_dataset, k, search_returns, batch_size
+ )
+
+ if hasattr(distance, "__call__"):
+ self.distance_fn = distance
+ elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(
+ distance, int
+ ):
+ self.distance_fn = lambda x1, x2: tf.norm(x1 - x2, ord=distance)
+ else:
+ raise AttributeError(
+ "The distance parameter is expected to be either a Callable or in"
+ + " ['fro', 'euclidean', 'cosine', 1, 2, np.inf] ",
+ +f"but {distance} was received.",
+ )
+
+ self.prototype_indices, self.prototype_weights = self.find_prototypes(nb_prototypes, **find_prototypes_kwargs)
+
+ @abstractmethod
+ def find_prototypes(self, nb_prototypes: int, **find_prototypes_kwargs):
+ """
+ Search for prototypes and their corresponding weights.
+
+ Parameters
+ ----------
+ nb_prototypes : int
+ Number of prototypes to find.
+
+ find_prototypes_kwargs
+ Additional parameters passed to `find_prototypes` function.
+
+ Returns
+ -------
+ prototype_indices : Tensor
+ The indices of the selected prototypes.
+ prototype_weights :
+ The normalized weights of the selected prototypes.
+ """
+ return NotImplementedError()
+
+ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray]):
+ """
+ Search the samples to return as examples. Called by the explain methods.
+ It may also return the indices corresponding to the samples,
+ based on `return_indices` value.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Assumed to have been already projected.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ """
+ # TODO: Find examples: here we provide a local explanation.
+ # Find the nearest prototypes to inputs
+ # we use self.distance_fn and self.prototype_indices.
+ examples_indices = None
+ examples_distances = None
+
+ # Set values in return dict
+ return_dict = {}
+ if "examples" in self.returns:
+ return_dict["examples"] = dataset_gather(self.cases_dataset, examples_indices)
+ if "include_inputs" in self.returns:
+ inputs = tf.expand_dims(inputs, axis=1)
+ return_dict["examples"] = tf.concat(
+ [inputs, return_dict["examples"]], axis=1
+ )
+ if "indices" in self.returns:
+ return_dict["indices"] = examples_indices
+ if "distances" in self.returns:
+ return_dict["distances"] = examples_distances
+
+ # Return a dict only different variables are returned
+ if len(return_dict) == 1:
+ return list(return_dict.values())[0]
+ return return_dict
From a86fd55e3857d534e131d6fc821721cabde0c856 Mon Sep 17 00:00:00 2001
From: Mohamed Chafik Bakey
Date: Thu, 21 Mar 2024 13:13:00 +0100
Subject: [PATCH 035/138] add Prototypes fix up
---
tests/example_based/test_prototypes.py | 2 +-
.../search_methods/proto_greedy_search.py | 15 ---------------
2 files changed, 1 insertion(+), 16 deletions(-)
diff --git a/tests/example_based/test_prototypes.py b/tests/example_based/test_prototypes.py
index 39ba2425..68ae3e59 100644
--- a/tests/example_based/test_prototypes.py
+++ b/tests/example_based/test_prototypes.py
@@ -243,6 +243,6 @@ def custom_kernel(x,y=None):
# # Visualize all prototypes
# plot(prototypes, prototype_weights, 'mmd_critic')
-test_proto_greedy_basic()
+# test_proto_greedy_basic()
# test_proto_dash_basic()
# test_mmd_critic_basic()
diff --git a/xplique/example_based/search_methods/proto_greedy_search.py b/xplique/example_based/search_methods/proto_greedy_search.py
index 00128b5c..a383b33a 100644
--- a/xplique/example_based/search_methods/proto_greedy_search.py
+++ b/xplique/example_based/search_methods/proto_greedy_search.py
@@ -50,21 +50,6 @@ class ProtoGreedySearch(PrototypesSearch):
# zero, then the nominator will also be zero).
EPSILON = tf.constant(1e-6)
- def __init__(
- self,
- cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- k: int = 1,
- search_returns: Optional[Union[List[str], str]] = None,
- batch_size: Optional[int] = 32,
- distance: Union[int, str, Callable] = "euclidean",
- nb_prototypes: int = 1,
- **find_prototypes_kwargs
- ): # pylint: disable=R0801
- super().__init__(
- cases_dataset, labels_dataset, k, search_returns, batch_size, distance, nb_prototypes, **find_prototypes_kwargs
- )
-
def compute_objectives(self, selection_indices, selection_cases, selection_labels, selection_weights, selection_selection_kernel, candidates_indices, candidates_cases, candidates_labels, candidates_selection_kernel):
"""
Compute the objective and its weights for each candidate.
From d70557330fb6b9e1f3bedbeb4cc9e9692fc75d9f Mon Sep 17 00:00:00 2001
From: Mohamed Chafik Bakey
Date: Thu, 4 Apr 2024 12:48:02 +0200
Subject: [PATCH 036/138] add Prototypes fix up
---
tests/example_based/test_prototypes.py | 50 +--
tests/utils.py | 28 +-
xplique/example_based/__init__.py | 3 +
xplique/example_based/base_example_method.py | 1 -
xplique/example_based/mmd_critic.py | 100 ++++++
xplique/example_based/proto_dash.py | 100 ++++++
xplique/example_based/proto_greedy.py | 100 ++++++
xplique/example_based/prototypes.py | 154 +++++++++-
.../example_based/search_methods/__init__.py | 1 -
xplique/example_based/search_methods/base.py | 2 -
.../search_methods/mmd_critic_search.py | 17 +-
.../search_methods/proto_dash_search.py | 63 ++--
.../search_methods/proto_greedy_search.py | 289 +++++++++++-------
.../search_methods/prototypes_search.py | 137 ---------
14 files changed, 728 insertions(+), 317 deletions(-)
create mode 100644 xplique/example_based/mmd_critic.py
create mode 100644 xplique/example_based/proto_dash.py
create mode 100644 xplique/example_based/proto_greedy.py
delete mode 100644 xplique/example_based/search_methods/prototypes_search.py
diff --git a/tests/example_based/test_prototypes.py b/tests/example_based/test_prototypes.py
index 68ae3e59..720c47ee 100644
--- a/tests/example_based/test_prototypes.py
+++ b/tests/example_based/test_prototypes.py
@@ -17,11 +17,10 @@
from xplique.commons import sanitize_dataset, are_dataset_first_elems_equal
from xplique.types import Union
-from xplique.example_based import Prototypes
+from xplique.example_based import Prototypes, ProtoGreedy, ProtoDash, MMDCritic
from xplique.example_based.projections import Projection, LatentSpaceProjection
-from xplique.example_based.search_methods import ProtoGreedySearch, ProtoDashSearch, MMDCriticSearch
-from tests.utils import almost_equal, get_Gaussian_Data, load_data, plot
+from tests.utils import almost_equal, get_Gaussian_Data, load_data, plot, plot_local_explanation
def test_proto_greedy_basic():
@@ -52,10 +51,9 @@ def custom_kernel(x,y=None):
kernel_type = "global"
# Method initialization
- method = Prototypes(
+ method = ProtoGreedy(
cases_dataset=x_train,
labels_dataset=y_train,
- search_method=ProtoGreedySearch,
k=k,
projection=identity_projection,
batch_size=32,
@@ -65,8 +63,8 @@ def custom_kernel(x,y=None):
kernel_fn=kernel_fn,
)
- # Generate explanation
- prototype_indices, prototype_weights = method.get_prototypes()
+ # Generate global explanation
+ prototype_indices, prototype_weights = method.get_global_prototypes()
prototypes = tf.gather(x_train, prototype_indices)
prototype_labels = tf.gather(y_train, prototype_indices)
@@ -97,6 +95,12 @@ def custom_kernel(x,y=None):
# # Visualize all prototypes
# plot(prototypes, prototype_weights, 'proto_greedy')
+ # # Generate local explanation
+ # examples = method.explain(x_test)
+
+ # # Visualize local explanation
+ # plot_local_explanation(examples, x_test, 'proto_greedy')
+
def test_proto_dash_basic():
"""
Test the SimilarExamples with an identity projection.
@@ -125,10 +129,9 @@ def custom_kernel(x,y=None):
kernel_type = "global"
# Method initialization
- method = Prototypes(
+ method = ProtoDash(
cases_dataset=x_train,
labels_dataset=y_train,
- search_method=ProtoDashSearch,
k=k,
projection=identity_projection,
batch_size=32,
@@ -138,8 +141,8 @@ def custom_kernel(x,y=None):
kernel_fn=kernel_fn,
)
- # Generate explanation
- prototype_indices, prototype_weights = method.get_prototypes()
+ # Generate global explanation
+ prototype_indices, prototype_weights = method.get_global_prototypes()
prototypes = tf.gather(x_train, prototype_indices)
prototype_labels = tf.gather(y_train, prototype_indices)
@@ -170,6 +173,12 @@ def custom_kernel(x,y=None):
# # Visualize all prototypes
# plot(prototypes, prototype_weights, 'proto_dash')
+ # # Generate local explanation
+ # examples = method.explain(x_test)
+
+ # # Visualize local explanation
+ # plot_local_explanation(examples, x_test, 'proto_dash')
+
def test_mmd_critic_basic():
"""
Test the SimilarExamples with an identity projection.
@@ -198,10 +207,9 @@ def custom_kernel(x,y=None):
kernel_type = "global"
# Method initialization
- method = Prototypes(
+ method = MMDCritic(
cases_dataset=x_train,
labels_dataset=y_train,
- search_method=MMDCriticSearch,
k=k,
projection=identity_projection,
batch_size=32,
@@ -211,8 +219,8 @@ def custom_kernel(x,y=None):
kernel_fn=kernel_fn,
)
- # Generate explanation
- prototype_indices, prototype_weights = method.get_prototypes()
+ # Generate global explanation
+ prototype_indices, prototype_weights = method.get_global_prototypes()
prototypes = tf.gather(x_train, prototype_indices)
prototype_labels = tf.gather(y_train, prototype_indices)
@@ -243,6 +251,12 @@ def custom_kernel(x,y=None):
# # Visualize all prototypes
# plot(prototypes, prototype_weights, 'mmd_critic')
-# test_proto_greedy_basic()
-# test_proto_dash_basic()
-# test_mmd_critic_basic()
+ # # Generate local explanation
+ # examples = method.explain(x_test)
+
+ # # Visualize local explanation
+ # plot_local_explanation(examples, x_test, 'mmd_critic')
+
+test_proto_greedy_basic()
+test_proto_dash_basic()
+test_mmd_critic_basic()
diff --git a/tests/utils.py b/tests/utils.py
index 4219cd1a..280d7a5f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -298,4 +298,30 @@ def plot(prototypes_sorted, prototype_weights_sorted, extension):
axis.set_title("{:.2f}".format(prototype_weights_sorted[i].numpy()))
axis.axis('off')
# fig.suptitle(f'{k} Prototypes')
- plt.savefig(output_dir / f'{k}_prototypes_{extension}.png')
\ No newline at end of file
+ plt.savefig(output_dir / f'{k}_prototypes_{extension}.png')
+
+def plot_local_explanation(examples, x_test, extension):
+
+ output_dir = Path('tests/example_based/tmp')
+ k = examples.shape[1]
+
+ # Visualize
+ num_cols = k+1
+ num_rows = x_test.shape[0]
+ fig, axes = plt.subplots(num_rows, num_cols, figsize=(6, num_rows * 0.75))
+ # Adjust the spacing between lines
+ plt.subplots_adjust(hspace=1)
+ axes[0,0].set_title("x_test")
+ for i in range(examples.shape[0]):
+ axes[i,0].imshow(x_test[i].numpy().reshape(16, 16), cmap='gray')
+ axes[i,0].axis('off')
+ for j in range(examples.shape[1]):
+ axe = axes[i,j+1]
+ axe.imshow(examples[i,j].numpy().reshape(16, 16), cmap='gray')
+ # axe.set_title("{:.2f}".format(prototype_distances[i,j]))
+ if i == 0:
+ axe.set_title("prototype_{}".format(j + 1))
+ axe.axis('off')
+
+ fig.suptitle(f'{k}-nearst prototypes')
+ plt.savefig(output_dir / f'{k}_nearest_prototypes_{extension}.png')
\ No newline at end of file
diff --git a/xplique/example_based/__init__.py b/xplique/example_based/__init__.py
index 0cdb3d2f..a30a789a 100644
--- a/xplique/example_based/__init__.py
+++ b/xplique/example_based/__init__.py
@@ -5,3 +5,6 @@
from .cole import Cole
from .similar_examples import SimilarExamples
from .prototypes import Prototypes
+from .proto_greedy import ProtoGreedy
+from .proto_dash import ProtoDash
+from .mmd_critic import MMDCritic
diff --git a/xplique/example_based/base_example_method.py b/xplique/example_based/base_example_method.py
index 31ea4f89..2c4b99df 100644
--- a/xplique/example_based/base_example_method.py
+++ b/xplique/example_based/base_example_method.py
@@ -120,7 +120,6 @@ def __init__(
# initiate search_method
self.search_method = search_method(
cases_dataset=projected_cases_dataset,
- labels_dataset=labels_dataset,
k=k,
batch_size=batch_size,
**search_method_kwargs,
diff --git a/xplique/example_based/mmd_critic.py b/xplique/example_based/mmd_critic.py
new file mode 100644
index 00000000..a2ccfb47
--- /dev/null
+++ b/xplique/example_based/mmd_critic.py
@@ -0,0 +1,100 @@
+"""
+MMDCritic method for searching prototypes
+"""
+
+import math
+
+import time
+
+import tensorflow as tf
+import numpy as np
+
+from ..types import Callable, Dict, List, Optional, Type, Union
+
+from ..commons import sanitize_inputs_targets
+from ..commons import sanitize_dataset, dataset_gather
+from .search_methods import MMDCriticSearch
+from .projections import Projection
+from .prototypes import Prototypes
+
+from .search_methods.base import _sanitize_returns
+
+
+class MMDCritic(Prototypes):
+ """
+ MMDCritic method for searching prototypes.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ targets_dataset
+ Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ k
+ The number of examples to retrieve.
+ projection
+ Projection or Callable that project samples from the input space to the search space.
+ The search space should be a space where distance make sense for the model.
+ It should not be `None`, otherwise,
+ all examples could be computed only with the `search_method`.
+
+ Example of Callable:
+ ```
+ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
+ '''
+ Example of projection,
+ inputs are the elements to project.
+ targets are optional parameters to orientated the projection.
+ '''
+ projected_inputs = # do some magic on inputs, it should use the model.
+ return projected_inputs
+ ```
+ case_returns
+ String or list of string with the elements to return in `self.explain()`.
+ See `self.set_returns()` for detail.
+ batch_size
+ Number of sample treated simultaneously for projection and search.
+ Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+ search_method_kwargs
+ Parameters to be passed at the construction of the `search_method`.
+ """
+
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ projection: Union[Projection, Callable] = None,
+ case_returns: Union[List[str], str] = "examples",
+ batch_size: Optional[int] = 32,
+ **search_method_kwargs,
+ ):
+ # the only difference with parent is that the search method is always MMDCriticSearch
+ search_method = MMDCriticSearch
+
+ super().__init__(
+ cases_dataset=cases_dataset,
+ labels_dataset=labels_dataset,
+ targets_dataset=targets_dataset,
+ search_method=search_method,
+ k=k,
+ projection=projection,
+ case_returns=case_returns,
+ batch_size=batch_size,
+ **search_method_kwargs,
+ )
+
diff --git a/xplique/example_based/proto_dash.py b/xplique/example_based/proto_dash.py
new file mode 100644
index 00000000..475e138b
--- /dev/null
+++ b/xplique/example_based/proto_dash.py
@@ -0,0 +1,100 @@
+"""
+ProtoDash method for searching prototypes
+"""
+
+import math
+
+import time
+
+import tensorflow as tf
+import numpy as np
+
+from ..types import Callable, Dict, List, Optional, Type, Union
+
+from ..commons import sanitize_inputs_targets
+from ..commons import sanitize_dataset, dataset_gather
+from .search_methods import ProtoDashSearch
+from .projections import Projection
+from .prototypes import Prototypes
+
+from .search_methods.base import _sanitize_returns
+
+
+class ProtoDash(Prototypes):
+ """
+ ProtoDash method for searching prototypes.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ targets_dataset
+ Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ k
+ The number of examples to retrieve.
+ projection
+ Projection or Callable that project samples from the input space to the search space.
+ The search space should be a space where distance make sense for the model.
+ It should not be `None`, otherwise,
+ all examples could be computed only with the `search_method`.
+
+ Example of Callable:
+ ```
+ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
+ '''
+ Example of projection,
+ inputs are the elements to project.
+ targets are optional parameters to orientated the projection.
+ '''
+ projected_inputs = # do some magic on inputs, it should use the model.
+ return projected_inputs
+ ```
+ case_returns
+ String or list of string with the elements to return in `self.explain()`.
+ See `self.set_returns()` for detail.
+ batch_size
+ Number of sample treated simultaneously for projection and search.
+ Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+ search_method_kwargs
+ Parameters to be passed at the construction of the `search_method`.
+ """
+
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ projection: Union[Projection, Callable] = None,
+ case_returns: Union[List[str], str] = "examples",
+ batch_size: Optional[int] = 32,
+ **search_method_kwargs,
+ ):
+ # the only difference with parent is that the search method is always ProtoDashSearch
+ search_method = ProtoDashSearch
+
+ super().__init__(
+ cases_dataset=cases_dataset,
+ labels_dataset=labels_dataset,
+ targets_dataset=targets_dataset,
+ search_method=search_method,
+ k=k,
+ projection=projection,
+ case_returns=case_returns,
+ batch_size=batch_size,
+ **search_method_kwargs,
+ )
+
diff --git a/xplique/example_based/proto_greedy.py b/xplique/example_based/proto_greedy.py
new file mode 100644
index 00000000..2c43565b
--- /dev/null
+++ b/xplique/example_based/proto_greedy.py
@@ -0,0 +1,100 @@
+"""
+ProtoGreedy method for searching prototypes
+"""
+
+import math
+
+import time
+
+import tensorflow as tf
+import numpy as np
+
+from ..types import Callable, Dict, List, Optional, Type, Union
+
+from ..commons import sanitize_inputs_targets
+from ..commons import sanitize_dataset, dataset_gather
+from .search_methods import ProtoGreedySearch
+from .projections import Projection
+from .prototypes import Prototypes
+
+from .search_methods.base import _sanitize_returns
+
+
+class ProtoGreedy(Prototypes):
+ """
+ ProtoGreedy method for searching prototypes.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ targets_dataset
+ Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ k
+ The number of examples to retrieve.
+ projection
+ Projection or Callable that project samples from the input space to the search space.
+ The search space should be a space where distance make sense for the model.
+ It should not be `None`, otherwise,
+ all examples could be computed only with the `search_method`.
+
+ Example of Callable:
+ ```
+ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
+ '''
+ Example of projection,
+ inputs are the elements to project.
+ targets are optional parameters to orientated the projection.
+ '''
+ projected_inputs = # do some magic on inputs, it should use the model.
+ return projected_inputs
+ ```
+ case_returns
+ String or list of string with the elements to return in `self.explain()`.
+ See `self.set_returns()` for detail.
+ batch_size
+ Number of sample treated simultaneously for projection and search.
+ Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+ search_method_kwargs
+ Parameters to be passed at the construction of the `search_method`.
+ """
+
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ projection: Union[Projection, Callable] = None,
+ case_returns: Union[List[str], str] = "examples",
+ batch_size: Optional[int] = 32,
+ **search_method_kwargs,
+ ):
+ # the only difference with parent is that the search method is always ProtoGreedySearch
+ search_method = ProtoGreedySearch
+
+ super().__init__(
+ cases_dataset=cases_dataset,
+ labels_dataset=labels_dataset,
+ targets_dataset=targets_dataset,
+ search_method=search_method,
+ k=k,
+ projection=projection,
+ case_returns=case_returns,
+ batch_size=batch_size,
+ **search_method_kwargs,
+ )
+
diff --git a/xplique/example_based/prototypes.py b/xplique/example_based/prototypes.py
index 9df3081f..906b6bca 100644
--- a/xplique/example_based/prototypes.py
+++ b/xplique/example_based/prototypes.py
@@ -13,7 +13,7 @@
from ..commons import sanitize_inputs_targets
from ..commons import sanitize_dataset, dataset_gather
-from .search_methods import ProtoGreedySearch, PrototypesSearch
+from .search_methods import ProtoGreedySearch
from .projections import Projection
from .base_example_method import BaseExampleMethod
@@ -68,12 +68,8 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
batch_size
Number of sample treated simultaneously for projection and search.
Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
- distance
- Distance for the knn search method.
- Either a Callable, or a value supported by `tf.norm` `ord` parameter.
- Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
- "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
- yielding the corresponding p-norm." We also added 'cosine'.
+ search_method_kwargs
+ Parameters to be passed at the construction of the `search_method`.
"""
def __init__(
@@ -81,28 +77,58 @@ def __init__(
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- search_method: Type[PrototypesSearch] = ProtoGreedySearch,
+ search_method: Type[ProtoGreedySearch] = ProtoGreedySearch,
k: int = 1,
projection: Union[Projection, Callable] = None,
case_returns: Union[List[str], str] = "examples",
batch_size: Optional[int] = 32,
**search_method_kwargs,
):
- super().__init__(
- cases_dataset=cases_dataset,
+ assert (
+ projection is not None
+ ), "`BaseExampleMethod` without `projection` is a `BaseSearchMethod`."
+
+ # set attributes
+ batch_size = self.__initialize_cases_dataset(
+ cases_dataset, labels_dataset, targets_dataset, batch_size
+ )
+
+ self.k = k
+ self.set_returns(case_returns)
+
+ assert hasattr(projection, "__call__"), "projection should be a callable."
+
+ # check projection type
+ if isinstance(projection, Projection):
+ self.projection = projection
+ elif hasattr(projection, "__call__"):
+ self.projection = Projection(get_weights=None, space_projection=projection)
+ else:
+ raise AttributeError(
+ "projection should be a `Projection` or a `Callable`, not a"
+ + f"{type(projection)}"
+ )
+
+ # project dataset
+ projected_cases_dataset = self.projection.project_dataset(self.cases_dataset,
+ self.targets_dataset)
+
+ # set `search_returns` if not provided and overwrite it otherwise
+ search_method_kwargs["search_returns"] = ["indices", "distances"]
+
+ # initiate search_method
+ self.search_method = search_method(
+ cases_dataset=projected_cases_dataset,
labels_dataset=labels_dataset,
- targets_dataset=targets_dataset,
- search_method=search_method,
k=k,
- projection=projection,
- case_returns=case_returns,
batch_size=batch_size,
**search_method_kwargs,
)
- def get_prototypes(self):
+ def get_global_prototypes(self):
"""
- Return the prototypes computed by the search method.
+ Return all the prototypes computed by the search method,
+ which consist of a global explanation of the dataset.
Returns:
prototype_indices : Tensor
@@ -111,4 +137,100 @@ def get_prototypes(self):
prototype weights.
"""
return self.search_method.prototype_indices, self.search_method.prototype_weights
+
+ def __initialize_cases_dataset(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]],
+ targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]],
+ batch_size: Optional[int],
+ ) -> int:
+ """
+ Factorization of `__init__()` method for dataset related attributes.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ labels_dataset
+ Labels associated to the examples in the dataset.
+ Indices should match with cases_dataset.
+ targets_dataset
+ Targets associated to the cases_dataset for dataset projection.
+ See `projection` for detail.
+ batch_size
+ Number of sample treated simultaneously when using the datasets.
+ Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+
+ Returns
+ -------
+ batch_size
+ Number of sample treated simultaneously when using the datasets.
+ Extracted from the datasets in case they are `tf.data.Dataset`.
+ Otherwise, the input value.
+ """
+ # at least one dataset provided
+ if isinstance(cases_dataset, tf.data.Dataset):
+ # set batch size (ignore provided argument) and cardinality
+ if isinstance(cases_dataset.element_spec, tuple):
+ batch_size = tf.shape(next(iter(cases_dataset))[0])[0].numpy()
+ else:
+ batch_size = tf.shape(next(iter(cases_dataset)))[0].numpy()
+
+ cardinality = cases_dataset.cardinality().numpy()
+ else:
+ # if case_dataset is not a `tf.data.Dataset`, then neither should the other.
+ assert not isinstance(labels_dataset, tf.data.Dataset)
+ assert not isinstance(targets_dataset, tf.data.Dataset)
+ # set batch size and cardinality
+ batch_size = min(batch_size, len(cases_dataset))
+ cardinality = math.ceil(len(cases_dataset) / batch_size)
+
+ # verify cardinality and create datasets from the tensors
+ self.cases_dataset = sanitize_dataset(
+ cases_dataset, batch_size, cardinality
+ )
+ self.labels_dataset = sanitize_dataset(
+ labels_dataset, batch_size, cardinality
+ )
+ self.targets_dataset = sanitize_dataset(
+ targets_dataset, batch_size, cardinality
+ )
+
+ # if the provided `cases_dataset` has several columns
+ if isinstance(self.cases_dataset.element_spec, tuple):
+ # switch case on the number of columns of `cases_dataset`
+ if len(self.cases_dataset.element_spec) == 2:
+ assert self.labels_dataset is None, (
+ "The second column of `cases_dataset` is assumed to be the labels."
+ + "Hence, `labels_dataset` should be empty."
+ )
+ self.labels_dataset = self.cases_dataset.map(lambda x, y: y)
+ self.cases_dataset = self.cases_dataset.map(lambda x, y: x)
+
+ elif len(self.cases_dataset.element_spec) == 3:
+ assert self.labels_dataset is None, (
+ "The second column of `cases_dataset` is assumed to be the labels."
+ + "Hence, `labels_dataset` should be empty."
+ )
+ assert self.targets_dataset is None, (
+ "The second column of `cases_dataset` is assumed to be the labels."
+ + "Hence, `labels_dataset` should be empty."
+ )
+ self.targets_dataset = self.cases_dataset.map(lambda x, y, t: t)
+ self.labels_dataset = self.cases_dataset.map(lambda x, y, t: y)
+ self.cases_dataset = self.cases_dataset.map(lambda x, y, t: x)
+ else:
+ raise AttributeError(
+ "`cases_dataset` cannot possess more than 3 columns,"
+ + f"{len(self.cases_dataset.element_spec)} were detected."
+ )
+
+ self.cases_dataset = self.cases_dataset.prefetch(tf.data.AUTOTUNE)
+ if self.labels_dataset is not None:
+ self.labels_dataset = self.labels_dataset.prefetch(tf.data.AUTOTUNE)
+ if self.targets_dataset is not None:
+ self.targets_dataset = self.targets_dataset.prefetch(tf.data.AUTOTUNE)
+
+ return batch_size
diff --git a/xplique/example_based/search_methods/__init__.py b/xplique/example_based/search_methods/__init__.py
index d54e85a4..351e7ba9 100644
--- a/xplique/example_based/search_methods/__init__.py
+++ b/xplique/example_based/search_methods/__init__.py
@@ -6,7 +6,6 @@
# from .sklearn_knn import SklearnKNN
from .knn import KNN
-from .prototypes_search import PrototypesSearch
from .proto_greedy_search import ProtoGreedySearch
from .proto_dash_search import ProtoDashSearch
from .mmd_critic_search import MMDCriticSearch
diff --git a/xplique/example_based/search_methods/base.py b/xplique/example_based/search_methods/base.py
index 5dde1a9c..1c7c0f1b 100644
--- a/xplique/example_based/search_methods/base.py
+++ b/xplique/example_based/search_methods/base.py
@@ -80,7 +80,6 @@ class BaseSearchMethod(ABC):
def __init__(
self,
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
k: int = 1,
search_returns: Optional[Union[List[str], str]] = None,
batch_size: Optional[int] = 32,
@@ -93,7 +92,6 @@ def __init__(
self.batch_size = batch_size
self.cases_dataset = sanitize_dataset(cases_dataset, self.batch_size)
- self.labels_dataset = sanitize_dataset(labels_dataset, self.batch_size)
self.set_k(k)
self.set_returns(search_returns)
diff --git a/xplique/example_based/search_methods/mmd_critic_search.py b/xplique/example_based/search_methods/mmd_critic_search.py
index cfe70941..87ad3188 100644
--- a/xplique/example_based/search_methods/mmd_critic_search.py
+++ b/xplique/example_based/search_methods/mmd_critic_search.py
@@ -14,14 +14,20 @@
class MMDCriticSearch(ProtoGreedySearch):
"""
- KNN method to search examples. Based on `sklearn.neighbors.NearestNeighbors`.
- Basically a wrapper of `NearestNeighbors` to match the `BaseSearchMethod` API.
+ MMDCritic method to search prototypes.
+
+ References:
+ .. [#] `Been Kim, Rajiv Khanna, Oluwasanmi Koyejo,
+ "Examples are not enough, learn to criticize! criticism for interpretability"
+ `_
Parameters
----------
cases_dataset
The dataset used to train the model, examples are extracted from the dataset.
For natural example-based methods it is the train dataset.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
k
The number of examples to retrieve.
search_returns
@@ -35,6 +41,13 @@ class MMDCriticSearch(ProtoGreedySearch):
Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
"Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
yielding the corresponding p-norm." We also added 'cosine'.
+ nb_prototypes : int
+ Number of prototypes to find.
+ kernel_type : str, optional
+ The kernel type. It can be 'local' or 'global', by default 'local'.
+ When it is local, the distances are calculated only within the classes.
+ kernel_fn : Callable, optional
+ Kernel function or kernel matrix, by default rbf_kernel.
"""
def compute_objectives(self, selection_indices, selection_cases, selection_labels, selection_weights, selection_selection_kernel, candidates_indices, candidates_cases, candidates_labels, candidates_selection_kernel):
diff --git a/xplique/example_based/search_methods/proto_dash_search.py b/xplique/example_based/search_methods/proto_dash_search.py
index e29d12b8..44d6e3e9 100644
--- a/xplique/example_based/search_methods/proto_dash_search.py
+++ b/xplique/example_based/search_methods/proto_dash_search.py
@@ -93,6 +93,8 @@ class ProtoDashSearch(ProtoGreedySearch):
cases_dataset
The dataset used to train the model, examples are extracted from the dataset.
For natural example-based methods it is the train dataset.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
k
The number of examples to retrieve.
search_returns
@@ -108,39 +110,42 @@ class ProtoDashSearch(ProtoGreedySearch):
yielding the corresponding p-norm." We also added 'cosine'.
nb_prototypes : int
Number of prototypes to find.
- find_prototypes_kwargs
- Additional parameters passed to `find_prototypes` function.
+ kernel_type : str, optional
+ The kernel type. It can be 'local' or 'global', by default 'local'.
+ When it is local, the distances are calculated only within the classes.
+ kernel_fn : Callable, optional
+ Kernel function or kernel matrix, by default rbf_kernel.
+ use_optimizer : bool, optional
+ Flag indicating whether to use an optimizer for prototype selection, by default False.
"""
- def find_prototypes(self, nb_prototypes, kernel_type: str = 'local', kernel_fn: callable = rbf_kernel, use_optimizer: bool = False):
- """
- Search for prototypes and their corresponding weights.
-
- Parameters
- ----------
- nb_prototypes : int
- Number of prototypes to find.
- nb_prototypes : int
- Number of prototypes to find.
- kernel_type : str, optional
- The kernel type. It can be 'local' or 'global', by default 'local'.
- When it is local, the distances are calculated only within the classes.
- kernel_fn : Callable, optional
- Kernel function or kernel matrix, by default rbf_kernel.
- use_optimizer : bool, optional
- Flag indicating whether to use an optimizer for prototype selection, by default False.
-
- Returns
- -------
- prototype_indices : Tensor
- The indices of the selected prototypes.
- prototype_weights :
- The normalized weights of the selected prototypes.
- """
-
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ search_returns: Optional[Union[List[str], str]] = None,
+ batch_size: Optional[int] = 32,
+ distance: Union[int, str, Callable] = None,
+ nb_prototypes: int = 1,
+ kernel_type: str = 'local',
+ kernel_fn: callable = rbf_kernel,
+ use_optimizer: bool = False,
+ ): # pylint: disable=R0801
+
self.use_optimizer = use_optimizer
- return super().find_prototypes(nb_prototypes, kernel_type, kernel_fn)
+ super().__init__(
+ cases_dataset=cases_dataset,
+ labels_dataset=labels_dataset,
+ k=k,
+ search_returns=search_returns,
+ batch_size=batch_size,
+ distance=distance,
+ nb_prototypes=nb_prototypes,
+ kernel_type=kernel_type,
+ kernel_fn=kernel_fn
+ )
def update_selection_weights(self, selection_indices, selection_weights, selection_selection_kernel, best_indice, best_weights, best_objective):
"""
diff --git a/xplique/example_based/search_methods/proto_greedy_search.py b/xplique/example_based/search_methods/proto_greedy_search.py
index a383b33a..df3bd5a5 100644
--- a/xplique/example_based/search_methods/proto_greedy_search.py
+++ b/xplique/example_based/search_methods/proto_greedy_search.py
@@ -9,11 +9,12 @@
from ...commons import dataset_gather, sanitize_dataset
from ...types import Callable, List, Union, Optional, Tuple
-from .prototypes_search import PrototypesSearch
+from .base import BaseSearchMethod
+from .knn import KNN
from ..projections import Projection
-class ProtoGreedySearch(PrototypesSearch):
+class ProtoGreedySearch(BaseSearchMethod):
"""
ProtoGreedy method for searching prototypes.
@@ -27,6 +28,8 @@ class ProtoGreedySearch(PrototypesSearch):
cases_dataset
The dataset used to train the model, examples are extracted from the dataset.
For natural example-based methods it is the train dataset.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
k
The number of examples to retrieve.
search_returns
@@ -42,14 +45,144 @@ class ProtoGreedySearch(PrototypesSearch):
yielding the corresponding p-norm." We also added 'cosine'.
nb_prototypes : int
Number of prototypes to find.
- find_prototypes_kwargs
- Additional parameters passed to `find_prototypes` function.
+ kernel_type : str, optional
+ The kernel type. It can be 'local' or 'global', by default 'local'.
+ When it is local, the distances are calculated only within the classes.
+ kernel_fn : Callable, optional
+ Kernel function or kernel matrix, by default rbf_kernel.
"""
# Avoid zero division during procedure. (the value is not important, as if the denominator is
# zero, then the nominator will also be zero).
EPSILON = tf.constant(1e-6)
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ search_returns: Optional[Union[List[str], str]] = None,
+ batch_size: Optional[int] = 32,
+ distance: Union[int, str, Callable] = None,
+ nb_prototypes: int = 1,
+ kernel_type: str = 'local',
+ kernel_fn: callable = rbf_kernel,
+ ): # pylint: disable=R0801
+ super().__init__(
+ cases_dataset, k, search_returns, batch_size
+ )
+
+ self.labels_dataset = sanitize_dataset(labels_dataset, self.batch_size)
+
+ if kernel_type in ['local', 'global']:
+ self.kernel_type = kernel_type
+ else:
+ raise AttributeError(
+ "The kernel_type parameter is expected to be in"
+ + " ['local', 'global'] ",
+ +f"but {kernel_type} was received.",
+ )
+
+ if hasattr(kernel_fn, "__call__"):
+ def custom_kernel_fn(x1, x2, y1=None, y2=None):
+ if self.kernel_type == 'global':
+ kernel_matrix = kernel_fn(x1,x2)
+ if isinstance(kernel_matrix, np.ndarray):
+ kernel_matrix = tf.convert_to_tensor(kernel_matrix)
+ else:
+ # In the case of a local kernel, calculations are limited to within the class.
+ # Across different classes, the kernel values are set to 0.
+ kernel_matrix = np.zeros((x1.shape[0], x2.shape[0]), dtype=np.float32)
+ y_intersect = np.intersect1d(y1, y2)
+ for i in range(y_intersect.shape[0]):
+ y1_indices = tf.where(tf.equal(y1, y_intersect[i]))[:, 0]
+ y2_indices = tf.where(tf.equal(y2, y_intersect[i]))[:, 0]
+ sub_matrix = kernel_fn(tf.gather(x1, y1_indices), tf.gather(x2, y2_indices))
+ kernel_matrix[tf.reshape(y1_indices, (-1, 1)), tf.reshape(y2_indices, (1, -1))] = sub_matrix
+ kernel_matrix = tf.convert_to_tensor(kernel_matrix)
+ return kernel_matrix
+
+ self.kernel_fn = custom_kernel_fn
+ else:
+ raise AttributeError(
+ "The kernel parameter is expected to be a Callable",
+ +f"but {kernel_fn} was received.",
+ )
+
+ # Compute the sum of the columns and the diagonal values of the kernel matrix of the dataset.
+ # We take advantage of the symmetry of this matrix to traverse only its lower triangle.
+ col_sums = []
+ diag = []
+ row_sums = []
+
+ for batch_col_index, (batch_col_cases, batch_col_labels) in enumerate(
+ zip(self.cases_dataset, self.labels_dataset)
+ ):
+ batch_col_sums = tf.zeros((batch_col_cases.shape[0]))
+
+ for batch_row_index, (batch_row_cases, batch_row_labels) in enumerate(
+ zip(self.cases_dataset, self.labels_dataset)
+ ):
+ if batch_row_index < batch_col_index:
+ continue
+
+ batch_kernel = self.kernel_fn(batch_row_cases, batch_col_cases, batch_row_labels, batch_col_labels)
+
+ batch_col_sums = batch_col_sums + tf.reduce_sum(batch_kernel, axis=0)
+
+ if batch_col_index == batch_row_index:
+ if batch_col_index != 0:
+ batch_col_sums = batch_col_sums + row_sums[batch_row_index]
+
+ diag.append(tf.linalg.diag_part(batch_kernel))
+
+ if batch_col_index == 0:
+ if batch_row_index == 0:
+ row_sums.append(None)
+ else:
+ row_sums.append(tf.reduce_sum(batch_kernel, axis=1))
+ else:
+ row_sums[batch_row_index] += tf.reduce_sum(batch_kernel, axis=1)
+
+ col_sums.append(batch_col_sums)
+
+ self.col_sums = tf.concat(col_sums, axis=0)
+ self.n = self.col_sums.shape[0]
+ self.col_means = self.col_sums / self.n
+ self.diag = tf.concat(diag, axis=0)
+ self.nb_features = batch_col_cases.shape[1]
+
+ # compute the prototypes in the latent space
+ self.prototype_indices, self.prototype_cases, self.prototype_labels, self.prototype_weights = self.find_prototypes(nb_prototypes)
+
+ if distance is None:
+ def custom_distance(x1,x2):
+ x1 = tf.expand_dims(x1, axis=0)
+ x2 = tf.expand_dims(x2, axis=0)
+ distance = tf.sqrt(kernel_fn(x1,x1) - 2 * kernel_fn(x1,x2) + kernel_fn(x2,x2))
+ return distance
+ self.distance_fn = custom_distance
+ elif hasattr(distance, "__call__"):
+ self.distance_fn = distance
+ elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(
+ distance, int
+ ):
+ self.distance_fn = lambda x1, x2: tf.norm(x1 - x2, ord=distance)
+ else:
+ raise AttributeError(
+ "The distance parameter is expected to be either a Callable or in"
+ + " ['fro', 'euclidean', 'cosine', 1, 2, np.inf] ",
+ +f"but {distance} was received.",
+ )
+
+ self.knn = KNN(
+ cases_dataset=self.prototype_cases,
+ k=k,
+ search_returns=search_returns,
+ batch_size=batch_size,
+ distance=self.distance_fn
+ )
+
def compute_objectives(self, selection_indices, selection_cases, selection_labels, selection_weights, selection_selection_kernel, candidates_indices, candidates_cases, candidates_labels, candidates_selection_kernel):
"""
Compute the objective and its weights for each candidate.
@@ -144,104 +277,7 @@ def update_selection_weights(self, selection_indices, selection_weights, selecti
return selection_weights
- def compute_kernel_attributes(self, kernel_type: str = 'local', kernel_fn: callable = rbf_kernel):
- """
- Compute the attributes of the class that are related to the kernel.
-
- Parameters
- ----------
- kernel_type : str, optional
- The kernel type. It can be 'local' or 'global', by default 'local'.
- When it is local, the distances are calculated only within the classes.
- kernel_fn : Callable, optional
- Kernel function or kernel matrix, by default rbf_kernel.
-
- Returns
- -------
- selection_weights : Tensor
- Updated weights corresponding to the selected prototypes.
- """
- if kernel_type in ['local', 'global']:
- self.kernel_type = kernel_type
- else:
- raise AttributeError(
- "The kernel_type parameter is expected to be in"
- + " ['local', 'global'] ",
- +f"but {kernel_type} was received.",
- )
-
- if hasattr(kernel_fn, "__call__"):
- def custom_kernel_fn(x1, x2, y1, y2):
- if self.kernel_type == 'global':
- kernel_matrix = kernel_fn(x1,x2)
- if isinstance(kernel_matrix, np.ndarray):
- kernel_matrix = tf.convert_to_tensor(kernel_matrix)
- else:
- # In the case of a local kernel, calculations are limited to within the class.
- # Across different classes, the kernel values are set to 0.
- kernel_matrix = np.zeros((x1.shape[0], x2.shape[0]), dtype=np.float32)
- y_intersect = np.intersect1d(y1, y2)
- for i in range(y_intersect.shape[0]):
- y1_indices = tf.where(tf.equal(y1, y_intersect[i]))[:, 0]
- y2_indices = tf.where(tf.equal(y2, y_intersect[i]))[:, 0]
- sub_matrix = kernel_fn(tf.gather(x1, y1_indices), tf.gather(x2, y2_indices))
- kernel_matrix[tf.reshape(y1_indices, (-1, 1)), tf.reshape(y2_indices, (1, -1))] = sub_matrix
- kernel_matrix = tf.convert_to_tensor(kernel_matrix)
- return kernel_matrix
-
- self.kernel_fn = custom_kernel_fn
- else:
- raise AttributeError(
- "The kernel parameter is expected to be a Callable",
- +f"but {kernel_fn} was received.",
- )
-
- # TODO: for local explanation add the ability to compute distance_fn based on the kernel
-
- # Compute the sum of the columns and the diagonal values of the kernel matrix of the dataset.
- # We take advantage of the symmetry of this matrix to traverse only its lower triangle.
- col_sums = []
- diag = []
- row_sums = []
-
- for batch_col_index, (batch_col_cases, batch_col_labels) in enumerate(
- zip(self.cases_dataset, self.labels_dataset)
- ):
- batch_col_sums = tf.zeros((batch_col_cases.shape[0]))
-
- for batch_row_index, (batch_row_cases, batch_row_labels) in enumerate(
- zip(self.cases_dataset, self.labels_dataset)
- ):
- if batch_row_index < batch_col_index:
- continue
-
- batch_kernel = self.kernel_fn(batch_row_cases, batch_col_cases, batch_row_labels, batch_col_labels)
-
- batch_col_sums = batch_col_sums + tf.reduce_sum(batch_kernel, axis=0)
-
- if batch_col_index == batch_row_index:
- if batch_col_index != 0:
- batch_col_sums = batch_col_sums + row_sums[batch_row_index]
-
- diag.append(tf.linalg.diag_part(batch_kernel))
-
- if batch_col_index == 0:
- if batch_row_index == 0:
- row_sums.append(None)
- else:
- row_sums.append(tf.reduce_sum(batch_kernel, axis=1))
- else:
- row_sums[batch_row_index] += tf.reduce_sum(batch_kernel, axis=1)
-
- col_sums.append(batch_col_sums)
-
- self.col_sums = tf.concat(col_sums, axis=0)
- self.n = self.col_sums.shape[0]
- self.col_means = self.col_sums / self.n
- self.diag = tf.concat(diag, axis=0)
- self.nb_features = batch_col_cases.shape[1]
-
- def find_prototypes(self, nb_prototypes, kernel_type: str = 'local', kernel_fn: callable = rbf_kernel):
+ def find_prototypes(self, nb_prototypes):
"""
Search for prototypes and their corresponding weights.
@@ -249,22 +285,19 @@ def find_prototypes(self, nb_prototypes, kernel_type: str = 'local', kernel_fn:
----------
nb_prototypes : int
Number of prototypes to find.
- kernel_type : str, optional
- The kernel type. It can be 'local' or 'global', by default 'local'.
- When it is local, the distances are calculated only within the classes.
- kernel_fn : Callable, optional
- Kernel function or kernel matrix, by default rbf_kernel.
Returns
-------
prototype_indices : Tensor
The indices of the selected prototypes.
+ prototype_cases : Tensor
+ The cases of the selected prototypes.
+ prototype_labels : Tensor
+ The labels of the selected prototypes.
prototype_weights :
The normalized weights of the selected prototypes.
"""
- self.compute_kernel_attributes(kernel_type, kernel_fn)
-
# Tensors to store selected indices and their corresponding cases, labels and weights.
selection_indices = tf.constant([], dtype=tf.int32)
selection_cases = tf.zeros((0, self.nb_features), dtype=tf.float32)
@@ -362,9 +395,45 @@ def find_prototypes(self, nb_prototypes, kernel_type: str = 'local', kernel_fn:
k += 1
prototype_indices = selection_indices
+ prototype_cases = selection_cases
+ prototype_labels = selection_labels
prototype_weights = selection_weights
# Normalize the weights
prototype_weights = prototype_weights / tf.reduce_sum(prototype_weights)
- return prototype_indices, prototype_weights
\ No newline at end of file
+ return prototype_indices, prototype_cases, prototype_labels, prototype_weights
+
+ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray]):
+ """
+ Search the samples to return as examples. Called by the explain methods.
+ It may also return the indices corresponding to the samples,
+ based on `return_indices` value.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Assumed to have been already projected.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ """
+
+ # look for closest prototypes to projected inputs
+ knn_output = self.knn(inputs)
+
+ # obtain closest prototypes indices with respect to the prototypes
+ indices_wrt_prototypes = knn_output["indices"]
+
+ # convert to unique indices
+ indices_wrt_prototypes = indices_wrt_prototypes[:, :, 0] * self.batch_size + indices_wrt_prototypes[:, :, 1]
+
+ # get prototypes indices with respect to the dataset
+ indices = tf.gather(self.prototype_indices, indices_wrt_prototypes)
+
+ # convert back to batch-element indices
+ batch_indices, elem_indices = indices // self.batch_size, indices % self.batch_size
+ indices = tf.stack([batch_indices, elem_indices], axis=-1)
+
+ knn_output["indices"] = indices
+
+ return knn_output
\ No newline at end of file
diff --git a/xplique/example_based/search_methods/prototypes_search.py b/xplique/example_based/search_methods/prototypes_search.py
deleted file mode 100644
index 6db01b3c..00000000
--- a/xplique/example_based/search_methods/prototypes_search.py
+++ /dev/null
@@ -1,137 +0,0 @@
-"""
-Prototypes search method in example-based module
-"""
-
-from abc import ABC, abstractmethod
-import numpy as np
-from sklearn.metrics.pairwise import rbf_kernel
-import tensorflow as tf
-
-from ...commons import dataset_gather
-from ...types import Callable, List, Union, Optional, Tuple
-
-from .base import BaseSearchMethod
-from ..projections import Projection
-
-
-class PrototypesSearch(BaseSearchMethod):
- """
- Prototypes search method to find prototypes and the examples closest to these prototypes.
-
- Parameters
- ----------
- cases_dataset
- The dataset used to train the model, examples are extracted from the dataset.
- For natural example-based methods it is the train dataset.
- k
- The number of examples to retrieve.
- search_returns
- String or list of string with the elements to return in `self.find_examples()`.
- See `self.set_returns()` for detail.
- batch_size
- Number of sample treated simultaneously.
- It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
- distance
- Either a Callable, or a value supported by `tf.norm` `ord` parameter.
- Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
- "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
- yielding the corresponding p-norm." We also added 'cosine'.
- nb_prototypes : int
- Number of prototypes to find.
- find_prototypes_kwargs
- Additional parameters passed to `find_prototypes` function.
- """
-
- # Avoid zero division during procedure. (the value is not important, as if the denominator is
- # zero, then the nominator will also be zero).
- EPSILON = tf.constant(1e-6)
-
- def __init__(
- self,
- cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- k: int = 1,
- search_returns: Optional[Union[List[str], str]] = None,
- batch_size: Optional[int] = 32,
- distance: Union[int, str, Callable] = "euclidean",
- nb_prototypes: int = 1,
- **find_prototypes_kwargs,
- ): # pylint: disable=R0801
- super().__init__(
- cases_dataset, labels_dataset, k, search_returns, batch_size
- )
-
- if hasattr(distance, "__call__"):
- self.distance_fn = distance
- elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(
- distance, int
- ):
- self.distance_fn = lambda x1, x2: tf.norm(x1 - x2, ord=distance)
- else:
- raise AttributeError(
- "The distance parameter is expected to be either a Callable or in"
- + " ['fro', 'euclidean', 'cosine', 1, 2, np.inf] ",
- +f"but {distance} was received.",
- )
-
- self.prototype_indices, self.prototype_weights = self.find_prototypes(nb_prototypes, **find_prototypes_kwargs)
-
- @abstractmethod
- def find_prototypes(self, nb_prototypes: int, **find_prototypes_kwargs):
- """
- Search for prototypes and their corresponding weights.
-
- Parameters
- ----------
- nb_prototypes : int
- Number of prototypes to find.
-
- find_prototypes_kwargs
- Additional parameters passed to `find_prototypes` function.
-
- Returns
- -------
- prototype_indices : Tensor
- The indices of the selected prototypes.
- prototype_weights :
- The normalized weights of the selected prototypes.
- """
- return NotImplementedError()
-
- def find_examples(self, inputs: Union[tf.Tensor, np.ndarray]):
- """
- Search the samples to return as examples. Called by the explain methods.
- It may also return the indices corresponding to the samples,
- based on `return_indices` value.
-
- Parameters
- ----------
- inputs
- Tensor or Array. Input samples to be explained.
- Assumed to have been already projected.
- Expected shape among (N, W), (N, T, W), (N, W, H, C).
- """
- # TODO: Find examples: here we provide a local explanation.
- # Find the nearest prototypes to inputs
- # we use self.distance_fn and self.prototype_indices.
- examples_indices = None
- examples_distances = None
-
- # Set values in return dict
- return_dict = {}
- if "examples" in self.returns:
- return_dict["examples"] = dataset_gather(self.cases_dataset, examples_indices)
- if "include_inputs" in self.returns:
- inputs = tf.expand_dims(inputs, axis=1)
- return_dict["examples"] = tf.concat(
- [inputs, return_dict["examples"]], axis=1
- )
- if "indices" in self.returns:
- return_dict["indices"] = examples_indices
- if "distances" in self.returns:
- return_dict["distances"] = examples_distances
-
- # Return a dict only different variables are returned
- if len(return_dict) == 1:
- return list(return_dict.values())[0]
- return return_dict
From 46494360b4132547679d4b931cbd0c8e18f6621f Mon Sep 17 00:00:00 2001
From: Mohamed Chafik Bakey
Date: Tue, 30 Apr 2024 16:44:42 +0200
Subject: [PATCH 037/138] add Prototypes fix up
---
tests/example_based/test_prototypes.py | 63 +++++---------
.../search_methods/mmd_critic_search.py | 13 ++-
.../search_methods/proto_dash_search.py | 20 ++---
.../search_methods/proto_greedy_search.py | 85 ++++++++++++-------
4 files changed, 88 insertions(+), 93 deletions(-)
diff --git a/tests/example_based/test_prototypes.py b/tests/example_based/test_prototypes.py
index 720c47ee..8a31b24d 100644
--- a/tests/example_based/test_prototypes.py
+++ b/tests/example_based/test_prototypes.py
@@ -11,7 +11,6 @@
import time
import numpy as np
-from sklearn.metrics.pairwise import rbf_kernel
import tensorflow as tf
from xplique.commons import sanitize_dataset, are_dataset_first_elems_equal
@@ -25,13 +24,14 @@
def test_proto_greedy_basic():
"""
- Test the SimilarExamples with an identity projection.
+ Test the Prototypes with an identity projection.
"""
# Setup
k = 3
nb_prototypes = 3
gamma = 0.026
x_train, y_train = get_Gaussian_Data(nb_samples_class=20)
+ x_test, y_test = get_Gaussian_Data(nb_samples_class=10)
# x_train, y_train = load_data('usps')
# x_test, y_test = load_data('usps.t')
# x_test = tf.random.shuffle(x_test)
@@ -41,13 +41,6 @@ def test_proto_greedy_basic():
space_projection=lambda inputs, targets=None: inputs
)
- def custom_kernel_wrapper(gamma):
- def custom_kernel(x,y=None):
- return rbf_kernel(x,y,gamma)
- return custom_kernel
-
- kernel_fn = custom_kernel_wrapper(gamma)
-
kernel_type = "global"
# Method initialization
@@ -57,10 +50,10 @@ def custom_kernel(x,y=None):
k=k,
projection=identity_projection,
batch_size=32,
- distance="euclidean",
+ distance=None, #"euclidean",
nb_prototypes=nb_prototypes,
kernel_type=kernel_type,
- kernel_fn=kernel_fn,
+ gamma=gamma,
)
# Generate global explanation
@@ -92,24 +85,25 @@ def custom_kernel(x,y=None):
# Check if all indices are between 0 and x_train.shape[0]-1
assert tf.reduce_all(tf.math.logical_and(prototype_indices >= 0, prototype_indices <= x_train.shape[0]-1))
+ # Generate local explanation
+ examples = method.explain(x_test)
+
# # Visualize all prototypes
# plot(prototypes, prototype_weights, 'proto_greedy')
- # # Generate local explanation
- # examples = method.explain(x_test)
-
# # Visualize local explanation
# plot_local_explanation(examples, x_test, 'proto_greedy')
def test_proto_dash_basic():
"""
- Test the SimilarExamples with an identity projection.
+ Test the Prototypes with an identity projection.
"""
# Setup
k = 3
nb_prototypes = 3
gamma = 0.026
x_train, y_train = get_Gaussian_Data(nb_samples_class=20)
+ x_test, y_test = get_Gaussian_Data(nb_samples_class=10)
# x_train, y_train = load_data('usps')
# x_test, y_test = load_data('usps.t')
# x_test = tf.random.shuffle(x_test)
@@ -119,13 +113,6 @@ def test_proto_dash_basic():
space_projection=lambda inputs, targets=None: inputs
)
- def custom_kernel_wrapper(gamma):
- def custom_kernel(x,y=None):
- return rbf_kernel(x,y,gamma)
- return custom_kernel
-
- kernel_fn = custom_kernel_wrapper(gamma)
-
kernel_type = "global"
# Method initialization
@@ -138,7 +125,7 @@ def custom_kernel(x,y=None):
distance="euclidean",
nb_prototypes=nb_prototypes,
kernel_type=kernel_type,
- kernel_fn=kernel_fn,
+ gamma=gamma,
)
# Generate global explanation
@@ -170,24 +157,25 @@ def custom_kernel(x,y=None):
# Check if all indices are between 0 and x_train.shape[0]-1
assert tf.reduce_all(tf.math.logical_and(prototype_indices >= 0, prototype_indices <= x_train.shape[0]-1))
+ # Generate local explanation
+ examples = method.explain(x_test)
+
# # Visualize all prototypes
# plot(prototypes, prototype_weights, 'proto_dash')
- # # Generate local explanation
- # examples = method.explain(x_test)
-
# # Visualize local explanation
# plot_local_explanation(examples, x_test, 'proto_dash')
def test_mmd_critic_basic():
"""
- Test the SimilarExamples with an identity projection.
+ Test the Prototypes with an identity projection.
"""
# Setup
k = 3
nb_prototypes = 3
gamma = 0.026
x_train, y_train = get_Gaussian_Data(nb_samples_class=20)
+ x_test, y_test = get_Gaussian_Data(nb_samples_class=10)
# x_train, y_train = load_data('usps')
# x_test, y_test = load_data('usps.t')
# x_test = tf.random.shuffle(x_test)
@@ -197,13 +185,6 @@ def test_mmd_critic_basic():
space_projection=lambda inputs, targets=None: inputs
)
- def custom_kernel_wrapper(gamma):
- def custom_kernel(x,y=None):
- return rbf_kernel(x,y,gamma)
- return custom_kernel
-
- kernel_fn = custom_kernel_wrapper(gamma)
-
kernel_type = "global"
# Method initialization
@@ -216,7 +197,7 @@ def custom_kernel(x,y=None):
distance="euclidean",
nb_prototypes=nb_prototypes,
kernel_type=kernel_type,
- kernel_fn=kernel_fn,
+ gamma=gamma,
)
# Generate global explanation
@@ -248,15 +229,15 @@ def custom_kernel(x,y=None):
# Check if all indices are between 0 and x_train.shape[0]-1
assert tf.reduce_all(tf.math.logical_and(prototype_indices >= 0, prototype_indices <= x_train.shape[0]-1))
+ # Generate local explanation
+ examples = method.explain(x_test)
+
# # Visualize all prototypes
# plot(prototypes, prototype_weights, 'mmd_critic')
- # # Generate local explanation
- # examples = method.explain(x_test)
-
# # Visualize local explanation
# plot_local_explanation(examples, x_test, 'mmd_critic')
-test_proto_greedy_basic()
-test_proto_dash_basic()
-test_mmd_critic_basic()
+# test_proto_greedy_basic()
+# test_proto_dash_basic()
+# test_mmd_critic_basic()
diff --git a/xplique/example_based/search_methods/mmd_critic_search.py b/xplique/example_based/search_methods/mmd_critic_search.py
index 87ad3188..7465fcfb 100644
--- a/xplique/example_based/search_methods/mmd_critic_search.py
+++ b/xplique/example_based/search_methods/mmd_critic_search.py
@@ -47,10 +47,13 @@ class MMDCriticSearch(ProtoGreedySearch):
The kernel type. It can be 'local' or 'global', by default 'local'.
When it is local, the distances are calculated only within the classes.
kernel_fn : Callable, optional
- Kernel function or kernel matrix, by default rbf_kernel.
+ Kernel function, by default the rbf kernel.
+ This function must only use TensorFlow operations.
+ gamma : float, optional
+ Parameter that determines the spread of the rbf kernel, defaults to 1.0 / n_features.
"""
- def compute_objectives(self, selection_indices, selection_cases, selection_labels, selection_weights, selection_selection_kernel, candidates_indices, candidates_cases, candidates_labels, candidates_selection_kernel):
+ def compute_objectives(self, selection_indices, selection_cases, selection_weights, selection_selection_kernel, candidates_indices, candidates_selection_kernel):
"""
Compute the objective function and corresponding weights for a given set of selected prototypes and a candidate.
@@ -70,18 +73,12 @@ def compute_objectives(self, selection_indices, selection_cases, selection_label
Indices corresponding to the selected prototypes.
selection_cases : Tensor
Cases corresponding to the selected prototypes.
- selection_labels : Tensor
- Labels corresponding to the selected prototypes.
selection_weights : Tensor
Weights corresponding to the selected prototypes.
selection_selection_kernel : Tensor
Kernel matrix computed from the selected prototypes.
candidates_indices : Tensor
Indices corresponding to the candidate prototypes.
- candidates_cases : Tensor
- Cases corresponding to the candidate prototypes.
- candidates_labels : Tensor
- Labels corresponding to the candidate prototypes.
candidates_selection_kernel : Tensor
Kernel matrix between the candidates and the selected prototypes.
diff --git a/xplique/example_based/search_methods/proto_dash_search.py b/xplique/example_based/search_methods/proto_dash_search.py
index 44d6e3e9..5bb7b78b 100644
--- a/xplique/example_based/search_methods/proto_dash_search.py
+++ b/xplique/example_based/search_methods/proto_dash_search.py
@@ -3,7 +3,6 @@
"""
import numpy as np
-from sklearn.metrics.pairwise import rbf_kernel
from scipy.optimize import minimize
import tensorflow as tf
@@ -114,7 +113,10 @@ class ProtoDashSearch(ProtoGreedySearch):
The kernel type. It can be 'local' or 'global', by default 'local'.
When it is local, the distances are calculated only within the classes.
kernel_fn : Callable, optional
- Kernel function or kernel matrix, by default rbf_kernel.
+ Kernel function, by default the rbf kernel.
+ This function must only use TensorFlow operations.
+ gamma : float, optional
+ Parameter that determines the spread of the rbf kernel, defaults to 1.0 / n_features.
use_optimizer : bool, optional
Flag indicating whether to use an optimizer for prototype selection, by default False.
"""
@@ -129,7 +131,8 @@ def __init__(
distance: Union[int, str, Callable] = None,
nb_prototypes: int = 1,
kernel_type: str = 'local',
- kernel_fn: callable = rbf_kernel,
+ kernel_fn: callable = None,
+ gamma: float = None,
use_optimizer: bool = False,
): # pylint: disable=R0801
@@ -144,7 +147,8 @@ def __init__(
distance=distance,
nb_prototypes=nb_prototypes,
kernel_type=kernel_type,
- kernel_fn=kernel_fn
+ kernel_fn=kernel_fn,
+ gamma=gamma
)
def update_selection_weights(self, selection_indices, selection_weights, selection_selection_kernel, best_indice, best_weights, best_objective):
@@ -197,7 +201,7 @@ def update_selection_weights(self, selection_indices, selection_weights, selecti
return selection_weights
- def compute_objectives(self, selection_indices, selection_cases, selection_labels, selection_weights, selection_selection_kernel, candidates_indices, candidates_cases, candidates_labels, candidates_selection_kernel):
+ def compute_objectives(self, selection_indices, selection_cases, selection_weights, selection_selection_kernel, candidates_indices, candidates_selection_kernel):
"""
Compute the objective function and corresponding weights for a given set of selected prototypes and a candidate.
Calculate the gradient of l(w) = w^T * μ_p - 1/2 * w^T * K * w
@@ -211,18 +215,12 @@ def compute_objectives(self, selection_indices, selection_cases, selection_label
Indices corresponding to the selected prototypes.
selection_cases : Tensor
Cases corresponding to the selected prototypes.
- selection_labels : Tensor
- Labels corresponding to the selected prototypes.
selection_weights : Tensor
Weights corresponding to the selected prototypes.
selection_selection_kernel : Tensor
Kernel matrix computed from the selected prototypes.
candidates_indices : Tensor
Indices corresponding to the candidate prototypes.
- candidates_cases : Tensor
- Cases corresponding to the candidate prototypes.
- candidates_labels : Tensor
- Labels corresponding to the candidate prototypes.
candidates_selection_kernel : Tensor
Kernel matrix between the candidates and the selected prototypes.
diff --git a/xplique/example_based/search_methods/proto_greedy_search.py b/xplique/example_based/search_methods/proto_greedy_search.py
index df3bd5a5..a86f610d 100644
--- a/xplique/example_based/search_methods/proto_greedy_search.py
+++ b/xplique/example_based/search_methods/proto_greedy_search.py
@@ -3,7 +3,6 @@
"""
import numpy as np
-from sklearn.metrics.pairwise import rbf_kernel
import tensorflow as tf
from ...commons import dataset_gather, sanitize_dataset
@@ -49,7 +48,10 @@ class ProtoGreedySearch(BaseSearchMethod):
The kernel type. It can be 'local' or 'global', by default 'local'.
When it is local, the distances are calculated only within the classes.
kernel_fn : Callable, optional
- Kernel function or kernel matrix, by default rbf_kernel.
+ Kernel function, by default the rbf kernel.
+ This function must only use TensorFlow operations.
+ gamma : float, optional
+ Parameter that determines the spread of the rbf kernel, defaults to 1.0 / n_features.
"""
# Avoid zero division during procedure. (the value is not important, as if the denominator is
@@ -66,7 +68,8 @@ def __init__(
distance: Union[int, str, Callable] = None,
nb_prototypes: int = 1,
kernel_type: str = 'local',
- kernel_fn: callable = rbf_kernel,
+ kernel_fn: callable = None,
+ gamma: float = None
): # pylint: disable=R0801
super().__init__(
cases_dataset, k, search_returns, batch_size
@@ -82,7 +85,27 @@ def __init__(
+ " ['local', 'global'] ",
+f"but {kernel_type} was received.",
)
+
+ if kernel_fn is None:
+ # define rbf kernel function
+ def rbf_kernel(X, Y=None, gamma=None):
+ if Y is None:
+ Y = X
+
+ if gamma is None:
+ gamma = 1.0 / tf.cast(tf.shape(X)[1], dtype=X.dtype)
+
+ X = tf.expand_dims(X, axis=1)
+ Y = tf.expand_dims(Y, axis=0)
+
+ pairwise_diff = X - Y
+ pairwise_sq_dist = tf.reduce_sum(tf.square(pairwise_diff), axis=-1)
+ kernel_matrix = tf.exp(-gamma * pairwise_sq_dist)
+
+ return kernel_matrix
+ kernel_fn = lambda x, y: rbf_kernel(x,y,gamma)
+
if hasattr(kernel_fn, "__call__"):
def custom_kernel_fn(x1, x2, y1=None, y2=None):
if self.kernel_type == 'global':
@@ -105,10 +128,32 @@ def custom_kernel_fn(x1, x2, y1=None, y2=None):
self.kernel_fn = custom_kernel_fn
else:
raise AttributeError(
- "The kernel parameter is expected to be a Callable",
+ "The kernel_fn parameter is expected to be a Callable",
+f"but {kernel_fn} was received.",
- )
+ )
+
+ if distance is None:
+ def kernel_induced_distance(x1,x2):
+ x1 = tf.expand_dims(x1, axis=0)
+ x2 = tf.expand_dims(x2, axis=0)
+ distance = tf.squeeze(tf.sqrt(kernel_fn(x1,x1) - 2 * kernel_fn(x1,x2) + kernel_fn(x2,x2)))
+ return distance
+
+ self.distance_fn = lambda x1, x2: kernel_induced_distance(x1,x2)
+ elif hasattr(distance, "__call__"):
+ self.distance_fn = distance
+ elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(
+ distance, int
+ ):
+ self.distance_fn = lambda x1, x2: tf.norm(x1 - x2, ord=distance)
+ else:
+ raise AttributeError(
+ "The distance parameter is expected to be either a Callable or in"
+ + " ['fro', 'euclidean', 'cosine', 1, 2, np.inf] ",
+ +f"but {distance} was received.",
+ )
+
# Compute the sum of the columns and the diagonal values of the kernel matrix of the dataset.
# We take advantage of the symmetry of this matrix to traverse only its lower triangle.
col_sums = []
@@ -155,26 +200,6 @@ def custom_kernel_fn(x1, x2, y1=None, y2=None):
# compute the prototypes in the latent space
self.prototype_indices, self.prototype_cases, self.prototype_labels, self.prototype_weights = self.find_prototypes(nb_prototypes)
- if distance is None:
- def custom_distance(x1,x2):
- x1 = tf.expand_dims(x1, axis=0)
- x2 = tf.expand_dims(x2, axis=0)
- distance = tf.sqrt(kernel_fn(x1,x1) - 2 * kernel_fn(x1,x2) + kernel_fn(x2,x2))
- return distance
- self.distance_fn = custom_distance
- elif hasattr(distance, "__call__"):
- self.distance_fn = distance
- elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(
- distance, int
- ):
- self.distance_fn = lambda x1, x2: tf.norm(x1 - x2, ord=distance)
- else:
- raise AttributeError(
- "The distance parameter is expected to be either a Callable or in"
- + " ['fro', 'euclidean', 'cosine', 1, 2, np.inf] ",
- +f"but {distance} was received.",
- )
-
self.knn = KNN(
cases_dataset=self.prototype_cases,
k=k,
@@ -183,7 +208,7 @@ def custom_distance(x1,x2):
distance=self.distance_fn
)
- def compute_objectives(self, selection_indices, selection_cases, selection_labels, selection_weights, selection_selection_kernel, candidates_indices, candidates_cases, candidates_labels, candidates_selection_kernel):
+ def compute_objectives(self, selection_indices, selection_cases, selection_weights, selection_selection_kernel, candidates_indices, candidates_selection_kernel):
"""
Compute the objective and its weights for each candidate.
@@ -193,18 +218,12 @@ def compute_objectives(self, selection_indices, selection_cases, selection_label
Indices corresponding to the selected prototypes.
selection_cases : Tensor
Cases corresponding to the selected prototypes.
- selection_labels : Tensor
- Labels corresponding to the selected prototypes.
selection_weights : Tensor
Weights corresponding to the selected prototypes.
selection_selection_kernel : Tensor
Kernel matrix computed from the selected prototypes.
candidates_indices : Tensor
Indices corresponding to the candidate prototypes.
- candidates_cases : Tensor
- Cases corresponding to the candidate prototypes.
- candidates_labels : Tensor
- Labels corresponding to the candidate prototypes.
candidates_selection_kernel : Tensor
Kernel matrix between the candidates and the selected prototypes.
@@ -353,7 +372,7 @@ def find_prototypes(self, nb_prototypes):
all_candidates_last_selected_kernel = tf.tensor_scatter_nd_update(all_candidates_last_selected_kernel, tf.expand_dims(candidates_indices, axis=1), tf.squeeze(candidates_last_selected_kernel, axis=1))
# Compute the objectives for the batch
- objectives, objectives_weights = self.compute_objectives(selection_indices, selection_cases, selection_labels, selection_weights, selection_selection_kernel, candidates_indices, candidates_cases, candidates_labels, candidates_selection_kernel)
+ objectives, objectives_weights = self.compute_objectives(selection_indices, selection_cases, selection_weights, selection_selection_kernel, candidates_indices, candidates_selection_kernel)
# Select the best objective in the batch
objectives_argmax = tf.argmax(objectives)
From a62fd8ea55f4938b69230f5fcc79cea4e8523540 Mon Sep 17 00:00:00 2001
From: Mohamed Chafik Bakey
Date: Tue, 30 Apr 2024 17:01:55 +0200
Subject: [PATCH 038/138] add Prototypes fix up
---
xplique/example_based/prototypes.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/xplique/example_based/prototypes.py b/xplique/example_based/prototypes.py
index 906b6bca..29946c22 100644
--- a/xplique/example_based/prototypes.py
+++ b/xplique/example_based/prototypes.py
@@ -133,7 +133,7 @@ def get_global_prototypes(self):
Returns:
prototype_indices : Tensor
prototype indices.
- prototype_indices : Tensor
+ prototype_weights : Tensor
prototype weights.
"""
return self.search_method.prototype_indices, self.search_method.prototype_weights
From e839e24ff857cea4dc02a4f21759bffd1bd5acdc Mon Sep 17 00:00:00 2001
From: Lucas Hervier
Date: Tue, 21 May 2024 15:18:33 +0200
Subject: [PATCH 039/138] feat: change the private initialize_cases_dataset
method to a protected one
---
xplique/example_based/base_example_method.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/xplique/example_based/base_example_method.py b/xplique/example_based/base_example_method.py
index df8ac306..9ce6b154 100644
--- a/xplique/example_based/base_example_method.py
+++ b/xplique/example_based/base_example_method.py
@@ -90,7 +90,7 @@ def __init__(
), "`BaseExampleMethod` without `projection` is a `BaseSearchMethod`."
# set attributes
- batch_size = self.__initialize_cases_dataset(
+ batch_size = self._initialize_cases_dataset(
cases_dataset, labels_dataset, targets_dataset, batch_size
)
@@ -126,7 +126,7 @@ def __init__(
**search_method_kwargs,
)
- def __initialize_cases_dataset(
+ def _initialize_cases_dataset(
self,
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]],
From b3ebe4a31c451a30bd387a735d19b336216a4f70 Mon Sep 17 00:00:00 2001
From: Lucas Hervier
Date: Tue, 21 May 2024 15:21:01 +0200
Subject: [PATCH 040/138] feat: change the fill value to np.inf when gathering
elements of a dataset from indices, such that indices to -1, -1 create inf
valued examples
---
xplique/commons/tf_dataset_operations.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/xplique/commons/tf_dataset_operations.py b/xplique/commons/tf_dataset_operations.py
index 69e750eb..f74f4ea2 100644
--- a/xplique/commons/tf_dataset_operations.py
+++ b/xplique/commons/tf_dataset_operations.py
@@ -206,9 +206,7 @@ def dataset_gather(dataset: tf.data.Dataset, indices: tf.Tensor) -> tf.Tensor:
example = next(iter(dataset))
# (n, bs, ...)
results = tf.Variable(
- tf.zeros(
- indices.shape[:-1] + example[0].shape, dtype=dataset.element_spec.dtype
- )
+ tf.fill(indices.shape[:-1] + example[0].shape, tf.constant(np.inf, dtype=dataset.element_spec.dtype)),
)
nb_results = product(indices.shape[:-1])
From 335dcb95406d472407574cc36b20f2c8f1321a44 Mon Sep 17 00:00:00 2001
From: Lucas Hervier
Date: Tue, 21 May 2024 15:25:27 +0200
Subject: [PATCH 041/138] feat: add the kleor search methods and their tests
---
tests/example_based/test_kleor.py | 212 +++++++++++++++
.../example_based/search_methods/__init__.py | 2 +-
xplique/example_based/search_methods/kleor.py | 241 ++++++++++++++++++
3 files changed, 454 insertions(+), 1 deletion(-)
create mode 100644 tests/example_based/test_kleor.py
create mode 100644 xplique/example_based/search_methods/kleor.py
diff --git a/tests/example_based/test_kleor.py b/tests/example_based/test_kleor.py
new file mode 100644
index 00000000..f4965f8d
--- /dev/null
+++ b/tests/example_based/test_kleor.py
@@ -0,0 +1,212 @@
+"""
+Tests for the contrastive methods.
+"""
+import tensorflow as tf
+import numpy as np
+
+from xplique.example_based.search_methods import KLEORSimMiss, KLEORGlobalSim
+
+def test_kleor_base_and_sim_miss():
+ """
+ Test suite for both the BaseKLEOR and KLEORSimMiss class. Indeed, the KLEORSimMiss class is a subclass of the
+ BaseKLEOR class with a very basic implementation of the only abstract method (identity function).
+ """
+ # setup the tests
+ cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32)
+ cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
+
+ cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2)
+ cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2)
+
+ inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
+ targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
+
+ # build the kleor object
+ kleor = KLEORSimMiss(cases_dataset, cases_targets_dataset, k=1, search_returns=["examples", "indices", "distances", "include_inputs", "nuns"], batch_size=2)
+
+ # test the _filter_fn method
+ fake_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
+ fake_cases_targets = tf.constant([[0, 1], [1, 0], [0, 1], [1, 0], [1, 0]], dtype=tf.float32)
+ # the mask should be True when the targets are the same i.e we keep those cases
+ expected_mask = tf.constant([[True, False, True, False, False],
+ [False, True, False, True, True],
+ [False, True, False, True, True],
+ [True, False, True, False, False],
+ [False, True, False, True, True]], dtype=tf.bool)
+ mask = kleor._filter_fn(inputs, cases, fake_targets, fake_cases_targets)
+ assert tf.reduce_all(tf.equal(mask, expected_mask))
+
+ # test the _filter_fn_nun method, this time the mask should be True when the targets are different
+ expected_mask = tf.constant([[False, True, False, True, True],
+ [True, False, True, False, False],
+ [True, False, True, False, False],
+ [False, True, False, True, True],
+ [True, False, True, False, False]], dtype=tf.bool)
+ mask = kleor._filter_fn_nun(inputs, cases, fake_targets, fake_cases_targets)
+ assert tf.reduce_all(tf.equal(mask, expected_mask))
+
+ # test the _get_nuns method
+ nuns, nuns_distances = kleor._get_nuns(inputs, targets)
+ expected_nuns = tf.constant([
+ [[2., 3.]],
+ [[1., 2.]],
+ [[4., 5.]]], dtype=tf.float32)
+ expected_nuns_distances = tf.constant([
+ [np.sqrt(2*0.5**2)],
+ [np.sqrt(2*1.5**2)],
+ [np.sqrt(2*0.5**2)]], dtype=tf.float32)
+ assert tf.reduce_all(tf.equal(nuns, expected_nuns))
+ assert tf.reduce_all(tf.equal(nuns_distances, expected_nuns_distances))
+
+ # test the _initialize_search method
+ sf_indices, input_sf_distances, nun_sf_distances, batch_indices = kleor._initialize_search(inputs)
+ assert sf_indices.shape == (3, 1, 2) # (n, k, 2)
+ assert input_sf_distances.shape == (3, 1) # (n, k)
+ assert nun_sf_distances.shape == (3, 1) # (n, k)
+ assert batch_indices.shape == (3, 2) # (n, bs)
+ expected_sf_indices = tf.constant([[[-1, -1]],[[-1, -1]],[[-1, -1]]], dtype=tf.int32)
+ assert tf.reduce_all(tf.equal(sf_indices, expected_sf_indices))
+ assert tf.reduce_all(tf.math.is_inf(input_sf_distances))
+ assert tf.reduce_all(tf.math.is_inf(nun_sf_distances))
+ expected_batch_indices = tf.constant([[0, 1], [0, 1], [0, 1]], dtype=tf.int32)
+ assert tf.reduce_all(tf.equal(batch_indices, expected_batch_indices))
+
+ # test the kneighbors method
+ input_sf_distances, sf_indices, nuns = kleor.kneighbors(inputs, targets)
+
+ assert input_sf_distances.shape == (3, 1) # (n, k)
+ assert sf_indices.shape == (3, 1, 2) # (n, k, 2)
+ assert nuns.shape == (3, 1, 2) # (n, k, 2)
+
+ assert tf.reduce_all(tf.equal(nuns, expected_nuns))
+
+ expected_distances = tf.constant([[np.sqrt(2*0.5**2)], [np.sqrt(2*0.5**2)], [np.sqrt(2*1.5**2)]], dtype=tf.float32)
+ assert tf.reduce_all(tf.abs(input_sf_distances - expected_distances) < 1e-5)
+
+ expected_indices = tf.constant([[[0, 0]],[[0, 1]],[[1, 0]]], dtype=tf.int32)
+ assert tf.reduce_all(tf.equal(sf_indices, expected_indices))
+
+ # test the find_examples method
+ return_dict = kleor.find_examples(inputs, targets)
+ assert set(return_dict.keys()) == set(["examples", "indices", "distances", "nuns"])
+
+ examples = return_dict["examples"]
+ distances = return_dict["distances"]
+ indices = return_dict["indices"]
+ nuns = return_dict["nuns"]
+
+ assert tf.reduce_all(tf.equal(nuns, expected_nuns))
+ assert tf.reduce_all(tf.equal(expected_indices, indices))
+ assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
+
+ expected_examples = tf.constant([
+ [[1.5, 2.5], [1., 2.]],
+ [[2.5, 3.5], [2., 3.]],
+ [[4.5, 5.5], [3., 4.]]], dtype=tf.float32)
+ assert tf.reduce_all(tf.equal(examples, expected_examples))
+
+def test_kleor_global_sim():
+ """
+ Test suite for the KleorGlobalSim class. As only the kneighbors, format_output are impacted by the
+ _additionnal_filtering method we test those 3 methods.
+ """
+ # setup the tests
+ cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32)
+ cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
+
+ cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2)
+ cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2)
+
+ inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
+ targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
+
+ # build the kleor object
+ kleor = KLEORGlobalSim(cases_dataset, cases_targets_dataset, k=1, search_returns=["examples", "indices", "distances", "include_inputs", "nuns"], batch_size=2)
+
+ # test the _additionnal_filtering method
+ # (n, bs)
+ fake_nun_sf_distances = tf.constant([[1., 2.], [2., 3.], [3., 4.]])
+ # (n, bs)
+ fake_input_sf_distances = tf.constant([[2., 1.], [3., 2.], [2., 5.]])
+ # (n,1)
+ fake_nuns_input_distances = tf.constant([[3.], [1.], [4.]])
+ # the expected filtering should be such that we keep the distance of a sf candidates
+ # when the input is closer to the sf than the nun, otherwise we set it to infinity
+ expected_nun_sf_distances = tf.constant([[1., 2.], [np.inf, np.inf], [3., np.inf]], dtype=tf.float32)
+ expected_input_sf_distances = tf.constant([[2., 1.], [np.inf, np.inf], [2., np.inf]], dtype=tf.float32)
+
+ nun_sf_distances, input_sf_distances = kleor._additional_filtering(fake_nun_sf_distances, fake_input_sf_distances, fake_nuns_input_distances)
+ assert nun_sf_distances.shape == (3, 2)
+ assert input_sf_distances.shape == (3, 2)
+
+ inf_mask_expected_nun_sf = tf.math.is_inf(expected_nun_sf_distances)
+ inf_mask_nun_sf = tf.math.is_inf(nun_sf_distances)
+ assert tf.reduce_all(tf.equal(inf_mask_expected_nun_sf, inf_mask_nun_sf))
+ assert tf.reduce_all(
+ tf.abs(tf.where(inf_mask_nun_sf, 0.0, nun_sf_distances) - tf.where(inf_mask_expected_nun_sf, 0.0, expected_nun_sf_distances)
+ ) < 1e-5)
+
+ inf_mask_expected_input_sf = tf.math.is_inf(expected_input_sf_distances)
+ inf_mask_input_sf = tf.math.is_inf(input_sf_distances)
+ assert tf.reduce_all(tf.equal(inf_mask_expected_input_sf, inf_mask_input_sf))
+ assert tf.reduce_all(
+ tf.abs(tf.where(inf_mask_input_sf, 0.0, input_sf_distances) - tf.where(inf_mask_expected_input_sf, 0.0, expected_input_sf_distances)
+ ) < 1e-5)
+
+ # test the kneighbors method
+ input_sf_distances, sf_indices, nuns = kleor.kneighbors(inputs, targets)
+
+ expected_nuns = tf.constant([
+ [[2., 3.]],
+ [[1., 2.]],
+ [[4., 5.]]], dtype=tf.float32)
+ assert tf.reduce_all(tf.equal(nuns, expected_nuns))
+
+ assert input_sf_distances.shape == (3, 1) # (n, k)
+ assert sf_indices.shape == (3, 1, 2) # (n, k, 2)
+
+ expected_indices = tf.constant([[[-1, -1]],[[0, 1]],[[-1, -1]]], dtype=tf.int32)
+ assert tf.reduce_all(tf.equal(sf_indices, expected_indices))
+
+ expected_distances = tf.constant([[kleor.fill_value], [np.sqrt(2*0.5**2)], [kleor.fill_value]], dtype=tf.float32)
+
+ # create masks for inf values
+ inf_mask_input = tf.math.is_inf(input_sf_distances)
+ inf_mask_expected = tf.math.is_inf(expected_distances)
+ assert tf.reduce_all(tf.equal(inf_mask_input, inf_mask_expected))
+
+ # compare finite values
+ assert tf.reduce_all(
+ tf.abs(tf.where(inf_mask_input, 0.0, input_sf_distances) - tf.where(inf_mask_expected, 0.0, expected_distances)
+ ) < 1e-5)
+
+ # test the find_examples
+ return_dict = kleor.find_examples(inputs, targets)
+
+ indices = return_dict["indices"]
+ nuns = return_dict["nuns"]
+ distances = return_dict["distances"]
+ examples = return_dict["examples"]
+
+ assert tf.reduce_all(tf.equal(nuns, expected_nuns))
+ assert tf.reduce_all(tf.equal(expected_indices, indices))
+
+ # create masks for inf values
+ inf_mask_dist = tf.math.is_inf(distances)
+ assert tf.reduce_all(tf.equal(inf_mask_dist, inf_mask_expected))
+ assert tf.reduce_all(
+ tf.abs(tf.where(inf_mask_dist, 0.0, distances) - tf.where(inf_mask_expected, 0.0, expected_distances)
+ ) < 1e-5)
+
+ expected_examples = tf.constant([
+ [[1.5, 2.5], [np.inf, np.inf]],
+ [[2.5, 3.5], [2., 3.]],
+ [[4.5, 5.5], [np.inf, np.inf]]], dtype=tf.float32)
+
+ # mask for inf values
+ inf_mask_examples = tf.math.is_inf(examples)
+ inf_mask_expected_examples = tf.math.is_inf(expected_examples)
+ assert tf.reduce_all(tf.equal(inf_mask_examples, inf_mask_expected_examples))
+ assert tf.reduce_all(
+ tf.abs(tf.where(inf_mask_examples, 0.0, examples) - tf.where(inf_mask_expected_examples, 0.0, expected_examples)
+ ) < 1e-5)
diff --git a/xplique/example_based/search_methods/__init__.py b/xplique/example_based/search_methods/__init__.py
index 010b7cb3..3c7897c5 100644
--- a/xplique/example_based/search_methods/__init__.py
+++ b/xplique/example_based/search_methods/__init__.py
@@ -4,5 +4,5 @@
from .base import BaseSearchMethod, ORDER
-# from .sklearn_knn import SklearnKNN
from .knn import BaseKNN, KNN, FilterKNN
+from .kleor import KLEORSimMiss, KLEORGlobalSim
diff --git a/xplique/example_based/search_methods/kleor.py b/xplique/example_based/search_methods/kleor.py
new file mode 100644
index 00000000..380d668a
--- /dev/null
+++ b/xplique/example_based/search_methods/kleor.py
@@ -0,0 +1,241 @@
+"""
+Define the KLEOR search method.
+"""
+from abc import abstractmethod, ABC
+
+import numpy as np
+import tensorflow as tf
+
+from ...commons import dataset_gather
+from ...types import Callable, List, Union, Optional, Tuple
+
+from .base import ORDER
+from .knn import FilterKNN
+
+class BaseKLEOR(FilterKNN, ABC):
+ """
+ Base class for the KLEOR search methods.
+ """
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ k: int = 1,
+ search_returns: Optional[Union[List[str], str]] = None,
+ batch_size: Optional[int] = 32,
+ distance: Union[int, str, Callable] = "euclidean",
+ ): # pylint: disable=R0801
+ possibilities = ["examples", "indices", "distances", "include_inputs", "nuns"]
+ super().__init__(
+ cases_dataset = cases_dataset,
+ targets_dataset=targets_dataset,
+ k=k,
+ filter_fn=self._filter_fn,
+ search_returns=search_returns,
+ batch_size=batch_size,
+ distance=distance,
+ order=ORDER.ASCENDING,
+ possibilities=possibilities
+ )
+
+ self.search_nuns = FilterKNN(
+ cases_dataset=cases_dataset,
+ targets_dataset=targets_dataset,
+ k=1,
+ filter_fn=self._filter_fn_nun,
+ search_returns=["indices", "distances"],
+ batch_size=batch_size,
+ distance=distance,
+ order = ORDER.ASCENDING,
+ )
+
+ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None):
+ """
+ Search the samples to return as examples. Called by the explain methods.
+ It may also return the indices corresponding to the samples,
+ based on `return_indices` value.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Assumed to have been already projected.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ """
+ # compute neighbors
+ examples_distances, examples_indices, nuns = self.kneighbors(inputs, targets)
+
+ # Set values in return dict
+ return_dict = {}
+ if "examples" in self.returns:
+ return_dict["examples"] = dataset_gather(self.cases_dataset, examples_indices)
+ # replace examples for which indices is -1, -1 by an inf value
+ # mask = tf.reduce_all(tf.equal(examples_indices, -1), axis=-1)
+ # return_dict["examples"] = tf.where(
+ # tf.expand_dims(mask, axis=-1),
+ # tf.fill(return_dict["examples"].shape, tf.constant(np.inf, dtype=tf.float32)),
+ # return_dict["examples"],
+ # )
+ if "include_inputs" in self.returns:
+ inputs = tf.expand_dims(inputs, axis=1)
+ return_dict["examples"] = tf.concat(
+ [inputs, return_dict["examples"]], axis=1
+ )
+ if "nuns" in self.returns:
+ return_dict["nuns"] = nuns
+ if "indices" in self.returns:
+ return_dict["indices"] = examples_indices
+ if "distances" in self.returns:
+ return_dict["distances"] = examples_distances
+
+ # Return a dict only different variables are returned
+ if len(return_dict) == 1:
+ return list(return_dict.values())[0]
+ return return_dict
+
+ def _filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor:
+ """
+ """
+ # get the labels predicted by the model
+ # (n, )
+ predicted_labels = tf.argmax(targets, axis=-1)
+ label_targets = tf.argmax(cases_targets, axis=-1)
+ # for each input, if the target label is the same as the cases label
+ # the mask as a True value and False otherwise
+ mask = tf.equal(tf.expand_dims(predicted_labels, axis=1), label_targets)
+ return mask
+
+ def _filter_fn_nun(self, _, __, targets, cases_targets) -> tf.Tensor:
+ """
+ Filter function to mask the cases for which the label is different from the predicted
+ label on the inputs.
+ """
+ # get the labels predicted by the model
+ # (n, )
+ predicted_labels = tf.argmax(targets, axis=-1)
+ label_targets = tf.argmax(cases_targets, axis=-1) # (bs,)
+ # for each input, if the target label is the same as the predicted label
+ # the mask as a False value and True otherwise
+ mask = tf.not_equal(tf.expand_dims(predicted_labels, axis=1), label_targets) #(n, bs)
+ return mask
+
+ def _get_nuns(self, inputs: Union[tf.Tensor, np.ndarray], targets: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Tensor, tf.Tensor]:
+ """
+ """
+ nuns_dict = self.search_nuns(inputs, targets)
+ nuns_indices, nuns_distances = nuns_dict["indices"], nuns_dict["distances"]
+ nuns = dataset_gather(self.cases_dataset, nuns_indices)
+ return nuns, nuns_distances
+
+ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], targets: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Tensor, tf.Tensor]:
+ """
+ """
+ # get the Nearest Unlike Neighbors and their distance to the related input
+ nuns, nuns_input_distances = self._get_nuns(inputs, targets)
+
+ # initialize the search for the KLEOR semi-factual methods
+ sf_indices, input_sf_distances, nun_sf_distances, batch_indices = self._initialize_search(inputs)
+
+ # iterate on batches
+ for batch_index, (cases, cases_targets) in enumerate(zip(self.cases_dataset, self.targets_dataset)):
+ # add new elements
+ # (n, current_bs, 2)
+ indices = batch_indices[:, : tf.shape(cases)[0]]
+ new_indices = tf.stack(
+ [tf.fill(indices.shape, tf.cast(batch_index, tf.int32)), indices], axis=-1
+ )
+
+ # get filter masks
+ # (n, current_bs)
+ filter_mask = self.filter_fn(inputs, cases, targets, cases_targets)
+
+ # compute distances
+ # (n, current_bs)
+ b_nun_sf_distances = self._crossed_distances_fn(nuns, cases, mask=filter_mask)
+ b_input_sf_distances = self._crossed_distances_fn(inputs, cases, mask=filter_mask)
+
+ # additional filtering
+ b_nun_sf_distances, b_input_sf_distances = self._additional_filtering(
+ b_nun_sf_distances, b_input_sf_distances, nuns_input_distances
+ )
+ # concatenate distances and indices
+ # (n, k+curent_bs, 2)
+ concatenated_indices = tf.concat([sf_indices, new_indices], axis=1)
+ # (n, k+curent_bs)
+ concatenated_nun_sf_distances = tf.concat([nun_sf_distances, b_nun_sf_distances], axis=1)
+ concatenated_input_sf_distances = tf.concat([input_sf_distances, b_input_sf_distances], axis=1)
+
+ # sort according to the smallest distances between sf and nun
+ # (n, k)
+ sort_order = tf.argsort(
+ concatenated_nun_sf_distances, axis=1, direction=self.order.name.upper()
+ )[:, : self.k]
+
+ sf_indices.assign(
+ tf.gather(concatenated_indices, sort_order, axis=1, batch_dims=1)
+ )
+ nun_sf_distances.assign(
+ tf.gather(concatenated_nun_sf_distances, sort_order, axis=1, batch_dims=1)
+ )
+ input_sf_distances.assign(
+ tf.gather(concatenated_input_sf_distances, sort_order, axis=1, batch_dims=1)
+ )
+
+ return input_sf_distances, sf_indices, nuns
+
+ def _initialize_search(self, inputs: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Variable, tf.Variable, tf.Variable, tf.Tensor]:
+ """
+ Initialize the search for the KLEOR semi-factual methods.
+ """
+ nb_inputs = tf.shape(inputs)[0]
+
+ # sf_indices shape (n, k, 2)
+ sf_indices = tf.Variable(tf.fill((nb_inputs, self.k, 2), -1))
+ # (n, k)
+ input_sf_distances = tf.Variable(tf.fill((nb_inputs, self.k), self.fill_value))
+ nun_sf_distances = tf.Variable(tf.fill((nb_inputs, self.k), self.fill_value))
+ # (n, bs)
+ batch_indices = tf.expand_dims(tf.range(self.batch_size, dtype=tf.int32), axis=0)
+ batch_indices = tf.tile(batch_indices, multiples=(nb_inputs, 1))
+ return sf_indices, input_sf_distances, nun_sf_distances, batch_indices
+
+ @abstractmethod
+ def _additional_filtering(self, nun_sf_distances: tf.Tensor, input_sf_distances: tf.Tensor, nuns_input_distances: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
+ """
+ Additional filtering to apply to the distances.
+ """
+ raise NotImplementedError
+
+class KLEORSimMiss(BaseKLEOR):
+ """
+ KLEOR search method.
+
+ Parameters
+ ----------
+ cases_dataset
+ Dataset of cases.
+ targets_dataset
+ Dataset of targets. Should be a one-hot encoded of the predicted class
+ """
+ def _additional_filtering(self, nun_sf_distances: tf.Tensor, input_sf_distances: tf.Tensor, nuns_input_distances: tf.Tensor) -> Tuple:
+ return nun_sf_distances, input_sf_distances
+
+class KLEORGlobalSim(BaseKLEOR):
+ """
+ KLEOR search method.
+
+ Parameters
+ ----------
+ cases_dataset
+ Dataset of cases.
+ targets_dataset
+ Dataset of targets. Should be a one-hot encoded of the predicted class
+ """
+ def _additional_filtering(self, nun_sf_distances: tf.Tensor, input_sf_distances: tf.Tensor, nuns_input_distances: tf.Tensor) -> Tuple:
+ # filter non acceptable cases, i.e. cases for which the distance to the input is greater
+ # than the distance between the input and its nun
+ # (n, current_bs)
+ mask = tf.less(input_sf_distances, nuns_input_distances)
+ nun_sf_distances = tf.where(mask, nun_sf_distances, self.fill_value)
+ input_sf_distances = tf.where(mask, input_sf_distances, self.fill_value)
+ return nun_sf_distances, input_sf_distances
From fe5cc44b5e00ad0e8a92e1340575a8a39b61af84 Mon Sep 17 00:00:00 2001
From: Lucas Hervier
Date: Tue, 21 May 2024 15:28:28 +0200
Subject: [PATCH 042/138] feat: add the possibilities as an initialisation args
---
xplique/example_based/search_methods/base.py | 8 +++++---
xplique/example_based/search_methods/knn.py | 12 ++++++++----
2 files changed, 13 insertions(+), 7 deletions(-)
diff --git a/xplique/example_based/search_methods/base.py b/xplique/example_based/search_methods/base.py
index a165688d..a7bf4e02 100644
--- a/xplique/example_based/search_methods/base.py
+++ b/xplique/example_based/search_methods/base.py
@@ -92,6 +92,7 @@ def __init__(
search_returns: Optional[Union[List[str], str]] = None,
batch_size: Optional[int] = 32,
targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ possibilities: Optional[List[str]] = None,
): # pylint: disable=R0801
# set batch size
@@ -103,7 +104,7 @@ def __init__(
self.cases_dataset = sanitize_dataset(cases_dataset, self.batch_size)
self.set_k(k)
- self.set_returns(search_returns)
+ self.set_returns(search_returns, possibilities)
# set targets_dataset
if targets_dataset is not None:
@@ -125,7 +126,7 @@ def set_k(self, k: int):
assert isinstance(k, int) and k >= 1, f"k should be an int >= 1 and not {k}"
self.k = k
- def set_returns(self, returns: Optional[Union[List[str], str]] = None):
+ def set_returns(self, returns: Optional[Union[List[str], str]] = None, possibilities: Optional[List[str]] = None):
"""
Set `self.returns` used to define returned elements in `self.find_examples()`.
@@ -143,7 +144,8 @@ def set_returns(self, returns: Optional[Union[List[str], str]] = None):
- 'include_inputs' specify if inputs should be included in the returned elements.
Note that it changes the number of returned elements from k to k+1.
"""
- possibilities = ["examples", "indices", "distances", "include_inputs"]
+ if possibilities is None:
+ possibilities = ["examples", "indices", "distances", "include_inputs"]
default = "examples"
self.returns = _sanitize_returns(returns, possibilities, default)
diff --git a/xplique/example_based/search_methods/knn.py b/xplique/example_based/search_methods/knn.py
index 5291999a..e53833cd 100644
--- a/xplique/example_based/search_methods/knn.py
+++ b/xplique/example_based/search_methods/knn.py
@@ -23,9 +23,10 @@ def __init__(
batch_size: Optional[int] = 32,
order: ORDER = ORDER.ASCENDING,
targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ possibilities: Optional[List[str]] = None,
):
super().__init__(
- cases_dataset, k, search_returns, batch_size, targets_dataset
+ cases_dataset, k, search_returns, batch_size, targets_dataset, possibilities
)
assert isinstance(order, ORDER), f"order should be an instance of ORDER and not {type(order)}"
@@ -130,9 +131,10 @@ def __init__(
distance: Union[int, str, Callable] = "euclidean",
order: ORDER = ORDER.ASCENDING,
targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ possibilities: Optional[List[str]] = None,
): # pylint: disable=R0801
super().__init__(
- cases_dataset, k, search_returns, batch_size, order, targets_dataset
+ cases_dataset, k, search_returns, batch_size, order, targets_dataset, possibilities
)
if hasattr(distance, "__call__"):
@@ -236,6 +238,7 @@ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], _ = None) -> Tuple[tf
class FilterKNN(BaseKNN):
"""
+ TODO: Change the class description
KNN method to search examples. Based on `sklearn.neighbors.NearestNeighbors`.
Basically a wrapper of `NearestNeighbors` to match the `BaseSearchMethod` API.
@@ -271,10 +274,11 @@ def __init__(
search_returns: Optional[Union[List[str], str]] = None,
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
- order: ORDER = ORDER.ASCENDING
+ order: ORDER = ORDER.ASCENDING,
+ possibilities: Optional[List[str]] = None,
): # pylint: disable=R0801
super().__init__(
- cases_dataset, k, search_returns, batch_size, order, targets_dataset
+ cases_dataset, k, search_returns, batch_size, order, targets_dataset, possibilities
)
if hasattr(distance, "__call__"):
From 67b78047891014286be24b71f994e7cea08e60b1 Mon Sep 17 00:00:00 2001
From: Lucas Hervier
Date: Tue, 21 May 2024 15:29:18 +0200
Subject: [PATCH 043/138] feat: add the KLEOR example based method and its
tests
---
tests/example_based/test_contrastive.py | 292 ++++++++++++------
xplique/example_based/contrastive_examples.py | 206 ++++++++----
2 files changed, 344 insertions(+), 154 deletions(-)
diff --git a/tests/example_based/test_contrastive.py b/tests/example_based/test_contrastive.py
index bac1aaa2..82a47d60 100644
--- a/tests/example_based/test_contrastive.py
+++ b/tests/example_based/test_contrastive.py
@@ -1,14 +1,12 @@
"""
Tests for the contrastive methods.
"""
-import pytest
-
import tensorflow as tf
import numpy as np
-from xplique.example_based import NaiveSemiFactuals, PredictedLabelAwareSemiFactuals, NaiveCounterFactuals
+from xplique.example_based import NaiveCounterFactuals, LabelAwareCounterFactuals, KLEOR
-def test_naive_semi_factuals():
+def test_naive_counter_factuals():
"""
"""
cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32)
@@ -16,21 +14,21 @@ def test_naive_semi_factuals():
cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2)
cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2)
- semi_factuals = NaiveSemiFactuals(cases_dataset, cases_targets_dataset, k=2, case_returns=["examples", "indices", "distances", "include_inputs"], batch_size=2)
+ counter_factuals = NaiveCounterFactuals(cases_dataset, cases_targets_dataset, k=2, case_returns=["examples", "indices", "distances", "include_inputs"], batch_size=2)
inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
- mask = semi_factuals.filter_fn(inputs, cases, targets, cases_targets)
+ mask = counter_factuals.filter_fn(inputs, cases, targets, cases_targets)
assert mask.shape == (inputs.shape[0], cases.shape[0])
expected_mask = tf.constant([
- [True, False, False, True, False],
[False, True, True, False, True],
- [False, True, True, False, True]], dtype=tf.bool)
+ [True, False, False, True, False],
+ [True, False, False, True, False]], dtype=tf.bool)
assert tf.reduce_all(tf.equal(mask, expected_mask))
- return_dict = semi_factuals(inputs, targets)
+ return_dict = counter_factuals(inputs, targets)
assert set(return_dict.keys()) == set(["examples", "indices", "distances"])
examples = return_dict["examples"]
@@ -42,159 +40,259 @@ def test_naive_semi_factuals():
assert indices.shape == (3, 2, 2) # (n, k, 2)
expected_examples = tf.constant([
- [[1.5, 2.5], [4., 5.], [1., 2.]],
- [[2.5, 3.5], [5., 6.], [2., 3.]],
- [[4.5, 5.5], [2., 3.], [3., 4.]]], dtype=tf.float32)
+ [[1.5, 2.5], [2., 3.], [3., 4.]],
+ [[2.5, 3.5], [1., 2.], [4., 5.]],
+ [[4.5, 5.5], [4., 5.], [1., 2.]]], dtype=tf.float32)
assert tf.reduce_all(tf.equal(examples, expected_examples))
- expected_distances = tf.constant([[np.sqrt(2*2.5**2), np.sqrt(0.5)], [np.sqrt(2*2.5**2), np.sqrt(0.5)], [np.sqrt(2*2.5**2), np.sqrt(2*1.5**2)]], dtype=tf.float32)
+ expected_distances = tf.constant([[np.sqrt(2*0.5**2), np.sqrt(2*1.5**2)], [np.sqrt(2*1.5**2), np.sqrt(2*1.5**2)], [np.sqrt(2*0.5**2), np.sqrt(2*3.5**2)]], dtype=tf.float32)
assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
- expected_indices = tf.constant([[[1, 1], [0, 0]],[[2, 0], [0, 1]],[[0, 1], [1, 0]]], dtype=tf.int32)
+ expected_indices = tf.constant([[[0, 1], [1, 0]],[[0, 0], [1, 1]],[[1, 1], [0, 0]]], dtype=tf.int32)
assert tf.reduce_all(tf.equal(indices, expected_indices))
-def test_labelaware_semifactuals():
+def test_label_aware_cf():
"""
+ Test suite for the LabelAwareCounterFactuals class
"""
+ # Same tests as the previous one but with the LabelAwareCounterFactuals class
+ # thus we only needs to use cf_targets = 1 - targets of the previous tests
cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32)
cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2)
cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2)
+ counter_factuals = LabelAwareCounterFactuals(cases_dataset, cases_targets_dataset, k=1, case_returns=["examples", "indices", "distances", "include_inputs"], batch_size=2)
inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
- targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
-
- semi_factuals = PredictedLabelAwareSemiFactuals(cases_dataset, cases_targets_dataset, target_label=0, k=2, batch_size=2, case_returns=["examples", "distances", "include_inputs"])
- # assert the filtering on the right label went right
-
- combined_dataset = tf.data.Dataset.zip((cases_dataset.unbatch(), cases_targets_dataset.unbatch()))
- combined_dataset = combined_dataset.filter(lambda x, y: tf.equal(tf.argmax(y, axis=-1),0))
-
- filter_cases = semi_factuals.cases_dataset
- filter_targets = semi_factuals.targets_dataset
-
- expected_filter_cases = tf.constant([[2., 3.], [3., 4.], [5., 6.]], dtype=tf.float32)
- expected_filter_targets = tf.constant([[1, 0], [1, 0], [1, 0]], dtype=tf.float32)
+ cf_targets = tf.constant([[1, 0], [0, 1], [0, 1]], dtype=tf.float32)
- tensor_filter_cases = []
- for elem in filter_cases.unbatch():
- tensor_filter_cases.append(elem)
- tensor_filter_cases = tf.stack(tensor_filter_cases)
- assert tf.reduce_all(tf.equal(tensor_filter_cases, expected_filter_cases))
-
- tensor_filter_targets = []
- for elem in filter_targets.unbatch():
- tensor_filter_targets.append(elem)
- tensor_filter_targets = tf.stack(tensor_filter_targets)
- assert tf.reduce_all(tf.equal(tensor_filter_targets, expected_filter_targets))
+ mask = counter_factuals.filter_fn(inputs, cases, cf_targets, cases_targets)
+ assert mask.shape == (inputs.shape[0], cases.shape[0])
- # check the call method
- filter_inputs = tf.constant([[2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
- filter_targets = tf.constant([[1, 0], [1, 0]], dtype=tf.float32)
+ expected_mask = tf.constant([
+ [False, True, True, False, True],
+ [True, False, False, True, False],
+ [True, False, False, True, False]], dtype=tf.bool)
+ assert tf.reduce_all(tf.equal(mask, expected_mask))
- return_dict = semi_factuals(filter_inputs, filter_targets)
- assert set(return_dict.keys()) == set(["examples", "distances"])
+ return_dict = counter_factuals(inputs, cf_targets)
+ assert set(return_dict.keys()) == set(["examples", "indices", "distances"])
examples = return_dict["examples"]
distances = return_dict["distances"]
+ indices = return_dict["indices"]
- assert examples.shape == (2, 3, 2) # (n_label0, k+1, W)
- assert distances.shape == (2, 2) # (n_label0, k)
+ assert examples.shape == (3, 2, 2) # (n, k+1, W)
+ assert distances.shape == (3, 1) # (n, k)
+ assert indices.shape == (3, 1, 2) # (n, k, 2)
expected_examples = tf.constant([
- [[2.5, 3.5], [5., 6.], [2., 3.]],
- [[4.5, 5.5], [2., 3.], [3., 4.]]], dtype=tf.float32)
+ [[1.5, 2.5], [2., 3.]],
+ [[2.5, 3.5], [1., 2.]],
+ [[4.5, 5.5], [4., 5.]]], dtype=tf.float32)
assert tf.reduce_all(tf.equal(examples, expected_examples))
- expected_distances = tf.constant([[np.sqrt(2*2.5**2), np.sqrt(0.5)], [np.sqrt(2*2.5**2), np.sqrt(2*1.5**2)]], dtype=tf.float32)
+ expected_distances = tf.constant([[np.sqrt(2*0.5**2)], [np.sqrt(2*1.5**2)], [np.sqrt(2*0.5**2)]], dtype=tf.float32)
assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
- # check an error is raised when a target does not match the target label
- with pytest.raises(AssertionError):
- semi_factuals(inputs, targets)
+ expected_indices = tf.constant([[[0, 1]],[[0, 0]],[[1, 1]]], dtype=tf.int32)
+ assert tf.reduce_all(tf.equal(indices, expected_indices))
- # same but with the other label
- semi_factuals = PredictedLabelAwareSemiFactuals(cases_dataset, cases_targets_dataset, target_label=1, k=2, batch_size=2, case_returns=["examples", "distances", "include_inputs"])
- filter_cases = semi_factuals.cases_dataset
- filter_targets = semi_factuals.targets_dataset
+ # Now let's dive when multiple classes are available in 1D
+ cases = tf.constant([[1.], [2.], [3.], [4.], [5.], [6.], [7.], [8.], [9.], [10.]], dtype=tf.float32)
+ cases_targets = tf.constant([[0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1], [0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=tf.float32)
- expected_filter_cases = tf.constant([[1., 2.], [4., 5.]], dtype=tf.float32)
- expected_filter_targets = tf.constant([[0, 1], [0, 1]], dtype=tf.float32)
+ cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2)
+ cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2)
- tensor_filter_cases = []
- for elem in filter_cases.unbatch():
- tensor_filter_cases.append(elem)
- tensor_filter_cases = tf.stack(tensor_filter_cases)
- assert tf.reduce_all(tf.equal(tensor_filter_cases, expected_filter_cases))
+ counter_factuals = LabelAwareCounterFactuals(cases_dataset, cases_targets_dataset, k=1, case_returns=["examples", "indices", "distances", "include_inputs"], batch_size=2)
- tensor_filter_targets = []
- for elem in filter_targets.unbatch():
- tensor_filter_targets.append(elem)
- tensor_filter_targets = tf.stack(tensor_filter_targets)
- assert tf.reduce_all(tf.equal(tensor_filter_targets, expected_filter_targets))
+ inputs = tf.constant([[1.5], [2.5], [4.5], [6.5], [8.5]], dtype=tf.float32)
+ cf_targets = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 1, 0]], dtype=tf.float32)
- # check the call method
- filter_inputs = tf.constant([[1.5, 2.5]], dtype=tf.float32)
- filter_targets = tf.constant([[0, 1]], dtype=tf.float32)
+ mask = counter_factuals.filter_fn(inputs, cases, cf_targets, cases_targets)
+ assert mask.shape == (inputs.shape[0], cases.shape[0])
- return_dict = semi_factuals(filter_inputs, filter_targets)
- assert set(return_dict.keys()) == set(["examples", "distances"])
+ expected_mask = tf.constant([
+ [False, True, False, True, True, False, False, False, False, True],
+ [True, False, False, False, False, False, True, False, True, False],
+ [False, False, True, False, False, True, False, True, False, False],
+ [False, False, True, False, False, True, False, True, False, False],
+ [True, False, False, False, False, False, True, False, True, False]], dtype=tf.bool)
+ assert tf.reduce_all(tf.equal(mask, expected_mask))
+
+ return_dict = counter_factuals(inputs, cf_targets)
+ assert set(return_dict.keys()) == set(["examples", "indices", "distances"])
examples = return_dict["examples"]
distances = return_dict["distances"]
+ indices = return_dict["indices"]
- assert examples.shape == (1, 3, 2) # (n_label1, k+1, W)
- assert distances.shape == (1, 2) # (n_label1, k)
+ assert examples.shape == (5, 2, 1) # (n, k+1, W)
+ assert distances.shape == (5, 1) # (n, k)
+ assert indices.shape == (5, 1, 2) # (n, k, 2)
expected_examples = tf.constant([
- [[1.5, 2.5], [4., 5.], [1., 2.]]], dtype=tf.float32)
+ [[1.5], [2.]],
+ [[2.5], [1.]],
+ [[4.5], [3.]],
+ [[6.5], [6.]],
+ [[8.5], [9.]]], dtype=tf.float32)
assert tf.reduce_all(tf.equal(examples, expected_examples))
- expected_distances = tf.constant([[np.sqrt(2*2.5**2), np.sqrt(0.5)]], dtype=tf.float32)
+ expected_distances = tf.constant([[np.sqrt(0.5**2)], [np.sqrt(1.5**2)], [np.sqrt(1.5**2)], [np.sqrt(0.5**2)], [np.sqrt(0.5**2)]], dtype=tf.float32)
assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
-def test_naive_counter_factuals():
+ expected_indices = tf.constant([[[0, 1]],[[0, 0]],[[1, 0]],[[2, 1]],[[4, 0]]], dtype=tf.int32)
+ assert tf.reduce_all(tf.equal(indices, expected_indices))
+
+def test_kleor():
"""
+ Test suite for the Kleor class
"""
+ # setup the tests
cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32)
cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2)
cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2)
- counter_factuals = NaiveCounterFactuals(cases_dataset, cases_targets_dataset, k=2, case_returns=["examples", "indices", "distances", "include_inputs"], batch_size=2)
inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
- mask = counter_factuals.filter_fn(inputs, cases, targets, cases_targets)
- assert mask.shape == (inputs.shape[0], cases.shape[0])
-
- expected_mask = tf.constant([
- [False, True, True, False, True],
- [True, False, False, True, False],
- [True, False, False, True, False]], dtype=tf.bool)
- assert tf.reduce_all(tf.equal(mask, expected_mask))
+ # start when strategy is sim_miss
+ kleor_sim_miss = KLEOR(
+ cases_dataset,
+ cases_targets_dataset,
+ k=1,
+ case_returns=["examples", "indices", "distances", "include_inputs", "nuns"],
+ batch_size=2,
+ strategy="sim_miss"
+ )
- return_dict = counter_factuals(inputs, targets)
- assert set(return_dict.keys()) == set(["examples", "indices", "distances"])
+ return_dict = kleor_sim_miss(inputs, targets)
+ assert set(return_dict.keys()) == set(["examples", "indices", "distances", "nuns"])
examples = return_dict["examples"]
distances = return_dict["distances"]
indices = return_dict["indices"]
+ nuns = return_dict["nuns"]
- assert examples.shape == (3, 3, 2) # (n, k+1, W)
- assert distances.shape == (3, 2) # (n, k)
- assert indices.shape == (3, 2, 2) # (n, k, 2)
+ expected_nuns = tf.constant([
+ [[2., 3.]],
+ [[1., 2.]],
+ [[4., 5.]]], dtype=tf.float32)
+ assert tf.reduce_all(tf.equal(nuns, expected_nuns))
+
+ assert examples.shape == (3, 2, 2) # (n, k+1, W)
+ assert distances.shape == (3, 1) # (n, k)
+ assert indices.shape == (3, 1, 2) # (n, k, 2)
expected_examples = tf.constant([
- [[1.5, 2.5], [2., 3.], [3., 4.]],
- [[2.5, 3.5], [1., 2.], [4., 5.]],
- [[4.5, 5.5], [4., 5.], [1., 2.]]], dtype=tf.float32)
+ [[1.5, 2.5], [1., 2.]],
+ [[2.5, 3.5], [2., 3.]],
+ [[4.5, 5.5], [3., 4.]]], dtype=tf.float32)
assert tf.reduce_all(tf.equal(examples, expected_examples))
- expected_distances = tf.constant([[np.sqrt(2*0.5**2), np.sqrt(2*1.5**2)], [np.sqrt(2*1.5**2), np.sqrt(2*1.5**2)], [np.sqrt(2*0.5**2), np.sqrt(2*3.5**2)]], dtype=tf.float32)
+ expected_distances = tf.constant([[np.sqrt(2*0.5**2)], [np.sqrt(2*0.5**2)], [np.sqrt(2*1.5**2)]], dtype=tf.float32)
assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
- expected_indices = tf.constant([[[0, 1], [1, 0]],[[0, 0], [1, 1]],[[1, 1], [0, 0]]], dtype=tf.int32)
+ expected_indices = tf.constant([[[0, 0]],[[0, 1]],[[1, 0]]], dtype=tf.int32)
+ assert tf.reduce_all(tf.equal(indices, expected_indices))
+
+ # now strategy is global_sim
+ kleor_global_sim = KLEOR(
+ cases_dataset,
+ cases_targets_dataset,
+ k=1,
+ case_returns=["examples", "indices", "distances", "include_inputs", "nuns"],
+ batch_size=2,
+ strategy="global_sim"
+ )
+
+ return_dict = kleor_global_sim(inputs, targets)
+ assert set(return_dict.keys()) == set(["examples", "indices", "distances", "nuns"])
+
+ nuns = return_dict["nuns"]
+ assert tf.reduce_all(tf.equal(nuns, expected_nuns))
+
+ examples = return_dict["examples"]
+ distances = return_dict["distances"]
+ indices = return_dict["indices"]
+
+ assert examples.shape == (3, 2, 2) # (n, k+1, W)
+ assert distances.shape == (3, 1) # (n, k)
+ assert indices.shape == (3, 1, 2) # (n, k, 2)
+
+ expected_indices = tf.constant([[[-1, -1]],[[0, 1]],[[-1, -1]]], dtype=tf.int32)
assert tf.reduce_all(tf.equal(indices, expected_indices))
+
+ expected_distances = tf.constant([[np.inf], [np.sqrt(2*0.5**2)], [np.inf]], dtype=tf.float32)
+ # create masks for inf values
+ inf_mask_dist = tf.math.is_inf(distances)
+ inf_mask_expected_distances = tf.math.is_inf(expected_distances)
+ assert tf.reduce_all(tf.equal(inf_mask_dist, inf_mask_expected_distances))
+ assert tf.reduce_all(
+ tf.abs(tf.where(inf_mask_dist, 0.0, distances) - tf.where(inf_mask_expected_distances, 0.0, expected_distances)
+ ) < 1e-5)
+
+ expected_examples = tf.constant([
+ [[1.5, 2.5], [np.inf, np.inf]],
+ [[2.5, 3.5], [2., 3.]],
+ [[4.5, 5.5], [np.inf, np.inf]]], dtype=tf.float32)
+ # mask for inf values
+ inf_mask_examples = tf.math.is_inf(examples)
+ inf_mask_expected_examples = tf.math.is_inf(expected_examples)
+ assert tf.reduce_all(tf.equal(inf_mask_examples, inf_mask_expected_examples))
+ assert tf.reduce_all(
+ tf.abs(tf.where(inf_mask_examples, 0.0, examples) - tf.where(inf_mask_expected_examples, 0.0, expected_examples)
+ ) < 1e-5)
+
+# def test_kleor_global_sim():
+# """
+# Test suite for the KleorSimMiss class
+# """
+# cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32)
+# cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
+
+# cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2)
+# cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2)
+# semi_factuals = KLEOR(
+# cases_dataset,
+# cases_targets_dataset,
+# k=1,
+# case_returns=["examples", "indices", "distances", "include_inputs", "nuns"],
+# batch_size=2,
+# strategy="global_sim"
+# )
+
+# return_dict = semi_factuals(inputs, targets)
+# assert set(return_dict.keys()) == set(["examples", "indices", "distances", "nuns"])
+
+# examples = return_dict["examples"]
+# distances = return_dict["distances"]
+# indices = return_dict["indices"]
+# nuns = return_dict["nuns"]
+
+# expected_nuns = tf.constant([
+# [[2., 3.]],
+# [[1., 2.]],
+# [[4., 5.]]], dtype=tf.float32)
+# assert tf.reduce_all(tf.equal(nuns, expected_nuns))
+
+# assert examples.shape == (3, 2, 2) # (n, k+1, W)
+# assert distances.shape == (3, 1) # (n, k)
+# assert indices.shape == (3, 1, 2) # (n, k, 2)
+
+# expected_examples = tf.constant([
+# [[1.5, 2.5], [1., 2.]],
+# [[2.5, 3.5], [2., 3.]],
+# [[4.5, 5.5], [3., 4.]]], dtype=tf.float32)
+# assert tf.reduce_all(tf.equal(examples, expected_examples))
+
+# expected_distances = tf.constant([[np.sqrt(2*0.5**2)], [np.sqrt(2*0.5**2)], [np.sqrt(2*1.5**2)]], dtype=tf.float32)
+# assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
+
+# expected_indices = tf.constant([[[0, 0]],[[0, 1]],[[1, 0]]], dtype=tf.int32)
+# assert tf.reduce_all(tf.equal(indices, expected_indices))
diff --git a/xplique/example_based/contrastive_examples.py b/xplique/example_based/contrastive_examples.py
index 83b03b11..0b996f89 100644
--- a/xplique/example_based/contrastive_examples.py
+++ b/xplique/example_based/contrastive_examples.py
@@ -1,19 +1,25 @@
"""
Implementation of both counterfactuals and semi factuals methods for classification tasks.
+
+SM CF guided to be implemented (I think): KLEOR at least Sim-Miss and Global-Sim
+SM CF free to be implemented: MDN but has to be adapated, Local-Region Model??
"""
import numpy as np
import tensorflow as tf
-from ..types import Callable, List, Optional, Union
+from ..types import Callable, List, Optional, Union, Dict
+from ..commons import sanitize_inputs_targets
from .base_example_method import BaseExampleMethod
-from .search_methods import BaseSearchMethod, KNN, ORDER, FilterKNN
+from .search_methods import ORDER, FilterKNN, KLEORSimMiss, KLEORGlobalSim
from .projections import Projection
-class NaiveSemiFactuals(BaseExampleMethod):
+from .search_methods.base import _sanitize_returns
+
+class NaiveCounterFactuals(BaseExampleMethod):
"""
- Define a naive version of semi factuals search. That for a given sample
- it will return the farthest sample which have the same label.
+ This class allows to search for counterfactuals by searching for the closest sample that do not have the same label.
+ It is a naive approach as it follows a greedy approach.
"""
def __init__(
self,
@@ -42,7 +48,7 @@ def __init__(
batch_size=batch_size,
distance=distance,
filter_fn=self.filter_fn,
- order = ORDER.DESCENDING
+ order = ORDER.ASCENDING
)
@@ -58,19 +64,18 @@ def filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor:
# for each input, if the target label is the same as the predicted label
# the mask as a True value and False otherwise
label_targets = tf.argmax(cases_targets, axis=-1) # (bs,)
- mask = tf.equal(tf.expand_dims(predicted_labels, axis=1), label_targets) #(n, bs)
+ mask = tf.not_equal(tf.expand_dims(predicted_labels, axis=1), label_targets) #(n, bs)
return mask
-class PredictedLabelAwareSemiFactuals(BaseExampleMethod):
+class LabelAwareCounterFactuals(BaseExampleMethod):
"""
- As we know semi-factuals should belong to the same class as the input,
- we propose here a method that is dedicated to a specific label.
+ This method will search the counterfactuals with a specific label. This label should be provided by the user in the
+ cf_labels_dataset args.
"""
def __init__(
self,
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- target_label: int,
labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
k: int = 1,
projection: Union[Projection, Callable] = None,
@@ -78,23 +83,11 @@ def __init__(
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
):
- # filter the cases dataset and targets dataset to keep only the ones
- # that have the target label
- # TODO: improve this unbatch and batch
- combined_dataset = tf.data.Dataset.zip((cases_dataset.unbatch(), targets_dataset.unbatch()))
- combined_dataset = combined_dataset.filter(lambda x, y: tf.equal(tf.argmax(y, axis=-1),target_label))
-
- # separate the cases and targets
- cases_dataset = combined_dataset.map(lambda x, y: x).batch(batch_size)
- targets_dataset = combined_dataset.map(lambda x, y: y).batch(batch_size)
-
- # delete the combined dataset
- del combined_dataset
+ search_method = FilterKNN
if projection is None:
projection = Projection(space_projection=lambda inputs: inputs)
-
- search_method = KNN
+ # TODO: add a warning here if it is a custom projection that requires using targets as it might mismatch with the explain
super().__init__(
cases_dataset=cases_dataset,
@@ -106,24 +99,60 @@ def __init__(
case_returns=case_returns,
batch_size=batch_size,
distance=distance,
- order = ORDER.DESCENDING
+ filter_fn=self.filter_fn,
+ order = ORDER.ASCENDING
)
- self.target_label = target_label
+ def filter_fn(self, _, __, cf_targets, cases_targets) -> tf.Tensor:
+ """
+ Filter function to mask the cases for which the label is different from the label(s) expected for the
+ counterfactuals.
+
+ Parameters
+ ----------
+ cf_targets
+ TODO
+ cases_targets
+ TODO
+ """
+ mask = tf.matmul(cf_targets, cases_targets, transpose_b=True) #(n, bs)
+ # TODO: I think some retracing are done here
+ mask = tf.cast(mask, dtype=tf.bool)
+ return mask
- def __call__(
+ @sanitize_inputs_targets
+ def explain(
self,
inputs: Union[tf.Tensor, np.ndarray],
- targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ cf_targets: Union[tf.Tensor, np.ndarray],
):
- # assert targets are all the same as the target label
- if targets is not None:
- assert tf.reduce_all(tf.argmax(targets, axis=-1) == self.target_label), "All targets should be the same as the target label."
- return super().__call__(inputs, targets)
+ """
+ Compute examples to explain the inputs.
+ It project inputs with `self.projection` in the search space
+ and find examples with `self.search_method`.
-class NaiveCounterFactuals(BaseExampleMethod):
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ More information in the documentation.
+ cf_targets
+ TODO: change the description here
+
+ Returns
+ -------
+ return_dict
+ Dictionary with listed elements in `self.returns`.
+ If only one element is present it returns the element.
+ The elements that can be returned are:
+ examples, weights, distances, indices, and labels.
+ """
+ # TODO make an assert on the cf_targets
+ return super().explain(inputs, cf_targets)
+
+class KLEOR(BaseExampleMethod):
"""
-
"""
def __init__(
self,
@@ -135,38 +164,101 @@ def __init__(
case_returns: Union[List[str], str] = "examples",
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
+ strategy: str = "sim_miss",
):
- search_method = FilterKNN
+
+ self.k = k
+ self.set_returns(case_returns)
+
+ if strategy == "global_sim":
+ search_method = KLEORGlobalSim
+ elif strategy == "sim_miss":
+ search_method = KLEORSimMiss
+ else:
+ raise ValueError("strategy should be either 'global_sim' or 'sim_miss'.")
if projection is None:
projection = Projection(space_projection=lambda inputs: inputs)
- super().__init__(
- cases_dataset=cases_dataset,
- labels_dataset=labels_dataset,
- targets_dataset=targets_dataset,
- search_method=search_method,
+ # set attributes
+ batch_size = super()._initialize_cases_dataset(
+ cases_dataset, labels_dataset, targets_dataset, batch_size
+ )
+
+ assert hasattr(projection, "__call__"), "projection should be a callable."
+
+ # check projection type
+ if isinstance(projection, Projection):
+ self.projection = projection
+ elif hasattr(projection, "__call__"):
+ self.projection = Projection(get_weights=None, space_projection=projection)
+ else:
+ raise AttributeError(
+ "projection should be a `Projection` or a `Callable`, not a"
+ + f"{type(projection)}"
+ )
+
+ # project dataset
+ projected_cases_dataset = self.projection.project_dataset(self.cases_dataset,
+ self.targets_dataset)
+
+ # set `search_returns` if not provided and overwrite it otherwise
+ if isinstance(case_returns, list) and ("nuns" in case_returns):
+ search_method_returns = ["indices", "distances", "nuns"]
+ else:
+ search_method_returns = ["indices", "distances"]
+
+ # initiate search_method
+ self.search_method = search_method(
+ cases_dataset=projected_cases_dataset,
+ targets_dataset=self.targets_dataset,
k=k,
- projection=projection,
- case_returns=case_returns,
+ search_returns=search_method_returns,
batch_size=batch_size,
distance=distance,
- filter_fn=self.filter_fn,
- order = ORDER.ASCENDING
)
-
- def filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor:
+ def set_returns(self, returns: Union[List[str], str]):
"""
- Filter function to mask the cases for which the label is different from the predicted
- label on the inputs.
+ Set `self.returns` used to define returned elements in `self.explain()`.
+
+ Parameters
+ ----------
+ returns
+ Most elements are useful in `xplique.plots.plot_examples()`.
+ `returns` can be set to 'all' for all possible elements to be returned.
+ - 'examples' correspond to the expected examples,
+ the inputs may be included in first position. (n, k(+1), ...)
+ - 'weights' the weights in the input space used in the projection.
+ They are associated to the input and the examples. (n, k(+1), ...)
+ - 'distances' the distances between the inputs and the corresponding examples.
+ They are associated to the examples. (n, k, ...)
+ - 'labels' if provided through `dataset_labels`,
+ they are the labels associated with the examples. (n, k, ...)
+ - 'include_inputs' specify if inputs should be included in the returned elements.
+ Note that it changes the number of returned elements from k to k+1.
"""
- # get the labels predicted by the model
- # (n, )
- predicted_labels = tf.argmax(targets, axis=-1)
+ possibilities = ["examples", "weights", "distances", "labels", "include_inputs", "nuns"]
+ default = "examples"
+ self.returns = _sanitize_returns(returns, possibilities, default)
- # for each input, if the target label is the same as the predicted label
- # the mask as a True value and False otherwise
- label_targets = tf.argmax(cases_targets, axis=-1) # (bs,)
- mask = tf.not_equal(tf.expand_dims(predicted_labels, axis=1), label_targets) #(n, bs)
- return mask
+ def format_search_output(
+ self,
+ search_output: Dict[str, tf.Tensor],
+ inputs: Union[tf.Tensor, np.ndarray],
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ ):
+ """
+ """
+ return_dict = super().format_search_output(search_output, inputs, targets)
+ if "nuns" in self.returns:
+ if isinstance(return_dict, dict):
+ return_dict["nuns"] = search_output["nuns"]
+ else:
+ # find the other only key
+ other_key = [k for k in self.returns if k != "nuns"][0]
+ return_dict = {
+ other_key: return_dict,
+ "nuns": search_output["nuns"]
+ }
+ return return_dict
From c6be204c930ae25af17bdca9b9906a280554ed19 Mon Sep 17 00:00:00 2001
From: Lucas Hervier
Date: Tue, 21 May 2024 15:30:07 +0200
Subject: [PATCH 044/138] feat: add the kleor method in the package init
---
xplique/example_based/__init__.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/xplique/example_based/__init__.py b/xplique/example_based/__init__.py
index 89e08d1b..5e23b9d6 100644
--- a/xplique/example_based/__init__.py
+++ b/xplique/example_based/__init__.py
@@ -4,4 +4,4 @@
from .cole import Cole
from .similar_examples import SimilarExamples
-from .contrastive_examples import NaiveSemiFactuals, PredictedLabelAwareSemiFactuals, NaiveCounterFactuals
+from .contrastive_examples import NaiveCounterFactuals, LabelAwareCounterFactuals, KLEOR
From aa9a50001846d35506478e350bde3d205cd60839 Mon Sep 17 00:00:00 2001
From: Lucas Hervier
Date: Tue, 21 May 2024 17:45:23 +0200
Subject: [PATCH 045/138] fix: change the set_returns and set_k methods to
properties with a setter for easier factorization. Some changes to improve
factorization and update of tests accordingly
---
tests/example_based/test_knn.py | 4 +-
tests/example_based/test_similar_examples.py | 15 ++--
xplique/example_based/base_example_method.py | 90 ++++++++++---------
xplique/example_based/contrastive_examples.py | 88 +++++-------------
xplique/example_based/search_methods/base.py | 47 +++++-----
xplique/example_based/search_methods/kleor.py | 31 ++-----
xplique/example_based/search_methods/knn.py | 24 ++---
7 files changed, 120 insertions(+), 179 deletions(-)
diff --git a/tests/example_based/test_knn.py b/tests/example_based/test_knn.py
index 63d4d504..61740a0e 100644
--- a/tests/example_based/test_knn.py
+++ b/tests/example_based/test_knn.py
@@ -106,14 +106,14 @@ def test_base_find_examples():
search_returns = returns,
)
return_dict = mock_knn.find_examples(inputs)
- assert return_dict.shape == (5, 3, 3)
+ assert return_dict["examples"].shape == (5, 3, 3)
mock_knn = MockKNN(
tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.], [10., 11., 12.]], dtype=tf.float32),
k = 2,
)
return_dict = mock_knn.find_examples(inputs)
- assert return_dict.shape == (5, 2, 3)
+ assert return_dict["examples"].shape == (5, 2, 3)
def test_knn_init():
"""
diff --git a/tests/example_based/test_similar_examples.py b/tests/example_based/test_similar_examples.py
index 2ec371d3..db4af594 100644
--- a/tests/example_based/test_similar_examples.py
+++ b/tests/example_based/test_similar_examples.py
@@ -12,12 +12,10 @@
import numpy as np
import tensorflow as tf
-from xplique.commons import sanitize_dataset, are_dataset_first_elems_equal
-from xplique.types import Union
+from xplique.commons import are_dataset_first_elems_equal
from xplique.example_based import SimilarExamples
-from xplique.example_based.projections import Projection, LatentSpaceProjection
-from xplique.example_based.search_methods import KNN
+from xplique.example_based.projections import Projection
from tests.utils import almost_equal
@@ -154,7 +152,7 @@ def test_similar_examples_basic():
)
# Generate explanation
- examples = method.explain(x_test)
+ examples = method.explain(x_test)["examples"]
# Verifications
# Shape should be (n, k, h, w, c)
@@ -198,9 +196,8 @@ def test_similar_examples_return_multiple_elements():
distance="euclidean",
)
- method.set_returns("all")
-
- method.set_k(k)
+ method.returns = "all"
+ method.k = k
# Generate explanation
method_output = method.explain(x_test)
@@ -278,7 +275,7 @@ def test_similar_examples_weighting():
)
# Generate explanation
- examples = method.explain(x_test)
+ examples = method.explain(x_test)["examples"]
# Verifications
# Shape should be (n, k, h, w, c)
diff --git a/xplique/example_based/base_example_method.py b/xplique/example_based/base_example_method.py
index 9ce6b154..f45bdc91 100644
--- a/xplique/example_based/base_example_method.py
+++ b/xplique/example_based/base_example_method.py
@@ -72,6 +72,7 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
search_method_kwargs
Parameters to be passed at the construction of the `search_method`.
"""
+ _returns_possibilities = ["examples", "weights", "distances", "labels", "include_inputs"]
def __init__(
self,
@@ -94,9 +95,7 @@ def __init__(
cases_dataset, labels_dataset, targets_dataset, batch_size
)
- self.k = k
- self.set_returns(case_returns)
-
+ self._search_returns = ["indices", "distances"]
assert hasattr(projection, "__call__"), "projection should be a callable."
# check projection type
@@ -115,7 +114,7 @@ def __init__(
self.targets_dataset)
# set `search_returns` if not provided and overwrite it otherwise
- search_method_kwargs["search_returns"] = ["indices", "distances"]
+ search_method_kwargs["search_returns"] = self._search_returns
# initiate search_method
self.search_method = search_method(
@@ -125,6 +124,49 @@ def __init__(
targets_dataset=self.targets_dataset,
**search_method_kwargs,
)
+ self.k = k
+ self.returns = case_returns
+
+ @property
+ def k(self) -> int:
+ """Getter for the k parameter."""
+ return self._k
+
+ @k.setter
+ def k(self, k: int):
+ """Setter for the k parameter."""
+ assert isinstance(k, int) and k >= 1, f"k should be an int >= 1 and not {k}"
+ self._k = k
+ self.search_method.k = k
+
+ @property
+ def returns(self) -> Union[List[str], str]:
+ """Getter for the returns parameter."""
+ return self._returns
+
+ @returns.setter
+ def returns(self, returns: Union[List[str], str]):
+ """
+ Setter for the returns parameter used to define returned elements in `self.explain()`.
+
+ Parameters
+ ----------
+ returns
+ Most elements are useful in `xplique.plots.plot_examples()`.
+ `returns` can be set to 'all' for all possible elements to be returned.
+ - 'examples' correspond to the expected examples,
+ the inputs may be included in first position. (n, k(+1), ...)
+ - 'weights' the weights in the input space used in the projection.
+ They are associated to the input and the examples. (n, k(+1), ...)
+ - 'distances' the distances between the inputs and the corresponding examples.
+ They are associated to the examples. (n, k, ...)
+ - 'labels' if provided through `dataset_labels`,
+ they are the labels associated with the examples. (n, k, ...)
+ - 'include_inputs' specify if inputs should be included in the returned elements.
+ Note that it changes the number of returned elements from k to k+1.
+ """
+ default = "examples"
+ self._returns = _sanitize_returns(returns, self._returns_possibilities, default)
def _initialize_cases_dataset(
self,
@@ -222,43 +264,6 @@ def _initialize_cases_dataset(
return batch_size
- def set_k(self, k: int):
- """
- Setter for the k parameter.
-
- Parameters
- ----------
- k
- Number of examples to return, it should be a positive integer.
- """
- assert isinstance(k, int) and k >= 1, f"k should be an int >= 1 and not {k}"
- self.k = k
- self.search_method.set_k(k)
-
- def set_returns(self, returns: Union[List[str], str]):
- """
- Set `self.returns` used to define returned elements in `self.explain()`.
-
- Parameters
- ----------
- returns
- Most elements are useful in `xplique.plots.plot_examples()`.
- `returns` can be set to 'all' for all possible elements to be returned.
- - 'examples' correspond to the expected examples,
- the inputs may be included in first position. (n, k(+1), ...)
- - 'weights' the weights in the input space used in the projection.
- They are associated to the input and the examples. (n, k(+1), ...)
- - 'distances' the distances between the inputs and the corresponding examples.
- They are associated to the examples. (n, k, ...)
- - 'labels' if provided through `dataset_labels`,
- they are the labels associated with the examples. (n, k, ...)
- - 'include_inputs' specify if inputs should be included in the returned elements.
- Note that it changes the number of returned elements from k to k+1.
- """
- possibilities = ["examples", "weights", "distances", "labels", "include_inputs"]
- default = "examples"
- self.returns = _sanitize_returns(returns, possibilities, default)
-
@sanitize_inputs_targets
def explain(
self,
@@ -392,7 +397,4 @@ def format_search_output(
), "The method cannot return labels without a label dataset."
return_dict["labels"] = examples_labels
- # return a dict only different variables are returned
- if len(return_dict) == 1:
- return list(return_dict.values())[0]
return return_dict
diff --git a/xplique/example_based/contrastive_examples.py b/xplique/example_based/contrastive_examples.py
index 0b996f89..0c584ae0 100644
--- a/xplique/example_based/contrastive_examples.py
+++ b/xplique/example_based/contrastive_examples.py
@@ -154,6 +154,8 @@ def explain(
class KLEOR(BaseExampleMethod):
"""
"""
+ _returns_possibilities = ["examples", "weights", "distances", "labels", "include_inputs", "nuns"]
+
def __init__(
self,
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
@@ -166,9 +168,6 @@ def __init__(
distance: Union[int, str, Callable] = "euclidean",
strategy: str = "sim_miss",
):
-
- self.k = k
- self.set_returns(case_returns)
if strategy == "global_sim":
search_method = KLEORGlobalSim
@@ -180,67 +179,34 @@ def __init__(
if projection is None:
projection = Projection(space_projection=lambda inputs: inputs)
- # set attributes
- batch_size = super()._initialize_cases_dataset(
- cases_dataset, labels_dataset, targets_dataset, batch_size
- )
-
- assert hasattr(projection, "__call__"), "projection should be a callable."
-
- # check projection type
- if isinstance(projection, Projection):
- self.projection = projection
- elif hasattr(projection, "__call__"):
- self.projection = Projection(get_weights=None, space_projection=projection)
- else:
- raise AttributeError(
- "projection should be a `Projection` or a `Callable`, not a"
- + f"{type(projection)}"
- )
-
- # project dataset
- projected_cases_dataset = self.projection.project_dataset(self.cases_dataset,
- self.targets_dataset)
-
- # set `search_returns` if not provided and overwrite it otherwise
- if isinstance(case_returns, list) and ("nuns" in case_returns):
- search_method_returns = ["indices", "distances", "nuns"]
- else:
- search_method_returns = ["indices", "distances"]
-
- # initiate search_method
- self.search_method = search_method(
- cases_dataset=projected_cases_dataset,
- targets_dataset=self.targets_dataset,
+ super().__init__(
+ cases_dataset=cases_dataset,
+ labels_dataset=labels_dataset,
+ targets_dataset=targets_dataset,
+ search_method=search_method,
k=k,
- search_returns=search_method_returns,
+ projection=projection,
+ case_returns=case_returns,
batch_size=batch_size,
distance=distance,
)
- def set_returns(self, returns: Union[List[str], str]):
- """
- Set `self.returns` used to define returned elements in `self.explain()`.
+ @property
+ def returns(self) -> Union[List[str], str]:
+ """Getter for the returns parameter."""
+ return self._returns
- Parameters
- ----------
- returns
- Most elements are useful in `xplique.plots.plot_examples()`.
- `returns` can be set to 'all' for all possible elements to be returned.
- - 'examples' correspond to the expected examples,
- the inputs may be included in first position. (n, k(+1), ...)
- - 'weights' the weights in the input space used in the projection.
- They are associated to the input and the examples. (n, k(+1), ...)
- - 'distances' the distances between the inputs and the corresponding examples.
- They are associated to the examples. (n, k, ...)
- - 'labels' if provided through `dataset_labels`,
- they are the labels associated with the examples. (n, k, ...)
- - 'include_inputs' specify if inputs should be included in the returned elements.
- Note that it changes the number of returned elements from k to k+1.
+ @returns.setter
+ def returns(self, returns: Union[List[str], str]):
+ """
"""
- possibilities = ["examples", "weights", "distances", "labels", "include_inputs", "nuns"]
default = "examples"
- self.returns = _sanitize_returns(returns, possibilities, default)
+ self._returns = _sanitize_returns(returns, self._returns_possibilities, default)
+ if isinstance(self._returns, list) and ("nuns" in self._returns):
+ self._search_returns = ["indices", "distances", "nuns"]
+ else:
+ self._search_returns = ["indices", "distances"]
+ self.search_method.returns = self._search_returns
def format_search_output(
self,
@@ -252,13 +218,5 @@ def format_search_output(
"""
return_dict = super().format_search_output(search_output, inputs, targets)
if "nuns" in self.returns:
- if isinstance(return_dict, dict):
- return_dict["nuns"] = search_output["nuns"]
- else:
- # find the other only key
- other_key = [k for k in self.returns if k != "nuns"][0]
- return_dict = {
- other_key: return_dict,
- "nuns": search_output["nuns"]
- }
+ return_dict["nuns"] = search_output["nuns"]
return return_dict
diff --git a/xplique/example_based/search_methods/base.py b/xplique/example_based/search_methods/base.py
index a7bf4e02..018f3ad4 100644
--- a/xplique/example_based/search_methods/base.py
+++ b/xplique/example_based/search_methods/base.py
@@ -7,7 +7,7 @@
import tensorflow as tf
import numpy as np
-from ...types import Callable, Union, Optional, List
+from ...types import Union, Optional, List
from ...commons import sanitize_dataset
@@ -84,6 +84,7 @@ class BaseSearchMethod(ABC):
Number of sample treated simultaneously.
It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
"""
+ _returns_possibilities = ["examples", "indices", "distances", "include_inputs"]
def __init__(
self,
@@ -92,7 +93,6 @@ def __init__(
search_returns: Optional[Union[List[str], str]] = None,
batch_size: Optional[int] = 32,
targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- possibilities: Optional[List[str]] = None,
): # pylint: disable=R0801
# set batch size
@@ -103,8 +103,8 @@ def __init__(
self.cases_dataset = sanitize_dataset(cases_dataset, self.batch_size)
- self.set_k(k)
- self.set_returns(search_returns, possibilities)
+ self.k = k
+ self.returns = search_returns
# set targets_dataset
if targets_dataset is not None:
@@ -113,22 +113,26 @@ def __init__(
# make an iterable of None
self.targets_dataset = [None]*len(cases_dataset)
- def set_k(self, k: int):
- """
- Change value of k with constructing a new `BaseSearchMethod`.
- It is useful because the constructor can be computationally expensive.
+ @property
+ def k(self) -> int:
+ """Getter for the k parameter."""
+ return self._k
- Parameters
- ----------
- k
- The number of examples to retrieve.
- """
+ @k.setter
+ def k(self, k: int):
+ """Setter for the k parameter."""
assert isinstance(k, int) and k >= 1, f"k should be an int >= 1 and not {k}"
- self.k = k
+ self._k = k
- def set_returns(self, returns: Optional[Union[List[str], str]] = None, possibilities: Optional[List[str]] = None):
+ @property
+ def returns(self) -> Union[List[str], str]:
+ """Getter for the returns parameter."""
+ return self._returns
+
+ @returns.setter
+ def returns(self, returns: Union[List[str], str]):
"""
- Set `self.returns` used to define returned elements in `self.find_examples()`.
+ Setter for the returns parameter used to define returned elements in `self.explain()`.
Parameters
----------
@@ -137,18 +141,17 @@ def set_returns(self, returns: Optional[Union[List[str], str]] = None, possibili
`returns` can be set to 'all' for all possible elements to be returned.
- 'examples' correspond to the expected examples,
the inputs may be included in first position. (n, k(+1), ...)
- - 'indices' the indices of the examples in the `search_set`.
- Used to retrieve the original example and labels. (n, k, ...)
+ - 'weights' the weights in the input space used in the projection.
+ They are associated to the input and the examples. (n, k(+1), ...)
- 'distances' the distances between the inputs and the corresponding examples.
They are associated to the examples. (n, k, ...)
+ - 'labels' if provided through `dataset_labels`,
+ they are the labels associated with the examples. (n, k, ...)
- 'include_inputs' specify if inputs should be included in the returned elements.
Note that it changes the number of returned elements from k to k+1.
"""
- if possibilities is None:
- possibilities = ["examples", "indices", "distances", "include_inputs"]
default = "examples"
- self.returns = _sanitize_returns(returns, possibilities, default)
-
+ self._returns = _sanitize_returns(returns, self._returns_possibilities, default)
@abstractmethod
def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None):
diff --git a/xplique/example_based/search_methods/kleor.py b/xplique/example_based/search_methods/kleor.py
index 380d668a..c62c14cd 100644
--- a/xplique/example_based/search_methods/kleor.py
+++ b/xplique/example_based/search_methods/kleor.py
@@ -25,7 +25,6 @@ def __init__(
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
): # pylint: disable=R0801
- possibilities = ["examples", "indices", "distances", "include_inputs", "nuns"]
super().__init__(
cases_dataset = cases_dataset,
targets_dataset=targets_dataset,
@@ -35,7 +34,6 @@ def __init__(
batch_size=batch_size,
distance=distance,
order=ORDER.ASCENDING,
- possibilities=possibilities
)
self.search_nuns = FilterKNN(
@@ -65,32 +63,13 @@ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[
# compute neighbors
examples_distances, examples_indices, nuns = self.kneighbors(inputs, targets)
- # Set values in return dict
- return_dict = {}
- if "examples" in self.returns:
- return_dict["examples"] = dataset_gather(self.cases_dataset, examples_indices)
- # replace examples for which indices is -1, -1 by an inf value
- # mask = tf.reduce_all(tf.equal(examples_indices, -1), axis=-1)
- # return_dict["examples"] = tf.where(
- # tf.expand_dims(mask, axis=-1),
- # tf.fill(return_dict["examples"].shape, tf.constant(np.inf, dtype=tf.float32)),
- # return_dict["examples"],
- # )
- if "include_inputs" in self.returns:
- inputs = tf.expand_dims(inputs, axis=1)
- return_dict["examples"] = tf.concat(
- [inputs, return_dict["examples"]], axis=1
- )
+ # build return dict
+ return_dict = self._build_return_dict(inputs, examples_distances, examples_indices)
+
+ # add the nuns if needed
if "nuns" in self.returns:
return_dict["nuns"] = nuns
- if "indices" in self.returns:
- return_dict["indices"] = examples_indices
- if "distances" in self.returns:
- return_dict["distances"] = examples_distances
-
- # Return a dict only different variables are returned
- if len(return_dict) == 1:
- return list(return_dict.values())[0]
+
return return_dict
def _filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor:
diff --git a/xplique/example_based/search_methods/knn.py b/xplique/example_based/search_methods/knn.py
index e53833cd..be686b45 100644
--- a/xplique/example_based/search_methods/knn.py
+++ b/xplique/example_based/search_methods/knn.py
@@ -1,13 +1,12 @@
"""
KNN online search method in example-based module
"""
-import math
from abc import abstractmethod
import numpy as np
import tensorflow as tf
-from ...commons import dataset_gather, sanitize_dataset
+from ...commons import dataset_gather
from ...types import Callable, List, Union, Optional, Tuple
from .base import BaseSearchMethod, ORDER
@@ -23,10 +22,9 @@ def __init__(
batch_size: Optional[int] = 32,
order: ORDER = ORDER.ASCENDING,
targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- possibilities: Optional[List[str]] = None,
):
super().__init__(
- cases_dataset, k, search_returns, batch_size, targets_dataset, possibilities
+ cases_dataset, k, search_returns, batch_size, targets_dataset
)
assert isinstance(order, ORDER), f"order should be an instance of ORDER and not {type(order)}"
@@ -79,6 +77,15 @@ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[
# compute neighbors
examples_distances, examples_indices = self.kneighbors(inputs, targets)
+ # build the return dict
+ return_dict = self._build_return_dict(inputs, examples_distances, examples_indices)
+
+ return return_dict
+
+ def _build_return_dict(self, inputs, examples_distances, examples_indices):
+ """
+ TODO: Change the description
+ """
# Set values in return dict
return_dict = {}
if "examples" in self.returns:
@@ -93,9 +100,6 @@ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[
if "distances" in self.returns:
return_dict["distances"] = examples_distances
- # Return a dict only different variables are returned
- if len(return_dict) == 1:
- return list(return_dict.values())[0]
return return_dict
class KNN(BaseKNN):
@@ -131,10 +135,9 @@ def __init__(
distance: Union[int, str, Callable] = "euclidean",
order: ORDER = ORDER.ASCENDING,
targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- possibilities: Optional[List[str]] = None,
): # pylint: disable=R0801
super().__init__(
- cases_dataset, k, search_returns, batch_size, order, targets_dataset, possibilities
+ cases_dataset, k, search_returns, batch_size, order, targets_dataset
)
if hasattr(distance, "__call__"):
@@ -275,10 +278,9 @@ def __init__(
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
order: ORDER = ORDER.ASCENDING,
- possibilities: Optional[List[str]] = None,
): # pylint: disable=R0801
super().__init__(
- cases_dataset, k, search_returns, batch_size, order, targets_dataset, possibilities
+ cases_dataset, k, search_returns, batch_size, order, targets_dataset
)
if hasattr(distance, "__call__"):
From 9d5062e21c03758be62ddafe703d260d0fd317bb Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 22 May 2024 17:32:08 +0200
Subject: [PATCH 046/138] example based: merge and solve part of problems and
refacto
---
tests/example_based/test_cole.py | 10 +-
xplique/example_based/__init__.py | 5 +-
xplique/example_based/base_example_method.py | 36 ++-
xplique/example_based/cole.py | 20 +-
xplique/example_based/contrastive_examples.py | 65 ++++-
xplique/example_based/mmd_critic.py | 100 -------
xplique/example_based/proto_dash.py | 100 -------
xplique/example_based/proto_greedy.py | 100 -------
xplique/example_based/prototypes.py | 268 +++++++++---------
xplique/example_based/search_methods/base.py | 8 -
xplique/example_based/search_methods/kleor.py | 4 +-
xplique/example_based/search_methods/knn.py | 30 +-
.../search_methods/proto_greedy_search.py | 6 +-
xplique/example_based/similar_examples.py | 21 +-
14 files changed, 259 insertions(+), 514 deletions(-)
delete mode 100644 xplique/example_based/mmd_critic.py
delete mode 100644 xplique/example_based/proto_dash.py
delete mode 100644 xplique/example_based/proto_greedy.py
diff --git a/tests/example_based/test_cole.py b/tests/example_based/test_cole.py
index a9dc1afe..ba71d5d3 100644
--- a/tests/example_based/test_cole.py
+++ b/tests/example_based/test_cole.py
@@ -104,9 +104,9 @@ def test_cole_attribution():
)
# Generate explanation
- examples_constructor = method_constructor.explain(x_test, y_test)
- examples_call = method_call.explain(x_test, y_test)
- examples_different_distance = method_different_distance(x_test, y_test)
+ examples_constructor = method_constructor.explain(x_test, y_test)["examples"]
+ examples_call = method_call.explain(x_test, y_test)["examples"]
+ examples_different_distance = method_different_distance(x_test, y_test)["examples"]
# Verifications
# Shape should be (n, k, h, w, c)
@@ -166,8 +166,8 @@ def test_cole_hadamard():
)
# Generate explanation
- examples_constructor = method_constructor.explain(x_test, y_test)
- examples_call = method_call.explain(x_test, y_test)
+ examples_constructor = method_constructor.explain(x_test, y_test)["examples"]
+ examples_call = method_call.explain(x_test, y_test)["examples"]
# Verifications
# Shape should be (n, k, h, w, c)
diff --git a/xplique/example_based/__init__.py b/xplique/example_based/__init__.py
index 7a174e91..e1d70b05 100644
--- a/xplique/example_based/__init__.py
+++ b/xplique/example_based/__init__.py
@@ -4,8 +4,5 @@
from .cole import Cole
from .similar_examples import SimilarExamples
-from .prototypes import Prototypes
-from .proto_greedy import ProtoGreedy
-from .proto_dash import ProtoDash
-from .mmd_critic import MMDCritic
+from .prototypes import Prototypes, ProtoGreedy, ProtoDash, MMDCritic
from .contrastive_examples import NaiveCounterFactuals, LabelAwareCounterFactuals, KLEOR
diff --git a/xplique/example_based/base_example_method.py b/xplique/example_based/base_example_method.py
index f45bdc91..0f4c1dfa 100644
--- a/xplique/example_based/base_example_method.py
+++ b/xplique/example_based/base_example_method.py
@@ -2,6 +2,8 @@
Base model for example-based
"""
+from abc import ABC, abstractmethod
+
import math
import tensorflow as tf
@@ -17,7 +19,7 @@
from .search_methods.base import _sanitize_returns
-class BaseExampleMethod:
+class BaseExampleMethod(ABC):
"""
Base class for natural example-based methods explaining models,
they project the cases_dataset into a pertinent space for the with a `Projection`,
@@ -79,19 +81,17 @@ def __init__(
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- search_method: Type[BaseSearchMethod] = KNN,
k: int = 1,
projection: Union[Projection, Callable] = None,
case_returns: Union[List[str], str] = "examples",
batch_size: Optional[int] = 32,
- **search_method_kwargs,
):
assert (
projection is not None
), "`BaseExampleMethod` without `projection` is a `BaseSearchMethod`."
# set attributes
- batch_size = self._initialize_cases_dataset(
+ self.batch_size = self._initialize_cases_dataset(
cases_dataset, labels_dataset, targets_dataset, batch_size
)
@@ -110,22 +110,16 @@ def __init__(
)
# project dataset
- projected_cases_dataset = self.projection.project_dataset(self.cases_dataset,
- self.targets_dataset)
-
- # set `search_returns` if not provided and overwrite it otherwise
- search_method_kwargs["search_returns"] = self._search_returns
-
- # initiate search_method
- self.search_method = search_method(
- cases_dataset=projected_cases_dataset,
- k=k,
- batch_size=batch_size,
- targets_dataset=self.targets_dataset,
- **search_method_kwargs,
- )
+ self.projected_cases_dataset = self.projection.project_dataset(self.cases_dataset,
+ self.targets_dataset)
+
self.k = k
self.returns = case_returns
+
+ @property
+ @abstractmethod
+ def search_method_class(self) -> Type[BaseSearchMethod]:
+ raise NotImplementedError
@property
def k(self) -> int:
@@ -137,7 +131,11 @@ def k(self, k: int):
"""Setter for the k parameter."""
assert isinstance(k, int) and k >= 1, f"k should be an int >= 1 and not {k}"
self._k = k
- self.search_method.k = k
+
+ try:
+ self.search_method.k = k
+ except AttributeError:
+ pass
@property
def returns(self) -> Union[List[str], str]:
diff --git a/xplique/example_based/cole.py b/xplique/example_based/cole.py
index 3fdfc82f..ca203b12 100644
--- a/xplique/example_based/cole.py
+++ b/xplique/example_based/cole.py
@@ -88,7 +88,7 @@ def __init__(
distance: Union[str, Callable] = "euclidean",
case_returns: Optional[Union[List[str], str]] = "examples",
batch_size: Optional[int] = 32,
- device: Optional[str] = None,
+ # device: Optional[str] = None,
latent_layer: Optional[Union[str, int]] = None,
attribution_method: Union[str, Type[BlackBoxExplainer]] = "gradient",
**attribution_kwargs,
@@ -104,7 +104,7 @@ def __init__(
model=model,
latent_layer=latent_layer,
operator=operator,
- device=device,
+ # device=device,
)
elif issubclass(attribution_method, BlackBoxExplainer):
# build attribution projection
@@ -112,7 +112,7 @@ def __init__(
model=model,
method=attribution_method,
latent_layer=latent_layer,
- device=device,
+ # device=device,
**attribution_kwargs,
)
else:
@@ -122,12 +122,12 @@ def __init__(
)
super().__init__(
- cases_dataset,
- labels_dataset,
- targets_dataset,
- k,
- projection,
- case_returns,
- batch_size,
+ cases_dataset=cases_dataset,
+ targets_dataset=targets_dataset,
+ labels_dataset=labels_dataset,
+ projection=projection,
+ k=k,
+ case_returns=case_returns,
+ batch_size=batch_size,
distance=distance,
)
diff --git a/xplique/example_based/contrastive_examples.py b/xplique/example_based/contrastive_examples.py
index 0c584ae0..05b1d9ad 100644
--- a/xplique/example_based/contrastive_examples.py
+++ b/xplique/example_based/contrastive_examples.py
@@ -32,7 +32,6 @@ def __init__(
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
):
- search_method = FilterKNN
if projection is None:
projection = Projection(space_projection=lambda inputs: inputs)
@@ -41,15 +40,29 @@ def __init__(
cases_dataset=cases_dataset,
labels_dataset=labels_dataset,
targets_dataset=targets_dataset,
- search_method=search_method,
k=k,
projection=projection,
case_returns=case_returns,
batch_size=batch_size,
+ )
+
+ self.distance = distance
+ self.order = ORDER.ASCENDING
+
+ self.search_method = self.search_method_class(
+ cases_dataset=self.cases_dataset,
+ targets_dataset=self.targets_dataset,
+ k=self.k,
+ search_returns=self._search_returns,
+ batch_size=self.batch_size,
distance=distance,
filter_fn=self.filter_fn,
- order = ORDER.ASCENDING
+ order=self.order
)
+
+ @property
+ def search_method_class(self):
+ return FilterKNN
def filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor:
@@ -83,8 +96,6 @@ def __init__(
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
):
- search_method = FilterKNN
-
if projection is None:
projection = Projection(space_projection=lambda inputs: inputs)
# TODO: add a warning here if it is a custom projection that requires using targets as it might mismatch with the explain
@@ -93,15 +104,30 @@ def __init__(
cases_dataset=cases_dataset,
labels_dataset=labels_dataset,
targets_dataset=targets_dataset,
- search_method=search_method,
k=k,
projection=projection,
case_returns=case_returns,
batch_size=batch_size,
+ )
+
+ self.distance = distance
+ self.order = ORDER.ASCENDING
+
+ self.search_method = self.search_method_class(
+ cases_dataset=self.cases_dataset,
+ targets_dataset=self.targets_dataset,
+ k=self.k,
+ search_returns=self._search_returns,
+ batch_size=self.batch_size,
distance=distance,
filter_fn=self.filter_fn,
- order = ORDER.ASCENDING
+ order=self.order
)
+
+ @property
+ def search_method_class(self):
+ return FilterKNN
+
def filter_fn(self, _, __, cf_targets, cases_targets) -> tf.Tensor:
"""
@@ -183,13 +209,30 @@ def __init__(
cases_dataset=cases_dataset,
labels_dataset=labels_dataset,
targets_dataset=targets_dataset,
- search_method=search_method,
k=k,
projection=projection,
case_returns=case_returns,
batch_size=batch_size,
+ )
+
+ self.distance = distance
+ self.order = ORDER.ASCENDING
+
+ self.search_method = self.search_method_class(
+ cases_dataset=self.cases_dataset,
+ targets_dataset=self.targets_dataset,
+ k=self.k,
+ search_returns=self._search_returns,
+ batch_size=self.batch_size,
distance=distance,
+ filter_fn=self.filter_fn,
+ order=self.order
)
+
+ @property
+ def search_method_class(self):
+ return FilterKNN
+
@property
def returns(self) -> Union[List[str], str]:
@@ -206,7 +249,11 @@ def returns(self, returns: Union[List[str], str]):
self._search_returns = ["indices", "distances", "nuns"]
else:
self._search_returns = ["indices", "distances"]
- self.search_method.returns = self._search_returns
+
+ try:
+ self.search_method.returns = self._search_returns
+ except AttributeError:
+ pass
def format_search_output(
self,
diff --git a/xplique/example_based/mmd_critic.py b/xplique/example_based/mmd_critic.py
deleted file mode 100644
index a2ccfb47..00000000
--- a/xplique/example_based/mmd_critic.py
+++ /dev/null
@@ -1,100 +0,0 @@
-"""
-MMDCritic method for searching prototypes
-"""
-
-import math
-
-import time
-
-import tensorflow as tf
-import numpy as np
-
-from ..types import Callable, Dict, List, Optional, Type, Union
-
-from ..commons import sanitize_inputs_targets
-from ..commons import sanitize_dataset, dataset_gather
-from .search_methods import MMDCriticSearch
-from .projections import Projection
-from .prototypes import Prototypes
-
-from .search_methods.base import _sanitize_returns
-
-
-class MMDCritic(Prototypes):
- """
- MMDCritic method for searching prototypes.
-
- Parameters
- ----------
- cases_dataset
- The dataset used to train the model, examples are extracted from the dataset.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- labels_dataset
- Labels associated to the examples in the dataset. Indices should match with cases_dataset.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other dataset should match `cases_dataset`.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- targets_dataset
- Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other dataset should match `cases_dataset`.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- k
- The number of examples to retrieve.
- projection
- Projection or Callable that project samples from the input space to the search space.
- The search space should be a space where distance make sense for the model.
- It should not be `None`, otherwise,
- all examples could be computed only with the `search_method`.
-
- Example of Callable:
- ```
- def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
- '''
- Example of projection,
- inputs are the elements to project.
- targets are optional parameters to orientated the projection.
- '''
- projected_inputs = # do some magic on inputs, it should use the model.
- return projected_inputs
- ```
- case_returns
- String or list of string with the elements to return in `self.explain()`.
- See `self.set_returns()` for detail.
- batch_size
- Number of sample treated simultaneously for projection and search.
- Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
- search_method_kwargs
- Parameters to be passed at the construction of the `search_method`.
- """
-
- def __init__(
- self,
- cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- k: int = 1,
- projection: Union[Projection, Callable] = None,
- case_returns: Union[List[str], str] = "examples",
- batch_size: Optional[int] = 32,
- **search_method_kwargs,
- ):
- # the only difference with parent is that the search method is always MMDCriticSearch
- search_method = MMDCriticSearch
-
- super().__init__(
- cases_dataset=cases_dataset,
- labels_dataset=labels_dataset,
- targets_dataset=targets_dataset,
- search_method=search_method,
- k=k,
- projection=projection,
- case_returns=case_returns,
- batch_size=batch_size,
- **search_method_kwargs,
- )
-
diff --git a/xplique/example_based/proto_dash.py b/xplique/example_based/proto_dash.py
deleted file mode 100644
index 475e138b..00000000
--- a/xplique/example_based/proto_dash.py
+++ /dev/null
@@ -1,100 +0,0 @@
-"""
-ProtoDash method for searching prototypes
-"""
-
-import math
-
-import time
-
-import tensorflow as tf
-import numpy as np
-
-from ..types import Callable, Dict, List, Optional, Type, Union
-
-from ..commons import sanitize_inputs_targets
-from ..commons import sanitize_dataset, dataset_gather
-from .search_methods import ProtoDashSearch
-from .projections import Projection
-from .prototypes import Prototypes
-
-from .search_methods.base import _sanitize_returns
-
-
-class ProtoDash(Prototypes):
- """
- ProtoDash method for searching prototypes.
-
- Parameters
- ----------
- cases_dataset
- The dataset used to train the model, examples are extracted from the dataset.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- labels_dataset
- Labels associated to the examples in the dataset. Indices should match with cases_dataset.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other dataset should match `cases_dataset`.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- targets_dataset
- Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other dataset should match `cases_dataset`.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- k
- The number of examples to retrieve.
- projection
- Projection or Callable that project samples from the input space to the search space.
- The search space should be a space where distance make sense for the model.
- It should not be `None`, otherwise,
- all examples could be computed only with the `search_method`.
-
- Example of Callable:
- ```
- def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
- '''
- Example of projection,
- inputs are the elements to project.
- targets are optional parameters to orientated the projection.
- '''
- projected_inputs = # do some magic on inputs, it should use the model.
- return projected_inputs
- ```
- case_returns
- String or list of string with the elements to return in `self.explain()`.
- See `self.set_returns()` for detail.
- batch_size
- Number of sample treated simultaneously for projection and search.
- Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
- search_method_kwargs
- Parameters to be passed at the construction of the `search_method`.
- """
-
- def __init__(
- self,
- cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- k: int = 1,
- projection: Union[Projection, Callable] = None,
- case_returns: Union[List[str], str] = "examples",
- batch_size: Optional[int] = 32,
- **search_method_kwargs,
- ):
- # the only difference with parent is that the search method is always ProtoDashSearch
- search_method = ProtoDashSearch
-
- super().__init__(
- cases_dataset=cases_dataset,
- labels_dataset=labels_dataset,
- targets_dataset=targets_dataset,
- search_method=search_method,
- k=k,
- projection=projection,
- case_returns=case_returns,
- batch_size=batch_size,
- **search_method_kwargs,
- )
-
diff --git a/xplique/example_based/proto_greedy.py b/xplique/example_based/proto_greedy.py
deleted file mode 100644
index 2c43565b..00000000
--- a/xplique/example_based/proto_greedy.py
+++ /dev/null
@@ -1,100 +0,0 @@
-"""
-ProtoGreedy method for searching prototypes
-"""
-
-import math
-
-import time
-
-import tensorflow as tf
-import numpy as np
-
-from ..types import Callable, Dict, List, Optional, Type, Union
-
-from ..commons import sanitize_inputs_targets
-from ..commons import sanitize_dataset, dataset_gather
-from .search_methods import ProtoGreedySearch
-from .projections import Projection
-from .prototypes import Prototypes
-
-from .search_methods.base import _sanitize_returns
-
-
-class ProtoGreedy(Prototypes):
- """
- ProtoGreedy method for searching prototypes.
-
- Parameters
- ----------
- cases_dataset
- The dataset used to train the model, examples are extracted from the dataset.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- labels_dataset
- Labels associated to the examples in the dataset. Indices should match with cases_dataset.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other dataset should match `cases_dataset`.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- targets_dataset
- Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other dataset should match `cases_dataset`.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- k
- The number of examples to retrieve.
- projection
- Projection or Callable that project samples from the input space to the search space.
- The search space should be a space where distance make sense for the model.
- It should not be `None`, otherwise,
- all examples could be computed only with the `search_method`.
-
- Example of Callable:
- ```
- def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
- '''
- Example of projection,
- inputs are the elements to project.
- targets are optional parameters to orientated the projection.
- '''
- projected_inputs = # do some magic on inputs, it should use the model.
- return projected_inputs
- ```
- case_returns
- String or list of string with the elements to return in `self.explain()`.
- See `self.set_returns()` for detail.
- batch_size
- Number of sample treated simultaneously for projection and search.
- Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
- search_method_kwargs
- Parameters to be passed at the construction of the `search_method`.
- """
-
- def __init__(
- self,
- cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- k: int = 1,
- projection: Union[Projection, Callable] = None,
- case_returns: Union[List[str], str] = "examples",
- batch_size: Optional[int] = 32,
- **search_method_kwargs,
- ):
- # the only difference with parent is that the search method is always ProtoGreedySearch
- search_method = ProtoGreedySearch
-
- super().__init__(
- cases_dataset=cases_dataset,
- labels_dataset=labels_dataset,
- targets_dataset=targets_dataset,
- search_method=search_method,
- k=k,
- projection=projection,
- case_returns=case_returns,
- batch_size=batch_size,
- **search_method_kwargs,
- )
-
diff --git a/xplique/example_based/prototypes.py b/xplique/example_based/prototypes.py
index 29946c22..16f68243 100644
--- a/xplique/example_based/prototypes.py
+++ b/xplique/example_based/prototypes.py
@@ -2,25 +2,19 @@
Base model for prototypes
"""
-import math
-
-import time
+from abc import ABC, abstractmethod
import tensorflow as tf
import numpy as np
from ..types import Callable, Dict, List, Optional, Type, Union
-from ..commons import sanitize_inputs_targets
-from ..commons import sanitize_dataset, dataset_gather
-from .search_methods import ProtoGreedySearch
+from .search_methods import BaseSearchMethod, ProtoGreedySearch, MMDCriticSearch, ProtoDashSearch
from .projections import Projection
from .base_example_method import BaseExampleMethod
-from .search_methods.base import _sanitize_returns
-
-class Prototypes(BaseExampleMethod):
+class Prototypes(BaseExampleMethod, ABC):
"""
Base class for prototypes.
@@ -76,54 +70,51 @@ def __init__(
self,
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- search_method: Type[ProtoGreedySearch] = ProtoGreedySearch,
k: int = 1,
projection: Union[Projection, Callable] = None,
case_returns: Union[List[str], str] = "examples",
batch_size: Optional[int] = 32,
- **search_method_kwargs,
- ):
- assert (
- projection is not None
- ), "`BaseExampleMethod` without `projection` is a `BaseSearchMethod`."
-
- # set attributes
- batch_size = self.__initialize_cases_dataset(
- cases_dataset, labels_dataset, targets_dataset, batch_size
+ distance: Union[int, str, Callable] = None,
+ nb_prototypes: int = 1,
+ kernel_type: str = 'local',
+ kernel_fn: callable = None,
+ gamma: float = None
+ ):
+ # set common example-based parameters
+ super().__init__(
+ cases_dataset=cases_dataset,
+ labels_dataset=labels_dataset,
+ k=k,
+ projection=projection,
+ case_returns=case_returns,
+ batch_size=batch_size,
)
- self.k = k
- self.set_returns(case_returns)
-
- assert hasattr(projection, "__call__"), "projection should be a callable."
-
- # check projection type
- if isinstance(projection, Projection):
- self.projection = projection
- elif hasattr(projection, "__call__"):
- self.projection = Projection(get_weights=None, space_projection=projection)
- else:
- raise AttributeError(
- "projection should be a `Projection` or a `Callable`, not a"
- + f"{type(projection)}"
- )
-
- # project dataset
- projected_cases_dataset = self.projection.project_dataset(self.cases_dataset,
- self.targets_dataset)
-
- # set `search_returns` if not provided and overwrite it otherwise
- search_method_kwargs["search_returns"] = ["indices", "distances"]
+ # set prototypes parameters
+ self.distance = distance
+ self.nb_prototypes = nb_prototypes
+ self.kernel_type = kernel_type
+ self.kernel_fn = kernel_fn
+ self.gamma = gamma
# initiate search_method
- self.search_method = search_method(
- cases_dataset=projected_cases_dataset,
- labels_dataset=labels_dataset,
- k=k,
- batch_size=batch_size,
- **search_method_kwargs,
+ self.search_method = self.search_method_class(
+ cases_dataset=self.cases_dataset,
+ labels_dataset=self.labels_dataset,
+ k=self.k,
+ search_returns=self._search_returns,
+ batch_size=self.batch_size,
+ distance=self.distance,
+ nb_prototypes=self.nb_prototypes,
+ kernel_type=self.kernel_type,
+ kernel_fn=self.kernel_fn,
+ gamma=self.gamma
)
+
+ @property
+ @abstractmethod
+ def search_method_class(self) -> Type[ProtoGreedySearch]:
+ raise NotImplementedError
def get_global_prototypes(self):
"""
@@ -137,100 +128,95 @@ def get_global_prototypes(self):
prototype weights.
"""
return self.search_method.prototype_indices, self.search_method.prototype_weights
-
- def __initialize_cases_dataset(
+
+
+class ProtoGreedy(Prototypes):
+ @property
+ def search_method_class(self) -> Type[ProtoGreedySearch]:
+ return ProtoGreedySearch
+
+
+class MMDCritic(Prototypes):
+ @property
+ def search_method_class(self) -> Type[ProtoGreedySearch]:
+ return MMDCriticSearch
+
+
+class ProtoDash(Prototypes):
+ """
+ Protodash method for searching prototypes.
+
+ References:
+ .. [#] `Karthik S. Gurumoorthy, Amit Dhurandhar, Guillermo Cecchi,
+ "ProtoDash: Fast Interpretable Prototype Selection"
+ `_
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from the dataset.
+ For natural example-based methods it is the train dataset.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
+ k
+ The number of examples to retrieve.
+ search_returns
+ String or list of string with the elements to return in `self.find_examples()`.
+ See `self.set_returns()` for detail.
+ batch_size
+ Number of sample treated simultaneously.
+ It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
+ distance
+ Either a Callable, or a value supported by `tf.norm` `ord` parameter.
+ Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
+ "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
+ yielding the corresponding p-norm." We also added 'cosine'.
+ nb_prototypes : int
+ Number of prototypes to find.
+ kernel_type : str, optional
+ The kernel type. It can be 'local' or 'global', by default 'local'.
+ When it is local, the distances are calculated only within the classes.
+ kernel_fn : Callable, optional
+ Kernel function, by default the rbf kernel.
+ This function must only use TensorFlow operations.
+ gamma : float, optional
+ Parameter that determines the spread of the rbf kernel, defaults to 1.0 / n_features.
+ use_optimizer : bool, optional
+ Flag indicating whether to use an optimizer for prototype selection, by default False.
+ """
+
+ def __init__(
self,
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]],
- targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]],
- batch_size: Optional[int],
- ) -> int:
- """
- Factorization of `__init__()` method for dataset related attributes.
-
- Parameters
- ----------
- cases_dataset
- The dataset used to train the model, examples are extracted from the dataset.
- labels_dataset
- Labels associated to the examples in the dataset.
- Indices should match with cases_dataset.
- targets_dataset
- Targets associated to the cases_dataset for dataset projection.
- See `projection` for detail.
- batch_size
- Number of sample treated simultaneously when using the datasets.
- Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
-
- Returns
- -------
- batch_size
- Number of sample treated simultaneously when using the datasets.
- Extracted from the datasets in case they are `tf.data.Dataset`.
- Otherwise, the input value.
- """
- # at least one dataset provided
- if isinstance(cases_dataset, tf.data.Dataset):
- # set batch size (ignore provided argument) and cardinality
- if isinstance(cases_dataset.element_spec, tuple):
- batch_size = tf.shape(next(iter(cases_dataset))[0])[0].numpy()
- else:
- batch_size = tf.shape(next(iter(cases_dataset)))[0].numpy()
-
- cardinality = cases_dataset.cardinality().numpy()
- else:
- # if case_dataset is not a `tf.data.Dataset`, then neither should the other.
- assert not isinstance(labels_dataset, tf.data.Dataset)
- assert not isinstance(targets_dataset, tf.data.Dataset)
- # set batch size and cardinality
- batch_size = min(batch_size, len(cases_dataset))
- cardinality = math.ceil(len(cases_dataset) / batch_size)
-
- # verify cardinality and create datasets from the tensors
- self.cases_dataset = sanitize_dataset(
- cases_dataset, batch_size, cardinality
- )
- self.labels_dataset = sanitize_dataset(
- labels_dataset, batch_size, cardinality
- )
- self.targets_dataset = sanitize_dataset(
- targets_dataset, batch_size, cardinality
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ projection: Union[Projection, Callable] = None,
+ case_returns: Union[List[str], str] = "examples",
+ batch_size: Optional[int] = 32,
+ distance: Union[int, str, Callable] = None,
+ nb_prototypes: int = 1,
+ kernel_type: str = 'local',
+ kernel_fn: callable = None,
+ gamma: float = None,
+ use_optimizer: bool = False,
+ ): # pylint: disable=R0801
+ self.use_optimizer = use_optimizer
+
+ super().__init__(
+ cases_dataset=cases_dataset,
+ labels_dataset=labels_dataset,
+ k=k,
+ projection=projection,
+ case_returns=case_returns,
+ batch_size=batch_size,
+ distance=distance,
+ nb_prototypes=nb_prototypes,
+ kernel_type=kernel_type,
+ kernel_fn=kernel_fn,
+ gamma=gamma
)
- # if the provided `cases_dataset` has several columns
- if isinstance(self.cases_dataset.element_spec, tuple):
- # switch case on the number of columns of `cases_dataset`
- if len(self.cases_dataset.element_spec) == 2:
- assert self.labels_dataset is None, (
- "The second column of `cases_dataset` is assumed to be the labels."
- + "Hence, `labels_dataset` should be empty."
- )
- self.labels_dataset = self.cases_dataset.map(lambda x, y: y)
- self.cases_dataset = self.cases_dataset.map(lambda x, y: x)
-
- elif len(self.cases_dataset.element_spec) == 3:
- assert self.labels_dataset is None, (
- "The second column of `cases_dataset` is assumed to be the labels."
- + "Hence, `labels_dataset` should be empty."
- )
- assert self.targets_dataset is None, (
- "The second column of `cases_dataset` is assumed to be the labels."
- + "Hence, `labels_dataset` should be empty."
- )
- self.targets_dataset = self.cases_dataset.map(lambda x, y, t: t)
- self.labels_dataset = self.cases_dataset.map(lambda x, y, t: y)
- self.cases_dataset = self.cases_dataset.map(lambda x, y, t: x)
- else:
- raise AttributeError(
- "`cases_dataset` cannot possess more than 3 columns,"
- + f"{len(self.cases_dataset.element_spec)} were detected."
- )
-
- self.cases_dataset = self.cases_dataset.prefetch(tf.data.AUTOTUNE)
- if self.labels_dataset is not None:
- self.labels_dataset = self.labels_dataset.prefetch(tf.data.AUTOTUNE)
- if self.targets_dataset is not None:
- self.targets_dataset = self.targets_dataset.prefetch(tf.data.AUTOTUNE)
-
- return batch_size
-
+ @property
+ def search_method_class(self) -> Type[ProtoGreedySearch]:
+ return ProtoDashSearch
+
diff --git a/xplique/example_based/search_methods/base.py b/xplique/example_based/search_methods/base.py
index 018f3ad4..60db96af 100644
--- a/xplique/example_based/search_methods/base.py
+++ b/xplique/example_based/search_methods/base.py
@@ -92,7 +92,6 @@ def __init__(
k: int = 1,
search_returns: Optional[Union[List[str], str]] = None,
batch_size: Optional[int] = 32,
- targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
): # pylint: disable=R0801
# set batch size
@@ -106,13 +105,6 @@ def __init__(
self.k = k
self.returns = search_returns
- # set targets_dataset
- if targets_dataset is not None:
- self.targets_dataset = sanitize_dataset(targets_dataset, self.batch_size)
- else:
- # make an iterable of None
- self.targets_dataset = [None]*len(cases_dataset)
-
@property
def k(self) -> int:
"""Getter for the k parameter."""
diff --git a/xplique/example_based/search_methods/kleor.py b/xplique/example_based/search_methods/kleor.py
index c62c14cd..315a25af 100644
--- a/xplique/example_based/search_methods/kleor.py
+++ b/xplique/example_based/search_methods/kleor.py
@@ -29,22 +29,22 @@ def __init__(
cases_dataset = cases_dataset,
targets_dataset=targets_dataset,
k=k,
- filter_fn=self._filter_fn,
search_returns=search_returns,
batch_size=batch_size,
distance=distance,
order=ORDER.ASCENDING,
+ filter_fn=self._filter_fn,
)
self.search_nuns = FilterKNN(
cases_dataset=cases_dataset,
targets_dataset=targets_dataset,
k=1,
- filter_fn=self._filter_fn_nun,
search_returns=["indices", "distances"],
batch_size=batch_size,
distance=distance,
order = ORDER.ASCENDING,
+ filter_fn=self._filter_fn_nun,
)
def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None):
diff --git a/xplique/example_based/search_methods/knn.py b/xplique/example_based/search_methods/knn.py
index be686b45..c1fbe4db 100644
--- a/xplique/example_based/search_methods/knn.py
+++ b/xplique/example_based/search_methods/knn.py
@@ -6,7 +6,7 @@
import numpy as np
import tensorflow as tf
-from ...commons import dataset_gather
+from ...commons import dataset_gather, sanitize_dataset
from ...types import Callable, List, Union, Optional, Tuple
from .base import BaseSearchMethod, ORDER
@@ -21,10 +21,12 @@ def __init__(
search_returns: Optional[Union[List[str], str]] = None,
batch_size: Optional[int] = 32,
order: ORDER = ORDER.ASCENDING,
- targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
):
super().__init__(
- cases_dataset, k, search_returns, batch_size, targets_dataset
+ cases_dataset=cases_dataset,
+ k=k,
+ search_returns=search_returns,
+ batch_size=batch_size,
)
assert isinstance(order, ORDER), f"order should be an instance of ORDER and not {type(order)}"
@@ -134,10 +136,13 @@ def __init__(
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
order: ORDER = ORDER.ASCENDING,
- targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
): # pylint: disable=R0801
super().__init__(
- cases_dataset, k, search_returns, batch_size, order, targets_dataset
+ cases_dataset=cases_dataset,
+ k=k,
+ search_returns=search_returns,
+ batch_size=batch_size,
+ order=order,
)
if hasattr(distance, "__call__"):
@@ -273,14 +278,18 @@ def __init__(
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
k: int = 1,
- filter_fn: Optional[Callable] = None,
search_returns: Optional[Union[List[str], str]] = None,
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
order: ORDER = ORDER.ASCENDING,
+ filter_fn: Optional[Callable] = None,
): # pylint: disable=R0801
super().__init__(
- cases_dataset, k, search_returns, batch_size, order, targets_dataset
+ cases_dataset=cases_dataset,
+ k=k,
+ search_returns=search_returns,
+ batch_size=batch_size,
+ order=order,
)
if hasattr(distance, "__call__"):
@@ -301,6 +310,13 @@ def __init__(
filter_fn = lambda x, z, y, t: tf.ones((tf.shape(x)[0], tf.shape(z)[0]), dtype=tf.bool)
self.filter_fn = filter_fn
+ # set targets_dataset
+ if targets_dataset is not None:
+ self.targets_dataset = sanitize_dataset(targets_dataset, self.batch_size)
+ else:
+ # make an iterable of None
+ self.targets_dataset = [None]*len(cases_dataset)
+
@tf.function
def _crossed_distances_fn(self, x1, x2, mask):
n = x1.shape[0]
diff --git a/xplique/example_based/search_methods/proto_greedy_search.py b/xplique/example_based/search_methods/proto_greedy_search.py
index a86f610d..4ed79899 100644
--- a/xplique/example_based/search_methods/proto_greedy_search.py
+++ b/xplique/example_based/search_methods/proto_greedy_search.py
@@ -146,7 +146,7 @@ def kernel_induced_distance(x1,x2):
elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(
distance, int
):
- self.distance_fn = lambda x1, x2: tf.norm(x1 - x2, ord=distance)
+ self.distance_fn = lambda x1, x2: tf.norm(x1 - x2, ord=distance, axis=-1)
else:
raise AttributeError(
"The distance parameter is expected to be either a Callable or in"
@@ -423,7 +423,7 @@ def find_prototypes(self, nb_prototypes):
return prototype_indices, prototype_cases, prototype_labels, prototype_weights
- def find_examples(self, inputs: Union[tf.Tensor, np.ndarray]):
+ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], _):
"""
Search the samples to return as examples. Called by the explain methods.
It may also return the indices corresponding to the samples,
@@ -438,7 +438,7 @@ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray]):
"""
# look for closest prototypes to projected inputs
- knn_output = self.knn(inputs)
+ knn_output = self.knn(inputs, _)
# obtain closest prototypes indices with respect to the prototypes
indices_wrt_prototypes = knn_output["indices"]
diff --git a/xplique/example_based/similar_examples.py b/xplique/example_based/similar_examples.py
index 2a9634d3..e8836167 100644
--- a/xplique/example_based/similar_examples.py
+++ b/xplique/example_based/similar_examples.py
@@ -84,18 +84,27 @@ def __init__(
case_returns: Union[List[str], str] = "examples",
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
- ):
- # the only difference with parent is that the search method is always KNN
- search_method = KNN
-
+ ):
super().__init__(
cases_dataset=cases_dataset,
labels_dataset=labels_dataset,
targets_dataset=targets_dataset,
- search_method=search_method,
k=k,
projection=projection,
case_returns=case_returns,
batch_size=batch_size,
- distance=distance
)
+
+ self.distance = distance
+
+ # initiate search_method
+ self.search_method = self.search_method_class(
+ cases_dataset=self.projected_cases_dataset,
+ search_returns=self._search_returns,
+ k=self.k,
+ batch_size=self.batch_size,
+ )
+
+ @property
+ def search_method_class(self) -> Type[BaseSearchMethod]:
+ return KNN
From 026c29e737b11104aebc915ab172e0c4642093d0 Mon Sep 17 00:00:00 2001
From: Lucas Hervier
Date: Wed, 22 May 2024 18:21:53 +0200
Subject: [PATCH 047/138] fix: change the fill value depending on the dataset
type
---
xplique/commons/tf_dataset_operations.py | 11 ++++++++---
1 file changed, 8 insertions(+), 3 deletions(-)
diff --git a/xplique/commons/tf_dataset_operations.py b/xplique/commons/tf_dataset_operations.py
index f74f4ea2..83c81fa0 100644
--- a/xplique/commons/tf_dataset_operations.py
+++ b/xplique/commons/tf_dataset_operations.py
@@ -205,9 +205,14 @@ def dataset_gather(dataset: tf.data.Dataset, indices: tf.Tensor) -> tf.Tensor:
example = next(iter(dataset))
# (n, bs, ...)
- results = tf.Variable(
- tf.fill(indices.shape[:-1] + example[0].shape, tf.constant(np.inf, dtype=dataset.element_spec.dtype)),
- )
+ if dataset.element_spec.dtype in ['uint8', 'int8', 'int16', 'int32', 'int64']:
+ results = tf.Variable(
+ tf.fill(indices.shape[:-1] + example[0].shape, tf.constant(-1, dtype=dataset.element_spec.dtype)),
+ )
+ else:
+ results = tf.Variable(
+ tf.fill(indices.shape[:-1] + example[0].shape, tf.constant(np.inf, dtype=dataset.element_spec.dtype)),
+ )
nb_results = product(indices.shape[:-1])
current_nb_results = 0
From 31bb02240a1ef48d320ce52f39de255cc8a40e5f Mon Sep 17 00:00:00 2001
From: Lucas Hervier
Date: Wed, 22 May 2024 18:22:34 +0200
Subject: [PATCH 048/138] fix: update kleor search and example base methods to
fit the new interface
---
tests/example_based/test_contrastive.py | 8 +-
tests/example_based/test_kleor.py | 12 +-
xplique/example_based/__init__.py | 2 +-
xplique/example_based/contrastive_examples.py | 131 +++++++++++++-----
.../example_based/search_methods/__init__.py | 2 +-
xplique/example_based/search_methods/kleor.py | 20 ++-
6 files changed, 123 insertions(+), 52 deletions(-)
diff --git a/tests/example_based/test_contrastive.py b/tests/example_based/test_contrastive.py
index 82a47d60..af5cdc6f 100644
--- a/tests/example_based/test_contrastive.py
+++ b/tests/example_based/test_contrastive.py
@@ -4,7 +4,7 @@
import tensorflow as tf
import numpy as np
-from xplique.example_based import NaiveCounterFactuals, LabelAwareCounterFactuals, KLEOR
+from xplique.example_based import NaiveCounterFactuals, LabelAwareCounterFactuals, KLEORGlobalSim, KLEORSimMiss
def test_naive_counter_factuals():
"""
@@ -162,13 +162,12 @@ def test_kleor():
targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
# start when strategy is sim_miss
- kleor_sim_miss = KLEOR(
+ kleor_sim_miss = KLEORSimMiss(
cases_dataset,
cases_targets_dataset,
k=1,
case_returns=["examples", "indices", "distances", "include_inputs", "nuns"],
batch_size=2,
- strategy="sim_miss"
)
return_dict = kleor_sim_miss(inputs, targets)
@@ -202,13 +201,12 @@ def test_kleor():
assert tf.reduce_all(tf.equal(indices, expected_indices))
# now strategy is global_sim
- kleor_global_sim = KLEOR(
+ kleor_global_sim = KLEORGlobalSim(
cases_dataset,
cases_targets_dataset,
k=1,
case_returns=["examples", "indices", "distances", "include_inputs", "nuns"],
batch_size=2,
- strategy="global_sim"
)
return_dict = kleor_global_sim(inputs, targets)
diff --git a/tests/example_based/test_kleor.py b/tests/example_based/test_kleor.py
index f4965f8d..fec68950 100644
--- a/tests/example_based/test_kleor.py
+++ b/tests/example_based/test_kleor.py
@@ -4,7 +4,7 @@
import tensorflow as tf
import numpy as np
-from xplique.example_based.search_methods import KLEORSimMiss, KLEORGlobalSim
+from xplique.example_based.search_methods import KLEORSimMissSearch, KLEORGlobalSimSearch
def test_kleor_base_and_sim_miss():
"""
@@ -22,7 +22,7 @@ def test_kleor_base_and_sim_miss():
targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
# build the kleor object
- kleor = KLEORSimMiss(cases_dataset, cases_targets_dataset, k=1, search_returns=["examples", "indices", "distances", "include_inputs", "nuns"], batch_size=2)
+ kleor = KLEORSimMissSearch(cases_dataset, cases_targets_dataset, k=1, search_returns=["examples", "indices", "distances", "include_inputs", "nuns"], batch_size=2)
# test the _filter_fn method
fake_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
@@ -46,7 +46,7 @@ def test_kleor_base_and_sim_miss():
assert tf.reduce_all(tf.equal(mask, expected_mask))
# test the _get_nuns method
- nuns, nuns_distances = kleor._get_nuns(inputs, targets)
+ nuns, _, nuns_distances = kleor._get_nuns(inputs, targets)
expected_nuns = tf.constant([
[[2., 3.]],
[[1., 2.]],
@@ -72,7 +72,7 @@ def test_kleor_base_and_sim_miss():
assert tf.reduce_all(tf.equal(batch_indices, expected_batch_indices))
# test the kneighbors method
- input_sf_distances, sf_indices, nuns = kleor.kneighbors(inputs, targets)
+ input_sf_distances, sf_indices, nuns, _, __ = kleor.kneighbors(inputs, targets)
assert input_sf_distances.shape == (3, 1) # (n, k)
assert sf_indices.shape == (3, 1, 2) # (n, k, 2)
@@ -121,7 +121,7 @@ def test_kleor_global_sim():
targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
# build the kleor object
- kleor = KLEORGlobalSim(cases_dataset, cases_targets_dataset, k=1, search_returns=["examples", "indices", "distances", "include_inputs", "nuns"], batch_size=2)
+ kleor = KLEORGlobalSimSearch(cases_dataset, cases_targets_dataset, k=1, search_returns=["examples", "indices", "distances", "include_inputs", "nuns"], batch_size=2)
# test the _additionnal_filtering method
# (n, bs)
@@ -154,7 +154,7 @@ def test_kleor_global_sim():
) < 1e-5)
# test the kneighbors method
- input_sf_distances, sf_indices, nuns = kleor.kneighbors(inputs, targets)
+ input_sf_distances, sf_indices, nuns, _, __ = kleor.kneighbors(inputs, targets)
expected_nuns = tf.constant([
[[2., 3.]],
diff --git a/xplique/example_based/__init__.py b/xplique/example_based/__init__.py
index e1d70b05..3de46d18 100644
--- a/xplique/example_based/__init__.py
+++ b/xplique/example_based/__init__.py
@@ -5,4 +5,4 @@
from .cole import Cole
from .similar_examples import SimilarExamples
from .prototypes import Prototypes, ProtoGreedy, ProtoDash, MMDCritic
-from .contrastive_examples import NaiveCounterFactuals, LabelAwareCounterFactuals, KLEOR
+from .contrastive_examples import NaiveCounterFactuals, LabelAwareCounterFactuals, KLEORGlobalSim, KLEORSimMiss
diff --git a/xplique/example_based/contrastive_examples.py b/xplique/example_based/contrastive_examples.py
index 05b1d9ad..5477b450 100644
--- a/xplique/example_based/contrastive_examples.py
+++ b/xplique/example_based/contrastive_examples.py
@@ -8,10 +8,10 @@
import tensorflow as tf
from ..types import Callable, List, Optional, Union, Dict
-from ..commons import sanitize_inputs_targets
+from ..commons import sanitize_inputs_targets, dataset_gather
from .base_example_method import BaseExampleMethod
-from .search_methods import ORDER, FilterKNN, KLEORSimMiss, KLEORGlobalSim
+from .search_methods import ORDER, FilterKNN, KLEORSimMissSearch, KLEORGlobalSimSearch
from .projections import Projection
from .search_methods.base import _sanitize_returns
@@ -177,10 +177,12 @@ def explain(
# TODO make an assert on the cf_targets
return super().explain(inputs, cf_targets)
-class KLEOR(BaseExampleMethod):
+class KLEORBase(BaseExampleMethod):
"""
"""
- _returns_possibilities = ["examples", "weights", "distances", "labels", "include_inputs", "nuns"]
+ _returns_possibilities = [
+ "examples", "weights", "distances", "labels", "include_inputs", "nuns", "nuns_indices", "dist_to_nuns"
+ ]
def __init__(
self,
@@ -192,16 +194,8 @@ def __init__(
case_returns: Union[List[str], str] = "examples",
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
- strategy: str = "sim_miss",
):
- if strategy == "global_sim":
- search_method = KLEORGlobalSim
- elif strategy == "sim_miss":
- search_method = KLEORSimMiss
- else:
- raise ValueError("strategy should be either 'global_sim' or 'sim_miss'.")
-
if projection is None:
projection = Projection(space_projection=lambda inputs: inputs)
@@ -218,22 +212,6 @@ def __init__(
self.distance = distance
self.order = ORDER.ASCENDING
- self.search_method = self.search_method_class(
- cases_dataset=self.cases_dataset,
- targets_dataset=self.targets_dataset,
- k=self.k,
- search_returns=self._search_returns,
- batch_size=self.batch_size,
- distance=distance,
- filter_fn=self.filter_fn,
- order=self.order
- )
-
- @property
- def search_method_class(self):
- return FilterKNN
-
-
@property
def returns(self) -> Union[List[str], str]:
"""Getter for the returns parameter."""
@@ -245,10 +223,15 @@ def returns(self, returns: Union[List[str], str]):
"""
default = "examples"
self._returns = _sanitize_returns(returns, self._returns_possibilities, default)
+ self._search_returns = ["indices", "distances"]
+
if isinstance(self._returns, list) and ("nuns" in self._returns):
- self._search_returns = ["indices", "distances", "nuns"]
- else:
- self._search_returns = ["indices", "distances"]
+ self._search_returns.append("nuns_indices")
+ elif isinstance(self._returns, list) and ("nuns_indices" in self._returns):
+ self._search_returns.append("nuns_indices")
+
+ if isinstance(self._returns, list) and ("dist_to_nuns" in self._returns):
+ self._search_returns.append("dist_to_nuns")
try:
self.search_method.returns = self._search_returns
@@ -265,5 +248,89 @@ def format_search_output(
"""
return_dict = super().format_search_output(search_output, inputs, targets)
if "nuns" in self.returns:
- return_dict["nuns"] = search_output["nuns"]
+ return_dict["nuns"] = dataset_gather(self.cases_dataset, search_output["nuns_indices"])
+ if "nuns_indices" in self.returns:
+ return_dict["nuns_indices"] = search_output["nuns_indices"]
+ if "dist_to_nuns" in self.returns:
+ return_dict["dist_to_nuns"] = search_output["dist_to_nuns"]
return return_dict
+
+class KLEORGlobalSim(KLEORBase):
+ """
+ """
+
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ projection: Union[Projection, Callable] = None,
+ case_returns: Union[List[str], str] = "examples",
+ batch_size: Optional[int] = 32,
+ distance: Union[int, str, Callable] = "euclidean",
+ ):
+
+ super().__init__(
+ cases_dataset=cases_dataset,
+ labels_dataset=labels_dataset,
+ targets_dataset=targets_dataset,
+ k=k,
+ projection=projection,
+ case_returns=case_returns,
+ batch_size=batch_size,
+ distance=distance,
+ )
+
+ self.search_method = self.search_method_class(
+ cases_dataset=self.cases_dataset,
+ targets_dataset=self.targets_dataset,
+ k=self.k,
+ search_returns=self._search_returns,
+ batch_size=self.batch_size,
+ distance=self.distance,
+ )
+
+ @property
+ def search_method_class(self):
+ return KLEORGlobalSimSearch
+
+class KLEORSimMiss(KLEORBase):
+ """
+ """
+
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ projection: Union[Projection, Callable] = None,
+ case_returns: Union[List[str], str] = "examples",
+ batch_size: Optional[int] = 32,
+ distance: Union[int, str, Callable] = "euclidean",
+ ):
+
+ super().__init__(
+ cases_dataset=cases_dataset,
+ labels_dataset=labels_dataset,
+ targets_dataset=targets_dataset,
+ k=k,
+ projection=projection,
+ case_returns=case_returns,
+ batch_size=batch_size,
+ distance=distance,
+ )
+
+ self.search_method = self.search_method_class(
+ cases_dataset=self.cases_dataset,
+ targets_dataset=self.targets_dataset,
+ k=self.k,
+ search_returns=self._search_returns,
+ batch_size=self.batch_size,
+ distance=self.distance,
+ )
+
+ @property
+ def search_method_class(self):
+ return KLEORSimMissSearch
diff --git a/xplique/example_based/search_methods/__init__.py b/xplique/example_based/search_methods/__init__.py
index 998a3025..24a2e14c 100644
--- a/xplique/example_based/search_methods/__init__.py
+++ b/xplique/example_based/search_methods/__init__.py
@@ -9,4 +9,4 @@
from .proto_dash_search import ProtoDashSearch
from .mmd_critic_search import MMDCriticSearch
from .knn import BaseKNN, KNN, FilterKNN
-from .kleor import KLEORSimMiss, KLEORGlobalSim
+from .kleor import KLEORSimMissSearch, KLEORGlobalSimSearch
diff --git a/xplique/example_based/search_methods/kleor.py b/xplique/example_based/search_methods/kleor.py
index 315a25af..bb057bf3 100644
--- a/xplique/example_based/search_methods/kleor.py
+++ b/xplique/example_based/search_methods/kleor.py
@@ -12,7 +12,7 @@
from .base import ORDER
from .knn import FilterKNN
-class BaseKLEOR(FilterKNN, ABC):
+class BaseKLEORSearch(FilterKNN, ABC):
"""
Base class for the KLEOR search methods.
"""
@@ -61,7 +61,7 @@ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[
Expected shape among (N, W), (N, T, W), (N, W, H, C).
"""
# compute neighbors
- examples_distances, examples_indices, nuns = self.kneighbors(inputs, targets)
+ examples_distances, examples_indices, nuns, nuns_indices, nuns_sf_distances = self.kneighbors(inputs, targets)
# build return dict
return_dict = self._build_return_dict(inputs, examples_distances, examples_indices)
@@ -70,6 +70,12 @@ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[
if "nuns" in self.returns:
return_dict["nuns"] = nuns
+ if "dist_to_nuns" in self.returns:
+ return_dict["dist_to_nuns"] = nuns_sf_distances
+
+ if "nuns_indices" in self.returns:
+ return_dict["nuns_indices"] = nuns_indices
+
return return_dict
def _filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor:
@@ -104,13 +110,13 @@ def _get_nuns(self, inputs: Union[tf.Tensor, np.ndarray], targets: Union[tf.Tens
nuns_dict = self.search_nuns(inputs, targets)
nuns_indices, nuns_distances = nuns_dict["indices"], nuns_dict["distances"]
nuns = dataset_gather(self.cases_dataset, nuns_indices)
- return nuns, nuns_distances
+ return nuns, nuns_indices, nuns_distances
def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], targets: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Tensor, tf.Tensor]:
"""
"""
# get the Nearest Unlike Neighbors and their distance to the related input
- nuns, nuns_input_distances = self._get_nuns(inputs, targets)
+ nuns, nuns_indices, nuns_input_distances = self._get_nuns(inputs, targets)
# initialize the search for the KLEOR semi-factual methods
sf_indices, input_sf_distances, nun_sf_distances, batch_indices = self._initialize_search(inputs)
@@ -160,7 +166,7 @@ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], targets: Union[tf.Ten
tf.gather(concatenated_input_sf_distances, sort_order, axis=1, batch_dims=1)
)
- return input_sf_distances, sf_indices, nuns
+ return input_sf_distances, sf_indices, nuns, nuns_indices, nun_sf_distances
def _initialize_search(self, inputs: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Variable, tf.Variable, tf.Variable, tf.Tensor]:
"""
@@ -185,7 +191,7 @@ def _additional_filtering(self, nun_sf_distances: tf.Tensor, input_sf_distances:
"""
raise NotImplementedError
-class KLEORSimMiss(BaseKLEOR):
+class KLEORSimMissSearch(BaseKLEORSearch):
"""
KLEOR search method.
@@ -199,7 +205,7 @@ class KLEORSimMiss(BaseKLEOR):
def _additional_filtering(self, nun_sf_distances: tf.Tensor, input_sf_distances: tf.Tensor, nuns_input_distances: tf.Tensor) -> Tuple:
return nun_sf_distances, input_sf_distances
-class KLEORGlobalSim(BaseKLEOR):
+class KLEORGlobalSimSearch(BaseKLEORSearch):
"""
KLEOR search method.
From 19d7cb3888b062c0178cb306799b011aae55c2b8 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Thu, 23 May 2024 14:59:46 +0200
Subject: [PATCH 049/138] example based: cole tests pass
---
tests/example_based/test_cole.py | 10 +--
xplique/example_based/contrastive_examples.py | 82 +++----------------
xplique/example_based/projections/base.py | 2 +
xplique/example_based/similar_examples.py | 4 +-
4 files changed, 22 insertions(+), 76 deletions(-)
diff --git a/tests/example_based/test_cole.py b/tests/example_based/test_cole.py
index ba71d5d3..e96abbb7 100644
--- a/tests/example_based/test_cole.py
+++ b/tests/example_based/test_cole.py
@@ -48,7 +48,7 @@ def test_cole_attribution():
Test that the distance has an impact.
"""
# Setup
- nb_samples = 20
+ nb_samples = 50
input_shape = (5, 5)
nb_labels = 10
k = 3
@@ -84,7 +84,7 @@ def test_cole_attribution():
explainer.gradient(model, inputs, targets)
projection = Projection(get_weights=explainer)
- euclidean_dist = lambda x, z: tf.sqrt(tf.reduce_sum(tf.square(x - z)))
+ euclidean_dist = lambda x, z: tf.sqrt(tf.reduce_sum(tf.square(x - z), axis=-1))
method_call = SimilarExamples(
cases_dataset=x_train,
targets_dataset=y_train,
@@ -121,9 +121,9 @@ def test_cole_attribution():
assert not almost_equal(examples_constructor, examples_different_distance)
# check weights are equal to the attribution directly on the input
- method_constructor.set_returns(["weights", "include_inputs"])
+ method_constructor.returns = ["weights", "include_inputs"]
assert almost_equal(
- method_constructor.explain(x_test, y_test)[:, 0],
+ method_constructor.explain(x_test, y_test)["weights"][:, 0],
Saliency(model)(x_test, y_test),
)
@@ -156,7 +156,7 @@ def test_cole_hadamard():
weights_extraction = lambda inputs, targets: gradients_predictions(model, inputs, targets)
projection = Projection(get_weights=weights_extraction)
- euclidean_dist = lambda x, z: tf.sqrt(tf.reduce_sum(tf.square(x - z)))
+ euclidean_dist = lambda x, z: tf.sqrt(tf.reduce_sum(tf.square(x - z), axis=-1))
method_call = SimilarExamples(
cases_dataset=x_train,
targets_dataset=y_train,
diff --git a/xplique/example_based/contrastive_examples.py b/xplique/example_based/contrastive_examples.py
index 5477b450..fd6b6204 100644
--- a/xplique/example_based/contrastive_examples.py
+++ b/xplique/example_based/contrastive_examples.py
@@ -177,6 +177,7 @@ def explain(
# TODO make an assert on the cf_targets
return super().explain(inputs, cf_targets)
+
class KLEORBase(BaseExampleMethod):
"""
"""
@@ -212,6 +213,15 @@ def __init__(
self.distance = distance
self.order = ORDER.ASCENDING
+ self.search_method = self.search_method_class(
+ cases_dataset=self.cases_dataset,
+ targets_dataset=self.targets_dataset,
+ k=self.k,
+ search_returns=self._search_returns,
+ batch_size=self.batch_size,
+ distance=self.distance,
+ )
+
@property
def returns(self) -> Union[List[str], str]:
"""Getter for the returns parameter."""
@@ -255,82 +265,14 @@ def format_search_output(
return_dict["dist_to_nuns"] = search_output["dist_to_nuns"]
return return_dict
-class KLEORGlobalSim(KLEORBase):
- """
- """
-
- def __init__(
- self,
- cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- k: int = 1,
- projection: Union[Projection, Callable] = None,
- case_returns: Union[List[str], str] = "examples",
- batch_size: Optional[int] = 32,
- distance: Union[int, str, Callable] = "euclidean",
- ):
-
- super().__init__(
- cases_dataset=cases_dataset,
- labels_dataset=labels_dataset,
- targets_dataset=targets_dataset,
- k=k,
- projection=projection,
- case_returns=case_returns,
- batch_size=batch_size,
- distance=distance,
- )
-
- self.search_method = self.search_method_class(
- cases_dataset=self.cases_dataset,
- targets_dataset=self.targets_dataset,
- k=self.k,
- search_returns=self._search_returns,
- batch_size=self.batch_size,
- distance=self.distance,
- )
+class KLEORGlobalSim(KLEORBase):
@property
def search_method_class(self):
return KLEORGlobalSimSearch
-class KLEORSimMiss(KLEORBase):
- """
- """
-
- def __init__(
- self,
- cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- k: int = 1,
- projection: Union[Projection, Callable] = None,
- case_returns: Union[List[str], str] = "examples",
- batch_size: Optional[int] = 32,
- distance: Union[int, str, Callable] = "euclidean",
- ):
-
- super().__init__(
- cases_dataset=cases_dataset,
- labels_dataset=labels_dataset,
- targets_dataset=targets_dataset,
- k=k,
- projection=projection,
- case_returns=case_returns,
- batch_size=batch_size,
- distance=distance,
- )
-
- self.search_method = self.search_method_class(
- cases_dataset=self.cases_dataset,
- targets_dataset=self.targets_dataset,
- k=self.k,
- search_returns=self._search_returns,
- batch_size=self.batch_size,
- distance=self.distance,
- )
+class KLEORSimMiss(KLEORBase):
@property
def search_method_class(self):
return KLEORSimMissSearch
diff --git a/xplique/example_based/projections/base.py b/xplique/example_based/projections/base.py
index 54192ed5..5efb3d27 100644
--- a/xplique/example_based/projections/base.py
+++ b/xplique/example_based/projections/base.py
@@ -73,6 +73,8 @@ def __init__(self,
# weights is a tensor
if isinstance(get_weights, np.ndarray):
weights = tf.convert_to_tensor(get_weights, dtype=tf.float32)
+ else:
+ weights = get_weights
# define a function that returns the weights
def get_weights(inputs, _ = None):
diff --git a/xplique/example_based/similar_examples.py b/xplique/example_based/similar_examples.py
index e8836167..75756074 100644
--- a/xplique/example_based/similar_examples.py
+++ b/xplique/example_based/similar_examples.py
@@ -11,7 +11,7 @@
from ..commons import sanitize_inputs_targets
from ..commons import sanitize_dataset, dataset_gather
-from .search_methods import KNN, BaseSearchMethod
+from .search_methods import KNN, BaseSearchMethod, ORDER
from .projections import Projection
from .base_example_method import BaseExampleMethod
@@ -103,6 +103,8 @@ def __init__(
search_returns=self._search_returns,
k=self.k,
batch_size=self.batch_size,
+ distance=self.distance,
+ order=ORDER.ASCENDING,
)
@property
From 122f8d56f543e09cf29a03626bf2f80728df6ad2 Mon Sep 17 00:00:00 2001
From: Lucas Hervier
Date: Wed, 29 May 2024 17:54:10 +0200
Subject: [PATCH 050/138] docs: refactoring of the documentation for the new
interfaces
---
tests/example_based/test_contrastive.py | 89 +++---
xplique/example_based/base_example_method.py | 91 +++---
xplique/example_based/cole.py | 31 +-
xplique/example_based/contrastive_examples.py | 291 +++++++++++++++---
xplique/example_based/search_methods/base.py | 31 +-
xplique/example_based/search_methods/kleor.py | 110 ++++++-
xplique/example_based/search_methods/knn.py | 120 +++++++-
xplique/example_based/similar_examples.py | 33 +-
8 files changed, 574 insertions(+), 222 deletions(-)
diff --git a/tests/example_based/test_contrastive.py b/tests/example_based/test_contrastive.py
index af5cdc6f..eab75ca7 100644
--- a/tests/example_based/test_contrastive.py
+++ b/tests/example_based/test_contrastive.py
@@ -5,20 +5,33 @@
import numpy as np
from xplique.example_based import NaiveCounterFactuals, LabelAwareCounterFactuals, KLEORGlobalSim, KLEORSimMiss
+from xplique.example_based.projections import Projection
def test_naive_counter_factuals():
"""
"""
+ # setup the tests
cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32)
cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2)
cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2)
- counter_factuals = NaiveCounterFactuals(cases_dataset, cases_targets_dataset, k=2, case_returns=["examples", "indices", "distances", "include_inputs"], batch_size=2)
inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
+ projection = Projection(space_projection=lambda inputs: inputs)
+
+ # build the NaiveCounterFactuals object
+ counter_factuals = NaiveCounterFactuals(
+ cases_dataset,
+ cases_targets_dataset,
+ k=2,
+ projection=projection,
+ case_returns=["examples", "indices", "distances", "include_inputs"],
+ batch_size=2
+ )
+
mask = counter_factuals.filter_fn(inputs, cases, targets, cases_targets)
assert mask.shape == (inputs.shape[0], cases.shape[0])
@@ -62,11 +75,22 @@ def test_label_aware_cf():
cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2)
cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2)
- counter_factuals = LabelAwareCounterFactuals(cases_dataset, cases_targets_dataset, k=1, case_returns=["examples", "indices", "distances", "include_inputs"], batch_size=2)
inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
cf_targets = tf.constant([[1, 0], [0, 1], [0, 1]], dtype=tf.float32)
+ projection = Projection(space_projection=lambda inputs: inputs)
+
+ # build the LabelAwareCounterFactuals object
+ counter_factuals = LabelAwareCounterFactuals(
+ cases_dataset,
+ cases_targets_dataset,
+ k=1,
+ projection=projection,
+ case_returns=["examples", "indices", "distances", "include_inputs"],
+ batch_size=2
+ )
+
mask = counter_factuals.filter_fn(inputs, cases, cf_targets, cases_targets)
assert mask.shape == (inputs.shape[0], cases.shape[0])
@@ -106,7 +130,14 @@ def test_label_aware_cf():
cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2)
cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2)
- counter_factuals = LabelAwareCounterFactuals(cases_dataset, cases_targets_dataset, k=1, case_returns=["examples", "indices", "distances", "include_inputs"], batch_size=2)
+ counter_factuals = LabelAwareCounterFactuals(
+ cases_dataset,
+ cases_targets_dataset,
+ k=1,
+ projection=projection,
+ case_returns=["examples", "indices", "distances", "include_inputs"],
+ batch_size=2
+ )
inputs = tf.constant([[1.5], [2.5], [4.5], [6.5], [8.5]], dtype=tf.float32)
cf_targets = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 1, 0]], dtype=tf.float32)
@@ -161,11 +192,14 @@ def test_kleor():
inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32)
targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
+ projection = Projection(space_projection=lambda inputs: inputs)
+
# start when strategy is sim_miss
kleor_sim_miss = KLEORSimMiss(
cases_dataset,
cases_targets_dataset,
k=1,
+ projection=projection,
case_returns=["examples", "indices", "distances", "include_inputs", "nuns"],
batch_size=2,
)
@@ -205,6 +239,7 @@ def test_kleor():
cases_dataset,
cases_targets_dataset,
k=1,
+ projection=projection,
case_returns=["examples", "indices", "distances", "include_inputs", "nuns"],
batch_size=2,
)
@@ -246,51 +281,3 @@ def test_kleor():
assert tf.reduce_all(
tf.abs(tf.where(inf_mask_examples, 0.0, examples) - tf.where(inf_mask_expected_examples, 0.0, expected_examples)
) < 1e-5)
-
-# def test_kleor_global_sim():
-# """
-# Test suite for the KleorSimMiss class
-# """
-# cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32)
-# cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
-
-# cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2)
-# cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2)
-# semi_factuals = KLEOR(
-# cases_dataset,
-# cases_targets_dataset,
-# k=1,
-# case_returns=["examples", "indices", "distances", "include_inputs", "nuns"],
-# batch_size=2,
-# strategy="global_sim"
-# )
-
-# return_dict = semi_factuals(inputs, targets)
-# assert set(return_dict.keys()) == set(["examples", "indices", "distances", "nuns"])
-
-# examples = return_dict["examples"]
-# distances = return_dict["distances"]
-# indices = return_dict["indices"]
-# nuns = return_dict["nuns"]
-
-# expected_nuns = tf.constant([
-# [[2., 3.]],
-# [[1., 2.]],
-# [[4., 5.]]], dtype=tf.float32)
-# assert tf.reduce_all(tf.equal(nuns, expected_nuns))
-
-# assert examples.shape == (3, 2, 2) # (n, k+1, W)
-# assert distances.shape == (3, 1) # (n, k)
-# assert indices.shape == (3, 1, 2) # (n, k, 2)
-
-# expected_examples = tf.constant([
-# [[1.5, 2.5], [1., 2.]],
-# [[2.5, 3.5], [2., 3.]],
-# [[4.5, 5.5], [3., 4.]]], dtype=tf.float32)
-# assert tf.reduce_all(tf.equal(examples, expected_examples))
-
-# expected_distances = tf.constant([[np.sqrt(2*0.5**2)], [np.sqrt(2*0.5**2)], [np.sqrt(2*1.5**2)]], dtype=tf.float32)
-# assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
-
-# expected_indices = tf.constant([[[0, 0]],[[0, 1]],[[1, 0]]], dtype=tf.int32)
-# assert tf.reduce_all(tf.equal(indices, expected_indices))
diff --git a/xplique/example_based/base_example_method.py b/xplique/example_based/base_example_method.py
index 0f4c1dfa..02cc1f4e 100644
--- a/xplique/example_based/base_example_method.py
+++ b/xplique/example_based/base_example_method.py
@@ -13,7 +13,7 @@
from ..commons import sanitize_inputs_targets
from ..commons import sanitize_dataset, dataset_gather
-from .search_methods import KNN, BaseSearchMethod
+from .search_methods import BaseSearchMethod
from .projections import Projection
from .search_methods.base import _sanitize_returns
@@ -21,38 +21,39 @@
class BaseExampleMethod(ABC):
"""
- Base class for natural example-based methods explaining models,
- they project the cases_dataset into a pertinent space for the with a `Projection`,
- then they call the `BaseSearchMethod` on it.
+ Base class for natural example-based methods explaining classification models.
+ An example-based method is a method that explains a model's predictions by providing examples from the cases_dataset
+ (usually the training dataset). The examples are selected with the help of a search method that performs a search in
+ the search space. The search space is defined with the help of a projection function that projects the cases_dataset
+ and the (inputs, targets) to explain into a space where the search method is relevant.
Parameters
----------
cases_dataset
- The dataset used to train the model, examples are extracted from the dataset.
+ The dataset used to train the model, examples are extracted from this dataset.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
labels_dataset
Labels associated to the examples in the dataset. Indices should match with cases_dataset.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other dataset should match `cases_dataset`.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
targets_dataset
- Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
+ Targets associated to the cases_dataset for dataset projection, oftentimes the one-hot encoding of a model's
+ predictions. See `projection` for detail.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other dataset should match `cases_dataset`.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
- search_method
- An algorithm to search the examples in the projected space.
k
- The number of examples to retrieve.
+ The number of examples to retrieve per input.
projection
Projection or Callable that project samples from the input space to the search space.
- The search space should be a space where distance make sense for the model.
- It should not be `None`, otherwise,
- all examples could be computed only with the `search_method`.
+ The search space should be a space where distances are relevant for the model.
+ It should not be `None`, otherwise, the model is not involved thus not explained. If you are interested in
+ searching the input space, you should use a `BaseSearchMethod` instead.
Example of Callable:
```
@@ -67,12 +68,10 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
```
case_returns
String or list of string with the elements to return in `self.explain()`.
- See `self.set_returns()` for detail.
+ See the returns property for details.
batch_size
Number of sample treated simultaneously for projection and search.
Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
- search_method_kwargs
- Parameters to be passed at the construction of the `search_method`.
"""
_returns_possibilities = ["examples", "weights", "distances", "labels", "include_inputs"]
@@ -88,7 +87,7 @@ def __init__(
):
assert (
projection is not None
- ), "`BaseExampleMethod` without `projection` is a `BaseSearchMethod`."
+ ), "`BaseExampleMethod` without Projection method should be a `BaseSearchMethod`."
# set attributes
self.batch_size = self._initialize_cases_dataset(
@@ -96,9 +95,9 @@ def __init__(
)
self._search_returns = ["indices", "distances"]
- assert hasattr(projection, "__call__"), "projection should be a callable."
- # check projection type
+ # check projection
+ assert hasattr(projection, "__call__"), "projection should be a callable."
if isinstance(projection, Projection):
self.projection = projection
elif hasattr(projection, "__call__"):
@@ -112,13 +111,17 @@ def __init__(
# project dataset
self.projected_cases_dataset = self.projection.project_dataset(self.cases_dataset,
self.targets_dataset)
-
+
+ # set properties
self.k = k
self.returns = case_returns
@property
@abstractmethod
def search_method_class(self) -> Type[BaseSearchMethod]:
+ """
+ When inheriting from `BaseExampleMethod`, one should define the search method class to use.
+ """
raise NotImplementedError
@property
@@ -179,13 +182,13 @@ def _initialize_cases_dataset(
Parameters
----------
cases_dataset
- The dataset used to train the model, examples are extracted from the dataset.
+ The dataset used to train the model, examples are extracted from this dataset.
labels_dataset
- Labels associated to the examples in the dataset.
+ Labels associated to the examples in the cases_dataset.
Indices should match with cases_dataset.
targets_dataset
Targets associated to the cases_dataset for dataset projection.
- See `projection` for detail.
+ See `projection` for details.
batch_size
Number of sample treated simultaneously when using the datasets.
Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
@@ -254,6 +257,7 @@ def _initialize_cases_dataset(
+ f"{len(self.cases_dataset.element_spec)} were detected."
)
+ # prefetch datasets
self.cases_dataset = self.cases_dataset.prefetch(tf.data.AUTOTUNE)
if self.labels_dataset is not None:
self.labels_dataset = self.labels_dataset.prefetch(tf.data.AUTOTUNE)
@@ -269,9 +273,9 @@ def explain(
targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
):
"""
- Compute examples to explain the inputs.
- It project inputs with `self.projection` in the search space
- and find examples with `self.search_method`.
+ Return the relevant examples to explain the (inputs, targets).
+ It projects inputs with `self.projection` in the search space
+ and find examples with the `self.search_method`.
Parameters
----------
@@ -280,20 +284,19 @@ def explain(
Expected shape among (N, W), (N, T, W), (N, W, H, C).
More information in the documentation.
targets
- Tensor or Array passed to the projection function.
+ Targets associated to the cases_dataset for dataset projection.
+ See `projection` for details.
Returns
-------
return_dict
Dictionary with listed elements in `self.returns`.
- If only one element is present it returns the element.
- The elements that can be returned are:
- examples, weights, distances, indices, and labels.
+ The elements that can be returned are defined with _returns_possibilities static attribute of the class.
"""
- # project inputs
+ # project inputs into the search space
projected_inputs = self.projection(inputs, targets)
- # look for closest elements to projected inputs
+ # look for relevant elements in the search space
search_output = self.search_method(projected_inputs, targets)
# manage returned elements
@@ -304,7 +307,7 @@ def __call__(
inputs: Union[tf.Tensor, np.ndarray],
targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
):
- """explain alias"""
+ """explain() alias"""
return self.explain(inputs, targets)
def format_search_output(
@@ -323,23 +326,20 @@ def format_search_output(
inputs
Tensor or Array. Input samples to be explained.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
- More information in the documentation.
targets
- Tensor or Array passed to the projection function.
- Here it is used by the explain function of attribution methods.
- Refer to the corresponding method documentation for more detail.
- Note that the default method is `Saliency`.
+ Targets associated to the cases_dataset for dataset projection.
+ See `projection` for details.
Returns
-------
return_dict
Dictionary with listed elements in `self.returns`.
- If only one element is present it returns the element.
- The elements that can be returned are:
- examples, weights, distances, indices, and labels.
+ The elements that can be returned are defined with _returns_possibilities static attribute of the class.
"""
+ # initialize return dictionary
return_dict = {}
+ # gather examples, labels, and targets from the example's indices of the search output
examples = dataset_gather(self.cases_dataset, search_output["indices"])
examples_labels = dataset_gather(self.labels_dataset, search_output["indices"])
examples_targets = dataset_gather(
@@ -377,13 +377,6 @@ def format_search_output(
return_dict["weights"] = tf.stack(weights, axis=0)
- # optimization test TODO
- # return_dict["weights"] = tf.vectorized_map(
- # fn=lambda x: self.projection.get_input_weights(x[0], x[1]),
- # elems=(examples, examples_targets),
- # # fn_output_signature=tf.float32,
- # )
-
# add indices, distances, and labels
if "indices" in self.returns:
return_dict["indices"] = search_output["indices"]
diff --git a/xplique/example_based/cole.py b/xplique/example_based/cole.py
index ca203b12..296bcea2 100644
--- a/xplique/example_based/cole.py
+++ b/xplique/example_based/cole.py
@@ -1,7 +1,6 @@
"""
Implementation of Cole method a simlilar examples method from example based module
"""
-
import numpy as np
import tensorflow as tf
@@ -14,10 +13,10 @@
class Cole(SimilarExamples):
"""
- Cole is a similar examples methods that gives the most similar examples to a query.
- Cole use the model to build a search space so that distances are meaningful for the model.
- It uses attribution methods to weights inputs.
- Those attributions may be computed in the latent space for complex data types like images.
+ Cole is a similar examples method that gives the most similar examples to a query in some specific projection space.
+ Cole use the model (to be explained) to build a search space so that distances are meaningful for the model.
+ It uses attribution methods to weight inputs.
+ Those attributions may be computed in the latent space for high-dimensional data like images.
It is an implementation of a method proposed by Kenny et Keane in 2019,
Twin-Systems to Explain Artificial Neural Networks using Case-Based Reasoning:
@@ -26,24 +25,25 @@ class Cole(SimilarExamples):
Parameters
----------
cases_dataset
- The dataset used to train the model, examples are extracted from the dataset.
+ The dataset used to train the model, examples are extracted from this dataset.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
labels_dataset
Labels associated to the examples in the dataset. Indices should match with cases_dataset.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other dataset should match `cases_dataset`.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
targets_dataset
- Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
+ Targets associated to the cases_dataset for dataset projection, oftentimes the one-hot encoding of a model's
+ predictions. See `projection` for detail.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other dataset should match `cases_dataset`.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
k
- The number of examples to retrieve. Default value is `1`.
+ The number of examples to retrieve per input.
distance
Either a Callable, or a value supported by `tf.norm` `ord` parameter.
Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
@@ -51,13 +51,10 @@ class Cole(SimilarExamples):
yielding the corresponding p-norm."
case_returns
String or list of string with the elements to return in `self.explain()`.
- See `self.set_returns()` from parent class `SimilarExamples` for detail.
- By default, the `explain()` method will only return the examples.
+ See the base class returns property for details.
batch_size
Number of sample treated simultaneously for projection and search.
Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
- device
- Device to use for the projection, if None, use the default device.
latent_layer
Layer used to split the model, the first part will be used for projection and
the second to compute the attributions. By default, the model is not split.
@@ -75,9 +72,8 @@ class Cole(SimilarExamples):
It should inherit from `xplique.attributions.base.BlackBoxExplainer`.
By default, it computes the gradient to make the Hadamard product in the latent space.
attribution_kwargs
- Parameters to be passed at the construction of the `attribution_method`.
+ Parameters to be passed for the construction of the `attribution_method`.
"""
-
def __init__(
self,
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
@@ -88,7 +84,6 @@ def __init__(
distance: Union[str, Callable] = "euclidean",
case_returns: Optional[Union[List[str], str]] = "examples",
batch_size: Optional[int] = 32,
- # device: Optional[str] = None,
latent_layer: Optional[Union[str, int]] = None,
attribution_method: Union[str, Type[BlackBoxExplainer]] = "gradient",
**attribution_kwargs,
@@ -104,7 +99,6 @@ def __init__(
model=model,
latent_layer=latent_layer,
operator=operator,
- # device=device,
)
elif issubclass(attribution_method, BlackBoxExplainer):
# build attribution projection
@@ -112,7 +106,6 @@ def __init__(
model=model,
method=attribution_method,
latent_layer=latent_layer,
- # device=device,
**attribution_kwargs,
)
else:
diff --git a/xplique/example_based/contrastive_examples.py b/xplique/example_based/contrastive_examples.py
index fd6b6204..caf4a3fe 100644
--- a/xplique/example_based/contrastive_examples.py
+++ b/xplique/example_based/contrastive_examples.py
@@ -1,9 +1,8 @@
"""
Implementation of both counterfactuals and semi factuals methods for classification tasks.
-
-SM CF guided to be implemented (I think): KLEOR at least Sim-Miss and Global-Sim
-SM CF free to be implemented: MDN but has to be adapated, Local-Region Model??
"""
+import warnings
+
import numpy as np
import tensorflow as tf
@@ -18,8 +17,60 @@
class NaiveCounterFactuals(BaseExampleMethod):
"""
- This class allows to search for counterfactuals by searching for the closest sample that do not have the same label.
+ This class allows to search for counterfactuals by searching for the closest sample to a query in a projection space
+ that do not have the same model's prediction.
It is a naive approach as it follows a greedy approach.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from this dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ targets_dataset
+ Targets are expected to be the one-hot encoding of the model's predictions for the samples in cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ k
+ The number of examples to retrieve per input.
+ projection
+ Projection or Callable that project samples from the input space to the search space.
+ The search space should be a space where distances are relevant for the model.
+ It should not be `None`, otherwise, the model is not involved thus not explained. If you are interested in
+ searching the input space, you should use a `BaseSearchMethod` instead.
+
+ Example of Callable:
+ ```
+ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
+ '''
+ Example of projection,
+ inputs are the elements to project.
+ targets are optional parameters to orientated the projection.
+ '''
+ projected_inputs = # do some magic on inputs, it should use the model.
+ return projected_inputs
+ ```
+ case_returns
+ String or list of string with the elements to return in `self.explain()`.
+ See the base class returns property for more details.
+ batch_size
+ Number of sample treated simultaneously for projection and search.
+ Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+ distance
+ Distance for the FilterKNN search method.
+ Either a Callable, or a value supported by `tf.norm` `ord` parameter.
+ Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
+ "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
+ yielding the corresponding p-norm." We also added 'cosine'.
"""
def __init__(
self,
@@ -32,10 +83,6 @@ def __init__(
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
):
-
- if projection is None:
- projection = Projection(space_projection=lambda inputs: inputs)
-
super().__init__(
cases_dataset=cases_dataset,
labels_dataset=labels_dataset,
@@ -45,10 +92,12 @@ def __init__(
case_returns=case_returns,
batch_size=batch_size,
)
-
+
+ # set distance function and order for the search method
self.distance = distance
self.order = ORDER.ASCENDING
+ # initiate search_method
self.search_method = self.search_method_class(
cases_dataset=self.cases_dataset,
targets_dataset=self.targets_dataset,
@@ -59,16 +108,20 @@ def __init__(
filter_fn=self.filter_fn,
order=self.order
)
-
+
@property
def search_method_class(self):
+ """
+ This property defines the search method class to use for the search. In this case, it is the FilterKNN that
+ is an efficient KNN search method ignoring non-acceptable cases, thus not considering them in the search.
+ """
return FilterKNN
def filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor:
"""
- Filter function to mask the cases for which the label is different from the predicted
- label on the inputs.
+ Filter function to mask the cases for which the model's prediction is different from the model's prediction
+ on the inputs.
"""
# get the labels predicted by the model
# (n, )
@@ -82,8 +135,59 @@ def filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor:
class LabelAwareCounterFactuals(BaseExampleMethod):
"""
- This method will search the counterfactuals with a specific label. This label should be provided by the user in the
- cf_labels_dataset args.
+ This method will search the counterfactuals of a query within an expected class. This class should be provided with
+ the query when calling the explain method.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from this dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ targets_dataset
+ Targets are expected to be the one-hot encoding of the model's predictions for the samples in cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ k
+ The number of examples to retrieve per input.
+ projection
+ Projection or Callable that project samples from the input space to the search space.
+ The search space should be a space where distances are relevant for the model.
+ It should not be `None`, otherwise, the model is not involved thus not explained. If you are interested in
+ searching the input space, you should use a `BaseSearchMethod` instead.
+
+ Example of Callable:
+ ```
+ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
+ '''
+ Example of projection,
+ inputs are the elements to project.
+ targets are optional parameters to orientated the projection.
+ '''
+ projected_inputs = # do some magic on inputs, it should use the model.
+ return projected_inputs
+ ```
+ case_returns
+ String or list of string with the elements to return in `self.explain()`.
+ See the base class returns property for more details.
+ batch_size
+ Number of sample treated simultaneously for projection and search.
+ Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+ distance
+ Distance for the FilterKNN search method.
+ Either a Callable, or a value supported by `tf.norm` `ord` parameter.
+ Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
+ "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
+ yielding the corresponding p-norm." We also added 'cosine'.
"""
def __init__(
self,
@@ -96,9 +200,6 @@ def __init__(
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = "euclidean",
):
- if projection is None:
- projection = Projection(space_projection=lambda inputs: inputs)
- # TODO: add a warning here if it is a custom projection that requires using targets as it might mismatch with the explain
super().__init__(
cases_dataset=cases_dataset,
@@ -109,10 +210,18 @@ def __init__(
case_returns=case_returns,
batch_size=batch_size,
)
-
+
+ # raise a warning to specify that target in the explain method is not the same as the target used for
+ # the target dataset
+ warnings.warn("If your projection method requires the target, be aware that when using the explain method,"
+ " the target provided is the class within one should search for the counterfactual.\nThus,"
+ " it is possible that the projection of the query is going wrong.")
+
+ # set distance function and order for the search method
self.distance = distance
self.order = ORDER.ASCENDING
+ # initiate search_method
self.search_method = self.search_method_class(
cases_dataset=self.cases_dataset,
targets_dataset=self.targets_dataset,
@@ -126,20 +235,24 @@ def __init__(
@property
def search_method_class(self):
+ """
+ This property defines the search method class to use for the search. In this case, it is the FilterKNN that
+ is an efficient KNN search method ignoring non-acceptable cases, thus not considering them in the search.
+ """
return FilterKNN
def filter_fn(self, _, __, cf_targets, cases_targets) -> tf.Tensor:
"""
- Filter function to mask the cases for which the label is different from the label(s) expected for the
+ Filter function to mask the cases for which the target is different from the target(s) expected for the
counterfactuals.
Parameters
----------
cf_targets
- TODO
+ The one-hot enoding of the target class for the counterfactuals.
cases_targets
- TODO
+ The one-hot encoding of the target class for the cases.
"""
mask = tf.matmul(cf_targets, cases_targets, transpose_b=True) #(n, bs)
# TODO: I think some retracing are done here
@@ -153,9 +266,9 @@ def explain(
cf_targets: Union[tf.Tensor, np.ndarray],
):
"""
- Compute examples to explain the inputs.
- It project inputs with `self.projection` in the search space
- and find examples with `self.search_method`.
+ Return the relevant CF examples to explain the inputs.
+ The CF examples are searched within cases for which the target is the one provided in `cf_targets`.
+ It projects inputs with `self.projection` in the search space and find examples with the `self.search_method`.
Parameters
----------
@@ -164,22 +277,80 @@ def explain(
Expected shape among (N, W), (N, T, W), (N, W, H, C).
More information in the documentation.
cf_targets
- TODO: change the description here
+ Tensor or Array. One-hot encoding of the target class for the counterfactuals.
Returns
-------
return_dict
Dictionary with listed elements in `self.returns`.
- If only one element is present it returns the element.
- The elements that can be returned are:
- examples, weights, distances, indices, and labels.
+ The elements that can be returned are defined with _returns_possibilities static attribute of the class.
"""
- # TODO make an assert on the cf_targets
return super().explain(inputs, cf_targets)
class KLEORBase(BaseExampleMethod):
"""
+ Base class for KLEOR methods. KLEOR methods search Semi-Factuals examples. In those methods, one should first
+ retrieve the Nearest Unlike Neighbor (NUN) which is the closest example to the query that has a different prediction
+ than the query. Then, the method search for the K-Nearest Neighbors (KNN) of the NUN that have the same prediction
+ as the query.
+
+ All the searches are done in a projection space where distances are relevant for the model. The projection space is
+ defined by the `projection` method.
+
+ Depending on the KLEOR method some additional condition for the search are added. See the specific KLEOR method for
+ more details.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from this dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ targets_dataset
+ Targets are expected to be the one-hot encoding of the model's predictions for the samples in cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ k
+ The number of examples to retrieve per input.
+ projection
+ Projection or Callable that project samples from the input space to the search space.
+ The search space should be a space where distances are relevant for the model.
+ It should not be `None`, otherwise, the model is not involved thus not explained. If you are interested in
+ searching the input space, you should use a `BaseSearchMethod` instead.
+
+ Example of Callable:
+ ```
+ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
+ '''
+ Example of projection,
+ inputs are the elements to project.
+ targets are optional parameters to orientated the projection.
+ '''
+ projected_inputs = # do some magic on inputs, it should use the model.
+ return projected_inputs
+ ```
+ case_returns
+ String or list of string with the elements to return in `self.explain()`.
+ See the base class returns property for more details.
+ batch_size
+ Number of sample treated simultaneously for projection and search.
+ Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+ distance
+ Distance for the FilterKNN search method.
+ Either a Callable, or a value supported by `tf.norm` `ord` parameter.
+ Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
+ "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
+ yielding the corresponding p-norm." We also added 'cosine'.
"""
_returns_possibilities = [
"examples", "weights", "distances", "labels", "include_inputs", "nuns", "nuns_indices", "dist_to_nuns"
@@ -197,9 +368,6 @@ def __init__(
distance: Union[int, str, Callable] = "euclidean",
):
- if projection is None:
- projection = Projection(space_projection=lambda inputs: inputs)
-
super().__init__(
cases_dataset=cases_dataset,
labels_dataset=labels_dataset,
@@ -209,10 +377,12 @@ def __init__(
case_returns=case_returns,
batch_size=batch_size,
)
-
+
+ # set distance function and order for the search method
self.distance = distance
self.order = ORDER.ASCENDING
+ # initiate search_method
self.search_method = self.search_method_class(
cases_dataset=self.cases_dataset,
targets_dataset=self.targets_dataset,
@@ -224,12 +394,14 @@ def __init__(
@property
def returns(self) -> Union[List[str], str]:
- """Getter for the returns parameter."""
+ """Override the Base class returns' parameter."""
return self._returns
@returns.setter
def returns(self, returns: Union[List[str], str]):
"""
+ Set the returns parameter. The returns parameter is a string or a list of string with the elements to return
+ in `self.explain()`. The elements that can be returned are defined with _returns_possibilities static attribute
"""
default = "examples"
self._returns = _sanitize_returns(returns, self._returns_possibilities, default)
@@ -255,6 +427,24 @@ def format_search_output(
targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
):
"""
+ Format the output of the `search_method` to match the expected returns in `self.returns`.
+
+ Parameters
+ ----------
+ search_output
+ Dictionary with the required outputs from the `search_method`.
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ targets
+ Targets associated to the cases_dataset for dataset projection.
+ See `projection` for details.
+
+ Returns
+ -------
+ return_dict
+ Dictionary with listed elements in `self.returns`.
+ The elements that can be returned are defined with _returns_possibilities static attribute of the class.
"""
return_dict = super().format_search_output(search_output, inputs, targets)
if "nuns" in self.returns:
@@ -266,13 +456,38 @@ def format_search_output(
return return_dict
-class KLEORGlobalSim(KLEORBase):
+class KLEORSimMiss(KLEORBase):
+ """
+ The KLEORSimMiss method search for Semi-Factuals examples by searching for the Nearest Unlike Neighbor (NUN) of
+ the query. The NUN is the closest example to the query that has a different prediction than the query. Then, the
+ method search for the K-Nearest Neighbors (KNN) of the NUN that have the same prediction as the query.
+
+ The search is done in a projection space where distances are relevant for the model. The projection space is defined
+ by the `projection` method.
+ """
@property
def search_method_class(self):
- return KLEORGlobalSimSearch
+ """
+ This property defines the search method class to use for the search. In this case, it is the KLEORSimMissSearch.
+ """
+ return KLEORSimMissSearch
+class KLEORGlobalSim(KLEORBase):
+ """
+ The KLEORGlobalSim method search for Semi-Factuals examples by searching for the Nearest Unlike Neighbor (NUN) of
+ the query. The NUN is the closest example to the query that has a different prediction than the query. Then, the
+ method search for the K-Nearest Neighbors (KNN) of the NUN that have the same prediction as the query.
-class KLEORSimMiss(KLEORBase):
+ In addition, for a SF candidate to be considered, the SF should be closer to the query than the NUN in the
+ projection space (i.e. the SF should be 'between' the input and its NUN). This condition is added to the search.
+
+ The search is done in a projection space where distances are relevant for the model. The projection space is defined
+ by the `projection` method.
+ """
@property
def search_method_class(self):
- return KLEORSimMissSearch
+ """
+ This property defines the search method class to use for the search. In this case, it is the
+ KLEORGlobalSimSearch.
+ """
+ return KLEORGlobalSimSearch
diff --git a/xplique/example_based/search_methods/base.py b/xplique/example_based/search_methods/base.py
index 60db96af..dc06bca9 100644
--- a/xplique/example_based/search_methods/base.py
+++ b/xplique/example_based/search_methods/base.py
@@ -8,7 +8,6 @@
import numpy as np
from ...types import Union, Optional, List
-
from ...commons import sanitize_dataset
class ORDER(Enum):
@@ -24,7 +23,6 @@ def _sanitize_returns(returns: Optional[Union[List[str], str]] = None,
possibilities: List[str] = None,
default: Union[List[str], str] = None):
"""
- Factorization of `set_returns` for `BaseSearchMethod` and `SimilarExamples`.
It cleans the `returns` parameter.
Results is either a sublist of possibilities or a value among possibilities.
@@ -66,20 +64,22 @@ def _sanitize_returns(returns: Optional[Union[List[str], str]] = None,
class BaseSearchMethod(ABC):
"""
- Base class used by `NaturalExampleBasedExplainer` search examples in
- a meaningful space for the model. It can also be used alone but will not provided
- model explanations.
+ Base class for the example-based search methods. This class is abstract. It should be inherited by
+ the search methods that are used to find examples in a dataset. It also defines the interface for the
+ search methods.
Parameters
----------
cases_dataset
- The dataset used to train the model, examples are extracted from the dataset.
- For natural example-based methods it is the train dataset.
+ The dataset containing the examples to search in.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
k
The number of examples to retrieve.
search_returns
String or list of string with the elements to return in `self.find_examples()`.
- See `self.set_returns()` for detail.
+ It should be a subset of `self._returns_possibilities`.
batch_size
Number of sample treated simultaneously.
It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
@@ -133,8 +133,6 @@ def returns(self, returns: Union[List[str], str]):
`returns` can be set to 'all' for all possible elements to be returned.
- 'examples' correspond to the expected examples,
the inputs may be included in first position. (n, k(+1), ...)
- - 'weights' the weights in the input space used in the projection.
- They are associated to the input and the examples. (n, k(+1), ...)
- 'distances' the distances between the inputs and the corresponding examples.
They are associated to the examples. (n, k, ...)
- 'labels' if provided through `dataset_labels`,
@@ -146,7 +144,7 @@ def returns(self, returns: Union[List[str], str]):
self._returns = _sanitize_returns(returns, self._returns_possibilities, default)
@abstractmethod
- def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None):
+ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None) -> dict:
"""
Search the samples to return as examples. Called by the explain methods.
It may also return the indices corresponding to the samples,
@@ -157,9 +155,16 @@ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[
inputs
Tensor or Array. Input samples to be explained.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ targets
+ Tensor or Array. Target of the samples to be explained.
+
+ Returns
+ -------
+ return_dict
+ Dictionary containing the elements to return which are specified in `self.returns`.
"""
raise NotImplementedError()
- def __call__(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None):
- """find_samples alias"""
+ def __call__(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None) -> dict:
+ """find_samples() alias"""
return self.find_examples(inputs, targets)
diff --git a/xplique/example_based/search_methods/kleor.py b/xplique/example_based/search_methods/kleor.py
index bb057bf3..08baa293 100644
--- a/xplique/example_based/search_methods/kleor.py
+++ b/xplique/example_based/search_methods/kleor.py
@@ -14,7 +14,39 @@
class BaseKLEORSearch(FilterKNN, ABC):
"""
- Base class for the KLEOR search methods.
+ Base class for the KLEOR search methods. In those methods, one should first retrieve the Nearest Unlike Neighbor
+ (NUN) which is the closest example to the query that has a different prediction than the query.
+ Then, the method search for the K-Nearest Neighbors (KNN) of the NUN that have the same prediction as the query.
+
+ Depending on the KLEOR method some additional condition for the search are added. See the specific KLEOR method for
+ more details.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to search the examples.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ targets_dataset
+ Targets are expected to be the one-hot encoding of the model's predictions for the samples in cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ k
+ The number of examples to retrieve per input.
+ search_returns
+ String or list of string with the elements to return in `self.find_examples()`.
+ It should be a subset of `self._returns_possibilities`.
+ batch_size
+ Number of sample treated simultaneously.
+ distance
+ Distance function to use to measure similarity.
+ Either a Callable, or a value supported by `tf.norm` `ord` parameter.
+ Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
+ "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
+ yielding the corresponding p-norm." We also added 'cosine'.
"""
def __init__(
self,
@@ -36,6 +68,7 @@ def __init__(
filter_fn=self._filter_fn,
)
+ # search method for the Nearest Unlike Neighbors
self.search_nuns = FilterKNN(
cases_dataset=cases_dataset,
targets_dataset=targets_dataset,
@@ -47,7 +80,7 @@ def __init__(
filter_fn=self._filter_fn_nun,
)
- def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None):
+ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None) -> dict:
"""
Search the samples to return as examples. Called by the explain methods.
It may also return the indices corresponding to the samples,
@@ -59,6 +92,13 @@ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[
Tensor or Array. Input samples to be explained.
Assumed to have been already projected.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ targets
+ Tensor or Array. Target of the samples to be explained.
+
+ Returns
+ -------
+ return_dict
+ Dictionary containing the elements to return which are specified in `self.returns`.
"""
# compute neighbors
examples_distances, examples_indices, nuns, nuns_indices, nuns_sf_distances = self.kneighbors(inputs, targets)
@@ -80,6 +120,7 @@ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[
def _filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor:
"""
+ Filter function to mask the cases for which the prediction is the same as the predicted label on the inputs.
"""
# get the labels predicted by the model
# (n, )
@@ -106,6 +147,7 @@ def _filter_fn_nun(self, _, __, targets, cases_targets) -> tf.Tensor:
def _get_nuns(self, inputs: Union[tf.Tensor, np.ndarray], targets: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Tensor, tf.Tensor]:
"""
+ Get the Nearest Unlike Neighbors and their distance to the related input.
"""
nuns_dict = self.search_nuns(inputs, targets)
nuns_indices, nuns_distances = nuns_dict["indices"], nuns_dict["distances"]
@@ -114,6 +156,41 @@ def _get_nuns(self, inputs: Union[tf.Tensor, np.ndarray], targets: Union[tf.Tens
def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], targets: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Tensor, tf.Tensor]:
"""
+ Compute the k SF to each tensor of `inputs` in `self.cases_dataset`.
+ Here `self.cases_dataset` is a `tf.data.Dataset`, hence, computations are done by batches.
+
+ Parameters
+ ----------
+ inputs
+ Tensor or Array. Input samples on which knn are computed.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ More information in the documentation.
+ targets
+ Tensor or Array. Target of the samples to be explained.
+
+ Returns
+ -------
+ input_sf_distances
+ Tensor of distances between the SFs and the inputs with dimension (n, k).
+ The n inputs times their k-SF.
+ sf_indices
+ Tensor of indices of the SFs in `self.cases_dataset` with dimension (n, k, 2).
+ Where, n represent the number of inputs and k the number of corresponding SFs.
+ The index of each element is encoded by two values,
+ the batch index and the index of the element in the batch.
+ Those indices can be used through `xplique.commons.tf_dataset_operation.dataset_gather`.
+ nuns
+ Tensor of Nearest Unlike Neighbors with dimension (n, 1, ...).
+ The n inputs times their NUN.
+ nuns_indices
+ Tensor of indices of the NUN in `self.cases_dataset` with dimension (n, 1, 2).
+ Where, n represent the number of inputs.
+ The index of each element is encoded by two values,
+ the batch index and the index of the element in the batch.
+ Those indices can be used through `xplique.commons.tf_dataset_operation.dataset_gather`.
+ nun_sf_distances
+ Tensor of distances between the SFs and the NUN with dimension (n, k).
+ The n NUNs times the k-SF.
"""
# get the Nearest Unlike Neighbors and their distance to the related input
nuns, nuns_indices, nuns_input_distances = self._get_nuns(inputs, targets)
@@ -193,30 +270,29 @@ def _additional_filtering(self, nun_sf_distances: tf.Tensor, input_sf_distances:
class KLEORSimMissSearch(BaseKLEORSearch):
"""
- KLEOR search method.
-
- Parameters
- ----------
- cases_dataset
- Dataset of cases.
- targets_dataset
- Dataset of targets. Should be a one-hot encoded of the predicted class
+ The KLEORSimMiss method search for Semi-Factuals examples by searching for the Nearest Unlike Neighbor (NUN) of
+ the query. The NUN is the closest example to the query that has a different prediction than the query. Then, the
+ method search for the K-Nearest Neighbors (KNN) of the NUN that have the same prediction as the query.
"""
def _additional_filtering(self, nun_sf_distances: tf.Tensor, input_sf_distances: tf.Tensor, nuns_input_distances: tf.Tensor) -> Tuple:
+ """
+ No additional filtering for the KLEORSimMiss method.
+ """
return nun_sf_distances, input_sf_distances
class KLEORGlobalSimSearch(BaseKLEORSearch):
"""
- KLEOR search method.
+ The KLEORGlobalSim method search for Semi-Factuals examples by searching for the Nearest Unlike Neighbor (NUN) of
+ the query. The NUN is the closest example to the query that has a different prediction than the query. Then, the
+ method search for the K-Nearest Neighbors (KNN) of the NUN that have the same prediction as the query.
- Parameters
- ----------
- cases_dataset
- Dataset of cases.
- targets_dataset
- Dataset of targets. Should be a one-hot encoded of the predicted class
+ In addition, for a SF candidate to be considered, the SF should be closer to the query than the NUN
+ (i.e. the SF should be 'between' the input and its NUN). This condition is added to the search.
"""
def _additional_filtering(self, nun_sf_distances: tf.Tensor, input_sf_distances: tf.Tensor, nuns_input_distances: tf.Tensor) -> Tuple:
+ """
+ Filter the distances to keep only the SF that are 'between' the input and its NUN.
+ """
# filter non acceptable cases, i.e. cases for which the distance to the input is greater
# than the distance between the input and its nun
# (n, current_bs)
diff --git a/xplique/example_based/search_methods/knn.py b/xplique/example_based/search_methods/knn.py
index c1fbe4db..70102a24 100644
--- a/xplique/example_based/search_methods/knn.py
+++ b/xplique/example_based/search_methods/knn.py
@@ -13,6 +13,27 @@
class BaseKNN(BaseSearchMethod):
"""
+ Base class for the KNN search methods. It is an abstract class that should be inherited by a specific KNN method.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset containing the examples to search in.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ k
+ The number of examples to retrieve.
+ search_returns
+ String or list of string with the elements to return in `self.find_examples()`.
+ It should be a subset of `self._returns_possibilities`.
+ batch_size
+ Number of sample treated simultaneously.
+ It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
+ order
+ The order of the distances, either `ORDER.ASCENDING` or `ORDER.DESCENDING`. Default is `ORDER.ASCENDING`.
+ ASCENDING means that the smallest distances are the best, DESCENDING means that the biggest distances are
+ the best.
"""
def __init__(
self,
@@ -28,7 +49,7 @@ def __init__(
search_returns=search_returns,
batch_size=batch_size,
)
-
+ # set order
assert isinstance(order, ORDER), f"order should be an instance of ORDER and not {type(order)}"
self.order = order
# fill value
@@ -47,7 +68,7 @@ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Uni
Expected shape among (N, W), (N, T, W), (N, W, H, C).
More information in the documentation.
targets
- Tensor or Array. Target samples to be explained.
+ Tensor or Array. Target of the samples to be explained.
Returns
-------
@@ -63,7 +84,7 @@ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Uni
"""
raise NotImplementedError
- def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None):
+ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None) -> dict:
"""
Search the samples to return as examples. Called by the explain methods.
It may also return the indices corresponding to the samples,
@@ -75,6 +96,13 @@ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[
Tensor or Array. Input samples to be explained.
Assumed to have been already projected.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
+ targets
+ Tensor or Array. Target of the samples to be explained.
+
+ Returns
+ -------
+ return_dict
+ Dictionary containing the elements to return which are specified in `self.returns`.
"""
# compute neighbors
examples_distances, examples_indices = self.kneighbors(inputs, targets)
@@ -84,9 +112,10 @@ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[
return return_dict
- def _build_return_dict(self, inputs, examples_distances, examples_indices):
+ def _build_return_dict(self, inputs, examples_distances, examples_indices) -> dict:
"""
- TODO: Change the description
+ Build the return dict based on the `self.returns` values. It builds the return dict with the value in the
+ subset of ['examples', 'include_inputs', 'indices', 'distances'] which is commonly shared.
"""
# Set values in return dict
return_dict = {}
@@ -107,22 +136,29 @@ def _build_return_dict(self, inputs, examples_distances, examples_indices):
class KNN(BaseKNN):
"""
KNN method to search examples. Based on `sklearn.neighbors.NearestNeighbors`.
- Basically a wrapper of `NearestNeighbors` to match the `BaseSearchMethod` API.
+ The kneighbors method is implemented in a batched way to handle large datasets.
Parameters
----------
cases_dataset
- The dataset used to train the model, examples are extracted from the dataset.
- For natural example-based methods it is the train dataset.
+ The dataset containing the examples to search in.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
k
The number of examples to retrieve.
search_returns
String or list of string with the elements to return in `self.find_examples()`.
- See `self.set_returns()` for detail.
+ It should be a subset of `self._returns_possibilities`.
batch_size
Number of sample treated simultaneously.
It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
+ order
+ The order of the distances, either `ORDER.ASCENDING` or `ORDER.DESCENDING`. Default is `ORDER.ASCENDING`.
+ ASCENDING means that the smallest distances are the best, DESCENDING means that the biggest distances are
+ the best.
distance
+ Distance function to use to measure similarity.
Either a Callable, or a value supported by `tf.norm` `ord` parameter.
Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
"Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
@@ -145,6 +181,7 @@ def __init__(
order=order,
)
+ # set distance function
if hasattr(distance, "__call__"):
self.distance_fn = distance
elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(
@@ -159,7 +196,23 @@ def __init__(
)
@tf.function
- def _crossed_distances_fn(self, x1, x2):
+ def _crossed_distances_fn(self, x1, x2) -> tf.Tensor:
+ """
+ Element-wise distance computation between two tensors.
+ It has been vectorized to handle batches of inputs and cases.
+
+ Parameters
+ ----------
+ x1
+ Tensor. Input samples of shape (n, ...).
+ x2
+ Tensor. Cases samples of shape (m, ...).
+
+ Returns
+ -------
+ distances
+ Tensor of distances between the inputs and the cases with dimension (n, m).
+ """
n = x1.shape[0]
m = x2.shape[0]
x2 = tf.expand_dims(x2, axis=0)
@@ -186,6 +239,8 @@ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], _ = None) -> Tuple[tf
Tensor or Array. Input samples on which knn are computed.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
More information in the documentation.
+ targets
+ Tensor or Array. Target of the samples to be explained.
Returns
-------
@@ -246,24 +301,33 @@ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], _ = None) -> Tuple[tf
class FilterKNN(BaseKNN):
"""
- TODO: Change the class description
KNN method to search examples. Based on `sklearn.neighbors.NearestNeighbors`.
- Basically a wrapper of `NearestNeighbors` to match the `BaseSearchMethod` API.
+ The kneighbors method is implemented in a batched way to handle large datasets.
+ In addition, a filter function is used to select the elements to compute the distances, thus reducing the
+ computational cost of the distance computation (worth if the computation of the filter is low and the matrix
+ of distances is sparse).
Parameters
----------
cases_dataset
- The dataset used to train the model, examples are extracted from the dataset.
- For natural example-based methods it is the train dataset.
+ The dataset containing the examples to search in.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
k
The number of examples to retrieve.
search_returns
String or list of string with the elements to return in `self.find_examples()`.
- See `self.set_returns()` for detail.
+ It should be a subset of `self._returns_possibilities`.
batch_size
Number of sample treated simultaneously.
It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
+ order
+ The order of the distances, either `ORDER.ASCENDING` or `ORDER.DESCENDING`. Default is `ORDER.ASCENDING`.
+ ASCENDING means that the smallest distances are the best, DESCENDING means that the biggest distances are
+ the best.
distance
+ Distance function to use to measure similarity.
Either a Callable, or a value supported by `tf.norm` `ord` parameter.
Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
"Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
@@ -291,7 +355,8 @@ def __init__(
batch_size=batch_size,
order=order,
)
-
+
+ # set distance function
if hasattr(distance, "__call__"):
self.distance_fn = distance
elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(
@@ -319,6 +384,24 @@ def __init__(
@tf.function
def _crossed_distances_fn(self, x1, x2, mask):
+ """
+ Element-wise distance computation between two tensors with a mask.
+ It has been vectorized to handle batches of inputs and cases.
+
+ Parameters
+ ----------
+ x1
+ Tensor. Input samples of shape (n, ...).
+ x2
+ Tensor. Cases samples of shape (m, ...).
+ mask
+ Tensor. Boolean mask of shape (n, m). It is used to filter the elements for which the distance is computed.
+
+ Returns
+ -------
+ distances
+ Tensor of distances between the inputs and the cases with dimension (n, m).
+ """
n = x1.shape[0]
m = x2.shape[0]
x2 = tf.expand_dims(x2, axis=0)
@@ -338,6 +421,9 @@ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Uni
"""
Compute the k-neareast neighbors to each tensor of `inputs` in `self.cases_dataset`.
Here `self.cases_dataset` is a `tf.data.Dataset`, hence, computations are done by batches.
+ In addition, a filter function is used to select the elements to compute the distances, thus reducing the
+ computational cost of the distance computation (worth if the computation of the filter is low and the matrix
+ of distances is sparse).
Parameters
----------
@@ -345,6 +431,8 @@ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Uni
Tensor or Array. Input samples on which knn are computed.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
More information in the documentation.
+ targets
+ Tensor or Array. Target of the samples to be explained.
Returns
-------
diff --git a/xplique/example_based/similar_examples.py b/xplique/example_based/similar_examples.py
index 75756074..ea16f261 100644
--- a/xplique/example_based/similar_examples.py
+++ b/xplique/example_based/similar_examples.py
@@ -1,53 +1,48 @@
"""
Base model for example-based
"""
-
-import math
-
import tensorflow as tf
import numpy as np
-from ..types import Callable, Dict, List, Optional, Type, Union
+from ..types import Callable, List, Optional, Type, Union
-from ..commons import sanitize_inputs_targets
-from ..commons import sanitize_dataset, dataset_gather
from .search_methods import KNN, BaseSearchMethod, ORDER
from .projections import Projection
from .base_example_method import BaseExampleMethod
-from .search_methods.base import _sanitize_returns
-
class SimilarExamples(BaseExampleMethod):
"""
- Base class for similar examples.
+ Class for similar example-based method. This class allows to search the k Nearest Neighbor of an input in the
+ projected space (defined by the projection method) using the distance defined by the distance method provided.
Parameters
----------
cases_dataset
- The dataset used to train the model, examples are extracted from the dataset.
+ The dataset used to train the model, examples are extracted from this dataset.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
labels_dataset
Labels associated to the examples in the dataset. Indices should match with cases_dataset.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other dataset should match `cases_dataset`.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
targets_dataset
- Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
+ Targets associated to the cases_dataset for dataset projection, oftentimes the one-hot encoding of a model's
+ predictions. See `projection` for detail.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other dataset should match `cases_dataset`.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
k
- The number of examples to retrieve.
+ The number of examples to retrieve per input.
projection
Projection or Callable that project samples from the input space to the search space.
- The search space should be a space where distance make sense for the model.
- It should not be `None`, otherwise,
- all examples could be computed only with the `search_method`.
+ The search space should be a space where distances are relevant for the model.
+ It should not be `None`, otherwise, the model is not involved thus not explained. If you are interested in
+ searching the input space, you should use a `BaseSearchMethod` instead.
Example of Callable:
```
@@ -62,7 +57,7 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
```
case_returns
String or list of string with the elements to return in `self.explain()`.
- See `self.set_returns()` for detail.
+ See the base class returns property for more details.
batch_size
Number of sample treated simultaneously for projection and search.
Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
@@ -73,7 +68,6 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
"Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
yielding the corresponding p-norm." We also added 'cosine'.
"""
-
def __init__(
self,
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
@@ -95,6 +89,7 @@ def __init__(
batch_size=batch_size,
)
+ # set distance function
self.distance = distance
# initiate search_method
From 7a665256de7f1f42168e0ebb351373c38c35b9cf Mon Sep 17 00:00:00 2001
From: Mohamed Chafik Bakey
Date: Wed, 3 Jul 2024 17:30:35 +0200
Subject: [PATCH 051/138] add the documentation for the prototypes search
methods
---
.../mmd_critic_search.md | 3 +
.../proto_dash_search.md | 3 +
.../proto_greedy_search.md | 3 +
.../prototypes_search_methods/prototypes.md | 69 +++++++++++++++++++
.../search_methods/search_methods.md | 0
mkdocs.yml | 7 ++
6 files changed, 85 insertions(+)
create mode 100644 docs/api/example_based/search_methods/prototypes_search_methods/mmd_critic_search.md
create mode 100644 docs/api/example_based/search_methods/prototypes_search_methods/proto_dash_search.md
create mode 100644 docs/api/example_based/search_methods/prototypes_search_methods/proto_greedy_search.md
create mode 100644 docs/api/example_based/search_methods/prototypes_search_methods/prototypes.md
create mode 100644 docs/api/example_based/search_methods/search_methods.md
diff --git a/docs/api/example_based/search_methods/prototypes_search_methods/mmd_critic_search.md b/docs/api/example_based/search_methods/prototypes_search_methods/mmd_critic_search.md
new file mode 100644
index 00000000..cb85d17c
--- /dev/null
+++ b/docs/api/example_based/search_methods/prototypes_search_methods/mmd_critic_search.md
@@ -0,0 +1,3 @@
+# MMDCriticSearch
+
+MMDCriticSearch ([Kim et al., 2016](https://proceedings.neurips.cc/paper/2016/hash/5680522b8e2bb01943234bce7bf84534-Abstract.html))
\ No newline at end of file
diff --git a/docs/api/example_based/search_methods/prototypes_search_methods/proto_dash_search.md b/docs/api/example_based/search_methods/prototypes_search_methods/proto_dash_search.md
new file mode 100644
index 00000000..b54dec50
--- /dev/null
+++ b/docs/api/example_based/search_methods/prototypes_search_methods/proto_dash_search.md
@@ -0,0 +1,3 @@
+# ProtoGreedySearch
+
+ProtoDashSearch ([Gurumoorthy et al., 2019](https://arxiv.org/abs/1707.01212))
\ No newline at end of file
diff --git a/docs/api/example_based/search_methods/prototypes_search_methods/proto_greedy_search.md b/docs/api/example_based/search_methods/prototypes_search_methods/proto_greedy_search.md
new file mode 100644
index 00000000..9213caa1
--- /dev/null
+++ b/docs/api/example_based/search_methods/prototypes_search_methods/proto_greedy_search.md
@@ -0,0 +1,3 @@
+# ProtoGreedySearch
+
+ProtoGreedySearch ([Gurumoorthy et al., 2019](https://arxiv.org/abs/1707.01212))
\ No newline at end of file
diff --git a/docs/api/example_based/search_methods/prototypes_search_methods/prototypes.md b/docs/api/example_based/search_methods/prototypes_search_methods/prototypes.md
new file mode 100644
index 00000000..e617985e
--- /dev/null
+++ b/docs/api/example_based/search_methods/prototypes_search_methods/prototypes.md
@@ -0,0 +1,69 @@
+# Prototypes
+
+Prototype-based explanation is a family of natural example-based XAI methods. Prototypes consist of a set of samples that are representative of either the dataset or a class.
+
+Three classes of prototype-based methods are found in the literature ([Poché et al., 2023](https://hal.science/hal-04117520/document)): Prototypes for Data-Centric Interpretability, Prototypes for Post-hoc Interpretability and Prototype-Based Models Interpretable by Design. This library focuses on first two classes.
+
+## Prototypes for Data-Centric Interpretability
+In this class, prototypes are selected without relying on the model and provide an overview of
+the dataset. In this library, the following methode are implemented as [search methods](./algorithms/search_methods/):
+
+Xplique includes the following prototypes search methods:
+
+| Method Name and Documentation link | **Tutorial** | Available with TF | Available with PyTorch* |
+|:-------------------------------------- | :----------------------: | :---------------: | :---------------------: |
+| [ProtoGreedySearch](../proto_greedy_search/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ |
+| [ProtoDashSearch](../proto_dash_search/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ |
+| [MMDCriticSearch](../mmd_critic_search/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ |
+
+*: Before using a PyTorch model it is highly recommended to read the [dedicated documentation](../pytorch/)
+
+### What is MMD?
+The commonality among these three methods is their utilization of the Maximum Mean Discrepancy (MMD) statistic as a measure of similarity between points and potential prototypes. MMD is a statistic for comparing two distributions (similar to KL-divergence). However, it is a non-parametric statistic, i.e., it does not assume a specific parametric form for the probability distributions being compared. It is defined as follows:
+
+$$
+\begin{align*}
+\text{MMD}(P, Q) &= \left\| \mathbb{E}_{X \sim P}[\varphi(X)] - \mathbb{E}_{Y \sim Q}[\varphi(Y)] \right\|_\mathcal{H}
+\end{align*}
+$$
+
+where $\varphi(\cdot)$ is a mapping function of the data points. If we want to consider all orders of moments of the distributions, the mapping vectors $\varphi(X)$ and $\varphi(Y)$ will be infinite-dimensional. Thus, we cannot calculate them directly. However, if we have a kernel that gives the same result as the inner product of these two mappings in Hilbert space ($k(x, y) = \langle \varphi(x), \varphi(y) \rangle_\mathcal{H}$), then the $MMD^2$ can be computed using only the kernel and without explicitly using $\varphi(X)$ and $\varphi(Y)$ (this is called the kernel trick):
+
+$$
+\begin{align*}
+\text{MMD}^2(P, Q) &= \langle \mathbb{E}_{X \sim P}[\varphi(X)], \mathbb{E}_{X' \sim P}[\varphi(X')] \rangle_\mathcal{H} + \langle \mathbb{E}_{Y \sim Q}[\varphi(Y)], \mathbb{E}_{Y' \sim Q}[\varphi(Y')] \rangle_\mathcal{H} \\
+&\quad - 2\langle \mathbb{E}_{X \sim P}[\varphi(X)], \mathbb{E}_{Y \sim Q}[\varphi(Y)] \rangle_\mathcal{H} \\
+&= \mathbb{E}_{X, X' \sim P}[k(X, X')] + \mathbb{E}_{Y, Y' \sim Q}[k(Y, Y')] - 2\mathbb{E}_{X \sim P, Y \sim Q}[k(X, Y)]
+\end{align*}
+$$
+
+### How to choose the kernel ?
+The choice of the kernel for selecting prototypes depends on the specific problem and the characteristics of your data. Several kernels can be used, including:
+
+- Gaussian
+- Laplace
+- Polynomial
+- Linear...
+
+If we consider any exponential kernel (Gaussian kernel, Laplace, ...), we automatically consider all the moments for the distribution, as the Taylor expansion of the exponential considers infinite-order moments. It is better to use a non-linear kernel to capture non-linear relationships in your data. If the problem is linear, it is better to choose a linear kernel such as the dot product kernel, since it is computationally efficient and often requires fewer hyperparameters to tune.
+
+For the MMD-critic method, the kernel must satisfy a condition ensuring the submodularity of the set function (the Gaussian kernel respects this constraint). In contrast, for Protodash and Protogreedy, any kernel can be used, as these methods rely on weak submodularity instead of full submodularity.
+
+### Default kernel
+The default kernel used is Gaussian kernel. This kernel distance assigns higher similarity to points that are close in feature space and gradually decreases similarity as points move further apart. It is a good choice when your data has complexity. However, it can be sensitive to the choice of hyperparameters, such as the width $\sigma$ of the Gaussian kernel, which may need to be carefully fine-tuned.
+
+## Prototypes for Post-hoc Interpretability
+
+Data-Centric methods such as Protogreedy, ProtoDash and MMD-critic can be used in either the output or the latent space of the model. In these cases, [projections methods](./algorithms/projections/) are used to transfer the data from the input space to the latent/output spaces.
+
+# Architecture of the code
+
+The Data-Centric prototypes methods are implemented as `search_methods`. The search method can have attribute `projection` that projects samples to a space where distances between samples make sense for the model. Then the `search_method` finds the prototypes by looking in the projected space.
+
+The class `ProtoGreedySearch` inherits from the `BaseSearchMethod` class. It finds prototypes and assigns a non-negative weight to each one.
+
+Both the `MMDCriticSearch` and `ProtoDashSearch` classes inherit from the `ProtoGreedySearch` class.
+
+The class `MMDCriticSearch` differs from `ProtoGreedySearch` by assigning equal weights to the selection of prototypes. The two classes use the same greedy algorithm. In the `compute_objective` method of `ProtoGreedySearch`, for each new candidate, we calculate the best weights for the selection of prototypes. However, in `MMDCriticSearch`, the `compute_objective` method assigns the same weight to all elements in the selection.
+
+The class `ProtoDashSearch`, like `ProtoGreedySearch`, assigns a non-negative weight to each prototype. However, the algorithm used by `ProtoDashSearch` is different: it maximizes a tight lower bound on $l(w)$ instead of maximizing $l(w)$, as done in `ProtoGreedySearch`. Therefore, `ProtoDashSearch` overrides the `compute_objective` method to calculate an objective based on the gradient of $l(w)$. It also overrides the `update_selection` method to select the best weights of the selection based on the gradient of the best candidate.
\ No newline at end of file
diff --git a/docs/api/example_based/search_methods/search_methods.md b/docs/api/example_based/search_methods/search_methods.md
new file mode 100644
index 00000000..e69de29b
diff --git a/mkdocs.yml b/mkdocs.yml
index 6b20f7f5..eb975edc 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -42,6 +42,13 @@ nav:
- Cav: api/concepts/cav.md
- Tcav: api/concepts/tcav.md
- Craft: api/concepts/craft.md
+ - Example based:
+ - Search Methods:
+ - Prototypes Search Methods:
+ - Prototypes: api/example_based/search_methods/prototypes_search_methods/prototypes.md
+ - ProtoGreedySearch: api/example_based/search_methods/prototypes_search_methods/proto_greedy_search.md
+ - ProtoDashSearch: api/example_based/search_methods/prototypes_search_methods/proto_dash_search.md
+ - MMDCriticSearch: api/example_based/search_methods/prototypes_search_methods/mmd_critic_search.md
- Feature visualization:
- Modern Feature Visualization (MaCo): api/feature_viz/maco.md
- Feature visualization: api/feature_viz/feature_viz.md
From c8047b0e5b4fb1e6f134875f1651083cafeba360 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Mon, 8 Jul 2024 15:14:37 +0200
Subject: [PATCH 052/138] prototypes: hotfix
---
.gitignore | 1 +
xplique/example_based/prototypes.py | 2 ++
2 files changed, 3 insertions(+)
diff --git a/.gitignore b/.gitignore
index fec8e55f..84161dc4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -43,6 +43,7 @@ coverage.xml
.pytest_cache/
cover/
*test*.sh
+tests/concepts/checkpoints/
# Environments
.env
diff --git a/xplique/example_based/prototypes.py b/xplique/example_based/prototypes.py
index 16f68243..6d8c8c34 100644
--- a/xplique/example_based/prototypes.py
+++ b/xplique/example_based/prototypes.py
@@ -70,6 +70,7 @@ def __init__(
self,
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
k: int = 1,
projection: Union[Projection, Callable] = None,
case_returns: Union[List[str], str] = "examples",
@@ -84,6 +85,7 @@ def __init__(
super().__init__(
cases_dataset=cases_dataset,
labels_dataset=labels_dataset,
+ targets_dataset=targets_dataset,
k=k,
projection=projection,
case_returns=case_returns,
From 693fa42a4013fe9c6f4a361b1c738f3a110d6961 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Mon, 8 Jul 2024 15:38:19 +0200
Subject: [PATCH 053/138] prototypes: hotfix
---
xplique/example_based/prototypes.py | 18 ++++++++++++++++--
1 file changed, 16 insertions(+), 2 deletions(-)
diff --git a/xplique/example_based/prototypes.py b/xplique/example_based/prototypes.py
index 6d8c8c34..cd14766b 100644
--- a/xplique/example_based/prototypes.py
+++ b/xplique/example_based/prototypes.py
@@ -157,9 +157,21 @@ class ProtoDash(Prototypes):
----------
cases_dataset
The dataset used to train the model, examples are extracted from the dataset.
- For natural example-based methods it is the train dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
labels_dataset
Labels associated to the examples in the dataset. Indices should match with cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ targets_dataset
+ Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other dataset should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
k
The number of examples to retrieve.
search_returns
@@ -191,6 +203,7 @@ def __init__(
self,
cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
k: int = 1,
projection: Union[Projection, Callable] = None,
case_returns: Union[List[str], str] = "examples",
@@ -206,7 +219,8 @@ def __init__(
super().__init__(
cases_dataset=cases_dataset,
- labels_dataset=labels_dataset,
+ labels_dataset=labels_dataset,
+ targets_dataset=targets_dataset,
k=k,
projection=projection,
case_returns=case_returns,
From d1e8031fd93183c3ac5257886b84f6a978e27571 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Mon, 8 Jul 2024 17:33:06 +0200
Subject: [PATCH 054/138] example-based: make tests pass
---
tests/example_based/test_kleor.py | 2 +-
tests/example_based/test_knn.py | 58 +++++++++++---------
tests/example_based/test_projections.py | 2 +-
xplique/example_based/projections/commons.py | 19 ++++++-
4 files changed, 52 insertions(+), 29 deletions(-)
diff --git a/tests/example_based/test_kleor.py b/tests/example_based/test_kleor.py
index fec68950..cd2cd333 100644
--- a/tests/example_based/test_kleor.py
+++ b/tests/example_based/test_kleor.py
@@ -56,7 +56,7 @@ def test_kleor_base_and_sim_miss():
[np.sqrt(2*1.5**2)],
[np.sqrt(2*0.5**2)]], dtype=tf.float32)
assert tf.reduce_all(tf.equal(nuns, expected_nuns))
- assert tf.reduce_all(tf.equal(nuns_distances, expected_nuns_distances))
+ assert tf.reduce_all(tf.abs(nuns_distances - expected_nuns_distances) < 1e-5)
# test the _initialize_search method
sf_indices, input_sf_distances, nun_sf_distances, batch_indices = kleor._initialize_search(inputs)
diff --git a/tests/example_based/test_knn.py b/tests/example_based/test_knn.py
index 61740a0e..4a9df427 100644
--- a/tests/example_based/test_knn.py
+++ b/tests/example_based/test_knn.py
@@ -5,6 +5,8 @@
import numpy as np
import tensorflow as tf
+from ..utils import almost_equal
+
from xplique.example_based.search_methods import BaseKNN, KNN, FilterKNN, ORDER
def get_setup(input_shape, nb_samples=10, nb_labels=10):
@@ -180,7 +182,7 @@ def test_knn_compute_distances():
distances = knn._crossed_distances_fn(x1, x2)
assert distances.shape == (x1.shape[0], x2.shape[0])
- assert tf.reduce_all(tf.equal(distances, expected_distance))
+ assert almost_equal(distances, expected_distance, epsilon=1e-5)
# Test with higher dimensions
data = np.array([
@@ -217,7 +219,7 @@ def test_knn_compute_distances():
distances = knn._crossed_distances_fn(x1, x2)
assert distances.shape == (x1.shape[0], x2.shape[0])
- assert tf.reduce_all(tf.equal(distances, expected_distance))
+ assert almost_equal(distances, expected_distance)
def test_knn_kneighbors():
@@ -237,8 +239,8 @@ def test_knn_kneighbors():
distances, indices = knn.kneighbors(inputs)
assert distances.shape == (3, 2)
assert indices.shape == (3, 2, 2)
- assert tf.reduce_all(tf.equal(distances, tf.constant([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]], dtype=tf.float32)))
- assert tf.reduce_all(tf.equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32)))
+ assert almost_equal(distances, tf.constant([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]], dtype=tf.float32))
+ assert almost_equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32))
# Test with reverse order
knn = KNN(
@@ -252,8 +254,8 @@ def test_knn_kneighbors():
distances, indices = knn.kneighbors(inputs)
assert distances.shape == (3, 2)
assert indices.shape == (3, 2, 2)
- assert tf.reduce_all(tf.equal(distances, tf.constant([[3.5, 2.5], [2.5, 1.5], [3.5, 2.5]], dtype=tf.float32)))
- assert tf.reduce_all(tf.equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32)))
+ assert almost_equal(distances, tf.constant([[3.5, 2.5], [2.5, 1.5], [3.5, 2.5]], dtype=tf.float32))
+ assert almost_equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32))
# Test with input and cases being 2D
cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32)
@@ -268,8 +270,8 @@ def test_knn_kneighbors():
distances, indices = knn.kneighbors(inputs)
assert distances.shape == (3, 2)
assert indices.shape == (3, 2, 2)
- assert tf.reduce_all(tf.equal(distances, tf.constant([[np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)]], dtype=tf.float32)))
- assert tf.reduce_all(tf.equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32)))
+ assert almost_equal(distances, tf.constant([[np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)]], dtype=tf.float32))
+ assert almost_equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32))
# Test with reverse order
knn = KNN(
@@ -285,7 +287,7 @@ def test_knn_kneighbors():
assert indices.shape == (3, 2, 2)
expected_distances = tf.constant([[np.sqrt(2*3.5**2), np.sqrt(2*2.5**2)], [np.sqrt(2*2.5**2), np.sqrt(2*1.5**2)], [np.sqrt(2*3.5**2), np.sqrt(2*2.5**2)]], dtype=tf.float32)
assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
- assert tf.reduce_all(tf.equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32)))
+ assert almost_equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32))
def test_filter_knn_compute_distances():
"""
@@ -310,12 +312,14 @@ def test_filter_knn_compute_distances():
mask = tf.ones((x1.shape[0], x2.shape[0]), dtype=tf.bool)
distances = knn._crossed_distances_fn(x1, x2, mask)
assert distances.shape == (x1.shape[0], x2.shape[0])
- assert tf.reduce_all(tf.equal(distances, expected_distance))
+ assert almost_equal(distances, expected_distance, epsilon=1e-5)
mask = tf.constant([[True, False], [False, True], [True, True]], dtype=tf.bool)
expected_distance = tf.constant([[np.sqrt(72), np.inf], [np.inf, np.sqrt(72)], [np.sqrt(8), np.sqrt(32)]], dtype=tf.float32)
distances = knn._crossed_distances_fn(x1, x2, mask)
- assert tf.reduce_all(tf.equal(distances, expected_distance))
+ assert np.allclose(distances, expected_distance, equal_nan=True)
+ assert np.array_equal(distances == np.inf, expected_distance == np.inf)
+ assert np.array_equal(distances == -np.inf, expected_distance == -np.inf)
# Test with higher dimensions
data = np.array([
@@ -353,13 +357,15 @@ def test_filter_knn_compute_distances():
mask = tf.ones((x1.shape[0], x2.shape[0]), dtype=tf.bool)
distances = knn._crossed_distances_fn(x1, x2, mask)
assert distances.shape == (x1.shape[0], x2.shape[0])
- assert tf.reduce_all(tf.equal(distances, expected_distance))
+ assert almost_equal(distances, expected_distance)
mask = tf.constant([[True, False], [False, True], [True, True]], dtype=tf.bool)
expected_distance = tf.constant([[np.sqrt(9)*27, np.inf], [np.inf, np.sqrt(9)*27], [np.sqrt(9)*9, np.sqrt(9)*18]], dtype=tf.float32)
distances = knn._crossed_distances_fn(x1, x2, mask)
assert distances.shape == (x1.shape[0], x2.shape[0])
- assert tf.reduce_all(tf.equal(distances, expected_distance))
+ assert np.allclose(distances, expected_distance, equal_nan=True)
+ assert np.array_equal(distances == np.inf, expected_distance == np.inf)
+ assert np.array_equal(distances == -np.inf, expected_distance == -np.inf)
def test_filter_knn_kneighbors():
"""
@@ -378,8 +384,8 @@ def test_filter_knn_kneighbors():
distances, indices = knn.kneighbors(inputs)
assert distances.shape == (3, 2)
assert indices.shape == (3, 2, 2)
- assert tf.reduce_all(tf.equal(distances, tf.constant([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]], dtype=tf.float32)))
- assert tf.reduce_all(tf.equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32)))
+ assert almost_equal(distances, tf.constant([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]], dtype=tf.float32))
+ assert almost_equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32))
cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
@@ -396,8 +402,8 @@ def test_filter_knn_kneighbors():
distances, indices = knn.kneighbors(inputs, targets)
assert distances.shape == (3, 2)
assert indices.shape == (3, 2, 2)
- assert tf.reduce_all(tf.equal(distances, tf.constant([[0.5, 2.5], [0.5, 0.5], [0.5, 1.5]], dtype=tf.float32)))
- assert tf.reduce_all(tf.equal(indices, tf.constant([[[0, 0], [1, 1]],[[0, 1], [1, 0]],[[2, 0], [1, 0]]], dtype=tf.int32)))
+ assert almost_equal(distances, tf.constant([[0.5, 2.5], [0.5, 0.5], [0.5, 1.5]], dtype=tf.float32))
+ assert almost_equal(indices, tf.constant([[[0, 0], [1, 1]],[[0, 1], [1, 0]],[[2, 0], [1, 0]]], dtype=tf.int32))
## test with reverse order
knn = FilterKNN(
@@ -412,8 +418,8 @@ def test_filter_knn_kneighbors():
assert distances.shape == (3, 2)
assert indices.shape == (3, 2, 2)
expected_distances = tf.constant([[3.5, 2.5], [2.5, 1.5], [3.5, 2.5]], dtype=tf.float32)
- assert tf.reduce_all(tf.equal(distances, expected_distances))
- assert tf.reduce_all(tf.equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32)))
+ assert almost_equal(distances, expected_distances)
+ assert almost_equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32))
## add a filter that is not the default one and reverse order
knn = FilterKNN(
@@ -429,8 +435,8 @@ def test_filter_knn_kneighbors():
distances, indices = knn.kneighbors(inputs, targets)
assert distances.shape == (3, 2)
assert indices.shape == (3, 2, 2)
- assert tf.reduce_all(tf.equal(distances, tf.constant([[2.5, 0.5], [2.5, 0.5], [2.5, 1.5]], dtype=tf.float32)))
- assert tf.reduce_all(tf.equal(indices, tf.constant([[[1, 1], [0, 0]],[[2, 0], [0, 1]],[[0, 1], [1, 0]]], dtype=tf.int32)))
+ assert almost_equal(distances, tf.constant([[2.5, 0.5], [2.5, 0.5], [2.5, 1.5]], dtype=tf.float32))
+ assert almost_equal(indices, tf.constant([[[1, 1], [0, 0]],[[2, 0], [0, 1]],[[0, 1], [1, 0]]], dtype=tf.int32))
# Test with input and cases being 2D
cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32)
@@ -446,8 +452,8 @@ def test_filter_knn_kneighbors():
distances, indices = knn.kneighbors(inputs)
assert distances.shape == (3, 2)
assert indices.shape == (3, 2, 2)
- assert tf.reduce_all(tf.equal(distances, tf.constant([[np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)]], dtype=tf.float32)))
- assert tf.reduce_all(tf.equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32)))
+ assert almost_equal(distances, tf.constant([[np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)]], dtype=tf.float32))
+ assert almost_equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32))
cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32)
targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32)
@@ -466,7 +472,7 @@ def test_filter_knn_kneighbors():
assert indices.shape == (3, 2, 2)
expected_distances = tf.constant([[np.sqrt(0.5), np.sqrt(2*2.5**2)], [np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(2*1.5**2)],], dtype=tf.float32)
assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
- assert tf.reduce_all(tf.equal(indices, tf.constant([[[0, 0], [1, 1]],[[0, 1], [1, 0]],[[2, 0], [1, 0]]], dtype=tf.int32)))
+ assert almost_equal(indices, tf.constant([[[0, 0], [1, 1]],[[0, 1], [1, 0]],[[2, 0], [1, 0]]], dtype=tf.int32))
## test with reverse order and default filter
knn = FilterKNN(
@@ -482,7 +488,7 @@ def test_filter_knn_kneighbors():
assert indices.shape == (3, 2, 2)
expected_distances = tf.constant([[np.sqrt(2*3.5**2), np.sqrt(2*2.5**2)], [np.sqrt(2*2.5**2), np.sqrt(2*1.5**2)], [np.sqrt(2*3.5**2), np.sqrt(2*2.5**2)]], dtype=tf.float32)
assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
- assert tf.reduce_all(tf.equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32)))
+ assert almost_equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32))
## add a filter that is not the default one and reverse order
knn = FilterKNN(
@@ -500,4 +506,4 @@ def test_filter_knn_kneighbors():
assert indices.shape == (3, 2, 2)
expected_distances = tf.constant([[np.sqrt(2*2.5**2), np.sqrt(0.5)], [np.sqrt(2*2.5**2), np.sqrt(0.5)], [np.sqrt(2*2.5**2), np.sqrt(2*1.5**2)]], dtype=tf.float32)
assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5)
- assert tf.reduce_all(tf.equal(indices, tf.constant([[[1, 1], [0, 0]],[[2, 0], [0, 1]],[[0, 1], [1, 0]]], dtype=tf.int32)))
+ assert almost_equal(indices, tf.constant([[[1, 1], [0, 0]],[[2, 0], [0, 1]],[[0, 1], [1, 0]]], dtype=tf.int32))
diff --git a/tests/example_based/test_projections.py b/tests/example_based/test_projections.py
index 8fe8b28f..f624b68a 100644
--- a/tests/example_based/test_projections.py
+++ b/tests/example_based/test_projections.py
@@ -13,7 +13,7 @@
from xplique.attributions import Saliency
from xplique.example_based.projections import Projection, AttributionProjection, LatentSpaceProjection
from xplique.example_based.projections.commons import model_splitting
-from ..utils import generate_data, almost_equal
+
def get_setup(input_shape, nb_samples=10, nb_labels=2):
"""
diff --git a/xplique/example_based/projections/commons.py b/xplique/example_based/projections/commons.py
index 59dc7ee8..c8747592 100644
--- a/xplique/example_based/projections/commons.py
+++ b/xplique/example_based/projections/commons.py
@@ -51,8 +51,25 @@ def model_splitting(model: tf.keras.Model,
features_extractor = tf.keras.Model(
model.input, latent_layer.output, name="features_extractor"
)
+ # predictor = tf.keras.Model(
+ # latent_layer.output, model.output, name="predictor"
+ # )
+ second_input = tf.keras.Input(shape=latent_layer.output_shape[1:])
+
+ # Reconstruct the second part of the model
+ x = second_input
+ layer_found = False
+ for layer in model.layers:
+ if layer_found:
+ x = layer(x)
+ if layer == latent_layer:
+ layer_found = True
+
+ # Create the second part of the model (predictor)
predictor = tf.keras.Model(
- latent_layer.output, model.output, name="predictor"
+ inputs=second_input,
+ outputs=x,
+ name="predictor"
)
if return_layer:
From 84fdb6852820b2330c121da12ad4edb71817d172 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Mon, 8 Jul 2024 18:20:32 +0200
Subject: [PATCH 055/138] prototypes: support non-identity projections
---
xplique/example_based/prototypes.py | 2 +-
.../example_based/search_methods/proto_greedy_search.py | 9 ++++++++-
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/xplique/example_based/prototypes.py b/xplique/example_based/prototypes.py
index cd14766b..5742190d 100644
--- a/xplique/example_based/prototypes.py
+++ b/xplique/example_based/prototypes.py
@@ -101,7 +101,7 @@ def __init__(
# initiate search_method
self.search_method = self.search_method_class(
- cases_dataset=self.cases_dataset,
+ cases_dataset=self.projected_cases_dataset,
labels_dataset=self.labels_dataset,
k=self.k,
search_returns=self._search_returns,
diff --git a/xplique/example_based/search_methods/proto_greedy_search.py b/xplique/example_based/search_methods/proto_greedy_search.py
index 4ed79899..39ed7159 100644
--- a/xplique/example_based/search_methods/proto_greedy_search.py
+++ b/xplique/example_based/search_methods/proto_greedy_search.py
@@ -10,7 +10,7 @@
from .base import BaseSearchMethod
from .knn import KNN
-from ..projections import Projection
+# from ..projections import Projection
class ProtoGreedySearch(BaseSearchMethod):
@@ -163,6 +163,13 @@ def kernel_induced_distance(x1,x2):
for batch_col_index, (batch_col_cases, batch_col_labels) in enumerate(
zip(self.cases_dataset, self.labels_dataset)
):
+ # elements should be tabular data
+ assert len(batch_col_cases.shape) == 2,\
+ "Expected prototypes' searches expects 2D data, (nb_samples, nb_features),"+\
+ f"but got {batch_col_cases.shape}"+\
+ "Please verify your projection if you provided a custom one."+\
+ "If you use a splitted model, make sure the output of the first part of the model is flattened."
+
batch_col_sums = tf.zeros((batch_col_cases.shape[0]))
for batch_row_index, (batch_row_cases, batch_row_labels) in enumerate(
From a7a9cc462ae11dad00ee247fa611edd179d484ff Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Tue, 9 Jul 2024 09:36:58 +0200
Subject: [PATCH 056/138] example-based: hotfix
---
xplique/example_based/search_methods/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/xplique/example_based/search_methods/base.py b/xplique/example_based/search_methods/base.py
index dc06bca9..77dd768b 100644
--- a/xplique/example_based/search_methods/base.py
+++ b/xplique/example_based/search_methods/base.py
@@ -96,7 +96,7 @@ def __init__(
# set batch size
if hasattr(cases_dataset, "_batch_size"):
- self.batch_size = cases_dataset._batch_size
+ self.batch_size = tf.cast(cases_dataset._batch_size, tf.int32)
else:
self.batch_size = batch_size
From 30e6a56f620f7b5909abaf3c8b5de0a670ab5ceb Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Tue, 9 Jul 2024 17:54:19 +0200
Subject: [PATCH 057/138] prototypes: enhance tests and adapt code
---
tests/example_based/test_prototypes.py | 452 +++++++++++-------
tests/utils.py | 17 +-
xplique/commons/tf_dataset_operations.py | 13 +-
xplique/example_based/prototypes.py | 166 +++----
.../search_methods/proto_dash_search.py | 12 +-
.../search_methods/proto_greedy_search.py | 123 +++--
6 files changed, 426 insertions(+), 357 deletions(-)
diff --git a/tests/example_based/test_prototypes.py b/tests/example_based/test_prototypes.py
index 8a31b24d..b9fd43f8 100644
--- a/tests/example_based/test_prototypes.py
+++ b/tests/example_based/test_prototypes.py
@@ -19,225 +19,317 @@
from xplique.example_based import Prototypes, ProtoGreedy, ProtoDash, MMDCritic
from xplique.example_based.projections import Projection, LatentSpaceProjection
-from tests.utils import almost_equal, get_Gaussian_Data, load_data, plot, plot_local_explanation
+from tests.utils import almost_equal, get_gaussian_data, load_data, plot, plot_local_explanation
-def test_proto_greedy_basic():
+def test_prototypes_global_explanations_basic():
"""
- Test the Prototypes with an identity projection.
+ Test prototypes shapes and uniqueness.
"""
# Setup
k = 3
- nb_prototypes = 3
+ nb_prototypes = 5
+ nb_classes = 3
+
gamma = 0.026
- x_train, y_train = get_Gaussian_Data(nb_samples_class=20)
- x_test, y_test = get_Gaussian_Data(nb_samples_class=10)
- # x_train, y_train = load_data('usps')
- # x_test, y_test = load_data('usps.t')
- # x_test = tf.random.shuffle(x_test)
- # x_test = x_test[0:8]
+ x_train, y_train = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=20)
+ x_test, y_test = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=10)
identity_projection = Projection(
space_projection=lambda inputs, targets=None: inputs
)
- kernel_type = "global"
-
- # Method initialization
- method = ProtoGreedy(
- cases_dataset=x_train,
- labels_dataset=y_train,
- k=k,
- projection=identity_projection,
- batch_size=32,
- distance=None, #"euclidean",
- nb_prototypes=nb_prototypes,
- kernel_type=kernel_type,
- gamma=gamma,
- )
-
- # Generate global explanation
- prototype_indices, prototype_weights = method.get_global_prototypes()
-
- prototypes = tf.gather(x_train, prototype_indices)
- prototype_labels = tf.gather(y_train, prototype_indices)
-
- # sort by label
- prototype_labels_sorted = prototype_labels.numpy().argsort()
-
- prototypes = tf.gather(prototypes, prototype_labels_sorted)
- prototype_indices = tf.gather(prototype_indices, prototype_labels_sorted)
- prototype_labels = tf.gather(prototype_labels, prototype_labels_sorted)
- prototype_weights = tf.gather(prototype_weights, prototype_labels_sorted)
-
- # Verifications
- # Shape
- assert prototype_indices.shape == (nb_prototypes,)
- assert prototypes.shape == (nb_prototypes, x_train.shape[1])
- assert prototype_weights.shape == (nb_prototypes,)
-
- # at least 1 prototype per class is selected
- assert tf.unique(prototype_labels)[0].shape == tf.unique(y_train)[0].shape
-
- # uniqueness test of prototypes
- assert prototype_indices.shape == tf.unique(prototype_indices)[0].shape
-
- # Check if all indices are between 0 and x_train.shape[0]-1
- assert tf.reduce_all(tf.math.logical_and(prototype_indices >= 0, prototype_indices <= x_train.shape[0]-1))
-
- # Generate local explanation
- examples = method.explain(x_test)
-
- # # Visualize all prototypes
- # plot(prototypes, prototype_weights, 'proto_greedy')
-
- # # Visualize local explanation
- # plot_local_explanation(examples, x_test, 'proto_greedy')
-
-def test_proto_dash_basic():
+ for kernel_type in ["local", "global"]:
+ for method_class in [ProtoGreedy, ProtoDash, MMDCritic]:
+ # compute general prototypes
+ method = method_class(
+ cases_dataset=x_train,
+ labels_dataset=y_train,
+ k=k,
+ projection=identity_projection,
+ batch_size=8,
+ distance="euclidean",
+ nb_prototypes=nb_prototypes,
+ kernel_type=kernel_type,
+ gamma=gamma,
+ )
+ # extract prototypes
+ prototypes_dict = method.get_global_prototypes()
+ prototypes = prototypes_dict["prototypes"]
+ prototypes_indices = prototypes_dict["prototypes_indices"]
+ prototypes_labels = prototypes_dict["prototypes_labels"]
+ prototypes_weights = prototypes_dict["prototypes_weights"]
+
+ # check shapes
+ assert prototypes.shape == (nb_prototypes,) + x_train.shape[1:]
+ assert prototypes_indices.shape == (nb_prototypes,)
+ assert prototypes_labels.shape == (nb_prototypes,)
+ assert prototypes_weights.shape == (nb_prototypes,)
+
+ # check uniqueness
+ assert len(prototypes_indices) == len(tf.unique(prototypes_indices)[0])
+
+ # for each prototype
+ for i in range(nb_prototypes):
+ # check prototypes are in the dataset and correspond to the index
+ assert tf.reduce_all(tf.equal(prototypes[i], x_train[prototypes_indices[i]]))
+
+ # same for labels
+ assert tf.reduce_all(tf.equal(prototypes_labels[i], y_train[prototypes_indices[i]]))
+
+ # check indices are in the dataset
+ assert prototypes_indices[i] >= 0 and prototypes_indices[i] < x_train.shape[0]
+
+
+def test_prototypes_local_explanations_basic():
"""
- Test the Prototypes with an identity projection.
+ Test prototypes local explanations.
"""
# Setup
k = 3
- nb_prototypes = 3
+ nb_prototypes = 5
+ nb_classes = 3
+ batch_size = 8
+
gamma = 0.026
- x_train, y_train = get_Gaussian_Data(nb_samples_class=20)
- x_test, y_test = get_Gaussian_Data(nb_samples_class=10)
- # x_train, y_train = load_data('usps')
- # x_test, y_test = load_data('usps.t')
- # x_test = tf.random.shuffle(x_test)
- # x_test = x_test[0:8]
+ x_train, y_train = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=20)
+ x_test, y_test = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=10)
identity_projection = Projection(
space_projection=lambda inputs, targets=None: inputs
)
- kernel_type = "global"
-
- # Method initialization
- method = ProtoDash(
- cases_dataset=x_train,
- labels_dataset=y_train,
- k=k,
- projection=identity_projection,
- batch_size=32,
- distance="euclidean",
- nb_prototypes=nb_prototypes,
- kernel_type=kernel_type,
- gamma=gamma,
- )
-
- # Generate global explanation
- prototype_indices, prototype_weights = method.get_global_prototypes()
-
- prototypes = tf.gather(x_train, prototype_indices)
- prototype_labels = tf.gather(y_train, prototype_indices)
-
- # sort by label
- prototype_labels_sorted = prototype_labels.numpy().argsort()
-
- prototypes = tf.gather(prototypes, prototype_labels_sorted)
- prototype_indices = tf.gather(prototype_indices, prototype_labels_sorted)
- prototype_labels = tf.gather(prototype_labels, prototype_labels_sorted)
- prototype_weights = tf.gather(prototype_weights, prototype_labels_sorted)
-
- # Verifications
- # Shape
- assert prototype_indices.shape == (nb_prototypes,)
- assert prototypes.shape == (nb_prototypes, x_train.shape[1])
- assert prototype_weights.shape == (nb_prototypes,)
-
- # at least 1 prototype per class is selected
- assert tf.unique(prototype_labels)[0].shape == tf.unique(y_train)[0].shape
-
- # uniqueness test of prototypes
- assert prototype_indices.shape == tf.unique(prototype_indices)[0].shape
-
- # Check if all indices are between 0 and x_train.shape[0]-1
- assert tf.reduce_all(tf.math.logical_and(prototype_indices >= 0, prototype_indices <= x_train.shape[0]-1))
-
- # Generate local explanation
- examples = method.explain(x_test)
-
- # # Visualize all prototypes
- # plot(prototypes, prototype_weights, 'proto_dash')
-
- # # Visualize local explanation
- # plot_local_explanation(examples, x_test, 'proto_dash')
-
-def test_mmd_critic_basic():
+ for kernel_type in ["local", "global"]:
+ for method_class in [ProtoGreedy, ProtoDash, MMDCritic]:
+ # compute general prototypes
+ method = method_class(
+ cases_dataset=x_train,
+ labels_dataset=y_train,
+ k=k,
+ projection=identity_projection,
+ case_returns=["examples", "distances", "labels", "indices"],
+ batch_size=batch_size,
+ distance="euclidean",
+ nb_prototypes=nb_prototypes,
+ kernel_type=kernel_type,
+ gamma=gamma,
+ )
+ # extract prototypes
+ prototypes_dict = method.get_global_prototypes()
+ prototypes = prototypes_dict["prototypes"]
+ prototypes_indices = prototypes_dict["prototypes_indices"]
+ prototypes_labels = prototypes_dict["prototypes_labels"]
+ prototypes_weights = prototypes_dict["prototypes_weights"]
+
+ # compute local explanations
+ outputs = method.explain(x_test)
+ examples = outputs["examples"]
+ distances = outputs["distances"]
+ labels = outputs["labels"]
+ indices = outputs["indices"]
+
+ # check shapes
+ assert examples.shape == (x_test.shape[0], k) + x_train.shape[1:]
+ assert distances.shape == (x_test.shape[0], k)
+ assert labels.shape == (x_test.shape[0], k)
+ assert indices.shape == (x_test.shape[0], k, 2)
+
+ # for each sample
+ for i in range(x_test.shape[0]):
+ # check first closest prototype label is the same as the sample label
+ assert tf.reduce_all(tf.equal(labels[i, 0], y_test[i]))
+
+ for j in range(k):
+ # check indices in prototypes' indices
+ index = indices[i, j, 0] * batch_size + indices[i, j, 1]
+ assert index in prototypes_indices
+
+ # check examples are in prototypes
+ assert tf.reduce_all(tf.equal(prototypes[prototypes_indices == index], examples[i, j]))
+
+ # check indices are in the dataset
+ assert tf.reduce_all(tf.equal(x_train[index], examples[i, j]))
+
+ # check distances
+ assert almost_equal(distances[i, j], tf.norm(x_test[i] - x_train[index]), epsilon=1e-5)
+
+ # check labels
+ assert tf.reduce_all(tf.equal(labels[i, j], y_train[index]))
+
+
+def test_prototypes_global_sanity_checks_1():
"""
- Test the Prototypes with an identity projection.
+ Test prototypes global explanations sanity checks.
+
+ Check 1: For n separated gaussians, for n requested prototypes, there should be 1 prototype per gaussian.
"""
+
# Setup
k = 3
nb_prototypes = 3
+
gamma = 0.026
- x_train, y_train = get_Gaussian_Data(nb_samples_class=20)
- x_test, y_test = get_Gaussian_Data(nb_samples_class=10)
- # x_train, y_train = load_data('usps')
- # x_test, y_test = load_data('usps.t')
- # x_test = tf.random.shuffle(x_test)
- # x_test = x_test[0:8]
+ x_train, y_train = get_gaussian_data(nb_classes=nb_prototypes, nb_samples_class=20)
identity_projection = Projection(
space_projection=lambda inputs, targets=None: inputs
)
- kernel_type = "global"
-
- # Method initialization
- method = MMDCritic(
- cases_dataset=x_train,
- labels_dataset=y_train,
- k=k,
- projection=identity_projection,
- batch_size=32,
- distance="euclidean",
- nb_prototypes=nb_prototypes,
- kernel_type=kernel_type,
- gamma=gamma,
- )
-
- # Generate global explanation
- prototype_indices, prototype_weights = method.get_global_prototypes()
-
- prototypes = tf.gather(x_train, prototype_indices)
- prototype_labels = tf.gather(y_train, prototype_indices)
-
- # sort by label
- prototype_labels_sorted = prototype_labels.numpy().argsort()
+ for kernel_type in ["local", "global"]:
+ for method_class in [ProtoGreedy, ProtoDash, MMDCritic]:
+ # compute general prototypes
+ method = method_class(
+ cases_dataset=x_train,
+ labels_dataset=y_train,
+ k=k,
+ projection=identity_projection,
+ batch_size=8,
+ distance="euclidean",
+ nb_prototypes=nb_prototypes,
+ kernel_type=kernel_type,
+ gamma=gamma,
+ )
+ # extract prototypes
+ prototypes_dict = method.get_global_prototypes()
+ prototypes = prototypes_dict["prototypes"]
+ prototypes_indices = prototypes_dict["prototypes_indices"]
+ prototypes_labels = prototypes_dict["prototypes_labels"]
+ prototypes_weights = prototypes_dict["prototypes_weights"]
+
+ # check 1
+ assert len(tf.unique(prototypes_labels)[0]) == nb_prototypes
+
+
+def test_prototypes_global_sanity_checks_2():
+ """
+ Test prototypes global explanations sanity checks.
- prototypes = tf.gather(prototypes, prototype_labels_sorted)
- prototype_indices = tf.gather(prototype_indices, prototype_labels_sorted)
- prototype_labels = tf.gather(prototype_labels, prototype_labels_sorted)
- prototype_weights = tf.gather(prototype_weights, prototype_labels_sorted)
+ Check 2: With local kernel_type, if there are more requested prototypes than classes, there should be at least 1 prototype per class.
+ """
+
+ # Setup
+ k = 3
+ nb_prototypes = 5
+ nb_classes = 3
- # Verifications
- # Shape
- assert prototype_indices.shape == (nb_prototypes,)
- assert prototypes.shape == (nb_prototypes, x_train.shape[1])
- assert prototype_weights.shape == (nb_prototypes,)
+ gamma = 0.026
+ x_train, y_train = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=20)
- # at least 1 prototype per class is selected
- assert tf.unique(prototype_labels)[0].shape == tf.unique(y_train)[0].shape
+ # randomize y_train
+ y_train = tf.random.shuffle(y_train)
- # uniqueness test of prototypes
- assert prototype_indices.shape == tf.unique(prototype_indices)[0].shape
+ identity_projection = Projection(
+ space_projection=lambda inputs, targets=None: inputs
+ )
- # Check if all indices are between 0 and x_train.shape[0]-1
- assert tf.reduce_all(tf.math.logical_and(prototype_indices >= 0, prototype_indices <= x_train.shape[0]-1))
+ for method_class in [ProtoGreedy, ProtoDash, MMDCritic]:
+ # compute general prototypes
+ method = method_class(
+ cases_dataset=x_train,
+ labels_dataset=y_train,
+ k=k,
+ projection=identity_projection,
+ batch_size=8,
+ distance="euclidean",
+ nb_prototypes=nb_prototypes,
+ kernel_type="local",
+ gamma=gamma,
+ )
+ # extract prototypes
+ prototypes_dict = method.get_global_prototypes()
+ prototypes = prototypes_dict["prototypes"]
+ prototypes_indices = prototypes_dict["prototypes_indices"]
+ prototypes_labels = prototypes_dict["prototypes_labels"]
+ prototypes_weights = prototypes_dict["prototypes_weights"]
+
+ # check 2
+ assert len(tf.unique(prototypes_labels)[0]) == nb_classes
+
+
+def test_prototypes_local_explanations_with_projection():
+ """
+ Test prototypes local explanations with a projection.
+ """
+ # Setup
+ k = 3
+ nb_prototypes = 5
+ nb_classes = 3
+ batch_size = 8
- # Generate local explanation
- examples = method.explain(x_test)
+ gamma = 0.026
+ x_train, y_train = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=20)
+ x_train_bis, _ = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=20)
+ x_train = tf.concat([x_train, x_train_bis], axis=1) # make a dataset with two dimensions
- # # Visualize all prototypes
- # plot(prototypes, prototype_weights, 'mmd_critic')
+ x_test, y_test = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=10)
- # # Visualize local explanation
- # plot_local_explanation(examples, x_test, 'mmd_critic')
+ projection = Projection(
+ space_projection=lambda inputs, targets=None: tf.reduce_mean(inputs, axis=1, keepdims=True)
+ )
-# test_proto_greedy_basic()
-# test_proto_dash_basic()
-# test_mmd_critic_basic()
+ for kernel_type in ["local", "global"]:
+ for method_class in [ProtoGreedy, ProtoDash, MMDCritic]:
+ # compute general prototypes
+ method = method_class(
+ cases_dataset=x_train,
+ labels_dataset=y_train,
+ k=k,
+ projection=projection,
+ case_returns=["examples", "distances", "labels", "indices"],
+ batch_size=batch_size,
+ distance="euclidean",
+ nb_prototypes=nb_prototypes,
+ kernel_type=kernel_type,
+ gamma=gamma,
+ )
+ # extract prototypes
+ prototypes_dict = method.get_global_prototypes()
+ prototypes = prototypes_dict["prototypes"]
+ prototypes_indices = prototypes_dict["prototypes_indices"]
+ prototypes_labels = prototypes_dict["prototypes_labels"]
+ prototypes_weights = prototypes_dict["prototypes_weights"]
+
+ # check shapes
+ assert prototypes.shape == (nb_prototypes,) + x_train.shape[1:]
+ assert prototypes_indices.shape == (nb_prototypes,)
+ assert prototypes_labels.shape == (nb_prototypes,)
+ assert prototypes_weights.shape == (nb_prototypes,)
+
+ # compute local explanations
+ outputs = method.explain(x_test)
+ examples = outputs["examples"]
+ distances = outputs["distances"]
+ labels = outputs["labels"]
+ indices = outputs["indices"]
+
+ # check shapes
+ assert examples.shape == (x_test.shape[0], k) + x_train.shape[1:]
+ assert distances.shape == (x_test.shape[0], k)
+ assert labels.shape == (x_test.shape[0], k)
+ assert indices.shape == (x_test.shape[0], k, 2)
+
+ # for each sample
+ for i in range(x_test.shape[0]):
+ # check first closest prototype label is the same as the sample label
+ assert tf.reduce_all(tf.equal(labels[i, 0], y_test[i]))
+
+ for j in range(k):
+ # check indices in prototypes' indices
+ index = indices[i, j, 0] * batch_size + indices[i, j, 1]
+ assert index in prototypes_indices
+
+ # check examples are in prototypes
+ assert tf.reduce_all(tf.equal(prototypes[prototypes_indices == index], examples[i, j]))
+
+ # check indices are in the dataset
+ assert tf.reduce_all(tf.equal(x_train[index], examples[i, j]))
+
+ # check labels
+ assert tf.reduce_all(tf.equal(labels[i, j], y_train[index]))
+
+ # check distances
+ assert almost_equal(
+ distances[i, j],
+ tf.norm(tf.reduce_mean(x_train[index]) - tf.reduce_mean(x_test[i])),
+ epsilon=1e-5
+ )
diff --git a/tests/utils.py b/tests/utils.py
index 280d7a5f..000b7f01 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -255,14 +255,21 @@ def download_file(identifier: str,
if chunk:
file.write(chunk)
-def get_Gaussian_Data(nb_samples_class=20):
+def get_gaussian_data(nb_classes=3, nb_samples_class=20):
tf.random.set_seed(42)
- sigma = 0.05
- mu = [10, 20, 30]
+ sigma = 1
+ mu = [10 * (id + 1) for id in range(nb_classes)]
- X = tf.concat([tf.random.normal(shape=(nb_samples_class,1), mean=mu[i], stddev=sigma, dtype=tf.float32) for i in range(3)], axis=0)
- y = tf.concat([tf.ones(shape=(nb_samples_class), dtype=tf.int32) * i for i in range(3)], axis=0)
+ X = tf.concat([
+ tf.random.normal(shape=(nb_samples_class,1), mean=mu[i], stddev=sigma, dtype=tf.float32)
+ for i in range(nb_classes)
+ ], axis=0)
+
+ y = tf.concat([
+ tf.ones(shape=(nb_samples_class), dtype=tf.int32) * i
+ for i in range(3)
+ ], axis=0)
return(X, y)
diff --git a/xplique/commons/tf_dataset_operations.py b/xplique/commons/tf_dataset_operations.py
index 83c81fa0..ea010933 100644
--- a/xplique/commons/tf_dataset_operations.py
+++ b/xplique/commons/tf_dataset_operations.py
@@ -197,11 +197,20 @@ def dataset_gather(dataset: tf.data.Dataset, indices: tf.Tensor) -> tf.Tensor:
Returns
-------
results
-
- indices should be (n, k, 2)
+ A tensor with the extracted elements from the `dataset`.
+ The shape of the tensor is (n, k, ...), where ... is the shape of the elements in the `dataset`.
"""
if dataset is None:
return None
+
+ if len(indices.shape) != 3 or indices.shape[-1] != 2:
+ raise ValueError(
+ "Indices should have dimensions (n, k, 2), "
+ + "where n represent the number of inputs and k the number of corresponding examples. "
+ + "The index of each element is encoded by two values, "
+ + "the batch index and the index of the element in the batch. "
+ + f"Received {indices.shape}."
+ )
example = next(iter(dataset))
# (n, bs, ...)
diff --git a/xplique/example_based/prototypes.py b/xplique/example_based/prototypes.py
index 5742190d..a4819952 100644
--- a/xplique/example_based/prototypes.py
+++ b/xplique/example_based/prototypes.py
@@ -9,6 +9,8 @@
from ..types import Callable, Dict, List, Optional, Type, Union
+from ..commons.tf_dataset_operations import dataset_gather
+
from .search_methods import BaseSearchMethod, ProtoGreedySearch, MMDCriticSearch, ProtoDashSearch
from .projections import Projection
from .base_example_method import BaseExampleMethod
@@ -38,11 +40,13 @@ class Prototypes(BaseExampleMethod, ABC):
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
k
- The number of examples to retrieve.
+ For decision explanations, the number of closest prototypes to return. Used in `explain`.
+ Default is 1, which means that only the closest prototype is returned.
projection
Projection or Callable that project samples from the input space to the search space.
The search space should be a space where distance make sense for the model.
- It should not be `None`, otherwise,
+ The output of the projection should be a two dimensional tensor. (nb_samples, nb_features).
+ `projection` should not be `None`, otherwise,
all examples could be computed only with the `search_method`.
Example of Callable:
@@ -61,9 +65,23 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
See `self.set_returns()` for detail.
batch_size
Number of sample treated simultaneously for projection and search.
- Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
- search_method_kwargs
- Parameters to be passed at the construction of the `search_method`.
+ Ignored if `tf.data.Dataset` are provided (these are supposed to be batched).
+ distance
+ Either a Callable, or a value supported by `tf.norm` `ord` parameter.
+ Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
+ "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
+ yielding the corresponding p-norm." We also added 'cosine'.
+ nb_prototypes : int
+ For general explanations, the number of prototypes to select.
+ If `class_wise` is True, it will correspond to the number of prototypes per class.
+ kernel_type : str, optional
+ The kernel type. It can be 'local' or 'global', by default 'local'.
+ When it is local, the distances are calculated only within the classes.
+ kernel_fn : Callable, optional
+ Kernel function, by default the rbf kernel.
+ This function must only use TensorFlow operations.
+ gamma : float, optional
+ Parameter that determines the spread of the rbf kernel, defaults to 1.0 / n_features.
"""
def __init__(
@@ -77,7 +95,7 @@ def __init__(
batch_size: Optional[int] = 32,
distance: Union[int, str, Callable] = None,
nb_prototypes: int = 1,
- kernel_type: str = 'local',
+ kernel_type: str = 'local',
kernel_fn: callable = None,
gamma: float = None
):
@@ -118,18 +136,47 @@ def __init__(
def search_method_class(self) -> Type[ProtoGreedySearch]:
raise NotImplementedError
- def get_global_prototypes(self):
+ def get_global_prototypes(self) -> Dict[str, tf.Tensor]:
"""
- Return all the prototypes computed by the search method,
- which consist of a global explanation of the dataset.
-
- Returns:
- prototype_indices : Tensor
- prototype indices.
- prototype_weights : Tensor
- prototype weights.
+ Provide the global prototypes computed at the initialization.
+ Prototypes and their labels are extracted from the indices.
+ The weights of the prototypes and their indices are also returned.
+
+ Returns
+ -------
+ prototypes_dict : Dict[str, tf.Tensor]
+ A dictionary with the following
+ - 'prototypes': The prototypes found by the method.
+ - 'prototype_labels': The labels of the prototypes.
+ - 'prototype_weights': The weights of the prototypes.
+ - 'prototype_indices': The indices of the prototypes.
"""
- return self.search_method.prototype_indices, self.search_method.prototype_weights
+ # (nb_prototypes,)
+ indices = self.search_method.prototypes_indices
+ batch_indices = indices // self.batch_size
+ elem_indices = indices % self.batch_size
+
+ # (nb_prototypes, 2)
+ batch_elem_indices = tf.stack([batch_indices, elem_indices], axis=1)
+
+ # (1, nb_prototypes, 2)
+ batch_elem_indices = tf.expand_dims(batch_elem_indices, axis=0)
+
+ # (nb_prototypes, ...)
+ prototypes = dataset_gather(self.cases_dataset, batch_elem_indices)[0]
+
+ # (nb_prototypes,)
+ labels = dataset_gather(self.labels_dataset, batch_elem_indices)[0]
+
+ # (nb_prototypes,)
+ weights = self.search_method.prototypes_weights
+
+ return {
+ "prototypes": prototypes,
+ "prototypes_labels": labels,
+ "prototypes_weights": weights,
+ "prototypes_indices": indices,
+ }
class ProtoGreedy(Prototypes):
@@ -145,93 +192,6 @@ def search_method_class(self) -> Type[ProtoGreedySearch]:
class ProtoDash(Prototypes):
- """
- Protodash method for searching prototypes.
-
- References:
- .. [#] `Karthik S. Gurumoorthy, Amit Dhurandhar, Guillermo Cecchi,
- "ProtoDash: Fast Interpretable Prototype Selection"
- `_
-
- Parameters
- ----------
- cases_dataset
- The dataset used to train the model, examples are extracted from the dataset.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- labels_dataset
- Labels associated to the examples in the dataset. Indices should match with cases_dataset.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other dataset should match `cases_dataset`.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- targets_dataset
- Targets associated to the cases_dataset for dataset projection. See `projection` for detail.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other dataset should match `cases_dataset`.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- k
- The number of examples to retrieve.
- search_returns
- String or list of string with the elements to return in `self.find_examples()`.
- See `self.set_returns()` for detail.
- batch_size
- Number of sample treated simultaneously.
- It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
- distance
- Either a Callable, or a value supported by `tf.norm` `ord` parameter.
- Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
- "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
- yielding the corresponding p-norm." We also added 'cosine'.
- nb_prototypes : int
- Number of prototypes to find.
- kernel_type : str, optional
- The kernel type. It can be 'local' or 'global', by default 'local'.
- When it is local, the distances are calculated only within the classes.
- kernel_fn : Callable, optional
- Kernel function, by default the rbf kernel.
- This function must only use TensorFlow operations.
- gamma : float, optional
- Parameter that determines the spread of the rbf kernel, defaults to 1.0 / n_features.
- use_optimizer : bool, optional
- Flag indicating whether to use an optimizer for prototype selection, by default False.
- """
-
- def __init__(
- self,
- cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- k: int = 1,
- projection: Union[Projection, Callable] = None,
- case_returns: Union[List[str], str] = "examples",
- batch_size: Optional[int] = 32,
- distance: Union[int, str, Callable] = None,
- nb_prototypes: int = 1,
- kernel_type: str = 'local',
- kernel_fn: callable = None,
- gamma: float = None,
- use_optimizer: bool = False,
- ): # pylint: disable=R0801
- self.use_optimizer = use_optimizer
-
- super().__init__(
- cases_dataset=cases_dataset,
- labels_dataset=labels_dataset,
- targets_dataset=targets_dataset,
- k=k,
- projection=projection,
- case_returns=case_returns,
- batch_size=batch_size,
- distance=distance,
- nb_prototypes=nb_prototypes,
- kernel_type=kernel_type,
- kernel_fn=kernel_fn,
- gamma=gamma
- )
-
@property
def search_method_class(self) -> Type[ProtoGreedySearch]:
return ProtoDashSearch
diff --git a/xplique/example_based/search_methods/proto_dash_search.py b/xplique/example_based/search_methods/proto_dash_search.py
index 5bb7b78b..cbe78b40 100644
--- a/xplique/example_based/search_methods/proto_dash_search.py
+++ b/xplique/example_based/search_methods/proto_dash_search.py
@@ -117,8 +117,10 @@ class ProtoDashSearch(ProtoGreedySearch):
This function must only use TensorFlow operations.
gamma : float, optional
Parameter that determines the spread of the rbf kernel, defaults to 1.0 / n_features.
- use_optimizer : bool, optional
- Flag indicating whether to use an optimizer for prototype selection, by default False.
+ exact_selection_weights_update : bool, optional
+ Wether to use an exact method to update selection weights, by default False.
+ Exact method is based on a scipy optimization,
+ while the other is based on a tensorflow inverse operation.
"""
def __init__(
@@ -133,10 +135,10 @@ def __init__(
kernel_type: str = 'local',
kernel_fn: callable = None,
gamma: float = None,
- use_optimizer: bool = False,
+ exact_selection_weights_update: bool = False,
): # pylint: disable=R0801
- self.use_optimizer = use_optimizer
+ self.exact_selection_weights_update = exact_selection_weights_update
super().__init__(
cases_dataset=cases_dataset,
@@ -187,7 +189,7 @@ def update_selection_weights(self, selection_indices, selection_weights, selecti
u = tf.expand_dims(tf.gather(self.col_means, selection_indices), axis=1)
K = selection_selection_kernel
- if self.use_optimizer:
+ if self.exact_selection_weights_update:
initial_weights = tf.concat([selection_weights, [best_objective / tf.gather(self.diag, best_indice)]], axis=0)
opt = Optimizer(initial_weights)
selection_weights, _ = opt.optimize(u, K)
diff --git a/xplique/example_based/search_methods/proto_greedy_search.py b/xplique/example_based/search_methods/proto_greedy_search.py
index 39ed7159..5000e4d4 100644
--- a/xplique/example_based/search_methods/proto_greedy_search.py
+++ b/xplique/example_based/search_methods/proto_greedy_search.py
@@ -13,6 +13,23 @@
# from ..projections import Projection
+def rbf_kernel(X, Y=None, gamma=None):
+ if Y is None:
+ Y = X
+
+ if gamma is None:
+ gamma = 1.0 / tf.cast(tf.shape(X)[1], dtype=X.dtype)
+
+ X = tf.expand_dims(X, axis=1)
+ Y = tf.expand_dims(Y, axis=0)
+
+ pairwise_diff = X - Y
+ pairwise_sq_dist = tf.reduce_sum(tf.square(pairwise_diff), axis=-1)
+ kernel_matrix = tf.exp(-gamma * pairwise_sq_dist)
+
+ return kernel_matrix
+
+
class ProtoGreedySearch(BaseSearchMethod):
"""
ProtoGreedy method for searching prototypes.
@@ -77,75 +94,57 @@ def __init__(
self.labels_dataset = sanitize_dataset(labels_dataset, self.batch_size)
- if kernel_type in ['local', 'global']:
- self.kernel_type = kernel_type
- else:
+ if kernel_type not in ['local', 'global']:
raise AttributeError(
"The kernel_type parameter is expected to be in"
+ " ['local', 'global'] ",
+f"but {kernel_type} was received.",
)
+
+ self.kernel_type = kernel_type
+ # set default kernel function (rbf_kernel) or raise error if kernel_fn is not callable
if kernel_fn is None:
# define rbf kernel function
- def rbf_kernel(X, Y=None, gamma=None):
- if Y is None:
- Y = X
-
- if gamma is None:
- gamma = 1.0 / tf.cast(tf.shape(X)[1], dtype=X.dtype)
-
- X = tf.expand_dims(X, axis=1)
- Y = tf.expand_dims(Y, axis=0)
-
- pairwise_diff = X - Y
- pairwise_sq_dist = tf.reduce_sum(tf.square(pairwise_diff), axis=-1)
- kernel_matrix = tf.exp(-gamma * pairwise_sq_dist)
-
- return kernel_matrix
-
kernel_fn = lambda x, y: rbf_kernel(x,y,gamma)
-
- if hasattr(kernel_fn, "__call__"):
- def custom_kernel_fn(x1, x2, y1=None, y2=None):
- if self.kernel_type == 'global':
- kernel_matrix = kernel_fn(x1,x2)
- if isinstance(kernel_matrix, np.ndarray):
- kernel_matrix = tf.convert_to_tensor(kernel_matrix)
- else:
- # In the case of a local kernel, calculations are limited to within the class.
- # Across different classes, the kernel values are set to 0.
- kernel_matrix = np.zeros((x1.shape[0], x2.shape[0]), dtype=np.float32)
- y_intersect = np.intersect1d(y1, y2)
- for i in range(y_intersect.shape[0]):
- y1_indices = tf.where(tf.equal(y1, y_intersect[i]))[:, 0]
- y2_indices = tf.where(tf.equal(y2, y_intersect[i]))[:, 0]
- sub_matrix = kernel_fn(tf.gather(x1, y1_indices), tf.gather(x2, y2_indices))
- kernel_matrix[tf.reshape(y1_indices, (-1, 1)), tf.reshape(y2_indices, (1, -1))] = sub_matrix
- kernel_matrix = tf.convert_to_tensor(kernel_matrix)
- return kernel_matrix
-
- self.kernel_fn = custom_kernel_fn
- else:
+ elif not hasattr(kernel_fn, "__call__"):
raise AttributeError(
"The kernel_fn parameter is expected to be a Callable",
+f"but {kernel_fn} was received.",
)
+ # define custom kernel function depending on the kernel type
+ def custom_kernel_fn(x1, x2, y1=None, y2=None):
+ if self.kernel_type == 'global':
+ kernel_matrix = kernel_fn(x1,x2)
+ if isinstance(kernel_matrix, np.ndarray):
+ kernel_matrix = tf.convert_to_tensor(kernel_matrix)
+ else:
+ # In the case of a local kernel, calculations are limited to within the class.
+ # Across different classes, the kernel values are set to 0.
+ kernel_matrix = np.zeros((x1.shape[0], x2.shape[0]), dtype=np.float32)
+ y_intersect = np.intersect1d(y1, y2)
+ for i in range(y_intersect.shape[0]):
+ y1_indices = tf.where(tf.equal(y1, y_intersect[i]))[:, 0]
+ y2_indices = tf.where(tf.equal(y2, y_intersect[i]))[:, 0]
+ sub_matrix = kernel_fn(tf.gather(x1, y1_indices), tf.gather(x2, y2_indices))
+ kernel_matrix[tf.reshape(y1_indices, (-1, 1)), tf.reshape(y2_indices, (1, -1))] = sub_matrix
+ kernel_matrix = tf.convert_to_tensor(kernel_matrix)
+ return kernel_matrix
+
+ self.kernel_fn = custom_kernel_fn
+
+
if distance is None:
- def kernel_induced_distance(x1,x2):
+ def kernel_induced_distance(x1, x2):
x1 = tf.expand_dims(x1, axis=0)
x2 = tf.expand_dims(x2, axis=0)
distance = tf.squeeze(tf.sqrt(kernel_fn(x1,x1) - 2 * kernel_fn(x1,x2) + kernel_fn(x2,x2)))
return distance
-
- self.distance_fn = lambda x1, x2: kernel_induced_distance(x1,x2)
-
+ self.distance_fn = kernel_induced_distance
elif hasattr(distance, "__call__"):
self.distance_fn = distance
- elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(
- distance, int
- ):
+ elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(distance, int):
self.distance_fn = lambda x1, x2: tf.norm(x1 - x2, ord=distance, axis=-1)
else:
raise AttributeError(
@@ -165,7 +164,7 @@ def kernel_induced_distance(x1,x2):
):
# elements should be tabular data
assert len(batch_col_cases.shape) == 2,\
- "Expected prototypes' searches expects 2D data, (nb_samples, nb_features),"+\
+ "Prototypes' searches expects 2D data, (nb_samples, nb_features),"+\
f"but got {batch_col_cases.shape}"+\
"Please verify your projection if you provided a custom one."+\
"If you use a splitted model, make sure the output of the first part of the model is flattened."
@@ -205,10 +204,10 @@ def kernel_induced_distance(x1,x2):
self.nb_features = batch_col_cases.shape[1]
# compute the prototypes in the latent space
- self.prototype_indices, self.prototype_cases, self.prototype_labels, self.prototype_weights = self.find_prototypes(nb_prototypes)
+ self.prototypes_indices, self.prototypes, self.prototypes_labels, self.prototypes_weights = self.find_prototypes(nb_prototypes)
self.knn = KNN(
- cases_dataset=self.prototype_cases,
+ cases_dataset=self.prototypes,
k=k,
search_returns=search_returns,
batch_size=batch_size,
@@ -314,13 +313,13 @@ def find_prototypes(self, nb_prototypes):
Returns
-------
- prototype_indices : Tensor
+ prototypes_indices : Tensor
The indices of the selected prototypes.
- prototype_cases : Tensor
+ prototypes : Tensor
The cases of the selected prototypes.
- prototype_labels : Tensor
+ prototypes_labels : Tensor
The labels of the selected prototypes.
- prototype_weights :
+ prototypes_weights :
The normalized weights of the selected prototypes.
"""
@@ -420,15 +419,15 @@ def find_prototypes(self, nb_prototypes):
k += 1
- prototype_indices = selection_indices
- prototype_cases = selection_cases
- prototype_labels = selection_labels
- prototype_weights = selection_weights
+ prototypes_indices = selection_indices
+ prototypes = selection_cases
+ prototypes_labels = selection_labels
+ prototypes_weights = selection_weights
# Normalize the weights
- prototype_weights = prototype_weights / tf.reduce_sum(prototype_weights)
+ prototypes_weights = prototypes_weights / tf.reduce_sum(prototypes_weights)
- return prototype_indices, prototype_cases, prototype_labels, prototype_weights
+ return prototypes_indices, prototypes, prototypes_labels, prototypes_weights
def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], _):
"""
@@ -454,7 +453,7 @@ def find_examples(self, inputs: Union[tf.Tensor, np.ndarray], _):
indices_wrt_prototypes = indices_wrt_prototypes[:, :, 0] * self.batch_size + indices_wrt_prototypes[:, :, 1]
# get prototypes indices with respect to the dataset
- indices = tf.gather(self.prototype_indices, indices_wrt_prototypes)
+ indices = tf.gather(self.prototypes_indices, indices_wrt_prototypes)
# convert back to batch-element indices
batch_indices, elem_indices = indices // self.batch_size, indices % self.batch_size
From e0731785639daa08f4c974f5858e441a447fc757 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 10 Jul 2024 17:08:04 +0200
Subject: [PATCH 058/138] tf dataset operations: make them on cpu
---
xplique/commons/tf_dataset_operations.py | 108 ++++++++++++++++++++---
1 file changed, 98 insertions(+), 10 deletions(-)
diff --git a/xplique/commons/tf_dataset_operations.py b/xplique/commons/tf_dataset_operations.py
index ea010933..783b7e57 100644
--- a/xplique/commons/tf_dataset_operations.py
+++ b/xplique/commons/tf_dataset_operations.py
@@ -156,10 +156,101 @@ def sanitize_dataset(
return dataset
+# def dataset_gather(dataset: tf.data.Dataset, indices: tf.Tensor) -> tf.Tensor:
+# """
+# Imitation of `tf.gather` for `tf.data.Dataset`,
+# it extract elements from `dataset` at the given indices.
+# We could see it as returning the `indices` tensor
+# where each index was replaced by the corresponding element in `dataset`.
+# The aim is to use it in the `example_based` module to extract examples form the cases dataset.
+# Hence, `indices` expect dimensions of (n, k, 2),
+# where n represent the number of inputs and k the number of corresponding examples.
+# Here indices for each element are encoded by two values,
+# the batch index and the index of the element in the batch.
+
+# Example of application
+# ```
+# >>> dataset = tf.data.Dataset.from_tensor_slices(
+# ... tf.reshape(tf.range(20), (-1, 2, 2))
+# ... ).batch(3) # shape=(None, 2, 2)
+# >>> indices = tf.constant([[[0, 0]], [[1, 0]]]) # shape=(2, 1, 2)
+# >>> dataset_gather(dataset, indices)
+#
+# ```
+
+# Parameters
+# ----------
+# dataset
+# Tensorflow dataset to verify or tensor to transform in `tf.data.Dataset` and verify.
+# indices
+# Tensor of indices of elements to extract from the `dataset`.
+# `indices` should be of dimensions (n, k, 2),
+# this is to match the format of indices in the `example_based` module.
+# Indeed, n represent the number of inputs and k the number of corresponding examples.
+# The index of each element is encoded by two values,
+# the batch index and the index of the element in the batch.
+
+# Returns
+# -------
+# results
+# A tensor with the extracted elements from the `dataset`.
+# The shape of the tensor is (n, k, ...), where ... is the shape of the elements in the `dataset`.
+# """
+# if dataset is None:
+# return None
+
+# if len(indices.shape) != 3 or indices.shape[-1] != 2:
+# raise ValueError(
+# "Indices should have dimensions (n, k, 2), "
+# + "where n represent the number of inputs and k the number of corresponding examples. "
+# + "The index of each element is encoded by two values, "
+# + "the batch index and the index of the element in the batch. "
+# + f"Received {indices.shape}."
+# )
+
+# example = next(iter(dataset))
+# # (n, bs, ...)
+# with tf.device('/CPU:0'):
+# if dataset.element_spec.dtype in ['uint8', 'int8', 'int16', 'int32', 'int64']:
+# results = tf.Variable(
+# tf.fill(indices.shape[:-1] + example[0].shape, tf.constant(-1, dtype=dataset.element_spec.dtype)),
+# )
+# else:
+# results = tf.Variable(
+# tf.fill(indices.shape[:-1] + example[0].shape, tf.constant(np.inf, dtype=dataset.element_spec.dtype)),
+# )
+
+# nb_results = product(indices.shape[:-1])
+# current_nb_results = 0
+
+# for i, batch in enumerate(dataset):
+# # check if the batch is interesting
+# if not tf.reduce_any(indices[..., 0] == i):
+# continue
+
+# # extract pertinent elements
+# pertinent_indices_location = tf.where(indices[..., 0] == i)
+# samples_index = tf.gather_nd(indices[..., 1], pertinent_indices_location)
+# samples = tf.gather(batch, samples_index)
+
+# # put them at the right place in results
+# for location, sample in zip(pertinent_indices_location, samples):
+# results[location[0], location[1]].assign(sample)
+# current_nb_results += 1
+
+# # test if results are filled to break the loop
+# if current_nb_results == nb_results:
+# break
+# return results
+
def dataset_gather(dataset: tf.data.Dataset, indices: tf.Tensor) -> tf.Tensor:
"""
Imitation of `tf.gather` for `tf.data.Dataset`,
- it extract elements from `dataset` at the given indices.
+ it extracts elements from `dataset` at the given indices.
We could see it as returning the `indices` tensor
where each index was replaced by the corresponding element in `dataset`.
The aim is to use it in the `example_based` module to extract examples form the cases dataset.
@@ -175,7 +266,7 @@ def dataset_gather(dataset: tf.data.Dataset, indices: tf.Tensor) -> tf.Tensor:
... ).batch(3) # shape=(None, 2, 2)
>>> indices = tf.constant([[[0, 0]], [[1, 0]]]) # shape=(2, 1, 2)
>>> dataset_gather(dataset, indices)
- tf.Tensor:
)
example = next(iter(dataset))
- # (n, bs, ...)
+
if dataset.element_spec.dtype in ['uint8', 'int8', 'int16', 'int32', 'int64']:
- results = tf.Variable(
- tf.fill(indices.shape[:-1] + example[0].shape, tf.constant(-1, dtype=dataset.element_spec.dtype)),
- )
+ results = tf.fill(indices.shape[:-1] + example[0].shape, tf.constant(-1, dtype=dataset.element_spec.dtype))
else:
- results = tf.Variable(
- tf.fill(indices.shape[:-1] + example[0].shape, tf.constant(np.inf, dtype=dataset.element_spec.dtype)),
- )
+ results = tf.fill(indices.shape[:-1] + example[0].shape, tf.constant(np.inf, dtype=dataset.element_spec.dtype))
nb_results = product(indices.shape[:-1])
current_nb_results = 0
@@ -238,10 +325,11 @@ def dataset_gather(dataset: tf.data.Dataset, indices: tf.Tensor) -> tf.Tensor:
# put them at the right place in results
for location, sample in zip(pertinent_indices_location, samples):
- results[location[0], location[1]].assign(sample)
+ results = tf.tensor_scatter_nd_update(results, [location], [sample])
current_nb_results += 1
# test if results are filled to break the loop
if current_nb_results == nb_results:
break
+
return results
From 1e1b0c7aa00b2beab4dae9dd6d6b8897ba561a9a Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 10 Jul 2024 17:08:41 +0200
Subject: [PATCH 059/138] test prototypes: remove absurd tests
---
tests/example_based/test_prototypes.py | 69 +++-----------------------
1 file changed, 7 insertions(+), 62 deletions(-)
diff --git a/tests/example_based/test_prototypes.py b/tests/example_based/test_prototypes.py
index b9fd43f8..fe1ae962 100644
--- a/tests/example_based/test_prototypes.py
+++ b/tests/example_based/test_prototypes.py
@@ -6,20 +6,12 @@
sys.path.append(os.getcwd())
-from math import prod, sqrt
-import unittest
-import time
-
-import numpy as np
import tensorflow as tf
-from xplique.commons import sanitize_dataset, are_dataset_first_elems_equal
-from xplique.types import Union
-
from xplique.example_based import Prototypes, ProtoGreedy, ProtoDash, MMDCritic
from xplique.example_based.projections import Projection, LatentSpaceProjection
-from tests.utils import almost_equal, get_gaussian_data, load_data, plot, plot_local_explanation
+from tests.utils import almost_equal, get_gaussian_data, generate_model
def test_prototypes_global_explanations_basic():
@@ -30,10 +22,9 @@ def test_prototypes_global_explanations_basic():
k = 3
nb_prototypes = 5
nb_classes = 3
-
gamma = 0.026
+
x_train, y_train = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=20)
- x_test, y_test = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=10)
identity_projection = Projection(
space_projection=lambda inputs, targets=None: inputs
@@ -90,8 +81,8 @@ def test_prototypes_local_explanations_basic():
nb_prototypes = 5
nb_classes = 3
batch_size = 8
-
gamma = 0.026
+
x_train, y_train = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=20)
x_test, y_test = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=10)
@@ -157,18 +148,18 @@ def test_prototypes_local_explanations_basic():
assert tf.reduce_all(tf.equal(labels[i, j], y_train[index]))
-def test_prototypes_global_sanity_checks_1():
+def test_prototypes_global_sanity_check():
"""
Test prototypes global explanations sanity checks.
- Check 1: For n separated gaussians, for n requested prototypes, there should be 1 prototype per gaussian.
+ Check: For n separated gaussians, for n requested prototypes, there should be 1 prototype per gaussian.
"""
# Setup
k = 3
nb_prototypes = 3
-
gamma = 0.026
+
x_train, y_train = get_gaussian_data(nb_classes=nb_prototypes, nb_samples_class=20)
identity_projection = Projection(
@@ -198,52 +189,6 @@ def test_prototypes_global_sanity_checks_1():
# check 1
assert len(tf.unique(prototypes_labels)[0]) == nb_prototypes
-
-
-def test_prototypes_global_sanity_checks_2():
- """
- Test prototypes global explanations sanity checks.
-
- Check 2: With local kernel_type, if there are more requested prototypes than classes, there should be at least 1 prototype per class.
- """
-
- # Setup
- k = 3
- nb_prototypes = 5
- nb_classes = 3
-
- gamma = 0.026
- x_train, y_train = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=20)
-
- # randomize y_train
- y_train = tf.random.shuffle(y_train)
-
- identity_projection = Projection(
- space_projection=lambda inputs, targets=None: inputs
- )
-
- for method_class in [ProtoGreedy, ProtoDash, MMDCritic]:
- # compute general prototypes
- method = method_class(
- cases_dataset=x_train,
- labels_dataset=y_train,
- k=k,
- projection=identity_projection,
- batch_size=8,
- distance="euclidean",
- nb_prototypes=nb_prototypes,
- kernel_type="local",
- gamma=gamma,
- )
- # extract prototypes
- prototypes_dict = method.get_global_prototypes()
- prototypes = prototypes_dict["prototypes"]
- prototypes_indices = prototypes_dict["prototypes_indices"]
- prototypes_labels = prototypes_dict["prototypes_labels"]
- prototypes_weights = prototypes_dict["prototypes_weights"]
-
- # check 2
- assert len(tf.unique(prototypes_labels)[0]) == nb_classes
def test_prototypes_local_explanations_with_projection():
@@ -255,8 +200,8 @@ def test_prototypes_local_explanations_with_projection():
nb_prototypes = 5
nb_classes = 3
batch_size = 8
-
gamma = 0.026
+
x_train, y_train = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=20)
x_train_bis, _ = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=20)
x_train = tf.concat([x_train, x_train_bis], axis=1) # make a dataset with two dimensions
From edabcda999cfd704ef81244bda8ecbc1cbd562a8 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 10 Jul 2024 17:09:13 +0200
Subject: [PATCH 060/138] example based: linting
---
xplique/example_based/base_example_method.py | 20 ++++++++++++-------
.../search_methods/proto_greedy_search.py | 7 +++----
2 files changed, 16 insertions(+), 11 deletions(-)
diff --git a/xplique/example_based/base_example_method.py b/xplique/example_based/base_example_method.py
index 02cc1f4e..3fabe7c1 100644
--- a/xplique/example_based/base_example_method.py
+++ b/xplique/example_based/base_example_method.py
@@ -210,9 +210,15 @@ def _initialize_cases_dataset(
cardinality = cases_dataset.cardinality().numpy()
else:
- # if case_dataset is not a `tf.data.Dataset`, then neither should the other.
- assert not isinstance(labels_dataset, tf.data.Dataset)
- assert not isinstance(targets_dataset, tf.data.Dataset)
+ # if cases_dataset is not a `tf.data.Dataset`, then neither should the other.
+ assert not isinstance(labels_dataset, tf.data.Dataset), (
+ "if the cases_dataset is not a `tf.data.Dataset`, "
+ + "then neither should the labels_dataset."
+ )
+ assert not isinstance(targets_dataset, tf.data.Dataset), (
+ "if the cases_dataset is not a `tf.data.Dataset`, "
+ + "then neither should the targets_dataset."
+ )
# set batch size and cardinality
batch_size = min(batch_size, len(cases_dataset))
cardinality = math.ceil(len(cases_dataset) / batch_size)
@@ -233,7 +239,7 @@ def _initialize_cases_dataset(
# switch case on the number of columns of `cases_dataset`
if len(self.cases_dataset.element_spec) == 2:
assert self.labels_dataset is None, (
- "The second column of `cases_dataset` is assumed to be the labels."
+ "The second column of `cases_dataset` is assumed to be the labels. "
+ "Hence, `labels_dataset` should be empty."
)
self.labels_dataset = self.cases_dataset.map(lambda x, y: y)
@@ -241,11 +247,11 @@ def _initialize_cases_dataset(
elif len(self.cases_dataset.element_spec) == 3:
assert self.labels_dataset is None, (
- "The second column of `cases_dataset` is assumed to be the labels."
+ "The second column of `cases_dataset` is assumed to be the labels. "
+ "Hence, `labels_dataset` should be empty."
)
assert self.targets_dataset is None, (
- "The second column of `cases_dataset` is assumed to be the labels."
+ "The second column of `cases_dataset` is assumed to be the labels. "
+ "Hence, `labels_dataset` should be empty."
)
self.targets_dataset = self.cases_dataset.map(lambda x, y, t: t)
@@ -253,7 +259,7 @@ def _initialize_cases_dataset(
self.cases_dataset = self.cases_dataset.map(lambda x, y, t: x)
else:
raise AttributeError(
- "`cases_dataset` cannot possess more than 3 columns,"
+ "`cases_dataset` cannot possess more than 3 columns, "
+ f"{len(self.cases_dataset.element_spec)} were detected."
)
diff --git a/xplique/example_based/search_methods/proto_greedy_search.py b/xplique/example_based/search_methods/proto_greedy_search.py
index 5000e4d4..d21ae9e1 100644
--- a/xplique/example_based/search_methods/proto_greedy_search.py
+++ b/xplique/example_based/search_methods/proto_greedy_search.py
@@ -126,15 +126,14 @@ def custom_kernel_fn(x1, x2, y1=None, y2=None):
y_intersect = np.intersect1d(y1, y2)
for i in range(y_intersect.shape[0]):
y1_indices = tf.where(tf.equal(y1, y_intersect[i]))[:, 0]
- y2_indices = tf.where(tf.equal(y2, y_intersect[i]))[:, 0]
- sub_matrix = kernel_fn(tf.gather(x1, y1_indices), tf.gather(x2, y2_indices))
+ y2_indices = tf.where(tf.equal(y2, y_intersect[i]))[:, 0]
+ sub_matrix = kernel_fn(tf.gather(x1, y1_indices), tf.gather(x2, y2_indices))
kernel_matrix[tf.reshape(y1_indices, (-1, 1)), tf.reshape(y2_indices, (1, -1))] = sub_matrix
kernel_matrix = tf.convert_to_tensor(kernel_matrix)
return kernel_matrix
self.kernel_fn = custom_kernel_fn
-
-
+
if distance is None:
def kernel_induced_distance(x1, x2):
x1 = tf.expand_dims(x1, axis=0)
From 5904a7700f0a24091ae864a092c22cc7a9edb32b Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 10 Jul 2024 17:17:35 +0200
Subject: [PATCH 061/138] prototypes: change constant for memory
---
xplique/example_based/search_methods/proto_greedy_search.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/xplique/example_based/search_methods/proto_greedy_search.py b/xplique/example_based/search_methods/proto_greedy_search.py
index d21ae9e1..238ce8ad 100644
--- a/xplique/example_based/search_methods/proto_greedy_search.py
+++ b/xplique/example_based/search_methods/proto_greedy_search.py
@@ -73,7 +73,7 @@ class ProtoGreedySearch(BaseSearchMethod):
# Avoid zero division during procedure. (the value is not important, as if the denominator is
# zero, then the nominator will also be zero).
- EPSILON = tf.constant(1e-6)
+ EPSILON = 1e-6
def __init__(
self,
From dee479377eaea5ced62fcb8811d4213a92b88386 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Thu, 11 Jul 2024 11:49:22 +0200
Subject: [PATCH 062/138] test contrastive: add projection test
---
tests/example_based/test_contrastive.py | 30 ++++++++++++++++++++++++-
1 file changed, 29 insertions(+), 1 deletion(-)
diff --git a/tests/example_based/test_contrastive.py b/tests/example_based/test_contrastive.py
index eab75ca7..b91c5814 100644
--- a/tests/example_based/test_contrastive.py
+++ b/tests/example_based/test_contrastive.py
@@ -5,7 +5,10 @@
import numpy as np
from xplique.example_based import NaiveCounterFactuals, LabelAwareCounterFactuals, KLEORGlobalSim, KLEORSimMiss
-from xplique.example_based.projections import Projection
+from xplique.example_based.projections import Projection, LatentSpaceProjection
+
+from ..utils import generate_data, generate_model
+
def test_naive_counter_factuals():
"""
@@ -281,3 +284,28 @@ def test_kleor():
assert tf.reduce_all(
tf.abs(tf.where(inf_mask_examples, 0.0, examples) - tf.where(inf_mask_expected_examples, 0.0, expected_examples)
) < 1e-5)
+
+
+def test_contrastive_with_projection():
+ input_shapes = [(28, 28, 1), (32, 32, 3)]
+ nb_labels = 10
+ nb_samples = 50
+
+ for input_shape in input_shapes:
+ features, labels = generate_data(input_shape, nb_labels, nb_samples)
+ model = generate_model(input_shape, nb_labels)
+
+ projection = LatentSpaceProjection(model, latent_layer=-1)
+
+ for contrastive_method_class in [NaiveCounterFactuals, LabelAwareCounterFactuals,
+ KLEORGlobalSim, KLEORSimMiss]:
+ contrastive_method = contrastive_method_class(
+ features,
+ labels,
+ k=1,
+ projection=projection,
+ case_returns=["examples", "indices", "distances", "include_inputs"],
+ batch_size=7
+ )
+
+ contrastive_method(features, labels)
\ No newline at end of file
From 22849058377160553a4a99f231aa93b5c225e29e Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Thu, 11 Jul 2024 11:49:44 +0200
Subject: [PATCH 063/138] contrastive: solve projection problems
---
xplique/example_based/contrastive_examples.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/xplique/example_based/contrastive_examples.py b/xplique/example_based/contrastive_examples.py
index caf4a3fe..3b82c73b 100644
--- a/xplique/example_based/contrastive_examples.py
+++ b/xplique/example_based/contrastive_examples.py
@@ -99,7 +99,7 @@ def __init__(
# initiate search_method
self.search_method = self.search_method_class(
- cases_dataset=self.cases_dataset,
+ cases_dataset=self.projected_cases_dataset,
targets_dataset=self.targets_dataset,
k=self.k,
search_returns=self._search_returns,
@@ -223,7 +223,7 @@ def __init__(
# initiate search_method
self.search_method = self.search_method_class(
- cases_dataset=self.cases_dataset,
+ cases_dataset=self.projected_cases_dataset,
targets_dataset=self.targets_dataset,
k=self.k,
search_returns=self._search_returns,
@@ -250,7 +250,7 @@ def filter_fn(self, _, __, cf_targets, cases_targets) -> tf.Tensor:
Parameters
----------
cf_targets
- The one-hot enoding of the target class for the counterfactuals.
+ The one-hot encoding of the target class for the counterfactuals.
cases_targets
The one-hot encoding of the target class for the cases.
"""
@@ -384,7 +384,7 @@ def __init__(
# initiate search_method
self.search_method = self.search_method_class(
- cases_dataset=self.cases_dataset,
+ cases_dataset=self.projected_cases_dataset,
targets_dataset=self.targets_dataset,
k=self.k,
search_returns=self._search_returns,
From 0356279ba8c71b095dd95a560d753ee03b91488a Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Thu, 11 Jul 2024 15:21:41 +0200
Subject: [PATCH 064/138] semi-factual: allow to return nuns labels
---
xplique/example_based/contrastive_examples.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/xplique/example_based/contrastive_examples.py b/xplique/example_based/contrastive_examples.py
index 3b82c73b..e0dfce7d 100644
--- a/xplique/example_based/contrastive_examples.py
+++ b/xplique/example_based/contrastive_examples.py
@@ -411,6 +411,8 @@ def returns(self, returns: Union[List[str], str]):
self._search_returns.append("nuns_indices")
elif isinstance(self._returns, list) and ("nuns_indices" in self._returns):
self._search_returns.append("nuns_indices")
+ elif isinstance(self._returns, list) and ("nuns_labels" in self._returns):
+ self._search_returns.append("nuns_indices")
if isinstance(self._returns, list) and ("dist_to_nuns" in self._returns):
self._search_returns.append("dist_to_nuns")
@@ -449,6 +451,8 @@ def format_search_output(
return_dict = super().format_search_output(search_output, inputs, targets)
if "nuns" in self.returns:
return_dict["nuns"] = dataset_gather(self.cases_dataset, search_output["nuns_indices"])
+ if "nuns_labels" in self.returns:
+ return_dict["nuns_labels"] = dataset_gather(self.labels_dataset, search_output["nuns_indices"])
if "nuns_indices" in self.returns:
return_dict["nuns_indices"] = search_output["nuns_indices"]
if "dist_to_nuns" in self.returns:
From ac812ae70ecca6f46efa8fcc53db25a9c2e68eaa Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Thu, 11 Jul 2024 15:21:54 +0200
Subject: [PATCH 065/138] prototypes: linting
---
xplique/example_based/prototypes.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/xplique/example_based/prototypes.py b/xplique/example_based/prototypes.py
index a4819952..4f00ff57 100644
--- a/xplique/example_based/prototypes.py
+++ b/xplique/example_based/prototypes.py
@@ -195,4 +195,3 @@ class ProtoDash(Prototypes):
@property
def search_method_class(self) -> Type[ProtoGreedySearch]:
return ProtoDashSearch
-
From f8a377dc1236f3bd88eead39e288f8661b9b63fa Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Thu, 11 Jul 2024 16:14:57 +0200
Subject: [PATCH 066/138] example based: clarify distances
---
tests/example_based/test_cole.py | 2 +-
xplique/example_based/cole.py | 7 +-
xplique/example_based/contrastive_examples.py | 22 ++-
xplique/example_based/prototypes.py | 7 +-
.../example_based/search_methods/common.py | 144 ++++++++++++++++++
xplique/example_based/search_methods/kleor.py | 8 +-
xplique/example_based/search_methods/knn.py | 30 +---
.../search_methods/mmd_critic_search.py | 7 +-
.../search_methods/proto_dash_search.py | 7 +-
.../search_methods/proto_greedy_search.py | 19 +--
xplique/example_based/similar_examples.py | 8 +-
11 files changed, 186 insertions(+), 75 deletions(-)
create mode 100644 xplique/example_based/search_methods/common.py
diff --git a/tests/example_based/test_cole.py b/tests/example_based/test_cole.py
index e96abbb7..3864a71d 100644
--- a/tests/example_based/test_cole.py
+++ b/tests/example_based/test_cole.py
@@ -98,7 +98,7 @@ def test_cole_attribution():
targets_dataset=y_train,
k=k,
batch_size=2,
- distance=np.inf, # infinity norm based distance
+ distance="cosine", # infinity norm based distance
model=model,
attribution_method=Saliency,
)
diff --git a/xplique/example_based/cole.py b/xplique/example_based/cole.py
index 296bcea2..47b21dbe 100644
--- a/xplique/example_based/cole.py
+++ b/xplique/example_based/cole.py
@@ -45,10 +45,9 @@ class Cole(SimilarExamples):
k
The number of examples to retrieve per input.
distance
- Either a Callable, or a value supported by `tf.norm` `ord` parameter.
- Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
- "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
- yielding the corresponding p-norm."
+ Distance function for examples search. It can be an integer, a string in
+ {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ by default "euclidean".
case_returns
String or list of string with the elements to return in `self.explain()`.
See the base class returns property for details.
diff --git a/xplique/example_based/contrastive_examples.py b/xplique/example_based/contrastive_examples.py
index e0dfce7d..31afc794 100644
--- a/xplique/example_based/contrastive_examples.py
+++ b/xplique/example_based/contrastive_examples.py
@@ -66,11 +66,9 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
Number of sample treated simultaneously for projection and search.
Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
distance
- Distance for the FilterKNN search method.
- Either a Callable, or a value supported by `tf.norm` `ord` parameter.
- Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
- "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
- yielding the corresponding p-norm." We also added 'cosine'.
+ Distance function for examples search. It can be an integer, a string in
+ {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ by default "euclidean".
"""
def __init__(
self,
@@ -184,10 +182,9 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
distance
Distance for the FilterKNN search method.
- Either a Callable, or a value supported by `tf.norm` `ord` parameter.
- Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
- "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
- yielding the corresponding p-norm." We also added 'cosine'.
+ Distance function for examples search. It can be an integer, a string in
+ {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ by default "euclidean".
"""
def __init__(
self,
@@ -347,10 +344,9 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
distance
Distance for the FilterKNN search method.
- Either a Callable, or a value supported by `tf.norm` `ord` parameter.
- Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
- "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
- yielding the corresponding p-norm." We also added 'cosine'.
+ Distance function for examples search. It can be an integer, a string in
+ {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ by default "euclidean".
"""
_returns_possibilities = [
"examples", "weights", "distances", "labels", "include_inputs", "nuns", "nuns_indices", "dist_to_nuns"
diff --git a/xplique/example_based/prototypes.py b/xplique/example_based/prototypes.py
index 4f00ff57..5f4017d4 100644
--- a/xplique/example_based/prototypes.py
+++ b/xplique/example_based/prototypes.py
@@ -67,10 +67,9 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
Number of sample treated simultaneously for projection and search.
Ignored if `tf.data.Dataset` are provided (these are supposed to be batched).
distance
- Either a Callable, or a value supported by `tf.norm` `ord` parameter.
- Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
- "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
- yielding the corresponding p-norm." We also added 'cosine'.
+ Distance function for examples search. It can be an integer, a string in
+ {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ by default "euclidean".
nb_prototypes : int
For general explanations, the number of prototypes to select.
If `class_wise` is True, it will correspond to the number of prototypes per class.
diff --git a/xplique/example_based/search_methods/common.py b/xplique/example_based/search_methods/common.py
new file mode 100644
index 00000000..bac0bce6
--- /dev/null
+++ b/xplique/example_based/search_methods/common.py
@@ -0,0 +1,144 @@
+"""
+Common functions for search methods.
+"""
+
+import numpy as np
+import tensorflow as tf
+
+from ...types import Callable, List, Union, Optional, Tuple
+
+
+def _manhattan_distance(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor:
+ """
+ Compute the Manhattan distance between two vectors.
+
+ Parameters
+ ----------
+ x1 : tf.Tensor
+ First vector.
+ x2 : tf.Tensor
+ Second vector.
+
+ Returns
+ -------
+ tf.Tensor
+ Manhattan distance between the two vectors.
+ """
+ return tf.reduce_sum(tf.abs(x1 - x2), axis=-1)
+
+
+def _euclidean_distance(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor:
+ """
+ Compute the Euclidean distance between two vectors.
+
+ Parameters
+ ----------
+ x1 : tf.Tensor
+ First vector.
+ x2 : tf.Tensor
+ Second vector.
+
+ Returns
+ -------
+ tf.Tensor
+ Euclidean distance between the two vectors.
+ """
+ return tf.norm(x1 - x2, axis=-1)
+
+
+def _cosine_distance(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor:
+ """
+ Compute the cosine distance between two vectors.
+
+ Parameters
+ ----------
+ x1 : tf.Tensor
+ First vector.
+ x2 : tf.Tensor
+ Second vector.
+
+ Returns
+ -------
+ tf.Tensor
+ Cosine distance between the two vectors.
+ """
+ return 1 - tf.reduce_sum(x1 * x2, axis=-1) / (
+ tf.norm(x1, axis=-1) * tf.norm(x2, axis=-1)
+ )
+
+
+def _chebyshev_distance(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor:
+ """
+ Compute the Chebyshev distance between two vectors.
+
+ Parameters
+ ----------
+ x1 : tf.Tensor
+ First vector.
+ x2 : tf.Tensor
+ Second vector.
+
+ Returns
+ -------
+ tf.Tensor
+ Chebyshev distance between the two vectors.
+ """
+ return tf.reduce_max(tf.abs(x1 - x2), axis=-1)
+
+
+def _minkowski_distance(x1: tf.Tensor, x2: tf.Tensor, p: int) -> tf.Tensor:
+ """
+ Compute the Minkowski distance between two vectors.
+
+ Parameters
+ ----------
+ x1 : tf.Tensor
+ First vector.
+ x2 : tf.Tensor
+ Second vector.
+ p : int
+ Order of the Minkowski distance.
+
+ Returns
+ -------
+ tf.Tensor
+ Minkowski distance between the two vectors.
+ """
+ return tf.norm(x1 - x2, ord=p, axis=-1)
+
+
+_distances = {
+ "manhattan": _manhattan_distance,
+ "euclidean": _euclidean_distance,
+ "cosine": _cosine_distance,
+ "chebyshev": _chebyshev_distance,
+}
+
+
+def get_distance_function(distance: Union[int, str, Callable] = "euclidean",) -> Callable:
+ """
+ Function to obtain a distance function from different inputs.
+
+ Parameters
+ ----------
+ distance : Union[int, str, Callable], optional
+ Distance function to use. It can be an integer, a string in
+ {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ by default "euclidean".
+ """
+ # set distance function
+ if hasattr(distance, "__call__"):
+ return distance
+ elif isinstance(distance, str) and distance in _distances:
+ return _distances[distance]
+ elif isinstance(distance, int):
+ return lambda x1, x2: _minkowski_distance(x1, x2, p=distance)
+ elif distance == np.inf:
+ return lambda x1, x2: _chebyshev_distance(x1, x2)
+ else:
+ raise AttributeError(
+ "The distance parameter is expected to be either a Callable, "
+ + f" an integer, or a string in {_distances.keys()}. "
+ +f"But {type(distance)} was received."
+ )
+
diff --git a/xplique/example_based/search_methods/kleor.py b/xplique/example_based/search_methods/kleor.py
index 08baa293..57a238e0 100644
--- a/xplique/example_based/search_methods/kleor.py
+++ b/xplique/example_based/search_methods/kleor.py
@@ -42,11 +42,9 @@ class BaseKLEORSearch(FilterKNN, ABC):
batch_size
Number of sample treated simultaneously.
distance
- Distance function to use to measure similarity.
- Either a Callable, or a value supported by `tf.norm` `ord` parameter.
- Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
- "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
- yielding the corresponding p-norm." We also added 'cosine'.
+ Distance function for examples search. It can be an integer, a string in
+ {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ by default "euclidean".
"""
def __init__(
self,
diff --git a/xplique/example_based/search_methods/knn.py b/xplique/example_based/search_methods/knn.py
index 70102a24..a1ae0cdd 100644
--- a/xplique/example_based/search_methods/knn.py
+++ b/xplique/example_based/search_methods/knn.py
@@ -10,6 +10,7 @@
from ...types import Callable, List, Union, Optional, Tuple
from .base import BaseSearchMethod, ORDER
+from .common import get_distance_function
class BaseKNN(BaseSearchMethod):
"""
@@ -158,11 +159,9 @@ class KNN(BaseKNN):
ASCENDING means that the smallest distances are the best, DESCENDING means that the biggest distances are
the best.
distance
- Distance function to use to measure similarity.
- Either a Callable, or a value supported by `tf.norm` `ord` parameter.
- Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
- "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
- yielding the corresponding p-norm." We also added 'cosine'.
+ Distance function for examples search. It can be an integer, a string in
+ {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ by default "euclidean".
"""
def __init__(
self,
@@ -182,18 +181,7 @@ def __init__(
)
# set distance function
- if hasattr(distance, "__call__"):
- self.distance_fn = distance
- elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(
- distance, int
- ):
- self.distance_fn = lambda x1, x2: tf.norm(x1 - x2, ord=distance, axis=-1)
- else:
- raise AttributeError(
- "The distance parameter is expected to be either a Callable or in"
- + " ['fro', 'euclidean', 1, 2, np.inf] "
- +f"but {type(distance)} was received."
- )
+ self.distance_fn = get_distance_function(distance)
@tf.function
def _crossed_distances_fn(self, x1, x2) -> tf.Tensor:
@@ -327,11 +315,9 @@ class FilterKNN(BaseKNN):
ASCENDING means that the smallest distances are the best, DESCENDING means that the biggest distances are
the best.
distance
- Distance function to use to measure similarity.
- Either a Callable, or a value supported by `tf.norm` `ord` parameter.
- Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
- "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
- yielding the corresponding p-norm." We also added 'cosine'.
+ Distance function for examples search. It can be an integer, a string in
+ {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ by default "euclidean".
filter_fn
A Callable that takes as inputs the inputs, their targets, the cases and their targets and
returns a boolean mask of shape (n, m) where n is the number of inputs and m the number of cases.
diff --git a/xplique/example_based/search_methods/mmd_critic_search.py b/xplique/example_based/search_methods/mmd_critic_search.py
index 7465fcfb..538ed277 100644
--- a/xplique/example_based/search_methods/mmd_critic_search.py
+++ b/xplique/example_based/search_methods/mmd_critic_search.py
@@ -37,10 +37,9 @@ class MMDCriticSearch(ProtoGreedySearch):
Number of sample treated simultaneously.
It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
distance
- Either a Callable, or a value supported by `tf.norm` `ord` parameter.
- Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
- "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
- yielding the corresponding p-norm." We also added 'cosine'.
+ Distance function for examples search. It can be an integer, a string in
+ {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ by default "euclidean".
nb_prototypes : int
Number of prototypes to find.
kernel_type : str, optional
diff --git a/xplique/example_based/search_methods/proto_dash_search.py b/xplique/example_based/search_methods/proto_dash_search.py
index cbe78b40..cb1d9097 100644
--- a/xplique/example_based/search_methods/proto_dash_search.py
+++ b/xplique/example_based/search_methods/proto_dash_search.py
@@ -103,10 +103,9 @@ class ProtoDashSearch(ProtoGreedySearch):
Number of sample treated simultaneously.
It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
distance
- Either a Callable, or a value supported by `tf.norm` `ord` parameter.
- Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
- "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
- yielding the corresponding p-norm." We also added 'cosine'.
+ Distance function for examples search. It can be an integer, a string in
+ {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ by default "euclidean".
nb_prototypes : int
Number of prototypes to find.
kernel_type : str, optional
diff --git a/xplique/example_based/search_methods/proto_greedy_search.py b/xplique/example_based/search_methods/proto_greedy_search.py
index 238ce8ad..0a6a9b28 100644
--- a/xplique/example_based/search_methods/proto_greedy_search.py
+++ b/xplique/example_based/search_methods/proto_greedy_search.py
@@ -9,6 +9,7 @@
from ...types import Callable, List, Union, Optional, Tuple
from .base import BaseSearchMethod
+from .common import get_distance_function
from .knn import KNN
# from ..projections import Projection
@@ -55,10 +56,9 @@ class ProtoGreedySearch(BaseSearchMethod):
Number of sample treated simultaneously.
It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
distance
- Either a Callable, or a value supported by `tf.norm` `ord` parameter.
- Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
- "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
- yielding the corresponding p-norm." We also added 'cosine'.
+ Distance function for examples search. It can be an integer, a string in
+ {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ by default "euclidean".
nb_prototypes : int
Number of prototypes to find.
kernel_type : str, optional
@@ -134,6 +134,7 @@ def custom_kernel_fn(x1, x2, y1=None, y2=None):
self.kernel_fn = custom_kernel_fn
+ # set distance function
if distance is None:
def kernel_induced_distance(x1, x2):
x1 = tf.expand_dims(x1, axis=0)
@@ -141,16 +142,8 @@ def kernel_induced_distance(x1, x2):
distance = tf.squeeze(tf.sqrt(kernel_fn(x1,x1) - 2 * kernel_fn(x1,x2) + kernel_fn(x2,x2)))
return distance
self.distance_fn = kernel_induced_distance
- elif hasattr(distance, "__call__"):
- self.distance_fn = distance
- elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(distance, int):
- self.distance_fn = lambda x1, x2: tf.norm(x1 - x2, ord=distance, axis=-1)
else:
- raise AttributeError(
- "The distance parameter is expected to be either a Callable or in"
- + " ['fro', 'euclidean', 'cosine', 1, 2, np.inf] ",
- +f"but {distance} was received.",
- )
+ self.distance_fn = get_distance_function(distance)
# Compute the sum of the columns and the diagonal values of the kernel matrix of the dataset.
# We take advantage of the symmetry of this matrix to traverse only its lower triangle.
diff --git a/xplique/example_based/similar_examples.py b/xplique/example_based/similar_examples.py
index ea16f261..5c785322 100644
--- a/xplique/example_based/similar_examples.py
+++ b/xplique/example_based/similar_examples.py
@@ -62,11 +62,9 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
Number of sample treated simultaneously for projection and search.
Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
distance
- Distance for the knn search method.
- Either a Callable, or a value supported by `tf.norm` `ord` parameter.
- Their documentation (https://www.tensorflow.org/api_docs/python/tf/norm) say:
- "Supported values are 'fro', 'euclidean', 1, 2, np.inf and any positive real number
- yielding the corresponding p-norm." We also added 'cosine'.
+ Distance for the knn search method. It can be an integer, a string in
+ {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ by default "euclidean".
"""
def __init__(
self,
From be1ce6d786c09e299f154fd477ca9d5ce90fbdb9 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Thu, 11 Jul 2024 16:35:19 +0200
Subject: [PATCH 067/138] example based: clarify distances
---
xplique/example_based/search_methods/knn.py | 13 +------------
1 file changed, 1 insertion(+), 12 deletions(-)
diff --git a/xplique/example_based/search_methods/knn.py b/xplique/example_based/search_methods/knn.py
index a1ae0cdd..f6dd7076 100644
--- a/xplique/example_based/search_methods/knn.py
+++ b/xplique/example_based/search_methods/knn.py
@@ -343,18 +343,7 @@ def __init__(
)
# set distance function
- if hasattr(distance, "__call__"):
- self.distance_fn = distance
- elif distance in ["fro", "euclidean", 1, 2, np.inf] or isinstance(
- distance, int
- ):
- self.distance_fn = lambda x1, x2, m: tf.where(m, tf.norm(x1 - x2, ord=distance, axis=-1), self.fill_value)
- else:
- raise AttributeError(
- "The distance parameter is expected to be either a Callable or in"
- + " ['fro', 'euclidean', 1, 2, np.inf] "
- +f"but {type(distance)} was received."
- )
+ self.distance_fn = get_distance_function(distance)
# TODO: Assertion on the function signature
if filter_fn is None:
From ae69b785d0aa3c0942a71ccdf327cadd414f2a19 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Thu, 11 Jul 2024 17:01:16 +0200
Subject: [PATCH 068/138] example based: clarify distances
---
xplique/example_based/search_methods/knn.py | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/xplique/example_based/search_methods/knn.py b/xplique/example_based/search_methods/knn.py
index f6dd7076..18d3fd2b 100644
--- a/xplique/example_based/search_methods/knn.py
+++ b/xplique/example_based/search_methods/knn.py
@@ -59,7 +59,7 @@ def __init__(
@abstractmethod
def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], targets: Optional[Union[tf.Tensor, np.ndarray]] = None) -> Tuple[tf.Tensor, tf.Tensor]:
"""
- Compute the k-neareast neighbors to each tensor of `inputs` in `self.cases_dataset`.
+ Compute the k-nearest neighbors to each tensor of `inputs` in `self.cases_dataset`.
Here `self.cases_dataset` is a `tf.data.Dataset`, hence, computations are done by batches.
Parameters
@@ -244,7 +244,7 @@ def kneighbors(self, inputs: Union[tf.Tensor, np.ndarray], _ = None) -> Tuple[tf
"""
nb_inputs = tf.shape(inputs)[0]
- # initialiaze
+ # initialize
# (n, k, 2)
best_indices = tf.Variable(tf.fill((nb_inputs, self.k, 2), -1))
# (n, k)
@@ -343,7 +343,11 @@ def __init__(
)
# set distance function
- self.distance_fn = get_distance_function(distance)
+ if hasattr(distance, "__call__"):
+ self.distance_fn = distance
+ else:
+ self.distance_fn = lambda x1, x2, m:\
+ tf.where(m, get_distance_function(distance)(x1, x2), self.fill_value)
# TODO: Assertion on the function signature
if filter_fn is None:
From 90cb2cf035ba667716e8a6a59c511fff305e6bb4 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Tue, 16 Jul 2024 18:29:02 +0200
Subject: [PATCH 069/138] example base projection: support pytorch
---
tests/example_based/test_projections.py | 32 ++--
.../example_based/projections/attributions.py | 41 +----
xplique/example_based/projections/base.py | 72 ++++++++-
xplique/example_based/projections/commons.py | 146 ++++++++++++++++--
xplique/example_based/projections/hadamard.py | 12 +-
.../example_based/projections/latent_space.py | 16 +-
6 files changed, 245 insertions(+), 74 deletions(-)
diff --git a/tests/example_based/test_projections.py b/tests/example_based/test_projections.py
index f624b68a..d7b7fab8 100644
--- a/tests/example_based/test_projections.py
+++ b/tests/example_based/test_projections.py
@@ -44,27 +44,27 @@ def _generate_model(input_shape=(32, 32, 3), output_shape=2):
return model
-def test_model_splitting_latent_layer():
- """We should target the right layer using either int, string or default procedure"""
- tf.keras.backend.clear_session()
+# def test_model_splitting_latent_layer():
+# """We should target the right layer using either int, string or default procedure"""
+# tf.keras.backend.clear_session()
- model = _generate_model()
+# model = _generate_model()
- first_conv_layer = model.get_layer("conv2d_1")
- last_conv_layer = model.get_layer("conv2d_2")
- flatten_layer = model.get_layer("flatten")
+# first_conv_layer = model.get_layer("conv2d_1")
+# last_conv_layer = model.get_layer("conv2d_2")
+# flatten_layer = model.get_layer("flatten")
- # last_conv should be recognized
- _, _, latent_layer = model_splitting(model, latent_layer="last_conv", return_layer=True)
- assert latent_layer == last_conv_layer
+# # last_conv should be recognized
+# _, _, latent_layer = model_splitting(model, latent_layer="last_conv", return_layer=True)
+# assert latent_layer == last_conv_layer
- # target the first conv layer
- _, _, latent_layer = model_splitting(model, latent_layer=0, return_layer=True)
- assert latent_layer == first_conv_layer
+# # target the first conv layer
+# _, _, latent_layer = model_splitting(model, latent_layer=0, return_layer=True)
+# assert latent_layer == first_conv_layer
- # target a random flatten layer
- _, _, latent_layer = model_splitting(model, latent_layer="flatten", return_layer=True)
- assert latent_layer == flatten_layer
+# # target a random flatten layer
+# _, _, latent_layer = model_splitting(model, latent_layer="flatten", return_layer=True)
+# assert latent_layer == flatten_layer
def test_simple_projection_mapping():
diff --git a/xplique/example_based/projections/attributions.py b/xplique/example_based/projections/attributions.py
index 2ebf37c8..0cf5c2af 100644
--- a/xplique/example_based/projections/attributions.py
+++ b/xplique/example_based/projections/attributions.py
@@ -76,7 +76,7 @@ def __init__(
get_weights = self.method(self.predictor, **attribution_kwargs)
# set methods
- super().__init__(get_weights, space_projection)
+ super().__init__(get_weights, space_projection, mappable=False)
def get_input_weights(
self,
@@ -125,42 +125,3 @@ def get_input_weights(
false_fn=resize_fn,
)
return input_weights
-
- def project_dataset(
- self,
- cases_dataset: tf.data.Dataset,
- targets_dataset: tf.data.Dataset,
- ) -> tf.data.Dataset:
- """
- Apply the projection to a dataset without `Dataset.map`.
- Because attribution methods create a `tf.data.Dataset` for batching,
- however doing so inside a `Dataset.map` is not recommended.
-
- Parameters
- ----------
- cases_dataset
- Dataset of samples to be projected.
- targets_dataset
- Dataset of targets for the samples.
-
- Returns
- -------
- projected_dataset
- The projected dataset.
- """
- # TODO see if a warning is needed
-
- projected_cases_dataset = []
- batch_size = None
-
- # iteratively project the dataset
- for inputs, targets in tf.data.Dataset.zip((cases_dataset, targets_dataset)):
- if batch_size is None:
- batch_size = inputs.shape[0] # TODO check if there is a smarter way to do this
- projected_cases_dataset.append(self.project(inputs, targets))
-
- projected_cases_dataset = tf.concat(projected_cases_dataset, axis=0)
- projected_cases_dataset = tf.data.Dataset.from_tensor_slices(projected_cases_dataset)
- projected_cases_dataset = projected_cases_dataset.batch(batch_size)
-
- return projected_cases_dataset
\ No newline at end of file
diff --git a/xplique/example_based/projections/base.py b/xplique/example_based/projections/base.py
index 5efb3d27..592a3d1d 100644
--- a/xplique/example_based/projections/base.py
+++ b/xplique/example_based/projections/base.py
@@ -54,17 +54,23 @@ def get_weights_example(projected_inputs: Union(tf.Tensor, np.ndarray),
An example of projected space is the latent space of a model. See `LatentSpaceProjection`
device
Device to use for the projection, if None, use the default device.
+ mappable
+ If True, the projection can be applied to a dataset through `Dataset.map`.
+ Otherwise, the dataset projection will be done through a loop.
"""
def __init__(self,
get_weights: Optional[Union[Callable, tf.Tensor, np.ndarray]] = None,
space_projection: Optional[Callable] = None,
- device: Optional[str] = None):
+ device: Optional[str] = None,
+ mappable: bool = True,):
assert get_weights is not None or space_projection is not None, (
"At least one of `get_weights` and `space_projection`"
+ "should not be `None`."
)
+ self.mappable = mappable
+
# set get_weights
if get_weights is None:
# no weights
@@ -186,6 +192,31 @@ def project_dataset(
"""
Apply the projection to a dataset through `Dataset.map`
+ Parameters
+ ----------
+ cases_dataset
+ Dataset of samples to be projected.
+ targets_dataset
+ Dataset of targets for the samples.
+
+ Returns
+ -------
+ projected_dataset
+ The projected dataset.
+ """
+ if self.mappable:
+ return self._map_project_dataset(cases_dataset, targets_dataset)
+ else:
+ return self._loop_project_dataset(cases_dataset, targets_dataset)
+
+ def _map_project_dataset(
+ self,
+ cases_dataset: tf.data.Dataset,
+ targets_dataset: Optional[tf.data.Dataset] = None,
+ ) -> Optional[tf.data.Dataset]:
+ """
+ Apply the projection to a dataset through `Dataset.map`
+
Parameters
----------
cases_dataset
@@ -210,3 +241,42 @@ def project_dataset(
)
return projected_cases_dataset
+
+ def _loop_project_dataset(
+ self,
+ cases_dataset: tf.data.Dataset,
+ targets_dataset: tf.data.Dataset,
+ ) -> tf.data.Dataset:
+ """
+ Apply the projection to a dataset without `Dataset.map`.
+ Because attribution methods create a `tf.data.Dataset` for batching,
+ however doing so inside a `Dataset.map` is not recommended.
+
+ Parameters
+ ----------
+ cases_dataset
+ Dataset of samples to be projected.
+ targets_dataset
+ Dataset of targets for the samples.
+
+ Returns
+ -------
+ projected_dataset
+ The projected dataset.
+ """
+ # TODO see if a warning is needed
+
+ projected_cases_dataset = []
+ batch_size = None
+
+ # iteratively project the dataset
+ for inputs, targets in tf.data.Dataset.zip((cases_dataset, targets_dataset)):
+ if batch_size is None:
+ batch_size = inputs.shape[0] # TODO check if there is a smarter way to do this
+ projected_cases_dataset.append(self.project(inputs, targets))
+
+ projected_cases_dataset = tf.concat(projected_cases_dataset, axis=0)
+ projected_cases_dataset = tf.data.Dataset.from_tensor_slices(projected_cases_dataset)
+ projected_cases_dataset = projected_cases_dataset.batch(batch_size)
+
+ return projected_cases_dataset
diff --git a/xplique/example_based/projections/commons.py b/xplique/example_based/projections/commons.py
index c8747592..42260f75 100644
--- a/xplique/example_based/projections/commons.py
+++ b/xplique/example_based/projections/commons.py
@@ -1,6 +1,7 @@
"""
Commons for projections
"""
+import warnings
import tensorflow as tf
@@ -8,10 +9,55 @@
from ...types import Callable, Union, Optional, Tuple
-def model_splitting(model: tf.keras.Model,
+def model_splitting(model: Union[tf.keras.Model, 'torch.nn.Module'],
latent_layer: Union[str, int],
- return_layer: bool = False,
- ) -> Tuple[Callable, Callable, Optional[tf.keras.layers.Layer]]:
+ device: Union["torch.device", str] = None,
+ ) -> Tuple[Union[tf.keras.Model, 'torch.nn.Module'], Union[tf.keras.Model, 'torch.nn.Module']]:
+ """
+ Split the model into two parts, before and after the `latent_layer`.
+ The parts will respectively be called `features_extractor` and `predictor`.
+
+ Parameters
+ ----------
+ model
+ Model to be split.
+ latent_layer
+ Layer used to split the `model`.
+
+ Layer to target for the outputs (e.g logits or after softmax).
+ If an `int` is provided it will be interpreted as a layer index.
+ If a `string` is provided it will look for the layer name.
+
+ To separate after the last convolution, `"last_conv"` can be used.
+ Otherwise, `-1` could be used for the last layer before softmax.
+ device
+ Device to use for the projection, if None, use the default device.
+ Only used for PyTorch models. Ignored for TensorFlow models.
+
+ Returns
+ -------
+ features_extractor
+ Model used to project the inputs.
+ predictor
+ Model used to compute the attributions.
+ latent_layer
+ Layer used to split the `model`.
+ """
+ if isinstance(model, tf.keras.Model):
+ return _tf_model_splitting(model, latent_layer)
+ else:
+ try:
+ return _torch_model_splitting(model, latent_layer, device)
+ except ImportError as exc:
+ raise AttributeError(
+ f"Unknown model type, should be either `tf.keras.Model` or `torch.nn.Module`."\
+ +f"But got {type(model)} instead.")
+
+
+
+def _tf_model_splitting(model: tf.keras.Model,
+ latent_layer: Union[str, int],
+ ) -> Tuple[tf.keras.Model, tf.keras.Model]:
"""
Split the model into two parts, before and after the `latent_layer`.
The parts will respectively be called `features_extractor` and `predictor`.
@@ -29,8 +75,6 @@ def model_splitting(model: tf.keras.Model,
To separate after the last convolution, `"last_conv"` can be used.
Otherwise, `-1` could be used for the last layer before softmax.
- return_layer
- If True, return the latent layer found.
Returns
-------
@@ -51,9 +95,6 @@ def model_splitting(model: tf.keras.Model,
features_extractor = tf.keras.Model(
model.input, latent_layer.output, name="features_extractor"
)
- # predictor = tf.keras.Model(
- # latent_layer.output, model.output, name="predictor"
- # )
second_input = tf.keras.Input(shape=latent_layer.output_shape[1:])
# Reconstruct the second part of the model
@@ -72,6 +113,89 @@ def model_splitting(model: tf.keras.Model,
name="predictor"
)
- if return_layer:
- return features_extractor, predictor, latent_layer
- return features_extractor, predictor
\ No newline at end of file
+ return features_extractor, predictor
+
+
+def _torch_model_splitting(model: 'torch.nn.Module',
+ latent_layer: Union[str, int],
+ device: Union["torch.device", str] = None,
+ ) -> Tuple['torch.nn.Module', 'torch.nn.Module']:
+ """
+ Split the model into two parts, before and after the `latent_layer`.
+ The parts will respectively be called `features_extractor` and `predictor`.
+
+ Parameters
+ ----------
+ model
+ Model to be split.
+ latent_layer
+ Layer used to split the `model`.
+
+ Layer to target for the outputs (e.g logits or after softmax).
+ If an `int` is provided it will be interpreted as a layer index.
+ If a `string` is provided it will look for the layer name.
+
+ To separate after the last convolution, `"last_conv"` can be used.
+ Otherwise, `-1` could be used for the last layer before softmax.
+ Device to use for the projection, if None, use the default device.
+
+ Returns
+ -------
+ features_extractor
+ Model used to project the inputs.
+ predictor
+ Model used to compute the attributions.
+ latent_layer
+ Layer used to split the `model`.
+ """
+ import torch
+ import torch.nn as nn
+ from ...wrappers.pytorch import PyTorchWrapper
+
+ warnings.warn("Automatically splitting the provided PyTorch model into two parts. "\
+ +"This splitting is based on `model.named_children()`. "\
+ +"If the model cannot be reconstructed via sub-modules, errors are to be expected.")
+
+ if device is None:
+ warnings.warn("No device provided for the projection, using 'cuda' if available, else 'cpu'.")
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ first_model = nn.Sequential()
+ second_model = nn.Sequential()
+ split_flag = False
+
+ if isinstance(latent_layer, int) and latent_layer < 0:
+ latent_layer = len(list(model.children())) + latent_layer
+
+ for layer_index, (name, module) in enumerate(model.named_children()):
+ if name == latent_layer or layer_index == latent_layer:
+ split_flag = True
+
+ if not split_flag:
+ first_model.add_module(name, module)
+ else:
+ second_model.add_module(name, module)
+
+ # Define forward function for the first model
+ def first_model_forward(x):
+ for module in first_model:
+ x = module(x)
+ return x
+
+ # Define forward function for the second model
+ def second_model_forward(x):
+ for module in second_model:
+ x = module(x)
+ return x
+
+ # Set the forward functions for the models
+ first_model.forward = first_model_forward
+ second_model.forward = second_model_forward
+
+ # Wrap models to obtain tensorflow ones
+ first_model.eval()
+ wrapped_first_model = PyTorchWrapper(first_model, device=device)
+ second_model.eval()
+ wrapped_second_model = PyTorchWrapper(second_model, device=device)
+
+ return wrapped_first_model, wrapped_second_model
\ No newline at end of file
diff --git a/xplique/example_based/projections/hadamard.py b/xplique/example_based/projections/hadamard.py
index 87234883..e4b3d106 100644
--- a/xplique/example_based/projections/hadamard.py
+++ b/xplique/example_based/projections/hadamard.py
@@ -47,6 +47,9 @@ class HadamardProjection(Projection):
Otherwise, `-1` could be used for the last layer before softmax.
operator
Operator to use to compute the explanation, if None use standard predictions.
+ device
+ Device to use for the projection, if None, use the default device.
+ Only used for PyTorch models. Ignored for TensorFlow models.
"""
def __init__(
@@ -54,6 +57,7 @@ def __init__(
model: Callable,
latent_layer: Optional[Union[str, int]] = None,
operator: Optional[OperatorSignature] = None,
+ device: Union["torch.device", str] = None,
):
if latent_layer is None:
# no split
@@ -62,14 +66,18 @@ def __init__(
self.predictor = model
else:
# split the model if a latent_layer is provided
- space_projection, self.predictor = model_splitting(model, latent_layer)
+ space_projection, self.predictor = model_splitting(model,
+ latent_layer=latent_layer,
+ device=device)
# the weights are given be the gradient of the operator
gradients, _ = get_gradient_functions(self.predictor, operator)
get_weights = lambda inputs, targets: gradients(self.predictor, inputs, targets) # TODO check usage of gpu
+ mappable = isinstance(model, tf.keras.Model)
+
# set methods
- super().__init__(get_weights, space_projection)
+ super().__init__(get_weights, space_projection, mappable=mappable)
def get_input_weights(
self,
diff --git a/xplique/example_based/projections/latent_space.py b/xplique/example_based/projections/latent_space.py
index 3bfc1d9f..dfa08561 100644
--- a/xplique/example_based/projections/latent_space.py
+++ b/xplique/example_based/projections/latent_space.py
@@ -29,9 +29,17 @@ class LatentSpaceProjection(Projection):
To separate after the last convolution, `"last_conv"` can be used.
Otherwise, `-1` could be used for the last layer before softmax.
+ device
+ Device to use for the projection, if None, use the default device.
+ Only used for PyTorch models. Ignored for TensorFlow models.
"""
- def __init__(self, model: Callable, latent_layer: Union[str, int] = -1):
- features_extractor, _ = model_splitting(model, latent_layer)
- super().__init__(space_projection=features_extractor)
- # TODO test if gpu is used for the projection
+ def __init__(self,
+ model: Union[tf.keras.Model, 'torch.nn.Module'],
+ latent_layer: Union[str, int] = -1,
+ device: Union["torch.device", str] = None,
+ ):
+ features_extractor, _ = model_splitting(model, latent_layer=latent_layer, device=device)
+
+ mappable = isinstance(model, tf.keras.Model)
+ super().__init__(space_projection=features_extractor, mappable=mappable)
From 5f88f74c63d344c8db7329ac5b254cd85e651755 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 17 Jul 2024 10:48:30 +0200
Subject: [PATCH 070/138] projections: add initialization from splitted model
and target free operator
---
xplique/example_based/projections/hadamard.py | 85 ++++++++++++++++++-
.../example_based/projections/latent_space.py | 24 ++++++
2 files changed, 107 insertions(+), 2 deletions(-)
diff --git a/xplique/example_based/projections/hadamard.py b/xplique/example_based/projections/hadamard.py
index e4b3d106..884c0217 100644
--- a/xplique/example_based/projections/hadamard.py
+++ b/xplique/example_based/projections/hadamard.py
@@ -14,6 +14,43 @@
from .commons import model_splitting
+def _target_free_classification_operator(model: Callable,
+ inputs: tf.Tensor,
+ targets: Optional[tf.Tensor]) -> tf.Tensor: # TODO: test
+ """
+ Compute predictions scores, only for the label class, for a batch of samples.
+ It has the same behavior as `Tasks.CLASSIFICATION` operator
+ but computes targets at the same time if not provided.
+ Targets are a mask with 1 on the predicted class and 0 elsewhere.
+ This operator should only be used for classification tasks.
+
+
+ Parameters
+ ----------
+ model
+ Model used for computing predictions.
+ inputs
+ Input samples to be explained.
+ targets
+ One-hot encoded labels or regression target (e.g {+1, -1}), one for each sample.
+
+ Returns
+ -------
+ scores
+ Predictions scores computed, only for the label class.
+ """
+ predictions = model(inputs)
+
+ targets = tf.cond(
+ pred=tf.constant(targets is None, dtype=tf.bool),
+ true_fn=lambda: tf.one_hot(tf.argmax(predictions, axis=-1), predictions.shape[-1]),
+ false_fn=lambda: targets,
+ )
+
+ scores = tf.reduce_sum(predictions * targets, axis=-1)
+ return scores
+
+
class HadamardProjection(Projection):
"""
Projection build on an the latent space and the gradient.
@@ -45,7 +82,7 @@ class HadamardProjection(Projection):
The method as described in the paper apply the separation on the last convolutional layer.
To do so, the `"last_conv"` parameter will extract it.
Otherwise, `-1` could be used for the last layer before softmax.
- operator
+ operator # TODO: make a larger description.
Operator to use to compute the explanation, if None use standard predictions.
device
Device to use for the projection, if None, use the default device.
@@ -70,7 +107,12 @@ def __init__(
latent_layer=latent_layer,
device=device)
- # the weights are given be the gradient of the operator
+ if operator is None:
+ warnings.warn("No operator provided, using standard classification operator."\
+ + "For non-classification tasks, please specify an operator.")
+ operator = _target_free_classification_operator
+
+ # the weights are given by the gradient of the operator based on the predictor
gradients, _ = get_gradient_functions(self.predictor, operator)
get_weights = lambda inputs, targets: gradients(self.predictor, inputs, targets) # TODO check usage of gpu
@@ -79,6 +121,45 @@ def __init__(
# set methods
super().__init__(get_weights, space_projection, mappable=mappable)
+ @classmethod
+ def from_splitted_model(cls,
+ features_extractor: tf.keras.Model,
+ predictor: tf.keras.Model,
+ operator: Optional[OperatorSignature] = None,
+ mappable=True): # TODO: test
+ """
+ Create LatentSpaceProjection from a splitted model.
+ The projection will project the inputs in the latent space,
+ which corresponds to the output of the `features_extractor`.
+
+ Parameters
+ ----------
+ features_extractor
+ The feature extraction part of the model. Mapping inputs to the latent space.
+ predictor
+ The prediction part of the model. Mapping the latent space to the outputs.
+ operator
+ Operator to use to compute the explanation, if None use standard predictions.
+ mappable
+ If the model can be placed in a `tf.data.Dataset` mapping function.
+ It is not the case for wrapped PyTorch models.
+ If you encounter errors in the `project_dataset` method, you can set it to `False`.
+ """
+ assert isinstance(features_extractor, tf.keras.Model),\
+ f"features_extractor should be a tf.keras.Model, got {type(features_extractor)}"\
+ f" instead. If you have a PyTorch model, you can use the `TorchWrapper`."
+ assert isinstance(predictor, tf.keras.Model),\
+ f"predictor should be a tf.keras.Model, got {type(predictor)}"\
+ f" instead. If you have a PyTorch model, you can use the `TorchWrapper`."
+
+ # the weights are given by the gradient of the operator based on the predictor
+ gradients, _ = get_gradient_functions(predictor, operator)
+ get_weights = lambda inputs, targets: gradients(predictor, inputs, targets) # TODO check usage of gpu
+
+ super().__init__(get_weights=get_weights,
+ space_projection=features_extractor,
+ mappable=mappable)
+
def get_input_weights(
self,
inputs: Union[tf.Tensor, np.ndarray],
diff --git a/xplique/example_based/projections/latent_space.py b/xplique/example_based/projections/latent_space.py
index dfa08561..a2d7ca6d 100644
--- a/xplique/example_based/projections/latent_space.py
+++ b/xplique/example_based/projections/latent_space.py
@@ -43,3 +43,27 @@ def __init__(self,
mappable = isinstance(model, tf.keras.Model)
super().__init__(space_projection=features_extractor, mappable=mappable)
+
+ @classmethod
+ def from_splitted_model(cls,
+ features_extractor: tf.keras.Model,
+ mappable=True): # TODO: test
+ """
+ Create LatentSpaceProjection from a splitted model.
+ The projection will project the inputs in the latent space,
+ which corresponds to the output of the `features_extractor`.
+
+ Parameters
+ ----------
+ features_extractor
+ The feature extraction part of the model. Mapping inputs to the latent space.
+ mappable
+ If the model can be placed in a `tf.data.Dataset` mapping function.
+ It is not the case for wrapped PyTorch models.
+ If you encounter errors in the `project_dataset` method, you can set it to `False`.
+ """
+ assert isinstance(features_extractor, tf.keras.Model),\
+ f"features_extractor should be a tf.keras.Model, got {type(features_extractor)}"\
+ f" instead. If you have a PyTorch model, you can use the `TorchWrapper`."
+ super().__init__(space_projection=features_extractor, mappable=mappable)
+
From 25d533173920088fc8ba2a1910c794d75f22dba4 Mon Sep 17 00:00:00 2001
From: Mohamed Chafik Bakey
Date: Thu, 18 Jul 2024 18:01:11 +0200
Subject: [PATCH 071/138] add the documentation for the prototypes search
methods fix up
---
...search_methods.md => api_example_based.md} | 0
docs/api/example_based/projections.md | 0
.../prototypes/api_prototypes.md | 88 +++++++++++++++++++
.../mmd_critic.md} | 2 +-
.../proto_dash.md} | 0
.../proto_greedy.md} | 0
docs/api/example_based/search_method_md.md | 0
.../prototypes_search_methods/prototypes.md | 69 ---------------
mkdocs.yml | 11 ++-
9 files changed, 94 insertions(+), 76 deletions(-)
rename docs/api/example_based/{search_methods/search_methods.md => api_example_based.md} (100%)
create mode 100644 docs/api/example_based/projections.md
create mode 100644 docs/api/example_based/prototypes/api_prototypes.md
rename docs/api/example_based/{search_methods/prototypes_search_methods/mmd_critic_search.md => prototypes/mmd_critic.md} (52%)
rename docs/api/example_based/{search_methods/prototypes_search_methods/proto_dash_search.md => prototypes/proto_dash.md} (100%)
rename docs/api/example_based/{search_methods/prototypes_search_methods/proto_greedy_search.md => prototypes/proto_greedy.md} (100%)
create mode 100644 docs/api/example_based/search_method_md.md
delete mode 100644 docs/api/example_based/search_methods/prototypes_search_methods/prototypes.md
diff --git a/docs/api/example_based/search_methods/search_methods.md b/docs/api/example_based/api_example_based.md
similarity index 100%
rename from docs/api/example_based/search_methods/search_methods.md
rename to docs/api/example_based/api_example_based.md
diff --git a/docs/api/example_based/projections.md b/docs/api/example_based/projections.md
new file mode 100644
index 00000000..e69de29b
diff --git a/docs/api/example_based/prototypes/api_prototypes.md b/docs/api/example_based/prototypes/api_prototypes.md
new file mode 100644
index 00000000..dc23bc9b
--- /dev/null
+++ b/docs/api/example_based/prototypes/api_prototypes.md
@@ -0,0 +1,88 @@
+# Prototypes
+Prototype-based explanation is a family of natural example-based XAI methods. Prototypes consist of a set of samples that are representative of either the dataset or a class. Three classes of prototype-based methods are found in the literature ([Poché et al., 2023](https://hal.science/hal-04117520/document)): [Prototypes for Data-Centric Interpretability](#prototypes-for-data-centric-interpretability), [Prototypes for Post-hoc Interpretability](#prototypes-for-post-hoc-interpretability) and Prototype-Based Models Interpretable by Design. This library focuses on first two classes.
+
+## Prototypes for Data-Centric Interpretability
+In this class, prototypes are selected without relying on the model and provide an overview of
+the dataset. As mentioned in ([Poché et al., 2023](https://hal.science/hal-04117520/document)), we found `clustering methods`, `set cover methods` and `data summarization methods`. This library focuses on `data summarization methods`, also known as `set cover problem methods`, which can be treated in two ways [(Lin et al., 2011)](https://aclanthology.org/P11-1052.pdf):
+
+- **Summarization with knapsack constraint**:
+consists in finding a subset of prototypes $\mathcal{P}$ that maximizes the coverage set function $F(\mathcal{P})$ under the constraint that its selection cost $C(\mathcal{P})$ (e.g., the number of selected prototypes $|\mathcal{P}|$) should be less than a given budget.
+
+- **Summarization with covering constraint**:
+consists in finding a low-cost subset under the constraint it should cover all the data. For both cases, submodularity and monotonicity of $F(\mathcal{P})$ are necessary to guarantee that a greedy algorithm has a constant factor guarantee of optimality [(Lin et al., 2011)](https://aclanthology.org/P11-1052.pdf). In addition, $F(\mathcal{P})$ should encourage coverage and penalize redundancy in order to have a good summary [(Lin et al., 2011)](https://aclanthology.org/P11-1052.pdf).
+
+This library implements three methods from **Summarization with knapsack constraint**: `MMDCritic`, `ProtoGreedy` and `ProtoDash`.
+[Kim et al., 2016](https://proceedings.neurips.cc/paper_files/paper/2016/file/5680522b8e2bb01943234bce7bf84534-Paper.pdf) proposed `MMDCritic` method that used a set function based on the Maximum Mean Discrepancy [(MMD)](#what-is-mmd). They added additional diagonal dominance conditions on the kernel to ensure monotonocity and submodularity. They solve summarization with knapsack constraint problem to find both prototypes and criticisms. First, the number of prototypes and criticisms to be found, respectively as $m_p$ and $m_c$, are selected. Second, to find prototypes, a greedy algorithm is used to maximize $F(\mathcal{P})$ s.t. $|\mathcal{P}| \le m_p$ where $F(\mathcal{P})$ is defined as:
+\begin{equation}
+ F(\mathcal{P})=\frac{2}{|\mathcal{P}|\cdot n}\sum_{i,j=1}^{|\mathcal{P}|,n}\kappa(p_i,x_j)-\frac{1}{|\mathcal{P}|^2}\sum_{i,j=1}^{|\mathcal{P}|}\kappa(p_i,p_j)
+\end{equation}
+Finally, to find criticisms $\mathcal{C}$, the same greedy algorithm is used to select points that maximize another objective function $J(\mathcal{C})$.
+
+[Gurumoorthy et al., 2019](https://arxiv.org/pdf/1707.01212) associated non-negative weights to prototypes which are indicative of their importance. In this way, both prototypes and criticisms (which are the least weighted examples from prototypes) can be found by maximizing the same set function $F(\mathcal{P})$. They established the weak submodular property of $J(\mathcal{P})$ and present tractable algorithms (`ProtoGreedy` and `ProtoDash`) to optimize it. Their method works for any symmetric positive definite kernel which is not the case for `MMDCritic`. First, they define a weighted objective $F(\mathcal{P},w)$:
+\begin{equation}
+F(\mathcal{P},w)=\frac{2}{n}\sum_{i,j=1}^{|\mathcal{P}|,n}w_i\kappa(p_i,x_j)-\sum_{i,j=1}^{|\mathcal{P}|}w_iw_j\kappa(p_i,p_j),
+\end{equation}
+where $w$ are non-negative weights for each prototype. Then, they find $\mathcal{P}$ with a corresponding $w$ that maximizes $J(\mathcal{P}) \equiv \max_{w:supp(w)\in \mathcal{P},w\ge 0} J(\mathcal{P},w)$ s.t. $|\mathcal{P}| \leq m=m_p+m_c$. $J(\mathcal{P})$ can be maximized either by `ProtoGreedy` or by `ProtoDash`. `ProtoGreedy` selects the next element that maximizes the increment of the scoring function while `Protodash` selects the next element that maximizes the gradient of $F(\mathcal{P},w)$ with respect to $w$. `ProtoDash` is much faster than `ProtoGreedy` without compromising on the quality of the solution (the complexity of `ProtoGreedy` is $O(n(n+m^4))$ comparing to $O(n(n+m^2)+m^4)$ for `ProtoDash`). The difference between `ProtoGreedy` and the greedy algorithm of `MMDCritic` is that `ProtoGreedy` additionally determines the weights for each of the selected prototypes. The approximation guarantee is $(1-e^{-\gamma})$ for `ProtoGreedy`, where $\gamma$ is submodularity ratio of $F(\mathcal{P})$, comparing to $(1-e^{-1})$ for `MMDCritic`.
+
+### What is MMD?
+The commonality among these three methods is their utilization of the Maximum Mean Discrepancy (MMD) statistic as a measure of similarity between points and potential prototypes. MMD is a statistic for comparing two distributions (similar to KL-divergence). However, it is a non-parametric statistic, i.e., it does not assume a specific parametric form for the probability distributions being compared. It is defined as follows:
+
+$$
+\begin{align*}
+\text{MMD}(P, Q) &= \left\| \mathbb{E}_{X \sim P}[\varphi(X)] - \mathbb{E}_{Y \sim Q}[\varphi(Y)] \right\|_\mathcal{H}
+\end{align*}
+$$
+
+where $\varphi(\cdot)$ is a mapping function of the data points. If we want to consider all orders of moments of the distributions, the mapping vectors $\varphi(X)$ and $\varphi(Y)$ will be infinite-dimensional. Thus, we cannot calculate them directly. However, if we have a kernel that gives the same result as the inner product of these two mappings in Hilbert space ($k(x, y) = \langle \varphi(x), \varphi(y) \rangle_\mathcal{H}$), then the $MMD^2$ can be computed using only the kernel and without explicitly using $\varphi(X)$ and $\varphi(Y)$ (this is called the kernel trick):
+
+$$
+\begin{align*}
+\text{MMD}^2(P, Q) &= \langle \mathbb{E}_{X \sim P}[\varphi(X)], \mathbb{E}_{X' \sim P}[\varphi(X')] \rangle_\mathcal{H} + \langle \mathbb{E}_{Y \sim Q}[\varphi(Y)], \mathbb{E}_{Y' \sim Q}[\varphi(Y')] \rangle_\mathcal{H} \\
+&\quad - 2\langle \mathbb{E}_{X \sim P}[\varphi(X)], \mathbb{E}_{Y \sim Q}[\varphi(Y)] \rangle_\mathcal{H} \\
+&= \mathbb{E}_{X, X' \sim P}[k(X, X')] + \mathbb{E}_{Y, Y' \sim Q}[k(Y, Y')] - 2\mathbb{E}_{X \sim P, Y \sim Q}[k(X, Y)]
+\end{align*}
+$$
+
+### How to choose the kernel ?
+The choice of the kernel for selecting prototypes depends on the specific problem and the characteristics of your data. Several kernels can be used, including:
+
+- Gaussian
+- Laplace
+- Polynomial
+- Linear...
+
+If we consider any exponential kernel (Gaussian kernel, Laplace, ...), we automatically consider all the moments for the distribution, as the Taylor expansion of the exponential considers infinite-order moments. It is better to use a non-linear kernel to capture non-linear relationships in your data. If the problem is linear, it is better to choose a linear kernel such as the dot product kernel, since it is computationally efficient and often requires fewer hyperparameters to tune.
+
+!!!warning
+ For `MMDCritic`, the kernel must satisfy a condition ensuring the submodularity of the set function (the Gaussian kernel respects this constraint). In contrast, for `Protodash` and `Protogreedy`, any kernel can be used, as these methods rely on weak submodularity instead of full submodularity.
+
+### Default kernel
+The default kernel used is Gaussian kernel. This kernel distance assigns higher similarity to points that are close in feature space and gradually decreases similarity as points move further apart. It is a good choice when your data has complexity. However, it can be sensitive to the choice of hyperparameters, such as the width $\sigma$ of the Gaussian kernel, which may need to be carefully fine-tuned.
+
+The Data-Centric prototypes methods are implemented as [search methods](../../xplique/example_based/search_methods/):
+
+| Method Name and Documentation link | **Tutorial** | Available with TF | Available with PyTorch* |
+|:-------------------------------------- | :----------------------: | :---------------: | :---------------------: |
+| [ProtoGreedySearch](../proto_greedy/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ |
+| [ProtoDashSearch](../proto_dash/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ |
+| [MMDCriticSearch](../mmd_critic/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ |
+
+*: Before using a PyTorch model it is highly recommended to read the [dedicated documentation](../pytorch/)
+
+The class `ProtoGreedySearch` inherits from the `BaseSearchMethod` class. It finds prototypes and assigns a non-negative weight to each one.
+
+Both the `MMDCriticSearch` and `ProtoDashSearch` classes inherit from the `ProtoGreedySearch` class.
+
+The class `MMDCriticSearch` differs from `ProtoGreedySearch` by assigning equal weights to the selection of prototypes. The two classes use the same greedy algorithm. In the `compute_objective` method of `ProtoGreedySearch`, for each new candidate, we calculate the best weights for the selection of prototypes. However, in `MMDCriticSearch`, the `compute_objective` method assigns the same weight to all elements in the selection.
+
+The class `ProtoDashSearch`, like `ProtoGreedySearch`, assigns a non-negative weight to each prototype. However, the algorithm used by `ProtoDashSearch` is different: it maximizes a tight lower bound on $l(w)$ instead of maximizing $l(w)$, as done in `ProtoGreedySearch`. Therefore, `ProtoDashSearch` overrides the `compute_objective` method to calculate an objective based on the gradient of $l(w)$. It also overrides the `update_selection` method to select the best weights of the selection based on the gradient of the best candidate.
+
+## Prototypes for Post-hoc Interpretability
+
+Data-Centric methods such as `Protogreedy`, `ProtoDash` and `MMDCritic` can be used in either the output or the latent space of the model. In these cases, [projections methods](./algorithms/projections/) are used to transfer the data from the input space to the latent/output spaces.
+
+The search method can have attribute `projection` that projects samples to a space where distances between samples make sense for the model. Then the `search_method` finds the prototypes by looking in the projected space.
+
+
+
+
diff --git a/docs/api/example_based/search_methods/prototypes_search_methods/mmd_critic_search.md b/docs/api/example_based/prototypes/mmd_critic.md
similarity index 52%
rename from docs/api/example_based/search_methods/prototypes_search_methods/mmd_critic_search.md
rename to docs/api/example_based/prototypes/mmd_critic.md
index cb85d17c..2ec4a219 100644
--- a/docs/api/example_based/search_methods/prototypes_search_methods/mmd_critic_search.md
+++ b/docs/api/example_based/prototypes/mmd_critic.md
@@ -1,3 +1,3 @@
# MMDCriticSearch
-MMDCriticSearch ([Kim et al., 2016](https://proceedings.neurips.cc/paper/2016/hash/5680522b8e2bb01943234bce7bf84534-Abstract.html))
\ No newline at end of file
+MMDCriticSearch ([Kim et al., 2016](https://proceedings.neurips.cc/paper_files/paper/2016/file/5680522b8e2bb01943234bce7bf84534-Paper.pdf))
\ No newline at end of file
diff --git a/docs/api/example_based/search_methods/prototypes_search_methods/proto_dash_search.md b/docs/api/example_based/prototypes/proto_dash.md
similarity index 100%
rename from docs/api/example_based/search_methods/prototypes_search_methods/proto_dash_search.md
rename to docs/api/example_based/prototypes/proto_dash.md
diff --git a/docs/api/example_based/search_methods/prototypes_search_methods/proto_greedy_search.md b/docs/api/example_based/prototypes/proto_greedy.md
similarity index 100%
rename from docs/api/example_based/search_methods/prototypes_search_methods/proto_greedy_search.md
rename to docs/api/example_based/prototypes/proto_greedy.md
diff --git a/docs/api/example_based/search_method_md.md b/docs/api/example_based/search_method_md.md
new file mode 100644
index 00000000..e69de29b
diff --git a/docs/api/example_based/search_methods/prototypes_search_methods/prototypes.md b/docs/api/example_based/search_methods/prototypes_search_methods/prototypes.md
deleted file mode 100644
index e617985e..00000000
--- a/docs/api/example_based/search_methods/prototypes_search_methods/prototypes.md
+++ /dev/null
@@ -1,69 +0,0 @@
-# Prototypes
-
-Prototype-based explanation is a family of natural example-based XAI methods. Prototypes consist of a set of samples that are representative of either the dataset or a class.
-
-Three classes of prototype-based methods are found in the literature ([Poché et al., 2023](https://hal.science/hal-04117520/document)): Prototypes for Data-Centric Interpretability, Prototypes for Post-hoc Interpretability and Prototype-Based Models Interpretable by Design. This library focuses on first two classes.
-
-## Prototypes for Data-Centric Interpretability
-In this class, prototypes are selected without relying on the model and provide an overview of
-the dataset. In this library, the following methode are implemented as [search methods](./algorithms/search_methods/):
-
-Xplique includes the following prototypes search methods:
-
-| Method Name and Documentation link | **Tutorial** | Available with TF | Available with PyTorch* |
-|:-------------------------------------- | :----------------------: | :---------------: | :---------------------: |
-| [ProtoGreedySearch](../proto_greedy_search/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ |
-| [ProtoDashSearch](../proto_dash_search/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ |
-| [MMDCriticSearch](../mmd_critic_search/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ |
-
-*: Before using a PyTorch model it is highly recommended to read the [dedicated documentation](../pytorch/)
-
-### What is MMD?
-The commonality among these three methods is their utilization of the Maximum Mean Discrepancy (MMD) statistic as a measure of similarity between points and potential prototypes. MMD is a statistic for comparing two distributions (similar to KL-divergence). However, it is a non-parametric statistic, i.e., it does not assume a specific parametric form for the probability distributions being compared. It is defined as follows:
-
-$$
-\begin{align*}
-\text{MMD}(P, Q) &= \left\| \mathbb{E}_{X \sim P}[\varphi(X)] - \mathbb{E}_{Y \sim Q}[\varphi(Y)] \right\|_\mathcal{H}
-\end{align*}
-$$
-
-where $\varphi(\cdot)$ is a mapping function of the data points. If we want to consider all orders of moments of the distributions, the mapping vectors $\varphi(X)$ and $\varphi(Y)$ will be infinite-dimensional. Thus, we cannot calculate them directly. However, if we have a kernel that gives the same result as the inner product of these two mappings in Hilbert space ($k(x, y) = \langle \varphi(x), \varphi(y) \rangle_\mathcal{H}$), then the $MMD^2$ can be computed using only the kernel and without explicitly using $\varphi(X)$ and $\varphi(Y)$ (this is called the kernel trick):
-
-$$
-\begin{align*}
-\text{MMD}^2(P, Q) &= \langle \mathbb{E}_{X \sim P}[\varphi(X)], \mathbb{E}_{X' \sim P}[\varphi(X')] \rangle_\mathcal{H} + \langle \mathbb{E}_{Y \sim Q}[\varphi(Y)], \mathbb{E}_{Y' \sim Q}[\varphi(Y')] \rangle_\mathcal{H} \\
-&\quad - 2\langle \mathbb{E}_{X \sim P}[\varphi(X)], \mathbb{E}_{Y \sim Q}[\varphi(Y)] \rangle_\mathcal{H} \\
-&= \mathbb{E}_{X, X' \sim P}[k(X, X')] + \mathbb{E}_{Y, Y' \sim Q}[k(Y, Y')] - 2\mathbb{E}_{X \sim P, Y \sim Q}[k(X, Y)]
-\end{align*}
-$$
-
-### How to choose the kernel ?
-The choice of the kernel for selecting prototypes depends on the specific problem and the characteristics of your data. Several kernels can be used, including:
-
-- Gaussian
-- Laplace
-- Polynomial
-- Linear...
-
-If we consider any exponential kernel (Gaussian kernel, Laplace, ...), we automatically consider all the moments for the distribution, as the Taylor expansion of the exponential considers infinite-order moments. It is better to use a non-linear kernel to capture non-linear relationships in your data. If the problem is linear, it is better to choose a linear kernel such as the dot product kernel, since it is computationally efficient and often requires fewer hyperparameters to tune.
-
-For the MMD-critic method, the kernel must satisfy a condition ensuring the submodularity of the set function (the Gaussian kernel respects this constraint). In contrast, for Protodash and Protogreedy, any kernel can be used, as these methods rely on weak submodularity instead of full submodularity.
-
-### Default kernel
-The default kernel used is Gaussian kernel. This kernel distance assigns higher similarity to points that are close in feature space and gradually decreases similarity as points move further apart. It is a good choice when your data has complexity. However, it can be sensitive to the choice of hyperparameters, such as the width $\sigma$ of the Gaussian kernel, which may need to be carefully fine-tuned.
-
-## Prototypes for Post-hoc Interpretability
-
-Data-Centric methods such as Protogreedy, ProtoDash and MMD-critic can be used in either the output or the latent space of the model. In these cases, [projections methods](./algorithms/projections/) are used to transfer the data from the input space to the latent/output spaces.
-
-# Architecture of the code
-
-The Data-Centric prototypes methods are implemented as `search_methods`. The search method can have attribute `projection` that projects samples to a space where distances between samples make sense for the model. Then the `search_method` finds the prototypes by looking in the projected space.
-
-The class `ProtoGreedySearch` inherits from the `BaseSearchMethod` class. It finds prototypes and assigns a non-negative weight to each one.
-
-Both the `MMDCriticSearch` and `ProtoDashSearch` classes inherit from the `ProtoGreedySearch` class.
-
-The class `MMDCriticSearch` differs from `ProtoGreedySearch` by assigning equal weights to the selection of prototypes. The two classes use the same greedy algorithm. In the `compute_objective` method of `ProtoGreedySearch`, for each new candidate, we calculate the best weights for the selection of prototypes. However, in `MMDCriticSearch`, the `compute_objective` method assigns the same weight to all elements in the selection.
-
-The class `ProtoDashSearch`, like `ProtoGreedySearch`, assigns a non-negative weight to each prototype. However, the algorithm used by `ProtoDashSearch` is different: it maximizes a tight lower bound on $l(w)$ instead of maximizing $l(w)$, as done in `ProtoGreedySearch`. Therefore, `ProtoDashSearch` overrides the `compute_objective` method to calculate an objective based on the gradient of $l(w)$. It also overrides the `update_selection` method to select the best weights of the selection based on the gradient of the best candidate.
\ No newline at end of file
diff --git a/mkdocs.yml b/mkdocs.yml
index eb975edc..f3f3eaf6 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -43,12 +43,11 @@ nav:
- Tcav: api/concepts/tcav.md
- Craft: api/concepts/craft.md
- Example based:
- - Search Methods:
- - Prototypes Search Methods:
- - Prototypes: api/example_based/search_methods/prototypes_search_methods/prototypes.md
- - ProtoGreedySearch: api/example_based/search_methods/prototypes_search_methods/proto_greedy_search.md
- - ProtoDashSearch: api/example_based/search_methods/prototypes_search_methods/proto_dash_search.md
- - MMDCriticSearch: api/example_based/search_methods/prototypes_search_methods/mmd_critic_search.md
+ - Prototypes:
+ - API Description: api/example_based/prototypes/api_prototypes.md
+ - ProtoGreedy: api/example_based/prototypes/proto_greedy.md
+ - ProtoDash: api/example_based/prototypes/proto_dash.md
+ - MMDCritic: api/example_based/prototypes/mmd_critic.md
- Feature visualization:
- Modern Feature Visualization (MaCo): api/feature_viz/maco.md
- Feature visualization: api/feature_viz/feature_viz.md
From c889d483bff0a8924b353d129b3db8bb9c32543d Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Fri, 19 Jul 2024 16:04:30 +0200
Subject: [PATCH 072/138] example-based search methods: add infitity norm
distance
---
xplique/example_based/search_methods/common.py | 8 ++++----
xplique/example_based/search_methods/kleor.py | 2 +-
xplique/example_based/search_methods/knn.py | 4 ++--
xplique/example_based/search_methods/mmd_critic_search.py | 2 +-
xplique/example_based/search_methods/proto_dash_search.py | 2 +-
.../example_based/search_methods/proto_greedy_search.py | 2 +-
6 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/xplique/example_based/search_methods/common.py b/xplique/example_based/search_methods/common.py
index bac0bce6..0f3af3d4 100644
--- a/xplique/example_based/search_methods/common.py
+++ b/xplique/example_based/search_methods/common.py
@@ -123,7 +123,7 @@ def get_distance_function(distance: Union[int, str, Callable] = "euclidean",) ->
----------
distance : Union[int, str, Callable], optional
Distance function to use. It can be an integer, a string in
- {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
by default "euclidean".
"""
# set distance function
@@ -133,12 +133,12 @@ def get_distance_function(distance: Union[int, str, Callable] = "euclidean",) ->
return _distances[distance]
elif isinstance(distance, int):
return lambda x1, x2: _minkowski_distance(x1, x2, p=distance)
- elif distance == np.inf:
+ elif distance == np.inf or (isinstance(distance, str) and distance == "inf"):
return lambda x1, x2: _chebyshev_distance(x1, x2)
else:
raise AttributeError(
"The distance parameter is expected to be either a Callable, "
- + f" an integer, or a string in {_distances.keys()}. "
- +f"But {type(distance)} was received."
+ + f" an integer, 'inf', or a string in {_distances.keys()}. "
+ +f"But a {type(distance)} was received, with value {distance}."
)
diff --git a/xplique/example_based/search_methods/kleor.py b/xplique/example_based/search_methods/kleor.py
index 57a238e0..ed6b92a3 100644
--- a/xplique/example_based/search_methods/kleor.py
+++ b/xplique/example_based/search_methods/kleor.py
@@ -43,7 +43,7 @@ class BaseKLEORSearch(FilterKNN, ABC):
Number of sample treated simultaneously.
distance
Distance function for examples search. It can be an integer, a string in
- {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
by default "euclidean".
"""
def __init__(
diff --git a/xplique/example_based/search_methods/knn.py b/xplique/example_based/search_methods/knn.py
index 18d3fd2b..d5ed1be2 100644
--- a/xplique/example_based/search_methods/knn.py
+++ b/xplique/example_based/search_methods/knn.py
@@ -160,7 +160,7 @@ class KNN(BaseKNN):
the best.
distance
Distance function for examples search. It can be an integer, a string in
- {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
by default "euclidean".
"""
def __init__(
@@ -316,7 +316,7 @@ class FilterKNN(BaseKNN):
the best.
distance
Distance function for examples search. It can be an integer, a string in
- {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
by default "euclidean".
filter_fn
A Callable that takes as inputs the inputs, their targets, the cases and their targets and
diff --git a/xplique/example_based/search_methods/mmd_critic_search.py b/xplique/example_based/search_methods/mmd_critic_search.py
index 538ed277..fae99771 100644
--- a/xplique/example_based/search_methods/mmd_critic_search.py
+++ b/xplique/example_based/search_methods/mmd_critic_search.py
@@ -38,7 +38,7 @@ class MMDCriticSearch(ProtoGreedySearch):
It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
distance
Distance function for examples search. It can be an integer, a string in
- {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
by default "euclidean".
nb_prototypes : int
Number of prototypes to find.
diff --git a/xplique/example_based/search_methods/proto_dash_search.py b/xplique/example_based/search_methods/proto_dash_search.py
index cb1d9097..21a8ae2a 100644
--- a/xplique/example_based/search_methods/proto_dash_search.py
+++ b/xplique/example_based/search_methods/proto_dash_search.py
@@ -104,7 +104,7 @@ class ProtoDashSearch(ProtoGreedySearch):
It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
distance
Distance function for examples search. It can be an integer, a string in
- {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
by default "euclidean".
nb_prototypes : int
Number of prototypes to find.
diff --git a/xplique/example_based/search_methods/proto_greedy_search.py b/xplique/example_based/search_methods/proto_greedy_search.py
index 0a6a9b28..c1e46862 100644
--- a/xplique/example_based/search_methods/proto_greedy_search.py
+++ b/xplique/example_based/search_methods/proto_greedy_search.py
@@ -57,7 +57,7 @@ class ProtoGreedySearch(BaseSearchMethod):
It should match the batch size of the `search_set` in the case of a `tf.data.Dataset`.
distance
Distance function for examples search. It can be an integer, a string in
- {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
by default "euclidean".
nb_prototypes : int
Number of prototypes to find.
From b6c3ffbafec60b534ef9d6980ae7d303d5d8e396 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Fri, 19 Jul 2024 16:06:04 +0200
Subject: [PATCH 073/138] example-based projections: debug and remove get
inputs weights
---
.../example_based/projections/attributions.py | 48 -----------------
xplique/example_based/projections/base.py | 52 ++++---------------
xplique/example_based/projections/commons.py | 9 ++--
xplique/example_based/projections/hadamard.py | 52 +------------------
4 files changed, 18 insertions(+), 143 deletions(-)
diff --git a/xplique/example_based/projections/attributions.py b/xplique/example_based/projections/attributions.py
index 0cf5c2af..8d298afe 100644
--- a/xplique/example_based/projections/attributions.py
+++ b/xplique/example_based/projections/attributions.py
@@ -77,51 +77,3 @@ def __init__(
# set methods
super().__init__(get_weights, space_projection, mappable=False)
-
- def get_input_weights(
- self,
- inputs: Union[tf.Tensor, np.ndarray],
- targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
- ):
- """
- For visualization purpose (and only), we may be interested to project weights
- from the projected space to the input space.
- This is applied only if their is a difference in dimension.
- We assume here that we are treating images and an upsampling is applied.
-
- Parameters
- ----------
- inputs
- Tensor or Array. Input samples to be explained.
- Expected shape among (N, W), (N, T, W), (N, W, H, C).
- More information in the documentation.
- targets
- Additional parameter for `self.get_weights` function.
-
- Returns
- -------
- input_weights
- Tensor with the same dimension as `inputs` modulo the channels.
- They are an upsampled version of the actual weights used in the projection.
- """
- projected_inputs = self.space_projection(inputs)
- weights = self.get_weights(projected_inputs, targets)
-
- # take mean over channels for images
- channel_mean_fn = lambda: tf.reduce_mean(weights, axis=-1, keepdims=True)
- weights = tf.cond(
- pred=tf.shape(weights).shape[0] < 4,
- true_fn=lambda: weights,
- false_fn=channel_mean_fn,
- )
-
- # resizing
- resize_fn = lambda: tf.image.resize(
- weights, inputs.shape[1:-1], method="bicubic"
- )
- input_weights = tf.cond(
- pred=projected_inputs.shape == inputs.shape,
- true_fn=lambda: weights,
- false_fn=resize_fn,
- )
- return input_weights
diff --git a/xplique/example_based/projections/base.py b/xplique/example_based/projections/base.py
index 592a3d1d..4a76de29 100644
--- a/xplique/example_based/projections/base.py
+++ b/xplique/example_based/projections/base.py
@@ -108,43 +108,6 @@ def get_weights(inputs, _ = None):
# set device
self.device = get_device(device)
- def get_input_weights(
- self,
- inputs: Union[tf.Tensor, np.ndarray],
- targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
- ):
- """
- Depending on the projection, we may not be able to visualize weights
- as they are after the space projection. In this case, this method should be overwritten,
- as in `AttributionProjection` that applies an up-sampling.
-
- Parameters
- ----------
- inputs
- Tensor or Array. Input samples to be explained.
- Expected shape among (N, W), (N, T, W), (N, W, H, C).
- More information in the documentation.
- targets
- Additional parameter for `self.get_weights` function.
-
- Returns
- -------
- input_weights
- Tensor with the same dimension as `inputs` modulo the channels.
- They are an up-sampled version of the actual weights used in the projection.
- """
- projected_inputs = self.space_projection(inputs)
- assert tf.reduce_all(tf.equal(projected_inputs, inputs)), (
- "Weights cannot be interpreted in the input space"
- + "if `space_projection()` is not an identity."
- + "Either remove 'weights' from the returns or"
- + "make your own projection and overwrite `get_input_weights`."
- )
-
- weights = self.get_weights(projected_inputs, targets)
-
- return weights
-
@sanitize_inputs_targets
def project(
self,
@@ -270,10 +233,17 @@ def _loop_project_dataset(
batch_size = None
# iteratively project the dataset
- for inputs, targets in tf.data.Dataset.zip((cases_dataset, targets_dataset)):
- if batch_size is None:
- batch_size = inputs.shape[0] # TODO check if there is a smarter way to do this
- projected_cases_dataset.append(self.project(inputs, targets))
+ if targets_dataset is None:
+ for inputs in cases_dataset:
+ if batch_size is None:
+ batch_size = inputs.shape[0] # TODO check if there is a smarter way to do this
+ projected_cases_dataset.append(self.project(inputs, None))
+ else:
+ # in case targets are provided, we zip the datasets and project them together
+ for inputs, targets in tf.data.Dataset.zip((cases_dataset, targets_dataset)):
+ if batch_size is None:
+ batch_size = inputs.shape[0] # TODO check if there is a smarter way to do this
+ projected_cases_dataset.append(self.project(inputs, targets))
projected_cases_dataset = tf.concat(projected_cases_dataset, axis=0)
projected_cases_dataset = tf.data.Dataset.from_tensor_slices(projected_cases_dataset)
diff --git a/xplique/example_based/projections/commons.py b/xplique/example_based/projections/commons.py
index 42260f75..ee9091c2 100644
--- a/xplique/example_based/projections/commons.py
+++ b/xplique/example_based/projections/commons.py
@@ -50,7 +50,8 @@ def model_splitting(model: Union[tf.keras.Model, 'torch.nn.Module'],
return _torch_model_splitting(model, latent_layer, device)
except ImportError as exc:
raise AttributeError(
- f"Unknown model type, should be either `tf.keras.Model` or `torch.nn.Module`."\
+ exc.__str__()+"\n\n"\
+ +f"Unknown model type, should be either `tf.keras.Model` or `torch.nn.Module`."\
+f"But got {type(model)} instead.")
@@ -150,7 +151,7 @@ def _torch_model_splitting(model: 'torch.nn.Module',
"""
import torch
import torch.nn as nn
- from ...wrappers.pytorch import PyTorchWrapper
+ from ...wrappers import TorchWrapper
warnings.warn("Automatically splitting the provided PyTorch model into two parts. "\
+"This splitting is based on `model.named_children()`. "\
@@ -194,8 +195,8 @@ def second_model_forward(x):
# Wrap models to obtain tensorflow ones
first_model.eval()
- wrapped_first_model = PyTorchWrapper(first_model, device=device)
+ wrapped_first_model = TorchWrapper(first_model, device=device)
second_model.eval()
- wrapped_second_model = PyTorchWrapper(second_model, device=device)
+ wrapped_second_model = TorchWrapper(second_model, device=device)
return wrapped_first_model, wrapped_second_model
\ No newline at end of file
diff --git a/xplique/example_based/projections/hadamard.py b/xplique/example_based/projections/hadamard.py
index 884c0217..97dc7acc 100644
--- a/xplique/example_based/projections/hadamard.py
+++ b/xplique/example_based/projections/hadamard.py
@@ -16,7 +16,7 @@
def _target_free_classification_operator(model: Callable,
inputs: tf.Tensor,
- targets: Optional[tf.Tensor]) -> tf.Tensor: # TODO: test
+ targets: Optional[tf.Tensor]) -> tf.Tensor: # TODO: test, and use in attribution projection
"""
Compute predictions scores, only for the label class, for a batch of samples.
It has the same behavior as `Tasks.CLASSIFICATION` operator
@@ -158,52 +158,4 @@ def from_splitted_model(cls,
super().__init__(get_weights=get_weights,
space_projection=features_extractor,
- mappable=mappable)
-
- def get_input_weights(
- self,
- inputs: Union[tf.Tensor, np.ndarray],
- targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
- ):
- """
- For visualization purpose (and only), we may be interested to project weights
- from the projected space to the input space.
- This is applied only if their is a difference in dimension.
- We assume here that we are treating images and an upsampling is applied.
-
- Parameters
- ----------
- inputs
- Tensor or Array. Input samples to be explained.
- Expected shape among (N, W), (N, T, W), (N, W, H, C).
- More information in the documentation.
- targets
- Additional parameter for `self.get_weights` function.
-
- Returns
- -------
- input_weights
- Tensor with the same dimension as `inputs` modulo the channels.
- They are an upsampled version of the actual weights used in the projection.
- """
- projected_inputs = self.space_projection(inputs)
- weights = self.get_weights(projected_inputs, targets)
-
- # take mean over channels for images
- channel_mean_fn = lambda: tf.reduce_mean(weights, axis=-1, keepdims=True)
- weights = tf.cond(
- pred=tf.shape(weights).shape[0] < 4,
- true_fn=lambda: weights,
- false_fn=channel_mean_fn,
- )
-
- # resizing
- resize_fn = lambda: tf.image.resize(
- weights, inputs.shape[1:-1], method="bicubic"
- )
- input_weights = tf.cond(
- pred=projected_inputs.shape == inputs.shape,
- true_fn=lambda: weights,
- false_fn=resize_fn,
- )
- return input_weights
+ mappable=mappable)
\ No newline at end of file
From 10438b0003495ff7eda3dfd11a388172c62d17fa Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Fri, 19 Jul 2024 16:07:16 +0200
Subject: [PATCH 074/138] example-based: remove weights from possible returns
---
tests/example_based/test_cole.py | 16 +-------
tests/example_based/test_similar_examples.py | 11 -----
xplique/example_based/base_example_method.py | 40 +++----------------
xplique/example_based/cole.py | 2 +-
xplique/example_based/contrastive_examples.py | 12 ++----
xplique/example_based/prototypes.py | 2 +-
xplique/example_based/similar_examples.py | 2 +-
7 files changed, 15 insertions(+), 70 deletions(-)
diff --git a/tests/example_based/test_cole.py b/tests/example_based/test_cole.py
index 3864a71d..f94ea24a 100644
--- a/tests/example_based/test_cole.py
+++ b/tests/example_based/test_cole.py
@@ -120,13 +120,6 @@ def test_cole_attribution():
# a different distance should give different results
assert not almost_equal(examples_constructor, examples_different_distance)
- # check weights are equal to the attribution directly on the input
- method_constructor.returns = ["weights", "include_inputs"]
- assert almost_equal(
- method_constructor.explain(x_test, y_test)["weights"][:, 0],
- Saliency(model)(x_test, y_test),
- )
-
def test_cole_hadamard():
"""
@@ -205,7 +198,7 @@ def test_cole_splitting():
cases_dataset=x_train,
targets_dataset=y_train,
k=k,
- case_returns=["examples", "weights", "include_inputs"],
+ case_returns=["examples", "include_inputs"],
model=model,
latent_layer="last_conv",
attribution_method=Occlusion,
@@ -215,14 +208,9 @@ def test_cole_splitting():
# Generate explanation
outputs = method.explain(x_test, y_test)
- examples, weights = outputs["examples"], outputs["weights"]
+ examples = outputs["examples"]
# Verifications
# Shape should be (n, k, h, w, c)
nb_samples_test = x_test.shape[0]
assert examples.shape == (nb_samples_test, k + 1) + input_shape
- assert weights.shape[:-1] == (nb_samples_test, k + 1) + input_shape[:-1]
-
-
-# test_cole_attribution()
-# test_cole_splitting()
diff --git a/tests/example_based/test_similar_examples.py b/tests/example_based/test_similar_examples.py
index db4af594..4580ed6d 100644
--- a/tests/example_based/test_similar_examples.py
+++ b/tests/example_based/test_similar_examples.py
@@ -205,13 +205,11 @@ def test_similar_examples_return_multiple_elements():
assert isinstance(method_output, dict)
examples = method_output["examples"]
- weights = method_output["weights"]
distances = method_output["distances"]
labels = method_output["labels"]
# test every outputs shape (with the include inputs)
assert examples.shape == (nb_samples_test, k + 1) + input_shape
- assert weights.shape == (nb_samples_test, k + 1) + input_shape
# the inputs distance ae zero and indices do not exist
assert distances.shape == (nb_samples_test, k)
assert labels.shape == (nb_samples_test, k)
@@ -227,9 +225,6 @@ def test_similar_examples_return_multiple_elements():
examples[i, 3], x_train[i + 2]
)
- # test weights
- assert almost_equal(weights[i], tf.ones(weights[i].shape, dtype=tf.float32))
-
# test distances
assert almost_equal(distances[i, 0], 0)
assert almost_equal(distances[i, 1], sqrt(prod(input_shape)))
@@ -294,9 +289,3 @@ def test_similar_examples_weighting():
assert almost_equal(examples[i, 2], x_train[i]) or almost_equal(
examples[i, 2], x_train[i + 2]
)
-
-
-# test_similar_examples_input_dataset_management()
-# test_similar_examples_basic()
-# test_similar_examples_return_multiple_elements()
-# test_similar_examples_weighting()
diff --git a/xplique/example_based/base_example_method.py b/xplique/example_based/base_example_method.py
index 3fabe7c1..7e6e19e9 100644
--- a/xplique/example_based/base_example_method.py
+++ b/xplique/example_based/base_example_method.py
@@ -73,7 +73,7 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
Number of sample treated simultaneously for projection and search.
Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
"""
- _returns_possibilities = ["examples", "weights", "distances", "labels", "include_inputs"]
+ _returns_possibilities = ["examples", "distances", "labels", "include_inputs"]
def __init__(
self,
@@ -157,8 +157,6 @@ def returns(self, returns: Union[List[str], str]):
`returns` can be set to 'all' for all possible elements to be returned.
- 'examples' correspond to the expected examples,
the inputs may be included in first position. (n, k(+1), ...)
- - 'weights' the weights in the input space used in the projection.
- They are associated to the input and the examples. (n, k(+1), ...)
- 'distances' the distances between the inputs and the corresponding examples.
They are associated to the examples. (n, k, ...)
- 'labels' if provided through `dataset_labels`,
@@ -306,7 +304,7 @@ def explain(
search_output = self.search_method(projected_inputs, targets)
# manage returned elements
- return self.format_search_output(search_output, inputs, targets)
+ return self.format_search_output(search_output, inputs)
def __call__(
self,
@@ -320,7 +318,6 @@ def format_search_output(
self,
search_output: Dict[str, tf.Tensor],
inputs: Union[tf.Tensor, np.ndarray],
- targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
):
"""
Format the output of the `search_method` to match the expected returns in `self.returns`.
@@ -332,9 +329,9 @@ def format_search_output(
inputs
Tensor or Array. Input samples to be explained.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
- targets
- Targets associated to the cases_dataset for dataset projection.
- See `projection` for details.
+ # targets
+ # Targets associated to the cases_dataset for dataset projection.
+ # See `projection` for details.
Returns
-------
@@ -348,40 +345,15 @@ def format_search_output(
# gather examples, labels, and targets from the example's indices of the search output
examples = dataset_gather(self.cases_dataset, search_output["indices"])
examples_labels = dataset_gather(self.labels_dataset, search_output["indices"])
- examples_targets = dataset_gather(
- self.targets_dataset, search_output["indices"]
- )
# add examples and weights
- if "examples" in self.returns or "weights" in self.returns:
+ if "examples" in self.returns: # or "weights" in self.returns:
if "include_inputs" in self.returns:
# include inputs
inputs = tf.expand_dims(inputs, axis=1)
examples = tf.concat([inputs, examples], axis=1)
- if targets is not None:
- targets = tf.expand_dims(targets, axis=1)
- examples_targets = tf.concat([targets, examples_targets], axis=1)
- else:
- examples_targets = [None] * len(examples)
if "examples" in self.returns:
return_dict["examples"] = examples
- if "weights" in self.returns:
- # get weights of examples (n, k, ...)
- # we iterate on the inputs dimension through maps
- # and ask weights for batch of examples
- weights = []
- for ex, ex_targ in zip(examples, examples_targets):
- if isinstance(self.projection, Projection):
- # get weights in the input space
- weights.append(self.projection.get_input_weights(ex, ex_targ))
- else:
- raise AttributeError(
- "Cannot extract weights from the provided projection function"
- + "Either remove 'weights' from the `case_returns` or"
- + "inherit from `Projection` and overwrite `get_input_weights`."
- )
-
- return_dict["weights"] = tf.stack(weights, axis=0)
# add indices, distances, and labels
if "indices" in self.returns:
diff --git a/xplique/example_based/cole.py b/xplique/example_based/cole.py
index 47b21dbe..00955452 100644
--- a/xplique/example_based/cole.py
+++ b/xplique/example_based/cole.py
@@ -46,7 +46,7 @@ class Cole(SimilarExamples):
The number of examples to retrieve per input.
distance
Distance function for examples search. It can be an integer, a string in
- {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
by default "euclidean".
case_returns
String or list of string with the elements to return in `self.explain()`.
diff --git a/xplique/example_based/contrastive_examples.py b/xplique/example_based/contrastive_examples.py
index 31afc794..b18302b3 100644
--- a/xplique/example_based/contrastive_examples.py
+++ b/xplique/example_based/contrastive_examples.py
@@ -67,7 +67,7 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
distance
Distance function for examples search. It can be an integer, a string in
- {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
by default "euclidean".
"""
def __init__(
@@ -183,7 +183,7 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
distance
Distance for the FilterKNN search method.
Distance function for examples search. It can be an integer, a string in
- {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
by default "euclidean".
"""
def __init__(
@@ -345,7 +345,7 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
distance
Distance for the FilterKNN search method.
Distance function for examples search. It can be an integer, a string in
- {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
by default "euclidean".
"""
_returns_possibilities = [
@@ -422,7 +422,6 @@ def format_search_output(
self,
search_output: Dict[str, tf.Tensor],
inputs: Union[tf.Tensor, np.ndarray],
- targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
):
"""
Format the output of the `search_method` to match the expected returns in `self.returns`.
@@ -434,9 +433,6 @@ def format_search_output(
inputs
Tensor or Array. Input samples to be explained.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
- targets
- Targets associated to the cases_dataset for dataset projection.
- See `projection` for details.
Returns
-------
@@ -444,7 +440,7 @@ def format_search_output(
Dictionary with listed elements in `self.returns`.
The elements that can be returned are defined with _returns_possibilities static attribute of the class.
"""
- return_dict = super().format_search_output(search_output, inputs, targets)
+ return_dict = super().format_search_output(search_output, inputs)
if "nuns" in self.returns:
return_dict["nuns"] = dataset_gather(self.cases_dataset, search_output["nuns_indices"])
if "nuns_labels" in self.returns:
diff --git a/xplique/example_based/prototypes.py b/xplique/example_based/prototypes.py
index 5f4017d4..c1857b48 100644
--- a/xplique/example_based/prototypes.py
+++ b/xplique/example_based/prototypes.py
@@ -68,7 +68,7 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
Ignored if `tf.data.Dataset` are provided (these are supposed to be batched).
distance
Distance function for examples search. It can be an integer, a string in
- {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
by default "euclidean".
nb_prototypes : int
For general explanations, the number of prototypes to select.
diff --git a/xplique/example_based/similar_examples.py b/xplique/example_based/similar_examples.py
index 5c785322..b1433370 100644
--- a/xplique/example_based/similar_examples.py
+++ b/xplique/example_based/similar_examples.py
@@ -63,7 +63,7 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
distance
Distance for the knn search method. It can be an integer, a string in
- {"manhattan", "euclidean", "cosine", "chebyshev"}, or a Callable,
+ {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
by default "euclidean".
"""
def __init__(
From 797f76e50a35448002bcce21f4f18fb45ebc9e6f Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Mon, 22 Jul 2024 15:54:06 +0200
Subject: [PATCH 075/138] example-based projections: put target free operator
as common
---
.../example_based/projections/attributions.py | 8 +++-
xplique/example_based/projections/commons.py | 39 +++++++++++++++++-
xplique/example_based/projections/hadamard.py | 41 +------------------
3 files changed, 47 insertions(+), 41 deletions(-)
diff --git a/xplique/example_based/projections/attributions.py b/xplique/example_based/projections/attributions.py
index 8d298afe..ef3b0ce8 100644
--- a/xplique/example_based/projections/attributions.py
+++ b/xplique/example_based/projections/attributions.py
@@ -12,7 +12,7 @@
from ...types import Callable, Union, Optional
from .base import Projection
-from .commons import model_splitting
+from .commons import model_splitting, target_free_classification_operator
class AttributionProjection(Projection):
@@ -71,6 +71,12 @@ def __init__(
else:
# split the model if a latent_layer is provided
space_projection, self.predictor = model_splitting(model, latent_layer)
+
+ # change default operator
+ if not "operator" in attribution_kwargs or attribution_kwargs["operator"] is None:
+ warnings.warn("No operator provided, using standard classification operator."\
+ + "For non-classification tasks, please specify an operator.")
+ attribution_kwargs["operator"] = target_free_classification_operator
# compute attributions
get_weights = self.method(self.predictor, **attribution_kwargs)
diff --git a/xplique/example_based/projections/commons.py b/xplique/example_based/projections/commons.py
index ee9091c2..45110310 100644
--- a/xplique/example_based/projections/commons.py
+++ b/xplique/example_based/projections/commons.py
@@ -199,4 +199,41 @@ def second_model_forward(x):
second_model.eval()
wrapped_second_model = TorchWrapper(second_model, device=device)
- return wrapped_first_model, wrapped_second_model
\ No newline at end of file
+ return wrapped_first_model, wrapped_second_model
+
+
+def target_free_classification_operator(model: Callable,
+ inputs: tf.Tensor,
+ targets: Optional[tf.Tensor] = None) -> tf.Tensor: # TODO: test, and use in attribution projection
+ """
+ Compute predictions scores, only for the label class, for a batch of samples.
+ It has the same behavior as `Tasks.CLASSIFICATION` operator
+ but computes targets at the same time if not provided.
+ Targets are a mask with 1 on the predicted class and 0 elsewhere.
+ This operator should only be used for classification tasks.
+
+
+ Parameters
+ ----------
+ model
+ Model used for computing predictions.
+ inputs
+ Input samples to be explained.
+ targets
+ One-hot encoded labels or regression target (e.g {+1, -1}), one for each sample.
+
+ Returns
+ -------
+ scores
+ Predictions scores computed, only for the label class.
+ """
+ predictions = model(inputs)
+
+ targets = tf.cond(
+ pred=tf.constant(targets is None, dtype=tf.bool),
+ true_fn=lambda: tf.one_hot(tf.argmax(predictions, axis=-1), predictions.shape[-1]),
+ false_fn=lambda: targets,
+ )
+
+ scores = tf.reduce_sum(predictions * targets, axis=-1)
+ return scores
diff --git a/xplique/example_based/projections/hadamard.py b/xplique/example_based/projections/hadamard.py
index 97dc7acc..05fb77e3 100644
--- a/xplique/example_based/projections/hadamard.py
+++ b/xplique/example_based/projections/hadamard.py
@@ -11,44 +11,7 @@
from ...types import Callable, Union, Optional, OperatorSignature
from .base import Projection
-from .commons import model_splitting
-
-
-def _target_free_classification_operator(model: Callable,
- inputs: tf.Tensor,
- targets: Optional[tf.Tensor]) -> tf.Tensor: # TODO: test, and use in attribution projection
- """
- Compute predictions scores, only for the label class, for a batch of samples.
- It has the same behavior as `Tasks.CLASSIFICATION` operator
- but computes targets at the same time if not provided.
- Targets are a mask with 1 on the predicted class and 0 elsewhere.
- This operator should only be used for classification tasks.
-
-
- Parameters
- ----------
- model
- Model used for computing predictions.
- inputs
- Input samples to be explained.
- targets
- One-hot encoded labels or regression target (e.g {+1, -1}), one for each sample.
-
- Returns
- -------
- scores
- Predictions scores computed, only for the label class.
- """
- predictions = model(inputs)
-
- targets = tf.cond(
- pred=tf.constant(targets is None, dtype=tf.bool),
- true_fn=lambda: tf.one_hot(tf.argmax(predictions, axis=-1), predictions.shape[-1]),
- false_fn=lambda: targets,
- )
-
- scores = tf.reduce_sum(predictions * targets, axis=-1)
- return scores
+from .commons import model_splitting, target_free_classification_operator
class HadamardProjection(Projection):
@@ -110,7 +73,7 @@ def __init__(
if operator is None:
warnings.warn("No operator provided, using standard classification operator."\
+ "For non-classification tasks, please specify an operator.")
- operator = _target_free_classification_operator
+ operator = target_free_classification_operator
# the weights are given by the gradient of the operator based on the predictor
gradients, _ = get_gradient_functions(self.predictor, operator)
From d2653300ef5dc54197aeb7789f724f9548705704 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Mon, 22 Jul 2024 15:54:57 +0200
Subject: [PATCH 076/138] example-based: split counterfactuals and semifactuals
---
xplique/example_based/__init__.py | 6 +-
xplique/example_based/cole.py | 125 ----------
...rastive_examples.py => counterfactuals.py} | 230 ++----------------
xplique/example_based/semifactuals.py | 220 +++++++++++++++++
xplique/example_based/similar_examples.py | 117 ++++++++-
5 files changed, 355 insertions(+), 343 deletions(-)
delete mode 100644 xplique/example_based/cole.py
rename xplique/example_based/{contrastive_examples.py => counterfactuals.py} (54%)
create mode 100644 xplique/example_based/semifactuals.py
diff --git a/xplique/example_based/__init__.py b/xplique/example_based/__init__.py
index 3de46d18..fa83c1ba 100644
--- a/xplique/example_based/__init__.py
+++ b/xplique/example_based/__init__.py
@@ -2,7 +2,7 @@
Example-based methods available
"""
-from .cole import Cole
-from .similar_examples import SimilarExamples
+from .similar_examples import SimilarExamples, Cole
from .prototypes import Prototypes, ProtoGreedy, ProtoDash, MMDCritic
-from .contrastive_examples import NaiveCounterFactuals, LabelAwareCounterFactuals, KLEORGlobalSim, KLEORSimMiss
+from .counterfactuals import NaiveCounterFactuals, LabelAwareCounterFactuals
+from .semifactuals import KLEORGlobalSim, KLEORSimMiss
diff --git a/xplique/example_based/cole.py b/xplique/example_based/cole.py
deleted file mode 100644
index 00955452..00000000
--- a/xplique/example_based/cole.py
+++ /dev/null
@@ -1,125 +0,0 @@
-"""
-Implementation of Cole method a simlilar examples method from example based module
-"""
-import numpy as np
-import tensorflow as tf
-
-from ..attributions.base import BlackBoxExplainer
-from ..types import Callable, List, Optional, Union, Type
-
-from .similar_examples import SimilarExamples
-from .projections import AttributionProjection, HadamardProjection
-
-
-class Cole(SimilarExamples):
- """
- Cole is a similar examples method that gives the most similar examples to a query in some specific projection space.
- Cole use the model (to be explained) to build a search space so that distances are meaningful for the model.
- It uses attribution methods to weight inputs.
- Those attributions may be computed in the latent space for high-dimensional data like images.
-
- It is an implementation of a method proposed by Kenny et Keane in 2019,
- Twin-Systems to Explain Artificial Neural Networks using Case-Based Reasoning:
- https://researchrepository.ucd.ie/handle/10197/11064
-
- Parameters
- ----------
- cases_dataset
- The dataset used to train the model, examples are extracted from this dataset.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- labels_dataset
- Labels associated to the examples in the dataset. Indices should match with cases_dataset.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other datasets should match `cases_dataset`.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- targets_dataset
- Targets associated to the cases_dataset for dataset projection, oftentimes the one-hot encoding of a model's
- predictions. See `projection` for detail.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other datasets should match `cases_dataset`.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- k
- The number of examples to retrieve per input.
- distance
- Distance function for examples search. It can be an integer, a string in
- {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
- by default "euclidean".
- case_returns
- String or list of string with the elements to return in `self.explain()`.
- See the base class returns property for details.
- batch_size
- Number of sample treated simultaneously for projection and search.
- Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
- latent_layer
- Layer used to split the model, the first part will be used for projection and
- the second to compute the attributions. By default, the model is not split.
- For such split, the `model` should be a `tf.keras.Model`.
-
- Layer to target for the outputs (e.g logits or after softmax).
- If an `int` is provided it will be interpreted as a layer index.
- If a `string` is provided it will look for the layer name.
-
- The method as described in the paper apply the separation on the last convolutional layer.
- To do so, the `"last_conv"` parameter will extract it.
- Otherwise, `-1` could be used for the last layer before softmax.
- attribution_method
- Class of the attribution method to use for projection.
- It should inherit from `xplique.attributions.base.BlackBoxExplainer`.
- By default, it computes the gradient to make the Hadamard product in the latent space.
- attribution_kwargs
- Parameters to be passed for the construction of the `attribution_method`.
- """
- def __init__(
- self,
- cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- model: tf.keras.Model,
- targets_dataset: Union[tf.Tensor, np.ndarray],
- labels_dataset: Optional[Union[tf.Tensor, np.ndarray]] = None,
- k: int = 1,
- distance: Union[str, Callable] = "euclidean",
- case_returns: Optional[Union[List[str], str]] = "examples",
- batch_size: Optional[int] = 32,
- latent_layer: Optional[Union[str, int]] = None,
- attribution_method: Union[str, Type[BlackBoxExplainer]] = "gradient",
- **attribution_kwargs,
- ):
- assert targets_dataset is not None
-
- # build the corresponding projection
- if isinstance(attribution_method, str) and attribution_method.lower() == "gradient":
-
- operator = attribution_kwargs.get("operator", None)
-
- projection = HadamardProjection(
- model=model,
- latent_layer=latent_layer,
- operator=operator,
- )
- elif issubclass(attribution_method, BlackBoxExplainer):
- # build attribution projection
- projection = AttributionProjection(
- model=model,
- method=attribution_method,
- latent_layer=latent_layer,
- **attribution_kwargs,
- )
- else:
- raise ValueError(
- f"attribution_method should be 'gradient' or a subclass of BlackBoxExplainer," +\
- "not {attribution_method}"
- )
-
- super().__init__(
- cases_dataset=cases_dataset,
- targets_dataset=targets_dataset,
- labels_dataset=labels_dataset,
- projection=projection,
- k=k,
- case_returns=case_returns,
- batch_size=batch_size,
- distance=distance,
- )
diff --git a/xplique/example_based/contrastive_examples.py b/xplique/example_based/counterfactuals.py
similarity index 54%
rename from xplique/example_based/contrastive_examples.py
rename to xplique/example_based/counterfactuals.py
index b18302b3..360fdcda 100644
--- a/xplique/example_based/contrastive_examples.py
+++ b/xplique/example_based/counterfactuals.py
@@ -6,14 +6,13 @@
import numpy as np
import tensorflow as tf
-from ..types import Callable, List, Optional, Union, Dict
-from ..commons import sanitize_inputs_targets, dataset_gather
+from ..types import Callable, List, Optional, Union
+from ..commons import sanitize_inputs_targets
from .base_example_method import BaseExampleMethod
-from .search_methods import ORDER, FilterKNN, KLEORSimMissSearch, KLEORGlobalSimSearch
+from .search_methods import ORDER, FilterKNN
from .projections import Projection
-from .search_methods.base import _sanitize_returns
class NaiveCounterFactuals(BaseExampleMethod):
"""
@@ -131,6 +130,7 @@ def filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor:
mask = tf.not_equal(tf.expand_dims(predicted_labels, axis=1), label_targets) #(n, bs)
return mask
+
class LabelAwareCounterFactuals(BaseExampleMethod):
"""
This method will search the counterfactuals of a query within an expected class. This class should be provided with
@@ -165,11 +165,10 @@ class LabelAwareCounterFactuals(BaseExampleMethod):
Example of Callable:
```
- def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
+ def custom_projection(inputs: tf.Tensor, np.ndarray):
'''
Example of projection,
inputs are the elements to project.
- targets are optional parameters to orientated the projection.
'''
projected_inputs = # do some magic on inputs, it should use the model.
return projected_inputs
@@ -239,19 +238,19 @@ def search_method_class(self):
return FilterKNN
- def filter_fn(self, _, __, cf_targets, cases_targets) -> tf.Tensor:
+ def filter_fn(self, _, __, cf_expected_classes, cases_targets) -> tf.Tensor:
"""
Filter function to mask the cases for which the target is different from the target(s) expected for the
counterfactuals.
Parameters
----------
- cf_targets
+ cf_expected_classes
The one-hot encoding of the target class for the counterfactuals.
cases_targets
The one-hot encoding of the target class for the cases.
"""
- mask = tf.matmul(cf_targets, cases_targets, transpose_b=True) #(n, bs)
+ mask = tf.matmul(cf_expected_classes, cases_targets, transpose_b=True) #(n, bs)
# TODO: I think some retracing are done here
mask = tf.cast(mask, dtype=tf.bool)
return mask
@@ -260,7 +259,7 @@ def filter_fn(self, _, __, cf_targets, cases_targets) -> tf.Tensor:
def explain(
self,
inputs: Union[tf.Tensor, np.ndarray],
- cf_targets: Union[tf.Tensor, np.ndarray],
+ cf_expected_classes: Union[tf.Tensor, np.ndarray],
):
"""
Return the relevant CF examples to explain the inputs.
@@ -273,7 +272,7 @@ def explain(
Tensor or Array. Input samples to be explained.
Expected shape among (N, W), (N, T, W), (N, W, H, C).
More information in the documentation.
- cf_targets
+ cf_expected_classes
Tensor or Array. One-hot encoding of the target class for the counterfactuals.
Returns
@@ -282,208 +281,11 @@ def explain(
Dictionary with listed elements in `self.returns`.
The elements that can be returned are defined with _returns_possibilities static attribute of the class.
"""
- return super().explain(inputs, cf_targets)
-
-
-class KLEORBase(BaseExampleMethod):
- """
- Base class for KLEOR methods. KLEOR methods search Semi-Factuals examples. In those methods, one should first
- retrieve the Nearest Unlike Neighbor (NUN) which is the closest example to the query that has a different prediction
- than the query. Then, the method search for the K-Nearest Neighbors (KNN) of the NUN that have the same prediction
- as the query.
-
- All the searches are done in a projection space where distances are relevant for the model. The projection space is
- defined by the `projection` method.
-
- Depending on the KLEOR method some additional condition for the search are added. See the specific KLEOR method for
- more details.
-
- Parameters
- ----------
- cases_dataset
- The dataset used to train the model, examples are extracted from this dataset.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- targets_dataset
- Targets are expected to be the one-hot encoding of the model's predictions for the samples in cases_dataset.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other datasets should match `cases_dataset`.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- labels_dataset
- Labels associated to the examples in the dataset. Indices should match with cases_dataset.
- `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
- Batch size and cardinality of other datasets should match `cases_dataset`.
- Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
- the case for your dataset, otherwise, examples will not make sense.
- k
- The number of examples to retrieve per input.
- projection
- Projection or Callable that project samples from the input space to the search space.
- The search space should be a space where distances are relevant for the model.
- It should not be `None`, otherwise, the model is not involved thus not explained. If you are interested in
- searching the input space, you should use a `BaseSearchMethod` instead.
+ # project inputs into the search space
+ projected_inputs = self.projection(inputs)
- Example of Callable:
- ```
- def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
- '''
- Example of projection,
- inputs are the elements to project.
- targets are optional parameters to orientated the projection.
- '''
- projected_inputs = # do some magic on inputs, it should use the model.
- return projected_inputs
- ```
- case_returns
- String or list of string with the elements to return in `self.explain()`.
- See the base class returns property for more details.
- batch_size
- Number of sample treated simultaneously for projection and search.
- Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
- distance
- Distance for the FilterKNN search method.
- Distance function for examples search. It can be an integer, a string in
- {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
- by default "euclidean".
- """
- _returns_possibilities = [
- "examples", "weights", "distances", "labels", "include_inputs", "nuns", "nuns_indices", "dist_to_nuns"
- ]
+ # look for relevant elements in the search space
+ search_output = self.search_method(projected_inputs, cf_expected_classes)
- def __init__(
- self,
- cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
- labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
- k: int = 1,
- projection: Union[Projection, Callable] = None,
- case_returns: Union[List[str], str] = "examples",
- batch_size: Optional[int] = 32,
- distance: Union[int, str, Callable] = "euclidean",
- ):
-
- super().__init__(
- cases_dataset=cases_dataset,
- labels_dataset=labels_dataset,
- targets_dataset=targets_dataset,
- k=k,
- projection=projection,
- case_returns=case_returns,
- batch_size=batch_size,
- )
-
- # set distance function and order for the search method
- self.distance = distance
- self.order = ORDER.ASCENDING
-
- # initiate search_method
- self.search_method = self.search_method_class(
- cases_dataset=self.projected_cases_dataset,
- targets_dataset=self.targets_dataset,
- k=self.k,
- search_returns=self._search_returns,
- batch_size=self.batch_size,
- distance=self.distance,
- )
-
- @property
- def returns(self) -> Union[List[str], str]:
- """Override the Base class returns' parameter."""
- return self._returns
-
- @returns.setter
- def returns(self, returns: Union[List[str], str]):
- """
- Set the returns parameter. The returns parameter is a string or a list of string with the elements to return
- in `self.explain()`. The elements that can be returned are defined with _returns_possibilities static attribute
- """
- default = "examples"
- self._returns = _sanitize_returns(returns, self._returns_possibilities, default)
- self._search_returns = ["indices", "distances"]
-
- if isinstance(self._returns, list) and ("nuns" in self._returns):
- self._search_returns.append("nuns_indices")
- elif isinstance(self._returns, list) and ("nuns_indices" in self._returns):
- self._search_returns.append("nuns_indices")
- elif isinstance(self._returns, list) and ("nuns_labels" in self._returns):
- self._search_returns.append("nuns_indices")
-
- if isinstance(self._returns, list) and ("dist_to_nuns" in self._returns):
- self._search_returns.append("dist_to_nuns")
-
- try:
- self.search_method.returns = self._search_returns
- except AttributeError:
- pass
-
- def format_search_output(
- self,
- search_output: Dict[str, tf.Tensor],
- inputs: Union[tf.Tensor, np.ndarray],
- ):
- """
- Format the output of the `search_method` to match the expected returns in `self.returns`.
-
- Parameters
- ----------
- search_output
- Dictionary with the required outputs from the `search_method`.
- inputs
- Tensor or Array. Input samples to be explained.
- Expected shape among (N, W), (N, T, W), (N, W, H, C).
-
- Returns
- -------
- return_dict
- Dictionary with listed elements in `self.returns`.
- The elements that can be returned are defined with _returns_possibilities static attribute of the class.
- """
- return_dict = super().format_search_output(search_output, inputs)
- if "nuns" in self.returns:
- return_dict["nuns"] = dataset_gather(self.cases_dataset, search_output["nuns_indices"])
- if "nuns_labels" in self.returns:
- return_dict["nuns_labels"] = dataset_gather(self.labels_dataset, search_output["nuns_indices"])
- if "nuns_indices" in self.returns:
- return_dict["nuns_indices"] = search_output["nuns_indices"]
- if "dist_to_nuns" in self.returns:
- return_dict["dist_to_nuns"] = search_output["dist_to_nuns"]
- return return_dict
-
-
-class KLEORSimMiss(KLEORBase):
- """
- The KLEORSimMiss method search for Semi-Factuals examples by searching for the Nearest Unlike Neighbor (NUN) of
- the query. The NUN is the closest example to the query that has a different prediction than the query. Then, the
- method search for the K-Nearest Neighbors (KNN) of the NUN that have the same prediction as the query.
-
- The search is done in a projection space where distances are relevant for the model. The projection space is defined
- by the `projection` method.
- """
- @property
- def search_method_class(self):
- """
- This property defines the search method class to use for the search. In this case, it is the KLEORSimMissSearch.
- """
- return KLEORSimMissSearch
-
-class KLEORGlobalSim(KLEORBase):
- """
- The KLEORGlobalSim method search for Semi-Factuals examples by searching for the Nearest Unlike Neighbor (NUN) of
- the query. The NUN is the closest example to the query that has a different prediction than the query. Then, the
- method search for the K-Nearest Neighbors (KNN) of the NUN that have the same prediction as the query.
-
- In addition, for a SF candidate to be considered, the SF should be closer to the query than the NUN in the
- projection space (i.e. the SF should be 'between' the input and its NUN). This condition is added to the search.
-
- The search is done in a projection space where distances are relevant for the model. The projection space is defined
- by the `projection` method.
- """
- @property
- def search_method_class(self):
- """
- This property defines the search method class to use for the search. In this case, it is the
- KLEORGlobalSimSearch.
- """
- return KLEORGlobalSimSearch
+ # manage returned elements
+ return self.format_search_output(search_output, inputs)
diff --git a/xplique/example_based/semifactuals.py b/xplique/example_based/semifactuals.py
new file mode 100644
index 00000000..b912eb1e
--- /dev/null
+++ b/xplique/example_based/semifactuals.py
@@ -0,0 +1,220 @@
+"""
+Implementation of semi factuals methods for classification tasks.
+"""
+import warnings
+
+import numpy as np
+import tensorflow as tf
+
+from ..types import Callable, List, Optional, Union, Dict
+from ..commons import dataset_gather
+
+from .base_example_method import BaseExampleMethod
+from .search_methods import ORDER, KLEORSimMissSearch, KLEORGlobalSimSearch
+from .projections import Projection
+
+from .search_methods.base import _sanitize_returns
+
+
+class KLEORBase(BaseExampleMethod):
+ """
+ Base class for KLEOR methods. KLEOR methods search Semi-Factuals examples. In those methods, one should first
+ retrieve the Nearest Unlike Neighbor (NUN) which is the closest example to the query that has a different prediction
+ than the query. Then, the method search for the K-Nearest Neighbors (KNN) of the NUN that have the same prediction
+ as the query.
+
+ All the searches are done in a projection space where distances are relevant for the model. The projection space is
+ defined by the `projection` method.
+
+ Depending on the KLEOR method some additional condition for the search are added. See the specific KLEOR method for
+ more details.
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from this dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ targets_dataset
+ Targets are expected to be the one-hot encoding of the model's predictions for the samples in cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ k
+ The number of examples to retrieve per input.
+ projection
+ Projection or Callable that project samples from the input space to the search space.
+ The search space should be a space where distances are relevant for the model.
+ It should not be `None`, otherwise, the model is not involved thus not explained. If you are interested in
+ searching the input space, you should use a `BaseSearchMethod` instead.
+
+ Example of Callable:
+ ```
+ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
+ '''
+ Example of projection,
+ inputs are the elements to project.
+ targets are optional parameters to orientated the projection.
+ '''
+ projected_inputs = # do some magic on inputs, it should use the model.
+ return projected_inputs
+ ```
+ case_returns
+ String or list of string with the elements to return in `self.explain()`.
+ See the base class returns property for more details.
+ batch_size
+ Number of sample treated simultaneously for projection and search.
+ Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+ distance
+ Distance for the FilterKNN search method.
+ Distance function for examples search. It can be an integer, a string in
+ {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
+ by default "euclidean".
+ """
+ _returns_possibilities = [
+ "examples", "weights", "distances", "labels", "include_inputs", "nuns", "nuns_indices", "dist_to_nuns"
+ ]
+
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ projection: Union[Projection, Callable] = None,
+ case_returns: Union[List[str], str] = "examples",
+ batch_size: Optional[int] = 32,
+ distance: Union[int, str, Callable] = "euclidean",
+ ):
+
+ super().__init__(
+ cases_dataset=cases_dataset,
+ labels_dataset=labels_dataset,
+ targets_dataset=targets_dataset,
+ k=k,
+ projection=projection,
+ case_returns=case_returns,
+ batch_size=batch_size,
+ )
+
+ # set distance function and order for the search method
+ self.distance = distance
+ self.order = ORDER.ASCENDING
+
+ # initiate search_method
+ self.search_method = self.search_method_class(
+ cases_dataset=self.projected_cases_dataset,
+ targets_dataset=self.targets_dataset,
+ k=self.k,
+ search_returns=self._search_returns,
+ batch_size=self.batch_size,
+ distance=self.distance,
+ )
+
+ @property
+ def returns(self) -> Union[List[str], str]:
+ """Override the Base class returns' parameter."""
+ return self._returns
+
+ @returns.setter
+ def returns(self, returns: Union[List[str], str]):
+ """
+ Set the returns parameter. The returns parameter is a string or a list of string with the elements to return
+ in `self.explain()`. The elements that can be returned are defined with _returns_possibilities static attribute
+ """
+ default = "examples"
+ self._returns = _sanitize_returns(returns, self._returns_possibilities, default)
+ self._search_returns = ["indices", "distances"]
+
+ if isinstance(self._returns, list) and ("nuns" in self._returns):
+ self._search_returns.append("nuns_indices")
+ elif isinstance(self._returns, list) and ("nuns_indices" in self._returns):
+ self._search_returns.append("nuns_indices")
+ elif isinstance(self._returns, list) and ("nuns_labels" in self._returns):
+ self._search_returns.append("nuns_indices")
+
+ if isinstance(self._returns, list) and ("dist_to_nuns" in self._returns):
+ self._search_returns.append("dist_to_nuns")
+
+ try:
+ self.search_method.returns = self._search_returns
+ except AttributeError:
+ pass
+
+ def format_search_output(
+ self,
+ search_output: Dict[str, tf.Tensor],
+ inputs: Union[tf.Tensor, np.ndarray],
+ ):
+ """
+ Format the output of the `search_method` to match the expected returns in `self.returns`.
+
+ Parameters
+ ----------
+ search_output
+ Dictionary with the required outputs from the `search_method`.
+ inputs
+ Tensor or Array. Input samples to be explained.
+ Expected shape among (N, W), (N, T, W), (N, W, H, C).
+
+ Returns
+ -------
+ return_dict
+ Dictionary with listed elements in `self.returns`.
+ The elements that can be returned are defined with _returns_possibilities static attribute of the class.
+ """
+ return_dict = super().format_search_output(search_output, inputs)
+ if "nuns" in self.returns:
+ return_dict["nuns"] = dataset_gather(self.cases_dataset, search_output["nuns_indices"])
+ if "nuns_labels" in self.returns:
+ return_dict["nuns_labels"] = dataset_gather(self.labels_dataset, search_output["nuns_indices"])
+ if "nuns_indices" in self.returns:
+ return_dict["nuns_indices"] = search_output["nuns_indices"]
+ if "dist_to_nuns" in self.returns:
+ return_dict["dist_to_nuns"] = search_output["dist_to_nuns"]
+ return return_dict
+
+
+class KLEORSimMiss(KLEORBase):
+ """
+ The KLEORSimMiss method search for Semi-Factuals examples by searching for the Nearest Unlike Neighbor (NUN) of
+ the query. The NUN is the closest example to the query that has a different prediction than the query. Then, the
+ method search for the K-Nearest Neighbors (KNN) of the NUN that have the same prediction as the query.
+
+ The search is done in a projection space where distances are relevant for the model. The projection space is defined
+ by the `projection` method.
+ """
+ @property
+ def search_method_class(self):
+ """
+ This property defines the search method class to use for the search. In this case, it is the KLEORSimMissSearch.
+ """
+ return KLEORSimMissSearch
+
+class KLEORGlobalSim(KLEORBase):
+ """
+ The KLEORGlobalSim method search for Semi-Factuals examples by searching for the Nearest Unlike Neighbor (NUN) of
+ the query. The NUN is the closest example to the query that has a different prediction than the query. Then, the
+ method search for the K-Nearest Neighbors (KNN) of the NUN that have the same prediction as the query.
+
+ In addition, for a SF candidate to be considered, the SF should be closer to the query than the NUN in the
+ projection space (i.e. the SF should be 'between' the input and its NUN). This condition is added to the search.
+
+ The search is done in a projection space where distances are relevant for the model. The projection space is defined
+ by the `projection` method.
+ """
+ @property
+ def search_method_class(self):
+ """
+ This property defines the search method class to use for the search. In this case, it is the
+ KLEORGlobalSimSearch.
+ """
+ return KLEORGlobalSimSearch
diff --git a/xplique/example_based/similar_examples.py b/xplique/example_based/similar_examples.py
index b1433370..4c598fd8 100644
--- a/xplique/example_based/similar_examples.py
+++ b/xplique/example_based/similar_examples.py
@@ -4,10 +4,11 @@
import tensorflow as tf
import numpy as np
+from ..attributions.base import BlackBoxExplainer
from ..types import Callable, List, Optional, Type, Union
from .search_methods import KNN, BaseSearchMethod, ORDER
-from .projections import Projection
+from .projections import Projection, AttributionProjection, HadamardProjection
from .base_example_method import BaseExampleMethod
@@ -103,3 +104,117 @@ def __init__(
@property
def search_method_class(self) -> Type[BaseSearchMethod]:
return KNN
+
+
+class Cole(SimilarExamples):
+ """
+ Cole is a similar examples method that gives the most similar examples to a query in some specific projection space.
+ Cole use the model (to be explained) to build a search space so that distances are meaningful for the model.
+ It uses attribution methods to weight inputs.
+ Those attributions may be computed in the latent space for high-dimensional data like images.
+
+ It is an implementation of a method proposed by Kenny et Keane in 2019,
+ Twin-Systems to Explain Artificial Neural Networks using Case-Based Reasoning:
+ https://researchrepository.ucd.ie/handle/10197/11064
+
+ Parameters
+ ----------
+ cases_dataset
+ The dataset used to train the model, examples are extracted from this dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ labels_dataset
+ Labels associated to the examples in the dataset. Indices should match with cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ targets_dataset
+ Targets associated to the cases_dataset for dataset projection, oftentimes the one-hot encoding of a model's
+ predictions. See `projection` for detail.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
+ k
+ The number of examples to retrieve per input.
+ distance
+ Distance function for examples search. It can be an integer, a string in
+ {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable,
+ by default "euclidean".
+ case_returns
+ String or list of string with the elements to return in `self.explain()`.
+ See the base class returns property for details.
+ batch_size
+ Number of sample treated simultaneously for projection and search.
+ Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
+ latent_layer
+ Layer used to split the model, the first part will be used for projection and
+ the second to compute the attributions. By default, the model is not split.
+ For such split, the `model` should be a `tf.keras.Model`.
+
+ Layer to target for the outputs (e.g logits or after softmax).
+ If an `int` is provided it will be interpreted as a layer index.
+ If a `string` is provided it will look for the layer name.
+
+ The method as described in the paper apply the separation on the last convolutional layer.
+ To do so, the `"last_conv"` parameter will extract it.
+ Otherwise, `-1` could be used for the last layer before softmax.
+ attribution_method
+ Class of the attribution method to use for projection.
+ It should inherit from `xplique.attributions.base.BlackBoxExplainer`.
+ By default, it computes the gradient to make the Hadamard product in the latent space.
+ attribution_kwargs
+ Parameters to be passed for the construction of the `attribution_method`.
+ """
+ def __init__(
+ self,
+ cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
+ model: tf.keras.Model,
+ targets_dataset: Union[tf.Tensor, np.ndarray],
+ labels_dataset: Optional[Union[tf.Tensor, np.ndarray]] = None,
+ k: int = 1,
+ distance: Union[str, Callable] = "euclidean",
+ case_returns: Optional[Union[List[str], str]] = "examples",
+ batch_size: Optional[int] = 32,
+ latent_layer: Optional[Union[str, int]] = None,
+ attribution_method: Union[str, Type[BlackBoxExplainer]] = "gradient",
+ **attribution_kwargs,
+ ):
+ assert targets_dataset is not None
+
+ # build the corresponding projection
+ if isinstance(attribution_method, str) and attribution_method.lower() == "gradient":
+
+ operator = attribution_kwargs.get("operator", None)
+
+ projection = HadamardProjection(
+ model=model,
+ latent_layer=latent_layer,
+ operator=operator,
+ )
+ elif issubclass(attribution_method, BlackBoxExplainer):
+ # build attribution projection
+ projection = AttributionProjection(
+ model=model,
+ method=attribution_method,
+ latent_layer=latent_layer,
+ **attribution_kwargs,
+ )
+ else:
+ raise ValueError(
+ f"attribution_method should be 'gradient' or a subclass of BlackBoxExplainer," +\
+ "not {attribution_method}"
+ )
+
+ super().__init__(
+ cases_dataset=cases_dataset,
+ targets_dataset=targets_dataset,
+ labels_dataset=labels_dataset,
+ projection=projection,
+ k=k,
+ case_returns=case_returns,
+ batch_size=batch_size,
+ distance=distance,
+ )
From 56baf6b315c45f24541eb3cd5a1164d7d4d8c62c Mon Sep 17 00:00:00 2001
From: Mohamed Chafik Bakey
Date: Wed, 24 Jul 2024 18:17:56 +0200
Subject: [PATCH 077/138] add the documentation for the prototypes search
methods fix up
---
.../prototypes/api_prototypes.md | 59 ++++++++++++++-----
...{search_method_md.md => search_methods.md} | 0
2 files changed, 43 insertions(+), 16 deletions(-)
rename docs/api/example_based/{search_method_md.md => search_methods.md} (100%)
diff --git a/docs/api/example_based/prototypes/api_prototypes.md b/docs/api/example_based/prototypes/api_prototypes.md
index dc23bc9b..c02132cd 100644
--- a/docs/api/example_based/prototypes/api_prototypes.md
+++ b/docs/api/example_based/prototypes/api_prototypes.md
@@ -1,28 +1,44 @@
# Prototypes
-Prototype-based explanation is a family of natural example-based XAI methods. Prototypes consist of a set of samples that are representative of either the dataset or a class. Three classes of prototype-based methods are found in the literature ([Poché et al., 2023](https://hal.science/hal-04117520/document)): [Prototypes for Data-Centric Interpretability](#prototypes-for-data-centric-interpretability), [Prototypes for Post-hoc Interpretability](#prototypes-for-post-hoc-interpretability) and Prototype-Based Models Interpretable by Design. This library focuses on first two classes.
+Prototype-based explanation is a family of natural example-based XAI methods. Prototypes consist of a set of samples that are representative of either the dataset or a class. Three classes of prototype-based methods are found in the literature ([Poché et al., 2023](https://hal.science/hal-04117520/document)):
+
+- [Prototypes for Data-Centric Interpretability](#prototypes-for-data-centric-interpretability)
+- [Prototypes for Post-hoc Interpretability](#prototypes-for-post-hoc-interpretability)
+- Prototype-Based Models Interpretable by Design
+
+This library focuses on first two classes.
## Prototypes for Data-Centric Interpretability
In this class, prototypes are selected without relying on the model and provide an overview of
-the dataset. As mentioned in ([Poché et al., 2023](https://hal.science/hal-04117520/document)), we found `clustering methods`, `set cover methods` and `data summarization methods`. This library focuses on `data summarization methods`, also known as `set cover problem methods`, which can be treated in two ways [(Lin et al., 2011)](https://aclanthology.org/P11-1052.pdf):
+the dataset. As mentioned in ([Poché et al., 2023](https://hal.science/hal-04117520/document)), we found in this class: **clustering methods** and **data summarization methods**, also known as **set cover methods**. This library focuses on **data summarization methods** which can be treated in two ways [(Lin et al., 2011)](https://aclanthology.org/P11-1052.pdf):
-- **Summarization with knapsack constraint**:
+- **Data summarization with knapsack constraint**:
consists in finding a subset of prototypes $\mathcal{P}$ that maximizes the coverage set function $F(\mathcal{P})$ under the constraint that its selection cost $C(\mathcal{P})$ (e.g., the number of selected prototypes $|\mathcal{P}|$) should be less than a given budget.
-- **Summarization with covering constraint**:
-consists in finding a low-cost subset under the constraint it should cover all the data. For both cases, submodularity and monotonicity of $F(\mathcal{P})$ are necessary to guarantee that a greedy algorithm has a constant factor guarantee of optimality [(Lin et al., 2011)](https://aclanthology.org/P11-1052.pdf). In addition, $F(\mathcal{P})$ should encourage coverage and penalize redundancy in order to have a good summary [(Lin et al., 2011)](https://aclanthology.org/P11-1052.pdf).
+- **Data summarization with covering constraint**:
+consists in finding a low-cost subset of prototypes $\mathcal{P}$ under the constraint it should cover all the data.
-This library implements three methods from **Summarization with knapsack constraint**: `MMDCritic`, `ProtoGreedy` and `ProtoDash`.
-[Kim et al., 2016](https://proceedings.neurips.cc/paper_files/paper/2016/file/5680522b8e2bb01943234bce7bf84534-Paper.pdf) proposed `MMDCritic` method that used a set function based on the Maximum Mean Discrepancy [(MMD)](#what-is-mmd). They added additional diagonal dominance conditions on the kernel to ensure monotonocity and submodularity. They solve summarization with knapsack constraint problem to find both prototypes and criticisms. First, the number of prototypes and criticisms to be found, respectively as $m_p$ and $m_c$, are selected. Second, to find prototypes, a greedy algorithm is used to maximize $F(\mathcal{P})$ s.t. $|\mathcal{P}| \le m_p$ where $F(\mathcal{P})$ is defined as:
+For both cases, submodularity and monotonicity of $F(\mathcal{P})$ are necessary to guarantee that a greedy algorithm has a constant factor guarantee of optimality [(Lin et al., 2011)](https://aclanthology.org/P11-1052.pdf). In addition, $F(\mathcal{P})$ should encourage coverage and penalize redundancy in order to have a good summary [(Lin et al., 2011)](https://aclanthology.org/P11-1052.pdf).
+
+This library implements three methods from **Data summarization with knapsack constraint**: `MMDCritic`, `ProtoGreedy` and `ProtoDash`.
+[Kim et al., 2016](https://proceedings.neurips.cc/paper_files/paper/2016/file/5680522b8e2bb01943234bce7bf84534-Paper.pdf) proposed `MMDCritic` method that used a set function based on the Maximum Mean Discrepancy [(MMD)](#what-is-mmd). They solved **data summarization with knapsack constraint** problem to find both prototypes and criticisms. First, the number of prototypes and criticisms to be found, respectively as $m_p$ and $m_c$, are selected. Second, to find prototypes, a greedy algorithm is used to maximize $F(\mathcal{P})$ s.t. $|\mathcal{P}| \le m_p$ where $F(\mathcal{P})$ is defined as:
\begin{equation}
F(\mathcal{P})=\frac{2}{|\mathcal{P}|\cdot n}\sum_{i,j=1}^{|\mathcal{P}|,n}\kappa(p_i,x_j)-\frac{1}{|\mathcal{P}|^2}\sum_{i,j=1}^{|\mathcal{P}|}\kappa(p_i,p_j)
\end{equation}
-Finally, to find criticisms $\mathcal{C}$, the same greedy algorithm is used to select points that maximize another objective function $J(\mathcal{C})$.
+They used diagonal dominance conditions on the kernel to ensure monotonocity and submodularity of $F(\mathcal{P})$. To find criticisms $\mathcal{C}$, the same greedy algorithm is used to select points that maximize another objective function $J(\mathcal{C})$.
-[Gurumoorthy et al., 2019](https://arxiv.org/pdf/1707.01212) associated non-negative weights to prototypes which are indicative of their importance. In this way, both prototypes and criticisms (which are the least weighted examples from prototypes) can be found by maximizing the same set function $F(\mathcal{P})$. They established the weak submodular property of $J(\mathcal{P})$ and present tractable algorithms (`ProtoGreedy` and `ProtoDash`) to optimize it. Their method works for any symmetric positive definite kernel which is not the case for `MMDCritic`. First, they define a weighted objective $F(\mathcal{P},w)$:
+[Gurumoorthy et al., 2019](https://arxiv.org/pdf/1707.01212) associated non-negative weights to prototypes which are indicative of their importance. This approach allows for identifying both prototypes and criticisms (the least weighted examples among prototypes) by maximizing the same weighted objective $F(\mathcal{P},w)$ defined as:
\begin{equation}
F(\mathcal{P},w)=\frac{2}{n}\sum_{i,j=1}^{|\mathcal{P}|,n}w_i\kappa(p_i,x_j)-\sum_{i,j=1}^{|\mathcal{P}|}w_iw_j\kappa(p_i,p_j),
\end{equation}
-where $w$ are non-negative weights for each prototype. Then, they find $\mathcal{P}$ with a corresponding $w$ that maximizes $J(\mathcal{P}) \equiv \max_{w:supp(w)\in \mathcal{P},w\ge 0} J(\mathcal{P},w)$ s.t. $|\mathcal{P}| \leq m=m_p+m_c$. $J(\mathcal{P})$ can be maximized either by `ProtoGreedy` or by `ProtoDash`. `ProtoGreedy` selects the next element that maximizes the increment of the scoring function while `Protodash` selects the next element that maximizes the gradient of $F(\mathcal{P},w)$ with respect to $w$. `ProtoDash` is much faster than `ProtoGreedy` without compromising on the quality of the solution (the complexity of `ProtoGreedy` is $O(n(n+m^4))$ comparing to $O(n(n+m^2)+m^4)$ for `ProtoDash`). The difference between `ProtoGreedy` and the greedy algorithm of `MMDCritic` is that `ProtoGreedy` additionally determines the weights for each of the selected prototypes. The approximation guarantee is $(1-e^{-\gamma})$ for `ProtoGreedy`, where $\gamma$ is submodularity ratio of $F(\mathcal{P})$, comparing to $(1-e^{-1})$ for `MMDCritic`.
+where $w$ are non-negative weights for each prototype. The problem then consist on finding $\mathcal{P}$ with a corresponding $w$ that maximizes $J(\mathcal{P}) \equiv \max_{w:supp(w)\in \mathcal{P},w\ge 0} J(\mathcal{P},w)$ s.t. $|\mathcal{P}| \leq m=m_p+m_c$. They established the weak submodular property of $J(\mathcal{P})$ and present tractable algorithms (`ProtoGreedy` and `ProtoDash`) to optimize it.
+
+### Method comparison
+
+- Compared to `MMDCritic`, both `ProtoGreedy` and `Protodash` additionally determine the weights for each of the selected prototypes.
+- `ProtoGreedy` and `Protodash` works for any symmetric positive definite kernel which is not the case for `MMDCritic`.
+- `MMDCritic` and `ProtoGreedy` select the next element that maximizes the increment of the scoring function while `Protodash` maximizes a tight lower bound on the increment of the scoring function (it maximizes the gradient of $F(\mathcal{P},w)$).
+- `ProtoDash` is much faster than `ProtoGreedy` without compromising on the quality of the solution (the complexity of `ProtoGreedy` is $O(n(n+m^4))$ comparing to $O(n(n+m^2)+m^4)$ for `ProtoDash`).
+- The approximation guarantee for `ProtoGreedy` is $(1-e^{-\gamma})$, where $\gamma$ is submodularity ratio of $F(\mathcal{P})$, comparing to $(1-e^{-1})$ for `MMDCritic`.
### What is MMD?
The commonality among these three methods is their utilization of the Maximum Mean Discrepancy (MMD) statistic as a measure of similarity between points and potential prototypes. MMD is a statistic for comparing two distributions (similar to KL-divergence). However, it is a non-parametric statistic, i.e., it does not assume a specific parametric form for the probability distributions being compared. It is defined as follows:
@@ -59,7 +75,9 @@ If we consider any exponential kernel (Gaussian kernel, Laplace, ...), we automa
### Default kernel
The default kernel used is Gaussian kernel. This kernel distance assigns higher similarity to points that are close in feature space and gradually decreases similarity as points move further apart. It is a good choice when your data has complexity. However, it can be sensitive to the choice of hyperparameters, such as the width $\sigma$ of the Gaussian kernel, which may need to be carefully fine-tuned.
-The Data-Centric prototypes methods are implemented as [search methods](../../xplique/example_based/search_methods/):
+### API Implementation
+
+The Data-Centric prototypes methods are implemented as [search methods](../../search_methods/):
| Method Name and Documentation link | **Tutorial** | Available with TF | Available with PyTorch* |
|:-------------------------------------- | :----------------------: | :---------------: | :---------------------: |
@@ -71,18 +89,27 @@ The Data-Centric prototypes methods are implemented as [search methods](../../xp
The class `ProtoGreedySearch` inherits from the `BaseSearchMethod` class. It finds prototypes and assigns a non-negative weight to each one.
-Both the `MMDCriticSearch` and `ProtoDashSearch` classes inherit from the `ProtoGreedySearch` class.
-
-The class `MMDCriticSearch` differs from `ProtoGreedySearch` by assigning equal weights to the selection of prototypes. The two classes use the same greedy algorithm. In the `compute_objective` method of `ProtoGreedySearch`, for each new candidate, we calculate the best weights for the selection of prototypes. However, in `MMDCriticSearch`, the `compute_objective` method assigns the same weight to all elements in the selection.
+Both the `MMDCriticSearch` and `ProtoDashSearch` classes inherit from the `ProtoGreedySearch` class. The class `MMDCriticSearch` differs from `ProtoGreedySearch` by assigning equal weights to the selection of prototypes. The two classes use the same greedy algorithm. In the `compute_objective` method of `ProtoGreedySearch`, for each new candidate, we calculate the best weights for the selection of prototypes. However, in `MMDCriticSearch`, the `compute_objective` method assigns the same weight to all elements in the selection.
-The class `ProtoDashSearch`, like `ProtoGreedySearch`, assigns a non-negative weight to each prototype. However, the algorithm used by `ProtoDashSearch` is different: it maximizes a tight lower bound on $l(w)$ instead of maximizing $l(w)$, as done in `ProtoGreedySearch`. Therefore, `ProtoDashSearch` overrides the `compute_objective` method to calculate an objective based on the gradient of $l(w)$. It also overrides the `update_selection` method to select the best weights of the selection based on the gradient of the best candidate.
+The class `ProtoDashSearch`, like `ProtoGreedySearch`, assigns a non-negative weight to each prototype. However, the algorithm used by `ProtoDashSearch` is [different](#method-comparison) from the one used by `ProtoGreedySearch`. Therefore, `ProtoDashSearch` overrides both the `compute_objective` method and the `update_selection` method.
## Prototypes for Post-hoc Interpretability
-Data-Centric methods such as `Protogreedy`, `ProtoDash` and `MMDCritic` can be used in either the output or the latent space of the model. In these cases, [projections methods](./algorithms/projections/) are used to transfer the data from the input space to the latent/output spaces.
+Data-Centric methods such as `Protogreedy`, `ProtoDash` and `MMDCritic` can be used in either the output or the latent space of the model. In these cases, [projections methods](../../projections/) are used to transfer the data from the input space to the latent/output spaces.
The search method can have attribute `projection` that projects samples to a space where distances between samples make sense for the model. Then the `search_method` finds the prototypes by looking in the projected space.
+## Common API ##
+```python
+explainer = Method(cases_dataset, labels_dataset, targets_dataset, k,
+ projection, case_returns, batch_size, distance,
+ nb_prototypes, kernel_type,
+ kernel_fn, gamma)
+# compute global explanation
+global_prototypes = explainer.get_global_prototypes()
+# compute local explanation
+local_prototypes = explainer(inputs)
+```
diff --git a/docs/api/example_based/search_method_md.md b/docs/api/example_based/search_methods.md
similarity index 100%
rename from docs/api/example_based/search_method_md.md
rename to docs/api/example_based/search_methods.md
From fdee972028f8aa1e7d87dc18554bc63a67ae6b44 Mon Sep 17 00:00:00 2001
From: Mohamed Chafik Bakey
Date: Thu, 25 Jul 2024 12:42:23 +0200
Subject: [PATCH 078/138] add the documentation for the prototypes search
methods fix up
---
.../example_based/prototypes/mmd_critic.md | 57 ++++++++++++++++-
.../example_based/prototypes/proto_dash.md | 61 ++++++++++++++++++-
.../example_based/prototypes/proto_greedy.md | 61 ++++++++++++++++++-
3 files changed, 175 insertions(+), 4 deletions(-)
diff --git a/docs/api/example_based/prototypes/mmd_critic.md b/docs/api/example_based/prototypes/mmd_critic.md
index 2ec4a219..8743fdbc 100644
--- a/docs/api/example_based/prototypes/mmd_critic.md
+++ b/docs/api/example_based/prototypes/mmd_critic.md
@@ -1,3 +1,58 @@
# MMDCriticSearch
-MMDCriticSearch ([Kim et al., 2016](https://proceedings.neurips.cc/paper_files/paper/2016/file/5680522b8e2bb01943234bce7bf84534-Paper.pdf))
\ No newline at end of file
+
+
+[View colab tutorial](https://colab.research.google.com/drive/1nsB7xdQbU0zeYQ1-aB_D-M67-RAnvt4X) |
+
+
+[View source](https://github.com/deel-ai/xplique/blob/antonin/example-based-merge/xplique/example_based/search_methods/proto_greedy_search.py) |
+📰 [Paper](https://proceedings.neurips.cc/paper_files/paper/2016/file/5680522b8e2bb01943234bce7bf84534-Paper.pdf)
+
+`MMDCriticSearch` finds prototypes and criticisms by maximizing two separate objectives based on the Maximum Mean Discrepancy (MMD).
+
+!!! quote
+ MMD-critic uses the MMD statistic as a measure of similarity between points and potential prototypes, and
+ efficiently selects prototypes that maximize the statistic. In addition to prototypes, MMD-critic selects criticism samples i.e. samples that are not well-explained by the prototypes using a regularized witness function score.
+
+ -- [Efficient Data Representation by Selecting Prototypes with Importance Weights (2019).](https://arxiv.org/abs/1707.01212)
+
+First, to find prototypes $\mathcal{P}$, a greedy algorithm is used to maximize $F(\mathcal{P})$ s.t. $|\mathcal{P}| \le m_p$ where $F(\mathcal{P})$ is defined as:
+\begin{equation}
+ F(\mathcal{P})=\frac{2}{|\mathcal{P}|\cdot n}\sum_{i,j=1}^{|\mathcal{P}|,n}\kappa(p_i,x_j)-\frac{1}{|\mathcal{P}|^2}\sum_{i,j=1}^{|\mathcal{P}|}\kappa(p_i,p_j),
+\end{equation}
+where $m_p$ the number of prototypes to be found. They used diagonal dominance conditions on the kernel to ensure monotonocity and submodularity of $F(\mathcal{P})$.
+
+Second, to find criticisms $\mathcal{C}$, the same greedy algorithm is used to select points that maximize another objective function $J(\mathcal{C})$.
+
+!!!warning
+ For `MMDCritic`, the kernel must satisfy a condition that ensures the submodularity of the set function. The Gaussian kernel meets this requirement and it is recommended. If you wish to choose a different kernel, it must satisfy the condition described by [Kim et al., 2016](https://proceedings.neurips.cc/paper_files/paper/2016/file/5680522b8e2bb01943234bce7bf84534-Paper.pdf).
+
+## Example
+
+```python
+from xplique.example_based import MMDCritic
+
+# load data and labels
+# ...
+
+explainer = MMDCritic(cases_dataset, labels_dataset, targets_dataset, k,
+ projection, case_returns, batch_size, distance,
+ nb_prototypes, kernel_type,
+ kernel_fn, gamma)
+# compute global explanation
+global_prototypes = explainer.get_global_prototypes()
+# compute local explanation
+local_prototypes = explainer(inputs)
+```
+
+## Notebooks
+
+- [**Prototypes**: Getting started](https://colab.research.google.com/drive
+/1XproaVxXjO9nrBSyyy7BuKJ1vy21iHs2)
+- [**MMDCritic**: Going Further](https://colab.research.google.com/drive/1nsB7xdQbU0zeYQ1-aB_D-M67-RAnvt4X)
+
+
+{{xplique.example_based.search_methods.MMDCriticSearch}}
+
+[^1]: [Visual Explanations from Deep Networks via Gradient-based Localization (2016).](https://arxiv.org/abs/1610.02391)
+
diff --git a/docs/api/example_based/prototypes/proto_dash.md b/docs/api/example_based/prototypes/proto_dash.md
index b54dec50..d694504d 100644
--- a/docs/api/example_based/prototypes/proto_dash.md
+++ b/docs/api/example_based/prototypes/proto_dash.md
@@ -1,3 +1,60 @@
-# ProtoGreedySearch
+# ProtoDashSearch
-ProtoDashSearch ([Gurumoorthy et al., 2019](https://arxiv.org/abs/1707.01212))
\ No newline at end of file
+
+
+[View colab tutorial](https://colab.research.google.com/drive/1nsB7xdQbU0zeYQ1-aB_D-M67-RAnvt4X) |
+
+
+[View source](https://github.com/deel-ai/xplique/blob/antonin/example-based-merge/xplique/example_based/search_methods/proto_greedy_search.py) |
+📰 [Paper](https://arxiv.org/abs/1707.01212)
+
+`ProtoDahsSearch` associated non-negative weights to prototypes which are indicative of their importance. This approach allows for identifying both prototypes and criticisms (the least weighted examples among prototypes) by maximmizing the same weighted objective function.
+
+!!! quote
+ Our work notably generalizes the recent work
+ by Kim et al. (2016) where in addition to selecting prototypes, we
+ also associate non-negative weights which are indicative of their
+ importance. This extension provides a single coherent framework
+ under which both prototypes and criticisms (i.e. outliers) can be
+ found. Furthermore, our framework works for any symmetric
+ positive definite kernel thus addressing one of the key open
+ questions laid out in Kim et al. (2016).
+
+ -- [Efficient Data Representation by Selecting Prototypes with Importance Weights (2019).](https://arxiv.org/abs/1707.01212)
+
+More precisely, the weighted objective $F(\mathcal{P},w)$ is defined as:
+\begin{equation}
+F(\mathcal{P},w)=\frac{2}{n}\sum_{i,j=1}^{|\mathcal{P}|,n}w_i\kappa(p_i,x_j)-\sum_{i,j=1}^{|\mathcal{P}|}w_iw_j\kappa(p_i,p_j),
+\end{equation}
+where $w$ are non-negative weights for each prototype. The problem then consist on finding a subset $\mathcal{P}$ with a corresponding $w$ that maximizes $J(\mathcal{P}) \equiv \max_{w:supp(w)\in \mathcal{P},w\ge 0} J(\mathcal{P},w)$ s.t. $|\mathcal{P}| \leq m=m_p+m_c$.
+
+[Gurumoorthy et al., 2019](https://arxiv.org/abs/1707.01212) proposed `ProtoDash` algorithm, which is much faster that `ProtoGreedy` without compromising on the quality of the solution. In fact, `ProtoGreedy` selects the next element that maximizes the increment of the scoring function, whereas `ProtoDash` selects the next element that maximizes a tight lower bound on the increment of the scoring function.
+
+## Example
+
+```python
+from xplique.example_based import ProtoDash
+
+# load data and labels
+# ...
+
+explainer = ProtoDash(cases_dataset, labels_dataset, targets_dataset, k,
+ projection, case_returns, batch_size, distance,
+ nb_prototypes, kernel_type,
+ kernel_fn, gamma)
+# compute global explanation
+global_prototypes = explainer.get_global_prototypes()
+# compute local explanation
+local_prototypes = explainer(inputs)
+```
+
+## Notebooks
+
+- [**Prototypes**: Getting started](https://colab.research.google.com/drive
+/1XproaVxXjO9nrBSyyy7BuKJ1vy21iHs2)
+- [**ProtoDash**: Going Further](https://colab.research.google.com/drive/1nsB7xdQbU0zeYQ1-aB_D-M67-RAnvt4X)
+
+
+{{xplique.example_based.search_methods.ProtoDashSearch}}
+
+[^1]: [Visual Explanations from Deep Networks via Gradient-based Localization (2016).](https://arxiv.org/abs/1610.02391)
diff --git a/docs/api/example_based/prototypes/proto_greedy.md b/docs/api/example_based/prototypes/proto_greedy.md
index 9213caa1..57644ef3 100644
--- a/docs/api/example_based/prototypes/proto_greedy.md
+++ b/docs/api/example_based/prototypes/proto_greedy.md
@@ -1,3 +1,62 @@
# ProtoGreedySearch
-ProtoGreedySearch ([Gurumoorthy et al., 2019](https://arxiv.org/abs/1707.01212))
\ No newline at end of file
+
+
+[View colab tutorial](https://colab.research.google.com/drive/1nsB7xdQbU0zeYQ1-aB_D-M67-RAnvt4X) |
+
+
+[View source](https://github.com/deel-ai/xplique/blob/antonin/example-based-merge/xplique/example_based/search_methods/proto_greedy_search.py) |
+📰 [Paper](https://arxiv.org/abs/1707.01212)
+
+`ProtoGreedySearch` associated non-negative weights to prototypes which are indicative of their importance. This approach allows for identifying both prototypes and criticisms (the least weighted examples among prototypes) by maximmizing the same weighted objective function.
+
+!!! quote
+ Our work notably generalizes the recent work
+ by Kim et al. (2016) where in addition to selecting prototypes, we
+ also associate non-negative weights which are indicative of their
+ importance. This extension provides a single coherent framework
+ under which both prototypes and criticisms (i.e. outliers) can be
+ found. Furthermore, our framework works for any symmetric
+ positive definite kernel thus addressing one of the key open
+ questions laid out in Kim et al. (2016).
+
+ -- [Efficient Data Representation by Selecting Prototypes with Importance Weights (2019).](https://arxiv.org/abs/1707.01212)
+
+More precisely, the weighted objective $F(\mathcal{P},w)$ is defined as:
+\begin{equation}
+F(\mathcal{P},w)=\frac{2}{n}\sum_{i,j=1}^{|\mathcal{P}|,n}w_i\kappa(p_i,x_j)-\sum_{i,j=1}^{|\mathcal{P}|}w_iw_j\kappa(p_i,p_j),
+\end{equation}
+where $w$ are non-negative weights for each prototype. The problem then consist on finding a subset $\mathcal{P}$ with a corresponding $w$ that maximizes $J(\mathcal{P}) \equiv \max_{w:supp(w)\in \mathcal{P},w\ge 0} J(\mathcal{P},w)$ s.t. $|\mathcal{P}| \leq m=m_p+m_c$.
+
+[Gurumoorthy et al., 2019](https://arxiv.org/abs/1707.01212) demonstrate that this problem is weakly submodular, which immediately leads to a standard greedy algorithm which they call `ProtoGreedy`.
+
+`ProtoGreedy` is algorithmically similar to greedy algorithm used by [Kim et al., 2016](https://proceedings.neurips.cc/paper_files/paper/2016/file/5680522b8e2bb01943234bce7bf84534-Paper.pdf) where both the methods greedily select the next element that maximizes the increment of the scoring function.
+
+## Example
+
+```python
+from xplique.example_based import ProtoGreedy
+
+# load data and labels
+# ...
+
+explainer = ProtoGreedy(cases_dataset, labels_dataset, targets_dataset, k,
+ projection, case_returns, batch_size, distance,
+ nb_prototypes, kernel_type,
+ kernel_fn, gamma)
+# compute global explanation
+global_prototypes = explainer.get_global_prototypes()
+# compute local explanation
+local_prototypes = explainer(inputs)
+```
+
+## Notebooks
+
+- [**Prototypes**: Getting started](https://colab.research.google.com/drive
+/1XproaVxXjO9nrBSyyy7BuKJ1vy21iHs2)
+- [**ProtoGreedy**: Going Further](https://colab.research.google.com/drive/1nsB7xdQbU0zeYQ1-aB_D-M67-RAnvt4X)
+
+
+{{xplique.example_based.search_methods.ProtoGreedySearch}}
+
+[^1]: [Visual Explanations from Deep Networks via Gradient-based Localization (2016).](https://arxiv.org/abs/1610.02391)
From 96fbf96ef3669cbeb8656bdfe02bff359060c81f Mon Sep 17 00:00:00 2001
From: lucas Hervier
Date: Fri, 26 Jul 2024 12:21:51 +0200
Subject: [PATCH 079/138] docs: create the api page for example based methods,
search methods and projections
---
docs/api/example_based/api_example_based.md | 193 ++++++++++++++++++++
1 file changed, 193 insertions(+)
diff --git a/docs/api/example_based/api_example_based.md b/docs/api/example_based/api_example_based.md
index e69de29b..b1f6f76f 100644
--- a/docs/api/example_based/api_example_based.md
+++ b/docs/api/example_based/api_example_based.md
@@ -0,0 +1,193 @@
+# API: Example-based API
+
+- [**Example-based Methods**: Getting strated]() **WIP**
+
+## Context ##
+
+!!! quote
+ While saliency maps have stolen the show for the last few years in the XAI field, their ability to reflect models' internal processes has been questioned. Although less in the spotlight, example-based XAI methods have continued to improve. It encompasses methods that use examples as explanations for a machine learning model's predictions. This aligns with the psychological mechanisms of human reasoning and makes example-based explanations natural and intuitive for users to understand. Indeed, humans learn and reason by forming mental representations of concepts based on examples.
+
+ -- [Natural Example-Based Explainability: a Survey (2023)](https://arxiv.org/abs/2309.03234)[^1]
+
+As mentioned by our team members in the quote above, example-based methods are an alternative to saliency maps and can be more aligned with some users' expectations. Thus, we have been working on implementing some of those methods in Xplique that have been put aside in the previous developments.
+
+While not being exhaustive we tried to cover a range of methods that are representative of the field and that belong to different families: similar examples, contrastive (counter-factuals and semi-factuals) examples, and prototypes (as concepts based methods have a dedicated sections).
+
+At present, we made the following choices:
+- Focus on methods that are natural example methods (see the paper above for more details).
+- Try to unify the three families of approaches with a common API.
+
+!!! info
+ We are in the early stages of development and are looking for feedback on the API design and the methods we have chosen to implement. Also, we are counting on the community to furnish the collection of methods available. If you are willing to contribute reach us on the [GitHub](https://github.com/deel-ai/xplique) repository (with an issue, pull request, ...).
+
+## Common API ##
+
+```python
+explainer = ExampleMethod(
+ cases_dataset,
+ labels_dataset,
+ targets_dataset,
+ k,
+ projection,
+ case_returns,
+ batch_size,
+ **kwargs
+)
+
+explanations = explainer.explain(inputs, targets)
+```
+
+We tried to keep the API as close as possible to the one of the attribution methods to keep a consistent experience for the users.
+
+The `BaseExampleMethod` is an abstract base class designed for example-based methods used to explain classification models. It provides examples from a dataset (usually the training dataset) to help understand a model's predictions. Examples are selected using a [search method](#search-methods) within a defined search space, projected from the input space using a [projection function](#projections).
+
+??? abstract "Table of example-based methods available"
+
+ | Method | Documentation | Family |
+ | --- | --- | --- |
+ | `SimilarExamples` | [SimilarExamples](api/example_based/methods/similar_examples) | Similar Examples |
+ | `Cole` | [Cole](api/example_based/methods/cole) | Similar Examples |
+ | `ProtoGreedy` | [ProtoGreedy](api/example_based/methods/proto_greedy/) | Prototypes |
+ | `ProtoDash` | [ProtoDash](api/example_based/methods/proto_dash/) | Prototypes |
+ | `MMDCritic` | [MMDCritic](api/example_based/methods/mmd_critic/) | Prototypes |
+ | `NaiveCounterFactuals` | [NaiveCounterFactuals](api/example_based/methods/naive_counter_factuals/) | Counter Factuals |
+ | `LabelAwareCounterFactuals` | [LabelAwareCounterFactuals](api/example_based/methods/label_aware_counter_factuals/) | Counter Factuals |
+ | `KLEORSimMiss` | [KLEOR](api/example_based/methods/kleor/) | Semi Factuals |
+ | `KLEORGlobalSim` | [KLEOR](api/example_based/methods/kleor/) | Semi Factuals |
+
+### Parameters ###
+
+- **cases_dataset** (`Union[tf.data.Dataset, tf.Tensor, np.ndarray]`): The dataset used to train the model, from which examples are extracted. It should be batched as TensorFlow provides no method to verify this. Ensure the dataset is not reshuffled at each iteration.
+- **labels_dataset** (`Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]]`): Labels associated with the examples in the cases dataset. Indices should match the `cases_dataset`.
+- **targets_dataset** (`Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]]`): Targets associated with the `cases_dataset` for dataset projection, often the one-hot encoding of a model's predictions.
+- **k** (`int`): The number of examples to retrieve per input.
+- **projection** (`Union[Projection, Callable]`): A projection or callable function that projects samples from the input space to the search space. The search space should be relevant for the model. (see [Projections](#projections))
+- **case_returns** (`Union[List[str], str]`): Elements to return in `self.explain()`. Default is "examples".
+- **batch_size** (`Optional[int]`): Number of samples processed simultaneously for projection and search. Ignored if `tf.data.Dataset` is provided.
+
+### Properties ###
+
+- **search_method_class** (`Type[BaseSearchMethod]`): Abstract property to define the search method class to use. Must be implemented in subclasses. (see [Search Methods](#search-methods))
+- **k** (`int`): Getter and setter for the `k` parameter.
+- **returns** (`Union[List[str], str]`): Getter and setter for the `returns` parameter. Defines the elements to return in `self.explain()`.
+
+### `explain(self, inputs, targets)` ###
+
+Returns the relevant examples to explain the (inputs, targets). Projects inputs using `self.projection` and finds examples using the `self.search_method`.
+
+- **inputs** (`Union[tf.Tensor, np.ndarray]`): Input samples to be explained.
+- **targets** (`Optional[Union[tf.Tensor, np.ndarray]]`): Targets associated with the `cases_dataset` for dataset projection.
+
+**Returns:** Dictionary with elements listed in `self.returns`.
+
+!!!info
+ The `__call__` method is an alias for the `explain` method.
+
+## Projections ##
+Projections are functions that map input samples to a search space where examples are retrieved with a `search_method`. The search space should be relevant for the model (e.g. projecting the inputs into the latent space of the model).
+
+!!!info
+ If one decides to use the identity function as a projection, the search space will be the input space, thus rather explaining the dataset than the model. In this case, it may be more relevant to directly use a `search_method` ([Search Methods](#search-methods)) for the dataset.
+
+The `Projection` class is an abstract base class for projections. It involves two parts: `space_projection` and `weights`. The samples are first projected to a new space and then weighted.
+
+!!!warning
+ If both parts are `None`, the projection acts as an identity function. At least one part should involve the model to ensure meaningful distance calculations.
+
+??? abstract "Table of projection methods available"
+
+ | Method | Documentation |
+ | --- | --- |
+ | `Projection` | HERE |
+ | `LatentSpaceProjection`| [LatentSpaceProjection](api/example_based/projections/latent_space_projection/) |
+ | `HadamardProjection` | [HadamardProjection](api/example_based/projections/hadamard_projection/) |
+ | `AttributionProjection` | [AttributionProjection](api/example_based/projections/attribution_projection/) |
+
+### Parameters ###
+
+- **get_weights** (`Optional[Union[Callable, tf.Tensor, np.ndarray]]`): Either a Tensor or a callable function.
+ - **Tensor**: Weights are applied in the projected space.
+ - **Callable**: A function that takes inputs and targets, returning the weights (Tensor). Weights should match the input shape (possibly differing in channels).
+
+ **Example**:
+ ```python
+ def get_weights_example(projected_inputs: Union[tf.Tensor, np.ndarray],
+ targets: Optional[Union[tf.Tensor, np.ndarray]] = None):
+ # Compute weights using projected_inputs and targets.
+ weights = ... # Custom logic involving the model.
+ return weights
+ ```
+
+- **space_projection** (`Optional[Callable]`): Callable that takes samples and returns a Tensor in the projected space. An example of a projected space is the latent space of a model.
+- **device** (`Optional[str]`): Device to use for the projection. If `None`, the default device is used.
+- **mappable** (`bool`): If `True`, the projection can be applied to a dataset through `Dataset.map`. Otherwise, the projection is done through a loop.
+
+### `project(self, inputs, targets=None)` ###
+
+Projects samples into a space meaningful for the model. This involves weighting the inputs, projecting them into a latent space, or both. This method should be called during initialization and for each explanation.
+
+- **inputs** (`Union[tf.Tensor, np.ndarray]`): Input samples to be explained. Expected shapes include (N, W), (N, T, W), (N, W, H, C).
+- **targets** (`Optional[Union[tf.Tensor, np.ndarray]]`): Additional parameter for `self.get_weights` function.
+
+**Returns:** `projected_samples` - The samples projected into the new space.
+
+!!!info
+ The `__call__` method is an alias for the `project` method.
+
+### `project_dataset(self, cases_dataset, targets_dataset=None)` ###
+
+Applies the projection to a dataset through `Dataset.map`.
+
+- **cases_dataset** (`tf.data.Dataset`): Dataset of samples to be projected.
+- **targets_dataset** (`Optional[tf.data.Dataset]`): Dataset of targets for the samples.
+
+**Returns:** `projected_dataset` - The projected dataset.
+
+## Search Methods ##
+
+Search methods are used to retrieve examples from the `cases_dataset` that are relevant to the input samples.
+
+The `BaseSearchMethod` class is an abstract base class for example-based search methods. It defines the interface for search methods used to find examples in a dataset. This class should be inherited by specific search methods.
+
+??? abstract "Table of search methods available"
+
+ | Method | Documentation |
+ | --- | --- |
+ | `KNN` | [KNN](api/example_based/search_methods/knn/) |
+ | `FilterKNN` | [KNN](api/example_based/search_methods/knn/) |
+ | `ProtoGreedySearch` | [ProtoGreedySearch](api/example_based/search_methods/proto_greedy_search/) |
+ | `ProtoDashSearch` | [ProtoDashSearch](api/example_based/search_methods/proto_dash_search/) |
+ | `MMDCriticSearch` | [MMDCriticSearch](api/example_based/search_methods/mmd_critic_search/) |
+ | `KLEORSimMissSearch` | [KLEOR](api/example_based/search_methods/kleor/) |
+ | `KLEORGlobalSimSearch` | [KLEOR](api/example_based/search_methods/kleor/) |
+
+
+### Parameters ###
+
+- **cases_dataset** (`Union[tf.data.Dataset, tf.Tensor, np.ndarray]`): The dataset containing the examples to search in. It should be batched as TensorFlow provides no method to verify this. Ensure the dataset is not reshuffled at each iteration.
+- **k** (`int`): The number of examples to retrieve.
+- **search_returns** (`Optional[Union[List[str], str]]`): Elements to return in `self.find_examples()`. It should be a subset of `self._returns_possibilities`.
+- **batch_size** (`Optional[int]`): Number of samples treated simultaneously. It should match the batch size of the `cases_dataset` if it is a `tf.data.Dataset`.
+
+### Properties ###
+
+- **k** (`int`): Getter and setter for the `k` parameter.
+- **returns** (`Union[List[str], str]`): Getter and setter for the `returns` parameter. Defines the elements to return in `self.find_examples()`.
+
+### `find_examples(self, inputs, targets)` ###
+
+Abstract method to search for samples to return as examples. It should be implemented in subclasses. It may return the indices corresponding to the samples based on `self.returns` value.
+
+- **inputs** (`Union[tf.Tensor, np.ndarray]`): Input samples to be explained. Expected shapes include (N, W), (N, T, W), (N, W, H, C).
+- **targets** (`Optional[Union[tf.Tensor, np.ndarray]]`): Targets associated with the samples to be explained.
+
+**Returns:** `return_dict` - Dictionary containing the elements specified in `self.returns`.
+
+!!!info
+ The `__call__` method is an alias for the `find_examples` method.
+
+### `_returns_possibilities`
+
+Attribute thet list possible elements that can be returned by the search methods. For the base class: `["examples", "distances", "labels", "include_inputs"]`.
+
+[^1]: [Natural Example-Based Explainability: a Survey (2023)](https://arxiv.org/abs/2309.03234)
\ No newline at end of file
From 6e99393fa30fc864a2972e2e99413ae7dc76f809 Mon Sep 17 00:00:00 2001
From: lucas Hervier
Date: Fri, 26 Jul 2024 15:45:57 +0200
Subject: [PATCH 080/138] docs: create the KLEOR page
---
docs/api/example_based/methods/kleor.md | 68 +++++++++++++++++++++++++
1 file changed, 68 insertions(+)
create mode 100644 docs/api/example_based/methods/kleor.md
diff --git a/docs/api/example_based/methods/kleor.md b/docs/api/example_based/methods/kleor.md
new file mode 100644
index 00000000..66aa986c
--- /dev/null
+++ b/docs/api/example_based/methods/kleor.md
@@ -0,0 +1,68 @@
+# KLEOR
+
+
+
+ [View colab tutorial]()**WIP** |
+
+
+ [View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/kleor.py) |
+📰 [Paper](https://www.researchgate.net/publication/220106308_KLEOR_A_Knowledge_Lite_Approach_to_Explanation_Oriented_Retrieval)
+
+KLEOR for Knowledge-Light Explanation-Oriented Retrieval was introduced by Cummins & Bridge in 2006. It is a method that use counterfactuals, Nearest Unlike Neighbor (NUN), to guide the selection of a semi-factual (SF) example.
+
+Given a distance function $dist$, the NUN of a sample $(x, y)$ is the closest sample in the training dataset which has a different label than $y$.
+
+The KLEOR method actually have three variants including:
+
+- The Sim-Miss approach
+- The Global-Sim approach
+
+In the Sim-Miss approach, the SF of the sample $(x,y)$ is the closest training sample from the corresponding NUN which has the same label as $y$.
+
+Denoting the training dataset as $\mathcal{D}$:
+
+$$Sim-Miss(x, y, NUN(x,y), \mathcal{D}) = arg \\ min_{(x',y') \in \mathcal{D} \\ | \\ y'=y} dist(x', NUN(x,y))$$
+
+In the Global-Sim approach, they add an additional constraint that the SF should lie between the sample $(x,y)$ and the NUN that is: $dist(x, SF) < dist(x, NUN(x,y))$.
+
+We extended to the $k$ nearest neighbors of the NUN for both approaches.
+
+!!!info
+ In our implementation, we rather consider the labels predicted by the model $\hat{y}$ (*i.e.* the targets) rather than $y$!
+
+## Example
+
+```python
+from xplique.example_based import KLEORGlobalSim, KLEORSimMiss
+
+cases_dataset = ... # load the training dataset
+targets = ... # load the targets of the training dataset
+
+k = 5
+
+# instantiate the KLEOR objects
+kleor_sim_miss = KLEORSimMiss(cases_dataset=cases_dataset,
+ targets_dataset=targets,
+ k=k,
+ )
+
+kleor_global_sim = KLEORGlobalSim(cases_dataset=cases_dataset,
+ targets_dataset=targets,
+ k=k,
+ )
+
+# load the test samples and targets
+test_samples = ... # load the test samples to search for
+test_targets = ... # load the targets of the test samples
+
+# search the SFs for the test samples
+sim_miss_sf = kleor_sim_miss.explain(test_samples, test_targets)
+global_sim_sf = kleor_global_sim.explain(test_samples, test_targets)
+```
+
+## Notebooks
+
+TODO: Add the notebook
+
+{{xplique.example_based.semifactuals.KLEORSimMiss}}
+{{xplique.example_based.semifactuals.KLEORGlobalSim}}
\ No newline at end of file
From 026634e00ede31ace40b5d798930d398316cd699 Mon Sep 17 00:00:00 2001
From: lucas Hervier
Date: Fri, 26 Jul 2024 16:18:11 +0200
Subject: [PATCH 081/138] docs: add pages for naive and label aware cf, for
kleor search methods, knn and filter knn, fix a mistake in the kleor page
---
docs/api/example_based/methods/kleor.md | 2 +-
.../methods/label_aware_counter_factuals.md | 51 ++++++++++++++
.../methods/naive_counter_factuals.md | 51 ++++++++++++++
.../api/example_based/search_methods/kleor.md | 47 +++++++++++++
docs/api/example_based/search_methods/knn.md | 68 +++++++++++++++++++
5 files changed, 218 insertions(+), 1 deletion(-)
create mode 100644 docs/api/example_based/methods/label_aware_counter_factuals.md
create mode 100644 docs/api/example_based/methods/naive_counter_factuals.md
create mode 100644 docs/api/example_based/search_methods/kleor.md
create mode 100644 docs/api/example_based/search_methods/knn.md
diff --git a/docs/api/example_based/methods/kleor.md b/docs/api/example_based/methods/kleor.md
index 66aa986c..12b2a9fb 100644
--- a/docs/api/example_based/methods/kleor.md
+++ b/docs/api/example_based/methods/kleor.md
@@ -5,7 +5,7 @@
[View colab tutorial]()**WIP** |
- [View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/kleor.py) |
+ [View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/semifactuals.py) |
📰 [Paper](https://www.researchgate.net/publication/220106308_KLEOR_A_Knowledge_Lite_Approach_to_Explanation_Oriented_Retrieval)
KLEOR for Knowledge-Light Explanation-Oriented Retrieval was introduced by Cummins & Bridge in 2006. It is a method that use counterfactuals, Nearest Unlike Neighbor (NUN), to guide the selection of a semi-factual (SF) example.
diff --git a/docs/api/example_based/methods/label_aware_counter_factuals.md b/docs/api/example_based/methods/label_aware_counter_factuals.md
new file mode 100644
index 00000000..d08f6224
--- /dev/null
+++ b/docs/api/example_based/methods/label_aware_counter_factuals.md
@@ -0,0 +1,51 @@
+# Label Aware Counterfactuals
+
+
+
+ [View colab tutorial]()**WIP** |
+
+
+ [View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/counterfactuals.py) |
+📰 [Paper](https://www.semanticscholar.org/paper/Nearest-unlike-neighbor-(NUN)%3A-an-aid-to-decision-Dasarathy/48c1a310f655b827e5e7d712c859b25a4e3c0902)
+
+!!!note
+ The paper referenced here is not exactly the one we implemented. However, it is probably the closest in essence of what we implemented.
+
+In contrast to the [Naive Counterfactuals](api/example_based/methods/naive_counter_factuals/) approach, the Label Aware Counterfactuals leverage an *a priori* knowledge of the Counterfactuals' (CFs) targets to guide the search for the CFs (*e.g.* one is looking for a CF of the digit 8 in MNIST dataset within the digit 0 instances).
+
+!!!warning
+ Consequently, for this class, when a user call the `explain` method, the user is not expected to provide the targets corresponding to the input samples but rather a one-hot encoding of the targets of the CFs to search for.
+
+!!!info
+ One can use the `Projection` object to compute the distances between the samples (e.g. search for the CF in the latent space of a model).
+
+## Example
+
+```python
+from xplique.example_based import LabelAwareCounterfactuals
+
+# load the training dataset
+cases_dataset = ... # load the training dataset
+targets_dataset = ... # load the targets of the training dataset
+
+k = 5
+
+# instantiate the LabelAwareCounterfactuals object
+lacf = LabelAwareCounterfactuals(cases_dataset=cases_dataset,
+ targets_dataset=targets_dataset,
+ k=k,
+ )
+
+# load the test samples
+test_samples = ... # load the test samples to search for
+test_cf_targets = ... # WARNING: provide the one-hot encoding of the targets of the CFs to search for
+
+# search the CFs for the test samples
+counterfactuals = lacf.explain(test_samples, test_cf_targets)
+```
+
+## Notebooks
+
+TODO: Add notebooks
+
+{{xplique.example_based.counterfactuals.LabelAwareCounterfactuals}}
\ No newline at end of file
diff --git a/docs/api/example_based/methods/naive_counter_factuals.md b/docs/api/example_based/methods/naive_counter_factuals.md
new file mode 100644
index 00000000..35ed8779
--- /dev/null
+++ b/docs/api/example_based/methods/naive_counter_factuals.md
@@ -0,0 +1,51 @@
+# Naive Counterfactuals
+
+
+
+ [View colab tutorial]()**WIP** |
+
+
+ [View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/counterfactuals.py) |
+📰 [Paper](https://www.semanticscholar.org/paper/Nearest-unlike-neighbor-(NUN)%3A-an-aid-to-decision-Dasarathy/48c1a310f655b827e5e7d712c859b25a4e3c0902)
+
+!!!note
+ The paper referenced here is not exactly the one we implemented as we a "naive" of it. However, it is probably the closest in essence of what we implemented.
+
+We define here a "naive" counterfactual method that is based on the Nearest Unlike Neighbor (NUN) concept introduced by Dasarathy in 1991[^1]. In essence, the NUN of a sample $(x, y)$ is the closest sample in the training dataset which has a different label than $y$.
+
+Thus, in this naive approach to counterfactuals, we yield the $k$ nearest training instances that have a different label than the target of the input sample in a greedy fashion.
+
+As it is mentioned in the [API documentation](api/example_based/methods/api_example_based/), by setting a `Projection` object, one can use the projection space to compute the distances between the samples (e.g. search for the CF in the latent space of a model).
+
+## Example
+
+```python
+from xplique.example_based import NaiveCounterfactuals
+
+# load the training dataset
+cases_dataset = ... # load the training dataset
+targets_dataset = ... # load the targets of the training dataset
+
+k = 5
+
+# instantiate the NaiveCounterfactuals object
+ncf = NaiveCounterfactuals(cases_dataset=cases_dataset,
+ targets_dataset=targets_dataset,
+ k=k,
+ )
+
+# load the test samples and targets
+test_samples = ... # load the test samples to search for
+test_targets = ... # load the targets of the test samples
+
+# search the CFs for the test samples
+counterfactuals = ncf.explain(test_samples, test_targets)
+```
+
+## Notebooks
+
+TODO: Add notebooks
+
+{{xplique.example_based.counterfactuals.NaiveCounterfactuals}}
+
+[^1] [Nearest unlike neighbor (NUN): an aid to decision making](https://www.semanticscholar.org/paper/Nearest-unlike-neighbor-(NUN)%3A-an-aid-to-decision-Dasarathy/48c1a310f655b827e5e7d712c859b25a4e3c0902)
\ No newline at end of file
diff --git a/docs/api/example_based/search_methods/kleor.md b/docs/api/example_based/search_methods/kleor.md
new file mode 100644
index 00000000..9ad70ba8
--- /dev/null
+++ b/docs/api/example_based/search_methods/kleor.md
@@ -0,0 +1,47 @@
+# KLEOR Search Methods
+
+Those search methods are used for the [KLEOR](api/example_based/methods/kleor/) methods.
+
+It encompasses the two following classes:
+- `KLEORSimMissSearch`: looks for Semi-Factuals examples by searching for the Nearest Unlike Neighbor (NUN) of the query. The NUN is the closest example to the query that has a different prediction than the query. Then, the method searches for the K-Nearest Neighbors (KNN) of the NUN that have the same prediction as the query.
+- `KLEORGlobalSim`: in addition to the previous method, the SF should be closer to the query than the NUN to be a candidate.
+
+## Examples
+
+```python
+from xplique.example_based.search_methods import KLEORSimMissSearch
+from xplique.example_based.search_methods import KLEORGlobalSim
+
+cases_dataset = ... # load the training dataset
+targets = ... # load the targets of the training dataset
+
+test_samples = ... # load the test samples to search for
+test_targets = ... # load the targets of the test samples
+
+# set some parameters
+k = 5
+distance = "euclidean"
+
+# create the KLEORSimMissSearch object
+kleor_sim_miss_search = KLEORSimMissSearch(cases_dataset=cases_dataset,
+ targets_dataset=targets,
+ k=k,
+ distance=distance)
+
+# create the KLEORGlobalSim object
+kleor_global_sim = KLEORGlobalSim(cases_dataset=cases_dataset,
+ targets_dataset=targets,
+ k=k,
+ distance=distance)
+
+# search for the K-Nearest Neighbors of the test samples
+sim_miss_neighbors = kleor_sim_miss_search.find_examples(test_samples, test_targets)
+global_sim_neighbors = kleor_global_sim.find_examples(test_samples, test_targets)
+```
+
+## Notebooks
+
+TODO: add the notebook for KLEOR
+
+{{xplique.example_based.search_methods.kleor.KLEORSimMissSearch}}
+{{xplique.example_based.search_methods.kleor.KLEORGlobalSim}}
diff --git a/docs/api/example_based/search_methods/knn.md b/docs/api/example_based/search_methods/knn.md
new file mode 100644
index 00000000..7f0eb423
--- /dev/null
+++ b/docs/api/example_based/search_methods/knn.md
@@ -0,0 +1,68 @@
+# K Nearest Neighbors
+
+KNN method to search examples. Based on `sklearn.neighbors.NearestNeighbors` [see the documentation](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NearestNeighbors.html).
+The kneighbors method is implemented in a batched way to handle large datasets and try to be memory efficient.
+
+In addition, we also added a `FilterKNN` class that allows to filter the neighbors based on a given criterion avoiding potentially a compute of the distances for all the samples. It is useful when the candidate neighbors are sparse and the distance computation is expensive.
+
+## Examples
+
+```python
+from xplique.example_based.search_methods import ORDER
+from xplique.example_based.search_methods import KNN
+
+# set some parameters
+k = 5
+cases_dataset = ... # load the training dataset
+test_samples = ... # load the test samples to search for
+
+distance = "euclidean"
+order = ORDER.ASCENDING
+
+# create the KNN object
+knn = KNN(cases_dataset = cases_dataset
+ k = k,
+ distance = distance,
+ order = order)
+
+k_nearest_neighbors = knn.kneighbors(test_samples)
+```
+
+```python
+from xplique.example_based.search_methods import ORDER
+from xplique.example_based.search_methods import FilterKNN
+
+# set some parameters
+k = 5
+cases_dataset = ... # load the training dataset
+targets = ... # load the targets of the training dataset
+
+test_samples = ... # load the test samples to search for
+test_targets = ... # load the targets of the test samples
+
+distance = "euclidean"
+order = ORDER.ASCENDING
+
+# define a filter function
+def filter_fn(cases, inputs, targets, cases_targets):
+ # filter the cases that have the same target as the input
+ mask = tf.not_equal(targets, cases_targets)
+ return mask
+
+# create the KNN object
+filter_knn = FilterKNN(cases_dataset=cases_dataset,
+ targets_dataset=targets,
+ k=k,
+ distance=distance,
+ order=order,
+ filter_fn=filter_fn)
+
+k_nearest_neighbors = filter_knn.kneighbors(test_samples, test_targets)
+```
+
+## Notebooks
+
+TODO: add all notebooks that use this search method
+
+{{xplique.example_based.search_methods.knn.KNN}}
+{{xplique.example_based.search_methods.knn.FilterKNN}}
\ No newline at end of file
From 393914eb711ecb8b14370dcd727214cbaed35f1d Mon Sep 17 00:00:00 2001
From: lucas Hervier
Date: Tue, 30 Jul 2024 09:21:23 +0200
Subject: [PATCH 082/138] docs: add the page for similar examples
---
.../example_based/methods/similar_examples.md | 45 +++++++++++++++++++
1 file changed, 45 insertions(+)
create mode 100644 docs/api/example_based/methods/similar_examples.md
diff --git a/docs/api/example_based/methods/similar_examples.md b/docs/api/example_based/methods/similar_examples.md
new file mode 100644
index 00000000..4a5748a1
--- /dev/null
+++ b/docs/api/example_based/methods/similar_examples.md
@@ -0,0 +1,45 @@
+# Similar-Examples
+
+
+
+ [View colab tutorial]()**WIP** |
+
+
+ [View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/semifactuals.py)
+
+We designate here as *Similar Examples* all methods that given an input sample, search for the most similar **training** samples given a distance function `distance`. Furthermore, one can define the search space using a `projection` function (see [Projections](api/example_based/projections.md)). This function should map an input sample to the search space where the distance function is defined and meaningful (**e.g.** the latent space of a Convolutional Neural Network).
+Then, a K-Nearest Neighbors (KNN) search is performed to find the most similar samples in the search space.
+
+## Example
+
+```python
+from xplique.example_based import SimilarExamples
+
+cases_dataset = ... # load the training dataset
+k = 5
+distance = "euclidean"
+
+# define the projection function
+def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
+ '''
+ Example of projection,
+ inputs are the elements to project.
+ targets are optional parameters to orientate the projection.
+ '''
+ projected_inputs = # do some magic on inputs, it should use the model.
+ return projected_inputs
+
+# instantiate the SimilarExamples object
+sim_ex = SimilarExamples(
+ cases_dataset=cases_dataset,
+ k=k,
+ projection=custom_projection,
+ distance=distance,
+)
+```
+
+# Notebooks
+
+TODO: Add the notebook
+
+{{xplique.example_based.similar_examples.SimilarExamples}}
\ No newline at end of file
From 37312daf90bcb37913fbfad9053002e4cc2ac07b Mon Sep 17 00:00:00 2001
From: lucas Hervier
Date: Tue, 30 Jul 2024 10:00:18 +0200
Subject: [PATCH 083/138] docs: create the documentation page for Cole
---
docs/api/example_based/methods/cole.md | 62 ++++++++++++++++++++++++++
1 file changed, 62 insertions(+)
create mode 100644 docs/api/example_based/methods/cole.md
diff --git a/docs/api/example_based/methods/cole.md b/docs/api/example_based/methods/cole.md
new file mode 100644
index 00000000..38d52b43
--- /dev/null
+++ b/docs/api/example_based/methods/cole.md
@@ -0,0 +1,62 @@
+# COLE: Contributions Oriented Local Explanations
+
+
+
+ [View colab tutorial]()**WIP** |
+
+
+ [View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/similar_examples.py) |
+📰 [Paper](https://researchrepository.ucd.ie/handle/10197/11064)
+
+COLE for Contributions Oriented Local Explanations was introduced by Kenny & Keane in 2019.
+
+!!! quote
+ Our method COLE is based on the premise that the contributions of features in a model’s classification represent the most sensible basis to inform case-based explanations.
+
+ -- [COLE paper](https://researchrepository.ucd.ie/handle/10197/11064)[^1]
+
+The core idea of the COLE approach is to use [attribution maps](api/attributions/api_attributions/) to define a relevant search space for the K-Nearest Neighbors (KNN) search.
+
+More specifically, the COLE approach is based on the following steps:
+- (1) Given an input sample $x$, compute the attribution map $A(x)$
+- (2) Consider the projection space defined by: $p: x \rightarrow A(x) \odot x$ ($\odot$ denotes the element-wise product)
+- (3) Perform a KNN search in the projection space to find the most similar training samples
+
+!!! info
+ In the original paper, the authors focused on Multi-Layer Perceptrons (MLP) and three attribution methods (LPR, Integrated Gradient, and DeepLift). We decided to implement a COLE method that generalizes to a more broader range of Neural Networks and attribution methods that are gradient-based (see [API Attributions documentation](api/attributions/api_attributions/) for definition).
+
+## Example
+
+```python
+from xplique.example_based import Cole
+from xplique.attributions import Saliency
+
+model = ... # load the model
+cases_dataset = ... # load the training dataset
+target_dataset = ... # load the target dataset (predicted one-hot encoding of model's predictions)
+k = 5
+
+# instantiate the Cole object
+cole = Cole(
+ cases_dataset=cases_dataset,
+ model=model,
+ k=k,
+ attribution_method=Saliency,
+)
+
+# load the test samples
+test_samples = ... # load the test samples
+test_targets = ... # load the test targets
+
+# search the most similar samples with the COLE method
+similar_samples = cole.explain(test_samples, test_targets)
+```
+
+## Notebooks
+
+TODO: Add the notebook
+
+{{xplique.example_based.similar_examples.Cole}}
+
+[^1]: [Twin-Systems to Explain Artificial Neural Networks using Case-Based Reasoning:
+Comparative Tests of Feature-Weighting Methods in ANN-CBR Twins for XAI (2019)](https://researchrepository.ucd.ie/handle/10197/11064)
\ No newline at end of file
From 6d1247cca0959202536e8325980ff64cefc73f27 Mon Sep 17 00:00:00 2001
From: lucas Hervier
Date: Tue, 30 Jul 2024 10:01:05 +0200
Subject: [PATCH 084/138] fixup: wrong github link
---
docs/api/example_based/methods/similar_examples.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/api/example_based/methods/similar_examples.md b/docs/api/example_based/methods/similar_examples.md
index 4a5748a1..1dc9f1b8 100644
--- a/docs/api/example_based/methods/similar_examples.md
+++ b/docs/api/example_based/methods/similar_examples.md
@@ -5,7 +5,7 @@
[View colab tutorial]()**WIP** |
- [View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/semifactuals.py)
+ [View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/similar_examples.py)
We designate here as *Similar Examples* all methods that given an input sample, search for the most similar **training** samples given a distance function `distance`. Furthermore, one can define the search space using a `projection` function (see [Projections](api/example_based/projections.md)). This function should map an input sample to the search space where the distance function is defined and meaningful (**e.g.** the latent space of a Convolutional Neural Network).
Then, a K-Nearest Neighbors (KNN) search is performed to find the most similar samples in the search space.
From 27af808c52d0bd916097d0a5c90f945716a67916 Mon Sep 17 00:00:00 2001
From: lucas Hervier
Date: Tue, 30 Jul 2024 10:32:32 +0200
Subject: [PATCH 085/138] fixup: in cole documentation, change misleading
information
---
docs/api/example_based/methods/cole.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/api/example_based/methods/cole.md b/docs/api/example_based/methods/cole.md
index 38d52b43..0ba3916e 100644
--- a/docs/api/example_based/methods/cole.md
+++ b/docs/api/example_based/methods/cole.md
@@ -23,7 +23,7 @@ More specifically, the COLE approach is based on the following steps:
- (3) Perform a KNN search in the projection space to find the most similar training samples
!!! info
- In the original paper, the authors focused on Multi-Layer Perceptrons (MLP) and three attribution methods (LPR, Integrated Gradient, and DeepLift). We decided to implement a COLE method that generalizes to a more broader range of Neural Networks and attribution methods that are gradient-based (see [API Attributions documentation](api/attributions/api_attributions/) for definition).
+ In the original paper, the authors focused on Multi-Layer Perceptrons (MLP) and three attribution methods (LPR, Integrated Gradient, and DeepLift). We decided to implement a COLE method that generalizes to a more broader range of Neural Networks and attribution methods (see [API Attributions documentation](api/attributions/api_attributions/) to see the list of methods available).
## Example
From 6d15e53ed331e0983ca44d08284069a7b43b5ffc Mon Sep 17 00:00:00 2001
From: lucas Hervier
Date: Thu, 1 Aug 2024 10:29:25 +0200
Subject: [PATCH 086/138] docs: add a documentation page for projections
---
docs/api/example_based/projections.md | 25 +++++++++++++++++++++++++
1 file changed, 25 insertions(+)
diff --git a/docs/api/example_based/projections.md b/docs/api/example_based/projections.md
index e69de29b..2918a573 100644
--- a/docs/api/example_based/projections.md
+++ b/docs/api/example_based/projections.md
@@ -0,0 +1,25 @@
+# Projections
+
+In example-based explainability, one often needs to define a notion of similarity (distance) between samples. However, the original feature space may not be the most suitable space to define this similarity. For instance, in the case of images, two images can be very similar in terms of their pixel values but very different in terms of their semantic content. In addition, computing distances in the original feature space does not take into account the model's whatsoever, questioning the explainability of the method.
+
+To address these issues, one can project the samples into a new space where the distances between samples are more meaningful with respect to the model's decision. Two approaches are commonly used to define this projection space: (1) use a latent space and (2) use a feature weighting scheme.
+
+Consequently, we defined the general `Projection` class that will be used as a base class for all projection methods. This class allows one to use one or both of the aforementioned approaches. Indeed, one can instantiate a `Projection` object with a `space_projection` method, that define a projection from the feature space to a space of interest, and a`get_weights` method, that defines the feature weighting scheme. The `Projection` class will then project a sample with the `space_projection` method and weight the projected sample's features with the `get_weights` method.
+
+In addition, we provide concrete implementations of the `Projection` class: `LatentSpaceProjection`, `AttributionProjection`, and `HadamardProjection`.
+
+## `Projection` class
+
+{{xplique.example_based.projections.Projection}}
+
+## `LatentSpaceProjection` class
+
+{{xplique.example_based.projections.LatentSpaceProjection}}
+
+## `AttributionProjection` class
+
+{{xplique.example_based.projections.AttributionProjection}}
+
+## `HadamardProjection` class
+
+{{xplique.example_based.projections.HadamardProjection}}
From 64903ab4c7c7d94b6d349f7f58f0795ae00756bb Mon Sep 17 00:00:00 2001
From: lucas Hervier
Date: Thu, 1 Aug 2024 10:56:56 +0200
Subject: [PATCH 087/138] docs: add some details in the api documentation
---
docs/api/example_based/api_example_based.md | 3 +++
docs/api/example_based/search_methods.md | 0
2 files changed, 3 insertions(+)
delete mode 100644 docs/api/example_based/search_methods.md
diff --git a/docs/api/example_based/api_example_based.md b/docs/api/example_based/api_example_based.md
index b1f6f76f..0677a265 100644
--- a/docs/api/example_based/api_example_based.md
+++ b/docs/api/example_based/api_example_based.md
@@ -147,6 +147,9 @@ Applies the projection to a dataset through `Dataset.map`.
Search methods are used to retrieve examples from the `cases_dataset` that are relevant to the input samples.
+!!!info
+ In an Example method, the `cases_dataset` is the dataset that has been projected with a `Projection` object (see the previous section). The search methods are used to find examples in this projected space.
+
The `BaseSearchMethod` class is an abstract base class for example-based search methods. It defines the interface for search methods used to find examples in a dataset. This class should be inherited by specific search methods.
??? abstract "Table of search methods available"
diff --git a/docs/api/example_based/search_methods.md b/docs/api/example_based/search_methods.md
deleted file mode 100644
index e69de29b..00000000
From 5b6f026634deffa12114b756230f6d487b649d8e Mon Sep 17 00:00:00 2001
From: lucas Hervier
Date: Fri, 2 Aug 2024 11:42:43 +0200
Subject: [PATCH 088/138] docs: update the mkdocs.yml
---
docs/api/example_based/api_example_based.md | 4 ++--
mkdocs.yml | 11 +++++++++++
2 files changed, 13 insertions(+), 2 deletions(-)
diff --git a/docs/api/example_based/api_example_based.md b/docs/api/example_based/api_example_based.md
index 0677a265..09a87b01 100644
--- a/docs/api/example_based/api_example_based.md
+++ b/docs/api/example_based/api_example_based.md
@@ -1,6 +1,6 @@
-# API: Example-based API
+# API: Example-based
-- [**Example-based Methods**: Getting strated]() **WIP**
+- [**Example-based Methods**: Getting started]() **WIP**
## Context ##
diff --git a/mkdocs.yml b/mkdocs.yml
index f3f3eaf6..c9a1391e 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -43,11 +43,22 @@ nav:
- Tcav: api/concepts/tcav.md
- Craft: api/concepts/craft.md
- Example based:
+ - API Description: api/example_based/api_example_based.md
+ - Methods:
+ - Cole: api/example_based/methods/cole.md
+ - Kleor: api/example_based/methods/kleor.md
+ - LabelAwareCounterFactuals: api/example_based/methods/label_aware_counterfactuals.md
+ - NaiveCounterFactuals: api/example_based/methods/naive_counterfactuals.md
+ - SimilarExamples: api/example_based/methods/similar_examples.md
- Prototypes:
- API Description: api/example_based/prototypes/api_prototypes.md
- ProtoGreedy: api/example_based/prototypes/proto_greedy.md
- ProtoDash: api/example_based/prototypes/proto_dash.md
- MMDCritic: api/example_based/prototypes/mmd_critic.md
+ - Projections: api/example_based/projections.md
+ - Search Methods:
+ - Kleor: api/example_based/search_methods/kleor.md
+ - KNN: api/example_based/search_methods/knn.md
- Feature visualization:
- Modern Feature Visualization (MaCo): api/feature_viz/maco.md
- Feature visualization: api/feature_viz/feature_viz.md
From 70106fe66f27c267264c9f6652c2387312dd6787 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Fri, 2 Aug 2024 16:36:13 +0200
Subject: [PATCH 089/138] example based docs: small modifications
---
docs/api/example_based/api_example_based.md | 16 +++++++++++-----
docs/api/example_based/methods/kleor.md | 3 +++
.../methods/naive_counter_factuals.md | 2 +-
mkdocs.yml | 4 ++--
4 files changed, 17 insertions(+), 8 deletions(-)
diff --git a/docs/api/example_based/api_example_based.md b/docs/api/example_based/api_example_based.md
index 09a87b01..23704ce7 100644
--- a/docs/api/example_based/api_example_based.md
+++ b/docs/api/example_based/api_example_based.md
@@ -5,7 +5,7 @@
## Context ##
!!! quote
- While saliency maps have stolen the show for the last few years in the XAI field, their ability to reflect models' internal processes has been questioned. Although less in the spotlight, example-based XAI methods have continued to improve. It encompasses methods that use examples as explanations for a machine learning model's predictions. This aligns with the psychological mechanisms of human reasoning and makes example-based explanations natural and intuitive for users to understand. Indeed, humans learn and reason by forming mental representations of concepts based on examples.
+ While saliency maps have stolen the show for the last few years in the XAI field, their ability to reflect models' internal processes has been questioned. Although less in the spotlight, example-based XAI methods have continued to improve. It encompasses methods that use samples as explanations for a machine learning model's predictions. This aligns with the psychological mechanisms of human reasoning and makes example-based explanations natural and intuitive for users to understand. Indeed, humans learn and reason by forming mental representations of concepts based on examples.
-- [Natural Example-Based Explainability: a Survey (2023)](https://arxiv.org/abs/2309.03234)[^1]
@@ -14,7 +14,7 @@ As mentioned by our team members in the quote above, example-based methods are a
While not being exhaustive we tried to cover a range of methods that are representative of the field and that belong to different families: similar examples, contrastive (counter-factuals and semi-factuals) examples, and prototypes (as concepts based methods have a dedicated sections).
At present, we made the following choices:
-- Focus on methods that are natural example methods (see the paper above for more details).
+- Focus on methods that are natural example methods (post-hoc and non-generative, see the paper above for more details).
- Try to unify the three families of approaches with a common API.
!!! info
@@ -39,7 +39,7 @@ explanations = explainer.explain(inputs, targets)
We tried to keep the API as close as possible to the one of the attribution methods to keep a consistent experience for the users.
-The `BaseExampleMethod` is an abstract base class designed for example-based methods used to explain classification models. It provides examples from a dataset (usually the training dataset) to help understand a model's predictions. Examples are selected using a [search method](#search-methods) within a defined search space, projected from the input space using a [projection function](#projections).
+The `BaseExampleMethod` is an abstract base class designed for example-based methods used to explain classification models. It provides examples from a dataset (usually the training dataset) to help understand a model's predictions. Examples are projected from the input space to a search space using a [projection function](#projections). The projection function defines the search space. Then, examples are selected using a [search method](#search-methods) within the search space.
??? abstract "Table of example-based methods available"
@@ -63,7 +63,13 @@ The `BaseExampleMethod` is an abstract base class designed for example-based met
- **k** (`int`): The number of examples to retrieve per input.
- **projection** (`Union[Projection, Callable]`): A projection or callable function that projects samples from the input space to the search space. The search space should be relevant for the model. (see [Projections](#projections))
- **case_returns** (`Union[List[str], str]`): Elements to return in `self.explain()`. Default is "examples".
-- **batch_size** (`Optional[int]`): Number of samples processed simultaneously for projection and search. Ignored if `tf.data.Dataset` is provided.
+- **batch_size** (`Optional[int]`): Number of samples processed simultaneously for projection and search. Ignored if `cases_dataset` is a `tf.data.Dataset`.
+
+!!!tips
+ If the elements of your dataset are tuples (cases, labels), you can pass this dataset directly to the `cases_dataset`.
+
+!!!tips
+ Apart from contrastive explanations, in the case of classification, the built-in [Projections](#projections) compute `targets` online and the `targets_dataset` is not necessary.
### Properties ###
@@ -148,7 +154,7 @@ Applies the projection to a dataset through `Dataset.map`.
Search methods are used to retrieve examples from the `cases_dataset` that are relevant to the input samples.
!!!info
- In an Example method, the `cases_dataset` is the dataset that has been projected with a `Projection` object (see the previous section). The search methods are used to find examples in this projected space.
+ In an search method, the `cases_dataset` is the dataset that has been projected with a `Projection` object (see the previous section). The search methods are used to find examples in this projected space.
The `BaseSearchMethod` class is an abstract base class for example-based search methods. It defines the interface for search methods used to find examples in a dataset. This class should be inherited by specific search methods.
diff --git a/docs/api/example_based/methods/kleor.md b/docs/api/example_based/methods/kleor.md
index 12b2a9fb..187d6923 100644
--- a/docs/api/example_based/methods/kleor.md
+++ b/docs/api/example_based/methods/kleor.md
@@ -30,6 +30,9 @@ We extended to the $k$ nearest neighbors of the NUN for both approaches.
!!!info
In our implementation, we rather consider the labels predicted by the model $\hat{y}$ (*i.e.* the targets) rather than $y$!
+!!!tips
+ As KLEOR methods use counterfactuals, they can also return them. Therefore, it is possible to obtain both semi-factuals and counterfactuals with an unique method. To do so "nuns" and "nuns_labels" should be added to the `cases_returns` list.
+
## Example
```python
diff --git a/docs/api/example_based/methods/naive_counter_factuals.md b/docs/api/example_based/methods/naive_counter_factuals.md
index 35ed8779..e81350c0 100644
--- a/docs/api/example_based/methods/naive_counter_factuals.md
+++ b/docs/api/example_based/methods/naive_counter_factuals.md
@@ -9,7 +9,7 @@
📰 [Paper](https://www.semanticscholar.org/paper/Nearest-unlike-neighbor-(NUN)%3A-an-aid-to-decision-Dasarathy/48c1a310f655b827e5e7d712c859b25a4e3c0902)
!!!note
- The paper referenced here is not exactly the one we implemented as we a "naive" of it. However, it is probably the closest in essence of what we implemented.
+ The paper referenced here is not exactly the one we implemented as we use a "naive" version of it. However, it is probably the closest in essence of what we implemented.
We define here a "naive" counterfactual method that is based on the Nearest Unlike Neighbor (NUN) concept introduced by Dasarathy in 1991[^1]. In essence, the NUN of a sample $(x, y)$ is the closest sample in the training dataset which has a different label than $y$.
diff --git a/mkdocs.yml b/mkdocs.yml
index c9a1391e..cd59b533 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -47,8 +47,8 @@ nav:
- Methods:
- Cole: api/example_based/methods/cole.md
- Kleor: api/example_based/methods/kleor.md
- - LabelAwareCounterFactuals: api/example_based/methods/label_aware_counterfactuals.md
- - NaiveCounterFactuals: api/example_based/methods/naive_counterfactuals.md
+ - LabelAwareCounterFactuals: api/example_based/methods/label_aware_counter_factuals.md
+ - NaiveCounterFactuals: api/example_based/methods/naive_counter_factuals.md
- SimilarExamples: api/example_based/methods/similar_examples.md
- Prototypes:
- API Description: api/example_based/prototypes/api_prototypes.md
From 467c79bded4af5166061dc724c5508c7aa599975 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Fri, 2 Aug 2024 16:37:15 +0200
Subject: [PATCH 090/138] example based: small fixes
---
xplique/example_based/projections/attributions.py | 2 +-
xplique/example_based/projections/base.py | 1 +
xplique/example_based/projections/hadamard.py | 2 +-
xplique/example_based/search_methods/knn.py | 6 ++++++
4 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/xplique/example_based/projections/attributions.py b/xplique/example_based/projections/attributions.py
index ef3b0ce8..46074fab 100644
--- a/xplique/example_based/projections/attributions.py
+++ b/xplique/example_based/projections/attributions.py
@@ -74,7 +74,7 @@ def __init__(
# change default operator
if not "operator" in attribution_kwargs or attribution_kwargs["operator"] is None:
- warnings.warn("No operator provided, using standard classification operator."\
+ warnings.warn("No operator provided, using standard classification operator. "\
+ "For non-classification tasks, please specify an operator.")
attribution_kwargs["operator"] = target_free_classification_operator
diff --git a/xplique/example_based/projections/base.py b/xplique/example_based/projections/base.py
index 4a76de29..9d076bb4 100644
--- a/xplique/example_based/projections/base.py
+++ b/xplique/example_based/projections/base.py
@@ -99,6 +99,7 @@ def get_weights(inputs, _ = None):
if space_projection is None:
self.space_projection = lambda inputs: inputs
elif hasattr(space_projection, "__call__"):
+ self.mappable = False
self.space_projection = space_projection
else:
raise TypeError(
diff --git a/xplique/example_based/projections/hadamard.py b/xplique/example_based/projections/hadamard.py
index 05fb77e3..9ac2b8c8 100644
--- a/xplique/example_based/projections/hadamard.py
+++ b/xplique/example_based/projections/hadamard.py
@@ -71,7 +71,7 @@ def __init__(
device=device)
if operator is None:
- warnings.warn("No operator provided, using standard classification operator."\
+ warnings.warn("No operator provided, using standard classification operator. "\
+ "For non-classification tasks, please specify an operator.")
operator = target_free_classification_operator
diff --git a/xplique/example_based/search_methods/knn.py b/xplique/example_based/search_methods/knn.py
index d5ed1be2..fe9b50fc 100644
--- a/xplique/example_based/search_methods/knn.py
+++ b/xplique/example_based/search_methods/knn.py
@@ -302,6 +302,12 @@ class FilterKNN(BaseKNN):
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
+ targets_dataset
+ Targets are expected to be the one-hot encoding of the model's predictions for the samples in cases_dataset.
+ `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
+ Batch size and cardinality of other datasets should match `cases_dataset`.
+ Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
+ the case for your dataset, otherwise, examples will not make sense.
k
The number of examples to retrieve.
search_returns
From 3e638a86a2c4ac4429fcba0f32cf25e69bbb8383 Mon Sep 17 00:00:00 2001
From: lucas Hervier
Date: Mon, 5 Aug 2024 16:42:10 +0200
Subject: [PATCH 091/138] docs: improve documentation with antonin's feedbacks,
modify a parameter name such it matches its documentation
---
docs/api/example_based/api_example_based.md | 134 ++++--------------
.../label_aware_counter_factuals.md | 17 ++-
.../naive_counter_factuals.md | 17 ++-
docs/api/example_based/projections.md | 65 +++++++--
.../api/example_based/search_methods/kleor.md | 47 ------
docs/api/example_based/search_methods/knn.md | 68 ---------
.../{methods => semifactuals}/kleor.md | 51 ++++++-
.../{methods => similar_examples}/cole.md | 16 ++-
.../similar_examples.md | 14 +-
mkdocs.yml | 21 ++-
tests/example_based/test_projections.py | 2 +-
.../example_based/projections/attributions.py | 6 +-
xplique/example_based/projections/hadamard.py | 4 +-
xplique/example_based/similar_examples.py | 2 +-
14 files changed, 192 insertions(+), 272 deletions(-)
rename docs/api/example_based/{methods => counterfactuals}/label_aware_counter_factuals.md (68%)
rename docs/api/example_based/{methods => counterfactuals}/naive_counter_factuals.md (74%)
delete mode 100644 docs/api/example_based/search_methods/kleor.md
delete mode 100644 docs/api/example_based/search_methods/knn.md
rename docs/api/example_based/{methods => semifactuals}/kleor.md (68%)
rename docs/api/example_based/{methods => similar_examples}/cole.md (83%)
rename docs/api/example_based/{methods => similar_examples}/similar_examples.md (69%)
diff --git a/docs/api/example_based/api_example_based.md b/docs/api/example_based/api_example_based.md
index 23704ce7..6485b09a 100644
--- a/docs/api/example_based/api_example_based.md
+++ b/docs/api/example_based/api_example_based.md
@@ -23,6 +23,8 @@ At present, we made the following choices:
## Common API ##
```python
+projection = ProjectionMethod(model)
+
explainer = ExampleMethod(
cases_dataset,
labels_dataset,
@@ -39,21 +41,31 @@ explanations = explainer.explain(inputs, targets)
We tried to keep the API as close as possible to the one of the attribution methods to keep a consistent experience for the users.
-The `BaseExampleMethod` is an abstract base class designed for example-based methods used to explain classification models. It provides examples from a dataset (usually the training dataset) to help understand a model's predictions. Examples are projected from the input space to a search space using a [projection function](#projections). The projection function defines the search space. Then, examples are selected using a [search method](#search-methods) within the search space.
+The `BaseExampleMethod` is an abstract base class designed for example-based methods used to explain classification models. It provides examples from a dataset (usually the training dataset) to help understand a model's predictions. Examples are projected from the input space to a search space using a [projection function](#projections). The projection function defines the search space. Then, examples are selected using a [search method](#search-methods) within the search space. For all example-based methods, one can define the `distance` function that will be used by the search method.
+
+We can broadly categorize example-based methods into four families: similar examples, counter-factuals, semi-factuals, and prototypes.
+
+- **Similar Examples**: This method involves finding instances in the dataset that are similar to a given instance. The similarity is often determined based on the feature space, and these examples can help in understanding the model's decision by showing what other data points resemble the instance in question.
+- **Counter Factuals**: Counterfactual explanations identify the minimal changes needed to an instance's features to change the model's prediction to a different, specified outcome. They help answer "what-if" scenarios by showing how altering certain aspects of the input would lead to a different decision.
+- **Semi Factuals**: Semifactual explanations describe hypothetical situations where most features of an instance remain the same except for one or a few features, without changing the overall outcome. They highlight which features could vary without altering the prediction.
+- **Prototypes**: Prototypes are representative examples from the dataset that summarize typical cases within a certain category or cluster. They act as archetypal instances that the model uses to make predictions, providing a reference point for understanding model behavior.
??? abstract "Table of example-based methods available"
- | Method | Documentation | Family |
+ | Method | Family | Documentation |
| --- | --- | --- |
- | `SimilarExamples` | [SimilarExamples](api/example_based/methods/similar_examples) | Similar Examples |
- | `Cole` | [Cole](api/example_based/methods/cole) | Similar Examples |
- | `ProtoGreedy` | [ProtoGreedy](api/example_based/methods/proto_greedy/) | Prototypes |
- | `ProtoDash` | [ProtoDash](api/example_based/methods/proto_dash/) | Prototypes |
- | `MMDCritic` | [MMDCritic](api/example_based/methods/mmd_critic/) | Prototypes |
- | `NaiveCounterFactuals` | [NaiveCounterFactuals](api/example_based/methods/naive_counter_factuals/) | Counter Factuals |
- | `LabelAwareCounterFactuals` | [LabelAwareCounterFactuals](api/example_based/methods/label_aware_counter_factuals/) | Counter Factuals |
- | `KLEORSimMiss` | [KLEOR](api/example_based/methods/kleor/) | Semi Factuals |
- | `KLEORGlobalSim` | [KLEOR](api/example_based/methods/kleor/) | Semi Factuals |
+ | `SimilarExamples` | Similar Examples | [SimilarExamples](../similar_examples/similar_examples/) |
+ | `Cole` | Similar Examples | [Cole](../similar_examples/cole/) |
+ | | | |
+ | `NaiveCounterFactuals` | Counter Factuals | [NaiveCounterFactuals](../counterfactuals/naive_counter_factuals/) |
+ | `LabelAwareCounterFactuals` | Counter Factuals | [LabelAwareCounterFactuals](../counterfactuals/label_aware_counter_factuals/) |
+ ||||
+ | `KLEORSimMiss` | Semi Factuals | [KLEOR](../semifactuals/kleor/) |
+ | `KLEORGlobalSim` | Semi Factuals | [KLEOR](../semifactuals/kleor/) |
+ ||||
+ | `ProtoGreedy` | Prototypes | [ProtoGreedy](../prototypes/proto_greedy/) |
+ | `ProtoDash` | Prototypes | [ProtoDash](../prototypes/proto_dash/) |
+ | `MMDCritic` | Prototypes | [MMDCritic](../prototypes/mmd_critic/) |
### Parameters ###
@@ -93,110 +105,22 @@ Returns the relevant examples to explain the (inputs, targets). Projects inputs
Projections are functions that map input samples to a search space where examples are retrieved with a `search_method`. The search space should be relevant for the model (e.g. projecting the inputs into the latent space of the model).
!!!info
- If one decides to use the identity function as a projection, the search space will be the input space, thus rather explaining the dataset than the model. In this case, it may be more relevant to directly use a `search_method` ([Search Methods](#search-methods)) for the dataset.
+ If one decides to use the identity function as a projection, the search space will be the input space, thus rather explaining the dataset than the model.
-The `Projection` class is an abstract base class for projections. It involves two parts: `space_projection` and `weights`. The samples are first projected to a new space and then weighted.
+The `Projection` class is a base class for projections. It involves two parts: `space_projection` and `weights`. The samples are first projected to a new space and then weighted.
!!!warning
- If both parts are `None`, the projection acts as an identity function. At least one part should involve the model to ensure meaningful distance calculations.
-
-??? abstract "Table of projection methods available"
-
- | Method | Documentation |
- | --- | --- |
- | `Projection` | HERE |
- | `LatentSpaceProjection`| [LatentSpaceProjection](api/example_based/projections/latent_space_projection/) |
- | `HadamardProjection` | [HadamardProjection](api/example_based/projections/hadamard_projection/) |
- | `AttributionProjection` | [AttributionProjection](api/example_based/projections/attribution_projection/) |
-
-### Parameters ###
-
-- **get_weights** (`Optional[Union[Callable, tf.Tensor, np.ndarray]]`): Either a Tensor or a callable function.
- - **Tensor**: Weights are applied in the projected space.
- - **Callable**: A function that takes inputs and targets, returning the weights (Tensor). Weights should match the input shape (possibly differing in channels).
-
- **Example**:
- ```python
- def get_weights_example(projected_inputs: Union[tf.Tensor, np.ndarray],
- targets: Optional[Union[tf.Tensor, np.ndarray]] = None):
- # Compute weights using projected_inputs and targets.
- weights = ... # Custom logic involving the model.
- return weights
- ```
-
-- **space_projection** (`Optional[Callable]`): Callable that takes samples and returns a Tensor in the projected space. An example of a projected space is the latent space of a model.
-- **device** (`Optional[str]`): Device to use for the projection. If `None`, the default device is used.
-- **mappable** (`bool`): If `True`, the projection can be applied to a dataset through `Dataset.map`. Otherwise, the projection is done through a loop.
-
-### `project(self, inputs, targets=None)` ###
+ If both parts are `None`, the projection acts as an identity function. In general, we advise that one part should involve the model to ensure meaningful distance calculations with respect to the model.
-Projects samples into a space meaningful for the model. This involves weighting the inputs, projecting them into a latent space, or both. This method should be called during initialization and for each explanation.
-
-- **inputs** (`Union[tf.Tensor, np.ndarray]`): Input samples to be explained. Expected shapes include (N, W), (N, T, W), (N, W, H, C).
-- **targets** (`Optional[Union[tf.Tensor, np.ndarray]]`): Additional parameter for `self.get_weights` function.
-
-**Returns:** `projected_samples` - The samples projected into the new space.
-
-!!!info
- The `__call__` method is an alias for the `project` method.
-
-### `project_dataset(self, cases_dataset, targets_dataset=None)` ###
-
-Applies the projection to a dataset through `Dataset.map`.
-
-- **cases_dataset** (`tf.data.Dataset`): Dataset of samples to be projected.
-- **targets_dataset** (`Optional[tf.data.Dataset]`): Dataset of targets for the samples.
-
-**Returns:** `projected_dataset` - The projected dataset.
+To know more about projections and their importance, you can refer to the [Projections](../../projections/) section.
## Search Methods ##
Search methods are used to retrieve examples from the `cases_dataset` that are relevant to the input samples.
-!!!info
+!!!warning
In an search method, the `cases_dataset` is the dataset that has been projected with a `Projection` object (see the previous section). The search methods are used to find examples in this projected space.
-The `BaseSearchMethod` class is an abstract base class for example-based search methods. It defines the interface for search methods used to find examples in a dataset. This class should be inherited by specific search methods.
-
-??? abstract "Table of search methods available"
-
- | Method | Documentation |
- | --- | --- |
- | `KNN` | [KNN](api/example_based/search_methods/knn/) |
- | `FilterKNN` | [KNN](api/example_based/search_methods/knn/) |
- | `ProtoGreedySearch` | [ProtoGreedySearch](api/example_based/search_methods/proto_greedy_search/) |
- | `ProtoDashSearch` | [ProtoDashSearch](api/example_based/search_methods/proto_dash_search/) |
- | `MMDCriticSearch` | [MMDCriticSearch](api/example_based/search_methods/mmd_critic_search/) |
- | `KLEORSimMissSearch` | [KLEOR](api/example_based/search_methods/kleor/) |
- | `KLEORGlobalSimSearch` | [KLEOR](api/example_based/search_methods/kleor/) |
-
-
-### Parameters ###
-
-- **cases_dataset** (`Union[tf.data.Dataset, tf.Tensor, np.ndarray]`): The dataset containing the examples to search in. It should be batched as TensorFlow provides no method to verify this. Ensure the dataset is not reshuffled at each iteration.
-- **k** (`int`): The number of examples to retrieve.
-- **search_returns** (`Optional[Union[List[str], str]]`): Elements to return in `self.find_examples()`. It should be a subset of `self._returns_possibilities`.
-- **batch_size** (`Optional[int]`): Number of samples treated simultaneously. It should match the batch size of the `cases_dataset` if it is a `tf.data.Dataset`.
-
-### Properties ###
-
-- **k** (`int`): Getter and setter for the `k` parameter.
-- **returns** (`Union[List[str], str]`): Getter and setter for the `returns` parameter. Defines the elements to return in `self.find_examples()`.
-
-### `find_examples(self, inputs, targets)` ###
-
-Abstract method to search for samples to return as examples. It should be implemented in subclasses. It may return the indices corresponding to the samples based on `self.returns` value.
-
-- **inputs** (`Union[tf.Tensor, np.ndarray]`): Input samples to be explained. Expected shapes include (N, W), (N, T, W), (N, W, H, C).
-- **targets** (`Optional[Union[tf.Tensor, np.ndarray]]`): Targets associated with the samples to be explained.
-
-**Returns:** `return_dict` - Dictionary containing the elements specified in `self.returns`.
-
-!!!info
- The `__call__` method is an alias for the `find_examples` method.
-
-### `_returns_possibilities`
-
-Attribute thet list possible elements that can be returned by the search methods. For the base class: `["examples", "distances", "labels", "include_inputs"]`.
+Each example-based method has its own search method. The search method is defined in the `search_method_class` property of the `ExampleMethod` class.
[^1]: [Natural Example-Based Explainability: a Survey (2023)](https://arxiv.org/abs/2309.03234)
\ No newline at end of file
diff --git a/docs/api/example_based/methods/label_aware_counter_factuals.md b/docs/api/example_based/counterfactuals/label_aware_counter_factuals.md
similarity index 68%
rename from docs/api/example_based/methods/label_aware_counter_factuals.md
rename to docs/api/example_based/counterfactuals/label_aware_counter_factuals.md
index d08f6224..93a20c9b 100644
--- a/docs/api/example_based/methods/label_aware_counter_factuals.md
+++ b/docs/api/example_based/counterfactuals/label_aware_counter_factuals.md
@@ -11,10 +11,10 @@
!!!note
The paper referenced here is not exactly the one we implemented. However, it is probably the closest in essence of what we implemented.
-In contrast to the [Naive Counterfactuals](api/example_based/methods/naive_counter_factuals/) approach, the Label Aware Counterfactuals leverage an *a priori* knowledge of the Counterfactuals' (CFs) targets to guide the search for the CFs (*e.g.* one is looking for a CF of the digit 8 in MNIST dataset within the digit 0 instances).
+In contrast to the [Naive Counterfactuals](../../counterfactuals/naive_counter_factuals/) approach, the Label Aware CounterFactuals leverage an *a priori* knowledge of the Counterfactuals' (CFs) targets to guide the search for the CFs (*e.g.* one is looking for a CF of the digit 8 in MNIST dataset within the digit 0 instances).
!!!warning
- Consequently, for this class, when a user call the `explain` method, the user is not expected to provide the targets corresponding to the input samples but rather a one-hot encoding of the targets of the CFs to search for.
+ Consequently, for this class, when a user call the `explain` method, the user is not expected to provide the targets corresponding to the input samples but rather a one-hot encoding of the label expected for the CFs.
!!!info
One can use the `Projection` object to compute the distances between the samples (e.g. search for the CF in the latent space of a model).
@@ -22,23 +22,26 @@ In contrast to the [Naive Counterfactuals](api/example_based/methods/naive_count
## Example
```python
-from xplique.example_based import LabelAwareCounterfactuals
+from xplique.example_based import LabelAwareCounterFactuals
# load the training dataset
cases_dataset = ... # load the training dataset
-targets_dataset = ... # load the targets of the training dataset
+targets_dataset = ... # load the one-hot encoding of predicted labels of the training dataset
+# parameters
k = 5
+distance = "euclidean"
# instantiate the LabelAwareCounterfactuals object
-lacf = LabelAwareCounterfactuals(cases_dataset=cases_dataset,
+lacf = LabelAwareCounterFactuals(cases_dataset=cases_dataset,
targets_dataset=targets_dataset,
k=k,
+ distance=distance,
)
# load the test samples
test_samples = ... # load the test samples to search for
-test_cf_targets = ... # WARNING: provide the one-hot encoding of the targets of the CFs to search for
+test_cf_targets = ... # WARNING: provide the one-hot encoding of the expected label of the CFs
# search the CFs for the test samples
counterfactuals = lacf.explain(test_samples, test_cf_targets)
@@ -48,4 +51,4 @@ counterfactuals = lacf.explain(test_samples, test_cf_targets)
TODO: Add notebooks
-{{xplique.example_based.counterfactuals.LabelAwareCounterfactuals}}
\ No newline at end of file
+{{xplique.example_based.counterfactuals.LabelAwareCounterFactuals}}
\ No newline at end of file
diff --git a/docs/api/example_based/methods/naive_counter_factuals.md b/docs/api/example_based/counterfactuals/naive_counter_factuals.md
similarity index 74%
rename from docs/api/example_based/methods/naive_counter_factuals.md
rename to docs/api/example_based/counterfactuals/naive_counter_factuals.md
index e81350c0..93d35307 100644
--- a/docs/api/example_based/methods/naive_counter_factuals.md
+++ b/docs/api/example_based/counterfactuals/naive_counter_factuals.md
@@ -1,4 +1,4 @@
-# Naive Counterfactuals
+# Naive CounterFactuals
@@ -15,28 +15,31 @@ We define here a "naive" counterfactual method that is based on the Nearest Unli
Thus, in this naive approach to counterfactuals, we yield the $k$ nearest training instances that have a different label than the target of the input sample in a greedy fashion.
-As it is mentioned in the [API documentation](api/example_based/methods/api_example_based/), by setting a `Projection` object, one can use the projection space to compute the distances between the samples (e.g. search for the CF in the latent space of a model).
+As it is mentioned in the [API documentation](../../api_example_based/), by setting a `Projection` object, one will map the inputs to a space where the distance function is meaningful.
## Example
```python
-from xplique.example_based import NaiveCounterfactuals
+from xplique.example_based import NaiveCounterFactuals
# load the training dataset
cases_dataset = ... # load the training dataset
-targets_dataset = ... # load the targets of the training dataset
+targets_dataset = ... # load the one-hot encoding of predicted labels of the training dataset
+# parameters
k = 5
+distance = "euclidean"
# instantiate the NaiveCounterfactuals object
-ncf = NaiveCounterfactuals(cases_dataset=cases_dataset,
+ncf = NaiveCounterFactuals(cases_dataset=cases_dataset,
targets_dataset=targets_dataset,
k=k,
+ distance=distance,
)
# load the test samples and targets
test_samples = ... # load the test samples to search for
-test_targets = ... # load the targets of the test samples
+test_targets = ... # load the one-hot encoding of the test samples' predictions
# search the CFs for the test samples
counterfactuals = ncf.explain(test_samples, test_targets)
@@ -46,6 +49,6 @@ counterfactuals = ncf.explain(test_samples, test_targets)
TODO: Add notebooks
-{{xplique.example_based.counterfactuals.NaiveCounterfactuals}}
+{{xplique.example_based.counterfactuals.NaiveCounterFactuals}}
[^1] [Nearest unlike neighbor (NUN): an aid to decision making](https://www.semanticscholar.org/paper/Nearest-unlike-neighbor-(NUN)%3A-an-aid-to-decision-Dasarathy/48c1a310f655b827e5e7d712c859b25a4e3c0902)
\ No newline at end of file
diff --git a/docs/api/example_based/projections.md b/docs/api/example_based/projections.md
index 2918a573..c0495b20 100644
--- a/docs/api/example_based/projections.md
+++ b/docs/api/example_based/projections.md
@@ -8,18 +8,67 @@ Consequently, we defined the general `Projection` class that will be used as a b
In addition, we provide concrete implementations of the `Projection` class: `LatentSpaceProjection`, `AttributionProjection`, and `HadamardProjection`.
-## `Projection` class
-
{{xplique.example_based.projections.Projection}}
-## `LatentSpaceProjection` class
+!!!info
+ The `__call__` method is an alias for the `project` method.
-{{xplique.example_based.projections.LatentSpaceProjection}}
+## Defining a custom projection
-## `AttributionProjection` class
+To define a custom projection, one needs to implement the `space_projection` and/or `get_weights` methods. The `space_projection` method should return the projected sample, and the `get_weights` method should return the weights of the features of the projected sample.
-{{xplique.example_based.projections.AttributionProjection}}
+!!!info
+ The `get_weights` method should take as input the original sample once it has been projected using the `space_projection` method.
+
+For the sake of clarity, we provide an example of a custom projection that projects the samples into a latent space (the final convolution block of the ResNet50 model) and weights the features with the gradients of the model's output with respect to the inputs once they have gone through the layers until the final convolutional layer.
+
+```python
+import tensorflow as tf
+from xplique.attributions import Saliency
+from xplique.example_based.projections import Projection
+
+# load the model
+model = tf.keras.applications.ResNet50(weights="imagenet", include_top=True)
+
+latent_layer = model.get_layer("conv5_block3_out") # output of the final convolutional block
+features_extractor = tf.keras.Model(
+ model.input, latent_layer.output, name="features_extractor"
+)
+
+# reconstruct the second part of the InceptionV3 model
+second_input = tf.keras.Input(shape=latent_layer.output.shape[1:])
+
+x = second_input
+layer_found = False
+for layer in model.layers:
+ if layer_found:
+ x = layer(x)
+ if layer == latent_layer:
+ layer_found = True
-## `HadamardProjection` class
+predictor = tf.keras.Model(
+ inputs=second_input,
+ outputs=x,
+ name="predictor"
+)
+
+# build the custom projection
+space_projection = features_extractor
+get_weights = Saliency(predictor)
+
+custom_projection = Projection(space_projection=space_projection, get_weights=get_weights, mappable=False)
+
+# build random samples
+rdm_imgs = tf.random.normal((5, 224, 224, 3))
+rdm_targets = tf.random.uniform(shape=[5], minval=0, maxval=1000, dtype=tf.int32)
+rdm_targets = tf.one_hot(rdm_targets, depth=1000)
+
+# project the samples
+projections = custom_projection(rdm_imgs, rdm_targets)
+```
+
+{{xplique.example_based.projections.LatentSpaceProjection}}
+
+{{xplique.example_based.projections.AttributionProjection}}
-{{xplique.example_based.projections.HadamardProjection}}
+{{xplique.example_based.projections.HadamardProjection}}
\ No newline at end of file
diff --git a/docs/api/example_based/search_methods/kleor.md b/docs/api/example_based/search_methods/kleor.md
deleted file mode 100644
index 9ad70ba8..00000000
--- a/docs/api/example_based/search_methods/kleor.md
+++ /dev/null
@@ -1,47 +0,0 @@
-# KLEOR Search Methods
-
-Those search methods are used for the [KLEOR](api/example_based/methods/kleor/) methods.
-
-It encompasses the two following classes:
-- `KLEORSimMissSearch`: looks for Semi-Factuals examples by searching for the Nearest Unlike Neighbor (NUN) of the query. The NUN is the closest example to the query that has a different prediction than the query. Then, the method searches for the K-Nearest Neighbors (KNN) of the NUN that have the same prediction as the query.
-- `KLEORGlobalSim`: in addition to the previous method, the SF should be closer to the query than the NUN to be a candidate.
-
-## Examples
-
-```python
-from xplique.example_based.search_methods import KLEORSimMissSearch
-from xplique.example_based.search_methods import KLEORGlobalSim
-
-cases_dataset = ... # load the training dataset
-targets = ... # load the targets of the training dataset
-
-test_samples = ... # load the test samples to search for
-test_targets = ... # load the targets of the test samples
-
-# set some parameters
-k = 5
-distance = "euclidean"
-
-# create the KLEORSimMissSearch object
-kleor_sim_miss_search = KLEORSimMissSearch(cases_dataset=cases_dataset,
- targets_dataset=targets,
- k=k,
- distance=distance)
-
-# create the KLEORGlobalSim object
-kleor_global_sim = KLEORGlobalSim(cases_dataset=cases_dataset,
- targets_dataset=targets,
- k=k,
- distance=distance)
-
-# search for the K-Nearest Neighbors of the test samples
-sim_miss_neighbors = kleor_sim_miss_search.find_examples(test_samples, test_targets)
-global_sim_neighbors = kleor_global_sim.find_examples(test_samples, test_targets)
-```
-
-## Notebooks
-
-TODO: add the notebook for KLEOR
-
-{{xplique.example_based.search_methods.kleor.KLEORSimMissSearch}}
-{{xplique.example_based.search_methods.kleor.KLEORGlobalSim}}
diff --git a/docs/api/example_based/search_methods/knn.md b/docs/api/example_based/search_methods/knn.md
deleted file mode 100644
index 7f0eb423..00000000
--- a/docs/api/example_based/search_methods/knn.md
+++ /dev/null
@@ -1,68 +0,0 @@
-# K Nearest Neighbors
-
-KNN method to search examples. Based on `sklearn.neighbors.NearestNeighbors` [see the documentation](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NearestNeighbors.html).
-The kneighbors method is implemented in a batched way to handle large datasets and try to be memory efficient.
-
-In addition, we also added a `FilterKNN` class that allows to filter the neighbors based on a given criterion avoiding potentially a compute of the distances for all the samples. It is useful when the candidate neighbors are sparse and the distance computation is expensive.
-
-## Examples
-
-```python
-from xplique.example_based.search_methods import ORDER
-from xplique.example_based.search_methods import KNN
-
-# set some parameters
-k = 5
-cases_dataset = ... # load the training dataset
-test_samples = ... # load the test samples to search for
-
-distance = "euclidean"
-order = ORDER.ASCENDING
-
-# create the KNN object
-knn = KNN(cases_dataset = cases_dataset
- k = k,
- distance = distance,
- order = order)
-
-k_nearest_neighbors = knn.kneighbors(test_samples)
-```
-
-```python
-from xplique.example_based.search_methods import ORDER
-from xplique.example_based.search_methods import FilterKNN
-
-# set some parameters
-k = 5
-cases_dataset = ... # load the training dataset
-targets = ... # load the targets of the training dataset
-
-test_samples = ... # load the test samples to search for
-test_targets = ... # load the targets of the test samples
-
-distance = "euclidean"
-order = ORDER.ASCENDING
-
-# define a filter function
-def filter_fn(cases, inputs, targets, cases_targets):
- # filter the cases that have the same target as the input
- mask = tf.not_equal(targets, cases_targets)
- return mask
-
-# create the KNN object
-filter_knn = FilterKNN(cases_dataset=cases_dataset,
- targets_dataset=targets,
- k=k,
- distance=distance,
- order=order,
- filter_fn=filter_fn)
-
-k_nearest_neighbors = filter_knn.kneighbors(test_samples, test_targets)
-```
-
-## Notebooks
-
-TODO: add all notebooks that use this search method
-
-{{xplique.example_based.search_methods.knn.KNN}}
-{{xplique.example_based.search_methods.knn.FilterKNN}}
\ No newline at end of file
diff --git a/docs/api/example_based/methods/kleor.md b/docs/api/example_based/semifactuals/kleor.md
similarity index 68%
rename from docs/api/example_based/methods/kleor.md
rename to docs/api/example_based/semifactuals/kleor.md
index 187d6923..99ad2486 100644
--- a/docs/api/example_based/methods/kleor.md
+++ b/docs/api/example_based/semifactuals/kleor.md
@@ -33,34 +33,73 @@ We extended to the $k$ nearest neighbors of the NUN for both approaches.
!!!tips
As KLEOR methods use counterfactuals, they can also return them. Therefore, it is possible to obtain both semi-factuals and counterfactuals with an unique method. To do so "nuns" and "nuns_labels" should be added to the `cases_returns` list.
-## Example
+## Examples
```python
-from xplique.example_based import KLEORGlobalSim, KLEORSimMiss
+from xplique.example_based import KLEORSimMiss
+# loading
cases_dataset = ... # load the training dataset
-targets = ... # load the targets of the training dataset
+targets = ... # load the one-hot encoding of predicted labels of the training dataset
+# parameters
k = 5
+distance = "euclidean"
+case_returns = ["examples", "nuns"]
-# instantiate the KLEOR objects
+# instantiate the KLEOR object
kleor_sim_miss = KLEORSimMiss(cases_dataset=cases_dataset,
targets_dataset=targets,
k=k,
+ distance=distance,
)
+# load the test samples and targets
+test_samples = ... # load the test samples to search for
+test_targets = ... # load the one-hot encoding of the test samples' predictions
+
+# search the SFs for the test samples
+sim_miss_sf = kleor_sim_miss.explain(test_samples, test_targets)
+
+# get the semi-factuals
+semifactuals = global_sim_sf["examples"]
+
+# get the counterfactuals
+counterfactuals = global_sim_sf["nuns"]
+```
+
+```python
+from xplique.example_based import KLEORGlobalSim
+
+# loading
+cases_dataset = ... # load the training dataset
+targets = ... # load the one-hot encoding of predicted labels of the training dataset
+
+# parameters
+k = 5
+distance = "euclidean"
+case_returns = ["examples", "nuns"]
+
+# instantiate the KLEOR object
kleor_global_sim = KLEORGlobalSim(cases_dataset=cases_dataset,
targets_dataset=targets,
k=k,
+ distance=distance,
+ case_returns=case_returns,
)
# load the test samples and targets
test_samples = ... # load the test samples to search for
-test_targets = ... # load the targets of the test samples
+test_targets = ... # load the one-hot encoding of the test samples' predictions
# search the SFs for the test samples
-sim_miss_sf = kleor_sim_miss.explain(test_samples, test_targets)
global_sim_sf = kleor_global_sim.explain(test_samples, test_targets)
+
+# get the semi-factuals
+semifactuals = global_sim_sf["examples"]
+
+# get the counterfactuals
+counterfactuals = global_sim_sf["nuns"]
```
## Notebooks
diff --git a/docs/api/example_based/methods/cole.md b/docs/api/example_based/similar_examples/cole.md
similarity index 83%
rename from docs/api/example_based/methods/cole.md
rename to docs/api/example_based/similar_examples/cole.md
index 0ba3916e..004dd7a3 100644
--- a/docs/api/example_based/methods/cole.md
+++ b/docs/api/example_based/similar_examples/cole.md
@@ -15,15 +15,18 @@ COLE for Contributions Oriented Local Explanations was introduced by Kenny & Kea
-- [COLE paper](https://researchrepository.ucd.ie/handle/10197/11064)[^1]
-The core idea of the COLE approach is to use [attribution maps](api/attributions/api_attributions/) to define a relevant search space for the K-Nearest Neighbors (KNN) search.
+The core idea of the COLE approach is to use [attribution maps](../../../attributions/api_attributions/) to define a relevant search space for the K-Nearest Neighbors (KNN) search.
More specifically, the COLE approach is based on the following steps:
+
- (1) Given an input sample $x$, compute the attribution map $A(x)$
+
- (2) Consider the projection space defined by: $p: x \rightarrow A(x) \odot x$ ($\odot$ denotes the element-wise product)
+
- (3) Perform a KNN search in the projection space to find the most similar training samples
!!! info
- In the original paper, the authors focused on Multi-Layer Perceptrons (MLP) and three attribution methods (LPR, Integrated Gradient, and DeepLift). We decided to implement a COLE method that generalizes to a more broader range of Neural Networks and attribution methods (see [API Attributions documentation](api/attributions/api_attributions/) to see the list of methods available).
+ In the original paper, the authors focused on Multi-Layer Perceptrons (MLP) and three attribution methods (LPR, Integrated Gradient, and DeepLift). We decided to implement a COLE method that generalizes to a more broader range of Neural Networks and attribution methods (see [API Attributions documentation](../../../attributions/api_attributions/) to see the list of methods available).
## Example
@@ -34,7 +37,10 @@ from xplique.attributions import Saliency
model = ... # load the model
cases_dataset = ... # load the training dataset
target_dataset = ... # load the target dataset (predicted one-hot encoding of model's predictions)
+
+# parameters
k = 5
+distance = "euclidean"
# instantiate the Cole object
cole = Cole(
@@ -44,9 +50,9 @@ cole = Cole(
attribution_method=Saliency,
)
-# load the test samples
-test_samples = ... # load the test samples
-test_targets = ... # load the test targets
+# load the test samples and targets
+test_samples = ... # load the test samples to search for
+test_targets = ... # load the one-hot encoding of the test samples' predictions
# search the most similar samples with the COLE method
similar_samples = cole.explain(test_samples, test_targets)
diff --git a/docs/api/example_based/methods/similar_examples.md b/docs/api/example_based/similar_examples/similar_examples.md
similarity index 69%
rename from docs/api/example_based/methods/similar_examples.md
rename to docs/api/example_based/similar_examples/similar_examples.md
index 1dc9f1b8..be875a5d 100644
--- a/docs/api/example_based/methods/similar_examples.md
+++ b/docs/api/example_based/similar_examples/similar_examples.md
@@ -7,7 +7,7 @@
[View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/similar_examples.py)
-We designate here as *Similar Examples* all methods that given an input sample, search for the most similar **training** samples given a distance function `distance`. Furthermore, one can define the search space using a `projection` function (see [Projections](api/example_based/projections.md)). This function should map an input sample to the search space where the distance function is defined and meaningful (**e.g.** the latent space of a Convolutional Neural Network).
+We designate here as *Similar Examples* all methods that given an input sample, search for the most similar **training** samples given a distance function `distance`. Furthermore, one can define the search space using a `projection` function (see [Projections](../../projections/)). This function should map an input sample to the search space where the distance function is defined and meaningful (**e.g.** the latent space of a Convolutional Neural Network).
Then, a K-Nearest Neighbors (KNN) search is performed to find the most similar samples in the search space.
## Example
@@ -16,8 +16,12 @@ Then, a K-Nearest Neighbors (KNN) search is performed to find the most similar s
from xplique.example_based import SimilarExamples
cases_dataset = ... # load the training dataset
+targets = ... # load the one-hot encoding of predicted labels of the training dataset
+
+# parameters
k = 5
distance = "euclidean"
+case_returns = ["examples", "nuns"]
# define the projection function
def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None):
@@ -32,10 +36,18 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
# instantiate the SimilarExamples object
sim_ex = SimilarExamples(
cases_dataset=cases_dataset,
+ targets_dataset=targets,
k=k,
projection=custom_projection,
distance=distance,
)
+
+# load the test samples and targets
+test_samples = ... # load the test samples to search for
+test_targets = ... # load the one-hot encoding of the test samples' predictions
+
+# search the most similar samples with the SimilarExamples method
+similar_samples = sim_ex.explain(test_samples, test_targets)
```
# Notebooks
diff --git a/mkdocs.yml b/mkdocs.yml
index cd59b533..ee5082d3 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -44,21 +44,20 @@ nav:
- Craft: api/concepts/craft.md
- Example based:
- API Description: api/example_based/api_example_based.md
- - Methods:
- - Cole: api/example_based/methods/cole.md
- - Kleor: api/example_based/methods/kleor.md
- - LabelAwareCounterFactuals: api/example_based/methods/label_aware_counter_factuals.md
- - NaiveCounterFactuals: api/example_based/methods/naive_counter_factuals.md
- - SimilarExamples: api/example_based/methods/similar_examples.md
+ - Similar Examples:
+ - SimilarExamples: api/example_based/similar_examples/similar_examples.md
+ - Cole: api/example_based/similar_examples/cole.md
+ - Counterfactuals:
+ - LabelAwareCounterFactuals: api/example_based/counterfactuals/label_aware_counter_factuals.md
+ - NaiveCounterFactuals: api/example_based/counterfactuals/naive_counter_factuals.md
+ - Semifactuals:
+ - Kleor: api/example_based/semifactuals/kleor.md
- Prototypes:
- API Description: api/example_based/prototypes/api_prototypes.md
- ProtoGreedy: api/example_based/prototypes/proto_greedy.md
- ProtoDash: api/example_based/prototypes/proto_dash.md
- MMDCritic: api/example_based/prototypes/mmd_critic.md
- Projections: api/example_based/projections.md
- - Search Methods:
- - Kleor: api/example_based/search_methods/kleor.md
- - KNN: api/example_based/search_methods/knn.md
- Feature visualization:
- Modern Feature Visualization (MaCo): api/feature_viz/maco.md
- Feature visualization: api/feature_viz/feature_viz.md
@@ -106,8 +105,8 @@ markdown_extensions:
custom_checkbox: true
clickable_checkbox: true
- pymdownx.emoji:
- emoji_index: !!python/name:materialx.emoji.twemoji
- emoji_generator: !!python/name:materialx.emoji.to_svg
+ emoji_index: !!python/name:material.extensions.emoji.twemoji
+ emoji_generator: !!python/name:material.extensions.emoji.to_svg
extra:
version:
diff --git a/tests/example_based/test_projections.py b/tests/example_based/test_projections.py
index d7b7fab8..9e84694e 100644
--- a/tests/example_based/test_projections.py
+++ b/tests/example_based/test_projections.py
@@ -123,7 +123,7 @@ def test_attribution_projection_mapping():
model = _generate_model(input_shape=input_shape, output_shape=nb_labels)
- projection = AttributionProjection(model, method=Saliency, latent_layer="last_conv")
+ projection = AttributionProjection(model, attribution_method=Saliency, latent_layer="last_conv")
# Generate tf.data.Dataset from numpy
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).batch(3)
diff --git a/xplique/example_based/projections/attributions.py b/xplique/example_based/projections/attributions.py
index 46074fab..ad8ad878 100644
--- a/xplique/example_based/projections/attributions.py
+++ b/xplique/example_based/projections/attributions.py
@@ -57,11 +57,11 @@ class AttributionProjection(Projection):
def __init__(
self,
model: Callable,
- method: BlackBoxExplainer = Saliency,
+ attribution_method: BlackBoxExplainer = Saliency,
latent_layer: Optional[Union[str, int]] = None,
**attribution_kwargs
):
- self.method = method
+ self.attribution_method = attribution_method
if latent_layer is None:
# no split
@@ -79,7 +79,7 @@ def __init__(
attribution_kwargs["operator"] = target_free_classification_operator
# compute attributions
- get_weights = self.method(self.predictor, **attribution_kwargs)
+ get_weights = self.attribution_method(self.predictor, **attribution_kwargs)
# set methods
super().__init__(get_weights, space_projection, mappable=False)
diff --git a/xplique/example_based/projections/hadamard.py b/xplique/example_based/projections/hadamard.py
index 9ac2b8c8..5eca3c65 100644
--- a/xplique/example_based/projections/hadamard.py
+++ b/xplique/example_based/projections/hadamard.py
@@ -45,13 +45,13 @@ class HadamardProjection(Projection):
The method as described in the paper apply the separation on the last convolutional layer.
To do so, the `"last_conv"` parameter will extract it.
Otherwise, `-1` could be used for the last layer before softmax.
- operator # TODO: make a larger description.
+ operator
Operator to use to compute the explanation, if None use standard predictions.
device
Device to use for the projection, if None, use the default device.
Only used for PyTorch models. Ignored for TensorFlow models.
"""
-
+ # TODO: make a larger description of the operator arg.
def __init__(
self,
model: Callable,
diff --git a/xplique/example_based/similar_examples.py b/xplique/example_based/similar_examples.py
index 4c598fd8..1b213288 100644
--- a/xplique/example_based/similar_examples.py
+++ b/xplique/example_based/similar_examples.py
@@ -198,7 +198,7 @@ def __init__(
# build attribution projection
projection = AttributionProjection(
model=model,
- method=attribution_method,
+ attribution_method=attribution_method,
latent_layer=latent_layer,
**attribution_kwargs,
)
From 410ed98730142dd498f6666ae42fca7f1eb4b744 Mon Sep 17 00:00:00 2001
From: lucas Hervier
Date: Mon, 12 Aug 2024 16:11:45 +0200
Subject: [PATCH 092/138] docs: change the wording of the prototypes such that
the search methods are abstracted
---
.../prototypes/api_prototypes.md | 85 +++++++++++--------
.../example_based/prototypes/mmd_critic.md | 6 +-
.../example_based/prototypes/proto_dash.md | 8 +-
.../example_based/prototypes/proto_greedy.md | 6 +-
4 files changed, 59 insertions(+), 46 deletions(-)
diff --git a/docs/api/example_based/prototypes/api_prototypes.md b/docs/api/example_based/prototypes/api_prototypes.md
index c02132cd..2dfba112 100644
--- a/docs/api/example_based/prototypes/api_prototypes.md
+++ b/docs/api/example_based/prototypes/api_prototypes.md
@@ -5,9 +5,43 @@ Prototype-based explanation is a family of natural example-based XAI methods. Pr
- [Prototypes for Post-hoc Interpretability](#prototypes-for-post-hoc-interpretability)
- Prototype-Based Models Interpretable by Design
-This library focuses on first two classes.
+For now, the library focuses on the first two classes.
+
+## Common API ##
+
+```python
+
+explainer = Method(cases_dataset, labels_dataset, targets_dataset, k,
+ projection, case_returns, batch_size, distance,
+ nb_prototypes, kernel_type,
+ kernel_fn, gamma)
+# compute global explanation
+global_prototypes = explainer.get_global_prototypes()
+# compute local explanation
+local_prototypes = explainer(inputs)
+
+```
+
+??? abstract "Table of methods available"
+
+ The following Data-Centric prototypes methods are implemented:
+
+ | Method Name and Documentation link | **Tutorial** | Available with TF | Available with PyTorch* |
+ |:-------------------------------------- | :----------------------: | :---------------: | :---------------------: |
+ | [ProtoGreedy](../proto_greedy/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ |
+ | [ProtoDash](../proto_dash/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ |
+ | [MMDCritic](../mmd_critic/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ |
+
+ *: Before using a PyTorch model it is highly recommended to read the [dedicated documentation](../pytorch/)
+
+!!!info
+ Using the identity projection, one is looking for the **dataset prototypes**. In contrast, using the latent space of a model as a projection, one is looking for **prototypes relevant for the model**.
+
+!!!info
+ Prototypes, share a common API with other example-based methods. Thus, to understand some parameters, we recommend reading the [dedicated documentation](../../api_example_based/).
## Prototypes for Data-Centric Interpretability
+
In this class, prototypes are selected without relying on the model and provide an overview of
the dataset. As mentioned in ([Poché et al., 2023](https://hal.science/hal-04117520/document)), we found in this class: **clustering methods** and **data summarization methods**, also known as **set cover methods**. This library focuses on **data summarization methods** which can be treated in two ways [(Lin et al., 2011)](https://aclanthology.org/P11-1052.pdf):
@@ -19,17 +53,22 @@ consists in finding a low-cost subset of prototypes $\mathcal{P}$ under the con
For both cases, submodularity and monotonicity of $F(\mathcal{P})$ are necessary to guarantee that a greedy algorithm has a constant factor guarantee of optimality [(Lin et al., 2011)](https://aclanthology.org/P11-1052.pdf). In addition, $F(\mathcal{P})$ should encourage coverage and penalize redundancy in order to have a good summary [(Lin et al., 2011)](https://aclanthology.org/P11-1052.pdf).
-This library implements three methods from **Data summarization with knapsack constraint**: `MMDCritic`, `ProtoGreedy` and `ProtoDash`.
+The library implements three methods from **Data summarization with knapsack constraint**: `MMDCritic`, `ProtoGreedy` and `ProtoDash`.
+
[Kim et al., 2016](https://proceedings.neurips.cc/paper_files/paper/2016/file/5680522b8e2bb01943234bce7bf84534-Paper.pdf) proposed `MMDCritic` method that used a set function based on the Maximum Mean Discrepancy [(MMD)](#what-is-mmd). They solved **data summarization with knapsack constraint** problem to find both prototypes and criticisms. First, the number of prototypes and criticisms to be found, respectively as $m_p$ and $m_c$, are selected. Second, to find prototypes, a greedy algorithm is used to maximize $F(\mathcal{P})$ s.t. $|\mathcal{P}| \le m_p$ where $F(\mathcal{P})$ is defined as:
+
\begin{equation}
F(\mathcal{P})=\frac{2}{|\mathcal{P}|\cdot n}\sum_{i,j=1}^{|\mathcal{P}|,n}\kappa(p_i,x_j)-\frac{1}{|\mathcal{P}|^2}\sum_{i,j=1}^{|\mathcal{P}|}\kappa(p_i,p_j)
\end{equation}
+
They used diagonal dominance conditions on the kernel to ensure monotonocity and submodularity of $F(\mathcal{P})$. To find criticisms $\mathcal{C}$, the same greedy algorithm is used to select points that maximize another objective function $J(\mathcal{C})$.
[Gurumoorthy et al., 2019](https://arxiv.org/pdf/1707.01212) associated non-negative weights to prototypes which are indicative of their importance. This approach allows for identifying both prototypes and criticisms (the least weighted examples among prototypes) by maximizing the same weighted objective $F(\mathcal{P},w)$ defined as:
+
\begin{equation}
-F(\mathcal{P},w)=\frac{2}{n}\sum_{i,j=1}^{|\mathcal{P}|,n}w_i\kappa(p_i,x_j)-\sum_{i,j=1}^{|\mathcal{P}|}w_iw_j\kappa(p_i,p_j),
+ F(\mathcal{P},w)=\frac{2}{n}\sum_{i,j=1}^{|\mathcal{P}|,n}w_i\kappa(p_i,x_j)-\sum_{i,j=1}^{|\mathcal{P}|}w_iw_j\kappa(p_i,p_j),
\end{equation}
+
where $w$ are non-negative weights for each prototype. The problem then consist on finding $\mathcal{P}$ with a corresponding $w$ that maximizes $J(\mathcal{P}) \equiv \max_{w:supp(w)\in \mathcal{P},w\ge 0} J(\mathcal{P},w)$ s.t. $|\mathcal{P}| \leq m=m_p+m_c$. They established the weak submodular property of $J(\mathcal{P})$ and present tractable algorithms (`ProtoGreedy` and `ProtoDash`) to optimize it.
### Method comparison
@@ -41,6 +80,7 @@ where $w$ are non-negative weights for each prototype. The problem then consist
- The approximation guarantee for `ProtoGreedy` is $(1-e^{-\gamma})$, where $\gamma$ is submodularity ratio of $F(\mathcal{P})$, comparing to $(1-e^{-1})$ for `MMDCritic`.
### What is MMD?
+
The commonality among these three methods is their utilization of the Maximum Mean Discrepancy (MMD) statistic as a measure of similarity between points and potential prototypes. MMD is a statistic for comparing two distributions (similar to KL-divergence). However, it is a non-parametric statistic, i.e., it does not assume a specific parametric form for the probability distributions being compared. It is defined as follows:
$$
@@ -70,46 +110,19 @@ The choice of the kernel for selecting prototypes depends on the specific proble
If we consider any exponential kernel (Gaussian kernel, Laplace, ...), we automatically consider all the moments for the distribution, as the Taylor expansion of the exponential considers infinite-order moments. It is better to use a non-linear kernel to capture non-linear relationships in your data. If the problem is linear, it is better to choose a linear kernel such as the dot product kernel, since it is computationally efficient and often requires fewer hyperparameters to tune.
!!!warning
- For `MMDCritic`, the kernel must satisfy a condition ensuring the submodularity of the set function (the Gaussian kernel respects this constraint). In contrast, for `Protodash` and `Protogreedy`, any kernel can be used, as these methods rely on weak submodularity instead of full submodularity.
+ For `MMDCritic`, the kernel must satisfy a condition ensuring the submodularity of the set function (the Gaussian kernel respects this constraint). In contrast, for `ProtoDash` and `ProtoGreedy`, any kernel can be used, as these methods rely on weak submodularity instead of full submodularity.
### Default kernel
The default kernel used is Gaussian kernel. This kernel distance assigns higher similarity to points that are close in feature space and gradually decreases similarity as points move further apart. It is a good choice when your data has complexity. However, it can be sensitive to the choice of hyperparameters, such as the width $\sigma$ of the Gaussian kernel, which may need to be carefully fine-tuned.
-### API Implementation
-
-The Data-Centric prototypes methods are implemented as [search methods](../../search_methods/):
-
-| Method Name and Documentation link | **Tutorial** | Available with TF | Available with PyTorch* |
-|:-------------------------------------- | :----------------------: | :---------------: | :---------------------: |
-| [ProtoGreedySearch](../proto_greedy/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ |
-| [ProtoDashSearch](../proto_dash/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ |
-| [MMDCriticSearch](../mmd_critic/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ |
+### Implementation details
-*: Before using a PyTorch model it is highly recommended to read the [dedicated documentation](../pytorch/)
+The search method for `ProtoGreedy` inherits from the `BaseSearchMethod` class. It finds prototypes and assigns a non-negative weight to each one.
-The class `ProtoGreedySearch` inherits from the `BaseSearchMethod` class. It finds prototypes and assigns a non-negative weight to each one.
+Both the search methods for `MMDCritic` and `ProtoDash` classes inherit from the one defined for `ProtoGreedy`. The search method for `MMDCritic` differs from `ProtoGreedy` by assigning equal weights to the selection of prototypes. The two classes use the same greedy algorithm. In the `compute_objective` method of the search method of `ProtoGreedy`, for each new candidate, we calculate the best weights for the selection of prototypes. However, in `MMDCritic`, the `compute_objective` method assigns the same weight to all elements in the selection.
-Both the `MMDCriticSearch` and `ProtoDashSearch` classes inherit from the `ProtoGreedySearch` class. The class `MMDCriticSearch` differs from `ProtoGreedySearch` by assigning equal weights to the selection of prototypes. The two classes use the same greedy algorithm. In the `compute_objective` method of `ProtoGreedySearch`, for each new candidate, we calculate the best weights for the selection of prototypes. However, in `MMDCriticSearch`, the `compute_objective` method assigns the same weight to all elements in the selection.
-
-The class `ProtoDashSearch`, like `ProtoGreedySearch`, assigns a non-negative weight to each prototype. However, the algorithm used by `ProtoDashSearch` is [different](#method-comparison) from the one used by `ProtoGreedySearch`. Therefore, `ProtoDashSearch` overrides both the `compute_objective` method and the `update_selection` method.
+`ProtoDash`, like `ProtoGreedy`, assigns a non-negative weight to each prototype. However, the algorithm used by `ProtoDash` is [different](#method-comparison) from the one used by `ProtoGreedy`. Therefore, search method of `ProtoDash` overrides both the `compute_objective` method and the `update_selection` method.
## Prototypes for Post-hoc Interpretability
-Data-Centric methods such as `Protogreedy`, `ProtoDash` and `MMDCritic` can be used in either the output or the latent space of the model. In these cases, [projections methods](../../projections/) are used to transfer the data from the input space to the latent/output spaces.
-
-The search method can have attribute `projection` that projects samples to a space where distances between samples make sense for the model. Then the `search_method` finds the prototypes by looking in the projected space.
-
-## Common API ##
-
-```python
-
-explainer = Method(cases_dataset, labels_dataset, targets_dataset, k,
- projection, case_returns, batch_size, distance,
- nb_prototypes, kernel_type,
- kernel_fn, gamma)
-# compute global explanation
-global_prototypes = explainer.get_global_prototypes()
-# compute local explanation
-local_prototypes = explainer(inputs)
-
-```
+Data-Centric methods such as `ProtoGreedy`, `ProtoDash` and `MMDCritic` can be used in either the output or the latent space of the model. In these cases, [projections methods](../../projections/) are used to transfer the data from the input space to the latent/output spaces.
diff --git a/docs/api/example_based/prototypes/mmd_critic.md b/docs/api/example_based/prototypes/mmd_critic.md
index 8743fdbc..e9f436a1 100644
--- a/docs/api/example_based/prototypes/mmd_critic.md
+++ b/docs/api/example_based/prototypes/mmd_critic.md
@@ -1,4 +1,4 @@
-# MMDCriticSearch
+# MMDCritic
@@ -8,7 +8,7 @@
[View source](https://github.com/deel-ai/xplique/blob/antonin/example-based-merge/xplique/example_based/search_methods/proto_greedy_search.py) |
📰 [Paper](https://proceedings.neurips.cc/paper_files/paper/2016/file/5680522b8e2bb01943234bce7bf84534-Paper.pdf)
-`MMDCriticSearch` finds prototypes and criticisms by maximizing two separate objectives based on the Maximum Mean Discrepancy (MMD).
+`MMDCritic` finds prototypes and criticisms by maximizing two separate objectives based on the Maximum Mean Discrepancy (MMD).
!!! quote
MMD-critic uses the MMD statistic as a measure of similarity between points and potential prototypes, and
@@ -52,7 +52,7 @@ local_prototypes = explainer(inputs)
- [**MMDCritic**: Going Further](https://colab.research.google.com/drive/1nsB7xdQbU0zeYQ1-aB_D-M67-RAnvt4X)
-{{xplique.example_based.search_methods.MMDCriticSearch}}
+{{xplique.example_based.prototypes.MMDCritic}}
[^1]: [Visual Explanations from Deep Networks via Gradient-based Localization (2016).](https://arxiv.org/abs/1610.02391)
diff --git a/docs/api/example_based/prototypes/proto_dash.md b/docs/api/example_based/prototypes/proto_dash.md
index d694504d..3684dcf2 100644
--- a/docs/api/example_based/prototypes/proto_dash.md
+++ b/docs/api/example_based/prototypes/proto_dash.md
@@ -1,4 +1,4 @@
-# ProtoDashSearch
+# ProtoDash
@@ -8,7 +8,7 @@
[View source](https://github.com/deel-ai/xplique/blob/antonin/example-based-merge/xplique/example_based/search_methods/proto_greedy_search.py) |
📰 [Paper](https://arxiv.org/abs/1707.01212)
-`ProtoDahsSearch` associated non-negative weights to prototypes which are indicative of their importance. This approach allows for identifying both prototypes and criticisms (the least weighted examples among prototypes) by maximmizing the same weighted objective function.
+`ProtoDash` associated non-negative weights to prototypes which are indicative of their importance. This approach allows for identifying both prototypes and criticisms (the least weighted examples among prototypes) by maximmizing the same weighted objective function.
!!! quote
Our work notably generalizes the recent work
@@ -28,7 +28,7 @@ F(\mathcal{P},w)=\frac{2}{n}\sum_{i,j=1}^{|\mathcal{P}|,n}w_i\kappa(p_i,x_j)-\su
\end{equation}
where $w$ are non-negative weights for each prototype. The problem then consist on finding a subset $\mathcal{P}$ with a corresponding $w$ that maximizes $J(\mathcal{P}) \equiv \max_{w:supp(w)\in \mathcal{P},w\ge 0} J(\mathcal{P},w)$ s.t. $|\mathcal{P}| \leq m=m_p+m_c$.
-[Gurumoorthy et al., 2019](https://arxiv.org/abs/1707.01212) proposed `ProtoDash` algorithm, which is much faster that `ProtoGreedy` without compromising on the quality of the solution. In fact, `ProtoGreedy` selects the next element that maximizes the increment of the scoring function, whereas `ProtoDash` selects the next element that maximizes a tight lower bound on the increment of the scoring function.
+[Gurumoorthy et al., 2019](https://arxiv.org/abs/1707.01212) proposed `ProtoDash` algorithm, which is much faster that [`ProtoGreedy`](../proto_greedy/) without compromising on the quality of the solution. In fact, `ProtoGreedy` selects the next element that maximizes the increment of the scoring function, whereas `ProtoDash` selects the next element that maximizes a tight lower bound on the increment of the scoring function.
## Example
@@ -55,6 +55,6 @@ local_prototypes = explainer(inputs)
- [**ProtoDash**: Going Further](https://colab.research.google.com/drive/1nsB7xdQbU0zeYQ1-aB_D-M67-RAnvt4X)
-{{xplique.example_based.search_methods.ProtoDashSearch}}
+{{xplique.example_based.prototypes.ProtoDash}}
[^1]: [Visual Explanations from Deep Networks via Gradient-based Localization (2016).](https://arxiv.org/abs/1610.02391)
diff --git a/docs/api/example_based/prototypes/proto_greedy.md b/docs/api/example_based/prototypes/proto_greedy.md
index 57644ef3..35900522 100644
--- a/docs/api/example_based/prototypes/proto_greedy.md
+++ b/docs/api/example_based/prototypes/proto_greedy.md
@@ -1,4 +1,4 @@
-# ProtoGreedySearch
+# ProtoGreedy
@@ -8,7 +8,7 @@
[View source](https://github.com/deel-ai/xplique/blob/antonin/example-based-merge/xplique/example_based/search_methods/proto_greedy_search.py) |
📰 [Paper](https://arxiv.org/abs/1707.01212)
-`ProtoGreedySearch` associated non-negative weights to prototypes which are indicative of their importance. This approach allows for identifying both prototypes and criticisms (the least weighted examples among prototypes) by maximmizing the same weighted objective function.
+`ProtoGreedy` associated non-negative weights to prototypes which are indicative of their importance. This approach allows for identifying both prototypes and criticisms (the least weighted examples among prototypes) by maximizing the same weighted objective function.
!!! quote
Our work notably generalizes the recent work
@@ -57,6 +57,6 @@ local_prototypes = explainer(inputs)
- [**ProtoGreedy**: Going Further](https://colab.research.google.com/drive/1nsB7xdQbU0zeYQ1-aB_D-M67-RAnvt4X)
-{{xplique.example_based.search_methods.ProtoGreedySearch}}
+{{xplique.example_based.prototypes.ProtoGreedy}}
[^1]: [Visual Explanations from Deep Networks via Gradient-based Localization (2016).](https://arxiv.org/abs/1610.02391)
From fbef589cc2e599366c5585f0f0b2b3ce844516ed Mon Sep 17 00:00:00 2001
From: lucas Hervier
Date: Mon, 12 Aug 2024 16:18:44 +0200
Subject: [PATCH 093/138] docs: update the main example-based doc page
considering antonin's comments and modifications for prototypes
---
docs/api/example_based/api_example_based.md | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/docs/api/example_based/api_example_based.md b/docs/api/example_based/api_example_based.md
index 6485b09a..6412e674 100644
--- a/docs/api/example_based/api_example_based.md
+++ b/docs/api/example_based/api_example_based.md
@@ -15,7 +15,7 @@ While not being exhaustive we tried to cover a range of methods that are represe
At present, we made the following choices:
- Focus on methods that are natural example methods (post-hoc and non-generative, see the paper above for more details).
-- Try to unify the three families of approaches with a common API.
+- Try to unify the four families of approaches with a common API.
!!! info
We are in the early stages of development and are looking for feedback on the API design and the methods we have chosen to implement. Also, we are counting on the community to furnish the collection of methods available. If you are willing to contribute reach us on the [GitHub](https://github.com/deel-ai/xplique) repository (with an issue, pull request, ...).
@@ -41,14 +41,14 @@ explanations = explainer.explain(inputs, targets)
We tried to keep the API as close as possible to the one of the attribution methods to keep a consistent experience for the users.
-The `BaseExampleMethod` is an abstract base class designed for example-based methods used to explain classification models. It provides examples from a dataset (usually the training dataset) to help understand a model's predictions. Examples are projected from the input space to a search space using a [projection function](#projections). The projection function defines the search space. Then, examples are selected using a [search method](#search-methods) within the search space. For all example-based methods, one can define the `distance` function that will be used by the search method.
+The `BaseExampleMethod` is an abstract base class designed for example-based methods used to explain classification models. It provides examples from a dataset (usually the training dataset) to help understand a model's predictions. Examples are projected from the input space to a search space using a [projection function](#projections). The projection function defines the search space. Then, examples are selected using a [search method](#search-methods) within the search space. For all example-based methods, one can define the `distance` that will be used by the search method.
We can broadly categorize example-based methods into four families: similar examples, counter-factuals, semi-factuals, and prototypes.
- **Similar Examples**: This method involves finding instances in the dataset that are similar to a given instance. The similarity is often determined based on the feature space, and these examples can help in understanding the model's decision by showing what other data points resemble the instance in question.
- **Counter Factuals**: Counterfactual explanations identify the minimal changes needed to an instance's features to change the model's prediction to a different, specified outcome. They help answer "what-if" scenarios by showing how altering certain aspects of the input would lead to a different decision.
- **Semi Factuals**: Semifactual explanations describe hypothetical situations where most features of an instance remain the same except for one or a few features, without changing the overall outcome. They highlight which features could vary without altering the prediction.
-- **Prototypes**: Prototypes are representative examples from the dataset that summarize typical cases within a certain category or cluster. They act as archetypal instances that the model uses to make predictions, providing a reference point for understanding model behavior.
+- **Prototypes**: Prototypes are representative examples from the dataset that summarize typical cases within a certain category or cluster. They act as archetypal instances that the model uses to make predictions, providing a reference point for understanding model behavior. Additional documentation can be found in the [Prototypes API documentation](../prototypes/api_prototypes/).
??? abstract "Table of example-based methods available"
@@ -93,8 +93,8 @@ We can broadly categorize example-based methods into four families: similar exam
Returns the relevant examples to explain the (inputs, targets). Projects inputs using `self.projection` and finds examples using the `self.search_method`.
-- **inputs** (`Union[tf.Tensor, np.ndarray]`): Input samples to be explained.
-- **targets** (`Optional[Union[tf.Tensor, np.ndarray]]`): Targets associated with the `cases_dataset` for dataset projection.
+- **inputs** (`Union[tf.Tensor, np.ndarray]`): Input samples to be explained. Shape: (n, ...) where n is the number of samples.
+- **targets** (`Optional[Union[tf.Tensor, np.ndarray]]`): Targets associated with the `cases_dataset` for dataset projection. Shape: (n, nb_classes) where n is the number of samples and nb_classes is the number of classes.
**Returns:** Dictionary with elements listed in `self.returns`.
From 4c2dbc07772a293dc82ed2d610780ea5ad350690 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Tue, 13 Aug 2024 14:16:01 +0200
Subject: [PATCH 094/138] projections tests: adapt to commons evolution
---
tests/example_based/test_projections.py | 44 +++++++++++++++++++++----
1 file changed, 38 insertions(+), 6 deletions(-)
diff --git a/tests/example_based/test_projections.py b/tests/example_based/test_projections.py
index 9e84694e..22d083f4 100644
--- a/tests/example_based/test_projections.py
+++ b/tests/example_based/test_projections.py
@@ -11,7 +11,7 @@
)
from xplique.attributions import Saliency
-from xplique.example_based.projections import Projection, AttributionProjection, LatentSpaceProjection
+from xplique.example_based.projections import Projection, AttributionProjection, LatentSpaceProjection, HadamardProjection
from xplique.example_based.projections.commons import model_splitting
@@ -81,13 +81,18 @@ def test_simple_projection_mapping():
space_projection = lambda x, y=None: tf.nn.max_pool2d(x, ksize=3, strides=1, padding="SAME")
- projection = Projection(get_weights=weights, space_projection=space_projection)
+ projection = Projection(get_weights=weights, space_projection=space_projection, mappable=True)
# Generate tf.data.Dataset from numpy
- train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(3)
+ train_dataset = tf.data.Dataset.from_tensor_slices(x_train).batch(3)
+ targets_dataset = tf.data.Dataset.from_tensor_slices(y_train).batch(3)
# Apply the projection by mapping the dataset
- projected_train_dataset = projection.project_dataset(train_dataset)
+ projected_train_dataset = projection.project_dataset(train_dataset, targets_dataset)
+
+ # Apply the projection by iterating over the dataset
+ projection.mappable = False
+ projected_train_dataset = projection.project_dataset(train_dataset, targets_dataset)
def test_latent_space_projection_mapping():
@@ -105,10 +110,37 @@ def test_latent_space_projection_mapping():
projection = LatentSpaceProjection(model, "last_conv")
# Generate tf.data.Dataset from numpy
- train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(3)
+ train_dataset = tf.data.Dataset.from_tensor_slices(x_train).batch(3)
+ targets_dataset = tf.data.Dataset.from_tensor_slices(y_train).batch(3)
+
+ # Apply the projection by mapping the dataset
+ projected_train_dataset = projection.project_dataset(train_dataset, targets_dataset)
+ projected_train_dataset = projection._map_project_dataset(train_dataset, targets_dataset)
+ projected_train_dataset = projection._loop_project_dataset(train_dataset, targets_dataset)
+
+
+def test_hadamard_projection_mapping():
+ """
+ Test if the hadamard projection can be mapped.
+ """
+ # Setup
+ input_shape = (7, 7, 3)
+ nb_samples = 10
+ nb_labels = 2
+ x_train, _, y_train = get_setup(input_shape, nb_samples=nb_samples, nb_labels=nb_labels)
+
+ model = _generate_model(input_shape=input_shape, output_shape=nb_labels)
+
+ projection = HadamardProjection(model, "last_conv")
+
+ # Generate tf.data.Dataset from numpy
+ train_dataset = tf.data.Dataset.from_tensor_slices(x_train).batch(3)
+ targets_dataset = tf.data.Dataset.from_tensor_slices(y_train).batch(3)
# Apply the projection by mapping the dataset
- projected_train_dataset = projection.project_dataset(train_dataset)
+ projected_train_dataset = projection.project_dataset(train_dataset, targets_dataset)
+ projected_train_dataset = projection._map_project_dataset(train_dataset, targets_dataset)
+ projected_train_dataset = projection._loop_project_dataset(train_dataset, targets_dataset)
def test_attribution_projection_mapping():
From e7f9d4a96d46235c3490ed975d74270b0f1d24e4 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Tue, 13 Aug 2024 17:31:16 +0200
Subject: [PATCH 095/138] pojections tests: add simple splitteing test
---
tests/example_based/test_projections.py | 26 ++++++++++++++++++++++++-
1 file changed, 25 insertions(+), 1 deletion(-)
diff --git a/tests/example_based/test_projections.py b/tests/example_based/test_projections.py
index 22d083f4..ec303d05 100644
--- a/tests/example_based/test_projections.py
+++ b/tests/example_based/test_projections.py
@@ -3,6 +3,7 @@
from tensorflow.keras.layers import (
Dense,
Conv2D,
+ Conv1D,
Activation,
Dropout,
Flatten,
@@ -14,6 +15,8 @@
from xplique.example_based.projections import Projection, AttributionProjection, LatentSpaceProjection, HadamardProjection
from xplique.example_based.projections.commons import model_splitting
+from ..utils import almost_equal
+
def get_setup(input_shape, nb_samples=10, nb_labels=2):
"""
@@ -95,6 +98,27 @@ def test_simple_projection_mapping():
projected_train_dataset = projection.project_dataset(train_dataset, targets_dataset)
+def test_model_splitting():
+ """
+ Test if projected samples have the expected values
+ """
+ x_train = np.reshape(np.arange(0, 100), (10, 10))
+
+ model = tf.keras.Sequential()
+ model.add(Input(shape=(10,)))
+ model.add(Dense(10, name="dense1"))
+ model.add(Dense(1, name="dense2"))
+ model.compile(loss="categorical_crossentropy", optimizer="sgd")
+
+ model.get_layer("dense1").set_weights([np.eye(10) * np.sign(np.arange(-4.5, 5.5)), np.zeros(10)])
+ model.get_layer("dense2").set_weights([np.ones((10, 1)), np.zeros(1)])
+
+ # Split the model
+ features_extractor, predictor = model_splitting(model, latent_layer="dense1")
+
+ assert almost_equal(predictor(features_extractor(x_train)).numpy(), model(x_train))
+
+
def test_latent_space_projection_mapping():
"""
Test if the latent space projection can be mapped.
@@ -162,4 +186,4 @@ def test_attribution_projection_mapping():
targets_dataset = tf.data.Dataset.from_tensor_slices(y_train).batch(3)
# Apply the projection by mapping the dataset
- projected_train_dataset = projection.project_dataset(train_dataset, targets_dataset)
\ No newline at end of file
+ projected_train_dataset = projection.project_dataset(train_dataset, targets_dataset)
From 85fc97d3d0de4d0f848b21eeca9119df7aa16889 Mon Sep 17 00:00:00 2001
From: lucas Hervier
Date: Wed, 14 Aug 2024 10:36:07 +0200
Subject: [PATCH 096/138] fix: correction on the labelawarecf method for
computation of the mask, add a missing case returns possibility for semi
factuals
---
xplique/example_based/counterfactuals.py | 6 +++---
xplique/example_based/semifactuals.py | 2 +-
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/xplique/example_based/counterfactuals.py b/xplique/example_based/counterfactuals.py
index 360fdcda..6f0dac19 100644
--- a/xplique/example_based/counterfactuals.py
+++ b/xplique/example_based/counterfactuals.py
@@ -250,9 +250,9 @@ def filter_fn(self, _, __, cf_expected_classes, cases_targets) -> tf.Tensor:
cases_targets
The one-hot encoding of the target class for the cases.
"""
- mask = tf.matmul(cf_expected_classes, cases_targets, transpose_b=True) #(n, bs)
- # TODO: I think some retracing are done here
- mask = tf.cast(mask, dtype=tf.bool)
+ cases_predicted_labels = tf.argmax(cases_targets, axis=-1)
+ cf_label_targets = tf.argmax(cf_expected_classes, axis=-1)
+ mask = tf.equal(tf.expand_dims(cf_label_targets, axis=1), cases_predicted_labels)
return mask
@sanitize_inputs_targets
diff --git a/xplique/example_based/semifactuals.py b/xplique/example_based/semifactuals.py
index b912eb1e..572d508d 100644
--- a/xplique/example_based/semifactuals.py
+++ b/xplique/example_based/semifactuals.py
@@ -80,7 +80,7 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
by default "euclidean".
"""
_returns_possibilities = [
- "examples", "weights", "distances", "labels", "include_inputs", "nuns", "nuns_indices", "dist_to_nuns"
+ "examples", "weights", "distances", "labels", "include_inputs", "nuns", "nuns_indices", "dist_to_nuns", "nuns_labels"
]
def __init__(
From 1bb8a1631ea1e765f62179c58457a638dea7c7f8 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Aug 2024 10:48:26 +0200
Subject: [PATCH 097/138] projections: add warning to tensorflow splitting
---
xplique/example_based/projections/commons.py | 16 +++++++++++++---
1 file changed, 13 insertions(+), 3 deletions(-)
diff --git a/xplique/example_based/projections/commons.py b/xplique/example_based/projections/commons.py
index 45110310..ea859690 100644
--- a/xplique/example_based/projections/commons.py
+++ b/xplique/example_based/projections/commons.py
@@ -86,6 +86,13 @@ def _tf_model_splitting(model: tf.keras.Model,
latent_layer
Layer used to split the `model`.
"""
+
+ warnings.warn(
+ "Automatically splitting the provided TensorFlow model into two parts. "\
+ +"This splitting is not robust to all models. "\
+ +"It is recommended to split the model manually. "\
+ +"Then the splitted parts can be provided through the `from_splitted_model` method.")
+
if latent_layer == "last_conv":
latent_layer = next(
layer for layer in model.layers[::-1] if hasattr(layer, "filters")
@@ -153,9 +160,12 @@ def _torch_model_splitting(model: 'torch.nn.Module',
import torch.nn as nn
from ...wrappers import TorchWrapper
- warnings.warn("Automatically splitting the provided PyTorch model into two parts. "\
- +"This splitting is based on `model.named_children()`. "\
- +"If the model cannot be reconstructed via sub-modules, errors are to be expected.")
+ warnings.warn(
+ "Automatically splitting the provided PyTorch model into two parts. "\
+ +"This splitting is based on `model.named_children()`. "\
+ +"If the model cannot be reconstructed via sub-modules, errors are to be expected. "\
+ +"It is recommended to split the model manually and wrap it with `TorchWrapper`. "\
+ +"Then the wrapped parts can be provided through the `from_splitted_model` method.")
if device is None:
warnings.warn("No device provided for the projection, using 'cuda' if available, else 'cpu'.")
From ac74a9d4365ce6b7199655b9f61ae6a854362d0c Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Aug 2024 10:48:56 +0200
Subject: [PATCH 098/138] commons: linting
---
xplique/commons/tf_dataset_operations.py | 103 ++---------------------
1 file changed, 8 insertions(+), 95 deletions(-)
diff --git a/xplique/commons/tf_dataset_operations.py b/xplique/commons/tf_dataset_operations.py
index 783b7e57..a08d2902 100644
--- a/xplique/commons/tf_dataset_operations.py
+++ b/xplique/commons/tf_dataset_operations.py
@@ -150,103 +150,13 @@ def sanitize_dataset(
assert dataset_cardinality == cardinality, (
"The number of batch should match between datasets. "
+ f"Received {dataset.cardinality().numpy()} vs {cardinality}. "
- + "You may have provided non-batched datasets or datasets with different length."
+ + "You may have provided non-batched datasets "\
+ + "or datasets with different lengths."
)
return dataset
-# def dataset_gather(dataset: tf.data.Dataset, indices: tf.Tensor) -> tf.Tensor:
-# """
-# Imitation of `tf.gather` for `tf.data.Dataset`,
-# it extract elements from `dataset` at the given indices.
-# We could see it as returning the `indices` tensor
-# where each index was replaced by the corresponding element in `dataset`.
-# The aim is to use it in the `example_based` module to extract examples form the cases dataset.
-# Hence, `indices` expect dimensions of (n, k, 2),
-# where n represent the number of inputs and k the number of corresponding examples.
-# Here indices for each element are encoded by two values,
-# the batch index and the index of the element in the batch.
-
-# Example of application
-# ```
-# >>> dataset = tf.data.Dataset.from_tensor_slices(
-# ... tf.reshape(tf.range(20), (-1, 2, 2))
-# ... ).batch(3) # shape=(None, 2, 2)
-# >>> indices = tf.constant([[[0, 0]], [[1, 0]]]) # shape=(2, 1, 2)
-# >>> dataset_gather(dataset, indices)
-#
-# ```
-
-# Parameters
-# ----------
-# dataset
-# Tensorflow dataset to verify or tensor to transform in `tf.data.Dataset` and verify.
-# indices
-# Tensor of indices of elements to extract from the `dataset`.
-# `indices` should be of dimensions (n, k, 2),
-# this is to match the format of indices in the `example_based` module.
-# Indeed, n represent the number of inputs and k the number of corresponding examples.
-# The index of each element is encoded by two values,
-# the batch index and the index of the element in the batch.
-
-# Returns
-# -------
-# results
-# A tensor with the extracted elements from the `dataset`.
-# The shape of the tensor is (n, k, ...), where ... is the shape of the elements in the `dataset`.
-# """
-# if dataset is None:
-# return None
-
-# if len(indices.shape) != 3 or indices.shape[-1] != 2:
-# raise ValueError(
-# "Indices should have dimensions (n, k, 2), "
-# + "where n represent the number of inputs and k the number of corresponding examples. "
-# + "The index of each element is encoded by two values, "
-# + "the batch index and the index of the element in the batch. "
-# + f"Received {indices.shape}."
-# )
-
-# example = next(iter(dataset))
-# # (n, bs, ...)
-# with tf.device('/CPU:0'):
-# if dataset.element_spec.dtype in ['uint8', 'int8', 'int16', 'int32', 'int64']:
-# results = tf.Variable(
-# tf.fill(indices.shape[:-1] + example[0].shape, tf.constant(-1, dtype=dataset.element_spec.dtype)),
-# )
-# else:
-# results = tf.Variable(
-# tf.fill(indices.shape[:-1] + example[0].shape, tf.constant(np.inf, dtype=dataset.element_spec.dtype)),
-# )
-
-# nb_results = product(indices.shape[:-1])
-# current_nb_results = 0
-
-# for i, batch in enumerate(dataset):
-# # check if the batch is interesting
-# if not tf.reduce_any(indices[..., 0] == i):
-# continue
-
-# # extract pertinent elements
-# pertinent_indices_location = tf.where(indices[..., 0] == i)
-# samples_index = tf.gather_nd(indices[..., 1], pertinent_indices_location)
-# samples = tf.gather(batch, samples_index)
-
-# # put them at the right place in results
-# for location, sample in zip(pertinent_indices_location, samples):
-# results[location[0], location[1]].assign(sample)
-# current_nb_results += 1
-
-# # test if results are filled to break the loop
-# if current_nb_results == nb_results:
-# break
-# return results
-
def dataset_gather(dataset: tf.data.Dataset, indices: tf.Tensor) -> tf.Tensor:
"""
Imitation of `tf.gather` for `tf.data.Dataset`,
@@ -289,7 +199,8 @@ def dataset_gather(dataset: tf.data.Dataset, indices: tf.Tensor) -> tf.Tensor:
-------
results
A tensor with the extracted elements from the `dataset`.
- The shape of the tensor is (n, k, ...), where ... is the shape of the elements in the `dataset`.
+ The shape of the tensor is (n, k, ...),
+ where ... is the shape of the elements in the `dataset`.
"""
if dataset is None:
return None
@@ -306,9 +217,11 @@ def dataset_gather(dataset: tf.data.Dataset, indices: tf.Tensor) -> tf.Tensor:
example = next(iter(dataset))
if dataset.element_spec.dtype in ['uint8', 'int8', 'int16', 'int32', 'int64']:
- results = tf.fill(indices.shape[:-1] + example[0].shape, tf.constant(-1, dtype=dataset.element_spec.dtype))
+ results = tf.fill(dims=indices.shape[:-1] + example[0].shape,
+ value=tf.constant(-1, dtype=dataset.element_spec.dtype))
else:
- results = tf.fill(indices.shape[:-1] + example[0].shape, tf.constant(np.inf, dtype=dataset.element_spec.dtype))
+ results = tf.fill(dims=indices.shape[:-1] + example[0].shape,
+ value=tf.constant(np.inf, dtype=dataset.element_spec.dtype))
nb_results = product(indices.shape[:-1])
current_nb_results = 0
From d709f0675ed37f3fab9cc1f19d956eb55a66fa5a Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Aug 2024 15:15:29 +0200
Subject: [PATCH 099/138] requirements: limit to tensorflow < 2.16
---
requirements.txt | 2 +-
setup.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/requirements.txt b/requirements.txt
index da889c05..b7985ca3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-tensorflow>=2.1.0
+tensorflow >= 2.1.0, < 2.16
numpy
scikit-learn
scikit-image
diff --git a/setup.py b/setup.py
index b5dc14d7..96f0aebf 100644
--- a/setup.py
+++ b/setup.py
@@ -12,7 +12,7 @@
author="Thomas FEL",
author_email="thomas_fel@brown.edu",
license="MIT",
- install_requires=['tensorflow>=2.1.0', 'numpy', 'scikit-learn', 'scikit-image',
+ install_requires=['tensorflow>=2.1.0,<2.16', 'numpy', 'scikit-learn', 'scikit-image',
'matplotlib', 'scipy', 'opencv-python', 'deprecated'],
extras_require={
"tests": ["pytest", "pylint"],
From cedf74481aa666d8fcecabe346063160728c0675 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Wed, 14 Aug 2024 15:15:49 +0200
Subject: [PATCH 100/138] linting
---
xplique/commons/tf_dataset_operations.py | 12 ++--
xplique/example_based/base_example_method.py | 31 ++++++----
xplique/example_based/counterfactuals.py | 61 +++++++++++--------
xplique/example_based/prototypes.py | 14 +++--
.../example_based/search_methods/common.py | 6 +-
.../search_methods/proto_greedy_search.py | 10 +--
6 files changed, 75 insertions(+), 59 deletions(-)
diff --git a/xplique/commons/tf_dataset_operations.py b/xplique/commons/tf_dataset_operations.py
index a08d2902..20933bfa 100644
--- a/xplique/commons/tf_dataset_operations.py
+++ b/xplique/commons/tf_dataset_operations.py
@@ -204,18 +204,18 @@ def dataset_gather(dataset: tf.data.Dataset, indices: tf.Tensor) -> tf.Tensor:
"""
if dataset is None:
return None
-
+
if len(indices.shape) != 3 or indices.shape[-1] != 2:
raise ValueError(
- "Indices should have dimensions (n, k, 2), "
- + "where n represent the number of inputs and k the number of corresponding examples. "
- + "The index of each element is encoded by two values, "
- + "the batch index and the index of the element in the batch. "
+ "Indices should have dimensions (n, k, 2), "\
+ + "where n represent the number of inputs and k the number of corresponding examples. "\
+ + "The index of each element is encoded by two values, "\
+ + "the batch index and the index of the element in the batch. "\
+ f"Received {indices.shape}."
)
example = next(iter(dataset))
-
+
if dataset.element_spec.dtype in ['uint8', 'int8', 'int16', 'int32', 'int64']:
results = tf.fill(dims=indices.shape[:-1] + example[0].shape,
value=tf.constant(-1, dtype=dataset.element_spec.dtype))
diff --git a/xplique/example_based/base_example_method.py b/xplique/example_based/base_example_method.py
index 7e6e19e9..c03c1665 100644
--- a/xplique/example_based/base_example_method.py
+++ b/xplique/example_based/base_example_method.py
@@ -22,9 +22,10 @@
class BaseExampleMethod(ABC):
"""
Base class for natural example-based methods explaining classification models.
- An example-based method is a method that explains a model's predictions by providing examples from the cases_dataset
- (usually the training dataset). The examples are selected with the help of a search method that performs a search in
- the search space. The search space is defined with the help of a projection function that projects the cases_dataset
+ An example-based method is a method that explains a model's predictions by providing
+ examples from the cases_dataset (usually the training dataset). The examples are selected with
+ the help of a search method that performs a search in the search space. The search space is
+ defined with the help of a projection function that projects the cases_dataset
and the (inputs, targets) to explain into a space where the search method is relevant.
Parameters
@@ -41,8 +42,8 @@ class BaseExampleMethod(ABC):
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
targets_dataset
- Targets associated to the cases_dataset for dataset projection, oftentimes the one-hot encoding of a model's
- predictions. See `projection` for detail.
+ Targets associated to the cases_dataset for dataset projection,
+ oftentimes the one-hot encoding of a model's predictions. See `projection` for detail.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
Batch size and cardinality of other datasets should match `cases_dataset`.
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
@@ -52,8 +53,7 @@ class BaseExampleMethod(ABC):
projection
Projection or Callable that project samples from the input space to the search space.
The search space should be a space where distances are relevant for the model.
- It should not be `None`, otherwise, the model is not involved thus not explained. If you are interested in
- searching the input space, you should use a `BaseSearchMethod` instead.
+ It should not be `None`, otherwise, the model is not involved thus not explained.
Example of Callable:
```
@@ -73,6 +73,7 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
Number of sample treated simultaneously for projection and search.
Ignored if `tf.data.Dataset` are provided (those are supposed to be batched).
"""
+ # pylint: disable=too-many-instance-attributes
_returns_possibilities = ["examples", "distances", "labels", "include_inputs"]
def __init__(
@@ -104,8 +105,7 @@ def __init__(
self.projection = Projection(get_weights=None, space_projection=projection)
else:
raise AttributeError(
- "projection should be a `Projection` or a `Callable`, not a"
- + f"{type(projection)}"
+ f"projection should be a `Projection` or a `Callable`, not a {type(projection)}"
)
# project dataset
@@ -115,7 +115,10 @@ def __init__(
# set properties
self.k = k
self.returns = case_returns
-
+
+ # temporary value for the search method
+ self.search_method = None
+
@property
@abstractmethod
def search_method_class(self) -> Type[BaseSearchMethod]:
@@ -257,7 +260,7 @@ def _initialize_cases_dataset(
self.cases_dataset = self.cases_dataset.map(lambda x, y, t: x)
else:
raise AttributeError(
- "`cases_dataset` cannot possess more than 3 columns, "
+ "`cases_dataset` cannot possess more than 3 columns, "\
+ f"{len(self.cases_dataset.element_spec)} were detected."
)
@@ -295,7 +298,8 @@ def explain(
-------
return_dict
Dictionary with listed elements in `self.returns`.
- The elements that can be returned are defined with _returns_possibilities static attribute of the class.
+ The elements that can be returned are defined with the `_returns_possibilities`
+ static attribute of the class.
"""
# project inputs into the search space
projected_inputs = self.projection(inputs, targets)
@@ -337,7 +341,8 @@ def format_search_output(
-------
return_dict
Dictionary with listed elements in `self.returns`.
- The elements that can be returned are defined with _returns_possibilities static attribute of the class.
+ The elements that can be returned are defined with the `_returns_possibilities`
+ static attribute of the class.
"""
# initialize return dictionary
return_dict = {}
diff --git a/xplique/example_based/counterfactuals.py b/xplique/example_based/counterfactuals.py
index 6f0dac19..8b486dbf 100644
--- a/xplique/example_based/counterfactuals.py
+++ b/xplique/example_based/counterfactuals.py
@@ -16,8 +16,8 @@
class NaiveCounterFactuals(BaseExampleMethod):
"""
- This class allows to search for counterfactuals by searching for the closest sample to a query in a projection space
- that do not have the same model's prediction.
+ This class allows to search for counterfactuals by searching for the closest sample to
+ a query in a projection space that do not have the same model's prediction.
It is a naive approach as it follows a greedy approach.
Parameters
@@ -28,7 +28,8 @@ class NaiveCounterFactuals(BaseExampleMethod):
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
targets_dataset
- Targets are expected to be the one-hot encoding of the model's predictions for the samples in cases_dataset.
+ Targets are expected to be the one-hot encoding of
+ the model's predictions for the samples in cases_dataset.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
Batch size and cardinality of other datasets should match `cases_dataset`.
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
@@ -44,8 +45,7 @@ class NaiveCounterFactuals(BaseExampleMethod):
projection
Projection or Callable that project samples from the input space to the search space.
The search space should be a space where distances are relevant for the model.
- It should not be `None`, otherwise, the model is not involved thus not explained. If you are interested in
- searching the input space, you should use a `BaseSearchMethod` instead.
+ It should not be `None`, otherwise, the model is not involved thus not explained.
Example of Callable:
```
@@ -109,16 +109,17 @@ def __init__(
@property
def search_method_class(self):
"""
- This property defines the search method class to use for the search. In this case, it is the FilterKNN that
- is an efficient KNN search method ignoring non-acceptable cases, thus not considering them in the search.
+ This property defines the search method class to use for the search.
+ In this case, it is the FilterKNN that is an efficient KNN search method
+ ignoring non-acceptable cases, thus not considering them in the search.
"""
return FilterKNN
def filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor:
"""
- Filter function to mask the cases for which the model's prediction is different from the model's prediction
- on the inputs.
+ Filter function to mask the cases for which the model's prediction
+ is different from the model's prediction on the inputs.
"""
# get the labels predicted by the model
# (n, )
@@ -133,8 +134,8 @@ def filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor:
class LabelAwareCounterFactuals(BaseExampleMethod):
"""
- This method will search the counterfactuals of a query within an expected class. This class should be provided with
- the query when calling the explain method.
+ This method will search the counterfactuals of a query within an expected class.
+ This class should be provided with the query when calling the explain method.
Parameters
----------
@@ -144,7 +145,8 @@ class LabelAwareCounterFactuals(BaseExampleMethod):
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
the case for your dataset, otherwise, examples will not make sense.
targets_dataset
- Targets are expected to be the one-hot encoding of the model's predictions for the samples in cases_dataset.
+ Targets are expected to be the one-hot encoding of the model's predictions
+ for the samples in cases_dataset.
`tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it.
Batch size and cardinality of other datasets should match `cases_dataset`.
Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not
@@ -160,8 +162,7 @@ class LabelAwareCounterFactuals(BaseExampleMethod):
projection
Projection or Callable that project samples from the input space to the search space.
The search space should be a space where distances are relevant for the model.
- It should not be `None`, otherwise, the model is not involved thus not explained. If you are interested in
- searching the input space, you should use a `BaseSearchMethod` instead.
+ It should not be `None`, otherwise, the model is not involved thus not explained.
Example of Callable:
```
@@ -207,11 +208,13 @@ def __init__(
batch_size=batch_size,
)
- # raise a warning to specify that target in the explain method is not the same as the target used for
- # the target dataset
- warnings.warn("If your projection method requires the target, be aware that when using the explain method,"
- " the target provided is the class within one should search for the counterfactual.\nThus,"
- " it is possible that the projection of the query is going wrong.")
+ # raise a warning to specify that target in the explain method is not the same
+ # as the target used for the target dataset
+ warnings.warn(
+ "If your projection method requires the target, "\
+ + "be aware that when using the explain method,"\
+ + "the target provided is the class within one should search for the counterfactual."\
+ + "\nThus, it is possible that the projection of the query is going wrong.")
# set distance function and order for the search method
self.distance = distance
@@ -228,20 +231,21 @@ def __init__(
filter_fn=self.filter_fn,
order=self.order
)
-
+
@property
def search_method_class(self):
"""
- This property defines the search method class to use for the search. In this case, it is the FilterKNN that
- is an efficient KNN search method ignoring non-acceptable cases, thus not considering them in the search.
+ This property defines the search method class to use for the search.
+ In this case, it is the FilterKNN that is an efficient KNN search method ignoring
+ non-acceptable cases, thus not considering them in the search.
"""
return FilterKNN
def filter_fn(self, _, __, cf_expected_classes, cases_targets) -> tf.Tensor:
"""
- Filter function to mask the cases for which the target is different from the target(s) expected for the
- counterfactuals.
+ Filter function to mask the cases for which the target is different from
+ the target(s) expected for the counterfactuals.
Parameters
----------
@@ -263,8 +267,10 @@ def explain(
):
"""
Return the relevant CF examples to explain the inputs.
- The CF examples are searched within cases for which the target is the one provided in `cf_targets`.
- It projects inputs with `self.projection` in the search space and find examples with the `self.search_method`.
+ The CF examples are searched within cases
+ for which the target is the one provided in `cf_targets`.
+ It projects inputs with `self.projection` in the search space and
+ find examples with the `self.search_method`.
Parameters
----------
@@ -279,7 +285,8 @@ def explain(
-------
return_dict
Dictionary with listed elements in `self.returns`.
- The elements that can be returned are defined with _returns_possibilities static attribute of the class.
+ The elements that can be returned are defined with the `_returns_possibilities`
+ static attribute of the class.
"""
# project inputs into the search space
projected_inputs = self.projection(inputs)
diff --git a/xplique/example_based/prototypes.py b/xplique/example_based/prototypes.py
index c1857b48..ac1c9e7a 100644
--- a/xplique/example_based/prototypes.py
+++ b/xplique/example_based/prototypes.py
@@ -11,7 +11,7 @@
from ..commons.tf_dataset_operations import dataset_gather
-from .search_methods import BaseSearchMethod, ProtoGreedySearch, MMDCriticSearch, ProtoDashSearch
+from .search_methods import ProtoGreedySearch, MMDCriticSearch, ProtoDashSearch
from .projections import Projection
from .base_example_method import BaseExampleMethod
@@ -82,6 +82,7 @@ def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndar
gamma : float, optional
Parameter that determines the spread of the rbf kernel, defaults to 1.0 / n_features.
"""
+ # pylint: disable=too-many-arguments
def __init__(
self,
@@ -97,8 +98,8 @@ def __init__(
kernel_type: str = 'local',
kernel_fn: callable = None,
gamma: float = None
- ):
- # set common example-based parameters
+ ):
+ # set common example-based parameters
super().__init__(
cases_dataset=cases_dataset,
labels_dataset=labels_dataset,
@@ -129,12 +130,12 @@ def __init__(
kernel_fn=self.kernel_fn,
gamma=self.gamma
)
-
+
@property
@abstractmethod
def search_method_class(self) -> Type[ProtoGreedySearch]:
raise NotImplementedError
-
+
def get_global_prototypes(self) -> Dict[str, tf.Tensor]:
"""
Provide the global prototypes computed at the initialization.
@@ -179,18 +180,21 @@ def get_global_prototypes(self) -> Dict[str, tf.Tensor]:
class ProtoGreedy(Prototypes):
+ # pylint: disable=missing-class-docstring
@property
def search_method_class(self) -> Type[ProtoGreedySearch]:
return ProtoGreedySearch
class MMDCritic(Prototypes):
+ # pylint: disable=missing-class-docstring
@property
def search_method_class(self) -> Type[ProtoGreedySearch]:
return MMDCriticSearch
class ProtoDash(Prototypes):
+ # pylint: disable=missing-class-docstring
@property
def search_method_class(self) -> Type[ProtoGreedySearch]:
return ProtoDashSearch
diff --git a/xplique/example_based/search_methods/common.py b/xplique/example_based/search_methods/common.py
index 0f3af3d4..3daa89ee 100644
--- a/xplique/example_based/search_methods/common.py
+++ b/xplique/example_based/search_methods/common.py
@@ -137,8 +137,8 @@ def get_distance_function(distance: Union[int, str, Callable] = "euclidean",) ->
return lambda x1, x2: _chebyshev_distance(x1, x2)
else:
raise AttributeError(
- "The distance parameter is expected to be either a Callable, "
- + f" an integer, 'inf', or a string in {_distances.keys()}. "
- +f"But a {type(distance)} was received, with value {distance}."
+ "The distance parameter is expected to be either a Callable, "\
+ + f" an integer, 'inf', or a string in {_distances.keys()}. "\
+ + f"But a {type(distance)} was received, with value {distance}."
)
diff --git a/xplique/example_based/search_methods/proto_greedy_search.py b/xplique/example_based/search_methods/proto_greedy_search.py
index c1e46862..0625e8fe 100644
--- a/xplique/example_based/search_methods/proto_greedy_search.py
+++ b/xplique/example_based/search_methods/proto_greedy_search.py
@@ -96,9 +96,9 @@ def __init__(
if kernel_type not in ['local', 'global']:
raise AttributeError(
- "The kernel_type parameter is expected to be in"
- + " ['local', 'global'] ",
- +f"but {kernel_type} was received.",
+ "The kernel_type parameter is expected to be in"\
+ + " ['local', 'global'] "\
+ +f"but {kernel_type} was received."\
)
self.kernel_type = kernel_type
@@ -109,8 +109,8 @@ def __init__(
kernel_fn = lambda x, y: rbf_kernel(x,y,gamma)
elif not hasattr(kernel_fn, "__call__"):
raise AttributeError(
- "The kernel_fn parameter is expected to be a Callable",
- +f"but {kernel_fn} was received.",
+ "The kernel_fn parameter is expected to be a Callable"\
+ +f"but {kernel_fn} was received."\
)
# define custom kernel function depending on the kernel type
From 0316d559df0702c0180382c48d1f49f1788cae88 Mon Sep 17 00:00:00 2001
From: lucas Hervier
Date: Wed, 14 Aug 2024 17:39:10 +0200
Subject: [PATCH 101/138] docs: update the README, add warning concerning the
tensorflow version, add the notebook link and update the tutorials tables
---
README.md | 35 +++++++++++++++++++
TUTORIALS.md | 9 +++++
docs/api/example_based/api_example_based.md | 25 ++++++-------
.../label_aware_counter_factuals.md | 4 +--
.../counterfactuals/naive_counter_factuals.md | 4 +--
docs/api/example_based/semifactuals/kleor.md | 4 +--
.../example_based/similar_examples/cole.md | 4 +--
.../similar_examples/similar_examples.md | 4 +--
docs/index.md | 30 ++++++++++++++++
docs/tutorials.md | 9 +++++
10 files changed, 106 insertions(+), 22 deletions(-)
diff --git a/README.md b/README.md
index 56fc5746..b4d4b719 100644
--- a/README.md
+++ b/README.md
@@ -7,6 +7,9 @@
+
+
+
@@ -41,8 +44,13 @@
Feature Visualization
·
Metrics
+ .
+ Example-based
+> [!IMPORTANT]
+> With the release of Keras 3.X since TensorFlow 2.16, some methods may not function as expected. We are actively working on a fix. In the meantime, we recommend using TensorFlow 2.15 or earlier versions for optimal compatibility.
+
The library is composed of several modules, the _Attributions Methods_ module implements various methods (e.g Saliency, Grad-CAM, Integrated-Gradients...), with explanations, examples and links to official papers.
The _Feature Visualization_ module allows to see how neural networks build their understanding of images by finding inputs that maximize neurons, channels, layers or compositions of these elements.
The _Concepts_ module allows you to extract human concepts from a model and to test their usefulness with respect to a class.
@@ -54,6 +62,9 @@ Finally, the _Metrics_ module covers the current metrics used in explainability.
+> [!NOTE]
+> We are proud to announce the release of the _Example-based_ module! This module is dedicated to methods that explain a model by retrieving relevant examples from a dataset. It includes methods that belong to different families: similar examples, contrastive (counter-factuals and semi-factuals) examples, and prototypes (as concepts based methods have a dedicated sections).
+
## 🔥 Tutorials
@@ -110,6 +121,8 @@ Finally, the _Metrics_ module covers the current metrics used in explainability.
+- [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF)
+
You can find a certain number of [**other practical tutorials just here**](https://github.com/deel-ai/xplique/blob/master/TUTORIALS.md). This section is actively developed and more contents will be
included. We will try to cover all the possible usage of the library, feel free to contact us if you have any suggestions or recommendations towards tutorials you would like to see.
@@ -361,6 +374,28 @@ TF : Tensorflow compatible
+Even though we are only at the early stages, we have also recently added an [Example-based methods](api/example_based/api_example_based/) module. Do not hesitate to give us feedback! Currently, the methods available are summarized in the following table:
+
+
+Table of example-based methods available
+
+| Method | Family | Documentation | Tutorial |
+| --- | --- | --- | --- |
+| `SimilarExamples` | Similar Examples | [SimilarExamples](../similar_examples/similar_examples/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+| `Cole` | Similar Examples | [Cole](../similar_examples/cole/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+| | | |
+| `NaiveCounterFactuals` | Counter Factuals | [NaiveCounterFactuals](../counterfactuals/naive_counter_factuals/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+| `LabelAwareCounterFactuals` | Counter Factuals | [LabelAwareCounterFactuals](../counterfactuals/label_aware_counter_factuals/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+||||
+| `KLEORSimMiss` | Semi Factuals | [KLEOR](../semifactuals/kleor/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+| `KLEORGlobalSim` | Semi Factuals | [KLEOR](../semifactuals/kleor/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+||||
+| `ProtoGreedy` | Prototypes | [ProtoGreedy](../prototypes/proto_greedy/) | **TODO** |
+| `ProtoDash` | Prototypes | [ProtoDash](../prototypes/proto_dash/) | **TODO** |
+| `MMDCritic` | Prototypes | [MMDCritic](../prototypes/mmd_critic/) | **TODO** |
+
+
+
## 👍 Contributing
Feel free to propose your ideas or come and contribute with us on the Xplique toolbox! We have a specific document where we describe in a simple way how to make your first pull request: [just here](https://github.com/deel-ai/xplique/blob/master/CONTRIBUTING.md).
diff --git a/TUTORIALS.md b/TUTORIALS.md
index 759964ec..e3681cd6 100644
--- a/TUTORIALS.md
+++ b/TUTORIALS.md
@@ -20,6 +20,8 @@ Here is the lists of the available tutorial for now:
| Metrics | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1WEpVpFSq-oL1Ejugr8Ojb3tcbqXIOPBg) |
| Concept Activation Vectors | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1iuEz46ZjgG97vTBH8p-vod3y14UETvVE) |
| Feature Visualization | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1st43K9AH-UL4eZM1S4QdyrOi7Epa5K8v) |
+| Example-Based Methods | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+| Prototypes | **TODO** |
## Attributions
@@ -74,3 +76,10 @@ Here is the lists of the available tutorial for now:
| :------------------------------------- | :-----------------------------------------------------------------------------------------------------------------------------------------------------: |
| Feature Visualization: Getting started | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1st43K9AH-UL4eZM1S4QdyrOi7Epa5K8v) |
| Modern Feature Visualization: MaCo | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1l0kag1o-qMY4NCbWuAwnuzkzd9sf92ic) |
+
+## Example-Based Methods
+
+| **Tutorial Name** | Notebook |
+| :------------------------------------- | :-----------------------------------------------------------------------------------------------------------------------------------------------------: |
+| Example-Based Methods: Getting started | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+| Prototypes: Getting started | **TODO** |
\ No newline at end of file
diff --git a/docs/api/example_based/api_example_based.md b/docs/api/example_based/api_example_based.md
index 6412e674..fd45c8f0 100644
--- a/docs/api/example_based/api_example_based.md
+++ b/docs/api/example_based/api_example_based.md
@@ -1,6 +1,7 @@
# API: Example-based
-- [**Example-based Methods**: Getting started]() **WIP**
+- [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF)
+- [**TODO: Add the Getting Started on Prototypes**]()
## Context ##
@@ -52,20 +53,20 @@ We can broadly categorize example-based methods into four families: similar exam
??? abstract "Table of example-based methods available"
- | Method | Family | Documentation |
- | --- | --- | --- |
- | `SimilarExamples` | Similar Examples | [SimilarExamples](../similar_examples/similar_examples/) |
- | `Cole` | Similar Examples | [Cole](../similar_examples/cole/) |
+ | Method | Family | Documentation | Tutorial |
+ | --- | --- | --- | --- |
+ | `SimilarExamples` | Similar Examples | [SimilarExamples](../similar_examples/similar_examples/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+ | `Cole` | Similar Examples | [Cole](../similar_examples/cole/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
| | | |
- | `NaiveCounterFactuals` | Counter Factuals | [NaiveCounterFactuals](../counterfactuals/naive_counter_factuals/) |
- | `LabelAwareCounterFactuals` | Counter Factuals | [LabelAwareCounterFactuals](../counterfactuals/label_aware_counter_factuals/) |
+ | `NaiveCounterFactuals` | Counter Factuals | [NaiveCounterFactuals](../counterfactuals/naive_counter_factuals/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+ | `LabelAwareCounterFactuals` | Counter Factuals | [LabelAwareCounterFactuals](../counterfactuals/label_aware_counter_factuals/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
||||
- | `KLEORSimMiss` | Semi Factuals | [KLEOR](../semifactuals/kleor/) |
- | `KLEORGlobalSim` | Semi Factuals | [KLEOR](../semifactuals/kleor/) |
+ | `KLEORSimMiss` | Semi Factuals | [KLEOR](../semifactuals/kleor/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+ | `KLEORGlobalSim` | Semi Factuals | [KLEOR](../semifactuals/kleor/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
||||
- | `ProtoGreedy` | Prototypes | [ProtoGreedy](../prototypes/proto_greedy/) |
- | `ProtoDash` | Prototypes | [ProtoDash](../prototypes/proto_dash/) |
- | `MMDCritic` | Prototypes | [MMDCritic](../prototypes/mmd_critic/) |
+ | `ProtoGreedy` | Prototypes | [ProtoGreedy](../prototypes/proto_greedy/) | **TODO** |
+ | `ProtoDash` | Prototypes | [ProtoDash](../prototypes/proto_dash/) | **TODO** |
+ | `MMDCritic` | Prototypes | [MMDCritic](../prototypes/mmd_critic/) | **TODO** |
### Parameters ###
diff --git a/docs/api/example_based/counterfactuals/label_aware_counter_factuals.md b/docs/api/example_based/counterfactuals/label_aware_counter_factuals.md
index 93a20c9b..2701410c 100644
--- a/docs/api/example_based/counterfactuals/label_aware_counter_factuals.md
+++ b/docs/api/example_based/counterfactuals/label_aware_counter_factuals.md
@@ -2,7 +2,7 @@
- [View colab tutorial]()**WIP** |
+ [View colab tutorial](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
[View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/counterfactuals.py) |
@@ -49,6 +49,6 @@ counterfactuals = lacf.explain(test_samples, test_cf_targets)
## Notebooks
-TODO: Add notebooks
+- [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF)
{{xplique.example_based.counterfactuals.LabelAwareCounterFactuals}}
\ No newline at end of file
diff --git a/docs/api/example_based/counterfactuals/naive_counter_factuals.md b/docs/api/example_based/counterfactuals/naive_counter_factuals.md
index 93d35307..1982ea8e 100644
--- a/docs/api/example_based/counterfactuals/naive_counter_factuals.md
+++ b/docs/api/example_based/counterfactuals/naive_counter_factuals.md
@@ -2,7 +2,7 @@
- [View colab tutorial]()**WIP** |
+ [View colab tutorial](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
[View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/counterfactuals.py) |
@@ -47,7 +47,7 @@ counterfactuals = ncf.explain(test_samples, test_targets)
## Notebooks
-TODO: Add notebooks
+- [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF)
{{xplique.example_based.counterfactuals.NaiveCounterFactuals}}
diff --git a/docs/api/example_based/semifactuals/kleor.md b/docs/api/example_based/semifactuals/kleor.md
index 99ad2486..f8aa571c 100644
--- a/docs/api/example_based/semifactuals/kleor.md
+++ b/docs/api/example_based/semifactuals/kleor.md
@@ -2,7 +2,7 @@
- [View colab tutorial]()**WIP** |
+ [View colab tutorial](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
[View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/semifactuals.py) |
@@ -104,7 +104,7 @@ counterfactuals = global_sim_sf["nuns"]
## Notebooks
-TODO: Add the notebook
+- [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF)
{{xplique.example_based.semifactuals.KLEORSimMiss}}
{{xplique.example_based.semifactuals.KLEORGlobalSim}}
\ No newline at end of file
diff --git a/docs/api/example_based/similar_examples/cole.md b/docs/api/example_based/similar_examples/cole.md
index 004dd7a3..8f717ae5 100644
--- a/docs/api/example_based/similar_examples/cole.md
+++ b/docs/api/example_based/similar_examples/cole.md
@@ -2,7 +2,7 @@
- [View colab tutorial]()**WIP** |
+ [View colab tutorial](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
[View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/similar_examples.py) |
@@ -60,7 +60,7 @@ similar_samples = cole.explain(test_samples, test_targets)
## Notebooks
-TODO: Add the notebook
+- [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF)
{{xplique.example_based.similar_examples.Cole}}
diff --git a/docs/api/example_based/similar_examples/similar_examples.md b/docs/api/example_based/similar_examples/similar_examples.md
index be875a5d..a36eadc4 100644
--- a/docs/api/example_based/similar_examples/similar_examples.md
+++ b/docs/api/example_based/similar_examples/similar_examples.md
@@ -2,7 +2,7 @@
- [View colab tutorial]()**WIP** |
+ [View colab tutorial](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
[View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/similar_examples.py)
@@ -52,6 +52,6 @@ similar_samples = sim_ex.explain(test_samples, test_targets)
# Notebooks
-TODO: Add the notebook
+- [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF)
{{xplique.example_based.similar_examples.SimilarExamples}}
\ No newline at end of file
diff --git a/docs/index.md b/docs/index.md
index 55320f23..eb6062ca 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -7,6 +7,9 @@
+
+
+
@@ -41,8 +44,13 @@
Feature Visualization
·
Metrics
+ .
+ Example-based
+!!! warning
+ With the release of Keras 3.X since TensorFlow 2.16, some methods may not function as expected. We are actively working on a fix. In the meantime, we recommend using TensorFlow 2.15 or earlier versions for optimal compatibility.
+
The library is composed of several modules, the _Attributions Methods_ module implements various methods (e.g Saliency, Grad-CAM, Integrated-Gradients...), with explanations, examples and links to official papers.
The _Feature Visualization_ module allows to see how neural networks build their understanding of images by finding inputs that maximize neurons, channels, layers or compositions of these elements.
The _Concepts_ module allows you to extract human concepts from a model and to test their usefulness with respect to a class.
@@ -54,6 +62,9 @@ Finally, the _Metrics_ module covers the current metrics used in explainability.
+!!! info "🔔 **New Module Available!**"
+ We are proud to announce the release of the _Example-based_ module! This module is dedicated to methods that explain a model by retrieving relevant examples from a dataset. It includes methods that belong to different families: similar examples, contrastive (counter-factuals and semi-factuals) examples, and prototypes (as concepts based methods have a dedicated sections).
+
## 🔥 Tutorials
??? example "We propose some Hands-on tutorials to get familiar with the library and its api"
@@ -109,6 +120,7 @@ Finally, the _Metrics_ module covers the current metrics used in explainability.
- [**Modern Feature Visualization with MaCo**: Getting started](https://colab.research.google.com/drive/1l0kag1o-qMY4NCbWuAwnuzkzd9sf92ic) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1l0kag1o-qMY4NCbWuAwnuzkzd9sf92ic)
+ - [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF)
You can find a certain number of [**other practical tutorials just here**](tutorials/). This section is actively developed and more contents will be
included. We will try to cover all the possible usage of the library, feel free to contact us if you have any suggestions or recommendations towards tutorials you would like to see.
@@ -333,6 +345,24 @@ There are 4 modules in Xplique, [Attribution methods](api/attributions/api_attri
TF : Tensorflow compatible
+Even though we are only at the early stages, we have also recently added an [Example-based methods](api/example_based/api_example_based/) module. Do not hesitate to give us feedback! Currently, the methods available are summarized in the following table:
+
+??? abstract "Table of example-based methods available"
+
+ | Method | Family | Documentation | Tutorial |
+ | --- | --- | --- | --- |
+ | `SimilarExamples` | Similar Examples | [SimilarExamples](../similar_examples/similar_examples/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+ | `Cole` | Similar Examples | [Cole](../similar_examples/cole/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+ | | | |
+ | `NaiveCounterFactuals` | Counter Factuals | [NaiveCounterFactuals](../counterfactuals/naive_counter_factuals/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+ | `LabelAwareCounterFactuals` | Counter Factuals | [LabelAwareCounterFactuals](../counterfactuals/label_aware_counter_factuals/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+ ||||
+ | `KLEORSimMiss` | Semi Factuals | [KLEOR](../semifactuals/kleor/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+ | `KLEORGlobalSim` | Semi Factuals | [KLEOR](../semifactuals/kleor/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+ ||||
+ | `ProtoGreedy` | Prototypes | [ProtoGreedy](../prototypes/proto_greedy/) | **TODO** |
+ | `ProtoDash` | Prototypes | [ProtoDash](../prototypes/proto_dash/) | **TODO** |
+ | `MMDCritic` | Prototypes | [MMDCritic](../prototypes/mmd_critic/) | **TODO** |
## 👍 Contributing
diff --git a/docs/tutorials.md b/docs/tutorials.md
index 38957e89..0e3e9429 100644
--- a/docs/tutorials.md
+++ b/docs/tutorials.md
@@ -20,6 +20,8 @@ Here is the lists of the availables tutorial for now:
| Metrics | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1WEpVpFSq-oL1Ejugr8Ojb3tcbqXIOPBg) |
| Concept Activation Vectors | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1iuEz46ZjgG97vTBH8p-vod3y14UETvVE) |
| Feature Visualization | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1st43K9AH-UL4eZM1S4QdyrOi7Epa5K8v) |
+| Example-Based Methods | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+| Prototypes | **TODO** |
## Attributions
@@ -79,3 +81,10 @@ Here is the lists of the availables tutorial for now:
| :------------------------------------- | :-----------------------------------------------------------------------------------------------------------------------------------------------------: |
| Feature Visualization: Getting started | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1st43K9AH-UL4eZM1S4QdyrOi7Epa5K8v) |
| Modern Feature Visualization: MaCo | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1l0kag1o-qMY4NCbWuAwnuzkzd9sf92ic) |
+
+## Example-Based Methods
+
+| **Tutorial Name** | Notebook |
+| :------------------------------------- | :-----------------------------------------------------------------------------------------------------------------------------------------------------: |
+| Example-Based Methods: Getting started | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) |
+| Prototypes: Getting started | **TODO** |
\ No newline at end of file
From 02b248a47782cc5337c380eaac347a8097cf2ae2 Mon Sep 17 00:00:00 2001
From: Antonin POCHE
Date: Mon, 19 Aug 2024 16:56:51 +0200
Subject: [PATCH 102/138] linting
---
tests/FreeMono.ttf | Bin 0 -> 592752 bytes
tests/example_based/test_image_plot.py | 3 -
xplique/concepts/craft_torch.py | 1 +
xplique/example_based/base_example_method.py | 7 +-
xplique/example_based/counterfactuals.py | 5 +
.../example_based/projections/attributions.py | 6 +-
xplique/example_based/projections/base.py | 23 +-
xplique/example_based/projections/commons.py | 204 ++++++-------
xplique/example_based/projections/hadamard.py | 19 +-
.../example_based/projections/latent_space.py | 8 +-
xplique/example_based/prototypes.py | 1 +
xplique/example_based/search_methods/base.py | 19 +-
.../example_based/search_methods/common.py | 22 +-
xplique/example_based/search_methods/kleor.py | 100 +++++--
xplique/example_based/search_methods/knn.py | 88 +++---
.../search_methods/mmd_critic_search.py | 27 +-
.../search_methods/proto_dash_search.py | 78 +++--
.../search_methods/proto_greedy_search.py | 271 ++++++++++++------
xplique/example_based/semifactuals.py | 81 +++---
xplique/example_based/similar_examples.py | 30 +-
xplique/plots/image.py | 16 +-
21 files changed, 614 insertions(+), 395 deletions(-)
create mode 100644 tests/FreeMono.ttf
diff --git a/tests/FreeMono.ttf b/tests/FreeMono.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..f88bcef9c138ae61473f852d10803c195601d51e
GIT binary patch
literal 592752
zcmeF)4_Kvh{`mjb{kPLhGZPbHGM$vBnaPL^u_0t8gb*4+Lx^n%A%xg5GC~Mp5JGHr
z4?+mBAvUyx5JEF^tdKQq)Xe!lU#D~X^yxC}etzHUcm1yG_i=q5`+nc=Kd<-ee&6@G
z ?CVu;A*{EsdE#fu9MUb>_8fJ9F(q%~nl?*5D8-gipG_x_-W*|DVH&?BP1EVx8`
zYvV+=-)qScNAA1r${EjyZ(y}Zbk(6p?6~t^b~@>y*=J|J>f5%Etxb=f%;GFNtByfh~x;{WUSs})}JnT`R)^DWq)tpWd&)MzYXs8b2+{)lx03+`Dyk+
zcQ^dMmOqq5UPa^UHaurO6KM~3>i@(Zpe?b78G}+KEfOmx$EIwlQ{8Dqa>O!=jCx8j
zS22G_snaraobie5WSEX)&a$IT(=ly(NEU@&luJXBQlCa7S#rpcescGRLK8WVTX%^V
zp-_nEcbmUsdB50Zk)&xRl%5glH=r12;b@$OVKPajyb;=-hbtrHXOT8-SA{5n&eO6s
zQrAAM>-ejoF@A?buo5~x8Vj*^q|~~1ybE+O{Vj_
z>o7Hj*EczTt?T+;&U)@TZfdD(dF%gDsd03Sp1+pf@o$wn_SfjSXbi8^z1H*nt#Vt&
zR6}z;7hWEVsq1uL^DOO3Dh_NGdW
zsdJ{5X_5ASw~S-FzV0+XkFI+Trk1DC#<_;pQl?MTWl^N`u92P#*EsYddB_4Uj(4CkWzpAB!XPN9AnG}dO3GLO>R
z6Mc@-7y6n~<9X*79d@&u^8wA6^Wv^zI++o6OrkZw;^1bJO$Edwp`w-rQK)bsv@P_0)AV$I0zpzqg)O
zmNHJS!_;&2@<^d?^8G&b+`Vg<$2g5y0o_ksTaWc#^{&^{9K1Dtt<*l<1C6ieo(Mf>
z%}IIt;g$Q*rt5j~soywh+`-f{q4&1tr}34Zhu+gyM(TR3bzN)ve+fNHUC(V-Xij>+
zc<1fqr17V-X?Bs}aPFyeEA^$+g$_
zwa^H4{frPxUGKR#T!v`qYkRLrb$y-cb3~s#XF>aQ+;r$OXMbp)(!NumuSY#zgfz^?
zV(5Hr-y3=_YFvGm={#-IXV0F