Skip to content

Commit

Permalink
Merge pull request #72 from SciML/gd/constant_f_enzyme
Browse files Browse the repository at this point in the history
Add constant_function kwarg to AutoEnzyme
  • Loading branch information
ChrisRackauckas authored Jul 17, 2024
2 parents 97d5146 + 091d3b6 commit 39da305
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
authors = [
"Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors",
]
version = "1.5.4"
version = "1.6.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
46 changes: 42 additions & 4 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,63 @@ struct AutoDiffractor <: AbstractADType end
mode(::AutoDiffractor) = ForwardOrReverseMode()

"""
AutoEnzyme{M}
AutoEnzyme{M,constant_function}
Struct used to select the [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) backend for automatic differentiation.
Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
# Constructors
AutoEnzyme(; mode=nothing)
AutoEnzyme(; mode=nothing, constant_function::Bool=false)
The `constant_function` keyword argument (and type parameter) determines whether the function object itself should be considered constant or not during differentiation with Enzyme.jl.
For simple functions, `constant_function` should usually be set to `false`, but in the case of closures or callable structs which contain differentiated data that can be treated as constant, `constant_function` should be set to `true` for increased performance (more details below).
# Fields
- `mode::M`: can be either
+ an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required
+ `nothing` to choose the best mode automatically
# Notes
If `constant_function = true` but the enclosed data is not truly constant, then Enzyme.jl will not compute the correct derivative values.
An example of such a function is:
```julia
cache = [0.0]
function f(x)
cache[1] = x[1]^2
cache[1] + x[1]
end
```
In this case, the enclosed cache is a function of the differentiated input, and thus its values are non-constant with respect to the input.
Thus, in order to compute the correct derivative of the output, the derivative must propagate through the `cache` value, and said `cache` must not be treated as constant.
Conversely, the following function can treat `parameter` as a constant, because `parameter` is never modified based on the input `x`:
```julia
parameter = [0.0]
function f(x)
parameter[1] + x[1]
end
```
In this case, `constant_function = true` would allow the chosen differentiation system to perform extra memory and compute optimizations, under the assumption that `parameter` is kept constant.
"""
Base.@kwdef struct AutoEnzyme{M} <: AbstractADType
mode::M = nothing
struct AutoEnzyme{M, constant_function} <: AbstractADType
mode::M
end

function AutoEnzyme(mode::M; constant_function::Bool = false) where {M}
return AutoEnzyme{M, constant_function}(mode)
end

function AutoEnzyme(; mode::M = nothing, constant_function::Bool = false) where {M}
return AutoEnzyme{M, constant_function}(mode)
end

mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension
Expand Down
14 changes: 10 additions & 4 deletions test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,25 @@ end
@testset "AutoEnzyme" begin
ad = AutoEnzyme()
@test ad isa AbstractADType
@test ad isa AutoEnzyme{Nothing}
@test ad isa AutoEnzyme{Nothing, false}
@test mode(ad) isa ForwardOrReverseMode
@test ad.mode === nothing

ad = AutoEnzyme(EnzymeCore.Forward; constant_function = true)
@test ad isa AbstractADType
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), true}
@test mode(ad) isa ForwardMode
@test ad.mode == EnzymeCore.Forward

ad = AutoEnzyme(; mode = EnzymeCore.Forward)
@test ad isa AbstractADType
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward)}
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), false}
@test mode(ad) isa ForwardMode
@test ad.mode == EnzymeCore.Forward

ad = AutoEnzyme(; mode = EnzymeCore.Reverse)
ad = AutoEnzyme(; mode = EnzymeCore.Reverse, constant_function = true)
@test ad isa AbstractADType
@test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse)}
@test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse), true}
@test mode(ad) isa ReverseMode
@test ad.mode == EnzymeCore.Reverse
end
Expand Down
5 changes: 0 additions & 5 deletions test/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ end
@test length(string(sparse_backend1)) < length(string(sparse_backend2))
end

import ADTypes

struct FakeSparsityDetector <: ADTypes.AbstractSparsityDetector end
struct FakeColoringAlgorithm <: ADTypes.AbstractColoringAlgorithm end

for backend in [
# dense
ADTypes.AutoChainRules(; ruleconfig = :rc),
Expand Down

0 comments on commit 39da305

Please sign in to comment.