You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was hoping someone might be able to explain the current frule for getindex. Say I want to compute $\nabla_a \text{getindex}(a, i)$ with forward mode AD where $a$ is a 5-element vector. Intuitively, I would expect to get back a five element vector, where the only nonzero element is at index $i$. But when I try frule((NoTangent(), Ones(5), NoTangent()), getindex, ones(5), 1), I get a tangent of 1, not [1,0,0,0,0]. Why?
Specifically, I would have thought that the proper frule would be something like this:
function ChainRulesCore.frule((_, ẋ), ::typeof(getindex), x::AbstractArray, ix::Integer)
return x[ix], sparsevec([ix], ẋ[ix], length(x))
end
The text was updated successfully, but these errors were encountered:
the frule is returning the pushforward result, not the gradient (nor the pullback result).
It is the result of pertubing the input with by a small amount and measuring how much the output changes.
It is thus expected that it is the same "type" as the output, which in this case is a scalar.
(this is incontrast to rrules's pullback which is same "type" as the input)
More details can be found in the docs https://juliadiff.org/ChainRulesCore.jl/dev/maths/propagators.html
I was hoping someone might be able to explain the current frule for$\nabla_a \text{getindex}(a, i)$ with forward mode AD where $a$ is a 5-element vector. Intuitively, I would expect to get back a five element vector, where the only nonzero element is at index $i$ . But when I try
getindex
. Say I want to computefrule((NoTangent(), Ones(5), NoTangent()), getindex, ones(5), 1)
, I get a tangent of1
, not[1,0,0,0,0]
. Why?Specifically, I would have thought that the proper
frule
would be something like this:The text was updated successfully, but these errors were encountered: