Gibbs sampling

1 Introduction to the model

We want to estimate the unknown parameters \(\mu\) and \(\sigma\) of a normal distribution. We will use the following model: \[ \begin{align} &y = \mu + \varepsilon \\ &\varepsilon \sim \mathcal{N}(0, \sigma) \\ \end{align} \]

and the following priors: \[ \begin{align} &\mu \sim \mathcal{N}(0, 5) \\ &\sigma \sim \mathcal{N}_+(0, 2) \end{align} \]

2 Gibbs sampling

2.1 Load packages

We will use the CairoMakie package for plotting and the Distributions package for generating data and calculating the likelihood. We will also use the MCMCChains package for checking the convergence of the chains.

Code
using CairoMakie
using Distributions
using MCMCChains

set_theme!(
    fontsize=18,
    Axis=(xgridvisible=false, ygridvisible=false,
          topspinevisible=false, rightspinevisible=false),
)

2.2 Generate data

We generate 500 data points from a normal distribution with the true parameters \(\mu = 5\) and \(\sigma = 2\):

μ, σ = 5, 2
y = rand(Normal(μ, σ), 500)
hist(y)

2.3 Sampling

We start with random values for \(\mu\) and \(\sigma\) from the prior distributions. For each step, we either sample \(\mu\) or \(\sigma\) and keep the other parameter constant. The method is a special calse of the Metropolis-Hastings algorithm.

nchains = 6
nsamples = 5_000
burnin = nsamples ÷ 2
θ = zeros(2, nsamples, nchains)

μ_prior = Normal(0, 5)
σ_prior = truncated(Normal(0, 2); lower = 0)

proposals_sigma = [0.5, 0.5]

for n in 1:nchains
    θ[:, 1, n] =  [rand(μ_prior), rand(σ_prior)]
    logprior_init = logpdf(μ_prior, θ[1, 1, n]) + logpdf(σ_prior, θ[2, 1, n])
    loglikelihood_init = sum(logpdf.(Normal(θ[1, 1, n], θ[2, 1, n]), y))
    current_logposterior = logprior_init + loglikelihood_init
    current_μ, current_σ = θ[:, 1, n]

    for i in 2:nsamples    
        if i % 2 == 0
            # sample new μ
            current_μ = rand(Normal(θ[1, i-1, n], proposals_sigma[1]))
            current_σ = θ[2, i-1, n]
        else
            # sample new σ
            current_μ = θ[1, i-1, n]
            current_σ = rand(Normal(θ[2, i-1, n], proposals_sigma[2]))   
        end
        
        # prior
        logprior = logpdf(μ_prior, current_μ) + logpdf(σ_prior, current_σ)
        if logprior == -Inf
            θ[:, i, n] = θ[:, i-1, n]
            continue
        end
        
        # likelihood
        loglikelihood = sum(logpdf.(Normal(current_μ, current_σ), y))    
        
        # posterior
        logposterior = logprior + loglikelihood
        
        r = logposterior - current_logposterior
        if log(rand()) < r
            θ[:, i, n] = [current_μ, current_σ]
            current_logposterior = logposterior
        else
            θ[:, i, n] = θ[:, i-1, n]
        end
    end
end

2.4 Plot

Code
let 
    fig = Figure(; size = (1200, 600))
    Axis(fig[1, 1]; ylabel = "μ", title = "mcmc trace")
    for n in 1:nchains
        lines!(burnin:nsamples, θ[1, burnin:end, n])
    end

    Axis(fig[2, 1]; ylabel = "σ")
    for n in 1:nchains
        lines!(burnin:nsamples, θ[2, burnin:end, n])
    end

    Axis(fig[1, 2]; title = "posterior density")
    density!(vec(θ[1, burnin:end, :]))
    
    Axis(fig[2, 2];)
    density!(vec(θ[2, burnin:end, :]))
    
    Axis(fig[1, 3]; title = "posterior vs prior")
    density!(vec(θ[1, burnin:end, :]))
    plot!(μ_prior; color = :red)
    
    Axis(fig[2, 3];)
    density!(vec(θ[2, burnin:end, :]))
    plot!(σ_prior; color = :red)
    
    fig
end

2.5 Convergence diagnostics

Code
Chains(permutedims(θ[:, burnin:end, :], (2, 1, 3)), [:μ, :σ])
Chains MCMC chain (2501×2×6 Array{Float64, 3}):

Iterations        = 1:1:2501
Number of chains  = 6
Samples per chain = 2501
parameters        = μ, σ

Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

           μ    5.0805    0.0918    0.0027   1202.1637   1179.3778    1.0051   ⋯
           σ    1.9926    0.0622    0.0021    894.8553    849.5681    1.0061   ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           μ    4.9048    5.0163    5.0823    5.1417    5.2615
           σ    1.8720    1.9514    1.9908    2.0372    2.1168

2.6 σ vs μ draws

Code
let 
    fig = Figure(; size = (900, 450))
    Axis(fig[1, 1]; xlabel = "μ", ylabel = "σ", title = "with burn-in period")
    for n in 1:nchains
        scatterlines!(θ[1, :, n], θ[2, :, n])
    end
    
    Axis(fig[1, 2]; xlabel = "μ", ylabel = "σ", title = "burn-in removed")
    for n in 1:nchains
        scatter!(θ[1, burnin:end, n], θ[2, burnin:end, n]; 
                 color = (Makie.wong_colors()[n], 0.2))
    end

    fig
end