Coverage for pesummary/core/plots/seaborn/kde.py: 82.2%
45 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3from scipy import stats
4from pesummary.core.plots.seaborn import SEABORN
5if SEABORN:
6 from seaborn._statistics import KDE as _StatisticsKDE
7 from seaborn import distributions
8else:
9 class _StatisticsKDE(object):
10 """Dummy class for the KDE class to inherit
11 """
13 class distributions(object):
14 class _DistributionPlotter(object):
15 """Dummy class for the _DistributionPlotter class to inherit
16 """
18 def kdeplot(*args, **kwargs):
19 """Dummy function to call
20 """
21 raise ValueError("Unable to produce kdeplot with 'seaborn'")
24__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>", "Seaborn authors"]
27class _BaseKDE(object):
28 """Extension of the `seaborn._statistics.KDE` to allow for custom
29 kde_kernel
31 Parameters
32 ----------
33 *args: tuple
34 all args passed to the `seaborn._statistics.KDE` class
35 kde_kernel: func, optional
36 kernel you wish to use to evaluate the KDE. Default
37 scipy.stats.gaussian_kde
38 kde_kwargs: dict, optional
39 optional kwargs to be passed to the kde_kernel. Default {}
40 **kwargs: dict
41 all kwargs passed to the `seaborn._statistics.KDE` class
42 """
43 def __init__(self, *args, kde_kernel=stats.gaussian_kde, kde_kwargs={}, **kwargs):
44 _kwargs = kwargs.copy()
45 for key, item in kwargs.items():
46 if key == "bw_method" or key == "bw_adjust":
47 setattr(self, key, item)
48 _kwargs.pop(key, None)
49 elif key not in ["gridsize", "cut", "clip", "cumulative"]:
50 kde_kwargs[key] = item
51 _kwargs.pop(key, None)
52 super().__init__(*args, **_kwargs)
53 if kde_kernel is None:
54 kde_kernel = stats.gaussian_kde
55 self._kde_kernel = kde_kernel
56 self._kde_kwargs = kde_kwargs
58 def _fit(self, fit_data, weights=None):
59 """Fit the scipy kde while adding bw_adjust logic and version check."""
60 fit_kws = self._kde_kwargs
61 fit_kws["bw_method"] = self.bw_method
62 if weights is not None:
63 fit_kws["weights"] = weights
65 kde = self._kde_kernel(fit_data, **fit_kws)
66 kde.set_bandwidth(kde.factor * self.bw_adjust)
67 return kde
70class StatisticsKDE(_BaseKDE, _StatisticsKDE):
71 """Extension of the `seaborn._statistics.KDE` to allow for custom
72 kde_kernel
74 Parameters
75 ----------
76 *args: tuple
77 all args passed to the `seaborn._statistics.KDE` class
78 kde_kernel: func, optional
79 kernel you wish to use to evaluate the KDE. Default
80 scipy.stats.gaussian_kde
81 kde_kwargs: dict, optional
82 optional kwargs to be passed to the kde_kernel. Default {}
83 **kwargs: dict
84 all kwargs passed to the `seaborn._statistics.KDE` class
85 """
88class _DistributionPlotter(distributions._DistributionPlotter):
89 def _compute_univariate_density(
90 self, data_variable, common_norm, common_grid, estimate_kws,
91 log_scale, **kwargs
92 ):
93 estimate_kws.update({"kde_kernel": KDE, "kde_kwargs": KDE_kwargs})
94 return super()._compute_univariate_density(
95 data_variable, common_norm, common_grid, estimate_kws,
96 log_scale, **kwargs
97 )
100distributions.KDE = StatisticsKDE
101distributions._DistributionPlotter = _DistributionPlotter
104def kdeplot(*args, kde_kernel=stats.gaussian_kde, kde_kwargs={}, **kwargs):
105 """Extension of the seaborn.distributions.kdeplot function to allow for
106 a custom kde_kernel and associated kwargs.
108 Parameters
109 ----------
110 *args: tuple
111 all args passed to the `seaborn.distributions.kdeplot` function
112 kde_kernel: func, optional
113 kernel you wish to use to evaluate the KDE. Default
114 scipy.stats.gaussian_kde
115 kde_kwargs: dict, optional
116 optional kwargs to be passed to the kde_kernel. Default {}
117 **kwargs: dict
118 all kwargs passed to the `seaborn.distributions.kdeplot` class
119 """
120 global KDE
121 global KDE_kwargs
122 KDE = kde_kernel
123 KDE_kwargs = kde_kwargs
124 return distributions.kdeplot(*args, **kwargs)