Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce allocations for multipling LazyTensor of sparse and dense #80

Merged
merged 17 commits into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,5 +392,9 @@ multiplicable(a::AbstractOperator, b::Ket) = multiplicable(a.basis_r, b.basis)
multiplicable(a::Bra, b::AbstractOperator) = multiplicable(a.basis, b.basis_l)
multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l)

Base.size(op::AbstractOperator) = prod(length(op.basis_l),length(op.basis_r))
Base.size(op::AbstractOperator, i::Int) = (i==1 ? length(op.basis_l) : length(op.basis_r))
Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r))
function Base.size(op::AbstractOperator, i::Int)
i < 1 && throw(ErrorException(lazy"dimension out of range, should be strictly positive, got $i"))
i > 2 && return 1
i==1 ? length(op.basis_l) : length(op.basis_r)
end
4 changes: 4 additions & 0 deletions src/operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ function _strides(shape)
return S
end

function _strides(shape::Ty)::Ty where Ty <: Tuple
accumulate(*, (1,Base.front(shape)...))
end

# Dense operator version
@generated function _ptrace(::Type{Val{RANK}}, a,
shape_l, shape_r,
Expand Down
74 changes: 50 additions & 24 deletions src/operators_lazytensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,11 @@ end
function _gemm_recursive_dense_lazy(i_k, N_k, K, J, val,
shape, strides_k, strides_j,
indices, h::LazyTensor,
op::Matrix, result::Matrix)
op::AbstractArray, result::AbstractArray)
if i_k > N_k
for I=1:size(op, 1)
if isa(op, AbstractVector)
result[K] += val*op[J]
else I=1:size(op, 1)
result[I, K] += val*op[I, J]
end
return nothing
Expand Down Expand Up @@ -609,7 +611,7 @@ end
function _gemm_recursive_lazy_dense(i_k, N_k, K, J, val,
shape, strides_k, strides_j,
indices, h::LazyTensor,
op::Matrix, result::Matrix)
op::AbstractArray, result::AbstractArray)
if i_k > N_k
for I=1:size(op, 2)
result[J, I] += val*op[K, I]
AmitRotem marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -641,45 +643,69 @@ function _gemm_recursive_lazy_dense(i_k, N_k, K, J, val,
end
end

function _gemm_puresparse(alpha, op::Matrix, h::LazyTensor{B1,B2,F,I,T}, beta, result::Matrix) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
"""
check_mul!_compatibility(R, A, B)
Check that `R,A,B` are dimentially compatible for `R.=A*B`. And that `R` is not aliased with either `A` nor `B`.
"""
function check_mul!_compatibility(R::AbstractVecOrMat, A, B)
_check_mul!_aliasing_compatibility(R, A, B)
_check_mul!_dim_compatibility(size(R), size(A), size(B))
end
function _check_mul!_dim_compatibility(sizeR::Tuple, sizeA::Tuple, sizeB::Tuple)
# R .= A*B
if sizeA[2] != sizeB[1]
throw(DimensionMismatch(lazy"A has dimensions $sizeA but B has dimensions $sizeB. Can't do `A*B`"))
end
if sizeR != (sizeA[1], Base.tail(sizeB)...) # using tail to account for vectors
throw(DimensionMismatch(lazy"R has dimensions $sizeR but A*B has dimensions $((sizeA[1], Base.tail(sizeB)...)). Can't do `R.=A*B`"))
end
end
function _check_mul!_aliasing_compatibility(R, A, B)
if R===A || R===B
throw(ArgumentError(lazy"output matrix must not be aliased with input matrix"))
end
end


function _gemm_puresparse(alpha, op::AbstractArray, h::LazyTensor{B1,B2,F,I,T}, beta, result::AbstractArray) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
if op isa AbstractVector
# _gemm_recursive_dense_lazy will treat `op` as a `Bra`
_check_mul!_aliasing_compatibility(result, op, h)
_check_mul!_dim_compatibility(size(result), reverse(size(h)), size(op))
else
check_mul!_compatibility(result, op, h)
end
if iszero(beta)
fill!(result, beta)
elseif !isone(beta)
rmul!(result, beta)
end
N_k = length(h.basis_r.bases)
shape = [min(h.basis_l.shape[i], h.basis_r.shape[i]) for i=1:length(h.basis_l.shape)]
strides_j = _strides(h.basis_l.shape)
strides_k = _strides(h.basis_r.shape)
shape, strides_j, strides_k = _get_shape_and_strides(h)
_gemm_recursive_dense_lazy(1, N_k, 1, 1, alpha*h.factor, shape, strides_k, strides_j, h.indices, h, op, result)
end

function _gemm_puresparse(alpha, h::LazyTensor{B1,B2,F,I,T}, op::Matrix, beta, result::Matrix) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
function _gemm_puresparse(alpha, h::LazyTensor{B1,B2,F,I,T}, op::AbstractArray, beta, result::AbstractArray) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
check_mul!_compatibility(result, h, op)
if iszero(beta)
fill!(result, beta)
elseif !isone(beta)
rmul!(result, beta)
end
N_k = length(h.basis_l.bases)
shape = [min(h.basis_l.shape[i], h.basis_r.shape[i]) for i=1:length(h.basis_l.shape)]
strides_j = _strides(h.basis_l.shape)
strides_k = _strides(h.basis_r.shape)
shape, strides_j, strides_k = _get_shape_and_strides(h)
_gemm_recursive_lazy_dense(1, N_k, 1, 1, alpha*h.factor, shape, strides_k, strides_j, h.indices, h, op, result)
end

function _get_shape_and_strides(h)
shape_l, shape_r = _comp_size(h.basis_l), _comp_size(h.basis_r)
shape = min.(shape_l, shape_r)
strides_j, strides_k = _strides(shape_l), _strides(shape_r)
return shape, strides_j, strides_k
end

_mul_puresparse!(result::DenseOpType{B1,B3},h::LazyTensor{B1,B2,F,I,T},op::DenseOpType{B2,B3},alpha,beta) where {B1,B2,B3,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, h, op.data, beta, result.data); result)
_mul_puresparse!(result::DenseOpType{B1,B3},op::DenseOpType{B1,B2},h::LazyTensor{B2,B3,F,I,T},alpha,beta) where {B1,B2,B3,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, op.data, h, beta, result.data); result)
_mul_puresparse!(result::Ket{B1},a::LazyTensor{B1,B2,F,I,T},b::Ket{B2},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, a, b.data, beta, result.data); result)
_mul_puresparse!(result::Bra{B2},a::Bra{B1},b::LazyTensor{B1,B2,F,I,T},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, a.data, b, beta, result.data); result)

function _mul_puresparse!(result::Ket{B1},a::LazyTensor{B1,B2,F,I,T},b::Ket{B2},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
b_data = reshape(b.data, length(b.data), 1)
result_data = reshape(result.data, length(result.data), 1)
amilsted marked this conversation as resolved.
Show resolved Hide resolved
_gemm_puresparse(alpha, a, b_data, beta, result_data)
result
end

function _mul_puresparse!(result::Bra{B2},a::Bra{B1},b::LazyTensor{B1,B2,F,I,T},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
a_data = reshape(a.data, 1, length(a.data))
result_data = reshape(result.data, 1, length(result.data))
_gemm_puresparse(alpha, a_data, b, beta, result_data)
result
end
10 changes: 10 additions & 0 deletions test/test_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,14 @@ op12 = destroy(bfock)⊗sigmap(bspin)
@test embed(b, [1,2], op12) == destroy(bfock)⊗sigmap(bspin)⊗one(bspin)
@test embed(b, [1,3], op12) == destroy(bfock)⊗one(bspin)⊗sigmap(bspin)

# size of AbstractOperator
b1, b2 = NLevelBasis.((2, 3))
Lop1 = LazyTensor(b1^2, b2^2, 2, sparse(randoperator(b1, b2)))
@test size(Lop1) == size(dense(Lop1)) == size(dense(Lop1).data)
@test all(size(Lop1, k) == size(dense(Lop1), k) for k=1:4)
@test_throws ErrorException size(Lop1, 0)
@test_throws ErrorException size(Lop1, -1)
@test_throws ErrorException size(dense(Lop1), 0) # check for consistency
@test_throws ErrorException size(dense(Lop1), -1)

end # testset
8 changes: 8 additions & 0 deletions test/test_operators_lazytensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,5 +404,13 @@ dop = randoperator(b3a⊗b3b, b2a⊗b2b)
@test dop*lop' ≈ Operator(dop.basis_l, lop.basis_l, dop.data*dense(lop).data')
@test lop*dop' ≈ Operator(lop.basis_l, dop.basis_l, dense(lop).data*dop.data')

# Dimension mismatches for LazyTensor with sparse
b1, b2 = NLevelBasis.((2, 3))
Lop1 = LazyTensor(b1^2, b2^2, 2, sparse(randoperator(b1, b2)))
@test_throws DimensionMismatch Lop1*Lop1
@test_throws DimensionMismatch dense(Lop1)*Lop1
@test_throws DimensionMismatch sparse(Lop1)*Lop1
@test_throws DimensionMismatch Lop1*dense(Lop1)
@test_throws DimensionMismatch Lop1*sparse(Lop1)

end # testset