Coverage for pesummary/core/plots/main.py: 74.9%
450 statements
« prev ^ index » next coverage.py v7.4.4, created at 2026-01-15 17:49 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2026-01-15 17:49 +0000
1#! /usr/bin/env python
3# Licensed under an MIT style license -- see LICENSE.md
5import numpy as np
6import os
7import importlib
8from contextlib import contextmanager
9from multiprocessing import Pool
11from pesummary.core.plots.latex_labels import latex_labels
12from pesummary.utils.utils import (
13 logger, get_matplotlib_backend, make_dir, get_matplotlib_style_file
14)
15from pesummary.core.plots import plot as core
16from pesummary.core.plots import interactive
18import matplotlib
19import matplotlib.style
21__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
22matplotlib.use(get_matplotlib_backend(parallel=True))
23matplotlib.style.use(get_matplotlib_style_file())
26class _PlotGeneration(object):
27 """Super class to handle the plot generation for a given set of result
28 files
30 Parameters
31 ----------
32 savedir: str
33 the directory to store the plots
34 webdir: str
35 the web directory of the run
36 labels: list
37 list of labels used to distinguish the result files
38 samples: dict
39 dictionary of posterior samples stored in the result files
40 kde_plot: Bool
41 if True, kde plots are generated instead of histograms, Default False
42 existing_labels: list
43 list of labels stored in an existing metafile
44 existing_injection_data: dict
45 dictionary of injection data stored in an existing metafile
46 existing_samples: dict
47 dictionary of posterior samples stored in an existing metafile
48 same_parameters: list
49 list of paramerers that are common in all result files
50 injection_data: dict
51 dictionary of injection data for each result file
52 result_files: list
53 list of result files passed
54 colors: list
55 colors that you wish to use to distinguish different result files
56 disable_comparison: bool, optional
57 whether to make comparison plots, default is True.
58 if disable_comparison is False and len(labels) == 1, no comparsion plots
59 will be generated
60 disable_interactive: bool, optional
61 whether to make interactive plots, default is False
62 disable_corner: bool, optional
63 whether to make the corner plot, default is False
64 """
65 def __init__(
66 self, savedir=None, webdir=None, labels=None, samples=None,
67 kde_plot=False, existing_labels=None, existing_injection_data=None,
68 existing_samples=None, existing_weights=None, same_parameters=None,
69 injection_data=None, colors=None, custom_plotting=None,
70 add_to_existing=False, priors={}, include_prior=False, weights=None,
71 disable_comparison=False, linestyles=None, disable_interactive=False,
72 multi_process=1, mcmc_samples=False, disable_corner=False,
73 corner_params=None, expert_plots=True, checkpoint=False, key_data=None
74 ):
75 self.package = "core"
76 self.webdir = webdir
77 make_dir(self.webdir)
78 make_dir(os.path.join(self.webdir, "plots", "corner"))
79 self.savedir = savedir
80 self.labels = labels
81 self.mcmc_samples = mcmc_samples
82 self.samples = samples
83 self.kde_plot = kde_plot
84 self.existing_labels = existing_labels
85 self.existing_injection_data = existing_injection_data
86 self.existing_samples = existing_samples
87 self.existing_weights = existing_weights
88 self.same_parameters = same_parameters
89 self.injection_data = injection_data
90 self.colors = colors
91 self.custom_plotting = custom_plotting
92 self.add_to_existing = add_to_existing
93 self.priors = priors
94 self.include_prior = include_prior
95 self.linestyles = linestyles
96 self.make_interactive = not disable_interactive
97 self.make_corner = not disable_corner
98 self.corner_params = corner_params
99 self.expert_plots = expert_plots
100 if self.mcmc_samples and self.expert_plots:
101 logger.warning("Unable to generate expert plots for mcmc samples")
102 self.expert_plots = False
103 self.checkpoint = checkpoint
104 self.key_data = key_data
105 self.multi_process = multi_process
106 self.pool = None
107 self.preliminary_pages = {label: False for label in self.labels}
108 self.preliminary_comparison_pages = False
109 self.make_comparison = (
110 not disable_comparison and self._total_number_of_labels > 1
111 )
112 self.weights = (
113 weights if weights is not None else {i: None for i in self.labels}
114 )
116 if self.same_parameters is not None and not self.mcmc_samples:
117 self.same_samples = {
118 param: {
119 key: item[param] for key, item in self.samples.items()
120 } for param in self.same_parameters
121 }
122 else:
123 self.same_samples = None
125 for i in self.samples.keys():
126 try:
127 self.check_latex_labels(
128 self.samples[i].keys(remove_debug=False)
129 )
130 except TypeError:
131 try:
132 self.check_latex_labels(self.samples[i].keys())
133 except TypeError:
134 pass
136 self.plot_type_dictionary = {
137 "oned_histogram": self.oned_histogram_plot,
138 "sample_evolution": self.sample_evolution_plot,
139 "autocorrelation": self.autocorrelation_plot,
140 "oned_cdf": self.oned_cdf_plot,
141 "custom": self.custom_plot
142 }
143 if self.make_corner:
144 self.plot_type_dictionary.update({"corner": self.corner_plot})
145 if self.expert_plots:
146 self.plot_type_dictionary.update({"expert": self.expert_plot})
147 if self.make_comparison:
148 self.plot_type_dictionary.update(dict(
149 oned_histogram_comparison=self.oned_histogram_comparison_plot,
150 oned_cdf_comparison=self.oned_cdf_comparison_plot,
151 box_plot_comparison=self.box_plot_comparison_plot,
152 ))
153 if self.make_interactive:
154 self.plot_type_dictionary.update(
155 dict(
156 interactive_corner=self.interactive_corner_plot
157 )
158 )
159 if self.make_comparison:
160 self.plot_type_dictionary.update(
161 dict(
162 interactive_ridgeline=self.interactive_ridgeline_plot
163 )
164 )
166 @staticmethod
167 def save(fig, name, preliminary=False, close=True, format="png"):
168 """Save a figure to disk.
170 Parameters
171 ----------
172 fig: matplotlib.pyplot.figure
173 Matplotlib figure that you wish to save
174 name: str
175 Name of the file that you wish to write it too
176 close: Bool, optional
177 Close the figure after it has been saved
178 format: str, optional
179 Format used to save the image
180 """
181 n = len(format)
182 if ".%s" % (format) != name[-n - 1:]:
183 name += ".%s" % (format)
184 if preliminary:
185 fig.text(
186 0.5, 0.5, 'Preliminary', fontsize=90, color='gray', alpha=0.1,
187 ha='center', va='center', rotation=30
188 )
189 fig.tight_layout()
190 fig.savefig(name, format=format)
191 if close:
192 fig.close()
194 @property
195 def _total_number_of_labels(self):
196 _number_of_labels = 0
197 for item in [self.labels, self.existing_labels]:
198 if isinstance(item, list):
199 _number_of_labels += len(item)
200 return _number_of_labels
202 @staticmethod
203 def check_latex_labels(parameters):
204 """Check to see if there is a latex label for all parameters. If not,
205 then create one
207 Parameters
208 ----------
209 parameters: list
210 list of parameters
211 """
212 for i in parameters:
213 if i not in list(latex_labels.keys()):
214 latex_labels[i] = i.replace("_", " ")
216 @property
217 def savedir(self):
218 return self._savedir
220 @savedir.setter
221 def savedir(self, savedir):
222 self._savedir = savedir
223 if savedir is None:
224 self._savedir = self.webdir + "/plots/"
226 @contextmanager
227 def setup_pool(self):
228 """Setup a pool of processes to speed up plot generation
229 """
230 with Pool(processes=self.multi_process) as pool:
231 yield pool
232 self.pool = None
234 def generate_plots(self):
235 """Generate all plots for all result files
236 """
237 for i in self.labels:
238 with self.setup_pool() as self.pool:
239 logger.debug("Starting to generate plots for {}".format(i))
240 self._generate_plots(i)
241 if self.make_interactive:
242 logger.debug(
243 "Starting to generate interactive plots for {}".format(i)
244 )
245 self._generate_interactive_plots(i)
246 if self.add_to_existing:
247 self.add_existing_data()
248 if self.make_comparison:
249 logger.debug("Starting to generate comparison plots")
250 self._generate_comparison_plots()
252 def check_key_data_in_dict(self, label, param):
253 """Check to see if there is key data for a given param
255 Parameters
256 ----------
257 label: str
258 the label used to distinguish a given run
259 param: str
260 name of the parameter you wish to return prior samples for
261 """
262 if self.key_data is None:
263 return None
264 elif label not in self.key_data.keys():
265 return None
266 elif param not in self.key_data[label].keys():
267 return None
268 return self.key_data[label][param]
270 def check_prior_samples_in_dict(self, label, param):
271 """Check to see if there are prior samples for a given param
273 Parameters
274 ----------
275 label: str
276 the label used to distinguish a given run
277 param: str
278 name of the parameter you wish to return prior samples for
279 """
280 cond1 = "samples" in self.priors.keys()
281 if cond1 and label in self.priors["samples"].keys():
282 cond1 = self.priors["samples"][label] != []
283 if cond1 and param in self.priors["samples"][label].keys():
284 return self.priors["samples"][label][param]
285 return None
286 return None
288 def add_existing_data(self):
289 """
290 """
291 from pesummary.utils.utils import _add_existing_data
293 self = _add_existing_data(self)
295 def _generate_plots(self, label):
296 """Generate all plots for a a given result file
297 """
298 if self.make_corner:
299 self.try_to_make_a_plot("corner", label=label)
300 self.try_to_make_a_plot("oned_histogram", label=label)
301 self.try_to_make_a_plot("sample_evolution", label=label)
302 self.try_to_make_a_plot("autocorrelation", label=label)
303 self.try_to_make_a_plot("oned_cdf", label=label)
304 if self.expert_plots:
305 self.try_to_make_a_plot("expert", label=label)
306 if self.custom_plotting:
307 self.try_to_make_a_plot("custom", label=label)
309 def _generate_interactive_plots(self, label):
310 """Generate all interactive plots and save them to an html file ready
311 to be imported later
312 """
313 self.try_to_make_a_plot("interactive_corner", label=label)
314 if self.make_comparison:
315 self.try_to_make_a_plot("interactive_ridgeline")
317 def _generate_comparison_plots(self):
318 """Generate all comparison plots
319 """
320 self.try_to_make_a_plot("oned_histogram_comparison")
321 self.try_to_make_a_plot("oned_cdf_comparison")
322 self.try_to_make_a_plot("box_plot_comparison")
324 def try_to_make_a_plot(self, plot_type, label=None):
325 """Wrapper function to _try_to_make_a_plot
327 Parameters
328 ----------
329 plot_type: str
330 String to describe the plot that you wish to try and make
331 label: str
332 The label of the results file that you wish to plot
333 """
334 self._try_to_make_a_plot(
335 [label], self.plot_type_dictionary[plot_type],
336 "Failed to generate %s plot because {}" % (plot_type)
337 )
339 @staticmethod
340 def _try_to_make_a_plot(arguments, function, message):
341 """Try to make a plot. If it fails return an error message and continue
342 plotting
344 Parameters
345 ----------
346 arguments: list
347 list of arguments that you wish to pass to function
348 function: func
349 function that you wish to execute
350 message: str
351 the error message that you wish to be printed.
352 """
353 try:
354 function(*arguments)
355 except RuntimeError:
356 try:
357 from matplotlib import rcParams
359 original = rcParams['text.usetex']
360 rcParams['text.usetex'] = False
361 function(*arguments)
362 rcParams['text.usetex'] = original
363 except Exception as e:
364 logger.info(message.format(e))
365 except Exception as e:
366 logger.info(message.format(e))
367 finally:
368 from matplotlib import pyplot
370 pyplot.close()
372 def corner_plot(self, label):
373 """Generate a corner plot for a given result file
375 Parameters
376 ----------
377 label: str
378 the label for the results file that you wish to plot
379 """
380 if self.mcmc_samples:
381 samples = self.samples[label].combine
382 else:
383 samples = self.samples[label]
384 self._corner_plot(
385 self.savedir, label, samples, latex_labels, self.webdir,
386 self.corner_params, self.preliminary_pages[label], self.checkpoint
387 )
389 @staticmethod
390 def _corner_plot(
391 savedir, label, samples, latex_labels, webdir, params, preliminary=False,
392 checkpoint=False
393 ):
394 """Generate a corner plot for a given set of samples
396 Parameters
397 ----------
398 savedir: str
399 the directory you wish to save the plot in
400 label: str
401 the label corresponding to the results file
402 samples: dict
403 dictionary containing PESummary.utils.array.Array objects that
404 contain samples for each parameter
405 latex_labels: str
406 latex labels for each parameter in samples
407 webdir: str
408 the directory where the `js` directory is located
409 preliminary: Bool, optional
410 if True, add a preliminary watermark to the plot
411 """
412 import warnings
414 with warnings.catch_warnings():
415 warnings.simplefilter("ignore")
416 filename = os.path.join(
417 savedir, "corner", "{}_all_density_plots.png".format(label)
418 )
419 if os.path.isfile(filename) and checkpoint:
420 return
421 fig, params, data = core._make_corner_plot(
422 samples, latex_labels, corner_parameters=params
423 )
424 _PlotGeneration.save(
425 fig, filename, preliminary=preliminary
426 )
427 combine_corner = open(
428 os.path.join(webdir, "js", "combine_corner.js")
429 )
430 combine_corner = combine_corner.readlines()
431 params = [str(i) for i in params]
432 ind = [
433 linenumber for linenumber, line in enumerate(combine_corner)
434 if "var list = {}" in line
435 ][0]
436 combine_corner.insert(
437 ind + 1, " list['{}'] = {};\n".format(label, params)
438 )
439 new_file = open(
440 os.path.join(webdir, "js", "combine_corner.js"), "w"
441 )
442 new_file.writelines(combine_corner)
443 new_file.close()
444 combine_corner = open(
445 os.path.join(webdir, "js", "combine_corner.js")
446 )
447 combine_corner = combine_corner.readlines()
448 params = [str(i) for i in params]
449 ind = [
450 linenumber for linenumber, line in enumerate(combine_corner)
451 if "var data = {}" in line
452 ][0]
453 combine_corner.insert(
454 ind + 1, " data['{}'] = {};\n".format(label, data)
455 )
456 new_file = open(
457 os.path.join(webdir, "js", "combine_corner.js"), "w"
458 )
459 new_file.writelines(combine_corner)
460 new_file.close()
462 def _mcmc_iterator(self, label, function):
463 """If the data is a set of mcmc chains, return a 2d list of samples
464 to plot. Otherwise return a list of posterior samples
465 """
466 if self.mcmc_samples:
467 function += "_mcmc"
468 return self.same_parameters, self.samples[label], getattr(
469 self, function
470 )
471 return self.samples[label].keys(), self.samples[label], getattr(
472 self, function
473 )
475 def oned_histogram_plot(self, label):
476 """Generate oned histogram plots for all parameters in the result file
478 Parameters
479 ----------
480 label: str
481 the label for the results file that you wish to plot
482 """
483 error_message = (
484 "Failed to generate oned_histogram plot for %s because {}"
485 )
487 iterator, samples, function = self._mcmc_iterator(
488 label, "_oned_histogram_plot"
489 )
491 prior = lambda param: self.check_prior_samples_in_dict(
492 label, param
493 ) if self.include_prior else None
494 key_data = lambda label, param: self.check_key_data_in_dict(label, param)
496 arguments = [
497 (
498 [
499 self.savedir, label, param, samples[param],
500 latex_labels[param], self.injection_data[label][param],
501 self.kde_plot, prior(param), self.weights[label],
502 self.package, self.preliminary_pages[label],
503 self.checkpoint, key_data(label, param)
504 ], function, error_message % (param)
505 ) for param in iterator
506 ]
507 self.pool.starmap(self._try_to_make_a_plot, arguments)
509 def oned_histogram_comparison_plot(self, label):
510 """Generate oned comparison histogram plots for all parameters that are
511 common to all result files
513 Parameters
514 ----------
515 label: str
516 the label for the results file that you wish to plot
517 """
518 error_message = (
519 "Failed to generate a comparison histogram plot for %s because {}"
520 )
521 for param in self.same_parameters:
522 injection = [
523 value[param] for value in self.injection_data.values()
524 ]
525 if self.weights is not None:
526 weights = [
527 self.weights.get(key, None) for key in
528 self.same_samples[param].keys()
529 ]
530 else:
531 weights = None
533 arguments = [
534 self.savedir, param, self.same_samples[param],
535 latex_labels[param], self.colors, injection, weights,
536 self.kde_plot, self.linestyles, self.package,
537 self.preliminary_comparison_pages, self.checkpoint, None
538 ]
539 self._try_to_make_a_plot(
540 arguments, self._oned_histogram_comparison_plot,
541 error_message % (param)
542 )
543 continue
545 @staticmethod
546 def _oned_histogram_comparison_plot(
547 savedir, parameter, samples, latex_label, colors, injection, weights,
548 kde=False, linestyles=None, package="core", preliminary=False,
549 checkpoint=False, filename=None,
550 ):
551 """Generate a oned comparison histogram plot for a given parameter
553 Parameters
554 ----------i
555 savedir: str
556 the directory you wish to save the plot in
557 parameter: str
558 the name of the parameter that you wish to make a oned comparison
559 histogram for
560 samples: dict
561 dictionary of pesummary.utils.array.Array objects containing the
562 samples that correspond to parameter for each result file. The key
563 should be the corresponding label
564 latex_label: str
565 the latex label for parameter
566 colors: list
567 list of colors to be used to distinguish different result files
568 injection: list
569 list of injected values, one for each analysis
570 kde: Bool, optional
571 if True, kde plots will be generated rather than 1d histograms
572 linestyles: list, optional
573 list of linestyles used to distinguish different result files
574 preliminary: Bool, optional
575 if True, add a preliminary watermark to the plot
576 """
577 import math
578 module = importlib.import_module(
579 "pesummary.{}.plots.plot".format(package)
580 )
581 if filename is None:
582 filename = os.path.join(
583 savedir, "combined_1d_posterior_{}.png".format(parameter)
584 )
585 if os.path.isfile(filename) and checkpoint:
586 return
587 hist = not kde
588 for num, inj in enumerate(injection):
589 if math.isnan(inj):
590 injection[num] = None
591 same_samples = [val for key, val in samples.items()]
592 fig = module._1d_comparison_histogram_plot(
593 parameter, same_samples, colors, latex_label,
594 list(samples.keys()), inj_value=injection, kde=kde,
595 linestyles=linestyles, hist=hist, weights=weights
596 )
597 _PlotGeneration.save(
598 fig, filename, preliminary=preliminary
599 )
601 @staticmethod
602 def _oned_histogram_plot(
603 savedir, label, parameter, samples, latex_label, injection, kde=False,
604 prior=None, weights=None, package="core", preliminary=False,
605 checkpoint=False, key_data=None
606 ):
607 """Generate a oned histogram plot for a given set of samples
609 Parameters
610 ----------
611 savedir: str
612 the directory you wish to save the plot in
613 label: str
614 the label corresponding to the results file
615 parameter: str
616 the name of the parameter that you wish to plot
617 samples: PESummary.utils.array.Array
618 array containing the samples corresponding to parameter
619 latex_label: str
620 the latex label corresponding to parameter
621 injection: float
622 the injected value
623 kde: Bool, optional
624 if True, kde plots will be generated rather than 1d histograms
625 prior: PESummary.utils.array.Array, optional
626 the prior samples for param
627 weights: PESummary.utils.utilsrray, optional
628 the weights for each samples. If None, assumed to be 1
629 preliminary: Bool, optional
630 if True, add a preliminary watermark to the plot
631 """
632 import math
633 module = importlib.import_module(
634 "pesummary.{}.plots.plot".format(package)
635 )
637 if math.isnan(injection):
638 injection = None
639 hist = not kde
641 filename = os.path.join(
642 savedir, "{}_1d_posterior_{}.png".format(label, parameter)
643 )
644 if os.path.isfile(filename) and checkpoint:
645 return
646 fig = module._1d_histogram_plot(
647 parameter, samples, latex_label, inj_value=injection, kde=kde,
648 hist=hist, prior=prior, weights=weights, key_data=key_data
649 )
650 _PlotGeneration.save(
651 fig, filename, preliminary=preliminary
652 )
654 @staticmethod
655 def _oned_histogram_plot_mcmc(
656 savedir, label, parameter, samples, latex_label, injection, kde=False,
657 prior=None, weights=None, package="core", preliminary=False,
658 checkpoint=False, key_data=None
659 ):
660 """Generate a oned histogram plot for a given set of samples for
661 multiple mcmc chains
663 Parameters
664 ----------
665 savedir: str
666 the directory you wish to save the plot in
667 label: str
668 the label corresponding to the results file
669 parameter: str
670 the name of the parameter that you wish to plot
671 samples: dict
672 dictionary of PESummary.utils.array.Array objects containing the
673 samples corresponding to parameter for multiple mcmc chains
674 latex_label: str
675 the latex label corresponding to parameter
676 injection: float
677 the injected value
678 kde: Bool, optional
679 if True, kde plots will be generated rather than 1d histograms
680 prior: PESummary.utils.array.Array, optional
681 the prior samples for param
682 weights: PESummary.utils.array.Array, optional
683 the weights for each samples. If None, assumed to be 1
684 preliminary: Bool, optional
685 if True, add a preliminary watermark to the plot
686 """
687 import math
688 from pesummary.utils.array import Array
690 module = importlib.import_module(
691 "pesummary.{}.plots.plot".format(package)
692 )
694 if math.isnan(injection):
695 injection = None
696 same_samples = [val for key, val in samples.items()]
697 filename = os.path.join(
698 savedir, "{}_1d_posterior_{}.png".format(label, parameter)
699 )
700 if os.path.isfile(filename) and checkpoint:
701 pass
702 else:
703 fig = module._1d_histogram_plot_mcmc(
704 parameter, same_samples, latex_label, inj_value=injection,
705 kde=kde, prior=prior, weights=weights
706 )
707 _PlotGeneration.save(
708 fig, filename, preliminary=preliminary
709 )
710 filename = os.path.join(
711 savedir, "{}_1d_posterior_{}_combined.png".format(label, parameter)
712 )
713 if os.path.isfile(filename) and checkpoint:
714 pass
715 else:
716 fig = module._1d_histogram_plot(
717 parameter, Array(np.concatenate(same_samples)), latex_label,
718 inj_value=injection, kde=kde, prior=prior, weights=weights
719 )
720 _PlotGeneration.save(
721 fig, filename, preliminary=preliminary
722 )
724 def sample_evolution_plot(self, label):
725 """Generate sample evolution plots for all parameters in the result file
727 Parameters
728 ----------
729 label: str
730 the label for the results file that you wish to plot
731 """
732 error_message = (
733 "Failed to generate a sample evolution plot for %s because {}"
734 )
735 iterator, samples, function = self._mcmc_iterator(
736 label, "_sample_evolution_plot"
737 )
738 arguments = [
739 (
740 [
741 self.savedir, label, param, samples[param],
742 latex_labels[param], self.injection_data[label][param],
743 self.preliminary_pages[label], self.checkpoint
744 ], function, error_message % (param)
745 ) for param in iterator
746 ]
747 self.pool.starmap(self._try_to_make_a_plot, arguments)
749 def expert_plot(self, label):
750 """Generate expert plots for diagnostics
752 Parameters
753 ----------
754 label: str
755 the label for the results file that you wish to plot
756 """
757 error_message = (
758 "Failed to generate log_likelihood-%s 2d contour plot because {}"
759 )
760 iterator, samples, function = self._mcmc_iterator(
761 label, "_2d_contour_plot"
762 )
763 _debug = self.samples[label].debug_keys()
764 arguments = [
765 (
766 [
767 self.savedir, label, param, "log_likelihood", samples[param],
768 samples["log_likelihood"], latex_labels[param],
769 latex_labels["log_likelihood"],
770 self.preliminary_pages[label], [
771 samples[param][np.argmax(samples["log_likelihood"])],
772 np.max(samples["log_likelihood"]),
773 ], self.checkpoint
774 ], function, error_message % (param)
775 ) for param in iterator + _debug
776 ]
777 self.pool.starmap(self._try_to_make_a_plot, arguments)
778 _reweight_keys = [
779 param for param in self.samples[label].debug_keys() if
780 "_non_reweighted" in param
781 ]
782 if len(_reweight_keys):
783 error_message = (
784 "Failed to generate %s-%s 2d contour plot because {}"
785 )
786 _base_param = lambda p: p.split("_non_reweighted")[0][1:]
787 arguments = [
788 (
789 [
790 self.savedir, label, _base_param(param), param,
791 samples[_base_param(param)], samples[param],
792 latex_labels[_base_param(param)], latex_labels[param],
793 self.preliminary_pages[label], None, self.checkpoint
794 ], function, error_message % (_base_param(param), param)
795 ) for param in _reweight_keys
796 ]
797 self.pool.starmap(self._try_to_make_a_plot, arguments)
798 error_message = (
799 "Failed to generate a histogram plot comparing %s and %s "
800 "because {}"
801 )
802 arguments = [
803 (
804 [
805 self.savedir, _base_param(param), {
806 "reweighted": samples[_base_param(param)],
807 "non-reweighted": samples[param]
808 }, latex_labels[_base_param(param)], ['b', 'r'],
809 [np.nan, np.nan], True, None, self.package,
810 self.preliminary_comparison_pages, self.checkpoint,
811 os.path.join(
812 self.savedir, "{}_1d_posterior_{}_{}.png".format(
813 label, _base_param(param), param
814 )
815 )
816 ], self._oned_histogram_comparison_plot,
817 error_message % (_base_param(param), param)
818 ) for param in _reweight_keys
819 ]
820 self.pool.starmap(self._try_to_make_a_plot, arguments)
822 error_message = (
823 "Failed to generate log_likelihood-%s sample_evolution plot "
824 "because {}"
825 )
826 iterator, samples, function = self._mcmc_iterator(
827 label, "_colored_sample_evolution_plot"
828 )
829 arguments = [
830 (
831 [
832 self.savedir, label, param, "log_likelihood", samples[param],
833 samples["log_likelihood"], latex_labels[param],
834 latex_labels["log_likelihood"],
835 self.preliminary_pages[label], self.checkpoint
836 ], function, error_message % (param)
837 ) for param in iterator
838 ]
839 self.pool.starmap(self._try_to_make_a_plot, arguments)
840 error_message = (
841 "Failed to generate bootstrapped oned_histogram plot for %s "
842 "because {}"
843 )
844 iterator, samples, function = self._mcmc_iterator(
845 label, "_oned_histogram_bootstrap_plot"
846 )
847 arguments = [
848 (
849 [
850 self.savedir, label, param, samples[param],
851 latex_labels[param], self.preliminary_pages[label],
852 self.package, self.checkpoint
853 ], function, error_message % (param)
854 ) for param in iterator
855 ]
856 self.pool.starmap(self._try_to_make_a_plot, arguments)
858 @staticmethod
859 def _oned_histogram_bootstrap_plot(
860 savedir, label, parameter, samples, latex_label, preliminary=False,
861 package="core", checkpoint=False, nsamples=1000, ntests=100, **kwargs
862 ):
863 """Generate a bootstrapped oned histogram plot for a given set of
864 samples
866 Parameters
867 ----------
868 savedir: str
869 the directory you wish to save the plot in
870 label: str
871 the label corresponding to the results file
872 parameter: str
873 the name of the parameter that you wish to plot
874 samples: PESummary.utils.array.Array
875 array containing the samples corresponding to parameter
876 latex_label: str
877 the latex label corresponding to parameter
878 preliminary: Bool, optional
879 if True, add a preliminary watermark to the plot
880 """
881 module = importlib.import_module(
882 "pesummary.{}.plots.plot".format(package)
883 )
885 filename = os.path.join(
886 savedir, "{}_1d_posterior_{}_bootstrap.png".format(label, parameter)
887 )
888 if os.path.isfile(filename) and checkpoint:
889 return
890 fig = module._1d_histogram_plot_bootstrap(
891 parameter, samples, latex_label, nsamples=nsamples, ntests=ntests,
892 **kwargs
893 )
894 _PlotGeneration.save(
895 fig, filename, preliminary=preliminary
896 )
898 @staticmethod
899 def _2d_contour_plot(
900 savedir, label, parameter_x, parameter_y, samples_x, samples_y,
901 latex_label_x, latex_label_y, preliminary=False, truth=None,
902 checkpoint=False
903 ):
904 """Generate a 2d contour plot for a given set of samples
906 Parameters
907 ----------
908 savedir: str
909 the directory you wish to save the plot in
910 label: str
911 the label corresponding to the results file
912 samples_x: PESummary.utils.array.Array
913 array containing the samples for the x axis
914 samples_y: PESummary.utils.array.Array
915 array containing the samples for the y axis
916 latex_label_x: str
917 the latex label for the x axis
918 latex_label_y: str
919 the latex label for the y axis
920 preliminary: Bool, optional
921 if True, add a preliminary watermark to the plot
922 """
923 from pesummary.core.plots.publication import twod_contour_plot
925 filename = os.path.join(
926 savedir, "{}_2d_contour_{}_{}.png".format(
927 label, parameter_x, parameter_y
928 )
929 )
930 if os.path.isfile(filename) and checkpoint:
931 return
932 fig = twod_contour_plot(
933 samples_x, samples_y, levels=[0.9, 0.5], xlabel=latex_label_x,
934 ylabel=latex_label_y, bins=50, truth=truth
935 )
936 _PlotGeneration.save(
937 fig, filename, preliminary=preliminary
938 )
940 @staticmethod
941 def _colored_sample_evolution_plot(
942 savedir, label, parameter_x, parameter_y, samples_x, samples_y,
943 latex_label_x, latex_label_y, preliminary=False, checkpoint=False
944 ):
945 """Generate a 2d contour plot for a given set of samples
947 Parameters
948 ----------
949 savedir: str
950 the directory you wish to save the plot in
951 label: str
952 the label corresponding to the results file
953 samples_x: PESummary.utils.array.Array
954 array containing the samples for the x axis
955 samples_y: PESummary.utils.array.Array
956 array containing the samples for the y axis
957 latex_label_x: str
958 the latex label for the x axis
959 latex_label_y: str
960 the latex label for the y axis
961 preliminary: Bool, optional
962 if True, add a preliminary watermark to the plot
963 """
964 filename = os.path.join(
965 savedir, "{}_sample_evolution_{}_{}_colored.png".format(
966 label, parameter_x, parameter_y
967 )
968 )
969 if os.path.isfile(filename) and checkpoint:
970 return
971 fig = core._sample_evolution_plot(
972 parameter_x, samples_x, latex_label_x, z=samples_y,
973 z_label=latex_label_y
974 )
975 _PlotGeneration.save(
976 fig, filename, preliminary=preliminary
977 )
979 @staticmethod
980 def _sample_evolution_plot(
981 savedir, label, parameter, samples, latex_label, injection,
982 preliminary=False, checkpoint=False
983 ):
984 """Generate a sample evolution plot for a given set of samples
986 Parameters
987 ----------
988 savedir: str
989 the directory you wish to save the plot in
990 label: str
991 the label corresponding to the results file
992 parameter: str
993 the name of the parameter that you wish to plot
994 samples: PESummary.utils.array.Array
995 array containing the samples corresponding to parameter
996 latex_label: str
997 the latex label corresponding to parameter
998 injection: float
999 the injected value
1000 preliminary: Bool, optional
1001 if True, add a preliminary watermark to the plot
1002 """
1003 filename = os.path.join(
1004 savedir, "{}_sample_evolution_{}.png".format(label, parameter)
1005 )
1006 if os.path.isfile(filename) and checkpoint:
1007 return
1008 fig = core._sample_evolution_plot(
1009 parameter, samples, latex_label, injection
1010 )
1011 _PlotGeneration.save(
1012 fig, filename, preliminary=preliminary
1013 )
1015 @staticmethod
1016 def _sample_evolution_plot_mcmc(
1017 savedir, label, parameter, samples, latex_label, injection,
1018 preliminary=False, checkpoint=False
1019 ):
1020 """Generate a sample evolution plot for a given set of mcmc chains
1022 Parameters
1023 ----------
1024 savedir: str
1025 the directory you wish to save the plot in
1026 label: str
1027 the label corresponding to the results file
1028 parameter: str
1029 the name of the parameter that you wish to plot
1030 samples: dict
1031 dictionary containing pesummary.utils.array.Array objects containing
1032 the samples corresponding to parameter for each chain
1033 latex_label: str
1034 the latex label corresponding to parameter
1035 injection: float
1036 the injected value
1037 preliminary: Bool, optional
1038 if True, add a preliminary watermark to the plot
1039 """
1040 filename = os.path.join(
1041 savedir, "{}_sample_evolution_{}.png".format(label, parameter)
1042 )
1043 if os.path.isfile(filename) and checkpoint:
1044 return
1045 same_samples = [val for key, val in samples.items()]
1046 fig = core._sample_evolution_plot_mcmc(
1047 parameter, same_samples, latex_label, injection
1048 )
1049 _PlotGeneration.save(
1050 fig, filename, preliminary=preliminary
1051 )
1053 def autocorrelation_plot(self, label):
1054 """Generate autocorrelation plots for all parameters in the result file
1056 Parameters
1057 ----------
1058 label: str
1059 the label for the results file that you wish to plot
1060 """
1061 error_message = (
1062 "Failed to generate an autocorrelation plot for %s because {}"
1063 )
1064 iterator, samples, function = self._mcmc_iterator(
1065 label, "_autocorrelation_plot"
1066 )
1067 arguments = [
1068 (
1069 [
1070 self.savedir, label, param, samples[param],
1071 self.preliminary_pages[label], self.checkpoint
1072 ], function, error_message % (param)
1073 ) for param in iterator
1074 ]
1075 self.pool.starmap(self._try_to_make_a_plot, arguments)
1077 @staticmethod
1078 def _autocorrelation_plot(
1079 savedir, label, parameter, samples, preliminary=False, checkpoint=False
1080 ):
1081 """Generate an autocorrelation plot for a given set of samples
1083 Parameters
1084 ----------
1085 savedir: str
1086 the directory you wish to save the plot in
1087 label: str
1088 the label corresponding to the results file
1089 parameter: str
1090 the name of the parameter that you wish to plot
1091 samples: PESummary.utils.array.Array
1092 array containing the samples corresponding to parameter
1093 preliminary: Bool, optional
1094 if True, add a preliminary watermark to the plot
1095 """
1096 filename = os.path.join(
1097 savedir, "{}_autocorrelation_{}.png".format(label, parameter)
1098 )
1099 if os.path.isfile(filename) and checkpoint:
1100 return
1101 fig = core._autocorrelation_plot(parameter, samples)
1102 _PlotGeneration.save(
1103 fig, filename, preliminary=preliminary
1104 )
1106 @staticmethod
1107 def _autocorrelation_plot_mcmc(
1108 savedir, label, parameter, samples, preliminary=False, checkpoint=False
1109 ):
1110 """Generate an autocorrelation plot for a list of samples, one for each
1111 mcmc chain
1113 Parameters
1114 ----------
1115 savedir: str
1116 the directory you wish to save the plot in
1117 label: str
1118 the label corresponding to the results file
1119 parameter: str
1120 the name of the parameter that you wish to plot
1121 samples: dict
1122 dictioanry of PESummary.utils.array.Array objects containing the
1123 samples corresponding to parameter for each mcmc chain
1124 preliminary: Bool, optional
1125 if True, add a preliminary watermark to the plot
1126 """
1127 filename = os.path.join(
1128 savedir, "{}_autocorrelation_{}.png".format(label, parameter)
1129 )
1130 if os.path.isfile(filename) and checkpoint:
1131 return
1132 same_samples = [val for key, val in samples.items()]
1133 fig = core._autocorrelation_plot_mcmc(parameter, same_samples)
1134 _PlotGeneration.save(
1135 fig, filename, preliminary=preliminary
1136 )
1138 def oned_cdf_plot(self, label):
1139 """Generate oned CDF plots for all parameters in the result file
1141 Parameters
1142 ----------
1143 label: str
1144 the label for the results file that you wish to plot
1145 """
1146 error_message = (
1147 "Failed to generate a CDF plot for %s because {}"
1148 )
1149 iterator, samples, function = self._mcmc_iterator(
1150 label, "_oned_cdf_plot"
1151 )
1152 arguments = [
1153 (
1154 [
1155 self.savedir, label, param, samples[param],
1156 latex_labels[param], self.preliminary_pages[label],
1157 self.checkpoint
1158 ], function, error_message % (param)
1159 ) for param in iterator
1160 ]
1161 self.pool.starmap(self._try_to_make_a_plot, arguments)
1163 @staticmethod
1164 def _oned_cdf_plot(
1165 savedir, label, parameter, samples, latex_label, preliminary=False,
1166 checkpoint=False
1167 ):
1168 """Generate a oned CDF plot for a given set of samples
1170 Parameters
1171 ----------
1172 savedir: str
1173 the directory you wish to save the plot in
1174 label: str
1175 the label corresponding to the results file
1176 parameter: str
1177 the name of the parameter that you wish to plot
1178 samples: PESummary.utils.array.Array
1179 array containing the samples corresponding to parameter
1180 latex_label: str
1181 the latex label corresponding to parameter
1182 preliminary: Bool, optional
1183 if True, add a preliminary watermark to the plot
1184 """
1185 filename = os.path.join(
1186 savedir + "{}_cdf_{}.png".format(label, parameter)
1187 )
1188 if os.path.isfile(filename) and checkpoint:
1189 return
1190 fig = core._1d_cdf_plot(parameter, samples, latex_label)
1191 _PlotGeneration.save(
1192 fig, filename, preliminary=preliminary
1193 )
1195 @staticmethod
1196 def _oned_cdf_plot_mcmc(
1197 savedir, label, parameter, samples, latex_label, preliminary=False,
1198 checkpoint=False
1199 ):
1200 """Generate a oned CDF plot for a given set of samples, one for each
1201 mcmc chain
1203 Parameters
1204 ----------
1205 savedir: str
1206 the directory you wish to save the plot in
1207 label: str
1208 the label corresponding to the results file
1209 parameter: str
1210 the name of the parameter that you wish to plot
1211 samples: dict
1212 dictionary of PESummary.utils.array.Array objects containing the
1213 samples corresponding to parameter for each mcmc chain
1214 latex_label: str
1215 the latex label corresponding to parameter
1216 preliminary: Bool, optional
1217 if True, add a preliminary watermark to the plot
1218 """
1219 filename = os.path.join(
1220 savedir + "{}_cdf_{}.png".format(label, parameter)
1221 )
1222 if os.path.isfile(filename) and checkpoint:
1223 return
1224 same_samples = [val for key, val in samples.items()]
1225 fig = core._1d_cdf_plot_mcmc(parameter, same_samples, latex_label)
1226 _PlotGeneration.save(
1227 fig, filename, preliminary=preliminary
1228 )
1230 def interactive_ridgeline_plot(self, label):
1231 """Generate an interactive ridgeline plot for all paramaters that are
1232 common to all result files
1233 """
1234 error_message = (
1235 "Failed to generate an interactive ridgeline plot for %s because {}"
1236 )
1237 for param in self.same_parameters:
1238 arguments = [
1239 self.savedir, param, self.same_samples[param],
1240 latex_labels[param], self.colors, self.checkpoint
1241 ]
1242 self._try_to_make_a_plot(
1243 arguments, self._interactive_ridgeline_plot,
1244 error_message % (param)
1245 )
1246 continue
1248 @staticmethod
1249 def _interactive_ridgeline_plot(
1250 savedir, parameter, samples, latex_label, colors, checkpoint=False
1251 ):
1252 """Generate an interactive ridgeline plot for
1253 """
1254 filename = os.path.join(
1255 savedir, "interactive_ridgeline_{}.html".format(parameter)
1256 )
1257 if os.path.isfile(filename) and checkpoint:
1258 return
1259 same_samples = [val for key, val in samples.items()]
1260 _ = interactive.ridgeline(
1261 same_samples, list(samples.keys()), xlabel=latex_label,
1262 colors=colors, write_to_html_file=filename
1263 )
1265 def interactive_corner_plot(self, label):
1266 """Generate an interactive corner plot for a given result file
1268 Parameters
1269 ----------
1270 label: str
1271 the label for the results file that you wish to plot
1272 """
1273 self._interactive_corner_plot(
1274 self.savedir, label, self.samples[label], latex_labels,
1275 self.checkpoint
1276 )
1278 @staticmethod
1279 def _interactive_corner_plot(
1280 savedir, label, samples, latex_labels, checkpoint=False
1281 ):
1282 """Generate an interactive corner plot for a given set of samples
1284 Parameters
1285 ----------
1286 savedir: str
1287 the directory you wish to save the plot in
1288 label: str
1289 the label corresponding to the results file
1290 samples: dict
1291 dictionary containing PESummary.utils.array.Array objects that
1292 contain samples for each parameter
1293 latex_labels: str
1294 latex labels for each parameter in samples
1295 """
1296 filename = os.path.join(
1297 savedir, "corner", "{}_interactive.html".format(label)
1298 )
1299 if os.path.isfile(filename) and checkpoint:
1300 return
1301 parameters = samples.keys()
1302 data = [samples[parameter] for parameter in parameters]
1303 latex_labels = [latex_labels[parameter] for parameter in parameters]
1304 _ = interactive.corner(
1305 data, latex_labels, write_to_html_file=filename
1306 )
1308 def oned_cdf_comparison_plot(self, label):
1309 """Generate oned comparison CDF plots for all parameters that are
1310 common to all result files
1312 Parameters
1313 ----------
1314 label: str
1315 the label for the results file that you wish to plot
1316 """
1317 error_message = (
1318 "Failed to generate a comparison CDF plot for %s because {}"
1319 )
1320 for param in self.same_parameters:
1321 if self.weights is not None:
1322 weights = [
1323 self.weights.get(key, None) for key in
1324 self.same_samples[param].keys()
1325 ]
1326 else:
1327 weights = None
1328 arguments = [
1329 self.savedir, param, self.same_samples[param],
1330 latex_labels[param], self.colors, self.linestyles,
1331 weights, self.preliminary_comparison_pages, self.checkpoint
1332 ]
1333 self._try_to_make_a_plot(
1334 arguments, self._oned_cdf_comparison_plot,
1335 error_message % (param)
1336 )
1337 continue
1339 @staticmethod
1340 def _oned_cdf_comparison_plot(
1341 savedir, parameter, samples, latex_label, colors, linestyles=None,
1342 weights=None, preliminary=False, checkpoint=False
1343 ):
1344 """Generate a oned comparison CDF plot for a given parameter
1346 Parameters
1347 ----------
1348 savedir: str
1349 the directory you wish to save the plot in
1350 parameter: str
1351 the name of the parameter that you wish to make a oned comparison
1352 histogram for
1353 samples: dict
1354 dictionary of pesummary.utils.array.Array objects containing the
1355 samples that correspond to parameter for each result file. The key
1356 should be the corresponding label
1357 latex_label: str
1358 the latex label for parameter
1359 colors: list
1360 list of colors to be used to distinguish different result files
1361 linestyles: list, optional
1362 list of linestyles used to distinguish different result files
1363 preliminary: Bool, optional
1364 if True, add a preliminary watermark to the plot
1365 """
1366 filename = os.path.join(
1367 savedir, "combined_cdf_{}.png".format(parameter)
1368 )
1369 if os.path.isfile(filename) and checkpoint:
1370 return
1371 keys = list(samples.keys())
1372 same_samples = [samples[key] for key in keys]
1373 fig = core._1d_cdf_comparison_plot(
1374 parameter, same_samples, colors, latex_label, keys, linestyles,
1375 weights=weights
1376 )
1377 _PlotGeneration.save(
1378 fig, filename, preliminary=preliminary
1379 )
1381 def box_plot_comparison_plot(self, label):
1382 """Generate comparison box plots for all parameters that are
1383 common to all result files
1385 Parameters
1386 ----------
1387 label: str
1388 the label for the results file that you wish to plot
1389 """
1390 error_message = (
1391 "Failed to generate a comparison box plot for %s because {}"
1392 )
1393 for param in self.same_parameters:
1394 arguments = [
1395 self.savedir, param, self.same_samples[param],
1396 latex_labels[param], self.colors,
1397 self.preliminary_comparison_pages, self.checkpoint
1398 ]
1399 self._try_to_make_a_plot(
1400 arguments, self._box_plot_comparison_plot,
1401 error_message % (param)
1402 )
1403 continue
1405 @staticmethod
1406 def _box_plot_comparison_plot(
1407 savedir, parameter, samples, latex_label, colors, preliminary=False,
1408 checkpoint=False
1409 ):
1410 """Generate a comparison box plot for a given parameter
1412 Parameters
1413 ----------
1414 savedir: str
1415 the directory you wish to save the plot in
1416 parameter: str
1417 the name of the parameter that you wish to make a oned comparison
1418 histogram for
1419 samples: dict
1420 dictionary of pesummary.utils.array.Array objects containing the
1421 samples that correspond to parameter for each result file. The key
1422 should be the corresponding label
1423 latex_label: str
1424 the latex label for parameter
1425 colors: list
1426 list of colors to be used to distinguish different result files
1427 preliminary: Bool, optional
1428 if True, add a preliminary watermark to the plot
1429 """
1430 filename = os.path.join(
1431 savedir, "combined_boxplot_{}.png".format(parameter)
1432 )
1433 if os.path.isfile(filename) and checkpoint:
1434 return
1435 same_samples = [val for key, val in samples.items()]
1436 fig = core._comparison_box_plot(
1437 parameter, same_samples, colors, latex_label,
1438 list(samples.keys())
1439 )
1440 _PlotGeneration.save(
1441 fig, filename, preliminary=preliminary
1442 )
1444 def custom_plot(self, label):
1445 """Generate custom plots according to the passed python file
1447 Parameters
1448 ----------
1449 label: str
1450 the label for the results file that you wish to plot
1451 """
1452 import importlib
1454 if self.custom_plotting[0] != "":
1455 import sys
1457 sys.path.append(self.custom_plotting[0])
1458 mod = importlib.import_module(self.custom_plotting[1])
1460 methods = [getattr(mod, i) for i in mod.__single_plots__]
1461 for num, i in enumerate(methods):
1462 fig = i(
1463 list(self.samples[label].keys()), self.samples[label]
1464 )
1465 _PlotGeneration.save(
1466 fig, os.path.join(
1467 self.savedir, "{}_custom_plotting_{}".format(label, num)
1468 ), preliminary=self.preliminary_pages[label]
1469 )