-
Notifications
You must be signed in to change notification settings - Fork 93
Open
Description
Our current rrule for sparse matrix vector products is very inefficient, and causes out-of-memory with large sparse CPU or GPU arrays. Our current rrule(*, sparse(A), x)
is implemented like this
function rrule(
::typeof(*),
A::AbstractVecOrMat{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
)
project_A = ProjectTo(A)
...
dA = @thunk(project_A(Ȳ * B'))
...
end
So we first compute a non-sparse Ȳ * B'
(may easily exceed memory if A was very large but very sparse) and then project back to a sparse tangent.
The best way to fix this (at least if Ȳ' and 'B'
are vectors) might be adding a specific "vector-outer-product" array type for read-only vector * adjoint-vector products (might be useful in general) that computes getindex
on the fly. Or maybe we already have that somewhere?
Metadata
Metadata
Assignees
Labels
No labels