Coverage for pesummary/gw/classification.py: 86.2%
174 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import importlib
4import os
5import numpy as np
6from scipy.special import logsumexp
7from pesummary.gw.cosmology import hubble_distance, hubble_parameter
8from pesummary.utils.utils import logger
10__author__ = [
11 "Anarya Ray <anarya.ray@ligo.org>",
12 "Charlie Hoy <charlie.hoy@ligo.org>"
13]
15class _Base():
16 """Base class for generating classification probabilities
18 Parameters
19 ----------
20 samples: dict
21 dictionary of posterior samples to use for generating classification
22 probabilities
24 Attributes
25 ----------
26 available_plots: list
27 list of available plotting types
28 """
29 def __init__(self, samples):
30 self.module = self.check_for_install()
31 if not isinstance(samples, dict):
32 raise ValueError("Please provide samples as dictionary")
33 if not all(_p in samples.keys() for _p in self.required_parameters):
34 from pesummary.utils.samples_dict import SamplesDict
35 samples = SamplesDict(samples)
36 samples.generate_all_posterior_samples(disable_remnant=True)
37 if not all(_p in samples.keys() for _p in self.required_parameters):
38 raise ValueError(
39 "Failed to compute classification probabilities because "
40 "the following parameters are required: {}".format(
41 ", ".join(self.required_parameters)
42 )
43 )
44 self.samples = {key: np.array(value) for key, value in samples.items()}
46 @property
47 def required_parameters(self):
48 return ["mass_1_source", "mass_2_source"]
50 @property
51 def available_plots(self):
52 return ["bar"]
54 @classmethod
55 def from_file(cls, filename, **kwargs):
56 """Initiate the classification class with samples stored in file
58 Parameters
59 ----------
60 filename: str
61 path to file you wish to initiate the classification class with
62 """
63 from pesummary.io import read
64 f = read(filename)
65 samples = f.samples_dict
66 return cls(samples, **kwargs)
68 @classmethod
69 def classification_from_file(cls, filename, **kwargs):
70 """Initiate the classification class with samples stored in file and
71 return a dictionary containing the classification probabilities
73 Parameters
74 ----------
75 filename: str
76 path to file you wish to initiate the classification class with
77 **kwargs: dict, optional
78 all kwargs passed to cls.from_file()
79 """
80 _cls = cls.from_file(filename, **kwargs)
81 return _cls.classification()
83 def check_for_install(self, package=None):
84 """Check that the required package is installed. If the package
85 is not installed, raise an ImportError
87 Parameters
88 ----------
89 package: str, optional
90 name of package to check for install. Default None
91 """
92 if package is None:
93 package = self.package
94 if isinstance(package, str):
95 package = [package]
96 _not_available = []
97 for _package in package:
98 try:
99 return importlib.import_module(_package)
100 except ModuleNotFoundError:
101 _not_available.append(_package)
102 if len(_not_available):
103 raise ImportError(
104 "Unable to import {}. Unable to compute classification "
105 "probabilities".format(" or ".join(package))
106 )
108 @staticmethod
109 def round_probabilities(ptable, rounding=5):
110 """Round the entries of a probability table
112 Parameters
113 ----------
114 ptable: dict
115 probability table
116 rounding: int
117 number of decimal places to round the entries of the probability
118 table
119 """
120 for key, value in ptable.items():
121 ptable[key] = np.round(value, rounding)
122 return ptable
124 def plot(self, probabilities, type="bar"):
125 """Generate a plot showing the classification probabilities
127 Parameters
128 ----------
129 probabilities: dict
130 dictionary giving the classification probabilities
131 type: str, optional
132 type of plot to produce
133 """
134 if type not in self.available_plots:
135 raise ValueError(
136 "Unknown plot '{}'. Please select a plot from {}".format(
137 type, ", ".join(self.available_plots)
138 )
139 )
140 return getattr(self, "_{}_plot".format(type))(probabilities)
142 def _bar_plot(self, probabilities):
143 """Generate a bar plot showing classification probabilities
145 Parameters
146 ----------
147 samples: dict
148 samples to use for plotting
149 probabilities: dict
150 dictionary giving the classification probabilities.
151 """
152 from pesummary.gw.plots.plot import _classification_plot
153 return _classification_plot(probabilities)
155 def save_to_file(self, file_name, probabilities, outdir="./", **kwargs):
156 """Save classification data to json file
158 Parameters
159 ----------
160 file_name: str
161 name of the file name that you wish to use
162 probabilities: dict
163 dictionary of probabilities you wish to save to file
164 """
165 from pesummary.io import write
166 write(
167 list(probabilities.keys()), list(probabilities.values()),
168 file_format="json", outdir=outdir, filename=file_name,
169 dataset_name=None, indent=None, **kwargs
170 )
173class PAstro(_Base):
174 """Class for generating source classification probabilities, i.e.
175 the probability that it is consistent with originating from a binary
176 black hole, p(BBH), neutron star black hole, p(NSBH), binary neutron star,
177 p(BNS). We use a rate and evidence based estimate, as detailed in
178 https://dcc.ligo.org/LIGO-G2301521 and described below:
180 The probability for a given classification is simply:
182 ..math ::
183 fraction = \frac{R_{\alpha}Z_{\alpha}}{\sum_{\beta}R_{\beta}Z_{\beta}}
184 P(H_{\alpha}|d) = (1 - P_{\text{Terr}}^{pipeline}) fraction
186 where :math:`Z_{\alpha}` is the Bayesian evidence for each category, estimated as,
188 ..math ::
189 fraction = \frac{p(m_{1s,i},m_{2s,i},z_i|\alpha)}{p(m_{1d,i}m_{2d,i},d_{L,i})\times \frac{dd_L}{dz}\frac{1}{(1+z_i)^2}}
190 Z_{\alpha}=\frac{Z_{PE}}{N_{samp}}\sum_{i\sim\text{posterior}}^{N_{samp}} fraction
192 and we use the following straw-person population prior for classifying the sources
193 into different astrophysical categories
195 ..math ::
196 fraction = \frac{m_{1s}^{\alpha}m_{2s}^{\beta}}{\text{min}(m_{1s},m_{2s,max})^{\beta+1}-m_{2s,min}^{\beta+1}}
197 p(m_{1s},m_{2s},z|\alpha) \propto fraction \frac{dV_c}{dz}\frac{1}{1+z}
199 Parameters
200 ----------
201 samples: dict
202 dictionary of posterior samples to use for generating classification
203 probabilities
204 category_data: dict, optional
205 dictionary of summary data (rates and population hyper parameters) for each
206 category. Default None
207 distance_prior: class, optional
208 class describing the distance prior used when generating the posterior
209 samples. It must have a method `ln_prob` for returning the log prior
210 probability for a given distance. Default
211 bilby.gw.prior.UniformSourceFrame
212 cosmology: str, optional
213 cosmology you wish to use. Default Planck15
214 terrestrial_probability: float, optional
215 probability that the observed gravitational-wave is of terrestrial
216 origin. Default None.
217 catch_terrestrial_probability_error: bool, optional
218 catch the ValueError raised when no terrestrial_probability is provided.
219 If True, terrestrial_probability is set to 0. Default False
221 Attributes
222 ----------
223 available_plots: list
224 list of available plotting types
226 Methods
227 -------
228 classification:
229 return a dictionary containing the classification probabilities
230 """
231 defaults = {"BBH": None, "BNS": None, "NSBH": None}
232 def __init__(
233 self, samples, category_data=None, distance_prior=None,
234 cosmology="Planck15", terrestrial_probability=None,
235 catch_terrestrial_probability_error=False
236 ):
237 self.package = ["bilby.gw.prior"]
238 super(PAstro, self).__init__(samples)
239 self.distance_prior = distance_prior
240 if distance_prior is None:
241 logger.debug(
242 f"No distance prior provided. Assuming the posterior samples "
243 f"were obtained with a 'UniformSourceFrame' prior (with a "
244 f"{cosmology} cosmology), as defined in 'bilby'."
245 )
246 self.distance_prior = self.module.UniformSourceFrame(
247 minimum=float(np.min(self.samples["luminosity_distance"]) * 0.5),
248 maximum=float(np.max(self.samples["luminosity_distance"]) * 1.5),
249 name="luminosity_distance", unit="Mpc",
250 cosmology=cosmology
251 )
252 if category_data is not None and os.path.isfile(category_data):
253 import yaml
254 with open(category_data, "r") as f:
255 config = yaml.full_load(f)
256 category_data = config["pop_prior"]
257 for key, value in config["Rates"].items():
258 category_data[key]["rate"] = float(value)
259 self.category_data = category_data
260 self.cosmology = cosmology
261 self.terrestrial_probability = terrestrial_probability
262 self.catch_terrestrial_probability_error = catch_terrestrial_probability_error
264 @property
265 def required_parameters(self):
266 params = super(PAstro, self).required_parameters
267 params.extend(["luminosity_distance", "redshift"])
268 return params
270 def _salpeter_prior(self, alpha, m1_bounds, m2_bounds, zmax, beta):
271 """Calculate and return the log probabilities assuming a Salpeter population
272 prior
274 Parameters
275 ----------
276 alpha: float
277 index of the powerlaw for the primary mass prior
278 m1_bounds: list
279 list of length 2 which contains the minimum (index 0) and maximum (index 1)
280 primary mass
281 m2_bounds: list
282 list of length 2 which contains the minimum (index 0) and maximum (index 1)
283 secondary mass
284 zmax: float
285 maximum redshift
286 beta: float
287 index of the powerlaw for the secondary mass prior
288 """
289 if alpha != -1:
290 upper = m1_bounds[1]**(1. + alpha)
291 lower = m1_bounds[0]**(1. + alpha)
292 log_m1_norm = np.log((1. + alpha) / (upper - lower))
293 else:
294 log_m1_norm = -np.log(np.log(m1_bounds[1] / m1_bounds[0]))
295 m2_max = np.min(
296 np.array(
297 [
298 m2_bounds[1] * np.ones(len(self.samples["mass_1_source"])),
299 self.samples["mass_1_source"]
300 ]
301 ), axis=0
302 )
303 if beta != -1:
304 upper = m2_max**(1. + beta)
305 lower = m2_bounds[0]**(1. + beta)
306 log_m2_norm = np.log((1. + beta) / (upper - lower))
307 else:
308 log_m2_norm = -np.log(np.log(m2_max / m2_bounds[0]))
309 z_prior = self.module.UniformSourceFrame(
310 name="redshift", minimum=0., maximum=zmax, unit=None
311 )
312 logprob = (
313 alpha * np.log(self.samples["mass_1_source"]) +
314 beta * np.log(self.samples["mass_2_source"]) +
315 log_m1_norm + log_m2_norm +
316 z_prior.ln_prob(self.samples["redshift"])
317 )
318 logprob[np.isnan(logprob)] = -np.inf
319 logprob += np.log(
320 (
321 (self.samples["mass_1_source"] >= self.samples["mass_2_source"]) *
322 (self.samples["mass_1_source"] <= m1_bounds[1]) *
323 (self.samples["mass_1_source"] >= m1_bounds[0]) *
324 (m2_max >= self.samples["mass_2_source"]) *
325 (self.samples["mass_2_source"] >= m2_bounds[0])
326 ).astype(int)
327 )
328 return logprob
330 def classification(self, rounding=5):
331 if self.category_data is None:
332 raise ValueError(
333 "No category data provided to estimate rate weighted evidence. "
334 "Unable to calculate source probabilities."
335 )
336 required_data = [
337 "rate", "alpha", "m1_bounds", "m2_bounds", "zmax", "beta"
338 ]
339 for value in self.category_data.values():
340 if not all(_ in value.keys() for _ in required_data):
341 raise ValueError(
342 "Please provide {} for each category".format(
343 ", ".join(required_data)
344 )
345 )
346 if self.terrestrial_probability is None:
347 if self.catch_terrestrial_probability_error:
348 logger.debug(
349 "Setting terrestrial probability to 0 for classification "
350 "probabilities"
351 )
352 self.terrestrial_probability = 0.
353 else:
354 raise ValueError(
355 "Please provide a terrestrial probability in order to calculate "
356 "classification probabilities. Alternatively pass the kwarg "
357 "catch_terrestrial_probability_error=True"
358 )
359 elif self.terrestrial_probability >= 1:
360 raise ValueError(
361 "Terrestrial probability >= 1 meaning that there is no "
362 "probability that the source is a BBH, NSBH or BNS"
363 )
364 # evaluate population prior
365 pop_log_priors = {
366 category: self._salpeter_prior(
367 config["alpha"], config["m1_bounds"], config["m2_bounds"],
368 config["zmax"], config["beta"]
369 ) for category, config in self.category_data.items()
370 }
371 # evaluate pe-prior
372 pe_log_prior = self.distance_prior.ln_prob(
373 self.samples["luminosity_distance"]
374 )
376 # evaluate detector frame to source frame jacobian
377 hd = hubble_distance(self.cosmology)
378 hp = hubble_parameter(self.cosmology, self.samples["redshift"])
379 ddL_dz = (
380 self.samples["luminosity_distance"] / (1 + self.samples["redshift"]) +
381 (1. + self.samples["redshift"]) * hd / hp
382 )
383 log_jacobian = -np.log(ddL_dz) - 2. * np.log1p(self.samples["redshift"])
385 #compute evidence
386 rate_weighted_evidence = {
387 category: config["rate"] * np.exp(
388 logsumexp(
389 pop_log_priors[category] - pe_log_prior -
390 np.log(len(self.samples["mass_1_source"])) + log_jacobian
391 )
392 ) for category, config in self.category_data.items()
393 }
394 #compute p_astro
395 total_evidence = np.sum(list(rate_weighted_evidence.values()))
396 ptable = {
397 category: (1. - self.terrestrial_probability) * rz / total_evidence
398 for category, rz in rate_weighted_evidence.items()
399 }
400 ptable["Terrestrial"] = self.terrestrial_probability
401 if rounding is not None:
402 return self.round_probabilities(ptable, rounding=rounding)
403 return ptable
405 def _samples_plot(self, probabilities):
406 """Generate a sample distribution plot showing classification
407 probabilities
409 Parameters
410 ----------
411 probabilities: dict
412 dictionary giving the classification probabilities.
413 """
414 from pesummary.gw.plots.plot import _classification_samples_plot
415 return _classification_samples_plot(
416 self.samples["mass_1_source"], self.samples["mass_2_source"],
417 probabilities
418 )
421class EMBright(_Base):
422 """Class for generating EM-Bright classification probabilities, i.e.
423 the probability that the binary has a neutron star, p(HasNS), and
424 the probability that the remnant is observable, p(HasRemnant).
426 Parameters
427 ----------
428 samples: dict
429 dictionary of posterior samples to use for generating classification
430 probabilities
432 Attributes
433 ----------
434 available_plots: list
435 list of available plotting types
437 Methods
438 -------
439 classification:
440 return a dictionary containing the classification probabilities
441 """
442 defaults = {"HasNS": None, "HasRemnant": None, "HasMassGap": None}
443 def __init__(self, samples, **kwargs):
444 self.package = ["ligo.em_bright.em_bright"]
445 super(EMBright, self).__init__(samples)
447 @property
448 def required_parameters(self):
449 params = super(EMBright, self).required_parameters
450 params.extend(["a_1", "a_2", "tilt_1", "tilt_2"])
451 return params
453 def classification(self, rounding=5, **kwargs):
454 probs = self.module.source_classification_pe_from_table(self.samples)
455 ptable = {"HasNS": probs[0], "HasRemnant": probs[1], "HasMassGap": probs[2]}
456 if rounding is not None:
457 return self.round_probabilities(ptable, rounding=rounding)
458 return ptable
461class Classify(_Base):
462 """Class for generating source classification and EM-Bright probabilities,
463 i.e. the probability that it is consistent with originating from a binary
464 black hole, p(BBH), neutron star black hole, p(NSBH), binary neutron star,
465 p(BNS), the probability that the binary has a neutron star, p(HasNS), and
466 the probability that the remnant is observable, p(HasRemnant).
467 """
468 @property
469 def required_parameters(self):
470 params = super(Classify, self).required_parameters
471 params.extend(
472 ["luminosity_distance", "redshift", "a_1", "a_2", "tilt_1", "tilt_2"]
473 )
474 return params
476 def check_for_install(self, *args, **kwargs):
477 pass
479 @classmethod
480 def classification_from_file(cls, filename, **kwargs):
481 """Initiate the classification class with samples stored in file and
482 return a dictionary containing the classification probabilities
484 Parameters
485 ----------
486 filename: str
487 path to file you wish to initiate the classification class with
488 path to file you wish to initiate the classification class with
489 **kwargs: dict, optional
490 all kwargs passed to cls.from_file()
491 """
492 _cls = PAstro.from_file(filename, **kwargs)
493 pastro = _cls.classification()
494 _cls = EMBright.from_file(filename, **kwargs)
495 embright = _cls.classification()
496 pastro.update(embright)
497 return pastro
499 def classification(self, rounding=5, **kwargs):
500 """return a dictionary containing the classification probabilities.
501 """
502 probs = PAstro(self.samples, **kwargs).classification(
503 rounding=rounding
504 )
505 pastro = EMBright(self.samples, **kwargs).classification(
506 rounding=rounding
507 )
508 probs.update(pastro)
509 return probs
512def classify(*args, **kwargs):
513 """Generate source classification and EM-Bright probabilities,
514 i.e. the probability that it is consistent with originating from a binary
515 black hole, p(BBH), neutron star black hole, p(NSBH), binary neutron star,
516 p(BNS), the probability that the binary has a neutron star, p(HasNS), and
517 the probability that the remnant is observable, p(HasRemnant).
519 Parameters
520 ----------
521 samples: dict
522 dictionary of posterior samples to use for generating classification
523 probabilities
524 """
525 return Classify(*args).classification(**kwargs)