-
Notifications
You must be signed in to change notification settings - Fork 45
/
Getting_Partial_Dependence_Plot.rmd
264 lines (228 loc) · 9.56 KB
/
Getting_Partial_Dependence_Plot.rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
---
title: "Getting Partial Dependence Plot"
output: html_document
---
**Author**: Matt Marzillo
**Label**: Evaluating Models
### Scope
The scope of this notebook is to provide instructions on how to get the Partial Dependence Plot of a specific model using the R API.
### Background
Partial dependence conveys how changes to the value of each feature change model predictions if everything else remained unchanged. You can find the partial dependence plot in the "Feature Effects" tab interface.
### Requirements
- R version 3.6.2
- DataRobot API version 2.17.0.
Small adjustments might be needed depending on the R version and DataRobot API version you are using.
Full documentation of the R package can be found here: https://cran.r-project.org/web/packages/datarobot/index.html
It is assumed you already have a DataRobot Project object and a DataRobot Model object.
#### Import Packages
```{r results = 'hide', warning=FALSE, message=FALSE}
library(datarobot)
library(tidyverse)
```
#### Connect to DataRobot
Connect to DataRobot using your credentials and your endpoint. Change input below accordingly.
```{r results = 'hide', warning=FALSE, message=FALSE}
ConnectToDataRobot(endpoint = "YOUR_DATAROBOT_HOSTNAME",
token = "YOUR_API_KEY")
```
#### Getting data predictions from DataRobot and setting up validation step
```{r eval=FALSE}
# Downloading training data predictions from DataRobot to grab validation fold rowID
validation_fold_id <- RequestTrainingPredictions(Model,
dataSubset = DataSubset$ValidationAndHoldout)
validation_fold <- GetTrainingPredictionsFromJobId(Project,
validation_fold_id)
validation <- Dataset[validation_fold$rowId, ]
```
#### Creating Function to Manually Generate Partial Dependence Plots
```{r eval=FALSE}
# Function for generating feature effects from DataRobot manually
# Inputs are project object, model, desired dataset, and feature of interest.
# Returns the plot and plot data
# Sampling done by default for n > 1000 for speed purposes but can be turned off.
# size_of_grid is the number of feature values to use. Default is 25.
# grid is automatically created (quantiles for numerics, random sampling for categoricals).
# size_of_grid can also take a vector of values for custom grids.
# Can only be used for categoricals and numerics in the requested dataset (no DR derived features).
partial_dependence <- function(project, model, data, feature, size_of_grid = 25,
sample_size = 1000, ice_plot = FALSE,
std_dev_plot = FALSE){
# Loading required ackages
pkgs <- c("ggplot2", "reshape2", "dplyr")
sapply(pkgs, require, character.only = TRUE)
# For reproducibility
set.seed(10)
# Get needed info
project_info <- GetProject(project)
data <- as.data.frame(data)
# No Multiclass
if(project_info$targetType == "Multiclass"){
stop("Feature Effects is not support for Multiclass yet.")
}
# If feature is not
if(!(feature %in% colnames(data))){
stop("Specified feature is not found in dataset.")
}
# Can sample to a smaller size for speed
if(!is.null(sample_size)){
# Random sample if regression or if target is not included in dataset
if(project_info$targetType == "Regression" |
!project_info$target %in% colnames(data)){
# Random sample
data <- data %>% sample_n(size = min(nrow(data), sample_size),
replace = FALSE)
} else {
# Stratified random sample
data <- data %>%
group_by_at(vars(project_info$target)) %>%
mutate(num_rows = n()) %>%
sample_frac(size = min(1, sample_size/nrow(data)), weight = num_rows,
replace = FALSE) %>%
ungroup() %>%
as.data.frame()
}
}
# Defining feature type
if(is.character(data[, feature]) | is.factor(data[, feature])){
feature_type <- "categorical"
}else{
feature_type <- "numeric"
}
# If scalar, create grid. If more than one value, assume it's a supplied grid
if(length(size_of_grid) == 1){
cats <- unique(data[, feature])
if(length(cats) > size_of_grid){
if(feature_type == "numeric"){
sampled_values <- quantile(data[, feature],
probs = seq(0.05, 0.95,
length.out = size_of_grid),
na.rm = TRUE)
}else{
sampled_values <- sample(cats, size = size_of_grid, replace = FALSE)
}
}else{
sampled_values <- cats
}
}else{
sampled_values <- size_of_grid
}
# Creating augmented dataset (function of sample_size and size_of_grid)
data$rowID <- 1:nrow(data)
augmented_dataset <- bind_rows(lapply(sampled_values, function(x){
data[, feature] <- x
return(data)
}))
# Uploading augmented dataset
cat("Uploading augmented dataset of size",
format(object.size(augmented_dataset), units = "Mb"),
"with", nrow(augmented_dataset), "rows.", "\n")
augmented_dataset_id <- UploadPredictionDataset(project, augmented_dataset)
# Requestion predictions on augmented dataset
cat("Requesting predictions on augmented dataset", '\n')
pred_job_id <- RequestPredictions(project,
model$modelId,
augmented_dataset_id$id)
# Adding back predictions
if(project_info$targetType == "Regression"){
augmented_dataset$predictions <- GetPredictions(project, pred_job_id,
type = "response",
maxWait = 600000)
}else{
augmented_dataset$predictions <- GetPredictions(project, pred_job_id,
type = "probability",
maxWait = 600000)
}
# Preparing data for plotting
cat("Preparing plots", '\n')
# Collecting needed info
ice_plot_data <- augmented_dataset[, c(feature, "rowID", "predictions")]
# Calculating partial dependence
pd_plot_data <- ice_plot_data %>%
group_by_at(vars(feature)) %>%
summarise(mean_pred = mean(predictions),
sd = sd(predictions),
mean_minus_sd = mean_pred - sd,
mean_plus_sd = mean_pred + sd)
# Plotting partial dependence
if(feature_type == "numeric"){
pd_plot <- ggplot() +
geom_line(data = pd_plot_data, aes(x = !!ensym(feature),
y = mean_pred, group = 1),
size = 2.5, color = "black") +
geom_line(data = pd_plot_data,
aes(x = !!ensym(feature),
y = mean_pred,
group = 1),
size = 2,
color = "gold") +
xlab(noquote(feature)) +
ylab(paste0("Target (", project_info$target, ")")) +
theme_bw() +
ggtitle("Partial dependence")
}else{
pd_plot <- ggplot() +
geom_point(data = pd_plot_data,
aes(x = !!ensym(feature), y = mean_pred, group = 1),
size = 4.5, color = "black") +
geom_point(data = pd_plot_data,
aes(x = !!ensym(feature),
y = mean_pred, group = 1),
size = 4, color = "gold") +
xlab(noquote(feature)) +
ylab(paste0("Target (", project_info$target, ")")) +
theme_bw() +
theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
ggtitle("Partial dependence")
}
# Making plots
if(ice_plot){
if(feature_type == "numeric"){
pd_plot$layers <- c(geom_line(data = ice_plot_data,
aes(x = !!ensym(feature),
y = predictions,
group = rowID),
size = 1, color = "darkgray", alpha = 0.25),
pd_plot$layers)
}else{
pd_plot$layers <- c(geom_point(data = ice_plot_data,
aes(x = !!ensym(feature), y = predictions,
group = rowID),
size = 1,
color = "darkgray",
alpha = 0.25),
pd_plot$layers)
}
}
if(std_dev_plot){
if(feature_type == "numeric"){
pd_plot <- pd_plot +
geom_line(data = pd_plot_data,
aes(x = !!ensym(feature), y = mean_minus_sd, group = 1),
size= 1,
color = "darkorchid") +
geom_line(data = pd_plot_data,
aes(x = !!ensym(feature), y = mean_plus_sd, group = 1),
size = 1,
color = "darkorchid")
}else{
pd_plot <- pd_plot +
geom_point(data = pd_plot_data,
aes(x = !!ensym(feature), y = mean_minus_sd, group = 1),
size = 2, color = "darkorchid") +
geom_point(data = pd_plot_data,
aes(x = !!ensym(feature),
y = mean_plus_sd,
group = 1),
size = 2, color = "darkorchid")
}
}
return(list(pd_plot = pd_plot, pd_data = pd_plot_data))
}
```
#### Run Partial Dependence Plot Function
```{r}
partial_dependence(project = Project,
model = Model,
data = validation,
feature = "FEATURE_NAME")
```