Skip to content

Commit f5ad214

Browse files
penelopeysmmhaurusunxd3
authored
Compatibility with DynamicPPL 0.38 + InitContext (#2676)
* Import `varname_leaves` etc from AbstractPPL instead * [no ci] initial updates for InitContext * [no ci] More fixes * [no ci] Fix pMCMC * [no ci] Fix Gibbs * [no ci] More fixes, reexport InitFrom * Fix a bunch of tests; I'll let CI tell me what's still broken... * Remove comment * Fix more tests * More test fixes * Fix more tests * fix GeneralizedExtremeValue numerical test * fix sample method * fix ESS reproducibility * Fix externalsampler test correctly * Fix everything (I _think_) * Add changelog * Fix remaining tests (for real this time) * Specify default chain type in Turing * fix DPPL revision * Fix changelog to mention unwrapped NT / Dict for initial_params * Remove references to islinked, set_flag, unset_flag * use `setleafcontext(::Model, ::AbstractContext)` * Fix for upstream removal of default_chain_type * Add clarifying comment for IS test * Revert ESS test (and add some numerical accuracy checks) * istrans -> is_transformed * Remove `loadstate` and `resume_from` * Remove a Sampler test * Paper over one crack * fix `resume_from` * remove a `Sampler` test * Update HISTORY.md Co-authored-by: Markus Hauru <[email protected]> * Remove `Sampler`, remove `InferenceAlgorithm`, transfer `initialstep`, `init_strategy`, and other functions from DynamicPPL to Turing (#2689) * Remove `Sampler` and move its interface to Turing * Test fixes (this is admittedly quite tiring) * Fix a couple of Gibbs tests (no doubt there are more) * actually fix the Gibbs ones * actually fix it this time * fix typo * point to breaking * Improve loadstate implementation * Re-add tests that were removed from DynamicPPL * Fix qualifier in src/mcmc/external_sampler.jl Co-authored-by: Xianda Sun <[email protected]> * Remove the default argument for initial_params * [skip ci] Remove DynamicPPL sources --------- Co-authored-by: Xianda Sun <[email protected]> * Fix a word in changelog * Improve changelog * Add PNTDist to changelog --------- Co-authored-by: Markus Hauru <[email protected]> Co-authored-by: Xianda Sun <[email protected]>
1 parent 385f161 commit f5ad214

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1157
-850
lines changed

HISTORY.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,58 @@
11
# 0.41.0
22

3+
## DynamicPPL 0.38
4+
5+
Turing.jl v0.41 brings with it all the underlying changes in DynamicPPL 0.38.
6+
Please see [the DynamicPPL changelog](https://github.com/TuringLang/DynamicPPL.jl/blob/main/HISTORY.md) for full details: in this section we only describe the changes that will directly affect end-users of Turing.jl.
7+
8+
### Performance
9+
10+
A number of functions such as `returned` and `predict` will have substantially better performance in this release.
11+
12+
### `ProductNamedTupleDistribution`
13+
14+
`Distributions.ProductNamedTupleDistribution` can now be used on the right-hand side of `~` in Turing models.
15+
16+
### Initial parameters
17+
18+
**Initial parameters for MCMC sampling must now be specified in a different form.**
19+
You still need to use the `initial_params` keyword argument to `sample`, but the allowed values are different.
20+
For almost all samplers in Turing.jl (except `Emcee`) this should now be a `DynamicPPL.AbstractInitStrategy`.
21+
22+
There are three kinds of initialisation strategies provided out of the box with Turing.jl (they are exported so you can use these directly with `using Turing`):
23+
24+
- `InitFromPrior()`: Sample from the prior distribution. This is the default for most samplers in Turing.jl (if you don't specify `initial_params`).
25+
26+
- `InitFromUniform(a, b)`: Sample uniformly from `[a, b]` in linked space. This is the default for Hamiltonian samplers. If `a` and `b` are not specified it defaults to `[-2, 2]`, which preserves the behaviour in previous versions (and mimics that of Stan).
27+
- `InitFromParams(p)`: Explicitly provide a set of initial parameters. **Note: `p` must be either a `NamedTuple` or an `AbstractDict{<:VarName}`; it can no longer be a `Vector`.** Parameters must be provided in unlinked space, even if the sampler later performs linking.
28+
29+
+ For this release of Turing.jl, you can also provide a `NamedTuple` or `AbstractDict{<:VarName}` and this will be automatically wrapped in `InitFromParams` for you. This is an intermediate measure for backwards compatibility, and will eventually be removed.
30+
31+
This change is made because Vectors are semantically ambiguous.
32+
It is not clear which element of the vector corresponds to which variable in the model, nor is it clear whether the parameters are in linked or unlinked space.
33+
Previously, both of these would depend on the internal structure of the VarInfo, which is an implementation detail.
34+
In contrast, the behaviour of `AbstractDict`s and `NamedTuple`s is invariant to the ordering of variables and it is also easier for readers to understand which variable is being set to which value.
35+
36+
If you were previously using `varinfo[:]` to extract a vector of initial parameters, you can now use `Dict(k => varinfo[k] for k in keys(varinfo)` to extract a Dict of initial parameters.
37+
38+
For more details about initialisation you can also refer to [the main TuringLang docs](https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters), and/or the [DynamicPPL API docs](https://turinglang.org/DynamicPPL.jl/stable/api/#DynamicPPL.InitFromPrior).
39+
40+
### `resume_from` and `loadstate`
41+
42+
The `resume_from` keyword argument to `sample` is now removed.
43+
Instead of `sample(...; resume_from=chain)` you can use `sample(...; initial_state=loadstate(chain))` which is entirely equivalent.
44+
`loadstate` is exported from Turing now instead of in DynamicPPL.
45+
46+
Note that `loadstate` only works for `MCMCChains.Chains`.
47+
For FlexiChains users please consult the FlexiChains docs directly where this functionality is described in detail.
48+
49+
### `pointwise_logdensities`
50+
51+
`pointwise_logdensities(model, chn)`, `pointwise_loglikelihoods(...)`, and `pointwise_prior_logdensities(...)` now return an `MCMCChains.Chains` object if `chn` is itself an `MCMCChains.Chains` object.
52+
The old behaviour of returning an `OrderedDict` is still available: you just need to pass `OrderedDict` as the third argument, i.e., `pointwise_logdensities(model, chn, OrderedDict)`.
53+
54+
## Initial step in MCMC sampling
55+
356
HMC and NUTS samplers no longer take an extra single step before starting the chain.
457
This means that if you do not discard any samples at the start, the first sample will be the initial parameters (which may be user-provided).
558

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
4545

4646
[extensions]
4747
TuringDynamicHMCExt = "DynamicHMC"
48-
TuringOptimExt = "Optim"
48+
TuringOptimExt = ["Optim", "AbstractPPL"]
4949

5050
[compat]
5151
ADTypes = "1.9"
@@ -64,7 +64,7 @@ Distributions = "0.25.77"
6464
DistributionsAD = "0.6"
6565
DocStringExtensions = "0.8, 0.9"
6666
DynamicHMC = "3.4"
67-
DynamicPPL = "0.37.2"
67+
DynamicPPL = "0.38"
6868
EllipticalSliceSampling = "0.5, 1, 2"
6969
ForwardDiff = "0.10.3, 1"
7070
Libtask = "0.9.3"

docs/src/api.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
7575
| `RepeatSampler` | [`Turing.Inference.RepeatSampler`](@ref) | A sampler that runs multiple times on the same variable |
7676
| `externalsampler` | [`Turing.Inference.externalsampler`](@ref) | Wrap an external sampler for use in Turing |
7777

78+
### Initialisation strategies
79+
80+
Turing.jl provides several strategies to initialise parameters for models.
81+
82+
| Exported symbol | Documentation | Description |
83+
|:----------------- |:--------------------------------------- |:--------------------------------------------------------------- |
84+
| `InitFromPrior` | [`DynamicPPL.InitFromPrior`](@extref) | Obtain initial parameters from the prior distribution |
85+
| `InitFromUniform` | [`DynamicPPL.InitFromUniform`](@extref) | Obtain initial parameters by sampling uniformly in linked space |
86+
| `InitFromParams` | [`DynamicPPL.InitFromParams`](@extref) | Manually specify (possibly a subset of) initial parameters |
87+
7888
### Variational inference
7989

8090
See the [docs of AdvancedVI.jl](https://turinglang.org/AdvancedVI.jl/stable/) for detailed usage and the [variational inference tutorial](https://turinglang.org/docs/tutorials/09-variational-inference/) for a basic walkthrough.

ext/TuringDynamicHMCExt.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,22 @@ struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S}
4444
stepsize::S
4545
end
4646

47-
function DynamicPPL.initialsampler(::DynamicPPL.Sampler{<:DynamicNUTS})
48-
return DynamicPPL.SampleFromUniform()
49-
end
50-
51-
function DynamicPPL.initialstep(
47+
function Turing.Inference.initialstep(
5248
rng::Random.AbstractRNG,
5349
model::DynamicPPL.Model,
54-
spl::DynamicPPL.Sampler{<:DynamicNUTS},
50+
spl::DynamicNUTS,
5551
vi::DynamicPPL.AbstractVarInfo;
5652
kwargs...,
5753
)
5854
# Ensure that initial sample is in unconstrained space.
59-
if !DynamicPPL.islinked(vi)
55+
if !DynamicPPL.is_transformed(vi)
6056
vi = DynamicPPL.link!!(vi, model)
6157
vi = last(DynamicPPL.evaluate!!(model, vi))
6258
end
6359

6460
# Define log-density function.
6561
= DynamicPPL.LogDensityFunction(
66-
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype
62+
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype
6763
)
6864

6965
# Perform initial step.
@@ -84,14 +80,14 @@ end
8480
function AbstractMCMC.step(
8581
rng::Random.AbstractRNG,
8682
model::DynamicPPL.Model,
87-
spl::DynamicPPL.Sampler{<:DynamicNUTS},
83+
spl::DynamicNUTS,
8884
state::DynamicNUTSState;
8985
kwargs...,
9086
)
9187
# Compute next sample.
9288
vi = state.vi
9389
= state.logdensity
94-
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize)
90+
steps = DynamicHMC.mcmc_steps(rng, spl.sampler, state.metric, ℓ, state.stepsize)
9591
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)
9692

9793
# Create next sample and state.

ext/TuringOptimExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module TuringOptimExt
22

33
using Turing: Turing
4+
using AbstractPPL: AbstractPPL
45
import Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
56
using Optim: Optim
67

@@ -186,7 +187,7 @@ function _optimize(
186187
f.ldf.model, f.ldf.getlogdensity, vi_optimum; adtype=f.ldf.adtype
187188
)
188189
vals_dict = Turing.Inference.getparams(f.ldf.model, vi_optimum)
189-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict))
190+
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict))
190191
vns_vals_iter = mapreduce(collect, vcat, iters)
191192
varnames = map(Symbol first, vns_vals_iter)
192193
vals = map(last, vns_vals_iter)

src/Turing.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ using DynamicPPL:
7373
conditioned,
7474
to_submodel,
7575
LogDensityFunction,
76-
@addlogprob!
76+
@addlogprob!,
77+
InitFromPrior,
78+
InitFromUniform,
79+
InitFromParams
7780
using StatsBase: predict
7881
using OrderedCollections: OrderedDict
7982

@@ -148,11 +151,17 @@ export
148151
fix,
149152
unfix,
150153
OrderedDict, # OrderedCollections
154+
# Initialisation strategies for models
155+
InitFromPrior,
156+
InitFromUniform,
157+
InitFromParams,
151158
# Point estimates - Turing.Optimisation
152159
# The MAP and MLE exports are only needed for the Optim.jl interface.
153160
maximum_a_posteriori,
154161
maximum_likelihood,
155162
MAP,
156-
MLE
163+
MLE,
164+
# Chain save/resume
165+
loadstate
157166

158167
end

src/mcmc/Inference.jl

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ using DynamicPPL:
1313
# or implement it for other VarInfo types and export it from DPPL.
1414
all_varnames_grouped_by_symbol,
1515
syms,
16-
islinked,
1716
setindex!!,
1817
push!!,
1918
setlogp!!,
@@ -23,12 +22,7 @@ using DynamicPPL:
2322
getsym,
2423
getdist,
2524
Model,
26-
Sampler,
27-
SampleFromPrior,
28-
SampleFromUniform,
29-
DefaultContext,
30-
set_flag!,
31-
unset_flag!
25+
DefaultContext
3226
using Distributions, Libtask, Bijectors
3327
using DistributionsAD: VectorOfMultivariate
3428
using LinearAlgebra
@@ -55,12 +49,9 @@ import Random
5549
import MCMCChains
5650
import StatsBase: predict
5751

58-
export InferenceAlgorithm,
59-
Hamiltonian,
52+
export Hamiltonian,
6053
StaticHamiltonian,
6154
AdaptiveHamiltonian,
62-
SampleFromUniform,
63-
SampleFromPrior,
6455
MH,
6556
ESS,
6657
Emcee,
@@ -78,13 +69,16 @@ export InferenceAlgorithm,
7869
RepeatSampler,
7970
Prior,
8071
predict,
81-
externalsampler
72+
externalsampler,
73+
init_strategy,
74+
loadstate
8275

83-
###############################################
84-
# Abstract interface for inference algorithms #
85-
###############################################
76+
#########################################
77+
# Generic AbstractMCMC methods dispatch #
78+
#########################################
8679

87-
include("algorithm.jl")
80+
const DEFAULT_CHAIN_TYPE = MCMCChains.Chains
81+
include("abstractmcmc.jl")
8882

8983
####################
9084
# Sampler wrappers #
@@ -262,13 +256,13 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
262256
dicts = map(ts) do t
263257
# In general getparams returns a dict of VarName => values. We need to also
264258
# split it up into constituent elements using
265-
# `DynamicPPL.varname_and_value_leaves` because otherwise MCMCChains.jl
259+
# `AbstractPPL.varname_and_value_leaves` because otherwise MCMCChains.jl
266260
# won't understand it.
267261
vals = getparams(model, t)
268262
nms_and_vs = if isempty(vals)
269263
Tuple{VarName,Any}[]
270264
else
271-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
265+
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals))
272266
mapreduce(collect, vcat, iters)
273267
end
274268
nms = map(first, nms_and_vs)
@@ -315,11 +309,10 @@ end
315309
getlogevidence(transitions, sampler, state) = missing
316310

317311
# Default MCMCChains.Chains constructor.
318-
# This is type piracy (at least for SampleFromPrior).
319312
function AbstractMCMC.bundle_samples(
320-
ts::Vector{<:Union{Transition,AbstractVarInfo}},
321-
model::AbstractModel,
322-
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
313+
ts::Vector{<:Transition},
314+
model::DynamicPPL.Model,
315+
spl::AbstractSampler,
323316
state,
324317
chain_type::Type{MCMCChains.Chains};
325318
save_state=false,
@@ -378,11 +371,10 @@ function AbstractMCMC.bundle_samples(
378371
return sort_chain ? sort(chain) : chain
379372
end
380373

381-
# This is type piracy (for SampleFromPrior).
382374
function AbstractMCMC.bundle_samples(
383-
ts::Vector{<:Union{Transition,AbstractVarInfo}},
384-
model::AbstractModel,
385-
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
375+
ts::Vector{<:Transition},
376+
model::DynamicPPL.Model,
377+
spl::AbstractSampler,
386378
state,
387379
chain_type::Type{Vector{NamedTuple}};
388380
kwargs...,
@@ -423,7 +415,7 @@ function group_varnames_by_symbol(vns)
423415
return d
424416
end
425417

426-
function save(c::MCMCChains.Chains, spl::Sampler, model, vi, samples)
418+
function save(c::MCMCChains.Chains, spl::AbstractSampler, model, vi, samples)
427419
nt = NamedTuple{(:sampler, :model, :vi, :samples)}((spl, model, deepcopy(vi), samples))
428420
return setinfo(c, merge(nt, c.info))
429421
end
@@ -442,18 +434,12 @@ include("sghmc.jl")
442434
include("emcee.jl")
443435
include("prior.jl")
444436

445-
#################################################
446-
# Generic AbstractMCMC methods dispatch #
447-
#################################################
448-
449-
include("abstractmcmc.jl")
450-
451437
################
452438
# Typing tools #
453439
################
454440

455441
function DynamicPPL.get_matching_type(
456-
spl::Sampler{<:Union{PG,SMC}}, vi, ::Type{TV}
442+
spl::Union{PG,SMC}, vi, ::Type{TV}
457443
) where {T,N,TV<:Array{T,N}}
458444
return Array{T,N}
459445
end

0 commit comments

Comments
 (0)