Coverage for pesummary/gw/plots/main.py: 65.8%
488 statements
« prev ^ index » next coverage.py v7.4.4, created at 2026-01-15 17:49 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2026-01-15 17:49 +0000
1#! /usr/bin/env python
3# Licensed under an MIT style license -- see LICENSE.md
5import os
7from pesummary.core.plots.main import _PlotGeneration as _BasePlotGeneration
8from pesummary.core.plots.latex_labels import latex_labels
9from pesummary.core.plots import interactive
10from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE
11from pesummary.gw.plots.latex_labels import GWlatex_labels
12from pesummary.utils.utils import logger, resample_posterior_distribution
13from pesummary.utils.decorators import no_latex_plot
14from pesummary.gw.plots import publication
15from pesummary.gw.plots import plot as gw
17import multiprocessing as mp
18import numpy as np
20__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
21latex_labels.update(GWlatex_labels)
24class _PlotGeneration(_BasePlotGeneration):
25 def __init__(
26 self, savedir=None, webdir=None, labels=None, samples=None,
27 kde_plot=False, existing_labels=None, existing_injection_data=None,
28 existing_file_kwargs=None, existing_samples=None,
29 existing_metafile=None, same_parameters=None, injection_data=None,
30 result_files=None, file_kwargs=None, colors=None, custom_plotting=None,
31 add_to_existing=False, priors={}, no_ligo_skymap=False,
32 nsamples_for_skymap=None, detectors=None, maxL_samples=None,
33 gwdata=None, calibration_definition=None, calibration=None, psd=None,
34 multi_threading_for_skymap=None, approximant=None,
35 classification_probs=None, include_prior=False, publication=False,
36 existing_approximant=None, existing_psd=None, existing_calibration=None,
37 existing_weights=None, weights=None, disable_comparison=False,
38 linestyles=None, disable_interactive=False, disable_corner=False,
39 publication_kwargs={}, multi_process=1, mcmc_samples=False,
40 skymap=None, existing_skymap=None, corner_params=None,
41 preliminary_pages=False, expert_plots=True, checkpoint=False,
42 key_data=None
43 ):
44 super(_PlotGeneration, self).__init__(
45 savedir=savedir, webdir=webdir, labels=labels,
46 samples=samples, kde_plot=kde_plot, existing_labels=existing_labels,
47 existing_injection_data=existing_injection_data,
48 existing_samples=existing_samples,
49 existing_weights=existing_weights,
50 same_parameters=same_parameters,
51 injection_data=injection_data, mcmc_samples=mcmc_samples,
52 colors=colors, custom_plotting=custom_plotting,
53 add_to_existing=add_to_existing, priors=priors,
54 include_prior=include_prior, weights=weights,
55 disable_comparison=disable_comparison, linestyles=linestyles,
56 disable_interactive=disable_interactive, disable_corner=disable_corner,
57 multi_process=multi_process, corner_params=corner_params,
58 expert_plots=expert_plots, checkpoint=checkpoint, key_data=key_data
59 )
60 self.preliminary_pages = preliminary_pages
61 if not isinstance(self.preliminary_pages, dict):
62 if self.preliminary_pages:
63 self.preliminary_pages = {
64 label: True for label in self.labels
65 }
66 else:
67 self.preliminary_pages = {
68 label: False for label in self.labels
69 }
70 self.preliminary_comparison_pages = any(
71 value for value in self.preliminary_pages.values()
72 )
73 self.package = "gw"
74 self.file_kwargs = file_kwargs
75 self.existing_file_kwargs = existing_file_kwargs
76 self.no_ligo_skymap = no_ligo_skymap
77 self.nsamples_for_skymap = nsamples_for_skymap
78 self.detectors = detectors
79 self.maxL_samples = maxL_samples
80 self.gwdata = gwdata
81 if skymap is None:
82 skymap = {label: None for label in self.labels}
83 self.skymap = skymap
84 self.existing_skymap = skymap
85 self.calibration_definition = calibration_definition
86 self.calibration = calibration
87 self.existing_calibration = existing_calibration
88 self.psd = psd
89 self.existing_psd = existing_psd
90 self.multi_threading_for_skymap = multi_threading_for_skymap
91 self.approximant = approximant
92 self.existing_approximant = existing_approximant
93 self.classification_probs = classification_probs
94 self.publication = publication
95 self.publication_kwargs = publication_kwargs
96 self._ligo_skymap_PID = {}
98 self.plot_type_dictionary.update({
99 "psd": self.psd_plot,
100 "calibration": self.calibration_plot,
101 "twod_histogram": self.twod_histogram_plot,
102 "skymap": self.skymap_plot,
103 "waveform_fd": self.waveform_fd_plot,
104 "waveform_td": self.waveform_td_plot,
105 "data": self.gwdata_plots,
106 "violin": self.violin_plot,
107 "spin_disk": self.spin_dist_plot,
108 "classification": self.classification_plot
109 })
110 if self.make_comparison:
111 self.plot_type_dictionary.update({
112 "skymap_comparison": self.skymap_comparison_plot,
113 "waveform_comparison_fd": self.waveform_comparison_fd_plot,
114 "waveform_comparison_td": self.waveform_comparison_td_plot,
115 "2d_comparison_contour": self.twod_comparison_contour_plot,
116 })
118 @property
119 def ligo_skymap_PID(self):
120 return self._ligo_skymap_PID
122 def generate_plots(self):
123 """Generate all plots for all result files
124 """
125 if self.calibration or "calibration" in list(self.priors.keys()):
126 self.try_to_make_a_plot("calibration")
127 if self.psd:
128 self.try_to_make_a_plot("psd")
129 super(_PlotGeneration, self).generate_plots()
131 def _generate_plots(self, label):
132 """Generate all plots for a given result file
133 """
134 super(_PlotGeneration, self)._generate_plots(label)
135 self.try_to_make_a_plot("twod_histogram", label=label)
136 self.try_to_make_a_plot("skymap", label=label)
137 self.try_to_make_a_plot("waveform_td", label=label)
138 self.try_to_make_a_plot("waveform_fd", label=label)
139 if self.classification_probs[label] is not None:
140 self.try_to_make_a_plot("classification", label=label)
141 if self.gwdata:
142 self.try_to_make_a_plot("data", label=label)
144 def _generate_comparison_plots(self):
145 """Generate all comparison plots
146 """
147 super(_PlotGeneration, self)._generate_comparison_plots()
148 self.try_to_make_a_plot("skymap_comparison")
149 self.try_to_make_a_plot("waveform_comparison_td")
150 self.try_to_make_a_plot("waveform_comparison_fd")
151 if self.publication:
152 self.try_to_make_a_plot("2d_comparison_contour")
153 self.try_to_make_a_plot("violin")
154 self.try_to_make_a_plot("spin_disk")
156 @staticmethod
157 def _corner_plot(
158 savedir, label, samples, latex_labels, webdir, params, preliminary=False,
159 checkpoint=False
160 ):
161 """Generate a corner plot for a given set of samples
163 Parameters
164 ----------
165 savedir: str
166 the directory you wish to save the plot in
167 label: str
168 the label corresponding to the results file
169 samples: dict
170 dictionary of samples for a given result file
171 latex_labels: dict
172 dictionary of latex labels
173 webdir: str
174 directory where the javascript is written
175 preliminary: Bool, optional
176 if True, add a preliminary watermark to the plot
177 """
178 import warnings
180 with warnings.catch_warnings():
181 warnings.simplefilter("ignore")
182 filename = os.path.join(
183 savedir, "corner", "{}_all_density_plots.png".format(label)
184 )
185 if os.path.isfile(filename) and checkpoint:
186 pass
187 else:
188 fig, params, data = gw._make_corner_plot(
189 samples, latex_labels, corner_parameters=params
190 )
191 fig.savefig(filename)
192 fig.close()
193 combine_corner = open(
194 os.path.join(webdir, "js", "combine_corner.js")
195 )
196 combine_corner = combine_corner.readlines()
197 params = [str(i) for i in params]
198 ind = [
199 linenumber for linenumber, line in enumerate(combine_corner)
200 if "var list = {}" in line
201 ][0]
202 combine_corner.insert(
203 ind + 1, " list['{}'] = {};\n".format(label, params)
204 )
205 new_file = open(
206 os.path.join(webdir, "js", "combine_corner.js"), "w"
207 )
208 new_file.writelines(combine_corner)
209 new_file.close()
210 combine_corner = open(
211 os.path.join(webdir, "js", "combine_corner.js")
212 )
213 combine_corner = combine_corner.readlines()
214 params = [str(i) for i in params]
215 ind = [
216 linenumber for linenumber, line in enumerate(combine_corner)
217 if "var data = {}" in line
218 ][0]
219 combine_corner.insert(
220 ind + 1, " data['{}'] = {};\n".format(label, data)
221 )
222 new_file = open(
223 os.path.join(webdir, "js", "combine_corner.js"), "w"
224 )
225 new_file.writelines(combine_corner)
226 new_file.close()
228 filename = os.path.join(
229 savedir, "corner", "{}_sourceframe.png".format(label)
230 )
231 if os.path.isfile(filename) and checkpoint:
232 pass
233 else:
234 fig = gw._make_source_corner_plot(samples, latex_labels)
235 fig.savefig(filename)
236 fig.close()
237 filename = os.path.join(
238 savedir, "corner", "{}_extrinsic.png".format(label)
239 )
240 if os.path.isfile(filename) and checkpoint:
241 pass
242 else:
243 fig = gw._make_extrinsic_corner_plot(samples, latex_labels)
244 fig.savefig(filename)
245 fig.close()
247 def twod_histogram_plot(self, label):
248 """
249 """
250 from pesummary import conf
251 error_message = (
252 "Failed to generate %s-%s triangle plot because {}"
253 )
254 paramset = [
255 params for params in conf.gw_2d_plots if
256 all(p in self.samples[label] for p in params)
257 ]
258 if self.weights is not None:
259 weights = self.weights.get(label, None)
260 else:
261 weights = None
262 arguments = [
263 (
264 [
265 self.savedir, label, params,
266 [self.samples[label][p] for p in params],
267 [latex_labels[p] for p in params],
268 [self.injection_data[label][p] for p in params],
269 weights, self.preliminary_pages[label], self.checkpoint
270 ], self._triangle_plot, error_message % (params[0], params[1])
271 ) for params in paramset
272 ]
273 self.pool.starmap(self._try_to_make_a_plot, arguments)
275 @staticmethod
276 def _triangle_plot(
277 savedir, label, params, samples, latex_labels, injection, weights,
278 preliminary=False, checkpoint=False
279 ):
280 from pesummary.core.plots.publication import triangle_plot
281 import math
282 for num, ii in enumerate(injection):
283 if math.isnan(ii):
284 injection[num] = None
286 if any(ii is None for ii in injection):
287 truth = None
288 else:
289 truth = injection
290 filename = os.path.join(
291 savedir, "{}_2d_posterior_{}_{}.png".format(
292 label, params[0], params[1]
293 )
294 )
295 if os.path.isfile(filename) and checkpoint:
296 return
297 fig, _, _, _ = triangle_plot(
298 *samples, kde=False, parameters=params, xlabel=latex_labels[0],
299 ylabel=latex_labels[1], plot_datapoints=True, plot_density=False,
300 levels=[1e-8], fill=False, grid=True, linewidths=[1.75],
301 percentiles=[5, 95], percentile_plot=[label], labels=[label],
302 truth=truth, weights=weights, data_kwargs={"alpha": 0.3}
303 )
304 _PlotGeneration.save(
305 fig, filename, preliminary=preliminary
306 )
309 def skymap_plot(self, label):
310 """Generate a skymap plot for a given result file
312 Parameters
313 ----------
314 label: str
315 the label for the results file that you wish to plot
316 """
317 try:
318 import ligo.skymap # noqa: F401
319 except ImportError:
320 SKYMAP = False
321 else:
322 SKYMAP = True
324 if self.mcmc_samples:
325 samples = self.samples[label].combine
326 else:
327 samples = self.samples[label]
328 _injection = [
329 self.injection_data[label]["ra"], self.injection_data[label]["dec"]
330 ]
331 self._skymap_plot(
332 self.savedir, samples["ra"], samples["dec"], label,
333 self.weights[label], _injection,
334 preliminary=self.preliminary_pages[label]
335 )
337 if SKYMAP and not self.no_ligo_skymap and self.skymap[label] is None:
338 from pesummary.utils.utils import RedirectLogger
340 logger.info("Launching subprocess to generate skymap plot with "
341 "ligo.skymap")
342 try:
343 _time = samples["geocent_time"]
344 except KeyError:
345 logger.warning(
346 "Unable to find 'geocent_time' in the posterior table for {}. "
347 "The ligo.skymap fits file will therefore not store the "
348 "DATE_OBS field in the header".format(label)
349 )
350 _time = None
351 with RedirectLogger("ligo.skymap", level="DEBUG") as redirector:
352 process = mp.Process(
353 target=self._ligo_skymap_plot,
354 args=[
355 self.savedir, samples["ra"], samples["dec"],
356 samples["luminosity_distance"], _time,
357 label, self.nsamples_for_skymap, self.webdir,
358 self.multi_threading_for_skymap, _injection,
359 self.preliminary_pages[label]
360 ]
361 )
362 process.start()
363 #PID = process.pid
364 self._ligo_skymap_PID[label] = process
365 elif SKYMAP and not self.no_ligo_skymap:
366 self._ligo_skymap_array_plot(
367 self.savedir, self.skymap[label], label,
368 self.preliminary_pages[label]
369 )
371 @staticmethod
372 @no_latex_plot
373 def _skymap_plot(
374 savedir, ra, dec, label, weights, injection=None, preliminary=False
375 ):
376 """Generate a skymap plot for a given set of samples
378 Parameters
379 ----------
380 savedir: str
381 the directory you wish to save the plot in
382 ra: pesummary.utils.utils.Array
383 array containing the samples for right ascension
384 dec: pesummary.utils.utils.Array
385 array containing the samples for declination
386 label: str
387 the label corresponding to the results file
388 weights: list
389 list of weights for the samples
390 injection: list, optional
391 list containing the injected value of ra and dec
392 preliminary: Bool, optional
393 if True, add a preliminary watermark to the plot
394 """
395 import math
397 if injection is not None and any(math.isnan(inj) for inj in injection):
398 injection = None
399 fig = gw._default_skymap_plot(ra, dec, weights, injection=injection)
400 _PlotGeneration.save(
401 fig, os.path.join(savedir, "{}_skymap".format(label)),
402 preliminary=preliminary
403 )
405 @staticmethod
406 @no_latex_plot
407 def _ligo_skymap_plot(savedir, ra, dec, dist, time, label, nsamples_for_skymap,
408 webdir, multi_threading_for_skymap, injection,
409 preliminary=False):
410 """Generate a skymap plot for a given set of samples using the
411 ligo.skymap package
413 Parameters
414 ----------
415 savedir: str
416 the directory you wish to save the plot in
417 ra: pesummary.utils.utils.Array
418 array containing the samples for right ascension
419 dec: pesummary.utils.utils.Array
420 array containing the samples for declination
421 dist: pesummary.utils.utils.Array
422 array containing the samples for luminosity distance
423 time: pesummary.utils.utils.Array
424 array containing the samples for the geocentric time of merger
425 label: str
426 the label corresponding to the results file
427 nsamples_for_skymap: int
428 the number of samples used to generate skymap
429 webdir: str
430 the directory to store the fits file
431 preliminary: Bool, optional
432 if True, add a preliminary watermark to the plot
433 """
434 import math
436 downsampled = False
437 if nsamples_for_skymap is not None:
438 ra, dec, dist = resample_posterior_distribution(
439 [ra, dec, dist], nsamples_for_skymap
440 )
441 downsampled = True
442 if injection is not None and any(math.isnan(inj) for inj in injection):
443 injection = None
444 fig = gw._ligo_skymap_plot(
445 ra, dec, dist=dist, savedir=os.path.join(webdir, "samples"),
446 nprocess=multi_threading_for_skymap, downsampled=downsampled,
447 label=label, time=time, injection=injection
448 )
449 _PlotGeneration.save(
450 fig, os.path.join(savedir, "{}_skymap".format(label)),
451 preliminary=preliminary
452 )
454 @staticmethod
455 @no_latex_plot
456 def _ligo_skymap_array_plot(savedir, skymap, label, preliminary=False):
457 """Generate a skymap based on skymap probability array already generated with
458 `ligo.skymap`
460 Parameters
461 ----------
462 savedir: str
463 the directory you wish to save the plot in
464 skymap: np.ndarray
465 array of skymap probabilities
466 label: str
467 the label corresponding to the results file
468 preliminary: Bool, optional
469 if True, add a preliminary watermark to the plot
470 """
471 fig = gw._ligo_skymap_plot_from_array(skymap)
472 _PlotGeneration.save(
473 fig, os.path.join(savedir, "{}_skymap".format(label)),
474 preliminary=preliminary
475 )
477 def waveform_fd_plot(self, label):
478 """Generate a frequency domain waveform plot for a given result file
480 Parameters
481 ----------
482 label: str
483 the label corresponding to the results file
484 """
485 if self.approximant[label] == {}:
486 return
487 self._waveform_fd_plot(
488 self.savedir, self.detectors[label], self.maxL_samples[label], label,
489 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint,
490 **self.file_kwargs[label]["meta_data"]
491 )
493 @staticmethod
494 def _waveform_fd_plot(
495 savedir, detectors, maxL_samples, label, preliminary=False,
496 checkpoint=False, **kwargs
497 ):
498 """Generate a frequency domain waveform plot for a given detector
499 network and set of samples
501 Parameters
502 ----------
503 savedir: str
504 the directory you wish to save the plot in
505 detectors: list
506 list of detectors used in your analysis
507 maxL_samples: dict
508 dictionary of maximum likelihood values
509 label: str
510 the label corresponding to the results file
511 preliminary: Bool, optional
512 if True, add a preliminary watermark to the plot
513 """
514 filename = os.path.join(savedir, "{}_waveform.png".format(label))
515 if os.path.isfile(filename) and checkpoint:
516 return
517 if detectors is None:
518 detectors = ["H1", "L1"]
519 else:
520 detectors = detectors.split("_")
522 fig = gw._waveform_plot(
523 detectors, maxL_samples, f_start=kwargs.get("f_start", 20.),
524 f_low=kwargs.get("f_low", 20.0),
525 f_max=kwargs.get("f_final", 1024.),
526 f_ref=kwargs.get("f_ref", 20.),
527 approximant_flags=kwargs.get("approximant_flags", {})
528 )
529 _PlotGeneration.save(
530 fig, filename, preliminary=preliminary
531 )
533 def waveform_td_plot(self, label):
534 """Generate a time domain waveform plot for a given result file
536 Parameters
537 ----------
538 label: str
539 the label corresponding to the results file
540 """
541 if self.approximant[label] == {}:
542 return
543 self._waveform_td_plot(
544 self.savedir, self.detectors[label], self.maxL_samples[label], label,
545 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint,
546 **self.file_kwargs[label]["meta_data"]
547 )
549 @staticmethod
550 def _waveform_td_plot(
551 savedir, detectors, maxL_samples, label, preliminary=False,
552 checkpoint=False, **kwargs
553 ):
554 """Generate a time domain waveform plot for a given detector network
555 and set of samples
557 Parameters
558 ----------
559 savedir: str
560 the directory you wish to save the plot in
561 detectors: list
562 list of detectors used in your analysis
563 maxL_samples: dict
564 dictionary of maximum likelihood values
565 label: str
566 the label corresponding to the results file
567 preliminary: Bool, optional
568 if True, add a preliminary watermark to the plot
569 """
570 filename = os.path.join(
571 savedir, "{}_waveform_time_domain.png".format(label)
572 )
573 if os.path.isfile(filename) and checkpoint:
574 return
575 if detectors is None:
576 detectors = ["H1", "L1"]
577 else:
578 detectors = detectors.split("_")
580 fig = gw._time_domain_waveform(
581 detectors, maxL_samples, f_start=kwargs.get("f_start", 20.),
582 f_low=kwargs.get("f_low", 20.0),
583 f_max=kwargs.get("f_final", 1024.),
584 f_ref=kwargs.get("f_ref", 20.),
585 approximant_flags=kwargs.get("approximant_flags", {})
586 )
587 _PlotGeneration.save(
588 fig, filename, preliminary=preliminary
589 )
591 def gwdata_plots(self, label):
592 """Generate all plots associated with the gwdata
594 Parameters
595 ----------
596 label: str
597 the label corresponding to the results file
598 """
599 from pesummary.utils.utils import determine_gps_time_and_window
601 base_error = "Failed to generate a %s because {}"
602 gps_time, window = determine_gps_time_and_window(
603 self.maxL_samples, self.labels
604 )
605 functions = [
606 self.strain_plot, self.spectrogram_plot, self.omegascan_plot
607 ]
608 args = [[label], [], [gps_time, window]]
609 func_names = ["strain_plot", "spectrogram plot", "omegascan plot"]
611 for func, args, name in zip(functions, args, func_names):
612 self._try_to_make_a_plot(args, func, base_error % (name))
613 continue
615 def strain_plot(self, label):
616 """Generate a plot showing the comparison between the data and the
617 maxL waveform gfor a given result file
619 Parameters
620 ----------
621 label: str
622 the label corresponding to the results file
623 """
624 logger.info("Launching subprocess to generate strain plot")
625 process = mp.Process(
626 target=self._strain_plot,
627 args=[self.savedir, self.gwdata, self.maxL_samples[label], label]
628 )
629 process.start()
631 @staticmethod
632 def _strain_plot(savedir, gwdata, maxL_samples, label, checkpoint=False):
633 """Generate a strain plot for a given set of samples
635 Parameters
636 ----------
637 savedir: str
638 the directory to save the plot
639 gwdata: dict
640 dictionary of strain data for each detector
641 maxL_samples: dict
642 dictionary of maximum likelihood values
643 label: str
644 the label corresponding to the results file
645 """
646 filename = os.path.join(savedir, "{}_strain.png".format(label))
647 if os.path.isfile(filename) and checkpoint:
648 return
649 fig = gw._strain_plot(gwdata, maxL_samples)
650 _PlotGeneration.save(fig, filename)
652 def spectrogram_plot(self):
653 """Generate a plot showing the spectrogram for all detectors
654 """
655 figs = self._spectrogram_plot(self.savedir, self.gwdata)
657 @staticmethod
658 def _spectrogram_plot(savedir, strain):
659 """Generate a plot showing the spectrogram for all detectors
661 Parameters
662 ----------
663 savedir: str
664 the directory you wish to save the plot in
665 strain: dict
666 dictionary of gwpy timeseries objects containing the strain data for
667 each IFO
668 """
669 from pesummary.gw.plots import detchar
671 figs = detchar.spectrogram(strain)
672 for det, fig in figs.items():
673 _PlotGeneration.save(
674 fig, os.path.join(savedir, "spectrogram_{}".format(det))
675 )
677 def omegascan_plot(self, gps_time, window):
678 """Generate a plot showing the omegascan for all detectors
680 Parameters
681 ----------
682 gps_time: float
683 time around which to centre the omegascan
684 window: float
685 window around gps time to generate plot for
686 """
687 figs = self._omegascan_plot(
688 self.savedir, self.gwdata, gps_time, window
689 )
691 @staticmethod
692 def _omegascan_plot(savedir, strain, gps, window):
693 """Generate a plot showing the spectrogram for all detectors
695 Parameters
696 ----------
697 savedir: str
698 the directory you wish to save the plot in
699 strain: dict
700 dictionary of gwpy timeseries objects containing the strain data for
701 each IFO
702 gps: float
703 time around which to centre the omegascan
704 window: float
705 window around gps time to generate plot for
706 """
707 from pesummary.gw.plots import detchar
709 figs = detchar.omegascan(strain, gps, window=window)
710 for det, fig in figs.items():
711 _PlotGeneration.save(
712 fig, os.path.join(savedir, "omegascan_{}".format(det))
713 )
715 def skymap_comparison_plot(self, label):
716 """Generate a plot to compare skymaps for all result files
718 Parameters
719 ----------
720 label: str
721 the label for the results file that you wish to plot
722 """
723 from pesummary.utils.utils import RedirectLogger
724 self._skymap_comparison_plot(
725 self.savedir, self.same_samples["ra"], self.same_samples["dec"],
726 self.labels, self.colors, self.preliminary_comparison_pages,
727 self.checkpoint
728 )
730 try:
731 import ligo.skymap # noqa: F401
732 except ImportError:
733 return
735 if self.no_ligo_skymap:
736 return
738 logger.info("Launching subprocess to generate comparison skymap plot with "
739 "ligo.skymap")
740 fits_files = [
741 os.path.join(self.webdir, "samples", "{}_skymap.fits".format(label))
742 for label in self.labels
743 ]
744 with RedirectLogger("ligo.skymap", level="DEBUG") as redirector:
745 process = mp.Process(
746 target=self._ligo_skymap_comparison_plot_from_fits,
747 args=[
748 self.savedir, fits_files, self.colors, self.labels,
749 self.preliminary_comparison_pages, self._ligo_skymap_PID
750 ]
751 )
752 process.start()
754 @staticmethod
755 def _skymap_comparison_plot(
756 savedir, ra, dec, labels, colors, preliminary=False, checkpoint=False
757 ):
758 """Generate a plot to compare skymaps for a given set of samples
760 Parameters
761 ----------
762 savedir: str
763 the directory you wish to save the plot in
764 ra: dict
765 dictionary of right ascension samples for each result file
766 dec: dict
767 dictionary of declination samples for each result file
768 labels: list
769 list of labels to distinguish each result file
770 colors: list
771 list of colors to be used to distinguish different result files
772 preliminary: Bool, optional
773 if True, add a preliminary watermark to the plot
774 """
775 filename = os.path.join(savedir, "combined_skymap.png")
776 if os.path.isfile(filename) and checkpoint:
777 return
778 ra_list = [ra[key] for key in labels]
779 dec_list = [dec[key] for key in labels]
780 fig = gw._sky_map_comparison_plot(ra_list, dec_list, labels, colors)
781 _PlotGeneration.save(
782 fig, filename, preliminary=preliminary
783 )
785 @staticmethod
786 @no_latex_plot
787 def _ligo_skymap_comparison_plot_from_fits(
788 savedir, fits_files, colors, labels, preliminary=False, ligo_skymap_PID=None
789 ):
790 """Generate a comparison skymap based on fits files already generated
791 with `ligo.skymap`
793 Parameters
794 ----------
795 savedir: str
796 the directory you wish to save the plot in
797 fits_files: list
798 list of paths to the fits files
799 colors: list
800 list of colors to use for each skymap
801 labels: list
802 list of labels corresponding to each fits file
803 preliminary: Bool, optional
804 if True, add a preliminary watermark to the plot
805 ligo_skymap_PID: dict, optional
806 dictionary of process IDs for the ligo.skymap subprocesses
807 """
808 import ligo.skymap.io
809 import subprocess
810 import time
812 if ligo_skymap_PID:
813 for label, fits_file in zip(labels, fits_files):
814 if label not in ligo_skymap_PID.keys():
815 continue
816 while not os.path.isfile(fits_file):
817 try:
818 output = subprocess.check_output(
819 ["ps -p {}".format(ligo_skymap_PID[label].pid)],
820 shell=True
821 )
822 cond1 = "summarypages" not in str(output)
823 cond2 = "defunct" in str(output)
824 if cond1 or cond2:
825 if not os.path.isfile(_path):
826 FAILURE = True
827 break
828 except subprocess.CalledProcessError:
829 FAILURE = True
830 break
831 # wait for the process to finish
832 time.sleep(60)
834 skymaps = []
835 for fits_file in fits_files:
836 try:
837 skymap, _ = ligo.skymap.io.read_sky_map(
838 fits_file, nest=None
839 )
840 skymaps.append(skymap)
841 except FileNotFoundError:
842 logger.warning(
843 "Failed to find {}. Unable to generate comparison skymap "
844 "plot.".format(fits_file)
845 )
846 return
848 fig = gw._ligo_skymap_comparion_plot_from_array(
849 skymaps, colors, labels
850 )
851 _PlotGeneration.save(
852 fig, os.path.join(savedir, "combined_skymap.png"),
853 preliminary=preliminary
854 )
856 def waveform_comparison_fd_plot(self, label):
857 """Generate a plot to compare the frequency domain waveform
859 Parameters
860 ----------
861 label: str
862 the label for the results file that you wish to plot
863 """
864 if any(self.approximant[i] == {} for i in self.labels):
865 return
867 self._waveform_comparison_fd_plot(
868 self.savedir, self.maxL_samples, self.labels, self.colors,
869 preliminary=self.preliminary_comparison_pages, checkpoint=self.checkpoint,
870 **self.file_kwargs
871 )
873 @staticmethod
874 def _waveform_comparison_fd_plot(
875 savedir, maxL_samples, labels, colors, preliminary=False,
876 checkpoint=False, **kwargs
877 ):
878 """Generate a plot to compare the frequency domain waveforms
880 Parameters
881 ----------
882 savedir: str
883 the directory you wish to save the plot in
884 maxL_samples: dict
885 dictionary of maximum likelihood samples for each result file
886 labels: list
887 list of labels to distinguish each result file
888 colors: list
889 list of colors to be used to distinguish different result files
890 preliminary: Bool, optional
891 if True, add a preliminary watermark to the plot
892 """
893 filename = os.path.join(savedir, "compare_waveforms.png")
894 if os.path.isfile(filename) and checkpoint:
895 return
896 samples = [maxL_samples[i] for i in labels]
897 for num, i in enumerate(labels):
898 samples[num]["approximant_flags"] = kwargs[i]["meta_data"].get(
899 "approximant_flags", {}
900 )
901 _defaults = [20., 20., 1024., 20.]
902 for freq, default in zip(["f_start", "f_low", "f_final", "f_ref"], _defaults):
903 samples[num][freq] = kwargs[i]["meta_data"].get(freq, default)
905 fig = gw._waveform_comparison_plot(samples, colors, labels)
906 _PlotGeneration.save(
907 fig, filename, preliminary=preliminary
908 )
910 def waveform_comparison_td_plot(self, label):
911 """Generate a plot to compare the time domain waveform
913 Parameters
914 ----------
915 label: str
916 the label for the results file that you wish to plot
917 """
918 if any(self.approximant[i] == {} for i in self.labels):
919 return
921 self._waveform_comparison_td_plot(
922 self.savedir, self.maxL_samples, self.labels, self.colors,
923 self.preliminary_comparison_pages, self.checkpoint,
924 **self.file_kwargs
925 )
927 @staticmethod
928 def _waveform_comparison_td_plot(
929 savedir, maxL_samples, labels, colors, preliminary=False,
930 checkpoint=False, **kwargs
931 ):
932 """Generate a plot to compare the time domain waveforms
934 Parameters
935 ----------
936 savedir: str
937 the directory you wish to save the plot in
938 maxL_samples: dict
939 dictionary of maximum likelihood samples for each result file
940 labels: list
941 list of labels to distinguish each result file
942 colors: list
943 list of colors to be used to distinguish different result files
944 preliminary: Bool, optional
945 if True, add a preliminary watermark to the plot
946 """
947 filename = os.path.join(savedir, "compare_time_domain_waveforms.png")
948 if os.path.isfile(filename) and checkpoint:
949 return
950 samples = [maxL_samples[i] for i in labels]
951 for num, i in enumerate(labels):
952 samples[num]["approximant_flags"] = kwargs[i]["meta_data"].get(
953 "approximant_flags", {}
954 )
955 _defaults = [20., 20., 20.]
956 for freq, default in zip(["f_start", "f_low", "f_ref"], _defaults):
957 samples[num][freq] = kwargs[i]["meta_data"].get(freq, default)
959 fig = gw._time_domain_waveform_comparison_plot(samples, colors, labels)
960 _PlotGeneration.save(
961 fig, filename, preliminary=preliminary
962 )
964 def twod_comparison_contour_plot(self, label):
965 """Generate 2d comparison contour plots
967 Parameters
968 ----------
969 label: str
970 the label for the results file that you wish to plot
971 """
972 error_message = (
973 "Failed to generate a 2d contour plot for %s because {}"
974 )
975 twod_plots = [
976 ["mass_ratio", "chi_eff"], ["mass_1", "mass_2"],
977 ["luminosity_distance", "chirp_mass_source"],
978 ["mass_1_source", "mass_2_source"],
979 ["theta_jn", "luminosity_distance"],
980 ["network_optimal_snr", "chirp_mass_source"]
981 ]
982 gridsize = (
983 int(self.publication_kwargs["gridsize"]) if "gridsize" in
984 self.publication_kwargs.keys() else 100
985 )
986 for plot in twod_plots:
987 if not all(
988 all(
989 i in self.samples[j].keys() for i in plot
990 ) for j in self.labels
991 ):
992 logger.warning(
993 "Failed to generate 2d contour plots for {} because {} are not "
994 "common in all result files".format(
995 " and ".join(plot), " and ".join(plot)
996 )
997 )
998 continue
999 samples = [[self.samples[i][j] for j in plot] for i in self.labels]
1000 arguments = [
1001 self.savedir, plot, samples, self.labels, latex_labels,
1002 self.colors, self.linestyles, gridsize,
1003 self.preliminary_comparison_pages, self.checkpoint
1004 ]
1005 self._try_to_make_a_plot(
1006 arguments, self._twod_comparison_contour_plot,
1007 error_message % (" and ".join(plot))
1008 )
1010 @staticmethod
1011 def _twod_comparison_contour_plot(
1012 savedir, plot_parameters, samples, labels, latex_labels, colors,
1013 linestyles, gridsize, preliminary=False, checkpoint=False
1014 ):
1015 """Generate a 2d comparison contour plot for a given set of samples
1017 Parameters
1018 ----------
1019 savedir: str
1020 the directory you wish to save the plot in
1021 plot_parameters: list
1022 list of parameters to use for the 2d contour plot
1023 samples: list
1024 list of samples for each parameter
1025 labels: list
1026 list of labels used to distinguish each result file
1027 latex_labels: dict
1028 dictionary containing the latex labels for each parameter
1029 gridsize: int
1030 the number of points to use when estimating the KDE
1031 preliminary: Bool, optional
1032 if True, add a preliminary watermark to the plot
1033 """
1034 filename = os.path.join(
1035 savedir, "publication", "2d_contour_plot_{}.png".format(
1036 "_and_".join(plot_parameters)
1037 )
1038 )
1039 if os.path.isfile(filename) and checkpoint:
1040 return
1041 fig = publication.twod_contour_plots(
1042 plot_parameters, samples, labels, latex_labels, colors=colors,
1043 linestyles=linestyles, gridsize=gridsize
1044 )
1045 _PlotGeneration.save(
1046 fig, filename, preliminary=preliminary
1047 )
1049 def violin_plot(self, label):
1050 """Generate violin plot to compare certain parameters in all result
1051 files
1053 Parameters
1054 ----------
1055 label: str
1056 the label for the results file that you wish to plot
1057 """
1058 error_message = (
1059 "Failed to generate a violin plot for %s because {}"
1060 )
1061 violin_plots = ["mass_ratio", "chi_eff", "chi_p", "luminosity_distance"]
1063 for plot in violin_plots:
1064 injection = [self.injection_data[label][plot] for label in self.labels]
1065 if not all(plot in self.samples[j].keys() for j in self.labels):
1066 logger.warning(
1067 "Failed to generate violin plots for {} because {} is not "
1068 "common in all result files".format(plot, plot)
1069 )
1070 samples = [self.samples[i][plot] for i in self.labels]
1071 arguments = [
1072 self.savedir, plot, samples, self.labels, latex_labels[plot],
1073 injection, self.preliminary_comparison_pages, self.checkpoint
1074 ]
1075 self._try_to_make_a_plot(
1076 arguments, self._violin_plot, error_message % (plot)
1077 )
1079 @staticmethod
1080 def _violin_plot(
1081 savedir, plot_parameter, samples, labels, latex_label, inj_values=None,
1082 preliminary=False, checkpoint=False, kde=ReflectionBoundedKDE,
1083 default_bounds=True
1084 ):
1085 """Generate a violin plot for a given set of samples
1087 Parameters
1088 ----------
1089 savedir: str
1090 the directory you wish to save the plot in
1091 plot_parameter: str
1092 name of the parameter you wish to generate a violin plot for
1093 samples: list
1094 list of samples for each parameter
1095 labels: list
1096 list of labels used to distinguish each result file
1097 latex_label: str
1098 latex_label correspondig to parameter
1099 inj_value: list
1100 list of injected values for each sample
1101 preliminary: Bool, optional
1102 if True, add a preliminary watermark to the plot
1103 """
1104 filename = os.path.join(
1105 savedir, "publication", "violin_plot_{}.png".format(plot_parameter)
1106 )
1107 if os.path.isfile(filename) and checkpoint:
1108 return
1109 xlow, xhigh = None, None
1110 if default_bounds:
1111 xlow, xhigh = gw._return_bounds(
1112 plot_parameter, samples, comparison=True
1113 )
1114 fig = publication.violin_plots(
1115 plot_parameter, samples, labels, latex_labels, kde=kde,
1116 kde_kwargs={"xlow": xlow, "xhigh": xhigh}, inj_values=inj_values
1117 )
1118 _PlotGeneration.save(
1119 fig, filename, preliminary=preliminary
1120 )
1122 def spin_dist_plot(self, label):
1123 """Generate a spin disk plot to compare spins in all result
1124 files
1126 Parameters
1127 ----------
1128 label: str
1129 the label for the results file that you wish to plot
1130 """
1131 error_message = (
1132 "Failed to generate a spin disk plot for %s because {}"
1133 )
1134 parameters = ["a_1", "a_2", "cos_tilt_1", "cos_tilt_2"]
1135 for num, label in enumerate(self.labels):
1136 if not all(i in self.samples[label].keys() for i in parameters):
1137 logger.warning(
1138 "Failed to generate spin disk plots because {} are not "
1139 "common in all result files".format(
1140 " and ".join(parameters)
1141 )
1142 )
1143 continue
1144 samples = [self.samples[label][i] for i in parameters]
1145 arguments = [
1146 self.savedir, parameters, samples, label, self.colors[num],
1147 self.preliminary_comparison_pages, self.checkpoint
1148 ]
1150 self._try_to_make_a_plot(
1151 arguments, self._spin_dist_plot, error_message % (label)
1152 )
1154 @staticmethod
1155 def _spin_dist_plot(
1156 savedir, parameters, samples, label, color, preliminary=False,
1157 checkpoint=False
1158 ):
1159 """Generate a spin disk plot for a given set of samples
1161 Parameters
1162 ----------
1163 preliminary: Bool, optional
1164 if True, add a preliminary watermark to the plot
1165 """
1166 filename = os.path.join(
1167 savedir, "publication", "spin_disk_plot_{}.png".format(label)
1168 )
1169 if os.path.isfile(filename) and checkpoint:
1170 return
1171 fig = publication.spin_distribution_plots(
1172 parameters, samples, label, color=color
1173 )
1174 _PlotGeneration.save(
1175 fig, filename, preliminary=preliminary
1176 )
1178 def classification_plot(self, label):
1179 """Generate plots showing source classification probabilities
1181 Parameters
1182 ----------
1183 label: str
1184 the label for the results file that you wish to plot
1185 """
1186 if self.mcmc_samples:
1187 samples = self.samples[label].combine
1188 else:
1189 samples = self.samples[label]
1190 self._classification_plot(
1191 self.savedir, samples, label,
1192 self.classification_probs[label]["default"],
1193 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint
1194 )
1196 @staticmethod
1197 @no_latex_plot
1198 def _classification_plot(
1199 savedir, samples, label, probabilities, preliminary=False,
1200 checkpoint=False
1201 ):
1202 """Generate a plot with the PEPredicates package for a given set of
1203 samples
1205 Parameters
1206 ----------
1207 savedir: str
1208 the directory you wish to save the plot in
1209 samples: dict
1210 dictionary of samples for each parameter
1211 label: str
1212 the label corresponding to the result file
1213 probabilities: dict
1214 dictionary of classification probabilities
1215 preliminary: Bool, optional
1216 if True, add a preliminary watermark to the plot
1217 """
1218 from pesummary.gw.classification import PAstro, EMBright
1220 _pastro = PAstro(samples)
1221 filename = os.path.join(
1222 savedir, "{}.pesummary.p_astro.png".format(label)
1223 )
1224 if os.path.isfile(filename) and checkpoint:
1225 pass
1226 else:
1227 fig = _pastro.plot(
1228 type="bar", probabilities={
1229 key: value for key, value in probabilities.items() if
1230 key in ["BBH", "BNS", "NSBH", "Terrestrial"]
1231 }
1232 )
1233 _PlotGeneration.save(
1234 fig, filename, preliminary=preliminary
1235 )
1237 _embright = EMBright(samples)
1238 filename = os.path.join(
1239 savedir, "{}.pesummary.em_bright.png".format(label)
1240 )
1241 if os.path.isfile(filename) and checkpoint:
1242 pass
1243 else:
1244 fig = _embright.plot(
1245 type="bar", probabilities={
1246 key: value for key, value in probabilities.items() if
1247 key in ["HasNS", "HasRemnant", "HasMassGap"]
1248 }
1249 )
1250 _PlotGeneration.save(
1251 fig, filename, preliminary=preliminary
1252 )
1254 def psd_plot(self, label):
1255 """Generate a psd plot for a given result file
1257 Parameters
1258 ----------
1259 label: str
1260 the label corresponding to the result file
1261 """
1262 error_message = (
1263 "Failed to generate a PSD plot for %s because {}"
1264 )
1266 fmin = None
1267 fmax = None
1269 for num, label in enumerate(self.labels):
1270 if list(self.psd[label].keys()) == [None]:
1271 return
1272 if list(self.psd[label].keys()) == []:
1273 return
1274 if "f_low" in list(self.file_kwargs[label]["meta_data"].keys()):
1275 fmin = self.file_kwargs[label]["meta_data"]["f_low"]
1276 if "f_final" in list(self.file_kwargs[label]["meta_data"].keys()):
1277 fmax = self.file_kwargs[label]["meta_data"]["f_final"]
1278 labels = list(self.psd[label].keys())
1279 frequencies = [np.array(self.psd[label][i]).T[0] for i in labels]
1280 strains = [np.array(self.psd[label][i]).T[1] for i in labels]
1281 arguments = [
1282 self.savedir, frequencies, strains, fmin, fmax, labels, label,
1283 self.checkpoint
1284 ]
1286 self._try_to_make_a_plot(
1287 arguments, self._psd_plot, error_message % (label)
1288 )
1290 @staticmethod
1291 def _psd_plot(
1292 savedir, frequencies, strains, fmin, fmax, psd_labels, label, checkpoint=False
1293 ):
1294 """Generate a psd plot for a given set of samples
1296 Parameters
1297 ----------
1298 savedir: str
1299 the directory you wish to save the plot in
1300 frequencies: list
1301 list of psd frequencies for each IFO
1302 strains: list
1303 list of psd strains for each IFO
1304 fmin: float
1305 frequency to start the psd plotting
1306 fmax: float
1307 frequency to end the psd plotting
1308 psd_labels: list
1309 list of IFOs used
1310 label: str
1311 the label used to distinguish the result file
1312 """
1313 filename = os.path.join(savedir, "{}_psd_plot.png".format(label))
1314 if os.path.isfile(filename) and checkpoint:
1315 return
1316 fig = gw._psd_plot(
1317 frequencies, strains, labels=psd_labels, fmin=fmin, fmax=fmax
1318 )
1319 _PlotGeneration.save(fig, filename)
1321 def calibration_plot(self, label):
1322 """Generate a calibration plot for a given result file
1324 Parameters
1325 ----------
1326 label: str
1327 the label corresponding to the result file
1328 """
1329 import numpy as np
1331 error_message = (
1332 "Failed to generate calibration plot for %s because {}"
1333 )
1334 frequencies = np.arange(20., 1024., 1. / 4)
1336 for num, label in enumerate(self.labels):
1337 if list(self.calibration[label].keys()) == [None]:
1338 return
1339 if list(self.calibration[label].keys()) == []:
1340 return
1342 ifos = list(self.calibration[label].keys())
1343 calibration_data = [
1344 self.calibration[label][i] for i in ifos
1345 ]
1346 if "calibration" in self.priors.keys():
1347 prior = [self.priors["calibration"][label][i] for i in ifos]
1348 else:
1349 prior = None
1350 arguments = [
1351 self.savedir, frequencies, calibration_data, ifos, prior,
1352 label, self.calibration_definition[label], self.checkpoint
1353 ]
1354 self._try_to_make_a_plot(
1355 arguments, self._calibration_plot, error_message % (label)
1356 )
1358 @staticmethod
1359 def _calibration_plot(
1360 savedir, frequencies, calibration_data, calibration_labels, prior, label,
1361 calibration_definition="data", checkpoint=False
1362 ):
1363 """Generate a calibration plot for a given set of samples
1365 Parameters
1366 ----------
1367 savedir: str
1368 the directory you wish to save the plot in
1369 frequencies: list
1370 list of frequencies used to interpolate the calibration data
1371 calibration_data: list
1372 list of calibration data for each IFO
1373 calibration_labels: list
1374 list of IFOs used
1375 prior: list
1376 list containing the priors used for each IFO
1377 label: str
1378 the label used to distinguish the result file
1379 calibration_definition: str
1380 the definition of the calibration prior used (either 'data' or 'template')
1381 """
1382 filename = os.path.join(
1383 savedir, "{}_calibration_plot.png".format(label)
1384 )
1385 if os.path.isfile(filename) and checkpoint:
1386 return
1387 fig = gw._calibration_envelope_plot(
1388 frequencies, calibration_data, calibration_labels, prior=prior,
1389 definition=calibration_definition
1390 )
1391 _PlotGeneration.save(fig, filename)
1393 @staticmethod
1394 def _interactive_corner_plot(
1395 savedir, label, samples, latex_labels, checkpoint=False
1396 ):
1397 """Generate an interactive corner plot for a given set of samples
1399 Parameters
1400 ----------
1401 savedir: str
1402 the directory you wish to save the plot in
1403 label: str
1404 the label corresponding to the results file
1405 samples: dict
1406 dictionary containing PESummary.utils.utils.Array objects that
1407 contain samples for each parameter
1408 latex_labels: str
1409 latex labels for each parameter in samples
1410 """
1411 filename = os.path.join(
1412 savedir, "corner", "{}_interactive_source.html".format(label)
1413 )
1414 if os.path.isfile(filename) and checkpoint:
1415 pass
1416 else:
1417 source_parameters = [
1418 "luminosity_distance", "mass_1_source", "mass_2_source",
1419 "total_mass_source", "chirp_mass_source", "redshift"
1420 ]
1421 parameters = [i for i in samples.keys() if i in source_parameters]
1422 data = [samples[parameter] for parameter in parameters]
1423 labels = [latex_labels[parameter] for parameter in parameters]
1424 _ = interactive.corner(
1425 data, labels, write_to_html_file=filename,
1426 dimensions={"width": 900, "height": 900}
1427 )
1429 filename = os.path.join(
1430 savedir, "corner", "{}_interactive_extrinsic.html".format(label)
1431 )
1432 if os.path.isfile(filename) and checkpoint:
1433 pass
1434 else:
1435 extrinsic_parameters = ["luminosity_distance", "psi", "ra", "dec"]
1436 parameters = [i for i in samples.keys() if i in extrinsic_parameters]
1437 data = [samples[parameter] for parameter in parameters]
1438 labels = [latex_labels[parameter] for parameter in parameters]
1439 _ = interactive.corner(
1440 data, labels, write_to_html_file=filename
1441 )