Random walk Metropolis

\[ \begin{align} &p(θ | data) \propto p(data | θ) \cdot p(θ) \\ &r_{M} = min\left(1, \frac{p(θ_{t+1} | data)}{p(θ_{t} | data)}\right) \end{align} \]

1 Setup:

using CairoMakie
using Distributions
using LinearAlgebra
using MCMCChains
using PairPlots
using Statistics

import StatsPlots

set_theme!(
    fontsize=18,
    Axis=(xgridvisible=false, ygridvisible=false,
          topspinevisible=false, rightspinevisible=false),
)

2 Generate data

μ, σ = 5, 2
y = rand(Normal(μ, σ), 500)
hist(y)

3 Define Likelihood and Prior

propσ = [0.01, 0.001]
prior_μ = truncated(Normal(0, 3), -10, 10)
prior_σ = truncated(Normal(0, 1), 0, 10)

function posterior(θ)
    ## prior 
    log_prior = logpdf(prior_μ, θ[1])
    log_prior += logpdf(prior_σ, θ[2])
    if log_prior == -Inf
        return -Inf
    end
        
    ## likelihood
    log_likelihood = 0
    for i in eachindex(y)
        log_likelihood += logpdf(Normal(θ[1], θ[2]), y[i])
    end
    
    ## unnormalized posterior
    p = log_likelihood + log_prior
    
    return p 
end
posterior (generic function with 1 method)

4 Run MCMC

n_samples = 20_000
burnin = n_samples ÷ 2
nchains = 6
nparameter = 2
accepted_θ = zeros(nchains, nparameter, n_samples)
accepted = zeros(nchains)
θ = zeros(nparameter)

for n in 1:nchains
    θ[1] = rand(prior_μ)
    θ[2] = rand(prior_σ)
    post = posterior(θ)

    for k in 1:n_samples
        
        ## new proposal
        proposal_dist = MvNormal(θ, Diagonal(propσ))
        θstar = rand(proposal_dist) 
        
        ## evaluate prior + likelihood
        poststar = posterior(θstar)
        
        ## M-H ratio
        ratio = poststar - post

        if log(rand()) < min(ratio, 1)
            accepted[n] += 1
            θ = θstar
            post = poststar
        end
        
        accepted_θ[n, :, k] = θ
    end
end


accepted / n_samples
6-element Vector{Float64}:
 0.60775
 0.61345
 0.61275
 0.61275
 0.6182
 0.6232

5 Convergence

chn = Chains(permutedims(accepted_θ, (3,2,1)), [:μ, :σ])
Chains MCMC chain (20000×2×6 Array{Float64, 3}):

Iterations        = 1:1:20000
Number of chains  = 6
Samples per chain = 20000
parameters        = μ, σ

Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

           μ    4.9236    0.4876    0.0228   6761.4570   3574.8038    1.0009   ⋯
           σ    1.9804    0.0950    0.0031   3527.9637   3195.3202    1.0019   ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           μ    4.7698    4.8933    4.9534    5.0139    5.1282
           σ    1.8595    1.9349    1.9759    2.0200    2.1108

6 Trace plot and densities of the MCMC samples

6.1 With burnin

function trace_plot(; burnin)
    fig = Figure()
    
    titles = ["μ", "σ"]
    for i in 1:2
        Axis(fig[i,1]; title = titles[i])
        
        for n in 1:nchains
            lines!((burnin:n_samples) .- burnin, accepted_θ[n, i, burnin:end];
                color=(Makie.wong_colors()[n], 0.5))
        end
        
        Axis(fig[i,2])
        for n in 1:nchains
            density!(accepted_θ[n, i, burnin:end];
                    bins=20, 
                    color= (Makie.wong_colors()[n], 0.1),
                    strokecolor = (Makie.wong_colors()[n], 1),
                    strokewidth = 2, strokearound = false)
        end
    
    end
    rowgap!(fig.layout, 1, 5)
    fig
end

trace_plot(; burnin = 1) # keep all samples

6.2 Without burnin

trace_plot(; burnin) # remove half of the samples

6.2.1 Or use the function fromm StatsPlots

StatsPlots.plot(chn[burnin:end, :, :])

7 Pair plot

7.1 With burnin

pairplot(chn)

7.2 Without burnin

pairplot(chn[burnin:end, :, :])

8 Posterior predictive check

begin
    fig = Figure()
    
    Axis(fig[1,1]; title="Posterior predictive check")

    μs = vec(accepted_θ[:, 1, burnin:end])
    σs = vec(accepted_θ[:, 2, burnin:end])
    
    npredsamples = 500
    ns = sample(1:length(μs), npredsamples;
                replace=false) 
    
    minx, maxx = minimum(y)-4, maximum(y)+4
    nxvals = 200
    xvals = LinRange(minx, maxx, nxvals)
    pred = zeros(npredsamples, nxvals)
    
    ## calculate and plot each predictive sample
    for i in eachindex(ns)
        μ = μs[ns[i]]
        σ = σs[ns[i]]
        post_dist = Normal(μ,σ)
        
        yvals = pdf.(post_dist, xvals)
        pred[i, :] = yvals
        
        lines!(xvals, yvals;
                color=(:grey, 0.1))
    end
    
    ## mean of the predicted densities
    meany = vec(mean(pred, dims=1))
    lines!(xvals, meany;
            linewidth=3,
            color=(:red, 0.8))
    
    ## histogram of the data
    hist!(y; 
        normalization=:pdf,
        bins=25,
        color=(:blue, 0.3),
        strokecolor=(:white, 0.5),
        strokewidth=1)
    
    fig
end