Coverage for pesummary/core/plots/seaborn/violin.py: 64.9%
94 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3from scipy.stats import gaussian_kde
4import numpy as np
5from pesummary.core.plots.palette import color_palette
6from pesummary.core.plots.seaborn import SEABORN
7from .kde import _BaseKDE
8if SEABORN:
9 from seaborn import categorical
10 from seaborn import _base
11 from seaborn._stats.density import KDE as _DensityKDE
12else:
13 class _DensityKDE(object):
14 """Dummy class for the KDE class to inherit
15 """
17 class _base(object):
18 class HueMapping(object):
19 """Dummy class for the HueMapping to inherit
20 """
22 class categorical(object):
23 class _CategoricalPlotter(object):
24 """Dummy class for the _CategoricalPlotter to inherit
25 """
27 def violinplot(*args, **kwargs):
28 """Dummy function to call
29 """
30 raise ValueError("Unable to produce violinplot with 'seaborn'")
33__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
36class DensityKDE(_BaseKDE, _DensityKDE):
37 """Extension of the `seaborn._stats.density.KDE` to allow for custom
38 kde_kernel
40 Parameters
41 ----------
42 *args: tuple
43 all args passed to the `seaborn._stats.density.KDE` class
44 kde_kernel: func, optional
45 kernel you wish to use to evaluate the KDE. Default
46 scipy.stats.gaussian_kde
47 kde_kwargs: dict, optional
48 optional kwargs to be passed to the kde_kernel. Default {}
49 **kwargs: dict
50 all kwargs passed to the `seaborn._stats.density.KDE` class
51 """
52 def _fit(self, fit_data, orient, **kwargs):
53 return super()._fit(fit_data[orient], **kwargs)
56class HueMapping(_base.HueMapping):
57 ind = {"left": 0, "right": 0, "num": 0}
58 _palette_dict = {"left": False, "right": False}
59 _lookup_table = {"left": None, "right": None}
61 def _lookup_single(self, key):
62 # check for different colored left and right violins
63 if colorlist is not None and self.palette is None:
64 color = colorlist[self.ind["num"]]
65 self.ind["num"] += 1
66 return color
67 if key not in self._palette_dict.keys():
68 return super()._lookup_single(key)
69 if self._palette_dict[key]:
70 color = self._lookup_table[key][self.ind[key]]
71 else:
72 color = self.lookup_table[key]
73 self.ind[key] += 1
74 return color
76 def categorical_mapping(self, data, palette, order):
77 levels, lookup_table = super().categorical_mapping(data, palette, order)
78 if isinstance(palette, dict):
79 for key in ["left", "right"]:
80 if key in palette:
81 if "color:" in palette[key]:
82 _color = palette[key].replace(" ", "").split(":")[1]
83 lookup_table[key] = _color
84 else:
85 self._palette_dict[key] = True
86 self._lookup_table[key] = color_palette(palette[key], n_colors=10)
87 _color = color_palette(palette[key], n_colors=1)[0]
88 lookup_table[key] = _color
89 return levels, lookup_table
92class _CategoricalPlotter(categorical._CategoricalPlotter):
93 def plot_violins(self, *args, **kwargs):
94 _kwargs = kwargs.copy()
95 kde_kws = _kwargs["kde_kws"]
96 kde_kws.update({"kde_kernel": KDE, "kde_kwargs": KDE_kwargs})
97 kde_kws.pop("gridsize", None)
98 kde_kws.pop("bw_adjust", None)
99 _kwargs["kde_kws"] = kde_kws
100 return super().plot_violins(*args, **_kwargs)
103categorical._CategoricalPlotter = _CategoricalPlotter
104categorical.KDE = DensityKDE
105_base.HueMapping = HueMapping
108def violinplot(
109 *args, kde_kernel=gaussian_kde, kde_kwargs={}, inj=None, colors=None,
110 **kwargs
111):
112 """Extension of the seaborn.categorical.violinplot function to allow for
113 a custom kde_kernel and associated kwargs.
115 Parameters
116 ----------
117 *args: tuple
118 all args passed to the `seaborn.categorical.violinplot` function
119 kde_kernel: func, optional
120 kernel you wish to use to evaluate the KDE. Default
121 scipy.stats.gaussian_kde
122 kde_kwargs: dict, optional
123 optional kwargs to be passed to the kde_kernel. Default {}
124 inj: float, optional
125 injected value. Currently ignored, but kept for backwards compatibility
126 colors: list, optional
127 list of colors to use for each violin. Default None
128 **kwargs: dict
129 all kwargs passed to the `seaborn.categorical.violinplot` class
130 """
131 global KDE
132 global KDE_kwargs
133 global colorlist
134 KDE = kde_kernel
135 KDE_kwargs = kde_kwargs
136 colorlist = colors
137 return categorical.violinplot(*args, **kwargs)
140def split_dataframe(
141 left, right, labels, left_label="left", right_label="right",
142 weights_left=None, weights_right=None
143):
144 """Generate a pandas DataFrame containing two sets of distributions -- one
145 set for the left hand side of the violins, and one set for the right hand
146 side of the violins
148 Parameters
149 ----------
150 left: np.ndarray
151 array of samples representing the left hand side of the violins
152 right: np.ndarray
153 array of samples representing the right hand side of the violins
154 labels: np.array
155 array containing the label associated with each violin
156 """
157 import pandas
159 nviolin = len(left)
160 if len(left) != len(right) != len(labels):
161 raise ValueError("Please ensure that 'left' == 'right' == 'labels'")
162 _left_label = np.array(
163 [[left_label] * len(sample) for sample in left], dtype="object"
164 )
165 _right_label = np.array(
166 [[right_label] * len(sample) for sample in right], dtype="object"
167 )
168 _labels = [
169 [label] * (len(left[num]) + len(right[num])) for num, label in
170 enumerate(labels)
171 ]
172 labels = [x for y in _labels for x in y]
173 dataframe = [
174 x for y in [[i, j] for i, j in zip(left, right)] for x in y
175 ]
176 dataframe = [x for y in dataframe for x in y]
177 sides = [
178 x for y in [[i, j] for i, j in zip(_left_label, _right_label)] for x in
179 y
180 ]
181 sides = [x for y in sides for x in y]
182 df = pandas.DataFrame(
183 data={"data": dataframe, "side": sides, "label": labels}
184 )
185 if all(kwarg is None for kwarg in [weights_left, weights_right]):
186 return df
188 left_inds = df["side"][df["side"] == left_label].index
189 right_inds = df["side"][df["side"] == right_label].index
190 if weights_left is not None and weights_right is None:
191 weights_right = [np.ones(len(right[num])) for num in range(nviolin)]
192 elif weights_left is None and weights_right is not None:
193 weights_left = [np.ones(len(left[num])) for num in range(nviolin)]
194 if any(len(kwarg) != nviolin for kwarg in [weights_left, weights_right]):
195 raise ValueError("help")
197 weights = [
198 x for y in [[i, j] for i, j in zip(weights_left, weights_right)] for x in y
199 ]
200 weights = [x for y in weights for x in y]
201 df["weights"] = weights
202 return df