Skip to content

Commit

Permalink
Merge branch 'main' into gaugup/UpdateRaiUtils
Browse files Browse the repository at this point in the history
  • Loading branch information
gaugup authored Jan 5, 2023
2 parents 848d665 + a4920bf commit 5044947
Show file tree
Hide file tree
Showing 19 changed files with 233 additions and 26 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/CI-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ on:
jobs:
ci-python:
strategy:
# keep running remaining matrix jobs even if one fails
# to avoid having to rerun all jobs several times
fail-fast: false
matrix:
packageDirectory:
[
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/CI-raiwidgets-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ jobs:
env:
node-version: 16.x
strategy:
# keep running remaining matrix jobs even if one fails
# to avoid having to rerun all jobs several times
fail-fast: false
matrix:
packageDirectory: ["raiwidgets"]
operatingSystem: [ubuntu-latest, macos-latest, windows-latest]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ export const covid19events: IDataset = {
"event7",
"event8"
],
task_type: DatasetTaskType.TextClassification,
task_type: DatasetTaskType.MultilabelTextClassification,
true_y: [
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 1, 0, 0, 0, 0, 1],
Expand Down
2 changes: 2 additions & 0 deletions libs/core-ui/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export * from "./lib/util/getFeatureOptions";
export * from "./lib/util/getFilterBoundsArgs";
export * from "./lib/util/calculateBoxData";
export * from "./lib/util/calculateLineData";
export * from "./lib/util/MultilabelStatisticsUtils";
export * from "./lib/util/StatisticsUtils";
export * from "./lib/util/string";
export * from "./lib/util/toScientific";
Expand Down Expand Up @@ -88,6 +89,7 @@ export * from "./lib/Interfaces/IErrorAnalysisData";
export * from "./lib/Interfaces/IDataBalanceMeasures";
export * from "./lib/Interfaces/IHighchartBoxData";
export * from "./lib/Interfaces/IMetaData";
export * from "./lib/Interfaces/IStatistic";
export * from "./lib/Interfaces/TextExplanationInterfaces";
export * from "./lib/Interfaces/VisionExplanationInterfaces";
export * from "./lib/Highchart/BasicHighChart";
Expand Down
4 changes: 3 additions & 1 deletion libs/core-ui/src/lib/Interfaces/IExplanationContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ export enum ModelTypes {
Multiclass = "multiclass",
ImageBinary = "imagebinary",
ImageMulticlass = "imagemulticlass",
ImageMultilabel = "imagemultilabel",
TextBinary = "textbinary",
TextMulticlass = "textmulticlass"
TextMulticlass = "textmulticlass",
TextMultilabel = "textmultilabel"
}

export interface IExplanationContext {
Expand Down
10 changes: 10 additions & 0 deletions libs/core-ui/src/lib/Interfaces/IStatistic.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

export interface ILabeledStatistic {
key: string;
label: string;
stat: number;
}

export const TotalCohortSamples = "samples";
3 changes: 2 additions & 1 deletion libs/core-ui/src/lib/components/OverallMetricChartUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ import { cohortKey } from "../cohortKey";
import { IModelAssessmentContext } from "../Context/ModelAssessmentContext";
import { getPrimaryChartColor } from "../Highchart/ChartColors";
import { ModelTypes } from "../Interfaces/IExplanationContext";
import { ILabeledStatistic } from "../Interfaces/IStatistic";
import { IsBinary } from "../util/ExplanationUtils";
import { FluentUIStyles } from "../util/FluentUIStyles";
import { ChartTypes, IGenericChartProps } from "../util/IGenericChartProps";
import { JointDataset } from "../util/JointDataset";
import { generateMetrics, ILabeledStatistic } from "../util/StatisticsUtils";
import { generateMetrics } from "../util/StatisticsUtils";

export function generatePlotlyProps(
jointData: JointDataset,
Expand Down
4 changes: 3 additions & 1 deletion libs/core-ui/src/lib/util/ExplanationUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ export function IsClassifier(modelType: ModelTypes): boolean {
modelType === ModelTypes.TextBinary ||
modelType === ModelTypes.Multiclass ||
modelType === ModelTypes.ImageMulticlass ||
modelType === ModelTypes.TextMulticlass
modelType === ModelTypes.TextMulticlass ||
modelType === ModelTypes.ImageMultilabel ||
modelType === ModelTypes.TextMultilabel
);
}
72 changes: 72 additions & 0 deletions libs/core-ui/src/lib/util/MultilabelStatisticsUtils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import { localization } from "@responsible-ai/localization";

import {
ILabeledStatistic,
TotalCohortSamples
} from "../Interfaces/IStatistic";

import { JointDataset } from "./JointDataset";

export enum MultilabelMetrics {
ExactMatchRatio = "exactMatchRatio",
HammingScore = "hammingScore"
}

export const generateMultilabelStats: (
jointDataset: JointDataset,
selectionIndexes: number[][]
) => ILabeledStatistic[][] = (
jointDataset: JointDataset,
selectionIndexes: number[][]
): ILabeledStatistic[][] => {
const numLabels = jointDataset.numLabels;
return selectionIndexes.map((selectionArray) => {
const matchingLabels = [];
let hammingScore = 0;
const count = selectionArray.length;
for (let i = 0; i < numLabels; i++) {
const trueYs = jointDataset.unwrap(JointDataset.TrueYLabel + i);
const predYs = jointDataset.unwrap(JointDataset.PredictedYLabel + i);

const trueYSubset = selectionArray.map((i) => trueYs[i]);
const predYSubset = selectionArray.map((i) => predYs[i]);
matchingLabels.push(
trueYSubset.filter((trueY, index) => trueY === predYSubset[index])
.length
);
const sumLogicalOr = trueYSubset
.map((trueY, index) => trueY | predYSubset[index])
.reduce((prev, curr) => prev + curr, 0);
const sumLogicalAnd = trueYSubset
.map((trueY, index) => trueY & predYSubset[index])
.reduce((prev, curr) => prev + curr, 0);
if (sumLogicalOr !== 0) {
hammingScore += sumLogicalAnd / sumLogicalOr;
}
}
hammingScore = hammingScore / numLabels;
const sum = matchingLabels.reduce((prev, curr) => prev + curr, 0);
const exactMatchRatio = sum / (numLabels * selectionArray.length);

return [
{
key: TotalCohortSamples,
label: localization.Interpret.Statistics.samples,
stat: count
},
{
key: MultilabelMetrics.ExactMatchRatio,
label: localization.Interpret.Statistics.exactMatchRatio,
stat: exactMatchRatio
},
{
key: MultilabelMetrics.HammingScore,
label: localization.Interpret.Statistics.hammingScore,
stat: hammingScore
}
];
});
};
19 changes: 11 additions & 8 deletions libs/core-ui/src/lib/util/StatisticsUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import { localization } from "@responsible-ai/localization";

import { ModelTypes } from "../Interfaces/IExplanationContext";
import {
ILabeledStatistic,
TotalCohortSamples
} from "../Interfaces/IStatistic";
import { IsBinary } from "../util/ExplanationUtils";

import {
Expand All @@ -15,12 +19,7 @@ import {
ClassificationEnum,
MulticlassClassificationEnum
} from "./JointDatasetUtils";

export interface ILabeledStatistic {
key: string;
label: string;
stat: number;
}
import { generateMultilabelStats } from "./MultilabelStatisticsUtils";

export enum BinaryClassificationMetrics {
Accuracy = "accuracy",
Expand All @@ -43,8 +42,6 @@ export enum MulticlassClassificationMetrics {
Accuracy = "accuracy"
}

export const TotalCohortSamples = "samples";

const generateBinaryStats: (outcomes: number[]) => ILabeledStatistic[] = (
outcomes: number[]
): ILabeledStatistic[] => {
Expand Down Expand Up @@ -255,6 +252,12 @@ export const generateMetrics: (
selectionIndexes: number[][],
modelType: ModelTypes
): ILabeledStatistic[][] => {
if (
modelType === ModelTypes.ImageMultilabel ||
modelType === ModelTypes.TextMultilabel
) {
return generateMultilabelStats(jointDataset, selectionIndexes);
}
const trueYs = jointDataset.unwrap(JointDataset.TrueYLabel);
const predYs = jointDataset.unwrap(JointDataset.PredictedYLabel);
if (modelType === ModelTypes.Regression) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import React from "react";

export interface ITreeViewLink {
d: string;
id: string;
key: string;
style: React.CSSProperties | undefined;
}

export interface ITreeViewPathProps {
link: ITreeViewLink;
onMouseOver: (linkId: string) => void;
onMouseOut: () => void;
}

export class TreeViewPath extends React.Component<ITreeViewPathProps> {
public render(): React.ReactNode {
const { link } = this.props;

return (
<path
key={link.key}
id={link.id}
d={link.d}
pointerEvents="all"
style={link.style}
onMouseOver={this.onMouseOver}
onMouseOut={this.onMouseOut}
/>
);
}

private onMouseOver = (): void => {
this.props.onMouseOver(this.props.link.id);
};

private onMouseOut = (): void => {
this.props.onMouseOut();
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import { FilterProps } from "../../FilterProps";
import { TreeLegend } from "../TreeLegend/TreeLegend";

import { TreeViewNode } from "./TreeViewNode";
import { ITreeViewLink, TreeViewPath } from "./TreeViewPath";
import { ITreeViewRendererProps } from "./TreeViewProps";
import {
ITreeViewRendererStyles,
Expand Down Expand Up @@ -190,7 +191,7 @@ export class TreeViewRenderer extends React.PureComponent<
// or not we highlight it. We use the d3 linkVertical which is a curved
// spline to draw the link. The thickness of the links depends on the
// ratio of data going through the path versus overall data in the tree.
const links = rootDescendants
const links: ITreeViewLink[] = rootDescendants
.slice(1)
.map((d: HierarchyPointNode<ITreeNode>) => {
const thick = 1 + Math.floor(30 * (d.data.size / this.state.rootSize));
Expand All @@ -201,8 +202,10 @@ export class TreeViewRenderer extends React.PureComponent<
const linkVerticalD = linkVertical({ source: d.parent, target: d });
return {
d: linkVerticalD || "",
id: id + getRandomId(),
id,
key: id + getRandomId(),
style: {
cursor: "pointer",
fill: theme.semanticColors.bodyBackground,
stroke: lineColor,
strokeWidth: thick
Expand Down Expand Up @@ -230,6 +233,10 @@ export class TreeViewRenderer extends React.PureComponent<
bbY: -0.5 * (bb.height + labelPaddingY) - labelYOffset,
id: `linkLabel${d.id}`,
style: {
display:
d.data.nodeState.onSelectedPath || this.state.hoverPathId === d.id
? undefined
: "none",
transform: `translate(${labelX}px, ${labelY}px)`
},
text: d.data.condition
Expand Down Expand Up @@ -326,11 +333,11 @@ export class TreeViewRenderer extends React.PureComponent<
<g className={containerStyles} tabIndex={0}>
<g>
{links.map((link) => (
<path
key={link.id}
id={link.id}
d={link.d}
style={link.style}
<TreeViewPath
key={link.key}
link={link}
onMouseOver={this.onMouseOver}
onMouseOut={this.onMouseOut}
/>
))}
</g>
Expand All @@ -352,7 +359,7 @@ export class TreeViewRenderer extends React.PureComponent<
<g
key={linkLabel.id}
style={linkLabel.style}
pointerEvents="none"
pointerEvents="all"
>
<rect
x={linkLabel.bbX}
Expand Down Expand Up @@ -384,6 +391,14 @@ export class TreeViewRenderer extends React.PureComponent<
);
}

private onMouseOver = (linkId: string | undefined): void => {
this.setState({ hoverPathId: linkId });
};

private onMouseOut = (): void => {
this.setState({ hoverPathId: undefined });
};

private calculateFilterProps(
node: IErrorAnalysisTreeNode,
rootErrorSize: number
Expand Down Expand Up @@ -540,6 +555,7 @@ export class TreeViewRenderer extends React.PureComponent<
}

return {
hoverPathId: undefined,
isErrorMetric,
maxDepth,
metric,
Expand Down Expand Up @@ -658,6 +674,7 @@ export class TreeViewRenderer extends React.PureComponent<
// APPLY TO NODEDETAIL OBJECT TO UPDATE DISPLAY PANEL
const nodeDetail = this.getNodeDetail(node);
return {
hoverPathId: undefined,
isErrorMetric: state.isErrorMetric,
maxDepth: state.maxDepth,
metric: state.metric,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export interface ITreeViewRendererState {
request?: AbortController;
nodeDetail: INodeDetail;
selectedNode: any;
hoverPathId: string | undefined;
transform: any;
treeNodes: any[];
root?: HierarchyPointNode<ITreeNode>;
Expand Down Expand Up @@ -72,6 +73,7 @@ export function createInitialTreeViewState(
errorAnalysisData: IErrorAnalysisData | undefined
): ITreeViewRendererState {
return {
hoverPathId: undefined,
isErrorMetric: true,
maxDepth: errorAnalysisData?.maxDepth ?? 4,
metric: errorAnalysisData?.metric ?? Metrics.ErrorRate,
Expand Down
10 changes: 10 additions & 0 deletions libs/localization/src/lib/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -1153,8 +1153,10 @@
"_rSquared.comment": "the coefficient of determination, see https://en.wikipedia.org/wiki/Coefficient_of_determination",
"_recall.comment": "computed recall of model, see https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers",
"accuracy": "Accuracy: {0}",
"exactMatchRatio": "Exact match ratio: {0}",
"fnr": "False negative rate: {0}",
"fpr": "False positive rate: {0}",
"hammingScore": "Hamming score: {0}",
"meanPrediction": "Mean prediction {0}",
"mse": "Mean squared error: {0}",
"precision": "Precision: {0}",
Expand Down Expand Up @@ -1652,6 +1654,14 @@
"name": "Accuracy score",
"description": "The fraction of data points classified correctly."
},
"exactMatchRatio": {
"name": "Exact match ratio",
"description": "The ratio of instances classified correctly for every label in multilabel task."
},
"hammingScore": {
"name": "Hamming score",
"description": "The average ratio of labels classified correctly among those classified as 1 in multilabel task."
},
"f1Score": {
"name": "F1 score",
"description": "F1-score is the harmonic mean of precision and recall."
Expand Down
Loading

0 comments on commit 5044947

Please sign in to comment.