Coverage for pesummary/core/plots/main.py: 74.6%
445 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1#! /usr/bin/env python
3# Licensed under an MIT style license -- see LICENSE.md
5import numpy as np
6import os
7import importlib
8from multiprocessing import Pool
10from pesummary.core.plots.latex_labels import latex_labels
11from pesummary.utils.utils import (
12 logger, get_matplotlib_backend, make_dir, get_matplotlib_style_file
13)
14from pesummary.core.plots import plot as core
15from pesummary.core.plots import interactive
17import matplotlib
19__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
20matplotlib.use(get_matplotlib_backend(parallel=True))
21matplotlib.style.use(get_matplotlib_style_file())
24class _PlotGeneration(object):
25 """Super class to handle the plot generation for a given set of result
26 files
28 Parameters
29 ----------
30 savedir: str
31 the directory to store the plots
32 webdir: str
33 the web directory of the run
34 labels: list
35 list of labels used to distinguish the result files
36 samples: dict
37 dictionary of posterior samples stored in the result files
38 kde_plot: Bool
39 if True, kde plots are generated instead of histograms, Default False
40 existing_labels: list
41 list of labels stored in an existing metafile
42 existing_injection_data: dict
43 dictionary of injection data stored in an existing metafile
44 existing_samples: dict
45 dictionary of posterior samples stored in an existing metafile
46 same_parameters: list
47 list of paramerers that are common in all result files
48 injection_data: dict
49 dictionary of injection data for each result file
50 result_files: list
51 list of result files passed
52 colors: list
53 colors that you wish to use to distinguish different result files
54 disable_comparison: bool, optional
55 whether to make comparison plots, default is True.
56 if disable_comparison is False and len(labels) == 1, no comparsion plots
57 will be generated
58 disable_interactive: bool, optional
59 whether to make interactive plots, default is False
60 disable_corner: bool, optional
61 whether to make the corner plot, default is False
62 """
63 def __init__(
64 self, savedir=None, webdir=None, labels=None, samples=None,
65 kde_plot=False, existing_labels=None, existing_injection_data=None,
66 existing_samples=None, existing_weights=None, same_parameters=None,
67 injection_data=None, colors=None, custom_plotting=None,
68 add_to_existing=False, priors={}, include_prior=False, weights=None,
69 disable_comparison=False, linestyles=None, disable_interactive=False,
70 multi_process=1, mcmc_samples=False, disable_corner=False,
71 corner_params=None, expert_plots=True, checkpoint=False, key_data=None
72 ):
73 self.package = "core"
74 self.webdir = webdir
75 make_dir(self.webdir)
76 make_dir(os.path.join(self.webdir, "plots", "corner"))
77 self.savedir = savedir
78 self.labels = labels
79 self.mcmc_samples = mcmc_samples
80 self.samples = samples
81 self.kde_plot = kde_plot
82 self.existing_labels = existing_labels
83 self.existing_injection_data = existing_injection_data
84 self.existing_samples = existing_samples
85 self.existing_weights = existing_weights
86 self.same_parameters = same_parameters
87 self.injection_data = injection_data
88 self.colors = colors
89 self.custom_plotting = custom_plotting
90 self.add_to_existing = add_to_existing
91 self.priors = priors
92 self.include_prior = include_prior
93 self.linestyles = linestyles
94 self.make_interactive = not disable_interactive
95 self.make_corner = not disable_corner
96 self.corner_params = corner_params
97 self.expert_plots = expert_plots
98 if self.mcmc_samples and self.expert_plots:
99 logger.warning("Unable to generate expert plots for mcmc samples")
100 self.expert_plots = False
101 self.checkpoint = checkpoint
102 self.key_data = key_data
103 self.multi_process = multi_process
104 self.pool = self.setup_pool()
105 self.preliminary_pages = {label: False for label in self.labels}
106 self.preliminary_comparison_pages = False
107 self.make_comparison = (
108 not disable_comparison and self._total_number_of_labels > 1
109 )
110 self.weights = (
111 weights if weights is not None else {i: None for i in self.labels}
112 )
114 if self.same_parameters is not None and not self.mcmc_samples:
115 self.same_samples = {
116 param: {
117 key: item[param] for key, item in self.samples.items()
118 } for param in self.same_parameters
119 }
120 else:
121 self.same_samples = None
123 for i in self.samples.keys():
124 try:
125 self.check_latex_labels(
126 self.samples[i].keys(remove_debug=False)
127 )
128 except TypeError:
129 try:
130 self.check_latex_labels(self.samples[i].keys())
131 except TypeError:
132 pass
134 self.plot_type_dictionary = {
135 "oned_histogram": self.oned_histogram_plot,
136 "sample_evolution": self.sample_evolution_plot,
137 "autocorrelation": self.autocorrelation_plot,
138 "oned_cdf": self.oned_cdf_plot,
139 "custom": self.custom_plot
140 }
141 if self.make_corner:
142 self.plot_type_dictionary.update({"corner": self.corner_plot})
143 if self.expert_plots:
144 self.plot_type_dictionary.update({"expert": self.expert_plot})
145 if self.make_comparison:
146 self.plot_type_dictionary.update(dict(
147 oned_histogram_comparison=self.oned_histogram_comparison_plot,
148 oned_cdf_comparison=self.oned_cdf_comparison_plot,
149 box_plot_comparison=self.box_plot_comparison_plot,
150 ))
151 if self.make_interactive:
152 self.plot_type_dictionary.update(
153 dict(
154 interactive_corner=self.interactive_corner_plot
155 )
156 )
157 if self.make_comparison:
158 self.plot_type_dictionary.update(
159 dict(
160 interactive_ridgeline=self.interactive_ridgeline_plot
161 )
162 )
164 @staticmethod
165 def save(fig, name, preliminary=False, close=True, format="png"):
166 """Save a figure to disk.
168 Parameters
169 ----------
170 fig: matplotlib.pyplot.figure
171 Matplotlib figure that you wish to save
172 name: str
173 Name of the file that you wish to write it too
174 close: Bool, optional
175 Close the figure after it has been saved
176 format: str, optional
177 Format used to save the image
178 """
179 n = len(format)
180 if ".%s" % (format) != name[-n - 1:]:
181 name += ".%s" % (format)
182 if preliminary:
183 fig.text(
184 0.5, 0.5, 'Preliminary', fontsize=90, color='gray', alpha=0.1,
185 ha='center', va='center', rotation=30
186 )
187 fig.tight_layout()
188 fig.savefig(name, format=format)
189 if close:
190 fig.close()
192 @property
193 def _total_number_of_labels(self):
194 _number_of_labels = 0
195 for item in [self.labels, self.existing_labels]:
196 if isinstance(item, list):
197 _number_of_labels += len(item)
198 return _number_of_labels
200 @staticmethod
201 def check_latex_labels(parameters):
202 """Check to see if there is a latex label for all parameters. If not,
203 then create one
205 Parameters
206 ----------
207 parameters: list
208 list of parameters
209 """
210 for i in parameters:
211 if i not in list(latex_labels.keys()):
212 latex_labels[i] = i.replace("_", " ")
214 @property
215 def savedir(self):
216 return self._savedir
218 @savedir.setter
219 def savedir(self, savedir):
220 self._savedir = savedir
221 if savedir is None:
222 self._savedir = self.webdir + "/plots/"
224 def setup_pool(self):
225 """Setup a pool of processes to speed up plot generation
226 """
227 pool = Pool(processes=self.multi_process)
228 return pool
230 def generate_plots(self):
231 """Generate all plots for all result files
232 """
233 for i in self.labels:
234 logger.debug("Starting to generate plots for {}".format(i))
235 self._generate_plots(i)
236 if self.make_interactive:
237 logger.debug(
238 "Starting to generate interactive plots for {}".format(i)
239 )
240 self._generate_interactive_plots(i)
241 if self.add_to_existing:
242 self.add_existing_data()
243 if self.make_comparison:
244 logger.debug("Starting to generate comparison plots")
245 self._generate_comparison_plots()
247 def check_key_data_in_dict(self, label, param):
248 """Check to see if there is key data for a given param
250 Parameters
251 ----------
252 label: str
253 the label used to distinguish a given run
254 param: str
255 name of the parameter you wish to return prior samples for
256 """
257 if self.key_data is None:
258 return None
259 elif label not in self.key_data.keys():
260 return None
261 elif param not in self.key_data[label].keys():
262 return None
263 return self.key_data[label][param]
265 def check_prior_samples_in_dict(self, label, param):
266 """Check to see if there are prior samples for a given param
268 Parameters
269 ----------
270 label: str
271 the label used to distinguish a given run
272 param: str
273 name of the parameter you wish to return prior samples for
274 """
275 cond1 = "samples" in self.priors.keys()
276 if cond1 and label in self.priors["samples"].keys():
277 cond1 = self.priors["samples"][label] != []
278 if cond1 and param in self.priors["samples"][label].keys():
279 return self.priors["samples"][label][param]
280 return None
281 return None
283 def add_existing_data(self):
284 """
285 """
286 from pesummary.utils.utils import _add_existing_data
288 self = _add_existing_data(self)
290 def _generate_plots(self, label):
291 """Generate all plots for a a given result file
292 """
293 if self.make_corner:
294 self.try_to_make_a_plot("corner", label=label)
295 self.try_to_make_a_plot("oned_histogram", label=label)
296 self.try_to_make_a_plot("sample_evolution", label=label)
297 self.try_to_make_a_plot("autocorrelation", label=label)
298 self.try_to_make_a_plot("oned_cdf", label=label)
299 if self.expert_plots:
300 self.try_to_make_a_plot("expert", label=label)
301 if self.custom_plotting:
302 self.try_to_make_a_plot("custom", label=label)
304 def _generate_interactive_plots(self, label):
305 """Generate all interactive plots and save them to an html file ready
306 to be imported later
307 """
308 self.try_to_make_a_plot("interactive_corner", label=label)
309 if self.make_comparison:
310 self.try_to_make_a_plot("interactive_ridgeline")
312 def _generate_comparison_plots(self):
313 """Generate all comparison plots
314 """
315 self.try_to_make_a_plot("oned_histogram_comparison")
316 self.try_to_make_a_plot("oned_cdf_comparison")
317 self.try_to_make_a_plot("box_plot_comparison")
319 def try_to_make_a_plot(self, plot_type, label=None):
320 """Wrapper function to _try_to_make_a_plot
322 Parameters
323 ----------
324 plot_type: str
325 String to describe the plot that you wish to try and make
326 label: str
327 The label of the results file that you wish to plot
328 """
329 self._try_to_make_a_plot(
330 [label], self.plot_type_dictionary[plot_type],
331 "Failed to generate %s plot because {}" % (plot_type)
332 )
334 @staticmethod
335 def _try_to_make_a_plot(arguments, function, message):
336 """Try to make a plot. If it fails return an error message and continue
337 plotting
339 Parameters
340 ----------
341 arguments: list
342 list of arguments that you wish to pass to function
343 function: func
344 function that you wish to execute
345 message: str
346 the error message that you wish to be printed.
347 """
348 try:
349 function(*arguments)
350 except RuntimeError:
351 try:
352 from matplotlib import rcParams
354 original = rcParams['text.usetex']
355 rcParams['text.usetex'] = False
356 function(*arguments)
357 rcParams['text.usetex'] = original
358 except Exception as e:
359 logger.info(message.format(e))
360 except Exception as e:
361 logger.info(message.format(e))
362 finally:
363 from matplotlib import pyplot
365 pyplot.close()
367 def corner_plot(self, label):
368 """Generate a corner plot for a given result file
370 Parameters
371 ----------
372 label: str
373 the label for the results file that you wish to plot
374 """
375 if self.mcmc_samples:
376 samples = self.samples[label].combine
377 else:
378 samples = self.samples[label]
379 self._corner_plot(
380 self.savedir, label, samples, latex_labels, self.webdir,
381 self.corner_params, self.preliminary_pages[label], self.checkpoint
382 )
384 @staticmethod
385 def _corner_plot(
386 savedir, label, samples, latex_labels, webdir, params, preliminary=False,
387 checkpoint=False
388 ):
389 """Generate a corner plot for a given set of samples
391 Parameters
392 ----------
393 savedir: str
394 the directory you wish to save the plot in
395 label: str
396 the label corresponding to the results file
397 samples: dict
398 dictionary containing PESummary.utils.array.Array objects that
399 contain samples for each parameter
400 latex_labels: str
401 latex labels for each parameter in samples
402 webdir: str
403 the directory where the `js` directory is located
404 preliminary: Bool, optional
405 if True, add a preliminary watermark to the plot
406 """
407 import warnings
409 with warnings.catch_warnings():
410 warnings.simplefilter("ignore")
411 filename = os.path.join(
412 savedir, "corner", "{}_all_density_plots.png".format(label)
413 )
414 if os.path.isfile(filename) and checkpoint:
415 return
416 fig, params, data = core._make_corner_plot(
417 samples, latex_labels, corner_parameters=params
418 )
419 _PlotGeneration.save(
420 fig, filename, preliminary=preliminary
421 )
422 combine_corner = open(
423 os.path.join(webdir, "js", "combine_corner.js")
424 )
425 combine_corner = combine_corner.readlines()
426 params = [str(i) for i in params]
427 ind = [
428 linenumber for linenumber, line in enumerate(combine_corner)
429 if "var list = {}" in line
430 ][0]
431 combine_corner.insert(
432 ind + 1, " list['{}'] = {};\n".format(label, params)
433 )
434 new_file = open(
435 os.path.join(webdir, "js", "combine_corner.js"), "w"
436 )
437 new_file.writelines(combine_corner)
438 new_file.close()
439 combine_corner = open(
440 os.path.join(webdir, "js", "combine_corner.js")
441 )
442 combine_corner = combine_corner.readlines()
443 params = [str(i) for i in params]
444 ind = [
445 linenumber for linenumber, line in enumerate(combine_corner)
446 if "var data = {}" in line
447 ][0]
448 combine_corner.insert(
449 ind + 1, " data['{}'] = {};\n".format(label, data)
450 )
451 new_file = open(
452 os.path.join(webdir, "js", "combine_corner.js"), "w"
453 )
454 new_file.writelines(combine_corner)
455 new_file.close()
457 def _mcmc_iterator(self, label, function):
458 """If the data is a set of mcmc chains, return a 2d list of samples
459 to plot. Otherwise return a list of posterior samples
460 """
461 if self.mcmc_samples:
462 function += "_mcmc"
463 return self.same_parameters, self.samples[label], getattr(
464 self, function
465 )
466 return self.samples[label].keys(), self.samples[label], getattr(
467 self, function
468 )
470 def oned_histogram_plot(self, label):
471 """Generate oned histogram plots for all parameters in the result file
473 Parameters
474 ----------
475 label: str
476 the label for the results file that you wish to plot
477 """
478 error_message = (
479 "Failed to generate oned_histogram plot for %s because {}"
480 )
482 iterator, samples, function = self._mcmc_iterator(
483 label, "_oned_histogram_plot"
484 )
486 prior = lambda param: self.check_prior_samples_in_dict(
487 label, param
488 ) if self.include_prior else None
489 key_data = lambda label, param: self.check_key_data_in_dict(label, param)
491 arguments = [
492 (
493 [
494 self.savedir, label, param, samples[param],
495 latex_labels[param], self.injection_data[label][param],
496 self.kde_plot, prior(param), self.weights[label],
497 self.package, self.preliminary_pages[label],
498 self.checkpoint, key_data(label, param)
499 ], function, error_message % (param)
500 ) for param in iterator
501 ]
502 self.pool.starmap(self._try_to_make_a_plot, arguments)
504 def oned_histogram_comparison_plot(self, label):
505 """Generate oned comparison histogram plots for all parameters that are
506 common to all result files
508 Parameters
509 ----------
510 label: str
511 the label for the results file that you wish to plot
512 """
513 error_message = (
514 "Failed to generate a comparison histogram plot for %s because {}"
515 )
516 for param in self.same_parameters:
517 injection = [
518 value[param] for value in self.injection_data.values()
519 ]
520 if self.weights is not None:
521 weights = [
522 self.weights.get(key, None) for key in
523 self.same_samples[param].keys()
524 ]
525 else:
526 weights = None
528 arguments = [
529 self.savedir, param, self.same_samples[param],
530 latex_labels[param], self.colors, injection, weights,
531 self.kde_plot, self.linestyles, self.package,
532 self.preliminary_comparison_pages, self.checkpoint, None
533 ]
534 self._try_to_make_a_plot(
535 arguments, self._oned_histogram_comparison_plot,
536 error_message % (param)
537 )
538 continue
540 @staticmethod
541 def _oned_histogram_comparison_plot(
542 savedir, parameter, samples, latex_label, colors, injection, weights,
543 kde=False, linestyles=None, package="core", preliminary=False,
544 checkpoint=False, filename=None,
545 ):
546 """Generate a oned comparison histogram plot for a given parameter
548 Parameters
549 ----------i
550 savedir: str
551 the directory you wish to save the plot in
552 parameter: str
553 the name of the parameter that you wish to make a oned comparison
554 histogram for
555 samples: dict
556 dictionary of pesummary.utils.array.Array objects containing the
557 samples that correspond to parameter for each result file. The key
558 should be the corresponding label
559 latex_label: str
560 the latex label for parameter
561 colors: list
562 list of colors to be used to distinguish different result files
563 injection: list
564 list of injected values, one for each analysis
565 kde: Bool, optional
566 if True, kde plots will be generated rather than 1d histograms
567 linestyles: list, optional
568 list of linestyles used to distinguish different result files
569 preliminary: Bool, optional
570 if True, add a preliminary watermark to the plot
571 """
572 import math
573 module = importlib.import_module(
574 "pesummary.{}.plots.plot".format(package)
575 )
576 if filename is None:
577 filename = os.path.join(
578 savedir, "combined_1d_posterior_{}.png".format(parameter)
579 )
580 if os.path.isfile(filename) and checkpoint:
581 return
582 hist = not kde
583 for num, inj in enumerate(injection):
584 if math.isnan(inj):
585 injection[num] = None
586 same_samples = [val for key, val in samples.items()]
587 fig = module._1d_comparison_histogram_plot(
588 parameter, same_samples, colors, latex_label,
589 list(samples.keys()), inj_value=injection, kde=kde,
590 linestyles=linestyles, hist=hist, weights=weights
591 )
592 _PlotGeneration.save(
593 fig, filename, preliminary=preliminary
594 )
596 @staticmethod
597 def _oned_histogram_plot(
598 savedir, label, parameter, samples, latex_label, injection, kde=False,
599 prior=None, weights=None, package="core", preliminary=False,
600 checkpoint=False, key_data=None
601 ):
602 """Generate a oned histogram plot for a given set of samples
604 Parameters
605 ----------
606 savedir: str
607 the directory you wish to save the plot in
608 label: str
609 the label corresponding to the results file
610 parameter: str
611 the name of the parameter that you wish to plot
612 samples: PESummary.utils.array.Array
613 array containing the samples corresponding to parameter
614 latex_label: str
615 the latex label corresponding to parameter
616 injection: float
617 the injected value
618 kde: Bool, optional
619 if True, kde plots will be generated rather than 1d histograms
620 prior: PESummary.utils.array.Array, optional
621 the prior samples for param
622 weights: PESummary.utils.utilsrray, optional
623 the weights for each samples. If None, assumed to be 1
624 preliminary: Bool, optional
625 if True, add a preliminary watermark to the plot
626 """
627 import math
628 module = importlib.import_module(
629 "pesummary.{}.plots.plot".format(package)
630 )
632 if math.isnan(injection):
633 injection = None
634 hist = not kde
636 filename = os.path.join(
637 savedir, "{}_1d_posterior_{}.png".format(label, parameter)
638 )
639 if os.path.isfile(filename) and checkpoint:
640 return
641 fig = module._1d_histogram_plot(
642 parameter, samples, latex_label, inj_value=injection, kde=kde,
643 hist=hist, prior=prior, weights=weights, key_data=key_data
644 )
645 _PlotGeneration.save(
646 fig, filename, preliminary=preliminary
647 )
649 @staticmethod
650 def _oned_histogram_plot_mcmc(
651 savedir, label, parameter, samples, latex_label, injection, kde=False,
652 prior=None, weights=None, package="core", preliminary=False,
653 checkpoint=False, key_data=None
654 ):
655 """Generate a oned histogram plot for a given set of samples for
656 multiple mcmc chains
658 Parameters
659 ----------
660 savedir: str
661 the directory you wish to save the plot in
662 label: str
663 the label corresponding to the results file
664 parameter: str
665 the name of the parameter that you wish to plot
666 samples: dict
667 dictionary of PESummary.utils.array.Array objects containing the
668 samples corresponding to parameter for multiple mcmc chains
669 latex_label: str
670 the latex label corresponding to parameter
671 injection: float
672 the injected value
673 kde: Bool, optional
674 if True, kde plots will be generated rather than 1d histograms
675 prior: PESummary.utils.array.Array, optional
676 the prior samples for param
677 weights: PESummary.utils.array.Array, optional
678 the weights for each samples. If None, assumed to be 1
679 preliminary: Bool, optional
680 if True, add a preliminary watermark to the plot
681 """
682 import math
683 from pesummary.utils.array import Array
685 module = importlib.import_module(
686 "pesummary.{}.plots.plot".format(package)
687 )
689 if math.isnan(injection):
690 injection = None
691 same_samples = [val for key, val in samples.items()]
692 filename = os.path.join(
693 savedir, "{}_1d_posterior_{}.png".format(label, parameter)
694 )
695 if os.path.isfile(filename) and checkpoint:
696 pass
697 else:
698 fig = module._1d_histogram_plot_mcmc(
699 parameter, same_samples, latex_label, inj_value=injection,
700 kde=kde, prior=prior, weights=weights
701 )
702 _PlotGeneration.save(
703 fig, filename, preliminary=preliminary
704 )
705 filename = os.path.join(
706 savedir, "{}_1d_posterior_{}_combined.png".format(label, parameter)
707 )
708 if os.path.isfile(filename) and checkpoint:
709 pass
710 else:
711 fig = module._1d_histogram_plot(
712 parameter, Array(np.concatenate(same_samples)), latex_label,
713 inj_value=injection, kde=kde, prior=prior, weights=weights
714 )
715 _PlotGeneration.save(
716 fig, filename, preliminary=preliminary
717 )
719 def sample_evolution_plot(self, label):
720 """Generate sample evolution plots for all parameters in the result file
722 Parameters
723 ----------
724 label: str
725 the label for the results file that you wish to plot
726 """
727 error_message = (
728 "Failed to generate a sample evolution plot for %s because {}"
729 )
730 iterator, samples, function = self._mcmc_iterator(
731 label, "_sample_evolution_plot"
732 )
733 arguments = [
734 (
735 [
736 self.savedir, label, param, samples[param],
737 latex_labels[param], self.injection_data[label][param],
738 self.preliminary_pages[label], self.checkpoint
739 ], function, error_message % (param)
740 ) for param in iterator
741 ]
742 self.pool.starmap(self._try_to_make_a_plot, arguments)
744 def expert_plot(self, label):
745 """Generate expert plots for diagnostics
747 Parameters
748 ----------
749 label: str
750 the label for the results file that you wish to plot
751 """
752 error_message = (
753 "Failed to generate log_likelihood-%s 2d contour plot because {}"
754 )
755 iterator, samples, function = self._mcmc_iterator(
756 label, "_2d_contour_plot"
757 )
758 _debug = self.samples[label].debug_keys()
759 arguments = [
760 (
761 [
762 self.savedir, label, param, "log_likelihood", samples[param],
763 samples["log_likelihood"], latex_labels[param],
764 latex_labels["log_likelihood"],
765 self.preliminary_pages[label], [
766 samples[param][np.argmax(samples["log_likelihood"])],
767 np.max(samples["log_likelihood"]),
768 ], self.checkpoint
769 ], function, error_message % (param)
770 ) for param in iterator + _debug
771 ]
772 self.pool.starmap(self._try_to_make_a_plot, arguments)
773 _reweight_keys = [
774 param for param in self.samples[label].debug_keys() if
775 "_non_reweighted" in param
776 ]
777 if len(_reweight_keys):
778 error_message = (
779 "Failed to generate %s-%s 2d contour plot because {}"
780 )
781 _base_param = lambda p: p.split("_non_reweighted")[0][1:]
782 arguments = [
783 (
784 [
785 self.savedir, label, _base_param(param), param,
786 samples[_base_param(param)], samples[param],
787 latex_labels[_base_param(param)], latex_labels[param],
788 self.preliminary_pages[label], None, self.checkpoint
789 ], function, error_message % (_base_param(param), param)
790 ) for param in _reweight_keys
791 ]
792 self.pool.starmap(self._try_to_make_a_plot, arguments)
793 error_message = (
794 "Failed to generate a histogram plot comparing %s and %s "
795 "because {}"
796 )
797 arguments = [
798 (
799 [
800 self.savedir, _base_param(param), {
801 "reweighted": samples[_base_param(param)],
802 "non-reweighted": samples[param]
803 }, latex_labels[_base_param(param)], ['b', 'r'],
804 [np.nan, np.nan], True, None, self.package,
805 self.preliminary_comparison_pages, self.checkpoint,
806 os.path.join(
807 self.savedir, "{}_1d_posterior_{}_{}.png".format(
808 label, _base_param(param), param
809 )
810 )
811 ], self._oned_histogram_comparison_plot,
812 error_message % (_base_param(param), param)
813 ) for param in _reweight_keys
814 ]
815 self.pool.starmap(self._try_to_make_a_plot, arguments)
817 error_message = (
818 "Failed to generate log_likelihood-%s sample_evolution plot "
819 "because {}"
820 )
821 iterator, samples, function = self._mcmc_iterator(
822 label, "_colored_sample_evolution_plot"
823 )
824 arguments = [
825 (
826 [
827 self.savedir, label, param, "log_likelihood", samples[param],
828 samples["log_likelihood"], latex_labels[param],
829 latex_labels["log_likelihood"],
830 self.preliminary_pages[label], self.checkpoint
831 ], function, error_message % (param)
832 ) for param in iterator
833 ]
834 self.pool.starmap(self._try_to_make_a_plot, arguments)
835 error_message = (
836 "Failed to generate bootstrapped oned_histogram plot for %s "
837 "because {}"
838 )
839 iterator, samples, function = self._mcmc_iterator(
840 label, "_oned_histogram_bootstrap_plot"
841 )
842 arguments = [
843 (
844 [
845 self.savedir, label, param, samples[param],
846 latex_labels[param], self.preliminary_pages[label],
847 self.package, self.checkpoint
848 ], function, error_message % (param)
849 ) for param in iterator
850 ]
851 self.pool.starmap(self._try_to_make_a_plot, arguments)
853 @staticmethod
854 def _oned_histogram_bootstrap_plot(
855 savedir, label, parameter, samples, latex_label, preliminary=False,
856 package="core", checkpoint=False, nsamples=1000, ntests=100, **kwargs
857 ):
858 """Generate a bootstrapped oned histogram plot for a given set of
859 samples
861 Parameters
862 ----------
863 savedir: str
864 the directory you wish to save the plot in
865 label: str
866 the label corresponding to the results file
867 parameter: str
868 the name of the parameter that you wish to plot
869 samples: PESummary.utils.array.Array
870 array containing the samples corresponding to parameter
871 latex_label: str
872 the latex label corresponding to parameter
873 preliminary: Bool, optional
874 if True, add a preliminary watermark to the plot
875 """
876 module = importlib.import_module(
877 "pesummary.{}.plots.plot".format(package)
878 )
880 filename = os.path.join(
881 savedir, "{}_1d_posterior_{}_bootstrap.png".format(label, parameter)
882 )
883 if os.path.isfile(filename) and checkpoint:
884 return
885 fig = module._1d_histogram_plot_bootstrap(
886 parameter, samples, latex_label, nsamples=nsamples, ntests=ntests,
887 **kwargs
888 )
889 _PlotGeneration.save(
890 fig, filename, preliminary=preliminary
891 )
893 @staticmethod
894 def _2d_contour_plot(
895 savedir, label, parameter_x, parameter_y, samples_x, samples_y,
896 latex_label_x, latex_label_y, preliminary=False, truth=None,
897 checkpoint=False
898 ):
899 """Generate a 2d contour plot for a given set of samples
901 Parameters
902 ----------
903 savedir: str
904 the directory you wish to save the plot in
905 label: str
906 the label corresponding to the results file
907 samples_x: PESummary.utils.array.Array
908 array containing the samples for the x axis
909 samples_y: PESummary.utils.array.Array
910 array containing the samples for the y axis
911 latex_label_x: str
912 the latex label for the x axis
913 latex_label_y: str
914 the latex label for the y axis
915 preliminary: Bool, optional
916 if True, add a preliminary watermark to the plot
917 """
918 from pesummary.core.plots.publication import twod_contour_plot
920 filename = os.path.join(
921 savedir, "{}_2d_contour_{}_{}.png".format(
922 label, parameter_x, parameter_y
923 )
924 )
925 if os.path.isfile(filename) and checkpoint:
926 return
927 fig = twod_contour_plot(
928 samples_x, samples_y, levels=[0.9, 0.5], xlabel=latex_label_x,
929 ylabel=latex_label_y, bins=50, truth=truth
930 )
931 _PlotGeneration.save(
932 fig, filename, preliminary=preliminary
933 )
935 @staticmethod
936 def _colored_sample_evolution_plot(
937 savedir, label, parameter_x, parameter_y, samples_x, samples_y,
938 latex_label_x, latex_label_y, preliminary=False, checkpoint=False
939 ):
940 """Generate a 2d contour plot for a given set of samples
942 Parameters
943 ----------
944 savedir: str
945 the directory you wish to save the plot in
946 label: str
947 the label corresponding to the results file
948 samples_x: PESummary.utils.array.Array
949 array containing the samples for the x axis
950 samples_y: PESummary.utils.array.Array
951 array containing the samples for the y axis
952 latex_label_x: str
953 the latex label for the x axis
954 latex_label_y: str
955 the latex label for the y axis
956 preliminary: Bool, optional
957 if True, add a preliminary watermark to the plot
958 """
959 filename = os.path.join(
960 savedir, "{}_sample_evolution_{}_{}_colored.png".format(
961 label, parameter_x, parameter_y
962 )
963 )
964 if os.path.isfile(filename) and checkpoint:
965 return
966 fig = core._sample_evolution_plot(
967 parameter_x, samples_x, latex_label_x, z=samples_y,
968 z_label=latex_label_y
969 )
970 _PlotGeneration.save(
971 fig, filename, preliminary=preliminary
972 )
974 @staticmethod
975 def _sample_evolution_plot(
976 savedir, label, parameter, samples, latex_label, injection,
977 preliminary=False, checkpoint=False
978 ):
979 """Generate a sample evolution plot for a given set of samples
981 Parameters
982 ----------
983 savedir: str
984 the directory you wish to save the plot in
985 label: str
986 the label corresponding to the results file
987 parameter: str
988 the name of the parameter that you wish to plot
989 samples: PESummary.utils.array.Array
990 array containing the samples corresponding to parameter
991 latex_label: str
992 the latex label corresponding to parameter
993 injection: float
994 the injected value
995 preliminary: Bool, optional
996 if True, add a preliminary watermark to the plot
997 """
998 filename = os.path.join(
999 savedir, "{}_sample_evolution_{}.png".format(label, parameter)
1000 )
1001 if os.path.isfile(filename) and checkpoint:
1002 return
1003 fig = core._sample_evolution_plot(
1004 parameter, samples, latex_label, injection
1005 )
1006 _PlotGeneration.save(
1007 fig, filename, preliminary=preliminary
1008 )
1010 @staticmethod
1011 def _sample_evolution_plot_mcmc(
1012 savedir, label, parameter, samples, latex_label, injection,
1013 preliminary=False, checkpoint=False
1014 ):
1015 """Generate a sample evolution plot for a given set of mcmc chains
1017 Parameters
1018 ----------
1019 savedir: str
1020 the directory you wish to save the plot in
1021 label: str
1022 the label corresponding to the results file
1023 parameter: str
1024 the name of the parameter that you wish to plot
1025 samples: dict
1026 dictionary containing pesummary.utils.array.Array objects containing
1027 the samples corresponding to parameter for each chain
1028 latex_label: str
1029 the latex label corresponding to parameter
1030 injection: float
1031 the injected value
1032 preliminary: Bool, optional
1033 if True, add a preliminary watermark to the plot
1034 """
1035 filename = os.path.join(
1036 savedir, "{}_sample_evolution_{}.png".format(label, parameter)
1037 )
1038 if os.path.isfile(filename) and checkpoint:
1039 return
1040 same_samples = [val for key, val in samples.items()]
1041 fig = core._sample_evolution_plot_mcmc(
1042 parameter, same_samples, latex_label, injection
1043 )
1044 _PlotGeneration.save(
1045 fig, filename, preliminary=preliminary
1046 )
1048 def autocorrelation_plot(self, label):
1049 """Generate autocorrelation plots for all parameters in the result file
1051 Parameters
1052 ----------
1053 label: str
1054 the label for the results file that you wish to plot
1055 """
1056 error_message = (
1057 "Failed to generate an autocorrelation plot for %s because {}"
1058 )
1059 iterator, samples, function = self._mcmc_iterator(
1060 label, "_autocorrelation_plot"
1061 )
1062 arguments = [
1063 (
1064 [
1065 self.savedir, label, param, samples[param],
1066 self.preliminary_pages[label], self.checkpoint
1067 ], function, error_message % (param)
1068 ) for param in iterator
1069 ]
1070 self.pool.starmap(self._try_to_make_a_plot, arguments)
1072 @staticmethod
1073 def _autocorrelation_plot(
1074 savedir, label, parameter, samples, preliminary=False, checkpoint=False
1075 ):
1076 """Generate an autocorrelation plot for a given set of samples
1078 Parameters
1079 ----------
1080 savedir: str
1081 the directory you wish to save the plot in
1082 label: str
1083 the label corresponding to the results file
1084 parameter: str
1085 the name of the parameter that you wish to plot
1086 samples: PESummary.utils.array.Array
1087 array containing the samples corresponding to parameter
1088 preliminary: Bool, optional
1089 if True, add a preliminary watermark to the plot
1090 """
1091 filename = os.path.join(
1092 savedir, "{}_autocorrelation_{}.png".format(label, parameter)
1093 )
1094 if os.path.isfile(filename) and checkpoint:
1095 return
1096 fig = core._autocorrelation_plot(parameter, samples)
1097 _PlotGeneration.save(
1098 fig, filename, preliminary=preliminary
1099 )
1101 @staticmethod
1102 def _autocorrelation_plot_mcmc(
1103 savedir, label, parameter, samples, preliminary=False, checkpoint=False
1104 ):
1105 """Generate an autocorrelation plot for a list of samples, one for each
1106 mcmc chain
1108 Parameters
1109 ----------
1110 savedir: str
1111 the directory you wish to save the plot in
1112 label: str
1113 the label corresponding to the results file
1114 parameter: str
1115 the name of the parameter that you wish to plot
1116 samples: dict
1117 dictioanry of PESummary.utils.array.Array objects containing the
1118 samples corresponding to parameter for each mcmc chain
1119 preliminary: Bool, optional
1120 if True, add a preliminary watermark to the plot
1121 """
1122 filename = os.path.join(
1123 savedir, "{}_autocorrelation_{}.png".format(label, parameter)
1124 )
1125 if os.path.isfile(filename) and checkpoint:
1126 return
1127 same_samples = [val for key, val in samples.items()]
1128 fig = core._autocorrelation_plot_mcmc(parameter, same_samples)
1129 _PlotGeneration.save(
1130 fig, filename, preliminary=preliminary
1131 )
1133 def oned_cdf_plot(self, label):
1134 """Generate oned CDF plots for all parameters in the result file
1136 Parameters
1137 ----------
1138 label: str
1139 the label for the results file that you wish to plot
1140 """
1141 error_message = (
1142 "Failed to generate a CDF plot for %s because {}"
1143 )
1144 iterator, samples, function = self._mcmc_iterator(
1145 label, "_oned_cdf_plot"
1146 )
1147 arguments = [
1148 (
1149 [
1150 self.savedir, label, param, samples[param],
1151 latex_labels[param], self.preliminary_pages[label],
1152 self.checkpoint
1153 ], function, error_message % (param)
1154 ) for param in iterator
1155 ]
1156 self.pool.starmap(self._try_to_make_a_plot, arguments)
1158 @staticmethod
1159 def _oned_cdf_plot(
1160 savedir, label, parameter, samples, latex_label, preliminary=False,
1161 checkpoint=False
1162 ):
1163 """Generate a oned CDF plot for a given set of samples
1165 Parameters
1166 ----------
1167 savedir: str
1168 the directory you wish to save the plot in
1169 label: str
1170 the label corresponding to the results file
1171 parameter: str
1172 the name of the parameter that you wish to plot
1173 samples: PESummary.utils.array.Array
1174 array containing the samples corresponding to parameter
1175 latex_label: str
1176 the latex label corresponding to parameter
1177 preliminary: Bool, optional
1178 if True, add a preliminary watermark to the plot
1179 """
1180 filename = os.path.join(
1181 savedir + "{}_cdf_{}.png".format(label, parameter)
1182 )
1183 if os.path.isfile(filename) and checkpoint:
1184 return
1185 fig = core._1d_cdf_plot(parameter, samples, latex_label)
1186 _PlotGeneration.save(
1187 fig, filename, preliminary=preliminary
1188 )
1190 @staticmethod
1191 def _oned_cdf_plot_mcmc(
1192 savedir, label, parameter, samples, latex_label, preliminary=False,
1193 checkpoint=False
1194 ):
1195 """Generate a oned CDF plot for a given set of samples, one for each
1196 mcmc chain
1198 Parameters
1199 ----------
1200 savedir: str
1201 the directory you wish to save the plot in
1202 label: str
1203 the label corresponding to the results file
1204 parameter: str
1205 the name of the parameter that you wish to plot
1206 samples: dict
1207 dictionary of PESummary.utils.array.Array objects containing the
1208 samples corresponding to parameter for each mcmc chain
1209 latex_label: str
1210 the latex label corresponding to parameter
1211 preliminary: Bool, optional
1212 if True, add a preliminary watermark to the plot
1213 """
1214 filename = os.path.join(
1215 savedir + "{}_cdf_{}.png".format(label, parameter)
1216 )
1217 if os.path.isfile(filename) and checkpoint:
1218 return
1219 same_samples = [val for key, val in samples.items()]
1220 fig = core._1d_cdf_plot_mcmc(parameter, same_samples, latex_label)
1221 _PlotGeneration.save(
1222 fig, filename, preliminary=preliminary
1223 )
1225 def interactive_ridgeline_plot(self, label):
1226 """Generate an interactive ridgeline plot for all paramaters that are
1227 common to all result files
1228 """
1229 error_message = (
1230 "Failed to generate an interactive ridgeline plot for %s because {}"
1231 )
1232 for param in self.same_parameters:
1233 arguments = [
1234 self.savedir, param, self.same_samples[param],
1235 latex_labels[param], self.colors, self.checkpoint
1236 ]
1237 self._try_to_make_a_plot(
1238 arguments, self._interactive_ridgeline_plot,
1239 error_message % (param)
1240 )
1241 continue
1243 @staticmethod
1244 def _interactive_ridgeline_plot(
1245 savedir, parameter, samples, latex_label, colors, checkpoint=False
1246 ):
1247 """Generate an interactive ridgeline plot for
1248 """
1249 filename = os.path.join(
1250 savedir, "interactive_ridgeline_{}.html".format(parameter)
1251 )
1252 if os.path.isfile(filename) and checkpoint:
1253 return
1254 same_samples = [val for key, val in samples.items()]
1255 _ = interactive.ridgeline(
1256 same_samples, list(samples.keys()), xlabel=latex_label,
1257 colors=colors, write_to_html_file=filename
1258 )
1260 def interactive_corner_plot(self, label):
1261 """Generate an interactive corner plot for a given result file
1263 Parameters
1264 ----------
1265 label: str
1266 the label for the results file that you wish to plot
1267 """
1268 self._interactive_corner_plot(
1269 self.savedir, label, self.samples[label], latex_labels,
1270 self.checkpoint
1271 )
1273 @staticmethod
1274 def _interactive_corner_plot(
1275 savedir, label, samples, latex_labels, checkpoint=False
1276 ):
1277 """Generate an interactive corner plot for a given set of samples
1279 Parameters
1280 ----------
1281 savedir: str
1282 the directory you wish to save the plot in
1283 label: str
1284 the label corresponding to the results file
1285 samples: dict
1286 dictionary containing PESummary.utils.array.Array objects that
1287 contain samples for each parameter
1288 latex_labels: str
1289 latex labels for each parameter in samples
1290 """
1291 filename = os.path.join(
1292 savedir, "corner", "{}_interactive.html".format(label)
1293 )
1294 if os.path.isfile(filename) and checkpoint:
1295 return
1296 parameters = samples.keys()
1297 data = [samples[parameter] for parameter in parameters]
1298 latex_labels = [latex_labels[parameter] for parameter in parameters]
1299 _ = interactive.corner(
1300 data, latex_labels, write_to_html_file=filename
1301 )
1303 def oned_cdf_comparison_plot(self, label):
1304 """Generate oned comparison CDF plots for all parameters that are
1305 common to all result files
1307 Parameters
1308 ----------
1309 label: str
1310 the label for the results file that you wish to plot
1311 """
1312 error_message = (
1313 "Failed to generate a comparison CDF plot for %s because {}"
1314 )
1315 for param in self.same_parameters:
1316 if self.weights is not None:
1317 weights = [
1318 self.weights.get(key, None) for key in
1319 self.same_samples[param].keys()
1320 ]
1321 else:
1322 weights = None
1323 arguments = [
1324 self.savedir, param, self.same_samples[param],
1325 latex_labels[param], self.colors, self.linestyles,
1326 weights, self.preliminary_comparison_pages, self.checkpoint
1327 ]
1328 self._try_to_make_a_plot(
1329 arguments, self._oned_cdf_comparison_plot,
1330 error_message % (param)
1331 )
1332 continue
1334 @staticmethod
1335 def _oned_cdf_comparison_plot(
1336 savedir, parameter, samples, latex_label, colors, linestyles=None,
1337 weights=None, preliminary=False, checkpoint=False
1338 ):
1339 """Generate a oned comparison CDF plot for a given parameter
1341 Parameters
1342 ----------
1343 savedir: str
1344 the directory you wish to save the plot in
1345 parameter: str
1346 the name of the parameter that you wish to make a oned comparison
1347 histogram for
1348 samples: dict
1349 dictionary of pesummary.utils.array.Array objects containing the
1350 samples that correspond to parameter for each result file. The key
1351 should be the corresponding label
1352 latex_label: str
1353 the latex label for parameter
1354 colors: list
1355 list of colors to be used to distinguish different result files
1356 linestyles: list, optional
1357 list of linestyles used to distinguish different result files
1358 preliminary: Bool, optional
1359 if True, add a preliminary watermark to the plot
1360 """
1361 filename = os.path.join(
1362 savedir, "combined_cdf_{}.png".format(parameter)
1363 )
1364 if os.path.isfile(filename) and checkpoint:
1365 return
1366 keys = list(samples.keys())
1367 same_samples = [samples[key] for key in keys]
1368 fig = core._1d_cdf_comparison_plot(
1369 parameter, same_samples, colors, latex_label, keys, linestyles,
1370 weights=weights
1371 )
1372 _PlotGeneration.save(
1373 fig, filename, preliminary=preliminary
1374 )
1376 def box_plot_comparison_plot(self, label):
1377 """Generate comparison box plots for all parameters that are
1378 common to all result files
1380 Parameters
1381 ----------
1382 label: str
1383 the label for the results file that you wish to plot
1384 """
1385 error_message = (
1386 "Failed to generate a comparison box plot for %s because {}"
1387 )
1388 for param in self.same_parameters:
1389 arguments = [
1390 self.savedir, param, self.same_samples[param],
1391 latex_labels[param], self.colors,
1392 self.preliminary_comparison_pages, self.checkpoint
1393 ]
1394 self._try_to_make_a_plot(
1395 arguments, self._box_plot_comparison_plot,
1396 error_message % (param)
1397 )
1398 continue
1400 @staticmethod
1401 def _box_plot_comparison_plot(
1402 savedir, parameter, samples, latex_label, colors, preliminary=False,
1403 checkpoint=False
1404 ):
1405 """Generate a comparison box plot for a given parameter
1407 Parameters
1408 ----------
1409 savedir: str
1410 the directory you wish to save the plot in
1411 parameter: str
1412 the name of the parameter that you wish to make a oned comparison
1413 histogram for
1414 samples: dict
1415 dictionary of pesummary.utils.array.Array objects containing the
1416 samples that correspond to parameter for each result file. The key
1417 should be the corresponding label
1418 latex_label: str
1419 the latex label for parameter
1420 colors: list
1421 list of colors to be used to distinguish different result files
1422 preliminary: Bool, optional
1423 if True, add a preliminary watermark to the plot
1424 """
1425 filename = os.path.join(
1426 savedir, "combined_boxplot_{}.png".format(parameter)
1427 )
1428 if os.path.isfile(filename) and checkpoint:
1429 return
1430 same_samples = [val for key, val in samples.items()]
1431 fig = core._comparison_box_plot(
1432 parameter, same_samples, colors, latex_label,
1433 list(samples.keys())
1434 )
1435 _PlotGeneration.save(
1436 fig, filename, preliminary=preliminary
1437 )
1439 def custom_plot(self, label):
1440 """Generate custom plots according to the passed python file
1442 Parameters
1443 ----------
1444 label: str
1445 the label for the results file that you wish to plot
1446 """
1447 import importlib
1449 if self.custom_plotting[0] != "":
1450 import sys
1452 sys.path.append(self.custom_plotting[0])
1453 mod = importlib.import_module(self.custom_plotting[1])
1455 methods = [getattr(mod, i) for i in mod.__single_plots__]
1456 for num, i in enumerate(methods):
1457 fig = i(
1458 list(self.samples[label].keys()), self.samples[label]
1459 )
1460 _PlotGeneration.save(
1461 fig, os.path.join(
1462 self.savedir, "{}_custom_plotting_{}".format(label, num)
1463 ), preliminary=self.preliminary_pages[label]
1464 )