Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions GeneralisedFilters/src/GFTest/models/dummy_linear_gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function create_dummy_linear_gaussian_model(
)
# Generate model matrices/vectors
μ0 = rand(rng, T, D_outer + D_inner)
Σ0s = [
Σ0 = [
rand_cov(rng, T, D_outer) zeros(T, D_outer, D_inner)
zeros(T, D_inner, D_outer) rand_cov(rng, T, D_inner)
]
Expand All @@ -76,39 +76,63 @@ function create_dummy_linear_gaussian_model(
R = rand_cov(rng, T, Dy; scale=obs_noise_scale)

# Create full model
full_model = create_homogeneous_linear_gaussian_model(μ0, Σ0s, A, b, Q, H, c, R)

# Create hierarchical model
outer_prior = HomogeneousGaussianPrior(μ0[1:D_outer], Σ0s[1:D_outer, 1:D_outer])

outer_dyn = HomogeneousLinearGaussianLatentDynamics(
A[1:D_outer, 1:D_outer], b[1:D_outer], Q[1:D_outer, 1:D_outer]
full_model = create_homogeneous_linear_gaussian_model(
μ0, PDMat(Σ0), A, b, PDMat(Q), H, c, PDMat(R)
)

outer_prior, outer_dyn = if static_arrays
prior = HomogeneousGaussianPrior(
SVector{D_outer,T}(μ0[1:D_outer]),
PDMat(SMatrix{D_outer,D_outer,T}(Σ0[1:D_outer, 1:D_outer])),
)
dyn = HomogeneousLinearGaussianLatentDynamics(
SMatrix{D_outer,D_outer,T}(A[1:D_outer, 1:D_outer]),
SVector{D_outer,T}(b[1:D_outer]),
PDMat(SMatrix{D_outer,D_outer,T}(Q[1:D_outer, 1:D_outer])),
)
prior, dyn
else
prior = HomogeneousGaussianPrior(μ0[1:D_outer], PDMat(Σ0[1:D_outer, 1:D_outer]))
dyn = HomogeneousLinearGaussianLatentDynamics(
A[1:D_outer, 1:D_outer], b[1:D_outer], PDMat(Q[1:D_outer, 1:D_outer])
)
prior, dyn
end

inner_prior, inner_dyn = if static_arrays
prior = InnerPrior(
SVector{D_inner,T}(μ0[(D_outer + 1):end]),
SMatrix{D_inner,D_inner,T}(Σ0s[(D_outer + 1):end, (D_outer + 1):end]),
PDMat(SMatrix{D_inner,D_inner,T}(Σ0[(D_outer + 1):end, (D_outer + 1):end])),
)
dyn = InnerDynamics(
SMatrix{D_inner,D_outer,T}(A[(D_outer + 1):end, (D_outer + 1):end]),
SVector{D_inner,T}(b[(D_outer + 1):end]),
SMatrix{D_inner,D_outer,T}(A[(D_outer + 1):end, 1:D_outer]),
SMatrix{D_inner,D_inner,T}(Q[(D_outer + 1):end, (D_outer + 1):end]),
PDMat(SMatrix{D_inner,D_inner,T}(Q[(D_outer + 1):end, (D_outer + 1):end])),
)
prior, dyn
else
prior = InnerPrior(μ0[(D_outer + 1):end], Σ0s[(D_outer + 1):end, (D_outer + 1):end])
prior = InnerPrior(
μ0[(D_outer + 1):end], PDMat(Σ0[(D_outer + 1):end, (D_outer + 1):end])
)
dyn = InnerDynamics(
A[(D_outer + 1):end, (D_outer + 1):end],
b[(D_outer + 1):end],
A[(D_outer + 1):end, 1:D_outer],
Q[(D_outer + 1):end, (D_outer + 1):end],
PDMat(Q[(D_outer + 1):end, (D_outer + 1):end]),
)
prior, dyn
end

obs = HomogeneousLinearGaussianObservationProcess(H[:, (D_outer + 1):end], c, R)
obs = if static_arrays
HomogeneousLinearGaussianObservationProcess(
SMatrix{Dy,D_inner,T}(H[:, (D_outer + 1):end]),
SVector{Dy,T}(c),
PDMat(SMatrix{Dy,Dy,T}(R)),
)
else
HomogeneousLinearGaussianObservationProcess(H[:, (D_outer + 1):end], c, PDMat(R))
end
hier_model = HierarchicalSSM(outer_prior, outer_dyn, inner_prior, inner_dyn, obs)

return full_model, hier_model
Expand Down
5 changes: 4 additions & 1 deletion GeneralisedFilters/src/GFTest/models/linear_gaussian.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using StaticArrays
import PDMats: PDMat

function create_linear_gaussian_model(
rng::AbstractRNG,
Expand Down Expand Up @@ -29,7 +30,9 @@ function create_linear_gaussian_model(
R = SMatrix{Dy,Dy,T}(R)
end

return create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R)
return create_homogeneous_linear_gaussian_model(
μ0, PDMat(Σ0), A, b, PDMat(Q), H, c, PDMat(R)
)
end

function _compute_joint(model, T::Integer)
Expand Down
17 changes: 12 additions & 5 deletions GeneralisedFilters/src/GFTest/proposals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@ struct OptimalProposal{
end

function SSMProblems.distribution(prop::OptimalProposal, step::Integer, x, y; kwargs...)
A, b, Q = GeneralisedFilters.calc_params(prop.dyn, step; kwargs...)
H, c, R = GeneralisedFilters.calc_params(prop.obs, step; kwargs...)
Σ = inv(inv(Q) + H' * inv(R) * H)
μ = Σ * (inv(Q) * (A * x + b) + H' * inv(R) * (y - c))
return MvNormal(μ, Σ)
# Get parameters
dyn_params = GeneralisedFilters.calc_params(prop.dyn, step; kwargs...)
obs_params = GeneralisedFilters.calc_params(prop.obs, step; kwargs...)
A, b, Q = dyn_params

# Predicted state: p(x_t | x_{t-1})
state = MvNormal(A * x + b, Q)

# Update with observation: p(x_t | x_{t-1}, y_t)
state, _ = GeneralisedFilters.kalman_update(state, obs_params, y)

return state
end

"""
Expand Down
2 changes: 1 addition & 1 deletion GeneralisedFilters/src/GeneralisedFilters.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module GeneralisedFilters

using AbstractMCMC: AbstractMCMC, AbstractSampler
import Distributions: MvNormal
import Distributions: MvNormal, params
import Random: AbstractRNG, default_rng, rand
import SSMProblems: prior, dyn, obs
using OffsetArrays
Expand Down
89 changes: 51 additions & 38 deletions GeneralisedFilters/src/algorithms/kalman.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export KalmanFilter, filter
using CUDA: i32
import PDMats: PDMat
import PDMats: PDMat, X_A_Xt
import LinearAlgebra: Symmetric

export KalmanFilter, KF, KalmanSmoother, KS
export BackwardInformationPredictor
Expand All @@ -12,61 +13,62 @@ KF() = KalmanFilter()

function initialise(rng::AbstractRNG, prior::GaussianPrior, filter::KalmanFilter; kwargs...)
μ0, Σ0 = calc_initial(prior; kwargs...)
return GaussianDistribution(μ0, Σ0)
return MvNormal(μ0, Σ0)
end

function predict(
rng::AbstractRNG,
dyn::LinearGaussianLatentDynamics,
algo::KalmanFilter,
iter::Integer,
state::GaussianDistribution,
state::MvNormal,
observation=nothing;
kwargs...,
)
params = calc_params(dyn, iter; kwargs...)
state = kalman_predict(state, params)
dyn_params = calc_params(dyn, iter; kwargs...)
state = kalman_predict(state, dyn_params)
return state
end

function kalman_predict(state, params)
μ, Σ = mean_cov(state)
A, b, Q = params
function kalman_predict(state, dyn_params)
μ, Σ = params(state)
A, b, Q = dyn_params

μ̂ = A * μ + b
Σ̂ = A * Σ * A' + Q
return GaussianDistribution(μ̂, Σ̂)
Σ̂ = X_A_Xt(Σ, A) + Q
return MvNormal(μ̂, Σ̂)
end

function update(
obs::LinearGaussianObservationProcess,
algo::KalmanFilter,
iter::Integer,
state::GaussianDistribution,
state::MvNormal,
observation::AbstractVector;
kwargs...,
)
params = calc_params(obs, iter; kwargs...)
state, ll = kalman_update(state, params, observation)
obs_params = calc_params(obs, iter; kwargs...)
state, ll = kalman_update(state, obs_params, observation)
return state, ll
end

function kalman_update(state, params, observation)
μ, Σ = mean_cov(state)
H, c, R = params
function kalman_update(state, obs_params, observation)
μ, Σ = params(state)
H, c, R = obs_params

# Update state
# Compute innovation distribution
m = H * μ + c
y = observation - m
S = H * Σ * H' + R
S = (S + S') / 2 # force symmetric; TODO: replace with SA-compatibile hermitianpart
S_chol = cholesky(S)
K = Σ * H' / S_chol # Zygote errors when using PDMat in solve
S = PDMat(X_A_Xt(Σ, H) + R)
ȳ = observation - m
K = Σ * H' / S

state = GaussianDistribution(μ + K * y, Σ - K * H * Σ)
# Update parameters using Joseph form to ensure numerical stability
μ̂ = μ + K * ȳ
Σ̂ = PDMat(X_A_Xt(Σ, I - K * H) + X_A_Xt(R, K))
state = MvNormal(μ̂, Σ̂)

# Compute log-likelihood
ll = logpdf(MvNormal(m, PDMat(S_chol)), observation)
ll = logpdf(MvNormal(m, S), observation)

return state, ll
end
Expand Down Expand Up @@ -146,25 +148,36 @@ function smooth(
return back_state, ll
end

import LinearAlgebra: eigen

function backward(
rng::AbstractRNG,
model::LinearGaussianStateSpaceModel,
algo::KalmanSmoother,
iter::Integer,
back_state,
back_state::MvNormal,
obs;
states_cache,
kwargs...,
)
μ, Σ = mean_cov(back_state)
μ_pred, Σ_pred = mean_cov(states_cache.proposed_states[iter + 1])
μ_filt, Σ_filt = mean_cov(states_cache.filtered_states[iter])
# Extract filtered and predicted states
μ_filt, Σ_filt = params(states_cache.filtered_states[iter])
μ_pred, Σ_pred = params(states_cache.proposed_states[iter + 1])
μ_back, Σ_back = params(back_state)

dyn_params = calc_params(model.dyn, iter + 1; kwargs...)
A, b, Q = dyn_params

G = Σ_filt * A' / Σ_pred
μ̂ = μ_filt + G * (μ_back - μ_pred)

# Σ_pred - Σ_back may be singular (even though it is PSD) so cannot use X_A_Xt with Cholesky
Σ̂ = Σ_filt + G * (Σ_back - Σ_pred) * G'

G = Σ_filt * model.dyn.A' * inv(Σ_pred)
μ = μ_filt .+ G * (μ .- μ_pred)
Σ = Σ_filt .+ G * (Σ .- Σ_pred) * G'
# Force symmetry
Σ̂ = PDMat(Symmetric(Σ̂))

return GaussianDistribution(μ, Σ)
return MvNormal(μ̂, Σ̂)
end

## BACKWARD INFORMATION FILTER #############################################################
Expand All @@ -182,7 +195,7 @@ struct BackwardInformationPredictor <: AbstractBackwardPredictor end
"""
backward_initialise(rng, obs, algo, iter, y; kwargs...)

Initialise a backward predictor at time `T` with observation `y`, forming the likelihood
Initialise a backward predictor at time `T` with observation `y`, forming the likelihood
p(y_T | x_T).
"""
function backward_initialise(
Expand All @@ -197,7 +210,7 @@ function backward_initialise(
R_inv = inv(R)
λ = H' * R_inv * (y - c)
Ω = H' * R_inv * H
return InformationDistribution(λ, Ω)
return InformationLikelihood(λ, Ω)
end

"""
Expand All @@ -210,7 +223,7 @@ function backward_predict(
dyn::LinearGaussianLatentDynamics,
algo::BackwardInformationPredictor,
iter::Integer,
state::InformationDistribution;
state::InformationLikelihood;
kwargs...,
)
λ, Ω = natural_params(state)
Expand All @@ -223,7 +236,7 @@ function backward_predict(
Ω̂ = A' * (I - Ω * F * inv(Λ) * F') * Ω * A
λ̂ = A' * (I - Ω * F * inv(Λ) * F') * m

return InformationDistribution(λ̂, Ω̂)
return InformationLikelihood(λ̂, Ω̂)
end

"""
Expand All @@ -235,7 +248,7 @@ function backward_update(
obs::LinearGaussianObservationProcess,
algo::BackwardInformationPredictor,
iter::Integer,
state::InformationDistribution,
state::InformationLikelihood,
y;
kwargs...,
)
Expand All @@ -246,5 +259,5 @@ function backward_update(
λ̂ = λ + H' * R_inv * (y - c)
Ω̂ = Ω + H' * R_inv * H

return InformationDistribution(λ̂, Ω̂)
return InformationLikelihood(λ̂, Ω̂)
end
3 changes: 2 additions & 1 deletion GeneralisedFilters/src/algorithms/particles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ end
function SSMProblems.simulate(
rng::AbstractRNG, prop::AbstractProposal, iter::Integer, state, observation; kwargs...
)
return rand(rng, SSMProblems.distribution(prop, iter, state, observation; kwargs...))
dist = SSMProblems.distribution(prop, iter, state, observation; kwargs...)
return SSMProblems.simulate_from_dist(rng, dist)
end

function SSMProblems.logdensity(
Expand Down
6 changes: 3 additions & 3 deletions GeneralisedFilters/src/ancestor_sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function ancestor_weight(
algo::RBPF,
iter::Integer,
state::RBState,
ref_state::RBState{<:Any,<:InformationDistribution};
ref_state::RBState{<:Any,<:InformationLikelihood};
kwargs...,
)
trans_weight = ancestor_weight(
Expand Down Expand Up @@ -52,9 +52,9 @@ p(y_{t+1:T} | x_{t+1}).
This Gaussian implementation is based on Lemma 1 of https://arxiv.org/pdf/1505.06357
"""
function compute_marginal_predictive_likelihood(
forward_dist::GaussianDistribution, backward_dist::InformationDistribution
forward_dist::MvNormal, backward_dist::InformationLikelihood
)
μ, Σ = mean_cov(forward_dist)
μ, Σ = params(forward_dist)
λ, Ω = natural_params(backward_dist)
Γ = cholesky(Σ).L

Expand Down
Loading
Loading