Automatic differentiation backends for LogDensityProblems.jl.
The only exposed function is ADgradient
. Example:
using LogDensityProblemsAD, ForwardDiff
∇ℓ = ADgradient(:ForwardDiff, ℓ) # assumes ℓ implements the LogDensityProblems interface
Below is the list of supported backends, more or less in the order they are recommended for ℝⁿ→ℝ functions. That said, for nontrivial problems you should do your own benchmarking and compare results from various backends in case you suspect an incorrect calculation (eg because MCMC does not converge and you have ruled everything else out).
Before using AD, make sure your code is type stable, inferred correctly, and minimize allocations. Eg
using LogDensityProblems, BenchmarkTools, Test
x = zeros(LogDensityProblems.dimension(ℓ)) # ℓ is your log density
@inferred LogDensityProblems.logdensity(ℓ, x) # check inference, also see @code_warntype
@benchmark LogDensityProblems.logdensity($ℓ, $x) # check performance and allocations
-
ForwardDiff.jl Robust and mature implementation, but not necessarily the fastest. Scales more or less linearly with input dimension, so use with caution for large problems. Ideal for checking correctness.
-
Enzyme.jl Fastest option if it works for your problem, ideal if your code does not allocate. Try it first, with reverse mode (the default). Since Enzyme is still experimental, check the gradient.
-
Zygote.jl and Tracker.jl May be a viable choice if Enzyme is not working for your problem, and calculations are non-mutating and performed on matrices and vectors, not scalars. Benchmark against alternatives above. Of the two, Zygote is more actively maintained.
-
ReverseDiff.jl Can be very performant with tape compilation, but make sure that your code does not branch changing the result (ie if you use tape compilation, check your derivatives).
-
FiniteDifferences.jl Finite differences are very robust, with a small numerical error, but usually not fast enough to practically replace AD on nontrivial problems. The backend in this package is mainly intended for checking and debugging results from other backends; but note that in most cases ForwardDiff is faster and more accurate.
Other AD frameworks are supported thanks to ADTypes.jl and DifferentiationInterface.jl.
PRs for remaining AD frameworks are welcome, even if they are WIP.