Skip to content

Commit

Permalink
Make things kind-of type stable when chunksize is not specified
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 30, 2023
1 parent c8c61bf commit 63fc800
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 7 deletions.
9 changes: 6 additions & 3 deletions src/highlevel/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ function sparse_jacobian(ad::AbstractADType, sd::AbstractMaybeSparsityDetection,
kwargs...)
cache = sparse_jacobian_cache(ad, sd, args...; kwargs...)
J = init_jacobian(cache)
return sparse_jacobian!(J, ad, cache, args...)
sparse_jacobian!(J, ad, cache, args...)
return J
end

"""
Expand All @@ -199,7 +200,8 @@ Jacobian at every function call
function sparse_jacobian(ad::AbstractADType, cache::AbstractMaybeSparseJacobianCache,
args...)
J = init_jacobian(cache)
return sparse_jacobian!(J, ad, cache, args...)
sparse_jacobian!(J, ad, cache, args...)
return J
end

"""
Expand All @@ -216,7 +218,8 @@ with the same cache to compute the jacobian.
function sparse_jacobian!(J::AbstractMatrix, ad::AbstractADType,
sd::AbstractMaybeSparsityDetection, args...; kwargs...)
cache = sparse_jacobian_cache(ad, sd, args...; kwargs...)
return sparse_jacobian!(J, ad, cache, args...)
sparse_jacobian!(J, ad, cache, args...)
return J
end

## Internal
Expand Down
2 changes: 2 additions & 0 deletions src/highlevel/finite_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ struct FiniteDiffJacobianCache{CO, CA, J, FX, X} <: AbstractMaybeSparseJacobianC
x::X
end

__getfield(c::FiniteDiffJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype

function sparse_jacobian_cache(fd::Union{AutoSparseFiniteDiff, AutoFiniteDiff},
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
coloring_result = sd(fd, f, x)
Expand Down
15 changes: 11 additions & 4 deletions src/highlevel/forward_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@ struct ForwardDiffJacobianCache{CO, CA, J, FX, X} <: AbstractMaybeSparseJacobian
x::X
end

__getfield(c::ForwardDiffJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype

struct SparseDiffToolsTag end

__standard_tag(::Nothing, x) = ForwardDiff.Tag(SparseDiffToolsTag(), eltype(x))
__standard_tag(tag, _) = tag

Check warning on line 14 in src/highlevel/forward_mode.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/forward_mode.jl#L14

Added line #L14 was not covered by tests

function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff},
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
coloring_result = sd(ad, f, x)
fx = fx === nothing ? similar(f(x)) : fx
if coloring_result isa NoMatrixColoring
cache = ForwardDiff.JacobianConfig(f, x, __chunksize(ad, x),
ifelse(ad.tag === nothing, SparseDiffToolsTag(), ad.tag))
__standard_tag(ad.tag, x))
jac_prototype = nothing
else
cache = ForwardColorJacCache(f, x, __chunksize(ad); coloring_result.colorvec,
Expand All @@ -29,7 +34,7 @@ function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff}
coloring_result = sd(ad, f!, fx, x)
if coloring_result isa NoMatrixColoring
cache = ForwardDiff.JacobianConfig(f!, fx, x, __chunksize(ad, x),
ifelse(ad.tag === nothing, SparseDiffToolsTag(), ad.tag))
__standard_tag(ad.tag, x))
jac_prototype = nothing
else
cache = ForwardColorJacCache(f!, x, __chunksize(ad); coloring_result.colorvec,
Expand All @@ -44,7 +49,8 @@ function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache,
if cache.cache isa ForwardColorJacCache
forwarddiff_color_jacobian(J, f, x, cache.cache) # Use Sparse ForwardDiff
else
ForwardDiff.jacobian!(J, f, x, cache.cache) # Don't try to exploit sparsity
# Disable tag checking since we set the tag to our custom tag
ForwardDiff.jacobian!(J, f, x, cache.cache, Val(false)) # Don't try to exploit sparsity
end
return J
end
Expand All @@ -54,7 +60,8 @@ function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache,
if cache.cache isa ForwardColorJacCache
forwarddiff_color_jacobian!(J, f!, x, cache.cache) # Use Sparse ForwardDiff
else
ForwardDiff.jacobian!(J, f!, fx, x, cache.cache) # Don't try to exploit sparsity
# Disable tag checking since we set the tag to our custom tag
ForwardDiff.jacobian!(J, f!, fx, x, cache.cache, Val(false)) # Don't try to exploit sparsity
end
return J
end
2 changes: 2 additions & 0 deletions src/highlevel/reverse_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ struct ReverseModeJacobianCache{CO, CA, J, FX, X, I} <: AbstractMaybeSparseJacob
idx_vec::I
end

__getfield(c::ReverseModeJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype

function sparse_jacobian_cache(ad::Union{AutoEnzyme, AbstractReverseMode},
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
fx = fx === nothing ? similar(f(x)) : fx
Expand Down

0 comments on commit 63fc800

Please sign in to comment.