Skip to content

Commit

Permalink
Merge pull request #398 from JuliaDiff/ox/optout
Browse files Browse the repository at this point in the history
Add opting out of rules
  • Loading branch information
oxinabox authored Jul 20, 2021
2 parents 460a559 + db35df7 commit 6efb2d2
Show file tree
Hide file tree
Showing 7 changed files with 263 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ makedocs(
"Introduction" => "index.md",
"FAQ" => "FAQ.md",
"Rule configurations and calling back into AD" => "config.md",
"Opting out of rules" => "opting_out_of_rules.md",
"Writing Good Rules" => "writing_good_rules.md",
"Complex Numbers" => "complex.md",
"Deriving Array Rules" => "arrays.md",
Expand Down
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,6 @@ ProjectTo
```@docs
ChainRulesCore.AbstractTangent
ChainRulesCore.debug_mode
ChainRulesCore.no_rrule
ChainRulesCore.no_frule
```
98 changes: 98 additions & 0 deletions docs/src/opting_out_of_rules.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# [Opting out of rules](@id opt_out)

It is common to define rules fairly generically.
Often matching (or exceeding) how generic the matching original primal method is.
Sometimes this is not the correct behaviour.
Sometimes the AD can do better than this human defined rule.
If this is generally the case, then we should not have the rule defined at all.
But if it is only the case for a particular set of types, then we want to opt-out just that one.
This is done with the [`@opt_out`](@ref) macro.

Consider one a `rrule` for `sum` (the following simplified from the one in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl) itself)
```julia
function rrule(::typeof(sum), x::AbstractArray{<:Number}; dims=:)
y = sum(x; dims=dims)
project = ProjectTo(x)
function sum_pullback(ȳ)
# broadcasting the two works out the size no-matter `dims`
# project makes sure we stay in the same vector subspace as `x`
# no putting in off-diagonal entries in Diagonal etc
= project(broadcast(lasttuple, x, ȳ)))
return (NoTangent(), x̄)
end
return y, sum_pullback
end
```

That is a fairly reasonable `rrule` for the vast majority of cases.

You might have a custom array type for which you could write a faster rule.
For example, the pullback for summing a [`SkewSymmetric` (anti-symmetric)](https://en.wikipedia.org/wiki/Skew-symmetric_matrix) matrix can be optimized to basically be `Diagonal(fill(ȳ, size(x,1)))`.
To do that, you can indeed write another more specific [`rrule`](@ref).
But another case is where the AD system itself would generate a more optimized case.

For example, the [`NamedDimsArray`](https://github.com/invenia/NamedDims.jl) is a thin wrapper around some other array type.
Its sum method is basically just to call `sum` on its parent.
It is entirely conceivable[^1] that the AD system can do better than our `rrule` here.
For example by avoiding the overhead of [`project`ing](@ref ProjectTo).

To opt-out of using the generic `rrule` and to allow the AD system to do its own thing we use the
[`@opt_out`](@ref) macro, to say to not use it for sum of `NamedDimsArrays`.

```julia
@opt_out rrule(::typeof(sum), ::NamedDimsArray)
```

We could even opt-out for all 1 arg functions.
```@julia
@opt_out rrule(::Any, ::NamedDimsArray)
```
Though this is likely to cause some method-ambiguities.

Similar can be done `@opt_out frule`.
It can also be done passing in a [`RuleConfig`](@ref config).


!!! warning "If the general rule uses a config, the opt-out must also"
Following the same principles as for [rules with config](@ref config), a rule with a `RuleConfig` argument will take precedence over one without, including if that one is a opt-out rule.
But if the general rule does not use a config, then the opt-out rule *can* use a config.
This allows, for example, you to use opt-out to avoid a particular AD system using a opt-out rule that takes that particular AD's config.


## How to support this (for AD implementers)

We provide two ways to know that a rule has been opted out of.

### `rrule` / `frule` returns `nothing`

`@opt_out` defines a `frule` or `rrule` matching the signature that returns `nothing`.

If you are in a position to generate code, in response to values returned by function calls then you can do something like:
```@julia
res = rrule(f, xs)
if res === nothing
y, pullback = perform_ad_via_decomposition(r, xs) # do AD without hitting the rrule
else
y, pullback = res
end
```
The Julia compiler will specialize based on inferring the return type of `rrule`, and so can remove that branch.

### `no_rrule` / `no_frule` has a method

`@opt_out` also defines a method for [`ChainRulesCore.no_frule`](@ref) or [`ChainRulesCore.no_rrule`](@ref).
The body of this method doesn't matter, what matters is that it is a method-table.
A simple thing you can do with this is not support opting out.
To do this, filter all methods from the `rrule`/`frule` method table that also occur in the `no_frule`/`no_rrule` table.
This will thus avoid ever hitting an `rrule`/`frule` that returns `nothing` (and thus prevents your library from erroring).
This is easily done, though it does mean ignoring the user's stated desire to opt out of the rule.

More complex you can use this to generate code that triggers your AD.
If for a given signature there is a more specific method in the `no_rrule`/`no_frule` method-table, than the one that would be hit from the `rrule`/`frule` table
(Excluding the one that exactly matches which will return `nothing`) then you know that the rule should not be used.
You can, likely by looking at the primal method table, workout which method you would have it if the rule had not been defined,
and then `invoke` it.



[^1]: It is also possible, that this is not the case. Benchmark your real uses cases.
3 changes: 2 additions & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module ChainRulesCore
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
using Base.Meta
using LinearAlgebra
using SparseArrays: SparseVector, SparseMatrixCSC
using Compat: hasfield
Expand All @@ -9,7 +10,7 @@ export frule, rrule # core function
export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode
export frule_via_ad, rrule_via_ad
# definition helper macros
export @non_differentiable, @scalar_rule, @thunk, @not_implemented
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
export ProjectTo, canonicalize, unthunk # differential operations
export add!! # gradient accumulation operations
# differentials
Expand Down
79 changes: 75 additions & 4 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# These are some macros (and supporting functions) to make it easier to define rules.
using Base.Meta

# Note: must be declared before it is used, which is later in this file.
macro strip_linenos(expr)
return esc(Base.remove_linenums!(expr))
end

############################################################################################
### @scalar_rule

"""
@scalar_rule(f(x₁, x₂, ...),
@setup(statement₁, statement₂, ...),
Expand Down Expand Up @@ -88,7 +91,6 @@ macro scalar_rule(call, maybe_setup, partials...)
frule_expr = scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
rrule_expr = scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)

############################################################################
# Final return: building the expression to insert in the place of this macro
code = quote
if !($f isa Type) && fieldcount(typeof($f)) > 0
Expand All @@ -114,7 +116,6 @@ returns (in order) the correctly escaped:
- `partials`: which are all `Expr{:tuple,...}`
"""
function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
############################################################################
# Setup: normalizing input form etc

if Meta.isexpr(maybe_setup, :macrocall) && maybe_setup.args[1] == Symbol("@setup")
Expand Down Expand Up @@ -275,6 +276,9 @@ propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propna
propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname)
propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname)

############################################################################################
### @non_differentiable

"""
@non_differentiable(signature_expression)
Expand Down Expand Up @@ -394,7 +398,74 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
end


###########
############################################################################################
# @opt_out

"""
@opt_out frule([config], _, f, args...)
@opt_out rrule([config], f, args...)
This allows you to opt-out of an `frule` or an `rrule` by providing a more specific method,
that says to use the AD system to differentiate it.
For example, consider some function `foo(x::AbtractArray)`.
In general, you know an efficient and generic way to implement its `rrule`.
You do so, (likely making use of [`ProjectTo`](@ref)).
But it actually turns out that for some `FancyArray` type it is better to let the AD do its
thing.
Then you would write something like:
```julia
function rrule(::typeof(foo), x::AbstractArray)
foo_pullback(ȳ) = ...
return foo(x), foo_pullback
end
@opt_out rrule(::typeof(foo), ::FancyArray)
```
This will generate an [`rrule`](@ref) that returns `nothing`,
and will also add a similar entry to [`ChainRulesCore.no_rrule`](@ref).
Similar applies for [`frule`](@ref) and [`ChainRulesCore.no_frule`](@ref)
For more information see the [documentation on opting out of rules](@ref opt_out).
"""
macro opt_out(expr)
no_rule_target = _no_rule_target_rewrite!(deepcopy(expr))

return @strip_linenos quote
$(esc(no_rule_target)) = nothing
$(esc(expr)) = nothing
end
end

"Rewrite method sig Expr for `rrule` to be for `no_rrule`, and `frule` to be `no_frule`."
function _no_rule_target_rewrite!(expr::Expr)
length(expr.args)===0 && error("Malformed method expression. $expr")
if expr.head === :call || expr.head === :where
expr.args[1] = _no_rule_target_rewrite!(expr.args[1])
elseif expr.head == :(.) && expr.args[1] == :ChainRulesCore
expr = _no_rule_target_rewrite!(expr.args[end])
else
error("Malformed method expression. $(expr)")
end
return expr
end
_no_rule_target_rewrite!(qt::QuoteNode) = _no_rule_target_rewrite!(qt.value)
function _no_rule_target_rewrite!(call_target::Symbol)
return if call_target == :rrule
:(ChainRulesCore.no_rrule)
elseif call_target == :frule
:(ChainRulesCore.no_frule)
else
error("Unexpected opt-out target. Exprected frule or rrule, got: $call_target")
end
end



############################################################################################
# Helpers

"""
Expand Down
57 changes: 57 additions & 0 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,60 @@ const rrule_kwfunc = Core.kwftype(typeof(rrule)).instance
function (::typeof(rrule_kwfunc))(kws::Any, ::typeof(rrule), ::RuleConfig, args...)
return rrule_kwfunc(kws, rrule, args...)
end

##############################################################
### Opt out functionality

const NO_RRULE_DOC = """
no_rrule
This is an piece of infastructure supporting opting out of [`rrule`](@ref).
It follows the signature for `rrule` exactly.
A collection of type-tuples is stored in its method-table.
If something has this defined, it means that it must having a must also have a `rrule`,
defined that returns `nothing`.
!!! warning "do not overload no_rrule directly
It is fine and intended to query the method table of `no_rrule`.
It is not safe to add to that directly, as corresponding changes also need to be made to
`rrule`.
The [`@opt_out`](@ref) macro does both these things, and so should almost always be used
rather than defining a method of `no_rrule` directly.
### Mechanics
note: when the text below says methods `==` it actually means:
`parameters(m.sig)[2:end]` (i.e. the signature type tuple) rather than the method object `m` itself.
To decide if should opt-out using this mechanism.
- find the most specific method of `rrule` and `no_rule` e.g with `Base.which`
- if the method of `no_rrule` `==` the method of `rrule`, then should opt-out
To just ignore the fact that rules can be opted-out from, and that some rules thus return
`nothing`, then filter the list of methods of `rrule` to remove those that are `==` to ones
that occur in the method table of `no_rrule`.
Note also when doing this you must still also handle falling back from rule with config, to
rule without config.
On the other-hand if your AD can work with `rrule`s that return `nothing`, then it is
simpler to just use that mechanism for opting out; and you don't need to worry about this
at all.
For more information see the [documentation on opting out of rules](@ref opt_out)
"""

"""
$NO_RRULE_DOC
See also [`ChainRulesCore.no_frule`](@ref).
"""
function no_rrule end
no_rrule(::Any, ::Vararg{Any}) = nothing

"""
$(replace(NO_RRULE_DOC, "rrule"=>"frule"))
See also [`ChainRulesCore.no_rrule`](@ref).
"""
function no_frule end
no_frule(ȧrgs, f, ::Vararg{Any}) = nothing
28 changes: 28 additions & 0 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,32 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
@test_skip ∂xr isa Float64 # to be made true with projection
@test_skip ∂xr real(∂x)
end


@testset "@opt_out" begin
first_oa(x, y) = x
@scalar_rule(first_oa(x, y), (1, 0))
@opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where T<:Float32
@opt_out(
ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where T<:Float32
)

@testset "rrule" begin
@test rrule(first_oa, 3.0, 4.0)[2](1) == (NoTangent(), 1, 0)
@test rrule(first_oa, 3f0, 4f0) === nothing

@test !isempty(Iterators.filter(methods(ChainRulesCore.no_rrule)) do m
m.sig <:Tuple{Any, typeof(first_oa), T, T} where T<:Float32
end)
end

@testset "frule" begin
@test frule((NoTangent(), 1,0), first_oa, 3.0, 4.0) == (3.0, 1)
@test frule((NoTangent(), 1,0), first_oa, 3f0, 4f0) === nothing

@test !isempty(Iterators.filter(methods(ChainRulesCore.no_frule)) do m
m.sig <:Tuple{Any, Any, typeof(first_oa), T, T} where T<:Float32
end)
end
end
end

0 comments on commit 6efb2d2

Please sign in to comment.