Skip to content

Commit

Permalink
Merge pull request #271 from avik-pal/ap/tagging
Browse files Browse the repository at this point in the history
(Kind-of) Type Stability Fixes for  No Chunksize Specified
  • Loading branch information
ChrisRackauckas authored Nov 1, 2023
2 parents 88dbb15 + 41033f4 commit 7d23bec
Show file tree
Hide file tree
Showing 16 changed files with 137 additions and 117 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseDiffTools"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
version = "2.9.1"
version = "2.9.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion ext/SparseDiffToolsSymbolicsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using SparseDiffTools, Symbolics
import SparseDiffTools: AbstractSparseADType

function (alg::SymbolicsSparsityDetection)(ad::AbstractSparseADType, f, x; fx = nothing,
kwargs...)
kwargs...)
fx = fx === nothing ? similar(f(x)) : dx
f!(y, x) = (y .= f(x))
J = Symbolics.jacobian_sparsity(f!, fx, x)
Expand Down
4 changes: 2 additions & 2 deletions ext/SparseDiffToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ end
end

function autoback_hesvec!(dy, f, x, v, cache1 = _default_autoback_hesvec_cache(x, v),
cache2 = _default_autoback_hesvec_cache(x, v))
cache2 = _default_autoback_hesvec_cache(x, v))
g = let f = f
(dx, x) -> dx .= first(Zygote.gradient(f, x))
end
Expand Down Expand Up @@ -140,7 +140,7 @@ end

# prefer non in-place method
function (L::AutoDiffVJP{<:AutoZygote, IIP, true})(dv, v, p, t;
VJP_input = nothing) where {IIP}
VJP_input = nothing) where {IIP}
# ignore VJP_input as pullback was computed in update_coefficients!(...)

_dv = L(v, p, t; VJP_input = VJP_input)
Expand Down
16 changes: 8 additions & 8 deletions src/coloring/acyclic_coloring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ the induced 2-colored subgraphs/trees where the id of set is an integer
representing an edge of graph 'g'
"""
function prevent_cycle!(first_visit_to_tree::AbstractVector{<:Tuple{Integer, Integer}},
forbidden_colors::AbstractVector{<:Integer}, v::Integer, w::Integer, x::Integer,
g::Graphs.AbstractGraph, two_colored_forest::DisjointSets{<:Integer},
color::AbstractVector{<:Integer})
forbidden_colors::AbstractVector{<:Integer}, v::Integer, w::Integer, x::Integer,
g::Graphs.AbstractGraph, two_colored_forest::DisjointSets{<:Integer},
color::AbstractVector{<:Integer})
e = find(w, x, g, two_colored_forest)
p, q = first_visit_to_tree[e]

Expand All @@ -97,8 +97,8 @@ Disjoint set is used to store stars in sets, which are identified through key
edges present in g.
"""
function grow_star!(two_colored_forest::DisjointSets{<:Integer},
first_neighbor::AbstractVector{<:Tuple{Integer, Integer}}, v::Integer, w::Integer,
g::Graphs.AbstractGraph, color::AbstractVector{<:Integer})
first_neighbor::AbstractVector{<:Tuple{Integer, Integer}}, v::Integer, w::Integer,
g::Graphs.AbstractGraph, color::AbstractVector{<:Integer})
insert_new_tree!(two_colored_forest, v, w, g)
p, q = first_neighbor[color[w]]

Expand All @@ -119,7 +119,7 @@ Subroutine to merge trees present in the disjoint set which have a
common edge.
"""
function merge_trees!(two_colored_forest::DisjointSets{<:Integer}, v::Integer, w::Integer,
x::Integer, g::Graphs.AbstractGraph)
x::Integer, g::Graphs.AbstractGraph)
e1 = find(v, w, g, two_colored_forest)
e2 = find(w, x, g, two_colored_forest)
if e1 != e2
Expand All @@ -135,7 +135,7 @@ creates a new singleton set in the disjoint set 'two_colored_forest' consisting
of the edge connecting v and w in the graph g
"""
function insert_new_tree!(two_colored_forest::DisjointSets{<:Integer}, v::Integer,
w::Integer, g::Graphs.AbstractGraph)
w::Integer, g::Graphs.AbstractGraph)
edge_index = find_edge_index(v, w, g)
push!(two_colored_forest, edge_index)
end
Expand All @@ -157,7 +157,7 @@ Returns the root of the disjoint set to which the edge connecting vertices w and
in the graph g belongs to
"""
function find(w::Integer, x::Integer, g::Graphs.AbstractGraph,
two_colored_forest::DisjointSets{<:Integer})
two_colored_forest::DisjointSets{<:Integer})
edge_index = find_edge_index(w, x, g)
return find_root!(two_colored_forest, edge_index)
end
Expand Down
10 changes: 5 additions & 5 deletions src/coloring/backtracking_coloring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ Returns an uncolored vertex from the partially
colored graph which has the highest degree
"""
function uncolored_vertex_of_maximal_degree(A::AbstractVector{<:Integer},
F::AbstractVector{<:Integer})
F::AbstractVector{<:Integer})
for v in A
if F[v] == 0
return v
Expand Down Expand Up @@ -145,8 +145,8 @@ g: Graph to be colored
opt: Current optimal number of colors to be used in the coloring of graph g
"""
function free_colors(x::Integer, A::AbstractVector{<:Integer},
colors::AbstractVector{<:Integer}, F::Vector{Integer}, g::Graphs.AbstractGraph,
opt::Integer)
colors::AbstractVector{<:Integer}, F::Vector{Integer}, g::Graphs.AbstractGraph,
opt::Integer)
index = -1

freecolors = zeros(Int, 0)
Expand Down Expand Up @@ -187,7 +187,7 @@ Returns least index i such that color of vertex
A[i] is equal to `opt` (optimal chromatic number)
"""
function least_index(F::AbstractVector{<:Integer}, A::AbstractVector{<:Integer},
opt::Integer)
opt::Integer)
for i in eachindex(A)
if F[A[i]] == opt
return i
Expand All @@ -202,7 +202,7 @@ Uncolors all vertices A[i] where i is
greater than or equal to start
"""
function uncolor_all!(F::AbstractVector{<:Integer}, A::AbstractVector{<:Integer},
start::Integer)
start::Integer)
for i in start:length(A)
F[A[i]] = 0
end
Expand Down
4 changes: 2 additions & 2 deletions src/coloring/high_level.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ If `ArrayInterface.fast_matrix_colors(A)` is true, then uses
`ArrayInterface.matrix_colors(A)` to compute the matrix colors.
"""
function ArrayInterface.matrix_colors(A::AbstractMatrix,
alg::SparseDiffToolsColoringAlgorithm = GreedyD1Color();
partition_by_rows::Bool = false)
alg::SparseDiffToolsColoringAlgorithm = GreedyD1Color();
partition_by_rows::Bool = false)

# If fast algorithm for matrix coloring exists use that
if !partition_by_rows
Expand Down
2 changes: 1 addition & 1 deletion src/coloring/matrix2graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Note that the sparsity pattern is defined by structural nonzeroes, ie includes e
stored zeros.
"""
function matrix2graph(sparse_matrix::AbstractSparseMatrix{<:Number},
partition_by_rows::Bool = true)
partition_by_rows::Bool = true)
(rows_index, cols_index, _) = findnz(sparse_matrix)

ncols = size(sparse_matrix, 2)
Expand Down
36 changes: 18 additions & 18 deletions src/differentiation/compute_hessian_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ function make_hessian_buffers(colorvec, x)
end

function ForwardColorHesCache(f, x::AbstractVector{<:Number},
colorvec::AbstractVector{<:Integer} = eachindex(x),
sparsity::Union{AbstractMatrix, Nothing} = nothing,
g! = (G, x, grad_config) -> ForwardDiff.gradient!(G, f, x, grad_config))
colorvec::AbstractVector{<:Integer} = eachindex(x),
sparsity::Union{AbstractMatrix, Nothing} = nothing,
g! = (G, x, grad_config) -> ForwardDiff.gradient!(G, f, x, grad_config))
ncolors, D, buffer, G, G2 = make_hessian_buffers(colorvec, x)
grad_config = ForwardDiff.GradientConfig(f, x)

Expand All @@ -46,7 +46,7 @@ function ForwardColorHesCache(f, x::AbstractVector{<:Number},
end

function numauto_color_hessian!(H::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
hes_cache::ForwardColorHesCache; safe = true)
hes_cache::ForwardColorHesCache; safe = true)
ϵ = cbrt(eps(eltype(x)))
for j in 1:(hes_cache.ncolors)
x .+= ϵ .* @view hes_cache.D[:, j]
Expand All @@ -67,23 +67,23 @@ function numauto_color_hessian!(H::AbstractMatrix{<:Number}, f, x::AbstractArray
end

function numauto_color_hessian!(H::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
colorvec::AbstractVector{<:Integer} = eachindex(x),
sparsity::Union{AbstractMatrix, Nothing} = nothing)
colorvec::AbstractVector{<:Integer} = eachindex(x),
sparsity::Union{AbstractMatrix, Nothing} = nothing)
hes_cache = ForwardColorHesCache(f, x, colorvec, sparsity)
numauto_color_hessian!(H, f, x, hes_cache)
return H
end

function numauto_color_hessian(f, x::AbstractArray{<:Number},
hes_cache::ForwardColorHesCache)
hes_cache::ForwardColorHesCache)
H = convert.(eltype(x), hes_cache.sparsity)
numauto_color_hessian!(H, f, x, hes_cache)
return H
end

function numauto_color_hessian(f, x::AbstractArray{<:Number},
colorvec::AbstractVector{<:Integer} = eachindex(x),
sparsity::Union{AbstractMatrix, Nothing} = nothing)
colorvec::AbstractVector{<:Integer} = eachindex(x),
sparsity::Union{AbstractMatrix, Nothing} = nothing)
hes_cache = ForwardColorHesCache(f, x, colorvec, sparsity)
H = convert.(eltype(x), hes_cache.sparsity)
numauto_color_hessian!(H, f, x, hes_cache)
Expand All @@ -102,9 +102,9 @@ end
struct AutoAutoTag end

function ForwardAutoColorHesCache(f, x::AbstractVector{V},
colorvec::AbstractVector{<:Integer} = eachindex(x),
sparsity::Union{AbstractMatrix, Nothing} = nothing,
tag::ForwardDiff.Tag = ForwardDiff.Tag(AutoAutoTag(), V)) where {V}
colorvec::AbstractVector{<:Integer} = eachindex(x),
sparsity::Union{AbstractMatrix, Nothing} = nothing,
tag::ForwardDiff.Tag = ForwardDiff.Tag(AutoAutoTag(), V)) where {V}
if sparsity === nothing
sparsity = sparse(ones(length(x), length(x)))
end
Expand All @@ -124,28 +124,28 @@ function ForwardAutoColorHesCache(f, x::AbstractVector{V},
end

function autoauto_color_hessian!(H::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
hes_cache::ForwardAutoColorHesCache)
hes_cache::ForwardAutoColorHesCache)
forwarddiff_color_jacobian!(H, hes_cache.grad!, x, hes_cache.jac_cache)
end

function autoauto_color_hessian!(H::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
colorvec::AbstractVector{<:Integer} = eachindex(x),
sparsity::Union{AbstractMatrix, Nothing} = nothing)
colorvec::AbstractVector{<:Integer} = eachindex(x),
sparsity::Union{AbstractMatrix, Nothing} = nothing)
hes_cache = ForwardAutoColorHesCache(f, x, colorvec, sparsity)
autoauto_color_hessian!(H, f, x, hes_cache)
return H
end

function autoauto_color_hessian(f, x::AbstractArray{<:Number},
hes_cache::ForwardAutoColorHesCache)
hes_cache::ForwardAutoColorHesCache)
H = convert.(eltype(x), hes_cache.sparsity)
autoauto_color_hessian!(H, f, x, hes_cache)
return H
end

function autoauto_color_hessian(f, x::AbstractArray{<:Number},
colorvec::AbstractVector{<:Integer} = eachindex(x),
sparsity::Union{AbstractMatrix, Nothing} = nothing)
colorvec::AbstractVector{<:Integer} = eachindex(x),
sparsity::Union{AbstractMatrix, Nothing} = nothing)
hes_cache = ForwardAutoColorHesCache(f, x, colorvec, sparsity)
H = convert.(eltype(x), hes_cache.sparsity)
autoauto_color_hessian!(H, f, x, hes_cache)
Expand Down
50 changes: 25 additions & 25 deletions src/differentiation/compute_jacobian_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const default_chunk_size = ForwardDiff.pickchunksize
const SMALLTAG = typeof(ForwardDiff.Tag(missing, Float64))

function ForwardColorJacCache(f::F, x, _chunksize = nothing; dx = nothing, tag = nothing,
colorvec = 1:length(x), sparsity::Union{AbstractArray, Nothing} = nothing) where {F}
colorvec = 1:length(x), sparsity::Union{AbstractArray, Nothing} = nothing) where {F}
if _chunksize isa Nothing
chunksize = ForwardDiff.pickchunksize(maximum(colorvec))
else
Expand Down Expand Up @@ -105,13 +105,13 @@ end
end

function forwarddiff_color_jacobian(f::F,
x::AbstractArray{<:Number};
colorvec = 1:length(x),
sparsity = nothing,
jac_prototype = nothing,
chunksize = nothing,
dx = sparsity === nothing && jac_prototype === nothing ?
nothing : copy(x)) where {F} #if dx is nothing, we will estimate dx at the cost of a function call
x::AbstractArray{<:Number};
colorvec = 1:length(x),
sparsity = nothing,
jac_prototype = nothing,
chunksize = nothing,
dx = sparsity === nothing && jac_prototype === nothing ?
nothing : copy(x)) where {F} #if dx is nothing, we will estimate dx at the cost of a function call
if sparsity === nothing && jac_prototype === nothing
cfg = if chunksize === nothing
if typeof(x) <: StaticArrays.StaticArray
Expand All @@ -136,12 +136,12 @@ function forwarddiff_color_jacobian(f::F,
end

function forwarddiff_color_jacobian(J::AbstractArray{<:Number}, f::F,
x::AbstractArray{<:Number};
colorvec = 1:length(x),
sparsity = nothing,
jac_prototype = nothing,
chunksize = nothing,
dx = similar(x, size(J, 1))) where {F} #dx kwarg can be used to avoid re-allocating dx every time
x::AbstractArray{<:Number};
colorvec = 1:length(x),
sparsity = nothing,
jac_prototype = nothing,
chunksize = nothing,
dx = similar(x, size(J, 1))) where {F} #dx kwarg can be used to avoid re-allocating dx every time
if sparsity === nothing && jac_prototype === nothing
cfg = chunksize === nothing ? ForwardDiff.JacobianConfig(f, x) :
ForwardDiff.JacobianConfig(f, x, ForwardDiff.Chunk(getsize(chunksize)))
Expand All @@ -154,8 +154,8 @@ function forwarddiff_color_jacobian(J::AbstractArray{<:Number}, f::F,
end

function forwarddiff_color_jacobian(f::F, x::AbstractArray{<:Number},
jac_cache::ForwardColorJacCache,
jac_prototype = nothing) where {F}
jac_cache::ForwardColorJacCache,
jac_prototype = nothing) where {F}
if jac_prototype isa Nothing ? ArrayInterface.ismutable(x) :
ArrayInterface.ismutable(jac_prototype)
# Whenever J is mutable, we mutate it to avoid allocations
Expand All @@ -174,8 +174,8 @@ end

# When J is mutable, this version of forwarddiff_color_jacobian will mutate J to avoid allocations
function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number}, f::F,
x::AbstractArray{<:Number},
jac_cache::ForwardColorJacCache) where {F}
x::AbstractArray{<:Number},
jac_cache::ForwardColorJacCache) where {F}
t = jac_cache.t
dx = jac_cache.dx
p = jac_cache.p
Expand Down Expand Up @@ -248,8 +248,8 @@ end

# When J is immutable, this version of forwarddiff_color_jacobian will avoid mutating J
function forwarddiff_color_jacobian_immutable(f, x::AbstractArray{<:Number},
jac_cache::ForwardColorJacCache,
jac_prototype = nothing)
jac_cache::ForwardColorJacCache,
jac_prototype = nothing)
t = jac_cache.t
dx = jac_cache.dx
p = jac_cache.p
Expand Down Expand Up @@ -313,15 +313,15 @@ function forwarddiff_color_jacobian_immutable(f, x::AbstractArray{<:Number},
end

function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number}, f,
x::AbstractArray{<:Number}; dx = similar(x, size(J, 1)), colorvec = 1:length(x),
sparsity = ArrayInterface.has_sparsestruct(J) ? J : nothing)
x::AbstractArray{<:Number}; dx = similar(x, size(J, 1)), colorvec = 1:length(x),
sparsity = ArrayInterface.has_sparsestruct(J) ? J : nothing)
forwarddiff_color_jacobian!(J, f, x, ForwardColorJacCache(f, x; dx, colorvec, sparsity))
end

function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
f,
x::AbstractArray{<:Number},
jac_cache::ForwardColorJacCache)
f,
x::AbstractArray{<:Number},
jac_cache::ForwardColorJacCache)
t = jac_cache.t
fx = jac_cache.fx
dx = jac_cache.dx
Expand Down
Loading

0 comments on commit 7d23bec

Please sign in to comment.