Gaussian Shell

Gaussian Shell#

import numpy as np
import dynesty
from dynesty import NestedSampler
from dynesty import utils as dyfunc
import matplotlib.pyplot as plt
import corner
from dynesty import DynamicNestedSampler
 
from scipy.special import logsumexp

# Prior cube: [-L, L]^D 
L = 6.0
ndim = 30
prior_volume = (2*L)**ndim

# Parameters of the twin shells (from paper)
w1 = w2 = 0.1  # thickness of both shells
r1 = r2 = 2.0  # radius of both shells
c1 = np.array([-3.5] + [0.0] * (ndim - 1))  # center of first shell
c2 = np.array([3.5] + [0.0] * (ndim - 1))   # center of second shell

def logprior(theta):
    """Uniform prior over [-L, L]^D."""
    if np.all(np.abs(theta) <= L):
        return -np.log(prior_volume)
    else:
        return -np.inf

def loglikelihood(theta):
    """
    Twin Gaussian shell log-likelihood from equation (38) in the paper.
    L(θ) = (1/√(2πw₁)) * exp[-(|θ-c₁|-r₁)²/(2w₁²)] + (1/√(2πw₂)) * exp[-(|θ-c₂|-r₂)²/(2w₂²)]
    """
    # Distance from centers
    r_c1 = np.linalg.norm(theta - c1)
    r_c2 = np.linalg.norm(theta - c2)
    
    # Log-likelihood components for each shell
    # Note: We work in log space for numerical stability
    log_norm1 = -0.5 * np.log(2 * np.pi * w1**2)
    log_norm2 = -0.5 * np.log(2 * np.pi * w2**2)
    
    log_exp1 = log_norm1 - 0.5 * ((r_c1 - r1) / w1)**2
    log_exp2 = log_norm2 - 0.5 * ((r_c2 - r2) / w2)**2
    
    # Use logsumexp for numerical stability when adding exponentials
    return logsumexp([log_exp1, log_exp2])

def prior_transform(u):
    """
    Map unit cube [0,1]^D to [-L, L]^D.
    u: array in [0,1]^D
    """
    return -L + 2*L*u

# Optional: Visualization function for 2D case
def plot_twin_shells_2d():
    """Plot the twin Gaussian shells for visualization (2D only)."""
    if ndim != 2:
        print("Visualization only available for 2D case")
        return
    
    # Create grid
    x = np.linspace(-6, 6, 200)
    y = np.linspace(-6, 6, 200)
    X, Y = np.meshgrid(x, y)
    
    # Evaluate likelihood on grid
    Z = np.zeros_like(X)
    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            theta = np.array([X[i,j], Y[i,j]])
            Z[i,j] = np.exp(loglikelihood(theta))
    
    # Plot
    plt.figure(figsize=(10, 8))
    plt.contourf(X, Y, Z, levels=50, cmap='viridis')
    plt.colorbar(label='Likelihood')
    plt.xlabel('θ₁')
    plt.ylabel('θ₂')
    plt.title('Twin Gaussian Shells')
    
    # Mark centers
    plt.plot(c1[0], c1[1], 'r*', markersize=15, label='Center 1')
    plt.plot(c2[0], c2[1], 'r*', markersize=15, label='Center 2')
    
    # Draw circles showing the shell radii
    circle1 = plt.Circle(c1[:2], r1, fill=False, color='red', linestyle='--', alpha=0.7)
    circle2 = plt.Circle(c2[:2], r2, fill=False, color='red', linestyle='--', alpha=0.7)
    plt.gca().add_patch(circle1)
    plt.gca().add_patch(circle2)
    
    plt.legend()
    plt.axis('equal')
    plt.grid(True, alpha=0.3)
    plt.show()

# Example usage with dynesty
def run_dynesty_example():
    """Example of how to use this with dynesty."""
    try:
        import dynesty
        from dynesty import plotting as dyplot
        
        # Run nested sampling
        sampler = dynesty.NestedSampler(loglikelihood, prior_transform, ndim)
        sampler.run_nested()
        results = sampler.results
        
        print(f"Log evidence: {results.logz[-1]:.2f} ± {results.logzerr[-1]:.2f}")
        
        # Plot results
        fig, axes = dyplot.runplot(results)
        plt.show()
        
        return results
        
    except ImportError:
        print("Dynesty not installed. Install with: pip install dynesty")
        return None

# Analytical log-evidence values from the paper (for reference)
analytical_logz = {
    20: -36.09, # 20D case from paper
    30: -60.13  # 30D case from paper
}



if __name__ == "__main__":
    print(f"Twin Gaussian Shell Problem Setup:")
    print(f"Dimensions: {ndim}")
    print(f"Prior bounds: [{-L}, {L}]^{ndim}")
    print(f"Shell centers: c1={c1}, c2={c2}")
    print(f"Shell radii: r1={r1}, r2={r2}")
    print(f"Shell widths: w1={w1}, w2={w2}")
    
    if ndim == 2:
        print("\nGenerating 2D visualization...")
        plot_twin_shells_2d()
    
    # Uncomment to run dynesty example
    # results = run_dynesty_example()


# -------------------------------
# Run dynesty Nested Sampling
# -------------------------------
NN = 1

log_z_NS = np.zeros(NN)
log_z_NS_err = np.zeros(NN)

for i in range(NN):
# Set up the sampler
    sampler = NestedSampler(
        loglikelihood,
        prior_transform,
        ndim,
        nlive=200,      # number of live points; increase for accuracy
        sample="rwalk",  # random-walk proposals; can try "unif", "slice"
        bound="multi",   # bounding method ("multi" good for multimodal)
    )
    # sampler = DynamicNestedSampler(
    #     loglikelihood,
    #     prior_transform,
    #     ndim,
    #     nlive=4000,
    #     )

    print("Running dynesty Nested Sampling on Gaussian shell...")
    sampler.run_nested(dlogz=0.1,print_progress=True)
    # sampler.run_nested(dlogz_init=0.1,print_progress=True)
    res = sampler.results
    samples, weights = res.samples, np.exp(res.logwt - res.logz[-1])
    posterior_samples = dyfunc.resample_equal(samples, weights)
    # -------------------------------
    # Report results
    # -------------------------------
    log_z_NS[i] = res.logz[-1]
    log_z_NS_err[i] = res.logzerr[-1]
true_logz = -36.09
print("\n==== RESULTS ====")
print(f"dynesty estimated logZ = {log_z_NS.mean():.3f} ± {log_z_NS_err.mean():.3f}")
print(f"True logZ (literature) = {true_logz:.3f}")
print(f"Error = {log_z_NS.mean()- true_logz:.3f}")
Twin Gaussian Shell Problem Setup:
Dimensions: 30
Prior bounds: [-6.0, 6.0]^30
Shell centers: c1=[-3.5  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
  0.   0. ], c2=[3.5 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0. ]
Shell radii: r1=2.0, r2=2.0
Shell widths: w1=0.1, w2=0.1
Running dynesty Nested Sampling on Gaussian shell...
11803it [00:26, 437.65it/s, +200 | bound: 394 | nc: 1 | ncall: 563261 | eff(%):  2.132 | loglstar:   -inf <  1.384 <    inf | logz: -55.129 +/-  0.525 | dlogz:  0.000 >  0.100]
==== RESULTS ====
dynesty estimated logZ = -55.129 ± 0.625
True logZ (literature) = -36.090
Error = -19.039
import morphZ
import logging
from morphZ import setup_logging ; setup_logging(level=logging.INFO)

import g_shell as gs
    
# def lnprobfn(theta):
#     """Log-probability combining prior and likelihood."""
#     return logprior(theta) + loglikelihood(theta)
samples = posterior_samples[::5,:]  # total_samples[::20,:]
tot_len , ndim = samples.shape
print('Total samples:', tot_len, 'Dimensions:', ndim)
log_prob = np.zeros(tot_len)

for i in range(tot_len):
    log_prob[i] = gs.lnprobfn(samples[i,:])

log_p_estimate = morphZ.evidence(
samples,
log_prob,
gs.lnprobfn,
n_resamples=10000,
thin=1,n_estimations=5,morph_type="2_group",kde_bw="silverman",output_path='./morphZ_gaussian_shell_group/',pool=6)
print('True:', true_logz)
INFO:morphZ.morph:Using Morph_Group for proposal distribution.
INFO:morphZ.morph:Using multiprocessing with 6 workers for bridge sampling.
Total samples: 2401 Dimensions: 30
INFO:morphZ.bridge_multiprocess:Filtered proposal samples: 9892 valid samples out of 10000 total samples.
Filtered proposal samples: 9892 valid samples out of 10000 total samples.
INFO:morphZ.bridge_multiprocess:iteration: 1 log(z) old: -60.15122964279204 log(z) New: -60.558865261848965
INFO:morphZ.bridge_multiprocess:iteration: 2 log(z) old: -60.558865261848965 log(z) New: -60.62895919337263
INFO:morphZ.bridge_multiprocess:iteration: 3 log(z) old: -60.62895919337263 log(z) New: -60.64142231745201
INFO:morphZ.bridge_multiprocess:iteration: 4 log(z) old: -60.64142231745201 log(z) New: -60.64364930059958
INFO:morphZ.bridge_multiprocess:Converged in 4 iterations. log(z): -60.6414 +/-: 0.0207
Estimation 1/5
iteration: 4 log(z) old: -60.64142231745201 log(z) New: -60.64364930059958
Converged in 4 iterations. log(z): -60.6414 +/-: 0.0207
INFO:morphZ.bridge_multiprocess:Filtered proposal samples: 9886 valid samples out of 10000 total samples.
Filtered proposal samples: 9886 valid samples out of 10000 total samples.
INFO:morphZ.bridge_multiprocess:iteration: 1 log(z) old: -60.16811634825051 log(z) New: -60.55936260439591
INFO:morphZ.bridge_multiprocess:iteration: 2 log(z) old: -60.55936260439591 log(z) New: -60.626861513004556
INFO:morphZ.bridge_multiprocess:iteration: 3 log(z) old: -60.626861513004556 log(z) New: -60.63887484889846
INFO:morphZ.bridge_multiprocess:iteration: 4 log(z) old: -60.63887484889846 log(z) New: -60.641022900052164
INFO:morphZ.bridge_multiprocess:Converged in 4 iterations. log(z): -60.6389 +/-: 0.0207
Estimation 2/5
iteration: 4 log(z) old: -60.63887484889846 log(z) New: -60.641022900052164
Converged in 4 iterations. log(z): -60.6389 +/-: 0.0207
INFO:morphZ.bridge_multiprocess:Filtered proposal samples: 9890 valid samples out of 10000 total samples.
Filtered proposal samples: 9890 valid samples out of 10000 total samples.
INFO:morphZ.bridge_multiprocess:iteration: 1 log(z) old: -59.93196880291696 log(z) New: -60.51689197466937
INFO:morphZ.bridge_multiprocess:iteration: 2 log(z) old: -60.51689197466937 log(z) New: -60.62329054221572
INFO:morphZ.bridge_multiprocess:iteration: 3 log(z) old: -60.62329054221572 log(z) New: -60.64341056148471
INFO:morphZ.bridge_multiprocess:iteration: 4 log(z) old: -60.64341056148471 log(z) New: -60.64723559751507
INFO:morphZ.bridge_multiprocess:Converged in 4 iterations. log(z): -60.6434 +/-: 0.0210
Estimation 3/5
iteration: 4 log(z) old: -60.64341056148471 log(z) New: -60.64723559751507
Converged in 4 iterations. log(z): -60.6434 +/-: 0.0210
INFO:morphZ.bridge_multiprocess:Filtered proposal samples: 9884 valid samples out of 10000 total samples.
Filtered proposal samples: 9884 valid samples out of 10000 total samples.
INFO:morphZ.bridge_multiprocess:iteration: 1 log(z) old: -60.169359518609845 log(z) New: -60.5671203985069
INFO:morphZ.bridge_multiprocess:iteration: 2 log(z) old: -60.5671203985069 log(z) New: -60.63844933680139
INFO:morphZ.bridge_multiprocess:iteration: 3 log(z) old: -60.63844933680139 log(z) New: -60.65170684819988
INFO:morphZ.bridge_multiprocess:iteration: 4 log(z) old: -60.65170684819988 log(z) New: -60.65418468235432
INFO:morphZ.bridge_multiprocess:Converged in 4 iterations. log(z): -60.6517 +/-: 0.0210
Estimation 4/5
iteration: 4 log(z) old: -60.65170684819988 log(z) New: -60.65418468235432
Converged in 4 iterations. log(z): -60.6517 +/-: 0.0210
INFO:morphZ.bridge_multiprocess:Filtered proposal samples: 9871 valid samples out of 10000 total samples.
Filtered proposal samples: 9871 valid samples out of 10000 total samples.
INFO:morphZ.bridge_multiprocess:iteration: 1 log(z) old: -60.17210586426401 log(z) New: -60.57901706041223
INFO:morphZ.bridge_multiprocess:iteration: 2 log(z) old: -60.57901706041223 log(z) New: -60.64786350989249
INFO:morphZ.bridge_multiprocess:iteration: 3 log(z) old: -60.64786350989249 log(z) New: -60.659912686557576
INFO:morphZ.bridge_multiprocess:iteration: 4 log(z) old: -60.659912686557576 log(z) New: -60.662031828652204
INFO:morphZ.bridge_multiprocess:Converged in 4 iterations. log(z): -60.6599 +/-: 0.0207
Estimation 5/5
iteration: 4 log(z) old: -60.659912686557576 log(z) New: -60.662031828652204
Converged in 4 iterations. log(z): -60.6599 +/-: 0.0207

Saved log(z) to ./morphZ_gaussian_shell_group//logz_morph_z_2_group_silverman.txt
True: -36.09
pair_kde = morphZ.PairwiseKDE(samples,'/home/apokrypha/Local NT/paper_tests_morph/morphZ_gaussian_shell/params_MI.json',bw=0.02,verbose=True)
pair_samples = pair_kde.resample(2000)
indep_kde = morphZ.KDE_approx(posterior_samples,bw=0.02)
indep_samples = indep_kde.resample(2000)
fig = corner.corner(
    posterior_samples[::10,:5], bins=20,label_kwargs = {"fontsize": 7},truth_color="dodgerblue",hist_kwargs={"density": True},quantiles=[0.05, 0.5, 0.95],
    show_titles=True,
    fontzise=6,
    title_fmt=".2f",
    plot_datapoints=False,
    fill_contours=True,
    levels=(0.5, 0.8, 0.95),         
    smooth=1.0  
)
corner.corner(
    pair_samples[:,:5],color="red", bins=20,label_kwargs = {"fontsize": 7},truth_color="dodgerblue",hist_kwargs={"density": True},quantiles=[0.05, 0.5, 0.95],
    show_titles=True,
    fig=fig,
    fontzise=6,
    title_fmt=".2f",
    plot_datapoints=False,
    fill_contours=True,
    levels=(0.5, 0.8, 0.95),         
    smooth=1.0  
)
corner.corner(
    indep_samples[:,:5],color="blue", bins=20,label_kwargs = {"fontsize": 7},truth_color="dodgerblue",hist_kwargs={"density": True},quantiles=[0.05, 0.5, 0.95],
    show_titles=True,
    fig=fig,
    fontzise=6,
    title_fmt=".2f",
    plot_datapoints=False,
    fill_contours=True,
    levels=(0.5, 0.8, 0.95),         
    smooth=1.0  
)
plt.show()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 pair_kde = morphZ.PairwiseKDE(samples,'/home/apokrypha/Local NT/paper_tests_morph/morphZ_gaussian_shell/params_MI.json',bw=0.02,verbose=True)
      2 pair_samples = pair_kde.resample(2000)
      3 indep_kde = morphZ.KDE_approx(posterior_samples,bw=0.02)

TypeError: Morph_Pairwise.__init__() got an unexpected keyword argument 'bw'
fig = plt.figure(figsize =(10, 7))
ax = fig.add_subplot(111)

log_p_estimate = np.array(log_p_estimate)

plt.hlines(y=true_logz,colors='grey',linestyle='dotted',xmin=0.5, xmax=5.5,label='True log(z)')

x1= [1,2,3,4]
#median = [np.mean(z_GSS_mcmc),res.logz[-1],np.mean(log_p_estimate)]
#error = [np.std(z_GSS_mcmc),res.logzerr[-1],np.std(log_p_estimate)]
median = [np.mean(log_z_NS),np.mean(log_z_NS),np.mean(log_p_estimate[:,0]),np.mean(log_p_estimate[:,0])]
error = [0,np.mean(log_z_NS_err),np.std(log_p_estimate[:,0]),np.mean(log_p_estimate[:,1])]
ax.errorbar(x1, median, yerr=error, fmt='+', color='black', capsize=5)


ax.set_xticks(x1,labels=[r'NS$_{Emp}$',r'NS$_{Theo}$', r'Morph$_{Emp}$',r'Morph$_{RMSE}$'],fontsize=10)
ax.set_ylabel('log(z)')
ax.legend()
<matplotlib.legend.Legend at 0x7fb2943c5f60>
../../_images/77eb183e9130090adb8238795294fbd26f066d3cf262140c345801196ae3d652.png