Skip to content

Commit

Permalink
move branches into InplaceableThunk
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Apr 29, 2021
1 parent 7b50309 commit cae1531
Showing 1 changed file with 36 additions and 22 deletions.
58 changes: 36 additions & 22 deletions src/rulesets/LinearAlgebra/norm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,36 @@ end
function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real)
y = LinearAlgebra.norm(x, p)
function norm_pullback_p(Δy)
∂x = if isempty(x) || p == 0
InplaceableThunk(
@thunk(zero.(x) .* (zero(y) * zero(real(Δy)))),
identity,
)
∂x = InplaceableThunk(
# out-of-place versions
if isempty(x) || p == 0
@thunk(zero.(x) .* (zero(y) * zero(real(Δy))))
elseif p == 2
InplaceableThunk(
@thunk(_norm2_back(x, y, Δy)),
dx -> _norm2_back!(dx, x, y, Δy),
)
@thunk(_norm2_back(x, y, Δy))
elseif p == 1
InplaceableThunk(
@thunk(_norm1_back(x, y, Δy)),
dx -> _norm1_back!(dx, x, y, Δy),
)
@thunk(_norm1_back(x, y, Δy))
elseif p == Inf
_normInf_back(x, y, Δy)
@thunk(_normInf_back(x, y, Δy))
elseif p == -Inf
_normInf_back(x, y, Δy)
@thunk(_normInf_back(x, y, Δy))
else
_normp_back_x(x, p, y, Δy)
@thunk(_normp_back_x(x, p, y, Δy))
end,
# in-place versions
if isempty(x) || p == 0
identity
elseif p == 2
dx -> _norm2_back!(dx, x, y, Δy)
elseif p == 1
dx -> _norm1_back!(dx, x, y, Δy)
elseif p == Inf
dx -> dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved
elseif p == -Inf
dx -> dx .+= _normInf_back(x, y, Δy)
else
dx -> dx .+= _normp_back_x(x, p, y, Δy)
end
)
∂p = @thunk _normp_back_p(x, p, y, Δy)
return (NO_FIELDS, ∂x, ∂p)
end
Expand All @@ -51,12 +59,18 @@ end
function rrule(::typeof(norm), x::AbstractArray{<:Number})
y = LinearAlgebra.norm(x)
function norm_pullback_2(Δy)
∂x = if isempty(x)
zero.(x) .* (zero(y) * zero(real(Δy)))
else
InplaceableThunk(
@thunk(_norm2_back(x, y, Δy)),
dx -> _norm2_back!(dx, x, y, Δy),
∂x = InplaceableThunk(
if isempty(x)
@thunk(zero.(x) .* (zero(y) * zero(real(Δy))))
else
@thunk(_norm2_back(x, y, Δy))
end
,
if isempty(x)
identity
else
dx -> _norm2_back!(dx, x, y, Δy)
end
)
end
return (NO_FIELDS, ∂x)
Expand Down

0 comments on commit cae1531

Please sign in to comment.