Coverage for pesummary/gw/plots/publication.py: 84.1%

227 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3from pesummary.utils.utils import logger, number_of_columns_for_legend 

4import seaborn 

5from pesummary.core.plots.figure import figure 

6from pesummary.core.plots.seaborn import violin 

7from pesummary.utils.bounded_2d_kde import Bounded_2d_kde 

8from pesummary.gw.plots.bounds import default_bounds 

9from pesummary.gw.plots.cmap import colormap_with_fixed_hue 

10from pesummary.gw.conversions import mchirp_from_m1_m2, q_from_m1_m2 

11import numpy as np 

12import copy 

13 

14__author__ = [ 

15 "Charlie Hoy <charlie.hoy@ligo.org>", 

16 "Michael Puerrer <michael.puerrer@ligo.org>" 

17] 

18 

19 

20def chirp_mass_and_q_from_mass1_mass2(pts): 

21 """Transform the component masses to chirp mass and mass ratio 

22 

23 Parameters 

24 ---------- 

25 pts: numpy.array 

26 array containing the mass1 and mass2 samples 

27 """ 

28 pts = np.atleast_2d(pts) 

29 

30 m1, m2 = pts 

31 mc = mchirp_from_m1_m2(m1, m2) 

32 q = q_from_m1_m2(m1, m2) 

33 return np.vstack([mc, q]) 

34 

35 

36def _return_bounds(parameters, T=True): 

37 """Return bounds for KDE 

38 

39 Parameters 

40 ---------- 

41 parameters: list 

42 list of parameters being plotted 

43 T: Bool, optional 

44 if True, modify the parameter bounds if a transform is required 

45 """ 

46 transform = xlow = xhigh = ylow = yhigh = None 

47 if parameters[0] in list(default_bounds.keys()): 

48 if "low" in list(default_bounds[parameters[0]].keys()): 

49 xlow = default_bounds[parameters[0]]["low"] 

50 if "high" in list(default_bounds[parameters[0]].keys()): 

51 if isinstance(default_bounds[parameters[0]]["high"], str) and T: 

52 if "mass_1" in default_bounds[parameters[0]]["high"]: 

53 transform = chirp_mass_and_q_from_mass1_mass2 

54 xhigh = 1. 

55 elif isinstance(default_bounds[parameters[0]]["high"], str): 

56 xhigh = None 

57 else: 

58 xhigh = default_bounds[parameters[0]]["high"] 

59 if parameters[1] in list(default_bounds.keys()): 

60 if "low" in list(default_bounds[parameters[1]].keys()): 

61 ylow = default_bounds[parameters[1]]["low"] 

62 if "high" in list(default_bounds[parameters[1]].keys()): 

63 if isinstance(default_bounds[parameters[1]]["high"], str) and T: 

64 if "mass_1" in default_bounds[parameters[1]]["high"]: 

65 transform = chirp_mass_and_q_from_mass1_mass2 

66 yhigh = 1. 

67 elif isinstance(default_bounds[parameters[1]]["high"], str): 

68 yhigh = None 

69 else: 

70 yhigh = default_bounds[parameters[1]]["high"] 

71 return transform, xlow, xhigh, ylow, yhigh 

72 

73 

74def twod_contour_plots( 

75 parameters, samples, labels, latex_labels, colors=None, linestyles=None, 

76 return_ax=False, plot_datapoints=False, smooth=None, latex_friendly=False, 

77 levels=[0.9], legend_kwargs={ 

78 "bbox_to_anchor": (0., 1.02, 1., .102), "loc": 3, "handlelength": 3, 

79 "mode": "expand", "borderaxespad": 0., "handleheight": 1.75 

80 }, **kwargs 

81): 

82 """Generate 2d contour plots for a set of samples for given parameters 

83 

84 Parameters 

85 ---------- 

86 parameters: list 

87 names of the parameters that you wish to plot 

88 samples: nd list 

89 list of samples for each parameter 

90 labels: list 

91 list of labels corresponding to each set of samples 

92 latex_labels: dict 

93 dictionary of latex labels 

94 """ 

95 from pesummary.core.plots.publication import ( 

96 comparison_twod_contour_plot as core 

97 ) 

98 from matplotlib.patches import Polygon 

99 

100 logger.debug("Generating 2d contour plots for %s" % ("_and_".join(parameters))) 

101 if colors is None: 

102 palette = seaborn.color_palette(palette="pastel", n_colors=len(samples)) 

103 else: 

104 palette = colors 

105 if linestyles is None: 

106 linestyles = ["-"] * len(samples) 

107 fig, ax1 = figure(gca=True) 

108 transform, xlow, xhigh, ylow, yhigh = _return_bounds(parameters) 

109 kwargs.update( 

110 { 

111 "kde": Bounded_2d_kde, "kde_kwargs": { 

112 "transform": transform, "xlow": xlow, "xhigh": xhigh, 

113 "ylow": ylow, "yhigh": yhigh 

114 } 

115 } 

116 ) 

117 fig = core( 

118 [i[0] for i in samples], [i[1] for i in samples], colors=colors, 

119 labels=labels, xlabel=latex_labels[parameters[0]], smooth=smooth, 

120 ylabel=latex_labels[parameters[1]], linestyles=linestyles, 

121 plot_datapoints=plot_datapoints, levels=levels, **kwargs 

122 ) 

123 ax1 = fig.gca() 

124 if all("mass_1" in i or "mass_2" in i for i in parameters): 

125 reg = Polygon([[0, 0], [0, 1000], [1000, 1000]], color='gray', alpha=0.75) 

126 ax1.add_patch(reg) 

127 ncols = number_of_columns_for_legend(labels) 

128 legend_kwargs.update({"ncol": ncols}) 

129 legend = ax1.legend(**legend_kwargs) 

130 for leg in legend.get_lines(): 

131 leg.set_linewidth(legend_kwargs.get("handleheight", 1.)) 

132 fig.tight_layout() 

133 if return_ax: 

134 return fig, ax1 

135 return fig 

136 

137 

138def _setup_triangle_plot(parameters, kwargs): 

139 """Modify a dictionary of kwargs for bounded KDEs 

140 

141 Parameters 

142 ---------- 

143 parameters: list 

144 list of parameters being plotted 

145 kwargs: dict 

146 kwargs to be passed to pesummary.gw.plots.publication.triangle_plot 

147 or pesummary.gw.plots.publication.reverse_triangle_plot 

148 """ 

149 from pesummary.utils.bounded_1d_kde import bounded_1d_kde 

150 

151 if not len(parameters): 

152 raise ValueError("Please provide a list of parameters") 

153 transform, xlow, xhigh, ylow, yhigh = _return_bounds(parameters) 

154 kwargs.update( 

155 { 

156 "kde_2d": Bounded_2d_kde, "kde_2d_kwargs": { 

157 "transform": transform, "xlow": xlow, "xhigh": xhigh, 

158 "ylow": ylow, "yhigh": yhigh 

159 }, "kde": bounded_1d_kde 

160 } 

161 ) 

162 _, xlow, xhigh, ylow, yhigh = _return_bounds(parameters, T=False) 

163 kwargs["kde_kwargs"] = { 

164 "x_axis": {"xlow": xlow, "xhigh": xhigh}, 

165 "y_axis": {"xlow": ylow, "xhigh": yhigh} 

166 } 

167 return kwargs 

168 

169 

170def triangle_plot(*args, parameters=[], **kwargs): 

171 """Generate a triangular plot made of 3 axis. One central axis showing the 

172 2d marginalized posterior and two smaller axes showing the marginalized 1d 

173 posterior distribution (above and to the right of central axis) 

174 

175 Parameters 

176 ---------- 

177 *args: tuple 

178 all args passed to pesummary.core.plots.publication.triangle_plot 

179 parameters: list 

180 list of parameters being plotted 

181 kwargs: dict, optional 

182 all kwargs passed to pesummary.core.plots.publication.triangle_plot 

183 """ 

184 from pesummary.core.plots.publication import triangle_plot as core 

185 kwargs = _setup_triangle_plot(parameters, kwargs) 

186 return core(*args, **kwargs) 

187 

188 

189def reverse_triangle_plot(*args, parameters=[], **kwargs): 

190 """Generate a triangular plot made of 3 axis. One central axis showing the 

191 2d marginalized posterior and two smaller axes showing the marginalized 1d 

192 posterior distribution (below and to the left of central axis). Only two 

193 axes are plotted, each below the 1d marginalized posterior distribution 

194 

195 Parameters 

196 ---------- 

197 *args: tuple 

198 all args passed to 

199 pesummary.core.plots.publication.reverse_triangle_plot 

200 parameters: list 

201 list of parameters being plotted 

202 kwargs: dict, optional 

203 all kwargs passed to 

204 pesummary.core.plots.publication.reverse_triangle_plot 

205 """ 

206 from pesummary.core.plots.publication import reverse_triangle_plot as core 

207 kwargs = _setup_triangle_plot(parameters, kwargs) 

208 return core(*args, **kwargs) 

209 

210 

211def violin_plots( 

212 parameter, samples, labels, latex_labels, inj_values=None, cut=0, 

213 _default_kwargs={"palette": "pastel", "inner": "line", "outer": "percent: 90"}, 

214 latex_friendly=True, **kwargs 

215): 

216 """Generate violin plots for a set of parameters and samples 

217 

218 Parameters 

219 ---------- 

220 parameters: str 

221 the name of the parameter that you wish to plot 

222 samples: nd list 

223 list of samples for each parameter 

224 labels: list 

225 list of labels corresponding to each set of samples 

226 latex_labels: dict 

227 dictionary of latex labels 

228 inj_values: list 

229 list of injected values for each set of samples 

230 """ 

231 logger.debug("Generating violin plots for %s" % (parameter)) 

232 fig, ax1 = figure(gca=True) 

233 _default_kwargs.update(kwargs) 

234 ax1 = violin.violinplot( 

235 data=samples, cut=cut, ax=ax1, scale="width", inj=inj_values, **_default_kwargs 

236 ) 

237 if latex_friendly: 

238 labels = copy.deepcopy(labels) 

239 for num, _ in enumerate(labels): 

240 labels[num] = labels[num].replace("_", "\_") 

241 ax1.set_xticklabels(labels) 

242 for label in ax1.get_xmajorticklabels(): 

243 label.set_rotation(30) 

244 ax1.set_ylabel(latex_labels[parameter]) 

245 fig.tight_layout() 

246 return fig 

247 

248 

249def spin_distribution_plots( 

250 parameters, samples, label, color=None, cmap=None, annotate=False, 

251 show_label=True, colorbar=False, vmin=0., 

252 vmax=np.log(1.0 + np.exp(1.) * 3.024) 

253): 

254 """Generate spin distribution plots for a set of parameters and samples 

255 

256 Parameters 

257 ---------- 

258 parameters: list 

259 list of parameters 

260 samples: nd list 

261 list of samples for each spin component 

262 label: str 

263 the label corresponding to the set of samples 

264 color: str, optioanl 

265 color to use for plotting 

266 cmap: str, optional 

267 cmap to use for plotting. cmap is preferentially chosen over color 

268 annotate: Bool, optional 

269 if True, label the magnitude and tilt directions 

270 show_label: Bool, optional 

271 if True, add labels indicating which side of the spin disk corresponds 

272 to which binary component 

273 """ 

274 logger.debug("Generating spin distribution plots for %s" % (label)) 

275 from matplotlib.projections import PolarAxes 

276 from matplotlib.transforms import Affine2D 

277 from matplotlib.patches import Wedge 

278 from matplotlib import patheffects as PathEffects 

279 from matplotlib.collections import PatchCollection 

280 from matplotlib.transforms import ScaledTranslation 

281 

282 from mpl_toolkits.axisartist.grid_finder import MaxNLocator 

283 import mpl_toolkits.axisartist.floating_axes as floating_axes 

284 import mpl_toolkits.axisartist.angle_helper as angle_helper 

285 

286 if color is not None and cmap is None: 

287 cmap = colormap_with_fixed_hue(color) 

288 elif color is None and cmap is None: 

289 raise ValueError( 

290 "Please provide either a single color or a cmap to use for plotting" 

291 ) 

292 

293 spin1 = samples[parameters.index("a_1")] 

294 spin2 = samples[parameters.index("a_2")] 

295 costheta1 = samples[parameters.index("cos_tilt_1")] 

296 costheta2 = samples[parameters.index("cos_tilt_2")] 

297 

298 pts = np.array([spin1, costheta1]) 

299 selected_indices = np.random.choice(pts.shape[1], pts.shape[1] // 2, replace=False) 

300 kde_sel = np.zeros(pts.shape[1], dtype=bool) 

301 kde_sel[selected_indices] = True 

302 kde_pts = pts[:, kde_sel] 

303 spin1 = Bounded_2d_kde(kde_pts, xlow=0, xhigh=.99, ylow=-1, yhigh=1) 

304 pts = np.array([spin2, costheta2]) 

305 selected_indices = np.random.choice(pts.shape[1], pts.shape[1] // 2, replace=False) 

306 kde_sel = np.zeros(pts.shape[1], dtype=bool) 

307 kde_sel[selected_indices] = True 

308 kde_pts = pts[:, kde_sel] 

309 spin2 = Bounded_2d_kde(kde_pts, xlow=0, xhigh=.99, ylow=-1, yhigh=1) 

310 

311 rs = np.linspace(0, .99, 25) 

312 dr = np.abs(rs[1] - rs[0]) 

313 costs = np.linspace(-1, 1, 25) 

314 dcost = np.abs(costs[1] - costs[0]) 

315 COSTS, RS = np.meshgrid(costs[:-1], rs[:-1]) 

316 X = np.arccos(COSTS) * 180 / np.pi + 90. 

317 Y = RS 

318 

319 scale = np.exp(1.0) 

320 spin1_PDF = spin1( 

321 np.vstack([RS.ravel() + dr / 2, COSTS.ravel() + dcost / 2])) 

322 spin2_PDF = spin2( 

323 np.vstack([RS.ravel() + dr / 2, COSTS.ravel() + dcost / 2])) 

324 H1 = np.log(1.0 + scale * spin1_PDF) 

325 H2 = np.log(1.0 + scale * spin2_PDF) 

326 

327 rect = 121 

328 

329 tr = Affine2D().translate(90, 0) + Affine2D().scale(np.pi / 180., 1.) + \ 

330 PolarAxes.PolarTransform() 

331 

332 grid_locator1 = angle_helper.LocatorD(7) 

333 tick_formatter1 = angle_helper.FormatterDMS() 

334 grid_locator2 = MaxNLocator(5) 

335 grid_helper = floating_axes.GridHelperCurveLinear( 

336 tr, extremes=(0, 180, 0, .99), 

337 grid_locator1=grid_locator1, 

338 grid_locator2=grid_locator2, 

339 tick_formatter1=tick_formatter1, 

340 tick_formatter2=None) 

341 

342 fig = figure(figsize=(6, 6), gca=False) 

343 ax1 = floating_axes.FloatingSubplot(fig, rect, grid_helper=grid_helper) 

344 fig.add_subplot(ax1) 

345 

346 ax1.axis["bottom"].toggle(all=False) 

347 ax1.axis["top"].toggle(all=True) 

348 ax1.axis["top"].major_ticks.set_tick_out(True) 

349 

350 ax1.axis["top"].set_axis_direction("top") 

351 ax1.axis["top"].set_ticklabel_direction('+') 

352 

353 ax1.axis["left"].major_ticks.set_tick_out(True) 

354 ax1.axis["left"].set_axis_direction('right') 

355 dx = 7.0 / 72. 

356 dy = 0 / 72. 

357 offset_transform = ScaledTranslation(dx, dy, fig.dpi_scale_trans) 

358 ax1.axis["left"].major_ticklabels.set(figure=fig, 

359 transform=offset_transform) 

360 

361 patches = [] 

362 colors = [] 

363 for x, y, h in zip(X.ravel(), Y.ravel(), H1.ravel()): 

364 cosx = np.cos((x - 90) * np.pi / 180) 

365 cosxp = cosx + dcost 

366 xp = np.arccos(cosxp) 

367 xp = xp * 180. / np.pi + 90. 

368 patches.append(Wedge((0., 0.), y + dr, xp, x, width=dr)) 

369 colors.append(h) 

370 

371 p = PatchCollection(patches, cmap=cmap, edgecolors='face', zorder=10) 

372 p.set_clim(vmin, vmax) 

373 p.set_array(np.array(colors)) 

374 ax1.add_collection(p) 

375 

376 # Spin 2 

377 rect = 122 

378 

379 tr_rotate = Affine2D().translate(90, 0) 

380 tr_scale = Affine2D().scale(np.pi / 180., 1.) 

381 tr = tr_rotate + tr_scale + PolarAxes.PolarTransform() 

382 

383 grid_locator1 = angle_helper.LocatorD(7) 

384 tick_formatter1 = angle_helper.FormatterDMS() 

385 

386 grid_locator2 = MaxNLocator(5) 

387 

388 grid_helper = floating_axes.GridHelperCurveLinear( 

389 tr, extremes=(0, 180, 0, .99), 

390 grid_locator1=grid_locator1, 

391 grid_locator2=grid_locator2, 

392 tick_formatter1=tick_formatter1, 

393 tick_formatter2=None) 

394 

395 ax1 = floating_axes.FloatingSubplot(fig, rect, grid_helper=grid_helper) 

396 ax1.invert_xaxis() 

397 fig.add_subplot(ax1) 

398 

399 # Label angles on the outside 

400 ax1.axis["bottom"].toggle(all=False) 

401 ax1.axis["top"].toggle(all=True) 

402 ax1.axis["top"].set_axis_direction("top") 

403 ax1.axis["top"].major_ticks.set_tick_out(True) 

404 

405 # Remove radial labels 

406 ax1.axis["left"].major_ticks.set_tick_out(True) 

407 ax1.axis["left"].toggle(ticklabels=False) 

408 ax1.axis["left"].major_ticklabels.set_visible(False) 

409 # Also have radial ticks for the lower half of the right semidisk 

410 ax1.axis["right"].major_ticks.set_tick_out(True) 

411 

412 patches = [] 

413 colors = [] 

414 for x, y, h in zip(X.ravel(), Y.ravel(), H2.ravel()): 

415 cosx = np.cos((x - 90) * np.pi / 180) 

416 cosxp = cosx + dcost 

417 xp = np.arccos(cosxp) 

418 xp = xp * 180. / np.pi + 90. 

419 patches.append(Wedge((0., 0.), y + dr, xp, x, width=dr)) 

420 colors.append(h) 

421 

422 p = PatchCollection(patches, cmap=cmap, edgecolors='face', zorder=10) 

423 p.set_clim(vmin, vmax) 

424 p.set_array(np.array(colors)) 

425 ax1.add_collection(p) 

426 

427 # Event name top, spin labels bottom 

428 if label is not None: 

429 title = ax1.text(0.16, 1.25, label, fontsize=18, horizontalalignment='center') 

430 if show_label: 

431 S1_label = ax1.text(1.25, -1.15, r'$c{S}_{1}/(Gm_1^2)$', fontsize=14) 

432 S2_label = ax1.text(-.5, -1.15, r'$c{S}_{2}/(Gm_2^2)$', fontsize=14) 

433 if annotate: 

434 scale = 1.0 

435 aux_ax2 = ax1.get_aux_axes(tr) 

436 txt = aux_ax2.text( 

437 50 * scale, 0.35 * scale, r'$\mathrm{magnitude}$', fontsize=20, 

438 zorder=10 

439 ) 

440 txt = aux_ax2.text( 

441 45 * scale, 1.2 * scale, r'$\mathrm{tilt}$', fontsize=20, zorder=10 

442 ) 

443 txt = aux_ax2.annotate( 

444 "", xy=(55, 1.158 * scale), xycoords='data', 

445 xytext=(35, 1.158 * scale), textcoords='data', 

446 arrowprops=dict( 

447 arrowstyle="->", color="k", shrinkA=2, shrinkB=2, patchA=None, 

448 patchB=None, connectionstyle='arc3,rad=-0.16' 

449 ) 

450 ) 

451 txt.arrow_patch.set_path_effects( 

452 [PathEffects.Stroke(linewidth=2, foreground="w"), PathEffects.Normal()] 

453 ) 

454 txt = aux_ax2.annotate( 

455 "", xy=(35, 0.55 * scale), xycoords='data', 

456 xytext=(150, 0. * scale), textcoords='data', 

457 arrowprops=dict( 

458 arrowstyle="->", color="k", shrinkA=2, shrinkB=2, patchA=None, 

459 patchB=None 

460 ), zorder=100 

461 ) 

462 txt.arrow_patch.set_path_effects( 

463 [ 

464 PathEffects.Stroke(linewidth=0.3, foreground="k"), 

465 PathEffects.Normal() 

466 ] 

467 ) 

468 fig.subplots_adjust(wspace=0.295) 

469 if colorbar: 

470 ax3 = fig.add_axes([0.22, 0.05, 0.55, 0.02]) 

471 cbar = fig.colorbar( 

472 p, cax=ax3, orientation="horizontal", pad=0.2, shrink=0.5, 

473 label='posterior probability per pixel' 

474 ) 

475 return fig