Skip to content

Commit

Permalink
Included warning TreeSHAP background dataset size. (#710)
Browse files Browse the repository at this point in the history
* Included warning TreeSHAP background dataset size.

* Fixed background size when DenseData object returned by summarisation.

* Updated waring to emphasize sampling with replacement.
  • Loading branch information
RobertSamoilescu authored Jul 4, 2022
1 parent e4970f7 commit 5b7931d
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions alibi/explainers/shap_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,7 @@ def reset_predictor(self, predictor: Callable) -> None:
# TODO: Look into pyspark support requirements if requested
# TODO: catboost.Pool not supported for fit stage (due to summarisation) but can do if there is a user need

TREE_SHAP_BACKGROUND_SUPPORTED_SIZE = 100
TREE_SHAP_BACKGROUND_WARNING_THRESHOLD = 1000
TREE_SHAP_MODEL_OUTPUT = ['raw', 'probability', 'probability_doubled', 'log_loss']

Expand Down Expand Up @@ -1159,6 +1160,24 @@ def fit(self, # type: ignore[override]
else:
self._check_inputs(background_data)

# summarisation can return a DenseData object
n_samples = (background_data.data if isinstance(background_data, shap_utils.DenseData)
else background_data).shape[0]

# Warns the user that TreeShap supports only up to TREE_SHAP_BACKGROUND_SIZE(100) samples in the
# background dataset. Note that there is a logic above related to the summarisation of the background
# dataset which uses TREE_SHAP_BACKGROUND_WARNING_THRESHOLD(1000) as (warning) threshold. Although the
# TREE_SHAP_BACKGROUND_WARNING_THRESHOLD > TREE_SHAP_BACKGROUND_SUPPORTED_SIZE which is contradictory, we
# leave the logic above untouched. This approach has at least two benefits:
# i) minimal refactoring
# ii) return the correct result if a newer version of shap which fixes the issue is used before we
# update our dependencies in alibi (i.e. just ignore the warning)
if n_samples > TREE_SHAP_BACKGROUND_SUPPORTED_SIZE:
logger.warning(f'The upstream implementation of interventional TreeShap supports only up to '
f'{TREE_SHAP_BACKGROUND_SUPPORTED_SIZE} samples in the background dataset. '
f'A larger background dataset will be sampled with replacement to '
f'{TREE_SHAP_BACKGROUND_SUPPORTED_SIZE} instances.')

perturbation = 'interventional' if background_data is not None else 'tree_path_dependent'
self.background_data = background_data
self._explainer = shap.TreeExplainer(
Expand Down

0 comments on commit 5b7931d

Please sign in to comment.