diff --git a/HISTORY.md b/HISTORY.md index 24c0df3d0..91306c219 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,40 @@ # DynamicPPL Changelog +## 0.39.0 + +### Breaking changes + +#### Fast Log Density Functions + +This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. +Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. + +For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. + +As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it. +In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`. +If you were previously relying on this behaviour, you will need to store a VarInfo separately. + +#### Parent and leaf contexts + +The `DynamicPPL.NodeTrait` function has been removed. +Instead of implementing this, parent contexts should subtype `DynamicPPL.AbstractParentContext`. +This is an abstract type which requires you to overload two functions, `DynamicPPL.childcontext` and `DynamicPPL.setchildcontext`. + +There should generally be few reasons to define your own parent contexts (the only one we are aware of, outside of DynamicPPL itself, is `Turing.Inference.GibbsContext`), so this change should not really affect users. + +Leaf contexts require no changes, apart from a removal of the `NodeTrait` function. + +`ConditionContext` and `PrefixContext` are no longer exported. +You should not need to use these directly, please use `AbstractPPL.condition` and `DynamicPPL.prefix` instead. + +#### Miscellaneous + +Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. + +The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. +This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). + ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/Project.toml b/Project.toml index 23f5eec5b..1b5e52492 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.38.9" +version = "0.39.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index c154c5ca5..523889a7a 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -24,7 +24,7 @@ DynamicPPL = {path = "../"} ADTypes = "1.14.0" Chairmarks = "1.3.1" Distributions = "0.25.117" -DynamicPPL = "0.38" +DynamicPPL = "0.39" Enzyme = "0.13" ForwardDiff = "1" JSON = "1.3.0" diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 3af6573cf..e8ffa7e0b 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -98,12 +98,15 @@ function run(; to_json=false) }[] for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations - @info "Running benchmark for $model_name" + @info "Running benchmark for $model_name, $varinfo_choice, $adbackend, $islinked" relative_eval_time, relative_ad_eval_time = try results = benchmark(model, varinfo_choice, adbackend, islinked) + @info " t(eval) = $(results.primal_time)" + @info " t(grad) = $(results.grad_time)" (results.primal_time / reference_time), (results.grad_time / results.primal_time) catch e + @info "benchmark errored: $e" missing, missing end push!( @@ -155,18 +158,33 @@ function combine(head_filename::String, base_filename::String) all_testcases = union(Set(keys(head_testcases)), Set(keys(base_testcases))) @info "$(length(all_testcases)) unique test cases found" sorted_testcases = sort( - collect(all_testcases); by=(c -> (c.model_name, c.ad_backend, c.varinfo, c.linked)) + collect(all_testcases); by=(c -> (c.model_name, c.linked, c.varinfo, c.ad_backend)) ) results_table = Tuple{ - String,Int,String,String,Bool,String,String,String,String,String,String + String, + Int, + String, + String, + Bool, + String, + String, + String, + String, + String, + String, + String, + String, + String, }[] + sublabels = ["base", "this PR", "speedup"] results_colnames = [ [ EmptyCells(5), MultiColumn(3, "t(eval) / t(ref)"), MultiColumn(3, "t(grad) / t(eval)"), + MultiColumn(3, "t(grad) / t(ref)"), ], - [colnames[1:5]..., "base", "this PR", "speedup", "base", "this PR", "speedup"], + [colnames[1:5]..., sublabels..., sublabels..., sublabels...], ] sprint_float(x::Float64) = @sprintf("%.2f", x) sprint_float(m::Missing) = "err" @@ -183,6 +201,10 @@ function combine(head_filename::String, base_filename::String) # Finally that lets us do this division safely speedup_eval = base_eval / head_eval speedup_grad = base_grad / head_grad + # As well as this multiplication, which is t(grad) / t(ref) + head_grad_vs_ref = head_grad * head_eval + base_grad_vs_ref = base_grad * base_eval + speedup_grad_vs_ref = base_grad_vs_ref / head_grad_vs_ref push!( results_table, ( @@ -197,6 +219,9 @@ function combine(head_filename::String, base_filename::String) sprint_float(base_grad), sprint_float(head_grad), sprint_float(speedup_grad), + sprint_float(base_grad_vs_ref), + sprint_float(head_grad_vs_ref), + sprint_float(speedup_grad_vs_ref), ), ) end @@ -212,7 +237,10 @@ function combine(head_filename::String, base_filename::String) backend=:text, fit_table_in_display_horizontally=false, fit_table_in_display_vertically=false, - table_format=TextTableFormat(; horizontal_line_at_merged_column_labels=true), + table_format=TextTableFormat(; + horizontal_line_at_merged_column_labels=true, + horizontal_lines_at_data_rows=collect(3:3:length(results_table)), + ), ) println("```") end diff --git a/docs/Project.toml b/docs/Project.toml index 03a3ff0a0..10a4a5c8a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -21,7 +21,7 @@ Accessors = "0.1" Distributions = "0.25" Documenter = "1" DocumenterMermaid = "0.1, 0.2" -DynamicPPL = "0.38" +DynamicPPL = "0.39" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" JET = "0.9, 0.10, 0.11" diff --git a/docs/src/api.md b/docs/src/api.md index bbe39fb73..adb476db5 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -66,6 +66,12 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte LogDensityFunction ``` +Internally, this is accomplished using [`init!!`](@ref) on: + +```@docs +OnlyAccsVarInfo +``` + ## Condition and decondition A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref). @@ -170,6 +176,12 @@ DynamicPPL.prefix ## Utilities +`typed_identity` is the same as `identity`, but with an overload for `with_logabsdet_jacobian` that ensures that it never errors. + +```@docs +typed_identity +``` + It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function. ```@docs @@ -352,13 +364,6 @@ Base.empty! SimpleVarInfo ``` -### Tilde-pipeline - -```@docs -tilde_assume!! -tilde_observe!! -``` - ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. @@ -463,22 +468,55 @@ By default, it does not perform any actual sampling: it only evaluates the model If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this. The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. -Contexts are subtypes of `AbstractPPL.AbstractContext`. + +All contexts are subtypes of `AbstractPPL.AbstractContext`. + +Contexts are split into two kinds: + +**Leaf contexts**: These are the most important contexts as they ultimately decide how model evaluation proceeds. +For example, `DefaultContext` evaluates the model using values stored inside a VarInfo's metadata, whereas `InitContext` obtains new values either by sampling or from a known set of parameters. +DynamicPPL has more leaf contexts which are used for internal purposes, but these are the two that are exported. ```@docs DefaultContext -PrefixContext -ConditionContext InitContext ``` +To implement a leaf context, you need to subtype `AbstractPPL.AbstractContext` and implement the `tilde_assume!!` and `tilde_observe!!` methods for your context. + +```@docs +tilde_assume!! +tilde_observe!! +``` + +**Parent contexts**: These essentially act as 'modifiers' for leaf contexts. +For example, `PrefixContext` adds a prefix to all variable names during evaluation, while `ConditionContext` marks certain variables as observed. + +To implement a parent context, you have to subtype `DynamicPPL.AbstractParentContext`, and implement the `childcontext` and `setchildcontext` methods. +If needed, you can also implement `tilde_assume!!` and `tilde_observe!!` for your context. +This is optional; the default implementation is to simply delegate to the child context. + +```@docs +AbstractParentContext +childcontext +setchildcontext +``` + +Since contexts form a tree structure, these functions are automatically defined for manipulating context stacks. +They are mainly useful for modifying the fundamental behaviour (i.e. the leaf context), without affecting any of the modifiers (i.e. parent contexts). + +```@docs +leafcontext +setleafcontext +``` + ### VarInfo initialisation The function `init!!` is used to initialise, or overwrite, values in a VarInfo. It is really a thin wrapper around using `evaluate!!` with an `InitContext`. ```@docs -DynamicPPL.init!! +init!! ``` To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. @@ -491,10 +529,12 @@ InitFromParams ``` If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. +In very rare situations, you may also need to implement `get_param_eltype`, which defines the element type of the parameters generated by the strategy. ```@docs -DynamicPPL.AbstractInitStrategy -DynamicPPL.init +AbstractInitStrategy +init +get_param_eltype ``` ### Choosing a suitable VarInfo diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index 35159636f..ef21c255b 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -1,16 +1,15 @@ module DynamicPPLEnzymeCoreExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL - using EnzymeCore -else - using ..DynamicPPL: DynamicPPL - using ..EnzymeCore -end +using DynamicPPL: DynamicPPL +using EnzymeCore # Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) = nothing +# Likewise for get_range_and_linked. +@inline EnzymeCore.EnzymeRules.inactive( + ::typeof(DynamicPPL._get_range_and_linked), args... +) = nothing end diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl index 7ea51918f..4b6d3fb41 100644 --- a/ext/DynamicPPLForwardDiffExt.jl +++ b/ext/DynamicPPLForwardDiffExt.jl @@ -8,12 +8,8 @@ use_dynamicppl_tag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true use_dynamicppl_tag(::ADTypes.AutoForwardDiff) = false function DynamicPPL.tweak_adtype( - ad::ADTypes.AutoForwardDiff{chunk_size}, - ::DynamicPPL.Model, - vi::DynamicPPL.AbstractVarInfo, + ad::ADTypes.AutoForwardDiff{chunk_size}, ::DynamicPPL.Model, params::AbstractVector ) where {chunk_size} - params = vi[:] - # Use DynamicPPL tag to improve stack traces # https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ # NOTE: DifferentiationInterface disables tag checking if the diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index d8c343917..e74f0b8a9 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -140,6 +140,43 @@ function AbstractMCMC.to_samples( end end +function AbstractMCMC.bundle_samples( + ts::Vector{<:DynamicPPL.ParamsWithStats}, + model::DynamicPPL.Model, + spl::AbstractMCMC.AbstractSampler, + state, + chain_type::Type{MCMCChains.Chains}; + save_state=false, + stats=missing, + sort_chain=false, + discard_initial=0, + thinning=1, + kwargs..., +) + bare_chain = AbstractMCMC.from_samples(MCMCChains.Chains, reshape(ts, :, 1)) + + # Add additional MCMC-specific info + info = bare_chain.info + if save_state + info = merge(info, (model=model, sampler=spl, samplerstate=state)) + end + if !ismissing(stats) + info = merge(info, (start_time=stats.start, stop_time=stats.stop)) + end + + # Reconstruct the chain with the extra information + # Yeah, this is quite ugly. Blame MCMCChains. + chain = MCMCChains.Chains( + bare_chain.value.data, + names(bare_chain), + bare_chain.name_map; + info=info, + start=discard_initial + 1, + thin=thinning, + ) + return sort_chain ? sort(chain) : chain +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 2155fa161..8b3040757 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -6,8 +6,13 @@ using MarginalLogDensities: MarginalLogDensities # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by # MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type # below. -struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction} +struct LogDensityFunctionWrapper{ + L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo +} logdensity::L + # This field is used only to reconstruct the VarInfo later on; it's not needed for the + # actual log-density evaluation. + varinfo::V end function (lw::LogDensityFunctionWrapper)(x, _) return LogDensityProblems.logdensity(lw.logdensity, x) @@ -101,7 +106,7 @@ function DynamicPPL.marginalize( # Construct the marginal log-density model. f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) mld = MarginalLogDensities.MarginalLogDensity( - LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs... + LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs... ) return mld end @@ -190,7 +195,7 @@ function DynamicPPL.VarInfo( unmarginalized_params::Union{AbstractVector,Nothing}=nothing, ) # Extract the original VarInfo. Its contents will in general be junk. - original_vi = mld.logdensity.logdensity.varinfo + original_vi = mld.logdensity.varinfo # Extract the stored parameters, which includes the modes for any marginalized # parameters full_params = MarginalLogDensities.cached_params(mld) diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 23a3430eb..8adf66030 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -5,5 +5,8 @@ using Mooncake: Mooncake # This is purely an optimisation. Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} +Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{ + typeof(DynamicPPL._get_range_and_linked),Vararg +} end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e66f3fe11..a885f6a96 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -84,30 +84,41 @@ export AbstractVarInfo, # Compiler @model, # Utilities - init, OrderedDict, + typed_identity, # Model Model, getmissings, getargnames, extract_priors, values_as_in_model, + # evaluation + evaluate!!, + init!!, # LogDensityFunction LogDensityFunction, - # Contexts + OnlyAccsVarInfo, + # Leaf contexts + AbstractContext, contextualize, DefaultContext, - PrefixContext, - ConditionContext, + InitContext, + # Parent contexts + AbstractParentContext, + childcontext, + setchildcontext, + leafcontext, + setleafcontext, # Tilde pipeline tilde_assume!!, tilde_observe!!, # Initialisation - InitContext, AbstractInitStrategy, InitFromPrior, InitFromUniform, InitFromParams, + init, + get_param_eltype, # Pseudo distributions NamedDist, NoDist, @@ -188,12 +199,14 @@ include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") include("simple_varinfo.jl") +include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") +include("experimental.jl") include("chains.jl") include("bijector.jl") @@ -201,7 +214,6 @@ include("debug_utils.jl") using .DebugUtils include("test_utils.jl") -include("experimental.jl") include("deprecated.jl") if isdefined(Base.Experimental, :register_error_hint) diff --git a/src/chains.jl b/src/chains.jl index 2b5976b9b..4660a1a31 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -133,3 +133,60 @@ function ParamsWithStats( end return ParamsWithStats(params, stats) end + +""" + ParamsWithStats( + param_vector::AbstractVector, + ldf::DynamicPPL.LogDensityFunction, + stats::NamedTuple=NamedTuple(); + include_colon_eq::Bool=true, + include_log_probs::Bool=true, + ) + +Generate a `ParamsWithStats` by re-evaluating the given `ldf` with the provided +`param_vector`. + +This method is intended to replace the old method of obtaining parameters and statistics +via `unflatten` plus re-evaluation. It is faster for two reasons: + +1. It does not rely on `deepcopy`-ing the VarInfo object (this used to be mandatory as + otherwise re-evaluation would mutate the VarInfo, rendering it unusable for subsequent + MCMC iterations). +2. The re-evaluation is faster as it uses `OnlyAccsVarInfo`. +""" +function ParamsWithStats( + param_vector::AbstractVector, + ldf::DynamicPPL.LogDensityFunction{Tlink}, + stats::NamedTuple=NamedTuple(); + include_colon_eq::Bool=true, + include_log_probs::Bool=true, +) where {Tlink} + strategy = InitFromParams( + VectorWithRanges{Tlink}( + ldf._iden_varname_ranges, ldf._varname_ranges, param_vector + ), + nothing, + ) + accs = if include_log_probs + ( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq), + ) + else + (DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),) + end + _, vi = DynamicPPL.init!!(ldf.model, OnlyAccsVarInfo(AccumulatorTuple(accs)), strategy) + params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values + if include_log_probs + stats = merge( + stats, + ( + logprior=DynamicPPL.getlogprior(vi), + loglikelihood=DynamicPPL.getloglikelihood(vi), + lp=DynamicPPL.getlogjoint(vi), + ), + ) + end + return ParamsWithStats(params, stats) +end diff --git a/src/compiler.jl b/src/compiler.jl index badba9f9d..3324780ca 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -718,14 +718,15 @@ end # TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? # TODO(mhauru) This function needs a more comprehensive docstring. """ - matchingvalue(vi, value) + matchingvalue(param_eltype, value) -Convert the `value` to the correct type for the `vi` object. +Convert the `value` to the correct type, given the element type of the parameters +being used to evaluate the model. """ -function matchingvalue(vi, value) +function matchingvalue(param_eltype, value) T = typeof(value) if hasmissing(T) - _value = convert(get_matching_type(vi, T), value) + _value = convert(get_matching_type(param_eltype, T), value) # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we # are happy to return `value` as-is? if _value === value @@ -738,29 +739,30 @@ function matchingvalue(vi, value) end end -function matchingvalue(vi, value::FloatOrArrayType) - return get_matching_type(vi, value) +function matchingvalue(param_eltype, value::FloatOrArrayType) + return get_matching_type(param_eltype, value) end -function matchingvalue(vi, ::TypeWrap{T}) where {T} - return TypeWrap{get_matching_type(vi, T)}() +function matchingvalue(param_eltype, ::TypeWrap{T}) where {T} + return TypeWrap{get_matching_type(param_eltype, T)}() end # TODO(mhauru) This function needs a more comprehensive docstring. What is it for? """ - get_matching_type(vi, ::TypeWrap{T}) where {T} + get_matching_type(param_eltype, ::TypeWrap{T}) where {T} -Get the specialized version of type `T` for `vi`. +Get the specialized version of type `T`, given an element type of the parameters +being used to evaluate the model. """ get_matching_type(_, ::Type{T}) where {T} = T -function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) - return Union{Missing,float_type_with_fallback(eltype(vi))} +function get_matching_type(param_eltype, ::Type{<:Union{Missing,AbstractFloat}}) + return Union{Missing,float_type_with_fallback(param_eltype)} end -function get_matching_type(vi, ::Type{<:AbstractFloat}) - return float_type_with_fallback(eltype(vi)) +function get_matching_type(param_eltype, ::Type{<:AbstractFloat}) + return float_type_with_fallback(param_eltype) end -function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(vi, T),N} +function get_matching_type(param_eltype, ::Type{<:Array{T,N}}) where {T,N} + return Array{get_matching_type(param_eltype, T),N} end -function get_matching_type(vi, ::Type{<:Array{T}}) where {T} - return Array{get_matching_type(vi, T)} +function get_matching_type(param_eltype, ::Type{<:Array{T}}) where {T} + return Array{get_matching_type(param_eltype, T)} end diff --git a/src/contexts.jl b/src/contexts.jl index 32a236e8e..46c5b8855 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,48 +1,32 @@ """ - NodeTrait(context) - NodeTrait(f, context) + AbstractParentContext -Specifies the role of `context` in the context-tree. +An abstract context that has a child context. -The officially supported traits are: -- `IsLeaf`: `context` does not have any decendants. -- `IsParent`: `context` has a child context to which we often defer. - Expects the following methods to be implemented: - - [`childcontext`](@ref) - - [`setchildcontext`](@ref) -""" -abstract type NodeTrait end -NodeTrait(_, context) = NodeTrait(context) - -""" - IsLeaf - -Specifies that the context is a leaf in the context-tree. -""" -struct IsLeaf <: NodeTrait end -""" - IsParent +Subtypes of `AbstractParentContext` must implement the following interface: -Specifies that the context is a parent in the context-tree. +- `DynamicPPL.childcontext(context::AbstractParentContext)`: Return the child context. +- `DynamicPPL.setchildcontext(parent::AbstractParentContext, child::AbstractContext)`: Reconstruct + `parent` but now using `child` as its child context. """ -struct IsParent <: NodeTrait end +abstract type AbstractParentContext <: AbstractContext end """ - childcontext(context) + childcontext(context::AbstractParentContext) Return the descendant context of `context`. """ childcontext """ - setchildcontext(parent::AbstractContext, child::AbstractContext) + setchildcontext(parent::AbstractParentContext, child::AbstractContext) Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), effectively updating the child context. # Examples ```jldoctest -julia> using DynamicPPL: DynamicTransformationContext +julia> using DynamicPPL: DynamicTransformationContext, ConditionContext julia> ctx = ConditionContext((; a = 1)); @@ -60,12 +44,11 @@ setchildcontext """ leafcontext(context::AbstractContext) -Return the leaf of `context`, i.e. the first descendant context that `IsLeaf`. +Return the leaf of `context`, i.e. the first descendant context that is not an +`AbstractParentContext`. """ -leafcontext(context::AbstractContext) = - leafcontext(NodeTrait(leafcontext, context), context) -leafcontext(::IsLeaf, context::AbstractContext) = context -leafcontext(::IsParent, context::AbstractContext) = leafcontext(childcontext(context)) +leafcontext(context::AbstractContext) = context +leafcontext(context::AbstractParentContext) = leafcontext(childcontext(context)) """ setleafcontext(left::AbstractContext, right::AbstractContext) @@ -80,12 +63,10 @@ original leaf context of `left`. ```jldoctest julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, DynamicTransformationContext -julia> struct ParentContext{C} <: AbstractContext +julia> struct ParentContext{C} <: AbstractParentContext context::C end -julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() - julia> DynamicPPL.childcontext(context::ParentContext) = context.context julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) @@ -104,21 +85,10 @@ julia> # Append another parent context. ParentContext(ParentContext(ParentContext(DefaultContext()))) ``` """ -function setleafcontext(left::AbstractContext, right::AbstractContext) - return setleafcontext( - NodeTrait(setleafcontext, left), NodeTrait(setleafcontext, right), left, right - ) -end -function setleafcontext( - ::IsParent, ::IsParent, left::AbstractContext, right::AbstractContext -) +function setleafcontext(left::AbstractParentContext, right::AbstractContext) return setchildcontext(left, setleafcontext(childcontext(left), right)) end -function setleafcontext(::IsParent, ::IsLeaf, left::AbstractContext, right::AbstractContext) - return setchildcontext(left, setleafcontext(childcontext(left), right)) -end -setleafcontext(::IsLeaf, ::IsParent, left::AbstractContext, right::AbstractContext) = right -setleafcontext(::IsLeaf, ::IsLeaf, left::AbstractContext, right::AbstractContext) = right +setleafcontext(::AbstractContext, right::AbstractContext) = right """ DynamicPPL.tilde_assume!!( @@ -138,10 +108,15 @@ This function should return a tuple `(x, vi)`, where `x` is the sampled value (w must be in unlinked space!) and `vi` is the updated VarInfo. """ function tilde_assume!!( - context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo + context::AbstractParentContext, right::Distribution, vn::VarName, vi::AbstractVarInfo ) return tilde_assume!!(childcontext(context), right, vn, vi) end +function tilde_assume!!( + context::AbstractContext, ::Distribution, ::VarName, ::AbstractVarInfo +) + return error("tilde_assume!! not implemented for context of type $(typeof(context))") +end """ DynamicPPL.tilde_observe!!( @@ -171,7 +146,7 @@ This function should return a tuple `(left, vi)`, where `left` is the same as th `vi` is the updated VarInfo. """ function tilde_observe!!( - context::AbstractContext, + context::AbstractParentContext, right::Distribution, left, vn::Union{VarName,Nothing}, @@ -179,3 +154,12 @@ function tilde_observe!!( ) return tilde_observe!!(childcontext(context), right, left, vn, vi) end +function tilde_observe!!( + context::AbstractContext, + ::Distribution, + ::Any, + ::Union{VarName,Nothing}, + ::AbstractVarInfo, +) + return error("tilde_observe!! not implemented for context of type $(typeof(context))") +end diff --git a/src/contexts/conditionfix.jl b/src/contexts/conditionfix.jl index d3802de85..7a34db5cb 100644 --- a/src/contexts/conditionfix.jl +++ b/src/contexts/conditionfix.jl @@ -11,7 +11,7 @@ when there are varnames that cannot be represented as symbols, e.g. """ struct ConditionContext{ Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext -} <: AbstractContext +} <: AbstractParentContext values::Values context::Ctx end @@ -41,9 +41,10 @@ function Base.show(io::IO, context::ConditionContext) return print(io, "ConditionContext($(context.values), $(childcontext(context)))") end -NodeTrait(::ConditionContext) = IsParent() childcontext(context::ConditionContext) = context.context -setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) +function setchildcontext(parent::ConditionContext, child::AbstractContext) + return ConditionContext(parent.values, child) +end """ hasconditioned(context::AbstractContext, vn::VarName) @@ -76,11 +77,8 @@ Return `true` if `vn` is found in `context` or any of its descendants. This is contrast to [`hasconditioned(::AbstractContext, ::VarName)`](@ref) which only checks for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. """ -function hasconditioned_nested(context::AbstractContext, vn) - return hasconditioned_nested(NodeTrait(hasconditioned_nested, context), context, vn) -end -hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) -function hasconditioned_nested(::IsParent, context, vn) +hasconditioned_nested(context::AbstractContext, vn) = hasconditioned(context, vn) +function hasconditioned_nested(context::AbstractParentContext, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end function hasconditioned_nested(context::PrefixContext, vn) @@ -96,15 +94,12 @@ This is contrast to [`getconditioned`](@ref) which only returns the value `vn` i not recursively looking into its descendants. """ function getconditioned_nested(context::AbstractContext, vn) - return getconditioned_nested(NodeTrait(getconditioned_nested, context), context, vn) -end -function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getconditioned_nested(context::PrefixContext, vn) return getconditioned_nested(collapse_prefix_stack(context), vn) end -function getconditioned_nested(::IsParent, context, vn) +function getconditioned_nested(context::AbstractParentContext, vn) return if hasconditioned(context, vn) getconditioned(context, vn) else @@ -113,7 +108,7 @@ function getconditioned_nested(::IsParent, context, vn) end """ - decondition(context::AbstractContext, syms...) + decondition_context(context::AbstractContext, syms...) Return `context` but with `syms` no longer conditioned on. @@ -121,13 +116,10 @@ Note that this recursively traverses contexts, deconditioning all along the way. See also: [`condition`](@ref) """ -decondition_context(::IsLeaf, context, args...) = context -function decondition_context(::IsParent, context, args...) +decondition_context(context::AbstractContext, args...) = context +function decondition_context(context::AbstractParentContext, args...) return setchildcontext(context, decondition_context(childcontext(context), args...)) end -function decondition_context(context, args...) - return decondition_context(NodeTrait(context), context, args...) -end function decondition_context(context::ConditionContext) return decondition_context(childcontext(context)) end @@ -160,11 +152,8 @@ Return `NamedTuple` of values that are conditioned on under context`. Note that this will recursively traverse the context stack and return a merged version of the condition values. """ -function conditioned(context::AbstractContext) - return conditioned(NodeTrait(conditioned, context), context) -end -conditioned(::IsLeaf, context) = NamedTuple() -conditioned(::IsParent, context) = conditioned(childcontext(context)) +conditioned(::AbstractContext) = NamedTuple() +conditioned(context::AbstractParentContext) = conditioned(childcontext(context)) function conditioned(context::ConditionContext) # Note the order of arguments to `merge`. The behavior of the rest of DPPL # is that the outermost `context` takes precendence, hence when resolving @@ -176,7 +165,7 @@ function conditioned(context::PrefixContext) return conditioned(collapse_prefix_stack(context)) end -struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext +struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractParentContext values::Values context::Ctx end @@ -197,16 +186,17 @@ function Base.show(io::IO, context::FixedContext) return print(io, "FixedContext($(context.values), $(childcontext(context)))") end -NodeTrait(::FixedContext) = IsParent() childcontext(context::FixedContext) = context.context -setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child) +function setchildcontext(parent::FixedContext, child::AbstractContext) + return FixedContext(parent.values, child) +end """ hasfixed(context::AbstractContext, vn::VarName) Return `true` if a fixed value for `vn` is found in `context`. """ -hasfixed(context::AbstractContext, vn::VarName) = false +hasfixed(::AbstractContext, ::VarName) = false hasfixed(context::FixedContext, vn::VarName) = hasvalue(context.values, vn) function hasfixed(context::FixedContext, vns::AbstractArray{<:VarName}) return all(Base.Fix1(hasvalue, context.values), vns) @@ -230,11 +220,8 @@ Return `true` if a fixed value for `vn` is found in `context` or any of its desc This is contrast to [`hasfixed(::AbstractContext, ::VarName)`](@ref) which only checks for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. """ -function hasfixed_nested(context::AbstractContext, vn) - return hasfixed_nested(NodeTrait(hasfixed_nested, context), context, vn) -end -hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) -function hasfixed_nested(::IsParent, context, vn) +hasfixed_nested(context::AbstractContext, vn) = hasfixed(context, vn) +function hasfixed_nested(context::AbstractParentContext, vn) return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) end function hasfixed_nested(context::PrefixContext, vn) @@ -250,15 +237,12 @@ This is contrast to [`getfixed`](@ref) which only returns the value `vn` in `con not recursively looking into its descendants. """ function getfixed_nested(context::AbstractContext, vn) - return getfixed_nested(NodeTrait(getfixed_nested, context), context, vn) -end -function getfixed_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getfixed_nested(context::PrefixContext, vn) return getfixed_nested(collapse_prefix_stack(context), vn) end -function getfixed_nested(::IsParent, context, vn) +function getfixed_nested(context::AbstractParentContext, vn) return if hasfixed(context, vn) getfixed(context, vn) else @@ -283,7 +267,7 @@ end function fix(values::NTuple{<:Any,<:Pair{<:VarName}}) return fix(DefaultContext(), values) end -fix(context::AbstractContext, values::NamedTuple{()}) = context +fix(context::AbstractContext, ::NamedTuple{()}) = context function fix(context::AbstractContext, values::Union{AbstractDict,NamedTuple}) return FixedContext(values, context) end @@ -306,13 +290,10 @@ Note that this recursively traverses contexts, unfixing all along the way. See also: [`fix`](@ref) """ -unfix(::IsLeaf, context, args...) = context -function unfix(::IsParent, context, args...) +unfix(context::AbstractContext, args...) = context +function unfix(context::AbstractParentContext, args...) return setchildcontext(context, unfix(childcontext(context), args...)) end -function unfix(context, args...) - return unfix(NodeTrait(context), context, args...) -end function unfix(context::FixedContext) return unfix(childcontext(context)) end @@ -341,9 +322,8 @@ Return the values that are fixed under `context`. Note that this will recursively traverse the context stack and return a merged version of the fix values. """ -fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context) -fixed(::IsLeaf, context) = NamedTuple() -fixed(::IsParent, context) = fixed(childcontext(context)) +fixed(::AbstractContext) = NamedTuple() +fixed(context::AbstractParentContext) = fixed(childcontext(context)) function fixed(context::FixedContext) # Note the order of arguments to `merge`. The behavior of the rest of DPPL # is that the outermost `context` takes precendence, hence when resolving @@ -374,7 +354,7 @@ topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_co which explains this in much more detail. ```jldoctest -julia> using DynamicPPL: collapse_prefix_stack +julia> using DynamicPPL: collapse_prefix_stack, PrefixContext, ConditionContext julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); @@ -403,11 +383,8 @@ function collapse_prefix_stack(context::PrefixContext) # depth of the context stack. return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) end -function collapse_prefix_stack(context::AbstractContext) - return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) -end -collapse_prefix_stack(::IsLeaf, context) = context -function collapse_prefix_stack(::IsParent, context) +collapse_prefix_stack(context::AbstractContext) = context +function collapse_prefix_stack(context::AbstractParentContext) new_child_context = collapse_prefix_stack(childcontext(context)) return setchildcontext(context, new_child_context) end @@ -448,19 +425,10 @@ function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) return FixedContext(prefixed_vn_dict, prefixed_child_ctx) end -function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) - return prefix_cond_and_fixed_variables( - NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix - ) -end -function prefix_cond_and_fixed_variables( - ::IsLeaf, context::AbstractContext, prefix::VarName -) +function prefix_cond_and_fixed_variables(context::AbstractContext, ::VarName) return context end -function prefix_cond_and_fixed_variables( - ::IsParent, context::AbstractContext, prefix::VarName -) +function prefix_cond_and_fixed_variables(context::AbstractParentContext, prefix::VarName) return setchildcontext( context, prefix_cond_and_fixed_variables(childcontext(context), prefix) ) diff --git a/src/contexts/default.jl b/src/contexts/default.jl index ec21e1a56..3cafe39f1 100644 --- a/src/contexts/default.jl +++ b/src/contexts/default.jl @@ -17,7 +17,6 @@ with `DefaultContext` means 'calculating the log-probability associated with the in the `AbstractVarInfo`'. """ struct DefaultContext <: AbstractContext end -NodeTrait(::DefaultContext) = IsLeaf() """ DynamicPPL.tilde_assume!!( diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 44dbc5508..b70bf2bf1 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -1,11 +1,11 @@ """ AbstractInitStrategy -Abstract type representing the possible ways of initialising new values for -the random variables in a model (e.g., when creating a new VarInfo). +Abstract type representing the possible ways of initialising new values for the random +variables in a model (e.g., when creating a new VarInfo). -Any subtype of `AbstractInitStrategy` must implement the -[`DynamicPPL.init`](@ref) method. +Any subtype of `AbstractInitStrategy` must implement the [`DynamicPPL.init`](@ref) method, +and in some cases, [`DynamicPPL.get_param_eltype`](@ref) (see its docstring for details). """ abstract type AbstractInitStrategy end @@ -14,14 +14,60 @@ abstract type AbstractInitStrategy end Generate a new value for a random variable with the given distribution. -!!! warning "Return values must be unlinked" - The values returned by `init` must always be in the untransformed space, i.e., - they must be within the support of the original distribution. That means that, - for example, `init(rng, dist, u::InitFromUniform)` will in general return values that - are outside the range [u.lower, u.upper]. +This function must return a tuple `(x, trf)`, where + +- `x` is the generated value +- `trf` is a function that transforms the generated value back to the unlinked space. If the + value is already in unlinked space, then this should be `DynamicPPL.typed_identity`. You + can also use `Base.identity`, but if you use this, you **must** be confident that + `zero(eltype(x))` will **never** error. See the docstring of `typed_identity` for more + information. """ function init end +""" + DynamicPPL.get_param_eltype(strategy::AbstractInitStrategy) + +Return the element type of the parameters generated from the given initialisation strategy. + +The default implementation returns `Any`. However, for `InitFromParams` which provides known +parameters for evaluating the model, methods are implemented in order to return more specific +types. + +In general, if you are implementing a custom `AbstractInitStrategy`, correct behaviour can +only be guaranteed if you implement this method as well. However, quite often, the default +return value of `Any` will actually suffice. The cases where this does *not* suffice, and +where you _do_ have to manually implement `get_param_eltype`, are explained in the extended +help (see `??DynamicPPL.get_param_eltype` in the REPL). + +# Extended help + +There are a few edge cases in DynamicPPL where the element type is needed. These largely +relate to determining the element type of accumulators ahead of time (_before_ evaluation), +as well as promoting type parameters in model arguments. The classic case is when evaluating +a model with ForwardDiff: the accumulators must be set to `Dual`s, and any `Vector{Float64}` +arguments must be promoted to `Vector{Dual}`. Other tracer types, for example those in +SparseConnectivityTracer.jl, also require similar treatment. + +If the `AbstractInitStrategy` is never used in combination with tracer types, then it is +perfectly safe to return `Any`. This does not lead to type instability downstream because +the actual accumulators will still be created with concrete Float types (the `Any` is just +used to determine whether the float type needs to be modified). + +In case that wasn't enough: in fact, even the above is not always true. Firstly, the +accumulator argument is only true when evaluating with ThreadSafeVarInfo. See the comments +in `DynamicPPL.unflatten` for more details. For non-threadsafe evaluation, Julia is capable +of automatically promoting the types on its own. Secondly, the promotion only matters if you +are trying to directly assign into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar +tracer type, for example using `xs[i] = MyDual`. This doesn't actually apply to +tilde-statements like `xs[i] ~ ...` because those use `Accessors.@set` under the hood, which +also does the promotion for you. For the gory details, see the following issues: + +- https://github.com/TuringLang/DynamicPPL.jl/issues/906 for accumulator types +- https://github.com/TuringLang/DynamicPPL.jl/issues/823 for type argument promotion +""" +get_param_eltype(::AbstractInitStrategy) = Any + """ InitFromPrior() @@ -29,7 +75,7 @@ Obtain new values by sampling from the prior distribution. """ struct InitFromPrior <: AbstractInitStrategy end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) - return rand(rng, dist) + return rand(rng, dist), typed_identity end """ @@ -69,43 +115,61 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro if x isa Array{<:Any,0} x = x[] end - return x + return x, typed_identity end """ InitFromParams( - params::Union{AbstractDict{<:VarName},NamedTuple}, + params::Any fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) -Obtain new values by extracting them from the given dictionary or NamedTuple. +Obtain new values by extracting them from the given set of `params`. + +The most common use case is to provide a `NamedTuple` or `AbstractDict{<:VarName}`, which +provides a mapping from variable names to values. However, we leave the type of `params` +open in order to allow for custom parameter storage types. -The parameter `fallback` specifies how new values are to be obtained if they -cannot be found in `params`, or they are specified as `missing`. `fallback` -can either be an initialisation strategy itself, in which case it will be -used to obtain new values, or it can be `nothing`, in which case an error -will be thrown. The default for `fallback` is `InitFromPrior()`. +## Custom parameter storage types -!!! note - The values in `params` must be provided in the space of the untransformed - distribution. +For `InitFromParams` to work correctly with a custom `params::P`, you need to implement + +```julia +DynamicPPL.init(rng, vn::VarName, dist::Distribution, p::InitFromParams{P}) where {P} +``` + +This tells you how to obtain values for the random variable `vn` from `p.params`. Note that +the last argument is `InitFromParams(params)`, not just `params` itself. Please see the +docstring of [`DynamicPPL.init`](@ref) for more information on the expected behaviour. + +If you only use `InitFromParams` with `DynamicPPL.OnlyAccsVarInfo`, as is usually the case, +then you will not need to implement anything else. So far, this is the same as you would do +for creating any new `AbstractInitStrategy` subtype. + +However, to use `InitFromParams` with a full `DynamicPPL.VarInfo`, you *may* also need to +implement + +```julia +DynamicPPL.get_param_eltype(p::InitFromParams{P}) where {P} +``` + +See the docstring of [`DynamicPPL.get_param_eltype`](@ref) for more information on when this +is needed. + +The argument `fallback` specifies how new values are to be obtained if they cannot be found +in `params`, or they are specified as `missing`. `fallback` can either be an initialisation +strategy itself, in which case it will be used to obtain new values, or it can be `nothing`, +in which case an error will be thrown. The default for `fallback` is `InitFromPrior()`. """ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy params::P fallback::S - function InitFromParams( - params::AbstractDict{<:VarName}, - fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior(), - ) - return new{typeof(params),typeof(fallback)}(params, fallback) - end - function InitFromParams( - params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() - ) - return InitFromParams(to_varname_dict(params), fallback) - end end -function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) +InitFromParams(params) = InitFromParams(params, InitFromPrior()) + +function init( + rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams{P} +) where {P<:Union{AbstractDict{<:VarName},NamedTuple}} # TODO(penelopeysm): It would be nice to do a check to make sure that all # of the parameters in `p.params` were actually used, and either warn or # error if they aren't. This is actually quite non-trivial though because @@ -119,13 +183,99 @@ function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitF else # TODO(penelopeysm): Since x is user-supplied, maybe we could also # check here that the type / size of x matches the dist? - x + x, typed_identity end else p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") init(rng, vn, dist, p.fallback) end end +function get_param_eltype( + strategy::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}} +) + return infer_nested_eltype(typeof(strategy.params)) +end + +""" + RangeAndLinked + +Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable +in the model will in general correspond to a sub-vector of `params`. This struct stores +information about that range, as well as whether the sub-vector represents a linked value or +an unlinked value. + +$(TYPEDFIELDS) +""" +struct RangeAndLinked + # indices that the variable corresponds to in the vectorised parameter + range::UnitRange{Int} + # whether it's linked + is_linked::Bool +end + +""" + VectorWithRanges{Tlink}( + iden_varname_ranges::NamedTuple, + varname_ranges::Dict{VarName,RangeAndLinked}, + vect::AbstractVector{<:Real}, + ) + +A struct that wraps a vector of parameter values, plus information about how random +variables map to ranges in that vector. + +In the simplest case, this could be accomplished only with a single dictionary mapping +VarNames to ranges and link status. However, for performance reasons, we separate out +VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All +non-identity-optic VarNames are stored in the `varname_ranges` Dict. + +It would be nice to improve the NamedTuple and Dict approach. See, e.g. +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. +""" +struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}} + # This NamedTuple stores the ranges for identity VarNames + iden_varname_ranges::N + # This Dict stores the ranges for all other VarNames + varname_ranges::Dict{VarName,RangeAndLinked} + # The full parameter vector which we index into to get variable values + vect::T + + function VectorWithRanges{Tlink}( + iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T + ) where {Tlink,N,T} + return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect) + end +end + +function _get_range_and_linked( + vr::VectorWithRanges, ::VarName{sym,typeof(identity)} +) where {sym} + return vr.iden_varname_ranges[sym] +end +function _get_range_and_linked(vr::VectorWithRanges, vn::VarName) + return vr.varname_ranges[vn] +end +function init( + ::Random.AbstractRNG, + vn::VarName, + dist::Distribution, + p::InitFromParams{<:VectorWithRanges{T}}, +) where {T} + vr = p.params + range_and_linked = _get_range_and_linked(vr, vn) + # T can either be `nothing` (i.e., link status is mixed, in which + # case we use the stored link status), or `true` / `false`, which + # indicates that all variables are linked / unlinked. + linked = isnothing(T) ? range_and_linked.is_linked : T + transform = if linked + from_linked_vec_transform(dist) + else + from_vec_transform(dist) + end + return (@view vr.vect[range_and_linked.range]), transform +end +function get_param_eltype(strategy::InitFromParams{<:VectorWithRanges}) + return eltype(strategy.params.vect) +end """ InitContext( @@ -150,15 +300,13 @@ struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractCon return InitContext(Random.default_rng(), strategy) end end -NodeTrait(::InitContext) = IsLeaf() function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) in_varinfo = haskey(vi, vn) - # `init()` always returns values in original space, i.e. possibly - # constrained - x = init(ctx.rng, vn, dist, ctx.strategy) + val, transform = init(ctx.rng, vn, dist, ctx.strategy) + x, inv_logjac = with_logabsdet_jacobian(transform, val) # Determine whether to insert a transformed value into the VarInfo. # If the VarInfo alrady had a value for this variable, we will # keep the same linked status as in the original VarInfo. If not, we @@ -166,17 +314,49 @@ function tilde_assume!!( # is_transformed(vi) returns true if vi is nonempty and all variables in vi # are linked. insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi) - y, logjac = if insert_transformed_value - with_logabsdet_jacobian(link_transform(dist), x) + val_to_insert, logjac = if insert_transformed_value + # Calculate the forward logjac and sum them up. + y, fwd_logjac = with_logabsdet_jacobian(link_transform(dist), x) + # Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian + # calculation wastes a lot of time going from linked vectorised -> unlinked -> + # linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. + # + # However, `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which + # case this branch is never hit (since `in_varinfo` will always be false). It does + # mean that the combination of InitFromParams{<:VectorWithRanges} with a full, + # linked, VarInfo will be very slow. That should never really be used, though. So + # (at least for now) we can leave this branch in for full generality with other + # combinations of init strategies / VarInfo. + # + # TODO(penelopeysm): Figure out one day how to refactor this. The crux of the issue + # is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`, + # which is NOT the same as `inverse(link_transform)` (because there is an additional + # vectorisation step). We need `init` and `tilde_assume!!` to share this information + # but it's not clear right now how to do this. In my opinion, there are a couple of + # potential ways forward: + # + # 1. Just remove metadata entirely so that there is never any need to construct + # a linked vectorised value again. This would require us to use VAIMAcc as the only + # way of getting values. I consider this the best option, but it might take a long + # time. + # + # 2. Clean up the behaviour of bijectors so that we can have a complete separation + # between the linking and vectorisation parts of it. That way, `x` can either be + # unlinked, unlinked vectorised, linked, or linked vectorised, and regardless of + # which it is, we should only need to apply at most one linking and one + # vectorisation transform. Doing so would allow us to remove the first call to + # `with_logabsdet_jacobian`, and instead compose and/or uncompose the + # transformations before calling `with_logabsdet_jacobian` once. + y, -inv_logjac + fwd_logjac else - x, zero(LogProbType) + x, -inv_logjac end # Add the new value to the VarInfo. `push!!` errors if the value already # exists, hence the need for setindex!!. if in_varinfo - vi = setindex!!(vi, y, vn) + vi = setindex!!(vi, val_to_insert, vn) else - vi = push!!(vi, vn, y, dist) + vi = push!!(vi, vn, val_to_insert, dist) end # Neither of these set the `trans` flag so we have to do it manually if # necessary. diff --git a/src/contexts/prefix.jl b/src/contexts/prefix.jl index 24615e683..45307874a 100644 --- a/src/contexts/prefix.jl +++ b/src/contexts/prefix.jl @@ -13,7 +13,7 @@ unique. See also: [`to_submodel`](@ref) """ -struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext +struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractParentContext vn_prefix::Tvn context::C end @@ -23,7 +23,6 @@ function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} end PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) -NodeTrait(::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context function setchildcontext(ctx::PrefixContext, child::AbstractContext) return PrefixContext(ctx.vn_prefix, child) @@ -37,11 +36,8 @@ Apply the prefixes in the context `ctx` to the variable name `vn`. function prefix(ctx::PrefixContext, vn::VarName) return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) end -function prefix(ctx::AbstractContext, vn::VarName) - return prefix(NodeTrait(ctx), ctx, vn) -end -prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn -function prefix(::IsParent, ctx::AbstractContext, vn::VarName) +prefix(::AbstractContext, vn::VarName) = vn +function prefix(ctx::AbstractParentContext, vn::VarName) return prefix(childcontext(ctx), vn) end @@ -72,11 +68,8 @@ function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) ) return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes end -function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) - return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) -end -prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) -function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) +prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) = (vn, ctx) +function prefix_and_strip_contexts(ctx::AbstractParentContext, vn::VarName) vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) return vn, setchildcontext(ctx, new_ctx) end diff --git a/src/contexts/transformation.jl b/src/contexts/transformation.jl index 5153f7857..c2eee2863 100644 --- a/src/contexts/transformation.jl +++ b/src/contexts/transformation.jl @@ -10,7 +10,6 @@ Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the how to do the transformation, used by e.g. `SimpleVarInfo`. """ struct DynamicTransformationContext{isinverse} <: AbstractContext end -NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume!!( ::DynamicTransformationContext{isinverse}, diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 7c7438c9f..abfb61c94 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -1,318 +1,305 @@ -using AbstractMCMC: AbstractModel +using DynamicPPL: + AbstractVarInfo, + AccumulatorTuple, + InitContext, + InitFromParams, + AbstractInitStrategy, + LogJacobianAccumulator, + LogLikelihoodAccumulator, + LogPriorAccumulator, + Model, + ThreadSafeVarInfo, + VarInfo, + OnlyAccsVarInfo, + RangeAndLinked, + VectorWithRanges, + Metadata, + VarNamedVector, + default_accumulators, + float_type_with_fallback, + getlogjoint, + getlogjoint_internal, + getloglikelihood, + getlogprior, + getlogprior_internal +using ADTypes: ADTypes +using BangBang: BangBang +using AbstractPPL: AbstractPPL, VarName +using LogDensityProblems: LogDensityProblems import DifferentiationInterface as DI +using Random: Random """ - is_supported(adtype::AbstractADType) - -Check if the given AD type is formally supported by DynamicPPL. - -AD backends that are not formally supported can still be used for gradient -calculation; it is just that the DynamicPPL developers do not commit to -maintaining compatibility with them. -""" -is_supported(::ADTypes.AbstractADType) = false -is_supported(::ADTypes.AutoEnzyme) = true -is_supported(::ADTypes.AutoForwardDiff) = true -is_supported(::ADTypes.AutoMooncake) = true -is_supported(::ADTypes.AutoReverseDiff) = true - -""" - LogDensityFunction( + DynamicPPL.LogDensityFunction( model::Model, getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) A struct which contains a model, along with all the information necessary to: - calculate its log density at a given point; - - and if `adtype` is provided, calculate the gradient of the log density at - that point. - -This information can be extracted using the LogDensityProblems.jl interface, -specifically, using `LogDensityProblems.logdensity` and -`LogDensityProblems.logdensity_and_gradient`. If `adtype` is nothing, then only -`logdensity` is implemented. If `adtype` is a concrete AD backend type, then -`logdensity_and_gradient` is also implemented. - -There are several options for `getlogdensity` that are 'supported' out of the -box: - -- [`getlogjoint_internal`](@ref): calculate the log joint, including the - log-Jacobian term for any variables that have been linked in the provided - VarInfo. -- [`getlogprior_internal`](@ref): calculate the log prior, including the - log-Jacobian term for any variables that have been linked in the provided - VarInfo. -- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring - any effects of linking -- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring - any effects of linking -- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected - by linking, since transforms are only applied to random variables) + - and if `adtype` is provided, calculate the gradient of the log density at that point. + +This information can be extracted using the LogDensityProblems.jl interface, specifically, +using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If +`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD +backend type, then `logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of + linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of + linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, + since transforms are only applied to random variables) !!! note - By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the - result of `LogDensityProblems.logdensity(f, x)` will depend on whether the - `LogDensityFunction` was created with a linked or unlinked VarInfo. This - is done primarily to ease interoperability with MCMC samplers. - -If you provide one of these functions, a `VarInfo` will be automatically created -for you. If you provide a different function, you have to manually create a -VarInfo and pass it as the third argument. - -If the `adtype` keyword argument is provided, then this struct will also store -the adtype along with other information for efficient calculation of the -gradient of the log density. Note that preparing a `LogDensityFunction` with an -AD type `AutoBackend()` requires the AD backend itself to have been loaded -(e.g. with `import Backend`). - -# Fields -$(FIELDS) - -# Examples - -```jldoctest -julia> using Distributions - -julia> using DynamicPPL: LogDensityFunction, setaccs!! - -julia> @model function demo(x) - m ~ Normal() - x ~ Normal(m, 1) - end -demo (generic function with 2 methods) - -julia> model = demo(1.0); - -julia> f = LogDensityFunction(model); - -julia> # It implements the interface of LogDensityProblems.jl. - using LogDensityProblems - -julia> LogDensityProblems.logdensity(f, [0.0]) --2.3378770664093453 - -julia> LogDensityProblems.dimension(f) -1 - -julia> # By default it uses `VarInfo` under the hood, but this is not necessary. - f = LogDensityFunction(model, getlogjoint_internal, SimpleVarInfo(model)); - -julia> LogDensityProblems.logdensity(f, [0.0]) --2.3378770664093453 - -julia> # One can also specify evaluating e.g. the log prior only: - f_prior = LogDensityFunction(model, getlogprior); - -julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) -true - -julia> # If we also need to calculate the gradient, we can specify an AD backend. - import ForwardDiff, ADTypes - -julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff()); - -julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) -(-2.3378770664093453, [1.0]) -``` + By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the result of + `LogDensityProblems.logdensity(f, x)` will depend on whether the `LogDensityFunction` + was created with a linked or unlinked VarInfo. This is done primarily to ease + interoperability with MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created for you. If +you provide a different function, you have to manually create a VarInfo and pass it as the +third argument. + +If the `adtype` keyword argument is provided, then this struct will also store the adtype +along with other information for efficient calculation of the gradient of the log density. +Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD +backend itself to have been loaded (e.g. with `import Backend`). + +## Fields + +Note that it is undefined behaviour to access any of a `LogDensityFunction`'s fields, apart +from: + +- `ldf.model`: The original model from which this `LogDensityFunction` was constructed. +- `ldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD + type was provided. + +# Extended help + +Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a +given set of parameters: + +1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters + inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. + +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores + them inside a VarInfo's metadata. + +In general, both of these approaches work fine, but the fact that they modify the VarInfo's +metadata can often be quite wasteful. In particular, it is very common that the only outputs +we care about from model evaluation are those which are stored in accumulators, such as log +probability densities, or `ValuesAsInModel`. + +To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains +accumulators. It implements enough of the `AbstractVarInfo` interface to not error during +model evaluation. + +Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with +it, it is mandatory that parameters are provided from outside the VarInfo, namely via +`InitContext`. + +The main problem that we face is that it is not possible to directly implement +`DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. +In particular, it is not clear: + + - which parts of the vector correspond to which random variables, and + - whether the variables are linked or unlinked. + +Traditionally, this problem has been solved by `unflatten`, because that function would +place values into the VarInfo's metadata alongside the information about ranges and linking. +That way, when we evaluate with `DefaultContext`, we can read this information out again. +However, we want to avoid using a metadata. Thus, here, we _extract this information from +the VarInfo_ a single time when constructing a `LogDensityFunction` object. Inside the +LogDensityFunction, we store a mapping from VarNames to ranges in that vector, along with +link status. + +For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all +other VarNames, this is stored in a Dict. The internal data structure used to represent this +could almost certainly be optimised further. See e.g. the discussion in +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. + +When evaluating the model, this allows us to combine the parameter vector together with those +ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly read +parameter values from the vector. + +Note that this assumes that the ranges and link status are static throughout the lifetime of +the `LogDensityFunction` object. Therefore, a `LogDensityFunction` object cannot handle +models which have variable numbers of parameters, or models which may visit random variables +in different orders depending on stochastic control flow. **Indeed, silent errors may occur +with such models.** This is a general limitation of vectorised parameters: the original +`unflatten` + `evaluate!!` approach also fails with such models. """ struct LogDensityFunction{ - M<:Model,F<:Function,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType} -} <: AbstractModel - "model used for evaluation" + # true if all variables are linked; false if all variables are unlinked; nothing if + # mixed + Tlink, + M<:Model, + AD<:Union{ADTypes.AbstractADType,Nothing}, + F<:Function, + N<:NamedTuple, + ADP<:Union{Nothing,DI.GradientPrep}, +} model::M - "function to be called on `varinfo` to extract the log density. By default `getlogjoint_internal`." - getlogdensity::F - "varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`." - varinfo::V - "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" adtype::AD - "(internal use only) gradient preparation object for the model" - prep::Union{Nothing,DI.GradientPrep} + _getlogdensity::F + _iden_varname_ranges::N + _varname_ranges::Dict{VarName,RangeAndLinked} + _adprep::ADP + _dim::Int + + """ + function LogDensityFunction( + model::Model, + getlogdensity::Function=getlogjoint_internal, + link::Union{Bool,Set{VarName}}=false; + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + + Generate a `LogDensityFunction` for the given model. + The `link` argument specifies which VarNames in the model should be linked. This can + either be a Bool (if `link=true` all variables are linked; if `link=false` all variables + are unlinked); or a `Set{VarName}` specifying exactly which variables should be linked. + Any sub-variables of the set's elements will be linked. + """ function LogDensityFunction( model::Model, getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); + link::Union{Bool,Set{VarName}}=false; adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) - if adtype === nothing - prep = nothing + # Run the model once to determine variable ranges and linking. Because the + # parameters stored in the LogDensityFunction are never used, we can just use + # InitFromPrior to create new values. The actual values don't matter, only the + # length, since that's used for gradient prep. + vi = OnlyAccsVarInfo(AccumulatorTuple((RangeLinkedValueAcc(link),))) + _, vi = DynamicPPL.init!!(model, vi, InitFromPrior()) + rlvacc = first(vi.accs) + Tlink, all_iden_ranges, all_ranges, x = get_data(rlvacc) + @info Tlink, all_iden_ranges, all_ranges, x + # That gives us all the information we need to create the LogDensityFunction. + dim = length(x) + # Do AD prep if needed + prep = if adtype === nothing + nothing else # Make backend-specific tweaks to the adtype - adtype = tweak_adtype(adtype, model, varinfo) - # Check whether it is supported - is_supported(adtype) || - @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." - # Get a set of dummy params to use for prep - x = [val for val in varinfo[:]] - if use_closure(adtype) - prep = DI.prepare_gradient( - LogDensityAt(model, getlogdensity, varinfo), adtype, x - ) - else - prep = DI.prepare_gradient( - logdensity_at, - adtype, - x, - DI.Constant(model), - DI.Constant(getlogdensity), - DI.Constant(varinfo), - ) - end + adtype = DynamicPPL.tweak_adtype(adtype, model, x) + DI.prepare_gradient( + LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges), + adtype, + x, + ) end - return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)}( - model, getlogdensity, varinfo, adtype, prep + return new{ + Tlink, + typeof(model), + typeof(adtype), + typeof(getlogdensity), + typeof(all_iden_ranges), + typeof(prep), + }( + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim ) end end +################################### +# LogDensityProblems.jl interface # +################################### """ - LogDensityFunction( - ldf::LogDensityFunction, - adtype::Union{Nothing,ADTypes.AbstractADType} - ) - -Create a new LogDensityFunction using the model and varinfo from the given -`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, -pass `nothing` as the second argument. -""" -function LogDensityFunction( - f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType} -) - return if adtype === f.adtype - f # Avoid recomputing prep if not needed - else - LogDensityFunction(f.model, f.getlogdensity, f.varinfo; adtype=adtype) - end -end - -""" - ldf_default_varinfo(model::Model, getlogdensity::Function) + fast_ldf_accs(getlogdensity::Function) -Create the default AbstractVarInfo that should be used for evaluating the log density. - -Only the accumulators necesessary for `getlogdensity` will be used. +Determine which accumulators are needed for fast evaluation with the given +`getlogdensity` function. """ -function ldf_default_varinfo(::Model, getlogdensity::Function) - msg = """ - LogDensityFunction does not know what sort of VarInfo should be used when \ - `getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly. - """ - return error(msg) +fast_ldf_accs(::Function) = default_accumulators() +fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function fast_ldf_accs(::typeof(getlogjoint)) + return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) end - -ldf_default_varinfo(model::Model, ::typeof(getlogjoint_internal)) = VarInfo(model) - -function ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogLikelihoodAccumulator())) +function fast_ldf_accs(::typeof(getlogprior_internal)) + return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) end +fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -function ldf_default_varinfo(model::Model, ::typeof(getlogprior_internal)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogJacobianAccumulator())) -end +struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple} + model::M + getlogdensity::F + iden_varname_ranges::N + varname_ranges::Dict{VarName,RangeAndLinked} -function ldf_default_varinfo(model::Model, ::typeof(getlogprior)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) + function LogDensityAt{Tlink}( + model::M, + getlogdensity::F, + iden_varname_ranges::N, + varname_ranges::Dict{VarName,RangeAndLinked}, + ) where {Tlink,M,F,N} + return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges) + end end - -function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood)) - return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),)) +function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink} + strategy = InitFromParams( + VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing + ) + accs = fast_ldf_accs(f.getlogdensity) + _, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy) + return f.getlogdensity(vi) end -""" - logdensity_at( - x::AbstractVector, - model::Model, - getlogdensity::Function, - varinfo::AbstractVarInfo, +function LogDensityProblems.logdensity( + ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} +) where {Tlink} + return LogDensityAt{Tlink}( + ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges + )( + params ) - -Evaluate the log density of the given `model` at the given parameter values -`x`, using the given `varinfo`. Note that the `varinfo` argument is provided -only for its structure, in the sense that the parameters from the vector `x` -are inserted into it, and its own parameters are discarded. `getlogdensity` is -the function that extracts the log density from the evaluated varinfo. -""" -function logdensity_at( - x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo -) - varinfo_new = unflatten(varinfo, x) - varinfo_eval = last(evaluate!!(model, varinfo_new)) - return getlogdensity(varinfo_eval) end -""" - LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}( - model::M - getlogdensity::F, - varinfo::V +function LogDensityProblems.logdensity_and_gradient( + ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} +) where {Tlink} + return DI.value_and_gradient( + LogDensityAt{Tlink}( + ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges + ), + ldf._adprep, + ldf.adtype, + params, ) - -A callable struct that serves the same purpose as `x -> logdensity_at(x, model, -getlogdensity, varinfo)`. -""" -struct LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo} - model::M - getlogdensity::F - varinfo::V -end -function (ld::LogDensityAt)(x::AbstractVector) - return logdensity_at(x, ld.model, ld.getlogdensity, ld.varinfo) end -### LogDensityProblems interface - function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,F,V,Nothing}} -) where {M,F,V} + ::Type{<:LogDensityFunction{T,M,Nothing}} +) where {T,M} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,F,V,AD}} -) where {M,F,V,AD<:ADTypes.AbstractADType} + ::Type{<:LogDensityFunction{T,M,<:ADTypes.AbstractADType}} +) where {T,M} return LogDensityProblems.LogDensityOrder{1}() end -function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - return logdensity_at(x, f.model, f.getlogdensity, f.varinfo) +function LogDensityProblems.dimension(ldf::LogDensityFunction) + return ldf._dim end -function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction{M,F,V,AD}, x::AbstractVector -) where {M,F,V,AD<:ADTypes.AbstractADType} - f.prep === nothing && - error("Gradient preparation not available; this should not happen") - x = [val for val in x] # Concretise type - # Make branching statically inferrable, i.e. type-stable (even if the two - # branches happen to return different types) - return if use_closure(f.adtype) - DI.value_and_gradient( - LogDensityAt(f.model, f.getlogdensity, f.varinfo), f.prep, f.adtype, x - ) - else - DI.value_and_gradient( - logdensity_at, - f.prep, - f.adtype, - x, - DI.Constant(f.model), - DI.Constant(f.getlogdensity), - DI.Constant(f.varinfo), - ) - end -end - -# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? -LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) - -### Utils """ tweak_adtype( adtype::ADTypes.AbstractADType, model::Model, - varinfo::AbstractVarInfo, + params::AbstractVector ) Return an 'optimised' form of the adtype. This is useful for doing @@ -323,55 +310,108 @@ model. By default, this just returns the input unchanged. """ -tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype - -""" - use_closure(adtype::ADTypes.AbstractADType) - -In LogDensityProblems, we want to calculate the derivative of logdensity(f, x) -with respect to x, where f is the model (in our case LogDensityFunction) and is -a constant. However, DifferentiationInterface generally expects a -single-argument function g(x) to differentiate. - -There are two ways of dealing with this: - -1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) - -2. Use a constant DI.Context. This lets us pass a two-argument function to DI, - as long as we also give it the 'inactive argument' (i.e. the model) wrapped - in `DI.Constant`. - -The relative performance of the two approaches, however, depends on the AD -backend used. Some benchmarks are provided here: -https://github.com/TuringLang/DynamicPPL.jl/issues/946#issuecomment-2931604829 - -This function is used to determine whether a given AD backend should use a -closure or a constant. If `use_closure(adtype)` returns `true`, then the -closure approach will be used. By default, this function returns `false`, i.e. -the constant approach will be used. -""" -use_closure(::ADTypes.AbstractADType) = true -use_closure(::ADTypes.AutoEnzyme) = false - -""" - getmodel(f) - -Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. -""" -getmodel(f::DynamicPPL.LogDensityFunction) = f.model - -""" - setmodel(f, model[, adtype]) - -Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. -""" -function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return LogDensityFunction(model, f.getlogdensity, f.varinfo; adtype=f.adtype) +tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVector) = adtype + +############################## +# RangeLinkedVal accumulator # +############################## + +struct RangeLinkedValueAcc{L<:Union{Bool,Set{VarName}},N<:NamedTuple} <: AbstractAccumulator + should_link::L + current_index::Int + iden_varname_ranges::N + varname_ranges::Dict{VarName,RangeAndLinked} + values::Vector{Any} +end +function RangeLinkedValueAcc(should_link::Union{Bool,Set{VarName}}) + return RangeLinkedValueAcc(should_link, 1, (;), Dict{VarName,RangeAndLinked}(), Any[]) end -""" - getparams(f::LogDensityFunction) +function get_data(rlvacc::RangeLinkedValueAcc) + link_statuses = Bool[] + for ral in rlvacc.iden_varname_ranges + push!(link_statuses, ral.is_linked) + end + for (_, ral) in rlvacc.varname_ranges + push!(link_statuses, ral.is_linked) + end + Tlink = if all(link_statuses) + true + elseif all(!s for s in link_statuses) + false + else + nothing + end + return ( + Tlink, rlvacc.iden_varname_ranges, rlvacc.varname_ranges, [v for v in rlvacc.values] + ) +end -Return the parameters of the wrapped varinfo as a vector. -""" -getparams(f::LogDensityFunction) = f.varinfo[:] +accumulator_name(::Type{<:RangeLinkedValueAcc}) = :RangeLinkedValueAcc +accumulate_observe!!(acc::RangeLinkedValueAcc, dist, val, vn) = acc +function accumulate_assume!!( + acc::RangeLinkedValueAcc, val, logjac, vn::VarName{sym}, dist::Distribution +) where {sym} + link_this_vn = if acc.should_link isa Bool + acc.should_link + else + # Set{VarName} + any(should_link_vn -> subsumes(should_link_vn, vn), acc.should_link) + end + val = if link_this_vn + to_linked_vec_transform(dist)(val) + else + to_vec_transform(dist)(val) + end + new_values = vcat(acc.values, val) + len = length(val) + range = (acc.current_index):(acc.current_index + len - 1) + ral = RangeAndLinked(range, link_this_vn) + iden_varnames, other_varnames = if getoptic(vn) === identity + merge(acc.iden_varname_ranges, (sym => ral,)), acc.varname_ranges + else + acc.varname_ranges[vn] = ral + acc.iden_varname_ranges, acc.varname_ranges + end + return RangeLinkedValueAcc( + acc.should_link, acc.current_index + len, iden_varnames, other_varnames, new_values + ) +end +function Base.copy(acc::RangeLinkedValueAcc) + return RangeLinkedValueAcc( + acc.should_link, + acc.current_index, + acc.iden_varname_ranges, + copy(acc.varname_ranges), + copy(acc.values), + ) +end +_zero(acc::RangeLinkedValueAcc) = RangeLinkedValueAcc(acc.should_link) +reset(acc::RangeLinkedValueAcc) = _zero(acc) +split(acc::RangeLinkedValueAcc) = _zero(acc) +function combine(acc1::RangeLinkedValueAcc, acc2::RangeLinkedValueAcc) + new_values = vcat(acc1.values, acc2.values) + new_current_index = acc1.current_index + acc2.current_index - 1 + acc2_iden_varnames_shifted = NamedTuple( + k => RangeAndLinked((ral.range .+ (acc1.current_index - 1)), ral.is_linked) for + (k, ral) in pairs(acc2.iden_varname_ranges) + ) + new_iden_varname_ranges = merge(acc1.iden_varname_ranges, acc2_iden_varnames_shifted) + acc2_varname_ranges_shifted = Dict{VarName,RangeAndLinked}() + for (k, ral) in acc2.varname_ranges + acc2_varname_ranges_shifted[k] = RangeAndLinked( + (ral.range .+ (acc1.current_index - 1)), ral.is_linked + ) + end + new_varname_ranges = merge(acc1.varname_ranges, acc2_varname_ranges_shifted) + return RangeLinkedValueAcc( + # TODO: using acc1.should_link is not really 'correct', but `should_link` only + # affects model evaluation and `combine` only runs at the end of model evaluation, + # so it shouldn't matter + acc1.should_link, + new_current_index, + new_iden_varname_ranges, + new_varname_ranges, + new_values, + ) +end diff --git a/src/model.jl b/src/model.jl index edb042ba9..9029318b1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -427,7 +427,7 @@ Return the conditioned values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: conditioned, contextualize +julia> using DynamicPPL: conditioned, contextualize, PrefixContext, ConditionContext julia> @model function demo() m ~ Normal() @@ -770,7 +770,7 @@ Return the fixed values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: fixed, contextualize +julia> using DynamicPPL: fixed, contextualize, PrefixContext julia> @model function demo() m ~ Normal() @@ -881,30 +881,56 @@ end [init_strategy::AbstractInitStrategy=InitFromPrior()] ) -Evaluate the `model` and replace the values of the model's random variables -in the given `varinfo` with new values, using a specified initialisation strategy. -If the values in `varinfo` are not set, they will be added -using a specified initialisation strategy. +Evaluate the `model` and replace the values of the model's random variables in the given +`varinfo` with new values, using a specified initialisation strategy. If the values in +`varinfo` are not set, they will be added using a specified initialisation strategy. If `init_strategy` is not provided, defaults to `InitFromPrior()`. Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -function init!!( +@inline function init!!( + # Note that this `@inline` is mandatory for performance, especially for + # LogDensityFunction. If it's not inlined, it leads to extra allocations (even for + # trivial models) and much slower runtime. rng::Random.AbstractRNG, model::Model, - varinfo::AbstractVarInfo, - init_strategy::AbstractInitStrategy=InitFromPrior(), + vi::AbstractVarInfo, + strategy::AbstractInitStrategy=InitFromPrior(), ) - new_model = setleafcontext(model, InitContext(rng, init_strategy)) - return evaluate!!(new_model, varinfo) + ctx = InitContext(rng, strategy) + model = DynamicPPL.setleafcontext(model, ctx) + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 + return if Threads.nthreads() > 1 + # TODO(penelopeysm): The logic for setting eltype of accs is very similar to that + # used in `unflatten`. The reason why we need it here is because the VarInfo `vi` + # won't have been filled with parameters prior to `init!!` being called. + # + # Note that this eltype promotion is only needed for threadsafe evaluation. In an + # ideal world, this code should be handled inside `evaluate_threadsafe!!` or a + # similar method. In other words, it should not be here, and it should not be inside + # `unflatten` either. The problem is performance. Shifting this code around can have + # massive, inexplicable, impacts on performance. This should be investigated + # properly. + param_eltype = DynamicPPL.get_param_eltype(strategy) + accs = map(vi.accs) do acc + DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + end + vi = DynamicPPL.setaccs!!(vi, accs) + tsvi = ThreadSafeVarInfo(resetaccs!!(vi)) + retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi) + return retval, setaccs!!(tsvi_new.varinfo, DynamicPPL.getaccs(tsvi_new)) + else + return DynamicPPL._evaluate!!(model, resetaccs!!(vi)) + end end -function init!!( - model::Model, - varinfo::AbstractVarInfo, - init_strategy::AbstractInitStrategy=InitFromPrior(), +@inline function init!!( + model::Model, vi::AbstractVarInfo, strategy::AbstractInitStrategy=InitFromPrior() ) - return init!!(Random.default_rng(), model, varinfo, init_strategy) + # This `@inline` is also mandatory for performance + return init!!(Random.default_rng(), model, vi, strategy) end """ @@ -986,9 +1012,13 @@ Return the arguments and keyword arguments to be passed to the evaluator of the ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) - :($matchingvalue(varinfo, model.args.$var)...) + :( + $matchingvalue( + $get_param_eltype(varinfo, model.context), model.args.$var + )... + ) else - :($matchingvalue(varinfo, model.args.$var)) + :($matchingvalue($get_param_eltype(varinfo, model.context), model.args.$var)) end for var in argnames ] return quote @@ -1006,6 +1036,30 @@ Return the arguments and keyword arguments to be passed to the evaluator of the end end +""" + get_param_eltype(varinfo::AbstractVarInfo, context::AbstractContext) + +Get the element type of the parameters being used to evaluate a model, using a `varinfo` +under the given `context`. For example, when evaluating a model with ForwardDiff AD, this +should return `ForwardDiff.Dual`. + +By default, this uses `eltype(varinfo)` which is slightly cursed. This relies on the fact +that typically, before evaluation, the parameters will have been inserted into the VarInfo's +metadata field. + +For `InitContext`, it's quite different: because `InitContext` is responsible for supplying +the parameters, we can avoid using `eltype(varinfo)` and instead query the parameters inside +it. See the docstring of `get_param_eltype(strategy::AbstractInitStrategy)` for more +explanation. +""" +function get_param_eltype(vi::AbstractVarInfo, ctx::AbstractParentContext) + return get_param_eltype(vi, DynamicPPL.childcontext(ctx)) +end +get_param_eltype(vi::AbstractVarInfo, ::AbstractContext) = eltype(vi) +function get_param_eltype(::AbstractVarInfo, ctx::InitContext) + return get_param_eltype(ctx.strategy) +end + """ getargnames(model::Model) @@ -1107,11 +1161,6 @@ function predict end Execute `model` with variables `keys` set to `values` and return the values returned by the `model`. - returned(model::Model, values, keys) - -Execute `model` with variables `keys` set to `values` and return the values returned by the `model`. -This method is deprecated; use the NamedTuple or AbstractDict version instead. - # Example ```jldoctest julia> using DynamicPPL, Distributions @@ -1141,6 +1190,3 @@ function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarN # We can't use new_model() because that overwrites it with an InitContext of its own. return first(evaluate!!(new_model, vi)) end -Base.@deprecate returned(model::Model, values, keys) returned( - model, NamedTuple{keys}(values) -) diff --git a/src/onlyaccs.jl b/src/onlyaccs.jl new file mode 100644 index 000000000..940f23124 --- /dev/null +++ b/src/onlyaccs.jl @@ -0,0 +1,42 @@ +""" + OnlyAccsVarInfo + +This is a wrapper around an `AccumulatorTuple` that implements the minimal `AbstractVarInfo` +interface to work with the `tilde_assume!!` and `tilde_observe!!` functions for +`InitContext`. + +Note that this does not implement almost every other AbstractVarInfo interface function, and +so using this with a different leaf context such as `DefaultContext` will result in errors. + +Conceptually, one can also think of this as a VarInfo that doesn't contain a metadata field. +This is also why it only works with `InitContext`: in this case, the parameters used for +evaluation are supplied by the context instead of the metadata. +""" +struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo + accs::Accs +end +OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) +function OnlyAccsVarInfo(accs::NTuple{N,AbstractAccumulator}) where {N} + return OnlyAccsVarInfo(AccumulatorTuple(accs)) +end + +# Minimal AbstractVarInfo interface +DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi +DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs +DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) + +# Ideally, we'd define this together with InitContext, but alas that file comes way before +# this one, and sorting out the include order is a pain. +function tilde_assume!!( + ctx::InitContext, + dist::Distribution, + vn::VarName, + vi::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, +) + # For OnlyAccsVarInfo, since we don't need to write into the VarInfo, we can + # cut out a lot of the code above. + val, transform = init(ctx.rng, vn, dist, ctx.strategy) + x, inv_logjac = with_logabsdet_jacobian(transform, val) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) + return x, vi +end diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index a49ffd18b..8ee850877 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -5,7 +5,13 @@ using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions using DynamicPPL: - Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link + DynamicPPL, + Model, + LogDensityFunction, + VarInfo, + AbstractVarInfo, + getlogjoint_internal, + link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -298,8 +304,10 @@ function run_ad( # Benchmark grad_time, primal_time = if benchmark - primal_benchmark = @be (ldf, params) logdensity(_[1], _[2]) - grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2]) + logdensity(ldf, params) # Warm-up + primal_benchmark = @be logdensity($ldf, $params) + logdensity_and_gradient(ldf, params) # Warm-up + grad_benchmark = @be logdensity_and_gradient($ldf, $params) median_primal = median(primal_benchmark).time median_grad = median(grad_benchmark).time r(f) = round(f; sigdigits=4) diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index aae2e4ec6..c48d2ddfd 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -4,11 +4,10 @@ # Utilities for testing contexts. # Dummy context to test nested behaviors. -struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext +struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractParentContext context::C end TestParentContext() = TestParentContext(DefaultContext()) -DynamicPPL.NodeTrait(::TestParentContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::TestParentContext) = context.context DynamicPPL.setchildcontext(::TestParentContext, child) = TestParentContext(child) function Base.show(io::IO, c::TestParentContext) @@ -25,19 +24,13 @@ This method ensures that `context` - Correctly implements the tilde-pipeline. """ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - node_trait = DynamicPPL.NodeTrait(context) - if node_trait isa DynamicPPL.IsLeaf - test_leaf_context(context, model) - elseif node_trait isa DynamicPPL.IsParent - test_parent_context(context, model) - else - error("Invalid NodeTrait: $node_trait") - end + return test_leaf_context(context, model) +end +function test_context(context::DynamicPPL.AbstractParentContext, model::DynamicPPL.Model) + return test_parent_context(context, model) end function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf - # Note that for a leaf context we can't assume that it will work with an # empty VarInfo. (For example, DefaultContext will error with empty # varinfos.) Thus we only test evaluation with VarInfos that are already @@ -57,8 +50,6 @@ function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPP end function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent - @testset "get/set leaf and child contexts" begin # Ensure we're using a different leaf context than the current. leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext diff --git a/src/utils.jl b/src/utils.jl index b55a2f715..75fb805dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -15,6 +15,41 @@ This is Float64 on 64-bit systems and Float32 on 32-bit systems. """ const LogProbType = float(Real) +""" + typed_identity(x) + +Identity function, but with an overload for `with_logabsdet_jacobian` to ensure +that it returns a sensible zero logjac. + +The problem with plain old `identity` is that the default definition of +`with_logabsdet_jacobian` for `identity` returns `zero(eltype(x))`: +https://github.com/JuliaMath/ChangesOfVariables.jl/blob/d6a8115fc9b9419decbdb48e2c56ec9675b4c6a4/src/with_ladj.jl#L154 + +This is fine for most samples `x`, but if `eltype(x)` doesn't return a sensible type (e.g. +if it's `Any`), then using `identity` will error with `zero(Any)`. This can happen with, +for example, `ProductNamedTupleDistribution`: + +```julia +julia> using Distributions; d = product_distribution((a = Normal(), b = LKJCholesky(3, 0.5))); + +julia> eltype(rand(d)) +Any +``` + +The same problem precludes us from eventually broadening the scope of DynamicPPL.jl to +support distributions with non-numeric samples. + +Furthermore, in principle, the type of the log-probability should be separate from the type +of the sample. Thus, instead of using `zero(LogProbType)`, we should use the eltype of the +LogJacobianAccumulator. There's no easy way to thread that through here, but if a way to do +this is discovered, then `typed_identity` is what will allow us to obtain that custom +behaviour. +""" +function typed_identity end +@inline typed_identity(x) = x +@inline Bijectors.with_logabsdet_jacobian(::typeof(typed_identity), x) = + (x, zero(LogProbType)) + """ @addlogprob!(ex) diff --git a/test/Project.toml b/test/Project.toml index 5590ac169..0c014a193 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" diff --git a/test/ad.jl b/test/ad.jl deleted file mode 100644 index 0236c232f..000000000 --- a/test/ad.jl +++ /dev/null @@ -1,137 +0,0 @@ -using DynamicPPL: LogDensityFunction -using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest - -@testset "Automatic differentiation" begin - # Used as the ground truth that others are compared against. - ref_adtype = AutoForwardDiff() - - test_adtypes = [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), - ] - - @testset "Unsupported backends" begin - @model demo() = x ~ Normal() - @test_logs (:warn, r"not officially supported") LogDensityFunction( - demo(); adtype=AutoZygote() - ) - end - - @testset "Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) - x = DynamicPPL.getparams(f) - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" - - # Put predicates here to avoid long lines - is_mooncake = adtype isa AutoMooncake - is_1_10 = v"1.10" <= VERSION < v"1.11" - is_1_11_or_1_12 = v"1.11" <= VERSION < v"1.13" - is_svi_vnv = - linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} - is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} - - # Mooncake doesn't work with several combinations of SimpleVarInfo. - if is_mooncake && is_1_11_or_1_12 && is_svi_vnv - # https://github.com/compintell/Mooncake.jl/issues/470 - @test_throws ArgumentError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_vnv - # TODO: report upstream - @test_throws UndefRefError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_od - # TODO: report upstream - @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - else - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any - end - end - end - end - end - - # Test that various different ways of specifying array types as arguments work with all - # ADTypes. - @testset "Array argument types" begin - test_m = randn(2, 3) - - function eval_logp_and_grad(model, m, adtype) - ldf = LogDensityFunction(model(); adtype=adtype) - return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) - end - - @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} - m = Matrix{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_matrix_model_reference = eval_logp_and_grad( - scalar_matrix_model, test_m, ref_adtype - ) - - @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) - - @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} - m = Array{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_array_model_reference = eval_logp_and_grad( - scalar_array_model, test_m, ref_adtype - ) - - @model function array_model(::Type{T}=Array{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) - - @testset "$adtype" for adtype in test_adtypes - scalar_matrix_model_logp_and_grad = eval_logp_and_grad( - scalar_matrix_model, test_m, adtype - ) - @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] - @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] - matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) - @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] - @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] - scalar_array_model_logp_and_grad = eval_logp_and_grad( - scalar_array_model, test_m, adtype - ) - @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] - @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] - array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) - @test array_model_logp_and_grad[1] ≈ array_model_reference[1] - @test array_model_logp_and_grad[2] ≈ array_model_reference[2] - end - end -end diff --git a/test/chains.jl b/test/chains.jl index ab0ff4475..12a9ece71 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -4,7 +4,7 @@ using DynamicPPL using Distributions using Test -@testset "ParamsWithStats" begin +@testset "ParamsWithStats from VarInfo" begin @model function f(z) x ~ Normal() y := x + 1 @@ -66,4 +66,30 @@ using Test end end +@testset "ParamsWithStats from LogDensityFunction" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + unlinked_vi = VarInfo(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + params = [x for x in vi[:]] + + # Get the ParamsWithStats using LogDensityFunction + ldf = DynamicPPL.LogDensityFunction(m, getlogjoint, vi) + ps = ParamsWithStats(params, ldf) + + # Check that length of parameters is as expected + @test length(ps.params) == length(keys(vi)) + + # Iterate over all variables to check that their values match + for vn in keys(vi) + @test ps.params[vn] == vi[vn] + end + end + end +end + end # module diff --git a/test/contexts.jl b/test/contexts.jl index 972d833a5..ae7332a43 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -6,10 +6,9 @@ using DynamicPPL: childcontext, setchildcontext, AbstractContext, - NodeTrait, - IsLeaf, - IsParent, + AbstractParentContext, contextual_isassumption, + PrefixContext, FixedContext, ConditionContext, decondition_context, @@ -25,22 +24,21 @@ using LinearAlgebra: I using Random: Xoshiro # TODO: Should we maybe put this in DPPL itself? +function Base.iterate(context::AbstractParentContext) + return context, childcontext(context) +end function Base.iterate(context::AbstractContext) - if NodeTrait(context) isa IsLeaf - return nothing - end - - return context, context + return context, nothing end -function Base.iterate(_::AbstractContext, context::AbstractContext) - return _iterate(NodeTrait(context), context) +function Base.iterate(::AbstractContext, state::AbstractParentContext) + return state, childcontext(state) end -_iterate(::IsLeaf, context) = nothing -function _iterate(::IsParent, context) - child = childcontext(context) - return child, child +function Base.iterate(::AbstractContext, state::AbstractContext) + return state, nothing +end +function Base.iterate(::AbstractContext, state::Nothing) + return nothing end - Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @@ -347,11 +345,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "collapse_prefix_stack" begin # Utility function to make sure that there are no PrefixContexts in # the context stack. - function has_no_prefixcontexts(ctx::AbstractContext) - return !(ctx isa PrefixContext) && ( - NodeTrait(ctx) isa IsLeaf || has_no_prefixcontexts(childcontext(ctx)) - ) + has_no_prefixcontexts(::PrefixContext) = false + function has_no_prefixcontexts(ctx::AbstractParentContext) + return has_no_prefixcontexts(childcontext(ctx)) end + has_no_prefixcontexts(::AbstractContext) = true # Prefix -> Condition c1 = PrefixContext(@varname(a), ConditionContext((c=1, d=2))) diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl index 44db66296..b58c3e7bc 100644 --- a/test/ext/DynamicPPLForwardDiffExt.jl +++ b/test/ext/DynamicPPLForwardDiffExt.jl @@ -14,16 +14,17 @@ using Test: @test, @testset @model f() = x ~ MvNormal(zeros(MODEL_SIZE), I) model = f() varinfo = VarInfo(model) + x = varinfo[:] @testset "Chunk size setting" for chunksize in (nothing, 0) base_adtype = AutoForwardDiff(; chunksize=chunksize) - new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, x) @test new_adtype isa AutoForwardDiff{MODEL_SIZE} end @testset "Tag setting" begin base_adtype = AutoForwardDiff() - new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, x) @test new_adtype.tag isa ForwardDiff.Tag{DynamicPPL.DynamicPPLTag} end end diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index b40bbeb8f..edfd67d18 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -5,9 +5,15 @@ using Test: @test, @testset import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test -ADTYPES = Dict( - "EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward)), - "EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse)), +ADTYPES = ( + ( + "EnzymeForward", + AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), + ), + ( + "EnzymeReverse", + AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), + ), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index fbd868f71..f43ed45a4 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -1,49 +1,256 @@ -using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff +module DynamicPPLLDFTests -@testset "`getmodel` and `setmodel`" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - model = DynamicPPL.TestUtils.DEMO_MODELS[1] - ℓ = DynamicPPL.LogDensityFunction(model) - @test DynamicPPL.getmodel(ℓ) == model - @test DynamicPPL.setmodel(ℓ, model).model == model +using AbstractPPL: AbstractPPL +using Chairmarks +using DynamicPPL +using Distributions +using DistributionsAD: filldist +using ADTypes +using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using LinearAlgebra: I +using Test +using LogDensityProblems: LogDensityProblems + +using ForwardDiff: ForwardDiff +using ReverseDiff: ReverseDiff +using Mooncake: Mooncake + +@testset "LogDensityFunction: Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$varinfo_func" for varinfo_func in [ + DynamicPPL.untyped_varinfo, + DynamicPPL.typed_varinfo, + DynamicPPL.untyped_vector_varinfo, + DynamicPPL.typed_vector_varinfo, + ] + unlinked_vi = varinfo_func(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + nt_ranges, dict_ranges = DynamicPPL.get_ranges_and_linked(vi) + params = [x for x in vi[:]] + # Iterate over all variables + for vn in keys(vi) + # Check that `getindex_internal` returns the same thing as using the ranges + # directly + range_with_linked = if AbstractPPL.getoptic(vn) === identity + nt_ranges[AbstractPPL.getsym(vn)] + else + dict_ranges[vn] + end + @test params[range_with_linked.range] == + DynamicPPL.getindex_internal(vi, vn) + # Check that the link status is correct + @test range_with_linked.is_linked == islinked + end + end + end + end + + @testset "Threaded observe" begin + if Threads.nthreads() > 1 + @model function threaded(y) + x ~ Normal() + Threads.@threads for i in eachindex(y) + y[i] ~ Normal(x) + end + end + N = 100 + model = threaded(zeros(N)) + ldf = DynamicPPL.LogDensityFunction(model) + + xs = [1.0] + @test LogDensityProblems.logdensity(ldf, xs) ≈ + logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) + end end end -@testset "LogDensityFunction" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) - - vi = first(varinfos) - theta = vi[:] - ldf_joint = DynamicPPL.LogDensityFunction(model) - @test LogDensityProblems.logdensity(ldf_joint, theta) ≈ logjoint(model, vi) - ldf_prior = DynamicPPL.LogDensityFunction(model, getlogprior) - @test LogDensityProblems.logdensity(ldf_prior, theta) ≈ logprior(model, vi) - ldf_likelihood = DynamicPPL.LogDensityFunction(model, getloglikelihood) - @test LogDensityProblems.logdensity(ldf_likelihood, theta) ≈ - loglikelihood(model, vi) - - @testset "$(varinfo)" for varinfo in varinfos - # Note use of `getlogjoint` rather than `getlogjoint_internal` here ... - logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo) - θ = varinfo[:] - # ... because it has to match with `logjoint(model, vi)`, which always returns - # the unlinked value - @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) - @test LogDensityProblems.dimension(logdensity) == length(θ) +@testset "LogDensityFunction: interface" begin + # miscellaneous parts of the LogDensityProblems interface + @testset "dimensions" begin + @model function m1() + x ~ Normal() + y ~ Normal() + return nothing end + model = m1() + ldf = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.dimension(ldf) == 2 + + @model function m2() + x ~ Dirichlet(ones(4)) + y ~ Categorical(x) + return nothing + end + model = m2() + ldf = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.dimension(ldf) == 5 + linked_vi = DynamicPPL.link!!(VarInfo(model), model) + ldf = DynamicPPL.LogDensityFunction(model, getlogjoint_internal, linked_vi) + @test LogDensityProblems.dimension(ldf) == 4 end @testset "capabilities" begin - model = DynamicPPL.TestUtils.DEMO_MODELS[1] + @model f() = x ~ Normal() + model = f() + # No adtype ldf = DynamicPPL.LogDensityFunction(model) @test LogDensityProblems.capabilities(typeof(ldf)) == LogDensityProblems.LogDensityOrder{0}() - - ldf_with_ad = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff()) - @test LogDensityProblems.capabilities(typeof(ldf_with_ad)) == + # With adtype + ldf = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff()) + @test LogDensityProblems.capabilities(typeof(ldf)) == LogDensityProblems.LogDensityOrder{1}() end end + +@testset "LogDensityFunction: Type stability" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + unlinked_vi = DynamicPPL.VarInfo(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + @inferred LogDensityProblems.logdensity(ldf, x) + end + end +end + +@testset "LogDensityFunction: performance" begin + if Threads.nthreads() == 1 + # Evaluating these three models should not lead to any allocations (but only when + # not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in + (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) + vi = VarInfo(model) + ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(ldf, x)) + @test iszero(bench.allocs) + end + end +end + +@testset "AD with LogDensityFunction" begin + # Used as the ground truth that others are compared against. + ref_adtype = AutoForwardDiff() + + test_adtypes = [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + + @testset "Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + varinfo = VarInfo(m) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) + x = [p for p in linked_varinfo[:]] + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $adtype" + + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any + end + end + end + + # Test that various different ways of specifying array types as arguments work with all + # ADTypes. + @testset "Array argument types" begin + test_m = randn(2, 3) + + function eval_logp_and_grad(model, m, adtype) + ldf = LogDensityFunction(model(); adtype=adtype) + return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) + end + + @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} + m = Matrix{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_matrix_model_reference = eval_logp_and_grad( + scalar_matrix_model, test_m, ref_adtype + ) + + @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) + + @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} + m = Array{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_array_model_reference = eval_logp_and_grad( + scalar_array_model, test_m, ref_adtype + ) + + @model function array_model(::Type{T}=Array{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) + + @testset "$adtype" for adtype in test_adtypes + scalar_matrix_model_logp_and_grad = eval_logp_and_grad( + scalar_matrix_model, test_m, adtype + ) + @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] + @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] + matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) + @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] + @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] + scalar_array_model_logp_and_grad = eval_logp_and_grad( + scalar_array_model, test_m, adtype + ) + @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] + @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] + array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) + @test array_model_logp_and_grad[1] ≈ array_model_reference[1] + @test array_model_logp_and_grad[2] ≈ array_model_reference[2] + end + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 5e40635e6..9649aebbb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ using ForwardDiff using LogDensityProblems using MacroTools using MCMCChains +using Mooncake using StableRNGs using ReverseDiff using Mooncake @@ -57,7 +58,6 @@ include("test_util.jl") include("simple_varinfo.jl") include("model.jl") include("distribution_wrappers.jl") - include("logdensityfunction.jl") include("linking.jl") include("serialization.jl") include("pointwise_logdensities.jl") @@ -68,10 +68,11 @@ include("test_util.jl") include("debug_utils.jl") include("submodels.jl") include("chains.jl") - include("bijector.jl") end if GROUP == "All" || GROUP == "Group2" + include("bijector.jl") + include("logdensityfunction.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") include("ext/DynamicPPLJETExt.jl") @@ -80,7 +81,6 @@ include("test_util.jl") @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl") include("ext/DynamicPPLMooncakeExt.jl") - include("ad.jl") end @testset "prob and logprob macro" begin @test_throws ErrorException prob"..."