Coverage for pesummary/gw/plots/plot.py: 77.8%

650 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2026-01-15 17:49 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3from pesummary.utils.utils import ( 

4 logger, number_of_columns_for_legend, _check_latex_install, 

5) 

6from pesummary.utils.decorators import no_latex_plot 

7from pesummary.gw.plots.bounds import default_bounds 

8from pesummary.core.plots.figure import figure, subplots, ExistingFigure 

9from pesummary.core.plots.plot import _default_legend_kwargs 

10from pesummary import conf 

11 

12import os 

13import matplotlib.style 

14import numpy as np 

15import math 

16from scipy.ndimage import gaussian_filter 

17from astropy.time import Time 

18 

19_check_latex_install() 

20 

21from lal import MSUN_SI, PC_SI, CreateDict 

22 

23__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

24try: 

25 import lalsimulation as lalsim 

26 LALSIMULATION = True 

27except ImportError: 

28 LALSIMULATION = None 

29 

30 

31def _return_bounds(param, samples, comparison=False): 

32 """Return the bounds for a given param 

33 

34 Parameters 

35 ---------- 

36 param: str 

37 name of the parameter you wish to get bounds for 

38 samples: list/np.ndarray 

39 array or list of array of posterior samples for param 

40 comparison: Bool, optional 

41 True if samples is a list of array's of posterior samples 

42 """ 

43 xlow, xhigh = None, None 

44 if param in default_bounds.keys(): 

45 bounds = default_bounds[param] 

46 if "low" in bounds.keys(): 

47 xlow = bounds["low"] 

48 if "high" in bounds.keys(): 

49 if isinstance(bounds["high"], str) and "mass_1" in bounds["high"]: 

50 if comparison: 

51 xhigh = np.max([np.max(i) for i in samples]) 

52 else: 

53 xhigh = np.max(samples) 

54 else: 

55 xhigh = bounds["high"] 

56 return xlow, xhigh 

57 

58 

59def _add_default_bounds_to_kde_kwargs_dict( 

60 kde_kwargs, param, samples, comparison=False 

61): 

62 """Add default kde bounds to the a dictionary of kwargs 

63 

64 Parameters 

65 ---------- 

66 kde_kwargs: dict 

67 dictionary of kwargs to pass to the kde class 

68 param: str 

69 name of the parameter you wish to plot 

70 samples: list 

71 list of samples for param 

72 """ 

73 from pesummary.utils.bounded_1d_kde import bounded_1d_kde 

74 

75 xlow, xhigh = _return_bounds(param, samples, comparison=comparison) 

76 kde_kwargs["xlow"] = xlow 

77 kde_kwargs["xhigh"] = xhigh 

78 kde_kwargs["kde_kernel"] = bounded_1d_kde 

79 return kde_kwargs 

80 

81 

82def _1d_histogram_plot( 

83 param, samples, *args, kde_kwargs={}, bounded=True, **kwargs 

84): 

85 """Generate the 1d histogram plot for a given parameter for a given 

86 approximant. 

87 

88 Parameters 

89 ---------- 

90 *args: tuple 

91 all args passed directly to pesummary.core.plots.plot._1d_histogram_plot 

92 function 

93 kde_kwargs: dict, optional 

94 optional kwargs passed to the kde class 

95 bounded: Bool, optional 

96 if True, pass default 'xlow' and 'xhigh' arguments to the kde class 

97 **kwargs: dict, optional 

98 all additional kwargs passed to the 

99 pesummary.core.plots.plot._1d_histogram_plot function 

100 """ 

101 from pesummary.core.plots.plot import _1d_histogram_plot 

102 

103 if bounded: 

104 kde_kwargs = _add_default_bounds_to_kde_kwargs_dict( 

105 kde_kwargs, param, samples 

106 ) 

107 return _1d_histogram_plot( 

108 param, samples, *args, kde_kwargs=kde_kwargs, **kwargs 

109 ) 

110 

111 

112def _1d_histogram_plot_mcmc( 

113 param, samples, *args, kde_kwargs={}, bounded=True, **kwargs 

114): 

115 """Generate the 1d histogram plot for a given parameter for set of 

116 mcmc chains 

117 

118 Parameters 

119 ---------- 

120 *args: tuple 

121 all args passed directly to 

122 pesummary.core.plots.plot._1d_histogram_plot_mcmc function 

123 kde_kwargs: dict, optional 

124 optional kwargs passed to the kde class 

125 bounded: Bool, optional 

126 if True, pass default 'xlow' and 'xhigh' arguments to the kde class 

127 **kwargs: dict, optional 

128 all additional kwargs passed to the 

129 pesummary.core.plots.plot._1d_histogram_plot_mcmc function 

130 """ 

131 from pesummary.core.plots.plot import _1d_histogram_plot_mcmc 

132 

133 if bounded: 

134 kde_kwargs = _add_default_bounds_to_kde_kwargs_dict( 

135 kde_kwargs, param, samples, comparison=True 

136 ) 

137 return _1d_histogram_plot_mcmc( 

138 param, samples, *args, kde_kwargs=kde_kwargs, **kwargs 

139 ) 

140 

141 

142def _1d_histogram_plot_bootstrap( 

143 param, samples, *args, kde_kwargs={}, bounded=True, **kwargs 

144): 

145 """Generate a bootstrapped 1d histogram plot for a given parameter 

146 

147 Parameters 

148 ---------- 

149 param: str 

150 name of the parameter that you wish to plot 

151 samples: np.ndarray 

152 array of samples for param 

153 args: tuple 

154 all args passed to 

155 pesummary.core.plots.plot._1d_histogram_plot_bootstrap function 

156 kde_kwargs: dict, optional 

157 optional kwargs passed to the kde class 

158 bounded: Bool, optional 

159 if True, pass default 'xlow' and 'xhigh' arguments to the kde class 

160 **kwargs: dict, optional 

161 all additional kwargs passed to the 

162 pesummary.core.plots.plot._1d_histogram_plot_bootstrap function 

163 """ 

164 from pesummary.core.plots.plot import _1d_histogram_plot_bootstrap 

165 

166 if bounded: 

167 kde_kwargs = _add_default_bounds_to_kde_kwargs_dict( 

168 kde_kwargs, param, samples 

169 ) 

170 return _1d_histogram_plot_bootstrap( 

171 param, samples, *args, kde_kwargs=kde_kwargs, **kwargs 

172 ) 

173 

174 

175def _1d_comparison_histogram_plot( 

176 param, samples, *args, kde_kwargs={}, bounded=True, max_vline=2, 

177 legend_kwargs=_default_legend_kwargs, **kwargs 

178): 

179 """Generate the a plot to compare the 1d_histogram plots for a given 

180 parameter for different approximants. 

181 

182 Parameters 

183 ---------- 

184 *args: tuple 

185 all args passed directly to 

186 pesummary.core.plots.plot._1d_comparisonhistogram_plot function 

187 kde_kwargs: dict, optional 

188 optional kwargs passed to the kde class 

189 bounded: Bool, optional 

190 if True, pass default 'xlow' and 'xhigh' arguments to the kde class 

191 max_vline: int, optional 

192 if number of peaks < max_vline draw peaks as vertical lines rather 

193 than histogramming the data 

194 **kwargs: dict, optional 

195 all additional kwargs passed to the 

196 pesummary.core.plots.plot._1d_comparison_histogram_plot function 

197 """ 

198 from pesummary.core.plots.plot import _1d_comparison_histogram_plot 

199 

200 if bounded: 

201 kde_kwargs = _add_default_bounds_to_kde_kwargs_dict( 

202 kde_kwargs, param, samples, comparison=True 

203 ) 

204 return _1d_comparison_histogram_plot( 

205 param, samples, *args, kde_kwargs=kde_kwargs, max_vline=max_vline, 

206 legend_kwargs=legend_kwargs, **kwargs 

207 ) 

208 

209 

210def _make_corner_plot(samples, latex_labels, corner_parameters=None, **kwargs): 

211 """Generate the corner plots for a given approximant 

212 

213 Parameters 

214 ---------- 

215 opts: argparse 

216 argument parser object to hold all information from the command line 

217 samples: nd list 

218 nd list of samples for each parameter for a given approximant 

219 params: list 

220 list of parameters associated with each element in samples 

221 approximant: str 

222 name of approximant that was used to generate the samples 

223 latex_labels: dict 

224 dictionary of latex labels for each parameter 

225 """ 

226 from pesummary.core.plots.plot import _make_corner_plot 

227 

228 if corner_parameters is None: 

229 corner_parameters = conf.gw_corner_parameters 

230 

231 return _make_corner_plot( 

232 samples, latex_labels, corner_parameters=corner_parameters, **kwargs 

233 ) 

234 

235 

236def _make_source_corner_plot(samples, latex_labels, **kwargs): 

237 """Generate the corner plots for a given approximant 

238 

239 Parameters 

240 ---------- 

241 opts: argparse 

242 argument parser object to hold all information from the command line 

243 samples: nd list 

244 nd list of samples for each parameter for a given approximant 

245 params: list 

246 list of parameters associated with each element in samples 

247 approximant: str 

248 name of approximant that was used to generate the samples 

249 latex_labels: dict 

250 dictionary of latex labels for each parameter 

251 """ 

252 from pesummary.core.plots.plot import _make_corner_plot 

253 

254 return _make_corner_plot( 

255 samples, latex_labels, 

256 corner_parameters=conf.gw_source_frame_corner_parameters, **kwargs 

257 )[0] 

258 

259 

260def _make_extrinsic_corner_plot(samples, latex_labels, **kwargs): 

261 """Generate the corner plots for a given approximant 

262 

263 Parameters 

264 ---------- 

265 opts: argparse 

266 argument parser object to hold all information from the command line 

267 samples: nd list 

268 nd list of samples for each parameter for a given approximant 

269 params: list 

270 list of parameters associated with each element in samples 

271 approximant: str 

272 name of approximant that was used to generate the samples 

273 latex_labels: dict 

274 dictionary of latex labels for each parameter 

275 """ 

276 from pesummary.core.plots.plot import _make_corner_plot 

277 

278 return _make_corner_plot( 

279 samples, latex_labels, 

280 corner_parameters=conf.gw_extrinsic_corner_parameters, **kwargs 

281 )[0] 

282 

283 

284def _make_comparison_corner_plot( 

285 samples, latex_labels, corner_parameters=None, colors=conf.corner_colors, 

286 **kwargs 

287): 

288 """Generate a corner plot which contains multiple datasets 

289 

290 Parameters 

291 ---------- 

292 samples: dict 

293 nested dictionary containing the label as key and SamplesDict as item 

294 for each dataset you wish to plot 

295 latex_labels: dict 

296 dictionary of latex labels for each parameter 

297 corner_parameters: list, optional 

298 corner parameters you wish to include in the plot 

299 colors: list, optional 

300 unique colors for each dataset 

301 **kwargs: dict 

302 all kwargs are passed to `corner.corner` 

303 """ 

304 from pesummary.core.plots.plot import _make_comparison_corner_plot 

305 

306 if corner_parameters is None: 

307 corner_parameters = conf.gw_corner_parameters 

308 

309 return _make_comparison_corner_plot( 

310 samples, latex_labels, corner_parameters=corner_parameters, 

311 colors=colors, **kwargs 

312 ) 

313 

314 

315def __antenna_response(name, ra, dec, psi, time_gps): 

316 """Calculate the antenna response function 

317 

318 Parameters 

319 ---------- 

320 name: str 

321 name of the detector you wish to calculate the antenna response 

322 function for 

323 ra: float 

324 right ascension of the source 

325 dec: float 

326 declination of the source 

327 psi: float 

328 polarisation of the source 

329 time_gps: float 

330 gps time of merger 

331 """ 

332 # Following 8 lines taken from pycbc.detector.Detector 

333 from astropy.units.si import sday 

334 reference_time = 1126259462.0 

335 gmst_reference = Time( 

336 reference_time, format='gps', scale='utc', location=(0, 0) 

337 ).sidereal_time('mean').rad 

338 dphase = (time_gps - reference_time) / float(sday.si.scale) * (2.0 * np.pi) 

339 gmst = (gmst_reference + dphase) % (2.0 * np.pi) 

340 corrected_ra = gmst - ra 

341 if not LALSIMULATION: 

342 raise Exception("lalsimulation could not be imported. please install " 

343 "lalsuite to be able to use all features") 

344 detector = lalsim.DetectorPrefixToLALDetector(str(name)) 

345 

346 x0 = -np.cos(psi) * np.sin(corrected_ra) - \ 

347 np.sin(psi) * np.cos(corrected_ra) * np.sin(dec) 

348 x1 = -np.cos(psi) * np.cos(corrected_ra) + \ 

349 np.sin(psi) * np.sin(corrected_ra) * np.sin(dec) 

350 x2 = np.sin(psi) * np.cos(dec) 

351 x = np.array([x0, x1, x2]) 

352 dx = detector.response.dot(x) 

353 

354 y0 = np.sin(psi) * np.sin(corrected_ra) - \ 

355 np.cos(psi) * np.cos(corrected_ra) * np.sin(dec) 

356 y1 = np.sin(psi) * np.cos(corrected_ra) + \ 

357 np.cos(psi) * np.sin(corrected_ra) * np.sin(dec) 

358 y2 = np.cos(psi) * np.cos(dec) 

359 y = np.array([y0, y1, y2]) 

360 dy = detector.response.dot(y) 

361 

362 if hasattr(dx, "shape"): 

363 fplus = (x * dx - y * dy).sum(axis=0) 

364 fcross = (x * dy + y * dx).sum(axis=0) 

365 else: 

366 fplus = (x * dx - y * dy).sum() 

367 fcross = (x * dy + y * dx).sum() 

368 

369 return fplus, fcross 

370 

371 

372@no_latex_plot 

373def _waveform_plot( 

374 detectors, maxL_params, color=None, label=None, fig=None, ax=None, 

375 **kwargs 

376): 

377 """Plot the maximum likelihood waveform for a given approximant. 

378 

379 Parameters 

380 ---------- 

381 detectors: list 

382 list of detectors that you want to generate waveforms for 

383 maxL_params: dict 

384 dictionary of maximum likelihood parameter values 

385 kwargs: dict 

386 dictionary of optional keyword arguments 

387 """ 

388 from gwpy.plot.colors import GW_OBSERVATORY_COLORS 

389 from pesummary.gw.waveform import fd_waveform 

390 if math.isnan(maxL_params["mass_1"]): 

391 return 

392 logger.debug("Generating the maximum likelihood waveform plot") 

393 if not LALSIMULATION: 

394 raise Exception("lalsimulation could not be imported. please install " 

395 "lalsuite to be able to use all features") 

396 

397 if (fig is None) and (ax is None): 

398 fig, ax = figure(gca=True) 

399 elif ax is None: 

400 ax = fig.gca() 

401 elif fig is None: 

402 raise ValueError("Please provide a figure for plotting") 

403 if color is None: 

404 color = [GW_OBSERVATORY_COLORS[i] for i in detectors] 

405 elif len(color) != len(detectors): 

406 raise ValueError( 

407 "Please provide a list of colors for each detector" 

408 ) 

409 if label is None: 

410 label = detectors 

411 elif len(label) != len(detectors): 

412 raise ValueError( 

413 "Please provide a list of labels for each detector" 

414 ) 

415 minimum_frequency = kwargs.get("f_low", 5.) 

416 starting_frequency = kwargs.get("f_start", 5.) 

417 maximum_frequency = kwargs.get("f_max", 1000.) 

418 approximant_flags = kwargs.get("approximant_flags", {}) 

419 for num, i in enumerate(detectors): 

420 ht = fd_waveform( 

421 maxL_params, maxL_params["approximant"], 

422 kwargs.get("delta_f", 1. / 256), starting_frequency, 

423 maximum_frequency, f_ref=kwargs.get("f_ref", starting_frequency), 

424 project=i, flags=approximant_flags 

425 ) 

426 mask = ( 

427 (ht.frequencies.value > starting_frequency) * 

428 (ht.frequencies.value < maximum_frequency) 

429 ) 

430 ax.plot( 

431 ht.frequencies.value[mask], np.abs(ht)[mask], color=color[num], 

432 linewidth=1.0, label=label[num] 

433 ) 

434 if starting_frequency < minimum_frequency: 

435 ax.axvspan(starting_frequency, minimum_frequency, alpha=0.1, color='grey') 

436 ax.set_xscale("log") 

437 ax.set_yscale("log") 

438 ax.set_xlabel(r"Frequency $[Hz]$") 

439 ax.set_ylabel(r"Strain") 

440 ax.grid(visible=True) 

441 ax.legend(loc="best") 

442 fig.tight_layout() 

443 return fig 

444 

445 

446@no_latex_plot 

447def _waveform_comparison_plot(maxL_params_list, colors, labels, 

448 **kwargs): 

449 """Generate a plot which compares the maximum likelihood waveforms for 

450 each approximant. 

451 

452 Parameters 

453 ---------- 

454 maxL_params_list: list 

455 list of dictionaries containing the maximum likelihood parameter 

456 values for each approximant 

457 colors: list 

458 list of colors to be used to differentiate the different approximants 

459 approximant_labels: list, optional 

460 label to prepend the approximant in the legend 

461 kwargs: dict 

462 dictionary of optional keyword arguments 

463 """ 

464 logger.debug("Generating the maximum likelihood waveform comparison plot " 

465 "for H1") 

466 if not LALSIMULATION: 

467 raise Exception("LALSimulation could not be imported. Please install " 

468 "LALSuite to be able to use all features") 

469 

470 fig, ax = figure(gca=True) 

471 for num, i in enumerate(maxL_params_list): 

472 _kwargs = { 

473 "f_start": i.get("f_start", 20.), 

474 "f_low": i.get("f_low", 20.), 

475 "f_max": i.get("f_final", 1024.), 

476 "f_ref": i.get("f_ref", 20.), 

477 "approximant_flags": i.get("approximant_flags", {}) 

478 } 

479 _ = _waveform_plot( 

480 ["H1"], i, fig=fig, ax=ax, color=[colors[num]], 

481 label=[labels[num]], **_kwargs 

482 ) 

483 ax.set_xscale("log") 

484 ax.set_yscale("log") 

485 ax.grid(visible=True) 

486 ax.legend(loc="best") 

487 ax.set_xlabel(r"Frequency $[Hz]$") 

488 ax.set_ylabel(r"Strain") 

489 fig.tight_layout() 

490 return fig 

491 

492 

493def _ligo_skymap_plot(ra, dec, dist=None, savedir="./", nprocess=1, 

494 downsampled=False, label="pesummary", time=None, 

495 distance_map=True, multi_resolution=True, 

496 injection=None, **kwargs): 

497 """Plot the sky location of the source for a given approximant using the 

498 ligo.skymap package 

499 

500 Parameters 

501 ---------- 

502 ra: list 

503 list of samples for right ascension 

504 dec: list 

505 list of samples for declination 

506 dist: list 

507 list of samples for the luminosity distance 

508 savedir: str 

509 path to the directory where you would like to save the output files 

510 nprocess: Bool 

511 Boolean for whether to use multithreading or not 

512 downsampled: Bool 

513 Boolean for whether the samples have been downsampled or not 

514 distance_map: Bool 

515 Boolean for whether or not to produce a distance map 

516 multi_resolution: Bool 

517 Boolean for whether or not to generate a multiresolution HEALPix map 

518 injection: list, optional 

519 List containing RA and DEC of the injection. Both must be in radians 

520 kwargs: dict 

521 optional keyword arguments 

522 """ 

523 from ligo.skymap.bayestar import rasterize 

524 from ligo.skymap import io 

525 from ligo.skymap.kde import Clustered2DSkyKDE, Clustered2Plus1DSkyKDE 

526 

527 if dist is not None and distance_map: 

528 pts = np.column_stack((ra, dec, dist)) 

529 cls = Clustered2Plus1DSkyKDE 

530 else: 

531 pts = np.column_stack((ra, dec)) 

532 cls = Clustered2DSkyKDE 

533 skypost = cls(pts, trials=5, jobs=nprocess) 

534 hpmap = skypost.as_healpix() 

535 if not multi_resolution: 

536 hpmap = rasterize(hpmap) 

537 hpmap.meta['creator'] = "pesummary" 

538 hpmap.meta['origin'] = 'LIGO/Virgo' 

539 hpmap.meta['gps_creation_time'] = Time.now().gps 

540 if dist is not None: 

541 hpmap.meta["distmean"] = float(np.mean(dist)) 

542 hpmap.meta["diststd"] = float(np.std(dist)) 

543 if time is not None: 

544 if isinstance(time, (float, int)): 

545 _time = time 

546 else: 

547 _time = np.mean(time) 

548 hpmap.meta["gps_time"] = _time 

549 

550 io.write_sky_map( 

551 os.path.join(savedir, "%s_skymap.fits" % (label)), hpmap, nest=True 

552 ) 

553 skymap, metadata = io.fits.read_sky_map( 

554 os.path.join(savedir, "%s_skymap.fits" % (label)), nest=None 

555 ) 

556 return _ligo_skymap_plot_from_array( 

557 skymap, nsamples=len(ra), downsampled=downsampled, injection=injection 

558 )[0] 

559 

560 

561def _ligo_skymap_plot_from_array( 

562 skymap, nsamples=None, downsampled=False, contour=[50, 90], 

563 annotate=True, fig=None, ax=None, colors="k", injection=None 

564): 

565 """Generate a skymap with `ligo.skymap` based on an array of probabilities 

566 

567 Parameters 

568 ---------- 

569 skymap: np.array 

570 array of probabilities 

571 nsamples: int, optional 

572 number of samples used 

573 downsampled: Bool, optional 

574 If True, add a header to the skymap saying that this plot is downsampled 

575 contour: list, optional 

576 list of contours to be plotted on the skymap. Default 50, 90 

577 annotate: Bool, optional 

578 If True, annotate the figure by adding the 90% and 50% sky areas 

579 by default 

580 ax: matplotlib.axes._subplots.AxesSubplot, optional 

581 Existing axis to add the plot to 

582 colors: str/list 

583 colors to use for the contours 

584 injection: list, optional 

585 List containing RA and DEC of the injection. Both must be in radians 

586 """ 

587 import healpy as hp 

588 from ligo.skymap import plot 

589 

590 if fig is None and ax is None: 

591 fig = figure(gca=False) 

592 ax = fig.add_subplot(111, projection='astro hours mollweide') 

593 elif ax is None: 

594 ax = fig.gca() 

595 

596 ax.grid(visible=True) 

597 nside = hp.npix2nside(len(skymap)) 

598 deg2perpix = hp.nside2pixarea(nside, degrees=True) 

599 probperdeg2 = skymap / deg2perpix 

600 

601 if downsampled: 

602 ax.set_title("Downsampled to %s" % (nsamples), fontdict={'fontsize': 11}) 

603 

604 vmax = probperdeg2.max() 

605 ax.imshow_hpx((probperdeg2, 'ICRS'), nested=True, vmin=0., 

606 vmax=vmax, cmap="cylon") 

607 cls, cs = _ligo_skymap_contours(ax, skymap, contour=contour, colors=colors) 

608 if annotate: 

609 text = [] 

610 pp = np.round(contour).astype(int) 

611 ii = np.round( 

612 np.searchsorted(np.sort(cls), contour) * deg2perpix).astype(int) 

613 for i, p in zip(ii, pp): 

614 text.append(u'{:d}% area: {:d} deg²'.format(p, i, grouping=True)) 

615 ax.text(1, 1.05, '\n'.join(text), transform=ax.transAxes, ha='right', 

616 fontsize=10) 

617 plot.outline_text(ax) 

618 if injection is not None and len(injection) == 2: 

619 from astropy.coordinates import SkyCoord 

620 from astropy import units as u 

621 

622 _inj = SkyCoord(*injection, unit=u.rad) 

623 ax.scatter( 

624 _inj.ra.value, _inj.dec.value, marker="*", color="orange", 

625 edgecolors='k', linewidth=1.75, s=100, zorder=100, 

626 transform=ax.get_transform('world') 

627 ) 

628 

629 if fig is None: 

630 return fig, ax 

631 return ExistingFigure(fig), ax 

632 

633 

634def _ligo_skymap_comparion_plot_from_array( 

635 skymaps, colors, labels, contour=[50, 90], show_probability_map=None, 

636 injection=None 

637): 

638 """Generate a skymap with `ligo.skymap` based which compares arrays of 

639 probabilities 

640 

641 Parameters 

642 ---------- 

643 skymaps: list 

644 list of skymap probabilities 

645 colors: list 

646 list of colors to use for each skymap 

647 labels: list 

648 list of labels associated with each skymap 

649 contour: list, optional 

650 contours you wish to display on the comparison plot 

651 show_probability_map: int, optional 

652 the index of the skymap you wish to show the probability 

653 map for. Default None 

654 injection: list, optional 

655 List containing RA and DEC of the injection. Both must be in radians 

656 """ 

657 from ligo.skymap import plot 

658 import matplotlib.lines as mlines 

659 ncols = number_of_columns_for_legend(labels) 

660 fig = figure(gca=False) 

661 ax = fig.add_subplot(111, projection='astro hours mollweide') 

662 ax.grid(visible=True) 

663 lines = [] 

664 for num, skymap in enumerate(skymaps): 

665 if isinstance(show_probability_map, int) and show_probability_map == num: 

666 _, ax = _ligo_skymap_plot_from_array( 

667 skymap, nsamples=None, downsampled=False, contour=contour, 

668 annotate=False, ax=ax, colors=colors[num], injection=injection, 

669 ) 

670 cls, cs = _ligo_skymap_contours( 

671 ax, skymap, contour=contour, colors=colors[num], 

672 ) 

673 lines.append(mlines.Line2D([], [], color=colors[num], label=labels[num])) 

674 ax.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, borderaxespad=0., 

675 mode="expand", ncol=ncols, handles=lines) 

676 return fig 

677 

678 

679def _ligo_skymap_contours(ax, skymap, contour=[50, 90], colors='k'): 

680 """Plot contours on a ligo.skymap skymap 

681 

682 Parameters 

683 ---------- 

684 ax: matplotlib.axes._subplots.AxesSubplot, optional 

685 Existing axis to add the plot to 

686 skymap: np.array 

687 array of probabilities 

688 contour: list 

689 list contours you wish to plot 

690 colors: str/list 

691 colors to use for the contours 

692 """ 

693 from ligo.skymap import postprocess 

694 

695 cls = 100 * postprocess.find_greedy_credible_levels(skymap) 

696 cs = ax.contour_hpx((cls, 'ICRS'), nested=True, colors=colors, 

697 linewidths=0.5, levels=contour) 

698 ax.clabel(cs, fmt=r'%g\%%', fontsize=6, inline=True) 

699 return cls, cs 

700 

701 

702def _default_skymap_plot(ra, dec, weights=None, injection=None, **kwargs): 

703 """Plot the default sky location of the source for a given approximant 

704 

705 Parameters 

706 ---------- 

707 ra: list 

708 list of samples for right ascension 

709 dec: list 

710 list of samples for declination 

711 injection: list, optional 

712 list containing the injected value of ra and dec 

713 kwargs: dict 

714 optional keyword arguments 

715 """ 

716 from .cmap import register_cylon, unregister_cylon 

717 # register the cylon cmap 

718 register_cylon() 

719 ra = [-i + np.pi for i in ra] 

720 logger.debug("Generating the sky map plot") 

721 fig, orig_ax = figure(gca=True) 

722 orig_ax.spines['left'].set_visible(False) 

723 orig_ax.spines['right'].set_visible(False) 

724 orig_ax.spines['top'].set_visible(False) 

725 orig_ax.spines['bottom'].set_visible(False) 

726 orig_ax.set_yticks([]) 

727 orig_ax.set_xticks([]) 

728 ax = fig.add_subplot( 

729 111, projection="mollweide", 

730 facecolor=(1.0, 0.939165516411, 0.880255669068) 

731 ) 

732 ax.cla() 

733 ax.set_title("Preliminary", fontdict={'fontsize': 11}) 

734 ax.grid(visible=True) 

735 ax.set_xticklabels([ 

736 r"$2^{h}$", r"$4^{h}$", r"$6^{h}$", r"$8^{h}$", r"$10^{h}$", 

737 r"$12^{h}$", r"$14^{h}$", r"$16^{h}$", r"$18^{h}$", r"$20^{h}$", 

738 r"$22^{h}$"]) 

739 levels = [0.9, 0.5] 

740 

741 if weights is None: 

742 H, X, Y = np.histogram2d(ra, dec, bins=50) 

743 else: 

744 H, X, Y = np.histogram2d(ra, dec, bins=50, weights=weights) 

745 H = gaussian_filter(H, kwargs.get("smooth", 0.9)) 

746 Hflat = H.flatten() 

747 indicies = np.argsort(Hflat)[::-1] 

748 Hflat = Hflat[indicies] 

749 

750 CF = np.cumsum(Hflat) 

751 CF /= CF[-1] 

752 

753 V = np.empty(len(levels)) 

754 for num, i in enumerate(levels): 

755 try: 

756 V[num] = Hflat[CF <= i][-1] 

757 except Exception: 

758 V[num] = Hflat[0] 

759 V.sort() 

760 m = np.diff(V) == 0 

761 while np.any(m): 

762 V[np.where(m)[0][0]] *= 1.0 - 1e-4 

763 m = np.diff(V) == 0 

764 V.sort() 

765 X1, Y1 = 0.5 * (X[1:] + X[:-1]), 0.5 * (Y[1:] + Y[:-1]) 

766 

767 H2 = H.min() + np.zeros((H.shape[0] + 4, H.shape[1] + 4)) 

768 H2[2:-2, 2:-2] = H 

769 H2[2:-2, 1] = H[:, 0] 

770 H2[2:-2, -2] = H[:, -1] 

771 H2[1, 2:-2] = H[0] 

772 H2[-2, 2:-2] = H[-1] 

773 H2[1, 1] = H[0, 0] 

774 H2[1, -2] = H[0, -1] 

775 H2[-2, 1] = H[-1, 0] 

776 H2[-2, -2] = H[-1, -1] 

777 X2 = np.concatenate([X1[0] + np.array([-2, -1]) * np.diff(X1[:2]), X1, 

778 X1[-1] + np.array([1, 2]) * np.diff(X1[-2:]), ]) 

779 Y2 = np.concatenate([Y1[0] + np.array([-2, -1]) * np.diff(Y1[:2]), Y1, 

780 Y1[-1] + np.array([1, 2]) * np.diff(Y1[-2:]), ]) 

781 

782 ax.pcolormesh(X2, Y2, H2.T, vmin=0., vmax=H2.T.max(), cmap="cylon") 

783 cs = ax.contour(X2, Y2, H2.T, V, colors="k", linewidths=0.5) 

784 if injection is not None: 

785 ax.scatter( 

786 -injection[0] + np.pi, injection[1], marker="*", 

787 color=conf.injection_color, edgecolors='k', linewidth=1.75, s=100 

788 ) 

789 fmt = {l: s for l, s in zip(cs.levels, [r"$90\%$", r"$50\%$"])} 

790 ax.clabel(cs, fmt=fmt, fontsize=8, inline=True) 

791 text = [] 

792 for path, j in zip(cs.get_paths(), [90, 50]): 

793 area = 0. 

794 for poly in path.to_polygons(): 

795 x = poly[:, 0] 

796 y = poly[:, 1] 

797 area += 0.5 * np.sum(y[:-1] * np.diff(x) - x[:-1] * np.diff(y)) 

798 area = int(np.abs(area) * (180 / np.pi) * (180 / np.pi)) 

799 text.append(u'{:d}\% area: {:d} deg²'.format(int(j), area, grouping=True)) 

800 orig_ax.text(1.0, 1.05, '\n'.join(text[::-1]), transform=ax.transAxes, ha='right', 

801 fontsize=10) 

802 xticks = np.arange(-np.pi, np.pi + np.pi / 6, np.pi / 4) 

803 ax.set_xticks(xticks) 

804 ax.set_yticks([-np.pi / 3, -np.pi / 6, 0, np.pi / 6, np.pi / 3]) 

805 labels = [r"$%s^{h}$" % (int(np.round((i + np.pi) * 3.82, 1))) for i in xticks] 

806 ax.set_xticklabels(labels[::-1], fontsize=10) 

807 ax.set_yticklabels([r"$-60^{\circ}$", r"$-30^{\circ}$", r"$0^{\circ}$", 

808 r"$30^{\circ}$", r"$60^{\circ}$"], fontsize=10) 

809 ax.grid(visible=True) 

810 # unregister the cylon cmap 

811 unregister_cylon() 

812 return fig 

813 

814 

815def _sky_map_comparison_plot(ra_list, dec_list, labels, colors, **kwargs): 

816 """Generate a plot that compares the sky location for multiple approximants 

817 

818 Parameters 

819 ---------- 

820 ra_list: 2d list 

821 list of samples for right ascension for each approximant 

822 dec_list: 2d list 

823 list of samples for declination for each approximant 

824 approximants: list 

825 list of approximants used to generate the samples 

826 colors: list 

827 list of colors to be used to differentiate the different approximants 

828 approximant_labels: list, optional 

829 label to prepend the approximant in the legend 

830 kwargs: dict 

831 optional keyword arguments 

832 """ 

833 ra_list = [[-i + np.pi for i in j] for j in ra_list] 

834 logger.debug("Generating the sky map comparison plot") 

835 fig = figure(gca=False) 

836 ax = fig.add_subplot( 

837 111, projection="mollweide", 

838 ) 

839 ax.cla() 

840 ax.set_title("Preliminary", fontdict={'fontsize': 11}) 

841 ax.grid(visible=True) 

842 ax.set_xticklabels([ 

843 r"$2^{h}$", r"$4^{h}$", r"$6^{h}$", r"$8^{h}$", r"$10^{h}$", 

844 r"$12^{h}$", r"$14^{h}$", r"$16^{h}$", r"$18^{h}$", r"$20^{h}$", 

845 r"$22^{h}$"]) 

846 levels = [0.9, 0.5] 

847 for num, i in enumerate(ra_list): 

848 H, X, Y = np.histogram2d(i, dec_list[num], bins=50) 

849 H = gaussian_filter(H, kwargs.get("smooth", 0.9)) 

850 Hflat = H.flatten() 

851 indicies = np.argsort(Hflat)[::-1] 

852 Hflat = Hflat[indicies] 

853 

854 CF = np.cumsum(Hflat) 

855 CF /= CF[-1] 

856 

857 V = np.empty(len(levels)) 

858 for num2, j in enumerate(levels): 

859 try: 

860 V[num2] = Hflat[CF <= j][-1] 

861 except Exception: 

862 V[num2] = Hflat[0] 

863 V.sort() 

864 m = np.diff(V) == 0 

865 while np.any(m): 

866 V[np.where(m)[0][0]] *= 1.0 - 1e-4 

867 m = np.diff(V) == 0 

868 V.sort() 

869 X1, Y1 = 0.5 * (X[1:] + X[:-1]), 0.5 * (Y[1:] + Y[:-1]) 

870 

871 H2 = H.min() + np.zeros((H.shape[0] + 4, H.shape[1] + 4)) 

872 H2[2:-2, 2:-2] = H 

873 H2[2:-2, 1] = H[:, 0] 

874 H2[2:-2, -2] = H[:, -1] 

875 H2[1, 2:-2] = H[0] 

876 H2[-2, 2:-2] = H[-1] 

877 H2[1, 1] = H[0, 0] 

878 H2[1, -2] = H[0, -1] 

879 H2[-2, 1] = H[-1, 0] 

880 H2[-2, -2] = H[-1, -1] 

881 X2 = np.concatenate([X1[0] + np.array([-2, -1]) * np.diff(X1[:2]), X1, 

882 X1[-1] + np.array([1, 2]) * np.diff(X1[-2:]), ]) 

883 Y2 = np.concatenate([Y1[0] + np.array([-2, -1]) * np.diff(Y1[:2]), Y1, 

884 Y1[-1] + np.array([1, 2]) * np.diff(Y1[-2:]), ]) 

885 CS = ax.contour(X2, Y2, H2.T, V, colors=colors[num], linewidths=2.0) 

886 ax.plot([], [], color=colors[num], linewidth=2.0, label=labels[num]) 

887 ncols = number_of_columns_for_legend(labels) 

888 ax.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, borderaxespad=0., 

889 mode="expand", ncol=ncols) 

890 xticks = np.arange(-np.pi, np.pi + np.pi / 6, np.pi / 4) 

891 ax.set_xticks(xticks) 

892 ax.set_yticks([-np.pi / 3, -np.pi / 6, 0, np.pi / 6, np.pi / 3]) 

893 labels = [r"$%s^{h}$" % (int(np.round((i + np.pi) * 3.82, 1))) for i in xticks] 

894 ax.set_xticklabels(labels[::-1], fontsize=10) 

895 ax.set_yticklabels([r"$-60^\degree$", r"$-30^\degree$", r"$0^\degree$", 

896 r"$30^\degree$", r"$60^\degree$"], fontsize=10) 

897 ax.grid(visible=True) 

898 return fig 

899 

900 

901def __get_cutoff_indices(flow, fhigh, df, N): 

902 """ 

903 Gets the indices of a frequency series at which to stop an overlap 

904 calculation. 

905 

906 Parameters 

907 ---------- 

908 flow: float 

909 The frequency (in Hz) of the lower index. 

910 fhigh: float 

911 The frequency (in Hz) of the upper index. 

912 df: float 

913 The frequency step (in Hz) of the frequency series. 

914 N: int 

915 The number of points in the **time** series. Can be odd 

916 or even. 

917 

918 Returns 

919 ------- 

920 kmin: int 

921 kmax: int 

922 """ 

923 if flow: 

924 kmin = int(flow / df) 

925 else: 

926 kmin = 1 

927 if fhigh: 

928 kmax = int(fhigh / df) 

929 else: 

930 kmax = int((N + 1) / 2.) 

931 return kmin, kmax 

932 

933 

934@no_latex_plot 

935def _sky_sensitivity(network, resolution, maxL_params, **kwargs): 

936 """Generate the sky sensitivity for a given network 

937 

938 Parameters 

939 ---------- 

940 network: list 

941 list of detectors you want included in your sky sensitivity plot 

942 resolution: float 

943 resolution of the skymap 

944 maxL_params: dict 

945 dictionary of waveform parameters for the maximum likelihood waveform 

946 """ 

947 logger.debug("Generating the sky sensitivity for %s" % (network)) 

948 if not LALSIMULATION: 

949 raise Exception("LALSimulation could not be imported. Please install " 

950 "LALSuite to be able to use all features") 

951 delta_frequency = kwargs.get("delta_f", 1. / 256) 

952 minimum_frequency = kwargs.get("f_min", 20.) 

953 maximum_frequency = kwargs.get("f_max", 1000.) 

954 frequency_array = np.arange(minimum_frequency, maximum_frequency, 

955 delta_frequency) 

956 

957 approx = lalsim.GetApproximantFromString(maxL_params["approximant"]) 

958 mass_1 = maxL_params["mass_1"] * MSUN_SI 

959 mass_2 = maxL_params["mass_2"] * MSUN_SI 

960 luminosity_distance = maxL_params["luminosity_distance"] * PC_SI * 10**6 

961 iota, S1x, S1y, S1z, S2x, S2y, S2z = \ 

962 lalsim.SimInspiralTransformPrecessingNewInitialConditions( 

963 maxL_params["iota"], maxL_params["phi_jl"], maxL_params["tilt_1"], 

964 maxL_params["tilt_2"], maxL_params["phi_12"], maxL_params["a_1"], 

965 maxL_params["a_2"], mass_1, mass_2, kwargs.get("f_ref", 10.), 

966 maxL_params["phase"]) 

967 h_plus, h_cross = lalsim.SimInspiralChooseFDWaveform( 

968 mass_1, mass_2, S1x, S1y, S1z, S2x, S2y, S2z, luminosity_distance, iota, 

969 maxL_params["phase"], 0.0, 0.0, 0.0, delta_frequency, minimum_frequency, 

970 maximum_frequency, kwargs.get("f_ref", 10.), None, approx) 

971 h_plus = h_plus.data.data 

972 h_cross = h_cross.data.data 

973 h_plus = h_plus[:len(frequency_array)] 

974 h_cross = h_cross[:len(frequency_array)] 

975 psd = {} 

976 psd["H1"] = psd["L1"] = np.array([ 

977 lalsim.SimNoisePSDaLIGOZeroDetHighPower(i) for i in frequency_array]) 

978 psd["V1"] = np.array([lalsim.SimNoisePSDVirgo(i) for i in frequency_array]) 

979 kmin, kmax = __get_cutoff_indices(minimum_frequency, maximum_frequency, 

980 delta_frequency, (len(h_plus) - 1) * 2) 

981 ra = np.arange(-np.pi, np.pi, resolution) 

982 dec = np.arange(-np.pi, np.pi, resolution) 

983 X, Y = np.meshgrid(ra, dec) 

984 N = np.zeros([len(dec), len(ra)]) 

985 

986 indices = np.ndindex(len(ra), len(dec)) 

987 for ind in indices: 

988 ar = {} 

989 SNR = {} 

990 for i in network: 

991 ard = __antenna_response(i, ra[ind[0]], dec[ind[1]], 

992 maxL_params["psi"], maxL_params["geocent_time"]) 

993 ar[i] = [ard[0], ard[1]] 

994 strain = np.array(h_plus * ar[i][0] + h_cross * ar[i][1]) 

995 integrand = np.conj(strain[kmin:kmax]) * strain[kmin:kmax] / psd[i][kmin:kmax] 

996 integrand = integrand[:-1] 

997 SNR[i] = np.sqrt(4 * delta_frequency * np.sum(integrand).real) 

998 ar[i][0] *= SNR[i] 

999 ar[i][1] *= SNR[i] 

1000 numerator = 0.0 

1001 denominator = 0.0 

1002 for i in network: 

1003 numerator += sum(i**2 for i in ar[i]) 

1004 denominator += SNR[i]**2 

1005 N[ind[1]][ind[0]] = (((numerator / denominator)**0.5)) 

1006 fig = figure(gca=False) 

1007 ax = fig.add_subplot(111, projection="hammer") 

1008 ax.cla() 

1009 ax.grid(visible=True) 

1010 ax.pcolormesh(X, Y, N) 

1011 ax.set_xticklabels([ 

1012 r"$22^{h}$", r"$20^{h}$", r"$18^{h}$", r"$16^{h}$", r"$14^{h}$", 

1013 r"$12^{h}$", r"$10^{h}$", r"$8^{h}$", r"$6^{h}$", r"$4^{h}$", 

1014 r"$2^{h}$"]) 

1015 return fig 

1016 

1017 

1018@no_latex_plot 

1019def _time_domain_waveform( 

1020 detectors, maxL_params, color=None, label=None, fig=None, ax=None, 

1021 **kwargs 

1022): 

1023 """ 

1024 Plot the maximum likelihood waveform for a given approximant 

1025 in the time domain. 

1026 

1027 Parameters 

1028 ---------- 

1029 detectors: list 

1030 list of detectors that you want to generate waveforms for 

1031 maxL_params: dict 

1032 dictionary of maximum likelihood parameter values 

1033 kwargs: dict 

1034 dictionary of optional keyword arguments 

1035 """ 

1036 from gwpy.timeseries import TimeSeries 

1037 from gwpy.plot.colors import GW_OBSERVATORY_COLORS 

1038 from pesummary.gw.waveform import td_waveform 

1039 from pesummary.utils.samples_dict import SamplesDict 

1040 if math.isnan(maxL_params["mass_1"]): 

1041 return 

1042 logger.debug("Generating the maximum likelihood waveform time domain plot") 

1043 if not LALSIMULATION: 

1044 raise Exception("lalsimulation could not be imported. please install " 

1045 "lalsuite to be able to use all features") 

1046 

1047 approximant = maxL_params["approximant"] 

1048 minimum_frequency = kwargs.get("f_low", 5.) 

1049 starting_frequency = kwargs.get("f_start", 5.) 

1050 approximant_flags = kwargs.get("approximant_flags", {}) 

1051 _samples = SamplesDict( 

1052 { 

1053 key: [item] for key, item in maxL_params.items() if 

1054 key != "approximant" 

1055 } 

1056 ) 

1057 _samples.generate_all_posterior_samples(disable_remnant=True) 

1058 _samples = {key: item[0] for key, item in _samples.items()} 

1059 chirptime = lalsim.SimIMRPhenomXASDuration( 

1060 _samples["mass_1"] * MSUN_SI, _samples["mass_2"] * MSUN_SI, 

1061 _samples.get("spin_1z", 0), _samples.get("spin_2z", 0), 

1062 minimum_frequency 

1063 ) 

1064 duration = np.max([2**np.ceil(np.log2(chirptime)), 1.0]) 

1065 if (fig is None) and (ax is None): 

1066 fig, ax = figure(gca=True) 

1067 elif ax is None: 

1068 ax = fig.gca() 

1069 elif fig is None: 

1070 raise ValueError("Please provide a figure for plotting") 

1071 if color is None: 

1072 color = [GW_OBSERVATORY_COLORS[i] for i in detectors] 

1073 elif len(color) != len(detectors): 

1074 raise ValueError( 

1075 "Please provide a list of colors for each detector" 

1076 ) 

1077 if label is None: 

1078 label = detectors 

1079 elif len(label) != len(detectors): 

1080 raise ValueError( 

1081 "Please provide a list of labels for each detector" 

1082 ) 

1083 for num, i in enumerate(detectors): 

1084 ht = td_waveform( 

1085 maxL_params, approximant, kwargs.get("delta_t", 1. / 4096.), 

1086 starting_frequency, f_ref=kwargs.get("f_ref", 10.), project=i, 

1087 flags=approximant_flags 

1088 ) 

1089 ax.plot( 

1090 ht.times.value, ht, color=color[num], linewidth=1.0, 

1091 label=label[num] 

1092 ) 

1093 ax.set_xlim( 

1094 [ 

1095 maxL_params["geocent_time"] - 0.75 * duration, 

1096 maxL_params["geocent_time"] + duration / 4 

1097 ] 

1098 ) 

1099 ax.set_xlabel(r"Time $[s]$") 

1100 ax.set_ylabel(r"Strain") 

1101 ax.grid(visible=True) 

1102 ax.legend(loc="best") 

1103 fig.tight_layout() 

1104 return fig 

1105 

1106 

1107@no_latex_plot 

1108def _time_domain_waveform_comparison_plot(maxL_params_list, colors, labels, 

1109 **kwargs): 

1110 """Generate a plot which compares the maximum likelihood waveforms for 

1111 each approximant. 

1112 

1113 Parameters 

1114 ---------- 

1115 maxL_params_list: list 

1116 list of dictionaries containing the maximum likelihood parameter 

1117 values for each approximant 

1118 colors: list 

1119 list of colors to be used to differentiate the different approximants 

1120 approximant_labels: list, optional 

1121 label to prepend the approximant in the legend 

1122 kwargs: dict 

1123 dictionary of optional keyword arguments 

1124 """ 

1125 from gwpy.timeseries import TimeSeries 

1126 logger.debug("Generating the maximum likelihood time domain waveform " 

1127 "comparison plot for H1") 

1128 if not LALSIMULATION: 

1129 raise Exception("LALSimulation could not be imported. Please install " 

1130 "LALSuite to be able to use all features") 

1131 fig, ax = figure(gca=True) 

1132 for num, i in enumerate(maxL_params_list): 

1133 _kwargs = { 

1134 "f_start": i.get("f_start", 20.), 

1135 "f_low": i.get("f_low", 20.), 

1136 "f_max": i.get("f_final", 1024.), 

1137 "f_ref": i.get("f_ref", 20.), 

1138 "approximant_flags": i.get("approximant_flags", {}) 

1139 } 

1140 _ = _time_domain_waveform( 

1141 ["H1"], i, fig=fig, ax=ax, color=[colors[num]], 

1142 label=[labels[num]], **_kwargs 

1143 ) 

1144 ax.set_xlabel(r"Time $[s]$") 

1145 ax.set_ylabel(r"Strain") 

1146 ax.grid(visible=True) 

1147 ax.legend(loc="best") 

1148 fig.tight_layout() 

1149 return fig 

1150 

1151 

1152def _psd_plot(frequencies, strains, colors=None, labels=None, fmin=None, fmax=None): 

1153 """Superimpose all PSD plots onto a single figure. 

1154 

1155 Parameters 

1156 ---------- 

1157 frequencies: nd list 

1158 list of all frequencies used for each psd file 

1159 strains: nd list 

1160 list of all strains used for each psd file 

1161 colors: optional, list 

1162 list of colors to be used to differentiate the different PSDs 

1163 labels: optional, list 

1164 list of lavels for each PSD 

1165 fmin: optional, float 

1166 starting frequency of the plot 

1167 fmax: optional, float 

1168 maximum frequency of the plot 

1169 """ 

1170 from gwpy.plot.colors import GW_OBSERVATORY_COLORS 

1171 fig, ax = figure(gca=True) 

1172 if not colors and all(i in GW_OBSERVATORY_COLORS.keys() for i in labels): 

1173 colors = [GW_OBSERVATORY_COLORS[i] for i in labels] 

1174 elif not colors: 

1175 colors = ['r', 'b', 'orange', 'c', 'g', 'purple'] 

1176 while len(colors) <= len(labels): 

1177 colors += colors 

1178 for num, i in enumerate(frequencies): 

1179 ff = np.array(i) 

1180 ss = np.array(strains[num]) 

1181 cond = np.ones_like(strains[num], dtype=bool) 

1182 if fmin is not None: 

1183 cond *= ff >= fmin 

1184 if fmax is not None: 

1185 cond *= ff <= fmax 

1186 i = ff[cond] 

1187 strains[num] = ss[cond] 

1188 ax.loglog(i, strains[num], color=colors[num], label=labels[num]) 

1189 ax.tick_params(which="both", bottom=True, length=3, width=1) 

1190 ax.set_xlabel(r"Frequency $[\mathrm{Hz}]$") 

1191 ax.set_ylabel(r"Power Spectral Density [$\mathrm{strain}^{2}/\mathrm{Hz}$]") 

1192 ax.legend(loc="best") 

1193 fig.tight_layout() 

1194 return fig 

1195 

1196 

1197@no_latex_plot 

1198def _calibration_envelope_plot(frequency, calibration_envelopes, ifos, 

1199 colors=None, prior=[], definition="data"): 

1200 """Generate a plot showing the calibration envelope 

1201 

1202 Parameters 

1203 ---------- 

1204 frequency: array 

1205 frequency bandwidth that you would like to use 

1206 calibration_envelopes: nd list 

1207 list containing the calibration envelope data for different IFOs 

1208 ifos: list 

1209 list of IFOs that are associated with the calibration envelopes 

1210 colors: list, optional 

1211 list of colors to be used to differentiate the different calibration 

1212 envelopes 

1213 prior: list, optional 

1214 list containing the prior calibration envelope data for different IFOs 

1215 definition: str, optional 

1216 definition used for the prior calibration envelope data 

1217 """ 

1218 from gwpy.plot.colors import GW_OBSERVATORY_COLORS 

1219 

1220 def interpolate_calibration(data): 

1221 """Interpolate the calibration data using spline 

1222 

1223 Parameters 

1224 ---------- 

1225 data: np.ndarray 

1226 array containing the calibration data 

1227 """ 

1228 interp = [ 

1229 np.interp(frequency, data[:, 0], data[:, j], left=k, right=k) 

1230 for j, k in zip(range(1, 7), [1, 0, 1, 0, 1, 0]) 

1231 ] 

1232 amp_median = (interp[0] - 1) * 100 

1233 phase_median = interp[1] * 180. / np.pi 

1234 amp_lower_sigma = (interp[2] - 1) * 100 

1235 phase_lower_sigma = interp[3] * 180. / np.pi 

1236 amp_upper_sigma = (interp[4] - 1) * 100 

1237 phase_upper_sigma = interp[5] * 180. / np.pi 

1238 data_dict = { 

1239 "amplitude": { 

1240 "median": amp_median, 

1241 "lower": amp_lower_sigma, 

1242 "upper": amp_upper_sigma 

1243 }, 

1244 "phase": { 

1245 "median": phase_median, 

1246 "lower": phase_lower_sigma, 

1247 "upper": phase_upper_sigma 

1248 } 

1249 } 

1250 return data_dict 

1251 

1252 fig, (ax1, ax2) = subplots(2, 1, sharex=True, gca=False) 

1253 if not colors and all(i in GW_OBSERVATORY_COLORS.keys() for i in ifos): 

1254 colors = [GW_OBSERVATORY_COLORS[i] for i in ifos] 

1255 elif not colors: 

1256 colors = ['r', 'b', 'orange', 'c', 'g', 'purple'] 

1257 while len(colors) <= len(ifos): 

1258 colors += colors 

1259 

1260 for num, i in enumerate(calibration_envelopes): 

1261 calibration_envelopes[num] = np.array(calibration_envelopes[num]) 

1262 

1263 for num, i in enumerate(calibration_envelopes): 

1264 calibration_data = interpolate_calibration(i) 

1265 if prior != []: 

1266 prior_data = interpolate_calibration(prior[num]) 

1267 ax1.plot( 

1268 frequency, calibration_data["amplitude"]["upper"], color=colors[num], 

1269 linestyle="-", label=ifos[num] 

1270 ) 

1271 ax1.plot( 

1272 frequency, calibration_data["amplitude"]["lower"], color=colors[num], 

1273 linestyle="-" 

1274 ) 

1275 ax1.set_ylabel(r"Amplitude deviation $[\%]$", fontsize=10) 

1276 ax1.legend(loc="best") 

1277 ax2.plot( 

1278 frequency, calibration_data["phase"]["upper"], color=colors[num], 

1279 linestyle="-", label=ifos[num] 

1280 ) 

1281 ax2.plot( 

1282 frequency, calibration_data["phase"]["lower"], color=colors[num], 

1283 linestyle="-" 

1284 ) 

1285 ax2.set_ylabel(r"Phase deviation $[\degree]$", fontsize=10) 

1286 if prior != []: 

1287 ax1.fill_between( 

1288 frequency, prior_data["amplitude"]["upper"], 

1289 prior_data["amplitude"]["lower"], color=colors[num], alpha=0.2 

1290 ) 

1291 ax2.fill_between( 

1292 frequency, prior_data["phase"]["upper"], 

1293 prior_data["phase"]["lower"], color=colors[num], alpha=0.2 

1294 ) 

1295 

1296 ax1.set_title(f"Calibration correction applied to {definition}") 

1297 ax1.set_xscale('log') 

1298 ax2.set_xscale('log') 

1299 ax2.set_xlabel(r"Frequency $[Hz]$") 

1300 fig.tight_layout() 

1301 return fig 

1302 

1303 

1304def _strain_plot(strain, maxL_params, **kwargs): 

1305 """Generate a plot showing the strain data and the maxL waveform 

1306 

1307 Parameters 

1308 ---------- 

1309 strain: gwpy.timeseries 

1310 timeseries containing the strain data 

1311 maxL_samples: dict 

1312 dictionary of maximum likelihood parameter values 

1313 """ 

1314 logger.debug("Generating the strain plot") 

1315 from pesummary.gw.conversions import time_in_each_ifo 

1316 from gwpy.timeseries import TimeSeries 

1317 

1318 fig, axs = subplots(nrows=len(strain.keys()), sharex=True) 

1319 time = maxL_params["geocent_time"] 

1320 delta_t = 1. / 4096. 

1321 minimum_frequency = kwargs.get("f_min", 5.) 

1322 t_start = time - 15.0 

1323 t_finish = time + 0.06 

1324 time_array = np.arange(t_start, t_finish, delta_t) 

1325 

1326 approx = lalsim.GetApproximantFromString(maxL_params["approximant"]) 

1327 mass_1 = maxL_params["mass_1"] * MSUN_SI 

1328 mass_2 = maxL_params["mass_2"] * MSUN_SI 

1329 luminosity_distance = maxL_params["luminosity_distance"] * PC_SI * 10**6 

1330 phase = maxL_params["phase"] if "phase" in maxL_params.keys() else 0.0 

1331 cartesian = [ 

1332 "iota", "spin_1x", "spin_1y", "spin_1z", "spin_2x", "spin_2y", "spin_2z" 

1333 ] 

1334 if not all(param in maxL_params.keys() for param in cartesian): 

1335 if "phi_jl" in maxL_params.keys(): 

1336 iota, S1x, S1y, S1z, S2x, S2y, S2z = \ 

1337 lalsim.SimInspiralTransformPrecessingNewInitialConditions( 

1338 maxL_params["theta_jn"], maxL_params["phi_jl"], 

1339 maxL_params["tilt_1"], maxL_params["tilt_2"], 

1340 maxL_params["phi_12"], maxL_params["a_1"], 

1341 maxL_params["a_2"], mass_1, mass_2, kwargs.get("f_ref", 10.), 

1342 phase 

1343 ) 

1344 else: 

1345 iota, S1x, S1y, S1z, S2x, S2y, S2z = maxL_params["iota"], 0., 0., \ 

1346 0., 0., 0., 0. 

1347 else: 

1348 iota, S1x, S1y, S1z, S2x, S2y, S2z = [ 

1349 maxL_params[param] for param in cartesian 

1350 ] 

1351 h_plus, h_cross = lalsim.SimInspiralChooseTDWaveform( 

1352 mass_1, mass_2, S1x, S1y, S1z, S2x, S2y, S2z, luminosity_distance, iota, 

1353 phase, 0.0, 0.0, 0.0, delta_t, minimum_frequency, 

1354 kwargs.get("f_ref", 10.), None, approx) 

1355 h_plus = TimeSeries( 

1356 h_plus.data.data[:], dt=h_plus.deltaT, t0=h_plus.epoch 

1357 ) 

1358 h_cross = TimeSeries( 

1359 h_cross.data.data[:], dt=h_cross.deltaT, t0=h_cross.epoch 

1360 ) 

1361 

1362 for num, key in enumerate(list(strain.keys())): 

1363 ifo_time = time_in_each_ifo(key, maxL_params["ra"], maxL_params["dec"], 

1364 maxL_params["geocent_time"]) 

1365 

1366 asd = strain[key].asd(8, 4, method="median") 

1367 strain_data_frequency = strain[key].fft() 

1368 asd_interp = asd.interpolate(float(np.array(strain_data_frequency.df))) 

1369 asd_interp = asd_interp[:len(strain_data_frequency)] 

1370 strain_data_time = (strain_data_frequency / asd_interp).ifft() 

1371 strain_data_time = strain_data_time.highpass(30) 

1372 strain_data_time = strain_data_time.lowpass(300) 

1373 

1374 ar = __antenna_response(key, maxL_params["ra"], maxL_params["dec"], 

1375 maxL_params["psi"], maxL_params["geocent_time"]) 

1376 

1377 h_t = ar[0] * h_plus + ar[1] * h_cross 

1378 h_t_frequency = h_t.fft() 

1379 asd_interp = asd.interpolate(float(np.array(h_t_frequency.df))) 

1380 asd_interp = asd_interp[:len(h_t_frequency)] 

1381 h_t_time = (h_t_frequency / asd_interp).ifft() 

1382 h_t_time = h_t_time.highpass(30) 

1383 h_t_time = h_t_time.lowpass(300) 

1384 h_t_time.times = [float(np.array(i)) + ifo_time for i in h_t.times] 

1385 

1386 strain_data_crop = strain_data_time.crop(ifo_time - 0.2, ifo_time + 0.06) 

1387 try: 

1388 h_t_time = h_t_time.crop(ifo_time - 0.2, ifo_time + 0.06) 

1389 except Exception: 

1390 pass 

1391 max_strain = np.max(strain_data_crop).value 

1392 

1393 axs[num].plot(strain_data_crop, color='grey', alpha=0.75, label="data") 

1394 axs[num].plot(h_t_time, color='orange', label="template") 

1395 axs[num].set_xlim([ifo_time - 0.2, ifo_time + 0.06]) 

1396 if not math.isnan(max_strain): 

1397 axs[num].set_ylim([-max_strain * 1.5, max_strain * 1.5]) 

1398 axs[num].set_ylabel("Whitened %s strain" % (key), fontsize=8) 

1399 axs[num].grid(False) 

1400 axs[num].legend(loc="best", prop={'size': 8}) 

1401 axs[-1].set_xlabel("Time $[s]$", fontsize=16) 

1402 fig.tight_layout() 

1403 return fig 

1404 

1405 

1406def _format_prob(prob): 

1407 """Format the probabilities for use with _classification_plot 

1408 """ 

1409 if prob >= 1: 

1410 return '100%' 

1411 elif prob <= 0: 

1412 return '0%' 

1413 elif prob > 0.99: 

1414 return '>99%' 

1415 elif prob < 0.01: 

1416 return '<1%' 

1417 else: 

1418 try: 

1419 return '{}%'.format(int(np.round(100 * prob))) 

1420 except ValueError: 

1421 return '{}%'.format(np.round(100 * prob)) 

1422 

1423 

1424@no_latex_plot 

1425def _classification_plot(classification): 

1426 """Generate a bar chart showing the source classifications probabilities 

1427 

1428 Parameters 

1429 ---------- 

1430 classification: dict 

1431 dictionary of source classifications 

1432 """ 

1433 probs, names = zip( 

1434 *sorted(zip(classification.values(), classification.keys()))) 

1435 with matplotlib.style.context([ 

1436 "seaborn-v0_8-white", 

1437 { 

1438 "font.size": 12, 

1439 "ytick.labelsize": 12, 

1440 }, 

1441 ]): 

1442 fig, ax = figure(figsize=(2.5, 2), gca=True) 

1443 ax.barh(names, probs) 

1444 for i, prob in enumerate(probs): 

1445 ax.annotate(_format_prob(prob), (0, i), (4, 0), 

1446 textcoords='offset points', ha='left', va='center') 

1447 ax.set_xlim(0, 1) 

1448 ax.set_xticks([]) 

1449 ax.tick_params(left=False) 

1450 for side in ['top', 'bottom', 'right']: 

1451 ax.spines[side].set_visible(False) 

1452 fig.tight_layout() 

1453 return fig