Logistic growth - state space model with MCMC

1 Load packages

Code
import Random
import StatsPlots

using AdaptiveMCMC
using CairoMakie
using Distributions
using LinearAlgebra
using LogDensityProblems
using MCMCChains
using PairPlots
using ProtoStructs
using Statistics
using TransformVariables
using TransformedLogDensities
using UnPack

set_theme!(
    fontsize = 18,
    Axis = (; xgridvisible = false, ygridvisible = false,
            topspinevisible = false, rightspinevisible = false),
    Legend = (; framevisible = false))

2 Generate data

Code
function generate_data(n_observations; σ_p, σ_o, r, K, x₀)
    ts = 1:n_observations
    T = length(ts)

    s = Array{Float64}(undef, T)
    x = Array{Float64}(undef, T)
    y = Array{Float64}(undef, T)
    ε = rand(Normal(0, σ_p), T)

    for t in ts
        x_lastt = t == 1 ? x₀ : x[t-1]
        s_lastt = t == 1 ? x₀ : s[t-1]

        s[t] = (1 + r*(1 - s_lastt/K)) * s_lastt
        x[t] = (1 + r*(1 - x_lastt/K) + ε[t]) * x_lastt
        y[t] = rand(Gamma(x[t]^2 / σ_o^2, σ_o^2 / x[t]))
    end

    (; ts, s, x, y, parameter = (; σ_o, r, K, x₀))
end

Random.seed!(123)
true_solution = generate_data(100; σ_p = 0.05, σ_o = 20.0, r = 0.1, K = 400, x₀ = 20.0);

let
    fig = Figure(size = (750, 300))

    ax = Axis(fig[1, 1]; xlabel = "time", ylabel = "population size")
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    Legend(fig[1, 2], ax)
    fig
end

3 Define the posterior

Code
my_priors = (;
    r = truncated(Normal(0.1, 0.02); lower = 0),
    K = (truncated(Normal(500, 120), lower = 0)),
    x₀ = truncated(Normal(15, 20); lower = 0),
    σ_o = truncated(Normal(15, 5); lower = 0),
    ε = Normal(0, 0.1) 
)

@proto struct StateSpaceModel
    ts::UnitRange{Int64} 
    y::Vector{Float64}
    prior_dists::NamedTuple
    nparameter::Int64
    transformation
end

function (problem::StateSpaceModel)(θ)
    @unpack ts, y, prior_dists = problem
    @unpack r, K, x₀, σ_o, ε  = θ
    
    logprior = 0
    for k in keys(prior_dists)
        if k == :ε
            logprior += sum(logpdf.(prior_dists[k], θ[k]))
        else
            logprior += logpdf(prior_dists[k], θ[k])
        end
    end
    
    if logprior ==  -Inf
        return -Inf
    end
        
    loglikelihood = 0.0
    x = 0.0
    for t in ts
        # process equation
        x_last = t == 1 ? x₀ : x
        x = (1 + r*(1 - x_last/K) + ε[t]) * x_last
    
        # observation equation
        if x <= 0
            return -Inf
        end
        α = x^2 / σ_o^2
        θ = σ_o^2 / x
        loglikelihood += logpdf(Gamma(α, θ), y[t])
    end
    
    return loglikelihood + logprior
end

function sample_prior(problem; transform_p = false)
    @unpack ts, prior_dists, transformation = problem
    
    x = []
    for k in keys(prior_dists)
        if k == :ε
            push!(x, rand(prior_dists[k], length(ts)))
        else
            push!(x, rand(prior_dists[k]))
        end
    end
    
    p = (; zip(keys(prior_dists), x)...)
    if transform_p 
        return inverse(transformation, p)
    end
        
    return p      
end    


function sample_initial_values(prob, sol; transform_p = false)
    @unpack ts, prior_dists, transformation = prob
    @unpack parameter = sol
    
    x = []
    for k in keys(prior_dists)
        if k == :ε
            push!(x, zeros(length(ts)))
        else
            push!(x, (1 + rand(Normal(0.0, 0.01))) * parameter[k])
        end
    end
    
  p = (; zip(keys(prior_dists), x)...)
    if transform_p 
        return inverse(transformation, p)
    end
        
    return p      
end
    
my_transform = as((r = asℝ₊, K = asℝ, x₀ = asℝ₊, σ_o = asℝ₊, 
                   ε = as(Array, length(true_solution.y))))
problem = StateSpaceModel(true_solution.ts, true_solution.y, my_priors, 
                          4 + length(true_solution.y), my_transform)

= TransformedLogDensity(problem.transformation, problem)
posterior(x) = LogDensityProblems.logdensity(ℓ, x)
posterior(sample_prior(problem; transform_p = true))
-4380.998181757678

4 Prior predictive check

Code
let 
    nsamples = 200
    
    fig = Figure(; size = (700, 800))
    Axis(fig[1, 1])
    for i in 1:nsamples
        @unpack r, K, x₀, σ_o, ε = sample_prior(problem)
        
        @unpack ts = problem
        
        x = zeros(length(ts))
        for t in ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            x[t] = (1 + r*(1 - x_last/K) + ε[t]) * x_last
        end
        
        lines!(ts, x; color = (:black, 0.1))
    end
    
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    
    Axis(fig[2, 1])
    for i in 1:nsamples
        @unpack r, K, x₀, σ_o, ε = sample_prior(problem)
        @unpack ts = problem
        
        x = zeros(length(ts))
        for t in ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            x[t] = (1 + r*(1 - x_last/K)) * x_last
        end
        
        lines!(ts, x; color = (:black, 0.1))
    end
    
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    fig
end

5 Sampling

Code
nsamples = 1_000_000
nchains = 4
L = 1
thin = 100
nparameter = problem.nparameter
post_raw = zeros(nchains, nparameter, nsamples ÷ thin)

Threads.@threads for n in 1:nchains
    # init_x = sample_prior(problem; transform_p = true)
    init_x = sample_initial_values(problem, true_solution; transform_p = true)
    out = adaptive_rwm(init_x, posterior, nsamples; 
                       algorithm=:am, b = 1, L, thin, progress = false)
    post_raw[n, :, :] = out.X
end

back to the original space:

Code
post = zeros(nsamples ÷ thin, nparameter, nchains)
for c in 1:nchains
    for i in 1:(nsamples ÷ thin)
        transformed_parameter = collect(transform(problem.transformation, post_raw[c, :, i])) 
        post[i, 1:4, c] .= transformed_parameter[1:4]
        post[i, 5:end, c] .= transformed_parameter[5]
    end
end

6 Convergence diagnostics

6.1 Rhat and estimated sampling size

Code
# collect parameter names
p_names = collect(keys(problem.prior_dists))[1:4]
last_pname = collect(keys(problem.prior_dists))[5]
last_pnames = [Symbol(last_pname, "_", i) for i in 1:length(true_solution.y)]
append!(p_names, last_pnames)

burnin = nsamples ÷ thin ÷ 2

chn = Chains(post[burnin:end, :, :], p_names)
chn[:, 1:9, :]
Chains MCMC chain (5001×9×4 Array{Float64, 3}):

Iterations        = 1:1:5001
Number of chains  = 4
Samples per chain = 5001
parameters        = r, K, x₀, σ_o, ε_1, ε_2, ε_3, ε_4, ε_5

Summary Statistics
  parameters       mean       std      mcse    ess_bulk    ess_tail      rhat      Symbol    Float64   Float64   Float64     Float64     Float64   Float64  ⋯

           r     0.1091    0.0150    0.0004   1721.9622   1910.2185    1.0036  ⋯
           K   411.5695   49.1072    1.5565   1159.6966    807.0328    1.0022  ⋯
          x₀    13.8230    3.5217    0.0598   3456.9825   7107.3957    1.0013  ⋯
         σ_o    18.1372    1.8768    0.0355   2758.2978   5764.8505    1.0021  ⋯
         ε_1    -0.0086    0.0992    0.0015   4517.7647   8549.7246    1.0007  ⋯
         ε_2    -0.0188    0.1004    0.0016   4043.1604   7433.0475    1.0007  ⋯
         ε_3    -0.0014    0.0959    0.0014   4625.2885   8905.7508    1.0006  ⋯
         ε_4    -0.0166    0.0977    0.0015   4379.2836   8720.8593    1.0007  ⋯
         ε_5     0.0031    0.0956    0.0016   3733.2010   7250.9654    1.0014  ⋯
                                                                1 column omitted

Quantiles
  parameters       2.5%      25.0%      50.0%      75.0%      97.5% 
      Symbol    Float64    Float64    Float64    Float64    Float64 

           r     0.0806     0.0987     0.1087     0.1191     0.1394
           K   329.4116   375.8507   406.6724   442.7579   521.2672
          x₀     8.0782    11.2829    13.4574    15.9522    21.7183
         σ_o    14.7987    16.8127    18.0061    19.3469    22.1594
         ε_1    -0.2036    -0.0754    -0.0093     0.0579     0.1859
         ε_2    -0.2155    -0.0861    -0.0188     0.0495     0.1774
         ε_3    -0.1906    -0.0659    -0.0020     0.0630     0.1877
         ε_4    -0.2078    -0.0820    -0.0166     0.0486     0.1746
         ε_5    -0.1807    -0.0618     0.0025     0.0678     0.1914

6.2 Pair plot for model parameter

Code
pairplot(chn[:, 1:4, :], PairPlots.Truth(true_solution.parameter))

6.3 Trace plot for model parameter

Code
StatsPlots.plot(chn[:, 1:4, :])

6.4 Trace plot for (some) state parameter

Code
StatsPlots.plot(chn[:, 5:9, :])

7 Posterior predictive check

Code
function sample_posterior(data, problem, burnin)
    nchains, nparameter, nsamples = size(data)
    transform(problem.transformation, data[sample(1:nchains), :, sample(burnin:nsamples)])
end

let
    fig = Figure(size = (800, 800))
    ax = Axis(fig[1, 1]; ylabel = "value")
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    for i in 1:200
        @unpack r, K, x₀, σ_o, ε = sample_posterior(post_raw, problem, burnin)
        x = zeros(length(problem.ts))
        for t in problem.ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            x[t] = (1 + r*(1 - x_last/K) + ε[t]) * x_last
        end
        
        lines!(true_solution.ts, x, color = (:black, 0.02))
    end
    
    Legend(fig[1:2, 2], ax)
    
    Axis(fig[2, 1]; xlabel = "time", ylabel = "value")
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    for i in 1:200
        @unpack r, K, x₀, σ_o, ε = sample_posterior(post_raw, problem, burnin)
        x = zeros(length(problem.ts))
        for t in problem.ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            x[t] = (1 + r*(1 - x_last/K)) * x_last
        end
        
        lines!(true_solution.ts, x, color = (:black, 0.05))
    end
    
    fig
end