Skip to content

Commit

Permalink
add toggle for switching classes in binary classification case
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Jun 6, 2022
1 parent 65ef440 commit 218ff29
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 23 deletions.
12 changes: 3 additions & 9 deletions libs/core-ui/src/lib/util/JointDataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -434,13 +434,7 @@ export class JointDataset {
featureArray.map((val) => [val])
);
}
case ModelTypes.Binary: {
return JointDataset.transposeLocalImportanceMatrix(
localExplanationRaw as number[][][]
).map((featuresByClasses) =>
featuresByClasses.map((classArray) => classArray.slice(0, 1))
);
}
case ModelTypes.Binary:
case ModelTypes.Multiclass:
default: {
return JointDataset.transposeLocalImportanceMatrix(
Expand Down Expand Up @@ -598,8 +592,7 @@ export class JointDataset {
Number.MIN_SAFE_INTEGER
);
switch (this._modelMeta.modelType) {
case ModelTypes.Regression:
case ModelTypes.Binary: {
case ModelTypes.Regression: {
// no need to flatten what is already flat
this.rawLocalImportance.forEach((featuresByClasses, rowIndex) => {
featuresByClasses.forEach((classArray, featureIndex) => {
Expand All @@ -621,6 +614,7 @@ export class JointDataset {
});
break;
}
case ModelTypes.Binary:
case ModelTypes.Multiclass: {
this.rawLocalImportance.forEach((featuresByClasses, rowIndex) => {
featuresByClasses.forEach((classArray, featureIndex) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ export class SidePanel extends React.Component<
onChange={this.onChartTypeChange}
id="ChartTypeSelection"
/>
{this.props.metadata.modelType === ModelTypes.Multiclass &&
{(this.props.metadata.modelType === ModelTypes.Multiclass ||
this.props.metadata.modelType === ModelTypes.Binary) &&
this.state.weightOptions && (
<div>
<LabelWithCallout
Expand Down Expand Up @@ -136,7 +137,10 @@ export class SidePanel extends React.Component<
};

private getWeightOptions(): IDropdownOption[] | undefined {
if (this.props.metadata.modelType === ModelTypes.Multiclass) {
if (
this.props.metadata.modelType === ModelTypes.Multiclass ||
this.props.metadata.modelType === ModelTypes.Binary
) {
return this.props.weightOptions.map((option) => {
return {
key: option,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ export class ExplanationExploration extends React.PureComponent<
dropdownOptions
);
const weightContext = this.props.dashboardContext.weightContext;
const modelType =
this.props.dashboardContext.explanationContext.modelMetadata.modelType;
const includeWeightDropdown =
this.props.dashboardContext.explanationContext.modelMetadata
.modelType === ModelTypes.Multiclass;
modelType === ModelTypes.Multiclass || modelType === ModelTypes.Binary;
let plotProp = ScatterUtils.populatePlotlyProps(
projectedData,
_.cloneDeep(this.plotlyProps)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,18 +214,19 @@ export class SinglePointFeatureImportance extends React.PureComponent<
// if (!this.props.explanationContext.testDataset.predictedY) {
// return result;
// }
const modelType = this.props.explanationContext.modelMetadata.modelType;
if (
this.props.explanationContext.modelMetadata.modelType !==
ModelTypes.Multiclass
modelType !== ModelTypes.Multiclass &&
modelType !== ModelTypes.Binary
) {
result.push({
key: FeatureKeys.AbsoluteLocal,
text: localization.Interpret.BarChart.absoluteLocal
});
}
if (
this.props.explanationContext.modelMetadata.modelType ===
ModelTypes.Multiclass
modelType === ModelTypes.Multiclass ||
modelType === ModelTypes.Binary
) {
result.push(
...this.props.explanationContext.modelMetadata.classNames.map(
Expand All @@ -243,8 +244,9 @@ export class SinglePointFeatureImportance extends React.PureComponent<
if (!this.props.explanationContext.testDataset.predictedY) {
return FeatureKeys.AbsoluteGlobal;
}
return this.props.explanationContext.modelMetadata.modelType ===
ModelTypes.Multiclass
const modelType = this.props.explanationContext.modelMetadata.modelType;
return modelType === ModelTypes.Multiclass ||
modelType === ModelTypes.Binary
? this.props.explanationContext.testDataset.predictedY[
this.props.selectedRow
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ export class LocalImportancePlots extends React.Component<
if (!this.props.jointDataset.hasDataset) {
return;
}
if (this.props.metadata.modelType === ModelTypes.Multiclass) {
const modelType = this.props.metadata.modelType;
if (
modelType === ModelTypes.Multiclass ||
modelType === ModelTypes.Binary
) {
this.weightOptions = this.props.weightOptions.map((option) => {
return {
key: option,
Expand Down Expand Up @@ -214,7 +218,8 @@ export class LocalImportancePlots extends React.Component<
</Stack.Item>
</Stack>

{this.props.metadata.modelType === ModelTypes.Multiclass && (
{(this.props.metadata.modelType === ModelTypes.Multiclass ||
this.props.metadata.modelType === ModelTypes.Binary) && (
<div>
<div className={classNames.multiclassWeightLabel}>
<Text
Expand Down
4 changes: 2 additions & 2 deletions libs/interpret/src/lib/MLIDashboard/ExplanationDashboard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,7 @@ export class ExplanationDashboard extends React.Component<
return undefined;
}
switch (modelType) {
case ModelTypes.Regression:
case ModelTypes.Binary: {
case ModelTypes.Regression: {
// no need to flatten what is already flat
return localExplanations.map((featuresByClasses) => {
return featuresByClasses.map((classArray) => {
Expand All @@ -491,6 +490,7 @@ export class ExplanationDashboard extends React.Component<
});
}
case ModelTypes.Multiclass:
case ModelTypes.Binary:
default: {
return localExplanations.map((featuresByClasses, rowIndex) => {
return featuresByClasses.map((classArray) => {
Expand Down

0 comments on commit 218ff29

Please sign in to comment.