Coverage for pesummary/cli/summaryjscompare.py: 94.8%
135 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
1#!/usr/bin/env python3
2"""
3Interface to generate JS test comparison between two results files.
5This plots the difference in CDF between two sets of samples and adds the JS test
6statistic with uncertainty estimated by bootstrapping.
7"""
9from collections import namedtuple
10import numpy as np
11import os
12import pandas as pd
13from scipy.stats import binom
14import matplotlib
15import matplotlib.style
16import matplotlib.pyplot as plt
18from pesummary.io import read
19from pesummary.core.cli.parser import ArgumentParser as _ArgumentParser
20from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE
21from pesummary.core.plots.figure import figure
22from pesummary.gw.plots.bounds import default_bounds
23from pesummary.gw.plots.latex_labels import GWlatex_labels
24from pesummary.core.plots.latex_labels import latex_labels
25from pesummary.utils.tqdm import tqdm
26from pesummary.utils.utils import jensen_shannon_divergence
27from pesummary.utils.utils import _check_latex_install, get_matplotlib_style_file, logger
29__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
31matplotlib.style.use(get_matplotlib_style_file())
32_check_latex_install()
35class ArgumentParser(_ArgumentParser):
36 def _pesummary_options(self):
37 options = super(ArgumentParser, self)._pesummary_options()
38 options.update(
39 {
40 "--ntests": {
41 "type": int,
42 "default": 100,
43 "help": "Number of iterations for bootstrapping",
44 },
45 "--main_keys": {
46 "nargs": "+",
47 "default": [
48 "theta_jn", "chirp_mass", "mass_ratio", "tilt_1",
49 "tilt_2", "luminosity_distance", "ra", "dec", "a_1",
50 "a_2"
51 ],
52 "help": "List of parameter names"
53 },
54 "--event": {
55 "type": str,
56 "required": True,
57 "help": "Label, e.g. the event name"
58 },
59 "--samples": {
60 "short": "-s",
61 "type": str,
62 "nargs": 2,
63 "help": "Paths to a pair of results files to compare."
64 },
65 "--labels": {
66 "short": "-l",
67 "type": str,
68 "nargs": 2,
69 "help": "Pair of labels for each result file",
70 },
71 }
72 )
73 options["--nsamples"]["default"] = 10000
74 return options
77def load_data(data_file):
78 """ Read in a data file and return samples dictionary """
79 f = read(data_file, package="gw", disable_prior=True)
80 return f.samples_dict
83def js_bootstrap(key, resultA, resultB, nsamples, ntests):
84 """
85 Evaluates mean JS divergence with bootstrapping
86 key: string posterior parameter
87 result_A: first full posterior samples set
88 result_B: second full posterior samples set
89 nsamples: number for downsampling full sample set
90 ntests: number of iterations over different nsamples realisations
91 returns: 1 dim array (of lenght ntests)
92 """
94 samplesA = resultA[key]
95 samplesB = resultB[key]
97 # Get minimum number of samples to use
98 nsamples = min([nsamples, len(samplesA), len(samplesB)])
100 xlow, xhigh = None, None
101 if key in default_bounds.keys():
102 bounds = default_bounds[key]
103 if "low" in bounds.keys():
104 xlow = bounds["low"]
105 if "high" in bounds.keys():
106 if isinstance(bounds["high"], str) and "mass_1" in bounds["high"]:
107 xhigh = np.min([np.max(samplesA), np.max(samplesB)])
108 else:
109 xhigh = bounds["high"]
111 js_array = np.zeros(ntests)
113 for j in tqdm(range(ntests)):
114 bootA = np.random.choice(samplesA, size=nsamples, replace=False)
115 bootB = np.random.choice(samplesB, size=nsamples, replace=False)
116 js_array[j] = np.nan_to_num(
117 jensen_shannon_divergence(
118 [bootA, bootB],
119 kde=ReflectionBoundedKDE, xlow=xlow, xhigh=xhigh
120 ), nan=0.0
121 )
122 return js_array
125def calc_median_error(jsvalues, quantiles=(0.16, 0.84)):
126 quants_to_compute = np.array([quantiles[0], 0.5, quantiles[1]])
127 quants = np.percentile(jsvalues, quants_to_compute * 100)
128 summary = namedtuple("summary", ["median", "lower", "upper"])
129 summary.median = quants[1]
130 summary.plus = quants[2] - summary.median
131 summary.minus = summary.median - quants[0]
132 return summary
135def bin_series_and_calc_cdf(x, y, bins=100):
136 """
137 Bin two unequal length series into equal bins
138 and calculate their cumulative distibution function
139 in order to generate pp-plots
140 """
141 boundaries = sorted(x)[:: round(len(x) / bins) + 1]
142 labels = [(boundaries[i] + boundaries[i + 1]) / 2 for i in range(len(boundaries) - 1)]
143 # Bin two series into equal bins
144 try:
145 xb = pd.cut(x, bins=boundaries, labels=labels)
146 yb = pd.cut(y, bins=boundaries, labels=labels)
147 # Get value counts for each bin and sort by bin
148 xhist = xb.value_counts().sort_index(ascending=True) / len(xb)
149 yhist = yb.value_counts().sort_index(ascending=True) / len(yb)
150 # Make cumulative
151 for ser in [xhist, yhist]:
152 ttl = 0
153 for idx, val in ser.items():
154 ttl += val
155 ser.loc[idx] = ttl
156 except ValueError:
157 xhist = np.linspace(0, 1, 1000)
158 yhist = np.linspace(0, 1, 1000)
159 return xhist, yhist
162def calculate_CI(len_samples, confidence_level=0.95, n_points=1001):
163 """
164 Calculate confidence intervals
165 (https://git.ligo.org/lscsoft/bilby/blob/master/bilby/core/result.py#L1578)
166 len_samples: lenght of posterior samples
167 confidence level: default 90%
168 n_points: number of points over which evaluating confidence region
169 """
170 x_values = np.linspace(0, 1, n_points)
171 N = len_samples
172 edge_of_bound = (1.0 - confidence_level) / 2.0
173 lower = binom.ppf(1 - edge_of_bound, N, x_values) / N
174 upper = binom.ppf(edge_of_bound, N, x_values) / N
175 lower[0] = 0
176 upper[0] = 0
177 return x_values, upper, lower
180def pp_plot(event, resultA, resultB, labelA, labelB, main_keys, nsamples, js_data, webdir):
181 """
182 Produce PP plot between sampleA and samplesB
183 for a set of paramaters (main keys) for a given event.
184 The JS divergence for each pair of samples is shown in legend.
185 """
186 # Creating dict where ks_data for each event will be saved
187 fig, ax = figure(figsize=(6, 5), gca=True)
189 latex_labels.update(GWlatex_labels)
190 for key_index, key in enumerate(main_keys):
192 # Check the key exists in both sets of samples
193 if key not in resultA or key not in resultB:
194 logger.debug(f"Missing key {key}")
195 continue
196 # Get minimum number of samples to use
197 nsamples = min([nsamples, len(resultA[key]), len(resultB[key])])
199 # Resample to nsamples
200 lp = np.random.choice(resultA[key], size=nsamples, replace=False)
201 bp = np.random.choice(resultB[key], size=nsamples, replace=False)
203 # Bin posterior samples into equal lenght bins and calculate cumulative
204 xhist, yhist = bin_series_and_calc_cdf(bp, lp)
206 summary = js_data[key]
207 logger.debug(f"JS {key}: {summary.median}, {summary.minus}, {summary.plus}")
208 fmt = "{{0:{0}}}".format(".5f").format
210 if key not in list(latex_labels.keys()):
211 latex_labels[key] = key.replace("_", " ")
212 ax.plot(
213 xhist,
214 yhist - xhist,
215 label=latex_labels[key]
216 + r" ${{{0}}}_{{-{1}}}^{{+{2}}}$".format(fmt(summary.median), fmt(summary.minus), fmt(summary.plus)),
217 linewidth=1,
218 linestyle="-",
219 )
221 for confidence in [0.68, 0.95, 0.997]:
222 x_values, upper, lower = calculate_CI(nsamples, confidence_level=confidence)
223 ax.fill_between(
224 x_values, lower - x_values, upper - x_values, linewidth=1, color="k", alpha=0.1,
225 )
226 ax.set_xlabel(f"{labelA} CDF")
227 ax.set_ylabel(f"{labelA} CDF - {labelB} CDF")
228 ax.set_xlim(0, 1)
229 ax.legend(loc="upper right", ncol=4, fontsize=6)
230 ax.set_title(r"{} N samples={:.0f}".format(event, nsamples))
231 ax.grid()
232 fig.tight_layout()
233 plt.savefig(os.path.join(webdir, "{}-comparison-{}-{}.png".format(event, labelA, labelB)))
236def parse_cmd_line(args=None):
237 _parser = ArgumentParser(description=__doc__)
238 _parser.add_known_options_to_parser(
239 [
240 "--seed", "--nsamples", "--webdir", "--ntests", "--main_keys",
241 "--labels", "--samples", "--event"
242 ]
243 )
244 args, unknown = _parser.parse_known_args(args=args)
245 return args
248def main(args=None):
249 args = parse_cmd_line(args=args)
251 # Set random seed
252 np.random.seed(seed=args.seed)
254 # Read in the keys to apply
255 main_keys = args.main_keys
257 # Read in the results and labels
258 resultA = load_data(args.samples[0])
259 labelA = args.labels[0]
260 resultB = load_data(args.samples[1])
261 labelB = args.labels[1]
263 logger.debug("Evaluating JS divergence..")
264 js_data = dict()
265 js = np.zeros((args.ntests, len(main_keys)))
266 for i, key in enumerate(main_keys):
267 js[:, i] = js_bootstrap(key, resultA, resultB, args.nsamples, ntests=args.ntests,)
268 js_data[key] = calc_median_error(js[:, i])
269 logger.debug("Making pp-plot..")
270 pp_plot(args.event, resultA, resultB, labelA, labelB, main_keys, args.nsamples, js_data=js_data, webdir=args.webdir)
273if __name__ == "__main__":
274 main()