Coverage for pesummary/cli/summaryplots.py: 51.9%
104 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-02 08:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-02 08:42 +0000
1#! /usr/bin/env python
3# Licensed under an MIT style license -- see LICENSE.md
5from pesummary.utils.utils import logger, make_dir
6from pesummary.core.cli.parser import ArgumentParser as _ArgumentParser
7from pesummary.gw.plots.latex_labels import GWlatex_labels
8from pesummary.core.plots.latex_labels import latex_labels
9from pesummary.gw.plots.main import _PlotGeneration
10from pesummary.core.cli.actions import DictionaryAction
11from pesummary.gw.file.read import read
13__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
16class PlotGeneration(object):
17 """Wrapper class for _GWPlotGeneration and _CorePlotGeneration
19 Parameters
20 ----------
21 inputs: argparse.Namespace
22 Namespace object containing the command line options
23 colors: list, optional
24 colors that you wish to use to distinguish different result files
25 """
26 def __init__(self, inputs, colors="default", gw=False):
27 self.inputs = inputs
28 self.colors = colors
29 self.gw = gw
30 self.generate_plots()
32 def generate_plots(self):
33 """Generate all plots for all result files passed
34 """
35 logger.info("Starting to generate plots")
36 if self.gw and self.inputs.public:
37 object = _PublicGWPlotGeneration(self.inputs, colors=self.colors)
38 self.ligo_skymap_PID = object.ligo_skymap_PID
39 elif self.gw:
40 object = _GWPlotGeneration(self.inputs, colors=self.colors)
41 else:
42 object = _CorePlotGeneration(self.inputs, colors=self.colors)
43 object.generate_plots()
44 if self.gw:
45 self.ligo_skymap_PID = object.ligo_skymap_PID
46 logger.info("Finished generating plots")
49class _CorePlotGeneration(object):
50 """Class to generate all plots associated with the Core module
52 Parameters
53 ----------
54 inputs: argparse.Namespace
55 Namespace object containing the command line options
56 colors: list, optional
57 colors that you wish to use to distinguish different result files
58 """
59 def __init__(self, inputs, colors="default"):
60 from pesummary.core.plots.main import _PlotGeneration
61 key_data = inputs.grab_key_data_from_result_files()
62 expert_plots = not inputs.disable_expert
63 self.plotting_object = _PlotGeneration(
64 webdir=inputs.webdir, labels=inputs.labels,
65 samples=inputs.samples, kde_plot=inputs.kde_plot,
66 existing_labels=inputs.existing_labels,
67 existing_injection_data=inputs.existing_injection_data,
68 existing_samples=inputs.existing_samples,
69 same_parameters=inputs.same_parameters,
70 injection_data=inputs.injection_data,
71 colors=inputs.colors, custom_plotting=inputs.custom_plotting,
72 add_to_existing=inputs.add_to_existing, priors=inputs.priors,
73 include_prior=inputs.include_prior, weights=inputs.weights,
74 disable_comparison=inputs.disable_comparison,
75 linestyles=inputs.linestyles,
76 disable_interactive=inputs.disable_interactive,
77 disable_corner=inputs.disable_corner,
78 multi_process=inputs.multi_process, mcmc_samples=inputs.mcmc_samples,
79 corner_params=inputs.corner_params, expert_plots=expert_plots,
80 checkpoint=inputs.restart_from_checkpoint, key_data=key_data
81 )
83 def generate_plots(self):
84 """Generate all plots within the Core module
85 """
86 self.plotting_object.generate_plots()
89class _GWPlotGeneration(object):
90 """Class to generate all plots associated with the GW module
92 Parameters
93 ----------
94 inputs: argparse.Namespace
95 Namespace object containing the command line options
96 colors: list, optional
97 colors that you wish to use to distinguish different result files
98 """
99 def __init__(self, inputs, colors="default"):
100 from pesummary.gw.plots.main import _PlotGeneration
101 key_data = inputs.grab_key_data_from_result_files()
102 expert_plots = not inputs.disable_expert
103 self.plotting_object = _PlotGeneration(
104 webdir=inputs.webdir, labels=inputs.labels,
105 samples=inputs.samples, kde_plot=inputs.kde_plot,
106 existing_labels=inputs.existing_labels,
107 existing_injection_data=inputs.existing_injection_data,
108 existing_samples=inputs.existing_samples,
109 existing_file_kwargs=inputs.existing_file_kwargs,
110 existing_approximant=inputs.existing_approximant,
111 existing_metafile=inputs.existing_metafile,
112 same_parameters=inputs.same_parameters,
113 injection_data=inputs.injection_data,
114 result_files=inputs.result_files,
115 file_kwargs=inputs.file_kwargs,
116 colors=inputs.colors, custom_plotting=inputs.custom_plotting,
117 add_to_existing=inputs.add_to_existing, priors=inputs.priors,
118 no_ligo_skymap=inputs.no_ligo_skymap,
119 nsamples_for_skymap=inputs.nsamples_for_skymap,
120 detectors=inputs.detectors, maxL_samples=inputs.maxL_samples,
121 gwdata=inputs.gwdata, calibration=inputs.calibration,
122 psd=inputs.psd, approximant=inputs.approximant,
123 multi_threading_for_skymap=inputs.multi_threading_for_skymap,
124 pepredicates_probs=inputs.pepredicates_probs,
125 include_prior=inputs.include_prior, publication=inputs.publication,
126 existing_psd=inputs.existing_psd,
127 existing_calibration=inputs.existing_calibration, weights=inputs.weights,
128 linestyles=inputs.linestyles,
129 disable_comparison=inputs.disable_comparison,
130 disable_interactive=inputs.disable_interactive,
131 disable_corner=inputs.disable_corner,
132 publication_kwargs=inputs.publication_kwargs,
133 multi_process=inputs.multi_process, mcmc_samples=inputs.mcmc_samples,
134 skymap=inputs.skymap, existing_skymap=inputs.existing_skymap,
135 corner_params=inputs.corner_params,
136 preliminary_pages=inputs.preliminary_pages, expert_plots=expert_plots,
137 checkpoint=inputs.restart_from_checkpoint, key_data=key_data
138 )
139 self.ligo_skymap_PID = self.plotting_object.ligo_skymap_PID
141 def generate_plots(self):
142 """Generate all plots within the GW module
143 """
144 self.plotting_object.generate_plots()
147class _PublicGWPlotGeneration(object):
148 """Class to generate all plots associated with the GW module
150 Parameters
151 ----------
152 inputs: argparse.Namespace
153 Namespace object containing the command line options
154 colors: list, optional
155 colors that you wish to use to distinguish different result files
156 """
157 def __init__(self, inputs, colors="default"):
158 from pesummary.gw.plots.public import _PlotGeneration
159 expert_plots = not inputs.disable_expert
160 self.plotting_object = _PlotGeneration(
161 webdir=inputs.webdir, labels=inputs.labels,
162 samples=inputs.samples, kde_plot=inputs.kde_plot,
163 existing_labels=inputs.existing_labels,
164 existing_injection_data=inputs.existing_injection_data,
165 existing_samples=inputs.existing_samples,
166 existing_file_kwargs=inputs.existing_file_kwargs,
167 existing_approximant=inputs.existing_approximant,
168 existing_metafile=inputs.existing_metafile,
169 same_parameters=inputs.same_parameters,
170 injection_data=inputs.injection_data,
171 result_files=inputs.result_files,
172 file_kwargs=inputs.file_kwargs,
173 colors=inputs.colors, custom_plotting=inputs.custom_plotting,
174 add_to_existing=inputs.add_to_existing, priors=inputs.priors,
175 no_ligo_skymap=inputs.no_ligo_skymap,
176 nsamples_for_skymap=inputs.nsamples_for_skymap,
177 detectors=inputs.detectors, maxL_samples=inputs.maxL_samples,
178 gwdata=inputs.gwdata, calibration=inputs.calibration,
179 psd=inputs.psd, approximant=inputs.approximant,
180 multi_threading_for_skymap=inputs.multi_threading_for_skymap,
181 pepredicates_probs=inputs.pepredicates_probs,
182 include_prior=inputs.include_prior, publication=inputs.publication,
183 existing_psd=inputs.existing_psd,
184 existing_calibration=inputs.existing_calibration, weights=inputs.weights,
185 linestyles=inputs.linestyles,
186 disable_comparison=inputs.disable_comparison,
187 disable_interactive=inputs.disable_interactive,
188 disable_corner=inputs.disable_corner,
189 publication_kwargs=inputs.publication_kwargs,
190 multi_process=inputs.multi_process, mcmc_samples=inputs.mcmc_samples,
191 skymap=inputs.skymap, existing_skymap=inputs.existing_skymap,
192 corner_params=inputs.corner_params,
193 preliminary_pages=inputs.preliminary_pages, expert_plots=expert_plots,
194 checkpoint=inputs.restart_from_checkpoint
195 )
196 self.ligo_skymap_PID = self.plotting_object.ligo_skymap_PID
198 def generate_plots(self):
199 """Generate all plots within the GW module
200 """
201 self.plotting_object.generate_plots()
204class ArgumentParser(_ArgumentParser):
205 def _pesummary_options(self):
206 options = super(ArgumentParser, self)._pesummary_options()
207 options.update(
208 {
209 "--plot": {
210 "help": "name of the publication plot you wish to produce",
211 "choices": [
212 "1d_histogram", "sample_evolution", "autocorrelation",
213 "skymap"
214 ],
215 "default": "2d_contour"
216 },
217 "--parameters": {
218 "nargs": "+",
219 "help": "parameters of the 2d contour plot you wish to make",
220 },
221 "--plot_kwargs": {
222 "help": "Optional plotting kwargs",
223 "default": {},
224 "nargs": "+",
225 "action": DictionaryAction
226 },
227 "--inj": {
228 "help": "Injected value",
229 },
230 }
231 )
232 return options
235def check_inputs(opts):
236 """Check that the inputs are compatible with `summaryplots`
237 """
238 from pesummary.utils.exceptions import InputError
240 base = "Please provide {} for each result file"
241 if opts.inj is None:
242 opts.inj = [float("nan")] * len(opts.samples)
243 if opts.labels is None:
244 opts.labels = [i for i in range(len(opts.samples))]
245 if len(opts.samples) != len(opts.labels):
246 raise InputError(base.format("a label"))
247 if len(opts.samples) != len(opts.inj):
248 raise InputError(base.format("the injected value"))
249 if opts.burnin is not None:
250 opts.burnin = int(opts.burnin)
251 return opts
254def read_input_file(path_to_file):
255 """Use PESummary to read a result file
257 Parameters
258 ----------
259 path_to_file: str
260 path to the results file
261 """
262 from pesummary.gw.file.read import read
264 f = read(path_to_file)
265 return f
268def oned_histogram_plot(opts):
269 """Make a 1d histogram plot
270 """
271 for num, samples in enumerate(opts.samples):
272 data = read(samples)
273 for parameter in opts.parameters:
274 _PlotGeneration._oned_histogram_plot(
275 opts.webdir, opts.labels[num], parameter,
276 data.samples_dict[parameter][opts.burnin:],
277 latex_labels[parameter], opts.inj[num], kde=opts.kde_plot
278 )
281def sample_evolution_plot(opts):
282 """Make a sample evolution plot
283 """
284 for num, samples in enumerate(opts.samples):
285 data = read(samples)
286 for parameter in opts.parameters:
287 _PlotGeneration._sample_evolution_plot(
288 opts.webdir, opts.labels[num], parameter,
289 data.samples_dict[parameter][opts.burnin:],
290 latex_labels[parameter], opts.inj[num]
291 )
294def autocorrelation_plot(opts):
295 """Make an autocorrelation plot
296 """
297 for num, samples in enumerate(opts.samples):
298 data = read(samples)
299 for parameter in opts.parameters:
300 _PlotGeneration._autocorrelation_plot(
301 opts.webdir, opts.labels[num], parameter,
302 data.samples_dict[parameter][opts.burnin:]
303 )
306def skymap_plot(opts):
307 """Make a skymap plot
308 """
309 for num, samples in enumerate(opts.samples):
310 data = read(samples)
311 _PlotGeneration._skymap_plot(
312 opts.webdir, data.samples_dict["ra"][opts.burnin:],
313 data.samples_dict["dec"][opts.burnin:], opts.labels[num]
314 )
317def main(args=None):
318 """The main interface for `summaryplots`
319 """
320 latex_labels.update(GWlatex_labels)
321 parser = ArgumentParser(description=__doc__)
322 parser.add_known_options_to_parser(
323 [
324 "--webdir", "--samples", "--labels", "--plot", "--parameters",
325 "--kde_plot", "--burnin", "--disable_comparison", "--plot_kwargs",
326 "--inj"
327 ]
328 )
329 opts = parser.parse_args(args=args)
330 opts = check_inputs(opts)
331 make_dir(opts.webdir)
332 func_map = {
333 "1d_histogram": oned_histogram_plot,
334 "sample_evolution": sample_evolution_plot,
335 "autocorrelation": autocorrelation_plot,
336 "skymap": skymap_plot
337 }
338 func_map[opts.plot](opts)
341if __name__ == "__main__":
342 main()