diff --git a/HISTORY.md b/HISTORY.md index d946a12d2..57ccaecd1 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,106 @@ # DynamicPPL Changelog +## 0.38.0 + +### Breaking changes + +#### Introduction of `InitContext` + +DynamicPPL 0.38 introduces a new evaluation context, `InitContext`. +It is used to generate fresh values for random variables in a model. + +Evaluation contexts are stored inside a `DynamicPPL.Model` object, and control what happens with tilde-statements when a model is run. +The two major leaf (basic) contexts are `DefaultContext` and, now, `InitContext`. +`DefaultContext` is the default context, and it simply uses the values that are already stored in the `VarInfo` object passed to the model evaluation function. +On the other hand, `InitContext` ignores values in the VarInfo object and inserts new values obtained from a specified source. +(It follows also that the VarInfo being used may be empty, which means that `InitContext` is now also the way to obtain a fresh VarInfo for a model.) + +DynamicPPL 0.38 provides three flavours of _initialisation strategies_, which are specified as the second argument to `InitContext`: + + - `InitContext(rng, InitFromPrior())`: New values are sampled from the prior distribution (on the right-hand side of the tilde). + - `InitContext(rng, InitFromUniform(a, b))`: New values are sampled uniformly from the interval `[a, b]`, and then invlinked to the support of the distribution on the right-hand side of the tilde. + - `InitContext(rng, InitFromParams(p, fallback))`: New values are obtained by indexing into the `p` object, which can be a `NamedTuple` or `Dict{<:VarName}`. If a variable is not found in `p`, then the `fallback` strategy is used, which is simply another of these strategies. In particular, `InitFromParams` enables the case where different variables are to be initialised from different sources. + +(It is possible to define your own initialisation strategy; users who wish to do so are referred to the DynamicPPL API documentation and source code.) + +**The main impact on the upcoming Turing.jl release** is that, instead of providing initial values for sampling, the user will be expected to provide an initialisation strategy instead. +This is a more flexible approach, and not only solves a number of pre-existing issues with initialisation of Turing models, but also improves the clarity of user code. +In particular: + + - When providing a set of fixed parameters (i.e. `InitFromParams(p)`), `p` must now either be a NamedTuple or a Dict. Previously Vectors were allowed, which is error-prone because the ordering of variables in a VarInfo is not obvious. + - The parameters in `p` must now always be provided in unlinked space (i.e., in the space of the distribution on the right-hand side of the tilde). Previously, whether a parameter was expected to be in linked or unlinked space depended on whether the VarInfo was linked or not, which was confusing. + +#### Removal of `SamplingContext` + +For developers working on DynamicPPL, `InitContext` now completely replaces what used to be `SamplingContext`, `SampleFromPrior`, and `SampleFromUniform`. +Evaluating a model with `SamplingContext(SampleFromPrior())` (e.g. with `DynamicPPL.evaluate_and_sample!!(model, VarInfo(), SampleFromPrior())` has a direct one-to-one replacement in `DynamicPPL.init!!(model, VarInfo(), InitFromPrior())`. +Please see the docstring of `init!!` for more details. +Likewise `SampleFromUniform()` can be replaced with `InitFromUniform()`. +`InitFromParams()` provides new functionality which was previously implemented in the roundabout way of manipulating the VarInfo (e.g. using `unflatten`, or even more hackily by directly modifying values in the VarInfo), and then evaluating using `DefaultContext`. + +The main change that this is likely to create is for those who are implementing samplers or inference algorithms. +The exact way in which this happens will be detailed in the Turing.jl changelog when a new release is made. +Broadly speaking, though, `SamplingContext(MySampler())` will be removed so if your sampler needs custom behaviour with the tilde-pipeline you will likely have to define your own context. + +#### Removal of `DynamicPPL.Sampler` + +`DynamicPPL.Sampler` and **all associated interface functions** have also been removed entirely. +If you were using these, the corresponding replacements are: + + - `DynamicPPL.Sampler(S)`: just don't wrap `S`; but make sure `S` subtypes `AbstractMCMC.AbstractSampler` + - `DynamicPPL.initialstep`: directly implement `AbstractMCMC.step` and `AbstractMCMC.step_warmup` as per the AbstractMCMC interface + - `DynamicPPL.loadstate`: `Turing.loadstate` (will be introduced in the next version) + - `DynamicPPL.default_chain_type`: removed, just use the `chain_type` keyword argument directly + - `DynamicPPL.initialsampler`: `Turing.Inference.init_strategy` (will be introduced in the next version; note that this function must return an `AbstractInitStrategy`, see above for explanation) + - `DynamicPPL.default_varinfo`: `Turing.Inference.default_varinfo` (will be introduced in the next version) + - `DynamicPPL.TestUtils.test_sampler` and related methods: removed, please implement your own testing utilities as needed + +#### Simplification of the tilde-pipeline + +There are now only two functions in the tilde-pipeline that need to be overloaded to change the behaviour of tilde-statements, namely, `tilde_assume!!` and `tilde_observe!!`. +Other functions such as `tilde_assume` and `assume` (and their `observe` counterparts) have been removed. + +Note that this was effectively already the case in DynamicPPL 0.37 (where they were just wrappers around each other). +The separation of these functions was primarily implemented to avoid performing extra work where unneeded (e.g. to not calculate the log-likelihood when `PriorContext` was being used). This functionality has since been replaced with accumulators (see the 0.37 changelog for more details). + +#### Removal of the `"del"` flag + +Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo`), had a flag called `"del"` for all variables. If it was set to `true` the variable was to be overwritten with a new value at the next evaluation. The new `InitContext` and related changes above make this flag unnecessary, and it has been removed. + +The only flag other than `"del"` that `Metadata` ever used was `"trans"`. Thus the generic functions `set_flag!`, `unset_flag!` and `is_flagged!` have also been removed in favour of more specific ones. We've also used this opportunity to name the `"trans"` flag and the corresponding `istrans` function to be more explicit. The new, exported interface consists of the `is_transformed` and `set_transformed!!` functions. + +#### Removal of `resume_from` + +The `resume_from=chn` keyword argument to `sample` has been removed; please use the `initial_state` argument instead. +`loadstate` will be exported from Turing in the next release of Turing. + +#### Change of output type for `pointwise_logdensities` + +The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` when called on `MCMCChains.Chains` objects, now return new `MCMCChains.Chains` objects by default, instead of dictionaries of matrices. + +If you want the old behaviour, you can pass `OrderedDict` as the third argument, i.e., `pointwise_logdensities(model, chain, OrderedDict)`. + +### Other changes + +#### `predict(model, chain; include_all)` + +The `include_all` keyword argument for `predict` now works even when no RNG is specified (previously it would only work when an RNG was explicitly passed). + +#### `DynamicPPL.setleafcontext(model, context)` + +This convenience method has been added to quickly modify the leaf context of a model. + +#### Reimplementation of functions using `InitContext` + +A number of functions have been reimplemented and unified with the help of `InitContext`. +In particular, this release brings substantial performance improvements for `returned` and `predict`. +Their APIs are the same. + +#### Upstreaming of VarName functionality + +The implementation of the `varname_leaves` and `varname_and_value_leaves` functions have been moved to AbstractPPL.jl. +Their behaviour is otherwise identical, and they are still accessible from the DynamicPPL module (though still not exported). + ## 0.37.5 A minor optimisation for Enzyme AD on DynamicPPL models. diff --git a/Project.toml b/Project.toml index e636199a1..2fe65fd7b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.37.5" +version = "0.38.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -48,7 +48,7 @@ DynamicPPLMooncakeExt = ["Mooncake"] [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.13" +AbstractPPL = "0.13.1" Accessors = "0.1" BangBang = "0.4.1" Bijectors = "0.13.18, 0.14, 0.15" diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 97afe3532..22fb89267 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -23,11 +23,11 @@ DynamicPPL = {path = "../"} ADTypes = "1.14.0" BenchmarkTools = "1.6.0" Distributions = "0.25.117" -DynamicPPL = "0.37" +DynamicPPL = "0.38" Enzyme = "0.13" ForwardDiff = "0.10.38, 1" LogDensityProblems = "2.1.2" Mooncake = "0.4" PrettyTables = "3" ReverseDiff = "1.15.3" -StableRNGs = "1" +StableRNGs = "1" \ No newline at end of file diff --git a/docs/Project.toml b/docs/Project.toml index 47563d9a7..ccd701c6e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -19,7 +19,7 @@ Accessors = "0.1" Distributions = "0.25" Documenter = "1" DocumenterMermaid = "0.1, 0.2" -DynamicPPL = "0.37" +DynamicPPL = "0.38" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" JET = "0.9, 0.10" diff --git a/docs/src/api.md b/docs/src/api.md index 999bbe822..80970c0bb 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -8,7 +8,7 @@ Part of the API of DynamicPPL is defined in the more lightweight interface packa A core component of DynamicPPL is the [`@model`](@ref) macro. It can be used to define probabilistic models in an intuitive way by specifying random variables and their distributions with `~` statements. -These statements are rewritten by `@model` as calls of [internal functions](@ref model_internal) for sampling the variables and computing their log densities. +These statements are rewritten by `@model` as calls of internal functions for sampling the variables and computing their log densities. ```@docs @model @@ -243,14 +243,7 @@ DynamicPPL.TestUtils.AD.ADIncorrectException ## Demo models -DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule. - -```@docs -DynamicPPL.TestUtils.test_sampler -DynamicPPL.TestUtils.test_sampler_on_demo_models -DynamicPPL.TestUtils.test_sampler_continuous -DynamicPPL.TestUtils.marginal_mean_of_samples -``` +DynamicPPL provides several demo models in the `DynamicPPL.TestUtils` submodule. ```@docs DynamicPPL.TestUtils.DEMO_MODELS @@ -345,9 +338,8 @@ The [Transformations section below](#Transformations) describes the methods used In the specific case of `VarInfo`, it keeps track of whether samples have been transformed by setting flags on them, using the following functions. ```@docs -set_flag! -unset_flag! -is_flagged +is_transformed +set_transformed!! ``` ```@docs @@ -360,6 +352,13 @@ Base.empty! SimpleVarInfo ``` +### Tilde-pipeline + +```@docs +tilde_assume!! +tilde_observe!! +``` + ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. @@ -432,8 +431,6 @@ DynamicPPL.StaticTransformation ``` ```@docs -DynamicPPL.istrans -DynamicPPL.settrans!! DynamicPPL.transformation DynamicPPL.link DynamicPPL.invlink @@ -451,8 +448,6 @@ DynamicPPL.maybe_invlink_before_eval!! Base.merge(::AbstractVarInfo) DynamicPPL.subset DynamicPPL.unflatten -DynamicPPL.varname_leaves -DynamicPPL.varname_and_value_leaves ``` ### Evaluation Contexts @@ -465,51 +460,44 @@ AbstractPPL.evaluate!! This method mutates the `varinfo` used for execution. By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. -To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method: - -```@docs -DynamicPPL.evaluate_and_sample!! -``` +If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this. The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs -SamplingContext DefaultContext PrefixContext ConditionContext +InitContext ``` -### Samplers +### VarInfo initialisation -In DynamicPPL two samplers are defined that are used to initialize unobserved random variables: -[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution. +The function `init!!` is used to initialise, or overwrite, values in a VarInfo. +It is really a thin wrapper around using `evaluate!!` with an `InitContext`. ```@docs -SampleFromPrior -SampleFromUniform +DynamicPPL.init!! ``` -Additionally, a generic sampler for inference is implemented. +To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. +There are three concrete strategies provided in DynamicPPL: ```@docs -Sampler +InitFromPrior +InitFromUniform +InitFromParams ``` -The default implementation of [`Sampler`](@ref) uses the following unexported functions. +If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. ```@docs -DynamicPPL.initialstep -DynamicPPL.loadstate -DynamicPPL.initialsampler +DynamicPPL.AbstractInitStrategy +DynamicPPL.init ``` -Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`. - -```@docs -DynamicPPL.default_varinfo -``` +### Choosing a suitable VarInfo There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.typed_varinfo`](@ref) or [`DynamicPPL.untyped_varinfo`](@ref), depending on which supports the model: @@ -517,9 +505,3 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va DynamicPPL.Experimental.determine_suitable_varinfo DynamicPPL.Experimental.is_suitable_varinfo ``` - -### [Model-Internal Functions](@id model_internal) - -```@docs -tilde_assume -``` diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index d592e76b3..35159636f 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -8,10 +8,9 @@ else using ..EnzymeCore end -@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true - -# Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme +# Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. -@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.istrans), args...) = nothing +@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) = + nothing end diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index 760d17bb0..e0163bb35 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -6,7 +6,6 @@ using JET: JET function DynamicPPL.Experimental.is_suitable_varinfo( model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true ) - # Let's make sure that both evaluation and sampling doesn't result in type errors. f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo) # If specified, we only check errors originating somewhere in the DynamicPPL.jl. # This way we don't just fall back to untyped if the user's code is the issue. @@ -21,32 +20,36 @@ end function DynamicPPL.Experimental._determine_varinfo_jet( model::DynamicPPL.Model; only_ddpl::Bool=true ) - # Use SamplingContext to test type stability. - sampling_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(model.context) - ) - - # First we try with the typed varinfo. - varinfo = DynamicPPL.typed_varinfo(sampling_model) + # Generate a typed varinfo to test model type stability with + varinfo = DynamicPPL.typed_varinfo(model) - # Let's make sure that both evaluation and sampling doesn't result in type errors. - issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( - sampling_model, varinfo; only_ddpl + # Check type stability of evaluation (i.e. DefaultContext) + model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext()) + eval_issuccess, eval_result = DynamicPPL.Experimental.is_suitable_varinfo( + model, varinfo; only_ddpl ) + if !eval_issuccess + @debug "Evaluation with typed varinfo failed with the following issues:" + @debug eval_result + end - if !issuccess - # Useful information for debugging. - @debug "Evaluaton with typed varinfo failed with the following issues:" - @debug result + # Check type stability of initialisation (i.e. InitContext) + model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) + init_issuccess, init_result = DynamicPPL.Experimental.is_suitable_varinfo( + model, varinfo; only_ddpl + ) + if !init_issuccess + @debug "Initialisation with typed varinfo failed with the following issues:" + @debug init_result end - # If we didn't fail anywhere, we return the type stable one. - return if issuccess + # If neither of them failed, we can return the typed varinfo as it's type stable. + return if (eval_issuccess && init_issuccess) varinfo else # Warn the user that we can't use the type stable one. @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.untyped_varinfo(sampling_model) + DynamicPPL.untyped_varinfo(model) end end diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index a29696720..003372449 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -1,24 +1,7 @@ module DynamicPPLMCMCChainsExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL - using MCMCChains: MCMCChains -else - using ..DynamicPPL: DynamicPPL - using ..MCMCChains: MCMCChains -end - -# Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata -function DynamicPPL.loadstate(chain::MCMCChains.Chains) - if !haskey(chain.info, :samplerstate) - throw( - ArgumentError( - "The chain object does not contain the final state of the sampler: Metadata `:samplerstate` missing.", - ), - ) - end - return chain.info[:samplerstate] -end +using DynamicPPL: DynamicPPL, AbstractPPL +using MCMCChains: MCMCChains _has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names @@ -28,7 +11,7 @@ end function _check_varname_indexing(c::MCMCChains.Chains) return DynamicPPL.supports_varname_indexing(c) || - error("Chains do not support indexing using `VarName`s.") + error("This `Chains` object does not support indexing using `VarName`s.") end function DynamicPPL.getindex_varname( @@ -42,6 +25,17 @@ function DynamicPPL.varnames(c::MCMCChains.Chains) return keys(c.info.varname_to_symbol) end +function chain_sample_to_varname_dict( + c::MCMCChains.Chains{Tval}, sample_idx, chain_idx +) where {Tval} + _check_varname_indexing(c) + d = Dict{DynamicPPL.VarName,Tval}() + for vn in DynamicPPL.varnames(c) + d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) + end + return d +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) @@ -110,18 +104,36 @@ function DynamicPPL.predict( include_all=false, ) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - varinfo = DynamicPPL.VarInfo(model) + + # Set up a VarInfo with the right accumulators + varinfo = DynamicPPL.setaccs!!( + DynamicPPL.VarInfo(), + ( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogJacobianAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(false), + ), + ) + _, varinfo = DynamicPPL.init!!(model, varinfo) + varinfo = DynamicPPL.typed_varinfo(varinfo) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) - DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) - varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo)) - - vals = DynamicPPL.values_as_in_model(model, false, varinfo) + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) + # Resample any variables that are not present in `values_dict` + _, varinfo = DynamicPPL.init!!( + rng, + model, + varinfo, + DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), + ) + vals = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values varname_vals = mapreduce( collect, vcat, - map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)), + map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)), ) return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo)) @@ -144,6 +156,13 @@ function DynamicPPL.predict( end return chain_result[parameter_names] end +function DynamicPPL.predict( + model::DynamicPPL.Model, chain::MCMCChains.Chains; include_all=false +) + return DynamicPPL.predict( + DynamicPPL.Random.default_rng(), model, chain; include_all=include_all + ) +end function _predictive_samples_to_arrays(predictive_samples) variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() @@ -248,13 +267,302 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) - # TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702. - # Update the varinfo with the current sample and make variables not present in `chain` - # to be sampled. - DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) - # NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to - # `deepcopy` the `varinfo` before passing it to the `model`. - model(deepcopy(varinfo)) + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx) + # Resample any variables that are not present in `values_dict`, and + # return the model's retval. + retval, _ = DynamicPPL.init!!( + model, + varinfo, + DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), + ) + retval + end +end + +""" + DynamicPPL.pointwise_logdensities( + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + ::Type{Tout}=MCMCChains.Chains + ::Val{whichlogprob}=Val(:both), + ) + +Runs `model` on each sample in `chain`, returning a new `MCMCChains.Chains` object where +the log-density of each variable at each sample is stored (rather than its value). + +`whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or +`:likelihood`. + +You can pass `Tout=OrderedDict` to get the result as an `OrderedDict{VarName, +Matrix{Float64}}` instead. + +See also: [`DynamicPPL.pointwise_loglikelihoods`](@ref), +[`DynamicPPL.pointwise_prior_logdensities`](@ref). + +# Examples + +```jldoctest pointwise-logdensities-chains; setup=:(using Distributions) +julia> using MCMCChains + +julia> @model function demo(xs, y) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in eachindex(xs) + xs[i] ~ Normal(m, √s) + end + y ~ Normal(m, √s) + end +demo (generic function with 2 methods) + +julia> # Example observations. + model = demo([1.0, 2.0, 3.0], [4.0]); + +julia> # A chain with 3 iterations. + chain = Chains( + reshape(1.:6., 3, 2), + [:s, :m]; + info=(varname_to_symbol=Dict( + @varname(s) => :s, + @varname(m) => :m, + ),), + ); + +julia> plds = pointwise_logdensities(model, chain) +Chains MCMC chain (3×6×1 Array{Float64, 3}): + +Iterations = 1:1:3 +Number of chains = 1 +Samples per chain = 3 +parameters = s, m, xs[1], xs[2], xs[3], y +[...] + +julia> plds[:s] +2-dimensional AxisArray{Float64,2,...} with axes: + :iter, 1:1:3 + :chain, 1:1 +And data, a 3×1 Matrix{Float64}: + -0.8027754226637804 + -1.3822169643436162 + -2.0986122886681096 + +julia> # The above is the same as: + logpdf.(InverseGamma(2, 3), chain[:s]) +3×1 Matrix{Float64}: + -0.8027754226637804 + -1.3822169643436162 + -2.0986122886681096 +``` + +julia> # Alternatively: + plds_dict = pointwise_logdensities(model, chain, OrderedDict) +OrderedDict{VarName, Matrix{Float64}} with 6 entries: + s => [-0.802775; -1.38222; -2.09861;;] + m => [-8.91894; -7.51551; -7.46824;;] + xs[1] => [-5.41894; -5.26551; -5.63491;;] + xs[2] => [-2.91894; -3.51551; -4.13491;;] + xs[3] => [-1.41894; -2.26551; -2.96824;;] + y => [-0.918939; -1.51551; -2.13491;;] +""" +function DynamicPPL.pointwise_logdensities( + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + ::Type{Tout}=MCMCChains.Chains, + ::Val{whichlogprob}=Val(:both), +) where {whichlogprob,Tout} + vi = DynamicPPL.VarInfo(model) + acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() + accname = DynamicPPL.accumulator_name(acc) + vi = DynamicPPL.setaccs!!(vi, (acc,)) + parameter_only_chain = MCMCChains.get_sections(chain, :parameters) + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + pointwise_logps = map(iters) do (sample_idx, chain_idx) + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) + # Re-evaluate the model + _, vi = DynamicPPL.init!!( + model, + vi, + DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), + ) + DynamicPPL.getacc(vi, Val(accname)).logps + end + + # pointwise_logps is a matrix of OrderedDicts + all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() + for d in pointwise_logps + union!(all_keys, DynamicPPL.OrderedCollections.OrderedSet(keys(d))) + end + # this is a 3D array: (iterations, variables, chains) + new_data = [ + get(pointwise_logps[iter, chain], k, missing) for + iter in 1:size(pointwise_logps, 1), k in all_keys, + chain in 1:size(pointwise_logps, 2) + ] + + if Tout == MCMCChains.Chains + return MCMCChains.Chains(new_data, Symbol.(collect(all_keys))) + elseif Tout <: AbstractDict + return Tout{DynamicPPL.VarName,Matrix{Float64}}( + k => new_data[:, i, :] for (i, k) in enumerate(all_keys) + ) + end +end + +""" + DynamicPPL.pointwise_loglikelihoods( + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + ::Type{Tout}=MCMCChains.Chains + ) + +Compute the pointwise log-likelihoods of the model given the chain. This is the same as +`pointwise_logdensities(model, chain)`, but only including the likelihood terms. + +See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref). +""" +function DynamicPPL.pointwise_loglikelihoods( + model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains +) where {Tout} + return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:likelihood)) +end + +""" + DynamicPPL.pointwise_prior_logdensities( + model::DynamicPPL.Model, + chain::MCMCChains.Chains + ) + +Compute the pointwise log-prior-densities of the model given the chain. This is the same as +`pointwise_logdensities(model, chain)`, but only including the prior terms. + +See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_loglikelihoods`](@ref). +""" +function DynamicPPL.pointwise_prior_logdensities( + model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains +) where {Tout} + return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:prior)) +end + +""" + logjoint(model::Model, chain::MCMCChains.Chains) + +Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`. + +# Examples + +```jldoctest +julia> using MCMCChains, Distributions + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + end; + +julia> # Construct a chain of samples using MCMCChains. + # This sets s = 0.5 and m = 1.0 for all three samples. + chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]); + +julia> logjoint(demo_model([1., 2.]), chain) +3×1 Matrix{Float64}: + -5.440428709758045 + -5.440428709758045 + -5.440428709758045 +``` +""" +function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) + var_info = DynamicPPL.VarInfo(model) # extract variables info from the model + map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) + argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( + vn_parent => DynamicPPL.values_from_chain( + var_info, vn_parent, chain, chain_idx, iteration_idx + ) for vn_parent in keys(var_info) + ) + DynamicPPL.logjoint(model, argvals_dict) + end +end + +""" + loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) + +Return an array of log likelihoods evaluated at each sample in an MCMC `chain`. +# Examples + +```jldoctest +julia> using MCMCChains, Distributions + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + end; + +julia> # Construct a chain of samples using MCMCChains. + # This sets s = 0.5 and m = 1.0 for all three samples. + chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]); + +julia> loglikelihood(demo_model([1., 2.]), chain) +3×1 Matrix{Float64}: + -2.1447298858494 + -2.1447298858494 + -2.1447298858494 +``` +""" +function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) + var_info = DynamicPPL.VarInfo(model) # extract variables info from the model + map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) + argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( + vn_parent => DynamicPPL.values_from_chain( + var_info, vn_parent, chain, chain_idx, iteration_idx + ) for vn_parent in keys(var_info) + ) + DynamicPPL.loglikelihood(model, argvals_dict) + end +end + +""" + logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) + +Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`. + +# Examples + +```jldoctest +julia> using MCMCChains, Distributions + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + end; + +julia> # Construct a chain of samples using MCMCChains. + # This sets s = 0.5 and m = 1.0 for all three samples. + chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]); + +julia> logprior(demo_model([1., 2.]), chain) +3×1 Matrix{Float64}: + -3.2956988239086447 + -3.2956988239086447 + -3.2956988239086447 +``` +""" +function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) + var_info = DynamicPPL.VarInfo(model) # extract variables info from the model + map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) + argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( + vn_parent => DynamicPPL.values_from_chain( + var_info, vn_parent, chain, chain_idx, iteration_idx + ) for vn_parent in keys(var_info) + ) + DynamicPPL.logprior(model, argvals_dict) end end diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index f6b352fab..23a3430eb 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -1,9 +1,9 @@ module DynamicPPLMooncakeExt -using DynamicPPL: DynamicPPL, istrans +using DynamicPPL: DynamicPPL, is_transformed using Mooncake: Mooncake # This is purely an optimisation. -Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg} +Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index bdc953a12..f5bd33d6d 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -70,10 +70,8 @@ export AbstractVarInfo, acclogjac!!, acclogprior!!, accloglikelihood!!, - is_flagged, - set_flag!, - unset_flag!, - istrans, + is_transformed, + set_transformed!!, link, link!!, invlink, @@ -94,20 +92,22 @@ export AbstractVarInfo, getargnames, extract_priors, values_as_in_model, - # Samplers - Sampler, - SampleFromPrior, - SampleFromUniform, # LogDensityFunction LogDensityFunction, # Contexts contextualize, - SamplingContext, DefaultContext, PrefixContext, ConditionContext, - assume, - tilde_assume, + # Tilde pipeline + tilde_assume!!, + tilde_observe!!, + # Initialisation + InitContext, + AbstractInitStrategy, + InitFromPrior, + InitFromUniform, + InitFromParams, # Pseudo distributions NamedDist, NoDist, @@ -170,11 +170,15 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") include("chains.jl") +include("contexts.jl") +include("contexts/default.jl") +include("contexts/init.jl") +include("contexts/transformation.jl") +include("contexts/prefix.jl") +include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl include("model.jl") -include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") -include("contexts.jl") include("submodel.jl") include("varnamedvector.jl") include("accumulators.jl") @@ -183,10 +187,8 @@ include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") include("simple_varinfo.jl") -include("context_implementations.jl") include("compiler.jl") include("pointwise_logdensities.jl") -include("transforming.jl") include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index ac841baab..ec5e1ea10 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -769,7 +769,7 @@ end # Transformations """ - istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}]) + is_transformed(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}]) Return `true` if `vi` is working in unconstrained space, and `false` if `vi` is assuming realizations to be in support of the corresponding distributions. @@ -780,27 +780,27 @@ If `vns` is provided, then only check if this/these varname(s) are transformed. Not all implementations of `AbstractVarInfo` support transforming only a subset of the variables. """ -istrans(vi::AbstractVarInfo) = istrans(vi, collect(keys(vi))) -function istrans(vi::AbstractVarInfo, vns::AbstractVector) - # This used to be: `!isempty(vns) && all(Base.Fix1(istrans, vi), vns)`. +is_transformed(vi::AbstractVarInfo) = is_transformed(vi, collect(keys(vi))) +function is_transformed(vi::AbstractVarInfo, vns::AbstractVector) + # This used to be: `!isempty(vns) && all(Base.Fix1(is_transformed, vi), vns)`. # In theory that should work perfectly fine. For unbeknownst reasons, # Julia 1.10 fails to infer its return type correctly. Thus we use this # slightly longer definition. isempty(vns) && return false for vn in vns - istrans(vi, vn) || return false + is_transformed(vi, vn) || return false end return true end """ - settrans!!(vi::AbstractVarInfo, trans::Bool[, vn::VarName]) + set_transformed!!(vi::AbstractVarInfo, trans::Bool[, vn::VarName]) -Return `vi` with `istrans(vi, vn)` evaluating to `true`. +Return `vi` with `is_transformed(vi, vn)` evaluating to `true`. -If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variables. +If `vn` is not specified, then `is_transformed(vi)` evaluates to `true` for all variables. """ -function settrans!! end +function set_transformed!! end # For link!!, invlink!!, link, and invlink, we deliberately do not provide a fallback # method for the case when no `vns` is provided, that would get all the keys from the @@ -827,6 +827,26 @@ end function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end +function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + # Note that in practice this method is only called for SimpleVarInfo, because VarInfo + # has a dedicated implementation + model = setleafcontext(model, DynamicTransformationContext{false}()) + vi = last(evaluate!!(model, vi)) + return set_transformed!!(vi, t) +end +function link!!( + t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model +) + b = inverse(t.bijector) + x = vi[:] + y, logjac = with_logabsdet_jacobian(b, x) + # Set parameters and add the logjac term. + vi = unflatten(vi, y) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, logjac) + end + return set_transformed!!(vi, t) +end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) @@ -846,6 +866,9 @@ end function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link(default_transformation(model, vi), vi, vns, model) end +function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + return link!!(t, deepcopy(vi), model) +end """ invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) @@ -866,23 +889,13 @@ end function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end - -# Vector-based ones. -function link!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model -) - b = inverse(t.bijector) - x = vi[:] - y, logjac = with_logabsdet_jacobian(b, x) - - # Set parameters and add the logjac term. - vi = unflatten(vi, y) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, logjac) - end - return settrans!!(vi, t) +function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) + # Note that in practice this method is only called for SimpleVarInfo, because VarInfo + # has a dedicated implementation + model = setleafcontext(model, DynamicTransformationContext{true}()) + vi = last(evaluate!!(model, vi)) + return set_transformed!!(vi, NoTransformation()) end - function invlink!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model ) @@ -897,7 +910,7 @@ function invlink!!( if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, inv_logjac) end - return settrans!!(vi, NoTransformation()) + return set_transformed!!(vi, NoTransformation()) end """ @@ -919,6 +932,9 @@ end function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink(default_transformation(model, vi), vi, vns, model) end +function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + return invlink!!(t, deepcopy(vi), model) +end """ maybe_invlink_before_eval!!([t::Transformation,] vi, model) @@ -1002,7 +1018,7 @@ function unflatten end """ to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) -Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`. +Return reconstructed `val`, possibly linked if `is_transformed(vi, vn)` is `true`. """ function to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) f = to_maybe_linked_internal_transform(vi, vn, dist) @@ -1012,7 +1028,7 @@ end """ from_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) -Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`. +Return reconstructed `val`, possibly invlinked if `is_transformed(vi, vn)` is `true`. """ function from_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) f = from_maybe_linked_internal_transform(vi, vn, dist) @@ -1069,14 +1085,14 @@ in `varinfo` to a representation compatible with `dist`. If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`. """ function from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist) - return if istrans(varinfo, vn) + return if is_transformed(varinfo, vn) from_linked_internal_transform(varinfo, vn, dist) else from_internal_transform(varinfo, vn, dist) end end function from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName) - return if istrans(varinfo, vn) + return if is_transformed(varinfo, vn) from_linked_internal_transform(varinfo, vn) else from_internal_transform(varinfo, vn) diff --git a/src/context_implementations.jl b/src/context_implementations.jl deleted file mode 100644 index 786d7c913..000000000 --- a/src/context_implementations.jl +++ /dev/null @@ -1,171 +0,0 @@ -# assume -""" - tilde_assume(context::SamplingContext, right, vn, vi) - -Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value with a context associated -with a sampler. - -Falls back to -```julia -tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) -``` -""" -function tilde_assume(context::SamplingContext, right, vn, vi) - return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) -end - -function tilde_assume(context::AbstractContext, args...) - return tilde_assume(childcontext(context), args...) -end -function tilde_assume(::DefaultContext, right, vn, vi) - return assume(right, vn, vi) -end - -function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return tilde_assume(rng, childcontext(context), args...) -end -function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) - return assume(rng, sampler, right, vn, vi) -end -function tilde_assume(::DefaultContext, sampler, right, vn, vi) - # same as above but no rng - return assume(Random.default_rng(), sampler, right, vn, vi) -end - -function tilde_assume(context::PrefixContext, right, vn, vi) - # Note that we can't use something like this here: - # new_vn = prefix(context, vn) - # return tilde_assume(childcontext(context), right, new_vn, vi) - # This is because `prefix` applies _all_ prefixes in a given context to a - # variable name. Thus, if we had two levels of nested prefixes e.g. - # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the - # first call would apply the prefix `a.b._`, and the recursive call - # would apply the prefix `b._`, resulting in `b.a.b._`. - # This is why we need a special function, `prefix_and_strip_contexts`. - new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume(new_context, right, new_vn, vi) -end -function tilde_assume( - rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi -) - new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume(rng, new_context, sampler, right, new_vn, vi) -end - -""" - tilde_assume!!(context, right, vn, vi) - -Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value and updated `vi`. - -By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log -probability of `vi` with the returned value. -""" -function tilde_assume!!(context, right, vn, vi) - return if right isa DynamicPPL.Submodel - _evaluate!!(right, vi, context, vn) - else - tilde_assume(context, right, vn, vi) - end -end - -# observe -""" - tilde_observe!!(context::SamplingContext, right, left, vi) - -Handle observed constants with a `context` associated with a sampler. - -Falls back to `tilde_observe!!(context.context, right, left, vi)`. -""" -function tilde_observe!!(context::SamplingContext, right, left, vn, vi) - return tilde_observe!!(context.context, right, left, vn, vi) -end - -function tilde_observe!!(context::AbstractContext, right, left, vn, vi) - return tilde_observe!!(childcontext(context), right, left, vn, vi) -end - -# `PrefixContext` -function tilde_observe!!(context::PrefixContext, right, left, vn, vi) - # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal - # value. For the need for prefix_and_strip_contexts rather than just prefix, see the - # comment in `tilde_assume!!`. - new_vn, new_context = if vn !== nothing - prefix_and_strip_contexts(context, vn) - else - vn, childcontext(context) - end - return tilde_observe!!(new_context, right, left, new_vn, vi) -end - -""" - tilde_observe!!(context, right, left, vn, vi) - -Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the observed value and updated `vi`. - -Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name -and indices; if needed, these can be accessed through this function, though. -""" -function tilde_observe!!(::DefaultContext, right, left, vn, vi) - right isa DynamicPPL.Submodel && - throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) - vi = accumulate_observe!!(vi, right, left, vn) - return left, vi -end - -function assume(::Random.AbstractRNG, spl::Sampler, dist) - return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") -end - -# fallback without sampler -function assume(dist::Distribution, vn::VarName, vi) - y = getindex_internal(vi, vn) - f = from_maybe_linked_internal_transform(vi, vn, dist) - x, inv_logjac = with_logabsdet_jacobian(f, y) - vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) - return x, vi -end - -# TODO: Remove this thing. -# SampleFromPrior and SampleFromUniform -function assume( - rng::Random.AbstractRNG, - sampler::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, - vn::VarName, - vi::VarInfoOrThreadSafeVarInfo, -) - if haskey(vi, vn) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") - # TODO(mhauru) Is it important to unset the flag here? The `true` allows us - # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure - # if that's okay. - unset_flag!(vi, vn, "del", true) - r = init(rng, dist, sampler) - f = to_maybe_linked_internal_transform(vi, vn, dist) - # TODO(mhauru) This should probably be call a function called setindex_internal! - vi = BangBang.setindex!!(vi, f(r), vn) - else - # Otherwise we just extract it. - r = vi[vn, dist] - end - else - r = init(rng, dist, sampler) - if istrans(vi) - f = to_linked_internal_transform(vi, vn, dist) - vi = push!!(vi, vn, f(r), dist) - # By default `push!!` sets the transformed flag to `false`. - vi = settrans!!(vi, true, vn) - else - vi = push!!(vi, vn, r, dist) - end - end - - # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. - logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - vi = accumulate_assume!!(vi, r, logjac, vn, dist) - return r, vi -end diff --git a/src/contexts.jl b/src/contexts.jl index addadfa1a..32a236e8e 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,6 +1,3 @@ -# Fallback traits -# TODO: Should this instead be `NoChildren()`, `HasChild()`, etc. so we allow plural too, e.g. `HasChildren()`? - """ NodeTrait(context) NodeTrait(f, context) @@ -47,7 +44,7 @@ effectively updating the child context. ```jldoctest julia> using DynamicPPL: DynamicTransformationContext -julia> ctx = SamplingContext(); +julia> ctx = ConditionContext((; a = 1)); julia> DynamicPPL.childcontext(ctx) DefaultContext() @@ -61,16 +58,17 @@ DynamicTransformationContext{true}() setchildcontext """ - leafcontext(context) + leafcontext(context::AbstractContext) Return the leaf of `context`, i.e. the first descendant context that `IsLeaf`. """ -leafcontext(context) = leafcontext(NodeTrait(leafcontext, context), context) -leafcontext(::IsLeaf, context) = context -leafcontext(::IsParent, context) = leafcontext(childcontext(context)) +leafcontext(context::AbstractContext) = + leafcontext(NodeTrait(leafcontext, context), context) +leafcontext(::IsLeaf, context::AbstractContext) = context +leafcontext(::IsParent, context::AbstractContext) = leafcontext(childcontext(context)) """ - setleafcontext(left, right) + setleafcontext(left::AbstractContext, right::AbstractContext) Return `left` but now with its leaf context replaced by `right`. @@ -106,675 +104,78 @@ julia> # Append another parent context. ParentContext(ParentContext(ParentContext(DefaultContext()))) ``` """ -function setleafcontext(left, right) +function setleafcontext(left::AbstractContext, right::AbstractContext) return setleafcontext( NodeTrait(setleafcontext, left), NodeTrait(setleafcontext, right), left, right ) end -function setleafcontext(::IsParent, ::IsParent, left, right) +function setleafcontext( + ::IsParent, ::IsParent, left::AbstractContext, right::AbstractContext +) return setchildcontext(left, setleafcontext(childcontext(left), right)) end -function setleafcontext(::IsParent, ::IsLeaf, left, right) +function setleafcontext(::IsParent, ::IsLeaf, left::AbstractContext, right::AbstractContext) return setchildcontext(left, setleafcontext(childcontext(left), right)) end -setleafcontext(::IsLeaf, ::IsParent, left, right) = right -setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right +setleafcontext(::IsLeaf, ::IsParent, left::AbstractContext, right::AbstractContext) = right +setleafcontext(::IsLeaf, ::IsLeaf, left::AbstractContext, right::AbstractContext) = right -# Contexts """ - SamplingContext( - [rng::Random.AbstractRNG=Random.default_rng()], - [sampler::AbstractSampler=SampleFromPrior()], - [context::AbstractContext=DefaultContext()], + DynamicPPL.tilde_assume!!( + context::AbstractContext, + right::Distribution, + vn::VarName, + vi::AbstractVarInfo ) -Create a context that allows you to sample parameters with the `sampler` when running the model. -The `context` determines how the returned log density is computed when running the model. - -See also: [`DefaultContext`](@ref) -""" -struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext - rng::R - sampler::S - context::C -end - -function SamplingContext( - rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior() -) - return SamplingContext(rng, sampler, DefaultContext()) -end - -function SamplingContext( - sampler::AbstractSampler, context::AbstractContext=DefaultContext() -) - return SamplingContext(Random.default_rng(), sampler, context) -end - -function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext) - return SamplingContext(rng, SampleFromPrior(), context) -end - -function SamplingContext(context::AbstractContext) - return SamplingContext(Random.default_rng(), SampleFromPrior(), context) -end - -NodeTrait(context::SamplingContext) = IsParent() -childcontext(context::SamplingContext) = context.context -function setchildcontext(parent::SamplingContext, child) - return SamplingContext(parent.rng, parent.sampler, child) -end - -""" - hassampler(context) - -Return `true` if `context` has a sampler. -""" -hassampler(::SamplingContext) = true -hassampler(context::AbstractContext) = hassampler(NodeTrait(context), context) -hassampler(::IsLeaf, context::AbstractContext) = false -hassampler(::IsParent, context::AbstractContext) = hassampler(childcontext(context)) - -""" - getsampler(context) - -Return the sampler of the context `context`. - -This will traverse the context tree until it reaches the first [`SamplingContext`](@ref), -at which point it will return the sampler of that context. -""" -getsampler(context::SamplingContext) = context.sampler -getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context) -getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context)) -getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context") - -""" - struct DefaultContext <: AbstractContext end - -The `DefaultContext` is used by default to accumulate values like the log joint probability -when running the model. -""" -struct DefaultContext <: AbstractContext end -NodeTrait(::DefaultContext) = IsLeaf() - -""" - PrefixContext(vn::VarName[, context::AbstractContext]) - PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym} - -Create a context that allows you to use the wrapped `context` when running the model and -prefixes all parameters with the VarName `vn`. - -`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`. -If `context` is not provided, it defaults to `DefaultContext()`. - -This context is useful in nested models to ensure that the names of the parameters are -unique. - -See also: [`to_submodel`](@ref) -""" -struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext - vn_prefix::Tvn - context::C -end -PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext()) -function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} - return PrefixContext(VarName{sym}(), context) -end -PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) - -NodeTrait(::PrefixContext) = IsParent() -childcontext(context::PrefixContext) = context.context -function setchildcontext(ctx::PrefixContext, child::AbstractContext) - return PrefixContext(ctx.vn_prefix, child) -end - -""" - prefix(ctx::AbstractContext, vn::VarName) +Handle assumed variables, i.e. anything which is not observed (see +[`tilde_observe!!`](@ref)). Accumulate the associated log probability, and return the +sampled value and updated `vi`. -Apply the prefixes in the context `ctx` to the variable name `vn`. -""" -function prefix(ctx::PrefixContext, vn::VarName) - return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) -end -function prefix(ctx::AbstractContext, vn::VarName) - return prefix(NodeTrait(ctx), ctx, vn) -end -prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn -function prefix(::IsParent, ctx::AbstractContext, vn::VarName) - return prefix(childcontext(ctx), vn) -end - -""" - prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) - -Same as `prefix`, but additionally returns a new context stack that has all the -PrefixContexts removed. - -NOTE: This does _not_ modify any variables in any `ConditionContext` and -`FixedContext` that may be present in the context stack. This is because this -function is only used in `tilde_assume`, which is lower in the tilde-pipeline -than `contextual_isassumption` and `contextual_isfixed` (the functions which -actually use the `ConditionContext` and `FixedContext` values). Thus, by this -time, any `ConditionContext`s and `FixedContext`s present have already served -their purpose. - -If you call this function, you must therefore be careful to ensure that you _do -not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you -_do_ need to modify them, then you may need to use -`prefix_cond_and_fixed_variables` instead. -""" -function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) - child_context = childcontext(ctx) - # vn_prefixed contains the prefixes from all lower levels - vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts( - child_context, vn - ) - return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes -end -function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) - return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) -end -prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) -function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) - vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) - return vn, setchildcontext(ctx, new_ctx) -end +`vn` is the VarName on the left-hand side of the tilde statement. +This function should return a tuple `(x, vi)`, where `x` is the sampled value (which +must be in unlinked space!) and `vi` is the updated VarInfo. """ - prefix(model::Model, x::VarName) - prefix(model::Model, x::Val{sym}) - prefix(model::Model, x::Any) - -Return `model` but with all random variables prefixed by `x`, where `x` is either: -- a `VarName` (e.g. `@varname(a)`), -- a `Val{sym}` (e.g. `Val(:a)`), or -- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that - this will introduce runtime overheads so is not recommended unless absolutely - necessary. - -# Examples - -```jldoctest -julia> using DynamicPPL: prefix - -julia> @model demo() = x ~ Dirac(1) -demo (generic function with 2 methods) - -julia> rand(prefix(demo(), @varname(my_prefix))) -(var"my_prefix.x" = 1,) - -julia> rand(prefix(demo(), Val(:my_prefix))) -(var"my_prefix.x" = 1,) -``` -""" -prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) -function prefix(model::Model, x::Val{sym}) where {sym} - return contextualize(model, PrefixContext(VarName{sym}(), model.context)) -end -function prefix(model::Model, x) - return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) -end - -""" - - ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} - -Model context that contains values that are to be conditioned on. The values -can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or -an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1, -@varname(b) => 2)`). The former is more performant, but the latter must be used -when there are varnames that cannot be represented as symbols, e.g. -`@varname(x[1])`. -""" -struct ConditionContext{ - Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext -} <: AbstractContext - values::Values - context::Ctx -end - -const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}} -const DictConditionContext = ConditionContext{<:AbstractDict} - -# Use DefaultContext as the default base context -function ConditionContext(values::Union{NamedTuple,AbstractDict}) - return ConditionContext(values, DefaultContext()) -end -# Optimisation when there are no values to condition on -ConditionContext(::NamedTuple{()}, context::AbstractContext) = context -# Same as above, and avoids method ambiguity with below -ConditionContext(::NamedTuple{()}, context::NamedConditionContext) = context -# Collapse consecutive levels of `ConditionContext`. Note that this overrides -# values inside the child context, thus giving precedence to the outermost -# `ConditionContext`. -function ConditionContext(values::NamedTuple, context::NamedConditionContext) - return ConditionContext(merge(context.values, values), childcontext(context)) -end -function ConditionContext(values::AbstractDict{<:VarName}, context::DictConditionContext) - return ConditionContext(merge(context.values, values), childcontext(context)) -end - -function Base.show(io::IO, context::ConditionContext) - return print(io, "ConditionContext($(context.values), $(childcontext(context)))") -end - -NodeTrait(::ConditionContext) = IsParent() -childcontext(context::ConditionContext) = context.context -setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) - -""" - hasconditioned(context::AbstractContext, vn::VarName) - -Return `true` if `vn` is found in `context`. -""" -hasconditioned(context::AbstractContext, vn::VarName) = false -hasconditioned(context::ConditionContext, vn::VarName) = hasvalue(context.values, vn) -function hasconditioned(context::ConditionContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(hasvalue, context.values), vns) -end - -""" - getconditioned(context::AbstractContext, vn::VarName) - -Return value of `vn` in `context`. -""" -function getconditioned(context::AbstractContext, vn::VarName) - return error("context $(context) does not contain value for $vn") -end -function getconditioned(context::ConditionContext, vn::VarName) - return getvalue(context.values, vn) -end - -""" - hasconditioned_nested(context, vn) - -Return `true` if `vn` is found in `context` or any of its descendants. - -This is contrast to [`hasconditioned(::AbstractContext, ::VarName)`](@ref) which only checks -for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. -""" -function hasconditioned_nested(context::AbstractContext, vn) - return hasconditioned_nested(NodeTrait(hasconditioned_nested, context), context, vn) -end -hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) -function hasconditioned_nested(::IsParent, context, vn) - return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) -end -function hasconditioned_nested(context::PrefixContext, vn) - return hasconditioned_nested(collapse_prefix_stack(context), vn) -end - -""" - getconditioned_nested(context, vn) - -Return the value of the parameter corresponding to `vn` from `context` or its descendants. - -This is contrast to [`getconditioned`](@ref) which only returns the value `vn` in `context`, -not recursively looking into its descendants. -""" -function getconditioned_nested(context::AbstractContext, vn) - return getconditioned_nested(NodeTrait(getconditioned_nested, context), context, vn) -end -function getconditioned_nested(::IsLeaf, context, vn) - return error("context $(context) does not contain value for $vn") -end -function getconditioned_nested(context::PrefixContext, vn) - return getconditioned_nested(collapse_prefix_stack(context), vn) -end -function getconditioned_nested(::IsParent, context, vn) - return if hasconditioned(context, vn) - getconditioned(context, vn) - else - getconditioned_nested(childcontext(context), vn) - end -end - -""" - decondition(context::AbstractContext, syms...) - -Return `context` but with `syms` no longer conditioned on. - -Note that this recursively traverses contexts, deconditioning all along the way. - -See also: [`condition`](@ref) -""" -decondition_context(::IsLeaf, context, args...) = context -function decondition_context(::IsParent, context, args...) - return setchildcontext(context, decondition_context(childcontext(context), args...)) -end -function decondition_context(context, args...) - return decondition_context(NodeTrait(context), context, args...) -end -function decondition_context(context::ConditionContext) - return decondition_context(childcontext(context)) -end -function decondition_context(context::ConditionContext, sym, syms...) - new_values = deepcopy(context.values) - for s in (sym, syms...) - new_values = BangBang.delete!!(new_values, s) - end - return if length(new_values) == 0 - # No more values left, can unwrap - decondition_context(childcontext(context), syms...) - else - ConditionContext( - new_values, decondition_context(childcontext(context), sym, syms...) - ) - end -end -function decondition_context(context::NamedConditionContext, vn::VarName{sym}) where {sym} - return ConditionContext( - BangBang.delete!!(context.values, sym), - decondition_context(childcontext(context), vn), - ) -end - -""" - conditioned(context::AbstractContext) - -Return `NamedTuple` of values that are conditioned on under context`. - -Note that this will recursively traverse the context stack and return -a merged version of the condition values. -""" -function conditioned(context::AbstractContext) - return conditioned(NodeTrait(conditioned, context), context) -end -conditioned(::IsLeaf, context) = NamedTuple() -conditioned(::IsParent, context) = conditioned(childcontext(context)) -function conditioned(context::ConditionContext) - # Note the order of arguments to `merge`. The behavior of the rest of DPPL - # is that the outermost `context` takes precendence, hence when resolving - # the `conditioned` variables we need to ensure that `context.values` takes - # precedence over decendants of `context`. - return _merge(context.values, conditioned(childcontext(context))) -end -function conditioned(context::PrefixContext) - return conditioned(collapse_prefix_stack(context)) -end - -struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext - values::Values - context::Ctx -end - -const NamedFixedContext{Names} = FixedContext{<:NamedTuple{Names}} -const DictFixedContext = FixedContext{<:AbstractDict} - -FixedContext(values) = FixedContext(values, DefaultContext()) - -# Try to avoid nested `FixedContext`. -function FixedContext(values::NamedTuple, context::NamedFixedContext) - # Note that this potentially overrides values from `context`, thus giving - # precedence to the outmost `FixedContext`. - return FixedContext(merge(context.values, values), childcontext(context)) -end - -function Base.show(io::IO, context::FixedContext) - return print(io, "FixedContext($(context.values), $(childcontext(context)))") -end - -NodeTrait(::FixedContext) = IsParent() -childcontext(context::FixedContext) = context.context -setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child) - -""" - hasfixed(context::AbstractContext, vn::VarName) - -Return `true` if a fixed value for `vn` is found in `context`. -""" -hasfixed(context::AbstractContext, vn::VarName) = false -hasfixed(context::FixedContext, vn::VarName) = hasvalue(context.values, vn) -function hasfixed(context::FixedContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(hasvalue, context.values), vns) -end - -""" - getfixed(context::AbstractContext, vn::VarName) - -Return the fixed value of `vn` in `context`. -""" -function getfixed(context::AbstractContext, vn::VarName) - return error("context $(context) does not contain value for $vn") -end -getfixed(context::FixedContext, vn::VarName) = getvalue(context.values, vn) - -""" - hasfixed_nested(context, vn) - -Return `true` if a fixed value for `vn` is found in `context` or any of its descendants. - -This is contrast to [`hasfixed(::AbstractContext, ::VarName)`](@ref) which only checks -for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. -""" -function hasfixed_nested(context::AbstractContext, vn) - return hasfixed_nested(NodeTrait(hasfixed_nested, context), context, vn) -end -hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) -function hasfixed_nested(::IsParent, context, vn) - return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) -end -function hasfixed_nested(context::PrefixContext, vn) - return hasfixed_nested(collapse_prefix_stack(context), vn) -end - -""" - getfixed_nested(context, vn) - -Return the fixed value of the parameter corresponding to `vn` from `context` or its descendants. - -This is contrast to [`getfixed`](@ref) which only returns the value `vn` in `context`, -not recursively looking into its descendants. -""" -function getfixed_nested(context::AbstractContext, vn) - return getfixed_nested(NodeTrait(getfixed_nested, context), context, vn) -end -function getfixed_nested(::IsLeaf, context, vn) - return error("context $(context) does not contain value for $vn") -end -function getfixed_nested(context::PrefixContext, vn) - return getfixed_nested(collapse_prefix_stack(context), vn) -end -function getfixed_nested(::IsParent, context, vn) - return if hasfixed(context, vn) - getfixed(context, vn) - else - getfixed_nested(childcontext(context), vn) - end -end - -""" - fix([context::AbstractContext,] values::NamedTuple) - fix([context::AbstractContext]; values...) - -Return `FixedContext` with `values` and `context` if `values` is non-empty, -otherwise return `context` which is [`DefaultContext`](@ref) by default. - -See also: [`unfix`](@ref) -""" -fix(; values...) = fix(NamedTuple(values)) -fix(values::NamedTuple) = fix(DefaultContext(), values) -function fix(value::Pair{<:VarName}, values::Pair{<:VarName}...) - return fix((value, values...)) -end -function fix(values::NTuple{<:Any,<:Pair{<:VarName}}) - return fix(DefaultContext(), values) -end -fix(context::AbstractContext, values::NamedTuple{()}) = context -function fix(context::AbstractContext, values::Union{AbstractDict,NamedTuple}) - return FixedContext(values, context) -end -function fix(context::AbstractContext; values...) - return fix(context, NamedTuple(values)) -end -function fix(context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...) - return fix(context, (value, values...)) -end -function fix(context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}}) - return fix(context, Dict(values)) +function tilde_assume!!( + context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + return tilde_assume!!(childcontext(context), right, vn, vi) end """ - unfix(context::AbstractContext, syms...) - -Return `context` but with `syms` no longer fixed. - -Note that this recursively traverses contexts, unfixing all along the way. - -See also: [`fix`](@ref) -""" -unfix(::IsLeaf, context, args...) = context -function unfix(::IsParent, context, args...) - return setchildcontext(context, unfix(childcontext(context), args...)) -end -function unfix(context, args...) - return unfix(NodeTrait(context), context, args...) -end -function unfix(context::FixedContext) - return unfix(childcontext(context)) -end -function unfix(context::FixedContext, sym) - return fix(unfix(childcontext(context), sym), BangBang.delete!!(context.values, sym)) -end -function unfix(context::FixedContext, sym, syms...) - return unfix( - fix(unfix(childcontext(context), syms...), BangBang.delete!!(context.values, sym)), - syms..., + DynamicPPL.tilde_observe!!( + context::AbstractContext, + right::Distribution, + left, + vn::Union{VarName, Nothing}, + vi::AbstractVarInfo ) -end - -function unfix(context::NamedFixedContext, vn::VarName{sym}) where {sym} - return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, sym)) -end -function unfix(context::FixedContext, vn::VarName) - return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, vn)) -end -""" - fixed(context::AbstractContext) +This function handles observed variables, which may be: -Return the values that are fixed under `context`. +- literals on the left-hand side, e.g., `3.0 ~ Normal()` +- a model input, e.g. `x ~ Normal()` in a model `@model f(x) ... end` +- a conditioned or fixed variable, e.g. `x ~ Normal()` in a model `model | (; x = 3.0)`. -Note that this will recursively traverse the context stack and return -a merged version of the fix values. -""" -fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context) -fixed(::IsLeaf, context) = NamedTuple() -fixed(::IsParent, context) = fixed(childcontext(context)) -function fixed(context::FixedContext) - # Note the order of arguments to `merge`. The behavior of the rest of DPPL - # is that the outermost `context` takes precendence, hence when resolving - # the `fixed` variables we need to ensure that `context.values` takes - # precedence over decendants of `context`. - return _merge(context.values, fixed(childcontext(context))) -end -function fixed(context::PrefixContext) - return fixed(collapse_prefix_stack(context)) -end +The relevant log-probability associated with the observation is computed and accumulated in +the VarInfo object `vi` (except for fixed variables, which do not contribute to the +log-probability). -""" - collapse_prefix_stack(context::AbstractContext) - -Apply `PrefixContext`s to any conditioned or fixed values inside them, and remove -the `PrefixContext`s from the context stack. +`left` is the actual value that the left-hand side evaluates to. `vn` is the VarName on the +left-hand side, or `nothing` if the left-hand side is a literal value. -!!! note - If you are reading this docstring, you might probably be interested in a more -thorough explanation of how PrefixContext and ConditionContext / FixedContext -interact with one another, especially in the context of submodels. - The DynamicPPL documentation contains [a separate page on this -topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_condition/) -which explains this in much more detail. - -```jldoctest -julia> using DynamicPPL: collapse_prefix_stack - -julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); - -julia> collapse_prefix_stack(c1) -ConditionContext(Dict(a.x => 1), DefaultContext()) - -julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both. - c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,))))); - -julia> collapsed = collapse_prefix_stack(c2); - -julia> # `collapsed` really looks something like this: - # ConditionContext(Dict{VarName{:a}, Int64}(a.b.y => 2, a.x => 1), DefaultContext()) - # To avoid fragility arising from the order of the keys in the doctest, we test - # this indirectly: - collapsed.values[@varname(a.x)], collapsed.values[@varname(a.b.y)] -(1, 2) -``` -""" -function collapse_prefix_stack(context::PrefixContext) - # Collapse the child context (thus applying any inner prefixes first) - collapsed = collapse_prefix_stack(childcontext(context)) - # Prefix any conditioned variables with the current prefix - # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. - # So is this function. In the worst case scenario, this is O(N^2) in the - # depth of the context stack. - return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) -end -function collapse_prefix_stack(context::AbstractContext) - return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) -end -collapse_prefix_stack(::IsLeaf, context) = context -function collapse_prefix_stack(::IsParent, context) - new_child_context = collapse_prefix_stack(childcontext(context)) - return setchildcontext(context, new_child_context) -end +Observations of submodels are not yet supported in DynamicPPL. +This function should return a tuple `(left, vi)`, where `left` is the same as the input, and +`vi` is the updated VarInfo. """ - prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) - -Prefix all the conditioned and fixed variables in a given context with a single -`prefix`. - -```jldoctest -julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext - -julia> c1 = ConditionContext((a=1, )) -ConditionContext((a = 1,), DefaultContext()) - -julia> prefix_cond_and_fixed_variables(c1, @varname(y)) -ConditionContext(Dict(y.a => 1), DefaultContext()) -``` -""" -function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) - # Replace the prefix of the conditioned variables - vn_dict = to_varname_dict(ctx.values) - prefixed_vn_dict = Dict( - AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict - ) - # Prefix the child context as well - prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) - return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) -end -function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) - # Replace the prefix of the conditioned variables - vn_dict = to_varname_dict(ctx.values) - prefixed_vn_dict = Dict( - AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict - ) - # Prefix the child context as well - prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) - return FixedContext(prefixed_vn_dict, prefixed_child_ctx) -end -function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) - return prefix_cond_and_fixed_variables( - NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix - ) -end -function prefix_cond_and_fixed_variables( - ::IsLeaf, context::AbstractContext, prefix::VarName +function tilde_observe!!( + context::AbstractContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, ) - return context -end -function prefix_cond_and_fixed_variables( - ::IsParent, context::AbstractContext, prefix::VarName -) - return setchildcontext( - context, prefix_cond_and_fixed_variables(childcontext(context), prefix) - ) + return tilde_observe!!(childcontext(context), right, left, vn, vi) end diff --git a/src/contexts/conditionfix.jl b/src/contexts/conditionfix.jl new file mode 100644 index 000000000..d3802de85 --- /dev/null +++ b/src/contexts/conditionfix.jl @@ -0,0 +1,467 @@ +""" + + ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} + +Model context that contains values that are to be conditioned on. The values +can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or +an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1, +@varname(b) => 2)`). The former is more performant, but the latter must be used +when there are varnames that cannot be represented as symbols, e.g. +`@varname(x[1])`. +""" +struct ConditionContext{ + Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext +} <: AbstractContext + values::Values + context::Ctx +end + +const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}} +const DictConditionContext = ConditionContext{<:AbstractDict} + +# Use DefaultContext as the default base context +function ConditionContext(values::Union{NamedTuple,AbstractDict}) + return ConditionContext(values, DefaultContext()) +end +# Optimisation when there are no values to condition on +ConditionContext(::NamedTuple{()}, context::AbstractContext) = context +# Same as above, and avoids method ambiguity with below +ConditionContext(::NamedTuple{()}, context::NamedConditionContext) = context +# Collapse consecutive levels of `ConditionContext`. Note that this overrides +# values inside the child context, thus giving precedence to the outermost +# `ConditionContext`. +function ConditionContext(values::NamedTuple, context::NamedConditionContext) + return ConditionContext(merge(context.values, values), childcontext(context)) +end +function ConditionContext(values::AbstractDict{<:VarName}, context::DictConditionContext) + return ConditionContext(merge(context.values, values), childcontext(context)) +end + +function Base.show(io::IO, context::ConditionContext) + return print(io, "ConditionContext($(context.values), $(childcontext(context)))") +end + +NodeTrait(::ConditionContext) = IsParent() +childcontext(context::ConditionContext) = context.context +setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) + +""" + hasconditioned(context::AbstractContext, vn::VarName) + +Return `true` if `vn` is found in `context`. +""" +hasconditioned(context::AbstractContext, vn::VarName) = false +hasconditioned(context::ConditionContext, vn::VarName) = hasvalue(context.values, vn) +function hasconditioned(context::ConditionContext, vns::AbstractArray{<:VarName}) + return all(Base.Fix1(hasvalue, context.values), vns) +end + +""" + getconditioned(context::AbstractContext, vn::VarName) + +Return value of `vn` in `context`. +""" +function getconditioned(context::AbstractContext, vn::VarName) + return error("context $(context) does not contain value for $vn") +end +function getconditioned(context::ConditionContext, vn::VarName) + return getvalue(context.values, vn) +end + +""" + hasconditioned_nested(context, vn) + +Return `true` if `vn` is found in `context` or any of its descendants. + +This is contrast to [`hasconditioned(::AbstractContext, ::VarName)`](@ref) which only checks +for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. +""" +function hasconditioned_nested(context::AbstractContext, vn) + return hasconditioned_nested(NodeTrait(hasconditioned_nested, context), context, vn) +end +hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) +function hasconditioned_nested(::IsParent, context, vn) + return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) +end +function hasconditioned_nested(context::PrefixContext, vn) + return hasconditioned_nested(collapse_prefix_stack(context), vn) +end + +""" + getconditioned_nested(context, vn) + +Return the value of the parameter corresponding to `vn` from `context` or its descendants. + +This is contrast to [`getconditioned`](@ref) which only returns the value `vn` in `context`, +not recursively looking into its descendants. +""" +function getconditioned_nested(context::AbstractContext, vn) + return getconditioned_nested(NodeTrait(getconditioned_nested, context), context, vn) +end +function getconditioned_nested(::IsLeaf, context, vn) + return error("context $(context) does not contain value for $vn") +end +function getconditioned_nested(context::PrefixContext, vn) + return getconditioned_nested(collapse_prefix_stack(context), vn) +end +function getconditioned_nested(::IsParent, context, vn) + return if hasconditioned(context, vn) + getconditioned(context, vn) + else + getconditioned_nested(childcontext(context), vn) + end +end + +""" + decondition(context::AbstractContext, syms...) + +Return `context` but with `syms` no longer conditioned on. + +Note that this recursively traverses contexts, deconditioning all along the way. + +See also: [`condition`](@ref) +""" +decondition_context(::IsLeaf, context, args...) = context +function decondition_context(::IsParent, context, args...) + return setchildcontext(context, decondition_context(childcontext(context), args...)) +end +function decondition_context(context, args...) + return decondition_context(NodeTrait(context), context, args...) +end +function decondition_context(context::ConditionContext) + return decondition_context(childcontext(context)) +end +function decondition_context(context::ConditionContext, sym, syms...) + new_values = deepcopy(context.values) + for s in (sym, syms...) + new_values = BangBang.delete!!(new_values, s) + end + return if length(new_values) == 0 + # No more values left, can unwrap + decondition_context(childcontext(context), syms...) + else + ConditionContext( + new_values, decondition_context(childcontext(context), sym, syms...) + ) + end +end +function decondition_context(context::NamedConditionContext, vn::VarName{sym}) where {sym} + return ConditionContext( + BangBang.delete!!(context.values, sym), + decondition_context(childcontext(context), vn), + ) +end + +""" + conditioned(context::AbstractContext) + +Return `NamedTuple` of values that are conditioned on under context`. + +Note that this will recursively traverse the context stack and return +a merged version of the condition values. +""" +function conditioned(context::AbstractContext) + return conditioned(NodeTrait(conditioned, context), context) +end +conditioned(::IsLeaf, context) = NamedTuple() +conditioned(::IsParent, context) = conditioned(childcontext(context)) +function conditioned(context::ConditionContext) + # Note the order of arguments to `merge`. The behavior of the rest of DPPL + # is that the outermost `context` takes precendence, hence when resolving + # the `conditioned` variables we need to ensure that `context.values` takes + # precedence over decendants of `context`. + return _merge(context.values, conditioned(childcontext(context))) +end +function conditioned(context::PrefixContext) + return conditioned(collapse_prefix_stack(context)) +end + +struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext + values::Values + context::Ctx +end + +const NamedFixedContext{Names} = FixedContext{<:NamedTuple{Names}} +const DictFixedContext = FixedContext{<:AbstractDict} + +FixedContext(values) = FixedContext(values, DefaultContext()) + +# Try to avoid nested `FixedContext`. +function FixedContext(values::NamedTuple, context::NamedFixedContext) + # Note that this potentially overrides values from `context`, thus giving + # precedence to the outmost `FixedContext`. + return FixedContext(merge(context.values, values), childcontext(context)) +end + +function Base.show(io::IO, context::FixedContext) + return print(io, "FixedContext($(context.values), $(childcontext(context)))") +end + +NodeTrait(::FixedContext) = IsParent() +childcontext(context::FixedContext) = context.context +setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child) + +""" + hasfixed(context::AbstractContext, vn::VarName) + +Return `true` if a fixed value for `vn` is found in `context`. +""" +hasfixed(context::AbstractContext, vn::VarName) = false +hasfixed(context::FixedContext, vn::VarName) = hasvalue(context.values, vn) +function hasfixed(context::FixedContext, vns::AbstractArray{<:VarName}) + return all(Base.Fix1(hasvalue, context.values), vns) +end + +""" + getfixed(context::AbstractContext, vn::VarName) + +Return the fixed value of `vn` in `context`. +""" +function getfixed(context::AbstractContext, vn::VarName) + return error("context $(context) does not contain value for $vn") +end +getfixed(context::FixedContext, vn::VarName) = getvalue(context.values, vn) + +""" + hasfixed_nested(context, vn) + +Return `true` if a fixed value for `vn` is found in `context` or any of its descendants. + +This is contrast to [`hasfixed(::AbstractContext, ::VarName)`](@ref) which only checks +for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. +""" +function hasfixed_nested(context::AbstractContext, vn) + return hasfixed_nested(NodeTrait(hasfixed_nested, context), context, vn) +end +hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) +function hasfixed_nested(::IsParent, context, vn) + return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) +end +function hasfixed_nested(context::PrefixContext, vn) + return hasfixed_nested(collapse_prefix_stack(context), vn) +end + +""" + getfixed_nested(context, vn) + +Return the fixed value of the parameter corresponding to `vn` from `context` or its descendants. + +This is contrast to [`getfixed`](@ref) which only returns the value `vn` in `context`, +not recursively looking into its descendants. +""" +function getfixed_nested(context::AbstractContext, vn) + return getfixed_nested(NodeTrait(getfixed_nested, context), context, vn) +end +function getfixed_nested(::IsLeaf, context, vn) + return error("context $(context) does not contain value for $vn") +end +function getfixed_nested(context::PrefixContext, vn) + return getfixed_nested(collapse_prefix_stack(context), vn) +end +function getfixed_nested(::IsParent, context, vn) + return if hasfixed(context, vn) + getfixed(context, vn) + else + getfixed_nested(childcontext(context), vn) + end +end + +""" + fix([context::AbstractContext,] values::NamedTuple) + fix([context::AbstractContext]; values...) + +Return `FixedContext` with `values` and `context` if `values` is non-empty, +otherwise return `context` which is [`DefaultContext`](@ref) by default. + +See also: [`unfix`](@ref) +""" +fix(; values...) = fix(NamedTuple(values)) +fix(values::NamedTuple) = fix(DefaultContext(), values) +function fix(value::Pair{<:VarName}, values::Pair{<:VarName}...) + return fix((value, values...)) +end +function fix(values::NTuple{<:Any,<:Pair{<:VarName}}) + return fix(DefaultContext(), values) +end +fix(context::AbstractContext, values::NamedTuple{()}) = context +function fix(context::AbstractContext, values::Union{AbstractDict,NamedTuple}) + return FixedContext(values, context) +end +function fix(context::AbstractContext; values...) + return fix(context, NamedTuple(values)) +end +function fix(context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...) + return fix(context, (value, values...)) +end +function fix(context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}}) + return fix(context, Dict(values)) +end + +""" + unfix(context::AbstractContext, syms...) + +Return `context` but with `syms` no longer fixed. + +Note that this recursively traverses contexts, unfixing all along the way. + +See also: [`fix`](@ref) +""" +unfix(::IsLeaf, context, args...) = context +function unfix(::IsParent, context, args...) + return setchildcontext(context, unfix(childcontext(context), args...)) +end +function unfix(context, args...) + return unfix(NodeTrait(context), context, args...) +end +function unfix(context::FixedContext) + return unfix(childcontext(context)) +end +function unfix(context::FixedContext, sym) + return fix(unfix(childcontext(context), sym), BangBang.delete!!(context.values, sym)) +end +function unfix(context::FixedContext, sym, syms...) + return unfix( + fix(unfix(childcontext(context), syms...), BangBang.delete!!(context.values, sym)), + syms..., + ) +end + +function unfix(context::NamedFixedContext, vn::VarName{sym}) where {sym} + return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, sym)) +end +function unfix(context::FixedContext, vn::VarName) + return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, vn)) +end + +""" + fixed(context::AbstractContext) + +Return the values that are fixed under `context`. + +Note that this will recursively traverse the context stack and return +a merged version of the fix values. +""" +fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context) +fixed(::IsLeaf, context) = NamedTuple() +fixed(::IsParent, context) = fixed(childcontext(context)) +function fixed(context::FixedContext) + # Note the order of arguments to `merge`. The behavior of the rest of DPPL + # is that the outermost `context` takes precendence, hence when resolving + # the `fixed` variables we need to ensure that `context.values` takes + # precedence over decendants of `context`. + return _merge(context.values, fixed(childcontext(context))) +end +function fixed(context::PrefixContext) + return fixed(collapse_prefix_stack(context)) +end + +########################################################################### +### Interaction of PrefixContext with ConditionContext and FixedContext ### +########################################################################### + +""" + collapse_prefix_stack(context::AbstractContext) + +Apply `PrefixContext`s to any conditioned or fixed values inside them, and remove +the `PrefixContext`s from the context stack. + +!!! note + If you are reading this docstring, you might probably be interested in a more +thorough explanation of how PrefixContext and ConditionContext / FixedContext +interact with one another, especially in the context of submodels. + The DynamicPPL documentation contains [a separate page on this +topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_condition/) +which explains this in much more detail. + +```jldoctest +julia> using DynamicPPL: collapse_prefix_stack + +julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); + +julia> collapse_prefix_stack(c1) +ConditionContext(Dict(a.x => 1), DefaultContext()) + +julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both. + c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,))))); + +julia> collapsed = collapse_prefix_stack(c2); + +julia> # `collapsed` really looks something like this: + # ConditionContext(Dict{VarName{:a}, Int64}(a.b.y => 2, a.x => 1), DefaultContext()) + # To avoid fragility arising from the order of the keys in the doctest, we test + # this indirectly: + collapsed.values[@varname(a.x)], collapsed.values[@varname(a.b.y)] +(1, 2) +``` +""" +function collapse_prefix_stack(context::PrefixContext) + # Collapse the child context (thus applying any inner prefixes first) + collapsed = collapse_prefix_stack(childcontext(context)) + # Prefix any conditioned variables with the current prefix + # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. + # So is this function. In the worst case scenario, this is O(N^2) in the + # depth of the context stack. + return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) +end +function collapse_prefix_stack(context::AbstractContext) + return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) +end +collapse_prefix_stack(::IsLeaf, context) = context +function collapse_prefix_stack(::IsParent, context) + new_child_context = collapse_prefix_stack(childcontext(context)) + return setchildcontext(context, new_child_context) +end + +""" + prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) + +Prefix all the conditioned and fixed variables in a given context with a single +`prefix`. + +```jldoctest +julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext + +julia> c1 = ConditionContext((a=1, )) +ConditionContext((a = 1,), DefaultContext()) + +julia> prefix_cond_and_fixed_variables(c1, @varname(y)) +ConditionContext(Dict(y.a => 1), DefaultContext()) +``` +""" +function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return FixedContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) + return prefix_cond_and_fixed_variables( + NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix + ) +end +function prefix_cond_and_fixed_variables( + ::IsLeaf, context::AbstractContext, prefix::VarName +) + return context +end +function prefix_cond_and_fixed_variables( + ::IsParent, context::AbstractContext, prefix::VarName +) + return setchildcontext( + context, prefix_cond_and_fixed_variables(childcontext(context), prefix) + ) +end diff --git a/src/contexts/default.jl b/src/contexts/default.jl new file mode 100644 index 000000000..ec21e1a56 --- /dev/null +++ b/src/contexts/default.jl @@ -0,0 +1,60 @@ +""" + struct DefaultContext <: AbstractContext end + +`DefaultContext`, as the name suggests, is the default context used when instantiating a +model. + +```jldoctest +julia> @model f() = x ~ Normal(); + +julia> model = f(); model.context +DefaultContext() +``` + +As an evaluation context, the behaviour of `DefaultContext` is to require all variables to be +present in the `AbstractVarInfo` used for evaluation. Thus, semantically, evaluating a model +with `DefaultContext` means 'calculating the log-probability associated with the variables +in the `AbstractVarInfo`'. +""" +struct DefaultContext <: AbstractContext end +NodeTrait(::DefaultContext) = IsLeaf() + +""" + DynamicPPL.tilde_assume!!( + ::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo + ) + +Handle assumed variables. For `DefaultContext`, this function extracts the value associated +with `vn` from `vi`, If `vi` does not contain an appropriate value then this will error. +""" +function tilde_assume!!( + ::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + y = getindex_internal(vi, vn) + f = from_maybe_linked_internal_transform(vi, vn, right) + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) + return x, vi +end + +""" + DynamicPPL.tilde_observe!!( + ::DefaultContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, + ) + +Handle observed variables. This just accumulates the log-likelihood for `left`. +""" +function tilde_observe!!( + ::DefaultContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) + vi = accumulate_observe!!(vi, right, left, vn) + return left, vi +end diff --git a/src/contexts/init.jl b/src/contexts/init.jl new file mode 100644 index 000000000..534c6a7b0 --- /dev/null +++ b/src/contexts/init.jl @@ -0,0 +1,200 @@ +""" + AbstractInitStrategy + +Abstract type representing the possible ways of initialising new values for +the random variables in a model (e.g., when creating a new VarInfo). + +Any subtype of `AbstractInitStrategy` must implement the +[`DynamicPPL.init`](@ref) method. +""" +abstract type AbstractInitStrategy end + +""" + init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy) + +Generate a new value for a random variable with the given distribution. + +!!! warning "Return values must be unlinked" + The values returned by `init` must always be in the untransformed space, i.e., + they must be within the support of the original distribution. That means that, + for example, `init(rng, dist, u::InitFromUniform)` will in general return values that + are outside the range [u.lower, u.upper]. +""" +function init end + +""" + InitFromPrior() + +Obtain new values by sampling from the prior distribution. +""" +struct InitFromPrior <: AbstractInitStrategy end +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) + return rand(rng, dist) +end + +""" + InitFromUniform() + InitFromUniform(lower, upper) + +Obtain new values by first transforming the distribution of the random variable +to unconstrained space, then sampling a value uniformly between `lower` and +`upper`, and transforming that value back to the original space. + +If `lower` and `upper` are unspecified, they default to `(-2, 2)`, which mimics +Stan's default initialisation strategy. + +Requires that `lower <= upper`. + +# References + +[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) +""" +struct InitFromUniform{T<:AbstractFloat} <: AbstractInitStrategy + lower::T + upper::T + function InitFromUniform(lower::T, upper::T) where {T<:AbstractFloat} + lower > upper && + throw(ArgumentError("`lower` must be less than or equal to `upper`")) + return new{T}(lower, upper) + end + InitFromUniform() = InitFromUniform(-2.0, 2.0) +end +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFromUniform) + b = Bijectors.bijector(dist) + sz = Bijectors.output_size(b, size(dist)) + y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...)) + b_inv = Bijectors.inverse(b) + x = b_inv(y) + # 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398 + if x isa Array{<:Any,0} + x = x[] + end + return x +end + +""" + InitFromParams( + params::Union{AbstractDict{<:VarName},NamedTuple}, + fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() + ) + +Obtain new values by extracting them from the given dictionary or NamedTuple. + +The parameter `fallback` specifies how new values are to be obtained if they +cannot be found in `params`, or they are specified as `missing`. `fallback` +can either be an initialisation strategy itself, in which case it will be +used to obtain new values, or it can be `nothing`, in which case an error +will be thrown. The default for `fallback` is `InitFromPrior()`. + +!!! note + The values in `params` must be provided in the space of the untransformed + distribution. +""" +struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy + params::P + fallback::S + function InitFromParams( + params::AbstractDict{<:VarName}, + fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior(), + ) + return new{typeof(params),typeof(fallback)}(params, fallback) + end + function InitFromParams( + params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() + ) + return InitFromParams(to_varname_dict(params), fallback) + end +end +function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) + # TODO(penelopeysm): It would be nice to do a check to make sure that all + # of the parameters in `p.params` were actually used, and either warn or + # error if they aren't. This is actually quite non-trivial though because + # the structure of Dicts in particular can have arbitrary nesting. + return if hasvalue(p.params, vn, dist) + x = getvalue(p.params, vn, dist) + if x === missing + p.fallback === nothing && + error("A `missing` value was provided for the variable `$(vn)`.") + init(rng, vn, dist, p.fallback) + else + # TODO(penelopeysm): Since x is user-supplied, maybe we could also + # check here that the type / size of x matches the dist? + x + end + else + p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") + init(rng, vn, dist, p.fallback) + end +end + +""" + InitContext( + [rng::Random.AbstractRNG=Random.default_rng()], + [strategy::AbstractInitStrategy=InitFromPrior()], + ) + +A leaf context that indicates that new values for random variables are +currently being obtained through sampling. Used e.g. when initialising a fresh +VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then +`evaluate!!(model, varinfo)` will override all values in the VarInfo. +""" +struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractContext + rng::R + strategy::S + function InitContext( + rng::Random.AbstractRNG, strategy::AbstractInitStrategy=InitFromPrior() + ) + return new{typeof(rng),typeof(strategy)}(rng, strategy) + end + function InitContext(strategy::AbstractInitStrategy=InitFromPrior()) + return InitContext(Random.default_rng(), strategy) + end +end +NodeTrait(::InitContext) = IsLeaf() + +function tilde_assume!!( + ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo +) + in_varinfo = haskey(vi, vn) + # `init()` always returns values in original space, i.e. possibly + # constrained + x = init(ctx.rng, vn, dist, ctx.strategy) + # Determine whether to insert a transformed value into the VarInfo. + # If the VarInfo alrady had a value for this variable, we will + # keep the same linked status as in the original VarInfo. If not, we + # check the rest of the VarInfo to see if other variables are linked. + # is_transformed(vi) returns true if vi is nonempty and all variables in vi + # are linked. + insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi) + f = if insert_transformed_value + link_transform(dist) + else + identity + end + y, logjac = with_logabsdet_jacobian(f, x) + # Add the new value to the VarInfo. `push!!` errors if the value already + # exists, hence the need for setindex!!. + if in_varinfo + vi = setindex!!(vi, y, vn) + else + vi = push!!(vi, vn, y, dist) + end + # Neither of these set the `trans` flag so we have to do it manually if + # necessary. + insert_transformed_value && set_transformed!!(vi, true, vn) + # `accumulate_assume!!` wants untransformed values as the second argument. + vi = accumulate_assume!!(vi, x, logjac, vn, dist) + # We always return the untransformed value here, as that will determine + # what the lhs of the tilde-statement is set to. + return x, vi +end + +function tilde_observe!!( + ::InitContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) + return tilde_observe!!(DefaultContext(), right, left, vn, vi) +end diff --git a/src/contexts/prefix.jl b/src/contexts/prefix.jl new file mode 100644 index 000000000..24615e683 --- /dev/null +++ b/src/contexts/prefix.jl @@ -0,0 +1,116 @@ +""" + PrefixContext(vn::VarName[, context::AbstractContext]) + PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym} + +Create a context that allows you to use the wrapped `context` when running the model and +prefixes all parameters with the VarName `vn`. + +`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`. +If `context` is not provided, it defaults to `DefaultContext()`. + +This context is useful in nested models to ensure that the names of the parameters are +unique. + +See also: [`to_submodel`](@ref) +""" +struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext + vn_prefix::Tvn + context::C +end +PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext()) +function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} + return PrefixContext(VarName{sym}(), context) +end +PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) + +NodeTrait(::PrefixContext) = IsParent() +childcontext(context::PrefixContext) = context.context +function setchildcontext(ctx::PrefixContext, child::AbstractContext) + return PrefixContext(ctx.vn_prefix, child) +end + +""" + prefix(ctx::AbstractContext, vn::VarName) + +Apply the prefixes in the context `ctx` to the variable name `vn`. +""" +function prefix(ctx::PrefixContext, vn::VarName) + return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) +end +function prefix(ctx::AbstractContext, vn::VarName) + return prefix(NodeTrait(ctx), ctx, vn) +end +prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn +function prefix(::IsParent, ctx::AbstractContext, vn::VarName) + return prefix(childcontext(ctx), vn) +end + +""" + prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) + +Same as `prefix`, but additionally returns a new context stack that has all the +PrefixContexts removed. + +NOTE: This does _not_ modify any variables in any `ConditionContext` and +`FixedContext` that may be present in the context stack. This is because this +function is only used in `tilde_assume!!`, which is lower in the tilde-pipeline +than `contextual_isassumption` and `contextual_isfixed` (the functions which +actually use the `ConditionContext` and `FixedContext` values). Thus, by this +time, any `ConditionContext`s and `FixedContext`s present have already served +their purpose. + +If you call this function, you must therefore be careful to ensure that you _do +not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you +_do_ need to modify them, then you may need to use +`prefix_cond_and_fixed_variables` instead. +""" +function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) + child_context = childcontext(ctx) + # vn_prefixed contains the prefixes from all lower levels + vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts( + child_context, vn + ) + return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes +end +function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) + return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) +end +prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) +function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) + vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) + return vn, setchildcontext(ctx, new_ctx) +end + +function tilde_assume!!( + context::PrefixContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + # Note that we can't use something like this here: + # new_vn = prefix(context, vn) + # return tilde_assume!!(childcontext(context), right, new_vn, vi) + # This is because `prefix` applies _all_ prefixes in a given context to a + # variable name. Thus, if we had two levels of nested prefixes e.g. + # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the + # first call would apply the prefix `a.b._`, and the recursive call + # would apply the prefix `b._`, resulting in `b.a.b._`. + # This is why we need a special function, `prefix_and_strip_contexts`. + new_vn, new_context = prefix_and_strip_contexts(context, vn) + return tilde_assume!!(new_context, right, new_vn, vi) +end + +function tilde_observe!!( + context::PrefixContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) + # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal + # value. For the need for prefix_and_strip_contexts rather than just prefix, see the + # comment in `tilde_assume!!`. + new_vn, new_context = if vn !== nothing + prefix_and_strip_contexts(context, vn) + else + vn, childcontext(context) + end + return tilde_observe!!(new_context, right, left, new_vn, vi) +end diff --git a/src/transforming.jl b/src/contexts/transformation.jl similarity index 53% rename from src/transforming.jl rename to src/contexts/transformation.jl index 56f861cff..5153f7857 100644 --- a/src/transforming.jl +++ b/src/contexts/transformation.jl @@ -12,13 +12,16 @@ how to do the transformation, used by e.g. `SimpleVarInfo`. struct DynamicTransformationContext{isinverse} <: AbstractContext end NodeTrait(::DynamicTransformationContext) = IsLeaf() -function tilde_assume( - ::DynamicTransformationContext{isinverse}, right, vn, vi +function tilde_assume!!( + ::DynamicTransformationContext{isinverse}, + right::Distribution, + vn::VarName, + vi::AbstractVarInfo, ) where {isinverse} # vi[vn, right] always provides the value in unlinked space. x = vi[vn, right] - if istrans(vi, vn) + if is_transformed(vi, vn) isinverse || @warn "Trying to link an already transformed variable ($vn)" else isinverse && @warn "Trying to invlink a non-transformed variable ($vn)" @@ -31,34 +34,12 @@ function tilde_assume( return x, vi end -function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi) - return tilde_observe!!(DefaultContext(), right, left, vn, vi) -end - -function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return _transform!!(t, DynamicTransformationContext{false}(), vi, model) -end - -function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return _transform!!(NoTransformation(), DynamicTransformationContext{true}(), vi, model) -end - -function _transform!!( - t::AbstractTransformation, - ctx::DynamicTransformationContext, +function tilde_observe!!( + ::DynamicTransformationContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, vi::AbstractVarInfo, - model::Model, ) - # To transform using DynamicTransformationContext, we evaluate the model using that as the leaf context: - model = contextualize(model, setleafcontext(model.context, ctx)) - vi = settrans!!(last(evaluate!!(model, vi)), t) - return vi -end - -function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return link!!(t, deepcopy(vi), model) -end - -function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return invlink!!(t, deepcopy(vi), model) + return tilde_observe!!(DefaultContext(), right, left, vn, vi) end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index c2be4b46b..13124e3a7 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -1,7 +1,6 @@ module DebugUtils using ..DynamicPPL -using ..DynamicPPL: broadcast_safe, AbstractContext, childcontext using Random: Random using Accessors: Accessors @@ -485,7 +484,7 @@ and checking if the model is consistent across runs. function has_static_constraints( rng::Random.AbstractRNG, model::Model; num_evals::Int=5, error_on_failure::Bool=false ) - new_model = DynamicPPL.contextualize(model, SamplingContext(rng, SampleFromPrior())) + new_model = DynamicPPL.contextualize(model, InitContext(rng)) results = map(1:num_evals) do _ check_model_and_trace(new_model, VarInfo(); error_on_failure=error_on_failure) end diff --git a/src/extract_priors.jl b/src/extract_priors.jl index d311a5f63..8c7b5f7db 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -123,7 +123,7 @@ extract_priors(args::Union{Model,AbstractVarInfo}...) = function extract_priors(rng::Random.AbstractRNG, model::Model) varinfo = VarInfo() varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),)) - varinfo = last(evaluate_and_sample!!(rng, model, varinfo)) + varinfo = last(init!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end diff --git a/src/model.jl b/src/model.jl index 9f9c6ec3b..d6682416b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -95,6 +95,16 @@ function contextualize(model::Model, context::AbstractContext) return Model(model.f, model.args, model.defaults, context) end +""" + setleafcontext(model::Model, context::AbstractContext) + +Return a new `Model` with its leaf context set to `context`. This is a convenience shortcut +for `contextualize(model, setleafcontext(model.context, context)`). +""" +function setleafcontext(model::Model, context::AbstractContext) + return contextualize(model, setleafcontext(model.context, context)) +end + """ model | (x = 1.0, ...) @@ -799,6 +809,41 @@ julia> # Now `a.x` will be sampled. """ fixed(model::Model) = fixed(model.context) +""" + prefix(model::Model, x::VarName) + prefix(model::Model, x::Val{sym}) + prefix(model::Model, x::Any) + +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. + +# Examples + +```jldoctest +julia> using DynamicPPL: prefix + +julia> @model demo() = x ~ Dirac(1) +demo (generic function with 2 methods) + +julia> rand(prefix(demo(), @varname(my_prefix))) +(var"my_prefix.x" = 1,) + +julia> rand(prefix(demo(), Val(:my_prefix))) +(var"my_prefix.x" = 1,) +``` +""" +prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) +function prefix(model::Model, x::Val{sym}) where {sym} + return contextualize(model, PrefixContext(VarName{sym}(), model.context)) +end +function prefix(model::Model, x) + return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) +end + """ (model::Model)([rng, varinfo]) @@ -815,7 +860,7 @@ end # ^ Weird Documenter.jl bug means that we have to write the two above separately # as it can only detect the `function`-less syntax. function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo()) - return first(evaluate_and_sample!!(rng, model, varinfo)) + return first(init!!(rng, model, varinfo)) end """ @@ -829,29 +874,37 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) end """ - evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) + init!!( + [rng::Random.AbstractRNG,] + model::Model, + varinfo::AbstractVarInfo, + [init_strategy::AbstractInitStrategy=InitFromPrior()] + ) -Evaluate the `model` with the given `varinfo`, but perform sampling during the -evaluation using the given `sampler` by wrapping the model's context in a -`SamplingContext`. +Evaluate the `model` and replace the values of the model's random variables +in the given `varinfo` with new values, using a specified initialisation strategy. +If the values in `varinfo` are not set, they will be added +using a specified initialisation strategy. -If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref). +If `init_strategy` is not provided, defaults to `InitFromPrior()`. Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -function evaluate_and_sample!!( +function init!!( rng::Random.AbstractRNG, model::Model, varinfo::AbstractVarInfo, - sampler::AbstractSampler=SampleFromPrior(), + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context)) - return evaluate!!(sampling_model, varinfo) + new_model = setleafcontext(model, InitContext(rng, init_strategy)) + return evaluate!!(new_model, varinfo) end -function evaluate_and_sample!!( - model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior() +function init!!( + model::Model, + varinfo::AbstractVarInfo, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) + return init!!(Random.default_rng(), model, varinfo, init_strategy) end """ @@ -981,11 +1034,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last( - evaluate_and_sample!!( - rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()) - ), - ) + x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()))) return values_as(x, T) end @@ -1009,42 +1058,6 @@ function logjoint(model::Model, varinfo::AbstractVarInfo) return getlogjoint(last(evaluate!!(model, varinfo))) end -""" - logjoint(model::Model, chain::AbstractMCMC.AbstractChains) - -Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`. - -# Examples - -```jldoctest -julia> using MCMCChains, Distributions - -julia> @model function demo_model(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - for i in eachindex(x) - x[i] ~ Normal(m, sqrt(s)) - end - end; - -julia> # construct a chain of samples using MCMCChains - chain = Chains(rand(10, 2, 3), [:s, :m]); - -julia> logjoint(demo_model([1., 2.]), chain); -``` -""" -function logjoint(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict{VarName,Any}( - vn_parent => - values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for - vn_parent in keys(var_info) - ) - logjoint(model, argvals_dict) - end -end - """ logprior(model::Model, varinfo::AbstractVarInfo) @@ -1067,42 +1080,6 @@ function logprior(model::Model, varinfo::AbstractVarInfo) return getlogprior(last(evaluate!!(model, varinfo))) end -""" - logprior(model::Model, chain::AbstractMCMC.AbstractChains) - -Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`. - -# Examples - -```jldoctest -julia> using MCMCChains, Distributions - -julia> @model function demo_model(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - for i in eachindex(x) - x[i] ~ Normal(m, sqrt(s)) - end - end; - -julia> # construct a chain of samples using MCMCChains - chain = Chains(rand(10, 2, 3), [:s, :m]); - -julia> logprior(demo_model([1., 2.]), chain); -``` -""" -function logprior(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict{VarName,Any}( - vn_parent => - values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for - vn_parent in keys(var_info) - ) - logprior(model, argvals_dict) - end -end - """ loglikelihood(model::Model, varinfo::AbstractVarInfo) @@ -1121,61 +1098,8 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) return getloglikelihood(last(evaluate!!(model, varinfo))) end -""" - loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) - -Return an array of log likelihoods evaluated at each sample in an MCMC `chain`. - -# Examples - -```jldoctest -julia> using MCMCChains, Distributions - -julia> @model function demo_model(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - for i in eachindex(x) - x[i] ~ Normal(m, sqrt(s)) - end - end; - -julia> # construct a chain of samples using MCMCChains - chain = Chains(rand(10, 2, 3), [:s, :m]); - -julia> loglikelihood(demo_model([1., 2.]), chain); -``` -""" -function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict{VarName,Any}( - vn_parent => - values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for - vn_parent in keys(var_info) - ) - loglikelihood(model, argvals_dict) - end -end - -""" - predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) - -Generate samples from the posterior predictive distribution by evaluating `model` at each set -of parameter values provided in `chain`. The number of posterior predictive samples matches -the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values -and the predicted values. -""" -function predict( - rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo} -) - varinfo = DynamicPPL.VarInfo(model) - return map(chain) do params_varinfo - vi = deepcopy(varinfo) - DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) - model(rng, vi) - return vi - end -end +# Implemented & documented in DynamicPPLMCMCChainsExt +function predict end """ returned(model::Model, parameters::NamedTuple) diff --git a/src/model_utils.jl b/src/model_utils.jl index ac4ec7022..e4c326b39 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -81,7 +81,7 @@ function varname_in_chain!( # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic. # This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent` # to extract the value from the `chain`. - for vn in varname_leaves(VarName{sym}(), x) + for vn in AbstractPPL.varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. l = AbstractPPL.getoptic(vn) varname_in_chain!(x, l ∘ vn_parent, chain, chain_idx, iteration_idx, out) @@ -107,7 +107,7 @@ function values_from_chain( # This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent` # to extract the value from the `chain`. out = similar(x) - for vn in varname_leaves(VarName{sym}(), x) + for vn in AbstractPPL.varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. l = AbstractPPL.getoptic(vn) out = Accessors.set( diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 61834ab62..848ecb1f0 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -1,35 +1,22 @@ """ - PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: AbstractAccumulator + PointwiseLogProbAccumulator{whichlogprob} <: AbstractAccumulator An accumulator that stores the log-probabilities of each variable in a model. -Internally this accumulator stores the log-probabilities in a dictionary, where -the keys are the variable names and the values are vectors of -log-probabilities. Each element in a vector corresponds to one execution of the -model. +Internally this accumulator stores the log-probabilities in a dictionary, where the keys are +the variable names and the values are log-probabilities. `whichlogprob` is a symbol that can be `:both`, `:prior`, or `:likelihood`, and specifies -which log-probabilities to store in the accumulator. `KeyType` is the type by which variable -names are stored, and should be `String` or `VarName`. `D` is the type of the dictionary -used internally to store the log-probabilities, by default -`OrderedDict{KeyType, Vector{LogProbType}}`. +which log-probabilities to store in the accumulator. """ -struct PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: - AbstractAccumulator - logps::D -end - -function PointwiseLogProbAccumulator{whichlogprob}(logps) where {whichlogprob} - return PointwiseLogProbAccumulator{whichlogprob,keytype(logps),typeof(logps)}(logps) -end +struct PointwiseLogProbAccumulator{whichlogprob} <: AbstractAccumulator + logps::OrderedDict{VarName,LogProbType} -function PointwiseLogProbAccumulator{whichlogprob}() where {whichlogprob} - return PointwiseLogProbAccumulator{whichlogprob,VarName}() -end - -function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob,KeyType} - logps = OrderedDict{KeyType,Vector{LogProbType}}() - return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps) + function PointwiseLogProbAccumulator{whichlogprob}( + d::OrderedDict{VarName,LogProbType}=OrderedDict{VarName,LogProbType}() + ) where {whichlogprob} + return new{whichlogprob}(d) + end end function Base.:(==)( @@ -42,28 +29,14 @@ function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichl return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps)) end -function Base.push!(acc::PointwiseLogProbAccumulator, vn, logp) - logps = acc.logps - # The last(fieldtypes(eltype(...))) gets the type of the values, rather than the keys. - T = last(fieldtypes(eltype(logps))) - logpvec = get!(logps, vn, T()) - return push!(logpvec, logp) -end - -function Base.push!( - acc::PointwiseLogProbAccumulator{whichlogprob,String}, vn::VarName, logp -) where {whichlogprob} - return push!(acc, string(vn), logp) -end - function accumulator_name( ::Type{<:PointwiseLogProbAccumulator{whichlogprob}} ) where {whichlogprob} return Symbol("PointwiseLogProbAccumulator{$whichlogprob}") end -function _zero(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} - return PointwiseLogProbAccumulator{whichlogprob}(empty(acc.logps)) +function _zero(::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}() end reset(acc::PointwiseLogProbAccumulator) = _zero(acc) split(acc::PointwiseLogProbAccumulator) = _zero(acc) @@ -71,21 +44,14 @@ function combine( acc::PointwiseLogProbAccumulator{whichlogprob}, acc2::PointwiseLogProbAccumulator{whichlogprob}, ) where {whichlogprob} - return PointwiseLogProbAccumulator{whichlogprob}(mergewith(vcat, acc.logps, acc2.logps)) + return PointwiseLogProbAccumulator{whichlogprob}(mergewith(+, acc.logps, acc2.logps)) end function accumulate_assume!!( acc::PointwiseLogProbAccumulator{whichlogprob}, val, logjac, vn, right ) where {whichlogprob} if whichlogprob == :both || whichlogprob == :prior - # T is the element type of the vectors that are the values of `acc.logps`. Usually - # it's LogProbType. - T = eltype(last(fieldtypes(eltype(acc.logps)))) - # Note that in only accumulating LogPrior, we effectively ignore logjac - # (since we want to return log densities that don't depend on the - # linking status of the VarInfo). - subacc = accumulate_assume!!(LogPriorAccumulator{T}(), val, logjac, vn, right) - push!(acc, vn, subacc.logp) + acc.logps[vn] = logpdf(right, val) end return acc end @@ -99,172 +65,11 @@ function accumulate_observe!!( return acc end if whichlogprob == :both || whichlogprob == :likelihood - # T is the element type of the vectors that are the values of `acc.logps`. Usually - # it's LogProbType. - T = eltype(last(fieldtypes(eltype(acc.logps)))) - subacc = accumulate_observe!!(LogLikelihoodAccumulator{T}(), right, left, vn) - push!(acc, vn, subacc.logp) + acc.logps[vn] = loglikelihood(right, left) end return acc end -""" - pointwise_logdensities( - model::Model, - chain::Chains, - keytype=String, - ::Val{whichlogprob}=Val(:both), - ) - -Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` -with keys corresponding to symbols of the variables, and values being matrices -of shape `(num_chains, num_samples)`. - -`keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -Currently, only `String` and `VarName` are supported. `whichlogprob` specifies -which log-probabilities to compute. It can be `:both`, `:prior`, or -`:likelihood`. - -See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_loglikelihoods`](@ref). - -# Notes -Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` -both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an -*observation*) statements can be implemented in three ways: -1. using a `for` loop: -```julia -for i in eachindex(y) - y[i] ~ Normal(μ, σ) -end -``` -2. using `.~`: -```julia -y .~ Normal(μ, σ) -``` -3. using `MvNormal`: -```julia -y ~ MvNormal(fill(μ, n), σ^2 * I) -``` - -In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, -while in (3) `y` will be treated as a _single_ n-dimensional observation. - -This is important to keep in mind, in particular if the computation is used -for downstream computations. - -# Examples -## From chain -```jldoctest pointwise-logdensities-chains; setup=:(using Distributions) -julia> using MCMCChains - -julia> @model function demo(xs, y) - s ~ InverseGamma(2, 3) - m ~ Normal(0, √s) - for i in eachindex(xs) - xs[i] ~ Normal(m, √s) - end - y ~ Normal(m, √s) - end -demo (generic function with 2 methods) - -julia> # Example observations. - model = demo([1.0, 2.0, 3.0], [4.0]); - -julia> # A chain with 3 iterations. - chain = Chains( - reshape(1.:6., 3, 2), - [:s, :m] - ); - -julia> pointwise_logdensities(model, chain) -OrderedDict{String, Matrix{Float64}} with 6 entries: - "s" => [-0.802775; -1.38222; -2.09861;;] - "m" => [-8.91894; -7.51551; -7.46824;;] - "xs[1]" => [-5.41894; -5.26551; -5.63491;;] - "xs[2]" => [-2.91894; -3.51551; -4.13491;;] - "xs[3]" => [-1.41894; -2.26551; -2.96824;;] - "y" => [-0.918939; -1.51551; -2.13491;;] - -julia> pointwise_logdensities(model, chain, String) -OrderedDict{String, Matrix{Float64}} with 6 entries: - "s" => [-0.802775; -1.38222; -2.09861;;] - "m" => [-8.91894; -7.51551; -7.46824;;] - "xs[1]" => [-5.41894; -5.26551; -5.63491;;] - "xs[2]" => [-2.91894; -3.51551; -4.13491;;] - "xs[3]" => [-1.41894; -2.26551; -2.96824;;] - "y" => [-0.918939; -1.51551; -2.13491;;] - -julia> pointwise_logdensities(model, chain, VarName) -OrderedDict{VarName, Matrix{Float64}} with 6 entries: - s => [-0.802775; -1.38222; -2.09861;;] - m => [-8.91894; -7.51551; -7.46824;;] - xs[1] => [-5.41894; -5.26551; -5.63491;;] - xs[2] => [-2.91894; -3.51551; -4.13491;;] - xs[3] => [-1.41894; -2.26551; -2.96824;;] - y => [-0.918939; -1.51551; -2.13491;;] -``` - -## Broadcasting -Note that `x .~ Dist()` will treat `x` as a collection of -_independent_ observations rather than as a single observation. - -```jldoctest; setup = :(using Distributions) -julia> @model function demo(x) - x .~ Normal() - end; - -julia> m = demo([1.0, ]); - -julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x[1])]) --1.4189385332046727 - -julia> m = demo([1.0; 1.0]); - -julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) -(-1.4189385332046727, -1.4189385332046727) -``` -""" -function pointwise_logdensities( - model::Model, chain, ::Type{KeyType}=String, ::Val{whichlogprob}=Val(:both) -) where {KeyType,whichlogprob} - # Get the data by executing the model once - vi = VarInfo(model) - - # This accumulator tracks the pointwise log-probabilities in a single iteration. - AccType = PointwiseLogProbAccumulator{whichlogprob,KeyType} - vi = setaccs!!(vi, (AccType(),)) - - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - - # Maintain a separate accumulator that isn't tied to a VarInfo but rather - # tracks _all_ iterations. - all_logps = AccType() - for (sample_idx, chain_idx) in iters - # Update the values - setval!(vi, chain, sample_idx, chain_idx) - - # Execute model - vi = last(evaluate!!(model, vi)) - - # Get the log-probabilities - this_iter_logps = getacc(vi, Val(accumulator_name(AccType))).logps - - # Merge into main acc - for (varname, this_lp) in this_iter_logps - # Because `this_lp` is obtained from one model execution, it should only - # contain one variable, hence `only()`. - push!(all_logps, varname, only(this_lp)) - end - end - - niters = size(chain, 1) - nchains = size(chain, 3) - logdensities = OrderedDict( - varname => reshape(vals, niters, nchains) for (varname, vals) in all_logps.logps - ) - return logdensities -end - function pointwise_logdensities( model::Model, varinfo::AbstractVarInfo, ::Val{whichlogprob}=Val(:both) ) where {whichlogprob} @@ -274,38 +79,10 @@ function pointwise_logdensities( return getacc(varinfo, Val(accumulator_name(AccType))).logps end -""" - pointwise_loglikelihoods(model, chain[, keytype]) - -Compute the pointwise log-likelihoods of the model given the chain. -This is the same as `pointwise_logdensities(model, chain)`, but only -including the likelihood terms. - -See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref). -""" -function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} - return pointwise_logdensities(model, chain, T, Val(:likelihood)) -end - function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) return pointwise_logdensities(model, varinfo, Val(:likelihood)) end -""" - pointwise_prior_logdensities(model, chain[, keytype]) - -Compute the pointwise log-prior-densities of the model given the chain. -This is the same as `pointwise_logdensities(model, chain)`, but only -including the prior terms. - -See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref). -""" -function pointwise_prior_logdensities( - model::Model, chain, keytype::Type{T}=String -) where {T} - return pointwise_logdensities(model, chain, T, Val(:prior)) -end - function pointwise_prior_logdensities(model::Model, varinfo::AbstractVarInfo) return pointwise_logdensities(model, varinfo, Val(:prior)) end diff --git a/src/sampler.jl b/src/sampler.jl deleted file mode 100644 index 27b990336..000000000 --- a/src/sampler.jl +++ /dev/null @@ -1,263 +0,0 @@ -# TODO: Make `UniformSampling` and `Prior` algs + just use `Sampler` -# That would let us use all defaults for Sampler, combine it with other samplers etc. -""" - SampleFromUniform - -Sampling algorithm that samples unobserved random variables from a uniform distribution. - -# References - -[Stan reference manual](https://mc-stan.org/docs/2_28/reference-manual/initialization.html#random-initial-values) -""" -struct SampleFromUniform <: AbstractSampler end - -""" - SampleFromPrior - -Sampling algorithm that samples unobserved random variables from their prior distribution. -""" -struct SampleFromPrior <: AbstractSampler end - -# Initializations. -init(rng, dist, ::SampleFromPrior) = rand(rng, dist) -function init(rng, dist, ::SampleFromUniform) - return istransformable(dist) ? inittrans(rng, dist) : rand(rng, dist) -end - -init(rng, dist, ::SampleFromPrior, n::Int) = rand(rng, dist, n) -function init(rng, dist, ::SampleFromUniform, n::Int) - return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n) -end - -# TODO(mhauru) Could we get rid of Sampler now that it's just a wrapper around `alg`? -# (Selector has been removed). -""" - Sampler{T} - -Generic sampler type for inference algorithms of type `T` in DynamicPPL. - -`Sampler` should implement the AbstractMCMC interface, and in particular -`AbstractMCMC.step`. A default implementation of the initial sampling step is -provided that supports resuming sampling from a previous state and setting initial -parameter values. It requires to overload [`loadstate`](@ref) and [`initialstep`](@ref) -for loading previous states and actually performing the initial sampling step, -respectively. Additionally, sometimes one might want to implement [`initialsampler`](@ref) -that specifies how the initial parameter values are sampled if they are not provided. -By default, values are sampled from the prior. -""" -struct Sampler{T} <: AbstractSampler - alg::T -end - -# AbstractMCMC interface for SampleFromUniform and SampleFromPrior -function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::Model, - sampler::Union{SampleFromUniform,SampleFromPrior}, - state=nothing; - kwargs..., -) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(rng, model, vi, sampler) - return vi, nothing -end - -""" - default_varinfo(rng, model, sampler) - -Return a default varinfo object for the given `model` and `sampler`. - -# Arguments -- `rng::Random.AbstractRNG`: Random number generator. -- `model::Model`: Model for which we want to create a varinfo object. -- `sampler::AbstractSampler`: Sampler which will make use of the varinfo object. - -# Returns -- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`. -""" -function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) - init_sampler = initialsampler(sampler) - return typed_varinfo(rng, model, init_sampler) -end - -function AbstractMCMC.sample( - rng::Random.AbstractRNG, - model::Model, - sampler::Sampler, - N::Integer; - chain_type=default_chain_type(sampler), - resume_from=nothing, - initial_state=loadstate(resume_from), - kwargs..., -) - return AbstractMCMC.mcmcsample( - rng, model, sampler, N; chain_type, initial_state, kwargs... - ) -end - -function AbstractMCMC.sample( - rng::Random.AbstractRNG, - model::Model, - sampler::Sampler, - parallel::AbstractMCMC.AbstractMCMCEnsemble, - N::Integer, - nchains::Integer; - chain_type=default_chain_type(sampler), - resume_from=nothing, - initial_state=loadstate(resume_from), - kwargs..., -) - return AbstractMCMC.mcmcsample( - rng, model, sampler, parallel, N, nchains; chain_type, initial_state, kwargs... - ) -end - -# initial step: general interface for resuming and -function AbstractMCMC.step( - rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs... -) - # Sample initial values. - vi = default_varinfo(rng, model, spl) - - # Update the parameters if provided. - if initial_params !== nothing - vi = initialize_parameters!!(vi, initial_params, model) - - # Update joint log probability. - # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 - # and https://github.com/TuringLang/Turing.jl/issues/1563 - # to avoid that existing variables are resampled - vi = last(evaluate!!(model, vi)) - end - - return initialstep(rng, model, spl, vi; initial_params, kwargs...) -end - -""" - loadstate(data) - -Load sampler state from `data`. - -By default, `data` is returned. -""" -loadstate(data) = data - -""" - default_chain_type(sampler) - -Default type of the chain of posterior samples from `sampler`. -""" -default_chain_type(sampler::Sampler) = Any - -""" - initialsampler(sampler::Sampler) - -Return the sampler that is used for generating the initial parameters when sampling with -`sampler`. - -By default, it returns an instance of [`SampleFromPrior`](@ref). -""" -initialsampler(spl::Sampler) = SampleFromPrior() - -""" - set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) - set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) - -Take the values inside `initial_params`, replace the corresponding values in -the given VarInfo object, and return a new VarInfo object with the updated values. - -This differs from `DynamicPPL.unflatten` in two ways: - -1. It works with `NamedTuple` arguments. -2. For the `AbstractVector` method, if any of the elements are missing, it will not -overwrite the original value in the VarInfo (it will just use the original -value instead). -""" -function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) - throw( - ArgumentError( - "`initial_params` must be a vector of type `Union{Real,Missing}`. " * - "If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.", - ), - ) -end - -function set_initial_values( - varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}} -) - flattened_param_vals = varinfo[:] - length(flattened_param_vals) == length(initial_params) || throw( - DimensionMismatch( - "Provided initial value size ($(length(initial_params))) doesn't match " * - "the model size ($(length(flattened_param_vals))).", - ), - ) - - # Update values that are provided. - for i in eachindex(initial_params) - x = initial_params[i] - if x !== missing - flattened_param_vals[i] = x - end - end - - # Update in `varinfo`. - new_varinfo = unflatten(varinfo, flattened_param_vals) - return new_varinfo -end - -function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) - varinfo = deepcopy(varinfo) - vars_in_varinfo = keys(varinfo) - for v in keys(initial_params) - vn = VarName{v}() - if !(vn in vars_in_varinfo) - for vv in vars_in_varinfo - if subsumes(vn, vv) - throw( - ArgumentError( - "The current model contains sub-variables of $v, such as ($vv). " * - "Using NamedTuple for initial_params is not supported in such a case. " * - "Please use AbstractVector for initial_params instead of NamedTuple.", - ), - ) - end - end - throw(ArgumentError("Variable $v not found in the model.")) - end - end - initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing) - return update_values!!( - varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params)) - ) -end - -function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model) - @debug "Using passed-in initial variable values" initial_params - - # `link` the varinfo if needed. - linked = islinked(vi) - if linked - vi = invlink!!(vi, model) - end - - # Set the values in `vi`. - vi = set_initial_values(vi, initial_params) - - # `invlink` if needed. - if linked - vi = link!!(vi, model) - end - - return vi -end - -""" - initialstep(rng, model, sampler, varinfo; kwargs...) - -Perform the initial sampling step of the `sampler` for the `model`. - -The `varinfo` contains the initial samples, which can be provided by the user or -sampled randomly. -""" -function initialstep end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 31b2d2ac6..2ba25f142 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -39,7 +39,7 @@ julia> rng = StableRNG(42); julia> # In the `NamedTuple` version we need to provide the place-holder values for # the variables which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo((x = ones(2), ))); + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo((x = ones(2), ))); julia> # (✓) Vroom, vroom! FAST!!! vi[@varname(x[1])] @@ -57,12 +57,12 @@ julia> vi[@varname(x[1:2])] 1.3736306979834252 julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); vi + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); ERROR: FieldError: type NamedTuple has no field `x`, available fields: `m` [...] julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -91,28 +91,28 @@ demo_constrained (generic function with 2 methods) julia> m = demo_constrained(); -julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); +julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ 1.8632965762164932 -julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); +julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.21080155351918753 -julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); + _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 -julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true @@ -121,7 +121,7 @@ true Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general -julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) +julia> vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), true) Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0))) julia> # (✓) Positive probability mass on negative numbers! @@ -129,7 +129,7 @@ julia> # (✓) Positive probability mass on negative numbers! -1.3678794411714423 julia> # While if we forget to indicate that it's transformed: - vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) + vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), false) SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0))) julia> # (✓) No probability mass on negative numbers! @@ -232,24 +232,27 @@ end # Constructor from `Model`. function SimpleVarInfo{T}( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) where {T<:Real} - new_model = contextualize(model, SamplingContext(rng, sampler, model.context)) - return last(evaluate!!(new_model, SimpleVarInfo{T}())) + return last(init!!(rng, model, SimpleVarInfo{T}(), init_strategy)) end function SimpleVarInfo{T}( - model::Model, sampler::AbstractSampler=SampleFromPrior() + model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() ) where {T<:Real} - return SimpleVarInfo{T}(Random.default_rng(), model, sampler) + return SimpleVarInfo{T}(Random.default_rng(), model, init_strategy) end # Constructors without type param function SimpleVarInfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return SimpleVarInfo{LogProbType}(rng, model, sampler) + return SimpleVarInfo{LogProbType}(rng, model, init_strategy) end -function SimpleVarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return SimpleVarInfo{LogProbType}(Random.default_rng(), model, sampler) +function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return SimpleVarInfo{LogProbType}(Random.default_rng(), model, init_strategy) end # Constructor from `VarInfo`. @@ -265,12 +268,12 @@ end function untyped_simple_varinfo(model::Model) varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) - return last(evaluate_and_sample!!(model, varinfo)) + return last(init!!(model, varinfo)) end function typed_simple_varinfo(model::Model) varinfo = SimpleVarInfo{Float64}() - return last(evaluate_and_sample!!(model, varinfo)) + return last(init!!(model, varinfo)) end function unflatten(svi::SimpleVarInfo, x::AbstractVector) @@ -463,42 +466,32 @@ function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) return SimpleVarInfo(values, accs, transformation) end -# Context implementations -# NOTE: Evaluations, i.e. those without `rng` are shared with other -# implementations of `AbstractVarInfo`. -function assume( - rng::Random.AbstractRNG, - sampler::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, - vn::VarName, - vi::SimpleOrThreadSafeSimple, -) - value = init(rng, dist, sampler) - # Transform if we're working in unconstrained space. - f = to_maybe_linked_internal_transform(vi, vn, dist) - value_raw, logjac = with_logabsdet_jacobian(f, value) - vi = BangBang.push!!(vi, vn, value_raw, dist) - vi = accumulate_assume!!(vi, value, logjac, vn, dist) - return value, vi -end - -# NOTE: We don't implement `settrans!!(vi, trans, vn)`. -function settrans!!(vi::SimpleVarInfo, trans) - return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) +function set_transformed!!(vi::SimpleVarInfo, trans) + return set_transformed!!(vi, trans ? DynamicTransformation() : NoTransformation()) end -function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation) +function set_transformed!!(vi::SimpleVarInfo, transformation::AbstractTransformation) return Accessors.@set vi.transformation = transformation end -function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) - return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans) +function set_transformed!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) + return Accessors.@set vi.varinfo = set_transformed!!(vi.varinfo, trans) +end +function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) + # We keep this method around just to obey the AbstractVarInfo interface. + # However, note that this would only be a valid operation if it would be a + # no-op, which we check here. + if trans != is_transformed(vi) + error( + "Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.", + ) + end end -istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) -istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) -istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) -istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = istrans(vi.varinfo) - -islinked(vi::SimpleVarInfo) = istrans(vi) +is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) +is_transformed(vi::SimpleVarInfo, ::VarName) = is_transformed(vi) +function is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) + return is_transformed(vi.varinfo, vn) +end +is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = is_transformed(vi.varinfo) values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values @@ -517,7 +510,7 @@ function values_as(vi::SimpleVarInfo, ::Type{T}) where {T} end """ - logjoint(model::Model, θ) + logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) Return the log joint probability of variables `θ` for the probabilistic `model`. @@ -546,10 +539,11 @@ julia> # Truth. -9902.33787706641 ``` """ -logjoint(model::Model, θ) = logjoint(model, SimpleVarInfo(θ)) +logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) = + logjoint(model, SimpleVarInfo(θ)) """ - logprior(model::Model, θ) + logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) Return the log prior probability of variables `θ` for the probabilistic `model`. @@ -578,10 +572,11 @@ julia> # Truth. -5000.918938533205 ``` """ -logprior(model::Model, θ) = logprior(model, SimpleVarInfo(θ)) +logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) = + logprior(model, SimpleVarInfo(θ)) """ - loglikelihood(model::Model, θ) + loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) Return the log likelihood of variables `θ` for the probabilistic `model`. @@ -610,7 +605,8 @@ julia> # Truth. -4901.418938533205 ``` """ -Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarInfo(θ)) +Distributions.loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) = + loglikelihood(model, SimpleVarInfo(θ)) # Allow usage of `NamedBijector` too. function link!!( @@ -625,7 +621,7 @@ function link!!( if hasacc(vi_new, Val(:LogJacobian)) vi_new = acclogjac!!(vi_new, logjac) end - return settrans!!(vi_new, t) + return set_transformed!!(vi_new, t) end function invlink!!( @@ -643,7 +639,7 @@ function invlink!!( if hasacc(vi_new, Val(:LogJacobian)) vi_new = acclogjac!!(vi_new, inv_logjac) end - return settrans!!(vi_new, NoTransformation()) + return set_transformed!!(vi_new, NoTransformation()) end # With `SimpleVarInfo`, when we're not working with linked variables, there's no need to do anything. diff --git a/src/submodel.jl b/src/submodel.jl index dcb107bb4..145bd42c9 100644 --- a/src/submodel.jl +++ b/src/submodel.jl @@ -8,6 +8,10 @@ struct Submodel{M,AutoPrefix} model::M end +# ---------------------- +# Constructing submodels +# ---------------------- + """ to_submodel(model::Model[, auto_prefix::Bool]) @@ -152,6 +156,26 @@ ERROR: ArgumentError: `x ~ to_submodel(...)` is not supported when `x` is observ """ to_submodel(m::Model, auto_prefix::Bool=true) = Submodel{typeof(m),auto_prefix}(m) +# --------------------------- +# Submodels in tilde-pipeline +# --------------------------- + +""" + DynamicPPL.tilde_assume!!( + context::AbstractContext, + right::DynamicPPL.Submodel, + vn::VarName, + vi::AbstractVarInfo + ) + +Evaluate the submodel with the given context. +""" +function tilde_assume!!( + context::AbstractContext, right::DynamicPPL.Submodel, vn::VarName, vi::AbstractVarInfo +) + return _evaluate!!(right, vi, context, vn) +end + # When automatic prefixing is used, the submodel itself doesn't carry the # prefix, as the prefix is obtained from the LHS of `~` (whereas the submodel # is on the RHS). The prefix can only be obtained in `tilde_assume!!`, and then @@ -193,3 +217,13 @@ function _evaluate!!( # returns a tuple of submodel.model's return value and the new varinfo. return _evaluate!!(model, vi) end + +function tilde_observe!!( + ::AbstractContext, + ::DynamicPPL.Submodel, + left, + vn::Union{VarName,Nothing}, + ::AbstractVarInfo, +) + throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) +end diff --git a/src/test_utils.jl b/src/test_utils.jl index 65079f023..f584055b3 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -11,13 +11,12 @@ using Bijectors: Bijectors using Accessors: Accessors # For backwards compat. -using DynamicPPL: varname_leaves, update_values!! +using DynamicPPL: update_values!! include("test_utils/model_interface.jl") include("test_utils/models.jl") include("test_utils/contexts.jl") include("test_utils/varinfo.jl") -include("test_utils/sampler.jl") include("test_utils/ad.jl") end diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 863db4262..aae2e4ec6 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -25,25 +25,49 @@ This method ensures that `context` - Correctly implements the tilde-pipeline. """ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - # `NodeTrait`. node_trait = DynamicPPL.NodeTrait(context) - # Throw error immediately if it it's missing a `NodeTrait` implementation. - node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} || - throw(ValueError("Invalid NodeTrait: $node_trait")) - - # To see change, let's make sure we're using a different leaf context than the current. - leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext - DynamicPPL.DynamicTransformationContext{false}() + if node_trait isa DynamicPPL.IsLeaf + test_leaf_context(context, model) + elseif node_trait isa DynamicPPL.IsParent + test_parent_context(context, model) else - DefaultContext() + error("Invalid NodeTrait: $node_trait") end - @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == - leafcontext_new +end + +function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf - # The interface methods. - if node_trait isa DynamicPPL.IsParent - # `childcontext` and `setchildcontext` - # With new child context + # Note that for a leaf context we can't assume that it will work with an + # empty VarInfo. (For example, DefaultContext will error with empty + # varinfos.) Thus we only test evaluation with VarInfos that are already + # filled with values. + @testset "evaluation" begin + # Generate a new filled untyped varinfo + _, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) + typed_vi = DynamicPPL.typed_varinfo(untyped_vi) + # Set the test context as the new leaf context + new_model = DynamicPPL.setleafcontext(model, context) + # Check that evaluation works + for vi in [untyped_vi, typed_vi] + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo + end + end +end + +function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent + + @testset "get/set leaf and child contexts" begin + # Ensure we're using a different leaf context than the current. + leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext + DynamicPPL.DynamicTransformationContext{false}() + else + DefaultContext() + end + @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == + leafcontext_new childcontext_new = TestParentContext() @test DynamicPPL.childcontext( DynamicPPL.setchildcontext(context, childcontext_new) @@ -56,19 +80,15 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod leafcontext_new end - # Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded). - # The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it. - # NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the - # context might alter which variables are present, their names, etc., e.g. `PrefixContext`. - # TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos. - # Untyped varinfo. - varinfo_untyped = DynamicPPL.VarInfo() - model_with_spl = contextualize(model, SamplingContext(context)) - model_without_spl = contextualize(model, context) - @test DynamicPPL.evaluate!!(model_with_spl, varinfo_untyped) isa Any - @test DynamicPPL.evaluate!!(model_without_spl, varinfo_untyped) isa Any - # Typed varinfo. - varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped) - @test DynamicPPL.evaluate!!(model_with_spl, varinfo_typed) isa Any - @test DynamicPPL.evaluate!!(model_without_spl, varinfo_typed) isa Any + @testset "initialisation and evaluation" begin + new_model = contextualize(model, context) + for vi in [DynamicPPL.VarInfo(), DynamicPPL.typed_varinfo(DynamicPPL.VarInfo())] + # Initialisation + _, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) + @test vi isa DynamicPPL.VarInfo + # Evaluation + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo + end + end end diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index 93aed074c..cb949464e 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -92,9 +92,7 @@ Even though it is recommended to implement this by hand for a particular `Model` a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) - return collect( - keys(last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(Dict())))) - ) + return collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict()))))) end """ diff --git a/src/test_utils/sampler.jl b/src/test_utils/sampler.jl deleted file mode 100644 index 71cdb1cac..000000000 --- a/src/test_utils/sampler.jl +++ /dev/null @@ -1,85 +0,0 @@ -# sampler.jl -# ---------- -# -# Utilities to test samplers on models. - -""" - marginal_mean_of_samples(chain, varname) - -Return the mean of variable represented by `varname` in `chain`. -""" -marginal_mean_of_samples(chain, varname) = mean(Array(chain[Symbol(varname)])) - -""" - test_sampler(models, sampler, args...; kwargs...) - -Test that `sampler` produces correct marginal posterior means on each model in `models`. - -In short, this method iterates through `models`, calls `AbstractMCMC.sample` on the -`model` and `sampler` to produce a `chain`, and then checks `marginal_mean_of_samples(chain, vn)` -for every (leaf) varname `vn` against the corresponding value returned by -[`posterior_mean`](@ref) for each model. - -To change how comparison is done for a particular `chain` type, one can overload -[`marginal_mean_of_samples`](@ref) for the corresponding type. - -# Arguments -- `models`: A collection of instaces of [`DynamicPPL.Model`](@ref) to test on. -- `sampler`: The `AbstractMCMC.AbstractSampler` to test. -- `args...`: Arguments forwarded to `sample`. - -# Keyword arguments -- `varnames_filter`: A filter to apply to `varnames(model)`, allowing comparison for only - a subset of the varnames. -- `atol=1e-1`: Absolute tolerance used in `@test`. -- `rtol=1e-3`: Relative tolerance used in `@test`. -- `kwargs...`: Keyword arguments forwarded to `sample`. -""" -function test_sampler( - models, - sampler::AbstractMCMC.AbstractSampler, - args...; - varnames_filter=Returns(true), - atol=1e-1, - rtol=1e-3, - sampler_name=typeof(sampler), - kwargs..., -) - @testset "$(sampler_name) on $(nameof(model))" for model in models - chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) - target_values = posterior_mean(model) - for vn in filter(varnames_filter, varnames(model)) - # We want to compare elementwise which can be achieved by - # extracting the leaves of the `VarName` and the corresponding value. - for vn_leaf in varname_leaves(vn, get(target_values, vn)) - target_value = get(target_values, vn_leaf) - chain_mean_value = marginal_mean_of_samples(chain, vn_leaf) - @test chain_mean_value ≈ target_value atol = atol rtol = rtol - end - end - end -end - -""" - test_sampler_on_demo_models(meanfunction, sampler, args...; kwargs...) - -Test `sampler` on every model in [`DEMO_MODELS`](@ref). - -This is just a proxy for `test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)`. -""" -function test_sampler_on_demo_models( - sampler::AbstractMCMC.AbstractSampler, args...; kwargs... -) - return test_sampler(DEMO_MODELS, sampler, args...; kwargs...) -end - -""" - test_sampler_continuous(sampler, args...; kwargs...) - -Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. - -As of right now, this is just an alias for [`test_sampler_on_demo_models`](@ref). -""" -function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) - return test_sampler_on_demo_models(sampler, args...; kwargs...) -end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 6ca3b9852..89877f385 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -80,7 +80,7 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) -islinked(vi::ThreadSafeVarInfo) = islinked(vi.varinfo) +is_transformed(vi::ThreadSafeVarInfo) = is_transformed(vi.varinfo) function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...) @@ -103,17 +103,13 @@ end # consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates # to define `getacc(vi)`. function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = contextualize( - model, setleafcontext(model.context, DynamicTransformationContext{false}()) - ) - return settrans!!(last(evaluate!!(model, vi)), t) + model = setleafcontext(model, DynamicTransformationContext{false}()) + return set_transformed!!(last(evaluate!!(model, vi)), t) end function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = contextualize( - model, setleafcontext(model.context, DynamicTransformationContext{true}()) - ) - return settrans!!(last(evaluate!!(model, vi)), NoTransformation()) + model = setleafcontext(model, DynamicTransformationContext{true}()) + return set_transformed!!(last(evaluate!!(model, vi)), NoTransformation()) end function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) @@ -185,22 +181,15 @@ end values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) -function unset_flag!( - vi::ThreadSafeVarInfo, vn::VarName, flag::String, ignoreable::Bool=false -) - return unset_flag!(vi.varinfo, vn, flag, ignoreable) -end -function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) - return is_flagged(vi.varinfo, vn, flag) +function set_transformed!!(vi::ThreadSafeVarInfo, val::Bool, vn::VarName) + return Accessors.@set vi.varinfo = set_transformed!!(vi.varinfo, val, vn) end -function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName) - return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) +is_transformed(vi::ThreadSafeVarInfo, vn::VarName) = is_transformed(vi.varinfo, vn) +function is_transformed(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) + return is_transformed(vi.varinfo, vns) end -istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) -istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns) - getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.varinfo, vn) function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector) diff --git a/src/utils.jl b/src/utils.jl index d3371271f..b09bfb9fa 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -456,50 +456,6 @@ function recombine(d::MultivariateDistribution, val::AbstractVector, n::Int) return copy(reshape(val, length(d), n)) end -# Uniform random numbers with range 4 for robust initializations -# Reference: https://mc-stan.org/docs/2_19/reference-manual/initialization.html -randrealuni(rng::Random.AbstractRNG) = 4 * rand(rng) - 2 -randrealuni(rng::Random.AbstractRNG, args...) = 4 .* rand(rng, args...) .- 2 - -istransformable(dist) = link_transform(dist) !== identity - -################################# -# Single-sample initialisations # -################################# - -inittrans(rng, dist::UnivariateDistribution) = Bijectors.invlink(dist, randrealuni(rng)) -function inittrans(rng, dist::MultivariateDistribution) - # Get the length of the unconstrained vector - b = link_transform(dist) - d = Bijectors.output_length(b, length(dist)) - return Bijectors.invlink(dist, randrealuni(rng, d)) -end -function inittrans(rng, dist::MatrixDistribution) - # Get the size of the unconstrained vector - b = link_transform(dist) - sz = Bijectors.output_size(b, size(dist)) - return Bijectors.invlink(dist, randrealuni(rng, sz...)) -end -function inittrans(rng, dist::Distribution{CholeskyVariate}) - # Get the size of the unconstrained vector - b = link_transform(dist) - sz = Bijectors.output_size(b, size(dist)) - return Bijectors.invlink(dist, randrealuni(rng, sz...)) -end -################################ -# Multi-sample initialisations # -################################ - -function inittrans(rng, dist::UnivariateDistribution, n::Int) - return Bijectors.invlink(dist, randrealuni(rng, n)) -end -function inittrans(rng, dist::MultivariateDistribution, n::Int) - return Bijectors.invlink(dist, randrealuni(rng, size(dist)[1], n)) -end -function inittrans(rng, dist::MatrixDistribution, n::Int) - return Bijectors.invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n]) -end - ####################### # Convenience methods # ####################### @@ -837,249 +793,6 @@ end # Handle `AbstractDict` differently since `eltype` results in a `Pair`. infer_nested_eltype(::Type{<:AbstractDict{<:Any,ET}}) where {ET} = infer_nested_eltype(ET) -""" - varname_leaves(vn::VarName, val) - -Return an iterator over all varnames that are represented by `vn` on `val`. - -# Examples -```jldoctest -julia> using DynamicPPL: varname_leaves - -julia> foreach(println, varname_leaves(@varname(x), rand(2))) -x[1] -x[2] - -julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2))) -x[1:2][1] -x[1:2][2] - -julia> x = (y = 1, z = [[2.0], [3.0]]); - -julia> foreach(println, varname_leaves(@varname(x), x)) -x.y -x.z[1][1] -x.z[2][1] -``` -""" -varname_leaves(vn::VarName, ::Real) = [vn] -function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) - return ( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for - I in CartesianIndices(val) - ) -end -function varname_leaves(vn::VarName, val::AbstractArray) - return Iterators.flatten( - varname_leaves( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I] - ) for I in CartesianIndices(val) - ) -end -function varname_leaves(vn::VarName, val::NamedTuple) - iter = Iterators.map(keys(val)) do k - optic = Accessors.PropertyLens{k}() - varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val)) - end - return Iterators.flatten(iter) -end - -""" - varname_and_value_leaves(vn::VarName, val) - -Return an iterator over all varname-value pairs that are represented by `vn` on `val`. - -# Examples -```jldoctest varname-and-value-leaves -julia> using DynamicPPL: varname_and_value_leaves - -julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2)) -(x[1], 1) -(x[2], 2) - -julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2)) -(x[1:2][1], 1) -(x[1:2][2], 2) - -julia> x = (y = 1, z = [[2.0], [3.0]]); - -julia> foreach(println, varname_and_value_leaves(@varname(x), x)) -(x.y, 1) -(x.z[1][1], 2.0) -(x.z[2][1], 3.0) -``` - -There is also some special handling for certain types: - -```jldoctest varname-and-value-leaves -julia> using LinearAlgebra - -julia> x = reshape(1:4, 2, 2); - -julia> # `LowerTriangular` - foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x))) -(x[1, 1], 1) -(x[2, 1], 2) -(x[2, 2], 4) - -julia> # `UpperTriangular` - foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x))) -(x[1, 1], 1) -(x[1, 2], 3) -(x[2, 2], 4) - -julia> # `Cholesky` with lower-triangular - foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0))) -(x.L[1, 1], 1.0) -(x.L[2, 1], 0.0) -(x.L[2, 2], 1.0) - -julia> # `Cholesky` with upper-triangular - foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0))) -(x.U[1, 1], 1.0) -(x.U[1, 2], 0.0) -(x.U[2, 2], 1.0) -``` -""" -function varname_and_value_leaves(vn::VarName, x) - return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x))) -end - -""" - varname_and_value_leaves(container) - -Return an iterator over all varname-value pairs that are represented by `container`. - -This is the same as [`varname_and_value_leaves(vn::VarName, x)`](@ref) but over a container -containing multiple varnames. - -See also: [`varname_and_value_leaves(vn::VarName, x)`](@ref). - -# Examples -```jldoctest varname-and-value-leaves-container -julia> using DynamicPPL: varname_and_value_leaves - -julia> # With an `OrderedDict` - dict = OrderedDict(@varname(y) => 1, @varname(z) => [[2.0], [3.0]]); - -julia> foreach(println, varname_and_value_leaves(dict)) -(y, 1) -(z[1][1], 2.0) -(z[2][1], 3.0) - -julia> # With a `NamedTuple` - nt = (y = 1, z = [[2.0], [3.0]]); - -julia> foreach(println, varname_and_value_leaves(nt)) -(y, 1) -(z[1][1], 2.0) -(z[2][1], 3.0) -``` -""" -function varname_and_value_leaves(container::OrderedDict) - return Iterators.flatten(varname_and_value_leaves(k, v) for (k, v) in container) -end -function varname_and_value_leaves(container::NamedTuple) - return Iterators.flatten( - varname_and_value_leaves(VarName{k}(), v) for (k, v) in pairs(container) - ) -end - -""" - Leaf{T} - -A container that represents the leaf of a nested structure, implementing -`iterate` to return itself. - -This is particularly useful in conjunction with `Iterators.flatten` to -prevent flattening of nested structures. -""" -struct Leaf{T} - value::T -end - -Leaf(xs...) = Leaf(xs) - -# Allow us to treat `Leaf` as an iterator containing a single element. -# Something like an `[x]` would also be an iterator with a single element, -# but when we call `flatten` on this, it would also iterate over `x`, -# unflattening that too. By making `Leaf` a single-element iterator, which -# returns itself, we can call `iterate` on this as many times as we like -# without causing any change. The result is that `Iterators.flatten` -# will _not_ unflatten `Leaf`s. -# Note that this is similar to how `Base.iterate` is implemented for `Real`:: -# -# julia> iterate(1) -# (1, nothing) -# -# One immediate example where this becomes in our scenario is that we might -# have `missing` values in our data, which does _not_ have an `iterate` -# implemented. Calling `Iterators.flatten` on this would cause an error. -Base.iterate(leaf::Leaf) = leaf, nothing -Base.iterate(::Leaf, _) = nothing - -# Convenience. -value(leaf::Leaf) = leaf.value - -# Leaf-types. -varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)] -function varname_and_value_leaves_inner( - vn::VarName, val::AbstractArray{<:Union{Real,Missing}} -) - return ( - Leaf( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), - val[I], - ) for I in CartesianIndices(val) - ) -end -# Containers. -function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray) - return Iterators.flatten( - varname_and_value_leaves_inner( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), - val[I], - ) for I in CartesianIndices(val) - ) -end -function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple) - iter = Iterators.map(keys(val)) do k - optic = Accessors.PropertyLens{k}() - varname_and_value_leaves_inner( - VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val) - ) - end - - return Iterators.flatten(iter) -end -# Special types. -function varname_and_value_leaves_inner(vn::VarName, x::Cholesky) - # TODO: Or do we use `PDMat` here? - return if x.uplo == 'L' - varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() ∘ vn, x.L) - else - varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() ∘ vn, x.U) - end -end -function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular) - return ( - Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) - # Iteration over the lower-triangular indices. - for I in CartesianIndices(x) if I[1] >= I[2] - ) -end -function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular) - return ( - Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) - # Iteration over the upper-triangular indices. - for I in CartesianIndices(x) if I[1] <= I[2] - ) -end - -broadcast_safe(x) = x -broadcast_safe(x::Distribution) = (x,) -broadcast_safe(x::AbstractContext) = (x,) - # Convert (x=1,) to Dict(@varname(x) => 1) function to_varname_dict(nt::NamedTuple) return Dict{VarName,Any}(VarName{k}() => v for (k, v) in pairs(nt)) diff --git a/src/varinfo.jl b/src/varinfo.jl index dec4db3ec..734bf3db5 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -15,13 +15,13 @@ not. Let `md` be an instance of `Metadata`: - `md.vns` is the vector of all `VarName` instances. - `md.idcs` is the dictionary that maps each `VarName` instance to its index in - `md.vns`, `md.ranges` `md.dists`, and `md.flags`. + `md.vns`, `md.ranges` `md.dists`, and `md.is_transformed`. - `md.vns[md.idcs[vn]] == vn`. - `md.dists[md.idcs[vn]]` is the distribution of `vn`. - `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`. - `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`. -- `md.flags` is a dictionary of true/false flags. `md.flags[flag][md.idcs[vn]]` is the - value of `flag` corresponding to `vn`. +- `md.is_transformed` is a BitVector of true/false flags for whether a variable has been + transformed. `md.is_transformed[md.idcs[vn]]` is the value corresponding to `vn`. To make `md::Metadata` type stable, all the `md.vns` must have the same symbol and distribution type. However, one can have a Julia variable, say `x`, that is a @@ -56,8 +56,7 @@ struct Metadata{ # Vector of distributions correpsonding to `vns` dists::TDists # AbstractVector{<:Distribution} - # Each `flag` has a `BitVector` `flags[flag]`, where `flags[flag][i]` is the true/false flag value corresonding to `vns[i]` - flags::Dict{String,BitVector} + is_transformed::BitVector end function Base.:(==)(md1::Metadata, md2::Metadata) @@ -67,7 +66,7 @@ function Base.:(==)(md1::Metadata, md2::Metadata) md1.ranges == md2.ranges && md1.vals == md2.vals && md1.dists == md2.dists && - md1.flags == md2.flags + md1.is_transformed == md2.is_transformed ) end @@ -113,10 +112,14 @@ function VarInfo(meta=Metadata()) end """ - VarInfo([rng, ]model[, sampler]) + VarInfo( + [rng::Random.AbstractRNG], + model, + [init_strategy::AbstractInitStrategy] + ) -Generate a `VarInfo` object for the given `model`, by evaluating it once using -the given `rng`, `sampler`. +Generate a `VarInfo` object for the given `model`, by initialising it with the +given `rng` and `init_strategy`. !!! warning @@ -129,12 +132,14 @@ the given `rng`, `sampler`. instead. """ function VarInfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return typed_varinfo(rng, model, sampler) + return typed_varinfo(rng, model, init_strategy) end -function VarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return VarInfo(Random.default_rng(), model, sampler) +function VarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return VarInfo(Random.default_rng(), model, init_strategy) end const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} @@ -166,8 +171,8 @@ function metadata_to_varnamedvector(md::Metadata) vns = copy(md.vns) ranges = copy(md.ranges) vals = copy(md.vals) - is_unconstrained = map(Base.Fix1(istrans, md), md.vns) - transforms = map(md.dists, is_unconstrained) do dist, trans + is_trans = map(Base.Fix1(is_transformed, md), md.vns) + transforms = map(md.dists, is_trans) do dist, trans if trans return from_linked_vec_transform(dist) else @@ -176,12 +181,7 @@ function metadata_to_varnamedvector(md::Metadata) end return VarNamedVector( - OrderedDict{eltype(keys(idcs)),Int}(idcs), - vns, - ranges, - vals, - transforms, - is_unconstrained, + OrderedDict{eltype(keys(idcs)),Int}(idcs), vns, ranges, vals, transforms, is_trans ) end @@ -195,7 +195,7 @@ end ######################## """ - untyped_varinfo([rng, ]model[, sampler]) + untyped_varinfo([rng, ]model[, init_strategy]) Construct a VarInfo object for the given `model`, which has just a single `Metadata` as its metadata field. @@ -203,15 +203,17 @@ Construct a VarInfo object for the given `model`, which has just a single # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ function untyped_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return last(evaluate_and_sample!!(rng, model, VarInfo(Metadata()), sampler)) + return last(init!!(rng, model, VarInfo(Metadata()), init_strategy)) end -function untyped_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return untyped_varinfo(Random.default_rng(), model, sampler) +function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return untyped_varinfo(Random.default_rng(), model, init_strategy) end """ @@ -238,8 +240,8 @@ function typed_varinfo(vi::UntypedVarInfo) sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) # New dists sym_dists = getindex.((meta.dists,), inds) - # New flags - sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags)) + # New is_transformed + sym_is_transformed = meta.is_transformed[inds] # Extract new ranges and vals _ranges = getindex.((meta.ranges,), inds) @@ -255,7 +257,9 @@ function typed_varinfo(vi::UntypedVarInfo) push!( new_metas, - Metadata(sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_flags), + Metadata( + sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_is_transformed + ), ) end nt = NamedTuple{syms_tuple}(Tuple(new_metas)) @@ -270,7 +274,7 @@ function typed_varinfo(vi::NTVarInfo) return vi end """ - typed_varinfo([rng, ]model[, sampler]) + typed_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has a NamedTuple of `Metadata` structs as its metadata field. @@ -278,19 +282,21 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ function typed_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return typed_varinfo(untyped_varinfo(rng, model, sampler)) + return typed_varinfo(untyped_varinfo(rng, model, init_strategy)) end -function typed_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return typed_varinfo(Random.default_rng(), model, sampler) +function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return typed_varinfo(Random.default_rng(), model, init_strategy) end """ - untyped_vector_varinfo([rng, ]model[, sampler]) + untyped_vector_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has just a single `VarNamedVector` as its metadata field. @@ -298,23 +304,27 @@ Return a VarInfo object for the given `model`, which has just a single # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) return VarInfo(md, copy(vi.accs)) end function untyped_vector_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler)) + return untyped_vector_varinfo(untyped_varinfo(rng, model, init_strategy)) end -function untyped_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return untyped_vector_varinfo(Random.default_rng(), model, sampler) +function untyped_vector_varinfo( + model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() +) + return untyped_vector_varinfo(Random.default_rng(), model, init_strategy) end """ - typed_vector_varinfo([rng, ]model[, sampler]) + typed_vector_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has a NamedTuple of `VarNamedVector`s as its metadata field. @@ -322,7 +332,7 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) @@ -334,12 +344,16 @@ function typed_vector_varinfo(vi::UntypedVectorVarInfo) return VarInfo(nt, copy(vi.accs)) end function typed_vector_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler)) + return typed_vector_varinfo(untyped_vector_varinfo(rng, model, init_strategy)) end -function typed_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return typed_vector_varinfo(Random.default_rng(), model, sampler) +function typed_vector_varinfo( + model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() +) + return typed_vector_varinfo(Random.default_rng(), model, init_strategy) end """ @@ -388,7 +402,7 @@ end end function unflatten_metadata(md::Metadata, x::AbstractVector) - return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.flags) + return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.is_transformed) end unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) @@ -404,9 +418,7 @@ Construct an empty type unstable instance of `Metadata`. """ function Metadata() vals = Vector{Real}() - flags = Dict{String,BitVector}() - flags["del"] = BitVector() - flags["trans"] = BitVector() + is_transformed = BitVector() return Metadata( Dict{VarName,Int}(), @@ -414,7 +426,7 @@ function Metadata() Vector{UnitRange{Int}}(), vals, Vector{Distribution}(), - flags, + is_transformed, ) end @@ -431,10 +443,7 @@ function empty!(meta::Metadata) empty!(meta.ranges) empty!(meta.vals) empty!(meta.dists) - for k in keys(meta.flags) - empty!(meta.flags[k]) - end - + empty!(meta.is_transformed) return meta end @@ -518,8 +527,9 @@ function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:Va offset = r[end] end - flags = Dict(k => v[indices_for_vns] for (k, v) in metadata.flags) - return Metadata(indices, vns, ranges, vals, metadata.dists[indices_for_vns], flags) + dists = metadata.dists[indices_for_vns] + is_transformed = metadata.is_transformed[indices_for_vns] + return Metadata(indices, vns, ranges, vals, dists, is_transformed) end function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) @@ -590,11 +600,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] - flags = Dict{String,BitVector}() - # Initialize the `flags`. - for k in union(keys(metadata_left.flags), keys(metadata_right.flags)) - flags[k] = BitVector() - end + transformed = BitVector() # Range offset. offset = 0 @@ -611,12 +617,10 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) offset = r[end] dist = getdist(metadata_for_vn, vn) push!(dists, dist) - for k in keys(flags) - push!(flags[k], is_flagged(metadata_for_vn, vn, k)) - end + push!(transformed, is_transformed(metadata_for_vn, vn)) end - return Metadata(idcs, vns, ranges, vals, dists, flags) + return Metadata(idcs, vns, ranges, vals, dists, transformed) end const VarView = Union{Int,UnitRange,Vector{Int}} @@ -785,35 +789,30 @@ function setval!(md::Metadata, val, vn::VarName) return md.vals[getrange(md, vn)] = tovec(val) end -function settrans!!(vi::VarInfo, trans::Bool, vn::VarName) - settrans!!(getmetadata(vi, vn), trans, vn) +function set_transformed!!(vi::VarInfo, val::Bool, vn::VarName) + set_transformed!!(getmetadata(vi, vn), val, vn) return vi end -function settrans!!(metadata::Metadata, trans::Bool, vn::VarName) - if trans - set_flag!(metadata, vn, "trans") - else - unset_flag!(metadata, vn, "trans") - end - +function set_transformed!!(metadata::Metadata, val::Bool, vn::VarName) + metadata.is_transformed[getidx(metadata, vn)] = val return metadata end -function settrans!!(vi::VarInfo, trans::Bool) +function set_transformed!!(vi::VarInfo, val::Bool) for vn in keys(vi) - settrans!!(vi, trans, vn) + set_transformed!!(vi, val, vn) end return vi end -settrans!!(vi::VarInfo, trans::NoTransformation) = settrans!!(vi, false) +set_transformed!!(vi::VarInfo, ::NoTransformation) = set_transformed!!(vi, false) # HACK: This is necessary to make something like `link!!(transformation, vi, model)` # work properly, which will transform the variables according to `transformation` -# and then call `settrans!!(vi, transformation)`. An alternative would be to add +# and then call `set_transformed!!(vi, transformation)`. An alternative would be to add # the `transformation` to the `VarInfo` object, but at the moment doesn't seem # worth it as `VarInfo` has its own way of handling transformations. -settrans!!(vi::VarInfo, trans::AbstractTransformation) = settrans!!(vi, true) +set_transformed!!(vi::VarInfo, ::AbstractTransformation) = set_transformed!!(vi, true) """ syms(vi::VarInfo) @@ -853,30 +852,6 @@ all_varnames_grouped_by_symbol(vi::NTVarInfo) = all_varnames_grouped_by_symbol(v return expr end -# TODO(mhauru) These set_flag! methods return the VarInfo. They should probably be called -# set_flag!!. -""" - set_flag!(vi::VarInfo, vn::VarName, flag::String) - -Set `vn`'s value for `flag` to `true` in `vi`. -""" -function set_flag!(vi::VarInfo, vn::VarName, flag::String) - set_flag!(getmetadata(vi, vn), vn, flag) - return vi -end -function set_flag!(md::Metadata, vn::VarName, flag::String) - return md.flags[flag][getidx(md, vn)] = true -end - -function set_flag!(vnv::VarNamedVector, ::VarName, flag::String) - if flag == "del" - # The "del" flag is effectively always set for a VarNamedVector, so this is a no-op. - else - throw(ErrorException("Flag $flag not valid for VarNamedVector")) - end - return vnv -end - #### #### APIs for typed and untyped VarInfo #### @@ -914,8 +889,8 @@ Base.keys(vi::NTVarInfo{<:NamedTuple{()}}) = VarName[] return expr end -istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) -istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans") +is_transformed(vi::VarInfo, vn::VarName) = is_transformed(getmetadata(vi, vn), vn) +is_transformed(md::Metadata, vn::VarName) = md.is_transformed[getidx(md, vn)] getaccs(vi::VarInfo) = vi.accs setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs @@ -970,11 +945,11 @@ end function _link!!(vi::UntypedVarInfo, vns) # TODO: Change to a lazy iterator over `vns` - if ~istrans(vi, vns[1]) + if ~is_transformed(vi, vns[1]) for vn in vns f = internal_to_linked_internal_transform(vi, vn) vi = _inner_transform!(vi, vn, f) - vi = settrans!!(vi, true, vn) + vi = set_transformed!!(vi, true, vn) end return vi else @@ -1015,12 +990,12 @@ end f_vns = vi.metadata.$f.vns f_vns = filter_subsumed(vns.$f, f_vns) if !isempty(f_vns) - if !istrans(vi, f_vns[1]) + if !is_transformed(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns f = internal_to_linked_internal_transform(vi, vn) vi = _inner_transform!(vi, vn, f) - vi = settrans!!(vi, true, vn) + vi = set_transformed!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -1076,17 +1051,18 @@ end function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) # Because `VarInfo` does not contain any information about what the transformation # other than whether or not it has actually been transformed, the best we can do - # is just assume that `default_transformation` is the correct one if `istrans(vi)`. - t = istrans(vi) ? default_transformation(model, vi) : NoTransformation() + # is just assume that `default_transformation` is the correct one if + # `is_transformed(vi)`. + t = is_transformed(vi) ? default_transformation(model, vi) : NoTransformation() return maybe_invlink_before_eval!!(t, vi, model) end function _invlink!!(vi::UntypedVarInfo, vns) - if istrans(vi, vns[1]) + if is_transformed(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) vi = _inner_transform!(vi, vn, f) - vi = settrans!!(vi, false, vn) + vi = set_transformed!!(vi, false, vn) end return vi else @@ -1118,12 +1094,12 @@ end quote f_vns = vi.metadata.$f.vns f_vns = filter_subsumed(vns.$f, f_vns) - if istrans(vi, f_vns[1]) + if is_transformed(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns f = linked_internal_to_internal_transform(vi, vn) vi = _inner_transform!(vi, vn, f) - vi = settrans!!(vi, false, vn) + vi = set_transformed!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -1253,7 +1229,7 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ vals_new = map(vns) do vn # Return early if we're already in unconstrained space. # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. - if istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) + if is_transformed(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] end @@ -1267,7 +1243,7 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ # Accumulate the log-abs-det jacobian correction. cumulative_logjac += logjac # Mark as transformed. - settrans!!(varinfo, true, vn) + set_transformed!!(varinfo, true, vn) # Return the vectorized transformed value. return yvec end @@ -1288,7 +1264,7 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.flags, + metadata.is_transformed, ), cumulative_logjac end @@ -1313,7 +1289,7 @@ function _link_metadata!!( # Fix this when attending to issue #653. cumulative_logjac += logjac1 + logjac2 metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked) - settrans!(metadata, true, vn) + set_transformed!(metadata, true, vn) end return metadata, cumulative_logjac end @@ -1428,7 +1404,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ # Return early if we're already in constrained space OR if we're not # supposed to touch this `vn`. # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. - if !istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) + if !is_transformed(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] end @@ -1442,7 +1418,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ # Accumulate the log-abs-det jacobian correction. cumulative_inv_logjac += inv_logjac # Mark as no longer transformed. - settrans!!(varinfo, false, vn) + set_transformed!!(varinfo, false, vn) # Return the vectorized transformed value. return xvec end @@ -1463,7 +1439,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.flags, + metadata.is_transformed, ), cumulative_inv_logjac end @@ -1481,67 +1457,32 @@ function _invlink_metadata!!( cumulative_inv_logjac += inv_logjac new_transform = from_vec_transform(new_val) metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) - settrans!(metadata, false, vn) + set_transformed!(metadata, false, vn) end return metadata, cumulative_inv_logjac end -# TODO(mhauru) The treatment of the case when some variables are linked and others are not -# should be revised. It used to be the case that for UntypedVarInfo `islinked` returned -# whether the first variable was linked. For NTVarInfo we did an OR over the first +# TODO(mhauru) The treatment of the case when some variables are transformed and others are +# not should be revised. It used to be the case that for UntypedVarInfo `is_transformed` +# returned whether the first variable was linked. For NTVarInfo we did an OR over the first # variables under each symbol. We now more consistently use OR, but I'm not convinced this # is really the right thing to do. """ - islinked(vi::VarInfo) + is_transformed(vi::VarInfo) Check whether `vi` is in the transformed space. Turing's Hamiltonian samplers use the `link` and `invlink` functions from [Bijectors.jl](https://github.com/TuringLang/Bijectors.jl) to map a constrained variable (for example, one bounded to the space `[0, 1]`) from its constrained space to the set of -real numbers. `islinked` checks if the number is in the constrained space or the real space. +real numbers. `is_transformed` checks if the number is in the constrained space or the real +space. -If some but only some of the variables in `vi` are linked, this function will return `true`. -This behavior will likely change in the future. +If some but only some of the variables in `vi` are transformed, this function will return +`true`. This behavior will likely change in the future. """ -function islinked(vi::VarInfo) - return any(istrans(vi, vn) for vn in keys(vi)) -end - -function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName) - return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) -end -function nested_setindex_maybe!( - vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym} -) where {names,sym} - return if sym in names - _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) - else - nothing - end -end -function _nested_setindex_maybe!( - vi::VarInfo, md::Union{Metadata,VarNamedVector}, val, vn::VarName -) - # If `vn` is in `vns`, then we can just use the standard `setindex!`. - vns = Base.keys(md) - if vn in vns - setindex!(vi, val, vn) - return vn - end - - # Otherwise, we need to check if either of the `vns` subsumes `vn`. - i = findfirst(Base.Fix2(subsumes, vn), vns) - i === nothing && return nothing - - vn_parent = vns[i] - val_parent = getindex(vi, vn_parent) # TODO: Ensure that we're working with a view here. - # Split the varname into its tail optic. - optic = remove_parent_optic(vn_parent, vn) - # Update the value for the parent. - val_parent_updated = set!!(val_parent, optic, val) - setindex!(vi, val_parent_updated, vn_parent) - return vn_parent +function is_transformed(vi::VarInfo) + return any(is_transformed(vi, vn) for vn in keys(vi)) end # The default getindex & setindex!() for get & set values @@ -1648,7 +1589,7 @@ function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) for accname in acckeys(vi) push!(lines, (string(accname), getacc(vi, Val(accname)))) end - push!(lines, ("flags", vi.metadata.flags)) + push!(lines, ("is_transformed", vi.metadata.is_transformed)) max_name_length = maximum(map(length ∘ first, lines)) fmt = Printf.Format("%-$(max_name_length)s") vi_str = ( @@ -1722,14 +1663,7 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) if vi isa NTVarInfo && ~haskey(vi.metadata, sym) # The NamedTuple doesn't have an entry for this variable, let's add one. val = tovec(r) - md = Metadata( - Dict(vn => 1), - [vn], - [1:length(val)], - val, - [dist], - Dict{String,BitVector}("trans" => [false], "del" => [false]), - ) + md = Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false])) vi = Accessors.@set vi.metadata[sym] = md else meta = getmetadata(vi, vn) @@ -1762,8 +1696,7 @@ function Base.push!(meta::Metadata, vn, r, dist) push!(meta.ranges, (l + 1):(l + n)) append!(meta.vals, val) push!(meta.dists, dist) - push!(meta.flags["del"], false) - push!(meta.flags["trans"], false) + push!(meta.is_transformed, false) return meta end @@ -1776,56 +1709,6 @@ end # Rand & replaying method for VarInfo # ####################################### -""" - is_flagged(vi::VarInfo, vn::VarName, flag::String) - -Check whether `vn` has a true value for `flag` in `vi`. -""" -function is_flagged(vi::VarInfo, vn::VarName, flag::String) - return is_flagged(getmetadata(vi, vn), vn, flag) -end -function is_flagged(metadata::Metadata, vn::VarName, flag::String) - return metadata.flags[flag][getidx(metadata, vn)] -end -function is_flagged(::VarNamedVector, ::VarName, flag::String) - if flag == "del" - return true - else - throw(ErrorException("Flag $flag not valid for VarNamedVector")) - end -end - -# TODO(mhauru) The "ignorable" argument is a temporary hack while developing VarNamedVector, -# but still having to support the interface based on Metadata too -""" - unset_flag!(vi::VarInfo, vn::VarName, flag::String, ignorable::Bool=false - -Set `vn`'s value for `flag` to `false` in `vi`. - -Setting some flags for some `VarInfo` types is not possible, and by default attempting to do -so will error. If `ignorable` is set to `true` then this will silently be ignored instead. -""" -function unset_flag!(vi::VarInfo, vn::VarName, flag::String, ignorable::Bool=false) - unset_flag!(getmetadata(vi, vn), vn, flag, ignorable) - return vi -end -function unset_flag!(metadata::Metadata, vn::VarName, flag::String, ignorable::Bool=false) - metadata.flags[flag][getidx(metadata, vn)] = false - return metadata -end - -function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String, ignorable::Bool=false) - if ignorable - return vnv - end - if flag == "del" - throw(ErrorException("The \"del\" flag cannot be unset for VarNamedVector")) - else - throw(ErrorException("Flag $flag not valid for VarNamedVector")) - end - return vnv -end - # TODO: Maybe rename or something? """ _apply!(kernel!, vi::VarInfo, values, keys) @@ -1900,179 +1783,6 @@ function _find_missing_keys(vi::VarInfoOrThreadSafeVarInfo, keys) return missing_keys end -""" - setval!(vi::VarInfo, x) - setval!(vi::VarInfo, values, keys) - setval!(vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) - -Set the values in `vi` to the provided values and leave those which are not present in -`x` or `chains` unchanged. - -## Notes -This is rather limited for two reasons: -1. It uses `subsumes_string(string(vn), map(string, keys))` under the hood, - and therefore suffers from the same limitations as [`subsumes_string`](@ref). -2. It will set every `vn` present in `keys`. It will NOT however - set every `k` present in `keys`. This means that if `vn == [m[1], m[2]]`, - representing some variable `m`, calling `setval!(vi, (m = [1.0, 2.0]))` will - be a no-op since it will try to find `m[1]` and `m[2]` in `keys((m = [1.0, 2.0]))`. - -## Example -```jldoctest -julia> using DynamicPPL, Distributions, StableRNGs - -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1) - end - end; - -julia> rng = StableRNG(42); - -julia> m = demo([missing]); - -julia> var_info = DynamicPPL.VarInfo(rng, m); - -julia> var_info[@varname(m)] --0.6702516921145671 - -julia> var_info[@varname(x[1])] --0.22312984965118443 - -julia> DynamicPPL.setval!(var_info, (m = 100.0, )); # set `m` and and keep `x[1]` - -julia> var_info[@varname(m)] # [✓] changed -100.0 - -julia> var_info[@varname(x[1])] # [✓] unchanged --0.22312984965118443 -``` -""" -setval!(vi::VarInfo, x) = setval!(vi, values(x), keys(x)) -setval!(vi::VarInfo, values, keys) = _apply!(_setval_kernel!, vi, values, keys) -function setval!(vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) - return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) -end - -function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys) - indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) - if !isempty(indices) - val = reduce(vcat, values[indices]) - setval!(vi, val, vn) - settrans!!(vi, false, vn) - end - - return indices -end - -""" - setval_and_resample!(vi::VarInfo, x) - setval_and_resample!(vi::VarInfo, values, keys) - setval_and_resample!(vi::VarInfo, chains::AbstractChains, sample_idx, chain_idx) - -Set the values in `vi` to the provided values and those which are not present -in `x` or `chains` to *be* resampled. - -Note that this does *not* resample the values not provided! It will call -`setflag!(vi, vn, "del")` for variables `vn` for which no values are provided, which means -that the next time we call `model(vi)` these variables will be resampled. - -## Note -- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info. - -## Example -```jldoctest -julia> using DynamicPPL, Distributions, StableRNGs - -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1) - end - end; - -julia> rng = StableRNG(42); - -julia> m = demo([missing]); - -julia> var_info = DynamicPPL.VarInfo(rng, m); - # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. - -julia> var_info[@varname(m)] --0.6702516921145671 - -julia> var_info[@varname(x[1])] --0.22312984965118443 - -julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling - -julia> var_info[@varname(m)] # [✓] changed -100.0 - -julia> var_info[@varname(x[1])] # [✓] unchanged --0.22312984965118443 - -julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0` - -julia> var_info[@varname(m)] # [✓] unchanged -100.0 - -julia> var_info[@varname(x[1])] # [✓] changed -101.37363069798343 -``` - -## See also -- [`setval!`](@ref) -""" -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x) - return setval_and_resample!(vi, values(x), keys(x)) -end -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys) - return _apply!(_setval_and_resample_kernel!, vi, values, keys) -end -function setval_and_resample!( - vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int -) - if supports_varname_indexing(chains) - # First we need to set every variable to be resampled. - for vn in keys(vi) - set_flag!(vi, vn, "del") - end - # Then we set the variables in `varinfo` from `chain`. - for vn in varnames(chains) - vn_updated = nested_setindex_maybe!( - vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn - ) - - # Unset the `del` flag if we found something. - if vn_updated !== nothing - # NOTE: This will be triggered even if only a subset of a variable has been set! - unset_flag!(vi, vn_updated, "del") - end - end - else - setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) - end -end - -function _setval_and_resample_kernel!( - vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys -) - indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) - if !isempty(indices) - val = reduce(vcat, values[indices]) - setval!(vi, val, vn) - settrans!!(vi, false, vn) - else - # Ensures that we'll resample the variable corresponding to `vn` if we run - # the model on `vi` again. - set_flag!(vi, vn, "del") - end - - return indices -end - values_as(vi::VarInfo) = vi.metadata values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon())) function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index d756a4922..4b2791d19 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -322,28 +322,28 @@ getrange(vnv::VarNamedVector, vn::VarName) = getrange(vnv, getidx(vnv, vn)) gettransform(vnv::VarNamedVector, idx::Int) = vnv.transforms[idx] gettransform(vnv::VarNamedVector, vn::VarName) = gettransform(vnv, getidx(vnv, vn)) -# TODO(mhauru) Eventually I would like to rename the istrans function to is_unconstrained, -# but that's significantly breaking. +# TODO(mhauru) Eventually I would like to rename the is_transformed function to +# is_unconstrained, but that's significantly breaking. """ - istrans(vnv::VarNamedVector, vn::VarName) + is_transformed(vnv::VarNamedVector, vn::VarName) Return a boolean for whether `vn` is guaranteed to have been transformed so that its domain is all of Euclidean space. """ -istrans(vnv::VarNamedVector, vn::VarName) = vnv.is_unconstrained[getidx(vnv, vn)] +is_transformed(vnv::VarNamedVector, vn::VarName) = vnv.is_unconstrained[getidx(vnv, vn)] """ - settrans!(vnv::VarNamedVector, val::Bool, vn::VarName) + set_transformed!(vnv::VarNamedVector, val::Bool, vn::VarName) Set the value for whether `vn` is guaranteed to have been transformed so that all of Euclidean space is its domain. """ -function settrans!(vnv::VarNamedVector, val::Bool, vn::VarName) +function set_transformed!(vnv::VarNamedVector, val::Bool, vn::VarName) return vnv.is_unconstrained[vnv.varname_to_index[vn]] = val end -function settrans!!(vnv::VarNamedVector, val::Bool, vn::VarName) - settrans!(vnv, val, vn) +function set_transformed!!(vnv::VarNamedVector, val::Bool, vn::VarName) + set_transformed!(vnv, val, vn) return vnv end @@ -548,7 +548,7 @@ julia> vnv[@varname(x)] function reset!(vnv::VarNamedVector, val, vn::VarName) f = from_vec_transform(val) retval = setindex_internal!(vnv, tovec(val), vn, f) - settrans!(vnv, false, vn) + set_transformed!(vnv, false, vn) return retval end @@ -766,6 +766,11 @@ function update_internal!( return nothing end +function BangBang.push!(vnv::VarNamedVector, vn, val, dist) + f = from_vec_transform(dist) + return setindex_internal!(vnv, tovec(val), vn, f) +end + # BangBang versions of the above functions. # The only difference is that update_internal!! and insert_internal!! check whether the # container types of the VarNamedVector vector need to be expanded to accommodate the new @@ -897,7 +902,7 @@ end function reset!!(vnv::VarNamedVector, val, vn::VarName) f = from_vec_transform(val) vnv = setindex_internal!!(vnv, tovec(val), vn, f) - vnv = settrans!!(vnv, false, vn) + vnv = set_transformed!!(vnv, false, vn) return vnv end @@ -1093,13 +1098,13 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) # `vn` is only in `left`. val = getindex_internal(left_vnv, vn) f = gettransform(left_vnv, vn) - is_unconstrained[idx] = istrans(left_vnv, vn) + is_unconstrained[idx] = is_transformed(left_vnv, vn) else # `vn` is either in both or just `right`. # Note that in a `merge` the right value has precedence. val = getindex_internal(right_vnv, vn) f = gettransform(right_vnv, vn) - is_unconstrained[idx] = istrans(right_vnv, vn) + is_unconstrained[idx] = is_transformed(right_vnv, vn) end n = length(val) r = (offset + 1):(offset + n) @@ -1148,7 +1153,7 @@ function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName}) for vn in vnv.varnames if any(subsumes(vn_given, vn) for vn_given in vns_given) insert_internal!(vnv_new, getindex_internal(vnv, vn), vn, gettransform(vnv, vn)) - settrans!(vnv_new, istrans(vnv, vn), vn) + set_transformed!(vnv_new, is_transformed(vnv, vn), vn) end end diff --git a/test/Project.toml b/test/Project.toml index d40cb8f39..a3bced4c3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,7 +11,6 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -41,7 +40,6 @@ DifferentiationInterface = "0.6.41, 0.7" Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "1" -EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12, 1" JET = "0.9, 0.10" LogDensityProblems = "2" diff --git a/test/ad.jl b/test/ad.jl index 6bc0271e9..d7505aab2 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -77,48 +77,6 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest end end - @testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin - # Failing model - t = 1:0.05:8 - σ = 0.3 - y = @. rand(sin(t) + Normal(0, σ)) - @model function state_space(y, TT, ::Type{T}=Float64) where {T} - # Priors - α ~ Normal(y[1], 0.001) - τ ~ Exponential(1) - η ~ filldist(Normal(0, 1), TT - 1) - σ ~ Exponential(1) - # create latent variable - x = Vector{T}(undef, TT) - x[1] = α - for t in 2:TT - x[t] = x[t - 1] + η[t - 1] * τ - end - # measurement model - y ~ MvNormal(x, σ^2 * I) - return x - end - model = state_space(y, length(t)) - - # Dummy sampling algorithm for testing. The test case can only be replicated - # with a custom sampler, it doesn't work with SampleFromPrior(). We need to - # overload assume so that model evaluation doesn't fail due to a lack - # of implementation - struct MyEmptyAlg end - DynamicPPL.assume( - ::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi - ) = DynamicPPL.assume(dist, vn, vi) - - # Compiling the ReverseDiff tape used to fail here - spl = Sampler(MyEmptyAlg()) - sampling_model = contextualize(model, SamplingContext(model.context)) - ldf = LogDensityFunction( - sampling_model, getlogjoint_internal; adtype=AutoReverseDiff(; compile=true) - ) - x = ldf.varinfo[:] - @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Any - end - # Test that various different ways of specifying array types as arguments work with all # ADTypes. @testset "Array argument types" begin diff --git a/test/compiler.jl b/test/compiler.jl index 97121715a..b1309254e 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -193,11 +193,11 @@ module Issue537 end varinfo = VarInfo(model) @test getlogjoint(varinfo) == lp @test varinfo_ isa AbstractVarInfo - # During the model evaluation, its context is wrapped in a - # SamplingContext, so `model_` is not going to be equal to `model`. - # We can still check equality of `f` though. + # During the model evaluation, its leaf context is changed to an InitContext, so + # `model_` is not going to be equal to `model`. We can still check equality of `f` + # though. @test model_.f === model.f - @test model_.context isa SamplingContext + @test model_.context isa DynamicPPL.InitContext @test model_.context.rng isa Random.AbstractRNG # disable warnings @@ -598,13 +598,13 @@ module Issue537 end # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. @model empty_model() = return x = 1 empty_vi = VarInfo() - retval_and_vi = DynamicPPL.evaluate_and_sample!!(empty_model(), empty_vi) + retval_and_vi = DynamicPPL.init!!(empty_model(), empty_vi) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() if Threads.nthreads() > 1 @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} @@ -620,11 +620,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() diff --git a/test/contexts.jl b/test/contexts.jl index 597ab736c..972d833a5 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,5 +1,5 @@ using Test, DynamicPPL, Accessors -using AbstractPPL: getoptic +using AbstractPPL: getoptic, hasvalue, getvalue using DynamicPPL: leafcontext, setleafcontext, @@ -20,10 +20,9 @@ using DynamicPPL: hasconditioned_nested, getconditioned_nested, collapse_prefix_stack, - prefix_cond_and_fixed_variables, - getvalue - -using EnzymeCore + prefix_cond_and_fixed_variables +using LinearAlgebra: I +using Random: Xoshiro # TODO: Should we maybe put this in DPPL itself? function Base.iterate(context::AbstractContext) @@ -49,7 +48,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() contexts = Dict( :default => DefaultContext(), :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), - :sampling => SamplingContext(), :prefix => PrefixContext(@varname(x)), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( @@ -92,7 +90,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # here to split up arrays which could potentially have some, # but not all, elements being `missing`. conditioned_vns = mapreduce( - p -> DynamicPPL.TestUtils.varname_leaves(p.first, p.second), + p -> AbstractPPL.varname_leaves(p.first, p.second), vcat, pairs(conditioned_values), ) @@ -103,7 +101,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # sometimes only the main symbol (e.g. it contains `x` when # `vn` is `x[1]`) for vn in conditioned_vns - val = DynamicPPL.getvalue(conditioned_values, vn) + val = getvalue(conditioned_values, vn) # These VarNames are present in the conditioning values, so # we should always be able to extract the value. @test hasconditioned_nested(context, vn) @@ -150,11 +148,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() vn = @varname(x[1]) ctx1 = PrefixContext(@varname(a)) @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) - ctx2 = SamplingContext(ctx1) + ctx2 = ConditionContext(Dict{VarName,Any}(), ctx1) @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) ctx3 = PrefixContext(@varname(b), ctx2) @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) - ctx4 = DynamicPPL.SamplingContext(ctx3) + ctx4 = FixedContext(Dict(), ctx3) @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end @@ -165,29 +163,30 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test new_vn == @varname(a.x[1]) @test new_ctx == DefaultContext() - ctx2 = SamplingContext(PrefixContext(@varname(a))) + ctx2 = FixedContext((b=4,), PrefixContext(@varname(a))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn) @test new_vn == @varname(a.x[1]) - @test new_ctx == SamplingContext() + @test new_ctx == FixedContext((b=4,)) ctx3 = PrefixContext(@varname(a), ConditionContext((a=1,))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn) @test new_vn == @varname(a.x[1]) @test new_ctx == ConditionContext((a=1,)) - ctx4 = SamplingContext(PrefixContext(@varname(a), ConditionContext((a=1,)))) + ctx4 = FixedContext( + (b=4,), PrefixContext(@varname(a), ConditionContext((a=1,))) + ) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) @test new_vn == @varname(a.x[1]) - @test new_ctx == SamplingContext(ConditionContext((a=1,))) + @test new_ctx == FixedContext((b=4,), ConditionContext((a=1,))) end @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS prefix_vn = @varname(my_prefix) - context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) - sampling_model = contextualize(model, context) - # Sample with the context. - varinfo = DynamicPPL.VarInfo() - DynamicPPL.evaluate!!(sampling_model, varinfo) + context = DynamicPPL.PrefixContext(prefix_vn, DefaultContext()) + new_model = contextualize(model, context) + # Initialize a new varinfo with the prefixed model + _, varinfo = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) # Extract the resulting varnames vns_actual = Set(keys(varinfo)) @@ -202,22 +201,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "SamplingContext" begin - context = SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()) - @test context isa SamplingContext - - # convenience constructors - @test SamplingContext() == context - @test SamplingContext(Random.default_rng()) == context - @test SamplingContext(SampleFromPrior()) == context - @test SamplingContext(DefaultContext()) == context - @test SamplingContext(Random.default_rng(), SampleFromPrior()) == context - @test SamplingContext(Random.default_rng(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test EnzymeCore.EnzymeRules.inactive_type(typeof(context)) - end - @testset "ConditionContext" begin @testset "Nesting" begin @testset "NamedTuple" begin @@ -431,4 +414,246 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test fixed(c6) == Dict(@varname(a.b.d) => 2) end end + + @testset "InitContext" begin + empty_varinfos = [ + ("untyped+metadata", VarInfo()), + ("typed+metadata", DynamicPPL.typed_varinfo(VarInfo())), + ("untyped+VNV", VarInfo(DynamicPPL.VarNamedVector())), + ( + "typed+VNV", + DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), + ), + ("SVI+NamedTuple", SimpleVarInfo()), + ("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())), + ] + + @model function test_init_model() + x ~ Normal() + y ~ MvNormal(fill(x, 2), I) + 1.0 ~ Normal() + return nothing + end + + function test_generating_new_values(strategy::AbstractInitStrategy) + @testset "generating new values: $(typeof(strategy))" begin + # Check that init!! can generate values that weren't there + # previously. + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + this_vi = deepcopy(empty_vi) + _, vi = DynamicPPL.init!!(model, this_vi, strategy) + @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) + x, y = vi[@varname(x)], vi[@varname(y)] + @test x isa Real + @test y isa AbstractVector{<:Real} + @test length(y) == 2 + (; logprior, loglikelihood) = getlogp(vi) + @test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) == + logprior + @test logpdf(Normal(), 1.0) == loglikelihood + end + end + end + + function test_replacing_values(strategy::AbstractInitStrategy) + @testset "replacing old values: $(typeof(strategy))" begin + # Check that init!! can overwrite values that were already there. + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + # start by generating some rubbish values + vi = deepcopy(empty_vi) + old_x, old_y = 100000.00, [300000.00, 500000.00] + push!!(vi, @varname(x), old_x, Normal()) + push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) + # then overwrite it + _, new_vi = DynamicPPL.init!!(model, vi, strategy) + new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] + # check that the values are (presumably) different + @test old_x != new_x + @test old_y != new_y + end + end + end + + function test_rng_respected(strategy::AbstractInitStrategy) + @testset "check that RNG is respected: $(typeof(strategy))" begin + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + _, vi1 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi2 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi3 = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), strategy + ) + @test vi1[@varname(x)] == vi2[@varname(x)] + @test vi1[@varname(y)] == vi2[@varname(y)] + @test vi1[@varname(x)] != vi3[@varname(x)] + @test vi1[@varname(y)] != vi3[@varname(y)] + end + end + end + + function test_link_status_respected(strategy::AbstractInitStrategy) + @testset "check that varinfo linking is preserved: $(typeof(strategy))" begin + @model logn() = a ~ LogNormal() + model = logn() + vi = VarInfo(model) + linked_vi = DynamicPPL.link!!(vi, model) + _, new_vi = DynamicPPL.init!!(model, linked_vi, strategy) + @test DynamicPPL.is_transformed(new_vi) + # this is the unlinked value, since it uses `getindex` + a = new_vi[@varname(a)] + # internal logjoint should correspond to the transformed value + @test isapprox( + DynamicPPL.getlogjoint_internal(new_vi), logpdf(Normal(), log(a)) + ) + # user logjoint should correspond to the transformed value + @test isapprox(DynamicPPL.getlogjoint(new_vi), logpdf(LogNormal(), a)) + @test isapprox( + only(DynamicPPL.getindex_internal(new_vi, @varname(a))), log(a) + ) + end + end + + @testset "InitFromPrior" begin + test_generating_new_values(InitFromPrior()) + test_replacing_values(InitFromPrior()) + test_rng_respected(InitFromPrior()) + test_link_status_respected(InitFromPrior()) + + @testset "check that values are within support" begin + # Not many other sensible checks we can do for priors. + @model just_unif() = x ~ Uniform(0.0, 1e-7) + for _ in 1:100 + _, vi = DynamicPPL.init!!(just_unif(), VarInfo(), InitFromPrior()) + @test vi[@varname(x)] isa Real + @test 0.0 <= vi[@varname(x)] <= 1e-7 + end + end + end + + @testset "InitFromUniform" begin + test_generating_new_values(InitFromUniform()) + test_replacing_values(InitFromUniform()) + test_rng_respected(InitFromUniform()) + test_link_status_respected(InitFromUniform()) + + @testset "check that bounds are respected" begin + @testset "unconstrained" begin + umin, umax = -1.0, 1.0 + @model just_norm() = x ~ Normal() + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_norm(), VarInfo(), InitFromUniform(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test umin <= vi[@varname(x)] <= umax + end + end + @testset "constrained" begin + umin, umax = -1.0, 1.0 + @model just_beta() = x ~ Beta(2, 2) + inv_bijector = inverse(Bijectors.bijector(Beta(2, 2))) + tmin, tmax = inv_bijector(umin), inv_bijector(umax) + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_beta(), VarInfo(), InitFromUniform(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test tmin <= vi[@varname(x)] <= tmax + end + end + end + end + + @testset "InitFromParams" begin + test_link_status_respected(InitFromParams((; a=1.0))) + test_link_status_respected(InitFromParams(Dict(@varname(a) => 1.0))) + + @testset "given full set of parameters" begin + # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I) + my_x, my_y = 1.0, [2.0, 3.0] + params_nt = (; x=my_x, y=my_y) + params_dict = Dict(@varname(x) => my_x, @varname(y) => my_y) + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_nt) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_nt = getlogp(vi) + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_dict) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_dict = getlogp(vi) + @test logp_nt == logp_dict + end + end + + @testset "given only partial parameters" begin + my_x = 1.0 + params_nt = (; x=my_x) + params_dict = Dict(@varname(x) => my_x) + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + @testset "with InitFromPrior fallback" begin + _, vi = DynamicPPL.init!!( + Xoshiro(468), + model, + deepcopy(empty_vi), + InitFromParams(params_nt, InitFromPrior()), + ) + @test vi[@varname(x)] == my_x + nt_y = vi[@varname(y)] + @test nt_y isa AbstractVector{<:Real} + @test length(nt_y) == 2 + _, vi = DynamicPPL.init!!( + Xoshiro(469), + model, + deepcopy(empty_vi), + InitFromParams(params_dict, InitFromPrior()), + ) + @test vi[@varname(x)] == my_x + dict_y = vi[@varname(y)] + @test dict_y isa AbstractVector{<:Real} + @test length(dict_y) == 2 + # the values should be different since we used different seeds + @test dict_y != nt_y + end + + @testset "with no fallback" begin + # These just don't have an entry for `y`. + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_nt, nothing) + ) + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_dict, nothing) + ) + # We also explicitly test the case where `y = missing`. + params_nt_missing = (; x=my_x, y=missing) + params_dict_missing = Dict( + @varname(x) => my_x, @varname(y) => missing + ) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + InitFromParams(params_nt_missing, nothing), + ) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + InitFromParams(params_dict_missing, nothing), + ) + end + end + end + end + end end diff --git a/test/debug_utils.jl b/test/debug_utils.jl index 5bf741ff3..f950f6b45 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -149,7 +149,7 @@ model = demo_missing_in_multivariate([1.0, missing]) # Have to run this check_model call with an empty varinfo, because actually # instantiating the VarInfo would cause it to throw a MethodError. - model = contextualize(model, SamplingContext()) + model = contextualize(model, InitContext()) @test_throws ErrorException check_model(model, VarInfo(); error_on_failure=true) end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 6737cf056..8ed29e0c7 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -30,7 +30,7 @@ DynamicPPL.UntypedVarInfo # Evaluation works (and it would even do so in practice), but sampling - # fill fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`. + # will fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`. @model function demo4() x ~ Bernoulli() if x @@ -62,33 +62,37 @@ @testset "demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - sampling_model = contextualize(model, SamplingContext(model.context)) # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) - # Check that the inferred varinfo is indeed suitable for evaluation and sampling + # Check that the inferred varinfo is indeed suitable for evaluation f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, varinfo ) JET.test_call(f_eval, argtypes_eval) - f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - sampling_model, varinfo - ) - JET.test_call(f_sample, argtypes_sample) # For our demo models, they should all result in typed. is_typed = varinfo isa DynamicPPL.NTVarInfo @test is_typed - # If the test failed, check why it didn't infer a typed varinfo + # If the test failed, check what the type stability problem was for + # the typed varinfo. This is mostly useful for debugging from test + # logs. if !is_typed + @info "Model `$(model.f)` is not type stable with typed varinfo." typed_vi = DynamicPPL.typed_varinfo(model) - f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, typed_vi + + @info "Evaluating with DefaultContext:" + model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext()) + f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, varinfo ) - JET.test_call(f_eval, argtypes_eval) - f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - sampling_model, typed_vi + JET.test_call(f, argtypes) + + @info "Initialising with InitContext:" + model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) + f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, varinfo ) - JET.test_call(f_sample, argtypes_sample) + JET.test_call(f, argtypes) end end end diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 3ba5edfe1..79e13ad84 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -2,7 +2,12 @@ @model demo() = x ~ Normal() model = demo() - chain = MCMCChains.Chains(randn(1000, 2, 1), [:x, :y], Dict(:internals => [:y])) + chain = MCMCChains.Chains( + randn(1000, 2, 1), + [:x, :y], + Dict(:internals => [:y]); + info=(; varname_to_symbol=Dict(@varname(x) => :x)), + ) chain_generated = @test_nowarn returned(model, chain) @test size(chain_generated) == (1000, 1) @test mean(chain_generated) ≈ 0 atol = 0.1 diff --git a/test/ext/DynamicPPLMooncakeExt.jl b/test/ext/DynamicPPLMooncakeExt.jl index 986057da0..971956542 100644 --- a/test/ext/DynamicPPLMooncakeExt.jl +++ b/test/ext/DynamicPPLMooncakeExt.jl @@ -1,5 +1,9 @@ @testset "DynamicPPLMooncakeExt" begin Mooncake.TestUtils.test_rule( - StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true + StableRNG(123456), + is_transformed, + VarInfo(); + unsafe_perturb=true, + interface_only=true, ) end diff --git a/test/linking.jl b/test/linking.jl index cae101c72..2047b9d11 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -50,9 +50,9 @@ end # Specify the link-transform to use. Bijectors.bijector(dist::MyMatrixDistribution) = TrilToVec((dist.dim, dist.dim)) -function Bijectors.logpdf_with_trans(dist::MyMatrixDistribution, x, istrans::Bool) +function Bijectors.logpdf_with_trans(dist::MyMatrixDistribution, x, is_transformed::Bool) lp = logpdf(dist, x) - if istrans + if is_transformed lp = lp - logabsdetjac(bijector(dist), x) end diff --git a/test/lkj.jl b/test/lkj.jl index d581cd21b..5c5603aba 100644 --- a/test/lkj.jl +++ b/test/lkj.jl @@ -16,20 +16,15 @@ end # Same for both distributions target_mean = vec(Matrix{Float64}(I, 2, 2)) +n_samples = 1000 _lkj_atol = 0.05 @testset "Sample from x ~ LKJ(2, 1)" begin model = lkj_prior_demo() - # `SampleFromPrior` will sample in constrained space. - @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000; progress=false) - @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = - _lkj_atol - end - - # `SampleFromUniform` will sample in unconstrained space. - @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000; progress=false) + for init_strategy in [InitFromPrior(), InitFromUniform()] + samples = [ + last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples + ] @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = _lkj_atol end @@ -37,21 +32,10 @@ end @testset "Sample from x ~ LKJCholesky(2, 1, $(uplo))" for uplo in ['U', 'L'] model = lkj_chol_prior_demo(uplo) - # `SampleFromPrior` will sample in unconstrained space. - @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000; progress=false) - # Build correlation matrix from factor - corr_matrices = map(samples) do s - M = reshape(s.metadata.vals, (2, 2)) - pd_from_triangular(M, uplo) - end - @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol - end - - # `SampleFromUniform` will sample in unconstrained space. - @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000; progress=false) - # Build correlation matrix from factor + for init_strategy in [InitFromPrior(), InitFromUniform()] + samples = [ + last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples + ] corr_matrices = map(samples) do s M = reshape(s.metadata.vals, (2, 2)) pd_from_triangular(M, uplo) diff --git a/test/model.jl b/test/model.jl index 81f84e548..6ba3bca2a 100644 --- a/test/model.jl +++ b/test/model.jl @@ -71,7 +71,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() chain_sym_map = Dict{Symbol,Symbol}() for vn_parent in keys(var_info) sym = DynamicPPL.getsym(vn_parent) - vn_children = DynamicPPL.varname_leaves(vn_parent, var_info[vn_parent]) + vn_children = AbstractPPL.varname_leaves(vn_parent, var_info[vn_parent]) for vn_child in vn_children chain_sym_map[Symbol(vn_child)] = sym end @@ -155,24 +155,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() logjoint(model, chain) end - @testset "rng" begin - model = GDEMO_DEFAULT - - for sampler in (SampleFromPrior(), SampleFromUniform()) - for i in 1:10 - Random.seed!(100 + i) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) - vals = vi[:] - - Random.seed!(100 + i) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) - @test vi[:] == vals - end - end - end - @testset "defaults without VarInfo, Sampler, and Context" begin model = GDEMO_DEFAULT @@ -332,7 +314,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(OrderedDict()))) + vi = last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict{VarName,Any}()))) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) @@ -347,7 +329,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Extract varnames and values. vns_and_vals_xs = map( - collect ∘ Base.Fix1(DynamicPPL.varname_and_value_leaves, @varname(x)), xs + collect ∘ Base.Fix1(AbstractPPL.varname_and_value_leaves, @varname(x)), xs ) vns = map(first, first(vns_and_vals_xs)) vals = map(vns_and_vals_xs) do vns_and_vals @@ -513,7 +495,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Construct a chain with 'sampled values' of β ground_truth_β = 2 - β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [:β]) + β_chain = MCMCChains.Chains( + rand(Normal(ground_truth_β, 0.002), 1000), + [:β]; + info=(; varname_to_symbol=Dict(@varname(β) => :β)), + ) # Generate predictions from that chain xs_test = [10 + 0.1, 10 + 2 * 0.1] @@ -533,6 +519,23 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test Set(keys(predictions)) == Set([Symbol("y[1]"), Symbol("y[2]")]) end + @testset "include_all=true" begin + inc_predictions = DynamicPPL.predict( + m_lin_reg_test, β_chain; include_all=true + ) + @test Set(keys(inc_predictions)) == + Set([:β, Symbol("y[1]"), Symbol("y[2]")]) + @test inc_predictions[:β] == β_chain[:β] + # check rng is respected + inc_predictions1 = DynamicPPL.predict( + Xoshiro(468), m_lin_reg_test, β_chain; include_all=true + ) + inc_predictions2 = DynamicPPL.predict( + Xoshiro(468), m_lin_reg_test, β_chain; include_all=true + ) + @test all(Array(inc_predictions1) .== Array(inc_predictions2)) + end + @testset "accuracy" begin ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) @test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 @@ -559,7 +562,9 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "prediction from multiple chains" begin # Normal linreg model multiple_β_chain = MCMCChains.Chains( - reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), [:β] + reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), + [:β]; + info=(; varname_to_symbol=Dict(@varname(β) => :β)), ) predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) @test size(multiple_β_chain, 3) == size(predictions, 3) @@ -584,43 +589,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end end - - @testset "with AbstractVector{<:AbstractVarInfo}" begin - @model function linear_reg(x, y, σ=0.1) - β ~ Normal(1, 1) - for i in eachindex(y) - y[i] ~ Normal(β * x[i], σ) - end - end - - ground_truth_β = 2.0 - # the data will be ignored, as we are generating samples from the prior - xs_train = 1:0.1:10 - ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) - m_lin_reg = linear_reg(xs_train, ys_train) - chain = [ - last(DynamicPPL.evaluate_and_sample!!(m_lin_reg, VarInfo())) for - _ in 1:10000 - ] - - # chain is generated from the prior - @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 - - xs_test = [10 + 0.1, 10 + 2 * 0.1] - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) - predicted_vis = DynamicPPL.predict(m_lin_reg_test, chain) - - @test size(predicted_vis) == size(chain) - @test Set(keys(predicted_vis[1])) == - Set([@varname(β), @varname(y[1]), @varname(y[2])]) - # because β samples are from the prior, the std will be larger - @test mean([ - predicted_vis[i][@varname(y[1])] for i in eachindex(predicted_vis) - ]) ≈ 1.0 * xs_test[1] rtol = 0.1 - @test mean([ - predicted_vis[i][@varname(y[2])] for i in eachindex(predicted_vis) - ]) ≈ 1.0 * xs_test[2] rtol = 0.1 - end end @testset "ProductNamedTupleDistribution sampling" begin diff --git a/test/model_utils.jl b/test/model_utils.jl index 720ae55aa..af695dbf2 100644 --- a/test/model_utils.jl +++ b/test/model_utils.jl @@ -6,11 +6,11 @@ chain = make_chain_from_prior(model, 10) for (i, d) in enumerate(value_iterator_from_chain(model, chain)) for vn in keys(d) - val = DynamicPPL.getvalue(d, vn) + val = AbstractPPL.getvalue(d, vn) # Because value_iterator_from_chain groups varnames with # the same parent symbol, we have to ungroup them here - for vn_leaf in DynamicPPL.varname_leaves(vn, val) - val_leaf = DynamicPPL.getvalue(d, vn_leaf) + for vn_leaf in AbstractPPL.varname_leaves(vn, val) + val_leaf = AbstractPPL.getvalue(d, vn_leaf) @test val_leaf == chain[i, Symbol(vn_leaf), 1] end end diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index cfb222b66..be5f20010 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,4 +1,4 @@ -@testset "logdensities_likelihoods.jl" begin +@testset "pointwise_logdensities.jl" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -39,32 +39,35 @@ end @testset "pointwise_logdensities chain" begin - # We'll just test one, since `pointwise_logdensities(::Model, ::AbstractVarInfo)` is tested extensively, - # and this is what is used to implement `pointwise_logdensities(::Model, ::Chains)`. This test suite is just - # to ensure that we don't accidentally break the version on `Chains`. model = DynamicPPL.TestUtils.demo_assume_index_observe() - # FIXME(torfjelde): Make use of `varname_and_value_leaves` once we've introduced - # an impl of this for containers. - # NOTE(torfjelde): This only returns the varnames of the _random_ variables, i.e. excl. observed. vns = DynamicPPL.TestUtils.varnames(model) # Get some random `NamedTuple` samples from the prior. num_iters = 3 vals = [DynamicPPL.TestUtils.rand_prior_true(model) for _ in 1:num_iters] # Concatenate the vector representations and create a `Chains` from it. vals_arr = reduce(hcat, mapreduce(DynamicPPL.tovec, vcat, values(nt)) for nt in vals) - chain = Chains(permutedims(vals_arr), map(Symbol, vns)) + chain = Chains( + permutedims(vals_arr), + map(Symbol, vns); + info=(varname_to_symbol=Dict(vn => Symbol(vn) for vn in vns),), + ) # Compute the different pointwise logdensities. logjoints_pointwise = pointwise_logdensities(model, chain) logpriors_pointwise = pointwise_prior_logdensities(model, chain) loglikelihoods_pointwise = pointwise_loglikelihoods(model, chain) + # Check output type + @test logjoints_pointwise isa MCMCChains.Chains + @test logpriors_pointwise isa MCMCChains.Chains + @test loglikelihoods_pointwise isa MCMCChains.Chains + # Check that they contain the correct variables. - @test all(string(vn) in keys(logjoints_pointwise) for vn in vns) - @test all(string(vn) in keys(logpriors_pointwise) for vn in vns) - @test !any(Base.Fix2(startswith, "x"), keys(logpriors_pointwise)) - @test !any(string(vn) in keys(loglikelihoods_pointwise) for vn in vns) - @test all(Base.Fix2(startswith, "x"), keys(loglikelihoods_pointwise)) + @test all(Symbol(vn) in keys(logjoints_pointwise) for vn in vns) + @test all(Symbol(vn) in keys(logpriors_pointwise) for vn in vns) + @test !any(Base.Fix1(startswith, "x"), String.(keys(logpriors_pointwise))) + @test !any(Symbol(vn) in keys(loglikelihoods_pointwise) for vn in vns) + @test all(Base.Fix1(startswith, "x"), String.(keys(loglikelihoods_pointwise))) # Get the sum of the logjoints for each of the iterations. logjoints = [ diff --git a/test/runtests.jl b/test/runtests.jl index 2b92a023d..b6a3f7bf6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,7 +61,6 @@ include("test_util.jl") include("varinfo.jl") include("simple_varinfo.jl") include("model.jl") - include("sampler.jl") include("distribution_wrappers.jl") include("logdensityfunction.jl") include("linking.jl") diff --git a/test/sampler.jl b/test/sampler.jl deleted file mode 100644 index 5eb0da057..000000000 --- a/test/sampler.jl +++ /dev/null @@ -1,307 +0,0 @@ -@testset "sampler.jl" begin - @testset "initial_state and resume_from kwargs" begin - # Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our - # overloaded method. - @model f() = x ~ Normal() - model = f() - # This sampler just returns the state it was given as its 'sample' - struct S <: AbstractMCMC.AbstractSampler end - function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::Model, - sampler::Sampler{<:S}, - state=nothing; - kwargs..., - ) - if state === nothing - s = rand() - return s, s - else - return state, state - end - end - spl = Sampler(S()) - - function AbstractMCMC.bundle_samples( - samples::Vector{Float64}, - model::Model, - sampler::Sampler{<:S}, - state, - chain_type::Type{MCMCChains.Chains}; - kwargs..., - ) - return MCMCChains.Chains(samples, [:x]; info=(samplerstate=state,)) - end - - N_iters, N_chains = 10, 3 - - @testset "single-chain sampling" begin - chn = sample(model, spl, N_iters; progress=false, chain_type=MCMCChains.Chains) - initial_value = chn[:x][1] - @test all(chn[:x] .== initial_value) # sanity check - # using `initial_state` - chn2 = sample( - model, - spl, - N_iters; - progress=false, - initial_state=chn.info.samplerstate, - chain_type=MCMCChains.Chains, - ) - @test all(chn2[:x] .== initial_value) - # using `resume_from` - chn3 = sample( - model, - spl, - N_iters; - progress=false, - resume_from=chn, - chain_type=MCMCChains.Chains, - ) - @test all(chn3[:x] .== initial_value) - end - - @testset "multiple-chain sampling" begin - chn = sample( - model, - spl, - MCMCThreads(), - N_iters, - N_chains; - progress=false, - chain_type=MCMCChains.Chains, - ) - initial_value = chn[:x][1, :] - @test all(i -> chn[:x][i, :] == initial_value, 1:N_iters) # sanity check - # using `initial_state` - chn2 = sample( - model, - spl, - MCMCThreads(), - N_iters, - N_chains; - progress=false, - initial_state=chn.info.samplerstate, - chain_type=MCMCChains.Chains, - ) - @test all(i -> chn2[:x][i, :] == initial_value, 1:N_iters) - # using `resume_from` - chn3 = sample( - model, - spl, - MCMCThreads(), - N_iters, - N_chains; - progress=false, - resume_from=chn, - chain_type=MCMCChains.Chains, - ) - @test all(i -> chn3[:x][i, :] == initial_value, 1:N_iters) - end - end - - @testset "SampleFromPrior and SampleUniform" begin - @model function gdemo(x, y) - s ~ InverseGamma(2, 3) - m ~ Normal(2.0, sqrt(s)) - x ~ Normal(m, sqrt(s)) - return y ~ Normal(m, sqrt(s)) - end - - model = gdemo(1.0, 2.0) - N = 1_000 - - chains = sample(model, SampleFromPrior(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # Expected value of ``X`` where ``X ~ N(2, ...)`` is 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15 - - # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. - @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 - - chains = sample(model, SampleFromUniform(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # `m` is Gaussian, i.e. no transformation is used, so it - # should have a mean equal to its prior, i.e. 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1 - - # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. - @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 - end - - @testset "init" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - N = 1000 - chain_init = sample(model, SampleFromUniform(), N; progress=false) - - for vn in keys(first(chain_init)) - if AbstractPPL.subsumes(@varname(s), vn) - # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. - dist = InverseGamma(2, 3) - b = DynamicPPL.link_transform(dist) - @test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11 - elseif AbstractPPL.subsumes(@varname(m), vn) - # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. - @test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11 - else - error("Unknown variable name: $vn") - end - end - end - end - - @testset "Initial parameters" begin - # dummy algorithm that just returns initial value and does not perform any sampling - abstract type OnlyInitAlg end - struct OnlyInitAlgDefault <: OnlyInitAlg end - struct OnlyInitAlgUniform <: OnlyInitAlg end - function DynamicPPL.initialstep( - rng::Random.AbstractRNG, - model::Model, - ::Sampler{<:OnlyInitAlg}, - vi::AbstractVarInfo; - kwargs..., - ) - return vi, nothing - end - - # initial samplers - DynamicPPL.initialsampler(::Sampler{OnlyInitAlgUniform}) = SampleFromUniform() - @test DynamicPPL.initialsampler(Sampler(OnlyInitAlgDefault())) == SampleFromPrior() - - for alg in (OnlyInitAlgDefault(), OnlyInitAlgUniform()) - # model with one variable: initialization p = 0.2 - @model function coinflip() - p ~ Beta(1, 1) - return 10 ~ Binomial(25, p) - end - model = coinflip() - sampler = Sampler(alg) - lptrue = logpdf(Binomial(25, 0.2), 10) - let inits = (; p=0.2) - chain = sample(model, sampler, 1; initial_params=inits, progress=false) - @test chain[1].metadata.p.vals == [0.2] - @test getlogjoint(chain[1]) == lptrue - - # parallel sampling - chains = sample( - model, - sampler, - MCMCThreads(), - 1, - 10; - initial_params=fill(inits, 10), - progress=false, - ) - for c in chains - @test c[1].metadata.p.vals == [0.2] - @test getlogjoint(c[1]) == lptrue - end - end - - # model with two variables: initialization s = 4, m = -1 - @model function twovars() - s ~ InverseGamma(2, 3) - return m ~ Normal(0, sqrt(s)) - end - model = twovars() - lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) - for inits in ([4, -1], (; s=4, m=-1)) - chain = sample(model, sampler, 1; initial_params=inits, progress=false) - @test chain[1].metadata.s.vals == [4] - @test chain[1].metadata.m.vals == [-1] - @test getlogjoint(chain[1]) == lptrue - - # parallel sampling - chains = sample( - model, - sampler, - MCMCThreads(), - 1, - 10; - initial_params=fill(inits, 10), - progress=false, - ) - for c in chains - @test c[1].metadata.s.vals == [4] - @test c[1].metadata.m.vals == [-1] - @test getlogjoint(c[1]) == lptrue - end - end - - # set only m = -1 - for inits in ([missing, -1], (; s=missing, m=-1), (; m=-1)) - chain = sample(model, sampler, 1; initial_params=inits, progress=false) - @test !ismissing(chain[1].metadata.s.vals[1]) - @test chain[1].metadata.m.vals == [-1] - - # parallel sampling - chains = sample( - model, - sampler, - MCMCThreads(), - 1, - 10; - initial_params=fill(inits, 10), - progress=false, - ) - for c in chains - @test !ismissing(c[1].metadata.s.vals[1]) - @test c[1].metadata.m.vals == [-1] - end - end - - # specify `initial_params=nothing` - Random.seed!(1234) - chain1 = sample(model, sampler, 1; progress=false) - Random.seed!(1234) - chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false) - @test_throws DimensionMismatch sample( - model, sampler, 1; progress=false, initial_params=zeros(10) - ) - @test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals - @test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals - - # parallel sampling - Random.seed!(1234) - chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false) - Random.seed!(1234) - chains2 = sample( - model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false - ) - for (c1, c2) in zip(chains1, chains2) - @test c1[1].metadata.m.vals == c2[1].metadata.m.vals - @test c1[1].metadata.s.vals == c2[1].metadata.s.vals - end - end - - @testset "error handling" begin - # https://github.com/TuringLang/Turing.jl/issues/2452 - @model function constrained_uniform(n) - Z ~ Uniform(10, 20) - X = Vector{Float64}(undef, n) - for i in 1:n - X[i] ~ Uniform(0, Z) - end - end - - n = 2 - initial_z = 15 - initial_x = [0.2, 0.5] - model = constrained_uniform(n) - vi = VarInfo(model) - - @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, [initial_z, initial_x], model - ) - - @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, (X=initial_x, Z=initial_z), model - ) - end - end -end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 526fce92c..488cb8941 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -150,9 +150,9 @@ ("Dict", svi_dict), ("VarNamedVector", svi_vnv), # TODO(mhauru) Fix linked SimpleVarInfos to work with our test models. - # DynamicPPL.settrans!!(deepcopy(svi_nt), true), - # DynamicPPL.settrans!!(deepcopy(svi_dict), true), - # DynamicPPL.settrans!!(deepcopy(svi_vnv), true), + # DynamicPPL.set_transformed!!(deepcopy(svi_nt), true), + # DynamicPPL.set_transformed!!(deepcopy(svi_dict), true), + # DynamicPPL.set_transformed!!(deepcopy(svi_vnv), true), ) # Random seed is set in each `@testset`, so we need to sample # a new realization for `m` here. @@ -160,7 +160,7 @@ ### Sampling ### # Sample a new varinfo! - _, svi_new = DynamicPPL.evaluate_and_sample!!(model, svi) + _, svi_new = DynamicPPL.init!!(model, svi) # Realization for `m` should be different wp. 1. for vn in DynamicPPL.TestUtils.varnames(model) @@ -172,7 +172,7 @@ ### Evaluation ### values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - if DynamicPPL.istrans(svi) + if DynamicPPL.is_transformed(svi) _values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( model, values_eval_constrained... ) @@ -227,10 +227,12 @@ model = DynamicPPL.TestUtils.demo_dynamic_constraint() # Initialize. - svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.evaluate_and_sample!!(model, svi_nt)) - svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - svi_vnv = last(DynamicPPL.evaluate_and_sample!!(model, svi_vnv)) + svi_nt = DynamicPPL.set_transformed!!(SimpleVarInfo(), true) + svi_nt = last(DynamicPPL.init!!(model, svi_nt)) + svi_vnv = DynamicPPL.set_transformed!!( + SimpleVarInfo(DynamicPPL.VarNamedVector()), true + ) + svi_vnv = last(DynamicPPL.init!!(model, svi_vnv)) for svi in (svi_nt, svi_vnv) # Sample with large variations in unconstrained space. @@ -270,13 +272,13 @@ vi_linked = DynamicPPL.link!!(vi, model) # Make sure `maybe_invlink_before_eval!!` results in `invlink!!`. - @test !DynamicPPL.istrans( + @test !DynamicPPL.is_transformed( DynamicPPL.maybe_invlink_before_eval!!(deepcopy(vi), model) ) # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.evaluate_and_sample!!(model, deepcopy(vi))) - @test !DynamicPPL.istrans(vi_result) + vi_result = last(DynamicPPL.init!!(model, deepcopy(vi))) + @test !DynamicPPL.is_transformed(vi_result) # Set the values to something that is out of domain if we're in constrained space. for vn in keys(vi) diff --git a/test/test_util.jl b/test/test_util.jl index e04486760..164751c7b 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -72,7 +72,7 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I # We have to use varname_and_value_leaves so that each parameter is a scalar dicts = map(varinfos) do t vals = DynamicPPL.values_as(t, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) tuples = mapreduce(collect, vcat, iters) # The following loop is a replacement for: # push!(varnames, map(first, tuples)...) @@ -87,8 +87,10 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I varnames = collect(varnames) # Construct matrix of values vals = [get(dict, vn, missing) for dict in dicts, vn in varnames] + # Construct dict of varnames -> symbol + vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames))) # Construct and return the Chains object - return Chains(vals, varnames) + return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict)) end function make_chain_from_prior(model::Model, n_iters::Int) return make_chain_from_prior(Random.default_rng(), model, n_iters) diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 0421c89e2..522730566 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -68,8 +68,7 @@ @time model(vi) # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) + DynamicPPL.evaluate_threadsafe!!(model, vi) @test getlogjoint(vi) ≈ lp_w_threads # check that it's wrapped during the model evaluation @test vi_ isa DynamicPPL.ThreadSafeVarInfo @@ -77,7 +76,7 @@ @test vi isa VarInfo println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) + @time DynamicPPL.evaluate_threadsafe!!(model, vi) @model function wothreads(x) global vi_ = __varinfo__ @@ -104,13 +103,12 @@ @test lp_w_threads ≈ lp_wo_threads # Ensure that we use `VarInfo`. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) + DynamicPPL.evaluate_threadunsafe!!(model, vi) @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa VarInfo @test vi isa VarInfo println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) + @time DynamicPPL.evaluate_threadunsafe!!(model, vi) end end diff --git a/test/varinfo.jl b/test/varinfo.jl index ba7c17b34..6b31fbe91 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -22,11 +22,6 @@ function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) r = rand(dist) push!!(vi, vn, r, dist) r - elseif DynamicPPL.is_flagged(vi, vn, "del") - DynamicPPL.unset_flag!(vi, vn, "del") - r = rand(dist) - vi[vn] = DynamicPPL.tovec(r) - r else vi[vn] end @@ -42,7 +37,7 @@ end end model = gdemo(1.0, 2.0) - vi = DynamicPPL.untyped_varinfo(model, SampleFromUniform()) + _, vi = DynamicPPL.init!!(model, VarInfo(), InitFromUniform()) tvi = DynamicPPL.typed_varinfo(vi) meta = vi.metadata @@ -53,9 +48,7 @@ end ind = meta.idcs[vn] tind = fmeta.idcs[vn] @test meta.dists[ind] == fmeta.dists[tind] - for flag in keys(meta.flags) - @test meta.flags[flag][ind] == fmeta.flags[flag][tind] - end + @test meta.is_transformed[ind] == fmeta.is_transformed[tind] range = meta.ranges[ind] trange = fmeta.ranges[tind] @test all(meta.vals[range] .== fmeta.vals[trange]) @@ -290,9 +283,8 @@ end @test all_accs_same(vi, vi_orig) end - @testset "flags" begin - # Test flag setting: - # is_flagged, set_flag!, unset_flag! + @testset "is_transformed flag" begin + # Test is_transformed and set_transformed!! function test_varinfo!(vi) vn_x = @varname x dist = Normal(0, 1) @@ -300,14 +292,14 @@ end push!!(vi, vn_x, r, dist) - # del is set by default - @test !is_flagged(vi, vn_x, "del") + # is_transformed is set by default + @test !is_transformed(vi, vn_x) - set_flag!(vi, vn_x, "del") - @test is_flagged(vi, vn_x, "del") + vi = set_transformed!!(vi, true, vn_x) + @test is_transformed(vi, vn_x) - unset_flag!(vi, vn_x, "del") - @test !is_flagged(vi, vn_x, "del") + vi = set_transformed!!(vi, false, vn_x) + @test !is_transformed(vi, vn_x) end vi = VarInfo() test_varinfo!(vi) @@ -325,194 +317,13 @@ end @test typed_vi[vn_y] == 2.0 end - @testset "setval! & setval_and_resample!" begin - @model function testmodel(x) - n = length(x) - s ~ truncated(Normal(); lower=0) - m ~ MvNormal(zeros(n), I) - return x ~ MvNormal(m, s^2 * I) - end - - @model function testmodel_univariate(x, ::Type{TV}=Vector{Float64}) where {TV} - n = length(x) - s ~ truncated(Normal(); lower=0) - - m = TV(undef, n) - for i in eachindex(m) - m[i] ~ Normal() - end - - for i in eachindex(x) - x[i] ~ Normal(m[i], s) - end - end - - x = randn(5) - model_mv = testmodel(x) - model_uv = testmodel_univariate(x) - - for model in [model_uv, model_mv] - m_vns = model == model_uv ? [@varname(m[i]) for i in 1:5] : @varname(m) - s_vns = @varname(s) - - vi_typed = DynamicPPL.typed_varinfo(model) - vi_untyped = DynamicPPL.untyped_varinfo(model) - vi_vnv = DynamicPPL.untyped_vector_varinfo(model) - vi_vnv_typed = DynamicPPL.typed_vector_varinfo(model) - - model_name = model == model_uv ? "univariate" : "multivariate" - @testset "$(model_name), $(short_varinfo_name(vi))" for vi in [ - vi_untyped, vi_typed, vi_vnv, vi_vnv_typed - ] - Random.seed!(23) - vicopy = deepcopy(vi) - - ### `setval` ### - # TODO(mhauru) The interface here seems inconsistent between Metadata and - # VarNamedVector. I'm lazy to fix it though, because I think we need to - # rework it soon anyway. - if vi in [vi_vnv, vi_vnv_typed] - DynamicPPL.setval!(vicopy, zeros(5), m_vns) - else - DynamicPPL.setval!(vicopy, (m=zeros(5),)) - end - # Setting `m` fails for univariate due to limitations of `setval!` - # and `setval_and_resample!`. See docstring of `setval!` for more info. - if model == model_uv && vi in [vi_untyped, vi_typed] - @test_broken vicopy[m_vns] == zeros(5) - else - @test vicopy[m_vns] == zeros(5) - end - @test vicopy[s_vns] == vi[s_vns] - - # Ordering is NOT preserved => fails for multivariate model. - DynamicPPL.setval!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) - ) - if model == model_uv - @test vicopy[m_vns] == 1:5 - else - @test vicopy[m_vns] == [1, 3, 5, 4, 2] - end - @test vicopy[s_vns] == vi[s_vns] - - DynamicPPL.setval!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...) - ) - DynamicPPL.setval!(vicopy, (s=42,)) - @test vicopy[m_vns] == 1:5 - @test vicopy[s_vns] == 42 - - ### `setval_and_resample!` ### - if model == model_mv && vi == vi_untyped - # Trying to re-run model with `MvNormal` on `vi_untyped` will call - # `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError` - # so we skip this particular case. - continue - end - - if vi in [vi_vnv, vi_vnv_typed] - # `setval_and_resample!` works differently for `VarNamedVector`: All - # values will be resampled when model(vicopy) is called. Hence the below - # tests are not applicable. - continue - end - - vicopy = deepcopy(vi) - DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),)) - model(vicopy) - # Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)` - if model == model_uv - @test_broken vicopy[m_vns] == zeros(5) - else - @test vicopy[m_vns] == zeros(5) - end - @test vicopy[s_vns] != vi[s_vns] - - # Ordering is NOT preserved. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) - ) - model(vicopy) - if model == model_uv - @test vicopy[m_vns] == 1:5 - else - @test vicopy[m_vns] == [1, 3, 5, 4, 2] - end - @test vicopy[s_vns] != vi[s_vns] - - # Correct ordering. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...) - ) - model(vicopy) - @test vicopy[m_vns] == 1:5 - @test vicopy[s_vns] != vi[s_vns] - - DynamicPPL.setval_and_resample!(vicopy, (s=42,)) - model(vicopy) - @test vicopy[m_vns] != 1:5 - @test vicopy[s_vns] == 42 - end - end - - # https://github.com/TuringLang/DynamicPPL.jl/issues/250 - @model function demo() - return x ~ filldist(MvNormal([1, 100], I), 2) - end - - vi = VarInfo(demo()) - vals_prev = vi.metadata.x.vals - ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])] - DynamicPPL.setval!(vi, vi.metadata.x.vals, ks) - @test vals_prev == vi.metadata.x.vals - - DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks) - @test vals_prev == vi.metadata.x.vals - end - - @testset "setval! on chain" begin - # Define a helper function - """ - test_setval!(model, chain; sample_idx = 1, chain_idx = 1) - - Test `setval!` on `model` and `chain`. - - Worth noting that this only supports models containing symbols of the forms - `m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. - """ - function test_setval!(model, chain; sample_idx=1, chain_idx=1) - var_info = VarInfo(model) - θ_old = var_info[:] - DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) - θ_new = var_info[:] - @test θ_old != θ_new - vals = DynamicPPL.values_as(var_info, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) - for (n, v) in mapreduce(collect, vcat, iters) - n = string(n) - if Symbol(n) ∉ keys(chain) - # Assume it's a group - chain_val = vec( - MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] - ) - v_true = vec(v) - else - chain_val = chain[sample_idx, n, chain_idx] - v_true = v - end - - @test v_true == chain_val - end - end - + @testset "returned on MCMCChains.Chains" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS chain = make_chain_from_prior(model, 10) # A simple way of checking that the computation is determinstic: run twice and compare. res1 = returned(model, MCMCChains.get_sections(chain, :parameters)) res2 = returned(model, MCMCChains.get_sections(chain, :parameters)) @test all(res1 .== res2) - test_setval!(model, MCMCChains.get_sections(chain, :parameters)) end end @@ -533,25 +344,26 @@ end end model = gdemo([1.0, 1.5], [2.0, 2.5]) - # Check that instantiating the model using SampleFromUniform does not + # Check that instantiating the model using InitFromUniform does not # perform linking - # Note (penelopeysm): The purpose of using SampleFromUniform (SFU) - # specifically in this test is because SFU samples from the linked - # distribution i.e. in unconstrained space. However, it does this not - # by linking the varinfo but by transforming the distributions on the - # fly. That's why it's worth specifically checking that it can do this - # without having to change the VarInfo object. + # Note (penelopeysm): The purpose of using InitFromUniform specifically in + # this test is because it samples from the linked distribution i.e. in + # unconstrained space. However, it does this not by linking the varinfo + # but by transforming the distributions on the fly. That's why it's + # worth specifically checking that it can do this without having to + # change the VarInfo object. + # TODO(penelopeysm): Move this to InitFromUniform tests rather than here. vi = VarInfo() meta = vi.metadata - _, vi = DynamicPPL.evaluate_and_sample!!(model, vi, SampleFromUniform()) - @test all(x -> !istrans(vi, x), meta.vns) + _, vi = DynamicPPL.init!!(model, vi, InitFromUniform()) + @test all(x -> !is_transformed(vi, x), meta.vns) - # Check that linking and invlinking set the `trans` flag accordingly + # Check that linking and invlinking set the `is_transformed` flag accordingly v = copy(meta.vals) vi = link!!(vi, model) - @test all(x -> istrans(vi, x), meta.vns) + @test all(x -> is_transformed(vi, x), meta.vns) vi = invlink!!(vi, model) - @test all(x -> !istrans(vi, x), meta.vns) + @test all(x -> !is_transformed(vi, x), meta.vns) @test meta.vals ≈ v atol = 1e-10 # Check that linking and invlinking preserves the values @@ -562,14 +374,14 @@ end v_x = copy(meta.x.vals) v_y = copy(meta.y.vals) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) + @test all(x -> !is_transformed(vi, x), meta.s.vns) + @test all(x -> !is_transformed(vi, x), meta.m.vns) vi = link!!(vi, model) - @test all(x -> istrans(vi, x), meta.s.vns) - @test all(x -> istrans(vi, x), meta.m.vns) + @test all(x -> is_transformed(vi, x), meta.s.vns) + @test all(x -> is_transformed(vi, x), meta.m.vns) vi = invlink!!(vi, model) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) + @test all(x -> !is_transformed(vi, x), meta.s.vns) + @test all(x -> !is_transformed(vi, x), meta.m.vns) @test meta.s.vals ≈ v_s atol = 1e-10 @test meta.m.vals ≈ v_m atol = 1e-10 @@ -588,10 +400,10 @@ end @test !isempty(target_vns) @test !isempty(other_vns) vi = link!!(vi, (vn,), model) - @test all(x -> istrans(vi, x), target_vns) - @test all(x -> !istrans(vi, x), other_vns) + @test all(x -> is_transformed(vi, x), target_vns) + @test all(x -> !is_transformed(vi, x), other_vns) vi = invlink!!(vi, (vn,), model) - @test all(x -> !istrans(vi, x), all_vns) + @test all(x -> !is_transformed(vi, x), all_vns) @test meta.s.vals ≈ v_s atol = 1e-10 @test meta.m.vals ≈ v_m atol = 1e-10 @test meta.x.vals ≈ v_x atol = 1e-10 @@ -607,10 +419,10 @@ end function test_linked_varinfo(model, vi) # vn and dist are taken from the containing scope - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + vi = last(DynamicPPL.init!!(model, vi, InitFromPrior())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test istrans(vi, vn) + @test is_transformed(vi, vn) @test getlogjoint_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @test getlogprior_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @test getloglikelihood(vi) == 0.0 @@ -618,32 +430,32 @@ end @test getlogprior(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) end + ### `VarInfo` + # Need to run once since we can't specify that we want to _sample_ + # in the unconstrained space for `VarInfo` without having `vn` + # present in the `varinfo`. + ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) - vi = DynamicPPL.settrans!!(vi, true, vn) - test_linked_varinfo(model, vi) - - ## `typed_varinfo` - vi = DynamicPPL.typed_varinfo(model) - vi = DynamicPPL.settrans!!(vi, true, vn) + vi = DynamicPPL.set_transformed!!(vi, true, vn) test_linked_varinfo(model, vi) ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) - vi = DynamicPPL.settrans!!(vi, true, vn) + vi = DynamicPPL.set_transformed!!(vi, true, vn) test_linked_varinfo(model, vi) ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` - vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) + vi = DynamicPPL.set_transformed!!(SimpleVarInfo(), true) test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:Dict}` - vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict{VarName,Any}()), true) + vi = DynamicPPL.set_transformed!!(SimpleVarInfo(Dict{VarName,Any}()), true) test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:VarNamedVector}` - vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) + vi = DynamicPPL.set_transformed!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) test_linked_varinfo(model, vi) end @@ -731,7 +543,7 @@ end DynamicPPL.link(varinfo, model) end for vn in keys(varinfo) - @test DynamicPPL.istrans(varinfo_linked, vn) + @test DynamicPPL.is_transformed(varinfo_linked, vn) end @test length(varinfo[:]) > length(varinfo_linked[:]) varinfo_linked_unflattened = DynamicPPL.unflatten( @@ -989,7 +801,7 @@ end varinfo_left = VarInfo(model_left) varinfo_right = VarInfo(model_right) - varinfo_right = DynamicPPL.settrans!!(varinfo_right, true, @varname(x)) + varinfo_right = DynamicPPL.set_transformed!!(varinfo_right, true, @varname(x)) varinfo_merged = merge(varinfo_left, varinfo_right) vns = [@varname(x), @varname(y), @varname(z)] @@ -997,7 +809,7 @@ end # Right has precedence. @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] - @test DynamicPPL.istrans(varinfo_merged, @varname(x)) + @test DynamicPPL.is_transformed(varinfo_merged, @varname(x)) end end @@ -1012,45 +824,6 @@ end @test merge(vi_double, vi_single)[vn] == 1.0 end - @testset "sampling from linked varinfo" begin - # `~` - @model function demo(n=1) - x = Vector(undef, n) - for i in eachindex(x) - x[i] ~ Exponential() - end - return x - end - model1 = demo(1) - varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. - model2 = demo(2) - varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) - for vn in [@varname(x[1]), @varname(x[2])] - @test DynamicPPL.istrans(varinfo2, vn) - end - - # `.~` - @model function demo_dot(n=1) - x ~ Exponential() - if n > 1 - y = Vector(undef, n - 1) - y .~ Exponential() - end - return x - end - model1 = demo_dot(1) - varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. - model2 = demo_dot(2) - varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) - for vn in [@varname(x), @varname(y[1])] - @test DynamicPPL.istrans(varinfo2, vn) - end - end - # NOTE: It is not yet clear if this is something we want from all varinfo types. # Hence, we only test the `VarInfo` types here. @testset "vector_getranges for `VarInfo`" begin diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 57a8175d4..3fd76ffe2 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -570,9 +570,9 @@ end vn = @varname(t[1]) vns = vcat(test_vns, [vn]) vnv = DynamicPPL.setindex_internal!!(vnv, [2.0], vn, x -> x .^ 2) - DynamicPPL.settrans!(vnv, true, @varname(t[1])) + DynamicPPL.set_transformed!(vnv, true, @varname(t[1])) @test vnv[@varname(t[1])] == [4.0] - @test istrans(vnv, @varname(t[1])) + @test is_transformed(vnv, @varname(t[1])) @test subset(vnv, vns) == vnv end end @@ -610,9 +610,7 @@ end DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) # Is sampling correct? - varinfo_sample = last( - DynamicPPL.evaluate_and_sample!!(model, deepcopy(varinfo)) - ) + varinfo_sample = last(DynamicPPL.init!!(model, deepcopy(varinfo))) # Log density should be different. @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) # Values should be different.