Skip to content

Commit

Permalink
Dont break code for non-functions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 28, 2023
1 parent 62ff85d commit 27f6b79
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/highlevel/coloring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end
# Approximate Jacobian Sparsity Detection
## Right now we hardcode it to use `ForwardDiff`
function (alg::ApproximateJacobianSparsity)(ad::AbstractSparseADType, f::F, x; fx = nothing,

Check warning on line 35 in src/highlevel/coloring.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/coloring.jl#L35

Added line #L35 was not covered by tests
kwargs...) where {F <: Function}
kwargs...) where {F}
@unpack ntrials, rng = alg
fx = fx === nothing ? f(x) : fx
J = fill!(similar(fx, length(fx), length(x)), 0)
Expand All @@ -48,7 +48,7 @@ function (alg::ApproximateJacobianSparsity)(ad::AbstractSparseADType, f::F, x; f
end

function (alg::ApproximateJacobianSparsity)(ad::AbstractSparseADType, f!::F, fx, x;

Check warning on line 50 in src/highlevel/coloring.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/coloring.jl#L50

Added line #L50 was not covered by tests
kwargs...) where {F <: Function}
kwargs...) where {F}
@unpack ntrials, rng = alg
cfg = ForwardDiff.JacobianConfig(f!, fx, x)
J = fill!(similar(fx, length(fx), length(x)), 0)
Expand Down
8 changes: 4 additions & 4 deletions src/highlevel/finite_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ struct FiniteDiffJacobianCache{CO, CA, J, FX, X} <: AbstractMaybeSparseJacobianC
end

function sparse_jacobian_cache(fd::Union{AutoSparseFiniteDiff, AutoFiniteDiff},
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F <: Function}
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
coloring_result = sd(fd, f, x)
fx = fx === nothing ? similar(f(x)) : fx
if coloring_result isa NoMatrixColoring
Expand All @@ -22,7 +22,7 @@ function sparse_jacobian_cache(fd::Union{AutoSparseFiniteDiff, AutoFiniteDiff},
end

function sparse_jacobian_cache(fd::Union{AutoSparseFiniteDiff, AutoFiniteDiff},
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F <: Function}
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
coloring_result = sd(fd, f!, fx, x)
if coloring_result isa NoMatrixColoring
cache = FiniteDiff.JacobianCache(x, fx)
Expand All @@ -36,13 +36,13 @@ function sparse_jacobian_cache(fd::Union{AutoSparseFiniteDiff, AutoFiniteDiff},
end

function sparse_jacobian!(J::AbstractMatrix, fd, cache::FiniteDiffJacobianCache, f::F,

Check warning on line 38 in src/highlevel/finite_diff.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/finite_diff.jl#L38

Added line #L38 was not covered by tests
x) where {F <: Function}
x) where {F}
f!(y, x) = (y .= f(x))
return sparse_jacobian!(J, fd, cache, f!, cache.fx, x)
end

function sparse_jacobian!(J::AbstractMatrix, _, cache::FiniteDiffJacobianCache, f!::F, _,

Check warning on line 44 in src/highlevel/finite_diff.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/finite_diff.jl#L44

Added line #L44 was not covered by tests
x) where {F <: Function}
x) where {F}
FiniteDiff.finite_difference_jacobian!(J, f!, x, cache.cache)
return J
end
8 changes: 4 additions & 4 deletions src/highlevel/forward_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ end
struct SparseDiffToolsTag end

function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff},
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F <: Function}
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
coloring_result = sd(ad, f, x)
fx = fx === nothing ? similar(f(x)) : fx
if coloring_result isa NoMatrixColoring
Expand All @@ -25,7 +25,7 @@ function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff}
end

function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff},
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F <: Function}
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
coloring_result = sd(ad, f!, fx, x)
if coloring_result isa NoMatrixColoring
cache = ForwardDiff.JacobianConfig(f!, fx, x, __chunksize(ad, x),
Expand All @@ -40,7 +40,7 @@ function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff}
end

function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache, f::F,

Check warning on line 42 in src/highlevel/forward_mode.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/forward_mode.jl#L42

Added line #L42 was not covered by tests
x) where {F <: Function}
x) where {F}
if cache.cache isa ForwardColorJacCache
forwarddiff_color_jacobian(J, f, x, cache.cache) # Use Sparse ForwardDiff
else
Expand All @@ -50,7 +50,7 @@ function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache,
end

function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache, f!::F, fx,

Check warning on line 52 in src/highlevel/forward_mode.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/forward_mode.jl#L52

Added line #L52 was not covered by tests
x) where {F <: Function}
x) where {F}
if cache.cache isa ForwardColorJacCache
forwarddiff_color_jacobian!(J, f!, x, cache.cache) # Use Sparse ForwardDiff
else
Expand Down
8 changes: 4 additions & 4 deletions src/highlevel/reverse_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ struct ReverseModeJacobianCache{CO, CA, J, FX, X, I} <: AbstractMaybeSparseJacob
end

function sparse_jacobian_cache(ad::Union{AutoEnzyme, AbstractReverseMode},
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F <: Function}
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
fx = fx === nothing ? similar(f(x)) : fx
coloring_result = sd(ad, f, x)
jac_prototype = __getfield(coloring_result, Val(:jacobian_sparsity))
Expand All @@ -17,7 +17,7 @@ function sparse_jacobian_cache(ad::Union{AutoEnzyme, AbstractReverseMode},
end

function sparse_jacobian_cache(ad::Union{AutoEnzyme, AbstractReverseMode},
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F <: Function}
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
coloring_result = sd(ad, f!, fx, x)
jac_prototype = __getfield(coloring_result, Val(:jacobian_sparsity))
return ReverseModeJacobianCache(coloring_result, nothing, jac_prototype, fx, x,
Expand All @@ -34,12 +34,12 @@ function sparse_jacobian!(J::AbstractMatrix, ad, cache::ReverseModeJacobianCache
end

function __sparse_jacobian_reverse_impl!(J::AbstractMatrix, ad, idx_vec,
cache::MatrixColoringResult, f::F, x) where {F <: Function}
cache::MatrixColoringResult, f::F, x) where {F}
return __sparse_jacobian_reverse_impl!(J, ad, idx_vec, cache, f, nothing, x)
end

function __sparse_jacobian_reverse_impl!(J::AbstractMatrix, ad, idx_vec,
cache::MatrixColoringResult, f::F, fx, x) where {F <: Function}
cache::MatrixColoringResult, f::F, fx, x) where {F}
# If `fx` is `nothing` then assume `f` is not in-place
x_ = __maybe_copy_x(ad, x)
fx_ = __maybe_copy_x(ad, fx)
Expand Down
6 changes: 3 additions & 3 deletions test/test_sparse_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
@testset "sparse_jacobian $(nameof(typeof(difftype))): Out of Place" for difftype in (AutoSparseZygote(),
AutoZygote(), AutoSparseForwardDiff(), AutoForwardDiff(),
AutoSparseForwardDiff(; chunksize = 0), AutoForwardDiff(; chunksize = 0),
AutoSparseForwardDiff(; chunksize = 8), AutoForwardDiff(; chunksize = 8),
AutoSparseForwardDiff(; chunksize = 4), AutoForwardDiff(; chunksize = 4),
AutoSparseFiniteDiff(), AutoFiniteDiff(), AutoEnzyme(), AutoSparseEnzyme())
@testset "Cache & Reuse" begin
cache = sparse_jacobian_cache(difftype, sd, fdiff, x)
Expand Down Expand Up @@ -95,8 +95,8 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa

@testset "sparse_jacobian $(nameof(typeof(difftype))): In place" for difftype in (AutoSparseForwardDiff(),
AutoForwardDiff(), AutoSparseForwardDiff(; chunksize = 0),
AutoForwardDiff(; chunksize = 0), AutoSparseForwardDiff(; chunksize = 8),
AutoForwardDiff(; chunksize = 8), AutoSparseFiniteDiff(), AutoFiniteDiff(),
AutoForwardDiff(; chunksize = 0), AutoSparseForwardDiff(; chunksize = 4),
AutoForwardDiff(; chunksize = 4), AutoSparseFiniteDiff(), AutoFiniteDiff(),
AutoEnzyme(), AutoSparseEnzyme())
y = similar(x)
cache = sparse_jacobian_cache(difftype, sd, fdiff, y, x)
Expand Down

0 comments on commit 27f6b79

Please sign in to comment.