Logistic growth - MCMC without modelling the hidden states v3

1 Load packages

Code
import Random
import StatsPlots

using AdaptiveMCMC
using CairoMakie
using Distributions
using LinearAlgebra
using LogDensityProblems
using MCMCChains
using PairPlots
using ProtoStructs
using Statistics
using TransformVariables
using TransformedLogDensities
using UnPack

set_theme!(
    fontsize = 18,
    Axis = (; xgridvisible = false, ygridvisible = false,
            topspinevisible = false, rightspinevisible = false),
    Legend = (; framevisible = false))

2 Generate data

Code
function generate_data(n_observations; σ_p, σ_o, r, K, x₀)
    ts = 1:n_observations
    T = length(ts)

    s = Array{Float64}(undef, T)
    x = Array{Float64}(undef, T)
    y = Array{Float64}(undef, T)
    ε = rand(Normal(0, σ_p), T)

    for t in ts
        x_lastt = t == 1 ? x₀ : x[t-1]
        s_lastt = t == 1 ? x₀ : s[t-1]

        s[t] = (1 + r*(1 - s_lastt/K)) * s_lastt
        x[t] = (1 + r*(1 - x_lastt/K) + ε[t]) * x_lastt
        y[t] = rand(Gamma(x[t]^2 / σ_o^2, σ_o^2 / x[t]))
    end

    (; ts, s, x, y, parameter = (; σ_o, σ_p, r, K, x₀))
end

Random.seed!(123)
true_solution = generate_data(100; σ_p = 0.05, σ_o = 20.0, r = 0.1, K = 400, x₀ = 20.0);

let
    fig = Figure(size = (750, 300))

    ax = Axis(fig[1, 1]; xlabel = "time", ylabel = "population size")
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    Legend(fig[1, 2], ax)
    fig
end

3 Define the posterior

Code
my_priors = (;
    r = truncated(Normal(0.1, 0.02); lower = 0),
    K = (truncated(Normal(500, 120), lower = 0)),
    x₀ = truncated(Normal(15, 20); lower = 0),
    σ_o = truncated(Normal(15, 5); lower = 0),
    σ_p = truncated(Normal(0.05, 0.01); lower = 0),
)

@proto struct StateSpaceModel
    ts::UnitRange{Int64} 
    y::Vector{Float64}
    prior_dists::NamedTuple
    nparameter::Int64
    transformation
end

function (problem::StateSpaceModel)(θ)
    @unpack ts, y, prior_dists = problem
    @unpack r, K, x₀, σ_o, σ_p  = θ
    
    logprior = 0
    for k in keys(prior_dists)
        logprior += logpdf(prior_dists[k], θ[k])
    end
    
    if logprior ==  -Inf
        return -Inf
    end
        
    loglikelihood = 0.0
    x = 0.0
    for t in ts
        # process equation
        x_last = t == 1 ? x₀ : y[t-1]
        ε = rand(Normal(0, σ_p))
        x = (1 + r*(1 - x_last/K) + ε) * x_last
    
        # observation equation
        if x <= 0
            return -Inf
        end
        α = x^2 / σ_o^2
        θ = σ_o^2 / x
        loglikelihood += logpdf(Gamma(α, θ), y[t])
    end
    
    return loglikelihood + logprior
end

function sample_prior(problem; transform_p = false)
    @unpack ts, prior_dists, transformation = problem
    
    x = []
    for k in keys(prior_dists)
        push!(x, rand(prior_dists[k]))
    end
    
    p = (; zip(keys(prior_dists), x)...)
    if transform_p 
        return inverse(transformation, p)
    end
        
    return p      
end    


function sample_initial_values(prob, sol; transform_p = false)
    @unpack ts, prior_dists, transformation = prob
    @unpack parameter = sol
    
    x = []
    for k in keys(prior_dists)
        push!(x, (1 + rand(Normal(0.0, 0.01))) * parameter[k])
    end
    
    p = (; zip(keys(prior_dists), x)...)
    if transform_p 
        return inverse(transformation, p)
    end
        
    return p      
end
    
my_transform = as((r = asℝ₊, K = asℝ, x₀ = asℝ₊, σ_o = asℝ₊, σ_p = asℝ₊))
problem = StateSpaceModel(true_solution.ts, true_solution.y, my_priors, 5, my_transform)

= TransformedLogDensity(problem.transformation, problem)
posterior(x) = LogDensityProblems.logdensity(ℓ, x)
posterior(sample_prior(problem; transform_p = true))
-704.9373768923347

4 Prior predictive check

Code
let 
    nsamples = 200
    
    fig = Figure(; size = (800, 800))
    Axis(fig[1, 1])
    for i in 1:nsamples
        @unpack r, K, x₀, σ_o = sample_prior(problem)
        
        @unpack ts = problem
        
        x = zeros(length(ts))
        for t in ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            x[t] = (1 + r*(1 - x_last/K)) * x_last
        end
        
        lines!(ts, x; color = (:black, 0.1))
    end
    
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    Axis(fig[2, 1])
    for i in 1:nsamples
        @unpack r, K, x₀, σ_o, σ_p = sample_prior(problem)
        
        @unpack ts = problem
        
        x = zeros(length(ts))
        for t in ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            ε = rand(Normal(0, σ_p))
            x[t] = (1 + r*(1 - x_last/K) + ε) * x_last
        end
        
        lines!(ts, x; color = (:black, 0.1))
    end
    
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    
    fig
end

5 Sampling

Code
nsamples = 1_000_000
nchains = 4
L = 2
thin = 100
nparameter = problem.nparameter
post_raw = zeros(nchains, nparameter, nsamples ÷ thin)

Threads.@threads for n in 1:nchains
    # init_x = sample_prior(problem; transform_p = true)
    init_x = sample_initial_values(problem, true_solution; transform_p = true)
    out = adaptive_rwm(init_x, posterior, nsamples; 
                       algorithm=:am, b = 1, L, thin, progress = false)
    post_raw[n, :, :] = out.X
end

back to the original space:

Code
post = zeros(nsamples ÷ thin, nparameter, nchains)
for c in 1:nchains
    for i in 1:(nsamples ÷ thin)
        post[i, :, c] .= collect(transform(problem.transformation, post_raw[c, :, i])) 
    end
end

6 Convergence diagnostics

6.1 Rhat and estimated sampling size

Code
p_names = collect(keys(problem.prior_dists))
burnin = nsamples ÷ thin ÷ 2

chn = Chains(post[burnin:end, :, :], p_names)
Chains MCMC chain (5001×5×4 Array{Float64, 3}):

Iterations        = 1:1:5001
Number of chains  = 4
Samples per chain = 5001
parameters        = r, K, x₀, σ_o, σ_p

Summary Statistics
  parameters       mean       std      mcse   ess_bulk   ess_tail      rhat        Symbol    Float64   Float64   Float64    Float64    Float64   Float64    ⋯

           r     0.0908    0.0181    0.0014   179.2726   318.1972    1.0520    ⋯
           K   411.4144   56.2869    2.8384   311.1559   361.7539    1.0817    ⋯
          x₀    28.3845   10.4171    0.7486   217.7648   270.6981    1.0629    ⋯
         σ_o    28.7636    1.7470    0.1313   186.8495   349.8429    1.0763    ⋯
         σ_p     0.0423    0.0081    0.0006   171.3517   300.5455    1.0877    ⋯
                                                                1 column omitted

Quantiles
  parameters       2.5%      25.0%      50.0%      75.0%      97.5% 
      Symbol    Float64    Float64    Float64    Float64    Float64 

           r     0.0531     0.0798     0.0904     0.1069     0.1204
           K   324.5475   374.9129   412.3955   429.1943   555.4846
          x₀     7.8504    20.1290    27.7525    37.9074    46.8730
         σ_o    25.4150    27.3036    28.5326    30.0089    32.5175
         σ_p     0.0256     0.0369     0.0441     0.0466     0.0551

6.2 Pair plot for model parameter

Code
pairplot(chn, PairPlots.Truth(true_solution.parameter))

6.3 Trace plot for model parameter

Code
StatsPlots.plot(chn)

7 Posterior predictive check

Code
function sample_posterior(data, problem, burnin)
    nchains, nparameter, nsamples = size(data)
    transform(problem.transformation, data[sample(1:nchains), :, sample(burnin:nsamples)])
end

let
    fig = Figure(size = (800, 800))
    ax = Axis(fig[1, 1]; ylabel = "value")
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    for i in 1:200
        @unpack r, K, x₀, σ_o = sample_posterior(post_raw, problem, burnin)
        x = zeros(length(problem.ts))
        for t in problem.ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            x[t] = (1 + r*(1 - x_last/K)) * x_last
        end
        
        lines!(true_solution.ts, x, color = (:black, 0.05))
    end
    
    Legend(fig[1:2, 2], ax)
    
    Axis(fig[2, 1]; ylabel = "value")
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    for i in 1:200
        @unpack r, K, x₀, σ_o, σ_p = sample_posterior(post_raw, problem, burnin)
        x = zeros(length(problem.ts))
        for t in problem.ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            ε = rand(Normal(0, σ_p))
            x[t] = (1 + r*(1 - x_last/K) + ε) * x_last
        end
        
        lines!(true_solution.ts, x, color = (:black, 0.05))
    end
    
    fig
end