-
Notifications
You must be signed in to change notification settings - Fork 360
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add more unittests RAI dashboard input class (#1448)
* Add unit tests for ResponsibleAIDashboardInput Signed-off-by: Gaurav Gupta <[email protected]> * Add more tests Signed-off-by: Gaurav Gupta <[email protected]> * Fix imports Signed-off-by: Gaurav Gupta <[email protected]> * Address code review comments Signed-off-by: Gaurav Gupta <[email protected]>
- Loading branch information
Showing
1 changed file
with
208 additions
and
10 deletions.
There are no files selected for viewing
218 changes: 208 additions & 10 deletions
218
raiwidgets/tests/test_responsibleai_dashboard_input.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,222 @@ | ||
# Copyright (c) Microsoft Corporation | ||
# Licensed under the MIT License. | ||
|
||
from unittest.mock import patch | ||
|
||
from raiwidgets.interfaces import WidgetRequestResponseConstants | ||
from raiwidgets.responsibleai_dashboard_input import \ | ||
ResponsibleAIDashboardInput | ||
|
||
|
||
class TestResponsibleAIDashboardInput: | ||
def test_model_analysis_adult( | ||
class TestResponsibleAIDashboardInputClassification: | ||
def test_rai_dashboard_input_adult_on_predict_success( | ||
self, create_rai_insights_object_classification): | ||
ri = create_rai_insights_object_classification | ||
knn = ri.model | ||
test_data = ri.test | ||
|
||
dashboard_input = ResponsibleAIDashboardInput(ri) | ||
with patch.object(knn, "predict_proba") as predict_mock: | ||
test_pred_data = test_data.head(1).drop("Income", axis=1).values | ||
dashboard_input.on_predict( | ||
test_pred_data) | ||
test_pred_data = test_data.head(1).drop("Income", axis=1).values | ||
flask_server_prediction_output = dashboard_input.on_predict( | ||
test_pred_data) | ||
knn_prediction = knn.predict_proba(test_pred_data) | ||
|
||
assert knn_prediction is not None | ||
assert flask_server_prediction_output is not None | ||
assert WidgetRequestResponseConstants.data in \ | ||
flask_server_prediction_output | ||
assert (flask_server_prediction_output['data'] == knn_prediction).all() | ||
|
||
def test_rai_dashboard_input_adult_on_predict_failure( | ||
self, create_rai_insights_object_classification): | ||
ri = create_rai_insights_object_classification | ||
test_data = ri.test | ||
|
||
dashboard_input = ResponsibleAIDashboardInput(ri) | ||
test_pred_data = test_data.head(1).values | ||
flask_server_prediction_output = dashboard_input.on_predict( | ||
test_pred_data) | ||
|
||
assert flask_server_prediction_output is not None | ||
assert WidgetRequestResponseConstants.error in \ | ||
flask_server_prediction_output | ||
assert "Model threw exception while predicting..." in \ | ||
flask_server_prediction_output[ | ||
WidgetRequestResponseConstants.error] | ||
assert len( | ||
flask_server_prediction_output[ | ||
WidgetRequestResponseConstants.data]) == 0 | ||
|
||
def test_rai_dashboard_input_adult_importances_success( | ||
self, create_rai_insights_object_classification): | ||
ri = create_rai_insights_object_classification | ||
|
||
dashboard_input = ResponsibleAIDashboardInput(ri) | ||
flask_server_prediction_output = dashboard_input.importances() | ||
assert flask_server_prediction_output is not None | ||
assert WidgetRequestResponseConstants.data in \ | ||
flask_server_prediction_output | ||
assert WidgetRequestResponseConstants.error not in \ | ||
flask_server_prediction_output | ||
|
||
def test_rai_dashboard_input_adult_matrix_success( | ||
self, create_rai_insights_object_classification): | ||
ri = create_rai_insights_object_classification | ||
features = ['Age', 'Workclass'] | ||
filters = [] | ||
composite_filters = [] | ||
quantile_binning = False | ||
num_bins = 8 | ||
metric = "Error rate" | ||
post_data = [features, filters, composite_filters, | ||
quantile_binning, num_bins, metric] | ||
|
||
dashboard_input = ResponsibleAIDashboardInput(ri) | ||
flask_server_prediction_output = dashboard_input.matrix(post_data) | ||
assert flask_server_prediction_output is not None | ||
assert WidgetRequestResponseConstants.data in \ | ||
flask_server_prediction_output | ||
assert WidgetRequestResponseConstants.error not in \ | ||
flask_server_prediction_output | ||
|
||
def test_rai_dashboard_input_adult_matrix_failure( | ||
self, create_rai_insights_object_classification): | ||
ri = create_rai_insights_object_classification | ||
features = ['Age', 'Workclass'] | ||
filters = [] | ||
composite_filters = [] | ||
quantile_binning = False | ||
num_bins = 8 | ||
metric = "Error Rate" | ||
post_data = [features, filters, composite_filters, | ||
quantile_binning, num_bins, metric] | ||
|
||
dashboard_input = ResponsibleAIDashboardInput(ri) | ||
flask_server_prediction_output = dashboard_input.matrix(post_data) | ||
assert flask_server_prediction_output is not None | ||
assert WidgetRequestResponseConstants.error in \ | ||
flask_server_prediction_output | ||
assert "Failed to generate json matrix representation," in \ | ||
flask_server_prediction_output[ | ||
WidgetRequestResponseConstants.error] | ||
assert len( | ||
flask_server_prediction_output[ | ||
WidgetRequestResponseConstants.data]) == 0 | ||
|
||
assert (predict_mock.call_args[0] | ||
[0].values == test_pred_data).all() | ||
def test_rai_dashboard_input_adult_debug_ml_success( | ||
self, create_rai_insights_object_classification): | ||
ri = create_rai_insights_object_classification | ||
|
||
features = ri.test.drop("Income", axis=1).columns.tolist() | ||
filters = [] | ||
composite_filters = [] | ||
max_depth = 3 | ||
num_leaves = 3 | ||
min_child_samples = 8 | ||
metric = "Error rate" | ||
post_data = [features, filters, composite_filters, | ||
max_depth, num_leaves, min_child_samples, metric] | ||
|
||
dashboard_input = ResponsibleAIDashboardInput(ri) | ||
flask_server_prediction_output = dashboard_input.debug_ml(post_data) | ||
assert flask_server_prediction_output is not None | ||
assert WidgetRequestResponseConstants.data in \ | ||
flask_server_prediction_output | ||
assert WidgetRequestResponseConstants.error not in \ | ||
flask_server_prediction_output | ||
|
||
def test_rai_dashboard_input_adult_debug_ml_failure( | ||
self, create_rai_insights_object_classification): | ||
ri = create_rai_insights_object_classification | ||
|
||
features = ri.test.drop("Income", axis=1).columns.tolist() | ||
filters = [] | ||
composite_filters = [] | ||
max_depth = 3 | ||
num_leaves = 3 | ||
min_child_samples = 8 | ||
metric = "Error Rate" | ||
post_data = [features, filters, composite_filters, | ||
max_depth, num_leaves, min_child_samples, metric] | ||
|
||
dashboard_input = ResponsibleAIDashboardInput(ri) | ||
flask_server_prediction_output = dashboard_input.debug_ml(post_data) | ||
assert flask_server_prediction_output is not None | ||
assert WidgetRequestResponseConstants.error in \ | ||
flask_server_prediction_output | ||
assert "Failed to generate json tree representation," in \ | ||
flask_server_prediction_output[ | ||
WidgetRequestResponseConstants.error] | ||
assert len( | ||
flask_server_prediction_output[ | ||
WidgetRequestResponseConstants.data]) == 0 | ||
|
||
|
||
class TestResponsibleAIDashboardInputRegression: | ||
def test_rai_dashboard_input_housing_on_predict_success( | ||
self, create_rai_insights_object_regression): | ||
ri = create_rai_insights_object_regression | ||
rf = ri.model | ||
test_data = ri.test | ||
|
||
dashboard_input = ResponsibleAIDashboardInput(ri) | ||
test_pred_data = test_data.head(1).drop("target", axis=1).values | ||
flask_server_prediction_output = dashboard_input.on_predict( | ||
test_pred_data) | ||
rf_prediction = rf.predict(test_pred_data) | ||
|
||
assert rf_prediction is not None | ||
assert flask_server_prediction_output is not None | ||
assert WidgetRequestResponseConstants.data in \ | ||
flask_server_prediction_output | ||
assert (flask_server_prediction_output['data'] == rf_prediction).all() | ||
|
||
def test_rai_dashboard_input_housing_causal_whatif_success( | ||
self, create_rai_insights_object_regression): | ||
ri = create_rai_insights_object_regression | ||
id = ri.causal.get()[0].id | ||
causal_whatif_test_data = ri.test.head(1).drop( | ||
"target", axis=1).to_dict(orient='records') | ||
treatment_feature = 'AveRooms' | ||
current_treatment_value = [causal_whatif_test_data[0][ | ||
treatment_feature]] | ||
current_outcome = [ri.test.head(1)["target"].values[0]] | ||
|
||
dashboard_input = ResponsibleAIDashboardInput(ri) | ||
post_data = (id, causal_whatif_test_data, | ||
treatment_feature, current_treatment_value, | ||
current_outcome) | ||
flask_server_prediction_output = dashboard_input.causal_whatif( | ||
post_data) | ||
assert flask_server_prediction_output is not None | ||
assert WidgetRequestResponseConstants.data in \ | ||
flask_server_prediction_output | ||
assert WidgetRequestResponseConstants.error not in \ | ||
flask_server_prediction_output | ||
|
||
def test_rai_dashboard_input_housing_causal_whatif_failure( | ||
self, create_rai_insights_object_regression): | ||
ri = create_rai_insights_object_regression | ||
id = "some_id" | ||
causal_whatif_test_data = ri.test.head(1).drop( | ||
"target", axis=1).to_dict(orient='records') | ||
treatment_feature = 'AveRooms' | ||
current_treatment_value = [causal_whatif_test_data[0][ | ||
treatment_feature]] | ||
current_outcome = [ri.test.head(1)["target"].values[0]] | ||
|
||
dashboard_input = ResponsibleAIDashboardInput(ri) | ||
post_data = (id, causal_whatif_test_data, | ||
treatment_feature, current_treatment_value, | ||
current_outcome) | ||
flask_server_prediction_output = dashboard_input.causal_whatif( | ||
post_data) | ||
assert flask_server_prediction_output is not None | ||
assert WidgetRequestResponseConstants.data in \ | ||
flask_server_prediction_output | ||
assert len( | ||
flask_server_prediction_output[ | ||
WidgetRequestResponseConstants.data]) == 0 | ||
assert WidgetRequestResponseConstants.error in \ | ||
flask_server_prediction_output | ||
assert "Failed to generate causal what-if," in \ | ||
flask_server_prediction_output[ | ||
WidgetRequestResponseConstants.error] |