-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add sparse
rrule
#579
add sparse
rrule
#579
Conversation
|
||
function rrule(::typeof(sparse), A::Union{AbstractVector, AbstractMatrix}) | ||
function sparse_pullback(Ω̄) | ||
return NoTangent(), Ω̄ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this ok or we need to project or something?
return sparse(I, J, V, m, n, combine), sparse_pullback | ||
end | ||
|
||
function rrule(::typeof(sparse), A::Union{AbstractVector, AbstractMatrix}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these rules be attached to sparse
or one step later to things like Type{<:SparseMatrixCSC}
? I don't have examples but catching all as many paths like T(A)
and convert(T, A)
etc. as possible sounds desirable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may also need rules for SparseMatrixCSC
, but for sure we need a rule for sparse
since it calls sparse! for the heavy-duty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I meant for the 1-arg sparse. The one which takes all the vectors is a different story.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry, didn't notice what you were pointing at. I'm not super confident in defining rrules for constructors.
Should it be rrule(::Type{<:SparseMatrixCSC},...)
or rrule(::Type{<:SparseMatrixCSC{Tv,Ti},...)
or the two are equivalent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fiddle until it works, but my guess is that the first would capture all mode specific cases. Maybe not convert
though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I defined the rrules for the type and removed this one. I did check separately that Zygote's gradients of sparse(A)
and sparse(v)
go through
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
Partial replacement for #246
cc @sethaxen