Coverage for pesummary/cli/summaryjscompare.py: 94.8%

135 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1#!/usr/bin/env python3 

2""" 

3Interface to generate JS test comparison between two results files. 

4 

5This plots the difference in CDF between two sets of samples and adds the JS test 

6statistic with uncertainty estimated by bootstrapping. 

7""" 

8 

9from collections import namedtuple 

10import numpy as np 

11import os 

12import pandas as pd 

13from scipy.stats import binom 

14import matplotlib 

15import matplotlib.style 

16import matplotlib.pyplot as plt 

17 

18from pesummary.io import read 

19from pesummary.core.cli.parser import ArgumentParser as _ArgumentParser 

20from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE 

21from pesummary.core.plots.figure import figure 

22from pesummary.gw.plots.bounds import default_bounds 

23from pesummary.gw.plots.latex_labels import GWlatex_labels 

24from pesummary.core.plots.latex_labels import latex_labels 

25from pesummary.utils.tqdm import tqdm 

26from pesummary.utils.utils import jensen_shannon_divergence 

27from pesummary.utils.utils import _check_latex_install, get_matplotlib_style_file, logger 

28 

29__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

30 

31matplotlib.style.use(get_matplotlib_style_file()) 

32_check_latex_install() 

33 

34 

35class ArgumentParser(_ArgumentParser): 

36 def _pesummary_options(self): 

37 options = super(ArgumentParser, self)._pesummary_options() 

38 options.update( 

39 { 

40 "--ntests": { 

41 "type": int, 

42 "default": 100, 

43 "help": "Number of iterations for bootstrapping", 

44 }, 

45 "--main_keys": { 

46 "nargs": "+", 

47 "default": [ 

48 "theta_jn", "chirp_mass", "mass_ratio", "tilt_1", 

49 "tilt_2", "luminosity_distance", "ra", "dec", "a_1", 

50 "a_2" 

51 ], 

52 "help": "List of parameter names" 

53 }, 

54 "--event": { 

55 "type": str, 

56 "required": True, 

57 "help": "Label, e.g. the event name" 

58 }, 

59 "--samples": { 

60 "short": "-s", 

61 "type": str, 

62 "nargs": 2, 

63 "help": "Paths to a pair of results files to compare." 

64 }, 

65 "--labels": { 

66 "short": "-l", 

67 "type": str, 

68 "nargs": 2, 

69 "help": "Pair of labels for each result file", 

70 }, 

71 } 

72 ) 

73 options["--nsamples"]["default"] = 10000 

74 return options 

75 

76 

77def load_data(data_file): 

78 """ Read in a data file and return samples dictionary """ 

79 f = read(data_file, package="gw", disable_prior=True) 

80 return f.samples_dict 

81 

82 

83def js_bootstrap(key, resultA, resultB, nsamples, ntests): 

84 """ 

85 Evaluates mean JS divergence with bootstrapping 

86 key: string posterior parameter 

87 result_A: first full posterior samples set 

88 result_B: second full posterior samples set 

89 nsamples: number for downsampling full sample set 

90 ntests: number of iterations over different nsamples realisations 

91 returns: 1 dim array (of lenght ntests) 

92 """ 

93 

94 samplesA = resultA[key] 

95 samplesB = resultB[key] 

96 

97 # Get minimum number of samples to use 

98 nsamples = min([nsamples, len(samplesA), len(samplesB)]) 

99 

100 xlow, xhigh = None, None 

101 if key in default_bounds.keys(): 

102 bounds = default_bounds[key] 

103 if "low" in bounds.keys(): 

104 xlow = bounds["low"] 

105 if "high" in bounds.keys(): 

106 if isinstance(bounds["high"], str) and "mass_1" in bounds["high"]: 

107 xhigh = np.min([np.max(samplesA), np.max(samplesB)]) 

108 else: 

109 xhigh = bounds["high"] 

110 

111 js_array = np.zeros(ntests) 

112 

113 for j in tqdm(range(ntests)): 

114 bootA = np.random.choice(samplesA, size=nsamples, replace=False) 

115 bootB = np.random.choice(samplesB, size=nsamples, replace=False) 

116 js_array[j] = np.nan_to_num( 

117 jensen_shannon_divergence([bootA, bootB], kde=ReflectionBoundedKDE, xlow=xlow, xhigh=xhigh) 

118 ) 

119 return js_array 

120 

121 

122def calc_median_error(jsvalues, quantiles=(0.16, 0.84)): 

123 quants_to_compute = np.array([quantiles[0], 0.5, quantiles[1]]) 

124 quants = np.percentile(jsvalues, quants_to_compute * 100) 

125 summary = namedtuple("summary", ["median", "lower", "upper"]) 

126 summary.median = quants[1] 

127 summary.plus = quants[2] - summary.median 

128 summary.minus = summary.median - quants[0] 

129 return summary 

130 

131 

132def bin_series_and_calc_cdf(x, y, bins=100): 

133 """ 

134 Bin two unequal length series into equal bins 

135 and calculate their cumulative distibution function 

136 in order to generate pp-plots 

137 """ 

138 boundaries = sorted(x)[:: round(len(x) / bins) + 1] 

139 labels = [(boundaries[i] + boundaries[i + 1]) / 2 for i in range(len(boundaries) - 1)] 

140 # Bin two series into equal bins 

141 try: 

142 xb = pd.cut(x, bins=boundaries, labels=labels) 

143 yb = pd.cut(y, bins=boundaries, labels=labels) 

144 # Get value counts for each bin and sort by bin 

145 xhist = xb.value_counts().sort_index(ascending=True) / len(xb) 

146 yhist = yb.value_counts().sort_index(ascending=True) / len(yb) 

147 # Make cumulative 

148 for ser in [xhist, yhist]: 

149 ttl = 0 

150 for idx, val in ser.items(): 

151 ttl += val 

152 ser.loc[idx] = ttl 

153 except ValueError: 

154 xhist = np.linspace(0, 1, 1000) 

155 yhist = np.linspace(0, 1, 1000) 

156 return xhist, yhist 

157 

158 

159def calculate_CI(len_samples, confidence_level=0.95, n_points=1001): 

160 """ 

161 Calculate confidence intervals 

162 (https://git.ligo.org/lscsoft/bilby/blob/master/bilby/core/result.py#L1578) 

163 len_samples: lenght of posterior samples 

164 confidence level: default 90% 

165 n_points: number of points over which evaluating confidence region 

166 """ 

167 x_values = np.linspace(0, 1, n_points) 

168 N = len_samples 

169 edge_of_bound = (1.0 - confidence_level) / 2.0 

170 lower = binom.ppf(1 - edge_of_bound, N, x_values) / N 

171 upper = binom.ppf(edge_of_bound, N, x_values) / N 

172 lower[0] = 0 

173 upper[0] = 0 

174 return x_values, upper, lower 

175 

176 

177def pp_plot(event, resultA, resultB, labelA, labelB, main_keys, nsamples, js_data, webdir): 

178 """ 

179 Produce PP plot between sampleA and samplesB 

180 for a set of paramaters (main keys) for a given event. 

181 The JS divergence for each pair of samples is shown in legend. 

182 """ 

183 # Creating dict where ks_data for each event will be saved 

184 fig, ax = figure(figsize=(6, 5), gca=True) 

185 

186 latex_labels.update(GWlatex_labels) 

187 for key_index, key in enumerate(main_keys): 

188 

189 # Check the key exists in both sets of samples 

190 if key not in resultA or key not in resultB: 

191 logger.debug(f"Missing key {key}") 

192 continue 

193 # Get minimum number of samples to use 

194 nsamples = min([nsamples, len(resultA[key]), len(resultB[key])]) 

195 

196 # Resample to nsamples 

197 lp = np.random.choice(resultA[key], size=nsamples, replace=False) 

198 bp = np.random.choice(resultB[key], size=nsamples, replace=False) 

199 

200 # Bin posterior samples into equal lenght bins and calculate cumulative 

201 xhist, yhist = bin_series_and_calc_cdf(bp, lp) 

202 

203 summary = js_data[key] 

204 logger.debug(f"JS {key}: {summary.median}, {summary.minus}, {summary.plus}") 

205 fmt = "{{0:{0}}}".format(".5f").format 

206 

207 if key not in list(latex_labels.keys()): 

208 latex_labels[key] = key.replace("_", " ") 

209 ax.plot( 

210 xhist, 

211 yhist - xhist, 

212 label=latex_labels[key] 

213 + r" ${{{0}}}_{{-{1}}}^{{+{2}}}$".format(fmt(summary.median), fmt(summary.minus), fmt(summary.plus)), 

214 linewidth=1, 

215 linestyle="-", 

216 ) 

217 

218 for confidence in [0.68, 0.95, 0.997]: 

219 x_values, upper, lower = calculate_CI(nsamples, confidence_level=confidence) 

220 ax.fill_between( 

221 x_values, lower - x_values, upper - x_values, linewidth=1, color="k", alpha=0.1, 

222 ) 

223 ax.set_xlabel(f"{labelA} CDF") 

224 ax.set_ylabel(f"{labelA} CDF - {labelB} CDF") 

225 ax.set_xlim(0, 1) 

226 ax.legend(loc="upper right", ncol=4, fontsize=6) 

227 ax.set_title(r"{} N samples={:.0f}".format(event, nsamples)) 

228 ax.grid() 

229 fig.tight_layout() 

230 plt.savefig(os.path.join(webdir, "{}-comparison-{}-{}.png".format(event, labelA, labelB))) 

231 

232 

233def parse_cmd_line(args=None): 

234 _parser = ArgumentParser(description=__doc__) 

235 _parser.add_known_options_to_parser( 

236 [ 

237 "--seed", "--nsamples", "--webdir", "--ntests", "--main_keys", 

238 "--labels", "--samples", "--event" 

239 ] 

240 ) 

241 args, unknown = _parser.parse_known_args(args=args) 

242 return args 

243 

244 

245def main(args=None): 

246 args = parse_cmd_line(args=args) 

247 

248 # Set random seed 

249 np.random.seed(seed=args.seed) 

250 

251 # Read in the keys to apply 

252 main_keys = args.main_keys 

253 

254 # Read in the results and labels 

255 resultA = load_data(args.samples[0]) 

256 labelA = args.labels[0] 

257 resultB = load_data(args.samples[1]) 

258 labelB = args.labels[1] 

259 

260 logger.debug("Evaluating JS divergence..") 

261 js_data = dict() 

262 js = np.zeros((args.ntests, len(main_keys))) 

263 for i, key in enumerate(main_keys): 

264 js[:, i] = js_bootstrap(key, resultA, resultB, args.nsamples, ntests=args.ntests,) 

265 js_data[key] = calc_median_error(js[:, i]) 

266 logger.debug("Making pp-plot..") 

267 pp_plot(args.event, resultA, resultB, labelA, labelB, main_keys, args.nsamples, js_data=js_data, webdir=args.webdir) 

268 

269 

270if __name__ == "__main__": 

271 main()