Coverage for pesummary/core/plots/main.py: 74.7%
439 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-02 08:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-02 08:42 +0000
1#! /usr/bin/env python
3# Licensed under an MIT style license -- see LICENSE.md
5import numpy as np
6import os
7import importlib
8from multiprocessing import Pool
10from pesummary.core.plots.latex_labels import latex_labels
11from pesummary.utils.utils import (
12 logger, get_matplotlib_backend, make_dir, get_matplotlib_style_file
13)
14from pesummary.core.plots import plot as core
15from pesummary.core.plots import interactive
17import matplotlib
19__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
20matplotlib.use(get_matplotlib_backend(parallel=True))
21matplotlib.style.use(get_matplotlib_style_file())
24class _PlotGeneration(object):
25 """Super class to handle the plot generation for a given set of result
26 files
28 Parameters
29 ----------
30 savedir: str
31 the directory to store the plots
32 webdir: str
33 the web directory of the run
34 labels: list
35 list of labels used to distinguish the result files
36 samples: dict
37 dictionary of posterior samples stored in the result files
38 kde_plot: Bool
39 if True, kde plots are generated instead of histograms, Default False
40 existing_labels: list
41 list of labels stored in an existing metafile
42 existing_injection_data: dict
43 dictionary of injection data stored in an existing metafile
44 existing_samples: dict
45 dictionary of posterior samples stored in an existing metafile
46 same_parameters: list
47 list of paramerers that are common in all result files
48 injection_data: dict
49 dictionary of injection data for each result file
50 result_files: list
51 list of result files passed
52 colors: list
53 colors that you wish to use to distinguish different result files
54 disable_comparison: bool, optional
55 whether to make comparison plots, default is True.
56 if disable_comparison is False and len(labels) == 1, no comparsion plots
57 will be generated
58 disable_interactive: bool, optional
59 whether to make interactive plots, default is False
60 disable_corner: bool, optional
61 whether to make the corner plot, default is False
62 """
63 def __init__(
64 self, savedir=None, webdir=None, labels=None, samples=None,
65 kde_plot=False, existing_labels=None, existing_injection_data=None,
66 existing_samples=None, existing_weights=None, same_parameters=None,
67 injection_data=None, colors=None, custom_plotting=None,
68 add_to_existing=False, priors={}, include_prior=False, weights=None,
69 disable_comparison=False, linestyles=None, disable_interactive=False,
70 multi_process=1, mcmc_samples=False, disable_corner=False,
71 corner_params=None, expert_plots=True, checkpoint=False, key_data=None
72 ):
73 self.package = "core"
74 self.webdir = webdir
75 make_dir(self.webdir)
76 make_dir(os.path.join(self.webdir, "plots", "corner"))
77 self.savedir = savedir
78 self.labels = labels
79 self.mcmc_samples = mcmc_samples
80 self.samples = samples
81 self.kde_plot = kde_plot
82 self.existing_labels = existing_labels
83 self.existing_injection_data = existing_injection_data
84 self.existing_samples = existing_samples
85 self.existing_weights = existing_weights
86 self.same_parameters = same_parameters
87 self.injection_data = injection_data
88 self.colors = colors
89 self.custom_plotting = custom_plotting
90 self.add_to_existing = add_to_existing
91 self.priors = priors
92 self.include_prior = include_prior
93 self.linestyles = linestyles
94 self.make_interactive = not disable_interactive
95 self.make_corner = not disable_corner
96 self.corner_params = corner_params
97 self.expert_plots = expert_plots
98 if self.mcmc_samples and self.expert_plots:
99 logger.warning("Unable to generate expert plots for mcmc samples")
100 self.expert_plots = False
101 self.checkpoint = checkpoint
102 self.key_data = key_data
103 self.multi_process = multi_process
104 self.pool = self.setup_pool()
105 self.preliminary_pages = {label: False for label in self.labels}
106 self.preliminary_comparison_pages = False
107 self.make_comparison = (
108 not disable_comparison and self._total_number_of_labels > 1
109 )
110 self.weights = (
111 weights if weights is not None else {i: None for i in self.labels}
112 )
114 if self.same_parameters is not None and not self.mcmc_samples:
115 self.same_samples = {
116 param: {
117 key: item[param] for key, item in self.samples.items()
118 } for param in self.same_parameters
119 }
120 else:
121 self.same_samples = None
123 for i in self.samples.keys():
124 try:
125 self.check_latex_labels(
126 self.samples[i].keys(remove_debug=False)
127 )
128 except TypeError:
129 try:
130 self.check_latex_labels(self.samples[i].keys())
131 except TypeError:
132 pass
134 self.plot_type_dictionary = {
135 "oned_histogram": self.oned_histogram_plot,
136 "sample_evolution": self.sample_evolution_plot,
137 "autocorrelation": self.autocorrelation_plot,
138 "oned_cdf": self.oned_cdf_plot,
139 "custom": self.custom_plot
140 }
141 if self.make_corner:
142 self.plot_type_dictionary.update({"corner": self.corner_plot})
143 if self.expert_plots:
144 self.plot_type_dictionary.update({"expert": self.expert_plot})
145 if self.make_comparison:
146 self.plot_type_dictionary.update(dict(
147 oned_histogram_comparison=self.oned_histogram_comparison_plot,
148 oned_cdf_comparison=self.oned_cdf_comparison_plot,
149 box_plot_comparison=self.box_plot_comparison_plot,
150 ))
151 if self.make_interactive:
152 self.plot_type_dictionary.update(
153 dict(
154 interactive_corner=self.interactive_corner_plot
155 )
156 )
157 if self.make_comparison:
158 self.plot_type_dictionary.update(
159 dict(
160 interactive_ridgeline=self.interactive_ridgeline_plot
161 )
162 )
164 @staticmethod
165 def save(fig, name, preliminary=False, close=True, format="png"):
166 """Save a figure to disk.
168 Parameters
169 ----------
170 fig: matplotlib.pyplot.figure
171 Matplotlib figure that you wish to save
172 name: str
173 Name of the file that you wish to write it too
174 close: Bool, optional
175 Close the figure after it has been saved
176 format: str, optional
177 Format used to save the image
178 """
179 n = len(format)
180 if ".%s" % (format) != name[-n - 1:]:
181 name += ".%s" % (format)
182 if preliminary:
183 fig.text(
184 0.5, 0.5, 'Preliminary', fontsize=90, color='gray', alpha=0.1,
185 ha='center', va='center', rotation=30
186 )
187 fig.tight_layout()
188 fig.savefig(name, format=format)
189 if close:
190 fig.close()
192 @property
193 def _total_number_of_labels(self):
194 _number_of_labels = 0
195 for item in [self.labels, self.existing_labels]:
196 if isinstance(item, list):
197 _number_of_labels += len(item)
198 return _number_of_labels
200 @staticmethod
201 def check_latex_labels(parameters):
202 """Check to see if there is a latex label for all parameters. If not,
203 then create one
205 Parameters
206 ----------
207 parameters: list
208 list of parameters
209 """
210 for i in parameters:
211 if i not in list(latex_labels.keys()):
212 latex_labels[i] = i.replace("_", " ")
214 @property
215 def savedir(self):
216 return self._savedir
218 @savedir.setter
219 def savedir(self, savedir):
220 self._savedir = savedir
221 if savedir is None:
222 self._savedir = self.webdir + "/plots/"
224 def setup_pool(self):
225 """Setup a pool of processes to speed up plot generation
226 """
227 pool = Pool(processes=self.multi_process)
228 return pool
230 def generate_plots(self):
231 """Generate all plots for all result files
232 """
233 for i in self.labels:
234 logger.debug("Starting to generate plots for {}".format(i))
235 self._generate_plots(i)
236 if self.make_interactive:
237 logger.debug(
238 "Starting to generate interactive plots for {}".format(i)
239 )
240 self._generate_interactive_plots(i)
241 if self.add_to_existing:
242 self.add_existing_data()
243 if self.make_comparison:
244 logger.debug("Starting to generate comparison plots")
245 self._generate_comparison_plots()
247 def check_key_data_in_dict(self, label, param):
248 """Check to see if there is key data for a given param
250 Parameters
251 ----------
252 label: str
253 the label used to distinguish a given run
254 param: str
255 name of the parameter you wish to return prior samples for
256 """
257 if self.key_data is None:
258 return None
259 elif label not in self.key_data.keys():
260 return None
261 elif param not in self.key_data[label].keys():
262 return None
263 return self.key_data[label][param]
265 def check_prior_samples_in_dict(self, label, param):
266 """Check to see if there are prior samples for a given param
268 Parameters
269 ----------
270 label: str
271 the label used to distinguish a given run
272 param: str
273 name of the parameter you wish to return prior samples for
274 """
275 cond1 = "samples" in self.priors.keys()
276 if cond1 and label in self.priors["samples"].keys():
277 cond1 = self.priors["samples"][label] != []
278 if cond1 and param in self.priors["samples"][label].keys():
279 return self.priors["samples"][label][param]
280 return None
281 return None
283 def add_existing_data(self):
284 """
285 """
286 from pesummary.utils.utils import _add_existing_data
288 self = _add_existing_data(self)
290 def _generate_plots(self, label):
291 """Generate all plots for a a given result file
292 """
293 if self.make_corner:
294 self.try_to_make_a_plot("corner", label=label)
295 self.try_to_make_a_plot("oned_histogram", label=label)
296 self.try_to_make_a_plot("sample_evolution", label=label)
297 self.try_to_make_a_plot("autocorrelation", label=label)
298 self.try_to_make_a_plot("oned_cdf", label=label)
299 if self.expert_plots:
300 self.try_to_make_a_plot("expert", label=label)
301 if self.custom_plotting:
302 self.try_to_make_a_plot("custom", label=label)
304 def _generate_interactive_plots(self, label):
305 """Generate all interactive plots and save them to an html file ready
306 to be imported later
307 """
308 self.try_to_make_a_plot("interactive_corner", label=label)
309 if self.make_comparison:
310 self.try_to_make_a_plot("interactive_ridgeline")
312 def _generate_comparison_plots(self):
313 """Generate all comparison plots
314 """
315 self.try_to_make_a_plot("oned_histogram_comparison")
316 self.try_to_make_a_plot("oned_cdf_comparison")
317 self.try_to_make_a_plot("box_plot_comparison")
319 def try_to_make_a_plot(self, plot_type, label=None):
320 """Wrapper function to _try_to_make_a_plot
322 Parameters
323 ----------
324 plot_type: str
325 String to describe the plot that you wish to try and make
326 label: str
327 The label of the results file that you wish to plot
328 """
329 self._try_to_make_a_plot(
330 [label], self.plot_type_dictionary[plot_type],
331 "Failed to generate %s plot because {}" % (plot_type)
332 )
334 @staticmethod
335 def _try_to_make_a_plot(arguments, function, message):
336 """Try to make a plot. If it fails return an error message and continue
337 plotting
339 Parameters
340 ----------
341 arguments: list
342 list of arguments that you wish to pass to function
343 function: func
344 function that you wish to execute
345 message: str
346 the error message that you wish to be printed.
347 """
348 try:
349 function(*arguments)
350 except RuntimeError:
351 try:
352 from matplotlib import rcParams
354 original = rcParams['text.usetex']
355 rcParams['text.usetex'] = False
356 function(*arguments)
357 rcParams['text.usetex'] = original
358 except Exception as e:
359 logger.info(message.format(e))
360 except Exception as e:
361 logger.info(message.format(e))
362 finally:
363 from matplotlib import pyplot
365 pyplot.close()
367 def corner_plot(self, label):
368 """Generate a corner plot for a given result file
370 Parameters
371 ----------
372 label: str
373 the label for the results file that you wish to plot
374 """
375 if self.mcmc_samples:
376 samples = self.samples[label].combine
377 else:
378 samples = self.samples[label]
379 self._corner_plot(
380 self.savedir, label, samples, latex_labels, self.webdir,
381 self.corner_params, self.preliminary_pages[label], self.checkpoint
382 )
384 @staticmethod
385 def _corner_plot(
386 savedir, label, samples, latex_labels, webdir, params, preliminary=False,
387 checkpoint=False
388 ):
389 """Generate a corner plot for a given set of samples
391 Parameters
392 ----------
393 savedir: str
394 the directory you wish to save the plot in
395 label: str
396 the label corresponding to the results file
397 samples: dict
398 dictionary containing PESummary.utils.array.Array objects that
399 contain samples for each parameter
400 latex_labels: str
401 latex labels for each parameter in samples
402 webdir: str
403 the directory where the `js` directory is located
404 preliminary: Bool, optional
405 if True, add a preliminary watermark to the plot
406 """
407 import warnings
409 with warnings.catch_warnings():
410 warnings.simplefilter("ignore")
411 filename = os.path.join(
412 savedir, "corner", "{}_all_density_plots.png".format(label)
413 )
414 if os.path.isfile(filename) and checkpoint:
415 return
416 fig, params, data = core._make_corner_plot(
417 samples, latex_labels, corner_parameters=params
418 )
419 _PlotGeneration.save(
420 fig, filename, preliminary=preliminary
421 )
422 combine_corner = open(
423 os.path.join(webdir, "js", "combine_corner.js")
424 )
425 combine_corner = combine_corner.readlines()
426 params = [str(i) for i in params]
427 ind = [
428 linenumber for linenumber, line in enumerate(combine_corner)
429 if "var list = {}" in line
430 ][0]
431 combine_corner.insert(
432 ind + 1, " list['{}'] = {};\n".format(label, params)
433 )
434 new_file = open(
435 os.path.join(webdir, "js", "combine_corner.js"), "w"
436 )
437 new_file.writelines(combine_corner)
438 new_file.close()
439 combine_corner = open(
440 os.path.join(webdir, "js", "combine_corner.js")
441 )
442 combine_corner = combine_corner.readlines()
443 params = [str(i) for i in params]
444 ind = [
445 linenumber for linenumber, line in enumerate(combine_corner)
446 if "var data = {}" in line
447 ][0]
448 combine_corner.insert(
449 ind + 1, " data['{}'] = {};\n".format(label, data)
450 )
451 new_file = open(
452 os.path.join(webdir, "js", "combine_corner.js"), "w"
453 )
454 new_file.writelines(combine_corner)
455 new_file.close()
457 def _mcmc_iterator(self, label, function):
458 """If the data is a set of mcmc chains, return a 2d list of samples
459 to plot. Otherwise return a list of posterior samples
460 """
461 if self.mcmc_samples:
462 function += "_mcmc"
463 return self.same_parameters, self.samples[label], getattr(
464 self, function
465 )
466 return self.samples[label].keys(), self.samples[label], getattr(
467 self, function
468 )
470 def oned_histogram_plot(self, label):
471 """Generate oned histogram plots for all parameters in the result file
473 Parameters
474 ----------
475 label: str
476 the label for the results file that you wish to plot
477 """
478 error_message = (
479 "Failed to generate oned_histogram plot for %s because {}"
480 )
482 iterator, samples, function = self._mcmc_iterator(
483 label, "_oned_histogram_plot"
484 )
486 prior = lambda param: self.check_prior_samples_in_dict(
487 label, param
488 ) if self.include_prior else None
489 key_data = lambda label, param: self.check_key_data_in_dict(label, param)
491 arguments = [
492 (
493 [
494 self.savedir, label, param, samples[param],
495 latex_labels[param], self.injection_data[label][param],
496 self.kde_plot, prior(param), self.weights[label],
497 self.package, self.preliminary_pages[label],
498 self.checkpoint, key_data(label, param)
499 ], function, error_message % (param)
500 ) for param in iterator
501 ]
502 self.pool.starmap(self._try_to_make_a_plot, arguments)
504 def oned_histogram_comparison_plot(self, label):
505 """Generate oned comparison histogram plots for all parameters that are
506 common to all result files
508 Parameters
509 ----------
510 label: str
511 the label for the results file that you wish to plot
512 """
513 error_message = (
514 "Failed to generate a comparison histogram plot for %s because {}"
515 )
516 for param in self.same_parameters:
517 injection = [
518 value[param] for value in self.injection_data.values()
519 ]
520 arguments = [
521 self.savedir, param, self.same_samples[param],
522 latex_labels[param], self.colors, injection, self.kde_plot,
523 self.linestyles, self.package,
524 self.preliminary_comparison_pages, self.checkpoint, None
525 ]
526 self._try_to_make_a_plot(
527 arguments, self._oned_histogram_comparison_plot,
528 error_message % (param)
529 )
530 continue
532 @staticmethod
533 def _oned_histogram_comparison_plot(
534 savedir, parameter, samples, latex_label, colors, injection, kde=False,
535 linestyles=None, package="core", preliminary=False, checkpoint=False,
536 filename=None
537 ):
538 """Generate a oned comparison histogram plot for a given parameter
540 Parameters
541 ----------i
542 savedir: str
543 the directory you wish to save the plot in
544 parameter: str
545 the name of the parameter that you wish to make a oned comparison
546 histogram for
547 samples: dict
548 dictionary of pesummary.utils.array.Array objects containing the
549 samples that correspond to parameter for each result file. The key
550 should be the corresponding label
551 latex_label: str
552 the latex label for parameter
553 colors: list
554 list of colors to be used to distinguish different result files
555 injection: list
556 list of injected values, one for each analysis
557 kde: Bool, optional
558 if True, kde plots will be generated rather than 1d histograms
559 linestyles: list, optional
560 list of linestyles used to distinguish different result files
561 preliminary: Bool, optional
562 if True, add a preliminary watermark to the plot
563 """
564 import math
565 module = importlib.import_module(
566 "pesummary.{}.plots.plot".format(package)
567 )
568 if filename is None:
569 filename = os.path.join(
570 savedir, "combined_1d_posterior_{}.png".format(parameter)
571 )
572 if os.path.isfile(filename) and checkpoint:
573 return
574 hist = not kde
575 for num, inj in enumerate(injection):
576 if math.isnan(inj):
577 injection[num] = None
578 same_samples = [val for key, val in samples.items()]
579 fig = module._1d_comparison_histogram_plot(
580 parameter, same_samples, colors, latex_label,
581 list(samples.keys()), inj_value=injection, kde=kde,
582 linestyles=linestyles, hist=hist
583 )
584 _PlotGeneration.save(
585 fig, filename, preliminary=preliminary
586 )
588 @staticmethod
589 def _oned_histogram_plot(
590 savedir, label, parameter, samples, latex_label, injection, kde=False,
591 prior=None, weights=None, package="core", preliminary=False,
592 checkpoint=False, key_data=None
593 ):
594 """Generate a oned histogram plot for a given set of samples
596 Parameters
597 ----------
598 savedir: str
599 the directory you wish to save the plot in
600 label: str
601 the label corresponding to the results file
602 parameter: str
603 the name of the parameter that you wish to plot
604 samples: PESummary.utils.array.Array
605 array containing the samples corresponding to parameter
606 latex_label: str
607 the latex label corresponding to parameter
608 injection: float
609 the injected value
610 kde: Bool, optional
611 if True, kde plots will be generated rather than 1d histograms
612 prior: PESummary.utils.array.Array, optional
613 the prior samples for param
614 weights: PESummary.utils.utilsrray, optional
615 the weights for each samples. If None, assumed to be 1
616 preliminary: Bool, optional
617 if True, add a preliminary watermark to the plot
618 """
619 import math
620 module = importlib.import_module(
621 "pesummary.{}.plots.plot".format(package)
622 )
624 if math.isnan(injection):
625 injection = None
626 hist = not kde
628 filename = os.path.join(
629 savedir, "{}_1d_posterior_{}.png".format(label, parameter)
630 )
631 if os.path.isfile(filename) and checkpoint:
632 return
633 fig = module._1d_histogram_plot(
634 parameter, samples, latex_label, inj_value=injection, kde=kde,
635 hist=hist, prior=prior, weights=weights, key_data=key_data
636 )
637 _PlotGeneration.save(
638 fig, filename, preliminary=preliminary
639 )
641 @staticmethod
642 def _oned_histogram_plot_mcmc(
643 savedir, label, parameter, samples, latex_label, injection, kde=False,
644 prior=None, weights=None, package="core", preliminary=False,
645 checkpoint=False, key_data=None
646 ):
647 """Generate a oned histogram plot for a given set of samples for
648 multiple mcmc chains
650 Parameters
651 ----------
652 savedir: str
653 the directory you wish to save the plot in
654 label: str
655 the label corresponding to the results file
656 parameter: str
657 the name of the parameter that you wish to plot
658 samples: dict
659 dictionary of PESummary.utils.array.Array objects containing the
660 samples corresponding to parameter for multiple mcmc chains
661 latex_label: str
662 the latex label corresponding to parameter
663 injection: float
664 the injected value
665 kde: Bool, optional
666 if True, kde plots will be generated rather than 1d histograms
667 prior: PESummary.utils.array.Array, optional
668 the prior samples for param
669 weights: PESummary.utils.array.Array, optional
670 the weights for each samples. If None, assumed to be 1
671 preliminary: Bool, optional
672 if True, add a preliminary watermark to the plot
673 """
674 import math
675 from pesummary.utils.array import Array
677 module = importlib.import_module(
678 "pesummary.{}.plots.plot".format(package)
679 )
681 if math.isnan(injection):
682 injection = None
683 same_samples = [val for key, val in samples.items()]
684 filename = os.path.join(
685 savedir, "{}_1d_posterior_{}.png".format(label, parameter)
686 )
687 if os.path.isfile(filename) and checkpoint:
688 pass
689 else:
690 fig = module._1d_histogram_plot_mcmc(
691 parameter, same_samples, latex_label, inj_value=injection,
692 kde=kde, prior=prior, weights=weights
693 )
694 _PlotGeneration.save(
695 fig, filename, preliminary=preliminary
696 )
697 filename = os.path.join(
698 savedir, "{}_1d_posterior_{}_combined.png".format(label, parameter)
699 )
700 if os.path.isfile(filename) and checkpoint:
701 pass
702 else:
703 fig = module._1d_histogram_plot(
704 parameter, Array(np.concatenate(same_samples)), latex_label,
705 inj_value=injection, kde=kde, prior=prior, weights=weights
706 )
707 _PlotGeneration.save(
708 fig, filename, preliminary=preliminary
709 )
711 def sample_evolution_plot(self, label):
712 """Generate sample evolution plots for all parameters in the result file
714 Parameters
715 ----------
716 label: str
717 the label for the results file that you wish to plot
718 """
719 error_message = (
720 "Failed to generate a sample evolution plot for %s because {}"
721 )
722 iterator, samples, function = self._mcmc_iterator(
723 label, "_sample_evolution_plot"
724 )
725 arguments = [
726 (
727 [
728 self.savedir, label, param, samples[param],
729 latex_labels[param], self.injection_data[label][param],
730 self.preliminary_pages[label], self.checkpoint
731 ], function, error_message % (param)
732 ) for param in iterator
733 ]
734 self.pool.starmap(self._try_to_make_a_plot, arguments)
736 def expert_plot(self, label):
737 """Generate expert plots for diagnostics
739 Parameters
740 ----------
741 label: str
742 the label for the results file that you wish to plot
743 """
744 error_message = (
745 "Failed to generate log_likelihood-%s 2d contour plot because {}"
746 )
747 iterator, samples, function = self._mcmc_iterator(
748 label, "_2d_contour_plot"
749 )
750 _debug = self.samples[label].debug_keys()
751 arguments = [
752 (
753 [
754 self.savedir, label, param, "log_likelihood", samples[param],
755 samples["log_likelihood"], latex_labels[param],
756 latex_labels["log_likelihood"],
757 self.preliminary_pages[label], [
758 samples[param][np.argmax(samples["log_likelihood"])],
759 np.max(samples["log_likelihood"]),
760 ], self.checkpoint
761 ], function, error_message % (param)
762 ) for param in iterator + _debug
763 ]
764 self.pool.starmap(self._try_to_make_a_plot, arguments)
765 _reweight_keys = [
766 param for param in self.samples[label].debug_keys() if
767 "_non_reweighted" in param
768 ]
769 if len(_reweight_keys):
770 error_message = (
771 "Failed to generate %s-%s 2d contour plot because {}"
772 )
773 _base_param = lambda p: p.split("_non_reweighted")[0][1:]
774 arguments = [
775 (
776 [
777 self.savedir, label, _base_param(param), param,
778 samples[_base_param(param)], samples[param],
779 latex_labels[_base_param(param)], latex_labels[param],
780 self.preliminary_pages[label], None, self.checkpoint
781 ], function, error_message % (_base_param(param), param)
782 ) for param in _reweight_keys
783 ]
784 self.pool.starmap(self._try_to_make_a_plot, arguments)
785 error_message = (
786 "Failed to generate a histogram plot comparing %s and %s "
787 "because {}"
788 )
789 arguments = [
790 (
791 [
792 self.savedir, _base_param(param), {
793 "reweighted": samples[_base_param(param)],
794 "non-reweighted": samples[param]
795 }, latex_labels[_base_param(param)], ['b', 'r'],
796 [np.nan, np.nan], True, None, self.package,
797 self.preliminary_comparison_pages, self.checkpoint,
798 os.path.join(
799 self.savedir, "{}_1d_posterior_{}_{}.png".format(
800 label, _base_param(param), param
801 )
802 )
803 ], self._oned_histogram_comparison_plot,
804 error_message % (_base_param(param), param)
805 ) for param in _reweight_keys
806 ]
807 self.pool.starmap(self._try_to_make_a_plot, arguments)
809 error_message = (
810 "Failed to generate log_likelihood-%s sample_evolution plot "
811 "because {}"
812 )
813 iterator, samples, function = self._mcmc_iterator(
814 label, "_colored_sample_evolution_plot"
815 )
816 arguments = [
817 (
818 [
819 self.savedir, label, param, "log_likelihood", samples[param],
820 samples["log_likelihood"], latex_labels[param],
821 latex_labels["log_likelihood"],
822 self.preliminary_pages[label], self.checkpoint
823 ], function, error_message % (param)
824 ) for param in iterator
825 ]
826 self.pool.starmap(self._try_to_make_a_plot, arguments)
827 error_message = (
828 "Failed to generate bootstrapped oned_histogram plot for %s "
829 "because {}"
830 )
831 iterator, samples, function = self._mcmc_iterator(
832 label, "_oned_histogram_bootstrap_plot"
833 )
834 arguments = [
835 (
836 [
837 self.savedir, label, param, samples[param],
838 latex_labels[param], self.preliminary_pages[label],
839 self.package, self.checkpoint
840 ], function, error_message % (param)
841 ) for param in iterator
842 ]
843 self.pool.starmap(self._try_to_make_a_plot, arguments)
845 @staticmethod
846 def _oned_histogram_bootstrap_plot(
847 savedir, label, parameter, samples, latex_label, preliminary=False,
848 package="core", checkpoint=False, nsamples=1000, ntests=100, **kwargs
849 ):
850 """Generate a bootstrapped oned histogram plot for a given set of
851 samples
853 Parameters
854 ----------
855 savedir: str
856 the directory you wish to save the plot in
857 label: str
858 the label corresponding to the results file
859 parameter: str
860 the name of the parameter that you wish to plot
861 samples: PESummary.utils.array.Array
862 array containing the samples corresponding to parameter
863 latex_label: str
864 the latex label corresponding to parameter
865 preliminary: Bool, optional
866 if True, add a preliminary watermark to the plot
867 """
868 module = importlib.import_module(
869 "pesummary.{}.plots.plot".format(package)
870 )
872 filename = os.path.join(
873 savedir, "{}_1d_posterior_{}_bootstrap.png".format(label, parameter)
874 )
875 if os.path.isfile(filename) and checkpoint:
876 return
877 fig = module._1d_histogram_plot_bootstrap(
878 parameter, samples, latex_label, nsamples=nsamples, ntests=ntests,
879 **kwargs
880 )
881 _PlotGeneration.save(
882 fig, filename, preliminary=preliminary
883 )
885 @staticmethod
886 def _2d_contour_plot(
887 savedir, label, parameter_x, parameter_y, samples_x, samples_y,
888 latex_label_x, latex_label_y, preliminary=False, truth=None,
889 checkpoint=False
890 ):
891 """Generate a 2d contour plot for a given set of samples
893 Parameters
894 ----------
895 savedir: str
896 the directory you wish to save the plot in
897 label: str
898 the label corresponding to the results file
899 samples_x: PESummary.utils.array.Array
900 array containing the samples for the x axis
901 samples_y: PESummary.utils.array.Array
902 array containing the samples for the y axis
903 latex_label_x: str
904 the latex label for the x axis
905 latex_label_y: str
906 the latex label for the y axis
907 preliminary: Bool, optional
908 if True, add a preliminary watermark to the plot
909 """
910 from pesummary.core.plots.publication import twod_contour_plot
912 filename = os.path.join(
913 savedir, "{}_2d_contour_{}_{}.png".format(
914 label, parameter_x, parameter_y
915 )
916 )
917 if os.path.isfile(filename) and checkpoint:
918 return
919 fig = twod_contour_plot(
920 samples_x, samples_y, levels=[0.9, 0.5], xlabel=latex_label_x,
921 ylabel=latex_label_y, bins=50, truth=truth
922 )
923 _PlotGeneration.save(
924 fig, filename, preliminary=preliminary
925 )
927 @staticmethod
928 def _colored_sample_evolution_plot(
929 savedir, label, parameter_x, parameter_y, samples_x, samples_y,
930 latex_label_x, latex_label_y, preliminary=False, checkpoint=False
931 ):
932 """Generate a 2d contour plot for a given set of samples
934 Parameters
935 ----------
936 savedir: str
937 the directory you wish to save the plot in
938 label: str
939 the label corresponding to the results file
940 samples_x: PESummary.utils.array.Array
941 array containing the samples for the x axis
942 samples_y: PESummary.utils.array.Array
943 array containing the samples for the y axis
944 latex_label_x: str
945 the latex label for the x axis
946 latex_label_y: str
947 the latex label for the y axis
948 preliminary: Bool, optional
949 if True, add a preliminary watermark to the plot
950 """
951 filename = os.path.join(
952 savedir, "{}_sample_evolution_{}_{}_colored.png".format(
953 label, parameter_x, parameter_y
954 )
955 )
956 if os.path.isfile(filename) and checkpoint:
957 return
958 fig = core._sample_evolution_plot(
959 parameter_x, samples_x, latex_label_x, z=samples_y,
960 z_label=latex_label_y
961 )
962 _PlotGeneration.save(
963 fig, filename, preliminary=preliminary
964 )
966 @staticmethod
967 def _sample_evolution_plot(
968 savedir, label, parameter, samples, latex_label, injection,
969 preliminary=False, checkpoint=False
970 ):
971 """Generate a sample evolution plot for a given set of samples
973 Parameters
974 ----------
975 savedir: str
976 the directory you wish to save the plot in
977 label: str
978 the label corresponding to the results file
979 parameter: str
980 the name of the parameter that you wish to plot
981 samples: PESummary.utils.array.Array
982 array containing the samples corresponding to parameter
983 latex_label: str
984 the latex label corresponding to parameter
985 injection: float
986 the injected value
987 preliminary: Bool, optional
988 if True, add a preliminary watermark to the plot
989 """
990 filename = os.path.join(
991 savedir, "{}_sample_evolution_{}.png".format(label, parameter)
992 )
993 if os.path.isfile(filename) and checkpoint:
994 return
995 fig = core._sample_evolution_plot(
996 parameter, samples, latex_label, injection
997 )
998 _PlotGeneration.save(
999 fig, filename, preliminary=preliminary
1000 )
1002 @staticmethod
1003 def _sample_evolution_plot_mcmc(
1004 savedir, label, parameter, samples, latex_label, injection,
1005 preliminary=False, checkpoint=False
1006 ):
1007 """Generate a sample evolution plot for a given set of mcmc chains
1009 Parameters
1010 ----------
1011 savedir: str
1012 the directory you wish to save the plot in
1013 label: str
1014 the label corresponding to the results file
1015 parameter: str
1016 the name of the parameter that you wish to plot
1017 samples: dict
1018 dictionary containing pesummary.utils.array.Array objects containing
1019 the samples corresponding to parameter for each chain
1020 latex_label: str
1021 the latex label corresponding to parameter
1022 injection: float
1023 the injected value
1024 preliminary: Bool, optional
1025 if True, add a preliminary watermark to the plot
1026 """
1027 filename = os.path.join(
1028 savedir, "{}_sample_evolution_{}.png".format(label, parameter)
1029 )
1030 if os.path.isfile(filename) and checkpoint:
1031 return
1032 same_samples = [val for key, val in samples.items()]
1033 fig = core._sample_evolution_plot_mcmc(
1034 parameter, same_samples, latex_label, injection
1035 )
1036 _PlotGeneration.save(
1037 fig, filename, preliminary=preliminary
1038 )
1040 def autocorrelation_plot(self, label):
1041 """Generate autocorrelation plots for all parameters in the result file
1043 Parameters
1044 ----------
1045 label: str
1046 the label for the results file that you wish to plot
1047 """
1048 error_message = (
1049 "Failed to generate an autocorrelation plot for %s because {}"
1050 )
1051 iterator, samples, function = self._mcmc_iterator(
1052 label, "_autocorrelation_plot"
1053 )
1054 arguments = [
1055 (
1056 [
1057 self.savedir, label, param, samples[param],
1058 self.preliminary_pages[label], self.checkpoint
1059 ], function, error_message % (param)
1060 ) for param in iterator
1061 ]
1062 self.pool.starmap(self._try_to_make_a_plot, arguments)
1064 @staticmethod
1065 def _autocorrelation_plot(
1066 savedir, label, parameter, samples, preliminary=False, checkpoint=False
1067 ):
1068 """Generate an autocorrelation plot for a given set of samples
1070 Parameters
1071 ----------
1072 savedir: str
1073 the directory you wish to save the plot in
1074 label: str
1075 the label corresponding to the results file
1076 parameter: str
1077 the name of the parameter that you wish to plot
1078 samples: PESummary.utils.array.Array
1079 array containing the samples corresponding to parameter
1080 preliminary: Bool, optional
1081 if True, add a preliminary watermark to the plot
1082 """
1083 filename = os.path.join(
1084 savedir, "{}_autocorrelation_{}.png".format(label, parameter)
1085 )
1086 if os.path.isfile(filename) and checkpoint:
1087 return
1088 fig = core._autocorrelation_plot(parameter, samples)
1089 _PlotGeneration.save(
1090 fig, filename, preliminary=preliminary
1091 )
1093 @staticmethod
1094 def _autocorrelation_plot_mcmc(
1095 savedir, label, parameter, samples, preliminary=False, checkpoint=False
1096 ):
1097 """Generate an autocorrelation plot for a list of samples, one for each
1098 mcmc chain
1100 Parameters
1101 ----------
1102 savedir: str
1103 the directory you wish to save the plot in
1104 label: str
1105 the label corresponding to the results file
1106 parameter: str
1107 the name of the parameter that you wish to plot
1108 samples: dict
1109 dictioanry of PESummary.utils.array.Array objects containing the
1110 samples corresponding to parameter for each mcmc chain
1111 preliminary: Bool, optional
1112 if True, add a preliminary watermark to the plot
1113 """
1114 filename = os.path.join(
1115 savedir, "{}_autocorrelation_{}.png".format(label, parameter)
1116 )
1117 if os.path.isfile(filename) and checkpoint:
1118 return
1119 same_samples = [val for key, val in samples.items()]
1120 fig = core._autocorrelation_plot_mcmc(parameter, same_samples)
1121 _PlotGeneration.save(
1122 fig, filename, preliminary=preliminary
1123 )
1125 def oned_cdf_plot(self, label):
1126 """Generate oned CDF plots for all parameters in the result file
1128 Parameters
1129 ----------
1130 label: str
1131 the label for the results file that you wish to plot
1132 """
1133 error_message = (
1134 "Failed to generate a CDF plot for %s because {}"
1135 )
1136 iterator, samples, function = self._mcmc_iterator(
1137 label, "_oned_cdf_plot"
1138 )
1139 arguments = [
1140 (
1141 [
1142 self.savedir, label, param, samples[param],
1143 latex_labels[param], self.preliminary_pages[label],
1144 self.checkpoint
1145 ], function, error_message % (param)
1146 ) for param in iterator
1147 ]
1148 self.pool.starmap(self._try_to_make_a_plot, arguments)
1150 @staticmethod
1151 def _oned_cdf_plot(
1152 savedir, label, parameter, samples, latex_label, preliminary=False,
1153 checkpoint=False
1154 ):
1155 """Generate a oned CDF plot for a given set of samples
1157 Parameters
1158 ----------
1159 savedir: str
1160 the directory you wish to save the plot in
1161 label: str
1162 the label corresponding to the results file
1163 parameter: str
1164 the name of the parameter that you wish to plot
1165 samples: PESummary.utils.array.Array
1166 array containing the samples corresponding to parameter
1167 latex_label: str
1168 the latex label corresponding to parameter
1169 preliminary: Bool, optional
1170 if True, add a preliminary watermark to the plot
1171 """
1172 filename = os.path.join(
1173 savedir + "{}_cdf_{}.png".format(label, parameter)
1174 )
1175 if os.path.isfile(filename) and checkpoint:
1176 return
1177 fig = core._1d_cdf_plot(parameter, samples, latex_label)
1178 _PlotGeneration.save(
1179 fig, filename, preliminary=preliminary
1180 )
1182 @staticmethod
1183 def _oned_cdf_plot_mcmc(
1184 savedir, label, parameter, samples, latex_label, preliminary=False,
1185 checkpoint=False
1186 ):
1187 """Generate a oned CDF plot for a given set of samples, one for each
1188 mcmc chain
1190 Parameters
1191 ----------
1192 savedir: str
1193 the directory you wish to save the plot in
1194 label: str
1195 the label corresponding to the results file
1196 parameter: str
1197 the name of the parameter that you wish to plot
1198 samples: dict
1199 dictionary of PESummary.utils.array.Array objects containing the
1200 samples corresponding to parameter for each mcmc chain
1201 latex_label: str
1202 the latex label corresponding to parameter
1203 preliminary: Bool, optional
1204 if True, add a preliminary watermark to the plot
1205 """
1206 filename = os.path.join(
1207 savedir + "{}_cdf_{}.png".format(label, parameter)
1208 )
1209 if os.path.isfile(filename) and checkpoint:
1210 return
1211 same_samples = [val for key, val in samples.items()]
1212 fig = core._1d_cdf_plot_mcmc(parameter, same_samples, latex_label)
1213 _PlotGeneration.save(
1214 fig, filename, preliminary=preliminary
1215 )
1217 def interactive_ridgeline_plot(self, label):
1218 """Generate an interactive ridgeline plot for all paramaters that are
1219 common to all result files
1220 """
1221 error_message = (
1222 "Failed to generate an interactive ridgeline plot for %s because {}"
1223 )
1224 for param in self.same_parameters:
1225 arguments = [
1226 self.savedir, param, self.same_samples[param],
1227 latex_labels[param], self.colors, self.checkpoint
1228 ]
1229 self._try_to_make_a_plot(
1230 arguments, self._interactive_ridgeline_plot,
1231 error_message % (param)
1232 )
1233 continue
1235 @staticmethod
1236 def _interactive_ridgeline_plot(
1237 savedir, parameter, samples, latex_label, colors, checkpoint=False
1238 ):
1239 """Generate an interactive ridgeline plot for
1240 """
1241 filename = os.path.join(
1242 savedir, "interactive_ridgeline_{}.html".format(parameter)
1243 )
1244 if os.path.isfile(filename) and checkpoint:
1245 return
1246 same_samples = [val for key, val in samples.items()]
1247 _ = interactive.ridgeline(
1248 same_samples, list(samples.keys()), xlabel=latex_label,
1249 colors=colors, write_to_html_file=filename
1250 )
1252 def interactive_corner_plot(self, label):
1253 """Generate an interactive corner plot for a given result file
1255 Parameters
1256 ----------
1257 label: str
1258 the label for the results file that you wish to plot
1259 """
1260 self._interactive_corner_plot(
1261 self.savedir, label, self.samples[label], latex_labels,
1262 self.checkpoint
1263 )
1265 @staticmethod
1266 def _interactive_corner_plot(
1267 savedir, label, samples, latex_labels, checkpoint=False
1268 ):
1269 """Generate an interactive corner plot for a given set of samples
1271 Parameters
1272 ----------
1273 savedir: str
1274 the directory you wish to save the plot in
1275 label: str
1276 the label corresponding to the results file
1277 samples: dict
1278 dictionary containing PESummary.utils.array.Array objects that
1279 contain samples for each parameter
1280 latex_labels: str
1281 latex labels for each parameter in samples
1282 """
1283 filename = os.path.join(
1284 savedir, "corner", "{}_interactive.html".format(label)
1285 )
1286 if os.path.isfile(filename) and checkpoint:
1287 return
1288 parameters = samples.keys()
1289 data = [samples[parameter] for parameter in parameters]
1290 latex_labels = [latex_labels[parameter] for parameter in parameters]
1291 _ = interactive.corner(
1292 data, latex_labels, write_to_html_file=filename
1293 )
1295 def oned_cdf_comparison_plot(self, label):
1296 """Generate oned comparison CDF plots for all parameters that are
1297 common to all result files
1299 Parameters
1300 ----------
1301 label: str
1302 the label for the results file that you wish to plot
1303 """
1304 error_message = (
1305 "Failed to generate a comparison CDF plot for %s because {}"
1306 )
1307 for param in self.same_parameters:
1308 arguments = [
1309 self.savedir, param, self.same_samples[param],
1310 latex_labels[param], self.colors, self.linestyles,
1311 self.preliminary_comparison_pages, self.checkpoint
1312 ]
1313 self._try_to_make_a_plot(
1314 arguments, self._oned_cdf_comparison_plot,
1315 error_message % (param)
1316 )
1317 continue
1319 @staticmethod
1320 def _oned_cdf_comparison_plot(
1321 savedir, parameter, samples, latex_label, colors, linestyles=None,
1322 preliminary=False, checkpoint=False
1323 ):
1324 """Generate a oned comparison CDF plot for a given parameter
1326 Parameters
1327 ----------
1328 savedir: str
1329 the directory you wish to save the plot in
1330 parameter: str
1331 the name of the parameter that you wish to make a oned comparison
1332 histogram for
1333 samples: dict
1334 dictionary of pesummary.utils.array.Array objects containing the
1335 samples that correspond to parameter for each result file. The key
1336 should be the corresponding label
1337 latex_label: str
1338 the latex label for parameter
1339 colors: list
1340 list of colors to be used to distinguish different result files
1341 linestyles: list, optional
1342 list of linestyles used to distinguish different result files
1343 preliminary: Bool, optional
1344 if True, add a preliminary watermark to the plot
1345 """
1346 filename = os.path.join(
1347 savedir, "combined_cdf_{}.png".format(parameter)
1348 )
1349 if os.path.isfile(filename) and checkpoint:
1350 return
1351 keys = list(samples.keys())
1352 same_samples = [samples[key] for key in keys]
1353 fig = core._1d_cdf_comparison_plot(
1354 parameter, same_samples, colors, latex_label, keys, linestyles
1355 )
1356 _PlotGeneration.save(
1357 fig, filename, preliminary=preliminary
1358 )
1360 def box_plot_comparison_plot(self, label):
1361 """Generate comparison box plots for all parameters that are
1362 common to all result files
1364 Parameters
1365 ----------
1366 label: str
1367 the label for the results file that you wish to plot
1368 """
1369 error_message = (
1370 "Failed to generate a comparison box plot for %s because {}"
1371 )
1372 for param in self.same_parameters:
1373 arguments = [
1374 self.savedir, param, self.same_samples[param],
1375 latex_labels[param], self.colors,
1376 self.preliminary_comparison_pages, self.checkpoint
1377 ]
1378 self._try_to_make_a_plot(
1379 arguments, self._box_plot_comparison_plot,
1380 error_message % (param)
1381 )
1382 continue
1384 @staticmethod
1385 def _box_plot_comparison_plot(
1386 savedir, parameter, samples, latex_label, colors, preliminary=False,
1387 checkpoint=False
1388 ):
1389 """Generate a comparison box plot for a given parameter
1391 Parameters
1392 ----------
1393 savedir: str
1394 the directory you wish to save the plot in
1395 parameter: str
1396 the name of the parameter that you wish to make a oned comparison
1397 histogram for
1398 samples: dict
1399 dictionary of pesummary.utils.array.Array objects containing the
1400 samples that correspond to parameter for each result file. The key
1401 should be the corresponding label
1402 latex_label: str
1403 the latex label for parameter
1404 colors: list
1405 list of colors to be used to distinguish different result files
1406 preliminary: Bool, optional
1407 if True, add a preliminary watermark to the plot
1408 """
1409 filename = os.path.join(
1410 savedir, "combined_boxplot_{}.png".format(parameter)
1411 )
1412 if os.path.isfile(filename) and checkpoint:
1413 return
1414 same_samples = [val for key, val in samples.items()]
1415 fig = core._comparison_box_plot(
1416 parameter, same_samples, colors, latex_label,
1417 list(samples.keys())
1418 )
1419 _PlotGeneration.save(
1420 fig, filename, preliminary=preliminary
1421 )
1423 def custom_plot(self, label):
1424 """Generate custom plots according to the passed python file
1426 Parameters
1427 ----------
1428 label: str
1429 the label for the results file that you wish to plot
1430 """
1431 import importlib
1433 if self.custom_plotting[0] != "":
1434 import sys
1436 sys.path.append(self.custom_plotting[0])
1437 mod = importlib.import_module(self.custom_plotting[1])
1439 methods = [getattr(mod, i) for i in mod.__single_plots__]
1440 for num, i in enumerate(methods):
1441 fig = i(
1442 list(self.samples[label].keys()), self.samples[label]
1443 )
1444 _PlotGeneration.save(
1445 fig, os.path.join(
1446 self.savedir, "{}_custom_plotting_{}".format(label, num)
1447 ), preliminary=self.preliminary_pages[label]
1448 )