Coverage for pesummary/gw/cli/inputs.py: 77.2%

986 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2026-01-15 17:49 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import ast 

4import os 

5import numpy as np 

6import pesummary.core.cli.inputs 

7from pesummary.gw.file.read import read as GWRead 

8from pesummary.gw.file.psd import PSD 

9from pesummary.gw.file.calibration import Calibration 

10from pesummary.utils.decorators import deprecation 

11from pesummary.utils.exceptions import InputError 

12from pesummary.utils.utils import logger 

13from pesummary import conf 

14 

15__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

16 

17 

18class _GWInput(pesummary.core.cli.inputs._Input): 

19 """Super class to handle gw specific command line inputs 

20 """ 

21 @staticmethod 

22 def grab_data_from_metafile( 

23 existing_file, webdir, compare=None, nsamples=None, **kwargs 

24 ): 

25 """Grab data from an existing PESummary metafile 

26 

27 Parameters 

28 ---------- 

29 existing_file: str 

30 path to the existing metafile 

31 webdir: str 

32 the directory to store the existing configuration file 

33 compare: list, optional 

34 list of labels for events stored in an existing metafile that you 

35 wish to compare 

36 """ 

37 _replace_kwargs = { 

38 "psd": "{file}.psd['{label}']" 

39 } 

40 if "psd_default" in kwargs.keys(): 

41 _replace_kwargs["psd_default"] = kwargs["psd_default"] 

42 data = pesummary.core.cli.inputs._Input.grab_data_from_metafile( 

43 existing_file, webdir, compare=compare, read_function=GWRead, 

44 nsamples=nsamples, _replace_with_pesummary_kwargs=_replace_kwargs, 

45 **kwargs 

46 ) 

47 f = GWRead(existing_file) 

48 

49 labels = data["labels"] 

50 

51 psd = {i: {} for i in labels} 

52 if f.psd is not None and f.psd != {}: 

53 for i in labels: 

54 if i in f.psd.keys() and f.psd[i] != {}: 

55 psd[i] = { 

56 ifo: PSD(f.psd[i][ifo]) for ifo in f.psd[i].keys() 

57 } 

58 calibration = {i: {} for i in labels} 

59 if f.calibration is not None and f.calibration != {}: 

60 for i in labels: 

61 if i in f.calibration.keys() and f.calibration[i] != {}: 

62 calibration[i] = { 

63 ifo: Calibration(f.calibration[i][ifo]) for ifo in 

64 f.calibration[i].keys() 

65 } 

66 skymap = {i: None for i in labels} 

67 if hasattr(f, "skymap") and f.skymap is not None and f.skymap != {}: 

68 for i in labels: 

69 if i in f.skymap.keys() and len(f.skymap[i]): 

70 skymap[i] = f.skymap[i] 

71 data.update( 

72 { 

73 "approximant": { 

74 i: j for i, j in zip( 

75 labels, [f.approximant[ind] for ind in data["indicies"]] 

76 ) 

77 }, 

78 "psd": psd, 

79 "calibration": calibration, 

80 "skymap": skymap 

81 } 

82 ) 

83 return data 

84 

85 @property 

86 def grab_data_kwargs(self): 

87 kwargs = super(_GWInput, self).grab_data_kwargs 

88 for _property in ["f_start", "f_low", "f_ref", "f_final", "delta_f"]: 

89 if getattr(self, _property) is None: 

90 setattr(self, "_{}".format(_property), [None] * len(self.labels)) 

91 elif len(getattr(self, _property)) == 1 and len(self.labels) != 1: 

92 setattr( 

93 self, "_{}".format(_property), 

94 getattr(self, _property) * len(self.labels) 

95 ) 

96 if self.opts.approximant is None: 

97 approx = [None] * len(self.labels) 

98 else: 

99 approx = self.opts.approximant 

100 self.approximant_flags = self.opts.approximant_flags 

101 resume_file = [ 

102 os.path.join( 

103 self.webdir, "checkpoint", 

104 "{}_conversion_class.pickle".format(label) 

105 ) for label in self.labels 

106 ] 

107 

108 try: 

109 for num, label in enumerate(self.labels): 

110 try: 

111 psd = self.psd[label] 

112 except KeyError: 

113 psd = {} 

114 kwargs[label].update(dict( 

115 evolve_spins_forwards=self.evolve_spins_forwards, 

116 evolve_spins_backwards=self.evolve_spins_backwards, 

117 f_start=self.f_start[num], f_low=self.f_low[num], 

118 approximant_flags=self.approximant_flags.get(label, {}), 

119 approximant=approx[num], f_ref=self.f_ref[num], 

120 NRSur_fits=self.NRSur_fits, return_kwargs=True, 

121 multipole_snr=self.calculate_multipole_snr, 

122 precessing_snr=self.calculate_precessing_snr, 

123 psd=psd, f_final=self.f_final[num], 

124 waveform_fits=self.waveform_fits, 

125 multi_process=self.opts.multi_process, 

126 redshift_method=self.redshift_method, 

127 cosmology=self.cosmology, 

128 no_conversion=self.no_conversion, 

129 add_zero_spin=True, delta_f=self.delta_f[num], 

130 psd_default=self.psd_default, 

131 disable_remnant=self.disable_remnant, 

132 force_BBH_remnant_computation=self.force_BBH_remnant_computation, 

133 resume_file=resume_file[num], 

134 restart_from_checkpoint=self.restart_from_checkpoint, 

135 force_BH_spin_evolution=self.force_BH_spin_evolution, 

136 )) 

137 return kwargs 

138 except IndexError: 

139 logger.warning( 

140 "Unable to find an f_ref, f_low and approximant for each " 

141 "label. Using and f_ref={}, f_low={} and approximant={} " 

142 "for all result files".format( 

143 self.f_ref[0], self.f_low[0], approx[0] 

144 ) 

145 ) 

146 for num, label in enumerate(self.labels): 

147 kwargs[label].update(dict( 

148 evolve_spins_forwards=self.evolve_spins_forwards, 

149 evolve_spins_backwards=self.evolve_spins_backwards, 

150 f_start=self.f_start[0], f_low=self.f_low[0], 

151 approximant=approx[0], f_ref=self.f_ref[0], 

152 NRSur_fits=self.NRSur_fits, return_kwargs=True, 

153 multipole_snr=self.calculate_multipole_snr, 

154 precessing_snr=self.calculate_precessing_snr, 

155 psd=self.psd[self.labels[0]], f_final=self.f_final[0], 

156 waveform_fits=self.waveform_fits, 

157 multi_process=self.opts.multi_process, 

158 redshift_method=self.redshift_method, 

159 cosmology=self.cosmology, 

160 no_conversion=self.no_conversion, 

161 add_zero_spin=True, delta_f=self.delta_f[0], 

162 psd_default=self.psd_default, 

163 disable_remnant=self.disable_remnant, 

164 force_BBH_remnant_computation=self.force_BBH_remnant_computation, 

165 resume_file=resume_file[num], 

166 restart_from_checkpoint=self.restart_from_checkpoint, 

167 force_BH_spin_evolution=self.force_BH_spin_evolution 

168 )) 

169 return kwargs 

170 

171 @staticmethod 

172 def grab_data_from_file( 

173 file, label, webdir, config=None, injection=None, file_format=None, 

174 nsamples=None, disable_prior_sampling=False, **kwargs 

175 ): 

176 """Grab data from a result file containing posterior samples 

177 

178 Parameters 

179 ---------- 

180 file: str 

181 path to the result file 

182 label: str 

183 label that you wish to use for the result file 

184 config: str, optional 

185 path to a configuration file used in the analysis 

186 injection: str, optional 

187 path to an injection file used in the analysis 

188 file_format, str, optional 

189 the file format you wish to use when loading. Default None. 

190 If None, the read function loops through all possible options 

191 """ 

192 data = pesummary.core.cli.inputs._Input.grab_data_from_file( 

193 file, label, webdir, config=config, injection=injection, 

194 read_function=GWRead, file_format=file_format, nsamples=nsamples, 

195 disable_prior_sampling=disable_prior_sampling, **kwargs 

196 ) 

197 try: 

198 for _kwgs in data["file_kwargs"][label]: 

199 if "approximant_flags" in _kwgs["meta_data"]: 

200 for key, item in _kwargs["meta_data"]["approximant_flags"].items(): 

201 warning_cond = ( 

202 key in self.approximant_flags[label] and 

203 self.approximant_flags[label][key] != item 

204 ) 

205 if warning_cond: 

206 logger.warning( 

207 "Approximant flag {}={} found in result file for {}. " 

208 "Ignoring and using the provided values {}={}".format( 

209 key, self.approximant_flags[label][key], label, 

210 key, item 

211 ) 

212 ) 

213 else: 

214 self.approximant_flags[label][key] = item 

215 except Exception: 

216 pass 

217 return data 

218 

219 @property 

220 def reweight_samples(self): 

221 return self._reweight_samples 

222 

223 @reweight_samples.setter 

224 def reweight_samples(self, reweight_samples): 

225 from pesummary.gw.reweight import options 

226 self._reweight_samples = self._check_reweight_samples( 

227 reweight_samples, options 

228 ) 

229 

230 def _set_samples(self, *args, **kwargs): 

231 super(_GWInput, self)._set_samples(*args, **kwargs) 

232 if "calibration" not in self.priors: 

233 self.priors["calibration"] = { 

234 label: {} for label in self.labels 

235 } 

236 if "calibration_raw" not in self.priors: 

237 self.priors["calibration_raw"] = { 

238 label: {} for label in self.labels 

239 } 

240 

241 def _set_corner_params(self, corner_params): 

242 corner_params = super(_GWInput, self)._set_corner_params(corner_params) 

243 if corner_params is None: 

244 logger.debug( 

245 "Using the default corner parameters: {}".format( 

246 ", ".join(conf.gw_corner_parameters) 

247 ) 

248 ) 

249 else: 

250 _corner_params = corner_params 

251 corner_params = list(set(conf.gw_corner_parameters + corner_params)) 

252 for param in _corner_params: 

253 _data = self.samples 

254 if not all(param in _data[label].keys() for label in self.labels): 

255 corner_params.remove(param) 

256 logger.debug( 

257 "Generating a corner plot with the following " 

258 "parameters: {}".format(", ".join(corner_params)) 

259 ) 

260 return corner_params 

261 

262 @property 

263 def cosmology(self): 

264 return self._cosmology 

265 

266 @cosmology.setter 

267 def cosmology(self, cosmology): 

268 from pesummary.gw.cosmology import available_cosmologies 

269 

270 if cosmology.lower() not in available_cosmologies: 

271 logger.warning( 

272 "Unrecognised cosmology: {}. Using {} as default".format( 

273 cosmology, conf.cosmology 

274 ) 

275 ) 

276 cosmology = conf.cosmology 

277 else: 

278 logger.debug("Using the {} cosmology".format(cosmology)) 

279 self._cosmology = cosmology 

280 

281 @property 

282 def approximant(self): 

283 return self._approximant 

284 

285 @approximant.setter 

286 def approximant(self, approximant): 

287 if not hasattr(self, "_approximant"): 

288 approximant_list = {i: {} for i in self.labels} 

289 if approximant is None: 

290 logger.warning( 

291 "No approximant passed. Waveform plots will not be " 

292 "generated" 

293 ) 

294 elif approximant is not None: 

295 if len(approximant) != len(self.labels): 

296 raise InputError( 

297 "Please pass an approximant for each result file" 

298 ) 

299 approximant_list = { 

300 i: j for i, j in zip(self.labels, approximant) 

301 } 

302 self._approximant = approximant_list 

303 else: 

304 for num, i in enumerate(self._approximant.keys()): 

305 if self._approximant[i] == {}: 

306 if num == 0: 

307 logger.warning( 

308 "No approximant passed. Waveform plots will not be " 

309 "generated" 

310 ) 

311 self._approximant[i] = None 

312 break 

313 

314 @property 

315 def approximant_flags(self): 

316 return self._approximant_flags 

317 

318 @approximant_flags.setter 

319 def approximant_flags(self, approximant_flags): 

320 if hasattr(self, "_approximant_flags"): 

321 return 

322 _approximant_flags = {key: {} for key in self.labels} 

323 for key, item in approximant_flags.items(): 

324 _label, key = key.split(":") 

325 if _label not in self.labels: 

326 raise ValueError( 

327 "Unable to assign waveform flags for label:{} because " 

328 "it does not exist. Available labels are: {}. Approximant " 

329 "flags must be provided in the form LABEL:FLAG:VALUE".format( 

330 _label, ", ".join(self.labels) 

331 ) 

332 ) 

333 _approximant_flags[_label][key] = item 

334 self._approximant_flags = _approximant_flags 

335 

336 @property 

337 def gracedb_server(self): 

338 return self._gracedb_server 

339 

340 @gracedb_server.setter 

341 def gracedb_server(self, gracedb_server): 

342 if gracedb_server is None: 

343 self._gracedb_server = conf.gracedb_server 

344 else: 

345 logger.debug( 

346 "Using '{}' as the GraceDB server".format(gracedb_server) 

347 ) 

348 self._gracedb_server = gracedb_server 

349 

350 @property 

351 def gracedb(self): 

352 return self._gracedb 

353 

354 @gracedb.setter 

355 def gracedb(self, gracedb): 

356 self._gracedb = gracedb 

357 if gracedb is not None: 

358 from pesummary.gw.gracedb import get_gracedb_data, HTTPError 

359 from json.decoder import JSONDecodeError 

360 

361 first_letter = gracedb[0] 

362 if first_letter != "G" and first_letter != "g" and first_letter != "S": 

363 logger.warning( 

364 "Invalid GraceDB ID passed. The GraceDB ID must be of the " 

365 "form G0000 or S0000. Ignoring input." 

366 ) 

367 self._gracedb = None 

368 return 

369 _error = ( 

370 "Unable to download data from Gracedb because {}. Only storing " 

371 "the GraceDB ID in the metafile" 

372 ) 

373 try: 

374 logger.info( 

375 "Downloading {} from gracedb for {}".format( 

376 ", ".join(self.gracedb_data), gracedb 

377 ) 

378 ) 

379 json = get_gracedb_data( 

380 gracedb, info=self.gracedb_data, 

381 service_url=self.gracedb_server 

382 ) 

383 json["id"] = gracedb 

384 except (HTTPError, RuntimeError, JSONDecodeError) as e: 

385 logger.warning(_error.format(e)) 

386 json = {"id": gracedb} 

387 

388 for label in self.labels: 

389 self.file_kwargs[label]["meta_data"]["gracedb"] = json 

390 

391 @property 

392 def terrestrial_probability(self): 

393 return self._terrestrial_probability 

394 

395 @terrestrial_probability.setter 

396 def terrestrial_probability(self, terrestrial_probability): 

397 if terrestrial_probability is None and self.gracedb is not None: 

398 logger.info( 

399 "No terrestrial probability provided. Trying to download " 

400 "from gracedb" 

401 ) 

402 from pesummary.core.fetch import download_and_read_file 

403 from urllib.error import HTTPError 

404 from json.decoder import JSONDecodeError 

405 import json 

406 try: 

407 ff = download_and_read_file( 

408 f"{self.gracedb_server}/superevents/{self.gracedb}/" 

409 f"files/p_astro.json", read_file=False, 

410 outdir=f"{self.webdir}/samples", delete_on_exit=False 

411 ) 

412 with open(ff, "r") as f: 

413 data = json.load(f) 

414 self._terrestrial_probability = [float(data["Terrestrial"])] 

415 except (RuntimeError, JSONDecodeError) as e: 

416 logger.warning( 

417 "Unable to grab terrestrial probability from gracedb " 

418 "because {}".format(e) 

419 ) 

420 self._terrestrial_probability = [None] 

421 except HTTPError as e: 

422 from pesummary.gw.gracedb import get_gracedb_data, get_gracedb_file 

423 try: 

424 preferred = get_gracedb_data( 

425 self.gracedb, info=["preferred_event_data"], 

426 service_url=self.gracedb_server 

427 )["preferred_event_data"]["submitter"] 

428 _pipelines = [ 

429 "pycbc", "gstlal", "mbta", "spiir" 

430 ] 

431 _filename = None 

432 for _pipe in _pipelines: 

433 if _pipe in preferred: 

434 _filename = f"{_pipe}.p_astro.json" 

435 if _filename is None: 

436 raise e 

437 data = get_gracedb_file( 

438 self.gracedb, _filename, service_url=self.gracedb_server 

439 ) 

440 with open(f"{self.webdir}/samples/{_filename}", "w") as json_file: 

441 json.dump(data, json_file) 

442 self._terrestrial_probability = [float(data["Terrestrial"])] 

443 except Exception as e: 

444 logger.warning( 

445 "Unable to grab terrestrial probability from gracedb " 

446 "because {}".format(e) 

447 ) 

448 self._terrestrial_probability = [None] 

449 self._terrestrial_probability *= len(self.labels) 

450 elif terrestrial_probability is None: 

451 self._terrestrial_probability = [None] * len(self.labels) 

452 else: 

453 if len(terrestrial_probability) == 1 and len(self.labels) > 1: 

454 logger.debug( 

455 f"Assuming a terrestrial probability: " 

456 f"{terrestrial_probability} for all analyses" 

457 ) 

458 self._terrestrial_probability = [ 

459 float(terrestrial_probability[0]) 

460 ] * len(self.labels) 

461 elif len(terrestrial_probability) == len(self.labels): 

462 self._terrestrial_probability = [ 

463 float(_) for _ in terrestrial_probability 

464 ] 

465 else: 

466 raise ValueError( 

467 "Please provide a terrestrial probability for each " 

468 "analysis, or a single value to be used for all analyses" 

469 ) 

470 

471 @property 

472 def detectors(self): 

473 return self._detectors 

474 

475 @detectors.setter 

476 def detectors(self, detectors): 

477 detector = {} 

478 if not detectors: 

479 for i in self.samples.keys(): 

480 params = list(self.samples[i].keys()) 

481 individual_detectors = [] 

482 for j in params: 

483 if "optimal_snr" in j and j != "network_optimal_snr": 

484 det = j.split("_optimal_snr")[0] 

485 individual_detectors.append(det) 

486 individual_detectors = sorted( 

487 [str(i) for i in individual_detectors]) 

488 if individual_detectors: 

489 detector[i] = "_".join(individual_detectors) 

490 else: 

491 detector[i] = None 

492 else: 

493 detector = detectors 

494 logger.debug("The detector network is %s" % (detector)) 

495 self._detectors = detector 

496 

497 @property 

498 def skymap(self): 

499 return self._skymap 

500 

501 @skymap.setter 

502 def skymap(self, skymap): 

503 if not hasattr(self, "_skymap"): 

504 self._skymap = {i: None for i in self.labels} 

505 

506 @property 

507 def calibration_definition(self): 

508 return self._calibration_definition 

509 

510 @calibration_definition.setter 

511 def calibration_definition(self, calibration_definition): 

512 if not len(self.opts.calibration): 

513 self._calibration_definition = None 

514 return 

515 if len(calibration_definition) == 1: 

516 logger.info( 

517 f"Assuming that the calibration correction was applied to " 

518 f"'{calibration_definition[0]}' for all analyses" 

519 ) 

520 calibration_definition *= len(self.labels) 

521 elif len(calibration_definition) != len(self.labels): 

522 raise ValueError( 

523 f"Please provide a calibration definition for each analysis " 

524 f"({len(self.labels)}) or a single definition to use for all " 

525 f"analyses" 

526 ) 

527 if any(_ not in ["data", "template"] for _ in calibration_definition): 

528 raise ValueError( 

529 "Calibration definitions must be either 'data' or 'template'" 

530 ) 

531 self._calibration_definition = { 

532 label: calibration_definition[num] for num, label in 

533 enumerate(self.labels) 

534 } 

535 

536 @property 

537 def calibration(self): 

538 return self._calibration 

539 

540 @calibration.setter 

541 def calibration(self, calibration): 

542 if not hasattr(self, "_calibration"): 

543 data = {i: {} for i in self.labels} 

544 if calibration != {}: 

545 prior_data_raw = self.get_psd_or_calibration_data( 

546 calibration, self.extract_calibration_data_from_file, 

547 type="template" # do not do any modification 

548 ) 

549 self.add_to_prior_dict("calibration_raw", prior_data_raw) 

550 prior_data = self.get_psd_or_calibration_data( 

551 calibration, self.extract_calibration_data_from_file, 

552 type=self.calibration_definition[self.labels[0]] 

553 ) 

554 self.add_to_prior_dict("calibration", prior_data) 

555 else: 

556 prior_data = {i: {} for i in self.labels} 

557 prior_data_raw = {i: {} for i in self.labels} 

558 for label in self.labels: 

559 if hasattr(self.opts, "{}_calibration".format(label)): 

560 cal_data = getattr(self.opts, "{}_calibration".format(label)) 

561 if cal_data != {} and cal_data is not None: 

562 prior_data[label] = { 

563 ifo: self.extract_calibration_data_from_file( 

564 cal_data[ifo], type=self.calibration_definition[label] 

565 ) for ifo in cal_data.keys() 

566 } 

567 prior_data_raw[label] = { 

568 ifo: self.extract_calibration_data_from_file( 

569 cal_data[ifo], type="template" # do not do any modification 

570 ) for ifo in cal_data.keys() 

571 } 

572 if not all(prior_data[i] == {} for i in self.labels): 

573 self.add_to_prior_dict("calibration_raw", prior_data_raw) 

574 self.add_to_prior_dict("calibration", prior_data) 

575 else: 

576 self.add_to_prior_dict("calibration_raw", prior_data_raw) 

577 self.add_to_prior_dict("calibration", prior_data) 

578 for num, i in enumerate(self.result_files): 

579 _opened = self._open_result_files 

580 if i in _opened.keys() and _opened[i] is not None: 

581 f = self._open_result_files[i] 

582 else: 

583 f = GWRead(i, disable_prior=True) 

584 try: 

585 calibration_data = f.interpolate_calibration_spline_posterior() 

586 except Exception as e: 

587 logger.warning( 

588 "Failed to extract calibration data from the result " 

589 "file: {} because {}".format(i, e) 

590 ) 

591 calibration_data = None 

592 labels = list(self.samples.keys()) 

593 if calibration_data is None: 

594 data[labels[num]] = { 

595 None: None 

596 } 

597 elif isinstance(f, pesummary.gw.file.formats.pesummary.PESummary): 

598 for num in range(len(calibration_data[0])): 

599 data[labels[num]] = { 

600 j: k for j, k in zip( 

601 calibration_data[1][num], 

602 calibration_data[0][num] 

603 ) 

604 } 

605 else: 

606 data[labels[num]] = { 

607 j: k for j, k in zip( 

608 calibration_data[1], calibration_data[0] 

609 ) 

610 } 

611 self._calibration = data 

612 

613 @property 

614 def psd(self): 

615 return self._psd 

616 

617 @psd.setter 

618 def psd(self, psd): 

619 if not hasattr(self, "_psd"): 

620 data = {i: {} for i in self.labels} 

621 if psd != {}: 

622 data = self.get_psd_or_calibration_data( 

623 psd, self.extract_psd_data_from_file 

624 ) 

625 else: 

626 for label in self.labels: 

627 if hasattr(self.opts, "{}_psd".format(label)): 

628 psd_data = getattr(self.opts, "{}_psd".format(label)) 

629 if psd_data != {} and psd_data is not None: 

630 data[label] = { 

631 ifo: self.extract_psd_data_from_file( 

632 psd_data[ifo], IFO=ifo 

633 ) for ifo in psd_data.keys() 

634 } 

635 self._psd = data 

636 

637 @property 

638 def nsamples_for_skymap(self): 

639 return self._nsamples_for_skymap 

640 

641 @nsamples_for_skymap.setter 

642 def nsamples_for_skymap(self, nsamples_for_skymap): 

643 self._nsamples_for_skymap = nsamples_for_skymap 

644 if nsamples_for_skymap is not None: 

645 self._nsamples_for_skymap = int(nsamples_for_skymap) 

646 number_of_samples = [ 

647 data.number_of_samples for label, data in self.samples.items() 

648 ] 

649 if not all(i > self._nsamples_for_skymap for i in number_of_samples): 

650 min_arg = np.argmin(number_of_samples) 

651 logger.warning( 

652 "You have specified that you would like to use {} " 

653 "samples to generate the skymap but the file {} only " 

654 "has {} samples. Reducing the number of samples to " 

655 "generate the skymap to {}".format( 

656 self._nsamples_for_skymap, self.result_files[min_arg], 

657 number_of_samples[min_arg], number_of_samples[min_arg] 

658 ) 

659 ) 

660 self._nsamples_for_skymap = int(number_of_samples[min_arg]) 

661 

662 @property 

663 def gwdata(self): 

664 return self._gwdata 

665 

666 @gwdata.setter 

667 def gwdata(self, gwdata): 

668 from pesummary.gw.file.strain import StrainDataDict 

669 

670 self._gwdata = gwdata 

671 if gwdata is not None: 

672 if isinstance(gwdata, dict): 

673 for i in gwdata.keys(): 

674 if not os.path.isfile(gwdata[i]): 

675 raise InputError( 

676 "The file {} does not exist. Please check the path " 

677 "to your strain file".format(gwdata[i]) 

678 ) 

679 self._gwdata = StrainDataDict.read(gwdata) 

680 else: 

681 logger.warning( 

682 "Please provide gwdata as a dictionary with keys " 

683 "displaying the channel and item giving the path to the " 

684 "strain file" 

685 ) 

686 self._gwdata = None 

687 

688 @property 

689 def evolve_spins_forwards(self): 

690 return self._evolve_spins_forwards 

691 

692 @evolve_spins_forwards.setter 

693 def evolve_spins_forwards(self, evolve_spins_forwards): 

694 self._evolve_spins_forwards = evolve_spins_forwards 

695 _msg = "Spins will be evolved up to {}" 

696 if evolve_spins_forwards: 

697 logger.info(_msg.format("Schwarzschild ISCO frequency")) 

698 self._evolve_spins_forwards = 6. ** -0.5 

699 

700 @property 

701 def evolve_spins_backwards(self): 

702 return self._evolve_spins_backwards 

703 

704 @evolve_spins_backwards.setter 

705 def evolve_spins_backwards(self, evolve_spins_backwards): 

706 self._evolve_spins_backwards = evolve_spins_backwards 

707 _msg = ( 

708 "Spins will be evolved backwards to an infinite separation using the '{}' " 

709 "method" 

710 ) 

711 if isinstance(evolve_spins_backwards, (str, bytes)): 

712 logger.info(_msg.format(evolve_spins_backwards)) 

713 elif evolve_spins_backwards is None: 

714 logger.info(_msg.format("precession_averaged")) 

715 self._evolve_spins_backwards = "precession_averaged" 

716 

717 @property 

718 def NRSur_fits(self): 

719 return self._NRSur_fits 

720 

721 @NRSur_fits.setter 

722 def NRSur_fits(self, NRSur_fits): 

723 self._NRSur_fits = NRSur_fits 

724 base = ( 

725 "Using the '{}' NRSurrogate model to calculate the remnant " 

726 "quantities" 

727 ) 

728 if isinstance(NRSur_fits, (str, bytes)): 

729 logger.info(base.format(NRSur_fits)) 

730 self._NRSur_fits = NRSur_fits 

731 elif NRSur_fits is None: 

732 from pesummary.gw.conversions.nrutils import NRSUR_MODEL 

733 

734 logger.info(base.format(NRSUR_MODEL)) 

735 self._NRSur_fits = NRSUR_MODEL 

736 

737 @property 

738 def waveform_fits(self): 

739 return self._waveform_fits 

740 

741 @waveform_fits.setter 

742 def waveform_fits(self, waveform_fits): 

743 self._waveform_fits = waveform_fits 

744 if waveform_fits: 

745 logger.info( 

746 "Evaluating the remnant quantities using the provided " 

747 "approximant" 

748 ) 

749 

750 @property 

751 def f_low(self): 

752 return self._f_low 

753 

754 @f_low.setter 

755 def f_low(self, f_low): 

756 self._f_low = f_low 

757 if f_low is not None: 

758 self._f_low = [float(i) for i in f_low] 

759 

760 @property 

761 def f_start(self): 

762 return self._f_start 

763 

764 @f_start.setter 

765 def f_start(self, f_start): 

766 self._f_start = f_start 

767 if f_start is not None: 

768 self._f_start = [float(i) for i in f_start] 

769 

770 @property 

771 def f_ref(self): 

772 return self._f_ref 

773 

774 @f_ref.setter 

775 def f_ref(self, f_ref): 

776 self._f_ref = f_ref 

777 if f_ref is not None: 

778 self._f_ref = [float(i) for i in f_ref] 

779 

780 @property 

781 def f_final(self): 

782 return self._f_final 

783 

784 @f_final.setter 

785 def f_final(self, f_final): 

786 self._f_final = f_final 

787 if f_final is not None: 

788 self._f_final = [float(i) for i in f_final] 

789 

790 @property 

791 def delta_f(self): 

792 return self._delta_f 

793 

794 @delta_f.setter 

795 def delta_f(self, delta_f): 

796 self._delta_f = delta_f 

797 if delta_f is not None: 

798 self._delta_f = [float(i) for i in delta_f] 

799 

800 @property 

801 def psd_default(self): 

802 return self._psd_default 

803 

804 @psd_default.setter 

805 def psd_default(self, psd_default): 

806 self._psd_default = psd_default 

807 if "stored:" in psd_default: 

808 label = psd_default.split("stored:")[1] 

809 self._psd_default = "{file}.psd['%s']" % (label) 

810 return 

811 try: 

812 from pycbc import psd 

813 psd_default = getattr(psd, psd_default) 

814 except ImportError: 

815 logger.warning( 

816 "Unable to import 'pycbc'. Unable to generate a default PSD" 

817 ) 

818 psd_default = None 

819 except AttributeError: 

820 logger.warning( 

821 "'pycbc' does not have the '{}' psd available. Using '{}' as " 

822 "the default PSD".format(psd_default, conf.psd) 

823 ) 

824 psd_default = getattr(psd, conf.psd) 

825 except ValueError as e: 

826 logger.warning("Setting 'psd_default' to None because {}".format(e)) 

827 psd_default = None 

828 self._psd_default = psd_default 

829 

830 @property 

831 def pastro_probs(self): 

832 return self._pastro_probs 

833 

834 @pastro_probs.setter 

835 def pastro_probs(self, pastro_probs): 

836 from pesummary.gw.classification import PAstro 

837 

838 classifications = {} 

839 for num, i in enumerate(list(self.samples.keys())): 

840 try: 

841 import importlib 

842 distance_prior = self.priors["analytic"]["luminosity_distance"] 

843 cls = distance_prior.split("(")[0] 

844 module = ".".join(cls.split(".")[:-1]) 

845 cls = cls.split(".")[-1] 

846 cls = getattr(importlib.import_module(module), cls, cls) 

847 args = "(".join(distance_prior.split("(")[1:])[:-1] 

848 distance_prior = cls.from_repr(args) 

849 except KeyError: 

850 logger.debug( 

851 f"Unable to find a distance prior. Defaulting to stored " 

852 f"prior in pesummary.gw.classification.PAstro for " 

853 f"source classification probabilities" 

854 ) 

855 distance_prior = None 

856 except AttributeError: 

857 logger.debug( 

858 f"Unable to load distance prior: {distance_prior}. " 

859 f"Defaulting to stored prior in " 

860 f"pesummary.gw.classification.PAstro for source " 

861 f"classification probabilities" 

862 ) 

863 distance_prior = None 

864 try: 

865 _cls = PAstro( 

866 self.samples[i], category_data=self.pastro_category_file, 

867 terrestrial_probability=self.terrestrial_probability[num], 

868 distance_prior=distance_prior, 

869 catch_terrestrial_probability_error=self.catch_terrestrial_probability_error 

870 ) 

871 classifications[i] = {"default": _cls.classification()} 

872 try: 

873 _cls.save_to_file( 

874 f"{i}.pesummary.p_astro.json", 

875 classifications[i]["default"], 

876 outdir=f"{self.webdir}/samples", 

877 overwrite=True 

878 ) 

879 except FileNotFoundError as e: 

880 logger.warning( 

881 f"Failed to write PAstro probabilities to file " 

882 f"because {e}" 

883 ) 

884 except Exception as e: 

885 logger.warning( 

886 "Failed to generate source classification probabilities " 

887 "because {}".format(e) 

888 ) 

889 classifications[i] = {"default": PAstro.defaults} 

890 if self.mcmc_samples: 

891 if any(_probs is None for _probs in classifications.values()): 

892 classifications[self.labels[0]] = None 

893 logger.warning( 

894 "Unable to average classification probabilities across " 

895 "mcmc chains because one or more chains failed to estimate " 

896 "classifications" 

897 ) 

898 else: 

899 logger.debug( 

900 "Averaging classification probability across mcmc samples" 

901 ) 

902 classifications[self.labels[0]] = { 

903 prior: { 

904 key: np.round(np.mean( 

905 [val[prior][key] for val in classifications.values()] 

906 ), 3) for key in _probs.keys() 

907 } for prior, _probs in 

908 list(classifications.values())[0].items() 

909 } 

910 self._pastro_probs = classifications 

911 

912 @property 

913 def embright_probs(self): 

914 return self._embright_probs 

915 

916 @embright_probs.setter 

917 def embright_probs(self, embright_probs): 

918 from pesummary.gw.classification import EMBright 

919 

920 probabilities = {} 

921 for num, i in enumerate(list(self.samples.keys())): 

922 try: 

923 _cls = EMBright(self.samples[i]) 

924 probabilities[i] = {"default": _cls.classification()} 

925 try: 

926 _cls.save_to_file( 

927 f"{i}.pesummary.em_bright.json", 

928 probabilities[i]["default"], 

929 outdir=f"{self.webdir}/samples", 

930 overwrite=True 

931 ) 

932 except FileNotFoundError as e: 

933 logger.warning( 

934 f"Failed to write EM bright probabilities to file " 

935 f"because {e}" 

936 ) 

937 except Exception as e: 

938 logger.warning( 

939 "Failed to generate em_bright probabilities because " 

940 "{}".format(e) 

941 ) 

942 probabilities[i] = {"default": EMBright.defaults} 

943 if self.mcmc_samples: 

944 if any(_probs is None for _probs in probabilities.values()): 

945 probabilities[self.labels[0]] = None 

946 logger.warning( 

947 "Unable to average em_bright probabilities across " 

948 "mcmc chains because one or more chains failed to estimate " 

949 "probabilities" 

950 ) 

951 else: 

952 logger.debug( 

953 "Averaging em_bright probability across mcmc samples" 

954 ) 

955 probabilities[self.labels[0]] = { 

956 prior: { 

957 key: np.round(np.mean( 

958 [val[prior][key] for val in probabilities.values()] 

959 ), 3) for key in _probs.keys() 

960 } for prior, _probs in list(probabilities.values())[0].items() 

961 } 

962 self._embright_probs = probabilities 

963 

964 @property 

965 def preliminary_pages(self): 

966 return self._preliminary_pages 

967 

968 @preliminary_pages.setter 

969 def preliminary_pages(self, preliminary_pages): 

970 required = conf.gw_reproducibility 

971 self._preliminary_pages = {label: False for label in self.labels} 

972 for num, label in enumerate(self.labels): 

973 for attr in required: 

974 _property = getattr(self, attr) 

975 if isinstance(_property, dict): 

976 if label not in _property.keys(): 

977 self._preliminary_pages[label] = True 

978 elif not len(_property[label]): 

979 self._preliminary_pages[label] = True 

980 elif isinstance(_property, list): 

981 if _property[num] is None: 

982 self._preliminary_pages[label] = True 

983 if any(value for value in self._preliminary_pages.values()): 

984 _labels = [ 

985 label for label, value in self._preliminary_pages.items() if 

986 value 

987 ] 

988 msg = ( 

989 "Unable to reproduce the {} analys{} because no {} data was " 

990 "provided. 'Preliminary' watermarks will be added to the final " 

991 "html pages".format( 

992 ", ".join(_labels), "es" if len(_labels) > 1 else "is", 

993 " or ".join(required) 

994 ) 

995 ) 

996 logger.warning(msg) 

997 

998 @staticmethod 

999 def _extract_IFO_data_from_file(file, cls, desc, IFO=None, **kwargs): 

1000 """Return IFO data stored in a file 

1001 

1002 Parameters 

1003 ---------- 

1004 file: path 

1005 path to a file containing the IFO data 

1006 cls: obj 

1007 class you wish to use when loading the file. This class must have 

1008 a '.read' method 

1009 desc: str 

1010 description of the IFO data stored in the file 

1011 IFO: str, optional 

1012 the IFO which the data belongs to 

1013 """ 

1014 general = ( 

1015 "Failed to read in %s data because {}. The %s plot will not be " 

1016 "generated and the %s data will not be added to the metafile." 

1017 ) % (desc, desc, desc) 

1018 try: 

1019 return cls.read(file, IFO=IFO, **kwargs) 

1020 except FileNotFoundError: 

1021 logger.warning( 

1022 general.format("the file {} does not exist".format(file)) 

1023 ) 

1024 return {} 

1025 except ValueError as e: 

1026 logger.warning(general.format(e)) 

1027 return {} 

1028 

1029 @staticmethod 

1030 def extract_psd_data_from_file(file, IFO=None): 

1031 """Return the data stored in a psd file 

1032 

1033 Parameters 

1034 ---------- 

1035 file: path 

1036 path to a file containing the psd data 

1037 """ 

1038 from pesummary.gw.file.psd import PSD 

1039 return _GWInput._extract_IFO_data_from_file(file, PSD, "PSD", IFO=IFO) 

1040 

1041 @staticmethod 

1042 def extract_calibration_data_from_file(file, type="data", **kwargs): 

1043 """Return the data stored in a calibration file 

1044 

1045 Parameters 

1046 ---------- 

1047 file: path 

1048 path to a file containing the calibration data 

1049 """ 

1050 from pesummary.gw.file.calibration import Calibration 

1051 return _GWInput._extract_IFO_data_from_file( 

1052 file, Calibration, "calibration", type=type, **kwargs 

1053 ) 

1054 

1055 @staticmethod 

1056 def get_ifo_from_file_name(file): 

1057 """Return the IFO from the file name 

1058 

1059 Parameters 

1060 ---------- 

1061 file: str 

1062 path to the file 

1063 """ 

1064 file_name = file.split("/")[-1] 

1065 if any(j in file_name for j in ["H1", "_0", "IFO0"]): 

1066 ifo = "H1" 

1067 elif any(j in file_name for j in ["L1", "_1", "IFO1"]): 

1068 ifo = "L1" 

1069 elif any(j in file_name for j in ["V1", "_2", "IFO2"]): 

1070 ifo = "V1" 

1071 else: 

1072 ifo = file_name 

1073 return ifo 

1074 

1075 def get_psd_or_calibration_data(self, input, executable, **kwargs): 

1076 """Return a dictionary containing the psd or calibration data 

1077 

1078 Parameters 

1079 ---------- 

1080 input: list/dict 

1081 list/dict containing paths to calibration/psd files 

1082 executable: func 

1083 executable that is used to extract the data from the calibration/psd 

1084 files 

1085 """ 

1086 data = {} 

1087 if input == {} or input == []: 

1088 return data 

1089 if isinstance(input, dict): 

1090 keys = list(input.keys()) 

1091 if isinstance(input, dict) and isinstance(input[keys[0]], list): 

1092 if not all(len(input[i]) == len(self.labels) for i in list(keys)): 

1093 raise InputError( 

1094 "Please ensure the number of calibration/psd files matches " 

1095 "the number of result files passed" 

1096 ) 

1097 for idx in range(len(input[keys[0]])): 

1098 data[self.labels[idx]] = { 

1099 i: executable(input[i][idx], IFO=i, **kwargs) for i in list(keys) 

1100 } 

1101 elif isinstance(input, dict): 

1102 for i in self.labels: 

1103 data[i] = { 

1104 j: executable(input[j], IFO=j, **kwargs) for j in list(input.keys()) 

1105 } 

1106 elif isinstance(input, list): 

1107 for i in self.labels: 

1108 data[i] = { 

1109 self.get_ifo_from_file_name(j): executable( 

1110 j, IFO=self.get_ifo_from_file_name(j), **kwargs 

1111 ) for j in input 

1112 } 

1113 else: 

1114 raise InputError( 

1115 "Did not understand the psd/calibration input. Please use the " 

1116 "following format 'H1:path/to/file'" 

1117 ) 

1118 return data 

1119 

1120 def grab_priors_from_inputs(self, priors): 

1121 def read_func(data, **kwargs): 

1122 from pesummary.gw.file.read import read as GWRead 

1123 data = GWRead(data, **kwargs) 

1124 data.generate_all_posterior_samples() 

1125 return data 

1126 

1127 return super(_GWInput, self).grab_priors_from_inputs( 

1128 priors, read_func=read_func, read_kwargs=self.grab_data_kwargs 

1129 ) 

1130 

1131 def grab_key_data_from_result_files(self): 

1132 """Grab the mean, median, maxL and standard deviation for all 

1133 parameters for all each result file 

1134 """ 

1135 from pesummary.utils.kde_list import KDEList 

1136 from pesummary.gw.plots.plot import _return_bounds 

1137 from pesummary.utils.credible_interval import ( 

1138 hpd_two_sided_credible_interval 

1139 ) 

1140 from pesummary.utils.bounded_1d_kde import bounded_1d_kde 

1141 key_data = super(_GWInput, self).grab_key_data_from_result_files() 

1142 bounded_parameters = ["mass_ratio", "a_1", "a_2", "lambda_tilde"] 

1143 for param in bounded_parameters: 

1144 xlow, xhigh = _return_bounds(param, []) 

1145 _samples = { 

1146 key: val[param] for key, val in self.samples.items() 

1147 if param in val.keys() 

1148 } 

1149 _min = [np.min(_) for _ in _samples.values() if len(_samples)] 

1150 _max = [np.max(_) for _ in _samples.values() if len(_samples)] 

1151 if not len(_min): 

1152 continue 

1153 _min = np.min(_min) 

1154 _max = np.max(_max) 

1155 x = np.linspace(_min, _max, 1000) 

1156 try: 

1157 kdes = KDEList( 

1158 list(_samples.values()), kde=bounded_1d_kde, 

1159 kde_kwargs={"xlow": xlow, "xhigh": xhigh} 

1160 ) 

1161 except Exception as e: 

1162 logger.warning( 

1163 "Unable to compute the HPD interval for {} because {}".format( 

1164 param, e 

1165 ) 

1166 ) 

1167 continue 

1168 pdfs = kdes(x) 

1169 for num, key in enumerate(_samples.keys()): 

1170 [xlow, xhigh], _ = hpd_two_sided_credible_interval( 

1171 [], 90, x=x, pdf=pdfs[num] 

1172 ) 

1173 key_data[key][param]["90% HPD"] = [xlow, xhigh] 

1174 for _param in self.samples[key].keys(): 

1175 if _param in bounded_parameters: 

1176 continue 

1177 key_data[key][_param]["90% HPD"] = float("nan") 

1178 return key_data 

1179 

1180 

1181class SamplesInput(_GWInput, pesummary.core.cli.inputs.SamplesInput): 

1182 """Class to handle and store sample specific command line arguments 

1183 """ 

1184 def __init__(self, *args, **kwargs): 

1185 kwargs.update({"ignore_copy": True}) 

1186 super(SamplesInput, self).__init__( 

1187 *args, gw=True, extra_options=[ 

1188 "evolve_spins_forwards", 

1189 "evolve_spins_backwards", 

1190 "NRSur_fits", 

1191 "calculate_multipole_snr", 

1192 "calculate_precessing_snr", 

1193 "f_start", 

1194 "f_low", 

1195 "f_ref", 

1196 "f_final", 

1197 "psd", 

1198 "waveform_fits", 

1199 "redshift_method", 

1200 "cosmology", 

1201 "no_conversion", 

1202 "delta_f", 

1203 "psd_default", 

1204 "disable_remnant", 

1205 "force_BBH_remnant_computation", 

1206 "force_BH_spin_evolution" 

1207 ], **kwargs 

1208 ) 

1209 if self._restarted_from_checkpoint: 

1210 return 

1211 if self.existing is not None: 

1212 self.existing_data = self.grab_data_from_metafile( 

1213 self.existing_metafile, self.existing, 

1214 compare=self.compare_results 

1215 ) 

1216 self.existing_approximant = self.existing_data["approximant"] 

1217 self.existing_psd = self.existing_data["psd"] 

1218 self.existing_calibration = self.existing_data["calibration"] 

1219 self.existing_skymap = self.existing_data["skymap"] 

1220 else: 

1221 self.existing_approximant = None 

1222 self.existing_psd = None 

1223 self.existing_calibration = None 

1224 self.existing_skymap = None 

1225 self.approximant = self.opts.approximant 

1226 self.gracedb_server = self.opts.gracedb_server 

1227 self.gracedb_data = self.opts.gracedb_data 

1228 self.gracedb = self.opts.gracedb 

1229 self.pastro_category_file = self.opts.pastro_category_file 

1230 self.terrestrial_probability = self.opts.terrestrial_probability 

1231 self.catch_terrestrial_probability_error = self.opts.catch_terrestrial_probability_error 

1232 self.approximant_flags = self.opts.approximant_flags 

1233 self.detectors = None 

1234 self.skymap = None 

1235 self.calibration_definition = self.opts.calibration_definition 

1236 self.calibration = self.opts.calibration 

1237 self.gwdata = self.opts.gwdata 

1238 self.maxL_samples = [] 

1239 

1240 @property 

1241 def maxL_samples(self): 

1242 return self._maxL_samples 

1243 

1244 @maxL_samples.setter 

1245 def maxL_samples(self, maxL_samples): 

1246 key_data = self.grab_key_data_from_result_files() 

1247 maxL_samples = { 

1248 i: { 

1249 j: key_data[i][j]["maxL"] for j in key_data[i].keys() 

1250 } for i in key_data.keys() 

1251 } 

1252 for i in self.labels: 

1253 maxL_samples[i]["approximant"] = self.approximant[i] 

1254 self._maxL_samples = maxL_samples 

1255 

1256 

1257class PlottingInput(SamplesInput, pesummary.core.cli.inputs.PlottingInput): 

1258 """Class to handle and store plottig specific command line arguments 

1259 """ 

1260 def __init__(self, *args, **kwargs): 

1261 super(PlottingInput, self).__init__(*args, **kwargs) 

1262 self.nsamples_for_skymap = self.opts.nsamples_for_skymap 

1263 self.sensitivity = self.opts.sensitivity 

1264 self.no_ligo_skymap = self.opts.no_ligo_skymap 

1265 self.multi_threading_for_skymap = self.multi_process 

1266 if not self.no_ligo_skymap and self.multi_process > 1: 

1267 total = self.multi_process 

1268 self.multi_threading_for_plots = int(total / 2.) 

1269 self.multi_threading_for_skymap = total - self.multi_threading_for_plots 

1270 logger.info( 

1271 "Assigning {} process{}to skymap generation and {} process{}to " 

1272 "other plots".format( 

1273 self.multi_threading_for_skymap, 

1274 "es " if self.multi_threading_for_skymap > 1 else " ", 

1275 self.multi_threading_for_plots, 

1276 "es " if self.multi_threading_for_plots > 1 else " " 

1277 ) 

1278 ) 

1279 self.preliminary_pages = None 

1280 self.pastro_probs = [] 

1281 self.embright_probs = [] 

1282 self.classification_probs = {} 

1283 for key in self.pastro_probs.keys(): 

1284 self.classification_probs[key] = {"default": {}} 

1285 self.classification_probs[key]["default"].update( 

1286 self.pastro_probs[key]["default"] 

1287 ) 

1288 self.classification_probs[key]["default"].update( 

1289 self.embright_probs[key]["default"] 

1290 ) 

1291 

1292 

1293class WebpageInput(SamplesInput, pesummary.core.cli.inputs.WebpageInput): 

1294 """Class to handle and store webpage specific command line arguments 

1295 """ 

1296 def __init__(self, *args, **kwargs): 

1297 super(WebpageInput, self).__init__(*args, **kwargs) 

1298 self.public = self.opts.public 

1299 if not hasattr(self, "preliminary_pages"): 

1300 self.preliminary_pages = None 

1301 if not hasattr(self, "pastro_probs"): 

1302 self.pastro_probs = [] 

1303 if not hasattr(self, "embright_probs"): 

1304 self.embright_probs = [] 

1305 self.classification_probs = {} 

1306 for key in self.pastro_probs.keys(): 

1307 self.classification_probs[key] = {"default": {}} 

1308 self.classification_probs[key]["default"].update( 

1309 self.pastro_probs[key]["default"] 

1310 ) 

1311 self.classification_probs[key]["default"].update( 

1312 self.embright_probs[key]["default"] 

1313 ) 

1314 

1315 

1316class WebpagePlusPlottingInput(PlottingInput, WebpageInput): 

1317 """Class to handle and store webpage and plotting specific command line 

1318 arguments 

1319 """ 

1320 def __init__(self, *args, **kwargs): 

1321 super(WebpagePlusPlottingInput, self).__init__(*args, **kwargs) 

1322 

1323 @property 

1324 def default_directories(self): 

1325 return super(WebpagePlusPlottingInput, self).default_directories 

1326 

1327 @property 

1328 def default_files_to_copy(self): 

1329 return super(WebpagePlusPlottingInput, self).default_files_to_copy 

1330 

1331 

1332class MetaFileInput(SamplesInput, pesummary.core.cli.inputs.MetaFileInput): 

1333 """Class to handle and store metafile specific command line arguments 

1334 """ 

1335 @property 

1336 def default_directories(self): 

1337 dirs = super(MetaFileInput, self).default_directories 

1338 dirs += ["psds", "calibration"] 

1339 return dirs 

1340 

1341 def copy_files(self): 

1342 _error = "Failed to save the {} to file" 

1343 for label in self.labels: 

1344 if self.psd[label] != {}: 

1345 for ifo in self.psd[label].keys(): 

1346 if not isinstance(self.psd[label][ifo], PSD): 

1347 logger.warning(_error.format("{} PSD".format(ifo))) 

1348 continue 

1349 self.psd[label][ifo].save_to_file( 

1350 os.path.join(self.webdir, "psds", "{}_{}_psd.dat".format( 

1351 label, ifo 

1352 )) 

1353 ) 

1354 if label in self.priors["calibration_raw"].keys(): 

1355 if self.priors["calibration_raw"][label] != {}: 

1356 for ifo in self.priors["calibration_raw"][label].keys(): 

1357 _instance = isinstance( 

1358 self.priors["calibration_raw"][label][ifo], Calibration 

1359 ) 

1360 if not _instance: 

1361 logger.warning( 

1362 _error.format( 

1363 "{} calibration envelope".format( 

1364 ifo 

1365 ) 

1366 ) 

1367 ) 

1368 continue 

1369 self.priors["calibration_raw"][label][ifo].save_to_file( 

1370 os.path.join(self.webdir, "calibration", "{}_{}_cal.txt".format( 

1371 label, ifo 

1372 )) 

1373 ) 

1374 return super(MetaFileInput, self).copy_files() 

1375 

1376 

1377class WebpagePlusPlottingPlusMetaFileInput(MetaFileInput, WebpagePlusPlottingInput): 

1378 """Class to handle and store webpage, plotting and metafile specific command 

1379 line arguments 

1380 """ 

1381 def __init__(self, *args, **kwargs): 

1382 super(WebpagePlusPlottingPlusMetaFileInput, self).__init__( 

1383 *args, **kwargs 

1384 ) 

1385 

1386 @property 

1387 def default_directories(self): 

1388 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_directories 

1389 

1390 @property 

1391 def default_files_to_copy(self): 

1392 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_files_to_copy 

1393 

1394 

1395@deprecation( 

1396 "The GWInput class is deprecated. Please use either the BaseInput, " 

1397 "SamplesInput, PlottingInput, WebpageInput, WebpagePlusPlottingInput, " 

1398 "MetaFileInput or the WebpagePlusPlottingPlusMetaFileInput class" 

1399) 

1400class GWInput(WebpagePlusPlottingPlusMetaFileInput): 

1401 pass 

1402 

1403 

1404class IMRCTInput(pesummary.core.cli.inputs._Input): 

1405 """Class to handle the TGR specific command line arguments 

1406 """ 

1407 @property 

1408 def labels(self): 

1409 return self._labels 

1410 

1411 @labels.setter 

1412 def labels(self, labels): 

1413 self._labels = labels 

1414 if len(labels) % 2 != 0: 

1415 raise ValueError( 

1416 "The IMRCT test requires 2 results files for each analysis. " 

1417 ) 

1418 elif len(labels) > 2: 

1419 cond = all( 

1420 ":inspiral" in label or ":postinspiral" in label for label in 

1421 labels 

1422 ) 

1423 if not cond: 

1424 raise ValueError( 

1425 "To compare 2 or more analyses, please provide labels as " 

1426 "'{}:inspiral' and '{}:postinspiral' where {} indicates " 

1427 "the analysis label" 

1428 ) 

1429 else: 

1430 self.analysis_label = [ 

1431 label.split(":inspiral")[0] 

1432 for label in labels 

1433 if ":inspiral" in label and ":postinspiral" not in label 

1434 ] 

1435 if len(self.analysis_label) != len(self.result_files) / 2: 

1436 raise ValueError( 

1437 "When comparing more than 2 analyses, labels must " 

1438 "be of the form '{}:inspiral' and '{}:postinspiral'." 

1439 ) 

1440 logger.info( 

1441 "Using the labels: {} to distinguish analyses".format( 

1442 ", ".join(self.analysis_label) 

1443 ) 

1444 ) 

1445 elif sorted(labels) != ["inspiral", "postinspiral"]: 

1446 if all(self.is_pesummary_metafile(ff) for ff in self.result_files): 

1447 meta_file_labels = [] 

1448 for suffix in [":inspiral", ":postinspiral"]: 

1449 if any(suffix in label for label in labels): 

1450 ind = [ 

1451 num for num, label in enumerate(labels) if 

1452 suffix in label 

1453 ] 

1454 if len(ind) > 1: 

1455 raise ValueError( 

1456 "Please provide a single {} label".format( 

1457 suffix.split(":")[1] 

1458 ) 

1459 ) 

1460 meta_file_labels.append( 

1461 labels[ind[0]].split(suffix)[0] 

1462 ) 

1463 else: 

1464 raise ValueError( 

1465 "Please provide labels as {inspiral_label}:inspiral " 

1466 "and {postinspiral_label}:postinspiral where " 

1467 "inspiral_label and postinspiral_label are the " 

1468 "PESummary labels for the inspiral and postinspiral " 

1469 "analyses respectively. " 

1470 ) 

1471 if len(self.result_files) == 1: 

1472 logger.info( 

1473 "Using the {} samples for the inspiral analysis and {} " 

1474 "samples for the postinspiral analysis from the file " 

1475 "{}".format( 

1476 meta_file_labels[0], meta_file_labels[1], 

1477 self.result_files[0] 

1478 ) 

1479 ) 

1480 elif len(self.result_files) == 2: 

1481 logger.info( 

1482 "Using the {} samples for the inspiral analysis from " 

1483 "the file {}. Using the {} samples for the " 

1484 "postinspiral analysis from the file {}".format( 

1485 meta_file_labels[0], self.result_files[0], 

1486 meta_file_labels[1], self.result_files[1] 

1487 ) 

1488 ) 

1489 else: 

1490 raise ValueError( 

1491 "Currently, you can only provide at most 2 pesummary " 

1492 "metafiles. If one is provided, both the inspiral and " 

1493 "postinspiral are extracted from that single file. If " 

1494 "two are provided, the inspiral is extracted from one " 

1495 "file and the postinspiral is extracted from the other." 

1496 ) 

1497 self._labels = ["inspiral", "postinspiral"] 

1498 self._meta_file_labels = meta_file_labels 

1499 self.analysis_label = ["primary"] 

1500 else: 

1501 raise ValueError( 

1502 "The IMRCT test requires an inspiral and postinspiral result " 

1503 "file. Please indicate which file is the inspiral and which " 

1504 "is postinspiral by providing these exact labels to the " 

1505 "summarytgr executable" 

1506 ) 

1507 else: 

1508 self.analysis_label = ["primary"] 

1509 

1510 def _extract_stored_approximant(self, opened_file, label): 

1511 """Extract the approximant used for a given analysis stored in a 

1512 PESummary metafile 

1513 

1514 Parameters 

1515 ---------- 

1516 opened_file: pesummary.gw.file.formats.pesummary.PESummary 

1517 opened metafile that contains the analysis 'label' 

1518 label: str 

1519 analysis label which is stored in the PESummary metafile 

1520 """ 

1521 if opened_file.approximant is not None: 

1522 if label not in opened_file.labels: 

1523 raise ValueError( 

1524 "Invalid label {}. The list of available labels are {}".format( 

1525 label, ", ".join(opened_file.labels) 

1526 ) 

1527 ) 

1528 _index = opened_file.labels.index(label) 

1529 return opened_file.approximant[_index] 

1530 return 

1531 

1532 def _extract_stored_remnant_fits(self, opened_file, label): 

1533 """Extract the remnant fits used for a given analysis stored in a 

1534 PESummary metafile 

1535 

1536 Parameters 

1537 ---------- 

1538 opened_file: pesummary.gw.file.formats.pesummary.PESummary 

1539 opened metafile that contains the analysis 'label' 

1540 label: str 

1541 analysis label which is stored in the PESummary metafile 

1542 """ 

1543 fits = {} 

1544 fit_strings = [ 

1545 "final_spin_NR_fits", "final_mass_NR_fits" 

1546 ] 

1547 if label not in opened_file.labels: 

1548 raise ValueError( 

1549 "Invalid label {}. The list of available labels are {}".format( 

1550 label, ", ".join(opened_file.labels) 

1551 ) 

1552 ) 

1553 _index = opened_file.labels.index(label) 

1554 _meta_data = opened_file.extra_kwargs[_index] 

1555 if "meta_data" in _meta_data.keys(): 

1556 for key in fit_strings: 

1557 if key in _meta_data["meta_data"].keys(): 

1558 fits[key] = _meta_data["meta_data"][key] 

1559 if len(fits): 

1560 return fits 

1561 return 

1562 

1563 def _extract_stored_cutoff_frequency(self, opened_file, label): 

1564 """Extract the cutoff frequencies used for a given analysis stored in a 

1565 PESummary metafile 

1566 

1567 Parameters 

1568 ---------- 

1569 opened_file: pesummary.gw.file.formats.pesummary.PESummary 

1570 opened metafile that contains the analysis 'label' 

1571 label: str 

1572 analysis label which is stored in the PESummary metafile 

1573 """ 

1574 frequencies = {} 

1575 if opened_file.config is not None: 

1576 if label not in opened_file.labels: 

1577 raise ValueError( 

1578 "Invalid label {}. The list of available labels are {}".format( 

1579 label, ", ".join(opened_file.labels) 

1580 ) 

1581 ) 

1582 if opened_file.config[label] is not None: 

1583 _config = opened_file.config[label] 

1584 if "config" in _config.keys(): 

1585 if "maximum-frequency" in _config["config"].keys(): 

1586 frequencies["fhigh"] = _config["config"][ 

1587 "maximum-frequency" 

1588 ] 

1589 if "minimum-frequency" in _config["config"].keys(): 

1590 frequencies["flow"] = _config["config"][ 

1591 "minimum-frequency" 

1592 ] 

1593 elif "lalinference" in _config.keys(): 

1594 if "fhigh" in _config["lalinference"].keys(): 

1595 frequencies["fhigh"] = _config["lalinference"][ 

1596 "fhigh" 

1597 ] 

1598 if "flow" in _config["lalinference"].keys(): 

1599 frequencies["flow"] = _config["lalinference"][ 

1600 "flow" 

1601 ] 

1602 return frequencies 

1603 return 

1604 

1605 @property 

1606 def samples(self): 

1607 return self._samples 

1608 

1609 @samples.setter 

1610 def samples(self, samples): 

1611 from pesummary.utils.samples_dict import MultiAnalysisSamplesDict 

1612 self._read_samples = { 

1613 _label: GWRead(_path, disable_prior=True) for _label, _path in zip( 

1614 self.labels, self.result_files 

1615 ) 

1616 } 

1617 _samples_dict = {} 

1618 _approximant_dict = {} 

1619 _cutoff_frequency_dict = {} 

1620 _remnant_fits_dict = {} 

1621 for label, _open in self._read_samples.items(): 

1622 if isinstance(_open.samples_dict, MultiAnalysisSamplesDict): 

1623 if not len(self._meta_file_labels): 

1624 raise ValueError( 

1625 "Currently you can only pass a file containing a " 

1626 "single analysis or a valid PESummary metafile " 

1627 "containing multiple analyses" 

1628 ) 

1629 _labels = _open.labels 

1630 if len(self._read_samples) == 1: 

1631 _samples_dict = { 

1632 label: _open.samples_dict[meta_file_label] for 

1633 label, meta_file_label in zip( 

1634 self.labels, self._meta_file_labels 

1635 ) 

1636 } 

1637 for label, meta_file_label in zip(self.labels, self._meta_file_labels): 

1638 _stored_approx = self._extract_stored_approximant( 

1639 _open, meta_file_label 

1640 ) 

1641 _stored_frequencies = self._extract_stored_cutoff_frequency( 

1642 _open, meta_file_label 

1643 ) 

1644 _stored_remnant_fits = self._extract_stored_remnant_fits( 

1645 _open, meta_file_label 

1646 ) 

1647 if _stored_approx is not None: 

1648 _approximant_dict[label] = _stored_approx 

1649 if _stored_remnant_fits is not None: 

1650 _remnant_fits_dict[label] = _stored_remnant_fits 

1651 if _stored_frequencies is not None: 

1652 if label == "inspiral": 

1653 if "fhigh" in _stored_frequencies.keys(): 

1654 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1655 "fhigh" 

1656 ] 

1657 if label == "postinspiral": 

1658 if "flow" in _stored_frequencies.keys(): 

1659 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1660 "flow" 

1661 ] 

1662 break 

1663 else: 

1664 ind = self.labels.index(label) 

1665 _samples_dict[label] = _open.samples_dict[ 

1666 self._meta_file_labels[ind] 

1667 ] 

1668 _stored_approx = self._extract_stored_approximant( 

1669 _open, self._meta_file_labels[ind] 

1670 ) 

1671 _stored_frequencies = self._extract_stored_cutoff_frequency( 

1672 _open, self._meta_file_labels[ind] 

1673 ) 

1674 _stored_remnant_fits = self._extract_stored_remnant_fits( 

1675 _open, self._meta_file_labels[ind] 

1676 ) 

1677 if _stored_approx is not None: 

1678 _approximant_dict[label] = _stored_approx 

1679 if _stored_remnant_fits is not None: 

1680 _remnant_fits_dict[label] = _stored_remnant_fits 

1681 if _stored_frequencies is not None: 

1682 if label == "inspiral": 

1683 if "fhigh" in _stored_frequencies.keys(): 

1684 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1685 "fhigh" 

1686 ] 

1687 if label == "postinspiral": 

1688 if "flow" in _stored_frequencies.keys(): 

1689 _cutoff_frequency_dict[label] = _stored_frequencies[ 

1690 "flow" 

1691 ] 

1692 else: 

1693 _samples_dict[label] = _open.samples_dict 

1694 extra_kwargs = _open.extra_kwargs 

1695 if "pe_algorithm" in extra_kwargs["sampler"].keys(): 

1696 if extra_kwargs["sampler"]["pe_algorithm"] == "bilby": 

1697 try: 

1698 subkwargs = extra_kwargs["other"]["likelihood"][ 

1699 "waveform_arguments" 

1700 ] 

1701 _approximant_dict[label] = ( 

1702 subkwargs["waveform_approximant"] 

1703 ) 

1704 if "inspiral" in label and "postinspiral" not in label: 

1705 _cutoff_frequency_dict[label] = ( 

1706 subkwargs["maximum_frequency"] 

1707 ) 

1708 elif "postinspiral" in label: 

1709 _cutoff_frequency_dict[label] = ( 

1710 subkwargs["minimum_frequency"] 

1711 ) 

1712 except KeyError: 

1713 pass 

1714 self._samples = MultiAnalysisSamplesDict(_samples_dict) 

1715 if len(_approximant_dict): 

1716 self._approximant_dict = _approximant_dict 

1717 if len(_cutoff_frequency_dict): 

1718 self._cutoff_frequency_dict = _cutoff_frequency_dict 

1719 if len(_remnant_fits_dict): 

1720 self._remnant_fits_dict = _remnant_fits_dict 

1721 

1722 @property 

1723 def imrct_kwargs(self): 

1724 return self._imrct_kwargs 

1725 

1726 @imrct_kwargs.setter 

1727 def imrct_kwargs(self, imrct_kwargs): 

1728 test_kwargs = dict(N_bins=101) 

1729 try: 

1730 test_kwargs.update(imrct_kwargs) 

1731 except AttributeError: 

1732 test_kwargs = test_kwargs 

1733 

1734 for key, value in test_kwargs.items(): 

1735 try: 

1736 test_kwargs[key] = ast.literal_eval(value) 

1737 except ValueError: 

1738 pass 

1739 self._imrct_kwargs = test_kwargs 

1740 

1741 @property 

1742 def meta_data(self): 

1743 return self._meta_data 

1744 

1745 @meta_data.setter 

1746 def meta_data(self, meta_data): 

1747 self._meta_data = {} 

1748 for num, _inspiral in enumerate(self.inspiral_keys): 

1749 frequency_dict = dict() 

1750 approximant_dict = dict() 

1751 remnant_dict = dict() 

1752 zipped = zip( 

1753 [self.cutoff_frequency, self.approximant, None], 

1754 [frequency_dict, approximant_dict, remnant_dict], 

1755 ["cutoff_frequency", "approximant", "remnant_fits"] 

1756 ) 

1757 _inspiral_string = self.inspiral_keys[num] 

1758 _postinspiral_string = self.postinspiral_keys[num] 

1759 for _list, _dict, name in zipped: 

1760 if _list is not None and len(_list) == len(self.labels): 

1761 inspiral_ind = self.labels.index(_inspiral_string) 

1762 postinspiral_ind = self.labels.index(_postinspiral_string) 

1763 _dict["inspiral"] = _list[inspiral_ind] 

1764 _dict["postinspiral"] = _list[postinspiral_ind] 

1765 elif _list is not None: 

1766 raise ValueError( 

1767 "Please provide a 'cutoff_frequency' and 'approximant' " 

1768 "for each file" 

1769 ) 

1770 else: 

1771 try: 

1772 if name == "cutoff_frequency": 

1773 if "inspiral" in self._cutoff_frequency_dict.keys(): 

1774 _dict["inspiral"] = self._cutoff_frequency_dict[ 

1775 "inspiral" 

1776 ] 

1777 if "postinspiral" in self._cutoff_frequency_dict.keys(): 

1778 _dict["postinspiral"] = self._cutoff_frequency_dict[ 

1779 "postinspiral" 

1780 ] 

1781 elif name == "approximant": 

1782 if "inspiral" in self._approximant_dict.keys(): 

1783 _dict["inspiral"] = self._approximant_dict[ 

1784 "inspiral" 

1785 ] 

1786 if "postinspiral" in self._approximant_dict.keys(): 

1787 _dict["postinspiral"] = self._approximant_dict[ 

1788 "postinspiral" 

1789 ] 

1790 elif name == "remnant_fits": 

1791 if "inspiral" in self._remnant_fits_dict.keys(): 

1792 _dict["inspiral"] = self._remnant_fits_dict[ 

1793 "inspiral" 

1794 ] 

1795 if "postinspiral" in self._remnant_fits_dict.keys(): 

1796 _dict["postinspiral"] = self._remnant_fits_dict[ 

1797 "postinspiral" 

1798 ] 

1799 except (AttributeError, KeyError, TypeError): 

1800 _dict["inspiral"] = None 

1801 _dict["postinspiral"] = None 

1802 

1803 self._meta_data[self.analysis_label[num]] = { 

1804 "inspiral maximum frequency (Hz)": frequency_dict["inspiral"], 

1805 "postinspiral minimum frequency (Hz)": frequency_dict["postinspiral"], 

1806 "inspiral approximant": approximant_dict["inspiral"], 

1807 "postinspiral approximant": approximant_dict["postinspiral"], 

1808 "inspiral remnant fits": remnant_dict["inspiral"], 

1809 "postinspiral remnant fits": remnant_dict["postinspiral"] 

1810 } 

1811 

1812 def __init__(self, opts): 

1813 self.opts = opts 

1814 self.existing = None 

1815 self.webdir = self.opts.webdir 

1816 self.user = None 

1817 self.baseurl = None 

1818 self.result_files = self.opts.samples 

1819 self.labels = self.opts.labels 

1820 self.samples = self.opts.samples 

1821 self.inspiral_keys = [ 

1822 key for key in self.samples.keys() if "inspiral" in key 

1823 and "postinspiral" not in key 

1824 ] 

1825 self.postinspiral_keys = [ 

1826 key.replace("inspiral", "postinspiral") for key in self.inspiral_keys 

1827 ] 

1828 try: 

1829 self.imrct_kwargs = self.opts.imrct_kwargs 

1830 except AttributeError: 

1831 self.imrct_kwargs = {} 

1832 for _arg in ["cutoff_frequency", "approximant", "links_to_pe_pages", "f_low"]: 

1833 _attr = getattr(self.opts, _arg) 

1834 if _attr is not None and len(_attr) and len(_attr) != len(self.labels): 

1835 raise ValueError("Please provide a {} for each file".format(_arg)) 

1836 setattr(self, _arg, _attr) 

1837 self.meta_data = None 

1838 self.default_directories = ["samples", "plots", "js", "html", "css"] 

1839 self.publication = False 

1840 self.make_directories()