Logistic growth - state space model with MCMC

1 Load packages

Code
import Random
import StatsPlots

using AdaptiveMCMC
using CairoMakie
using Distributions
using LinearAlgebra
using LogDensityProblems
using MCMCChains
using PairPlots
using ProtoStructs
using Statistics
using TransformVariables
using TransformedLogDensities
using UnPack

set_theme!(
    fontsize = 18,
    Axis = (; xgridvisible = false, ygridvisible = false,
            topspinevisible = false, rightspinevisible = false),
    Legend = (; framevisible = false))

2 Generate data

Code
function generate_data(n_observations; σ_p, σ_o, r, K, x₀)
    ts = 1:n_observations
    T = length(ts)

    s = Array{Float64}(undef, T)
    x = Array{Float64}(undef, T)
    y = Array{Float64}(undef, T)
    ε = rand(Normal(0, σ_p), T)

    for t in ts
        x_lastt = t == 1 ? x₀ : x[t-1]
        s_lastt = t == 1 ? x₀ : s[t-1]

        s[t] = (1 + r*(1 - s_lastt/K)) * s_lastt
        x[t] = (1 + r*(1 - x_lastt/K) + ε[t]) * x_lastt
        y[t] = rand(Gamma(x[t]^2 / σ_o^2, σ_o^2 / x[t]))
    end

    (; ts, s, x, y, parameter = (; σ_o, r, K, x₀))
end

Random.seed!(123)
true_solution = generate_data(100; σ_p = 0.05, σ_o = 20.0, r = 0.1, K = 400, x₀ = 20.0);

let
    fig = Figure(size = (750, 300))

    ax = Axis(fig[1, 1]; xlabel = "time", ylabel = "population size")
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    Legend(fig[1, 2], ax)
    fig
end

3 Define the posterior

Code
my_priors = (;
    r = truncated(Normal(0.1, 0.02); lower = 0),
    K = (truncated(Normal(500, 120), lower = 0)),
    x₀ = truncated(Normal(15, 20); lower = 0),
    σ_o = truncated(Normal(15, 5); lower = 0),
    ε = Normal(0, 0.1) 
)

@proto struct StateSpaceModel
    ts::UnitRange{Int64} 
    y::Vector{Float64}
    prior_dists::NamedTuple
    nparameter::Int64
    transformation
end

function (problem::StateSpaceModel)(θ)
    @unpack ts, y, prior_dists = problem
    @unpack r, K, x₀, σ_o, ε  = θ
    
    logprior = 0
    for k in keys(prior_dists)
        if k == :ε
            logprior += sum(logpdf.(prior_dists[k], θ[k]))
        else
            logprior += logpdf(prior_dists[k], θ[k])
        end
    end
    
    if logprior ==  -Inf
        return -Inf
    end
        
    loglikelihood = 0.0
    x = 0.0
    for t in ts
        # process equation
        x_last = t == 1 ? x₀ : x
        x = (1 + r*(1 - x_last/K) + ε[t]) * x_last
    
        # observation equation
        if x <= 0
            return -Inf
        end
        α = x^2 / σ_o^2
        θ = σ_o^2 / x
        loglikelihood += logpdf(Gamma(α, θ), y[t])
    end
    
    return loglikelihood + logprior
end

function sample_prior(problem; transform_p = false)
    @unpack ts, prior_dists, transformation = problem
    
    x = []
    for k in keys(prior_dists)
        if k == :ε
            push!(x, rand(prior_dists[k], length(ts)))
        else
            push!(x, rand(prior_dists[k]))
        end
    end
    
    p = (; zip(keys(prior_dists), x)...)
    if transform_p 
        return inverse(transformation, p)
    end
        
    return p      
end    


function sample_initial_values(prob, sol; transform_p = false)
    @unpack ts, prior_dists, transformation = prob
    @unpack parameter = sol
    
    x = []
    for k in keys(prior_dists)
        if k == :ε
            push!(x, zeros(length(ts)))
        else
            push!(x, (1 + rand(Normal(0.0, 0.01))) * parameter[k])
        end
    end
    
  p = (; zip(keys(prior_dists), x)...)
    if transform_p 
        return inverse(transformation, p)
    end
        
    return p      
end
    
my_transform = as((r = asℝ₊, K = asℝ, x₀ = asℝ₊, σ_o = asℝ₊, 
                   ε = as(Array, length(true_solution.y))))
problem = StateSpaceModel(true_solution.ts, true_solution.y, my_priors, 
                          4 + length(true_solution.y), my_transform)

= TransformedLogDensity(problem.transformation, problem)
posterior(x) = LogDensityProblems.logdensity(ℓ, x)
posterior(sample_prior(problem; transform_p = true))
-8654.414043066796

4 Prior predictive check

Code
let 
    nsamples = 200
    
    fig = Figure(; size = (700, 800))
    Axis(fig[1, 1])
    for i in 1:nsamples
        @unpack r, K, x₀, σ_o, ε = sample_prior(problem)
        
        @unpack ts = problem
        
        x = zeros(length(ts))
        for t in ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            x[t] = (1 + r*(1 - x_last/K) + ε[t]) * x_last
        end
        
        lines!(ts, x; color = (:black, 0.1))
    end
    
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    
    Axis(fig[2, 1])
    for i in 1:nsamples
        @unpack r, K, x₀, σ_o, ε = sample_prior(problem)
        @unpack ts = problem
        
        x = zeros(length(ts))
        for t in ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            x[t] = (1 + r*(1 - x_last/K)) * x_last
        end
        
        lines!(ts, x; color = (:black, 0.1))
    end
    
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    fig
end

5 Sampling

Code
nsamples = 1_000_000
nchains = 4
L = 1
thin = 100
nparameter = problem.nparameter
post_raw = zeros(nchains, nparameter, nsamples ÷ thin)

Threads.@threads for n in 1:nchains
    # init_x = sample_prior(problem; transform_p = true)
    init_x = sample_initial_values(problem, true_solution; transform_p = true)
    out = adaptive_rwm(init_x, posterior, nsamples; 
                       algorithm=:am, b = 1, L, thin, progress = false)
    post_raw[n, :, :] = out.X
end

back to the original space:

Code
post = zeros(nsamples ÷ thin, nparameter, nchains)
for c in 1:nchains
    for i in 1:(nsamples ÷ thin)
        transformed_parameter = collect(transform(problem.transformation, post_raw[c, :, i])) 
        post[i, 1:4, c] .= transformed_parameter[1:4]
        post[i, 5:end, c] .= transformed_parameter[5]
    end
end

6 Convergence diagnostics

6.1 Rhat and estimated sampling size

Code
# collect parameter names
p_names = collect(keys(problem.prior_dists))[1:4]
last_pname = collect(keys(problem.prior_dists))[5]
last_pnames = [Symbol(last_pname, "_", i) for i in 1:length(true_solution.y)]
append!(p_names, last_pnames)

burnin = nsamples ÷ thin ÷ 2

chn = Chains(post[burnin:end, :, :], p_names)
chn[:, 1:9, :]
Chains MCMC chain (5001×9×4 Array{Float64, 3}):

Iterations        = 1:1:5001
Number of chains  = 4
Samples per chain = 5001
parameters        = r, K, x₀, σ_o, ε_1, ε_2, ε_3, ε_4, ε_5

Summary Statistics
  parameters       mean       std      mcse    ess_bulk    ess_tail      rhat  ⋯
      Symbol    Float64   Float64   Float64     Float64     Float64   Float64  ⋯

           r     0.1090    0.0151    0.0004   1563.8626   1950.8757    1.0071  ⋯
           K   409.5450   48.8434    1.5094   1123.2964   1042.4883    1.0070  ⋯
          x₀    13.9495    3.5431    0.0561   3866.2916   7428.3721    1.0033  ⋯
         σ_o    18.1061    1.8164    0.0337   2860.9175   5353.2925    1.0021  ⋯
         ε_1    -0.0085    0.1000    0.0015   4288.8237   7171.7700    1.0007  ⋯
         ε_2    -0.0230    0.1012    0.0015   4574.9005   8666.9285    1.0013  ⋯
         ε_3    -0.0018    0.0958    0.0015   3928.2181   8308.7558    1.0008  ⋯
         ε_4    -0.0167    0.0969    0.0015   4129.6575   8338.4868    1.0012  ⋯
         ε_5     0.0027    0.0948    0.0014   4372.9207   8447.8361    1.0007  ⋯
                                                                1 column omitted

Quantiles
  parameters       2.5%      25.0%      50.0%      75.0%      97.5%
      Symbol    Float64    Float64    Float64    Float64    Float64

           r     0.0797     0.0986     0.1089     0.1192     0.1384
           K   327.5882   373.9473   404.5582   440.4740   517.1499
          x₀     8.0987    11.4233    13.6010    16.0539    21.8519
         σ_o    14.8112    16.8379    18.0067    19.2922    21.8794
         ε_1    -0.1997    -0.0772    -0.0086     0.0584     0.1886
         ε_2    -0.2227    -0.0907    -0.0230     0.0449     0.1744
         ε_3    -0.1894    -0.0669    -0.0021     0.0636     0.1840
         ε_4    -0.2081    -0.0813    -0.0172     0.0481     0.1732
         ε_5    -0.1823    -0.0611     0.0025     0.0660     0.1901

6.2 Pair plot for model parameter

Code
pairplot(chn[:, 1:4, :], PairPlots.Truth(true_solution.parameter))

6.3 Trace plot for model parameter

Code
StatsPlots.plot(chn[:, 1:4, :])

6.4 Trace plot for (some) state parameter

Code
StatsPlots.plot(chn[:, 5:9, :])

7 Posterior predictive check

Code
function sample_posterior(data, problem, burnin)
    nchains, nparameter, nsamples = size(data)
    transform(problem.transformation, data[sample(1:nchains), :, sample(burnin:nsamples)])
end

let
    fig = Figure(size = (800, 800))
    ax = Axis(fig[1, 1]; ylabel = "value")
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    for i in 1:200
        @unpack r, K, x₀, σ_o, ε = sample_posterior(post_raw, problem, burnin)
        x = zeros(length(problem.ts))
        for t in problem.ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            x[t] = (1 + r*(1 - x_last/K) + ε[t]) * x_last
        end
        
        lines!(true_solution.ts, x, color = (:black, 0.02))
    end
    
    Legend(fig[1:2, 2], ax)
    
    Axis(fig[2, 1]; xlabel = "time", ylabel = "value")
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    for i in 1:200
        @unpack r, K, x₀, σ_o, ε = sample_posterior(post_raw, problem, burnin)
        x = zeros(length(problem.ts))
        for t in problem.ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            x[t] = (1 + r*(1 - x_last/K)) * x_last
        end
        
        lines!(true_solution.ts, x, color = (:black, 0.05))
    end
    
    fig
end