Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
hendersontrent committed Aug 20, 2023
2 parents cc114ac + b4de758 commit bd7e492
Show file tree
Hide file tree
Showing 83 changed files with 193 additions and 80 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: correctR
Type: Package
Title: Corrected Test Statistics for Comparing Machine Learning Models on Correlated Samples
Version: 0.1.3
Date: 2023-08-20
Date: 2023-01-27
Authors@R: c(
person("Trent", "Henderson", email = "[email protected]", role = c("cre", "aut"))
)
Expand Down
18 changes: 13 additions & 5 deletions R/kfold_ttest.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,22 @@ kfold_ttest <- function(x, y, n, k){
# Calculations

d <- x - y # Calculate differences
statistic <- mean(d, na.rm = TRUE) / sqrt(stats::var(d, na.rm = TRUE) * ((1/n + (1/k)) / (1 - 1/k))) # Calculate t-statistic

if(statistic < 0){
p.value <- stats::pt(statistic, n - 1) # p-value for left tail
# Catch for when there is zero difference(s) between the models

if (sum(d) == 0) {
tmp <- data.frame(statistic = 0, p.value = 1)
} else{
p.value <- stats::pt(statistic, n - 1, lower.tail = FALSE) # p-value for right tail
statistic <- mean(d, na.rm = TRUE) / sqrt(stats::var(d, na.rm = TRUE) * ((1/n + (1/k)) / (1 - 1/k))) # Calculate t-statistic

if(statistic < 0){
p.value <- stats::pt(statistic, n - 1) # p-value for left tail
} else{
p.value <- stats::pt(statistic, n - 1, lower.tail = FALSE) # p-value for right tail
}

tmp <- data.frame(statistic = statistic, p.value = p.value)
}

tmp <- data.frame(statistic = statistic, p.value = p.value)
return(tmp)
}
34 changes: 15 additions & 19 deletions R/repkfold_ttest.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' Compute correlated t-statistic and p-value for repeated k-fold cross-validated results
#' @importFrom stats var pt
#' @param data \code{data.frame} of values for model A and model B over repeated k-fold cross-validation. Three named columns are expected:
#' @param data \code{data.frame} of values for model A and model B over repeated k-fold cross-validation. Four named columns are expected: \code{"model"}, \code{"values"}, \code{"k"}, and \code{"k"}
#' @param n1 \code{integer} denoting train set size
#' @param n2 \code{integer} denoting test set size
#' @param k \code{integer} denoting number of folds used in k-fold
Expand All @@ -18,19 +18,7 @@ repkfold_ttest <- function(data, n1, n2, k, r){

'%ni%' <- Negate('%in%')

if("model" %ni% colnames(data)){
stop("data should contain at least four columns called 'model', 'values', 'k', and 'r'.")
}

if("values" %ni% colnames(data)){
stop("data should contain at least four columns called 'model', 'values', 'k', and 'r'.")
}

if("k" %ni% colnames(data)){
stop("data should contain at least four columns called 'model', 'values', 'k', and 'r'.")
}

if("r" %ni% colnames(data)){
if("model" %ni% colnames(data) || "values" %ni% colnames(data) || "k" %ni% colnames(data) || "r" %ni% colnames(data)){
stop("data should contain at least four columns called 'model', 'values', 'k', and 'r'.")
}

Expand Down Expand Up @@ -59,14 +47,22 @@ repkfold_ttest <- function(data, n1, n2, k, r){
}
}

statistic <- mean(d, na.rm = TRUE) / sqrt(stats::var(d, na.rm = TRUE) * ((1/(k * r)) + (n2/n1))) # Calculate t-statistic
# Catch for when there is zero difference(s) between the models

if(statistic < 0){
p.value <- stats::pt(statistic, (k * r) - 1) # p-value for left tail
if (sum(d) == 0) {
tmp <- data.frame(statistic = 0, p.value = 1)
} else{
p.value <- stats::pt(statistic, (k * r) - 1, lower.tail = FALSE) # p-value for right tail

statistic <- mean(d, na.rm = TRUE) / sqrt(stats::var(d, na.rm = TRUE) * ((1/(k * r)) + (n2/n1))) # Calculate t-statistic

if(statistic < 0){
p.value <- stats::pt(statistic, (k * r) - 1) # p-value for left tail
} else{
p.value <- stats::pt(statistic, (k * r) - 1, lower.tail = FALSE) # p-value for right tail
}

tmp <- data.frame(statistic = statistic, p.value = p.value)
}

tmp <- data.frame(statistic = statistic, p.value = p.value)
return(tmp)
}
18 changes: 13 additions & 5 deletions R/resampled_ttest.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,22 @@ resampled_ttest <- function(x, y, n, n1, n2){
# Calculations

d <- x - y # Calculate differences
statistic <- mean(d, na.rm = TRUE) / sqrt(stats::var(d, na.rm = TRUE) * (1/n + n2/n1)) # Calculate t-statistic

if(statistic < 0){
p.value <- stats::pt(statistic, n - 1) # p-value for left tail
# Catch for when there is zero difference(s) between the models

if (sum(d) == 0) {
tmp <- data.frame(statistic = 0, p.value = 1)
} else{
p.value <- stats::pt(statistic, n - 1, lower.tail = FALSE) # p-value for right tail
statistic <- mean(d, na.rm = TRUE) / sqrt(stats::var(d, na.rm = TRUE) * (1/n + n2/n1)) # Calculate t-statistic

if(statistic < 0){
p.value <- stats::pt(statistic, n - 1) # p-value for left tail
} else{
p.value <- stats::pt(statistic, n - 1, lower.tail = FALSE) # p-value for right tail
}

tmp <- data.frame(statistic = statistic, p.value = p.value)
}

tmp <- data.frame(statistic = statistic, p.value = p.value)
return(tmp)
}
4 changes: 4 additions & 0 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@ devtools::install_github("hendersontrent/theft")
## General purpose

Often in machine learning, we want to compare the performance of different models to determine if one statistically outperforms another. However, the methods used (e.g., data resampling, $k$-fold cross-validation) to obtain these performance metrics (e.g., classification accuracy) violate the assumptions of traditional statistical tests such as a $t$-test. The purpose of these methods is to either aid generalisability of findings (i.e., through quantification of error as they produce multiple values for each model instead of just one) or to optimise model hyperparameters. This makes them invaluable, but unusable with traditional tests, as [Dietterich (1998)](https://pubmed.ncbi.nlm.nih.gov/9744903/) found that the standard $t$-test underestimates the variance, therefore driving a high Type I error. `correctR` is a lightweight package that implements a small number of corrected test statistics for cases when samples are not independent (and therefore are correlated), such as in the case of resampling, $k$-fold cross-validation, and repeated $k$-fold cross-validation. These corrections were all originally proposed by [Nadeau and Bengio (2003)](https://link.springer.com/article/10.1023/A:1024068626366). Currently, only cases where two models are to be compared are supported.

## Python version

A Python version of `correctR` called `correctipy` is available at the [GitHub repository](https://github.com/hendersontrent/correctipy).
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,8 @@ cross-validation. These corrections were all originally proposed by
[Nadeau and Bengio
(2003)](https://link.springer.com/article/10.1023/A:1024068626366).
Currently, only cases where two models are to be compared are supported.

## Python version

A Python version of `correctR` called `correctipy` is available at the
[GitHub repository](https://github.com/hendersontrent/correctipy).
9 changes: 9 additions & 0 deletions docs/LICENSE-text.html
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
<!DOCTYPE html>
<<<<<<< HEAD
<!-- Generated by pkgdown: do not edit by hand --><html lang="en"><head><meta http-equiv="Content-Type" content="text/html; charset=UTF-8"><meta charset="utf-8"><meta http-equiv="X-UA-Compatible" content="IE=edge"><meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no"><title>License • correctR</title><script src="deps/jquery-3.6.0/jquery-3.6.0.min.js"></script><meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no"><link href="deps/bootstrap-5.1.3/bootstrap.min.css" rel="stylesheet"><script src="deps/bootstrap-5.1.3/bootstrap.bundle.min.js"></script><!-- Font Awesome icons --><link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.12.1/css/all.min.css" integrity="sha256-mmgLkCYLUQbXn0B1SRqzHar6dCnv9oZFPEC1g1cwlkk=" crossorigin="anonymous"><link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.12.1/css/v4-shims.min.css" integrity="sha256-wZjR52fzng1pJHwx4aV2AO3yyTOXrcDW7jBpJtTwVxw=" crossorigin="anonymous"><!-- bootstrap-toc --><script src="https://cdn.jsdelivr.net/gh/afeld/[email protected]/dist/bootstrap-toc.min.js" integrity="sha256-4veVQbu7//Lk5TSmc7YV48MxtMy98e26cf5MrgZYnwo=" crossorigin="anonymous"></script><!-- headroom.js --><script src="https://cdnjs.cloudflare.com/ajax/libs/headroom/0.11.0/headroom.min.js" integrity="sha256-AsUX4SJE1+yuDu5+mAVzJbuYNPHj/WroHuZ8Ir/CkE0=" crossorigin="anonymous"></script><script src="https://cdnjs.cloudflare.com/ajax/libs/headroom/0.11.0/jQuery.headroom.min.js" integrity="sha256-ZX/yNShbjqsohH1k95liqY9Gd8uOiE1S4vZc+9KQ1K4=" crossorigin="anonymous"></script><!-- clipboard.js --><script src="https://cdnjs.cloudflare.com/ajax/libs/clipboard.js/2.0.6/clipboard.min.js" integrity="sha256-inc5kl9MA1hkeYUt+EC3BhlIgyp/2jDIyBLS6k3UxPI=" crossorigin="anonymous"></script><!-- search --><script src="https://cdnjs.cloudflare.com/ajax/libs/fuse.js/6.4.6/fuse.js" integrity="sha512-zv6Ywkjyktsohkbp9bb45V6tEMoWhzFzXis+LrMehmJZZSys19Yxf1dopHx7WzIKxr5tK2dVcYmaCk2uqdjF4A==" crossorigin="anonymous"></script><script src="https://cdnjs.cloudflare.com/ajax/libs/autocomplete.js/0.38.0/autocomplete.jquery.min.js" integrity="sha512-GU9ayf+66Xx2TmpxqJpliWbT5PiGYxpaG8rfnBEk1LL8l1KGkRShhngwdXK1UgqhAzWpZHSiYPc09/NwDQIGyg==" crossorigin="anonymous"></script><script src="https://cdnjs.cloudflare.com/ajax/libs/mark.js/8.11.1/mark.min.js" integrity="sha512-5CYOlHXGh6QpOFA/TeTylKLWfB3ftPsde7AnmhuitiTX4K5SqCLBeKro6sPS8ilsz1Q4NRx3v8Ko2IBiszzdww==" crossorigin="anonymous"></script><!-- pkgdown --><script src="pkgdown.js"></script><meta property="og:title" content="License"><!-- mathjax --><script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js" integrity="sha256-nvJJv9wWKEm88qvoQl9ekL2J+k/RWIsaSScxxlsrv8k=" crossorigin="anonymous"></script><script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/config/TeX-AMS-MML_HTMLorMML.js" integrity="sha256-84DKXVJXs0/F8OTMzX4UR909+jtl4G7SPypPavF+GfA=" crossorigin="anonymous"></script><!--[if lt IE 9]>
=======
<!-- Generated by pkgdown: do not edit by hand --><html lang="en"><head><meta http-equiv="Content-Type" content="text/html; charset=UTF-8"><meta charset="utf-8"><meta http-equiv="X-UA-Compatible" content="IE=edge"><meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no"><title>License • correctR</title><script src="deps/jquery-3.6.0/jquery-3.6.0.min.js"></script><meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no"><link href="deps/bootstrap-5.1.0/bootstrap.min.css" rel="stylesheet"><script src="deps/bootstrap-5.1.0/bootstrap.bundle.min.js"></script><!-- Font Awesome icons --><link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.12.1/css/all.min.css" integrity="sha256-mmgLkCYLUQbXn0B1SRqzHar6dCnv9oZFPEC1g1cwlkk=" crossorigin="anonymous"><link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.12.1/css/v4-shims.min.css" integrity="sha256-wZjR52fzng1pJHwx4aV2AO3yyTOXrcDW7jBpJtTwVxw=" crossorigin="anonymous"><!-- bootstrap-toc --><script src="https://cdn.rawgit.com/afeld/bootstrap-toc/v1.0.1/dist/bootstrap-toc.min.js"></script><!-- headroom.js --><script src="https://cdnjs.cloudflare.com/ajax/libs/headroom/0.11.0/headroom.min.js" integrity="sha256-AsUX4SJE1+yuDu5+mAVzJbuYNPHj/WroHuZ8Ir/CkE0=" crossorigin="anonymous"></script><script src="https://cdnjs.cloudflare.com/ajax/libs/headroom/0.11.0/jQuery.headroom.min.js" integrity="sha256-ZX/yNShbjqsohH1k95liqY9Gd8uOiE1S4vZc+9KQ1K4=" crossorigin="anonymous"></script><!-- clipboard.js --><script src="https://cdnjs.cloudflare.com/ajax/libs/clipboard.js/2.0.6/clipboard.min.js" integrity="sha256-inc5kl9MA1hkeYUt+EC3BhlIgyp/2jDIyBLS6k3UxPI=" crossorigin="anonymous"></script><!-- search --><script src="https://cdnjs.cloudflare.com/ajax/libs/fuse.js/6.4.6/fuse.js" integrity="sha512-zv6Ywkjyktsohkbp9bb45V6tEMoWhzFzXis+LrMehmJZZSys19Yxf1dopHx7WzIKxr5tK2dVcYmaCk2uqdjF4A==" crossorigin="anonymous"></script><script src="https://cdnjs.cloudflare.com/ajax/libs/autocomplete.js/0.38.0/autocomplete.jquery.min.js" integrity="sha512-GU9ayf+66Xx2TmpxqJpliWbT5PiGYxpaG8rfnBEk1LL8l1KGkRShhngwdXK1UgqhAzWpZHSiYPc09/NwDQIGyg==" crossorigin="anonymous"></script><script src="https://cdnjs.cloudflare.com/ajax/libs/mark.js/8.11.1/mark.min.js" integrity="sha512-5CYOlHXGh6QpOFA/TeTylKLWfB3ftPsde7AnmhuitiTX4K5SqCLBeKro6sPS8ilsz1Q4NRx3v8Ko2IBiszzdww==" crossorigin="anonymous"></script><!-- pkgdown --><script src="pkgdown.js"></script><meta property="og:title" content="License"><!-- mathjax --><script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js" integrity="sha256-nvJJv9wWKEm88qvoQl9ekL2J+k/RWIsaSScxxlsrv8k=" crossorigin="anonymous"></script><script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/config/TeX-AMS-MML_HTMLorMML.js" integrity="sha256-84DKXVJXs0/F8OTMzX4UR909+jtl4G7SPypPavF+GfA=" crossorigin="anonymous"></script><!--[if lt IE 9]>
>>>>>>> b4de758c9fd2b61f632e58cff96f46b8d8e30d63
<script src="https://oss.maxcdn.com/html5shiv/3.7.3/html5shiv.min.js"></script>
<script src="https://oss.maxcdn.com/respond/1.4.2/respond.min.js"></script>
<![endif]--></head><body>
Expand Down Expand Up @@ -56,7 +60,12 @@
</div>

<div class="pkgdown-footer-right">
<<<<<<< HEAD
<p></p><p>Site built with <a href="https://pkgdown.r-lib.org/" class="external-link">pkgdown</a> 2.0.7.</p>
=======
<p></p><p>Site built with <a href="https://pkgdown.r-lib.org/" class="external-link">pkgdown</a>
2.0.2.</p>
>>>>>>> b4de758c9fd2b61f632e58cff96f46b8d8e30d63
</div>

</footer></div>
Expand Down
Loading

0 comments on commit bd7e492

Please sign in to comment.