import Random
import StatsBase
using CairoMakie
using Distributions
using Statistics
set_theme!(
= 18,
fontsize = (; xgridvisible = false, ygridvisible = false,
Axis = false, rightspinevisible = false),
topspinevisible = (; framevisible = false)) Legend
Sequential monte carlo
0.1 Load packages and Makie theme
1 First example
1.1 Generate some data
Random.seed!(123)
= 20
true_μ = 5
true_σ = Normal(true_μ, true_σ)
true_dist = rand(true_dist, 100)
y
density(y)
1.2 Define the model
= (;
prior_dists = Normal(0, 10),
μ = truncated(Normal(0, 10), lower=0)
σ
)
function log_likelihood(θ, y)
= θ
μ, σ return sum(logpdf(Normal(μ, σ), y))
end
function log_prior(θ)
= θ
μ, σ return logpdf(prior_dists.μ, μ) + logpdf(prior_dists.σ, σ)
end
function sample_prior(prior_dists, n)
= rand(prior_dists.μ, n)
μ = rand(prior_dists.σ, n)
σ return [μ σ]
end
sample_prior (generic function with 1 method)
1.3 Sampling
function SMC(y, prior_dists, num_particles, num_iterations)
= sample_prior(prior_dists, num_particles)
particles
for iter in 1:num_iterations
= [log_likelihood(p, y) + log_prior(p) for p in eachrow(particles)]
log_weights = exp.(log_weights)
weights
# Resample particles based on weights
= sample(1:num_particles, StatsBase.Weights(weights), num_particles; replace = true)
indices = particles[indices, :]
particles
# Move particles
= (Normal(0, 0.5), Normal(0, 0.5))
proposal_dist for i in 1:num_particles
= particles[i, 1] + rand(proposal_dist[1])
μ_new = particles[i, 2] + rand(proposal_dist[2])
σ_new :] = [μ_new, σ_new]
particles[i, end
end
return particles
end
= 2000
num_particles = 20
num_iterations = SMC(y, prior_dists, num_particles, num_iterations)
posterior_samples
let
= Figure(size = (800, 600))
fig Axis(fig[1, 1]; xlabel = "μ", ylabel = "σ")
scatter!(posterior_samples[:, 1], posterior_samples[:, 2],
= 10, color = (:black, 0.4))
markersize
vlines!(true_μ, color = :red, linewidth = 2, label = "true parameters\nμ and σ")
hlines!(true_σ, color = :red, linewidth = 2)
vlines!(mean(y), color = :orange, linewidth = 2, label = "sampling mean\nand std")
hlines!(std(y), color = :orange, linewidth = 2)
axislegend()
figend