Coverage for pesummary/utils/kde_list.py: 65.8%
38 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1import numpy as np
2from scipy.stats import gaussian_kde
3import multiprocessing
6class KDEList(object):
7 """Class to generate and evaluate a set of KDEs at the same points
9 Parameters
10 ----------
11 samples: np.ndarray
12 2d array of samples to generate kdes
13 kde: func, optional
14 kde function to use. Default scipy.stats.gaussian_kde
15 kde_kwargs: dict, optional
16 kwargs to pass to kde
17 """
18 def __init__(self, samples, kde=gaussian_kde, kde_kwargs={}, pts=None):
19 self.samples = samples
20 if not all(isinstance(_, (list, np.ndarray)) for _ in self.samples):
21 raise ValueError("2d array of samples must be provided")
22 self.kdes = np.array(
23 [kde(_, **kde_kwargs) for _ in self.samples], dtype=object
24 )
25 self.pts = pts
27 def __call__(self, pts=None, multi_process=1, idx=None):
28 if pts is None and self.pts is None:
29 raise ValueError("Please provide a set of points to evaluate the KDE")
30 elif pts is None:
31 pts = self.pts
32 singular_idx = not isinstance(idx, (list, np.ndarray)) and idx is not None
33 if idx is None:
34 idx = np.ones(len(self.kdes), dtype=bool)
35 elif not isinstance(idx, (list, np.ndarray)):
36 idx = np.array([idx])
37 if multi_process == 1:
38 out = np.array(
39 [self._evaluate_single_kde(kde, pts) for kde in self.kdes[idx]]
40 )
41 else:
42 with multiprocessing.Pool(multi_process) as pool:
43 args = np.array([[kde, pts] for kde in self.kdes[idx]], dtype=object)
44 out = np.array(
45 pool.map(KDEList._wrapper_for_evaluate_single_kde, args)
46 )
47 if singular_idx:
48 return out[0]
49 return out
51 def evaluate(self, **kwargs):
52 if self.pts is None:
53 raise ValueError(
54 "No points stored. Please use the __call__ method and provide "
55 "a list of points or re-initalise class"
56 )
57 return self.__call__(self.pts, **kwargs)
59 @staticmethod
60 def _wrapper_for_evaluate_single_kde(args):
61 return KDEList._evaluate_single_kde(*args)
63 @staticmethod
64 def _evaluate_single_kde(kde, pts):
65 return kde(pts)