Coverage for pesummary/cli/summaryplots.py: 51.9%
104 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1#! /usr/bin/env python
3# Licensed under an MIT style license -- see LICENSE.md
5from pesummary.utils.utils import logger, make_dir
6from pesummary.core.cli.parser import ArgumentParser as _ArgumentParser
7from pesummary.gw.plots.latex_labels import GWlatex_labels
8from pesummary.core.plots.latex_labels import latex_labels
9from pesummary.gw.plots.main import _PlotGeneration
10from pesummary.core.cli.actions import DictionaryAction
11from pesummary.gw.file.read import read
13__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
16class PlotGeneration(object):
17 """Wrapper class for _GWPlotGeneration and _CorePlotGeneration
19 Parameters
20 ----------
21 inputs: argparse.Namespace
22 Namespace object containing the command line options
23 colors: list, optional
24 colors that you wish to use to distinguish different result files
25 """
26 def __init__(self, inputs, colors="default", gw=False):
27 self.inputs = inputs
28 self.colors = colors
29 self.gw = gw
30 self.generate_plots()
32 def generate_plots(self):
33 """Generate all plots for all result files passed
34 """
35 logger.info("Starting to generate plots")
36 if self.gw and self.inputs.public:
37 object = _PublicGWPlotGeneration(self.inputs, colors=self.colors)
38 self.ligo_skymap_PID = object.ligo_skymap_PID
39 elif self.gw:
40 object = _GWPlotGeneration(self.inputs, colors=self.colors)
41 else:
42 object = _CorePlotGeneration(self.inputs, colors=self.colors)
43 object.generate_plots()
44 if self.gw:
45 self.ligo_skymap_PID = object.ligo_skymap_PID
46 logger.info("Finished generating plots")
49class _CorePlotGeneration(object):
50 """Class to generate all plots associated with the Core module
52 Parameters
53 ----------
54 inputs: argparse.Namespace
55 Namespace object containing the command line options
56 colors: list, optional
57 colors that you wish to use to distinguish different result files
58 """
59 def __init__(self, inputs, colors="default"):
60 from pesummary.core.plots.main import _PlotGeneration
61 key_data = inputs.grab_key_data_from_result_files()
62 expert_plots = not inputs.disable_expert
63 self.plotting_object = _PlotGeneration(
64 webdir=inputs.webdir, labels=inputs.labels,
65 samples=inputs.samples, kde_plot=inputs.kde_plot,
66 existing_labels=inputs.existing_labels,
67 existing_injection_data=inputs.existing_injection_data,
68 existing_samples=inputs.existing_samples,
69 same_parameters=inputs.same_parameters,
70 injection_data=inputs.injection_data,
71 colors=inputs.colors, custom_plotting=inputs.custom_plotting,
72 add_to_existing=inputs.add_to_existing, priors=inputs.priors,
73 include_prior=inputs.include_prior, weights=inputs.weights,
74 disable_comparison=inputs.disable_comparison,
75 linestyles=inputs.linestyles,
76 disable_interactive=inputs.disable_interactive,
77 disable_corner=inputs.disable_corner,
78 multi_process=inputs.multi_process, mcmc_samples=inputs.mcmc_samples,
79 corner_params=inputs.corner_params, expert_plots=expert_plots,
80 checkpoint=inputs.restart_from_checkpoint, key_data=key_data
81 )
83 def generate_plots(self):
84 """Generate all plots within the Core module
85 """
86 self.plotting_object.generate_plots()
89class _GWPlotGeneration(object):
90 """Class to generate all plots associated with the GW module
92 Parameters
93 ----------
94 inputs: argparse.Namespace
95 Namespace object containing the command line options
96 colors: list, optional
97 colors that you wish to use to distinguish different result files
98 """
99 def __init__(self, inputs, colors="default"):
100 from pesummary.gw.plots.main import _PlotGeneration
101 key_data = inputs.grab_key_data_from_result_files()
102 expert_plots = not inputs.disable_expert
103 self.plotting_object = _PlotGeneration(
104 webdir=inputs.webdir, labels=inputs.labels,
105 samples=inputs.samples, kde_plot=inputs.kde_plot,
106 existing_labels=inputs.existing_labels,
107 existing_injection_data=inputs.existing_injection_data,
108 existing_samples=inputs.existing_samples,
109 existing_file_kwargs=inputs.existing_file_kwargs,
110 existing_approximant=inputs.existing_approximant,
111 existing_metafile=inputs.existing_metafile,
112 same_parameters=inputs.same_parameters,
113 injection_data=inputs.injection_data,
114 result_files=inputs.result_files,
115 file_kwargs=inputs.file_kwargs,
116 colors=inputs.colors, custom_plotting=inputs.custom_plotting,
117 add_to_existing=inputs.add_to_existing, priors=inputs.priors,
118 no_ligo_skymap=inputs.no_ligo_skymap,
119 nsamples_for_skymap=inputs.nsamples_for_skymap,
120 detectors=inputs.detectors, maxL_samples=inputs.maxL_samples,
121 gwdata=inputs.gwdata, calibration=inputs.calibration,
122 psd=inputs.psd, approximant=inputs.approximant,
123 multi_threading_for_skymap=inputs.multi_threading_for_skymap,
124 pepredicates_probs=inputs.pepredicates_probs,
125 include_prior=inputs.include_prior, publication=inputs.publication,
126 existing_psd=inputs.existing_psd,
127 existing_calibration=inputs.existing_calibration, weights=inputs.weights,
128 linestyles=inputs.linestyles,
129 calibration_definition=inputs.calibration_definition,
130 disable_comparison=inputs.disable_comparison,
131 disable_interactive=inputs.disable_interactive,
132 disable_corner=inputs.disable_corner,
133 publication_kwargs=inputs.publication_kwargs,
134 multi_process=inputs.multi_process, mcmc_samples=inputs.mcmc_samples,
135 skymap=inputs.skymap, existing_skymap=inputs.existing_skymap,
136 corner_params=inputs.corner_params,
137 preliminary_pages=inputs.preliminary_pages, expert_plots=expert_plots,
138 checkpoint=inputs.restart_from_checkpoint, key_data=key_data
139 )
140 self.ligo_skymap_PID = self.plotting_object.ligo_skymap_PID
142 def generate_plots(self):
143 """Generate all plots within the GW module
144 """
145 self.plotting_object.generate_plots()
148class _PublicGWPlotGeneration(object):
149 """Class to generate all plots associated with the GW module
151 Parameters
152 ----------
153 inputs: argparse.Namespace
154 Namespace object containing the command line options
155 colors: list, optional
156 colors that you wish to use to distinguish different result files
157 """
158 def __init__(self, inputs, colors="default"):
159 from pesummary.gw.plots.public import _PlotGeneration
160 expert_plots = not inputs.disable_expert
161 self.plotting_object = _PlotGeneration(
162 webdir=inputs.webdir, labels=inputs.labels,
163 samples=inputs.samples, kde_plot=inputs.kde_plot,
164 existing_labels=inputs.existing_labels,
165 existing_injection_data=inputs.existing_injection_data,
166 existing_samples=inputs.existing_samples,
167 existing_file_kwargs=inputs.existing_file_kwargs,
168 existing_approximant=inputs.existing_approximant,
169 existing_metafile=inputs.existing_metafile,
170 same_parameters=inputs.same_parameters,
171 injection_data=inputs.injection_data,
172 result_files=inputs.result_files,
173 file_kwargs=inputs.file_kwargs,
174 colors=inputs.colors, custom_plotting=inputs.custom_plotting,
175 add_to_existing=inputs.add_to_existing, priors=inputs.priors,
176 no_ligo_skymap=inputs.no_ligo_skymap,
177 nsamples_for_skymap=inputs.nsamples_for_skymap,
178 detectors=inputs.detectors, maxL_samples=inputs.maxL_samples,
179 gwdata=inputs.gwdata, calibration=inputs.calibration,
180 psd=inputs.psd, approximant=inputs.approximant,
181 multi_threading_for_skymap=inputs.multi_threading_for_skymap,
182 pepredicates_probs=inputs.pepredicates_probs,
183 include_prior=inputs.include_prior, publication=inputs.publication,
184 existing_psd=inputs.existing_psd,
185 existing_calibration=inputs.existing_calibration, weights=inputs.weights,
186 linestyles=inputs.linestyles,
187 disable_comparison=inputs.disable_comparison,
188 disable_interactive=inputs.disable_interactive,
189 disable_corner=inputs.disable_corner,
190 publication_kwargs=inputs.publication_kwargs,
191 multi_process=inputs.multi_process, mcmc_samples=inputs.mcmc_samples,
192 skymap=inputs.skymap, existing_skymap=inputs.existing_skymap,
193 corner_params=inputs.corner_params,
194 preliminary_pages=inputs.preliminary_pages, expert_plots=expert_plots,
195 checkpoint=inputs.restart_from_checkpoint
196 )
197 self.ligo_skymap_PID = self.plotting_object.ligo_skymap_PID
199 def generate_plots(self):
200 """Generate all plots within the GW module
201 """
202 self.plotting_object.generate_plots()
205class ArgumentParser(_ArgumentParser):
206 def _pesummary_options(self):
207 options = super(ArgumentParser, self)._pesummary_options()
208 options.update(
209 {
210 "--plot": {
211 "help": "name of the publication plot you wish to produce",
212 "choices": [
213 "1d_histogram", "sample_evolution", "autocorrelation",
214 "skymap"
215 ],
216 "default": "2d_contour"
217 },
218 "--parameters": {
219 "nargs": "+",
220 "help": "parameters of the 2d contour plot you wish to make",
221 },
222 "--plot_kwargs": {
223 "help": "Optional plotting kwargs",
224 "default": {},
225 "nargs": "+",
226 "action": DictionaryAction
227 },
228 "--inj": {
229 "help": "Injected value",
230 },
231 }
232 )
233 return options
236def check_inputs(opts):
237 """Check that the inputs are compatible with `summaryplots`
238 """
239 from pesummary.utils.exceptions import InputError
241 base = "Please provide {} for each result file"
242 if opts.inj is None:
243 opts.inj = [float("nan")] * len(opts.samples)
244 if opts.labels is None:
245 opts.labels = [i for i in range(len(opts.samples))]
246 if len(opts.samples) != len(opts.labels):
247 raise InputError(base.format("a label"))
248 if len(opts.samples) != len(opts.inj):
249 raise InputError(base.format("the injected value"))
250 if opts.burnin is not None:
251 opts.burnin = int(opts.burnin)
252 return opts
255def read_input_file(path_to_file):
256 """Use PESummary to read a result file
258 Parameters
259 ----------
260 path_to_file: str
261 path to the results file
262 """
263 from pesummary.gw.file.read import read
265 f = read(path_to_file)
266 return f
269def oned_histogram_plot(opts):
270 """Make a 1d histogram plot
271 """
272 for num, samples in enumerate(opts.samples):
273 data = read(samples)
274 for parameter in opts.parameters:
275 _PlotGeneration._oned_histogram_plot(
276 opts.webdir, opts.labels[num], parameter,
277 data.samples_dict[parameter][opts.burnin:],
278 latex_labels[parameter], opts.inj[num], kde=opts.kde_plot
279 )
282def sample_evolution_plot(opts):
283 """Make a sample evolution plot
284 """
285 for num, samples in enumerate(opts.samples):
286 data = read(samples)
287 for parameter in opts.parameters:
288 _PlotGeneration._sample_evolution_plot(
289 opts.webdir, opts.labels[num], parameter,
290 data.samples_dict[parameter][opts.burnin:],
291 latex_labels[parameter], opts.inj[num]
292 )
295def autocorrelation_plot(opts):
296 """Make an autocorrelation plot
297 """
298 for num, samples in enumerate(opts.samples):
299 data = read(samples)
300 for parameter in opts.parameters:
301 _PlotGeneration._autocorrelation_plot(
302 opts.webdir, opts.labels[num], parameter,
303 data.samples_dict[parameter][opts.burnin:]
304 )
307def skymap_plot(opts):
308 """Make a skymap plot
309 """
310 for num, samples in enumerate(opts.samples):
311 data = read(samples)
312 _PlotGeneration._skymap_plot(
313 opts.webdir, data.samples_dict["ra"][opts.burnin:],
314 data.samples_dict["dec"][opts.burnin:], opts.labels[num]
315 )
318def main(args=None):
319 """The main interface for `summaryplots`
320 """
321 latex_labels.update(GWlatex_labels)
322 parser = ArgumentParser(description=__doc__)
323 parser.add_known_options_to_parser(
324 [
325 "--webdir", "--samples", "--labels", "--plot", "--parameters",
326 "--kde_plot", "--burnin", "--disable_comparison", "--plot_kwargs",
327 "--inj"
328 ]
329 )
330 opts = parser.parse_args(args=args)
331 opts = check_inputs(opts)
332 make_dir(opts.webdir)
333 func_map = {
334 "1d_histogram": oned_histogram_plot,
335 "sample_evolution": sample_evolution_plot,
336 "autocorrelation": autocorrelation_plot,
337 "skymap": skymap_plot
338 }
339 func_map[opts.plot](opts)
342if __name__ == "__main__":
343 main()