Coverage for pesummary/gw/cli/inputs.py: 79.6%
846 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-02 08:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-02 08:42 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import ast
4import os
5import numpy as np
6import pesummary.core.cli.inputs
7from pesummary.gw.file.read import read as GWRead
8from pesummary.gw.file.psd import PSD
9from pesummary.gw.file.calibration import Calibration
10from pesummary.utils.decorators import deprecation
11from pesummary.utils.exceptions import InputError
12from pesummary.utils.utils import logger
13from pesummary import conf
15__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
18class _GWInput(pesummary.core.cli.inputs._Input):
19 """Super class to handle gw specific command line inputs
20 """
21 @staticmethod
22 def grab_data_from_metafile(
23 existing_file, webdir, compare=None, nsamples=None, **kwargs
24 ):
25 """Grab data from an existing PESummary metafile
27 Parameters
28 ----------
29 existing_file: str
30 path to the existing metafile
31 webdir: str
32 the directory to store the existing configuration file
33 compare: list, optional
34 list of labels for events stored in an existing metafile that you
35 wish to compare
36 """
37 _replace_kwargs = {
38 "psd": "{file}.psd['{label}']"
39 }
40 if "psd_default" in kwargs.keys():
41 _replace_kwargs["psd_default"] = kwargs["psd_default"]
42 data = pesummary.core.cli.inputs._Input.grab_data_from_metafile(
43 existing_file, webdir, compare=compare, read_function=GWRead,
44 nsamples=nsamples, _replace_with_pesummary_kwargs=_replace_kwargs,
45 **kwargs
46 )
47 f = GWRead(existing_file)
49 labels = data["labels"]
51 psd = {i: {} for i in labels}
52 if f.psd is not None and f.psd != {}:
53 for i in labels:
54 if i in f.psd.keys() and f.psd[i] != {}:
55 psd[i] = {
56 ifo: PSD(f.psd[i][ifo]) for ifo in f.psd[i].keys()
57 }
58 calibration = {i: {} for i in labels}
59 if f.calibration is not None and f.calibration != {}:
60 for i in labels:
61 if i in f.calibration.keys() and f.calibration[i] != {}:
62 calibration[i] = {
63 ifo: Calibration(f.calibration[i][ifo]) for ifo in
64 f.calibration[i].keys()
65 }
66 skymap = {i: None for i in labels}
67 if hasattr(f, "skymap") and f.skymap is not None and f.skymap != {}:
68 for i in labels:
69 if i in f.skymap.keys() and len(f.skymap[i]):
70 skymap[i] = f.skymap[i]
71 data.update(
72 {
73 "approximant": {
74 i: j for i, j in zip(
75 labels, [f.approximant[ind] for ind in data["indicies"]]
76 )
77 },
78 "psd": psd,
79 "calibration": calibration,
80 "skymap": skymap
81 }
82 )
83 return data
85 @property
86 def grab_data_kwargs(self):
87 kwargs = super(_GWInput, self).grab_data_kwargs
88 for _property in ["f_low", "f_ref", "f_final", "delta_f"]:
89 if getattr(self, _property) is None:
90 setattr(self, "_{}".format(_property), [None] * len(self.labels))
91 elif len(getattr(self, _property)) == 1 and len(self.labels) != 1:
92 setattr(
93 self, "_{}".format(_property),
94 getattr(self, _property) * len(self.labels)
95 )
96 if self.opts.approximant is None:
97 approx = [None] * len(self.labels)
98 else:
99 approx = self.opts.approximant
100 resume_file = [
101 os.path.join(
102 self.webdir, "checkpoint",
103 "{}_conversion_class.pickle".format(label)
104 ) for label in self.labels
105 ]
107 try:
108 for num, label in enumerate(self.labels):
109 try:
110 psd = self.psd[label]
111 except KeyError:
112 psd = {}
113 kwargs[label].update(dict(
114 evolve_spins_forwards=self.evolve_spins_forwards,
115 evolve_spins_backwards=self.evolve_spins_backwards,
116 f_low=self.f_low[num],
117 approximant=approx[num], f_ref=self.f_ref[num],
118 NRSur_fits=self.NRSur_fits, return_kwargs=True,
119 multipole_snr=self.calculate_multipole_snr,
120 precessing_snr=self.calculate_precessing_snr,
121 psd=psd, f_final=self.f_final[num],
122 waveform_fits=self.waveform_fits,
123 multi_process=self.opts.multi_process,
124 redshift_method=self.redshift_method,
125 cosmology=self.cosmology,
126 no_conversion=self.no_conversion,
127 add_zero_spin=True, delta_f=self.delta_f[num],
128 psd_default=self.psd_default,
129 disable_remnant=self.disable_remnant,
130 force_BBH_remnant_computation=self.force_BBH_remnant_computation,
131 resume_file=resume_file[num],
132 restart_from_checkpoint=self.restart_from_checkpoint,
133 force_BH_spin_evolution=self.force_BH_spin_evolution,
134 ))
135 return kwargs
136 except IndexError:
137 logger.warning(
138 "Unable to find an f_ref, f_low and approximant for each "
139 "label. Using and f_ref={}, f_low={} and approximant={} "
140 "for all result files".format(
141 self.f_ref[0], self.f_low[0], approx[0]
142 )
143 )
144 for num, label in enumerate(self.labels):
145 kwargs[label].update(dict(
146 evolve_spins_forwards=self.evolve_spins_forwards,
147 evolve_spins_backwards=self.evolve_spins_backwards,
148 f_low=self.f_low[0],
149 approximant=approx[0], f_ref=self.f_ref[0],
150 NRSur_fits=self.NRSur_fits, return_kwargs=True,
151 multipole_snr=self.calculate_multipole_snr,
152 precessing_snr=self.calculate_precessing_snr,
153 psd=self.psd[self.labels[0]], f_final=self.f_final[0],
154 waveform_fits=self.waveform_fits,
155 multi_process=self.opts.multi_process,
156 redshift_method=self.redshift_method,
157 cosmology=self.cosmology,
158 no_conversion=self.no_conversion,
159 add_zero_spin=True, delta_f=self.delta_f[0],
160 psd_default=self.psd_default,
161 disable_remnant=self.disable_remnant,
162 force_BBH_remnant_computation=self.force_BBH_remnant_computation,
163 resume_file=resume_file[num],
164 restart_from_checkpoint=self.restart_from_checkpoint,
165 force_BH_spin_evolution=self.force_BH_spin_evolution
166 ))
167 return kwargs
169 @staticmethod
170 def grab_data_from_file(
171 file, label, webdir, config=None, injection=None, file_format=None,
172 nsamples=None, disable_prior_sampling=False, **kwargs
173 ):
174 """Grab data from a result file containing posterior samples
176 Parameters
177 ----------
178 file: str
179 path to the result file
180 label: str
181 label that you wish to use for the result file
182 config: str, optional
183 path to a configuration file used in the analysis
184 injection: str, optional
185 path to an injection file used in the analysis
186 file_format, str, optional
187 the file format you wish to use when loading. Default None.
188 If None, the read function loops through all possible options
189 """
190 data = pesummary.core.cli.inputs._Input.grab_data_from_file(
191 file, label, webdir, config=config, injection=injection,
192 read_function=GWRead, file_format=file_format, nsamples=nsamples,
193 disable_prior_sampling=disable_prior_sampling, **kwargs
194 )
195 return data
197 @property
198 def reweight_samples(self):
199 return self._reweight_samples
201 @reweight_samples.setter
202 def reweight_samples(self, reweight_samples):
203 from pesummary.gw.reweight import options
204 self._reweight_samples = self._check_reweight_samples(
205 reweight_samples, options
206 )
208 def _set_samples(self, *args, **kwargs):
209 super(_GWInput, self)._set_samples(*args, **kwargs)
210 if "calibration" not in self.priors:
211 self.priors["calibration"] = {
212 label: {} for label in self.labels
213 }
215 def _set_corner_params(self, corner_params):
216 corner_params = super(_GWInput, self)._set_corner_params(corner_params)
217 if corner_params is None:
218 logger.debug(
219 "Using the default corner parameters: {}".format(
220 ", ".join(conf.gw_corner_parameters)
221 )
222 )
223 else:
224 _corner_params = corner_params
225 corner_params = list(set(conf.gw_corner_parameters + corner_params))
226 for param in _corner_params:
227 _data = self.samples
228 if not all(param in _data[label].keys() for label in self.labels):
229 corner_params.remove(param)
230 logger.debug(
231 "Generating a corner plot with the following "
232 "parameters: {}".format(", ".join(corner_params))
233 )
234 return corner_params
236 @property
237 def cosmology(self):
238 return self._cosmology
240 @cosmology.setter
241 def cosmology(self, cosmology):
242 from pesummary.gw.cosmology import available_cosmologies
244 if cosmology.lower() not in available_cosmologies:
245 logger.warning(
246 "Unrecognised cosmology: {}. Using {} as default".format(
247 cosmology, conf.cosmology
248 )
249 )
250 cosmology = conf.cosmology
251 else:
252 logger.debug("Using the {} cosmology".format(cosmology))
253 self._cosmology = cosmology
255 @property
256 def approximant(self):
257 return self._approximant
259 @approximant.setter
260 def approximant(self, approximant):
261 if not hasattr(self, "_approximant"):
262 approximant_list = {i: {} for i in self.labels}
263 if approximant is None:
264 logger.warning(
265 "No approximant passed. Waveform plots will not be "
266 "generated"
267 )
268 elif approximant is not None:
269 if len(approximant) != len(self.labels):
270 raise InputError(
271 "Please pass an approximant for each result file"
272 )
273 approximant_list = {
274 i: j for i, j in zip(self.labels, approximant)
275 }
276 self._approximant = approximant_list
277 else:
278 for num, i in enumerate(self._approximant.keys()):
279 if self._approximant[i] == {}:
280 if num == 0:
281 logger.warning(
282 "No approximant passed. Waveform plots will not be "
283 "generated"
284 )
285 self._approximant[i] = None
286 break
288 @property
289 def gracedb_server(self):
290 return self._gracedb_server
292 @gracedb_server.setter
293 def gracedb_server(self, gracedb_server):
294 if gracedb_server is None:
295 self._gracedb_server = conf.gracedb_server
296 else:
297 logger.debug(
298 "Using '{}' as the GraceDB server".format(gracedb_server)
299 )
300 self._gracedb_server = gracedb_server
302 @property
303 def gracedb(self):
304 return self._gracedb
306 @gracedb.setter
307 def gracedb(self, gracedb):
308 self._gracedb = gracedb
309 if gracedb is not None:
310 from pesummary.gw.gracedb import get_gracedb_data, HTTPError
311 from json.decoder import JSONDecodeError
313 first_letter = gracedb[0]
314 if first_letter != "G" and first_letter != "g" and first_letter != "S":
315 logger.warning(
316 "Invalid GraceDB ID passed. The GraceDB ID must be of the "
317 "form G0000 or S0000. Ignoring input."
318 )
319 self._gracedb = None
320 return
321 _error = (
322 "Unable to download data from Gracedb because {}. Only storing "
323 "the GraceDB ID in the metafile"
324 )
325 try:
326 logger.info(
327 "Downloading {} from gracedb for {}".format(
328 ", ".join(self.gracedb_data), gracedb
329 )
330 )
331 json = get_gracedb_data(
332 gracedb, info=self.gracedb_data,
333 service_url=self.gracedb_server
334 )
335 json["id"] = gracedb
336 except (HTTPError, RuntimeError, JSONDecodeError) as e:
337 logger.warning(_error.format(e))
338 json = {"id": gracedb}
340 for label in self.labels:
341 self.file_kwargs[label]["meta_data"]["gracedb"] = json
343 @property
344 def detectors(self):
345 return self._detectors
347 @detectors.setter
348 def detectors(self, detectors):
349 detector = {}
350 if not detectors:
351 for i in self.samples.keys():
352 params = list(self.samples[i].keys())
353 individual_detectors = []
354 for j in params:
355 if "optimal_snr" in j and j != "network_optimal_snr":
356 det = j.split("_optimal_snr")[0]
357 individual_detectors.append(det)
358 individual_detectors = sorted(
359 [str(i) for i in individual_detectors])
360 if individual_detectors:
361 detector[i] = "_".join(individual_detectors)
362 else:
363 detector[i] = None
364 else:
365 detector = detectors
366 logger.debug("The detector network is %s" % (detector))
367 self._detectors = detector
369 @property
370 def skymap(self):
371 return self._skymap
373 @skymap.setter
374 def skymap(self, skymap):
375 if not hasattr(self, "_skymap"):
376 self._skymap = {i: None for i in self.labels}
378 @property
379 def calibration(self):
380 return self._calibration
382 @calibration.setter
383 def calibration(self, calibration):
384 if not hasattr(self, "_calibration"):
385 data = {i: {} for i in self.labels}
386 if calibration != {}:
387 prior_data = self.get_psd_or_calibration_data(
388 calibration, self.extract_calibration_data_from_file
389 )
390 self.add_to_prior_dict("calibration", prior_data)
391 else:
392 prior_data = {i: {} for i in self.labels}
393 for label in self.labels:
394 if hasattr(self.opts, "{}_calibration".format(label)):
395 cal_data = getattr(self.opts, "{}_calibration".format(label))
396 if cal_data != {} and cal_data is not None:
397 prior_data[label] = {
398 ifo: self.extract_calibration_data_from_file(
399 cal_data[ifo]
400 ) for ifo in cal_data.keys()
401 }
402 if not all(prior_data[i] == {} for i in self.labels):
403 self.add_to_prior_dict("calibration", prior_data)
404 else:
405 self.add_to_prior_dict("calibration", {})
406 for num, i in enumerate(self.result_files):
407 _opened = self._open_result_files
408 if i in _opened.keys() and _opened[i] is not None:
409 f = self._open_result_files[i]
410 else:
411 f = GWRead(i, disable_prior=True)
412 try:
413 calibration_data = f.interpolate_calibration_spline_posterior()
414 except Exception as e:
415 logger.warning(
416 "Failed to extract calibration data from the result "
417 "file: {} because {}".format(i, e)
418 )
419 calibration_data = None
420 labels = list(self.samples.keys())
421 if calibration_data is None:
422 data[labels[num]] = {
423 None: None
424 }
425 elif isinstance(f, pesummary.gw.file.formats.pesummary.PESummary):
426 for num in range(len(calibration_data[0])):
427 data[labels[num]] = {
428 j: k for j, k in zip(
429 calibration_data[1][num],
430 calibration_data[0][num]
431 )
432 }
433 else:
434 data[labels[num]] = {
435 j: k for j, k in zip(
436 calibration_data[1], calibration_data[0]
437 )
438 }
439 self._calibration = data
441 @property
442 def psd(self):
443 return self._psd
445 @psd.setter
446 def psd(self, psd):
447 if not hasattr(self, "_psd"):
448 data = {i: {} for i in self.labels}
449 if psd != {}:
450 data = self.get_psd_or_calibration_data(
451 psd, self.extract_psd_data_from_file
452 )
453 else:
454 for label in self.labels:
455 if hasattr(self.opts, "{}_psd".format(label)):
456 psd_data = getattr(self.opts, "{}_psd".format(label))
457 if psd_data != {} and psd_data is not None:
458 data[label] = {
459 ifo: self.extract_psd_data_from_file(
460 psd_data[ifo], IFO=ifo
461 ) for ifo in psd_data.keys()
462 }
463 self._psd = data
465 @property
466 def nsamples_for_skymap(self):
467 return self._nsamples_for_skymap
469 @nsamples_for_skymap.setter
470 def nsamples_for_skymap(self, nsamples_for_skymap):
471 self._nsamples_for_skymap = nsamples_for_skymap
472 if nsamples_for_skymap is not None:
473 self._nsamples_for_skymap = int(nsamples_for_skymap)
474 number_of_samples = [
475 data.number_of_samples for label, data in self.samples.items()
476 ]
477 if not all(i > self._nsamples_for_skymap for i in number_of_samples):
478 min_arg = np.argmin(number_of_samples)
479 logger.warning(
480 "You have specified that you would like to use {} "
481 "samples to generate the skymap but the file {} only "
482 "has {} samples. Reducing the number of samples to "
483 "generate the skymap to {}".format(
484 self._nsamples_for_skymap, self.result_files[min_arg],
485 number_of_samples[min_arg], number_of_samples[min_arg]
486 )
487 )
488 self._nsamples_for_skymap = int(number_of_samples[min_arg])
490 @property
491 def gwdata(self):
492 return self._gwdata
494 @gwdata.setter
495 def gwdata(self, gwdata):
496 from pesummary.gw.file.strain import StrainDataDict
498 self._gwdata = gwdata
499 if gwdata is not None:
500 if isinstance(gwdata, dict):
501 for i in gwdata.keys():
502 if not os.path.isfile(gwdata[i]):
503 raise InputError(
504 "The file {} does not exist. Please check the path "
505 "to your strain file".format(gwdata[i])
506 )
507 self._gwdata = StrainDataDict.read(gwdata)
508 else:
509 logger.warning(
510 "Please provide gwdata as a dictionary with keys "
511 "displaying the channel and item giving the path to the "
512 "strain file"
513 )
514 self._gwdata = None
516 @property
517 def evolve_spins_forwards(self):
518 return self._evolve_spins_forwards
520 @evolve_spins_forwards.setter
521 def evolve_spins_forwards(self, evolve_spins_forwards):
522 self._evolve_spins_forwards = evolve_spins_forwards
523 _msg = "Spins will be evolved up to {}"
524 if evolve_spins_forwards:
525 logger.info(_msg.format("Schwarzschild ISCO frequency"))
526 self._evolve_spins_forwards = 6. ** -0.5
528 @property
529 def evolve_spins_backwards(self):
530 return self._evolve_spins_backwards
532 @evolve_spins_backwards.setter
533 def evolve_spins_backwards(self, evolve_spins_backwards):
534 self._evolve_spins_backwards = evolve_spins_backwards
535 _msg = (
536 "Spins will be evolved backwards to an infinite separation using the '{}' "
537 "method"
538 )
539 if isinstance(evolve_spins_backwards, (str, bytes)):
540 logger.info(_msg.format(evolve_spins_backwards))
541 elif evolve_spins_backwards is None:
542 logger.info(_msg.format("precession_averaged"))
543 self._evolve_spins_backwards = "precession_averaged"
545 @property
546 def NRSur_fits(self):
547 return self._NRSur_fits
549 @NRSur_fits.setter
550 def NRSur_fits(self, NRSur_fits):
551 self._NRSur_fits = NRSur_fits
552 base = (
553 "Using the '{}' NRSurrogate model to calculate the remnant "
554 "quantities"
555 )
556 if isinstance(NRSur_fits, (str, bytes)):
557 logger.info(base.format(NRSur_fits))
558 self._NRSur_fits = NRSur_fits
559 elif NRSur_fits is None:
560 from pesummary.gw.conversions.nrutils import NRSUR_MODEL
562 logger.info(base.format(NRSUR_MODEL))
563 self._NRSur_fits = NRSUR_MODEL
565 @property
566 def waveform_fits(self):
567 return self._waveform_fits
569 @waveform_fits.setter
570 def waveform_fits(self, waveform_fits):
571 self._waveform_fits = waveform_fits
572 if waveform_fits:
573 logger.info(
574 "Evaluating the remnant quantities using the provided "
575 "approximant"
576 )
578 @property
579 def f_low(self):
580 return self._f_low
582 @f_low.setter
583 def f_low(self, f_low):
584 self._f_low = f_low
585 if f_low is not None:
586 self._f_low = [float(i) for i in f_low]
588 @property
589 def f_ref(self):
590 return self._f_ref
592 @f_ref.setter
593 def f_ref(self, f_ref):
594 self._f_ref = f_ref
595 if f_ref is not None:
596 self._f_ref = [float(i) for i in f_ref]
598 @property
599 def f_final(self):
600 return self._f_final
602 @f_final.setter
603 def f_final(self, f_final):
604 self._f_final = f_final
605 if f_final is not None:
606 self._f_final = [float(i) for i in f_final]
608 @property
609 def delta_f(self):
610 return self._delta_f
612 @delta_f.setter
613 def delta_f(self, delta_f):
614 self._delta_f = delta_f
615 if delta_f is not None:
616 self._delta_f = [float(i) for i in delta_f]
618 @property
619 def psd_default(self):
620 return self._psd_default
622 @psd_default.setter
623 def psd_default(self, psd_default):
624 self._psd_default = psd_default
625 if "stored:" in psd_default:
626 label = psd_default.split("stored:")[1]
627 self._psd_default = "{file}.psd['%s']" % (label)
628 return
629 try:
630 from pycbc import psd
631 psd_default = getattr(psd, psd_default)
632 except ImportError:
633 logger.warning(
634 "Unable to import 'pycbc'. Unable to generate a default PSD"
635 )
636 psd_default = None
637 except AttributeError:
638 logger.warning(
639 "'pycbc' does not have the '{}' psd available. Using '{}' as "
640 "the default PSD".format(psd_default, conf.psd)
641 )
642 psd_default = getattr(psd, conf.psd)
643 except ValueError as e:
644 logger.warning("Setting 'psd_default' to None because {}".format(e))
645 psd_default = None
646 self._psd_default = psd_default
648 @property
649 def pepredicates_probs(self):
650 return self._pepredicates_probs
652 @pepredicates_probs.setter
653 def pepredicates_probs(self, pepredicates_probs):
654 from pesummary.gw.classification import PEPredicates
656 classifications = {}
657 for num, i in enumerate(list(self.samples.keys())):
658 try:
659 classifications[i] = PEPredicates(
660 self.samples[i]
661 ).dual_classification()
662 except Exception as e:
663 logger.warning(
664 "Failed to generate source classification probabilities "
665 "because {}".format(e)
666 )
667 classifications[i] = None
668 if self.mcmc_samples:
669 if any(_probs is None for _probs in classifications.values()):
670 classifications[self.labels[0]] = None
671 logger.warning(
672 "Unable to average classification probabilities across "
673 "mcmc chains because one or more chains failed to estimate "
674 "classifications"
675 )
676 else:
677 logger.debug(
678 "Averaging classification probability across mcmc samples"
679 )
680 classifications[self.labels[0]] = {
681 prior: {
682 key: np.round(np.mean(
683 [val[prior][key] for val in classifications.values()]
684 ), 3) for key in _probs.keys()
685 } for prior, _probs in
686 list(classifications.values())[0].items()
687 }
688 self._pepredicates_probs = classifications
690 @property
691 def pastro_probs(self):
692 return self._pastro_probs
694 @pastro_probs.setter
695 def pastro_probs(self, pastro_probs):
696 from pesummary.gw.classification import PAstro
698 probabilities = {}
699 for num, i in enumerate(list(self.samples.keys())):
700 try:
701 probabilities[i] = PAstro(self.samples[i]).dual_classification()
702 except Exception as e:
703 logger.warning(
704 "Failed to generate em_bright probabilities because "
705 "{}".format(e)
706 )
707 probabilities[i] = None
708 if self.mcmc_samples:
709 if any(_probs is None for _probs in probabilities.values()):
710 probabilities[self.labels[0]] = None
711 logger.warning(
712 "Unable to average em_bright probabilities across "
713 "mcmc chains because one or more chains failed to estimate "
714 "probabilities"
715 )
716 else:
717 logger.debug(
718 "Averaging em_bright probability across mcmc samples"
719 )
720 probabilities[self.labels[0]] = {
721 prior: {
722 key: np.round(np.mean(
723 [val[prior][key] for val in probabilities.values()]
724 ), 3) for key in _probs.keys()
725 } for prior, _probs in list(probabilities.values())[0].items()
726 }
727 self._pastro_probs = probabilities
729 @property
730 def preliminary_pages(self):
731 return self._preliminary_pages
733 @preliminary_pages.setter
734 def preliminary_pages(self, preliminary_pages):
735 required = conf.gw_reproducibility
736 self._preliminary_pages = {label: False for label in self.labels}
737 for num, label in enumerate(self.labels):
738 for attr in required:
739 _property = getattr(self, attr)
740 if isinstance(_property, dict):
741 if label not in _property.keys():
742 self._preliminary_pages[label] = True
743 elif not len(_property[label]):
744 self._preliminary_pages[label] = True
745 elif isinstance(_property, list):
746 if _property[num] is None:
747 self._preliminary_pages[label] = True
748 if any(value for value in self._preliminary_pages.values()):
749 _labels = [
750 label for label, value in self._preliminary_pages.items() if
751 value
752 ]
753 msg = (
754 "Unable to reproduce the {} analys{} because no {} data was "
755 "provided. 'Preliminary' watermarks will be added to the final "
756 "html pages".format(
757 ", ".join(_labels), "es" if len(_labels) > 1 else "is",
758 " or ".join(required)
759 )
760 )
761 logger.warning(msg)
763 @staticmethod
764 def _extract_IFO_data_from_file(file, cls, desc, IFO=None):
765 """Return IFO data stored in a file
767 Parameters
768 ----------
769 file: path
770 path to a file containing the IFO data
771 cls: obj
772 class you wish to use when loading the file. This class must have
773 a '.read' method
774 desc: str
775 description of the IFO data stored in the file
776 IFO: str, optional
777 the IFO which the data belongs to
778 """
779 general = (
780 "Failed to read in %s data because {}. The %s plot will not be "
781 "generated and the %s data will not be added to the metafile."
782 ) % (desc, desc, desc)
783 try:
784 return cls.read(file, IFO=IFO)
785 except FileNotFoundError:
786 logger.warning(
787 general.format("the file {} does not exist".format(file))
788 )
789 return {}
790 except ValueError as e:
791 logger.warning(general.format(e))
792 return {}
794 @staticmethod
795 def extract_psd_data_from_file(file, IFO=None):
796 """Return the data stored in a psd file
798 Parameters
799 ----------
800 file: path
801 path to a file containing the psd data
802 """
803 from pesummary.gw.file.psd import PSD
804 return _GWInput._extract_IFO_data_from_file(file, PSD, "PSD", IFO=IFO)
806 @staticmethod
807 def extract_calibration_data_from_file(file, **kwargs):
808 """Return the data stored in a calibration file
810 Parameters
811 ----------
812 file: path
813 path to a file containing the calibration data
814 """
815 from pesummary.gw.file.calibration import Calibration
816 return _GWInput._extract_IFO_data_from_file(
817 file, Calibration, "calibration", **kwargs
818 )
820 @staticmethod
821 def get_ifo_from_file_name(file):
822 """Return the IFO from the file name
824 Parameters
825 ----------
826 file: str
827 path to the file
828 """
829 file_name = file.split("/")[-1]
830 if any(j in file_name for j in ["H1", "_0", "IFO0"]):
831 ifo = "H1"
832 elif any(j in file_name for j in ["L1", "_1", "IFO1"]):
833 ifo = "L1"
834 elif any(j in file_name for j in ["V1", "_2", "IFO2"]):
835 ifo = "V1"
836 else:
837 ifo = file_name
838 return ifo
840 def get_psd_or_calibration_data(self, input, executable):
841 """Return a dictionary containing the psd or calibration data
843 Parameters
844 ----------
845 input: list/dict
846 list/dict containing paths to calibration/psd files
847 executable: func
848 executable that is used to extract the data from the calibration/psd
849 files
850 """
851 data = {}
852 if input == {} or input == []:
853 return data
854 if isinstance(input, dict):
855 keys = list(input.keys())
856 if isinstance(input, dict) and isinstance(input[keys[0]], list):
857 if not all(len(input[i]) == len(self.labels) for i in list(keys)):
858 raise InputError(
859 "Please ensure the number of calibration/psd files matches "
860 "the number of result files passed"
861 )
862 for idx in range(len(input[keys[0]])):
863 data[self.labels[idx]] = {
864 i: executable(input[i][idx], IFO=i) for i in list(keys)
865 }
866 elif isinstance(input, dict):
867 for i in self.labels:
868 data[i] = {
869 j: executable(input[j], IFO=j) for j in list(input.keys())
870 }
871 elif isinstance(input, list):
872 for i in self.labels:
873 data[i] = {
874 self.get_ifo_from_file_name(j): executable(
875 j, IFO=self.get_ifo_from_file_name(j)
876 ) for j in input
877 }
878 else:
879 raise InputError(
880 "Did not understand the psd/calibration input. Please use the "
881 "following format 'H1:path/to/file'"
882 )
883 return data
885 def grab_priors_from_inputs(self, priors):
886 def read_func(data, **kwargs):
887 from pesummary.gw.file.read import read as GWRead
888 data = GWRead(data, **kwargs)
889 data.generate_all_posterior_samples()
890 return data
892 return super(_GWInput, self).grab_priors_from_inputs(
893 priors, read_func=read_func, read_kwargs=self.grab_data_kwargs
894 )
896 def grab_key_data_from_result_files(self):
897 """Grab the mean, median, maxL and standard deviation for all
898 parameters for all each result file
899 """
900 from pesummary.utils.kde_list import KDEList
901 from pesummary.gw.plots.plot import _return_bounds
902 from pesummary.utils.credible_interval import (
903 hpd_two_sided_credible_interval
904 )
905 from pesummary.utils.bounded_1d_kde import bounded_1d_kde
906 key_data = super(_GWInput, self).grab_key_data_from_result_files()
907 bounded_parameters = ["mass_ratio", "a_1", "a_2", "lambda_tilde"]
908 for param in bounded_parameters:
909 xlow, xhigh = _return_bounds(param, [])
910 _samples = {
911 key: val[param] for key, val in self.samples.items()
912 if param in val.keys()
913 }
914 _min = [np.min(_) for _ in _samples.values() if len(_samples)]
915 _max = [np.max(_) for _ in _samples.values() if len(_samples)]
916 if not len(_min):
917 continue
918 _min = np.min(_min)
919 _max = np.max(_max)
920 x = np.linspace(_min, _max, 1000)
921 try:
922 kdes = KDEList(
923 list(_samples.values()), kde=bounded_1d_kde,
924 kde_kwargs={"xlow": xlow, "xhigh": xhigh}
925 )
926 except Exception as e:
927 logger.warning(
928 "Unable to compute the HPD interval for {} because {}".format(
929 param, e
930 )
931 )
932 continue
933 pdfs = kdes(x)
934 for num, key in enumerate(_samples.keys()):
935 [xlow, xhigh], _ = hpd_two_sided_credible_interval(
936 [], 90, x=x, pdf=pdfs[num]
937 )
938 key_data[key][param]["90% HPD"] = [xlow, xhigh]
939 for _param in self.samples[key].keys():
940 if _param in bounded_parameters:
941 continue
942 key_data[key][_param]["90% HPD"] = float("nan")
943 return key_data
946class SamplesInput(_GWInput, pesummary.core.cli.inputs.SamplesInput):
947 """Class to handle and store sample specific command line arguments
948 """
949 def __init__(self, *args, **kwargs):
950 kwargs.update({"ignore_copy": True})
951 super(SamplesInput, self).__init__(
952 *args, gw=True, extra_options=[
953 "evolve_spins_forwards",
954 "evolve_spins_backwards",
955 "NRSur_fits",
956 "calculate_multipole_snr",
957 "calculate_precessing_snr",
958 "f_low",
959 "f_ref",
960 "f_final",
961 "psd",
962 "waveform_fits",
963 "redshift_method",
964 "cosmology",
965 "no_conversion",
966 "delta_f",
967 "psd_default",
968 "disable_remnant",
969 "force_BBH_remnant_computation",
970 "force_BH_spin_evolution"
971 ], **kwargs
972 )
973 if self._restarted_from_checkpoint:
974 return
975 if self.existing is not None:
976 self.existing_data = self.grab_data_from_metafile(
977 self.existing_metafile, self.existing,
978 compare=self.compare_results
979 )
980 self.existing_approximant = self.existing_data["approximant"]
981 self.existing_psd = self.existing_data["psd"]
982 self.existing_calibration = self.existing_data["calibration"]
983 self.existing_skymap = self.existing_data["skymap"]
984 else:
985 self.existing_approximant = None
986 self.existing_psd = None
987 self.existing_calibration = None
988 self.existing_skymap = None
989 self.approximant = self.opts.approximant
990 self.detectors = None
991 self.skymap = None
992 self.calibration = self.opts.calibration
993 self.gwdata = self.opts.gwdata
994 self.maxL_samples = []
996 @property
997 def maxL_samples(self):
998 return self._maxL_samples
1000 @maxL_samples.setter
1001 def maxL_samples(self, maxL_samples):
1002 key_data = self.grab_key_data_from_result_files()
1003 maxL_samples = {
1004 i: {
1005 j: key_data[i][j]["maxL"] for j in key_data[i].keys()
1006 } for i in key_data.keys()
1007 }
1008 for i in self.labels:
1009 maxL_samples[i]["approximant"] = self.approximant[i]
1010 self._maxL_samples = maxL_samples
1013class PlottingInput(SamplesInput, pesummary.core.cli.inputs.PlottingInput):
1014 """Class to handle and store plottig specific command line arguments
1015 """
1016 def __init__(self, *args, **kwargs):
1017 super(PlottingInput, self).__init__(*args, **kwargs)
1018 self.nsamples_for_skymap = self.opts.nsamples_for_skymap
1019 self.sensitivity = self.opts.sensitivity
1020 self.no_ligo_skymap = self.opts.no_ligo_skymap
1021 self.multi_threading_for_skymap = self.multi_process
1022 if not self.no_ligo_skymap and self.multi_process > 1:
1023 total = self.multi_process
1024 self.multi_threading_for_plots = int(total / 2.)
1025 self.multi_threading_for_skymap = total - self.multi_threading_for_plots
1026 logger.info(
1027 "Assigning {} process{}to skymap generation and {} process{}to "
1028 "other plots".format(
1029 self.multi_threading_for_skymap,
1030 "es " if self.multi_threading_for_skymap > 1 else " ",
1031 self.multi_threading_for_plots,
1032 "es " if self.multi_threading_for_plots > 1 else " "
1033 )
1034 )
1035 self.preliminary_pages = None
1036 self.pepredicates_probs = []
1037 self.pastro_probs = []
1040class WebpageInput(SamplesInput, pesummary.core.cli.inputs.WebpageInput):
1041 """Class to handle and store webpage specific command line arguments
1042 """
1043 def __init__(self, *args, **kwargs):
1044 super(WebpageInput, self).__init__(*args, **kwargs)
1045 self.gracedb_server = self.opts.gracedb_server
1046 self.gracedb_data = self.opts.gracedb_data
1047 self.gracedb = self.opts.gracedb
1048 self.public = self.opts.public
1049 if not hasattr(self, "preliminary_pages"):
1050 self.preliminary_pages = None
1051 if not hasattr(self, "pepredicates_probs"):
1052 self.pepredicates_probs = []
1053 if not hasattr(self, "pastro_probs"):
1054 self.pastro_probs = []
1057class WebpagePlusPlottingInput(PlottingInput, WebpageInput):
1058 """Class to handle and store webpage and plotting specific command line
1059 arguments
1060 """
1061 def __init__(self, *args, **kwargs):
1062 super(WebpagePlusPlottingInput, self).__init__(*args, **kwargs)
1064 @property
1065 def default_directories(self):
1066 return super(WebpagePlusPlottingInput, self).default_directories
1068 @property
1069 def default_files_to_copy(self):
1070 return super(WebpagePlusPlottingInput, self).default_files_to_copy
1073class MetaFileInput(SamplesInput, pesummary.core.cli.inputs.MetaFileInput):
1074 """Class to handle and store metafile specific command line arguments
1075 """
1076 @property
1077 def default_directories(self):
1078 dirs = super(MetaFileInput, self).default_directories
1079 dirs += ["psds", "calibration"]
1080 return dirs
1082 def copy_files(self):
1083 _error = "Failed to save the {} to file"
1084 for label in self.labels:
1085 if self.psd[label] != {}:
1086 for ifo in self.psd[label].keys():
1087 if not isinstance(self.psd[label][ifo], PSD):
1088 logger.warning(_error.format("{} PSD".format(ifo)))
1089 continue
1090 self.psd[label][ifo].save_to_file(
1091 os.path.join(self.webdir, "psds", "{}_{}_psd.dat".format(
1092 label, ifo
1093 ))
1094 )
1095 if label in self.priors["calibration"].keys():
1096 if self.priors["calibration"][label] != {}:
1097 for ifo in self.priors["calibration"][label].keys():
1098 _instance = isinstance(
1099 self.priors["calibration"][label][ifo], Calibration
1100 )
1101 if not _instance:
1102 logger.warning(
1103 _error.format(
1104 "{} calibration envelope".format(
1105 ifo
1106 )
1107 )
1108 )
1109 continue
1110 self.priors["calibration"][label][ifo].save_to_file(
1111 os.path.join(self.webdir, "calibration", "{}_{}_cal.txt".format(
1112 label, ifo
1113 ))
1114 )
1115 return super(MetaFileInput, self).copy_files()
1118class WebpagePlusPlottingPlusMetaFileInput(MetaFileInput, WebpagePlusPlottingInput):
1119 """Class to handle and store webpage, plotting and metafile specific command
1120 line arguments
1121 """
1122 def __init__(self, *args, **kwargs):
1123 super(WebpagePlusPlottingPlusMetaFileInput, self).__init__(
1124 *args, **kwargs
1125 )
1127 @property
1128 def default_directories(self):
1129 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_directories
1131 @property
1132 def default_files_to_copy(self):
1133 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_files_to_copy
1136@deprecation(
1137 "The GWInput class is deprecated. Please use either the BaseInput, "
1138 "SamplesInput, PlottingInput, WebpageInput, WebpagePlusPlottingInput, "
1139 "MetaFileInput or the WebpagePlusPlottingPlusMetaFileInput class"
1140)
1141class GWInput(WebpagePlusPlottingPlusMetaFileInput):
1142 pass
1145class IMRCTInput(pesummary.core.cli.inputs._Input):
1146 """Class to handle the TGR specific command line arguments
1147 """
1148 @property
1149 def labels(self):
1150 return self._labels
1152 @labels.setter
1153 def labels(self, labels):
1154 self._labels = labels
1155 if len(labels) % 2 != 0:
1156 raise ValueError(
1157 "The IMRCT test requires 2 results files for each analysis. "
1158 )
1159 elif len(labels) > 2:
1160 cond = all(
1161 ":inspiral" in label or ":postinspiral" in label for label in
1162 labels
1163 )
1164 if not cond:
1165 raise ValueError(
1166 "To compare 2 or more analyses, please provide labels as "
1167 "'{}:inspiral' and '{}:postinspiral' where {} indicates "
1168 "the analysis label"
1169 )
1170 else:
1171 self.analysis_label = [
1172 label.split(":inspiral")[0]
1173 for label in labels
1174 if ":inspiral" in label and ":postinspiral" not in label
1175 ]
1176 if len(self.analysis_label) != len(self.result_files) / 2:
1177 raise ValueError(
1178 "When comparing more than 2 analyses, labels must "
1179 "be of the form '{}:inspiral' and '{}:postinspiral'."
1180 )
1181 logger.info(
1182 "Using the labels: {} to distinguish analyses".format(
1183 ", ".join(self.analysis_label)
1184 )
1185 )
1186 elif sorted(labels) != ["inspiral", "postinspiral"]:
1187 if all(self.is_pesummary_metafile(ff) for ff in self.result_files):
1188 meta_file_labels = []
1189 for suffix in [":inspiral", ":postinspiral"]:
1190 if any(suffix in label for label in labels):
1191 ind = [
1192 num for num, label in enumerate(labels) if
1193 suffix in label
1194 ]
1195 if len(ind) > 1:
1196 raise ValueError(
1197 "Please provide a single {} label".format(
1198 suffix.split(":")[1]
1199 )
1200 )
1201 meta_file_labels.append(
1202 labels[ind[0]].split(suffix)[0]
1203 )
1204 else:
1205 raise ValueError(
1206 "Please provide labels as {inspiral_label}:inspiral "
1207 "and {postinspiral_label}:postinspiral where "
1208 "inspiral_label and postinspiral_label are the "
1209 "PESummary labels for the inspiral and postinspiral "
1210 "analyses respectively. "
1211 )
1212 if len(self.result_files) == 1:
1213 logger.info(
1214 "Using the {} samples for the inspiral analysis and {} "
1215 "samples for the postinspiral analysis from the file "
1216 "{}".format(
1217 meta_file_labels[0], meta_file_labels[1],
1218 self.result_files[0]
1219 )
1220 )
1221 elif len(self.result_files) == 2:
1222 logger.info(
1223 "Using the {} samples for the inspiral analysis from "
1224 "the file {}. Using the {} samples for the "
1225 "postinspiral analysis from the file {}".format(
1226 meta_file_labels[0], self.result_files[0],
1227 meta_file_labels[1], self.result_files[1]
1228 )
1229 )
1230 else:
1231 raise ValueError(
1232 "Currently, you can only provide at most 2 pesummary "
1233 "metafiles. If one is provided, both the inspiral and "
1234 "postinspiral are extracted from that single file. If "
1235 "two are provided, the inspiral is extracted from one "
1236 "file and the postinspiral is extracted from the other."
1237 )
1238 self._labels = ["inspiral", "postinspiral"]
1239 self._meta_file_labels = meta_file_labels
1240 self.analysis_label = ["primary"]
1241 else:
1242 raise ValueError(
1243 "The IMRCT test requires an inspiral and postinspiral result "
1244 "file. Please indicate which file is the inspiral and which "
1245 "is postinspiral by providing these exact labels to the "
1246 "summarytgr executable"
1247 )
1248 else:
1249 self.analysis_label = ["primary"]
1251 def _extract_stored_approximant(self, opened_file, label):
1252 """Extract the approximant used for a given analysis stored in a
1253 PESummary metafile
1255 Parameters
1256 ----------
1257 opened_file: pesummary.gw.file.formats.pesummary.PESummary
1258 opened metafile that contains the analysis 'label'
1259 label: str
1260 analysis label which is stored in the PESummary metafile
1261 """
1262 if opened_file.approximant is not None:
1263 if label not in opened_file.labels:
1264 raise ValueError(
1265 "Invalid label {}. The list of available labels are {}".format(
1266 label, ", ".join(opened_file.labels)
1267 )
1268 )
1269 _index = opened_file.labels.index(label)
1270 return opened_file.approximant[_index]
1271 return
1273 def _extract_stored_remnant_fits(self, opened_file, label):
1274 """Extract the remnant fits used for a given analysis stored in a
1275 PESummary metafile
1277 Parameters
1278 ----------
1279 opened_file: pesummary.gw.file.formats.pesummary.PESummary
1280 opened metafile that contains the analysis 'label'
1281 label: str
1282 analysis label which is stored in the PESummary metafile
1283 """
1284 fits = {}
1285 fit_strings = [
1286 "final_spin_NR_fits", "final_mass_NR_fits"
1287 ]
1288 if label not in opened_file.labels:
1289 raise ValueError(
1290 "Invalid label {}. The list of available labels are {}".format(
1291 label, ", ".join(opened_file.labels)
1292 )
1293 )
1294 _index = opened_file.labels.index(label)
1295 _meta_data = opened_file.extra_kwargs[_index]
1296 if "meta_data" in _meta_data.keys():
1297 for key in fit_strings:
1298 if key in _meta_data["meta_data"].keys():
1299 fits[key] = _meta_data["meta_data"][key]
1300 if len(fits):
1301 return fits
1302 return
1304 def _extract_stored_cutoff_frequency(self, opened_file, label):
1305 """Extract the cutoff frequencies used for a given analysis stored in a
1306 PESummary metafile
1308 Parameters
1309 ----------
1310 opened_file: pesummary.gw.file.formats.pesummary.PESummary
1311 opened metafile that contains the analysis 'label'
1312 label: str
1313 analysis label which is stored in the PESummary metafile
1314 """
1315 frequencies = {}
1316 if opened_file.config is not None:
1317 if label not in opened_file.labels:
1318 raise ValueError(
1319 "Invalid label {}. The list of available labels are {}".format(
1320 label, ", ".join(opened_file.labels)
1321 )
1322 )
1323 if opened_file.config[label] is not None:
1324 _config = opened_file.config[label]
1325 if "config" in _config.keys():
1326 if "maximum-frequency" in _config["config"].keys():
1327 frequencies["fhigh"] = _config["config"][
1328 "maximum-frequency"
1329 ]
1330 if "minimum-frequency" in _config["config"].keys():
1331 frequencies["flow"] = _config["config"][
1332 "minimum-frequency"
1333 ]
1334 elif "lalinference" in _config.keys():
1335 if "fhigh" in _config["lalinference"].keys():
1336 frequencies["fhigh"] = _config["lalinference"][
1337 "fhigh"
1338 ]
1339 if "flow" in _config["lalinference"].keys():
1340 frequencies["flow"] = _config["lalinference"][
1341 "flow"
1342 ]
1343 return frequencies
1344 return
1346 @property
1347 def samples(self):
1348 return self._samples
1350 @samples.setter
1351 def samples(self, samples):
1352 from pesummary.utils.samples_dict import MultiAnalysisSamplesDict
1353 self._read_samples = {
1354 _label: GWRead(_path, disable_prior=True) for _label, _path in zip(
1355 self.labels, self.result_files
1356 )
1357 }
1358 _samples_dict = {}
1359 _approximant_dict = {}
1360 _cutoff_frequency_dict = {}
1361 _remnant_fits_dict = {}
1362 for label, _open in self._read_samples.items():
1363 if isinstance(_open.samples_dict, MultiAnalysisSamplesDict):
1364 if not len(self._meta_file_labels):
1365 raise ValueError(
1366 "Currently you can only pass a file containing a "
1367 "single analysis or a valid PESummary metafile "
1368 "containing multiple analyses"
1369 )
1370 _labels = _open.labels
1371 if len(self._read_samples) == 1:
1372 _samples_dict = {
1373 label: _open.samples_dict[meta_file_label] for
1374 label, meta_file_label in zip(
1375 self.labels, self._meta_file_labels
1376 )
1377 }
1378 for label, meta_file_label in zip(self.labels, self._meta_file_labels):
1379 _stored_approx = self._extract_stored_approximant(
1380 _open, meta_file_label
1381 )
1382 _stored_frequencies = self._extract_stored_cutoff_frequency(
1383 _open, meta_file_label
1384 )
1385 _stored_remnant_fits = self._extract_stored_remnant_fits(
1386 _open, meta_file_label
1387 )
1388 if _stored_approx is not None:
1389 _approximant_dict[label] = _stored_approx
1390 if _stored_remnant_fits is not None:
1391 _remnant_fits_dict[label] = _stored_remnant_fits
1392 if _stored_frequencies is not None:
1393 if label == "inspiral":
1394 if "fhigh" in _stored_frequencies.keys():
1395 _cutoff_frequency_dict[label] = _stored_frequencies[
1396 "fhigh"
1397 ]
1398 if label == "postinspiral":
1399 if "flow" in _stored_frequencies.keys():
1400 _cutoff_frequency_dict[label] = _stored_frequencies[
1401 "flow"
1402 ]
1403 break
1404 else:
1405 ind = self.labels.index(label)
1406 _samples_dict[label] = _open.samples_dict[
1407 self._meta_file_labels[ind]
1408 ]
1409 _stored_approx = self._extract_stored_approximant(
1410 _open, self._meta_file_labels[ind]
1411 )
1412 _stored_frequencies = self._extract_stored_cutoff_frequency(
1413 _open, self._meta_file_labels[ind]
1414 )
1415 _stored_remnant_fits = self._extract_stored_remnant_fits(
1416 _open, self._meta_file_labels[ind]
1417 )
1418 if _stored_approx is not None:
1419 _approximant_dict[label] = _stored_approx
1420 if _stored_remnant_fits is not None:
1421 _remnant_fits_dict[label] = _stored_remnant_fits
1422 if _stored_frequencies is not None:
1423 if label == "inspiral":
1424 if "fhigh" in _stored_frequencies.keys():
1425 _cutoff_frequency_dict[label] = _stored_frequencies[
1426 "fhigh"
1427 ]
1428 if label == "postinspiral":
1429 if "flow" in _stored_frequencies.keys():
1430 _cutoff_frequency_dict[label] = _stored_frequencies[
1431 "flow"
1432 ]
1433 else:
1434 _samples_dict[label] = _open.samples_dict
1435 extra_kwargs = _open.extra_kwargs
1436 if "pe_algorithm" in extra_kwargs["sampler"].keys():
1437 if extra_kwargs["sampler"]["pe_algorithm"] == "bilby":
1438 try:
1439 subkwargs = extra_kwargs["other"]["likelihood"][
1440 "waveform_arguments"
1441 ]
1442 _approximant_dict[label] = (
1443 subkwargs["waveform_approximant"]
1444 )
1445 if "inspiral" in label and "postinspiral" not in label:
1446 _cutoff_frequency_dict[label] = (
1447 subkwargs["maximum_frequency"]
1448 )
1449 elif "postinspiral" in label:
1450 _cutoff_frequency_dict[label] = (
1451 subkwargs["minimum_frequency"]
1452 )
1453 except KeyError:
1454 pass
1455 self._samples = MultiAnalysisSamplesDict(_samples_dict)
1456 if len(_approximant_dict):
1457 self._approximant_dict = _approximant_dict
1458 if len(_cutoff_frequency_dict):
1459 self._cutoff_frequency_dict = _cutoff_frequency_dict
1460 if len(_remnant_fits_dict):
1461 self._remnant_fits_dict = _remnant_fits_dict
1463 @property
1464 def imrct_kwargs(self):
1465 return self._imrct_kwargs
1467 @imrct_kwargs.setter
1468 def imrct_kwargs(self, imrct_kwargs):
1469 test_kwargs = dict(N_bins=101)
1470 try:
1471 test_kwargs.update(imrct_kwargs)
1472 except AttributeError:
1473 test_kwargs = test_kwargs
1475 for key, value in test_kwargs.items():
1476 try:
1477 test_kwargs[key] = ast.literal_eval(value)
1478 except ValueError:
1479 pass
1480 self._imrct_kwargs = test_kwargs
1482 @property
1483 def meta_data(self):
1484 return self._meta_data
1486 @meta_data.setter
1487 def meta_data(self, meta_data):
1488 self._meta_data = {}
1489 for num, _inspiral in enumerate(self.inspiral_keys):
1490 frequency_dict = dict()
1491 approximant_dict = dict()
1492 remnant_dict = dict()
1493 zipped = zip(
1494 [self.cutoff_frequency, self.approximant, None],
1495 [frequency_dict, approximant_dict, remnant_dict],
1496 ["cutoff_frequency", "approximant", "remnant_fits"]
1497 )
1498 _inspiral_string = self.inspiral_keys[num]
1499 _postinspiral_string = self.postinspiral_keys[num]
1500 for _list, _dict, name in zipped:
1501 if _list is not None and len(_list) == len(self.labels):
1502 inspiral_ind = self.labels.index(_inspiral_string)
1503 postinspiral_ind = self.labels.index(_postinspiral_string)
1504 _dict["inspiral"] = _list[inspiral_ind]
1505 _dict["postinspiral"] = _list[postinspiral_ind]
1506 elif _list is not None:
1507 raise ValueError(
1508 "Please provide a 'cutoff_frequency' and 'approximant' "
1509 "for each file"
1510 )
1511 else:
1512 try:
1513 if name == "cutoff_frequency":
1514 cond = (
1515 "inspiral" in self._cutoff_frequency_dict.keys()
1516 and "postinspiral" not in
1517 self._cutoff_frequency_dict.keys()
1518 )
1519 if cond:
1520 _dict["inspiral"] = self._cutoff_frequency_dict[
1521 "inspiral"
1522 ]
1523 elif "postinspiral" in self._cutoff_frequency_dict.keys():
1524 _dict["postinspiral"] = self._cutoff_frequency_dict[
1525 "postinspiral"
1526 ]
1527 elif name == "approximant":
1528 cond = (
1529 "inspiral" in self._approximant_dict.keys()
1530 and "postinspiral" not in
1531 self._approximant_dict.keys()
1532 )
1533 if cond:
1534 _dict["inspiral"] = self._approximant_dict[
1535 "inspiral"
1536 ]
1537 elif "postinspiral" in self._approximant_dict.keys():
1538 _dict["postinspiral"] = self._approximant_dict[
1539 "postinspiral"
1540 ]
1541 elif name == "remnant_fits":
1542 cond = (
1543 "inspiral" in self._remnant_fits_dict.keys()
1544 and "postinspiral" not in
1545 self._remnant_fits_dict.keys()
1546 )
1547 if cond:
1548 _dict["inspiral"] = self._remnant_fits_dict[
1549 "inspiral"
1550 ]
1551 elif "postinspiral" in self._remnant_fits_dict.keys():
1552 _dict["postinspiral"] = self._remnant_fits_dict[
1553 "postinspiral"
1554 ]
1555 except (AttributeError, KeyError, TypeError):
1556 _dict["inspiral"] = None
1557 _dict["postinspiral"] = None
1559 self._meta_data[self.analysis_label[num]] = {
1560 "inspiral maximum frequency (Hz)": frequency_dict["inspiral"],
1561 "postinspiral minimum frequency (Hz)": frequency_dict["postinspiral"],
1562 "inspiral approximant": approximant_dict["inspiral"],
1563 "postinspiral approximant": approximant_dict["postinspiral"],
1564 "inspiral remnant fits": remnant_dict["inspiral"],
1565 "postinspiral remnant fits": remnant_dict["postinspiral"]
1566 }
1568 def __init__(self, opts):
1569 self.opts = opts
1570 self.existing = None
1571 self.webdir = self.opts.webdir
1572 self.user = None
1573 self.baseurl = None
1574 self.result_files = self.opts.samples
1575 self.labels = self.opts.labels
1576 self.samples = self.opts.samples
1577 self.inspiral_keys = [
1578 key for key in self.samples.keys() if "inspiral" in key
1579 and "postinspiral" not in key
1580 ]
1581 self.postinspiral_keys = [
1582 key.replace("inspiral", "postinspiral") for key in self.inspiral_keys
1583 ]
1584 try:
1585 self.imrct_kwargs = self.opts.imrct_kwargs
1586 except AttributeError:
1587 self.imrct_kwargs = {}
1588 for _arg in ["cutoff_frequency", "approximant", "links_to_pe_pages", "f_low"]:
1589 _attr = getattr(self.opts, _arg)
1590 if _attr is not None and len(_attr) and len(_attr) != len(self.labels):
1591 raise ValueError("Please provide a {} for each file".format(_arg))
1592 setattr(self, _arg, _attr)
1593 self.meta_data = None
1594 self.default_directories = ["samples", "plots", "js", "html", "css"]
1595 self.publication = False
1596 self.make_directories()