Coverage for pesummary/core/plots/seaborn/kde.py: 82.2%

45 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3from scipy import stats 

4from pesummary.core.plots.seaborn import SEABORN 

5if SEABORN: 

6 from seaborn._statistics import KDE as _StatisticsKDE 

7 from seaborn import distributions 

8else: 

9 class _StatisticsKDE(object): 

10 """Dummy class for the KDE class to inherit 

11 """ 

12 

13 class distributions(object): 

14 class _DistributionPlotter(object): 

15 """Dummy class for the _DistributionPlotter class to inherit 

16 """ 

17 

18 def kdeplot(*args, **kwargs): 

19 """Dummy function to call 

20 """ 

21 raise ValueError("Unable to produce kdeplot with 'seaborn'") 

22 

23 

24__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>", "Seaborn authors"] 

25 

26 

27class _BaseKDE(object): 

28 """Extension of the `seaborn._statistics.KDE` to allow for custom 

29 kde_kernel 

30 

31 Parameters 

32 ---------- 

33 *args: tuple 

34 all args passed to the `seaborn._statistics.KDE` class 

35 kde_kernel: func, optional 

36 kernel you wish to use to evaluate the KDE. Default 

37 scipy.stats.gaussian_kde 

38 kde_kwargs: dict, optional 

39 optional kwargs to be passed to the kde_kernel. Default {} 

40 **kwargs: dict 

41 all kwargs passed to the `seaborn._statistics.KDE` class 

42 """ 

43 def __init__(self, *args, kde_kernel=stats.gaussian_kde, kde_kwargs={}, **kwargs): 

44 _kwargs = kwargs.copy() 

45 for key, item in kwargs.items(): 

46 if key == "bw_method" or key == "bw_adjust": 

47 setattr(self, key, item) 

48 _kwargs.pop(key, None) 

49 elif key not in ["gridsize", "cut", "clip", "cumulative"]: 

50 kde_kwargs[key] = item 

51 _kwargs.pop(key, None) 

52 super().__init__(*args, **_kwargs) 

53 if kde_kernel is None: 

54 kde_kernel = stats.gaussian_kde 

55 self._kde_kernel = kde_kernel 

56 self._kde_kwargs = kde_kwargs 

57 

58 def _fit(self, fit_data, weights=None): 

59 """Fit the scipy kde while adding bw_adjust logic and version check.""" 

60 fit_kws = self._kde_kwargs 

61 fit_kws["bw_method"] = self.bw_method 

62 if weights is not None: 

63 fit_kws["weights"] = weights 

64 

65 kde = self._kde_kernel(fit_data, **fit_kws) 

66 kde.set_bandwidth(kde.factor * self.bw_adjust) 

67 return kde 

68 

69 

70class StatisticsKDE(_BaseKDE, _StatisticsKDE): 

71 """Extension of the `seaborn._statistics.KDE` to allow for custom 

72 kde_kernel 

73 

74 Parameters 

75 ---------- 

76 *args: tuple 

77 all args passed to the `seaborn._statistics.KDE` class 

78 kde_kernel: func, optional 

79 kernel you wish to use to evaluate the KDE. Default 

80 scipy.stats.gaussian_kde 

81 kde_kwargs: dict, optional 

82 optional kwargs to be passed to the kde_kernel. Default {} 

83 **kwargs: dict 

84 all kwargs passed to the `seaborn._statistics.KDE` class 

85 """ 

86 

87 

88class _DistributionPlotter(distributions._DistributionPlotter): 

89 def _compute_univariate_density( 

90 self, data_variable, common_norm, common_grid, estimate_kws, 

91 log_scale, **kwargs 

92 ): 

93 estimate_kws.update({"kde_kernel": KDE, "kde_kwargs": KDE_kwargs}) 

94 return super()._compute_univariate_density( 

95 data_variable, common_norm, common_grid, estimate_kws, 

96 log_scale, **kwargs 

97 ) 

98 

99 

100distributions.KDE = StatisticsKDE 

101distributions._DistributionPlotter = _DistributionPlotter 

102 

103 

104def kdeplot(*args, kde_kernel=stats.gaussian_kde, kde_kwargs={}, **kwargs): 

105 """Extension of the seaborn.distributions.kdeplot function to allow for 

106 a custom kde_kernel and associated kwargs. 

107 

108 Parameters 

109 ---------- 

110 *args: tuple 

111 all args passed to the `seaborn.distributions.kdeplot` function 

112 kde_kernel: func, optional 

113 kernel you wish to use to evaluate the KDE. Default 

114 scipy.stats.gaussian_kde 

115 kde_kwargs: dict, optional 

116 optional kwargs to be passed to the kde_kernel. Default {} 

117 **kwargs: dict 

118 all kwargs passed to the `seaborn.distributions.kdeplot` class 

119 """ 

120 global KDE 

121 global KDE_kwargs 

122 KDE = kde_kernel 

123 KDE_kwargs = kde_kwargs 

124 return distributions.kdeplot(*args, **kwargs)