Coverage for pesummary/core/plots/plot.py: 65.3%
372 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3from pesummary.utils.utils import (
4 logger, number_of_columns_for_legend, _check_latex_install, gelman_rubin,
5)
6from pesummary.core.plots.corner import corner
7from pesummary.core.plots.figure import figure, ExistingFigure
8from pesummary import conf
10import matplotlib.lines as mlines
11import copy
12from itertools import cycle
14import numpy as np
15from scipy import signal
17__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
18_check_latex_install()
20_default_legend_kwargs = dict(
21 bbox_to_anchor=(0.0, 1.02, 1.0, 0.102), loc=3, handlelength=3, mode="expand",
22 borderaxespad=0.0,
23)
26def _update_1d_comparison_legend(legend, linestyles, linewidth=1.75):
27 """Update the width and style of lines in the legend for a 1d comparison plot.
28 """
29 try:
30 handles = legend.legend_handles
31 except AttributeError: # matplotlib < 3.7.0
32 handles = legend.legendHandles
33 for handle, style in zip(handles, linestyles):
34 handle.set_linewidth(linewidth)
35 handle.set_linestyle(style)
38def _autocorrelation_plot(
39 param, samples, fig=None, color=conf.color, markersize=0.5, grid=True
40):
41 """Generate the autocorrelation function for a set of samples for a given
42 parameter for a given approximant.
44 Parameters
45 ----------
46 param: str
47 name of the parameter that you wish to plot
48 samples: list
49 list of samples for param
50 fig: matplotlib.pyplot.figure
51 existing figure you wish to use
52 color: str, optional
53 color you wish to use for the autocorrelation plot
54 grid: Bool, optional
55 if True, plot a grid
56 """
57 import warnings
58 warnings.filterwarnings("ignore", category=RuntimeWarning)
59 logger.debug("Generating the autocorrelation function for %s" % (param))
60 if fig is None:
61 fig, ax = figure(gca=True)
62 else:
63 ax = fig.gca()
64 samples = samples[int(len(samples) / 2):]
65 x = samples - np.mean(samples)
66 y = np.conj(x[::-1])
67 acf = np.fft.ifftshift(signal.fftconvolve(y, x, mode="full"))
68 N = np.array(samples).shape[0]
69 acf = acf[0:N]
70 # Hack to make test pass with python3.8
71 if color == "$":
72 color = conf.color
73 ax.plot(
74 acf / acf[0], linestyle=" ", marker="o", markersize=markersize,
75 color=color
76 )
77 ax.ticklabel_format(axis="x", style="plain")
78 ax.set_xlabel("lag")
79 ax.set_ylabel("ACF")
80 ax.grid(visible=grid)
81 fig.tight_layout()
82 return fig
85def _autocorrelation_plot_mcmc(
86 param, samples, colorcycle=conf.colorcycle, grid=True
87):
88 """Generate the autocorrelation function for a set of samples for a given
89 parameter for a given set of mcmc chains
91 Parameters
92 ----------
93 param: str
94 name of the parameter that you wish to plot
95 samples: np.ndarray
96 2d array containing a list of samples for param for each mcmc chain
97 colorcycle: list, str
98 color cycle you wish to use for the different mcmc chains
99 grid: Bool, optional
100 if True, plot a grid
101 """
102 cycol = cycle(colorcycle)
103 fig, ax = figure(gca=True)
104 for ss in samples:
105 fig = _autocorrelation_plot(
106 param, ss, fig=fig, markersize=1.25, color=next(cycol), grid=grid
107 )
108 return fig
111def _sample_evolution_plot(
112 param, samples, latex_label, inj_value=None, fig=None, color=conf.color,
113 markersize=0.5, grid=True, z=None, z_label=None, **kwargs
114):
115 """Generate a scatter plot showing the evolution of the samples for a
116 given parameter for a given approximant.
118 Parameters
119 ----------
120 param: str
121 name of the parameter that you wish to plot
122 samples: list
123 list of samples for param
124 latex_label: str
125 latex label for param
126 inj_value: float
127 value that was injected
128 fig: matplotlib.pyplot.figure, optional
129 existing figure you wish to use
130 color: str, optional
131 color you wish to use to plot the scatter points
132 grid: Bool, optional
133 if True, plot a grid
134 """
135 logger.debug("Generating the sample scatter plot for %s" % (param))
136 if fig is None:
137 fig, ax = figure(gca=True)
138 else:
139 ax = fig.gca()
140 n_samples = len(samples)
141 add_cbar = True if z is not None else False
142 if z is None:
143 z = color
144 s = ax.scatter(
145 range(n_samples), samples, marker="o", s=markersize, c=z,
146 **kwargs
147 )
148 if add_cbar:
149 cbar = fig.colorbar(s)
150 if z_label is not None:
151 cbar.set_label(z_label)
152 ax.ticklabel_format(axis="x", style="plain")
153 ax.set_xlabel("samples")
154 ax.set_ylabel(latex_label)
155 ax.grid(visible=grid)
156 fig.tight_layout()
157 return fig
160def _sample_evolution_plot_mcmc(
161 param, samples, latex_label, inj_value=None, colorcycle=conf.colorcycle,
162 grid=True
163):
164 """Generate a scatter plot showing the evolution of the samples in each
165 mcmc chain for a given parameter
167 Parameters
168 ----------
169 param: str
170 name of the parameter that you wish to plot
171 samples: np.ndarray
172 2d array containing the samples for param for each mcmc chain
173 latex_label: str
174 latex label for param
175 inj_value: float
176 value that was injected
177 colorcycle: list, str
178 color cycle you wish to use for the different mcmc chains
179 grid: Bool, optional
180 if True, plot a grid
181 """
182 cycol = cycle(colorcycle)
183 fig, ax = figure(gca=True)
184 for ss in samples:
185 fig = _sample_evolution_plot(
186 param, ss, latex_label, inj_value=None, fig=fig, markersize=1.25,
187 color=next(cycol), grid=grid
188 )
189 return fig
192def _1d_cdf_plot(
193 param, samples, latex_label, fig=None, color=conf.color, title=True,
194 grid=True, linestyle="-", weights=None, **kwargs
195):
196 """Generate the cumulative distribution function for a given parameter for
197 a given approximant.
199 Parameters
200 ----------
201 param: str
202 name of the parameter that you wish to plot
203 samples: list
204 list of samples for param
205 latex_label: str
206 latex label for param
207 fig: matplotlib.pyplot.figure, optional
208 existing figure you wish to use
209 color: str, optional09
210 color you wish to use to plot the scatter points
211 title: Bool, optional
212 if True, add a title to the 1d cdf plot showing giving the median
213 and symmetric 90% credible intervals
214 grid: Bool, optional
215 if True, plot a grid
216 linestyle: str, optional
217 linestyle to use for plotting the CDF. Default "-"
218 weights: list, optional
219 list of weights for samples. Default None
220 **kwargs: dict, optional
221 all additional kwargs passed to ax.plot
222 """
223 logger.debug("Generating the 1d CDF for %s" % (param))
224 if fig is None:
225 fig, ax = figure(gca=True)
226 else:
227 ax = fig.gca()
228 if weights is None:
229 sorted_samples = copy.deepcopy(samples)
230 sorted_samples.sort()
231 ax.plot(
232 sorted_samples, np.linspace(0, 1, len(sorted_samples)), color=color,
233 linestyle=linestyle, **kwargs
234 )
235 else:
236 hist, bin_edges = np.histogram(samples, bins=50, density=True, weights=weights)
237 total = np.cumsum(hist)
238 total /= total[-1]
239 ax.plot(
240 0.5 * (bin_edges[:-1] + bin_edges[1:]),
241 total, color=color, linestyle=linestyle,
242 **kwargs
243 )
244 ax.set_xlabel(latex_label)
245 ax.set_ylabel("Cumulative Density Function")
246 upper_percentile = np.percentile(samples, 95)
247 lower_percentile = np.percentile(samples, 5)
248 median = np.median(samples)
249 upper = np.round(upper_percentile - median, 2)
250 lower = np.round(median - lower_percentile, 2)
251 median = np.round(median, 2)
252 if title:
253 ax.set_title(r"$%s^{+%s}_{-%s}$" % (median, upper, lower))
254 ax.grid(visible=grid)
255 ax.set_ylim([0, 1.05])
256 fig.tight_layout()
257 return fig
260def _1d_cdf_plot_mcmc(
261 param, samples, latex_label, colorcycle=conf.colorcycle, grid=True, **kwargs
262):
263 """Generate the cumulative distribution function for a given parameter
264 for a given set of mcmc chains
266 Parameters
267 ----------
268 param: str
269 name of the parameter that you wish to plot
270 samples: np.ndarray
271 2d array containing the samples for param for each mcmc chain
272 latex_label: str
273 latex label for param
274 colorcycle: list, str
275 color cycle you wish to use for the different mcmc chains
276 grid: Bool, optional
277 if True, plot a grid
278 **kwargs: dict, optional
279 all additional kwargs passed to _1d_cdf_plot
280 """
281 cycol = cycle(colorcycle)
282 fig, ax = figure(gca=True)
283 for ss in samples:
284 fig = _1d_cdf_plot(
285 param, ss, latex_label, fig=fig, color=next(cycol), title=False,
286 grid=grid, **kwargs
287 )
288 gelman = gelman_rubin(samples)
289 ax.set_title("Gelman-Rubin: {}".format(gelman))
290 return fig
293def _1d_cdf_comparison_plot(
294 param, samples, colors, latex_label, labels, linestyles=None, grid=True,
295 legend_kwargs=_default_legend_kwargs, latex_friendly=False, weights=None,
296 **kwargs
297):
298 """Generate a plot to compare the cdfs for a given parameter for different
299 approximants.
301 Parameters
302 ----------
303 param: str
304 name of the parameter that you wish to plot
305 approximants: list
306 list of approximant names that you would like to compare
307 samples: 2d list
308 list of samples for param for each approximant
309 colors: list
310 list of colors to be used to differentiate the different approximants
311 latex_label: str
312 latex label for param
313 approximant_labels: list, optional
314 label to prepend the approximant in the legend
315 grid: Bool, optional
316 if True, plot a grid
317 legend_kwargs: dict, optional
318 optional kwargs to pass to ax.legend()
319 latex_friendly: Bool, optional
320 if True, make the label latex friendly. Default False
321 weights: list, optional
322 list of weights to use for each analysis. Default None
323 **kwargs: dict, optional
324 all additional kwargs passed to _1d_cdf_plot
325 """
326 logger.debug("Generating the 1d comparison CDF for %s" % (param))
327 if linestyles is None:
328 linestyles = ["-"] * len(samples)
329 if weights is None:
330 weights = [None] * len(samples)
331 if len(weights) != len(samples):
332 raise ValueError(
333 "Please provide a set of weights for each analysis"
334 )
335 fig, ax = figure(figsize=(8, 6), gca=True)
336 handles = []
337 for num, i in enumerate(samples):
338 fig = _1d_cdf_plot(
339 param, i, latex_label, fig=fig, color=colors[num], title=False,
340 grid=grid, linestyle=linestyles[num], weights=weights[num], **kwargs
341 )
342 if latex_friendly:
343 labels = copy.deepcopy(labels)
344 labels[num] = labels[num].replace("_", "\_")
345 handles.append(mlines.Line2D([], [], color=colors[num], label=labels[num]))
346 ncols = number_of_columns_for_legend(labels)
347 legend = ax.legend(handles=handles, ncol=ncols, **legend_kwargs)
348 _update_1d_comparison_legend(legend, linestyles)
349 ax.set_xlabel(latex_label)
350 ax.set_ylabel("Cumulative Density Function")
351 ax.grid(visible=grid)
352 ax.set_ylim([0, 1.05])
353 fig.tight_layout()
354 return fig
357def _1d_analytic_plot(
358 param, x, pdf, latex_label, inj_value=None, prior=None, fig=None, ax=None,
359 title=True, color=conf.color, autoscale=True, grid=True, set_labels=True,
360 plot_percentile=True, xlims=None, label=None, linestyle="-",
361 linewidth=1.75, injection_color=conf.injection_color,
362 _default_inj_kwargs={"linewidth": 2.5, "linestyle": "-"}, **plot_kwargs
363):
364 """Generate a plot to display a PDF
366 Parameters
367 ----------
368 param: str
369 name of the parameter that you wish to plot
371 latex_label: str
372 latex label for param
373 inj_value: float, optional
374 value that was injected
375 prior: list
376 list of prior samples for param
377 weights: list
378 list of weights for each sample
379 fig: matplotlib.pyplot.figure, optional
380 existing figure you wish to use
381 ax: matplotlib.pyplot.axes._subplots.AxesSubplot, optional
382 existing axis you wish to use
383 color: str, optional
384 color you wish to use to plot the scatter points
385 title: Bool, optional
386 if True, add a title to the 1d cdf plot showing giving the median
387 and symmetric 90% credible intervals
388 autoscale: Bool, optional
389 autoscale the x axis
390 grid: Bool, optional
391 if True, plot a grid
392 set_labels: Bool, optional
393 if True, add labels to the axes
394 plot_percentile: Bool, optional
395 if True, plot dashed vertical lines showing the 90% symmetric credible
396 intervals
397 xlims: list, optional
398 x axis limits you wish to use
399 label: str, optional
400 label you wish to use for the plot
401 linestyle: str, optional
402 linestyle you wish to use for the plot
403 linewidth: float, optional
404 linewidth to use for the plot
405 injection_color: str, optional
406 color of vertical line showing the injected value
407 """
408 from pesummary.utils.array import Array
410 if ax is None and fig is None:
411 fig, ax = figure(gca=True)
412 elif ax is None:
413 ax = fig.gca()
415 pdf = Array(x, weights=pdf)
417 ax.plot(pdf, pdf.weights, color=color, linestyle=linestyle, label=label)
418 _xlims = ax.get_xlim()
419 percentile = pdf.credible_interval([5, 95])
420 median = pdf.average("median")
421 if title:
422 upper = np.round(percentile[1] - median, 2)
423 lower = np.round(median - percentile[0], 2)
424 median = np.round(median, 2)
425 ax.set_title(r"$%s^{+%s}_{-%s}$" % (median, upper, lower))
426 if plot_percentile:
427 for pp in percentile:
428 ax.axvline(
429 pp, color=color, linestyle="--", linewidth=linewidth
430 )
431 if set_labels:
432 ax.set_xlabel(latex_label)
433 ax.set_ylabel("Probability Density")
435 if inj_value is not None:
436 ax.axvline(
437 inj_value, color=injection_color, **_default_inj_kwargs
438 )
439 ax.grid(visible=grid)
440 ax.set_xlim(xlims)
441 if autoscale:
442 ax.set_xlim(_xlims)
443 if fig is None:
444 return ax
445 fig.tight_layout()
446 return fig
449def _1d_histogram_plot(
450 param, samples, latex_label, inj_value=None, kde=False, hist=True,
451 prior=None, weights=None, fig=None, ax=None, title=True, color=conf.color,
452 autoscale=True, grid=True, kde_kwargs={}, hist_kwargs={}, set_labels=True,
453 plot_percentile=True, plot_hdp=True, xlims=None, max_vline=1, label=None,
454 linestyle="-", injection_color=conf.injection_color, _default_hist_kwargs={
455 "density": True, "bins": 50, "histtype": "step", "linewidth": 1.75
456 }, _default_kde_kwargs={"fill": True, "alpha": 0.1},
457 _default_inj_kwargs={"linewidth": 2.5, "linestyle": "-"},
458 key_data=None, **plot_kwargs
459):
460 """Generate the 1d histogram plot for a given parameter for a given
461 approximant.
463 Parameters
464 ----------
465 param: str
466 name of the parameter that you wish to plot
467 samples: list
468 list of samples for param
469 latex_label: str
470 latex label for param
471 inj_value: float, optional
472 value that was injected
473 kde: Bool, optional
474 if True, a kde is plotted instead of a histogram
475 hist: Bool, optional
476 if True, plot a histogram
477 prior: list
478 list of prior samples for param
479 weights: list
480 list of weights for each sample
481 fig: matplotlib.pyplot.figure, optional
482 existing figure you wish to use
483 ax: matplotlib.pyplot.axes._subplots.AxesSubplot, optional
484 existing axis you wish to use
485 color: str, optional
486 color you wish to use to plot the scatter points
487 title: Bool, optional
488 if True, add a title to the 1d cdf plot showing giving the median
489 and symmetric 90% credible intervals
490 autoscale: Bool, optional
491 autoscale the x axis
492 grid: Bool, optional
493 if True, plot a grid
494 kde_kwargs, dict, optional
495 optional kwargs to pass to the kde class
496 hist_kwargs: dict, optional
497 optional kwargs to pass to matplotlib.pyplot.hist
498 set_labels: Bool, optional
499 if True, add labels to the axes
500 plot_percentile: Bool, optional
501 if True, plot dashed vertical lines showing the 90% symmetric credible
502 intervals
503 xlims: list, optional
504 x axis limits you wish to use
505 max_vline: int, optional
506 if number of peaks < max_vline draw peaks as vertical lines rather
507 than histogramming the data
508 label: str, optional
509 label you wish to use for the plot
510 linestyle: str, optional
511 linestyle you wish to use for the plot
512 injection_color: str, optional
513 color of vertical line showing the injected value
514 """
515 from pesummary.utils.array import Array
517 logger.debug("Generating the 1d histogram plot for %s" % (param))
518 samples = Array(samples, weights=weights)
519 if ax is None and fig is None:
520 fig, ax = figure(gca=True)
521 elif ax is None:
522 ax = fig.gca()
524 if len(set(samples)) <= max_vline:
525 for _ind, _sample in enumerate(set(samples)):
526 _label = None
527 if _ind == 0:
528 _label = label
529 ax.axvline(_sample, color=color, label=_label)
530 _xlims = ax.get_xlim()
531 else:
532 if hist:
533 _default_hist_kwargs.update(hist_kwargs)
534 ax.hist(
535 samples, weights=weights, color=color, label=label,
536 linestyle=linestyle, **_default_hist_kwargs, **plot_kwargs
537 )
538 _xlims = ax.get_xlim()
539 if prior is not None:
540 _prior_hist_kwargs = _default_hist_kwargs.copy()
541 _prior_hist_kwargs["histtype"] = "bar"
542 _ = ax.hist(
543 prior, color=conf.prior_color, alpha=0.2, edgecolor="w",
544 linestyle=linestyle, **_prior_hist_kwargs, **plot_kwargs
545 )
546 if kde:
547 from pesummary.core.plots.seaborn.kde import kdeplot
548 _kde_kwargs = kde_kwargs.copy()
549 kwargs = _default_kde_kwargs
550 kwargs.update({
551 "kde_kwargs": _kde_kwargs,
552 "kde_kernel": _kde_kwargs.pop("kde_kernel", None),
553 "weights": weights
554 })
555 kwargs.update(plot_kwargs)
556 x = kdeplot(
557 samples, color=color, ax=ax, linestyle=linestyle, **kwargs
558 )
559 _xlims = ax.get_xlim()
560 if prior is not None:
561 kdeplot(
562 prior, color=conf.prior_color, ax=ax, linestyle=linestyle,
563 **kwargs
564 )
566 if set_labels:
567 ax.set_xlabel(latex_label)
568 ax.set_ylabel("Probability Density")
570 if inj_value is not None:
571 ax.axvline(
572 inj_value, color=injection_color, **_default_inj_kwargs
573 )
574 hdp = float("nan")
575 if key_data is not None:
576 percentile = [key_data["5th percentile"], key_data["95th percentile"]]
577 median = key_data["median"]
578 if "90% HPD" in key_data.keys():
579 hdp = key_data["90% HPD"]
580 else:
581 percentile = samples.credible_interval([5, 95])
582 median = samples.average("median")
583 if plot_percentile:
584 for pp in percentile:
585 ax.axvline(
586 pp, color=color, linestyle="--",
587 linewidth=hist_kwargs.get("linewidth", 1.75)
588 )
589 if plot_hdp and isinstance(hdp, (list, np.ndarray)):
590 for pp in hdp:
591 ax.axvline(
592 pp, color=color, linestyle=":",
593 linewidth=hist_kwargs.get("linewidth", 1.75)
594 )
595 if title:
596 upper = np.round(percentile[1] - median, 2)
597 lower = np.abs(np.round(median - percentile[0], 2))
598 median = np.round(median, 2)
599 _base = r"$%s^{+%s}_{-%s}" % (median, upper, lower)
600 if not isinstance(hdp, (list, np.ndarray)) and np.isnan(hdp):
601 _base += r"$"
602 ax.set_title(_base)
603 else:
604 upper = np.round(hdp[1] - median, 2)
605 lower = np.abs(np.round(median - hdp[0], 2))
606 _base += r"\, (\mathrm{CI}) / %s^{+%s}_{-%s}\, (\mathrm{HPD})$" % (
607 median, upper, lower
608 )
609 ax.set_title(_base)
610 ax.grid(visible=grid)
611 ax.set_xlim(xlims)
612 if autoscale:
613 ax.set_xlim(_xlims)
614 if fig is None:
615 return ax
616 fig.tight_layout()
617 return fig
620def _1d_histogram_plot_mcmc(
621 param, samples, latex_label, colorcycle=conf.colorcycle, **kwargs
622):
623 """Generate a 1d histogram plot for a given parameter for a given
624 set of mcmc chains
626 Parameters
627 ----------
628 param: str
629 name of the parameter that you wish to plot
630 samples: np.ndarray
631 2d array of samples for param for each mcmc chain
632 latex_label: str
633 latex label for param
634 colorcycle: list, str
635 color cycle you wish to use for the different mcmc chains
636 **kwargs: dict, optional
637 all additional kwargs passed to _1d_histogram_plot
638 """
639 cycol = cycle(colorcycle)
640 fig, ax = figure(gca=True)
641 for ss in samples:
642 fig = _1d_histogram_plot(
643 param, ss, latex_label, color=next(cycol), title=False,
644 autoscale=False, fig=fig, **kwargs
645 )
646 gelman = gelman_rubin(samples)
647 ax.set_title("Gelman-Rubin: {}".format(gelman))
648 return fig
651def _1d_histogram_plot_bootstrap(
652 param, samples, latex_label, colorcycle=conf.colorcycle, nsamples=1000,
653 ntests=100, shade=False, plot_percentile=False, kde=True, hist=False,
654 **kwargs
655):
656 """Generate a bootstrapped 1d histogram plot for a given parameter
658 Parameters
659 ----------
660 param: str
661 name of the parameter that you wish to plot
662 samples: np.ndarray
663 array of samples for param
664 latex_label: str
665 latex label for param
666 colorcycle: list, str
667 color cycle you wish to use for the different tests
668 nsamples: int, optional
669 number of samples to randomly draw from samples. Default 1000
670 ntests: int, optional
671 number of tests to perform. Default 100
672 **kwargs: dict, optional
673 all additional kwargs passed to _1d_histogram_plot
674 """
675 if nsamples > len(samples):
676 nsamples = int(len(samples) / 2)
677 _samples = [
678 np.random.choice(samples, size=nsamples, replace=False) for _ in
679 range(ntests)
680 ]
681 cycol = cycle(colorcycle)
682 fig, ax = figure(gca=True)
683 for ss in _samples:
684 fig = _1d_histogram_plot(
685 param, ss, latex_label, color=next(cycol), title=False,
686 autoscale=False, fig=fig, shade=shade,
687 plot_percentile=plot_percentile, kde=kde, hist=hist, **kwargs
688 )
689 ax.set_title("Ntests: {}, Nsamples per test: {}".format(ntests, nsamples))
690 fig.tight_layout()
691 return fig
694def _1d_comparison_histogram_plot(
695 param, samples, colors, latex_label, labels, inj_value=None, kde=False,
696 hist=True, linestyles=None, kde_kwargs={}, hist_kwargs={}, max_vline=1,
697 figsize=(8, 6), grid=True, legend_kwargs=_default_legend_kwargs,
698 latex_friendly=False, max_inj_line=1, injection_color="k",
699 weights=None, **kwargs
700):
701 """Generate the a plot to compare the 1d_histogram plots for a given
702 parameter for different approximants.
704 Parameters
705 ----------
706 param: str
707 name of the parameter that you wish to plot
708 approximants: list
709 list of approximant names that you would like to compare
710 samples: 2d list
711 list of samples for param for each approximant
712 colors: list
713 list of colors to be used to differentiate the different approximants
714 latex_label: str
715 latex label for param
716 approximant_labels: list, optional
717 label to prepend the approximant in the legend
718 kde: Bool
719 if true, a kde is plotted instead of a histogram
720 linestyles: list
721 list of linestyles for each set of samples
722 grid: Bool, optional
723 if True, plot a grid
724 legend_kwargs: dict, optional
725 optional kwargs to pass to ax.legend()
726 latex_friendly: Bool, optional
727 if True, make the label latex friendly. Default False
728 inj_value: float/list, optional
729 either a single injection value which will be used for all histograms
730 or a list of injection values, one for each histogram
731 injection_color: str/list, optional
732 either a single color which will be used for all vertical line showing
733 the injected value or a list of colors, one for each injection
734 weights: list, optional
735 list of weights to use for each analysis. Default None
736 **kwargs: dict, optional
737 all additional kwargs passed to _1d_histogram_plot
738 """
739 logger.debug("Generating the 1d comparison histogram plot for %s" % (param))
740 if linestyles is None:
741 linestyles = ["-"] * len(samples)
742 if inj_value is None:
743 inj_value = [None] * len(samples)
744 if weights is None:
745 weights = [None] * len(samples)
746 if len(weights) != len(samples):
747 raise ValueError(
748 "Please provide a set of weights for each analysis"
749 )
750 elif isinstance(inj_value, (list, np.ndarray)) and len(inj_value) != len(samples):
751 raise ValueError(
752 "Please provide an injection for each analysis or a single "
753 "injection value which will be used for all histograms"
754 )
755 elif not isinstance(inj_value, (list, np.ndarray)):
756 inj_value = [inj_value] * len(samples)
758 if isinstance(injection_color, str):
759 injection_color = [injection_color] * len(samples)
760 elif len(injection_color) != len(samples):
761 raise ValueError(
762 "Please provide an injection color for each analysis or a single "
763 "injection color which will be used for all lines showing the "
764 "injected values"
765 )
767 flat_injection = np.array([_ for _ in inj_value if _ is not None]).flatten()
768 if len(set(flat_injection)) > max_inj_line:
769 logger.warning(
770 "Number of unique injection values ({}) is more than the maximum "
771 "allowed injection value ({}). Not plotting injection value. If "
772 "this is a mistake, please increase `max_inj_line`".format(
773 len(set(flat_injection)), max_inj_line
774 )
775 )
776 inj_value = [None] * len(samples)
778 fig, ax = figure(figsize=figsize, gca=True)
779 handles = []
780 hist_kwargs.update({"linewidth": 2.5})
781 for num, i in enumerate(samples):
782 if latex_friendly:
783 labels = copy.deepcopy(labels)
784 labels[num] = labels[num].replace("_", "\_")
785 fig = _1d_histogram_plot(
786 param, i, latex_label, kde=kde, hist=hist, kde_kwargs=kde_kwargs,
787 max_vline=max_vline, grid=grid, title=False, autoscale=False,
788 label=labels[num], color=colors[num], fig=fig, hist_kwargs=hist_kwargs,
789 inj_value=inj_value[num], injection_color=injection_color[num],
790 weights=weights[num], linestyle=linestyles[num], _default_inj_kwargs={
791 "linewidth": 4., "linestyle": "-", "alpha": 0.4
792 }, **kwargs
793 )
794 handles.append(mlines.Line2D([], [], color=colors[num], label=labels[num]))
795 ax = fig.gca()
796 ncols = number_of_columns_for_legend(labels)
797 legend = ax.legend(handles=handles, ncol=ncols, **legend_kwargs)
798 _update_1d_comparison_legend(legend, linestyles)
799 ax.set_xlabel(latex_label)
800 ax.set_ylabel("Probability Density")
801 ax.autoscale(axis='x')
802 ax.grid(visible=grid)
803 fig.tight_layout()
804 return fig
807def _comparison_box_plot(param, samples, colors, latex_label, labels, grid=True):
808 """Generate a box plot to compare 1d_histograms for a given parameter
810 Parameters
811 ----------
812 param: str
813 name of the parameter that you wish to plot
814 approximants: list
815 list of approximant names that you would like to compare
816 samples: 2d list
817 list of samples for param for each approximant
818 colors: list
819 list of colors to be used to differentiate the different approximants
820 latex_label: str
821 latex label for param
822 approximant_labels: list, optional
823 label to prepend the approximant in the legend
824 grid: Bool, optional
825 if True, plot a grid
826 """
827 logger.debug("Generating the 1d comparison boxplot plot for %s" % (param))
828 fig, ax = figure(gca=True)
829 maximum = np.max([np.max(i) for i in samples])
830 minimum = np.min([np.min(i) for i in samples])
831 middle = (maximum + minimum) * 0.5
832 ax.boxplot(samples, widths=0.2, vert=False, whis=np.inf, labels=labels)
833 for num, i in enumerate(labels):
834 ax.annotate(i, xy=(middle, 1), xytext=(middle, num + 1.0 + 0.2), ha="center")
835 ax.set_yticks([])
836 ax.set_xlabel(latex_label)
837 fig.tight_layout()
838 ax.grid(visible=grid)
839 return fig
842def _make_corner_plot(
843 samples, latex_labels, corner_parameters=None, parameters=None, **kwargs
844):
845 """Generate the corner plots for a given approximant
847 Parameters
848 ----------
849 opts: argparse
850 argument parser object to hold all information from the command line
851 samples: nd list
852 nd list of samples for each parameter for a given approximant
853 params: list
854 list of parameters associated with each element in samples
855 approximant: str
856 name of approximant that was used to generate the samples
857 latex_labels: dict
858 dictionary of latex labels for each parameter
859 """
860 logger.debug("Generating the corner plot")
861 # set the default kwargs
862 default_kwargs = conf.corner_kwargs.copy()
863 if parameters is None:
864 parameters = list(samples.keys())
865 if corner_parameters is not None:
866 included_parameters = [i for i in corner_parameters if i in parameters]
867 excluded_parameters = [i for i in corner_parameters if i not in parameters]
868 if len(excluded_parameters):
869 plural = len(excluded_parameters) > 1
870 logger.warning(
871 f"Removing the parameter{'s' if plural else ''}: "
872 f"{', '.join(excluded_parameters)} from the corner plot as "
873 f"{'they are' if plural else 'it is'} not available in the "
874 f"posterior table. This may affect the truth lines if "
875 f"provided."
876 )
877 if not len(included_parameters):
878 raise ValueError(
879 "None of the chosen parameters are in the posterior "
880 "samples table. Please choose other parameters to plot"
881 )
882 else:
883 included_parameters = parameters
884 xs = np.zeros([len(included_parameters), len(samples[parameters[0]])])
885 for num, i in enumerate(included_parameters):
886 xs[num] = samples[i]
887 default_kwargs.update(kwargs)
888 default_kwargs["range"] = [1.0] * len(included_parameters)
889 default_kwargs["labels"] = [latex_labels[i] for i in included_parameters]
891 _figure = ExistingFigure(corner(xs.T, included_parameters, **default_kwargs))
892 # grab the axes of the subplots
893 axes = _figure.get_axes()
894 axes_of_interest = axes[:2]
895 location = []
896 for i in axes_of_interest:
897 extent = i.get_window_extent().transformed(_figure.dpi_scale_trans.inverted())
898 location.append([extent.x0 * _figure.dpi, extent.y0 * _figure.dpi])
899 width, height = extent.width, extent.height
900 width *= _figure.dpi
901 height *= _figure.dpi
902 try:
903 seperation = float(abs(location[0][0] - location[1][0]) - width)
904 except IndexError:
905 seperation = None
906 # explicitly cast to float to ensure correct rendering in JS
907 # https://git.ligo.org/lscsoft/pesummary/-/issues/332
908 data = {
909 "width": float(width),
910 "height": float(height),
911 "seperation": seperation,
912 "x0": float(location[0][0]),
913 "y0": float(location[0][0]),
914 }
915 return _figure, included_parameters, data
918def _make_comparison_corner_plot(
919 samples, latex_labels, corner_parameters=None, colors=conf.corner_colors,
920 latex_friendly=True, **kwargs
921):
922 """Generate a corner plot which contains multiple datasets
924 Parameters
925 ----------
926 samples: dict
927 nested dictionary containing the label as key and SamplesDict as item
928 for each dataset you wish to plot
929 latex_labels: dict
930 dictionary of latex labels for each parameter
931 corner_parameters: list, optional
932 corner parameters you wish to include in the plot
933 colors: list, optional
934 unique colors for each dataset
935 latex_friendly: Bool, optional
936 if True, make the label latex friendly. Default True
937 **kwargs: dict
938 all kwargs are passed to `corner.corner`
939 """
940 parameters = corner_parameters
941 if corner_parameters is None:
942 _parameters = [list(_samples.keys()) for _samples in samples.values()]
943 parameters = [
944 i for i in _parameters[0] if all(i in _params for _params in _parameters)
945 ]
946 if len(samples.keys()) > len(colors):
947 raise ValueError("Please provide a unique color for each dataset")
949 hist_kwargs = kwargs.get("hist_kwargs", dict())
950 hist_kwargs["density"] = True
951 lines = []
952 for num, (label, posterior) in enumerate(samples.items()):
953 if latex_friendly:
954 label = copy.deepcopy(label)
955 label = label.replace("_", "\_")
956 lines.append(mlines.Line2D([], [], color=colors[num], label=label))
957 _samples = {
958 param: value for param, value in posterior.items() if param in
959 parameters
960 }
961 hist_kwargs["color"] = colors[num]
962 kwargs.update({"hist_kwargs": hist_kwargs})
963 if num == 0:
964 fig, _, _ = _make_corner_plot(
965 _samples, latex_labels, corner_parameters=corner_parameters,
966 parameters=parameters, color=colors[num], **kwargs
967 )
968 else:
969 fig, _, _ = _make_corner_plot(
970 _samples, latex_labels, corner_parameters=corner_parameters,
971 fig=fig, parameters=parameters, color=colors[num], **kwargs
972 )
973 fig.legend(handles=lines, loc="upper right")
974 lines = []
975 return fig