Coverage for pesummary/gw/plots/main.py: 69.4%
448 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1#! /usr/bin/env python
3# Licensed under an MIT style license -- see LICENSE.md
5import os
7from pesummary.core.plots.main import _PlotGeneration as _BasePlotGeneration
8from pesummary.core.plots.latex_labels import latex_labels
9from pesummary.core.plots import interactive
10from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE
11from pesummary.gw.plots.latex_labels import GWlatex_labels
12from pesummary.utils.utils import logger, resample_posterior_distribution
13from pesummary.utils.decorators import no_latex_plot
14from pesummary.gw.plots import publication
15from pesummary.gw.plots import plot as gw
17import multiprocessing as mp
18import numpy as np
20__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
21latex_labels.update(GWlatex_labels)
24class _PlotGeneration(_BasePlotGeneration):
25 def __init__(
26 self, savedir=None, webdir=None, labels=None, samples=None,
27 kde_plot=False, existing_labels=None, existing_injection_data=None,
28 existing_file_kwargs=None, existing_samples=None,
29 existing_metafile=None, same_parameters=None, injection_data=None,
30 result_files=None, file_kwargs=None, colors=None, custom_plotting=None,
31 add_to_existing=False, priors={}, no_ligo_skymap=False,
32 nsamples_for_skymap=None, detectors=None, maxL_samples=None,
33 gwdata=None, calibration_definition=None, calibration=None, psd=None,
34 multi_threading_for_skymap=None, approximant=None,
35 pepredicates_probs=None, include_prior=False, publication=False,
36 existing_approximant=None, existing_psd=None, existing_calibration=None,
37 existing_weights=None, weights=None, disable_comparison=False,
38 linestyles=None, disable_interactive=False, disable_corner=False,
39 publication_kwargs={}, multi_process=1, mcmc_samples=False,
40 skymap=None, existing_skymap=None, corner_params=None,
41 preliminary_pages=False, expert_plots=True, checkpoint=False,
42 key_data=None
43 ):
44 super(_PlotGeneration, self).__init__(
45 savedir=savedir, webdir=webdir, labels=labels,
46 samples=samples, kde_plot=kde_plot, existing_labels=existing_labels,
47 existing_injection_data=existing_injection_data,
48 existing_samples=existing_samples,
49 existing_weights=existing_weights,
50 same_parameters=same_parameters,
51 injection_data=injection_data, mcmc_samples=mcmc_samples,
52 colors=colors, custom_plotting=custom_plotting,
53 add_to_existing=add_to_existing, priors=priors,
54 include_prior=include_prior, weights=weights,
55 disable_comparison=disable_comparison, linestyles=linestyles,
56 disable_interactive=disable_interactive, disable_corner=disable_corner,
57 multi_process=multi_process, corner_params=corner_params,
58 expert_plots=expert_plots, checkpoint=checkpoint, key_data=key_data
59 )
60 self.preliminary_pages = preliminary_pages
61 if not isinstance(self.preliminary_pages, dict):
62 if self.preliminary_pages:
63 self.preliminary_pages = {
64 label: True for label in self.labels
65 }
66 else:
67 self.preliminary_pages = {
68 label: False for label in self.labels
69 }
70 self.preliminary_comparison_pages = any(
71 value for value in self.preliminary_pages.values()
72 )
73 self.package = "gw"
74 self.file_kwargs = file_kwargs
75 self.existing_file_kwargs = existing_file_kwargs
76 self.no_ligo_skymap = no_ligo_skymap
77 self.nsamples_for_skymap = nsamples_for_skymap
78 self.detectors = detectors
79 self.maxL_samples = maxL_samples
80 self.gwdata = gwdata
81 if skymap is None:
82 skymap = {label: None for label in self.labels}
83 self.skymap = skymap
84 self.existing_skymap = skymap
85 self.calibration_definition = calibration_definition
86 self.calibration = calibration
87 self.existing_calibration = existing_calibration
88 self.psd = psd
89 self.existing_psd = existing_psd
90 self.multi_threading_for_skymap = multi_threading_for_skymap
91 self.approximant = approximant
92 self.existing_approximant = existing_approximant
93 self.pepredicates_probs = pepredicates_probs
94 self.publication = publication
95 self.publication_kwargs = publication_kwargs
96 self._ligo_skymap_PID = {}
98 self.plot_type_dictionary.update({
99 "psd": self.psd_plot,
100 "calibration": self.calibration_plot,
101 "twod_histogram": self.twod_histogram_plot,
102 "skymap": self.skymap_plot,
103 "waveform_fd": self.waveform_fd_plot,
104 "waveform_td": self.waveform_td_plot,
105 "data": self.gwdata_plots,
106 "violin": self.violin_plot,
107 "spin_disk": self.spin_dist_plot,
108 "pepredicates": self.pepredicates_plot
109 })
110 if self.make_comparison:
111 self.plot_type_dictionary.update({
112 "skymap_comparison": self.skymap_comparison_plot,
113 "waveform_comparison_fd": self.waveform_comparison_fd_plot,
114 "waveform_comparison_td": self.waveform_comparison_td_plot,
115 "2d_comparison_contour": self.twod_comparison_contour_plot,
116 })
118 @property
119 def ligo_skymap_PID(self):
120 return self._ligo_skymap_PID
122 def generate_plots(self):
123 """Generate all plots for all result files
124 """
125 if self.calibration or "calibration" in list(self.priors.keys()):
126 self.try_to_make_a_plot("calibration")
127 if self.psd:
128 self.try_to_make_a_plot("psd")
129 super(_PlotGeneration, self).generate_plots()
131 def _generate_plots(self, label):
132 """Generate all plots for a given result file
133 """
134 super(_PlotGeneration, self)._generate_plots(label)
135 self.try_to_make_a_plot("twod_histogram", label=label)
136 self.try_to_make_a_plot("skymap", label=label)
137 self.try_to_make_a_plot("waveform_td", label=label)
138 self.try_to_make_a_plot("waveform_fd", label=label)
139 if self.pepredicates_probs[label] is not None:
140 self.try_to_make_a_plot("pepredicates", label=label)
141 if self.gwdata:
142 self.try_to_make_a_plot("data", label=label)
144 def _generate_comparison_plots(self):
145 """Generate all comparison plots
146 """
147 super(_PlotGeneration, self)._generate_comparison_plots()
148 self.try_to_make_a_plot("skymap_comparison")
149 self.try_to_make_a_plot("waveform_comparison_td")
150 self.try_to_make_a_plot("waveform_comparison_fd")
151 if self.publication:
152 self.try_to_make_a_plot("2d_comparison_contour")
153 self.try_to_make_a_plot("violin")
154 self.try_to_make_a_plot("spin_disk")
156 @staticmethod
157 def _corner_plot(
158 savedir, label, samples, latex_labels, webdir, params, preliminary=False,
159 checkpoint=False
160 ):
161 """Generate a corner plot for a given set of samples
163 Parameters
164 ----------
165 savedir: str
166 the directory you wish to save the plot in
167 label: str
168 the label corresponding to the results file
169 samples: dict
170 dictionary of samples for a given result file
171 latex_labels: dict
172 dictionary of latex labels
173 webdir: str
174 directory where the javascript is written
175 preliminary: Bool, optional
176 if True, add a preliminary watermark to the plot
177 """
178 import warnings
180 with warnings.catch_warnings():
181 warnings.simplefilter("ignore")
182 filename = os.path.join(
183 savedir, "corner", "{}_all_density_plots.png".format(label)
184 )
185 if os.path.isfile(filename) and checkpoint:
186 pass
187 else:
188 fig, params, data = gw._make_corner_plot(
189 samples, latex_labels, corner_parameters=params
190 )
191 fig.savefig(filename)
192 fig.close()
193 combine_corner = open(
194 os.path.join(webdir, "js", "combine_corner.js")
195 )
196 combine_corner = combine_corner.readlines()
197 params = [str(i) for i in params]
198 ind = [
199 linenumber for linenumber, line in enumerate(combine_corner)
200 if "var list = {}" in line
201 ][0]
202 combine_corner.insert(
203 ind + 1, " list['{}'] = {};\n".format(label, params)
204 )
205 new_file = open(
206 os.path.join(webdir, "js", "combine_corner.js"), "w"
207 )
208 new_file.writelines(combine_corner)
209 new_file.close()
210 combine_corner = open(
211 os.path.join(webdir, "js", "combine_corner.js")
212 )
213 combine_corner = combine_corner.readlines()
214 params = [str(i) for i in params]
215 ind = [
216 linenumber for linenumber, line in enumerate(combine_corner)
217 if "var data = {}" in line
218 ][0]
219 combine_corner.insert(
220 ind + 1, " data['{}'] = {};\n".format(label, data)
221 )
222 new_file = open(
223 os.path.join(webdir, "js", "combine_corner.js"), "w"
224 )
225 new_file.writelines(combine_corner)
226 new_file.close()
228 filename = os.path.join(
229 savedir, "corner", "{}_sourceframe.png".format(label)
230 )
231 if os.path.isfile(filename) and checkpoint:
232 pass
233 else:
234 fig = gw._make_source_corner_plot(samples, latex_labels)
235 fig.savefig(filename)
236 fig.close()
237 filename = os.path.join(
238 savedir, "corner", "{}_extrinsic.png".format(label)
239 )
240 if os.path.isfile(filename) and checkpoint:
241 pass
242 else:
243 fig = gw._make_extrinsic_corner_plot(samples, latex_labels)
244 fig.savefig(filename)
245 fig.close()
247 def twod_histogram_plot(self, label):
248 """
249 """
250 from pesummary import conf
251 error_message = (
252 "Failed to generate %s-%s triangle plot because {}"
253 )
254 paramset = [
255 params for params in conf.gw_2d_plots if
256 all(p in self.samples[label] for p in params)
257 ]
258 if self.weights is not None:
259 weights = self.weights.get(label, None)
260 else:
261 weights = None
262 arguments = [
263 (
264 [
265 self.savedir, label, params,
266 [self.samples[label][p] for p in params],
267 [latex_labels[p] for p in params],
268 [self.injection_data[label][p] for p in params],
269 weights, self.preliminary_pages[label], self.checkpoint
270 ], self._triangle_plot, error_message % (params[0], params[1])
271 ) for params in paramset
272 ]
273 self.pool.starmap(self._try_to_make_a_plot, arguments)
275 @staticmethod
276 def _triangle_plot(
277 savedir, label, params, samples, latex_labels, injection, weights,
278 preliminary=False, checkpoint=False
279 ):
280 from pesummary.core.plots.publication import triangle_plot
281 import math
282 for num, ii in enumerate(injection):
283 if math.isnan(ii):
284 injection[num] = None
286 if any(ii is None for ii in injection):
287 truth = None
288 else:
289 truth = injection
290 filename = os.path.join(
291 savedir, "{}_2d_posterior_{}_{}.png".format(
292 label, params[0], params[1]
293 )
294 )
295 if os.path.isfile(filename) and checkpoint:
296 return
297 fig, _, _, _ = triangle_plot(
298 *samples, kde=False, parameters=params, xlabel=latex_labels[0],
299 ylabel=latex_labels[1], plot_datapoints=True, plot_density=False,
300 levels=[1e-8], fill=False, grid=True, linewidths=[1.75],
301 percentiles=[5, 95], percentile_plot=[label], labels=[label],
302 truth=truth, weights=weights, data_kwargs={"alpha": 0.3}
303 )
304 _PlotGeneration.save(
305 fig, filename, preliminary=preliminary
306 )
309 def skymap_plot(self, label):
310 """Generate a skymap plot for a given result file
312 Parameters
313 ----------
314 label: str
315 the label for the results file that you wish to plot
316 """
317 try:
318 import ligo.skymap # noqa: F401
319 except ImportError:
320 SKYMAP = False
321 else:
322 SKYMAP = True
324 if self.mcmc_samples:
325 samples = self.samples[label].combine
326 else:
327 samples = self.samples[label]
328 _injection = [
329 self.injection_data[label]["ra"], self.injection_data[label]["dec"]
330 ]
331 self._skymap_plot(
332 self.savedir, samples["ra"], samples["dec"], label,
333 self.weights[label], _injection,
334 preliminary=self.preliminary_pages[label]
335 )
337 if SKYMAP and not self.no_ligo_skymap and self.skymap[label] is None:
338 from pesummary.utils.utils import RedirectLogger
340 logger.info("Launching subprocess to generate skymap plot with "
341 "ligo.skymap")
342 try:
343 _time = samples["geocent_time"]
344 except KeyError:
345 logger.warning(
346 "Unable to find 'geocent_time' in the posterior table for {}. "
347 "The ligo.skymap fits file will therefore not store the "
348 "DATE_OBS field in the header".format(label)
349 )
350 _time = None
351 with RedirectLogger("ligo.skymap", level="DEBUG") as redirector:
352 process = mp.Process(
353 target=self._ligo_skymap_plot,
354 args=[
355 self.savedir, samples["ra"], samples["dec"],
356 samples["luminosity_distance"], _time,
357 label, self.nsamples_for_skymap, self.webdir,
358 self.multi_threading_for_skymap, _injection,
359 self.preliminary_pages[label]
360 ]
361 )
362 process.start()
363 PID = process.pid
364 self._ligo_skymap_PID[label] = PID
365 elif SKYMAP and not self.no_ligo_skymap:
366 self._ligo_skymap_array_plot(
367 self.savedir, self.skymap[label], label,
368 self.preliminary_pages[label]
369 )
371 @staticmethod
372 @no_latex_plot
373 def _skymap_plot(
374 savedir, ra, dec, label, weights, injection=None, preliminary=False
375 ):
376 """Generate a skymap plot for a given set of samples
378 Parameters
379 ----------
380 savedir: str
381 the directory you wish to save the plot in
382 ra: pesummary.utils.utils.Array
383 array containing the samples for right ascension
384 dec: pesummary.utils.utils.Array
385 array containing the samples for declination
386 label: str
387 the label corresponding to the results file
388 weights: list
389 list of weights for the samples
390 injection: list, optional
391 list containing the injected value of ra and dec
392 preliminary: Bool, optional
393 if True, add a preliminary watermark to the plot
394 """
395 import math
397 if injection is not None and any(math.isnan(inj) for inj in injection):
398 injection = None
399 fig = gw._default_skymap_plot(ra, dec, weights, injection=injection)
400 _PlotGeneration.save(
401 fig, os.path.join(savedir, "{}_skymap".format(label)),
402 preliminary=preliminary
403 )
405 @staticmethod
406 @no_latex_plot
407 def _ligo_skymap_plot(savedir, ra, dec, dist, time, label, nsamples_for_skymap,
408 webdir, multi_threading_for_skymap, injection,
409 preliminary=False):
410 """Generate a skymap plot for a given set of samples using the
411 ligo.skymap package
413 Parameters
414 ----------
415 savedir: str
416 the directory you wish to save the plot in
417 ra: pesummary.utils.utils.Array
418 array containing the samples for right ascension
419 dec: pesummary.utils.utils.Array
420 array containing the samples for declination
421 dist: pesummary.utils.utils.Array
422 array containing the samples for luminosity distance
423 time: pesummary.utils.utils.Array
424 array containing the samples for the geocentric time of merger
425 label: str
426 the label corresponding to the results file
427 nsamples_for_skymap: int
428 the number of samples used to generate skymap
429 webdir: str
430 the directory to store the fits file
431 preliminary: Bool, optional
432 if True, add a preliminary watermark to the plot
433 """
434 import math
436 downsampled = False
437 if nsamples_for_skymap is not None:
438 ra, dec, dist = resample_posterior_distribution(
439 [ra, dec, dist], nsamples_for_skymap
440 )
441 downsampled = True
442 if injection is not None and any(math.isnan(inj) for inj in injection):
443 injection = None
444 fig = gw._ligo_skymap_plot(
445 ra, dec, dist=dist, savedir=os.path.join(webdir, "samples"),
446 nprocess=multi_threading_for_skymap, downsampled=downsampled,
447 label=label, time=time, injection=injection
448 )
449 _PlotGeneration.save(
450 fig, os.path.join(savedir, "{}_skymap".format(label)),
451 preliminary=preliminary
452 )
454 @staticmethod
455 @no_latex_plot
456 def _ligo_skymap_array_plot(savedir, skymap, label, preliminary=False):
457 """Generate a skymap based on skymap probability array already generated with
458 `ligo.skymap`
460 Parameters
461 ----------
462 savedir: str
463 the directory you wish to save the plot in
464 skymap: np.ndarray
465 array of skymap probabilities
466 label: str
467 the label corresponding to the results file
468 preliminary: Bool, optional
469 if True, add a preliminary watermark to the plot
470 """
471 fig = gw._ligo_skymap_plot_from_array(skymap)
472 _PlotGeneration.save(
473 fig, os.path.join(savedir, "{}_skymap".format(label)),
474 preliminary=preliminary
475 )
477 def waveform_fd_plot(self, label):
478 """Generate a frequency domain waveform plot for a given result file
480 Parameters
481 ----------
482 label: str
483 the label corresponding to the results file
484 """
485 if self.approximant[label] == {}:
486 return
487 self._waveform_fd_plot(
488 self.savedir, self.detectors[label], self.maxL_samples[label], label,
489 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint,
490 **self.file_kwargs[label]["meta_data"]
491 )
493 @staticmethod
494 def _waveform_fd_plot(
495 savedir, detectors, maxL_samples, label, preliminary=False,
496 checkpoint=False, **kwargs
497 ):
498 """Generate a frequency domain waveform plot for a given detector
499 network and set of samples
501 Parameters
502 ----------
503 savedir: str
504 the directory you wish to save the plot in
505 detectors: list
506 list of detectors used in your analysis
507 maxL_samples: dict
508 dictionary of maximum likelihood values
509 label: str
510 the label corresponding to the results file
511 preliminary: Bool, optional
512 if True, add a preliminary watermark to the plot
513 """
514 filename = os.path.join(savedir, "{}_waveform.png".format(label))
515 if os.path.isfile(filename) and checkpoint:
516 return
517 if detectors is None:
518 detectors = ["H1", "L1"]
519 else:
520 detectors = detectors.split("_")
522 fig = gw._waveform_plot(
523 detectors, maxL_samples, f_start=kwargs.get("f_start", 20.),
524 f_low=kwargs.get("f_low", 20.0),
525 f_max=kwargs.get("f_final", 1024.),
526 f_ref=kwargs.get("f_ref", 20.),
527 approximant_flags=kwargs.get("approximant_flags", {})
528 )
529 _PlotGeneration.save(
530 fig, filename, preliminary=preliminary
531 )
533 def waveform_td_plot(self, label):
534 """Generate a time domain waveform plot for a given result file
536 Parameters
537 ----------
538 label: str
539 the label corresponding to the results file
540 """
541 if self.approximant[label] == {}:
542 return
543 self._waveform_td_plot(
544 self.savedir, self.detectors[label], self.maxL_samples[label], label,
545 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint,
546 **self.file_kwargs[label]["meta_data"]
547 )
549 @staticmethod
550 def _waveform_td_plot(
551 savedir, detectors, maxL_samples, label, preliminary=False,
552 checkpoint=False, **kwargs
553 ):
554 """Generate a time domain waveform plot for a given detector network
555 and set of samples
557 Parameters
558 ----------
559 savedir: str
560 the directory you wish to save the plot in
561 detectors: list
562 list of detectors used in your analysis
563 maxL_samples: dict
564 dictionary of maximum likelihood values
565 label: str
566 the label corresponding to the results file
567 preliminary: Bool, optional
568 if True, add a preliminary watermark to the plot
569 """
570 filename = os.path.join(
571 savedir, "{}_waveform_time_domain.png".format(label)
572 )
573 if os.path.isfile(filename) and checkpoint:
574 return
575 if detectors is None:
576 detectors = ["H1", "L1"]
577 else:
578 detectors = detectors.split("_")
580 fig = gw._time_domain_waveform(
581 detectors, maxL_samples, f_start=kwargs.get("f_start", 20.),
582 f_low=kwargs.get("f_low", 20.0),
583 f_max=kwargs.get("f_final", 1024.),
584 f_ref=kwargs.get("f_ref", 20.),
585 approximant_flags=kwargs.get("approximant_flags", {})
586 )
587 _PlotGeneration.save(
588 fig, filename, preliminary=preliminary
589 )
591 def gwdata_plots(self, label):
592 """Generate all plots associated with the gwdata
594 Parameters
595 ----------
596 label: str
597 the label corresponding to the results file
598 """
599 from pesummary.utils.utils import determine_gps_time_and_window
601 base_error = "Failed to generate a %s because {}"
602 gps_time, window = determine_gps_time_and_window(
603 self.maxL_samples, self.labels
604 )
605 functions = [
606 self.strain_plot, self.spectrogram_plot, self.omegascan_plot
607 ]
608 args = [[label], [], [gps_time, window]]
609 func_names = ["strain_plot", "spectrogram plot", "omegascan plot"]
611 for func, args, name in zip(functions, args, func_names):
612 self._try_to_make_a_plot(args, func, base_error % (name))
613 continue
615 def strain_plot(self, label):
616 """Generate a plot showing the comparison between the data and the
617 maxL waveform gfor a given result file
619 Parameters
620 ----------
621 label: str
622 the label corresponding to the results file
623 """
624 logger.info("Launching subprocess to generate strain plot")
625 process = mp.Process(
626 target=self._strain_plot,
627 args=[self.savedir, self.gwdata, self.maxL_samples[label], label]
628 )
629 process.start()
631 @staticmethod
632 def _strain_plot(savedir, gwdata, maxL_samples, label, checkpoint=False):
633 """Generate a strain plot for a given set of samples
635 Parameters
636 ----------
637 savedir: str
638 the directory to save the plot
639 gwdata: dict
640 dictionary of strain data for each detector
641 maxL_samples: dict
642 dictionary of maximum likelihood values
643 label: str
644 the label corresponding to the results file
645 """
646 filename = os.path.join(savedir, "{}_strain.png".format(label))
647 if os.path.isfile(filename) and checkpoint:
648 return
649 fig = gw._strain_plot(gwdata, maxL_samples)
650 _PlotGeneration.save(fig, filename)
652 def spectrogram_plot(self):
653 """Generate a plot showing the spectrogram for all detectors
654 """
655 figs = self._spectrogram_plot(self.savedir, self.gwdata)
657 @staticmethod
658 def _spectrogram_plot(savedir, strain):
659 """Generate a plot showing the spectrogram for all detectors
661 Parameters
662 ----------
663 savedir: str
664 the directory you wish to save the plot in
665 strain: dict
666 dictionary of gwpy timeseries objects containing the strain data for
667 each IFO
668 """
669 from pesummary.gw.plots import detchar
671 figs = detchar.spectrogram(strain)
672 for det, fig in figs.items():
673 _PlotGeneration.save(
674 fig, os.path.join(savedir, "spectrogram_{}".format(det))
675 )
677 def omegascan_plot(self, gps_time, window):
678 """Generate a plot showing the omegascan for all detectors
680 Parameters
681 ----------
682 gps_time: float
683 time around which to centre the omegascan
684 window: float
685 window around gps time to generate plot for
686 """
687 figs = self._omegascan_plot(
688 self.savedir, self.gwdata, gps_time, window
689 )
691 @staticmethod
692 def _omegascan_plot(savedir, strain, gps, window):
693 """Generate a plot showing the spectrogram for all detectors
695 Parameters
696 ----------
697 savedir: str
698 the directory you wish to save the plot in
699 strain: dict
700 dictionary of gwpy timeseries objects containing the strain data for
701 each IFO
702 gps: float
703 time around which to centre the omegascan
704 window: float
705 window around gps time to generate plot for
706 """
707 from pesummary.gw.plots import detchar
709 figs = detchar.omegascan(strain, gps, window=window)
710 for det, fig in figs.items():
711 _PlotGeneration.save(
712 fig, os.path.join(savedir, "omegascan_{}".format(det))
713 )
715 def skymap_comparison_plot(self, label):
716 """Generate a plot to compare skymaps for all result files
718 Parameters
719 ----------
720 label: str
721 the label for the results file that you wish to plot
722 """
723 self._skymap_comparison_plot(
724 self.savedir, self.same_samples["ra"], self.same_samples["dec"],
725 self.labels, self.colors, self.preliminary_comparison_pages,
726 self.checkpoint
727 )
729 @staticmethod
730 def _skymap_comparison_plot(
731 savedir, ra, dec, labels, colors, preliminary=False, checkpoint=False
732 ):
733 """Generate a plot to compare skymaps for a given set of samples
735 Parameters
736 ----------
737 savedir: str
738 the directory you wish to save the plot in
739 ra: dict
740 dictionary of right ascension samples for each result file
741 dec: dict
742 dictionary of declination samples for each result file
743 labels: list
744 list of labels to distinguish each result file
745 colors: list
746 list of colors to be used to distinguish different result files
747 preliminary: Bool, optional
748 if True, add a preliminary watermark to the plot
749 """
750 filename = os.path.join(savedir, "combined_skymap.png")
751 if os.path.isfile(filename) and checkpoint:
752 return
753 ra_list = [ra[key] for key in labels]
754 dec_list = [dec[key] for key in labels]
755 fig = gw._sky_map_comparison_plot(ra_list, dec_list, labels, colors)
756 _PlotGeneration.save(
757 fig, filename, preliminary=preliminary
758 )
760 def waveform_comparison_fd_plot(self, label):
761 """Generate a plot to compare the frequency domain waveform
763 Parameters
764 ----------
765 label: str
766 the label for the results file that you wish to plot
767 """
768 if any(self.approximant[i] == {} for i in self.labels):
769 return
771 self._waveform_comparison_fd_plot(
772 self.savedir, self.maxL_samples, self.labels, self.colors,
773 preliminary=self.preliminary_comparison_pages, checkpoint=self.checkpoint,
774 **self.file_kwargs
775 )
777 @staticmethod
778 def _waveform_comparison_fd_plot(
779 savedir, maxL_samples, labels, colors, preliminary=False,
780 checkpoint=False, **kwargs
781 ):
782 """Generate a plot to compare the frequency domain waveforms
784 Parameters
785 ----------
786 savedir: str
787 the directory you wish to save the plot in
788 maxL_samples: dict
789 dictionary of maximum likelihood samples for each result file
790 labels: list
791 list of labels to distinguish each result file
792 colors: list
793 list of colors to be used to distinguish different result files
794 preliminary: Bool, optional
795 if True, add a preliminary watermark to the plot
796 """
797 filename = os.path.join(savedir, "compare_waveforms.png")
798 if os.path.isfile(filename) and checkpoint:
799 return
800 samples = [maxL_samples[i] for i in labels]
801 for num, i in enumerate(labels):
802 samples[num]["approximant_flags"] = kwargs[i]["meta_data"].get(
803 "approximant_flags", {}
804 )
805 _defaults = [20., 20., 1024., 20.]
806 for freq, default in zip(["f_start", "f_low", "f_final", "f_ref"], _defaults):
807 samples[num][freq] = kwargs[i]["meta_data"].get(freq, default)
809 fig = gw._waveform_comparison_plot(samples, colors, labels)
810 _PlotGeneration.save(
811 fig, filename, preliminary=preliminary
812 )
814 def waveform_comparison_td_plot(self, label):
815 """Generate a plot to compare the time domain waveform
817 Parameters
818 ----------
819 label: str
820 the label for the results file that you wish to plot
821 """
822 if any(self.approximant[i] == {} for i in self.labels):
823 return
825 self._waveform_comparison_td_plot(
826 self.savedir, self.maxL_samples, self.labels, self.colors,
827 self.preliminary_comparison_pages, self.checkpoint,
828 **self.file_kwargs
829 )
831 @staticmethod
832 def _waveform_comparison_td_plot(
833 savedir, maxL_samples, labels, colors, preliminary=False,
834 checkpoint=False, **kwargs
835 ):
836 """Generate a plot to compare the time domain waveforms
838 Parameters
839 ----------
840 savedir: str
841 the directory you wish to save the plot in
842 maxL_samples: dict
843 dictionary of maximum likelihood samples for each result file
844 labels: list
845 list of labels to distinguish each result file
846 colors: list
847 list of colors to be used to distinguish different result files
848 preliminary: Bool, optional
849 if True, add a preliminary watermark to the plot
850 """
851 filename = os.path.join(savedir, "compare_time_domain_waveforms.png")
852 if os.path.isfile(filename) and checkpoint:
853 return
854 samples = [maxL_samples[i] for i in labels]
855 for num, i in enumerate(labels):
856 samples[num]["approximant_flags"] = kwargs[i]["meta_data"].get(
857 "approximant_flags", {}
858 )
859 _defaults = [20., 20., 20.]
860 for freq, default in zip(["f_start", "f_low", "f_ref"], _defaults):
861 samples[num][freq] = kwargs[i]["meta_data"].get(freq, default)
863 fig = gw._time_domain_waveform_comparison_plot(samples, colors, labels)
864 _PlotGeneration.save(
865 fig, filename, preliminary=preliminary
866 )
868 def twod_comparison_contour_plot(self, label):
869 """Generate 2d comparison contour plots
871 Parameters
872 ----------
873 label: str
874 the label for the results file that you wish to plot
875 """
876 error_message = (
877 "Failed to generate a 2d contour plot for %s because {}"
878 )
879 twod_plots = [
880 ["mass_ratio", "chi_eff"], ["mass_1", "mass_2"],
881 ["luminosity_distance", "chirp_mass_source"],
882 ["mass_1_source", "mass_2_source"],
883 ["theta_jn", "luminosity_distance"],
884 ["network_optimal_snr", "chirp_mass_source"]
885 ]
886 gridsize = (
887 int(self.publication_kwargs["gridsize"]) if "gridsize" in
888 self.publication_kwargs.keys() else 100
889 )
890 for plot in twod_plots:
891 if not all(
892 all(
893 i in self.samples[j].keys() for i in plot
894 ) for j in self.labels
895 ):
896 logger.warning(
897 "Failed to generate 2d contour plots for {} because {} are not "
898 "common in all result files".format(
899 " and ".join(plot), " and ".join(plot)
900 )
901 )
902 continue
903 samples = [[self.samples[i][j] for j in plot] for i in self.labels]
904 arguments = [
905 self.savedir, plot, samples, self.labels, latex_labels,
906 self.colors, self.linestyles, gridsize,
907 self.preliminary_comparison_pages, self.checkpoint
908 ]
909 self._try_to_make_a_plot(
910 arguments, self._twod_comparison_contour_plot,
911 error_message % (" and ".join(plot))
912 )
914 @staticmethod
915 def _twod_comparison_contour_plot(
916 savedir, plot_parameters, samples, labels, latex_labels, colors,
917 linestyles, gridsize, preliminary=False, checkpoint=False
918 ):
919 """Generate a 2d comparison contour plot for a given set of samples
921 Parameters
922 ----------
923 savedir: str
924 the directory you wish to save the plot in
925 plot_parameters: list
926 list of parameters to use for the 2d contour plot
927 samples: list
928 list of samples for each parameter
929 labels: list
930 list of labels used to distinguish each result file
931 latex_labels: dict
932 dictionary containing the latex labels for each parameter
933 gridsize: int
934 the number of points to use when estimating the KDE
935 preliminary: Bool, optional
936 if True, add a preliminary watermark to the plot
937 """
938 filename = os.path.join(
939 savedir, "publication", "2d_contour_plot_{}.png".format(
940 "_and_".join(plot_parameters)
941 )
942 )
943 if os.path.isfile(filename) and checkpoint:
944 return
945 fig = publication.twod_contour_plots(
946 plot_parameters, samples, labels, latex_labels, colors=colors,
947 linestyles=linestyles, gridsize=gridsize
948 )
949 _PlotGeneration.save(
950 fig, filename, preliminary=preliminary
951 )
953 def violin_plot(self, label):
954 """Generate violin plot to compare certain parameters in all result
955 files
957 Parameters
958 ----------
959 label: str
960 the label for the results file that you wish to plot
961 """
962 error_message = (
963 "Failed to generate a violin plot for %s because {}"
964 )
965 violin_plots = ["mass_ratio", "chi_eff", "chi_p", "luminosity_distance"]
967 for plot in violin_plots:
968 injection = [self.injection_data[label][plot] for label in self.labels]
969 if not all(plot in self.samples[j].keys() for j in self.labels):
970 logger.warning(
971 "Failed to generate violin plots for {} because {} is not "
972 "common in all result files".format(plot, plot)
973 )
974 samples = [self.samples[i][plot] for i in self.labels]
975 arguments = [
976 self.savedir, plot, samples, self.labels, latex_labels[plot],
977 injection, self.preliminary_comparison_pages, self.checkpoint
978 ]
979 self._try_to_make_a_plot(
980 arguments, self._violin_plot, error_message % (plot)
981 )
983 @staticmethod
984 def _violin_plot(
985 savedir, plot_parameter, samples, labels, latex_label, inj_values=None,
986 preliminary=False, checkpoint=False, kde=ReflectionBoundedKDE,
987 default_bounds=True
988 ):
989 """Generate a violin plot for a given set of samples
991 Parameters
992 ----------
993 savedir: str
994 the directory you wish to save the plot in
995 plot_parameter: str
996 name of the parameter you wish to generate a violin plot for
997 samples: list
998 list of samples for each parameter
999 labels: list
1000 list of labels used to distinguish each result file
1001 latex_label: str
1002 latex_label correspondig to parameter
1003 inj_value: list
1004 list of injected values for each sample
1005 preliminary: Bool, optional
1006 if True, add a preliminary watermark to the plot
1007 """
1008 filename = os.path.join(
1009 savedir, "publication", "violin_plot_{}.png".format(plot_parameter)
1010 )
1011 if os.path.isfile(filename) and checkpoint:
1012 return
1013 xlow, xhigh = None, None
1014 if default_bounds:
1015 xlow, xhigh = gw._return_bounds(
1016 plot_parameter, samples, comparison=True
1017 )
1018 fig = publication.violin_plots(
1019 plot_parameter, samples, labels, latex_labels, kde=kde,
1020 kde_kwargs={"xlow": xlow, "xhigh": xhigh}, inj_values=inj_values
1021 )
1022 _PlotGeneration.save(
1023 fig, filename, preliminary=preliminary
1024 )
1026 def spin_dist_plot(self, label):
1027 """Generate a spin disk plot to compare spins in all result
1028 files
1030 Parameters
1031 ----------
1032 label: str
1033 the label for the results file that you wish to plot
1034 """
1035 error_message = (
1036 "Failed to generate a spin disk plot for %s because {}"
1037 )
1038 parameters = ["a_1", "a_2", "cos_tilt_1", "cos_tilt_2"]
1039 for num, label in enumerate(self.labels):
1040 if not all(i in self.samples[label].keys() for i in parameters):
1041 logger.warning(
1042 "Failed to generate spin disk plots because {} are not "
1043 "common in all result files".format(
1044 " and ".join(parameters)
1045 )
1046 )
1047 continue
1048 samples = [self.samples[label][i] for i in parameters]
1049 arguments = [
1050 self.savedir, parameters, samples, label, self.colors[num],
1051 self.preliminary_comparison_pages, self.checkpoint
1052 ]
1054 self._try_to_make_a_plot(
1055 arguments, self._spin_dist_plot, error_message % (label)
1056 )
1058 @staticmethod
1059 def _spin_dist_plot(
1060 savedir, parameters, samples, label, color, preliminary=False,
1061 checkpoint=False
1062 ):
1063 """Generate a spin disk plot for a given set of samples
1065 Parameters
1066 ----------
1067 preliminary: Bool, optional
1068 if True, add a preliminary watermark to the plot
1069 """
1070 filename = os.path.join(
1071 savedir, "publication", "spin_disk_plot_{}.png".format(label)
1072 )
1073 if os.path.isfile(filename) and checkpoint:
1074 return
1075 fig = publication.spin_distribution_plots(
1076 parameters, samples, label, color=color
1077 )
1078 _PlotGeneration.save(
1079 fig, filename, preliminary=preliminary
1080 )
1082 def pepredicates_plot(self, label):
1083 """Generate plots with the PEPredicates package
1085 Parameters
1086 ----------
1087 label: str
1088 the label for the results file that you wish to plot
1089 """
1090 if self.mcmc_samples:
1091 samples = self.samples[label].combine
1092 else:
1093 samples = self.samples[label]
1094 self._pepredicates_plot(
1095 self.savedir, samples, label,
1096 self.pepredicates_probs[label]["default"], population_prior=False,
1097 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint
1098 )
1099 self._pepredicates_plot(
1100 self.savedir, samples, label,
1101 self.pepredicates_probs[label]["population"], population_prior=True,
1102 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint
1103 )
1105 @staticmethod
1106 @no_latex_plot
1107 def _pepredicates_plot(
1108 savedir, samples, label, probabilities, population_prior=False,
1109 preliminary=False, checkpoint=False
1110 ):
1111 """Generate a plot with the PEPredicates package for a given set of
1112 samples
1114 Parameters
1115 ----------
1116 savedir: str
1117 the directory you wish to save the plot in
1118 samples: dict
1119 dictionary of samples for each parameter
1120 label: str
1121 the label corresponding to the result file
1122 probabilities: dict
1123 dictionary of classification probabilities
1124 population_prior: Bool, optional
1125 if True, the samples will be reweighted according to a population
1126 prior
1127 preliminary: Bool, optional
1128 if True, add a preliminary watermark to the plot
1129 """
1130 from pesummary.gw.classification import PEPredicates
1132 if not population_prior:
1133 filename = os.path.join(
1134 savedir, "{}_default_pepredicates.png".format(label)
1135 )
1136 else:
1137 filename = os.path.join(
1138 savedir, "{}_population_pepredicates.png".format(label)
1139 )
1141 _pepredicates = PEPredicates(samples)
1142 if os.path.isfile(filename) and checkpoint:
1143 pass
1144 else:
1145 fig = _pepredicates.plot(
1146 type="pepredicates", population=population_prior,
1147 probabilities=probabilities
1148 )
1149 _PlotGeneration.save(
1150 fig, filename, preliminary=preliminary
1151 )
1153 if not population_prior:
1154 filename = os.path.join(
1155 savedir, "{}_default_pepredicates_bar.png".format(label)
1156 )
1157 else:
1158 filename = os.path.join(
1159 savedir, "{}_population_pepredicates_bar.png".format(label)
1160 )
1161 if os.path.isfile(filename) and checkpoint:
1162 pass
1163 else:
1164 fig = _pepredicates.plot(
1165 type="bar", probabilities=probabilities,
1166 population=population_prior
1167 )
1168 _PlotGeneration.save(
1169 fig, filename, preliminary=preliminary
1170 )
1172 def psd_plot(self, label):
1173 """Generate a psd plot for a given result file
1175 Parameters
1176 ----------
1177 label: str
1178 the label corresponding to the result file
1179 """
1180 error_message = (
1181 "Failed to generate a PSD plot for %s because {}"
1182 )
1184 fmin = None
1185 fmax = None
1187 for num, label in enumerate(self.labels):
1188 if list(self.psd[label].keys()) == [None]:
1189 return
1190 if list(self.psd[label].keys()) == []:
1191 return
1192 if "f_low" in list(self.file_kwargs[label]["meta_data"].keys()):
1193 fmin = self.file_kwargs[label]["meta_data"]["f_low"]
1194 if "f_final" in list(self.file_kwargs[label]["meta_data"].keys()):
1195 fmax = self.file_kwargs[label]["meta_data"]["f_final"]
1196 labels = list(self.psd[label].keys())
1197 frequencies = [np.array(self.psd[label][i]).T[0] for i in labels]
1198 strains = [np.array(self.psd[label][i]).T[1] for i in labels]
1199 arguments = [
1200 self.savedir, frequencies, strains, fmin, fmax, labels, label,
1201 self.checkpoint
1202 ]
1204 self._try_to_make_a_plot(
1205 arguments, self._psd_plot, error_message % (label)
1206 )
1208 @staticmethod
1209 def _psd_plot(
1210 savedir, frequencies, strains, fmin, fmax, psd_labels, label, checkpoint=False
1211 ):
1212 """Generate a psd plot for a given set of samples
1214 Parameters
1215 ----------
1216 savedir: str
1217 the directory you wish to save the plot in
1218 frequencies: list
1219 list of psd frequencies for each IFO
1220 strains: list
1221 list of psd strains for each IFO
1222 fmin: float
1223 frequency to start the psd plotting
1224 fmax: float
1225 frequency to end the psd plotting
1226 psd_labels: list
1227 list of IFOs used
1228 label: str
1229 the label used to distinguish the result file
1230 """
1231 filename = os.path.join(savedir, "{}_psd_plot.png".format(label))
1232 if os.path.isfile(filename) and checkpoint:
1233 return
1234 fig = gw._psd_plot(
1235 frequencies, strains, labels=psd_labels, fmin=fmin, fmax=fmax
1236 )
1237 _PlotGeneration.save(fig, filename)
1239 def calibration_plot(self, label):
1240 """Generate a calibration plot for a given result file
1242 Parameters
1243 ----------
1244 label: str
1245 the label corresponding to the result file
1246 """
1247 import numpy as np
1249 error_message = (
1250 "Failed to generate calibration plot for %s because {}"
1251 )
1252 frequencies = np.arange(20., 1024., 1. / 4)
1254 for num, label in enumerate(self.labels):
1255 if list(self.calibration[label].keys()) == [None]:
1256 return
1257 if list(self.calibration[label].keys()) == []:
1258 return
1260 ifos = list(self.calibration[label].keys())
1261 calibration_data = [
1262 self.calibration[label][i] for i in ifos
1263 ]
1264 if "calibration" in self.priors.keys():
1265 prior = [self.priors["calibration"][label][i] for i in ifos]
1266 else:
1267 prior = None
1268 arguments = [
1269 self.savedir, frequencies, calibration_data, ifos, prior,
1270 label, self.calibration_definition[label], self.checkpoint
1271 ]
1272 self._try_to_make_a_plot(
1273 arguments, self._calibration_plot, error_message % (label)
1274 )
1276 @staticmethod
1277 def _calibration_plot(
1278 savedir, frequencies, calibration_data, calibration_labels, prior, label,
1279 calibration_definition="data", checkpoint=False
1280 ):
1281 """Generate a calibration plot for a given set of samples
1283 Parameters
1284 ----------
1285 savedir: str
1286 the directory you wish to save the plot in
1287 frequencies: list
1288 list of frequencies used to interpolate the calibration data
1289 calibration_data: list
1290 list of calibration data for each IFO
1291 calibration_labels: list
1292 list of IFOs used
1293 prior: list
1294 list containing the priors used for each IFO
1295 label: str
1296 the label used to distinguish the result file
1297 calibration_definition: str
1298 the definition of the calibration prior used (either 'data' or 'template')
1299 """
1300 filename = os.path.join(
1301 savedir, "{}_calibration_plot.png".format(label)
1302 )
1303 if os.path.isfile(filename) and checkpoint:
1304 return
1305 fig = gw._calibration_envelope_plot(
1306 frequencies, calibration_data, calibration_labels, prior=prior,
1307 definition=calibration_definition
1308 )
1309 _PlotGeneration.save(fig, filename)
1311 @staticmethod
1312 def _interactive_corner_plot(
1313 savedir, label, samples, latex_labels, checkpoint=False
1314 ):
1315 """Generate an interactive corner plot for a given set of samples
1317 Parameters
1318 ----------
1319 savedir: str
1320 the directory you wish to save the plot in
1321 label: str
1322 the label corresponding to the results file
1323 samples: dict
1324 dictionary containing PESummary.utils.utils.Array objects that
1325 contain samples for each parameter
1326 latex_labels: str
1327 latex labels for each parameter in samples
1328 """
1329 filename = os.path.join(
1330 savedir, "corner", "{}_interactive_source.html".format(label)
1331 )
1332 if os.path.isfile(filename) and checkpoint:
1333 pass
1334 else:
1335 source_parameters = [
1336 "luminosity_distance", "mass_1_source", "mass_2_source",
1337 "total_mass_source", "chirp_mass_source", "redshift"
1338 ]
1339 parameters = [i for i in samples.keys() if i in source_parameters]
1340 data = [samples[parameter] for parameter in parameters]
1341 labels = [latex_labels[parameter] for parameter in parameters]
1342 _ = interactive.corner(
1343 data, labels, write_to_html_file=filename,
1344 dimensions={"width": 900, "height": 900}
1345 )
1347 filename = os.path.join(
1348 savedir, "corner", "{}_interactive_extrinsic.html".format(label)
1349 )
1350 if os.path.isfile(filename) and checkpoint:
1351 pass
1352 else:
1353 extrinsic_parameters = ["luminosity_distance", "psi", "ra", "dec"]
1354 parameters = [i for i in samples.keys() if i in extrinsic_parameters]
1355 data = [samples[parameter] for parameter in parameters]
1356 labels = [latex_labels[parameter] for parameter in parameters]
1357 _ = interactive.corner(
1358 data, labels, write_to_html_file=filename
1359 )