Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Projecting Cotangents #286

Closed
willtebbutt opened this issue Jan 24, 2021 · 4 comments
Closed

Projecting Cotangents #286

willtebbutt opened this issue Jan 24, 2021 · 4 comments

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Jan 24, 2021

ChainRules embraces multiple possible representations of cotangent, for example AbstractZero, Composite, and AbstractArray are all valid representations for the cotangent of a Diagonal. However, this flexibility results in an increased burden on rule implementers in that there is in principle no real upper bound on the number of types that one might have to accept as the cotangent w.r.t. the output of some function foo that returns a Diagonal.

I wonder whether some design-orthogonalisation might help to deal with this -- could we separate out the standardisation of the representation of cotangents from the rule implementation?

Consider a function canonicalise(primal, cotangent) whose job it is to map a type onto a well-defined, predictable finite set of types for any given primal type. For example, you might implement this as follows for Diagonal:

canonicalise(::Any, dX::AbstractZero) = dX
canonicalise(X::Diagonal, dX::Composite{T}) = Composite{T}(diag=canonicalise(X.diag, dX.diag))
canonicalise(X::Diagonal, dX::AbstractMatrix) = Composite{T}(diag=canonicalise(X.diag, diag(dX)))

Note that I've chosen to make the canonical cotangent type for a Diagonal a Composite rather than an AbstractMatrix for the usual performance related reasons discussed extensively in JuliaDiff/ChainRules.jl#232. An AbstractMatrix doesn't count as a "canonical" type in my definition here since it's abstract, so doesn't meet the finiteness criterion.

If you did this, then we will certainly be able to avoid defining + on so many things -- you just assume that things have been canonicalised before hitting +. Similarly, Zygotes automatic constructor pullback generation ought to have an easier time because, if you ensure that everything is appropriately canonicalised, constructors should always receive appropriate NamedTuples.

@sethaxen pointed out that this is something that we might want to concern ourselves with in #160, but I wanted to raise it separately, as I think it's an interesting thing to consider on its own.

edit: not sure whether we want to choose a different name from canonicalise, given that we already have a function with that name. Possibly we could extend it to handle the more general class of things described here.

@sethaxen
Copy link
Member

sethaxen commented Feb 4, 2021

In general I am for something like this, for all the reasons noted elsewhere. I do have some concerns though. Suppose I have a primal function f(::AbstractMatrix) -> AbstractMatrix, where it just so happens that the internal operations through dispatch result in f(::Diagonal) -> Diagonal. Now I define rrule(::typeof(f), ::AbstractMatrix), where the rrule calls the primal function. The output will be Diagonal, so the cotangent my pullback will be passed will be a Composite{Diagonal}. But my pullback is probably assuming I am being passed an AbstractMatrix and will end up erroring. This is currently the problem we have, because users are defining such abstract rules, and a method like this that picks a Composite representation could just error and not work at all for such rules. Of course, the alternative has its own problems. Do you have any ideas for how we can work around this?

@willtebbutt
Copy link
Member Author

This is a good point -- perhaps we need a more general piece of functionality that rule-implementers can also hook into that translates between any valid representation of a differential? So if they receive a Composite they have functionality to translate that into a Diagonal?

@willtebbutt
Copy link
Member Author

I've sketched an implementation of this proposal here: #306

@mzgubic
Copy link
Member

mzgubic commented Jul 6, 2021

closed by #385

@mzgubic mzgubic closed this as completed Jul 6, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants