Coverage for pesummary/cli/summarypublication.py: 59.6%
178 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1#! /usr/bin/env python
3# Licensed under an MIT style license -- see LICENSE.md
5import pesummary
6from pesummary.gw.file.read import read as GWRead
7from pesummary.gw.plots.latex_labels import GWlatex_labels
8from pesummary.gw.plots import publication as pub
9from pesummary.core.plots import population as pop
10from pesummary.core.plots.latex_labels import latex_labels
11from pesummary.utils.utils import make_dir, logger, _check_latex_install
12from pesummary.core.cli.parser import ArgumentParser as _ArgumentParser
13from pesummary.core.cli.actions import DictionaryAction
14import seaborn
15import numpy as np
17__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
18__doc__ = """This executable is used to generate publication quality plots given
19result files"""
20_check_latex_install()
23class ArgumentParser(_ArgumentParser):
24 def _pesummary_options(self):
25 options = super(ArgumentParser, self)._pesummary_options()
26 options.update(
27 {
28 "--plot": {
29 "help": "name of the publication plot you wish to produce",
30 "default": "2d_contour",
31 "choices": [
32 "2d_contour", "violin", "spin_disk",
33 "population_scatter", "population_scatter_error"
34 ],
35 },
36 "--parameters": {
37 "nargs": "+",
38 "help": "parameters of the 2d contour plot you wish to make",
39 },
40 "--publication_kwargs": {
41 "help": "Optional kwargs for publication plots",
42 "nargs": "+",
43 "default": {},
44 "action": DictionaryAction
45 },
46 "--levels": {
47 "default": [0.9],
48 "nargs": "+",
49 "help": "Contour levels you wish to plot",
50 "type": float
51 }
52 }
53 )
54 return options
57def draw_specific_samples(param, parameters, samples):
58 """Return samples for a given parameter
60 param: str
61 parameter that you wish to get samples for
62 parameters: nd list
63 list of all parameters stored in the result file
64 samples: nd list
65 list of samples for each parameter
66 """
67 ind = [i.index(param) for i in parameters]
68 return [[k[ind] for k in l] for l in samples]
71def default_2d_contour_plot():
72 """Return the default 2d contour plots
73 """
74 twod_plots = [["mass_ratio", "chi_eff"], ["mass_1", "mass_2"],
75 ["luminosity_distance", "chirp_mass_source"],
76 ["mass_1_source", "mass_2_source"],
77 ["theta_jn", "luminosity_distance"],
78 ["network_optimal_snr", "chirp_mass_source"]]
79 return twod_plots
82def default_violin_plot():
83 """Retrn the default violin plots
84 """
85 violin_plots = ["chi_eff", "chi_p", "mass_ratio", "luminosity_distance"]
86 return violin_plots
89def read_samples(result_files):
90 """Read and return a list of parameters and samples stored in the result
91 files
93 Parameters
94 ----------
95 result_files: list
96 list of result files
97 """
98 parameters = []
99 samples = []
100 for i in result_files:
101 try:
102 f = GWRead(i)
103 if isinstance(f, pesummary.gw.file.formats.pesummary.PESummary):
104 parameters.append(f.parameters[0])
105 samples.append(f.samples[0])
106 else:
107 f.generate_all_posterior_samples()
108 parameters.append(f.parameters)
109 samples.append(f.samples)
110 except Exception:
111 logger.warning(
112 "Failed to read '{}'. Data will not be added to the "
113 "plots".format(i)
114 )
115 parameters.append([None])
116 samples.append([None])
117 return parameters, samples
120def get_colors_and_linestyles(opts):
121 """Return a list of colors and linestyles
122 """
123 if opts.colors is not None:
124 colors = opts.colors
125 else:
126 colors = seaborn.color_palette(
127 palette=opts.palette, n_colors=len(opts.samples)
128 ).as_hex()
129 if opts.linestyles is not None:
130 linestyles = opts.linestyles
131 return colors, linestyles
132 available_linestyles = ["-", "--", ":", "-."]
133 linestyles = ["-"] * len(colors)
134 unique_colors = np.unique(colors)
135 for color in unique_colors:
136 indicies = [num for num, i in enumerate(colors) if i == color]
137 for idx, j in enumerate(indicies):
138 linestyles[j] = available_linestyles[
139 np.mod(idx, len(available_linestyles))
140 ]
141 return colors, linestyles
144def make_2d_contour_plot(opts):
145 """Make a 2d contour plot
146 """
147 if opts.parameters and len(opts.parameters) != 2:
148 raise Exception("Please pass 2 variables that you wish to plot")
149 if opts.parameters:
150 default = [opts.parameters]
151 else:
152 default = default_2d_contour_plot()
153 colors, linestyles = get_colors_and_linestyles(opts)
154 parameters, samples = read_samples(opts.samples)
155 for i in default:
156 if not all(all(k in j for k in i) for j in parameters):
157 idxs = [
158 num for num, j in enumerate(parameters) if not
159 all(k in j for k in i)
160 ]
161 files = [opts.samples[j] for j in idxs]
162 logger.warning(
163 "Removing {} from 2d contour plot because the parameters {} are "
164 "not in the result file".format(
165 " and ".join(files), " and ".join(i)
166 )
167 )
168 parameters = [j for num, j in enumerate(parameters) if num not in idxs]
169 opts.labels = [j for num, j in enumerate(opts.labels) if num not in idxs]
170 samples = [j for num, j in enumerate(samples) if num not in idxs]
171 ind1 = [j.index(i[0]) for j in parameters]
172 ind2 = [j.index(i[1]) for j in parameters]
173 samples1 = [[k[ind1[num]] for k in l] for num, l in
174 enumerate(samples)]
175 samples2 = [[k[ind2[num]] for k in l] for num, l in
176 enumerate(samples)]
177 twod_samples = [[j, k] for j, k in zip(samples1, samples2)]
178 gridsize = (
179 opts.publication_kwargs["gridsize"] if "gridsize" in
180 opts.publication_kwargs.keys() else 100
181 )
182 fig, ax = pub.twod_contour_plots(
183 i, twod_samples, opts.labels, latex_labels, colors=colors,
184 linestyles=linestyles, gridsize=gridsize, levels=opts.levels,
185 return_ax=True
186 )
187 current_xlow, current_xhigh = ax.get_xlim()
188 current_ylow, current_yhigh = ax.get_ylim()
189 keys = opts.publication_kwargs.keys()
190 if "xlow" in keys and "xhigh" in keys:
191 ax.set_xlim(
192 [
193 float(opts.publication_kwargs["xlow"]),
194 float(opts.publication_kwargs["xhigh"])
195 ]
196 )
197 elif "xhigh" in keys:
198 ax.set_xlim([current_xlow, float(opts.publication_kwargs["xhigh"])])
199 elif "xlow" in keys:
200 ax.set_xlim([float(opts.publication_kwargs["xlow"]), current_xhigh])
201 if "ylow" in keys and "yhigh" in keys:
202 ax.set_ylim(
203 [
204 float(opts.publication_kwargs["ylow"]),
205 float(opts.publication_kwargs["yhigh"])
206 ]
207 )
208 elif "yhigh" in keys:
209 ax.set_ylim([current_ylow, float(opts.publication_kwargs["yhigh"])])
210 elif "ylow" in keys:
211 ax.set_ylim([float(opts.publication_kwargs["ylow"]), current_yhigh])
212 fig.savefig("%s/2d_contour_plot_%s" % (opts.webdir, "_and_".join(i)))
213 fig.close()
216def make_violin_plot(opts):
217 """
218 """
219 if opts.parameters and len(opts.parameters) != 1:
220 raise Exception("Please pass a single variable that you wish to plot")
221 if opts.parameters:
222 default = opts.parameters
223 else:
224 default = default_violin_plot()
225 parameters, samples = read_samples(opts.samples)
227 for i in default:
228 if not all(i in j for j in parameters):
229 idxs = [num for num, j in enumerate(parameters) if i not in j]
230 files = [opts.samples[j] for j in idxs]
231 logger.warning(
232 "Removing {} from violin plot because the parameter {} does "
233 "not exist in the result file".format(
234 " and ".join(files), i
235 )
236 )
237 parameters = [j for num, j in enumerate(parameters) if num not in idxs]
238 opts.labels = [j for num, j in enumerate(opts.labels) if num not in idxs]
239 samples = [j for num, j in enumerate(samples) if num not in idxs]
240 try:
241 ind = [j.index(i) for j in parameters]
242 samples = [[k[ind[num]] for k in l] for num, l in
243 enumerate(samples)]
244 fig = pub.violin_plots(i, samples, opts.labels, latex_labels)
245 fig.savefig("%s/violin_plot_%s.png" % (opts.webdir, i))
246 fig.close()
247 except Exception as e:
248 logger.info(
249 "Failed to generate a violin plot for %s because %s" % (i, e)
250 )
251 continue
254def make_spin_disk_plot(opts):
255 """Make a spin disk plot
256 """
257 colors, linestyles = get_colors_and_linestyles(opts)
258 parameters, samples = read_samples(opts.samples)
260 required_parameters = ["a_1", "a_2", "cos_tilt_1", "cos_tilt_2"]
261 for num, i in enumerate(parameters):
262 if not all(j in i for j in required_parameters):
263 logger.info("Failed to generate spin disk plots for %s because "
264 "%s are not in the result file" % (
265 opts.labels[num],
266 " and ".join(required_parameters)))
267 continue
268 try:
269 ind = [i.index(j) for j in required_parameters]
270 spin_samples = [[k[idx] for k in samples[num]] for idx in ind]
271 fig = pub.spin_distribution_plots(
272 required_parameters, spin_samples, opts.labels[num],
273 colors[num])
274 fig.savefig("%s/spin_disk_plot_%s.png" % (
275 opts.webdir, opts.labels[num]))
276 fig.close()
277 except Exception as e:
278 logger.warning(
279 "Failed to generate a spin disk plot for %s because %s" % (
280 opts.labels[num], e
281 )
282 )
283 continue
286def make_population_scatter_plot(opts):
287 """Make a scatter plot showing a population of runs
288 """
289 if len(opts.samples) > 1:
290 parameters, samples = read_samples(opts.samples)
291 plotting_data = {}
292 xerr, yerr = None, None
293 if "error" in opts.plot:
294 xerr, yerr = {}, {}
295 for num, label in enumerate(opts.labels):
296 if not all(i in parameters[num] for i in opts.parameters):
297 logger.warning(
298 "'{}' does not include samples for '{}' and/or '{}'. This "
299 "analysis will not be added to the plot".format(
300 label, opts.parameters[0], opts.parameters[1]
301 )
302 )
303 continue
304 plotting_data[label] = {}
305 if xerr is not None:
306 xerr[label] = {}
307 yerr[label] = {}
308 for param in opts.parameters:
309 ind = parameters[num].index(param)
310 plotting_data[label][param] = np.median(
311 [i[ind] for i in samples[num]]
312 )
313 if xerr is not None:
314 ind = parameters[num].index(opts.parameters[0])
315 xerr[label][opts.parameters[0]] = [
316 np.abs(plotting_data[label][opts.parameters[0]] - np.percentile(
317 [i[ind] for i in samples[num]], j
318 )) for j in [5, 95]
319 ]
320 if yerr is not None:
321 ind = parameters[num].index(opts.parameters[1])
322 yerr[label][opts.parameters[1]] = [
323 np.abs(plotting_data[label][opts.parameters[1]] - np.percentile(
324 [i[ind] for i in samples[num]], j
325 )) for j in [5, 95]
326 ]
327 fig = pop.scatter_plot(
328 opts.parameters, plotting_data, latex_labels, xerr=xerr, yerr=yerr
329 )
330 fig.savefig("{}/event_scatter_plot_{}.png".format(
331 opts.webdir, "_and_".join(opts.parameters)
332 ))
333 fig.close()
336def main(args=None):
337 """Top level interface for `summarypublication`
338 """
339 latex_labels.update(GWlatex_labels)
340 parser = ArgumentParser(description=__doc__)
341 parser.add_known_options_to_parser(
342 [
343 "--webdir", "--samples", "--labels", "--plot", "--parameters",
344 "--publication_kwargs", "--colors", "--palette", "--linestyles",
345 "--levels"
346 ]
347 )
348 opts = parser.parse_args(args=args)
349 make_dir(opts.webdir)
350 func_map = {"2d_contour": make_2d_contour_plot,
351 "violin": make_violin_plot,
352 "spin_disk": make_spin_disk_plot,
353 "population_scatter": make_population_scatter_plot,
354 "population_scatter_error": make_population_scatter_plot}
355 func_map[opts.plot](opts)
358if __name__ == "__main__":
359 main()