Coverage for pesummary/gw/cli/inputs.py: 78.7%

894 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-09 22:34 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import ast 

4import os 

5import numpy as np 

6import pesummary.core.cli.inputs 

7from pesummary.gw.file.read import read as GWRead 

8from pesummary.gw.file.psd import PSD 

9from pesummary.gw.file.calibration import Calibration 

10from pesummary.utils.decorators import deprecation 

11from pesummary.utils.exceptions import InputError 

12from pesummary.utils.utils import logger 

13from pesummary import conf 

14 

15__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

16 

17 

18class _GWInput(pesummary.core.cli.inputs._Input): 

19 """Super class to handle gw specific command line inputs 

20 """ 

21 @staticmethod 

22 def grab_data_from_metafile( 

23 existing_file, webdir, compare=None, nsamples=None, **kwargs 

24 ): 

25 """Grab data from an existing PESummary metafile 

26 

27 Parameters 

28 ---------- 

29 existing_file: str 

30 path to the existing metafile 

31 webdir: str 

32 the directory to store the existing configuration file 

33 compare: list, optional 

34 list of labels for events stored in an existing metafile that you 

35 wish to compare 

36 """ 

37 _replace_kwargs = { 

38 "psd": "{file}.psd['{label}']" 

39 } 

40 if "psd_default" in kwargs.keys(): 

41 _replace_kwargs["psd_default"] = kwargs["psd_default"] 

42 data = pesummary.core.cli.inputs._Input.grab_data_from_metafile( 

43 existing_file, webdir, compare=compare, read_function=GWRead, 

44 nsamples=nsamples, _replace_with_pesummary_kwargs=_replace_kwargs, 

45 **kwargs 

46 ) 

47 f = GWRead(existing_file) 

48 

49 labels = data["labels"] 

50 

51 psd = {i: {} for i in labels} 

52 if f.psd is not None and f.psd != {}: 

53 for i in labels: 

54 if i in f.psd.keys() and f.psd[i] != {}: 

55 psd[i] = { 

56 ifo: PSD(f.psd[i][ifo]) for ifo in f.psd[i].keys() 

57 } 

58 calibration = {i: {} for i in labels} 

59 if f.calibration is not None and f.calibration != {}: 

60 for i in labels: 

61 if i in f.calibration.keys() and f.calibration[i] != {}: 

62 calibration[i] = { 

63 ifo: Calibration(f.calibration[i][ifo]) for ifo in 

64 f.calibration[i].keys() 

65 } 

66 skymap = {i: None for i in labels} 

67 if hasattr(f, "skymap") and f.skymap is not None and f.skymap != {}: 

68 for i in labels: 

69 if i in f.skymap.keys() and len(f.skymap[i]): 

70 skymap[i] = f.skymap[i] 

71 data.update( 

72 { 

73 "approximant": { 

74 i: j for i, j in zip( 

75 labels, [f.approximant[ind] for ind in data["indicies"]] 

76 ) 

77 }, 

78 "psd": psd, 

79 "calibration": calibration, 

80 "skymap": skymap 

81 } 

82 ) 

83 return data 

84 

85 @property 

86 def grab_data_kwargs(self): 

87 kwargs = super(_GWInput, self).grab_data_kwargs 

88 for _property in ["f_start", "f_low", "f_ref", "f_final", "delta_f"]: 

89 if getattr(self, _property) is None: 

90 setattr(self, "_{}".format(_property), [None] * len(self.labels)) 

91 elif len(getattr(self, _property)) == 1 and len(self.labels) != 1: 

92 setattr( 

93 self, "_{}".format(_property), 

94 getattr(self, _property) * len(self.labels) 

95 ) 

96 if self.opts.approximant is None: 

97 approx = [None] * len(self.labels) 

98 else: 

99 approx = self.opts.approximant 

100 self.approximant_flags = self.opts.approximant_flags 

101 resume_file = [ 

102 os.path.join( 

103 self.webdir, "checkpoint", 

104 "{}_conversion_class.pickle".format(label) 

105 ) for label in self.labels 

106 ] 

107 

108 try: 

109 for num, label in enumerate(self.labels): 

110 try: 

111 psd = self.psd[label] 

112 except KeyError: 

113 psd = {} 

114 kwargs[label].update(dict( 

115 evolve_spins_forwards=self.evolve_spins_forwards, 

116 evolve_spins_backwards=self.evolve_spins_backwards, 

117 f_start=self.f_start[num], f_low=self.f_low[num], 

118 approximant_flags=self.approximant_flags.get(label, {}), 

119 approximant=approx[num], f_ref=self.f_ref[num], 

120 NRSur_fits=self.NRSur_fits, return_kwargs=True, 

121 multipole_snr=self.calculate_multipole_snr, 

122 precessing_snr=self.calculate_precessing_snr, 

123 psd=psd, f_final=self.f_final[num], 

124 waveform_fits=self.waveform_fits, 

125 multi_process=self.opts.multi_process, 

126 redshift_method=self.redshift_method, 

127 cosmology=self.cosmology, 

128 no_conversion=self.no_conversion, 

129 add_zero_spin=True, delta_f=self.delta_f[num], 

130 psd_default=self.psd_default, 

131 disable_remnant=self.disable_remnant, 

132 force_BBH_remnant_computation=self.force_BBH_remnant_computation, 

133 resume_file=resume_file[num], 

134 restart_from_checkpoint=self.restart_from_checkpoint, 

135 force_BH_spin_evolution=self.force_BH_spin_evolution, 

136 )) 

137 return kwargs 

138 except IndexError: 

139 logger.warning( 

140 "Unable to find an f_ref, f_low and approximant for each " 

141 "label. Using and f_ref={}, f_low={} and approximant={} " 

142 "for all result files".format( 

143 self.f_ref[0], self.f_low[0], approx[0] 

144 ) 

145 ) 

146 for num, label in enumerate(self.labels): 

147 kwargs[label].update(dict( 

148 evolve_spins_forwards=self.evolve_spins_forwards, 

149 evolve_spins_backwards=self.evolve_spins_backwards, 

150 f_start=self.f_start[0], f_low=self.f_low[0], 

151 approximant=approx[0], f_ref=self.f_ref[0], 

152 NRSur_fits=self.NRSur_fits, return_kwargs=True, 

153 multipole_snr=self.calculate_multipole_snr, 

154 precessing_snr=self.calculate_precessing_snr, 

155 psd=self.psd[self.labels[0]], f_final=self.f_final[0], 

156 waveform_fits=self.waveform_fits, 

157 multi_process=self.opts.multi_process, 

158 redshift_method=self.redshift_method, 

159 cosmology=self.cosmology, 

160 no_conversion=self.no_conversion, 

161 add_zero_spin=True, delta_f=self.delta_f[0], 

162 psd_default=self.psd_default, 

163 disable_remnant=self.disable_remnant, 

164 force_BBH_remnant_computation=self.force_BBH_remnant_computation, 

165 resume_file=resume_file[num], 

166 restart_from_checkpoint=self.restart_from_checkpoint, 

167 force_BH_spin_evolution=self.force_BH_spin_evolution 

168 )) 

169 return kwargs 

170 

171 @staticmethod 

172 def grab_data_from_file( 

173 file, label, webdir, config=None, injection=None, file_format=None, 

174 nsamples=None, disable_prior_sampling=False, **kwargs 

175 ): 

176 """Grab data from a result file containing posterior samples 

177 

178 Parameters 

179 ---------- 

180 file: str 

181 path to the result file 

182 label: str 

183 label that you wish to use for the result file 

184 config: str, optional 

185 path to a configuration file used in the analysis 

186 injection: str, optional 

187 path to an injection file used in the analysis 

188 file_format, str, optional 

189 the file format you wish to use when loading. Default None. 

190 If None, the read function loops through all possible options 

191 """ 

192 data = pesummary.core.cli.inputs._Input.grab_data_from_file( 

193 file, label, webdir, config=config, injection=injection, 

194 read_function=GWRead, file_format=file_format, nsamples=nsamples, 

195 disable_prior_sampling=disable_prior_sampling, **kwargs 

196 ) 

197 try: 

198 for _kwgs in data["file_kwargs"][label]: 

199 if "approximant_flags" in _kwgs["meta_data"]: 

200 for key, item in _kwargs["meta_data"]["approximant_flags"].items(): 

201 warning_cond = ( 

202 key in self.approximant_flags[label] and 

203 self.approximant_flags[label][key] != item 

204 ) 

205 if warning_cond: 

206 logger.warning( 

207 "Approximant flag {}={} found in result file for {}. " 

208 "Ignoring and using the provided values {}={}".format( 

209 key, self.approximant_flags[label][key], label, 

210 key, item 

211 ) 

212 ) 

213 else: 

214 self.approximant_flags[label][key] = item 

215 except Exception: 

216 pass 

217 return data 

218 

219 @property 

220 def reweight_samples(self): 

221 return self._reweight_samples 

222 

223 @reweight_samples.setter 

224 def reweight_samples(self, reweight_samples): 

225 from pesummary.gw.reweight import options 

226 self._reweight_samples = self._check_reweight_samples( 

227 reweight_samples, options 

228 ) 

229 

230 def _set_samples(self, *args, **kwargs): 

231 super(_GWInput, self)._set_samples(*args, **kwargs) 

232 if "calibration" not in self.priors: 

233 self.priors["calibration"] = { 

234 label: {} for label in self.labels 

235 } 

236 

237 def _set_corner_params(self, corner_params): 

238 corner_params = super(_GWInput, self)._set_corner_params(corner_params) 

239 if corner_params is None: 

240 logger.debug( 

241 "Using the default corner parameters: {}".format( 

242 ", ".join(conf.gw_corner_parameters) 

243 ) 

244 ) 

245 else: 

246 _corner_params = corner_params 

247 corner_params = list(set(conf.gw_corner_parameters + corner_params)) 

248 for param in _corner_params: 

249 _data = self.samples 

250 if not all(param in _data[label].keys() for label in self.labels): 

251 corner_params.remove(param) 

252 logger.debug( 

253 "Generating a corner plot with the following " 

254 "parameters: {}".format(", ".join(corner_params)) 

255 ) 

256 return corner_params 

257 

258 @property 

259 def cosmology(self): 

260 return self._cosmology 

261 

262 @cosmology.setter 

263 def cosmology(self, cosmology): 

264 from pesummary.gw.cosmology import available_cosmologies 

265 

266 if cosmology.lower() not in available_cosmologies: 

267 logger.warning( 

268 "Unrecognised cosmology: {}. Using {} as default".format( 

269 cosmology, conf.cosmology 

270 ) 

271 ) 

272 cosmology = conf.cosmology 

273 else: 

274 logger.debug("Using the {} cosmology".format(cosmology)) 

275 self._cosmology = cosmology 

276 

277 @property 

278 def approximant(self): 

279 return self._approximant 

280 

281 @approximant.setter 

282 def approximant(self, approximant): 

283 if not hasattr(self, "_approximant"): 

284 approximant_list = {i: {} for i in self.labels} 

285 if approximant is None: 

286 logger.warning( 

287 "No approximant passed. Waveform plots will not be " 

288 "generated" 

289 ) 

290 elif approximant is not None: 

291 if len(approximant) != len(self.labels): 

292 raise InputError( 

293 "Please pass an approximant for each result file" 

294 ) 

295 approximant_list = { 

296 i: j for i, j in zip(self.labels, approximant) 

297 } 

298 self._approximant = approximant_list 

299 else: 

300 for num, i in enumerate(self._approximant.keys()): 

301 if self._approximant[i] == {}: 

302 if num == 0: 

303 logger.warning( 

304 "No approximant passed. Waveform plots will not be " 

305 "generated" 

306 ) 

307 self._approximant[i] = None 

308 break 

309 

310 @property 

311 def approximant_flags(self): 

312 return self._approximant_flags 

313 

314 @approximant_flags.setter 

315 def approximant_flags(self, approximant_flags): 

316 if hasattr(self, "_approximant_flags"): 

317 return 

318 _approximant_flags = {key: {} for key in self.labels} 

319 for key, item in approximant_flags.items(): 

320 _label, key = key.split(":") 

321 if _label not in self.labels: 

322 raise ValueError( 

323 "Unable to assign waveform flags for label:{} because " 

324 "it does not exist. Available labels are: {}. Approximant " 

325 "flags must be provided in the form LABEL:FLAG:VALUE".format( 

326 _label, ", ".join(self.labels) 

327 ) 

328 ) 

329 _approximant_flags[_label][key] = item 

330 self._approximant_flags = _approximant_flags 

331 

332 @property 

333 def gracedb_server(self): 

334 return self._gracedb_server 

335 

336 @gracedb_server.setter 

337 def gracedb_server(self, gracedb_server): 

338 if gracedb_server is None: 

339 self._gracedb_server = conf.gracedb_server 

340 else: 

341 logger.debug( 

342 "Using '{}' as the GraceDB server".format(gracedb_server) 

343 ) 

344 self._gracedb_server = gracedb_server 

345 

346 @property 

347 def gracedb(self): 

348 return self._gracedb 

349 

350 @gracedb.setter 

351 def gracedb(self, gracedb): 

352 self._gracedb = gracedb 

353 if gracedb is not None: 

354 from pesummary.gw.gracedb import get_gracedb_data, HTTPError 

355 from json.decoder import JSONDecodeError 

356 

357 first_letter = gracedb[0] 

358 if first_letter != "G" and first_letter != "g" and first_letter != "S": 

359 logger.warning( 

360 "Invalid GraceDB ID passed. The GraceDB ID must be of the " 

361 "form G0000 or S0000. Ignoring input." 

362 ) 

363 self._gracedb = None 

364 return 

365 _error = ( 

366 "Unable to download data from Gracedb because {}. Only storing " 

367 "the GraceDB ID in the metafile" 

368 ) 

369 try: 

370 logger.info( 

371 "Downloading {} from gracedb for {}".format( 

372 ", ".join(self.gracedb_data), gracedb 

373 ) 

374 ) 

375 json = get_gracedb_data( 

376 gracedb, info=self.gracedb_data, 

377 service_url=self.gracedb_server 

378 ) 

379 json["id"] = gracedb 

380 except (HTTPError, RuntimeError, JSONDecodeError) as e: 

381 logger.warning(_error.format(e)) 

382 json = {"id": gracedb} 

383 

384 for label in self.labels: 

385 self.file_kwargs[label]["meta_data"]["gracedb"] = json 

386 

387 @property 

388 def detectors(self): 

389 return self._detectors 

390 

391 @detectors.setter 

392 def detectors(self, detectors): 

393 detector = {} 

394 if not detectors: 

395 for i in self.samples.keys(): 

396 params = list(self.samples[i].keys()) 

397 individual_detectors = [] 

398 for j in params: 

399 if "optimal_snr" in j and j != "network_optimal_snr": 

400 det = j.split("_optimal_snr")[0] 

401 individual_detectors.append(det) 

402 individual_detectors = sorted( 

403 [str(i) for i in individual_detectors]) 

404 if individual_detectors: 

405 detector[i] = "_".join(individual_detectors) 

406 else: 

407 detector[i] = None 

408 else: 

409 detector = detectors 

410 logger.debug("The detector network is %s" % (detector)) 

411 self._detectors = detector 

412 

413 @property 

414 def skymap(self): 

415 return self._skymap 

416 

417 @skymap.setter 

418 def skymap(self, skymap): 

419 if not hasattr(self, "_skymap"): 

420 self._skymap = {i: None for i in self.labels} 

421 

422 @property 

423 def calibration_definition(self): 

424 return self._calibration_definition 

425 

426 @calibration_definition.setter 

427 def calibration_definition(self, calibration_definition): 

428 if not len(self.opts.calibration): 

429 self._calibration_definition = None 

430 return 

431 if len(calibration_definition) == 1: 

432 logger.info( 

433 f"Assuming that the calibration correction was applied to " 

434 f"'{calibration_definition[0]}' for all analyses" 

435 ) 

436 calibration_definition *= len(self.labels) 

437 elif len(calibration_definition) != len(self.labels): 

438 raise ValueError( 

439 f"Please provide a calibration definition for each analysis " 

440 f"({len(self.labels)}) or a single definition to use for all " 

441 f"analyses" 

442 ) 

443 if any(_ not in ["data", "template"] for _ in calibration_definition): 

444 raise ValueError( 

445 "Calibration definitions must be either 'data' or 'template'" 

446 ) 

447 self._calibration_definition = { 

448 label: calibration_definition[num] for num, label in 

449 enumerate(self.labels) 

450 } 

451 

452 @property 

453 def calibration(self): 

454 return self._calibration 

455 

456 @calibration.setter 

457 def calibration(self, calibration): 

458 if not hasattr(self, "_calibration"): 

459 data = {i: {} for i in self.labels} 

460 if calibration != {}: 

461 prior_data = self.get_psd_or_calibration_data( 

462 calibration, self.extract_calibration_data_from_file, 

463 type=self.calibration_definition[self.labels[0]] 

464 ) 

465 self.add_to_prior_dict("calibration", prior_data) 

466 else: 

467 prior_data = {i: {} for i in self.labels} 

468 for label in self.labels: 

469 if hasattr(self.opts, "{}_calibration".format(label)): 

470 cal_data = getattr(self.opts, "{}_calibration".format(label)) 

471 if cal_data != {} and cal_data is not None: 

472 prior_data[label] = { 

473 ifo: self.extract_calibration_data_from_file( 

474 cal_data[ifo], type=self.calibration_definition[label] 

475 ) for ifo in cal_data.keys() 

476 } 

477 if not all(prior_data[i] == {} for i in self.labels): 

478 self.add_to_prior_dict("calibration", prior_data) 

479 else: 

480 self.add_to_prior_dict("calibration", {}) 

481 for num, i in enumerate(self.result_files): 

482 _opened = self._open_result_files 

483 if i in _opened.keys() and _opened[i] is not None: 

484 f = self._open_result_files[i] 

485 else: 

486 f = GWRead(i, disable_prior=True) 

487 try: 

488 calibration_data = f.interpolate_calibration_spline_posterior() 

489 except Exception as e: 

490 logger.warning( 

491 "Failed to extract calibration data from the result " 

492 "file: {} because {}".format(i, e) 

493 ) 

494 calibration_data = None 

495 labels = list(self.samples.keys()) 

496 if calibration_data is None: 

497 data[labels[num]] = { 

498 None: None 

499 } 

500 elif isinstance(f, pesummary.gw.file.formats.pesummary.PESummary): 

501 for num in range(len(calibration_data[0])): 

502 data[labels[num]] = { 

503 j: k for j, k in zip( 

504 calibration_data[1][num], 

505 calibration_data[0][num] 

506 ) 

507 } 

508 else: 

509 data[labels[num]] = { 

510 j: k for j, k in zip( 

511 calibration_data[1], calibration_data[0] 

512 ) 

513 } 

514 self._calibration = data 

515 

516 @property 

517 def psd(self): 

518 return self._psd 

519 

520 @psd.setter 

521 def psd(self, psd): 

522 if not hasattr(self, "_psd"): 

523 data = {i: {} for i in self.labels} 

524 if psd != {}: 

525 data = self.get_psd_or_calibration_data( 

526 psd, self.extract_psd_data_from_file 

527 ) 

528 else: 

529 for label in self.labels: 

530 if hasattr(self.opts, "{}_psd".format(label)): 

531 psd_data = getattr(self.opts, "{}_psd".format(label)) 

532 if psd_data != {} and psd_data is not None: 

533 data[label] = { 

534 ifo: self.extract_psd_data_from_file( 

535 psd_data[ifo], IFO=ifo 

536 ) for ifo in psd_data.keys() 

537 } 

538 self._psd = data 

539 

540 @property 

541 def nsamples_for_skymap(self): 

542 return self._nsamples_for_skymap 

543 

544 @nsamples_for_skymap.setter 

545 def nsamples_for_skymap(self, nsamples_for_skymap): 

546 self._nsamples_for_skymap = nsamples_for_skymap 

547 if nsamples_for_skymap is not None: 

548 self._nsamples_for_skymap = int(nsamples_for_skymap) 

549 number_of_samples = [ 

550 data.number_of_samples for label, data in self.samples.items() 

551 ] 

552 if not all(i > self._nsamples_for_skymap for i in number_of_samples): 

553 min_arg = np.argmin(number_of_samples) 

554 logger.warning( 

555 "You have specified that you would like to use {} " 

556 "samples to generate the skymap but the file {} only " 

557 "has {} samples. Reducing the number of samples to " 

558 "generate the skymap to {}".format( 

559 self._nsamples_for_skymap, self.result_files[min_arg], 

560 number_of_samples[min_arg], number_of_samples[min_arg] 

561 ) 

562 ) 

563 self._nsamples_for_skymap = int(number_of_samples[min_arg]) 

564 

565 @property 

566 def gwdata(self): 

567 return self._gwdata 

568 

569 @gwdata.setter 

570 def gwdata(self, gwdata): 

571 from pesummary.gw.file.strain import StrainDataDict 

572 

573 self._gwdata = gwdata 

574 if gwdata is not None: 

575 if isinstance(gwdata, dict): 

576 for i in gwdata.keys(): 

577 if not os.path.isfile(gwdata[i]): 

578 raise InputError( 

579 "The file {} does not exist. Please check the path " 

580 "to your strain file".format(gwdata[i]) 

581 ) 

582 self._gwdata = StrainDataDict.read(gwdata) 

583 else: 

584 logger.warning( 

585 "Please provide gwdata as a dictionary with keys " 

586 "displaying the channel and item giving the path to the " 

587 "strain file" 

588 ) 

589 self._gwdata = None 

590 

591 @property 

592 def evolve_spins_forwards(self): 

593 return self._evolve_spins_forwards 

594 

595 @evolve_spins_forwards.setter 

596 def evolve_spins_forwards(self, evolve_spins_forwards): 

597 self._evolve_spins_forwards = evolve_spins_forwards 

598 _msg = "Spins will be evolved up to {}" 

599 if evolve_spins_forwards: 

600 logger.info(_msg.format("Schwarzschild ISCO frequency")) 

601 self._evolve_spins_forwards = 6. ** -0.5 

602 

603 @property 

604 def evolve_spins_backwards(self): 

605 return self._evolve_spins_backwards 

606 

607 @evolve_spins_backwards.setter 

608 def evolve_spins_backwards(self, evolve_spins_backwards): 

609 self._evolve_spins_backwards = evolve_spins_backwards 

610 _msg = ( 

611 "Spins will be evolved backwards to an infinite separation using the '{}' " 

612 "method" 

613 ) 

614 if isinstance(evolve_spins_backwards, (str, bytes)): 

615 logger.info(_msg.format(evolve_spins_backwards)) 

616 elif evolve_spins_backwards is None: 

617 logger.info(_msg.format("precession_averaged")) 

618 self._evolve_spins_backwards = "precession_averaged" 

619 

620 @property 

621 def NRSur_fits(self): 

622 return self._NRSur_fits 

623 

624 @NRSur_fits.setter 

625 def NRSur_fits(self, NRSur_fits): 

626 self._NRSur_fits = NRSur_fits 

627 base = ( 

628 "Using the '{}' NRSurrogate model to calculate the remnant " 

629 "quantities" 

630 ) 

631 if isinstance(NRSur_fits, (str, bytes)): 

632 logger.info(base.format(NRSur_fits)) 

633 self._NRSur_fits = NRSur_fits 

634 elif NRSur_fits is None: 

635 from pesummary.gw.conversions.nrutils import NRSUR_MODEL 

636 

637 logger.info(base.format(NRSUR_MODEL)) 

638 self._NRSur_fits = NRSUR_MODEL 

639 

640 @property 

641 def waveform_fits(self): 

642 return self._waveform_fits 

643 

644 @waveform_fits.setter 

645 def waveform_fits(self, waveform_fits): 

646 self._waveform_fits = waveform_fits 

647 if waveform_fits: 

648 logger.info( 

649 "Evaluating the remnant quantities using the provided " 

650 "approximant" 

651 ) 

652 

653 @property 

654 def f_low(self): 

655 return self._f_low 

656 

657 @f_low.setter 

658 def f_low(self, f_low): 

659 self._f_low = f_low 

660 if f_low is not None: 

661 self._f_low = [float(i) for i in f_low] 

662 

663 @property 

664 def f_start(self): 

665 return self._f_start 

666 

667 @f_start.setter 

668 def f_start(self, f_start): 

669 self._f_start = f_start 

670 if f_start is not None: 

671 self._f_start = [float(i) for i in f_start] 

672 

673 @property 

674 def f_ref(self): 

675 return self._f_ref 

676 

677 @f_ref.setter 

678 def f_ref(self, f_ref): 

679 self._f_ref = f_ref 

680 if f_ref is not None: 

681 self._f_ref = [float(i) for i in f_ref] 

682 

683 @property 

684 def f_final(self): 

685 return self._f_final 

686 

687 @f_final.setter 

688 def f_final(self, f_final): 

689 self._f_final = f_final 

690 if f_final is not None: 

691 self._f_final = [float(i) for i in f_final] 

692 

693 @property 

694 def delta_f(self): 

695 return self._delta_f 

696 

697 @delta_f.setter 

698 def delta_f(self, delta_f): 

699 self._delta_f = delta_f 

700 if delta_f is not None: 

701 self._delta_f = [float(i) for i in delta_f] 

702 

703 @property 

704 def psd_default(self): 

705 return self._psd_default 

706 

707 @psd_default.setter 

708 def psd_default(self, psd_default): 

709 self._psd_default = psd_default 

710 if "stored:" in psd_default: 

711 label = psd_default.split("stored:")[1] 

712 self._psd_default = "{file}.psd['%s']" % (label) 

713 return 

714 try: 

715 from pycbc import psd 

716 psd_default = getattr(psd, psd_default) 

717 except ImportError: 

718 logger.warning( 

719 "Unable to import 'pycbc'. Unable to generate a default PSD" 

720 ) 

721 psd_default = None 

722 except AttributeError: 

723 logger.warning( 

724 "'pycbc' does not have the '{}' psd available. Using '{}' as " 

725 "the default PSD".format(psd_default, conf.psd) 

726 ) 

727 psd_default = getattr(psd, conf.psd) 

728 except ValueError as e: 

729 logger.warning("Setting 'psd_default' to None because {}".format(e)) 

730 psd_default = None 

731 self._psd_default = psd_default 

732 

733 @property 

734 def pepredicates_probs(self): 

735 return self._pepredicates_probs 

736 

737 @pepredicates_probs.setter 

738 def pepredicates_probs(self, pepredicates_probs): 

739 from pesummary.gw.classification import PEPredicates 

740 

741 classifications = {} 

742 for num, i in enumerate(list(self.samples.keys())): 

743 try: 

744 classifications[i] = PEPredicates( 

745 self.samples[i] 

746 ).dual_classification() 

747 except Exception as e: 

748 logger.warning( 

749 "Failed to generate source classification probabilities " 

750 "because {}".format(e) 

751 ) 

752 classifications[i] = None 

753 if self.mcmc_samples: 

754 if any(_probs is None for _probs in classifications.values()): 

755 classifications[self.labels[0]] = None 

756 logger.warning( 

757 "Unable to average classification probabilities across " 

758 "mcmc chains because one or more chains failed to estimate " 

759 "classifications" 

760 ) 

761 else: 

762 logger.debug( 

763 "Averaging classification probability across mcmc samples" 

764 ) 

765 classifications[self.labels[0]] = { 

766 prior: { 

767 key: np.round(np.mean( 

768 [val[prior][key] for val in classifications.values()] 

769 ), 3) for key in _probs.keys() 

770 } for prior, _probs in 

771 list(classifications.values())[0].items() 

772 } 

773 self._pepredicates_probs = classifications 

774 

775 @property 

776 def pastro_probs(self): 

777 return self._pastro_probs 

778 

779 @pastro_probs.setter 

780 def pastro_probs(self, pastro_probs): 

781 from pesummary.gw.classification import PAstro 

782 

783 probabilities = {} 

784 for num, i in enumerate(list(self.samples.keys())): 

785 try: 

786 probabilities[i] = PAstro(self.samples[i]).dual_classification() 

787 except Exception as e: 

788 logger.warning( 

789 "Failed to generate em_bright probabilities because " 

790 "{}".format(e) 

791 ) 

792 probabilities[i] = None 

793 if self.mcmc_samples: 

794 if any(_probs is None for _probs in probabilities.values()): 

795 probabilities[self.labels[0]] = None 

796 logger.warning( 

797 "Unable to average em_bright probabilities across " 

798 "mcmc chains because one or more chains failed to estimate " 

799 "probabilities" 

800 ) 

801 else: 

802 logger.debug( 

803 "Averaging em_bright probability across mcmc samples" 

804 ) 

805 probabilities[self.labels[0]] = { 

806 prior: { 

807 key: np.round(np.mean( 

808 [val[prior][key] for val in probabilities.values()] 

809 ), 3) for key in _probs.keys() 

810 } for prior, _probs in list(probabilities.values())[0].items() 

811 } 

812 self._pastro_probs = probabilities 

813 

814 @property 

815 def preliminary_pages(self): 

816 return self._preliminary_pages 

817 

818 @preliminary_pages.setter 

819 def preliminary_pages(self, preliminary_pages): 

820 required = conf.gw_reproducibility 

821 self._preliminary_pages = {label: False for label in self.labels} 

822 for num, label in enumerate(self.labels): 

823 for attr in required: 

824 _property = getattr(self, attr) 

825 if isinstance(_property, dict): 

826 if label not in _property.keys(): 

827 self._preliminary_pages[label] = True 

828 elif not len(_property[label]): 

829 self._preliminary_pages[label] = True 

830 elif isinstance(_property, list): 

831 if _property[num] is None: 

832 self._preliminary_pages[label] = True 

833 if any(value for value in self._preliminary_pages.values()): 

834 _labels = [ 

835 label for label, value in self._preliminary_pages.items() if 

836 value 

837 ] 

838 msg = ( 

839 "Unable to reproduce the {} analys{} because no {} data was " 

840 "provided. 'Preliminary' watermarks will be added to the final " 

841 "html pages".format( 

842 ", ".join(_labels), "es" if len(_labels) > 1 else "is", 

843 " or ".join(required) 

844 ) 

845 ) 

846 logger.warning(msg) 

847 

848 @staticmethod 

849 def _extract_IFO_data_from_file(file, cls, desc, IFO=None, **kwargs): 

850 """Return IFO data stored in a file 

851 

852 Parameters 

853 ---------- 

854 file: path 

855 path to a file containing the IFO data 

856 cls: obj 

857 class you wish to use when loading the file. This class must have 

858 a '.read' method 

859 desc: str 

860 description of the IFO data stored in the file 

861 IFO: str, optional 

862 the IFO which the data belongs to 

863 """ 

864 general = ( 

865 "Failed to read in %s data because {}. The %s plot will not be " 

866 "generated and the %s data will not be added to the metafile." 

867 ) % (desc, desc, desc) 

868 try: 

869 return cls.read(file, IFO=IFO, **kwargs) 

870 except FileNotFoundError: 

871 logger.warning( 

872 general.format("the file {} does not exist".format(file)) 

873 ) 

874 return {} 

875 except ValueError as e: 

876 logger.warning(general.format(e)) 

877 return {} 

878 

879 @staticmethod 

880 def extract_psd_data_from_file(file, IFO=None): 

881 """Return the data stored in a psd file 

882 

883 Parameters 

884 ---------- 

885 file: path 

886 path to a file containing the psd data 

887 """ 

888 from pesummary.gw.file.psd import PSD 

889 return _GWInput._extract_IFO_data_from_file(file, PSD, "PSD", IFO=IFO) 

890 

891 @staticmethod 

892 def extract_calibration_data_from_file(file, type="data", **kwargs): 

893 """Return the data stored in a calibration file 

894 

895 Parameters 

896 ---------- 

897 file: path 

898 path to a file containing the calibration data 

899 """ 

900 from pesummary.gw.file.calibration import Calibration 

901 return _GWInput._extract_IFO_data_from_file( 

902 file, Calibration, "calibration", type=type, **kwargs 

903 ) 

904 

905 @staticmethod 

906 def get_ifo_from_file_name(file): 

907 """Return the IFO from the file name 

908 

909 Parameters 

910 ---------- 

911 file: str 

912 path to the file 

913 """ 

914 file_name = file.split("/")[-1] 

915 if any(j in file_name for j in ["H1", "_0", "IFO0"]): 

916 ifo = "H1" 

917 elif any(j in file_name for j in ["L1", "_1", "IFO1"]): 

918 ifo = "L1" 

919 elif any(j in file_name for j in ["V1", "_2", "IFO2"]): 

920 ifo = "V1" 

921 else: 

922 ifo = file_name 

923 return ifo 

924 

925 def get_psd_or_calibration_data(self, input, executable, **kwargs): 

926 """Return a dictionary containing the psd or calibration data 

927 

928 Parameters 

929 ---------- 

930 input: list/dict 

931 list/dict containing paths to calibration/psd files 

932 executable: func 

933 executable that is used to extract the data from the calibration/psd 

934 files 

935 """ 

936 data = {} 

937 if input == {} or input == []: 

938 return data 

939 if isinstance(input, dict): 

940 keys = list(input.keys()) 

941 if isinstance(input, dict) and isinstance(input[keys[0]], list): 

942 if not all(len(input[i]) == len(self.labels) for i in list(keys)): 

943 raise InputError( 

944 "Please ensure the number of calibration/psd files matches " 

945 "the number of result files passed" 

946 ) 

947 for idx in range(len(input[keys[0]])): 

948 data[self.labels[idx]] = { 

949 i: executable(input[i][idx], IFO=i, **kwargs) for i in list(keys) 

950 } 

951 elif isinstance(input, dict): 

952 for i in self.labels: 

953 data[i] = { 

954 j: executable(input[j], IFO=j, **kwargs) for j in list(input.keys()) 

955 } 

956 elif isinstance(input, list): 

957 for i in self.labels: 

958 data[i] = { 

959 self.get_ifo_from_file_name(j): executable( 

960 j, IFO=self.get_ifo_from_file_name(j), **kwargs 

961 ) for j in input 

962 } 

963 else: 

964 raise InputError( 

965 "Did not understand the psd/calibration input. Please use the " 

966 "following format 'H1:path/to/file'" 

967 ) 

968 return data 

969 

970 def grab_priors_from_inputs(self, priors): 

971 def read_func(data, **kwargs): 

972 from pesummary.gw.file.read import read as GWRead 

973 data = GWRead(data, **kwargs) 

974 data.generate_all_posterior_samples() 

975 return data 

976 

977 return super(_GWInput, self).grab_priors_from_inputs( 

978 priors, read_func=read_func, read_kwargs=self.grab_data_kwargs 

979 ) 

980 

981 def grab_key_data_from_result_files(self): 

982 """Grab the mean, median, maxL and standard deviation for all 

983 parameters for all each result file 

984 """ 

985 from pesummary.utils.kde_list import KDEList 

986 from pesummary.gw.plots.plot import _return_bounds 

987 from pesummary.utils.credible_interval import ( 

988 hpd_two_sided_credible_interval 

989 ) 

990 from pesummary.utils.bounded_1d_kde import bounded_1d_kde 

991 key_data = super(_GWInput, self).grab_key_data_from_result_files() 

992 bounded_parameters = ["mass_ratio", "a_1", "a_2", "lambda_tilde"] 

993 for param in bounded_parameters: 

994 xlow, xhigh = _return_bounds(param, []) 

995 _samples = { 

996 key: val[param] for key, val in self.samples.items() 

997 if param in val.keys() 

998 } 

999 _min = [np.min(_) for _ in _samples.values() if len(_samples)] 

1000 _max = [np.max(_) for _ in _samples.values() if len(_samples)] 

1001 if not len(_min): 

1002 continue 

1003 _min = np.min(_min) 

1004 _max = np.max(_max) 

1005 x = np.linspace(_min, _max, 1000) 

1006 try: 

1007 kdes = KDEList( 

1008 list(_samples.values()), kde=bounded_1d_kde, 

1009 kde_kwargs={"xlow": xlow, "xhigh": xhigh} 

1010 ) 

1011 except Exception as e: 

1012 logger.warning( 

1013 "Unable to compute the HPD interval for {} because {}".format( 

1014 param, e 

1015 ) 

1016 ) 

1017 continue 

1018 pdfs = kdes(x) 

1019 for num, key in enumerate(_samples.keys()): 

1020 [xlow, xhigh], _ = hpd_two_sided_credible_interval( 

1021 [], 90, x=x, pdf=pdfs[num] 

1022 ) 

1023 key_data[key][param]["90% HPD"] = [xlow, xhigh] 

1024 for _param in self.samples[key].keys(): 

1025 if _param in bounded_parameters: 

1026 continue 

1027 key_data[key][_param]["90% HPD"] = float("nan") 

1028 return key_data 

1029 

1030 

1031class SamplesInput(_GWInput, pesummary.core.cli.inputs.SamplesInput): 

1032 """Class to handle and store sample specific command line arguments 

1033 """ 

1034 def __init__(self, *args, **kwargs): 

1035 kwargs.update({"ignore_copy": True}) 

1036 super(SamplesInput, self).__init__( 

1037 *args, gw=True, extra_options=[ 

1038 "evolve_spins_forwards", 

1039 "evolve_spins_backwards", 

1040 "NRSur_fits", 

1041 "calculate_multipole_snr", 

1042 "calculate_precessing_snr", 

1043 "f_start", 

1044 "f_low", 

1045 "f_ref", 

1046 "f_final", 

1047 "psd", 

1048 "waveform_fits", 

1049 "redshift_method", 

1050 "cosmology", 

1051 "no_conversion", 

1052 "delta_f", 

1053 "psd_default", 

1054 "disable_remnant", 

1055 "force_BBH_remnant_computation", 

1056 "force_BH_spin_evolution" 

1057 ], **kwargs 

1058 ) 

1059 if self._restarted_from_checkpoint: 

1060 return 

1061 if self.existing is not None: 

1062 self.existing_data = self.grab_data_from_metafile( 

1063 self.existing_metafile, self.existing, 

1064 compare=self.compare_results 

1065 ) 

1066 self.existing_approximant = self.existing_data["approximant"] 

1067 self.existing_psd = self.existing_data["psd"] 

1068 self.existing_calibration = self.existing_data["calibration"] 

1069 self.existing_skymap = self.existing_data["skymap"] 

1070 else: 

1071 self.existing_approximant = None 

1072 self.existing_psd = None 

1073 self.existing_calibration = None 

1074 self.existing_skymap = None 

1075 self.approximant = self.opts.approximant 

1076 self.approximant_flags = self.opts.approximant_flags 

1077 self.detectors = None 

1078 self.skymap = None 

1079 self.calibration_definition = self.opts.calibration_definition 

1080 self.calibration = self.opts.calibration 

1081 self.gwdata = self.opts.gwdata 

1082 self.maxL_samples = [] 

1083 

1084 @property 

1085 def maxL_samples(self): 

1086 return self._maxL_samples 

1087 

1088 @maxL_samples.setter 

1089 def maxL_samples(self, maxL_samples): 

1090 key_data = self.grab_key_data_from_result_files() 

1091 maxL_samples = { 

1092 i: { 

1093 j: key_data[i][j]["maxL"] for j in key_data[i].keys() 

1094 } for i in key_data.keys() 

1095 } 

1096 for i in self.labels: 

1097 maxL_samples[i]["approximant"] = self.approximant[i] 

1098 self._maxL_samples = maxL_samples 

1099 

1100 

1101class PlottingInput(SamplesInput, pesummary.core.cli.inputs.PlottingInput): 

1102 """Class to handle and store plottig specific command line arguments 

1103 """ 

1104 def __init__(self, *args, **kwargs): 

1105 super(PlottingInput, self).__init__(*args, **kwargs) 

1106 self.nsamples_for_skymap = self.opts.nsamples_for_skymap 

1107 self.sensitivity = self.opts.sensitivity 

1108 self.no_ligo_skymap = self.opts.no_ligo_skymap 

1109 self.multi_threading_for_skymap = self.multi_process 

1110 if not self.no_ligo_skymap and self.multi_process > 1: 

1111 total = self.multi_process 

1112 self.multi_threading_for_plots = int(total / 2.) 

1113 self.multi_threading_for_skymap = total - self.multi_threading_for_plots 

1114 logger.info( 

1115 "Assigning {} process{}to skymap generation and {} process{}to " 

1116 "other plots".format( 

1117 self.multi_threading_for_skymap, 

1118 "es " if self.multi_threading_for_skymap > 1 else " ", 

1119 self.multi_threading_for_plots, 

1120 "es " if self.multi_threading_for_plots > 1 else " " 

1121 ) 

1122 ) 

1123 self.preliminary_pages = None 

1124 self.pepredicates_probs = [] 

1125 self.pastro_probs = [] 

1126 

1127 

1128class WebpageInput(SamplesInput, pesummary.core.cli.inputs.WebpageInput): 

1129 """Class to handle and store webpage specific command line arguments 

1130 """ 

1131 def __init__(self, *args, **kwargs): 

1132 super(WebpageInput, self).__init__(*args, **kwargs) 

1133 self.gracedb_server = self.opts.gracedb_server 

1134 self.gracedb_data = self.opts.gracedb_data 

1135 self.gracedb = self.opts.gracedb 

1136 self.public = self.opts.public 

1137 if not hasattr(self, "preliminary_pages"): 

1138 self.preliminary_pages = None 

1139 if not hasattr(self, "pepredicates_probs"): 

1140 self.pepredicates_probs = [] 

1141 if not hasattr(self, "pastro_probs"): 

1142 self.pastro_probs = [] 

1143 

1144 

1145class WebpagePlusPlottingInput(PlottingInput, WebpageInput): 

1146 """Class to handle and store webpage and plotting specific command line 

1147 arguments 

1148 """ 

1149 def __init__(self, *args, **kwargs): 

1150 super(WebpagePlusPlottingInput, self).__init__(*args, **kwargs) 

1151 

1152 @property 

1153 def default_directories(self): 

1154 return super(WebpagePlusPlottingInput, self).default_directories 

1155 

1156 @property 

1157 def default_files_to_copy(self): 

1158 return super(WebpagePlusPlottingInput, self).default_files_to_copy 

1159 

1160 

1161class MetaFileInput(SamplesInput, pesummary.core.cli.inputs.MetaFileInput): 

1162 """Class to handle and store metafile specific command line arguments 

1163 """ 

1164 @property 

1165 def default_directories(self): 

1166 dirs = super(MetaFileInput, self).default_directories 

1167 dirs += ["psds", "calibration"] 

1168 return dirs 

1169 

1170 def copy_files(self): 

1171 _error = "Failed to save the {} to file" 

1172 for label in self.labels: 

1173 if self.psd[label] != {}: 

1174 for ifo in self.psd[label].keys(): 

1175 if not isinstance(self.psd[label][ifo], PSD): 

1176 logger.warning(_error.format("{} PSD".format(ifo))) 

1177 continue 

1178 self.psd[label][ifo].save_to_file( 

1179 os.path.join(self.webdir, "psds", "{}_{}_psd.dat".format( 

1180 label, ifo 

1181 )) 

1182 ) 

1183 if label in self.priors["calibration"].keys(): 

1184 if self.priors["calibration"][label] != {}: 

1185 for ifo in self.priors["calibration"][label].keys(): 

1186 _instance = isinstance( 

1187 self.priors["calibration"][label][ifo], Calibration 

1188 ) 

1189 if not _instance: 

1190 logger.warning( 

1191 _error.format( 

1192 "{} calibration envelope".format( 

1193 ifo 

1194 ) 

1195 ) 

1196 ) 

1197 continue 

1198 self.priors["calibration"][label][ifo].save_to_file( 

1199 os.path.join(self.webdir, "calibration", "{}_{}_cal.txt".format( 

1200 label, ifo 

1201 )) 

1202 ) 

1203 return super(MetaFileInput, self).copy_files() 

1204 

1205 

1206class WebpagePlusPlottingPlusMetaFileInput(MetaFileInput, WebpagePlusPlottingInput): 

1207 """Class to handle and store webpage, plotting and metafile specific command 

1208 line arguments 

1209 """ 

1210 def __init__(self, *args, **kwargs): 

1211 super(WebpagePlusPlottingPlusMetaFileInput, self).__init__( 

1212 *args, **kwargs 

1213 ) 

1214 

1215 @property 

1216 def default_directories(self): 

1217 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_directories 

1218 

1219 @property 

1220 def default_files_to_copy(self): 

1221 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_files_to_copy 

1222 

1223 

1224@deprecation( 

1225 "The GWInput class is deprecated. Please use either the BaseInput, " 

1226 "SamplesInput, PlottingInput, WebpageInput, WebpagePlusPlottingInput, " 

1227 "MetaFileInput or the WebpagePlusPlottingPlusMetaFileInput class" 

1228) 

1229class GWInput(WebpagePlusPlottingPlusMetaFileInput): 

1230 pass 

1231 

1232 

1233class IMRCTInput(pesummary.core.cli.inputs._Input): 

1234 """Class to handle the TGR specific command line arguments 

1235 """ 

1236 @property 

1237 def labels(self): 

1238 return self._labels 

1239 

1240 @labels.setter 

1241 def labels(self, labels): 

1242 self._labels = labels 

1243 if len(labels) % 2 != 0: 

1244 raise ValueError( 

1245 "The IMRCT test requires 2 results files for each analysis. " 

1246 ) 

1247 elif len(labels) > 2: 

1248 cond = all( 

1249 ":inspiral" in label or ":postinspiral" in label for label in 

1250 labels 

1251 ) 

1252 if not cond: 

1253 raise ValueError( 

1254 "To compare 2 or more analyses, please provide labels as " 

1255 "'{}:inspiral' and '{}:postinspiral' where {} indicates " 

1256 "the analysis label" 

1257 ) 

1258 else: 

1259 self.analysis_label = [ 

1260 label.split(":inspiral")[0] 

1261 for label in labels 

1262 if ":inspiral" in label and ":postinspiral" not in label 

1263 ] 

1264 if len(self.analysis_label) != len(self.result_files) / 2: 

1265 raise ValueError( 

1266 "When comparing more than 2 analyses, labels must " 

1267 "be of the form '{}:inspiral' and '{}:postinspiral'." 

1268 ) 

1269 logger.info( 

1270 "Using the labels: {} to distinguish analyses".format( 

1271 ", ".join(self.analysis_label) 

1272 ) 

1273 ) 

1274 elif sorted(labels) != ["inspiral", "postinspiral"]: 

1275 if all(self.is_pesummary_metafile(ff) for ff in self.result_files): 

1276 meta_file_labels = [] 

1277 for suffix in [":inspiral", ":postinspiral"]: 

1278 if any(suffix in label for label in labels): 

1279 ind = [ 

1280 num for num, label in enumerate(labels) if 

1281 suffix in label 

1282 ] 

1283 if len(ind) > 1: 

1284 raise ValueError( 

1285 "Please provide a single {} label".format( 

1286 suffix.split(":")[1] 

1287 ) 

1288 ) 

1289 meta_file_labels.append( 

1290 labels[ind[0]].split(suffix)[0] 

1291 ) 

1292 else: 

1293 raise ValueError( 

1294 "Please provide labels as {inspiral_label}:inspiral " 

1295 "and {postinspiral_label}:postinspiral where " 

1296 "inspiral_label and postinspiral_label are the " 

1297 "PESummary labels for the inspiral and postinspiral " 

1298 "analyses respectively. " 

1299 ) 

1300 if len(self.result_files) == 1: 

1301 logger.info( 

1302 "Using the {} samples for the inspiral analysis and {} " 

1303 "samples for the postinspiral analysis from the file " 

1304 "{}".format( 

1305 meta_file_labels[0], meta_file_labels[1], 

1306 self.result_files[0] 

1307 ) 

1308 ) 

1309 elif len(self.result_files) == 2: 

1310 logger.info( 

1311 "Using the {} samples for the inspiral analysis from " 

1312 "the file {}. Using the {} samples for the " 

1313 "postinspiral analysis from the file {}".format( 

1314 meta_file_labels[0], self.result_files[0], 

1315 meta_file_labels[1], self.result_files[1] 

1316 ) 

1317 ) 

1318 else: 

1319 raise ValueError( 

1320 "Currently, you can only provide at most 2 pesummary " 

1321 "metafiles. If one is provided, both the inspiral and " 

1322 "postinspiral are extracted from that single file. If " 

1323 "two are provided, the inspiral is extracted from one " 

1324 "file and the postinspiral is extracted from the other." 

1325 ) 

1326 self._labels = ["inspiral", "postinspiral"] 

1327 self._meta_file_labels = meta_file_labels 

1328 self.analysis_label = ["primary"] 

1329 else: 

1330 raise ValueError( 

1331 "The IMRCT test requires an inspiral and postinspiral result " 

1332 "file. Please indicate which file is the inspiral and which " 

1333 "is postinspiral by providing these exact labels to the " 

1334 "summarytgr executable" 

1335 ) 

1336 else: 

1337 self.analysis_label = ["primary"] 

1338 

1339 def _extract_stored_approximant(self, opened_file, label): 

1340 """Extract the approximant used for a given analysis stored in a 

1341 PESummary metafile 

1342 

1343 Parameters 

1344 ---------- 

1345 opened_file: pesummary.gw.file.formats.pesummary.PESummary 

1346 opened metafile that contains the analysis 'label' 

1347 label: str 

1348 analysis label which is stored in the PESummary metafile 

1349 """ 

1350 if opened_file.approximant is not None: 

1351 if label not in opened_file.labels: 

1352 raise ValueError( 

1353 "Invalid label {}. The list of available labels are {}".format( 

1354 label, ", ".join(opened_file.labels) 

1355 ) 

1356 ) 

1357 _index = opened_file.labels.index(label) 

1358 return opened_file.approximant[_index] 

1359 return 

1360 

1361 def _extract_stored_remnant_fits(self, opened_file, label): 

1362 """Extract the remnant fits used for a given analysis stored in a 

1363 PESummary metafile 

1364 

1365 Parameters 

1366 ---------- 

1367 opened_file: pesummary.gw.file.formats.pesummary.PESummary 

1368 opened metafile that contains the analysis 'label' 

1369 label: str 

1370 analysis label which is stored in the PESummary metafile 

1371 """ 

1372 fits = {} 

1373 fit_strings = [ 

1374 "final_spin_NR_fits", "final_mass_NR_fits" 

1375 ] 

1376 if label not in opened_file.labels: 

1377 raise ValueError( 

1378 "Invalid label {}. The list of available labels are {}".format( 

1379 label, ", ".join(opened_file.labels) 

1380 ) 

1381 ) 

1382 _index = opened_file.labels.index(label) 

1383 _meta_data = opened_file.extra_kwargs[_index] 

1384 if "meta_data" in _meta_data.keys(): 

1385 for key in fit_strings: 

1386 if key in _meta_data["meta_data"].keys(): 

1387 fits[key] = _meta_data["meta_data"][key] 

1388 if len(fits): 

1389 return fits 

1390 return 

1391 

1392 def _extract_stored_cutoff_frequency(self, opened_file, label): 

1393 """Extract the cutoff frequencies used for a given analysis stored in a 

1394 PESummary metafile 

1395 

1396 Parameters 

1397 ---------- 

1398 opened_file: pesummary.gw.file.formats.pesummary.PESummary 

1399 opened metafile that contains the analysis 'label' 

1400 label: str 

1401 analysis label which is stored in the PESummary metafile 

1402 """ 

1403 frequencies = {} 

1404 if opened_file.config is not None: 

1405 if label not in opened_file.labels: 

1406 raise ValueError( 

1407 "Invalid label {}. The list of available labels are {}".format( 

1408 label, ", ".join(opened_file.labels) 

1409 ) 

1410 ) 

1411 if opened_file.config[label] is not None: 

1412 _config = opened_file.config[label] 

1413 if "config" in _config.keys(): 

1414 if "maximum-frequency" in _config["config"].keys(): 

1415 frequencies["fhigh"] = _config["config"][ 

1416 "maximum-frequency" 

1417 ] 

1418 if "minimum-frequency" in _config["config"].keys(): 

1419 frequencies["flow"] = _config["config"][ 

1420 "minimum-frequency" 

1421 ] 

1422 elif "lalinference" in _config.keys(): 

1423 if "fhigh" in _config["lalinference"].keys(): 

1424 frequencies["fhigh"] = _config["lalinference"][ 

1425 "fhigh" 

1426 ] 

1427 if "flow" in _config["lalinference"].keys(): 

1428 frequencies["flow"] = _config["lalinference"][ 

1429 "flow" 

1430 ] 

1431 return frequencies 

1432 return 

1433 

1434 @property 

1435 def samples(self): 

1436 return self._samples 

1437 

1438 @samples.setter 

1439 def samples(self, samples): 

1440 from pesummary.utils.samples_dict import MultiAnalysisSamplesDict 

1441 self._read_samples = { 

1442 _label: GWRead(_path, disable_prior=True) for _label, _path in zip( 

1443 self.labels, self.result_files 

1444 ) 

1445 } 

1446 _samples_dict = {} 

1447 _approximant_dict = {} 

1448 _cutoff_frequency_dict = {} 

1449 _remnant_fits_dict = {} 

1450 for label, _open in self._read_samples.items(): 

1451 if isinstance(_open.samples_dict, MultiAnalysisSamplesDict): 

1452 if not len(self._meta_file_labels): 

1453 raise ValueError( 

1454 "Currently you can only pass a file containing a " 

1455 "single analysis or a valid PESummary metafile " 

1456 "containing multiple analyses" 

1457 ) 

1458 _labels = _open.labels 

1459 if len(self._read_samples) == 1: 

1460 _samples_dict = { 

1461 label: _open.samples_dict[meta_file_label] for 

1462 label, meta_file_label in zip( 

1463 self.labels, self._meta_file_labels 

1464 ) 

1465 } 

1466 for label, meta_file_label in zip(self.labels, self._meta_file_labels): 

1467 _stored_approx = self._extract_stored_approximant( 

1468 _open, meta_file_label 

1469 ) 

1470 _stored_frequencies = self._extract_stored_cutoff_frequency( 

1471 _open, meta_file_label 

1472 ) 

1473 _stored_remnant_fits = self._extract_stored_remnant_fits( 

1474 _open, meta_file_label 

1475 ) 

1476 if _stored_approx is not None: 

1477 _approximant_dict[label] = _stored_approx 

1478 if _stored_remnant_fits is not None: 

1479 _remnant_fits_dict[label] = _stored_remnant_fits 

1480 if _stored_frequencies is not None: 

1481 if label == "inspiral": 

1482 if "fhigh" in _stored_frequencies.keys(): 

1483 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1484 "fhigh" 

1485 ] 

1486 if label == "postinspiral": 

1487 if "flow" in _stored_frequencies.keys(): 

1488 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1489 "flow" 

1490 ] 

1491 break 

1492 else: 

1493 ind = self.labels.index(label) 

1494 _samples_dict[label] = _open.samples_dict[ 

1495 self._meta_file_labels[ind] 

1496 ] 

1497 _stored_approx = self._extract_stored_approximant( 

1498 _open, self._meta_file_labels[ind] 

1499 ) 

1500 _stored_frequencies = self._extract_stored_cutoff_frequency( 

1501 _open, self._meta_file_labels[ind] 

1502 ) 

1503 _stored_remnant_fits = self._extract_stored_remnant_fits( 

1504 _open, self._meta_file_labels[ind] 

1505 ) 

1506 if _stored_approx is not None: 

1507 _approximant_dict[label] = _stored_approx 

1508 if _stored_remnant_fits is not None: 

1509 _remnant_fits_dict[label] = _stored_remnant_fits 

1510 if _stored_frequencies is not None: 

1511 if label == "inspiral": 

1512 if "fhigh" in _stored_frequencies.keys(): 

1513 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1514 "fhigh" 

1515 ] 

1516 if label == "postinspiral": 

1517 if "flow" in _stored_frequencies.keys(): 

1518 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1519 "flow" 

1520 ] 

1521 else: 

1522 _samples_dict[label] = _open.samples_dict 

1523 extra_kwargs = _open.extra_kwargs 

1524 if "pe_algorithm" in extra_kwargs["sampler"].keys(): 

1525 if extra_kwargs["sampler"]["pe_algorithm"] == "bilby": 

1526 try: 

1527 subkwargs = extra_kwargs["other"]["likelihood"][ 

1528 "waveform_arguments" 

1529 ] 

1530 _approximant_dict[label] = ( 

1531 subkwargs["waveform_approximant"] 

1532 ) 

1533 if "inspiral" in label and "postinspiral" not in label: 

1534 _cutoff_frequency_dict[label] = ( 

1535 subkwargs["maximum_frequency"] 

1536 ) 

1537 elif "postinspiral" in label: 

1538 _cutoff_frequency_dict[label] = ( 

1539 subkwargs["minimum_frequency"] 

1540 ) 

1541 except KeyError: 

1542 pass 

1543 self._samples = MultiAnalysisSamplesDict(_samples_dict) 

1544 if len(_approximant_dict): 

1545 self._approximant_dict = _approximant_dict 

1546 if len(_cutoff_frequency_dict): 

1547 self._cutoff_frequency_dict = _cutoff_frequency_dict 

1548 if len(_remnant_fits_dict): 

1549 self._remnant_fits_dict = _remnant_fits_dict 

1550 

1551 @property 

1552 def imrct_kwargs(self): 

1553 return self._imrct_kwargs 

1554 

1555 @imrct_kwargs.setter 

1556 def imrct_kwargs(self, imrct_kwargs): 

1557 test_kwargs = dict(N_bins=101) 

1558 try: 

1559 test_kwargs.update(imrct_kwargs) 

1560 except AttributeError: 

1561 test_kwargs = test_kwargs 

1562 

1563 for key, value in test_kwargs.items(): 

1564 try: 

1565 test_kwargs[key] = ast.literal_eval(value) 

1566 except ValueError: 

1567 pass 

1568 self._imrct_kwargs = test_kwargs 

1569 

1570 @property 

1571 def meta_data(self): 

1572 return self._meta_data 

1573 

1574 @meta_data.setter 

1575 def meta_data(self, meta_data): 

1576 self._meta_data = {} 

1577 for num, _inspiral in enumerate(self.inspiral_keys): 

1578 frequency_dict = dict() 

1579 approximant_dict = dict() 

1580 remnant_dict = dict() 

1581 zipped = zip( 

1582 [self.cutoff_frequency, self.approximant, None], 

1583 [frequency_dict, approximant_dict, remnant_dict], 

1584 ["cutoff_frequency", "approximant", "remnant_fits"] 

1585 ) 

1586 _inspiral_string = self.inspiral_keys[num] 

1587 _postinspiral_string = self.postinspiral_keys[num] 

1588 for _list, _dict, name in zipped: 

1589 if _list is not None and len(_list) == len(self.labels): 

1590 inspiral_ind = self.labels.index(_inspiral_string) 

1591 postinspiral_ind = self.labels.index(_postinspiral_string) 

1592 _dict["inspiral"] = _list[inspiral_ind] 

1593 _dict["postinspiral"] = _list[postinspiral_ind] 

1594 elif _list is not None: 

1595 raise ValueError( 

1596 "Please provide a 'cutoff_frequency' and 'approximant' " 

1597 "for each file" 

1598 ) 

1599 else: 

1600 try: 

1601 if name == "cutoff_frequency": 

1602 if "inspiral" in self._cutoff_frequency_dict.keys(): 

1603 _dict["inspiral"] = self._cutoff_frequency_dict[ 

1604 "inspiral" 

1605 ] 

1606 if "postinspiral" in self._cutoff_frequency_dict.keys(): 

1607 _dict["postinspiral"] = self._cutoff_frequency_dict[ 

1608 "postinspiral" 

1609 ] 

1610 elif name == "approximant": 

1611 if "inspiral" in self._approximant_dict.keys(): 

1612 _dict["inspiral"] = self._approximant_dict[ 

1613 "inspiral" 

1614 ] 

1615 if "postinspiral" in self._approximant_dict.keys(): 

1616 _dict["postinspiral"] = self._approximant_dict[ 

1617 "postinspiral" 

1618 ] 

1619 elif name == "remnant_fits": 

1620 if "inspiral" in self._remnant_fits_dict.keys(): 

1621 _dict["inspiral"] = self._remnant_fits_dict[ 

1622 "inspiral" 

1623 ] 

1624 if "postinspiral" in self._remnant_fits_dict.keys(): 

1625 _dict["postinspiral"] = self._remnant_fits_dict[ 

1626 "postinspiral" 

1627 ] 

1628 except (AttributeError, KeyError, TypeError): 

1629 _dict["inspiral"] = None 

1630 _dict["postinspiral"] = None 

1631 

1632 self._meta_data[self.analysis_label[num]] = { 

1633 "inspiral maximum frequency (Hz)": frequency_dict["inspiral"], 

1634 "postinspiral minimum frequency (Hz)": frequency_dict["postinspiral"], 

1635 "inspiral approximant": approximant_dict["inspiral"], 

1636 "postinspiral approximant": approximant_dict["postinspiral"], 

1637 "inspiral remnant fits": remnant_dict["inspiral"], 

1638 "postinspiral remnant fits": remnant_dict["postinspiral"] 

1639 } 

1640 

1641 def __init__(self, opts): 

1642 self.opts = opts 

1643 self.existing = None 

1644 self.webdir = self.opts.webdir 

1645 self.user = None 

1646 self.baseurl = None 

1647 self.result_files = self.opts.samples 

1648 self.labels = self.opts.labels 

1649 self.samples = self.opts.samples 

1650 self.inspiral_keys = [ 

1651 key for key in self.samples.keys() if "inspiral" in key 

1652 and "postinspiral" not in key 

1653 ] 

1654 self.postinspiral_keys = [ 

1655 key.replace("inspiral", "postinspiral") for key in self.inspiral_keys 

1656 ] 

1657 try: 

1658 self.imrct_kwargs = self.opts.imrct_kwargs 

1659 except AttributeError: 

1660 self.imrct_kwargs = {} 

1661 for _arg in ["cutoff_frequency", "approximant", "links_to_pe_pages", "f_low"]: 

1662 _attr = getattr(self.opts, _arg) 

1663 if _attr is not None and len(_attr) and len(_attr) != len(self.labels): 

1664 raise ValueError("Please provide a {} for each file".format(_arg)) 

1665 setattr(self, _arg, _attr) 

1666 self.meta_data = None 

1667 self.default_directories = ["samples", "plots", "js", "html", "css"] 

1668 self.publication = False 

1669 self.make_directories()