Coverage for pesummary/core/plots/publication.py: 71.0%

259 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-09 22:34 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import numpy as np 

4from matplotlib import gridspec 

5from scipy.stats import gaussian_kde 

6import copy 

7 

8from pesummary.core.plots.figure import figure 

9from .corner import hist2d 

10from pesummary import conf 

11 

12__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

13DEFAULT_LEGEND_KWARGS = {"loc": "best", "frameon": False} 

14 

15 

16def pcolormesh( 

17 x, y, density, ax=None, levels=None, smooth=None, bins=None, label=None, 

18 level_kwargs={}, range=None, grid=True, legend=False, legend_kwargs={}, 

19 weights=None, **kwargs 

20): 

21 """Generate a colormesh plot on a given axis 

22 

23 Parameters 

24 ---------- 

25 x: np.ndarray 

26 array of floats for the x axis 

27 y: np.ndarray 

28 array of floats for the y axis 

29 density: np.ndarray 

30 2d array of probabilities 

31 ax: matplotlib.axes._subplots.AxesSubplot, optional 

32 axis you wish to use for plotting 

33 levels: list, optional 

34 contour levels to show on the plot. Default None 

35 smooth: float, optional 

36 sigma to use for smoothing. Default, no smoothing applied 

37 level_kwargs: dict, optional 

38 optional kwargs to use for ax.contour 

39 **kwargs: dict, optional 

40 all additional kwargs passed to ax.pcolormesh 

41 """ 

42 if smooth is not None: 

43 import scipy.ndimage.filters as filter 

44 density = filter.gaussian_filter(density, sigma=smooth) 

45 if weights is not None: 

46 raise ValueError( 

47 "This function does not currently support weighted data" 

48 ) 

49 _cmap = kwargs.get("cmap", None) 

50 _off = False 

51 if _cmap is not None and isinstance(_cmap, str) and _cmap.lower() == "off": 

52 _off = True 

53 if grid and "zorder" not in kwargs: 

54 _zorder = -10 

55 else: 

56 _zorder = kwargs.pop("zorder", 10) 

57 if not _off: 

58 ax.pcolormesh(x, y, density, zorder=_zorder, **kwargs) 

59 if levels is not None: 

60 CS = ax.contour(x, y, density, levels=levels, **level_kwargs) 

61 if legend: 

62 _legend_kwargs = DEFAULT_LEGEND_KWARGS.copy() 

63 _legend_kwargs.update(legend_kwargs) 

64 CS.collections[0].set_label(label) 

65 ax.legend(**_legend_kwargs) 

66 return ax 

67 

68 

69def analytic_twod_contour_plot(*args, smooth=None, **kwargs): 

70 """Generate a 2d contour plot given an analytic PDF 

71 

72 Parameters 

73 ---------- 

74 *args: tuple 

75 all args passed to twod_contour_plot 

76 smooth: float, optional 

77 degree of smoothing to apply to probabilities 

78 **kwargs: dict, optional 

79 all additional kwargs passed to twod_contour_plot 

80 """ 

81 return twod_contour_plot( 

82 *args, smooth=smooth, _function=pcolormesh, **kwargs 

83 ) 

84 

85 

86def twod_contour_plot( 

87 x, y, *args, rangex=None, rangey=None, fig=None, ax=None, return_ax=False, 

88 levels=[0.9], bins=300, smooth=7, xlabel=None, ylabel=None, 

89 fontsize={"label": 12}, grid=True, label=None, truth=None, 

90 _function=hist2d, truth_lines=True, truth_kwargs={}, weights=None, 

91 _default_truth_kwargs={ 

92 "marker": 'o', "markeredgewidth": 2, "markersize": 6, "color": 'k' 

93 }, **kwargs 

94): 

95 """Generate a 2d contour contour plot for 2 marginalized posterior 

96 distributions 

97 

98 Parameters 

99 ---------- 

100 x: np.array 

101 array of posterior samples to use for the x axis 

102 y: np.array 

103 array of posterior samples to use for the y axis 

104 rangex: tuple, optional 

105 range over which to plot the x axis 

106 rangey: tuple, optional 

107 range over which to plot the y axis 

108 fig: matplotlib.figure.Figure, optional 

109 figure you wish to use for plotting 

110 ax: matplotlib.axes._subplots.AxesSubplot, optional 

111 axis you wish to use for plotting 

112 return_ax: Bool, optional 

113 if True return the axis used for plotting. Else return the figure 

114 levels: list, optional 

115 levels you wish to use for the 2d contours. Default [0.9] 

116 bins: int, optional 

117 number of bins to use for gridding 2d parameter space. Default 300 

118 smooth: int, optional 

119 how much smoothing you wish to use for the 2d contours 

120 xlabel: str, optional 

121 label to use for the xaxis 

122 ylabel: str, optional 

123 label to use for the yaxis 

124 fontsize: dict, optional 

125 dictionary containing the fontsize to use for the plot 

126 grid: Bool, optional 

127 if True, add a grid to the plot 

128 label: str, optional 

129 label to use for a given contour 

130 truth: list, optional 

131 the true value of the posterior. `truth` is a list of length 2 with 

132 first element being the true x value and second element being the true 

133 y value 

134 truth_lines: Bool, optional 

135 if True, add vertical and horizontal lines spanning the 2d space to show 

136 injected value 

137 truth_kwargs: dict, optional 

138 kwargs to use to indicate truth 

139 **kwargs: dict, optional 

140 all additional kwargs are passed to the 

141 `pesummary.core.plots.corner.hist2d` function 

142 """ 

143 if fig is None and ax is None: 

144 fig, ax = figure(gca=True) 

145 elif fig is None and ax is not None: 

146 return_ax = True 

147 elif ax is None: 

148 ax = fig.gca() 

149 

150 xlow, xhigh = np.min(x), np.max(x) 

151 ylow, yhigh = np.min(y), np.max(y) 

152 if rangex is not None: 

153 xlow, xhigh = rangex 

154 if rangey is not None: 

155 ylow, yhigh = rangey 

156 if "range" not in list(kwargs.keys()): 

157 kwargs["range"] = [[xlow, xhigh], [ylow, yhigh]] 

158 

159 _function( 

160 x, y, *args, ax=ax, levels=levels, bins=bins, smooth=smooth, 

161 label=label, grid=grid, weights=weights, **kwargs 

162 ) 

163 if truth is not None: 

164 _default_truth_kwargs.update(truth_kwargs) 

165 ax.plot(*truth, **_default_truth_kwargs) 

166 if truth_lines: 

167 ax.axvline( 

168 truth[0], color=_default_truth_kwargs["color"], linewidth=0.5 

169 ) 

170 ax.axhline( 

171 truth[1], color=_default_truth_kwargs["color"], linewidth=0.5 

172 ) 

173 if xlabel is not None: 

174 ax.set_xlabel(xlabel, fontsize=fontsize["label"]) 

175 if ylabel is not None: 

176 ax.set_ylabel(ylabel, fontsize=fontsize["label"]) 

177 ax.grid(grid) 

178 if fig is not None: 

179 fig.tight_layout() 

180 if return_ax: 

181 return ax 

182 return fig 

183 

184 

185def comparison_twod_contour_plot( 

186 x, y, labels=None, plot_density=None, rangex=None, rangey=None, 

187 legend_kwargs={"loc": "best", "frameon": False}, 

188 colors=list(conf.colorcycle), linestyles=None, **kwargs 

189): 

190 """Generate a comparison 2d contour contour plot for 2 marginalized 

191 posterior distributions from multiple analyses 

192 

193 Parameters 

194 ---------- 

195 x: np.ndarray 

196 2d array of posterior samples to use for the x axis; array for each 

197 analysis 

198 y: np.ndarray 

199 2d array of posterior samples to use for the y axis; array for each 

200 analysis 

201 labels: list, optional 

202 labels to assign to each contour 

203 plot_density: str, optional 

204 label of the analysis you wish to plot the density for. If you wish 

205 to plot both, simply pass `plot_density='both'` 

206 rangex: tuple, optional 

207 range over which to plot the x axis 

208 rangey: tuple, optional 

209 range over which to plot the y axis 

210 legend_kwargs: dict, optional 

211 kwargs to use for the legend 

212 colors: list, optional 

213 list of colors to use for each contour 

214 linestyles: list, optional 

215 linestyles to use for each contour 

216 **kwargs: dict, optional 

217 all additional kwargs are passed to the 

218 `pesummary.core.plots.publication.twod_contour_plot` function 

219 """ 

220 if labels is None and plot_density is not None: 

221 plot_density = None 

222 if labels is None: 

223 labels = [None] * len(x) 

224 

225 xlow = np.min([np.min(_x) for _x in x]) 

226 xhigh = np.max([np.max(_x) for _x in x]) 

227 ylow = np.min([np.min(_y) for _y in y]) 

228 yhigh = np.max([np.max(_y) for _y in y]) 

229 if rangex is None: 

230 rangex = [xlow, xhigh] 

231 if rangey is None: 

232 rangey = [ylow, yhigh] 

233 

234 fig = None 

235 for num, (_x, _y) in enumerate(zip(x, y)): 

236 if plot_density is not None and plot_density == labels[num]: 

237 plot_density = True 

238 elif plot_density is not None and isinstance(plot_density, list): 

239 if labels[num] in plot_density: 

240 plot_density = True 

241 else: 

242 plot_density = False 

243 elif plot_density is not None and plot_density == "both": 

244 plot_density = True 

245 else: 

246 plot_density = False 

247 

248 _label = _color = _linestyle = None 

249 if labels is not None: 

250 _label = labels[num] 

251 if colors is not None: 

252 _color = colors[num] 

253 if linestyles is not None: 

254 _linestyle = linestyles[num] 

255 fig = twod_contour_plot( 

256 _x, _y, plot_density=plot_density, label=_label, fig=fig, 

257 rangex=rangex, rangey=rangey, color=_color, linestyles=_linestyle, 

258 **kwargs 

259 ) 

260 ax = fig.gca() 

261 legend = ax.legend(**legend_kwargs) 

262 return fig 

263 

264 

265def _triangle_axes( 

266 figsize=(8, 8), width_ratios=[4, 1], height_ratios=[1, 4], wspace=0.0, 

267 hspace=0.0, 

268): 

269 """Initialize the axes for a 2d triangle plot 

270 

271 Parameters 

272 ---------- 

273 figsize: tuple, optional 

274 figure size you wish to use. Default (8, 8) 

275 width_ratios: list, optional 

276 ratio of widths for the triangular axis. Default 4:1 

277 height_ratios: list, optional 

278 ratio of heights for the triangular axis. Default 1:4 

279 wspace: float, optional 

280 horizontal space between the axis. Default 0.0 

281 hspace: float, optional 

282 vertical space between the axis. Default 0.0 

283 """ 

284 high1d = 1.0 

285 fig = figure(figsize=figsize, gca=False) 

286 gs = gridspec.GridSpec( 

287 2, 2, width_ratios=width_ratios, height_ratios=height_ratios, 

288 wspace=wspace, hspace=hspace 

289 ) 

290 ax1, ax2, ax3, ax4 = ( 

291 fig.add_subplot(gs[0]), 

292 fig.add_subplot(gs[1]), 

293 fig.add_subplot(gs[2]), 

294 fig.add_subplot(gs[3]), 

295 ) 

296 ax1.minorticks_on() 

297 ax3.minorticks_on() 

298 ax4.minorticks_on() 

299 ax1.xaxis.set_ticklabels([]) 

300 ax4.yaxis.set_ticklabels([]) 

301 return fig, ax1, ax2, ax3, ax4 

302 

303 

304def _generate_triangle_plot( 

305 *args, function=None, fig_kwargs={}, existing_figure=None, **kwargs 

306): 

307 """Generate a triangle plot according to a given function 

308 

309 Parameters 

310 ---------- 

311 *args: tuple 

312 all args passed to function 

313 function: func, optional 

314 function you wish to use to generate triangle plot. Default 

315 _triangle_plot 

316 **kwargs: dict, optional 

317 all kwargs passed to function 

318 """ 

319 if existing_figure is None: 

320 fig, ax1, ax2, ax3, ax4 = _triangle_axes(**fig_kwargs) 

321 ax2.axis("off") 

322 else: 

323 fig, ax1, ax3, ax4 = existing_figure 

324 if function is None: 

325 function = _triangle_plot 

326 return function(fig, [ax1, ax3, ax4], *args, **kwargs) 

327 

328 

329def triangle_plot(*args, **kwargs): 

330 """Generate a triangular plot made of 3 axis. One central axis showing the 

331 2d marginalized posterior and two smaller axes showing the marginalized 1d 

332 posterior distribution (above and to the right of central axis) 

333 

334 Parameters 

335 ---------- 

336 x: list 

337 list of samples for the x axis 

338 y: list 

339 list of samples for the y axis 

340 kde: Bool/func, optional 

341 kde to use for smoothing the 1d marginalized posterior distribution. If 

342 you do not want to use KDEs, simply pass kde=False. Default 

343 scipy.stats.gaussian_kde 

344 kde_2d: func, optional 

345 kde to use for smoothing the 2d marginalized posterior distribution. 

346 default None 

347 npoints: int, optional 

348 number of points to use for the 1d kde 

349 kde_kwargs: dict, optional 

350 optional kwargs which are passed directly to the kde function 

351 kde_2d_kwargs: dict, optional 

352 optional kwargs which are passed directly to the 2d kde function 

353 fill: Bool, optional 

354 whether or not to fill the 1d posterior distributions 

355 fill_alpha: float, optional 

356 alpha to use for fill 

357 levels: list, optional 

358 levels you wish to use for the 2d contours 

359 smooth: dict/float, optional 

360 how much smoothing you wish to use for the 2d contours. If you wish 

361 to use different smoothing for different contours, then provide a dict 

362 with keys given by the label 

363 colors: list, optional 

364 list of colors you wish to use for each analysis 

365 xlabel: str, optional 

366 xlabel you wish to use for the plot 

367 ylabel: str, optional 

368 ylabel you wish to use for the plot 

369 fontsize: dict, optional 

370 dictionary giving the fontsize for the labels and legend. Default 

371 {'legend': 12, 'label': 12} 

372 linestyles: list, optional 

373 linestyles you wish to use for each analysis 

374 linewidths: list, optional 

375 linewidths you wish to use for each analysis 

376 plot_density: Bool, optional 

377 whether or not to plot the density on the 2d contour. Default True 

378 percentiles: list, optional 

379 percentiles you wish to plot. Default None 

380 percentile_plot: list, optional 

381 list of analyses to plot percentiles. Default all analyses 

382 fig_kwargs: dict, optional 

383 optional kwargs passed directly to the _triangle_axes function 

384 labels: list, optional 

385 label associated with each set of samples 

386 rangex: tuple, optional 

387 range over which to plot the x axis 

388 rangey: tuple, optional 

389 range over which to plot the y axis 

390 grid: Bool, optional 

391 if True, show a grid on all axes. Default False 

392 legend_kwargs: dict, optional 

393 optional kwargs for the legend. Default {"loc": "best", "frameon": False} 

394 **kwargs: dict 

395 all additional kwargs are passed to the corner.hist2d function 

396 """ 

397 return _generate_triangle_plot(*args, function=_triangle_plot, **kwargs) 

398 

399 

400def analytic_triangle_plot(*args, **kwargs): 

401 """Generate a triangle plot given probability densities for x, y and xy. 

402 

403 Parameters 

404 ---------- 

405 fig: matplotlib.figure.Figure 

406 figure on which to make the plots 

407 axes: list 

408 list of subplots associated with the figure 

409 x: list 

410 list of points to use for the x axis 

411 y: list 

412 list of points to use for the y axis 

413 prob_x: list 

414 list of probabilities associated with x 

415 prob_y: list 

416 list of probabilities associated with y 

417 probs_xy: list 

418 2d list of probabilities for xy 

419 smooth: float, optional 

420 degree of smoothing to apply to probs_xy. Default no smoothing applied 

421 cmap: str, optional 

422 name of cmap to use for plotting 

423 """ 

424 return _generate_triangle_plot( 

425 *args, function=_analytic_triangle_plot, **kwargs 

426 ) 

427 

428 

429def _analytic_triangle_plot( 

430 fig, axes, x, y, probs_x, probs_y, probs_xy, smooth=None, xlabel=None, 

431 ylabel=None, grid=True, **kwargs 

432): 

433 """Generate a triangle plot given probability densities for x, y and xy. 

434 

435 Parameters 

436 ---------- 

437 fig: matplotlib.figure.Figure 

438 figure on which to make the plots 

439 axes: list 

440 list of subplots associated with the figure 

441 x: list 

442 list of points to use for the x axis 

443 y: list 

444 list of points to use for the y axis 

445 prob_x: list 

446 list of probabilities associated with x 

447 prob_y: list 

448 list of probabilities associated with y 

449 probs_xy: list 

450 2d list of probabilities for xy 

451 smooth: float, optional 

452 degree of smoothing to apply to probs_xy. Default no smoothing applied 

453 xlabel: str, optional 

454 label to use for the x axis 

455 ylabel: str, optional 

456 label to use for the y axis 

457 grid: Bool, optional 

458 if True, add a grid to the plot 

459 """ 

460 ax1, ax3, ax4 = axes 

461 analytic_twod_contour_plot( 

462 x, y, probs_xy, ax=ax3, smooth=smooth, grid=grid, **kwargs 

463 ) 

464 level_kwargs = kwargs.get("level_kwargs", None) 

465 if level_kwargs is not None and "colors" in level_kwargs.keys(): 

466 color = level_kwargs["colors"][0] 

467 else: 

468 color = None 

469 ax1.plot(x, probs_x, color=color) 

470 ax4.plot(probs_y, y, color=color) 

471 fontsize = kwargs.get("fontsize", {"label": 12}) 

472 if xlabel is not None: 

473 ax3.set_xlabel(xlabel, fontsize=fontsize["label"]) 

474 if ylabel is not None: 

475 ax3.set_ylabel(ylabel, fontsize=fontsize["label"]) 

476 ax1.grid(grid) 

477 if grid: 

478 ax3.grid(grid, zorder=10) 

479 ax4.grid(grid) 

480 xlims = ax3.get_xlim() 

481 ax1.set_xlim(xlims) 

482 ylims = ax3.get_ylim() 

483 ax4.set_ylim(ylims) 

484 fig.tight_layout() 

485 return fig, ax1, ax3, ax4 

486 

487 

488def _triangle_plot( 

489 fig, axes, x, y, kde=gaussian_kde, npoints=100, kde_kwargs={}, fill=True, 

490 fill_alpha=0.5, levels=[0.9], smooth=7, colors=list(conf.colorcycle), 

491 xlabel=None, ylabel=None, fontsize={"legend": 12, "label": 12}, 

492 linestyles=None, linewidths=None, plot_density=True, percentiles=None, 

493 percentile_plot=None, fig_kwargs={}, labels=None, plot_datapoints=False, 

494 rangex=None, rangey=None, grid=False, latex_friendly=False, kde_2d=None, 

495 kde_2d_kwargs={}, legend_kwargs={"loc": "best", "frameon": False}, 

496 truth=None, hist_kwargs={"density": True, "bins": 50}, 

497 _contour_function=twod_contour_plot, weights=None, **kwargs 

498): 

499 """Base function to generate a triangular plot 

500 

501 Parameters 

502 ---------- 

503 fig: matplotlib.figure.Figure 

504 figure on which to make the plots 

505 axes: list 

506 list of subplots associated with the figure 

507 x: list 

508 list of samples for the x axis 

509 y: list 

510 list of samples for the y axis 

511 kde: Bool/func, optional 

512 kde to use for smoothing the 1d marginalized posterior distribution. If 

513 you do not want to use KDEs, simply pass kde=False. Default 

514 scipy.stats.gaussian_kde 

515 kde_2d: func, optional 

516 kde to use for smoothing the 2d marginalized posterior distribution. 

517 default None 

518 npoints: int, optional 

519 number of points to use for the 1d kde 

520 kde_kwargs: dict, optional 

521 optional kwargs which are passed directly to the kde function. 

522 kde_kwargs to be passed to the kde on the y axis may be specified 

523 by the dictionary entry 'y_axis'. kde_kwargs to be passed to the kde on 

524 the x axis may be specified by the dictionary entry 'x_axis'. 

525 kde_2d_kwargs: dict, optional 

526 optional kwargs which are passed directly to the 2d kde function 

527 fill: Bool, optional 

528 whether or not to fill the 1d posterior distributions 

529 fill_alpha: float, optional 

530 alpha to use for fill 

531 levels: list, optional 

532 levels you wish to use for the 2d contours 

533 smooth: dict/float, optional 

534 how much smoothing you wish to use for the 2d contours. If you wish 

535 to use different smoothing for different contours, then provide a dict 

536 with keys given by the label 

537 colors: list, optional 

538 list of colors you wish to use for each analysis 

539 xlabel: str, optional 

540 xlabel you wish to use for the plot 

541 ylabel: str, optional 

542 ylabel you wish to use for the plot 

543 fontsize: dict, optional 

544 dictionary giving the fontsize for the labels and legend. Default 

545 {'legend': 12, 'label': 12} 

546 linestyles: list, optional 

547 linestyles you wish to use for each analysis 

548 linewidths: list, optional 

549 linewidths you wish to use for each analysis 

550 plot_density: Bool, optional 

551 whether or not to plot the density on the 2d contour. Default True 

552 percentiles: list, optional 

553 percentiles you wish to plot. Default None 

554 percentile_plot: list, optional 

555 list of analyses to plot percentiles. Default all analyses 

556 fig_kwargs: dict, optional 

557 optional kwargs passed directly to the _triangle_axes function 

558 labels: list, optional 

559 label associated with each set of samples 

560 rangex: tuple, optional 

561 range over which to plot the x axis 

562 rangey: tuple, optional 

563 range over which to plot the y axis 

564 grid: Bool, optional 

565 if True, show a grid on all axes 

566 legend_kwargs: dict, optional 

567 optional kwargs for the legend. Default {"loc": "best", "frameon": False} 

568 **kwargs: dict 

569 all kwargs are passed to the corner.hist2d function 

570 """ 

571 ax1, ax3, ax4 = axes 

572 if not isinstance(x[0], (list, np.ndarray)): 

573 x, y = np.atleast_2d(x), np.atleast_2d(y) 

574 _base_error = "Please provide {} for each analysis" 

575 if len(colors) < len(x): 

576 raise ValueError(_base_error.format("a single color")) 

577 if linestyles is None: 

578 linestyles = ["-"] * len(x) 

579 elif len(linestyles) < len(x): 

580 raise ValueError(_base_error.format("a single linestyle")) 

581 if linewidths is None: 

582 linewidths = [None] * len(x) 

583 elif len(linewidths) < len(x): 

584 raise ValueError(_base_error.format("a single linewidth")) 

585 if labels is None: 

586 labels = [None] * len(x) 

587 elif len(labels) != len(x): 

588 raise ValueError(_base_error.format("a label")) 

589 

590 xlow = np.min([np.min(_x) for _x in x]) 

591 xhigh = np.max([np.max(_x) for _x in x]) 

592 ylow = np.min([np.min(_y) for _y in y]) 

593 yhigh = np.max([np.max(_y) for _y in y]) 

594 if rangex is not None: 

595 xlow, xhigh = rangex 

596 if rangey is not None: 

597 ylow, yhigh = rangey 

598 for num in range(len(x)): 

599 plot_kwargs = dict( 

600 color=colors[num], linewidth=linewidths[num], 

601 linestyle=linestyles[num] 

602 ) 

603 if kde: 

604 if "x_axis" in kde_kwargs.keys(): 

605 _kde = kde(x[num], weights=weights, **kde_kwargs["x_axis"]) 

606 else: 

607 _kde = kde(x[num], weights=weights, **kde_kwargs) 

608 _x = np.linspace(xlow, xhigh, npoints) 

609 _y = _kde(_x) 

610 ax1.plot(_x, _y, **plot_kwargs) 

611 if fill: 

612 ax1.fill_between(_x, 0, _y, alpha=fill_alpha, **plot_kwargs) 

613 _y = np.linspace(ylow, yhigh, npoints) 

614 if "y_axis" in kde_kwargs.keys(): 

615 _kde = kde(y[num], weights=weights, **kde_kwargs["y_axis"]) 

616 else: 

617 _kde = kde(y[num], weights=weights, **kde_kwargs) 

618 _x = _kde(_y) 

619 if latex_friendly: 

620 labels = copy.deepcopy(labels) 

621 labels[num] = labels[num].replace("_", "\_") 

622 ax4.plot(_x, _y, label=labels[num], **plot_kwargs) 

623 if fill: 

624 ax4.fill_betweenx(_y, 0, _x, alpha=fill_alpha, **plot_kwargs) 

625 else: 

626 if fill: 

627 histtype = "stepfilled" 

628 else: 

629 histtype = "step" 

630 ax1.hist( 

631 x[num], histtype=histtype, weights=weights, **hist_kwargs, 

632 **plot_kwargs 

633 ) 

634 ax4.hist( 

635 y[num], histtype=histtype, weights=weights, orientation="horizontal", 

636 **hist_kwargs, **plot_kwargs 

637 ) 

638 if percentiles is not None: 

639 if percentile_plot is not None and labels[num] in percentile_plot: 

640 _percentiles = np.percentile(x[num], percentiles) 

641 ax1.axvline( 

642 _percentiles[0], linestyle="--", 

643 linewidth=plot_kwargs.get("linewidth", 1.75) 

644 ) 

645 ax1.axvline( 

646 _percentiles[1], linestyle="--", 

647 linewidth=plot_kwargs.get("linewidth", 1.75) 

648 ) 

649 _percentiles = np.percentile(y[num], percentiles) 

650 ax4.axhline( 

651 _percentiles[0], linestyle="--", 

652 linewidth=plot_kwargs.get("linewidth", 1.75) 

653 ) 

654 ax4.axhline( 

655 _percentiles[1], linestyle="--", 

656 linewidth=plot_kwargs.get("linewidth", 1.75) 

657 ) 

658 if isinstance(smooth, dict): 

659 _smooth = smooth[labels[num]] 

660 else: 

661 _smooth = smooth 

662 _contour_function( 

663 x[num], y[num], ax=ax3, levels=levels, smooth=_smooth, 

664 rangex=[xlow, xhigh], rangey=[ylow, yhigh], color=colors[num], 

665 linestyles=linestyles[num], weights=weights, 

666 plot_density=plot_density, contour_kwargs=dict( 

667 linestyles=[linestyles[num]], linewidths=linewidths[num] 

668 ), plot_datapoints=plot_datapoints, kde=kde_2d, 

669 kde_kwargs=kde_2d_kwargs, grid=False, truth=truth, **kwargs 

670 ) 

671 

672 if truth is not None: 

673 ax1.axvline(truth[0], color='k', linewidth=0.5) 

674 ax4.axhline(truth[1], color='k', linewidth=0.5) 

675 if xlabel is not None: 

676 ax3.set_xlabel(xlabel, fontsize=fontsize["label"]) 

677 if ylabel is not None: 

678 ax3.set_ylabel(ylabel, fontsize=fontsize["label"]) 

679 if not all(label is None for label in labels): 

680 legend_kwargs["fontsize"] = fontsize["legend"] 

681 ax3.legend(*ax4.get_legend_handles_labels(), **legend_kwargs) 

682 ax1.grid(grid) 

683 ax3.grid(grid) 

684 ax4.grid(grid) 

685 xlims = ax1.get_xlim() 

686 ax3.set_xlim(xlims) 

687 ylims = ax4.get_ylim() 

688 ax3.set_ylim(ylims) 

689 return fig, ax1, ax3, ax4 

690 

691 

692def _generate_reverse_triangle_plot( 

693 *args, xlabel=None, ylabel=None, function=None, existing_figure=None, **kwargs 

694): 

695 """Generate a reverse triangle plot according to a given function 

696 

697 Parameters 

698 ---------- 

699 *args: tuple 

700 all args passed to function 

701 xlabel: str, optional 

702 label to use for the x axis 

703 ylabel: str, optional 

704 label to use for the y axis 

705 function: func, optional 

706 function to use to generate triangle plot. Default _triangle_plot 

707 **kwargs: dict, optional 

708 all kwargs passed to function 

709 """ 

710 if existing_figure is None: 

711 fig, ax1, ax2, ax3, ax4 = _triangle_axes( 

712 width_ratios=[1, 4], height_ratios=[4, 1] 

713 ) 

714 ax3.axis("off") 

715 else: 

716 fig, ax1, ax2, ax4 = existing_figure 

717 if function is None: 

718 function = _triangle_plot 

719 fig, ax4, ax2, ax1 = function(fig, [ax4, ax2, ax1], *args, **kwargs) 

720 ax2.axis("off") 

721 ax4.spines["right"].set_visible(False) 

722 ax4.spines["top"].set_visible(False) 

723 ax4.spines["left"].set_visible(False) 

724 ax4.set_yticks([]) 

725 

726 ax1.spines["right"].set_visible(False) 

727 ax1.spines["top"].set_visible(False) 

728 ax1.spines["bottom"].set_visible(False) 

729 ax1.set_xticks([]) 

730 

731 _fontsize = kwargs.get("fontsize", {"label": 12})["label"] 

732 if xlabel is not None: 

733 ax4.set_xlabel(xlabel, fontsize=_fontsize) 

734 if ylabel is not None: 

735 ax1.set_ylabel(ylabel, fontsize=_fontsize) 

736 return fig, ax1, ax2, ax4 

737 

738 

739def reverse_triangle_plot(*args, **kwargs): 

740 """Generate a triangular plot made of 3 axis. One central axis showing the 

741 2d marginalized posterior and two smaller axes showing the marginalized 1d 

742 posterior distribution (below and to the left of central axis). Only two 

743 axes are plotted, each below the 1d marginalized posterior distribution 

744 

745 Parameters 

746 ---------- 

747 x: list 

748 list of samples for the x axis 

749 y: list 

750 list of samples for the y axis 

751 kde: Bool/func, optional 

752 kde to use for smoothing the 1d marginalized posterior distribution. If 

753 you do not want to use KDEs, simply pass kde=False. Default 

754 scipy.stats.gaussian_kde 

755 kde_2d: func, optional 

756 kde to use for smoothing the 2d marginalized posterior distribution. 

757 default None 

758 npoints: int, optional 

759 number of points to use for the 1d kde 

760 kde_kwargs: dict, optional 

761 optional kwargs which are passed directly to the kde function. 

762 kde_kwargs to be passed to the kde on the y axis may be specified 

763 by the dictionary entry 'y_axis'. kde_kwargs to be passed to the kde on 

764 the x axis may be specified by the dictionary entry 'x_axis'. 

765 kde_2d_kwargs: dict, optional 

766 optional kwargs which are passed directly to the 2d kde function 

767 fill: Bool, optional 

768 whether or not to fill the 1d posterior distributions 

769 fill_alpha: float, optional 

770 alpha to use for fill 

771 levels: list, optional 

772 levels you wish to use for the 2d contours 

773 smooth: dict/float, optional 

774 how much smoothing you wish to use for the 2d contours. If you wish 

775 to use different smoothing for different contours, then provide a dict 

776 with keys given by the label 

777 colors: list, optional 

778 list of colors you wish to use for each analysis 

779 xlabel: str, optional 

780 xlabel you wish to use for the plot 

781 ylabel: str, optional 

782 ylabel you wish to use for the plot 

783 fontsize: dict, optional 

784 dictionary giving the fontsize for the labels and legend. Default 

785 {'legend': 12, 'label': 12} 

786 linestyles: list, optional 

787 linestyles you wish to use for each analysis 

788 linewidths: list, optional 

789 linewidths you wish to use for each analysis 

790 plot_density: Bool, optional 

791 whether or not to plot the density on the 2d contour. Default True 

792 percentiles: list, optional 

793 percentiles you wish to plot. Default None 

794 percentile_plot: list, optional 

795 list of analyses to plot percentiles. Default all analyses 

796 fig_kwargs: dict, optional 

797 optional kwargs passed directly to the _triangle_axes function 

798 labels: list, optional 

799 label associated with each set of samples 

800 rangex: tuple, optional 

801 range over which to plot the x axis 

802 rangey: tuple, optional 

803 range over which to plot the y axis 

804 legend_kwargs: dict, optional 

805 optional kwargs for the legend. Default {"loc": "best", "frameon": False} 

806 **kwargs: dict 

807 all kwargs are passed to the corner.hist2d function 

808 """ 

809 return _generate_reverse_triangle_plot( 

810 *args, function=_triangle_plot, **kwargs 

811 ) 

812 

813 

814def analytic_reverse_triangle_plot(*args, **kwargs): 

815 """Generate a triangle plot given probability densities for x, y and xy. 

816 

817 Parameters 

818 ---------- 

819 fig: matplotlib.figure.Figure 

820 figure on which to make the plots 

821 axes: list 

822 list of subplots associated with the figure 

823 x: list 

824 list of points to use for the x axis 

825 y: list 

826 list of points to use for the y axis 

827 prob_x: list 

828 list of probabilities associated with x 

829 prob_y: list 

830 list of probabilities associated with y 

831 probs_xy: list 

832 2d list of probabilities for xy 

833 smooth: float, optional 

834 degree of smoothing to apply to probs_xy. Default no smoothing applied 

835 cmap: str, optional 

836 name of cmap to use for plotting 

837 """ 

838 return _generate_reverse_triangle_plot( 

839 *args, function=_analytic_triangle_plot, **kwargs 

840 )