Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
- SDE3
version:
- '1'
- '1.11'
- 'lts'
steps:
- uses: actions/checkout@v4
Expand Down
9 changes: 8 additions & 1 deletion src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
module SciMLSensitivity

# Enzyme is not compatible with Julia 1.12+
const ENZYME_ENABLED = VERSION < v"1.12"

using ADTypes: ADTypes, AutoEnzyme, AutoFiniteDiff, AutoForwardDiff,
AutoReverseDiff, AutoTracker, AutoZygote
using Accessors: @reset
Expand Down Expand Up @@ -45,7 +48,9 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqCore, BrownFullBasicInit, DefaultInit,
# AD Backends
using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ZeroTangent,
AbstractThunk, AbstractTangent
using Enzyme: Enzyme
@static if ENZYME_ENABLED
using Enzyme: Enzyme
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enzyme is loadable on 1.12 always, so this isn't required atm

end
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Tracker: Tracker, TrackedArray
Expand Down Expand Up @@ -86,6 +91,8 @@ include("sde_tools.jl")

export extract_local_sensitivities

export ENZYME_ENABLED

export ODEForwardSensitivityFunction, ODEForwardSensitivityProblem, SensitivityFunction,
ODEAdjointProblem, AdjointSensitivityIntegrand,
SDEAdjointProblem, RODEAdjointProblem, SensitivityAlg,
Expand Down
178 changes: 90 additions & 88 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1285,105 +1285,107 @@ function DiffEqBase._concrete_solve_adjoint(
p)
end

function DiffEqBase._concrete_solve_adjoint(
prob::Union{SciMLBase.AbstractDiscreteProblem,
SciMLBase.AbstractODEProblem,
SciMLBase.AbstractDAEProblem,
SciMLBase.AbstractDDEProblem,
SciMLBase.AbstractSDEProblem,
SciMLBase.AbstractSDDEProblem,
SciMLBase.AbstractRODEProblem
},
alg, sensealg::EnzymeAdjoint,
u0, p, originator::SciMLBase.ADOriginator,
args...; kwargs...)
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))
du0 = Enzyme.make_zero(u0)
dp = Enzyme.make_zero(p)
mode = sensealg.mode

# Force no FunctionWrappers for Enzyme
_prob = remake(prob, f = f = ODEFunction{isinplace(prob), SciMLBase.FullSpecialize}(unwrapped_f(prob.f)) )

diff_func = (u0,
p) -> solve(_prob, alg, args...; u0 = u0, p = p,
sensealg = SensitivityADPassThrough(),
kwargs_filtered...)

splitmode = if mode isa Enzyme.ForwardMode
error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.")
elseif mode === nothing || mode isa Enzyme.ReverseMode
Enzyme.set_runtime_activity(Enzyme.ReverseSplitWithPrimal)
end
@static if ENZYME_ENABLED
function DiffEqBase._concrete_solve_adjoint(
prob::Union{SciMLBase.AbstractDiscreteProblem,
SciMLBase.AbstractODEProblem,
SciMLBase.AbstractDAEProblem,
SciMLBase.AbstractDDEProblem,
SciMLBase.AbstractSDEProblem,
SciMLBase.AbstractSDDEProblem,
SciMLBase.AbstractRODEProblem
},
alg, sensealg::EnzymeAdjoint,
u0, p, originator::SciMLBase.ADOriginator,
args...; kwargs...)
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))
du0 = Enzyme.make_zero(u0)
dp = Enzyme.make_zero(p)
mode = sensealg.mode

forward,
reverse = Enzyme.autodiff_thunk(
splitmode, Enzyme.Const{typeof(diff_func)}, Enzyme.Duplicated,
Enzyme.Duplicated{typeof(u0)}, Enzyme.Duplicated{typeof(p)})
tape, result,
shadow_result = forward(
Enzyme.Const(diff_func), Enzyme.Duplicated(copy(u0), du0), Enzyme.Duplicated(copy(p), dp))

function enzyme_sensitivity_backpass(Δ)
if (Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray)
for (x, y) in zip(shadow_result.u, Δ.u)
x .= y
end
else
error("typeof(Δ) = $(typeof(Δ)) is not currently handled in EnzymeAdjoint. Please open an issue with an MWE to add support")
# Force no FunctionWrappers for Enzyme
_prob = remake(prob, f = f = ODEFunction{isinplace(prob), SciMLBase.FullSpecialize}(unwrapped_f(prob.f)) )

diff_func = (u0,
p) -> solve(_prob, alg, args...; u0 = u0, p = p,
sensealg = SensitivityADPassThrough(),
kwargs_filtered...)

splitmode = if mode isa Enzyme.ForwardMode
error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.")
elseif mode === nothing || mode isa Enzyme.ReverseMode
Enzyme.set_runtime_activity(Enzyme.ReverseSplitWithPrimal)
end
reverse(Enzyme.Const(diff_func), Enzyme.Duplicated(u0, du0), Enzyme.Duplicated(p, dp), tape)
if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)

forward,
reverse = Enzyme.autodiff_thunk(
splitmode, Enzyme.Const{typeof(diff_func)}, Enzyme.Duplicated,
Enzyme.Duplicated{typeof(u0)}, Enzyme.Duplicated{typeof(p)})
tape, result,
shadow_result = forward(
Enzyme.Const(diff_func), Enzyme.Duplicated(copy(u0), du0), Enzyme.Duplicated(copy(p), dp))

function enzyme_sensitivity_backpass(Δ)
if (Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray)
for (x, y) in zip(shadow_result.u, Δ.u)
x .= y
end
else
error("typeof(Δ) = $(typeof(Δ)) is not currently handled in EnzymeAdjoint. Please open an issue with an MWE to add support")
end
reverse(Enzyme.Const(diff_func), Enzyme.Duplicated(u0, du0), Enzyme.Duplicated(p, dp), tape)
if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end
end
result, enzyme_sensitivity_backpass
end
result, enzyme_sensitivity_backpass
end

# NOTE: This is needed to prevent a method ambiguity error
function DiffEqBase._concrete_solve_adjoint(
prob::AbstractNonlinearProblem, alg, sensealg::EnzymeAdjoint,
u0, p, originator::SciMLBase.ADOriginator,
args...; kwargs...)
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))

du0 = make_zero(u0)
dp = make_zero(p)
mode = sensealg.mode
# NOTE: This is needed to prevent a method ambiguity error
function DiffEqBase._concrete_solve_adjoint(
prob::AbstractNonlinearProblem, alg, sensealg::EnzymeAdjoint,
u0, p, originator::SciMLBase.ADOriginator,
args...; kwargs...)
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))

f = (u0,
p) -> solve(prob, alg, args...; u0 = u0, p = p,
sensealg = SensitivityADPassThrough(),
kwargs_filtered...)
du0 = make_zero(u0)
dp = make_zero(p)
mode = sensealg.mode

splitmode = if mode isa Forward
error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.")
elseif mode === nothing || mode === Reverse
ReverseSplitWithPrimal
end
f = (u0,
p) -> solve(prob, alg, args...; u0 = u0, p = p,
sensealg = SensitivityADPassThrough(),
kwargs_filtered...)

forward,
reverse = autodiff_thunk(splitmode, Const{typeof(f)}, Duplicated,
Duplicated{typeof(u0)}, Duplicated{typeof(p)})
tape, result, shadow_result = forward(Const(f), Duplicated(u0, du0), Duplicated(p, dp))
splitmode = if mode isa Forward
error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.")
elseif mode === nothing || mode === Reverse
ReverseSplitWithPrimal
end

function enzyme_sensitivity_backpass(Δ)
reverse(Const(f), Duplicated(u0, du0), Duplicated(p, dp), Δ, tape)
if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
forward,
reverse = autodiff_thunk(splitmode, Const{typeof(f)}, Duplicated,
Duplicated{typeof(u0)}, Duplicated{typeof(p)})
tape, result, shadow_result = forward(Const(f), Duplicated(u0, du0), Duplicated(p, dp))

function enzyme_sensitivity_backpass(Δ)
reverse(Const(f), Duplicated(u0, du0), Duplicated(p, dp), Δ, tape)
if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end
end
sol, enzyme_sensitivity_backpass
end
sol, enzyme_sensitivity_backpass
end

const ENZYME_TRACKED_REAL_ERROR_MESSAGE = """
Expand Down
19 changes: 15 additions & 4 deletions src/sensitivity_algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -801,10 +801,21 @@ EnzymeAdjoint(mode = nothing)

Currently fails on almost every solver.
"""
struct EnzymeAdjoint{M <: Union{Nothing, Enzyme.EnzymeCore.Mode}} <:
AbstractAdjointSensitivityAlgorithm{nothing, true, nothing}
mode::M
EnzymeAdjoint(mode = nothing) = new{typeof(mode)}(mode)
@static if ENZYME_ENABLED
struct EnzymeAdjoint{M <: Union{Nothing, Enzyme.EnzymeCore.Mode}} <:
AbstractAdjointSensitivityAlgorithm{nothing, true, nothing}
mode::M
EnzymeAdjoint(mode = nothing) = new{typeof(mode)}(mode)
end
else
# Dummy type for Julia 1.12+ - Enzyme is not loaded on this version
struct EnzymeAdjoint{M <: Nothing} <:
AbstractAdjointSensitivityAlgorithm{nothing, true, nothing}
mode::M
function EnzymeAdjoint(mode = nothing)
error("EnzymeAdjoint is not supported on Julia 1.12+. Please use a different sensitivity algorithm.")
end
end
end

"""
Expand Down
96 changes: 52 additions & 44 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,18 @@ easy_res11 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.ReverseDiffVJP(true)))
_,
easy_res12 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
_,
easy_res13 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
@static if SciMLSensitivity.ENZYME_ENABLED
_,
easy_res12 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
_,
easy_res13 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
end
_,
easy_res14 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
Expand Down Expand Up @@ -179,11 +181,13 @@ easy_res143 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussAdjoint(autojacvec = ReverseDiffVJP(true)))
_,
easy_res144 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
@static if SciMLSensitivity.ENZYME_ENABLED
_,
easy_res144 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
end
_,
easy_res145 = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
Expand Down Expand Up @@ -212,11 +216,13 @@ easy_res143k = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussKronrodAdjoint(autojacvec = ReverseDiffVJP(true)))
_,
easy_res144k = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussKronrodAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
@static if SciMLSensitivity.ENZYME_ENABLED
_,
easy_res144k = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussKronrodAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP()))
end
_,
easy_res145k = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
Expand Down Expand Up @@ -1049,34 +1055,36 @@ function dynamics!(du, u, p, t)
du[2] = -u[2] + tanh(p[3] * u[1] + p[4] * u[2])
end

function backsolve_grad(sol, lqr_params, checkpointing)
bwd_sol = solve(
ODEAdjointProblem(sol,
BacksolveAdjoint(autojacvec = EnzymeVJP(),
checkpointing = checkpointing),
@static if SciMLSensitivity.ENZYME_ENABLED
function backsolve_grad(sol, lqr_params, checkpointing)
bwd_sol = solve(
ODEAdjointProblem(sol,
BacksolveAdjoint(autojacvec = EnzymeVJP(),
checkpointing = checkpointing),
Tsit5(),
nothing, nothing, nothing, nothing, nothing,
(x, lqr_params, t) -> cost(x, lqr_params)),
Tsit5(),
nothing, nothing, nothing, nothing, nothing,
(x, lqr_params, t) -> cost(x, lqr_params)),
Tsit5(),
dense = false,
save_everystep = false)

bwd_sol.u[end][1:(end - x_dim)]
#fwd_sol, bwd_sol
end

x0 = ones(x_dim)
fwd_sol = solve(ODEProblem(dynamics!, x0, (0, T), params),
Tsit5(), abstol = 1e-9, reltol = 1e-9,
u0 = x0,
p = params,
dense = false,
save_everystep = false)
save_everystep = true)

bwd_sol.u[end][1:(end - x_dim)]
#fwd_sol, bwd_sol
end
backsolve_results = backsolve_grad(fwd_sol, params, false)
backsolve_checkpointing_results = backsolve_grad(fwd_sol, params, true)

x0 = ones(x_dim)
fwd_sol = solve(ODEProblem(dynamics!, x0, (0, T), params),
Tsit5(), abstol = 1e-9, reltol = 1e-9,
u0 = x0,
p = params,
dense = false,
save_everystep = true)

backsolve_results = backsolve_grad(fwd_sol, params, false)
backsolve_checkpointing_results = backsolve_grad(fwd_sol, params, true)

@test backsolve_results != backsolve_checkpointing_results
@test backsolve_results != backsolve_checkpointing_results
end

int_u0,
int_p = adjoint_sensitivities(fwd_sol, Tsit5(),
Expand Down
Loading
Loading