Coverage for pesummary/core/plots/seaborn/kde.py: 8.5%
176 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import numpy as np
4import warnings
5from scipy import stats
6from seaborn.distributions import (
7 _DistributionPlotter as SeabornDistributionPlotter, KDE as SeabornKDE,
8)
9from seaborn.utils import _normalize_kwargs, _check_argument
11import pandas as pd
13__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>", "Seaborn authors"]
16class KDE(SeabornKDE):
17 """Extension of the `seaborn._statistics.KDE` to allow for custom
18 kde_kernel
20 Parameters
21 ----------
22 *args: tuple
23 all args passed to the `seaborn._statistics.KDE` class
24 kde_kernel: func, optional
25 kernel you wish to use to evaluate the KDE. Default
26 scipy.stats.gaussian_kde
27 kde_kwargs: dict, optional
28 optional kwargs to be passed to the kde_kernel. Default {}
29 **kwargs: dict
30 all kwargs passed to the `seaborn._statistics.KDE` class
31 """
32 def __init__(
33 self, *args, kde_kernel=stats.gaussian_kde, kde_kwargs={}, **kwargs
34 ):
35 super(KDE, self).__init__(*args, **kwargs)
36 self._kde_kernel = kde_kernel
37 self._kde_kwargs = kde_kwargs
39 def _fit(self, fit_data, weights=None):
40 """Fit the scipy kde while adding bw_adjust logic and version check."""
41 fit_kws = self._kde_kwargs
42 fit_kws["bw_method"] = self.bw_method
43 if weights is not None:
44 fit_kws["weights"] = weights
46 kde = self._kde_kernel(fit_data, **fit_kws)
47 kde.set_bandwidth(kde.factor * self.bw_adjust)
48 return kde
51class _DistributionPlotter(SeabornDistributionPlotter):
52 """Extension of the `seaborn._statistics._DistributionPlotter` to allow for
53 the custom KDE method to be used
55 Parameters
56 ----------
57 *args: tuple
58 all args passed to the `seaborn._statistics._DistributionPlotter` class
59 **kwargs: dict
60 all kwargs passed to the `seaborn._statistics._DistributionPlotter`
61 class
62 """
63 def __init__(self, *args, **kwargs):
64 super(_DistributionPlotter, self).__init__(*args, **kwargs)
66 def plot_univariate_density(
67 self,
68 multiple,
69 common_norm,
70 common_grid,
71 fill,
72 legend,
73 estimate_kws,
74 variance_atol,
75 **plot_kws,
76 ):
78 import matplotlib as mpl
79 # Handle conditional defaults
80 if fill is None:
81 fill = multiple in ("stack", "fill")
83 # Preprocess the matplotlib keyword dictionaries
84 if fill:
85 artist = mpl.collections.PolyCollection
86 else:
87 artist = mpl.lines.Line2D
88 plot_kws = _normalize_kwargs(plot_kws, artist)
90 # Input checking
91 _check_argument("multiple", ["layer", "stack", "fill"], multiple)
93 # Always share the evaluation grid when stacking
94 subsets = bool(set(self.variables) - {"x", "y"})
95 if subsets and multiple in ("stack", "fill"):
96 common_grid = True
98 # Check if the data axis is log scaled
99 log_scale = self._log_scaled(self.data_variable)
101 # Do the computation
102 densities = self._compute_univariate_density(
103 self.data_variable,
104 common_norm,
105 common_grid,
106 estimate_kws,
107 log_scale,
108 variance_atol,
109 )
111 # Note: raises when no hue and multiple != layer. A problem?
112 densities, baselines = self._resolve_multiple(densities, multiple)
114 # Control the interaction with autoscaling by defining sticky_edges
115 # i.e. we don't want autoscale margins below the density curve
116 sticky_density = (0, 1) if multiple == "fill" else (0, np.inf)
118 if multiple == "fill":
119 # Filled plots should not have any margins
120 sticky_support = densities.index.min(), densities.index.max()
121 else:
122 sticky_support = []
124 # Handle default visual attributes
125 if "hue" not in self.variables:
126 if self.ax is None:
127 color = plot_kws.pop("color", None)
128 default_color = "C0" if color is None else color
129 else:
130 if fill:
131 if self.var_types[self.data_variable] == "datetime":
132 # Avoid drawing empty fill_between on date axis
133 # https://github.com/matplotlib/matplotlib/issues/17586
134 scout = None
135 default_color = plot_kws.pop(
136 "color", plot_kws.pop("facecolor", None)
137 )
138 if default_color is None:
139 default_color = "C0"
140 else:
141 alpha_shade = plot_kws.pop("alpha_shade", 0.25)
142 scout = self.ax.fill_between([], [], **plot_kws)
143 default_color = tuple(scout.get_facecolor().squeeze())
144 plot_kws.pop("color", None)
145 else:
146 plot_kws.pop("alpha_shade", 0.25)
147 scout, = self.ax.plot([], [], **plot_kws)
148 default_color = scout.get_color()
149 if scout is not None:
150 scout.remove()
152 plot_kws.pop("color", None)
154 default_alpha = .25 if multiple == "layer" else .75
155 alpha = plot_kws.pop("alpha", default_alpha) # TODO make parameter?
157 # Now iterate through the subsets and draw the densities
158 # We go backwards so stacked densities read from top-to-bottom
159 for sub_vars, _ in self.iter_data("hue", reverse=True):
161 # Extract the support grid and density curve for this level
162 key = tuple(sub_vars.items())
163 try:
164 density = densities[key]
165 except KeyError:
166 continue
167 support = density.index
168 fill_from = baselines[key]
170 ax = self._get_axes(sub_vars)
172 # Modify the matplotlib attributes from semantic mapping
173 if "hue" in self.variables:
174 color = self._hue_map(sub_vars["hue"])
175 else:
176 color = default_color
178 artist_kws = self._artist_kws(
179 plot_kws, fill, False, multiple, color, alpha
180 )
182 # Either plot a curve with observation values on the x axis
183 if "x" in self.variables:
185 if fill:
186 artist = ax.fill_between(
187 support, fill_from, density, **artist_kws
188 )
189 else:
190 artist, = ax.plot(support, density, **artist_kws)
192 artist.sticky_edges.x[:] = sticky_support
193 artist.sticky_edges.y[:] = sticky_density
195 # Or plot a curve with observation values on the y axis
196 else:
197 if fill:
198 artist = ax.fill_betweenx(
199 support, fill_from, density, **artist_kws
200 )
201 else:
202 artist, = ax.plot(density, support, **artist_kws)
204 artist.sticky_edges.x[:] = sticky_density
205 artist.sticky_edges.y[:] = sticky_support
207 # --- Finalize the plot ----
209 ax = self.ax if self.ax is not None else self.facets.axes.flat[0]
210 default_x = default_y = ""
211 if self.data_variable == "x":
212 default_y = "Density"
213 if self.data_variable == "y":
214 default_x = "Density"
215 self._add_axis_labels(ax, default_x, default_y)
217 if "hue" in self.variables and legend:
218 from functools import partial
219 if fill:
220 artist = partial(mpl.patches.Patch)
221 else:
222 artist = partial(mpl.lines.Line2D, [], [])
224 ax_obj = self.ax if self.ax is not None else self.facets
225 self._add_legend(
226 ax_obj, artist, fill, False, multiple, alpha, plot_kws, {},
227 )
229 def _compute_univariate_density(
230 self,
231 data_variable,
232 common_norm,
233 common_grid,
234 estimate_kws,
235 log_scale,
236 variance_atol,
237 ):
239 # Initialize the estimator object
240 estimator = KDE(**estimate_kws)
242 all_data = self.plot_data.dropna()
244 if set(self.variables) - {"x", "y"}:
245 if common_grid:
246 all_observations = self.comp_data.dropna()
247 estimator.define_support(all_observations[data_variable])
248 else:
249 common_norm = False
251 densities = {}
253 for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True):
255 # Extract the data points from this sub set and remove nulls
256 sub_data = sub_data.dropna()
257 observations = sub_data[data_variable]
259 observation_variance = observations.var()
260 if np.isclose(observation_variance, 0, atol=variance_atol) or np.isnan(observation_variance):
261 msg = "Dataset has 0 variance; skipping density estimate."
262 warnings.warn(msg, UserWarning)
263 continue
265 # Extract the weights for this subset of observations
266 if "weights" in self.variables:
267 weights = sub_data["weights"]
268 else:
269 weights = None
271 # Estimate the density of observations at this level
272 density, support = estimator(observations, weights=weights)
274 if log_scale:
275 support = np.power(10, support)
277 # Apply a scaling factor so that the integral over all subsets is 1
278 if common_norm:
279 density *= len(sub_data) / len(all_data)
281 # Store the density for this level
282 key = tuple(sub_vars.items())
283 densities[key] = pd.Series(density, index=support)
285 return densities
288def kdeplot(
289 x=None, # Allow positional x, because behavior will not change with reorg
290 *,
291 y=None,
292 shade=None, # Note "soft" deprecation, explained below
293 vertical=False, # Deprecated
294 kernel=None, # Deprecated
295 bw=None, # Deprecated
296 gridsize=200, # TODO maybe depend on uni/bivariate?
297 cut=3, clip=None, legend=True, cumulative=False,
298 shade_lowest=None, # Deprecated, controlled with levels now
299 cbar=False, cbar_ax=None, cbar_kws=None,
300 ax=None,
302 # New params
303 weights=None, # TODO note that weights is grouped with semantics
304 hue=None, palette=None, hue_order=None, hue_norm=None,
305 multiple="layer", common_norm=True, common_grid=False,
306 levels=10, thresh=.05,
307 bw_method="scott", bw_adjust=1, log_scale=None,
308 color=None, fill=None, kde_kernel=stats.gaussian_kde, kde_kwargs={},
309 variance_atol=1e-8,
311 # Renamed params
312 data=None, data2=None,
314 **kwargs,
315):
317 if kde_kernel is None:
318 kde_kernel = stats.gaussian_kde
319 # Handle deprecation of `data2` as name for y variable
320 if data2 is not None:
322 y = data2
324 # If `data2` is present, we need to check for the `data` kwarg being
325 # used to pass a vector for `x`. We'll reassign the vectors and warn.
326 # We need this check because just passing a vector to `data` is now
327 # technically valid.
329 x_passed_as_data = (
330 x is None
331 and data is not None
332 and np.ndim(data) == 1
333 )
335 if x_passed_as_data:
336 msg = "Use `x` and `y` rather than `data` `and `data2`"
337 x = data
338 else:
339 msg = "The `data2` param is now named `y`; please update your code"
341 warnings.warn(msg, FutureWarning)
343 # Handle deprecation of `vertical`
344 if vertical:
345 msg = (
346 "The `vertical` parameter is deprecated and will be removed in a "
347 "future version. Assign the data to the `y` variable instead."
348 )
349 warnings.warn(msg, FutureWarning)
350 x, y = y, x
352 # Handle deprecation of `bw`
353 if bw is not None:
354 msg = (
355 "The `bw` parameter is deprecated in favor of `bw_method` and "
356 f"`bw_adjust`. Using {bw} for `bw_method`, but please "
357 "see the docs for the new parameters and update your code."
358 )
359 warnings.warn(msg, FutureWarning)
360 bw_method = bw
362 # Handle deprecation of `kernel`
363 if kernel is not None:
364 msg = (
365 "Support for alternate kernels has been removed. "
366 "Using Gaussian kernel."
367 )
368 warnings.warn(msg, UserWarning)
370 # Handle deprecation of shade_lowest
371 if shade_lowest is not None:
372 if shade_lowest:
373 thresh = 0
374 msg = (
375 "`shade_lowest` is now deprecated in favor of `thresh`. "
376 f"Setting `thresh={thresh}`, but please update your code."
377 )
378 warnings.warn(msg, UserWarning)
380 # Handle `n_levels`
381 # This was never in the formal API but it was processed, and appeared in an
382 # example. We can treat as an alias for `levels` now and deprecate later.
383 levels = kwargs.pop("n_levels", levels)
385 # Handle "soft" deprecation of shade `shade` is not really the right
386 # terminology here, but unlike some of the other deprecated parameters it
387 # is probably very commonly used and much hard to remove. This is therefore
388 # going to be a longer process where, first, `fill` will be introduced and
389 # be used throughout the documentation. In 0.12, when kwarg-only
390 # enforcement hits, we can remove the shade/shade_lowest out of the
391 # function signature all together and pull them out of the kwargs. Then we
392 # can actually fire a FutureWarning, and eventually remove.
393 if shade is not None:
394 fill = shade
396 # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
398 p = _DistributionPlotter(
399 data=data,
400 variables=_DistributionPlotter.get_semantics(locals()),
401 )
403 p.map_hue(palette=palette, order=hue_order, norm=hue_norm)
405 if ax is None:
406 import matplotlib.pyplot as plt
407 ax = plt.gca()
409 # Check for a specification that lacks x/y data and return early
410 if not p.has_xy_data:
411 return ax
413 # Pack the kwargs for statistics.KDE
414 estimate_kws = dict(
415 bw_method=bw_method,
416 bw_adjust=bw_adjust,
417 gridsize=gridsize,
418 cut=cut,
419 clip=clip,
420 cumulative=cumulative,
421 kde_kernel=kde_kernel,
422 kde_kwargs=kde_kwargs
423 )
425 p._attach(ax, allowed_types=["numeric", "datetime"], log_scale=log_scale)
427 if p.univariate:
429 plot_kws = kwargs.copy()
430 if color is not None:
431 plot_kws["color"] = color
433 p.plot_univariate_density(
434 multiple=multiple,
435 common_norm=common_norm,
436 common_grid=common_grid,
437 fill=fill,
438 legend=legend,
439 estimate_kws=estimate_kws,
440 variance_atol=variance_atol,
441 **plot_kws,
442 )
444 else:
446 p.plot_bivariate_density(
447 common_norm=common_norm,
448 fill=fill,
449 levels=levels,
450 thresh=thresh,
451 legend=legend,
452 color=color,
453 cbar=cbar,
454 cbar_ax=cbar_ax,
455 cbar_kws=cbar_kws,
456 estimate_kws=estimate_kws,
457 **kwargs,
458 )
460 return ax