-
Notifications
You must be signed in to change notification settings - Fork 36
Description
(continued from invenia/Nabla.jl#81, cc @willtebbutt)
This means that we would need a slightly more general interface for linear algebra, and would certainly require different forward- and reverse- mode expressions, than is currently provided by DiffRules.
Agreed, DiffRules only properly handles scalar kernels now. To support linear algebra, we need to add a notion of tensor/scalar, allowing in-place methods, marking adjoint variables, etc. to DiffRules.
Regarding the where possible statement above, there are certainly operations (such as the Cholesky factorisation) for which symbolic expressions might be a bit unweildy (see here). Now that I think about this a bit more, I'm not sure whether these larger implementations are going to be problematic or not.
Anyway, my point is that given symbolic expressions for the linear algebra operations I agree that it's reasonable to hope that compiler optimisations can eliminate redundant code when compiling custom kernel implementations, and that this is a significantly better idea than hand-coding lots of optimisations. (The issue I linked in my previous comment is a good example of this. I would definitely prefer to be able to just forget about this). However, I would contend that you simply can't handle linear algebra properly without a number of hand-coded symbolic expressions for the forward- and reverse-mode sensitivities because they aren't written in Julia. If at some point in the future we have native Julia implementation of (for example) LAPACK, then it would be a really good idea to try and produce an AD tool which is able to produce reasonably-well optimised kernels for each operation. To the best of my knowledge, we shouldn't expect this to happen any time soon (and almost certainly never for BLAS), so a symbolic version of the current implementation of DiffLinearAlgebra will be necessary for Capstan to be able to differentiate arbitrary Julia code even reasonably efficiently.
I think there might've been a misunderstanding with my previous post 😛I definitely am not arguing that we should express e.g. complex LAPACK kernels symbolically, and I didn't mean to imply that DiffRules/DiffLinearAlgebra were directly competing approaches. On the contrary, I think they're quite complementary - if DiffLinearAlgebra didn't exist, I eventually would need to make a "DiffKernels.jl" anyway. DiffRules is useful for mapping primal functions to derivative functions, and is thus useful when generating e.g. instruction primitives/computation graphs within downstream tools (i.e. it solves the problem "what kernels should I call and how should I call them?"). DiffLinearAlgebra (as it stands) is useful for providing implementations of these kernels (i.e. solves the problem "how do I execute the kernels that I'm calling?"). They're both necessary components of the AD ecosystem.
As for deciding what computations should be primitives, I think we're already on the same page; a computation should be defined as a primitive if either/both of the following applies:
- it is difficult to express the computation as a composition of existing primitives
- a hand-optimized kernel for the computation is sufficiently more performant than the equivalent composition of existing primitives, even after taking into account potential compiler-level optimizations.