Skip to content

Commit

Permalink
Merge pull request #101 from omlins/ad
Browse files Browse the repository at this point in the history
Add high-level support for architecture-agnostic automatic differentiation
  • Loading branch information
omlins authored Jul 12, 2023
2 parents 8530bfd + 75d9fca commit 19b8835
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 19 deletions.
54 changes: 48 additions & 6 deletions src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,17 @@ parallel_async(caller::Module, args::Union{Symbol,Expr}...; package::Symbol=get_

function parallel(caller::Module, args::Union{Symbol,Expr}...; package::Symbol=get_package(), async::Bool=false)
posargs, kwargs_expr, kernelarg = split_parallel_args(args)
kwargs, backend_kwargs_expr = extract_kwargs(caller, kwargs_expr, (:stream, :shmem, :launch, :configcall), "@parallel <kernelcall>", true; eval_args=(:launch,))
launch = haskey(kwargs, :launch) ? kwargs.launch : true
configcall = haskey(kwargs, :configcall) ? kwargs.configcall : kernelarg
if isgpu(package) parallel_call_gpu(posargs..., kernelarg, backend_kwargs_expr, async, package; kwargs...)
elseif (package == PKG_THREADS) parallel_call_threads(posargs..., kernelarg, async; launch=launch, configcall=configcall) # Ignore keyword args as they are not for the threads case (noted in doc).
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
kwargs, backend_kwargs_expr = extract_kwargs(caller, kwargs_expr, (:stream, :shmem, :launch, :configcall, :∇, :ad_mode, :ad_annotations), "@parallel <kernelcall>", true; eval_args=(:launch,))
is_ad_highlevel = haskey(kwargs, :∇)
launch = haskey(kwargs, :launch) ? kwargs.launch : true
configcall = haskey(kwargs, :configcall) ? kwargs.configcall : kernelarg
if is_ad_highlevel
parallel_call_ad(caller, kernelarg, backend_kwargs_expr, async, package, posargs, kwargs)
else
if isgpu(package) parallel_call_gpu(posargs..., kernelarg, backend_kwargs_expr, async, package; kwargs...)
elseif (package == PKG_THREADS) parallel_call_threads(posargs..., kernelarg, async; launch=launch, configcall=configcall) # Ignore keyword args as they are not for the threads case (noted in doc).
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
end
end

Expand Down Expand Up @@ -176,6 +181,43 @@ end

## @PARALLEL CALL FUNCTIONS

function parallel_call_ad(caller::Module, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool, package::Symbol, posargs, kwargs)
ad_mode = haskey(kwargs, :ad_mode) ? kwargs.ad_mode : AD_MODE_DEFAULT
ad_annotations_expr = haskey(kwargs, :ad_annotations) ? extract_tuple(kwargs.ad_annotations; nested=true) : []
ad_vars_expr = extract_tuple(kwargs.∇; nested=true)
~, ~, ad_vars = extract_kwargs(caller, ad_vars_expr, (), "", true; separator=:->)
~, ~, ad_annotations = extract_kwargs(caller, ad_annotations_expr, (), "", true)
ad_vars = map(x->unblock(x), ad_vars)
ad_annotations = map(x->extract_tuple(x), ad_annotations)
f_name = extract_kernelcall_name(kernelcall)
f_posargs, ~ = extract_kernelcall_args(kernelcall)
ad_annotations_byvar = Dict(a => [] for a in f_posargs)
for (keyword, vars) in zip(keys(ad_annotations), values(ad_annotations))
if (keyword keys(AD_SUPPORTED_ANNOTATIONS)) @KeywordArgumentError("annotation $keyword is not (yet) supported with high-level syntax; use the generic syntax calling directly `autodiff_deferred!`.") end
for var in vars
if (ad_annotations_byvar[var] != []) @KeywordArgumentError("variable $var has more than one annotation. Nested annotations are not (yet) supported with high-level syntax; use the generic syntax calling directly `autodiff_deferred!`.") end
push!(ad_annotations_byvar[var], AD_SUPPORTED_ANNOTATIONS[keyword])
end
end
for var in keys(ad_vars)
if ad_annotations_byvar[var] == []
push!(ad_annotations_byvar[var], AD_DUPLICATE_DEFAULT)
end
end
for var in f_posargs
if ad_annotations_byvar[var] == []
push!(ad_annotations_byvar[var], AD_ANNOTATION_DEFAULT)
end
end
annotated_args = (:($(ad_annotations_byvar[var][1])($((var keys(ad_vars) ? (var, ad_vars[var]) : (var,))...))) for var in f_posargs)
ad_call = :(autodiff_deferred!($ad_mode, $f_name, $(annotated_args...)))
kwargs_remaining = filter(x->!(x in (:∇, :ad_mode, :ad_annotations)), keys(kwargs))
kwargs_remaining_expr = [:($key=$val) for (key,val) in kwargs_remaining]
if (async) return :( @parallel $(posargs...) $(backend_kwargs_expr...) $(kwargs_remaining_expr...) configcall=$kernelcall $ad_call ) #TODO: the package needs to be passed further here later.
else return :( @parallel_async $(posargs...) $(backend_kwargs_expr...) $(kwargs_remaining_expr...) configcall=$kernelcall $ad_call ) #...
end
end

function parallel_call_gpu(ranges::Union{Symbol,Expr}, nblocks::Union{Symbol,Expr}, nthreads::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool, package::Symbol; stream::Union{Symbol,Expr}=default_stream(package), shmem::Union{Symbol,Expr,Nothing}=nothing, launch::Bool=true, configcall::Expr=kernelcall)
ranges = :(ParallelStencil.ParallelKernel.promote_ranges($ranges))
if (package == PKG_CUDA) int_type = INT_CUDA
Expand Down
29 changes: 18 additions & 11 deletions src/ParallelKernel/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ else
end
import Enzyme
using CellArrays, StaticArrays, MacroTools
import MacroTools: postwalk, splitdef, combinedef, isexpr # NOTE: inexpr_walk used instead of MacroTools.inexpr
import MacroTools: postwalk, splitdef, combinedef, isexpr, unblock # NOTE: inexpr_walk used instead of MacroTools.inexpr


## CONSTANTS AND TYPES (and the macros wrapping them)
Expand Down Expand Up @@ -58,6 +58,10 @@ const SUPPORTED_LITERALTYPES = [Float16, Float32, Float64, Complex{Fl
const SUPPORTED_NUMBERTYPES = [Float16, Float32, Float64, Complex{Float16}, Complex{Float32}, Complex{Float64}]
const PKNumber = Union{Float16, Float32, Float64, Complex{Float16}, Complex{Float32}, Complex{Float64}} # NOTE: this always needs to correspond to SUPPORTED_NUMBERTYPES!
const NUMBERTYPE_NONE = DataType
const AD_MODE_DEFAULT = :(Enzyme.Reverse)
const AD_DUPLICATE_DEFAULT = Enzyme.DuplicatedNoNeed
const AD_ANNOTATION_DEFAULT = Enzyme.Const
const AD_SUPPORTED_ANNOTATIONS = (Const=Enzyme.Const, Active=Enzyme.Active, Duplicated=Enzyme.Duplicated, DuplicatedNoNeed=Enzyme.DuplicatedNoNeed)
const ERRMSG_UNSUPPORTED_PACKAGE = "unsupported package for parallelization"
const ERRMSG_CHECK_PACKAGE = "package has to be functional and one of the following: $(join(SUPPORTED_PACKAGES,", "))"
const ERRMSG_CHECK_NUMBERTYPE = "numbertype has to be one of the following: $(join(SUPPORTED_NUMBERTYPES,", "))"
Expand Down Expand Up @@ -200,10 +204,11 @@ function extract_args(call::Expr, macroname::Symbol)
end

extract_kernelcall_args(call::Expr) = split_args(call.args[2:end]; in_kernelcall=true)
extract_kernelcall_name(call::Expr) = call.args[1]

function is_kwarg(arg; in_kernelcall=false)
function is_kwarg(arg; in_kernelcall=false, separator=:(=))
if in_kernelcall return ( isa(arg, Expr) && inexpr_walk(arg, :kw; match_only_head=true) )
else return ( isa(arg, Expr) && (arg.head == :(=)) && isa(arg.args[1], Symbol))
else return ( isa(arg, Expr) && (arg.head == separator) && isa(arg.args[1], Symbol))
end
end

Expand All @@ -220,8 +225,8 @@ function split_args(args; in_kernelcall=false)
return posargs, kwargs
end

function split_kwargs(kwargs)
if !all(is_kwarg.(kwargs)) @ModuleInternalError("not all of kwargs are keyword arguments.") end
function split_kwargs(kwargs; separator=:(=))
if !all(is_kwarg.(kwargs; separator=separator)) @ModuleInternalError("not all of kwargs are keyword arguments.") end
return Dict(x.args[1] => x.args[2] for x in kwargs)
end

Expand All @@ -241,16 +246,16 @@ function extract_kwargvalues(kwargs_expr, valid_kwargs, macroname)
return extract_values(kwargs, valid_kwargs)
end

function extract_kwargs(caller::Module, kwargs_expr, valid_kwargs, macroname, has_unknown_kwargs; eval_args=())
kwargs = split_kwargs(kwargs_expr)
function extract_kwargs(caller::Module, kwargs_expr, valid_kwargs, macroname, has_unknown_kwargs; eval_args=(), separator=:(=))
kwargs = split_kwargs(kwargs_expr, separator=separator)
if (!has_unknown_kwargs) validate_kwargkeys(kwargs, valid_kwargs, macroname) end
for k in keys(kwargs)
if (k in eval_args) kwargs[k] = eval_arg(caller, kwargs[k]) end
end
kwargs_known = NamedTuple(filter(x -> x.first valid_kwargs, kwargs))
kwargs_unknown = NamedTuple(filter(x -> x.first valid_kwargs, kwargs))
kwargs_unknown_expr = [:($k = $(kwargs_unknown[k])) for k in keys(kwargs_unknown)]
return kwargs_known, kwargs_unknown_expr
return kwargs_known, kwargs_unknown_expr, kwargs_unknown
end

function extract_kwargs(caller::Module, kwargs_expr, valid_kwargs, macroname; eval_args=())
Expand Down Expand Up @@ -314,9 +319,11 @@ inexpr_walk(expr, s::Symbol; match_only_head=false) = false

Base.unquoted(s::Symbol) = s

function extract_tuple(t::Union{Expr,Symbol}) # NOTE: this could return a tuple, but would require to change all small arrays to tuples...
if isa(t, Expr)
return Base.unquoted.(t.args)
function extract_tuple(t::Union{Expr,Symbol}; nested=false) # NOTE: this could return a tuple, but would require to change all small arrays to tuples...
if isa(t, Expr) && t.head == :tuple
if (nested) return t.args
else return Base.unquoted.(t.args)
end
else
return [t]
end
Expand Down
7 changes: 5 additions & 2 deletions src/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,13 @@ function parallel(source::LineNumberNode, caller::Module, args::Union{Symbol,Exp
elseif is_call(args[end])
posargs, kwargs_expr, kernelarg = split_parallel_args(args)
kwargs, backend_kwargs_expr = extract_kwargs(caller, kwargs_expr, (:memopt, :configcall), "@parallel <kernelcall>", true; eval_args=(:memopt,))
memopt = haskey(kwargs, :memopt) ? kwargs.memopt : get_memopt()
memopt = haskey(kwargs, :memopt) ? kwargs.memopt : get_memopt()
configcall = haskey(kwargs, :configcall) ? kwargs.configcall : kernelarg
configcall_kwarg_expr = :(configcall=$configcall)
if memopt
is_ad_highlevel = haskey(kwargs, :∇)
if is_ad_highlevel
ParallelKernel.parallel_call_ad(caller, kernelarg, backend_kwargs_expr, async, package, posargs, kwargs)
elseif memopt
if (length(posargs) > 1) @ArgumentError("maximum one positional argument (ranges) is allowed in a @parallel memopt=true call.") end
parallel_call_memopt(caller, posargs..., kernelarg, backend_kwargs_expr, async; kwargs...)
else
Expand Down
7 changes: 7 additions & 0 deletions test/ParallelKernel/test_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ end
@test @prettystring(1, @parallel stream=mystream f(A)) == "f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))"
end;
end;
@testset "@parallel ∇" begin
@test @prettystring(1, @parallel=B->f!(A, B, a)) == "@parallel_async configcall = f!(A, B, a) autodiff_deferred!(Enzyme.Reverse, f!, (EnzymeCore.Const)(A), (EnzymeCore.DuplicatedNoNeed)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) f!(A, B, a)) == "@parallel_async configcall = f!(A, B, a) autodiff_deferred!(Enzyme.Reverse, f!, (EnzymeCore.DuplicatedNoNeed)(A, Ā), (EnzymeCore.DuplicatedNoNeed)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward f!(A, B, a)) == "@parallel_async configcall = f!(A, B, a) autodiff_deferred!(Enzyme.Forward, f!, (EnzymeCore.DuplicatedNoNeed)(A, Ā), (EnzymeCore.DuplicatedNoNeed)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward ad_annotations=(Duplicated=B) f!(A, B, a)) == "@parallel_async configcall = f!(A, B, a) autodiff_deferred!(Enzyme.Forward, f!, (EnzymeCore.DuplicatedNoNeed)(A, Ā), (EnzymeCore.Duplicated)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward ad_annotations=(Duplicated=(B,A), Active=b) f!(A, B, a, b)) == "@parallel_async configcall = f!(A, B, a, b) autodiff_deferred!(Enzyme.Forward, f!, (EnzymeCore.Duplicated)(A, Ā), (EnzymeCore.Duplicated)(B, B̄), (EnzymeCore.Const)(a), (EnzymeCore.Active)(b))"
end;
@testset "@parallel_indices" begin
@testset "addition of range arguments" begin
expansion = @gorgeousstring(1, @parallel_indices (ix,iy) f(a::T, b::T) where T <: Union{Array{Float32}, Array{Float64}} = (println("a=$a, b=$b)"); return))
Expand Down

0 comments on commit 19b8835

Please sign in to comment.