Skip to content

Commit

Permalink
Add sparseL extractor and use it in condVar (#492)
Browse files Browse the repository at this point in the history
* Add sparseL extractor and use it in condVar

* Add and export condVarTables.  Comment out failing tests.

* update version bound for condVar

* make cvtbl internal

* remove some unused type restrictions

* Remove guards on code for v1.7.0-DEV

* Add tests of convVartable

Co-authored-by: Phillip Alday <[email protected]>
  • Loading branch information
dmbates and palday authored May 8, 2021
1 parent 16bb7fb commit cbcc1e1
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 17 deletions.
2 changes: 2 additions & 0 deletions src/MixedModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ export @formula,
coeftable,
cond,
condVar,
condVartables,
describeblocks,
deviance,
dispersion,
Expand Down Expand Up @@ -118,6 +119,7 @@ export @formula,
setθ!,
simulate!,
sparse,
sparseL,
std,
stderror,
updateL!,
Expand Down
117 changes: 109 additions & 8 deletions src/linearmixedmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,15 +285,48 @@ diagonal blocks from the conditional variance-covariance matrix,
s² Λ(Λ'Z'ZΛ + I)⁻¹Λ'
"""
function condVar(m::LinearMixedModel{T}) where {T}
retrms = m.reterms
t1 = first(retrms)
L11 = first(m.L)
if !isone(length(retrms)) || !isa(L11, Diagonal{T,Vector{T}})
throw(ArgumentError("code for multiple or vector-valued r.e. not yet written"))
L = m.L
s = sdest(m)
@static if VERSION < v"1.6.1"
spL = LowerTriangular(SparseMatrixCSC{T, Int}(sparseL(m)))
else
spL = LowerTriangular(sparseL(m))
end
nre = size(spL, 1)
val = Array{T,3}[]
offset = 0
for (i, re) in enumerate(m.reterms)
λt = s * transpose(re.λ)
vi = size(λt, 2)
ℓi = length(re.levels)
vali = Array{T}(undef, (vi, vi, ℓi))
scratch = Matrix{T}(undef, (size(spL, 1), vi))
for b in 1:ℓi
fill!(scratch, zero(T))
copyto!(view(scratch, (offset + (b - 1) * vi) .+ (1:vi), :), λt)
ldiv!(spL, scratch)
mul!(view(vali, :, :, b), scratch', scratch)
end
push!(val, vali)
offset += vi * ℓi
end
ll = first(t1.λ)
Ld = L11.diag
Array{T,3}[reshape(abs2.(ll ./ Ld) .* varest(m), (1, 1, length(Ld)))]
val
end

function _cvtbl(arr::Array{T,3}, trm) where {T}
merge(
NamedTuple{(fname(trm),)}((trm.levels,)),
columntable([NamedTuple{(:σ, :ρ)}(sdcorr(view(arr, :, :, i))) for i in axes(arr, 3)]),
)
end

"""
condVartables(m::LinearMixedModel)
Return the conditional covariance matrices of the random effects as a `NamedTuple` of columntables
"""
function condVartables(m::MixedModel{T}) where {T}
NamedTuple{fnames(m)}((map(_cvtbl, condVar(m), m.reterms)...,))
end

function pushALblock!(A, L, blk)
Expand Down Expand Up @@ -910,6 +943,74 @@ end

Base.show(io::IO, m::LinearMixedModel) = Base.show(io, MIME"text/plain"(), m)

"""
_coord(A::AbstractMatrix)
Return the positions and values of the nonzeros in `A` as a
`NamedTuple{(:i, :j, :v), Tuple{Vector{Int32}, Vector{Int32}, Vector{Float64}}}`
"""
function _coord(A::Diagonal)
(i = Int32.(axes(A,1)), j = Int32.(axes(A,2)), v = A.diag)
end

function _coord(A::UniformBlockDiagonal)
dat = A.data
r, c, k = size(dat)
blk = repeat(r .* (0:k-1), inner=r*c)
(
i = Int32.(repeat(1:r, outer=c*k) .+ blk),
j = Int32.(repeat(1:c, inner=r, outer=k) .+ blk),
v = vec(dat)
)
end

function _coord(A::SparseMatrixCSC{T,Int32}) where {T}
rv = rowvals(A)
cv = similar(rv)
for j in axes(A, 2), k in nzrange(A, j)
cv[k] = j
end
(i = rv, j = cv, v = nonzeros(A), )
end

function _coord(A::Matrix)
m, n = size(A)
(
i = Int32.(repeat(axes(A, 1), outer=n)),
j = Int32.(repeat(axes(A, 2), inner=m)),
v = vec(A),
)
end

"""
sparseL(m::LinearMixedModel{T}; full::Bool=false) where {T}
Return the lower Cholesky factor `L` as a `SparseMatrix{T,Int32}`.
`full` indicates whether the parts of `L` associated with the fixed-effects and response
are to be included.
"""
function sparseL(m::LinearMixedModel{T}; full::Bool=false) where {T}
L, reterms = m.L, m.reterms
nt = length(reterms) + full
rowoffset, coloffset = 0, 0
val = (i = Int32[], j = Int32[], v = T[])
for i in 1:nt, j in 1:i
Lblk = L[block(i, j)]
cblk = _coord(Lblk)
append!(val.i, cblk.i .+ Int32(rowoffset))
append!(val.j, cblk.j .+ Int32(coloffset))
append!(val.v, cblk.v)
if i == j
coloffset = 0
rowoffset += size(Lblk, 1)
else
coloffset += size(Lblk, 2)
end
end
dropzeros!(tril!(sparse(val...,)))
end


"""
ssqdenom(m::LinearMixedModel)
Expand Down
26 changes: 25 additions & 1 deletion src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ end
Return the average of `a` and `b`
"""
average(a::T, b::T) where {T<:AbstractFloat} = (a + b) / 2
average(a::T, b::T) where {T<:AbstractFloat} = (a + b) / T(2)

"""
cpad(s::AbstractString, n::Integer)
Expand Down Expand Up @@ -137,3 +137,27 @@ function replicate(f::Function, n::Integer;
end
results
end

"""
sdcorr(A::AbstractMatrix{T}) where {T}
Transform a square matrix `A` with positive diagonals into an `NTuple{size(A,1), T}` of
standard deviations and a tuple of correlations.
`A` is assumed to be symmetric and only the lower triangle is used. The order of the
correlations is row-major ordering of the lower triangle (or, equivalently, column-major
in the upper triangle).
"""
function sdcorr(A::AbstractMatrix{T}) where {T}
m,n = size(A)
m == n || throw(ArgumentError("matrix A must be square"))
indpairs = checkindprsk(m)
rtdiag = sqrt.(NTuple{m,T}(diag(A)))
(
rtdiag,
ntuple(kchoose2(m)) do k
i,j = indpairs[k]
A[i,j]/(rtdiag[i] * rtdiag[j])
end,
)
end
4 changes: 1 addition & 3 deletions test/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using LinearAlgebra
using MixedModels
using Random
using StableRNGs
using Statistics
using Tables
using Test

Expand Down Expand Up @@ -34,7 +35,6 @@ include("modelcache.jl")
# restore the original state
refit!(fm, vec(float.(ds.yield)))
end

@testset "Poisson" begin
center(v::AbstractVector) = v .- (sum(v) / length(v))
grouseticks = DataFrame(dataset(:grouseticks))
Expand All @@ -50,7 +50,6 @@ include("modelcache.jl")
gm2sim = refit!(simulate!(StableRNG(42), deepcopy(gm2)), fast=true)
@test isapprox(gm2.β, gm2sim.β; atol=norm(stderror(gm2)))
end

@testset "_rand with dispersion" begin
@test_throws ArgumentError MixedModels._rand(StableRNG(42), Normal(), 1, 1, 1)
@test_throws ArgumentError MixedModels._rand(StableRNG(42), Gamma(), 1, 1, 1)
Expand Down Expand Up @@ -99,7 +98,6 @@ end
@test sum(issingular(bsamp)) == sum(issingular(bsamp_threaded))
end


@testset "Bernoulli simulate! and GLMM boostrap" begin
contra = dataset(:contra)
gm0 = fit(MixedModel, only(gfms[:contra]), contra, Bernoulli(), fast=true)
Expand Down
1 change: 0 additions & 1 deletion test/likelihoodratiotest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ end
fm1 = fit(MixedModel,@formula(reaction ~ 1 + days + (1+days|subj)),slp, REML=true);

@test_throws ArgumentError likelihoodratiotest(fm0,fm1)

contra = MixedModels.dataset(:contra);
# glm doesn't like categorical responses, so we convert it to numeric ourselves
# TODO: upstream fix
Expand Down
3 changes: 0 additions & 3 deletions test/mime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ using MixedModels: pirls!, setβθ!, setθ!, updateL!

include("modelcache.jl")


# explicitly setting theta for these to so that we can do exact textual comparisons
βθ = [0.1955554704948119, 0.05755412761885973, 0.3207843518569843, -1.0582595252774376,
-2.1047524824609853, -1.0549789653925743, 1.339766125847893, 0.4953047709862237]
Expand All @@ -32,7 +31,6 @@ lrt = likelihoodratiotest(fm0, fm1)
mime = MIME"text/markdown"()
@test_logs (:warn, "Model has not been fit: results will be nonsense") sprint(show, mime, gm3)
gm3.optsum.feval = 1

@testset "lmm" begin
@test sprint(show, mime, fm0) == """
| | Est. | SE | z | p | σ_subj |
Expand Down Expand Up @@ -153,7 +151,6 @@ end
@testset "html" begin
# this is minimal since we're mostly testing that dispatch works
# the stdlib actually handles most of the conversion

@test sprint(show, MIME("text/html"), BlockDescription(gm3)) == """
<table><tr><th align="left">rows</th><th align="left">subj</th><th align="left">item</th><th align="left">fixed</th></tr><tr><td align="left">316</td><td align="left">Diagonal</td><td align="left"></td><td align="left"></td></tr><tr><td align="left">24</td><td align="left">Dense</td><td align="left">Diag/Dense</td><td align="left"></td></tr><tr><td align="left">7</td><td align="left">Dense</td><td align="left">Dense</td><td align="left">Dense</td></tr></table>
"""
Expand Down
41 changes: 40 additions & 1 deletion test/pls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,12 @@ end
@test varest(fm) 0.3024263987592062 atol=0.0001
@test logdet(fm) 95.74614821367786 atol=0.001

@test_throws ArgumentError condVar(fm)
cv = condVar(fm)
@test length(cv) == 2
@test size(first(cv)) == (1, 1, 24)
@test size(last(cv)) == (1, 1, 6)
@test first(first(cv)) 0.07331320237988301 rtol=1.e-4
@test last(last(cv)) 0.04051547211287544 rtol=1.e-4

rfu = ranef(fm, uscale=true)
@test length(rfu) == 2
Expand Down Expand Up @@ -200,6 +205,13 @@ end
@test varest(fm) 0.6780020742644107 atol=0.0001
@test logdet(fm) 101.0381339953986 atol=0.001

cv = condVar(fm)
@test length(cv) == 2
@test size(first(cv)) == (1, 1, 30)
@test first(first(cv)) 1.111873335663485 rtol=1.e-4
@test size(last(cv)) == (1, 1, 10)
@test last(last(cv)) 0.850428770978789 rtol=1.e-4

show(io, BlockDescription(fm))
@test countlines(seekstart(io)) == 4
tokens = Set(split(String(take!(io)), r"\s+"))
Expand All @@ -217,6 +229,10 @@ end
@test fm1.optsum.initial == ones(3)
@test lowerbd(fm1) == zeros(3)

spL = sparseL(fm1)
@test size(spL) == (4114, 4114)
@test 733090 < nnz(spL) < 733100

@test objective(fm1) 237721.7687745563 atol=0.001
ftd1 = fitted(fm1);
@test size(ftd1) == (73421, )
Expand Down Expand Up @@ -291,6 +307,29 @@ end
@test size(first(u3)) == (2, 18)
@test first(u3)[1, 1] 3.030300122575336 atol=0.001

cv = condVar(fm)
@test length(cv) == 1
@test size(first(cv)) == (2, 2, 18)
@test first(first(cv)) 140.96612241084617 rtol=1.e-4
@test last(last(cv)) 5.157750215432247 rtol=1.e-4
@test first(cv)[2] -20.60428045516186 rtol=1.e-4

cvt = condVartables(fm)
@test length(cvt) == 1
@test only(keys(cvt)) == :subj
cvtsubj = cvt.subj
@test only(cvt) === cvtsubj
@test keys(cvtsubj) == (:subj, , )
@test Tables.istable(cvtsubj)
@test first(cvtsubj.subj) == "S308"
cvtsubjσ1 = first(cvtsubj.σ)
@test all(==(cvtsubjσ1), cvtsubj.σ)
@test first(cvtsubjσ1) 11.87291549750297 atol=1.0e-4
@test last(cvtsubjσ1) 2.271068078114843 atol=1.0e-4
cvtsubjρ = first(cvtsubj.ρ)
@test all(==(cvtsubjρ), cvtsubj.ρ)
@test only(cvtsubjρ) -0.7641347018831385 atol=1.0e-4

b3 = ranef(fm)
@test length(b3) == 1
@test size(first(b3)) == (2, 18)
Expand Down

0 comments on commit cbcc1e1

Please sign in to comment.