Coverage for pesummary/cli/summaryplots.py: 51.9%

104 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1#! /usr/bin/env python 

2 

3# Licensed under an MIT style license -- see LICENSE.md 

4 

5from pesummary.utils.utils import logger, make_dir 

6from pesummary.core.cli.parser import ArgumentParser as _ArgumentParser 

7from pesummary.gw.plots.latex_labels import GWlatex_labels 

8from pesummary.core.plots.latex_labels import latex_labels 

9from pesummary.gw.plots.main import _PlotGeneration 

10from pesummary.core.cli.actions import DictionaryAction 

11from pesummary.gw.file.read import read 

12 

13__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

14 

15 

16class PlotGeneration(object): 

17 """Wrapper class for _GWPlotGeneration and _CorePlotGeneration 

18 

19 Parameters 

20 ---------- 

21 inputs: argparse.Namespace 

22 Namespace object containing the command line options 

23 colors: list, optional 

24 colors that you wish to use to distinguish different result files 

25 """ 

26 def __init__(self, inputs, colors="default", gw=False): 

27 self.inputs = inputs 

28 self.colors = colors 

29 self.gw = gw 

30 self.generate_plots() 

31 

32 def generate_plots(self): 

33 """Generate all plots for all result files passed 

34 """ 

35 logger.info("Starting to generate plots") 

36 if self.gw and self.inputs.public: 

37 object = _PublicGWPlotGeneration(self.inputs, colors=self.colors) 

38 self.ligo_skymap_PID = object.ligo_skymap_PID 

39 elif self.gw: 

40 object = _GWPlotGeneration(self.inputs, colors=self.colors) 

41 else: 

42 object = _CorePlotGeneration(self.inputs, colors=self.colors) 

43 object.generate_plots() 

44 if self.gw: 

45 self.ligo_skymap_PID = object.ligo_skymap_PID 

46 logger.info("Finished generating plots") 

47 

48 

49class _CorePlotGeneration(object): 

50 """Class to generate all plots associated with the Core module 

51 

52 Parameters 

53 ---------- 

54 inputs: argparse.Namespace 

55 Namespace object containing the command line options 

56 colors: list, optional 

57 colors that you wish to use to distinguish different result files 

58 """ 

59 def __init__(self, inputs, colors="default"): 

60 from pesummary.core.plots.main import _PlotGeneration 

61 key_data = inputs.grab_key_data_from_result_files() 

62 expert_plots = not inputs.disable_expert 

63 self.plotting_object = _PlotGeneration( 

64 webdir=inputs.webdir, labels=inputs.labels, 

65 samples=inputs.samples, kde_plot=inputs.kde_plot, 

66 existing_labels=inputs.existing_labels, 

67 existing_injection_data=inputs.existing_injection_data, 

68 existing_samples=inputs.existing_samples, 

69 same_parameters=inputs.same_parameters, 

70 injection_data=inputs.injection_data, 

71 colors=inputs.colors, custom_plotting=inputs.custom_plotting, 

72 add_to_existing=inputs.add_to_existing, priors=inputs.priors, 

73 include_prior=inputs.include_prior, weights=inputs.weights, 

74 disable_comparison=inputs.disable_comparison, 

75 linestyles=inputs.linestyles, 

76 disable_interactive=inputs.disable_interactive, 

77 disable_corner=inputs.disable_corner, 

78 multi_process=inputs.multi_process, mcmc_samples=inputs.mcmc_samples, 

79 corner_params=inputs.corner_params, expert_plots=expert_plots, 

80 checkpoint=inputs.restart_from_checkpoint, key_data=key_data 

81 ) 

82 

83 def generate_plots(self): 

84 """Generate all plots within the Core module 

85 """ 

86 self.plotting_object.generate_plots() 

87 

88 

89class _GWPlotGeneration(object): 

90 """Class to generate all plots associated with the GW module 

91 

92 Parameters 

93 ---------- 

94 inputs: argparse.Namespace 

95 Namespace object containing the command line options 

96 colors: list, optional 

97 colors that you wish to use to distinguish different result files 

98 """ 

99 def __init__(self, inputs, colors="default"): 

100 from pesummary.gw.plots.main import _PlotGeneration 

101 key_data = inputs.grab_key_data_from_result_files() 

102 expert_plots = not inputs.disable_expert 

103 self.plotting_object = _PlotGeneration( 

104 webdir=inputs.webdir, labels=inputs.labels, 

105 samples=inputs.samples, kde_plot=inputs.kde_plot, 

106 existing_labels=inputs.existing_labels, 

107 existing_injection_data=inputs.existing_injection_data, 

108 existing_samples=inputs.existing_samples, 

109 existing_file_kwargs=inputs.existing_file_kwargs, 

110 existing_approximant=inputs.existing_approximant, 

111 existing_metafile=inputs.existing_metafile, 

112 same_parameters=inputs.same_parameters, 

113 injection_data=inputs.injection_data, 

114 result_files=inputs.result_files, 

115 file_kwargs=inputs.file_kwargs, 

116 colors=inputs.colors, custom_plotting=inputs.custom_plotting, 

117 add_to_existing=inputs.add_to_existing, priors=inputs.priors, 

118 no_ligo_skymap=inputs.no_ligo_skymap, 

119 nsamples_for_skymap=inputs.nsamples_for_skymap, 

120 detectors=inputs.detectors, maxL_samples=inputs.maxL_samples, 

121 gwdata=inputs.gwdata, calibration=inputs.calibration, 

122 psd=inputs.psd, approximant=inputs.approximant, 

123 multi_threading_for_skymap=inputs.multi_threading_for_skymap, 

124 pepredicates_probs=inputs.pepredicates_probs, 

125 include_prior=inputs.include_prior, publication=inputs.publication, 

126 existing_psd=inputs.existing_psd, 

127 existing_calibration=inputs.existing_calibration, weights=inputs.weights, 

128 linestyles=inputs.linestyles, 

129 disable_comparison=inputs.disable_comparison, 

130 disable_interactive=inputs.disable_interactive, 

131 disable_corner=inputs.disable_corner, 

132 publication_kwargs=inputs.publication_kwargs, 

133 multi_process=inputs.multi_process, mcmc_samples=inputs.mcmc_samples, 

134 skymap=inputs.skymap, existing_skymap=inputs.existing_skymap, 

135 corner_params=inputs.corner_params, 

136 preliminary_pages=inputs.preliminary_pages, expert_plots=expert_plots, 

137 checkpoint=inputs.restart_from_checkpoint, key_data=key_data 

138 ) 

139 self.ligo_skymap_PID = self.plotting_object.ligo_skymap_PID 

140 

141 def generate_plots(self): 

142 """Generate all plots within the GW module 

143 """ 

144 self.plotting_object.generate_plots() 

145 

146 

147class _PublicGWPlotGeneration(object): 

148 """Class to generate all plots associated with the GW module 

149 

150 Parameters 

151 ---------- 

152 inputs: argparse.Namespace 

153 Namespace object containing the command line options 

154 colors: list, optional 

155 colors that you wish to use to distinguish different result files 

156 """ 

157 def __init__(self, inputs, colors="default"): 

158 from pesummary.gw.plots.public import _PlotGeneration 

159 expert_plots = not inputs.disable_expert 

160 self.plotting_object = _PlotGeneration( 

161 webdir=inputs.webdir, labels=inputs.labels, 

162 samples=inputs.samples, kde_plot=inputs.kde_plot, 

163 existing_labels=inputs.existing_labels, 

164 existing_injection_data=inputs.existing_injection_data, 

165 existing_samples=inputs.existing_samples, 

166 existing_file_kwargs=inputs.existing_file_kwargs, 

167 existing_approximant=inputs.existing_approximant, 

168 existing_metafile=inputs.existing_metafile, 

169 same_parameters=inputs.same_parameters, 

170 injection_data=inputs.injection_data, 

171 result_files=inputs.result_files, 

172 file_kwargs=inputs.file_kwargs, 

173 colors=inputs.colors, custom_plotting=inputs.custom_plotting, 

174 add_to_existing=inputs.add_to_existing, priors=inputs.priors, 

175 no_ligo_skymap=inputs.no_ligo_skymap, 

176 nsamples_for_skymap=inputs.nsamples_for_skymap, 

177 detectors=inputs.detectors, maxL_samples=inputs.maxL_samples, 

178 gwdata=inputs.gwdata, calibration=inputs.calibration, 

179 psd=inputs.psd, approximant=inputs.approximant, 

180 multi_threading_for_skymap=inputs.multi_threading_for_skymap, 

181 pepredicates_probs=inputs.pepredicates_probs, 

182 include_prior=inputs.include_prior, publication=inputs.publication, 

183 existing_psd=inputs.existing_psd, 

184 existing_calibration=inputs.existing_calibration, weights=inputs.weights, 

185 linestyles=inputs.linestyles, 

186 disable_comparison=inputs.disable_comparison, 

187 disable_interactive=inputs.disable_interactive, 

188 disable_corner=inputs.disable_corner, 

189 publication_kwargs=inputs.publication_kwargs, 

190 multi_process=inputs.multi_process, mcmc_samples=inputs.mcmc_samples, 

191 skymap=inputs.skymap, existing_skymap=inputs.existing_skymap, 

192 corner_params=inputs.corner_params, 

193 preliminary_pages=inputs.preliminary_pages, expert_plots=expert_plots, 

194 checkpoint=inputs.restart_from_checkpoint 

195 ) 

196 self.ligo_skymap_PID = self.plotting_object.ligo_skymap_PID 

197 

198 def generate_plots(self): 

199 """Generate all plots within the GW module 

200 """ 

201 self.plotting_object.generate_plots() 

202 

203 

204class ArgumentParser(_ArgumentParser): 

205 def _pesummary_options(self): 

206 options = super(ArgumentParser, self)._pesummary_options() 

207 options.update( 

208 { 

209 "--plot": { 

210 "help": "name of the publication plot you wish to produce", 

211 "choices": [ 

212 "1d_histogram", "sample_evolution", "autocorrelation", 

213 "skymap" 

214 ], 

215 "default": "2d_contour" 

216 }, 

217 "--parameters": { 

218 "nargs": "+", 

219 "help": "parameters of the 2d contour plot you wish to make", 

220 }, 

221 "--plot_kwargs": { 

222 "help": "Optional plotting kwargs", 

223 "default": {}, 

224 "nargs": "+", 

225 "action": DictionaryAction 

226 }, 

227 "--inj": { 

228 "help": "Injected value", 

229 }, 

230 } 

231 ) 

232 return options 

233 

234 

235def check_inputs(opts): 

236 """Check that the inputs are compatible with `summaryplots` 

237 """ 

238 from pesummary.utils.exceptions import InputError 

239 

240 base = "Please provide {} for each result file" 

241 if opts.inj is None: 

242 opts.inj = [float("nan")] * len(opts.samples) 

243 if opts.labels is None: 

244 opts.labels = [i for i in range(len(opts.samples))] 

245 if len(opts.samples) != len(opts.labels): 

246 raise InputError(base.format("a label")) 

247 if len(opts.samples) != len(opts.inj): 

248 raise InputError(base.format("the injected value")) 

249 if opts.burnin is not None: 

250 opts.burnin = int(opts.burnin) 

251 return opts 

252 

253 

254def read_input_file(path_to_file): 

255 """Use PESummary to read a result file 

256 

257 Parameters 

258 ---------- 

259 path_to_file: str 

260 path to the results file 

261 """ 

262 from pesummary.gw.file.read import read 

263 

264 f = read(path_to_file) 

265 return f 

266 

267 

268def oned_histogram_plot(opts): 

269 """Make a 1d histogram plot 

270 """ 

271 for num, samples in enumerate(opts.samples): 

272 data = read(samples) 

273 for parameter in opts.parameters: 

274 _PlotGeneration._oned_histogram_plot( 

275 opts.webdir, opts.labels[num], parameter, 

276 data.samples_dict[parameter][opts.burnin:], 

277 latex_labels[parameter], opts.inj[num], kde=opts.kde_plot 

278 ) 

279 

280 

281def sample_evolution_plot(opts): 

282 """Make a sample evolution plot 

283 """ 

284 for num, samples in enumerate(opts.samples): 

285 data = read(samples) 

286 for parameter in opts.parameters: 

287 _PlotGeneration._sample_evolution_plot( 

288 opts.webdir, opts.labels[num], parameter, 

289 data.samples_dict[parameter][opts.burnin:], 

290 latex_labels[parameter], opts.inj[num] 

291 ) 

292 

293 

294def autocorrelation_plot(opts): 

295 """Make an autocorrelation plot 

296 """ 

297 for num, samples in enumerate(opts.samples): 

298 data = read(samples) 

299 for parameter in opts.parameters: 

300 _PlotGeneration._autocorrelation_plot( 

301 opts.webdir, opts.labels[num], parameter, 

302 data.samples_dict[parameter][opts.burnin:] 

303 ) 

304 

305 

306def skymap_plot(opts): 

307 """Make a skymap plot 

308 """ 

309 for num, samples in enumerate(opts.samples): 

310 data = read(samples) 

311 _PlotGeneration._skymap_plot( 

312 opts.webdir, data.samples_dict["ra"][opts.burnin:], 

313 data.samples_dict["dec"][opts.burnin:], opts.labels[num] 

314 ) 

315 

316 

317def main(args=None): 

318 """The main interface for `summaryplots` 

319 """ 

320 latex_labels.update(GWlatex_labels) 

321 parser = ArgumentParser(description=__doc__) 

322 parser.add_known_options_to_parser( 

323 [ 

324 "--webdir", "--samples", "--labels", "--plot", "--parameters", 

325 "--kde_plot", "--burnin", "--disable_comparison", "--plot_kwargs", 

326 "--inj" 

327 ] 

328 ) 

329 opts = parser.parse_args(args=args) 

330 opts = check_inputs(opts) 

331 make_dir(opts.webdir) 

332 func_map = { 

333 "1d_histogram": oned_histogram_plot, 

334 "sample_evolution": sample_evolution_plot, 

335 "autocorrelation": autocorrelation_plot, 

336 "skymap": skymap_plot 

337 } 

338 func_map[opts.plot](opts) 

339 

340 

341if __name__ == "__main__": 

342 main()