Logistic growth - MCMC without modelling the hidden states v1

1 Load packages

Code
import Random
import StatsPlots

using AdaptiveMCMC
using CairoMakie
using Distributions
using LinearAlgebra
using LogDensityProblems
using MCMCChains
using PairPlots
using ProtoStructs
using Statistics
using TransformVariables
using TransformedLogDensities
using UnPack

set_theme!(
    fontsize = 18,
    Axis = (; xgridvisible = false, ygridvisible = false,
            topspinevisible = false, rightspinevisible = false),
    Legend = (; framevisible = false))

2 Generate data

Code
function generate_data(n_observations; σ_p, σ_o, r, K, x₀)
    ts = 1:n_observations
    T = length(ts)

    s = Array{Float64}(undef, T)
    x = Array{Float64}(undef, T)
    y = Array{Float64}(undef, T)
    ε = rand(Normal(0, σ_p), T)

    for t in ts
        x_lastt = t == 1 ? x₀ : x[t-1]
        s_lastt = t == 1 ? x₀ : s[t-1]

        s[t] = (1 + r*(1 - s_lastt/K)) * s_lastt
        x[t] = (1 + r*(1 - x_lastt/K) + ε[t]) * x_lastt
        y[t] = rand(Gamma(x[t]^2 / σ_o^2, σ_o^2 / x[t]))
    end

    (; ts, s, x, y, parameter = (; σ_o, r, K, x₀))
end

Random.seed!(123)
true_solution = generate_data(100; σ_p = 0.05, σ_o = 20.0, r = 0.1, K = 400, x₀ = 20.0);

let
    fig = Figure(size = (750, 300))

    ax = Axis(fig[1, 1]; xlabel = "time", ylabel = "population size")
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    Legend(fig[1, 2], ax)
    fig
end

3 Define the posterior

Code
my_priors = (;
    r = truncated(Normal(0.1, 0.02); lower = 0),
    K = (truncated(Normal(500, 120), lower = 0)),
    x₀ = truncated(Normal(15, 20); lower = 0),
    σ_o = truncated(Normal(15, 5); lower = 0)
)

@proto struct StateSpaceModel
    ts::UnitRange{Int64} 
    y::Vector{Float64}
    prior_dists::NamedTuple
    nparameter::Int64
    transformation
end

function (problem::StateSpaceModel)(θ)
    @unpack ts, y, prior_dists = problem
    @unpack r, K, x₀, σ_o  = θ
    
    logprior = 0
    for k in keys(prior_dists)
        logprior += logpdf(prior_dists[k], θ[k])
    end
    
    if logprior ==  -Inf
        return -Inf
    end
        
    loglikelihood = 0.0
    x = 0.0
    for t in ts
        # process equation
        x_last = t == 1 ? x₀ : x
        x = (1 + r*(1 - x_last/K)) * x_last
    
        # observation equation
        if x <= 0
            return -Inf
        end
        α = x^2 / σ_o^2
        θ = σ_o^2 / x
        loglikelihood += logpdf(Gamma(α, θ), y[t])
    end
    
    return loglikelihood + logprior
end

function sample_prior(problem; transform_p = false)
    @unpack ts, prior_dists, transformation = problem
    
    x = []
    for k in keys(prior_dists)
        push!(x, rand(prior_dists[k]))
    end
    
    p = (; zip(keys(prior_dists), x)...)
    if transform_p 
        return inverse(transformation, p)
    end
        
    return p      
end    


function sample_initial_values(prob, sol; transform_p = false)
    @unpack ts, prior_dists, transformation = prob
    @unpack parameter = sol
    
    x = []
    for k in keys(prior_dists)
        push!(x, (1 + rand(Normal(0.0, 0.01))) * parameter[k])
    end
    
    p = (; zip(keys(prior_dists), x)...)
    if transform_p 
        return inverse(transformation, p)
    end
        
    return p      
end
    
my_transform = as((r = asℝ₊, K = asℝ, x₀ = asℝ₊, σ_o = asℝ₊))
problem = StateSpaceModel(true_solution.ts, true_solution.y, my_priors, 4, my_transform)

= TransformedLogDensity(problem.transformation, problem)
posterior(x) = LogDensityProblems.logdensity(ℓ, x)
posterior(sample_prior(problem; transform_p = true))
-2229.0066154324477

4 Prior predictive check

Code
let 
    nsamples = 200
    
    fig = Figure(; size = (700, 400))
    Axis(fig[1, 1])
    for i in 1:nsamples
        @unpack r, K, x₀, σ_o = sample_prior(problem)
        
        @unpack ts = problem
        
        x = zeros(length(ts))
        for t in ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            x[t] = (1 + r*(1 - x_last/K)) * x_last
        end
        
        lines!(ts, x; color = (:black, 0.1))
    end
    
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    

    fig
end

5 Sampling

Code
nsamples = 1_000_000
nchains = 4
L = 1
thin = 100
nparameter = problem.nparameter
post_raw = zeros(nchains, nparameter, nsamples ÷ thin)

Threads.@threads for n in 1:nchains
    # init_x = sample_prior(problem; transform_p = true)
    init_x = sample_initial_values(problem, true_solution; transform_p = true)
    out = adaptive_rwm(init_x, posterior, nsamples; 
                       algorithm=:am, b = 1, L, thin, progress = false)
    post_raw[n, :, :] = out.X
end

back to the original space:

Code
post = zeros(nsamples ÷ thin, nparameter, nchains)
for c in 1:nchains
    for i in 1:(nsamples ÷ thin)
        post[i, 1:4, c] .= collect(transform(problem.transformation, post_raw[c, :, i])) 
    end
end

6 Convergence diagnostics

6.1 Rhat and estimated sampling size

Code
p_names = collect(keys(problem.prior_dists))
burnin = nsamples ÷ thin ÷ 2

chn = Chains(post[burnin:end, :, :], p_names)
Chains MCMC chain (5001×4×4 Array{Float64, 3}):

Iterations        = 1:1:5001
Number of chains  = 4
Samples per chain = 5001
parameters        = r, K, x₀, σ_o

Summary Statistics
  parameters       mean       std      mcse     ess_bulk     ess_tail      rha     Symbol    Float64   Float64   Float64      Float64      Float64   Float6 ⋯

           r     0.0813    0.0052    0.0000   20161.4627   19506.9542    1.000 ⋯
           K   387.2089    7.6274    0.0537   20176.8304   19687.9238    1.000 ⋯
          x₀    27.9200    3.1544    0.0222   20285.5176   19121.5713    1.000 ⋯
         σ_o    32.9005    1.9322    0.0136   20173.0757   19824.2450    1.000 ⋯
                                                               2 columns omitted

Quantiles
  parameters       2.5%      25.0%      50.0%      75.0%      97.5% 
      Symbol    Float64    Float64    Float64    Float64    Float64 

           r     0.0717     0.0777     0.0810     0.0846     0.0920
           K   372.7365   382.0207   387.0613   392.1432   402.8794
          x₀    21.9440    25.7643    27.8505    29.9939    34.3157
         σ_o    29.3240    31.5760    32.8216    34.1645    36.9307

6.2 Pair plot for model parameter

Code
pairplot(chn, PairPlots.Truth(true_solution.parameter))

6.3 Trace plot for model parameter

Code
StatsPlots.plot(chn)

7 Posterior predictive check

Code
function sample_posterior(data, problem, burnin)
    nchains, nparameter, nsamples = size(data)
    transform(problem.transformation, data[sample(1:nchains), :, sample(burnin:nsamples)])
end

let
    fig = Figure(size = (800, 500))
    ax = Axis(fig[1, 1]; ylabel = "value")
    scatter!(true_solution.ts, true_solution.y, color = :steelblue4, label = "observations: y")
    lines!(true_solution.ts, true_solution.x, color = :blue, label = "true hidden state: x")
    lines!(true_solution.ts, true_solution.s, color = :red, label = "process-model state: s")
    
    for i in 1:200
        @unpack r, K, x₀, σ_o = sample_posterior(post_raw, problem, burnin)
        x = zeros(length(problem.ts))
        for t in problem.ts
            # process equation
            x_last = t == 1 ? x₀ : x[t-1]
            x[t] = (1 + r*(1 - x_last/K)) * x_last
        end
        
        lines!(true_solution.ts, x, color = (:black, 0.02))
    end
    
    Legend(fig[1, 2], ax)
    
    fig
end