From ce97457ec11281ccf50d06835c7f7f36bc566dab Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sun, 20 Oct 2024 11:47:32 +0100 Subject: [PATCH 1/9] Add auxiliary particle filter --- Project.toml | 1 + src/GeneralisedFilters.jl | 17 ++++++ src/algorithms/apf.jl | 114 ++++++++++++++++++++++++++++++++++++ src/algorithms/bootstrap.jl | 19 +++++- src/containers.jl | 9 ++- src/resamplers.jl | 21 ++++++- 6 files changed, 171 insertions(+), 10 deletions(-) create mode 100644 src/algorithms/apf.jl diff --git a/Project.toml b/Project.toml index 97bd2148..465cd53d 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] DataStructures = "0.18.20" GaussianDistributions = "0.5.2" +SSMProblems = "0.3.0" StatsBase = "0.34.3" [extras] diff --git a/src/GeneralisedFilters.jl b/src/GeneralisedFilters.jl index da9b8d25..255fe903 100644 --- a/src/GeneralisedFilters.jl +++ b/src/GeneralisedFilters.jl @@ -41,6 +41,22 @@ Perform a combined predict and update call on a single iteration of the filter. """ function step end +""" + reset_weights!(log_weights, filter) + +Reset container log-weights after a resampling step +""" +function reset_weights! end + +""" + update_weights! +""" +function update_weights! end + +function log_marginal end + +function update_ref! end + function initialise(model, alg; kwargs...) return initialise(default_rng(), model, alg; kwargs...) end @@ -106,6 +122,7 @@ include("models/hierarchical.jl") # Filtering/smoothing algorithms include("algorithms/bootstrap.jl") +include("algorithms/apf.jl") include("algorithms/kalman.jl") include("algorithms/forward.jl") include("algorithms/rbpf.jl") diff --git a/src/algorithms/apf.jl b/src/algorithms/apf.jl new file mode 100644 index 00000000..80fe9770 --- /dev/null +++ b/src/algorithms/apf.jl @@ -0,0 +1,114 @@ +export AuxiliaryParticleFilter, APF + +struct AuxiliaryParticleFilter{RS<:AbstractConditionalResampler} <: AbstractFilter + N::Integer + resampler::RS + aux::Vector # Auxiliary weights +end + +function AuxiliaryParticleFilter( + N::Integer, threshold::Real=1.0, resampler::AbstractResampler=Systematic() +) + conditional_resampler = ESSResampler(threshold, resampler) + return AuxiliaryParticleFilter(N, conditional_resampler, zeros(N)) +end + +const APF = AuxiliaryParticleFilter + +function initialise( + rng::AbstractRNG, + model::StateSpaceModel{T}, + filter::AuxiliaryParticleFilter; + ref_state::Union{Nothing,AbstractVector}=nothing, + kwargs..., +) where {T} + initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N)) + initial_weights = fill(-log(T(filter.N)), filter.N) + + return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state) +end + +function update_weights!( + rng::AbstractRNG, filter, model, step, states, observation; kwargs... +) + simulation_weights = eta(rng, model, step, states, observation) + return states.log_weights += simulation_weights +end + +function predict( + rng::AbstractRNG, + model::StateSpaceModel, + filter::AuxiliaryParticleFilter, + step::Integer, + states::ParticleContainer{T}, + observation; + ref_state::Union{Nothing,AbstractVector{T}}=nothing, + kwargs..., +) where {T} + # states = update_weights!(rng, filter.eta, model, step, states.filtered, observation; kwargs...) + + # Compute auxilary weights + # POC: use the simplest approximation to the predictive likelihood + # Ideally should be something like update_weights!(filter, ...) + auxiliary_weights = map( + x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), + states.filtered.particles, + ) + state.filtered.log_weights .+= auxiliary_weights + filter.aux = auxiliary_weights + + states.proposed = resample(rng, filter.resampler, states.filtered, filter) + states.proposed.particles = map( + x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), + states.proposed.particles, + ) + + return update_ref!(states, ref_state, step) +end + +function update( + model::StateSpaceModel{T}, + filter::AuxiliaryParticleFilter, + step::Integer, + states::ParticleContainer, + observation; + kwargs..., +) where {T} + @debug "step $step" + log_increments = map( + x -> SSMProblems.logdensity(model.obs, step - 1, x, observation; kwargs...), + collect(states.proposed.particles), + ) + + states.filtered.log_weights = states.proposed.log_weights + log_increments + states.filtered.particles = states.proposed.particles + + return (states, logsumexp(log_increments) - log(T(filter.N))) +end + +function step( + rng::AbstractRNG, + model::AbstractStateSpaceModel, + alg::AuxiliaryParticleFilter, + iter::Integer, + state, + observation; + kwargs..., +) + proposed_state = predict(rng, model, alg, iter, state, observation; kwargs...) + filtered_state, ll = update(model, alg, iter, proposed_state, observation; kwargs...) + + return filtered_state, ll +end + +function reset_weights!( + state::ParticleState{T,WT}, idxs, filter::AuxiliaryParticleFilter +) where {T,WT<:Real} + # From Choping: An Introduction to sequential monte carlo, section 10.3.3 + state.log_weights = state.log_weights[idxs] - filter.aux[idxs] + return state +end + +function logmarginal(states::ParticleContainer, ::AuxiliaryParticleFilter) + return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) +end diff --git a/src/algorithms/bootstrap.jl b/src/algorithms/bootstrap.jl index ac8ed1eb..bb6ba5cc 100644 --- a/src/algorithms/bootstrap.jl +++ b/src/algorithms/bootstrap.jl @@ -25,7 +25,9 @@ function initialise( initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N)) initial_weights = zeros(T, filter.N) - return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state) + return update_ref!( + ParticleContainer(initial_states, initial_weights), ref_state, filter + ) end function predict( @@ -37,13 +39,13 @@ function predict( ref_state::Union{Nothing,AbstractVector{T}}=nothing, kwargs..., ) where {T} - states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered) + states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered, filter) states.proposed.particles = map( x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), collect(states.proposed), ) - return update_ref!(states, ref_state, step) + return update_ref!(states, ref_state, filter, step) end function update( @@ -64,3 +66,14 @@ function update( return states, logmarginal(states) end + +function reset_weights!( + state::ParticleState{T,WT}, idxs, filter::BootstrapFilter +) where {T,WT<:Real} + fill!(state.log_weights, -log(WT(length(state.particles)))) + return state +end + +function logmarginal(states::ParticleContainer, ::BootstrapFilter) + return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) +end diff --git a/src/containers.jl b/src/containers.jl index e7e652b4..1bf26336 100644 --- a/src/containers.jl +++ b/src/containers.jl @@ -111,7 +111,10 @@ function reset_weights!(state::ParticleState{T,WT}) where {T,WT<:Real} end function update_ref!( - pc::ParticleContainer{T}, ref_state::Union{Nothing,AbstractVector{T}}, step::Integer=0 + pc::ParticleContainer{T}, + ref_state::Union{Nothing,AbstractVector{T}}, + ::AbstractFilter, + step::Integer=0, ) where {T} # this comes from Nicolas Chopin's package particles if !isnothing(ref_state) @@ -122,10 +125,6 @@ function update_ref!( return pc end -function logmarginal(states::ParticleContainer) - return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) -end - ## SPARSE PARTICLE STORAGE ################################################################# Base.append!(s::Stack, a::AbstractVector) = map(x -> push!(s, x), a) diff --git a/src/resamplers.jl b/src/resamplers.jl index d0e74e51..ec9aeeac 100644 --- a/src/resamplers.jl +++ b/src/resamplers.jl @@ -8,8 +8,8 @@ export Multinomial, Systematic, Metropolis, Rejection abstract type AbstractResampler end function resample( - rng::AbstractRNG, resampler::AbstractResampler, states::ParticleState{PT,WT} -) where {PT,WT} + rng::AbstractRNG, resampler::AbstractResampler, states::ParticleState{PT,WT}, filter::U +) where {PT,WT,U<:AbstractFilter} weights = StatsBase.weights(states) idxs = sample_ancestors(rng, resampler, weights) @@ -36,6 +36,23 @@ function resample( return new_state, idxs end +# TODO: combine this with above definition +function resample( + rng::AbstractRNG, + resampler::AbstractResampler, + states::RaoBlackwellisedParticleState{T,M,ZT}, +) where {T,M,ZT} + weights = StatsBase.weights(states) + idxs = sample_ancestors(rng, resampler, weights) + + new_state = RaoBlackwellisedParticleState( + deepcopy(states.x_particles[:, idxs]), + deepcopy(states.z_particles[idxs]), + CUDA.zeros(T, length(states)), + ) + return reset_weights!(state, idxs, filter) +end + ## CONDITIONAL RESAMPLING ################################################################## abstract type AbstractConditionalResampler <: AbstractResampler end From 9dc1f996edf14ab8a48f16cd431a271cf54b4612 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sun, 20 Oct 2024 11:49:38 +0100 Subject: [PATCH 2/9] Foramt --- src/resamplers.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/resamplers.jl b/src/resamplers.jl index ec9aeeac..43da798c 100644 --- a/src/resamplers.jl +++ b/src/resamplers.jl @@ -8,8 +8,11 @@ export Multinomial, Systematic, Metropolis, Rejection abstract type AbstractResampler end function resample( - rng::AbstractRNG, resampler::AbstractResampler, states::ParticleState{PT,WT}, filter::U -) where {PT,WT,U<:AbstractFilter} + rng::AbstractRNG, + resampler::AbstractResampler, + states::ParticleState{PT,WT}, + filter::AbstractFilter, +) where {PT,WT} weights = StatsBase.weights(states) idxs = sample_ancestors(rng, resampler, weights) From 840d7a9cc7fc11bba136c362cd5c1d9e50d17f65 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sun, 20 Oct 2024 12:10:53 +0100 Subject: [PATCH 3/9] Proper weights --- src/algorithms/apf.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/algorithms/apf.jl b/src/algorithms/apf.jl index 80fe9770..4e7ccc1f 100644 --- a/src/algorithms/apf.jl +++ b/src/algorithms/apf.jl @@ -50,10 +50,14 @@ function predict( # Compute auxilary weights # POC: use the simplest approximation to the predictive likelihood # Ideally should be something like update_weights!(filter, ...) - auxiliary_weights = map( + predicted = map( x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), states.filtered.particles, ) + auxiliary_weights = map( + x -> SSMProblems.logdensity(model.obs, step - 1, x, observation; kwargs...), + predicted, + ) state.filtered.log_weights .+= auxiliary_weights filter.aux = auxiliary_weights From 7daa3d16e4bb06d9a1cdacd3621bd2971e241706 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 23 Oct 2024 20:45:31 +0100 Subject: [PATCH 4/9] Mean transition --- Project.toml | 2 +- src/GeneralisedFilters.jl | 2 ++ src/algorithms/apf.jl | 7 +++---- src/containers.jl | 4 ++-- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 465cd53d..6a105b7b 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] DataStructures = "0.18.20" GaussianDistributions = "0.5.2" -SSMProblems = "0.3.0" +SSMProblems = "0.4.0" StatsBase = "0.34.3" [extras] diff --git a/src/GeneralisedFilters.jl b/src/GeneralisedFilters.jl index 255fe903..2ad72746 100644 --- a/src/GeneralisedFilters.jl +++ b/src/GeneralisedFilters.jl @@ -73,6 +73,7 @@ function filter( callback=nothing, kwargs..., ) + println("1") states = initialise(rng, model, alg; kwargs...) log_evidence = zero(eltype(model)) @@ -93,6 +94,7 @@ function filter( observations::AbstractVector; kwargs..., ) + println("2") return filter(default_rng(), model, alg, observations; kwargs...) end diff --git a/src/algorithms/apf.jl b/src/algorithms/apf.jl index 4e7ccc1f..74694f17 100644 --- a/src/algorithms/apf.jl +++ b/src/algorithms/apf.jl @@ -51,12 +51,11 @@ function predict( # POC: use the simplest approximation to the predictive likelihood # Ideally should be something like update_weights!(filter, ...) predicted = map( - x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), + x -> mean(SSMProblems.distribution(model.dyn, step, x; kwargs...)), states.filtered.particles, ) auxiliary_weights = map( - x -> SSMProblems.logdensity(model.obs, step - 1, x, observation; kwargs...), - predicted, + x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), predicted ) state.filtered.log_weights .+= auxiliary_weights filter.aux = auxiliary_weights @@ -80,7 +79,7 @@ function update( ) where {T} @debug "step $step" log_increments = map( - x -> SSMProblems.logdensity(model.obs, step - 1, x, observation; kwargs...), + x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), collect(states.proposed.particles), ) diff --git a/src/containers.jl b/src/containers.jl index 1bf26336..dcc92004 100644 --- a/src/containers.jl +++ b/src/containers.jl @@ -179,7 +179,7 @@ function prune!(tree::ParticleTree, offspring::Vector{Int64}) end function insert!( - tree::ParticleTree{T}, states::Vector{T}, ancestors::AbstractVector{Int64} + tree::ParticleTree{T}, states::Vector{T}, ancestors::AbstractVector{<:Integer} ) where {T} # parents of new generation parents = getindex(tree.leaves, ancestors) @@ -212,7 +212,7 @@ function expand!(tree::ParticleTree) return tree end -function get_offspring(a::AbstractVector{Int64}) +function get_offspring(a::AbstractVector{<:Integer}) offspring = zero(a) for i in a offspring[i] += 1 From 6564f5481803b6faad1e654e34dc5b1a8cbab17a Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 23 Oct 2024 21:03:18 +0100 Subject: [PATCH 5/9] Merge conflict --- src/GeneralisedFilters.jl | 2 -- src/algorithms/bootstrap.jl | 7 +++++++ src/algorithms/rbpf.jl | 2 +- src/containers.jl | 2 +- test/runtests.jl | 42 +++++++++++++++++++++++++++++++++++++ 5 files changed, 51 insertions(+), 4 deletions(-) diff --git a/src/GeneralisedFilters.jl b/src/GeneralisedFilters.jl index 2ad72746..255fe903 100644 --- a/src/GeneralisedFilters.jl +++ b/src/GeneralisedFilters.jl @@ -73,7 +73,6 @@ function filter( callback=nothing, kwargs..., ) - println("1") states = initialise(rng, model, alg; kwargs...) log_evidence = zero(eltype(model)) @@ -94,7 +93,6 @@ function filter( observations::AbstractVector; kwargs..., ) - println("2") return filter(default_rng(), model, alg, observations; kwargs...) end diff --git a/src/algorithms/bootstrap.jl b/src/algorithms/bootstrap.jl index bb6ba5cc..8ec54a1b 100644 --- a/src/algorithms/bootstrap.jl +++ b/src/algorithms/bootstrap.jl @@ -5,6 +5,13 @@ struct BootstrapFilter{RS<:AbstractResampler} <: AbstractFilter resampler::RS end +function BootstrapFilter( + N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic() +) + conditional_resampler = ESSResampler(threshold, resampler) + return BootstrapFilter(N, conditional_resampler) +end + """Shorthand for `BootstrapFilter`""" const BF = BootstrapFilter diff --git a/src/algorithms/rbpf.jl b/src/algorithms/rbpf.jl index 4b1b7000..fcadb07a 100644 --- a/src/algorithms/rbpf.jl +++ b/src/algorithms/rbpf.jl @@ -72,7 +72,7 @@ end function predict( rng::AbstractRNG, model::HierarchicalSSM, algo::RBPF, t::Integer, states; kwargs... ) - states.proposed, states.ancestors = resample(rng, algo.resampler, states.filtered) + states.proposed, states.ancestors = resample(rng, algo.resampler, states.filtered, algo) states.proposed.particles = map( x -> marginal_predict(rng, model, algo, t, x; kwargs...), diff --git a/src/containers.jl b/src/containers.jl index dcc92004..6be3e695 100644 --- a/src/containers.jl +++ b/src/containers.jl @@ -1,5 +1,5 @@ using DataStructures: Stack -using Random: rand +import Random: rand ## GAUSSIAN STATES ######################################################################### diff --git a/test/runtests.jl b/test/runtests.jl index 86c0f96f..65e614ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -122,6 +122,48 @@ end @test llkf ≈ llbf atol = 0.1 end +@testitem "APF filter test" begin + using GeneralisedFilters + using SSMProblems + using StableRNGs + using PDMats + using LinearAlgebra + using Random: randexp + + T = Float32 + rng = StableRNG(1234) + σx², σy² = randexp(rng, T, 2) + + # initial state distribution + μ0 = zeros(T, 2) + Σ0 = PDMat(T[1 0; 0 1]) + + # state transition equation + A = T[1 1; 0 1] + b = T[0; 0] + Q = PDiagMat([σx²; 0]) + + # observation equation + H = T[1 0] + c = T[0;] + R = [σy²;;] + + # when working with PDMats, the Kalman filter doesn't play nicely without this + function Base.convert(::Type{PDMat{T,MT}}, mat::MT) where {MT<:AbstractMatrix,T<:Real} + return PDMat(Symmetric(mat)) + end + + model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) + _, _, data = sample(rng, model, 20) + + bf = APF(2^10; threshold=0.8) + _, llbf = GeneralisedFilters.filter(rng, model, bf, data) + _, llkf = GeneralisedFilters.filter(rng, model, KF(), data) + + # since this is log valued, we can up the tolerance + @test llkf ≈ llbf atol = 2 +end + @testitem "Forward algorithm test" begin using GeneralisedFilters using Distributions From 72282af033eb4987690fd8600eaa6388b0f0f71a Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 23 Oct 2024 21:28:05 +0100 Subject: [PATCH 6/9] APF --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 65e614ea..f9bb510f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -156,7 +156,7 @@ end model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) _, _, data = sample(rng, model, 20) - bf = APF(2^10; threshold=0.8) + bf = APF(2^10, threshold=0.8) _, llbf = GeneralisedFilters.filter(rng, model, bf, data) _, llkf = GeneralisedFilters.filter(rng, model, KF(), data) From 9a8b28e9c6efc888e6bc8028cd6c211d0a26dab6 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 23 Oct 2024 21:34:01 +0100 Subject: [PATCH 7/9] new interface, fix pre-sampling --- src/algorithms/apf.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/algorithms/apf.jl b/src/algorithms/apf.jl index 74694f17..e7652424 100644 --- a/src/algorithms/apf.jl +++ b/src/algorithms/apf.jl @@ -1,13 +1,13 @@ export AuxiliaryParticleFilter, APF -struct AuxiliaryParticleFilter{RS<:AbstractConditionalResampler} <: AbstractFilter +mutable struct AuxiliaryParticleFilter{RS<:AbstractConditionalResampler} <: AbstractFilter N::Integer resampler::RS aux::Vector # Auxiliary weights end function AuxiliaryParticleFilter( - N::Integer, threshold::Real=1.0, resampler::AbstractResampler=Systematic() + N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic() ) conditional_resampler = ESSResampler(threshold, resampler) return AuxiliaryParticleFilter(N, conditional_resampler, zeros(N)) @@ -25,7 +25,7 @@ function initialise( initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N)) initial_weights = fill(-log(T(filter.N)), filter.N) - return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state) + return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state, filter) end function update_weights!( @@ -57,16 +57,16 @@ function predict( auxiliary_weights = map( x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), predicted ) - state.filtered.log_weights .+= auxiliary_weights + states.filtered.log_weights .+= auxiliary_weights filter.aux = auxiliary_weights - states.proposed = resample(rng, filter.resampler, states.filtered, filter) + states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered, filter) states.proposed.particles = map( x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), states.proposed.particles, ) - return update_ref!(states, ref_state, step) + return update_ref!(states, ref_state, filter, step) end function update( From b78efceaf0052080702f64c4dfac86b6dc1255d6 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 23 Oct 2024 21:51:18 +0100 Subject: [PATCH 8/9] test --- src/algorithms/apf.jl | 2 +- test/runtests.jl | 45 +++---------------------------------------- 2 files changed, 4 insertions(+), 43 deletions(-) diff --git a/src/algorithms/apf.jl b/src/algorithms/apf.jl index e7652424..85f993de 100644 --- a/src/algorithms/apf.jl +++ b/src/algorithms/apf.jl @@ -7,7 +7,7 @@ mutable struct AuxiliaryParticleFilter{RS<:AbstractConditionalResampler} <: Abst end function AuxiliaryParticleFilter( - N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic() + N::Integer; threshold::Real=0., resampler::AbstractResampler=Systematic() ) conditional_resampler = ESSResampler(threshold, resampler) return AuxiliaryParticleFilter(N, conditional_resampler, zeros(N)) diff --git a/test/runtests.jl b/test/runtests.jl index f9bb510f..4d85a0ca 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -109,7 +109,9 @@ end _, _, data = sample(rng, model, 20) bf = BF(2^12; threshold=0.8) + apf = APF(2^10, threshold=1.) bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, data) + _, llapf= GeneralisedFilters.filter(rng, model, apf, data) kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), data) xs = bf_state.filtered.particles @@ -120,48 +122,7 @@ end # since this is log valued, we can up the tolerance @test llkf ≈ llbf atol = 0.1 -end - -@testitem "APF filter test" begin - using GeneralisedFilters - using SSMProblems - using StableRNGs - using PDMats - using LinearAlgebra - using Random: randexp - - T = Float32 - rng = StableRNG(1234) - σx², σy² = randexp(rng, T, 2) - - # initial state distribution - μ0 = zeros(T, 2) - Σ0 = PDMat(T[1 0; 0 1]) - - # state transition equation - A = T[1 1; 0 1] - b = T[0; 0] - Q = PDiagMat([σx²; 0]) - - # observation equation - H = T[1 0] - c = T[0;] - R = [σy²;;] - - # when working with PDMats, the Kalman filter doesn't play nicely without this - function Base.convert(::Type{PDMat{T,MT}}, mat::MT) where {MT<:AbstractMatrix,T<:Real} - return PDMat(Symmetric(mat)) - end - - model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) - _, _, data = sample(rng, model, 20) - - bf = APF(2^10, threshold=0.8) - _, llbf = GeneralisedFilters.filter(rng, model, bf, data) - _, llkf = GeneralisedFilters.filter(rng, model, KF(), data) - - # since this is log valued, we can up the tolerance - @test llkf ≈ llbf atol = 2 + @test llkf ≈ llapf atol = 2 end @testitem "Forward algorithm test" begin From c791095301de4ebf464d3cbe39d6db660c0f0b2d Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sun, 27 Oct 2024 12:32:48 +0000 Subject: [PATCH 9/9] Merge gpu code --- src/GeneralisedFilters.jl | 2 ++ src/algorithms/apf.jl | 15 +++++++-------- src/algorithms/bootstrap.jl | 22 +++++++--------------- src/algorithms/rbpf.jl | 2 +- src/containers.jl | 6 +++++- src/resamplers.jl | 33 ++++++++------------------------- 6 files changed, 30 insertions(+), 50 deletions(-) diff --git a/src/GeneralisedFilters.jl b/src/GeneralisedFilters.jl index 255fe903..d53589f6 100644 --- a/src/GeneralisedFilters.jl +++ b/src/GeneralisedFilters.jl @@ -13,6 +13,8 @@ using NNlib abstract type AbstractFilter <: AbstractSampler end +abstract type AbstractParticleFilter{N} <: AbstractFilter end + """ predict([rng,] model, alg, iter, state; kwargs...) diff --git a/src/algorithms/apf.jl b/src/algorithms/apf.jl index 85f993de..f4a9df6a 100644 --- a/src/algorithms/apf.jl +++ b/src/algorithms/apf.jl @@ -1,7 +1,6 @@ export AuxiliaryParticleFilter, APF -mutable struct AuxiliaryParticleFilter{RS<:AbstractConditionalResampler} <: AbstractFilter - N::Integer +mutable struct AuxiliaryParticleFilter{N,RS<:AbstractConditionalResampler} <: AbstractParticleFilter{N} resampler::RS aux::Vector # Auxiliary weights end @@ -10,7 +9,7 @@ function AuxiliaryParticleFilter( N::Integer; threshold::Real=0., resampler::AbstractResampler=Systematic() ) conditional_resampler = ESSResampler(threshold, resampler) - return AuxiliaryParticleFilter(N, conditional_resampler, zeros(N)) + return AuxiliaryParticleFilter{N,typeof(conditional_resampler)}(conditional_resampler, zeros(N)) end const APF = AuxiliaryParticleFilter @@ -18,12 +17,12 @@ const APF = AuxiliaryParticleFilter function initialise( rng::AbstractRNG, model::StateSpaceModel{T}, - filter::AuxiliaryParticleFilter; + filter::AuxiliaryParticleFilter{N}, ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., -) where {T} - initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N)) - initial_weights = fill(-log(T(filter.N)), filter.N) +) where {N,T} + initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N) + initial_weights = zeros(T, N) return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state, filter) end @@ -86,7 +85,7 @@ function update( states.filtered.log_weights = states.proposed.log_weights + log_increments states.filtered.particles = states.proposed.particles - return (states, logsumexp(log_increments) - log(T(filter.N))) + return states, logmarginal(states, filter) end function step( diff --git a/src/algorithms/bootstrap.jl b/src/algorithms/bootstrap.jl index 8ec54a1b..a5087616 100644 --- a/src/algorithms/bootstrap.jl +++ b/src/algorithms/bootstrap.jl @@ -1,7 +1,6 @@ export BootstrapFilter, BF -struct BootstrapFilter{RS<:AbstractResampler} <: AbstractFilter - N::Integer +struct BootstrapFilter{N,RS<:AbstractResampler} <: AbstractParticleFilter{N} resampler::RS end @@ -9,28 +8,21 @@ function BootstrapFilter( N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic() ) conditional_resampler = ESSResampler(threshold, resampler) - return BootstrapFilter(N, conditional_resampler) + return BootstrapFilter{N, typeof(conditional_resampler)}(conditional_resampler) end """Shorthand for `BootstrapFilter`""" const BF = BootstrapFilter -function BootstrapFilter( - N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic() -) - conditional_resampler = ESSResampler(threshold, resampler) - return BootstrapFilter(N, conditional_resampler) -end - function initialise( rng::AbstractRNG, model::StateSpaceModel{T}, - filter::BootstrapFilter; + filter::BootstrapFilter{N}; ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., -) where {T} - initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N)) - initial_weights = zeros(T, filter.N) +) where {N,T} + initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N) + initial_weights = zeros(T, N) return update_ref!( ParticleContainer(initial_states, initial_weights), ref_state, filter @@ -71,7 +63,7 @@ function update( states.filtered.log_weights = states.proposed.log_weights + log_increments states.filtered.particles = states.proposed.particles - return states, logmarginal(states) + return states, logmarginal(states, filter) end function reset_weights!( diff --git a/src/algorithms/rbpf.jl b/src/algorithms/rbpf.jl index fcadb07a..11b47d50 100644 --- a/src/algorithms/rbpf.jl +++ b/src/algorithms/rbpf.jl @@ -108,7 +108,7 @@ function update( states.filtered.log_weights = states.proposed.log_weights + log_increments - return states, logmarginal(states) + return states, logmarginal(states, algo) end ################################# diff --git a/src/containers.jl b/src/containers.jl index 6be3e695..e31c7f07 100644 --- a/src/containers.jl +++ b/src/containers.jl @@ -105,11 +105,15 @@ Base.keys(state::ParticleState) = LinearIndices(state.particles) Base.@propagate_inbounds Base.getindex(state::ParticleState, i) = state.particles[i] # Base.@propagate_inbounds Base.getindex(state::ParticleState, i::Vector{Int}) = state.particles[i] -function reset_weights!(state::ParticleState{T,WT}) where {T,WT<:Real} +function reset_weights!(state::ParticleState{T,WT}, idx, ::AbstractFilter) where {T,WT<:Real} fill!(state.log_weights, zero(WT)) return state.log_weights end +function logmarginal(states::ParticleContainer, ::AbstractFilter) + return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) +end + function update_ref!( pc::ParticleContainer{T}, ref_state::Union{Nothing,AbstractVector{T}}, diff --git a/src/resamplers.jl b/src/resamplers.jl index 43da798c..60a7b36d 100644 --- a/src/resamplers.jl +++ b/src/resamplers.jl @@ -11,13 +11,12 @@ function resample( rng::AbstractRNG, resampler::AbstractResampler, states::ParticleState{PT,WT}, - filter::AbstractFilter, + filter::AbstractFilter; + weights::AbstractVector{WT}=StatsBase.weights(states) ) where {PT,WT} - weights = StatsBase.weights(states) idxs = sample_ancestors(rng, resampler, weights) - - new_state = ParticleState(deepcopy(states.particles[idxs]), zeros(WT, length(states))) - + new_state = ParticleState(deepcopy(states.particles[idxs]), zeros(WT, length(states))) + reset_weights!(new_state, idxs, filter) return new_state, idxs end @@ -26,8 +25,9 @@ function resample( rng::AbstractRNG, resampler::AbstractResampler, states::RaoBlackwellisedParticleState{T,M,ZT}, + ::AbstractFilter; + weights=StatsBase.weights(states) ) where {T,M,ZT} - weights = StatsBase.weights(states) idxs = sample_ancestors(rng, resampler, weights) new_state = RaoBlackwellisedParticleState( @@ -39,23 +39,6 @@ function resample( return new_state, idxs end -# TODO: combine this with above definition -function resample( - rng::AbstractRNG, - resampler::AbstractResampler, - states::RaoBlackwellisedParticleState{T,M,ZT}, -) where {T,M,ZT} - weights = StatsBase.weights(states) - idxs = sample_ancestors(rng, resampler, weights) - - new_state = RaoBlackwellisedParticleState( - deepcopy(states.x_particles[:, idxs]), - deepcopy(states.z_particles[idxs]), - CUDA.zeros(T, length(states)), - ) - return reset_weights!(state, idxs, filter) -end - ## CONDITIONAL RESAMPLING ################################################################## abstract type AbstractConditionalResampler <: AbstractResampler end @@ -69,7 +52,7 @@ struct ESSResampler <: AbstractConditionalResampler end function resample( - rng::AbstractRNG, cond_resampler::ESSResampler, state::ParticleState{PT,WT} + rng::AbstractRNG, cond_resampler::ESSResampler, state::ParticleState{PT,WT}, filter::AbstractFilter ) where {PT,WT} n = length(state) # TODO: computing weights twice. Should create a wrapper to avoid this @@ -78,7 +61,7 @@ function resample( @debug "ESS: $ess" if cond_resampler.threshold * n ≥ ess - return resample(rng, cond_resampler.resampler, state) + return resample(rng, cond_resampler.resampler, state, filter; weights=weights) else return deepcopy(state), collect(1:n) end