Estimating ln(Z) with NumPyro + MorphZ#

This notebook demonstrates how to compute the Bayesian evidence (log marginal likelihood, ln Z) using:

  • NumPyro for posterior sampling (via NUTS)

  • MorphZ for evidence estimation from posterior samples

Open In Colab

1. What We Are Computing#

Given a model with parameters \(\theta\) and data \(y\):

Posterior: $\( p(\theta \mid y) \propto p(y \mid \theta)\, p(\theta) \)$

Evidence (marginal likelihood): $\( Z = p(y) = \int p(y \mid \theta)\, p(\theta)\, d\theta \)$

We want \(\ln Z\). NumPyro gives posterior samples, and MorphZ estimates \(Z\) from those samples plus log-posterior evaluations.

2. Colab Setup#

%pip -q install "jax[cpu]" numpyro morphz
import numpy as np
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import log_density
from morphZ import evidence

print('jax:', jax.__version__)
print('numpyro:', numpyro.__version__)

3. Minimal Toy Model: Gaussian Mean Inference#

We infer an unknown mean \(\mu\) with known noise \(\sigma\).

  • Prior: \(\mu \sim \mathcal{N}(0, \tau_0^2)\)

  • Likelihood: \(y_i \sim \mathcal{N}(\mu, \sigma^2)\)

rng = np.random.default_rng(0)

# Synthetic dataset
n_obs = 40
mu_true = 1.25
sigma = 0.7
tau0 = 2.0
y = rng.normal(mu_true, sigma, size=n_obs)

print('n_obs =', n_obs)
print('sample mean =', y.mean())
def gaussian_mean_model(y, sigma, tau0):
    mu = numpyro.sample('mu', dist.Normal(0.0, tau0))
    numpyro.sample('obs', dist.Normal(mu, sigma), obs=y)

4. Run NUTS and Collect Posterior Samples#

nuts = NUTS(gaussian_mean_model)
mcmc = MCMC(
    nuts,
    num_warmup=800,
    num_samples=2000,
    num_chains=1,
    progress_bar=True,
)

rng_key = jax.random.PRNGKey(42)
mcmc.run(rng_key, y=jnp.asarray(y), sigma=sigma, tau0=tau0)
samples = mcmc.get_samples(group_by_chain=False)

print('posterior keys:', samples.keys())
print('mu sample shape:', np.asarray(samples['mu']).shape)

5. Build a Log Posterior Callable#

MorphZ needs lp_fn(sample_vector) that returns the same log posterior used for sampling.

def build_log_density_fn(model, model_kwargs):
    model_kwargs = jax.tree_util.tree_map(jnp.asarray, model_kwargs)

    def _logpost(params):
        log_prob, _ = log_density(model, (), model_kwargs, params)
        return log_prob

    return jax.jit(_logpost)

model_kwargs = {
    'y': jnp.asarray(y),
    'sigma': sigma,
    'tau0': tau0,
}

logpost_fn = build_log_density_fn(gaussian_mean_model, model_kwargs)

6. Pack Samples into a 2D Array#

MorphZ expects shape = (n_draws, n_parameters). Here we have one parameter (mu).

mu_samples = np.asarray(samples['mu'])
post_smp = mu_samples[:, None]  # (n_draws, 1)

def lp_fn(sample_vec):
    params = {'mu': jnp.asarray(sample_vec[0])}
    return float(logpost_fn(params))

lp = np.array([lp_fn(v) for v in post_smp])

print('post_smp shape:', post_smp.shape)
print('lp shape:', lp.shape)

7. Sanity Check: lp vs lp_fn#

for i in range(5):
    print(f'i={i}  precomputed={lp[i]: .6f}  callable={lp_fn(post_smp[i]): .6f}')

8. Estimate ln(Z) with MorphZ#

results = evidence(
    post_samples=post_smp,
    log_posterior_values=lp,
    log_posterior_function=lp_fn,
    n_resamples=1000,
    thin=2,
    kde_fraction=0.6,
    bridge_start_fraction=0.5,
    max_iter=2000,
    tol=1e-4,
    morph_type='indep',
    kde_bw='silverman',
    param_names=['mu'],
    output_path='morphz_numpyro_demo',
    n_estimations=3,
    verbose=False,
    plot=False,
    show_progress=False,
)

results = np.asarray(results)
results

9. Compare to Analytic ln(Z)#

For this conjugate Gaussian setup, we can compute exact ln(Z) for validation.

def analytic_lnz(y, sigma, tau0):
    n = y.size
    C = (sigma ** 2) * np.eye(n) + (tau0 ** 2) * np.ones((n, n))
    sign, logdet = np.linalg.slogdet(C)
    if sign <= 0:
        raise RuntimeError('Covariance matrix is not positive definite.')
    quad = y @ np.linalg.solve(C, y)
    return -0.5 * (n * np.log(2 * np.pi) + logdet + quad)

lnz_true = analytic_lnz(y, sigma, tau0)
lnz_est = results[:, 0]
lnz_err = results[:, 1]

print('MorphZ ln(Z) per run:', lnz_est)
print('MorphZ reported errors:', lnz_err)
print(f'MorphZ mean ln(Z): {lnz_est.mean():.6f}')
print(f'Analytic ln(Z):     {lnz_true:.6f}')
print(f'Absolute difference: {abs(lnz_est.mean() - lnz_true):.6f}')

10. Summary#

The MorphZ inputs are:

  • post_smp: posterior sample matrix

  • lp: log posterior values at those samples

  • lp_fn: callable log posterior

As long as lp[i] == lp_fn(post_smp[i]), your evidence estimate is consistent with your sampled posterior.