-
Notifications
You must be signed in to change notification settings - Fork 154
Open
Description
In Zygote.jl
, we can take the gradient with respect to all fields of a struct foo
passed through a function bar
via
g = Zygote.gradient(f -> bar(f), foo)
Can this be done in ForwardDiff
as well?
Reproducer:
using Zygote
using ForwardDiff
struct Foo
x::Number
t::Number
c::Number
end
function bar(f::Foo)
return f.x - f.c*f.t
end
foo = Foo(2, 3, 3e8)
println(foo)
g = Zygote.gradient(f -> bar(f), foo)
println(g)
g = ForwardDiff.gradient(f -> bar(f), foo)
println(g)
Metadata
Metadata
Assignees
Labels
No labels