@@ -2,6 +2,7 @@ module Optimisation
22
33using .. Turing
44using NamedArrays: NamedArrays
5+ using AbstractPPL: AbstractPPL
56using DynamicPPL: DynamicPPL
67using LogDensityProblems: LogDensityProblems
78using Optimization: Optimization
@@ -320,7 +321,7 @@ function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
320321 # m.values, but they are more convenient to filter when they are VarNames rather than
321322 # Symbols.
322323 vals_dict = Turing. Inference. getparams (log_density. model, log_density. varinfo)
323- iters = map (DynamicPPL . varname_and_value_leaves, keys (vals_dict), values (vals_dict))
324+ iters = map (AbstractPPL . varname_and_value_leaves, keys (vals_dict), values (vals_dict))
324325 vns_and_vals = mapreduce (collect, vcat, iters)
325326 varnames = collect (map (first, vns_and_vals))
326327 # For each symbol s in var_symbols, pick all the values from m.values for which the
@@ -351,7 +352,7 @@ function ModeResult(log_density::OptimLogDensity, solution::SciMLBase.Optimizati
351352 varinfo_new = DynamicPPL. unflatten (log_density. ldf. varinfo, solution. u)
352353 # `getparams` performs invlinking if needed
353354 vals = Turing. Inference. getparams (log_density. ldf. model, varinfo_new)
354- iters = map (DynamicPPL . varname_and_value_leaves, keys (vals), values (vals))
355+ iters = map (AbstractPPL . varname_and_value_leaves, keys (vals), values (vals))
355356 vns_vals_iter = mapreduce (collect, vcat, iters)
356357 syms = map (Symbol ∘ first, vns_vals_iter)
357358 vals = map (last, vns_vals_iter)
0 commit comments