Coverage for pesummary/gw/plots/plot.py: 60.2%

640 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-09 22:34 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3from pesummary.utils.utils import ( 

4 logger, number_of_columns_for_legend, _check_latex_install, 

5) 

6from pesummary.utils.decorators import no_latex_plot 

7from pesummary.gw.plots.bounds import default_bounds 

8from pesummary.core.plots.figure import figure, subplots, ExistingFigure 

9from pesummary.core.plots.plot import _default_legend_kwargs 

10from pesummary import conf 

11 

12import os 

13import matplotlib.style 

14import numpy as np 

15import math 

16from scipy.ndimage import gaussian_filter 

17from astropy.time import Time 

18 

19_check_latex_install() 

20 

21from lal import MSUN_SI, PC_SI, CreateDict 

22 

23__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

24try: 

25 import lalsimulation as lalsim 

26 LALSIMULATION = True 

27except ImportError: 

28 LALSIMULATION = None 

29 

30 

31def _return_bounds(param, samples, comparison=False): 

32 """Return the bounds for a given param 

33 

34 Parameters 

35 ---------- 

36 param: str 

37 name of the parameter you wish to get bounds for 

38 samples: list/np.ndarray 

39 array or list of array of posterior samples for param 

40 comparison: Bool, optional 

41 True if samples is a list of array's of posterior samples 

42 """ 

43 xlow, xhigh = None, None 

44 if param in default_bounds.keys(): 

45 bounds = default_bounds[param] 

46 if "low" in bounds.keys(): 

47 xlow = bounds["low"] 

48 if "high" in bounds.keys(): 

49 if isinstance(bounds["high"], str) and "mass_1" in bounds["high"]: 

50 if comparison: 

51 xhigh = np.max([np.max(i) for i in samples]) 

52 else: 

53 xhigh = np.max(samples) 

54 else: 

55 xhigh = bounds["high"] 

56 return xlow, xhigh 

57 

58 

59def _add_default_bounds_to_kde_kwargs_dict( 

60 kde_kwargs, param, samples, comparison=False 

61): 

62 """Add default kde bounds to the a dictionary of kwargs 

63 

64 Parameters 

65 ---------- 

66 kde_kwargs: dict 

67 dictionary of kwargs to pass to the kde class 

68 param: str 

69 name of the parameter you wish to plot 

70 samples: list 

71 list of samples for param 

72 """ 

73 from pesummary.utils.bounded_1d_kde import bounded_1d_kde 

74 

75 xlow, xhigh = _return_bounds(param, samples, comparison=comparison) 

76 kde_kwargs["xlow"] = xlow 

77 kde_kwargs["xhigh"] = xhigh 

78 kde_kwargs["kde_kernel"] = bounded_1d_kde 

79 return kde_kwargs 

80 

81 

82def _1d_histogram_plot( 

83 param, samples, *args, kde_kwargs={}, bounded=True, **kwargs 

84): 

85 """Generate the 1d histogram plot for a given parameter for a given 

86 approximant. 

87 

88 Parameters 

89 ---------- 

90 *args: tuple 

91 all args passed directly to pesummary.core.plots.plot._1d_histogram_plot 

92 function 

93 kde_kwargs: dict, optional 

94 optional kwargs passed to the kde class 

95 bounded: Bool, optional 

96 if True, pass default 'xlow' and 'xhigh' arguments to the kde class 

97 **kwargs: dict, optional 

98 all additional kwargs passed to the 

99 pesummary.core.plots.plot._1d_histogram_plot function 

100 """ 

101 from pesummary.core.plots.plot import _1d_histogram_plot 

102 

103 if bounded: 

104 kde_kwargs = _add_default_bounds_to_kde_kwargs_dict( 

105 kde_kwargs, param, samples 

106 ) 

107 return _1d_histogram_plot( 

108 param, samples, *args, kde_kwargs=kde_kwargs, **kwargs 

109 ) 

110 

111 

112def _1d_histogram_plot_mcmc( 

113 param, samples, *args, kde_kwargs={}, bounded=True, **kwargs 

114): 

115 """Generate the 1d histogram plot for a given parameter for set of 

116 mcmc chains 

117 

118 Parameters 

119 ---------- 

120 *args: tuple 

121 all args passed directly to 

122 pesummary.core.plots.plot._1d_histogram_plot_mcmc function 

123 kde_kwargs: dict, optional 

124 optional kwargs passed to the kde class 

125 bounded: Bool, optional 

126 if True, pass default 'xlow' and 'xhigh' arguments to the kde class 

127 **kwargs: dict, optional 

128 all additional kwargs passed to the 

129 pesummary.core.plots.plot._1d_histogram_plot_mcmc function 

130 """ 

131 from pesummary.core.plots.plot import _1d_histogram_plot_mcmc 

132 

133 if bounded: 

134 kde_kwargs = _add_default_bounds_to_kde_kwargs_dict( 

135 kde_kwargs, param, samples, comparison=True 

136 ) 

137 return _1d_histogram_plot_mcmc( 

138 param, samples, *args, kde_kwargs=kde_kwargs, **kwargs 

139 ) 

140 

141 

142def _1d_histogram_plot_bootstrap( 

143 param, samples, *args, kde_kwargs={}, bounded=True, **kwargs 

144): 

145 """Generate a bootstrapped 1d histogram plot for a given parameter 

146 

147 Parameters 

148 ---------- 

149 param: str 

150 name of the parameter that you wish to plot 

151 samples: np.ndarray 

152 array of samples for param 

153 args: tuple 

154 all args passed to 

155 pesummary.core.plots.plot._1d_histogram_plot_bootstrap function 

156 kde_kwargs: dict, optional 

157 optional kwargs passed to the kde class 

158 bounded: Bool, optional 

159 if True, pass default 'xlow' and 'xhigh' arguments to the kde class 

160 **kwargs: dict, optional 

161 all additional kwargs passed to the 

162 pesummary.core.plots.plot._1d_histogram_plot_bootstrap function 

163 """ 

164 from pesummary.core.plots.plot import _1d_histogram_plot_bootstrap 

165 

166 if bounded: 

167 kde_kwargs = _add_default_bounds_to_kde_kwargs_dict( 

168 kde_kwargs, param, samples 

169 ) 

170 return _1d_histogram_plot_bootstrap( 

171 param, samples, *args, kde_kwargs=kde_kwargs, **kwargs 

172 ) 

173 

174 

175def _1d_comparison_histogram_plot( 

176 param, samples, *args, kde_kwargs={}, bounded=True, max_vline=2, 

177 legend_kwargs=_default_legend_kwargs, **kwargs 

178): 

179 """Generate the a plot to compare the 1d_histogram plots for a given 

180 parameter for different approximants. 

181 

182 Parameters 

183 ---------- 

184 *args: tuple 

185 all args passed directly to 

186 pesummary.core.plots.plot._1d_comparisonhistogram_plot function 

187 kde_kwargs: dict, optional 

188 optional kwargs passed to the kde class 

189 bounded: Bool, optional 

190 if True, pass default 'xlow' and 'xhigh' arguments to the kde class 

191 max_vline: int, optional 

192 if number of peaks < max_vline draw peaks as vertical lines rather 

193 than histogramming the data 

194 **kwargs: dict, optional 

195 all additional kwargs passed to the 

196 pesummary.core.plots.plot._1d_comparison_histogram_plot function 

197 """ 

198 from pesummary.core.plots.plot import _1d_comparison_histogram_plot 

199 

200 if bounded: 

201 kde_kwargs = _add_default_bounds_to_kde_kwargs_dict( 

202 kde_kwargs, param, samples, comparison=True 

203 ) 

204 return _1d_comparison_histogram_plot( 

205 param, samples, *args, kde_kwargs=kde_kwargs, max_vline=max_vline, 

206 legend_kwargs=legend_kwargs, **kwargs 

207 ) 

208 

209 

210def _make_corner_plot(samples, latex_labels, corner_parameters=None, **kwargs): 

211 """Generate the corner plots for a given approximant 

212 

213 Parameters 

214 ---------- 

215 opts: argparse 

216 argument parser object to hold all information from the command line 

217 samples: nd list 

218 nd list of samples for each parameter for a given approximant 

219 params: list 

220 list of parameters associated with each element in samples 

221 approximant: str 

222 name of approximant that was used to generate the samples 

223 latex_labels: dict 

224 dictionary of latex labels for each parameter 

225 """ 

226 from pesummary.core.plots.plot import _make_corner_plot 

227 

228 if corner_parameters is None: 

229 corner_parameters = conf.gw_corner_parameters 

230 

231 return _make_corner_plot( 

232 samples, latex_labels, corner_parameters=corner_parameters, **kwargs 

233 ) 

234 

235 

236def _make_source_corner_plot(samples, latex_labels, **kwargs): 

237 """Generate the corner plots for a given approximant 

238 

239 Parameters 

240 ---------- 

241 opts: argparse 

242 argument parser object to hold all information from the command line 

243 samples: nd list 

244 nd list of samples for each parameter for a given approximant 

245 params: list 

246 list of parameters associated with each element in samples 

247 approximant: str 

248 name of approximant that was used to generate the samples 

249 latex_labels: dict 

250 dictionary of latex labels for each parameter 

251 """ 

252 from pesummary.core.plots.plot import _make_corner_plot 

253 

254 return _make_corner_plot( 

255 samples, latex_labels, 

256 corner_parameters=conf.gw_source_frame_corner_parameters, **kwargs 

257 )[0] 

258 

259 

260def _make_extrinsic_corner_plot(samples, latex_labels, **kwargs): 

261 """Generate the corner plots for a given approximant 

262 

263 Parameters 

264 ---------- 

265 opts: argparse 

266 argument parser object to hold all information from the command line 

267 samples: nd list 

268 nd list of samples for each parameter for a given approximant 

269 params: list 

270 list of parameters associated with each element in samples 

271 approximant: str 

272 name of approximant that was used to generate the samples 

273 latex_labels: dict 

274 dictionary of latex labels for each parameter 

275 """ 

276 from pesummary.core.plots.plot import _make_corner_plot 

277 

278 return _make_corner_plot( 

279 samples, latex_labels, 

280 corner_parameters=conf.gw_extrinsic_corner_parameters, **kwargs 

281 )[0] 

282 

283 

284def _make_comparison_corner_plot( 

285 samples, latex_labels, corner_parameters=None, colors=conf.corner_colors, 

286 **kwargs 

287): 

288 """Generate a corner plot which contains multiple datasets 

289 

290 Parameters 

291 ---------- 

292 samples: dict 

293 nested dictionary containing the label as key and SamplesDict as item 

294 for each dataset you wish to plot 

295 latex_labels: dict 

296 dictionary of latex labels for each parameter 

297 corner_parameters: list, optional 

298 corner parameters you wish to include in the plot 

299 colors: list, optional 

300 unique colors for each dataset 

301 **kwargs: dict 

302 all kwargs are passed to `corner.corner` 

303 """ 

304 from pesummary.core.plots.plot import _make_comparison_corner_plot 

305 

306 if corner_parameters is None: 

307 corner_parameters = conf.gw_corner_parameters 

308 

309 return _make_comparison_corner_plot( 

310 samples, latex_labels, corner_parameters=corner_parameters, 

311 colors=colors, **kwargs 

312 ) 

313 

314 

315def __antenna_response(name, ra, dec, psi, time_gps): 

316 """Calculate the antenna response function 

317 

318 Parameters 

319 ---------- 

320 name: str 

321 name of the detector you wish to calculate the antenna response 

322 function for 

323 ra: float 

324 right ascension of the source 

325 dec: float 

326 declination of the source 

327 psi: float 

328 polarisation of the source 

329 time_gps: float 

330 gps time of merger 

331 """ 

332 # Following 8 lines taken from pycbc.detector.Detector 

333 from astropy.units.si import sday 

334 reference_time = 1126259462.0 

335 gmst_reference = Time( 

336 reference_time, format='gps', scale='utc', location=(0, 0) 

337 ).sidereal_time('mean').rad 

338 dphase = (time_gps - reference_time) / float(sday.si.scale) * (2.0 * np.pi) 

339 gmst = (gmst_reference + dphase) % (2.0 * np.pi) 

340 corrected_ra = gmst - ra 

341 if not LALSIMULATION: 

342 raise Exception("lalsimulation could not be imported. please install " 

343 "lalsuite to be able to use all features") 

344 detector = lalsim.DetectorPrefixToLALDetector(str(name)) 

345 

346 x0 = -np.cos(psi) * np.sin(corrected_ra) - \ 

347 np.sin(psi) * np.cos(corrected_ra) * np.sin(dec) 

348 x1 = -np.cos(psi) * np.cos(corrected_ra) + \ 

349 np.sin(psi) * np.sin(corrected_ra) * np.sin(dec) 

350 x2 = np.sin(psi) * np.cos(dec) 

351 x = np.array([x0, x1, x2]) 

352 dx = detector.response.dot(x) 

353 

354 y0 = np.sin(psi) * np.sin(corrected_ra) - \ 

355 np.cos(psi) * np.cos(corrected_ra) * np.sin(dec) 

356 y1 = np.sin(psi) * np.cos(corrected_ra) + \ 

357 np.cos(psi) * np.sin(corrected_ra) * np.sin(dec) 

358 y2 = np.cos(psi) * np.cos(dec) 

359 y = np.array([y0, y1, y2]) 

360 dy = detector.response.dot(y) 

361 

362 if hasattr(dx, "shape"): 

363 fplus = (x * dx - y * dy).sum(axis=0) 

364 fcross = (x * dy + y * dx).sum(axis=0) 

365 else: 

366 fplus = (x * dx - y * dy).sum() 

367 fcross = (x * dy + y * dx).sum() 

368 

369 return fplus, fcross 

370 

371 

372@no_latex_plot 

373def _waveform_plot( 

374 detectors, maxL_params, color=None, label=None, fig=None, ax=None, 

375 **kwargs 

376): 

377 """Plot the maximum likelihood waveform for a given approximant. 

378 

379 Parameters 

380 ---------- 

381 detectors: list 

382 list of detectors that you want to generate waveforms for 

383 maxL_params: dict 

384 dictionary of maximum likelihood parameter values 

385 kwargs: dict 

386 dictionary of optional keyword arguments 

387 """ 

388 from gwpy.plot.colors import GW_OBSERVATORY_COLORS 

389 from pesummary.gw.waveform import fd_waveform 

390 if math.isnan(maxL_params["mass_1"]): 

391 return 

392 logger.debug("Generating the maximum likelihood waveform plot") 

393 if not LALSIMULATION: 

394 raise Exception("lalsimulation could not be imported. please install " 

395 "lalsuite to be able to use all features") 

396 

397 if (fig is None) and (ax is None): 

398 fig, ax = figure(gca=True) 

399 elif ax is None: 

400 ax = fig.gca() 

401 elif fig is None: 

402 raise ValueError("Please provide a figure for plotting") 

403 if color is None: 

404 color = [GW_OBSERVATORY_COLORS[i] for i in detectors] 

405 elif len(color) != len(detectors): 

406 raise ValueError( 

407 "Please provide a list of colors for each detector" 

408 ) 

409 if label is None: 

410 label = detectors 

411 elif len(label) != len(detectors): 

412 raise ValueError( 

413 "Please provide a list of labels for each detector" 

414 ) 

415 minimum_frequency = kwargs.get("f_low", 5.) 

416 starting_frequency = kwargs.get("f_start", 5.) 

417 maximum_frequency = kwargs.get("f_max", 1000.) 

418 approximant_flags = kwargs.get("approximant_flags", {}) 

419 for num, i in enumerate(detectors): 

420 ht = fd_waveform( 

421 maxL_params, maxL_params["approximant"], 

422 kwargs.get("delta_f", 1. / 256), starting_frequency, 

423 maximum_frequency, f_ref=kwargs.get("f_ref", starting_frequency), 

424 project=i, flags=approximant_flags 

425 ) 

426 mask = ( 

427 (ht.frequencies.value > starting_frequency) * 

428 (ht.frequencies.value < maximum_frequency) 

429 ) 

430 ax.plot( 

431 ht.frequencies.value[mask], np.abs(ht)[mask], color=color[num], 

432 linewidth=1.0, label=label[num] 

433 ) 

434 if starting_frequency < minimum_frequency: 

435 ax.axvspan(starting_frequency, minimum_frequency, alpha=0.1, color='grey') 

436 ax.set_xscale("log") 

437 ax.set_yscale("log") 

438 ax.set_xlabel(r"Frequency $[Hz]$") 

439 ax.set_ylabel(r"Strain") 

440 ax.grid(visible=True) 

441 ax.legend(loc="best") 

442 fig.tight_layout() 

443 return fig 

444 

445 

446@no_latex_plot 

447def _waveform_comparison_plot(maxL_params_list, colors, labels, 

448 **kwargs): 

449 """Generate a plot which compares the maximum likelihood waveforms for 

450 each approximant. 

451 

452 Parameters 

453 ---------- 

454 maxL_params_list: list 

455 list of dictionaries containing the maximum likelihood parameter 

456 values for each approximant 

457 colors: list 

458 list of colors to be used to differentiate the different approximants 

459 approximant_labels: list, optional 

460 label to prepend the approximant in the legend 

461 kwargs: dict 

462 dictionary of optional keyword arguments 

463 """ 

464 logger.debug("Generating the maximum likelihood waveform comparison plot " 

465 "for H1") 

466 if not LALSIMULATION: 

467 raise Exception("LALSimulation could not be imported. Please install " 

468 "LALSuite to be able to use all features") 

469 

470 fig, ax = figure(gca=True) 

471 for num, i in enumerate(maxL_params_list): 

472 _kwargs = { 

473 "f_start": i.get("f_start", 20.), 

474 "f_low": i.get("f_low", 20.), 

475 "f_max": i.get("f_final", 1024.), 

476 "f_ref": i.get("f_ref", 20.), 

477 "approximant_flags": i.get("approximant_flags", {}) 

478 } 

479 _ = _waveform_plot( 

480 ["H1"], i, fig=fig, ax=ax, color=[colors[num]], 

481 label=[labels[num]], **_kwargs 

482 ) 

483 ax.set_xscale("log") 

484 ax.set_yscale("log") 

485 ax.grid(visible=True) 

486 ax.legend(loc="best") 

487 ax.set_xlabel(r"Frequency $[Hz]$") 

488 ax.set_ylabel(r"Strain") 

489 fig.tight_layout() 

490 return fig 

491 

492 

493def _ligo_skymap_plot(ra, dec, dist=None, savedir="./", nprocess=1, 

494 downsampled=False, label="pesummary", time=None, 

495 distance_map=True, multi_resolution=True, 

496 injection=None, **kwargs): 

497 """Plot the sky location of the source for a given approximant using the 

498 ligo.skymap package 

499 

500 Parameters 

501 ---------- 

502 ra: list 

503 list of samples for right ascension 

504 dec: list 

505 list of samples for declination 

506 dist: list 

507 list of samples for the luminosity distance 

508 savedir: str 

509 path to the directory where you would like to save the output files 

510 nprocess: Bool 

511 Boolean for whether to use multithreading or not 

512 downsampled: Bool 

513 Boolean for whether the samples have been downsampled or not 

514 distance_map: Bool 

515 Boolean for whether or not to produce a distance map 

516 multi_resolution: Bool 

517 Boolean for whether or not to generate a multiresolution HEALPix map 

518 injection: list, optional 

519 List containing RA and DEC of the injection. Both must be in radians 

520 kwargs: dict 

521 optional keyword arguments 

522 """ 

523 from ligo.skymap.bayestar import rasterize 

524 from ligo.skymap import io 

525 from ligo.skymap.kde import Clustered2DSkyKDE, Clustered2Plus1DSkyKDE 

526 

527 if dist is not None and distance_map: 

528 pts = np.column_stack((ra, dec, dist)) 

529 cls = Clustered2Plus1DSkyKDE 

530 else: 

531 pts = np.column_stack((ra, dec)) 

532 cls = Clustered2DSkyKDE 

533 skypost = cls(pts, trials=5, jobs=nprocess) 

534 hpmap = skypost.as_healpix() 

535 if not multi_resolution: 

536 hpmap = rasterize(hpmap) 

537 hpmap.meta['creator'] = "pesummary" 

538 hpmap.meta['origin'] = 'LIGO/Virgo' 

539 hpmap.meta['gps_creation_time'] = Time.now().gps 

540 if dist is not None: 

541 hpmap.meta["distmean"] = float(np.mean(dist)) 

542 hpmap.meta["diststd"] = float(np.std(dist)) 

543 if time is not None: 

544 if isinstance(time, (float, int)): 

545 _time = time 

546 else: 

547 _time = np.mean(time) 

548 hpmap.meta["gps_time"] = _time 

549 

550 io.write_sky_map( 

551 os.path.join(savedir, "%s_skymap.fits" % (label)), hpmap, nest=True 

552 ) 

553 skymap, metadata = io.fits.read_sky_map( 

554 os.path.join(savedir, "%s_skymap.fits" % (label)), nest=None 

555 ) 

556 return _ligo_skymap_plot_from_array( 

557 skymap, nsamples=len(ra), downsampled=downsampled, injection=injection 

558 )[0] 

559 

560 

561def _ligo_skymap_plot_from_array( 

562 skymap, nsamples=None, downsampled=False, contour=[50, 90], 

563 annotate=True, fig=None, ax=None, colors="k", injection=None 

564): 

565 """Generate a skymap with `ligo.skymap` based on an array of probabilities 

566 

567 Parameters 

568 ---------- 

569 skymap: np.array 

570 array of probabilities 

571 nsamples: int, optional 

572 number of samples used 

573 downsampled: Bool, optional 

574 If True, add a header to the skymap saying that this plot is downsampled 

575 contour: list, optional 

576 list of contours to be plotted on the skymap. Default 50, 90 

577 annotate: Bool, optional 

578 If True, annotate the figure by adding the 90% and 50% sky areas 

579 by default 

580 ax: matplotlib.axes._subplots.AxesSubplot, optional 

581 Existing axis to add the plot to 

582 colors: str/list 

583 colors to use for the contours 

584 injection: list, optional 

585 List containing RA and DEC of the injection. Both must be in radians 

586 """ 

587 import healpy as hp 

588 from ligo.skymap import plot 

589 

590 if fig is None and ax is None: 

591 fig = figure(gca=False) 

592 ax = fig.add_subplot(111, projection='astro hours mollweide') 

593 elif ax is None: 

594 ax = fig.gca() 

595 

596 ax.grid(visible=True) 

597 nside = hp.npix2nside(len(skymap)) 

598 deg2perpix = hp.nside2pixarea(nside, degrees=True) 

599 probperdeg2 = skymap / deg2perpix 

600 

601 if downsampled: 

602 ax.set_title("Downsampled to %s" % (nsamples), fontdict={'fontsize': 11}) 

603 

604 vmax = probperdeg2.max() 

605 ax.imshow_hpx((probperdeg2, 'ICRS'), nested=True, vmin=0., 

606 vmax=vmax, cmap="cylon") 

607 cls, cs = _ligo_skymap_contours(ax, skymap, contour=contour, colors=colors) 

608 if annotate: 

609 text = [] 

610 pp = np.round(contour).astype(int) 

611 ii = np.round( 

612 np.searchsorted(np.sort(cls), contour) * deg2perpix).astype(int) 

613 for i, p in zip(ii, pp): 

614 text.append(u'{:d}% area: {:d} deg²'.format(p, i, grouping=True)) 

615 ax.text(1, 1.05, '\n'.join(text), transform=ax.transAxes, ha='right', 

616 fontsize=10) 

617 plot.outline_text(ax) 

618 if injection is not None and len(injection) == 2: 

619 from astropy.coordinates import SkyCoord 

620 from astropy import units as u 

621 

622 _inj = SkyCoord(*injection, unit=u.rad) 

623 ax.scatter( 

624 _inj.ra.value, _inj.dec.value, marker="*", color="orange", 

625 edgecolors='k', linewidth=1.75, s=100, zorder=100, 

626 transform=ax.get_transform('world') 

627 ) 

628 

629 if fig is None: 

630 return fig, ax 

631 return ExistingFigure(fig), ax 

632 

633 

634def _ligo_skymap_comparion_plot_from_array( 

635 skymaps, colors, labels, contour=[50, 90], show_probability_map=None, 

636 injection=None 

637): 

638 """Generate a skymap with `ligo.skymap` based which compares arrays of 

639 probabilities 

640 

641 Parameters 

642 ---------- 

643 skymaps: list 

644 list of skymap probabilities 

645 colors: list 

646 list of colors to use for each skymap 

647 labels: list 

648 list of labels associated with each skymap 

649 contour: list, optional 

650 contours you wish to display on the comparison plot 

651 show_probability_map: int, optional 

652 the index of the skymap you wish to show the probability 

653 map for. Default None 

654 injection: list, optional 

655 List containing RA and DEC of the injection. Both must be in radians 

656 """ 

657 from ligo.skymap import plot 

658 import matplotlib.lines as mlines 

659 ncols = number_of_columns_for_legend(labels) 

660 fig = figure(gca=False) 

661 ax = fig.add_subplot(111, projection='astro hours mollweide') 

662 ax.grid(visible=True) 

663 lines = [] 

664 for num, skymap in enumerate(skymaps): 

665 if isinstance(show_probability_map, int) and show_probability_map == num: 

666 _, ax = _ligo_skymap_plot_from_array( 

667 skymap, nsamples=None, downsampled=False, contour=contour, 

668 annotate=False, ax=ax, colors=colors[num], injection=injection, 

669 ) 

670 cls, cs = _ligo_skymap_contours( 

671 ax, skymap, contour=contour, colors=colors[num], 

672 ) 

673 lines.append(mlines.Line2D([], [], color=colors[num], label=labels[num])) 

674 ax.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, borderaxespad=0., 

675 mode="expand", ncol=ncols, handles=lines) 

676 return fig 

677 

678 

679def _ligo_skymap_contours(ax, skymap, contour=[50, 90], colors='k'): 

680 """Plot contours on a ligo.skymap skymap 

681 

682 Parameters 

683 ---------- 

684 ax: matplotlib.axes._subplots.AxesSubplot, optional 

685 Existing axis to add the plot to 

686 skymap: np.array 

687 array of probabilities 

688 contour: list 

689 list contours you wish to plot 

690 colors: str/list 

691 colors to use for the contours 

692 """ 

693 from ligo.skymap import postprocess 

694 

695 cls = 100 * postprocess.find_greedy_credible_levels(skymap) 

696 cs = ax.contour_hpx((cls, 'ICRS'), nested=True, colors=colors, 

697 linewidths=0.5, levels=contour) 

698 ax.clabel(cs, fmt=r'%g\%%', fontsize=6, inline=True) 

699 return cls, cs 

700 

701 

702def _default_skymap_plot(ra, dec, weights=None, injection=None, **kwargs): 

703 """Plot the default sky location of the source for a given approximant 

704 

705 Parameters 

706 ---------- 

707 ra: list 

708 list of samples for right ascension 

709 dec: list 

710 list of samples for declination 

711 injection: list, optional 

712 list containing the injected value of ra and dec 

713 kwargs: dict 

714 optional keyword arguments 

715 """ 

716 from .cmap import register_cylon, unregister_cylon 

717 # register the cylon cmap 

718 register_cylon() 

719 ra = [-i + np.pi for i in ra] 

720 logger.debug("Generating the sky map plot") 

721 fig, ax = figure(gca=True) 

722 ax = fig.add_subplot( 

723 111, projection="mollweide", 

724 facecolor=(1.0, 0.939165516411, 0.880255669068) 

725 ) 

726 ax.cla() 

727 ax.set_title("Preliminary", fontdict={'fontsize': 11}) 

728 ax.grid(visible=True) 

729 ax.set_xticklabels([ 

730 r"$2^{h}$", r"$4^{h}$", r"$6^{h}$", r"$8^{h}$", r"$10^{h}$", 

731 r"$12^{h}$", r"$14^{h}$", r"$16^{h}$", r"$18^{h}$", r"$20^{h}$", 

732 r"$22^{h}$"]) 

733 levels = [0.9, 0.5] 

734 

735 if weights is None: 

736 H, X, Y = np.histogram2d(ra, dec, bins=50) 

737 else: 

738 H, X, Y = np.histogram2d(ra, dec, bins=50, weights=weights) 

739 H = gaussian_filter(H, kwargs.get("smooth", 0.9)) 

740 Hflat = H.flatten() 

741 indicies = np.argsort(Hflat)[::-1] 

742 Hflat = Hflat[indicies] 

743 

744 CF = np.cumsum(Hflat) 

745 CF /= CF[-1] 

746 

747 V = np.empty(len(levels)) 

748 for num, i in enumerate(levels): 

749 try: 

750 V[num] = Hflat[CF <= i][-1] 

751 except Exception: 

752 V[num] = Hflat[0] 

753 V.sort() 

754 m = np.diff(V) == 0 

755 while np.any(m): 

756 V[np.where(m)[0][0]] *= 1.0 - 1e-4 

757 m = np.diff(V) == 0 

758 V.sort() 

759 X1, Y1 = 0.5 * (X[1:] + X[:-1]), 0.5 * (Y[1:] + Y[:-1]) 

760 

761 H2 = H.min() + np.zeros((H.shape[0] + 4, H.shape[1] + 4)) 

762 H2[2:-2, 2:-2] = H 

763 H2[2:-2, 1] = H[:, 0] 

764 H2[2:-2, -2] = H[:, -1] 

765 H2[1, 2:-2] = H[0] 

766 H2[-2, 2:-2] = H[-1] 

767 H2[1, 1] = H[0, 0] 

768 H2[1, -2] = H[0, -1] 

769 H2[-2, 1] = H[-1, 0] 

770 H2[-2, -2] = H[-1, -1] 

771 X2 = np.concatenate([X1[0] + np.array([-2, -1]) * np.diff(X1[:2]), X1, 

772 X1[-1] + np.array([1, 2]) * np.diff(X1[-2:]), ]) 

773 Y2 = np.concatenate([Y1[0] + np.array([-2, -1]) * np.diff(Y1[:2]), Y1, 

774 Y1[-1] + np.array([1, 2]) * np.diff(Y1[-2:]), ]) 

775 

776 ax.pcolormesh(X2, Y2, H2.T, vmin=0., vmax=H2.T.max(), cmap="cylon") 

777 cs = ax.contour(X2, Y2, H2.T, V, colors="k", linewidths=0.5) 

778 if injection is not None: 

779 ax.scatter( 

780 -injection[0] + np.pi, injection[1], marker="*", 

781 color=conf.injection_color, edgecolors='k', linewidth=1.75, s=100 

782 ) 

783 fmt = {l: s for l, s in zip(cs.levels, [r"$90\%$", r"$50\%$"])} 

784 ax.clabel(cs, fmt=fmt, fontsize=8, inline=True) 

785 text = [] 

786 for i, j in zip(cs.collections, [90, 50]): 

787 area = 0. 

788 for k in i.get_paths(): 

789 x = k.vertices[:, 0] 

790 y = k.vertices[:, 1] 

791 area += 0.5 * np.sum(y[:-1] * np.diff(x) - x[:-1] * np.diff(y)) 

792 area = int(np.abs(area) * (180 / np.pi) * (180 / np.pi)) 

793 text.append(u'{:d}% area: {:d} deg²'.format( 

794 int(j), area, grouping=True)) 

795 ax.text(1, 1.05, '\n'.join(text[::-1]), transform=ax.transAxes, ha='right', 

796 fontsize=10) 

797 xticks = np.arange(-np.pi, np.pi + np.pi / 6, np.pi / 4) 

798 ax.set_xticks(xticks) 

799 ax.set_yticks([-np.pi / 3, -np.pi / 6, 0, np.pi / 6, np.pi / 3]) 

800 labels = [r"$%s^{h}$" % (int(np.round((i + np.pi) * 3.82, 1))) for i in xticks] 

801 ax.set_xticklabels(labels[::-1], fontsize=10) 

802 ax.set_yticklabels([r"$-60^{\circ}$", r"$-30^{\circ}$", r"$0^{\circ}$", 

803 r"$30^{\circ}$", r"$60^{\circ}$"], fontsize=10) 

804 ax.grid(visible=True) 

805 # unregister the cylon cmap 

806 unregister_cylon() 

807 return fig 

808 

809 

810def _sky_map_comparison_plot(ra_list, dec_list, labels, colors, **kwargs): 

811 """Generate a plot that compares the sky location for multiple approximants 

812 

813 Parameters 

814 ---------- 

815 ra_list: 2d list 

816 list of samples for right ascension for each approximant 

817 dec_list: 2d list 

818 list of samples for declination for each approximant 

819 approximants: list 

820 list of approximants used to generate the samples 

821 colors: list 

822 list of colors to be used to differentiate the different approximants 

823 approximant_labels: list, optional 

824 label to prepend the approximant in the legend 

825 kwargs: dict 

826 optional keyword arguments 

827 """ 

828 ra_list = [[-i + np.pi for i in j] for j in ra_list] 

829 logger.debug("Generating the sky map comparison plot") 

830 fig = figure(gca=False) 

831 ax = fig.add_subplot( 

832 111, projection="mollweide", 

833 facecolor=(1.0, 0.939165516411, 0.880255669068) 

834 ) 

835 ax.cla() 

836 ax.grid(visible=True) 

837 ax.set_xticklabels([ 

838 r"$2^{h}$", r"$4^{h}$", r"$6^{h}$", r"$8^{h}$", r"$10^{h}$", 

839 r"$12^{h}$", r"$14^{h}$", r"$16^{h}$", r"$18^{h}$", r"$20^{h}$", 

840 r"$22^{h}$"]) 

841 levels = [0.9, 0.5] 

842 for num, i in enumerate(ra_list): 

843 H, X, Y = np.histogram2d(i, dec_list[num], bins=50) 

844 H = gaussian_filter(H, kwargs.get("smooth", 0.9)) 

845 Hflat = H.flatten() 

846 indicies = np.argsort(Hflat)[::-1] 

847 Hflat = Hflat[indicies] 

848 

849 CF = np.cumsum(Hflat) 

850 CF /= CF[-1] 

851 

852 V = np.empty(len(levels)) 

853 for num2, j in enumerate(levels): 

854 try: 

855 V[num2] = Hflat[CF <= j][-1] 

856 except Exception: 

857 V[num2] = Hflat[0] 

858 V.sort() 

859 m = np.diff(V) == 0 

860 while np.any(m): 

861 V[np.where(m)[0][0]] *= 1.0 - 1e-4 

862 m = np.diff(V) == 0 

863 V.sort() 

864 X1, Y1 = 0.5 * (X[1:] + X[:-1]), 0.5 * (Y[1:] + Y[:-1]) 

865 

866 H2 = H.min() + np.zeros((H.shape[0] + 4, H.shape[1] + 4)) 

867 H2[2:-2, 2:-2] = H 

868 H2[2:-2, 1] = H[:, 0] 

869 H2[2:-2, -2] = H[:, -1] 

870 H2[1, 2:-2] = H[0] 

871 H2[-2, 2:-2] = H[-1] 

872 H2[1, 1] = H[0, 0] 

873 H2[1, -2] = H[0, -1] 

874 H2[-2, 1] = H[-1, 0] 

875 H2[-2, -2] = H[-1, -1] 

876 X2 = np.concatenate([X1[0] + np.array([-2, -1]) * np.diff(X1[:2]), X1, 

877 X1[-1] + np.array([1, 2]) * np.diff(X1[-2:]), ]) 

878 Y2 = np.concatenate([Y1[0] + np.array([-2, -1]) * np.diff(Y1[:2]), Y1, 

879 Y1[-1] + np.array([1, 2]) * np.diff(Y1[-2:]), ]) 

880 CS = ax.contour(X2, Y2, H2.T, V, colors=colors[num], linewidths=2.0) 

881 CS.collections[0].set_label(labels[num]) 

882 ncols = number_of_columns_for_legend(labels) 

883 ax.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, borderaxespad=0., 

884 mode="expand", ncol=ncols) 

885 xticks = np.arange(-np.pi, np.pi + np.pi / 6, np.pi / 4) 

886 ax.set_xticks(xticks) 

887 ax.set_yticks([-np.pi / 3, -np.pi / 6, 0, np.pi / 6, np.pi / 3]) 

888 labels = [r"$%s^{h}$" % (int(np.round((i + np.pi) * 3.82, 1))) for i in xticks] 

889 ax.set_xticklabels(labels[::-1], fontsize=10) 

890 ax.set_yticklabels([r"$-60^\degree$", r"$-30^\degree$", r"$0^\degree$", 

891 r"$30^\degree$", r"$60^\degree$"], fontsize=10) 

892 ax.grid(visible=True) 

893 return fig 

894 

895 

896def __get_cutoff_indices(flow, fhigh, df, N): 

897 """ 

898 Gets the indices of a frequency series at which to stop an overlap 

899 calculation. 

900 

901 Parameters 

902 ---------- 

903 flow: float 

904 The frequency (in Hz) of the lower index. 

905 fhigh: float 

906 The frequency (in Hz) of the upper index. 

907 df: float 

908 The frequency step (in Hz) of the frequency series. 

909 N: int 

910 The number of points in the **time** series. Can be odd 

911 or even. 

912 

913 Returns 

914 ------- 

915 kmin: int 

916 kmax: int 

917 """ 

918 if flow: 

919 kmin = int(flow / df) 

920 else: 

921 kmin = 1 

922 if fhigh: 

923 kmax = int(fhigh / df) 

924 else: 

925 kmax = int((N + 1) / 2.) 

926 return kmin, kmax 

927 

928 

929@no_latex_plot 

930def _sky_sensitivity(network, resolution, maxL_params, **kwargs): 

931 """Generate the sky sensitivity for a given network 

932 

933 Parameters 

934 ---------- 

935 network: list 

936 list of detectors you want included in your sky sensitivity plot 

937 resolution: float 

938 resolution of the skymap 

939 maxL_params: dict 

940 dictionary of waveform parameters for the maximum likelihood waveform 

941 """ 

942 logger.debug("Generating the sky sensitivity for %s" % (network)) 

943 if not LALSIMULATION: 

944 raise Exception("LALSimulation could not be imported. Please install " 

945 "LALSuite to be able to use all features") 

946 delta_frequency = kwargs.get("delta_f", 1. / 256) 

947 minimum_frequency = kwargs.get("f_min", 20.) 

948 maximum_frequency = kwargs.get("f_max", 1000.) 

949 frequency_array = np.arange(minimum_frequency, maximum_frequency, 

950 delta_frequency) 

951 

952 approx = lalsim.GetApproximantFromString(maxL_params["approximant"]) 

953 mass_1 = maxL_params["mass_1"] * MSUN_SI 

954 mass_2 = maxL_params["mass_2"] * MSUN_SI 

955 luminosity_distance = maxL_params["luminosity_distance"] * PC_SI * 10**6 

956 iota, S1x, S1y, S1z, S2x, S2y, S2z = \ 

957 lalsim.SimInspiralTransformPrecessingNewInitialConditions( 

958 maxL_params["iota"], maxL_params["phi_jl"], maxL_params["tilt_1"], 

959 maxL_params["tilt_2"], maxL_params["phi_12"], maxL_params["a_1"], 

960 maxL_params["a_2"], mass_1, mass_2, kwargs.get("f_ref", 10.), 

961 maxL_params["phase"]) 

962 h_plus, h_cross = lalsim.SimInspiralChooseFDWaveform( 

963 mass_1, mass_2, S1x, S1y, S1z, S2x, S2y, S2z, luminosity_distance, iota, 

964 maxL_params["phase"], 0.0, 0.0, 0.0, delta_frequency, minimum_frequency, 

965 maximum_frequency, kwargs.get("f_ref", 10.), None, approx) 

966 h_plus = h_plus.data.data 

967 h_cross = h_cross.data.data 

968 h_plus = h_plus[:len(frequency_array)] 

969 h_cross = h_cross[:len(frequency_array)] 

970 psd = {} 

971 psd["H1"] = psd["L1"] = np.array([ 

972 lalsim.SimNoisePSDaLIGOZeroDetHighPower(i) for i in frequency_array]) 

973 psd["V1"] = np.array([lalsim.SimNoisePSDVirgo(i) for i in frequency_array]) 

974 kmin, kmax = __get_cutoff_indices(minimum_frequency, maximum_frequency, 

975 delta_frequency, (len(h_plus) - 1) * 2) 

976 ra = np.arange(-np.pi, np.pi, resolution) 

977 dec = np.arange(-np.pi, np.pi, resolution) 

978 X, Y = np.meshgrid(ra, dec) 

979 N = np.zeros([len(dec), len(ra)]) 

980 

981 indices = np.ndindex(len(ra), len(dec)) 

982 for ind in indices: 

983 ar = {} 

984 SNR = {} 

985 for i in network: 

986 ard = __antenna_response(i, ra[ind[0]], dec[ind[1]], 

987 maxL_params["psi"], maxL_params["geocent_time"]) 

988 ar[i] = [ard[0], ard[1]] 

989 strain = np.array(h_plus * ar[i][0] + h_cross * ar[i][1]) 

990 integrand = np.conj(strain[kmin:kmax]) * strain[kmin:kmax] / psd[i][kmin:kmax] 

991 integrand = integrand[:-1] 

992 SNR[i] = np.sqrt(4 * delta_frequency * np.sum(integrand).real) 

993 ar[i][0] *= SNR[i] 

994 ar[i][1] *= SNR[i] 

995 numerator = 0.0 

996 denominator = 0.0 

997 for i in network: 

998 numerator += sum(i**2 for i in ar[i]) 

999 denominator += SNR[i]**2 

1000 N[ind[1]][ind[0]] = (((numerator / denominator)**0.5)) 

1001 fig = figure(gca=False) 

1002 ax = fig.add_subplot(111, projection="hammer") 

1003 ax.cla() 

1004 ax.grid(visible=True) 

1005 ax.pcolormesh(X, Y, N) 

1006 ax.set_xticklabels([ 

1007 r"$22^{h}$", r"$20^{h}$", r"$18^{h}$", r"$16^{h}$", r"$14^{h}$", 

1008 r"$12^{h}$", r"$10^{h}$", r"$8^{h}$", r"$6^{h}$", r"$4^{h}$", 

1009 r"$2^{h}$"]) 

1010 return fig 

1011 

1012 

1013@no_latex_plot 

1014def _time_domain_waveform( 

1015 detectors, maxL_params, color=None, label=None, fig=None, ax=None, 

1016 **kwargs 

1017): 

1018 """ 

1019 Plot the maximum likelihood waveform for a given approximant 

1020 in the time domain. 

1021 

1022 Parameters 

1023 ---------- 

1024 detectors: list 

1025 list of detectors that you want to generate waveforms for 

1026 maxL_params: dict 

1027 dictionary of maximum likelihood parameter values 

1028 kwargs: dict 

1029 dictionary of optional keyword arguments 

1030 """ 

1031 from gwpy.timeseries import TimeSeries 

1032 from gwpy.plot.colors import GW_OBSERVATORY_COLORS 

1033 from pesummary.gw.waveform import td_waveform 

1034 from pesummary.utils.samples_dict import SamplesDict 

1035 if math.isnan(maxL_params["mass_1"]): 

1036 return 

1037 logger.debug("Generating the maximum likelihood waveform time domain plot") 

1038 if not LALSIMULATION: 

1039 raise Exception("lalsimulation could not be imported. please install " 

1040 "lalsuite to be able to use all features") 

1041 

1042 approximant = maxL_params["approximant"] 

1043 minimum_frequency = kwargs.get("f_low", 5.) 

1044 starting_frequency = kwargs.get("f_start", 5.) 

1045 approximant_flags = kwargs.get("approximant_flags", {}) 

1046 _samples = SamplesDict( 

1047 { 

1048 key: [item] for key, item in maxL_params.items() if 

1049 key != "approximant" 

1050 } 

1051 ) 

1052 _samples.generate_all_posterior_samples(disable_remnant=True) 

1053 _samples = {key: item[0] for key, item in _samples.items()} 

1054 chirptime = lalsim.SimIMRPhenomXASDuration( 

1055 _samples["mass_1"] * MSUN_SI, _samples["mass_2"] * MSUN_SI, 

1056 _samples.get("spin_1z", 0), _samples.get("spin_2z", 0), 

1057 minimum_frequency 

1058 ) 

1059 duration = np.max([2**np.ceil(np.log2(chirptime)), 1.0]) 

1060 if (fig is None) and (ax is None): 

1061 fig, ax = figure(gca=True) 

1062 elif ax is None: 

1063 ax = fig.gca() 

1064 elif fig is None: 

1065 raise ValueError("Please provide a figure for plotting") 

1066 if color is None: 

1067 color = [GW_OBSERVATORY_COLORS[i] for i in detectors] 

1068 elif len(color) != len(detectors): 

1069 raise ValueError( 

1070 "Please provide a list of colors for each detector" 

1071 ) 

1072 if label is None: 

1073 label = detectors 

1074 elif len(label) != len(detectors): 

1075 raise ValueError( 

1076 "Please provide a list of labels for each detector" 

1077 ) 

1078 for num, i in enumerate(detectors): 

1079 ht = td_waveform( 

1080 maxL_params, approximant, kwargs.get("delta_t", 1. / 4096.), 

1081 starting_frequency, f_ref=kwargs.get("f_ref", 10.), project=i, 

1082 flags=approximant_flags 

1083 ) 

1084 ax.plot( 

1085 ht.times.value, ht, color=color[num], linewidth=1.0, 

1086 label=label[num] 

1087 ) 

1088 ax.set_xlim( 

1089 [ 

1090 maxL_params["geocent_time"] - 0.75 * duration, 

1091 maxL_params["geocent_time"] + duration / 4 

1092 ] 

1093 ) 

1094 ax.set_xlabel(r"Time $[s]$") 

1095 ax.set_ylabel(r"Strain") 

1096 ax.grid(visible=True) 

1097 ax.legend(loc="best") 

1098 fig.tight_layout() 

1099 return fig 

1100 

1101 

1102@no_latex_plot 

1103def _time_domain_waveform_comparison_plot(maxL_params_list, colors, labels, 

1104 **kwargs): 

1105 """Generate a plot which compares the maximum likelihood waveforms for 

1106 each approximant. 

1107 

1108 Parameters 

1109 ---------- 

1110 maxL_params_list: list 

1111 list of dictionaries containing the maximum likelihood parameter 

1112 values for each approximant 

1113 colors: list 

1114 list of colors to be used to differentiate the different approximants 

1115 approximant_labels: list, optional 

1116 label to prepend the approximant in the legend 

1117 kwargs: dict 

1118 dictionary of optional keyword arguments 

1119 """ 

1120 from gwpy.timeseries import TimeSeries 

1121 logger.debug("Generating the maximum likelihood time domain waveform " 

1122 "comparison plot for H1") 

1123 if not LALSIMULATION: 

1124 raise Exception("LALSimulation could not be imported. Please install " 

1125 "LALSuite to be able to use all features") 

1126 fig, ax = figure(gca=True) 

1127 for num, i in enumerate(maxL_params_list): 

1128 _kwargs = { 

1129 "f_start": i.get("f_start", 20.), 

1130 "f_low": i.get("f_low", 20.), 

1131 "f_max": i.get("f_final", 1024.), 

1132 "f_ref": i.get("f_ref", 20.), 

1133 "approximant_flags": i.get("approximant_flags", {}) 

1134 } 

1135 _ = _time_domain_waveform( 

1136 ["H1"], i, fig=fig, ax=ax, color=[colors[num]], 

1137 label=[labels[num]], **_kwargs 

1138 ) 

1139 ax.set_xlabel(r"Time $[s]$") 

1140 ax.set_ylabel(r"Strain") 

1141 ax.grid(visible=True) 

1142 ax.legend(loc="best") 

1143 fig.tight_layout() 

1144 return fig 

1145 

1146 

1147def _psd_plot(frequencies, strains, colors=None, labels=None, fmin=None, fmax=None): 

1148 """Superimpose all PSD plots onto a single figure. 

1149 

1150 Parameters 

1151 ---------- 

1152 frequencies: nd list 

1153 list of all frequencies used for each psd file 

1154 strains: nd list 

1155 list of all strains used for each psd file 

1156 colors: optional, list 

1157 list of colors to be used to differentiate the different PSDs 

1158 labels: optional, list 

1159 list of lavels for each PSD 

1160 fmin: optional, float 

1161 starting frequency of the plot 

1162 fmax: optional, float 

1163 maximum frequency of the plot 

1164 """ 

1165 from gwpy.plot.colors import GW_OBSERVATORY_COLORS 

1166 fig, ax = figure(gca=True) 

1167 if not colors and all(i in GW_OBSERVATORY_COLORS.keys() for i in labels): 

1168 colors = [GW_OBSERVATORY_COLORS[i] for i in labels] 

1169 elif not colors: 

1170 colors = ['r', 'b', 'orange', 'c', 'g', 'purple'] 

1171 while len(colors) <= len(labels): 

1172 colors += colors 

1173 for num, i in enumerate(frequencies): 

1174 ff = np.array(i) 

1175 ss = np.array(strains[num]) 

1176 cond = np.ones_like(strains[num], dtype=bool) 

1177 if fmin is not None: 

1178 cond *= ff >= fmin 

1179 if fmax is not None: 

1180 cond *= ff <= fmax 

1181 i = ff[cond] 

1182 strains[num] = ss[cond] 

1183 ax.loglog(i, strains[num], color=colors[num], label=labels[num]) 

1184 ax.tick_params(which="both", bottom=True, length=3, width=1) 

1185 ax.set_xlabel(r"Frequency $[\mathrm{Hz}]$") 

1186 ax.set_ylabel(r"Power Spectral Density [$\mathrm{strain}^{2}/\mathrm{Hz}$]") 

1187 ax.legend(loc="best") 

1188 fig.tight_layout() 

1189 return fig 

1190 

1191 

1192@no_latex_plot 

1193def _calibration_envelope_plot(frequency, calibration_envelopes, ifos, 

1194 colors=None, prior=[], definition="data"): 

1195 """Generate a plot showing the calibration envelope 

1196 

1197 Parameters 

1198 ---------- 

1199 frequency: array 

1200 frequency bandwidth that you would like to use 

1201 calibration_envelopes: nd list 

1202 list containing the calibration envelope data for different IFOs 

1203 ifos: list 

1204 list of IFOs that are associated with the calibration envelopes 

1205 colors: list, optional 

1206 list of colors to be used to differentiate the different calibration 

1207 envelopes 

1208 prior: list, optional 

1209 list containing the prior calibration envelope data for different IFOs 

1210 definition: str, optional 

1211 definition used for the prior calibration envelope data 

1212 """ 

1213 from gwpy.plot.colors import GW_OBSERVATORY_COLORS 

1214 

1215 def interpolate_calibration(data): 

1216 """Interpolate the calibration data using spline 

1217 

1218 Parameters 

1219 ---------- 

1220 data: np.ndarray 

1221 array containing the calibration data 

1222 """ 

1223 interp = [ 

1224 np.interp(frequency, data[:, 0], data[:, j], left=k, right=k) 

1225 for j, k in zip(range(1, 7), [1, 0, 1, 0, 1, 0]) 

1226 ] 

1227 amp_median = (interp[0] - 1) * 100 

1228 phase_median = interp[1] * 180. / np.pi 

1229 amp_lower_sigma = (interp[2] - 1) * 100 

1230 phase_lower_sigma = interp[3] * 180. / np.pi 

1231 amp_upper_sigma = (interp[4] - 1) * 100 

1232 phase_upper_sigma = interp[5] * 180. / np.pi 

1233 data_dict = { 

1234 "amplitude": { 

1235 "median": amp_median, 

1236 "lower": amp_lower_sigma, 

1237 "upper": amp_upper_sigma 

1238 }, 

1239 "phase": { 

1240 "median": phase_median, 

1241 "lower": phase_lower_sigma, 

1242 "upper": phase_upper_sigma 

1243 } 

1244 } 

1245 return data_dict 

1246 

1247 fig, (ax1, ax2) = subplots(2, 1, sharex=True, gca=False) 

1248 if not colors and all(i in GW_OBSERVATORY_COLORS.keys() for i in ifos): 

1249 colors = [GW_OBSERVATORY_COLORS[i] for i in ifos] 

1250 elif not colors: 

1251 colors = ['r', 'b', 'orange', 'c', 'g', 'purple'] 

1252 while len(colors) <= len(ifos): 

1253 colors += colors 

1254 

1255 for num, i in enumerate(calibration_envelopes): 

1256 calibration_envelopes[num] = np.array(calibration_envelopes[num]) 

1257 

1258 for num, i in enumerate(calibration_envelopes): 

1259 calibration_data = interpolate_calibration(i) 

1260 if prior != []: 

1261 prior_data = interpolate_calibration(prior[num]) 

1262 ax1.plot( 

1263 frequency, calibration_data["amplitude"]["upper"], color=colors[num], 

1264 linestyle="-", label=ifos[num] 

1265 ) 

1266 ax1.plot( 

1267 frequency, calibration_data["amplitude"]["lower"], color=colors[num], 

1268 linestyle="-" 

1269 ) 

1270 ax1.set_ylabel(r"Amplitude deviation $[\%]$", fontsize=10) 

1271 ax1.legend(loc="best") 

1272 ax2.plot( 

1273 frequency, calibration_data["phase"]["upper"], color=colors[num], 

1274 linestyle="-", label=ifos[num] 

1275 ) 

1276 ax2.plot( 

1277 frequency, calibration_data["phase"]["lower"], color=colors[num], 

1278 linestyle="-" 

1279 ) 

1280 ax2.set_ylabel(r"Phase deviation $[\degree]$", fontsize=10) 

1281 if prior != []: 

1282 ax1.fill_between( 

1283 frequency, prior_data["amplitude"]["upper"], 

1284 prior_data["amplitude"]["lower"], color=colors[num], alpha=0.2 

1285 ) 

1286 ax2.fill_between( 

1287 frequency, prior_data["phase"]["upper"], 

1288 prior_data["phase"]["lower"], color=colors[num], alpha=0.2 

1289 ) 

1290 

1291 ax1.set_title(f"Calibration correction applied to {definition}") 

1292 ax1.set_xscale('log') 

1293 ax2.set_xscale('log') 

1294 ax2.set_xlabel(r"Frequency $[Hz]$") 

1295 fig.tight_layout() 

1296 return fig 

1297 

1298 

1299def _strain_plot(strain, maxL_params, **kwargs): 

1300 """Generate a plot showing the strain data and the maxL waveform 

1301 

1302 Parameters 

1303 ---------- 

1304 strain: gwpy.timeseries 

1305 timeseries containing the strain data 

1306 maxL_samples: dict 

1307 dictionary of maximum likelihood parameter values 

1308 """ 

1309 logger.debug("Generating the strain plot") 

1310 from pesummary.gw.conversions import time_in_each_ifo 

1311 from gwpy.timeseries import TimeSeries 

1312 

1313 fig, axs = subplots(nrows=len(strain.keys()), sharex=True) 

1314 time = maxL_params["geocent_time"] 

1315 delta_t = 1. / 4096. 

1316 minimum_frequency = kwargs.get("f_min", 5.) 

1317 t_start = time - 15.0 

1318 t_finish = time + 0.06 

1319 time_array = np.arange(t_start, t_finish, delta_t) 

1320 

1321 approx = lalsim.GetApproximantFromString(maxL_params["approximant"]) 

1322 mass_1 = maxL_params["mass_1"] * MSUN_SI 

1323 mass_2 = maxL_params["mass_2"] * MSUN_SI 

1324 luminosity_distance = maxL_params["luminosity_distance"] * PC_SI * 10**6 

1325 phase = maxL_params["phase"] if "phase" in maxL_params.keys() else 0.0 

1326 cartesian = [ 

1327 "iota", "spin_1x", "spin_1y", "spin_1z", "spin_2x", "spin_2y", "spin_2z" 

1328 ] 

1329 if not all(param in maxL_params.keys() for param in cartesian): 

1330 if "phi_jl" in maxL_params.keys(): 

1331 iota, S1x, S1y, S1z, S2x, S2y, S2z = \ 

1332 lalsim.SimInspiralTransformPrecessingNewInitialConditions( 

1333 maxL_params["theta_jn"], maxL_params["phi_jl"], 

1334 maxL_params["tilt_1"], maxL_params["tilt_2"], 

1335 maxL_params["phi_12"], maxL_params["a_1"], 

1336 maxL_params["a_2"], mass_1, mass_2, kwargs.get("f_ref", 10.), 

1337 phase 

1338 ) 

1339 else: 

1340 iota, S1x, S1y, S1z, S2x, S2y, S2z = maxL_params["iota"], 0., 0., \ 

1341 0., 0., 0., 0. 

1342 else: 

1343 iota, S1x, S1y, S1z, S2x, S2y, S2z = [ 

1344 maxL_params[param] for param in cartesian 

1345 ] 

1346 h_plus, h_cross = lalsim.SimInspiralChooseTDWaveform( 

1347 mass_1, mass_2, S1x, S1y, S1z, S2x, S2y, S2z, luminosity_distance, iota, 

1348 phase, 0.0, 0.0, 0.0, delta_t, minimum_frequency, 

1349 kwargs.get("f_ref", 10.), None, approx) 

1350 h_plus = TimeSeries( 

1351 h_plus.data.data[:], dt=h_plus.deltaT, t0=h_plus.epoch 

1352 ) 

1353 h_cross = TimeSeries( 

1354 h_cross.data.data[:], dt=h_cross.deltaT, t0=h_cross.epoch 

1355 ) 

1356 

1357 for num, key in enumerate(list(strain.keys())): 

1358 ifo_time = time_in_each_ifo(key, maxL_params["ra"], maxL_params["dec"], 

1359 maxL_params["geocent_time"]) 

1360 

1361 asd = strain[key].asd(8, 4, method="median") 

1362 strain_data_frequency = strain[key].fft() 

1363 asd_interp = asd.interpolate(float(np.array(strain_data_frequency.df))) 

1364 asd_interp = asd_interp[:len(strain_data_frequency)] 

1365 strain_data_time = (strain_data_frequency / asd_interp).ifft() 

1366 strain_data_time = strain_data_time.highpass(30) 

1367 strain_data_time = strain_data_time.lowpass(300) 

1368 

1369 ar = __antenna_response(key, maxL_params["ra"], maxL_params["dec"], 

1370 maxL_params["psi"], maxL_params["geocent_time"]) 

1371 

1372 h_t = ar[0] * h_plus + ar[1] * h_cross 

1373 h_t_frequency = h_t.fft() 

1374 asd_interp = asd.interpolate(float(np.array(h_t_frequency.df))) 

1375 asd_interp = asd_interp[:len(h_t_frequency)] 

1376 h_t_time = (h_t_frequency / asd_interp).ifft() 

1377 h_t_time = h_t_time.highpass(30) 

1378 h_t_time = h_t_time.lowpass(300) 

1379 h_t_time.times = [float(np.array(i)) + ifo_time for i in h_t.times] 

1380 

1381 strain_data_crop = strain_data_time.crop(ifo_time - 0.2, ifo_time + 0.06) 

1382 try: 

1383 h_t_time = h_t_time.crop(ifo_time - 0.2, ifo_time + 0.06) 

1384 except Exception: 

1385 pass 

1386 max_strain = np.max(strain_data_crop).value 

1387 

1388 axs[num].plot(strain_data_crop, color='grey', alpha=0.75, label="data") 

1389 axs[num].plot(h_t_time, color='orange', label="template") 

1390 axs[num].set_xlim([ifo_time - 0.2, ifo_time + 0.06]) 

1391 if not math.isnan(max_strain): 

1392 axs[num].set_ylim([-max_strain * 1.5, max_strain * 1.5]) 

1393 axs[num].set_ylabel("Whitened %s strain" % (key), fontsize=8) 

1394 axs[num].grid(False) 

1395 axs[num].legend(loc="best", prop={'size': 8}) 

1396 axs[-1].set_xlabel("Time $[s]$", fontsize=16) 

1397 fig.tight_layout() 

1398 return fig 

1399 

1400 

1401def _format_prob(prob): 

1402 """Format the probabilities for use with _classification_plot 

1403 """ 

1404 if prob >= 1: 

1405 return '100%' 

1406 elif prob <= 0: 

1407 return '0%' 

1408 elif prob > 0.99: 

1409 return '>99%' 

1410 elif prob < 0.01: 

1411 return '<1%' 

1412 else: 

1413 return '{}%'.format(int(np.round(100 * prob))) 

1414 

1415 

1416@no_latex_plot 

1417def _classification_plot(classification): 

1418 """Generate a bar chart showing the source classifications probabilities 

1419 

1420 Parameters 

1421 ---------- 

1422 classification: dict 

1423 dictionary of source classifications 

1424 """ 

1425 probs, names = zip( 

1426 *sorted(zip(classification.values(), classification.keys()))) 

1427 with matplotlib.style.context([ 

1428 "seaborn-v0_8-white", 

1429 { 

1430 "font.size": 12, 

1431 "ytick.labelsize": 12, 

1432 }, 

1433 ]): 

1434 fig, ax = figure(figsize=(2.5, 2), gca=True) 

1435 ax.barh(names, probs) 

1436 for i, prob in enumerate(probs): 

1437 ax.annotate(_format_prob(prob), (0, i), (4, 0), 

1438 textcoords='offset points', ha='left', va='center') 

1439 ax.set_xlim(0, 1) 

1440 ax.set_xticks([]) 

1441 ax.tick_params(left=False) 

1442 for side in ['top', 'bottom', 'right']: 

1443 ax.spines[side].set_visible(False) 

1444 fig.tight_layout() 

1445 return fig