Coverage for pesummary/cli/summaryjscompare.py: 94.8%

135 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1#!/usr/bin/env python3 

2""" 

3Interface to generate JS test comparison between two results files. 

4 

5This plots the difference in CDF between two sets of samples and adds the JS test 

6statistic with uncertainty estimated by bootstrapping. 

7""" 

8 

9from collections import namedtuple 

10import numpy as np 

11import os 

12import pandas as pd 

13from scipy.stats import binom 

14import matplotlib 

15import matplotlib.style 

16import matplotlib.pyplot as plt 

17 

18from pesummary.io import read 

19from pesummary.core.cli.parser import ArgumentParser as _ArgumentParser 

20from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE 

21from pesummary.core.plots.figure import figure 

22from pesummary.gw.plots.bounds import default_bounds 

23from pesummary.gw.plots.latex_labels import GWlatex_labels 

24from pesummary.core.plots.latex_labels import latex_labels 

25from pesummary.utils.tqdm import tqdm 

26from pesummary.utils.utils import jensen_shannon_divergence 

27from pesummary.utils.utils import _check_latex_install, get_matplotlib_style_file, logger 

28 

29__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

30 

31matplotlib.style.use(get_matplotlib_style_file()) 

32_check_latex_install() 

33 

34 

35class ArgumentParser(_ArgumentParser): 

36 def _pesummary_options(self): 

37 options = super(ArgumentParser, self)._pesummary_options() 

38 options.update( 

39 { 

40 "--ntests": { 

41 "type": int, 

42 "default": 100, 

43 "help": "Number of iterations for bootstrapping", 

44 }, 

45 "--main_keys": { 

46 "nargs": "+", 

47 "default": [ 

48 "theta_jn", "chirp_mass", "mass_ratio", "tilt_1", 

49 "tilt_2", "luminosity_distance", "ra", "dec", "a_1", 

50 "a_2" 

51 ], 

52 "help": "List of parameter names" 

53 }, 

54 "--event": { 

55 "type": str, 

56 "required": True, 

57 "help": "Label, e.g. the event name" 

58 }, 

59 "--samples": { 

60 "short": "-s", 

61 "type": str, 

62 "nargs": 2, 

63 "help": "Paths to a pair of results files to compare." 

64 }, 

65 "--labels": { 

66 "short": "-l", 

67 "type": str, 

68 "nargs": 2, 

69 "help": "Pair of labels for each result file", 

70 }, 

71 } 

72 ) 

73 options["--nsamples"]["default"] = 10000 

74 return options 

75 

76 

77def load_data(data_file): 

78 """ Read in a data file and return samples dictionary """ 

79 f = read(data_file, package="gw", disable_prior=True) 

80 return f.samples_dict 

81 

82 

83def js_bootstrap(key, resultA, resultB, nsamples, ntests): 

84 """ 

85 Evaluates mean JS divergence with bootstrapping 

86 key: string posterior parameter 

87 result_A: first full posterior samples set 

88 result_B: second full posterior samples set 

89 nsamples: number for downsampling full sample set 

90 ntests: number of iterations over different nsamples realisations 

91 returns: 1 dim array (of lenght ntests) 

92 """ 

93 

94 samplesA = resultA[key] 

95 samplesB = resultB[key] 

96 

97 # Get minimum number of samples to use 

98 nsamples = min([nsamples, len(samplesA), len(samplesB)]) 

99 

100 xlow, xhigh = None, None 

101 if key in default_bounds.keys(): 

102 bounds = default_bounds[key] 

103 if "low" in bounds.keys(): 

104 xlow = bounds["low"] 

105 if "high" in bounds.keys(): 

106 if isinstance(bounds["high"], str) and "mass_1" in bounds["high"]: 

107 xhigh = np.min([np.max(samplesA), np.max(samplesB)]) 

108 else: 

109 xhigh = bounds["high"] 

110 

111 js_array = np.zeros(ntests) 

112 

113 for j in tqdm(range(ntests)): 

114 bootA = np.random.choice(samplesA, size=nsamples, replace=False) 

115 bootB = np.random.choice(samplesB, size=nsamples, replace=False) 

116 js_array[j] = np.nan_to_num( 

117 jensen_shannon_divergence( 

118 [bootA, bootB], 

119 kde=ReflectionBoundedKDE, xlow=xlow, xhigh=xhigh 

120 ), nan=0.0 

121 ) 

122 return js_array 

123 

124 

125def calc_median_error(jsvalues, quantiles=(0.16, 0.84)): 

126 quants_to_compute = np.array([quantiles[0], 0.5, quantiles[1]]) 

127 quants = np.percentile(jsvalues, quants_to_compute * 100) 

128 summary = namedtuple("summary", ["median", "lower", "upper"]) 

129 summary.median = quants[1] 

130 summary.plus = quants[2] - summary.median 

131 summary.minus = summary.median - quants[0] 

132 return summary 

133 

134 

135def bin_series_and_calc_cdf(x, y, bins=100): 

136 """ 

137 Bin two unequal length series into equal bins 

138 and calculate their cumulative distibution function 

139 in order to generate pp-plots 

140 """ 

141 boundaries = sorted(x)[:: round(len(x) / bins) + 1] 

142 labels = [(boundaries[i] + boundaries[i + 1]) / 2 for i in range(len(boundaries) - 1)] 

143 # Bin two series into equal bins 

144 try: 

145 xb = pd.cut(x, bins=boundaries, labels=labels) 

146 yb = pd.cut(y, bins=boundaries, labels=labels) 

147 # Get value counts for each bin and sort by bin 

148 xhist = xb.value_counts().sort_index(ascending=True) / len(xb) 

149 yhist = yb.value_counts().sort_index(ascending=True) / len(yb) 

150 # Make cumulative 

151 for ser in [xhist, yhist]: 

152 ttl = 0 

153 for idx, val in ser.items(): 

154 ttl += val 

155 ser.loc[idx] = ttl 

156 except ValueError: 

157 xhist = np.linspace(0, 1, 1000) 

158 yhist = np.linspace(0, 1, 1000) 

159 return xhist, yhist 

160 

161 

162def calculate_CI(len_samples, confidence_level=0.95, n_points=1001): 

163 """ 

164 Calculate confidence intervals 

165 (https://git.ligo.org/lscsoft/bilby/blob/master/bilby/core/result.py#L1578) 

166 len_samples: lenght of posterior samples 

167 confidence level: default 90% 

168 n_points: number of points over which evaluating confidence region 

169 """ 

170 x_values = np.linspace(0, 1, n_points) 

171 N = len_samples 

172 edge_of_bound = (1.0 - confidence_level) / 2.0 

173 lower = binom.ppf(1 - edge_of_bound, N, x_values) / N 

174 upper = binom.ppf(edge_of_bound, N, x_values) / N 

175 lower[0] = 0 

176 upper[0] = 0 

177 return x_values, upper, lower 

178 

179 

180def pp_plot(event, resultA, resultB, labelA, labelB, main_keys, nsamples, js_data, webdir): 

181 """ 

182 Produce PP plot between sampleA and samplesB 

183 for a set of paramaters (main keys) for a given event. 

184 The JS divergence for each pair of samples is shown in legend. 

185 """ 

186 # Creating dict where ks_data for each event will be saved 

187 fig, ax = figure(figsize=(6, 5), gca=True) 

188 

189 latex_labels.update(GWlatex_labels) 

190 for key_index, key in enumerate(main_keys): 

191 

192 # Check the key exists in both sets of samples 

193 if key not in resultA or key not in resultB: 

194 logger.debug(f"Missing key {key}") 

195 continue 

196 # Get minimum number of samples to use 

197 nsamples = min([nsamples, len(resultA[key]), len(resultB[key])]) 

198 

199 # Resample to nsamples 

200 lp = np.random.choice(resultA[key], size=nsamples, replace=False) 

201 bp = np.random.choice(resultB[key], size=nsamples, replace=False) 

202 

203 # Bin posterior samples into equal lenght bins and calculate cumulative 

204 xhist, yhist = bin_series_and_calc_cdf(bp, lp) 

205 

206 summary = js_data[key] 

207 logger.debug(f"JS {key}: {summary.median}, {summary.minus}, {summary.plus}") 

208 fmt = "{{0:{0}}}".format(".5f").format 

209 

210 if key not in list(latex_labels.keys()): 

211 latex_labels[key] = key.replace("_", " ") 

212 ax.plot( 

213 xhist, 

214 yhist - xhist, 

215 label=latex_labels[key] 

216 + r" ${{{0}}}_{{-{1}}}^{{+{2}}}$".format(fmt(summary.median), fmt(summary.minus), fmt(summary.plus)), 

217 linewidth=1, 

218 linestyle="-", 

219 ) 

220 

221 for confidence in [0.68, 0.95, 0.997]: 

222 x_values, upper, lower = calculate_CI(nsamples, confidence_level=confidence) 

223 ax.fill_between( 

224 x_values, lower - x_values, upper - x_values, linewidth=1, color="k", alpha=0.1, 

225 ) 

226 ax.set_xlabel(f"{labelA} CDF") 

227 ax.set_ylabel(f"{labelA} CDF - {labelB} CDF") 

228 ax.set_xlim(0, 1) 

229 ax.legend(loc="upper right", ncol=4, fontsize=6) 

230 ax.set_title(r"{} N samples={:.0f}".format(event, nsamples)) 

231 ax.grid() 

232 fig.tight_layout() 

233 plt.savefig(os.path.join(webdir, "{}-comparison-{}-{}.png".format(event, labelA, labelB))) 

234 

235 

236def parse_cmd_line(args=None): 

237 _parser = ArgumentParser(description=__doc__) 

238 _parser.add_known_options_to_parser( 

239 [ 

240 "--seed", "--nsamples", "--webdir", "--ntests", "--main_keys", 

241 "--labels", "--samples", "--event" 

242 ] 

243 ) 

244 args, unknown = _parser.parse_known_args(args=args) 

245 return args 

246 

247 

248def main(args=None): 

249 args = parse_cmd_line(args=args) 

250 

251 # Set random seed 

252 np.random.seed(seed=args.seed) 

253 

254 # Read in the keys to apply 

255 main_keys = args.main_keys 

256 

257 # Read in the results and labels 

258 resultA = load_data(args.samples[0]) 

259 labelA = args.labels[0] 

260 resultB = load_data(args.samples[1]) 

261 labelB = args.labels[1] 

262 

263 logger.debug("Evaluating JS divergence..") 

264 js_data = dict() 

265 js = np.zeros((args.ntests, len(main_keys))) 

266 for i, key in enumerate(main_keys): 

267 js[:, i] = js_bootstrap(key, resultA, resultB, args.nsamples, ntests=args.ntests,) 

268 js_data[key] = calc_median_error(js[:, i]) 

269 logger.debug("Making pp-plot..") 

270 pp_plot(args.event, resultA, resultB, labelA, labelB, main_keys, args.nsamples, js_data=js_data, webdir=args.webdir) 

271 

272 

273if __name__ == "__main__": 

274 main()