Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Miha Zgubic <[email protected]>
  • Loading branch information
oxinabox and mzgubic authored Jul 20, 2021
1 parent cd8babc commit 819e56a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
26 changes: 13 additions & 13 deletions docs/src/opting_out_of_rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

It is common to define rules fairly generically.
Often matching (or exceeding) how generic the matching original primal method is.
Sometimes this is not the correct behavour.
Sometimes this is not the correct behaviour.
Sometimes the AD can do better than this human defined rule.
If this is generally the case, then we should not have the rule defined at all.
But if it is only the case for a particular set of types, then we want to opt-out just that one.
This is done with the [`@opt_out`](@ref) macro.

Consider one might have a rrule for `sum` (the following simplified from the one in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl) itself)
Consider one a `rrule` for `sum` (the following simplified from the one in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl) itself)
```julia
function rrule(::typeof(sum), x::AbstractArray{<:Number}; dims=:)
y = sum(x; dims=dims)
Expand All @@ -27,17 +27,17 @@ end
That is a fairly reasonable `rrule` for the vast majority of cases.

You might have a custom array type for which you could write a faster rule.
For example, the pullback for summing a`SkewSymmetric` matrix can be optimizes to basically be `Diagonal(fill(ȳ, size(x,1)))`.
For example, the pullback for summing a [`SkewSymmetric` (anti-symmetric)](https://en.wikipedia.org/wiki/Skew-symmetric_matrix) matrix can be optimized to basically be `Diagonal(fill(ȳ, size(x,1)))`.
To do that, you can indeed write another more specific [`rrule`](@ref).
But another case is where the AD system itself would generate a more optimized case.

For example, the a [`NamedDimArray`](https://github.com/invenia/NamedDims.jl) is a thin wrapper around some other array type.
It's sum method is basically just to call `sum` on it's parent.
For example, the [`NamedDimsArray`](https://github.com/invenia/NamedDims.jl) is a thin wrapper around some other array type.
Its sum method is basically just to call `sum` on its parent.
It is entirely conceivable[^1] that the AD system can do better than our `rrule` here.
For example by avoiding the overhead of [`project`ing](@ref ProjectTo).

To opt-out of using the `rrule` and to allow the AD system to do its own thing we use the
[`@opt_out`](@ref) macro, to say to not use it for sum.
To opt-out of using the generic `rrule` and to allow the AD system to do its own thing we use the
[`@opt_out`](@ref) macro, to say to not use it for sum of `NamedDimsArrays`.

```julia
@opt_out rrule(::typeof(sum), ::NamedDimsArray)
Expand All @@ -53,11 +53,11 @@ Similar can be done `@opt_out frule`.
It can also be done passing in a [`RuleConfig`](@ref config).


### How to support this (for AD implementers)
## How to support this (for AD implementers)

We provide two ways to know that a rule has been opted out of.

## `rrule` / `frule` returns `nothing`
### `rrule` / `frule` returns `nothing`

`@opt_out` defines a `frule` or `rrule` matching the signature that returns `nothing`.

Expand All @@ -70,12 +70,12 @@ else
y, pullback = res
end
```
The Julia compiler, will specialize based on inferring the restun type of `rrule`, and so can remove that branch.
The Julia compiler will specialize based on inferring the return type of `rrule`, and so can remove that branch.

## `no_rrule` / `no_frule` has a method
### `no_rrule` / `no_frule` has a method

`@opt_out` also defines a method for [`ChainRulesCore.no_frule`](@ref) or [`ChainRulesCore.no_rrule`](@ref).
The use of this method doesn't matter, what matters is it's method-table.
The body of this method doesn't matter, what matters is that it is a method-table.
A simple thing you can do with this is not support opting out.
To do this, filter all methods from the `rrule`/`frule` method table that also occur in the `no_frule`/`no_rrule` table.
This will thus avoid ever hitting an `rrule`/`frule` that returns `nothing` and thus makes your library error.
Expand All @@ -89,4 +89,4 @@ and then `invoke` it.



[^1]: It is also possible, that this is not the case. Benchmark your real uses cases.
[^1]: It is also possible, that this is not the case. Benchmark your real uses cases.
10 changes: 5 additions & 5 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,13 +403,13 @@ end
@opt_out frule([config], _, f, args...)
@opt_out rrule([config], f, args...)
This allows you to opt-out of a `frule` or `rrule` by providing a more specific method,
that says to use the AD system, to solver it.
This allows you to opt-out of an `frule` or an `rrule` by providing a more specific method,
that says to use the AD system to differentiate it.
For example, consider some function `foo(x::AbtractArray)`.
In general, you know a efficicent and generic way to implement it's `rrule`.
In general, you know an efficient and generic way to implement its `rrule`.
You do so, (likely making use of [`ProjectTo`](@ref)).
But it actually turns out that for some `FancyArray` type it is better to let the AD do it's
But it actually turns out that for some `FancyArray` type it is better to let the AD do its
thing.
Then you would write something like:
Expand All @@ -422,7 +422,7 @@ end
@opt_out rrule(::typeof(foo), ::FancyArray)
```
This will generate a [`rrule`](@ref) that returns `nothing`,
This will generate an [`rrule`](@ref) that returns `nothing`,
and will also add a similar entry to [`ChainRulesCore.no_rrule`](@ref).
Similar applies for [`frule`](@ref) and [`ChainRulesCore.no_frule`](@ref)
Expand Down
8 changes: 4 additions & 4 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ We use it as a way to store a collection of type-tuples in its method-table.
If something has this defined, it means that it must having a must also have a `rrule`,
that returns `nothing`.
### Machanics
note: when this says methods `==` or `<:` it actually means:
`parameters(m.sig)[2:end]` rather than the method object `m` itself.
### Mechanics
note: when the text below says methods `==` or `<:` it actually means:
`parameters(m.sig)[2:end]` (i.e. the signature type tuple) rather than the method object `m` itself.
To decide if should opt-out using this mechanism.
- find the most specific method of `rrule`
Expand Down Expand Up @@ -187,4 +187,4 @@ $(replace(NO_RRULE_DOC, "rrule"=>"frule"))
See also [`ChainRulesCore.no_rrule`](@ref).
"""
function no_frule end
no_frule(ȧrgs, f, ::Vararg{Any}) = nothing
no_frule(ȧrgs, f, ::Vararg{Any}) = nothing

0 comments on commit 819e56a

Please sign in to comment.