Coverage for pesummary/gw/plots/main.py: 67.5%
437 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-02 08:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-02 08:42 +0000
1#! /usr/bin/env python
3# Licensed under an MIT style license -- see LICENSE.md
5import os
7from pesummary.core.plots.main import _PlotGeneration as _BasePlotGeneration
8from pesummary.core.plots.latex_labels import latex_labels
9from pesummary.core.plots import interactive
10from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE
11from pesummary.gw.plots.latex_labels import GWlatex_labels
12from pesummary.utils.utils import logger, resample_posterior_distribution
13from pesummary.utils.decorators import no_latex_plot
14from pesummary.gw.plots import publication
15from pesummary.gw.plots import plot as gw
17import multiprocessing as mp
18import numpy as np
20__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
21latex_labels.update(GWlatex_labels)
24class _PlotGeneration(_BasePlotGeneration):
25 def __init__(
26 self, savedir=None, webdir=None, labels=None, samples=None,
27 kde_plot=False, existing_labels=None, existing_injection_data=None,
28 existing_file_kwargs=None, existing_samples=None,
29 existing_metafile=None, same_parameters=None, injection_data=None,
30 result_files=None, file_kwargs=None, colors=None, custom_plotting=None,
31 add_to_existing=False, priors={}, no_ligo_skymap=False,
32 nsamples_for_skymap=None, detectors=None, maxL_samples=None,
33 gwdata=None, calibration=None, psd=None,
34 multi_threading_for_skymap=None, approximant=None,
35 pepredicates_probs=None, include_prior=False, publication=False,
36 existing_approximant=None, existing_psd=None, existing_calibration=None,
37 existing_weights=None, weights=None, disable_comparison=False,
38 linestyles=None, disable_interactive=False, disable_corner=False,
39 publication_kwargs={}, multi_process=1, mcmc_samples=False,
40 skymap=None, existing_skymap=None, corner_params=None,
41 preliminary_pages=False, expert_plots=True, checkpoint=False,
42 key_data=None
43 ):
44 super(_PlotGeneration, self).__init__(
45 savedir=savedir, webdir=webdir, labels=labels,
46 samples=samples, kde_plot=kde_plot, existing_labels=existing_labels,
47 existing_injection_data=existing_injection_data,
48 existing_samples=existing_samples,
49 existing_weights=existing_weights,
50 same_parameters=same_parameters,
51 injection_data=injection_data, mcmc_samples=mcmc_samples,
52 colors=colors, custom_plotting=custom_plotting,
53 add_to_existing=add_to_existing, priors=priors,
54 include_prior=include_prior, weights=weights,
55 disable_comparison=disable_comparison, linestyles=linestyles,
56 disable_interactive=disable_interactive, disable_corner=disable_corner,
57 multi_process=multi_process, corner_params=corner_params,
58 expert_plots=expert_plots, checkpoint=checkpoint, key_data=key_data
59 )
60 self.preliminary_pages = preliminary_pages
61 if not isinstance(self.preliminary_pages, dict):
62 if self.preliminary_pages:
63 self.preliminary_pages = {
64 label: True for label in self.labels
65 }
66 else:
67 self.preliminary_pages = {
68 label: False for label in self.labels
69 }
70 self.preliminary_comparison_pages = any(
71 value for value in self.preliminary_pages.values()
72 )
73 self.package = "gw"
74 self.file_kwargs = file_kwargs
75 self.existing_file_kwargs = existing_file_kwargs
76 self.no_ligo_skymap = no_ligo_skymap
77 self.nsamples_for_skymap = nsamples_for_skymap
78 self.detectors = detectors
79 self.maxL_samples = maxL_samples
80 self.gwdata = gwdata
81 if skymap is None:
82 skymap = {label: None for label in self.labels}
83 self.skymap = skymap
84 self.existing_skymap = skymap
85 self.calibration = calibration
86 self.existing_calibration = existing_calibration
87 self.psd = psd
88 self.existing_psd = existing_psd
89 self.multi_threading_for_skymap = multi_threading_for_skymap
90 self.approximant = approximant
91 self.existing_approximant = existing_approximant
92 self.pepredicates_probs = pepredicates_probs
93 self.publication = publication
94 self.publication_kwargs = publication_kwargs
95 self._ligo_skymap_PID = {}
97 self.plot_type_dictionary.update({
98 "psd": self.psd_plot,
99 "calibration": self.calibration_plot,
100 "twod_histogram": self.twod_histogram_plot,
101 "skymap": self.skymap_plot,
102 "waveform_fd": self.waveform_fd_plot,
103 "waveform_td": self.waveform_td_plot,
104 "data": self.gwdata_plots,
105 "violin": self.violin_plot,
106 "spin_disk": self.spin_dist_plot,
107 "pepredicates": self.pepredicates_plot
108 })
109 if self.make_comparison:
110 self.plot_type_dictionary.update({
111 "skymap_comparison": self.skymap_comparison_plot,
112 "waveform_comparison_fd": self.waveform_comparison_fd_plot,
113 "waveform_comparison_td": self.waveform_comparison_td_plot,
114 "2d_comparison_contour": self.twod_comparison_contour_plot,
115 })
117 @property
118 def ligo_skymap_PID(self):
119 return self._ligo_skymap_PID
121 def generate_plots(self):
122 """Generate all plots for all result files
123 """
124 if self.calibration or "calibration" in list(self.priors.keys()):
125 self.try_to_make_a_plot("calibration")
126 if self.psd:
127 self.try_to_make_a_plot("psd")
128 super(_PlotGeneration, self).generate_plots()
130 def _generate_plots(self, label):
131 """Generate all plots for a given result file
132 """
133 super(_PlotGeneration, self)._generate_plots(label)
134 self.try_to_make_a_plot("twod_histogram", label=label)
135 self.try_to_make_a_plot("skymap", label=label)
136 self.try_to_make_a_plot("waveform_td", label=label)
137 self.try_to_make_a_plot("waveform_fd", label=label)
138 if self.pepredicates_probs[label] is not None:
139 self.try_to_make_a_plot("pepredicates", label=label)
140 if self.gwdata:
141 self.try_to_make_a_plot("data", label=label)
143 def _generate_comparison_plots(self):
144 """Generate all comparison plots
145 """
146 super(_PlotGeneration, self)._generate_comparison_plots()
147 self.try_to_make_a_plot("skymap_comparison")
148 self.try_to_make_a_plot("waveform_comparison_td")
149 self.try_to_make_a_plot("waveform_comparison_fd")
150 if self.publication:
151 self.try_to_make_a_plot("2d_comparison_contour")
152 self.try_to_make_a_plot("violin")
153 self.try_to_make_a_plot("spin_disk")
155 @staticmethod
156 def _corner_plot(
157 savedir, label, samples, latex_labels, webdir, params, preliminary=False,
158 checkpoint=False
159 ):
160 """Generate a corner plot for a given set of samples
162 Parameters
163 ----------
164 savedir: str
165 the directory you wish to save the plot in
166 label: str
167 the label corresponding to the results file
168 samples: dict
169 dictionary of samples for a given result file
170 latex_labels: dict
171 dictionary of latex labels
172 webdir: str
173 directory where the javascript is written
174 preliminary: Bool, optional
175 if True, add a preliminary watermark to the plot
176 """
177 import warnings
179 with warnings.catch_warnings():
180 warnings.simplefilter("ignore")
181 filename = os.path.join(
182 savedir, "corner", "{}_all_density_plots.png".format(label)
183 )
184 if os.path.isfile(filename) and checkpoint:
185 pass
186 else:
187 fig, params, data = gw._make_corner_plot(
188 samples, latex_labels, corner_parameters=params
189 )
190 fig.savefig(filename)
191 fig.close()
192 combine_corner = open(
193 os.path.join(webdir, "js", "combine_corner.js")
194 )
195 combine_corner = combine_corner.readlines()
196 params = [str(i) for i in params]
197 ind = [
198 linenumber for linenumber, line in enumerate(combine_corner)
199 if "var list = {}" in line
200 ][0]
201 combine_corner.insert(
202 ind + 1, " list['{}'] = {};\n".format(label, params)
203 )
204 new_file = open(
205 os.path.join(webdir, "js", "combine_corner.js"), "w"
206 )
207 new_file.writelines(combine_corner)
208 new_file.close()
209 combine_corner = open(
210 os.path.join(webdir, "js", "combine_corner.js")
211 )
212 combine_corner = combine_corner.readlines()
213 params = [str(i) for i in params]
214 ind = [
215 linenumber for linenumber, line in enumerate(combine_corner)
216 if "var data = {}" in line
217 ][0]
218 combine_corner.insert(
219 ind + 1, " data['{}'] = {};\n".format(label, data)
220 )
221 new_file = open(
222 os.path.join(webdir, "js", "combine_corner.js"), "w"
223 )
224 new_file.writelines(combine_corner)
225 new_file.close()
227 filename = os.path.join(
228 savedir, "corner", "{}_sourceframe.png".format(label)
229 )
230 if os.path.isfile(filename) and checkpoint:
231 pass
232 else:
233 fig = gw._make_source_corner_plot(samples, latex_labels)
234 fig.savefig(filename)
235 fig.close()
236 filename = os.path.join(
237 savedir, "corner", "{}_extrinsic.png".format(label)
238 )
239 if os.path.isfile(filename) and checkpoint:
240 pass
241 else:
242 fig = gw._make_extrinsic_corner_plot(samples, latex_labels)
243 fig.savefig(filename)
244 fig.close()
246 def twod_histogram_plot(self, label):
247 """
248 """
249 from pesummary import conf
250 error_message = (
251 "Failed to generate %s-%s triangle plot because {}"
252 )
253 paramset = [
254 params for params in conf.gw_2d_plots if
255 all(p in self.samples[label] for p in params)
256 ]
257 arguments = [
258 (
259 [
260 self.savedir, label, params,
261 [self.samples[label][p] for p in params],
262 [latex_labels[p] for p in params],
263 [self.injection_data[label][p] for p in params],
264 self.preliminary_pages[label], self.checkpoint
265 ], self._triangle_plot, error_message % (params[0], params[1])
266 ) for params in paramset
267 ]
268 self.pool.starmap(self._try_to_make_a_plot, arguments)
270 @staticmethod
271 def _triangle_plot(
272 savedir, label, params, samples, latex_labels, injection, preliminary=False,
273 checkpoint=False
274 ):
275 from pesummary.core.plots.publication import triangle_plot
276 import math
277 for num, ii in enumerate(injection):
278 if math.isnan(ii):
279 injection[num] = None
281 if any(ii is None for ii in injection):
282 truth = None
283 else:
284 truth = injection
285 filename = os.path.join(
286 savedir, "{}_2d_posterior_{}_{}.png".format(
287 label, params[0], params[1]
288 )
289 )
290 if os.path.isfile(filename) and checkpoint:
291 return
292 fig, _, _, _ = triangle_plot(
293 *samples, kde=False, parameters=params, xlabel=latex_labels[0],
294 ylabel=latex_labels[1], plot_datapoints=True, plot_density=False,
295 levels=[1e-8], fill=False, grid=True, linewidths=[1.75],
296 percentiles=[5, 95], percentile_plot=[label], labels=[label],
297 truth=truth
298 )
299 _PlotGeneration.save(
300 fig, filename, preliminary=preliminary
301 )
304 def skymap_plot(self, label):
305 """Generate a skymap plot for a given result file
307 Parameters
308 ----------
309 label: str
310 the label for the results file that you wish to plot
311 """
312 try:
313 import ligo.skymap # noqa: F401
314 except ImportError:
315 SKYMAP = False
316 else:
317 SKYMAP = True
319 if self.mcmc_samples:
320 samples = self.samples[label].combine
321 else:
322 samples = self.samples[label]
323 _injection = [
324 self.injection_data[label]["ra"], self.injection_data[label]["dec"]
325 ]
326 self._skymap_plot(
327 self.savedir, samples["ra"], samples["dec"], label,
328 self.weights[label], _injection,
329 preliminary=self.preliminary_pages[label]
330 )
332 if SKYMAP and not self.no_ligo_skymap and self.skymap[label] is None:
333 from pesummary.utils.utils import RedirectLogger
335 logger.info("Launching subprocess to generate skymap plot with "
336 "ligo.skymap")
337 try:
338 _time = samples["geocent_time"]
339 except KeyError:
340 logger.warning(
341 "Unable to find 'geocent_time' in the posterior table for {}. "
342 "The ligo.skymap fits file will therefore not store the "
343 "DATE_OBS field in the header".format(label)
344 )
345 _time = None
346 with RedirectLogger("ligo.skymap", level="DEBUG") as redirector:
347 process = mp.Process(
348 target=self._ligo_skymap_plot,
349 args=[
350 self.savedir, samples["ra"], samples["dec"],
351 samples["luminosity_distance"], _time,
352 label, self.nsamples_for_skymap, self.webdir,
353 self.multi_threading_for_skymap, _injection,
354 self.preliminary_pages[label]
355 ]
356 )
357 process.start()
358 PID = process.pid
359 self._ligo_skymap_PID[label] = PID
360 elif SKYMAP and not self.no_ligo_skymap:
361 self._ligo_skymap_array_plot(
362 self.savedir, self.skymap[label], label,
363 self.preliminary_pages[label]
364 )
366 @staticmethod
367 @no_latex_plot
368 def _skymap_plot(
369 savedir, ra, dec, label, weights, injection=None, preliminary=False
370 ):
371 """Generate a skymap plot for a given set of samples
373 Parameters
374 ----------
375 savedir: str
376 the directory you wish to save the plot in
377 ra: pesummary.utils.utils.Array
378 array containing the samples for right ascension
379 dec: pesummary.utils.utils.Array
380 array containing the samples for declination
381 label: str
382 the label corresponding to the results file
383 weights: list
384 list of weights for the samples
385 injection: list, optional
386 list containing the injected value of ra and dec
387 preliminary: Bool, optional
388 if True, add a preliminary watermark to the plot
389 """
390 import math
392 if injection is not None and any(math.isnan(inj) for inj in injection):
393 injection = None
394 fig = gw._default_skymap_plot(ra, dec, weights, injection=injection)
395 _PlotGeneration.save(
396 fig, os.path.join(savedir, "{}_skymap".format(label)),
397 preliminary=preliminary
398 )
400 @staticmethod
401 @no_latex_plot
402 def _ligo_skymap_plot(savedir, ra, dec, dist, time, label, nsamples_for_skymap,
403 webdir, multi_threading_for_skymap, injection,
404 preliminary=False):
405 """Generate a skymap plot for a given set of samples using the
406 ligo.skymap package
408 Parameters
409 ----------
410 savedir: str
411 the directory you wish to save the plot in
412 ra: pesummary.utils.utils.Array
413 array containing the samples for right ascension
414 dec: pesummary.utils.utils.Array
415 array containing the samples for declination
416 dist: pesummary.utils.utils.Array
417 array containing the samples for luminosity distance
418 time: pesummary.utils.utils.Array
419 array containing the samples for the geocentric time of merger
420 label: str
421 the label corresponding to the results file
422 nsamples_for_skymap: int
423 the number of samples used to generate skymap
424 webdir: str
425 the directory to store the fits file
426 preliminary: Bool, optional
427 if True, add a preliminary watermark to the plot
428 """
429 import math
431 downsampled = False
432 if nsamples_for_skymap is not None:
433 ra, dec, dist = resample_posterior_distribution(
434 [ra, dec, dist], nsamples_for_skymap
435 )
436 downsampled = True
437 if injection is not None and any(math.isnan(inj) for inj in injection):
438 injection = None
439 fig = gw._ligo_skymap_plot(
440 ra, dec, dist=dist, savedir=os.path.join(webdir, "samples"),
441 nprocess=multi_threading_for_skymap, downsampled=downsampled,
442 label=label, time=time, injection=injection
443 )
444 _PlotGeneration.save(
445 fig, os.path.join(savedir, "{}_skymap".format(label)),
446 preliminary=preliminary
447 )
449 @staticmethod
450 @no_latex_plot
451 def _ligo_skymap_array_plot(savedir, skymap, label, preliminary=False):
452 """Generate a skymap based on skymap probability array already generated with
453 `ligo.skymap`
455 Parameters
456 ----------
457 savedir: str
458 the directory you wish to save the plot in
459 skymap: np.ndarray
460 array of skymap probabilities
461 label: str
462 the label corresponding to the results file
463 preliminary: Bool, optional
464 if True, add a preliminary watermark to the plot
465 """
466 fig = gw._ligo_skymap_plot_from_array(skymap)
467 _PlotGeneration.save(
468 fig, os.path.join(savedir, "{}_skymap".format(label)),
469 preliminary=preliminary
470 )
472 def waveform_fd_plot(self, label):
473 """Generate a frequency domain waveform plot for a given result file
475 Parameters
476 ----------
477 label: str
478 the label corresponding to the results file
479 """
480 if self.approximant[label] == {}:
481 return
482 self._waveform_fd_plot(
483 self.savedir, self.detectors[label], self.maxL_samples[label], label,
484 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint,
485 **self.file_kwargs[label]["meta_data"]
486 )
488 @staticmethod
489 def _waveform_fd_plot(
490 savedir, detectors, maxL_samples, label, preliminary=False,
491 checkpoint=False, **kwargs
492 ):
493 """Generate a frequency domain waveform plot for a given detector
494 network and set of samples
496 Parameters
497 ----------
498 savedir: str
499 the directory you wish to save the plot in
500 detectors: list
501 list of detectors used in your analysis
502 maxL_samples: dict
503 dictionary of maximum likelihood values
504 label: str
505 the label corresponding to the results file
506 preliminary: Bool, optional
507 if True, add a preliminary watermark to the plot
508 """
509 filename = os.path.join(savedir, "{}_waveform.png".format(label))
510 if os.path.isfile(filename) and checkpoint:
511 return
512 if detectors is None:
513 detectors = ["H1", "L1"]
514 else:
515 detectors = detectors.split("_")
517 fig = gw._waveform_plot(
518 detectors, maxL_samples, f_min=kwargs.get("f_low", 20.0),
519 f_max=kwargs.get("f_final", 1024.),
520 f_ref=kwargs.get("f_ref", 20.)
521 )
522 _PlotGeneration.save(
523 fig, filename, preliminary=preliminary
524 )
526 def waveform_td_plot(self, label):
527 """Generate a time domain waveform plot for a given result file
529 Parameters
530 ----------
531 label: str
532 the label corresponding to the results file
533 """
534 if self.approximant[label] == {}:
535 return
536 self._waveform_td_plot(
537 self.savedir, self.detectors[label], self.maxL_samples[label], label,
538 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint,
539 **self.file_kwargs[label]["meta_data"]
540 )
542 @staticmethod
543 def _waveform_td_plot(
544 savedir, detectors, maxL_samples, label, preliminary=False,
545 checkpoint=False, **kwargs
546 ):
547 """Generate a time domain waveform plot for a given detector network
548 and set of samples
550 Parameters
551 ----------
552 savedir: str
553 the directory you wish to save the plot in
554 detectors: list
555 list of detectors used in your analysis
556 maxL_samples: dict
557 dictionary of maximum likelihood values
558 label: str
559 the label corresponding to the results file
560 preliminary: Bool, optional
561 if True, add a preliminary watermark to the plot
562 """
563 filename = os.path.join(
564 savedir, "{}_waveform_time_domain.png".format(label)
565 )
566 if os.path.isfile(filename) and checkpoint:
567 return
568 if detectors is None:
569 detectors = ["H1", "L1"]
570 else:
571 detectors = detectors.split("_")
573 fig = gw._time_domain_waveform(
574 detectors, maxL_samples, f_min=kwargs.get("f_low", 20.0),
575 f_max=kwargs.get("f_final", 1024.),
576 f_ref=kwargs.get("f_ref", 20.)
577 )
578 _PlotGeneration.save(
579 fig, filename, preliminary=preliminary
580 )
582 def gwdata_plots(self, label):
583 """Generate all plots associated with the gwdata
585 Parameters
586 ----------
587 label: str
588 the label corresponding to the results file
589 """
590 from pesummary.utils.utils import determine_gps_time_and_window
592 base_error = "Failed to generate a %s because {}"
593 gps_time, window = determine_gps_time_and_window(
594 self.maxL_samples, self.labels
595 )
596 functions = [
597 self.strain_plot, self.spectrogram_plot, self.omegascan_plot
598 ]
599 args = [[label], [], [gps_time, window]]
600 func_names = ["strain_plot", "spectrogram plot", "omegascan plot"]
602 for func, args, name in zip(functions, args, func_names):
603 self._try_to_make_a_plot(args, func, base_error % (name))
604 continue
606 def strain_plot(self, label):
607 """Generate a plot showing the comparison between the data and the
608 maxL waveform gfor a given result file
610 Parameters
611 ----------
612 label: str
613 the label corresponding to the results file
614 """
615 logger.info("Launching subprocess to generate strain plot")
616 process = mp.Process(
617 target=self._strain_plot,
618 args=[self.savedir, self.gwdata, self.maxL_samples[label], label]
619 )
620 process.start()
622 @staticmethod
623 def _strain_plot(savedir, gwdata, maxL_samples, label, checkpoint=False):
624 """Generate a strain plot for a given set of samples
626 Parameters
627 ----------
628 savedir: str
629 the directory to save the plot
630 gwdata: dict
631 dictionary of strain data for each detector
632 maxL_samples: dict
633 dictionary of maximum likelihood values
634 label: str
635 the label corresponding to the results file
636 """
637 filename = os.path.join(savedir, "{}_strain.png".format(label))
638 if os.path.isfile(filename) and checkpoint:
639 return
640 fig = gw._strain_plot(gwdata, maxL_samples)
641 _PlotGeneration.save(fig, filename)
643 def spectrogram_plot(self):
644 """Generate a plot showing the spectrogram for all detectors
645 """
646 figs = self._spectrogram_plot(self.savedir, self.gwdata)
648 @staticmethod
649 def _spectrogram_plot(savedir, strain):
650 """Generate a plot showing the spectrogram for all detectors
652 Parameters
653 ----------
654 savedir: str
655 the directory you wish to save the plot in
656 strain: dict
657 dictionary of gwpy timeseries objects containing the strain data for
658 each IFO
659 """
660 from pesummary.gw.plots import detchar
662 figs = detchar.spectrogram(strain)
663 for det, fig in figs.items():
664 _PlotGeneration.save(
665 fig, os.path.join(savedir, "spectrogram_{}".format(det))
666 )
668 def omegascan_plot(self, gps_time, window):
669 """Generate a plot showing the omegascan for all detectors
671 Parameters
672 ----------
673 gps_time: float
674 time around which to centre the omegascan
675 window: float
676 window around gps time to generate plot for
677 """
678 figs = self._omegascan_plot(
679 self.savedir, self.gwdata, gps_time, window
680 )
682 @staticmethod
683 def _omegascan_plot(savedir, strain, gps, window):
684 """Generate a plot showing the spectrogram for all detectors
686 Parameters
687 ----------
688 savedir: str
689 the directory you wish to save the plot in
690 strain: dict
691 dictionary of gwpy timeseries objects containing the strain data for
692 each IFO
693 gps: float
694 time around which to centre the omegascan
695 window: float
696 window around gps time to generate plot for
697 """
698 from pesummary.gw.plots import detchar
700 figs = detchar.omegascan(strain, gps, window=window)
701 for det, fig in figs.items():
702 _PlotGeneration.save(
703 fig, os.path.join(savedir, "omegascan_{}".format(det))
704 )
706 def skymap_comparison_plot(self, label):
707 """Generate a plot to compare skymaps for all result files
709 Parameters
710 ----------
711 label: str
712 the label for the results file that you wish to plot
713 """
714 self._skymap_comparison_plot(
715 self.savedir, self.same_samples["ra"], self.same_samples["dec"],
716 self.labels, self.colors, self.preliminary_comparison_pages,
717 self.checkpoint
718 )
720 @staticmethod
721 def _skymap_comparison_plot(
722 savedir, ra, dec, labels, colors, preliminary=False, checkpoint=False
723 ):
724 """Generate a plot to compare skymaps for a given set of samples
726 Parameters
727 ----------
728 savedir: str
729 the directory you wish to save the plot in
730 ra: dict
731 dictionary of right ascension samples for each result file
732 dec: dict
733 dictionary of declination samples for each result file
734 labels: list
735 list of labels to distinguish each result file
736 colors: list
737 list of colors to be used to distinguish different result files
738 preliminary: Bool, optional
739 if True, add a preliminary watermark to the plot
740 """
741 filename = os.path.join(savedir, "combined_skymap.png")
742 if os.path.isfile(filename) and checkpoint:
743 return
744 ra_list = [ra[key] for key in labels]
745 dec_list = [dec[key] for key in labels]
746 fig = gw._sky_map_comparison_plot(ra_list, dec_list, labels, colors)
747 _PlotGeneration.save(
748 fig, filename, preliminary=preliminary
749 )
751 def waveform_comparison_fd_plot(self, label):
752 """Generate a plot to compare the frequency domain waveform
754 Parameters
755 ----------
756 label: str
757 the label for the results file that you wish to plot
758 """
759 if any(self.approximant[i] == {} for i in self.labels):
760 return
762 self._waveform_comparison_fd_plot(
763 self.savedir, self.maxL_samples, self.labels, self.colors,
764 preliminary=self.preliminary_comparison_pages, checkpoint=self.checkpoint,
765 **self.file_kwargs
766 )
768 @staticmethod
769 def _waveform_comparison_fd_plot(
770 savedir, maxL_samples, labels, colors, preliminary=False,
771 checkpoint=False, **kwargs
772 ):
773 """Generate a plot to compare the frequency domain waveforms
775 Parameters
776 ----------
777 savedir: str
778 the directory you wish to save the plot in
779 maxL_samples: dict
780 dictionary of maximum likelihood samples for each result file
781 labels: list
782 list of labels to distinguish each result file
783 colors: list
784 list of colors to be used to distinguish different result files
785 preliminary: Bool, optional
786 if True, add a preliminary watermark to the plot
787 """
788 filename = os.path.join(savedir, "compare_waveforms.png")
789 if os.path.isfile(filename) and checkpoint:
790 return
791 samples = [maxL_samples[i] for i in labels]
792 f_min = np.max(
793 [kwargs[label]["meta_data"].get("f_low", 20.) for label in labels]
794 )
795 f_max = np.min(
796 [kwargs[label]["meta_data"].get("f_final", 1024.) for label in labels]
797 )
798 f_ref = kwargs[labels[0]]["meta_data"].get("f_ref", 20.)
799 fig = gw._waveform_comparison_plot(
800 samples, colors, labels, f_min=f_min, f_max=f_max,
801 f_ref=f_ref
802 )
803 _PlotGeneration.save(
804 fig, filename, preliminary=preliminary
805 )
807 def waveform_comparison_td_plot(self, label):
808 """Generate a plot to compare the time domain waveform
810 Parameters
811 ----------
812 label: str
813 the label for the results file that you wish to plot
814 """
815 if any(self.approximant[i] == {} for i in self.labels):
816 return
818 self._waveform_comparison_fd_plot(
819 self.savedir, self.maxL_samples, self.labels, self.colors,
820 self.preliminary_comparison_pages, self.checkpoint
821 )
823 @staticmethod
824 def _waveform_comparison_td_plot(
825 savedir, maxL_samples, labels, colors, preliminary=False,
826 checkpoint=False
827 ):
828 """Generate a plot to compare the time domain waveforms
830 Parameters
831 ----------
832 savedir: str
833 the directory you wish to save the plot in
834 maxL_samples: dict
835 dictionary of maximum likelihood samples for each result file
836 labels: list
837 list of labels to distinguish each result file
838 colors: list
839 list of colors to be used to distinguish different result files
840 preliminary: Bool, optional
841 if True, add a preliminary watermark to the plot
842 """
843 filename = os.path.join(savedir, "compare_time_domain_waveforms.png")
844 if os.path.isfile(filename) and checkpoint:
845 return
846 samples = [maxL_samples[i] for i in labels]
847 fig = gw._time_domainwaveform_comparison_plot(samples, colors, labels)
848 _PlotGeneration.save(
849 fig, filename, preliminary=preliminary
850 )
852 def twod_comparison_contour_plot(self, label):
853 """Generate 2d comparison contour plots
855 Parameters
856 ----------
857 label: str
858 the label for the results file that you wish to plot
859 """
860 error_message = (
861 "Failed to generate a 2d contour plot for %s because {}"
862 )
863 twod_plots = [
864 ["mass_ratio", "chi_eff"], ["mass_1", "mass_2"],
865 ["luminosity_distance", "chirp_mass_source"],
866 ["mass_1_source", "mass_2_source"],
867 ["theta_jn", "luminosity_distance"],
868 ["network_optimal_snr", "chirp_mass_source"]
869 ]
870 gridsize = (
871 int(self.publication_kwargs["gridsize"]) if "gridsize" in
872 self.publication_kwargs.keys() else 100
873 )
874 for plot in twod_plots:
875 if not all(
876 all(
877 i in self.samples[j].keys() for i in plot
878 ) for j in self.labels
879 ):
880 logger.warning(
881 "Failed to generate 2d contour plots for {} because {} are not "
882 "common in all result files".format(
883 " and ".join(plot), " and ".join(plot)
884 )
885 )
886 continue
887 samples = [[self.samples[i][j] for j in plot] for i in self.labels]
888 arguments = [
889 self.savedir, plot, samples, self.labels, latex_labels,
890 self.colors, self.linestyles, gridsize,
891 self.preliminary_comparison_pages, self.checkpoint
892 ]
893 self._try_to_make_a_plot(
894 arguments, self._twod_comparison_contour_plot,
895 error_message % (" and ".join(plot))
896 )
898 @staticmethod
899 def _twod_comparison_contour_plot(
900 savedir, plot_parameters, samples, labels, latex_labels, colors,
901 linestyles, gridsize, preliminary=False, checkpoint=False
902 ):
903 """Generate a 2d comparison contour plot for a given set of samples
905 Parameters
906 ----------
907 savedir: str
908 the directory you wish to save the plot in
909 plot_parameters: list
910 list of parameters to use for the 2d contour plot
911 samples: list
912 list of samples for each parameter
913 labels: list
914 list of labels used to distinguish each result file
915 latex_labels: dict
916 dictionary containing the latex labels for each parameter
917 gridsize: int
918 the number of points to use when estimating the KDE
919 preliminary: Bool, optional
920 if True, add a preliminary watermark to the plot
921 """
922 filename = os.path.join(
923 savedir, "publication", "2d_contour_plot_{}.png".format(
924 "_and_".join(plot_parameters)
925 )
926 )
927 if os.path.isfile(filename) and checkpoint:
928 return
929 fig = publication.twod_contour_plots(
930 plot_parameters, samples, labels, latex_labels, colors=colors,
931 linestyles=linestyles, gridsize=gridsize
932 )
933 _PlotGeneration.save(
934 fig, filename, preliminary=preliminary
935 )
937 def violin_plot(self, label):
938 """Generate violin plot to compare certain parameters in all result
939 files
941 Parameters
942 ----------
943 label: str
944 the label for the results file that you wish to plot
945 """
946 error_message = (
947 "Failed to generate a violin plot for %s because {}"
948 )
949 violin_plots = ["mass_ratio", "chi_eff", "chi_p", "luminosity_distance"]
951 for plot in violin_plots:
952 injection = [self.injection_data[label][plot] for label in self.labels]
953 if not all(plot in self.samples[j].keys() for j in self.labels):
954 logger.warning(
955 "Failed to generate violin plots for {} because {} is not "
956 "common in all result files".format(plot, plot)
957 )
958 samples = [self.samples[i][plot] for i in self.labels]
959 arguments = [
960 self.savedir, plot, samples, self.labels, latex_labels[plot],
961 injection, self.preliminary_comparison_pages, self.checkpoint
962 ]
963 self._try_to_make_a_plot(
964 arguments, self._violin_plot, error_message % (plot)
965 )
967 @staticmethod
968 def _violin_plot(
969 savedir, plot_parameter, samples, labels, latex_label, inj_values=None,
970 preliminary=False, checkpoint=False, kde=ReflectionBoundedKDE,
971 default_bounds=True
972 ):
973 """Generate a violin plot for a given set of samples
975 Parameters
976 ----------
977 savedir: str
978 the directory you wish to save the plot in
979 plot_parameter: str
980 name of the parameter you wish to generate a violin plot for
981 samples: list
982 list of samples for each parameter
983 labels: list
984 list of labels used to distinguish each result file
985 latex_label: str
986 latex_label correspondig to parameter
987 inj_value: list
988 list of injected values for each sample
989 preliminary: Bool, optional
990 if True, add a preliminary watermark to the plot
991 """
992 filename = os.path.join(
993 savedir, "publication", "violin_plot_{}.png".format(plot_parameter)
994 )
995 if os.path.isfile(filename) and checkpoint:
996 return
997 xlow, xhigh = None, None
998 if default_bounds:
999 xlow, xhigh = gw._return_bounds(
1000 plot_parameter, samples, comparison=True
1001 )
1002 fig = publication.violin_plots(
1003 plot_parameter, samples, labels, latex_labels, kde=kde,
1004 kde_kwargs={"xlow": xlow, "xhigh": xhigh}, inj_values=inj_values
1005 )
1006 _PlotGeneration.save(
1007 fig, filename, preliminary=preliminary
1008 )
1010 def spin_dist_plot(self, label):
1011 """Generate a spin disk plot to compare spins in all result
1012 files
1014 Parameters
1015 ----------
1016 label: str
1017 the label for the results file that you wish to plot
1018 """
1019 error_message = (
1020 "Failed to generate a spin disk plot for %s because {}"
1021 )
1022 parameters = ["a_1", "a_2", "cos_tilt_1", "cos_tilt_2"]
1023 for num, label in enumerate(self.labels):
1024 if not all(i in self.samples[label].keys() for i in parameters):
1025 logger.warning(
1026 "Failed to generate spin disk plots because {} are not "
1027 "common in all result files".format(
1028 " and ".join(parameters)
1029 )
1030 )
1031 continue
1032 samples = [self.samples[label][i] for i in parameters]
1033 arguments = [
1034 self.savedir, parameters, samples, label, self.colors[num],
1035 self.preliminary_comparison_pages, self.checkpoint
1036 ]
1038 self._try_to_make_a_plot(
1039 arguments, self._spin_dist_plot, error_message % (label)
1040 )
1042 @staticmethod
1043 def _spin_dist_plot(
1044 savedir, parameters, samples, label, color, preliminary=False,
1045 checkpoint=False
1046 ):
1047 """Generate a spin disk plot for a given set of samples
1049 Parameters
1050 ----------
1051 preliminary: Bool, optional
1052 if True, add a preliminary watermark to the plot
1053 """
1054 filename = os.path.join(
1055 savedir, "publication", "spin_disk_plot_{}.png".format(label)
1056 )
1057 if os.path.isfile(filename) and checkpoint:
1058 return
1059 fig = publication.spin_distribution_plots(
1060 parameters, samples, label, color=color
1061 )
1062 _PlotGeneration.save(
1063 fig, filename, preliminary=preliminary
1064 )
1066 def pepredicates_plot(self, label):
1067 """Generate plots with the PEPredicates package
1069 Parameters
1070 ----------
1071 label: str
1072 the label for the results file that you wish to plot
1073 """
1074 if self.mcmc_samples:
1075 samples = self.samples[label].combine
1076 else:
1077 samples = self.samples[label]
1078 self._pepredicates_plot(
1079 self.savedir, samples, label,
1080 self.pepredicates_probs[label]["default"], population_prior=False,
1081 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint
1082 )
1083 self._pepredicates_plot(
1084 self.savedir, samples, label,
1085 self.pepredicates_probs[label]["population"], population_prior=True,
1086 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint
1087 )
1089 @staticmethod
1090 @no_latex_plot
1091 def _pepredicates_plot(
1092 savedir, samples, label, probabilities, population_prior=False,
1093 preliminary=False, checkpoint=False
1094 ):
1095 """Generate a plot with the PEPredicates package for a given set of
1096 samples
1098 Parameters
1099 ----------
1100 savedir: str
1101 the directory you wish to save the plot in
1102 samples: dict
1103 dictionary of samples for each parameter
1104 label: str
1105 the label corresponding to the result file
1106 probabilities: dict
1107 dictionary of classification probabilities
1108 population_prior: Bool, optional
1109 if True, the samples will be reweighted according to a population
1110 prior
1111 preliminary: Bool, optional
1112 if True, add a preliminary watermark to the plot
1113 """
1114 from pesummary.gw.classification import PEPredicates
1116 if not population_prior:
1117 filename = os.path.join(
1118 savedir, "{}_default_pepredicates.png".format(label)
1119 )
1120 else:
1121 filename = os.path.join(
1122 savedir, "{}_population_pepredicates.png".format(label)
1123 )
1125 _pepredicates = PEPredicates(samples)
1126 if os.path.isfile(filename) and checkpoint:
1127 pass
1128 else:
1129 fig = _pepredicates.plot(
1130 type="pepredicates", population=population_prior,
1131 probabilities=probabilities
1132 )
1133 _PlotGeneration.save(
1134 fig, filename, preliminary=preliminary
1135 )
1137 if not population_prior:
1138 filename = os.path.join(
1139 savedir, "{}_default_pepredicates_bar.png".format(label)
1140 )
1141 else:
1142 filename = os.path.join(
1143 savedir, "{}_population_pepredicates_bar.png".format(label)
1144 )
1145 if os.path.isfile(filename) and checkpoint:
1146 pass
1147 else:
1148 fig = _pepredicates.plot(
1149 type="bar", probabilities=probabilities,
1150 population=population_prior
1151 )
1152 _PlotGeneration.save(
1153 fig, filename, preliminary=preliminary
1154 )
1156 def psd_plot(self, label):
1157 """Generate a psd plot for a given result file
1159 Parameters
1160 ----------
1161 label: str
1162 the label corresponding to the result file
1163 """
1164 error_message = (
1165 "Failed to generate a PSD plot for %s because {}"
1166 )
1168 fmin = None
1169 fmax = None
1171 for num, label in enumerate(self.labels):
1172 if list(self.psd[label].keys()) == [None]:
1173 return
1174 if list(self.psd[label].keys()) == []:
1175 return
1176 if "f_low" in list(self.file_kwargs[label]["meta_data"].keys()):
1177 fmin = self.file_kwargs[label]["meta_data"]["f_low"]
1178 if "f_final" in list(self.file_kwargs[label]["meta_data"].keys()):
1179 fmax = self.file_kwargs[label]["meta_data"]["f_final"]
1180 labels = list(self.psd[label].keys())
1181 frequencies = [np.array(self.psd[label][i]).T[0] for i in labels]
1182 strains = [np.array(self.psd[label][i]).T[1] for i in labels]
1183 arguments = [
1184 self.savedir, frequencies, strains, fmin, fmax, labels, label,
1185 self.checkpoint
1186 ]
1188 self._try_to_make_a_plot(
1189 arguments, self._psd_plot, error_message % (label)
1190 )
1192 @staticmethod
1193 def _psd_plot(
1194 savedir, frequencies, strains, fmin, fmax, psd_labels, label, checkpoint=False
1195 ):
1196 """Generate a psd plot for a given set of samples
1198 Parameters
1199 ----------
1200 savedir: str
1201 the directory you wish to save the plot in
1202 frequencies: list
1203 list of psd frequencies for each IFO
1204 strains: list
1205 list of psd strains for each IFO
1206 fmin: float
1207 frequency to start the psd plotting
1208 fmax: float
1209 frequency to end the psd plotting
1210 psd_labels: list
1211 list of IFOs used
1212 label: str
1213 the label used to distinguish the result file
1214 """
1215 filename = os.path.join(savedir, "{}_psd_plot.png".format(label))
1216 if os.path.isfile(filename) and checkpoint:
1217 return
1218 fig = gw._psd_plot(
1219 frequencies, strains, labels=psd_labels, fmin=fmin, fmax=fmax
1220 )
1221 _PlotGeneration.save(fig, filename)
1223 def calibration_plot(self, label):
1224 """Generate a calibration plot for a given result file
1226 Parameters
1227 ----------
1228 label: str
1229 the label corresponding to the result file
1230 """
1231 import numpy as np
1233 error_message = (
1234 "Failed to generate calibration plot for %s because {}"
1235 )
1236 frequencies = np.arange(20., 1024., 1. / 4)
1238 for num, label in enumerate(self.labels):
1239 if list(self.calibration[label].keys()) == [None]:
1240 return
1241 if list(self.calibration[label].keys()) == []:
1242 return
1244 ifos = list(self.calibration[label].keys())
1245 calibration_data = [
1246 self.calibration[label][i] for i in ifos
1247 ]
1248 if "calibration" in self.priors.keys():
1249 prior = [self.priors["calibration"][label][i] for i in ifos]
1250 else:
1251 prior = None
1252 arguments = [
1253 self.savedir, frequencies, calibration_data, ifos, prior,
1254 label, self.checkpoint
1255 ]
1256 self._try_to_make_a_plot(
1257 arguments, self._calibration_plot, error_message % (label)
1258 )
1260 @staticmethod
1261 def _calibration_plot(
1262 savedir, frequencies, calibration_data, calibration_labels, prior, label,
1263 checkpoint=False
1264 ):
1265 """Generate a calibration plot for a given set of samples
1267 Parameters
1268 ----------
1269 savedir: str
1270 the directory you wish to save the plot in
1271 frequencies: list
1272 list of frequencies used to interpolate the calibration data
1273 calibration_data: list
1274 list of calibration data for each IFO
1275 calibration_labels: list
1276 list of IFOs used
1277 prior: list
1278 list containing the priors used for each IFO
1279 label: str
1280 the label used to distinguish the result file
1281 """
1282 filename = os.path.join(
1283 savedir, "{}_calibration_plot.png".format(label)
1284 )
1285 if os.path.isfile(filename) and checkpoint:
1286 return
1287 fig = gw._calibration_envelope_plot(
1288 frequencies, calibration_data, calibration_labels, prior=prior
1289 )
1290 _PlotGeneration.save(fig, filename)
1292 @staticmethod
1293 def _interactive_corner_plot(
1294 savedir, label, samples, latex_labels, checkpoint=False
1295 ):
1296 """Generate an interactive corner plot for a given set of samples
1298 Parameters
1299 ----------
1300 savedir: str
1301 the directory you wish to save the plot in
1302 label: str
1303 the label corresponding to the results file
1304 samples: dict
1305 dictionary containing PESummary.utils.utils.Array objects that
1306 contain samples for each parameter
1307 latex_labels: str
1308 latex labels for each parameter in samples
1309 """
1310 filename = os.path.join(
1311 savedir, "corner", "{}_interactive_source.html".format(label)
1312 )
1313 if os.path.isfile(filename) and checkpoint:
1314 pass
1315 else:
1316 source_parameters = [
1317 "luminosity_distance", "mass_1_source", "mass_2_source",
1318 "total_mass_source", "chirp_mass_source", "redshift"
1319 ]
1320 parameters = [i for i in samples.keys() if i in source_parameters]
1321 data = [samples[parameter] for parameter in parameters]
1322 labels = [latex_labels[parameter] for parameter in parameters]
1323 _ = interactive.corner(
1324 data, labels, write_to_html_file=filename,
1325 dimensions={"width": 900, "height": 900}
1326 )
1328 filename = os.path.join(
1329 savedir, "corner", "{}_interactive_extrinsic.html".format(label)
1330 )
1331 if os.path.isfile(filename) and checkpoint:
1332 pass
1333 else:
1334 extrinsic_parameters = ["luminosity_distance", "psi", "ra", "dec"]
1335 parameters = [i for i in samples.keys() if i in extrinsic_parameters]
1336 data = [samples[parameter] for parameter in parameters]
1337 labels = [latex_labels[parameter] for parameter in parameters]
1338 _ = interactive.corner(
1339 data, labels, write_to_html_file=filename
1340 )