Coverage for pesummary/gw/classification.py: 92.9%

183 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import numpy as np 

4import importlib 

5 

6__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

7 

8 

9class _Base(object): 

10 """Base class for generating classification probabilities 

11 

12 Parameters 

13 ---------- 

14 samples: dict 

15 dictionary of posterior samples to use for generating classification 

16 probabilities 

17 

18 Attributes 

19 ---------- 

20 available_plots: list 

21 list of available plotting types 

22 

23 Methods 

24 ------- 

25 classification: 

26 return a dictionary containing the classification probabilities. These 

27 probabilities can either be generated from the raw samples or samples 

28 reweighted to a population inferred prior 

29 dual_classification: 

30 return a dictionary containing the classification probabilities 

31 generated from the raw samples ('default') and samples reweighted to 

32 a population inferred prior ('population') 

33 plot: 

34 generate a plot showing the classification probabilities 

35 """ 

36 def __init__(self, samples): 

37 self.module = self.check_for_install() 

38 if not isinstance(samples, dict): 

39 raise ValueError("Please provide samples as dictionary") 

40 if not all(_p in samples.keys() for _p in self.required_parameters): 

41 from pesummary.utils.samples_dict import SamplesDict 

42 samples = SamplesDict(samples) 

43 samples.generate_all_posterior_samples(disable_remnant=True) 

44 if not all(_p in samples.keys() for _p in self.required_parameters): 

45 raise ValueError( 

46 "Failed to compute classification probabilities because " 

47 "the following parameters are required: {}".format( 

48 ", ".join(self.required_parameters) 

49 ) 

50 ) 

51 self.samples = self._convert_samples(samples) 

52 

53 @classmethod 

54 def from_file(cls, filename): 

55 """Initiate the classification class with samples stored in file 

56 

57 Parameters 

58 ---------- 

59 filename: str 

60 path to file you wish to initiate the classification class with 

61 """ 

62 from pesummary.io import read 

63 f = read(filename) 

64 samples = f.samples_dict 

65 return cls(samples) 

66 

67 @classmethod 

68 def classification_from_file(cls, filename, **kwargs): 

69 """Initiate the classification class with samples stored in file and 

70 return a dictionary containing the classification probabilities 

71 

72 Parameters 

73 ---------- 

74 filename: str 

75 path to file you wish to initiate the classification class with 

76 **kwargs: dict, optional 

77 all kwargs passed to cls.classification() 

78 """ 

79 _cls = cls.from_file(filename) 

80 return _cls.classification(**kwargs) 

81 

82 @classmethod 

83 def dual_classification_from_file(cls, filename, seed=123456789): 

84 """Initiate the classification class with samples stored in file and 

85 return a dictionary containing the classification probabilities 

86 generated from the raw samples ('default') and samples reweighted to 

87 a population inferred prior ('population') 

88 

89 Parameters 

90 ---------- 

91 filename: str 

92 path to file you wish to initiate the classification class with 

93 seed: int, optional 

94 random seed to use when reweighing to a population inferred prior 

95 """ 

96 _cls = cls.from_file(filename) 

97 return _cls.dual_classification(seed=seed) 

98 

99 @property 

100 def required_parameters(self): 

101 return ["mass_1_source", "mass_2_source", "a_1", "a_2"] 

102 

103 @property 

104 def available_plots(self): 

105 return ["bar"] 

106 

107 @staticmethod 

108 def round_probabilities(ptable, rounding=5): 

109 """Round the entries of a probability table 

110 

111 Parameters 

112 ---------- 

113 ptable: dict 

114 probability table 

115 rounding: int 

116 number of decimal places to round the entries of the probability 

117 table 

118 """ 

119 for key, value in ptable.items(): 

120 ptable[key] = np.round(value, rounding) 

121 return ptable 

122 

123 def check_for_install(self, package=None): 

124 """Check that the required package is installed. If the package 

125 is not installed, raise an ImportError 

126 

127 Parameters 

128 ---------- 

129 package: str, optional 

130 name of package to check for install. Default None 

131 """ 

132 if package is None: 

133 package = self.package 

134 if isinstance(package, str): 

135 package = [package] 

136 _not_available = [] 

137 for _package in package: 

138 try: 

139 return importlib.import_module(_package) 

140 except ModuleNotFoundError: 

141 _not_available.append(_package) 

142 if len(_not_available): 

143 raise ImportError( 

144 "Unable to import {}. Unable to compute classification " 

145 "probabilities".format(" or ".join(package)) 

146 ) 

147 

148 def _resample_to_population_prior(self, samples=None): 

149 """Use the pepredicates.rewt_approx_massdist_redshift function to 

150 reweight a pandas DataFrame to a population informed prior 

151 

152 Parameters 

153 ---------- 

154 samples: dict, optional 

155 pandas DataFrame containing posterior samples 

156 """ 

157 import copy 

158 if not self.__class__.__name__ == "PEPredicates": 

159 _module = self.check_for_install(package="pepredicates") 

160 else: 

161 _module = self.module 

162 if samples is None: 

163 samples = self.samples 

164 _samples = copy.deepcopy(samples) 

165 if not all(param in _samples.keys() for param in ["redshift", "dist"]): 

166 raise ValueError( 

167 "Samples for redshift and distance required for population " 

168 "reweighting" 

169 ) 

170 return _module.rewt_approx_massdist_redshift(_samples) 

171 

172 def dual_classification(self, seed=123456789): 

173 """Return a dictionary containing the classification probabilities 

174 generated from the raw samples ('default') and samples reweighted to 

175 a population inferred prior ('population') 

176 

177 Parameters 

178 ---------- 

179 seed: int, optional 

180 random seed to use when reweighing to a population inferred prior 

181 """ 

182 return { 

183 "default": self.classification(), 

184 "population": self.classification(population=True, seed=seed) 

185 } 

186 

187 def classification(self, population=False, return_samples=False, seed=123456789): 

188 """return a dictionary containing the classification probabilities. 

189 These probabilities can either be generated from the raw samples or 

190 samples reweighted to a population inferred prior 

191 

192 Parameters 

193 ---------- 

194 population: Bool, optional 

195 if True, reweight the samples to a population informed prior and 

196 then calculate classification probabilities. Default False 

197 return_samples: Bool, optional 

198 if True, return the samples used as well as the classification 

199 probabilities 

200 seed: int, optional 

201 random seed to use when reweighing to a population inferred prior 

202 """ 

203 if not population: 

204 ptable = self._compute_classification_probabilities() 

205 if return_samples: 

206 return self.samples, ptable 

207 return ptable 

208 np.random.seed(seed) 

209 _samples = PEPredicates._convert_samples(self.samples) 

210 reweighted_samples = self._resample_to_population_prior( 

211 samples=_samples 

212 ) 

213 ptable = self._compute_classification_probabilities( 

214 samples=self._convert_samples(reweighted_samples) 

215 ) 

216 if return_samples: 

217 return _samples, ptable 

218 return ptable 

219 

220 def _compute_classification_probabilities(self, samples=None): 

221 """Base function to compute classification probabilities 

222 

223 Parameters 

224 ---------- 

225 samples: dict, optional 

226 samples to use for computing the classification probabilities 

227 Default None. 

228 """ 

229 if samples is None: 

230 samples = self.samples 

231 return samples, {} 

232 

233 def plot( 

234 self, samples=None, probabilities=None, type="bar", population=False 

235 ): 

236 """Generate a plot showing the classification probabilities 

237 

238 Parameters 

239 ---------- 

240 samples: dict, optional 

241 samples to use for plotting. Default None 

242 probabilities: dict, optional 

243 dictionary giving the classification probabilities. Default None 

244 type: str, optional 

245 type of plot to produce 

246 population: Bool, optional 

247 if True, reweight the posterior samples to a population informed 

248 prior before computing the classification probabilities for 

249 plotting. Only used when probabilities=None. Default False 

250 """ 

251 if type not in self.available_plots: 

252 raise ValueError( 

253 "Unknown plot '{}'. Please select a plot from {}".format( 

254 type, ", ".join(self.available_plots) 

255 ) 

256 ) 

257 if (probabilities is None) or ((samples is None) and population): 

258 s, p = self.classification( 

259 population=population, return_samples=True 

260 ) 

261 if probabilities is None: 

262 probabilities = p 

263 if ((samples is None) and population): 

264 samples = s 

265 return getattr(self, "_{}_plot".format(type))(samples, probabilities) 

266 

267 def _bar_plot(self, samples, probabilities): 

268 """Generate a bar plot showing classification probabilities 

269 

270 Parameters 

271 ---------- 

272 samples: dict 

273 samples to use for plotting 

274 probabilities: dict 

275 dictionary giving the classification probabilities. 

276 """ 

277 from pesummary.gw.plots.plot import _classification_plot 

278 return _classification_plot(probabilities) 

279 

280 

281class PEPredicates(_Base): 

282 """Class for generating source classification probabilities, i.e. 

283 the probability that it is consistent with originating from a binary 

284 black hole, p(BBH), neutron star black hole, p(NSBH), binary neutron star, 

285 p(BNS), or a binary originating from the mass gap, p(MassGap) 

286 

287 Parameters 

288 ---------- 

289 samples: dict 

290 dictionary of posterior samples to use for generating classification 

291 probabilities 

292 

293 Attributes 

294 ---------- 

295 available_plots: list 

296 list of available plotting types 

297 

298 Methods 

299 ------- 

300 classification: 

301 return a dictionary containing the classification probabilities. These 

302 probabilities can either be generated from the raw samples or samples 

303 reweighted to a population inferred prior 

304 dual_classification: 

305 return a dictionary containing the classification probabilities 

306 generated from the raw samples ('default') and samples reweighted to 

307 a population inferred prior ('population') 

308 plot: 

309 generate a plot showing the classification probabilities 

310 """ 

311 def __init__(self, samples): 

312 self.package = "pepredicates" 

313 super(PEPredicates, self).__init__(samples) 

314 

315 @property 

316 def _default_probabilities(self): 

317 return { 

318 'BNS': self.module.BNS_p, 'NSBH': self.module.NSBH_p, 

319 'BBH': self.module.BBH_p, 'MassGap': self.module.MG_p 

320 } 

321 

322 @property 

323 def available_plots(self): 

324 _plots = super(PEPredicates, self).available_plots 

325 _plots.extend(["pepredicates"]) 

326 return _plots 

327 

328 @staticmethod 

329 def mapping(reverse=False): 

330 _mapping = { 

331 "mass_1_source": "m1_source", "mass_2_source": "m2_source", 

332 "luminosity_distance": "dist", "redshift": "redshift", 

333 "a_1": "a1", "a_2": "a2", "tilt_1": "tilt1", "tilt_2": "tilt2" 

334 } 

335 if reverse: 

336 return {item: key for key, item in _mapping.items()} 

337 return _mapping 

338 

339 @staticmethod 

340 def _convert_samples(samples): 

341 """Convert dictionary of posterior samples to required form 

342 needed for pepredicates 

343 

344 Parameters 

345 ---------- 

346 samples: dict 

347 samples to use for computing the classification probabilities 

348 """ 

349 import pandas as pd 

350 mapping = PEPredicates.mapping() 

351 reverse = PEPredicates.mapping(reverse=True) 

352 if not all(param in samples.keys() for param in reverse.keys()): 

353 _samples = { 

354 new: samples[original] for original, new in mapping.items() 

355 if original in samples.keys() 

356 } 

357 else: 

358 _samples = samples.copy() 

359 return pd.DataFrame.from_dict(_samples) 

360 

361 def _compute_classification_probabilities(self, samples=None, rounding=5): 

362 """Compute classification probabilities 

363 

364 Parameters 

365 ---------- 

366 samples: dict, optional 

367 samples to use for computing the classification probabilities 

368 Default None. 

369 rounding: int, optional 

370 number of decimal places to round entries of probability table. 

371 Default 5 

372 """ 

373 samples, _ = super()._compute_classification_probabilities( 

374 samples=samples 

375 ) 

376 ptable = self.module.predicate_table( 

377 self._default_probabilities, samples 

378 ) 

379 if rounding is not None: 

380 return self.round_probabilities(ptable, rounding=rounding) 

381 return ptable 

382 

383 def _pepredicates_plot(self, samples, probabilities): 

384 """Generate the a plot using the pepredicates.plot_predicates function 

385 showing classification probabilities 

386 

387 Parameters 

388 ---------- 

389 samples: dict 

390 samples to use for plotting 

391 probabilities: dict 

392 dictionary giving the classification probabilities. 

393 """ 

394 from pesummary.core.plots.figure import ExistingFigure 

395 if samples is None: 

396 from pesummary.utils.utils import logger 

397 logger.debug( 

398 "No samples provided for plotting. Using cached array." 

399 ) 

400 samples = self.samples 

401 idxs = { 

402 "BBH": self.module.is_BBH(samples), 

403 "BNS": self.module.is_BNS(samples), 

404 "NSBH": self.module.is_NSBH(samples), 

405 "MassGap": self.module.is_MG(samples) 

406 } 

407 return ExistingFigure( 

408 self.module.plot_predicates( 

409 idxs, samples, probs=probabilities 

410 ) 

411 ) 

412 

413 

414class PAstro(_Base): 

415 """Class for generating EM-Bright classification probabilities, i.e. 

416 the probability that the binary has a neutron star, p(HasNS), and 

417 the probability that the remnant is observable, p(HasRemnant). 

418 

419 Parameters 

420 ---------- 

421 samples: dict 

422 dictionary of posterior samples to use for generating classification 

423 probabilities 

424 

425 Attributes 

426 ---------- 

427 available_plots: list 

428 list of available plotting types 

429 

430 Methods 

431 ------- 

432 classification: 

433 return a dictionary containing the classification probabilities. These 

434 probabilities can either be generated from the raw samples or samples 

435 reweighted to a population inferred prior 

436 dual_classification: 

437 return a dictionary containing the classification probabilities 

438 generated from the raw samples ('default') and samples reweighted to 

439 a population inferred prior ('population') 

440 plot: 

441 generate a plot showing the classification probabilities 

442 """ 

443 def __init__(self, samples): 

444 self.package = "ligo.em_bright.em_bright" 

445 super(PAstro, self).__init__(samples) 

446 

447 @property 

448 def required_parameters(self): 

449 _parameters = super(PAstro, self).required_parameters 

450 _parameters.extend(["tilt_1", "tilt_2"]) 

451 return _parameters 

452 

453 @staticmethod 

454 def _convert_samples(samples): 

455 """Convert dictionary of posterior samples to required form 

456 needed for ligo.computeDiskMass 

457 

458 Parameters 

459 ---------- 

460 samples: dict 

461 samples to use for computing the classification probabilities 

462 """ 

463 _samples = {} 

464 try: 

465 reverse = PEPredicates.mapping(reverse=True) 

466 for key, item in reverse.items(): 

467 if key in samples.keys(): 

468 samples[item] = samples.pop(key) 

469 except KeyError: 

470 pass 

471 for key, item in samples.items(): 

472 _samples[key] = np.asarray(item) 

473 return _samples 

474 

475 def _compute_classification_probabilities(self, samples=None, rounding=5): 

476 """Compute classification probabilities 

477 

478 Parameters 

479 ---------- 

480 samples: dict, optional 

481 samples to use for computing the classification probabilities 

482 Default None. 

483 rounding: int, optional 

484 number of decimal places to round entries of probability table. 

485 Default 5 

486 """ 

487 samples, _ = super()._compute_classification_probabilities( 

488 samples=samples 

489 ) 

490 probs = self.module.source_classification_pe_from_table(samples) 

491 ptable = {"HasNS": probs[0], "HasRemnant": probs[1]} 

492 if rounding is not None: 

493 return self.round_probabilities(ptable, rounding=rounding) 

494 return ptable 

495 

496 

497class Classify(_Base): 

498 """Class for generating source classification and EM-Bright probabilities, 

499 i.e. the probability that it is consistent with originating from a binary 

500 black hole, p(BBH), neutron star black hole, p(NSBH), binary neutron star, 

501 p(BNS), or a binary originating from the mass gap, p(MassGap), the 

502 probability that the binary has a neutron star, p(HasNS), and the 

503 probability that the remnant is observable, p(HasRemnant). 

504 """ 

505 @property 

506 def required_parameters(self): 

507 _parameters = super(Classify, self).required_parameters 

508 _parameters.extend(["tilt_1", "tilt_2"]) 

509 return _parameters 

510 

511 def check_for_install(self, *args, **kwargs): 

512 pass 

513 

514 def _convert_samples(self, samples): 

515 return samples 

516 

517 def classification(self, **kwargs): 

518 """return a dictionary containing the classification probabilities. 

519 These probabilities can either be generated from the raw samples or 

520 samples reweighted to a population inferred prior 

521 

522 Parameters 

523 ---------- 

524 population: Bool, optional 

525 if True, reweight the samples to a population informed prior and 

526 then calculate classification probabilities. Default False 

527 return_samples: Bool, optional 

528 if True, return the samples used as well as the classification 

529 probabilities 

530 """ 

531 probs = PEPredicates(self.samples).classification(**kwargs) 

532 pastro = PAstro(self.samples).classification(**kwargs) 

533 probs.update(pastro) 

534 return probs 

535 

536 

537def classify(*args, **kwargs): 

538 """Generate source classification and EM-Bright probabilities, 

539 i.e. the probability that it is consistent with originating from a binary 

540 black hole, p(BBH), neutron star black hole, p(NSBH), binary neutron star, 

541 p(BNS), or a binary originating from the mass gap, p(MassGap), the 

542 probability that the binary has a neutron star, p(HasNS), and the 

543 probability that the remnant is observable, p(HasRemnant). 

544 

545 Parameters 

546 ---------- 

547 samples: dict 

548 dictionary of posterior samples to use for generating classification 

549 probabilities 

550 population: Bool, optional 

551 if True, reweight the samples to a population informed prior and 

552 then calculate classification probabilities. Default False 

553 return_samples: Bool, optional 

554 if True, return the samples used as well as the classification 

555 probabilities 

556 """ 

557 return Classify(*args).classification(**kwargs)