Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add opting out of rules #398

Merged
merged 6 commits into from
Jul 20, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ makedocs(
"Introduction" => "index.md",
"FAQ" => "FAQ.md",
"Rule configurations and calling back into AD" => "config.md",
"Opting out of rules" => "opting_out_of_rules.md",
"Writing Good Rules" => "writing_good_rules.md",
"Complex Numbers" => "complex.md",
"Deriving Array Rules" => "arrays.md",
Expand Down
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,6 @@ ProjectTo
```@docs
ChainRulesCore.AbstractTangent
ChainRulesCore.debug_mode
ChainRulesCore.no_rrule
ChainRulesCore.no_frule
```
92 changes: 92 additions & 0 deletions docs/src/opting_out_of_rules.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Opting out of rules

It is common to define rules fairly generically.
Often matching (or exceeding) how generic the matching original primal method is.
Sometimes this is not the correct behaviour.
Sometimes the AD can do better than this human defined rule.
If this is generally the case, then we should not have the rule defined at all.
But if it is only the case for a particular set of types, then we want to opt-out just that one.
This is done with the [`@opt_out`](@ref) macro.

Consider one a `rrule` for `sum` (the following simplified from the one in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl) itself)
```julia
function rrule(::typeof(sum), x::AbstractArray{<:Number}; dims=:)
y = sum(x; dims=dims)
project = ProjectTo(x)
function sum_pullback(ȳ)
# broadcasting the two works out the size no-matter `dims`
# project makes sure we stay in the same vector subspace as `x`
# no putting in off-diagonal entries in Diagonal etc
x̄ = project(broadcast(last∘tuple, x, ȳ)))
return (NoTangent(), x̄)
end
return y, sum_pullback
end
```

That is a fairly reasonable `rrule` for the vast majority of cases.

You might have a custom array type for which you could write a faster rule.
For example, the pullback for summing a [`SkewSymmetric` (anti-symmetric)](https://en.wikipedia.org/wiki/Skew-symmetric_matrix) matrix can be optimized to basically be `Diagonal(fill(ȳ, size(x,1)))`.
To do that, you can indeed write another more specific [`rrule`](@ref).
But another case is where the AD system itself would generate a more optimized case.

For example, the [`NamedDimsArray`](https://github.com/invenia/NamedDims.jl) is a thin wrapper around some other array type.
Its sum method is basically just to call `sum` on its parent.
It is entirely conceivable[^1] that the AD system can do better than our `rrule` here.
For example by avoiding the overhead of [`project`ing](@ref ProjectTo).

To opt-out of using the generic `rrule` and to allow the AD system to do its own thing we use the
[`@opt_out`](@ref) macro, to say to not use it for sum of `NamedDimsArrays`.

```julia
@opt_out rrule(::typeof(sum), ::NamedDimsArray)
```

We could even opt-out for all 1 arg functions.
```@julia
@opt_out rrule(::Any, ::NamedDimsArray)
```
Though this is likely to cause some method-ambiguities.

Similar can be done `@opt_out frule`.
It can also be done passing in a [`RuleConfig`](@ref config).


## How to support this (for AD implementers)

We provide two ways to know that a rule has been opted out of.

### `rrule` / `frule` returns `nothing`

`@opt_out` defines a `frule` or `rrule` matching the signature that returns `nothing`.

If you are in a position to generate code, in response to values returned by function calls then you can do something like:
```@julia
res = rrule(f, xs)
if res === nothing
y, pullback = perform_ad_via_decomposition(r, xs) # do AD without hitting the rrule
else
y, pullback = res
end
```
The Julia compiler will specialize based on inferring the return type of `rrule`, and so can remove that branch.

### `no_rrule` / `no_frule` has a method

`@opt_out` also defines a method for [`ChainRulesCore.no_frule`](@ref) or [`ChainRulesCore.no_rrule`](@ref).
The body of this method doesn't matter, what matters is that it is a method-table.
A simple thing you can do with this is not support opting out.
To do this, filter all methods from the `rrule`/`frule` method table that also occur in the `no_frule`/`no_rrule` table.
This will thus avoid ever hitting an `rrule`/`frule` that returns `nothing` (and thus prevents your library from erroring).
This is easily done, though it does mean ignoring the user's stated desire to opt out of the rule.

More complex you can use this to generate code that triggers your AD.
If for a given signature there is a more specific method in the `no_rrule`/`no_frule` method-table, than the one that would be hit from the `rrule`/`frule` table
(Excluding the one that exactly matches which will return `nothing`) then you know that the rule should not be used.
You can, likely by looking at the primal method table, workout which method you would have it if the rule had not been defined,
and then `invoke` it.



[^1]: It is also possible, that this is not the case. Benchmark your real uses cases.
3 changes: 2 additions & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module ChainRulesCore
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
using Base.Meta
using LinearAlgebra
using SparseArrays: SparseVector, SparseMatrixCSC
using Compat: hasfield
Expand All @@ -9,7 +10,7 @@ export frule, rrule # core function
export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode
export frule_via_ad, rrule_via_ad
# definition helper macros
export @non_differentiable, @scalar_rule, @thunk, @not_implemented
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
export ProjectTo, canonicalize, unthunk # differential operations
export add!! # gradient accumulation operations
# differentials
Expand Down
75 changes: 71 additions & 4 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# These are some macros (and supporting functions) to make it easier to define rules.
using Base.Meta

# Note: must be declared before it is used, which is later in this file.
macro strip_linenos(expr)
return esc(Base.remove_linenums!(expr))
end

############################################################################################
### @scalar_rule

"""
@scalar_rule(f(x₁, x₂, ...),
@setup(statement₁, statement₂, ...),
Expand Down Expand Up @@ -88,7 +91,6 @@ macro scalar_rule(call, maybe_setup, partials...)
frule_expr = scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
rrule_expr = scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)

############################################################################
# Final return: building the expression to insert in the place of this macro
code = quote
if !($f isa Type) && fieldcount(typeof($f)) > 0
Expand All @@ -114,7 +116,6 @@ returns (in order) the correctly escaped:
- `partials`: which are all `Expr{:tuple,...}`
"""
function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
############################################################################
# Setup: normalizing input form etc

if Meta.isexpr(maybe_setup, :macrocall) && maybe_setup.args[1] == Symbol("@setup")
Expand Down Expand Up @@ -275,6 +276,9 @@ propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propna
propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname)
propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname)

############################################################################################
### @non_differentiable

"""
@non_differentiable(signature_expression)

Expand Down Expand Up @@ -394,7 +398,70 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
end


###########
############################################################################################
# @opt_out

"""
@opt_out frule([config], _, f, args...)
@opt_out rrule([config], f, args...)

This allows you to opt-out of an `frule` or an `rrule` by providing a more specific method,
that says to use the AD system to differentiate it.

For example, consider some function `foo(x::AbtractArray)`.
In general, you know an efficient and generic way to implement its `rrule`.
You do so, (likely making use of [`ProjectTo`](@ref)).
But it actually turns out that for some `FancyArray` type it is better to let the AD do its
thing.

Then you would write something like:
```julia
function rrule(::typeof(foo), x::AbstractArray)
foo_pullback(ȳ) = ...
return foo(x), foo_pullback
end

@opt_out rrule(::typeof(foo), ::FancyArray)
```

This will generate an [`rrule`](@ref) that returns `nothing`,
and will also add a similar entry to [`ChainRulesCore.no_rrule`](@ref).

Similar applies for [`frule`](@ref) and [`ChainRulesCore.no_frule`](@ref)
"""
macro opt_out(expr)
no_rule_target = _no_rule_target_rewrite!(deepcopy(expr))

return @strip_linenos quote
$(esc(no_rule_target)) = nothing
$(esc(expr)) = nothing
Comment on lines +438 to +439
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this escaping will mean this doesn't work in anything that doesn't have ChainRulesCore imported.
It needs a test in the isolated scope testset.
But I am ok leaving that for a follow up PR.

end
end

function _no_rule_target_rewrite!(call_target::Symbol)
return if call_target == :rrule
:(ChainRulesCore.no_rrule)
elseif call_target == :frule
:(ChainRulesCore.no_frule)
else
error("Unexpected opt-out target. Exprected frule or rrule, got: $call_target")
end
end
_no_rule_target_rewrite!(qt::QuoteNode) = _no_rule_target_rewrite!(qt.value)
function _no_rule_target_rewrite!(expr::Expr)
length(expr.args)===0 && error("Malformed method expression. $expr")
if expr.head === :call || expr.head === :where
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
expr.args[1] = _no_rule_target_rewrite!(expr.args[1])
elseif expr.head == :(.) && expr.args[1] == :ChainRulesCore
expr = _no_rule_target_rewrite!(expr.args[end])
else
error("Malformed method expression. $(expr)")
end
return expr
end
oxinabox marked this conversation as resolved.
Show resolved Hide resolved


############################################################################################
# Helpers

"""
Expand Down
49 changes: 49 additions & 0 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,52 @@ const rrule_kwfunc = Core.kwftype(typeof(rrule)).instance
function (::typeof(rrule_kwfunc))(kws::Any, ::typeof(rrule), ::RuleConfig, args...)
return rrule_kwfunc(kws, rrule, args...)
end

##############################################################
### Opt out functionality

const NO_RRULE_DOC = """
no_rrule
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we maybe use _no_rrule to signal it is an internal thing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not really internal
Zygote, and ChainRulesOverloadGeneration need to access it.

I will instead put a warning in the docstring


This is an implementation detail for opting out of [`rrule`](@ref).
It follows the signature for `rrule` exactly.
We use it as a way to store a collection of type-tuples in its method-table.
If something has this defined, it means that it must having a must also have a `rrule`,
that returns `nothing`.

### Mechanics
note: when the text below says methods `==` or `<:` it actually means:
`parameters(m.sig)[2:end]` (i.e. the signature type tuple) rather than the method object `m` itself.

To decide if should opt-out using this mechanism.
- find the most specific method of `rrule`
- find the most specific method of `no_rrule`
- if the method of `no_rrule` `<:` the method of `rrule`, then should opt-out

To just ignore the fact that rules can be opted-out from, and that some rules thus return
`nothing`, then filter the list of methods of `rrule` to remove those that are `==` to ones
that occur in the method table of `no_rrule`.

Note also when doing this you must still also handle falling back from rule with config, to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we expand on how to do this in the docs? Maybe just mentioning and seeing an example in Nabla or Zygote is enough

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do expand on this in the docs already.

rule without config.

On the other-hand if your AD can work with `rrule`s that return `nothing`, then it is
simpler to just use that mechanism for opting out; and you don't need to worry about this
at all.
"""

"""
$NO_RRULE_DOC

See also [`ChainRulesCore.no_frule`](@ref).
"""
function no_rrule end
no_rrule(::Any, ::Vararg{Any}) = nothing

"""
$(replace(NO_RRULE_DOC, "rrule"=>"frule"))

See also [`ChainRulesCore.no_rrule`](@ref).
"""
function no_frule end
no_frule(ȧrgs, f, ::Vararg{Any}) = nothing
28 changes: 28 additions & 0 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,32 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
@test_skip ∂xr isa Float64 # to be made true with projection
@test_skip ∂xr ≈ real(∂x)
end


@testset "@opt_out" begin
first_oa(x, y) = x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does oa stand for? we could use oo for opt out?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

opt-aut. 😂
idk why I said oa

@scalar_rule(first_oa(x, y), (1, 0))
@opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where T<:Float32
@opt_out(
ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where T<:Float32
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
)

@testset "rrule" begin
@test rrule(first_oa, 3.0, 4.0)[2](1) == (NoTangent(), 1, 0)
@test rrule(first_oa, 3f0, 4f0) === nothing

@test !isempty(Iterators.filter(methods(ChainRulesCore.no_rrule)) do m
m.sig <:Tuple{Any, typeof(first_oa), T, T} where T<:Float32
end)
end

@testset "frule" begin
@test frule((NoTangent(), 1,0), first_oa, 3.0, 4.0) == (3.0, 1)
@test frule((NoTangent(), 1,0), first_oa, 3f0, 4f0) === nothing

@test !isempty(Iterators.filter(methods(ChainRulesCore.no_frule)) do m
m.sig <:Tuple{Any, Any, typeof(first_oa), T, T} where T<:Float32
end)
end
end
end