@@ -2,8 +2,8 @@ module BracketingNonlinearSolveChainRulesCoreExt
22
33using BracketingNonlinearSolve: bracketingnonlinear_solve_up, CommonSolve, SciMLBase
44using CommonSolve: solve
5- using SciMLBase: IntervalNonlinearProblem
6- using ForwardDiff
5+ using SciMLBase: IntervalNonlinearProblem, unwrapped_f
6+ using ForwardDiff: derivative, gradient
77using ChainRulesCore: ChainRulesCore, AbstractThunk, NoTangent, Tangent, unthunk
88
99function ChainRulesCore. rrule (
@@ -13,16 +13,16 @@ function ChainRulesCore.rrule(
1313)
1414 out = solve (prob, alg)
1515 u = out. u
16- f = SciMLBase . unwrapped_f (prob. f)
16+ f = unwrapped_f (prob. f)
1717 function ∇bracketingnonlinear_solve_up (Δ)
1818 Δ = Δ isa AbstractThunk ? unthunk (Δ) : Δ
1919 # Δ = dg/du
2020 Δ isa Tangent ? delu = Δ. u : delu = Δ
21- λ = only (ForwardDiff . derivative (u -> f (u, p), only (u)) \ delu)
21+ λ = only (derivative (u -> f (u, p), only (u)) \ delu)
2222 if p isa Number
23- dgdp = - λ * ForwardDiff . derivative (p -> f (u, p), p)
23+ dgdp = - λ * derivative (p -> f (u, p), p)
2424 else
25- dgdp = - λ * ForwardDiff . gradient (p -> f (u, p), p)
25+ dgdp = - λ * gradient (p -> f (u, p), p)
2626 end
2727 return (NoTangent (), NoTangent (), NoTangent (),
2828 dgdp, NoTangent (),
@@ -31,4 +31,4 @@ function ChainRulesCore.rrule(
3131 return out, ∇bracketingnonlinear_solve_up
3232end
3333
34- end
34+ end
0 commit comments