Coverage for pesummary/core/plots/seaborn/violin.py: 64.9%

94 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3from scipy.stats import gaussian_kde 

4import numpy as np 

5from pesummary.core.plots.palette import color_palette 

6from pesummary.core.plots.seaborn import SEABORN 

7from .kde import _BaseKDE 

8if SEABORN: 

9 from seaborn import categorical 

10 from seaborn import _base 

11 from seaborn._stats.density import KDE as _DensityKDE 

12else: 

13 class _DensityKDE(object): 

14 """Dummy class for the KDE class to inherit 

15 """ 

16 

17 class _base(object): 

18 class HueMapping(object): 

19 """Dummy class for the HueMapping to inherit 

20 """ 

21 

22 class categorical(object): 

23 class _CategoricalPlotter(object): 

24 """Dummy class for the _CategoricalPlotter to inherit 

25 """ 

26 

27 def violinplot(*args, **kwargs): 

28 """Dummy function to call 

29 """ 

30 raise ValueError("Unable to produce violinplot with 'seaborn'") 

31 

32 

33__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

34 

35 

36class DensityKDE(_BaseKDE, _DensityKDE): 

37 """Extension of the `seaborn._stats.density.KDE` to allow for custom 

38 kde_kernel 

39 

40 Parameters 

41 ---------- 

42 *args: tuple 

43 all args passed to the `seaborn._stats.density.KDE` class 

44 kde_kernel: func, optional 

45 kernel you wish to use to evaluate the KDE. Default 

46 scipy.stats.gaussian_kde 

47 kde_kwargs: dict, optional 

48 optional kwargs to be passed to the kde_kernel. Default {} 

49 **kwargs: dict 

50 all kwargs passed to the `seaborn._stats.density.KDE` class 

51 """ 

52 def _fit(self, fit_data, orient, **kwargs): 

53 return super()._fit(fit_data[orient], **kwargs) 

54 

55 

56class HueMapping(_base.HueMapping): 

57 ind = {"left": 0, "right": 0, "num": 0} 

58 _palette_dict = {"left": False, "right": False} 

59 _lookup_table = {"left": None, "right": None} 

60 

61 def _lookup_single(self, key): 

62 # check for different colored left and right violins 

63 if colorlist is not None and self.palette is None: 

64 color = colorlist[self.ind["num"]] 

65 self.ind["num"] += 1 

66 return color 

67 if key not in self._palette_dict.keys(): 

68 return super()._lookup_single(key) 

69 if self._palette_dict[key]: 

70 color = self._lookup_table[key][self.ind[key]] 

71 else: 

72 color = self.lookup_table[key] 

73 self.ind[key] += 1 

74 return color 

75 

76 def categorical_mapping(self, data, palette, order): 

77 levels, lookup_table = super().categorical_mapping(data, palette, order) 

78 if isinstance(palette, dict): 

79 for key in ["left", "right"]: 

80 if key in palette: 

81 if "color:" in palette[key]: 

82 _color = palette[key].replace(" ", "").split(":")[1] 

83 lookup_table[key] = _color 

84 else: 

85 self._palette_dict[key] = True 

86 self._lookup_table[key] = color_palette(palette[key], n_colors=10) 

87 _color = color_palette(palette[key], n_colors=1)[0] 

88 lookup_table[key] = _color 

89 return levels, lookup_table 

90 

91 

92class _CategoricalPlotter(categorical._CategoricalPlotter): 

93 def plot_violins(self, *args, **kwargs): 

94 _kwargs = kwargs.copy() 

95 kde_kws = _kwargs["kde_kws"] 

96 kde_kws.update({"kde_kernel": KDE, "kde_kwargs": KDE_kwargs}) 

97 kde_kws.pop("gridsize", None) 

98 kde_kws.pop("bw_adjust", None) 

99 _kwargs["kde_kws"] = kde_kws 

100 return super().plot_violins(*args, **_kwargs) 

101 

102 

103categorical._CategoricalPlotter = _CategoricalPlotter 

104categorical.KDE = DensityKDE 

105_base.HueMapping = HueMapping 

106 

107 

108def violinplot( 

109 *args, kde_kernel=gaussian_kde, kde_kwargs={}, inj=None, colors=None, 

110 **kwargs 

111): 

112 """Extension of the seaborn.categorical.violinplot function to allow for 

113 a custom kde_kernel and associated kwargs. 

114 

115 Parameters 

116 ---------- 

117 *args: tuple 

118 all args passed to the `seaborn.categorical.violinplot` function 

119 kde_kernel: func, optional 

120 kernel you wish to use to evaluate the KDE. Default 

121 scipy.stats.gaussian_kde 

122 kde_kwargs: dict, optional 

123 optional kwargs to be passed to the kde_kernel. Default {} 

124 inj: float, optional 

125 injected value. Currently ignored, but kept for backwards compatibility 

126 colors: list, optional 

127 list of colors to use for each violin. Default None 

128 **kwargs: dict 

129 all kwargs passed to the `seaborn.categorical.violinplot` class 

130 """ 

131 global KDE 

132 global KDE_kwargs 

133 global colorlist 

134 KDE = kde_kernel 

135 KDE_kwargs = kde_kwargs 

136 colorlist = colors 

137 return categorical.violinplot(*args, **kwargs) 

138 

139 

140def split_dataframe( 

141 left, right, labels, left_label="left", right_label="right", 

142 weights_left=None, weights_right=None 

143): 

144 """Generate a pandas DataFrame containing two sets of distributions -- one 

145 set for the left hand side of the violins, and one set for the right hand 

146 side of the violins 

147 

148 Parameters 

149 ---------- 

150 left: np.ndarray 

151 array of samples representing the left hand side of the violins 

152 right: np.ndarray 

153 array of samples representing the right hand side of the violins 

154 labels: np.array 

155 array containing the label associated with each violin 

156 """ 

157 import pandas 

158 

159 nviolin = len(left) 

160 if len(left) != len(right) != len(labels): 

161 raise ValueError("Please ensure that 'left' == 'right' == 'labels'") 

162 _left_label = np.array( 

163 [[left_label] * len(sample) for sample in left], dtype="object" 

164 ) 

165 _right_label = np.array( 

166 [[right_label] * len(sample) for sample in right], dtype="object" 

167 ) 

168 _labels = [ 

169 [label] * (len(left[num]) + len(right[num])) for num, label in 

170 enumerate(labels) 

171 ] 

172 labels = [x for y in _labels for x in y] 

173 dataframe = [ 

174 x for y in [[i, j] for i, j in zip(left, right)] for x in y 

175 ] 

176 dataframe = [x for y in dataframe for x in y] 

177 sides = [ 

178 x for y in [[i, j] for i, j in zip(_left_label, _right_label)] for x in 

179 y 

180 ] 

181 sides = [x for y in sides for x in y] 

182 df = pandas.DataFrame( 

183 data={"data": dataframe, "side": sides, "label": labels} 

184 ) 

185 if all(kwarg is None for kwarg in [weights_left, weights_right]): 

186 return df 

187 

188 left_inds = df["side"][df["side"] == left_label].index 

189 right_inds = df["side"][df["side"] == right_label].index 

190 if weights_left is not None and weights_right is None: 

191 weights_right = [np.ones(len(right[num])) for num in range(nviolin)] 

192 elif weights_left is None and weights_right is not None: 

193 weights_left = [np.ones(len(left[num])) for num in range(nviolin)] 

194 if any(len(kwarg) != nviolin for kwarg in [weights_left, weights_right]): 

195 raise ValueError("help") 

196 

197 weights = [ 

198 x for y in [[i, j] for i, j in zip(weights_left, weights_right)] for x in y 

199 ] 

200 weights = [x for y in weights for x in y] 

201 df["weights"] = weights 

202 return df