Coverage for pesummary/core/plots/plot.py: 85.4%

349 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3from pesummary.utils.utils import ( 

4 logger, number_of_columns_for_legend, _check_latex_install, gelman_rubin, 

5) 

6from pesummary.core.plots.seaborn.kde import kdeplot 

7from pesummary.core.plots.corner import corner 

8from pesummary.core.plots.figure import figure, ExistingFigure 

9from pesummary import conf 

10 

11import matplotlib.lines as mlines 

12import copy 

13from itertools import cycle 

14 

15import numpy as np 

16from scipy import signal 

17 

18__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

19_check_latex_install() 

20 

21_default_legend_kwargs = dict( 

22 bbox_to_anchor=(0.0, 1.02, 1.0, 0.102), loc=3, handlelength=3, mode="expand", 

23 borderaxespad=0.0, 

24) 

25 

26 

27def _autocorrelation_plot( 

28 param, samples, fig=None, color=conf.color, markersize=0.5, grid=True 

29): 

30 """Generate the autocorrelation function for a set of samples for a given 

31 parameter for a given approximant. 

32 

33 Parameters 

34 ---------- 

35 param: str 

36 name of the parameter that you wish to plot 

37 samples: list 

38 list of samples for param 

39 fig: matplotlib.pyplot.figure 

40 existing figure you wish to use 

41 color: str, optional 

42 color you wish to use for the autocorrelation plot 

43 grid: Bool, optional 

44 if True, plot a grid 

45 """ 

46 import warnings 

47 warnings.filterwarnings("ignore", category=RuntimeWarning) 

48 logger.debug("Generating the autocorrelation function for %s" % (param)) 

49 if fig is None: 

50 fig, ax = figure(gca=True) 

51 else: 

52 ax = fig.gca() 

53 samples = samples[int(len(samples) / 2):] 

54 x = samples - np.mean(samples) 

55 y = np.conj(x[::-1]) 

56 acf = np.fft.ifftshift(signal.fftconvolve(y, x, mode="full")) 

57 N = np.array(samples).shape[0] 

58 acf = acf[0:N] 

59 # Hack to make test pass with python3.8 

60 if color == "$": 

61 color = conf.color 

62 ax.plot( 

63 acf / acf[0], linestyle=" ", marker="o", markersize=markersize, 

64 color=color 

65 ) 

66 ax.ticklabel_format(axis="x", style="plain") 

67 ax.set_xlabel("lag") 

68 ax.set_ylabel("ACF") 

69 ax.grid(visible=grid) 

70 fig.tight_layout() 

71 return fig 

72 

73 

74def _autocorrelation_plot_mcmc( 

75 param, samples, colorcycle=conf.colorcycle, grid=True 

76): 

77 """Generate the autocorrelation function for a set of samples for a given 

78 parameter for a given set of mcmc chains 

79 

80 Parameters 

81 ---------- 

82 param: str 

83 name of the parameter that you wish to plot 

84 samples: np.ndarray 

85 2d array containing a list of samples for param for each mcmc chain 

86 colorcycle: list, str 

87 color cycle you wish to use for the different mcmc chains 

88 grid: Bool, optional 

89 if True, plot a grid 

90 """ 

91 cycol = cycle(colorcycle) 

92 fig, ax = figure(gca=True) 

93 for ss in samples: 

94 fig = _autocorrelation_plot( 

95 param, ss, fig=fig, markersize=1.25, color=next(cycol), grid=grid 

96 ) 

97 return fig 

98 

99 

100def _sample_evolution_plot( 

101 param, samples, latex_label, inj_value=None, fig=None, color=conf.color, 

102 markersize=0.5, grid=True, z=None, z_label=None, **kwargs 

103): 

104 """Generate a scatter plot showing the evolution of the samples for a 

105 given parameter for a given approximant. 

106 

107 Parameters 

108 ---------- 

109 param: str 

110 name of the parameter that you wish to plot 

111 samples: list 

112 list of samples for param 

113 latex_label: str 

114 latex label for param 

115 inj_value: float 

116 value that was injected 

117 fig: matplotlib.pyplot.figure, optional 

118 existing figure you wish to use 

119 color: str, optional 

120 color you wish to use to plot the scatter points 

121 grid: Bool, optional 

122 if True, plot a grid 

123 """ 

124 logger.debug("Generating the sample scatter plot for %s" % (param)) 

125 if fig is None: 

126 fig, ax = figure(gca=True) 

127 else: 

128 ax = fig.gca() 

129 n_samples = len(samples) 

130 add_cbar = True if z is not None else False 

131 if z is None: 

132 z = color 

133 s = ax.scatter( 

134 range(n_samples), samples, marker="o", s=markersize, c=z, 

135 **kwargs 

136 ) 

137 if add_cbar: 

138 cbar = fig.colorbar(s) 

139 if z_label is not None: 

140 cbar.set_label(z_label) 

141 ax.ticklabel_format(axis="x", style="plain") 

142 ax.set_xlabel("samples") 

143 ax.set_ylabel(latex_label) 

144 ax.grid(visible=grid) 

145 fig.tight_layout() 

146 return fig 

147 

148 

149def _sample_evolution_plot_mcmc( 

150 param, samples, latex_label, inj_value=None, colorcycle=conf.colorcycle, 

151 grid=True 

152): 

153 """Generate a scatter plot showing the evolution of the samples in each 

154 mcmc chain for a given parameter 

155 

156 Parameters 

157 ---------- 

158 param: str 

159 name of the parameter that you wish to plot 

160 samples: np.ndarray 

161 2d array containing the samples for param for each mcmc chain 

162 latex_label: str 

163 latex label for param 

164 inj_value: float 

165 value that was injected 

166 colorcycle: list, str 

167 color cycle you wish to use for the different mcmc chains 

168 grid: Bool, optional 

169 if True, plot a grid 

170 """ 

171 cycol = cycle(colorcycle) 

172 fig, ax = figure(gca=True) 

173 for ss in samples: 

174 fig = _sample_evolution_plot( 

175 param, ss, latex_label, inj_value=None, fig=fig, markersize=1.25, 

176 color=next(cycol), grid=grid 

177 ) 

178 return fig 

179 

180 

181def _1d_cdf_plot( 

182 param, samples, latex_label, fig=None, color=conf.color, title=True, 

183 grid=True, linestyle="-", **kwargs 

184): 

185 """Generate the cumulative distribution function for a given parameter for 

186 a given approximant. 

187 

188 Parameters 

189 ---------- 

190 param: str 

191 name of the parameter that you wish to plot 

192 samples: list 

193 list of samples for param 

194 latex_label: str 

195 latex label for param 

196 fig: matplotlib.pyplot.figure, optional 

197 existing figure you wish to use 

198 color: str, optional09 

199 color you wish to use to plot the scatter points 

200 title: Bool, optional 

201 if True, add a title to the 1d cdf plot showing giving the median 

202 and symmetric 90% credible intervals 

203 grid: Bool, optional 

204 if True, plot a grid 

205 linestyle: str, optional 

206 linestyle to use for plotting the CDF. Default "-" 

207 **kwargs: dict, optional 

208 all additional kwargs passed to ax.plot 

209 """ 

210 logger.debug("Generating the 1d CDF for %s" % (param)) 

211 if fig is None: 

212 fig, ax = figure(gca=True) 

213 else: 

214 ax = fig.gca() 

215 sorted_samples = copy.deepcopy(samples) 

216 sorted_samples.sort() 

217 ax.set_xlabel(latex_label) 

218 ax.set_ylabel("Cumulative Density Function") 

219 upper_percentile = np.percentile(samples, 95) 

220 lower_percentile = np.percentile(samples, 5) 

221 median = np.median(samples) 

222 upper = np.round(upper_percentile - median, 2) 

223 lower = np.round(median - lower_percentile, 2) 

224 median = np.round(median, 2) 

225 if title: 

226 ax.set_title(r"$%s^{+%s}_{-%s}$" % (median, upper, lower)) 

227 ax.plot( 

228 sorted_samples, np.linspace(0, 1, len(sorted_samples)), color=color, 

229 linestyle=linestyle, **kwargs 

230 ) 

231 ax.grid(visible=grid) 

232 ax.set_ylim([0, 1.05]) 

233 fig.tight_layout() 

234 return fig 

235 

236 

237def _1d_cdf_plot_mcmc( 

238 param, samples, latex_label, colorcycle=conf.colorcycle, grid=True, **kwargs 

239): 

240 """Generate the cumulative distribution function for a given parameter 

241 for a given set of mcmc chains 

242 

243 Parameters 

244 ---------- 

245 param: str 

246 name of the parameter that you wish to plot 

247 samples: np.ndarray 

248 2d array containing the samples for param for each mcmc chain 

249 latex_label: str 

250 latex label for param 

251 colorcycle: list, str 

252 color cycle you wish to use for the different mcmc chains 

253 grid: Bool, optional 

254 if True, plot a grid 

255 **kwargs: dict, optional 

256 all additional kwargs passed to _1d_cdf_plot 

257 """ 

258 cycol = cycle(colorcycle) 

259 fig, ax = figure(gca=True) 

260 for ss in samples: 

261 fig = _1d_cdf_plot( 

262 param, ss, latex_label, fig=fig, color=next(cycol), title=False, 

263 grid=grid, **kwargs 

264 ) 

265 gelman = gelman_rubin(samples) 

266 ax.set_title("Gelman-Rubin: {}".format(gelman)) 

267 return fig 

268 

269 

270def _1d_cdf_comparison_plot( 

271 param, samples, colors, latex_label, labels, linestyles=None, grid=True, 

272 legend_kwargs=_default_legend_kwargs, latex_friendly=False, **kwargs 

273): 

274 """Generate a plot to compare the cdfs for a given parameter for different 

275 approximants. 

276 

277 Parameters 

278 ---------- 

279 param: str 

280 name of the parameter that you wish to plot 

281 approximants: list 

282 list of approximant names that you would like to compare 

283 samples: 2d list 

284 list of samples for param for each approximant 

285 colors: list 

286 list of colors to be used to differentiate the different approximants 

287 latex_label: str 

288 latex label for param 

289 approximant_labels: list, optional 

290 label to prepend the approximant in the legend 

291 grid: Bool, optional 

292 if True, plot a grid 

293 legend_kwargs: dict, optional 

294 optional kwargs to pass to ax.legend() 

295 latex_friendly: Bool, optional 

296 if True, make the label latex friendly. Default False 

297 **kwargs: dict, optional 

298 all additional kwargs passed to _1d_cdf_plot 

299 """ 

300 logger.debug("Generating the 1d comparison CDF for %s" % (param)) 

301 if linestyles is None: 

302 linestyles = ["-"] * len(samples) 

303 fig, ax = figure(figsize=(8, 6), gca=True) 

304 handles = [] 

305 for num, i in enumerate(samples): 

306 fig = _1d_cdf_plot( 

307 param, i, latex_label, fig=fig, color=colors[num], title=False, 

308 grid=grid, linestyle=linestyles[num], **kwargs 

309 ) 

310 if latex_friendly: 

311 labels = copy.deepcopy(labels) 

312 labels[num] = labels[num].replace("_", "\_") 

313 handles.append(mlines.Line2D([], [], color=colors[num], label=labels[num])) 

314 ncols = number_of_columns_for_legend(labels) 

315 legend = ax.legend(handles=handles, ncol=ncols, **legend_kwargs) 

316 for num, legobj in enumerate(legend.legendHandles): 

317 legobj.set_linewidth(1.75) 

318 legobj.set_linestyle(linestyles[num]) 

319 ax.set_xlabel(latex_label) 

320 ax.set_ylabel("Cumulative Density Function") 

321 ax.grid(visible=grid) 

322 ax.set_ylim([0, 1.05]) 

323 fig.tight_layout() 

324 return fig 

325 

326 

327def _1d_analytic_plot( 

328 param, x, pdf, latex_label, inj_value=None, prior=None, fig=None, ax=None, 

329 title=True, color=conf.color, autoscale=True, grid=True, set_labels=True, 

330 plot_percentile=True, xlims=None, label=None, linestyle="-", 

331 linewidth=1.75, injection_color=conf.injection_color, 

332 _default_inj_kwargs={"linewidth": 2.5, "linestyle": "-"}, **plot_kwargs 

333): 

334 """Generate a plot to display a PDF 

335 

336 Parameters 

337 ---------- 

338 param: str 

339 name of the parameter that you wish to plot 

340 

341 latex_label: str 

342 latex label for param 

343 inj_value: float, optional 

344 value that was injected 

345 prior: list 

346 list of prior samples for param 

347 weights: list 

348 list of weights for each sample 

349 fig: matplotlib.pyplot.figure, optional 

350 existing figure you wish to use 

351 ax: matplotlib.pyplot.axes._subplots.AxesSubplot, optional 

352 existing axis you wish to use 

353 color: str, optional 

354 color you wish to use to plot the scatter points 

355 title: Bool, optional 

356 if True, add a title to the 1d cdf plot showing giving the median 

357 and symmetric 90% credible intervals 

358 autoscale: Bool, optional 

359 autoscale the x axis 

360 grid: Bool, optional 

361 if True, plot a grid 

362 set_labels: Bool, optional 

363 if True, add labels to the axes 

364 plot_percentile: Bool, optional 

365 if True, plot dashed vertical lines showing the 90% symmetric credible 

366 intervals 

367 xlims: list, optional 

368 x axis limits you wish to use 

369 label: str, optional 

370 label you wish to use for the plot 

371 linestyle: str, optional 

372 linestyle you wish to use for the plot 

373 linewidth: float, optional 

374 linewidth to use for the plot 

375 injection_color: str, optional 

376 color of vertical line showing the injected value 

377 """ 

378 from pesummary.utils.array import Array 

379 

380 if ax is None and fig is None: 

381 fig, ax = figure(gca=True) 

382 elif ax is None: 

383 ax = fig.gca() 

384 

385 pdf = Array(x, weights=pdf) 

386 

387 ax.plot(pdf, pdf.weights, color=color, linestyle=linestyle, label=label) 

388 _xlims = ax.get_xlim() 

389 percentile = pdf.credible_interval([5, 95]) 

390 median = pdf.average("median") 

391 if title: 

392 upper = np.round(percentile[1] - median, 2) 

393 lower = np.round(median - percentile[0], 2) 

394 median = np.round(median, 2) 

395 ax.set_title(r"$%s^{+%s}_{-%s}$" % (median, upper, lower)) 

396 if plot_percentile: 

397 for pp in percentile: 

398 ax.axvline( 

399 pp, color=color, linestyle="--", linewidth=linewidth 

400 ) 

401 if set_labels: 

402 ax.set_xlabel(latex_label) 

403 ax.set_ylabel("Probability Density") 

404 

405 if inj_value is not None: 

406 ax.axvline( 

407 inj_value, color=injection_color, **_default_inj_kwargs 

408 ) 

409 ax.grid(visible=grid) 

410 ax.set_xlim(xlims) 

411 if autoscale: 

412 ax.set_xlim(_xlims) 

413 if fig is None: 

414 return ax 

415 fig.tight_layout() 

416 return fig 

417 

418 

419def _1d_histogram_plot( 

420 param, samples, latex_label, inj_value=None, kde=False, hist=True, 

421 prior=None, weights=None, fig=None, ax=None, title=True, color=conf.color, 

422 autoscale=True, grid=True, kde_kwargs={}, hist_kwargs={}, set_labels=True, 

423 plot_percentile=True, plot_hdp=True, xlims=None, max_vline=1, label=None, 

424 linestyle="-", injection_color=conf.injection_color, _default_hist_kwargs={ 

425 "density": True, "bins": 50, "histtype": "step", "linewidth": 1.75 

426 }, _default_kde_kwargs={"shade": True, "alpha_shade": 0.1}, 

427 _default_inj_kwargs={"linewidth": 2.5, "linestyle": "-"}, 

428 key_data=None, **plot_kwargs 

429): 

430 """Generate the 1d histogram plot for a given parameter for a given 

431 approximant. 

432 

433 Parameters 

434 ---------- 

435 param: str 

436 name of the parameter that you wish to plot 

437 samples: list 

438 list of samples for param 

439 latex_label: str 

440 latex label for param 

441 inj_value: float, optional 

442 value that was injected 

443 kde: Bool, optional 

444 if True, a kde is plotted instead of a histogram 

445 hist: Bool, optional 

446 if True, plot a histogram 

447 prior: list 

448 list of prior samples for param 

449 weights: list 

450 list of weights for each sample 

451 fig: matplotlib.pyplot.figure, optional 

452 existing figure you wish to use 

453 ax: matplotlib.pyplot.axes._subplots.AxesSubplot, optional 

454 existing axis you wish to use 

455 color: str, optional 

456 color you wish to use to plot the scatter points 

457 title: Bool, optional 

458 if True, add a title to the 1d cdf plot showing giving the median 

459 and symmetric 90% credible intervals 

460 autoscale: Bool, optional 

461 autoscale the x axis 

462 grid: Bool, optional 

463 if True, plot a grid 

464 kde_kwargs, dict, optional 

465 optional kwargs to pass to the kde class 

466 hist_kwargs: dict, optional 

467 optional kwargs to pass to matplotlib.pyplot.hist 

468 set_labels: Bool, optional 

469 if True, add labels to the axes 

470 plot_percentile: Bool, optional 

471 if True, plot dashed vertical lines showing the 90% symmetric credible 

472 intervals 

473 xlims: list, optional 

474 x axis limits you wish to use 

475 max_vline: int, optional 

476 if number of peaks < max_vline draw peaks as vertical lines rather 

477 than histogramming the data 

478 label: str, optional 

479 label you wish to use for the plot 

480 linestyle: str, optional 

481 linestyle you wish to use for the plot 

482 injection_color: str, optional 

483 color of vertical line showing the injected value 

484 """ 

485 from pesummary.utils.array import Array 

486 

487 logger.debug("Generating the 1d histogram plot for %s" % (param)) 

488 samples = Array(samples, weights=weights) 

489 if ax is None and fig is None: 

490 fig, ax = figure(gca=True) 

491 elif ax is None: 

492 ax = fig.gca() 

493 

494 if len(set(samples)) <= max_vline: 

495 for _ind, _sample in enumerate(set(samples)): 

496 _label = None 

497 if _ind == 0: 

498 _label = label 

499 ax.axvline(_sample, color=color, label=_label) 

500 _xlims = ax.get_xlim() 

501 else: 

502 if hist: 

503 _default_hist_kwargs.update(hist_kwargs) 

504 ax.hist( 

505 samples, weights=weights, color=color, label=label, 

506 linestyle=linestyle, **_default_hist_kwargs, **plot_kwargs 

507 ) 

508 _xlims = ax.get_xlim() 

509 if prior is not None: 

510 _prior_hist_kwargs = _default_hist_kwargs.copy() 

511 _prior_hist_kwargs["histtype"] = "bar" 

512 _ = ax.hist( 

513 prior, color=conf.prior_color, alpha=0.2, edgecolor="w", 

514 linestyle=linestyle, **_prior_hist_kwargs, **plot_kwargs 

515 ) 

516 if kde: 

517 _kde_kwargs = kde_kwargs.copy() 

518 kwargs = _default_kde_kwargs 

519 kwargs.update({ 

520 "kde_kwargs": _kde_kwargs, 

521 "kde_kernel": _kde_kwargs.pop("kde_kernel", None), 

522 "variance_atol": _kde_kwargs.pop("variance_atol", 1e-8), 

523 "weights": weights 

524 }) 

525 kwargs.update(plot_kwargs) 

526 x = kdeplot( 

527 samples, color=color, ax=ax, linestyle=linestyle, **kwargs 

528 ) 

529 _xlims = ax.get_xlim() 

530 if prior is not None: 

531 kdeplot( 

532 prior, color=conf.prior_color, ax=ax, linestyle=linestyle, 

533 **kwargs 

534 ) 

535 

536 if set_labels: 

537 ax.set_xlabel(latex_label) 

538 ax.set_ylabel("Probability Density") 

539 

540 if inj_value is not None: 

541 ax.axvline( 

542 inj_value, color=injection_color, **_default_inj_kwargs 

543 ) 

544 hdp = float("nan") 

545 if key_data is not None: 

546 percentile = [key_data["5th percentile"], key_data["95th percentile"]] 

547 median = key_data["median"] 

548 if "90% HPD" in key_data.keys(): 

549 hdp = key_data["90% HPD"] 

550 else: 

551 percentile = samples.credible_interval([5, 95]) 

552 median = samples.average("median") 

553 if plot_percentile: 

554 for pp in percentile: 

555 ax.axvline( 

556 pp, color=color, linestyle="--", 

557 linewidth=hist_kwargs.get("linewidth", 1.75) 

558 ) 

559 if plot_hdp and isinstance(hdp, (list, np.ndarray)): 

560 for pp in hdp: 

561 ax.axvline( 

562 pp, color=color, linestyle=":", 

563 linewidth=hist_kwargs.get("linewidth", 1.75) 

564 ) 

565 if title: 

566 upper = np.round(percentile[1] - median, 2) 

567 lower = np.abs(np.round(median - percentile[0], 2)) 

568 median = np.round(median, 2) 

569 _base = r"$%s^{+%s}_{-%s}" % (median, upper, lower) 

570 if not isinstance(hdp, (list, np.ndarray)) and np.isnan(hdp): 

571 _base += r"$" 

572 ax.set_title(_base) 

573 else: 

574 upper = np.round(hdp[1] - median, 2) 

575 lower = np.abs(np.round(median - hdp[0], 2)) 

576 _base += r"\, (\mathrm{CI}) / %s^{+%s}_{-%s}\, (\mathrm{HPD})$" % ( 

577 median, upper, lower 

578 ) 

579 ax.set_title(_base) 

580 ax.grid(visible=grid) 

581 ax.set_xlim(xlims) 

582 if autoscale: 

583 ax.set_xlim(_xlims) 

584 if fig is None: 

585 return ax 

586 fig.tight_layout() 

587 return fig 

588 

589 

590def _1d_histogram_plot_mcmc( 

591 param, samples, latex_label, colorcycle=conf.colorcycle, **kwargs 

592): 

593 """Generate a 1d histogram plot for a given parameter for a given 

594 set of mcmc chains 

595 

596 Parameters 

597 ---------- 

598 param: str 

599 name of the parameter that you wish to plot 

600 samples: np.ndarray 

601 2d array of samples for param for each mcmc chain 

602 latex_label: str 

603 latex label for param 

604 colorcycle: list, str 

605 color cycle you wish to use for the different mcmc chains 

606 **kwargs: dict, optional 

607 all additional kwargs passed to _1d_histogram_plot 

608 """ 

609 cycol = cycle(colorcycle) 

610 fig, ax = figure(gca=True) 

611 for ss in samples: 

612 fig = _1d_histogram_plot( 

613 param, ss, latex_label, color=next(cycol), title=False, 

614 autoscale=False, fig=fig, **kwargs 

615 ) 

616 gelman = gelman_rubin(samples) 

617 ax.set_title("Gelman-Rubin: {}".format(gelman)) 

618 return fig 

619 

620 

621def _1d_histogram_plot_bootstrap( 

622 param, samples, latex_label, colorcycle=conf.colorcycle, nsamples=1000, 

623 ntests=100, shade=False, plot_percentile=False, kde=True, hist=False, 

624 **kwargs 

625): 

626 """Generate a bootstrapped 1d histogram plot for a given parameter 

627 

628 Parameters 

629 ---------- 

630 param: str 

631 name of the parameter that you wish to plot 

632 samples: np.ndarray 

633 array of samples for param 

634 latex_label: str 

635 latex label for param 

636 colorcycle: list, str 

637 color cycle you wish to use for the different tests 

638 nsamples: int, optional 

639 number of samples to randomly draw from samples. Default 1000 

640 ntests: int, optional 

641 number of tests to perform. Default 100 

642 **kwargs: dict, optional 

643 all additional kwargs passed to _1d_histogram_plot 

644 """ 

645 if nsamples > len(samples): 

646 nsamples = int(len(samples) / 2) 

647 _samples = [ 

648 np.random.choice(samples, size=nsamples, replace=False) for _ in 

649 range(ntests) 

650 ] 

651 cycol = cycle(colorcycle) 

652 fig, ax = figure(gca=True) 

653 for ss in _samples: 

654 fig = _1d_histogram_plot( 

655 param, ss, latex_label, color=next(cycol), title=False, 

656 autoscale=False, fig=fig, shade=shade, 

657 plot_percentile=plot_percentile, kde=kde, hist=hist, **kwargs 

658 ) 

659 ax.set_title("Ntests: {}, Nsamples per test: {}".format(ntests, nsamples)) 

660 fig.tight_layout() 

661 return fig 

662 

663 

664def _1d_comparison_histogram_plot( 

665 param, samples, colors, latex_label, labels, inj_value=None, kde=False, 

666 hist=True, linestyles=None, kde_kwargs={}, hist_kwargs={}, max_vline=1, 

667 figsize=(8, 6), grid=True, legend_kwargs=_default_legend_kwargs, 

668 latex_friendly=False, max_inj_line=1, injection_color="k", **kwargs 

669): 

670 """Generate the a plot to compare the 1d_histogram plots for a given 

671 parameter for different approximants. 

672 

673 Parameters 

674 ---------- 

675 param: str 

676 name of the parameter that you wish to plot 

677 approximants: list 

678 list of approximant names that you would like to compare 

679 samples: 2d list 

680 list of samples for param for each approximant 

681 colors: list 

682 list of colors to be used to differentiate the different approximants 

683 latex_label: str 

684 latex label for param 

685 approximant_labels: list, optional 

686 label to prepend the approximant in the legend 

687 kde: Bool 

688 if true, a kde is plotted instead of a histogram 

689 linestyles: list 

690 list of linestyles for each set of samples 

691 grid: Bool, optional 

692 if True, plot a grid 

693 legend_kwargs: dict, optional 

694 optional kwargs to pass to ax.legend() 

695 latex_friendly: Bool, optional 

696 if True, make the label latex friendly. Default False 

697 inj_value: float/list, optional 

698 either a single injection value which will be used for all histograms 

699 or a list of injection values, one for each histogram 

700 injection_color: str/list, optional 

701 either a single color which will be used for all vertical line showing 

702 the injected value or a list of colors, one for each injection 

703 **kwargs: dict, optional 

704 all additional kwargs passed to _1d_histogram_plot 

705 """ 

706 logger.debug("Generating the 1d comparison histogram plot for %s" % (param)) 

707 if linestyles is None: 

708 linestyles = ["-"] * len(samples) 

709 if inj_value is None: 

710 inj_value = [None] * len(samples) 

711 elif isinstance(inj_value, (list, np.ndarray)) and len(inj_value) != len(samples): 

712 raise ValueError( 

713 "Please provide an injection for each analysis or a single " 

714 "injection value which will be used for all histograms" 

715 ) 

716 elif not isinstance(inj_value, (list, np.ndarray)): 

717 inj_value = [inj_value] * len(samples) 

718 

719 if isinstance(injection_color, str): 

720 injection_color = [injection_color] * len(samples) 

721 elif len(injection_color) != len(samples): 

722 raise ValueError( 

723 "Please provide an injection color for each analysis or a single " 

724 "injection color which will be used for all lines showing the " 

725 "injected values" 

726 ) 

727 

728 flat_injection = np.array([_ for _ in inj_value if _ is not None]).flatten() 

729 if len(set(flat_injection)) > max_inj_line: 

730 logger.warning( 

731 "Number of unique injection values ({}) is more than the maximum " 

732 "allowed injection value ({}). Not plotting injection value. If " 

733 "this is a mistake, please increase `max_inj_line`".format( 

734 len(set(flat_injection)), max_inj_line 

735 ) 

736 ) 

737 inj_value = [None] * len(samples) 

738 

739 fig, ax = figure(figsize=figsize, gca=True) 

740 handles = [] 

741 hist_kwargs.update({"linewidth": 2.5}) 

742 for num, i in enumerate(samples): 

743 if latex_friendly: 

744 labels = copy.deepcopy(labels) 

745 labels[num] = labels[num].replace("_", "\_") 

746 fig = _1d_histogram_plot( 

747 param, i, latex_label, kde=kde, hist=hist, kde_kwargs=kde_kwargs, 

748 max_vline=max_vline, grid=grid, title=False, autoscale=False, 

749 label=labels[num], color=colors[num], fig=fig, hist_kwargs=hist_kwargs, 

750 inj_value=inj_value[num], injection_color=injection_color[num], 

751 linestyle=linestyles[num], _default_inj_kwargs={ 

752 "linewidth": 4., "linestyle": "-", "alpha": 0.4 

753 }, **kwargs 

754 ) 

755 handles.append(mlines.Line2D([], [], color=colors[num], label=labels[num])) 

756 ax = fig.gca() 

757 ncols = number_of_columns_for_legend(labels) 

758 legend = ax.legend(handles=handles, ncol=ncols, **legend_kwargs) 

759 for num, legobj in enumerate(legend.legendHandles): 

760 legobj.set_linewidth(1.75) 

761 legobj.set_linestyle(linestyles[num]) 

762 ax.set_xlabel(latex_label) 

763 ax.set_ylabel("Probability Density") 

764 ax.autoscale(axis='x') 

765 ax.grid(visible=grid) 

766 fig.tight_layout() 

767 return fig 

768 

769 

770def _comparison_box_plot(param, samples, colors, latex_label, labels, grid=True): 

771 """Generate a box plot to compare 1d_histograms for a given parameter 

772 

773 Parameters 

774 ---------- 

775 param: str 

776 name of the parameter that you wish to plot 

777 approximants: list 

778 list of approximant names that you would like to compare 

779 samples: 2d list 

780 list of samples for param for each approximant 

781 colors: list 

782 list of colors to be used to differentiate the different approximants 

783 latex_label: str 

784 latex label for param 

785 approximant_labels: list, optional 

786 label to prepend the approximant in the legend 

787 grid: Bool, optional 

788 if True, plot a grid 

789 """ 

790 logger.debug("Generating the 1d comparison boxplot plot for %s" % (param)) 

791 fig, ax = figure(gca=True) 

792 maximum = np.max([np.max(i) for i in samples]) 

793 minimum = np.min([np.min(i) for i in samples]) 

794 middle = (maximum + minimum) * 0.5 

795 ax.boxplot(samples, widths=0.2, vert=False, whis=np.inf, labels=labels) 

796 for num, i in enumerate(labels): 

797 ax.annotate(i, xy=(middle, 1), xytext=(middle, num + 1.0 + 0.2), ha="center") 

798 ax.set_yticks([]) 

799 ax.set_xlabel(latex_label) 

800 fig.tight_layout() 

801 ax.grid(visible=grid) 

802 return fig 

803 

804 

805def _make_corner_plot( 

806 samples, latex_labels, corner_parameters=None, parameters=None, **kwargs 

807): 

808 """Generate the corner plots for a given approximant 

809 

810 Parameters 

811 ---------- 

812 opts: argparse 

813 argument parser object to hold all information from the command line 

814 samples: nd list 

815 nd list of samples for each parameter for a given approximant 

816 params: list 

817 list of parameters associated with each element in samples 

818 approximant: str 

819 name of approximant that was used to generate the samples 

820 latex_labels: dict 

821 dictionary of latex labels for each parameter 

822 """ 

823 logger.debug("Generating the corner plot") 

824 # set the default kwargs 

825 default_kwargs = conf.corner_kwargs.copy() 

826 if parameters is None: 

827 parameters = list(samples.keys()) 

828 if corner_parameters is not None: 

829 included_parameters = [i for i in parameters if i in corner_parameters] 

830 else: 

831 included_parameters = parameters 

832 xs = np.zeros([len(included_parameters), len(samples[parameters[0]])]) 

833 for num, i in enumerate(included_parameters): 

834 xs[num] = samples[i] 

835 default_kwargs.update(kwargs) 

836 default_kwargs["range"] = [1.0] * len(included_parameters) 

837 default_kwargs["labels"] = [latex_labels[i] for i in included_parameters] 

838 

839 _figure = ExistingFigure(corner(xs.T, included_parameters, **default_kwargs)) 

840 # grab the axes of the subplots 

841 axes = _figure.get_axes() 

842 axes_of_interest = axes[:2] 

843 location = [] 

844 for i in axes_of_interest: 

845 extent = i.get_window_extent().transformed(_figure.dpi_scale_trans.inverted()) 

846 location.append([extent.x0 * _figure.dpi, extent.y0 * _figure.dpi]) 

847 width, height = extent.width, extent.height 

848 width *= _figure.dpi 

849 height *= _figure.dpi 

850 try: 

851 seperation = abs(location[0][0] - location[1][0]) - width 

852 except IndexError: 

853 seperation = None 

854 data = { 

855 "width": width, "height": height, "seperation": seperation, 

856 "x0": location[0][0], "y0": location[0][0] 

857 } 

858 return _figure, included_parameters, data 

859 

860 

861def _make_comparison_corner_plot( 

862 samples, latex_labels, corner_parameters=None, colors=conf.corner_colors, 

863 latex_friendly=True, **kwargs 

864): 

865 """Generate a corner plot which contains multiple datasets 

866 

867 Parameters 

868 ---------- 

869 samples: dict 

870 nested dictionary containing the label as key and SamplesDict as item 

871 for each dataset you wish to plot 

872 latex_labels: dict 

873 dictionary of latex labels for each parameter 

874 corner_parameters: list, optional 

875 corner parameters you wish to include in the plot 

876 colors: list, optional 

877 unique colors for each dataset 

878 latex_friendly: Bool, optional 

879 if True, make the label latex friendly. Default True 

880 **kwargs: dict 

881 all kwargs are passed to `corner.corner` 

882 """ 

883 parameters = corner_parameters 

884 if corner_parameters is None: 

885 _parameters = [list(_samples.keys()) for _samples in samples.values()] 

886 parameters = [ 

887 i for i in _parameters[0] if all(i in _params for _params in _parameters) 

888 ] 

889 if len(samples.keys()) > len(colors): 

890 raise ValueError("Please provide a unique color for each dataset") 

891 

892 hist_kwargs = kwargs.get("hist_kwargs", dict()) 

893 hist_kwargs["density"] = True 

894 lines = [] 

895 for num, (label, posterior) in enumerate(samples.items()): 

896 if latex_friendly: 

897 label = copy.deepcopy(label) 

898 label = label.replace("_", "\_") 

899 lines.append(mlines.Line2D([], [], color=colors[num], label=label)) 

900 _samples = { 

901 param: value for param, value in posterior.items() if param in 

902 parameters 

903 } 

904 hist_kwargs["color"] = colors[num] 

905 kwargs.update({"hist_kwargs": hist_kwargs}) 

906 if num == 0: 

907 fig, _, _ = _make_corner_plot( 

908 _samples, latex_labels, corner_parameters=corner_parameters, 

909 parameters=parameters, color=colors[num], **kwargs 

910 ) 

911 else: 

912 fig, _, _ = _make_corner_plot( 

913 _samples, latex_labels, corner_parameters=corner_parameters, 

914 fig=fig, parameters=parameters, color=colors[num], **kwargs 

915 ) 

916 fig.legend(handles=lines, loc="upper right") 

917 lines = [] 

918 return fig