Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b1cdc2a
Bump minor version
penelopeysm Aug 8, 2025
2a1ebff
Implement InitContext
penelopeysm Jul 9, 2025
06d0beb
Fix loading order of modules; move `prefix(::Model)` to model.jl
penelopeysm Jul 9, 2025
1f7017a
Add tests for InitContext behaviour
penelopeysm Jul 9, 2025
02ae965
inline `rand(::Distributions.Uniform)`
penelopeysm Jul 9, 2025
55634f4
Document
penelopeysm Jul 9, 2025
d6ba16c
Add a test to check that `init!!` doesn't change linking
penelopeysm Jul 19, 2025
fd78d42
Fix `push!` for VarNamedVector
penelopeysm Jul 20, 2025
d40df7e
Replace `evaluate_and_sample!!` -> `init!!`
penelopeysm Jul 10, 2025
a0f308b
Use `ParamsInit` for `predict`; remove `setval_and_resample!` and fri…
penelopeysm Jul 10, 2025
4ae143c
Use `init!!` for initialisation
penelopeysm Jul 10, 2025
c7e33e7
Paper over the `Sampling->Init` context stack (pending removal of Sam…
penelopeysm Jul 10, 2025
4b3df70
Remove SamplingContext from JETExt to avoid triggering `Sampling->Ini…
penelopeysm Jul 10, 2025
ef92a4b
Remove `predict` on vector of VarInfo
penelopeysm Jul 26, 2025
ec2632b
Fix some tests
penelopeysm Jul 20, 2025
23cafe0
Remove duplicated test
penelopeysm Jul 20, 2025
707bc4e
Remove `SamplingContext` for good
penelopeysm Jul 10, 2025
13988b5
Remove `tilde_assume` as well
penelopeysm Jul 10, 2025
331279c
Split up tilde_observe!! for Distribution / Submodel
penelopeysm Jul 18, 2025
cf87ce7
Move `PrefixContext` to a model field
penelopeysm Aug 5, 2025
c670ef0
Re-add tests and doctests
penelopeysm Aug 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.37.0"
version = "0.38.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Accessors = "0.1"
Distributions = "0.25"
Documenter = "1"
DocumenterMermaid = "0.1, 0.2"
DynamicPPL = "0.37"
DynamicPPL = "0.38"
FillArrays = "0.13, 1"
ForwardDiff = "0.10, 1"
JET = "0.9, 0.10"
Expand Down
55 changes: 34 additions & 21 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Part of the API of DynamicPPL is defined in the more lightweight interface packa

A core component of DynamicPPL is the [`@model`](@ref) macro.
It can be used to define probabilistic models in an intuitive way by specifying random variables and their distributions with `~` statements.
These statements are rewritten by `@model` as calls of [internal functions](@ref model_internal) for sampling the variables and computing their log densities.
These statements are rewritten by `@model` as calls of internal functions for sampling the variables and computing their log densities.

```@docs
@model
Expand Down Expand Up @@ -344,6 +344,13 @@ Base.empty!
SimpleVarInfo
```

### Tilde-pipeline

```@docs
tilde_assume!!
tilde_observe!!
```

### Accumulators

The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators.
Expand Down Expand Up @@ -450,33 +457,45 @@ AbstractPPL.evaluate!!

This method mutates the `varinfo` used for execution.
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method:

```@docs
DynamicPPL.evaluate_and_sample!!
```
If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this.

The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
Contexts are subtypes of `AbstractPPL.AbstractContext`.

```@docs
SamplingContext
DefaultContext
PrefixContext
ConditionContext
InitContext
```

### Samplers
### VarInfo initialisation

The function `init!!` is used to initialise, or overwrite, values in a VarInfo.
It is really a thin wrapper around using `evaluate!!` with an `InitContext`.

```@docs
DynamicPPL.init!!
```

In DynamicPPL two samplers are defined that are used to initialize unobserved random variables:
[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution.
To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained.
There are three concrete strategies provided in DynamicPPL:

```@docs
SampleFromPrior
SampleFromUniform
PriorInit
UniformInit
ParamsInit
```

Additionally, a generic sampler for inference is implemented.
If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method.

```@docs
DynamicPPL.AbstractInitStrategy
DynamicPPL.init
```

### Samplers

In DynamicPPL a generic sampler for inference is implemented.

```@docs
Sampler
Expand All @@ -487,7 +506,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu
```@docs
DynamicPPL.initialstep
DynamicPPL.loadstate
DynamicPPL.initialsampler
DynamicPPL.init_strategy
```

Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.
Expand All @@ -502,9 +521,3 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va
DynamicPPL.Experimental.determine_suitable_varinfo
DynamicPPL.Experimental.is_suitable_varinfo
```

### [Model-Internal Functions](@id model_internal)

```@docs
tilde_assume
```
2 changes: 0 additions & 2 deletions ext/DynamicPPLEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ else
using ..EnzymeCore
end

@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true

# Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme
# only checks whether such a method exists, and never runs it.
@inline EnzymeCore.EnzymeRules.inactive_noinl(::typeof(DynamicPPL.istrans), args...) =
Expand Down
15 changes: 5 additions & 10 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,17 @@ end
function DynamicPPL.Experimental._determine_varinfo_jet(
model::DynamicPPL.Model; only_ddpl::Bool=true
)
# Use SamplingContext to test type stability.
sampling_model = DynamicPPL.contextualize(
model, DynamicPPL.SamplingContext(model.context)
)

# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(sampling_model)
varinfo = DynamicPPL.typed_varinfo(model)

# Let's make sure that both evaluation and sampling doesn't result in type errors.
# Let's make sure that evaluation doesn't result in type errors.
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
sampling_model, varinfo; only_ddpl
model, varinfo; only_ddpl
)

if !issuccess
# Useful information for debugging.
@debug "Evaluaton with typed varinfo failed with the following issues:"
@debug "Evaluation with typed varinfo failed with the following issues:"
@debug result
end

Expand All @@ -46,7 +41,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(sampling_model)
DynamicPPL.untyped_varinfo(model)
end
end

Expand Down
38 changes: 27 additions & 11 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ end

function _check_varname_indexing(c::MCMCChains.Chains)
return DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using `VarName`s.")
error("This `Chains` object does not support indexing using `VarName`s.")
end

function DynamicPPL.getindex_varname(
Expand All @@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
return keys(c.info.varname_to_symbol)
end

function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
_check_varname_indexing(c)
d = Dict{DynamicPPL.VarName,Any}()
for vn in DynamicPPL.varnames(c)
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
end
return d
end

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Expand Down Expand Up @@ -114,9 +123,15 @@ function DynamicPPL.predict(

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
predictive_samples = map(iters) do (sample_idx, chain_idx)
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))

# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`
_, varinfo = DynamicPPL.init!!(
rng,
model,
varinfo,
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
)
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
varname_vals = mapreduce(
collect,
Expand Down Expand Up @@ -248,13 +263,14 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
# Update the varinfo with the current sample and make variables not present in `chain`
# to be sampled.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to the `model`.
model(deepcopy(varinfo))
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`, and
# return the model's retval.
retval, _ = DynamicPPL.init!!(
model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit())
)
retval
end
end

Expand Down
19 changes: 12 additions & 7 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,21 @@ export AbstractVarInfo,
values_as_in_model,
# Samplers
Sampler,
SampleFromPrior,
SampleFromUniform,
# LogDensityFunction
LogDensityFunction,
# Contexts
contextualize,
SamplingContext,
DefaultContext,
PrefixContext,
ConditionContext,
assume,
tilde_assume,
# Tilde pipeline
tilde_assume!!,
tilde_observe!!,
# Initialisation
InitContext,
AbstractInitStrategy,
PriorInit,
UniformInit,
ParamsInit,
# Pseudo distributions
NamedDist,
NoDist,
Expand Down Expand Up @@ -170,11 +173,13 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
# Necessary forward declarations
include("utils.jl")
include("chains.jl")
include("contexts.jl")
include("contexts/init.jl")
include("model.jl")
include("prefix.jl")
include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("submodel.jl")
include("varnamedvector.jl")
include("accumulators.jl")
Expand Down
47 changes: 25 additions & 22 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@ evaluates to a `VarName`, and this will be used in the subsequent checks.
If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be
used in its place.
"""
function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr))
function isassumption(expr::Union{Expr,Symbol}, left_vn=make_varname_expression(expr))
@gensym vn
return quote
if $(DynamicPPL.contextual_isassumption)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
)
# TODO(penelopeysm): This re-prefixing seems a bit wasteful. I'd really like
# the whole `isassumption` thing to be simplified, though, so I'll
# leave it till later.
$vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix)
if $(DynamicPPL.contextual_isassumption)(__model__.context, $vn)
# Considered an assumption by `__model__.context` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
Expand All @@ -78,8 +81,8 @@ function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr)
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
$(DynamicPPL.inmissings)($vn, __model__)
if !($(DynamicPPL.inargnames)($left_vn, __model__)) ||
$(DynamicPPL.inmissings)($left_vn, __model__)
true
else
$(maybe_view(expr)) === missing
Expand All @@ -99,7 +102,7 @@ isassumption(expr) = :(false)

Return `true` if `vn` is considered an assumption by `context`.
"""
function contextual_isassumption(context::AbstractContext, vn)
function contextual_isassumption(context::AbstractContext, vn::VarName)
if hasconditioned_nested(context, vn)
val = getconditioned_nested(context, vn)
# TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler?
Expand All @@ -115,9 +118,7 @@ end

isfixed(expr, vn) = false
function isfixed(::Union{Symbol,Expr}, vn)
return :($(DynamicPPL.contextual_isfixed)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
))
return :($(DynamicPPL.contextual_isfixed)(__model__.context, $vn))
end

"""
Expand Down Expand Up @@ -413,7 +414,9 @@ function generate_assign(left, right)
return quote
$right_val = $right
if $(DynamicPPL.is_extracting_values)(__varinfo__)
$vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left)))
$vn = $(DynamicPPL.maybe_prefix)(
$(make_varname_expression(left)), __model__.prefix
)
__varinfo__ = $(map_accumulator!!)(
$acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)
)
Expand Down Expand Up @@ -448,24 +451,23 @@ function generate_tilde(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn isassumption value dist
@gensym left_vn vn isassumption value dist

return quote
$dist = $right
$vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist)
$isassumption = $(DynamicPPL.isassumption(left, vn))
$left_vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist)
$vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix)
$isassumption = $(DynamicPPL.isassumption(left, left_vn))
if $(DynamicPPL.isfixed(left, vn))
$left = $(DynamicPPL.getfixed_nested)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
)
$left = $(DynamicPPL.getfixed_nested)(__model__.context, $vn)
elseif $isassumption
$(generate_tilde_assume(left, dist, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left = $(DynamicPPL.getconditioned_nested)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
)
# If `left_vn` is not in `argnames`, we need to make sure that the variable is defined.
# (Note: we use the unprefixed `left_vn` here rather than `vn` which will have had
# prefixes applied!)
if !$(DynamicPPL.inargnames)($left_vn, __model__)
$left = $(DynamicPPL.getconditioned_nested)(__model__.context, $vn)
end

$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
Expand Down Expand Up @@ -495,6 +497,7 @@ function generate_tilde_assume(left, right, vn)
return quote
$value, __varinfo__ = $(DynamicPPL.tilde_assume!!)(
__model__.context,
__model__.prefix,
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
__varinfo__,
)
Expand Down
Loading
Loading