diff --git a/src/highlevel/common.jl b/src/highlevel/common.jl index 133eed58..8c8e3e51 100644 --- a/src/highlevel/common.jl +++ b/src/highlevel/common.jl @@ -186,7 +186,8 @@ function sparse_jacobian(ad::AbstractADType, sd::AbstractMaybeSparsityDetection, kwargs...) cache = sparse_jacobian_cache(ad, sd, args...; kwargs...) J = init_jacobian(cache) - return sparse_jacobian!(J, ad, cache, args...) + sparse_jacobian!(J, ad, cache, args...) + return J end """ @@ -199,7 +200,8 @@ Jacobian at every function call function sparse_jacobian(ad::AbstractADType, cache::AbstractMaybeSparseJacobianCache, args...) J = init_jacobian(cache) - return sparse_jacobian!(J, ad, cache, args...) + sparse_jacobian!(J, ad, cache, args...) + return J end """ @@ -216,7 +218,8 @@ with the same cache to compute the jacobian. function sparse_jacobian!(J::AbstractMatrix, ad::AbstractADType, sd::AbstractMaybeSparsityDetection, args...; kwargs...) cache = sparse_jacobian_cache(ad, sd, args...; kwargs...) - return sparse_jacobian!(J, ad, cache, args...) + sparse_jacobian!(J, ad, cache, args...) + return J end ## Internal diff --git a/src/highlevel/finite_diff.jl b/src/highlevel/finite_diff.jl index b9c955d7..31114ae0 100644 --- a/src/highlevel/finite_diff.jl +++ b/src/highlevel/finite_diff.jl @@ -6,6 +6,8 @@ struct FiniteDiffJacobianCache{CO, CA, J, FX, X} <: AbstractMaybeSparseJacobianC x::X end +__getfield(c::FiniteDiffJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype + function sparse_jacobian_cache(fd::Union{AutoSparseFiniteDiff, AutoFiniteDiff}, sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F} coloring_result = sd(fd, f, x) diff --git a/src/highlevel/forward_mode.jl b/src/highlevel/forward_mode.jl index f1dc6e2b..47bafd77 100644 --- a/src/highlevel/forward_mode.jl +++ b/src/highlevel/forward_mode.jl @@ -6,15 +6,20 @@ struct ForwardDiffJacobianCache{CO, CA, J, FX, X} <: AbstractMaybeSparseJacobian x::X end +__getfield(c::ForwardDiffJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype + struct SparseDiffToolsTag end +__standard_tag(::Nothing, x) = ForwardDiff.Tag(SparseDiffToolsTag(), eltype(x)) +__standard_tag(tag, _) = tag + function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff}, sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F} coloring_result = sd(ad, f, x) fx = fx === nothing ? similar(f(x)) : fx if coloring_result isa NoMatrixColoring cache = ForwardDiff.JacobianConfig(f, x, __chunksize(ad, x), - ifelse(ad.tag === nothing, SparseDiffToolsTag(), ad.tag)) + __standard_tag(ad.tag, x)) jac_prototype = nothing else cache = ForwardColorJacCache(f, x, __chunksize(ad); coloring_result.colorvec, @@ -29,7 +34,7 @@ function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff} coloring_result = sd(ad, f!, fx, x) if coloring_result isa NoMatrixColoring cache = ForwardDiff.JacobianConfig(f!, fx, x, __chunksize(ad, x), - ifelse(ad.tag === nothing, SparseDiffToolsTag(), ad.tag)) + __standard_tag(ad.tag, x)) jac_prototype = nothing else cache = ForwardColorJacCache(f!, x, __chunksize(ad); coloring_result.colorvec, @@ -44,7 +49,8 @@ function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache, if cache.cache isa ForwardColorJacCache forwarddiff_color_jacobian(J, f, x, cache.cache) # Use Sparse ForwardDiff else - ForwardDiff.jacobian!(J, f, x, cache.cache) # Don't try to exploit sparsity + # Disable tag checking since we set the tag to our custom tag + ForwardDiff.jacobian!(J, f, x, cache.cache, Val(false)) # Don't try to exploit sparsity end return J end @@ -54,7 +60,8 @@ function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache, if cache.cache isa ForwardColorJacCache forwarddiff_color_jacobian!(J, f!, x, cache.cache) # Use Sparse ForwardDiff else - ForwardDiff.jacobian!(J, f!, fx, x, cache.cache) # Don't try to exploit sparsity + # Disable tag checking since we set the tag to our custom tag + ForwardDiff.jacobian!(J, f!, fx, x, cache.cache, Val(false)) # Don't try to exploit sparsity end return J end diff --git a/src/highlevel/reverse_mode.jl b/src/highlevel/reverse_mode.jl index 815a879c..b24c42dc 100644 --- a/src/highlevel/reverse_mode.jl +++ b/src/highlevel/reverse_mode.jl @@ -7,6 +7,8 @@ struct ReverseModeJacobianCache{CO, CA, J, FX, X, I} <: AbstractMaybeSparseJacob idx_vec::I end +__getfield(c::ReverseModeJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype + function sparse_jacobian_cache(ad::Union{AutoEnzyme, AbstractReverseMode}, sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F} fx = fx === nothing ? similar(f(x)) : fx