Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use ProjectTo functor everywhere, and update to v1 of CRC and CRTU #459

Merged
merged 44 commits into from
Jul 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
22621f0
array multiplication
Jun 22, 2021
98a2bdf
change to preproject
Jun 28, 2021
ec54fcb
Merge branch 'master' into mz/preproject
Jun 28, 2021
8d5e101
use ProjectTo functor for *
oxinabox Jun 29, 2021
88c0302
implement project in arraymath.jl
Jun 30, 2021
fc1e9e7
type instability with varargs
Jul 1, 2021
9066d93
simplify maps so inference is happier
oxinabox Jul 1, 2021
94af4dd
base.jl and fastmathable.jl
Jul 1, 2021
71c1dfc
evalpoly.jl
Jul 1, 2021
1251d2d
dense array
Jul 1, 2021
119070c
add array
Jul 2, 2021
e71f1c9
add structured array
Jul 2, 2021
df91e07
symmetric
Jul 2, 2021
07a1a45
fix evalpoly
Jul 2, 2021
ff93ce5
do not test types in array.jl
Jul 2, 2021
b8e9863
mapreduce.jl
Jul 2, 2021
3bbcc22
fix tests
Jul 2, 2021
290cc32
fix tests
Jul 5, 2021
e51d979
fastmathable
Jul 5, 2021
6f4e6fd
remove unionall wrapper
Jul 5, 2021
07f7664
uncomment test
Jul 5, 2021
9699d83
fix Symmetric
Jul 5, 2021
e589378
fix StaticArrays
Jul 5, 2021
12fda11
remove test_types
Jul 6, 2021
085166f
bump version, add compat
Jul 6, 2021
33a6c4c
Merge branch 'master' into ox/project_info
mzgubic Jul 6, 2021
b940980
move StaticArrays to deps
Jul 6, 2021
286ad38
no need to define the StaticArray projectto anymore
Jul 21, 2021
ac01a61
move staticarrays back to extras
Jul 21, 2021
0e5095c
bump dependency to 1
Jul 21, 2021
53aa703
Merge branch 'master' into ox/project_info
Jul 21, 2021
b211f61
fix commas
Jul 21, 2021
d59b920
fix fastmathable tests
Jul 21, 2021
76a2ce2
change tests
Jul 22, 2021
9857958
dont check inference
Jul 22, 2021
ff16952
fix the remaining issues
Jul 22, 2021
494afee
fix
Jul 22, 2021
6d14366
drive by bugfix
Jul 23, 2021
51fccc1
use new CRTU with ProjectTo solving many ⊢ problems
oxinabox Jul 23, 2021
c6ef56b
add comment to existing skipped test
oxinabox Jul 23, 2021
a085a69
Stop trying to inplace into unknown array types
oxinabox Jul 23, 2021
20ce318
stop using isapprox
oxinabox Jul 23, 2021
8789e84
Set version number to 1.0.0-DEV
oxinabox Jul 23, 2021
132f074
bump CRTU
oxinabox Jul 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.8.23"
version = "1.0.0-DEV"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -10,8 +10,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.10.12"
ChainRulesTestUtils = "0.7.9"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
Compat = "3.31"
FiniteDifferences = "0.12.8"
StaticArrays = "1.2"
Expand Down
23 changes: 15 additions & 8 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...)
Y = hcat(Xs...) # note that Y always has 1-based indexing, even if X isa OffsetArray
ndimsY = Val(ndims(Y)) # this avoids closing over Y, Val() is essential for type-stability
sizes = map(size, Xs) # this avoids closing over Xs
project_Xs = map(ProjectTo, Xs)
function hcat_pullback(ȳ)
dY = unthunk(ȳ)
hi = Ref(0) # Ref avoids hi::Core.Box
dXs = map(sizes) do sizeX
dXs = map(project_Xs, sizes) do project, sizeX
ndimsX = length(sizeX)
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
lo = hi[] + 1
hi[] += get(sizeX, 2, 1)
Expand All @@ -95,14 +96,15 @@ function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...)
d > ndimsX ? 1 : (:)
end
end
if ndimsX > 0
dX = if ndimsX > 0
# Here InplaceableThunk breaks @inferred, removed for now
# InplaceableThunk(dX -> dX .+= view(dY, ind...), @thunk(dY[ind...]))
dY[ind...]
else
# This is a hack to perhaps avoid GPU scalar indexing
sum(view(dY, ind...))
end
return project(dX)
end
return (NoTangent(), dXs...)
end
Expand Down Expand Up @@ -141,10 +143,11 @@ function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...)
Y = vcat(Xs...)
ndimsY = Val(ndims(Y))
sizes = map(size, Xs)
project_Xs = map(ProjectTo, Xs)
function vcat_pullback(ȳ)
dY = unthunk(ȳ)
hi = Ref(0)
dXs = map(sizes) do sizeX
dXs = map(project_Xs, sizes) do project, sizeX
ndimsX = length(sizeX)
lo = hi[] + 1
hi[] += get(sizeX, 1, 1)
Expand All @@ -155,12 +158,13 @@ function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...)
d > ndimsX ? 1 : (:)
end
end
if ndimsX > 0
dX = if ndimsX > 0
# InplaceableThunk(@thunk(dY[ind...]), dX -> dX .+= view(dY, ind...))
dY[ind...]
else
sum(view(dY, ind...))
end
return project(dX)
end
return (NoTangent(), dXs...)
end
Expand Down Expand Up @@ -195,10 +199,11 @@ function rrule(::typeof(cat), Xs::Union{AbstractArray, Number}...; dims)
cdims = dims isa Val ? Int(_val(dims)) : dims isa Integer ? Int(dims) : Tuple(dims)
ndimsY = Val(ndims(Y))
sizes = map(size, Xs)
project_Xs = map(ProjectTo, Xs)
function cat_pullback(ȳ)
dY = unthunk(ȳ)
prev = fill(0, _val(ndimsY)) # note that Y always has 1-based indexing, even if X isa OffsetArray
dXs = map(sizes) do sizeX
dXs = map(project_Xs, sizes) do project, sizeX
ndimsX = length(sizeX)
index = ntuple(ndimsY) do d
if d in cdims
Expand All @@ -210,12 +215,13 @@ function rrule(::typeof(cat), Xs::Union{AbstractArray, Number}...; dims)
for d in cdims
prev[d] += get(sizeX, d, 1)
end
if ndimsX > 0
dX = if ndimsX > 0
# InplaceableThunk(@thunk(dY[index...]), dX -> dX .+= view(dY, index...))
dY[index...]
else
sum(view(dY, index...))
end
return project(dX)
end
return (NoTangent(), dXs...)
end
Expand All @@ -231,9 +237,10 @@ function rrule(::typeof(hvcat), rows, values::Union{AbstractArray, Number}...)
cols = size(Y,2)
ndimsY = Val(ndims(Y))
sizes = map(size, values)
project_Vs = map(ProjectTo, values)
function hvcat_pullback(dY)
prev = fill(0, 2)
dXs = map(sizes) do sizeX
dXs = map(project_Vs, sizes) do project, sizeX
ndimsX = length(sizeX)
index = ntuple(ndimsY) do d
if d in (1, 2)
Expand All @@ -247,7 +254,7 @@ function rrule(::typeof(hvcat), rows, values::Union{AbstractArray, Number}...)
prev[2] = 0
prev[1] += get(sizeX, 1, 1)
end
dY[index...]
project(dY[index...])
end
return (NoTangent(), NoTangent(), dXs...)
end
Expand Down
81 changes: 51 additions & 30 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ function rrule(
A::AbstractVecOrMat{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
Ȳ = unthunk(ȳ)
return (
NoTangent(),
InplaceableThunk(
X̄ -> mul!(X̄, Ȳ, B', true, true),
@thunk(Ȳ * B'),
),
InplaceableThunk(
X̄ -> mul!(X̄, A', Ȳ, true, true),
@thunk(A' * Ȳ),
)
dA = InplaceableThunk(
X̄ -> mul!(X̄, Ȳ, B', true, true),
@thunk(project_A(Ȳ * B')),
)
dB = InplaceableThunk(
X̄ -> mul!(X̄, A', Ȳ, true, true),
@thunk(project_B(A' * Ȳ)),
)
return NoTangent(), dA, dB
end
return A * B, times_pullback
end
Expand All @@ -46,18 +46,20 @@ function rrule(
A::AbstractVector{<:CommutativeMulNumber},
B::AbstractMatrix{<:CommutativeMulNumber},
)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
Ȳ = unthunk(ȳ)
@assert size(B, 1) === 1 # otherwise primal would have failed.
return (
NoTangent(),
InplaceableThunk(
X̄ -> mul!(X̄, Ȳ, vec(B'), true, true),
@thunk(Ȳ * vec(B')),
@thunk(project_A(Ȳ * vec(B'))),
),
InplaceableThunk(
X̄ -> mul!(X̄, A', Ȳ, true, true),
@thunk(A' * Ȳ),
@thunk(project_B(A' * Ȳ)),
)
)
end
Expand All @@ -67,14 +69,16 @@ end
function rrule(
::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber}
)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
Ȳ = unthunk(ȳ)
return (
NoTangent(),
@thunk(dot(Ȳ, B)'),
@thunk(project_A(dot(Ȳ, B)')),
InplaceableThunk(
X̄ -> mul!(X̄, conj(A), Ȳ, true, true),
@thunk(A' * Ȳ),
@thunk(project_B(A' * Ȳ)),
)
)
end
Expand All @@ -84,15 +88,17 @@ end
function rrule(
::typeof(*), B::AbstractArray{<:CommutativeMulNumber}, A::CommutativeMulNumber
)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
Ȳ = unthunk(ȳ)
return (
NoTangent(),
InplaceableThunk(
X̄ -> mul!(X̄, conj(A), Ȳ, true, true),
@thunk(A' * Ȳ),
@thunk(project_B(A' * Ȳ)),
),
@thunk(dot(Ȳ, B)'),
@thunk(project_A(dot(Ȳ, B)')),
)
end
return A * B, times_pullback
Expand All @@ -109,27 +115,31 @@ function rrule(
B::AbstractVecOrMat{<:CommutativeMulNumber},
z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}},
)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
project_z = ProjectTo(z)

# The useful case, mul! fused with +
function muladd_pullback_1(ȳ)
Ȳ = unthunk(ȳ)
matmul = (
InplaceableThunk(
dA -> mul!(dA, Ȳ, B', true, true),
@thunk(Ȳ * B'),
@thunk(project_A(Ȳ * B')),
),
InplaceableThunk(
dB -> mul!(dB, A', Ȳ, true, true),
@thunk(A' * Ȳ),
@thunk(project_B(A' * Ȳ)),
)
)
addon = if z isa Bool
NoTangent()
elseif z isa Number
@thunk(sum(Ȳ))
@thunk(project_z(sum(Ȳ)))
else
InplaceableThunk(
dz -> sum!(dz, Ȳ; init=false),
@thunk(sum!(similar(z, eltype(Ȳ)), Ȳ)),
@thunk(project_z(sum!(similar(z, eltype(Ȳ)), Ȳ))),
)
end
(NoTangent(), matmul..., addon)
Expand All @@ -143,18 +153,22 @@ function rrule(
v::AbstractVector{<:CommutativeMulNumber},
z::CommutativeMulNumber,
)
project_ut = ProjectTo(ut)
project_v = ProjectTo(v)
project_z = ProjectTo(z)

# This case is dot(u,v)+z, but would also match signature above.
function muladd_pullback_2(ȳ)
dy = unthunk(ȳ)
ut_thunk = InplaceableThunk(
dut -> dut .+= v' .* dy,
@thunk(v' .* dy),
@thunk(project_ut(v' .* dy)),
)
v_thunk = InplaceableThunk(
dv -> dv .+= ut' .* dy,
@thunk(ut' .* dy),
@thunk(project_v(ut' .* dy)),
)
(NoTangent(), ut_thunk, v_thunk, z isa Bool ? NoTangent() : dy)
(NoTangent(), ut_thunk, v_thunk, z isa Bool ? NoTangent() : project_z(dy))
end
return muladd(ut, v, z), muladd_pullback_2
end
Expand All @@ -165,21 +179,25 @@ function rrule(
vt::LinearAlgebra.AdjOrTransAbsVec{<:CommutativeMulNumber},
z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}},
)
project_u = ProjectTo(u)
project_vt = ProjectTo(vt)
project_z = ProjectTo(z)

# Outer product, just broadcasting
function muladd_pullback_3(ȳ)
Ȳ = unthunk(ȳ)
proj = (
@thunk(vec(sum(Ȳ .* conj.(vt), dims=2))),
@thunk(vec(sum(u .* conj.(Ȳ), dims=1))'),
@thunk(project_u(vec(sum(Ȳ .* conj.(vt), dims=2)))),
@thunk(project_vt(vec(sum(u .* conj.(Ȳ), dims=1))')),
)
addon = if z isa Bool
NoTangent()
elseif z isa Number
@thunk(sum(Ȳ))
@thunk(project_z(sum(Ȳ)))
else
InplaceableThunk(
dz -> sum!(dz, Ȳ; init=false),
@thunk(sum!(similar(z, eltype(Ȳ)), Ȳ)),
@thunk(project_z(sum!(similar(z, eltype(Ȳ)), Ȳ))),
)
end
(NoTangent(), proj..., addon)
Expand All @@ -202,7 +220,7 @@ function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
_, dBᵀ, dAᵀ = dS_pb(unthunk(dC))

∂A = last(dA_pb(unthunk(dAᵀ)))
∂B = last(dA_pb(unthunk(dBᵀ)))
∂B = last(dB_pb(unthunk(dBᵀ)))

(NoTangent(), ∂A, ∂B)
end
Expand All @@ -214,6 +232,9 @@ end
#####

function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
project_A = ProjectTo(A)
project_B = ProjectTo(B)

Y = A \ B
function backslash_pullback(ȳ)
Ȳ = unthunk(ȳ)
Expand All @@ -222,9 +243,9 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
Ā = -B̄ * Y'
Ā = add!!(Ā, (B - A * Y) * B̄' / A')
Ā = add!!(Ā, A' \ Y * (Ȳ' - B̄'A))
project_A(Ā)
end
∂B = @thunk A' \ Ȳ
∂B = @thunk project_B(A' \ Ȳ)
return NoTangent(), ∂A, ∂B
end
return Y, backslash_pullback
Expand Down
5 changes: 4 additions & 1 deletion src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ function rrule(::Type{T}, x::Real) where {T<:Complex}
return (T(x), Complex_pullback)
end
function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex}
project_x = ProjectTo(x)
project_y = ProjectTo(y)

function Complex_pullback(Ω̄)
ΔΩ = unthunk(Ω̄)
return (NoTangent(), real(ΔΩ), imag(ΔΩ))
return (NoTangent(), project_x(real(ΔΩ)), project_y(imag(ΔΩ)))
end
return (T(x, y), Complex_pullback)
end
Expand Down
11 changes: 11 additions & 0 deletions src/rulesets/Base/evalpoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ if VERSION ≥ v"1.4"
end

function rrule(::typeof(evalpoly), x, p)
y, ys = _evalpoly_intermediates(x, p)
project_x = ProjectTo(x)
project_p = p isa Tuple ? identity : ProjectTo(p)
function evalpoly_pullback(Δy)
∂x, ∂p = _evalpoly_back(x, p, ys, Δy)
return NoTangent(), project_x(∂x), project_p(∂p)
end
return y, evalpoly_pullback
end

function rrule(::typeof(evalpoly), x, p::Vector{<:Matrix}) # does not type infer with ProjectTo
y, ys = _evalpoly_intermediates(x, p)
function evalpoly_pullback(Δy)
∂x, ∂p = _evalpoly_back(x, p, ys, Δy)
Expand Down
4 changes: 3 additions & 1 deletion src/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,11 @@ let
end

function rrule(::typeof(*), x::Number, y::Number)
project_x = ProjectTo(x)
project_y = ProjectTo(y)
function times_pullback(Ω̇)
ΔΩ = unthunk(Ω̇)
return (NoTangent(), ΔΩ * y', x' * ΔΩ)
return (NoTangent(), project_x(ΔΩ * y'), project_y(x' * ΔΩ))
end
return x * y, times_pullback
end
Expand Down
Loading