Coverage for pesummary/utils/kde_list.py: 84.2%

38 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1import numpy as np 

2from scipy.stats import gaussian_kde 

3import multiprocessing 

4 

5 

6class KDEList(object): 

7 """Class to generate and evaluate a set of KDEs at the same points 

8 

9 Parameters 

10 ---------- 

11 samples: np.ndarray 

12 2d array of samples to generate kdes 

13 kde: func, optional 

14 kde function to use. Default scipy.stats.gaussian_kde 

15 kde_kwargs: dict, optional 

16 kwargs to pass to kde 

17 """ 

18 def __init__(self, samples, kde=gaussian_kde, kde_kwargs={}, pts=None): 

19 self.samples = samples 

20 if not all(isinstance(_, (list, np.ndarray)) for _ in self.samples): 

21 raise ValueError("2d array of samples must be provided") 

22 self.kdes = np.array( 

23 [kde(_, **kde_kwargs) for _ in self.samples], dtype=object 

24 ) 

25 self.pts = pts 

26 

27 def __call__(self, pts=None, multi_process=1, idx=None): 

28 if pts is None and self.pts is None: 

29 raise ValueError("Please provide a set of points to evaluate the KDE") 

30 elif pts is None: 

31 pts = self.pts 

32 singular_idx = not isinstance(idx, (list, np.ndarray)) and idx is not None 

33 if idx is None: 

34 idx = np.ones(len(self.kdes), dtype=bool) 

35 elif not isinstance(idx, (list, np.ndarray)): 

36 idx = np.array([idx]) 

37 if multi_process == 1: 

38 out = np.array( 

39 [self._evaluate_single_kde(kde, pts) for kde in self.kdes[idx]] 

40 ) 

41 else: 

42 with multiprocessing.Pool(multi_process) as pool: 

43 args = np.array([[kde, pts] for kde in self.kdes[idx]], dtype=object) 

44 out = np.array( 

45 pool.map(KDEList._wrapper_for_evaluate_single_kde, args) 

46 ) 

47 if singular_idx: 

48 return out[0] 

49 return out 

50 

51 def evaluate(self, **kwargs): 

52 if self.pts is None: 

53 raise ValueError( 

54 "No points stored. Please use the __call__ method and provide " 

55 "a list of points or re-initalise class" 

56 ) 

57 return self.__call__(self.pts, **kwargs) 

58 

59 @staticmethod 

60 def _wrapper_for_evaluate_single_kde(args): 

61 return KDEList._evaluate_single_kde(*args) 

62 

63 @staticmethod 

64 def _evaluate_single_kde(kde, pts): 

65 return kde(pts)