Skip to content

Commit

Permalink
Simplify the train pipeline responsibleaidashboard-census-classificat…
Browse files Browse the repository at this point in the history
…ion-model-debugging.ipynb (#1195)

* Simplify the train pipeline responsibleaidashboard-census-classification-model-debugging.ipynb

Signed-off-by: Gaurav Gupta <[email protected]>

* Address code review comments

* Update notebooks/responsibleaidashboard/responsibleaidashboard-census-classification-model-debugging.ipynb

Co-authored-by: Roman Lutz <[email protected]>

Co-authored-by: Roman Lutz <[email protected]>
  • Loading branch information
gaugup and romanlutz authored Feb 27, 2022
1 parent e9ecb46 commit cdb0e4a
Showing 1 changed file with 17 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
"id": "clinical-henry",
"metadata": {},
"source": [
"First, load the census dataset and specify the different types of features. Then, clean the target feature values to include only 0 and 1."
"First, load the census dataset and specify the different types of features. Compose a pipeline which contains a preprocessor and estimator."
]
},
{
Expand All @@ -99,7 +99,7 @@
" y = dataset[[target_feature]]\n",
" return X, y\n",
"\n",
"def clean_data(X, y, target_feature):\n",
"def create_classification_pipeline(X, y, target_feature):\n",
" features = X.columns.values.tolist()\n",
" classes = y[target_feature].unique().tolist()\n",
" pipe_cfg = {\n",
Expand All @@ -118,9 +118,13 @@
" ('num_pipe', num_pipe, pipe_cfg['num_cols']),\n",
" ('cat_pipe', cat_pipe, pipe_cfg['cat_cols'])\n",
" ])\n",
" X = feat_pipe.fit_transform(X)\n",
" print(pipe_cfg['cat_cols'])\n",
" return X, feat_pipe, features, classes\n",
"\n",
" # Append classifier to preprocessing pipeline.\n",
" # Now we have a full prediction pipeline.\n",
" pipeline = Pipeline(steps=[('preprocessor', feat_pipe),\n",
" ('model', LGBMClassifier())])\n",
"\n",
" return pipeline\n",
"\n",
"outdirname = 'responsibleai.12.28.21'\n",
"try:\n",
Expand All @@ -140,30 +144,25 @@
"train_data = pd.read_csv('adult-train.csv')\n",
"test_data = pd.read_csv('adult-test.csv')\n",
"\n",
"\n",
"X_train_original, y_train = split_label(train_data, target_feature)\n",
"X_test_original, y_test = split_label(test_data, target_feature)\n",
"\n",
"pipeline = create_classification_pipeline(X_train_original, y_train, target_feature)\n",
"\n",
"X_train, feat_pipe, features, classes = clean_data(X_train_original, y_train, target_feature)\n",
"y_train = y_train[target_feature].to_numpy()\n",
"\n",
"X_test = feat_pipe.transform(X_test_original)\n",
"y_test = y_test[target_feature].to_numpy()\n",
"\n",
"train_data[target_feature] = y_train\n",
"test_data[target_feature] = y_test\n",
"\n",
"test_data_sample = test_data.sample(n=500, random_state=5)\n",
"train_data_sample = train_data.sample(n=8000, random_state=5)"
"# Take 500 samples from the test data\n",
"test_data_sample = test_data.sample(n=500, random_state=5)"
]
},
{
"cell_type": "markdown",
"id": "potential-proportion",
"metadata": {},
"source": [
"Train a LightGBM classifier on the training data."
"Train the classification pipeline composed in the previous cell on the training data."
]
},
{
Expand All @@ -173,8 +172,7 @@
"metadata": {},
"outputs": [],
"source": [
"clf = LGBMClassifier()\n",
"model = clf.fit(X_train, y_train)"
"model = pipeline.fit(X_train_original, y_train)"
]
},
{
Expand Down Expand Up @@ -213,10 +211,8 @@
"metadata": {},
"outputs": [],
"source": [
"dashboard_pipeline = Pipeline(steps=[('preprocess', feat_pipe), ('model', model)])\n",
"\n",
"rai_insights = RAIInsights(dashboard_pipeline, train_data_sample, test_data_sample, target_feature, 'classification',\n",
" categorical_features=categorical_features)"
"rai_insights = RAIInsights(model, train_data, test_data_sample, target_feature, 'classification',\n",
" categorical_features=categorical_features)"
]
},
{
Expand Down Expand Up @@ -519,7 +515,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.13"
"version": "3.7.11"
}
},
"nbformat": 4,
Expand Down

0 comments on commit cdb0e4a

Please sign in to comment.