diff --git a/apps/dashboard/src/model-assessment/__mock_data__/adultCensus.ts b/apps/dashboard/src/model-assessment/__mock_data__/adultCensus.ts index 936ab964f3..ebba831025 100644 --- a/apps/dashboard/src/model-assessment/__mock_data__/adultCensus.ts +++ b/apps/dashboard/src/model-assessment/__mock_data__/adultCensus.ts @@ -9908,7 +9908,7 @@ export const adultCounterfactualData: ICounterfactualData = { "occupation", "relationship", "race", - "sex", + "gender", "capital-gain", "capital-loss", "hours-per-week", @@ -9924,7 +9924,7 @@ export const adultCounterfactualData: ICounterfactualData = { "occupation", "relationship", "race", - "sex", + "gender", "capital-gain", "capital-loss", "hours-per-week", diff --git a/apps/widget-e2e/src/describer/modelAssessment/modelAssessmentDatasets.ts b/apps/widget-e2e/src/describer/modelAssessment/modelAssessmentDatasets.ts index c1c8b0729b..84ecfc91ba 100644 --- a/apps/widget-e2e/src/describer/modelAssessment/modelAssessmentDatasets.ts +++ b/apps/widget-e2e/src/describer/modelAssessment/modelAssessmentDatasets.ts @@ -26,7 +26,7 @@ const modelAssessmentDatasets = { }, featureImportanceData: { datapoint: 500, - dropdownRowName: "Row 4", + dropdownRowName: "Row 34", hasCorrectIncorrectDatapoints: true, hasFeatureImportanceComponent: true, newFeatureDropdownValue: "workclass", diff --git a/libs/counterfactuals/src/lib/CounterfactualChart.tsx b/libs/counterfactuals/src/lib/CounterfactualChart.tsx index 7b45feea3c..8b27daca39 100644 --- a/libs/counterfactuals/src/lib/CounterfactualChart.tsx +++ b/libs/counterfactuals/src/lib/CounterfactualChart.tsx @@ -203,6 +203,7 @@ export class CounterfactualChart extends React.PureComponent< closePanel={this.togglePanel} saveAsPoint={this.saveAsPoint} setCustomRowProperty={this.setCustomRowProperty} + setCustomRowPropertyComboBox={this.setCustomRowPropertyComboBox} temporaryPoint={this.temporaryPoint} isPanelOpen={this.state.isPanelOpen} data={this.context.counterfactualData} @@ -845,6 +846,24 @@ export class CounterfactualChart extends React.PureComponent< } }; + private setCustomRowPropertyComboBox = ( + key: string | number, + index?: number, + value?: string + ): void => { + if (!this.temporaryPoint || (!value && !index)) { + return; + } + const editingData = this.temporaryPoint; + if (index !== undefined) { + // User selected/de-selected an existing option + editingData[key] = index; + } + + this.forceUpdate(); + this.fetchData(editingData); + }; + private disableCounterfactualPanel = (): boolean => { return ( this.state.selectedPointsIndexes[0] === undefined || diff --git a/libs/counterfactuals/src/lib/CounterfactualList.tsx b/libs/counterfactuals/src/lib/CounterfactualList.tsx index 0145695a77..fbec169dd0 100644 --- a/libs/counterfactuals/src/lib/CounterfactualList.tsx +++ b/libs/counterfactuals/src/lib/CounterfactualList.tsx @@ -51,6 +51,11 @@ export interface ICounterfactualListProps { isString: boolean, newValue?: string | number ): void; + setCustomRowPropertyComboBox( + key: string | number, + index?: number, + value?: string + ): void; } interface ICounterfactualListState { @@ -270,8 +275,9 @@ export class CounterfactualList extends React.Component< return columns; } - private updateDropdownColValue = ( + private updateComboBoxColValue = ( key: string | number, + options: IComboBoxOption[], _event: React.FormEvent, option?: IComboBoxOption ): void => { @@ -279,10 +285,17 @@ export class CounterfactualList extends React.Component< const keyIndex = this.props.data?.feature_names_including_target.indexOf(id); if (option?.text) { - this.props.setCustomRowProperty(`Data${keyIndex}`, false, option.text); + const optionIndex = options.findIndex( + (feature) => feature.key === option.text + ); + this.props.setCustomRowPropertyComboBox( + `Data${keyIndex}`, + optionIndex, + option.text + ); this.setState((prevState) => { prevState.data[id] = option.text; - return { data: prevState.data }; + return { data: { ...prevState.data } }; }); } }; @@ -298,7 +311,7 @@ export class CounterfactualList extends React.Component< this.props.setCustomRowProperty(`Data${keyIndex}`, false, newValue); this.setState((prevState) => { prevState.data[id] = toNumber(newValue); - return { data: prevState.data }; + return { data: { ...prevState.data } }; }); }; @@ -343,7 +356,17 @@ export class CounterfactualList extends React.Component< allowFreeform selectedKey={`${this.state.data[column.key]}`} options={dropdownOption.data.categoricalOptions} - onChange={this.updateDropdownColValue.bind(this, column.key)} + onChange={( + _event: React.FormEvent, + option?: IComboBoxOption + ) => + this.updateComboBoxColValue( + column.key, + dropdownOption.data.categoricalOptions, + _event, + option + ) + } /> diff --git a/libs/counterfactuals/src/lib/CounterfactualPanel.tsx b/libs/counterfactuals/src/lib/CounterfactualPanel.tsx index 7504c93d37..1d42cc289d 100644 --- a/libs/counterfactuals/src/lib/CounterfactualPanel.tsx +++ b/libs/counterfactuals/src/lib/CounterfactualPanel.tsx @@ -41,6 +41,11 @@ export interface ICounterfactualPanelProps { isString: boolean, newValue?: string ): void; + setCustomRowPropertyComboBox( + key: string | number, + index?: number, + value?: string + ): void; } interface ICounterfactualState { filterText?: string; @@ -84,6 +89,9 @@ export class CounterfactualPanel extends React.Component< data={this.props.data} temporaryPoint={this.props.temporaryPoint} setCustomRowProperty={this.props.setCustomRowProperty} + setCustomRowPropertyComboBox={ + this.props.setCustomRowPropertyComboBox + } sortFeatures={this.state.sortFeatures} /> diff --git a/notebooks/responsibleaidashboard/responsibleaidashboard-housing-classification-model-debugging.ipynb b/notebooks/responsibleaidashboard/responsibleaidashboard-housing-classification-model-debugging.ipynb index e0a5ec72f9..a1cc535fa4 100644 --- a/notebooks/responsibleaidashboard/responsibleaidashboard-housing-classification-model-debugging.ipynb +++ b/notebooks/responsibleaidashboard/responsibleaidashboard-housing-classification-model-debugging.ipynb @@ -153,7 +153,7 @@ "metadata": {}, "outputs": [], "source": [ - "clf = LGBMClassifier(n_estimators=5)\n", + "clf = LGBMClassifier()\n", "model = clf.fit(X_train, y_train)" ] },