diff --git a/docs/src/index.md b/docs/src/index.md index eacd11c..5a0cf82 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -30,6 +30,12 @@ AutoFiniteDiff AutoFiniteDifferences ``` +Taylor mode: + +```@docs +AutoGTPSA +``` + ### Reverse mode ```@docs diff --git a/src/ADTypes.jl b/src/ADTypes.jl index d289b09..336a026 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -40,6 +40,7 @@ export AutoChainRules, AutoFiniteDiff, AutoFiniteDifferences, AutoForwardDiff, + AutoGTPSA, AutoModelingToolkit, AutoPolyesterForwardDiff, AutoReverseDiff, diff --git a/src/dense.jl b/src/dense.jl index dfcf6a1..cbea28d 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -195,6 +195,38 @@ function Base.show(io::IO, backend::AutoForwardDiff{chunksize}) where {chunksize print(io, ")") end +""" + AutoGTPSA{D} + +Struct used to select the [GTPSA.jl](https://github.com/bmad-sim/GTPSA.jl) backend for automatic differentiation. + +Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). + +# Constructors + + AutoGTPSA(; descriptor=nothing) + +# Fields + + - `descriptor::D`: can be either + + + a GTPSA `Descriptor` specifying the number of variables/parameters, parameter + order, individual variable/parameter truncation orders, and maximum order. See + the [GTPSA.jl documentation](https://bmad-sim.github.io/GTPSA.jl/stable/man/c_descriptor/) for more details. + + `nothing` to automatically use a `Descriptor` given the context. +""" +Base.@kwdef struct AutoGTPSA{D} <: AbstractADType + descriptor::D = nothing +end + +mode(::AutoGTPSA) = ForwardMode() + +function Base.show(io::IO, backend::AutoGTPSA{D}) where {D} + print(io, AutoGTPSA, "(") + D != Nothing && print(io, "descriptor=", repr(backend.descriptor; context = io)) + print(io, ")") +end + """ AutoPolyesterForwardDiff{chunksize,T} diff --git a/test/dense.jl b/test/dense.jl index 0b1d944..7c8b0ca 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -105,6 +105,20 @@ end @test ad.tag == CustomTag() end +@testset "AutoGTPSA" begin + ad = AutoGTPSA(; descriptor = nothing) + @test ad isa AbstractADType + @test ad isa AutoGTPSA{Nothing} + @test mode(ad) isa ForwardMode + @test ad.descriptor === nothing + + ad = AutoGTPSA(; descriptor = Val(:descriptor)) + @test ad isa AbstractADType + @test ad isa AutoGTPSA{Val{:descriptor}} + @test mode(ad) isa ForwardMode + @test ad.descriptor == Val(:descriptor) +end + @testset "AutoPolyesterForwardDiff" begin ad = AutoPolyesterForwardDiff() @test ad isa AbstractADType diff --git a/test/misc.jl b/test/misc.jl index e27b63c..bb11fb1 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -40,6 +40,8 @@ for backend in [ ADTypes.AutoFiniteDifferences(; fdm = :fdm), ADTypes.AutoForwardDiff(), ADTypes.AutoForwardDiff(chunksize = 3, tag = :tag), + ADTypes.AutoGTPSA(), + ADTypes.AutoGTPSA(; descriptor = Val(:descriptor)), ADTypes.AutoPolyesterForwardDiff(), ADTypes.AutoPolyesterForwardDiff(chunksize = 3, tag = :tag), ADTypes.AutoReverseDiff(), diff --git a/test/runtests.jl b/test/runtests.jl index 2584a67..2e3f3e7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,6 +45,7 @@ function every_ad() AutoFiniteDiff(), AutoFiniteDifferences(; fdm = :fdm), AutoForwardDiff(), + AutoGTPSA(), AutoPolyesterForwardDiff(), AutoReverseDiff(), AutoSymbolics(), @@ -66,6 +67,8 @@ function every_ad_with_options() AutoFiniteDifferences(; fdm = :fdm), AutoForwardDiff(), AutoForwardDiff(chunksize = 3, tag = :tag), + AutoGTPSA(), + AutoGTPSA(descriptor = Val(:descriptor)), AutoPolyesterForwardDiff(), AutoPolyesterForwardDiff(chunksize = 3, tag = :tag), AutoReverseDiff(),