Skip to content

Commit

Permalink
Merge pull request #19 from mayer79/ice_tests
Browse files Browse the repository at this point in the history
Improve ICE tests
  • Loading branch information
mayer79 authored Jul 2, 2023
2 parents 8c95df4 + 1096bbb commit 3eb0240
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ jobs:
function_exclusions = c(
"partial_dep\\.Learner",
"partial_dep\\.ranger",
"ice\\.Learner",
"ice\\.ranger",
"interact\\.Learner",
"interact\\.ranger"
)
Expand Down
19 changes: 18 additions & 1 deletion tests/testthat/test_ice.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ test_that("ice() returns same as partial_dep() for one row", {
ic <- ice(fit1, v = "Species", X = iris[1L, ])$ice_curves[2:3]
pd <- partial_dep(fit1, v = "Species", X = iris[1L, ])$pd
expect_equal(ic, pd)
capture_output(expect_no_error(print(ic)))
})

test_that("ice() returns the same values as ice_raw()", {
Expand Down Expand Up @@ -136,19 +137,29 @@ test_that("Plots give 'ggplot' objects", {
# One v, no by, univariate
expect_s3_class(plot(ice(fit, v = "Species", X = iris2)), "ggplot")

# Two v give error
ic <- ice(fit, v = c("Species", "Petal.Width"), X = iris2)
expect_error(plot(ic))

# One v, one by, univariate
expect_s3_class(
plot(ice(fit, v = "Species", X = iris2, BY = "Petal.Width")),
"ggplot"
)

# Centered
expect_s3_class(
plot(ice(fit, v = "Species", X = iris2, BY = "Petal.Width"), center = TRUE),
"ggplot"
)

# One v, two by, univariate
expect_s3_class(
plot(ice(fit, v = "Petal.Length", X = iris2, BY = c("Petal.Width", "Species"))),
"ggplot"
)

# NOW multioutput
# Now multioutput
fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)

# One v, no by, multivariate
Expand All @@ -157,6 +168,12 @@ test_that("Plots give 'ggplot' objects", {
"ggplot"
)

# Same centered
expect_s3_class(
plot(ice(fit, v = "Species", X = iris2), center = TRUE),
"ggplot"
)

# One v, one by, multivariate
expect_s3_class(
plot(
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_pdp.R
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ test_that("Plots give 'ggplot' objects", {
pd <- partial_dep(fit, v = c(v, "Petal.Length"), X = iris)
expect_error(plot(pd))

# NOW multioutput
# Now multioutput
fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)

# One v, no by, multivariate
Expand Down

0 comments on commit 3eb0240

Please sign in to comment.