Coverage for pesummary/core/plots/publication.py: 80.5%
261 statements
« prev ^ index » next coverage.py v7.4.4, created at 2026-01-15 17:49 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2026-01-15 17:49 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import numpy as np
4from matplotlib import gridspec
5from scipy.stats import gaussian_kde
6import copy
8from pesummary.core.plots.figure import figure
9from .corner import hist2d
10from pesummary import conf
12__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
13DEFAULT_LEGEND_KWARGS = {"loc": "best", "frameon": False}
16def pcolormesh(
17 x, y, density, ax=None, levels=None, smooth=None, bins=None, label=None,
18 level_kwargs={}, range=None, grid=True, legend=False, legend_kwargs={},
19 weights=None, **kwargs
20):
21 """Generate a colormesh plot on a given axis
23 Parameters
24 ----------
25 x: np.ndarray
26 array of floats for the x axis
27 y: np.ndarray
28 array of floats for the y axis
29 density: np.ndarray
30 2d array of probabilities
31 ax: matplotlib.axes._subplots.AxesSubplot, optional
32 axis you wish to use for plotting
33 levels: list, optional
34 contour levels to show on the plot. Default None
35 smooth: float, optional
36 sigma to use for smoothing. Default, no smoothing applied
37 level_kwargs: dict, optional
38 optional kwargs to use for ax.contour
39 **kwargs: dict, optional
40 all additional kwargs passed to ax.pcolormesh
41 """
42 if smooth is not None:
43 import scipy.ndimage.filters as filter
44 density = filter.gaussian_filter(density, sigma=smooth)
45 if weights is not None:
46 raise ValueError(
47 "This function does not currently support weighted data"
48 )
49 _cmap = kwargs.get("cmap", None)
50 _off = False
51 if _cmap is not None and isinstance(_cmap, str) and _cmap.lower() == "off":
52 _off = True
53 if grid and "zorder" not in kwargs:
54 _zorder = -10
55 else:
56 _zorder = kwargs.pop("zorder", 10)
57 if not _off:
58 ax.pcolormesh(x, y, density, zorder=_zorder, **kwargs)
59 if levels is not None:
60 CS = ax.contour(x, y, density, levels=levels, **level_kwargs)
61 if legend:
62 _legend_kwargs = DEFAULT_LEGEND_KWARGS.copy()
63 _legend_kwargs.update(legend_kwargs)
64 CS.set_label(label)
65 ax.legend(**_legend_kwargs)
66 return ax
69def analytic_twod_contour_plot(*args, smooth=None, **kwargs):
70 """Generate a 2d contour plot given an analytic PDF
72 Parameters
73 ----------
74 *args: tuple
75 all args passed to twod_contour_plot
76 smooth: float, optional
77 degree of smoothing to apply to probabilities
78 **kwargs: dict, optional
79 all additional kwargs passed to twod_contour_plot
80 """
81 return twod_contour_plot(
82 *args, smooth=smooth, _function=pcolormesh, **kwargs
83 )
86def twod_contour_plot(
87 x, y, *args, rangex=None, rangey=None, fig=None, ax=None, return_ax=False,
88 levels=[0.9], bins=300, smooth=7, xlabel=None, ylabel=None,
89 fontsize={"label": 12}, grid=True, label=None, truth=None,
90 _function=hist2d, truth_lines=True, truth_kwargs={}, weights=None,
91 _default_truth_kwargs={
92 "marker": 'o', "markeredgewidth": 2, "markersize": 6, "color": 'k'
93 }, **kwargs
94):
95 """Generate a 2d contour contour plot for 2 marginalized posterior
96 distributions
98 Parameters
99 ----------
100 x: np.array
101 array of posterior samples to use for the x axis
102 y: np.array
103 array of posterior samples to use for the y axis
104 rangex: tuple, optional
105 range over which to plot the x axis
106 rangey: tuple, optional
107 range over which to plot the y axis
108 fig: matplotlib.figure.Figure, optional
109 figure you wish to use for plotting
110 ax: matplotlib.axes._subplots.AxesSubplot, optional
111 axis you wish to use for plotting
112 return_ax: Bool, optional
113 if True return the axis used for plotting. Else return the figure
114 levels: list, optional
115 levels you wish to use for the 2d contours. Default [0.9]
116 bins: int, optional
117 number of bins to use for gridding 2d parameter space. Default 300
118 smooth: int, optional
119 how much smoothing you wish to use for the 2d contours
120 xlabel: str, optional
121 label to use for the xaxis
122 ylabel: str, optional
123 label to use for the yaxis
124 fontsize: dict, optional
125 dictionary containing the fontsize to use for the plot
126 grid: Bool, optional
127 if True, add a grid to the plot
128 label: str, optional
129 label to use for a given contour
130 truth: list, optional
131 the true value of the posterior. `truth` is a list of length 2 with
132 first element being the true x value and second element being the true
133 y value
134 truth_lines: Bool, optional
135 if True, add vertical and horizontal lines spanning the 2d space to show
136 injected value
137 truth_kwargs: dict, optional
138 kwargs to use to indicate truth
139 **kwargs: dict, optional
140 all additional kwargs are passed to the
141 `pesummary.core.plots.corner.hist2d` function
142 """
143 if fig is None and ax is None:
144 fig, ax = figure(gca=True)
145 elif fig is None and ax is not None:
146 return_ax = True
147 elif ax is None:
148 ax = fig.gca()
150 xlow, xhigh = np.min(x), np.max(x)
151 ylow, yhigh = np.min(y), np.max(y)
152 if rangex is not None:
153 xlow, xhigh = rangex
154 if rangey is not None:
155 ylow, yhigh = rangey
156 if "range" not in list(kwargs.keys()):
157 kwargs["range"] = [[xlow, xhigh], [ylow, yhigh]]
159 _function(
160 x, y, *args, ax=ax, levels=levels, bins=bins, smooth=smooth,
161 label=label, grid=grid, weights=weights, **kwargs
162 )
163 if truth is not None:
164 _default_truth_kwargs.update(truth_kwargs)
165 ax.plot(*truth, **_default_truth_kwargs)
166 if truth_lines:
167 ax.axvline(
168 truth[0], color=_default_truth_kwargs["color"], linewidth=0.5
169 )
170 ax.axhline(
171 truth[1], color=_default_truth_kwargs["color"], linewidth=0.5
172 )
173 if xlabel is not None:
174 ax.set_xlabel(xlabel, fontsize=fontsize["label"])
175 if ylabel is not None:
176 ax.set_ylabel(ylabel, fontsize=fontsize["label"])
177 ax.grid(grid)
178 if fig is not None:
179 fig.tight_layout()
180 if return_ax:
181 return ax
182 return fig
185def comparison_twod_contour_plot(
186 x, y, labels=None, plot_density=None, rangex=None, rangey=None,
187 legend_kwargs={"loc": "best", "frameon": False}, fig=None, ax=None,
188 colors=list(conf.colorcycle), linestyles=None, add_legend=True, **kwargs
189):
190 """Generate a comparison 2d contour contour plot for 2 marginalized
191 posterior distributions from multiple analyses
193 Parameters
194 ----------
195 x: np.ndarray
196 2d array of posterior samples to use for the x axis; array for each
197 analysis
198 y: np.ndarray
199 2d array of posterior samples to use for the y axis; array for each
200 analysis
201 labels: list, optional
202 labels to assign to each contour
203 plot_density: str, optional
204 label of the analysis you wish to plot the density for. If you wish
205 to plot both, simply pass `plot_density='both'`
206 rangex: tuple, optional
207 range over which to plot the x axis
208 rangey: tuple, optional
209 range over which to plot the y axis
210 legend_kwargs: dict, optional
211 kwargs to use for the legend
212 colors: list, optional
213 list of colors to use for each contour
214 linestyles: list, optional
215 linestyles to use for each contour
216 **kwargs: dict, optional
217 all additional kwargs are passed to the
218 `pesummary.core.plots.publication.twod_contour_plot` function
219 """
220 _ax = ax
221 if labels is None and plot_density is not None:
222 plot_density = None
223 if labels is None:
224 labels = [None] * len(x)
226 xlow = np.min([np.min(_x) for _x in x])
227 xhigh = np.max([np.max(_x) for _x in x])
228 ylow = np.min([np.min(_y) for _y in y])
229 yhigh = np.max([np.max(_y) for _y in y])
230 if rangex is None:
231 rangex = [xlow, xhigh]
232 if rangey is None:
233 rangey = [ylow, yhigh]
235 for num, (_x, _y) in enumerate(zip(x, y)):
236 if plot_density is not None and plot_density == labels[num]:
237 plot_density = True
238 elif plot_density is not None and isinstance(plot_density, list):
239 if labels[num] in plot_density:
240 plot_density = True
241 else:
242 plot_density = False
243 elif plot_density is not None and plot_density == "both":
244 plot_density = True
245 else:
246 plot_density = False
248 _label = _color = _linestyle = None
249 if labels is not None:
250 _label = labels[num]
251 if colors is not None:
252 _color = colors[num]
253 if linestyles is not None:
254 _linestyle = linestyles[num]
255 fig = twod_contour_plot(
256 _x, _y, plot_density=plot_density, label=_label, fig=fig, ax=_ax,
257 rangex=rangex, rangey=rangey, color=_color, linestyles=_linestyle,
258 **kwargs
259 )
260 if ax is None:
261 _ax = fig.gca()
262 if add_legend:
263 legend = _ax.legend(**legend_kwargs)
264 return fig
267def _triangle_axes(
268 figsize=(8, 8), width_ratios=[4, 1], height_ratios=[1, 4], wspace=0.0,
269 hspace=0.0,
270):
271 """Initialize the axes for a 2d triangle plot
273 Parameters
274 ----------
275 figsize: tuple, optional
276 figure size you wish to use. Default (8, 8)
277 width_ratios: list, optional
278 ratio of widths for the triangular axis. Default 4:1
279 height_ratios: list, optional
280 ratio of heights for the triangular axis. Default 1:4
281 wspace: float, optional
282 horizontal space between the axis. Default 0.0
283 hspace: float, optional
284 vertical space between the axis. Default 0.0
285 """
286 high1d = 1.0
287 fig = figure(figsize=figsize, gca=False)
288 gs = gridspec.GridSpec(
289 2, 2, width_ratios=width_ratios, height_ratios=height_ratios,
290 wspace=wspace, hspace=hspace
291 )
292 ax1, ax2, ax3, ax4 = (
293 fig.add_subplot(gs[0]),
294 fig.add_subplot(gs[1]),
295 fig.add_subplot(gs[2]),
296 fig.add_subplot(gs[3]),
297 )
298 ax1.minorticks_on()
299 ax3.minorticks_on()
300 ax4.minorticks_on()
301 ax1.xaxis.set_ticklabels([])
302 ax4.yaxis.set_ticklabels([])
303 return fig, ax1, ax2, ax3, ax4
306def _generate_triangle_plot(
307 *args, function=None, fig_kwargs={}, existing_figure=None, **kwargs
308):
309 """Generate a triangle plot according to a given function
311 Parameters
312 ----------
313 *args: tuple
314 all args passed to function
315 function: func, optional
316 function you wish to use to generate triangle plot. Default
317 _triangle_plot
318 **kwargs: dict, optional
319 all kwargs passed to function
320 """
321 if existing_figure is None:
322 fig, ax1, ax2, ax3, ax4 = _triangle_axes(**fig_kwargs)
323 ax2.axis("off")
324 else:
325 fig, ax1, ax3, ax4 = existing_figure
326 if function is None:
327 function = _triangle_plot
328 return function(fig, [ax1, ax3, ax4], *args, **kwargs)
331def triangle_plot(*args, **kwargs):
332 """Generate a triangular plot made of 3 axis. One central axis showing the
333 2d marginalized posterior and two smaller axes showing the marginalized 1d
334 posterior distribution (above and to the right of central axis)
336 Parameters
337 ----------
338 x: list
339 list of samples for the x axis
340 y: list
341 list of samples for the y axis
342 kde: Bool/func, optional
343 kde to use for smoothing the 1d marginalized posterior distribution. If
344 you do not want to use KDEs, simply pass kde=False. Default
345 scipy.stats.gaussian_kde
346 kde_2d: func, optional
347 kde to use for smoothing the 2d marginalized posterior distribution.
348 default None
349 npoints: int, optional
350 number of points to use for the 1d kde
351 kde_kwargs: dict, optional
352 optional kwargs which are passed directly to the kde function
353 kde_2d_kwargs: dict, optional
354 optional kwargs which are passed directly to the 2d kde function
355 fill: Bool, optional
356 whether or not to fill the 1d posterior distributions
357 fill_alpha: float, optional
358 alpha to use for fill
359 levels: list, optional
360 levels you wish to use for the 2d contours
361 smooth: dict/float, optional
362 how much smoothing you wish to use for the 2d contours. If you wish
363 to use different smoothing for different contours, then provide a dict
364 with keys given by the label
365 colors: list, optional
366 list of colors you wish to use for each analysis
367 xlabel: str, optional
368 xlabel you wish to use for the plot
369 ylabel: str, optional
370 ylabel you wish to use for the plot
371 fontsize: dict, optional
372 dictionary giving the fontsize for the labels and legend. Default
373 {'legend': 12, 'label': 12}
374 linestyles: list, optional
375 linestyles you wish to use for each analysis
376 linewidths: list, optional
377 linewidths you wish to use for each analysis
378 plot_density: Bool, optional
379 whether or not to plot the density on the 2d contour. Default True
380 percentiles: list, optional
381 percentiles you wish to plot. Default None
382 percentile_plot: list, optional
383 list of analyses to plot percentiles. Default all analyses
384 fig_kwargs: dict, optional
385 optional kwargs passed directly to the _triangle_axes function
386 labels: list, optional
387 label associated with each set of samples
388 rangex: tuple, optional
389 range over which to plot the x axis
390 rangey: tuple, optional
391 range over which to plot the y axis
392 grid: Bool, optional
393 if True, show a grid on all axes. Default False
394 legend_kwargs: dict, optional
395 optional kwargs for the legend. Default {"loc": "best", "frameon": False}
396 **kwargs: dict
397 all additional kwargs are passed to the corner.hist2d function
398 """
399 return _generate_triangle_plot(*args, function=_triangle_plot, **kwargs)
402def analytic_triangle_plot(*args, **kwargs):
403 """Generate a triangle plot given probability densities for x, y and xy.
405 Parameters
406 ----------
407 fig: matplotlib.figure.Figure
408 figure on which to make the plots
409 axes: list
410 list of subplots associated with the figure
411 x: list
412 list of points to use for the x axis
413 y: list
414 list of points to use for the y axis
415 prob_x: list
416 list of probabilities associated with x
417 prob_y: list
418 list of probabilities associated with y
419 probs_xy: list
420 2d list of probabilities for xy
421 smooth: float, optional
422 degree of smoothing to apply to probs_xy. Default no smoothing applied
423 cmap: str, optional
424 name of cmap to use for plotting
425 """
426 return _generate_triangle_plot(
427 *args, function=_analytic_triangle_plot, **kwargs
428 )
431def _analytic_triangle_plot(
432 fig, axes, x, y, probs_x, probs_y, probs_xy, smooth=None, xlabel=None,
433 ylabel=None, grid=True, **kwargs
434):
435 """Generate a triangle plot given probability densities for x, y and xy.
437 Parameters
438 ----------
439 fig: matplotlib.figure.Figure
440 figure on which to make the plots
441 axes: list
442 list of subplots associated with the figure
443 x: list
444 list of points to use for the x axis
445 y: list
446 list of points to use for the y axis
447 prob_x: list
448 list of probabilities associated with x
449 prob_y: list
450 list of probabilities associated with y
451 probs_xy: list
452 2d list of probabilities for xy
453 smooth: float, optional
454 degree of smoothing to apply to probs_xy. Default no smoothing applied
455 xlabel: str, optional
456 label to use for the x axis
457 ylabel: str, optional
458 label to use for the y axis
459 grid: Bool, optional
460 if True, add a grid to the plot
461 """
462 ax1, ax3, ax4 = axes
463 analytic_twod_contour_plot(
464 x, y, probs_xy, ax=ax3, smooth=smooth, grid=grid, **kwargs
465 )
466 level_kwargs = kwargs.get("level_kwargs", None)
467 if level_kwargs is not None and "colors" in level_kwargs.keys():
468 color = level_kwargs["colors"][0]
469 else:
470 color = None
471 ax1.plot(x, probs_x, color=color)
472 ax4.plot(probs_y, y, color=color)
473 fontsize = kwargs.get("fontsize", {"label": 12})
474 if xlabel is not None:
475 ax3.set_xlabel(xlabel, fontsize=fontsize["label"])
476 if ylabel is not None:
477 ax3.set_ylabel(ylabel, fontsize=fontsize["label"])
478 ax1.grid(grid)
479 if grid:
480 ax3.grid(grid, zorder=10)
481 ax4.grid(grid)
482 xlims = ax3.get_xlim()
483 ax1.set_xlim(xlims)
484 ylims = ax3.get_ylim()
485 ax4.set_ylim(ylims)
486 fig.tight_layout()
487 return fig, ax1, ax3, ax4
490def _triangle_plot(
491 fig, axes, x, y, kde=gaussian_kde, npoints=100, kde_kwargs={}, fill=True,
492 fill_alpha=0.5, levels=[0.9], smooth=7, colors=list(conf.colorcycle),
493 xlabel=None, ylabel=None, fontsize={"legend": 12, "label": 12},
494 linestyles=None, linewidths=None, plot_density=True, percentiles=None,
495 percentile_plot=None, fig_kwargs={}, labels=None, plot_datapoints=False,
496 rangex=None, rangey=None, grid=False, latex_friendly=False, kde_2d=None,
497 kde_2d_kwargs={}, legend_kwargs={"loc": "best", "frameon": False},
498 truth=None, hist_kwargs={"density": True, "bins": 50},
499 _contour_function=twod_contour_plot, weights=None, **kwargs
500):
501 """Base function to generate a triangular plot
503 Parameters
504 ----------
505 fig: matplotlib.figure.Figure
506 figure on which to make the plots
507 axes: list
508 list of subplots associated with the figure
509 x: list
510 list of samples for the x axis
511 y: list
512 list of samples for the y axis
513 kde: Bool/func, optional
514 kde to use for smoothing the 1d marginalized posterior distribution. If
515 you do not want to use KDEs, simply pass kde=False. Default
516 scipy.stats.gaussian_kde
517 kde_2d: func, optional
518 kde to use for smoothing the 2d marginalized posterior distribution.
519 default None
520 npoints: int, optional
521 number of points to use for the 1d kde
522 kde_kwargs: dict, optional
523 optional kwargs which are passed directly to the kde function.
524 kde_kwargs to be passed to the kde on the y axis may be specified
525 by the dictionary entry 'y_axis'. kde_kwargs to be passed to the kde on
526 the x axis may be specified by the dictionary entry 'x_axis'.
527 kde_2d_kwargs: dict, optional
528 optional kwargs which are passed directly to the 2d kde function
529 fill: Bool, optional
530 whether or not to fill the 1d posterior distributions
531 fill_alpha: float, optional
532 alpha to use for fill
533 levels: list, optional
534 levels you wish to use for the 2d contours
535 smooth: dict/float, optional
536 how much smoothing you wish to use for the 2d contours. If you wish
537 to use different smoothing for different contours, then provide a dict
538 with keys given by the label
539 colors: list, optional
540 list of colors you wish to use for each analysis
541 xlabel: str, optional
542 xlabel you wish to use for the plot
543 ylabel: str, optional
544 ylabel you wish to use for the plot
545 fontsize: dict, optional
546 dictionary giving the fontsize for the labels and legend. Default
547 {'legend': 12, 'label': 12}
548 linestyles: list, optional
549 linestyles you wish to use for each analysis
550 linewidths: list, optional
551 linewidths you wish to use for each analysis
552 plot_density: Bool, optional
553 whether or not to plot the density on the 2d contour. Default True
554 percentiles: list, optional
555 percentiles you wish to plot. Default None
556 percentile_plot: list, optional
557 list of analyses to plot percentiles. Default all analyses
558 fig_kwargs: dict, optional
559 optional kwargs passed directly to the _triangle_axes function
560 labels: list, optional
561 label associated with each set of samples
562 rangex: tuple, optional
563 range over which to plot the x axis
564 rangey: tuple, optional
565 range over which to plot the y axis
566 grid: Bool, optional
567 if True, show a grid on all axes
568 legend_kwargs: dict, optional
569 optional kwargs for the legend. Default {"loc": "best", "frameon": False}
570 **kwargs: dict
571 all kwargs are passed to the corner.hist2d function
572 """
573 ax1, ax3, ax4 = axes
574 if not isinstance(x[0], (list, np.ndarray)):
575 x, y = np.atleast_2d(x), np.atleast_2d(y)
576 _base_error = "Please provide {} for each analysis"
577 if len(colors) < len(x):
578 raise ValueError(_base_error.format("a single color"))
579 if linestyles is None:
580 linestyles = ["-"] * len(x)
581 elif len(linestyles) < len(x):
582 raise ValueError(_base_error.format("a single linestyle"))
583 if linewidths is None:
584 linewidths = [None] * len(x)
585 elif len(linewidths) < len(x):
586 raise ValueError(_base_error.format("a single linewidth"))
587 if labels is None:
588 labels = [None] * len(x)
589 elif len(labels) != len(x):
590 raise ValueError(_base_error.format("a label"))
592 xlow = np.min([np.min(_x) for _x in x])
593 xhigh = np.max([np.max(_x) for _x in x])
594 ylow = np.min([np.min(_y) for _y in y])
595 yhigh = np.max([np.max(_y) for _y in y])
596 if rangex is not None:
597 xlow, xhigh = rangex
598 if rangey is not None:
599 ylow, yhigh = rangey
600 for num in range(len(x)):
601 plot_kwargs = dict(
602 color=colors[num], linewidth=linewidths[num],
603 linestyle=linestyles[num]
604 )
605 if kde:
606 if "x_axis" in kde_kwargs.keys():
607 _kde = kde(x[num], weights=weights, **kde_kwargs["x_axis"])
608 else:
609 _kde = kde(x[num], weights=weights, **kde_kwargs)
610 _x = np.linspace(xlow, xhigh, npoints)
611 _y = _kde(_x)
612 ax1.plot(_x, _y, **plot_kwargs)
613 if fill:
614 ax1.fill_between(_x, 0, _y, alpha=fill_alpha, **plot_kwargs)
615 _y = np.linspace(ylow, yhigh, npoints)
616 if "y_axis" in kde_kwargs.keys():
617 _kde = kde(y[num], weights=weights, **kde_kwargs["y_axis"])
618 else:
619 _kde = kde(y[num], weights=weights, **kde_kwargs)
620 _x = _kde(_y)
621 if latex_friendly:
622 labels = copy.deepcopy(labels)
623 labels[num] = labels[num].replace("_", "\_")
624 ax4.plot(_x, _y, label=labels[num], **plot_kwargs)
625 if fill:
626 ax4.fill_betweenx(_y, 0, _x, alpha=fill_alpha, **plot_kwargs)
627 else:
628 if fill:
629 histtype = "stepfilled"
630 else:
631 histtype = "step"
632 ax1.hist(
633 x[num], histtype=histtype, weights=weights, **hist_kwargs,
634 **plot_kwargs
635 )
636 ax4.hist(
637 y[num], histtype=histtype, weights=weights, orientation="horizontal",
638 **hist_kwargs, **plot_kwargs
639 )
640 if percentiles is not None:
641 if percentile_plot is not None and labels[num] in percentile_plot:
642 _percentiles = np.percentile(x[num], percentiles)
643 ax1.axvline(
644 _percentiles[0], linestyle="--",
645 linewidth=plot_kwargs.get("linewidth", 1.75)
646 )
647 ax1.axvline(
648 _percentiles[1], linestyle="--",
649 linewidth=plot_kwargs.get("linewidth", 1.75)
650 )
651 _percentiles = np.percentile(y[num], percentiles)
652 ax4.axhline(
653 _percentiles[0], linestyle="--",
654 linewidth=plot_kwargs.get("linewidth", 1.75)
655 )
656 ax4.axhline(
657 _percentiles[1], linestyle="--",
658 linewidth=plot_kwargs.get("linewidth", 1.75)
659 )
660 if isinstance(smooth, dict):
661 _smooth = smooth[labels[num]]
662 else:
663 _smooth = smooth
664 _contour_function(
665 x[num], y[num], ax=ax3, levels=levels, smooth=_smooth,
666 rangex=[xlow, xhigh], rangey=[ylow, yhigh], color=colors[num],
667 linestyles=linestyles[num], weights=weights,
668 plot_density=plot_density, contour_kwargs=dict(
669 linestyles=[linestyles[num]], linewidths=linewidths[num]
670 ), plot_datapoints=plot_datapoints, kde=kde_2d,
671 kde_kwargs=kde_2d_kwargs, grid=False, truth=truth, **kwargs
672 )
674 if truth is not None:
675 ax1.axvline(truth[0], color='k', linewidth=0.5)
676 ax4.axhline(truth[1], color='k', linewidth=0.5)
677 if xlabel is not None:
678 ax3.set_xlabel(xlabel, fontsize=fontsize["label"])
679 if ylabel is not None:
680 ax3.set_ylabel(ylabel, fontsize=fontsize["label"])
681 if not all(label is None for label in labels):
682 legend_kwargs["fontsize"] = fontsize["legend"]
683 ax3.legend(*ax4.get_legend_handles_labels(), **legend_kwargs)
684 ax1.grid(grid)
685 ax3.grid(grid)
686 ax4.grid(grid)
687 xlims = ax1.get_xlim()
688 ax3.set_xlim(xlims)
689 ylims = ax4.get_ylim()
690 ax3.set_ylim(ylims)
691 return fig, ax1, ax3, ax4
694def _generate_reverse_triangle_plot(
695 *args, xlabel=None, ylabel=None, function=None, existing_figure=None, **kwargs
696):
697 """Generate a reverse triangle plot according to a given function
699 Parameters
700 ----------
701 *args: tuple
702 all args passed to function
703 xlabel: str, optional
704 label to use for the x axis
705 ylabel: str, optional
706 label to use for the y axis
707 function: func, optional
708 function to use to generate triangle plot. Default _triangle_plot
709 **kwargs: dict, optional
710 all kwargs passed to function
711 """
712 if existing_figure is None:
713 fig, ax1, ax2, ax3, ax4 = _triangle_axes(
714 width_ratios=[1, 4], height_ratios=[4, 1]
715 )
716 ax3.axis("off")
717 else:
718 fig, ax1, ax2, ax4 = existing_figure
719 if function is None:
720 function = _triangle_plot
721 fig, ax4, ax2, ax1 = function(fig, [ax4, ax2, ax1], *args, **kwargs)
722 ax2.axis("off")
723 ax4.spines["right"].set_visible(False)
724 ax4.spines["top"].set_visible(False)
725 ax4.spines["left"].set_visible(False)
726 ax4.set_yticks([])
728 ax1.spines["right"].set_visible(False)
729 ax1.spines["top"].set_visible(False)
730 ax1.spines["bottom"].set_visible(False)
731 ax1.set_xticks([])
733 _fontsize = kwargs.get("fontsize", {"label": 12})["label"]
734 if xlabel is not None:
735 ax4.set_xlabel(xlabel, fontsize=_fontsize)
736 if ylabel is not None:
737 ax1.set_ylabel(ylabel, fontsize=_fontsize)
738 return fig, ax1, ax2, ax4
741def reverse_triangle_plot(*args, **kwargs):
742 """Generate a triangular plot made of 3 axis. One central axis showing the
743 2d marginalized posterior and two smaller axes showing the marginalized 1d
744 posterior distribution (below and to the left of central axis). Only two
745 axes are plotted, each below the 1d marginalized posterior distribution
747 Parameters
748 ----------
749 x: list
750 list of samples for the x axis
751 y: list
752 list of samples for the y axis
753 kde: Bool/func, optional
754 kde to use for smoothing the 1d marginalized posterior distribution. If
755 you do not want to use KDEs, simply pass kde=False. Default
756 scipy.stats.gaussian_kde
757 kde_2d: func, optional
758 kde to use for smoothing the 2d marginalized posterior distribution.
759 default None
760 npoints: int, optional
761 number of points to use for the 1d kde
762 kde_kwargs: dict, optional
763 optional kwargs which are passed directly to the kde function.
764 kde_kwargs to be passed to the kde on the y axis may be specified
765 by the dictionary entry 'y_axis'. kde_kwargs to be passed to the kde on
766 the x axis may be specified by the dictionary entry 'x_axis'.
767 kde_2d_kwargs: dict, optional
768 optional kwargs which are passed directly to the 2d kde function
769 fill: Bool, optional
770 whether or not to fill the 1d posterior distributions
771 fill_alpha: float, optional
772 alpha to use for fill
773 levels: list, optional
774 levels you wish to use for the 2d contours
775 smooth: dict/float, optional
776 how much smoothing you wish to use for the 2d contours. If you wish
777 to use different smoothing for different contours, then provide a dict
778 with keys given by the label
779 colors: list, optional
780 list of colors you wish to use for each analysis
781 xlabel: str, optional
782 xlabel you wish to use for the plot
783 ylabel: str, optional
784 ylabel you wish to use for the plot
785 fontsize: dict, optional
786 dictionary giving the fontsize for the labels and legend. Default
787 {'legend': 12, 'label': 12}
788 linestyles: list, optional
789 linestyles you wish to use for each analysis
790 linewidths: list, optional
791 linewidths you wish to use for each analysis
792 plot_density: Bool, optional
793 whether or not to plot the density on the 2d contour. Default True
794 percentiles: list, optional
795 percentiles you wish to plot. Default None
796 percentile_plot: list, optional
797 list of analyses to plot percentiles. Default all analyses
798 fig_kwargs: dict, optional
799 optional kwargs passed directly to the _triangle_axes function
800 labels: list, optional
801 label associated with each set of samples
802 rangex: tuple, optional
803 range over which to plot the x axis
804 rangey: tuple, optional
805 range over which to plot the y axis
806 legend_kwargs: dict, optional
807 optional kwargs for the legend. Default {"loc": "best", "frameon": False}
808 **kwargs: dict
809 all kwargs are passed to the corner.hist2d function
810 """
811 return _generate_reverse_triangle_plot(
812 *args, function=_triangle_plot, **kwargs
813 )
816def analytic_reverse_triangle_plot(*args, **kwargs):
817 """Generate a triangle plot given probability densities for x, y and xy.
819 Parameters
820 ----------
821 fig: matplotlib.figure.Figure
822 figure on which to make the plots
823 axes: list
824 list of subplots associated with the figure
825 x: list
826 list of points to use for the x axis
827 y: list
828 list of points to use for the y axis
829 prob_x: list
830 list of probabilities associated with x
831 prob_y: list
832 list of probabilities associated with y
833 probs_xy: list
834 2d list of probabilities for xy
835 smooth: float, optional
836 degree of smoothing to apply to probs_xy. Default no smoothing applied
837 cmap: str, optional
838 name of cmap to use for plotting
839 """
840 return _generate_reverse_triangle_plot(
841 *args, function=_analytic_triangle_plot, **kwargs
842 )