diff --git a/docs/make.jl b/docs/make.jl index 1b2063ac4..71c4468c1 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -48,6 +48,7 @@ makedocs( "Introduction" => "index.md", "FAQ" => "FAQ.md", "Rule configurations and calling back into AD" => "config.md", + "Opting out of rules" => "opting_out_of_rules.md", "Writing Good Rules" => "writing_good_rules.md", "Complex Numbers" => "complex.md", "Deriving Array Rules" => "arrays.md", diff --git a/docs/src/api.md b/docs/src/api.md index 3b48151aa..491749f04 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -50,4 +50,6 @@ ProjectTo ```@docs ChainRulesCore.AbstractTangent ChainRulesCore.debug_mode +ChainRulesCore.no_rrule +ChainRulesCore.no_frule ``` \ No newline at end of file diff --git a/docs/src/opting_out_of_rules.md b/docs/src/opting_out_of_rules.md new file mode 100644 index 000000000..87a2403f0 --- /dev/null +++ b/docs/src/opting_out_of_rules.md @@ -0,0 +1,98 @@ +# [Opting out of rules](@id opt_out) + +It is common to define rules fairly generically. +Often matching (or exceeding) how generic the matching original primal method is. +Sometimes this is not the correct behaviour. +Sometimes the AD can do better than this human defined rule. +If this is generally the case, then we should not have the rule defined at all. +But if it is only the case for a particular set of types, then we want to opt-out just that one. +This is done with the [`@opt_out`](@ref) macro. + +Consider one a `rrule` for `sum` (the following simplified from the one in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl) itself) +```julia +function rrule(::typeof(sum), x::AbstractArray{<:Number}; dims=:) + y = sum(x; dims=dims) + project = ProjectTo(x) + function sum_pullback(ȳ) + # broadcasting the two works out the size no-matter `dims` + # project makes sure we stay in the same vector subspace as `x` + # no putting in off-diagonal entries in Diagonal etc + x̄ = project(broadcast(last∘tuple, x, ȳ))) + return (NoTangent(), x̄) + end + return y, sum_pullback +end +``` + +That is a fairly reasonable `rrule` for the vast majority of cases. + +You might have a custom array type for which you could write a faster rule. +For example, the pullback for summing a [`SkewSymmetric` (anti-symmetric)](https://en.wikipedia.org/wiki/Skew-symmetric_matrix) matrix can be optimized to basically be `Diagonal(fill(ȳ, size(x,1)))`. +To do that, you can indeed write another more specific [`rrule`](@ref). +But another case is where the AD system itself would generate a more optimized case. + +For example, the [`NamedDimsArray`](https://github.com/invenia/NamedDims.jl) is a thin wrapper around some other array type. +Its sum method is basically just to call `sum` on its parent. +It is entirely conceivable[^1] that the AD system can do better than our `rrule` here. +For example by avoiding the overhead of [`project`ing](@ref ProjectTo). + +To opt-out of using the generic `rrule` and to allow the AD system to do its own thing we use the +[`@opt_out`](@ref) macro, to say to not use it for sum of `NamedDimsArrays`. + +```julia +@opt_out rrule(::typeof(sum), ::NamedDimsArray) +``` + +We could even opt-out for all 1 arg functions. +```@julia +@opt_out rrule(::Any, ::NamedDimsArray) +``` +Though this is likely to cause some method-ambiguities. + +Similar can be done `@opt_out frule`. +It can also be done passing in a [`RuleConfig`](@ref config). + + +!!! warning "If the general rule uses a config, the opt-out must also" + Following the same principles as for [rules with config](@ref config), a rule with a `RuleConfig` argument will take precedence over one without, including if that one is a opt-out rule. + But if the general rule does not use a config, then the opt-out rule *can* use a config. + This allows, for example, you to use opt-out to avoid a particular AD system using a opt-out rule that takes that particular AD's config. + + +## How to support this (for AD implementers) + +We provide two ways to know that a rule has been opted out of. + +### `rrule` / `frule` returns `nothing` + +`@opt_out` defines a `frule` or `rrule` matching the signature that returns `nothing`. + +If you are in a position to generate code, in response to values returned by function calls then you can do something like: +```@julia +res = rrule(f, xs) +if res === nothing + y, pullback = perform_ad_via_decomposition(r, xs) # do AD without hitting the rrule +else + y, pullback = res +end +``` +The Julia compiler will specialize based on inferring the return type of `rrule`, and so can remove that branch. + +### `no_rrule` / `no_frule` has a method + +`@opt_out` also defines a method for [`ChainRulesCore.no_frule`](@ref) or [`ChainRulesCore.no_rrule`](@ref). +The body of this method doesn't matter, what matters is that it is a method-table. +A simple thing you can do with this is not support opting out. +To do this, filter all methods from the `rrule`/`frule` method table that also occur in the `no_frule`/`no_rrule` table. +This will thus avoid ever hitting an `rrule`/`frule` that returns `nothing` (and thus prevents your library from erroring). +This is easily done, though it does mean ignoring the user's stated desire to opt out of the rule. + +More complex you can use this to generate code that triggers your AD. +If for a given signature there is a more specific method in the `no_rrule`/`no_frule` method-table, than the one that would be hit from the `rrule`/`frule` table +(Excluding the one that exactly matches which will return `nothing`) then you know that the rule should not be used. +You can, likely by looking at the primal method table, workout which method you would have it if the rule had not been defined, +and then `invoke` it. + + + +[^1]: It is also possible, that this is not the case. Benchmark your real uses cases. diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 1114bd6fd..09c3ebc4c 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -1,5 +1,6 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! +using Base.Meta using LinearAlgebra using SparseArrays: SparseVector, SparseMatrixCSC using Compat: hasfield @@ -9,7 +10,7 @@ export frule, rrule # core function export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode export frule_via_ad, rrule_via_ad # definition helper macros -export @non_differentiable, @scalar_rule, @thunk, @not_implemented +export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented export ProjectTo, canonicalize, unthunk # differential operations export add!! # gradient accumulation operations # differentials diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 830b5bba8..13d7d9406 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -1,10 +1,13 @@ # These are some macros (and supporting functions) to make it easier to define rules. -using Base.Meta +# Note: must be declared before it is used, which is later in this file. macro strip_linenos(expr) return esc(Base.remove_linenums!(expr)) end +############################################################################################ +### @scalar_rule + """ @scalar_rule(f(x₁, x₂, ...), @setup(statement₁, statement₂, ...), @@ -88,7 +91,6 @@ macro scalar_rule(call, maybe_setup, partials...) frule_expr = scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) rrule_expr = scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) - ############################################################################ # Final return: building the expression to insert in the place of this macro code = quote if !($f isa Type) && fieldcount(typeof($f)) > 0 @@ -114,7 +116,6 @@ returns (in order) the correctly escaped: - `partials`: which are all `Expr{:tuple,...}` """ function _normalize_scalarrules_macro_input(call, maybe_setup, partials) - ############################################################################ # Setup: normalizing input form etc if Meta.isexpr(maybe_setup, :macrocall) && maybe_setup.args[1] == Symbol("@setup") @@ -275,6 +276,9 @@ propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propna propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname) propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname) +############################################################################################ +### @non_differentiable + """ @non_differentiable(signature_expression) @@ -394,7 +398,74 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) end -########### +############################################################################################ +# @opt_out + +""" + @opt_out frule([config], _, f, args...) + @opt_out rrule([config], f, args...) + +This allows you to opt-out of an `frule` or an `rrule` by providing a more specific method, +that says to use the AD system to differentiate it. + +For example, consider some function `foo(x::AbtractArray)`. +In general, you know an efficient and generic way to implement its `rrule`. +You do so, (likely making use of [`ProjectTo`](@ref)). +But it actually turns out that for some `FancyArray` type it is better to let the AD do its +thing. + +Then you would write something like: +```julia +function rrule(::typeof(foo), x::AbstractArray) + foo_pullback(ȳ) = ... + return foo(x), foo_pullback +end + +@opt_out rrule(::typeof(foo), ::FancyArray) +``` + +This will generate an [`rrule`](@ref) that returns `nothing`, +and will also add a similar entry to [`ChainRulesCore.no_rrule`](@ref). + +Similar applies for [`frule`](@ref) and [`ChainRulesCore.no_frule`](@ref) + +For more information see the [documentation on opting out of rules](@ref opt_out). +""" +macro opt_out(expr) + no_rule_target = _no_rule_target_rewrite!(deepcopy(expr)) + + return @strip_linenos quote + $(esc(no_rule_target)) = nothing + $(esc(expr)) = nothing + end +end + +"Rewrite method sig Expr for `rrule` to be for `no_rrule`, and `frule` to be `no_frule`." +function _no_rule_target_rewrite!(expr::Expr) + length(expr.args)===0 && error("Malformed method expression. $expr") + if expr.head === :call || expr.head === :where + expr.args[1] = _no_rule_target_rewrite!(expr.args[1]) + elseif expr.head == :(.) && expr.args[1] == :ChainRulesCore + expr = _no_rule_target_rewrite!(expr.args[end]) + else + error("Malformed method expression. $(expr)") + end + return expr +end +_no_rule_target_rewrite!(qt::QuoteNode) = _no_rule_target_rewrite!(qt.value) +function _no_rule_target_rewrite!(call_target::Symbol) + return if call_target == :rrule + :(ChainRulesCore.no_rrule) + elseif call_target == :frule + :(ChainRulesCore.no_frule) + else + error("Unexpected opt-out target. Exprected frule or rrule, got: $call_target") + end +end + + + +############################################################################################ # Helpers """ diff --git a/src/rules.jl b/src/rules.jl index 0abc81205..1d9113c24 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -139,3 +139,60 @@ const rrule_kwfunc = Core.kwftype(typeof(rrule)).instance function (::typeof(rrule_kwfunc))(kws::Any, ::typeof(rrule), ::RuleConfig, args...) return rrule_kwfunc(kws, rrule, args...) end + +############################################################## +### Opt out functionality + +const NO_RRULE_DOC = """ + no_rrule + +This is an piece of infastructure supporting opting out of [`rrule`](@ref). +It follows the signature for `rrule` exactly. +A collection of type-tuples is stored in its method-table. +If something has this defined, it means that it must having a must also have a `rrule`, +defined that returns `nothing`. + +!!! warning "do not overload no_rrule directly + It is fine and intended to query the method table of `no_rrule`. + It is not safe to add to that directly, as corresponding changes also need to be made to + `rrule`. + The [`@opt_out`](@ref) macro does both these things, and so should almost always be used + rather than defining a method of `no_rrule` directly. + +### Mechanics +note: when the text below says methods `==` it actually means: +`parameters(m.sig)[2:end]` (i.e. the signature type tuple) rather than the method object `m` itself. + +To decide if should opt-out using this mechanism. + - find the most specific method of `rrule` and `no_rule` e.g with `Base.which` + - if the method of `no_rrule` `==` the method of `rrule`, then should opt-out + +To just ignore the fact that rules can be opted-out from, and that some rules thus return +`nothing`, then filter the list of methods of `rrule` to remove those that are `==` to ones +that occur in the method table of `no_rrule`. + +Note also when doing this you must still also handle falling back from rule with config, to +rule without config. + +On the other-hand if your AD can work with `rrule`s that return `nothing`, then it is +simpler to just use that mechanism for opting out; and you don't need to worry about this +at all. + +For more information see the [documentation on opting out of rules](@ref opt_out) +""" + +""" +$NO_RRULE_DOC + +See also [`ChainRulesCore.no_frule`](@ref). +""" +function no_rrule end +no_rrule(::Any, ::Vararg{Any}) = nothing + +""" +$(replace(NO_RRULE_DOC, "rrule"=>"frule")) + +See also [`ChainRulesCore.no_rrule`](@ref). +""" +function no_frule end +no_frule(ȧrgs, f, ::Vararg{Any}) = nothing diff --git a/test/rules.jl b/test/rules.jl index f5247797f..d43ca42d2 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -148,4 +148,32 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test_skip ∂xr isa Float64 # to be made true with projection @test_skip ∂xr ≈ real(∂x) end + + + @testset "@opt_out" begin + first_oa(x, y) = x + @scalar_rule(first_oa(x, y), (1, 0)) + @opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where T<:Float32 + @opt_out( + ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where T<:Float32 + ) + + @testset "rrule" begin + @test rrule(first_oa, 3.0, 4.0)[2](1) == (NoTangent(), 1, 0) + @test rrule(first_oa, 3f0, 4f0) === nothing + + @test !isempty(Iterators.filter(methods(ChainRulesCore.no_rrule)) do m + m.sig <:Tuple{Any, typeof(first_oa), T, T} where T<:Float32 + end) + end + + @testset "frule" begin + @test frule((NoTangent(), 1,0), first_oa, 3.0, 4.0) == (3.0, 1) + @test frule((NoTangent(), 1,0), first_oa, 3f0, 4f0) === nothing + + @test !isempty(Iterators.filter(methods(ChainRulesCore.no_frule)) do m + m.sig <:Tuple{Any, Any, typeof(first_oa), T, T} where T<:Float32 + end) + end + end end