Skip to content

Commit

Permalink
Use AbstractDifferentiation package
Browse files Browse the repository at this point in the history
  • Loading branch information
gerlero committed Feb 5, 2024
1 parent 9a9bbdc commit 392d21c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Gabriel S. Gerlero <[email protected]>"]
version = "2.5.2"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -20,6 +21,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
ToeplitzMatrices = "c751599d-da0a-543b-9d20-d0a503d91d24"

[compat]
AbstractDifferentiation = "0.6.2"
ArgCheck = "2"
ForwardDiff = "0.10"
LinearAlgebra = "1"
Expand Down
24 changes: 16 additions & 8 deletions src/_Diff.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
module _Diff

using ForwardDiff: derivative
using ForwardDiff: Dual, Tag, value, extract_derivative
import AbstractDifferentiation
import ForwardDiff

@inline function derivative(f, x::Real)
return only(AbstractDifferentiation.derivative(AbstractDifferentiation.ForwardDiffBackend(),
f,
x))
end

@inline function value_and_derivative(f, x::Real)
T = typeof(Tag(f, typeof(x)))
ydual = f(Dual{T}(x, oneunit(x)))
return value(T, ydual), extract_derivative(T, ydual)
a, b = AbstractDifferentiation.value_and_derivative(AbstractDifferentiation.ForwardDiffBackend(),
f,
x)
return a, only(b)
end

@inline function value_and_derivatives(f, x::Real)
T = typeof(Tag(f, typeof(x)))
ydual, ddual = value_and_derivative(f, Dual{T}(x, oneunit(x)))
return value(T, ydual), value(T, ddual), extract_derivative(T, ddual)
a, b, c = AbstractDifferentiation.value_derivative_and_second_derivative(AbstractDifferentiation.ForwardDiffBackend(),
f,
x)
return a, only(b), only(c)
end

export derivative, value_and_derivative, value_and_derivatives
Expand Down

0 comments on commit 392d21c

Please sign in to comment.