Coverage for pesummary/gw/cli/inputs.py: 78.7%
894 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import ast
4import os
5import numpy as np
6import pesummary.core.cli.inputs
7from pesummary.gw.file.read import read as GWRead
8from pesummary.gw.file.psd import PSD
9from pesummary.gw.file.calibration import Calibration
10from pesummary.utils.decorators import deprecation
11from pesummary.utils.exceptions import InputError
12from pesummary.utils.utils import logger
13from pesummary import conf
15__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
18class _GWInput(pesummary.core.cli.inputs._Input):
19 """Super class to handle gw specific command line inputs
20 """
21 @staticmethod
22 def grab_data_from_metafile(
23 existing_file, webdir, compare=None, nsamples=None, **kwargs
24 ):
25 """Grab data from an existing PESummary metafile
27 Parameters
28 ----------
29 existing_file: str
30 path to the existing metafile
31 webdir: str
32 the directory to store the existing configuration file
33 compare: list, optional
34 list of labels for events stored in an existing metafile that you
35 wish to compare
36 """
37 _replace_kwargs = {
38 "psd": "{file}.psd['{label}']"
39 }
40 if "psd_default" in kwargs.keys():
41 _replace_kwargs["psd_default"] = kwargs["psd_default"]
42 data = pesummary.core.cli.inputs._Input.grab_data_from_metafile(
43 existing_file, webdir, compare=compare, read_function=GWRead,
44 nsamples=nsamples, _replace_with_pesummary_kwargs=_replace_kwargs,
45 **kwargs
46 )
47 f = GWRead(existing_file)
49 labels = data["labels"]
51 psd = {i: {} for i in labels}
52 if f.psd is not None and f.psd != {}:
53 for i in labels:
54 if i in f.psd.keys() and f.psd[i] != {}:
55 psd[i] = {
56 ifo: PSD(f.psd[i][ifo]) for ifo in f.psd[i].keys()
57 }
58 calibration = {i: {} for i in labels}
59 if f.calibration is not None and f.calibration != {}:
60 for i in labels:
61 if i in f.calibration.keys() and f.calibration[i] != {}:
62 calibration[i] = {
63 ifo: Calibration(f.calibration[i][ifo]) for ifo in
64 f.calibration[i].keys()
65 }
66 skymap = {i: None for i in labels}
67 if hasattr(f, "skymap") and f.skymap is not None and f.skymap != {}:
68 for i in labels:
69 if i in f.skymap.keys() and len(f.skymap[i]):
70 skymap[i] = f.skymap[i]
71 data.update(
72 {
73 "approximant": {
74 i: j for i, j in zip(
75 labels, [f.approximant[ind] for ind in data["indicies"]]
76 )
77 },
78 "psd": psd,
79 "calibration": calibration,
80 "skymap": skymap
81 }
82 )
83 return data
85 @property
86 def grab_data_kwargs(self):
87 kwargs = super(_GWInput, self).grab_data_kwargs
88 for _property in ["f_start", "f_low", "f_ref", "f_final", "delta_f"]:
89 if getattr(self, _property) is None:
90 setattr(self, "_{}".format(_property), [None] * len(self.labels))
91 elif len(getattr(self, _property)) == 1 and len(self.labels) != 1:
92 setattr(
93 self, "_{}".format(_property),
94 getattr(self, _property) * len(self.labels)
95 )
96 if self.opts.approximant is None:
97 approx = [None] * len(self.labels)
98 else:
99 approx = self.opts.approximant
100 self.approximant_flags = self.opts.approximant_flags
101 resume_file = [
102 os.path.join(
103 self.webdir, "checkpoint",
104 "{}_conversion_class.pickle".format(label)
105 ) for label in self.labels
106 ]
108 try:
109 for num, label in enumerate(self.labels):
110 try:
111 psd = self.psd[label]
112 except KeyError:
113 psd = {}
114 kwargs[label].update(dict(
115 evolve_spins_forwards=self.evolve_spins_forwards,
116 evolve_spins_backwards=self.evolve_spins_backwards,
117 f_start=self.f_start[num], f_low=self.f_low[num],
118 approximant_flags=self.approximant_flags.get(label, {}),
119 approximant=approx[num], f_ref=self.f_ref[num],
120 NRSur_fits=self.NRSur_fits, return_kwargs=True,
121 multipole_snr=self.calculate_multipole_snr,
122 precessing_snr=self.calculate_precessing_snr,
123 psd=psd, f_final=self.f_final[num],
124 waveform_fits=self.waveform_fits,
125 multi_process=self.opts.multi_process,
126 redshift_method=self.redshift_method,
127 cosmology=self.cosmology,
128 no_conversion=self.no_conversion,
129 add_zero_spin=True, delta_f=self.delta_f[num],
130 psd_default=self.psd_default,
131 disable_remnant=self.disable_remnant,
132 force_BBH_remnant_computation=self.force_BBH_remnant_computation,
133 resume_file=resume_file[num],
134 restart_from_checkpoint=self.restart_from_checkpoint,
135 force_BH_spin_evolution=self.force_BH_spin_evolution,
136 ))
137 return kwargs
138 except IndexError:
139 logger.warning(
140 "Unable to find an f_ref, f_low and approximant for each "
141 "label. Using and f_ref={}, f_low={} and approximant={} "
142 "for all result files".format(
143 self.f_ref[0], self.f_low[0], approx[0]
144 )
145 )
146 for num, label in enumerate(self.labels):
147 kwargs[label].update(dict(
148 evolve_spins_forwards=self.evolve_spins_forwards,
149 evolve_spins_backwards=self.evolve_spins_backwards,
150 f_start=self.f_start[0], f_low=self.f_low[0],
151 approximant=approx[0], f_ref=self.f_ref[0],
152 NRSur_fits=self.NRSur_fits, return_kwargs=True,
153 multipole_snr=self.calculate_multipole_snr,
154 precessing_snr=self.calculate_precessing_snr,
155 psd=self.psd[self.labels[0]], f_final=self.f_final[0],
156 waveform_fits=self.waveform_fits,
157 multi_process=self.opts.multi_process,
158 redshift_method=self.redshift_method,
159 cosmology=self.cosmology,
160 no_conversion=self.no_conversion,
161 add_zero_spin=True, delta_f=self.delta_f[0],
162 psd_default=self.psd_default,
163 disable_remnant=self.disable_remnant,
164 force_BBH_remnant_computation=self.force_BBH_remnant_computation,
165 resume_file=resume_file[num],
166 restart_from_checkpoint=self.restart_from_checkpoint,
167 force_BH_spin_evolution=self.force_BH_spin_evolution
168 ))
169 return kwargs
171 @staticmethod
172 def grab_data_from_file(
173 file, label, webdir, config=None, injection=None, file_format=None,
174 nsamples=None, disable_prior_sampling=False, **kwargs
175 ):
176 """Grab data from a result file containing posterior samples
178 Parameters
179 ----------
180 file: str
181 path to the result file
182 label: str
183 label that you wish to use for the result file
184 config: str, optional
185 path to a configuration file used in the analysis
186 injection: str, optional
187 path to an injection file used in the analysis
188 file_format, str, optional
189 the file format you wish to use when loading. Default None.
190 If None, the read function loops through all possible options
191 """
192 data = pesummary.core.cli.inputs._Input.grab_data_from_file(
193 file, label, webdir, config=config, injection=injection,
194 read_function=GWRead, file_format=file_format, nsamples=nsamples,
195 disable_prior_sampling=disable_prior_sampling, **kwargs
196 )
197 try:
198 for _kwgs in data["file_kwargs"][label]:
199 if "approximant_flags" in _kwgs["meta_data"]:
200 for key, item in _kwargs["meta_data"]["approximant_flags"].items():
201 warning_cond = (
202 key in self.approximant_flags[label] and
203 self.approximant_flags[label][key] != item
204 )
205 if warning_cond:
206 logger.warning(
207 "Approximant flag {}={} found in result file for {}. "
208 "Ignoring and using the provided values {}={}".format(
209 key, self.approximant_flags[label][key], label,
210 key, item
211 )
212 )
213 else:
214 self.approximant_flags[label][key] = item
215 except Exception:
216 pass
217 return data
219 @property
220 def reweight_samples(self):
221 return self._reweight_samples
223 @reweight_samples.setter
224 def reweight_samples(self, reweight_samples):
225 from pesummary.gw.reweight import options
226 self._reweight_samples = self._check_reweight_samples(
227 reweight_samples, options
228 )
230 def _set_samples(self, *args, **kwargs):
231 super(_GWInput, self)._set_samples(*args, **kwargs)
232 if "calibration" not in self.priors:
233 self.priors["calibration"] = {
234 label: {} for label in self.labels
235 }
237 def _set_corner_params(self, corner_params):
238 corner_params = super(_GWInput, self)._set_corner_params(corner_params)
239 if corner_params is None:
240 logger.debug(
241 "Using the default corner parameters: {}".format(
242 ", ".join(conf.gw_corner_parameters)
243 )
244 )
245 else:
246 _corner_params = corner_params
247 corner_params = list(set(conf.gw_corner_parameters + corner_params))
248 for param in _corner_params:
249 _data = self.samples
250 if not all(param in _data[label].keys() for label in self.labels):
251 corner_params.remove(param)
252 logger.debug(
253 "Generating a corner plot with the following "
254 "parameters: {}".format(", ".join(corner_params))
255 )
256 return corner_params
258 @property
259 def cosmology(self):
260 return self._cosmology
262 @cosmology.setter
263 def cosmology(self, cosmology):
264 from pesummary.gw.cosmology import available_cosmologies
266 if cosmology.lower() not in available_cosmologies:
267 logger.warning(
268 "Unrecognised cosmology: {}. Using {} as default".format(
269 cosmology, conf.cosmology
270 )
271 )
272 cosmology = conf.cosmology
273 else:
274 logger.debug("Using the {} cosmology".format(cosmology))
275 self._cosmology = cosmology
277 @property
278 def approximant(self):
279 return self._approximant
281 @approximant.setter
282 def approximant(self, approximant):
283 if not hasattr(self, "_approximant"):
284 approximant_list = {i: {} for i in self.labels}
285 if approximant is None:
286 logger.warning(
287 "No approximant passed. Waveform plots will not be "
288 "generated"
289 )
290 elif approximant is not None:
291 if len(approximant) != len(self.labels):
292 raise InputError(
293 "Please pass an approximant for each result file"
294 )
295 approximant_list = {
296 i: j for i, j in zip(self.labels, approximant)
297 }
298 self._approximant = approximant_list
299 else:
300 for num, i in enumerate(self._approximant.keys()):
301 if self._approximant[i] == {}:
302 if num == 0:
303 logger.warning(
304 "No approximant passed. Waveform plots will not be "
305 "generated"
306 )
307 self._approximant[i] = None
308 break
310 @property
311 def approximant_flags(self):
312 return self._approximant_flags
314 @approximant_flags.setter
315 def approximant_flags(self, approximant_flags):
316 if hasattr(self, "_approximant_flags"):
317 return
318 _approximant_flags = {key: {} for key in self.labels}
319 for key, item in approximant_flags.items():
320 _label, key = key.split(":")
321 if _label not in self.labels:
322 raise ValueError(
323 "Unable to assign waveform flags for label:{} because "
324 "it does not exist. Available labels are: {}. Approximant "
325 "flags must be provided in the form LABEL:FLAG:VALUE".format(
326 _label, ", ".join(self.labels)
327 )
328 )
329 _approximant_flags[_label][key] = item
330 self._approximant_flags = _approximant_flags
332 @property
333 def gracedb_server(self):
334 return self._gracedb_server
336 @gracedb_server.setter
337 def gracedb_server(self, gracedb_server):
338 if gracedb_server is None:
339 self._gracedb_server = conf.gracedb_server
340 else:
341 logger.debug(
342 "Using '{}' as the GraceDB server".format(gracedb_server)
343 )
344 self._gracedb_server = gracedb_server
346 @property
347 def gracedb(self):
348 return self._gracedb
350 @gracedb.setter
351 def gracedb(self, gracedb):
352 self._gracedb = gracedb
353 if gracedb is not None:
354 from pesummary.gw.gracedb import get_gracedb_data, HTTPError
355 from json.decoder import JSONDecodeError
357 first_letter = gracedb[0]
358 if first_letter != "G" and first_letter != "g" and first_letter != "S":
359 logger.warning(
360 "Invalid GraceDB ID passed. The GraceDB ID must be of the "
361 "form G0000 or S0000. Ignoring input."
362 )
363 self._gracedb = None
364 return
365 _error = (
366 "Unable to download data from Gracedb because {}. Only storing "
367 "the GraceDB ID in the metafile"
368 )
369 try:
370 logger.info(
371 "Downloading {} from gracedb for {}".format(
372 ", ".join(self.gracedb_data), gracedb
373 )
374 )
375 json = get_gracedb_data(
376 gracedb, info=self.gracedb_data,
377 service_url=self.gracedb_server
378 )
379 json["id"] = gracedb
380 except (HTTPError, RuntimeError, JSONDecodeError) as e:
381 logger.warning(_error.format(e))
382 json = {"id": gracedb}
384 for label in self.labels:
385 self.file_kwargs[label]["meta_data"]["gracedb"] = json
387 @property
388 def detectors(self):
389 return self._detectors
391 @detectors.setter
392 def detectors(self, detectors):
393 detector = {}
394 if not detectors:
395 for i in self.samples.keys():
396 params = list(self.samples[i].keys())
397 individual_detectors = []
398 for j in params:
399 if "optimal_snr" in j and j != "network_optimal_snr":
400 det = j.split("_optimal_snr")[0]
401 individual_detectors.append(det)
402 individual_detectors = sorted(
403 [str(i) for i in individual_detectors])
404 if individual_detectors:
405 detector[i] = "_".join(individual_detectors)
406 else:
407 detector[i] = None
408 else:
409 detector = detectors
410 logger.debug("The detector network is %s" % (detector))
411 self._detectors = detector
413 @property
414 def skymap(self):
415 return self._skymap
417 @skymap.setter
418 def skymap(self, skymap):
419 if not hasattr(self, "_skymap"):
420 self._skymap = {i: None for i in self.labels}
422 @property
423 def calibration_definition(self):
424 return self._calibration_definition
426 @calibration_definition.setter
427 def calibration_definition(self, calibration_definition):
428 if not len(self.opts.calibration):
429 self._calibration_definition = None
430 return
431 if len(calibration_definition) == 1:
432 logger.info(
433 f"Assuming that the calibration correction was applied to "
434 f"'{calibration_definition[0]}' for all analyses"
435 )
436 calibration_definition *= len(self.labels)
437 elif len(calibration_definition) != len(self.labels):
438 raise ValueError(
439 f"Please provide a calibration definition for each analysis "
440 f"({len(self.labels)}) or a single definition to use for all "
441 f"analyses"
442 )
443 if any(_ not in ["data", "template"] for _ in calibration_definition):
444 raise ValueError(
445 "Calibration definitions must be either 'data' or 'template'"
446 )
447 self._calibration_definition = {
448 label: calibration_definition[num] for num, label in
449 enumerate(self.labels)
450 }
452 @property
453 def calibration(self):
454 return self._calibration
456 @calibration.setter
457 def calibration(self, calibration):
458 if not hasattr(self, "_calibration"):
459 data = {i: {} for i in self.labels}
460 if calibration != {}:
461 prior_data = self.get_psd_or_calibration_data(
462 calibration, self.extract_calibration_data_from_file,
463 type=self.calibration_definition[self.labels[0]]
464 )
465 self.add_to_prior_dict("calibration", prior_data)
466 else:
467 prior_data = {i: {} for i in self.labels}
468 for label in self.labels:
469 if hasattr(self.opts, "{}_calibration".format(label)):
470 cal_data = getattr(self.opts, "{}_calibration".format(label))
471 if cal_data != {} and cal_data is not None:
472 prior_data[label] = {
473 ifo: self.extract_calibration_data_from_file(
474 cal_data[ifo], type=self.calibration_definition[label]
475 ) for ifo in cal_data.keys()
476 }
477 if not all(prior_data[i] == {} for i in self.labels):
478 self.add_to_prior_dict("calibration", prior_data)
479 else:
480 self.add_to_prior_dict("calibration", {})
481 for num, i in enumerate(self.result_files):
482 _opened = self._open_result_files
483 if i in _opened.keys() and _opened[i] is not None:
484 f = self._open_result_files[i]
485 else:
486 f = GWRead(i, disable_prior=True)
487 try:
488 calibration_data = f.interpolate_calibration_spline_posterior()
489 except Exception as e:
490 logger.warning(
491 "Failed to extract calibration data from the result "
492 "file: {} because {}".format(i, e)
493 )
494 calibration_data = None
495 labels = list(self.samples.keys())
496 if calibration_data is None:
497 data[labels[num]] = {
498 None: None
499 }
500 elif isinstance(f, pesummary.gw.file.formats.pesummary.PESummary):
501 for num in range(len(calibration_data[0])):
502 data[labels[num]] = {
503 j: k for j, k in zip(
504 calibration_data[1][num],
505 calibration_data[0][num]
506 )
507 }
508 else:
509 data[labels[num]] = {
510 j: k for j, k in zip(
511 calibration_data[1], calibration_data[0]
512 )
513 }
514 self._calibration = data
516 @property
517 def psd(self):
518 return self._psd
520 @psd.setter
521 def psd(self, psd):
522 if not hasattr(self, "_psd"):
523 data = {i: {} for i in self.labels}
524 if psd != {}:
525 data = self.get_psd_or_calibration_data(
526 psd, self.extract_psd_data_from_file
527 )
528 else:
529 for label in self.labels:
530 if hasattr(self.opts, "{}_psd".format(label)):
531 psd_data = getattr(self.opts, "{}_psd".format(label))
532 if psd_data != {} and psd_data is not None:
533 data[label] = {
534 ifo: self.extract_psd_data_from_file(
535 psd_data[ifo], IFO=ifo
536 ) for ifo in psd_data.keys()
537 }
538 self._psd = data
540 @property
541 def nsamples_for_skymap(self):
542 return self._nsamples_for_skymap
544 @nsamples_for_skymap.setter
545 def nsamples_for_skymap(self, nsamples_for_skymap):
546 self._nsamples_for_skymap = nsamples_for_skymap
547 if nsamples_for_skymap is not None:
548 self._nsamples_for_skymap = int(nsamples_for_skymap)
549 number_of_samples = [
550 data.number_of_samples for label, data in self.samples.items()
551 ]
552 if not all(i > self._nsamples_for_skymap for i in number_of_samples):
553 min_arg = np.argmin(number_of_samples)
554 logger.warning(
555 "You have specified that you would like to use {} "
556 "samples to generate the skymap but the file {} only "
557 "has {} samples. Reducing the number of samples to "
558 "generate the skymap to {}".format(
559 self._nsamples_for_skymap, self.result_files[min_arg],
560 number_of_samples[min_arg], number_of_samples[min_arg]
561 )
562 )
563 self._nsamples_for_skymap = int(number_of_samples[min_arg])
565 @property
566 def gwdata(self):
567 return self._gwdata
569 @gwdata.setter
570 def gwdata(self, gwdata):
571 from pesummary.gw.file.strain import StrainDataDict
573 self._gwdata = gwdata
574 if gwdata is not None:
575 if isinstance(gwdata, dict):
576 for i in gwdata.keys():
577 if not os.path.isfile(gwdata[i]):
578 raise InputError(
579 "The file {} does not exist. Please check the path "
580 "to your strain file".format(gwdata[i])
581 )
582 self._gwdata = StrainDataDict.read(gwdata)
583 else:
584 logger.warning(
585 "Please provide gwdata as a dictionary with keys "
586 "displaying the channel and item giving the path to the "
587 "strain file"
588 )
589 self._gwdata = None
591 @property
592 def evolve_spins_forwards(self):
593 return self._evolve_spins_forwards
595 @evolve_spins_forwards.setter
596 def evolve_spins_forwards(self, evolve_spins_forwards):
597 self._evolve_spins_forwards = evolve_spins_forwards
598 _msg = "Spins will be evolved up to {}"
599 if evolve_spins_forwards:
600 logger.info(_msg.format("Schwarzschild ISCO frequency"))
601 self._evolve_spins_forwards = 6. ** -0.5
603 @property
604 def evolve_spins_backwards(self):
605 return self._evolve_spins_backwards
607 @evolve_spins_backwards.setter
608 def evolve_spins_backwards(self, evolve_spins_backwards):
609 self._evolve_spins_backwards = evolve_spins_backwards
610 _msg = (
611 "Spins will be evolved backwards to an infinite separation using the '{}' "
612 "method"
613 )
614 if isinstance(evolve_spins_backwards, (str, bytes)):
615 logger.info(_msg.format(evolve_spins_backwards))
616 elif evolve_spins_backwards is None:
617 logger.info(_msg.format("precession_averaged"))
618 self._evolve_spins_backwards = "precession_averaged"
620 @property
621 def NRSur_fits(self):
622 return self._NRSur_fits
624 @NRSur_fits.setter
625 def NRSur_fits(self, NRSur_fits):
626 self._NRSur_fits = NRSur_fits
627 base = (
628 "Using the '{}' NRSurrogate model to calculate the remnant "
629 "quantities"
630 )
631 if isinstance(NRSur_fits, (str, bytes)):
632 logger.info(base.format(NRSur_fits))
633 self._NRSur_fits = NRSur_fits
634 elif NRSur_fits is None:
635 from pesummary.gw.conversions.nrutils import NRSUR_MODEL
637 logger.info(base.format(NRSUR_MODEL))
638 self._NRSur_fits = NRSUR_MODEL
640 @property
641 def waveform_fits(self):
642 return self._waveform_fits
644 @waveform_fits.setter
645 def waveform_fits(self, waveform_fits):
646 self._waveform_fits = waveform_fits
647 if waveform_fits:
648 logger.info(
649 "Evaluating the remnant quantities using the provided "
650 "approximant"
651 )
653 @property
654 def f_low(self):
655 return self._f_low
657 @f_low.setter
658 def f_low(self, f_low):
659 self._f_low = f_low
660 if f_low is not None:
661 self._f_low = [float(i) for i in f_low]
663 @property
664 def f_start(self):
665 return self._f_start
667 @f_start.setter
668 def f_start(self, f_start):
669 self._f_start = f_start
670 if f_start is not None:
671 self._f_start = [float(i) for i in f_start]
673 @property
674 def f_ref(self):
675 return self._f_ref
677 @f_ref.setter
678 def f_ref(self, f_ref):
679 self._f_ref = f_ref
680 if f_ref is not None:
681 self._f_ref = [float(i) for i in f_ref]
683 @property
684 def f_final(self):
685 return self._f_final
687 @f_final.setter
688 def f_final(self, f_final):
689 self._f_final = f_final
690 if f_final is not None:
691 self._f_final = [float(i) for i in f_final]
693 @property
694 def delta_f(self):
695 return self._delta_f
697 @delta_f.setter
698 def delta_f(self, delta_f):
699 self._delta_f = delta_f
700 if delta_f is not None:
701 self._delta_f = [float(i) for i in delta_f]
703 @property
704 def psd_default(self):
705 return self._psd_default
707 @psd_default.setter
708 def psd_default(self, psd_default):
709 self._psd_default = psd_default
710 if "stored:" in psd_default:
711 label = psd_default.split("stored:")[1]
712 self._psd_default = "{file}.psd['%s']" % (label)
713 return
714 try:
715 from pycbc import psd
716 psd_default = getattr(psd, psd_default)
717 except ImportError:
718 logger.warning(
719 "Unable to import 'pycbc'. Unable to generate a default PSD"
720 )
721 psd_default = None
722 except AttributeError:
723 logger.warning(
724 "'pycbc' does not have the '{}' psd available. Using '{}' as "
725 "the default PSD".format(psd_default, conf.psd)
726 )
727 psd_default = getattr(psd, conf.psd)
728 except ValueError as e:
729 logger.warning("Setting 'psd_default' to None because {}".format(e))
730 psd_default = None
731 self._psd_default = psd_default
733 @property
734 def pepredicates_probs(self):
735 return self._pepredicates_probs
737 @pepredicates_probs.setter
738 def pepredicates_probs(self, pepredicates_probs):
739 from pesummary.gw.classification import PEPredicates
741 classifications = {}
742 for num, i in enumerate(list(self.samples.keys())):
743 try:
744 classifications[i] = PEPredicates(
745 self.samples[i]
746 ).dual_classification()
747 except Exception as e:
748 logger.warning(
749 "Failed to generate source classification probabilities "
750 "because {}".format(e)
751 )
752 classifications[i] = None
753 if self.mcmc_samples:
754 if any(_probs is None for _probs in classifications.values()):
755 classifications[self.labels[0]] = None
756 logger.warning(
757 "Unable to average classification probabilities across "
758 "mcmc chains because one or more chains failed to estimate "
759 "classifications"
760 )
761 else:
762 logger.debug(
763 "Averaging classification probability across mcmc samples"
764 )
765 classifications[self.labels[0]] = {
766 prior: {
767 key: np.round(np.mean(
768 [val[prior][key] for val in classifications.values()]
769 ), 3) for key in _probs.keys()
770 } for prior, _probs in
771 list(classifications.values())[0].items()
772 }
773 self._pepredicates_probs = classifications
775 @property
776 def pastro_probs(self):
777 return self._pastro_probs
779 @pastro_probs.setter
780 def pastro_probs(self, pastro_probs):
781 from pesummary.gw.classification import PAstro
783 probabilities = {}
784 for num, i in enumerate(list(self.samples.keys())):
785 try:
786 probabilities[i] = PAstro(self.samples[i]).dual_classification()
787 except Exception as e:
788 logger.warning(
789 "Failed to generate em_bright probabilities because "
790 "{}".format(e)
791 )
792 probabilities[i] = None
793 if self.mcmc_samples:
794 if any(_probs is None for _probs in probabilities.values()):
795 probabilities[self.labels[0]] = None
796 logger.warning(
797 "Unable to average em_bright probabilities across "
798 "mcmc chains because one or more chains failed to estimate "
799 "probabilities"
800 )
801 else:
802 logger.debug(
803 "Averaging em_bright probability across mcmc samples"
804 )
805 probabilities[self.labels[0]] = {
806 prior: {
807 key: np.round(np.mean(
808 [val[prior][key] for val in probabilities.values()]
809 ), 3) for key in _probs.keys()
810 } for prior, _probs in list(probabilities.values())[0].items()
811 }
812 self._pastro_probs = probabilities
814 @property
815 def preliminary_pages(self):
816 return self._preliminary_pages
818 @preliminary_pages.setter
819 def preliminary_pages(self, preliminary_pages):
820 required = conf.gw_reproducibility
821 self._preliminary_pages = {label: False for label in self.labels}
822 for num, label in enumerate(self.labels):
823 for attr in required:
824 _property = getattr(self, attr)
825 if isinstance(_property, dict):
826 if label not in _property.keys():
827 self._preliminary_pages[label] = True
828 elif not len(_property[label]):
829 self._preliminary_pages[label] = True
830 elif isinstance(_property, list):
831 if _property[num] is None:
832 self._preliminary_pages[label] = True
833 if any(value for value in self._preliminary_pages.values()):
834 _labels = [
835 label for label, value in self._preliminary_pages.items() if
836 value
837 ]
838 msg = (
839 "Unable to reproduce the {} analys{} because no {} data was "
840 "provided. 'Preliminary' watermarks will be added to the final "
841 "html pages".format(
842 ", ".join(_labels), "es" if len(_labels) > 1 else "is",
843 " or ".join(required)
844 )
845 )
846 logger.warning(msg)
848 @staticmethod
849 def _extract_IFO_data_from_file(file, cls, desc, IFO=None, **kwargs):
850 """Return IFO data stored in a file
852 Parameters
853 ----------
854 file: path
855 path to a file containing the IFO data
856 cls: obj
857 class you wish to use when loading the file. This class must have
858 a '.read' method
859 desc: str
860 description of the IFO data stored in the file
861 IFO: str, optional
862 the IFO which the data belongs to
863 """
864 general = (
865 "Failed to read in %s data because {}. The %s plot will not be "
866 "generated and the %s data will not be added to the metafile."
867 ) % (desc, desc, desc)
868 try:
869 return cls.read(file, IFO=IFO, **kwargs)
870 except FileNotFoundError:
871 logger.warning(
872 general.format("the file {} does not exist".format(file))
873 )
874 return {}
875 except ValueError as e:
876 logger.warning(general.format(e))
877 return {}
879 @staticmethod
880 def extract_psd_data_from_file(file, IFO=None):
881 """Return the data stored in a psd file
883 Parameters
884 ----------
885 file: path
886 path to a file containing the psd data
887 """
888 from pesummary.gw.file.psd import PSD
889 return _GWInput._extract_IFO_data_from_file(file, PSD, "PSD", IFO=IFO)
891 @staticmethod
892 def extract_calibration_data_from_file(file, type="data", **kwargs):
893 """Return the data stored in a calibration file
895 Parameters
896 ----------
897 file: path
898 path to a file containing the calibration data
899 """
900 from pesummary.gw.file.calibration import Calibration
901 return _GWInput._extract_IFO_data_from_file(
902 file, Calibration, "calibration", type=type, **kwargs
903 )
905 @staticmethod
906 def get_ifo_from_file_name(file):
907 """Return the IFO from the file name
909 Parameters
910 ----------
911 file: str
912 path to the file
913 """
914 file_name = file.split("/")[-1]
915 if any(j in file_name for j in ["H1", "_0", "IFO0"]):
916 ifo = "H1"
917 elif any(j in file_name for j in ["L1", "_1", "IFO1"]):
918 ifo = "L1"
919 elif any(j in file_name for j in ["V1", "_2", "IFO2"]):
920 ifo = "V1"
921 else:
922 ifo = file_name
923 return ifo
925 def get_psd_or_calibration_data(self, input, executable, **kwargs):
926 """Return a dictionary containing the psd or calibration data
928 Parameters
929 ----------
930 input: list/dict
931 list/dict containing paths to calibration/psd files
932 executable: func
933 executable that is used to extract the data from the calibration/psd
934 files
935 """
936 data = {}
937 if input == {} or input == []:
938 return data
939 if isinstance(input, dict):
940 keys = list(input.keys())
941 if isinstance(input, dict) and isinstance(input[keys[0]], list):
942 if not all(len(input[i]) == len(self.labels) for i in list(keys)):
943 raise InputError(
944 "Please ensure the number of calibration/psd files matches "
945 "the number of result files passed"
946 )
947 for idx in range(len(input[keys[0]])):
948 data[self.labels[idx]] = {
949 i: executable(input[i][idx], IFO=i, **kwargs) for i in list(keys)
950 }
951 elif isinstance(input, dict):
952 for i in self.labels:
953 data[i] = {
954 j: executable(input[j], IFO=j, **kwargs) for j in list(input.keys())
955 }
956 elif isinstance(input, list):
957 for i in self.labels:
958 data[i] = {
959 self.get_ifo_from_file_name(j): executable(
960 j, IFO=self.get_ifo_from_file_name(j), **kwargs
961 ) for j in input
962 }
963 else:
964 raise InputError(
965 "Did not understand the psd/calibration input. Please use the "
966 "following format 'H1:path/to/file'"
967 )
968 return data
970 def grab_priors_from_inputs(self, priors):
971 def read_func(data, **kwargs):
972 from pesummary.gw.file.read import read as GWRead
973 data = GWRead(data, **kwargs)
974 data.generate_all_posterior_samples()
975 return data
977 return super(_GWInput, self).grab_priors_from_inputs(
978 priors, read_func=read_func, read_kwargs=self.grab_data_kwargs
979 )
981 def grab_key_data_from_result_files(self):
982 """Grab the mean, median, maxL and standard deviation for all
983 parameters for all each result file
984 """
985 from pesummary.utils.kde_list import KDEList
986 from pesummary.gw.plots.plot import _return_bounds
987 from pesummary.utils.credible_interval import (
988 hpd_two_sided_credible_interval
989 )
990 from pesummary.utils.bounded_1d_kde import bounded_1d_kde
991 key_data = super(_GWInput, self).grab_key_data_from_result_files()
992 bounded_parameters = ["mass_ratio", "a_1", "a_2", "lambda_tilde"]
993 for param in bounded_parameters:
994 xlow, xhigh = _return_bounds(param, [])
995 _samples = {
996 key: val[param] for key, val in self.samples.items()
997 if param in val.keys()
998 }
999 _min = [np.min(_) for _ in _samples.values() if len(_samples)]
1000 _max = [np.max(_) for _ in _samples.values() if len(_samples)]
1001 if not len(_min):
1002 continue
1003 _min = np.min(_min)
1004 _max = np.max(_max)
1005 x = np.linspace(_min, _max, 1000)
1006 try:
1007 kdes = KDEList(
1008 list(_samples.values()), kde=bounded_1d_kde,
1009 kde_kwargs={"xlow": xlow, "xhigh": xhigh}
1010 )
1011 except Exception as e:
1012 logger.warning(
1013 "Unable to compute the HPD interval for {} because {}".format(
1014 param, e
1015 )
1016 )
1017 continue
1018 pdfs = kdes(x)
1019 for num, key in enumerate(_samples.keys()):
1020 [xlow, xhigh], _ = hpd_two_sided_credible_interval(
1021 [], 90, x=x, pdf=pdfs[num]
1022 )
1023 key_data[key][param]["90% HPD"] = [xlow, xhigh]
1024 for _param in self.samples[key].keys():
1025 if _param in bounded_parameters:
1026 continue
1027 key_data[key][_param]["90% HPD"] = float("nan")
1028 return key_data
1031class SamplesInput(_GWInput, pesummary.core.cli.inputs.SamplesInput):
1032 """Class to handle and store sample specific command line arguments
1033 """
1034 def __init__(self, *args, **kwargs):
1035 kwargs.update({"ignore_copy": True})
1036 super(SamplesInput, self).__init__(
1037 *args, gw=True, extra_options=[
1038 "evolve_spins_forwards",
1039 "evolve_spins_backwards",
1040 "NRSur_fits",
1041 "calculate_multipole_snr",
1042 "calculate_precessing_snr",
1043 "f_start",
1044 "f_low",
1045 "f_ref",
1046 "f_final",
1047 "psd",
1048 "waveform_fits",
1049 "redshift_method",
1050 "cosmology",
1051 "no_conversion",
1052 "delta_f",
1053 "psd_default",
1054 "disable_remnant",
1055 "force_BBH_remnant_computation",
1056 "force_BH_spin_evolution"
1057 ], **kwargs
1058 )
1059 if self._restarted_from_checkpoint:
1060 return
1061 if self.existing is not None:
1062 self.existing_data = self.grab_data_from_metafile(
1063 self.existing_metafile, self.existing,
1064 compare=self.compare_results
1065 )
1066 self.existing_approximant = self.existing_data["approximant"]
1067 self.existing_psd = self.existing_data["psd"]
1068 self.existing_calibration = self.existing_data["calibration"]
1069 self.existing_skymap = self.existing_data["skymap"]
1070 else:
1071 self.existing_approximant = None
1072 self.existing_psd = None
1073 self.existing_calibration = None
1074 self.existing_skymap = None
1075 self.approximant = self.opts.approximant
1076 self.approximant_flags = self.opts.approximant_flags
1077 self.detectors = None
1078 self.skymap = None
1079 self.calibration_definition = self.opts.calibration_definition
1080 self.calibration = self.opts.calibration
1081 self.gwdata = self.opts.gwdata
1082 self.maxL_samples = []
1084 @property
1085 def maxL_samples(self):
1086 return self._maxL_samples
1088 @maxL_samples.setter
1089 def maxL_samples(self, maxL_samples):
1090 key_data = self.grab_key_data_from_result_files()
1091 maxL_samples = {
1092 i: {
1093 j: key_data[i][j]["maxL"] for j in key_data[i].keys()
1094 } for i in key_data.keys()
1095 }
1096 for i in self.labels:
1097 maxL_samples[i]["approximant"] = self.approximant[i]
1098 self._maxL_samples = maxL_samples
1101class PlottingInput(SamplesInput, pesummary.core.cli.inputs.PlottingInput):
1102 """Class to handle and store plottig specific command line arguments
1103 """
1104 def __init__(self, *args, **kwargs):
1105 super(PlottingInput, self).__init__(*args, **kwargs)
1106 self.nsamples_for_skymap = self.opts.nsamples_for_skymap
1107 self.sensitivity = self.opts.sensitivity
1108 self.no_ligo_skymap = self.opts.no_ligo_skymap
1109 self.multi_threading_for_skymap = self.multi_process
1110 if not self.no_ligo_skymap and self.multi_process > 1:
1111 total = self.multi_process
1112 self.multi_threading_for_plots = int(total / 2.)
1113 self.multi_threading_for_skymap = total - self.multi_threading_for_plots
1114 logger.info(
1115 "Assigning {} process{}to skymap generation and {} process{}to "
1116 "other plots".format(
1117 self.multi_threading_for_skymap,
1118 "es " if self.multi_threading_for_skymap > 1 else " ",
1119 self.multi_threading_for_plots,
1120 "es " if self.multi_threading_for_plots > 1 else " "
1121 )
1122 )
1123 self.preliminary_pages = None
1124 self.pepredicates_probs = []
1125 self.pastro_probs = []
1128class WebpageInput(SamplesInput, pesummary.core.cli.inputs.WebpageInput):
1129 """Class to handle and store webpage specific command line arguments
1130 """
1131 def __init__(self, *args, **kwargs):
1132 super(WebpageInput, self).__init__(*args, **kwargs)
1133 self.gracedb_server = self.opts.gracedb_server
1134 self.gracedb_data = self.opts.gracedb_data
1135 self.gracedb = self.opts.gracedb
1136 self.public = self.opts.public
1137 if not hasattr(self, "preliminary_pages"):
1138 self.preliminary_pages = None
1139 if not hasattr(self, "pepredicates_probs"):
1140 self.pepredicates_probs = []
1141 if not hasattr(self, "pastro_probs"):
1142 self.pastro_probs = []
1145class WebpagePlusPlottingInput(PlottingInput, WebpageInput):
1146 """Class to handle and store webpage and plotting specific command line
1147 arguments
1148 """
1149 def __init__(self, *args, **kwargs):
1150 super(WebpagePlusPlottingInput, self).__init__(*args, **kwargs)
1152 @property
1153 def default_directories(self):
1154 return super(WebpagePlusPlottingInput, self).default_directories
1156 @property
1157 def default_files_to_copy(self):
1158 return super(WebpagePlusPlottingInput, self).default_files_to_copy
1161class MetaFileInput(SamplesInput, pesummary.core.cli.inputs.MetaFileInput):
1162 """Class to handle and store metafile specific command line arguments
1163 """
1164 @property
1165 def default_directories(self):
1166 dirs = super(MetaFileInput, self).default_directories
1167 dirs += ["psds", "calibration"]
1168 return dirs
1170 def copy_files(self):
1171 _error = "Failed to save the {} to file"
1172 for label in self.labels:
1173 if self.psd[label] != {}:
1174 for ifo in self.psd[label].keys():
1175 if not isinstance(self.psd[label][ifo], PSD):
1176 logger.warning(_error.format("{} PSD".format(ifo)))
1177 continue
1178 self.psd[label][ifo].save_to_file(
1179 os.path.join(self.webdir, "psds", "{}_{}_psd.dat".format(
1180 label, ifo
1181 ))
1182 )
1183 if label in self.priors["calibration"].keys():
1184 if self.priors["calibration"][label] != {}:
1185 for ifo in self.priors["calibration"][label].keys():
1186 _instance = isinstance(
1187 self.priors["calibration"][label][ifo], Calibration
1188 )
1189 if not _instance:
1190 logger.warning(
1191 _error.format(
1192 "{} calibration envelope".format(
1193 ifo
1194 )
1195 )
1196 )
1197 continue
1198 self.priors["calibration"][label][ifo].save_to_file(
1199 os.path.join(self.webdir, "calibration", "{}_{}_cal.txt".format(
1200 label, ifo
1201 ))
1202 )
1203 return super(MetaFileInput, self).copy_files()
1206class WebpagePlusPlottingPlusMetaFileInput(MetaFileInput, WebpagePlusPlottingInput):
1207 """Class to handle and store webpage, plotting and metafile specific command
1208 line arguments
1209 """
1210 def __init__(self, *args, **kwargs):
1211 super(WebpagePlusPlottingPlusMetaFileInput, self).__init__(
1212 *args, **kwargs
1213 )
1215 @property
1216 def default_directories(self):
1217 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_directories
1219 @property
1220 def default_files_to_copy(self):
1221 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_files_to_copy
1224@deprecation(
1225 "The GWInput class is deprecated. Please use either the BaseInput, "
1226 "SamplesInput, PlottingInput, WebpageInput, WebpagePlusPlottingInput, "
1227 "MetaFileInput or the WebpagePlusPlottingPlusMetaFileInput class"
1228)
1229class GWInput(WebpagePlusPlottingPlusMetaFileInput):
1230 pass
1233class IMRCTInput(pesummary.core.cli.inputs._Input):
1234 """Class to handle the TGR specific command line arguments
1235 """
1236 @property
1237 def labels(self):
1238 return self._labels
1240 @labels.setter
1241 def labels(self, labels):
1242 self._labels = labels
1243 if len(labels) % 2 != 0:
1244 raise ValueError(
1245 "The IMRCT test requires 2 results files for each analysis. "
1246 )
1247 elif len(labels) > 2:
1248 cond = all(
1249 ":inspiral" in label or ":postinspiral" in label for label in
1250 labels
1251 )
1252 if not cond:
1253 raise ValueError(
1254 "To compare 2 or more analyses, please provide labels as "
1255 "'{}:inspiral' and '{}:postinspiral' where {} indicates "
1256 "the analysis label"
1257 )
1258 else:
1259 self.analysis_label = [
1260 label.split(":inspiral")[0]
1261 for label in labels
1262 if ":inspiral" in label and ":postinspiral" not in label
1263 ]
1264 if len(self.analysis_label) != len(self.result_files) / 2:
1265 raise ValueError(
1266 "When comparing more than 2 analyses, labels must "
1267 "be of the form '{}:inspiral' and '{}:postinspiral'."
1268 )
1269 logger.info(
1270 "Using the labels: {} to distinguish analyses".format(
1271 ", ".join(self.analysis_label)
1272 )
1273 )
1274 elif sorted(labels) != ["inspiral", "postinspiral"]:
1275 if all(self.is_pesummary_metafile(ff) for ff in self.result_files):
1276 meta_file_labels = []
1277 for suffix in [":inspiral", ":postinspiral"]:
1278 if any(suffix in label for label in labels):
1279 ind = [
1280 num for num, label in enumerate(labels) if
1281 suffix in label
1282 ]
1283 if len(ind) > 1:
1284 raise ValueError(
1285 "Please provide a single {} label".format(
1286 suffix.split(":")[1]
1287 )
1288 )
1289 meta_file_labels.append(
1290 labels[ind[0]].split(suffix)[0]
1291 )
1292 else:
1293 raise ValueError(
1294 "Please provide labels as {inspiral_label}:inspiral "
1295 "and {postinspiral_label}:postinspiral where "
1296 "inspiral_label and postinspiral_label are the "
1297 "PESummary labels for the inspiral and postinspiral "
1298 "analyses respectively. "
1299 )
1300 if len(self.result_files) == 1:
1301 logger.info(
1302 "Using the {} samples for the inspiral analysis and {} "
1303 "samples for the postinspiral analysis from the file "
1304 "{}".format(
1305 meta_file_labels[0], meta_file_labels[1],
1306 self.result_files[0]
1307 )
1308 )
1309 elif len(self.result_files) == 2:
1310 logger.info(
1311 "Using the {} samples for the inspiral analysis from "
1312 "the file {}. Using the {} samples for the "
1313 "postinspiral analysis from the file {}".format(
1314 meta_file_labels[0], self.result_files[0],
1315 meta_file_labels[1], self.result_files[1]
1316 )
1317 )
1318 else:
1319 raise ValueError(
1320 "Currently, you can only provide at most 2 pesummary "
1321 "metafiles. If one is provided, both the inspiral and "
1322 "postinspiral are extracted from that single file. If "
1323 "two are provided, the inspiral is extracted from one "
1324 "file and the postinspiral is extracted from the other."
1325 )
1326 self._labels = ["inspiral", "postinspiral"]
1327 self._meta_file_labels = meta_file_labels
1328 self.analysis_label = ["primary"]
1329 else:
1330 raise ValueError(
1331 "The IMRCT test requires an inspiral and postinspiral result "
1332 "file. Please indicate which file is the inspiral and which "
1333 "is postinspiral by providing these exact labels to the "
1334 "summarytgr executable"
1335 )
1336 else:
1337 self.analysis_label = ["primary"]
1339 def _extract_stored_approximant(self, opened_file, label):
1340 """Extract the approximant used for a given analysis stored in a
1341 PESummary metafile
1343 Parameters
1344 ----------
1345 opened_file: pesummary.gw.file.formats.pesummary.PESummary
1346 opened metafile that contains the analysis 'label'
1347 label: str
1348 analysis label which is stored in the PESummary metafile
1349 """
1350 if opened_file.approximant is not None:
1351 if label not in opened_file.labels:
1352 raise ValueError(
1353 "Invalid label {}. The list of available labels are {}".format(
1354 label, ", ".join(opened_file.labels)
1355 )
1356 )
1357 _index = opened_file.labels.index(label)
1358 return opened_file.approximant[_index]
1359 return
1361 def _extract_stored_remnant_fits(self, opened_file, label):
1362 """Extract the remnant fits used for a given analysis stored in a
1363 PESummary metafile
1365 Parameters
1366 ----------
1367 opened_file: pesummary.gw.file.formats.pesummary.PESummary
1368 opened metafile that contains the analysis 'label'
1369 label: str
1370 analysis label which is stored in the PESummary metafile
1371 """
1372 fits = {}
1373 fit_strings = [
1374 "final_spin_NR_fits", "final_mass_NR_fits"
1375 ]
1376 if label not in opened_file.labels:
1377 raise ValueError(
1378 "Invalid label {}. The list of available labels are {}".format(
1379 label, ", ".join(opened_file.labels)
1380 )
1381 )
1382 _index = opened_file.labels.index(label)
1383 _meta_data = opened_file.extra_kwargs[_index]
1384 if "meta_data" in _meta_data.keys():
1385 for key in fit_strings:
1386 if key in _meta_data["meta_data"].keys():
1387 fits[key] = _meta_data["meta_data"][key]
1388 if len(fits):
1389 return fits
1390 return
1392 def _extract_stored_cutoff_frequency(self, opened_file, label):
1393 """Extract the cutoff frequencies used for a given analysis stored in a
1394 PESummary metafile
1396 Parameters
1397 ----------
1398 opened_file: pesummary.gw.file.formats.pesummary.PESummary
1399 opened metafile that contains the analysis 'label'
1400 label: str
1401 analysis label which is stored in the PESummary metafile
1402 """
1403 frequencies = {}
1404 if opened_file.config is not None:
1405 if label not in opened_file.labels:
1406 raise ValueError(
1407 "Invalid label {}. The list of available labels are {}".format(
1408 label, ", ".join(opened_file.labels)
1409 )
1410 )
1411 if opened_file.config[label] is not None:
1412 _config = opened_file.config[label]
1413 if "config" in _config.keys():
1414 if "maximum-frequency" in _config["config"].keys():
1415 frequencies["fhigh"] = _config["config"][
1416 "maximum-frequency"
1417 ]
1418 if "minimum-frequency" in _config["config"].keys():
1419 frequencies["flow"] = _config["config"][
1420 "minimum-frequency"
1421 ]
1422 elif "lalinference" in _config.keys():
1423 if "fhigh" in _config["lalinference"].keys():
1424 frequencies["fhigh"] = _config["lalinference"][
1425 "fhigh"
1426 ]
1427 if "flow" in _config["lalinference"].keys():
1428 frequencies["flow"] = _config["lalinference"][
1429 "flow"
1430 ]
1431 return frequencies
1432 return
1434 @property
1435 def samples(self):
1436 return self._samples
1438 @samples.setter
1439 def samples(self, samples):
1440 from pesummary.utils.samples_dict import MultiAnalysisSamplesDict
1441 self._read_samples = {
1442 _label: GWRead(_path, disable_prior=True) for _label, _path in zip(
1443 self.labels, self.result_files
1444 )
1445 }
1446 _samples_dict = {}
1447 _approximant_dict = {}
1448 _cutoff_frequency_dict = {}
1449 _remnant_fits_dict = {}
1450 for label, _open in self._read_samples.items():
1451 if isinstance(_open.samples_dict, MultiAnalysisSamplesDict):
1452 if not len(self._meta_file_labels):
1453 raise ValueError(
1454 "Currently you can only pass a file containing a "
1455 "single analysis or a valid PESummary metafile "
1456 "containing multiple analyses"
1457 )
1458 _labels = _open.labels
1459 if len(self._read_samples) == 1:
1460 _samples_dict = {
1461 label: _open.samples_dict[meta_file_label] for
1462 label, meta_file_label in zip(
1463 self.labels, self._meta_file_labels
1464 )
1465 }
1466 for label, meta_file_label in zip(self.labels, self._meta_file_labels):
1467 _stored_approx = self._extract_stored_approximant(
1468 _open, meta_file_label
1469 )
1470 _stored_frequencies = self._extract_stored_cutoff_frequency(
1471 _open, meta_file_label
1472 )
1473 _stored_remnant_fits = self._extract_stored_remnant_fits(
1474 _open, meta_file_label
1475 )
1476 if _stored_approx is not None:
1477 _approximant_dict[label] = _stored_approx
1478 if _stored_remnant_fits is not None:
1479 _remnant_fits_dict[label] = _stored_remnant_fits
1480 if _stored_frequencies is not None:
1481 if label == "inspiral":
1482 if "fhigh" in _stored_frequencies.keys():
1483 _cutoff_frequency_dict[label] = _stored_frequencies[
1484 "fhigh"
1485 ]
1486 if label == "postinspiral":
1487 if "flow" in _stored_frequencies.keys():
1488 _cutoff_frequency_dict[label] = _stored_frequencies[
1489 "flow"
1490 ]
1491 break
1492 else:
1493 ind = self.labels.index(label)
1494 _samples_dict[label] = _open.samples_dict[
1495 self._meta_file_labels[ind]
1496 ]
1497 _stored_approx = self._extract_stored_approximant(
1498 _open, self._meta_file_labels[ind]
1499 )
1500 _stored_frequencies = self._extract_stored_cutoff_frequency(
1501 _open, self._meta_file_labels[ind]
1502 )
1503 _stored_remnant_fits = self._extract_stored_remnant_fits(
1504 _open, self._meta_file_labels[ind]
1505 )
1506 if _stored_approx is not None:
1507 _approximant_dict[label] = _stored_approx
1508 if _stored_remnant_fits is not None:
1509 _remnant_fits_dict[label] = _stored_remnant_fits
1510 if _stored_frequencies is not None:
1511 if label == "inspiral":
1512 if "fhigh" in _stored_frequencies.keys():
1513 _cutoff_frequency_dict[label] = _stored_frequencies[
1514 "fhigh"
1515 ]
1516 if label == "postinspiral":
1517 if "flow" in _stored_frequencies.keys():
1518 _cutoff_frequency_dict[label] = _stored_frequencies[
1519 "flow"
1520 ]
1521 else:
1522 _samples_dict[label] = _open.samples_dict
1523 extra_kwargs = _open.extra_kwargs
1524 if "pe_algorithm" in extra_kwargs["sampler"].keys():
1525 if extra_kwargs["sampler"]["pe_algorithm"] == "bilby":
1526 try:
1527 subkwargs = extra_kwargs["other"]["likelihood"][
1528 "waveform_arguments"
1529 ]
1530 _approximant_dict[label] = (
1531 subkwargs["waveform_approximant"]
1532 )
1533 if "inspiral" in label and "postinspiral" not in label:
1534 _cutoff_frequency_dict[label] = (
1535 subkwargs["maximum_frequency"]
1536 )
1537 elif "postinspiral" in label:
1538 _cutoff_frequency_dict[label] = (
1539 subkwargs["minimum_frequency"]
1540 )
1541 except KeyError:
1542 pass
1543 self._samples = MultiAnalysisSamplesDict(_samples_dict)
1544 if len(_approximant_dict):
1545 self._approximant_dict = _approximant_dict
1546 if len(_cutoff_frequency_dict):
1547 self._cutoff_frequency_dict = _cutoff_frequency_dict
1548 if len(_remnant_fits_dict):
1549 self._remnant_fits_dict = _remnant_fits_dict
1551 @property
1552 def imrct_kwargs(self):
1553 return self._imrct_kwargs
1555 @imrct_kwargs.setter
1556 def imrct_kwargs(self, imrct_kwargs):
1557 test_kwargs = dict(N_bins=101)
1558 try:
1559 test_kwargs.update(imrct_kwargs)
1560 except AttributeError:
1561 test_kwargs = test_kwargs
1563 for key, value in test_kwargs.items():
1564 try:
1565 test_kwargs[key] = ast.literal_eval(value)
1566 except ValueError:
1567 pass
1568 self._imrct_kwargs = test_kwargs
1570 @property
1571 def meta_data(self):
1572 return self._meta_data
1574 @meta_data.setter
1575 def meta_data(self, meta_data):
1576 self._meta_data = {}
1577 for num, _inspiral in enumerate(self.inspiral_keys):
1578 frequency_dict = dict()
1579 approximant_dict = dict()
1580 remnant_dict = dict()
1581 zipped = zip(
1582 [self.cutoff_frequency, self.approximant, None],
1583 [frequency_dict, approximant_dict, remnant_dict],
1584 ["cutoff_frequency", "approximant", "remnant_fits"]
1585 )
1586 _inspiral_string = self.inspiral_keys[num]
1587 _postinspiral_string = self.postinspiral_keys[num]
1588 for _list, _dict, name in zipped:
1589 if _list is not None and len(_list) == len(self.labels):
1590 inspiral_ind = self.labels.index(_inspiral_string)
1591 postinspiral_ind = self.labels.index(_postinspiral_string)
1592 _dict["inspiral"] = _list[inspiral_ind]
1593 _dict["postinspiral"] = _list[postinspiral_ind]
1594 elif _list is not None:
1595 raise ValueError(
1596 "Please provide a 'cutoff_frequency' and 'approximant' "
1597 "for each file"
1598 )
1599 else:
1600 try:
1601 if name == "cutoff_frequency":
1602 if "inspiral" in self._cutoff_frequency_dict.keys():
1603 _dict["inspiral"] = self._cutoff_frequency_dict[
1604 "inspiral"
1605 ]
1606 if "postinspiral" in self._cutoff_frequency_dict.keys():
1607 _dict["postinspiral"] = self._cutoff_frequency_dict[
1608 "postinspiral"
1609 ]
1610 elif name == "approximant":
1611 if "inspiral" in self._approximant_dict.keys():
1612 _dict["inspiral"] = self._approximant_dict[
1613 "inspiral"
1614 ]
1615 if "postinspiral" in self._approximant_dict.keys():
1616 _dict["postinspiral"] = self._approximant_dict[
1617 "postinspiral"
1618 ]
1619 elif name == "remnant_fits":
1620 if "inspiral" in self._remnant_fits_dict.keys():
1621 _dict["inspiral"] = self._remnant_fits_dict[
1622 "inspiral"
1623 ]
1624 if "postinspiral" in self._remnant_fits_dict.keys():
1625 _dict["postinspiral"] = self._remnant_fits_dict[
1626 "postinspiral"
1627 ]
1628 except (AttributeError, KeyError, TypeError):
1629 _dict["inspiral"] = None
1630 _dict["postinspiral"] = None
1632 self._meta_data[self.analysis_label[num]] = {
1633 "inspiral maximum frequency (Hz)": frequency_dict["inspiral"],
1634 "postinspiral minimum frequency (Hz)": frequency_dict["postinspiral"],
1635 "inspiral approximant": approximant_dict["inspiral"],
1636 "postinspiral approximant": approximant_dict["postinspiral"],
1637 "inspiral remnant fits": remnant_dict["inspiral"],
1638 "postinspiral remnant fits": remnant_dict["postinspiral"]
1639 }
1641 def __init__(self, opts):
1642 self.opts = opts
1643 self.existing = None
1644 self.webdir = self.opts.webdir
1645 self.user = None
1646 self.baseurl = None
1647 self.result_files = self.opts.samples
1648 self.labels = self.opts.labels
1649 self.samples = self.opts.samples
1650 self.inspiral_keys = [
1651 key for key in self.samples.keys() if "inspiral" in key
1652 and "postinspiral" not in key
1653 ]
1654 self.postinspiral_keys = [
1655 key.replace("inspiral", "postinspiral") for key in self.inspiral_keys
1656 ]
1657 try:
1658 self.imrct_kwargs = self.opts.imrct_kwargs
1659 except AttributeError:
1660 self.imrct_kwargs = {}
1661 for _arg in ["cutoff_frequency", "approximant", "links_to_pe_pages", "f_low"]:
1662 _attr = getattr(self.opts, _arg)
1663 if _attr is not None and len(_attr) and len(_attr) != len(self.labels):
1664 raise ValueError("Please provide a {} for each file".format(_arg))
1665 setattr(self, _arg, _attr)
1666 self.meta_data = None
1667 self.default_directories = ["samples", "plots", "js", "html", "css"]
1668 self.publication = False
1669 self.make_directories()