Coverage for pesummary/cli/summarypublication.py: 59.6%

178 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1#! /usr/bin/env python 

2 

3# Licensed under an MIT style license -- see LICENSE.md 

4 

5import pesummary 

6from pesummary.gw.file.read import read as GWRead 

7from pesummary.gw.plots.latex_labels import GWlatex_labels 

8from pesummary.gw.plots import publication as pub 

9from pesummary.core.plots import population as pop 

10from pesummary.core.plots.latex_labels import latex_labels 

11from pesummary.utils.utils import make_dir, logger, _check_latex_install 

12from pesummary.core.cli.parser import ArgumentParser as _ArgumentParser 

13from pesummary.core.cli.actions import DictionaryAction 

14import seaborn 

15import numpy as np 

16 

17__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

18__doc__ = """This executable is used to generate publication quality plots given 

19result files""" 

20_check_latex_install() 

21 

22 

23class ArgumentParser(_ArgumentParser): 

24 def _pesummary_options(self): 

25 options = super(ArgumentParser, self)._pesummary_options() 

26 options.update( 

27 { 

28 "--plot": { 

29 "help": "name of the publication plot you wish to produce", 

30 "default": "2d_contour", 

31 "choices": [ 

32 "2d_contour", "violin", "spin_disk", 

33 "population_scatter", "population_scatter_error" 

34 ], 

35 }, 

36 "--parameters": { 

37 "nargs": "+", 

38 "help": "parameters of the 2d contour plot you wish to make", 

39 }, 

40 "--publication_kwargs": { 

41 "help": "Optional kwargs for publication plots", 

42 "nargs": "+", 

43 "default": {}, 

44 "action": DictionaryAction 

45 }, 

46 "--levels": { 

47 "default": [0.9], 

48 "nargs": "+", 

49 "help": "Contour levels you wish to plot", 

50 "type": float 

51 } 

52 } 

53 ) 

54 return options 

55 

56 

57def draw_specific_samples(param, parameters, samples): 

58 """Return samples for a given parameter 

59 

60 param: str 

61 parameter that you wish to get samples for 

62 parameters: nd list 

63 list of all parameters stored in the result file 

64 samples: nd list 

65 list of samples for each parameter 

66 """ 

67 ind = [i.index(param) for i in parameters] 

68 return [[k[ind] for k in l] for l in samples] 

69 

70 

71def default_2d_contour_plot(): 

72 """Return the default 2d contour plots 

73 """ 

74 twod_plots = [["mass_ratio", "chi_eff"], ["mass_1", "mass_2"], 

75 ["luminosity_distance", "chirp_mass_source"], 

76 ["mass_1_source", "mass_2_source"], 

77 ["theta_jn", "luminosity_distance"], 

78 ["network_optimal_snr", "chirp_mass_source"]] 

79 return twod_plots 

80 

81 

82def default_violin_plot(): 

83 """Retrn the default violin plots 

84 """ 

85 violin_plots = ["chi_eff", "chi_p", "mass_ratio", "luminosity_distance"] 

86 return violin_plots 

87 

88 

89def read_samples(result_files): 

90 """Read and return a list of parameters and samples stored in the result 

91 files 

92 

93 Parameters 

94 ---------- 

95 result_files: list 

96 list of result files 

97 """ 

98 parameters = [] 

99 samples = [] 

100 for i in result_files: 

101 try: 

102 f = GWRead(i) 

103 if isinstance(f, pesummary.gw.file.formats.pesummary.PESummary): 

104 parameters.append(f.parameters[0]) 

105 samples.append(f.samples[0]) 

106 else: 

107 f.generate_all_posterior_samples() 

108 parameters.append(f.parameters) 

109 samples.append(f.samples) 

110 except Exception: 

111 logger.warning( 

112 "Failed to read '{}'. Data will not be added to the " 

113 "plots".format(i) 

114 ) 

115 parameters.append([None]) 

116 samples.append([None]) 

117 return parameters, samples 

118 

119 

120def get_colors_and_linestyles(opts): 

121 """Return a list of colors and linestyles 

122 """ 

123 if opts.colors is not None: 

124 colors = opts.colors 

125 else: 

126 colors = seaborn.color_palette( 

127 palette=opts.palette, n_colors=len(opts.samples) 

128 ).as_hex() 

129 if opts.linestyles is not None: 

130 linestyles = opts.linestyles 

131 return colors, linestyles 

132 available_linestyles = ["-", "--", ":", "-."] 

133 linestyles = ["-"] * len(colors) 

134 unique_colors = np.unique(colors) 

135 for color in unique_colors: 

136 indicies = [num for num, i in enumerate(colors) if i == color] 

137 for idx, j in enumerate(indicies): 

138 linestyles[j] = available_linestyles[ 

139 np.mod(idx, len(available_linestyles)) 

140 ] 

141 return colors, linestyles 

142 

143 

144def make_2d_contour_plot(opts): 

145 """Make a 2d contour plot 

146 """ 

147 if opts.parameters and len(opts.parameters) != 2: 

148 raise Exception("Please pass 2 variables that you wish to plot") 

149 if opts.parameters: 

150 default = [opts.parameters] 

151 else: 

152 default = default_2d_contour_plot() 

153 colors, linestyles = get_colors_and_linestyles(opts) 

154 parameters, samples = read_samples(opts.samples) 

155 for i in default: 

156 if not all(all(k in j for k in i) for j in parameters): 

157 idxs = [ 

158 num for num, j in enumerate(parameters) if not 

159 all(k in j for k in i) 

160 ] 

161 files = [opts.samples[j] for j in idxs] 

162 logger.warning( 

163 "Removing {} from 2d contour plot because the parameters {} are " 

164 "not in the result file".format( 

165 " and ".join(files), " and ".join(i) 

166 ) 

167 ) 

168 parameters = [j for num, j in enumerate(parameters) if num not in idxs] 

169 opts.labels = [j for num, j in enumerate(opts.labels) if num not in idxs] 

170 samples = [j for num, j in enumerate(samples) if num not in idxs] 

171 ind1 = [j.index(i[0]) for j in parameters] 

172 ind2 = [j.index(i[1]) for j in parameters] 

173 samples1 = [[k[ind1[num]] for k in l] for num, l in 

174 enumerate(samples)] 

175 samples2 = [[k[ind2[num]] for k in l] for num, l in 

176 enumerate(samples)] 

177 twod_samples = [[j, k] for j, k in zip(samples1, samples2)] 

178 gridsize = ( 

179 opts.publication_kwargs["gridsize"] if "gridsize" in 

180 opts.publication_kwargs.keys() else 100 

181 ) 

182 fig, ax = pub.twod_contour_plots( 

183 i, twod_samples, opts.labels, latex_labels, colors=colors, 

184 linestyles=linestyles, gridsize=gridsize, levels=opts.levels, 

185 return_ax=True 

186 ) 

187 current_xlow, current_xhigh = ax.get_xlim() 

188 current_ylow, current_yhigh = ax.get_ylim() 

189 keys = opts.publication_kwargs.keys() 

190 if "xlow" in keys and "xhigh" in keys: 

191 ax.set_xlim( 

192 [ 

193 float(opts.publication_kwargs["xlow"]), 

194 float(opts.publication_kwargs["xhigh"]) 

195 ] 

196 ) 

197 elif "xhigh" in keys: 

198 ax.set_xlim([current_xlow, float(opts.publication_kwargs["xhigh"])]) 

199 elif "xlow" in keys: 

200 ax.set_xlim([float(opts.publication_kwargs["xlow"]), current_xhigh]) 

201 if "ylow" in keys and "yhigh" in keys: 

202 ax.set_ylim( 

203 [ 

204 float(opts.publication_kwargs["ylow"]), 

205 float(opts.publication_kwargs["yhigh"]) 

206 ] 

207 ) 

208 elif "yhigh" in keys: 

209 ax.set_ylim([current_ylow, float(opts.publication_kwargs["yhigh"])]) 

210 elif "ylow" in keys: 

211 ax.set_ylim([float(opts.publication_kwargs["ylow"]), current_yhigh]) 

212 fig.savefig("%s/2d_contour_plot_%s" % (opts.webdir, "_and_".join(i))) 

213 fig.close() 

214 

215 

216def make_violin_plot(opts): 

217 """ 

218 """ 

219 if opts.parameters and len(opts.parameters) != 1: 

220 raise Exception("Please pass a single variable that you wish to plot") 

221 if opts.parameters: 

222 default = opts.parameters 

223 else: 

224 default = default_violin_plot() 

225 parameters, samples = read_samples(opts.samples) 

226 

227 for i in default: 

228 if not all(i in j for j in parameters): 

229 idxs = [num for num, j in enumerate(parameters) if i not in j] 

230 files = [opts.samples[j] for j in idxs] 

231 logger.warning( 

232 "Removing {} from violin plot because the parameter {} does " 

233 "not exist in the result file".format( 

234 " and ".join(files), i 

235 ) 

236 ) 

237 parameters = [j for num, j in enumerate(parameters) if num not in idxs] 

238 opts.labels = [j for num, j in enumerate(opts.labels) if num not in idxs] 

239 samples = [j for num, j in enumerate(samples) if num not in idxs] 

240 try: 

241 ind = [j.index(i) for j in parameters] 

242 samples = [[k[ind[num]] for k in l] for num, l in 

243 enumerate(samples)] 

244 fig = pub.violin_plots(i, samples, opts.labels, latex_labels) 

245 fig.savefig("%s/violin_plot_%s.png" % (opts.webdir, i)) 

246 fig.close() 

247 except Exception as e: 

248 logger.info( 

249 "Failed to generate a violin plot for %s because %s" % (i, e) 

250 ) 

251 continue 

252 

253 

254def make_spin_disk_plot(opts): 

255 """Make a spin disk plot 

256 """ 

257 colors, linestyles = get_colors_and_linestyles(opts) 

258 parameters, samples = read_samples(opts.samples) 

259 

260 required_parameters = ["a_1", "a_2", "cos_tilt_1", "cos_tilt_2"] 

261 for num, i in enumerate(parameters): 

262 if not all(j in i for j in required_parameters): 

263 logger.info("Failed to generate spin disk plots for %s because " 

264 "%s are not in the result file" % ( 

265 opts.labels[num], 

266 " and ".join(required_parameters))) 

267 continue 

268 try: 

269 ind = [i.index(j) for j in required_parameters] 

270 spin_samples = [[k[idx] for k in samples[num]] for idx in ind] 

271 fig = pub.spin_distribution_plots( 

272 required_parameters, spin_samples, opts.labels[num], 

273 colors[num]) 

274 fig.savefig("%s/spin_disk_plot_%s.png" % ( 

275 opts.webdir, opts.labels[num])) 

276 fig.close() 

277 except Exception as e: 

278 logger.warning( 

279 "Failed to generate a spin disk plot for %s because %s" % ( 

280 opts.labels[num], e 

281 ) 

282 ) 

283 continue 

284 

285 

286def make_population_scatter_plot(opts): 

287 """Make a scatter plot showing a population of runs 

288 """ 

289 if len(opts.samples) > 1: 

290 parameters, samples = read_samples(opts.samples) 

291 plotting_data = {} 

292 xerr, yerr = None, None 

293 if "error" in opts.plot: 

294 xerr, yerr = {}, {} 

295 for num, label in enumerate(opts.labels): 

296 if not all(i in parameters[num] for i in opts.parameters): 

297 logger.warning( 

298 "'{}' does not include samples for '{}' and/or '{}'. This " 

299 "analysis will not be added to the plot".format( 

300 label, opts.parameters[0], opts.parameters[1] 

301 ) 

302 ) 

303 continue 

304 plotting_data[label] = {} 

305 if xerr is not None: 

306 xerr[label] = {} 

307 yerr[label] = {} 

308 for param in opts.parameters: 

309 ind = parameters[num].index(param) 

310 plotting_data[label][param] = np.median( 

311 [i[ind] for i in samples[num]] 

312 ) 

313 if xerr is not None: 

314 ind = parameters[num].index(opts.parameters[0]) 

315 xerr[label][opts.parameters[0]] = [ 

316 np.abs(plotting_data[label][opts.parameters[0]] - np.percentile( 

317 [i[ind] for i in samples[num]], j 

318 )) for j in [5, 95] 

319 ] 

320 if yerr is not None: 

321 ind = parameters[num].index(opts.parameters[1]) 

322 yerr[label][opts.parameters[1]] = [ 

323 np.abs(plotting_data[label][opts.parameters[1]] - np.percentile( 

324 [i[ind] for i in samples[num]], j 

325 )) for j in [5, 95] 

326 ] 

327 fig = pop.scatter_plot( 

328 opts.parameters, plotting_data, latex_labels, xerr=xerr, yerr=yerr 

329 ) 

330 fig.savefig("{}/event_scatter_plot_{}.png".format( 

331 opts.webdir, "_and_".join(opts.parameters) 

332 )) 

333 fig.close() 

334 

335 

336def main(args=None): 

337 """Top level interface for `summarypublication` 

338 """ 

339 latex_labels.update(GWlatex_labels) 

340 parser = ArgumentParser(description=__doc__) 

341 parser.add_known_options_to_parser( 

342 [ 

343 "--webdir", "--samples", "--labels", "--plot", "--parameters", 

344 "--publication_kwargs", "--colors", "--palette", "--linestyles", 

345 "--levels" 

346 ] 

347 ) 

348 opts = parser.parse_args(args=args) 

349 make_dir(opts.webdir) 

350 func_map = {"2d_contour": make_2d_contour_plot, 

351 "violin": make_violin_plot, 

352 "spin_disk": make_spin_disk_plot, 

353 "population_scatter": make_population_scatter_plot, 

354 "population_scatter_error": make_population_scatter_plot} 

355 func_map[opts.plot](opts) 

356 

357 

358if __name__ == "__main__": 

359 main()