Coverage for pesummary/gw/plots/plot.py: 77.8%
650 statements
« prev ^ index » next coverage.py v7.4.4, created at 2026-01-15 17:49 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2026-01-15 17:49 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3from pesummary.utils.utils import (
4 logger, number_of_columns_for_legend, _check_latex_install,
5)
6from pesummary.utils.decorators import no_latex_plot
7from pesummary.gw.plots.bounds import default_bounds
8from pesummary.core.plots.figure import figure, subplots, ExistingFigure
9from pesummary.core.plots.plot import _default_legend_kwargs
10from pesummary import conf
12import os
13import matplotlib.style
14import numpy as np
15import math
16from scipy.ndimage import gaussian_filter
17from astropy.time import Time
19_check_latex_install()
21from lal import MSUN_SI, PC_SI, CreateDict
23__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
24try:
25 import lalsimulation as lalsim
26 LALSIMULATION = True
27except ImportError:
28 LALSIMULATION = None
31def _return_bounds(param, samples, comparison=False):
32 """Return the bounds for a given param
34 Parameters
35 ----------
36 param: str
37 name of the parameter you wish to get bounds for
38 samples: list/np.ndarray
39 array or list of array of posterior samples for param
40 comparison: Bool, optional
41 True if samples is a list of array's of posterior samples
42 """
43 xlow, xhigh = None, None
44 if param in default_bounds.keys():
45 bounds = default_bounds[param]
46 if "low" in bounds.keys():
47 xlow = bounds["low"]
48 if "high" in bounds.keys():
49 if isinstance(bounds["high"], str) and "mass_1" in bounds["high"]:
50 if comparison:
51 xhigh = np.max([np.max(i) for i in samples])
52 else:
53 xhigh = np.max(samples)
54 else:
55 xhigh = bounds["high"]
56 return xlow, xhigh
59def _add_default_bounds_to_kde_kwargs_dict(
60 kde_kwargs, param, samples, comparison=False
61):
62 """Add default kde bounds to the a dictionary of kwargs
64 Parameters
65 ----------
66 kde_kwargs: dict
67 dictionary of kwargs to pass to the kde class
68 param: str
69 name of the parameter you wish to plot
70 samples: list
71 list of samples for param
72 """
73 from pesummary.utils.bounded_1d_kde import bounded_1d_kde
75 xlow, xhigh = _return_bounds(param, samples, comparison=comparison)
76 kde_kwargs["xlow"] = xlow
77 kde_kwargs["xhigh"] = xhigh
78 kde_kwargs["kde_kernel"] = bounded_1d_kde
79 return kde_kwargs
82def _1d_histogram_plot(
83 param, samples, *args, kde_kwargs={}, bounded=True, **kwargs
84):
85 """Generate the 1d histogram plot for a given parameter for a given
86 approximant.
88 Parameters
89 ----------
90 *args: tuple
91 all args passed directly to pesummary.core.plots.plot._1d_histogram_plot
92 function
93 kde_kwargs: dict, optional
94 optional kwargs passed to the kde class
95 bounded: Bool, optional
96 if True, pass default 'xlow' and 'xhigh' arguments to the kde class
97 **kwargs: dict, optional
98 all additional kwargs passed to the
99 pesummary.core.plots.plot._1d_histogram_plot function
100 """
101 from pesummary.core.plots.plot import _1d_histogram_plot
103 if bounded:
104 kde_kwargs = _add_default_bounds_to_kde_kwargs_dict(
105 kde_kwargs, param, samples
106 )
107 return _1d_histogram_plot(
108 param, samples, *args, kde_kwargs=kde_kwargs, **kwargs
109 )
112def _1d_histogram_plot_mcmc(
113 param, samples, *args, kde_kwargs={}, bounded=True, **kwargs
114):
115 """Generate the 1d histogram plot for a given parameter for set of
116 mcmc chains
118 Parameters
119 ----------
120 *args: tuple
121 all args passed directly to
122 pesummary.core.plots.plot._1d_histogram_plot_mcmc function
123 kde_kwargs: dict, optional
124 optional kwargs passed to the kde class
125 bounded: Bool, optional
126 if True, pass default 'xlow' and 'xhigh' arguments to the kde class
127 **kwargs: dict, optional
128 all additional kwargs passed to the
129 pesummary.core.plots.plot._1d_histogram_plot_mcmc function
130 """
131 from pesummary.core.plots.plot import _1d_histogram_plot_mcmc
133 if bounded:
134 kde_kwargs = _add_default_bounds_to_kde_kwargs_dict(
135 kde_kwargs, param, samples, comparison=True
136 )
137 return _1d_histogram_plot_mcmc(
138 param, samples, *args, kde_kwargs=kde_kwargs, **kwargs
139 )
142def _1d_histogram_plot_bootstrap(
143 param, samples, *args, kde_kwargs={}, bounded=True, **kwargs
144):
145 """Generate a bootstrapped 1d histogram plot for a given parameter
147 Parameters
148 ----------
149 param: str
150 name of the parameter that you wish to plot
151 samples: np.ndarray
152 array of samples for param
153 args: tuple
154 all args passed to
155 pesummary.core.plots.plot._1d_histogram_plot_bootstrap function
156 kde_kwargs: dict, optional
157 optional kwargs passed to the kde class
158 bounded: Bool, optional
159 if True, pass default 'xlow' and 'xhigh' arguments to the kde class
160 **kwargs: dict, optional
161 all additional kwargs passed to the
162 pesummary.core.plots.plot._1d_histogram_plot_bootstrap function
163 """
164 from pesummary.core.plots.plot import _1d_histogram_plot_bootstrap
166 if bounded:
167 kde_kwargs = _add_default_bounds_to_kde_kwargs_dict(
168 kde_kwargs, param, samples
169 )
170 return _1d_histogram_plot_bootstrap(
171 param, samples, *args, kde_kwargs=kde_kwargs, **kwargs
172 )
175def _1d_comparison_histogram_plot(
176 param, samples, *args, kde_kwargs={}, bounded=True, max_vline=2,
177 legend_kwargs=_default_legend_kwargs, **kwargs
178):
179 """Generate the a plot to compare the 1d_histogram plots for a given
180 parameter for different approximants.
182 Parameters
183 ----------
184 *args: tuple
185 all args passed directly to
186 pesummary.core.plots.plot._1d_comparisonhistogram_plot function
187 kde_kwargs: dict, optional
188 optional kwargs passed to the kde class
189 bounded: Bool, optional
190 if True, pass default 'xlow' and 'xhigh' arguments to the kde class
191 max_vline: int, optional
192 if number of peaks < max_vline draw peaks as vertical lines rather
193 than histogramming the data
194 **kwargs: dict, optional
195 all additional kwargs passed to the
196 pesummary.core.plots.plot._1d_comparison_histogram_plot function
197 """
198 from pesummary.core.plots.plot import _1d_comparison_histogram_plot
200 if bounded:
201 kde_kwargs = _add_default_bounds_to_kde_kwargs_dict(
202 kde_kwargs, param, samples, comparison=True
203 )
204 return _1d_comparison_histogram_plot(
205 param, samples, *args, kde_kwargs=kde_kwargs, max_vline=max_vline,
206 legend_kwargs=legend_kwargs, **kwargs
207 )
210def _make_corner_plot(samples, latex_labels, corner_parameters=None, **kwargs):
211 """Generate the corner plots for a given approximant
213 Parameters
214 ----------
215 opts: argparse
216 argument parser object to hold all information from the command line
217 samples: nd list
218 nd list of samples for each parameter for a given approximant
219 params: list
220 list of parameters associated with each element in samples
221 approximant: str
222 name of approximant that was used to generate the samples
223 latex_labels: dict
224 dictionary of latex labels for each parameter
225 """
226 from pesummary.core.plots.plot import _make_corner_plot
228 if corner_parameters is None:
229 corner_parameters = conf.gw_corner_parameters
231 return _make_corner_plot(
232 samples, latex_labels, corner_parameters=corner_parameters, **kwargs
233 )
236def _make_source_corner_plot(samples, latex_labels, **kwargs):
237 """Generate the corner plots for a given approximant
239 Parameters
240 ----------
241 opts: argparse
242 argument parser object to hold all information from the command line
243 samples: nd list
244 nd list of samples for each parameter for a given approximant
245 params: list
246 list of parameters associated with each element in samples
247 approximant: str
248 name of approximant that was used to generate the samples
249 latex_labels: dict
250 dictionary of latex labels for each parameter
251 """
252 from pesummary.core.plots.plot import _make_corner_plot
254 return _make_corner_plot(
255 samples, latex_labels,
256 corner_parameters=conf.gw_source_frame_corner_parameters, **kwargs
257 )[0]
260def _make_extrinsic_corner_plot(samples, latex_labels, **kwargs):
261 """Generate the corner plots for a given approximant
263 Parameters
264 ----------
265 opts: argparse
266 argument parser object to hold all information from the command line
267 samples: nd list
268 nd list of samples for each parameter for a given approximant
269 params: list
270 list of parameters associated with each element in samples
271 approximant: str
272 name of approximant that was used to generate the samples
273 latex_labels: dict
274 dictionary of latex labels for each parameter
275 """
276 from pesummary.core.plots.plot import _make_corner_plot
278 return _make_corner_plot(
279 samples, latex_labels,
280 corner_parameters=conf.gw_extrinsic_corner_parameters, **kwargs
281 )[0]
284def _make_comparison_corner_plot(
285 samples, latex_labels, corner_parameters=None, colors=conf.corner_colors,
286 **kwargs
287):
288 """Generate a corner plot which contains multiple datasets
290 Parameters
291 ----------
292 samples: dict
293 nested dictionary containing the label as key and SamplesDict as item
294 for each dataset you wish to plot
295 latex_labels: dict
296 dictionary of latex labels for each parameter
297 corner_parameters: list, optional
298 corner parameters you wish to include in the plot
299 colors: list, optional
300 unique colors for each dataset
301 **kwargs: dict
302 all kwargs are passed to `corner.corner`
303 """
304 from pesummary.core.plots.plot import _make_comparison_corner_plot
306 if corner_parameters is None:
307 corner_parameters = conf.gw_corner_parameters
309 return _make_comparison_corner_plot(
310 samples, latex_labels, corner_parameters=corner_parameters,
311 colors=colors, **kwargs
312 )
315def __antenna_response(name, ra, dec, psi, time_gps):
316 """Calculate the antenna response function
318 Parameters
319 ----------
320 name: str
321 name of the detector you wish to calculate the antenna response
322 function for
323 ra: float
324 right ascension of the source
325 dec: float
326 declination of the source
327 psi: float
328 polarisation of the source
329 time_gps: float
330 gps time of merger
331 """
332 # Following 8 lines taken from pycbc.detector.Detector
333 from astropy.units.si import sday
334 reference_time = 1126259462.0
335 gmst_reference = Time(
336 reference_time, format='gps', scale='utc', location=(0, 0)
337 ).sidereal_time('mean').rad
338 dphase = (time_gps - reference_time) / float(sday.si.scale) * (2.0 * np.pi)
339 gmst = (gmst_reference + dphase) % (2.0 * np.pi)
340 corrected_ra = gmst - ra
341 if not LALSIMULATION:
342 raise Exception("lalsimulation could not be imported. please install "
343 "lalsuite to be able to use all features")
344 detector = lalsim.DetectorPrefixToLALDetector(str(name))
346 x0 = -np.cos(psi) * np.sin(corrected_ra) - \
347 np.sin(psi) * np.cos(corrected_ra) * np.sin(dec)
348 x1 = -np.cos(psi) * np.cos(corrected_ra) + \
349 np.sin(psi) * np.sin(corrected_ra) * np.sin(dec)
350 x2 = np.sin(psi) * np.cos(dec)
351 x = np.array([x0, x1, x2])
352 dx = detector.response.dot(x)
354 y0 = np.sin(psi) * np.sin(corrected_ra) - \
355 np.cos(psi) * np.cos(corrected_ra) * np.sin(dec)
356 y1 = np.sin(psi) * np.cos(corrected_ra) + \
357 np.cos(psi) * np.sin(corrected_ra) * np.sin(dec)
358 y2 = np.cos(psi) * np.cos(dec)
359 y = np.array([y0, y1, y2])
360 dy = detector.response.dot(y)
362 if hasattr(dx, "shape"):
363 fplus = (x * dx - y * dy).sum(axis=0)
364 fcross = (x * dy + y * dx).sum(axis=0)
365 else:
366 fplus = (x * dx - y * dy).sum()
367 fcross = (x * dy + y * dx).sum()
369 return fplus, fcross
372@no_latex_plot
373def _waveform_plot(
374 detectors, maxL_params, color=None, label=None, fig=None, ax=None,
375 **kwargs
376):
377 """Plot the maximum likelihood waveform for a given approximant.
379 Parameters
380 ----------
381 detectors: list
382 list of detectors that you want to generate waveforms for
383 maxL_params: dict
384 dictionary of maximum likelihood parameter values
385 kwargs: dict
386 dictionary of optional keyword arguments
387 """
388 from gwpy.plot.colors import GW_OBSERVATORY_COLORS
389 from pesummary.gw.waveform import fd_waveform
390 if math.isnan(maxL_params["mass_1"]):
391 return
392 logger.debug("Generating the maximum likelihood waveform plot")
393 if not LALSIMULATION:
394 raise Exception("lalsimulation could not be imported. please install "
395 "lalsuite to be able to use all features")
397 if (fig is None) and (ax is None):
398 fig, ax = figure(gca=True)
399 elif ax is None:
400 ax = fig.gca()
401 elif fig is None:
402 raise ValueError("Please provide a figure for plotting")
403 if color is None:
404 color = [GW_OBSERVATORY_COLORS[i] for i in detectors]
405 elif len(color) != len(detectors):
406 raise ValueError(
407 "Please provide a list of colors for each detector"
408 )
409 if label is None:
410 label = detectors
411 elif len(label) != len(detectors):
412 raise ValueError(
413 "Please provide a list of labels for each detector"
414 )
415 minimum_frequency = kwargs.get("f_low", 5.)
416 starting_frequency = kwargs.get("f_start", 5.)
417 maximum_frequency = kwargs.get("f_max", 1000.)
418 approximant_flags = kwargs.get("approximant_flags", {})
419 for num, i in enumerate(detectors):
420 ht = fd_waveform(
421 maxL_params, maxL_params["approximant"],
422 kwargs.get("delta_f", 1. / 256), starting_frequency,
423 maximum_frequency, f_ref=kwargs.get("f_ref", starting_frequency),
424 project=i, flags=approximant_flags
425 )
426 mask = (
427 (ht.frequencies.value > starting_frequency) *
428 (ht.frequencies.value < maximum_frequency)
429 )
430 ax.plot(
431 ht.frequencies.value[mask], np.abs(ht)[mask], color=color[num],
432 linewidth=1.0, label=label[num]
433 )
434 if starting_frequency < minimum_frequency:
435 ax.axvspan(starting_frequency, minimum_frequency, alpha=0.1, color='grey')
436 ax.set_xscale("log")
437 ax.set_yscale("log")
438 ax.set_xlabel(r"Frequency $[Hz]$")
439 ax.set_ylabel(r"Strain")
440 ax.grid(visible=True)
441 ax.legend(loc="best")
442 fig.tight_layout()
443 return fig
446@no_latex_plot
447def _waveform_comparison_plot(maxL_params_list, colors, labels,
448 **kwargs):
449 """Generate a plot which compares the maximum likelihood waveforms for
450 each approximant.
452 Parameters
453 ----------
454 maxL_params_list: list
455 list of dictionaries containing the maximum likelihood parameter
456 values for each approximant
457 colors: list
458 list of colors to be used to differentiate the different approximants
459 approximant_labels: list, optional
460 label to prepend the approximant in the legend
461 kwargs: dict
462 dictionary of optional keyword arguments
463 """
464 logger.debug("Generating the maximum likelihood waveform comparison plot "
465 "for H1")
466 if not LALSIMULATION:
467 raise Exception("LALSimulation could not be imported. Please install "
468 "LALSuite to be able to use all features")
470 fig, ax = figure(gca=True)
471 for num, i in enumerate(maxL_params_list):
472 _kwargs = {
473 "f_start": i.get("f_start", 20.),
474 "f_low": i.get("f_low", 20.),
475 "f_max": i.get("f_final", 1024.),
476 "f_ref": i.get("f_ref", 20.),
477 "approximant_flags": i.get("approximant_flags", {})
478 }
479 _ = _waveform_plot(
480 ["H1"], i, fig=fig, ax=ax, color=[colors[num]],
481 label=[labels[num]], **_kwargs
482 )
483 ax.set_xscale("log")
484 ax.set_yscale("log")
485 ax.grid(visible=True)
486 ax.legend(loc="best")
487 ax.set_xlabel(r"Frequency $[Hz]$")
488 ax.set_ylabel(r"Strain")
489 fig.tight_layout()
490 return fig
493def _ligo_skymap_plot(ra, dec, dist=None, savedir="./", nprocess=1,
494 downsampled=False, label="pesummary", time=None,
495 distance_map=True, multi_resolution=True,
496 injection=None, **kwargs):
497 """Plot the sky location of the source for a given approximant using the
498 ligo.skymap package
500 Parameters
501 ----------
502 ra: list
503 list of samples for right ascension
504 dec: list
505 list of samples for declination
506 dist: list
507 list of samples for the luminosity distance
508 savedir: str
509 path to the directory where you would like to save the output files
510 nprocess: Bool
511 Boolean for whether to use multithreading or not
512 downsampled: Bool
513 Boolean for whether the samples have been downsampled or not
514 distance_map: Bool
515 Boolean for whether or not to produce a distance map
516 multi_resolution: Bool
517 Boolean for whether or not to generate a multiresolution HEALPix map
518 injection: list, optional
519 List containing RA and DEC of the injection. Both must be in radians
520 kwargs: dict
521 optional keyword arguments
522 """
523 from ligo.skymap.bayestar import rasterize
524 from ligo.skymap import io
525 from ligo.skymap.kde import Clustered2DSkyKDE, Clustered2Plus1DSkyKDE
527 if dist is not None and distance_map:
528 pts = np.column_stack((ra, dec, dist))
529 cls = Clustered2Plus1DSkyKDE
530 else:
531 pts = np.column_stack((ra, dec))
532 cls = Clustered2DSkyKDE
533 skypost = cls(pts, trials=5, jobs=nprocess)
534 hpmap = skypost.as_healpix()
535 if not multi_resolution:
536 hpmap = rasterize(hpmap)
537 hpmap.meta['creator'] = "pesummary"
538 hpmap.meta['origin'] = 'LIGO/Virgo'
539 hpmap.meta['gps_creation_time'] = Time.now().gps
540 if dist is not None:
541 hpmap.meta["distmean"] = float(np.mean(dist))
542 hpmap.meta["diststd"] = float(np.std(dist))
543 if time is not None:
544 if isinstance(time, (float, int)):
545 _time = time
546 else:
547 _time = np.mean(time)
548 hpmap.meta["gps_time"] = _time
550 io.write_sky_map(
551 os.path.join(savedir, "%s_skymap.fits" % (label)), hpmap, nest=True
552 )
553 skymap, metadata = io.fits.read_sky_map(
554 os.path.join(savedir, "%s_skymap.fits" % (label)), nest=None
555 )
556 return _ligo_skymap_plot_from_array(
557 skymap, nsamples=len(ra), downsampled=downsampled, injection=injection
558 )[0]
561def _ligo_skymap_plot_from_array(
562 skymap, nsamples=None, downsampled=False, contour=[50, 90],
563 annotate=True, fig=None, ax=None, colors="k", injection=None
564):
565 """Generate a skymap with `ligo.skymap` based on an array of probabilities
567 Parameters
568 ----------
569 skymap: np.array
570 array of probabilities
571 nsamples: int, optional
572 number of samples used
573 downsampled: Bool, optional
574 If True, add a header to the skymap saying that this plot is downsampled
575 contour: list, optional
576 list of contours to be plotted on the skymap. Default 50, 90
577 annotate: Bool, optional
578 If True, annotate the figure by adding the 90% and 50% sky areas
579 by default
580 ax: matplotlib.axes._subplots.AxesSubplot, optional
581 Existing axis to add the plot to
582 colors: str/list
583 colors to use for the contours
584 injection: list, optional
585 List containing RA and DEC of the injection. Both must be in radians
586 """
587 import healpy as hp
588 from ligo.skymap import plot
590 if fig is None and ax is None:
591 fig = figure(gca=False)
592 ax = fig.add_subplot(111, projection='astro hours mollweide')
593 elif ax is None:
594 ax = fig.gca()
596 ax.grid(visible=True)
597 nside = hp.npix2nside(len(skymap))
598 deg2perpix = hp.nside2pixarea(nside, degrees=True)
599 probperdeg2 = skymap / deg2perpix
601 if downsampled:
602 ax.set_title("Downsampled to %s" % (nsamples), fontdict={'fontsize': 11})
604 vmax = probperdeg2.max()
605 ax.imshow_hpx((probperdeg2, 'ICRS'), nested=True, vmin=0.,
606 vmax=vmax, cmap="cylon")
607 cls, cs = _ligo_skymap_contours(ax, skymap, contour=contour, colors=colors)
608 if annotate:
609 text = []
610 pp = np.round(contour).astype(int)
611 ii = np.round(
612 np.searchsorted(np.sort(cls), contour) * deg2perpix).astype(int)
613 for i, p in zip(ii, pp):
614 text.append(u'{:d}% area: {:d} deg²'.format(p, i, grouping=True))
615 ax.text(1, 1.05, '\n'.join(text), transform=ax.transAxes, ha='right',
616 fontsize=10)
617 plot.outline_text(ax)
618 if injection is not None and len(injection) == 2:
619 from astropy.coordinates import SkyCoord
620 from astropy import units as u
622 _inj = SkyCoord(*injection, unit=u.rad)
623 ax.scatter(
624 _inj.ra.value, _inj.dec.value, marker="*", color="orange",
625 edgecolors='k', linewidth=1.75, s=100, zorder=100,
626 transform=ax.get_transform('world')
627 )
629 if fig is None:
630 return fig, ax
631 return ExistingFigure(fig), ax
634def _ligo_skymap_comparion_plot_from_array(
635 skymaps, colors, labels, contour=[50, 90], show_probability_map=None,
636 injection=None
637):
638 """Generate a skymap with `ligo.skymap` based which compares arrays of
639 probabilities
641 Parameters
642 ----------
643 skymaps: list
644 list of skymap probabilities
645 colors: list
646 list of colors to use for each skymap
647 labels: list
648 list of labels associated with each skymap
649 contour: list, optional
650 contours you wish to display on the comparison plot
651 show_probability_map: int, optional
652 the index of the skymap you wish to show the probability
653 map for. Default None
654 injection: list, optional
655 List containing RA and DEC of the injection. Both must be in radians
656 """
657 from ligo.skymap import plot
658 import matplotlib.lines as mlines
659 ncols = number_of_columns_for_legend(labels)
660 fig = figure(gca=False)
661 ax = fig.add_subplot(111, projection='astro hours mollweide')
662 ax.grid(visible=True)
663 lines = []
664 for num, skymap in enumerate(skymaps):
665 if isinstance(show_probability_map, int) and show_probability_map == num:
666 _, ax = _ligo_skymap_plot_from_array(
667 skymap, nsamples=None, downsampled=False, contour=contour,
668 annotate=False, ax=ax, colors=colors[num], injection=injection,
669 )
670 cls, cs = _ligo_skymap_contours(
671 ax, skymap, contour=contour, colors=colors[num],
672 )
673 lines.append(mlines.Line2D([], [], color=colors[num], label=labels[num]))
674 ax.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, borderaxespad=0.,
675 mode="expand", ncol=ncols, handles=lines)
676 return fig
679def _ligo_skymap_contours(ax, skymap, contour=[50, 90], colors='k'):
680 """Plot contours on a ligo.skymap skymap
682 Parameters
683 ----------
684 ax: matplotlib.axes._subplots.AxesSubplot, optional
685 Existing axis to add the plot to
686 skymap: np.array
687 array of probabilities
688 contour: list
689 list contours you wish to plot
690 colors: str/list
691 colors to use for the contours
692 """
693 from ligo.skymap import postprocess
695 cls = 100 * postprocess.find_greedy_credible_levels(skymap)
696 cs = ax.contour_hpx((cls, 'ICRS'), nested=True, colors=colors,
697 linewidths=0.5, levels=contour)
698 ax.clabel(cs, fmt=r'%g\%%', fontsize=6, inline=True)
699 return cls, cs
702def _default_skymap_plot(ra, dec, weights=None, injection=None, **kwargs):
703 """Plot the default sky location of the source for a given approximant
705 Parameters
706 ----------
707 ra: list
708 list of samples for right ascension
709 dec: list
710 list of samples for declination
711 injection: list, optional
712 list containing the injected value of ra and dec
713 kwargs: dict
714 optional keyword arguments
715 """
716 from .cmap import register_cylon, unregister_cylon
717 # register the cylon cmap
718 register_cylon()
719 ra = [-i + np.pi for i in ra]
720 logger.debug("Generating the sky map plot")
721 fig, orig_ax = figure(gca=True)
722 orig_ax.spines['left'].set_visible(False)
723 orig_ax.spines['right'].set_visible(False)
724 orig_ax.spines['top'].set_visible(False)
725 orig_ax.spines['bottom'].set_visible(False)
726 orig_ax.set_yticks([])
727 orig_ax.set_xticks([])
728 ax = fig.add_subplot(
729 111, projection="mollweide",
730 facecolor=(1.0, 0.939165516411, 0.880255669068)
731 )
732 ax.cla()
733 ax.set_title("Preliminary", fontdict={'fontsize': 11})
734 ax.grid(visible=True)
735 ax.set_xticklabels([
736 r"$2^{h}$", r"$4^{h}$", r"$6^{h}$", r"$8^{h}$", r"$10^{h}$",
737 r"$12^{h}$", r"$14^{h}$", r"$16^{h}$", r"$18^{h}$", r"$20^{h}$",
738 r"$22^{h}$"])
739 levels = [0.9, 0.5]
741 if weights is None:
742 H, X, Y = np.histogram2d(ra, dec, bins=50)
743 else:
744 H, X, Y = np.histogram2d(ra, dec, bins=50, weights=weights)
745 H = gaussian_filter(H, kwargs.get("smooth", 0.9))
746 Hflat = H.flatten()
747 indicies = np.argsort(Hflat)[::-1]
748 Hflat = Hflat[indicies]
750 CF = np.cumsum(Hflat)
751 CF /= CF[-1]
753 V = np.empty(len(levels))
754 for num, i in enumerate(levels):
755 try:
756 V[num] = Hflat[CF <= i][-1]
757 except Exception:
758 V[num] = Hflat[0]
759 V.sort()
760 m = np.diff(V) == 0
761 while np.any(m):
762 V[np.where(m)[0][0]] *= 1.0 - 1e-4
763 m = np.diff(V) == 0
764 V.sort()
765 X1, Y1 = 0.5 * (X[1:] + X[:-1]), 0.5 * (Y[1:] + Y[:-1])
767 H2 = H.min() + np.zeros((H.shape[0] + 4, H.shape[1] + 4))
768 H2[2:-2, 2:-2] = H
769 H2[2:-2, 1] = H[:, 0]
770 H2[2:-2, -2] = H[:, -1]
771 H2[1, 2:-2] = H[0]
772 H2[-2, 2:-2] = H[-1]
773 H2[1, 1] = H[0, 0]
774 H2[1, -2] = H[0, -1]
775 H2[-2, 1] = H[-1, 0]
776 H2[-2, -2] = H[-1, -1]
777 X2 = np.concatenate([X1[0] + np.array([-2, -1]) * np.diff(X1[:2]), X1,
778 X1[-1] + np.array([1, 2]) * np.diff(X1[-2:]), ])
779 Y2 = np.concatenate([Y1[0] + np.array([-2, -1]) * np.diff(Y1[:2]), Y1,
780 Y1[-1] + np.array([1, 2]) * np.diff(Y1[-2:]), ])
782 ax.pcolormesh(X2, Y2, H2.T, vmin=0., vmax=H2.T.max(), cmap="cylon")
783 cs = ax.contour(X2, Y2, H2.T, V, colors="k", linewidths=0.5)
784 if injection is not None:
785 ax.scatter(
786 -injection[0] + np.pi, injection[1], marker="*",
787 color=conf.injection_color, edgecolors='k', linewidth=1.75, s=100
788 )
789 fmt = {l: s for l, s in zip(cs.levels, [r"$90\%$", r"$50\%$"])}
790 ax.clabel(cs, fmt=fmt, fontsize=8, inline=True)
791 text = []
792 for path, j in zip(cs.get_paths(), [90, 50]):
793 area = 0.
794 for poly in path.to_polygons():
795 x = poly[:, 0]
796 y = poly[:, 1]
797 area += 0.5 * np.sum(y[:-1] * np.diff(x) - x[:-1] * np.diff(y))
798 area = int(np.abs(area) * (180 / np.pi) * (180 / np.pi))
799 text.append(u'{:d}\% area: {:d} deg²'.format(int(j), area, grouping=True))
800 orig_ax.text(1.0, 1.05, '\n'.join(text[::-1]), transform=ax.transAxes, ha='right',
801 fontsize=10)
802 xticks = np.arange(-np.pi, np.pi + np.pi / 6, np.pi / 4)
803 ax.set_xticks(xticks)
804 ax.set_yticks([-np.pi / 3, -np.pi / 6, 0, np.pi / 6, np.pi / 3])
805 labels = [r"$%s^{h}$" % (int(np.round((i + np.pi) * 3.82, 1))) for i in xticks]
806 ax.set_xticklabels(labels[::-1], fontsize=10)
807 ax.set_yticklabels([r"$-60^{\circ}$", r"$-30^{\circ}$", r"$0^{\circ}$",
808 r"$30^{\circ}$", r"$60^{\circ}$"], fontsize=10)
809 ax.grid(visible=True)
810 # unregister the cylon cmap
811 unregister_cylon()
812 return fig
815def _sky_map_comparison_plot(ra_list, dec_list, labels, colors, **kwargs):
816 """Generate a plot that compares the sky location for multiple approximants
818 Parameters
819 ----------
820 ra_list: 2d list
821 list of samples for right ascension for each approximant
822 dec_list: 2d list
823 list of samples for declination for each approximant
824 approximants: list
825 list of approximants used to generate the samples
826 colors: list
827 list of colors to be used to differentiate the different approximants
828 approximant_labels: list, optional
829 label to prepend the approximant in the legend
830 kwargs: dict
831 optional keyword arguments
832 """
833 ra_list = [[-i + np.pi for i in j] for j in ra_list]
834 logger.debug("Generating the sky map comparison plot")
835 fig = figure(gca=False)
836 ax = fig.add_subplot(
837 111, projection="mollweide",
838 )
839 ax.cla()
840 ax.set_title("Preliminary", fontdict={'fontsize': 11})
841 ax.grid(visible=True)
842 ax.set_xticklabels([
843 r"$2^{h}$", r"$4^{h}$", r"$6^{h}$", r"$8^{h}$", r"$10^{h}$",
844 r"$12^{h}$", r"$14^{h}$", r"$16^{h}$", r"$18^{h}$", r"$20^{h}$",
845 r"$22^{h}$"])
846 levels = [0.9, 0.5]
847 for num, i in enumerate(ra_list):
848 H, X, Y = np.histogram2d(i, dec_list[num], bins=50)
849 H = gaussian_filter(H, kwargs.get("smooth", 0.9))
850 Hflat = H.flatten()
851 indicies = np.argsort(Hflat)[::-1]
852 Hflat = Hflat[indicies]
854 CF = np.cumsum(Hflat)
855 CF /= CF[-1]
857 V = np.empty(len(levels))
858 for num2, j in enumerate(levels):
859 try:
860 V[num2] = Hflat[CF <= j][-1]
861 except Exception:
862 V[num2] = Hflat[0]
863 V.sort()
864 m = np.diff(V) == 0
865 while np.any(m):
866 V[np.where(m)[0][0]] *= 1.0 - 1e-4
867 m = np.diff(V) == 0
868 V.sort()
869 X1, Y1 = 0.5 * (X[1:] + X[:-1]), 0.5 * (Y[1:] + Y[:-1])
871 H2 = H.min() + np.zeros((H.shape[0] + 4, H.shape[1] + 4))
872 H2[2:-2, 2:-2] = H
873 H2[2:-2, 1] = H[:, 0]
874 H2[2:-2, -2] = H[:, -1]
875 H2[1, 2:-2] = H[0]
876 H2[-2, 2:-2] = H[-1]
877 H2[1, 1] = H[0, 0]
878 H2[1, -2] = H[0, -1]
879 H2[-2, 1] = H[-1, 0]
880 H2[-2, -2] = H[-1, -1]
881 X2 = np.concatenate([X1[0] + np.array([-2, -1]) * np.diff(X1[:2]), X1,
882 X1[-1] + np.array([1, 2]) * np.diff(X1[-2:]), ])
883 Y2 = np.concatenate([Y1[0] + np.array([-2, -1]) * np.diff(Y1[:2]), Y1,
884 Y1[-1] + np.array([1, 2]) * np.diff(Y1[-2:]), ])
885 CS = ax.contour(X2, Y2, H2.T, V, colors=colors[num], linewidths=2.0)
886 ax.plot([], [], color=colors[num], linewidth=2.0, label=labels[num])
887 ncols = number_of_columns_for_legend(labels)
888 ax.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, borderaxespad=0.,
889 mode="expand", ncol=ncols)
890 xticks = np.arange(-np.pi, np.pi + np.pi / 6, np.pi / 4)
891 ax.set_xticks(xticks)
892 ax.set_yticks([-np.pi / 3, -np.pi / 6, 0, np.pi / 6, np.pi / 3])
893 labels = [r"$%s^{h}$" % (int(np.round((i + np.pi) * 3.82, 1))) for i in xticks]
894 ax.set_xticklabels(labels[::-1], fontsize=10)
895 ax.set_yticklabels([r"$-60^\degree$", r"$-30^\degree$", r"$0^\degree$",
896 r"$30^\degree$", r"$60^\degree$"], fontsize=10)
897 ax.grid(visible=True)
898 return fig
901def __get_cutoff_indices(flow, fhigh, df, N):
902 """
903 Gets the indices of a frequency series at which to stop an overlap
904 calculation.
906 Parameters
907 ----------
908 flow: float
909 The frequency (in Hz) of the lower index.
910 fhigh: float
911 The frequency (in Hz) of the upper index.
912 df: float
913 The frequency step (in Hz) of the frequency series.
914 N: int
915 The number of points in the **time** series. Can be odd
916 or even.
918 Returns
919 -------
920 kmin: int
921 kmax: int
922 """
923 if flow:
924 kmin = int(flow / df)
925 else:
926 kmin = 1
927 if fhigh:
928 kmax = int(fhigh / df)
929 else:
930 kmax = int((N + 1) / 2.)
931 return kmin, kmax
934@no_latex_plot
935def _sky_sensitivity(network, resolution, maxL_params, **kwargs):
936 """Generate the sky sensitivity for a given network
938 Parameters
939 ----------
940 network: list
941 list of detectors you want included in your sky sensitivity plot
942 resolution: float
943 resolution of the skymap
944 maxL_params: dict
945 dictionary of waveform parameters for the maximum likelihood waveform
946 """
947 logger.debug("Generating the sky sensitivity for %s" % (network))
948 if not LALSIMULATION:
949 raise Exception("LALSimulation could not be imported. Please install "
950 "LALSuite to be able to use all features")
951 delta_frequency = kwargs.get("delta_f", 1. / 256)
952 minimum_frequency = kwargs.get("f_min", 20.)
953 maximum_frequency = kwargs.get("f_max", 1000.)
954 frequency_array = np.arange(minimum_frequency, maximum_frequency,
955 delta_frequency)
957 approx = lalsim.GetApproximantFromString(maxL_params["approximant"])
958 mass_1 = maxL_params["mass_1"] * MSUN_SI
959 mass_2 = maxL_params["mass_2"] * MSUN_SI
960 luminosity_distance = maxL_params["luminosity_distance"] * PC_SI * 10**6
961 iota, S1x, S1y, S1z, S2x, S2y, S2z = \
962 lalsim.SimInspiralTransformPrecessingNewInitialConditions(
963 maxL_params["iota"], maxL_params["phi_jl"], maxL_params["tilt_1"],
964 maxL_params["tilt_2"], maxL_params["phi_12"], maxL_params["a_1"],
965 maxL_params["a_2"], mass_1, mass_2, kwargs.get("f_ref", 10.),
966 maxL_params["phase"])
967 h_plus, h_cross = lalsim.SimInspiralChooseFDWaveform(
968 mass_1, mass_2, S1x, S1y, S1z, S2x, S2y, S2z, luminosity_distance, iota,
969 maxL_params["phase"], 0.0, 0.0, 0.0, delta_frequency, minimum_frequency,
970 maximum_frequency, kwargs.get("f_ref", 10.), None, approx)
971 h_plus = h_plus.data.data
972 h_cross = h_cross.data.data
973 h_plus = h_plus[:len(frequency_array)]
974 h_cross = h_cross[:len(frequency_array)]
975 psd = {}
976 psd["H1"] = psd["L1"] = np.array([
977 lalsim.SimNoisePSDaLIGOZeroDetHighPower(i) for i in frequency_array])
978 psd["V1"] = np.array([lalsim.SimNoisePSDVirgo(i) for i in frequency_array])
979 kmin, kmax = __get_cutoff_indices(minimum_frequency, maximum_frequency,
980 delta_frequency, (len(h_plus) - 1) * 2)
981 ra = np.arange(-np.pi, np.pi, resolution)
982 dec = np.arange(-np.pi, np.pi, resolution)
983 X, Y = np.meshgrid(ra, dec)
984 N = np.zeros([len(dec), len(ra)])
986 indices = np.ndindex(len(ra), len(dec))
987 for ind in indices:
988 ar = {}
989 SNR = {}
990 for i in network:
991 ard = __antenna_response(i, ra[ind[0]], dec[ind[1]],
992 maxL_params["psi"], maxL_params["geocent_time"])
993 ar[i] = [ard[0], ard[1]]
994 strain = np.array(h_plus * ar[i][0] + h_cross * ar[i][1])
995 integrand = np.conj(strain[kmin:kmax]) * strain[kmin:kmax] / psd[i][kmin:kmax]
996 integrand = integrand[:-1]
997 SNR[i] = np.sqrt(4 * delta_frequency * np.sum(integrand).real)
998 ar[i][0] *= SNR[i]
999 ar[i][1] *= SNR[i]
1000 numerator = 0.0
1001 denominator = 0.0
1002 for i in network:
1003 numerator += sum(i**2 for i in ar[i])
1004 denominator += SNR[i]**2
1005 N[ind[1]][ind[0]] = (((numerator / denominator)**0.5))
1006 fig = figure(gca=False)
1007 ax = fig.add_subplot(111, projection="hammer")
1008 ax.cla()
1009 ax.grid(visible=True)
1010 ax.pcolormesh(X, Y, N)
1011 ax.set_xticklabels([
1012 r"$22^{h}$", r"$20^{h}$", r"$18^{h}$", r"$16^{h}$", r"$14^{h}$",
1013 r"$12^{h}$", r"$10^{h}$", r"$8^{h}$", r"$6^{h}$", r"$4^{h}$",
1014 r"$2^{h}$"])
1015 return fig
1018@no_latex_plot
1019def _time_domain_waveform(
1020 detectors, maxL_params, color=None, label=None, fig=None, ax=None,
1021 **kwargs
1022):
1023 """
1024 Plot the maximum likelihood waveform for a given approximant
1025 in the time domain.
1027 Parameters
1028 ----------
1029 detectors: list
1030 list of detectors that you want to generate waveforms for
1031 maxL_params: dict
1032 dictionary of maximum likelihood parameter values
1033 kwargs: dict
1034 dictionary of optional keyword arguments
1035 """
1036 from gwpy.timeseries import TimeSeries
1037 from gwpy.plot.colors import GW_OBSERVATORY_COLORS
1038 from pesummary.gw.waveform import td_waveform
1039 from pesummary.utils.samples_dict import SamplesDict
1040 if math.isnan(maxL_params["mass_1"]):
1041 return
1042 logger.debug("Generating the maximum likelihood waveform time domain plot")
1043 if not LALSIMULATION:
1044 raise Exception("lalsimulation could not be imported. please install "
1045 "lalsuite to be able to use all features")
1047 approximant = maxL_params["approximant"]
1048 minimum_frequency = kwargs.get("f_low", 5.)
1049 starting_frequency = kwargs.get("f_start", 5.)
1050 approximant_flags = kwargs.get("approximant_flags", {})
1051 _samples = SamplesDict(
1052 {
1053 key: [item] for key, item in maxL_params.items() if
1054 key != "approximant"
1055 }
1056 )
1057 _samples.generate_all_posterior_samples(disable_remnant=True)
1058 _samples = {key: item[0] for key, item in _samples.items()}
1059 chirptime = lalsim.SimIMRPhenomXASDuration(
1060 _samples["mass_1"] * MSUN_SI, _samples["mass_2"] * MSUN_SI,
1061 _samples.get("spin_1z", 0), _samples.get("spin_2z", 0),
1062 minimum_frequency
1063 )
1064 duration = np.max([2**np.ceil(np.log2(chirptime)), 1.0])
1065 if (fig is None) and (ax is None):
1066 fig, ax = figure(gca=True)
1067 elif ax is None:
1068 ax = fig.gca()
1069 elif fig is None:
1070 raise ValueError("Please provide a figure for plotting")
1071 if color is None:
1072 color = [GW_OBSERVATORY_COLORS[i] for i in detectors]
1073 elif len(color) != len(detectors):
1074 raise ValueError(
1075 "Please provide a list of colors for each detector"
1076 )
1077 if label is None:
1078 label = detectors
1079 elif len(label) != len(detectors):
1080 raise ValueError(
1081 "Please provide a list of labels for each detector"
1082 )
1083 for num, i in enumerate(detectors):
1084 ht = td_waveform(
1085 maxL_params, approximant, kwargs.get("delta_t", 1. / 4096.),
1086 starting_frequency, f_ref=kwargs.get("f_ref", 10.), project=i,
1087 flags=approximant_flags
1088 )
1089 ax.plot(
1090 ht.times.value, ht, color=color[num], linewidth=1.0,
1091 label=label[num]
1092 )
1093 ax.set_xlim(
1094 [
1095 maxL_params["geocent_time"] - 0.75 * duration,
1096 maxL_params["geocent_time"] + duration / 4
1097 ]
1098 )
1099 ax.set_xlabel(r"Time $[s]$")
1100 ax.set_ylabel(r"Strain")
1101 ax.grid(visible=True)
1102 ax.legend(loc="best")
1103 fig.tight_layout()
1104 return fig
1107@no_latex_plot
1108def _time_domain_waveform_comparison_plot(maxL_params_list, colors, labels,
1109 **kwargs):
1110 """Generate a plot which compares the maximum likelihood waveforms for
1111 each approximant.
1113 Parameters
1114 ----------
1115 maxL_params_list: list
1116 list of dictionaries containing the maximum likelihood parameter
1117 values for each approximant
1118 colors: list
1119 list of colors to be used to differentiate the different approximants
1120 approximant_labels: list, optional
1121 label to prepend the approximant in the legend
1122 kwargs: dict
1123 dictionary of optional keyword arguments
1124 """
1125 from gwpy.timeseries import TimeSeries
1126 logger.debug("Generating the maximum likelihood time domain waveform "
1127 "comparison plot for H1")
1128 if not LALSIMULATION:
1129 raise Exception("LALSimulation could not be imported. Please install "
1130 "LALSuite to be able to use all features")
1131 fig, ax = figure(gca=True)
1132 for num, i in enumerate(maxL_params_list):
1133 _kwargs = {
1134 "f_start": i.get("f_start", 20.),
1135 "f_low": i.get("f_low", 20.),
1136 "f_max": i.get("f_final", 1024.),
1137 "f_ref": i.get("f_ref", 20.),
1138 "approximant_flags": i.get("approximant_flags", {})
1139 }
1140 _ = _time_domain_waveform(
1141 ["H1"], i, fig=fig, ax=ax, color=[colors[num]],
1142 label=[labels[num]], **_kwargs
1143 )
1144 ax.set_xlabel(r"Time $[s]$")
1145 ax.set_ylabel(r"Strain")
1146 ax.grid(visible=True)
1147 ax.legend(loc="best")
1148 fig.tight_layout()
1149 return fig
1152def _psd_plot(frequencies, strains, colors=None, labels=None, fmin=None, fmax=None):
1153 """Superimpose all PSD plots onto a single figure.
1155 Parameters
1156 ----------
1157 frequencies: nd list
1158 list of all frequencies used for each psd file
1159 strains: nd list
1160 list of all strains used for each psd file
1161 colors: optional, list
1162 list of colors to be used to differentiate the different PSDs
1163 labels: optional, list
1164 list of lavels for each PSD
1165 fmin: optional, float
1166 starting frequency of the plot
1167 fmax: optional, float
1168 maximum frequency of the plot
1169 """
1170 from gwpy.plot.colors import GW_OBSERVATORY_COLORS
1171 fig, ax = figure(gca=True)
1172 if not colors and all(i in GW_OBSERVATORY_COLORS.keys() for i in labels):
1173 colors = [GW_OBSERVATORY_COLORS[i] for i in labels]
1174 elif not colors:
1175 colors = ['r', 'b', 'orange', 'c', 'g', 'purple']
1176 while len(colors) <= len(labels):
1177 colors += colors
1178 for num, i in enumerate(frequencies):
1179 ff = np.array(i)
1180 ss = np.array(strains[num])
1181 cond = np.ones_like(strains[num], dtype=bool)
1182 if fmin is not None:
1183 cond *= ff >= fmin
1184 if fmax is not None:
1185 cond *= ff <= fmax
1186 i = ff[cond]
1187 strains[num] = ss[cond]
1188 ax.loglog(i, strains[num], color=colors[num], label=labels[num])
1189 ax.tick_params(which="both", bottom=True, length=3, width=1)
1190 ax.set_xlabel(r"Frequency $[\mathrm{Hz}]$")
1191 ax.set_ylabel(r"Power Spectral Density [$\mathrm{strain}^{2}/\mathrm{Hz}$]")
1192 ax.legend(loc="best")
1193 fig.tight_layout()
1194 return fig
1197@no_latex_plot
1198def _calibration_envelope_plot(frequency, calibration_envelopes, ifos,
1199 colors=None, prior=[], definition="data"):
1200 """Generate a plot showing the calibration envelope
1202 Parameters
1203 ----------
1204 frequency: array
1205 frequency bandwidth that you would like to use
1206 calibration_envelopes: nd list
1207 list containing the calibration envelope data for different IFOs
1208 ifos: list
1209 list of IFOs that are associated with the calibration envelopes
1210 colors: list, optional
1211 list of colors to be used to differentiate the different calibration
1212 envelopes
1213 prior: list, optional
1214 list containing the prior calibration envelope data for different IFOs
1215 definition: str, optional
1216 definition used for the prior calibration envelope data
1217 """
1218 from gwpy.plot.colors import GW_OBSERVATORY_COLORS
1220 def interpolate_calibration(data):
1221 """Interpolate the calibration data using spline
1223 Parameters
1224 ----------
1225 data: np.ndarray
1226 array containing the calibration data
1227 """
1228 interp = [
1229 np.interp(frequency, data[:, 0], data[:, j], left=k, right=k)
1230 for j, k in zip(range(1, 7), [1, 0, 1, 0, 1, 0])
1231 ]
1232 amp_median = (interp[0] - 1) * 100
1233 phase_median = interp[1] * 180. / np.pi
1234 amp_lower_sigma = (interp[2] - 1) * 100
1235 phase_lower_sigma = interp[3] * 180. / np.pi
1236 amp_upper_sigma = (interp[4] - 1) * 100
1237 phase_upper_sigma = interp[5] * 180. / np.pi
1238 data_dict = {
1239 "amplitude": {
1240 "median": amp_median,
1241 "lower": amp_lower_sigma,
1242 "upper": amp_upper_sigma
1243 },
1244 "phase": {
1245 "median": phase_median,
1246 "lower": phase_lower_sigma,
1247 "upper": phase_upper_sigma
1248 }
1249 }
1250 return data_dict
1252 fig, (ax1, ax2) = subplots(2, 1, sharex=True, gca=False)
1253 if not colors and all(i in GW_OBSERVATORY_COLORS.keys() for i in ifos):
1254 colors = [GW_OBSERVATORY_COLORS[i] for i in ifos]
1255 elif not colors:
1256 colors = ['r', 'b', 'orange', 'c', 'g', 'purple']
1257 while len(colors) <= len(ifos):
1258 colors += colors
1260 for num, i in enumerate(calibration_envelopes):
1261 calibration_envelopes[num] = np.array(calibration_envelopes[num])
1263 for num, i in enumerate(calibration_envelopes):
1264 calibration_data = interpolate_calibration(i)
1265 if prior != []:
1266 prior_data = interpolate_calibration(prior[num])
1267 ax1.plot(
1268 frequency, calibration_data["amplitude"]["upper"], color=colors[num],
1269 linestyle="-", label=ifos[num]
1270 )
1271 ax1.plot(
1272 frequency, calibration_data["amplitude"]["lower"], color=colors[num],
1273 linestyle="-"
1274 )
1275 ax1.set_ylabel(r"Amplitude deviation $[\%]$", fontsize=10)
1276 ax1.legend(loc="best")
1277 ax2.plot(
1278 frequency, calibration_data["phase"]["upper"], color=colors[num],
1279 linestyle="-", label=ifos[num]
1280 )
1281 ax2.plot(
1282 frequency, calibration_data["phase"]["lower"], color=colors[num],
1283 linestyle="-"
1284 )
1285 ax2.set_ylabel(r"Phase deviation $[\degree]$", fontsize=10)
1286 if prior != []:
1287 ax1.fill_between(
1288 frequency, prior_data["amplitude"]["upper"],
1289 prior_data["amplitude"]["lower"], color=colors[num], alpha=0.2
1290 )
1291 ax2.fill_between(
1292 frequency, prior_data["phase"]["upper"],
1293 prior_data["phase"]["lower"], color=colors[num], alpha=0.2
1294 )
1296 ax1.set_title(f"Calibration correction applied to {definition}")
1297 ax1.set_xscale('log')
1298 ax2.set_xscale('log')
1299 ax2.set_xlabel(r"Frequency $[Hz]$")
1300 fig.tight_layout()
1301 return fig
1304def _strain_plot(strain, maxL_params, **kwargs):
1305 """Generate a plot showing the strain data and the maxL waveform
1307 Parameters
1308 ----------
1309 strain: gwpy.timeseries
1310 timeseries containing the strain data
1311 maxL_samples: dict
1312 dictionary of maximum likelihood parameter values
1313 """
1314 logger.debug("Generating the strain plot")
1315 from pesummary.gw.conversions import time_in_each_ifo
1316 from gwpy.timeseries import TimeSeries
1318 fig, axs = subplots(nrows=len(strain.keys()), sharex=True)
1319 time = maxL_params["geocent_time"]
1320 delta_t = 1. / 4096.
1321 minimum_frequency = kwargs.get("f_min", 5.)
1322 t_start = time - 15.0
1323 t_finish = time + 0.06
1324 time_array = np.arange(t_start, t_finish, delta_t)
1326 approx = lalsim.GetApproximantFromString(maxL_params["approximant"])
1327 mass_1 = maxL_params["mass_1"] * MSUN_SI
1328 mass_2 = maxL_params["mass_2"] * MSUN_SI
1329 luminosity_distance = maxL_params["luminosity_distance"] * PC_SI * 10**6
1330 phase = maxL_params["phase"] if "phase" in maxL_params.keys() else 0.0
1331 cartesian = [
1332 "iota", "spin_1x", "spin_1y", "spin_1z", "spin_2x", "spin_2y", "spin_2z"
1333 ]
1334 if not all(param in maxL_params.keys() for param in cartesian):
1335 if "phi_jl" in maxL_params.keys():
1336 iota, S1x, S1y, S1z, S2x, S2y, S2z = \
1337 lalsim.SimInspiralTransformPrecessingNewInitialConditions(
1338 maxL_params["theta_jn"], maxL_params["phi_jl"],
1339 maxL_params["tilt_1"], maxL_params["tilt_2"],
1340 maxL_params["phi_12"], maxL_params["a_1"],
1341 maxL_params["a_2"], mass_1, mass_2, kwargs.get("f_ref", 10.),
1342 phase
1343 )
1344 else:
1345 iota, S1x, S1y, S1z, S2x, S2y, S2z = maxL_params["iota"], 0., 0., \
1346 0., 0., 0., 0.
1347 else:
1348 iota, S1x, S1y, S1z, S2x, S2y, S2z = [
1349 maxL_params[param] for param in cartesian
1350 ]
1351 h_plus, h_cross = lalsim.SimInspiralChooseTDWaveform(
1352 mass_1, mass_2, S1x, S1y, S1z, S2x, S2y, S2z, luminosity_distance, iota,
1353 phase, 0.0, 0.0, 0.0, delta_t, minimum_frequency,
1354 kwargs.get("f_ref", 10.), None, approx)
1355 h_plus = TimeSeries(
1356 h_plus.data.data[:], dt=h_plus.deltaT, t0=h_plus.epoch
1357 )
1358 h_cross = TimeSeries(
1359 h_cross.data.data[:], dt=h_cross.deltaT, t0=h_cross.epoch
1360 )
1362 for num, key in enumerate(list(strain.keys())):
1363 ifo_time = time_in_each_ifo(key, maxL_params["ra"], maxL_params["dec"],
1364 maxL_params["geocent_time"])
1366 asd = strain[key].asd(8, 4, method="median")
1367 strain_data_frequency = strain[key].fft()
1368 asd_interp = asd.interpolate(float(np.array(strain_data_frequency.df)))
1369 asd_interp = asd_interp[:len(strain_data_frequency)]
1370 strain_data_time = (strain_data_frequency / asd_interp).ifft()
1371 strain_data_time = strain_data_time.highpass(30)
1372 strain_data_time = strain_data_time.lowpass(300)
1374 ar = __antenna_response(key, maxL_params["ra"], maxL_params["dec"],
1375 maxL_params["psi"], maxL_params["geocent_time"])
1377 h_t = ar[0] * h_plus + ar[1] * h_cross
1378 h_t_frequency = h_t.fft()
1379 asd_interp = asd.interpolate(float(np.array(h_t_frequency.df)))
1380 asd_interp = asd_interp[:len(h_t_frequency)]
1381 h_t_time = (h_t_frequency / asd_interp).ifft()
1382 h_t_time = h_t_time.highpass(30)
1383 h_t_time = h_t_time.lowpass(300)
1384 h_t_time.times = [float(np.array(i)) + ifo_time for i in h_t.times]
1386 strain_data_crop = strain_data_time.crop(ifo_time - 0.2, ifo_time + 0.06)
1387 try:
1388 h_t_time = h_t_time.crop(ifo_time - 0.2, ifo_time + 0.06)
1389 except Exception:
1390 pass
1391 max_strain = np.max(strain_data_crop).value
1393 axs[num].plot(strain_data_crop, color='grey', alpha=0.75, label="data")
1394 axs[num].plot(h_t_time, color='orange', label="template")
1395 axs[num].set_xlim([ifo_time - 0.2, ifo_time + 0.06])
1396 if not math.isnan(max_strain):
1397 axs[num].set_ylim([-max_strain * 1.5, max_strain * 1.5])
1398 axs[num].set_ylabel("Whitened %s strain" % (key), fontsize=8)
1399 axs[num].grid(False)
1400 axs[num].legend(loc="best", prop={'size': 8})
1401 axs[-1].set_xlabel("Time $[s]$", fontsize=16)
1402 fig.tight_layout()
1403 return fig
1406def _format_prob(prob):
1407 """Format the probabilities for use with _classification_plot
1408 """
1409 if prob >= 1:
1410 return '100%'
1411 elif prob <= 0:
1412 return '0%'
1413 elif prob > 0.99:
1414 return '>99%'
1415 elif prob < 0.01:
1416 return '<1%'
1417 else:
1418 try:
1419 return '{}%'.format(int(np.round(100 * prob)))
1420 except ValueError:
1421 return '{}%'.format(np.round(100 * prob))
1424@no_latex_plot
1425def _classification_plot(classification):
1426 """Generate a bar chart showing the source classifications probabilities
1428 Parameters
1429 ----------
1430 classification: dict
1431 dictionary of source classifications
1432 """
1433 probs, names = zip(
1434 *sorted(zip(classification.values(), classification.keys())))
1435 with matplotlib.style.context([
1436 "seaborn-v0_8-white",
1437 {
1438 "font.size": 12,
1439 "ytick.labelsize": 12,
1440 },
1441 ]):
1442 fig, ax = figure(figsize=(2.5, 2), gca=True)
1443 ax.barh(names, probs)
1444 for i, prob in enumerate(probs):
1445 ax.annotate(_format_prob(prob), (0, i), (4, 0),
1446 textcoords='offset points', ha='left', va='center')
1447 ax.set_xlim(0, 1)
1448 ax.set_xticks([])
1449 ax.tick_params(left=False)
1450 for side in ['top', 'bottom', 'right']:
1451 ax.spines[side].set_visible(False)
1452 fig.tight_layout()
1453 return fig