Coverage for pesummary/gw/cli/inputs.py: 79.6%

846 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import ast 

4import os 

5import numpy as np 

6import pesummary.core.cli.inputs 

7from pesummary.gw.file.read import read as GWRead 

8from pesummary.gw.file.psd import PSD 

9from pesummary.gw.file.calibration import Calibration 

10from pesummary.utils.decorators import deprecation 

11from pesummary.utils.exceptions import InputError 

12from pesummary.utils.utils import logger 

13from pesummary import conf 

14 

15__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

16 

17 

18class _GWInput(pesummary.core.cli.inputs._Input): 

19 """Super class to handle gw specific command line inputs 

20 """ 

21 @staticmethod 

22 def grab_data_from_metafile( 

23 existing_file, webdir, compare=None, nsamples=None, **kwargs 

24 ): 

25 """Grab data from an existing PESummary metafile 

26 

27 Parameters 

28 ---------- 

29 existing_file: str 

30 path to the existing metafile 

31 webdir: str 

32 the directory to store the existing configuration file 

33 compare: list, optional 

34 list of labels for events stored in an existing metafile that you 

35 wish to compare 

36 """ 

37 _replace_kwargs = { 

38 "psd": "{file}.psd['{label}']" 

39 } 

40 if "psd_default" in kwargs.keys(): 

41 _replace_kwargs["psd_default"] = kwargs["psd_default"] 

42 data = pesummary.core.cli.inputs._Input.grab_data_from_metafile( 

43 existing_file, webdir, compare=compare, read_function=GWRead, 

44 nsamples=nsamples, _replace_with_pesummary_kwargs=_replace_kwargs, 

45 **kwargs 

46 ) 

47 f = GWRead(existing_file) 

48 

49 labels = data["labels"] 

50 

51 psd = {i: {} for i in labels} 

52 if f.psd is not None and f.psd != {}: 

53 for i in labels: 

54 if i in f.psd.keys() and f.psd[i] != {}: 

55 psd[i] = { 

56 ifo: PSD(f.psd[i][ifo]) for ifo in f.psd[i].keys() 

57 } 

58 calibration = {i: {} for i in labels} 

59 if f.calibration is not None and f.calibration != {}: 

60 for i in labels: 

61 if i in f.calibration.keys() and f.calibration[i] != {}: 

62 calibration[i] = { 

63 ifo: Calibration(f.calibration[i][ifo]) for ifo in 

64 f.calibration[i].keys() 

65 } 

66 skymap = {i: None for i in labels} 

67 if hasattr(f, "skymap") and f.skymap is not None and f.skymap != {}: 

68 for i in labels: 

69 if i in f.skymap.keys() and len(f.skymap[i]): 

70 skymap[i] = f.skymap[i] 

71 data.update( 

72 { 

73 "approximant": { 

74 i: j for i, j in zip( 

75 labels, [f.approximant[ind] for ind in data["indicies"]] 

76 ) 

77 }, 

78 "psd": psd, 

79 "calibration": calibration, 

80 "skymap": skymap 

81 } 

82 ) 

83 return data 

84 

85 @property 

86 def grab_data_kwargs(self): 

87 kwargs = super(_GWInput, self).grab_data_kwargs 

88 for _property in ["f_low", "f_ref", "f_final", "delta_f"]: 

89 if getattr(self, _property) is None: 

90 setattr(self, "_{}".format(_property), [None] * len(self.labels)) 

91 elif len(getattr(self, _property)) == 1 and len(self.labels) != 1: 

92 setattr( 

93 self, "_{}".format(_property), 

94 getattr(self, _property) * len(self.labels) 

95 ) 

96 if self.opts.approximant is None: 

97 approx = [None] * len(self.labels) 

98 else: 

99 approx = self.opts.approximant 

100 resume_file = [ 

101 os.path.join( 

102 self.webdir, "checkpoint", 

103 "{}_conversion_class.pickle".format(label) 

104 ) for label in self.labels 

105 ] 

106 

107 try: 

108 for num, label in enumerate(self.labels): 

109 try: 

110 psd = self.psd[label] 

111 except KeyError: 

112 psd = {} 

113 kwargs[label].update(dict( 

114 evolve_spins_forwards=self.evolve_spins_forwards, 

115 evolve_spins_backwards=self.evolve_spins_backwards, 

116 f_low=self.f_low[num], 

117 approximant=approx[num], f_ref=self.f_ref[num], 

118 NRSur_fits=self.NRSur_fits, return_kwargs=True, 

119 multipole_snr=self.calculate_multipole_snr, 

120 precessing_snr=self.calculate_precessing_snr, 

121 psd=psd, f_final=self.f_final[num], 

122 waveform_fits=self.waveform_fits, 

123 multi_process=self.opts.multi_process, 

124 redshift_method=self.redshift_method, 

125 cosmology=self.cosmology, 

126 no_conversion=self.no_conversion, 

127 add_zero_spin=True, delta_f=self.delta_f[num], 

128 psd_default=self.psd_default, 

129 disable_remnant=self.disable_remnant, 

130 force_BBH_remnant_computation=self.force_BBH_remnant_computation, 

131 resume_file=resume_file[num], 

132 restart_from_checkpoint=self.restart_from_checkpoint, 

133 force_BH_spin_evolution=self.force_BH_spin_evolution, 

134 )) 

135 return kwargs 

136 except IndexError: 

137 logger.warning( 

138 "Unable to find an f_ref, f_low and approximant for each " 

139 "label. Using and f_ref={}, f_low={} and approximant={} " 

140 "for all result files".format( 

141 self.f_ref[0], self.f_low[0], approx[0] 

142 ) 

143 ) 

144 for num, label in enumerate(self.labels): 

145 kwargs[label].update(dict( 

146 evolve_spins_forwards=self.evolve_spins_forwards, 

147 evolve_spins_backwards=self.evolve_spins_backwards, 

148 f_low=self.f_low[0], 

149 approximant=approx[0], f_ref=self.f_ref[0], 

150 NRSur_fits=self.NRSur_fits, return_kwargs=True, 

151 multipole_snr=self.calculate_multipole_snr, 

152 precessing_snr=self.calculate_precessing_snr, 

153 psd=self.psd[self.labels[0]], f_final=self.f_final[0], 

154 waveform_fits=self.waveform_fits, 

155 multi_process=self.opts.multi_process, 

156 redshift_method=self.redshift_method, 

157 cosmology=self.cosmology, 

158 no_conversion=self.no_conversion, 

159 add_zero_spin=True, delta_f=self.delta_f[0], 

160 psd_default=self.psd_default, 

161 disable_remnant=self.disable_remnant, 

162 force_BBH_remnant_computation=self.force_BBH_remnant_computation, 

163 resume_file=resume_file[num], 

164 restart_from_checkpoint=self.restart_from_checkpoint, 

165 force_BH_spin_evolution=self.force_BH_spin_evolution 

166 )) 

167 return kwargs 

168 

169 @staticmethod 

170 def grab_data_from_file( 

171 file, label, webdir, config=None, injection=None, file_format=None, 

172 nsamples=None, disable_prior_sampling=False, **kwargs 

173 ): 

174 """Grab data from a result file containing posterior samples 

175 

176 Parameters 

177 ---------- 

178 file: str 

179 path to the result file 

180 label: str 

181 label that you wish to use for the result file 

182 config: str, optional 

183 path to a configuration file used in the analysis 

184 injection: str, optional 

185 path to an injection file used in the analysis 

186 file_format, str, optional 

187 the file format you wish to use when loading. Default None. 

188 If None, the read function loops through all possible options 

189 """ 

190 data = pesummary.core.cli.inputs._Input.grab_data_from_file( 

191 file, label, webdir, config=config, injection=injection, 

192 read_function=GWRead, file_format=file_format, nsamples=nsamples, 

193 disable_prior_sampling=disable_prior_sampling, **kwargs 

194 ) 

195 return data 

196 

197 @property 

198 def reweight_samples(self): 

199 return self._reweight_samples 

200 

201 @reweight_samples.setter 

202 def reweight_samples(self, reweight_samples): 

203 from pesummary.gw.reweight import options 

204 self._reweight_samples = self._check_reweight_samples( 

205 reweight_samples, options 

206 ) 

207 

208 def _set_samples(self, *args, **kwargs): 

209 super(_GWInput, self)._set_samples(*args, **kwargs) 

210 if "calibration" not in self.priors: 

211 self.priors["calibration"] = { 

212 label: {} for label in self.labels 

213 } 

214 

215 def _set_corner_params(self, corner_params): 

216 corner_params = super(_GWInput, self)._set_corner_params(corner_params) 

217 if corner_params is None: 

218 logger.debug( 

219 "Using the default corner parameters: {}".format( 

220 ", ".join(conf.gw_corner_parameters) 

221 ) 

222 ) 

223 else: 

224 _corner_params = corner_params 

225 corner_params = list(set(conf.gw_corner_parameters + corner_params)) 

226 for param in _corner_params: 

227 _data = self.samples 

228 if not all(param in _data[label].keys() for label in self.labels): 

229 corner_params.remove(param) 

230 logger.debug( 

231 "Generating a corner plot with the following " 

232 "parameters: {}".format(", ".join(corner_params)) 

233 ) 

234 return corner_params 

235 

236 @property 

237 def cosmology(self): 

238 return self._cosmology 

239 

240 @cosmology.setter 

241 def cosmology(self, cosmology): 

242 from pesummary.gw.cosmology import available_cosmologies 

243 

244 if cosmology.lower() not in available_cosmologies: 

245 logger.warning( 

246 "Unrecognised cosmology: {}. Using {} as default".format( 

247 cosmology, conf.cosmology 

248 ) 

249 ) 

250 cosmology = conf.cosmology 

251 else: 

252 logger.debug("Using the {} cosmology".format(cosmology)) 

253 self._cosmology = cosmology 

254 

255 @property 

256 def approximant(self): 

257 return self._approximant 

258 

259 @approximant.setter 

260 def approximant(self, approximant): 

261 if not hasattr(self, "_approximant"): 

262 approximant_list = {i: {} for i in self.labels} 

263 if approximant is None: 

264 logger.warning( 

265 "No approximant passed. Waveform plots will not be " 

266 "generated" 

267 ) 

268 elif approximant is not None: 

269 if len(approximant) != len(self.labels): 

270 raise InputError( 

271 "Please pass an approximant for each result file" 

272 ) 

273 approximant_list = { 

274 i: j for i, j in zip(self.labels, approximant) 

275 } 

276 self._approximant = approximant_list 

277 else: 

278 for num, i in enumerate(self._approximant.keys()): 

279 if self._approximant[i] == {}: 

280 if num == 0: 

281 logger.warning( 

282 "No approximant passed. Waveform plots will not be " 

283 "generated" 

284 ) 

285 self._approximant[i] = None 

286 break 

287 

288 @property 

289 def gracedb_server(self): 

290 return self._gracedb_server 

291 

292 @gracedb_server.setter 

293 def gracedb_server(self, gracedb_server): 

294 if gracedb_server is None: 

295 self._gracedb_server = conf.gracedb_server 

296 else: 

297 logger.debug( 

298 "Using '{}' as the GraceDB server".format(gracedb_server) 

299 ) 

300 self._gracedb_server = gracedb_server 

301 

302 @property 

303 def gracedb(self): 

304 return self._gracedb 

305 

306 @gracedb.setter 

307 def gracedb(self, gracedb): 

308 self._gracedb = gracedb 

309 if gracedb is not None: 

310 from pesummary.gw.gracedb import get_gracedb_data, HTTPError 

311 from json.decoder import JSONDecodeError 

312 

313 first_letter = gracedb[0] 

314 if first_letter != "G" and first_letter != "g" and first_letter != "S": 

315 logger.warning( 

316 "Invalid GraceDB ID passed. The GraceDB ID must be of the " 

317 "form G0000 or S0000. Ignoring input." 

318 ) 

319 self._gracedb = None 

320 return 

321 _error = ( 

322 "Unable to download data from Gracedb because {}. Only storing " 

323 "the GraceDB ID in the metafile" 

324 ) 

325 try: 

326 logger.info( 

327 "Downloading {} from gracedb for {}".format( 

328 ", ".join(self.gracedb_data), gracedb 

329 ) 

330 ) 

331 json = get_gracedb_data( 

332 gracedb, info=self.gracedb_data, 

333 service_url=self.gracedb_server 

334 ) 

335 json["id"] = gracedb 

336 except (HTTPError, RuntimeError, JSONDecodeError) as e: 

337 logger.warning(_error.format(e)) 

338 json = {"id": gracedb} 

339 

340 for label in self.labels: 

341 self.file_kwargs[label]["meta_data"]["gracedb"] = json 

342 

343 @property 

344 def detectors(self): 

345 return self._detectors 

346 

347 @detectors.setter 

348 def detectors(self, detectors): 

349 detector = {} 

350 if not detectors: 

351 for i in self.samples.keys(): 

352 params = list(self.samples[i].keys()) 

353 individual_detectors = [] 

354 for j in params: 

355 if "optimal_snr" in j and j != "network_optimal_snr": 

356 det = j.split("_optimal_snr")[0] 

357 individual_detectors.append(det) 

358 individual_detectors = sorted( 

359 [str(i) for i in individual_detectors]) 

360 if individual_detectors: 

361 detector[i] = "_".join(individual_detectors) 

362 else: 

363 detector[i] = None 

364 else: 

365 detector = detectors 

366 logger.debug("The detector network is %s" % (detector)) 

367 self._detectors = detector 

368 

369 @property 

370 def skymap(self): 

371 return self._skymap 

372 

373 @skymap.setter 

374 def skymap(self, skymap): 

375 if not hasattr(self, "_skymap"): 

376 self._skymap = {i: None for i in self.labels} 

377 

378 @property 

379 def calibration(self): 

380 return self._calibration 

381 

382 @calibration.setter 

383 def calibration(self, calibration): 

384 if not hasattr(self, "_calibration"): 

385 data = {i: {} for i in self.labels} 

386 if calibration != {}: 

387 prior_data = self.get_psd_or_calibration_data( 

388 calibration, self.extract_calibration_data_from_file 

389 ) 

390 self.add_to_prior_dict("calibration", prior_data) 

391 else: 

392 prior_data = {i: {} for i in self.labels} 

393 for label in self.labels: 

394 if hasattr(self.opts, "{}_calibration".format(label)): 

395 cal_data = getattr(self.opts, "{}_calibration".format(label)) 

396 if cal_data != {} and cal_data is not None: 

397 prior_data[label] = { 

398 ifo: self.extract_calibration_data_from_file( 

399 cal_data[ifo] 

400 ) for ifo in cal_data.keys() 

401 } 

402 if not all(prior_data[i] == {} for i in self.labels): 

403 self.add_to_prior_dict("calibration", prior_data) 

404 else: 

405 self.add_to_prior_dict("calibration", {}) 

406 for num, i in enumerate(self.result_files): 

407 _opened = self._open_result_files 

408 if i in _opened.keys() and _opened[i] is not None: 

409 f = self._open_result_files[i] 

410 else: 

411 f = GWRead(i, disable_prior=True) 

412 try: 

413 calibration_data = f.interpolate_calibration_spline_posterior() 

414 except Exception as e: 

415 logger.warning( 

416 "Failed to extract calibration data from the result " 

417 "file: {} because {}".format(i, e) 

418 ) 

419 calibration_data = None 

420 labels = list(self.samples.keys()) 

421 if calibration_data is None: 

422 data[labels[num]] = { 

423 None: None 

424 } 

425 elif isinstance(f, pesummary.gw.file.formats.pesummary.PESummary): 

426 for num in range(len(calibration_data[0])): 

427 data[labels[num]] = { 

428 j: k for j, k in zip( 

429 calibration_data[1][num], 

430 calibration_data[0][num] 

431 ) 

432 } 

433 else: 

434 data[labels[num]] = { 

435 j: k for j, k in zip( 

436 calibration_data[1], calibration_data[0] 

437 ) 

438 } 

439 self._calibration = data 

440 

441 @property 

442 def psd(self): 

443 return self._psd 

444 

445 @psd.setter 

446 def psd(self, psd): 

447 if not hasattr(self, "_psd"): 

448 data = {i: {} for i in self.labels} 

449 if psd != {}: 

450 data = self.get_psd_or_calibration_data( 

451 psd, self.extract_psd_data_from_file 

452 ) 

453 else: 

454 for label in self.labels: 

455 if hasattr(self.opts, "{}_psd".format(label)): 

456 psd_data = getattr(self.opts, "{}_psd".format(label)) 

457 if psd_data != {} and psd_data is not None: 

458 data[label] = { 

459 ifo: self.extract_psd_data_from_file( 

460 psd_data[ifo], IFO=ifo 

461 ) for ifo in psd_data.keys() 

462 } 

463 self._psd = data 

464 

465 @property 

466 def nsamples_for_skymap(self): 

467 return self._nsamples_for_skymap 

468 

469 @nsamples_for_skymap.setter 

470 def nsamples_for_skymap(self, nsamples_for_skymap): 

471 self._nsamples_for_skymap = nsamples_for_skymap 

472 if nsamples_for_skymap is not None: 

473 self._nsamples_for_skymap = int(nsamples_for_skymap) 

474 number_of_samples = [ 

475 data.number_of_samples for label, data in self.samples.items() 

476 ] 

477 if not all(i > self._nsamples_for_skymap for i in number_of_samples): 

478 min_arg = np.argmin(number_of_samples) 

479 logger.warning( 

480 "You have specified that you would like to use {} " 

481 "samples to generate the skymap but the file {} only " 

482 "has {} samples. Reducing the number of samples to " 

483 "generate the skymap to {}".format( 

484 self._nsamples_for_skymap, self.result_files[min_arg], 

485 number_of_samples[min_arg], number_of_samples[min_arg] 

486 ) 

487 ) 

488 self._nsamples_for_skymap = int(number_of_samples[min_arg]) 

489 

490 @property 

491 def gwdata(self): 

492 return self._gwdata 

493 

494 @gwdata.setter 

495 def gwdata(self, gwdata): 

496 from pesummary.gw.file.strain import StrainDataDict 

497 

498 self._gwdata = gwdata 

499 if gwdata is not None: 

500 if isinstance(gwdata, dict): 

501 for i in gwdata.keys(): 

502 if not os.path.isfile(gwdata[i]): 

503 raise InputError( 

504 "The file {} does not exist. Please check the path " 

505 "to your strain file".format(gwdata[i]) 

506 ) 

507 self._gwdata = StrainDataDict.read(gwdata) 

508 else: 

509 logger.warning( 

510 "Please provide gwdata as a dictionary with keys " 

511 "displaying the channel and item giving the path to the " 

512 "strain file" 

513 ) 

514 self._gwdata = None 

515 

516 @property 

517 def evolve_spins_forwards(self): 

518 return self._evolve_spins_forwards 

519 

520 @evolve_spins_forwards.setter 

521 def evolve_spins_forwards(self, evolve_spins_forwards): 

522 self._evolve_spins_forwards = evolve_spins_forwards 

523 _msg = "Spins will be evolved up to {}" 

524 if evolve_spins_forwards: 

525 logger.info(_msg.format("Schwarzschild ISCO frequency")) 

526 self._evolve_spins_forwards = 6. ** -0.5 

527 

528 @property 

529 def evolve_spins_backwards(self): 

530 return self._evolve_spins_backwards 

531 

532 @evolve_spins_backwards.setter 

533 def evolve_spins_backwards(self, evolve_spins_backwards): 

534 self._evolve_spins_backwards = evolve_spins_backwards 

535 _msg = ( 

536 "Spins will be evolved backwards to an infinite separation using the '{}' " 

537 "method" 

538 ) 

539 if isinstance(evolve_spins_backwards, (str, bytes)): 

540 logger.info(_msg.format(evolve_spins_backwards)) 

541 elif evolve_spins_backwards is None: 

542 logger.info(_msg.format("precession_averaged")) 

543 self._evolve_spins_backwards = "precession_averaged" 

544 

545 @property 

546 def NRSur_fits(self): 

547 return self._NRSur_fits 

548 

549 @NRSur_fits.setter 

550 def NRSur_fits(self, NRSur_fits): 

551 self._NRSur_fits = NRSur_fits 

552 base = ( 

553 "Using the '{}' NRSurrogate model to calculate the remnant " 

554 "quantities" 

555 ) 

556 if isinstance(NRSur_fits, (str, bytes)): 

557 logger.info(base.format(NRSur_fits)) 

558 self._NRSur_fits = NRSur_fits 

559 elif NRSur_fits is None: 

560 from pesummary.gw.conversions.nrutils import NRSUR_MODEL 

561 

562 logger.info(base.format(NRSUR_MODEL)) 

563 self._NRSur_fits = NRSUR_MODEL 

564 

565 @property 

566 def waveform_fits(self): 

567 return self._waveform_fits 

568 

569 @waveform_fits.setter 

570 def waveform_fits(self, waveform_fits): 

571 self._waveform_fits = waveform_fits 

572 if waveform_fits: 

573 logger.info( 

574 "Evaluating the remnant quantities using the provided " 

575 "approximant" 

576 ) 

577 

578 @property 

579 def f_low(self): 

580 return self._f_low 

581 

582 @f_low.setter 

583 def f_low(self, f_low): 

584 self._f_low = f_low 

585 if f_low is not None: 

586 self._f_low = [float(i) for i in f_low] 

587 

588 @property 

589 def f_ref(self): 

590 return self._f_ref 

591 

592 @f_ref.setter 

593 def f_ref(self, f_ref): 

594 self._f_ref = f_ref 

595 if f_ref is not None: 

596 self._f_ref = [float(i) for i in f_ref] 

597 

598 @property 

599 def f_final(self): 

600 return self._f_final 

601 

602 @f_final.setter 

603 def f_final(self, f_final): 

604 self._f_final = f_final 

605 if f_final is not None: 

606 self._f_final = [float(i) for i in f_final] 

607 

608 @property 

609 def delta_f(self): 

610 return self._delta_f 

611 

612 @delta_f.setter 

613 def delta_f(self, delta_f): 

614 self._delta_f = delta_f 

615 if delta_f is not None: 

616 self._delta_f = [float(i) for i in delta_f] 

617 

618 @property 

619 def psd_default(self): 

620 return self._psd_default 

621 

622 @psd_default.setter 

623 def psd_default(self, psd_default): 

624 self._psd_default = psd_default 

625 if "stored:" in psd_default: 

626 label = psd_default.split("stored:")[1] 

627 self._psd_default = "{file}.psd['%s']" % (label) 

628 return 

629 try: 

630 from pycbc import psd 

631 psd_default = getattr(psd, psd_default) 

632 except ImportError: 

633 logger.warning( 

634 "Unable to import 'pycbc'. Unable to generate a default PSD" 

635 ) 

636 psd_default = None 

637 except AttributeError: 

638 logger.warning( 

639 "'pycbc' does not have the '{}' psd available. Using '{}' as " 

640 "the default PSD".format(psd_default, conf.psd) 

641 ) 

642 psd_default = getattr(psd, conf.psd) 

643 except ValueError as e: 

644 logger.warning("Setting 'psd_default' to None because {}".format(e)) 

645 psd_default = None 

646 self._psd_default = psd_default 

647 

648 @property 

649 def pepredicates_probs(self): 

650 return self._pepredicates_probs 

651 

652 @pepredicates_probs.setter 

653 def pepredicates_probs(self, pepredicates_probs): 

654 from pesummary.gw.classification import PEPredicates 

655 

656 classifications = {} 

657 for num, i in enumerate(list(self.samples.keys())): 

658 try: 

659 classifications[i] = PEPredicates( 

660 self.samples[i] 

661 ).dual_classification() 

662 except Exception as e: 

663 logger.warning( 

664 "Failed to generate source classification probabilities " 

665 "because {}".format(e) 

666 ) 

667 classifications[i] = None 

668 if self.mcmc_samples: 

669 if any(_probs is None for _probs in classifications.values()): 

670 classifications[self.labels[0]] = None 

671 logger.warning( 

672 "Unable to average classification probabilities across " 

673 "mcmc chains because one or more chains failed to estimate " 

674 "classifications" 

675 ) 

676 else: 

677 logger.debug( 

678 "Averaging classification probability across mcmc samples" 

679 ) 

680 classifications[self.labels[0]] = { 

681 prior: { 

682 key: np.round(np.mean( 

683 [val[prior][key] for val in classifications.values()] 

684 ), 3) for key in _probs.keys() 

685 } for prior, _probs in 

686 list(classifications.values())[0].items() 

687 } 

688 self._pepredicates_probs = classifications 

689 

690 @property 

691 def pastro_probs(self): 

692 return self._pastro_probs 

693 

694 @pastro_probs.setter 

695 def pastro_probs(self, pastro_probs): 

696 from pesummary.gw.classification import PAstro 

697 

698 probabilities = {} 

699 for num, i in enumerate(list(self.samples.keys())): 

700 try: 

701 probabilities[i] = PAstro(self.samples[i]).dual_classification() 

702 except Exception as e: 

703 logger.warning( 

704 "Failed to generate em_bright probabilities because " 

705 "{}".format(e) 

706 ) 

707 probabilities[i] = None 

708 if self.mcmc_samples: 

709 if any(_probs is None for _probs in probabilities.values()): 

710 probabilities[self.labels[0]] = None 

711 logger.warning( 

712 "Unable to average em_bright probabilities across " 

713 "mcmc chains because one or more chains failed to estimate " 

714 "probabilities" 

715 ) 

716 else: 

717 logger.debug( 

718 "Averaging em_bright probability across mcmc samples" 

719 ) 

720 probabilities[self.labels[0]] = { 

721 prior: { 

722 key: np.round(np.mean( 

723 [val[prior][key] for val in probabilities.values()] 

724 ), 3) for key in _probs.keys() 

725 } for prior, _probs in list(probabilities.values())[0].items() 

726 } 

727 self._pastro_probs = probabilities 

728 

729 @property 

730 def preliminary_pages(self): 

731 return self._preliminary_pages 

732 

733 @preliminary_pages.setter 

734 def preliminary_pages(self, preliminary_pages): 

735 required = conf.gw_reproducibility 

736 self._preliminary_pages = {label: False for label in self.labels} 

737 for num, label in enumerate(self.labels): 

738 for attr in required: 

739 _property = getattr(self, attr) 

740 if isinstance(_property, dict): 

741 if label not in _property.keys(): 

742 self._preliminary_pages[label] = True 

743 elif not len(_property[label]): 

744 self._preliminary_pages[label] = True 

745 elif isinstance(_property, list): 

746 if _property[num] is None: 

747 self._preliminary_pages[label] = True 

748 if any(value for value in self._preliminary_pages.values()): 

749 _labels = [ 

750 label for label, value in self._preliminary_pages.items() if 

751 value 

752 ] 

753 msg = ( 

754 "Unable to reproduce the {} analys{} because no {} data was " 

755 "provided. 'Preliminary' watermarks will be added to the final " 

756 "html pages".format( 

757 ", ".join(_labels), "es" if len(_labels) > 1 else "is", 

758 " or ".join(required) 

759 ) 

760 ) 

761 logger.warning(msg) 

762 

763 @staticmethod 

764 def _extract_IFO_data_from_file(file, cls, desc, IFO=None): 

765 """Return IFO data stored in a file 

766 

767 Parameters 

768 ---------- 

769 file: path 

770 path to a file containing the IFO data 

771 cls: obj 

772 class you wish to use when loading the file. This class must have 

773 a '.read' method 

774 desc: str 

775 description of the IFO data stored in the file 

776 IFO: str, optional 

777 the IFO which the data belongs to 

778 """ 

779 general = ( 

780 "Failed to read in %s data because {}. The %s plot will not be " 

781 "generated and the %s data will not be added to the metafile." 

782 ) % (desc, desc, desc) 

783 try: 

784 return cls.read(file, IFO=IFO) 

785 except FileNotFoundError: 

786 logger.warning( 

787 general.format("the file {} does not exist".format(file)) 

788 ) 

789 return {} 

790 except ValueError as e: 

791 logger.warning(general.format(e)) 

792 return {} 

793 

794 @staticmethod 

795 def extract_psd_data_from_file(file, IFO=None): 

796 """Return the data stored in a psd file 

797 

798 Parameters 

799 ---------- 

800 file: path 

801 path to a file containing the psd data 

802 """ 

803 from pesummary.gw.file.psd import PSD 

804 return _GWInput._extract_IFO_data_from_file(file, PSD, "PSD", IFO=IFO) 

805 

806 @staticmethod 

807 def extract_calibration_data_from_file(file, **kwargs): 

808 """Return the data stored in a calibration file 

809 

810 Parameters 

811 ---------- 

812 file: path 

813 path to a file containing the calibration data 

814 """ 

815 from pesummary.gw.file.calibration import Calibration 

816 return _GWInput._extract_IFO_data_from_file( 

817 file, Calibration, "calibration", **kwargs 

818 ) 

819 

820 @staticmethod 

821 def get_ifo_from_file_name(file): 

822 """Return the IFO from the file name 

823 

824 Parameters 

825 ---------- 

826 file: str 

827 path to the file 

828 """ 

829 file_name = file.split("/")[-1] 

830 if any(j in file_name for j in ["H1", "_0", "IFO0"]): 

831 ifo = "H1" 

832 elif any(j in file_name for j in ["L1", "_1", "IFO1"]): 

833 ifo = "L1" 

834 elif any(j in file_name for j in ["V1", "_2", "IFO2"]): 

835 ifo = "V1" 

836 else: 

837 ifo = file_name 

838 return ifo 

839 

840 def get_psd_or_calibration_data(self, input, executable): 

841 """Return a dictionary containing the psd or calibration data 

842 

843 Parameters 

844 ---------- 

845 input: list/dict 

846 list/dict containing paths to calibration/psd files 

847 executable: func 

848 executable that is used to extract the data from the calibration/psd 

849 files 

850 """ 

851 data = {} 

852 if input == {} or input == []: 

853 return data 

854 if isinstance(input, dict): 

855 keys = list(input.keys()) 

856 if isinstance(input, dict) and isinstance(input[keys[0]], list): 

857 if not all(len(input[i]) == len(self.labels) for i in list(keys)): 

858 raise InputError( 

859 "Please ensure the number of calibration/psd files matches " 

860 "the number of result files passed" 

861 ) 

862 for idx in range(len(input[keys[0]])): 

863 data[self.labels[idx]] = { 

864 i: executable(input[i][idx], IFO=i) for i in list(keys) 

865 } 

866 elif isinstance(input, dict): 

867 for i in self.labels: 

868 data[i] = { 

869 j: executable(input[j], IFO=j) for j in list(input.keys()) 

870 } 

871 elif isinstance(input, list): 

872 for i in self.labels: 

873 data[i] = { 

874 self.get_ifo_from_file_name(j): executable( 

875 j, IFO=self.get_ifo_from_file_name(j) 

876 ) for j in input 

877 } 

878 else: 

879 raise InputError( 

880 "Did not understand the psd/calibration input. Please use the " 

881 "following format 'H1:path/to/file'" 

882 ) 

883 return data 

884 

885 def grab_priors_from_inputs(self, priors): 

886 def read_func(data, **kwargs): 

887 from pesummary.gw.file.read import read as GWRead 

888 data = GWRead(data, **kwargs) 

889 data.generate_all_posterior_samples() 

890 return data 

891 

892 return super(_GWInput, self).grab_priors_from_inputs( 

893 priors, read_func=read_func, read_kwargs=self.grab_data_kwargs 

894 ) 

895 

896 def grab_key_data_from_result_files(self): 

897 """Grab the mean, median, maxL and standard deviation for all 

898 parameters for all each result file 

899 """ 

900 from pesummary.utils.kde_list import KDEList 

901 from pesummary.gw.plots.plot import _return_bounds 

902 from pesummary.utils.credible_interval import ( 

903 hpd_two_sided_credible_interval 

904 ) 

905 from pesummary.utils.bounded_1d_kde import bounded_1d_kde 

906 key_data = super(_GWInput, self).grab_key_data_from_result_files() 

907 bounded_parameters = ["mass_ratio", "a_1", "a_2", "lambda_tilde"] 

908 for param in bounded_parameters: 

909 xlow, xhigh = _return_bounds(param, []) 

910 _samples = { 

911 key: val[param] for key, val in self.samples.items() 

912 if param in val.keys() 

913 } 

914 _min = [np.min(_) for _ in _samples.values() if len(_samples)] 

915 _max = [np.max(_) for _ in _samples.values() if len(_samples)] 

916 if not len(_min): 

917 continue 

918 _min = np.min(_min) 

919 _max = np.max(_max) 

920 x = np.linspace(_min, _max, 1000) 

921 try: 

922 kdes = KDEList( 

923 list(_samples.values()), kde=bounded_1d_kde, 

924 kde_kwargs={"xlow": xlow, "xhigh": xhigh} 

925 ) 

926 except Exception as e: 

927 logger.warning( 

928 "Unable to compute the HPD interval for {} because {}".format( 

929 param, e 

930 ) 

931 ) 

932 continue 

933 pdfs = kdes(x) 

934 for num, key in enumerate(_samples.keys()): 

935 [xlow, xhigh], _ = hpd_two_sided_credible_interval( 

936 [], 90, x=x, pdf=pdfs[num] 

937 ) 

938 key_data[key][param]["90% HPD"] = [xlow, xhigh] 

939 for _param in self.samples[key].keys(): 

940 if _param in bounded_parameters: 

941 continue 

942 key_data[key][_param]["90% HPD"] = float("nan") 

943 return key_data 

944 

945 

946class SamplesInput(_GWInput, pesummary.core.cli.inputs.SamplesInput): 

947 """Class to handle and store sample specific command line arguments 

948 """ 

949 def __init__(self, *args, **kwargs): 

950 kwargs.update({"ignore_copy": True}) 

951 super(SamplesInput, self).__init__( 

952 *args, gw=True, extra_options=[ 

953 "evolve_spins_forwards", 

954 "evolve_spins_backwards", 

955 "NRSur_fits", 

956 "calculate_multipole_snr", 

957 "calculate_precessing_snr", 

958 "f_low", 

959 "f_ref", 

960 "f_final", 

961 "psd", 

962 "waveform_fits", 

963 "redshift_method", 

964 "cosmology", 

965 "no_conversion", 

966 "delta_f", 

967 "psd_default", 

968 "disable_remnant", 

969 "force_BBH_remnant_computation", 

970 "force_BH_spin_evolution" 

971 ], **kwargs 

972 ) 

973 if self._restarted_from_checkpoint: 

974 return 

975 if self.existing is not None: 

976 self.existing_data = self.grab_data_from_metafile( 

977 self.existing_metafile, self.existing, 

978 compare=self.compare_results 

979 ) 

980 self.existing_approximant = self.existing_data["approximant"] 

981 self.existing_psd = self.existing_data["psd"] 

982 self.existing_calibration = self.existing_data["calibration"] 

983 self.existing_skymap = self.existing_data["skymap"] 

984 else: 

985 self.existing_approximant = None 

986 self.existing_psd = None 

987 self.existing_calibration = None 

988 self.existing_skymap = None 

989 self.approximant = self.opts.approximant 

990 self.detectors = None 

991 self.skymap = None 

992 self.calibration = self.opts.calibration 

993 self.gwdata = self.opts.gwdata 

994 self.maxL_samples = [] 

995 

996 @property 

997 def maxL_samples(self): 

998 return self._maxL_samples 

999 

1000 @maxL_samples.setter 

1001 def maxL_samples(self, maxL_samples): 

1002 key_data = self.grab_key_data_from_result_files() 

1003 maxL_samples = { 

1004 i: { 

1005 j: key_data[i][j]["maxL"] for j in key_data[i].keys() 

1006 } for i in key_data.keys() 

1007 } 

1008 for i in self.labels: 

1009 maxL_samples[i]["approximant"] = self.approximant[i] 

1010 self._maxL_samples = maxL_samples 

1011 

1012 

1013class PlottingInput(SamplesInput, pesummary.core.cli.inputs.PlottingInput): 

1014 """Class to handle and store plottig specific command line arguments 

1015 """ 

1016 def __init__(self, *args, **kwargs): 

1017 super(PlottingInput, self).__init__(*args, **kwargs) 

1018 self.nsamples_for_skymap = self.opts.nsamples_for_skymap 

1019 self.sensitivity = self.opts.sensitivity 

1020 self.no_ligo_skymap = self.opts.no_ligo_skymap 

1021 self.multi_threading_for_skymap = self.multi_process 

1022 if not self.no_ligo_skymap and self.multi_process > 1: 

1023 total = self.multi_process 

1024 self.multi_threading_for_plots = int(total / 2.) 

1025 self.multi_threading_for_skymap = total - self.multi_threading_for_plots 

1026 logger.info( 

1027 "Assigning {} process{}to skymap generation and {} process{}to " 

1028 "other plots".format( 

1029 self.multi_threading_for_skymap, 

1030 "es " if self.multi_threading_for_skymap > 1 else " ", 

1031 self.multi_threading_for_plots, 

1032 "es " if self.multi_threading_for_plots > 1 else " " 

1033 ) 

1034 ) 

1035 self.preliminary_pages = None 

1036 self.pepredicates_probs = [] 

1037 self.pastro_probs = [] 

1038 

1039 

1040class WebpageInput(SamplesInput, pesummary.core.cli.inputs.WebpageInput): 

1041 """Class to handle and store webpage specific command line arguments 

1042 """ 

1043 def __init__(self, *args, **kwargs): 

1044 super(WebpageInput, self).__init__(*args, **kwargs) 

1045 self.gracedb_server = self.opts.gracedb_server 

1046 self.gracedb_data = self.opts.gracedb_data 

1047 self.gracedb = self.opts.gracedb 

1048 self.public = self.opts.public 

1049 if not hasattr(self, "preliminary_pages"): 

1050 self.preliminary_pages = None 

1051 if not hasattr(self, "pepredicates_probs"): 

1052 self.pepredicates_probs = [] 

1053 if not hasattr(self, "pastro_probs"): 

1054 self.pastro_probs = [] 

1055 

1056 

1057class WebpagePlusPlottingInput(PlottingInput, WebpageInput): 

1058 """Class to handle and store webpage and plotting specific command line 

1059 arguments 

1060 """ 

1061 def __init__(self, *args, **kwargs): 

1062 super(WebpagePlusPlottingInput, self).__init__(*args, **kwargs) 

1063 

1064 @property 

1065 def default_directories(self): 

1066 return super(WebpagePlusPlottingInput, self).default_directories 

1067 

1068 @property 

1069 def default_files_to_copy(self): 

1070 return super(WebpagePlusPlottingInput, self).default_files_to_copy 

1071 

1072 

1073class MetaFileInput(SamplesInput, pesummary.core.cli.inputs.MetaFileInput): 

1074 """Class to handle and store metafile specific command line arguments 

1075 """ 

1076 @property 

1077 def default_directories(self): 

1078 dirs = super(MetaFileInput, self).default_directories 

1079 dirs += ["psds", "calibration"] 

1080 return dirs 

1081 

1082 def copy_files(self): 

1083 _error = "Failed to save the {} to file" 

1084 for label in self.labels: 

1085 if self.psd[label] != {}: 

1086 for ifo in self.psd[label].keys(): 

1087 if not isinstance(self.psd[label][ifo], PSD): 

1088 logger.warning(_error.format("{} PSD".format(ifo))) 

1089 continue 

1090 self.psd[label][ifo].save_to_file( 

1091 os.path.join(self.webdir, "psds", "{}_{}_psd.dat".format( 

1092 label, ifo 

1093 )) 

1094 ) 

1095 if label in self.priors["calibration"].keys(): 

1096 if self.priors["calibration"][label] != {}: 

1097 for ifo in self.priors["calibration"][label].keys(): 

1098 _instance = isinstance( 

1099 self.priors["calibration"][label][ifo], Calibration 

1100 ) 

1101 if not _instance: 

1102 logger.warning( 

1103 _error.format( 

1104 "{} calibration envelope".format( 

1105 ifo 

1106 ) 

1107 ) 

1108 ) 

1109 continue 

1110 self.priors["calibration"][label][ifo].save_to_file( 

1111 os.path.join(self.webdir, "calibration", "{}_{}_cal.txt".format( 

1112 label, ifo 

1113 )) 

1114 ) 

1115 return super(MetaFileInput, self).copy_files() 

1116 

1117 

1118class WebpagePlusPlottingPlusMetaFileInput(MetaFileInput, WebpagePlusPlottingInput): 

1119 """Class to handle and store webpage, plotting and metafile specific command 

1120 line arguments 

1121 """ 

1122 def __init__(self, *args, **kwargs): 

1123 super(WebpagePlusPlottingPlusMetaFileInput, self).__init__( 

1124 *args, **kwargs 

1125 ) 

1126 

1127 @property 

1128 def default_directories(self): 

1129 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_directories 

1130 

1131 @property 

1132 def default_files_to_copy(self): 

1133 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_files_to_copy 

1134 

1135 

1136@deprecation( 

1137 "The GWInput class is deprecated. Please use either the BaseInput, " 

1138 "SamplesInput, PlottingInput, WebpageInput, WebpagePlusPlottingInput, " 

1139 "MetaFileInput or the WebpagePlusPlottingPlusMetaFileInput class" 

1140) 

1141class GWInput(WebpagePlusPlottingPlusMetaFileInput): 

1142 pass 

1143 

1144 

1145class IMRCTInput(pesummary.core.cli.inputs._Input): 

1146 """Class to handle the TGR specific command line arguments 

1147 """ 

1148 @property 

1149 def labels(self): 

1150 return self._labels 

1151 

1152 @labels.setter 

1153 def labels(self, labels): 

1154 self._labels = labels 

1155 if len(labels) % 2 != 0: 

1156 raise ValueError( 

1157 "The IMRCT test requires 2 results files for each analysis. " 

1158 ) 

1159 elif len(labels) > 2: 

1160 cond = all( 

1161 ":inspiral" in label or ":postinspiral" in label for label in 

1162 labels 

1163 ) 

1164 if not cond: 

1165 raise ValueError( 

1166 "To compare 2 or more analyses, please provide labels as " 

1167 "'{}:inspiral' and '{}:postinspiral' where {} indicates " 

1168 "the analysis label" 

1169 ) 

1170 else: 

1171 self.analysis_label = [ 

1172 label.split(":inspiral")[0] 

1173 for label in labels 

1174 if ":inspiral" in label and ":postinspiral" not in label 

1175 ] 

1176 if len(self.analysis_label) != len(self.result_files) / 2: 

1177 raise ValueError( 

1178 "When comparing more than 2 analyses, labels must " 

1179 "be of the form '{}:inspiral' and '{}:postinspiral'." 

1180 ) 

1181 logger.info( 

1182 "Using the labels: {} to distinguish analyses".format( 

1183 ", ".join(self.analysis_label) 

1184 ) 

1185 ) 

1186 elif sorted(labels) != ["inspiral", "postinspiral"]: 

1187 if all(self.is_pesummary_metafile(ff) for ff in self.result_files): 

1188 meta_file_labels = [] 

1189 for suffix in [":inspiral", ":postinspiral"]: 

1190 if any(suffix in label for label in labels): 

1191 ind = [ 

1192 num for num, label in enumerate(labels) if 

1193 suffix in label 

1194 ] 

1195 if len(ind) > 1: 

1196 raise ValueError( 

1197 "Please provide a single {} label".format( 

1198 suffix.split(":")[1] 

1199 ) 

1200 ) 

1201 meta_file_labels.append( 

1202 labels[ind[0]].split(suffix)[0] 

1203 ) 

1204 else: 

1205 raise ValueError( 

1206 "Please provide labels as {inspiral_label}:inspiral " 

1207 "and {postinspiral_label}:postinspiral where " 

1208 "inspiral_label and postinspiral_label are the " 

1209 "PESummary labels for the inspiral and postinspiral " 

1210 "analyses respectively. " 

1211 ) 

1212 if len(self.result_files) == 1: 

1213 logger.info( 

1214 "Using the {} samples for the inspiral analysis and {} " 

1215 "samples for the postinspiral analysis from the file " 

1216 "{}".format( 

1217 meta_file_labels[0], meta_file_labels[1], 

1218 self.result_files[0] 

1219 ) 

1220 ) 

1221 elif len(self.result_files) == 2: 

1222 logger.info( 

1223 "Using the {} samples for the inspiral analysis from " 

1224 "the file {}. Using the {} samples for the " 

1225 "postinspiral analysis from the file {}".format( 

1226 meta_file_labels[0], self.result_files[0], 

1227 meta_file_labels[1], self.result_files[1] 

1228 ) 

1229 ) 

1230 else: 

1231 raise ValueError( 

1232 "Currently, you can only provide at most 2 pesummary " 

1233 "metafiles. If one is provided, both the inspiral and " 

1234 "postinspiral are extracted from that single file. If " 

1235 "two are provided, the inspiral is extracted from one " 

1236 "file and the postinspiral is extracted from the other." 

1237 ) 

1238 self._labels = ["inspiral", "postinspiral"] 

1239 self._meta_file_labels = meta_file_labels 

1240 self.analysis_label = ["primary"] 

1241 else: 

1242 raise ValueError( 

1243 "The IMRCT test requires an inspiral and postinspiral result " 

1244 "file. Please indicate which file is the inspiral and which " 

1245 "is postinspiral by providing these exact labels to the " 

1246 "summarytgr executable" 

1247 ) 

1248 else: 

1249 self.analysis_label = ["primary"] 

1250 

1251 def _extract_stored_approximant(self, opened_file, label): 

1252 """Extract the approximant used for a given analysis stored in a 

1253 PESummary metafile 

1254 

1255 Parameters 

1256 ---------- 

1257 opened_file: pesummary.gw.file.formats.pesummary.PESummary 

1258 opened metafile that contains the analysis 'label' 

1259 label: str 

1260 analysis label which is stored in the PESummary metafile 

1261 """ 

1262 if opened_file.approximant is not None: 

1263 if label not in opened_file.labels: 

1264 raise ValueError( 

1265 "Invalid label {}. The list of available labels are {}".format( 

1266 label, ", ".join(opened_file.labels) 

1267 ) 

1268 ) 

1269 _index = opened_file.labels.index(label) 

1270 return opened_file.approximant[_index] 

1271 return 

1272 

1273 def _extract_stored_remnant_fits(self, opened_file, label): 

1274 """Extract the remnant fits used for a given analysis stored in a 

1275 PESummary metafile 

1276 

1277 Parameters 

1278 ---------- 

1279 opened_file: pesummary.gw.file.formats.pesummary.PESummary 

1280 opened metafile that contains the analysis 'label' 

1281 label: str 

1282 analysis label which is stored in the PESummary metafile 

1283 """ 

1284 fits = {} 

1285 fit_strings = [ 

1286 "final_spin_NR_fits", "final_mass_NR_fits" 

1287 ] 

1288 if label not in opened_file.labels: 

1289 raise ValueError( 

1290 "Invalid label {}. The list of available labels are {}".format( 

1291 label, ", ".join(opened_file.labels) 

1292 ) 

1293 ) 

1294 _index = opened_file.labels.index(label) 

1295 _meta_data = opened_file.extra_kwargs[_index] 

1296 if "meta_data" in _meta_data.keys(): 

1297 for key in fit_strings: 

1298 if key in _meta_data["meta_data"].keys(): 

1299 fits[key] = _meta_data["meta_data"][key] 

1300 if len(fits): 

1301 return fits 

1302 return 

1303 

1304 def _extract_stored_cutoff_frequency(self, opened_file, label): 

1305 """Extract the cutoff frequencies used for a given analysis stored in a 

1306 PESummary metafile 

1307 

1308 Parameters 

1309 ---------- 

1310 opened_file: pesummary.gw.file.formats.pesummary.PESummary 

1311 opened metafile that contains the analysis 'label' 

1312 label: str 

1313 analysis label which is stored in the PESummary metafile 

1314 """ 

1315 frequencies = {} 

1316 if opened_file.config is not None: 

1317 if label not in opened_file.labels: 

1318 raise ValueError( 

1319 "Invalid label {}. The list of available labels are {}".format( 

1320 label, ", ".join(opened_file.labels) 

1321 ) 

1322 ) 

1323 if opened_file.config[label] is not None: 

1324 _config = opened_file.config[label] 

1325 if "config" in _config.keys(): 

1326 if "maximum-frequency" in _config["config"].keys(): 

1327 frequencies["fhigh"] = _config["config"][ 

1328 "maximum-frequency" 

1329 ] 

1330 if "minimum-frequency" in _config["config"].keys(): 

1331 frequencies["flow"] = _config["config"][ 

1332 "minimum-frequency" 

1333 ] 

1334 elif "lalinference" in _config.keys(): 

1335 if "fhigh" in _config["lalinference"].keys(): 

1336 frequencies["fhigh"] = _config["lalinference"][ 

1337 "fhigh" 

1338 ] 

1339 if "flow" in _config["lalinference"].keys(): 

1340 frequencies["flow"] = _config["lalinference"][ 

1341 "flow" 

1342 ] 

1343 return frequencies 

1344 return 

1345 

1346 @property 

1347 def samples(self): 

1348 return self._samples 

1349 

1350 @samples.setter 

1351 def samples(self, samples): 

1352 from pesummary.utils.samples_dict import MultiAnalysisSamplesDict 

1353 self._read_samples = { 

1354 _label: GWRead(_path, disable_prior=True) for _label, _path in zip( 

1355 self.labels, self.result_files 

1356 ) 

1357 } 

1358 _samples_dict = {} 

1359 _approximant_dict = {} 

1360 _cutoff_frequency_dict = {} 

1361 _remnant_fits_dict = {} 

1362 for label, _open in self._read_samples.items(): 

1363 if isinstance(_open.samples_dict, MultiAnalysisSamplesDict): 

1364 if not len(self._meta_file_labels): 

1365 raise ValueError( 

1366 "Currently you can only pass a file containing a " 

1367 "single analysis or a valid PESummary metafile " 

1368 "containing multiple analyses" 

1369 ) 

1370 _labels = _open.labels 

1371 if len(self._read_samples) == 1: 

1372 _samples_dict = { 

1373 label: _open.samples_dict[meta_file_label] for 

1374 label, meta_file_label in zip( 

1375 self.labels, self._meta_file_labels 

1376 ) 

1377 } 

1378 for label, meta_file_label in zip(self.labels, self._meta_file_labels): 

1379 _stored_approx = self._extract_stored_approximant( 

1380 _open, meta_file_label 

1381 ) 

1382 _stored_frequencies = self._extract_stored_cutoff_frequency( 

1383 _open, meta_file_label 

1384 ) 

1385 _stored_remnant_fits = self._extract_stored_remnant_fits( 

1386 _open, meta_file_label 

1387 ) 

1388 if _stored_approx is not None: 

1389 _approximant_dict[label] = _stored_approx 

1390 if _stored_remnant_fits is not None: 

1391 _remnant_fits_dict[label] = _stored_remnant_fits 

1392 if _stored_frequencies is not None: 

1393 if label == "inspiral": 

1394 if "fhigh" in _stored_frequencies.keys(): 

1395 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1396 "fhigh" 

1397 ] 

1398 if label == "postinspiral": 

1399 if "flow" in _stored_frequencies.keys(): 

1400 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1401 "flow" 

1402 ] 

1403 break 

1404 else: 

1405 ind = self.labels.index(label) 

1406 _samples_dict[label] = _open.samples_dict[ 

1407 self._meta_file_labels[ind] 

1408 ] 

1409 _stored_approx = self._extract_stored_approximant( 

1410 _open, self._meta_file_labels[ind] 

1411 ) 

1412 _stored_frequencies = self._extract_stored_cutoff_frequency( 

1413 _open, self._meta_file_labels[ind] 

1414 ) 

1415 _stored_remnant_fits = self._extract_stored_remnant_fits( 

1416 _open, self._meta_file_labels[ind] 

1417 ) 

1418 if _stored_approx is not None: 

1419 _approximant_dict[label] = _stored_approx 

1420 if _stored_remnant_fits is not None: 

1421 _remnant_fits_dict[label] = _stored_remnant_fits 

1422 if _stored_frequencies is not None: 

1423 if label == "inspiral": 

1424 if "fhigh" in _stored_frequencies.keys(): 

1425 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1426 "fhigh" 

1427 ] 

1428 if label == "postinspiral": 

1429 if "flow" in _stored_frequencies.keys(): 

1430 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1431 "flow" 

1432 ] 

1433 else: 

1434 _samples_dict[label] = _open.samples_dict 

1435 extra_kwargs = _open.extra_kwargs 

1436 if "pe_algorithm" in extra_kwargs["sampler"].keys(): 

1437 if extra_kwargs["sampler"]["pe_algorithm"] == "bilby": 

1438 try: 

1439 subkwargs = extra_kwargs["other"]["likelihood"][ 

1440 "waveform_arguments" 

1441 ] 

1442 _approximant_dict[label] = ( 

1443 subkwargs["waveform_approximant"] 

1444 ) 

1445 if "inspiral" in label and "postinspiral" not in label: 

1446 _cutoff_frequency_dict[label] = ( 

1447 subkwargs["maximum_frequency"] 

1448 ) 

1449 elif "postinspiral" in label: 

1450 _cutoff_frequency_dict[label] = ( 

1451 subkwargs["minimum_frequency"] 

1452 ) 

1453 except KeyError: 

1454 pass 

1455 self._samples = MultiAnalysisSamplesDict(_samples_dict) 

1456 if len(_approximant_dict): 

1457 self._approximant_dict = _approximant_dict 

1458 if len(_cutoff_frequency_dict): 

1459 self._cutoff_frequency_dict = _cutoff_frequency_dict 

1460 if len(_remnant_fits_dict): 

1461 self._remnant_fits_dict = _remnant_fits_dict 

1462 

1463 @property 

1464 def imrct_kwargs(self): 

1465 return self._imrct_kwargs 

1466 

1467 @imrct_kwargs.setter 

1468 def imrct_kwargs(self, imrct_kwargs): 

1469 test_kwargs = dict(N_bins=101) 

1470 try: 

1471 test_kwargs.update(imrct_kwargs) 

1472 except AttributeError: 

1473 test_kwargs = test_kwargs 

1474 

1475 for key, value in test_kwargs.items(): 

1476 try: 

1477 test_kwargs[key] = ast.literal_eval(value) 

1478 except ValueError: 

1479 pass 

1480 self._imrct_kwargs = test_kwargs 

1481 

1482 @property 

1483 def meta_data(self): 

1484 return self._meta_data 

1485 

1486 @meta_data.setter 

1487 def meta_data(self, meta_data): 

1488 self._meta_data = {} 

1489 for num, _inspiral in enumerate(self.inspiral_keys): 

1490 frequency_dict = dict() 

1491 approximant_dict = dict() 

1492 remnant_dict = dict() 

1493 zipped = zip( 

1494 [self.cutoff_frequency, self.approximant, None], 

1495 [frequency_dict, approximant_dict, remnant_dict], 

1496 ["cutoff_frequency", "approximant", "remnant_fits"] 

1497 ) 

1498 _inspiral_string = self.inspiral_keys[num] 

1499 _postinspiral_string = self.postinspiral_keys[num] 

1500 for _list, _dict, name in zipped: 

1501 if _list is not None and len(_list) == len(self.labels): 

1502 inspiral_ind = self.labels.index(_inspiral_string) 

1503 postinspiral_ind = self.labels.index(_postinspiral_string) 

1504 _dict["inspiral"] = _list[inspiral_ind] 

1505 _dict["postinspiral"] = _list[postinspiral_ind] 

1506 elif _list is not None: 

1507 raise ValueError( 

1508 "Please provide a 'cutoff_frequency' and 'approximant' " 

1509 "for each file" 

1510 ) 

1511 else: 

1512 try: 

1513 if name == "cutoff_frequency": 

1514 cond = ( 

1515 "inspiral" in self._cutoff_frequency_dict.keys() 

1516 and "postinspiral" not in 

1517 self._cutoff_frequency_dict.keys() 

1518 ) 

1519 if cond: 

1520 _dict["inspiral"] = self._cutoff_frequency_dict[ 

1521 "inspiral" 

1522 ] 

1523 elif "postinspiral" in self._cutoff_frequency_dict.keys(): 

1524 _dict["postinspiral"] = self._cutoff_frequency_dict[ 

1525 "postinspiral" 

1526 ] 

1527 elif name == "approximant": 

1528 cond = ( 

1529 "inspiral" in self._approximant_dict.keys() 

1530 and "postinspiral" not in 

1531 self._approximant_dict.keys() 

1532 ) 

1533 if cond: 

1534 _dict["inspiral"] = self._approximant_dict[ 

1535 "inspiral" 

1536 ] 

1537 elif "postinspiral" in self._approximant_dict.keys(): 

1538 _dict["postinspiral"] = self._approximant_dict[ 

1539 "postinspiral" 

1540 ] 

1541 elif name == "remnant_fits": 

1542 cond = ( 

1543 "inspiral" in self._remnant_fits_dict.keys() 

1544 and "postinspiral" not in 

1545 self._remnant_fits_dict.keys() 

1546 ) 

1547 if cond: 

1548 _dict["inspiral"] = self._remnant_fits_dict[ 

1549 "inspiral" 

1550 ] 

1551 elif "postinspiral" in self._remnant_fits_dict.keys(): 

1552 _dict["postinspiral"] = self._remnant_fits_dict[ 

1553 "postinspiral" 

1554 ] 

1555 except (AttributeError, KeyError, TypeError): 

1556 _dict["inspiral"] = None 

1557 _dict["postinspiral"] = None 

1558 

1559 self._meta_data[self.analysis_label[num]] = { 

1560 "inspiral maximum frequency (Hz)": frequency_dict["inspiral"], 

1561 "postinspiral minimum frequency (Hz)": frequency_dict["postinspiral"], 

1562 "inspiral approximant": approximant_dict["inspiral"], 

1563 "postinspiral approximant": approximant_dict["postinspiral"], 

1564 "inspiral remnant fits": remnant_dict["inspiral"], 

1565 "postinspiral remnant fits": remnant_dict["postinspiral"] 

1566 } 

1567 

1568 def __init__(self, opts): 

1569 self.opts = opts 

1570 self.existing = None 

1571 self.webdir = self.opts.webdir 

1572 self.user = None 

1573 self.baseurl = None 

1574 self.result_files = self.opts.samples 

1575 self.labels = self.opts.labels 

1576 self.samples = self.opts.samples 

1577 self.inspiral_keys = [ 

1578 key for key in self.samples.keys() if "inspiral" in key 

1579 and "postinspiral" not in key 

1580 ] 

1581 self.postinspiral_keys = [ 

1582 key.replace("inspiral", "postinspiral") for key in self.inspiral_keys 

1583 ] 

1584 try: 

1585 self.imrct_kwargs = self.opts.imrct_kwargs 

1586 except AttributeError: 

1587 self.imrct_kwargs = {} 

1588 for _arg in ["cutoff_frequency", "approximant", "links_to_pe_pages", "f_low"]: 

1589 _attr = getattr(self.opts, _arg) 

1590 if _attr is not None and len(_attr) and len(_attr) != len(self.labels): 

1591 raise ValueError("Please provide a {} for each file".format(_arg)) 

1592 setattr(self, _arg, _attr) 

1593 self.meta_data = None 

1594 self.default_directories = ["samples", "plots", "js", "html", "css"] 

1595 self.publication = False 

1596 self.make_directories()