Skip to content

Supporting Linear Algebraic Primitives #10

@jrevels

Description

@jrevels

(continued from invenia/Nabla.jl#81, cc @willtebbutt)

This means that we would need a slightly more general interface for linear algebra, and would certainly require different forward- and reverse- mode expressions, than is currently provided by DiffRules.

Agreed, DiffRules only properly handles scalar kernels now. To support linear algebra, we need to add a notion of tensor/scalar, allowing in-place methods, marking adjoint variables, etc. to DiffRules.

Regarding the where possible statement above, there are certainly operations (such as the Cholesky factorisation) for which symbolic expressions might be a bit unweildy (see here). Now that I think about this a bit more, I'm not sure whether these larger implementations are going to be problematic or not.

Anyway, my point is that given symbolic expressions for the linear algebra operations I agree that it's reasonable to hope that compiler optimisations can eliminate redundant code when compiling custom kernel implementations, and that this is a significantly better idea than hand-coding lots of optimisations. (The issue I linked in my previous comment is a good example of this. I would definitely prefer to be able to just forget about this). However, I would contend that you simply can't handle linear algebra properly without a number of hand-coded symbolic expressions for the forward- and reverse-mode sensitivities because they aren't written in Julia. If at some point in the future we have native Julia implementation of (for example) LAPACK, then it would be a really good idea to try and produce an AD tool which is able to produce reasonably-well optimised kernels for each operation. To the best of my knowledge, we shouldn't expect this to happen any time soon (and almost certainly never for BLAS), so a symbolic version of the current implementation of DiffLinearAlgebra will be necessary for Capstan to be able to differentiate arbitrary Julia code even reasonably efficiently.

I think there might've been a misunderstanding with my previous post 😛I definitely am not arguing that we should express e.g. complex LAPACK kernels symbolically, and I didn't mean to imply that DiffRules/DiffLinearAlgebra were directly competing approaches. On the contrary, I think they're quite complementary - if DiffLinearAlgebra didn't exist, I eventually would need to make a "DiffKernels.jl" anyway. DiffRules is useful for mapping primal functions to derivative functions, and is thus useful when generating e.g. instruction primitives/computation graphs within downstream tools (i.e. it solves the problem "what kernels should I call and how should I call them?"). DiffLinearAlgebra (as it stands) is useful for providing implementations of these kernels (i.e. solves the problem "how do I execute the kernels that I'm calling?"). They're both necessary components of the AD ecosystem.

As for deciding what computations should be primitives, I think we're already on the same page; a computation should be defined as a primitive if either/both of the following applies:

  1. it is difficult to express the computation as a composition of existing primitives
  2. a hand-optimized kernel for the computation is sufficiently more performant than the equivalent composition of existing primitives, even after taking into account potential compiler-level optimizations.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions