Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

660/truncate feature values #661

Merged
merged 13 commits into from
Oct 6, 2023
6 changes: 6 additions & 0 deletions doc/config_rsmexplain.rst.inc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ If this option is set to ``false``, the feature values for the responses in ``ba

If ``experiment_dir`` contains the rsmtool configuration file, that file's value for ``standardize_features`` will override the value specified by the user. The reason is that if ``rsmtool`` trained the model with (or without) standardized features, then ``rsmexplain`` must do the same for the explanations to be meaningful.

.. _truncate_outliers_rsmexplain:

truncate_outliers *(Optional)*
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If this option is set to ``false``, outliers (values more than 4 standard deviations away from the mean) in feature columns will _not_ be truncated. Defaults to ``true``.

.. _use_wandb_rsmexplain:

use_wandb *(Optional)*
Expand Down
6 changes: 6 additions & 0 deletions doc/config_rsmpredict.rst.inc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ subgroups *(Optional)*
~~~~~~~~~~~~~~~~~~~~~~
A list of column names indicating grouping variables used for generating analyses specific to each of those defined subgroups. For example, ``["prompt, gender, native_language, test_country"]``. All these columns will be included into the predictions file with the original names.

.. _truncate_outliers_rsmpredict:

truncate_outliers *(Optional)*
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If this option is set to ``false``, outliers (values more than 4 standard deviations away from the mean) in feature columns will _not_ be truncated. Defaults to ``true``.

.. _use_wandb_rsmpredict:

use_wandb *(Optional)*
Expand Down
6 changes: 6 additions & 0 deletions doc/config_rsmtool.rst.inc
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,12 @@ Defaults to 0.4998.

For more fine-grained control over the trimming range, you can set ``trim_tolerance`` to `0` and use ``trim_min`` and ``trim_max`` to specify the exact floor and ceiling values.

.. _truncate_outliers:

truncate_outliers *(Optional)*
"""""""""""""""""""""""""""""""
If this option is set to ``false``, outliers (values more than 4 standard deviations away from the mean) in feature columns will _not_ be truncated. Defaults to ``true``.

.. _use_scaled_predictions_rsmtool:

use_scaled_predictions *(Optional)*
Expand Down
1 change: 1 addition & 0 deletions doc/config_rsmxval.rst.inc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ In addition to the fields described so far, an ``rsmxval`` configuration file al
- ``trim_max``
- ``trim_min``
- ``trim_tolerance``
- ``truncate_outliers``
- ``use_scaled_predictions``
- ``use_thumbnails``
- ``use_truncation_thresholds``
Expand Down
3 changes: 2 additions & 1 deletion examples/rsmtool/config_rsmtool.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
"trim_max": 6,
"id_column": "ID",
"second_human_score_column": "score2",
"length_column": "LENGTH"
"length_column": "LENGTH",
"standardize_features": false
}
damien2012eng marked this conversation as resolved.
Show resolved Hide resolved
66 changes: 54 additions & 12 deletions rsmtool/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,7 @@ def preprocess_feature(
exclude_zero_sd=False,
raise_error=True,
truncations=None,
truncate_outliers=True,
):
"""
Remove outliers and transform the values in given numpy array.
Expand Down Expand Up @@ -1050,6 +1051,9 @@ def preprocess_feature(
truncations : pandas DataFrame, optional
A set of pre-defined truncation values.
Defaults to ``None``.
truncate_outliers : bool, optional
Whether to truncate outlier values.
Defaults to ``True``.

Returns
-------
Expand All @@ -1063,16 +1067,21 @@ def preprocess_feature(
If the preprocessed feature values have zero standard deviation
and ``exclude_zero_sd`` is set to ``True``.
"""
if truncations is not None:
# clamp outlier values using the truncations set
features_no_outliers = self.remove_outliers_using_truncations(
values, feature_name, truncations
)
if truncate_outliers:
if truncations is not None:
# clamp outlier values using the truncations set
features_no_outliers = self.remove_outliers_using_truncations(
values, feature_name, truncations
)

else:
# clamp any outlier values that are 4 standard deviations
# away from the mean
features_no_outliers = self.remove_outliers(
values, mean=feature_mean, sd=feature_sd
)
else:
# clamp any outlier values that are 4 standard deviations
# away from the mean
features_no_outliers = self.remove_outliers(values, mean=feature_mean, sd=feature_sd)
features_no_outliers = values

# apply the requested transformation to the feature
transformed_feature = FeatureTransformer().transform_feature(
Expand Down Expand Up @@ -1105,6 +1114,7 @@ def preprocess_features(
df_feature_specs,
standardize_features=True,
use_truncations=False,
truncate_outliers=True,
):
"""
Preprocess features in given data using corresponding specifications.
Expand All @@ -1128,11 +1138,15 @@ def preprocess_features(
standardize_features : bool, optional
Whether to standardize the features.
Defaults to ``True``.
truncate_outliers : bool, optional
Truncate outlier values if set in the config file
damien2012eng marked this conversation as resolved.
Show resolved Hide resolved
Defaults to ``True``.
use_truncations : bool, optional
Whether we should use the truncation set
for removing outliers.
Defaults to ``False``.


Returns
-------
df_train_preprocessed : pandas DataFrame
Expand Down Expand Up @@ -1178,6 +1192,7 @@ def preprocess_features(
train_feature_sd,
exclude_zero_sd=True,
truncations=truncations,
truncate_outliers=truncate_outliers,
)

testing_feature_values = df_test[feature_name].values
Expand All @@ -1188,6 +1203,7 @@ def preprocess_features(
train_feature_mean,
train_feature_sd,
truncations=truncations,
truncate_outliers=truncate_outliers,
)

# Standardize the features using the mean and sd computed on the
Expand Down Expand Up @@ -1708,6 +1724,9 @@ def process_data_rsmtool(self, config_obj, data_container_obj):
# should we standardize the features
standardize_features = config_obj["standardize_features"]

# should outliers be truncated?
truncate_outliers = config_obj.get("truncate_outliers", True)

# if we are excluding zero scores but trim_min
# is set to 0, then we need to warn the user
if exclude_zero_scores and spec_trim_min == 0:
Expand Down Expand Up @@ -1973,6 +1992,7 @@ def process_data_rsmtool(self, config_obj, data_container_obj):
feature_specs,
standardize_features,
use_truncations,
truncate_outliers,
)

# configuration options that either override previous values or are
Expand Down Expand Up @@ -2471,6 +2491,9 @@ def process_data_rsmpredict(self, config_obj, data_container_obj):
# should features be standardized?
standardize_features = config_obj.get("standardize_features", True)

# should outliers be truncated?
truncate_outliers = config_obj.get("truncate_outliers", True)

# should we predict expected scores
predict_expected_scores = config_obj["predict_expected_scores"]

Expand Down Expand Up @@ -2531,7 +2554,10 @@ def process_data_rsmpredict(self, config_obj, data_container_obj):
)

(df_features_preprocessed, df_excluded) = self.preprocess_new_data(
df_input, df_feature_info, standardize_features
df_input,
df_feature_info,
standardize_features=standardize_features,
truncate_outliers=truncate_outliers,
)

trim_min = df_postproc_params["trim_min"].values[0]
Expand Down Expand Up @@ -2646,6 +2672,9 @@ def process_data_rsmexplain(self, config_obj, data_container_obj):
# should features be standardized?
standardize_features = config_obj.get("standardize_features", True)

# should outliers be truncated?
truncate_outliers = config_obj.get("truncate_outliers", True)

# rename the ID columns in both frames
df_background_preprocessed = self.rename_default_columns(
df_background_features,
Expand Down Expand Up @@ -2689,10 +2718,16 @@ def process_data_rsmexplain(self, config_obj, data_container_obj):

# now pre-process all the features that go into the model
(df_background_preprocessed, _) = self.preprocess_new_data(
df_background_preprocessed, df_feature_info, standardize_features
df_background_preprocessed,
df_feature_info,
standardize_features=standardize_features,
truncate_outliers=truncate_outliers,
)
(df_explain_preprocessed, _) = self.preprocess_new_data(
df_explain_preprocessed, df_feature_info, standardize_features
df_explain_preprocessed,
df_feature_info,
standardize_features=standardize_features,
truncate_outliers=truncate_outliers,
)

# set ID column as index for the background and explain feature frames
Expand Down Expand Up @@ -2748,7 +2783,9 @@ def process_data(self, config_obj, data_container_obj, context="rsmtool"):
f"'rsmeval', 'rsmpredict', 'rsmexplain']. You specified `{context}`."
)

def preprocess_new_data(self, df_input, df_feature_info, standardize_features=True):
def preprocess_new_data(
self, df_input, df_feature_info, standardize_features=True, truncate_outliers=True
):
"""
Preprocess feature values using the parameters in ``df_feature_info``.

Expand Down Expand Up @@ -2780,6 +2817,10 @@ def preprocess_new_data(self, df_input, df_feature_info, standardize_features=Tr
Whether the features should be standardized prior to prediction.
Defaults to ``True``.

truncate_outliers : bool, optional
Whether outlier should be truncated prior to prediction.
Defaults to ``True``.

Returns
-------
df_features_preprocessed : pandas DataFrame
Expand Down Expand Up @@ -2881,6 +2922,7 @@ def preprocess_new_data(self, df_input, df_feature_info, standardize_features=Tr
train_feature_sd,
exclude_zero_sd=False,
raise_error=False,
truncate_outliers=truncate_outliers,
)

# filter the feature values once again to remove possible NaN and inf values that
Expand Down
64 changes: 38 additions & 26 deletions rsmtool/rsmexplain.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ def mask(learner, featureset, feature_range=None):


def generate_explanation(
config_file_or_obj_or_dict, output_dir, overwrite_output=False, logger=None, wandb_run=None
config_file_or_obj_or_dict,
output_dir,
overwrite_output=False,
logger=None,
wandb_run=None,
):
"""
Generate a shap.Explanation object.
Expand Down Expand Up @@ -265,35 +269,33 @@ def generate_explanation(
f"generated during model training."
)

# read the original rsmtool configuration file, if it exists, and figure
# out the value of `standardize_features` that was specified when running
# the original rsmtool experiment
rsmexplain_standardize_features = configuration["standardize_features"]
# read the original rsmtool configuration file, if it exists, and ensure
# that we use its value of `standardize_features` and `truncate_outliers`
# even if that means we have to override the values specified in the
# rsmexplain configuration file
expected_config_file_path = join(experiment_output_dir, f"{experiment_id}_rsmtool.json")
if exists(expected_config_file_path):
with open(expected_config_file_path, "r") as rsmtool_configfh:
rsmtool_config = json.load(rsmtool_configfh)
rsmtool_standardize_features = rsmtool_config["standardize_features"]

# use the original rsmtool experiment's value for `standardize_features`
# for rsmexplain as well; raise a warning if the values were different
# to begin with
if rsmexplain_standardize_features != rsmtool_standardize_features:
logger.warning(
f"overwriting current `standardize_features` value "
f"({rsmexplain_standardize_features}) to match "
f"value specified in original rsmtool experiment "
f"({rsmtool_standardize_features})."
)
configuration["standardize_features"] = rsmtool_standardize_features
rsmtool_configuration = json.load(rsmtool_configfh)

for option in ["standardize_features", "truncate_outliers"]:
rsmtool_value = rsmtool_configuration[option]
rsmexplain_value = configuration[option]
if rsmexplain_value != rsmtool_value:
logger.warning(
f"overwriting current `{option}` value "
f"({rsmexplain_value}) to match "
f"value specified in original rsmtool experiment "
f"({rsmtool_value})."
)
configuration[option] = rsmtool_value

# if the original experiment rsmtool does not exist, let the user know
else:
logger.warning(
f"cannot locate original rsmtool configuration; "
f"ensure that current value of "
f"`standardize_features` ({rsmexplain_standardize_features}) "
f"was the same when running rsmtool."
"cannot locate original rsmtool configuration; "
"ensure that the values of `standardize_features` "
"and `truncate_outliers` were the same as when running rsmtool."
)

# load the background and explain data sets
Expand Down Expand Up @@ -547,7 +549,12 @@ def main():
# or one of the valid optional arguments, then assume that they
# are arguments for the "run" sub-command. This allows the
# old style command-line invocations to work without modification.
if sys.argv[1] not in VALID_PARSER_SUBCOMMANDS + ["-h", "--help", "-V", "--version"]:
if sys.argv[1] not in VALID_PARSER_SUBCOMMANDS + [
"-h",
"--help",
"-V",
"--version",
]:
args_to_pass = ["run"] + sys.argv[1:]
else:
args_to_pass = sys.argv[1:]
Expand All @@ -561,7 +568,9 @@ def main():
logger.info(f"Output directory: {args.output_dir}")

generate_explanation(
abspath(args.config_file), abspath(args.output_dir), overwrite_output=args.force_write
abspath(args.config_file),
abspath(args.output_dir),
overwrite_output=args.force_write,
)

else:
Expand All @@ -570,7 +579,10 @@ def main():

# auto-generate an example configuration and print it to STDOUT
generator = ConfigurationGenerator(
"rsmexplain", as_string=True, suppress_warnings=args.quiet, use_subgroups=False
"rsmexplain",
as_string=True,
suppress_warnings=args.quiet,
use_subgroups=False,
)
configuration = (
generator.interact(output_file_name=args.output_file.name if args.output_file else None)
Expand Down
Loading