Skip to content

Commit

Permalink
fix: address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhaoyang Xie committed Oct 5, 2023
1 parent 1de5b74 commit 4d31f66
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 32 deletions.
2 changes: 1 addition & 1 deletion doc/config_rsmeval.rst.inc
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ Defaults to 0.4998.
truncate_outliers *(Optional)*
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If this option is set to ``false``, outliers will not be truncated by truncating outliers that are 4 standard deviations away from the mean. Defaults to ``true``.
If this option is set to ``false``, outliers (values more than 4 standard deviations away from the mean) in feature columns will _not_ be truncated. Defaults to ``true``.
.. _use_thumbnails_rsmeval:
Expand Down
2 changes: 1 addition & 1 deletion doc/config_rsmexplain.rst.inc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ If this option is set to ``false``, the feature values for the responses in ``ba
truncate_outliers *(Optional)*
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If this option is set to ``false``, outliers will not be truncated by truncating outliers that are 4 standard deviations away from the mean. Defaults to ``true``.
If this option is set to ``false``, outliers (values more than 4 standard deviations away from the mean) in feature columns will _not_ be truncated. Defaults to ``true``.
.. _use_wandb_rsmexplain:
Expand Down
2 changes: 1 addition & 1 deletion doc/config_rsmpredict.rst.inc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ A list of column names indicating grouping variables used for generating analyse
truncate_outliers *(Optional)*
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If this option is set to ``false``, outliers will not be truncated by truncating outliers that are 4 standard deviations away from the mean. Defaults to ``true``.
If this option is set to ``false``, outliers (values more than 4 standard deviations away from the mean) in feature columns will _not_ be truncated. Defaults to ``true``.
.. _use_wandb_rsmpredict:
Expand Down
2 changes: 1 addition & 1 deletion doc/config_rsmtool.rst.inc
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ Defaults to 0.4998.

truncate_outliers *(Optional)*
"""""""""""""""""""""""""""""""
If this option is set to ``false``, outliers will not be truncated by truncating outliers that are 4 standard deviations away from the mean. Defaults to ``true``.
If this option is set to ``false``, outliers (values more than 4 standard deviations away from the mean) in feature columns will _not_ be truncated. Defaults to ``true``.
.. _use_scaled_predictions_rsmtool:
Expand Down
19 changes: 14 additions & 5 deletions rsmtool/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@ def preprocess_feature(
A set of pre-defined truncation values.
Defaults to ``None``.
truncate_outliers : bool, optional
Truncate outlier values if set in the config file
Whether to truncate outlier values.
Defaults to ``True``.
Returns
Expand Down Expand Up @@ -2554,7 +2554,10 @@ def process_data_rsmpredict(self, config_obj, data_container_obj):
)

(df_features_preprocessed, df_excluded) = self.preprocess_new_data(
df_input, df_feature_info, standardize_features, truncate_outliers
df_input,
df_feature_info,
standardize_features=standardize_features,
truncate_outliers=truncate_outliers,
)

trim_min = df_postproc_params["trim_min"].values[0]
Expand Down Expand Up @@ -2715,10 +2718,16 @@ def process_data_rsmexplain(self, config_obj, data_container_obj):

# now pre-process all the features that go into the model
(df_background_preprocessed, _) = self.preprocess_new_data(
df_background_preprocessed, df_feature_info, standardize_features, truncate_outliers
df_background_preprocessed,
df_feature_info,
standardize_features=standardize_features,
truncate_outliers=truncate_outliers,
)
(df_explain_preprocessed, _) = self.preprocess_new_data(
df_explain_preprocessed, df_feature_info, standardize_features, truncate_outliers
df_explain_preprocessed,
df_feature_info,
standardize_features=standardize_features,
truncate_outliers=truncate_outliers,
)

# set ID column as index for the background and explain feature frames
Expand Down Expand Up @@ -2809,7 +2818,7 @@ def preprocess_new_data(
Defaults to ``True``.
truncate_outliers : bool, optional
Whether the outlier should be truncated prior to prediction.
Whether outlier should be truncated prior to prediction.
Defaults to ``True``.
Returns
Expand Down
89 changes: 67 additions & 22 deletions rsmtool/rsmexplain.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,47 @@
from .utils.wandb import init_wandb_run, log_configuration_to_wandb


def verify_config_features(explain_config, rsmtool_config, feature, logger=None):
"""
Verify and update a specific feature in the explanation configuration.
Parameters
----------
explain_config : rsmtool.configuration_parser.Configuration
The Configuration object for rsmexplain module.
rsmtool_config : rsmtool.configuration_parser.Configuration
The Configuration object for rsmtool module.
feature : str
The name of the feature to verify and update.
logger : logging object optional
A logging object. If ``None`` is passed, get logger from ``__name__``.
Defaults to ``None``.
Returns
-------
rsmtool.configuration_parser.Configuration
The updated explanation Configuration object.
"""
logger = logger if logger else logging.getLogger(__name__)

rsmexplain_feature = explain_config[feature]
rsmtool_feature = rsmtool_config[feature]

# use the original rsmtool experiment's value for either `standardize_features`
# or `truncate_outliers` for rsmexplain as well; raise a warning if the values
# were different to begin with
if rsmexplain_feature != rsmtool_feature:
logger.warning(
f"overwriting current {feature} value "
f"({rsmexplain_feature}) to match "
f"value specified in original rsmtool experiment "
f"({rsmtool_feature})."
)
explain_config[feature] = rsmtool_feature
return explain_config


def select_examples(featureset, range_size=None):
"""
Sample examples from the given featureset and return indices.
Expand Down Expand Up @@ -133,7 +174,11 @@ def mask(learner, featureset, feature_range=None):


def generate_explanation(
config_file_or_obj_or_dict, output_dir, overwrite_output=False, logger=None, wandb_run=None
config_file_or_obj_or_dict,
output_dir,
overwrite_output=False,
logger=None,
wandb_run=None,
):
"""
Generate a shap.Explanation object.
Expand Down Expand Up @@ -268,32 +313,22 @@ def generate_explanation(
# read the original rsmtool configuration file, if it exists, and figure
# out the value of `standardize_features` that was specified when running
# the original rsmtool experiment
rsmexplain_standardize_features = configuration["standardize_features"]
expected_config_file_path = join(experiment_output_dir, f"{experiment_id}_rsmtool.json")
if exists(expected_config_file_path):
with open(expected_config_file_path, "r") as rsmtool_configfh:
rsmtool_config = json.load(rsmtool_configfh)
rsmtool_standardize_features = rsmtool_config["standardize_features"]

# use the original rsmtool experiment's value for `standardize_features`
# for rsmexplain as well; raise a warning if the values were different
# to begin with
if rsmexplain_standardize_features != rsmtool_standardize_features:
logger.warning(
f"overwriting current `standardize_features` value "
f"({rsmexplain_standardize_features}) to match "
f"value specified in original rsmtool experiment "
f"({rsmtool_standardize_features})."
)
configuration["standardize_features"] = rsmtool_standardize_features
for feature in ["standardize_features", "truncate_outliers"]:
configuration = verify_config_features(
configuration, rsmtool_config, feature, logger
)

# if the original experiment rsmtool does not exist, let the user know
else:
logger.warning(
f"cannot locate original rsmtool configuration; "
f"ensure that current value of "
f"`standardize_features` ({rsmexplain_standardize_features}) "
f"was the same when running rsmtool."
"cannot locate original rsmtool configuration; "
"ensure that current value of "
"`standardize_features` and `truncate_outliers`"
"were the same when running rsmtool."
)

# load the background and explain data sets
Expand Down Expand Up @@ -547,7 +582,12 @@ def main():
# or one of the valid optional arguments, then assume that they
# are arguments for the "run" sub-command. This allows the
# old style command-line invocations to work without modification.
if sys.argv[1] not in VALID_PARSER_SUBCOMMANDS + ["-h", "--help", "-V", "--version"]:
if sys.argv[1] not in VALID_PARSER_SUBCOMMANDS + [
"-h",
"--help",
"-V",
"--version",
]:
args_to_pass = ["run"] + sys.argv[1:]
else:
args_to_pass = sys.argv[1:]
Expand All @@ -561,7 +601,9 @@ def main():
logger.info(f"Output directory: {args.output_dir}")

generate_explanation(
abspath(args.config_file), abspath(args.output_dir), overwrite_output=args.force_write
abspath(args.config_file),
abspath(args.output_dir),
overwrite_output=args.force_write,
)

else:
Expand All @@ -570,7 +612,10 @@ def main():

# auto-generate an example configuration and print it to STDOUT
generator = ConfigurationGenerator(
"rsmexplain", as_string=True, suppress_warnings=args.quiet, use_subgroups=False
"rsmexplain",
as_string=True,
suppress_warnings=args.quiet,
use_subgroups=False,
)
configuration = (
generator.interact(output_file_name=args.output_file.name if args.output_file else None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"exclude_zero_scores": true,
"select_transformations": false,
"standardize_features": true,
"truncate_outliers": true,
"use_thumbnails": false,
"use_truncation_thresholds": false,
"predict_expected_scores": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"exclude_zero_scores": true,
"select_transformations": false,
"standardize_features": false,
"truncate_outliers": true,
"use_thumbnails": false,
"use_truncation_thresholds": false,
"predict_expected_scores": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"exclude_zero_scores": true,
"select_transformations": false,
"standardize_features": true,
"truncate_outliers": true,
"use_thumbnails": false,
"use_truncation_thresholds": false,
"predict_expected_scores": false,
Expand Down
26 changes: 25 additions & 1 deletion tests/test_explanation_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import unittest
from os import environ
from os.path import join
Expand All @@ -9,7 +10,7 @@
from skll.learner import Learner

from rsmtool.modeler import Modeler
from rsmtool.rsmexplain import mask, select_examples
from rsmtool.rsmexplain import mask, select_examples, verify_config_features

# allow test directory to be set via an environment variable
# which is needed for package testing
Expand Down Expand Up @@ -205,3 +206,26 @@ def test_mask_from_learner_on_disk(self):
computed_ids, computed_features = mask(model, background, feature_range=[5, 10])
self.assertEqual(computed_ids, expected_ids)
assert_array_equal(computed_features, expected_features)

def test_verify_config_features(self):
"""Test verify_config_features when features are different."""
experiment_path = join(rsmtool_test_dir, "data", "experiments", "knn-explain-diff-std")
rsmtool_config_path = join(
experiment_path,
"existing_experiment",
"output",
"knn_diff_std_rsmtool.json",
)
rsmexplain_config_path = join(experiment_path, "rsmexplain.json")
expected_output = True

with open(rsmtool_config_path) as rsmtool_configfh, open(
rsmexplain_config_path
) as rsmexplain_configfh:
rsmexplain_config = json.load(rsmexplain_configfh)
rsmtool_config = json.load(rsmtool_configfh)
computed_config = verify_config_features(
rsmexplain_config, rsmtool_config, "standardize_features"
)

self.assertEqual(computed_config["standardize_features"], expected_output)

0 comments on commit 4d31f66

Please sign in to comment.