Coverage for pesummary/core/plots/publication.py: 81.3%
257 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-02 08:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-02 08:42 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import numpy as np
4from matplotlib import gridspec
5from scipy.stats import gaussian_kde
6import copy
8from pesummary.core.plots.figure import figure
9from .corner import hist2d
10from pesummary import conf
12__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
13DEFAULT_LEGEND_KWARGS = {"loc": "best", "frameon": False}
16def pcolormesh(
17 x, y, density, ax=None, levels=None, smooth=None, bins=None, label=None,
18 level_kwargs={}, range=None, grid=True, legend=False, legend_kwargs={},
19 **kwargs
20):
21 """Generate a colormesh plot on a given axis
23 Parameters
24 ----------
25 x: np.ndarray
26 array of floats for the x axis
27 y: np.ndarray
28 array of floats for the y axis
29 density: np.ndarray
30 2d array of probabilities
31 ax: matplotlib.axes._subplots.AxesSubplot, optional
32 axis you wish to use for plotting
33 levels: list, optional
34 contour levels to show on the plot. Default None
35 smooth: float, optional
36 sigma to use for smoothing. Default, no smoothing applied
37 level_kwargs: dict, optional
38 optional kwargs to use for ax.contour
39 **kwargs: dict, optional
40 all additional kwargs passed to ax.pcolormesh
41 """
42 if smooth is not None:
43 import scipy.ndimage.filters as filter
44 density = filter.gaussian_filter(density, sigma=smooth)
45 _cmap = kwargs.get("cmap", None)
46 _off = False
47 if _cmap is not None and isinstance(_cmap, str) and _cmap.lower() == "off":
48 _off = True
49 if grid and "zorder" not in kwargs:
50 _zorder = -10
51 else:
52 _zorder = kwargs.pop("zorder", 10)
53 if not _off:
54 ax.pcolormesh(x, y, density, zorder=_zorder, **kwargs)
55 if levels is not None:
56 CS = ax.contour(x, y, density, levels=levels, **level_kwargs)
57 if legend:
58 _legend_kwargs = DEFAULT_LEGEND_KWARGS.copy()
59 _legend_kwargs.update(legend_kwargs)
60 CS.collections[0].set_label(label)
61 ax.legend(**_legend_kwargs)
62 return ax
65def analytic_twod_contour_plot(*args, smooth=None, **kwargs):
66 """Generate a 2d contour plot given an analytic PDF
68 Parameters
69 ----------
70 *args: tuple
71 all args passed to twod_contour_plot
72 smooth: float, optional
73 degree of smoothing to apply to probabilities
74 **kwargs: dict, optional
75 all additional kwargs passed to twod_contour_plot
76 """
77 return twod_contour_plot(
78 *args, smooth=smooth, _function=pcolormesh, **kwargs
79 )
82def twod_contour_plot(
83 x, y, *args, rangex=None, rangey=None, fig=None, ax=None, return_ax=False,
84 levels=[0.9], bins=300, smooth=7, xlabel=None, ylabel=None,
85 fontsize={"label": 12}, grid=True, label=None, truth=None,
86 _function=hist2d, truth_lines=True, truth_kwargs={},
87 _default_truth_kwargs={
88 "marker": 'o', "markeredgewidth": 2, "markersize": 6, "color": 'k'
89 }, **kwargs
90):
91 """Generate a 2d contour contour plot for 2 marginalized posterior
92 distributions
94 Parameters
95 ----------
96 x: np.array
97 array of posterior samples to use for the x axis
98 y: np.array
99 array of posterior samples to use for the y axis
100 rangex: tuple, optional
101 range over which to plot the x axis
102 rangey: tuple, optional
103 range over which to plot the y axis
104 fig: matplotlib.figure.Figure, optional
105 figure you wish to use for plotting
106 ax: matplotlib.axes._subplots.AxesSubplot, optional
107 axis you wish to use for plotting
108 return_ax: Bool, optional
109 if True return the axis used for plotting. Else return the figure
110 levels: list, optional
111 levels you wish to use for the 2d contours. Default [0.9]
112 bins: int, optional
113 number of bins to use for gridding 2d parameter space. Default 300
114 smooth: int, optional
115 how much smoothing you wish to use for the 2d contours
116 xlabel: str, optional
117 label to use for the xaxis
118 ylabel: str, optional
119 label to use for the yaxis
120 fontsize: dict, optional
121 dictionary containing the fontsize to use for the plot
122 grid: Bool, optional
123 if True, add a grid to the plot
124 label: str, optional
125 label to use for a given contour
126 truth: list, optional
127 the true value of the posterior. `truth` is a list of length 2 with
128 first element being the true x value and second element being the true
129 y value
130 truth_lines: Bool, optional
131 if True, add vertical and horizontal lines spanning the 2d space to show
132 injected value
133 truth_kwargs: dict, optional
134 kwargs to use to indicate truth
135 **kwargs: dict, optional
136 all additional kwargs are passed to the
137 `pesummary.core.plots.corner.hist2d` function
138 """
139 if fig is None and ax is None:
140 fig, ax = figure(gca=True)
141 elif fig is None and ax is not None:
142 return_ax = True
143 elif ax is None:
144 ax = fig.gca()
146 xlow, xhigh = np.min(x), np.max(x)
147 ylow, yhigh = np.min(y), np.max(y)
148 if rangex is not None:
149 xlow, xhigh = rangex
150 if rangey is not None:
151 ylow, yhigh = rangey
152 if "range" not in list(kwargs.keys()):
153 kwargs["range"] = [[xlow, xhigh], [ylow, yhigh]]
155 _function(
156 x, y, *args, ax=ax, levels=levels, bins=bins, smooth=smooth,
157 label=label, grid=grid, **kwargs
158 )
159 if truth is not None:
160 _default_truth_kwargs.update(truth_kwargs)
161 ax.plot(*truth, **_default_truth_kwargs)
162 if truth_lines:
163 ax.axvline(
164 truth[0], color=_default_truth_kwargs["color"], linewidth=0.5
165 )
166 ax.axhline(
167 truth[1], color=_default_truth_kwargs["color"], linewidth=0.5
168 )
169 if xlabel is not None:
170 ax.set_xlabel(xlabel, fontsize=fontsize["label"])
171 if ylabel is not None:
172 ax.set_ylabel(ylabel, fontsize=fontsize["label"])
173 ax.grid(grid)
174 if fig is not None:
175 fig.tight_layout()
176 if return_ax:
177 return ax
178 return fig
181def comparison_twod_contour_plot(
182 x, y, labels=None, plot_density=None, rangex=None, rangey=None,
183 legend_kwargs={"loc": "best", "frameon": False},
184 colors=list(conf.colorcycle), linestyles=None, **kwargs
185):
186 """Generate a comparison 2d contour contour plot for 2 marginalized
187 posterior distributions from multiple analyses
189 Parameters
190 ----------
191 x: np.ndarray
192 2d array of posterior samples to use for the x axis; array for each
193 analysis
194 y: np.ndarray
195 2d array of posterior samples to use for the y axis; array for each
196 analysis
197 labels: list, optional
198 labels to assign to each contour
199 plot_density: str, optional
200 label of the analysis you wish to plot the density for. If you wish
201 to plot both, simply pass `plot_density='both'`
202 rangex: tuple, optional
203 range over which to plot the x axis
204 rangey: tuple, optional
205 range over which to plot the y axis
206 legend_kwargs: dict, optional
207 kwargs to use for the legend
208 colors: list, optional
209 list of colors to use for each contour
210 linestyles: list, optional
211 linestyles to use for each contour
212 **kwargs: dict, optional
213 all additional kwargs are passed to the
214 `pesummary.core.plots.publication.twod_contour_plot` function
215 """
216 if labels is None and plot_density is not None:
217 plot_density = None
218 if labels is None:
219 labels = [None] * len(x)
221 xlow = np.min([np.min(_x) for _x in x])
222 xhigh = np.max([np.max(_x) for _x in x])
223 ylow = np.min([np.min(_y) for _y in y])
224 yhigh = np.max([np.max(_y) for _y in y])
225 if rangex is None:
226 rangex = [xlow, xhigh]
227 if rangey is None:
228 rangey = [ylow, yhigh]
230 fig = None
231 for num, (_x, _y) in enumerate(zip(x, y)):
232 if plot_density is not None and plot_density == labels[num]:
233 plot_density = True
234 elif plot_density is not None and isinstance(plot_density, list):
235 if labels[num] in plot_density:
236 plot_density = True
237 else:
238 plot_density = False
239 elif plot_density is not None and plot_density == "both":
240 plot_density = True
241 else:
242 plot_density = False
244 _label = _color = _linestyle = None
245 if labels is not None:
246 _label = labels[num]
247 if colors is not None:
248 _color = colors[num]
249 if linestyles is not None:
250 _linestyle = linestyles[num]
251 fig = twod_contour_plot(
252 _x, _y, plot_density=plot_density, label=_label, fig=fig,
253 rangex=rangex, rangey=rangey, color=_color, linestyles=_linestyle,
254 **kwargs
255 )
256 ax = fig.gca()
257 legend = ax.legend(**legend_kwargs)
258 return fig
261def _triangle_axes(
262 figsize=(8, 8), width_ratios=[4, 1], height_ratios=[1, 4], wspace=0.0,
263 hspace=0.0,
264):
265 """Initialize the axes for a 2d triangle plot
267 Parameters
268 ----------
269 figsize: tuple, optional
270 figure size you wish to use. Default (8, 8)
271 width_ratios: list, optional
272 ratio of widths for the triangular axis. Default 4:1
273 height_ratios: list, optional
274 ratio of heights for the triangular axis. Default 1:4
275 wspace: float, optional
276 horizontal space between the axis. Default 0.0
277 hspace: float, optional
278 vertical space between the axis. Default 0.0
279 """
280 high1d = 1.0
281 fig = figure(figsize=figsize, gca=False)
282 gs = gridspec.GridSpec(
283 2, 2, width_ratios=width_ratios, height_ratios=height_ratios,
284 wspace=wspace, hspace=hspace
285 )
286 ax1, ax2, ax3, ax4 = (
287 fig.add_subplot(gs[0]),
288 fig.add_subplot(gs[1]),
289 fig.add_subplot(gs[2]),
290 fig.add_subplot(gs[3]),
291 )
292 ax1.minorticks_on()
293 ax3.minorticks_on()
294 ax4.minorticks_on()
295 ax1.xaxis.set_ticklabels([])
296 ax4.yaxis.set_ticklabels([])
297 return fig, ax1, ax2, ax3, ax4
300def _generate_triangle_plot(
301 *args, function=None, fig_kwargs={}, existing_figure=None, **kwargs
302):
303 """Generate a triangle plot according to a given function
305 Parameters
306 ----------
307 *args: tuple
308 all args passed to function
309 function: func, optional
310 function you wish to use to generate triangle plot. Default
311 _triangle_plot
312 **kwargs: dict, optional
313 all kwargs passed to function
314 """
315 if existing_figure is None:
316 fig, ax1, ax2, ax3, ax4 = _triangle_axes(**fig_kwargs)
317 ax2.axis("off")
318 else:
319 fig, ax1, ax3, ax4 = existing_figure
320 if function is None:
321 function = _triangle_plot
322 return function(fig, [ax1, ax3, ax4], *args, **kwargs)
325def triangle_plot(*args, **kwargs):
326 """Generate a triangular plot made of 3 axis. One central axis showing the
327 2d marginalized posterior and two smaller axes showing the marginalized 1d
328 posterior distribution (above and to the right of central axis)
330 Parameters
331 ----------
332 x: list
333 list of samples for the x axis
334 y: list
335 list of samples for the y axis
336 kde: Bool/func, optional
337 kde to use for smoothing the 1d marginalized posterior distribution. If
338 you do not want to use KDEs, simply pass kde=False. Default
339 scipy.stats.gaussian_kde
340 kde_2d: func, optional
341 kde to use for smoothing the 2d marginalized posterior distribution.
342 default None
343 npoints: int, optional
344 number of points to use for the 1d kde
345 kde_kwargs: dict, optional
346 optional kwargs which are passed directly to the kde function
347 kde_2d_kwargs: dict, optional
348 optional kwargs which are passed directly to the 2d kde function
349 fill: Bool, optional
350 whether or not to fill the 1d posterior distributions
351 fill_alpha: float, optional
352 alpha to use for fill
353 levels: list, optional
354 levels you wish to use for the 2d contours
355 smooth: dict/float, optional
356 how much smoothing you wish to use for the 2d contours. If you wish
357 to use different smoothing for different contours, then provide a dict
358 with keys given by the label
359 colors: list, optional
360 list of colors you wish to use for each analysis
361 xlabel: str, optional
362 xlabel you wish to use for the plot
363 ylabel: str, optional
364 ylabel you wish to use for the plot
365 fontsize: dict, optional
366 dictionary giving the fontsize for the labels and legend. Default
367 {'legend': 12, 'label': 12}
368 linestyles: list, optional
369 linestyles you wish to use for each analysis
370 linewidths: list, optional
371 linewidths you wish to use for each analysis
372 plot_density: Bool, optional
373 whether or not to plot the density on the 2d contour. Default True
374 percentiles: list, optional
375 percentiles you wish to plot. Default None
376 percentile_plot: list, optional
377 list of analyses to plot percentiles. Default all analyses
378 fig_kwargs: dict, optional
379 optional kwargs passed directly to the _triangle_axes function
380 labels: list, optional
381 label associated with each set of samples
382 rangex: tuple, optional
383 range over which to plot the x axis
384 rangey: tuple, optional
385 range over which to plot the y axis
386 grid: Bool, optional
387 if True, show a grid on all axes. Default False
388 legend_kwargs: dict, optional
389 optional kwargs for the legend. Default {"loc": "best", "frameon": False}
390 **kwargs: dict
391 all additional kwargs are passed to the corner.hist2d function
392 """
393 return _generate_triangle_plot(*args, function=_triangle_plot, **kwargs)
396def analytic_triangle_plot(*args, **kwargs):
397 """Generate a triangle plot given probability densities for x, y and xy.
399 Parameters
400 ----------
401 fig: matplotlib.figure.Figure
402 figure on which to make the plots
403 axes: list
404 list of subplots associated with the figure
405 x: list
406 list of points to use for the x axis
407 y: list
408 list of points to use for the y axis
409 prob_x: list
410 list of probabilities associated with x
411 prob_y: list
412 list of probabilities associated with y
413 probs_xy: list
414 2d list of probabilities for xy
415 smooth: float, optional
416 degree of smoothing to apply to probs_xy. Default no smoothing applied
417 cmap: str, optional
418 name of cmap to use for plotting
419 """
420 return _generate_triangle_plot(
421 *args, function=_analytic_triangle_plot, **kwargs
422 )
425def _analytic_triangle_plot(
426 fig, axes, x, y, probs_x, probs_y, probs_xy, smooth=None, xlabel=None,
427 ylabel=None, grid=True, **kwargs
428):
429 """Generate a triangle plot given probability densities for x, y and xy.
431 Parameters
432 ----------
433 fig: matplotlib.figure.Figure
434 figure on which to make the plots
435 axes: list
436 list of subplots associated with the figure
437 x: list
438 list of points to use for the x axis
439 y: list
440 list of points to use for the y axis
441 prob_x: list
442 list of probabilities associated with x
443 prob_y: list
444 list of probabilities associated with y
445 probs_xy: list
446 2d list of probabilities for xy
447 smooth: float, optional
448 degree of smoothing to apply to probs_xy. Default no smoothing applied
449 xlabel: str, optional
450 label to use for the x axis
451 ylabel: str, optional
452 label to use for the y axis
453 grid: Bool, optional
454 if True, add a grid to the plot
455 """
456 ax1, ax3, ax4 = axes
457 analytic_twod_contour_plot(
458 x, y, probs_xy, ax=ax3, smooth=smooth, grid=grid, **kwargs
459 )
460 level_kwargs = kwargs.get("level_kwargs", None)
461 if level_kwargs is not None and "colors" in level_kwargs.keys():
462 color = level_kwargs["colors"][0]
463 else:
464 color = None
465 ax1.plot(x, probs_x, color=color)
466 ax4.plot(probs_y, y, color=color)
467 fontsize = kwargs.get("fontsize", {"label": 12})
468 if xlabel is not None:
469 ax3.set_xlabel(xlabel, fontsize=fontsize["label"])
470 if ylabel is not None:
471 ax3.set_ylabel(ylabel, fontsize=fontsize["label"])
472 ax1.grid(grid)
473 if grid:
474 ax3.grid(grid, zorder=10)
475 ax4.grid(grid)
476 xlims = ax3.get_xlim()
477 ax1.set_xlim(xlims)
478 ylims = ax3.get_ylim()
479 ax4.set_ylim(ylims)
480 fig.tight_layout()
481 return fig, ax1, ax3, ax4
484def _triangle_plot(
485 fig, axes, x, y, kde=gaussian_kde, npoints=100, kde_kwargs={}, fill=True,
486 fill_alpha=0.5, levels=[0.9], smooth=7, colors=list(conf.colorcycle),
487 xlabel=None, ylabel=None, fontsize={"legend": 12, "label": 12},
488 linestyles=None, linewidths=None, plot_density=True, percentiles=None,
489 percentile_plot=None, fig_kwargs={}, labels=None, plot_datapoints=False,
490 rangex=None, rangey=None, grid=False, latex_friendly=False, kde_2d=None,
491 kde_2d_kwargs={}, legend_kwargs={"loc": "best", "frameon": False},
492 truth=None, hist_kwargs={"density": True, "bins": 50},
493 _contour_function=twod_contour_plot, **kwargs
494):
495 """Base function to generate a triangular plot
497 Parameters
498 ----------
499 fig: matplotlib.figure.Figure
500 figure on which to make the plots
501 axes: list
502 list of subplots associated with the figure
503 x: list
504 list of samples for the x axis
505 y: list
506 list of samples for the y axis
507 kde: Bool/func, optional
508 kde to use for smoothing the 1d marginalized posterior distribution. If
509 you do not want to use KDEs, simply pass kde=False. Default
510 scipy.stats.gaussian_kde
511 kde_2d: func, optional
512 kde to use for smoothing the 2d marginalized posterior distribution.
513 default None
514 npoints: int, optional
515 number of points to use for the 1d kde
516 kde_kwargs: dict, optional
517 optional kwargs which are passed directly to the kde function.
518 kde_kwargs to be passed to the kde on the y axis may be specified
519 by the dictionary entry 'y_axis'. kde_kwargs to be passed to the kde on
520 the x axis may be specified by the dictionary entry 'x_axis'.
521 kde_2d_kwargs: dict, optional
522 optional kwargs which are passed directly to the 2d kde function
523 fill: Bool, optional
524 whether or not to fill the 1d posterior distributions
525 fill_alpha: float, optional
526 alpha to use for fill
527 levels: list, optional
528 levels you wish to use for the 2d contours
529 smooth: dict/float, optional
530 how much smoothing you wish to use for the 2d contours. If you wish
531 to use different smoothing for different contours, then provide a dict
532 with keys given by the label
533 colors: list, optional
534 list of colors you wish to use for each analysis
535 xlabel: str, optional
536 xlabel you wish to use for the plot
537 ylabel: str, optional
538 ylabel you wish to use for the plot
539 fontsize: dict, optional
540 dictionary giving the fontsize for the labels and legend. Default
541 {'legend': 12, 'label': 12}
542 linestyles: list, optional
543 linestyles you wish to use for each analysis
544 linewidths: list, optional
545 linewidths you wish to use for each analysis
546 plot_density: Bool, optional
547 whether or not to plot the density on the 2d contour. Default True
548 percentiles: list, optional
549 percentiles you wish to plot. Default None
550 percentile_plot: list, optional
551 list of analyses to plot percentiles. Default all analyses
552 fig_kwargs: dict, optional
553 optional kwargs passed directly to the _triangle_axes function
554 labels: list, optional
555 label associated with each set of samples
556 rangex: tuple, optional
557 range over which to plot the x axis
558 rangey: tuple, optional
559 range over which to plot the y axis
560 grid: Bool, optional
561 if True, show a grid on all axes
562 legend_kwargs: dict, optional
563 optional kwargs for the legend. Default {"loc": "best", "frameon": False}
564 **kwargs: dict
565 all kwargs are passed to the corner.hist2d function
566 """
567 ax1, ax3, ax4 = axes
568 if not isinstance(x[0], (list, np.ndarray)):
569 x, y = np.atleast_2d(x), np.atleast_2d(y)
570 _base_error = "Please provide {} for each analysis"
571 if len(colors) < len(x):
572 raise ValueError(_base_error.format("a single color"))
573 if linestyles is None:
574 linestyles = ["-"] * len(x)
575 elif len(linestyles) < len(x):
576 raise ValueError(_base_error.format("a single linestyle"))
577 if linewidths is None:
578 linewidths = [None] * len(x)
579 elif len(linewidths) < len(x):
580 raise ValueError(_base_error.format("a single linewidth"))
581 if labels is None:
582 labels = [None] * len(x)
583 elif len(labels) != len(x):
584 raise ValueError(_base_error.format("a label"))
586 xlow = np.min([np.min(_x) for _x in x])
587 xhigh = np.max([np.max(_x) for _x in x])
588 ylow = np.min([np.min(_y) for _y in y])
589 yhigh = np.max([np.max(_y) for _y in y])
590 if rangex is not None:
591 xlow, xhigh = rangex
592 if rangey is not None:
593 ylow, yhigh = rangey
594 for num in range(len(x)):
595 plot_kwargs = dict(
596 color=colors[num], linewidth=linewidths[num],
597 linestyle=linestyles[num]
598 )
599 if kde:
600 if "x_axis" in kde_kwargs.keys():
601 _kde = kde(x[num], **kde_kwargs["x_axis"])
602 else:
603 _kde = kde(x[num], **kde_kwargs)
604 _x = np.linspace(xlow, xhigh, npoints)
605 _y = _kde(_x)
606 ax1.plot(_x, _y, **plot_kwargs)
607 if fill:
608 ax1.fill_between(_x, 0, _y, alpha=fill_alpha, **plot_kwargs)
609 _y = np.linspace(ylow, yhigh, npoints)
610 if "y_axis" in kde_kwargs.keys():
611 _kde = kde(y[num], **kde_kwargs["y_axis"])
612 else:
613 _kde = kde(y[num], **kde_kwargs)
614 _x = _kde(_y)
615 if latex_friendly:
616 labels = copy.deepcopy(labels)
617 labels[num] = labels[num].replace("_", "\_")
618 ax4.plot(_x, _y, label=labels[num], **plot_kwargs)
619 if fill:
620 ax4.fill_betweenx(_y, 0, _x, alpha=fill_alpha, **plot_kwargs)
621 else:
622 if fill:
623 histtype = "stepfilled"
624 else:
625 histtype = "step"
626 ax1.hist(x[num], histtype=histtype, **hist_kwargs, **plot_kwargs)
627 ax4.hist(
628 y[num], histtype=histtype, orientation="horizontal",
629 **hist_kwargs, **plot_kwargs
630 )
631 if percentiles is not None:
632 if percentile_plot is not None and labels[num] in percentile_plot:
633 _percentiles = np.percentile(x[num], percentiles)
634 ax1.axvline(
635 _percentiles[0], linestyle="--",
636 linewidth=plot_kwargs.get("linewidth", 1.75)
637 )
638 ax1.axvline(
639 _percentiles[1], linestyle="--",
640 linewidth=plot_kwargs.get("linewidth", 1.75)
641 )
642 _percentiles = np.percentile(y[num], percentiles)
643 ax4.axhline(
644 _percentiles[0], linestyle="--",
645 linewidth=plot_kwargs.get("linewidth", 1.75)
646 )
647 ax4.axhline(
648 _percentiles[1], linestyle="--",
649 linewidth=plot_kwargs.get("linewidth", 1.75)
650 )
651 if isinstance(smooth, dict):
652 _smooth = smooth[labels[num]]
653 else:
654 _smooth = smooth
655 _contour_function(
656 x[num], y[num], ax=ax3, levels=levels, smooth=_smooth,
657 rangex=[xlow, xhigh], rangey=[ylow, yhigh], color=colors[num],
658 linestyles=linestyles[num],
659 plot_density=plot_density, contour_kwargs=dict(
660 linestyles=[linestyles[num]], linewidths=linewidths[num]
661 ), plot_datapoints=plot_datapoints, kde=kde_2d,
662 kde_kwargs=kde_2d_kwargs, grid=False, truth=truth, **kwargs
663 )
665 if truth is not None:
666 ax1.axvline(truth[0], color='k', linewidth=0.5)
667 ax4.axhline(truth[1], color='k', linewidth=0.5)
668 if xlabel is not None:
669 ax3.set_xlabel(xlabel, fontsize=fontsize["label"])
670 if ylabel is not None:
671 ax3.set_ylabel(ylabel, fontsize=fontsize["label"])
672 if not all(label is None for label in labels):
673 legend_kwargs["fontsize"] = fontsize["legend"]
674 ax3.legend(*ax4.get_legend_handles_labels(), **legend_kwargs)
675 ax1.grid(grid)
676 ax3.grid(grid)
677 ax4.grid(grid)
678 xlims = ax1.get_xlim()
679 ax3.set_xlim(xlims)
680 ylims = ax4.get_ylim()
681 ax3.set_ylim(ylims)
682 return fig, ax1, ax3, ax4
685def _generate_reverse_triangle_plot(
686 *args, xlabel=None, ylabel=None, function=None, existing_figure=None, **kwargs
687):
688 """Generate a reverse triangle plot according to a given function
690 Parameters
691 ----------
692 *args: tuple
693 all args passed to function
694 xlabel: str, optional
695 label to use for the x axis
696 ylabel: str, optional
697 label to use for the y axis
698 function: func, optional
699 function to use to generate triangle plot. Default _triangle_plot
700 **kwargs: dict, optional
701 all kwargs passed to function
702 """
703 if existing_figure is None:
704 fig, ax1, ax2, ax3, ax4 = _triangle_axes(
705 width_ratios=[1, 4], height_ratios=[4, 1]
706 )
707 ax3.axis("off")
708 else:
709 fig, ax1, ax2, ax4 = existing_figure
710 if function is None:
711 function = _triangle_plot
712 fig, ax4, ax2, ax1 = function(fig, [ax4, ax2, ax1], *args, **kwargs)
713 ax2.axis("off")
714 ax4.spines["right"].set_visible(False)
715 ax4.spines["top"].set_visible(False)
716 ax4.spines["left"].set_visible(False)
717 ax4.set_yticks([])
719 ax1.spines["right"].set_visible(False)
720 ax1.spines["top"].set_visible(False)
721 ax1.spines["bottom"].set_visible(False)
722 ax1.set_xticks([])
724 _fontsize = kwargs.get("fontsize", {"label": 12})["label"]
725 if xlabel is not None:
726 ax4.set_xlabel(xlabel, fontsize=_fontsize)
727 if ylabel is not None:
728 ax1.set_ylabel(ylabel, fontsize=_fontsize)
729 return fig, ax1, ax2, ax4
732def reverse_triangle_plot(*args, **kwargs):
733 """Generate a triangular plot made of 3 axis. One central axis showing the
734 2d marginalized posterior and two smaller axes showing the marginalized 1d
735 posterior distribution (below and to the left of central axis). Only two
736 axes are plotted, each below the 1d marginalized posterior distribution
738 Parameters
739 ----------
740 x: list
741 list of samples for the x axis
742 y: list
743 list of samples for the y axis
744 kde: Bool/func, optional
745 kde to use for smoothing the 1d marginalized posterior distribution. If
746 you do not want to use KDEs, simply pass kde=False. Default
747 scipy.stats.gaussian_kde
748 kde_2d: func, optional
749 kde to use for smoothing the 2d marginalized posterior distribution.
750 default None
751 npoints: int, optional
752 number of points to use for the 1d kde
753 kde_kwargs: dict, optional
754 optional kwargs which are passed directly to the kde function.
755 kde_kwargs to be passed to the kde on the y axis may be specified
756 by the dictionary entry 'y_axis'. kde_kwargs to be passed to the kde on
757 the x axis may be specified by the dictionary entry 'x_axis'.
758 kde_2d_kwargs: dict, optional
759 optional kwargs which are passed directly to the 2d kde function
760 fill: Bool, optional
761 whether or not to fill the 1d posterior distributions
762 fill_alpha: float, optional
763 alpha to use for fill
764 levels: list, optional
765 levels you wish to use for the 2d contours
766 smooth: dict/float, optional
767 how much smoothing you wish to use for the 2d contours. If you wish
768 to use different smoothing for different contours, then provide a dict
769 with keys given by the label
770 colors: list, optional
771 list of colors you wish to use for each analysis
772 xlabel: str, optional
773 xlabel you wish to use for the plot
774 ylabel: str, optional
775 ylabel you wish to use for the plot
776 fontsize: dict, optional
777 dictionary giving the fontsize for the labels and legend. Default
778 {'legend': 12, 'label': 12}
779 linestyles: list, optional
780 linestyles you wish to use for each analysis
781 linewidths: list, optional
782 linewidths you wish to use for each analysis
783 plot_density: Bool, optional
784 whether or not to plot the density on the 2d contour. Default True
785 percentiles: list, optional
786 percentiles you wish to plot. Default None
787 percentile_plot: list, optional
788 list of analyses to plot percentiles. Default all analyses
789 fig_kwargs: dict, optional
790 optional kwargs passed directly to the _triangle_axes function
791 labels: list, optional
792 label associated with each set of samples
793 rangex: tuple, optional
794 range over which to plot the x axis
795 rangey: tuple, optional
796 range over which to plot the y axis
797 legend_kwargs: dict, optional
798 optional kwargs for the legend. Default {"loc": "best", "frameon": False}
799 **kwargs: dict
800 all kwargs are passed to the corner.hist2d function
801 """
802 return _generate_reverse_triangle_plot(
803 *args, function=_triangle_plot, **kwargs
804 )
807def analytic_reverse_triangle_plot(*args, **kwargs):
808 """Generate a triangle plot given probability densities for x, y and xy.
810 Parameters
811 ----------
812 fig: matplotlib.figure.Figure
813 figure on which to make the plots
814 axes: list
815 list of subplots associated with the figure
816 x: list
817 list of points to use for the x axis
818 y: list
819 list of points to use for the y axis
820 prob_x: list
821 list of probabilities associated with x
822 prob_y: list
823 list of probabilities associated with y
824 probs_xy: list
825 2d list of probabilities for xy
826 smooth: float, optional
827 degree of smoothing to apply to probs_xy. Default no smoothing applied
828 cmap: str, optional
829 name of cmap to use for plotting
830 """
831 return _generate_reverse_triangle_plot(
832 *args, function=_analytic_triangle_plot, **kwargs
833 )