diff --git a/GeneralisedFilters/src/GFTest/models/dummy_linear_gaussian.jl b/GeneralisedFilters/src/GFTest/models/dummy_linear_gaussian.jl index defb9214..a1ce5a1b 100644 --- a/GeneralisedFilters/src/GFTest/models/dummy_linear_gaussian.jl +++ b/GeneralisedFilters/src/GFTest/models/dummy_linear_gaussian.jl @@ -56,7 +56,7 @@ function create_dummy_linear_gaussian_model( ) # Generate model matrices/vectors μ0 = rand(rng, T, D_outer + D_inner) - Σ0s = [ + Σ0 = [ rand_cov(rng, T, D_outer) zeros(T, D_outer, D_inner) zeros(T, D_inner, D_outer) rand_cov(rng, T, D_inner) ] @@ -76,39 +76,63 @@ function create_dummy_linear_gaussian_model( R = rand_cov(rng, T, Dy; scale=obs_noise_scale) # Create full model - full_model = create_homogeneous_linear_gaussian_model(μ0, Σ0s, A, b, Q, H, c, R) - - # Create hierarchical model - outer_prior = HomogeneousGaussianPrior(μ0[1:D_outer], Σ0s[1:D_outer, 1:D_outer]) - - outer_dyn = HomogeneousLinearGaussianLatentDynamics( - A[1:D_outer, 1:D_outer], b[1:D_outer], Q[1:D_outer, 1:D_outer] + full_model = create_homogeneous_linear_gaussian_model( + μ0, PDMat(Σ0), A, b, PDMat(Q), H, c, PDMat(R) ) + outer_prior, outer_dyn = if static_arrays + prior = HomogeneousGaussianPrior( + SVector{D_outer,T}(μ0[1:D_outer]), + PDMat(SMatrix{D_outer,D_outer,T}(Σ0[1:D_outer, 1:D_outer])), + ) + dyn = HomogeneousLinearGaussianLatentDynamics( + SMatrix{D_outer,D_outer,T}(A[1:D_outer, 1:D_outer]), + SVector{D_outer,T}(b[1:D_outer]), + PDMat(SMatrix{D_outer,D_outer,T}(Q[1:D_outer, 1:D_outer])), + ) + prior, dyn + else + prior = HomogeneousGaussianPrior(μ0[1:D_outer], PDMat(Σ0[1:D_outer, 1:D_outer])) + dyn = HomogeneousLinearGaussianLatentDynamics( + A[1:D_outer, 1:D_outer], b[1:D_outer], PDMat(Q[1:D_outer, 1:D_outer]) + ) + prior, dyn + end + inner_prior, inner_dyn = if static_arrays prior = InnerPrior( SVector{D_inner,T}(μ0[(D_outer + 1):end]), - SMatrix{D_inner,D_inner,T}(Σ0s[(D_outer + 1):end, (D_outer + 1):end]), + PDMat(SMatrix{D_inner,D_inner,T}(Σ0[(D_outer + 1):end, (D_outer + 1):end])), ) dyn = InnerDynamics( SMatrix{D_inner,D_outer,T}(A[(D_outer + 1):end, (D_outer + 1):end]), SVector{D_inner,T}(b[(D_outer + 1):end]), SMatrix{D_inner,D_outer,T}(A[(D_outer + 1):end, 1:D_outer]), - SMatrix{D_inner,D_inner,T}(Q[(D_outer + 1):end, (D_outer + 1):end]), + PDMat(SMatrix{D_inner,D_inner,T}(Q[(D_outer + 1):end, (D_outer + 1):end])), ) prior, dyn else - prior = InnerPrior(μ0[(D_outer + 1):end], Σ0s[(D_outer + 1):end, (D_outer + 1):end]) + prior = InnerPrior( + μ0[(D_outer + 1):end], PDMat(Σ0[(D_outer + 1):end, (D_outer + 1):end]) + ) dyn = InnerDynamics( A[(D_outer + 1):end, (D_outer + 1):end], b[(D_outer + 1):end], A[(D_outer + 1):end, 1:D_outer], - Q[(D_outer + 1):end, (D_outer + 1):end], + PDMat(Q[(D_outer + 1):end, (D_outer + 1):end]), ) prior, dyn end - obs = HomogeneousLinearGaussianObservationProcess(H[:, (D_outer + 1):end], c, R) + obs = if static_arrays + HomogeneousLinearGaussianObservationProcess( + SMatrix{Dy,D_inner,T}(H[:, (D_outer + 1):end]), + SVector{Dy,T}(c), + PDMat(SMatrix{Dy,Dy,T}(R)), + ) + else + HomogeneousLinearGaussianObservationProcess(H[:, (D_outer + 1):end], c, PDMat(R)) + end hier_model = HierarchicalSSM(outer_prior, outer_dyn, inner_prior, inner_dyn, obs) return full_model, hier_model diff --git a/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl b/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl index 58887d9e..e2aae457 100644 --- a/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl +++ b/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl @@ -1,4 +1,5 @@ using StaticArrays +import PDMats: PDMat function create_linear_gaussian_model( rng::AbstractRNG, @@ -29,7 +30,9 @@ function create_linear_gaussian_model( R = SMatrix{Dy,Dy,T}(R) end - return create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) + return create_homogeneous_linear_gaussian_model( + μ0, PDMat(Σ0), A, b, PDMat(Q), H, c, PDMat(R) + ) end function _compute_joint(model, T::Integer) diff --git a/GeneralisedFilters/src/GFTest/proposals.jl b/GeneralisedFilters/src/GFTest/proposals.jl index 9b4ca3d3..2059b482 100644 --- a/GeneralisedFilters/src/GFTest/proposals.jl +++ b/GeneralisedFilters/src/GFTest/proposals.jl @@ -16,11 +16,18 @@ struct OptimalProposal{ end function SSMProblems.distribution(prop::OptimalProposal, step::Integer, x, y; kwargs...) - A, b, Q = GeneralisedFilters.calc_params(prop.dyn, step; kwargs...) - H, c, R = GeneralisedFilters.calc_params(prop.obs, step; kwargs...) - Σ = inv(inv(Q) + H' * inv(R) * H) - μ = Σ * (inv(Q) * (A * x + b) + H' * inv(R) * (y - c)) - return MvNormal(μ, Σ) + # Get parameters + dyn_params = GeneralisedFilters.calc_params(prop.dyn, step; kwargs...) + obs_params = GeneralisedFilters.calc_params(prop.obs, step; kwargs...) + A, b, Q = dyn_params + + # Predicted state: p(x_t | x_{t-1}) + state = MvNormal(A * x + b, Q) + + # Update with observation: p(x_t | x_{t-1}, y_t) + state, _ = GeneralisedFilters.kalman_update(state, obs_params, y) + + return state end """ diff --git a/GeneralisedFilters/src/GeneralisedFilters.jl b/GeneralisedFilters/src/GeneralisedFilters.jl index 8d025352..a3b3464c 100644 --- a/GeneralisedFilters/src/GeneralisedFilters.jl +++ b/GeneralisedFilters/src/GeneralisedFilters.jl @@ -1,7 +1,7 @@ module GeneralisedFilters using AbstractMCMC: AbstractMCMC, AbstractSampler -import Distributions: MvNormal +import Distributions: MvNormal, params import Random: AbstractRNG, default_rng, rand import SSMProblems: prior, dyn, obs using OffsetArrays diff --git a/GeneralisedFilters/src/algorithms/kalman.jl b/GeneralisedFilters/src/algorithms/kalman.jl index 9c699363..d39de21f 100644 --- a/GeneralisedFilters/src/algorithms/kalman.jl +++ b/GeneralisedFilters/src/algorithms/kalman.jl @@ -1,6 +1,7 @@ export KalmanFilter, filter using CUDA: i32 -import PDMats: PDMat +import PDMats: PDMat, X_A_Xt +import LinearAlgebra: Symmetric export KalmanFilter, KF, KalmanSmoother, KS export BackwardInformationPredictor @@ -12,7 +13,7 @@ KF() = KalmanFilter() function initialise(rng::AbstractRNG, prior::GaussianPrior, filter::KalmanFilter; kwargs...) μ0, Σ0 = calc_initial(prior; kwargs...) - return GaussianDistribution(μ0, Σ0) + return MvNormal(μ0, Σ0) end function predict( @@ -20,53 +21,54 @@ function predict( dyn::LinearGaussianLatentDynamics, algo::KalmanFilter, iter::Integer, - state::GaussianDistribution, + state::MvNormal, observation=nothing; kwargs..., ) - params = calc_params(dyn, iter; kwargs...) - state = kalman_predict(state, params) + dyn_params = calc_params(dyn, iter; kwargs...) + state = kalman_predict(state, dyn_params) return state end -function kalman_predict(state, params) - μ, Σ = mean_cov(state) - A, b, Q = params +function kalman_predict(state, dyn_params) + μ, Σ = params(state) + A, b, Q = dyn_params μ̂ = A * μ + b - Σ̂ = A * Σ * A' + Q - return GaussianDistribution(μ̂, Σ̂) + Σ̂ = X_A_Xt(Σ, A) + Q + return MvNormal(μ̂, Σ̂) end function update( obs::LinearGaussianObservationProcess, algo::KalmanFilter, iter::Integer, - state::GaussianDistribution, + state::MvNormal, observation::AbstractVector; kwargs..., ) - params = calc_params(obs, iter; kwargs...) - state, ll = kalman_update(state, params, observation) + obs_params = calc_params(obs, iter; kwargs...) + state, ll = kalman_update(state, obs_params, observation) return state, ll end -function kalman_update(state, params, observation) - μ, Σ = mean_cov(state) - H, c, R = params +function kalman_update(state, obs_params, observation) + μ, Σ = params(state) + H, c, R = obs_params - # Update state + # Compute innovation distribution m = H * μ + c - y = observation - m - S = H * Σ * H' + R - S = (S + S') / 2 # force symmetric; TODO: replace with SA-compatibile hermitianpart - S_chol = cholesky(S) - K = Σ * H' / S_chol # Zygote errors when using PDMat in solve + S = PDMat(X_A_Xt(Σ, H) + R) + ȳ = observation - m + K = Σ * H' / S - state = GaussianDistribution(μ + K * y, Σ - K * H * Σ) + # Update parameters using Joseph form to ensure numerical stability + μ̂ = μ + K * ȳ + Σ̂ = PDMat(X_A_Xt(Σ, I - K * H) + X_A_Xt(R, K)) + state = MvNormal(μ̂, Σ̂) # Compute log-likelihood - ll = logpdf(MvNormal(m, PDMat(S_chol)), observation) + ll = logpdf(MvNormal(m, S), observation) return state, ll end @@ -146,25 +148,36 @@ function smooth( return back_state, ll end +import LinearAlgebra: eigen + function backward( rng::AbstractRNG, model::LinearGaussianStateSpaceModel, algo::KalmanSmoother, iter::Integer, - back_state, + back_state::MvNormal, obs; states_cache, kwargs..., ) - μ, Σ = mean_cov(back_state) - μ_pred, Σ_pred = mean_cov(states_cache.proposed_states[iter + 1]) - μ_filt, Σ_filt = mean_cov(states_cache.filtered_states[iter]) + # Extract filtered and predicted states + μ_filt, Σ_filt = params(states_cache.filtered_states[iter]) + μ_pred, Σ_pred = params(states_cache.proposed_states[iter + 1]) + μ_back, Σ_back = params(back_state) + + dyn_params = calc_params(model.dyn, iter + 1; kwargs...) + A, b, Q = dyn_params + + G = Σ_filt * A' / Σ_pred + μ̂ = μ_filt + G * (μ_back - μ_pred) + + # Σ_pred - Σ_back may be singular (even though it is PSD) so cannot use X_A_Xt with Cholesky + Σ̂ = Σ_filt + G * (Σ_back - Σ_pred) * G' - G = Σ_filt * model.dyn.A' * inv(Σ_pred) - μ = μ_filt .+ G * (μ .- μ_pred) - Σ = Σ_filt .+ G * (Σ .- Σ_pred) * G' + # Force symmetry + Σ̂ = PDMat(Symmetric(Σ̂)) - return GaussianDistribution(μ, Σ) + return MvNormal(μ̂, Σ̂) end ## BACKWARD INFORMATION FILTER ############################################################# @@ -182,7 +195,7 @@ struct BackwardInformationPredictor <: AbstractBackwardPredictor end """ backward_initialise(rng, obs, algo, iter, y; kwargs...) -Initialise a backward predictor at time `T` with observation `y`, forming the likelihood +Initialise a backward predictor at time `T` with observation `y`, forming the likelihood p(y_T | x_T). """ function backward_initialise( @@ -197,7 +210,7 @@ function backward_initialise( R_inv = inv(R) λ = H' * R_inv * (y - c) Ω = H' * R_inv * H - return InformationDistribution(λ, Ω) + return InformationLikelihood(λ, Ω) end """ @@ -210,7 +223,7 @@ function backward_predict( dyn::LinearGaussianLatentDynamics, algo::BackwardInformationPredictor, iter::Integer, - state::InformationDistribution; + state::InformationLikelihood; kwargs..., ) λ, Ω = natural_params(state) @@ -223,7 +236,7 @@ function backward_predict( Ω̂ = A' * (I - Ω * F * inv(Λ) * F') * Ω * A λ̂ = A' * (I - Ω * F * inv(Λ) * F') * m - return InformationDistribution(λ̂, Ω̂) + return InformationLikelihood(λ̂, Ω̂) end """ @@ -235,7 +248,7 @@ function backward_update( obs::LinearGaussianObservationProcess, algo::BackwardInformationPredictor, iter::Integer, - state::InformationDistribution, + state::InformationLikelihood, y; kwargs..., ) @@ -246,5 +259,5 @@ function backward_update( λ̂ = λ + H' * R_inv * (y - c) Ω̂ = Ω + H' * R_inv * H - return InformationDistribution(λ̂, Ω̂) + return InformationLikelihood(λ̂, Ω̂) end diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index c4a058ea..0d984940 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -16,7 +16,8 @@ end function SSMProblems.simulate( rng::AbstractRNG, prop::AbstractProposal, iter::Integer, state, observation; kwargs... ) - return rand(rng, SSMProblems.distribution(prop, iter, state, observation; kwargs...)) + dist = SSMProblems.distribution(prop, iter, state, observation; kwargs...) + return SSMProblems.simulate_from_dist(rng, dist) end function SSMProblems.logdensity( diff --git a/GeneralisedFilters/src/ancestor_sampling.jl b/GeneralisedFilters/src/ancestor_sampling.jl index 8348a95a..deb92b45 100644 --- a/GeneralisedFilters/src/ancestor_sampling.jl +++ b/GeneralisedFilters/src/ancestor_sampling.jl @@ -13,7 +13,7 @@ function ancestor_weight( algo::RBPF, iter::Integer, state::RBState, - ref_state::RBState{<:Any,<:InformationDistribution}; + ref_state::RBState{<:Any,<:InformationLikelihood}; kwargs..., ) trans_weight = ancestor_weight( @@ -52,9 +52,9 @@ p(y_{t+1:T} | x_{t+1}). This Gaussian implementation is based on Lemma 1 of https://arxiv.org/pdf/1505.06357 """ function compute_marginal_predictive_likelihood( - forward_dist::GaussianDistribution, backward_dist::InformationDistribution + forward_dist::MvNormal, backward_dist::InformationLikelihood ) - μ, Σ = mean_cov(forward_dist) + μ, Σ = params(forward_dist) λ, Ω = natural_params(backward_dist) Γ = cholesky(Σ).L diff --git a/GeneralisedFilters/src/containers.jl b/GeneralisedFilters/src/containers.jl index 6aba7cda..938134e8 100644 --- a/GeneralisedFilters/src/containers.jl +++ b/GeneralisedFilters/src/containers.jl @@ -98,34 +98,40 @@ end ## GAUSSIAN STATES ######################################################################### -struct GaussianDistribution{PT,ΣT} - μ::PT - Σ::ΣT -end +""" + InformationLikelihood -function mean_cov(state::GaussianDistribution) - return state.μ, state.Σ -end +A container representing an unnormalized Gaussian likelihood p(y | x) in information form, +parameterized by natural parameters (λ, Ω). + +The unnormalized log-likelihood is given by: + log p(y | x) ∝ λ'x - (1/2)x'Ωx -struct InformationDistribution{λT,ΩT} +This representation is particularly useful in backward filtering algorithms and +Rao-Blackwellised particle filtering, where it represents the predictive likelihood +p(y_{t:T} | x_t) conditioned on future observations. + +# Fields +- `λ::λT`: The natural parameter vector (information vector) +- `Ω::ΩT`: The natural parameter matrix (information/precision matrix) + +# See also +- [`natural_params`](@ref): Extract the natural parameters (λ, Ω) +- [`BackwardInformationPredictor`](@ref): Algorithm that uses this representation +""" +struct InformationLikelihood{λT,ΩT} λ::λT Ω::ΩT end -function natural_params(state::InformationDistribution) - return state.λ, state.Ω -end +""" + natural_params(state::InformationLikelihood) -# Conversions — explicit since these may fail if the covariance/precision is not invertible -function GaussianDistribution(state::InformationDistribution) - λ, Ω = natural_params(state) - Σ = inv(Ω) - μ = Σ * λ - return GaussianDistribution(μ, Σ) -end -function InformationDistribution(state::GaussianDistribution) - μ, Σ = mean_cov(state) - Ω = inv(Σ) - λ = Ω * μ - return InformationDistribution(λ, Ω) +Extract the natural parameters (λ, Ω) from an InformationLikelihood. + +Returns a tuple `(λ, Ω)` where λ is the information vector and Ω is the +information/precision matrix. +""" +function natural_params(state::InformationLikelihood) + return state.λ, state.Ω end diff --git a/GeneralisedFilters/src/models/linear_gaussian.jl b/GeneralisedFilters/src/models/linear_gaussian.jl index 57be3122..418bba22 100644 --- a/GeneralisedFilters/src/models/linear_gaussian.jl +++ b/GeneralisedFilters/src/models/linear_gaussian.jl @@ -7,9 +7,21 @@ export HomogeneousGaussianPrior export HomogeneousLinearGaussianLatentDynamics export HomogeneousLinearGaussianObservationProcess -import SSMProblems: distribution -import Distributions: MvNormal +import SSMProblems: distribution, simulate_from_dist +import Distributions: MvNormal, params import LinearAlgebra: cholesky +import Random: AbstractRNG, randn +import PDMats: PDMat +using StaticArrays + +# Custom sampling for MvNormal with static arrays to return SVector instead of Vector +function SSMProblems.simulate_from_dist( + rng::AbstractRNG, d::MvNormal{T,<:PDMat{T,<:SMatrix{D,D,T}},SVector{D,T}} +) where {T,D} + μ, Σ = params(d) + z = @SVector randn(rng, D) + return μ + cholesky(Σ).L * z +end abstract type GaussianPrior <: StatePrior end @@ -49,11 +61,6 @@ const LinearGaussianStateSpaceModel = StateSpaceModel{ <:GaussianPrior,<:LinearGaussianLatentDynamics,<:LinearGaussianObservationProcess } -function rb_eltype(model::LinearGaussianStateSpaceModel) - μ0, Σ0 = calc_initial(model.prior) - return Gaussian{typeof(μ0),typeof(Σ0)} -end - ####################### #### DISTRIBUTIONS #### ####################### @@ -81,7 +88,7 @@ end #### HOMOGENEOUS LINEAR GAUSSIAN MODEL #### ########################################### -struct HomogeneousGaussianPrior{XT<:AbstractVector,ΣT<:AbstractMatrix} <: GaussianPrior +struct HomogeneousGaussianPrior{XT<:AbstractVector,ΣT<:PDMat} <: GaussianPrior μ0::XT Σ0::ΣT end @@ -89,7 +96,7 @@ calc_μ0(prior::HomogeneousGaussianPrior; kwargs...) = prior.μ0 calc_Σ0(prior::HomogeneousGaussianPrior; kwargs...) = prior.Σ0 struct HomogeneousLinearGaussianLatentDynamics{ - AT<:AbstractMatrix,bT<:AbstractVector,QT<:AbstractMatrix + AT<:AbstractMatrix,bT<:AbstractVector,QT<:PDMat } <: LinearGaussianLatentDynamics A::AT b::bT @@ -100,7 +107,7 @@ calc_b(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer; kwargs...) = dyn calc_Q(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer; kwargs...) = dyn.Q struct HomogeneousLinearGaussianObservationProcess{ - HT<:AbstractMatrix,cT<:AbstractVector,RT<:AbstractMatrix + HT<:AbstractMatrix,cT<:AbstractVector,RT<:PDMat } <: LinearGaussianObservationProcess H::HT c::cT diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index 14130e16..faa0e9f7 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -107,6 +107,7 @@ end using SSMProblems using StableRNGs using StaticArrays + using PDMats D = 2 rng = StableRNG(1234) @@ -123,16 +124,18 @@ end R = @SMatrix rand(rng, D, D) R = R * R' - model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) + model = create_homogeneous_linear_gaussian_model( + μ0, PDMat(Σ0), A, b, PDMat(Q), H, c, PDMat(R) + ) _, _, ys = sample(rng, model, 2) state, _ = GeneralisedFilters.filter(rng, model, KalmanFilter(), ys) # Verify returned values are still StaticArrays - # @test ys[2] isa SVector{D,Float64} # TODO: this fails due to use of MvNormal + @test ys[2] isa SVector{D,Float64} @test state.μ isa SVector{D,Float64} - @test state.Σ isa SMatrix{D,D,Float64} + @test state.Σ isa PDMat{Float64,SMatrix{D,D,Float64,D * D}} end @testitem "Kalman smoother test" begin @@ -147,7 +150,9 @@ end for Dy in Dys rng = StableRNG(1234) - model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, Dx, Dy) + model = GeneralisedFilters.GFTest.create_linear_gaussian_model( + rng, Dx, Dy; static_arrays=true + ) _, _, ys = sample(rng, model, 2) states, ll = GeneralisedFilters.smooth(rng, model, KalmanSmoother(), ys) @@ -578,6 +583,8 @@ end rng, D_outer, D_inner, D_obs, T; static_arrays=true ) _, _, ys = sample(rng, full_model, K) + # Convert to static arrays + ys = [SVector{1,T}(y) for y in ys] # Kalman smoother state, _ = GeneralisedFilters.smooth( @@ -589,8 +596,8 @@ end ref_traj = nothing trajectory_samples = [] + cb = GeneralisedFilters.DenseAncestorCallback(nothing) for i in 1:N_steps - cb = GeneralisedFilters.DenseAncestorCallback(nothing) bf_state, _ = GeneralisedFilters.filter( rng, hier_model, rbpf, ys; ref_state=ref_traj, callback=cb ) @@ -622,9 +629,10 @@ end μ_filt = trajectory_samples[i][t].z.μ Σ_filt = trajectory_samples[i][t].z.Σ μ_pred = A * μ_filt + b + C * trajectory_samples[i][t].x - Σ_pred = A * Σ_filt * A' + Q + Σ_pred = X_A_Xt(Σ_filt, A) + Q + Σ_pred = PDMat(Symmetric(Σ_pred)) - G = Σ_filt * A' * inv(Σ_pred) + G = Σ_filt * A' / Σ_pred μ = μ_filt .+ G * (μ .- μ_pred) Σ = Σ_filt .+ G * (Σ .- Σ_pred) * G' end @@ -649,7 +657,7 @@ end using LogExpFunctions import SSMProblems: prior, dyn, obs - import GeneralisedFilters: resampler, resample, move, RBState, InformationDistribution + import GeneralisedFilters: resampler, resample, move, RBState, InformationLikelihood using OffsetArrays @@ -678,7 +686,7 @@ end N_steps = N_burnin + N_sample rbpf = RBPF(BF(N_particles; threshold=0.8), KalmanFilter()) ref_traj = nothing - predictive_likelihoods = Vector{InformationDistribution{Vector{T},Matrix{T}}}(undef, K) + predictive_likelihoods = Vector{InformationLikelihood{Vector{T},Matrix{T}}}(undef, K) trajectory_samples = [] for i in 1:N_steps diff --git a/SSMProblems/src/SSMProblems.jl b/SSMProblems/src/SSMProblems.jl index 669d2cf0..57c33dd8 100644 --- a/SSMProblems/src/SSMProblems.jl +++ b/SSMProblems/src/SSMProblems.jl @@ -11,6 +11,7 @@ import Distributions: logpdf export StatePrior, LatentDynamics, ObservationProcess export AbstractStateSpaceModel, StateSpaceModel export prior, dyn, obs +export simulate_from_dist """ Initial state prior of a state space model. @@ -107,6 +108,19 @@ function distribution(obs::ObservationProcess, step::Integer, state; kwargs...) throw(MethodError(distribution, (obs, step, state, kwargs...))) end +""" + simulate_from_dist(rng::AbstractRNG, dist) + +Sample from a distribution object. This is a fallback method that can be overridden for +custom distributions or to provide optimized sampling for specific types. + +The default implementation simply calls `rand(rng, dist)`. A common use case is to provide +specialized sampling that returns static arrays instead of regular vectors. +""" +function simulate_from_dist(rng::AbstractRNG, dist) + return rand(rng, dist) +end + """ simulate([rng::AbstractRNG], prior::StatePrior; kwargs...) @@ -121,7 +135,7 @@ corresponding `distribution()` method. See also [`StatePrior`](@ref). """ function simulate(rng::AbstractRNG, prior::StatePrior; kwargs...) - return rand(rng, distribution(prior; kwargs...)) + return simulate_from_dist(rng, distribution(prior; kwargs...)) end simulate(prior::StatePrior; kwargs...) = simulate(default_rng(), prior; kwargs...) @@ -141,7 +155,7 @@ See also [`LatentDynamics`](@ref). function simulate( rng::AbstractRNG, dyn::LatentDynamics, step::Integer, prev_state; kwargs... ) - return rand(rng, distribution(dyn, step, prev_state; kwargs...)) + return simulate_from_dist(rng, distribution(dyn, step, prev_state; kwargs...)) end function simulate(dynamics::LatentDynamics, prev_state, step; kwargs...) return simulate(default_rng(), dynamics, prev_state, step; kwargs...) @@ -163,7 +177,7 @@ See also [`ObservationProcess`](@ref). function simulate( rng::AbstractRNG, obs::ObservationProcess, step::Integer, state; kwargs... ) - return rand(rng, distribution(obs, step, state; kwargs...)) + return simulate_from_dist(rng, distribution(obs, step, state; kwargs...)) end function simulate(obs::ObservationProcess, step::Integer, state; kwargs...) return simulate(default_rng(), obs, step, state; kwargs...)