Coverage for pesummary/cli/summaryjscompare.py: 94.8%
135 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1#!/usr/bin/env python3
2"""
3Interface to generate JS test comparison between two results files.
5This plots the difference in CDF between two sets of samples and adds the JS test
6statistic with uncertainty estimated by bootstrapping.
7"""
9from collections import namedtuple
10import numpy as np
11import os
12import pandas as pd
13from scipy.stats import binom
14import matplotlib
15import matplotlib.style
16import matplotlib.pyplot as plt
18from pesummary.io import read
19from pesummary.core.cli.parser import ArgumentParser as _ArgumentParser
20from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE
21from pesummary.core.plots.figure import figure
22from pesummary.gw.plots.bounds import default_bounds
23from pesummary.gw.plots.latex_labels import GWlatex_labels
24from pesummary.core.plots.latex_labels import latex_labels
25from pesummary.utils.tqdm import tqdm
26from pesummary.utils.utils import jensen_shannon_divergence
27from pesummary.utils.utils import _check_latex_install, get_matplotlib_style_file, logger
29__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
31matplotlib.style.use(get_matplotlib_style_file())
32_check_latex_install()
35class ArgumentParser(_ArgumentParser):
36 def _pesummary_options(self):
37 options = super(ArgumentParser, self)._pesummary_options()
38 options.update(
39 {
40 "--ntests": {
41 "type": int,
42 "default": 100,
43 "help": "Number of iterations for bootstrapping",
44 },
45 "--main_keys": {
46 "nargs": "+",
47 "default": [
48 "theta_jn", "chirp_mass", "mass_ratio", "tilt_1",
49 "tilt_2", "luminosity_distance", "ra", "dec", "a_1",
50 "a_2"
51 ],
52 "help": "List of parameter names"
53 },
54 "--event": {
55 "type": str,
56 "required": True,
57 "help": "Label, e.g. the event name"
58 },
59 "--samples": {
60 "short": "-s",
61 "type": str,
62 "nargs": 2,
63 "help": "Paths to a pair of results files to compare."
64 },
65 "--labels": {
66 "short": "-l",
67 "type": str,
68 "nargs": 2,
69 "help": "Pair of labels for each result file",
70 },
71 }
72 )
73 options["--nsamples"]["default"] = 10000
74 return options
77def load_data(data_file):
78 """ Read in a data file and return samples dictionary """
79 f = read(data_file, package="gw", disable_prior=True)
80 return f.samples_dict
83def js_bootstrap(key, resultA, resultB, nsamples, ntests):
84 """
85 Evaluates mean JS divergence with bootstrapping
86 key: string posterior parameter
87 result_A: first full posterior samples set
88 result_B: second full posterior samples set
89 nsamples: number for downsampling full sample set
90 ntests: number of iterations over different nsamples realisations
91 returns: 1 dim array (of lenght ntests)
92 """
94 samplesA = resultA[key]
95 samplesB = resultB[key]
97 # Get minimum number of samples to use
98 nsamples = min([nsamples, len(samplesA), len(samplesB)])
100 xlow, xhigh = None, None
101 if key in default_bounds.keys():
102 bounds = default_bounds[key]
103 if "low" in bounds.keys():
104 xlow = bounds["low"]
105 if "high" in bounds.keys():
106 if isinstance(bounds["high"], str) and "mass_1" in bounds["high"]:
107 xhigh = np.min([np.max(samplesA), np.max(samplesB)])
108 else:
109 xhigh = bounds["high"]
111 js_array = np.zeros(ntests)
113 for j in tqdm(range(ntests)):
114 bootA = np.random.choice(samplesA, size=nsamples, replace=False)
115 bootB = np.random.choice(samplesB, size=nsamples, replace=False)
116 js_array[j] = np.nan_to_num(
117 jensen_shannon_divergence([bootA, bootB], kde=ReflectionBoundedKDE, xlow=xlow, xhigh=xhigh)
118 )
119 return js_array
122def calc_median_error(jsvalues, quantiles=(0.16, 0.84)):
123 quants_to_compute = np.array([quantiles[0], 0.5, quantiles[1]])
124 quants = np.percentile(jsvalues, quants_to_compute * 100)
125 summary = namedtuple("summary", ["median", "lower", "upper"])
126 summary.median = quants[1]
127 summary.plus = quants[2] - summary.median
128 summary.minus = summary.median - quants[0]
129 return summary
132def bin_series_and_calc_cdf(x, y, bins=100):
133 """
134 Bin two unequal length series into equal bins
135 and calculate their cumulative distibution function
136 in order to generate pp-plots
137 """
138 boundaries = sorted(x)[:: round(len(x) / bins) + 1]
139 labels = [(boundaries[i] + boundaries[i + 1]) / 2 for i in range(len(boundaries) - 1)]
140 # Bin two series into equal bins
141 try:
142 xb = pd.cut(x, bins=boundaries, labels=labels)
143 yb = pd.cut(y, bins=boundaries, labels=labels)
144 # Get value counts for each bin and sort by bin
145 xhist = xb.value_counts().sort_index(ascending=True) / len(xb)
146 yhist = yb.value_counts().sort_index(ascending=True) / len(yb)
147 # Make cumulative
148 for ser in [xhist, yhist]:
149 ttl = 0
150 for idx, val in ser.items():
151 ttl += val
152 ser.loc[idx] = ttl
153 except ValueError:
154 xhist = np.linspace(0, 1, 1000)
155 yhist = np.linspace(0, 1, 1000)
156 return xhist, yhist
159def calculate_CI(len_samples, confidence_level=0.95, n_points=1001):
160 """
161 Calculate confidence intervals
162 (https://git.ligo.org/lscsoft/bilby/blob/master/bilby/core/result.py#L1578)
163 len_samples: lenght of posterior samples
164 confidence level: default 90%
165 n_points: number of points over which evaluating confidence region
166 """
167 x_values = np.linspace(0, 1, n_points)
168 N = len_samples
169 edge_of_bound = (1.0 - confidence_level) / 2.0
170 lower = binom.ppf(1 - edge_of_bound, N, x_values) / N
171 upper = binom.ppf(edge_of_bound, N, x_values) / N
172 lower[0] = 0
173 upper[0] = 0
174 return x_values, upper, lower
177def pp_plot(event, resultA, resultB, labelA, labelB, main_keys, nsamples, js_data, webdir):
178 """
179 Produce PP plot between sampleA and samplesB
180 for a set of paramaters (main keys) for a given event.
181 The JS divergence for each pair of samples is shown in legend.
182 """
183 # Creating dict where ks_data for each event will be saved
184 fig, ax = figure(figsize=(6, 5), gca=True)
186 latex_labels.update(GWlatex_labels)
187 for key_index, key in enumerate(main_keys):
189 # Check the key exists in both sets of samples
190 if key not in resultA or key not in resultB:
191 logger.debug(f"Missing key {key}")
192 continue
193 # Get minimum number of samples to use
194 nsamples = min([nsamples, len(resultA[key]), len(resultB[key])])
196 # Resample to nsamples
197 lp = np.random.choice(resultA[key], size=nsamples, replace=False)
198 bp = np.random.choice(resultB[key], size=nsamples, replace=False)
200 # Bin posterior samples into equal lenght bins and calculate cumulative
201 xhist, yhist = bin_series_and_calc_cdf(bp, lp)
203 summary = js_data[key]
204 logger.debug(f"JS {key}: {summary.median}, {summary.minus}, {summary.plus}")
205 fmt = "{{0:{0}}}".format(".5f").format
207 if key not in list(latex_labels.keys()):
208 latex_labels[key] = key.replace("_", " ")
209 ax.plot(
210 xhist,
211 yhist - xhist,
212 label=latex_labels[key]
213 + r" ${{{0}}}_{{-{1}}}^{{+{2}}}$".format(fmt(summary.median), fmt(summary.minus), fmt(summary.plus)),
214 linewidth=1,
215 linestyle="-",
216 )
218 for confidence in [0.68, 0.95, 0.997]:
219 x_values, upper, lower = calculate_CI(nsamples, confidence_level=confidence)
220 ax.fill_between(
221 x_values, lower - x_values, upper - x_values, linewidth=1, color="k", alpha=0.1,
222 )
223 ax.set_xlabel(f"{labelA} CDF")
224 ax.set_ylabel(f"{labelA} CDF - {labelB} CDF")
225 ax.set_xlim(0, 1)
226 ax.legend(loc="upper right", ncol=4, fontsize=6)
227 ax.set_title(r"{} N samples={:.0f}".format(event, nsamples))
228 ax.grid()
229 fig.tight_layout()
230 plt.savefig(os.path.join(webdir, "{}-comparison-{}-{}.png".format(event, labelA, labelB)))
233def parse_cmd_line(args=None):
234 _parser = ArgumentParser(description=__doc__)
235 _parser.add_known_options_to_parser(
236 [
237 "--seed", "--nsamples", "--webdir", "--ntests", "--main_keys",
238 "--labels", "--samples", "--event"
239 ]
240 )
241 args, unknown = _parser.parse_known_args(args=args)
242 return args
245def main(args=None):
246 args = parse_cmd_line(args=args)
248 # Set random seed
249 np.random.seed(seed=args.seed)
251 # Read in the keys to apply
252 main_keys = args.main_keys
254 # Read in the results and labels
255 resultA = load_data(args.samples[0])
256 labelA = args.labels[0]
257 resultB = load_data(args.samples[1])
258 labelB = args.labels[1]
260 logger.debug("Evaluating JS divergence..")
261 js_data = dict()
262 js = np.zeros((args.ntests, len(main_keys)))
263 for i, key in enumerate(main_keys):
264 js[:, i] = js_bootstrap(key, resultA, resultB, args.nsamples, ntests=args.ntests,)
265 js_data[key] = calc_median_error(js[:, i])
266 logger.debug("Making pp-plot..")
267 pp_plot(args.event, resultA, resultB, labelA, labelB, main_keys, args.nsamples, js_data=js_data, webdir=args.webdir)
270if __name__ == "__main__":
271 main()