Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"

[weakdeps]
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"

[extensions]
PreallocationToolsEnzymeCoreExt = "EnzymeCore"
PreallocationToolsForwardDiffExt = "ForwardDiff"
PreallocationToolsReverseDiffExt = "ReverseDiff"
PreallocationToolsSparseConnectivityTracerExt = "SparseConnectivityTracer"
Expand All @@ -23,6 +25,8 @@ ADTypes = "1.16"
Adapt = "4.3.0"
Aqua = "0.8.11"
ArrayInterface = "7.19.0"
Enzyme = "0.13"
EnzymeCore = "0.8"
ForwardDiff = "0.10.38, 1.0.1"
LabelledArrays = "1.16.0"
LinearAlgebra = "1.10"
Expand All @@ -44,6 +48,8 @@ julia = "1.10"
[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -61,4 +67,4 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "ADTypes", "ForwardDiff", "Random", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SparseArrays", "Symbolics", "SparseConnectivityTracer"]
test = ["Aqua", "ADTypes", "Enzyme", "ForwardDiff", "Random", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SparseArrays", "Symbolics", "SparseConnectivityTracer"]
51 changes: 51 additions & 0 deletions ext/PreallocationToolsEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
module PreallocationToolsEnzymeCoreExt

using PreallocationTools
import EnzymeCore: EnzymeRules, Const, Duplicated

# TODO: Support Batched mode, on 1.11
# if VERSION >= v"1.11.0"
# function tuple_of_vectors(M::Matrix{T}, shape) where {T}
# n, m = size(M)
# return ntuple(m) do i
# vec = Base.wrap(Array, memoryref(M.ref, (i - 1) * n + 1), (n,))
# reshape(vec, shape)
# end
# end
# end

# TODO: Support reverse mode?

function EnzymeRules.forward(config, func::Const{typeof(PreallocationTools.get_tmp)}, ::Type{<:Duplicated},
dc::Duplicated{<:PreallocationTools.DiffCache}, u::Union{Const{T}, Duplicated{T}}) where {T}
du = PreallocationTools.get_tmp(dc.val, u.val)
ddu = PreallocationTools.get_tmp(dc.dval, u.val)
Duplicated(du, ddu)
end

function EnzymeRules.forward(config, func::Const{typeof(PreallocationTools.get_tmp)}, ::Type{<:Duplicated},
dc::Const{<:PreallocationTools.DiffCache}, u::Union{Const{T}, Duplicated{T}}) where {T}
dc = dc.val
du = PreallocationTools.get_tmp(dc, u.val)

# ddu = if isbitstype(T)
# nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
# if nelem > length(dc.dual_du)
# PreallocationTools.enlargediffcache!(dc, nelem)
# end
# PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
# else
# PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
# end

# Enzyme requires that Duplicated types have the same type and structure
# the above code fails since it creates something like a `Base.ReshapedArray{Float64, 2, SubArray{…}, Tuple{}})`

# TODO: How does this interact with Enzyme over ForwardDiff?
ddu = dc.dual_du
resize!(ddu, length(du))

Duplicated(du, reshape(ddu, size(du)))
end

end
42 changes: 42 additions & 0 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
module TestEnzyme
using Enzyme
using PreallocationTools
using ForwardDiff

const randmat = rand(5, 3)


function claytonsample!(sto, τ, α; randmat = randmat)
sto = get_tmp(sto, τ)
sto .= randmat
τ == 0 && return sto

n = size(sto, 1)
for i in 1:n
v = sto[i, 2]
u = sto[i, 1]
sto[i, 1] = (1 - u^(-τ) + u^(-τ) * v^(-(τ / (1 + τ))))^(-1 / τ) * α
sto[i, 2] = (1 - u^(-τ) + u^(-τ) * v^(-(τ / (1 + τ))))^(-1 / τ)
end
return sto
end

sto = similar(randmat)
stod = DiffCache(sto)

d_sto_fwd = ForwardDiff.derivative(τ -> claytonsample!(stod, τ, 0.0), 0.3)
d_sto_enz = Enzyme.autodiff(Forward, claytonsample!, Const(stod), Duplicated(0.3, 1.0), Const(0.0)) |> only

@test d_sto_enz ≈ d_sto_fwd

d_sto_enz2 = Enzyme.autodiff(Forward, claytonsample!, Duplicated(stod, Enzyme.make_zero(stod)), Duplicated(0.3, 1.0), Const(0.0)) |> only
@test d_sto_enz2 ≈ d_sto_fwd

d_sto_enz3 = Enzyme.autodiff(Forward, claytonsample!, Const(stod), Const(0.3), Const(0.0)) |> only
@test all(d_sto_enz3 .== 0.0)

d_sto_enz4 = Enzyme.autodiff(Forward, claytonsample!, Const(stod), Const(0.3), Duplicated(1.0, 1.0)) |> only
d_sto_fwd4 = reshape(ForwardDiff.jacobian(x -> claytonsample!(stod, x[1], x[2]), [0.3; 0.0])[:, 2], size(sto))
@test d_sto_enz4 ≈ d_sto_fwd4
end # TestEnzyme

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "DiffCache Nested Duals" include("core_nesteddual.jl")
@safetestset "DiffCache Sparsity Support" include("sparsity_support.jl")
@safetestset "DiffCache with SparseConnectivityTracer" include("sparse_connectivity_tracer.jl")
@safetestset "DiffCache with Enzyme" include("enzyme.jl")
@safetestset "LazyBufferCache" include("lbc.jl")
@safetestset "GeneralLazyBufferCache" include("general_lbc.jl")
@safetestset "Zero and Copy Dispatches" include("test_zero_copy.jl")
Expand Down
Loading