Skip to content

Commit

Permalink
concepts tests: move random test to debug notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
AntoninPoche committed Oct 8, 2024
1 parent dbf7d79 commit c888d71
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 418 deletions.
180 changes: 4 additions & 176 deletions tests/concepts/test_craft_tf.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import numpy as np
import tensorflow as tf
import random
import pytest
import os
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Activation, Flatten, Input
from tensorflow.keras.optimizers import Adam

from tensorflow.keras.layers import Input

from xplique.concepts import CraftTf as Craft
from ..utils import generate_data, generate_model, generate_txt_images_data
from ..utils import download_file

from ..utils import generate_data, generate_model


def test_shape():
Expand Down Expand Up @@ -100,172 +97,3 @@ def test_wrong_layers():
number_of_concepts = number_of_concepts,
patch_size = patch_size,
batch_size = 64)

def test_classifier():
""" Check the Craft results on a small fake dataset """

input_shape = (64, 64, 3)
nb_labels = 3
nb_samples = 200

# Create a dataset of 'ABC', 'BCD', 'CDE' images
x, y, nb_samples, _ = generate_txt_images_data(input_shape, nb_labels, nb_samples)

# train a small classifier on the dataset
def create_classifier_model(input_shape=(64, 64, 3), output_shape=10):
model = Sequential()
model.add(Input(shape=input_shape))
model.add(Conv2D(6, kernel_size=(2, 2)))
model.add(Activation('relu'))
model.add(Conv2D(6, kernel_size=(2, 2)))
model.add(Activation('relu'))
model.add(Conv2D(6, kernel_size=(2, 2)))
model.add(Activation('relu', name='relu'))
model.add(Flatten())
model.add(Dense(output_shape))
model.add(Activation('softmax'))
opt = Adam(learning_rate=0.005)
model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

return model

model = create_classifier_model(input_shape, nb_labels)

tf.random.set_seed(0)
np.random.seed(0)
random.seed(0)

# Retrieve checkpoints
checkpoint_path = "tests/concepts/checkpoints/classifier_test_craft_tf.ckpt"
if not os.path.exists(f"{checkpoint_path}.index"):
os.makedirs("tests/concepts/checkpoints/", exist_ok=True)
identifier = "1NLA7x2EpElzEEmyvFQhD6VS6bMwS_bCs"
download_file(identifier, f"{checkpoint_path}.index")

identifier = "1wDi-y9b-3I_a-ZtqRlfuib-D7Ox4j8pX"
download_file(identifier, f"{checkpoint_path}.data-00000-of-00001")

model.load_weights(checkpoint_path)

acc = np.sum(np.argmax(model(x), axis=1) == np.argmax(y, axis=1)) / nb_samples
assert acc == 1.0

# cut the model in two parts (as explained in the paper)
# first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model
cut_layer = model.get_layer('relu')
g = tf.keras.Model(model.inputs, cut_layer.output)
h = tf.keras.Model(Input(tensor=cut_layer.output), model.outputs)

assert np.all(g(x) >= 0.0)

# Init Craft on the full dataset
craft = Craft(input_to_latent_model = g,
latent_to_logit_model = h,
number_of_concepts = 3,
patch_size = 12,
batch_size = 32)

# Expected best crop for class 0 (ABC) is AB
AB_str = """
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 1 1 1 1 1 1 1 1
0 0 0 0 0 0 1 0 0 0 1 1
1 0 0 0 0 0 1 0 0 0 0 1
1 0 0 0 0 0 1 0 0 0 1 1
1 0 0 0 0 0 1 1 1 1 1 1
1 1 0 0 0 0 1 0 0 0 0 1
0 1 0 0 0 0 1 0 0 0 0 0
0 1 1 0 0 0 1 0 0 0 0 1
1 1 1 1 1 1 1 1 1 1 1 1
0 0 0 0 0 0 0 0 0 0 0 0
"""
AB = np.genfromtxt(AB_str.splitlines())

# Expected best crop for class 1 (BCD) is BC
BC_str = """
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
1 1 1 1 1 1 0 0 0 0 1 1
1 0 0 0 1 1 0 0 0 1 1 0
1 0 0 0 0 1 0 0 0 1 0 0
1 0 0 0 1 1 0 0 0 1 0 0
1 1 1 1 1 1 0 0 0 1 0 0
1 0 0 0 0 1 1 0 0 1 0 0
1 0 0 0 0 0 1 0 0 1 0 0
1 0 0 0 0 1 1 0 0 1 1 0
1 1 1 1 1 1 0 0 0 0 1 1
"""
BC = np.genfromtxt(BC_str.splitlines())

# Expected best crop for class 2 (CDE) is DE
DE_str = """
0 0 0 0 0 0 0 0 0 0 0 0
1 0 0 1 1 1 1 1 1 1 1 0
1 1 0 0 0 1 0 0 0 0 1 0
0 1 0 0 0 1 0 0 0 0 1 0
0 1 1 0 0 1 0 0 1 0 0 0
0 1 1 0 0 1 1 1 1 0 0 0
0 1 1 0 0 1 0 0 1 0 0 0
0 1 0 0 0 1 0 0 0 0 1 1
1 1 0 0 0 1 0 0 0 0 1 1
1 0 0 1 1 1 1 1 1 1 1 1
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
"""
DE = np.genfromtxt(DE_str.splitlines())

DE2_str = """
0 0 0 0 0 0 0 0 0 0 0 Z
1 1 1 1 0 0 1 1 1 1 1 1
0 0 1 1 1 0 0 0 1 0 0 0
0 0 0 0 1 0 0 0 1 0 0 0
0 0 0 0 1 1 0 0 1 0 0 1
0 0 0 0 1 1 0 0 1 1 1 1
0 0 0 0 1 1 0 0 1 0 0 1
0 0 0 0 1 0 0 0 1 0 0 0
0 0 1 1 1 0 0 0 1 0 0 0
1 1 1 1 0 0 1 1 1 1 1 1
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
"""
DE2 = np.genfromtxt(DE2_str.splitlines())

expected_best_crops = [[AB], [BC], [DE, DE2]]
expected_best_crops_names = ['AB', 'BC', 'DE']

# Run 3 Craft studies on each class, and in each case check if the best crop is the expected one
class_check = [False, False, False]
for class_id in range(3):
# Focus on class class_id
# Selecting subset for class {class_id} : {labels_str[class_id]}'
x_subset = x[np.argmax(y, axis=1)==class_id,:,:,:]

# fit craft on the selected class
crops, crops_u, w = craft.fit(x_subset, class_id)

# compute importances
importances = craft.estimate_importance()
assert importances[0] > 0.8

# find the best crop and compare it to the expected best crop
most_important_concepts = np.argsort(importances)[::-1]

# Find the best crop for the most important concept
c_id = most_important_concepts[0]
best_crops_ids = np.argsort(crops_u[:, c_id])[::-1]
best_crop = np.array(crops)[best_crops_ids[0]]

# Compare this best crop to the expectation
predicted_best_crop = np.where(best_crop.sum(axis=2) > 0.25, 1, 0)
for expected_best_crop in expected_best_crops[class_id]:
expected_best_crop = expected_best_crop.astype(np.uint8)

comparison = predicted_best_crop == expected_best_crop
acc = np.sum(comparison) / len(comparison.ravel())
check = acc > 0.9
if check:
class_check[class_id] = True
break
assert np.all(class_check)
178 changes: 0 additions & 178 deletions tests/concepts/test_craft_torch.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytest
import random

from xplique.concepts import CraftTorch as Craft
from ..utils import generate_txt_images_data
from ..utils import download_file

def generate_torch_data(x_shape=(3, 32, 32), num_labels=10, samples=100):
x = torch.tensor(np.random.rand(samples, *x_shape).astype(np.float32))
Expand Down Expand Up @@ -133,177 +129,3 @@ def test_wrong_layers():
number_of_concepts = number_of_concepts,
patch_size = patch_size,
batch_size = 64)

def test_classifier():
""" Check the Craft results on a small fake dataset """

input_shape = (64, 64, 3)
nb_labels = 3
nb_samples = 200

torch.manual_seed(0)
torch.use_deterministic_algorithms(True)
random.seed(0)
np.random.seed(0)

# Create a dataset of 'ABC', 'BCD', 'CDE' images
x, y, nb_samples, _ = generate_txt_images_data(input_shape, nb_labels, nb_samples)
x = np.moveaxis(x, -1, 1) # reorder the axis to match torch format
x, y = torch.Tensor(x), torch.Tensor(y)

# train a small classifier on the dataset
def create_torch_classifier_model(input_shape=(3, 64, 64), output_shape=10):
flatten_size = 6*(input_shape[1]-3)*(input_shape[2]-3)
model = nn.Sequential(
nn.Conv2d(3, 6, kernel_size=(2, 2)),
nn.ReLU(),
nn.Conv2d(6, 6, kernel_size=(2, 2)),
nn.ReLU(),
nn.Conv2d(6, 6, kernel_size=(2, 2)),
nn.ReLU(),
nn.Flatten(1, -1),
# nn.Dropout(p=0.2),
nn.Linear(flatten_size, output_shape))
for layer in model:
if isinstance(layer, nn.Conv2d):
nn.init.kaiming_uniform_(layer.weight, mode='fan_in', nonlinearity='relu')
layer.bias.data.fill_(0.01)
elif isinstance(layer, nn.Linear):
nn.init.xavier_normal_(layer.weight)
layer.bias.data.fill_(0.01)
return model

model = create_torch_classifier_model((input_shape[-1], *input_shape[0:2]), nb_labels)

# Retrieve checkpoints
checkpoint_path = "tests/concepts/checkpoints/classifier_test_craft_torch.ckpt"
if not os.path.exists(checkpoint_path):
os.makedirs("tests/concepts/checkpoints/", exist_ok=True)
identifier = "1vz6hMibMEN6_t9yAY9SS4iaMY8G8aAPQ"
download_file(identifier, checkpoint_path)
model.load_state_dict(torch.load(checkpoint_path))

# check accuracy
model.eval()
acc = torch.sum(torch.argmax(model(x), axis=1) == torch.argmax(y, axis=1))/len(y)
assert acc > 0.9

# cut pytorch model
g = nn.Sequential(*(list(model.children())[:6])) # input to penultimate layer
h = nn.Sequential(*(list(model.children())[6:])) # penultimate layer to logits
assert torch.all(g(x) >= 0.0)

# Init Craft on the full dataset
craft = Craft(input_to_latent_model = g,
latent_to_logit_model = h,
number_of_concepts = 3,
patch_size = 12,
batch_size = 32,
device='cpu')

# Expected best crop for class 0 (ABC) is AB
AB_str = """
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 1 1 1 1 1 1 1 1 0 0 0
0 0 0 1 0 0 0 1 1 0 0 1
0 0 0 1 0 0 0 0 1 0 0 1
0 0 0 1 0 0 0 1 1 0 0 1
0 0 0 1 1 1 1 1 1 0 0 1
0 0 0 1 0 0 0 0 1 1 0 1
0 0 0 1 0 0 0 0 0 1 0 1
0 0 0 1 0 0 0 0 1 1 0 1
1 1 1 1 1 1 1 1 1 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
"""
AB = np.genfromtxt(AB_str.splitlines())

# Expected best crop for class 1 (BCD) is BC
BC_str = """
1 1 1 1 1 1 0 0 0 0 1 1
1 0 0 0 1 1 0 0 0 1 1 0
1 0 0 0 0 1 0 0 0 1 0 0
1 0 0 0 1 1 0 0 0 1 0 0
1 1 1 1 1 1 0 0 0 1 0 0
1 0 0 0 0 1 1 0 0 1 0 0
1 0 0 0 0 0 1 0 0 1 0 0
1 0 0 0 0 1 1 0 0 1 1 0
1 1 1 1 1 1 0 0 0 0 1 1
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
"""
BC = np.genfromtxt(BC_str.splitlines())

# Expected best crop for class 2 (CDE) is DE
DE_str = """
0 0 0 0 0 0 0 0 0 0 0 0
1 0 0 1 1 1 1 1 1 1 1 0
1 1 0 0 0 1 0 0 0 0 1 0
0 1 0 0 0 1 0 0 0 0 1 0
0 1 1 0 0 1 0 0 1 0 0 0
0 1 1 0 0 1 1 1 1 0 0 0
0 1 1 0 0 1 0 0 1 0 0 0
0 1 0 0 0 1 0 0 0 0 1 1
1 1 0 0 0 1 0 0 0 0 1 1
1 0 0 1 1 1 1 1 1 1 1 1
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
"""
DE = np.genfromtxt(DE_str.splitlines())

DE2_str = """
0 0 0 0 0 0 0 0 0 0 0 0
1 1 1 1 0 0 1 1 1 1 1 1
0 0 1 1 1 0 0 0 1 0 0 0
0 0 0 0 1 0 0 0 1 0 0 0
0 0 0 0 1 1 0 0 1 0 0 1
0 0 0 0 1 1 0 0 1 1 1 1
0 0 0 0 1 1 0 0 1 0 0 1
0 0 0 0 1 0 0 0 1 0 0 0
0 0 1 1 1 0 0 0 1 0 0 0
1 1 1 1 0 0 1 1 1 1 1 1
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
"""
DE2 = np.genfromtxt(DE2_str.splitlines())

expected_best_crops = [[AB], [BC], [DE, DE2]]
expected_best_crops_names = ['AB', 'BC', 'DE']

# Run 3 Craft studies on each class, and in each case check if the best crop is the expected one
class_check = [False, False, False]
for class_id in range(3):
# Focus on class class_id
# Selecting subset for class {class_id} : {labels_str[class_id]}'
x_subset = x[np.argmax(y, axis=1)==class_id,:,:,:]

# fit craft on the selected class
crops, crops_u, w = craft.fit(x_subset, class_id)

# compute importances
importances = craft.estimate_importance()
assert np.all(importances >= 0)

# find the best crop and compare it to the expected best crop
most_important_concepts = np.argsort(importances)[::-1]

# Find the best crop for the most important concept
c_id = most_important_concepts[0]
best_crops_ids = np.argsort(crops_u[:, c_id])[::-1]
best_crop = np.array(crops)[best_crops_ids[0]]
best_crop = np.moveaxis(best_crop, 0, -1)

# Compare this best crop to the expectation
predicted_best_crop = np.where(best_crop.sum(axis=2) > 0.25, 1, 0)

# Comparison between expected:
for expected_best_crop in expected_best_crops[class_id]:
expected_best_crop = expected_best_crop.astype(np.uint8)
comparison = predicted_best_crop == expected_best_crop
acc = np.sum(comparison) / len(comparison.ravel())
check = acc > 0.9
if check:
class_check[class_id] = True
break
assert np.all(class_check)
Loading

0 comments on commit c888d71

Please sign in to comment.