Coverage for pesummary/core/plots/plot.py: 49.7%

366 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-09 22:34 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3from pesummary.utils.utils import ( 

4 logger, number_of_columns_for_legend, _check_latex_install, gelman_rubin, 

5) 

6from pesummary.core.plots.seaborn.kde import kdeplot 

7from pesummary.core.plots.corner import corner 

8from pesummary.core.plots.figure import figure, ExistingFigure 

9from pesummary import conf 

10 

11import matplotlib.lines as mlines 

12import copy 

13from itertools import cycle 

14 

15import numpy as np 

16from scipy import signal 

17 

18__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

19_check_latex_install() 

20 

21_default_legend_kwargs = dict( 

22 bbox_to_anchor=(0.0, 1.02, 1.0, 0.102), loc=3, handlelength=3, mode="expand", 

23 borderaxespad=0.0, 

24) 

25 

26 

27def _update_1d_comparison_legend(legend, linestyles, linewidth=1.75): 

28 """Update the width and style of lines in the legend for a 1d comparison plot. 

29 """ 

30 try: 

31 handles = legend.legend_handles 

32 except AttributeError: # matplotlib < 3.7.0 

33 handles = legend.legendHandles 

34 for handle, style in zip(handles, linestyles): 

35 handle.set_linewidth(linewidth) 

36 handle.set_linestyle(style) 

37 

38 

39def _autocorrelation_plot( 

40 param, samples, fig=None, color=conf.color, markersize=0.5, grid=True 

41): 

42 """Generate the autocorrelation function for a set of samples for a given 

43 parameter for a given approximant. 

44 

45 Parameters 

46 ---------- 

47 param: str 

48 name of the parameter that you wish to plot 

49 samples: list 

50 list of samples for param 

51 fig: matplotlib.pyplot.figure 

52 existing figure you wish to use 

53 color: str, optional 

54 color you wish to use for the autocorrelation plot 

55 grid: Bool, optional 

56 if True, plot a grid 

57 """ 

58 import warnings 

59 warnings.filterwarnings("ignore", category=RuntimeWarning) 

60 logger.debug("Generating the autocorrelation function for %s" % (param)) 

61 if fig is None: 

62 fig, ax = figure(gca=True) 

63 else: 

64 ax = fig.gca() 

65 samples = samples[int(len(samples) / 2):] 

66 x = samples - np.mean(samples) 

67 y = np.conj(x[::-1]) 

68 acf = np.fft.ifftshift(signal.fftconvolve(y, x, mode="full")) 

69 N = np.array(samples).shape[0] 

70 acf = acf[0:N] 

71 # Hack to make test pass with python3.8 

72 if color == "$": 

73 color = conf.color 

74 ax.plot( 

75 acf / acf[0], linestyle=" ", marker="o", markersize=markersize, 

76 color=color 

77 ) 

78 ax.ticklabel_format(axis="x", style="plain") 

79 ax.set_xlabel("lag") 

80 ax.set_ylabel("ACF") 

81 ax.grid(visible=grid) 

82 fig.tight_layout() 

83 return fig 

84 

85 

86def _autocorrelation_plot_mcmc( 

87 param, samples, colorcycle=conf.colorcycle, grid=True 

88): 

89 """Generate the autocorrelation function for a set of samples for a given 

90 parameter for a given set of mcmc chains 

91 

92 Parameters 

93 ---------- 

94 param: str 

95 name of the parameter that you wish to plot 

96 samples: np.ndarray 

97 2d array containing a list of samples for param for each mcmc chain 

98 colorcycle: list, str 

99 color cycle you wish to use for the different mcmc chains 

100 grid: Bool, optional 

101 if True, plot a grid 

102 """ 

103 cycol = cycle(colorcycle) 

104 fig, ax = figure(gca=True) 

105 for ss in samples: 

106 fig = _autocorrelation_plot( 

107 param, ss, fig=fig, markersize=1.25, color=next(cycol), grid=grid 

108 ) 

109 return fig 

110 

111 

112def _sample_evolution_plot( 

113 param, samples, latex_label, inj_value=None, fig=None, color=conf.color, 

114 markersize=0.5, grid=True, z=None, z_label=None, **kwargs 

115): 

116 """Generate a scatter plot showing the evolution of the samples for a 

117 given parameter for a given approximant. 

118 

119 Parameters 

120 ---------- 

121 param: str 

122 name of the parameter that you wish to plot 

123 samples: list 

124 list of samples for param 

125 latex_label: str 

126 latex label for param 

127 inj_value: float 

128 value that was injected 

129 fig: matplotlib.pyplot.figure, optional 

130 existing figure you wish to use 

131 color: str, optional 

132 color you wish to use to plot the scatter points 

133 grid: Bool, optional 

134 if True, plot a grid 

135 """ 

136 logger.debug("Generating the sample scatter plot for %s" % (param)) 

137 if fig is None: 

138 fig, ax = figure(gca=True) 

139 else: 

140 ax = fig.gca() 

141 n_samples = len(samples) 

142 add_cbar = True if z is not None else False 

143 if z is None: 

144 z = color 

145 s = ax.scatter( 

146 range(n_samples), samples, marker="o", s=markersize, c=z, 

147 **kwargs 

148 ) 

149 if add_cbar: 

150 cbar = fig.colorbar(s) 

151 if z_label is not None: 

152 cbar.set_label(z_label) 

153 ax.ticklabel_format(axis="x", style="plain") 

154 ax.set_xlabel("samples") 

155 ax.set_ylabel(latex_label) 

156 ax.grid(visible=grid) 

157 fig.tight_layout() 

158 return fig 

159 

160 

161def _sample_evolution_plot_mcmc( 

162 param, samples, latex_label, inj_value=None, colorcycle=conf.colorcycle, 

163 grid=True 

164): 

165 """Generate a scatter plot showing the evolution of the samples in each 

166 mcmc chain for a given parameter 

167 

168 Parameters 

169 ---------- 

170 param: str 

171 name of the parameter that you wish to plot 

172 samples: np.ndarray 

173 2d array containing the samples for param for each mcmc chain 

174 latex_label: str 

175 latex label for param 

176 inj_value: float 

177 value that was injected 

178 colorcycle: list, str 

179 color cycle you wish to use for the different mcmc chains 

180 grid: Bool, optional 

181 if True, plot a grid 

182 """ 

183 cycol = cycle(colorcycle) 

184 fig, ax = figure(gca=True) 

185 for ss in samples: 

186 fig = _sample_evolution_plot( 

187 param, ss, latex_label, inj_value=None, fig=fig, markersize=1.25, 

188 color=next(cycol), grid=grid 

189 ) 

190 return fig 

191 

192 

193def _1d_cdf_plot( 

194 param, samples, latex_label, fig=None, color=conf.color, title=True, 

195 grid=True, linestyle="-", weights=None, **kwargs 

196): 

197 """Generate the cumulative distribution function for a given parameter for 

198 a given approximant. 

199 

200 Parameters 

201 ---------- 

202 param: str 

203 name of the parameter that you wish to plot 

204 samples: list 

205 list of samples for param 

206 latex_label: str 

207 latex label for param 

208 fig: matplotlib.pyplot.figure, optional 

209 existing figure you wish to use 

210 color: str, optional09 

211 color you wish to use to plot the scatter points 

212 title: Bool, optional 

213 if True, add a title to the 1d cdf plot showing giving the median 

214 and symmetric 90% credible intervals 

215 grid: Bool, optional 

216 if True, plot a grid 

217 linestyle: str, optional 

218 linestyle to use for plotting the CDF. Default "-" 

219 weights: list, optional 

220 list of weights for samples. Default None 

221 **kwargs: dict, optional 

222 all additional kwargs passed to ax.plot 

223 """ 

224 logger.debug("Generating the 1d CDF for %s" % (param)) 

225 if fig is None: 

226 fig, ax = figure(gca=True) 

227 else: 

228 ax = fig.gca() 

229 if weights is None: 

230 sorted_samples = copy.deepcopy(samples) 

231 sorted_samples.sort() 

232 ax.plot( 

233 sorted_samples, np.linspace(0, 1, len(sorted_samples)), color=color, 

234 linestyle=linestyle, **kwargs 

235 ) 

236 else: 

237 hist, bin_edges = np.histogram(samples, bins=50, density=True, weights=weights) 

238 total = np.cumsum(hist) 

239 total /= total[-1] 

240 ax.plot( 

241 0.5 * (bin_edges[:-1] + bin_edges[1:]), 

242 total, color=color, linestyle=linestyle, 

243 **kwargs 

244 ) 

245 ax.set_xlabel(latex_label) 

246 ax.set_ylabel("Cumulative Density Function") 

247 upper_percentile = np.percentile(samples, 95) 

248 lower_percentile = np.percentile(samples, 5) 

249 median = np.median(samples) 

250 upper = np.round(upper_percentile - median, 2) 

251 lower = np.round(median - lower_percentile, 2) 

252 median = np.round(median, 2) 

253 if title: 

254 ax.set_title(r"$%s^{+%s}_{-%s}$" % (median, upper, lower)) 

255 ax.grid(visible=grid) 

256 ax.set_ylim([0, 1.05]) 

257 fig.tight_layout() 

258 return fig 

259 

260 

261def _1d_cdf_plot_mcmc( 

262 param, samples, latex_label, colorcycle=conf.colorcycle, grid=True, **kwargs 

263): 

264 """Generate the cumulative distribution function for a given parameter 

265 for a given set of mcmc chains 

266 

267 Parameters 

268 ---------- 

269 param: str 

270 name of the parameter that you wish to plot 

271 samples: np.ndarray 

272 2d array containing the samples for param for each mcmc chain 

273 latex_label: str 

274 latex label for param 

275 colorcycle: list, str 

276 color cycle you wish to use for the different mcmc chains 

277 grid: Bool, optional 

278 if True, plot a grid 

279 **kwargs: dict, optional 

280 all additional kwargs passed to _1d_cdf_plot 

281 """ 

282 cycol = cycle(colorcycle) 

283 fig, ax = figure(gca=True) 

284 for ss in samples: 

285 fig = _1d_cdf_plot( 

286 param, ss, latex_label, fig=fig, color=next(cycol), title=False, 

287 grid=grid, **kwargs 

288 ) 

289 gelman = gelman_rubin(samples) 

290 ax.set_title("Gelman-Rubin: {}".format(gelman)) 

291 return fig 

292 

293 

294def _1d_cdf_comparison_plot( 

295 param, samples, colors, latex_label, labels, linestyles=None, grid=True, 

296 legend_kwargs=_default_legend_kwargs, latex_friendly=False, weights=None, 

297 **kwargs 

298): 

299 """Generate a plot to compare the cdfs for a given parameter for different 

300 approximants. 

301 

302 Parameters 

303 ---------- 

304 param: str 

305 name of the parameter that you wish to plot 

306 approximants: list 

307 list of approximant names that you would like to compare 

308 samples: 2d list 

309 list of samples for param for each approximant 

310 colors: list 

311 list of colors to be used to differentiate the different approximants 

312 latex_label: str 

313 latex label for param 

314 approximant_labels: list, optional 

315 label to prepend the approximant in the legend 

316 grid: Bool, optional 

317 if True, plot a grid 

318 legend_kwargs: dict, optional 

319 optional kwargs to pass to ax.legend() 

320 latex_friendly: Bool, optional 

321 if True, make the label latex friendly. Default False 

322 weights: list, optional 

323 list of weights to use for each analysis. Default None 

324 **kwargs: dict, optional 

325 all additional kwargs passed to _1d_cdf_plot 

326 """ 

327 logger.debug("Generating the 1d comparison CDF for %s" % (param)) 

328 if linestyles is None: 

329 linestyles = ["-"] * len(samples) 

330 if weights is None: 

331 weights = [None] * len(samples) 

332 if len(weights) != len(samples): 

333 raise ValueError( 

334 "Please provide a set of weights for each analysis" 

335 ) 

336 fig, ax = figure(figsize=(8, 6), gca=True) 

337 handles = [] 

338 for num, i in enumerate(samples): 

339 fig = _1d_cdf_plot( 

340 param, i, latex_label, fig=fig, color=colors[num], title=False, 

341 grid=grid, linestyle=linestyles[num], weights=weights[num], **kwargs 

342 ) 

343 if latex_friendly: 

344 labels = copy.deepcopy(labels) 

345 labels[num] = labels[num].replace("_", "\_") 

346 handles.append(mlines.Line2D([], [], color=colors[num], label=labels[num])) 

347 ncols = number_of_columns_for_legend(labels) 

348 legend = ax.legend(handles=handles, ncol=ncols, **legend_kwargs) 

349 _update_1d_comparison_legend(legend, linestyles) 

350 ax.set_xlabel(latex_label) 

351 ax.set_ylabel("Cumulative Density Function") 

352 ax.grid(visible=grid) 

353 ax.set_ylim([0, 1.05]) 

354 fig.tight_layout() 

355 return fig 

356 

357 

358def _1d_analytic_plot( 

359 param, x, pdf, latex_label, inj_value=None, prior=None, fig=None, ax=None, 

360 title=True, color=conf.color, autoscale=True, grid=True, set_labels=True, 

361 plot_percentile=True, xlims=None, label=None, linestyle="-", 

362 linewidth=1.75, injection_color=conf.injection_color, 

363 _default_inj_kwargs={"linewidth": 2.5, "linestyle": "-"}, **plot_kwargs 

364): 

365 """Generate a plot to display a PDF 

366 

367 Parameters 

368 ---------- 

369 param: str 

370 name of the parameter that you wish to plot 

371 

372 latex_label: str 

373 latex label for param 

374 inj_value: float, optional 

375 value that was injected 

376 prior: list 

377 list of prior samples for param 

378 weights: list 

379 list of weights for each sample 

380 fig: matplotlib.pyplot.figure, optional 

381 existing figure you wish to use 

382 ax: matplotlib.pyplot.axes._subplots.AxesSubplot, optional 

383 existing axis you wish to use 

384 color: str, optional 

385 color you wish to use to plot the scatter points 

386 title: Bool, optional 

387 if True, add a title to the 1d cdf plot showing giving the median 

388 and symmetric 90% credible intervals 

389 autoscale: Bool, optional 

390 autoscale the x axis 

391 grid: Bool, optional 

392 if True, plot a grid 

393 set_labels: Bool, optional 

394 if True, add labels to the axes 

395 plot_percentile: Bool, optional 

396 if True, plot dashed vertical lines showing the 90% symmetric credible 

397 intervals 

398 xlims: list, optional 

399 x axis limits you wish to use 

400 label: str, optional 

401 label you wish to use for the plot 

402 linestyle: str, optional 

403 linestyle you wish to use for the plot 

404 linewidth: float, optional 

405 linewidth to use for the plot 

406 injection_color: str, optional 

407 color of vertical line showing the injected value 

408 """ 

409 from pesummary.utils.array import Array 

410 

411 if ax is None and fig is None: 

412 fig, ax = figure(gca=True) 

413 elif ax is None: 

414 ax = fig.gca() 

415 

416 pdf = Array(x, weights=pdf) 

417 

418 ax.plot(pdf, pdf.weights, color=color, linestyle=linestyle, label=label) 

419 _xlims = ax.get_xlim() 

420 percentile = pdf.credible_interval([5, 95]) 

421 median = pdf.average("median") 

422 if title: 

423 upper = np.round(percentile[1] - median, 2) 

424 lower = np.round(median - percentile[0], 2) 

425 median = np.round(median, 2) 

426 ax.set_title(r"$%s^{+%s}_{-%s}$" % (median, upper, lower)) 

427 if plot_percentile: 

428 for pp in percentile: 

429 ax.axvline( 

430 pp, color=color, linestyle="--", linewidth=linewidth 

431 ) 

432 if set_labels: 

433 ax.set_xlabel(latex_label) 

434 ax.set_ylabel("Probability Density") 

435 

436 if inj_value is not None: 

437 ax.axvline( 

438 inj_value, color=injection_color, **_default_inj_kwargs 

439 ) 

440 ax.grid(visible=grid) 

441 ax.set_xlim(xlims) 

442 if autoscale: 

443 ax.set_xlim(_xlims) 

444 if fig is None: 

445 return ax 

446 fig.tight_layout() 

447 return fig 

448 

449 

450def _1d_histogram_plot( 

451 param, samples, latex_label, inj_value=None, kde=False, hist=True, 

452 prior=None, weights=None, fig=None, ax=None, title=True, color=conf.color, 

453 autoscale=True, grid=True, kde_kwargs={}, hist_kwargs={}, set_labels=True, 

454 plot_percentile=True, plot_hdp=True, xlims=None, max_vline=1, label=None, 

455 linestyle="-", injection_color=conf.injection_color, _default_hist_kwargs={ 

456 "density": True, "bins": 50, "histtype": "step", "linewidth": 1.75 

457 }, _default_kde_kwargs={"shade": True, "alpha_shade": 0.1}, 

458 _default_inj_kwargs={"linewidth": 2.5, "linestyle": "-"}, 

459 key_data=None, **plot_kwargs 

460): 

461 """Generate the 1d histogram plot for a given parameter for a given 

462 approximant. 

463 

464 Parameters 

465 ---------- 

466 param: str 

467 name of the parameter that you wish to plot 

468 samples: list 

469 list of samples for param 

470 latex_label: str 

471 latex label for param 

472 inj_value: float, optional 

473 value that was injected 

474 kde: Bool, optional 

475 if True, a kde is plotted instead of a histogram 

476 hist: Bool, optional 

477 if True, plot a histogram 

478 prior: list 

479 list of prior samples for param 

480 weights: list 

481 list of weights for each sample 

482 fig: matplotlib.pyplot.figure, optional 

483 existing figure you wish to use 

484 ax: matplotlib.pyplot.axes._subplots.AxesSubplot, optional 

485 existing axis you wish to use 

486 color: str, optional 

487 color you wish to use to plot the scatter points 

488 title: Bool, optional 

489 if True, add a title to the 1d cdf plot showing giving the median 

490 and symmetric 90% credible intervals 

491 autoscale: Bool, optional 

492 autoscale the x axis 

493 grid: Bool, optional 

494 if True, plot a grid 

495 kde_kwargs, dict, optional 

496 optional kwargs to pass to the kde class 

497 hist_kwargs: dict, optional 

498 optional kwargs to pass to matplotlib.pyplot.hist 

499 set_labels: Bool, optional 

500 if True, add labels to the axes 

501 plot_percentile: Bool, optional 

502 if True, plot dashed vertical lines showing the 90% symmetric credible 

503 intervals 

504 xlims: list, optional 

505 x axis limits you wish to use 

506 max_vline: int, optional 

507 if number of peaks < max_vline draw peaks as vertical lines rather 

508 than histogramming the data 

509 label: str, optional 

510 label you wish to use for the plot 

511 linestyle: str, optional 

512 linestyle you wish to use for the plot 

513 injection_color: str, optional 

514 color of vertical line showing the injected value 

515 """ 

516 from pesummary.utils.array import Array 

517 

518 logger.debug("Generating the 1d histogram plot for %s" % (param)) 

519 samples = Array(samples, weights=weights) 

520 if ax is None and fig is None: 

521 fig, ax = figure(gca=True) 

522 elif ax is None: 

523 ax = fig.gca() 

524 

525 if len(set(samples)) <= max_vline: 

526 for _ind, _sample in enumerate(set(samples)): 

527 _label = None 

528 if _ind == 0: 

529 _label = label 

530 ax.axvline(_sample, color=color, label=_label) 

531 _xlims = ax.get_xlim() 

532 else: 

533 if hist: 

534 _default_hist_kwargs.update(hist_kwargs) 

535 ax.hist( 

536 samples, weights=weights, color=color, label=label, 

537 linestyle=linestyle, **_default_hist_kwargs, **plot_kwargs 

538 ) 

539 _xlims = ax.get_xlim() 

540 if prior is not None: 

541 _prior_hist_kwargs = _default_hist_kwargs.copy() 

542 _prior_hist_kwargs["histtype"] = "bar" 

543 _ = ax.hist( 

544 prior, color=conf.prior_color, alpha=0.2, edgecolor="w", 

545 linestyle=linestyle, **_prior_hist_kwargs, **plot_kwargs 

546 ) 

547 if kde: 

548 _kde_kwargs = kde_kwargs.copy() 

549 kwargs = _default_kde_kwargs 

550 kwargs.update({ 

551 "kde_kwargs": _kde_kwargs, 

552 "kde_kernel": _kde_kwargs.pop("kde_kernel", None), 

553 "variance_atol": _kde_kwargs.pop("variance_atol", 1e-8), 

554 "weights": weights 

555 }) 

556 kwargs.update(plot_kwargs) 

557 x = kdeplot( 

558 samples, color=color, ax=ax, linestyle=linestyle, **kwargs 

559 ) 

560 _xlims = ax.get_xlim() 

561 if prior is not None: 

562 kdeplot( 

563 prior, color=conf.prior_color, ax=ax, linestyle=linestyle, 

564 **kwargs 

565 ) 

566 

567 if set_labels: 

568 ax.set_xlabel(latex_label) 

569 ax.set_ylabel("Probability Density") 

570 

571 if inj_value is not None: 

572 ax.axvline( 

573 inj_value, color=injection_color, **_default_inj_kwargs 

574 ) 

575 hdp = float("nan") 

576 if key_data is not None: 

577 percentile = [key_data["5th percentile"], key_data["95th percentile"]] 

578 median = key_data["median"] 

579 if "90% HPD" in key_data.keys(): 

580 hdp = key_data["90% HPD"] 

581 else: 

582 percentile = samples.credible_interval([5, 95]) 

583 median = samples.average("median") 

584 if plot_percentile: 

585 for pp in percentile: 

586 ax.axvline( 

587 pp, color=color, linestyle="--", 

588 linewidth=hist_kwargs.get("linewidth", 1.75) 

589 ) 

590 if plot_hdp and isinstance(hdp, (list, np.ndarray)): 

591 for pp in hdp: 

592 ax.axvline( 

593 pp, color=color, linestyle=":", 

594 linewidth=hist_kwargs.get("linewidth", 1.75) 

595 ) 

596 if title: 

597 upper = np.round(percentile[1] - median, 2) 

598 lower = np.abs(np.round(median - percentile[0], 2)) 

599 median = np.round(median, 2) 

600 _base = r"$%s^{+%s}_{-%s}" % (median, upper, lower) 

601 if not isinstance(hdp, (list, np.ndarray)) and np.isnan(hdp): 

602 _base += r"$" 

603 ax.set_title(_base) 

604 else: 

605 upper = np.round(hdp[1] - median, 2) 

606 lower = np.abs(np.round(median - hdp[0], 2)) 

607 _base += r"\, (\mathrm{CI}) / %s^{+%s}_{-%s}\, (\mathrm{HPD})$" % ( 

608 median, upper, lower 

609 ) 

610 ax.set_title(_base) 

611 ax.grid(visible=grid) 

612 ax.set_xlim(xlims) 

613 if autoscale: 

614 ax.set_xlim(_xlims) 

615 if fig is None: 

616 return ax 

617 fig.tight_layout() 

618 return fig 

619 

620 

621def _1d_histogram_plot_mcmc( 

622 param, samples, latex_label, colorcycle=conf.colorcycle, **kwargs 

623): 

624 """Generate a 1d histogram plot for a given parameter for a given 

625 set of mcmc chains 

626 

627 Parameters 

628 ---------- 

629 param: str 

630 name of the parameter that you wish to plot 

631 samples: np.ndarray 

632 2d array of samples for param for each mcmc chain 

633 latex_label: str 

634 latex label for param 

635 colorcycle: list, str 

636 color cycle you wish to use for the different mcmc chains 

637 **kwargs: dict, optional 

638 all additional kwargs passed to _1d_histogram_plot 

639 """ 

640 cycol = cycle(colorcycle) 

641 fig, ax = figure(gca=True) 

642 for ss in samples: 

643 fig = _1d_histogram_plot( 

644 param, ss, latex_label, color=next(cycol), title=False, 

645 autoscale=False, fig=fig, **kwargs 

646 ) 

647 gelman = gelman_rubin(samples) 

648 ax.set_title("Gelman-Rubin: {}".format(gelman)) 

649 return fig 

650 

651 

652def _1d_histogram_plot_bootstrap( 

653 param, samples, latex_label, colorcycle=conf.colorcycle, nsamples=1000, 

654 ntests=100, shade=False, plot_percentile=False, kde=True, hist=False, 

655 **kwargs 

656): 

657 """Generate a bootstrapped 1d histogram plot for a given parameter 

658 

659 Parameters 

660 ---------- 

661 param: str 

662 name of the parameter that you wish to plot 

663 samples: np.ndarray 

664 array of samples for param 

665 latex_label: str 

666 latex label for param 

667 colorcycle: list, str 

668 color cycle you wish to use for the different tests 

669 nsamples: int, optional 

670 number of samples to randomly draw from samples. Default 1000 

671 ntests: int, optional 

672 number of tests to perform. Default 100 

673 **kwargs: dict, optional 

674 all additional kwargs passed to _1d_histogram_plot 

675 """ 

676 if nsamples > len(samples): 

677 nsamples = int(len(samples) / 2) 

678 _samples = [ 

679 np.random.choice(samples, size=nsamples, replace=False) for _ in 

680 range(ntests) 

681 ] 

682 cycol = cycle(colorcycle) 

683 fig, ax = figure(gca=True) 

684 for ss in _samples: 

685 fig = _1d_histogram_plot( 

686 param, ss, latex_label, color=next(cycol), title=False, 

687 autoscale=False, fig=fig, shade=shade, 

688 plot_percentile=plot_percentile, kde=kde, hist=hist, **kwargs 

689 ) 

690 ax.set_title("Ntests: {}, Nsamples per test: {}".format(ntests, nsamples)) 

691 fig.tight_layout() 

692 return fig 

693 

694 

695def _1d_comparison_histogram_plot( 

696 param, samples, colors, latex_label, labels, inj_value=None, kde=False, 

697 hist=True, linestyles=None, kde_kwargs={}, hist_kwargs={}, max_vline=1, 

698 figsize=(8, 6), grid=True, legend_kwargs=_default_legend_kwargs, 

699 latex_friendly=False, max_inj_line=1, injection_color="k", 

700 weights=None, **kwargs 

701): 

702 """Generate the a plot to compare the 1d_histogram plots for a given 

703 parameter for different approximants. 

704 

705 Parameters 

706 ---------- 

707 param: str 

708 name of the parameter that you wish to plot 

709 approximants: list 

710 list of approximant names that you would like to compare 

711 samples: 2d list 

712 list of samples for param for each approximant 

713 colors: list 

714 list of colors to be used to differentiate the different approximants 

715 latex_label: str 

716 latex label for param 

717 approximant_labels: list, optional 

718 label to prepend the approximant in the legend 

719 kde: Bool 

720 if true, a kde is plotted instead of a histogram 

721 linestyles: list 

722 list of linestyles for each set of samples 

723 grid: Bool, optional 

724 if True, plot a grid 

725 legend_kwargs: dict, optional 

726 optional kwargs to pass to ax.legend() 

727 latex_friendly: Bool, optional 

728 if True, make the label latex friendly. Default False 

729 inj_value: float/list, optional 

730 either a single injection value which will be used for all histograms 

731 or a list of injection values, one for each histogram 

732 injection_color: str/list, optional 

733 either a single color which will be used for all vertical line showing 

734 the injected value or a list of colors, one for each injection 

735 weights: list, optional 

736 list of weights to use for each analysis. Default None 

737 **kwargs: dict, optional 

738 all additional kwargs passed to _1d_histogram_plot 

739 """ 

740 logger.debug("Generating the 1d comparison histogram plot for %s" % (param)) 

741 if linestyles is None: 

742 linestyles = ["-"] * len(samples) 

743 if inj_value is None: 

744 inj_value = [None] * len(samples) 

745 if weights is None: 

746 weights = [None] * len(samples) 

747 if len(weights) != len(samples): 

748 raise ValueError( 

749 "Please provide a set of weights for each analysis" 

750 ) 

751 elif isinstance(inj_value, (list, np.ndarray)) and len(inj_value) != len(samples): 

752 raise ValueError( 

753 "Please provide an injection for each analysis or a single " 

754 "injection value which will be used for all histograms" 

755 ) 

756 elif not isinstance(inj_value, (list, np.ndarray)): 

757 inj_value = [inj_value] * len(samples) 

758 

759 if isinstance(injection_color, str): 

760 injection_color = [injection_color] * len(samples) 

761 elif len(injection_color) != len(samples): 

762 raise ValueError( 

763 "Please provide an injection color for each analysis or a single " 

764 "injection color which will be used for all lines showing the " 

765 "injected values" 

766 ) 

767 

768 flat_injection = np.array([_ for _ in inj_value if _ is not None]).flatten() 

769 if len(set(flat_injection)) > max_inj_line: 

770 logger.warning( 

771 "Number of unique injection values ({}) is more than the maximum " 

772 "allowed injection value ({}). Not plotting injection value. If " 

773 "this is a mistake, please increase `max_inj_line`".format( 

774 len(set(flat_injection)), max_inj_line 

775 ) 

776 ) 

777 inj_value = [None] * len(samples) 

778 

779 fig, ax = figure(figsize=figsize, gca=True) 

780 handles = [] 

781 hist_kwargs.update({"linewidth": 2.5}) 

782 for num, i in enumerate(samples): 

783 if latex_friendly: 

784 labels = copy.deepcopy(labels) 

785 labels[num] = labels[num].replace("_", "\_") 

786 fig = _1d_histogram_plot( 

787 param, i, latex_label, kde=kde, hist=hist, kde_kwargs=kde_kwargs, 

788 max_vline=max_vline, grid=grid, title=False, autoscale=False, 

789 label=labels[num], color=colors[num], fig=fig, hist_kwargs=hist_kwargs, 

790 inj_value=inj_value[num], injection_color=injection_color[num], 

791 weights=weights[num], linestyle=linestyles[num], _default_inj_kwargs={ 

792 "linewidth": 4., "linestyle": "-", "alpha": 0.4 

793 }, **kwargs 

794 ) 

795 handles.append(mlines.Line2D([], [], color=colors[num], label=labels[num])) 

796 ax = fig.gca() 

797 ncols = number_of_columns_for_legend(labels) 

798 legend = ax.legend(handles=handles, ncol=ncols, **legend_kwargs) 

799 _update_1d_comparison_legend(legend, linestyles) 

800 ax.set_xlabel(latex_label) 

801 ax.set_ylabel("Probability Density") 

802 ax.autoscale(axis='x') 

803 ax.grid(visible=grid) 

804 fig.tight_layout() 

805 return fig 

806 

807 

808def _comparison_box_plot(param, samples, colors, latex_label, labels, grid=True): 

809 """Generate a box plot to compare 1d_histograms for a given parameter 

810 

811 Parameters 

812 ---------- 

813 param: str 

814 name of the parameter that you wish to plot 

815 approximants: list 

816 list of approximant names that you would like to compare 

817 samples: 2d list 

818 list of samples for param for each approximant 

819 colors: list 

820 list of colors to be used to differentiate the different approximants 

821 latex_label: str 

822 latex label for param 

823 approximant_labels: list, optional 

824 label to prepend the approximant in the legend 

825 grid: Bool, optional 

826 if True, plot a grid 

827 """ 

828 logger.debug("Generating the 1d comparison boxplot plot for %s" % (param)) 

829 fig, ax = figure(gca=True) 

830 maximum = np.max([np.max(i) for i in samples]) 

831 minimum = np.min([np.min(i) for i in samples]) 

832 middle = (maximum + minimum) * 0.5 

833 ax.boxplot(samples, widths=0.2, vert=False, whis=np.inf, labels=labels) 

834 for num, i in enumerate(labels): 

835 ax.annotate(i, xy=(middle, 1), xytext=(middle, num + 1.0 + 0.2), ha="center") 

836 ax.set_yticks([]) 

837 ax.set_xlabel(latex_label) 

838 fig.tight_layout() 

839 ax.grid(visible=grid) 

840 return fig 

841 

842 

843def _make_corner_plot( 

844 samples, latex_labels, corner_parameters=None, parameters=None, **kwargs 

845): 

846 """Generate the corner plots for a given approximant 

847 

848 Parameters 

849 ---------- 

850 opts: argparse 

851 argument parser object to hold all information from the command line 

852 samples: nd list 

853 nd list of samples for each parameter for a given approximant 

854 params: list 

855 list of parameters associated with each element in samples 

856 approximant: str 

857 name of approximant that was used to generate the samples 

858 latex_labels: dict 

859 dictionary of latex labels for each parameter 

860 """ 

861 logger.debug("Generating the corner plot") 

862 # set the default kwargs 

863 default_kwargs = conf.corner_kwargs.copy() 

864 if parameters is None: 

865 parameters = list(samples.keys()) 

866 if corner_parameters is not None: 

867 included_parameters = [i for i in parameters if i in corner_parameters] 

868 else: 

869 included_parameters = parameters 

870 xs = np.zeros([len(included_parameters), len(samples[parameters[0]])]) 

871 for num, i in enumerate(included_parameters): 

872 xs[num] = samples[i] 

873 default_kwargs.update(kwargs) 

874 default_kwargs["range"] = [1.0] * len(included_parameters) 

875 default_kwargs["labels"] = [latex_labels[i] for i in included_parameters] 

876 

877 _figure = ExistingFigure(corner(xs.T, included_parameters, **default_kwargs)) 

878 # grab the axes of the subplots 

879 axes = _figure.get_axes() 

880 axes_of_interest = axes[:2] 

881 location = [] 

882 for i in axes_of_interest: 

883 extent = i.get_window_extent().transformed(_figure.dpi_scale_trans.inverted()) 

884 location.append([extent.x0 * _figure.dpi, extent.y0 * _figure.dpi]) 

885 width, height = extent.width, extent.height 

886 width *= _figure.dpi 

887 height *= _figure.dpi 

888 try: 

889 seperation = abs(location[0][0] - location[1][0]) - width 

890 except IndexError: 

891 seperation = None 

892 data = { 

893 "width": width, "height": height, "seperation": seperation, 

894 "x0": location[0][0], "y0": location[0][0] 

895 } 

896 return _figure, included_parameters, data 

897 

898 

899def _make_comparison_corner_plot( 

900 samples, latex_labels, corner_parameters=None, colors=conf.corner_colors, 

901 latex_friendly=True, **kwargs 

902): 

903 """Generate a corner plot which contains multiple datasets 

904 

905 Parameters 

906 ---------- 

907 samples: dict 

908 nested dictionary containing the label as key and SamplesDict as item 

909 for each dataset you wish to plot 

910 latex_labels: dict 

911 dictionary of latex labels for each parameter 

912 corner_parameters: list, optional 

913 corner parameters you wish to include in the plot 

914 colors: list, optional 

915 unique colors for each dataset 

916 latex_friendly: Bool, optional 

917 if True, make the label latex friendly. Default True 

918 **kwargs: dict 

919 all kwargs are passed to `corner.corner` 

920 """ 

921 parameters = corner_parameters 

922 if corner_parameters is None: 

923 _parameters = [list(_samples.keys()) for _samples in samples.values()] 

924 parameters = [ 

925 i for i in _parameters[0] if all(i in _params for _params in _parameters) 

926 ] 

927 if len(samples.keys()) > len(colors): 

928 raise ValueError("Please provide a unique color for each dataset") 

929 

930 hist_kwargs = kwargs.get("hist_kwargs", dict()) 

931 hist_kwargs["density"] = True 

932 lines = [] 

933 for num, (label, posterior) in enumerate(samples.items()): 

934 if latex_friendly: 

935 label = copy.deepcopy(label) 

936 label = label.replace("_", "\_") 

937 lines.append(mlines.Line2D([], [], color=colors[num], label=label)) 

938 _samples = { 

939 param: value for param, value in posterior.items() if param in 

940 parameters 

941 } 

942 hist_kwargs["color"] = colors[num] 

943 kwargs.update({"hist_kwargs": hist_kwargs}) 

944 if num == 0: 

945 fig, _, _ = _make_corner_plot( 

946 _samples, latex_labels, corner_parameters=corner_parameters, 

947 parameters=parameters, color=colors[num], **kwargs 

948 ) 

949 else: 

950 fig, _, _ = _make_corner_plot( 

951 _samples, latex_labels, corner_parameters=corner_parameters, 

952 fig=fig, parameters=parameters, color=colors[num], **kwargs 

953 ) 

954 fig.legend(handles=lines, loc="upper right") 

955 lines = [] 

956 return fig