Coverage for pesummary/gw/plots/main.py: 65.8%

488 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2026-01-15 17:49 +0000

1#! /usr/bin/env python 

2 

3# Licensed under an MIT style license -- see LICENSE.md 

4 

5import os 

6 

7from pesummary.core.plots.main import _PlotGeneration as _BasePlotGeneration 

8from pesummary.core.plots.latex_labels import latex_labels 

9from pesummary.core.plots import interactive 

10from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE 

11from pesummary.gw.plots.latex_labels import GWlatex_labels 

12from pesummary.utils.utils import logger, resample_posterior_distribution 

13from pesummary.utils.decorators import no_latex_plot 

14from pesummary.gw.plots import publication 

15from pesummary.gw.plots import plot as gw 

16 

17import multiprocessing as mp 

18import numpy as np 

19 

20__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

21latex_labels.update(GWlatex_labels) 

22 

23 

24class _PlotGeneration(_BasePlotGeneration): 

25 def __init__( 

26 self, savedir=None, webdir=None, labels=None, samples=None, 

27 kde_plot=False, existing_labels=None, existing_injection_data=None, 

28 existing_file_kwargs=None, existing_samples=None, 

29 existing_metafile=None, same_parameters=None, injection_data=None, 

30 result_files=None, file_kwargs=None, colors=None, custom_plotting=None, 

31 add_to_existing=False, priors={}, no_ligo_skymap=False, 

32 nsamples_for_skymap=None, detectors=None, maxL_samples=None, 

33 gwdata=None, calibration_definition=None, calibration=None, psd=None, 

34 multi_threading_for_skymap=None, approximant=None, 

35 classification_probs=None, include_prior=False, publication=False, 

36 existing_approximant=None, existing_psd=None, existing_calibration=None, 

37 existing_weights=None, weights=None, disable_comparison=False, 

38 linestyles=None, disable_interactive=False, disable_corner=False, 

39 publication_kwargs={}, multi_process=1, mcmc_samples=False, 

40 skymap=None, existing_skymap=None, corner_params=None, 

41 preliminary_pages=False, expert_plots=True, checkpoint=False, 

42 key_data=None 

43 ): 

44 super(_PlotGeneration, self).__init__( 

45 savedir=savedir, webdir=webdir, labels=labels, 

46 samples=samples, kde_plot=kde_plot, existing_labels=existing_labels, 

47 existing_injection_data=existing_injection_data, 

48 existing_samples=existing_samples, 

49 existing_weights=existing_weights, 

50 same_parameters=same_parameters, 

51 injection_data=injection_data, mcmc_samples=mcmc_samples, 

52 colors=colors, custom_plotting=custom_plotting, 

53 add_to_existing=add_to_existing, priors=priors, 

54 include_prior=include_prior, weights=weights, 

55 disable_comparison=disable_comparison, linestyles=linestyles, 

56 disable_interactive=disable_interactive, disable_corner=disable_corner, 

57 multi_process=multi_process, corner_params=corner_params, 

58 expert_plots=expert_plots, checkpoint=checkpoint, key_data=key_data 

59 ) 

60 self.preliminary_pages = preliminary_pages 

61 if not isinstance(self.preliminary_pages, dict): 

62 if self.preliminary_pages: 

63 self.preliminary_pages = { 

64 label: True for label in self.labels 

65 } 

66 else: 

67 self.preliminary_pages = { 

68 label: False for label in self.labels 

69 } 

70 self.preliminary_comparison_pages = any( 

71 value for value in self.preliminary_pages.values() 

72 ) 

73 self.package = "gw" 

74 self.file_kwargs = file_kwargs 

75 self.existing_file_kwargs = existing_file_kwargs 

76 self.no_ligo_skymap = no_ligo_skymap 

77 self.nsamples_for_skymap = nsamples_for_skymap 

78 self.detectors = detectors 

79 self.maxL_samples = maxL_samples 

80 self.gwdata = gwdata 

81 if skymap is None: 

82 skymap = {label: None for label in self.labels} 

83 self.skymap = skymap 

84 self.existing_skymap = skymap 

85 self.calibration_definition = calibration_definition 

86 self.calibration = calibration 

87 self.existing_calibration = existing_calibration 

88 self.psd = psd 

89 self.existing_psd = existing_psd 

90 self.multi_threading_for_skymap = multi_threading_for_skymap 

91 self.approximant = approximant 

92 self.existing_approximant = existing_approximant 

93 self.classification_probs = classification_probs 

94 self.publication = publication 

95 self.publication_kwargs = publication_kwargs 

96 self._ligo_skymap_PID = {} 

97 

98 self.plot_type_dictionary.update({ 

99 "psd": self.psd_plot, 

100 "calibration": self.calibration_plot, 

101 "twod_histogram": self.twod_histogram_plot, 

102 "skymap": self.skymap_plot, 

103 "waveform_fd": self.waveform_fd_plot, 

104 "waveform_td": self.waveform_td_plot, 

105 "data": self.gwdata_plots, 

106 "violin": self.violin_plot, 

107 "spin_disk": self.spin_dist_plot, 

108 "classification": self.classification_plot 

109 }) 

110 if self.make_comparison: 

111 self.plot_type_dictionary.update({ 

112 "skymap_comparison": self.skymap_comparison_plot, 

113 "waveform_comparison_fd": self.waveform_comparison_fd_plot, 

114 "waveform_comparison_td": self.waveform_comparison_td_plot, 

115 "2d_comparison_contour": self.twod_comparison_contour_plot, 

116 }) 

117 

118 @property 

119 def ligo_skymap_PID(self): 

120 return self._ligo_skymap_PID 

121 

122 def generate_plots(self): 

123 """Generate all plots for all result files 

124 """ 

125 if self.calibration or "calibration" in list(self.priors.keys()): 

126 self.try_to_make_a_plot("calibration") 

127 if self.psd: 

128 self.try_to_make_a_plot("psd") 

129 super(_PlotGeneration, self).generate_plots() 

130 

131 def _generate_plots(self, label): 

132 """Generate all plots for a given result file 

133 """ 

134 super(_PlotGeneration, self)._generate_plots(label) 

135 self.try_to_make_a_plot("twod_histogram", label=label) 

136 self.try_to_make_a_plot("skymap", label=label) 

137 self.try_to_make_a_plot("waveform_td", label=label) 

138 self.try_to_make_a_plot("waveform_fd", label=label) 

139 if self.classification_probs[label] is not None: 

140 self.try_to_make_a_plot("classification", label=label) 

141 if self.gwdata: 

142 self.try_to_make_a_plot("data", label=label) 

143 

144 def _generate_comparison_plots(self): 

145 """Generate all comparison plots 

146 """ 

147 super(_PlotGeneration, self)._generate_comparison_plots() 

148 self.try_to_make_a_plot("skymap_comparison") 

149 self.try_to_make_a_plot("waveform_comparison_td") 

150 self.try_to_make_a_plot("waveform_comparison_fd") 

151 if self.publication: 

152 self.try_to_make_a_plot("2d_comparison_contour") 

153 self.try_to_make_a_plot("violin") 

154 self.try_to_make_a_plot("spin_disk") 

155 

156 @staticmethod 

157 def _corner_plot( 

158 savedir, label, samples, latex_labels, webdir, params, preliminary=False, 

159 checkpoint=False 

160 ): 

161 """Generate a corner plot for a given set of samples 

162 

163 Parameters 

164 ---------- 

165 savedir: str 

166 the directory you wish to save the plot in 

167 label: str 

168 the label corresponding to the results file 

169 samples: dict 

170 dictionary of samples for a given result file 

171 latex_labels: dict 

172 dictionary of latex labels 

173 webdir: str 

174 directory where the javascript is written 

175 preliminary: Bool, optional 

176 if True, add a preliminary watermark to the plot 

177 """ 

178 import warnings 

179 

180 with warnings.catch_warnings(): 

181 warnings.simplefilter("ignore") 

182 filename = os.path.join( 

183 savedir, "corner", "{}_all_density_plots.png".format(label) 

184 ) 

185 if os.path.isfile(filename) and checkpoint: 

186 pass 

187 else: 

188 fig, params, data = gw._make_corner_plot( 

189 samples, latex_labels, corner_parameters=params 

190 ) 

191 fig.savefig(filename) 

192 fig.close() 

193 combine_corner = open( 

194 os.path.join(webdir, "js", "combine_corner.js") 

195 ) 

196 combine_corner = combine_corner.readlines() 

197 params = [str(i) for i in params] 

198 ind = [ 

199 linenumber for linenumber, line in enumerate(combine_corner) 

200 if "var list = {}" in line 

201 ][0] 

202 combine_corner.insert( 

203 ind + 1, " list['{}'] = {};\n".format(label, params) 

204 ) 

205 new_file = open( 

206 os.path.join(webdir, "js", "combine_corner.js"), "w" 

207 ) 

208 new_file.writelines(combine_corner) 

209 new_file.close() 

210 combine_corner = open( 

211 os.path.join(webdir, "js", "combine_corner.js") 

212 ) 

213 combine_corner = combine_corner.readlines() 

214 params = [str(i) for i in params] 

215 ind = [ 

216 linenumber for linenumber, line in enumerate(combine_corner) 

217 if "var data = {}" in line 

218 ][0] 

219 combine_corner.insert( 

220 ind + 1, " data['{}'] = {};\n".format(label, data) 

221 ) 

222 new_file = open( 

223 os.path.join(webdir, "js", "combine_corner.js"), "w" 

224 ) 

225 new_file.writelines(combine_corner) 

226 new_file.close() 

227 

228 filename = os.path.join( 

229 savedir, "corner", "{}_sourceframe.png".format(label) 

230 ) 

231 if os.path.isfile(filename) and checkpoint: 

232 pass 

233 else: 

234 fig = gw._make_source_corner_plot(samples, latex_labels) 

235 fig.savefig(filename) 

236 fig.close() 

237 filename = os.path.join( 

238 savedir, "corner", "{}_extrinsic.png".format(label) 

239 ) 

240 if os.path.isfile(filename) and checkpoint: 

241 pass 

242 else: 

243 fig = gw._make_extrinsic_corner_plot(samples, latex_labels) 

244 fig.savefig(filename) 

245 fig.close() 

246 

247 def twod_histogram_plot(self, label): 

248 """ 

249 """ 

250 from pesummary import conf 

251 error_message = ( 

252 "Failed to generate %s-%s triangle plot because {}" 

253 ) 

254 paramset = [ 

255 params for params in conf.gw_2d_plots if 

256 all(p in self.samples[label] for p in params) 

257 ] 

258 if self.weights is not None: 

259 weights = self.weights.get(label, None) 

260 else: 

261 weights = None 

262 arguments = [ 

263 ( 

264 [ 

265 self.savedir, label, params, 

266 [self.samples[label][p] for p in params], 

267 [latex_labels[p] for p in params], 

268 [self.injection_data[label][p] for p in params], 

269 weights, self.preliminary_pages[label], self.checkpoint 

270 ], self._triangle_plot, error_message % (params[0], params[1]) 

271 ) for params in paramset 

272 ] 

273 self.pool.starmap(self._try_to_make_a_plot, arguments) 

274 

275 @staticmethod 

276 def _triangle_plot( 

277 savedir, label, params, samples, latex_labels, injection, weights, 

278 preliminary=False, checkpoint=False 

279 ): 

280 from pesummary.core.plots.publication import triangle_plot 

281 import math 

282 for num, ii in enumerate(injection): 

283 if math.isnan(ii): 

284 injection[num] = None 

285 

286 if any(ii is None for ii in injection): 

287 truth = None 

288 else: 

289 truth = injection 

290 filename = os.path.join( 

291 savedir, "{}_2d_posterior_{}_{}.png".format( 

292 label, params[0], params[1] 

293 ) 

294 ) 

295 if os.path.isfile(filename) and checkpoint: 

296 return 

297 fig, _, _, _ = triangle_plot( 

298 *samples, kde=False, parameters=params, xlabel=latex_labels[0], 

299 ylabel=latex_labels[1], plot_datapoints=True, plot_density=False, 

300 levels=[1e-8], fill=False, grid=True, linewidths=[1.75], 

301 percentiles=[5, 95], percentile_plot=[label], labels=[label], 

302 truth=truth, weights=weights, data_kwargs={"alpha": 0.3} 

303 ) 

304 _PlotGeneration.save( 

305 fig, filename, preliminary=preliminary 

306 ) 

307 

308 

309 def skymap_plot(self, label): 

310 """Generate a skymap plot for a given result file 

311 

312 Parameters 

313 ---------- 

314 label: str 

315 the label for the results file that you wish to plot 

316 """ 

317 try: 

318 import ligo.skymap # noqa: F401 

319 except ImportError: 

320 SKYMAP = False 

321 else: 

322 SKYMAP = True 

323 

324 if self.mcmc_samples: 

325 samples = self.samples[label].combine 

326 else: 

327 samples = self.samples[label] 

328 _injection = [ 

329 self.injection_data[label]["ra"], self.injection_data[label]["dec"] 

330 ] 

331 self._skymap_plot( 

332 self.savedir, samples["ra"], samples["dec"], label, 

333 self.weights[label], _injection, 

334 preliminary=self.preliminary_pages[label] 

335 ) 

336 

337 if SKYMAP and not self.no_ligo_skymap and self.skymap[label] is None: 

338 from pesummary.utils.utils import RedirectLogger 

339 

340 logger.info("Launching subprocess to generate skymap plot with " 

341 "ligo.skymap") 

342 try: 

343 _time = samples["geocent_time"] 

344 except KeyError: 

345 logger.warning( 

346 "Unable to find 'geocent_time' in the posterior table for {}. " 

347 "The ligo.skymap fits file will therefore not store the " 

348 "DATE_OBS field in the header".format(label) 

349 ) 

350 _time = None 

351 with RedirectLogger("ligo.skymap", level="DEBUG") as redirector: 

352 process = mp.Process( 

353 target=self._ligo_skymap_plot, 

354 args=[ 

355 self.savedir, samples["ra"], samples["dec"], 

356 samples["luminosity_distance"], _time, 

357 label, self.nsamples_for_skymap, self.webdir, 

358 self.multi_threading_for_skymap, _injection, 

359 self.preliminary_pages[label] 

360 ] 

361 ) 

362 process.start() 

363 #PID = process.pid 

364 self._ligo_skymap_PID[label] = process 

365 elif SKYMAP and not self.no_ligo_skymap: 

366 self._ligo_skymap_array_plot( 

367 self.savedir, self.skymap[label], label, 

368 self.preliminary_pages[label] 

369 ) 

370 

371 @staticmethod 

372 @no_latex_plot 

373 def _skymap_plot( 

374 savedir, ra, dec, label, weights, injection=None, preliminary=False 

375 ): 

376 """Generate a skymap plot for a given set of samples 

377 

378 Parameters 

379 ---------- 

380 savedir: str 

381 the directory you wish to save the plot in 

382 ra: pesummary.utils.utils.Array 

383 array containing the samples for right ascension 

384 dec: pesummary.utils.utils.Array 

385 array containing the samples for declination 

386 label: str 

387 the label corresponding to the results file 

388 weights: list 

389 list of weights for the samples 

390 injection: list, optional 

391 list containing the injected value of ra and dec 

392 preliminary: Bool, optional 

393 if True, add a preliminary watermark to the plot 

394 """ 

395 import math 

396 

397 if injection is not None and any(math.isnan(inj) for inj in injection): 

398 injection = None 

399 fig = gw._default_skymap_plot(ra, dec, weights, injection=injection) 

400 _PlotGeneration.save( 

401 fig, os.path.join(savedir, "{}_skymap".format(label)), 

402 preliminary=preliminary 

403 ) 

404 

405 @staticmethod 

406 @no_latex_plot 

407 def _ligo_skymap_plot(savedir, ra, dec, dist, time, label, nsamples_for_skymap, 

408 webdir, multi_threading_for_skymap, injection, 

409 preliminary=False): 

410 """Generate a skymap plot for a given set of samples using the 

411 ligo.skymap package 

412 

413 Parameters 

414 ---------- 

415 savedir: str 

416 the directory you wish to save the plot in 

417 ra: pesummary.utils.utils.Array 

418 array containing the samples for right ascension 

419 dec: pesummary.utils.utils.Array 

420 array containing the samples for declination 

421 dist: pesummary.utils.utils.Array 

422 array containing the samples for luminosity distance 

423 time: pesummary.utils.utils.Array 

424 array containing the samples for the geocentric time of merger 

425 label: str 

426 the label corresponding to the results file 

427 nsamples_for_skymap: int 

428 the number of samples used to generate skymap 

429 webdir: str 

430 the directory to store the fits file 

431 preliminary: Bool, optional 

432 if True, add a preliminary watermark to the plot 

433 """ 

434 import math 

435 

436 downsampled = False 

437 if nsamples_for_skymap is not None: 

438 ra, dec, dist = resample_posterior_distribution( 

439 [ra, dec, dist], nsamples_for_skymap 

440 ) 

441 downsampled = True 

442 if injection is not None and any(math.isnan(inj) for inj in injection): 

443 injection = None 

444 fig = gw._ligo_skymap_plot( 

445 ra, dec, dist=dist, savedir=os.path.join(webdir, "samples"), 

446 nprocess=multi_threading_for_skymap, downsampled=downsampled, 

447 label=label, time=time, injection=injection 

448 ) 

449 _PlotGeneration.save( 

450 fig, os.path.join(savedir, "{}_skymap".format(label)), 

451 preliminary=preliminary 

452 ) 

453 

454 @staticmethod 

455 @no_latex_plot 

456 def _ligo_skymap_array_plot(savedir, skymap, label, preliminary=False): 

457 """Generate a skymap based on skymap probability array already generated with 

458 `ligo.skymap` 

459 

460 Parameters 

461 ---------- 

462 savedir: str 

463 the directory you wish to save the plot in 

464 skymap: np.ndarray 

465 array of skymap probabilities 

466 label: str 

467 the label corresponding to the results file 

468 preliminary: Bool, optional 

469 if True, add a preliminary watermark to the plot 

470 """ 

471 fig = gw._ligo_skymap_plot_from_array(skymap) 

472 _PlotGeneration.save( 

473 fig, os.path.join(savedir, "{}_skymap".format(label)), 

474 preliminary=preliminary 

475 ) 

476 

477 def waveform_fd_plot(self, label): 

478 """Generate a frequency domain waveform plot for a given result file 

479 

480 Parameters 

481 ---------- 

482 label: str 

483 the label corresponding to the results file 

484 """ 

485 if self.approximant[label] == {}: 

486 return 

487 self._waveform_fd_plot( 

488 self.savedir, self.detectors[label], self.maxL_samples[label], label, 

489 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint, 

490 **self.file_kwargs[label]["meta_data"] 

491 ) 

492 

493 @staticmethod 

494 def _waveform_fd_plot( 

495 savedir, detectors, maxL_samples, label, preliminary=False, 

496 checkpoint=False, **kwargs 

497 ): 

498 """Generate a frequency domain waveform plot for a given detector 

499 network and set of samples 

500 

501 Parameters 

502 ---------- 

503 savedir: str 

504 the directory you wish to save the plot in 

505 detectors: list 

506 list of detectors used in your analysis 

507 maxL_samples: dict 

508 dictionary of maximum likelihood values 

509 label: str 

510 the label corresponding to the results file 

511 preliminary: Bool, optional 

512 if True, add a preliminary watermark to the plot 

513 """ 

514 filename = os.path.join(savedir, "{}_waveform.png".format(label)) 

515 if os.path.isfile(filename) and checkpoint: 

516 return 

517 if detectors is None: 

518 detectors = ["H1", "L1"] 

519 else: 

520 detectors = detectors.split("_") 

521 

522 fig = gw._waveform_plot( 

523 detectors, maxL_samples, f_start=kwargs.get("f_start", 20.), 

524 f_low=kwargs.get("f_low", 20.0), 

525 f_max=kwargs.get("f_final", 1024.), 

526 f_ref=kwargs.get("f_ref", 20.), 

527 approximant_flags=kwargs.get("approximant_flags", {}) 

528 ) 

529 _PlotGeneration.save( 

530 fig, filename, preliminary=preliminary 

531 ) 

532 

533 def waveform_td_plot(self, label): 

534 """Generate a time domain waveform plot for a given result file 

535 

536 Parameters 

537 ---------- 

538 label: str 

539 the label corresponding to the results file 

540 """ 

541 if self.approximant[label] == {}: 

542 return 

543 self._waveform_td_plot( 

544 self.savedir, self.detectors[label], self.maxL_samples[label], label, 

545 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint, 

546 **self.file_kwargs[label]["meta_data"] 

547 ) 

548 

549 @staticmethod 

550 def _waveform_td_plot( 

551 savedir, detectors, maxL_samples, label, preliminary=False, 

552 checkpoint=False, **kwargs 

553 ): 

554 """Generate a time domain waveform plot for a given detector network 

555 and set of samples 

556 

557 Parameters 

558 ---------- 

559 savedir: str 

560 the directory you wish to save the plot in 

561 detectors: list 

562 list of detectors used in your analysis 

563 maxL_samples: dict 

564 dictionary of maximum likelihood values 

565 label: str 

566 the label corresponding to the results file 

567 preliminary: Bool, optional 

568 if True, add a preliminary watermark to the plot 

569 """ 

570 filename = os.path.join( 

571 savedir, "{}_waveform_time_domain.png".format(label) 

572 ) 

573 if os.path.isfile(filename) and checkpoint: 

574 return 

575 if detectors is None: 

576 detectors = ["H1", "L1"] 

577 else: 

578 detectors = detectors.split("_") 

579 

580 fig = gw._time_domain_waveform( 

581 detectors, maxL_samples, f_start=kwargs.get("f_start", 20.), 

582 f_low=kwargs.get("f_low", 20.0), 

583 f_max=kwargs.get("f_final", 1024.), 

584 f_ref=kwargs.get("f_ref", 20.), 

585 approximant_flags=kwargs.get("approximant_flags", {}) 

586 ) 

587 _PlotGeneration.save( 

588 fig, filename, preliminary=preliminary 

589 ) 

590 

591 def gwdata_plots(self, label): 

592 """Generate all plots associated with the gwdata 

593 

594 Parameters 

595 ---------- 

596 label: str 

597 the label corresponding to the results file 

598 """ 

599 from pesummary.utils.utils import determine_gps_time_and_window 

600 

601 base_error = "Failed to generate a %s because {}" 

602 gps_time, window = determine_gps_time_and_window( 

603 self.maxL_samples, self.labels 

604 ) 

605 functions = [ 

606 self.strain_plot, self.spectrogram_plot, self.omegascan_plot 

607 ] 

608 args = [[label], [], [gps_time, window]] 

609 func_names = ["strain_plot", "spectrogram plot", "omegascan plot"] 

610 

611 for func, args, name in zip(functions, args, func_names): 

612 self._try_to_make_a_plot(args, func, base_error % (name)) 

613 continue 

614 

615 def strain_plot(self, label): 

616 """Generate a plot showing the comparison between the data and the 

617 maxL waveform gfor a given result file 

618 

619 Parameters 

620 ---------- 

621 label: str 

622 the label corresponding to the results file 

623 """ 

624 logger.info("Launching subprocess to generate strain plot") 

625 process = mp.Process( 

626 target=self._strain_plot, 

627 args=[self.savedir, self.gwdata, self.maxL_samples[label], label] 

628 ) 

629 process.start() 

630 

631 @staticmethod 

632 def _strain_plot(savedir, gwdata, maxL_samples, label, checkpoint=False): 

633 """Generate a strain plot for a given set of samples 

634 

635 Parameters 

636 ---------- 

637 savedir: str 

638 the directory to save the plot 

639 gwdata: dict 

640 dictionary of strain data for each detector 

641 maxL_samples: dict 

642 dictionary of maximum likelihood values 

643 label: str 

644 the label corresponding to the results file 

645 """ 

646 filename = os.path.join(savedir, "{}_strain.png".format(label)) 

647 if os.path.isfile(filename) and checkpoint: 

648 return 

649 fig = gw._strain_plot(gwdata, maxL_samples) 

650 _PlotGeneration.save(fig, filename) 

651 

652 def spectrogram_plot(self): 

653 """Generate a plot showing the spectrogram for all detectors 

654 """ 

655 figs = self._spectrogram_plot(self.savedir, self.gwdata) 

656 

657 @staticmethod 

658 def _spectrogram_plot(savedir, strain): 

659 """Generate a plot showing the spectrogram for all detectors 

660 

661 Parameters 

662 ---------- 

663 savedir: str 

664 the directory you wish to save the plot in 

665 strain: dict 

666 dictionary of gwpy timeseries objects containing the strain data for 

667 each IFO 

668 """ 

669 from pesummary.gw.plots import detchar 

670 

671 figs = detchar.spectrogram(strain) 

672 for det, fig in figs.items(): 

673 _PlotGeneration.save( 

674 fig, os.path.join(savedir, "spectrogram_{}".format(det)) 

675 ) 

676 

677 def omegascan_plot(self, gps_time, window): 

678 """Generate a plot showing the omegascan for all detectors 

679 

680 Parameters 

681 ---------- 

682 gps_time: float 

683 time around which to centre the omegascan 

684 window: float 

685 window around gps time to generate plot for 

686 """ 

687 figs = self._omegascan_plot( 

688 self.savedir, self.gwdata, gps_time, window 

689 ) 

690 

691 @staticmethod 

692 def _omegascan_plot(savedir, strain, gps, window): 

693 """Generate a plot showing the spectrogram for all detectors 

694 

695 Parameters 

696 ---------- 

697 savedir: str 

698 the directory you wish to save the plot in 

699 strain: dict 

700 dictionary of gwpy timeseries objects containing the strain data for 

701 each IFO 

702 gps: float 

703 time around which to centre the omegascan 

704 window: float 

705 window around gps time to generate plot for 

706 """ 

707 from pesummary.gw.plots import detchar 

708 

709 figs = detchar.omegascan(strain, gps, window=window) 

710 for det, fig in figs.items(): 

711 _PlotGeneration.save( 

712 fig, os.path.join(savedir, "omegascan_{}".format(det)) 

713 ) 

714 

715 def skymap_comparison_plot(self, label): 

716 """Generate a plot to compare skymaps for all result files 

717 

718 Parameters 

719 ---------- 

720 label: str 

721 the label for the results file that you wish to plot 

722 """ 

723 from pesummary.utils.utils import RedirectLogger 

724 self._skymap_comparison_plot( 

725 self.savedir, self.same_samples["ra"], self.same_samples["dec"], 

726 self.labels, self.colors, self.preliminary_comparison_pages, 

727 self.checkpoint 

728 ) 

729 

730 try: 

731 import ligo.skymap # noqa: F401 

732 except ImportError: 

733 return 

734 

735 if self.no_ligo_skymap: 

736 return 

737 

738 logger.info("Launching subprocess to generate comparison skymap plot with " 

739 "ligo.skymap") 

740 fits_files = [ 

741 os.path.join(self.webdir, "samples", "{}_skymap.fits".format(label)) 

742 for label in self.labels 

743 ] 

744 with RedirectLogger("ligo.skymap", level="DEBUG") as redirector: 

745 process = mp.Process( 

746 target=self._ligo_skymap_comparison_plot_from_fits, 

747 args=[ 

748 self.savedir, fits_files, self.colors, self.labels, 

749 self.preliminary_comparison_pages, self._ligo_skymap_PID 

750 ] 

751 ) 

752 process.start() 

753 

754 @staticmethod 

755 def _skymap_comparison_plot( 

756 savedir, ra, dec, labels, colors, preliminary=False, checkpoint=False 

757 ): 

758 """Generate a plot to compare skymaps for a given set of samples 

759 

760 Parameters 

761 ---------- 

762 savedir: str 

763 the directory you wish to save the plot in 

764 ra: dict 

765 dictionary of right ascension samples for each result file 

766 dec: dict 

767 dictionary of declination samples for each result file 

768 labels: list 

769 list of labels to distinguish each result file 

770 colors: list 

771 list of colors to be used to distinguish different result files 

772 preliminary: Bool, optional 

773 if True, add a preliminary watermark to the plot 

774 """ 

775 filename = os.path.join(savedir, "combined_skymap.png") 

776 if os.path.isfile(filename) and checkpoint: 

777 return 

778 ra_list = [ra[key] for key in labels] 

779 dec_list = [dec[key] for key in labels] 

780 fig = gw._sky_map_comparison_plot(ra_list, dec_list, labels, colors) 

781 _PlotGeneration.save( 

782 fig, filename, preliminary=preliminary 

783 ) 

784 

785 @staticmethod 

786 @no_latex_plot 

787 def _ligo_skymap_comparison_plot_from_fits( 

788 savedir, fits_files, colors, labels, preliminary=False, ligo_skymap_PID=None 

789 ): 

790 """Generate a comparison skymap based on fits files already generated 

791 with `ligo.skymap` 

792 

793 Parameters 

794 ---------- 

795 savedir: str 

796 the directory you wish to save the plot in 

797 fits_files: list 

798 list of paths to the fits files 

799 colors: list 

800 list of colors to use for each skymap 

801 labels: list 

802 list of labels corresponding to each fits file 

803 preliminary: Bool, optional 

804 if True, add a preliminary watermark to the plot 

805 ligo_skymap_PID: dict, optional 

806 dictionary of process IDs for the ligo.skymap subprocesses 

807 """ 

808 import ligo.skymap.io 

809 import subprocess 

810 import time 

811 

812 if ligo_skymap_PID: 

813 for label, fits_file in zip(labels, fits_files): 

814 if label not in ligo_skymap_PID.keys(): 

815 continue 

816 while not os.path.isfile(fits_file): 

817 try: 

818 output = subprocess.check_output( 

819 ["ps -p {}".format(ligo_skymap_PID[label].pid)], 

820 shell=True 

821 ) 

822 cond1 = "summarypages" not in str(output) 

823 cond2 = "defunct" in str(output) 

824 if cond1 or cond2: 

825 if not os.path.isfile(_path): 

826 FAILURE = True 

827 break 

828 except subprocess.CalledProcessError: 

829 FAILURE = True 

830 break 

831 # wait for the process to finish 

832 time.sleep(60) 

833 

834 skymaps = [] 

835 for fits_file in fits_files: 

836 try: 

837 skymap, _ = ligo.skymap.io.read_sky_map( 

838 fits_file, nest=None 

839 ) 

840 skymaps.append(skymap) 

841 except FileNotFoundError: 

842 logger.warning( 

843 "Failed to find {}. Unable to generate comparison skymap " 

844 "plot.".format(fits_file) 

845 ) 

846 return 

847 

848 fig = gw._ligo_skymap_comparion_plot_from_array( 

849 skymaps, colors, labels 

850 ) 

851 _PlotGeneration.save( 

852 fig, os.path.join(savedir, "combined_skymap.png"), 

853 preliminary=preliminary 

854 ) 

855 

856 def waveform_comparison_fd_plot(self, label): 

857 """Generate a plot to compare the frequency domain waveform 

858 

859 Parameters 

860 ---------- 

861 label: str 

862 the label for the results file that you wish to plot 

863 """ 

864 if any(self.approximant[i] == {} for i in self.labels): 

865 return 

866 

867 self._waveform_comparison_fd_plot( 

868 self.savedir, self.maxL_samples, self.labels, self.colors, 

869 preliminary=self.preliminary_comparison_pages, checkpoint=self.checkpoint, 

870 **self.file_kwargs 

871 ) 

872 

873 @staticmethod 

874 def _waveform_comparison_fd_plot( 

875 savedir, maxL_samples, labels, colors, preliminary=False, 

876 checkpoint=False, **kwargs 

877 ): 

878 """Generate a plot to compare the frequency domain waveforms 

879 

880 Parameters 

881 ---------- 

882 savedir: str 

883 the directory you wish to save the plot in 

884 maxL_samples: dict 

885 dictionary of maximum likelihood samples for each result file 

886 labels: list 

887 list of labels to distinguish each result file 

888 colors: list 

889 list of colors to be used to distinguish different result files 

890 preliminary: Bool, optional 

891 if True, add a preliminary watermark to the plot 

892 """ 

893 filename = os.path.join(savedir, "compare_waveforms.png") 

894 if os.path.isfile(filename) and checkpoint: 

895 return 

896 samples = [maxL_samples[i] for i in labels] 

897 for num, i in enumerate(labels): 

898 samples[num]["approximant_flags"] = kwargs[i]["meta_data"].get( 

899 "approximant_flags", {} 

900 ) 

901 _defaults = [20., 20., 1024., 20.] 

902 for freq, default in zip(["f_start", "f_low", "f_final", "f_ref"], _defaults): 

903 samples[num][freq] = kwargs[i]["meta_data"].get(freq, default) 

904 

905 fig = gw._waveform_comparison_plot(samples, colors, labels) 

906 _PlotGeneration.save( 

907 fig, filename, preliminary=preliminary 

908 ) 

909 

910 def waveform_comparison_td_plot(self, label): 

911 """Generate a plot to compare the time domain waveform 

912 

913 Parameters 

914 ---------- 

915 label: str 

916 the label for the results file that you wish to plot 

917 """ 

918 if any(self.approximant[i] == {} for i in self.labels): 

919 return 

920 

921 self._waveform_comparison_td_plot( 

922 self.savedir, self.maxL_samples, self.labels, self.colors, 

923 self.preliminary_comparison_pages, self.checkpoint, 

924 **self.file_kwargs 

925 ) 

926 

927 @staticmethod 

928 def _waveform_comparison_td_plot( 

929 savedir, maxL_samples, labels, colors, preliminary=False, 

930 checkpoint=False, **kwargs 

931 ): 

932 """Generate a plot to compare the time domain waveforms 

933 

934 Parameters 

935 ---------- 

936 savedir: str 

937 the directory you wish to save the plot in 

938 maxL_samples: dict 

939 dictionary of maximum likelihood samples for each result file 

940 labels: list 

941 list of labels to distinguish each result file 

942 colors: list 

943 list of colors to be used to distinguish different result files 

944 preliminary: Bool, optional 

945 if True, add a preliminary watermark to the plot 

946 """ 

947 filename = os.path.join(savedir, "compare_time_domain_waveforms.png") 

948 if os.path.isfile(filename) and checkpoint: 

949 return 

950 samples = [maxL_samples[i] for i in labels] 

951 for num, i in enumerate(labels): 

952 samples[num]["approximant_flags"] = kwargs[i]["meta_data"].get( 

953 "approximant_flags", {} 

954 ) 

955 _defaults = [20., 20., 20.] 

956 for freq, default in zip(["f_start", "f_low", "f_ref"], _defaults): 

957 samples[num][freq] = kwargs[i]["meta_data"].get(freq, default) 

958 

959 fig = gw._time_domain_waveform_comparison_plot(samples, colors, labels) 

960 _PlotGeneration.save( 

961 fig, filename, preliminary=preliminary 

962 ) 

963 

964 def twod_comparison_contour_plot(self, label): 

965 """Generate 2d comparison contour plots 

966 

967 Parameters 

968 ---------- 

969 label: str 

970 the label for the results file that you wish to plot 

971 """ 

972 error_message = ( 

973 "Failed to generate a 2d contour plot for %s because {}" 

974 ) 

975 twod_plots = [ 

976 ["mass_ratio", "chi_eff"], ["mass_1", "mass_2"], 

977 ["luminosity_distance", "chirp_mass_source"], 

978 ["mass_1_source", "mass_2_source"], 

979 ["theta_jn", "luminosity_distance"], 

980 ["network_optimal_snr", "chirp_mass_source"] 

981 ] 

982 gridsize = ( 

983 int(self.publication_kwargs["gridsize"]) if "gridsize" in 

984 self.publication_kwargs.keys() else 100 

985 ) 

986 for plot in twod_plots: 

987 if not all( 

988 all( 

989 i in self.samples[j].keys() for i in plot 

990 ) for j in self.labels 

991 ): 

992 logger.warning( 

993 "Failed to generate 2d contour plots for {} because {} are not " 

994 "common in all result files".format( 

995 " and ".join(plot), " and ".join(plot) 

996 ) 

997 ) 

998 continue 

999 samples = [[self.samples[i][j] for j in plot] for i in self.labels] 

1000 arguments = [ 

1001 self.savedir, plot, samples, self.labels, latex_labels, 

1002 self.colors, self.linestyles, gridsize, 

1003 self.preliminary_comparison_pages, self.checkpoint 

1004 ] 

1005 self._try_to_make_a_plot( 

1006 arguments, self._twod_comparison_contour_plot, 

1007 error_message % (" and ".join(plot)) 

1008 ) 

1009 

1010 @staticmethod 

1011 def _twod_comparison_contour_plot( 

1012 savedir, plot_parameters, samples, labels, latex_labels, colors, 

1013 linestyles, gridsize, preliminary=False, checkpoint=False 

1014 ): 

1015 """Generate a 2d comparison contour plot for a given set of samples 

1016 

1017 Parameters 

1018 ---------- 

1019 savedir: str 

1020 the directory you wish to save the plot in 

1021 plot_parameters: list 

1022 list of parameters to use for the 2d contour plot 

1023 samples: list 

1024 list of samples for each parameter 

1025 labels: list 

1026 list of labels used to distinguish each result file 

1027 latex_labels: dict 

1028 dictionary containing the latex labels for each parameter 

1029 gridsize: int 

1030 the number of points to use when estimating the KDE 

1031 preliminary: Bool, optional 

1032 if True, add a preliminary watermark to the plot 

1033 """ 

1034 filename = os.path.join( 

1035 savedir, "publication", "2d_contour_plot_{}.png".format( 

1036 "_and_".join(plot_parameters) 

1037 ) 

1038 ) 

1039 if os.path.isfile(filename) and checkpoint: 

1040 return 

1041 fig = publication.twod_contour_plots( 

1042 plot_parameters, samples, labels, latex_labels, colors=colors, 

1043 linestyles=linestyles, gridsize=gridsize 

1044 ) 

1045 _PlotGeneration.save( 

1046 fig, filename, preliminary=preliminary 

1047 ) 

1048 

1049 def violin_plot(self, label): 

1050 """Generate violin plot to compare certain parameters in all result 

1051 files 

1052 

1053 Parameters 

1054 ---------- 

1055 label: str 

1056 the label for the results file that you wish to plot 

1057 """ 

1058 error_message = ( 

1059 "Failed to generate a violin plot for %s because {}" 

1060 ) 

1061 violin_plots = ["mass_ratio", "chi_eff", "chi_p", "luminosity_distance"] 

1062 

1063 for plot in violin_plots: 

1064 injection = [self.injection_data[label][plot] for label in self.labels] 

1065 if not all(plot in self.samples[j].keys() for j in self.labels): 

1066 logger.warning( 

1067 "Failed to generate violin plots for {} because {} is not " 

1068 "common in all result files".format(plot, plot) 

1069 ) 

1070 samples = [self.samples[i][plot] for i in self.labels] 

1071 arguments = [ 

1072 self.savedir, plot, samples, self.labels, latex_labels[plot], 

1073 injection, self.preliminary_comparison_pages, self.checkpoint 

1074 ] 

1075 self._try_to_make_a_plot( 

1076 arguments, self._violin_plot, error_message % (plot) 

1077 ) 

1078 

1079 @staticmethod 

1080 def _violin_plot( 

1081 savedir, plot_parameter, samples, labels, latex_label, inj_values=None, 

1082 preliminary=False, checkpoint=False, kde=ReflectionBoundedKDE, 

1083 default_bounds=True 

1084 ): 

1085 """Generate a violin plot for a given set of samples 

1086 

1087 Parameters 

1088 ---------- 

1089 savedir: str 

1090 the directory you wish to save the plot in 

1091 plot_parameter: str 

1092 name of the parameter you wish to generate a violin plot for 

1093 samples: list 

1094 list of samples for each parameter 

1095 labels: list 

1096 list of labels used to distinguish each result file 

1097 latex_label: str 

1098 latex_label correspondig to parameter 

1099 inj_value: list 

1100 list of injected values for each sample 

1101 preliminary: Bool, optional 

1102 if True, add a preliminary watermark to the plot 

1103 """ 

1104 filename = os.path.join( 

1105 savedir, "publication", "violin_plot_{}.png".format(plot_parameter) 

1106 ) 

1107 if os.path.isfile(filename) and checkpoint: 

1108 return 

1109 xlow, xhigh = None, None 

1110 if default_bounds: 

1111 xlow, xhigh = gw._return_bounds( 

1112 plot_parameter, samples, comparison=True 

1113 ) 

1114 fig = publication.violin_plots( 

1115 plot_parameter, samples, labels, latex_labels, kde=kde, 

1116 kde_kwargs={"xlow": xlow, "xhigh": xhigh}, inj_values=inj_values 

1117 ) 

1118 _PlotGeneration.save( 

1119 fig, filename, preliminary=preliminary 

1120 ) 

1121 

1122 def spin_dist_plot(self, label): 

1123 """Generate a spin disk plot to compare spins in all result 

1124 files 

1125 

1126 Parameters 

1127 ---------- 

1128 label: str 

1129 the label for the results file that you wish to plot 

1130 """ 

1131 error_message = ( 

1132 "Failed to generate a spin disk plot for %s because {}" 

1133 ) 

1134 parameters = ["a_1", "a_2", "cos_tilt_1", "cos_tilt_2"] 

1135 for num, label in enumerate(self.labels): 

1136 if not all(i in self.samples[label].keys() for i in parameters): 

1137 logger.warning( 

1138 "Failed to generate spin disk plots because {} are not " 

1139 "common in all result files".format( 

1140 " and ".join(parameters) 

1141 ) 

1142 ) 

1143 continue 

1144 samples = [self.samples[label][i] for i in parameters] 

1145 arguments = [ 

1146 self.savedir, parameters, samples, label, self.colors[num], 

1147 self.preliminary_comparison_pages, self.checkpoint 

1148 ] 

1149 

1150 self._try_to_make_a_plot( 

1151 arguments, self._spin_dist_plot, error_message % (label) 

1152 ) 

1153 

1154 @staticmethod 

1155 def _spin_dist_plot( 

1156 savedir, parameters, samples, label, color, preliminary=False, 

1157 checkpoint=False 

1158 ): 

1159 """Generate a spin disk plot for a given set of samples 

1160 

1161 Parameters 

1162 ---------- 

1163 preliminary: Bool, optional 

1164 if True, add a preliminary watermark to the plot 

1165 """ 

1166 filename = os.path.join( 

1167 savedir, "publication", "spin_disk_plot_{}.png".format(label) 

1168 ) 

1169 if os.path.isfile(filename) and checkpoint: 

1170 return 

1171 fig = publication.spin_distribution_plots( 

1172 parameters, samples, label, color=color 

1173 ) 

1174 _PlotGeneration.save( 

1175 fig, filename, preliminary=preliminary 

1176 ) 

1177 

1178 def classification_plot(self, label): 

1179 """Generate plots showing source classification probabilities 

1180 

1181 Parameters 

1182 ---------- 

1183 label: str 

1184 the label for the results file that you wish to plot 

1185 """ 

1186 if self.mcmc_samples: 

1187 samples = self.samples[label].combine 

1188 else: 

1189 samples = self.samples[label] 

1190 self._classification_plot( 

1191 self.savedir, samples, label, 

1192 self.classification_probs[label]["default"], 

1193 preliminary=self.preliminary_pages[label], checkpoint=self.checkpoint 

1194 ) 

1195 

1196 @staticmethod 

1197 @no_latex_plot 

1198 def _classification_plot( 

1199 savedir, samples, label, probabilities, preliminary=False, 

1200 checkpoint=False 

1201 ): 

1202 """Generate a plot with the PEPredicates package for a given set of 

1203 samples 

1204 

1205 Parameters 

1206 ---------- 

1207 savedir: str 

1208 the directory you wish to save the plot in 

1209 samples: dict 

1210 dictionary of samples for each parameter 

1211 label: str 

1212 the label corresponding to the result file 

1213 probabilities: dict 

1214 dictionary of classification probabilities 

1215 preliminary: Bool, optional 

1216 if True, add a preliminary watermark to the plot 

1217 """ 

1218 from pesummary.gw.classification import PAstro, EMBright 

1219 

1220 _pastro = PAstro(samples) 

1221 filename = os.path.join( 

1222 savedir, "{}.pesummary.p_astro.png".format(label) 

1223 ) 

1224 if os.path.isfile(filename) and checkpoint: 

1225 pass 

1226 else: 

1227 fig = _pastro.plot( 

1228 type="bar", probabilities={ 

1229 key: value for key, value in probabilities.items() if 

1230 key in ["BBH", "BNS", "NSBH", "Terrestrial"] 

1231 } 

1232 ) 

1233 _PlotGeneration.save( 

1234 fig, filename, preliminary=preliminary 

1235 ) 

1236 

1237 _embright = EMBright(samples) 

1238 filename = os.path.join( 

1239 savedir, "{}.pesummary.em_bright.png".format(label) 

1240 ) 

1241 if os.path.isfile(filename) and checkpoint: 

1242 pass 

1243 else: 

1244 fig = _embright.plot( 

1245 type="bar", probabilities={ 

1246 key: value for key, value in probabilities.items() if 

1247 key in ["HasNS", "HasRemnant", "HasMassGap"] 

1248 } 

1249 ) 

1250 _PlotGeneration.save( 

1251 fig, filename, preliminary=preliminary 

1252 ) 

1253 

1254 def psd_plot(self, label): 

1255 """Generate a psd plot for a given result file 

1256 

1257 Parameters 

1258 ---------- 

1259 label: str 

1260 the label corresponding to the result file 

1261 """ 

1262 error_message = ( 

1263 "Failed to generate a PSD plot for %s because {}" 

1264 ) 

1265 

1266 fmin = None 

1267 fmax = None 

1268 

1269 for num, label in enumerate(self.labels): 

1270 if list(self.psd[label].keys()) == [None]: 

1271 return 

1272 if list(self.psd[label].keys()) == []: 

1273 return 

1274 if "f_low" in list(self.file_kwargs[label]["meta_data"].keys()): 

1275 fmin = self.file_kwargs[label]["meta_data"]["f_low"] 

1276 if "f_final" in list(self.file_kwargs[label]["meta_data"].keys()): 

1277 fmax = self.file_kwargs[label]["meta_data"]["f_final"] 

1278 labels = list(self.psd[label].keys()) 

1279 frequencies = [np.array(self.psd[label][i]).T[0] for i in labels] 

1280 strains = [np.array(self.psd[label][i]).T[1] for i in labels] 

1281 arguments = [ 

1282 self.savedir, frequencies, strains, fmin, fmax, labels, label, 

1283 self.checkpoint 

1284 ] 

1285 

1286 self._try_to_make_a_plot( 

1287 arguments, self._psd_plot, error_message % (label) 

1288 ) 

1289 

1290 @staticmethod 

1291 def _psd_plot( 

1292 savedir, frequencies, strains, fmin, fmax, psd_labels, label, checkpoint=False 

1293 ): 

1294 """Generate a psd plot for a given set of samples 

1295 

1296 Parameters 

1297 ---------- 

1298 savedir: str 

1299 the directory you wish to save the plot in 

1300 frequencies: list 

1301 list of psd frequencies for each IFO 

1302 strains: list 

1303 list of psd strains for each IFO 

1304 fmin: float 

1305 frequency to start the psd plotting 

1306 fmax: float 

1307 frequency to end the psd plotting 

1308 psd_labels: list 

1309 list of IFOs used 

1310 label: str 

1311 the label used to distinguish the result file 

1312 """ 

1313 filename = os.path.join(savedir, "{}_psd_plot.png".format(label)) 

1314 if os.path.isfile(filename) and checkpoint: 

1315 return 

1316 fig = gw._psd_plot( 

1317 frequencies, strains, labels=psd_labels, fmin=fmin, fmax=fmax 

1318 ) 

1319 _PlotGeneration.save(fig, filename) 

1320 

1321 def calibration_plot(self, label): 

1322 """Generate a calibration plot for a given result file 

1323 

1324 Parameters 

1325 ---------- 

1326 label: str 

1327 the label corresponding to the result file 

1328 """ 

1329 import numpy as np 

1330 

1331 error_message = ( 

1332 "Failed to generate calibration plot for %s because {}" 

1333 ) 

1334 frequencies = np.arange(20., 1024., 1. / 4) 

1335 

1336 for num, label in enumerate(self.labels): 

1337 if list(self.calibration[label].keys()) == [None]: 

1338 return 

1339 if list(self.calibration[label].keys()) == []: 

1340 return 

1341 

1342 ifos = list(self.calibration[label].keys()) 

1343 calibration_data = [ 

1344 self.calibration[label][i] for i in ifos 

1345 ] 

1346 if "calibration" in self.priors.keys(): 

1347 prior = [self.priors["calibration"][label][i] for i in ifos] 

1348 else: 

1349 prior = None 

1350 arguments = [ 

1351 self.savedir, frequencies, calibration_data, ifos, prior, 

1352 label, self.calibration_definition[label], self.checkpoint 

1353 ] 

1354 self._try_to_make_a_plot( 

1355 arguments, self._calibration_plot, error_message % (label) 

1356 ) 

1357 

1358 @staticmethod 

1359 def _calibration_plot( 

1360 savedir, frequencies, calibration_data, calibration_labels, prior, label, 

1361 calibration_definition="data", checkpoint=False 

1362 ): 

1363 """Generate a calibration plot for a given set of samples 

1364 

1365 Parameters 

1366 ---------- 

1367 savedir: str 

1368 the directory you wish to save the plot in 

1369 frequencies: list 

1370 list of frequencies used to interpolate the calibration data 

1371 calibration_data: list 

1372 list of calibration data for each IFO 

1373 calibration_labels: list 

1374 list of IFOs used 

1375 prior: list 

1376 list containing the priors used for each IFO 

1377 label: str 

1378 the label used to distinguish the result file 

1379 calibration_definition: str 

1380 the definition of the calibration prior used (either 'data' or 'template') 

1381 """ 

1382 filename = os.path.join( 

1383 savedir, "{}_calibration_plot.png".format(label) 

1384 ) 

1385 if os.path.isfile(filename) and checkpoint: 

1386 return 

1387 fig = gw._calibration_envelope_plot( 

1388 frequencies, calibration_data, calibration_labels, prior=prior, 

1389 definition=calibration_definition 

1390 ) 

1391 _PlotGeneration.save(fig, filename) 

1392 

1393 @staticmethod 

1394 def _interactive_corner_plot( 

1395 savedir, label, samples, latex_labels, checkpoint=False 

1396 ): 

1397 """Generate an interactive corner plot for a given set of samples 

1398 

1399 Parameters 

1400 ---------- 

1401 savedir: str 

1402 the directory you wish to save the plot in 

1403 label: str 

1404 the label corresponding to the results file 

1405 samples: dict 

1406 dictionary containing PESummary.utils.utils.Array objects that 

1407 contain samples for each parameter 

1408 latex_labels: str 

1409 latex labels for each parameter in samples 

1410 """ 

1411 filename = os.path.join( 

1412 savedir, "corner", "{}_interactive_source.html".format(label) 

1413 ) 

1414 if os.path.isfile(filename) and checkpoint: 

1415 pass 

1416 else: 

1417 source_parameters = [ 

1418 "luminosity_distance", "mass_1_source", "mass_2_source", 

1419 "total_mass_source", "chirp_mass_source", "redshift" 

1420 ] 

1421 parameters = [i for i in samples.keys() if i in source_parameters] 

1422 data = [samples[parameter] for parameter in parameters] 

1423 labels = [latex_labels[parameter] for parameter in parameters] 

1424 _ = interactive.corner( 

1425 data, labels, write_to_html_file=filename, 

1426 dimensions={"width": 900, "height": 900} 

1427 ) 

1428 

1429 filename = os.path.join( 

1430 savedir, "corner", "{}_interactive_extrinsic.html".format(label) 

1431 ) 

1432 if os.path.isfile(filename) and checkpoint: 

1433 pass 

1434 else: 

1435 extrinsic_parameters = ["luminosity_distance", "psi", "ra", "dec"] 

1436 parameters = [i for i in samples.keys() if i in extrinsic_parameters] 

1437 data = [samples[parameter] for parameter in parameters] 

1438 labels = [latex_labels[parameter] for parameter in parameters] 

1439 _ = interactive.corner( 

1440 data, labels, write_to_html_file=filename 

1441 )