Skip to content

Commit

Permalink
backport #41363: fix equality of QRCompactWY (#41395)
Browse files Browse the repository at this point in the history
* fix equality of QRCompactWY (#41363)

Equality for `QRCompactWY` did not ignore the subdiagonal entries of
`T` leading to nondeterministic behavior.

This is pulled out from #41228, since this change should be less
controversial than the other changes there and this particular bug just
came up in ChainRules again.
  • Loading branch information
simeonschaub authored and staticfloat committed Dec 22, 2022
1 parent 8a0ee3d commit 3f36355
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 1 deletion.
3 changes: 2 additions & 1 deletion stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
setindex!, show, similar, sin, sincos, sinh, size, sqrt,
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec
using Base: hvcat_fill, IndexLinear, promote_op, promote_typeof,
@propagate_inbounds, @pure, reduce, typed_vcat, require_one_based_indexing
@propagate_inbounds, @pure, reduce, typed_vcat, require_one_based_indexing,
splat
using Base.Broadcast: Broadcasted, broadcasted

export
Expand Down
34 changes: 34 additions & 0 deletions stdlib/LinearAlgebra/src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,40 @@ Base.iterate(S::QRCompactWY) = (S.Q, Val(:R))
Base.iterate(S::QRCompactWY, ::Val{:R}) = (S.R, Val(:done))
Base.iterate(S::QRCompactWY, ::Val{:done}) = nothing

# returns upper triangular views of all non-undef values of `qr(A).T`:
#
# julia> sparse(qr(A).T .== qr(A).T)
# 36×100 SparseMatrixCSC{Bool, Int64} with 1767 stored entries:
# ⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿
# ⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿
# ⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿
# ⠀⠀⠀⠀⠀⠂⠛⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿
# ⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⢀⠐⠙⢿⣿⣿⣿⣿
# ⠀⠀⠐⠀⠀⠀⠀⠀⠀⢀⢙⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠁⠀⡀⠀⠙⢿⣿⣿
# ⠀⠀⠐⠀⠀⠀⠀⠀⠀⠀⠄⠀⠙⢿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⡀⠀⠀⢀⠀⠀⠙⢿
# ⠀⡀⠀⠀⠀⠀⠀⠀⠂⠒⠒⠀⠀⠀⠙⢿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⠀⠀⠀⠀⠀⠀⠀⢀⠀⠀⠀⡀⠀⠀
# ⠀⠀⠀⠀⠀⠀⠀⠀⣈⡀⠀⠀⠀⠀⠀⠀⠙⢿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⠀⠀⠀⠀⠀⠀⠀⠀⠀⡀⠂⠀⢀⠀
#
function _triuppers_qr(T)
blocksize, cols = size(T)
return Iterators.map(0:div(cols - 1, blocksize)) do i
n = min(blocksize, cols - i * blocksize)
return UpperTriangular(view(T, 1:n, (1:n) .+ i * blocksize))
end
end

function Base.hash(F::QRCompactWY, h::UInt)
return hash(F.factors, foldr(hash, _triuppers_qr(F.T); init=hash(QRCompactWY, h)))
end
function Base.:(==)(A::QRCompactWY, B::QRCompactWY)
return A.factors == B.factors && all(splat(==), zip(_triuppers_qr.((A.T, B.T))...))
end
function Base.isequal(A::QRCompactWY, B::QRCompactWY)
return isequal(A.factors, B.factors) && all(zip(_triuppers_qr.((A.T, B.T))...)) do (a, b)
isequal(a, b)::Bool
end
end

"""
QRPivoted <: Factorization
Expand Down
65 changes: 65 additions & 0 deletions stdlib/LinearAlgebra/test/factorization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module TestFactorization
using Test, LinearAlgebra

@testset "equality for factorizations - $f" for f in Any[
bunchkaufman,
cholesky,
x -> cholesky(x, Val(true)),
eigen,
hessenberg,
lq,
lu,
qr,
x -> qr(x, Val(true)),
svd,
schur,
]
A = randn(3, 3)
A = A * A' # ensure A is pos. def. and symmetric
F, G = f(A), f(A)

@test F == G
@test isequal(F, G)
@test hash(F) == hash(G)

f === hessenberg && continue

# change all arrays in F to have eltype Float32
F = typeof(F).name.wrapper(Base.mapany(1:nfields(F)) do i
x = getfield(F, i)
return x isa AbstractArray{Float64} ? Float32.(x) : x
end...)
# round all arrays in G to the nearest Float64 representable as Float32
G = typeof(G).name.wrapper(Base.mapany(1:nfields(G)) do i
x = getfield(G, i)
return x isa AbstractArray{Float64} ? Float64.(Float32.(x)) : x
end...)

if f === qr
@test F == G
@test isequal(F, G)
else
@test_broken F == G
@test_broken isequal(F, G)
end
@test hash(F) == hash(G)
end

@testset "equality of QRCompactWY" begin
A = rand(100, 100)
F, G = qr(A), qr(A)

@test F == G
@test isequal(F, G)
@test hash(F) == hash(G)

G.T[28, 100] = 42

@test F != G
@test !isequal(F, G)
@test hash(F) != hash(G)
end

end
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/test/testgroups
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ givens
structuredbroadcast
addmul
ldlt
factorization

0 comments on commit 3f36355

Please sign in to comment.