Coverage for pesummary/gw/cli/inputs.py: 75.2%

978 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import ast 

4import os 

5import numpy as np 

6import pesummary.core.cli.inputs 

7from pesummary.gw.file.read import read as GWRead 

8from pesummary.gw.file.psd import PSD 

9from pesummary.gw.file.calibration import Calibration 

10from pesummary.utils.decorators import deprecation 

11from pesummary.utils.exceptions import InputError 

12from pesummary.utils.utils import logger 

13from pesummary import conf 

14 

15__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

16 

17 

18class _GWInput(pesummary.core.cli.inputs._Input): 

19 """Super class to handle gw specific command line inputs 

20 """ 

21 @staticmethod 

22 def grab_data_from_metafile( 

23 existing_file, webdir, compare=None, nsamples=None, **kwargs 

24 ): 

25 """Grab data from an existing PESummary metafile 

26 

27 Parameters 

28 ---------- 

29 existing_file: str 

30 path to the existing metafile 

31 webdir: str 

32 the directory to store the existing configuration file 

33 compare: list, optional 

34 list of labels for events stored in an existing metafile that you 

35 wish to compare 

36 """ 

37 _replace_kwargs = { 

38 "psd": "{file}.psd['{label}']" 

39 } 

40 if "psd_default" in kwargs.keys(): 

41 _replace_kwargs["psd_default"] = kwargs["psd_default"] 

42 data = pesummary.core.cli.inputs._Input.grab_data_from_metafile( 

43 existing_file, webdir, compare=compare, read_function=GWRead, 

44 nsamples=nsamples, _replace_with_pesummary_kwargs=_replace_kwargs, 

45 **kwargs 

46 ) 

47 f = GWRead(existing_file) 

48 

49 labels = data["labels"] 

50 

51 psd = {i: {} for i in labels} 

52 if f.psd is not None and f.psd != {}: 

53 for i in labels: 

54 if i in f.psd.keys() and f.psd[i] != {}: 

55 psd[i] = { 

56 ifo: PSD(f.psd[i][ifo]) for ifo in f.psd[i].keys() 

57 } 

58 calibration = {i: {} for i in labels} 

59 if f.calibration is not None and f.calibration != {}: 

60 for i in labels: 

61 if i in f.calibration.keys() and f.calibration[i] != {}: 

62 calibration[i] = { 

63 ifo: Calibration(f.calibration[i][ifo]) for ifo in 

64 f.calibration[i].keys() 

65 } 

66 skymap = {i: None for i in labels} 

67 if hasattr(f, "skymap") and f.skymap is not None and f.skymap != {}: 

68 for i in labels: 

69 if i in f.skymap.keys() and len(f.skymap[i]): 

70 skymap[i] = f.skymap[i] 

71 data.update( 

72 { 

73 "approximant": { 

74 i: j for i, j in zip( 

75 labels, [f.approximant[ind] for ind in data["indicies"]] 

76 ) 

77 }, 

78 "psd": psd, 

79 "calibration": calibration, 

80 "skymap": skymap 

81 } 

82 ) 

83 return data 

84 

85 @property 

86 def grab_data_kwargs(self): 

87 kwargs = super(_GWInput, self).grab_data_kwargs 

88 for _property in ["f_start", "f_low", "f_ref", "f_final", "delta_f"]: 

89 if getattr(self, _property) is None: 

90 setattr(self, "_{}".format(_property), [None] * len(self.labels)) 

91 elif len(getattr(self, _property)) == 1 and len(self.labels) != 1: 

92 setattr( 

93 self, "_{}".format(_property), 

94 getattr(self, _property) * len(self.labels) 

95 ) 

96 if self.opts.approximant is None: 

97 approx = [None] * len(self.labels) 

98 else: 

99 approx = self.opts.approximant 

100 self.approximant_flags = self.opts.approximant_flags 

101 resume_file = [ 

102 os.path.join( 

103 self.webdir, "checkpoint", 

104 "{}_conversion_class.pickle".format(label) 

105 ) for label in self.labels 

106 ] 

107 

108 try: 

109 for num, label in enumerate(self.labels): 

110 try: 

111 psd = self.psd[label] 

112 except KeyError: 

113 psd = {} 

114 kwargs[label].update(dict( 

115 evolve_spins_forwards=self.evolve_spins_forwards, 

116 evolve_spins_backwards=self.evolve_spins_backwards, 

117 f_start=self.f_start[num], f_low=self.f_low[num], 

118 approximant_flags=self.approximant_flags.get(label, {}), 

119 approximant=approx[num], f_ref=self.f_ref[num], 

120 NRSur_fits=self.NRSur_fits, return_kwargs=True, 

121 multipole_snr=self.calculate_multipole_snr, 

122 precessing_snr=self.calculate_precessing_snr, 

123 psd=psd, f_final=self.f_final[num], 

124 waveform_fits=self.waveform_fits, 

125 multi_process=self.opts.multi_process, 

126 redshift_method=self.redshift_method, 

127 cosmology=self.cosmology, 

128 no_conversion=self.no_conversion, 

129 add_zero_spin=True, delta_f=self.delta_f[num], 

130 psd_default=self.psd_default, 

131 disable_remnant=self.disable_remnant, 

132 force_BBH_remnant_computation=self.force_BBH_remnant_computation, 

133 resume_file=resume_file[num], 

134 restart_from_checkpoint=self.restart_from_checkpoint, 

135 force_BH_spin_evolution=self.force_BH_spin_evolution, 

136 )) 

137 return kwargs 

138 except IndexError: 

139 logger.warning( 

140 "Unable to find an f_ref, f_low and approximant for each " 

141 "label. Using and f_ref={}, f_low={} and approximant={} " 

142 "for all result files".format( 

143 self.f_ref[0], self.f_low[0], approx[0] 

144 ) 

145 ) 

146 for num, label in enumerate(self.labels): 

147 kwargs[label].update(dict( 

148 evolve_spins_forwards=self.evolve_spins_forwards, 

149 evolve_spins_backwards=self.evolve_spins_backwards, 

150 f_start=self.f_start[0], f_low=self.f_low[0], 

151 approximant=approx[0], f_ref=self.f_ref[0], 

152 NRSur_fits=self.NRSur_fits, return_kwargs=True, 

153 multipole_snr=self.calculate_multipole_snr, 

154 precessing_snr=self.calculate_precessing_snr, 

155 psd=self.psd[self.labels[0]], f_final=self.f_final[0], 

156 waveform_fits=self.waveform_fits, 

157 multi_process=self.opts.multi_process, 

158 redshift_method=self.redshift_method, 

159 cosmology=self.cosmology, 

160 no_conversion=self.no_conversion, 

161 add_zero_spin=True, delta_f=self.delta_f[0], 

162 psd_default=self.psd_default, 

163 disable_remnant=self.disable_remnant, 

164 force_BBH_remnant_computation=self.force_BBH_remnant_computation, 

165 resume_file=resume_file[num], 

166 restart_from_checkpoint=self.restart_from_checkpoint, 

167 force_BH_spin_evolution=self.force_BH_spin_evolution 

168 )) 

169 return kwargs 

170 

171 @staticmethod 

172 def grab_data_from_file( 

173 file, label, webdir, config=None, injection=None, file_format=None, 

174 nsamples=None, disable_prior_sampling=False, **kwargs 

175 ): 

176 """Grab data from a result file containing posterior samples 

177 

178 Parameters 

179 ---------- 

180 file: str 

181 path to the result file 

182 label: str 

183 label that you wish to use for the result file 

184 config: str, optional 

185 path to a configuration file used in the analysis 

186 injection: str, optional 

187 path to an injection file used in the analysis 

188 file_format, str, optional 

189 the file format you wish to use when loading. Default None. 

190 If None, the read function loops through all possible options 

191 """ 

192 data = pesummary.core.cli.inputs._Input.grab_data_from_file( 

193 file, label, webdir, config=config, injection=injection, 

194 read_function=GWRead, file_format=file_format, nsamples=nsamples, 

195 disable_prior_sampling=disable_prior_sampling, **kwargs 

196 ) 

197 try: 

198 for _kwgs in data["file_kwargs"][label]: 

199 if "approximant_flags" in _kwgs["meta_data"]: 

200 for key, item in _kwargs["meta_data"]["approximant_flags"].items(): 

201 warning_cond = ( 

202 key in self.approximant_flags[label] and 

203 self.approximant_flags[label][key] != item 

204 ) 

205 if warning_cond: 

206 logger.warning( 

207 "Approximant flag {}={} found in result file for {}. " 

208 "Ignoring and using the provided values {}={}".format( 

209 key, self.approximant_flags[label][key], label, 

210 key, item 

211 ) 

212 ) 

213 else: 

214 self.approximant_flags[label][key] = item 

215 except Exception: 

216 pass 

217 return data 

218 

219 @property 

220 def reweight_samples(self): 

221 return self._reweight_samples 

222 

223 @reweight_samples.setter 

224 def reweight_samples(self, reweight_samples): 

225 from pesummary.gw.reweight import options 

226 self._reweight_samples = self._check_reweight_samples( 

227 reweight_samples, options 

228 ) 

229 

230 def _set_samples(self, *args, **kwargs): 

231 super(_GWInput, self)._set_samples(*args, **kwargs) 

232 if "calibration" not in self.priors: 

233 self.priors["calibration"] = { 

234 label: {} for label in self.labels 

235 } 

236 

237 def _set_corner_params(self, corner_params): 

238 corner_params = super(_GWInput, self)._set_corner_params(corner_params) 

239 if corner_params is None: 

240 logger.debug( 

241 "Using the default corner parameters: {}".format( 

242 ", ".join(conf.gw_corner_parameters) 

243 ) 

244 ) 

245 else: 

246 _corner_params = corner_params 

247 corner_params = list(set(conf.gw_corner_parameters + corner_params)) 

248 for param in _corner_params: 

249 _data = self.samples 

250 if not all(param in _data[label].keys() for label in self.labels): 

251 corner_params.remove(param) 

252 logger.debug( 

253 "Generating a corner plot with the following " 

254 "parameters: {}".format(", ".join(corner_params)) 

255 ) 

256 return corner_params 

257 

258 @property 

259 def cosmology(self): 

260 return self._cosmology 

261 

262 @cosmology.setter 

263 def cosmology(self, cosmology): 

264 from pesummary.gw.cosmology import available_cosmologies 

265 

266 if cosmology.lower() not in available_cosmologies: 

267 logger.warning( 

268 "Unrecognised cosmology: {}. Using {} as default".format( 

269 cosmology, conf.cosmology 

270 ) 

271 ) 

272 cosmology = conf.cosmology 

273 else: 

274 logger.debug("Using the {} cosmology".format(cosmology)) 

275 self._cosmology = cosmology 

276 

277 @property 

278 def approximant(self): 

279 return self._approximant 

280 

281 @approximant.setter 

282 def approximant(self, approximant): 

283 if not hasattr(self, "_approximant"): 

284 approximant_list = {i: {} for i in self.labels} 

285 if approximant is None: 

286 logger.warning( 

287 "No approximant passed. Waveform plots will not be " 

288 "generated" 

289 ) 

290 elif approximant is not None: 

291 if len(approximant) != len(self.labels): 

292 raise InputError( 

293 "Please pass an approximant for each result file" 

294 ) 

295 approximant_list = { 

296 i: j for i, j in zip(self.labels, approximant) 

297 } 

298 self._approximant = approximant_list 

299 else: 

300 for num, i in enumerate(self._approximant.keys()): 

301 if self._approximant[i] == {}: 

302 if num == 0: 

303 logger.warning( 

304 "No approximant passed. Waveform plots will not be " 

305 "generated" 

306 ) 

307 self._approximant[i] = None 

308 break 

309 

310 @property 

311 def approximant_flags(self): 

312 return self._approximant_flags 

313 

314 @approximant_flags.setter 

315 def approximant_flags(self, approximant_flags): 

316 if hasattr(self, "_approximant_flags"): 

317 return 

318 _approximant_flags = {key: {} for key in self.labels} 

319 for key, item in approximant_flags.items(): 

320 _label, key = key.split(":") 

321 if _label not in self.labels: 

322 raise ValueError( 

323 "Unable to assign waveform flags for label:{} because " 

324 "it does not exist. Available labels are: {}. Approximant " 

325 "flags must be provided in the form LABEL:FLAG:VALUE".format( 

326 _label, ", ".join(self.labels) 

327 ) 

328 ) 

329 _approximant_flags[_label][key] = item 

330 self._approximant_flags = _approximant_flags 

331 

332 @property 

333 def gracedb_server(self): 

334 return self._gracedb_server 

335 

336 @gracedb_server.setter 

337 def gracedb_server(self, gracedb_server): 

338 if gracedb_server is None: 

339 self._gracedb_server = conf.gracedb_server 

340 else: 

341 logger.debug( 

342 "Using '{}' as the GraceDB server".format(gracedb_server) 

343 ) 

344 self._gracedb_server = gracedb_server 

345 

346 @property 

347 def gracedb(self): 

348 return self._gracedb 

349 

350 @gracedb.setter 

351 def gracedb(self, gracedb): 

352 self._gracedb = gracedb 

353 if gracedb is not None: 

354 from pesummary.gw.gracedb import get_gracedb_data, HTTPError 

355 from json.decoder import JSONDecodeError 

356 

357 first_letter = gracedb[0] 

358 if first_letter != "G" and first_letter != "g" and first_letter != "S": 

359 logger.warning( 

360 "Invalid GraceDB ID passed. The GraceDB ID must be of the " 

361 "form G0000 or S0000. Ignoring input." 

362 ) 

363 self._gracedb = None 

364 return 

365 _error = ( 

366 "Unable to download data from Gracedb because {}. Only storing " 

367 "the GraceDB ID in the metafile" 

368 ) 

369 try: 

370 logger.info( 

371 "Downloading {} from gracedb for {}".format( 

372 ", ".join(self.gracedb_data), gracedb 

373 ) 

374 ) 

375 json = get_gracedb_data( 

376 gracedb, info=self.gracedb_data, 

377 service_url=self.gracedb_server 

378 ) 

379 json["id"] = gracedb 

380 except (HTTPError, RuntimeError, JSONDecodeError) as e: 

381 logger.warning(_error.format(e)) 

382 json = {"id": gracedb} 

383 

384 for label in self.labels: 

385 self.file_kwargs[label]["meta_data"]["gracedb"] = json 

386 

387 @property 

388 def terrestrial_probability(self): 

389 return self._terrestrial_probability 

390 

391 @terrestrial_probability.setter 

392 def terrestrial_probability(self, terrestrial_probability): 

393 if terrestrial_probability is None and self.gracedb is not None: 

394 logger.info( 

395 "No terrestrial probability provided. Trying to download " 

396 "from gracedb" 

397 ) 

398 from pesummary.core.fetch import download_and_read_file 

399 from urllib.error import HTTPError 

400 from json.decoder import JSONDecodeError 

401 import json 

402 try: 

403 ff = download_and_read_file( 

404 f"{self.gracedb_server}/superevents/{self.gracedb}/" 

405 f"files/p_astro.json", read_file=False, 

406 outdir=f"{self.webdir}/samples", delete_on_exit=False 

407 ) 

408 with open(ff, "r") as f: 

409 data = json.load(f) 

410 self._terrestrial_probability = [float(data["Terrestrial"])] 

411 except (RuntimeError, JSONDecodeError) as e: 

412 logger.warning( 

413 "Unable to grab terrestrial probability from gracedb " 

414 "because {}".format(e) 

415 ) 

416 self._terrestrial_probability = [None] 

417 except HTTPError as e: 

418 from pesummary.gw.gracedb import get_gracedb_data, get_gracedb_file 

419 try: 

420 preferred = get_gracedb_data( 

421 self.gracedb, info=["preferred_event_data"], 

422 service_url=self.gracedb_server 

423 )["preferred_event_data"]["submitter"] 

424 _pipelines = [ 

425 "pycbc", "gstlal", "mbta", "spiir" 

426 ] 

427 _filename = None 

428 for _pipe in _pipelines: 

429 if _pipe in preferred: 

430 _filename = f"{_pipe}.p_astro.json" 

431 if _filename is None: 

432 raise e 

433 data = get_gracedb_file( 

434 self.gracedb, _filename, service_url=self.gracedb_server 

435 ) 

436 with open(f"{self.webdir}/samples/{_filename}", "w") as json_file: 

437 json.dump(data, json_file) 

438 self._terrestrial_probability = [float(data["Terrestrial"])] 

439 except Exception as e: 

440 logger.warning( 

441 "Unable to grab terrestrial probability from gracedb " 

442 "because {}".format(e) 

443 ) 

444 self._terrestrial_probability = [None] 

445 self._terrestrial_probability *= len(self.labels) 

446 elif terrestrial_probability is None: 

447 self._terrestrial_probability = [None] * len(self.labels) 

448 else: 

449 if len(terrestrial_probability) == 1 and len(self.labels) > 1: 

450 logger.debug( 

451 f"Assuming a terrestrial probability: " 

452 f"{terrestrial_probability} for all analyses" 

453 ) 

454 self._terrestrial_probability = [ 

455 float(terrestrial_probability[0]) 

456 ] * len(self.labels) 

457 elif len(terrestrial_probability) == len(self.labels): 

458 self._terrestrial_probability = [ 

459 float(_) for _ in terrestrial_probability 

460 ] 

461 else: 

462 raise ValueError( 

463 "Please provide a terrestrial probability for each " 

464 "analysis, or a single value to be used for all analyses" 

465 ) 

466 

467 @property 

468 def detectors(self): 

469 return self._detectors 

470 

471 @detectors.setter 

472 def detectors(self, detectors): 

473 detector = {} 

474 if not detectors: 

475 for i in self.samples.keys(): 

476 params = list(self.samples[i].keys()) 

477 individual_detectors = [] 

478 for j in params: 

479 if "optimal_snr" in j and j != "network_optimal_snr": 

480 det = j.split("_optimal_snr")[0] 

481 individual_detectors.append(det) 

482 individual_detectors = sorted( 

483 [str(i) for i in individual_detectors]) 

484 if individual_detectors: 

485 detector[i] = "_".join(individual_detectors) 

486 else: 

487 detector[i] = None 

488 else: 

489 detector = detectors 

490 logger.debug("The detector network is %s" % (detector)) 

491 self._detectors = detector 

492 

493 @property 

494 def skymap(self): 

495 return self._skymap 

496 

497 @skymap.setter 

498 def skymap(self, skymap): 

499 if not hasattr(self, "_skymap"): 

500 self._skymap = {i: None for i in self.labels} 

501 

502 @property 

503 def calibration_definition(self): 

504 return self._calibration_definition 

505 

506 @calibration_definition.setter 

507 def calibration_definition(self, calibration_definition): 

508 if not len(self.opts.calibration): 

509 self._calibration_definition = None 

510 return 

511 if len(calibration_definition) == 1: 

512 logger.info( 

513 f"Assuming that the calibration correction was applied to " 

514 f"'{calibration_definition[0]}' for all analyses" 

515 ) 

516 calibration_definition *= len(self.labels) 

517 elif len(calibration_definition) != len(self.labels): 

518 raise ValueError( 

519 f"Please provide a calibration definition for each analysis " 

520 f"({len(self.labels)}) or a single definition to use for all " 

521 f"analyses" 

522 ) 

523 if any(_ not in ["data", "template"] for _ in calibration_definition): 

524 raise ValueError( 

525 "Calibration definitions must be either 'data' or 'template'" 

526 ) 

527 self._calibration_definition = { 

528 label: calibration_definition[num] for num, label in 

529 enumerate(self.labels) 

530 } 

531 

532 @property 

533 def calibration(self): 

534 return self._calibration 

535 

536 @calibration.setter 

537 def calibration(self, calibration): 

538 if not hasattr(self, "_calibration"): 

539 data = {i: {} for i in self.labels} 

540 if calibration != {}: 

541 prior_data = self.get_psd_or_calibration_data( 

542 calibration, self.extract_calibration_data_from_file, 

543 type=self.calibration_definition[self.labels[0]] 

544 ) 

545 self.add_to_prior_dict("calibration", prior_data) 

546 else: 

547 prior_data = {i: {} for i in self.labels} 

548 for label in self.labels: 

549 if hasattr(self.opts, "{}_calibration".format(label)): 

550 cal_data = getattr(self.opts, "{}_calibration".format(label)) 

551 if cal_data != {} and cal_data is not None: 

552 prior_data[label] = { 

553 ifo: self.extract_calibration_data_from_file( 

554 cal_data[ifo], type=self.calibration_definition[label] 

555 ) for ifo in cal_data.keys() 

556 } 

557 if not all(prior_data[i] == {} for i in self.labels): 

558 self.add_to_prior_dict("calibration", prior_data) 

559 else: 

560 self.add_to_prior_dict("calibration", {}) 

561 for num, i in enumerate(self.result_files): 

562 _opened = self._open_result_files 

563 if i in _opened.keys() and _opened[i] is not None: 

564 f = self._open_result_files[i] 

565 else: 

566 f = GWRead(i, disable_prior=True) 

567 try: 

568 calibration_data = f.interpolate_calibration_spline_posterior() 

569 except Exception as e: 

570 logger.warning( 

571 "Failed to extract calibration data from the result " 

572 "file: {} because {}".format(i, e) 

573 ) 

574 calibration_data = None 

575 labels = list(self.samples.keys()) 

576 if calibration_data is None: 

577 data[labels[num]] = { 

578 None: None 

579 } 

580 elif isinstance(f, pesummary.gw.file.formats.pesummary.PESummary): 

581 for num in range(len(calibration_data[0])): 

582 data[labels[num]] = { 

583 j: k for j, k in zip( 

584 calibration_data[1][num], 

585 calibration_data[0][num] 

586 ) 

587 } 

588 else: 

589 data[labels[num]] = { 

590 j: k for j, k in zip( 

591 calibration_data[1], calibration_data[0] 

592 ) 

593 } 

594 self._calibration = data 

595 

596 @property 

597 def psd(self): 

598 return self._psd 

599 

600 @psd.setter 

601 def psd(self, psd): 

602 if not hasattr(self, "_psd"): 

603 data = {i: {} for i in self.labels} 

604 if psd != {}: 

605 data = self.get_psd_or_calibration_data( 

606 psd, self.extract_psd_data_from_file 

607 ) 

608 else: 

609 for label in self.labels: 

610 if hasattr(self.opts, "{}_psd".format(label)): 

611 psd_data = getattr(self.opts, "{}_psd".format(label)) 

612 if psd_data != {} and psd_data is not None: 

613 data[label] = { 

614 ifo: self.extract_psd_data_from_file( 

615 psd_data[ifo], IFO=ifo 

616 ) for ifo in psd_data.keys() 

617 } 

618 self._psd = data 

619 

620 @property 

621 def nsamples_for_skymap(self): 

622 return self._nsamples_for_skymap 

623 

624 @nsamples_for_skymap.setter 

625 def nsamples_for_skymap(self, nsamples_for_skymap): 

626 self._nsamples_for_skymap = nsamples_for_skymap 

627 if nsamples_for_skymap is not None: 

628 self._nsamples_for_skymap = int(nsamples_for_skymap) 

629 number_of_samples = [ 

630 data.number_of_samples for label, data in self.samples.items() 

631 ] 

632 if not all(i > self._nsamples_for_skymap for i in number_of_samples): 

633 min_arg = np.argmin(number_of_samples) 

634 logger.warning( 

635 "You have specified that you would like to use {} " 

636 "samples to generate the skymap but the file {} only " 

637 "has {} samples. Reducing the number of samples to " 

638 "generate the skymap to {}".format( 

639 self._nsamples_for_skymap, self.result_files[min_arg], 

640 number_of_samples[min_arg], number_of_samples[min_arg] 

641 ) 

642 ) 

643 self._nsamples_for_skymap = int(number_of_samples[min_arg]) 

644 

645 @property 

646 def gwdata(self): 

647 return self._gwdata 

648 

649 @gwdata.setter 

650 def gwdata(self, gwdata): 

651 from pesummary.gw.file.strain import StrainDataDict 

652 

653 self._gwdata = gwdata 

654 if gwdata is not None: 

655 if isinstance(gwdata, dict): 

656 for i in gwdata.keys(): 

657 if not os.path.isfile(gwdata[i]): 

658 raise InputError( 

659 "The file {} does not exist. Please check the path " 

660 "to your strain file".format(gwdata[i]) 

661 ) 

662 self._gwdata = StrainDataDict.read(gwdata) 

663 else: 

664 logger.warning( 

665 "Please provide gwdata as a dictionary with keys " 

666 "displaying the channel and item giving the path to the " 

667 "strain file" 

668 ) 

669 self._gwdata = None 

670 

671 @property 

672 def evolve_spins_forwards(self): 

673 return self._evolve_spins_forwards 

674 

675 @evolve_spins_forwards.setter 

676 def evolve_spins_forwards(self, evolve_spins_forwards): 

677 self._evolve_spins_forwards = evolve_spins_forwards 

678 _msg = "Spins will be evolved up to {}" 

679 if evolve_spins_forwards: 

680 logger.info(_msg.format("Schwarzschild ISCO frequency")) 

681 self._evolve_spins_forwards = 6. ** -0.5 

682 

683 @property 

684 def evolve_spins_backwards(self): 

685 return self._evolve_spins_backwards 

686 

687 @evolve_spins_backwards.setter 

688 def evolve_spins_backwards(self, evolve_spins_backwards): 

689 self._evolve_spins_backwards = evolve_spins_backwards 

690 _msg = ( 

691 "Spins will be evolved backwards to an infinite separation using the '{}' " 

692 "method" 

693 ) 

694 if isinstance(evolve_spins_backwards, (str, bytes)): 

695 logger.info(_msg.format(evolve_spins_backwards)) 

696 elif evolve_spins_backwards is None: 

697 logger.info(_msg.format("precession_averaged")) 

698 self._evolve_spins_backwards = "precession_averaged" 

699 

700 @property 

701 def NRSur_fits(self): 

702 return self._NRSur_fits 

703 

704 @NRSur_fits.setter 

705 def NRSur_fits(self, NRSur_fits): 

706 self._NRSur_fits = NRSur_fits 

707 base = ( 

708 "Using the '{}' NRSurrogate model to calculate the remnant " 

709 "quantities" 

710 ) 

711 if isinstance(NRSur_fits, (str, bytes)): 

712 logger.info(base.format(NRSur_fits)) 

713 self._NRSur_fits = NRSur_fits 

714 elif NRSur_fits is None: 

715 from pesummary.gw.conversions.nrutils import NRSUR_MODEL 

716 

717 logger.info(base.format(NRSUR_MODEL)) 

718 self._NRSur_fits = NRSUR_MODEL 

719 

720 @property 

721 def waveform_fits(self): 

722 return self._waveform_fits 

723 

724 @waveform_fits.setter 

725 def waveform_fits(self, waveform_fits): 

726 self._waveform_fits = waveform_fits 

727 if waveform_fits: 

728 logger.info( 

729 "Evaluating the remnant quantities using the provided " 

730 "approximant" 

731 ) 

732 

733 @property 

734 def f_low(self): 

735 return self._f_low 

736 

737 @f_low.setter 

738 def f_low(self, f_low): 

739 self._f_low = f_low 

740 if f_low is not None: 

741 self._f_low = [float(i) for i in f_low] 

742 

743 @property 

744 def f_start(self): 

745 return self._f_start 

746 

747 @f_start.setter 

748 def f_start(self, f_start): 

749 self._f_start = f_start 

750 if f_start is not None: 

751 self._f_start = [float(i) for i in f_start] 

752 

753 @property 

754 def f_ref(self): 

755 return self._f_ref 

756 

757 @f_ref.setter 

758 def f_ref(self, f_ref): 

759 self._f_ref = f_ref 

760 if f_ref is not None: 

761 self._f_ref = [float(i) for i in f_ref] 

762 

763 @property 

764 def f_final(self): 

765 return self._f_final 

766 

767 @f_final.setter 

768 def f_final(self, f_final): 

769 self._f_final = f_final 

770 if f_final is not None: 

771 self._f_final = [float(i) for i in f_final] 

772 

773 @property 

774 def delta_f(self): 

775 return self._delta_f 

776 

777 @delta_f.setter 

778 def delta_f(self, delta_f): 

779 self._delta_f = delta_f 

780 if delta_f is not None: 

781 self._delta_f = [float(i) for i in delta_f] 

782 

783 @property 

784 def psd_default(self): 

785 return self._psd_default 

786 

787 @psd_default.setter 

788 def psd_default(self, psd_default): 

789 self._psd_default = psd_default 

790 if "stored:" in psd_default: 

791 label = psd_default.split("stored:")[1] 

792 self._psd_default = "{file}.psd['%s']" % (label) 

793 return 

794 try: 

795 from pycbc import psd 

796 psd_default = getattr(psd, psd_default) 

797 except ImportError: 

798 logger.warning( 

799 "Unable to import 'pycbc'. Unable to generate a default PSD" 

800 ) 

801 psd_default = None 

802 except AttributeError: 

803 logger.warning( 

804 "'pycbc' does not have the '{}' psd available. Using '{}' as " 

805 "the default PSD".format(psd_default, conf.psd) 

806 ) 

807 psd_default = getattr(psd, conf.psd) 

808 except ValueError as e: 

809 logger.warning("Setting 'psd_default' to None because {}".format(e)) 

810 psd_default = None 

811 self._psd_default = psd_default 

812 

813 @property 

814 def pastro_probs(self): 

815 return self._pastro_probs 

816 

817 @pastro_probs.setter 

818 def pastro_probs(self, pastro_probs): 

819 from pesummary.gw.classification import PAstro 

820 

821 classifications = {} 

822 for num, i in enumerate(list(self.samples.keys())): 

823 try: 

824 import importlib 

825 distance_prior = self.priors["analytic"]["luminosity_distance"] 

826 cls = distance_prior.split("(")[0] 

827 module = ".".join(cls.split(".")[:-1]) 

828 cls = cls.split(".")[-1] 

829 cls = getattr(importlib.import_module(module), cls, cls) 

830 args = "(".join(distance_prior.split("(")[1:])[:-1] 

831 distance_prior = cls.from_repr(args) 

832 except KeyError: 

833 logger.debug( 

834 f"Unable to find a distance prior. Defaulting to stored " 

835 f"prior in pesummary.gw.classification.PAstro for " 

836 f"source classification probabilities" 

837 ) 

838 distance_prior = None 

839 except AttributeError: 

840 logger.debug( 

841 f"Unable to load distance prior: {distance_prior}. " 

842 f"Defaulting to stored prior in " 

843 f"pesummary.gw.classification.PAstro for source " 

844 f"classification probabilities" 

845 ) 

846 distance_prior = None 

847 try: 

848 _cls = PAstro( 

849 self.samples[i], category_data=self.pastro_category_file, 

850 terrestrial_probability=self.terrestrial_probability[num], 

851 distance_prior=distance_prior, 

852 catch_terrestrial_probability_error=self.catch_terrestrial_probability_error 

853 ) 

854 classifications[i] = {"default": _cls.classification()} 

855 try: 

856 _cls.save_to_file( 

857 f"{i}.pesummary.p_astro.json", 

858 classifications[i]["default"], 

859 outdir=f"{self.webdir}/samples", 

860 overwrite=True 

861 ) 

862 except FileNotFoundError as e: 

863 logger.warning( 

864 f"Failed to write PAstro probabilities to file " 

865 f"because {e}" 

866 ) 

867 except Exception as e: 

868 logger.warning( 

869 "Failed to generate source classification probabilities " 

870 "because {}".format(e) 

871 ) 

872 classifications[i] = {"default": PAstro.defaults} 

873 if self.mcmc_samples: 

874 if any(_probs is None for _probs in classifications.values()): 

875 classifications[self.labels[0]] = None 

876 logger.warning( 

877 "Unable to average classification probabilities across " 

878 "mcmc chains because one or more chains failed to estimate " 

879 "classifications" 

880 ) 

881 else: 

882 logger.debug( 

883 "Averaging classification probability across mcmc samples" 

884 ) 

885 classifications[self.labels[0]] = { 

886 prior: { 

887 key: np.round(np.mean( 

888 [val[prior][key] for val in classifications.values()] 

889 ), 3) for key in _probs.keys() 

890 } for prior, _probs in 

891 list(classifications.values())[0].items() 

892 } 

893 self._pastro_probs = classifications 

894 

895 @property 

896 def embright_probs(self): 

897 return self._embright_probs 

898 

899 @embright_probs.setter 

900 def embright_probs(self, embright_probs): 

901 from pesummary.gw.classification import EMBright 

902 

903 probabilities = {} 

904 for num, i in enumerate(list(self.samples.keys())): 

905 try: 

906 _cls = EMBright(self.samples[i]) 

907 probabilities[i] = {"default": _cls.classification()} 

908 try: 

909 _cls.save_to_file( 

910 f"{i}.pesummary.em_bright.json", 

911 probabilities[i]["default"], 

912 outdir=f"{self.webdir}/samples", 

913 overwrite=True 

914 ) 

915 except FileNotFoundError as e: 

916 logger.warning( 

917 f"Failed to write EM bright probabilities to file " 

918 f"because {e}" 

919 ) 

920 except Exception as e: 

921 logger.warning( 

922 "Failed to generate em_bright probabilities because " 

923 "{}".format(e) 

924 ) 

925 probabilities[i] = {"default": EMBright.defaults} 

926 if self.mcmc_samples: 

927 if any(_probs is None for _probs in probabilities.values()): 

928 probabilities[self.labels[0]] = None 

929 logger.warning( 

930 "Unable to average em_bright probabilities across " 

931 "mcmc chains because one or more chains failed to estimate " 

932 "probabilities" 

933 ) 

934 else: 

935 logger.debug( 

936 "Averaging em_bright probability across mcmc samples" 

937 ) 

938 probabilities[self.labels[0]] = { 

939 prior: { 

940 key: np.round(np.mean( 

941 [val[prior][key] for val in probabilities.values()] 

942 ), 3) for key in _probs.keys() 

943 } for prior, _probs in list(probabilities.values())[0].items() 

944 } 

945 self._embright_probs = probabilities 

946 

947 @property 

948 def preliminary_pages(self): 

949 return self._preliminary_pages 

950 

951 @preliminary_pages.setter 

952 def preliminary_pages(self, preliminary_pages): 

953 required = conf.gw_reproducibility 

954 self._preliminary_pages = {label: False for label in self.labels} 

955 for num, label in enumerate(self.labels): 

956 for attr in required: 

957 _property = getattr(self, attr) 

958 if isinstance(_property, dict): 

959 if label not in _property.keys(): 

960 self._preliminary_pages[label] = True 

961 elif not len(_property[label]): 

962 self._preliminary_pages[label] = True 

963 elif isinstance(_property, list): 

964 if _property[num] is None: 

965 self._preliminary_pages[label] = True 

966 if any(value for value in self._preliminary_pages.values()): 

967 _labels = [ 

968 label for label, value in self._preliminary_pages.items() if 

969 value 

970 ] 

971 msg = ( 

972 "Unable to reproduce the {} analys{} because no {} data was " 

973 "provided. 'Preliminary' watermarks will be added to the final " 

974 "html pages".format( 

975 ", ".join(_labels), "es" if len(_labels) > 1 else "is", 

976 " or ".join(required) 

977 ) 

978 ) 

979 logger.warning(msg) 

980 

981 @staticmethod 

982 def _extract_IFO_data_from_file(file, cls, desc, IFO=None, **kwargs): 

983 """Return IFO data stored in a file 

984 

985 Parameters 

986 ---------- 

987 file: path 

988 path to a file containing the IFO data 

989 cls: obj 

990 class you wish to use when loading the file. This class must have 

991 a '.read' method 

992 desc: str 

993 description of the IFO data stored in the file 

994 IFO: str, optional 

995 the IFO which the data belongs to 

996 """ 

997 general = ( 

998 "Failed to read in %s data because {}. The %s plot will not be " 

999 "generated and the %s data will not be added to the metafile." 

1000 ) % (desc, desc, desc) 

1001 try: 

1002 return cls.read(file, IFO=IFO, **kwargs) 

1003 except FileNotFoundError: 

1004 logger.warning( 

1005 general.format("the file {} does not exist".format(file)) 

1006 ) 

1007 return {} 

1008 except ValueError as e: 

1009 logger.warning(general.format(e)) 

1010 return {} 

1011 

1012 @staticmethod 

1013 def extract_psd_data_from_file(file, IFO=None): 

1014 """Return the data stored in a psd file 

1015 

1016 Parameters 

1017 ---------- 

1018 file: path 

1019 path to a file containing the psd data 

1020 """ 

1021 from pesummary.gw.file.psd import PSD 

1022 return _GWInput._extract_IFO_data_from_file(file, PSD, "PSD", IFO=IFO) 

1023 

1024 @staticmethod 

1025 def extract_calibration_data_from_file(file, type="data", **kwargs): 

1026 """Return the data stored in a calibration file 

1027 

1028 Parameters 

1029 ---------- 

1030 file: path 

1031 path to a file containing the calibration data 

1032 """ 

1033 from pesummary.gw.file.calibration import Calibration 

1034 return _GWInput._extract_IFO_data_from_file( 

1035 file, Calibration, "calibration", type=type, **kwargs 

1036 ) 

1037 

1038 @staticmethod 

1039 def get_ifo_from_file_name(file): 

1040 """Return the IFO from the file name 

1041 

1042 Parameters 

1043 ---------- 

1044 file: str 

1045 path to the file 

1046 """ 

1047 file_name = file.split("/")[-1] 

1048 if any(j in file_name for j in ["H1", "_0", "IFO0"]): 

1049 ifo = "H1" 

1050 elif any(j in file_name for j in ["L1", "_1", "IFO1"]): 

1051 ifo = "L1" 

1052 elif any(j in file_name for j in ["V1", "_2", "IFO2"]): 

1053 ifo = "V1" 

1054 else: 

1055 ifo = file_name 

1056 return ifo 

1057 

1058 def get_psd_or_calibration_data(self, input, executable, **kwargs): 

1059 """Return a dictionary containing the psd or calibration data 

1060 

1061 Parameters 

1062 ---------- 

1063 input: list/dict 

1064 list/dict containing paths to calibration/psd files 

1065 executable: func 

1066 executable that is used to extract the data from the calibration/psd 

1067 files 

1068 """ 

1069 data = {} 

1070 if input == {} or input == []: 

1071 return data 

1072 if isinstance(input, dict): 

1073 keys = list(input.keys()) 

1074 if isinstance(input, dict) and isinstance(input[keys[0]], list): 

1075 if not all(len(input[i]) == len(self.labels) for i in list(keys)): 

1076 raise InputError( 

1077 "Please ensure the number of calibration/psd files matches " 

1078 "the number of result files passed" 

1079 ) 

1080 for idx in range(len(input[keys[0]])): 

1081 data[self.labels[idx]] = { 

1082 i: executable(input[i][idx], IFO=i, **kwargs) for i in list(keys) 

1083 } 

1084 elif isinstance(input, dict): 

1085 for i in self.labels: 

1086 data[i] = { 

1087 j: executable(input[j], IFO=j, **kwargs) for j in list(input.keys()) 

1088 } 

1089 elif isinstance(input, list): 

1090 for i in self.labels: 

1091 data[i] = { 

1092 self.get_ifo_from_file_name(j): executable( 

1093 j, IFO=self.get_ifo_from_file_name(j), **kwargs 

1094 ) for j in input 

1095 } 

1096 else: 

1097 raise InputError( 

1098 "Did not understand the psd/calibration input. Please use the " 

1099 "following format 'H1:path/to/file'" 

1100 ) 

1101 return data 

1102 

1103 def grab_priors_from_inputs(self, priors): 

1104 def read_func(data, **kwargs): 

1105 from pesummary.gw.file.read import read as GWRead 

1106 data = GWRead(data, **kwargs) 

1107 data.generate_all_posterior_samples() 

1108 return data 

1109 

1110 return super(_GWInput, self).grab_priors_from_inputs( 

1111 priors, read_func=read_func, read_kwargs=self.grab_data_kwargs 

1112 ) 

1113 

1114 def grab_key_data_from_result_files(self): 

1115 """Grab the mean, median, maxL and standard deviation for all 

1116 parameters for all each result file 

1117 """ 

1118 from pesummary.utils.kde_list import KDEList 

1119 from pesummary.gw.plots.plot import _return_bounds 

1120 from pesummary.utils.credible_interval import ( 

1121 hpd_two_sided_credible_interval 

1122 ) 

1123 from pesummary.utils.bounded_1d_kde import bounded_1d_kde 

1124 key_data = super(_GWInput, self).grab_key_data_from_result_files() 

1125 bounded_parameters = ["mass_ratio", "a_1", "a_2", "lambda_tilde"] 

1126 for param in bounded_parameters: 

1127 xlow, xhigh = _return_bounds(param, []) 

1128 _samples = { 

1129 key: val[param] for key, val in self.samples.items() 

1130 if param in val.keys() 

1131 } 

1132 _min = [np.min(_) for _ in _samples.values() if len(_samples)] 

1133 _max = [np.max(_) for _ in _samples.values() if len(_samples)] 

1134 if not len(_min): 

1135 continue 

1136 _min = np.min(_min) 

1137 _max = np.max(_max) 

1138 x = np.linspace(_min, _max, 1000) 

1139 try: 

1140 kdes = KDEList( 

1141 list(_samples.values()), kde=bounded_1d_kde, 

1142 kde_kwargs={"xlow": xlow, "xhigh": xhigh} 

1143 ) 

1144 except Exception as e: 

1145 logger.warning( 

1146 "Unable to compute the HPD interval for {} because {}".format( 

1147 param, e 

1148 ) 

1149 ) 

1150 continue 

1151 pdfs = kdes(x) 

1152 for num, key in enumerate(_samples.keys()): 

1153 [xlow, xhigh], _ = hpd_two_sided_credible_interval( 

1154 [], 90, x=x, pdf=pdfs[num] 

1155 ) 

1156 key_data[key][param]["90% HPD"] = [xlow, xhigh] 

1157 for _param in self.samples[key].keys(): 

1158 if _param in bounded_parameters: 

1159 continue 

1160 key_data[key][_param]["90% HPD"] = float("nan") 

1161 return key_data 

1162 

1163 

1164class SamplesInput(_GWInput, pesummary.core.cli.inputs.SamplesInput): 

1165 """Class to handle and store sample specific command line arguments 

1166 """ 

1167 def __init__(self, *args, **kwargs): 

1168 kwargs.update({"ignore_copy": True}) 

1169 super(SamplesInput, self).__init__( 

1170 *args, gw=True, extra_options=[ 

1171 "evolve_spins_forwards", 

1172 "evolve_spins_backwards", 

1173 "NRSur_fits", 

1174 "calculate_multipole_snr", 

1175 "calculate_precessing_snr", 

1176 "f_start", 

1177 "f_low", 

1178 "f_ref", 

1179 "f_final", 

1180 "psd", 

1181 "waveform_fits", 

1182 "redshift_method", 

1183 "cosmology", 

1184 "no_conversion", 

1185 "delta_f", 

1186 "psd_default", 

1187 "disable_remnant", 

1188 "force_BBH_remnant_computation", 

1189 "force_BH_spin_evolution" 

1190 ], **kwargs 

1191 ) 

1192 if self._restarted_from_checkpoint: 

1193 return 

1194 if self.existing is not None: 

1195 self.existing_data = self.grab_data_from_metafile( 

1196 self.existing_metafile, self.existing, 

1197 compare=self.compare_results 

1198 ) 

1199 self.existing_approximant = self.existing_data["approximant"] 

1200 self.existing_psd = self.existing_data["psd"] 

1201 self.existing_calibration = self.existing_data["calibration"] 

1202 self.existing_skymap = self.existing_data["skymap"] 

1203 else: 

1204 self.existing_approximant = None 

1205 self.existing_psd = None 

1206 self.existing_calibration = None 

1207 self.existing_skymap = None 

1208 self.approximant = self.opts.approximant 

1209 self.gracedb_server = self.opts.gracedb_server 

1210 self.gracedb_data = self.opts.gracedb_data 

1211 self.gracedb = self.opts.gracedb 

1212 self.pastro_category_file = self.opts.pastro_category_file 

1213 self.terrestrial_probability = self.opts.terrestrial_probability 

1214 self.catch_terrestrial_probability_error = self.opts.catch_terrestrial_probability_error 

1215 self.approximant_flags = self.opts.approximant_flags 

1216 self.detectors = None 

1217 self.skymap = None 

1218 self.calibration_definition = self.opts.calibration_definition 

1219 self.calibration = self.opts.calibration 

1220 self.gwdata = self.opts.gwdata 

1221 self.maxL_samples = [] 

1222 

1223 @property 

1224 def maxL_samples(self): 

1225 return self._maxL_samples 

1226 

1227 @maxL_samples.setter 

1228 def maxL_samples(self, maxL_samples): 

1229 key_data = self.grab_key_data_from_result_files() 

1230 maxL_samples = { 

1231 i: { 

1232 j: key_data[i][j]["maxL"] for j in key_data[i].keys() 

1233 } for i in key_data.keys() 

1234 } 

1235 for i in self.labels: 

1236 maxL_samples[i]["approximant"] = self.approximant[i] 

1237 self._maxL_samples = maxL_samples 

1238 

1239 

1240class PlottingInput(SamplesInput, pesummary.core.cli.inputs.PlottingInput): 

1241 """Class to handle and store plottig specific command line arguments 

1242 """ 

1243 def __init__(self, *args, **kwargs): 

1244 super(PlottingInput, self).__init__(*args, **kwargs) 

1245 self.nsamples_for_skymap = self.opts.nsamples_for_skymap 

1246 self.sensitivity = self.opts.sensitivity 

1247 self.no_ligo_skymap = self.opts.no_ligo_skymap 

1248 self.multi_threading_for_skymap = self.multi_process 

1249 if not self.no_ligo_skymap and self.multi_process > 1: 

1250 total = self.multi_process 

1251 self.multi_threading_for_plots = int(total / 2.) 

1252 self.multi_threading_for_skymap = total - self.multi_threading_for_plots 

1253 logger.info( 

1254 "Assigning {} process{}to skymap generation and {} process{}to " 

1255 "other plots".format( 

1256 self.multi_threading_for_skymap, 

1257 "es " if self.multi_threading_for_skymap > 1 else " ", 

1258 self.multi_threading_for_plots, 

1259 "es " if self.multi_threading_for_plots > 1 else " " 

1260 ) 

1261 ) 

1262 self.preliminary_pages = None 

1263 self.pastro_probs = [] 

1264 self.embright_probs = [] 

1265 self.classification_probs = {} 

1266 for key in self.pastro_probs.keys(): 

1267 self.classification_probs[key] = {"default": {}} 

1268 self.classification_probs[key]["default"].update( 

1269 self.pastro_probs[key]["default"] 

1270 ) 

1271 self.classification_probs[key]["default"].update( 

1272 self.embright_probs[key]["default"] 

1273 ) 

1274 

1275 

1276class WebpageInput(SamplesInput, pesummary.core.cli.inputs.WebpageInput): 

1277 """Class to handle and store webpage specific command line arguments 

1278 """ 

1279 def __init__(self, *args, **kwargs): 

1280 super(WebpageInput, self).__init__(*args, **kwargs) 

1281 self.public = self.opts.public 

1282 if not hasattr(self, "preliminary_pages"): 

1283 self.preliminary_pages = None 

1284 if not hasattr(self, "pastro_probs"): 

1285 self.pastro_probs = [] 

1286 if not hasattr(self, "embright_probs"): 

1287 self.embright_probs = [] 

1288 self.classification_probs = {} 

1289 for key in self.pastro_probs.keys(): 

1290 self.classification_probs[key] = {"default": {}} 

1291 self.classification_probs[key]["default"].update( 

1292 self.pastro_probs[key]["default"] 

1293 ) 

1294 self.classification_probs[key]["default"].update( 

1295 self.embright_probs[key]["default"] 

1296 ) 

1297 

1298 

1299class WebpagePlusPlottingInput(PlottingInput, WebpageInput): 

1300 """Class to handle and store webpage and plotting specific command line 

1301 arguments 

1302 """ 

1303 def __init__(self, *args, **kwargs): 

1304 super(WebpagePlusPlottingInput, self).__init__(*args, **kwargs) 

1305 

1306 @property 

1307 def default_directories(self): 

1308 return super(WebpagePlusPlottingInput, self).default_directories 

1309 

1310 @property 

1311 def default_files_to_copy(self): 

1312 return super(WebpagePlusPlottingInput, self).default_files_to_copy 

1313 

1314 

1315class MetaFileInput(SamplesInput, pesummary.core.cli.inputs.MetaFileInput): 

1316 """Class to handle and store metafile specific command line arguments 

1317 """ 

1318 @property 

1319 def default_directories(self): 

1320 dirs = super(MetaFileInput, self).default_directories 

1321 dirs += ["psds", "calibration"] 

1322 return dirs 

1323 

1324 def copy_files(self): 

1325 _error = "Failed to save the {} to file" 

1326 for label in self.labels: 

1327 if self.psd[label] != {}: 

1328 for ifo in self.psd[label].keys(): 

1329 if not isinstance(self.psd[label][ifo], PSD): 

1330 logger.warning(_error.format("{} PSD".format(ifo))) 

1331 continue 

1332 self.psd[label][ifo].save_to_file( 

1333 os.path.join(self.webdir, "psds", "{}_{}_psd.dat".format( 

1334 label, ifo 

1335 )) 

1336 ) 

1337 if label in self.priors["calibration"].keys(): 

1338 if self.priors["calibration"][label] != {}: 

1339 for ifo in self.priors["calibration"][label].keys(): 

1340 _instance = isinstance( 

1341 self.priors["calibration"][label][ifo], Calibration 

1342 ) 

1343 if not _instance: 

1344 logger.warning( 

1345 _error.format( 

1346 "{} calibration envelope".format( 

1347 ifo 

1348 ) 

1349 ) 

1350 ) 

1351 continue 

1352 self.priors["calibration"][label][ifo].save_to_file( 

1353 os.path.join(self.webdir, "calibration", "{}_{}_cal.txt".format( 

1354 label, ifo 

1355 )) 

1356 ) 

1357 return super(MetaFileInput, self).copy_files() 

1358 

1359 

1360class WebpagePlusPlottingPlusMetaFileInput(MetaFileInput, WebpagePlusPlottingInput): 

1361 """Class to handle and store webpage, plotting and metafile specific command 

1362 line arguments 

1363 """ 

1364 def __init__(self, *args, **kwargs): 

1365 super(WebpagePlusPlottingPlusMetaFileInput, self).__init__( 

1366 *args, **kwargs 

1367 ) 

1368 

1369 @property 

1370 def default_directories(self): 

1371 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_directories 

1372 

1373 @property 

1374 def default_files_to_copy(self): 

1375 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_files_to_copy 

1376 

1377 

1378@deprecation( 

1379 "The GWInput class is deprecated. Please use either the BaseInput, " 

1380 "SamplesInput, PlottingInput, WebpageInput, WebpagePlusPlottingInput, " 

1381 "MetaFileInput or the WebpagePlusPlottingPlusMetaFileInput class" 

1382) 

1383class GWInput(WebpagePlusPlottingPlusMetaFileInput): 

1384 pass 

1385 

1386 

1387class IMRCTInput(pesummary.core.cli.inputs._Input): 

1388 """Class to handle the TGR specific command line arguments 

1389 """ 

1390 @property 

1391 def labels(self): 

1392 return self._labels 

1393 

1394 @labels.setter 

1395 def labels(self, labels): 

1396 self._labels = labels 

1397 if len(labels) % 2 != 0: 

1398 raise ValueError( 

1399 "The IMRCT test requires 2 results files for each analysis. " 

1400 ) 

1401 elif len(labels) > 2: 

1402 cond = all( 

1403 ":inspiral" in label or ":postinspiral" in label for label in 

1404 labels 

1405 ) 

1406 if not cond: 

1407 raise ValueError( 

1408 "To compare 2 or more analyses, please provide labels as " 

1409 "'{}:inspiral' and '{}:postinspiral' where {} indicates " 

1410 "the analysis label" 

1411 ) 

1412 else: 

1413 self.analysis_label = [ 

1414 label.split(":inspiral")[0] 

1415 for label in labels 

1416 if ":inspiral" in label and ":postinspiral" not in label 

1417 ] 

1418 if len(self.analysis_label) != len(self.result_files) / 2: 

1419 raise ValueError( 

1420 "When comparing more than 2 analyses, labels must " 

1421 "be of the form '{}:inspiral' and '{}:postinspiral'." 

1422 ) 

1423 logger.info( 

1424 "Using the labels: {} to distinguish analyses".format( 

1425 ", ".join(self.analysis_label) 

1426 ) 

1427 ) 

1428 elif sorted(labels) != ["inspiral", "postinspiral"]: 

1429 if all(self.is_pesummary_metafile(ff) for ff in self.result_files): 

1430 meta_file_labels = [] 

1431 for suffix in [":inspiral", ":postinspiral"]: 

1432 if any(suffix in label for label in labels): 

1433 ind = [ 

1434 num for num, label in enumerate(labels) if 

1435 suffix in label 

1436 ] 

1437 if len(ind) > 1: 

1438 raise ValueError( 

1439 "Please provide a single {} label".format( 

1440 suffix.split(":")[1] 

1441 ) 

1442 ) 

1443 meta_file_labels.append( 

1444 labels[ind[0]].split(suffix)[0] 

1445 ) 

1446 else: 

1447 raise ValueError( 

1448 "Please provide labels as {inspiral_label}:inspiral " 

1449 "and {postinspiral_label}:postinspiral where " 

1450 "inspiral_label and postinspiral_label are the " 

1451 "PESummary labels for the inspiral and postinspiral " 

1452 "analyses respectively. " 

1453 ) 

1454 if len(self.result_files) == 1: 

1455 logger.info( 

1456 "Using the {} samples for the inspiral analysis and {} " 

1457 "samples for the postinspiral analysis from the file " 

1458 "{}".format( 

1459 meta_file_labels[0], meta_file_labels[1], 

1460 self.result_files[0] 

1461 ) 

1462 ) 

1463 elif len(self.result_files) == 2: 

1464 logger.info( 

1465 "Using the {} samples for the inspiral analysis from " 

1466 "the file {}. Using the {} samples for the " 

1467 "postinspiral analysis from the file {}".format( 

1468 meta_file_labels[0], self.result_files[0], 

1469 meta_file_labels[1], self.result_files[1] 

1470 ) 

1471 ) 

1472 else: 

1473 raise ValueError( 

1474 "Currently, you can only provide at most 2 pesummary " 

1475 "metafiles. If one is provided, both the inspiral and " 

1476 "postinspiral are extracted from that single file. If " 

1477 "two are provided, the inspiral is extracted from one " 

1478 "file and the postinspiral is extracted from the other." 

1479 ) 

1480 self._labels = ["inspiral", "postinspiral"] 

1481 self._meta_file_labels = meta_file_labels 

1482 self.analysis_label = ["primary"] 

1483 else: 

1484 raise ValueError( 

1485 "The IMRCT test requires an inspiral and postinspiral result " 

1486 "file. Please indicate which file is the inspiral and which " 

1487 "is postinspiral by providing these exact labels to the " 

1488 "summarytgr executable" 

1489 ) 

1490 else: 

1491 self.analysis_label = ["primary"] 

1492 

1493 def _extract_stored_approximant(self, opened_file, label): 

1494 """Extract the approximant used for a given analysis stored in a 

1495 PESummary metafile 

1496 

1497 Parameters 

1498 ---------- 

1499 opened_file: pesummary.gw.file.formats.pesummary.PESummary 

1500 opened metafile that contains the analysis 'label' 

1501 label: str 

1502 analysis label which is stored in the PESummary metafile 

1503 """ 

1504 if opened_file.approximant is not None: 

1505 if label not in opened_file.labels: 

1506 raise ValueError( 

1507 "Invalid label {}. The list of available labels are {}".format( 

1508 label, ", ".join(opened_file.labels) 

1509 ) 

1510 ) 

1511 _index = opened_file.labels.index(label) 

1512 return opened_file.approximant[_index] 

1513 return 

1514 

1515 def _extract_stored_remnant_fits(self, opened_file, label): 

1516 """Extract the remnant fits used for a given analysis stored in a 

1517 PESummary metafile 

1518 

1519 Parameters 

1520 ---------- 

1521 opened_file: pesummary.gw.file.formats.pesummary.PESummary 

1522 opened metafile that contains the analysis 'label' 

1523 label: str 

1524 analysis label which is stored in the PESummary metafile 

1525 """ 

1526 fits = {} 

1527 fit_strings = [ 

1528 "final_spin_NR_fits", "final_mass_NR_fits" 

1529 ] 

1530 if label not in opened_file.labels: 

1531 raise ValueError( 

1532 "Invalid label {}. The list of available labels are {}".format( 

1533 label, ", ".join(opened_file.labels) 

1534 ) 

1535 ) 

1536 _index = opened_file.labels.index(label) 

1537 _meta_data = opened_file.extra_kwargs[_index] 

1538 if "meta_data" in _meta_data.keys(): 

1539 for key in fit_strings: 

1540 if key in _meta_data["meta_data"].keys(): 

1541 fits[key] = _meta_data["meta_data"][key] 

1542 if len(fits): 

1543 return fits 

1544 return 

1545 

1546 def _extract_stored_cutoff_frequency(self, opened_file, label): 

1547 """Extract the cutoff frequencies used for a given analysis stored in a 

1548 PESummary metafile 

1549 

1550 Parameters 

1551 ---------- 

1552 opened_file: pesummary.gw.file.formats.pesummary.PESummary 

1553 opened metafile that contains the analysis 'label' 

1554 label: str 

1555 analysis label which is stored in the PESummary metafile 

1556 """ 

1557 frequencies = {} 

1558 if opened_file.config is not None: 

1559 if label not in opened_file.labels: 

1560 raise ValueError( 

1561 "Invalid label {}. The list of available labels are {}".format( 

1562 label, ", ".join(opened_file.labels) 

1563 ) 

1564 ) 

1565 if opened_file.config[label] is not None: 

1566 _config = opened_file.config[label] 

1567 if "config" in _config.keys(): 

1568 if "maximum-frequency" in _config["config"].keys(): 

1569 frequencies["fhigh"] = _config["config"][ 

1570 "maximum-frequency" 

1571 ] 

1572 if "minimum-frequency" in _config["config"].keys(): 

1573 frequencies["flow"] = _config["config"][ 

1574 "minimum-frequency" 

1575 ] 

1576 elif "lalinference" in _config.keys(): 

1577 if "fhigh" in _config["lalinference"].keys(): 

1578 frequencies["fhigh"] = _config["lalinference"][ 

1579 "fhigh" 

1580 ] 

1581 if "flow" in _config["lalinference"].keys(): 

1582 frequencies["flow"] = _config["lalinference"][ 

1583 "flow" 

1584 ] 

1585 return frequencies 

1586 return 

1587 

1588 @property 

1589 def samples(self): 

1590 return self._samples 

1591 

1592 @samples.setter 

1593 def samples(self, samples): 

1594 from pesummary.utils.samples_dict import MultiAnalysisSamplesDict 

1595 self._read_samples = { 

1596 _label: GWRead(_path, disable_prior=True) for _label, _path in zip( 

1597 self.labels, self.result_files 

1598 ) 

1599 } 

1600 _samples_dict = {} 

1601 _approximant_dict = {} 

1602 _cutoff_frequency_dict = {} 

1603 _remnant_fits_dict = {} 

1604 for label, _open in self._read_samples.items(): 

1605 if isinstance(_open.samples_dict, MultiAnalysisSamplesDict): 

1606 if not len(self._meta_file_labels): 

1607 raise ValueError( 

1608 "Currently you can only pass a file containing a " 

1609 "single analysis or a valid PESummary metafile " 

1610 "containing multiple analyses" 

1611 ) 

1612 _labels = _open.labels 

1613 if len(self._read_samples) == 1: 

1614 _samples_dict = { 

1615 label: _open.samples_dict[meta_file_label] for 

1616 label, meta_file_label in zip( 

1617 self.labels, self._meta_file_labels 

1618 ) 

1619 } 

1620 for label, meta_file_label in zip(self.labels, self._meta_file_labels): 

1621 _stored_approx = self._extract_stored_approximant( 

1622 _open, meta_file_label 

1623 ) 

1624 _stored_frequencies = self._extract_stored_cutoff_frequency( 

1625 _open, meta_file_label 

1626 ) 

1627 _stored_remnant_fits = self._extract_stored_remnant_fits( 

1628 _open, meta_file_label 

1629 ) 

1630 if _stored_approx is not None: 

1631 _approximant_dict[label] = _stored_approx 

1632 if _stored_remnant_fits is not None: 

1633 _remnant_fits_dict[label] = _stored_remnant_fits 

1634 if _stored_frequencies is not None: 

1635 if label == "inspiral": 

1636 if "fhigh" in _stored_frequencies.keys(): 

1637 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1638 "fhigh" 

1639 ] 

1640 if label == "postinspiral": 

1641 if "flow" in _stored_frequencies.keys(): 

1642 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1643 "flow" 

1644 ] 

1645 break 

1646 else: 

1647 ind = self.labels.index(label) 

1648 _samples_dict[label] = _open.samples_dict[ 

1649 self._meta_file_labels[ind] 

1650 ] 

1651 _stored_approx = self._extract_stored_approximant( 

1652 _open, self._meta_file_labels[ind] 

1653 ) 

1654 _stored_frequencies = self._extract_stored_cutoff_frequency( 

1655 _open, self._meta_file_labels[ind] 

1656 ) 

1657 _stored_remnant_fits = self._extract_stored_remnant_fits( 

1658 _open, self._meta_file_labels[ind] 

1659 ) 

1660 if _stored_approx is not None: 

1661 _approximant_dict[label] = _stored_approx 

1662 if _stored_remnant_fits is not None: 

1663 _remnant_fits_dict[label] = _stored_remnant_fits 

1664 if _stored_frequencies is not None: 

1665 if label == "inspiral": 

1666 if "fhigh" in _stored_frequencies.keys(): 

1667 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1668 "fhigh" 

1669 ] 

1670 if label == "postinspiral": 

1671 if "flow" in _stored_frequencies.keys(): 

1672 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1673 "flow" 

1674 ] 

1675 else: 

1676 _samples_dict[label] = _open.samples_dict 

1677 extra_kwargs = _open.extra_kwargs 

1678 if "pe_algorithm" in extra_kwargs["sampler"].keys(): 

1679 if extra_kwargs["sampler"]["pe_algorithm"] == "bilby": 

1680 try: 

1681 subkwargs = extra_kwargs["other"]["likelihood"][ 

1682 "waveform_arguments" 

1683 ] 

1684 _approximant_dict[label] = ( 

1685 subkwargs["waveform_approximant"] 

1686 ) 

1687 if "inspiral" in label and "postinspiral" not in label: 

1688 _cutoff_frequency_dict[label] = ( 

1689 subkwargs["maximum_frequency"] 

1690 ) 

1691 elif "postinspiral" in label: 

1692 _cutoff_frequency_dict[label] = ( 

1693 subkwargs["minimum_frequency"] 

1694 ) 

1695 except KeyError: 

1696 pass 

1697 self._samples = MultiAnalysisSamplesDict(_samples_dict) 

1698 if len(_approximant_dict): 

1699 self._approximant_dict = _approximant_dict 

1700 if len(_cutoff_frequency_dict): 

1701 self._cutoff_frequency_dict = _cutoff_frequency_dict 

1702 if len(_remnant_fits_dict): 

1703 self._remnant_fits_dict = _remnant_fits_dict 

1704 

1705 @property 

1706 def imrct_kwargs(self): 

1707 return self._imrct_kwargs 

1708 

1709 @imrct_kwargs.setter 

1710 def imrct_kwargs(self, imrct_kwargs): 

1711 test_kwargs = dict(N_bins=101) 

1712 try: 

1713 test_kwargs.update(imrct_kwargs) 

1714 except AttributeError: 

1715 test_kwargs = test_kwargs 

1716 

1717 for key, value in test_kwargs.items(): 

1718 try: 

1719 test_kwargs[key] = ast.literal_eval(value) 

1720 except ValueError: 

1721 pass 

1722 self._imrct_kwargs = test_kwargs 

1723 

1724 @property 

1725 def meta_data(self): 

1726 return self._meta_data 

1727 

1728 @meta_data.setter 

1729 def meta_data(self, meta_data): 

1730 self._meta_data = {} 

1731 for num, _inspiral in enumerate(self.inspiral_keys): 

1732 frequency_dict = dict() 

1733 approximant_dict = dict() 

1734 remnant_dict = dict() 

1735 zipped = zip( 

1736 [self.cutoff_frequency, self.approximant, None], 

1737 [frequency_dict, approximant_dict, remnant_dict], 

1738 ["cutoff_frequency", "approximant", "remnant_fits"] 

1739 ) 

1740 _inspiral_string = self.inspiral_keys[num] 

1741 _postinspiral_string = self.postinspiral_keys[num] 

1742 for _list, _dict, name in zipped: 

1743 if _list is not None and len(_list) == len(self.labels): 

1744 inspiral_ind = self.labels.index(_inspiral_string) 

1745 postinspiral_ind = self.labels.index(_postinspiral_string) 

1746 _dict["inspiral"] = _list[inspiral_ind] 

1747 _dict["postinspiral"] = _list[postinspiral_ind] 

1748 elif _list is not None: 

1749 raise ValueError( 

1750 "Please provide a 'cutoff_frequency' and 'approximant' " 

1751 "for each file" 

1752 ) 

1753 else: 

1754 try: 

1755 if name == "cutoff_frequency": 

1756 if "inspiral" in self._cutoff_frequency_dict.keys(): 

1757 _dict["inspiral"] = self._cutoff_frequency_dict[ 

1758 "inspiral" 

1759 ] 

1760 if "postinspiral" in self._cutoff_frequency_dict.keys(): 

1761 _dict["postinspiral"] = self._cutoff_frequency_dict[ 

1762 "postinspiral" 

1763 ] 

1764 elif name == "approximant": 

1765 if "inspiral" in self._approximant_dict.keys(): 

1766 _dict["inspiral"] = self._approximant_dict[ 

1767 "inspiral" 

1768 ] 

1769 if "postinspiral" in self._approximant_dict.keys(): 

1770 _dict["postinspiral"] = self._approximant_dict[ 

1771 "postinspiral" 

1772 ] 

1773 elif name == "remnant_fits": 

1774 if "inspiral" in self._remnant_fits_dict.keys(): 

1775 _dict["inspiral"] = self._remnant_fits_dict[ 

1776 "inspiral" 

1777 ] 

1778 if "postinspiral" in self._remnant_fits_dict.keys(): 

1779 _dict["postinspiral"] = self._remnant_fits_dict[ 

1780 "postinspiral" 

1781 ] 

1782 except (AttributeError, KeyError, TypeError): 

1783 _dict["inspiral"] = None 

1784 _dict["postinspiral"] = None 

1785 

1786 self._meta_data[self.analysis_label[num]] = { 

1787 "inspiral maximum frequency (Hz)": frequency_dict["inspiral"], 

1788 "postinspiral minimum frequency (Hz)": frequency_dict["postinspiral"], 

1789 "inspiral approximant": approximant_dict["inspiral"], 

1790 "postinspiral approximant": approximant_dict["postinspiral"], 

1791 "inspiral remnant fits": remnant_dict["inspiral"], 

1792 "postinspiral remnant fits": remnant_dict["postinspiral"] 

1793 } 

1794 

1795 def __init__(self, opts): 

1796 self.opts = opts 

1797 self.existing = None 

1798 self.webdir = self.opts.webdir 

1799 self.user = None 

1800 self.baseurl = None 

1801 self.result_files = self.opts.samples 

1802 self.labels = self.opts.labels 

1803 self.samples = self.opts.samples 

1804 self.inspiral_keys = [ 

1805 key for key in self.samples.keys() if "inspiral" in key 

1806 and "postinspiral" not in key 

1807 ] 

1808 self.postinspiral_keys = [ 

1809 key.replace("inspiral", "postinspiral") for key in self.inspiral_keys 

1810 ] 

1811 try: 

1812 self.imrct_kwargs = self.opts.imrct_kwargs 

1813 except AttributeError: 

1814 self.imrct_kwargs = {} 

1815 for _arg in ["cutoff_frequency", "approximant", "links_to_pe_pages", "f_low"]: 

1816 _attr = getattr(self.opts, _arg) 

1817 if _attr is not None and len(_attr) and len(_attr) != len(self.labels): 

1818 raise ValueError("Please provide a {} for each file".format(_arg)) 

1819 setattr(self, _arg, _attr) 

1820 self.meta_data = None 

1821 self.default_directories = ["samples", "plots", "js", "html", "css"] 

1822 self.publication = False 

1823 self.make_directories()