Skip to content

Commit

Permalink
Merge pull request #63 from SciML/ap/reversediff
Browse files Browse the repository at this point in the history
Promote ReverseDiff compile field to type
  • Loading branch information
ChrisRackauckas authored Jun 23, 2024
2 parents 72f806d + 6c27b87 commit eab4336
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
authors = [
"Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors",
]
version = "1.4.0"
version = "1.5.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
3 changes: 3 additions & 0 deletions src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ abstract type AbstractADType end

Base.broadcastable(ad::AbstractADType) = Ref(ad)

@inline _unwrap_val(::Val{T}) where {T} = T
@inline _unwrap_val(x) = x

include("mode.jl")
include("dense.jl")
include("sparse.jl")
Expand Down
22 changes: 18 additions & 4 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,28 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
# Constructors
AutoReverseDiff(; compile=false)
AutoReverseDiff(; compile::Union{Val, Bool} = Val(false))
# Fields
- `compile::Bool`: whether to [compile the tape](https://juliadiff.org/ReverseDiff.jl/api/#ReverseDiff.compile) prior to differentiation
- `compile::Union{Val, Bool}`: whether to [compile the tape](https://juliadiff.org/ReverseDiff.jl/api/#ReverseDiff.compile) prior to differentiation
"""
Base.@kwdef struct AutoReverseDiff <: AbstractADType
compile::Bool = false
struct AutoReverseDiff{C} <: AbstractADType
compile::Bool # this field if left for legacy reasons

function AutoReverseDiff(; compile::Union{Val, Bool} = Val(false))
_compile = _unwrap_val(compile)
return new{_compile}(_compile)
end
end

function Base.getproperty(ad::AutoReverseDiff, s::Symbol)
if s === :compile
Base.depwarn(
"`ad.compile` where `ad` is `AutoReverseDiff` has been deprecated and will be removed in v2. Instead it is available as a compile-time constant as `AutoReverseDiff{true}` or `AutoReverseDiff{false}`.",
:getproperty)
end
return getfield(ad, s)
end

mode(::AutoReverseDiff) = ReverseMode()
Expand Down
8 changes: 4 additions & 4 deletions src/symbols.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ ADTypes.AutoZygote()
Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...)

for backend in (:ChainRules, :Diffractor, :Enzyme, :FastDifferentiation,
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :PolyesterForwardDiff,
:ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote)
@eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(args...; kws...)
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :PolyesterForwardDiff,
:ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote)
@eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(
args...; kws...)
end

11 changes: 10 additions & 1 deletion test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,26 @@ end
end

@testset "AutoReverseDiff" begin
ad = AutoReverseDiff()
ad = @inferred AutoReverseDiff()
@test ad isa AbstractADType
@test ad isa AutoReverseDiff
@test mode(ad) isa ReverseMode
@test !ad.compile
@test_deprecated ad.compile

ad = AutoReverseDiff(; compile = true)
@test ad isa AbstractADType
@test ad isa AutoReverseDiff
@test mode(ad) isa ReverseMode
@test ad.compile
@test_deprecated ad.compile

ad = @inferred AutoReverseDiff(; compile = Val(true))
@test ad isa AbstractADType
@test ad isa AutoReverseDiff
@test mode(ad) isa ReverseMode
@test ad.compile
@test_deprecated ad.compile
end

@testset "AutoSymbolics" begin
Expand Down

0 comments on commit eab4336

Please sign in to comment.