Skip to content

Commit

Permalink
Merge pull request #422 from JuliaSparse/jn/cat
Browse files Browse the repository at this point in the history
cat performance tweaks
  • Loading branch information
vtjnash authored Aug 21, 2023
2 parents 2c4f870 + e2c78b8 commit 18b7fce
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 44 deletions.
56 changes: 35 additions & 21 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,9 @@ function hcat(Xin::AbstractSparseVector...)
X = map(_unsafe_unfix, Xin)
Tv = promote_type(map(eltype, X)...)
Ti = promote_type(map(indtype, X)...)
r = _absspvec_hcat(map(x -> convert(SparseVector{Tv,Ti}, x), X)...)
r = (function (::Type{SV}) where SV
_absspvec_hcat(map(x -> convert(SV, x), X)...)
end)(SparseVector{Tv,Ti})
return @if_move_fixed Xin... r
end
function _absspvec_hcat(X::AbstractSparseVector{Tv,Ti}...) where {Tv,Ti}
Expand Down Expand Up @@ -1144,7 +1146,9 @@ function vcat(Xin::AbstractSparseVector...)
X = map(_unsafe_unfix, Xin)
Tv = promote_type(map(eltype, X)...)
Ti = promote_type(map(indtype, X)...)
r = _absspvec_vcat(map(x -> convert(SparseVector{Tv,Ti}, x), X)...)
r = (function (::Type{SV}) where SV
_absspvec_vcat(map(x -> convert(SV, x), X)...)
end)(SparseVector{Tv,Ti})
return @if_move_fixed Xin... r
end
function _absspvec_vcat(X::AbstractSparseVector{Tv,Ti}...) where {Tv,Ti}
Expand Down Expand Up @@ -1194,13 +1198,14 @@ anysparse() = false
anysparse(X) = X isa AbstractArray && issparse(X)
anysparse(X, Xs...) = anysparse(X) || anysparse(Xs...)

function hcat(X::Union{Vector, AbstractSparseVector}...)
const _SparseVecConcatGroup = Union{Vector, AbstractSparseVector}
function hcat(X::_SparseVecConcatGroup...)
if anysparse(X...)
X = map(sparse, X)
end
return cat(X...; dims=Val(2))
end
function vcat(X::Union{Vector, AbstractSparseVector}...)
function vcat(X::_SparseVecConcatGroup...)
if anysparse(X...)
X = map(sparse, X)
end
Expand All @@ -1213,30 +1218,30 @@ end
const _SparseConcatGroup = Union{AbstractVecOrMat{<:Number},Number}

# `@constprop :aggressive` allows `dims` to be propagated as constant improving return type inference
Base.@constprop :aggressive function Base._cat(dims, X::_SparseConcatGroup...)
T = promote_eltype(X...)
if anysparse(X...)
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
Base.@constprop :aggressive function Base._cat(dims, X1::_SparseConcatGroup, X::_SparseConcatGroup...)
T = promote_eltype(X1, X...)
if anysparse(X1) || anysparse(X...)
X1, X = _sparse(X1), map(_makesparse, X)
end
return Base._cat_t(dims, T, X...)
return Base._cat_t(dims, T, X1, X...)
end
function hcat(X::_SparseConcatGroup...)
if anysparse(X...)
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
function hcat(X1::_SparseConcatGroup, X::_SparseConcatGroup...)
if anysparse(X1) || anysparse(X...)
X1, X = _sparse(X1), map(_makesparse, X)
end
return cat(X..., dims=Val(2))
return cat(X1, X..., dims=Val(2))
end
function vcat(X::_SparseConcatGroup...)
if anysparse(X...)
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
function vcat(X1::_SparseConcatGroup, X::_SparseConcatGroup...)
if anysparse(X1) || anysparse(X...)
X1, X = _sparse(X1), map(_makesparse, X)
end
return cat(X..., dims=Val(1))
return cat(X1, X..., dims=Val(1))
end
function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
if anysparse(X...)
vcat(_hvcat_rows(rows, X...)...)
function hvcat(rows::Tuple{Vararg{Int}}, X1::_SparseConcatGroup, X::_SparseConcatGroup...)
if anysparse(X1) || anysparse(X...)
vcat(_hvcat_rows(rows, X1, X...)...)
else
Base.typed_hvcat(Base.promote_eltypeof(X...), rows, X...)
Base.typed_hvcat(Base.promote_eltypeof(X1, X...), rows, X1, X...)
end
end
function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
Expand All @@ -1254,6 +1259,15 @@ function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup.
end
_hvcat_rows(::Tuple{}, X::_SparseConcatGroup...) = ()

# disambiguation for type-piracy problems created above
hcat(n1::Number, ns::Vararg{Number}) = invoke(hcat, Tuple{Vararg{Number}}, n1, ns...)
vcat(n1::Number, ns::Vararg{Number}) = invoke(vcat, Tuple{Vararg{Number}}, n1, ns...)
hcat(n1::Type{N}, ns::Vararg{N}) where {N<:Number} = invoke(hcat, Tuple{Vararg{Number}}, n1, ns...)
vcat(n1::Type{N}, ns::Vararg{N}) where {N<:Number} = invoke(vcat, Tuple{Vararg{Number}}, n1, ns...)
hvcat(rows::Tuple{Vararg{Int}}, n1::Number, ns::Vararg{Number}) = invoke(hvcat, Tuple{typeof(rows), Vararg{Number}}, rows, n1, ns...)
hvcat(rows::Tuple{Vararg{Int}}, n1::N, ns::Vararg{N}) where {N<:Number} = invoke(hvcat, Tuple{typeof(rows), Vararg{N}}, rows, n1, ns...)


# make sure UniformScaling objects are converted to sparse matrices for concatenation
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = anysparse(A...) ? SparseMatrixCSC : Matrix
promote_to_arrays_(n::Int, ::Type{SparseMatrixCSC}, J::UniformScaling) = sparse(J, n, n)
Expand Down
44 changes: 21 additions & 23 deletions test/ambiguous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,31 +50,29 @@ using Test, LinearAlgebra, SparseArrays, Aqua
end
end

@testset "detect_ambiguities" begin
@test_nowarn detect_ambiguities(SparseArrays; recursive=true, ambiguous_bottom=false)
let ambig = detect_ambiguities(SparseArrays; recursive=true)
@test_broken isempty(ambig)
ambig = Set{Any}(((m1.sig, m2.sig) for (m1, m2) in ambig))
expect = []
push!(expect, (Tuple{typeof(LinearAlgebra.generic_trimatmul!), AbstractVecOrMat, Any, Any, Function, AbstractMatrix, AbstractVecOrMat},
Tuple{typeof(LinearAlgebra.generic_trimatmul!), StridedVecOrMat, Any, Any, Any, Union{Adjoint{var"#s388", var"#s387"}, Transpose{var"#s388", var"#s387"}} where {var"#s388", var"#s387"<:(Union{SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}, SubArray{Tv, 2, <:SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}, Tuple{Base.Slice{Base.OneTo{Int}}, I}} where I<:AbstractUnitRange} where {Tv, Ti})}, AbstractVecOrMat}))
push!(expect, (Tuple{typeof(LinearAlgebra.generic_trimatmul!), AbstractVecOrMat, Any, Any, Function, Union{Adjoint{T, S}, Transpose{T, S}} where {T, S}, AbstractVecOrMat},
Tuple{typeof(LinearAlgebra.generic_trimatmul!), StridedVecOrMat, Any, Any, Any, Union{Adjoint{var"#s388", var"#s387"}, Transpose{var"#s388", var"#s387"}} where {var"#s388", var"#s387"<:(Union{SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}, SubArray{Tv, 2, <:SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}, Tuple{Base.Slice{Base.OneTo{Int}}, I}} where I<:AbstractUnitRange} where {Tv, Ti})}, AbstractVecOrMat}))
good = true
while !isempty(ambig)
sigs = pop!(ambig)
i = findfirst(==(sigs), expect)
if i === nothing
println(stderr, "push!(expect, (", sigs[1], ", ", sigs[2], "))")
good = false
continue
end
deleteat!(expect, i)
end
@test isempty(expect)
@test good
end

## This was the older version that was disabled

# let ambig = detect_ambiguities(SparseArrays; recursive=true)
# @test isempty(ambig)
# ambig = Set{Any}(((m1.sig, m2.sig) for (m1, m2) in ambig))
# expect = []
# good = true
# while !isempty(ambig)
# sigs = pop!(ambig)
# i = findfirst(==(sigs), expect)
# if i === nothing
# println(stderr, "push!(expect, (", sigs[1], ", ", sigs[2], "))")
# good = false
# continue
# end
# deleteat!(expect, i)
# end
# @test isempty(expect)
# @test good
# end

###
# Now we restore the original env, as promised
empty!(Base.DEPOT_PATH)
Expand Down

0 comments on commit 18b7fce

Please sign in to comment.