Coverage for pesummary/gw/classification.py: 86.2%

174 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import importlib 

4import os 

5import numpy as np 

6from scipy.special import logsumexp 

7from pesummary.gw.cosmology import hubble_distance, hubble_parameter 

8from pesummary.utils.utils import logger 

9 

10__author__ = [ 

11 "Anarya Ray <anarya.ray@ligo.org>", 

12 "Charlie Hoy <charlie.hoy@ligo.org>" 

13] 

14 

15class _Base(): 

16 """Base class for generating classification probabilities 

17 

18 Parameters 

19 ---------- 

20 samples: dict 

21 dictionary of posterior samples to use for generating classification 

22 probabilities 

23 

24 Attributes 

25 ---------- 

26 available_plots: list 

27 list of available plotting types 

28 """ 

29 def __init__(self, samples): 

30 self.module = self.check_for_install() 

31 if not isinstance(samples, dict): 

32 raise ValueError("Please provide samples as dictionary") 

33 if not all(_p in samples.keys() for _p in self.required_parameters): 

34 from pesummary.utils.samples_dict import SamplesDict 

35 samples = SamplesDict(samples) 

36 samples.generate_all_posterior_samples(disable_remnant=True) 

37 if not all(_p in samples.keys() for _p in self.required_parameters): 

38 raise ValueError( 

39 "Failed to compute classification probabilities because " 

40 "the following parameters are required: {}".format( 

41 ", ".join(self.required_parameters) 

42 ) 

43 ) 

44 self.samples = {key: np.array(value) for key, value in samples.items()} 

45 

46 @property 

47 def required_parameters(self): 

48 return ["mass_1_source", "mass_2_source"] 

49 

50 @property 

51 def available_plots(self): 

52 return ["bar"] 

53 

54 @classmethod 

55 def from_file(cls, filename, **kwargs): 

56 """Initiate the classification class with samples stored in file 

57 

58 Parameters 

59 ---------- 

60 filename: str 

61 path to file you wish to initiate the classification class with 

62 """ 

63 from pesummary.io import read 

64 f = read(filename) 

65 samples = f.samples_dict 

66 return cls(samples, **kwargs) 

67 

68 @classmethod 

69 def classification_from_file(cls, filename, **kwargs): 

70 """Initiate the classification class with samples stored in file and 

71 return a dictionary containing the classification probabilities 

72 

73 Parameters 

74 ---------- 

75 filename: str 

76 path to file you wish to initiate the classification class with 

77 **kwargs: dict, optional 

78 all kwargs passed to cls.from_file() 

79 """ 

80 _cls = cls.from_file(filename, **kwargs) 

81 return _cls.classification() 

82 

83 def check_for_install(self, package=None): 

84 """Check that the required package is installed. If the package 

85 is not installed, raise an ImportError 

86 

87 Parameters 

88 ---------- 

89 package: str, optional 

90 name of package to check for install. Default None 

91 """ 

92 if package is None: 

93 package = self.package 

94 if isinstance(package, str): 

95 package = [package] 

96 _not_available = [] 

97 for _package in package: 

98 try: 

99 return importlib.import_module(_package) 

100 except ModuleNotFoundError: 

101 _not_available.append(_package) 

102 if len(_not_available): 

103 raise ImportError( 

104 "Unable to import {}. Unable to compute classification " 

105 "probabilities".format(" or ".join(package)) 

106 ) 

107 

108 @staticmethod 

109 def round_probabilities(ptable, rounding=5): 

110 """Round the entries of a probability table 

111 

112 Parameters 

113 ---------- 

114 ptable: dict 

115 probability table 

116 rounding: int 

117 number of decimal places to round the entries of the probability 

118 table 

119 """ 

120 for key, value in ptable.items(): 

121 ptable[key] = np.round(value, rounding) 

122 return ptable 

123 

124 def plot(self, probabilities, type="bar"): 

125 """Generate a plot showing the classification probabilities 

126 

127 Parameters 

128 ---------- 

129 probabilities: dict 

130 dictionary giving the classification probabilities 

131 type: str, optional 

132 type of plot to produce 

133 """ 

134 if type not in self.available_plots: 

135 raise ValueError( 

136 "Unknown plot '{}'. Please select a plot from {}".format( 

137 type, ", ".join(self.available_plots) 

138 ) 

139 ) 

140 return getattr(self, "_{}_plot".format(type))(probabilities) 

141 

142 def _bar_plot(self, probabilities): 

143 """Generate a bar plot showing classification probabilities 

144 

145 Parameters 

146 ---------- 

147 samples: dict 

148 samples to use for plotting 

149 probabilities: dict 

150 dictionary giving the classification probabilities. 

151 """ 

152 from pesummary.gw.plots.plot import _classification_plot 

153 return _classification_plot(probabilities) 

154 

155 def save_to_file(self, file_name, probabilities, outdir="./", **kwargs): 

156 """Save classification data to json file 

157 

158 Parameters 

159 ---------- 

160 file_name: str 

161 name of the file name that you wish to use 

162 probabilities: dict 

163 dictionary of probabilities you wish to save to file 

164 """ 

165 from pesummary.io import write 

166 write( 

167 list(probabilities.keys()), list(probabilities.values()), 

168 file_format="json", outdir=outdir, filename=file_name, 

169 dataset_name=None, indent=None, **kwargs 

170 ) 

171 

172 

173class PAstro(_Base): 

174 """Class for generating source classification probabilities, i.e. 

175 the probability that it is consistent with originating from a binary 

176 black hole, p(BBH), neutron star black hole, p(NSBH), binary neutron star, 

177 p(BNS). We use a rate and evidence based estimate, as detailed in 

178 https://dcc.ligo.org/LIGO-G2301521 and described below: 

179  

180 The probability for a given classification is simply: 

181 

182 ..math :: 

183 fraction = \frac{R_{\alpha}Z_{\alpha}}{\sum_{\beta}R_{\beta}Z_{\beta}} 

184 P(H_{\alpha}|d) = (1 - P_{\text{Terr}}^{pipeline}) fraction 

185 

186 where :math:`Z_{\alpha}` is the Bayesian evidence for each category, estimated as, 

187 

188 ..math :: 

189 fraction = \frac{p(m_{1s,i},m_{2s,i},z_i|\alpha)}{p(m_{1d,i}m_{2d,i},d_{L,i})\times \frac{dd_L}{dz}\frac{1}{(1+z_i)^2}} 

190 Z_{\alpha}=\frac{Z_{PE}}{N_{samp}}\sum_{i\sim\text{posterior}}^{N_{samp}} fraction 

191 

192 and we use the following straw-person population prior for classifying the sources 

193 into different astrophysical categories 

194 

195 ..math :: 

196 fraction = \frac{m_{1s}^{\alpha}m_{2s}^{\beta}}{\text{min}(m_{1s},m_{2s,max})^{\beta+1}-m_{2s,min}^{\beta+1}} 

197 p(m_{1s},m_{2s},z|\alpha) \propto fraction \frac{dV_c}{dz}\frac{1}{1+z} 

198 

199 Parameters 

200 ---------- 

201 samples: dict 

202 dictionary of posterior samples to use for generating classification 

203 probabilities 

204 category_data: dict, optional 

205 dictionary of summary data (rates and population hyper parameters) for each 

206 category. Default None 

207 distance_prior: class, optional 

208 class describing the distance prior used when generating the posterior 

209 samples. It must have a method `ln_prob` for returning the log prior 

210 probability for a given distance. Default 

211 bilby.gw.prior.UniformSourceFrame 

212 cosmology: str, optional 

213 cosmology you wish to use. Default Planck15 

214 terrestrial_probability: float, optional 

215 probability that the observed gravitational-wave is of terrestrial 

216 origin. Default None. 

217 catch_terrestrial_probability_error: bool, optional 

218 catch the ValueError raised when no terrestrial_probability is provided. 

219 If True, terrestrial_probability is set to 0. Default False 

220 

221 Attributes 

222 ---------- 

223 available_plots: list 

224 list of available plotting types 

225 

226 Methods 

227 ------- 

228 classification: 

229 return a dictionary containing the classification probabilities 

230 """ 

231 defaults = {"BBH": None, "BNS": None, "NSBH": None} 

232 def __init__( 

233 self, samples, category_data=None, distance_prior=None, 

234 cosmology="Planck15", terrestrial_probability=None, 

235 catch_terrestrial_probability_error=False 

236 ): 

237 self.package = ["bilby.gw.prior"] 

238 super(PAstro, self).__init__(samples) 

239 self.distance_prior = distance_prior 

240 if distance_prior is None: 

241 logger.debug( 

242 f"No distance prior provided. Assuming the posterior samples " 

243 f"were obtained with a 'UniformSourceFrame' prior (with a " 

244 f"{cosmology} cosmology), as defined in 'bilby'." 

245 ) 

246 self.distance_prior = self.module.UniformSourceFrame( 

247 minimum=float(np.min(self.samples["luminosity_distance"]) * 0.5), 

248 maximum=float(np.max(self.samples["luminosity_distance"]) * 1.5), 

249 name="luminosity_distance", unit="Mpc", 

250 cosmology=cosmology 

251 ) 

252 if category_data is not None and os.path.isfile(category_data): 

253 import yaml 

254 with open(category_data, "r") as f: 

255 config = yaml.full_load(f) 

256 category_data = config["pop_prior"] 

257 for key, value in config["Rates"].items(): 

258 category_data[key]["rate"] = float(value) 

259 self.category_data = category_data 

260 self.cosmology = cosmology 

261 self.terrestrial_probability = terrestrial_probability 

262 self.catch_terrestrial_probability_error = catch_terrestrial_probability_error 

263 

264 @property 

265 def required_parameters(self): 

266 params = super(PAstro, self).required_parameters 

267 params.extend(["luminosity_distance", "redshift"]) 

268 return params 

269 

270 def _salpeter_prior(self, alpha, m1_bounds, m2_bounds, zmax, beta): 

271 """Calculate and return the log probabilities assuming a Salpeter population 

272 prior 

273 

274 Parameters 

275 ---------- 

276 alpha: float 

277 index of the powerlaw for the primary mass prior 

278 m1_bounds: list 

279 list of length 2 which contains the minimum (index 0) and maximum (index 1) 

280 primary mass 

281 m2_bounds: list 

282 list of length 2 which contains the minimum (index 0) and maximum (index 1) 

283 secondary mass 

284 zmax: float 

285 maximum redshift 

286 beta: float 

287 index of the powerlaw for the secondary mass prior 

288 """ 

289 if alpha != -1: 

290 upper = m1_bounds[1]**(1. + alpha) 

291 lower = m1_bounds[0]**(1. + alpha) 

292 log_m1_norm = np.log((1. + alpha) / (upper - lower)) 

293 else: 

294 log_m1_norm = -np.log(np.log(m1_bounds[1] / m1_bounds[0])) 

295 m2_max = np.min( 

296 np.array( 

297 [ 

298 m2_bounds[1] * np.ones(len(self.samples["mass_1_source"])), 

299 self.samples["mass_1_source"] 

300 ] 

301 ), axis=0 

302 ) 

303 if beta != -1: 

304 upper = m2_max**(1. + beta) 

305 lower = m2_bounds[0]**(1. + beta) 

306 log_m2_norm = np.log((1. + beta) / (upper - lower)) 

307 else: 

308 log_m2_norm = -np.log(np.log(m2_max / m2_bounds[0])) 

309 z_prior = self.module.UniformSourceFrame( 

310 name="redshift", minimum=0., maximum=zmax, unit=None 

311 ) 

312 logprob = ( 

313 alpha * np.log(self.samples["mass_1_source"]) + 

314 beta * np.log(self.samples["mass_2_source"]) + 

315 log_m1_norm + log_m2_norm + 

316 z_prior.ln_prob(self.samples["redshift"]) 

317 ) 

318 logprob[np.isnan(logprob)] = -np.inf 

319 logprob += np.log( 

320 ( 

321 (self.samples["mass_1_source"] >= self.samples["mass_2_source"]) * 

322 (self.samples["mass_1_source"] <= m1_bounds[1]) * 

323 (self.samples["mass_1_source"] >= m1_bounds[0]) * 

324 (m2_max >= self.samples["mass_2_source"]) * 

325 (self.samples["mass_2_source"] >= m2_bounds[0]) 

326 ).astype(int) 

327 ) 

328 return logprob 

329 

330 def classification(self, rounding=5): 

331 if self.category_data is None: 

332 raise ValueError( 

333 "No category data provided to estimate rate weighted evidence. " 

334 "Unable to calculate source probabilities." 

335 ) 

336 required_data = [ 

337 "rate", "alpha", "m1_bounds", "m2_bounds", "zmax", "beta" 

338 ] 

339 for value in self.category_data.values(): 

340 if not all(_ in value.keys() for _ in required_data): 

341 raise ValueError( 

342 "Please provide {} for each category".format( 

343 ", ".join(required_data) 

344 ) 

345 ) 

346 if self.terrestrial_probability is None: 

347 if self.catch_terrestrial_probability_error: 

348 logger.debug( 

349 "Setting terrestrial probability to 0 for classification " 

350 "probabilities" 

351 ) 

352 self.terrestrial_probability = 0. 

353 else: 

354 raise ValueError( 

355 "Please provide a terrestrial probability in order to calculate " 

356 "classification probabilities. Alternatively pass the kwarg " 

357 "catch_terrestrial_probability_error=True" 

358 ) 

359 elif self.terrestrial_probability >= 1: 

360 raise ValueError( 

361 "Terrestrial probability >= 1 meaning that there is no " 

362 "probability that the source is a BBH, NSBH or BNS" 

363 ) 

364 # evaluate population prior 

365 pop_log_priors = { 

366 category: self._salpeter_prior( 

367 config["alpha"], config["m1_bounds"], config["m2_bounds"], 

368 config["zmax"], config["beta"] 

369 ) for category, config in self.category_data.items() 

370 } 

371 # evaluate pe-prior 

372 pe_log_prior = self.distance_prior.ln_prob( 

373 self.samples["luminosity_distance"] 

374 ) 

375 

376 # evaluate detector frame to source frame jacobian 

377 hd = hubble_distance(self.cosmology) 

378 hp = hubble_parameter(self.cosmology, self.samples["redshift"]) 

379 ddL_dz = ( 

380 self.samples["luminosity_distance"] / (1 + self.samples["redshift"]) + 

381 (1. + self.samples["redshift"]) * hd / hp 

382 ) 

383 log_jacobian = -np.log(ddL_dz) - 2. * np.log1p(self.samples["redshift"]) 

384 

385 #compute evidence 

386 rate_weighted_evidence = { 

387 category: config["rate"] * np.exp( 

388 logsumexp( 

389 pop_log_priors[category] - pe_log_prior - 

390 np.log(len(self.samples["mass_1_source"])) + log_jacobian 

391 ) 

392 ) for category, config in self.category_data.items() 

393 } 

394 #compute p_astro 

395 total_evidence = np.sum(list(rate_weighted_evidence.values())) 

396 ptable = { 

397 category: (1. - self.terrestrial_probability) * rz / total_evidence 

398 for category, rz in rate_weighted_evidence.items() 

399 } 

400 ptable["Terrestrial"] = self.terrestrial_probability 

401 if rounding is not None: 

402 return self.round_probabilities(ptable, rounding=rounding) 

403 return ptable 

404 

405 def _samples_plot(self, probabilities): 

406 """Generate a sample distribution plot showing classification 

407 probabilities 

408 

409 Parameters 

410 ---------- 

411 probabilities: dict 

412 dictionary giving the classification probabilities. 

413 """ 

414 from pesummary.gw.plots.plot import _classification_samples_plot 

415 return _classification_samples_plot( 

416 self.samples["mass_1_source"], self.samples["mass_2_source"], 

417 probabilities 

418 ) 

419 

420 

421class EMBright(_Base): 

422 """Class for generating EM-Bright classification probabilities, i.e. 

423 the probability that the binary has a neutron star, p(HasNS), and 

424 the probability that the remnant is observable, p(HasRemnant). 

425 

426 Parameters 

427 ---------- 

428 samples: dict 

429 dictionary of posterior samples to use for generating classification 

430 probabilities 

431 

432 Attributes 

433 ---------- 

434 available_plots: list 

435 list of available plotting types 

436 

437 Methods 

438 ------- 

439 classification: 

440 return a dictionary containing the classification probabilities 

441 """ 

442 defaults = {"HasNS": None, "HasRemnant": None, "HasMassGap": None} 

443 def __init__(self, samples, **kwargs): 

444 self.package = ["ligo.em_bright.em_bright"] 

445 super(EMBright, self).__init__(samples) 

446 

447 @property 

448 def required_parameters(self): 

449 params = super(EMBright, self).required_parameters 

450 params.extend(["a_1", "a_2", "tilt_1", "tilt_2"]) 

451 return params 

452 

453 def classification(self, rounding=5, **kwargs): 

454 probs = self.module.source_classification_pe_from_table(self.samples) 

455 ptable = {"HasNS": probs[0], "HasRemnant": probs[1], "HasMassGap": probs[2]} 

456 if rounding is not None: 

457 return self.round_probabilities(ptable, rounding=rounding) 

458 return ptable 

459 

460 

461class Classify(_Base): 

462 """Class for generating source classification and EM-Bright probabilities, 

463 i.e. the probability that it is consistent with originating from a binary 

464 black hole, p(BBH), neutron star black hole, p(NSBH), binary neutron star, 

465 p(BNS), the probability that the binary has a neutron star, p(HasNS), and 

466 the probability that the remnant is observable, p(HasRemnant). 

467 """ 

468 @property 

469 def required_parameters(self): 

470 params = super(Classify, self).required_parameters 

471 params.extend( 

472 ["luminosity_distance", "redshift", "a_1", "a_2", "tilt_1", "tilt_2"] 

473 ) 

474 return params 

475 

476 def check_for_install(self, *args, **kwargs): 

477 pass 

478 

479 @classmethod 

480 def classification_from_file(cls, filename, **kwargs): 

481 """Initiate the classification class with samples stored in file and 

482 return a dictionary containing the classification probabilities 

483 

484 Parameters 

485 ---------- 

486 filename: str 

487 path to file you wish to initiate the classification class with 

488 path to file you wish to initiate the classification class with 

489 **kwargs: dict, optional 

490 all kwargs passed to cls.from_file() 

491 """ 

492 _cls = PAstro.from_file(filename, **kwargs) 

493 pastro = _cls.classification() 

494 _cls = EMBright.from_file(filename, **kwargs) 

495 embright = _cls.classification() 

496 pastro.update(embright) 

497 return pastro 

498 

499 def classification(self, rounding=5, **kwargs): 

500 """return a dictionary containing the classification probabilities. 

501 """ 

502 probs = PAstro(self.samples, **kwargs).classification( 

503 rounding=rounding 

504 ) 

505 pastro = EMBright(self.samples, **kwargs).classification( 

506 rounding=rounding 

507 ) 

508 probs.update(pastro) 

509 return probs 

510 

511 

512def classify(*args, **kwargs): 

513 """Generate source classification and EM-Bright probabilities, 

514 i.e. the probability that it is consistent with originating from a binary 

515 black hole, p(BBH), neutron star black hole, p(NSBH), binary neutron star, 

516 p(BNS), the probability that the binary has a neutron star, p(HasNS), and 

517 the probability that the remnant is observable, p(HasRemnant). 

518 

519 Parameters 

520 ---------- 

521 samples: dict 

522 dictionary of posterior samples to use for generating classification 

523 probabilities 

524 """ 

525 return Classify(*args).classification(**kwargs)