Package-Independent BLAS / LAPACK Rule System #190
willtebbutt
started this conversation in
Ideas
Replies: 1 comment 1 reply
-
@willtebbutt, would it be possible to make EDIT: can we have a |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Both Enzyme.jl and Tapir.jl can in principle support low-ish level calls to BLAS / LAPACK functionality. However, both really do need rules for them (certainly Tapir.jl does because BLAS / LAPACK sit behind C-call barriers, and I believe I'm correct in saying that Enzyme.jl also needs rules for them). To this end, I would like to consider implementing a restricted set of rules which is
In this sense, it should be more like DiffRules.jl than ChainRules.jl, albeit it won't involve symbolic rules.
In order for this to happen, it would need to support (in some sense) the union of the interfaces for the rule systems for the two packages, and be just flexible enough to cover everything in BLAS / LAPACK.
Key Point: Limited Ambition
The main idea is to avoid having to write BLAS / LAPACK rules twice, once for Enzyme.jl and once for Tapir.jl.
Moreover, I want a system which is entirely focused on BLAS / LAPACK, with no regard for anything more general than this.
This lets us agree on how to handle each of the possible types we might encounter, rather than having to support an open-ended collection -- this is a much easier task.
BLAS / LAPACK Interfaces
To the best of my knowledge, there are basically 3 or 4 types which appear as arguments in BLAS / LAPACK implementations in Julia:
Ptr{T}
T
Char
Int
where
T<:IEEEFloat
. It might be thatChar
andInt
are actually pointers to such types in some cases, but the important features is that they're non-differentiable arguments.The Idea
Implement rules for linear algebra with these types.
They would use shadow memory for pointers, but not for scalars.
They look a lot like both Tapir.jl rules and Enzyme.jl rules.
For example, the rule for gemm would be something like
We would need some data structures to represent activity / whether or not certain memory needs to be restored on the reverse-pass etc, and decide what happens in those situations. Largely I imagine that this will be the same as in Enzyme.jl.
Types which are identified with their address should have shadow memory if they're active, and not if they're not.
How would Tapir.jl make use of such rules?
Tapir.jl targets
LinearAlgebra.BLAS.gemm!
, rather than the C-call to gemm, basically because it's a little bit cleaner.You could use the above to write an
rrule!!
along the following lines:Testing
Testing can be done using FiniteDifferences.jl. Due to the heavily restricted interface, it will be easy to write some simple standardised testing functionality. We would not make it available to other packages, as the intention would not be for other packages to extend this collection of rules -- all of the things we wish to support are already known to us.
Beta Was this translation helpful? Give feedback.
All reactions