From 25600cae0c119ab55880682cf6f346aed0eccb1e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 9 Aug 2024 07:47:17 +0200 Subject: [PATCH] Add `function_annotation` to `AutoEnzyme` (#77) * Add function_annotation to AutoEnzyme * Fix printing * Test warning --- Project.toml | 2 +- src/dense.jl | 24 +++++++++++++++++------- test/dense.jl | 19 ++++++++++--------- test/misc.jl | 7 +++++++ 4 files changed, 35 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index adcecd9..69c8e9f 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = [ "Vaibhav Dixit , Guillaume Dalle and contributors", ] -version = "1.6.2" +version = "1.7.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/dense.jl b/src/dense.jl index 8e6d960..dfcf6a1 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -39,7 +39,7 @@ struct AutoDiffractor <: AbstractADType end mode(::AutoDiffractor) = ForwardOrReverseMode() """ - AutoEnzyme{M} + AutoEnzyme{M,A} Struct used to select the [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) backend for automatic differentiation. @@ -47,28 +47,38 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). # Constructors - AutoEnzyme(; mode=nothing) + AutoEnzyme(; mode::M=nothing, function_annotation::Type{A}=Nothing) + +# Type parameters + + - `A` determines how the function `f` to differentiate is passed to Enzyme. It can be: + + + a subtype of `EnzymeCore.Annotation` (like `EnzymeCore.Const` or `EnzymeCore.Duplicated`) to enforce a given annotation + + `Nothing` to simply pass `f` and let Enzyme choose the most appropriate annotation # Fields - - `mode::M`: can be either + - `mode::M` determines the autodiff mode (forward or reverse). It can be: + an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required + `nothing` to choose the best mode automatically """ -struct AutoEnzyme{M} <: AbstractADType +struct AutoEnzyme{M, A} <: AbstractADType mode::M end -function AutoEnzyme(; mode::M = nothing) where {M} - return AutoEnzyme{M}(mode) +function AutoEnzyme(; + mode::M = nothing, function_annotation::Type{A} = Nothing) where {M, A} + return AutoEnzyme{M, A}(mode) end mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension -function Base.show(io::IO, backend::AutoEnzyme) +function Base.show(io::IO, backend::AutoEnzyme{M, A}) where {M, A} print(io, AutoEnzyme, "(") !isnothing(backend.mode) && print(io, "mode=", repr(backend.mode; context = io)) + !isnothing(backend.mode) && !(A <: Nothing) && print(io, ", ") + !(A <: Nothing) && print(io, "function_annotation=", repr(A; context = io)) print(io, ")") end diff --git a/test/dense.jl b/test/dense.jl index 7554565..0b1d944 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -28,25 +28,26 @@ end @testset "AutoEnzyme" begin ad = AutoEnzyme() @test ad isa AbstractADType - @test ad isa AutoEnzyme{Nothing} + @test ad isa AutoEnzyme{Nothing, Nothing} @test mode(ad) isa ForwardOrReverseMode @test ad.mode === nothing - ad = AutoEnzyme(EnzymeCore.Forward) + ad = AutoEnzyme(; mode = EnzymeCore.Forward) @test ad isa AbstractADType - @test ad isa AutoEnzyme{typeof(EnzymeCore.Forward)} + @test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), Nothing} @test mode(ad) isa ForwardMode @test ad.mode == EnzymeCore.Forward - ad = AutoEnzyme(; mode = EnzymeCore.Forward) + ad = AutoEnzyme(; function_annotation = EnzymeCore.Const) @test ad isa AbstractADType - @test ad isa AutoEnzyme{typeof(EnzymeCore.Forward)} - @test mode(ad) isa ForwardMode - @test ad.mode == EnzymeCore.Forward + @test ad isa AutoEnzyme{Nothing, EnzymeCore.Const} + @test mode(ad) isa ForwardOrReverseMode + @test ad.mode === nothing - ad = AutoEnzyme(; mode = EnzymeCore.Reverse) + ad = AutoEnzyme(; + mode = EnzymeCore.Reverse, function_annotation = EnzymeCore.Duplicated) @test ad isa AbstractADType - @test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse)} + @test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse), EnzymeCore.Duplicated} @test mode(ad) isa ReverseMode @test ad.mode == EnzymeCore.Reverse end diff --git a/test/misc.jl b/test/misc.jl index b3e501f..e27b63c 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -21,12 +21,19 @@ end @test length(string(sparse_backend1)) < length(string(sparse_backend2)) end +#= +The following tests are only for visual assessment of the printing behavior. +They do not correspond to proper use of ADTypes constructors. +Please refer to the docstrings for that. +=# for backend in [ # dense ADTypes.AutoChainRules(; ruleconfig = :rc), ADTypes.AutoDiffractor(), ADTypes.AutoEnzyme(), ADTypes.AutoEnzyme(mode = :forward), + ADTypes.AutoEnzyme(function_annotation = Val{:forward}), + ADTypes.AutoEnzyme(mode = :reverse, function_annotation = Val{:duplicated}), ADTypes.AutoFastDifferentiation(), ADTypes.AutoFiniteDiff(), ADTypes.AutoFiniteDiff(fdtype = :fd, fdjtype = :fdj, fdhtype = :fdh),