From d0047f140f26b8bae736217637ea88a51c49a628 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Wed, 11 May 2022 09:33:54 -0400 Subject: [PATCH 01/22] fix error on machines with pyspark installed where passed dataframe is not spark pandas (#1415) --- erroranalysis/erroranalysis/_internal/utils.py | 7 +++++-- erroranalysis/erroranalysis/version.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/erroranalysis/erroranalysis/_internal/utils.py b/erroranalysis/erroranalysis/_internal/utils.py index f2bc3ba7e5..11de548740 100644 --- a/erroranalysis/erroranalysis/_internal/utils.py +++ b/erroranalysis/erroranalysis/_internal/utils.py @@ -6,7 +6,7 @@ import numpy as np try: - import pyspark + import pyspark.pandas as ps spark_available = True except ImportError: spark_available = False @@ -47,4 +47,7 @@ def is_spark(df): :return: True if the dataframe is a spark dataframe, False otherwise. :rtype: bool """ - return spark_available and isinstance(df, pyspark.pandas.frame.DataFrame) + try: + return spark_available and isinstance(df, ps.frame.DataFrame) + except Exception: + return False diff --git a/erroranalysis/erroranalysis/version.py b/erroranalysis/erroranalysis/version.py index 3ef520e4ea..72eb56e0b2 100644 --- a/erroranalysis/erroranalysis/version.py +++ b/erroranalysis/erroranalysis/version.py @@ -4,5 +4,5 @@ name = 'erroranalysis' _major = '0' _minor = '3' -_patch = '0' +_patch = '1' version = '{}.{}.{}'.format(_major, _minor, _patch) From a4a46d509eeb7087dc22cb9f65d8c988140265e6 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Wed, 11 May 2022 17:32:18 -0400 Subject: [PATCH 02/22] add postbuild branch trigger (#1417) --- .azure-devops/component-governance.yml | 1 + .github/workflows/CD.yml | 2 +- .github/workflows/CI-codescan.yml | 4 ++-- .github/workflows/CI-notebook.yml | 4 ++-- .github/workflows/CI-python.yml | 4 ++-- .github/workflows/CI-typescript.yml | 2 +- .github/workflows/Ci-raiwigets-python-typescript.yml | 4 ++-- .github/workflows/GitHubPages.yml | 2 +- .github/workflows/python-linting.yml | 4 ++-- 9 files changed, 14 insertions(+), 13 deletions(-) diff --git a/.azure-devops/component-governance.yml b/.azure-devops/component-governance.yml index 3f2d1c979c..b9b144efe7 100644 --- a/.azure-devops/component-governance.yml +++ b/.azure-devops/component-governance.yml @@ -2,6 +2,7 @@ trigger: - main + - postbuild pool: vmImage: "ubuntu-latest" diff --git a/.github/workflows/CD.yml b/.github/workflows/CD.yml index c348576729..8ae0537e41 100644 --- a/.github/workflows/CD.yml +++ b/.github/workflows/CD.yml @@ -19,7 +19,7 @@ on: push: branches: [main] pull_request: - branches: [main] + branches: [main, postbuild] workflow_dispatch: jobs: diff --git a/.github/workflows/CI-codescan.yml b/.github/workflows/CI-codescan.yml index f93e6ffcca..8cdd13c8d9 100644 --- a/.github/workflows/CI-codescan.yml +++ b/.github/workflows/CI-codescan.yml @@ -2,9 +2,9 @@ name: CI code scan on: push: - branches: [main] + branches: [main, postbuild] pull_request: - branches: [main] + branches: [main, postbuild] jobs: analyze: diff --git a/.github/workflows/CI-notebook.yml b/.github/workflows/CI-notebook.yml index 037652a6d0..e12e0148ef 100644 --- a/.github/workflows/CI-notebook.yml +++ b/.github/workflows/CI-notebook.yml @@ -2,9 +2,9 @@ name: CI Notebooks on: push: - branches: [main] + branches: [main, postbuild] pull_request: - branches: [main] + branches: [main, postbuild] jobs: ci-notebook: diff --git a/.github/workflows/CI-python.yml b/.github/workflows/CI-python.yml index 385e236587..97377ae26f 100644 --- a/.github/workflows/CI-python.yml +++ b/.github/workflows/CI-python.yml @@ -2,9 +2,9 @@ name: CI Python on: push: - branches: [main] + branches: [main, postbuild] pull_request: - branches: [main] + branches: [main, postbuild] jobs: ci-python: diff --git a/.github/workflows/CI-typescript.yml b/.github/workflows/CI-typescript.yml index 9d7b03cc9a..de2c4a28a5 100644 --- a/.github/workflows/CI-typescript.yml +++ b/.github/workflows/CI-typescript.yml @@ -10,7 +10,7 @@ env: on: pull_request: - branches: [main] + branches: [main, postbuild] jobs: ci-typescript: diff --git a/.github/workflows/Ci-raiwigets-python-typescript.yml b/.github/workflows/Ci-raiwigets-python-typescript.yml index 62f41b1025..8c450ceca6 100644 --- a/.github/workflows/Ci-raiwigets-python-typescript.yml +++ b/.github/workflows/Ci-raiwigets-python-typescript.yml @@ -2,9 +2,9 @@ name: CI RAIWidgets Python Typescript on: push: - branches: [main] + branches: [main, postbuild] pull_request: - branches: [main] + branches: [main, postbuild] jobs: ci-raiwidgets-python-typescript: diff --git a/.github/workflows/GitHubPages.yml b/.github/workflows/GitHubPages.yml index ff537ca623..adf9f6c4bb 100644 --- a/.github/workflows/GitHubPages.yml +++ b/.github/workflows/GitHubPages.yml @@ -4,7 +4,7 @@ on: push: branches: [main] pull_request: - branches: [main] + branches: [main, postbuild] jobs: website-build-and-deploy: diff --git a/.github/workflows/python-linting.yml b/.github/workflows/python-linting.yml index 8acf7c5706..0d786cd79c 100644 --- a/.github/workflows/python-linting.yml +++ b/.github/workflows/python-linting.yml @@ -2,9 +2,9 @@ name: Python Linting on: push: - branches: [main] + branches: [main, postbuild] pull_request: - branches: [main] + branches: [main, postbuild] jobs: build: From c53ba4d3ecfe6a02a3b3ed440deffa2ca48b3de6 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta <47334368+gaugup@users.noreply.github.com> Date: Wed, 11 May 2022 18:02:53 -0700 Subject: [PATCH 03/22] Fix causal UI strings according to classification/regression tasks (#1419) * Fix causal UI strings according to classification/regression tasks Signed-off-by: Gaurav Gupta * Fix lint error Signed-off-by: Gaurav Gupta * Fix UI test Signed-off-by: Gaurav Gupta --- .../CausalAggregateView.tsx | 23 +++++++++++++++---- .../modelAssessment/IModelAssessmentData.ts | 2 ++ .../describeAggregateCausalAffects.ts | 4 ++-- .../modelAssessmentDatasets.ts | 8 +++++++ libs/localization/src/lib/en.json | 2 ++ 5 files changed, 32 insertions(+), 7 deletions(-) diff --git a/libs/causality/src/lib/CausalAnalysisDashboard/Controls/CausalAnalysisView/CausalAggregateView/CausalAggregateView.tsx b/libs/causality/src/lib/CausalAnalysisDashboard/Controls/CausalAnalysisView/CausalAggregateView/CausalAggregateView.tsx index 484a343d00..95c0fdbd7c 100644 --- a/libs/causality/src/lib/CausalAnalysisDashboard/Controls/CausalAnalysisView/CausalAggregateView/CausalAggregateView.tsx +++ b/libs/causality/src/lib/CausalAnalysisDashboard/Controls/CausalAnalysisView/CausalAggregateView/CausalAggregateView.tsx @@ -29,6 +29,7 @@ export class CausalAggregateView extends React.PureComponent d2.point - d1.point); + return ( {localization.CausalAnalysis.AggregateView.continuous} - { - localization.CausalAnalysis.AggregateView - .continuousDescription - } + {this.getContinuousDescription()} {localization.CausalAnalysis.AggregateView.binary} - {localization.CausalAnalysis.AggregateView.binaryDescription} + {this.getBinaryDescription()} {localization.CausalAnalysis.AggregateView.lasso}{" "} @@ -107,4 +105,19 @@ export class CausalAggregateView extends React.PureComponent ); } + + private getContinuousDescription(): string { + if (this.context.dataset.task_type === "classification") { + return localization.CausalAnalysis.AggregateView.continuousDescription; + } + return localization.CausalAnalysis.AggregateView + .continuousRegressionDescription; + } + + private getBinaryDescription(): string { + if (this.context.dataset.task_type === "classification") { + return localization.CausalAnalysis.AggregateView.binaryDescription; + } + return localization.CausalAnalysis.AggregateView.regressionDescription; + } } diff --git a/libs/e2e/src/lib/describer/modelAssessment/IModelAssessmentData.ts b/libs/e2e/src/lib/describer/modelAssessment/IModelAssessmentData.ts index 5afdbec047..62858c0f87 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/IModelAssessmentData.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/IModelAssessmentData.ts @@ -111,6 +111,8 @@ export interface ICausalAnalysisData { yAxisPanelOptions?: { [key: string]: string[] }; treatmentPolicyData?: { [key: string]: string[] }; featureListInCausalTable?: string[]; + continuousDescription?: string; + binaryDescription?: string; } export interface IWhatIfCounterfactualsData { diff --git a/libs/e2e/src/lib/describer/modelAssessment/causalAnalysis/describeAggregateCausalAffects.ts b/libs/e2e/src/lib/describer/modelAssessment/causalAnalysis/describeAggregateCausalAffects.ts index 2879d259b7..4dfdff7bcc 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/causalAnalysis/describeAggregateCausalAffects.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/causalAnalysis/describeAggregateCausalAffects.ts @@ -77,11 +77,11 @@ export function describeAggregateCausalAffects( it("should have continuous and binary treatment definitions", () => { cy.get(Locators.CausalAggregateView).should( "contain", - localization.CausalAnalysis.AggregateView.continuousDescription + dataShape.causalAnalysisData?.continuousDescription ); cy.get(Locators.CausalAggregateView).should( "contain", - localization.CausalAnalysis.AggregateView.binaryDescription + dataShape.causalAnalysisData?.binaryDescription ); }); it("should have details about lasso", () => { diff --git a/libs/e2e/src/lib/describer/modelAssessment/modelAssessmentDatasets.ts b/libs/e2e/src/lib/describer/modelAssessment/modelAssessmentDatasets.ts index d62e13b7de..5ad9e5a2ae 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/modelAssessmentDatasets.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/modelAssessmentDatasets.ts @@ -92,6 +92,10 @@ const modelAssessmentDatasets = { }, DiabetesDecisionMaking: { causalAnalysisData: { + binaryDescription: + "On average in this sample, turning on this feature will cause the predictions of the target to increase/decrease by X units.", + continuousDescription: + "On average in this sample, increasing this feature by 1 unit will cause the predictions of the target to increase/decrease by X units.", featureListInCausalTable: ["s2(num)", "bmi(num)", "bp(num)"], hasCausalAnalysisComponent: true }, @@ -318,6 +322,10 @@ const modelAssessmentDatasets = { }, HousingDecisionMaking: { causalAnalysisData: { + binaryDescription: + "On average in this sample, turning on this feature will cause the predictions of the target to increase/decrease by X units.", + continuousDescription: + "On average in this sample, increasing this feature by 1 unit will cause the predictions of the target to increase/decrease by X units.", featureListInCausalTable: [ "GarageCars(num)", "OverallQual(num)", diff --git a/libs/localization/src/lib/en.json b/libs/localization/src/lib/en.json index 86f0d462f9..cc8114968e 100644 --- a/libs/localization/src/lib/en.json +++ b/libs/localization/src/lib/en.json @@ -3,12 +3,14 @@ "AggregateView": { "binary": "Binary treatments: ", "binaryDescription": "On average in this sample, turning on this feature will cause the probability of class/label 1 to increase by X units.", + "regressionDescription": "On average in this sample, turning on this feature will cause the predictions of the target to increase/decrease by X units.", "causalPoint": "Causal effect point", "confidenceLower": "Confidence interval (lower)", "confidenceUpper": "Confidence interval (upper)", "confoundingFeature": "A confounding feature is correlated with both the treatment and the outcome of interest. The confounder creates an extra correlation path from the treatment to the outcome on top of the direct causal effect. Unless confounding features are measured and included in the model, these extra correlations can bias estimates of the causal effect. The bias can be positive or negative, depending on the directions of correlations between omitted confounders, treatments, and outcomes.", "continuous": "Continuous treatments: ", "continuousDescription": "On average in this sample, increasing this feature by 1 unit will cause the probability of class/label 1 to increase by X units.", + "continuousRegressionDescription": "On average in this sample, increasing this feature by 1 unit will cause the predictions of the target to increase/decrease by X units.", "description": "Causal analysis answers “what if” questions about how real world outcomes would have changed under different policy choices, such as a different pricing strategy for a product or an alternative treatment for a patient. Unlike model predictions that identify important correlation patterns, these tools help you identify the most important causal features that directly affect your outcome of interest. These models identify the causal effect of one feature (typically referred to as a “treatment”), holding other confounding features constant. For best results, make sure that the full dataset contains all available features that may correlate with the outcome as confounders.", "directAggregate": "Direct aggregate causal effect of each treatment with 95% confidence interval", "here": "here", From fb5c7695fd82f5849a4633a0dfbfe47010e94d29 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Thu, 12 May 2022 20:34:44 -0400 Subject: [PATCH 04/22] Fix description for model overview (#1425) * fix description for model overview * keep new description for new model overview --- .../Controls/ModelOverview/ModelOverview.tsx | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx index de3eb0a674..ae3a44e936 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx @@ -213,12 +213,19 @@ export class ModelOverview extends React.Component< className={classNames.sectionStack} tokens={{ childrenGap: "10px" }} > - - {localization.ModelAssessment.ModelOverview.topLevelDescription} - - {!this.props.showNewModelOverviewExperience && } + {!this.props.showNewModelOverviewExperience && ( + <> + + {localization.Interpret.ModelPerformance.helperText} + + + + )} {this.props.showNewModelOverviewExperience && ( + + {localization.ModelAssessment.ModelOverview.topLevelDescription} + Date: Fri, 13 May 2022 09:17:25 -0400 Subject: [PATCH 05/22] fix failing to create error report when filter_features is empty list (#1421) --- .../erroranalysis/analyzer/error_analyzer.py | 2 +- erroranalysis/erroranalysis/version.py | 2 +- erroranalysis/tests/test_error_report.py | 17 +++++++++++++---- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/erroranalysis/erroranalysis/analyzer/error_analyzer.py b/erroranalysis/erroranalysis/analyzer/error_analyzer.py index 51041a81b0..2aebf03381 100644 --- a/erroranalysis/erroranalysis/analyzer/error_analyzer.py +++ b/erroranalysis/erroranalysis/analyzer/error_analyzer.py @@ -334,7 +334,7 @@ def create_error_report(self, num_leaves=num_leaves, min_child_samples=min_child_samples) matrix = None - if filter_features is not None: + if filter_features: matrix = self.compute_matrix(filter_features, None, None) diff --git a/erroranalysis/erroranalysis/version.py b/erroranalysis/erroranalysis/version.py index 72eb56e0b2..8bce628c68 100644 --- a/erroranalysis/erroranalysis/version.py +++ b/erroranalysis/erroranalysis/version.py @@ -4,5 +4,5 @@ name = 'erroranalysis' _major = '0' _minor = '3' -_patch = '1' +_patch = '2' version = '{}.{}.{}'.format(_major, _minor, _patch) diff --git a/erroranalysis/tests/test_error_report.py b/erroranalysis/tests/test_error_report.py index d9be0ea3cd..412d18a7e9 100644 --- a/erroranalysis/tests/test_error_report.py +++ b/erroranalysis/tests/test_error_report.py @@ -51,7 +51,9 @@ def test_error_report_housing(self): run_error_analyzer(model, X_test, y_test, feature_names, categorical_features) - def test_error_report_housing_pandas(self): + @pytest.mark.parametrize('filter_features', + [None, [], ['MedInc', 'HouseAge']]) + def test_error_report_housing_pandas(self, filter_features): X_train, X_test, y_train, y_test, feature_names = \ create_housing_data() X_train = create_dataframe(X_train, feature_names) @@ -61,7 +63,8 @@ def test_error_report_housing_pandas(self): for model in models: categorical_features = [] run_error_analyzer(model, X_test, y_test, feature_names, - categorical_features) + categorical_features, + filter_features=filter_features) def is_valid_uuid(id): @@ -73,7 +76,8 @@ def is_valid_uuid(id): def run_error_analyzer(model, X_test, y_test, feature_names, - categorical_features, expect_user_warnings=False): + categorical_features, expect_user_warnings=False, + filter_features=None): if expect_user_warnings and pd.__version__[0] == '0': with pytest.warns(UserWarning, match='which has issues with pandas version'): @@ -84,7 +88,7 @@ def run_error_analyzer(model, X_test, y_test, feature_names, model_analyzer = ModelAnalyzer(model, X_test, y_test, feature_names, categorical_features) - report1 = model_analyzer.create_error_report(filter_features=None, + report1 = model_analyzer.create_error_report(filter_features, max_depth=3, num_leaves=None, compute_importances=True) @@ -109,6 +113,11 @@ def run_error_analyzer(model, X_test, y_test, feature_names, assert ea_deserialized.importances == report1.importances assert ea_deserialized.root_stats == report1.root_stats + if not filter_features: + assert ea_deserialized.matrix is None + else: + assert ea_deserialized.matrix is not None + # validate error report does not modify original dataset in ModelAnalyzer if isinstance(X_test, pd.DataFrame): assert X_test.equals(model_analyzer.dataset) From 4f1e262cecc4ffa14c80571a66e333fb4d16ad72 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Fri, 13 May 2022 09:18:23 -0400 Subject: [PATCH 06/22] filter out missing values from what if dropdown to prevent explanation dashboard from crashing (#1418) --- apps/dashboard/src/app/applications.ts | 2 + .../__mock_data__/ibmDataMissingValues.ts | 660 ++++++++++++++++++ .../core-ui/src/lib/util/getFeatureOptions.ts | 13 +- 3 files changed, 671 insertions(+), 4 deletions(-) create mode 100644 apps/dashboard/src/interpret/__mock_data__/ibmDataMissingValues.ts diff --git a/apps/dashboard/src/app/applications.ts b/apps/dashboard/src/app/applications.ts index e683c2e05d..822ba0eecf 100644 --- a/apps/dashboard/src/app/applications.ts +++ b/apps/dashboard/src/app/applications.ts @@ -28,6 +28,7 @@ import { breastCancerData } from "../interpret/__mock_data__/breastCancerData"; import { ebmData } from "../interpret/__mock_data__/ebmData"; import { ibmData } from "../interpret/__mock_data__/ibmData"; import { ibmDataInconsistent } from "../interpret/__mock_data__/ibmDataInconsistent"; +import { ibmDataMissingValues } from "../interpret/__mock_data__/ibmDataMissingValues"; import { ibmNoClass } from "../interpret/__mock_data__/ibmNoClass"; import { irisData } from "../interpret/__mock_data__/irisData"; import { irisDataNoLocal } from "../interpret/__mock_data__/irisDataNoLocal"; @@ -179,6 +180,7 @@ export const applications: IApplications = { ebmData: { classDimension: 2, data: ebmData }, ibmData: { classDimension: 2, data: ibmData }, ibmDataInconsistent: { classDimension: 2, data: ibmDataInconsistent }, + ibmDataMissingValues: { classDimension: 2, data: ibmDataMissingValues }, ibmNoClass: { classDimension: 2, data: ibmNoClass }, irisData: { classDimension: 3, data: irisData }, irisDataNoLocal: { classDimension: 3, data: irisDataNoLocal }, diff --git a/apps/dashboard/src/interpret/__mock_data__/ibmDataMissingValues.ts b/apps/dashboard/src/interpret/__mock_data__/ibmDataMissingValues.ts new file mode 100644 index 0000000000..b711f39a4f --- /dev/null +++ b/apps/dashboard/src/interpret/__mock_data__/ibmDataMissingValues.ts @@ -0,0 +1,660 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { IExplanationDashboardData } from "@responsible-ai/core-ui"; + +import { ibmData } from "./ibmData"; + +export const ibmDataMissingValues: IExplanationDashboardData = { + dataSummary: { + classNames: ibmData.dataSummary.classNames, + featureNames: ibmData.dataSummary.featureNames + }, + modelInformation: ibmData.modelInformation, + precomputedExplanations: ibmData.precomputedExplanations, + predictedY: ibmData.predictedY, + probabilityY: ibmData.probabilityY, + testData: [ + [ + 49, + "Travel_Rarely", + 1098, + "Research & Development", + 4, + 2, + "Medical", + 1, + "Male", + 85, + 2, + 5, + "Manager", + 3, + "Married", + 18711, + 12124, + 2, + "No", + 13, + 3, + 3, + 1, + 23, + 2, + 4, + 1, + 0, + 0, + 0 + ], + [ + 27, + "Travel_Rarely", + 269, + "Research & Development", + 5, + 1, + "Technical Degree", + 3, + "Male", + 42, + 2, + 3, + "Research Director", + 4, + "Divorced", + 12808, + 8842, + 1, + "Yes", + 16, + 3, + 2, + 1, + 9, + 3, + 3, + 9, + 8, + 0, + 8 + ], + [ + 41, + "Travel_Rarely", + 1085, + "Research & Development", + 2, + 4, + "Life Sciences", + 2, + "Female", + 57, + 1, + 1, + "Laboratory Technician", + 4, + "Divorced", + 2778, + 17725, + 4, + "Yes", + 13, + 3, + 3, + 1, + 10, + 1, + 2, + 7, + 7, + 1, + 0 + ], + [ + 44, + undefined, + 1097, + "Research & Development", + 10, + 4, + "Life Sciences", + 3, + "Male", + 96, + 3, + 1, + "Research Scientist", + 3, + "Single", + 2936, + 10826, + 1, + "Yes", + 11, + 3, + 3, + 0, + 6, + 4, + 3, + 6, + 4, + 0, + 2 + ], + [ + 29, + "Travel_Rarely", + 1246, + "Sales", + 19, + 3, + "Life Sciences", + 3, + "Male", + 77, + 2, + 2, + "Sales Executive", + 3, + "Divorced", + 8620, + 23757, + 1, + "No", + 14, + 3, + 3, + 2, + 10, + 3, + 3, + 10, + 7, + 0, + 4 + ], + [ + 37, + "Non-Travel", + 1413, + "Research & Development", + 5, + 2, + "Technical Degree", + 3, + "Male", + 84, + 4, + 1, + "Laboratory Technician", + 3, + "Single", + 3500, + 25470, + 0, + "No", + 14, + 3, + 1, + 0, + 7, + 2, + 1, + 6, + 5, + 1, + 3 + ], + [ + 47, + "Travel_Rarely", + 1454, + "Sales", + 2, + 4, + "Life Sciences", + 4, + "Female", + 65, + 2, + 1, + "Sales Representative", + 4, + "Single", + 3294, + 13137, + 1, + "Yes", + 18, + 3, + 1, + 0, + 3, + 3, + 2, + 3, + 2, + 1, + 2 + ], + [ + 35, + "Travel_Rarely", + 982, + undefined, + 1, + 4, + "Medical", + 4, + "Male", + 58, + 2, + 1, + "Laboratory Technician", + 3, + "Married", + 2258, + 16340, + 6, + "No", + 12, + 3, + 2, + 1, + 10, + 2, + 3, + 8, + 0, + 1, + 7 + ], + [ + 25, + "Travel_Rarely", + 1219, + "Research & Development", + 4, + 1, + "Technical Degree", + 4, + "Male", + 32, + 3, + 1, + "Laboratory Technician", + 4, + "Married", + 3691, + 4605, + 1, + "Yes", + 15, + 3, + 2, + 1, + 7, + 3, + 4, + 7, + 7, + 5, + 6 + ], + [ + 45, + "Non-Travel", + 1238, + "Research & Development", + 1, + 1, + "Life Sciences", + 3, + "Male", + 74, + 2, + 3, + "Healthcare Representative", + 3, + "Married", + 10748, + 3395, + 3, + "No", + 23, + 4, + 4, + 1, + 25, + 3, + 2, + 23, + 15, + 14, + 4 + ], + [ + 19, + "Travel_Rarely", + 1181, + "Research & Development", + 3, + 1, + "Medical", + 2, + "Female", + 79, + 3, + 1, + "Laboratory Technician", + 2, + "Single", + 1483, + 16102, + 1, + "No", + 14, + 3, + 4, + 0, + 1, + 3, + 3, + 1, + 0, + 0, + 0 + ], + [ + 32, + "Travel_Rarely", + 634, + "Research & Development", + 5, + 4, + "Other", + 2, + "Female", + 35, + 4, + 1, + "Research Scientist", + 4, + "Married", + 3312, + 18783, + 3, + "No", + 17, + 3, + 4, + 2, + 6, + 3, + 3, + 3, + 2, + 0, + 2 + ], + [ + 41, + "Travel_Rarely", + 642, + "Research & Development", + 1, + 3, + "Life Sciences", + 4, + "Male", + 76, + 3, + 1, + "Research Scientist", + 4, + "Married", + 2782, + 21412, + 3, + "No", + 22, + 4, + 1, + 1, + 12, + 3, + 3, + 5, + 3, + 1, + 0 + ], + [ + 32, + "Travel_Rarely", + 1018, + "Research & Development", + 2, + 4, + "Medical", + 1, + "Female", + 74, + 4, + 2, + "Research Scientist", + 4, + "Single", + 5055, + 10557, + 7, + "No", + 16, + 3, + 3, + 0, + 10, + 0, + 2, + 7, + 7, + 0, + 7 + ], + [ + 21, + "Travel_Rarely", + 1334, + "Research & Development", + 10, + 3, + "Life Sciences", + 3, + "Female", + 36, + 2, + 1, + "Laboratory Technician", + 1, + "Single", + 1416, + 17258, + 1, + "No", + 13, + 3, + 1, + 0, + 1, + 6, + 2, + 1, + 0, + 1, + 0 + ], + [ + 37, + "Travel_Rarely", + 408, + "Research & Development", + 19, + 2, + "Life Sciences", + 2, + "Male", + 73, + 3, + 1, + "Research Scientist", + 2, + "Married", + 3022, + 10227, + 4, + "No", + 21, + 4, + 1, + 0, + 8, + 1, + 3, + 1, + 0, + 0, + 0 + ], + [ + 30, + "Travel_Rarely", + 1358, + "Sales", + 16, + 1, + "Life Sciences", + 4, + "Male", + 96, + 3, + 2, + "Sales Executive", + 3, + "Married", + 5301, + 2939, + 8, + "No", + 15, + 3, + 3, + 2, + 4, + 2, + 2, + 2, + 1, + 2, + 2 + ], + [ + 39, + "Travel_Rarely", + 119, + "Sales", + 15, + 4, + "Marketing", + 2, + "Male", + 77, + 3, + 4, + "Sales Executive", + 1, + "Single", + 13341, + 25098, + 0, + "No", + 12, + 3, + 1, + 0, + 21, + 3, + 3, + 20, + 8, + 11, + 10 + ], + [ + 51, + "Travel_Rarely", + 432, + "Research & Development", + 9, + 4, + "Life Sciences", + 4, + "Male", + 96, + 3, + 1, + "Laboratory Technician", + 4, + "Married", + 2075, + 18725, + 3, + "No", + 23, + 4, + 2, + 2, + 10, + 4, + 3, + 4, + 2, + 0, + 3 + ], + [ + 48, + undefined, + 1224, + undefined, + 10, + 3, + undefined, + 4, + "Male", + 91, + 2, + 5, + "Research Director", + 2, + "Married", + 19665, + 13583, + 4, + "No", + 12, + 3, + 4, + 0, + 29, + 3, + 3, + 22, + 10, + 12, + 9 + ] + ], + trueY: ibmData.trueY +}; diff --git a/libs/core-ui/src/lib/util/getFeatureOptions.ts b/libs/core-ui/src/lib/util/getFeatureOptions.ts index cc70454dd0..186ca3d511 100644 --- a/libs/core-ui/src/lib/util/getFeatureOptions.ts +++ b/libs/core-ui/src/lib/util/getFeatureOptions.ts @@ -14,14 +14,19 @@ export function getFeatureOptions( const key = JointDataset.DataLabelRoot + index.toString(); const meta = jointDataset.metaDict[key]; const options = meta.isCategorical - ? meta.sortedCategoricalValues?.map( - (optionText: string | number, index: number) => { + ? meta.sortedCategoricalValues + ?.filter( + (optionText: string | number, _: number) => + !!optionText && + typeof optionText !== "boolean" && + typeof optionText !== "number" + ) + .map((optionText: string | number, index: number) => { if (typeof optionText !== "string") { optionText = optionText.toString(); } return { key: index, text: optionText }; - } - ) + }) : undefined; return { data: { From cf3c25bae79711ec7c45a65e3593fe752559851f Mon Sep 17 00:00:00 2001 From: Gaurav Gupta <47334368+gaugup@users.noreply.github.com> Date: Fri, 13 May 2022 09:39:16 -0700 Subject: [PATCH 07/22] Remove |Set Value| blurb in case it is not availble in counterfactual panel (#1426) Signed-off-by: Gaurav Gupta --- libs/counterfactuals/src/lib/CounterfactualPanel.tsx | 4 +++- libs/localization/src/lib/en.json | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/libs/counterfactuals/src/lib/CounterfactualPanel.tsx b/libs/counterfactuals/src/lib/CounterfactualPanel.tsx index de6920b351..d455ded593 100644 --- a/libs/counterfactuals/src/lib/CounterfactualPanel.tsx +++ b/libs/counterfactuals/src/lib/CounterfactualPanel.tsx @@ -135,7 +135,9 @@ export class CounterfactualPanel extends React.Component< - {localization.Counterfactuals.panelDescription} + {this.context.requestPredictions + ? localization.Counterfactuals.panelDescription + : localization.Counterfactuals.panelDescriptionWithoutSetValue} diff --git a/libs/localization/src/lib/en.json b/libs/localization/src/lib/en.json index cc8114968e..b49af8f9c0 100644 --- a/libs/localization/src/lib/en.json +++ b/libs/localization/src/lib/en.json @@ -130,6 +130,7 @@ "noData": "No data", "noFeatures": "No features available", "panelDescription": "Browse counterfactuals and create your own. Search features to see suggested values from a diverse set of counterfactual examples. Set suggested counterfactual feature values by clicking “Set Value” text under each counterfactual name. Name your counterfactual and save it.", + "panelDescriptionWithoutSetValue": "Browse counterfactuals and create your own. Search features to see suggested values from a diverse set of counterfactual examples.", "whatIfPanelHeader": "What-if counterfactuals", "panelHeader": "Counterfactuals", "recommendedPolicy": "Recommended policy gain", From 42a495c709d18ba51fc992a0735526dc9efbe576 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta <47334368+gaugup@users.noreply.github.com> Date: Fri, 13 May 2022 09:39:50 -0700 Subject: [PATCH 08/22] Add y-axis description to counterfactual feature importance chart (#1423) Signed-off-by: Gaurav Gupta Co-authored-by: xuke444 <40614413+xuke444@users.noreply.github.com> --- libs/counterfactuals/src/lib/LocalImportanceChart.tsx | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/libs/counterfactuals/src/lib/LocalImportanceChart.tsx b/libs/counterfactuals/src/lib/LocalImportanceChart.tsx index 329588bc1b..3acd4ca7c9 100644 --- a/libs/counterfactuals/src/lib/LocalImportanceChart.tsx +++ b/libs/counterfactuals/src/lib/LocalImportanceChart.tsx @@ -80,6 +80,12 @@ export class LocalImportanceChart extends React.PureComponent Date: Fri, 13 May 2022 09:43:19 -0700 Subject: [PATCH 09/22] Add the user class name to causal UI strings (#1422) * Fix causal UI strings according to classification/regression tasks Signed-off-by: Gaurav Gupta * Fix lint error Signed-off-by: Gaurav Gupta * Fix UI test Signed-off-by: Gaurav Gupta * Add the user class name to causal UI strings Signed-off-by: Gaurav Gupta Co-authored-by: xuke444 <40614413+xuke444@users.noreply.github.com> --- .../CausalAggregateView.tsx | 19 +++++++++++++++++-- libs/localization/src/lib/en.json | 4 ++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/libs/causality/src/lib/CausalAnalysisDashboard/Controls/CausalAnalysisView/CausalAggregateView/CausalAggregateView.tsx b/libs/causality/src/lib/CausalAnalysisDashboard/Controls/CausalAnalysisView/CausalAggregateView/CausalAggregateView.tsx index 95c0fdbd7c..7bc3c52e70 100644 --- a/libs/causality/src/lib/CausalAnalysisDashboard/Controls/CausalAnalysisView/CausalAggregateView/CausalAggregateView.tsx +++ b/libs/causality/src/lib/CausalAnalysisDashboard/Controls/CausalAnalysisView/CausalAggregateView/CausalAggregateView.tsx @@ -108,7 +108,15 @@ export class CausalAggregateView extends React.PureComponent Date: Fri, 13 May 2022 12:35:46 -0700 Subject: [PATCH 10/22] fix math.min / max for array size more than 10^7 (#1427) Signed-off-by: Ke Xu --- apps/dashboard/src/fairness/utils.ts | 3 ++- libs/core-ui/src/lib/util/JointDataset.ts | 16 ++++++++-------- .../lib/util/getFeatureImportanceBoxOptions.ts | 3 ++- .../TreeViewRenderer/TreeViewRenderer.tsx | 9 +++++---- .../Controls/FeatureImportance/Beehive.tsx | 4 ++-- .../src/lib/components/ModelMetadata.ts | 4 ++-- 6 files changed, 21 insertions(+), 18 deletions(-) diff --git a/apps/dashboard/src/fairness/utils.ts b/apps/dashboard/src/fairness/utils.ts index 1231c429b4..743a54b7ca 100644 --- a/apps/dashboard/src/fairness/utils.ts +++ b/apps/dashboard/src/fairness/utils.ts @@ -6,6 +6,7 @@ import { IMetricRequest, IMetricResponse } from "@responsible-ai/core-ui"; +import _ from "lodash"; export const supportedBinaryClassificationPerformanceKeys = [ "accuracy_score", @@ -65,7 +66,7 @@ export function generateRandomMetrics( request: IMetricRequest, abortSignal?: AbortSignal ): Promise { - const binSize = Math.max(...request.binVector); + const binSize = _.max(request.binVector) || 0; const bins: number[] = new Array(binSize + 1) .fill(0) .map(() => Math.random() / 3 + 0.33); diff --git a/libs/core-ui/src/lib/util/JointDataset.ts b/libs/core-ui/src/lib/util/JointDataset.ts index fcc617cfbd..83fbb3fb0f 100644 --- a/libs/core-ui/src/lib/util/JointDataset.ts +++ b/libs/core-ui/src/lib/util/JointDataset.ts @@ -190,8 +190,8 @@ export class JointDataset { }; if (args.metadata.modelType === ModelTypes.Regression) { this.metaDict[JointDataset.PredictedYLabel].featureRange = { - max: Math.max(...args.predictedY), - min: Math.min(...args.predictedY), + max: _.max(args.predictedY) || 0, + min: _.min(args.predictedY) || 0, rangeType: RangeTypes.Numeric }; } @@ -223,8 +223,8 @@ export class JointDataset { abbridgedLabel: label, category: ColumnCategories.Outcome, featureRange: { - max: Math.max(...projection), - min: Math.min(...projection), + max: _.max(projection) || 0, + min: _.min(projection) || 0, rangeType: RangeTypes.Numeric }, isCategorical: false, @@ -257,8 +257,8 @@ export class JointDataset { }; if (args.metadata.modelType === ModelTypes.Regression) { this.metaDict[JointDataset.TrueYLabel].featureRange = { - max: Math.max(...args.trueY), - min: Math.min(...args.trueY), + max: _.max(args.trueY) || 0, + min: _.min(args.trueY) || 0, rangeType: RangeTypes.Numeric }; } @@ -278,8 +278,8 @@ export class JointDataset { abbridgedLabel: localization.Interpret.Columns.error, category: ColumnCategories.Outcome, featureRange: { - max: Math.max(...regressionErrorArray), - min: Math.min(...regressionErrorArray), + max: _.max(regressionErrorArray) || 0, + min: _.min(regressionErrorArray) || 0, rangeType: RangeTypes.Numeric }, isCategorical: false, diff --git a/libs/core-ui/src/lib/util/getFeatureImportanceBoxOptions.ts b/libs/core-ui/src/lib/util/getFeatureImportanceBoxOptions.ts index 3d3bdac91b..dd023a908e 100644 --- a/libs/core-ui/src/lib/util/getFeatureImportanceBoxOptions.ts +++ b/libs/core-ui/src/lib/util/getFeatureImportanceBoxOptions.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import { SeriesOptionsType } from "highcharts"; +import _ from "lodash"; import { IGlobalSeries } from "../Highchart/FeatureImportanceBar"; import { IHighchartsConfig } from "../Highchart/IHighchartsConfig"; @@ -31,7 +32,7 @@ export function getFeatureImportanceBoxOptions( const y = base.concat( ...sortArray.map((index) => series.unsortedIndividualY?.[index] || []) ); - const curMin = Math.min(...y); + const curMin = _.min(y) || 0; yAxisMin = Math.min(yAxisMin, curMin); boxTempData.push({ color: FabricStyles.fabricColorPalette[series.colorIndex], diff --git a/libs/error-analysis/src/lib/ErrorAnalysisDashboard/Controls/TreeViewRenderer/TreeViewRenderer.tsx b/libs/error-analysis/src/lib/ErrorAnalysisDashboard/Controls/TreeViewRenderer/TreeViewRenderer.tsx index 607c715608..d173585b38 100644 --- a/libs/error-analysis/src/lib/ErrorAnalysisDashboard/Controls/TreeViewRenderer/TreeViewRenderer.tsx +++ b/libs/error-analysis/src/lib/ErrorAnalysisDashboard/Controls/TreeViewRenderer/TreeViewRenderer.tsx @@ -27,6 +27,7 @@ import { interpolateHcl as d3interpolateHcl } from "d3-interpolate"; import { scaleLinear as d3scaleLinear } from "d3-scale"; import { select } from "d3-selection"; import { linkVertical as d3linkVertical } from "d3-shape"; +import _ from "lodash"; import { getTheme, IProcessedStyleSet, @@ -261,12 +262,12 @@ export class TreeViewRenderer extends React.PureComponent< ); const x = rootDescendants.map((d) => d.x); const y = rootDescendants.map((d) => d.y); - const minX = Math.min(Math.min(...x) - 40, pathMin); + const minX = Math.min((_.min(x) || 0) - 40, pathMin); //100:tooltip width - const maxX = Math.max(Math.max(...x) + 40 + 100, pathMax); - const minY = Math.min(...y) - 40; + const maxX = Math.max((_.max(x) || 0) + 40 + 100, pathMax); + const minY = (_.min(y) || 0) - 40; //40:tooltip height - const maxY = Math.max(...y) + 40 + 40; + const maxY = (_.max(y) || 0) + 40 + 40; const containerStyles = mergeStyles({ transform: `translate(${-minX}px, ${-minY}px)` }); diff --git a/libs/interpret/src/lib/MLIDashboard/Controls/FeatureImportance/Beehive.tsx b/libs/interpret/src/lib/MLIDashboard/Controls/FeatureImportance/Beehive.tsx index 8548977d18..0cc15e2d30 100644 --- a/libs/interpret/src/lib/MLIDashboard/Controls/FeatureImportance/Beehive.tsx +++ b/libs/interpret/src/lib/MLIDashboard/Controls/FeatureImportance/Beehive.tsx @@ -82,8 +82,8 @@ export class Beehive extends React.PureComponent< const featureArray = data.testDataset.dataset?.map((row: number[]) => row[featureIndex]) || []; - const min = Math.min(...featureArray); - const max = Math.max(...featureArray); + const min = _.min(featureArray) || 0; + const max = _.max(featureArray) || 0; const range = max - min; return (value: string | number): number => { return range !== 0 && typeof value === "number" diff --git a/libs/mlchartlib/src/lib/components/ModelMetadata.ts b/libs/mlchartlib/src/lib/components/ModelMetadata.ts index 83cf84d534..095996d239 100644 --- a/libs/mlchartlib/src/lib/components/ModelMetadata.ts +++ b/libs/mlchartlib/src/lib/components/ModelMetadata.ts @@ -42,8 +42,8 @@ export class ModelMetadata { } const featureVector = testData.map((row) => row[featureIndex]); return { - max: Math.max(...featureVector), - min: Math.min(...featureVector), + max: _.max(featureVector) || 0, + min: _.min(featureVector) || 0, rangeType: featureVector.every((val) => Number.isInteger(val)) ? RangeTypes.Integer : RangeTypes.Numeric From 8c4ce561d593851fece2342b3e42468b2542a9ef Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 17 May 2022 07:28:21 -0400 Subject: [PATCH 11/22] upgrade pytest and lightgbm to try to fix random pytest segfault test failures (#1424) s --- .github/workflows/Ci-raiwigets-python-typescript.yml | 12 ++++++++++-- raiwidgets/requirements-dev.txt | 3 --- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/.github/workflows/Ci-raiwigets-python-typescript.yml b/.github/workflows/Ci-raiwigets-python-typescript.yml index 8c450ceca6..fe7dc0b456 100644 --- a/.github/workflows/Ci-raiwigets-python-typescript.yml +++ b/.github/workflows/Ci-raiwigets-python-typescript.yml @@ -45,8 +45,10 @@ jobs: run: yarn buildall - if: ${{ matrix.operatingSystem == 'macos-latest' }} - name: Use Homebrew to install libomp on MacOS for lightgbm - run: brew install https://publictestdatasets.blob.core.windows.net/data/libomp-11.1.0.catalina.bottle.tar.gz + name: Install latest numpy from conda-forge for MacOS + shell: bash -l {0} + run: | + conda install --yes --quiet -c conda-forge numpy - if: ${{ matrix.operatingSystem != 'macos-latest' }} name: Install pytorch on non-MacOS @@ -60,6 +62,12 @@ jobs: run: | conda install --yes --quiet pytorch torchvision captum -c pytorch + - if: ${{ matrix.operatingSystem == 'macos-latest' }} + name: Install latest lightgbm from conda-forge for MacOS + shell: bash -l {0} + run: | + conda install --yes --quiet lightgbm -c conda-forge + - name: Setup tools shell: bash -l {0} run: | diff --git a/raiwidgets/requirements-dev.txt b/raiwidgets/requirements-dev.txt index c125631603..ed3219b7e6 100644 --- a/raiwidgets/requirements-dev.txt +++ b/raiwidgets/requirements-dev.txt @@ -9,9 +9,6 @@ requirements-parser==0.2.0 wheel -# Required for interpret-community 0.18.0 -lightgbm==2.3.0 - fairlearn==0.6.0 # Jupyter dependency that fails with python 3.6 From b98201666bb8ce3453c43e7ee21809a34f04a7e3 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 17 May 2022 12:12:58 -0400 Subject: [PATCH 12/22] fix flaky notebook causing build failures by adding retry logic (#1431) --- ...retability-dashboard-loan-allocation.ipynb | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/notebooks/individual-dashboards/fairness-interpretability-dashboard-loan-allocation.ipynb b/notebooks/individual-dashboards/fairness-interpretability-dashboard-loan-allocation.ipynb index 868ae0bf1e..f7b65a8367 100644 --- a/notebooks/individual-dashboards/fairness-interpretability-dashboard-loan-allocation.ipynb +++ b/notebooks/individual-dashboards/fairness-interpretability-dashboard-loan-allocation.ipynb @@ -54,25 +54,21 @@ "outputs": [], "source": [ "from fairlearn.reductions import GridSearch\n", - "from fairlearn.reductions import DemographicParity, ErrorRate\n", + "from fairlearn.reductions import DemographicParity\n", "from fairlearn.datasets import fetch_adult\n", "from fairlearn.metrics import MetricFrame, selection_rate\n", "\n", - "from sklearn import svm, neighbors, tree\n", "from sklearn.compose import ColumnTransformer, make_column_selector\n", "from sklearn.preprocessing import LabelEncoder,StandardScaler\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", - "from sklearn.svm import SVC\n", "from sklearn.metrics import accuracy_score\n", "\n", "import pandas as pd\n", - "import numpy as np\n", "\n", "# SHAP Tabular Explainer\n", - "from interpret.ext.blackbox import KernelExplainer\n", "from interpret.ext.blackbox import MimicExplainer\n", "from interpret.ext.glassbox import LGBMExplainableModel" ] @@ -90,7 +86,23 @@ "metadata": {}, "outputs": [], "source": [ - "dataset = fetch_adult(as_frame=True)\n", + "from raiutils.common.retries import retry_function\n", + "\n", + "class FetchAdult(object):\n", + " def __init__(self):\n", + " pass\n", + "\n", + " def fetch(self):\n", + " return fetch_adult(as_frame=True)\n", + "\n", + "fetcher = FetchAdult()\n", + "action_name = \"Dataset download\"\n", + "err_msg = \"Failed to download dataset\"\n", + "max_retries = 4\n", + "retry_delay = 60\n", + "dataset = retry_function(fetcher.fetch, action_name, err_msg,\n", + " max_retries=max_retries,\n", + " retry_delay=retry_delay)\n", "X_raw, y = dataset['data'], dataset['target']\n", "X_raw[\"race\"].value_counts().to_dict()" ] @@ -201,7 +213,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Using SHAP KernelExplainer\n", + "# Using SHAP MimicExplainer\n", "# clf.steps[-1][1] returns the trained classification model\n", "explainer = MimicExplainer(model.steps[-1][1], \n", " X_train,\n", From f7b2fc830721c8ddd8b32fcd8f1dab1db2b2181b Mon Sep 17 00:00:00 2001 From: Richard Edgar Date: Tue, 17 May 2022 15:27:50 -0400 Subject: [PATCH 13/22] Upper bound SciKit-Learn to address freeze in causal (#1432) ## Description Replaces #1429 to address #1430 . Causal analysis is getting stuck with the latest release of SciKit-Learn. This contains: - Test case which gets stuck with SciKit-Learn 1.1.0 - Upper bound on SciKit-Learn in `requirements.txt` ## Checklist - [x] I have added screenshots above for all UI changes. - [x] Documentation was updated if it was needed. - [x] New tests were added or changes were manually verified. Signed-off-by: Richard Edgar --- responsibleai/requirements.txt | 2 +- .../tests/causal/test_causal_general.py | 51 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 responsibleai/tests/causal/test_causal_general.py diff --git a/responsibleai/requirements.txt b/responsibleai/requirements.txt index 4176531ce1..9844238436 100644 --- a/responsibleai/requirements.txt +++ b/responsibleai/requirements.txt @@ -7,7 +7,7 @@ lightgbm>=2.0.11 numpy>=1.17.2 numba<0.54.0 pandas>=0.25.1 -scikit-learn>=0.22.1 +scikit-learn>=0.22.1,<1.1 # See PR 1429 about upper bound scipy>=1.4.1 semver~=2.13.0 diff --git a/responsibleai/tests/causal/test_causal_general.py b/responsibleai/tests/causal/test_causal_general.py new file mode 100644 index 0000000000..d32e73a94c --- /dev/null +++ b/responsibleai/tests/causal/test_causal_general.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation +# Licensed under the MIT License. + + +from responsibleai import RAIInsights + +from ..causal_manager_validator import _check_causal_result +from ..common_utils import create_adult_income_dataset + + +def test_causal_classification_scikitlearn_issue(): + # This test gets stuck on SciKit-Learn v1.1.0 + # See PR #1429 + data_train, data_test, _, _, categorical_features, \ + _, target_name, classes = create_adult_income_dataset() + + rai_i = RAIInsights( + model=None, + train=data_train, + test=data_test, + task_type='classification', + target_column=target_name, + categorical_features=categorical_features, + classes=classes + ) + assert rai_i is not None + + treatment_features = ["age", "gender"] + cat_expansion = 49 + rai_i.causal.add(treatment_features=treatment_features, + heterogeneity_features=["marital_status"], + nuisance_model="automl", + heterogeneity_model="forest", + alpha=0.06, + upper_bound_on_cat_expansion=cat_expansion, + treatment_cost=[0.1, 0.2], + min_tree_leaf_samples=2, + skip_cat_limit_checks=False, + categories="auto", + n_jobs=1, + verbose=1, + random_state=100, + ) + + rai_i.compute() + + results = rai_i.causal.get() + assert results is not None + assert isinstance(results, list) + assert len(results) == 1 + _check_causal_result(results[0]) From 6e57f20865dc4415a5ea13d2f41e1b6a9b27ccdb Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Wed, 18 May 2022 09:03:02 -0400 Subject: [PATCH 14/22] fix dependency chart axis updating with incorrect values in explanation dashboard (#1437) --- libs/core-ui/src/lib/util/getDependencyChartOptions.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/libs/core-ui/src/lib/util/getDependencyChartOptions.ts b/libs/core-ui/src/lib/util/getDependencyChartOptions.ts index 31d9b528db..6bf6372975 100644 --- a/libs/core-ui/src/lib/util/getDependencyChartOptions.ts +++ b/libs/core-ui/src/lib/util/getDependencyChartOptions.ts @@ -28,6 +28,9 @@ export function getDependencyChartOptions( type: "scatter", zoomType: "xy" }, + custom: { + disableUpdate: true + }, plotOptions: { scatter: { marker: { From 4968e90d95670bcc2802c233e387a1141c20d211 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Wed, 18 May 2022 11:57:48 -0400 Subject: [PATCH 15/22] fix codecov and widget test screenshot uploads (#1428) --- .github/workflows/CD.yml | 2 +- .github/workflows/CI-notebook.yml | 6 ++--- .github/workflows/CI-python.yml | 4 ++-- .github/workflows/CI-typescript.yml | 6 ++--- .../Ci-raiwigets-python-typescript.yml | 23 ++++++++++--------- .github/workflows/release-rai.yml | 6 ++--- 6 files changed, 24 insertions(+), 23 deletions(-) diff --git a/.github/workflows/CD.yml b/.github/workflows/CD.yml index 8ae0537e41..3e2fd526bd 100644 --- a/.github/workflows/CD.yml +++ b/.github/workflows/CD.yml @@ -111,7 +111,7 @@ jobs: run: pytest --durations=10 - name: Upload a Build result - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: dist path: dist diff --git a/.github/workflows/CI-notebook.yml b/.github/workflows/CI-notebook.yml index e12e0148ef..357b418168 100644 --- a/.github/workflows/CI-notebook.yml +++ b/.github/workflows/CI-notebook.yml @@ -63,7 +63,7 @@ jobs: working-directory: raiwidgets - name: Upload requirements - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: requirements-dev.txt path: raiwidgets/installed-requirements-dev.txt @@ -74,14 +74,14 @@ jobs: - name: Upload notebook test result if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: notebook-test-${{ matrix.operatingSystem }}-${{ matrix.pythonVersion }} path: notebooks - name: Upload e2e test screen shot if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: raiwidgets-e2e-screen-shot path: dist/cypress diff --git a/.github/workflows/CI-python.yml b/.github/workflows/CI-python.yml index 97377ae26f..0b2d7a9289 100644 --- a/.github/workflows/CI-python.yml +++ b/.github/workflows/CI-python.yml @@ -46,7 +46,7 @@ jobs: working-directory: ${{ matrix.packageDirectory }} - name: Upload requirements - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: requirements-dev.txt path: ${{ matrix.packageDirectory }}/requirements-dev.txt @@ -67,7 +67,7 @@ jobs: working-directory: ${{ matrix.packageDirectory }} - name: Upload code coverage results - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: ${{ matrix.packageDirectory }}-code-coverage-results path: ${{ matrix.packageDirectory }}/htmlcov diff --git a/.github/workflows/CI-typescript.yml b/.github/workflows/CI-typescript.yml index de2c4a28a5..7e22a4afe3 100644 --- a/.github/workflows/CI-typescript.yml +++ b/.github/workflows/CI-typescript.yml @@ -32,19 +32,19 @@ jobs: run: yarn ci - name: Upload unit test code coverage if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: unit-test-coverage-${{ matrix.node-version }} path: coverage - name: Upload e2e test screen shot if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: e2e-screen-shot-${{ matrix.node-version }} path: dist/cypress - name: Upload yarn error if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: yarn-error.log-${{ matrix.node-version }} path: yarn-error.log diff --git a/.github/workflows/Ci-raiwigets-python-typescript.yml b/.github/workflows/Ci-raiwigets-python-typescript.yml index fe7dc0b456..c643435eaf 100644 --- a/.github/workflows/Ci-raiwigets-python-typescript.yml +++ b/.github/workflows/Ci-raiwigets-python-typescript.yml @@ -90,14 +90,15 @@ jobs: working-directory: ${{ matrix.packageDirectory }} - name: Upload requirements - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: requirements-dev.txt path: ${{ matrix.packageDirectory }}/installed-requirements-dev.txt - - name: Run widget tests + - if: ${{ matrix.operatingSystem != 'macos-latest' }} + name: Run widget tests + id: raiwidgettests shell: bash -l {0} - if: ${{ matrix.operatingSystem != 'macos-latest' }} run: yarn e2e-widget - name: Run tests @@ -106,13 +107,13 @@ jobs: pytest --durations=10 --junitxml=junit/test-results.xml --cov=${{ matrix.packageDirectory }} --cov-report=xml --cov-report=html working-directory: ${{ matrix.packageDirectory }} - - name: Upload code coverage results - uses: actions/upload-artifact@v2 + # Only try to upload code cov if python tests were run + - if: ${{ (steps.raiwidgettests.outcome == 'success') }} + name: Upload code coverage results + uses: actions/upload-artifact@v3 with: name: ${{ matrix.packageDirectory }}-code-coverage-results path: ${{ matrix.packageDirectory }}/htmlcov - # Use always() to always run this step to publish test results when there are test failures - if: ${{ always() }} - if: ${{ (matrix.operatingSystem == 'windows-latest') && (matrix.pythonVersion == '3.8') }} name: Upload to codecov @@ -142,12 +143,12 @@ jobs: name: codecov-umbrella verbose: true - - name: Upload e2e test screen shot - if: always() - uses: actions/upload-artifact@v2 + - if: ${{ matrix.operatingSystem != 'macos-latest' }} + name: Upload e2e test screen shot + uses: actions/upload-artifact@v3 with: name: ${{ matrix.packageDirectory }}-e2e-screen-shot - path: dist/cypress + path: ./dist/cypress - name: Set codecov status if: ${{ (matrix.pythonVersion == '3.8') && (matrix.operatingSystem == 'windows-latest') }} diff --git a/.github/workflows/release-rai.yml b/.github/workflows/release-rai.yml index cde69670ed..b25b6f10c1 100644 --- a/.github/workflows/release-rai.yml +++ b/.github/workflows/release-rai.yml @@ -142,19 +142,19 @@ jobs: run: yarn e2e-widget - name: Upload a raiwidgets build result - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: ${{ env.widgetDirectory }} path: ${{ env.widgetDirectory }}/dist/ - name: Upload a responsibleai build result - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: ${{ env.raiDirectory }} path: ${{ env.raiDirectory }}/dist/ - name: Upload a typescript build result - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: ${{ env.typescriptBuildArtifactName }} path: ${{ env.typescriptBuildOutput }} From b74045a6fdd4479e768606e988a8996234d72080 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Wed, 18 May 2022 13:29:16 -0400 Subject: [PATCH 16/22] release raiwidgets and responsibleai v0.18.2 (#1439) --- CHANGES.md | 28 ++++++++++++++++++++++++++++ version.cfg | 2 +- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index bbe87bbb68..9f937d7ef1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -24,6 +24,34 @@ this file to understand what changed. - bug fixes - other +## v0.18.2 + +- bug fixes and tests + - ## Responsible AI Dashboard + - Bug fixes on 'Set value' not copying over feature values correctly in what if counterfactual panel by @tongyu-microsoft in https://github.com/microsoft/responsible-ai-toolbox/pull/1416 + - Fix description for model overview by @romanlutz in https://github.com/microsoft/responsible-ai-toolbox/pull/1425 + - Fix math.min / max for array size more than 10^7 by @xuke444 in https://github.com/microsoft/responsible-ai-toolbox/pull/1427 + - ## RAIInsights + - Add warning in counterfactual manager when unable to load explainer by @gaugup in https://github.com/microsoft/responsible-ai-toolbox/pull/1412 + - ## Counterfactual + - Remove "Set Value" blurb in case it is not available in counterfactual panel by @gaugup in https://github.com/microsoft/responsible-ai-toolbox/pull/1426 + - Add y-axis description to counterfactual feature importance chart by @gaugup in https://github.com/microsoft/responsible-ai-toolbox/pull/1423 + - ## Causal + - Fix causal UI strings according to classification/regression tasks by @gaugup in https://github.com/microsoft/responsible-ai-toolbox/pull/1419 + - Add the user class name to causal UI strings by @gaugup in https://github.com/microsoft/responsible-ai-toolbox/pull/1422 + - Upper bound SciKit-Learn to address freeze in causal by @riedgar-ms in https://github.com/microsoft/responsible-ai-toolbox/pull/1432 + - ## Error Analysis + - Fix error on machines with pyspark installed where passed dataframe is not spark pandas by @imatiach-msft in https://github.com/microsoft/responsible-ai-toolbox/pull/1415 + - Fix failing to create error report when filter_features is empty list by @imatiach-msft in https://github.com/microsoft/responsible-ai-toolbox/pull/1421 + - ## Interpret + - Filter out missing values from what if dropdown to prevent explanation dashboard from crashing by @imatiach-msft in https://github.com/microsoft/responsible-ai-toolbox/pull/1418 + - Fix dependency chart axis updating with incorrect values in explanation dashboard by @imatiach-msft in https://github.com/microsoft/responsible-ai-toolbox/pull/1437 +- ## other + - Add postbuild branch trigger by @romanlutz in https://github.com/microsoft/responsible-ai-toolbox/pull/1417 + - Upgrade numpy to fix random segfault test failures by @imatiach-msft in https://github.com/microsoft/responsible-ai-toolbox/pull/1424 + - Fix flaky notebook causing build failures by adding retry logic by @imatiach-msft in https://github.com/microsoft/responsible-ai-toolbox/pull/1431 + - Fix codecov and widget test screenshot uploads by @imatiach-msft in https://github.com/microsoft/responsible-ai-toolbox/pull/1428 + ## v0.18.1 - educational materials diff --git a/version.cfg b/version.cfg index 249afd517d..503a21deb4 100644 --- a/version.cfg +++ b/version.cfg @@ -1 +1 @@ -0.18.1 +0.18.2 From 4c811574eadf4f0691ebad0a3413a00ffaa6303e Mon Sep 17 00:00:00 2001 From: Vinutha Karanth Date: Thu, 19 May 2022 14:17:45 -0700 Subject: [PATCH 17/22] fix (#1441) Signed-off-by: vinutha karanth --- .../dataExplorer/describeCohortFunctionality.ts | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/libs/e2e/src/lib/describer/modelAssessment/dataExplorer/describeCohortFunctionality.ts b/libs/e2e/src/lib/describer/modelAssessment/dataExplorer/describeCohortFunctionality.ts index 8b4f564955..ef3bccf433 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/dataExplorer/describeCohortFunctionality.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/dataExplorer/describeCohortFunctionality.ts @@ -37,9 +37,13 @@ export function describeCohortFunctionality( cy.get("#cohortEditPanel").should("exist"); cy.get(Locators.CohortNameInput).clear().type(cohortName); cy.get(Locators.CohortFilterSelection).eq(1).check(); // select Dataset - cy.get(Locators.CohortDatasetValueInput) - .clear() - .type(dataShape.datasetExplorerData?.cohortDatasetNewValue || ""); + cy.get(Locators.CohortDatasetValueInput).then(($input) => { + if ($input.length > 0) { + cy.get(Locators.CohortDatasetValueInput) + .clear() + .type(dataShape.datasetExplorerData?.cohortDatasetNewValue || ""); + } + }); cy.get(Locators.CohortAddFilterButton).click(); cy.get(Locators.CohortSaveAndSwitchButton).eq(0).click({ force: true }); cy.get(Locators.NewCohortSpan).should("exist"); From 7de48f7e58bf784f39e73d29821f3656d971c79c Mon Sep 17 00:00:00 2001 From: Vinutha Karanth Date: Fri, 20 May 2022 14:09:17 -0700 Subject: [PATCH 18/22] Fix cohort name conflict and not run few tests for AML (#1442) * fix Signed-off-by: vinutha karanth * lintfix Signed-off-by: vinutha karanth --- .../whatIfCounterfactuals/describeWhatIf.ts | 2 +- .../describeWhatIfCreate.ts | 72 ++++++++++--------- libs/e2e/src/util/createCohort.ts | 5 +- libs/e2e/src/util/generateId.ts | 11 +++ 4 files changed, 55 insertions(+), 35 deletions(-) create mode 100644 libs/e2e/src/util/generateId.ts diff --git a/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIf.ts b/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIf.ts index 0bbf039755..73519d912e 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIf.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIf.ts @@ -42,7 +42,7 @@ export function describeWhatIf( ) { describeWhatIfCommonFunctionalities(datasetShape); describeAxisFlyouts(datasetShape); - describeWhatIfCreate(datasetShape); + describeWhatIfCreate(datasetShape, name); } }); } diff --git a/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIfCreate.ts b/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIfCreate.ts index 4ca95d00df..96f9c1d4d4 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIfCreate.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIfCreate.ts @@ -4,8 +4,12 @@ import { getSpan } from "../../../../util/getSpan"; import { Locators } from "../Constants"; import { IModelAssessmentData } from "../IModelAssessmentData"; +import { modelAssessmentDatasets } from "../modelAssessmentDatasets"; -export function describeWhatIfCreate(dataShape: IModelAssessmentData): void { +export function describeWhatIfCreate( + dataShape: IModelAssessmentData, + name?: keyof typeof modelAssessmentDatasets +): void { describe("What if Create counterfactual", () => { before(() => { cy.get(Locators.WICDatapointDropbox).click(); @@ -50,40 +54,42 @@ export function describeWhatIfCreate(dataShape: IModelAssessmentData): void { dataShape.whatIfCounterfactualsData?.columnHeaderAfterSort || "" ); }); + // AML do not need to execute below tests, as these options are not available for static view + if (name) { + it("Should have 'Create your own counterfactual' section and it should be editable", () => { + cy.get(Locators.CreateYourOwnCounterfactualInputField) + .eq(2) + .clear() + .type( + dataShape.whatIfCounterfactualsData + ?.createYourOwnCounterfactualInputFieldUpdated || "25" + ); + cy.get(Locators.CreateYourOwnCounterfactualInputField).eq(2).focus(); + cy.focused() + .should("have.attr", "value") + .and( + "contain", + dataShape.whatIfCounterfactualsData + ?.createYourOwnCounterfactualInputFieldUpdated || "25" + ); + }); - it("Should have 'Create your own counterfactual' section and it should be editable", () => { - cy.get(Locators.CreateYourOwnCounterfactualInputField) - .eq(2) - .clear() - .type( - dataShape.whatIfCounterfactualsData - ?.createYourOwnCounterfactualInputFieldUpdated || "25" - ); - cy.get(Locators.CreateYourOwnCounterfactualInputField).eq(2).focus(); - cy.focused() - .should("have.attr", "value") - .and( - "contain", - dataShape.whatIfCounterfactualsData - ?.createYourOwnCounterfactualInputFieldUpdated || "25" - ); - }); - - it("Should have what-if counterfactual name as 'Copy of row ' by default and should be editable", () => { - cy.get(Locators.WhatIfNameLabel) - .should("have.attr", "value") - .and("contain", dataShape.whatIfCounterfactualsData?.whatIfNameLabel); - cy.get(Locators.WhatIfNameLabel).type( - dataShape.whatIfCounterfactualsData?.whatIfNameLabelUpdated || - "New Copy of row 1" - ); - cy.get(Locators.WhatIfNameLabel) - .should("have.attr", "value") - .and( - "contain", - dataShape.whatIfCounterfactualsData?.whatIfNameLabelUpdated + it("Should have what-if counterfactual name as 'Copy of row ' by default and should be editable", () => { + cy.get(Locators.WhatIfNameLabel) + .should("have.attr", "value") + .and("contain", dataShape.whatIfCounterfactualsData?.whatIfNameLabel); + cy.get(Locators.WhatIfNameLabel).type( + dataShape.whatIfCounterfactualsData?.whatIfNameLabelUpdated || + "New Copy of row 1" ); - }); + cy.get(Locators.WhatIfNameLabel) + .should("have.attr", "value") + .and( + "contain", + dataShape.whatIfCounterfactualsData?.whatIfNameLabelUpdated + ); + }); + } }); describe.skip("What-If save scenario", () => { diff --git a/libs/e2e/src/util/createCohort.ts b/libs/e2e/src/util/createCohort.ts index 67885259a5..08ae1ee178 100644 --- a/libs/e2e/src/util/createCohort.ts +++ b/libs/e2e/src/util/createCohort.ts @@ -3,10 +3,13 @@ import { Locators } from "../lib/describer/modelAssessment/Constants"; +import { generateId } from "./generateId"; + export function createCohort(): void { + const cohortName = `CohortCreateE2E-${generateId(4)}`; cy.get(Locators.CreateNewCohortButton).click(); cy.get("#cohortEditPanel").should("exist"); - cy.get(Locators.CohortNameInput).clear().type("CohortCreateE2E"); + cy.get(Locators.CohortNameInput).clear().type(cohortName); cy.get(Locators.CohortFilterSelection).eq(1).check(); // select Dataset cy.get(Locators.CohortAddFilterButton).click(); cy.get(Locators.CohortSaveAndSwitchButton).eq(0).click({ force: true }); diff --git a/libs/e2e/src/util/generateId.ts b/libs/e2e/src/util/generateId.ts new file mode 100644 index 0000000000..125a67a598 --- /dev/null +++ b/libs/e2e/src/util/generateId.ts @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +export function generateId(length?: number): string { + const len = length === undefined ? 4 : length; + // tslint:disable-next-line: insecure-random + return Math.random() + .toString(36) + .replace(/[^a-z]+/g, "") + .slice(0, Math.max(0, len)); +} From b1bf63805c110ed2988a9c8c7757131e8e8df797 Mon Sep 17 00:00:00 2001 From: Vinutha Karanth Date: Mon, 23 May 2022 13:33:40 -0700 Subject: [PATCH 19/22] Few e2e tests changes to accommodate AML static tests (#1445) * update Signed-off-by: vinutha karanth * update Signed-off-by: vinutha karanth --- .../describeGlobalExplanationChart.ts | 2 +- .../describeIndividualFeatureImportance.ts | 2 +- .../describeSubLineChart.ts | 86 +++++++++++-------- .../describeTabularDataView.ts | 12 ++- .../describeWhatIfCreate.ts | 37 ++++---- 5 files changed, 79 insertions(+), 60 deletions(-) diff --git a/libs/e2e/src/lib/describer/modelAssessment/featureImportances/aggregateFeatureImportance/describeGlobalExplanationChart.ts b/libs/e2e/src/lib/describer/modelAssessment/featureImportances/aggregateFeatureImportance/describeGlobalExplanationChart.ts index 55c7a8d89b..66807110c2 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/featureImportances/aggregateFeatureImportance/describeGlobalExplanationChart.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/featureImportances/aggregateFeatureImportance/describeGlobalExplanationChart.ts @@ -64,7 +64,7 @@ export function describeGlobalExplanationChart< const dependencePlotChart = new ScatterHighchart("#DependencePlot"); describe("DependencePlot", () => { beforeEach(() => { - selectComboBox("#DependencePlotFeatureSelection", 0); + selectComboBox("#DependencePlotFeatureSelection", 3); }); it("should render", () => { expect(dependencePlotChart.Elements.length).greaterThan(0); diff --git a/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeIndividualFeatureImportance.ts b/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeIndividualFeatureImportance.ts index 98f13fa7b1..500747ff24 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeIndividualFeatureImportance.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeIndividualFeatureImportance.ts @@ -45,7 +45,7 @@ export function describeIndividualFeatureImportance( }); } if (datasetShape.featureImportanceData?.hasFeatureImportanceComponent) { - describeTabularDataView(datasetShape); + describeTabularDataView(datasetShape, name); } }); } diff --git a/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeSubLineChart.ts b/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeSubLineChart.ts index 841d84630c..bd284fbc62 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeSubLineChart.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeSubLineChart.ts @@ -6,48 +6,58 @@ import { localization } from "@responsible-ai/localization"; import { selectRow } from "../../../../../util/Table"; import { Locators } from "../../Constants"; import { IModelAssessmentData } from "../../IModelAssessmentData"; +import { modelAssessmentDatasets } from "../../modelAssessmentDatasets"; -export function describeSubLineChart(dataShape: IModelAssessmentData): void { - describe("Sub line chart", () => { - before(() => { - selectRow("Index", "4"); +export function describeSubLineChart( + dataShape: IModelAssessmentData, + name?: keyof typeof modelAssessmentDatasets +): void { + // AML do not need to execute below tests, as these options are not available for static view + if (name) { + describe("Sub line chart", () => { + before(() => { + selectRow("Index", "4"); - cy.get(Locators.ICEPlot).click(); - }); - after(() => { - selectRow("Index", "4"); - }); - it("should have more than one point", () => { - cy.get(Locators.ICENoOfPoints).its("length").should("be.gte", 1); - }); + cy.get(Locators.ICEPlot).click(); + }); + after(() => { + selectRow("Index", "4"); + }); + it("should have more than one point", () => { + cy.get(Locators.ICENoOfPoints).its("length").should("be.gte", 1); + }); - it("should update x-axis value when 'Feature' dropdown is changed", () => { - cy.get(Locators.ICEFeatureDropdown).eq(0).click(); // feature dropdown - cy.get(".ms-Callout") - .should("be.visible") - .contains( + it("should update x-axis value when 'Feature' dropdown is changed", () => { + cy.get(Locators.ICEFeatureDropdown).eq(0).click(); // feature dropdown + cy.get(".ms-Callout") + .should("be.visible") + .contains( + dataShape.featureImportanceData?.newFeatureDropdownValue || "" + ) + .scrollIntoView() + .focus() + .click({ force: true }); + cy.get(Locators.ICEXAxisNewValue).should( + "contain", dataShape.featureImportanceData?.newFeatureDropdownValue || "" - ) - .scrollIntoView() - .focus() - .click({ force: true }); - cy.get(Locators.ICEXAxisNewValue).should( - "contain", - dataShape.featureImportanceData?.newFeatureDropdownValue || "" - ); - }); + ); + }); - it("Should have tooltip 'How to read this chart'", () => { - cy.get(Locators.ICEToolTipButton).should("exist"); - cy.get(Locators.ICEToolTipButton).click({ force: true }); - cy.get(Locators.ICECalloutTitle) - .scrollIntoView() - .should("exist") - .should("contain", localization.Interpret.WhatIfTab.icePlot); - cy.get(Locators.ICECalloutBody) - .scrollIntoView() - .should("exist") - .should("contain", localization.Interpret.WhatIfTab.icePlotHelperText); + it("Should have tooltip 'How to read this chart'", () => { + cy.get(Locators.ICEToolTipButton).should("exist"); + cy.get(Locators.ICEToolTipButton).click({ force: true }); + cy.get(Locators.ICECalloutTitle) + .scrollIntoView() + .should("exist") + .should("contain", localization.Interpret.WhatIfTab.icePlot); + cy.get(Locators.ICECalloutBody) + .scrollIntoView() + .should("exist") + .should( + "contain", + localization.Interpret.WhatIfTab.icePlotHelperText + ); + }); }); - }); + } } diff --git a/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeTabularDataView.ts b/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeTabularDataView.ts index f2e566d5ae..178a62a125 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeTabularDataView.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeTabularDataView.ts @@ -5,12 +5,18 @@ import { getMenu } from "../../../../../util/getMenu"; import { selectRow } from "../../../../../util/Table"; import { Locators } from "../../Constants"; import { IModelAssessmentData } from "../../IModelAssessmentData"; -import { regExForNumbersWithBrackets } from "../../modelAssessmentDatasets"; +import { + modelAssessmentDatasets, + regExForNumbersWithBrackets +} from "../../modelAssessmentDatasets"; // import { describeSubBarChart } from "./describeSubBarChart"; import { describeSubLineChart } from "./describeSubLineChart"; -export function describeTabularDataView(dataShape: IModelAssessmentData): void { +export function describeTabularDataView( + dataShape: IModelAssessmentData, + name?: keyof typeof modelAssessmentDatasets +): void { describe("Tabular data view", () => { before(() => { getMenu("Individual feature importance").click(); @@ -65,7 +71,7 @@ export function describeTabularDataView(dataShape: IModelAssessmentData): void { // describeSubBarChart(dataShape); // } if (!dataShape.featureImportanceData?.noPredict) { - describeSubLineChart(dataShape); + describeSubLineChart(dataShape, name); } }); } diff --git a/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIfCreate.ts b/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIfCreate.ts index 96f9c1d4d4..83e470a000 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIfCreate.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIfCreate.ts @@ -24,23 +24,6 @@ export function describeWhatIfCreate( after(() => { cy.get(Locators.WhatIfCloseButton).click(); }); - it.skip("should sort feature on clicking 'Sort feature columns by counterfactual feature importance'", () => { - cy.get(Locators.WhatIfColumnHeaders) - .eq(2) - .contains( - dataShape.whatIfCounterfactualsData?.columnHeaderBeforeSort || "" - ); - cy.get(Locators.WhatIfCreateCounterfactualSortButton).click(); - cy.get(Locators.WhatIfColumnHeaders) - .eq(2) - .invoke("text") - .then((text1) => { - expect(text1).to.not.equal( - dataShape.whatIfCounterfactualsData?.columnHeaderBeforeSort - ); - }); - cy.get(Locators.WhatIfCreateCounterfactualSortButton).click(); - }); it("should filter by included letters in search query", () => { cy.get(Locators.WhatIfSearchBar).type( @@ -118,4 +101,24 @@ export function describeWhatIfCreate( cy.get(Locators.WhatIfSaveAsDataPoints).should("not.exist"); }); }); + + describe.skip("What-If sort scenario", () => { + it("should sort feature on clicking 'Sort feature columns by counterfactual feature importance'", () => { + cy.get(Locators.WhatIfColumnHeaders) + .eq(2) + .contains( + dataShape.whatIfCounterfactualsData?.columnHeaderBeforeSort || "" + ); + cy.get(Locators.WhatIfCreateCounterfactualSortButton).click(); + cy.get(Locators.WhatIfColumnHeaders) + .eq(2) + .invoke("text") + .then((text1) => { + expect(text1).to.not.equal( + dataShape.whatIfCounterfactualsData?.columnHeaderBeforeSort + ); + }); + cy.get(Locators.WhatIfCreateCounterfactualSortButton).click(); + }); + }); } From 0010330ea56e54aa02624df43e011b06cb4c4ba7 Mon Sep 17 00:00:00 2001 From: Vinutha Karanth Date: Mon, 23 May 2022 23:13:06 -0700 Subject: [PATCH 20/22] Fix locators logic for string features - data explorer and model statistics components (#1446) * update Signed-off-by: vinutha karanth * update Signed-off-by: vinutha karanth * fix Signed-off-by: vinutha karanth * update Signed-off-by: vinutha karanth * lintfix Signed-off-by: vinutha karanth * fix Signed-off-by: vinutha karanth --- .../src/lib/describer/modelAssessment/Constants.ts | 2 ++ .../dataExplorer/describeCohortFunctionality.ts | 11 ++++++----- .../modelAssessment/modelAssessmentDatasets.ts | 10 +++++----- .../describeModelPerformanceSideBar.ts | 13 +++++++++---- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/libs/e2e/src/lib/describer/modelAssessment/Constants.ts b/libs/e2e/src/lib/describer/modelAssessment/Constants.ts index 555265d38c..29fa148c49 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/Constants.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/Constants.ts @@ -30,6 +30,7 @@ export enum Locators { SortByDropdown = "#featureImportanceChartContainer div.ms-Dropdown-container", SortByDropdownOptions = "div[class^='dropdownItemsWrapper'] button:contains('CohortCreateE2E')", CreateNewCohortButton = "button:contains('New cohort')", + CohortEditPanel = "#cohortEditPanel", CohortNameInput = "#cohortEditPanel input:eq(0)", CohortDatasetValueInput = "#cohortEditPanel input[class^='ms-spinButton-input']", CohortFilterSelection = "#cohortEditPanel [type='radio']", @@ -89,6 +90,7 @@ export enum Locators { CausalAnalysisHeader = "#ModelAssessmentDashboard #causalAnalysisHeader", ErrorAnalysisHeader = "#ModelAssessmentDashboard #errorAnalysisHeader", MSSideBarCards = "#OverallMetricChart div[class^='statsBox']", + AxisConfigPanel = "#AxisConfigPanel", MSSideBarNumberOfBinsInput = "#AxisConfigPanel input[class^='ms-spinButton-input']", MSScrollable = "#OverallMetricChart div[class^='scrollableWrapper']", MSCohortDropdown = "#modelPerformanceCohortPicker", diff --git a/libs/e2e/src/lib/describer/modelAssessment/dataExplorer/describeCohortFunctionality.ts b/libs/e2e/src/lib/describer/modelAssessment/dataExplorer/describeCohortFunctionality.ts index ef3bccf433..b1f9c91f9c 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/dataExplorer/describeCohortFunctionality.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/dataExplorer/describeCohortFunctionality.ts @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +import { generateId } from "../../../../util/generateId"; import { Locators } from "../Constants"; import { IModelAssessmentData } from "../IModelAssessmentData"; -const cohortName = "CohortCreateE2E"; export function describeCohortFunctionality( dataShape: IModelAssessmentData ): void { @@ -34,11 +34,12 @@ export function describeCohortFunctionality( }); it("Should update dataset selection with new cohort when a new cohort is created", () => { cy.get(Locators.CreateNewCohortButton).click(); - cy.get("#cohortEditPanel").should("exist"); + cy.get(Locators.CohortEditPanel).should("exist"); + const cohortName = `CohortCreateE2E-${generateId(4)}`; cy.get(Locators.CohortNameInput).clear().type(cohortName); cy.get(Locators.CohortFilterSelection).eq(1).check(); // select Dataset - cy.get(Locators.CohortDatasetValueInput).then(($input) => { - if ($input.length > 0) { + cy.get(Locators.CohortEditPanel).then(($panel) => { + if ($panel.find(Locators.CohortDatasetValueInput).length > 0) { cy.get(Locators.CohortDatasetValueInput) .clear() .type(dataShape.datasetExplorerData?.cohortDatasetNewValue || ""); @@ -46,7 +47,7 @@ export function describeCohortFunctionality( }); cy.get(Locators.CohortAddFilterButton).click(); cy.get(Locators.CohortSaveAndSwitchButton).eq(0).click({ force: true }); - cy.get(Locators.NewCohortSpan).should("exist"); + cy.get(`span:contains(${cohortName})`).should("exist"); cy.get(Locators.DECohortDropdown).click(); cy.get(Locators.DEDropdownOptions).should("exist"); diff --git a/libs/e2e/src/lib/describer/modelAssessment/modelAssessmentDatasets.ts b/libs/e2e/src/lib/describer/modelAssessment/modelAssessmentDatasets.ts index 5ad9e5a2ae..56eae4b941 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/modelAssessmentDatasets.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/modelAssessmentDatasets.ts @@ -72,7 +72,7 @@ const modelAssessmentDatasets = { xAxisNewValue: "Probability : <=50K", yAxisNewPanelValue: "Dataset", yAxisNewValue: "age", - yAxisNumberOfBins: "8" + yAxisNumberOfBins: "5" }, whatIfCounterfactualsData: { checkForClassField: true, @@ -143,7 +143,7 @@ const modelAssessmentDatasets = { xAxisNewValue: "Error", yAxisNewPanelValue: "Dataset", yAxisNewValue: "age", - yAxisNumberOfBins: "8" + yAxisNumberOfBins: "5" }, whatIfCounterfactualsData: { checkForClassField: false, @@ -209,7 +209,7 @@ const modelAssessmentDatasets = { xAxisNewValue: "Error", yAxisNewPanelValue: "Dataset", yAxisNewValue: "age", - yAxisNumberOfBins: "8" + yAxisNumberOfBins: "5" }, whatIfCounterfactualsData: { checkForClassField: false, @@ -302,7 +302,7 @@ const modelAssessmentDatasets = { xAxisNewValue: "Probability : Less than median", yAxisNewPanelValue: "Dataset", yAxisNewValue: "LotFrontage", - yAxisNumberOfBins: "8" + yAxisNumberOfBins: "5" }, whatIfCounterfactualsData: { checkForClassField: true, @@ -452,7 +452,7 @@ const modelAssessmentDatasets = { xAxisNewValue: "Probability : 0", yAxisNewPanelValue: "Dataset", yAxisNewValue: "alcohol", - yAxisNumberOfBins: "8" + yAxisNumberOfBins: "5" }, whatIfCounterfactualsData: { hasWhatIfCounterfactualsComponent: false diff --git a/libs/e2e/src/lib/describer/modelAssessment/modelStatistics/describeModelPerformanceSideBar.ts b/libs/e2e/src/lib/describer/modelAssessment/modelStatistics/describeModelPerformanceSideBar.ts index 7a0cfe6545..5a70ee8557 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/modelStatistics/describeModelPerformanceSideBar.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/modelStatistics/describeModelPerformanceSideBar.ts @@ -31,16 +31,21 @@ export function describeModelPerformanceSideBar( `${Locators.DECChoiceFieldGroup} label:contains(${dataShape.modelStatisticsData?.yAxisNewPanelValue})` ) .click(); - cy.get(Locators.MSSideBarNumberOfBinsInput) - .clear() - .type(dataShape.modelStatisticsData?.yAxisNumberOfBins || "8"); + cy.get(Locators.AxisConfigPanel).then(($panel) => { + if ($panel.find(Locators.MSSideBarNumberOfBinsInput).length > 0) { + cy.get(Locators.MSSideBarNumberOfBinsInput) + .clear() + .type(dataShape.modelStatisticsData?.yAxisNumberOfBins || "5"); + } + }); + cy.get(Locators.SelectButton).click(); cy.get(`${Locators.MSCRotatedVerticalBox}`).contains( dataShape.modelStatisticsData?.yAxisNewValue || "age" ); cy.get(Locators.MSSideBarCards).should( "have.length", - dataShape.modelStatisticsData?.yAxisNumberOfBins || "8" + dataShape.modelStatisticsData?.yAxisNumberOfBins || "5" ); // Side bar should be scrollable when data cards overflows cy.get(Locators.MSScrollable).should("exist"); From f19be20c406e4ec7d5f62c1d06628c652b4d8ba0 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta <47334368+gaugup@users.noreply.github.com> Date: Tue, 24 May 2022 18:18:47 -0700 Subject: [PATCH 21/22] Add more unittests RAI dashboard input class (#1448) * Add unit tests for ResponsibleAIDashboardInput Signed-off-by: Gaurav Gupta * Add more tests Signed-off-by: Gaurav Gupta * Fix imports Signed-off-by: Gaurav Gupta * Address code review comments Signed-off-by: Gaurav Gupta --- .../test_responsibleai_dashboard_input.py | 218 +++++++++++++++++- 1 file changed, 208 insertions(+), 10 deletions(-) diff --git a/raiwidgets/tests/test_responsibleai_dashboard_input.py b/raiwidgets/tests/test_responsibleai_dashboard_input.py index 0d16e36b4a..071e7c2266 100644 --- a/raiwidgets/tests/test_responsibleai_dashboard_input.py +++ b/raiwidgets/tests/test_responsibleai_dashboard_input.py @@ -1,24 +1,222 @@ # Copyright (c) Microsoft Corporation # Licensed under the MIT License. -from unittest.mock import patch - +from raiwidgets.interfaces import WidgetRequestResponseConstants from raiwidgets.responsibleai_dashboard_input import \ ResponsibleAIDashboardInput -class TestResponsibleAIDashboardInput: - def test_model_analysis_adult( +class TestResponsibleAIDashboardInputClassification: + def test_rai_dashboard_input_adult_on_predict_success( self, create_rai_insights_object_classification): ri = create_rai_insights_object_classification knn = ri.model test_data = ri.test dashboard_input = ResponsibleAIDashboardInput(ri) - with patch.object(knn, "predict_proba") as predict_mock: - test_pred_data = test_data.head(1).drop("Income", axis=1).values - dashboard_input.on_predict( - test_pred_data) + test_pred_data = test_data.head(1).drop("Income", axis=1).values + flask_server_prediction_output = dashboard_input.on_predict( + test_pred_data) + knn_prediction = knn.predict_proba(test_pred_data) + + assert knn_prediction is not None + assert flask_server_prediction_output is not None + assert WidgetRequestResponseConstants.data in \ + flask_server_prediction_output + assert (flask_server_prediction_output['data'] == knn_prediction).all() + + def test_rai_dashboard_input_adult_on_predict_failure( + self, create_rai_insights_object_classification): + ri = create_rai_insights_object_classification + test_data = ri.test + + dashboard_input = ResponsibleAIDashboardInput(ri) + test_pred_data = test_data.head(1).values + flask_server_prediction_output = dashboard_input.on_predict( + test_pred_data) + + assert flask_server_prediction_output is not None + assert WidgetRequestResponseConstants.error in \ + flask_server_prediction_output + assert "Model threw exception while predicting..." in \ + flask_server_prediction_output[ + WidgetRequestResponseConstants.error] + assert len( + flask_server_prediction_output[ + WidgetRequestResponseConstants.data]) == 0 + + def test_rai_dashboard_input_adult_importances_success( + self, create_rai_insights_object_classification): + ri = create_rai_insights_object_classification + + dashboard_input = ResponsibleAIDashboardInput(ri) + flask_server_prediction_output = dashboard_input.importances() + assert flask_server_prediction_output is not None + assert WidgetRequestResponseConstants.data in \ + flask_server_prediction_output + assert WidgetRequestResponseConstants.error not in \ + flask_server_prediction_output + + def test_rai_dashboard_input_adult_matrix_success( + self, create_rai_insights_object_classification): + ri = create_rai_insights_object_classification + features = ['Age', 'Workclass'] + filters = [] + composite_filters = [] + quantile_binning = False + num_bins = 8 + metric = "Error rate" + post_data = [features, filters, composite_filters, + quantile_binning, num_bins, metric] + + dashboard_input = ResponsibleAIDashboardInput(ri) + flask_server_prediction_output = dashboard_input.matrix(post_data) + assert flask_server_prediction_output is not None + assert WidgetRequestResponseConstants.data in \ + flask_server_prediction_output + assert WidgetRequestResponseConstants.error not in \ + flask_server_prediction_output + + def test_rai_dashboard_input_adult_matrix_failure( + self, create_rai_insights_object_classification): + ri = create_rai_insights_object_classification + features = ['Age', 'Workclass'] + filters = [] + composite_filters = [] + quantile_binning = False + num_bins = 8 + metric = "Error Rate" + post_data = [features, filters, composite_filters, + quantile_binning, num_bins, metric] + + dashboard_input = ResponsibleAIDashboardInput(ri) + flask_server_prediction_output = dashboard_input.matrix(post_data) + assert flask_server_prediction_output is not None + assert WidgetRequestResponseConstants.error in \ + flask_server_prediction_output + assert "Failed to generate json matrix representation," in \ + flask_server_prediction_output[ + WidgetRequestResponseConstants.error] + assert len( + flask_server_prediction_output[ + WidgetRequestResponseConstants.data]) == 0 - assert (predict_mock.call_args[0] - [0].values == test_pred_data).all() + def test_rai_dashboard_input_adult_debug_ml_success( + self, create_rai_insights_object_classification): + ri = create_rai_insights_object_classification + + features = ri.test.drop("Income", axis=1).columns.tolist() + filters = [] + composite_filters = [] + max_depth = 3 + num_leaves = 3 + min_child_samples = 8 + metric = "Error rate" + post_data = [features, filters, composite_filters, + max_depth, num_leaves, min_child_samples, metric] + + dashboard_input = ResponsibleAIDashboardInput(ri) + flask_server_prediction_output = dashboard_input.debug_ml(post_data) + assert flask_server_prediction_output is not None + assert WidgetRequestResponseConstants.data in \ + flask_server_prediction_output + assert WidgetRequestResponseConstants.error not in \ + flask_server_prediction_output + + def test_rai_dashboard_input_adult_debug_ml_failure( + self, create_rai_insights_object_classification): + ri = create_rai_insights_object_classification + + features = ri.test.drop("Income", axis=1).columns.tolist() + filters = [] + composite_filters = [] + max_depth = 3 + num_leaves = 3 + min_child_samples = 8 + metric = "Error Rate" + post_data = [features, filters, composite_filters, + max_depth, num_leaves, min_child_samples, metric] + + dashboard_input = ResponsibleAIDashboardInput(ri) + flask_server_prediction_output = dashboard_input.debug_ml(post_data) + assert flask_server_prediction_output is not None + assert WidgetRequestResponseConstants.error in \ + flask_server_prediction_output + assert "Failed to generate json tree representation," in \ + flask_server_prediction_output[ + WidgetRequestResponseConstants.error] + assert len( + flask_server_prediction_output[ + WidgetRequestResponseConstants.data]) == 0 + + +class TestResponsibleAIDashboardInputRegression: + def test_rai_dashboard_input_housing_on_predict_success( + self, create_rai_insights_object_regression): + ri = create_rai_insights_object_regression + rf = ri.model + test_data = ri.test + + dashboard_input = ResponsibleAIDashboardInput(ri) + test_pred_data = test_data.head(1).drop("target", axis=1).values + flask_server_prediction_output = dashboard_input.on_predict( + test_pred_data) + rf_prediction = rf.predict(test_pred_data) + + assert rf_prediction is not None + assert flask_server_prediction_output is not None + assert WidgetRequestResponseConstants.data in \ + flask_server_prediction_output + assert (flask_server_prediction_output['data'] == rf_prediction).all() + + def test_rai_dashboard_input_housing_causal_whatif_success( + self, create_rai_insights_object_regression): + ri = create_rai_insights_object_regression + id = ri.causal.get()[0].id + causal_whatif_test_data = ri.test.head(1).drop( + "target", axis=1).to_dict(orient='records') + treatment_feature = 'AveRooms' + current_treatment_value = [causal_whatif_test_data[0][ + treatment_feature]] + current_outcome = [ri.test.head(1)["target"].values[0]] + + dashboard_input = ResponsibleAIDashboardInput(ri) + post_data = (id, causal_whatif_test_data, + treatment_feature, current_treatment_value, + current_outcome) + flask_server_prediction_output = dashboard_input.causal_whatif( + post_data) + assert flask_server_prediction_output is not None + assert WidgetRequestResponseConstants.data in \ + flask_server_prediction_output + assert WidgetRequestResponseConstants.error not in \ + flask_server_prediction_output + + def test_rai_dashboard_input_housing_causal_whatif_failure( + self, create_rai_insights_object_regression): + ri = create_rai_insights_object_regression + id = "some_id" + causal_whatif_test_data = ri.test.head(1).drop( + "target", axis=1).to_dict(orient='records') + treatment_feature = 'AveRooms' + current_treatment_value = [causal_whatif_test_data[0][ + treatment_feature]] + current_outcome = [ri.test.head(1)["target"].values[0]] + + dashboard_input = ResponsibleAIDashboardInput(ri) + post_data = (id, causal_whatif_test_data, + treatment_feature, current_treatment_value, + current_outcome) + flask_server_prediction_output = dashboard_input.causal_whatif( + post_data) + assert flask_server_prediction_output is not None + assert WidgetRequestResponseConstants.data in \ + flask_server_prediction_output + assert len( + flask_server_prediction_output[ + WidgetRequestResponseConstants.data]) == 0 + assert WidgetRequestResponseConstants.error in \ + flask_server_prediction_output + assert "Failed to generate causal what-if," in \ + flask_server_prediction_output[ + WidgetRequestResponseConstants.error] From 647296482a01fc8257a9ce2ddb879aad53e28666 Mon Sep 17 00:00:00 2001 From: Vinutha Karanth Date: Wed, 25 May 2022 17:24:08 -0700 Subject: [PATCH 22/22] Update the way to get the length of elements obtained in e2e tests (#1450) * update Signed-off-by: vinutha karanth * update Signed-off-by: vinutha karanth * skip what-if create tests for AML Signed-off-by: vinutha karanth --- .../describer/modelAssessment/Constants.ts | 1 + .../describeAggregateCausalAffects.ts | 10 ++-- .../describeSubBarChart.ts | 7 ++- .../describeSubLineChart.ts | 4 +- .../describeAxisFlyouts.ts | 1 + .../describeSubBarChart.ts | 6 +- .../describeWhatIfCreate.ts | 60 +++++++++---------- 7 files changed, 46 insertions(+), 43 deletions(-) diff --git a/libs/e2e/src/lib/describer/modelAssessment/Constants.ts b/libs/e2e/src/lib/describer/modelAssessment/Constants.ts index 29fa148c49..65a220c48a 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/Constants.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/Constants.ts @@ -44,6 +44,7 @@ export enum Locators { WhatIfScatterChartYAxis = "#IndividualFeatureContainer div[class^='rotatedVerticalBox']", WhatIfScatterChartFlyoutCancel = "#AxisConfigPanel button:contains('Cancel')", WhatIfScatterChartFlyoutSelect = "#AxisConfigPanel button:contains('Select')", + WhatIfScatterChartSelectFeatureCaretButton = "#AxisConfigPanel i[data-icon-name='ChevronDown']", WhatIfAxisPanel = "#AxisConfigPanel", AxisFeatureDropdown = "#AxisConfigPanel div.ms-ComboBox-container", AxisFeatureDropdownOption = "div.ms-ComboBox-optionsContainerWrapper button[role='option']", diff --git a/libs/e2e/src/lib/describer/modelAssessment/causalAnalysis/describeAggregateCausalAffects.ts b/libs/e2e/src/lib/describer/modelAssessment/causalAnalysis/describeAggregateCausalAffects.ts index 4dfdff7bcc..2b7017b194 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/causalAnalysis/describeAggregateCausalAffects.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/causalAnalysis/describeAggregateCausalAffects.ts @@ -59,12 +59,10 @@ export function describeAggregateCausalAffects( }); it("should render feature names on x-axis that are passed in from SDK", () => { - cy.get(Locators.CausalChartXAxisValues) - .its("length") - .should( - "be", - dataShape.causalAnalysisData?.featureListInCausalTable?.length - ); + cy.get(Locators.CausalChartXAxisValues).should( + "have.length", + dataShape.causalAnalysisData?.featureListInCausalTable?.length + ); cy.get(`${Locators.CausalChartXAxisValues}`) .last() .invoke("text") diff --git a/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeSubBarChart.ts b/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeSubBarChart.ts index 6a3a720deb..639c395dcc 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeSubBarChart.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeSubBarChart.ts @@ -35,9 +35,10 @@ export function describeSubBarChart(dataShape: IModelAssessmentData): void { ); }); it("should have right number of x axis labels", () => { - cy.get(Locators.IFIXAxisValue) - .its("length") - .should("be", props.dataShape.featureNames?.length); + cy.get(Locators.IFIXAxisValue).should( + "have.length", + props.dataShape.featureNames?.length + ); }); it("should update x axis labels on changing top features by their importance number", () => { diff --git a/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeSubLineChart.ts b/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeSubLineChart.ts index bd284fbc62..ca0cead326 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeSubLineChart.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/featureImportances/individualFeatureImportance/describeSubLineChart.ts @@ -24,7 +24,9 @@ export function describeSubLineChart( selectRow("Index", "4"); }); it("should have more than one point", () => { - cy.get(Locators.ICENoOfPoints).its("length").should("be.gte", 1); + cy.get(Locators.ICENoOfPoints).then(($noOfPoints) => { + expect($noOfPoints).length.to.be.at.least(1); + }); }); it("should update x-axis value when 'Feature' dropdown is changed", () => { diff --git a/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeAxisFlyouts.ts b/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeAxisFlyouts.ts index 0466d52a0e..a448dc676f 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeAxisFlyouts.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeAxisFlyouts.ts @@ -22,6 +22,7 @@ export function describeAxisFlyouts(dataShape: IModelAssessmentData): void { "have.length", dataShape.featureNames?.length ); + cy.get(Locators.WhatIfScatterChartSelectFeatureCaretButton).click(); cy.get(Locators.WhatIfScatterChartFlyoutCancel).click(); }); it("should be able to select different feature", () => { diff --git a/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeSubBarChart.ts b/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeSubBarChart.ts index 5918245dcd..58c184135e 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeSubBarChart.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeSubBarChart.ts @@ -30,9 +30,9 @@ export function describeSubBarChart(dataShape: IModelAssessmentData): void { ).should("contain.text", "Feature importance"); }); it("should have right number of x axis labels", () => { - cy.get("#WhatIfFeatureImportanceBar g.highcharts-xaxis-labels text") - .its("length") - .should("be", props.dataShape.featureNames?.length); + cy.get( + "#WhatIfFeatureImportanceBar g.highcharts-xaxis-labels text" + ).should("have.length", props.dataShape.featureNames?.length); }); }); } diff --git a/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIfCreate.ts b/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIfCreate.ts index 83e470a000..fe02b5fc27 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIfCreate.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/whatIfCounterfactuals/describeWhatIfCreate.ts @@ -10,35 +10,35 @@ export function describeWhatIfCreate( dataShape: IModelAssessmentData, name?: keyof typeof modelAssessmentDatasets ): void { - describe("What if Create counterfactual", () => { - before(() => { - cy.get(Locators.WICDatapointDropbox).click(); - getSpan( - dataShape.whatIfCounterfactualsData?.selectedDatapoint || "Index 1" - ).click(); - cy.get(Locators.CreateWhatIfCounterfactualButton) - .click() - .get(Locators.WhatIfCounterfactualPanel) - .should("exist"); - }); - after(() => { - cy.get(Locators.WhatIfCloseButton).click(); - }); + // AML do not need to execute below tests, as these options are not available for static view + if (name) { + describe("What if Create counterfactual", () => { + before(() => { + cy.get(Locators.WICDatapointDropbox).click(); + getSpan( + dataShape.whatIfCounterfactualsData?.selectedDatapoint || "Index 1" + ).click(); + cy.get(Locators.CreateWhatIfCounterfactualButton) + .click() + .get(Locators.WhatIfCounterfactualPanel) + .should("exist"); + }); + after(() => { + cy.get(Locators.WhatIfCloseButton).click(); + }); - it("should filter by included letters in search query", () => { - cy.get(Locators.WhatIfSearchBar).type( - dataShape.whatIfCounterfactualsData?.searchBarQuery || "" - ); - cy.get(Locators.WhatIfColumnHeaders) - .eq(2) - .contains(dataShape.whatIfCounterfactualsData?.searchBarQuery || ""); - cy.get(Locators.WhatIfSearchBarClearTextButton).click(); - cy.get(Locators.WhatIfColumnHeaders).contains( - dataShape.whatIfCounterfactualsData?.columnHeaderAfterSort || "" - ); - }); - // AML do not need to execute below tests, as these options are not available for static view - if (name) { + it("should filter by included letters in search query", () => { + cy.get(Locators.WhatIfSearchBar).type( + dataShape.whatIfCounterfactualsData?.searchBarQuery || "" + ); + cy.get(Locators.WhatIfColumnHeaders) + .eq(2) + .contains(dataShape.whatIfCounterfactualsData?.searchBarQuery || ""); + cy.get(Locators.WhatIfSearchBarClearTextButton).click(); + cy.get(Locators.WhatIfColumnHeaders).contains( + dataShape.whatIfCounterfactualsData?.columnHeaderAfterSort || "" + ); + }); it("Should have 'Create your own counterfactual' section and it should be editable", () => { cy.get(Locators.CreateYourOwnCounterfactualInputField) .eq(2) @@ -72,8 +72,8 @@ export function describeWhatIfCreate( dataShape.whatIfCounterfactualsData?.whatIfNameLabelUpdated ); }); - } - }); + }); + } describe.skip("What-If save scenario", () => { before(() => {