From c1e8699b14e1ffb42579be7240c413471c260315 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Mon, 7 Aug 2023 16:29:50 -0400 Subject: [PATCH] add multiclass statistics to text and tabular RAI dashboards --- libs/core-ui/src/index.ts | 2 +- ...sUtils.ts => MulticlassStatisticsUtils.ts} | 26 ++-- libs/core-ui/src/lib/util/StatisticsUtils.ts | 53 ++------ .../src/lib/util/StatisticsUtilsEnums.ts | 8 +- .../datasets/MulticlassDnnModelDebugging.ts | 10 +- ...tasetCohortsViewBasicElementsArePresent.ts | 6 + .../Controls/ModelOverview/ModelOverview.tsx | 15 +-- .../Controls/ModelOverview/StatsTableUtils.ts | 116 ++++++++---------- 8 files changed, 101 insertions(+), 135 deletions(-) rename libs/core-ui/src/lib/util/{ImageStatisticsUtils.ts => MulticlassStatisticsUtils.ts} (81%) diff --git a/libs/core-ui/src/index.ts b/libs/core-ui/src/index.ts index b2a0c2d5db..cf2faa2d24 100644 --- a/libs/core-ui/src/index.ts +++ b/libs/core-ui/src/index.ts @@ -38,7 +38,7 @@ export * from "./lib/util/getRandomId"; export * from "./lib/util/getCohortFilterCount"; export * from "./lib/util/getDependencyChartOptions"; export * from "./lib/util/IGenericChartProps"; -export * from "./lib/util/ImageStatisticsUtils"; +export * from "./lib/util/MulticlassStatisticsUtils"; export * from "./lib/util/initializeOfficeFabric"; export * from "./lib/util/isNumber"; export * from "./lib/util/ModelExplanationUtils"; diff --git a/libs/core-ui/src/lib/util/ImageStatisticsUtils.ts b/libs/core-ui/src/lib/util/MulticlassStatisticsUtils.ts similarity index 81% rename from libs/core-ui/src/lib/util/ImageStatisticsUtils.ts rename to libs/core-ui/src/lib/util/MulticlassStatisticsUtils.ts index f77e156cbc..4a82e6e408 100644 --- a/libs/core-ui/src/lib/util/ImageStatisticsUtils.ts +++ b/libs/core-ui/src/lib/util/MulticlassStatisticsUtils.ts @@ -8,15 +8,7 @@ import { TotalCohortSamples } from "../Interfaces/IStatistic"; -export enum ImageClassificationMetrics { - Accuracy = "accuracy", - MacroF1 = "f1", - MacroPrecision = "precision", - MacroRecall = "recall", - MicroF1 = "microF1", - MicroPrecision = "microPrecision", - MicroRecall = "microRecall" -} +import { MulticlassClassificationMetrics } from "./StatisticsUtilsEnums"; interface IMicroMacroRetVal { macroScore: number; @@ -64,7 +56,7 @@ export const generateMicroMacroMetrics = ( }; }; -export const generateImageStats: ( +export const generateMulticlassStats: ( trueYs: number[], predYs: number[] ) => ILabeledStatistic[] = ( @@ -90,37 +82,37 @@ export const generateImageStats: ( stat: predYs.length }, { - key: ImageClassificationMetrics.Accuracy, + key: MulticlassClassificationMetrics.Accuracy, label: localization.Interpret.Statistics.accuracy, stat: accuracy }, { - key: ImageClassificationMetrics.MicroPrecision, + key: MulticlassClassificationMetrics.MicroPrecision, label: localization.Interpret.Statistics.precision, stat: microP }, { - key: ImageClassificationMetrics.MicroRecall, + key: MulticlassClassificationMetrics.MicroRecall, label: localization.Interpret.Statistics.recall, stat: microR }, { - key: ImageClassificationMetrics.MicroF1, + key: MulticlassClassificationMetrics.MicroF1, label: localization.Interpret.Statistics.f1Score, stat: microF1 }, { - key: ImageClassificationMetrics.MacroPrecision, + key: MulticlassClassificationMetrics.MacroPrecision, label: localization.Interpret.Statistics.precision, stat: macroP }, { - key: ImageClassificationMetrics.MacroRecall, + key: MulticlassClassificationMetrics.MacroRecall, label: localization.Interpret.Statistics.recall, stat: macroR }, { - key: ImageClassificationMetrics.MacroF1, + key: MulticlassClassificationMetrics.MacroF1, label: localization.Interpret.Statistics.f1Score, stat: macroF1 } diff --git a/libs/core-ui/src/lib/util/StatisticsUtils.ts b/libs/core-ui/src/lib/util/StatisticsUtils.ts index 27f9077557..cde9e2ca93 100644 --- a/libs/core-ui/src/lib/util/StatisticsUtils.ts +++ b/libs/core-ui/src/lib/util/StatisticsUtils.ts @@ -10,18 +10,14 @@ import { } from "../Interfaces/IStatistic"; import { IsBinary } from "../util/ExplanationUtils"; -import { generateImageStats } from "./ImageStatisticsUtils"; import { JointDataset } from "./JointDataset"; -import { - ClassificationEnum, - MulticlassClassificationEnum -} from "./JointDatasetUtils"; +import { ClassificationEnum } from "./JointDatasetUtils"; +import { generateMulticlassStats } from "./MulticlassStatisticsUtils"; import { generateMultilabelStats } from "./MultilabelStatisticsUtils"; import { generateObjectDetectionStats } from "./ObjectDetectionStatisticsUtils"; import { generateQuestionAnsweringStats } from "./QuestionAnsweringStatisticsUtils"; import { BinaryClassificationMetrics, - MulticlassClassificationMetrics, RegressionMetrics } from "./StatisticsUtilsEnums"; @@ -147,27 +143,6 @@ const generateRegressionStats: ( ]; }; -const generateMulticlassStats: (outcomes: number[]) => ILabeledStatistic[] = ( - outcomes: number[] -): ILabeledStatistic[] => { - const correctCount = outcomes.filter( - (x) => x === MulticlassClassificationEnum.Correct - ).length; - const total = outcomes.length; - return [ - { - key: TotalCohortSamples, - label: localization.Interpret.Statistics.samples, - stat: total - }, - { - key: MulticlassClassificationMetrics.Accuracy, - label: localization.Interpret.Statistics.accuracy, - stat: correctCount / total - } - ]; -}; - export const generateMetrics: ( jointDataset: JointDataset, selectionIndexes: number[][], @@ -206,13 +181,6 @@ export const generateMetrics: ( return generateRegressionStats(trueYSubset, predYSubset, errorsSubset); }); } - if (modelType === ModelTypes.ImageMulticlass) { - return selectionIndexes.map((selectionArray) => { - const trueYSubset = selectionArray.map((i) => trueYs[i]); - const predYSubset = selectionArray.map((i) => predYs[i]); - return generateImageStats(trueYSubset, predYSubset); - }); - } if ( modelType === ModelTypes.ObjectDetection && objectDetectionCache && @@ -225,12 +193,17 @@ export const generateMetrics: ( ); } const outcomes = jointDataset.unwrap(JointDataset.ClassificationError); - return selectionIndexes.map((selectionArray) => { - const outcomeSubset = selectionArray.map((i) => outcomes[i]); - if (IsBinary(modelType)) { + if (IsBinary(modelType)) { + return selectionIndexes.map((selectionArray) => { + const outcomeSubset = selectionArray.map((i) => outcomes[i]); + return generateBinaryStats(outcomeSubset); - } - // modelType === ModelTypes.Multiclass - return generateMulticlassStats(outcomeSubset); + }); + } + // modelType === ModelTypes.Multiclass + return selectionIndexes.map((selectionArray) => { + const trueYSubset = selectionArray.map((i) => trueYs[i]); + const predYSubset = selectionArray.map((i) => predYs[i]); + return generateMulticlassStats(trueYSubset, predYSubset); }); }; diff --git a/libs/core-ui/src/lib/util/StatisticsUtilsEnums.ts b/libs/core-ui/src/lib/util/StatisticsUtilsEnums.ts index f029438b03..6a1a51392b 100644 --- a/libs/core-ui/src/lib/util/StatisticsUtilsEnums.ts +++ b/libs/core-ui/src/lib/util/StatisticsUtilsEnums.ts @@ -19,5 +19,11 @@ export enum RegressionMetrics { } export enum MulticlassClassificationMetrics { - Accuracy = "accuracy" + Accuracy = "accuracy", + MacroF1 = "f1", + MacroPrecision = "precision", + MacroRecall = "recall", + MicroF1 = "microF1", + MicroPrecision = "microPrecision", + MicroRecall = "microRecall" } diff --git a/libs/e2e/src/lib/describer/modelAssessment/datasets/MulticlassDnnModelDebugging.ts b/libs/e2e/src/lib/describer/modelAssessment/datasets/MulticlassDnnModelDebugging.ts index 1adb1d917d..28daade702 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/datasets/MulticlassDnnModelDebugging.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/datasets/MulticlassDnnModelDebugging.ts @@ -63,7 +63,10 @@ export const MulticlassDnnModelDebugging = { initialCohorts: [ { metrics: { - accuracy: "0.674" + accuracy: "0.674", + macroF1Score: "0.673", + macroPrecisionScore: "0.669", + macroRecallScore: "0.677" }, name: "All data", sampleSize: "89" @@ -71,7 +74,10 @@ export const MulticlassDnnModelDebugging = { ], newCohort: { metrics: { - accuracy: "0.67" + accuracy: "0.67", + macroF1Score: "0.671", + macroPrecisionScore: "0.666", + macroRecallScore: "0.675" }, name: "CohortCreateE2E-multiclass-dnn", sampleSize: "88" diff --git a/libs/e2e/src/lib/describer/modelAssessment/modelOverview/ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent.ts b/libs/e2e/src/lib/describer/modelAssessment/modelOverview/ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent.ts index cc1c846203..6f00727607 100644 --- a/libs/e2e/src/lib/describer/modelAssessment/modelOverview/ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent.ts +++ b/libs/e2e/src/lib/describer/modelAssessment/modelOverview/ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent.ts @@ -75,6 +75,12 @@ export function ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent( "falseNegativeRate", "selectionRate" ); + } else { + metricsOrder.push( + "macroF1Score", + "macroPrecisionScore", + "macroRecallScore" + ); } } diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx index bb524b9176..cdb8c54aa3 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx @@ -33,7 +33,6 @@ import { TelemetryLevels, TelemetryEventName, DatasetTaskType, - ImageClassificationMetrics, QuestionAnsweringMetrics, TotalCohortSamples } from "@responsible-ai/core-ui"; @@ -147,17 +146,13 @@ export class ModelOverview extends React.Component< BinaryClassificationMetrics.FalseNegativeRate, BinaryClassificationMetrics.SelectionRate ]; - } else if ( - this.context.dataset.task_type === DatasetTaskType.ImageClassification - ) { + } else { defaultSelectedMetrics = [ - ImageClassificationMetrics.Accuracy, - ImageClassificationMetrics.MacroF1, - ImageClassificationMetrics.MacroPrecision, - ImageClassificationMetrics.MacroRecall + MulticlassClassificationMetrics.Accuracy, + MulticlassClassificationMetrics.MacroF1, + MulticlassClassificationMetrics.MacroPrecision, + MulticlassClassificationMetrics.MacroRecall ]; - } else { - defaultSelectedMetrics = [MulticlassClassificationMetrics.Accuracy]; } } else if ( this.context.dataset.task_type === diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts index 073bc0364f..f4a4d32751 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts @@ -8,7 +8,6 @@ import { ErrorCohort, HighchartsNull, ILabeledStatistic, - ImageClassificationMetrics, MulticlassClassificationMetrics, MultilabelMetrics, ObjectDetectionMetrics, @@ -258,74 +257,63 @@ export function getSelectableMetrics( taskType === DatasetTaskType.ImageClassification ) { if (isMulticlass) { - if (taskType === DatasetTaskType.ImageClassification) { - selectableMetrics.push( - { - description: - localization.ModelAssessment.ModelOverview.metrics.accuracy - .description, - key: ImageClassificationMetrics.Accuracy, - text: localization.ModelAssessment.ModelOverview.metrics.accuracy - .name - }, - { - description: - localization.ModelAssessment.ModelOverview.metrics.precisionMacro - .description, - key: ImageClassificationMetrics.MacroPrecision, - text: localization.ModelAssessment.ModelOverview.metrics - .precisionMacro.name - }, - { - description: - localization.ModelAssessment.ModelOverview.metrics.recallMacro - .description, - key: ImageClassificationMetrics.MacroRecall, - text: localization.ModelAssessment.ModelOverview.metrics.recallMacro - .name - }, - { - description: - localization.ModelAssessment.ModelOverview.metrics.f1ScoreMacro - .description, - key: ImageClassificationMetrics.MacroF1, - text: localization.ModelAssessment.ModelOverview.metrics - .f1ScoreMacro.name - }, - { - description: - localization.ModelAssessment.ModelOverview.metrics.precisionMicro - .description, - key: ImageClassificationMetrics.MicroPrecision, - text: localization.ModelAssessment.ModelOverview.metrics - .precisionMicro.name - }, - { - description: - localization.ModelAssessment.ModelOverview.metrics.recallMicro - .description, - key: ImageClassificationMetrics.MicroRecall, - text: localization.ModelAssessment.ModelOverview.metrics.recallMicro - .name - }, - { - description: - localization.ModelAssessment.ModelOverview.metrics.f1ScoreMicro - .description, - key: ImageClassificationMetrics.MicroF1, - text: localization.ModelAssessment.ModelOverview.metrics - .f1ScoreMicro.name - } - ); - } else { - selectableMetrics.push({ + selectableMetrics.push( + { description: localization.ModelAssessment.ModelOverview.metrics.accuracy .description, key: MulticlassClassificationMetrics.Accuracy, text: localization.ModelAssessment.ModelOverview.metrics.accuracy.name - }); - } + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.precisionMacro + .description, + key: MulticlassClassificationMetrics.MacroPrecision, + text: localization.ModelAssessment.ModelOverview.metrics + .precisionMacro.name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.recallMacro + .description, + key: MulticlassClassificationMetrics.MacroRecall, + text: localization.ModelAssessment.ModelOverview.metrics.recallMacro + .name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.f1ScoreMacro + .description, + key: MulticlassClassificationMetrics.MacroF1, + text: localization.ModelAssessment.ModelOverview.metrics.f1ScoreMacro + .name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.precisionMicro + .description, + key: MulticlassClassificationMetrics.MicroPrecision, + text: localization.ModelAssessment.ModelOverview.metrics + .precisionMicro.name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.recallMicro + .description, + key: MulticlassClassificationMetrics.MicroRecall, + text: localization.ModelAssessment.ModelOverview.metrics.recallMicro + .name + }, + { + description: + localization.ModelAssessment.ModelOverview.metrics.f1ScoreMicro + .description, + key: MulticlassClassificationMetrics.MicroF1, + text: localization.ModelAssessment.ModelOverview.metrics.f1ScoreMicro + .name + } + ); } else { selectableMetrics.push( {