Skip to content

Commit

Permalink
Merge pull request #169 from aelligp/main
Browse files Browse the repository at this point in the history
Enzyme `>= 0.12` compatibility
  • Loading branch information
omlins authored Oct 9, 2024
2 parents 51da6f5 + 01d4f7b commit 489252e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ParallelStencil_EnzymeExt = "Enzyme"
AMDGPU = "0.6, 0.7, 0.8, 0.9, 1"
CUDA = "3.12, 4, 5"
CellArrays = "0.2.1"
Enzyme = "0.11"
Enzyme = "0.11, 0.12, 0.13"
MacroTools = "0.5"
Polyester = "0.7"
StaticArrays = "1"
Expand Down
24 changes: 21 additions & 3 deletions src/ParallelKernel/EnzymeExt/autodiff_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,36 @@ import ParallelStencil
import ParallelStencil: PKG_THREADS, PKG_POLYESTER
import Enzyme

function ParallelStencil.ParallelKernel.AD.init_AD(package::Symbol)
if iscpu(package)
Enzyme.API.runtimeActivity!(true) # NOTE: this is currently required for Enzyme to work correctly with threads
# function ParallelStencil.ParallelKernel.AD.init_AD(package::Symbol)
# if iscpu(package)
# Enzyme.API.runtimeActivity!(true) # NOTE: this is currently required for Enzyme to work correctly with threads
# end
# end

# ParallelStencil injects a configuration parameter at the end, for Enzyme we need to wrap that parameter as a Annotation
# for all purposes this ought to be Const. This is not ideal since we might accidentially wrap other parameters the user
# provided as well. This is needed to support @parallel autodiff_deferred(...)
function promote_to_const(args...)
ntuple(length(args)) do i
@inline
if !(args[i] isa Enzyme.Annotation ||
(args[i] isa UnionAll && args[i] <: Enzyme.Annotation) || # Const
(args[i] isa DataType && args[i] <: Enzyme.Annotation)) # Const{Nothing}
return Enzyme.Const(args[i])
else
return args[i]
end
end
end

function ParallelStencil.ParallelKernel.AD.autodiff_deferred!(arg, args...) # NOTE: minimal specialization is used to avoid overwriting the default method
args = promote_to_const(args...)
Enzyme.autodiff_deferred(arg, args...)
return
end

function ParallelStencil.ParallelKernel.AD.autodiff_deferred_thunk!(arg, args...) # NOTE: minimal specialization is used to avoid overwriting the default method
args = promote_to_const(args...)
Enzyme.autodiff_deferred_thunk(arg, args...)
return
end
Expand Down
4 changes: 2 additions & 2 deletions test/ParallelKernel/test_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ import Enzyme
end
return
end
@parallel configcall=f!(A, B, a) AD.autodiff_deferred!(Enzyme.Reverse, f!, DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a))
Enzyme.autodiff_deferred(Enzyme.Reverse, g!, DuplicatedNoNeed(A_ref, Ā_ref), DuplicatedNoNeed(B_ref, B̄_ref), Const(a))
@parallel configcall=f!(A, B, a) AD.autodiff_deferred!(Enzyme.Reverse, Const(f!), Const, DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a))
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(g!),Const, DuplicatedNoNeed(A_ref, Ā_ref), DuplicatedNoNeed(B_ref, B̄_ref), Const(a))
@test Array(Ā) Ā_ref
@test Array(B̄) B̄_ref
end
Expand Down

0 comments on commit 489252e

Please sign in to comment.