Coverage for pesummary/gw/plots/main.py: 67.5%

437 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1#! /usr/bin/env python 

2 

3# Licensed under an MIT style license -- see LICENSE.md 

4 

5import os 

6 

7from pesummary.core.plots.main import _PlotGeneration as _BasePlotGeneration 

8from pesummary.core.plots.latex_labels import latex_labels 

9from pesummary.core.plots import interactive 

10from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE 

11from pesummary.gw.plots.latex_labels import GWlatex_labels 

12from pesummary.utils.utils import logger, resample_posterior_distribution 

13from pesummary.utils.decorators import no_latex_plot 

14from pesummary.gw.plots import publication 

15from pesummary.gw.plots import plot as gw 

16 

17import multiprocessing as mp 

18import numpy as np 

19 

20__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

21latex_labels.update(GWlatex_labels) 

22 

23 

24class _PlotGeneration(_BasePlotGeneration): 

25 def __init__( 

26 self, savedir=None, webdir=None, labels=None, samples=None, 

27 kde_plot=False, existing_labels=None, existing_injection_data=None, 

28 existing_file_kwargs=None, existing_samples=None, 

29 existing_metafile=None, same_parameters=None, injection_data=None, 

30 result_files=None, file_kwargs=None, colors=None, custom_plotting=None, 

31 add_to_existing=False, priors={}, no_ligo_skymap=False, 

32 nsamples_for_skymap=None, detectors=None, maxL_samples=None, 

33 gwdata=None, calibration=None, psd=None, 

34 multi_threading_for_skymap=None, approximant=None, 

35 pepredicates_probs=None, include_prior=False, publication=False, 

36 existing_approximant=None, existing_psd=None, existing_calibration=None, 

37 existing_weights=None, weights=None, disable_comparison=False, 

38 linestyles=None, disable_interactive=False, disable_corner=False, 

39 publication_kwargs={}, multi_process=1, mcmc_samples=False, 

40 skymap=None, existing_skymap=None, corner_params=None, 

41 preliminary_pages=False, expert_plots=True, checkpoint=False, 

42 key_data=None 

43 ): 

44 super(_PlotGeneration, self).__init__( 

45 savedir=savedir, webdir=webdir, labels=labels, 

46 samples=samples, kde_plot=kde_plot, existing_labels=existing_labels, 

47 existing_injection_data=existing_injection_data, 

48 existing_samples=existing_samples, 

49 existing_weights=existing_weights, 

50 same_parameters=same_parameters, 

51 injection_data=injection_data, mcmc_samples=mcmc_samples, 

52 colors=colors, custom_plotting=custom_plotting, 

53 add_to_existing=add_to_existing, priors=priors, 

54 include_prior=include_prior, weights=weights, 

55 disable_comparison=disable_comparison, linestyles=linestyles, 

56 disable_interactive=disable_interactive, disable_corner=disable_corner, 

57 multi_process=multi_process, corner_params=corner_params, 

58 expert_plots=expert_plots, checkpoint=checkpoint, key_data=key_data 

59 ) 

60 self.preliminary_pages = preliminary_pages 

61 if not isinstance(self.preliminary_pages, dict): 

62 if self.preliminary_pages: 

63 self.preliminary_pages = { 

64 label: True for label in self.labels 

65 } 

66 else: 

67 self.preliminary_pages = { 

68 label: False for label in self.labels 

69 } 

70 self.preliminary_comparison_pages = any( 

71 value for value in self.preliminary_pages.values() 

72 ) 

73 self.package = "gw" 

74 self.file_kwargs = file_kwargs 

75 self.existing_file_kwargs = existing_file_kwargs 

76 self.no_ligo_skymap = no_ligo_skymap 

77 self.nsamples_for_skymap = nsamples_for_skymap 

78 self.detectors = detectors 

79 self.maxL_samples = maxL_samples 

80 self.gwdata = gwdata 

81 if skymap is None: 

82 skymap = {label: None for label in self.labels} 

83 self.skymap = skymap 

84 self.existing_skymap = skymap 

85 self.calibration = calibration 

86 self.existing_calibration = existing_calibration 

87 self.psd = psd 

88 self.existing_psd = existing_psd 

89 self.multi_threading_for_skymap = multi_threading_for_skymap 

90 self.approximant = approximant 

91 self.existing_approximant = existing_approximant 

92 self.pepredicates_probs = pepredicates_probs 

93 self.publication = publication 

94 self.publication_kwargs = publication_kwargs 

95 self._ligo_skymap_PID = {} 

96 

97 self.plot_type_dictionary.update({ 

98 "psd": self.psd_plot, 

99 "calibration": self.calibration_plot, 

100 "twod_histogram": self.twod_histogram_plot, 

101 "skymap": self.skymap_plot, 

102 "waveform_fd": self.waveform_fd_plot, 

103 "waveform_td": self.waveform_td_plot, 

104 "data": self.gwdata_plots, 

105 "violin": self.violin_plot, 

106 "spin_disk": self.spin_dist_plot, 

107 "pepredicates": self.pepredicates_plot 

108 }) 

109 if self.make_comparison: 

110 self.plot_type_dictionary.update({ 

111 "skymap_comparison": self.skymap_comparison_plot, 

112 "waveform_comparison_fd": self.waveform_comparison_fd_plot, 

113 "waveform_comparison_td": self.waveform_comparison_td_plot, 

114 "2d_comparison_contour": self.twod_comparison_contour_plot, 

115 }) 

116 

117 @property 

118 def ligo_skymap_PID(self): 

119 return self._ligo_skymap_PID 

120 

121 def generate_plots(self): 

122 """Generate all plots for all result files 

123 """ 

124 if self.calibration or "calibration" in list(self.priors.keys()): 

125 self.try_to_make_a_plot("calibration") 

126 if self.psd: 

127 self.try_to_make_a_plot("psd") 

128 super(_PlotGeneration, self).generate_plots() 

129 

130 def _generate_plots(self, label): 

131 """Generate all plots for a given result file 

132 """ 

133 super(_PlotGeneration, self)._generate_plots(label) 

134 self.try_to_make_a_plot("twod_histogram", label=label) 

135 self.try_to_make_a_plot("skymap", label=label) 

136 self.try_to_make_a_plot("waveform_td", label=label) 

137 self.try_to_make_a_plot("waveform_fd", label=label) 

138 if self.pepredicates_probs[label] is not None: 

139 self.try_to_make_a_plot("pepredicates", label=label) 

140 if self.gwdata: 

141 self.try_to_make_a_plot("data", label=label) 

142 

143 def _generate_comparison_plots(self): 

144 """Generate all comparison plots 

145 """ 

146 super(_PlotGeneration, self)._generate_comparison_plots() 

147 self.try_to_make_a_plot("skymap_comparison") 

148 self.try_to_make_a_plot("waveform_comparison_td") 

149 self.try_to_make_a_plot("waveform_comparison_fd") 

150 if self.publication: 

151 self.try_to_make_a_plot("2d_comparison_contour") 

152 self.try_to_make_a_plot("violin") 

153 self.try_to_make_a_plot("spin_disk") 

154 

155 @staticmethod 

156 def _corner_plot( 

157 savedir, label, samples, latex_labels, webdir, params, preliminary=False, 

158 checkpoint=False 

159 ): 

160 """Generate a corner plot for a given set of samples 

161 

162 Parameters 

163 ---------- 

164 savedir: str 

165 the directory you wish to save the plot in 

166 label: str 

167 the label corresponding to the results file 

168 samples: dict 

169 dictionary of samples for a given result file 

170 latex_labels: dict 

171 dictionary of latex labels 

172 webdir: str 

173 directory where the javascript is written 

174 preliminary: Bool, optional 

175 if True, add a preliminary watermark to the plot 

176 """ 

177 import warnings 

178 

179 with warnings.catch_warnings(): 

180 warnings.simplefilter("ignore") 

181 filename = os.path.join( 

182 savedir, "corner", "{}_all_density_plots.png".format(label) 

183 ) 

184 if os.path.isfile(filename) and checkpoint: 

185 pass 

186 else: 

187 fig, params, data = gw._make_corner_plot( 

188 samples, latex_labels, corner_parameters=params 

189 ) 

190 fig.savefig(filename) 

191 fig.close() 

192 combine_corner = open( 

193 os.path.join(webdir, "js", "combine_corner.js") 

194 ) 

195 combine_corner = combine_corner.readlines() 

196 params = [str(i) for i in params] 

197 ind = [ 

198 linenumber for linenumber, line in enumerate(combine_corner) 

199 if "var list = {}" in line 

200 ][0] 

201 combine_corner.insert( 

202 ind + 1, " list['{}'] = {};\n".format(label, params) 

203 ) 

204 new_file = open( 

205 os.path.join(webdir, "js", "combine_corner.js"), "w" 

206 ) 

207 new_file.writelines(combine_corner) 

208 new_file.close() 

209 combine_corner = open( 

210 os.path.join(webdir, "js", "combine_corner.js") 

211 ) 

212 combine_corner = combine_corner.readlines() 

213 params = [str(i) for i in params] 

214 ind = [ 

215 linenumber for linenumber, line in enumerate(combine_corner) 

216 if "var data = {}" in line 

217 ][0] 

218 combine_corner.insert( 

219 ind + 1, " data['{}'] = {};\n".format(label, data) 

220 ) 

221 new_file = open( 

222 os.path.join(webdir, "js", "combine_corner.js"), "w" 

223 ) 

224 new_file.writelines(combine_corner) 

225 new_file.close() 

226 

227 filename = os.path.join( 

228 savedir, "corner", "{}_sourceframe.png".format(label) 

229 ) 

230 if os.path.isfile(filename) and checkpoint: 

231 pass 

232 else: 

233 fig = gw._make_source_corner_plot(samples, latex_labels) 

234 fig.savefig(filename) 

235 fig.close() 

236 filename = os.path.join( 

237 savedir, "corner", "{}_extrinsic.png".format(label) 

238 ) 

239 if os.path.isfile(filename) and checkpoint: 

240 pass 

241 else: 

242 fig = gw._make_extrinsic_corner_plot(samples, latex_labels) 

243 fig.savefig(filename) 

244 fig.close() 

245 

246 def twod_histogram_plot(self, label): 

247 """ 

248 """ 

249 from pesummary import conf 

250 error_message = ( 

251 "Failed to generate %s-%s triangle plot because {}" 

252 ) 

253 paramset = [ 

254 params for params in conf.gw_2d_plots if 

255 all(p in self.samples[label] for p in params) 

256 ] 

257 arguments = [ 

258 ( 

259 [ 

260 self.savedir, label, params, 

261 [self.samples[label][p] for p in params], 

262 [latex_labels[p] for p in params], 

263 [self.injection_data[label][p] for p in params], 

264 self.preliminary_pages[label], self.checkpoint 

265 ], self._triangle_plot, error_message % (params[0], params[1]) 

266 ) for params in paramset 

267 ] 

268 self.pool.starmap(self._try_to_make_a_plot, arguments) 

269 

270 @staticmethod 

271 def _triangle_plot( 

272 savedir, label, params, samples, latex_labels, injection, preliminary=False, 

273 checkpoint=False 

274 ): 

275 from pesummary.core.plots.publication import triangle_plot 

276 import math 

277 for num, ii in enumerate(injection): 

278 if math.isnan(ii): 

279 injection[num] = None 

280 

281 if any(ii is None for ii in injection): 

282 truth = None 

283 else: 

284 truth = injection 

285 filename = os.path.join( 

286 savedir, "{}_2d_posterior_{}_{}.png".format( 

287 label, params[0], params[1] 

288 ) 

289 ) 

290 if os.path.isfile(filename) and checkpoint: 

291 return 

292 fig, _, _, _ = triangle_plot( 

293 *samples, kde=False, parameters=params, xlabel=latex_labels[0], 

294 ylabel=latex_labels[1], plot_datapoints=True, plot_density=False, 

295 levels=[1e-8], fill=False, grid=True, linewidths=[1.75], 

296 percentiles=[5, 95], percentile_plot=[label], labels=[label], 

297 truth=truth 

298 ) 

299 _PlotGeneration.save( 

300 fig, filename, preliminary=preliminary 

301 ) 

302 

303 

304 def skymap_plot(self, label): 

305 """Generate a skymap plot for a given result file 

306 

307 Parameters 

308 ---------- 

309 label: str 

310 the label for the results file that you wish to plot 

311 """ 

312 try: 

313 import ligo.skymap # noqa: F401 

314 except ImportError: 

315 SKYMAP = False 

316 else: 

317 SKYMAP = True 

318 

319 if self.mcmc_samples: 

320 samples = self.samples[label].combine 

321 else: 

322 samples = self.samples[label] 

323 _injection = [ 

324 self.injection_data[label]["ra"], self.injection_data[label]["dec"] 

325 ] 

326 self._skymap_plot( 

327 self.savedir, samples["ra"], samples["dec"], label, 

328 self.weights[label], _injection, 

329 preliminary=self.preliminary_pages[label] 

330 ) 

331 

332 if SKYMAP and not self.no_ligo_skymap and self.skymap[label] is None: 

333 from pesummary.utils.utils import RedirectLogger 

334 

335 logger.info("Launching subprocess to generate skymap plot with " 

336 "ligo.skymap") 

337 try: 

338 _time = samples["geocent_time"] 

339 except KeyError: 

340 logger.warning( 

341 "Unable to find 'geocent_time' in the posterior table for {}. " 

342 "The ligo.skymap fits file will therefore not store the " 

343 "DATE_OBS field in the header".format(label) 

344 ) 

345 _time = None 

346 with RedirectLogger("ligo.skymap", level="DEBUG") as redirector: 

347 process = mp.Process( 

348 target=self._ligo_skymap_plot, 

349 args=[ 

350 self.savedir, samples["ra"], samples["dec"], 

351 samples["luminosity_distance"], _time, 

352 label, self.nsamples_for_skymap, self.webdir, 

353 self.multi_threading_for_skymap, _injection, 

354 self.preliminary_pages[label] 

355 ] 

356 ) 

357 process.start() 

358 PID = process.pid 

359 self._ligo_skymap_PID[label] = PID 

360 elif SKYMAP and not self.no_ligo_skymap: 

361 self._ligo_skymap_array_plot( 

362 self.savedir, self.skymap[label], label, 

363 self.preliminary_pages[label] 

364 ) 

365 

366 @staticmethod 

367 @no_latex_plot 

368 def _skymap_plot( 

369 savedir, ra, dec, label, weights, injection=None, preliminary=False 

370 ): 

371 """Generate a skymap plot for a given set of samples 

372 

373 Parameters 

374 ---------- 

375 savedir: str 

376 the directory you wish to save the plot in 

377 ra: pesummary.utils.utils.Array 

378 array containing the samples for right ascension 

379 dec: pesummary.utils.utils.Array 

380 array containing the samples for declination 

381 label: str 

382 the label corresponding to the results file 

383 weights: list 

384 list of weights for the samples 

385 injection: list, optional 

386 list containing the injected value of ra and dec 

387 preliminary: Bool, optional 

388 if True, add a preliminary watermark to the plot 

389 """ 

390 import math 

391 

392 if injection is not None and any(math.isnan(inj) for inj in injection): 

393 injection = None 

394 fig = gw._default_skymap_plot(ra, dec, weights, injection=injection) 

395 _PlotGeneration.save( 

396 fig, os.path.join(savedir, "{}_skymap".format(label)), 

397 preliminary=preliminary 

398 ) 

399 

400 @staticmethod 

401 @no_latex_plot 

402 def _ligo_skymap_plot(savedir, ra, dec, dist, time, label, nsamples_for_skymap, 

403 webdir, multi_threading_for_skymap, injection, 

404 preliminary=False): 

405 """Generate a skymap plot for a given set of samples using the 

406 ligo.skymap package 

407 

408 Parameters 

409 ---------- 

410 savedir: str 

411 the directory you wish to save the plot in 

412 ra: pesummary.utils.utils.Array 

413 array containing the samples for right ascension 

414 dec: pesummary.utils.utils.Array 

415 array containing the samples for declination 

416 dist: pesummary.utils.utils.Array 

417 array containing the samples for luminosity distance 

418 time: pesummary.utils.utils.Array 

419 array containing the samples for the geocentric time of merger 

420 label: str 

421 the label corresponding to the results file 

422 nsamples_for_skymap: int 

423 the number of samples used to generate skymap 

424 webdir: str 

425 the directory to store the fits file 

426 preliminary: Bool, optional 

427 if True, add a preliminary watermark to the plot 

428 """ 

429 import math 

430 

431 downsampled = False 

432 if nsamples_for_skymap is not None: 

433 ra, dec, dist = resample_posterior_distribution( 

434 [ra, dec, dist], nsamples_for_skymap 

435 ) 

436 downsampled = True 

437 if injection is not None and any(math.isnan(inj) for inj in injection): 

438 injection = None 

439 fig = gw._ligo_skymap_plot( 

440 ra, dec, dist=dist, savedir=os.path.join(webdir, "samples"), 

441 nprocess=multi_threading_for_skymap, downsampled=downsampled, 

442 label=label, time=time, injection=injection 

443 ) 

444 _PlotGeneration.save( 

445 fig, os.path.join(savedir, "{}_skymap".format(label)), 

446 preliminary=preliminary 

447 ) 

448 

449 @staticmethod 

450 @no_latex_plot 

451 def _ligo_skymap_array_plot(savedir, skymap, label, preliminary=False): 

452 """Generate a skymap based on skymap probability array already generated with 

453 `ligo.skymap` 

454 

455 Parameters 

456 ---------- 

457 savedir: str 

458 the directory you wish to save the plot in 

459 skymap: np.ndarray 

460 array of skymap probabilities 

461 label: str 

462 the label corresponding to the results file 

463 preliminary: Bool, optional 

464 if True, add a preliminary watermark to the plot 

465 """ 

466 fig = gw._ligo_skymap_plot_from_array(skymap) 

467 _PlotGeneration.save( 

468 fig, os.path.join(savedir, "{}_skymap".format(label)), 

469 preliminary=preliminary 

470 ) 

471 

472 def waveform_fd_plot(self, label): 

473 """Generate a frequency domain waveform plot for a given result file 

474 

475 Parameters 

476 ---------- 

477 label: str 

478 the label corresponding to the results file 

479 """ 

480 if self.approximant[label] == {}: 

481 return 

482 self._waveform_fd_plot( 

483 self.savedir, self.detectors[label], self.maxL_samples[label], label, 

484 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint, 

485 **self.file_kwargs[label]["meta_data"] 

486 ) 

487 

488 @staticmethod 

489 def _waveform_fd_plot( 

490 savedir, detectors, maxL_samples, label, preliminary=False, 

491 checkpoint=False, **kwargs 

492 ): 

493 """Generate a frequency domain waveform plot for a given detector 

494 network and set of samples 

495 

496 Parameters 

497 ---------- 

498 savedir: str 

499 the directory you wish to save the plot in 

500 detectors: list 

501 list of detectors used in your analysis 

502 maxL_samples: dict 

503 dictionary of maximum likelihood values 

504 label: str 

505 the label corresponding to the results file 

506 preliminary: Bool, optional 

507 if True, add a preliminary watermark to the plot 

508 """ 

509 filename = os.path.join(savedir, "{}_waveform.png".format(label)) 

510 if os.path.isfile(filename) and checkpoint: 

511 return 

512 if detectors is None: 

513 detectors = ["H1", "L1"] 

514 else: 

515 detectors = detectors.split("_") 

516 

517 fig = gw._waveform_plot( 

518 detectors, maxL_samples, f_min=kwargs.get("f_low", 20.0), 

519 f_max=kwargs.get("f_final", 1024.), 

520 f_ref=kwargs.get("f_ref", 20.) 

521 ) 

522 _PlotGeneration.save( 

523 fig, filename, preliminary=preliminary 

524 ) 

525 

526 def waveform_td_plot(self, label): 

527 """Generate a time domain waveform plot for a given result file 

528 

529 Parameters 

530 ---------- 

531 label: str 

532 the label corresponding to the results file 

533 """ 

534 if self.approximant[label] == {}: 

535 return 

536 self._waveform_td_plot( 

537 self.savedir, self.detectors[label], self.maxL_samples[label], label, 

538 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint, 

539 **self.file_kwargs[label]["meta_data"] 

540 ) 

541 

542 @staticmethod 

543 def _waveform_td_plot( 

544 savedir, detectors, maxL_samples, label, preliminary=False, 

545 checkpoint=False, **kwargs 

546 ): 

547 """Generate a time domain waveform plot for a given detector network 

548 and set of samples 

549 

550 Parameters 

551 ---------- 

552 savedir: str 

553 the directory you wish to save the plot in 

554 detectors: list 

555 list of detectors used in your analysis 

556 maxL_samples: dict 

557 dictionary of maximum likelihood values 

558 label: str 

559 the label corresponding to the results file 

560 preliminary: Bool, optional 

561 if True, add a preliminary watermark to the plot 

562 """ 

563 filename = os.path.join( 

564 savedir, "{}_waveform_time_domain.png".format(label) 

565 ) 

566 if os.path.isfile(filename) and checkpoint: 

567 return 

568 if detectors is None: 

569 detectors = ["H1", "L1"] 

570 else: 

571 detectors = detectors.split("_") 

572 

573 fig = gw._time_domain_waveform( 

574 detectors, maxL_samples, f_min=kwargs.get("f_low", 20.0), 

575 f_max=kwargs.get("f_final", 1024.), 

576 f_ref=kwargs.get("f_ref", 20.) 

577 ) 

578 _PlotGeneration.save( 

579 fig, filename, preliminary=preliminary 

580 ) 

581 

582 def gwdata_plots(self, label): 

583 """Generate all plots associated with the gwdata 

584 

585 Parameters 

586 ---------- 

587 label: str 

588 the label corresponding to the results file 

589 """ 

590 from pesummary.utils.utils import determine_gps_time_and_window 

591 

592 base_error = "Failed to generate a %s because {}" 

593 gps_time, window = determine_gps_time_and_window( 

594 self.maxL_samples, self.labels 

595 ) 

596 functions = [ 

597 self.strain_plot, self.spectrogram_plot, self.omegascan_plot 

598 ] 

599 args = [[label], [], [gps_time, window]] 

600 func_names = ["strain_plot", "spectrogram plot", "omegascan plot"] 

601 

602 for func, args, name in zip(functions, args, func_names): 

603 self._try_to_make_a_plot(args, func, base_error % (name)) 

604 continue 

605 

606 def strain_plot(self, label): 

607 """Generate a plot showing the comparison between the data and the 

608 maxL waveform gfor a given result file 

609 

610 Parameters 

611 ---------- 

612 label: str 

613 the label corresponding to the results file 

614 """ 

615 logger.info("Launching subprocess to generate strain plot") 

616 process = mp.Process( 

617 target=self._strain_plot, 

618 args=[self.savedir, self.gwdata, self.maxL_samples[label], label] 

619 ) 

620 process.start() 

621 

622 @staticmethod 

623 def _strain_plot(savedir, gwdata, maxL_samples, label, checkpoint=False): 

624 """Generate a strain plot for a given set of samples 

625 

626 Parameters 

627 ---------- 

628 savedir: str 

629 the directory to save the plot 

630 gwdata: dict 

631 dictionary of strain data for each detector 

632 maxL_samples: dict 

633 dictionary of maximum likelihood values 

634 label: str 

635 the label corresponding to the results file 

636 """ 

637 filename = os.path.join(savedir, "{}_strain.png".format(label)) 

638 if os.path.isfile(filename) and checkpoint: 

639 return 

640 fig = gw._strain_plot(gwdata, maxL_samples) 

641 _PlotGeneration.save(fig, filename) 

642 

643 def spectrogram_plot(self): 

644 """Generate a plot showing the spectrogram for all detectors 

645 """ 

646 figs = self._spectrogram_plot(self.savedir, self.gwdata) 

647 

648 @staticmethod 

649 def _spectrogram_plot(savedir, strain): 

650 """Generate a plot showing the spectrogram for all detectors 

651 

652 Parameters 

653 ---------- 

654 savedir: str 

655 the directory you wish to save the plot in 

656 strain: dict 

657 dictionary of gwpy timeseries objects containing the strain data for 

658 each IFO 

659 """ 

660 from pesummary.gw.plots import detchar 

661 

662 figs = detchar.spectrogram(strain) 

663 for det, fig in figs.items(): 

664 _PlotGeneration.save( 

665 fig, os.path.join(savedir, "spectrogram_{}".format(det)) 

666 ) 

667 

668 def omegascan_plot(self, gps_time, window): 

669 """Generate a plot showing the omegascan for all detectors 

670 

671 Parameters 

672 ---------- 

673 gps_time: float 

674 time around which to centre the omegascan 

675 window: float 

676 window around gps time to generate plot for 

677 """ 

678 figs = self._omegascan_plot( 

679 self.savedir, self.gwdata, gps_time, window 

680 ) 

681 

682 @staticmethod 

683 def _omegascan_plot(savedir, strain, gps, window): 

684 """Generate a plot showing the spectrogram for all detectors 

685 

686 Parameters 

687 ---------- 

688 savedir: str 

689 the directory you wish to save the plot in 

690 strain: dict 

691 dictionary of gwpy timeseries objects containing the strain data for 

692 each IFO 

693 gps: float 

694 time around which to centre the omegascan 

695 window: float 

696 window around gps time to generate plot for 

697 """ 

698 from pesummary.gw.plots import detchar 

699 

700 figs = detchar.omegascan(strain, gps, window=window) 

701 for det, fig in figs.items(): 

702 _PlotGeneration.save( 

703 fig, os.path.join(savedir, "omegascan_{}".format(det)) 

704 ) 

705 

706 def skymap_comparison_plot(self, label): 

707 """Generate a plot to compare skymaps for all result files 

708 

709 Parameters 

710 ---------- 

711 label: str 

712 the label for the results file that you wish to plot 

713 """ 

714 self._skymap_comparison_plot( 

715 self.savedir, self.same_samples["ra"], self.same_samples["dec"], 

716 self.labels, self.colors, self.preliminary_comparison_pages, 

717 self.checkpoint 

718 ) 

719 

720 @staticmethod 

721 def _skymap_comparison_plot( 

722 savedir, ra, dec, labels, colors, preliminary=False, checkpoint=False 

723 ): 

724 """Generate a plot to compare skymaps for a given set of samples 

725 

726 Parameters 

727 ---------- 

728 savedir: str 

729 the directory you wish to save the plot in 

730 ra: dict 

731 dictionary of right ascension samples for each result file 

732 dec: dict 

733 dictionary of declination samples for each result file 

734 labels: list 

735 list of labels to distinguish each result file 

736 colors: list 

737 list of colors to be used to distinguish different result files 

738 preliminary: Bool, optional 

739 if True, add a preliminary watermark to the plot 

740 """ 

741 filename = os.path.join(savedir, "combined_skymap.png") 

742 if os.path.isfile(filename) and checkpoint: 

743 return 

744 ra_list = [ra[key] for key in labels] 

745 dec_list = [dec[key] for key in labels] 

746 fig = gw._sky_map_comparison_plot(ra_list, dec_list, labels, colors) 

747 _PlotGeneration.save( 

748 fig, filename, preliminary=preliminary 

749 ) 

750 

751 def waveform_comparison_fd_plot(self, label): 

752 """Generate a plot to compare the frequency domain waveform 

753 

754 Parameters 

755 ---------- 

756 label: str 

757 the label for the results file that you wish to plot 

758 """ 

759 if any(self.approximant[i] == {} for i in self.labels): 

760 return 

761 

762 self._waveform_comparison_fd_plot( 

763 self.savedir, self.maxL_samples, self.labels, self.colors, 

764 preliminary=self.preliminary_comparison_pages, checkpoint=self.checkpoint, 

765 **self.file_kwargs 

766 ) 

767 

768 @staticmethod 

769 def _waveform_comparison_fd_plot( 

770 savedir, maxL_samples, labels, colors, preliminary=False, 

771 checkpoint=False, **kwargs 

772 ): 

773 """Generate a plot to compare the frequency domain waveforms 

774 

775 Parameters 

776 ---------- 

777 savedir: str 

778 the directory you wish to save the plot in 

779 maxL_samples: dict 

780 dictionary of maximum likelihood samples for each result file 

781 labels: list 

782 list of labels to distinguish each result file 

783 colors: list 

784 list of colors to be used to distinguish different result files 

785 preliminary: Bool, optional 

786 if True, add a preliminary watermark to the plot 

787 """ 

788 filename = os.path.join(savedir, "compare_waveforms.png") 

789 if os.path.isfile(filename) and checkpoint: 

790 return 

791 samples = [maxL_samples[i] for i in labels] 

792 f_min = np.max( 

793 [kwargs[label]["meta_data"].get("f_low", 20.) for label in labels] 

794 ) 

795 f_max = np.min( 

796 [kwargs[label]["meta_data"].get("f_final", 1024.) for label in labels] 

797 ) 

798 f_ref = kwargs[labels[0]]["meta_data"].get("f_ref", 20.) 

799 fig = gw._waveform_comparison_plot( 

800 samples, colors, labels, f_min=f_min, f_max=f_max, 

801 f_ref=f_ref 

802 ) 

803 _PlotGeneration.save( 

804 fig, filename, preliminary=preliminary 

805 ) 

806 

807 def waveform_comparison_td_plot(self, label): 

808 """Generate a plot to compare the time domain waveform 

809 

810 Parameters 

811 ---------- 

812 label: str 

813 the label for the results file that you wish to plot 

814 """ 

815 if any(self.approximant[i] == {} for i in self.labels): 

816 return 

817 

818 self._waveform_comparison_fd_plot( 

819 self.savedir, self.maxL_samples, self.labels, self.colors, 

820 self.preliminary_comparison_pages, self.checkpoint 

821 ) 

822 

823 @staticmethod 

824 def _waveform_comparison_td_plot( 

825 savedir, maxL_samples, labels, colors, preliminary=False, 

826 checkpoint=False 

827 ): 

828 """Generate a plot to compare the time domain waveforms 

829 

830 Parameters 

831 ---------- 

832 savedir: str 

833 the directory you wish to save the plot in 

834 maxL_samples: dict 

835 dictionary of maximum likelihood samples for each result file 

836 labels: list 

837 list of labels to distinguish each result file 

838 colors: list 

839 list of colors to be used to distinguish different result files 

840 preliminary: Bool, optional 

841 if True, add a preliminary watermark to the plot 

842 """ 

843 filename = os.path.join(savedir, "compare_time_domain_waveforms.png") 

844 if os.path.isfile(filename) and checkpoint: 

845 return 

846 samples = [maxL_samples[i] for i in labels] 

847 fig = gw._time_domainwaveform_comparison_plot(samples, colors, labels) 

848 _PlotGeneration.save( 

849 fig, filename, preliminary=preliminary 

850 ) 

851 

852 def twod_comparison_contour_plot(self, label): 

853 """Generate 2d comparison contour plots 

854 

855 Parameters 

856 ---------- 

857 label: str 

858 the label for the results file that you wish to plot 

859 """ 

860 error_message = ( 

861 "Failed to generate a 2d contour plot for %s because {}" 

862 ) 

863 twod_plots = [ 

864 ["mass_ratio", "chi_eff"], ["mass_1", "mass_2"], 

865 ["luminosity_distance", "chirp_mass_source"], 

866 ["mass_1_source", "mass_2_source"], 

867 ["theta_jn", "luminosity_distance"], 

868 ["network_optimal_snr", "chirp_mass_source"] 

869 ] 

870 gridsize = ( 

871 int(self.publication_kwargs["gridsize"]) if "gridsize" in 

872 self.publication_kwargs.keys() else 100 

873 ) 

874 for plot in twod_plots: 

875 if not all( 

876 all( 

877 i in self.samples[j].keys() for i in plot 

878 ) for j in self.labels 

879 ): 

880 logger.warning( 

881 "Failed to generate 2d contour plots for {} because {} are not " 

882 "common in all result files".format( 

883 " and ".join(plot), " and ".join(plot) 

884 ) 

885 ) 

886 continue 

887 samples = [[self.samples[i][j] for j in plot] for i in self.labels] 

888 arguments = [ 

889 self.savedir, plot, samples, self.labels, latex_labels, 

890 self.colors, self.linestyles, gridsize, 

891 self.preliminary_comparison_pages, self.checkpoint 

892 ] 

893 self._try_to_make_a_plot( 

894 arguments, self._twod_comparison_contour_plot, 

895 error_message % (" and ".join(plot)) 

896 ) 

897 

898 @staticmethod 

899 def _twod_comparison_contour_plot( 

900 savedir, plot_parameters, samples, labels, latex_labels, colors, 

901 linestyles, gridsize, preliminary=False, checkpoint=False 

902 ): 

903 """Generate a 2d comparison contour plot for a given set of samples 

904 

905 Parameters 

906 ---------- 

907 savedir: str 

908 the directory you wish to save the plot in 

909 plot_parameters: list 

910 list of parameters to use for the 2d contour plot 

911 samples: list 

912 list of samples for each parameter 

913 labels: list 

914 list of labels used to distinguish each result file 

915 latex_labels: dict 

916 dictionary containing the latex labels for each parameter 

917 gridsize: int 

918 the number of points to use when estimating the KDE 

919 preliminary: Bool, optional 

920 if True, add a preliminary watermark to the plot 

921 """ 

922 filename = os.path.join( 

923 savedir, "publication", "2d_contour_plot_{}.png".format( 

924 "_and_".join(plot_parameters) 

925 ) 

926 ) 

927 if os.path.isfile(filename) and checkpoint: 

928 return 

929 fig = publication.twod_contour_plots( 

930 plot_parameters, samples, labels, latex_labels, colors=colors, 

931 linestyles=linestyles, gridsize=gridsize 

932 ) 

933 _PlotGeneration.save( 

934 fig, filename, preliminary=preliminary 

935 ) 

936 

937 def violin_plot(self, label): 

938 """Generate violin plot to compare certain parameters in all result 

939 files 

940 

941 Parameters 

942 ---------- 

943 label: str 

944 the label for the results file that you wish to plot 

945 """ 

946 error_message = ( 

947 "Failed to generate a violin plot for %s because {}" 

948 ) 

949 violin_plots = ["mass_ratio", "chi_eff", "chi_p", "luminosity_distance"] 

950 

951 for plot in violin_plots: 

952 injection = [self.injection_data[label][plot] for label in self.labels] 

953 if not all(plot in self.samples[j].keys() for j in self.labels): 

954 logger.warning( 

955 "Failed to generate violin plots for {} because {} is not " 

956 "common in all result files".format(plot, plot) 

957 ) 

958 samples = [self.samples[i][plot] for i in self.labels] 

959 arguments = [ 

960 self.savedir, plot, samples, self.labels, latex_labels[plot], 

961 injection, self.preliminary_comparison_pages, self.checkpoint 

962 ] 

963 self._try_to_make_a_plot( 

964 arguments, self._violin_plot, error_message % (plot) 

965 ) 

966 

967 @staticmethod 

968 def _violin_plot( 

969 savedir, plot_parameter, samples, labels, latex_label, inj_values=None, 

970 preliminary=False, checkpoint=False, kde=ReflectionBoundedKDE, 

971 default_bounds=True 

972 ): 

973 """Generate a violin plot for a given set of samples 

974 

975 Parameters 

976 ---------- 

977 savedir: str 

978 the directory you wish to save the plot in 

979 plot_parameter: str 

980 name of the parameter you wish to generate a violin plot for 

981 samples: list 

982 list of samples for each parameter 

983 labels: list 

984 list of labels used to distinguish each result file 

985 latex_label: str 

986 latex_label correspondig to parameter 

987 inj_value: list 

988 list of injected values for each sample 

989 preliminary: Bool, optional 

990 if True, add a preliminary watermark to the plot 

991 """ 

992 filename = os.path.join( 

993 savedir, "publication", "violin_plot_{}.png".format(plot_parameter) 

994 ) 

995 if os.path.isfile(filename) and checkpoint: 

996 return 

997 xlow, xhigh = None, None 

998 if default_bounds: 

999 xlow, xhigh = gw._return_bounds( 

1000 plot_parameter, samples, comparison=True 

1001 ) 

1002 fig = publication.violin_plots( 

1003 plot_parameter, samples, labels, latex_labels, kde=kde, 

1004 kde_kwargs={"xlow": xlow, "xhigh": xhigh}, inj_values=inj_values 

1005 ) 

1006 _PlotGeneration.save( 

1007 fig, filename, preliminary=preliminary 

1008 ) 

1009 

1010 def spin_dist_plot(self, label): 

1011 """Generate a spin disk plot to compare spins in all result 

1012 files 

1013 

1014 Parameters 

1015 ---------- 

1016 label: str 

1017 the label for the results file that you wish to plot 

1018 """ 

1019 error_message = ( 

1020 "Failed to generate a spin disk plot for %s because {}" 

1021 ) 

1022 parameters = ["a_1", "a_2", "cos_tilt_1", "cos_tilt_2"] 

1023 for num, label in enumerate(self.labels): 

1024 if not all(i in self.samples[label].keys() for i in parameters): 

1025 logger.warning( 

1026 "Failed to generate spin disk plots because {} are not " 

1027 "common in all result files".format( 

1028 " and ".join(parameters) 

1029 ) 

1030 ) 

1031 continue 

1032 samples = [self.samples[label][i] for i in parameters] 

1033 arguments = [ 

1034 self.savedir, parameters, samples, label, self.colors[num], 

1035 self.preliminary_comparison_pages, self.checkpoint 

1036 ] 

1037 

1038 self._try_to_make_a_plot( 

1039 arguments, self._spin_dist_plot, error_message % (label) 

1040 ) 

1041 

1042 @staticmethod 

1043 def _spin_dist_plot( 

1044 savedir, parameters, samples, label, color, preliminary=False, 

1045 checkpoint=False 

1046 ): 

1047 """Generate a spin disk plot for a given set of samples 

1048 

1049 Parameters 

1050 ---------- 

1051 preliminary: Bool, optional 

1052 if True, add a preliminary watermark to the plot 

1053 """ 

1054 filename = os.path.join( 

1055 savedir, "publication", "spin_disk_plot_{}.png".format(label) 

1056 ) 

1057 if os.path.isfile(filename) and checkpoint: 

1058 return 

1059 fig = publication.spin_distribution_plots( 

1060 parameters, samples, label, color=color 

1061 ) 

1062 _PlotGeneration.save( 

1063 fig, filename, preliminary=preliminary 

1064 ) 

1065 

1066 def pepredicates_plot(self, label): 

1067 """Generate plots with the PEPredicates package 

1068 

1069 Parameters 

1070 ---------- 

1071 label: str 

1072 the label for the results file that you wish to plot 

1073 """ 

1074 if self.mcmc_samples: 

1075 samples = self.samples[label].combine 

1076 else: 

1077 samples = self.samples[label] 

1078 self._pepredicates_plot( 

1079 self.savedir, samples, label, 

1080 self.pepredicates_probs[label]["default"], population_prior=False, 

1081 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint 

1082 ) 

1083 self._pepredicates_plot( 

1084 self.savedir, samples, label, 

1085 self.pepredicates_probs[label]["population"], population_prior=True, 

1086 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint 

1087 ) 

1088 

1089 @staticmethod 

1090 @no_latex_plot 

1091 def _pepredicates_plot( 

1092 savedir, samples, label, probabilities, population_prior=False, 

1093 preliminary=False, checkpoint=False 

1094 ): 

1095 """Generate a plot with the PEPredicates package for a given set of 

1096 samples 

1097 

1098 Parameters 

1099 ---------- 

1100 savedir: str 

1101 the directory you wish to save the plot in 

1102 samples: dict 

1103 dictionary of samples for each parameter 

1104 label: str 

1105 the label corresponding to the result file 

1106 probabilities: dict 

1107 dictionary of classification probabilities 

1108 population_prior: Bool, optional 

1109 if True, the samples will be reweighted according to a population 

1110 prior 

1111 preliminary: Bool, optional 

1112 if True, add a preliminary watermark to the plot 

1113 """ 

1114 from pesummary.gw.classification import PEPredicates 

1115 

1116 if not population_prior: 

1117 filename = os.path.join( 

1118 savedir, "{}_default_pepredicates.png".format(label) 

1119 ) 

1120 else: 

1121 filename = os.path.join( 

1122 savedir, "{}_population_pepredicates.png".format(label) 

1123 ) 

1124 

1125 _pepredicates = PEPredicates(samples) 

1126 if os.path.isfile(filename) and checkpoint: 

1127 pass 

1128 else: 

1129 fig = _pepredicates.plot( 

1130 type="pepredicates", population=population_prior, 

1131 probabilities=probabilities 

1132 ) 

1133 _PlotGeneration.save( 

1134 fig, filename, preliminary=preliminary 

1135 ) 

1136 

1137 if not population_prior: 

1138 filename = os.path.join( 

1139 savedir, "{}_default_pepredicates_bar.png".format(label) 

1140 ) 

1141 else: 

1142 filename = os.path.join( 

1143 savedir, "{}_population_pepredicates_bar.png".format(label) 

1144 ) 

1145 if os.path.isfile(filename) and checkpoint: 

1146 pass 

1147 else: 

1148 fig = _pepredicates.plot( 

1149 type="bar", probabilities=probabilities, 

1150 population=population_prior 

1151 ) 

1152 _PlotGeneration.save( 

1153 fig, filename, preliminary=preliminary 

1154 ) 

1155 

1156 def psd_plot(self, label): 

1157 """Generate a psd plot for a given result file 

1158 

1159 Parameters 

1160 ---------- 

1161 label: str 

1162 the label corresponding to the result file 

1163 """ 

1164 error_message = ( 

1165 "Failed to generate a PSD plot for %s because {}" 

1166 ) 

1167 

1168 fmin = None 

1169 fmax = None 

1170 

1171 for num, label in enumerate(self.labels): 

1172 if list(self.psd[label].keys()) == [None]: 

1173 return 

1174 if list(self.psd[label].keys()) == []: 

1175 return 

1176 if "f_low" in list(self.file_kwargs[label]["meta_data"].keys()): 

1177 fmin = self.file_kwargs[label]["meta_data"]["f_low"] 

1178 if "f_final" in list(self.file_kwargs[label]["meta_data"].keys()): 

1179 fmax = self.file_kwargs[label]["meta_data"]["f_final"] 

1180 labels = list(self.psd[label].keys()) 

1181 frequencies = [np.array(self.psd[label][i]).T[0] for i in labels] 

1182 strains = [np.array(self.psd[label][i]).T[1] for i in labels] 

1183 arguments = [ 

1184 self.savedir, frequencies, strains, fmin, fmax, labels, label, 

1185 self.checkpoint 

1186 ] 

1187 

1188 self._try_to_make_a_plot( 

1189 arguments, self._psd_plot, error_message % (label) 

1190 ) 

1191 

1192 @staticmethod 

1193 def _psd_plot( 

1194 savedir, frequencies, strains, fmin, fmax, psd_labels, label, checkpoint=False 

1195 ): 

1196 """Generate a psd plot for a given set of samples 

1197 

1198 Parameters 

1199 ---------- 

1200 savedir: str 

1201 the directory you wish to save the plot in 

1202 frequencies: list 

1203 list of psd frequencies for each IFO 

1204 strains: list 

1205 list of psd strains for each IFO 

1206 fmin: float 

1207 frequency to start the psd plotting 

1208 fmax: float 

1209 frequency to end the psd plotting 

1210 psd_labels: list 

1211 list of IFOs used 

1212 label: str 

1213 the label used to distinguish the result file 

1214 """ 

1215 filename = os.path.join(savedir, "{}_psd_plot.png".format(label)) 

1216 if os.path.isfile(filename) and checkpoint: 

1217 return 

1218 fig = gw._psd_plot( 

1219 frequencies, strains, labels=psd_labels, fmin=fmin, fmax=fmax 

1220 ) 

1221 _PlotGeneration.save(fig, filename) 

1222 

1223 def calibration_plot(self, label): 

1224 """Generate a calibration plot for a given result file 

1225 

1226 Parameters 

1227 ---------- 

1228 label: str 

1229 the label corresponding to the result file 

1230 """ 

1231 import numpy as np 

1232 

1233 error_message = ( 

1234 "Failed to generate calibration plot for %s because {}" 

1235 ) 

1236 frequencies = np.arange(20., 1024., 1. / 4) 

1237 

1238 for num, label in enumerate(self.labels): 

1239 if list(self.calibration[label].keys()) == [None]: 

1240 return 

1241 if list(self.calibration[label].keys()) == []: 

1242 return 

1243 

1244 ifos = list(self.calibration[label].keys()) 

1245 calibration_data = [ 

1246 self.calibration[label][i] for i in ifos 

1247 ] 

1248 if "calibration" in self.priors.keys(): 

1249 prior = [self.priors["calibration"][label][i] for i in ifos] 

1250 else: 

1251 prior = None 

1252 arguments = [ 

1253 self.savedir, frequencies, calibration_data, ifos, prior, 

1254 label, self.checkpoint 

1255 ] 

1256 self._try_to_make_a_plot( 

1257 arguments, self._calibration_plot, error_message % (label) 

1258 ) 

1259 

1260 @staticmethod 

1261 def _calibration_plot( 

1262 savedir, frequencies, calibration_data, calibration_labels, prior, label, 

1263 checkpoint=False 

1264 ): 

1265 """Generate a calibration plot for a given set of samples 

1266 

1267 Parameters 

1268 ---------- 

1269 savedir: str 

1270 the directory you wish to save the plot in 

1271 frequencies: list 

1272 list of frequencies used to interpolate the calibration data 

1273 calibration_data: list 

1274 list of calibration data for each IFO 

1275 calibration_labels: list 

1276 list of IFOs used 

1277 prior: list 

1278 list containing the priors used for each IFO 

1279 label: str 

1280 the label used to distinguish the result file 

1281 """ 

1282 filename = os.path.join( 

1283 savedir, "{}_calibration_plot.png".format(label) 

1284 ) 

1285 if os.path.isfile(filename) and checkpoint: 

1286 return 

1287 fig = gw._calibration_envelope_plot( 

1288 frequencies, calibration_data, calibration_labels, prior=prior 

1289 ) 

1290 _PlotGeneration.save(fig, filename) 

1291 

1292 @staticmethod 

1293 def _interactive_corner_plot( 

1294 savedir, label, samples, latex_labels, checkpoint=False 

1295 ): 

1296 """Generate an interactive corner plot for a given set of samples 

1297 

1298 Parameters 

1299 ---------- 

1300 savedir: str 

1301 the directory you wish to save the plot in 

1302 label: str 

1303 the label corresponding to the results file 

1304 samples: dict 

1305 dictionary containing PESummary.utils.utils.Array objects that 

1306 contain samples for each parameter 

1307 latex_labels: str 

1308 latex labels for each parameter in samples 

1309 """ 

1310 filename = os.path.join( 

1311 savedir, "corner", "{}_interactive_source.html".format(label) 

1312 ) 

1313 if os.path.isfile(filename) and checkpoint: 

1314 pass 

1315 else: 

1316 source_parameters = [ 

1317 "luminosity_distance", "mass_1_source", "mass_2_source", 

1318 "total_mass_source", "chirp_mass_source", "redshift" 

1319 ] 

1320 parameters = [i for i in samples.keys() if i in source_parameters] 

1321 data = [samples[parameter] for parameter in parameters] 

1322 labels = [latex_labels[parameter] for parameter in parameters] 

1323 _ = interactive.corner( 

1324 data, labels, write_to_html_file=filename, 

1325 dimensions={"width": 900, "height": 900} 

1326 ) 

1327 

1328 filename = os.path.join( 

1329 savedir, "corner", "{}_interactive_extrinsic.html".format(label) 

1330 ) 

1331 if os.path.isfile(filename) and checkpoint: 

1332 pass 

1333 else: 

1334 extrinsic_parameters = ["luminosity_distance", "psi", "ra", "dec"] 

1335 parameters = [i for i in samples.keys() if i in extrinsic_parameters] 

1336 data = [samples[parameter] for parameter in parameters] 

1337 labels = [latex_labels[parameter] for parameter in parameters] 

1338 _ = interactive.corner( 

1339 data, labels, write_to_html_file=filename 

1340 )