From 3f3635538804a21ed390a9c1c414166841205130 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Tue, 29 Jun 2021 08:56:48 +0200 Subject: [PATCH] backport #41363: fix equality of QRCompactWY (#41395) * fix equality of QRCompactWY (#41363) Equality for `QRCompactWY` did not ignore the subdiagonal entries of `T` leading to nondeterministic behavior. This is pulled out from #41228, since this change should be less controversial than the other changes there and this particular bug just came up in ChainRules again. --- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 3 +- stdlib/LinearAlgebra/src/qr.jl | 34 +++++++++++ stdlib/LinearAlgebra/test/factorization.jl | 65 ++++++++++++++++++++++ stdlib/LinearAlgebra/test/testgroups | 1 + 4 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 stdlib/LinearAlgebra/test/factorization.jl diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 4eee4ecd92b11..136d61c758900 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -16,7 +16,8 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as setindex!, show, similar, sin, sincos, sinh, size, sqrt, strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec using Base: hvcat_fill, IndexLinear, promote_op, promote_typeof, - @propagate_inbounds, @pure, reduce, typed_vcat, require_one_based_indexing + @propagate_inbounds, @pure, reduce, typed_vcat, require_one_based_indexing, + splat using Base.Broadcast: Broadcasted, broadcasted export diff --git a/stdlib/LinearAlgebra/src/qr.jl b/stdlib/LinearAlgebra/src/qr.jl index a76577bb63a0d..27afc6237f07c 100644 --- a/stdlib/LinearAlgebra/src/qr.jl +++ b/stdlib/LinearAlgebra/src/qr.jl @@ -127,6 +127,40 @@ Base.iterate(S::QRCompactWY) = (S.Q, Val(:R)) Base.iterate(S::QRCompactWY, ::Val{:R}) = (S.R, Val(:done)) Base.iterate(S::QRCompactWY, ::Val{:done}) = nothing +# returns upper triangular views of all non-undef values of `qr(A).T`: +# +# julia> sparse(qr(A).T .== qr(A).T) +# 36×100 SparseMatrixCSC{Bool, Int64} with 1767 stored entries: +# ⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿ +# ⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿ +# ⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿ +# ⠀⠀⠀⠀⠀⠂⠛⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿ +# ⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⢀⠐⠙⢿⣿⣿⣿⣿ +# ⠀⠀⠐⠀⠀⠀⠀⠀⠀⢀⢙⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠁⠀⡀⠀⠙⢿⣿⣿ +# ⠀⠀⠐⠀⠀⠀⠀⠀⠀⠀⠄⠀⠙⢿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⡀⠀⠀⢀⠀⠀⠙⢿ +# ⠀⡀⠀⠀⠀⠀⠀⠀⠂⠒⠒⠀⠀⠀⠙⢿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⠀⠀⠀⠀⠀⠀⠀⢀⠀⠀⠀⡀⠀⠀ +# ⠀⠀⠀⠀⠀⠀⠀⠀⣈⡀⠀⠀⠀⠀⠀⠀⠙⢿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⠀⠀⠀⠀⠀⠀⠀⠀⠀⡀⠂⠀⢀⠀ +# +function _triuppers_qr(T) + blocksize, cols = size(T) + return Iterators.map(0:div(cols - 1, blocksize)) do i + n = min(blocksize, cols - i * blocksize) + return UpperTriangular(view(T, 1:n, (1:n) .+ i * blocksize)) + end +end + +function Base.hash(F::QRCompactWY, h::UInt) + return hash(F.factors, foldr(hash, _triuppers_qr(F.T); init=hash(QRCompactWY, h))) +end +function Base.:(==)(A::QRCompactWY, B::QRCompactWY) + return A.factors == B.factors && all(splat(==), zip(_triuppers_qr.((A.T, B.T))...)) +end +function Base.isequal(A::QRCompactWY, B::QRCompactWY) + return isequal(A.factors, B.factors) && all(zip(_triuppers_qr.((A.T, B.T))...)) do (a, b) + isequal(a, b)::Bool + end +end + """ QRPivoted <: Factorization diff --git a/stdlib/LinearAlgebra/test/factorization.jl b/stdlib/LinearAlgebra/test/factorization.jl new file mode 100644 index 0000000000000..804341f4a05f4 --- /dev/null +++ b/stdlib/LinearAlgebra/test/factorization.jl @@ -0,0 +1,65 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +module TestFactorization +using Test, LinearAlgebra + +@testset "equality for factorizations - $f" for f in Any[ + bunchkaufman, + cholesky, + x -> cholesky(x, Val(true)), + eigen, + hessenberg, + lq, + lu, + qr, + x -> qr(x, Val(true)), + svd, + schur, +] + A = randn(3, 3) + A = A * A' # ensure A is pos. def. and symmetric + F, G = f(A), f(A) + + @test F == G + @test isequal(F, G) + @test hash(F) == hash(G) + + f === hessenberg && continue + + # change all arrays in F to have eltype Float32 + F = typeof(F).name.wrapper(Base.mapany(1:nfields(F)) do i + x = getfield(F, i) + return x isa AbstractArray{Float64} ? Float32.(x) : x + end...) + # round all arrays in G to the nearest Float64 representable as Float32 + G = typeof(G).name.wrapper(Base.mapany(1:nfields(G)) do i + x = getfield(G, i) + return x isa AbstractArray{Float64} ? Float64.(Float32.(x)) : x + end...) + + if f === qr + @test F == G + @test isequal(F, G) + else + @test_broken F == G + @test_broken isequal(F, G) + end + @test hash(F) == hash(G) +end + +@testset "equality of QRCompactWY" begin + A = rand(100, 100) + F, G = qr(A), qr(A) + + @test F == G + @test isequal(F, G) + @test hash(F) == hash(G) + + G.T[28, 100] = 42 + + @test F != G + @test !isequal(F, G) + @test hash(F) != hash(G) +end + +end diff --git a/stdlib/LinearAlgebra/test/testgroups b/stdlib/LinearAlgebra/test/testgroups index b33dfecaa82ee..de082d8e7dce0 100644 --- a/stdlib/LinearAlgebra/test/testgroups +++ b/stdlib/LinearAlgebra/test/testgroups @@ -25,3 +25,4 @@ givens structuredbroadcast addmul ldlt +factorization