diff --git a/src/solver/CUSOLVER.jl b/src/solver/CUSOLVER.jl index 5715210a..eaadb8bd 100644 --- a/src/solver/CUSOLVER.jl +++ b/src/solver/CUSOLVER.jl @@ -1,7 +1,7 @@ module CUSOLVER using ..CuArrays -using ..CuArrays: libcusolver, @allowscalar, unsafe_free!, @argout, @workspace, @retry_reclaim +using ..CuArrays: libcusolver, @allowscalar, assertscalar, unsafe_free!, @argout, @workspace, @retry_reclaim using ..CUBLAS: cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasDiagType_t using ..CUSPARSE: cusparseMatDescr_t diff --git a/src/solver/linalg.jl b/src/solver/linalg.jl index e191bc9f..a569b45f 100644 --- a/src/solver/linalg.jl +++ b/src/solver/linalg.jl @@ -54,10 +54,11 @@ LinearAlgebra.lmul!(trA::Transpose{T,<:CuQRPackedQ{T,S}}, B::CuVecOrMat{T}) wher ormqr!('L', 'T', parent(trA).factors, parent(trA).τ, B) function Base.getindex(A::CuQRPackedQ{T, S}, i::Integer, j::Integer) where {T, S} + assertscalar("CuQRPackedQ getindex") x = CuArrays.zeros(T, size(A, 2)) x[j] = 1 lmul!(A, x) - return @allowscalar x[i] + return x[i] end function Base.show(io::IO, F::CuQR) diff --git a/test/solver.jl b/test/solver.jl index 8a1b956e..698ec07d 100644 --- a/test/solver.jl +++ b/test/solver.jl @@ -295,7 +295,7 @@ k = 1 @test abs.(h_U'h_U) ≈ I @test abs.(h_U[:,1:min(_m,_n)]'U[:,1:min(_m,_n)]) ≈ I @test collect(svdvals(d_A, method)) ≈ svdvals(A) - @test abs.(h_V'h_V) ≈ I + @test abs.(h_V'*h_V) ≈ I @test abs.(h_V[:,1:min(_m,_n)]'*V[:,1:min(_m,_n)]) ≈ I @test collect(d_U'*d_A*d_V) ≈ U'*A*V @test collect(svd(d_A, method).V') == h_V[:,1:min(_m,_n)]' @@ -312,6 +312,7 @@ k = 1 tol = min(m, n)*eps(real(elty))*(1 + (elty <: Complex)) A = rand(elty, m, n) + qra = qr(A) d_A = CuArray(A) d_F = qr(d_A) d_RR = d_F.Q'*d_A @@ -320,6 +321,14 @@ k = 1 @test size(d_F) == size(A) @test size(d_F.Q, 1) == size(A, 1) @test det(d_F.Q) ≈ det(collect(d_F.Q * CuMatrix{elty}(I, size(d_F.Q)))) atol=tol*norm(A) + CuArrays.@allowscalar begin + qval = d_F.Q[1, 1] + @test qval ≈ qra.Q[1, 1] + qrstr = sprint(show, d_F) + @test qrstr == "$(typeof(d_F)) with factors Q and R:\n$(sprint(show, d_F.Q))\n$(sprint(show, d_F.R))" + end + dQ, dR = d_F + @test collect(dQ*dR) ≈ A A = rand(elty, n, m) d_A = CuArray(A) d_F = qr(d_A)