Estimating ln(Z) with NumPyro + MorphZ#
This notebook demonstrates how to compute the Bayesian evidence (log marginal likelihood, ln Z) using:
NumPyro for posterior sampling (via NUTS)
MorphZ for evidence estimation from posterior samples
1. What We Are Computing#
Given a model with parameters \(\theta\) and data \(y\):
Posterior: $\( p(\theta \mid y) \propto p(y \mid \theta)\, p(\theta) \)$
Evidence (marginal likelihood): $\( Z = p(y) = \int p(y \mid \theta)\, p(\theta)\, d\theta \)$
We want \(\ln Z\). NumPyro gives posterior samples, and MorphZ estimates \(Z\) from those samples plus log-posterior evaluations.
2. Colab Setup#
%pip -q install "jax[cpu]" numpyro morphz
import numpy as np
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import log_density
from morphZ import evidence
print('jax:', jax.__version__)
print('numpyro:', numpyro.__version__)
3. Minimal Toy Model: Gaussian Mean Inference#
We infer an unknown mean \(\mu\) with known noise \(\sigma\).
Prior: \(\mu \sim \mathcal{N}(0, \tau_0^2)\)
Likelihood: \(y_i \sim \mathcal{N}(\mu, \sigma^2)\)
rng = np.random.default_rng(0)
# Synthetic dataset
n_obs = 40
mu_true = 1.25
sigma = 0.7
tau0 = 2.0
y = rng.normal(mu_true, sigma, size=n_obs)
print('n_obs =', n_obs)
print('sample mean =', y.mean())
def gaussian_mean_model(y, sigma, tau0):
mu = numpyro.sample('mu', dist.Normal(0.0, tau0))
numpyro.sample('obs', dist.Normal(mu, sigma), obs=y)
4. Run NUTS and Collect Posterior Samples#
nuts = NUTS(gaussian_mean_model)
mcmc = MCMC(
nuts,
num_warmup=800,
num_samples=2000,
num_chains=1,
progress_bar=True,
)
rng_key = jax.random.PRNGKey(42)
mcmc.run(rng_key, y=jnp.asarray(y), sigma=sigma, tau0=tau0)
samples = mcmc.get_samples(group_by_chain=False)
print('posterior keys:', samples.keys())
print('mu sample shape:', np.asarray(samples['mu']).shape)
5. Build a Log Posterior Callable#
MorphZ needs lp_fn(sample_vector) that returns the same log posterior used for sampling.
def build_log_density_fn(model, model_kwargs):
model_kwargs = jax.tree_util.tree_map(jnp.asarray, model_kwargs)
def _logpost(params):
log_prob, _ = log_density(model, (), model_kwargs, params)
return log_prob
return jax.jit(_logpost)
model_kwargs = {
'y': jnp.asarray(y),
'sigma': sigma,
'tau0': tau0,
}
logpost_fn = build_log_density_fn(gaussian_mean_model, model_kwargs)
6. Pack Samples into a 2D Array#
MorphZ expects shape = (n_draws, n_parameters). Here we have one parameter (mu).
mu_samples = np.asarray(samples['mu'])
post_smp = mu_samples[:, None] # (n_draws, 1)
def lp_fn(sample_vec):
params = {'mu': jnp.asarray(sample_vec[0])}
return float(logpost_fn(params))
lp = np.array([lp_fn(v) for v in post_smp])
print('post_smp shape:', post_smp.shape)
print('lp shape:', lp.shape)
7. Sanity Check: lp vs lp_fn#
for i in range(5):
print(f'i={i} precomputed={lp[i]: .6f} callable={lp_fn(post_smp[i]): .6f}')
8. Estimate ln(Z) with MorphZ#
results = evidence(
post_samples=post_smp,
log_posterior_values=lp,
log_posterior_function=lp_fn,
n_resamples=1000,
thin=2,
kde_fraction=0.6,
bridge_start_fraction=0.5,
max_iter=2000,
tol=1e-4,
morph_type='indep',
kde_bw='silverman',
param_names=['mu'],
output_path='morphz_numpyro_demo',
n_estimations=3,
verbose=False,
plot=False,
show_progress=False,
)
results = np.asarray(results)
results
9. Compare to Analytic ln(Z)#
For this conjugate Gaussian setup, we can compute exact ln(Z) for validation.
def analytic_lnz(y, sigma, tau0):
n = y.size
C = (sigma ** 2) * np.eye(n) + (tau0 ** 2) * np.ones((n, n))
sign, logdet = np.linalg.slogdet(C)
if sign <= 0:
raise RuntimeError('Covariance matrix is not positive definite.')
quad = y @ np.linalg.solve(C, y)
return -0.5 * (n * np.log(2 * np.pi) + logdet + quad)
lnz_true = analytic_lnz(y, sigma, tau0)
lnz_est = results[:, 0]
lnz_err = results[:, 1]
print('MorphZ ln(Z) per run:', lnz_est)
print('MorphZ reported errors:', lnz_err)
print(f'MorphZ mean ln(Z): {lnz_est.mean():.6f}')
print(f'Analytic ln(Z): {lnz_true:.6f}')
print(f'Absolute difference: {abs(lnz_est.mean() - lnz_true):.6f}')
10. Summary#
The MorphZ inputs are:
post_smp: posterior sample matrixlp: log posterior values at those sampleslp_fn: callable log posterior
As long as lp[i] == lp_fn(post_smp[i]), your evidence estimate is consistent with your sampled posterior.