Coverage for pesummary/gw/cli/inputs.py: 75.2%
978 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import ast
4import os
5import numpy as np
6import pesummary.core.cli.inputs
7from pesummary.gw.file.read import read as GWRead
8from pesummary.gw.file.psd import PSD
9from pesummary.gw.file.calibration import Calibration
10from pesummary.utils.decorators import deprecation
11from pesummary.utils.exceptions import InputError
12from pesummary.utils.utils import logger
13from pesummary import conf
15__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
18class _GWInput(pesummary.core.cli.inputs._Input):
19 """Super class to handle gw specific command line inputs
20 """
21 @staticmethod
22 def grab_data_from_metafile(
23 existing_file, webdir, compare=None, nsamples=None, **kwargs
24 ):
25 """Grab data from an existing PESummary metafile
27 Parameters
28 ----------
29 existing_file: str
30 path to the existing metafile
31 webdir: str
32 the directory to store the existing configuration file
33 compare: list, optional
34 list of labels for events stored in an existing metafile that you
35 wish to compare
36 """
37 _replace_kwargs = {
38 "psd": "{file}.psd['{label}']"
39 }
40 if "psd_default" in kwargs.keys():
41 _replace_kwargs["psd_default"] = kwargs["psd_default"]
42 data = pesummary.core.cli.inputs._Input.grab_data_from_metafile(
43 existing_file, webdir, compare=compare, read_function=GWRead,
44 nsamples=nsamples, _replace_with_pesummary_kwargs=_replace_kwargs,
45 **kwargs
46 )
47 f = GWRead(existing_file)
49 labels = data["labels"]
51 psd = {i: {} for i in labels}
52 if f.psd is not None and f.psd != {}:
53 for i in labels:
54 if i in f.psd.keys() and f.psd[i] != {}:
55 psd[i] = {
56 ifo: PSD(f.psd[i][ifo]) for ifo in f.psd[i].keys()
57 }
58 calibration = {i: {} for i in labels}
59 if f.calibration is not None and f.calibration != {}:
60 for i in labels:
61 if i in f.calibration.keys() and f.calibration[i] != {}:
62 calibration[i] = {
63 ifo: Calibration(f.calibration[i][ifo]) for ifo in
64 f.calibration[i].keys()
65 }
66 skymap = {i: None for i in labels}
67 if hasattr(f, "skymap") and f.skymap is not None and f.skymap != {}:
68 for i in labels:
69 if i in f.skymap.keys() and len(f.skymap[i]):
70 skymap[i] = f.skymap[i]
71 data.update(
72 {
73 "approximant": {
74 i: j for i, j in zip(
75 labels, [f.approximant[ind] for ind in data["indicies"]]
76 )
77 },
78 "psd": psd,
79 "calibration": calibration,
80 "skymap": skymap
81 }
82 )
83 return data
85 @property
86 def grab_data_kwargs(self):
87 kwargs = super(_GWInput, self).grab_data_kwargs
88 for _property in ["f_start", "f_low", "f_ref", "f_final", "delta_f"]:
89 if getattr(self, _property) is None:
90 setattr(self, "_{}".format(_property), [None] * len(self.labels))
91 elif len(getattr(self, _property)) == 1 and len(self.labels) != 1:
92 setattr(
93 self, "_{}".format(_property),
94 getattr(self, _property) * len(self.labels)
95 )
96 if self.opts.approximant is None:
97 approx = [None] * len(self.labels)
98 else:
99 approx = self.opts.approximant
100 self.approximant_flags = self.opts.approximant_flags
101 resume_file = [
102 os.path.join(
103 self.webdir, "checkpoint",
104 "{}_conversion_class.pickle".format(label)
105 ) for label in self.labels
106 ]
108 try:
109 for num, label in enumerate(self.labels):
110 try:
111 psd = self.psd[label]
112 except KeyError:
113 psd = {}
114 kwargs[label].update(dict(
115 evolve_spins_forwards=self.evolve_spins_forwards,
116 evolve_spins_backwards=self.evolve_spins_backwards,
117 f_start=self.f_start[num], f_low=self.f_low[num],
118 approximant_flags=self.approximant_flags.get(label, {}),
119 approximant=approx[num], f_ref=self.f_ref[num],
120 NRSur_fits=self.NRSur_fits, return_kwargs=True,
121 multipole_snr=self.calculate_multipole_snr,
122 precessing_snr=self.calculate_precessing_snr,
123 psd=psd, f_final=self.f_final[num],
124 waveform_fits=self.waveform_fits,
125 multi_process=self.opts.multi_process,
126 redshift_method=self.redshift_method,
127 cosmology=self.cosmology,
128 no_conversion=self.no_conversion,
129 add_zero_spin=True, delta_f=self.delta_f[num],
130 psd_default=self.psd_default,
131 disable_remnant=self.disable_remnant,
132 force_BBH_remnant_computation=self.force_BBH_remnant_computation,
133 resume_file=resume_file[num],
134 restart_from_checkpoint=self.restart_from_checkpoint,
135 force_BH_spin_evolution=self.force_BH_spin_evolution,
136 ))
137 return kwargs
138 except IndexError:
139 logger.warning(
140 "Unable to find an f_ref, f_low and approximant for each "
141 "label. Using and f_ref={}, f_low={} and approximant={} "
142 "for all result files".format(
143 self.f_ref[0], self.f_low[0], approx[0]
144 )
145 )
146 for num, label in enumerate(self.labels):
147 kwargs[label].update(dict(
148 evolve_spins_forwards=self.evolve_spins_forwards,
149 evolve_spins_backwards=self.evolve_spins_backwards,
150 f_start=self.f_start[0], f_low=self.f_low[0],
151 approximant=approx[0], f_ref=self.f_ref[0],
152 NRSur_fits=self.NRSur_fits, return_kwargs=True,
153 multipole_snr=self.calculate_multipole_snr,
154 precessing_snr=self.calculate_precessing_snr,
155 psd=self.psd[self.labels[0]], f_final=self.f_final[0],
156 waveform_fits=self.waveform_fits,
157 multi_process=self.opts.multi_process,
158 redshift_method=self.redshift_method,
159 cosmology=self.cosmology,
160 no_conversion=self.no_conversion,
161 add_zero_spin=True, delta_f=self.delta_f[0],
162 psd_default=self.psd_default,
163 disable_remnant=self.disable_remnant,
164 force_BBH_remnant_computation=self.force_BBH_remnant_computation,
165 resume_file=resume_file[num],
166 restart_from_checkpoint=self.restart_from_checkpoint,
167 force_BH_spin_evolution=self.force_BH_spin_evolution
168 ))
169 return kwargs
171 @staticmethod
172 def grab_data_from_file(
173 file, label, webdir, config=None, injection=None, file_format=None,
174 nsamples=None, disable_prior_sampling=False, **kwargs
175 ):
176 """Grab data from a result file containing posterior samples
178 Parameters
179 ----------
180 file: str
181 path to the result file
182 label: str
183 label that you wish to use for the result file
184 config: str, optional
185 path to a configuration file used in the analysis
186 injection: str, optional
187 path to an injection file used in the analysis
188 file_format, str, optional
189 the file format you wish to use when loading. Default None.
190 If None, the read function loops through all possible options
191 """
192 data = pesummary.core.cli.inputs._Input.grab_data_from_file(
193 file, label, webdir, config=config, injection=injection,
194 read_function=GWRead, file_format=file_format, nsamples=nsamples,
195 disable_prior_sampling=disable_prior_sampling, **kwargs
196 )
197 try:
198 for _kwgs in data["file_kwargs"][label]:
199 if "approximant_flags" in _kwgs["meta_data"]:
200 for key, item in _kwargs["meta_data"]["approximant_flags"].items():
201 warning_cond = (
202 key in self.approximant_flags[label] and
203 self.approximant_flags[label][key] != item
204 )
205 if warning_cond:
206 logger.warning(
207 "Approximant flag {}={} found in result file for {}. "
208 "Ignoring and using the provided values {}={}".format(
209 key, self.approximant_flags[label][key], label,
210 key, item
211 )
212 )
213 else:
214 self.approximant_flags[label][key] = item
215 except Exception:
216 pass
217 return data
219 @property
220 def reweight_samples(self):
221 return self._reweight_samples
223 @reweight_samples.setter
224 def reweight_samples(self, reweight_samples):
225 from pesummary.gw.reweight import options
226 self._reweight_samples = self._check_reweight_samples(
227 reweight_samples, options
228 )
230 def _set_samples(self, *args, **kwargs):
231 super(_GWInput, self)._set_samples(*args, **kwargs)
232 if "calibration" not in self.priors:
233 self.priors["calibration"] = {
234 label: {} for label in self.labels
235 }
237 def _set_corner_params(self, corner_params):
238 corner_params = super(_GWInput, self)._set_corner_params(corner_params)
239 if corner_params is None:
240 logger.debug(
241 "Using the default corner parameters: {}".format(
242 ", ".join(conf.gw_corner_parameters)
243 )
244 )
245 else:
246 _corner_params = corner_params
247 corner_params = list(set(conf.gw_corner_parameters + corner_params))
248 for param in _corner_params:
249 _data = self.samples
250 if not all(param in _data[label].keys() for label in self.labels):
251 corner_params.remove(param)
252 logger.debug(
253 "Generating a corner plot with the following "
254 "parameters: {}".format(", ".join(corner_params))
255 )
256 return corner_params
258 @property
259 def cosmology(self):
260 return self._cosmology
262 @cosmology.setter
263 def cosmology(self, cosmology):
264 from pesummary.gw.cosmology import available_cosmologies
266 if cosmology.lower() not in available_cosmologies:
267 logger.warning(
268 "Unrecognised cosmology: {}. Using {} as default".format(
269 cosmology, conf.cosmology
270 )
271 )
272 cosmology = conf.cosmology
273 else:
274 logger.debug("Using the {} cosmology".format(cosmology))
275 self._cosmology = cosmology
277 @property
278 def approximant(self):
279 return self._approximant
281 @approximant.setter
282 def approximant(self, approximant):
283 if not hasattr(self, "_approximant"):
284 approximant_list = {i: {} for i in self.labels}
285 if approximant is None:
286 logger.warning(
287 "No approximant passed. Waveform plots will not be "
288 "generated"
289 )
290 elif approximant is not None:
291 if len(approximant) != len(self.labels):
292 raise InputError(
293 "Please pass an approximant for each result file"
294 )
295 approximant_list = {
296 i: j for i, j in zip(self.labels, approximant)
297 }
298 self._approximant = approximant_list
299 else:
300 for num, i in enumerate(self._approximant.keys()):
301 if self._approximant[i] == {}:
302 if num == 0:
303 logger.warning(
304 "No approximant passed. Waveform plots will not be "
305 "generated"
306 )
307 self._approximant[i] = None
308 break
310 @property
311 def approximant_flags(self):
312 return self._approximant_flags
314 @approximant_flags.setter
315 def approximant_flags(self, approximant_flags):
316 if hasattr(self, "_approximant_flags"):
317 return
318 _approximant_flags = {key: {} for key in self.labels}
319 for key, item in approximant_flags.items():
320 _label, key = key.split(":")
321 if _label not in self.labels:
322 raise ValueError(
323 "Unable to assign waveform flags for label:{} because "
324 "it does not exist. Available labels are: {}. Approximant "
325 "flags must be provided in the form LABEL:FLAG:VALUE".format(
326 _label, ", ".join(self.labels)
327 )
328 )
329 _approximant_flags[_label][key] = item
330 self._approximant_flags = _approximant_flags
332 @property
333 def gracedb_server(self):
334 return self._gracedb_server
336 @gracedb_server.setter
337 def gracedb_server(self, gracedb_server):
338 if gracedb_server is None:
339 self._gracedb_server = conf.gracedb_server
340 else:
341 logger.debug(
342 "Using '{}' as the GraceDB server".format(gracedb_server)
343 )
344 self._gracedb_server = gracedb_server
346 @property
347 def gracedb(self):
348 return self._gracedb
350 @gracedb.setter
351 def gracedb(self, gracedb):
352 self._gracedb = gracedb
353 if gracedb is not None:
354 from pesummary.gw.gracedb import get_gracedb_data, HTTPError
355 from json.decoder import JSONDecodeError
357 first_letter = gracedb[0]
358 if first_letter != "G" and first_letter != "g" and first_letter != "S":
359 logger.warning(
360 "Invalid GraceDB ID passed. The GraceDB ID must be of the "
361 "form G0000 or S0000. Ignoring input."
362 )
363 self._gracedb = None
364 return
365 _error = (
366 "Unable to download data from Gracedb because {}. Only storing "
367 "the GraceDB ID in the metafile"
368 )
369 try:
370 logger.info(
371 "Downloading {} from gracedb for {}".format(
372 ", ".join(self.gracedb_data), gracedb
373 )
374 )
375 json = get_gracedb_data(
376 gracedb, info=self.gracedb_data,
377 service_url=self.gracedb_server
378 )
379 json["id"] = gracedb
380 except (HTTPError, RuntimeError, JSONDecodeError) as e:
381 logger.warning(_error.format(e))
382 json = {"id": gracedb}
384 for label in self.labels:
385 self.file_kwargs[label]["meta_data"]["gracedb"] = json
387 @property
388 def terrestrial_probability(self):
389 return self._terrestrial_probability
391 @terrestrial_probability.setter
392 def terrestrial_probability(self, terrestrial_probability):
393 if terrestrial_probability is None and self.gracedb is not None:
394 logger.info(
395 "No terrestrial probability provided. Trying to download "
396 "from gracedb"
397 )
398 from pesummary.core.fetch import download_and_read_file
399 from urllib.error import HTTPError
400 from json.decoder import JSONDecodeError
401 import json
402 try:
403 ff = download_and_read_file(
404 f"{self.gracedb_server}/superevents/{self.gracedb}/"
405 f"files/p_astro.json", read_file=False,
406 outdir=f"{self.webdir}/samples", delete_on_exit=False
407 )
408 with open(ff, "r") as f:
409 data = json.load(f)
410 self._terrestrial_probability = [float(data["Terrestrial"])]
411 except (RuntimeError, JSONDecodeError) as e:
412 logger.warning(
413 "Unable to grab terrestrial probability from gracedb "
414 "because {}".format(e)
415 )
416 self._terrestrial_probability = [None]
417 except HTTPError as e:
418 from pesummary.gw.gracedb import get_gracedb_data, get_gracedb_file
419 try:
420 preferred = get_gracedb_data(
421 self.gracedb, info=["preferred_event_data"],
422 service_url=self.gracedb_server
423 )["preferred_event_data"]["submitter"]
424 _pipelines = [
425 "pycbc", "gstlal", "mbta", "spiir"
426 ]
427 _filename = None
428 for _pipe in _pipelines:
429 if _pipe in preferred:
430 _filename = f"{_pipe}.p_astro.json"
431 if _filename is None:
432 raise e
433 data = get_gracedb_file(
434 self.gracedb, _filename, service_url=self.gracedb_server
435 )
436 with open(f"{self.webdir}/samples/{_filename}", "w") as json_file:
437 json.dump(data, json_file)
438 self._terrestrial_probability = [float(data["Terrestrial"])]
439 except Exception as e:
440 logger.warning(
441 "Unable to grab terrestrial probability from gracedb "
442 "because {}".format(e)
443 )
444 self._terrestrial_probability = [None]
445 self._terrestrial_probability *= len(self.labels)
446 elif terrestrial_probability is None:
447 self._terrestrial_probability = [None] * len(self.labels)
448 else:
449 if len(terrestrial_probability) == 1 and len(self.labels) > 1:
450 logger.debug(
451 f"Assuming a terrestrial probability: "
452 f"{terrestrial_probability} for all analyses"
453 )
454 self._terrestrial_probability = [
455 float(terrestrial_probability[0])
456 ] * len(self.labels)
457 elif len(terrestrial_probability) == len(self.labels):
458 self._terrestrial_probability = [
459 float(_) for _ in terrestrial_probability
460 ]
461 else:
462 raise ValueError(
463 "Please provide a terrestrial probability for each "
464 "analysis, or a single value to be used for all analyses"
465 )
467 @property
468 def detectors(self):
469 return self._detectors
471 @detectors.setter
472 def detectors(self, detectors):
473 detector = {}
474 if not detectors:
475 for i in self.samples.keys():
476 params = list(self.samples[i].keys())
477 individual_detectors = []
478 for j in params:
479 if "optimal_snr" in j and j != "network_optimal_snr":
480 det = j.split("_optimal_snr")[0]
481 individual_detectors.append(det)
482 individual_detectors = sorted(
483 [str(i) for i in individual_detectors])
484 if individual_detectors:
485 detector[i] = "_".join(individual_detectors)
486 else:
487 detector[i] = None
488 else:
489 detector = detectors
490 logger.debug("The detector network is %s" % (detector))
491 self._detectors = detector
493 @property
494 def skymap(self):
495 return self._skymap
497 @skymap.setter
498 def skymap(self, skymap):
499 if not hasattr(self, "_skymap"):
500 self._skymap = {i: None for i in self.labels}
502 @property
503 def calibration_definition(self):
504 return self._calibration_definition
506 @calibration_definition.setter
507 def calibration_definition(self, calibration_definition):
508 if not len(self.opts.calibration):
509 self._calibration_definition = None
510 return
511 if len(calibration_definition) == 1:
512 logger.info(
513 f"Assuming that the calibration correction was applied to "
514 f"'{calibration_definition[0]}' for all analyses"
515 )
516 calibration_definition *= len(self.labels)
517 elif len(calibration_definition) != len(self.labels):
518 raise ValueError(
519 f"Please provide a calibration definition for each analysis "
520 f"({len(self.labels)}) or a single definition to use for all "
521 f"analyses"
522 )
523 if any(_ not in ["data", "template"] for _ in calibration_definition):
524 raise ValueError(
525 "Calibration definitions must be either 'data' or 'template'"
526 )
527 self._calibration_definition = {
528 label: calibration_definition[num] for num, label in
529 enumerate(self.labels)
530 }
532 @property
533 def calibration(self):
534 return self._calibration
536 @calibration.setter
537 def calibration(self, calibration):
538 if not hasattr(self, "_calibration"):
539 data = {i: {} for i in self.labels}
540 if calibration != {}:
541 prior_data = self.get_psd_or_calibration_data(
542 calibration, self.extract_calibration_data_from_file,
543 type=self.calibration_definition[self.labels[0]]
544 )
545 self.add_to_prior_dict("calibration", prior_data)
546 else:
547 prior_data = {i: {} for i in self.labels}
548 for label in self.labels:
549 if hasattr(self.opts, "{}_calibration".format(label)):
550 cal_data = getattr(self.opts, "{}_calibration".format(label))
551 if cal_data != {} and cal_data is not None:
552 prior_data[label] = {
553 ifo: self.extract_calibration_data_from_file(
554 cal_data[ifo], type=self.calibration_definition[label]
555 ) for ifo in cal_data.keys()
556 }
557 if not all(prior_data[i] == {} for i in self.labels):
558 self.add_to_prior_dict("calibration", prior_data)
559 else:
560 self.add_to_prior_dict("calibration", {})
561 for num, i in enumerate(self.result_files):
562 _opened = self._open_result_files
563 if i in _opened.keys() and _opened[i] is not None:
564 f = self._open_result_files[i]
565 else:
566 f = GWRead(i, disable_prior=True)
567 try:
568 calibration_data = f.interpolate_calibration_spline_posterior()
569 except Exception as e:
570 logger.warning(
571 "Failed to extract calibration data from the result "
572 "file: {} because {}".format(i, e)
573 )
574 calibration_data = None
575 labels = list(self.samples.keys())
576 if calibration_data is None:
577 data[labels[num]] = {
578 None: None
579 }
580 elif isinstance(f, pesummary.gw.file.formats.pesummary.PESummary):
581 for num in range(len(calibration_data[0])):
582 data[labels[num]] = {
583 j: k for j, k in zip(
584 calibration_data[1][num],
585 calibration_data[0][num]
586 )
587 }
588 else:
589 data[labels[num]] = {
590 j: k for j, k in zip(
591 calibration_data[1], calibration_data[0]
592 )
593 }
594 self._calibration = data
596 @property
597 def psd(self):
598 return self._psd
600 @psd.setter
601 def psd(self, psd):
602 if not hasattr(self, "_psd"):
603 data = {i: {} for i in self.labels}
604 if psd != {}:
605 data = self.get_psd_or_calibration_data(
606 psd, self.extract_psd_data_from_file
607 )
608 else:
609 for label in self.labels:
610 if hasattr(self.opts, "{}_psd".format(label)):
611 psd_data = getattr(self.opts, "{}_psd".format(label))
612 if psd_data != {} and psd_data is not None:
613 data[label] = {
614 ifo: self.extract_psd_data_from_file(
615 psd_data[ifo], IFO=ifo
616 ) for ifo in psd_data.keys()
617 }
618 self._psd = data
620 @property
621 def nsamples_for_skymap(self):
622 return self._nsamples_for_skymap
624 @nsamples_for_skymap.setter
625 def nsamples_for_skymap(self, nsamples_for_skymap):
626 self._nsamples_for_skymap = nsamples_for_skymap
627 if nsamples_for_skymap is not None:
628 self._nsamples_for_skymap = int(nsamples_for_skymap)
629 number_of_samples = [
630 data.number_of_samples for label, data in self.samples.items()
631 ]
632 if not all(i > self._nsamples_for_skymap for i in number_of_samples):
633 min_arg = np.argmin(number_of_samples)
634 logger.warning(
635 "You have specified that you would like to use {} "
636 "samples to generate the skymap but the file {} only "
637 "has {} samples. Reducing the number of samples to "
638 "generate the skymap to {}".format(
639 self._nsamples_for_skymap, self.result_files[min_arg],
640 number_of_samples[min_arg], number_of_samples[min_arg]
641 )
642 )
643 self._nsamples_for_skymap = int(number_of_samples[min_arg])
645 @property
646 def gwdata(self):
647 return self._gwdata
649 @gwdata.setter
650 def gwdata(self, gwdata):
651 from pesummary.gw.file.strain import StrainDataDict
653 self._gwdata = gwdata
654 if gwdata is not None:
655 if isinstance(gwdata, dict):
656 for i in gwdata.keys():
657 if not os.path.isfile(gwdata[i]):
658 raise InputError(
659 "The file {} does not exist. Please check the path "
660 "to your strain file".format(gwdata[i])
661 )
662 self._gwdata = StrainDataDict.read(gwdata)
663 else:
664 logger.warning(
665 "Please provide gwdata as a dictionary with keys "
666 "displaying the channel and item giving the path to the "
667 "strain file"
668 )
669 self._gwdata = None
671 @property
672 def evolve_spins_forwards(self):
673 return self._evolve_spins_forwards
675 @evolve_spins_forwards.setter
676 def evolve_spins_forwards(self, evolve_spins_forwards):
677 self._evolve_spins_forwards = evolve_spins_forwards
678 _msg = "Spins will be evolved up to {}"
679 if evolve_spins_forwards:
680 logger.info(_msg.format("Schwarzschild ISCO frequency"))
681 self._evolve_spins_forwards = 6. ** -0.5
683 @property
684 def evolve_spins_backwards(self):
685 return self._evolve_spins_backwards
687 @evolve_spins_backwards.setter
688 def evolve_spins_backwards(self, evolve_spins_backwards):
689 self._evolve_spins_backwards = evolve_spins_backwards
690 _msg = (
691 "Spins will be evolved backwards to an infinite separation using the '{}' "
692 "method"
693 )
694 if isinstance(evolve_spins_backwards, (str, bytes)):
695 logger.info(_msg.format(evolve_spins_backwards))
696 elif evolve_spins_backwards is None:
697 logger.info(_msg.format("precession_averaged"))
698 self._evolve_spins_backwards = "precession_averaged"
700 @property
701 def NRSur_fits(self):
702 return self._NRSur_fits
704 @NRSur_fits.setter
705 def NRSur_fits(self, NRSur_fits):
706 self._NRSur_fits = NRSur_fits
707 base = (
708 "Using the '{}' NRSurrogate model to calculate the remnant "
709 "quantities"
710 )
711 if isinstance(NRSur_fits, (str, bytes)):
712 logger.info(base.format(NRSur_fits))
713 self._NRSur_fits = NRSur_fits
714 elif NRSur_fits is None:
715 from pesummary.gw.conversions.nrutils import NRSUR_MODEL
717 logger.info(base.format(NRSUR_MODEL))
718 self._NRSur_fits = NRSUR_MODEL
720 @property
721 def waveform_fits(self):
722 return self._waveform_fits
724 @waveform_fits.setter
725 def waveform_fits(self, waveform_fits):
726 self._waveform_fits = waveform_fits
727 if waveform_fits:
728 logger.info(
729 "Evaluating the remnant quantities using the provided "
730 "approximant"
731 )
733 @property
734 def f_low(self):
735 return self._f_low
737 @f_low.setter
738 def f_low(self, f_low):
739 self._f_low = f_low
740 if f_low is not None:
741 self._f_low = [float(i) for i in f_low]
743 @property
744 def f_start(self):
745 return self._f_start
747 @f_start.setter
748 def f_start(self, f_start):
749 self._f_start = f_start
750 if f_start is not None:
751 self._f_start = [float(i) for i in f_start]
753 @property
754 def f_ref(self):
755 return self._f_ref
757 @f_ref.setter
758 def f_ref(self, f_ref):
759 self._f_ref = f_ref
760 if f_ref is not None:
761 self._f_ref = [float(i) for i in f_ref]
763 @property
764 def f_final(self):
765 return self._f_final
767 @f_final.setter
768 def f_final(self, f_final):
769 self._f_final = f_final
770 if f_final is not None:
771 self._f_final = [float(i) for i in f_final]
773 @property
774 def delta_f(self):
775 return self._delta_f
777 @delta_f.setter
778 def delta_f(self, delta_f):
779 self._delta_f = delta_f
780 if delta_f is not None:
781 self._delta_f = [float(i) for i in delta_f]
783 @property
784 def psd_default(self):
785 return self._psd_default
787 @psd_default.setter
788 def psd_default(self, psd_default):
789 self._psd_default = psd_default
790 if "stored:" in psd_default:
791 label = psd_default.split("stored:")[1]
792 self._psd_default = "{file}.psd['%s']" % (label)
793 return
794 try:
795 from pycbc import psd
796 psd_default = getattr(psd, psd_default)
797 except ImportError:
798 logger.warning(
799 "Unable to import 'pycbc'. Unable to generate a default PSD"
800 )
801 psd_default = None
802 except AttributeError:
803 logger.warning(
804 "'pycbc' does not have the '{}' psd available. Using '{}' as "
805 "the default PSD".format(psd_default, conf.psd)
806 )
807 psd_default = getattr(psd, conf.psd)
808 except ValueError as e:
809 logger.warning("Setting 'psd_default' to None because {}".format(e))
810 psd_default = None
811 self._psd_default = psd_default
813 @property
814 def pastro_probs(self):
815 return self._pastro_probs
817 @pastro_probs.setter
818 def pastro_probs(self, pastro_probs):
819 from pesummary.gw.classification import PAstro
821 classifications = {}
822 for num, i in enumerate(list(self.samples.keys())):
823 try:
824 import importlib
825 distance_prior = self.priors["analytic"]["luminosity_distance"]
826 cls = distance_prior.split("(")[0]
827 module = ".".join(cls.split(".")[:-1])
828 cls = cls.split(".")[-1]
829 cls = getattr(importlib.import_module(module), cls, cls)
830 args = "(".join(distance_prior.split("(")[1:])[:-1]
831 distance_prior = cls.from_repr(args)
832 except KeyError:
833 logger.debug(
834 f"Unable to find a distance prior. Defaulting to stored "
835 f"prior in pesummary.gw.classification.PAstro for "
836 f"source classification probabilities"
837 )
838 distance_prior = None
839 except AttributeError:
840 logger.debug(
841 f"Unable to load distance prior: {distance_prior}. "
842 f"Defaulting to stored prior in "
843 f"pesummary.gw.classification.PAstro for source "
844 f"classification probabilities"
845 )
846 distance_prior = None
847 try:
848 _cls = PAstro(
849 self.samples[i], category_data=self.pastro_category_file,
850 terrestrial_probability=self.terrestrial_probability[num],
851 distance_prior=distance_prior,
852 catch_terrestrial_probability_error=self.catch_terrestrial_probability_error
853 )
854 classifications[i] = {"default": _cls.classification()}
855 try:
856 _cls.save_to_file(
857 f"{i}.pesummary.p_astro.json",
858 classifications[i]["default"],
859 outdir=f"{self.webdir}/samples",
860 overwrite=True
861 )
862 except FileNotFoundError as e:
863 logger.warning(
864 f"Failed to write PAstro probabilities to file "
865 f"because {e}"
866 )
867 except Exception as e:
868 logger.warning(
869 "Failed to generate source classification probabilities "
870 "because {}".format(e)
871 )
872 classifications[i] = {"default": PAstro.defaults}
873 if self.mcmc_samples:
874 if any(_probs is None for _probs in classifications.values()):
875 classifications[self.labels[0]] = None
876 logger.warning(
877 "Unable to average classification probabilities across "
878 "mcmc chains because one or more chains failed to estimate "
879 "classifications"
880 )
881 else:
882 logger.debug(
883 "Averaging classification probability across mcmc samples"
884 )
885 classifications[self.labels[0]] = {
886 prior: {
887 key: np.round(np.mean(
888 [val[prior][key] for val in classifications.values()]
889 ), 3) for key in _probs.keys()
890 } for prior, _probs in
891 list(classifications.values())[0].items()
892 }
893 self._pastro_probs = classifications
895 @property
896 def embright_probs(self):
897 return self._embright_probs
899 @embright_probs.setter
900 def embright_probs(self, embright_probs):
901 from pesummary.gw.classification import EMBright
903 probabilities = {}
904 for num, i in enumerate(list(self.samples.keys())):
905 try:
906 _cls = EMBright(self.samples[i])
907 probabilities[i] = {"default": _cls.classification()}
908 try:
909 _cls.save_to_file(
910 f"{i}.pesummary.em_bright.json",
911 probabilities[i]["default"],
912 outdir=f"{self.webdir}/samples",
913 overwrite=True
914 )
915 except FileNotFoundError as e:
916 logger.warning(
917 f"Failed to write EM bright probabilities to file "
918 f"because {e}"
919 )
920 except Exception as e:
921 logger.warning(
922 "Failed to generate em_bright probabilities because "
923 "{}".format(e)
924 )
925 probabilities[i] = {"default": EMBright.defaults}
926 if self.mcmc_samples:
927 if any(_probs is None for _probs in probabilities.values()):
928 probabilities[self.labels[0]] = None
929 logger.warning(
930 "Unable to average em_bright probabilities across "
931 "mcmc chains because one or more chains failed to estimate "
932 "probabilities"
933 )
934 else:
935 logger.debug(
936 "Averaging em_bright probability across mcmc samples"
937 )
938 probabilities[self.labels[0]] = {
939 prior: {
940 key: np.round(np.mean(
941 [val[prior][key] for val in probabilities.values()]
942 ), 3) for key in _probs.keys()
943 } for prior, _probs in list(probabilities.values())[0].items()
944 }
945 self._embright_probs = probabilities
947 @property
948 def preliminary_pages(self):
949 return self._preliminary_pages
951 @preliminary_pages.setter
952 def preliminary_pages(self, preliminary_pages):
953 required = conf.gw_reproducibility
954 self._preliminary_pages = {label: False for label in self.labels}
955 for num, label in enumerate(self.labels):
956 for attr in required:
957 _property = getattr(self, attr)
958 if isinstance(_property, dict):
959 if label not in _property.keys():
960 self._preliminary_pages[label] = True
961 elif not len(_property[label]):
962 self._preliminary_pages[label] = True
963 elif isinstance(_property, list):
964 if _property[num] is None:
965 self._preliminary_pages[label] = True
966 if any(value for value in self._preliminary_pages.values()):
967 _labels = [
968 label for label, value in self._preliminary_pages.items() if
969 value
970 ]
971 msg = (
972 "Unable to reproduce the {} analys{} because no {} data was "
973 "provided. 'Preliminary' watermarks will be added to the final "
974 "html pages".format(
975 ", ".join(_labels), "es" if len(_labels) > 1 else "is",
976 " or ".join(required)
977 )
978 )
979 logger.warning(msg)
981 @staticmethod
982 def _extract_IFO_data_from_file(file, cls, desc, IFO=None, **kwargs):
983 """Return IFO data stored in a file
985 Parameters
986 ----------
987 file: path
988 path to a file containing the IFO data
989 cls: obj
990 class you wish to use when loading the file. This class must have
991 a '.read' method
992 desc: str
993 description of the IFO data stored in the file
994 IFO: str, optional
995 the IFO which the data belongs to
996 """
997 general = (
998 "Failed to read in %s data because {}. The %s plot will not be "
999 "generated and the %s data will not be added to the metafile."
1000 ) % (desc, desc, desc)
1001 try:
1002 return cls.read(file, IFO=IFO, **kwargs)
1003 except FileNotFoundError:
1004 logger.warning(
1005 general.format("the file {} does not exist".format(file))
1006 )
1007 return {}
1008 except ValueError as e:
1009 logger.warning(general.format(e))
1010 return {}
1012 @staticmethod
1013 def extract_psd_data_from_file(file, IFO=None):
1014 """Return the data stored in a psd file
1016 Parameters
1017 ----------
1018 file: path
1019 path to a file containing the psd data
1020 """
1021 from pesummary.gw.file.psd import PSD
1022 return _GWInput._extract_IFO_data_from_file(file, PSD, "PSD", IFO=IFO)
1024 @staticmethod
1025 def extract_calibration_data_from_file(file, type="data", **kwargs):
1026 """Return the data stored in a calibration file
1028 Parameters
1029 ----------
1030 file: path
1031 path to a file containing the calibration data
1032 """
1033 from pesummary.gw.file.calibration import Calibration
1034 return _GWInput._extract_IFO_data_from_file(
1035 file, Calibration, "calibration", type=type, **kwargs
1036 )
1038 @staticmethod
1039 def get_ifo_from_file_name(file):
1040 """Return the IFO from the file name
1042 Parameters
1043 ----------
1044 file: str
1045 path to the file
1046 """
1047 file_name = file.split("/")[-1]
1048 if any(j in file_name for j in ["H1", "_0", "IFO0"]):
1049 ifo = "H1"
1050 elif any(j in file_name for j in ["L1", "_1", "IFO1"]):
1051 ifo = "L1"
1052 elif any(j in file_name for j in ["V1", "_2", "IFO2"]):
1053 ifo = "V1"
1054 else:
1055 ifo = file_name
1056 return ifo
1058 def get_psd_or_calibration_data(self, input, executable, **kwargs):
1059 """Return a dictionary containing the psd or calibration data
1061 Parameters
1062 ----------
1063 input: list/dict
1064 list/dict containing paths to calibration/psd files
1065 executable: func
1066 executable that is used to extract the data from the calibration/psd
1067 files
1068 """
1069 data = {}
1070 if input == {} or input == []:
1071 return data
1072 if isinstance(input, dict):
1073 keys = list(input.keys())
1074 if isinstance(input, dict) and isinstance(input[keys[0]], list):
1075 if not all(len(input[i]) == len(self.labels) for i in list(keys)):
1076 raise InputError(
1077 "Please ensure the number of calibration/psd files matches "
1078 "the number of result files passed"
1079 )
1080 for idx in range(len(input[keys[0]])):
1081 data[self.labels[idx]] = {
1082 i: executable(input[i][idx], IFO=i, **kwargs) for i in list(keys)
1083 }
1084 elif isinstance(input, dict):
1085 for i in self.labels:
1086 data[i] = {
1087 j: executable(input[j], IFO=j, **kwargs) for j in list(input.keys())
1088 }
1089 elif isinstance(input, list):
1090 for i in self.labels:
1091 data[i] = {
1092 self.get_ifo_from_file_name(j): executable(
1093 j, IFO=self.get_ifo_from_file_name(j), **kwargs
1094 ) for j in input
1095 }
1096 else:
1097 raise InputError(
1098 "Did not understand the psd/calibration input. Please use the "
1099 "following format 'H1:path/to/file'"
1100 )
1101 return data
1103 def grab_priors_from_inputs(self, priors):
1104 def read_func(data, **kwargs):
1105 from pesummary.gw.file.read import read as GWRead
1106 data = GWRead(data, **kwargs)
1107 data.generate_all_posterior_samples()
1108 return data
1110 return super(_GWInput, self).grab_priors_from_inputs(
1111 priors, read_func=read_func, read_kwargs=self.grab_data_kwargs
1112 )
1114 def grab_key_data_from_result_files(self):
1115 """Grab the mean, median, maxL and standard deviation for all
1116 parameters for all each result file
1117 """
1118 from pesummary.utils.kde_list import KDEList
1119 from pesummary.gw.plots.plot import _return_bounds
1120 from pesummary.utils.credible_interval import (
1121 hpd_two_sided_credible_interval
1122 )
1123 from pesummary.utils.bounded_1d_kde import bounded_1d_kde
1124 key_data = super(_GWInput, self).grab_key_data_from_result_files()
1125 bounded_parameters = ["mass_ratio", "a_1", "a_2", "lambda_tilde"]
1126 for param in bounded_parameters:
1127 xlow, xhigh = _return_bounds(param, [])
1128 _samples = {
1129 key: val[param] for key, val in self.samples.items()
1130 if param in val.keys()
1131 }
1132 _min = [np.min(_) for _ in _samples.values() if len(_samples)]
1133 _max = [np.max(_) for _ in _samples.values() if len(_samples)]
1134 if not len(_min):
1135 continue
1136 _min = np.min(_min)
1137 _max = np.max(_max)
1138 x = np.linspace(_min, _max, 1000)
1139 try:
1140 kdes = KDEList(
1141 list(_samples.values()), kde=bounded_1d_kde,
1142 kde_kwargs={"xlow": xlow, "xhigh": xhigh}
1143 )
1144 except Exception as e:
1145 logger.warning(
1146 "Unable to compute the HPD interval for {} because {}".format(
1147 param, e
1148 )
1149 )
1150 continue
1151 pdfs = kdes(x)
1152 for num, key in enumerate(_samples.keys()):
1153 [xlow, xhigh], _ = hpd_two_sided_credible_interval(
1154 [], 90, x=x, pdf=pdfs[num]
1155 )
1156 key_data[key][param]["90% HPD"] = [xlow, xhigh]
1157 for _param in self.samples[key].keys():
1158 if _param in bounded_parameters:
1159 continue
1160 key_data[key][_param]["90% HPD"] = float("nan")
1161 return key_data
1164class SamplesInput(_GWInput, pesummary.core.cli.inputs.SamplesInput):
1165 """Class to handle and store sample specific command line arguments
1166 """
1167 def __init__(self, *args, **kwargs):
1168 kwargs.update({"ignore_copy": True})
1169 super(SamplesInput, self).__init__(
1170 *args, gw=True, extra_options=[
1171 "evolve_spins_forwards",
1172 "evolve_spins_backwards",
1173 "NRSur_fits",
1174 "calculate_multipole_snr",
1175 "calculate_precessing_snr",
1176 "f_start",
1177 "f_low",
1178 "f_ref",
1179 "f_final",
1180 "psd",
1181 "waveform_fits",
1182 "redshift_method",
1183 "cosmology",
1184 "no_conversion",
1185 "delta_f",
1186 "psd_default",
1187 "disable_remnant",
1188 "force_BBH_remnant_computation",
1189 "force_BH_spin_evolution"
1190 ], **kwargs
1191 )
1192 if self._restarted_from_checkpoint:
1193 return
1194 if self.existing is not None:
1195 self.existing_data = self.grab_data_from_metafile(
1196 self.existing_metafile, self.existing,
1197 compare=self.compare_results
1198 )
1199 self.existing_approximant = self.existing_data["approximant"]
1200 self.existing_psd = self.existing_data["psd"]
1201 self.existing_calibration = self.existing_data["calibration"]
1202 self.existing_skymap = self.existing_data["skymap"]
1203 else:
1204 self.existing_approximant = None
1205 self.existing_psd = None
1206 self.existing_calibration = None
1207 self.existing_skymap = None
1208 self.approximant = self.opts.approximant
1209 self.gracedb_server = self.opts.gracedb_server
1210 self.gracedb_data = self.opts.gracedb_data
1211 self.gracedb = self.opts.gracedb
1212 self.pastro_category_file = self.opts.pastro_category_file
1213 self.terrestrial_probability = self.opts.terrestrial_probability
1214 self.catch_terrestrial_probability_error = self.opts.catch_terrestrial_probability_error
1215 self.approximant_flags = self.opts.approximant_flags
1216 self.detectors = None
1217 self.skymap = None
1218 self.calibration_definition = self.opts.calibration_definition
1219 self.calibration = self.opts.calibration
1220 self.gwdata = self.opts.gwdata
1221 self.maxL_samples = []
1223 @property
1224 def maxL_samples(self):
1225 return self._maxL_samples
1227 @maxL_samples.setter
1228 def maxL_samples(self, maxL_samples):
1229 key_data = self.grab_key_data_from_result_files()
1230 maxL_samples = {
1231 i: {
1232 j: key_data[i][j]["maxL"] for j in key_data[i].keys()
1233 } for i in key_data.keys()
1234 }
1235 for i in self.labels:
1236 maxL_samples[i]["approximant"] = self.approximant[i]
1237 self._maxL_samples = maxL_samples
1240class PlottingInput(SamplesInput, pesummary.core.cli.inputs.PlottingInput):
1241 """Class to handle and store plottig specific command line arguments
1242 """
1243 def __init__(self, *args, **kwargs):
1244 super(PlottingInput, self).__init__(*args, **kwargs)
1245 self.nsamples_for_skymap = self.opts.nsamples_for_skymap
1246 self.sensitivity = self.opts.sensitivity
1247 self.no_ligo_skymap = self.opts.no_ligo_skymap
1248 self.multi_threading_for_skymap = self.multi_process
1249 if not self.no_ligo_skymap and self.multi_process > 1:
1250 total = self.multi_process
1251 self.multi_threading_for_plots = int(total / 2.)
1252 self.multi_threading_for_skymap = total - self.multi_threading_for_plots
1253 logger.info(
1254 "Assigning {} process{}to skymap generation and {} process{}to "
1255 "other plots".format(
1256 self.multi_threading_for_skymap,
1257 "es " if self.multi_threading_for_skymap > 1 else " ",
1258 self.multi_threading_for_plots,
1259 "es " if self.multi_threading_for_plots > 1 else " "
1260 )
1261 )
1262 self.preliminary_pages = None
1263 self.pastro_probs = []
1264 self.embright_probs = []
1265 self.classification_probs = {}
1266 for key in self.pastro_probs.keys():
1267 self.classification_probs[key] = {"default": {}}
1268 self.classification_probs[key]["default"].update(
1269 self.pastro_probs[key]["default"]
1270 )
1271 self.classification_probs[key]["default"].update(
1272 self.embright_probs[key]["default"]
1273 )
1276class WebpageInput(SamplesInput, pesummary.core.cli.inputs.WebpageInput):
1277 """Class to handle and store webpage specific command line arguments
1278 """
1279 def __init__(self, *args, **kwargs):
1280 super(WebpageInput, self).__init__(*args, **kwargs)
1281 self.public = self.opts.public
1282 if not hasattr(self, "preliminary_pages"):
1283 self.preliminary_pages = None
1284 if not hasattr(self, "pastro_probs"):
1285 self.pastro_probs = []
1286 if not hasattr(self, "embright_probs"):
1287 self.embright_probs = []
1288 self.classification_probs = {}
1289 for key in self.pastro_probs.keys():
1290 self.classification_probs[key] = {"default": {}}
1291 self.classification_probs[key]["default"].update(
1292 self.pastro_probs[key]["default"]
1293 )
1294 self.classification_probs[key]["default"].update(
1295 self.embright_probs[key]["default"]
1296 )
1299class WebpagePlusPlottingInput(PlottingInput, WebpageInput):
1300 """Class to handle and store webpage and plotting specific command line
1301 arguments
1302 """
1303 def __init__(self, *args, **kwargs):
1304 super(WebpagePlusPlottingInput, self).__init__(*args, **kwargs)
1306 @property
1307 def default_directories(self):
1308 return super(WebpagePlusPlottingInput, self).default_directories
1310 @property
1311 def default_files_to_copy(self):
1312 return super(WebpagePlusPlottingInput, self).default_files_to_copy
1315class MetaFileInput(SamplesInput, pesummary.core.cli.inputs.MetaFileInput):
1316 """Class to handle and store metafile specific command line arguments
1317 """
1318 @property
1319 def default_directories(self):
1320 dirs = super(MetaFileInput, self).default_directories
1321 dirs += ["psds", "calibration"]
1322 return dirs
1324 def copy_files(self):
1325 _error = "Failed to save the {} to file"
1326 for label in self.labels:
1327 if self.psd[label] != {}:
1328 for ifo in self.psd[label].keys():
1329 if not isinstance(self.psd[label][ifo], PSD):
1330 logger.warning(_error.format("{} PSD".format(ifo)))
1331 continue
1332 self.psd[label][ifo].save_to_file(
1333 os.path.join(self.webdir, "psds", "{}_{}_psd.dat".format(
1334 label, ifo
1335 ))
1336 )
1337 if label in self.priors["calibration"].keys():
1338 if self.priors["calibration"][label] != {}:
1339 for ifo in self.priors["calibration"][label].keys():
1340 _instance = isinstance(
1341 self.priors["calibration"][label][ifo], Calibration
1342 )
1343 if not _instance:
1344 logger.warning(
1345 _error.format(
1346 "{} calibration envelope".format(
1347 ifo
1348 )
1349 )
1350 )
1351 continue
1352 self.priors["calibration"][label][ifo].save_to_file(
1353 os.path.join(self.webdir, "calibration", "{}_{}_cal.txt".format(
1354 label, ifo
1355 ))
1356 )
1357 return super(MetaFileInput, self).copy_files()
1360class WebpagePlusPlottingPlusMetaFileInput(MetaFileInput, WebpagePlusPlottingInput):
1361 """Class to handle and store webpage, plotting and metafile specific command
1362 line arguments
1363 """
1364 def __init__(self, *args, **kwargs):
1365 super(WebpagePlusPlottingPlusMetaFileInput, self).__init__(
1366 *args, **kwargs
1367 )
1369 @property
1370 def default_directories(self):
1371 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_directories
1373 @property
1374 def default_files_to_copy(self):
1375 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_files_to_copy
1378@deprecation(
1379 "The GWInput class is deprecated. Please use either the BaseInput, "
1380 "SamplesInput, PlottingInput, WebpageInput, WebpagePlusPlottingInput, "
1381 "MetaFileInput or the WebpagePlusPlottingPlusMetaFileInput class"
1382)
1383class GWInput(WebpagePlusPlottingPlusMetaFileInput):
1384 pass
1387class IMRCTInput(pesummary.core.cli.inputs._Input):
1388 """Class to handle the TGR specific command line arguments
1389 """
1390 @property
1391 def labels(self):
1392 return self._labels
1394 @labels.setter
1395 def labels(self, labels):
1396 self._labels = labels
1397 if len(labels) % 2 != 0:
1398 raise ValueError(
1399 "The IMRCT test requires 2 results files for each analysis. "
1400 )
1401 elif len(labels) > 2:
1402 cond = all(
1403 ":inspiral" in label or ":postinspiral" in label for label in
1404 labels
1405 )
1406 if not cond:
1407 raise ValueError(
1408 "To compare 2 or more analyses, please provide labels as "
1409 "'{}:inspiral' and '{}:postinspiral' where {} indicates "
1410 "the analysis label"
1411 )
1412 else:
1413 self.analysis_label = [
1414 label.split(":inspiral")[0]
1415 for label in labels
1416 if ":inspiral" in label and ":postinspiral" not in label
1417 ]
1418 if len(self.analysis_label) != len(self.result_files) / 2:
1419 raise ValueError(
1420 "When comparing more than 2 analyses, labels must "
1421 "be of the form '{}:inspiral' and '{}:postinspiral'."
1422 )
1423 logger.info(
1424 "Using the labels: {} to distinguish analyses".format(
1425 ", ".join(self.analysis_label)
1426 )
1427 )
1428 elif sorted(labels) != ["inspiral", "postinspiral"]:
1429 if all(self.is_pesummary_metafile(ff) for ff in self.result_files):
1430 meta_file_labels = []
1431 for suffix in [":inspiral", ":postinspiral"]:
1432 if any(suffix in label for label in labels):
1433 ind = [
1434 num for num, label in enumerate(labels) if
1435 suffix in label
1436 ]
1437 if len(ind) > 1:
1438 raise ValueError(
1439 "Please provide a single {} label".format(
1440 suffix.split(":")[1]
1441 )
1442 )
1443 meta_file_labels.append(
1444 labels[ind[0]].split(suffix)[0]
1445 )
1446 else:
1447 raise ValueError(
1448 "Please provide labels as {inspiral_label}:inspiral "
1449 "and {postinspiral_label}:postinspiral where "
1450 "inspiral_label and postinspiral_label are the "
1451 "PESummary labels for the inspiral and postinspiral "
1452 "analyses respectively. "
1453 )
1454 if len(self.result_files) == 1:
1455 logger.info(
1456 "Using the {} samples for the inspiral analysis and {} "
1457 "samples for the postinspiral analysis from the file "
1458 "{}".format(
1459 meta_file_labels[0], meta_file_labels[1],
1460 self.result_files[0]
1461 )
1462 )
1463 elif len(self.result_files) == 2:
1464 logger.info(
1465 "Using the {} samples for the inspiral analysis from "
1466 "the file {}. Using the {} samples for the "
1467 "postinspiral analysis from the file {}".format(
1468 meta_file_labels[0], self.result_files[0],
1469 meta_file_labels[1], self.result_files[1]
1470 )
1471 )
1472 else:
1473 raise ValueError(
1474 "Currently, you can only provide at most 2 pesummary "
1475 "metafiles. If one is provided, both the inspiral and "
1476 "postinspiral are extracted from that single file. If "
1477 "two are provided, the inspiral is extracted from one "
1478 "file and the postinspiral is extracted from the other."
1479 )
1480 self._labels = ["inspiral", "postinspiral"]
1481 self._meta_file_labels = meta_file_labels
1482 self.analysis_label = ["primary"]
1483 else:
1484 raise ValueError(
1485 "The IMRCT test requires an inspiral and postinspiral result "
1486 "file. Please indicate which file is the inspiral and which "
1487 "is postinspiral by providing these exact labels to the "
1488 "summarytgr executable"
1489 )
1490 else:
1491 self.analysis_label = ["primary"]
1493 def _extract_stored_approximant(self, opened_file, label):
1494 """Extract the approximant used for a given analysis stored in a
1495 PESummary metafile
1497 Parameters
1498 ----------
1499 opened_file: pesummary.gw.file.formats.pesummary.PESummary
1500 opened metafile that contains the analysis 'label'
1501 label: str
1502 analysis label which is stored in the PESummary metafile
1503 """
1504 if opened_file.approximant is not None:
1505 if label not in opened_file.labels:
1506 raise ValueError(
1507 "Invalid label {}. The list of available labels are {}".format(
1508 label, ", ".join(opened_file.labels)
1509 )
1510 )
1511 _index = opened_file.labels.index(label)
1512 return opened_file.approximant[_index]
1513 return
1515 def _extract_stored_remnant_fits(self, opened_file, label):
1516 """Extract the remnant fits used for a given analysis stored in a
1517 PESummary metafile
1519 Parameters
1520 ----------
1521 opened_file: pesummary.gw.file.formats.pesummary.PESummary
1522 opened metafile that contains the analysis 'label'
1523 label: str
1524 analysis label which is stored in the PESummary metafile
1525 """
1526 fits = {}
1527 fit_strings = [
1528 "final_spin_NR_fits", "final_mass_NR_fits"
1529 ]
1530 if label not in opened_file.labels:
1531 raise ValueError(
1532 "Invalid label {}. The list of available labels are {}".format(
1533 label, ", ".join(opened_file.labels)
1534 )
1535 )
1536 _index = opened_file.labels.index(label)
1537 _meta_data = opened_file.extra_kwargs[_index]
1538 if "meta_data" in _meta_data.keys():
1539 for key in fit_strings:
1540 if key in _meta_data["meta_data"].keys():
1541 fits[key] = _meta_data["meta_data"][key]
1542 if len(fits):
1543 return fits
1544 return
1546 def _extract_stored_cutoff_frequency(self, opened_file, label):
1547 """Extract the cutoff frequencies used for a given analysis stored in a
1548 PESummary metafile
1550 Parameters
1551 ----------
1552 opened_file: pesummary.gw.file.formats.pesummary.PESummary
1553 opened metafile that contains the analysis 'label'
1554 label: str
1555 analysis label which is stored in the PESummary metafile
1556 """
1557 frequencies = {}
1558 if opened_file.config is not None:
1559 if label not in opened_file.labels:
1560 raise ValueError(
1561 "Invalid label {}. The list of available labels are {}".format(
1562 label, ", ".join(opened_file.labels)
1563 )
1564 )
1565 if opened_file.config[label] is not None:
1566 _config = opened_file.config[label]
1567 if "config" in _config.keys():
1568 if "maximum-frequency" in _config["config"].keys():
1569 frequencies["fhigh"] = _config["config"][
1570 "maximum-frequency"
1571 ]
1572 if "minimum-frequency" in _config["config"].keys():
1573 frequencies["flow"] = _config["config"][
1574 "minimum-frequency"
1575 ]
1576 elif "lalinference" in _config.keys():
1577 if "fhigh" in _config["lalinference"].keys():
1578 frequencies["fhigh"] = _config["lalinference"][
1579 "fhigh"
1580 ]
1581 if "flow" in _config["lalinference"].keys():
1582 frequencies["flow"] = _config["lalinference"][
1583 "flow"
1584 ]
1585 return frequencies
1586 return
1588 @property
1589 def samples(self):
1590 return self._samples
1592 @samples.setter
1593 def samples(self, samples):
1594 from pesummary.utils.samples_dict import MultiAnalysisSamplesDict
1595 self._read_samples = {
1596 _label: GWRead(_path, disable_prior=True) for _label, _path in zip(
1597 self.labels, self.result_files
1598 )
1599 }
1600 _samples_dict = {}
1601 _approximant_dict = {}
1602 _cutoff_frequency_dict = {}
1603 _remnant_fits_dict = {}
1604 for label, _open in self._read_samples.items():
1605 if isinstance(_open.samples_dict, MultiAnalysisSamplesDict):
1606 if not len(self._meta_file_labels):
1607 raise ValueError(
1608 "Currently you can only pass a file containing a "
1609 "single analysis or a valid PESummary metafile "
1610 "containing multiple analyses"
1611 )
1612 _labels = _open.labels
1613 if len(self._read_samples) == 1:
1614 _samples_dict = {
1615 label: _open.samples_dict[meta_file_label] for
1616 label, meta_file_label in zip(
1617 self.labels, self._meta_file_labels
1618 )
1619 }
1620 for label, meta_file_label in zip(self.labels, self._meta_file_labels):
1621 _stored_approx = self._extract_stored_approximant(
1622 _open, meta_file_label
1623 )
1624 _stored_frequencies = self._extract_stored_cutoff_frequency(
1625 _open, meta_file_label
1626 )
1627 _stored_remnant_fits = self._extract_stored_remnant_fits(
1628 _open, meta_file_label
1629 )
1630 if _stored_approx is not None:
1631 _approximant_dict[label] = _stored_approx
1632 if _stored_remnant_fits is not None:
1633 _remnant_fits_dict[label] = _stored_remnant_fits
1634 if _stored_frequencies is not None:
1635 if label == "inspiral":
1636 if "fhigh" in _stored_frequencies.keys():
1637 _cutoff_frequency_dict[label] = _stored_frequencies[
1638 "fhigh"
1639 ]
1640 if label == "postinspiral":
1641 if "flow" in _stored_frequencies.keys():
1642 _cutoff_frequency_dict[label] = _stored_frequencies[
1643 "flow"
1644 ]
1645 break
1646 else:
1647 ind = self.labels.index(label)
1648 _samples_dict[label] = _open.samples_dict[
1649 self._meta_file_labels[ind]
1650 ]
1651 _stored_approx = self._extract_stored_approximant(
1652 _open, self._meta_file_labels[ind]
1653 )
1654 _stored_frequencies = self._extract_stored_cutoff_frequency(
1655 _open, self._meta_file_labels[ind]
1656 )
1657 _stored_remnant_fits = self._extract_stored_remnant_fits(
1658 _open, self._meta_file_labels[ind]
1659 )
1660 if _stored_approx is not None:
1661 _approximant_dict[label] = _stored_approx
1662 if _stored_remnant_fits is not None:
1663 _remnant_fits_dict[label] = _stored_remnant_fits
1664 if _stored_frequencies is not None:
1665 if label == "inspiral":
1666 if "fhigh" in _stored_frequencies.keys():
1667 _cutoff_frequency_dict[label] = _stored_frequencies[
1668 "fhigh"
1669 ]
1670 if label == "postinspiral":
1671 if "flow" in _stored_frequencies.keys():
1672 _cutoff_frequency_dict[label] = _stored_frequencies[
1673 "flow"
1674 ]
1675 else:
1676 _samples_dict[label] = _open.samples_dict
1677 extra_kwargs = _open.extra_kwargs
1678 if "pe_algorithm" in extra_kwargs["sampler"].keys():
1679 if extra_kwargs["sampler"]["pe_algorithm"] == "bilby":
1680 try:
1681 subkwargs = extra_kwargs["other"]["likelihood"][
1682 "waveform_arguments"
1683 ]
1684 _approximant_dict[label] = (
1685 subkwargs["waveform_approximant"]
1686 )
1687 if "inspiral" in label and "postinspiral" not in label:
1688 _cutoff_frequency_dict[label] = (
1689 subkwargs["maximum_frequency"]
1690 )
1691 elif "postinspiral" in label:
1692 _cutoff_frequency_dict[label] = (
1693 subkwargs["minimum_frequency"]
1694 )
1695 except KeyError:
1696 pass
1697 self._samples = MultiAnalysisSamplesDict(_samples_dict)
1698 if len(_approximant_dict):
1699 self._approximant_dict = _approximant_dict
1700 if len(_cutoff_frequency_dict):
1701 self._cutoff_frequency_dict = _cutoff_frequency_dict
1702 if len(_remnant_fits_dict):
1703 self._remnant_fits_dict = _remnant_fits_dict
1705 @property
1706 def imrct_kwargs(self):
1707 return self._imrct_kwargs
1709 @imrct_kwargs.setter
1710 def imrct_kwargs(self, imrct_kwargs):
1711 test_kwargs = dict(N_bins=101)
1712 try:
1713 test_kwargs.update(imrct_kwargs)
1714 except AttributeError:
1715 test_kwargs = test_kwargs
1717 for key, value in test_kwargs.items():
1718 try:
1719 test_kwargs[key] = ast.literal_eval(value)
1720 except ValueError:
1721 pass
1722 self._imrct_kwargs = test_kwargs
1724 @property
1725 def meta_data(self):
1726 return self._meta_data
1728 @meta_data.setter
1729 def meta_data(self, meta_data):
1730 self._meta_data = {}
1731 for num, _inspiral in enumerate(self.inspiral_keys):
1732 frequency_dict = dict()
1733 approximant_dict = dict()
1734 remnant_dict = dict()
1735 zipped = zip(
1736 [self.cutoff_frequency, self.approximant, None],
1737 [frequency_dict, approximant_dict, remnant_dict],
1738 ["cutoff_frequency", "approximant", "remnant_fits"]
1739 )
1740 _inspiral_string = self.inspiral_keys[num]
1741 _postinspiral_string = self.postinspiral_keys[num]
1742 for _list, _dict, name in zipped:
1743 if _list is not None and len(_list) == len(self.labels):
1744 inspiral_ind = self.labels.index(_inspiral_string)
1745 postinspiral_ind = self.labels.index(_postinspiral_string)
1746 _dict["inspiral"] = _list[inspiral_ind]
1747 _dict["postinspiral"] = _list[postinspiral_ind]
1748 elif _list is not None:
1749 raise ValueError(
1750 "Please provide a 'cutoff_frequency' and 'approximant' "
1751 "for each file"
1752 )
1753 else:
1754 try:
1755 if name == "cutoff_frequency":
1756 if "inspiral" in self._cutoff_frequency_dict.keys():
1757 _dict["inspiral"] = self._cutoff_frequency_dict[
1758 "inspiral"
1759 ]
1760 if "postinspiral" in self._cutoff_frequency_dict.keys():
1761 _dict["postinspiral"] = self._cutoff_frequency_dict[
1762 "postinspiral"
1763 ]
1764 elif name == "approximant":
1765 if "inspiral" in self._approximant_dict.keys():
1766 _dict["inspiral"] = self._approximant_dict[
1767 "inspiral"
1768 ]
1769 if "postinspiral" in self._approximant_dict.keys():
1770 _dict["postinspiral"] = self._approximant_dict[
1771 "postinspiral"
1772 ]
1773 elif name == "remnant_fits":
1774 if "inspiral" in self._remnant_fits_dict.keys():
1775 _dict["inspiral"] = self._remnant_fits_dict[
1776 "inspiral"
1777 ]
1778 if "postinspiral" in self._remnant_fits_dict.keys():
1779 _dict["postinspiral"] = self._remnant_fits_dict[
1780 "postinspiral"
1781 ]
1782 except (AttributeError, KeyError, TypeError):
1783 _dict["inspiral"] = None
1784 _dict["postinspiral"] = None
1786 self._meta_data[self.analysis_label[num]] = {
1787 "inspiral maximum frequency (Hz)": frequency_dict["inspiral"],
1788 "postinspiral minimum frequency (Hz)": frequency_dict["postinspiral"],
1789 "inspiral approximant": approximant_dict["inspiral"],
1790 "postinspiral approximant": approximant_dict["postinspiral"],
1791 "inspiral remnant fits": remnant_dict["inspiral"],
1792 "postinspiral remnant fits": remnant_dict["postinspiral"]
1793 }
1795 def __init__(self, opts):
1796 self.opts = opts
1797 self.existing = None
1798 self.webdir = self.opts.webdir
1799 self.user = None
1800 self.baseurl = None
1801 self.result_files = self.opts.samples
1802 self.labels = self.opts.labels
1803 self.samples = self.opts.samples
1804 self.inspiral_keys = [
1805 key for key in self.samples.keys() if "inspiral" in key
1806 and "postinspiral" not in key
1807 ]
1808 self.postinspiral_keys = [
1809 key.replace("inspiral", "postinspiral") for key in self.inspiral_keys
1810 ]
1811 try:
1812 self.imrct_kwargs = self.opts.imrct_kwargs
1813 except AttributeError:
1814 self.imrct_kwargs = {}
1815 for _arg in ["cutoff_frequency", "approximant", "links_to_pe_pages", "f_low"]:
1816 _attr = getattr(self.opts, _arg)
1817 if _attr is not None and len(_attr) and len(_attr) != len(self.labels):
1818 raise ValueError("Please provide a {} for each file".format(_arg))
1819 setattr(self, _arg, _attr)
1820 self.meta_data = None
1821 self.default_directories = ["samples", "plots", "js", "html", "css"]
1822 self.publication = False
1823 self.make_directories()