Coverage for pesummary/core/plots/plot.py: 49.7%
366 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3from pesummary.utils.utils import (
4 logger, number_of_columns_for_legend, _check_latex_install, gelman_rubin,
5)
6from pesummary.core.plots.seaborn.kde import kdeplot
7from pesummary.core.plots.corner import corner
8from pesummary.core.plots.figure import figure, ExistingFigure
9from pesummary import conf
11import matplotlib.lines as mlines
12import copy
13from itertools import cycle
15import numpy as np
16from scipy import signal
18__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
19_check_latex_install()
21_default_legend_kwargs = dict(
22 bbox_to_anchor=(0.0, 1.02, 1.0, 0.102), loc=3, handlelength=3, mode="expand",
23 borderaxespad=0.0,
24)
27def _update_1d_comparison_legend(legend, linestyles, linewidth=1.75):
28 """Update the width and style of lines in the legend for a 1d comparison plot.
29 """
30 try:
31 handles = legend.legend_handles
32 except AttributeError: # matplotlib < 3.7.0
33 handles = legend.legendHandles
34 for handle, style in zip(handles, linestyles):
35 handle.set_linewidth(linewidth)
36 handle.set_linestyle(style)
39def _autocorrelation_plot(
40 param, samples, fig=None, color=conf.color, markersize=0.5, grid=True
41):
42 """Generate the autocorrelation function for a set of samples for a given
43 parameter for a given approximant.
45 Parameters
46 ----------
47 param: str
48 name of the parameter that you wish to plot
49 samples: list
50 list of samples for param
51 fig: matplotlib.pyplot.figure
52 existing figure you wish to use
53 color: str, optional
54 color you wish to use for the autocorrelation plot
55 grid: Bool, optional
56 if True, plot a grid
57 """
58 import warnings
59 warnings.filterwarnings("ignore", category=RuntimeWarning)
60 logger.debug("Generating the autocorrelation function for %s" % (param))
61 if fig is None:
62 fig, ax = figure(gca=True)
63 else:
64 ax = fig.gca()
65 samples = samples[int(len(samples) / 2):]
66 x = samples - np.mean(samples)
67 y = np.conj(x[::-1])
68 acf = np.fft.ifftshift(signal.fftconvolve(y, x, mode="full"))
69 N = np.array(samples).shape[0]
70 acf = acf[0:N]
71 # Hack to make test pass with python3.8
72 if color == "$":
73 color = conf.color
74 ax.plot(
75 acf / acf[0], linestyle=" ", marker="o", markersize=markersize,
76 color=color
77 )
78 ax.ticklabel_format(axis="x", style="plain")
79 ax.set_xlabel("lag")
80 ax.set_ylabel("ACF")
81 ax.grid(visible=grid)
82 fig.tight_layout()
83 return fig
86def _autocorrelation_plot_mcmc(
87 param, samples, colorcycle=conf.colorcycle, grid=True
88):
89 """Generate the autocorrelation function for a set of samples for a given
90 parameter for a given set of mcmc chains
92 Parameters
93 ----------
94 param: str
95 name of the parameter that you wish to plot
96 samples: np.ndarray
97 2d array containing a list of samples for param for each mcmc chain
98 colorcycle: list, str
99 color cycle you wish to use for the different mcmc chains
100 grid: Bool, optional
101 if True, plot a grid
102 """
103 cycol = cycle(colorcycle)
104 fig, ax = figure(gca=True)
105 for ss in samples:
106 fig = _autocorrelation_plot(
107 param, ss, fig=fig, markersize=1.25, color=next(cycol), grid=grid
108 )
109 return fig
112def _sample_evolution_plot(
113 param, samples, latex_label, inj_value=None, fig=None, color=conf.color,
114 markersize=0.5, grid=True, z=None, z_label=None, **kwargs
115):
116 """Generate a scatter plot showing the evolution of the samples for a
117 given parameter for a given approximant.
119 Parameters
120 ----------
121 param: str
122 name of the parameter that you wish to plot
123 samples: list
124 list of samples for param
125 latex_label: str
126 latex label for param
127 inj_value: float
128 value that was injected
129 fig: matplotlib.pyplot.figure, optional
130 existing figure you wish to use
131 color: str, optional
132 color you wish to use to plot the scatter points
133 grid: Bool, optional
134 if True, plot a grid
135 """
136 logger.debug("Generating the sample scatter plot for %s" % (param))
137 if fig is None:
138 fig, ax = figure(gca=True)
139 else:
140 ax = fig.gca()
141 n_samples = len(samples)
142 add_cbar = True if z is not None else False
143 if z is None:
144 z = color
145 s = ax.scatter(
146 range(n_samples), samples, marker="o", s=markersize, c=z,
147 **kwargs
148 )
149 if add_cbar:
150 cbar = fig.colorbar(s)
151 if z_label is not None:
152 cbar.set_label(z_label)
153 ax.ticklabel_format(axis="x", style="plain")
154 ax.set_xlabel("samples")
155 ax.set_ylabel(latex_label)
156 ax.grid(visible=grid)
157 fig.tight_layout()
158 return fig
161def _sample_evolution_plot_mcmc(
162 param, samples, latex_label, inj_value=None, colorcycle=conf.colorcycle,
163 grid=True
164):
165 """Generate a scatter plot showing the evolution of the samples in each
166 mcmc chain for a given parameter
168 Parameters
169 ----------
170 param: str
171 name of the parameter that you wish to plot
172 samples: np.ndarray
173 2d array containing the samples for param for each mcmc chain
174 latex_label: str
175 latex label for param
176 inj_value: float
177 value that was injected
178 colorcycle: list, str
179 color cycle you wish to use for the different mcmc chains
180 grid: Bool, optional
181 if True, plot a grid
182 """
183 cycol = cycle(colorcycle)
184 fig, ax = figure(gca=True)
185 for ss in samples:
186 fig = _sample_evolution_plot(
187 param, ss, latex_label, inj_value=None, fig=fig, markersize=1.25,
188 color=next(cycol), grid=grid
189 )
190 return fig
193def _1d_cdf_plot(
194 param, samples, latex_label, fig=None, color=conf.color, title=True,
195 grid=True, linestyle="-", weights=None, **kwargs
196):
197 """Generate the cumulative distribution function for a given parameter for
198 a given approximant.
200 Parameters
201 ----------
202 param: str
203 name of the parameter that you wish to plot
204 samples: list
205 list of samples for param
206 latex_label: str
207 latex label for param
208 fig: matplotlib.pyplot.figure, optional
209 existing figure you wish to use
210 color: str, optional09
211 color you wish to use to plot the scatter points
212 title: Bool, optional
213 if True, add a title to the 1d cdf plot showing giving the median
214 and symmetric 90% credible intervals
215 grid: Bool, optional
216 if True, plot a grid
217 linestyle: str, optional
218 linestyle to use for plotting the CDF. Default "-"
219 weights: list, optional
220 list of weights for samples. Default None
221 **kwargs: dict, optional
222 all additional kwargs passed to ax.plot
223 """
224 logger.debug("Generating the 1d CDF for %s" % (param))
225 if fig is None:
226 fig, ax = figure(gca=True)
227 else:
228 ax = fig.gca()
229 if weights is None:
230 sorted_samples = copy.deepcopy(samples)
231 sorted_samples.sort()
232 ax.plot(
233 sorted_samples, np.linspace(0, 1, len(sorted_samples)), color=color,
234 linestyle=linestyle, **kwargs
235 )
236 else:
237 hist, bin_edges = np.histogram(samples, bins=50, density=True, weights=weights)
238 total = np.cumsum(hist)
239 total /= total[-1]
240 ax.plot(
241 0.5 * (bin_edges[:-1] + bin_edges[1:]),
242 total, color=color, linestyle=linestyle,
243 **kwargs
244 )
245 ax.set_xlabel(latex_label)
246 ax.set_ylabel("Cumulative Density Function")
247 upper_percentile = np.percentile(samples, 95)
248 lower_percentile = np.percentile(samples, 5)
249 median = np.median(samples)
250 upper = np.round(upper_percentile - median, 2)
251 lower = np.round(median - lower_percentile, 2)
252 median = np.round(median, 2)
253 if title:
254 ax.set_title(r"$%s^{+%s}_{-%s}$" % (median, upper, lower))
255 ax.grid(visible=grid)
256 ax.set_ylim([0, 1.05])
257 fig.tight_layout()
258 return fig
261def _1d_cdf_plot_mcmc(
262 param, samples, latex_label, colorcycle=conf.colorcycle, grid=True, **kwargs
263):
264 """Generate the cumulative distribution function for a given parameter
265 for a given set of mcmc chains
267 Parameters
268 ----------
269 param: str
270 name of the parameter that you wish to plot
271 samples: np.ndarray
272 2d array containing the samples for param for each mcmc chain
273 latex_label: str
274 latex label for param
275 colorcycle: list, str
276 color cycle you wish to use for the different mcmc chains
277 grid: Bool, optional
278 if True, plot a grid
279 **kwargs: dict, optional
280 all additional kwargs passed to _1d_cdf_plot
281 """
282 cycol = cycle(colorcycle)
283 fig, ax = figure(gca=True)
284 for ss in samples:
285 fig = _1d_cdf_plot(
286 param, ss, latex_label, fig=fig, color=next(cycol), title=False,
287 grid=grid, **kwargs
288 )
289 gelman = gelman_rubin(samples)
290 ax.set_title("Gelman-Rubin: {}".format(gelman))
291 return fig
294def _1d_cdf_comparison_plot(
295 param, samples, colors, latex_label, labels, linestyles=None, grid=True,
296 legend_kwargs=_default_legend_kwargs, latex_friendly=False, weights=None,
297 **kwargs
298):
299 """Generate a plot to compare the cdfs for a given parameter for different
300 approximants.
302 Parameters
303 ----------
304 param: str
305 name of the parameter that you wish to plot
306 approximants: list
307 list of approximant names that you would like to compare
308 samples: 2d list
309 list of samples for param for each approximant
310 colors: list
311 list of colors to be used to differentiate the different approximants
312 latex_label: str
313 latex label for param
314 approximant_labels: list, optional
315 label to prepend the approximant in the legend
316 grid: Bool, optional
317 if True, plot a grid
318 legend_kwargs: dict, optional
319 optional kwargs to pass to ax.legend()
320 latex_friendly: Bool, optional
321 if True, make the label latex friendly. Default False
322 weights: list, optional
323 list of weights to use for each analysis. Default None
324 **kwargs: dict, optional
325 all additional kwargs passed to _1d_cdf_plot
326 """
327 logger.debug("Generating the 1d comparison CDF for %s" % (param))
328 if linestyles is None:
329 linestyles = ["-"] * len(samples)
330 if weights is None:
331 weights = [None] * len(samples)
332 if len(weights) != len(samples):
333 raise ValueError(
334 "Please provide a set of weights for each analysis"
335 )
336 fig, ax = figure(figsize=(8, 6), gca=True)
337 handles = []
338 for num, i in enumerate(samples):
339 fig = _1d_cdf_plot(
340 param, i, latex_label, fig=fig, color=colors[num], title=False,
341 grid=grid, linestyle=linestyles[num], weights=weights[num], **kwargs
342 )
343 if latex_friendly:
344 labels = copy.deepcopy(labels)
345 labels[num] = labels[num].replace("_", "\_")
346 handles.append(mlines.Line2D([], [], color=colors[num], label=labels[num]))
347 ncols = number_of_columns_for_legend(labels)
348 legend = ax.legend(handles=handles, ncol=ncols, **legend_kwargs)
349 _update_1d_comparison_legend(legend, linestyles)
350 ax.set_xlabel(latex_label)
351 ax.set_ylabel("Cumulative Density Function")
352 ax.grid(visible=grid)
353 ax.set_ylim([0, 1.05])
354 fig.tight_layout()
355 return fig
358def _1d_analytic_plot(
359 param, x, pdf, latex_label, inj_value=None, prior=None, fig=None, ax=None,
360 title=True, color=conf.color, autoscale=True, grid=True, set_labels=True,
361 plot_percentile=True, xlims=None, label=None, linestyle="-",
362 linewidth=1.75, injection_color=conf.injection_color,
363 _default_inj_kwargs={"linewidth": 2.5, "linestyle": "-"}, **plot_kwargs
364):
365 """Generate a plot to display a PDF
367 Parameters
368 ----------
369 param: str
370 name of the parameter that you wish to plot
372 latex_label: str
373 latex label for param
374 inj_value: float, optional
375 value that was injected
376 prior: list
377 list of prior samples for param
378 weights: list
379 list of weights for each sample
380 fig: matplotlib.pyplot.figure, optional
381 existing figure you wish to use
382 ax: matplotlib.pyplot.axes._subplots.AxesSubplot, optional
383 existing axis you wish to use
384 color: str, optional
385 color you wish to use to plot the scatter points
386 title: Bool, optional
387 if True, add a title to the 1d cdf plot showing giving the median
388 and symmetric 90% credible intervals
389 autoscale: Bool, optional
390 autoscale the x axis
391 grid: Bool, optional
392 if True, plot a grid
393 set_labels: Bool, optional
394 if True, add labels to the axes
395 plot_percentile: Bool, optional
396 if True, plot dashed vertical lines showing the 90% symmetric credible
397 intervals
398 xlims: list, optional
399 x axis limits you wish to use
400 label: str, optional
401 label you wish to use for the plot
402 linestyle: str, optional
403 linestyle you wish to use for the plot
404 linewidth: float, optional
405 linewidth to use for the plot
406 injection_color: str, optional
407 color of vertical line showing the injected value
408 """
409 from pesummary.utils.array import Array
411 if ax is None and fig is None:
412 fig, ax = figure(gca=True)
413 elif ax is None:
414 ax = fig.gca()
416 pdf = Array(x, weights=pdf)
418 ax.plot(pdf, pdf.weights, color=color, linestyle=linestyle, label=label)
419 _xlims = ax.get_xlim()
420 percentile = pdf.credible_interval([5, 95])
421 median = pdf.average("median")
422 if title:
423 upper = np.round(percentile[1] - median, 2)
424 lower = np.round(median - percentile[0], 2)
425 median = np.round(median, 2)
426 ax.set_title(r"$%s^{+%s}_{-%s}$" % (median, upper, lower))
427 if plot_percentile:
428 for pp in percentile:
429 ax.axvline(
430 pp, color=color, linestyle="--", linewidth=linewidth
431 )
432 if set_labels:
433 ax.set_xlabel(latex_label)
434 ax.set_ylabel("Probability Density")
436 if inj_value is not None:
437 ax.axvline(
438 inj_value, color=injection_color, **_default_inj_kwargs
439 )
440 ax.grid(visible=grid)
441 ax.set_xlim(xlims)
442 if autoscale:
443 ax.set_xlim(_xlims)
444 if fig is None:
445 return ax
446 fig.tight_layout()
447 return fig
450def _1d_histogram_plot(
451 param, samples, latex_label, inj_value=None, kde=False, hist=True,
452 prior=None, weights=None, fig=None, ax=None, title=True, color=conf.color,
453 autoscale=True, grid=True, kde_kwargs={}, hist_kwargs={}, set_labels=True,
454 plot_percentile=True, plot_hdp=True, xlims=None, max_vline=1, label=None,
455 linestyle="-", injection_color=conf.injection_color, _default_hist_kwargs={
456 "density": True, "bins": 50, "histtype": "step", "linewidth": 1.75
457 }, _default_kde_kwargs={"shade": True, "alpha_shade": 0.1},
458 _default_inj_kwargs={"linewidth": 2.5, "linestyle": "-"},
459 key_data=None, **plot_kwargs
460):
461 """Generate the 1d histogram plot for a given parameter for a given
462 approximant.
464 Parameters
465 ----------
466 param: str
467 name of the parameter that you wish to plot
468 samples: list
469 list of samples for param
470 latex_label: str
471 latex label for param
472 inj_value: float, optional
473 value that was injected
474 kde: Bool, optional
475 if True, a kde is plotted instead of a histogram
476 hist: Bool, optional
477 if True, plot a histogram
478 prior: list
479 list of prior samples for param
480 weights: list
481 list of weights for each sample
482 fig: matplotlib.pyplot.figure, optional
483 existing figure you wish to use
484 ax: matplotlib.pyplot.axes._subplots.AxesSubplot, optional
485 existing axis you wish to use
486 color: str, optional
487 color you wish to use to plot the scatter points
488 title: Bool, optional
489 if True, add a title to the 1d cdf plot showing giving the median
490 and symmetric 90% credible intervals
491 autoscale: Bool, optional
492 autoscale the x axis
493 grid: Bool, optional
494 if True, plot a grid
495 kde_kwargs, dict, optional
496 optional kwargs to pass to the kde class
497 hist_kwargs: dict, optional
498 optional kwargs to pass to matplotlib.pyplot.hist
499 set_labels: Bool, optional
500 if True, add labels to the axes
501 plot_percentile: Bool, optional
502 if True, plot dashed vertical lines showing the 90% symmetric credible
503 intervals
504 xlims: list, optional
505 x axis limits you wish to use
506 max_vline: int, optional
507 if number of peaks < max_vline draw peaks as vertical lines rather
508 than histogramming the data
509 label: str, optional
510 label you wish to use for the plot
511 linestyle: str, optional
512 linestyle you wish to use for the plot
513 injection_color: str, optional
514 color of vertical line showing the injected value
515 """
516 from pesummary.utils.array import Array
518 logger.debug("Generating the 1d histogram plot for %s" % (param))
519 samples = Array(samples, weights=weights)
520 if ax is None and fig is None:
521 fig, ax = figure(gca=True)
522 elif ax is None:
523 ax = fig.gca()
525 if len(set(samples)) <= max_vline:
526 for _ind, _sample in enumerate(set(samples)):
527 _label = None
528 if _ind == 0:
529 _label = label
530 ax.axvline(_sample, color=color, label=_label)
531 _xlims = ax.get_xlim()
532 else:
533 if hist:
534 _default_hist_kwargs.update(hist_kwargs)
535 ax.hist(
536 samples, weights=weights, color=color, label=label,
537 linestyle=linestyle, **_default_hist_kwargs, **plot_kwargs
538 )
539 _xlims = ax.get_xlim()
540 if prior is not None:
541 _prior_hist_kwargs = _default_hist_kwargs.copy()
542 _prior_hist_kwargs["histtype"] = "bar"
543 _ = ax.hist(
544 prior, color=conf.prior_color, alpha=0.2, edgecolor="w",
545 linestyle=linestyle, **_prior_hist_kwargs, **plot_kwargs
546 )
547 if kde:
548 _kde_kwargs = kde_kwargs.copy()
549 kwargs = _default_kde_kwargs
550 kwargs.update({
551 "kde_kwargs": _kde_kwargs,
552 "kde_kernel": _kde_kwargs.pop("kde_kernel", None),
553 "variance_atol": _kde_kwargs.pop("variance_atol", 1e-8),
554 "weights": weights
555 })
556 kwargs.update(plot_kwargs)
557 x = kdeplot(
558 samples, color=color, ax=ax, linestyle=linestyle, **kwargs
559 )
560 _xlims = ax.get_xlim()
561 if prior is not None:
562 kdeplot(
563 prior, color=conf.prior_color, ax=ax, linestyle=linestyle,
564 **kwargs
565 )
567 if set_labels:
568 ax.set_xlabel(latex_label)
569 ax.set_ylabel("Probability Density")
571 if inj_value is not None:
572 ax.axvline(
573 inj_value, color=injection_color, **_default_inj_kwargs
574 )
575 hdp = float("nan")
576 if key_data is not None:
577 percentile = [key_data["5th percentile"], key_data["95th percentile"]]
578 median = key_data["median"]
579 if "90% HPD" in key_data.keys():
580 hdp = key_data["90% HPD"]
581 else:
582 percentile = samples.credible_interval([5, 95])
583 median = samples.average("median")
584 if plot_percentile:
585 for pp in percentile:
586 ax.axvline(
587 pp, color=color, linestyle="--",
588 linewidth=hist_kwargs.get("linewidth", 1.75)
589 )
590 if plot_hdp and isinstance(hdp, (list, np.ndarray)):
591 for pp in hdp:
592 ax.axvline(
593 pp, color=color, linestyle=":",
594 linewidth=hist_kwargs.get("linewidth", 1.75)
595 )
596 if title:
597 upper = np.round(percentile[1] - median, 2)
598 lower = np.abs(np.round(median - percentile[0], 2))
599 median = np.round(median, 2)
600 _base = r"$%s^{+%s}_{-%s}" % (median, upper, lower)
601 if not isinstance(hdp, (list, np.ndarray)) and np.isnan(hdp):
602 _base += r"$"
603 ax.set_title(_base)
604 else:
605 upper = np.round(hdp[1] - median, 2)
606 lower = np.abs(np.round(median - hdp[0], 2))
607 _base += r"\, (\mathrm{CI}) / %s^{+%s}_{-%s}\, (\mathrm{HPD})$" % (
608 median, upper, lower
609 )
610 ax.set_title(_base)
611 ax.grid(visible=grid)
612 ax.set_xlim(xlims)
613 if autoscale:
614 ax.set_xlim(_xlims)
615 if fig is None:
616 return ax
617 fig.tight_layout()
618 return fig
621def _1d_histogram_plot_mcmc(
622 param, samples, latex_label, colorcycle=conf.colorcycle, **kwargs
623):
624 """Generate a 1d histogram plot for a given parameter for a given
625 set of mcmc chains
627 Parameters
628 ----------
629 param: str
630 name of the parameter that you wish to plot
631 samples: np.ndarray
632 2d array of samples for param for each mcmc chain
633 latex_label: str
634 latex label for param
635 colorcycle: list, str
636 color cycle you wish to use for the different mcmc chains
637 **kwargs: dict, optional
638 all additional kwargs passed to _1d_histogram_plot
639 """
640 cycol = cycle(colorcycle)
641 fig, ax = figure(gca=True)
642 for ss in samples:
643 fig = _1d_histogram_plot(
644 param, ss, latex_label, color=next(cycol), title=False,
645 autoscale=False, fig=fig, **kwargs
646 )
647 gelman = gelman_rubin(samples)
648 ax.set_title("Gelman-Rubin: {}".format(gelman))
649 return fig
652def _1d_histogram_plot_bootstrap(
653 param, samples, latex_label, colorcycle=conf.colorcycle, nsamples=1000,
654 ntests=100, shade=False, plot_percentile=False, kde=True, hist=False,
655 **kwargs
656):
657 """Generate a bootstrapped 1d histogram plot for a given parameter
659 Parameters
660 ----------
661 param: str
662 name of the parameter that you wish to plot
663 samples: np.ndarray
664 array of samples for param
665 latex_label: str
666 latex label for param
667 colorcycle: list, str
668 color cycle you wish to use for the different tests
669 nsamples: int, optional
670 number of samples to randomly draw from samples. Default 1000
671 ntests: int, optional
672 number of tests to perform. Default 100
673 **kwargs: dict, optional
674 all additional kwargs passed to _1d_histogram_plot
675 """
676 if nsamples > len(samples):
677 nsamples = int(len(samples) / 2)
678 _samples = [
679 np.random.choice(samples, size=nsamples, replace=False) for _ in
680 range(ntests)
681 ]
682 cycol = cycle(colorcycle)
683 fig, ax = figure(gca=True)
684 for ss in _samples:
685 fig = _1d_histogram_plot(
686 param, ss, latex_label, color=next(cycol), title=False,
687 autoscale=False, fig=fig, shade=shade,
688 plot_percentile=plot_percentile, kde=kde, hist=hist, **kwargs
689 )
690 ax.set_title("Ntests: {}, Nsamples per test: {}".format(ntests, nsamples))
691 fig.tight_layout()
692 return fig
695def _1d_comparison_histogram_plot(
696 param, samples, colors, latex_label, labels, inj_value=None, kde=False,
697 hist=True, linestyles=None, kde_kwargs={}, hist_kwargs={}, max_vline=1,
698 figsize=(8, 6), grid=True, legend_kwargs=_default_legend_kwargs,
699 latex_friendly=False, max_inj_line=1, injection_color="k",
700 weights=None, **kwargs
701):
702 """Generate the a plot to compare the 1d_histogram plots for a given
703 parameter for different approximants.
705 Parameters
706 ----------
707 param: str
708 name of the parameter that you wish to plot
709 approximants: list
710 list of approximant names that you would like to compare
711 samples: 2d list
712 list of samples for param for each approximant
713 colors: list
714 list of colors to be used to differentiate the different approximants
715 latex_label: str
716 latex label for param
717 approximant_labels: list, optional
718 label to prepend the approximant in the legend
719 kde: Bool
720 if true, a kde is plotted instead of a histogram
721 linestyles: list
722 list of linestyles for each set of samples
723 grid: Bool, optional
724 if True, plot a grid
725 legend_kwargs: dict, optional
726 optional kwargs to pass to ax.legend()
727 latex_friendly: Bool, optional
728 if True, make the label latex friendly. Default False
729 inj_value: float/list, optional
730 either a single injection value which will be used for all histograms
731 or a list of injection values, one for each histogram
732 injection_color: str/list, optional
733 either a single color which will be used for all vertical line showing
734 the injected value or a list of colors, one for each injection
735 weights: list, optional
736 list of weights to use for each analysis. Default None
737 **kwargs: dict, optional
738 all additional kwargs passed to _1d_histogram_plot
739 """
740 logger.debug("Generating the 1d comparison histogram plot for %s" % (param))
741 if linestyles is None:
742 linestyles = ["-"] * len(samples)
743 if inj_value is None:
744 inj_value = [None] * len(samples)
745 if weights is None:
746 weights = [None] * len(samples)
747 if len(weights) != len(samples):
748 raise ValueError(
749 "Please provide a set of weights for each analysis"
750 )
751 elif isinstance(inj_value, (list, np.ndarray)) and len(inj_value) != len(samples):
752 raise ValueError(
753 "Please provide an injection for each analysis or a single "
754 "injection value which will be used for all histograms"
755 )
756 elif not isinstance(inj_value, (list, np.ndarray)):
757 inj_value = [inj_value] * len(samples)
759 if isinstance(injection_color, str):
760 injection_color = [injection_color] * len(samples)
761 elif len(injection_color) != len(samples):
762 raise ValueError(
763 "Please provide an injection color for each analysis or a single "
764 "injection color which will be used for all lines showing the "
765 "injected values"
766 )
768 flat_injection = np.array([_ for _ in inj_value if _ is not None]).flatten()
769 if len(set(flat_injection)) > max_inj_line:
770 logger.warning(
771 "Number of unique injection values ({}) is more than the maximum "
772 "allowed injection value ({}). Not plotting injection value. If "
773 "this is a mistake, please increase `max_inj_line`".format(
774 len(set(flat_injection)), max_inj_line
775 )
776 )
777 inj_value = [None] * len(samples)
779 fig, ax = figure(figsize=figsize, gca=True)
780 handles = []
781 hist_kwargs.update({"linewidth": 2.5})
782 for num, i in enumerate(samples):
783 if latex_friendly:
784 labels = copy.deepcopy(labels)
785 labels[num] = labels[num].replace("_", "\_")
786 fig = _1d_histogram_plot(
787 param, i, latex_label, kde=kde, hist=hist, kde_kwargs=kde_kwargs,
788 max_vline=max_vline, grid=grid, title=False, autoscale=False,
789 label=labels[num], color=colors[num], fig=fig, hist_kwargs=hist_kwargs,
790 inj_value=inj_value[num], injection_color=injection_color[num],
791 weights=weights[num], linestyle=linestyles[num], _default_inj_kwargs={
792 "linewidth": 4., "linestyle": "-", "alpha": 0.4
793 }, **kwargs
794 )
795 handles.append(mlines.Line2D([], [], color=colors[num], label=labels[num]))
796 ax = fig.gca()
797 ncols = number_of_columns_for_legend(labels)
798 legend = ax.legend(handles=handles, ncol=ncols, **legend_kwargs)
799 _update_1d_comparison_legend(legend, linestyles)
800 ax.set_xlabel(latex_label)
801 ax.set_ylabel("Probability Density")
802 ax.autoscale(axis='x')
803 ax.grid(visible=grid)
804 fig.tight_layout()
805 return fig
808def _comparison_box_plot(param, samples, colors, latex_label, labels, grid=True):
809 """Generate a box plot to compare 1d_histograms for a given parameter
811 Parameters
812 ----------
813 param: str
814 name of the parameter that you wish to plot
815 approximants: list
816 list of approximant names that you would like to compare
817 samples: 2d list
818 list of samples for param for each approximant
819 colors: list
820 list of colors to be used to differentiate the different approximants
821 latex_label: str
822 latex label for param
823 approximant_labels: list, optional
824 label to prepend the approximant in the legend
825 grid: Bool, optional
826 if True, plot a grid
827 """
828 logger.debug("Generating the 1d comparison boxplot plot for %s" % (param))
829 fig, ax = figure(gca=True)
830 maximum = np.max([np.max(i) for i in samples])
831 minimum = np.min([np.min(i) for i in samples])
832 middle = (maximum + minimum) * 0.5
833 ax.boxplot(samples, widths=0.2, vert=False, whis=np.inf, labels=labels)
834 for num, i in enumerate(labels):
835 ax.annotate(i, xy=(middle, 1), xytext=(middle, num + 1.0 + 0.2), ha="center")
836 ax.set_yticks([])
837 ax.set_xlabel(latex_label)
838 fig.tight_layout()
839 ax.grid(visible=grid)
840 return fig
843def _make_corner_plot(
844 samples, latex_labels, corner_parameters=None, parameters=None, **kwargs
845):
846 """Generate the corner plots for a given approximant
848 Parameters
849 ----------
850 opts: argparse
851 argument parser object to hold all information from the command line
852 samples: nd list
853 nd list of samples for each parameter for a given approximant
854 params: list
855 list of parameters associated with each element in samples
856 approximant: str
857 name of approximant that was used to generate the samples
858 latex_labels: dict
859 dictionary of latex labels for each parameter
860 """
861 logger.debug("Generating the corner plot")
862 # set the default kwargs
863 default_kwargs = conf.corner_kwargs.copy()
864 if parameters is None:
865 parameters = list(samples.keys())
866 if corner_parameters is not None:
867 included_parameters = [i for i in parameters if i in corner_parameters]
868 else:
869 included_parameters = parameters
870 xs = np.zeros([len(included_parameters), len(samples[parameters[0]])])
871 for num, i in enumerate(included_parameters):
872 xs[num] = samples[i]
873 default_kwargs.update(kwargs)
874 default_kwargs["range"] = [1.0] * len(included_parameters)
875 default_kwargs["labels"] = [latex_labels[i] for i in included_parameters]
877 _figure = ExistingFigure(corner(xs.T, included_parameters, **default_kwargs))
878 # grab the axes of the subplots
879 axes = _figure.get_axes()
880 axes_of_interest = axes[:2]
881 location = []
882 for i in axes_of_interest:
883 extent = i.get_window_extent().transformed(_figure.dpi_scale_trans.inverted())
884 location.append([extent.x0 * _figure.dpi, extent.y0 * _figure.dpi])
885 width, height = extent.width, extent.height
886 width *= _figure.dpi
887 height *= _figure.dpi
888 try:
889 seperation = abs(location[0][0] - location[1][0]) - width
890 except IndexError:
891 seperation = None
892 data = {
893 "width": width, "height": height, "seperation": seperation,
894 "x0": location[0][0], "y0": location[0][0]
895 }
896 return _figure, included_parameters, data
899def _make_comparison_corner_plot(
900 samples, latex_labels, corner_parameters=None, colors=conf.corner_colors,
901 latex_friendly=True, **kwargs
902):
903 """Generate a corner plot which contains multiple datasets
905 Parameters
906 ----------
907 samples: dict
908 nested dictionary containing the label as key and SamplesDict as item
909 for each dataset you wish to plot
910 latex_labels: dict
911 dictionary of latex labels for each parameter
912 corner_parameters: list, optional
913 corner parameters you wish to include in the plot
914 colors: list, optional
915 unique colors for each dataset
916 latex_friendly: Bool, optional
917 if True, make the label latex friendly. Default True
918 **kwargs: dict
919 all kwargs are passed to `corner.corner`
920 """
921 parameters = corner_parameters
922 if corner_parameters is None:
923 _parameters = [list(_samples.keys()) for _samples in samples.values()]
924 parameters = [
925 i for i in _parameters[0] if all(i in _params for _params in _parameters)
926 ]
927 if len(samples.keys()) > len(colors):
928 raise ValueError("Please provide a unique color for each dataset")
930 hist_kwargs = kwargs.get("hist_kwargs", dict())
931 hist_kwargs["density"] = True
932 lines = []
933 for num, (label, posterior) in enumerate(samples.items()):
934 if latex_friendly:
935 label = copy.deepcopy(label)
936 label = label.replace("_", "\_")
937 lines.append(mlines.Line2D([], [], color=colors[num], label=label))
938 _samples = {
939 param: value for param, value in posterior.items() if param in
940 parameters
941 }
942 hist_kwargs["color"] = colors[num]
943 kwargs.update({"hist_kwargs": hist_kwargs})
944 if num == 0:
945 fig, _, _ = _make_corner_plot(
946 _samples, latex_labels, corner_parameters=corner_parameters,
947 parameters=parameters, color=colors[num], **kwargs
948 )
949 else:
950 fig, _, _ = _make_corner_plot(
951 _samples, latex_labels, corner_parameters=corner_parameters,
952 fig=fig, parameters=parameters, color=colors[num], **kwargs
953 )
954 fig.legend(handles=lines, loc="upper right")
955 lines = []
956 return fig