Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use ProjectTo functor everywhere, and update to v1 of CRC and CRTU #459

Merged
merged 44 commits into from
Jul 26, 2021

Conversation

oxinabox
Copy link
Member

This is the partner to JuliaDiff/ChainRulesCore.jl#385
it matches against #457

It is much cleaner and it should be good again inference problems from closing over types,
as the type gets put into a type-parameter of project_A and project_B

Following minimized version works find for me, haven't tested real thing properly.

using Test, LinearAlgebra, ChainRulesCore

function projectto_infers(x::AbstractArray)
       project = ProjectTo(x)
       function myclosure(y)
           return project(y)
       end
       return myclosure
   end
   
   
const pb = projectto_infers(Diagonal([1.0, 2.0]))
@inferred pb([1.0 2.0; 3.0 4.0])

@github-actions github-actions bot added the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Jun 29, 2021
src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member Author

oxinabox commented Jul 1, 2021

I have made changes that make the following all work for me locally

using ChainRulesCore, ChainRules, Test
@inferred(
    rrule(hvcat, 1, rand(3)', transpose(rand(3)))[2](rand(2,3))
)
using ChainRulesTestUtils
test_rrule(hvcat, 1, rand(3)', transpose(rand(3))  rand(1,3))

@mcabbott
Copy link
Member

mcabbott commented Jul 7, 2021

One way you could avoid boilerplate here is to define a macro which turns

@ProjectEverything function rrule(f, x, ys...)
   # do stuff
  return y, back
end

into

function rrule(f, x, ys...)
   # do stuff
  extra = MultiProject(f, x, ys...)
  return y, extra∘back
end

In fact even without a macro, a MultiProject which you compose like this would cut a lot of lines here. It's just this, right? (untested)

struct MultiProject{T}
    funs::T
    function MultiProject(xs...)
        funs = map(xs) do x
            if x isa Number || x isa AbstractArray
                ProjectTo
            else
                identity
            end
        end
        new{typeof(funs)}(funs)
    end
end
(m::MultiProject)(dxs::Tuple) = map((f,dx) -> f(dx), m.funs, dxs)

@mzgubic
Copy link
Member

mzgubic commented Jul 22, 2021

still failing:
y, pb = rrule(/, rand(3), rand(3)); @inferred pb(rand(3, 3)) (rrule inference)
y, pb = rrule(rc, sum, sum, [[2.0, 4.0], [4.0,1.9]]); pb(y) (can't deal with project receiving an InplaceableThunk, rc is the fake ruleconfig)
test_rrule(dot, Diagonal(rand(2)), rand(2, 2)) (unthunk inference)
test_rrule(adjoint, rand(2, 3); output_tangent=rand(3,2)) (rrule inference)
y, pb = rrule(Symmetric, rand(3, 3), :U); @inferred pb(ΔΩ)[2] (rrule inference)

it seems that solving

julia> @inferred ProjectTo(rand(3, 3))(Diagonal(rand(3)))
ERROR: return type Diagonal{Float64, Vector{Float64}} does not match inferred return type Union{Base.ReshapedArray{Float64, 2, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Diagonal{Float64, Vector{Float64}}}

could solve some (maybe all?) of the inference issues

@oxinabox
Copy link
Member Author

For the inference problems, I am happy enough to open an issue and come back to them.
For

y, pb = rrule(rc, sum, sum, [[2.0, 4.0], [4.0,1.9]]); pb(y) (can't deal with project receiving an InplaceableThunk, rc is the fake ruleconfig)

Can we insert a unthunk here before the project ?
We do not always want to project InplacableThunks via unthunking but in this case we do,
The kinda function people realistically use with sum(f,...) are ones that we would basically never want to preserve thunking through.

@oxinabox oxinabox changed the title use ProjectTo functor in * use ProjectTo functor everywhere, and update to v1 of CRC and CRTU Jul 23, 2021
@thunk(project_y(reshape(x .* ΔΩ, axes(y)))),
)
return (NoTangent(), xthunk, ythunk)
x̄ = @thunk(project_x(reshape(y .* ΔΩ', axes(x))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this permanent or until we have something like JuliaDiff/ChainRulesCore.jl#393?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah,
more generally I hate the InplacableThunks that are just deferring broadcast materialization.
They are so boilerplatey.
That has its own issue JuliaDiff/ChainRulesCore.jl#274
and is probably going to be easier to solve.

I am more than happy to delete them.
Can always add them back

@mzgubic mzgubic merged commit e5c5338 into master Jul 26, 2021
@mzgubic mzgubic deleted the ox/project_info branch July 26, 2021 11:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants