---
title: "Logistic growth - state space model with MCMC"
engine: julia
bibliography: lit.bib
format:
html:
toc: true
number-sections: true
code-fold: show
code-tools: true
---
# Load packages
```{julia}
#| code-fold: true
import Random
import StatsPlots
using AdaptiveMCMC
using CairoMakie
using Distributions
using LinearAlgebra
using LogDensityProblems
using MCMCChains
using PairPlots
using ProtoStructs
using Statistics
using TransformVariables
using TransformedLogDensities
using UnPack
set_theme!(
fontsize = 18,
Axis = (; xgridvisible = false, ygridvisible = false,
topspinevisible = false, rightspinevisible = false),
Legend = (; framevisible = false))
```
# Generate data
```{julia}
#| code-fold: true
function generate_data(n_observations; σ_p, σ_o, r, K, x₀)
ts = 1:n_observations
T = length(ts)
s = Array{Float64}(undef, T)
x = Array{Float64}(undef, T)
y = Array{Float64}(undef, T)
ε = rand(Normal(0, σ_p), T)
for t in ts
x_lastt = t == 1 ? x₀ : x[t-1]
s_lastt = t == 1 ? x₀ : s[t-1]
s[t] = (1 + r*(1 - s_lastt/K)) * s_lastt
x[t] = (1 + r*(1 - x_lastt/K) + ε[t]) * x_lastt
y[t] = rand(Gamma(x[t]^2 / σ_o^2, σ_o^2 / x[t]))
end
(; ts, s, x, y, parameter = (; σ_o, r, K, x₀))
end
Random.seed!(123)
true_solution = generate_data(100; σ_p = 0.05, σ_o = 20.0, r = 0.1, K = 400, x₀ = 20.0);
let
fig = Figure(size = (750, 300))
ax = Axis(fig[1, 1]; xlabel = "time", ylabel = "population size")
scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
Legend(fig[1, 2], ax)
fig
end
```
# Define the posterior
```{julia}
my_priors = (;
r = truncated(Normal(0.1, 0.02); lower = 0),
K = (truncated(Normal(500, 120), lower = 0)),
x₀ = truncated(Normal(15, 20); lower = 0),
σ_o = truncated(Normal(15, 5); lower = 0),
ε = Normal(0, 0.1)
)
@proto struct StateSpaceModel
ts::UnitRange{Int64}
y::Vector{Float64}
prior_dists::NamedTuple
nparameter::Int64
transformation
end
function (problem::StateSpaceModel)(θ)
@unpack ts, y, prior_dists = problem
@unpack r, K, x₀, σ_o, ε = θ
logprior = 0
for k in keys(prior_dists)
if k == :ε
logprior += sum(logpdf.(prior_dists[k], θ[k]))
else
logprior += logpdf(prior_dists[k], θ[k])
end
end
if logprior == -Inf
return -Inf
end
loglikelihood = 0.0
x = 0.0
for t in ts
# process equation
x_last = t == 1 ? x₀ : x
x = (1 + r*(1 - x_last/K) + ε[t]) * x_last
# observation equation
if x <= 0
return -Inf
end
α = x^2 / σ_o^2
θ = σ_o^2 / x
loglikelihood += logpdf(Gamma(α, θ), y[t])
end
return loglikelihood + logprior
end
function sample_prior(problem; transform_p = false)
@unpack ts, prior_dists, transformation = problem
x = []
for k in keys(prior_dists)
if k == :ε
push!(x, rand(prior_dists[k], length(ts)))
else
push!(x, rand(prior_dists[k]))
end
end
p = (; zip(keys(prior_dists), x)...)
if transform_p
return inverse(transformation, p)
end
return p
end
function sample_initial_values(prob, sol; transform_p = false)
@unpack ts, prior_dists, transformation = prob
@unpack parameter = sol
x = []
for k in keys(prior_dists)
if k == :ε
push!(x, zeros(length(ts)))
else
push!(x, (1 + rand(Normal(0.0, 0.01))) * parameter[k])
end
end
p = (; zip(keys(prior_dists), x)...)
if transform_p
return inverse(transformation, p)
end
return p
end
my_transform = as((r = asℝ₊, K = asℝ, x₀ = asℝ₊, σ_o = asℝ₊,
ε = as(Array, length(true_solution.y))))
problem = StateSpaceModel(true_solution.ts, true_solution.y, my_priors,
4 + length(true_solution.y), my_transform)
ℓ = TransformedLogDensity(problem.transformation, problem)
posterior(x) = LogDensityProblems.logdensity(ℓ, x)
posterior(sample_prior(problem; transform_p = true))
```
# Prior predictive check
```{julia}
#| code-fold: true
let
nsamples = 200
fig = Figure(; size = (700, 800))
Axis(fig[1, 1])
for i in 1:nsamples
@unpack r, K, x₀, σ_o, ε = sample_prior(problem)
@unpack ts = problem
x = zeros(length(ts))
for t in ts
# process equation
x_last = t == 1 ? x₀ : x[t-1]
x[t] = (1 + r*(1 - x_last/K) + ε[t]) * x_last
end
lines!(ts, x; color = (:black, 0.1))
end
scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
Axis(fig[2, 1])
for i in 1:nsamples
@unpack r, K, x₀, σ_o, ε = sample_prior(problem)
@unpack ts = problem
x = zeros(length(ts))
for t in ts
# process equation
x_last = t == 1 ? x₀ : x[t-1]
x[t] = (1 + r*(1 - x_last/K)) * x_last
end
lines!(ts, x; color = (:black, 0.1))
end
scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
fig
end
```
# Sampling
```{julia}
nsamples = 1_000_000
nchains = 4
L = 1
thin = 100
nparameter = problem.nparameter
post_raw = zeros(nchains, nparameter, nsamples ÷ thin)
Threads.@threads for n in 1:nchains
# init_x = sample_prior(problem; transform_p = true)
init_x = sample_initial_values(problem, true_solution; transform_p = true)
out = adaptive_rwm(init_x, posterior, nsamples;
algorithm=:am, b = 1, L, thin, progress = false)
post_raw[n, :, :] = out.X
end
```
back to the original space:
```{julia}
post = zeros(nsamples ÷ thin, nparameter, nchains)
for c in 1:nchains
for i in 1:(nsamples ÷ thin)
transformed_parameter = collect(transform(problem.transformation, post_raw[c, :, i]))
post[i, 1:4, c] .= transformed_parameter[1:4]
post[i, 5:end, c] .= transformed_parameter[5]
end
end
```
# Convergence diagnostics
## Rhat and estimated sampling size
```{julia}
#| code-fold: true
# collect parameter names
p_names = collect(keys(problem.prior_dists))[1:4]
last_pname = collect(keys(problem.prior_dists))[5]
last_pnames = [Symbol(last_pname, "_", i) for i in 1:length(true_solution.y)]
append!(p_names, last_pnames)
burnin = nsamples ÷ thin ÷ 2
chn = Chains(post[burnin:end, :, :], p_names)
chn[:, 1:9, :]
```
## Pair plot for model parameter
```{julia}
#| code-fold: true
pairplot(chn[:, 1:4, :], PairPlots.Truth(true_solution.parameter))
```
## Trace plot for model parameter
```{julia}
#| code-fold: true
StatsPlots.plot(chn[:, 1:4, :])
```
## Trace plot for (some) state parameter
```{julia}
#| code-fold: true
StatsPlots.plot(chn[:, 5:9, :])
```
# Posterior predictive check
```{julia}
#| code-fold: true
function sample_posterior(data, problem, burnin)
nchains, nparameter, nsamples = size(data)
transform(problem.transformation, data[sample(1:nchains), :, sample(burnin:nsamples)])
end
let
fig = Figure(size = (800, 800))
ax = Axis(fig[1, 1]; ylabel = "value")
scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
for i in 1:200
@unpack r, K, x₀, σ_o, ε = sample_posterior(post_raw, problem, burnin)
x = zeros(length(problem.ts))
for t in problem.ts
# process equation
x_last = t == 1 ? x₀ : x[t-1]
x[t] = (1 + r*(1 - x_last/K) + ε[t]) * x_last
end
lines!(true_solution.ts, x, color = (:black, 0.02))
end
Legend(fig[1:2, 2], ax)
Axis(fig[2, 1]; xlabel = "time", ylabel = "value")
scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
for i in 1:200
@unpack r, K, x₀, σ_o, ε = sample_posterior(post_raw, problem, burnin)
x = zeros(length(problem.ts))
for t in problem.ts
# process equation
x_last = t == 1 ? x₀ : x[t-1]
x[t] = (1 + r*(1 - x_last/K)) * x_last
end
lines!(true_solution.ts, x, color = (:black, 0.05))
end
fig
end
```