Source code for postprocessing.Cornerplot

import numpy as np
import matplotlib.pyplot as plt
import corner
import matplotlib.lines as mlines

def overlapping_corner_three(array1, array2, array3, param_labels=None,sample_labels=None, save_plot=False, filename="corner_plot.pdf"):
    """
    Creates a corner plot using three 2D arrays, overlapping them with different colors,
    and displays histograms with consistent binning for comparison. Histograms are unfilled,
    the y-axis is kept on the first histogram, the x-axis is kept on the last histogram, and
    the last histogram has a label beneath the x-axis showing the parameter it represents.

    Parameters:
    -----------
    array1, array2, array3 : 2D arrays
        The input datasets. Each should be of shape (n_samples, n_params), where `n_samples`
        is the number of samples, and `n_params` is the number of parameters. Arrays can
        have different lengths (n_samples).
    param_labels : list of str, optional
        The labels for each parameter to use in the corner plot. Defaults to None, in which
        case generic labels will be used.
    save_plot : bool, optional
        Whether to save the plot to a file. Defaults to False.
    filename : str, optional
        The filename to save the plot if `save_plot=True`. Defaults to "corner_plot.pdf".
    
    Returns:
    --------
    None
        Displays the corner plot.
    """

    # Set up colors for the three datasets
    colors = ['red', 'blue', 'green']

    # Compute consistent bin edges based on combined data
    def get_consistent_bins(arrays, n_params, n_bins=50):
        bin_edges = []
        for i in range(n_params):
            combined_data = np.concatenate([array[:, i] for array in arrays])
            bins = np.histogram_bin_edges(combined_data, bins=n_bins)
            bin_edges.append(bins)
        return bin_edges

    # Compute bin edges for all parameters
    arrays = [array1, array2, array3]
    n_params = array1.shape[1]
    bin_edges = get_consistent_bins(arrays, n_params)

    # Create the corner plot for the first dataset (without 1D histograms)
    figure = corner.corner(array1, smooth=0.9,
                           label_kwargs=dict(fontsize=17),
                           title_kwargs=dict(fontsize=16),
                           levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.)),
                           labels=param_labels, color=colors[0], show_titles=False, 
                           plot_datapoints=False, plot_density=True, fill_contours=True)
    
    # Overlap the second and third datasets onto the existing figure
    corner.corner(array2, smooth=0.9,
                  label_kwargs=dict(fontsize=17),
                  title_kwargs=dict(fontsize=16),
                  levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.)),
                  labels=param_labels, color=colors[1], fig=figure, show_titles=False, 
                  plot_datapoints=False, plot_density=True, fill_contours=True)
    corner.corner(array3, smooth=0.9,
                  label_kwargs=dict(fontsize=17),
                  title_kwargs=dict(fontsize=16),
                  levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.)),
                  labels=param_labels, color=colors[2], fig=figure, show_titles=False, 
                  plot_datapoints=False, plot_density=True, fill_contours=True)

    # Add histograms with consistent binning to the 1D marginal plots (diagonal panels)
    axes = np.array(figure.axes).reshape((n_params, n_params))
    for i in range(n_params):
        ax = axes[i, i]
        
        # Clear the existing diagonal histogram (which was automatically created by `corner`)
        ax.clear()
        param_name = param_labels[i] if param_labels else f"Param {i+1}"
        ax.set_title(param_name, fontsize=20, pad=75)  # Increase padding to make space for mean/std
        
        # Plot unfilled histograms with consistent bins for each dataset on the diagonal
        for array, label, color, bins in zip(arrays, sample_labels, colors, [bin_edges[i]]*3):
            ax.hist(array[:, i], bins=bins, histtype='step', color=color, label=label, density=True)

        # Compute and add mean and 68% quantile range in LaTeX format
        for idx, (array, color) in enumerate(zip(arrays, colors), start=1):
            mean = np.mean(array[:, i])
            lower_quantile = np.percentile(array[:, i], 16)
            upper_quantile = np.percentile(array[:, i], 84)
            lower_diff = mean - lower_quantile
            upper_diff = upper_quantile - mean
            # Use LaTeX-like format for displaying the uncertainties
            ax.text(0.5, 0.9 + 0.17 * idx, 
                    f"${mean:.3f}_{{-{lower_diff:.3f}}}^{{+{upper_diff:.3f}}}$", 
                    color=color, transform=ax.transAxes, fontsize=16, ha='center')

        # Keep y-axis visible for the first histogram, x-axis visible for the last histogram
        
        ax.set_yticks([])  # Hide y-axis for all other diagonal histograms

        if i == n_params - 1:
            ax.xaxis.set_visible(True)  # Keep x-axis visible for the last histogram
            ax.set_xlabel(param_labels[i]) # Add label beneath the x-axis for the last histogram
            plt.setp(ax.get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor')
        else:
            ax.set_xticks([])  # Hide x-axis for all other diagonal histograms
    # Adjust layout: Reduce the spacing between the panels
    figure.subplots_adjust(top=0.9, wspace=0.05, hspace=0.05)

    # Add the legend
    handles = [
        mlines.Line2D([], [], color=colors[i], label=sample_labels[i]) for i in range(len(sample_labels))
    ]
    figure.legend(handles=handles, loc='upper right', bbox_to_anchor=(0.85, 0.85), fontsize=20, frameon=False)
    
    # Set tick label font size for non-diagonal plots
    label_font_size = 13
    for ax in figure.get_axes():
        ax.tick_params(axis='both', labelsize=label_font_size)
    ax = axes[-1, -1]
    ax.annotate(param_labels[-1], xy=(0.5, -0.38), xycoords='axes fraction', fontsize=20, ha='center', va='center')
    # Show or save the plot
    if save_plot:
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Corner plot saved as {filename}.")
    else:
        plt.show()
        


[docs] def overlapping_corner_three(array1, array2, param_labels=None, sample_labels=None, save_plot=False, filename="corner_plot.pdf"): """ Creates a corner plot using three 2D arrays, overlapping them with different colors, and displays histograms with consistent binning for comparison. Histograms are unfilled, the y-axis is kept on the first histogram, the x-axis is kept on the last histogram, and the last histogram has a label beneath the x-axis showing the parameter it represents. Parameters: ----------- array1, array2 : 2D arrays The input datasets. Each should be of shape (n_samples, n_params), where `n_samples` is the number of samples, and `n_params` is the number of parameters. Arrays can have different lengths (n_samples). param_labels : list of str, optional The labels for each parameter to use in the corner plot. Defaults to None, in which case generic labels will be used. save_plot : bool, optional Whether to save the plot to a file. Defaults to False. filename : str, optional The filename to save the plot if `save_plot=True`. Defaults to "corner_plot.pdf". Returns: -------- None Displays the corner plot. """ # Set up colors for the three datasets colors = ['red', 'blue'] # Compute consistent bin edges based on combined data def get_consistent_bins(arrays, n_params, n_bins=50): bin_edges = [] for i in range(n_params): combined_data = np.concatenate([array[:, i] for array in arrays]) bins = np.histogram_bin_edges(combined_data, bins=n_bins) bin_edges.append(bins) return bin_edges # Compute bin edges for all parameters arrays = [array1, array2,] n_params = array1.shape[1] bin_edges = get_consistent_bins(arrays, n_params) # Create the corner plot for the first dataset (without 1D histograms) figure = corner.corner(array1, smooth=0.9, label_kwargs=dict(fontsize=17), title_kwargs=dict(fontsize=16), levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.)), labels=param_labels, color=colors[0], show_titles=False, plot_datapoints=False, plot_density=True, fill_contours=True) # Overlap the second and third datasets onto the existing figure corner.corner(array2, smooth=0.9, label_kwargs=dict(fontsize=17), title_kwargs=dict(fontsize=16), levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.)), labels=param_labels, color=colors[1], fig=figure, show_titles=False, plot_datapoints=False, plot_density=True, fill_contours=True) # Add histograms with consistent binning to the 1D marginal plots (diagonal panels) axes = np.array(figure.axes).reshape((n_params, n_params)) for i in range(n_params): ax = axes[i, i] # Clear the existing diagonal histogram (which was automatically created by `corner`) ax.clear() param_name = param_labels[i] if param_labels else f"Param {i+1}" ax.set_title(param_name, fontsize=20, pad=55) # Increase padding to make space for mean/std # Plot unfilled histograms with consistent bins for each dataset on the diagonal for array, label, color, bins in zip(arrays, sample_labels, colors, [bin_edges[i]]*3): ax.hist(array[:, i], bins=bins, histtype='step', color=color, label=label, density=True) # Compute and add mean and 68% quantile range in LaTeX format for idx, (array, color) in enumerate(zip(arrays, colors), start=1): mean = np.mean(array[:, i]) lower_quantile = np.percentile(array[:, i], 16) upper_quantile = np.percentile(array[:, i], 84) lower_diff = mean - lower_quantile upper_diff = upper_quantile - mean # Use LaTeX-like format for displaying the uncertainties ax.text(0.5, 0.9 + 0.17 * idx, f"${mean:.3f}_{{-{lower_diff:.3f}}}^{{+{upper_diff:.3f}}}$", color=color, transform=ax.transAxes, fontsize=16, ha='center') # Keep y-axis visible for the first histogram, x-axis visible for the last histogram ax.set_yticks([]) # Hide y-axis for all other diagonal histograms if i == n_params - 1: ax.xaxis.set_visible(True) # Keep x-axis visible for the last histogram ax.set_xlabel(param_labels[i]) # Add label beneath the x-axis for the last histogram plt.setp(ax.get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor') else: ax.set_xticks([]) # Hide x-axis for all other diagonal histograms # Adjust layout: Reduce the spacing between the panels figure.subplots_adjust(top=0.9, wspace=0.05, hspace=0.05) # Add the legend handles = [ mlines.Line2D([], [], color=colors[i], label=sample_labels[i]) for i in range(len(sample_labels)) ] figure.legend(handles=handles, loc='upper right', bbox_to_anchor=(0.95, 0.95), fontsize=20, frameon=False) # Set tick label font size for non-diagonal plots label_font_size = 13 for ax in figure.get_axes(): ax.tick_params(axis='both', labelsize=label_font_size) ax = axes[-1, -1] ax.annotate(param_labels[-1], xy=(0.5, -0.38), xycoords='axes fraction', fontsize=20, ha='center', va='center') # Show or save the plot if save_plot: plt.savefig(filename, dpi=300, bbox_inches='tight') print(f"Corner plot saved as {filename}.") else: plt.show()