Logistic growth - MCMC without modelling the hidden states v2

1 Load packages

Code
import Random
import StatsPlots

using AdaptiveMCMC
using CairoMakie
using Distributions
using LinearAlgebra
using LogDensityProblems
using MCMCChains
using PairPlots
using ProtoStructs
using Statistics
using TransformVariables
using TransformedLogDensities
using UnPack

set_theme!(
    fontsize = 18,
    Axis = (; xgridvisible = false, ygridvisible = false,
            topspinevisible = false, rightspinevisible = false),
    Legend = (; framevisible = false))

2 Generate data

Code
function generate_data(n_observations; σ_p, σ_o, r, K, x₀)
    ts = 1:n_observations
    T = length(ts)

    s = Array{Float64}(undef, T)
    x = Array{Float64}(undef, T)
    y = Array{Float64}(undef, T)
    ε = rand(Normal(0, σ_p), T)

    for t in ts
        x_lastt = t == 1 ? x₀ : x[t-1]
        s_lastt = t == 1 ? x₀ : s[t-1]

        s[t] = (1 + r*(1 - s_lastt/K)) * s_lastt
        x[t] = (1 + r*(1 - x_lastt/K) + ε[t]) * x_lastt
        y[t] = rand(Gamma(x[t]^2 / σ_o^2, σ_o^2 / x[t]))
    end

    (; ts, s, x, y, parameter = (; σ_o, σ_p, r, K, x₀))
end

Random.seed!(123)
true_solution = generate_data(100; σ_p = 0.05, σ_o = 20.0, r = 0.1, K = 400, x₀ = 20.0);

let
    fig = Figure(size = (750, 300))

    ax = Axis(fig[1, 1]; xlabel = "time", ylabel = "population size")
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    Legend(fig[1, 2], ax)
    fig
end

3 Define the posterior

Code
my_priors = (;
    r = truncated(Normal(0.1, 0.02); lower = 0),
    K = (truncated(Normal(500, 120), lower = 0)),
    x₀ = truncated(Normal(15, 20); lower = 0),
    σ_o = truncated(Normal(15, 5); lower = 0),
    σ_p = truncated(Normal(0.05, 0.01); lower = 0),
)

@proto struct StateSpaceModel
    ts::UnitRange{Int64} 
    y::Vector{Float64}
    prior_dists::NamedTuple
    nparameter::Int64
    transformation
end

function (problem::StateSpaceModel)(θ)
    @unpack ts, y, prior_dists = problem
    @unpack r, K, x₀, σ_o, σ_p  = θ
    
    logprior = 0
    for k in keys(prior_dists)
        logprior += logpdf(prior_dists[k], θ[k])
    end
    
    if logprior ==  -Inf
        return -Inf
    end
        
    loglikelihood = 0.0
    x = 0.0
    for t in ts
        # process equation
        x_last = t == 1 ? x₀ : x
        ε = rand(Normal(0, σ_p))
        x = (1 + r*(1 - x_last/K) + ε) * x_last
    
        # observation equation
        if x <= 0
            return -Inf
        end
        α = x^2 / σ_o^2
        θ = σ_o^2 / x
        loglikelihood += logpdf(Gamma(α, θ), y[t])
    end
    
    return loglikelihood + logprior
end

function sample_prior(problem; transform_p = false)
    @unpack ts, prior_dists, transformation = problem
    
    x = []
    for k in keys(prior_dists)
        push!(x, rand(prior_dists[k]))
    end
    
    p = (; zip(keys(prior_dists), x)...)
    if transform_p 
        return inverse(transformation, p)
    end
        
    return p      
end    


function sample_initial_values(prob, sol; transform_p = false)
    @unpack ts, prior_dists, transformation = prob
    @unpack parameter = sol
    
    x = []
    for k in keys(prior_dists)
        push!(x, (1 + rand(Normal(0.0, 0.01))) * parameter[k])
    end
    
    p = (; zip(keys(prior_dists), x)...)
    if transform_p 
        return inverse(transformation, p)
    end
        
    return p      
end
    
my_transform = as((r = asℝ₊, K = asℝ, x₀ = asℝ₊, σ_o = asℝ₊, σ_p = asℝ₊))
problem = StateSpaceModel(true_solution.ts, true_solution.y, my_priors, 5, my_transform)

= TransformedLogDensity(problem.transformation, problem)
posterior(x) = LogDensityProblems.logdensity(ℓ, x)
posterior(sample_prior(problem; transform_p = true))
-40605.63841959801

4 Prior predictive check

Code
let 
    nsamples = 200
    
    fig = Figure(; size = (800, 800))
    Axis(fig[1, 1])
    for i in 1:nsamples
        @unpack r, K, x₀, σ_o = sample_prior(problem)
        
        @unpack ts = problem
        
        x = zeros(length(ts))
        for t in ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            x[t] = (1 + r*(1 - x_last/K)) * x_last
        end
        
        lines!(ts, x; color = (:black, 0.1))
    end
    
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    Axis(fig[2, 1])
    for i in 1:nsamples
        @unpack r, K, x₀, σ_o, σ_p = sample_prior(problem)
        
        @unpack ts = problem
        
        x = zeros(length(ts))
        for t in ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            ε = rand(Normal(0, σ_p))
            x[t] = (1 + r*(1 - x_last/K) + ε) * x_last
        end
        
        lines!(ts, x; color = (:black, 0.1))
    end
    
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    
    fig
end

5 Sampling

Code
nsamples = 1_000_000
nchains = 4
L = 1
thin = 100
nparameter = problem.nparameter
post_raw = zeros(nchains, nparameter, nsamples ÷ thin)

Threads.@threads for n in 1:nchains
    # init_x = sample_prior(problem; transform_p = true)
    init_x = sample_initial_values(problem, true_solution; transform_p = true)
    out = adaptive_rwm(init_x, posterior, nsamples; 
                       algorithm=:am, b = 1, L, thin, progress = false)
    post_raw[n, :, :] = out.X
end

back to the original space:

Code
post = zeros(nsamples ÷ thin, nparameter, nchains)
for c in 1:nchains
    for i in 1:(nsamples ÷ thin)
        post[i, :, c] .= collect(transform(problem.transformation, post_raw[c, :, i])) 
    end
end

6 Convergence diagnostics

6.1 Rhat and estimated sampling size

Code
p_names = collect(keys(problem.prior_dists))
burnin = nsamples ÷ thin ÷ 2

chn = Chains(post[burnin:end, :, :], p_names)
Chains MCMC chain (5001×5×4 Array{Float64, 3}):

Iterations        = 1:1:5001
Number of chains  = 4
Samples per chain = 5001
parameters        = r, K, x₀, σ_o, σ_p

Summary Statistics
  parameters       mean       std      mcse   ess_bulk   ess_tail      rhat        Symbol    Float64   Float64   Float64    Float64    Float64   Float64    ⋯

           r     0.0922    0.0051    0.0008    41.7129    42.3606    2.2790    ⋯
           K   394.1455   17.2363    2.7178    40.3628        NaN    5.8685    ⋯
          x₀    23.2861    3.3008    0.5184    40.9575        NaN    2.9960    ⋯
         σ_o    25.4705    1.7812    0.2793    40.8037        NaN    3.8455    ⋯
         σ_p     0.0413    0.0082    0.0013    40.6588    48.1797    4.1841    ⋯
                                                                1 column omitted

Quantiles
  parameters       2.5%      25.0%      50.0%      75.0%      97.5% 
      Symbol    Float64    Float64    Float64    Float64    Float64 

           r     0.0816     0.0884     0.0924     0.0964     0.0994
           K   358.1780   388.1525   397.0915   404.1561   420.0461
          x₀    18.9239    21.4854    23.2282    25.0299    30.1694
         σ_o    22.8621    23.9970    25.2035    27.2304    28.4546
         σ_p     0.0266     0.0381     0.0395     0.0438     0.0565

6.2 Pair plot for model parameter

Code
pairplot(chn, PairPlots.Truth(true_solution.parameter))

6.3 Trace plot for model parameter

Code
StatsPlots.plot(chn)

7 Posterior predictive check

Code
function sample_posterior(data, problem, burnin)
    nchains, nparameter, nsamples = size(data)
    transform(problem.transformation, data[sample(1:nchains), :, sample(burnin:nsamples)])
end

let
    fig = Figure(size = (800, 800))
    ax = Axis(fig[1, 1]; ylabel = "value")
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    for i in 1:200
        @unpack r, K, x₀, σ_o = sample_posterior(post_raw, problem, burnin)
        x = zeros(length(problem.ts))
        for t in problem.ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            x[t] = (1 + r*(1 - x_last/K)) * x_last
        end
        
        lines!(true_solution.ts, x, color = (:black, 0.02))
    end
    
    Legend(fig[1:2, 2], ax)
    
    Axis(fig[2, 1]; ylabel = "value")
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    for i in 1:200
        @unpack r, K, x₀, σ_o, σ_p = sample_posterior(post_raw, problem, burnin)
        x = zeros(length(problem.ts))
        for t in problem.ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            ε = rand(Normal(0, σ_p))
            x[t] = (1 + r*(1 - x_last/K) + ε) * x_last
        end
        
        lines!(true_solution.ts, x, color = (:black, 0.05))
    end
    
    fig
end