Skip to content

Commit

Permalink
Hotfix to add VEM as soft dependency to erroranalysis (#2290)
Browse files Browse the repository at this point in the history
* vem soft dependency update

* import, flake & isort fixes

* comment fixes

* lint fixes

* comment fix

* comment fixes

* comment update

* added extras gate

* ci py vision refactor
  • Loading branch information
Advitya17 authored Aug 31, 2023
1 parent 225bbf1 commit ffa1acb
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 11 deletions.
21 changes: 19 additions & 2 deletions erroranalysis/erroranalysis/_internal/matrix_filter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# Copyright (c) Microsoft Corporation
# Licensed under the MIT License.

import logging
import math
import warnings
from abc import ABC, abstractmethod

import numpy as np
import pandas as pd
from sklearn.metrics import multilabel_confusion_matrix
from vision_explanation_methods.error_labeling.error_labeling import \
ErrorLabeling

from erroranalysis._internal.cohort_filter import filter_from_cohort
from erroranalysis._internal.constants import (DIFF, PRED_Y, ROW_INDEX, TRUE_Y,
Expand All @@ -21,6 +20,19 @@
metric_to_func)
from raiutils.exceptions import UserConfigValidationException

module_logger = logging.getLogger(__name__)
module_logger.setLevel(logging.INFO)

try:
from vision_explanation_methods.error_labeling.error_labeling import \
ErrorLabeling
pytorch_installed = True
except ImportError:
pytorch_installed = False
module_logger.debug("Can't import vision_explanation_methods"
"or underlying torch dependencies, "
"required for Object Detection scenario.")

BIN_THRESHOLD = MatrixParams.BIN_THRESHOLD
CATEGORY1 = 'category1'
CATEGORY2 = 'category2'
Expand Down Expand Up @@ -115,6 +127,11 @@ def compute_matrix_on_dataset(analyzer, features, dataset,
if analyzer.model_task == ModelTask.CLASSIFICATION:
diff = pred_y != true_y
elif analyzer.model_task == ModelTask.OBJECT_DETECTION:
if not pytorch_installed:
raise ModuleNotFoundError(
"User Error: torch & torchvision are not installed "
"and are needed for the Object Detection scenario."
)
diff = [
len(
ErrorLabeling(
Expand Down
21 changes: 19 additions & 2 deletions erroranalysis/erroranalysis/_internal/surrogate_error_tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation
# Licensed under the MIT License.

import logging
import numbers
from enum import Enum

Expand All @@ -9,8 +10,6 @@
from lightgbm import Booster, LGBMClassifier, LGBMRegressor
from sklearn.metrics import (mean_absolute_error, mean_squared_error,
median_absolute_error, r2_score)
from vision_explanation_methods.error_labeling.error_labeling import \
ErrorLabeling

from erroranalysis._internal.cohort_filter import filter_from_cohort
from erroranalysis._internal.constants import (DIFF, LEAF_INDEX, METHOD,
Expand All @@ -28,6 +27,19 @@
from erroranalysis._internal.utils import is_spark
from raiutils.exceptions import UserConfigValidationException

module_logger = logging.getLogger(__name__)
module_logger.setLevel(logging.INFO)

try:
from vision_explanation_methods.error_labeling.error_labeling import \
ErrorLabeling
pytorch_installed = True
except ImportError:
pytorch_installed = False
module_logger.debug("Can't import vision_explanation_methods"
"or underlying torch dependencies, "
"required for Object Detection scenario.")

# imports required for pyspark support
try:
import pyspark.sql.functions as F
Expand Down Expand Up @@ -310,6 +322,11 @@ def get_surrogate_booster_local(filtered_df, analyzer, is_model_analyzer,
if analyzer.model_task == ModelTask.CLASSIFICATION:
diff = pred_y != true_y
elif analyzer.model_task == ModelTask.OBJECT_DETECTION:
if not pytorch_installed:
raise ModuleNotFoundError(
"User Error: torch & torchvision are not installed "
"and are needed for the Object Detection scenario."
)
diff = [
len(
ErrorLabeling(
Expand Down
21 changes: 19 additions & 2 deletions erroranalysis/erroranalysis/analyzer/error_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

"""Defines the BaseAnalyzer, the ModelAnalyzer and PredictionsAnalyzer."""

import logging
from abc import ABC, abstractmethod

import numpy as np
import pandas as pd
from sklearn.feature_selection import (mutual_info_classif,
mutual_info_regression)
from vision_explanation_methods.error_labeling.error_labeling import \
ErrorLabeling

from erroranalysis._internal.constants import (ErrorCorrelationMethods,
MatrixParams, Metrics,
Expand All @@ -32,6 +31,19 @@
compute_ebm_global_importance, compute_gbm_global_importance)
from erroranalysis.report import ErrorReport

module_logger = logging.getLogger(__name__)
module_logger.setLevel(logging.INFO)

try:
from vision_explanation_methods.error_labeling.error_labeling import \
ErrorLabeling
pytorch_installed = True
except ImportError:
pytorch_installed = False
module_logger.debug("Can't import vision_explanation_methods"
"or underlying torch dependencies, "
"required for Object Detection scenario.")

BIN_THRESHOLD = MatrixParams.BIN_THRESHOLD
IMPORTANCES_THRESHOLD = 50000
ROOT_COVERAGE = 100
Expand Down Expand Up @@ -690,6 +702,11 @@ def get_diff(self):
if self._model_task == ModelTask.CLASSIFICATION:
return self.model.predict(self.dataset) != self.true_y
elif self._model_task == ModelTask.OBJECT_DETECTION:
if not pytorch_installed:
raise ModuleNotFoundError(
"User Error: torch & torchvision are not installed "
"and are needed for the Object Detection scenario."
)
pred_y = self.model.predict(self.dataset)
diff = [
len(
Expand Down
2 changes: 1 addition & 1 deletion erroranalysis/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ pytest-mock==3.6.1

requirements-parser==0.2.0
rai_test_utils
interpret-core[required]<=0.3.2
interpret-core[required]<=0.3.2
1 change: 1 addition & 0 deletions erroranalysis/requirements-object-detection.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
vision_explanation_methods
3 changes: 1 addition & 2 deletions erroranalysis/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ pandas>=0.25.1,<2.0.0
scipy>=1.4.1
scikit-learn>=0.22.1
lightgbm>=2.0.11
raiutils>=0.4.0
vision_explanation_methods
raiutils>=0.4.0
6 changes: 6 additions & 0 deletions erroranalysis/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
with open('requirements.txt') as f:
install_requires = [line.strip() for line in f]

with open('requirements-object-detection.txt') as f:
extras_require = {
'object_detection': [line.strip() for line in f]
}

setuptools.setup(
name=name, # noqa: F821
version=version, # noqa: F821
Expand All @@ -27,6 +32,7 @@
packages=setuptools.find_packages(),
python_requires='>=3.6',
install_requires=install_requires,
extras_require=extras_require,
classifiers=[
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
Expand Down
19 changes: 17 additions & 2 deletions erroranalysis/tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
# Copyright (c) Microsoft Corporation
# Licensed under the MIT License.

import logging

import numpy as np
import pytest
from vision_explanation_methods.error_labeling.error_labeling import \
ErrorLabeling

from erroranalysis._internal.constants import (
Metrics, ModelTask, binary_classification_metrics,
multiclass_classification_metrics, object_detection_metrics,
regression_metrics)
from erroranalysis._internal.metrics import metric_to_func

module_logger = logging.getLogger(__name__)
module_logger.setLevel(logging.INFO)

try:
from vision_explanation_methods.error_labeling.error_labeling import \
ErrorLabeling
vem_installed = True
except ImportError:
vem_installed = False
module_logger.debug("Can't import vision_explanation_methods "
"or underlying torch dependencies, "
"required for Object Detection scenario.")


class TestMetrics:
@pytest.mark.parametrize('metric', binary_classification_metrics)
Expand Down Expand Up @@ -54,6 +67,8 @@ def test_multiclass_classification_metrics(self, metric):
metric_value = metric_to_func[metric](y_true, y_pred)
assert isinstance(metric_value, float)

@pytest.mark.skipif(not vem_installed,
reason="vision_explanation_methods not installed")
@pytest.mark.parametrize('metric', object_detection_metrics)
def test_object_detection_metrics(self, metric):
y_true = np.array([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])
Expand Down

0 comments on commit ffa1acb

Please sign in to comment.