Coverage for pesummary/gw/classification.py: 85.8%
183 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import numpy as np
4import importlib
6__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
9class _Base(object):
10 """Base class for generating classification probabilities
12 Parameters
13 ----------
14 samples: dict
15 dictionary of posterior samples to use for generating classification
16 probabilities
18 Attributes
19 ----------
20 available_plots: list
21 list of available plotting types
23 Methods
24 -------
25 classification:
26 return a dictionary containing the classification probabilities. These
27 probabilities can either be generated from the raw samples or samples
28 reweighted to a population inferred prior
29 dual_classification:
30 return a dictionary containing the classification probabilities
31 generated from the raw samples ('default') and samples reweighted to
32 a population inferred prior ('population')
33 plot:
34 generate a plot showing the classification probabilities
35 """
36 def __init__(self, samples):
37 self.module = self.check_for_install()
38 if not isinstance(samples, dict):
39 raise ValueError("Please provide samples as dictionary")
40 if not all(_p in samples.keys() for _p in self.required_parameters):
41 from pesummary.utils.samples_dict import SamplesDict
42 samples = SamplesDict(samples)
43 samples.generate_all_posterior_samples(disable_remnant=True)
44 if not all(_p in samples.keys() for _p in self.required_parameters):
45 raise ValueError(
46 "Failed to compute classification probabilities because "
47 "the following parameters are required: {}".format(
48 ", ".join(self.required_parameters)
49 )
50 )
51 self.samples = self._convert_samples(samples)
53 @classmethod
54 def from_file(cls, filename):
55 """Initiate the classification class with samples stored in file
57 Parameters
58 ----------
59 filename: str
60 path to file you wish to initiate the classification class with
61 """
62 from pesummary.io import read
63 f = read(filename)
64 samples = f.samples_dict
65 return cls(samples)
67 @classmethod
68 def classification_from_file(cls, filename, **kwargs):
69 """Initiate the classification class with samples stored in file and
70 return a dictionary containing the classification probabilities
72 Parameters
73 ----------
74 filename: str
75 path to file you wish to initiate the classification class with
76 **kwargs: dict, optional
77 all kwargs passed to cls.classification()
78 """
79 _cls = cls.from_file(filename)
80 return _cls.classification(**kwargs)
82 @classmethod
83 def dual_classification_from_file(cls, filename, seed=123456789):
84 """Initiate the classification class with samples stored in file and
85 return a dictionary containing the classification probabilities
86 generated from the raw samples ('default') and samples reweighted to
87 a population inferred prior ('population')
89 Parameters
90 ----------
91 filename: str
92 path to file you wish to initiate the classification class with
93 seed: int, optional
94 random seed to use when reweighing to a population inferred prior
95 """
96 _cls = cls.from_file(filename)
97 return _cls.dual_classification(seed=seed)
99 @property
100 def required_parameters(self):
101 return ["mass_1_source", "mass_2_source", "a_1", "a_2"]
103 @property
104 def available_plots(self):
105 return ["bar"]
107 @staticmethod
108 def round_probabilities(ptable, rounding=5):
109 """Round the entries of a probability table
111 Parameters
112 ----------
113 ptable: dict
114 probability table
115 rounding: int
116 number of decimal places to round the entries of the probability
117 table
118 """
119 for key, value in ptable.items():
120 ptable[key] = np.round(value, rounding)
121 return ptable
123 def check_for_install(self, package=None):
124 """Check that the required package is installed. If the package
125 is not installed, raise an ImportError
127 Parameters
128 ----------
129 package: str, optional
130 name of package to check for install. Default None
131 """
132 if package is None:
133 package = self.package
134 if isinstance(package, str):
135 package = [package]
136 _not_available = []
137 for _package in package:
138 try:
139 return importlib.import_module(_package)
140 except ModuleNotFoundError:
141 _not_available.append(_package)
142 if len(_not_available):
143 raise ImportError(
144 "Unable to import {}. Unable to compute classification "
145 "probabilities".format(" or ".join(package))
146 )
148 def _resample_to_population_prior(self, samples=None):
149 """Use the pepredicates.rewt_approx_massdist_redshift function to
150 reweight a pandas DataFrame to a population informed prior
152 Parameters
153 ----------
154 samples: dict, optional
155 pandas DataFrame containing posterior samples
156 """
157 import copy
158 if not self.__class__.__name__ == "PEPredicates":
159 _module = self.check_for_install(package="pepredicates")
160 else:
161 _module = self.module
162 if samples is None:
163 samples = self.samples
164 _samples = copy.deepcopy(samples)
165 if not all(param in _samples.keys() for param in ["redshift", "dist"]):
166 raise ValueError(
167 "Samples for redshift and distance required for population "
168 "reweighting"
169 )
170 return _module.rewt_approx_massdist_redshift(_samples)
172 def dual_classification(self, seed=123456789):
173 """Return a dictionary containing the classification probabilities
174 generated from the raw samples ('default') and samples reweighted to
175 a population inferred prior ('population')
177 Parameters
178 ----------
179 seed: int, optional
180 random seed to use when reweighing to a population inferred prior
181 """
182 return {
183 "default": self.classification(),
184 "population": self.classification(population=True, seed=seed)
185 }
187 def classification(self, population=False, return_samples=False, seed=123456789):
188 """return a dictionary containing the classification probabilities.
189 These probabilities can either be generated from the raw samples or
190 samples reweighted to a population inferred prior
192 Parameters
193 ----------
194 population: Bool, optional
195 if True, reweight the samples to a population informed prior and
196 then calculate classification probabilities. Default False
197 return_samples: Bool, optional
198 if True, return the samples used as well as the classification
199 probabilities
200 seed: int, optional
201 random seed to use when reweighing to a population inferred prior
202 """
203 if not population:
204 ptable = self._compute_classification_probabilities()
205 if return_samples:
206 return self.samples, ptable
207 return ptable
208 np.random.seed(seed)
209 _samples = PEPredicates._convert_samples(self.samples)
210 reweighted_samples = self._resample_to_population_prior(
211 samples=_samples
212 )
213 ptable = self._compute_classification_probabilities(
214 samples=self._convert_samples(reweighted_samples)
215 )
216 if return_samples:
217 return _samples, ptable
218 return ptable
220 def _compute_classification_probabilities(self, samples=None):
221 """Base function to compute classification probabilities
223 Parameters
224 ----------
225 samples: dict, optional
226 samples to use for computing the classification probabilities
227 Default None.
228 """
229 if samples is None:
230 samples = self.samples
231 return samples, {}
233 def plot(
234 self, samples=None, probabilities=None, type="bar", population=False
235 ):
236 """Generate a plot showing the classification probabilities
238 Parameters
239 ----------
240 samples: dict, optional
241 samples to use for plotting. Default None
242 probabilities: dict, optional
243 dictionary giving the classification probabilities. Default None
244 type: str, optional
245 type of plot to produce
246 population: Bool, optional
247 if True, reweight the posterior samples to a population informed
248 prior before computing the classification probabilities for
249 plotting. Only used when probabilities=None. Default False
250 """
251 if type not in self.available_plots:
252 raise ValueError(
253 "Unknown plot '{}'. Please select a plot from {}".format(
254 type, ", ".join(self.available_plots)
255 )
256 )
257 if (probabilities is None) or ((samples is None) and population):
258 s, p = self.classification(
259 population=population, return_samples=True
260 )
261 if probabilities is None:
262 probabilities = p
263 if ((samples is None) and population):
264 samples = s
265 return getattr(self, "_{}_plot".format(type))(samples, probabilities)
267 def _bar_plot(self, samples, probabilities):
268 """Generate a bar plot showing classification probabilities
270 Parameters
271 ----------
272 samples: dict
273 samples to use for plotting
274 probabilities: dict
275 dictionary giving the classification probabilities.
276 """
277 from pesummary.gw.plots.plot import _classification_plot
278 return _classification_plot(probabilities)
281class PEPredicates(_Base):
282 """Class for generating source classification probabilities, i.e.
283 the probability that it is consistent with originating from a binary
284 black hole, p(BBH), neutron star black hole, p(NSBH), binary neutron star,
285 p(BNS), or a binary originating from the mass gap, p(MassGap)
287 Parameters
288 ----------
289 samples: dict
290 dictionary of posterior samples to use for generating classification
291 probabilities
293 Attributes
294 ----------
295 available_plots: list
296 list of available plotting types
298 Methods
299 -------
300 classification:
301 return a dictionary containing the classification probabilities. These
302 probabilities can either be generated from the raw samples or samples
303 reweighted to a population inferred prior
304 dual_classification:
305 return a dictionary containing the classification probabilities
306 generated from the raw samples ('default') and samples reweighted to
307 a population inferred prior ('population')
308 plot:
309 generate a plot showing the classification probabilities
310 """
311 def __init__(self, samples):
312 self.package = "pepredicates"
313 super(PEPredicates, self).__init__(samples)
315 @property
316 def _default_probabilities(self):
317 return {
318 'BNS': self.module.BNS_p, 'NSBH': self.module.NSBH_p,
319 'BBH': self.module.BBH_p, 'MassGap': self.module.MG_p
320 }
322 @property
323 def available_plots(self):
324 _plots = super(PEPredicates, self).available_plots
325 _plots.extend(["pepredicates"])
326 return _plots
328 @staticmethod
329 def mapping(reverse=False):
330 _mapping = {
331 "mass_1_source": "m1_source", "mass_2_source": "m2_source",
332 "luminosity_distance": "dist", "redshift": "redshift",
333 "a_1": "a1", "a_2": "a2", "tilt_1": "tilt1", "tilt_2": "tilt2"
334 }
335 if reverse:
336 return {item: key for key, item in _mapping.items()}
337 return _mapping
339 @staticmethod
340 def _convert_samples(samples):
341 """Convert dictionary of posterior samples to required form
342 needed for pepredicates
344 Parameters
345 ----------
346 samples: dict
347 samples to use for computing the classification probabilities
348 """
349 import pandas as pd
350 mapping = PEPredicates.mapping()
351 reverse = PEPredicates.mapping(reverse=True)
352 if not all(param in samples.keys() for param in reverse.keys()):
353 _samples = {
354 new: samples[original] for original, new in mapping.items()
355 if original in samples.keys()
356 }
357 else:
358 _samples = samples.copy()
359 return pd.DataFrame.from_dict(_samples)
361 def _compute_classification_probabilities(self, samples=None, rounding=5):
362 """Compute classification probabilities
364 Parameters
365 ----------
366 samples: dict, optional
367 samples to use for computing the classification probabilities
368 Default None.
369 rounding: int, optional
370 number of decimal places to round entries of probability table.
371 Default 5
372 """
373 samples, _ = super()._compute_classification_probabilities(
374 samples=samples
375 )
376 ptable = self.module.predicate_table(
377 self._default_probabilities, samples
378 )
379 if rounding is not None:
380 return self.round_probabilities(ptable, rounding=rounding)
381 return ptable
383 def _pepredicates_plot(self, samples, probabilities):
384 """Generate the a plot using the pepredicates.plot_predicates function
385 showing classification probabilities
387 Parameters
388 ----------
389 samples: dict
390 samples to use for plotting
391 probabilities: dict
392 dictionary giving the classification probabilities.
393 """
394 from pesummary.core.plots.figure import ExistingFigure
395 if samples is None:
396 from pesummary.utils.utils import logger
397 logger.debug(
398 "No samples provided for plotting. Using cached array."
399 )
400 samples = self.samples
401 idxs = {
402 "BBH": self.module.is_BBH(samples),
403 "BNS": self.module.is_BNS(samples),
404 "NSBH": self.module.is_NSBH(samples),
405 "MassGap": self.module.is_MG(samples)
406 }
407 return ExistingFigure(
408 self.module.plot_predicates(
409 idxs, samples, probs=probabilities
410 )
411 )
414class PAstro(_Base):
415 """Class for generating EM-Bright classification probabilities, i.e.
416 the probability that the binary has a neutron star, p(HasNS), and
417 the probability that the remnant is observable, p(HasRemnant).
419 Parameters
420 ----------
421 samples: dict
422 dictionary of posterior samples to use for generating classification
423 probabilities
425 Attributes
426 ----------
427 available_plots: list
428 list of available plotting types
430 Methods
431 -------
432 classification:
433 return a dictionary containing the classification probabilities. These
434 probabilities can either be generated from the raw samples or samples
435 reweighted to a population inferred prior
436 dual_classification:
437 return a dictionary containing the classification probabilities
438 generated from the raw samples ('default') and samples reweighted to
439 a population inferred prior ('population')
440 plot:
441 generate a plot showing the classification probabilities
442 """
443 def __init__(self, samples):
444 self.package = "ligo.em_bright.em_bright"
445 super(PAstro, self).__init__(samples)
447 @property
448 def required_parameters(self):
449 _parameters = super(PAstro, self).required_parameters
450 _parameters.extend(["tilt_1", "tilt_2"])
451 return _parameters
453 @staticmethod
454 def _convert_samples(samples):
455 """Convert dictionary of posterior samples to required form
456 needed for ligo.computeDiskMass
458 Parameters
459 ----------
460 samples: dict
461 samples to use for computing the classification probabilities
462 """
463 _samples = {}
464 try:
465 reverse = PEPredicates.mapping(reverse=True)
466 for key, item in reverse.items():
467 if key in samples.keys():
468 samples[item] = samples.pop(key)
469 except KeyError:
470 pass
471 for key, item in samples.items():
472 _samples[key] = np.asarray(item)
473 return _samples
475 def _compute_classification_probabilities(self, samples=None, rounding=5):
476 """Compute classification probabilities
478 Parameters
479 ----------
480 samples: dict, optional
481 samples to use for computing the classification probabilities
482 Default None.
483 rounding: int, optional
484 number of decimal places to round entries of probability table.
485 Default 5
486 """
487 samples, _ = super()._compute_classification_probabilities(
488 samples=samples
489 )
490 probs = self.module.source_classification_pe_from_table(samples)
491 ptable = {"HasNS": probs[0], "HasRemnant": probs[1]}
492 if rounding is not None:
493 return self.round_probabilities(ptable, rounding=rounding)
494 return ptable
497class Classify(_Base):
498 """Class for generating source classification and EM-Bright probabilities,
499 i.e. the probability that it is consistent with originating from a binary
500 black hole, p(BBH), neutron star black hole, p(NSBH), binary neutron star,
501 p(BNS), or a binary originating from the mass gap, p(MassGap), the
502 probability that the binary has a neutron star, p(HasNS), and the
503 probability that the remnant is observable, p(HasRemnant).
504 """
505 @property
506 def required_parameters(self):
507 _parameters = super(Classify, self).required_parameters
508 _parameters.extend(["tilt_1", "tilt_2"])
509 return _parameters
511 def check_for_install(self, *args, **kwargs):
512 pass
514 def _convert_samples(self, samples):
515 return samples
517 def classification(self, **kwargs):
518 """return a dictionary containing the classification probabilities.
519 These probabilities can either be generated from the raw samples or
520 samples reweighted to a population inferred prior
522 Parameters
523 ----------
524 population: Bool, optional
525 if True, reweight the samples to a population informed prior and
526 then calculate classification probabilities. Default False
527 return_samples: Bool, optional
528 if True, return the samples used as well as the classification
529 probabilities
530 """
531 probs = PEPredicates(self.samples).classification(**kwargs)
532 pastro = PAstro(self.samples).classification(**kwargs)
533 probs.update(pastro)
534 return probs
537def classify(*args, **kwargs):
538 """Generate source classification and EM-Bright probabilities,
539 i.e. the probability that it is consistent with originating from a binary
540 black hole, p(BBH), neutron star black hole, p(NSBH), binary neutron star,
541 p(BNS), or a binary originating from the mass gap, p(MassGap), the
542 probability that the binary has a neutron star, p(HasNS), and the
543 probability that the remnant is observable, p(HasRemnant).
545 Parameters
546 ----------
547 samples: dict
548 dictionary of posterior samples to use for generating classification
549 probabilities
550 population: Bool, optional
551 if True, reweight the samples to a population informed prior and
552 then calculate classification probabilities. Default False
553 return_samples: Bool, optional
554 if True, return the samples used as well as the classification
555 probabilities
556 """
557 return Classify(*args).classification(**kwargs)