Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding EnzymeRules #503

Closed
sethaxen opened this issue Apr 19, 2023 · 9 comments
Closed

Adding EnzymeRules #503

sethaxen opened this issue Apr 19, 2023 · 9 comments

Comments

@sethaxen
Copy link

sethaxen commented Apr 19, 2023

Motivation and description

On Slack, @ChrisRackauckas said that the functions defined in https://github.com/FluxML/NNlibCUDA.jl/tree/master/src/cudnn will need EnzymeCore.EnzymeRules overloads to work with Enzyme, since they use precompiled shipped binaries.

Or would it be better to provide overloads for functions defined in https://github.com/JuliaGPU/CUDA.jl/tree/master/lib/cudnn?

Possible Implementation

This could be done on Julia v1.9 using an extension module and on pre-v1.9 the module could be loaded directly as described in https://pkgdocs.julialang.org/dev/creating-packages/#Transition-from-normal-dependency-to-extension, since EnzymeCore's sole dependency is Adapt, which is also a dependency here.

@ToucheSir
Copy link
Member

Or would it be better to provide overloads for functions defined in https://github.com/JuliaGPU/CUDA.jl/tree/master/lib/cudnn?

Not sure. It would be enlightening to have some other examples of rules for CUDA libraries before doing this.

on pre-v1.9 the module could be loaded directly as described in https://pkgdocs.julialang.org/dev/creating-packages/#Transition-from-normal-dependency-to-extension, since EnzymeCore's sole dependency is Adapt, which is also a dependency here.

I would want to hold off on this until it's known how the GPU stack plans on using Enzyme. If all those packages add EnzymeCore as a weak dep, then I'd rather not add it as a hard dep here.

@sethaxen
Copy link
Author

Or would it be better to provide overloads for functions defined in https://github.com/JuliaGPU/CUDA.jl/tree/master/lib/cudnn?

Not sure. It would be enlightening to have some other examples of rules for CUDA libraries before doing this.

EnzymeCore is quite new, so there are few rules out there to point to as examples yet. But it's designed to be used similarly to ChainRulesCore. Where are ChainRulesCore rules defined for these CUDA libraries? Might make sense to define the EnzymeCore rules in the same place.

One difference is that many ChainRules rules are intended to cover up language features not supported or poorly supported by Diffractor/Zygote, especially control flow and mutation, so the rules are often implemented on higher level functions with generic types, whereas Enzyme supports more language features so can be implemented at a lower level with more concrete types.

I would want to hold off on this until it's known how the GPU stack plans on using Enzyme. If all those packages add EnzymeCore as a weak dep, then I'd rather not add it as a hard dep here.

It can always be added as weak dep first and made a hard dep later to support older Julia versions.

@ToucheSir
Copy link
Member

Where are ChainRulesCore rules defined for these CUDA libraries?

Perhaps surprisingly, nowhere. They're just defined against stdlib APIs and tested with GPU arrays. However, this doesn't really work with Enzyme because e.g. CPU BLAS and cuBLAS don't follow the same API. Thus I'm interested in how more straightforward libraries like cuBLAS and cuFFT will get rules: will they be defined at the CUDA wrapper level, or against the LinearAlgebra and AbstractFFT APIs respectively?

It can always be added as weak dep first and made a hard dep later to support older Julia versions.

Yes, that's always the fallback. If a dependency of NNlib picks up EnzymeCore as a hard dep however (e.g. JuliaGPU/KernelAbstractions.jl#382), we might as well make it one too.

@sethaxen
Copy link
Author

Perhaps surprisingly, nowhere. They're just defined against stdlib APIs and tested with GPU arrays.

Oh, I meant the kernels. But I think I answered my own question. The rrules are defined on NNLib's API functions and in NNLib, right? e.g.

NNlib.jl/src/conv.jl

Lines 341 to 383 in acf87f5

for conv in [:conv, :depthwiseconv]
local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter])
conv_pullback, ∇conv_data_pullback = Symbol.([conv, ∇conv_data], :_pullback)
@eval function rrule(::typeof($conv), x, w, cdims; kw...)
function $conv_pullback(Δ)
Δ = colmajor(Δ)
return (
NoTangent(),
@thunk($∇conv_data(unthunk(Δ), w, cdims, kw...)),
@thunk($∇conv_filter(x, unthunk(Δ), cdims, kw...)),
NoTangent(),
)
end
return $conv(x, w, cdims; kw...), $conv_pullback
end
@eval function rrule(::typeof($∇conv_data), x, w, cdims; kw...)
function $∇conv_data_pullback(Δ)
Δ = colmajor(Δ)
return (
NoTangent(),
@thunk($conv(unthunk(Δ), w, cdims, kw...)),
@thunk($∇conv_filter(unthunk(Δ), x, cdims, kw...)),
NoTangent(),
)
end
return $∇conv_data(x, w, cdims; kw...), $∇conv_data_pullback
end
end
function rrule(::typeof(∇conv_filter), x, dy, cdims; kw...)
function ∇conv_filter_pullback(Δ)
Δ1 = colmajor(unthunk(Δ))
return (
NoTangent(),
@thunk(∇conv_data(dy, Δ1, cdims, kw...)),
@thunk(conv(x, Δ1, cdims, kw...)),
NoTangent(),
)
end
return ∇conv_filter(x, dy, cdims; kw...), ∇conv_filter_pullback
end

Thus I'm interested in how more straightforward libraries like cuBLAS and cuFFT will get rules: will they be defined at the CUDA wrapper level, or against the LinearAlgebra and AbstractFFT APIs respectively?

Good question. It makes sense to define them in terms of the high-level API functions only if the following conditions are fulfilled:

  1. if the resulting rules would really work for the CUDA array types
  2. if using the high-level API doesn't result in excessive allocations that could otherwise be removed by working at a lower level
  3. if a user is highly unlikely to ever call a lower-level function that we could otherwise have defined a rule for
  4. if the amount of work needed to write the rules at a high-level would be significantly less than at a low level.

Naively, I would guess defining rules at the level of LinearAlgebra, AbstractFFT, and NNLib would be sufficient, but specialized on CuArrays, since that seems to be what was done with ChainRules. But I'm not very familiar with this stack. What is your sense about which of the above conditions are fulfilled?

@ToucheSir
Copy link
Member

I'm not sure. I would've agreed with your criteria and ideas on where to define rules, but then I see rules being defined for functions like gemm! vs mul! (which is closer to the level NNlib operates at) and am less certain.

A more concrete concern is where to place these rules with our current package extension structure. If EnzymeCore becomes a direct dep, then they can live in the existing GPU backend extensions. If it becomes a weak dep, we'd need to add another extension for each GPU backend just for the rules. Happy to accept PRs using either approach for now, but my inclination is to hold off on tagging until the dust settles there.

@sethaxen
Copy link
Author

Another data point: there's an open PR to add EnzymeCore as an extension to CUDA.jl with some basic forward-mode rules: JuliaGPU/CUDA.jl#1869

@ToucheSir
Copy link
Member

Yes, that's the main one I've been following to see what should be done on the NNlib side.

@wsmoses
Copy link
Contributor

wsmoses commented Sep 24, 2023

See #536 for a conv example, adding these to the other layers would be incredibly valuable as well!

@wsmoses
Copy link
Contributor

wsmoses commented May 11, 2024

Given that the conv, scatter, gather, and other rules are presently in place, I propose that this issue be closed and new issues for any unsupported functions be added?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants