Skip to content

Commit

Permalink
Merge pull request #84 from aai-institute/infinimol-dev
Browse files Browse the repository at this point in the history
Support test_io_data for MultiDataEval and eval util
  • Loading branch information
opcode81 authored Feb 29, 2024
2 parents 363d139 + 248d239 commit 7d20514
Showing 1 changed file with 37 additions and 9 deletions.
46 changes: 37 additions & 9 deletions src/sensai/evaluation/eval_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ def create_vector_model_cross_validator(data: InputOutputData,

def create_evaluation_util(data: InputOutputData, model: VectorModel = None, is_regression: bool = None,
evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams]] = None,
cross_validator_params: Optional[Dict[str, Any]] = None) \
cross_validator_params: Optional[Dict[str, Any]] = None, test_io_data: Optional[InputOutputData] = None) \
-> Union["ClassificationModelEvaluation", "RegressionModelEvaluation"]:
if _is_regression(model, is_regression):
return RegressionModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params)
return RegressionModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data)
else:
return ClassificationModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params)
return ClassificationModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data)


def eval_model_via_evaluator(model: TModel, io_data: InputOutputData, test_fraction=0.2,
Expand Down Expand Up @@ -576,16 +576,34 @@ class MultiDataModelEvaluation:
def __init__(self, io_data_dict: Dict[str, InputOutputData], key_name: str = "dataset",
meta_data_dict: Optional[Dict[str, Dict[str, Any]]] = None,
evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams, Dict[str, Any]]] = None,
cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None):
cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None,
test_io_data_dict: Optional[Dict[str, Optional[InputOutputData]]] = None):
"""
:param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models
:param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models.
For evaluation or cross-validation, these datasets will usually be split according to the rules
specified by `evaluator_params or `cross_validator_params`. An exception is the case where
explicit test data sets are specified by passing `test_io_data_dict`. Then, for these data
sets, the io_data will not be split for evaluation, but the test_io_data will be used instead.
:param key_name: a name for the key value used in inputOutputDataDict, which will be used as a column name in result data frames
:param meta_data_dict: a dictionary which maps from a name (same keys as in inputOutputDataDict) to a dictionary, which maps
from a column name to a value and which is to be used to extend the result data frames containing per-dataset results
:param evaluator_params: parameters to use for the instantiation of evaluators (relevant if useCrossValidation==False)
:param cross_validator_params: parameters to use for the instantiation of cross-validators (relevant if useCrossValidation==True)
:param test_io_data_dict: a dictionary mapping from names to the test data sets to use for evaluation or to None.
Entries with non-None values will be used for evaluation of the models that were trained on the respective io_data_dict.
If passed, the keys need to be a superset of io_data_dict's keys (note that the values may be None, e.g.
if you want to use test data sets for some entries, and splitting of the io_data for others).
If not None, cross-validation cannot be used when calling ``compare_models``.
"""
if test_io_data_dict is not None:
missing_keys = set(io_data_dict).difference(test_io_data_dict)
if len(missing_keys) > 0:
raise ValueError(
"If test_io_data_dict is passed, its keys must be a superset of the io_data_dict's keys."
f"However, found missing_keys: {missing_keys}")
self.io_data_dict = io_data_dict
self.test_io_data_dict = test_io_data_dict

self.key_name = key_name
self.evaluator_params = evaluator_params
self.cross_validator_params = cross_validator_params
Expand All @@ -612,25 +630,34 @@ def compare_models(self,
"""
:param model_factories: a sequence of factory functions for the creation of models to evaluate; every factory must result
in a model with a fixed model name (otherwise results cannot be correctly aggregated)
:param use_cross_validation: whether to use cross-validation (rather than a single split) for model evaluation
:param use_cross_validation: whether to use cross-validation (rather than a single split) for model evaluation.
This can only be used if the instance's ``test_io_data_dict`` is None.
:param result_writer: a writer with which to store results; if None, results are not stored
:param write_per_dataset_results: whether to use resultWriter (if not None) in order to generate detailed results for each
dataset in a subdirectory named according to the name of the dataset
:param write_csvs: whether to write metrics table to CSV files
:param column_name_for_model_ranking: column name to use for ranking models
:param rank_max: if true, use max for ranking, else min
:param add_combined_eval_stats: whether to also report, for each model, evaluation metrics on the combined set data points from
all EvalStats objects.
Note that for classification, this is only possible if all individual experiments use the same set of class labels.
:param create_metric_distribution_plots: whether to create, for each model, plots of the distribution of each metric across the
datasets (applies only if resultWriter is not None)
datasets (applies only if result_writer is not None)
:param create_combined_eval_stats_plots: whether to combine, for each type of model, the EvalStats objects from the individual
experiments into a single objects that holds all results and use it to create plots reflecting the overall result (applies only
if resultWriter is not None).
Note that for classification, this is only possible if all individual experiments use the same set of class labels.
:param distribution_plots_cdf: whether to create CDF plots for the metric distributions. Applies only if
create_metric_distribution_plots is True and result_writer is not None.
:param distribution_plots_cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that
distribution_plots_cdf is True.
:param visitors: visitors which may process individual results. Plots generated by visitors are created/collected at the end of the
comparison.
:return: an object containing the full comparison results
"""
if self.test_io_data_dict and use_cross_validation:
raise ValueError("Cannot use cross-validation when `test_io_data_dict` is specified")

all_results_df = pd.DataFrame()
eval_stats_by_model_name = defaultdict(list)
results_by_model_name: Dict[str, List[ModelComparisonData.Result]] = defaultdict(list)
Expand Down Expand Up @@ -659,8 +686,9 @@ def compare_models(self,
else:
raise ValueError("The models have to be either all regression models or all classification, not a mixture")

test_io_data = self.test_io_data_dict[key] if self.test_io_data_dict is not None else None
ev = create_evaluation_util(inputOutputData, is_regression=is_regression, evaluator_params=self.evaluator_params,
cross_validator_params=self.cross_validator_params)
cross_validator_params=self.cross_validator_params, test_io_data=test_io_data)

if plot_collector is None:
plot_collector = ev.eval_stats_plot_collector
Expand Down Expand Up @@ -918,7 +946,7 @@ def create_distribution_plots(self, result_writer: ResultWriter, cdf=True, cdf_c
:param result_writer: the result writer
:param cdf: whether to additionally plot, for each distribution, the cumulative distribution function
:param cdf_complementary: whether to plot the complementary cdf, provided that ``cdf`` is True
:param cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that ``cdf`` is True
"""
for modelName in self.get_model_names():
eval_stats_collection = self.get_eval_stats_collection(modelName)
Expand Down

0 comments on commit 7d20514

Please sign in to comment.