Skip to content

Commit

Permalink
Enable tag checking
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 31, 2023
1 parent a7666fa commit d6fae4e
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions src/highlevel/forward_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,26 @@ __getfield(c::ForwardDiffJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype

struct SparseDiffToolsTag end

function ForwardDiff.checktag(::Type{<:ForwardDiff.Tag{<:SparseDiffToolsTag, <:T}}, f::F,

Check warning on line 13 in src/highlevel/forward_mode.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/forward_mode.jl#L13

Added line #L13 was not covered by tests
x::AbstractArray{T}) where {T, F}
return true

Check warning on line 15 in src/highlevel/forward_mode.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/forward_mode.jl#L15

Added line #L15 was not covered by tests
end

__standard_tag(::Nothing, x) = ForwardDiff.Tag(SparseDiffToolsTag(), eltype(x))
__standard_tag(tag, _) = tag
__standard_tag(tag, x) = ForwardDiff.Tag(tag, eltype(x))

Check warning on line 19 in src/highlevel/forward_mode.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/forward_mode.jl#L18-L19

Added lines #L18 - L19 were not covered by tests

function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff},
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
coloring_result = sd(ad, f, x)
fx = fx === nothing ? similar(f(x)) : fx
tag = __standard_tag(ad.tag, x)

Check warning on line 25 in src/highlevel/forward_mode.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/forward_mode.jl#L25

Added line #L25 was not covered by tests
if coloring_result isa NoMatrixColoring
cache = ForwardDiff.JacobianConfig(f, x, __chunksize(ad, x),
__standard_tag(ad.tag, x))
cache = ForwardDiff.JacobianConfig(f, x, __chunksize(ad, x), tag)

Check warning on line 27 in src/highlevel/forward_mode.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/forward_mode.jl#L27

Added line #L27 was not covered by tests
jac_prototype = nothing
else
# Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof`
cache = ForwardColorJacCache(f, x, __chunksize(ad); coloring_result.colorvec,
dx = fx, sparsity = coloring_result.jacobian_sparsity, ad.tag)
dx = fx, sparsity = coloring_result.jacobian_sparsity, tag = typeof(tag))
jac_prototype = coloring_result.jacobian_sparsity
end
return ForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
Expand All @@ -32,13 +38,14 @@ end
function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff},
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
coloring_result = sd(ad, f!, fx, x)
tag = __standard_tag(ad.tag, x)

Check warning on line 41 in src/highlevel/forward_mode.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/forward_mode.jl#L41

Added line #L41 was not covered by tests
if coloring_result isa NoMatrixColoring
cache = ForwardDiff.JacobianConfig(f!, fx, x, __chunksize(ad, x),
__standard_tag(ad.tag, x))
cache = ForwardDiff.JacobianConfig(f!, fx, x, __chunksize(ad, x), tag)

Check warning on line 43 in src/highlevel/forward_mode.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/forward_mode.jl#L43

Added line #L43 was not covered by tests
jac_prototype = nothing
else
# Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof`
cache = ForwardColorJacCache(f!, x, __chunksize(ad); coloring_result.colorvec,
dx = fx, sparsity = coloring_result.jacobian_sparsity, ad.tag)
dx = fx, sparsity = coloring_result.jacobian_sparsity, tag = typeof(tag))
jac_prototype = coloring_result.jacobian_sparsity
end
return ForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
Expand All @@ -49,8 +56,7 @@ function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache,
if cache.cache isa ForwardColorJacCache
forwarddiff_color_jacobian(J, f, x, cache.cache) # Use Sparse ForwardDiff
else
# Disable tag checking since we set the tag to our custom tag
ForwardDiff.jacobian!(J, f, x, cache.cache, Val(false)) # Don't try to exploit sparsity
ForwardDiff.jacobian!(J, f, x, cache.cache) # Don't try to exploit sparsity
end
return J
end
Expand All @@ -60,8 +66,7 @@ function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache,
if cache.cache isa ForwardColorJacCache
forwarddiff_color_jacobian!(J, f!, x, cache.cache) # Use Sparse ForwardDiff
else
# Disable tag checking since we set the tag to our custom tag
ForwardDiff.jacobian!(J, f!, fx, x, cache.cache, Val(false)) # Don't try to exploit sparsity
ForwardDiff.jacobian!(J, f!, fx, x, cache.cache) # Don't try to exploit sparsity
end
return J
end

0 comments on commit d6fae4e

Please sign in to comment.