Coverage for pesummary/cli/summaryplots.py: 51.9%

104 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-09 22:34 +0000

1#! /usr/bin/env python 

2 

3# Licensed under an MIT style license -- see LICENSE.md 

4 

5from pesummary.utils.utils import logger, make_dir 

6from pesummary.core.cli.parser import ArgumentParser as _ArgumentParser 

7from pesummary.gw.plots.latex_labels import GWlatex_labels 

8from pesummary.core.plots.latex_labels import latex_labels 

9from pesummary.gw.plots.main import _PlotGeneration 

10from pesummary.core.cli.actions import DictionaryAction 

11from pesummary.gw.file.read import read 

12 

13__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

14 

15 

16class PlotGeneration(object): 

17 """Wrapper class for _GWPlotGeneration and _CorePlotGeneration 

18 

19 Parameters 

20 ---------- 

21 inputs: argparse.Namespace 

22 Namespace object containing the command line options 

23 colors: list, optional 

24 colors that you wish to use to distinguish different result files 

25 """ 

26 def __init__(self, inputs, colors="default", gw=False): 

27 self.inputs = inputs 

28 self.colors = colors 

29 self.gw = gw 

30 self.generate_plots() 

31 

32 def generate_plots(self): 

33 """Generate all plots for all result files passed 

34 """ 

35 logger.info("Starting to generate plots") 

36 if self.gw and self.inputs.public: 

37 object = _PublicGWPlotGeneration(self.inputs, colors=self.colors) 

38 self.ligo_skymap_PID = object.ligo_skymap_PID 

39 elif self.gw: 

40 object = _GWPlotGeneration(self.inputs, colors=self.colors) 

41 else: 

42 object = _CorePlotGeneration(self.inputs, colors=self.colors) 

43 object.generate_plots() 

44 if self.gw: 

45 self.ligo_skymap_PID = object.ligo_skymap_PID 

46 logger.info("Finished generating plots") 

47 

48 

49class _CorePlotGeneration(object): 

50 """Class to generate all plots associated with the Core module 

51 

52 Parameters 

53 ---------- 

54 inputs: argparse.Namespace 

55 Namespace object containing the command line options 

56 colors: list, optional 

57 colors that you wish to use to distinguish different result files 

58 """ 

59 def __init__(self, inputs, colors="default"): 

60 from pesummary.core.plots.main import _PlotGeneration 

61 key_data = inputs.grab_key_data_from_result_files() 

62 expert_plots = not inputs.disable_expert 

63 self.plotting_object = _PlotGeneration( 

64 webdir=inputs.webdir, labels=inputs.labels, 

65 samples=inputs.samples, kde_plot=inputs.kde_plot, 

66 existing_labels=inputs.existing_labels, 

67 existing_injection_data=inputs.existing_injection_data, 

68 existing_samples=inputs.existing_samples, 

69 same_parameters=inputs.same_parameters, 

70 injection_data=inputs.injection_data, 

71 colors=inputs.colors, custom_plotting=inputs.custom_plotting, 

72 add_to_existing=inputs.add_to_existing, priors=inputs.priors, 

73 include_prior=inputs.include_prior, weights=inputs.weights, 

74 disable_comparison=inputs.disable_comparison, 

75 linestyles=inputs.linestyles, 

76 disable_interactive=inputs.disable_interactive, 

77 disable_corner=inputs.disable_corner, 

78 multi_process=inputs.multi_process, mcmc_samples=inputs.mcmc_samples, 

79 corner_params=inputs.corner_params, expert_plots=expert_plots, 

80 checkpoint=inputs.restart_from_checkpoint, key_data=key_data 

81 ) 

82 

83 def generate_plots(self): 

84 """Generate all plots within the Core module 

85 """ 

86 self.plotting_object.generate_plots() 

87 

88 

89class _GWPlotGeneration(object): 

90 """Class to generate all plots associated with the GW module 

91 

92 Parameters 

93 ---------- 

94 inputs: argparse.Namespace 

95 Namespace object containing the command line options 

96 colors: list, optional 

97 colors that you wish to use to distinguish different result files 

98 """ 

99 def __init__(self, inputs, colors="default"): 

100 from pesummary.gw.plots.main import _PlotGeneration 

101 key_data = inputs.grab_key_data_from_result_files() 

102 expert_plots = not inputs.disable_expert 

103 self.plotting_object = _PlotGeneration( 

104 webdir=inputs.webdir, labels=inputs.labels, 

105 samples=inputs.samples, kde_plot=inputs.kde_plot, 

106 existing_labels=inputs.existing_labels, 

107 existing_injection_data=inputs.existing_injection_data, 

108 existing_samples=inputs.existing_samples, 

109 existing_file_kwargs=inputs.existing_file_kwargs, 

110 existing_approximant=inputs.existing_approximant, 

111 existing_metafile=inputs.existing_metafile, 

112 same_parameters=inputs.same_parameters, 

113 injection_data=inputs.injection_data, 

114 result_files=inputs.result_files, 

115 file_kwargs=inputs.file_kwargs, 

116 colors=inputs.colors, custom_plotting=inputs.custom_plotting, 

117 add_to_existing=inputs.add_to_existing, priors=inputs.priors, 

118 no_ligo_skymap=inputs.no_ligo_skymap, 

119 nsamples_for_skymap=inputs.nsamples_for_skymap, 

120 detectors=inputs.detectors, maxL_samples=inputs.maxL_samples, 

121 gwdata=inputs.gwdata, calibration=inputs.calibration, 

122 psd=inputs.psd, approximant=inputs.approximant, 

123 multi_threading_for_skymap=inputs.multi_threading_for_skymap, 

124 pepredicates_probs=inputs.pepredicates_probs, 

125 include_prior=inputs.include_prior, publication=inputs.publication, 

126 existing_psd=inputs.existing_psd, 

127 existing_calibration=inputs.existing_calibration, weights=inputs.weights, 

128 linestyles=inputs.linestyles, 

129 calibration_definition=inputs.calibration_definition, 

130 disable_comparison=inputs.disable_comparison, 

131 disable_interactive=inputs.disable_interactive, 

132 disable_corner=inputs.disable_corner, 

133 publication_kwargs=inputs.publication_kwargs, 

134 multi_process=inputs.multi_process, mcmc_samples=inputs.mcmc_samples, 

135 skymap=inputs.skymap, existing_skymap=inputs.existing_skymap, 

136 corner_params=inputs.corner_params, 

137 preliminary_pages=inputs.preliminary_pages, expert_plots=expert_plots, 

138 checkpoint=inputs.restart_from_checkpoint, key_data=key_data 

139 ) 

140 self.ligo_skymap_PID = self.plotting_object.ligo_skymap_PID 

141 

142 def generate_plots(self): 

143 """Generate all plots within the GW module 

144 """ 

145 self.plotting_object.generate_plots() 

146 

147 

148class _PublicGWPlotGeneration(object): 

149 """Class to generate all plots associated with the GW module 

150 

151 Parameters 

152 ---------- 

153 inputs: argparse.Namespace 

154 Namespace object containing the command line options 

155 colors: list, optional 

156 colors that you wish to use to distinguish different result files 

157 """ 

158 def __init__(self, inputs, colors="default"): 

159 from pesummary.gw.plots.public import _PlotGeneration 

160 expert_plots = not inputs.disable_expert 

161 self.plotting_object = _PlotGeneration( 

162 webdir=inputs.webdir, labels=inputs.labels, 

163 samples=inputs.samples, kde_plot=inputs.kde_plot, 

164 existing_labels=inputs.existing_labels, 

165 existing_injection_data=inputs.existing_injection_data, 

166 existing_samples=inputs.existing_samples, 

167 existing_file_kwargs=inputs.existing_file_kwargs, 

168 existing_approximant=inputs.existing_approximant, 

169 existing_metafile=inputs.existing_metafile, 

170 same_parameters=inputs.same_parameters, 

171 injection_data=inputs.injection_data, 

172 result_files=inputs.result_files, 

173 file_kwargs=inputs.file_kwargs, 

174 colors=inputs.colors, custom_plotting=inputs.custom_plotting, 

175 add_to_existing=inputs.add_to_existing, priors=inputs.priors, 

176 no_ligo_skymap=inputs.no_ligo_skymap, 

177 nsamples_for_skymap=inputs.nsamples_for_skymap, 

178 detectors=inputs.detectors, maxL_samples=inputs.maxL_samples, 

179 gwdata=inputs.gwdata, calibration=inputs.calibration, 

180 psd=inputs.psd, approximant=inputs.approximant, 

181 multi_threading_for_skymap=inputs.multi_threading_for_skymap, 

182 pepredicates_probs=inputs.pepredicates_probs, 

183 include_prior=inputs.include_prior, publication=inputs.publication, 

184 existing_psd=inputs.existing_psd, 

185 existing_calibration=inputs.existing_calibration, weights=inputs.weights, 

186 linestyles=inputs.linestyles, 

187 disable_comparison=inputs.disable_comparison, 

188 disable_interactive=inputs.disable_interactive, 

189 disable_corner=inputs.disable_corner, 

190 publication_kwargs=inputs.publication_kwargs, 

191 multi_process=inputs.multi_process, mcmc_samples=inputs.mcmc_samples, 

192 skymap=inputs.skymap, existing_skymap=inputs.existing_skymap, 

193 corner_params=inputs.corner_params, 

194 preliminary_pages=inputs.preliminary_pages, expert_plots=expert_plots, 

195 checkpoint=inputs.restart_from_checkpoint 

196 ) 

197 self.ligo_skymap_PID = self.plotting_object.ligo_skymap_PID 

198 

199 def generate_plots(self): 

200 """Generate all plots within the GW module 

201 """ 

202 self.plotting_object.generate_plots() 

203 

204 

205class ArgumentParser(_ArgumentParser): 

206 def _pesummary_options(self): 

207 options = super(ArgumentParser, self)._pesummary_options() 

208 options.update( 

209 { 

210 "--plot": { 

211 "help": "name of the publication plot you wish to produce", 

212 "choices": [ 

213 "1d_histogram", "sample_evolution", "autocorrelation", 

214 "skymap" 

215 ], 

216 "default": "2d_contour" 

217 }, 

218 "--parameters": { 

219 "nargs": "+", 

220 "help": "parameters of the 2d contour plot you wish to make", 

221 }, 

222 "--plot_kwargs": { 

223 "help": "Optional plotting kwargs", 

224 "default": {}, 

225 "nargs": "+", 

226 "action": DictionaryAction 

227 }, 

228 "--inj": { 

229 "help": "Injected value", 

230 }, 

231 } 

232 ) 

233 return options 

234 

235 

236def check_inputs(opts): 

237 """Check that the inputs are compatible with `summaryplots` 

238 """ 

239 from pesummary.utils.exceptions import InputError 

240 

241 base = "Please provide {} for each result file" 

242 if opts.inj is None: 

243 opts.inj = [float("nan")] * len(opts.samples) 

244 if opts.labels is None: 

245 opts.labels = [i for i in range(len(opts.samples))] 

246 if len(opts.samples) != len(opts.labels): 

247 raise InputError(base.format("a label")) 

248 if len(opts.samples) != len(opts.inj): 

249 raise InputError(base.format("the injected value")) 

250 if opts.burnin is not None: 

251 opts.burnin = int(opts.burnin) 

252 return opts 

253 

254 

255def read_input_file(path_to_file): 

256 """Use PESummary to read a result file 

257 

258 Parameters 

259 ---------- 

260 path_to_file: str 

261 path to the results file 

262 """ 

263 from pesummary.gw.file.read import read 

264 

265 f = read(path_to_file) 

266 return f 

267 

268 

269def oned_histogram_plot(opts): 

270 """Make a 1d histogram plot 

271 """ 

272 for num, samples in enumerate(opts.samples): 

273 data = read(samples) 

274 for parameter in opts.parameters: 

275 _PlotGeneration._oned_histogram_plot( 

276 opts.webdir, opts.labels[num], parameter, 

277 data.samples_dict[parameter][opts.burnin:], 

278 latex_labels[parameter], opts.inj[num], kde=opts.kde_plot 

279 ) 

280 

281 

282def sample_evolution_plot(opts): 

283 """Make a sample evolution plot 

284 """ 

285 for num, samples in enumerate(opts.samples): 

286 data = read(samples) 

287 for parameter in opts.parameters: 

288 _PlotGeneration._sample_evolution_plot( 

289 opts.webdir, opts.labels[num], parameter, 

290 data.samples_dict[parameter][opts.burnin:], 

291 latex_labels[parameter], opts.inj[num] 

292 ) 

293 

294 

295def autocorrelation_plot(opts): 

296 """Make an autocorrelation plot 

297 """ 

298 for num, samples in enumerate(opts.samples): 

299 data = read(samples) 

300 for parameter in opts.parameters: 

301 _PlotGeneration._autocorrelation_plot( 

302 opts.webdir, opts.labels[num], parameter, 

303 data.samples_dict[parameter][opts.burnin:] 

304 ) 

305 

306 

307def skymap_plot(opts): 

308 """Make a skymap plot 

309 """ 

310 for num, samples in enumerate(opts.samples): 

311 data = read(samples) 

312 _PlotGeneration._skymap_plot( 

313 opts.webdir, data.samples_dict["ra"][opts.burnin:], 

314 data.samples_dict["dec"][opts.burnin:], opts.labels[num] 

315 ) 

316 

317 

318def main(args=None): 

319 """The main interface for `summaryplots` 

320 """ 

321 latex_labels.update(GWlatex_labels) 

322 parser = ArgumentParser(description=__doc__) 

323 parser.add_known_options_to_parser( 

324 [ 

325 "--webdir", "--samples", "--labels", "--plot", "--parameters", 

326 "--kde_plot", "--burnin", "--disable_comparison", "--plot_kwargs", 

327 "--inj" 

328 ] 

329 ) 

330 opts = parser.parse_args(args=args) 

331 opts = check_inputs(opts) 

332 make_dir(opts.webdir) 

333 func_map = { 

334 "1d_histogram": oned_histogram_plot, 

335 "sample_evolution": sample_evolution_plot, 

336 "autocorrelation": autocorrelation_plot, 

337 "skymap": skymap_plot 

338 } 

339 func_map[opts.plot](opts) 

340 

341 

342if __name__ == "__main__": 

343 main()