Skip to content

Commit

Permalink
Rewrote corkendall (issue #634) (#647)
Browse files Browse the repository at this point in the history
New version of corkendall is approx 4 times faster if both arguments are vectors and 7 times faster if at least one is a matrix. See issue #634 for details.
  • Loading branch information
PGS62 authored Feb 8, 2021
1 parent 3b0b2da commit 11ac5b5
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 87 deletions.
242 changes: 162 additions & 80 deletions src/rankcorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,121 +33,203 @@ corspearman(X::RealMatrix) = (Z = mapslices(tiedrank, X, dims=1); cor(Z, Z))
#
#######################################

# Knight JASA (1966)

function corkendall!(x::RealVector, y::RealVector)
# Knight, William R. “A Computer Method for Calculating Kendall's Tau with Ungrouped Data.”
# Journal of the American Statistical Association, vol. 61, no. 314, 1966, pp. 436–439.
# JSTOR, www.jstor.org/stable/2282833.
function corkendall!(x::RealVector, y::RealVector, permx::AbstractVector{<:Integer}=sortperm(x))
if any(isnan, x) || any(isnan, y) return NaN end
n = length(x)
if n != length(y) error("Vectors must have same length") end

# Initial sorting
pm = sortperm(y)
x[:] = x[pm]
y[:] = y[pm]
pm[:] = sortperm(x)
x[:] = x[pm]

# Counting ties in x and y
iT = 1
nT = 0
iU = 1
nU = 0
for i = 2:n
if x[i] == x[i-1]
iT += 1
else
nT += iT*(iT - 1)
iT = 1
end
if y[i] == y[i-1]
iU += 1
else
nU += iU*(iU - 1)
iU = 1
permute!(x, permx)
permute!(y, permx)

# Use widen to avoid overflows on both 32bit and 64bit
npairs = div(widen(n) * (n - 1), 2)
ntiesx = ndoubleties = nswaps = widen(0)
k = 0

@inbounds for i = 2:n
if x[i - 1] == x[i]
k += 1
elseif k > 0
# Sort the corresponding chunk of y, so the rows of hcat(x,y) are
# sorted first on x, then (where x values are tied) on y. Hence
# double ties can be counted by calling countties.
sort!(view(y, (i - k - 1):(i - 1)))
ntiesx += div(widen(k) * (k + 1), 2) # Must use wide integers here
ndoubleties += countties(y, i - k - 1, i - 1)
k = 0
end
end
if iT > 1 nT += iT*(iT - 1) end
nT = div(nT,2)
if iU > 1 nU += iU*(iU - 1) end
nU = div(nU,2)

# Sort y after x
y[:] = y[pm]

# Calculate double ties
iV = 1
nV = 0
jV = 1
for i = 2:n
if x[i] == x[i-1] && y[i] == y[i-1]
iV += 1
else
nV += iV*(iV - 1)
iV = 1
end
if k > 0
sort!(view(y, (n - k):n))
ntiesx += div(widen(k) * (k + 1), 2)
ndoubleties += countties(y, n - k, n)
end
if iV > 1 nV += iV*(iV - 1) end
nV = div(nV,2)

nD = div(n*(n - 1),2)
return (nD - nT - nU + nV - 2swaps!(y)) / (sqrt(nD - nT) * sqrt(nD - nU))
end
nswaps = merge_sort!(y, 1, n)
ntiesy = countties(y, 1, n)

# Calls to float below prevent possible overflow errors when
# length(x) exceeds 77_936 (32 bit) or 5_107_605_667 (64 bit)
(npairs + ndoubleties - ntiesx - ntiesy - 2 * nswaps) /
sqrt(float(npairs - ntiesx) * float(npairs - ntiesy))
end

"""
corkendall(x, y=x)
Compute Kendall's rank correlation coefficient, τ. `x` and `y` must both be either
matrices or vectors.
"""
corkendall(x::RealVector, y::RealVector) = corkendall!(float(copy(x)), float(copy(y)))
corkendall(x::RealVector, y::RealVector) = corkendall!(copy(x), copy(y))

corkendall(X::RealMatrix, y::RealVector) = Float64[corkendall!(float(X[:,i]), float(copy(y))) for i in 1:size(X, 2)]

corkendall(x::RealVector, Y::RealMatrix) = (n = size(Y,2); reshape(Float64[corkendall!(float(copy(x)), float(Y[:,i])) for i in 1:n], 1, n))
function corkendall(X::RealMatrix, y::RealVector)
permy = sortperm(y)
return([corkendall!(copy(y), X[:,i], permy) for i in 1:size(X, 2)])
end

corkendall(X::RealMatrix, Y::RealMatrix) = Float64[corkendall!(float(X[:,i]), float(Y[:,j])) for i in 1:size(X, 2), j in 1:size(Y, 2)]
function corkendall(x::RealVector, Y::RealMatrix)
n = size(Y, 2)
permx = sortperm(x)
return(reshape([corkendall!(copy(x), Y[:,i], permx) for i in 1:n], 1, n))
end

function corkendall(X::RealMatrix)
n = size(X, 2)
C = Matrix{eltype(X)}(I, n, n)
C = Matrix{Float64}(I, n, n)
for j = 2:n
for i = 1:j-1
C[i,j] = corkendall!(X[:,i],X[:,j])
C[j,i] = C[i,j]
permx = sortperm(X[:,j])
for i = 1:j - 1
C[j,i] = corkendall!(X[:,j], X[:,i], permx)
C[i,j] = C[j,i]
end
end
return C
end

function corkendall(X::RealMatrix, Y::RealMatrix)
nr = size(X, 2)
nc = size(Y, 2)
C = Matrix{Float64}(undef, nr, nc)
for j = 1:nr
permx = sortperm(X[:,j])
for i = 1:nc
C[j,i] = corkendall!(X[:,j], Y[:,i], permx)
end
end
return C
end

# Auxilliary functions for Kendall's rank correlation

function swaps!(x::RealVector)
n = length(x)
if n == 1 return 0 end
n2 = div(n, 2)
xl = view(x, 1:n2)
xr = view(x, n2+1:n)
nsl = swaps!(xl)
nsr = swaps!(xr)
sort!(xl)
sort!(xr)
return nsl + nsr + mswaps(xl,xr)
"""
countties(x::RealVector, lo::Integer, hi::Integer)
Return the number of ties within `x[lo:hi]`. Assumes `x` is sorted.
"""
function countties(x::AbstractVector, lo::Integer, hi::Integer)
# Use of widen below prevents possible overflow errors when
# length(x) exceeds 2^16 (32 bit) or 2^32 (64 bit)
thistiecount = result = widen(0)
checkbounds(x, lo:hi)
@inbounds for i = (lo + 1):hi
if x[i] == x[i - 1]
thistiecount += 1
elseif thistiecount > 0
result += div(thistiecount * (thistiecount + 1), 2)
thistiecount = widen(0)
end
end

if thistiecount > 0
result += div(thistiecount * (thistiecount + 1), 2)
end
result
end

function mswaps(x::RealVector, y::RealVector)
i = 1
j = 1
nSwaps = 0
n = length(x)
while i <= n && j <= length(y)
if y[j] < x[i]
nSwaps += n - i + 1
# Tests appear to show that a value of 64 is optimal,
# but note that the equivalent constant in base/sort.jl is 20.
const SMALL_THRESHOLD = 64

# merge_sort! copied from Julia Base
# (commit 28330a2fef4d9d149ba0fd3ffa06347b50067647, dated 20 Sep 2020)
"""
merge_sort!(v::AbstractVector, lo::Integer, hi::Integer, t::AbstractVector=similar(v, 0))
Mutates `v` by sorting elements `x[lo:hi]` using the merge sort algorithm.
This method is a copy-paste-edit of sort! in base/sort.jl, amended to return the bubblesort distance.
"""
function merge_sort!(v::AbstractVector, lo::Integer, hi::Integer, t::AbstractVector=similar(v, 0))
# Use of widen below prevents possible overflow errors when
# length(v) exceeds 2^16 (32 bit) or 2^32 (64 bit)
nswaps = widen(0)
@inbounds if lo < hi
hi - lo <= SMALL_THRESHOLD && return insertion_sort!(v, lo, hi)

m = midpoint(lo, hi)
(length(t) < m - lo + 1) && resize!(t, m - lo + 1)

nswaps = merge_sort!(v, lo, m, t)
nswaps += merge_sort!(v, m + 1, hi, t)

i, j = 1, lo
while j <= m
t[i] = v[j]
i += 1
j += 1
else
end

i, k = 1, lo
while k < j <= hi
if v[j] < t[i]
v[k] = v[j]
j += 1
nswaps += m - lo + 1 - (i - 1)
else
v[k] = t[i]
i += 1
end
k += 1
end
while k < j
v[k] = t[i]
k += 1
i += 1
end
end
return nSwaps
return nswaps
end

# insertion_sort! and midpoint copied from Julia Base
# (commit 28330a2fef4d9d149ba0fd3ffa06347b50067647, dated 20 Sep 2020)
midpoint(lo::T, hi::T) where T <: Integer = lo + ((hi - lo) >>> 0x01)
midpoint(lo::Integer, hi::Integer) = midpoint(promote(lo, hi)...)

"""
insertion_sort!(v::AbstractVector, lo::Integer, hi::Integer)
Mutates `v` by sorting elements `x[lo:hi]` using the insertion sort algorithm.
This method is a copy-paste-edit of sort! in base/sort.jl, amended to return the bubblesort distance.
"""
function insertion_sort!(v::AbstractVector, lo::Integer, hi::Integer)
if lo == hi return widen(0) end
nswaps = widen(0)
@inbounds for i = lo + 1:hi
j = i
x = v[i]
while j > lo
if x < v[j - 1]
nswaps += 1
v[j] = v[j - 1]
j -= 1
continue
end
break
end
v[j] = x
end
return nswaps
end
80 changes: 73 additions & 7 deletions test/rankcorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,86 @@ c22 = corspearman(x2, x2)
@test corspearman(X, X) [c11 c12; c12 c22]
@test corspearman(X) [c11 c12; c12 c22]


# corkendall

@test corkendall(x1, y) -0.105409255338946
@test corkendall(x2, y) -0.117851130197758
# Check error, handling of NaN, Inf etc
@test_throws ErrorException("Vectors must have same length") corkendall([1,2,3,4], [1,2,3])
@test isnan(corkendall([1,2], [3,NaN]))
@test isnan(corkendall([1,1,1], [1,2,3]))
@test corkendall([-Inf,-0.0,Inf],[1,2,3]) == 1.0

# Test, with exact equality, some known results.
# RealVector, RealVector
@test corkendall(x1, y) == -1/sqrt(90)
@test corkendall(x2, y) == -1/sqrt(72)
# RealMatrix, RealVector
@test corkendall(X, y) == [-1/sqrt(90), -1/sqrt(72)]
# RealVector, RealMatrix
@test corkendall(y, X) == [-1/sqrt(90) -1/sqrt(72)]

# n = 78_000 tests for overflow errors on 32 bit
# Testing for overflow errors on 64bit would require n be too large for practicality
# This also tests merge_sort! since n is (much) greater than SMALL_THRESHOLD.
n = 78_000
# Test with many repeats
@test corkendall(repeat(x1, n), repeat(y, n)) -1/sqrt(90)
@test corkendall(repeat(x2, n), repeat(y, n)) -1/sqrt(72)
@test corkendall(repeat(X, n), repeat(y, n)) [-1/sqrt(90), -1/sqrt(72)]
@test corkendall(repeat(y, n), repeat(X, n)) [-1/sqrt(90) -1/sqrt(72)]
@test corkendall(repeat([0,1,1,0], n), repeat([1,0,1,0], n)) == 0.0

# Test with no repeats, note testing for exact equality
@test corkendall(collect(1:n), collect(1:n)) == 1.0
@test corkendall(collect(1:n), reverse(collect(1:n))) == -1.0

@test corkendall(X, y) [-0.105409255338946, -0.117851130197758]
@test corkendall(y, X) [-0.105409255338946 -0.117851130197758]
# All elements identical should yield NaN
@test isnan(corkendall(repeat([1], n), collect(1:n)))

c11 = corkendall(x1, x1)
c12 = corkendall(x1, x2)
c22 = corkendall(x2, x2)

@test c11 1.0
@test c22 1.0
# RealMatrix, RealMatrix
@test corkendall(X, X) [c11 c12; c12 c22]
# RealMatrix
@test corkendall(X) [c11 c12; c12 c22]

@test c11 == 1.0
@test c22 == 1.0
@test c12 == 3/sqrt(20)

# Finished testing for overflow, so redefine n for speedier tests
n = 100

@test corkendall(repeat(X, n), repeat(X, n)) [c11 c12; c12 c22]
@test corkendall(repeat(X, n)) [c11 c12; c12 c22]

# All eight three-element permutations
z = [1 1 1;
1 1 2;
1 2 2;
1 2 2;
1 2 1;
2 1 2;
1 1 2;
2 2 2]

@test corkendall(z) == [1 0 1/3; 0 1 0; 1/3 0 1]
@test corkendall(z, z) == [1 0 1/3; 0 1 0; 1/3 0 1]
@test corkendall(z[:,1], z) == [1 0 1/3]
@test corkendall(z, z[:,1]) == [1; 0; 1/3]

z = float(z)
@test corkendall(z) == [1 0 1/3; 0 1 0; 1/3 0 1]
@test corkendall(z, z) == [1 0 1/3; 0 1 0; 1/3 0 1]
@test corkendall(z[:,1], z) == [1 0 1/3]
@test corkendall(z, z[:,1]) == [1; 0; 1/3]

w = repeat(z, n)
@test corkendall(w) == [1 0 1/3; 0 1 0; 1/3 0 1]
@test corkendall(w, w) == [1 0 1/3; 0 1 0; 1/3 0 1]
@test corkendall(w[:,1], w) == [1 0 1/3]
@test corkendall(w, w[:,1]) == [1; 0; 1/3]

StatsBase.midpoint(1,10) == 5
StatsBase.midpoint(1,widen(10)) == 5

0 comments on commit 11ac5b5

Please sign in to comment.