Coverage for pesummary/core/plots/plot.py: 65.3%

372 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3from pesummary.utils.utils import ( 

4 logger, number_of_columns_for_legend, _check_latex_install, gelman_rubin, 

5) 

6from pesummary.core.plots.corner import corner 

7from pesummary.core.plots.figure import figure, ExistingFigure 

8from pesummary import conf 

9 

10import matplotlib.lines as mlines 

11import copy 

12from itertools import cycle 

13 

14import numpy as np 

15from scipy import signal 

16 

17__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

18_check_latex_install() 

19 

20_default_legend_kwargs = dict( 

21 bbox_to_anchor=(0.0, 1.02, 1.0, 0.102), loc=3, handlelength=3, mode="expand", 

22 borderaxespad=0.0, 

23) 

24 

25 

26def _update_1d_comparison_legend(legend, linestyles, linewidth=1.75): 

27 """Update the width and style of lines in the legend for a 1d comparison plot. 

28 """ 

29 try: 

30 handles = legend.legend_handles 

31 except AttributeError: # matplotlib < 3.7.0 

32 handles = legend.legendHandles 

33 for handle, style in zip(handles, linestyles): 

34 handle.set_linewidth(linewidth) 

35 handle.set_linestyle(style) 

36 

37 

38def _autocorrelation_plot( 

39 param, samples, fig=None, color=conf.color, markersize=0.5, grid=True 

40): 

41 """Generate the autocorrelation function for a set of samples for a given 

42 parameter for a given approximant. 

43 

44 Parameters 

45 ---------- 

46 param: str 

47 name of the parameter that you wish to plot 

48 samples: list 

49 list of samples for param 

50 fig: matplotlib.pyplot.figure 

51 existing figure you wish to use 

52 color: str, optional 

53 color you wish to use for the autocorrelation plot 

54 grid: Bool, optional 

55 if True, plot a grid 

56 """ 

57 import warnings 

58 warnings.filterwarnings("ignore", category=RuntimeWarning) 

59 logger.debug("Generating the autocorrelation function for %s" % (param)) 

60 if fig is None: 

61 fig, ax = figure(gca=True) 

62 else: 

63 ax = fig.gca() 

64 samples = samples[int(len(samples) / 2):] 

65 x = samples - np.mean(samples) 

66 y = np.conj(x[::-1]) 

67 acf = np.fft.ifftshift(signal.fftconvolve(y, x, mode="full")) 

68 N = np.array(samples).shape[0] 

69 acf = acf[0:N] 

70 # Hack to make test pass with python3.8 

71 if color == "$": 

72 color = conf.color 

73 ax.plot( 

74 acf / acf[0], linestyle=" ", marker="o", markersize=markersize, 

75 color=color 

76 ) 

77 ax.ticklabel_format(axis="x", style="plain") 

78 ax.set_xlabel("lag") 

79 ax.set_ylabel("ACF") 

80 ax.grid(visible=grid) 

81 fig.tight_layout() 

82 return fig 

83 

84 

85def _autocorrelation_plot_mcmc( 

86 param, samples, colorcycle=conf.colorcycle, grid=True 

87): 

88 """Generate the autocorrelation function for a set of samples for a given 

89 parameter for a given set of mcmc chains 

90 

91 Parameters 

92 ---------- 

93 param: str 

94 name of the parameter that you wish to plot 

95 samples: np.ndarray 

96 2d array containing a list of samples for param for each mcmc chain 

97 colorcycle: list, str 

98 color cycle you wish to use for the different mcmc chains 

99 grid: Bool, optional 

100 if True, plot a grid 

101 """ 

102 cycol = cycle(colorcycle) 

103 fig, ax = figure(gca=True) 

104 for ss in samples: 

105 fig = _autocorrelation_plot( 

106 param, ss, fig=fig, markersize=1.25, color=next(cycol), grid=grid 

107 ) 

108 return fig 

109 

110 

111def _sample_evolution_plot( 

112 param, samples, latex_label, inj_value=None, fig=None, color=conf.color, 

113 markersize=0.5, grid=True, z=None, z_label=None, **kwargs 

114): 

115 """Generate a scatter plot showing the evolution of the samples for a 

116 given parameter for a given approximant. 

117 

118 Parameters 

119 ---------- 

120 param: str 

121 name of the parameter that you wish to plot 

122 samples: list 

123 list of samples for param 

124 latex_label: str 

125 latex label for param 

126 inj_value: float 

127 value that was injected 

128 fig: matplotlib.pyplot.figure, optional 

129 existing figure you wish to use 

130 color: str, optional 

131 color you wish to use to plot the scatter points 

132 grid: Bool, optional 

133 if True, plot a grid 

134 """ 

135 logger.debug("Generating the sample scatter plot for %s" % (param)) 

136 if fig is None: 

137 fig, ax = figure(gca=True) 

138 else: 

139 ax = fig.gca() 

140 n_samples = len(samples) 

141 add_cbar = True if z is not None else False 

142 if z is None: 

143 z = color 

144 s = ax.scatter( 

145 range(n_samples), samples, marker="o", s=markersize, c=z, 

146 **kwargs 

147 ) 

148 if add_cbar: 

149 cbar = fig.colorbar(s) 

150 if z_label is not None: 

151 cbar.set_label(z_label) 

152 ax.ticklabel_format(axis="x", style="plain") 

153 ax.set_xlabel("samples") 

154 ax.set_ylabel(latex_label) 

155 ax.grid(visible=grid) 

156 fig.tight_layout() 

157 return fig 

158 

159 

160def _sample_evolution_plot_mcmc( 

161 param, samples, latex_label, inj_value=None, colorcycle=conf.colorcycle, 

162 grid=True 

163): 

164 """Generate a scatter plot showing the evolution of the samples in each 

165 mcmc chain for a given parameter 

166 

167 Parameters 

168 ---------- 

169 param: str 

170 name of the parameter that you wish to plot 

171 samples: np.ndarray 

172 2d array containing the samples for param for each mcmc chain 

173 latex_label: str 

174 latex label for param 

175 inj_value: float 

176 value that was injected 

177 colorcycle: list, str 

178 color cycle you wish to use for the different mcmc chains 

179 grid: Bool, optional 

180 if True, plot a grid 

181 """ 

182 cycol = cycle(colorcycle) 

183 fig, ax = figure(gca=True) 

184 for ss in samples: 

185 fig = _sample_evolution_plot( 

186 param, ss, latex_label, inj_value=None, fig=fig, markersize=1.25, 

187 color=next(cycol), grid=grid 

188 ) 

189 return fig 

190 

191 

192def _1d_cdf_plot( 

193 param, samples, latex_label, fig=None, color=conf.color, title=True, 

194 grid=True, linestyle="-", weights=None, **kwargs 

195): 

196 """Generate the cumulative distribution function for a given parameter for 

197 a given approximant. 

198 

199 Parameters 

200 ---------- 

201 param: str 

202 name of the parameter that you wish to plot 

203 samples: list 

204 list of samples for param 

205 latex_label: str 

206 latex label for param 

207 fig: matplotlib.pyplot.figure, optional 

208 existing figure you wish to use 

209 color: str, optional09 

210 color you wish to use to plot the scatter points 

211 title: Bool, optional 

212 if True, add a title to the 1d cdf plot showing giving the median 

213 and symmetric 90% credible intervals 

214 grid: Bool, optional 

215 if True, plot a grid 

216 linestyle: str, optional 

217 linestyle to use for plotting the CDF. Default "-" 

218 weights: list, optional 

219 list of weights for samples. Default None 

220 **kwargs: dict, optional 

221 all additional kwargs passed to ax.plot 

222 """ 

223 logger.debug("Generating the 1d CDF for %s" % (param)) 

224 if fig is None: 

225 fig, ax = figure(gca=True) 

226 else: 

227 ax = fig.gca() 

228 if weights is None: 

229 sorted_samples = copy.deepcopy(samples) 

230 sorted_samples.sort() 

231 ax.plot( 

232 sorted_samples, np.linspace(0, 1, len(sorted_samples)), color=color, 

233 linestyle=linestyle, **kwargs 

234 ) 

235 else: 

236 hist, bin_edges = np.histogram(samples, bins=50, density=True, weights=weights) 

237 total = np.cumsum(hist) 

238 total /= total[-1] 

239 ax.plot( 

240 0.5 * (bin_edges[:-1] + bin_edges[1:]), 

241 total, color=color, linestyle=linestyle, 

242 **kwargs 

243 ) 

244 ax.set_xlabel(latex_label) 

245 ax.set_ylabel("Cumulative Density Function") 

246 upper_percentile = np.percentile(samples, 95) 

247 lower_percentile = np.percentile(samples, 5) 

248 median = np.median(samples) 

249 upper = np.round(upper_percentile - median, 2) 

250 lower = np.round(median - lower_percentile, 2) 

251 median = np.round(median, 2) 

252 if title: 

253 ax.set_title(r"$%s^{+%s}_{-%s}$" % (median, upper, lower)) 

254 ax.grid(visible=grid) 

255 ax.set_ylim([0, 1.05]) 

256 fig.tight_layout() 

257 return fig 

258 

259 

260def _1d_cdf_plot_mcmc( 

261 param, samples, latex_label, colorcycle=conf.colorcycle, grid=True, **kwargs 

262): 

263 """Generate the cumulative distribution function for a given parameter 

264 for a given set of mcmc chains 

265 

266 Parameters 

267 ---------- 

268 param: str 

269 name of the parameter that you wish to plot 

270 samples: np.ndarray 

271 2d array containing the samples for param for each mcmc chain 

272 latex_label: str 

273 latex label for param 

274 colorcycle: list, str 

275 color cycle you wish to use for the different mcmc chains 

276 grid: Bool, optional 

277 if True, plot a grid 

278 **kwargs: dict, optional 

279 all additional kwargs passed to _1d_cdf_plot 

280 """ 

281 cycol = cycle(colorcycle) 

282 fig, ax = figure(gca=True) 

283 for ss in samples: 

284 fig = _1d_cdf_plot( 

285 param, ss, latex_label, fig=fig, color=next(cycol), title=False, 

286 grid=grid, **kwargs 

287 ) 

288 gelman = gelman_rubin(samples) 

289 ax.set_title("Gelman-Rubin: {}".format(gelman)) 

290 return fig 

291 

292 

293def _1d_cdf_comparison_plot( 

294 param, samples, colors, latex_label, labels, linestyles=None, grid=True, 

295 legend_kwargs=_default_legend_kwargs, latex_friendly=False, weights=None, 

296 **kwargs 

297): 

298 """Generate a plot to compare the cdfs for a given parameter for different 

299 approximants. 

300 

301 Parameters 

302 ---------- 

303 param: str 

304 name of the parameter that you wish to plot 

305 approximants: list 

306 list of approximant names that you would like to compare 

307 samples: 2d list 

308 list of samples for param for each approximant 

309 colors: list 

310 list of colors to be used to differentiate the different approximants 

311 latex_label: str 

312 latex label for param 

313 approximant_labels: list, optional 

314 label to prepend the approximant in the legend 

315 grid: Bool, optional 

316 if True, plot a grid 

317 legend_kwargs: dict, optional 

318 optional kwargs to pass to ax.legend() 

319 latex_friendly: Bool, optional 

320 if True, make the label latex friendly. Default False 

321 weights: list, optional 

322 list of weights to use for each analysis. Default None 

323 **kwargs: dict, optional 

324 all additional kwargs passed to _1d_cdf_plot 

325 """ 

326 logger.debug("Generating the 1d comparison CDF for %s" % (param)) 

327 if linestyles is None: 

328 linestyles = ["-"] * len(samples) 

329 if weights is None: 

330 weights = [None] * len(samples) 

331 if len(weights) != len(samples): 

332 raise ValueError( 

333 "Please provide a set of weights for each analysis" 

334 ) 

335 fig, ax = figure(figsize=(8, 6), gca=True) 

336 handles = [] 

337 for num, i in enumerate(samples): 

338 fig = _1d_cdf_plot( 

339 param, i, latex_label, fig=fig, color=colors[num], title=False, 

340 grid=grid, linestyle=linestyles[num], weights=weights[num], **kwargs 

341 ) 

342 if latex_friendly: 

343 labels = copy.deepcopy(labels) 

344 labels[num] = labels[num].replace("_", "\_") 

345 handles.append(mlines.Line2D([], [], color=colors[num], label=labels[num])) 

346 ncols = number_of_columns_for_legend(labels) 

347 legend = ax.legend(handles=handles, ncol=ncols, **legend_kwargs) 

348 _update_1d_comparison_legend(legend, linestyles) 

349 ax.set_xlabel(latex_label) 

350 ax.set_ylabel("Cumulative Density Function") 

351 ax.grid(visible=grid) 

352 ax.set_ylim([0, 1.05]) 

353 fig.tight_layout() 

354 return fig 

355 

356 

357def _1d_analytic_plot( 

358 param, x, pdf, latex_label, inj_value=None, prior=None, fig=None, ax=None, 

359 title=True, color=conf.color, autoscale=True, grid=True, set_labels=True, 

360 plot_percentile=True, xlims=None, label=None, linestyle="-", 

361 linewidth=1.75, injection_color=conf.injection_color, 

362 _default_inj_kwargs={"linewidth": 2.5, "linestyle": "-"}, **plot_kwargs 

363): 

364 """Generate a plot to display a PDF 

365 

366 Parameters 

367 ---------- 

368 param: str 

369 name of the parameter that you wish to plot 

370 

371 latex_label: str 

372 latex label for param 

373 inj_value: float, optional 

374 value that was injected 

375 prior: list 

376 list of prior samples for param 

377 weights: list 

378 list of weights for each sample 

379 fig: matplotlib.pyplot.figure, optional 

380 existing figure you wish to use 

381 ax: matplotlib.pyplot.axes._subplots.AxesSubplot, optional 

382 existing axis you wish to use 

383 color: str, optional 

384 color you wish to use to plot the scatter points 

385 title: Bool, optional 

386 if True, add a title to the 1d cdf plot showing giving the median 

387 and symmetric 90% credible intervals 

388 autoscale: Bool, optional 

389 autoscale the x axis 

390 grid: Bool, optional 

391 if True, plot a grid 

392 set_labels: Bool, optional 

393 if True, add labels to the axes 

394 plot_percentile: Bool, optional 

395 if True, plot dashed vertical lines showing the 90% symmetric credible 

396 intervals 

397 xlims: list, optional 

398 x axis limits you wish to use 

399 label: str, optional 

400 label you wish to use for the plot 

401 linestyle: str, optional 

402 linestyle you wish to use for the plot 

403 linewidth: float, optional 

404 linewidth to use for the plot 

405 injection_color: str, optional 

406 color of vertical line showing the injected value 

407 """ 

408 from pesummary.utils.array import Array 

409 

410 if ax is None and fig is None: 

411 fig, ax = figure(gca=True) 

412 elif ax is None: 

413 ax = fig.gca() 

414 

415 pdf = Array(x, weights=pdf) 

416 

417 ax.plot(pdf, pdf.weights, color=color, linestyle=linestyle, label=label) 

418 _xlims = ax.get_xlim() 

419 percentile = pdf.credible_interval([5, 95]) 

420 median = pdf.average("median") 

421 if title: 

422 upper = np.round(percentile[1] - median, 2) 

423 lower = np.round(median - percentile[0], 2) 

424 median = np.round(median, 2) 

425 ax.set_title(r"$%s^{+%s}_{-%s}$" % (median, upper, lower)) 

426 if plot_percentile: 

427 for pp in percentile: 

428 ax.axvline( 

429 pp, color=color, linestyle="--", linewidth=linewidth 

430 ) 

431 if set_labels: 

432 ax.set_xlabel(latex_label) 

433 ax.set_ylabel("Probability Density") 

434 

435 if inj_value is not None: 

436 ax.axvline( 

437 inj_value, color=injection_color, **_default_inj_kwargs 

438 ) 

439 ax.grid(visible=grid) 

440 ax.set_xlim(xlims) 

441 if autoscale: 

442 ax.set_xlim(_xlims) 

443 if fig is None: 

444 return ax 

445 fig.tight_layout() 

446 return fig 

447 

448 

449def _1d_histogram_plot( 

450 param, samples, latex_label, inj_value=None, kde=False, hist=True, 

451 prior=None, weights=None, fig=None, ax=None, title=True, color=conf.color, 

452 autoscale=True, grid=True, kde_kwargs={}, hist_kwargs={}, set_labels=True, 

453 plot_percentile=True, plot_hdp=True, xlims=None, max_vline=1, label=None, 

454 linestyle="-", injection_color=conf.injection_color, _default_hist_kwargs={ 

455 "density": True, "bins": 50, "histtype": "step", "linewidth": 1.75 

456 }, _default_kde_kwargs={"fill": True, "alpha": 0.1}, 

457 _default_inj_kwargs={"linewidth": 2.5, "linestyle": "-"}, 

458 key_data=None, **plot_kwargs 

459): 

460 """Generate the 1d histogram plot for a given parameter for a given 

461 approximant. 

462 

463 Parameters 

464 ---------- 

465 param: str 

466 name of the parameter that you wish to plot 

467 samples: list 

468 list of samples for param 

469 latex_label: str 

470 latex label for param 

471 inj_value: float, optional 

472 value that was injected 

473 kde: Bool, optional 

474 if True, a kde is plotted instead of a histogram 

475 hist: Bool, optional 

476 if True, plot a histogram 

477 prior: list 

478 list of prior samples for param 

479 weights: list 

480 list of weights for each sample 

481 fig: matplotlib.pyplot.figure, optional 

482 existing figure you wish to use 

483 ax: matplotlib.pyplot.axes._subplots.AxesSubplot, optional 

484 existing axis you wish to use 

485 color: str, optional 

486 color you wish to use to plot the scatter points 

487 title: Bool, optional 

488 if True, add a title to the 1d cdf plot showing giving the median 

489 and symmetric 90% credible intervals 

490 autoscale: Bool, optional 

491 autoscale the x axis 

492 grid: Bool, optional 

493 if True, plot a grid 

494 kde_kwargs, dict, optional 

495 optional kwargs to pass to the kde class 

496 hist_kwargs: dict, optional 

497 optional kwargs to pass to matplotlib.pyplot.hist 

498 set_labels: Bool, optional 

499 if True, add labels to the axes 

500 plot_percentile: Bool, optional 

501 if True, plot dashed vertical lines showing the 90% symmetric credible 

502 intervals 

503 xlims: list, optional 

504 x axis limits you wish to use 

505 max_vline: int, optional 

506 if number of peaks < max_vline draw peaks as vertical lines rather 

507 than histogramming the data 

508 label: str, optional 

509 label you wish to use for the plot 

510 linestyle: str, optional 

511 linestyle you wish to use for the plot 

512 injection_color: str, optional 

513 color of vertical line showing the injected value 

514 """ 

515 from pesummary.utils.array import Array 

516 

517 logger.debug("Generating the 1d histogram plot for %s" % (param)) 

518 samples = Array(samples, weights=weights) 

519 if ax is None and fig is None: 

520 fig, ax = figure(gca=True) 

521 elif ax is None: 

522 ax = fig.gca() 

523 

524 if len(set(samples)) <= max_vline: 

525 for _ind, _sample in enumerate(set(samples)): 

526 _label = None 

527 if _ind == 0: 

528 _label = label 

529 ax.axvline(_sample, color=color, label=_label) 

530 _xlims = ax.get_xlim() 

531 else: 

532 if hist: 

533 _default_hist_kwargs.update(hist_kwargs) 

534 ax.hist( 

535 samples, weights=weights, color=color, label=label, 

536 linestyle=linestyle, **_default_hist_kwargs, **plot_kwargs 

537 ) 

538 _xlims = ax.get_xlim() 

539 if prior is not None: 

540 _prior_hist_kwargs = _default_hist_kwargs.copy() 

541 _prior_hist_kwargs["histtype"] = "bar" 

542 _ = ax.hist( 

543 prior, color=conf.prior_color, alpha=0.2, edgecolor="w", 

544 linestyle=linestyle, **_prior_hist_kwargs, **plot_kwargs 

545 ) 

546 if kde: 

547 from pesummary.core.plots.seaborn.kde import kdeplot 

548 _kde_kwargs = kde_kwargs.copy() 

549 kwargs = _default_kde_kwargs 

550 kwargs.update({ 

551 "kde_kwargs": _kde_kwargs, 

552 "kde_kernel": _kde_kwargs.pop("kde_kernel", None), 

553 "weights": weights 

554 }) 

555 kwargs.update(plot_kwargs) 

556 x = kdeplot( 

557 samples, color=color, ax=ax, linestyle=linestyle, **kwargs 

558 ) 

559 _xlims = ax.get_xlim() 

560 if prior is not None: 

561 kdeplot( 

562 prior, color=conf.prior_color, ax=ax, linestyle=linestyle, 

563 **kwargs 

564 ) 

565 

566 if set_labels: 

567 ax.set_xlabel(latex_label) 

568 ax.set_ylabel("Probability Density") 

569 

570 if inj_value is not None: 

571 ax.axvline( 

572 inj_value, color=injection_color, **_default_inj_kwargs 

573 ) 

574 hdp = float("nan") 

575 if key_data is not None: 

576 percentile = [key_data["5th percentile"], key_data["95th percentile"]] 

577 median = key_data["median"] 

578 if "90% HPD" in key_data.keys(): 

579 hdp = key_data["90% HPD"] 

580 else: 

581 percentile = samples.credible_interval([5, 95]) 

582 median = samples.average("median") 

583 if plot_percentile: 

584 for pp in percentile: 

585 ax.axvline( 

586 pp, color=color, linestyle="--", 

587 linewidth=hist_kwargs.get("linewidth", 1.75) 

588 ) 

589 if plot_hdp and isinstance(hdp, (list, np.ndarray)): 

590 for pp in hdp: 

591 ax.axvline( 

592 pp, color=color, linestyle=":", 

593 linewidth=hist_kwargs.get("linewidth", 1.75) 

594 ) 

595 if title: 

596 upper = np.round(percentile[1] - median, 2) 

597 lower = np.abs(np.round(median - percentile[0], 2)) 

598 median = np.round(median, 2) 

599 _base = r"$%s^{+%s}_{-%s}" % (median, upper, lower) 

600 if not isinstance(hdp, (list, np.ndarray)) and np.isnan(hdp): 

601 _base += r"$" 

602 ax.set_title(_base) 

603 else: 

604 upper = np.round(hdp[1] - median, 2) 

605 lower = np.abs(np.round(median - hdp[0], 2)) 

606 _base += r"\, (\mathrm{CI}) / %s^{+%s}_{-%s}\, (\mathrm{HPD})$" % ( 

607 median, upper, lower 

608 ) 

609 ax.set_title(_base) 

610 ax.grid(visible=grid) 

611 ax.set_xlim(xlims) 

612 if autoscale: 

613 ax.set_xlim(_xlims) 

614 if fig is None: 

615 return ax 

616 fig.tight_layout() 

617 return fig 

618 

619 

620def _1d_histogram_plot_mcmc( 

621 param, samples, latex_label, colorcycle=conf.colorcycle, **kwargs 

622): 

623 """Generate a 1d histogram plot for a given parameter for a given 

624 set of mcmc chains 

625 

626 Parameters 

627 ---------- 

628 param: str 

629 name of the parameter that you wish to plot 

630 samples: np.ndarray 

631 2d array of samples for param for each mcmc chain 

632 latex_label: str 

633 latex label for param 

634 colorcycle: list, str 

635 color cycle you wish to use for the different mcmc chains 

636 **kwargs: dict, optional 

637 all additional kwargs passed to _1d_histogram_plot 

638 """ 

639 cycol = cycle(colorcycle) 

640 fig, ax = figure(gca=True) 

641 for ss in samples: 

642 fig = _1d_histogram_plot( 

643 param, ss, latex_label, color=next(cycol), title=False, 

644 autoscale=False, fig=fig, **kwargs 

645 ) 

646 gelman = gelman_rubin(samples) 

647 ax.set_title("Gelman-Rubin: {}".format(gelman)) 

648 return fig 

649 

650 

651def _1d_histogram_plot_bootstrap( 

652 param, samples, latex_label, colorcycle=conf.colorcycle, nsamples=1000, 

653 ntests=100, shade=False, plot_percentile=False, kde=True, hist=False, 

654 **kwargs 

655): 

656 """Generate a bootstrapped 1d histogram plot for a given parameter 

657 

658 Parameters 

659 ---------- 

660 param: str 

661 name of the parameter that you wish to plot 

662 samples: np.ndarray 

663 array of samples for param 

664 latex_label: str 

665 latex label for param 

666 colorcycle: list, str 

667 color cycle you wish to use for the different tests 

668 nsamples: int, optional 

669 number of samples to randomly draw from samples. Default 1000 

670 ntests: int, optional 

671 number of tests to perform. Default 100 

672 **kwargs: dict, optional 

673 all additional kwargs passed to _1d_histogram_plot 

674 """ 

675 if nsamples > len(samples): 

676 nsamples = int(len(samples) / 2) 

677 _samples = [ 

678 np.random.choice(samples, size=nsamples, replace=False) for _ in 

679 range(ntests) 

680 ] 

681 cycol = cycle(colorcycle) 

682 fig, ax = figure(gca=True) 

683 for ss in _samples: 

684 fig = _1d_histogram_plot( 

685 param, ss, latex_label, color=next(cycol), title=False, 

686 autoscale=False, fig=fig, shade=shade, 

687 plot_percentile=plot_percentile, kde=kde, hist=hist, **kwargs 

688 ) 

689 ax.set_title("Ntests: {}, Nsamples per test: {}".format(ntests, nsamples)) 

690 fig.tight_layout() 

691 return fig 

692 

693 

694def _1d_comparison_histogram_plot( 

695 param, samples, colors, latex_label, labels, inj_value=None, kde=False, 

696 hist=True, linestyles=None, kde_kwargs={}, hist_kwargs={}, max_vline=1, 

697 figsize=(8, 6), grid=True, legend_kwargs=_default_legend_kwargs, 

698 latex_friendly=False, max_inj_line=1, injection_color="k", 

699 weights=None, **kwargs 

700): 

701 """Generate the a plot to compare the 1d_histogram plots for a given 

702 parameter for different approximants. 

703 

704 Parameters 

705 ---------- 

706 param: str 

707 name of the parameter that you wish to plot 

708 approximants: list 

709 list of approximant names that you would like to compare 

710 samples: 2d list 

711 list of samples for param for each approximant 

712 colors: list 

713 list of colors to be used to differentiate the different approximants 

714 latex_label: str 

715 latex label for param 

716 approximant_labels: list, optional 

717 label to prepend the approximant in the legend 

718 kde: Bool 

719 if true, a kde is plotted instead of a histogram 

720 linestyles: list 

721 list of linestyles for each set of samples 

722 grid: Bool, optional 

723 if True, plot a grid 

724 legend_kwargs: dict, optional 

725 optional kwargs to pass to ax.legend() 

726 latex_friendly: Bool, optional 

727 if True, make the label latex friendly. Default False 

728 inj_value: float/list, optional 

729 either a single injection value which will be used for all histograms 

730 or a list of injection values, one for each histogram 

731 injection_color: str/list, optional 

732 either a single color which will be used for all vertical line showing 

733 the injected value or a list of colors, one for each injection 

734 weights: list, optional 

735 list of weights to use for each analysis. Default None 

736 **kwargs: dict, optional 

737 all additional kwargs passed to _1d_histogram_plot 

738 """ 

739 logger.debug("Generating the 1d comparison histogram plot for %s" % (param)) 

740 if linestyles is None: 

741 linestyles = ["-"] * len(samples) 

742 if inj_value is None: 

743 inj_value = [None] * len(samples) 

744 if weights is None: 

745 weights = [None] * len(samples) 

746 if len(weights) != len(samples): 

747 raise ValueError( 

748 "Please provide a set of weights for each analysis" 

749 ) 

750 elif isinstance(inj_value, (list, np.ndarray)) and len(inj_value) != len(samples): 

751 raise ValueError( 

752 "Please provide an injection for each analysis or a single " 

753 "injection value which will be used for all histograms" 

754 ) 

755 elif not isinstance(inj_value, (list, np.ndarray)): 

756 inj_value = [inj_value] * len(samples) 

757 

758 if isinstance(injection_color, str): 

759 injection_color = [injection_color] * len(samples) 

760 elif len(injection_color) != len(samples): 

761 raise ValueError( 

762 "Please provide an injection color for each analysis or a single " 

763 "injection color which will be used for all lines showing the " 

764 "injected values" 

765 ) 

766 

767 flat_injection = np.array([_ for _ in inj_value if _ is not None]).flatten() 

768 if len(set(flat_injection)) > max_inj_line: 

769 logger.warning( 

770 "Number of unique injection values ({}) is more than the maximum " 

771 "allowed injection value ({}). Not plotting injection value. If " 

772 "this is a mistake, please increase `max_inj_line`".format( 

773 len(set(flat_injection)), max_inj_line 

774 ) 

775 ) 

776 inj_value = [None] * len(samples) 

777 

778 fig, ax = figure(figsize=figsize, gca=True) 

779 handles = [] 

780 hist_kwargs.update({"linewidth": 2.5}) 

781 for num, i in enumerate(samples): 

782 if latex_friendly: 

783 labels = copy.deepcopy(labels) 

784 labels[num] = labels[num].replace("_", "\_") 

785 fig = _1d_histogram_plot( 

786 param, i, latex_label, kde=kde, hist=hist, kde_kwargs=kde_kwargs, 

787 max_vline=max_vline, grid=grid, title=False, autoscale=False, 

788 label=labels[num], color=colors[num], fig=fig, hist_kwargs=hist_kwargs, 

789 inj_value=inj_value[num], injection_color=injection_color[num], 

790 weights=weights[num], linestyle=linestyles[num], _default_inj_kwargs={ 

791 "linewidth": 4., "linestyle": "-", "alpha": 0.4 

792 }, **kwargs 

793 ) 

794 handles.append(mlines.Line2D([], [], color=colors[num], label=labels[num])) 

795 ax = fig.gca() 

796 ncols = number_of_columns_for_legend(labels) 

797 legend = ax.legend(handles=handles, ncol=ncols, **legend_kwargs) 

798 _update_1d_comparison_legend(legend, linestyles) 

799 ax.set_xlabel(latex_label) 

800 ax.set_ylabel("Probability Density") 

801 ax.autoscale(axis='x') 

802 ax.grid(visible=grid) 

803 fig.tight_layout() 

804 return fig 

805 

806 

807def _comparison_box_plot(param, samples, colors, latex_label, labels, grid=True): 

808 """Generate a box plot to compare 1d_histograms for a given parameter 

809 

810 Parameters 

811 ---------- 

812 param: str 

813 name of the parameter that you wish to plot 

814 approximants: list 

815 list of approximant names that you would like to compare 

816 samples: 2d list 

817 list of samples for param for each approximant 

818 colors: list 

819 list of colors to be used to differentiate the different approximants 

820 latex_label: str 

821 latex label for param 

822 approximant_labels: list, optional 

823 label to prepend the approximant in the legend 

824 grid: Bool, optional 

825 if True, plot a grid 

826 """ 

827 logger.debug("Generating the 1d comparison boxplot plot for %s" % (param)) 

828 fig, ax = figure(gca=True) 

829 maximum = np.max([np.max(i) for i in samples]) 

830 minimum = np.min([np.min(i) for i in samples]) 

831 middle = (maximum + minimum) * 0.5 

832 ax.boxplot(samples, widths=0.2, vert=False, whis=np.inf, labels=labels) 

833 for num, i in enumerate(labels): 

834 ax.annotate(i, xy=(middle, 1), xytext=(middle, num + 1.0 + 0.2), ha="center") 

835 ax.set_yticks([]) 

836 ax.set_xlabel(latex_label) 

837 fig.tight_layout() 

838 ax.grid(visible=grid) 

839 return fig 

840 

841 

842def _make_corner_plot( 

843 samples, latex_labels, corner_parameters=None, parameters=None, **kwargs 

844): 

845 """Generate the corner plots for a given approximant 

846 

847 Parameters 

848 ---------- 

849 opts: argparse 

850 argument parser object to hold all information from the command line 

851 samples: nd list 

852 nd list of samples for each parameter for a given approximant 

853 params: list 

854 list of parameters associated with each element in samples 

855 approximant: str 

856 name of approximant that was used to generate the samples 

857 latex_labels: dict 

858 dictionary of latex labels for each parameter 

859 """ 

860 logger.debug("Generating the corner plot") 

861 # set the default kwargs 

862 default_kwargs = conf.corner_kwargs.copy() 

863 if parameters is None: 

864 parameters = list(samples.keys()) 

865 if corner_parameters is not None: 

866 included_parameters = [i for i in corner_parameters if i in parameters] 

867 excluded_parameters = [i for i in corner_parameters if i not in parameters] 

868 if len(excluded_parameters): 

869 plural = len(excluded_parameters) > 1 

870 logger.warning( 

871 f"Removing the parameter{'s' if plural else ''}: " 

872 f"{', '.join(excluded_parameters)} from the corner plot as " 

873 f"{'they are' if plural else 'it is'} not available in the " 

874 f"posterior table. This may affect the truth lines if " 

875 f"provided." 

876 ) 

877 if not len(included_parameters): 

878 raise ValueError( 

879 "None of the chosen parameters are in the posterior " 

880 "samples table. Please choose other parameters to plot" 

881 ) 

882 else: 

883 included_parameters = parameters 

884 xs = np.zeros([len(included_parameters), len(samples[parameters[0]])]) 

885 for num, i in enumerate(included_parameters): 

886 xs[num] = samples[i] 

887 default_kwargs.update(kwargs) 

888 default_kwargs["range"] = [1.0] * len(included_parameters) 

889 default_kwargs["labels"] = [latex_labels[i] for i in included_parameters] 

890 

891 _figure = ExistingFigure(corner(xs.T, included_parameters, **default_kwargs)) 

892 # grab the axes of the subplots 

893 axes = _figure.get_axes() 

894 axes_of_interest = axes[:2] 

895 location = [] 

896 for i in axes_of_interest: 

897 extent = i.get_window_extent().transformed(_figure.dpi_scale_trans.inverted()) 

898 location.append([extent.x0 * _figure.dpi, extent.y0 * _figure.dpi]) 

899 width, height = extent.width, extent.height 

900 width *= _figure.dpi 

901 height *= _figure.dpi 

902 try: 

903 seperation = float(abs(location[0][0] - location[1][0]) - width) 

904 except IndexError: 

905 seperation = None 

906 # explicitly cast to float to ensure correct rendering in JS 

907 # https://git.ligo.org/lscsoft/pesummary/-/issues/332 

908 data = { 

909 "width": float(width), 

910 "height": float(height), 

911 "seperation": seperation, 

912 "x0": float(location[0][0]), 

913 "y0": float(location[0][0]), 

914 } 

915 return _figure, included_parameters, data 

916 

917 

918def _make_comparison_corner_plot( 

919 samples, latex_labels, corner_parameters=None, colors=conf.corner_colors, 

920 latex_friendly=True, **kwargs 

921): 

922 """Generate a corner plot which contains multiple datasets 

923 

924 Parameters 

925 ---------- 

926 samples: dict 

927 nested dictionary containing the label as key and SamplesDict as item 

928 for each dataset you wish to plot 

929 latex_labels: dict 

930 dictionary of latex labels for each parameter 

931 corner_parameters: list, optional 

932 corner parameters you wish to include in the plot 

933 colors: list, optional 

934 unique colors for each dataset 

935 latex_friendly: Bool, optional 

936 if True, make the label latex friendly. Default True 

937 **kwargs: dict 

938 all kwargs are passed to `corner.corner` 

939 """ 

940 parameters = corner_parameters 

941 if corner_parameters is None: 

942 _parameters = [list(_samples.keys()) for _samples in samples.values()] 

943 parameters = [ 

944 i for i in _parameters[0] if all(i in _params for _params in _parameters) 

945 ] 

946 if len(samples.keys()) > len(colors): 

947 raise ValueError("Please provide a unique color for each dataset") 

948 

949 hist_kwargs = kwargs.get("hist_kwargs", dict()) 

950 hist_kwargs["density"] = True 

951 lines = [] 

952 for num, (label, posterior) in enumerate(samples.items()): 

953 if latex_friendly: 

954 label = copy.deepcopy(label) 

955 label = label.replace("_", "\_") 

956 lines.append(mlines.Line2D([], [], color=colors[num], label=label)) 

957 _samples = { 

958 param: value for param, value in posterior.items() if param in 

959 parameters 

960 } 

961 hist_kwargs["color"] = colors[num] 

962 kwargs.update({"hist_kwargs": hist_kwargs}) 

963 if num == 0: 

964 fig, _, _ = _make_corner_plot( 

965 _samples, latex_labels, corner_parameters=corner_parameters, 

966 parameters=parameters, color=colors[num], **kwargs 

967 ) 

968 else: 

969 fig, _, _ = _make_corner_plot( 

970 _samples, latex_labels, corner_parameters=corner_parameters, 

971 fig=fig, parameters=parameters, color=colors[num], **kwargs 

972 ) 

973 fig.legend(handles=lines, loc="upper right") 

974 lines = [] 

975 return fig