Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Allow scalar indexing where necessary and add a few tests #675

Merged
merged 3 commits into from
Apr 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/solver/CUSOLVER.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module CUSOLVER

using ..CuArrays
using ..CuArrays: libcusolver, @allowscalar, unsafe_free!, @argout, @workspace, @retry_reclaim
using ..CuArrays: libcusolver, @allowscalar, assertscalar, unsafe_free!, @argout, @workspace, @retry_reclaim

using ..CUBLAS: cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasDiagType_t
using ..CUSPARSE: cusparseMatDescr_t
Expand Down
3 changes: 2 additions & 1 deletion src/solver/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ LinearAlgebra.lmul!(trA::Transpose{T,<:CuQRPackedQ{T,S}}, B::CuVecOrMat{T}) wher
ormqr!('L', 'T', parent(trA).factors, parent(trA).τ, B)

function Base.getindex(A::CuQRPackedQ{T, S}, i::Integer, j::Integer) where {T, S}
assertscalar("CuQRPackedQ getindex")
x = CuArrays.zeros(T, size(A, 2))
x[j] = 1
lmul!(A, x)
return @allowscalar x[i]
return x[i]
end

function Base.show(io::IO, F::CuQR)
Expand Down
11 changes: 10 additions & 1 deletion test/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ k = 1
@test abs.(h_U'h_U) ≈ I
@test abs.(h_U[:,1:min(_m,_n)]'U[:,1:min(_m,_n)]) ≈ I
@test collect(svdvals(d_A, method)) ≈ svdvals(A)
@test abs.(h_V'h_V) ≈ I
@test abs.(h_V'*h_V) ≈ I
@test abs.(h_V[:,1:min(_m,_n)]'*V[:,1:min(_m,_n)]) ≈ I
@test collect(d_U'*d_A*d_V) ≈ U'*A*V
@test collect(svd(d_A, method).V') == h_V[:,1:min(_m,_n)]'
Expand All @@ -312,6 +312,7 @@ k = 1
tol = min(m, n)*eps(real(elty))*(1 + (elty <: Complex))

A = rand(elty, m, n)
qra = qr(A)
d_A = CuArray(A)
d_F = qr(d_A)
d_RR = d_F.Q'*d_A
Expand All @@ -320,6 +321,14 @@ k = 1
@test size(d_F) == size(A)
@test size(d_F.Q, 1) == size(A, 1)
@test det(d_F.Q) ≈ det(collect(d_F.Q * CuMatrix{elty}(I, size(d_F.Q)))) atol=tol*norm(A)
CuArrays.@allowscalar begin
qval = d_F.Q[1, 1]
@test qval ≈ qra.Q[1, 1]
qrstr = sprint(show, d_F)
@test qrstr == "$(typeof(d_F)) with factors Q and R:\n$(sprint(show, d_F.Q))\n$(sprint(show, d_F.R))"
end
dQ, dR = d_F
@test collect(dQ*dR) ≈ A
A = rand(elty, n, m)
d_A = CuArray(A)
d_F = qr(d_A)
Expand Down