Coverage for pesummary/core/plots/plot.py: 85.4%
349 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-02 08:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-02 08:42 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3from pesummary.utils.utils import (
4 logger, number_of_columns_for_legend, _check_latex_install, gelman_rubin,
5)
6from pesummary.core.plots.seaborn.kde import kdeplot
7from pesummary.core.plots.corner import corner
8from pesummary.core.plots.figure import figure, ExistingFigure
9from pesummary import conf
11import matplotlib.lines as mlines
12import copy
13from itertools import cycle
15import numpy as np
16from scipy import signal
18__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
19_check_latex_install()
21_default_legend_kwargs = dict(
22 bbox_to_anchor=(0.0, 1.02, 1.0, 0.102), loc=3, handlelength=3, mode="expand",
23 borderaxespad=0.0,
24)
27def _autocorrelation_plot(
28 param, samples, fig=None, color=conf.color, markersize=0.5, grid=True
29):
30 """Generate the autocorrelation function for a set of samples for a given
31 parameter for a given approximant.
33 Parameters
34 ----------
35 param: str
36 name of the parameter that you wish to plot
37 samples: list
38 list of samples for param
39 fig: matplotlib.pyplot.figure
40 existing figure you wish to use
41 color: str, optional
42 color you wish to use for the autocorrelation plot
43 grid: Bool, optional
44 if True, plot a grid
45 """
46 import warnings
47 warnings.filterwarnings("ignore", category=RuntimeWarning)
48 logger.debug("Generating the autocorrelation function for %s" % (param))
49 if fig is None:
50 fig, ax = figure(gca=True)
51 else:
52 ax = fig.gca()
53 samples = samples[int(len(samples) / 2):]
54 x = samples - np.mean(samples)
55 y = np.conj(x[::-1])
56 acf = np.fft.ifftshift(signal.fftconvolve(y, x, mode="full"))
57 N = np.array(samples).shape[0]
58 acf = acf[0:N]
59 # Hack to make test pass with python3.8
60 if color == "$":
61 color = conf.color
62 ax.plot(
63 acf / acf[0], linestyle=" ", marker="o", markersize=markersize,
64 color=color
65 )
66 ax.ticklabel_format(axis="x", style="plain")
67 ax.set_xlabel("lag")
68 ax.set_ylabel("ACF")
69 ax.grid(visible=grid)
70 fig.tight_layout()
71 return fig
74def _autocorrelation_plot_mcmc(
75 param, samples, colorcycle=conf.colorcycle, grid=True
76):
77 """Generate the autocorrelation function for a set of samples for a given
78 parameter for a given set of mcmc chains
80 Parameters
81 ----------
82 param: str
83 name of the parameter that you wish to plot
84 samples: np.ndarray
85 2d array containing a list of samples for param for each mcmc chain
86 colorcycle: list, str
87 color cycle you wish to use for the different mcmc chains
88 grid: Bool, optional
89 if True, plot a grid
90 """
91 cycol = cycle(colorcycle)
92 fig, ax = figure(gca=True)
93 for ss in samples:
94 fig = _autocorrelation_plot(
95 param, ss, fig=fig, markersize=1.25, color=next(cycol), grid=grid
96 )
97 return fig
100def _sample_evolution_plot(
101 param, samples, latex_label, inj_value=None, fig=None, color=conf.color,
102 markersize=0.5, grid=True, z=None, z_label=None, **kwargs
103):
104 """Generate a scatter plot showing the evolution of the samples for a
105 given parameter for a given approximant.
107 Parameters
108 ----------
109 param: str
110 name of the parameter that you wish to plot
111 samples: list
112 list of samples for param
113 latex_label: str
114 latex label for param
115 inj_value: float
116 value that was injected
117 fig: matplotlib.pyplot.figure, optional
118 existing figure you wish to use
119 color: str, optional
120 color you wish to use to plot the scatter points
121 grid: Bool, optional
122 if True, plot a grid
123 """
124 logger.debug("Generating the sample scatter plot for %s" % (param))
125 if fig is None:
126 fig, ax = figure(gca=True)
127 else:
128 ax = fig.gca()
129 n_samples = len(samples)
130 add_cbar = True if z is not None else False
131 if z is None:
132 z = color
133 s = ax.scatter(
134 range(n_samples), samples, marker="o", s=markersize, c=z,
135 **kwargs
136 )
137 if add_cbar:
138 cbar = fig.colorbar(s)
139 if z_label is not None:
140 cbar.set_label(z_label)
141 ax.ticklabel_format(axis="x", style="plain")
142 ax.set_xlabel("samples")
143 ax.set_ylabel(latex_label)
144 ax.grid(visible=grid)
145 fig.tight_layout()
146 return fig
149def _sample_evolution_plot_mcmc(
150 param, samples, latex_label, inj_value=None, colorcycle=conf.colorcycle,
151 grid=True
152):
153 """Generate a scatter plot showing the evolution of the samples in each
154 mcmc chain for a given parameter
156 Parameters
157 ----------
158 param: str
159 name of the parameter that you wish to plot
160 samples: np.ndarray
161 2d array containing the samples for param for each mcmc chain
162 latex_label: str
163 latex label for param
164 inj_value: float
165 value that was injected
166 colorcycle: list, str
167 color cycle you wish to use for the different mcmc chains
168 grid: Bool, optional
169 if True, plot a grid
170 """
171 cycol = cycle(colorcycle)
172 fig, ax = figure(gca=True)
173 for ss in samples:
174 fig = _sample_evolution_plot(
175 param, ss, latex_label, inj_value=None, fig=fig, markersize=1.25,
176 color=next(cycol), grid=grid
177 )
178 return fig
181def _1d_cdf_plot(
182 param, samples, latex_label, fig=None, color=conf.color, title=True,
183 grid=True, linestyle="-", **kwargs
184):
185 """Generate the cumulative distribution function for a given parameter for
186 a given approximant.
188 Parameters
189 ----------
190 param: str
191 name of the parameter that you wish to plot
192 samples: list
193 list of samples for param
194 latex_label: str
195 latex label for param
196 fig: matplotlib.pyplot.figure, optional
197 existing figure you wish to use
198 color: str, optional09
199 color you wish to use to plot the scatter points
200 title: Bool, optional
201 if True, add a title to the 1d cdf plot showing giving the median
202 and symmetric 90% credible intervals
203 grid: Bool, optional
204 if True, plot a grid
205 linestyle: str, optional
206 linestyle to use for plotting the CDF. Default "-"
207 **kwargs: dict, optional
208 all additional kwargs passed to ax.plot
209 """
210 logger.debug("Generating the 1d CDF for %s" % (param))
211 if fig is None:
212 fig, ax = figure(gca=True)
213 else:
214 ax = fig.gca()
215 sorted_samples = copy.deepcopy(samples)
216 sorted_samples.sort()
217 ax.set_xlabel(latex_label)
218 ax.set_ylabel("Cumulative Density Function")
219 upper_percentile = np.percentile(samples, 95)
220 lower_percentile = np.percentile(samples, 5)
221 median = np.median(samples)
222 upper = np.round(upper_percentile - median, 2)
223 lower = np.round(median - lower_percentile, 2)
224 median = np.round(median, 2)
225 if title:
226 ax.set_title(r"$%s^{+%s}_{-%s}$" % (median, upper, lower))
227 ax.plot(
228 sorted_samples, np.linspace(0, 1, len(sorted_samples)), color=color,
229 linestyle=linestyle, **kwargs
230 )
231 ax.grid(visible=grid)
232 ax.set_ylim([0, 1.05])
233 fig.tight_layout()
234 return fig
237def _1d_cdf_plot_mcmc(
238 param, samples, latex_label, colorcycle=conf.colorcycle, grid=True, **kwargs
239):
240 """Generate the cumulative distribution function for a given parameter
241 for a given set of mcmc chains
243 Parameters
244 ----------
245 param: str
246 name of the parameter that you wish to plot
247 samples: np.ndarray
248 2d array containing the samples for param for each mcmc chain
249 latex_label: str
250 latex label for param
251 colorcycle: list, str
252 color cycle you wish to use for the different mcmc chains
253 grid: Bool, optional
254 if True, plot a grid
255 **kwargs: dict, optional
256 all additional kwargs passed to _1d_cdf_plot
257 """
258 cycol = cycle(colorcycle)
259 fig, ax = figure(gca=True)
260 for ss in samples:
261 fig = _1d_cdf_plot(
262 param, ss, latex_label, fig=fig, color=next(cycol), title=False,
263 grid=grid, **kwargs
264 )
265 gelman = gelman_rubin(samples)
266 ax.set_title("Gelman-Rubin: {}".format(gelman))
267 return fig
270def _1d_cdf_comparison_plot(
271 param, samples, colors, latex_label, labels, linestyles=None, grid=True,
272 legend_kwargs=_default_legend_kwargs, latex_friendly=False, **kwargs
273):
274 """Generate a plot to compare the cdfs for a given parameter for different
275 approximants.
277 Parameters
278 ----------
279 param: str
280 name of the parameter that you wish to plot
281 approximants: list
282 list of approximant names that you would like to compare
283 samples: 2d list
284 list of samples for param for each approximant
285 colors: list
286 list of colors to be used to differentiate the different approximants
287 latex_label: str
288 latex label for param
289 approximant_labels: list, optional
290 label to prepend the approximant in the legend
291 grid: Bool, optional
292 if True, plot a grid
293 legend_kwargs: dict, optional
294 optional kwargs to pass to ax.legend()
295 latex_friendly: Bool, optional
296 if True, make the label latex friendly. Default False
297 **kwargs: dict, optional
298 all additional kwargs passed to _1d_cdf_plot
299 """
300 logger.debug("Generating the 1d comparison CDF for %s" % (param))
301 if linestyles is None:
302 linestyles = ["-"] * len(samples)
303 fig, ax = figure(figsize=(8, 6), gca=True)
304 handles = []
305 for num, i in enumerate(samples):
306 fig = _1d_cdf_plot(
307 param, i, latex_label, fig=fig, color=colors[num], title=False,
308 grid=grid, linestyle=linestyles[num], **kwargs
309 )
310 if latex_friendly:
311 labels = copy.deepcopy(labels)
312 labels[num] = labels[num].replace("_", "\_")
313 handles.append(mlines.Line2D([], [], color=colors[num], label=labels[num]))
314 ncols = number_of_columns_for_legend(labels)
315 legend = ax.legend(handles=handles, ncol=ncols, **legend_kwargs)
316 for num, legobj in enumerate(legend.legendHandles):
317 legobj.set_linewidth(1.75)
318 legobj.set_linestyle(linestyles[num])
319 ax.set_xlabel(latex_label)
320 ax.set_ylabel("Cumulative Density Function")
321 ax.grid(visible=grid)
322 ax.set_ylim([0, 1.05])
323 fig.tight_layout()
324 return fig
327def _1d_analytic_plot(
328 param, x, pdf, latex_label, inj_value=None, prior=None, fig=None, ax=None,
329 title=True, color=conf.color, autoscale=True, grid=True, set_labels=True,
330 plot_percentile=True, xlims=None, label=None, linestyle="-",
331 linewidth=1.75, injection_color=conf.injection_color,
332 _default_inj_kwargs={"linewidth": 2.5, "linestyle": "-"}, **plot_kwargs
333):
334 """Generate a plot to display a PDF
336 Parameters
337 ----------
338 param: str
339 name of the parameter that you wish to plot
341 latex_label: str
342 latex label for param
343 inj_value: float, optional
344 value that was injected
345 prior: list
346 list of prior samples for param
347 weights: list
348 list of weights for each sample
349 fig: matplotlib.pyplot.figure, optional
350 existing figure you wish to use
351 ax: matplotlib.pyplot.axes._subplots.AxesSubplot, optional
352 existing axis you wish to use
353 color: str, optional
354 color you wish to use to plot the scatter points
355 title: Bool, optional
356 if True, add a title to the 1d cdf plot showing giving the median
357 and symmetric 90% credible intervals
358 autoscale: Bool, optional
359 autoscale the x axis
360 grid: Bool, optional
361 if True, plot a grid
362 set_labels: Bool, optional
363 if True, add labels to the axes
364 plot_percentile: Bool, optional
365 if True, plot dashed vertical lines showing the 90% symmetric credible
366 intervals
367 xlims: list, optional
368 x axis limits you wish to use
369 label: str, optional
370 label you wish to use for the plot
371 linestyle: str, optional
372 linestyle you wish to use for the plot
373 linewidth: float, optional
374 linewidth to use for the plot
375 injection_color: str, optional
376 color of vertical line showing the injected value
377 """
378 from pesummary.utils.array import Array
380 if ax is None and fig is None:
381 fig, ax = figure(gca=True)
382 elif ax is None:
383 ax = fig.gca()
385 pdf = Array(x, weights=pdf)
387 ax.plot(pdf, pdf.weights, color=color, linestyle=linestyle, label=label)
388 _xlims = ax.get_xlim()
389 percentile = pdf.credible_interval([5, 95])
390 median = pdf.average("median")
391 if title:
392 upper = np.round(percentile[1] - median, 2)
393 lower = np.round(median - percentile[0], 2)
394 median = np.round(median, 2)
395 ax.set_title(r"$%s^{+%s}_{-%s}$" % (median, upper, lower))
396 if plot_percentile:
397 for pp in percentile:
398 ax.axvline(
399 pp, color=color, linestyle="--", linewidth=linewidth
400 )
401 if set_labels:
402 ax.set_xlabel(latex_label)
403 ax.set_ylabel("Probability Density")
405 if inj_value is not None:
406 ax.axvline(
407 inj_value, color=injection_color, **_default_inj_kwargs
408 )
409 ax.grid(visible=grid)
410 ax.set_xlim(xlims)
411 if autoscale:
412 ax.set_xlim(_xlims)
413 if fig is None:
414 return ax
415 fig.tight_layout()
416 return fig
419def _1d_histogram_plot(
420 param, samples, latex_label, inj_value=None, kde=False, hist=True,
421 prior=None, weights=None, fig=None, ax=None, title=True, color=conf.color,
422 autoscale=True, grid=True, kde_kwargs={}, hist_kwargs={}, set_labels=True,
423 plot_percentile=True, plot_hdp=True, xlims=None, max_vline=1, label=None,
424 linestyle="-", injection_color=conf.injection_color, _default_hist_kwargs={
425 "density": True, "bins": 50, "histtype": "step", "linewidth": 1.75
426 }, _default_kde_kwargs={"shade": True, "alpha_shade": 0.1},
427 _default_inj_kwargs={"linewidth": 2.5, "linestyle": "-"},
428 key_data=None, **plot_kwargs
429):
430 """Generate the 1d histogram plot for a given parameter for a given
431 approximant.
433 Parameters
434 ----------
435 param: str
436 name of the parameter that you wish to plot
437 samples: list
438 list of samples for param
439 latex_label: str
440 latex label for param
441 inj_value: float, optional
442 value that was injected
443 kde: Bool, optional
444 if True, a kde is plotted instead of a histogram
445 hist: Bool, optional
446 if True, plot a histogram
447 prior: list
448 list of prior samples for param
449 weights: list
450 list of weights for each sample
451 fig: matplotlib.pyplot.figure, optional
452 existing figure you wish to use
453 ax: matplotlib.pyplot.axes._subplots.AxesSubplot, optional
454 existing axis you wish to use
455 color: str, optional
456 color you wish to use to plot the scatter points
457 title: Bool, optional
458 if True, add a title to the 1d cdf plot showing giving the median
459 and symmetric 90% credible intervals
460 autoscale: Bool, optional
461 autoscale the x axis
462 grid: Bool, optional
463 if True, plot a grid
464 kde_kwargs, dict, optional
465 optional kwargs to pass to the kde class
466 hist_kwargs: dict, optional
467 optional kwargs to pass to matplotlib.pyplot.hist
468 set_labels: Bool, optional
469 if True, add labels to the axes
470 plot_percentile: Bool, optional
471 if True, plot dashed vertical lines showing the 90% symmetric credible
472 intervals
473 xlims: list, optional
474 x axis limits you wish to use
475 max_vline: int, optional
476 if number of peaks < max_vline draw peaks as vertical lines rather
477 than histogramming the data
478 label: str, optional
479 label you wish to use for the plot
480 linestyle: str, optional
481 linestyle you wish to use for the plot
482 injection_color: str, optional
483 color of vertical line showing the injected value
484 """
485 from pesummary.utils.array import Array
487 logger.debug("Generating the 1d histogram plot for %s" % (param))
488 samples = Array(samples, weights=weights)
489 if ax is None and fig is None:
490 fig, ax = figure(gca=True)
491 elif ax is None:
492 ax = fig.gca()
494 if len(set(samples)) <= max_vline:
495 for _ind, _sample in enumerate(set(samples)):
496 _label = None
497 if _ind == 0:
498 _label = label
499 ax.axvline(_sample, color=color, label=_label)
500 _xlims = ax.get_xlim()
501 else:
502 if hist:
503 _default_hist_kwargs.update(hist_kwargs)
504 ax.hist(
505 samples, weights=weights, color=color, label=label,
506 linestyle=linestyle, **_default_hist_kwargs, **plot_kwargs
507 )
508 _xlims = ax.get_xlim()
509 if prior is not None:
510 _prior_hist_kwargs = _default_hist_kwargs.copy()
511 _prior_hist_kwargs["histtype"] = "bar"
512 _ = ax.hist(
513 prior, color=conf.prior_color, alpha=0.2, edgecolor="w",
514 linestyle=linestyle, **_prior_hist_kwargs, **plot_kwargs
515 )
516 if kde:
517 _kde_kwargs = kde_kwargs.copy()
518 kwargs = _default_kde_kwargs
519 kwargs.update({
520 "kde_kwargs": _kde_kwargs,
521 "kde_kernel": _kde_kwargs.pop("kde_kernel", None),
522 "variance_atol": _kde_kwargs.pop("variance_atol", 1e-8),
523 "weights": weights
524 })
525 kwargs.update(plot_kwargs)
526 x = kdeplot(
527 samples, color=color, ax=ax, linestyle=linestyle, **kwargs
528 )
529 _xlims = ax.get_xlim()
530 if prior is not None:
531 kdeplot(
532 prior, color=conf.prior_color, ax=ax, linestyle=linestyle,
533 **kwargs
534 )
536 if set_labels:
537 ax.set_xlabel(latex_label)
538 ax.set_ylabel("Probability Density")
540 if inj_value is not None:
541 ax.axvline(
542 inj_value, color=injection_color, **_default_inj_kwargs
543 )
544 hdp = float("nan")
545 if key_data is not None:
546 percentile = [key_data["5th percentile"], key_data["95th percentile"]]
547 median = key_data["median"]
548 if "90% HPD" in key_data.keys():
549 hdp = key_data["90% HPD"]
550 else:
551 percentile = samples.credible_interval([5, 95])
552 median = samples.average("median")
553 if plot_percentile:
554 for pp in percentile:
555 ax.axvline(
556 pp, color=color, linestyle="--",
557 linewidth=hist_kwargs.get("linewidth", 1.75)
558 )
559 if plot_hdp and isinstance(hdp, (list, np.ndarray)):
560 for pp in hdp:
561 ax.axvline(
562 pp, color=color, linestyle=":",
563 linewidth=hist_kwargs.get("linewidth", 1.75)
564 )
565 if title:
566 upper = np.round(percentile[1] - median, 2)
567 lower = np.abs(np.round(median - percentile[0], 2))
568 median = np.round(median, 2)
569 _base = r"$%s^{+%s}_{-%s}" % (median, upper, lower)
570 if not isinstance(hdp, (list, np.ndarray)) and np.isnan(hdp):
571 _base += r"$"
572 ax.set_title(_base)
573 else:
574 upper = np.round(hdp[1] - median, 2)
575 lower = np.abs(np.round(median - hdp[0], 2))
576 _base += r"\, (\mathrm{CI}) / %s^{+%s}_{-%s}\, (\mathrm{HPD})$" % (
577 median, upper, lower
578 )
579 ax.set_title(_base)
580 ax.grid(visible=grid)
581 ax.set_xlim(xlims)
582 if autoscale:
583 ax.set_xlim(_xlims)
584 if fig is None:
585 return ax
586 fig.tight_layout()
587 return fig
590def _1d_histogram_plot_mcmc(
591 param, samples, latex_label, colorcycle=conf.colorcycle, **kwargs
592):
593 """Generate a 1d histogram plot for a given parameter for a given
594 set of mcmc chains
596 Parameters
597 ----------
598 param: str
599 name of the parameter that you wish to plot
600 samples: np.ndarray
601 2d array of samples for param for each mcmc chain
602 latex_label: str
603 latex label for param
604 colorcycle: list, str
605 color cycle you wish to use for the different mcmc chains
606 **kwargs: dict, optional
607 all additional kwargs passed to _1d_histogram_plot
608 """
609 cycol = cycle(colorcycle)
610 fig, ax = figure(gca=True)
611 for ss in samples:
612 fig = _1d_histogram_plot(
613 param, ss, latex_label, color=next(cycol), title=False,
614 autoscale=False, fig=fig, **kwargs
615 )
616 gelman = gelman_rubin(samples)
617 ax.set_title("Gelman-Rubin: {}".format(gelman))
618 return fig
621def _1d_histogram_plot_bootstrap(
622 param, samples, latex_label, colorcycle=conf.colorcycle, nsamples=1000,
623 ntests=100, shade=False, plot_percentile=False, kde=True, hist=False,
624 **kwargs
625):
626 """Generate a bootstrapped 1d histogram plot for a given parameter
628 Parameters
629 ----------
630 param: str
631 name of the parameter that you wish to plot
632 samples: np.ndarray
633 array of samples for param
634 latex_label: str
635 latex label for param
636 colorcycle: list, str
637 color cycle you wish to use for the different tests
638 nsamples: int, optional
639 number of samples to randomly draw from samples. Default 1000
640 ntests: int, optional
641 number of tests to perform. Default 100
642 **kwargs: dict, optional
643 all additional kwargs passed to _1d_histogram_plot
644 """
645 if nsamples > len(samples):
646 nsamples = int(len(samples) / 2)
647 _samples = [
648 np.random.choice(samples, size=nsamples, replace=False) for _ in
649 range(ntests)
650 ]
651 cycol = cycle(colorcycle)
652 fig, ax = figure(gca=True)
653 for ss in _samples:
654 fig = _1d_histogram_plot(
655 param, ss, latex_label, color=next(cycol), title=False,
656 autoscale=False, fig=fig, shade=shade,
657 plot_percentile=plot_percentile, kde=kde, hist=hist, **kwargs
658 )
659 ax.set_title("Ntests: {}, Nsamples per test: {}".format(ntests, nsamples))
660 fig.tight_layout()
661 return fig
664def _1d_comparison_histogram_plot(
665 param, samples, colors, latex_label, labels, inj_value=None, kde=False,
666 hist=True, linestyles=None, kde_kwargs={}, hist_kwargs={}, max_vline=1,
667 figsize=(8, 6), grid=True, legend_kwargs=_default_legend_kwargs,
668 latex_friendly=False, max_inj_line=1, injection_color="k", **kwargs
669):
670 """Generate the a plot to compare the 1d_histogram plots for a given
671 parameter for different approximants.
673 Parameters
674 ----------
675 param: str
676 name of the parameter that you wish to plot
677 approximants: list
678 list of approximant names that you would like to compare
679 samples: 2d list
680 list of samples for param for each approximant
681 colors: list
682 list of colors to be used to differentiate the different approximants
683 latex_label: str
684 latex label for param
685 approximant_labels: list, optional
686 label to prepend the approximant in the legend
687 kde: Bool
688 if true, a kde is plotted instead of a histogram
689 linestyles: list
690 list of linestyles for each set of samples
691 grid: Bool, optional
692 if True, plot a grid
693 legend_kwargs: dict, optional
694 optional kwargs to pass to ax.legend()
695 latex_friendly: Bool, optional
696 if True, make the label latex friendly. Default False
697 inj_value: float/list, optional
698 either a single injection value which will be used for all histograms
699 or a list of injection values, one for each histogram
700 injection_color: str/list, optional
701 either a single color which will be used for all vertical line showing
702 the injected value or a list of colors, one for each injection
703 **kwargs: dict, optional
704 all additional kwargs passed to _1d_histogram_plot
705 """
706 logger.debug("Generating the 1d comparison histogram plot for %s" % (param))
707 if linestyles is None:
708 linestyles = ["-"] * len(samples)
709 if inj_value is None:
710 inj_value = [None] * len(samples)
711 elif isinstance(inj_value, (list, np.ndarray)) and len(inj_value) != len(samples):
712 raise ValueError(
713 "Please provide an injection for each analysis or a single "
714 "injection value which will be used for all histograms"
715 )
716 elif not isinstance(inj_value, (list, np.ndarray)):
717 inj_value = [inj_value] * len(samples)
719 if isinstance(injection_color, str):
720 injection_color = [injection_color] * len(samples)
721 elif len(injection_color) != len(samples):
722 raise ValueError(
723 "Please provide an injection color for each analysis or a single "
724 "injection color which will be used for all lines showing the "
725 "injected values"
726 )
728 flat_injection = np.array([_ for _ in inj_value if _ is not None]).flatten()
729 if len(set(flat_injection)) > max_inj_line:
730 logger.warning(
731 "Number of unique injection values ({}) is more than the maximum "
732 "allowed injection value ({}). Not plotting injection value. If "
733 "this is a mistake, please increase `max_inj_line`".format(
734 len(set(flat_injection)), max_inj_line
735 )
736 )
737 inj_value = [None] * len(samples)
739 fig, ax = figure(figsize=figsize, gca=True)
740 handles = []
741 hist_kwargs.update({"linewidth": 2.5})
742 for num, i in enumerate(samples):
743 if latex_friendly:
744 labels = copy.deepcopy(labels)
745 labels[num] = labels[num].replace("_", "\_")
746 fig = _1d_histogram_plot(
747 param, i, latex_label, kde=kde, hist=hist, kde_kwargs=kde_kwargs,
748 max_vline=max_vline, grid=grid, title=False, autoscale=False,
749 label=labels[num], color=colors[num], fig=fig, hist_kwargs=hist_kwargs,
750 inj_value=inj_value[num], injection_color=injection_color[num],
751 linestyle=linestyles[num], _default_inj_kwargs={
752 "linewidth": 4., "linestyle": "-", "alpha": 0.4
753 }, **kwargs
754 )
755 handles.append(mlines.Line2D([], [], color=colors[num], label=labels[num]))
756 ax = fig.gca()
757 ncols = number_of_columns_for_legend(labels)
758 legend = ax.legend(handles=handles, ncol=ncols, **legend_kwargs)
759 for num, legobj in enumerate(legend.legendHandles):
760 legobj.set_linewidth(1.75)
761 legobj.set_linestyle(linestyles[num])
762 ax.set_xlabel(latex_label)
763 ax.set_ylabel("Probability Density")
764 ax.autoscale(axis='x')
765 ax.grid(visible=grid)
766 fig.tight_layout()
767 return fig
770def _comparison_box_plot(param, samples, colors, latex_label, labels, grid=True):
771 """Generate a box plot to compare 1d_histograms for a given parameter
773 Parameters
774 ----------
775 param: str
776 name of the parameter that you wish to plot
777 approximants: list
778 list of approximant names that you would like to compare
779 samples: 2d list
780 list of samples for param for each approximant
781 colors: list
782 list of colors to be used to differentiate the different approximants
783 latex_label: str
784 latex label for param
785 approximant_labels: list, optional
786 label to prepend the approximant in the legend
787 grid: Bool, optional
788 if True, plot a grid
789 """
790 logger.debug("Generating the 1d comparison boxplot plot for %s" % (param))
791 fig, ax = figure(gca=True)
792 maximum = np.max([np.max(i) for i in samples])
793 minimum = np.min([np.min(i) for i in samples])
794 middle = (maximum + minimum) * 0.5
795 ax.boxplot(samples, widths=0.2, vert=False, whis=np.inf, labels=labels)
796 for num, i in enumerate(labels):
797 ax.annotate(i, xy=(middle, 1), xytext=(middle, num + 1.0 + 0.2), ha="center")
798 ax.set_yticks([])
799 ax.set_xlabel(latex_label)
800 fig.tight_layout()
801 ax.grid(visible=grid)
802 return fig
805def _make_corner_plot(
806 samples, latex_labels, corner_parameters=None, parameters=None, **kwargs
807):
808 """Generate the corner plots for a given approximant
810 Parameters
811 ----------
812 opts: argparse
813 argument parser object to hold all information from the command line
814 samples: nd list
815 nd list of samples for each parameter for a given approximant
816 params: list
817 list of parameters associated with each element in samples
818 approximant: str
819 name of approximant that was used to generate the samples
820 latex_labels: dict
821 dictionary of latex labels for each parameter
822 """
823 logger.debug("Generating the corner plot")
824 # set the default kwargs
825 default_kwargs = conf.corner_kwargs.copy()
826 if parameters is None:
827 parameters = list(samples.keys())
828 if corner_parameters is not None:
829 included_parameters = [i for i in parameters if i in corner_parameters]
830 else:
831 included_parameters = parameters
832 xs = np.zeros([len(included_parameters), len(samples[parameters[0]])])
833 for num, i in enumerate(included_parameters):
834 xs[num] = samples[i]
835 default_kwargs.update(kwargs)
836 default_kwargs["range"] = [1.0] * len(included_parameters)
837 default_kwargs["labels"] = [latex_labels[i] for i in included_parameters]
839 _figure = ExistingFigure(corner(xs.T, included_parameters, **default_kwargs))
840 # grab the axes of the subplots
841 axes = _figure.get_axes()
842 axes_of_interest = axes[:2]
843 location = []
844 for i in axes_of_interest:
845 extent = i.get_window_extent().transformed(_figure.dpi_scale_trans.inverted())
846 location.append([extent.x0 * _figure.dpi, extent.y0 * _figure.dpi])
847 width, height = extent.width, extent.height
848 width *= _figure.dpi
849 height *= _figure.dpi
850 try:
851 seperation = abs(location[0][0] - location[1][0]) - width
852 except IndexError:
853 seperation = None
854 data = {
855 "width": width, "height": height, "seperation": seperation,
856 "x0": location[0][0], "y0": location[0][0]
857 }
858 return _figure, included_parameters, data
861def _make_comparison_corner_plot(
862 samples, latex_labels, corner_parameters=None, colors=conf.corner_colors,
863 latex_friendly=True, **kwargs
864):
865 """Generate a corner plot which contains multiple datasets
867 Parameters
868 ----------
869 samples: dict
870 nested dictionary containing the label as key and SamplesDict as item
871 for each dataset you wish to plot
872 latex_labels: dict
873 dictionary of latex labels for each parameter
874 corner_parameters: list, optional
875 corner parameters you wish to include in the plot
876 colors: list, optional
877 unique colors for each dataset
878 latex_friendly: Bool, optional
879 if True, make the label latex friendly. Default True
880 **kwargs: dict
881 all kwargs are passed to `corner.corner`
882 """
883 parameters = corner_parameters
884 if corner_parameters is None:
885 _parameters = [list(_samples.keys()) for _samples in samples.values()]
886 parameters = [
887 i for i in _parameters[0] if all(i in _params for _params in _parameters)
888 ]
889 if len(samples.keys()) > len(colors):
890 raise ValueError("Please provide a unique color for each dataset")
892 hist_kwargs = kwargs.get("hist_kwargs", dict())
893 hist_kwargs["density"] = True
894 lines = []
895 for num, (label, posterior) in enumerate(samples.items()):
896 if latex_friendly:
897 label = copy.deepcopy(label)
898 label = label.replace("_", "\_")
899 lines.append(mlines.Line2D([], [], color=colors[num], label=label))
900 _samples = {
901 param: value for param, value in posterior.items() if param in
902 parameters
903 }
904 hist_kwargs["color"] = colors[num]
905 kwargs.update({"hist_kwargs": hist_kwargs})
906 if num == 0:
907 fig, _, _ = _make_corner_plot(
908 _samples, latex_labels, corner_parameters=corner_parameters,
909 parameters=parameters, color=colors[num], **kwargs
910 )
911 else:
912 fig, _, _ = _make_corner_plot(
913 _samples, latex_labels, corner_parameters=corner_parameters,
914 fig=fig, parameters=parameters, color=colors[num], **kwargs
915 )
916 fig.legend(handles=lines, loc="upper right")
917 lines = []
918 return fig