Coverage for pesummary/core/plots/seaborn/kde.py: 64.8%

176 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import numpy as np 

4import warnings 

5from scipy import stats 

6from seaborn.distributions import ( 

7 _DistributionPlotter as SeabornDistributionPlotter, KDE as SeabornKDE, 

8) 

9from seaborn.utils import _normalize_kwargs, _check_argument 

10 

11import pandas as pd 

12 

13__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>", "Seaborn authors"] 

14 

15 

16class KDE(SeabornKDE): 

17 """Extension of the `seaborn._statistics.KDE` to allow for custom 

18 kde_kernel 

19 

20 Parameters 

21 ---------- 

22 *args: tuple 

23 all args passed to the `seaborn._statistics.KDE` class 

24 kde_kernel: func, optional 

25 kernel you wish to use to evaluate the KDE. Default 

26 scipy.stats.gaussian_kde 

27 kde_kwargs: dict, optional 

28 optional kwargs to be passed to the kde_kernel. Default {} 

29 **kwargs: dict 

30 all kwargs passed to the `seaborn._statistics.KDE` class 

31 """ 

32 def __init__( 

33 self, *args, kde_kernel=stats.gaussian_kde, kde_kwargs={}, **kwargs 

34 ): 

35 super(KDE, self).__init__(*args, **kwargs) 

36 self._kde_kernel = kde_kernel 

37 self._kde_kwargs = kde_kwargs 

38 

39 def _fit(self, fit_data, weights=None): 

40 """Fit the scipy kde while adding bw_adjust logic and version check.""" 

41 fit_kws = self._kde_kwargs 

42 fit_kws["bw_method"] = self.bw_method 

43 if weights is not None: 

44 fit_kws["weights"] = weights 

45 

46 kde = self._kde_kernel(fit_data, **fit_kws) 

47 kde.set_bandwidth(kde.factor * self.bw_adjust) 

48 return kde 

49 

50 

51class _DistributionPlotter(SeabornDistributionPlotter): 

52 """Extension of the `seaborn._statistics._DistributionPlotter` to allow for 

53 the custom KDE method to be used 

54 

55 Parameters 

56 ---------- 

57 *args: tuple 

58 all args passed to the `seaborn._statistics._DistributionPlotter` class 

59 **kwargs: dict 

60 all kwargs passed to the `seaborn._statistics._DistributionPlotter` 

61 class 

62 """ 

63 def __init__(self, *args, **kwargs): 

64 super(_DistributionPlotter, self).__init__(*args, **kwargs) 

65 

66 def plot_univariate_density( 

67 self, 

68 multiple, 

69 common_norm, 

70 common_grid, 

71 fill, 

72 legend, 

73 estimate_kws, 

74 variance_atol, 

75 **plot_kws, 

76 ): 

77 

78 import matplotlib as mpl 

79 # Handle conditional defaults 

80 if fill is None: 

81 fill = multiple in ("stack", "fill") 

82 

83 # Preprocess the matplotlib keyword dictionaries 

84 if fill: 

85 artist = mpl.collections.PolyCollection 

86 else: 

87 artist = mpl.lines.Line2D 

88 plot_kws = _normalize_kwargs(plot_kws, artist) 

89 

90 # Input checking 

91 _check_argument("multiple", ["layer", "stack", "fill"], multiple) 

92 

93 # Always share the evaluation grid when stacking 

94 subsets = bool(set(self.variables) - {"x", "y"}) 

95 if subsets and multiple in ("stack", "fill"): 

96 common_grid = True 

97 

98 # Check if the data axis is log scaled 

99 log_scale = self._log_scaled(self.data_variable) 

100 

101 # Do the computation 

102 densities = self._compute_univariate_density( 

103 self.data_variable, 

104 common_norm, 

105 common_grid, 

106 estimate_kws, 

107 log_scale, 

108 variance_atol, 

109 ) 

110 

111 # Note: raises when no hue and multiple != layer. A problem? 

112 densities, baselines = self._resolve_multiple(densities, multiple) 

113 

114 # Control the interaction with autoscaling by defining sticky_edges 

115 # i.e. we don't want autoscale margins below the density curve 

116 sticky_density = (0, 1) if multiple == "fill" else (0, np.inf) 

117 

118 if multiple == "fill": 

119 # Filled plots should not have any margins 

120 sticky_support = densities.index.min(), densities.index.max() 

121 else: 

122 sticky_support = [] 

123 

124 # Handle default visual attributes 

125 if "hue" not in self.variables: 

126 if self.ax is None: 

127 color = plot_kws.pop("color", None) 

128 default_color = "C0" if color is None else color 

129 else: 

130 if fill: 

131 if self.var_types[self.data_variable] == "datetime": 

132 # Avoid drawing empty fill_between on date axis 

133 # https://github.com/matplotlib/matplotlib/issues/17586 

134 scout = None 

135 default_color = plot_kws.pop( 

136 "color", plot_kws.pop("facecolor", None) 

137 ) 

138 if default_color is None: 

139 default_color = "C0" 

140 else: 

141 alpha_shade = plot_kws.pop("alpha_shade", 0.25) 

142 scout = self.ax.fill_between([], [], **plot_kws) 

143 default_color = tuple(scout.get_facecolor().squeeze()) 

144 plot_kws.pop("color", None) 

145 else: 

146 plot_kws.pop("alpha_shade", 0.25) 

147 scout, = self.ax.plot([], [], **plot_kws) 

148 default_color = scout.get_color() 

149 if scout is not None: 

150 scout.remove() 

151 

152 plot_kws.pop("color", None) 

153 

154 default_alpha = .25 if multiple == "layer" else .75 

155 alpha = plot_kws.pop("alpha", default_alpha) # TODO make parameter? 

156 

157 # Now iterate through the subsets and draw the densities 

158 # We go backwards so stacked densities read from top-to-bottom 

159 for sub_vars, _ in self.iter_data("hue", reverse=True): 

160 

161 # Extract the support grid and density curve for this level 

162 key = tuple(sub_vars.items()) 

163 try: 

164 density = densities[key] 

165 except KeyError: 

166 continue 

167 support = density.index 

168 fill_from = baselines[key] 

169 

170 ax = self._get_axes(sub_vars) 

171 

172 # Modify the matplotlib attributes from semantic mapping 

173 if "hue" in self.variables: 

174 color = self._hue_map(sub_vars["hue"]) 

175 else: 

176 color = default_color 

177 

178 artist_kws = self._artist_kws( 

179 plot_kws, fill, False, multiple, color, alpha 

180 ) 

181 

182 # Either plot a curve with observation values on the x axis 

183 if "x" in self.variables: 

184 

185 if fill: 

186 artist = ax.fill_between( 

187 support, fill_from, density, **artist_kws 

188 ) 

189 else: 

190 artist, = ax.plot(support, density, **artist_kws) 

191 

192 artist.sticky_edges.x[:] = sticky_support 

193 artist.sticky_edges.y[:] = sticky_density 

194 

195 # Or plot a curve with observation values on the y axis 

196 else: 

197 if fill: 

198 artist = ax.fill_betweenx( 

199 support, fill_from, density, **artist_kws 

200 ) 

201 else: 

202 artist, = ax.plot(density, support, **artist_kws) 

203 

204 artist.sticky_edges.x[:] = sticky_density 

205 artist.sticky_edges.y[:] = sticky_support 

206 

207 # --- Finalize the plot ---- 

208 

209 ax = self.ax if self.ax is not None else self.facets.axes.flat[0] 

210 default_x = default_y = "" 

211 if self.data_variable == "x": 

212 default_y = "Density" 

213 if self.data_variable == "y": 

214 default_x = "Density" 

215 self._add_axis_labels(ax, default_x, default_y) 

216 

217 if "hue" in self.variables and legend: 

218 from functools import partial 

219 if fill: 

220 artist = partial(mpl.patches.Patch) 

221 else: 

222 artist = partial(mpl.lines.Line2D, [], []) 

223 

224 ax_obj = self.ax if self.ax is not None else self.facets 

225 self._add_legend( 

226 ax_obj, artist, fill, False, multiple, alpha, plot_kws, {}, 

227 ) 

228 

229 def _compute_univariate_density( 

230 self, 

231 data_variable, 

232 common_norm, 

233 common_grid, 

234 estimate_kws, 

235 log_scale, 

236 variance_atol, 

237 ): 

238 

239 # Initialize the estimator object 

240 estimator = KDE(**estimate_kws) 

241 

242 all_data = self.plot_data.dropna() 

243 

244 if set(self.variables) - {"x", "y"}: 

245 if common_grid: 

246 all_observations = self.comp_data.dropna() 

247 estimator.define_support(all_observations[data_variable]) 

248 else: 

249 common_norm = False 

250 

251 densities = {} 

252 

253 for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True): 

254 

255 # Extract the data points from this sub set and remove nulls 

256 sub_data = sub_data.dropna() 

257 observations = sub_data[data_variable] 

258 

259 observation_variance = observations.var() 

260 if np.isclose(observation_variance, 0, atol=variance_atol) or np.isnan(observation_variance): 

261 msg = "Dataset has 0 variance; skipping density estimate." 

262 warnings.warn(msg, UserWarning) 

263 continue 

264 

265 # Extract the weights for this subset of observations 

266 if "weights" in self.variables: 

267 weights = sub_data["weights"] 

268 else: 

269 weights = None 

270 

271 # Estimate the density of observations at this level 

272 density, support = estimator(observations, weights=weights) 

273 

274 if log_scale: 

275 support = np.power(10, support) 

276 

277 # Apply a scaling factor so that the integral over all subsets is 1 

278 if common_norm: 

279 density *= len(sub_data) / len(all_data) 

280 

281 # Store the density for this level 

282 key = tuple(sub_vars.items()) 

283 densities[key] = pd.Series(density, index=support) 

284 

285 return densities 

286 

287 

288def kdeplot( 

289 x=None, # Allow positional x, because behavior will not change with reorg 

290 *, 

291 y=None, 

292 shade=None, # Note "soft" deprecation, explained below 

293 vertical=False, # Deprecated 

294 kernel=None, # Deprecated 

295 bw=None, # Deprecated 

296 gridsize=200, # TODO maybe depend on uni/bivariate? 

297 cut=3, clip=None, legend=True, cumulative=False, 

298 shade_lowest=None, # Deprecated, controlled with levels now 

299 cbar=False, cbar_ax=None, cbar_kws=None, 

300 ax=None, 

301 

302 # New params 

303 weights=None, # TODO note that weights is grouped with semantics 

304 hue=None, palette=None, hue_order=None, hue_norm=None, 

305 multiple="layer", common_norm=True, common_grid=False, 

306 levels=10, thresh=.05, 

307 bw_method="scott", bw_adjust=1, log_scale=None, 

308 color=None, fill=None, kde_kernel=stats.gaussian_kde, kde_kwargs={}, 

309 variance_atol=1e-8, 

310 

311 # Renamed params 

312 data=None, data2=None, 

313 

314 **kwargs, 

315): 

316 

317 if kde_kernel is None: 

318 kde_kernel = stats.gaussian_kde 

319 # Handle deprecation of `data2` as name for y variable 

320 if data2 is not None: 

321 

322 y = data2 

323 

324 # If `data2` is present, we need to check for the `data` kwarg being 

325 # used to pass a vector for `x`. We'll reassign the vectors and warn. 

326 # We need this check because just passing a vector to `data` is now 

327 # technically valid. 

328 

329 x_passed_as_data = ( 

330 x is None 

331 and data is not None 

332 and np.ndim(data) == 1 

333 ) 

334 

335 if x_passed_as_data: 

336 msg = "Use `x` and `y` rather than `data` `and `data2`" 

337 x = data 

338 else: 

339 msg = "The `data2` param is now named `y`; please update your code" 

340 

341 warnings.warn(msg, FutureWarning) 

342 

343 # Handle deprecation of `vertical` 

344 if vertical: 

345 msg = ( 

346 "The `vertical` parameter is deprecated and will be removed in a " 

347 "future version. Assign the data to the `y` variable instead." 

348 ) 

349 warnings.warn(msg, FutureWarning) 

350 x, y = y, x 

351 

352 # Handle deprecation of `bw` 

353 if bw is not None: 

354 msg = ( 

355 "The `bw` parameter is deprecated in favor of `bw_method` and " 

356 f"`bw_adjust`. Using {bw} for `bw_method`, but please " 

357 "see the docs for the new parameters and update your code." 

358 ) 

359 warnings.warn(msg, FutureWarning) 

360 bw_method = bw 

361 

362 # Handle deprecation of `kernel` 

363 if kernel is not None: 

364 msg = ( 

365 "Support for alternate kernels has been removed. " 

366 "Using Gaussian kernel." 

367 ) 

368 warnings.warn(msg, UserWarning) 

369 

370 # Handle deprecation of shade_lowest 

371 if shade_lowest is not None: 

372 if shade_lowest: 

373 thresh = 0 

374 msg = ( 

375 "`shade_lowest` is now deprecated in favor of `thresh`. " 

376 f"Setting `thresh={thresh}`, but please update your code." 

377 ) 

378 warnings.warn(msg, UserWarning) 

379 

380 # Handle `n_levels` 

381 # This was never in the formal API but it was processed, and appeared in an 

382 # example. We can treat as an alias for `levels` now and deprecate later. 

383 levels = kwargs.pop("n_levels", levels) 

384 

385 # Handle "soft" deprecation of shade `shade` is not really the right 

386 # terminology here, but unlike some of the other deprecated parameters it 

387 # is probably very commonly used and much hard to remove. This is therefore 

388 # going to be a longer process where, first, `fill` will be introduced and 

389 # be used throughout the documentation. In 0.12, when kwarg-only 

390 # enforcement hits, we can remove the shade/shade_lowest out of the 

391 # function signature all together and pull them out of the kwargs. Then we 

392 # can actually fire a FutureWarning, and eventually remove. 

393 if shade is not None: 

394 fill = shade 

395 

396 # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 

397 

398 p = _DistributionPlotter( 

399 data=data, 

400 variables=_DistributionPlotter.get_semantics(locals()), 

401 ) 

402 

403 p.map_hue(palette=palette, order=hue_order, norm=hue_norm) 

404 

405 if ax is None: 

406 import matplotlib.pyplot as plt 

407 ax = plt.gca() 

408 

409 # Check for a specification that lacks x/y data and return early 

410 if not p.has_xy_data: 

411 return ax 

412 

413 # Pack the kwargs for statistics.KDE 

414 estimate_kws = dict( 

415 bw_method=bw_method, 

416 bw_adjust=bw_adjust, 

417 gridsize=gridsize, 

418 cut=cut, 

419 clip=clip, 

420 cumulative=cumulative, 

421 kde_kernel=kde_kernel, 

422 kde_kwargs=kde_kwargs 

423 ) 

424 

425 p._attach(ax, allowed_types=["numeric", "datetime"], log_scale=log_scale) 

426 

427 if p.univariate: 

428 

429 plot_kws = kwargs.copy() 

430 if color is not None: 

431 plot_kws["color"] = color 

432 

433 p.plot_univariate_density( 

434 multiple=multiple, 

435 common_norm=common_norm, 

436 common_grid=common_grid, 

437 fill=fill, 

438 legend=legend, 

439 estimate_kws=estimate_kws, 

440 variance_atol=variance_atol, 

441 **plot_kws, 

442 ) 

443 

444 else: 

445 

446 p.plot_bivariate_density( 

447 common_norm=common_norm, 

448 fill=fill, 

449 levels=levels, 

450 thresh=thresh, 

451 legend=legend, 

452 color=color, 

453 cbar=cbar, 

454 cbar_ax=cbar_ax, 

455 cbar_kws=cbar_kws, 

456 estimate_kws=estimate_kws, 

457 **kwargs, 

458 ) 

459 

460 return ax