Coverage for pesummary/core/plots/main.py: 74.7%
446 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
1#! /usr/bin/env python
3# Licensed under an MIT style license -- see LICENSE.md
5import numpy as np
6import os
7import importlib
8from multiprocessing import Pool
10from pesummary.core.plots.latex_labels import latex_labels
11from pesummary.utils.utils import (
12 logger, get_matplotlib_backend, make_dir, get_matplotlib_style_file
13)
14from pesummary.core.plots import plot as core
15from pesummary.core.plots import interactive
17import matplotlib
18import matplotlib.style
20__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
21matplotlib.use(get_matplotlib_backend(parallel=True))
22matplotlib.style.use(get_matplotlib_style_file())
25class _PlotGeneration(object):
26 """Super class to handle the plot generation for a given set of result
27 files
29 Parameters
30 ----------
31 savedir: str
32 the directory to store the plots
33 webdir: str
34 the web directory of the run
35 labels: list
36 list of labels used to distinguish the result files
37 samples: dict
38 dictionary of posterior samples stored in the result files
39 kde_plot: Bool
40 if True, kde plots are generated instead of histograms, Default False
41 existing_labels: list
42 list of labels stored in an existing metafile
43 existing_injection_data: dict
44 dictionary of injection data stored in an existing metafile
45 existing_samples: dict
46 dictionary of posterior samples stored in an existing metafile
47 same_parameters: list
48 list of paramerers that are common in all result files
49 injection_data: dict
50 dictionary of injection data for each result file
51 result_files: list
52 list of result files passed
53 colors: list
54 colors that you wish to use to distinguish different result files
55 disable_comparison: bool, optional
56 whether to make comparison plots, default is True.
57 if disable_comparison is False and len(labels) == 1, no comparsion plots
58 will be generated
59 disable_interactive: bool, optional
60 whether to make interactive plots, default is False
61 disable_corner: bool, optional
62 whether to make the corner plot, default is False
63 """
64 def __init__(
65 self, savedir=None, webdir=None, labels=None, samples=None,
66 kde_plot=False, existing_labels=None, existing_injection_data=None,
67 existing_samples=None, existing_weights=None, same_parameters=None,
68 injection_data=None, colors=None, custom_plotting=None,
69 add_to_existing=False, priors={}, include_prior=False, weights=None,
70 disable_comparison=False, linestyles=None, disable_interactive=False,
71 multi_process=1, mcmc_samples=False, disable_corner=False,
72 corner_params=None, expert_plots=True, checkpoint=False, key_data=None
73 ):
74 self.package = "core"
75 self.webdir = webdir
76 make_dir(self.webdir)
77 make_dir(os.path.join(self.webdir, "plots", "corner"))
78 self.savedir = savedir
79 self.labels = labels
80 self.mcmc_samples = mcmc_samples
81 self.samples = samples
82 self.kde_plot = kde_plot
83 self.existing_labels = existing_labels
84 self.existing_injection_data = existing_injection_data
85 self.existing_samples = existing_samples
86 self.existing_weights = existing_weights
87 self.same_parameters = same_parameters
88 self.injection_data = injection_data
89 self.colors = colors
90 self.custom_plotting = custom_plotting
91 self.add_to_existing = add_to_existing
92 self.priors = priors
93 self.include_prior = include_prior
94 self.linestyles = linestyles
95 self.make_interactive = not disable_interactive
96 self.make_corner = not disable_corner
97 self.corner_params = corner_params
98 self.expert_plots = expert_plots
99 if self.mcmc_samples and self.expert_plots:
100 logger.warning("Unable to generate expert plots for mcmc samples")
101 self.expert_plots = False
102 self.checkpoint = checkpoint
103 self.key_data = key_data
104 self.multi_process = multi_process
105 self.pool = self.setup_pool()
106 self.preliminary_pages = {label: False for label in self.labels}
107 self.preliminary_comparison_pages = False
108 self.make_comparison = (
109 not disable_comparison and self._total_number_of_labels > 1
110 )
111 self.weights = (
112 weights if weights is not None else {i: None for i in self.labels}
113 )
115 if self.same_parameters is not None and not self.mcmc_samples:
116 self.same_samples = {
117 param: {
118 key: item[param] for key, item in self.samples.items()
119 } for param in self.same_parameters
120 }
121 else:
122 self.same_samples = None
124 for i in self.samples.keys():
125 try:
126 self.check_latex_labels(
127 self.samples[i].keys(remove_debug=False)
128 )
129 except TypeError:
130 try:
131 self.check_latex_labels(self.samples[i].keys())
132 except TypeError:
133 pass
135 self.plot_type_dictionary = {
136 "oned_histogram": self.oned_histogram_plot,
137 "sample_evolution": self.sample_evolution_plot,
138 "autocorrelation": self.autocorrelation_plot,
139 "oned_cdf": self.oned_cdf_plot,
140 "custom": self.custom_plot
141 }
142 if self.make_corner:
143 self.plot_type_dictionary.update({"corner": self.corner_plot})
144 if self.expert_plots:
145 self.plot_type_dictionary.update({"expert": self.expert_plot})
146 if self.make_comparison:
147 self.plot_type_dictionary.update(dict(
148 oned_histogram_comparison=self.oned_histogram_comparison_plot,
149 oned_cdf_comparison=self.oned_cdf_comparison_plot,
150 box_plot_comparison=self.box_plot_comparison_plot,
151 ))
152 if self.make_interactive:
153 self.plot_type_dictionary.update(
154 dict(
155 interactive_corner=self.interactive_corner_plot
156 )
157 )
158 if self.make_comparison:
159 self.plot_type_dictionary.update(
160 dict(
161 interactive_ridgeline=self.interactive_ridgeline_plot
162 )
163 )
165 @staticmethod
166 def save(fig, name, preliminary=False, close=True, format="png"):
167 """Save a figure to disk.
169 Parameters
170 ----------
171 fig: matplotlib.pyplot.figure
172 Matplotlib figure that you wish to save
173 name: str
174 Name of the file that you wish to write it too
175 close: Bool, optional
176 Close the figure after it has been saved
177 format: str, optional
178 Format used to save the image
179 """
180 n = len(format)
181 if ".%s" % (format) != name[-n - 1:]:
182 name += ".%s" % (format)
183 if preliminary:
184 fig.text(
185 0.5, 0.5, 'Preliminary', fontsize=90, color='gray', alpha=0.1,
186 ha='center', va='center', rotation=30
187 )
188 fig.tight_layout()
189 fig.savefig(name, format=format)
190 if close:
191 fig.close()
193 @property
194 def _total_number_of_labels(self):
195 _number_of_labels = 0
196 for item in [self.labels, self.existing_labels]:
197 if isinstance(item, list):
198 _number_of_labels += len(item)
199 return _number_of_labels
201 @staticmethod
202 def check_latex_labels(parameters):
203 """Check to see if there is a latex label for all parameters. If not,
204 then create one
206 Parameters
207 ----------
208 parameters: list
209 list of parameters
210 """
211 for i in parameters:
212 if i not in list(latex_labels.keys()):
213 latex_labels[i] = i.replace("_", " ")
215 @property
216 def savedir(self):
217 return self._savedir
219 @savedir.setter
220 def savedir(self, savedir):
221 self._savedir = savedir
222 if savedir is None:
223 self._savedir = self.webdir + "/plots/"
225 def setup_pool(self):
226 """Setup a pool of processes to speed up plot generation
227 """
228 pool = Pool(processes=self.multi_process)
229 return pool
231 def generate_plots(self):
232 """Generate all plots for all result files
233 """
234 for i in self.labels:
235 logger.debug("Starting to generate plots for {}".format(i))
236 self._generate_plots(i)
237 if self.make_interactive:
238 logger.debug(
239 "Starting to generate interactive plots for {}".format(i)
240 )
241 self._generate_interactive_plots(i)
242 if self.add_to_existing:
243 self.add_existing_data()
244 if self.make_comparison:
245 logger.debug("Starting to generate comparison plots")
246 self._generate_comparison_plots()
248 def check_key_data_in_dict(self, label, param):
249 """Check to see if there is key data for a given param
251 Parameters
252 ----------
253 label: str
254 the label used to distinguish a given run
255 param: str
256 name of the parameter you wish to return prior samples for
257 """
258 if self.key_data is None:
259 return None
260 elif label not in self.key_data.keys():
261 return None
262 elif param not in self.key_data[label].keys():
263 return None
264 return self.key_data[label][param]
266 def check_prior_samples_in_dict(self, label, param):
267 """Check to see if there are prior samples for a given param
269 Parameters
270 ----------
271 label: str
272 the label used to distinguish a given run
273 param: str
274 name of the parameter you wish to return prior samples for
275 """
276 cond1 = "samples" in self.priors.keys()
277 if cond1 and label in self.priors["samples"].keys():
278 cond1 = self.priors["samples"][label] != []
279 if cond1 and param in self.priors["samples"][label].keys():
280 return self.priors["samples"][label][param]
281 return None
282 return None
284 def add_existing_data(self):
285 """
286 """
287 from pesummary.utils.utils import _add_existing_data
289 self = _add_existing_data(self)
291 def _generate_plots(self, label):
292 """Generate all plots for a a given result file
293 """
294 if self.make_corner:
295 self.try_to_make_a_plot("corner", label=label)
296 self.try_to_make_a_plot("oned_histogram", label=label)
297 self.try_to_make_a_plot("sample_evolution", label=label)
298 self.try_to_make_a_plot("autocorrelation", label=label)
299 self.try_to_make_a_plot("oned_cdf", label=label)
300 if self.expert_plots:
301 self.try_to_make_a_plot("expert", label=label)
302 if self.custom_plotting:
303 self.try_to_make_a_plot("custom", label=label)
305 def _generate_interactive_plots(self, label):
306 """Generate all interactive plots and save them to an html file ready
307 to be imported later
308 """
309 self.try_to_make_a_plot("interactive_corner", label=label)
310 if self.make_comparison:
311 self.try_to_make_a_plot("interactive_ridgeline")
313 def _generate_comparison_plots(self):
314 """Generate all comparison plots
315 """
316 self.try_to_make_a_plot("oned_histogram_comparison")
317 self.try_to_make_a_plot("oned_cdf_comparison")
318 self.try_to_make_a_plot("box_plot_comparison")
320 def try_to_make_a_plot(self, plot_type, label=None):
321 """Wrapper function to _try_to_make_a_plot
323 Parameters
324 ----------
325 plot_type: str
326 String to describe the plot that you wish to try and make
327 label: str
328 The label of the results file that you wish to plot
329 """
330 self._try_to_make_a_plot(
331 [label], self.plot_type_dictionary[plot_type],
332 "Failed to generate %s plot because {}" % (plot_type)
333 )
335 @staticmethod
336 def _try_to_make_a_plot(arguments, function, message):
337 """Try to make a plot. If it fails return an error message and continue
338 plotting
340 Parameters
341 ----------
342 arguments: list
343 list of arguments that you wish to pass to function
344 function: func
345 function that you wish to execute
346 message: str
347 the error message that you wish to be printed.
348 """
349 try:
350 function(*arguments)
351 except RuntimeError:
352 try:
353 from matplotlib import rcParams
355 original = rcParams['text.usetex']
356 rcParams['text.usetex'] = False
357 function(*arguments)
358 rcParams['text.usetex'] = original
359 except Exception as e:
360 logger.info(message.format(e))
361 except Exception as e:
362 logger.info(message.format(e))
363 finally:
364 from matplotlib import pyplot
366 pyplot.close()
368 def corner_plot(self, label):
369 """Generate a corner plot for a given result file
371 Parameters
372 ----------
373 label: str
374 the label for the results file that you wish to plot
375 """
376 if self.mcmc_samples:
377 samples = self.samples[label].combine
378 else:
379 samples = self.samples[label]
380 self._corner_plot(
381 self.savedir, label, samples, latex_labels, self.webdir,
382 self.corner_params, self.preliminary_pages[label], self.checkpoint
383 )
385 @staticmethod
386 def _corner_plot(
387 savedir, label, samples, latex_labels, webdir, params, preliminary=False,
388 checkpoint=False
389 ):
390 """Generate a corner plot for a given set of samples
392 Parameters
393 ----------
394 savedir: str
395 the directory you wish to save the plot in
396 label: str
397 the label corresponding to the results file
398 samples: dict
399 dictionary containing PESummary.utils.array.Array objects that
400 contain samples for each parameter
401 latex_labels: str
402 latex labels for each parameter in samples
403 webdir: str
404 the directory where the `js` directory is located
405 preliminary: Bool, optional
406 if True, add a preliminary watermark to the plot
407 """
408 import warnings
410 with warnings.catch_warnings():
411 warnings.simplefilter("ignore")
412 filename = os.path.join(
413 savedir, "corner", "{}_all_density_plots.png".format(label)
414 )
415 if os.path.isfile(filename) and checkpoint:
416 return
417 fig, params, data = core._make_corner_plot(
418 samples, latex_labels, corner_parameters=params
419 )
420 _PlotGeneration.save(
421 fig, filename, preliminary=preliminary
422 )
423 combine_corner = open(
424 os.path.join(webdir, "js", "combine_corner.js")
425 )
426 combine_corner = combine_corner.readlines()
427 params = [str(i) for i in params]
428 ind = [
429 linenumber for linenumber, line in enumerate(combine_corner)
430 if "var list = {}" in line
431 ][0]
432 combine_corner.insert(
433 ind + 1, " list['{}'] = {};\n".format(label, params)
434 )
435 new_file = open(
436 os.path.join(webdir, "js", "combine_corner.js"), "w"
437 )
438 new_file.writelines(combine_corner)
439 new_file.close()
440 combine_corner = open(
441 os.path.join(webdir, "js", "combine_corner.js")
442 )
443 combine_corner = combine_corner.readlines()
444 params = [str(i) for i in params]
445 ind = [
446 linenumber for linenumber, line in enumerate(combine_corner)
447 if "var data = {}" in line
448 ][0]
449 combine_corner.insert(
450 ind + 1, " data['{}'] = {};\n".format(label, data)
451 )
452 new_file = open(
453 os.path.join(webdir, "js", "combine_corner.js"), "w"
454 )
455 new_file.writelines(combine_corner)
456 new_file.close()
458 def _mcmc_iterator(self, label, function):
459 """If the data is a set of mcmc chains, return a 2d list of samples
460 to plot. Otherwise return a list of posterior samples
461 """
462 if self.mcmc_samples:
463 function += "_mcmc"
464 return self.same_parameters, self.samples[label], getattr(
465 self, function
466 )
467 return self.samples[label].keys(), self.samples[label], getattr(
468 self, function
469 )
471 def oned_histogram_plot(self, label):
472 """Generate oned histogram plots for all parameters in the result file
474 Parameters
475 ----------
476 label: str
477 the label for the results file that you wish to plot
478 """
479 error_message = (
480 "Failed to generate oned_histogram plot for %s because {}"
481 )
483 iterator, samples, function = self._mcmc_iterator(
484 label, "_oned_histogram_plot"
485 )
487 prior = lambda param: self.check_prior_samples_in_dict(
488 label, param
489 ) if self.include_prior else None
490 key_data = lambda label, param: self.check_key_data_in_dict(label, param)
492 arguments = [
493 (
494 [
495 self.savedir, label, param, samples[param],
496 latex_labels[param], self.injection_data[label][param],
497 self.kde_plot, prior(param), self.weights[label],
498 self.package, self.preliminary_pages[label],
499 self.checkpoint, key_data(label, param)
500 ], function, error_message % (param)
501 ) for param in iterator
502 ]
503 self.pool.starmap(self._try_to_make_a_plot, arguments)
505 def oned_histogram_comparison_plot(self, label):
506 """Generate oned comparison histogram plots for all parameters that are
507 common to all result files
509 Parameters
510 ----------
511 label: str
512 the label for the results file that you wish to plot
513 """
514 error_message = (
515 "Failed to generate a comparison histogram plot for %s because {}"
516 )
517 for param in self.same_parameters:
518 injection = [
519 value[param] for value in self.injection_data.values()
520 ]
521 if self.weights is not None:
522 weights = [
523 self.weights.get(key, None) for key in
524 self.same_samples[param].keys()
525 ]
526 else:
527 weights = None
529 arguments = [
530 self.savedir, param, self.same_samples[param],
531 latex_labels[param], self.colors, injection, weights,
532 self.kde_plot, self.linestyles, self.package,
533 self.preliminary_comparison_pages, self.checkpoint, None
534 ]
535 self._try_to_make_a_plot(
536 arguments, self._oned_histogram_comparison_plot,
537 error_message % (param)
538 )
539 continue
541 @staticmethod
542 def _oned_histogram_comparison_plot(
543 savedir, parameter, samples, latex_label, colors, injection, weights,
544 kde=False, linestyles=None, package="core", preliminary=False,
545 checkpoint=False, filename=None,
546 ):
547 """Generate a oned comparison histogram plot for a given parameter
549 Parameters
550 ----------i
551 savedir: str
552 the directory you wish to save the plot in
553 parameter: str
554 the name of the parameter that you wish to make a oned comparison
555 histogram for
556 samples: dict
557 dictionary of pesummary.utils.array.Array objects containing the
558 samples that correspond to parameter for each result file. The key
559 should be the corresponding label
560 latex_label: str
561 the latex label for parameter
562 colors: list
563 list of colors to be used to distinguish different result files
564 injection: list
565 list of injected values, one for each analysis
566 kde: Bool, optional
567 if True, kde plots will be generated rather than 1d histograms
568 linestyles: list, optional
569 list of linestyles used to distinguish different result files
570 preliminary: Bool, optional
571 if True, add a preliminary watermark to the plot
572 """
573 import math
574 module = importlib.import_module(
575 "pesummary.{}.plots.plot".format(package)
576 )
577 if filename is None:
578 filename = os.path.join(
579 savedir, "combined_1d_posterior_{}.png".format(parameter)
580 )
581 if os.path.isfile(filename) and checkpoint:
582 return
583 hist = not kde
584 for num, inj in enumerate(injection):
585 if math.isnan(inj):
586 injection[num] = None
587 same_samples = [val for key, val in samples.items()]
588 fig = module._1d_comparison_histogram_plot(
589 parameter, same_samples, colors, latex_label,
590 list(samples.keys()), inj_value=injection, kde=kde,
591 linestyles=linestyles, hist=hist, weights=weights
592 )
593 _PlotGeneration.save(
594 fig, filename, preliminary=preliminary
595 )
597 @staticmethod
598 def _oned_histogram_plot(
599 savedir, label, parameter, samples, latex_label, injection, kde=False,
600 prior=None, weights=None, package="core", preliminary=False,
601 checkpoint=False, key_data=None
602 ):
603 """Generate a oned histogram plot for a given set of samples
605 Parameters
606 ----------
607 savedir: str
608 the directory you wish to save the plot in
609 label: str
610 the label corresponding to the results file
611 parameter: str
612 the name of the parameter that you wish to plot
613 samples: PESummary.utils.array.Array
614 array containing the samples corresponding to parameter
615 latex_label: str
616 the latex label corresponding to parameter
617 injection: float
618 the injected value
619 kde: Bool, optional
620 if True, kde plots will be generated rather than 1d histograms
621 prior: PESummary.utils.array.Array, optional
622 the prior samples for param
623 weights: PESummary.utils.utilsrray, optional
624 the weights for each samples. If None, assumed to be 1
625 preliminary: Bool, optional
626 if True, add a preliminary watermark to the plot
627 """
628 import math
629 module = importlib.import_module(
630 "pesummary.{}.plots.plot".format(package)
631 )
633 if math.isnan(injection):
634 injection = None
635 hist = not kde
637 filename = os.path.join(
638 savedir, "{}_1d_posterior_{}.png".format(label, parameter)
639 )
640 if os.path.isfile(filename) and checkpoint:
641 return
642 fig = module._1d_histogram_plot(
643 parameter, samples, latex_label, inj_value=injection, kde=kde,
644 hist=hist, prior=prior, weights=weights, key_data=key_data
645 )
646 _PlotGeneration.save(
647 fig, filename, preliminary=preliminary
648 )
650 @staticmethod
651 def _oned_histogram_plot_mcmc(
652 savedir, label, parameter, samples, latex_label, injection, kde=False,
653 prior=None, weights=None, package="core", preliminary=False,
654 checkpoint=False, key_data=None
655 ):
656 """Generate a oned histogram plot for a given set of samples for
657 multiple mcmc chains
659 Parameters
660 ----------
661 savedir: str
662 the directory you wish to save the plot in
663 label: str
664 the label corresponding to the results file
665 parameter: str
666 the name of the parameter that you wish to plot
667 samples: dict
668 dictionary of PESummary.utils.array.Array objects containing the
669 samples corresponding to parameter for multiple mcmc chains
670 latex_label: str
671 the latex label corresponding to parameter
672 injection: float
673 the injected value
674 kde: Bool, optional
675 if True, kde plots will be generated rather than 1d histograms
676 prior: PESummary.utils.array.Array, optional
677 the prior samples for param
678 weights: PESummary.utils.array.Array, optional
679 the weights for each samples. If None, assumed to be 1
680 preliminary: Bool, optional
681 if True, add a preliminary watermark to the plot
682 """
683 import math
684 from pesummary.utils.array import Array
686 module = importlib.import_module(
687 "pesummary.{}.plots.plot".format(package)
688 )
690 if math.isnan(injection):
691 injection = None
692 same_samples = [val for key, val in samples.items()]
693 filename = os.path.join(
694 savedir, "{}_1d_posterior_{}.png".format(label, parameter)
695 )
696 if os.path.isfile(filename) and checkpoint:
697 pass
698 else:
699 fig = module._1d_histogram_plot_mcmc(
700 parameter, same_samples, latex_label, inj_value=injection,
701 kde=kde, prior=prior, weights=weights
702 )
703 _PlotGeneration.save(
704 fig, filename, preliminary=preliminary
705 )
706 filename = os.path.join(
707 savedir, "{}_1d_posterior_{}_combined.png".format(label, parameter)
708 )
709 if os.path.isfile(filename) and checkpoint:
710 pass
711 else:
712 fig = module._1d_histogram_plot(
713 parameter, Array(np.concatenate(same_samples)), latex_label,
714 inj_value=injection, kde=kde, prior=prior, weights=weights
715 )
716 _PlotGeneration.save(
717 fig, filename, preliminary=preliminary
718 )
720 def sample_evolution_plot(self, label):
721 """Generate sample evolution plots for all parameters in the result file
723 Parameters
724 ----------
725 label: str
726 the label for the results file that you wish to plot
727 """
728 error_message = (
729 "Failed to generate a sample evolution plot for %s because {}"
730 )
731 iterator, samples, function = self._mcmc_iterator(
732 label, "_sample_evolution_plot"
733 )
734 arguments = [
735 (
736 [
737 self.savedir, label, param, samples[param],
738 latex_labels[param], self.injection_data[label][param],
739 self.preliminary_pages[label], self.checkpoint
740 ], function, error_message % (param)
741 ) for param in iterator
742 ]
743 self.pool.starmap(self._try_to_make_a_plot, arguments)
745 def expert_plot(self, label):
746 """Generate expert plots for diagnostics
748 Parameters
749 ----------
750 label: str
751 the label for the results file that you wish to plot
752 """
753 error_message = (
754 "Failed to generate log_likelihood-%s 2d contour plot because {}"
755 )
756 iterator, samples, function = self._mcmc_iterator(
757 label, "_2d_contour_plot"
758 )
759 _debug = self.samples[label].debug_keys()
760 arguments = [
761 (
762 [
763 self.savedir, label, param, "log_likelihood", samples[param],
764 samples["log_likelihood"], latex_labels[param],
765 latex_labels["log_likelihood"],
766 self.preliminary_pages[label], [
767 samples[param][np.argmax(samples["log_likelihood"])],
768 np.max(samples["log_likelihood"]),
769 ], self.checkpoint
770 ], function, error_message % (param)
771 ) for param in iterator + _debug
772 ]
773 self.pool.starmap(self._try_to_make_a_plot, arguments)
774 _reweight_keys = [
775 param for param in self.samples[label].debug_keys() if
776 "_non_reweighted" in param
777 ]
778 if len(_reweight_keys):
779 error_message = (
780 "Failed to generate %s-%s 2d contour plot because {}"
781 )
782 _base_param = lambda p: p.split("_non_reweighted")[0][1:]
783 arguments = [
784 (
785 [
786 self.savedir, label, _base_param(param), param,
787 samples[_base_param(param)], samples[param],
788 latex_labels[_base_param(param)], latex_labels[param],
789 self.preliminary_pages[label], None, self.checkpoint
790 ], function, error_message % (_base_param(param), param)
791 ) for param in _reweight_keys
792 ]
793 self.pool.starmap(self._try_to_make_a_plot, arguments)
794 error_message = (
795 "Failed to generate a histogram plot comparing %s and %s "
796 "because {}"
797 )
798 arguments = [
799 (
800 [
801 self.savedir, _base_param(param), {
802 "reweighted": samples[_base_param(param)],
803 "non-reweighted": samples[param]
804 }, latex_labels[_base_param(param)], ['b', 'r'],
805 [np.nan, np.nan], True, None, self.package,
806 self.preliminary_comparison_pages, self.checkpoint,
807 os.path.join(
808 self.savedir, "{}_1d_posterior_{}_{}.png".format(
809 label, _base_param(param), param
810 )
811 )
812 ], self._oned_histogram_comparison_plot,
813 error_message % (_base_param(param), param)
814 ) for param in _reweight_keys
815 ]
816 self.pool.starmap(self._try_to_make_a_plot, arguments)
818 error_message = (
819 "Failed to generate log_likelihood-%s sample_evolution plot "
820 "because {}"
821 )
822 iterator, samples, function = self._mcmc_iterator(
823 label, "_colored_sample_evolution_plot"
824 )
825 arguments = [
826 (
827 [
828 self.savedir, label, param, "log_likelihood", samples[param],
829 samples["log_likelihood"], latex_labels[param],
830 latex_labels["log_likelihood"],
831 self.preliminary_pages[label], self.checkpoint
832 ], function, error_message % (param)
833 ) for param in iterator
834 ]
835 self.pool.starmap(self._try_to_make_a_plot, arguments)
836 error_message = (
837 "Failed to generate bootstrapped oned_histogram plot for %s "
838 "because {}"
839 )
840 iterator, samples, function = self._mcmc_iterator(
841 label, "_oned_histogram_bootstrap_plot"
842 )
843 arguments = [
844 (
845 [
846 self.savedir, label, param, samples[param],
847 latex_labels[param], self.preliminary_pages[label],
848 self.package, self.checkpoint
849 ], function, error_message % (param)
850 ) for param in iterator
851 ]
852 self.pool.starmap(self._try_to_make_a_plot, arguments)
854 @staticmethod
855 def _oned_histogram_bootstrap_plot(
856 savedir, label, parameter, samples, latex_label, preliminary=False,
857 package="core", checkpoint=False, nsamples=1000, ntests=100, **kwargs
858 ):
859 """Generate a bootstrapped oned histogram plot for a given set of
860 samples
862 Parameters
863 ----------
864 savedir: str
865 the directory you wish to save the plot in
866 label: str
867 the label corresponding to the results file
868 parameter: str
869 the name of the parameter that you wish to plot
870 samples: PESummary.utils.array.Array
871 array containing the samples corresponding to parameter
872 latex_label: str
873 the latex label corresponding to parameter
874 preliminary: Bool, optional
875 if True, add a preliminary watermark to the plot
876 """
877 module = importlib.import_module(
878 "pesummary.{}.plots.plot".format(package)
879 )
881 filename = os.path.join(
882 savedir, "{}_1d_posterior_{}_bootstrap.png".format(label, parameter)
883 )
884 if os.path.isfile(filename) and checkpoint:
885 return
886 fig = module._1d_histogram_plot_bootstrap(
887 parameter, samples, latex_label, nsamples=nsamples, ntests=ntests,
888 **kwargs
889 )
890 _PlotGeneration.save(
891 fig, filename, preliminary=preliminary
892 )
894 @staticmethod
895 def _2d_contour_plot(
896 savedir, label, parameter_x, parameter_y, samples_x, samples_y,
897 latex_label_x, latex_label_y, preliminary=False, truth=None,
898 checkpoint=False
899 ):
900 """Generate a 2d contour plot for a given set of samples
902 Parameters
903 ----------
904 savedir: str
905 the directory you wish to save the plot in
906 label: str
907 the label corresponding to the results file
908 samples_x: PESummary.utils.array.Array
909 array containing the samples for the x axis
910 samples_y: PESummary.utils.array.Array
911 array containing the samples for the y axis
912 latex_label_x: str
913 the latex label for the x axis
914 latex_label_y: str
915 the latex label for the y axis
916 preliminary: Bool, optional
917 if True, add a preliminary watermark to the plot
918 """
919 from pesummary.core.plots.publication import twod_contour_plot
921 filename = os.path.join(
922 savedir, "{}_2d_contour_{}_{}.png".format(
923 label, parameter_x, parameter_y
924 )
925 )
926 if os.path.isfile(filename) and checkpoint:
927 return
928 fig = twod_contour_plot(
929 samples_x, samples_y, levels=[0.9, 0.5], xlabel=latex_label_x,
930 ylabel=latex_label_y, bins=50, truth=truth
931 )
932 _PlotGeneration.save(
933 fig, filename, preliminary=preliminary
934 )
936 @staticmethod
937 def _colored_sample_evolution_plot(
938 savedir, label, parameter_x, parameter_y, samples_x, samples_y,
939 latex_label_x, latex_label_y, preliminary=False, checkpoint=False
940 ):
941 """Generate a 2d contour plot for a given set of samples
943 Parameters
944 ----------
945 savedir: str
946 the directory you wish to save the plot in
947 label: str
948 the label corresponding to the results file
949 samples_x: PESummary.utils.array.Array
950 array containing the samples for the x axis
951 samples_y: PESummary.utils.array.Array
952 array containing the samples for the y axis
953 latex_label_x: str
954 the latex label for the x axis
955 latex_label_y: str
956 the latex label for the y axis
957 preliminary: Bool, optional
958 if True, add a preliminary watermark to the plot
959 """
960 filename = os.path.join(
961 savedir, "{}_sample_evolution_{}_{}_colored.png".format(
962 label, parameter_x, parameter_y
963 )
964 )
965 if os.path.isfile(filename) and checkpoint:
966 return
967 fig = core._sample_evolution_plot(
968 parameter_x, samples_x, latex_label_x, z=samples_y,
969 z_label=latex_label_y
970 )
971 _PlotGeneration.save(
972 fig, filename, preliminary=preliminary
973 )
975 @staticmethod
976 def _sample_evolution_plot(
977 savedir, label, parameter, samples, latex_label, injection,
978 preliminary=False, checkpoint=False
979 ):
980 """Generate a sample evolution plot for a given set of samples
982 Parameters
983 ----------
984 savedir: str
985 the directory you wish to save the plot in
986 label: str
987 the label corresponding to the results file
988 parameter: str
989 the name of the parameter that you wish to plot
990 samples: PESummary.utils.array.Array
991 array containing the samples corresponding to parameter
992 latex_label: str
993 the latex label corresponding to parameter
994 injection: float
995 the injected value
996 preliminary: Bool, optional
997 if True, add a preliminary watermark to the plot
998 """
999 filename = os.path.join(
1000 savedir, "{}_sample_evolution_{}.png".format(label, parameter)
1001 )
1002 if os.path.isfile(filename) and checkpoint:
1003 return
1004 fig = core._sample_evolution_plot(
1005 parameter, samples, latex_label, injection
1006 )
1007 _PlotGeneration.save(
1008 fig, filename, preliminary=preliminary
1009 )
1011 @staticmethod
1012 def _sample_evolution_plot_mcmc(
1013 savedir, label, parameter, samples, latex_label, injection,
1014 preliminary=False, checkpoint=False
1015 ):
1016 """Generate a sample evolution plot for a given set of mcmc chains
1018 Parameters
1019 ----------
1020 savedir: str
1021 the directory you wish to save the plot in
1022 label: str
1023 the label corresponding to the results file
1024 parameter: str
1025 the name of the parameter that you wish to plot
1026 samples: dict
1027 dictionary containing pesummary.utils.array.Array objects containing
1028 the samples corresponding to parameter for each chain
1029 latex_label: str
1030 the latex label corresponding to parameter
1031 injection: float
1032 the injected value
1033 preliminary: Bool, optional
1034 if True, add a preliminary watermark to the plot
1035 """
1036 filename = os.path.join(
1037 savedir, "{}_sample_evolution_{}.png".format(label, parameter)
1038 )
1039 if os.path.isfile(filename) and checkpoint:
1040 return
1041 same_samples = [val for key, val in samples.items()]
1042 fig = core._sample_evolution_plot_mcmc(
1043 parameter, same_samples, latex_label, injection
1044 )
1045 _PlotGeneration.save(
1046 fig, filename, preliminary=preliminary
1047 )
1049 def autocorrelation_plot(self, label):
1050 """Generate autocorrelation plots for all parameters in the result file
1052 Parameters
1053 ----------
1054 label: str
1055 the label for the results file that you wish to plot
1056 """
1057 error_message = (
1058 "Failed to generate an autocorrelation plot for %s because {}"
1059 )
1060 iterator, samples, function = self._mcmc_iterator(
1061 label, "_autocorrelation_plot"
1062 )
1063 arguments = [
1064 (
1065 [
1066 self.savedir, label, param, samples[param],
1067 self.preliminary_pages[label], self.checkpoint
1068 ], function, error_message % (param)
1069 ) for param in iterator
1070 ]
1071 self.pool.starmap(self._try_to_make_a_plot, arguments)
1073 @staticmethod
1074 def _autocorrelation_plot(
1075 savedir, label, parameter, samples, preliminary=False, checkpoint=False
1076 ):
1077 """Generate an autocorrelation plot for a given set of samples
1079 Parameters
1080 ----------
1081 savedir: str
1082 the directory you wish to save the plot in
1083 label: str
1084 the label corresponding to the results file
1085 parameter: str
1086 the name of the parameter that you wish to plot
1087 samples: PESummary.utils.array.Array
1088 array containing the samples corresponding to parameter
1089 preliminary: Bool, optional
1090 if True, add a preliminary watermark to the plot
1091 """
1092 filename = os.path.join(
1093 savedir, "{}_autocorrelation_{}.png".format(label, parameter)
1094 )
1095 if os.path.isfile(filename) and checkpoint:
1096 return
1097 fig = core._autocorrelation_plot(parameter, samples)
1098 _PlotGeneration.save(
1099 fig, filename, preliminary=preliminary
1100 )
1102 @staticmethod
1103 def _autocorrelation_plot_mcmc(
1104 savedir, label, parameter, samples, preliminary=False, checkpoint=False
1105 ):
1106 """Generate an autocorrelation plot for a list of samples, one for each
1107 mcmc chain
1109 Parameters
1110 ----------
1111 savedir: str
1112 the directory you wish to save the plot in
1113 label: str
1114 the label corresponding to the results file
1115 parameter: str
1116 the name of the parameter that you wish to plot
1117 samples: dict
1118 dictioanry of PESummary.utils.array.Array objects containing the
1119 samples corresponding to parameter for each mcmc chain
1120 preliminary: Bool, optional
1121 if True, add a preliminary watermark to the plot
1122 """
1123 filename = os.path.join(
1124 savedir, "{}_autocorrelation_{}.png".format(label, parameter)
1125 )
1126 if os.path.isfile(filename) and checkpoint:
1127 return
1128 same_samples = [val for key, val in samples.items()]
1129 fig = core._autocorrelation_plot_mcmc(parameter, same_samples)
1130 _PlotGeneration.save(
1131 fig, filename, preliminary=preliminary
1132 )
1134 def oned_cdf_plot(self, label):
1135 """Generate oned CDF plots for all parameters in the result file
1137 Parameters
1138 ----------
1139 label: str
1140 the label for the results file that you wish to plot
1141 """
1142 error_message = (
1143 "Failed to generate a CDF plot for %s because {}"
1144 )
1145 iterator, samples, function = self._mcmc_iterator(
1146 label, "_oned_cdf_plot"
1147 )
1148 arguments = [
1149 (
1150 [
1151 self.savedir, label, param, samples[param],
1152 latex_labels[param], self.preliminary_pages[label],
1153 self.checkpoint
1154 ], function, error_message % (param)
1155 ) for param in iterator
1156 ]
1157 self.pool.starmap(self._try_to_make_a_plot, arguments)
1159 @staticmethod
1160 def _oned_cdf_plot(
1161 savedir, label, parameter, samples, latex_label, preliminary=False,
1162 checkpoint=False
1163 ):
1164 """Generate a oned CDF plot for a given set of samples
1166 Parameters
1167 ----------
1168 savedir: str
1169 the directory you wish to save the plot in
1170 label: str
1171 the label corresponding to the results file
1172 parameter: str
1173 the name of the parameter that you wish to plot
1174 samples: PESummary.utils.array.Array
1175 array containing the samples corresponding to parameter
1176 latex_label: str
1177 the latex label corresponding to parameter
1178 preliminary: Bool, optional
1179 if True, add a preliminary watermark to the plot
1180 """
1181 filename = os.path.join(
1182 savedir + "{}_cdf_{}.png".format(label, parameter)
1183 )
1184 if os.path.isfile(filename) and checkpoint:
1185 return
1186 fig = core._1d_cdf_plot(parameter, samples, latex_label)
1187 _PlotGeneration.save(
1188 fig, filename, preliminary=preliminary
1189 )
1191 @staticmethod
1192 def _oned_cdf_plot_mcmc(
1193 savedir, label, parameter, samples, latex_label, preliminary=False,
1194 checkpoint=False
1195 ):
1196 """Generate a oned CDF plot for a given set of samples, one for each
1197 mcmc chain
1199 Parameters
1200 ----------
1201 savedir: str
1202 the directory you wish to save the plot in
1203 label: str
1204 the label corresponding to the results file
1205 parameter: str
1206 the name of the parameter that you wish to plot
1207 samples: dict
1208 dictionary of PESummary.utils.array.Array objects containing the
1209 samples corresponding to parameter for each mcmc chain
1210 latex_label: str
1211 the latex label corresponding to parameter
1212 preliminary: Bool, optional
1213 if True, add a preliminary watermark to the plot
1214 """
1215 filename = os.path.join(
1216 savedir + "{}_cdf_{}.png".format(label, parameter)
1217 )
1218 if os.path.isfile(filename) and checkpoint:
1219 return
1220 same_samples = [val for key, val in samples.items()]
1221 fig = core._1d_cdf_plot_mcmc(parameter, same_samples, latex_label)
1222 _PlotGeneration.save(
1223 fig, filename, preliminary=preliminary
1224 )
1226 def interactive_ridgeline_plot(self, label):
1227 """Generate an interactive ridgeline plot for all paramaters that are
1228 common to all result files
1229 """
1230 error_message = (
1231 "Failed to generate an interactive ridgeline plot for %s because {}"
1232 )
1233 for param in self.same_parameters:
1234 arguments = [
1235 self.savedir, param, self.same_samples[param],
1236 latex_labels[param], self.colors, self.checkpoint
1237 ]
1238 self._try_to_make_a_plot(
1239 arguments, self._interactive_ridgeline_plot,
1240 error_message % (param)
1241 )
1242 continue
1244 @staticmethod
1245 def _interactive_ridgeline_plot(
1246 savedir, parameter, samples, latex_label, colors, checkpoint=False
1247 ):
1248 """Generate an interactive ridgeline plot for
1249 """
1250 filename = os.path.join(
1251 savedir, "interactive_ridgeline_{}.html".format(parameter)
1252 )
1253 if os.path.isfile(filename) and checkpoint:
1254 return
1255 same_samples = [val for key, val in samples.items()]
1256 _ = interactive.ridgeline(
1257 same_samples, list(samples.keys()), xlabel=latex_label,
1258 colors=colors, write_to_html_file=filename
1259 )
1261 def interactive_corner_plot(self, label):
1262 """Generate an interactive corner plot for a given result file
1264 Parameters
1265 ----------
1266 label: str
1267 the label for the results file that you wish to plot
1268 """
1269 self._interactive_corner_plot(
1270 self.savedir, label, self.samples[label], latex_labels,
1271 self.checkpoint
1272 )
1274 @staticmethod
1275 def _interactive_corner_plot(
1276 savedir, label, samples, latex_labels, checkpoint=False
1277 ):
1278 """Generate an interactive corner plot for a given set of samples
1280 Parameters
1281 ----------
1282 savedir: str
1283 the directory you wish to save the plot in
1284 label: str
1285 the label corresponding to the results file
1286 samples: dict
1287 dictionary containing PESummary.utils.array.Array objects that
1288 contain samples for each parameter
1289 latex_labels: str
1290 latex labels for each parameter in samples
1291 """
1292 filename = os.path.join(
1293 savedir, "corner", "{}_interactive.html".format(label)
1294 )
1295 if os.path.isfile(filename) and checkpoint:
1296 return
1297 parameters = samples.keys()
1298 data = [samples[parameter] for parameter in parameters]
1299 latex_labels = [latex_labels[parameter] for parameter in parameters]
1300 _ = interactive.corner(
1301 data, latex_labels, write_to_html_file=filename
1302 )
1304 def oned_cdf_comparison_plot(self, label):
1305 """Generate oned comparison CDF plots for all parameters that are
1306 common to all result files
1308 Parameters
1309 ----------
1310 label: str
1311 the label for the results file that you wish to plot
1312 """
1313 error_message = (
1314 "Failed to generate a comparison CDF plot for %s because {}"
1315 )
1316 for param in self.same_parameters:
1317 if self.weights is not None:
1318 weights = [
1319 self.weights.get(key, None) for key in
1320 self.same_samples[param].keys()
1321 ]
1322 else:
1323 weights = None
1324 arguments = [
1325 self.savedir, param, self.same_samples[param],
1326 latex_labels[param], self.colors, self.linestyles,
1327 weights, self.preliminary_comparison_pages, self.checkpoint
1328 ]
1329 self._try_to_make_a_plot(
1330 arguments, self._oned_cdf_comparison_plot,
1331 error_message % (param)
1332 )
1333 continue
1335 @staticmethod
1336 def _oned_cdf_comparison_plot(
1337 savedir, parameter, samples, latex_label, colors, linestyles=None,
1338 weights=None, preliminary=False, checkpoint=False
1339 ):
1340 """Generate a oned comparison CDF plot for a given parameter
1342 Parameters
1343 ----------
1344 savedir: str
1345 the directory you wish to save the plot in
1346 parameter: str
1347 the name of the parameter that you wish to make a oned comparison
1348 histogram for
1349 samples: dict
1350 dictionary of pesummary.utils.array.Array objects containing the
1351 samples that correspond to parameter for each result file. The key
1352 should be the corresponding label
1353 latex_label: str
1354 the latex label for parameter
1355 colors: list
1356 list of colors to be used to distinguish different result files
1357 linestyles: list, optional
1358 list of linestyles used to distinguish different result files
1359 preliminary: Bool, optional
1360 if True, add a preliminary watermark to the plot
1361 """
1362 filename = os.path.join(
1363 savedir, "combined_cdf_{}.png".format(parameter)
1364 )
1365 if os.path.isfile(filename) and checkpoint:
1366 return
1367 keys = list(samples.keys())
1368 same_samples = [samples[key] for key in keys]
1369 fig = core._1d_cdf_comparison_plot(
1370 parameter, same_samples, colors, latex_label, keys, linestyles,
1371 weights=weights
1372 )
1373 _PlotGeneration.save(
1374 fig, filename, preliminary=preliminary
1375 )
1377 def box_plot_comparison_plot(self, label):
1378 """Generate comparison box plots for all parameters that are
1379 common to all result files
1381 Parameters
1382 ----------
1383 label: str
1384 the label for the results file that you wish to plot
1385 """
1386 error_message = (
1387 "Failed to generate a comparison box plot for %s because {}"
1388 )
1389 for param in self.same_parameters:
1390 arguments = [
1391 self.savedir, param, self.same_samples[param],
1392 latex_labels[param], self.colors,
1393 self.preliminary_comparison_pages, self.checkpoint
1394 ]
1395 self._try_to_make_a_plot(
1396 arguments, self._box_plot_comparison_plot,
1397 error_message % (param)
1398 )
1399 continue
1401 @staticmethod
1402 def _box_plot_comparison_plot(
1403 savedir, parameter, samples, latex_label, colors, preliminary=False,
1404 checkpoint=False
1405 ):
1406 """Generate a comparison box plot for a given parameter
1408 Parameters
1409 ----------
1410 savedir: str
1411 the directory you wish to save the plot in
1412 parameter: str
1413 the name of the parameter that you wish to make a oned comparison
1414 histogram for
1415 samples: dict
1416 dictionary of pesummary.utils.array.Array objects containing the
1417 samples that correspond to parameter for each result file. The key
1418 should be the corresponding label
1419 latex_label: str
1420 the latex label for parameter
1421 colors: list
1422 list of colors to be used to distinguish different result files
1423 preliminary: Bool, optional
1424 if True, add a preliminary watermark to the plot
1425 """
1426 filename = os.path.join(
1427 savedir, "combined_boxplot_{}.png".format(parameter)
1428 )
1429 if os.path.isfile(filename) and checkpoint:
1430 return
1431 same_samples = [val for key, val in samples.items()]
1432 fig = core._comparison_box_plot(
1433 parameter, same_samples, colors, latex_label,
1434 list(samples.keys())
1435 )
1436 _PlotGeneration.save(
1437 fig, filename, preliminary=preliminary
1438 )
1440 def custom_plot(self, label):
1441 """Generate custom plots according to the passed python file
1443 Parameters
1444 ----------
1445 label: str
1446 the label for the results file that you wish to plot
1447 """
1448 import importlib
1450 if self.custom_plotting[0] != "":
1451 import sys
1453 sys.path.append(self.custom_plotting[0])
1454 mod = importlib.import_module(self.custom_plotting[1])
1456 methods = [getattr(mod, i) for i in mod.__single_plots__]
1457 for num, i in enumerate(methods):
1458 fig = i(
1459 list(self.samples[label].keys()), self.samples[label]
1460 )
1461 _PlotGeneration.save(
1462 fig, os.path.join(
1463 self.savedir, "{}_custom_plotting_{}".format(label, num)
1464 ), preliminary=self.preliminary_pages[label]
1465 )