Skip to content

Conversation

@jondeuce
Copy link
Contributor

I noticed that OptimiserChain was not marked @functor; this PR marks it as such to allow one to modify the rules internal to OptimiserChain via fmap.

I also modified the optimiser chain state to be a Tuple instead of an AbstractArray for better type inference in apply!(o::OptimiserChain, ...). If this is considered breaking/undesired I can remove it from the PR. Example of the improved inference:

x = zeros(Float32, 3)
dx = zero(x)
rule = AdamW()
state = Optimisers.setup(rule, x)
@code_warntype Optimisers.apply!(rule, state.state, x, dx)

Before:

MethodInstance for Optimisers.apply!(::OptimiserChain{Tuple{Adam{Float32}, WeightDecay{Float32}}}, ::Vector{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, ::Vector{Float32}, ::Vector{Float32})
  from apply!(o::OptimiserChain, states, x, dx, dxs...) in Optimisers at /home/jdoucette/.julia/dev/Optimisers/src/rules.jl:623
Arguments
  #self#::Core.Const(Optimisers.apply!)
  o::OptimiserChain{Tuple{Adam{Float32}, WeightDecay{Float32}}}
  states::Vector{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}
  x::Vector{Float32}
  dx@_5::Vector{Float32}
  dxs::Tuple{}
Locals
  @_7::Union{Nothing, Tuple{Tuple{Int64, Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}, Tuple{Int64, Tuple{Int64, Int64}}}}
  new_states::Vector{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}
  @_9::Int64
  @_10::Int64
  @_11::Int64
  state::Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}
  opt::Union{Adam{Float32}, WeightDecay{Float32}}
  i::Int64
  dx@_15::Any
Body::Tuple{Vector{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, Any}
1 ─       (dx@_15 = dx@_5)
│         (new_states = Optimisers.similar(states))
│   %3  = Base.getproperty(o, :opts)::Tuple{Adam{Float32}, WeightDecay{Float32}}%4  = Optimisers.zip(%3, states)::Base.Iterators.Zip{Tuple{Tuple{Adam{Float32}, WeightDecay{Float32}}, Vector{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}
│   %5  = Optimisers.enumerate(%4)::Base.Iterators.Enumerate{Base.Iterators.Zip{Tuple{Tuple{Adam{Float32}, WeightDecay{Float32}}, Vector{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}}
│         (@_7 = Base.iterate(%5))
│   %7  = (@_7::Union{Nothing, Tuple{Tuple{Int64, Tuple{Adam{Float32}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}, Tuple{Int64, Tuple{Int64, Int64}}}} === nothing)::Bool%8  = Base.not_int(%7)::Bool
└──       goto #4 if not %8
2%10 = @_7::Tuple{Tuple{Int64, Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}, Tuple{Int64, Tuple{Int64, Int64}}}
│   %11 = Core.getfield(%10, 1)::Tuple{Int64, Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}
│   %12 = Base.indexed_iterate(%11, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (i = Core.getfield(%12, 1))
│         (@_11 = Core.getfield(%12, 2))
│   %15 = Base.indexed_iterate(%11, 2, @_11::Core.Const(2))::Core.PartialStruct(Tuple{Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, Int64}, Any[Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, Core.Const(3)])
│   %16 = Core.getfield(%15, 1)::Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}
│   %17 = Base.indexed_iterate(%16, 1)::Core.PartialStruct(Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Int64}, Any[Union{Adam{Float32}, WeightDecay{Float32}}, Core.Const(2)])
│         (opt = Core.getfield(%17, 1))
│         (@_10 = Core.getfield(%17, 2))
│   %20 = Base.indexed_iterate(%16, 2, @_10::Core.Const(2))::Core.PartialStruct(Tuple{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}, Int64}, Any[Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}, Core.Const(3)])
│         (state = Core.getfield(%20, 1))
│   %22 = Core.getfield(%10, 2)::Tuple{Int64, Tuple{Int64, Int64}}%23 = Core.tuple(opt, state, x, dx@_15)::Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}, Vector{Float32}, Any}
│   %24 = Core._apply_iterate(Base.iterate, Optimisers.apply!, %23, dxs)::Tuple{Union{Nothing, Tuple{Any, Any, Tuple{Float32, Float32}}}, Any}
│   %25 = Base.indexed_iterate(%24, 1)::Core.PartialStruct(Tuple{Union{Nothing, Tuple{Any, Any, Tuple{Float32, Float32}}}, Int64}, Any[Union{Nothing, Tuple{Any, Any, Tuple{Float32, Float32}}}, Core.Const(2)])
│   %26 = Core.getfield(%25, 1)::Union{Nothing, Tuple{Any, Any, Tuple{Float32, Float32}}}
│         (@_9 = Core.getfield(%25, 2))
│   %28 = Base.indexed_iterate(%24, 2, @_9::Core.Const(2))::Tuple{Any, Int64}
│         (dx@_15 = Core.getfield(%28, 1))
│         Base.setindex!(new_states, %26, i)
│         (@_7 = Base.iterate(%5, %22))
│   %32 = (@_7 === nothing)::Bool%33 = Base.not_int(%32)::Bool
└──       goto #4 if not %33
3 ─       goto #2
4%36 = Core.tuple(new_states, dx@_15)::Tuple{Vector{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, Any}
└──       return %36

After:

MethodInstance for Optimisers.apply!(::OptimiserChain{Tuple{Adam{Float32}, WeightDecay{Float32}}}, ::Tuple{Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}, Nothing}, ::Vector{Float32}, ::Vector{Float32})
  from apply!(o::OptimiserChain, states, x, dx, dxs...) in Optimisers at /home/jdoucette/.julia/dev/Optimisers/src/rules.jl:625
Arguments
  #self#::Core.Const(Optimisers.apply!)
  o::OptimiserChain{Tuple{Adam{Float32}, WeightDecay{Float32}}}
  states::Tuple{Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}, Nothing}
  x::Vector{Float32}
  dx::Vector{Float32}
  dxs::Tuple{}
Locals
  #81::Optimisers.var"#81#82"{Vector{Float32}, Tuple{}}
Body::Tuple{Tuple{Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}, Nothing}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}}, Float32}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(*), Tuple{Float32, Vector{Float32}}}}}}
1%1  = Optimisers.:(var"#81#82")::Core.Const(Optimisers.var"#81#82")
│   %2  = Core.typeof(x)::Core.Const(Vector{Float32})
│   %3  = Core.typeof(dxs)::Core.Const(Tuple{})
│   %4  = Core.apply_type(%1, %2, %3)::Core.Const(Optimisers.var"#81#82"{Vector{Float32}, Tuple{}})
│         (#81 = %new(%4, x, dxs))%6  = #81::Optimisers.var"#81#82"{Vector{Float32}, Tuple{}}%7  = (:init,)::Core.Const((:init,))
│   %8  = Core.apply_type(Core.NamedTuple, %7)::Core.Const(NamedTuple{(:init,)})
│   %9  = ()::Core.Const(())
│   %10 = Core.tuple(%9, dx)::Tuple{Tuple{}, Vector{Float32}}%11 = Core.tuple(%10)::Tuple{Tuple{Tuple{}, Vector{Float32}}}%12 = (%8)(%11)::NamedTuple{(:init,), Tuple{Tuple{Tuple{}, Vector{Float32}}}}
│   %13 = Core.kwfunc(Optimisers.foldl)::Core.Const(Base.var"#foldl##kw"())
│   %14 = Base.getproperty(o, :opts)::Tuple{Adam{Float32}, WeightDecay{Float32}}%15 = Base.broadcasted(Optimisers.tuple, %14, states)::Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(tuple), Tuple{Tuple{Adam{Float32}, WeightDecay{Float32}}, Tuple{Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}, Nothing}}}
│   %16 = Base.materialize(%15)::Tuple{Tuple{Adam{Float32}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}, Tuple{WeightDecay{Float32}, Nothing}}
│   %17 = (%13)(%12, Optimisers.foldl, %6, %16)::Core.PartialStruct(Tuple{Tuple{Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}, Nothing}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}}, Float32}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(*), Tuple{Float32, Vector{Float32}}}}}}, Any[Tuple{Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}, Nothing}, Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}}, Float32}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(*), Tuple{Float32, Vector{Float32}}}}}, Any[Core.Const(+), Core.PartialStruct(Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}}, Float32}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(*), Tuple{Float32, Vector{Float32}}}}, Any[Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}}, Float32}}, Any[Core.Const(*), Core.PartialStruct(Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}}, Float32}, Any[Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}}, Any[Core.Const(/), Core.PartialStruct(Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}, Any[Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Any[Core.Const(/), Core.PartialStruct(Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}, Any[Vector{Float32}, Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}, Any[Core.Const(-), Core.PartialStruct(Tuple{Int64, Float32}, Any[Core.Const(1), Float32]), Core.Const(nothing)])]), Core.Const(nothing)]), Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}, Any[Core.Const(+), Core.PartialStruct(Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}, Any[Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Any[Core.Const(sqrt), Core.PartialStruct(Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}, Any[Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Any[Core.Const(/), Core.PartialStruct(Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}, Any[Vector{Float32}, Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}, Any[Core.Const(-), Core.PartialStruct(Tuple{Int64, Float32}, Any[Core.Const(1), Float32]), Core.Const(nothing)])]), Core.Const(nothing)])]), Core.Const(nothing)]), Float32]), Core.Const(nothing)])]), Core.Const(nothing)]), Float32]), Tuple{Base.OneTo{Int64}}]), Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(*), Tuple{Float32, Vector{Float32}}}]), Tuple{Base.OneTo{Int64}}])])
└──       return %17

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty reasonable to me. I recall we had some discussion about OptimiserChain's using a vector for state instead of a tuple, but not the conclusion. @darsnack might know.

julia> st2 = Optimisers.setup(OptimiserChain(ClipGrad(), Adam()), m)
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(10.0), Adam{Float32}(0.001, (0.9, 0.999), 1.19209f-7)), [nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999))]), fun = nothing)
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(10.0), Adam{Float32}(0.001, (0.9, 0.999), 1.19209f-7)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
Copy link
Member

@ToucheSir ToucheSir Oct 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't have expected fun = nothing to change to fun = () with this PR. Is this the output you see locally? If so, can you change this codeblock to a jldoctest so that we catch changes like this in the future?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed by #106 I think.

I guess we just need to remember the syntax for naming a doctest block, so that this one uses m, st from the one before.

Copy link
Contributor Author

@jondeuce jondeuce Oct 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that change does not arise from this PR, sorry for the confusion. I changed nothing to () for consistency with the doctest above it; I can lookup the syntax to make that block into a proper doctest, as well @mcabbott.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed now.

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do this, I think.

@mcabbott mcabbott merged commit 0b2d32b into FluxML:master Nov 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants