Coverage for pesummary/core/file/formats/base_read.py: 57.4%
427 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import os
4import numpy as np
5import h5py
6from pesummary.utils.parameters import MultiAnalysisParameters, Parameters
7from pesummary.utils.samples_dict import (
8 MultiAnalysisSamplesDict, SamplesDict, MCMCSamplesDict, Array
9)
10from pesummary.utils.utils import logger
12__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
15def _downsample(samples, number, extra_kwargs=None):
16 """Downsample a posterior table
18 Parameters
19 ----------
20 samples: 2d list
21 list of posterior samples where the columns correspond to a given
22 parameter
23 number: int
24 number of posterior samples you wish to downsample to
25 extra_kwargs: dict, optional
26 dictionary of kwargs to modify
27 """
28 from pesummary.utils.utils import resample_posterior_distribution
29 import copy
31 _samples = np.array(samples).T
32 if number > len(_samples[0]):
33 raise ValueError(
34 "Failed to downsample the posterior samples to {} because "
35 "there are only {} samples stored in the file.".format(
36 number, len(_samples[0])
37 )
38 )
39 _samples = np.array(resample_posterior_distribution(_samples, number))
40 if extra_kwargs is None:
41 return _samples.T.tolist()
42 _extra_kwargs = copy.deepcopy(extra_kwargs)
43 _extra_kwargs["sampler"]["nsamples"] = number
44 return _samples.T.tolist(), _extra_kwargs
47class Read(object):
48 """Base class to read in a results file
50 Parameters
51 ----------
52 path_to_results_file: str
53 path to the results file you wish to load
54 remove_nan_likelihood_samples: Bool, optional
55 if True, remove samples which have log_likelihood='nan'. Default True
57 Attributes
58 ----------
59 parameters: list
60 list of parameters stored in the result file
61 samples: 2d list
62 list of samples stored in the result file
63 samples_dict: dict
64 dictionary of samples stored in the result file keyed by parameters
65 input_version: str
66 version of the result file passed.
67 extra_kwargs: dict
68 dictionary of kwargs that were extracted from the result file
69 pe_algorithm: str
70 name of the algorithm used to generate the posterior samples
72 Methods
73 -------
74 downsample:
75 downsample the posterior samples stored in the result file
76 to_dat:
77 save the posterior samples to a .dat file
78 to_latex_table:
79 convert the posterior samples to a latex table
80 generate_latex_macros:
81 generate a set of latex macros for the stored posterior samples
82 """
83 def __init__(
84 self, path_to_results_file, remove_nan_likelihood_samples=True, **kwargs
85 ):
86 self.path_to_results_file = path_to_results_file
87 self.mcmc_samples = False
88 self.remove_nan_likelihood_samples = remove_nan_likelihood_samples
89 self.extension = self.extension_from_path(self.path_to_results_file)
90 self.converted_parameters = []
92 @classmethod
93 def load_file(cls, path, **kwargs):
94 """Initialize the class with a file
96 Parameters
97 ----------
98 path: str
99 path to the result file you wish to load
100 **kwargs: dict, optional
101 all kwargs passed to the class
102 """
103 if not os.path.isfile(path):
104 raise FileNotFoundError("%s does not exist" % (path))
105 return cls(path, **kwargs)
107 @staticmethod
108 def load_from_function(function, path_to_file, **kwargs):
109 """Load a file according to a given function
111 Parameters
112 ----------
113 function: func
114 callable function that will load in your file
115 path_to_file: str
116 path to the file that you wish to load
117 kwargs: dict
118 all kwargs are passed to the function
119 """
120 return function(path_to_file, **kwargs)
122 @staticmethod
123 def check_for_nan_likelihoods(parameters, samples, remove=False):
124 """Check to see if there are any samples with log_likelihood='nan' in
125 the posterior table and remove if requested
127 Parameters
128 ----------
129 parameters: list
130 list of parameters stored in the result file
131 samples: np.ndarray
132 array of samples for each parameter
133 remove: Bool, optional
134 if True, remove samples with log_likelihood='nan' from samples
135 """
136 import math
137 if "log_likelihood" not in parameters:
138 return parameters, samples
139 ind = parameters.index("log_likelihood")
140 likelihoods = np.array(samples).T[ind]
141 inds = np.array(
142 [math.isnan(_) for _ in likelihoods], dtype=bool
143 )
144 if not sum(inds):
145 return parameters, samples
146 msg = (
147 "Posterior table contains {} samples with 'nan' log likelihood. "
148 )
149 if remove:
150 msg += "Removing samples from posterior table."
151 samples = np.array(samples)[~inds].tolist()
152 else:
153 msg += "This may cause problems when analysing posterior samples."
154 logger.warning(msg.format(sum(inds)))
155 return parameters, samples
157 @staticmethod
158 def check_for_weights(parameters, samples):
159 """Check to see if the samples are weighted
161 Parameters
162 ----------
163 parameters: list
164 list of parameters stored in the result file
165 samples: np.ndarray
166 array of samples for each parameter
167 """
168 likely_names = ["weights", "weight"]
169 if any(i in parameters for i in likely_names):
170 ind = (
171 parameters.index("weights") if "weights" in parameters else
172 parameters.index("weight")
173 )
174 return Array(np.array(samples).T[ind])
175 return None
177 @property
178 def pe_algorithm(self):
179 try:
180 return self.extra_kwargs["sampler"]["pe_algorithm"]
181 except KeyError:
182 return None
184 def __repr__(self):
185 return self.summary()
187 def _parameter_summary(self, parameters, parameters_to_show=4):
188 """Return a summary of the parameter stored
190 Parameters
191 ----------
192 parameters: list
193 list of parameters to create a summary for
194 parameters_to_show: int, optional
195 number of parameters to show. Default 4.
196 """
197 params = parameters
198 if len(parameters) > parameters_to_show:
199 params = parameters[:2] + ["..."] + parameters[-2:]
200 return ", ".join(params)
202 def summary(
203 self, parameters_to_show=4, show_parameters=True, show_nsamples=True
204 ):
205 """Return a summary of the contents of the file
207 Parameters
208 ----------
209 parameters_to_show: int, optional
210 number of parameters to show. Default 4
211 show_parameters: Bool, optional
212 if True print a list of the parameters stored
213 show_nsamples: Bool, optional
214 if True print how many samples are stored in the file
215 """
216 string = ""
217 if self.path_to_results_file is not None:
218 string += "file: {}\n".format(self.path_to_results_file)
219 string += "cls: {}.{}\n".format(
220 self.__class__.__module__, self.__class__.__name__
221 )
222 if show_nsamples:
223 string += "nsamples: {}\n".format(len(self.samples))
224 if show_parameters:
225 string += "parameters: {}".format(
226 self._parameter_summary(
227 self.parameters, parameters_to_show=parameters_to_show
228 )
229 )
230 return string
232 attrs = {
233 "input_version": "version", "extra_kwargs": "kwargs",
234 "priors": "prior", "analytic": "analytic", "labels": "labels",
235 "config": "config", "weights": "weights", "history": "history",
236 "description": "description"
237 }
239 def _load(self, function, **kwargs):
240 """Extract the data from a file using a given function
242 Parameters
243 ----------
244 function: func
245 function you wish to use to extract the data
246 **kwargs: dict, optional
247 optional kwargs to pass to the load function
248 """
249 return self.load_from_function(
250 function, self.path_to_results_file, **kwargs
251 )
253 def load(self, function, _data=None, **kwargs):
254 """Load a results file according to a given function
256 Parameters
257 ----------
258 function: func
259 callable function that will load in your results file
260 """
261 self.data = _data
262 if _data is None:
263 self.data = self._load(function, **kwargs)
264 if isinstance(self.data["parameters"][0], list):
265 _cls = MultiAnalysisParameters
266 else:
267 _cls = Parameters
268 self.parameters = _cls(self.data["parameters"])
269 self.converted_parameters = []
270 self.samples = self.data["samples"]
271 self.parameters, self.samples = self.check_for_nan_likelihoods(
272 self.parameters, self.samples,
273 remove=self.remove_nan_likelihood_samples
274 )
275 if "mcmc_samples" in self.data.keys():
276 self.mcmc_samples = self.data["mcmc_samples"]
277 if "injection" in self.data.keys():
278 if isinstance(self.data["injection"], dict):
279 self.injection_parameters = {
280 key.decode("utf-8") if isinstance(key, bytes) else key: val
281 for key, val in self.data["injection"].items()
282 }
283 elif isinstance(self.data["injection"], list):
284 self.injection_parameters = [
285 {
286 key.decode("utf-8") if isinstance(key, bytes) else
287 key: val for key, val in i.items()
288 } for i in self.data["injection"]
289 ]
290 else:
291 self.injection_parameters = self.data["injection"]
292 for new_attr, existing_attr in self.attrs.items():
293 if existing_attr in self.data.keys():
294 setattr(self, new_attr, self.data[existing_attr])
295 else:
296 setattr(self, new_attr, None)
297 if self.input_version is None:
298 self.input_version = self._default_version
299 if self.extra_kwargs is None:
300 self.extra_kwargs = self._default_kwargs
301 if self.description is None:
302 self.description = self._default_description
303 if self.weights is None:
304 self.weights = self.check_for_weights(self.parameters, self.samples)
306 @staticmethod
307 def extension_from_path(path):
308 """Return the extension of the file from the file path
310 Parameters
311 ----------
312 path: str
313 path to the results file
314 """
315 extension = path.split(".")[-1]
316 return extension
318 @staticmethod
319 def guess_path_to_samples(path):
320 """Guess the path to the posterior samples stored in an hdf5 object
322 Parameters
323 ----------
324 path: str
325 path to the results file
326 """
327 def _find_name(name, item):
328 c1 = "posterior_samples" in name or "posterior" in name
329 c2 = isinstance(item, (h5py._hl.dataset.Dataset, np.ndarray))
330 _group = isinstance(item, h5py._hl.group.Group)
331 c3, c4 = False, False
332 if _group:
333 try:
334 if isinstance(item[0], (float, int, np.number)):
335 c3 = True
336 except (TypeError, AttributeError):
337 c3 = False
338 try:
339 keys = list(item.keys())
340 if isinstance(item[keys[0]], (h5py._hl.dataset.Dataset, np.ndarray)):
341 c4 = True
342 except (TypeError, IndexError, AttributeError):
343 c4 = False
344 c5 = (
345 _group and "parameter_names" in item.keys() and "samples" in item.keys()
346 )
347 if c1 and c4:
348 paths.append(name)
349 elif c1 and c3:
350 paths.append(name)
351 elif c1 and c5:
352 paths.append(name)
353 elif c1 and c2:
354 if "/".join(name.split("/")[:-1]) not in paths:
355 paths.append(name)
357 f = h5py.File(path, 'r')
358 paths = []
359 f.visititems(_find_name)
360 f.close()
361 if len(paths) == 1:
362 return paths[0]
363 elif len(paths) > 1:
364 raise ValueError(
365 "Found multiple posterior sample tables in '{}': {}. Not sure "
366 "which to load.".format(
367 path, ", ".join(paths)
368 )
369 )
370 else:
371 raise ValueError(
372 "Unable to find a posterior samples table in '{}'".format(path)
373 )
375 def generate_all_posterior_samples(self, **kwargs):
376 """Empty function
377 """
378 pass
380 def add_fixed_parameters_from_config_file(self, config_file):
381 """Search the conifiguration file and add fixed parameters to the
382 list of parameters and samples
384 Parameters
385 ----------
386 config_file: str
387 path to the configuration file
388 """
389 pass
391 def add_injection_parameters_from_file(self, injection_file, **kwargs):
392 """Populate the 'injection_parameters' property with data from a file
394 Parameters
395 ----------
396 injection_file: str
397 path to injection file
398 """
399 self.injection_parameters = self._grab_injection_parameters_from_file(
400 injection_file, **kwargs
401 )
403 def _grab_injection_parameters_from_file(
404 self, path, cls=None, add_nans=True, **kwargs
405 ):
406 """Extract data from an injection file
408 Parameters
409 ----------
410 path: str
411 path to injection file
412 cls: class, optional
413 class to read in injection file. The class must have a read class
414 method and a samples_dict property. Default None which means that
415 the pesummary.core.file.injection.Injection class is used
416 """
417 if cls is None:
418 from pesummary.core.file.injection import Injection
419 cls = Injection
420 data = cls.read(path, **kwargs).samples_dict
421 for i in self.parameters:
422 if i not in data.keys():
423 data[i] = float("nan")
424 return data
426 def write(
427 self, package="core", file_format="dat", extra_kwargs=None,
428 file_versions=None, **kwargs
429 ):
430 """Save the data to file
432 Parameters
433 ----------
434 package: str, optional
435 package you wish to use when writing the data
436 kwargs: dict, optional
437 all additional kwargs are passed to the pesummary.io.write function
438 """
439 from pesummary.io import write
441 if file_format == "pesummary" and np.array(self.parameters).ndim > 1:
442 args = [self.samples_dict]
443 else:
444 args = [self.parameters, self.samples]
445 if extra_kwargs is None:
446 extra_kwargs = self.extra_kwargs
447 if file_versions is None:
448 file_versions = self.input_version
449 if file_format == "ini":
450 kwargs["file_format"] = "ini"
451 return write(getattr(self, "config", None), **kwargs)
452 else:
453 return write(
454 *args, package=package, file_versions=file_versions,
455 file_kwargs=extra_kwargs, file_format=file_format, **kwargs
456 )
458 def downsample(self, number):
459 """Downsample the posterior samples stored in the result file
460 """
461 self.samples, self.extra_kwargs = _downsample(
462 self.samples, number, extra_kwargs=self.extra_kwargs
463 )
465 @staticmethod
466 def latex_table(samples, parameter_dict=None, labels=None):
467 """Return a latex table displaying the passed data.
469 Parameters
470 ----------
471 samples_dict: list
472 list of pesummary.utils.utils.SamplesDict objects
473 parameter_dict: dict, optional
474 dictionary of parameters that you wish to include in the latex
475 table. The keys are the name of the parameters and the items are
476 the descriptive text. If None, all parameters are included
477 """
478 table = (
479 "\\begin{table}[hptb]\n\\begin{ruledtabular}\n\\begin{tabular}"
480 "{l %s}\n" % ("c " * len(samples))
481 )
482 if labels:
483 table += (
484 " & " + " & ".join(labels)
485 )
486 table += "\\\ \n\\hline \\\ \n"
487 data = {i: i for i in samples[0].keys()}
488 if parameter_dict is not None:
489 import copy
491 data = copy.deepcopy(parameter_dict)
492 for param in parameter_dict.keys():
493 if not all(param in samples_dict.keys() for samples_dict in samples):
494 logger.warning(
495 "{} not in list of parameters. Not adding to "
496 "table".format(param)
497 )
498 data.pop(param)
500 for param, desc in data.items():
501 table += "{}".format(desc)
502 for samples_dict in samples:
503 median = samples_dict[param].average(type="median")
504 confidence = samples_dict[param].credible_interval()
505 table += (
506 " & $%s^{+%s}_{-%s}$" % (
507 np.round(median, 2),
508 np.round(confidence[1] - median, 2),
509 np.round(median - confidence[0], 2)
510 )
511 )
512 table += "\\\ \n"
513 table += (
514 "\\end{tabular}\n\\end{ruledtabular}\n\\caption{}\n\\end{table}"
515 )
516 return table
518 @staticmethod
519 def latex_macros(
520 samples, parameter_dict=None, labels=None, rounding="smart"
521 ):
522 """Return a latex table displaying the passed data.
524 Parameters
525 ----------
526 samples_dict: list
527 list of pesummary.utils.utils.SamplesDict objects
528 parameter_dict: dict, optional
529 dictionary of parameters that you wish to generate macros for. The
530 keys are the name of the parameters and the items are the latex
531 macros name you wish to use. If None, all parameters are included.
532 rounding: int, optional
533 decimal place for rounding. Default uses the
534 `pesummary.utils.utils.smart_round` function to round according to
535 the uncertainty
536 """
537 macros = ""
538 data = {i: i for i in samples[0].keys()}
539 if parameter_dict is not None:
540 import copy
542 data = copy.deepcopy(parameter_dict)
543 for param in parameter_dict.keys():
544 if not all(param in samples_dict.keys() for samples_dict in samples):
545 logger.warning(
546 "{} not in list of parameters. Not generating "
547 "macro".format(param)
548 )
549 data.pop(param)
550 for param, desc in data.items():
551 for num, samples_dict in enumerate(samples):
552 if labels:
553 description = "{}{}".format(desc, labels[num])
554 else:
555 description = desc
557 median = samples_dict[param].average(type="median")
558 confidence = samples_dict[param].credible_interval()
559 if rounding == "smart":
560 from pesummary.utils.utils import smart_round
562 median, upper, low = smart_round([
563 median, confidence[1] - median, median - confidence[0]
564 ])
565 else:
566 median = np.round(median, rounding)
567 low = np.round(median - confidence[0], rounding)
568 upper = np.round(confidence[1] - median, rounding)
569 macros += (
570 "\\def\\%s{$%s_{-%s}^{+%s}$}\n" % (
571 description, median, low, upper
572 )
573 )
574 macros += (
575 "\\def\\%smedian{$%s$}\n" % (description, median)
576 )
577 macros += (
578 "\\def\\%supper{$%s$}\n" % (
579 description, np.round(median + upper, 9)
580 )
581 )
582 macros += (
583 "\\def\\%slower{$%s$}\n" % (
584 description, np.round(median - low, 9)
585 )
586 )
587 return macros
590class SingleAnalysisRead(Read):
591 """Base class to read in a results file which contains a single analyses
593 Parameters
594 ----------
595 path_to_results_file: str
596 path to the results file you wish to load
597 remove_nan_likelihood_samples: Bool, optional
598 if True, remove samples which have log_likelihood='nan'. Default True
600 Attributes
601 ----------
602 parameters: list
603 list of parameters stored in the file
604 samples: 2d list
605 list of samples stored in the result file
606 samples_dict: dict
607 dictionary of samples stored in the result file
608 input_version: str
609 version of the result file passed
610 extra_kwargs: dict
611 dictionary of kwargs that were extracted from the result file
613 Methods
614 -------
615 downsample:
616 downsample the posterior samples stored in the result file
617 to_dat:
618 save the posterior samples to a .dat file
619 to_latex_table:
620 convert the posterior samples to a latex table
621 generate_latex_macros:
622 generate a set of latex macros for the stored posterior samples
623 reweight_samples:
624 reweight the posterior and/or samples according to a new prior
625 """
626 def __init__(self, *args, **kwargs):
627 super(SingleAnalysisRead, self).__init__(*args, **kwargs)
629 @property
630 def samples_dict(self):
631 if self.mcmc_samples:
632 return MCMCSamplesDict(
633 self.parameters, [np.array(i).T for i in self.samples]
634 )
635 return SamplesDict(self.parameters, np.array(self.samples).T)
637 @property
638 def _default_version(self):
639 return "No version information found"
641 @property
642 def _default_kwargs(self):
643 _kwargs = {"sampler": {}, "meta_data": {}}
644 _kwargs["sampler"]["nsamples"] = len(self.data["samples"])
645 return _kwargs
647 @property
648 def _default_description(self):
649 return "No description found"
651 def _add_fixed_parameters_from_config_file(self, config_file, function):
652 """Search the conifiguration file and add fixed parameters to the
653 list of parameters and samples
655 Parameters
656 ----------
657 config_file: str
658 path to the configuration file
659 function: func
660 function you wish to use to extract the information from the
661 configuration file
662 """
663 self.data[0], self.data[1] = function(self.parameters, self.samples, config_file)
665 def _add_marginalized_parameters_from_config_file(self, config_file, function):
666 """Search the configuration file and add marginalized parameters to the
667 list of parameters and samples
669 Parameters
670 ----------
671 config_file: str
672 path to the configuration file
673 function: func
674 function you wish to use to extract the information from the
675 configuration file
676 """
677 self.data[0], self.data[1] = function(self.parameters, self.samples, config_file)
679 def to_latex_table(self, parameter_dict=None, save_to_file=None):
680 """Make a latex table displaying the data in the result file.
682 Parameters
683 ----------
684 parameter_dict: dict, optional
685 dictionary of parameters that you wish to include in the latex
686 table. The keys are the name of the parameters and the items are
687 the descriptive text. If None, all parameters are included
688 save_to_file: str, optional
689 name of the file you wish to save the latex table to. If None, print
690 to stdout
691 """
692 import os
694 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)):
695 raise FileExistsError(
696 "The file {} already exists.".format(save_to_file)
697 )
699 table = self.latex_table([self.samples_dict], parameter_dict)
700 if save_to_file is None:
701 print(table)
702 elif os.path.isfile("{}".format(save_to_file)):
703 logger.warning(
704 "File {} already exists. Printing to stdout".format(save_to_file)
705 )
706 print(table)
707 else:
708 with open(save_to_file, "w") as f:
709 f.writelines([table])
711 def generate_latex_macros(
712 self, parameter_dict=None, save_to_file=None, rounding="smart"
713 ):
714 """Generate a list of latex macros for each parameter in the result
715 file
717 Parameters
718 ----------
719 labels: list, optional
720 list of labels that you want to include in the table
721 parameter_dict: dict, optional
722 dictionary of parameters that you wish to generate macros for. The
723 keys are the name of the parameters and the items are the latex
724 macros name you wish to use. If None, all parameters are included.
725 save_to_file: str, optional
726 name of the file you wish to save the latex table to. If None, print
727 to stdout
728 rounding: int, optional
729 number of decimal points to round the latex macros
730 """
731 import os
733 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)):
734 raise FileExistsError(
735 "The file {} already exists.".format(save_to_file)
736 )
738 macros = self.latex_macros(
739 [self.samples_dict], parameter_dict, rounding=rounding
740 )
741 if save_to_file is None:
742 print(macros)
743 else:
744 with open(save_to_file, "w") as f:
745 f.writelines([macros])
747 def to_dat(self, **kwargs):
748 """Save the PESummary results file object to a dat file
750 Parameters
751 ----------
752 kwargs: dict
753 all kwargs passed to the pesummary.core.file.formats.dat.write_dat
754 function
755 """
756 return self.write(file_format="dat", **kwargs)
758 def reweight_samples(self, function, **kwargs):
759 """Reweight the posterior and/or prior samples according to a new prior
760 """
761 if self.mcmc_samples:
762 return ValueError("Cannot currently reweight MCMC chains")
763 _samples = self.samples_dict
764 new_samples = _samples.reweight(function, **kwargs)
765 self.parameters = Parameters(new_samples.parameters)
766 self.samples = np.array(new_samples.samples).T
767 self.extra_kwargs["sampler"].update(
768 {
769 "nsamples": new_samples.number_of_samples,
770 "nsamples_before_reweighting": _samples.number_of_samples
771 }
772 )
773 self.extra_kwargs["meta_data"]["reweighting"] = function
774 if not hasattr(self, "priors"):
775 return
776 if (self.priors is None) or ("samples" not in self.priors.keys()):
777 return
778 prior_samples = self.priors["samples"]
779 if not len(prior_samples):
780 return
781 new_prior_samples = prior_samples.reweight(function, **kwargs)
782 self.priors["samples"] = new_prior_samples
785class MultiAnalysisRead(Read):
786 """Base class to read in a results file which contains multiple analyses
788 Parameters
789 ----------
790 path_to_results_file: str
791 path to the results file you wish to load
792 remove_nan_likelihood_samples: Bool, optional
793 if True, remove samples which have log_likelihood='nan'. Default True
795 Attributes
796 ----------
797 parameters: 2d list
798 list of parameters for each analysis
799 samples: 3d list
800 list of samples stored in the result file for each analysis
801 samples_dict: dict
802 dictionary of samples stored in the result file keyed by analysis label
803 input_version: str
804 version of the result files passed
805 extra_kwargs: dict
806 dictionary of kwargs that were extracted from the result file
808 Methods
809 -------
810 samples_dict_for_label: dict
811 dictionary of samples for a specific analysis
812 reduced_samples_dict: dict
813 dictionary of samples for one or more analyses
814 downsample:
815 downsample the posterior samples stored in the result file
816 to_dat:
817 save the posterior samples to a .dat file
818 to_latex_table:
819 convert the posterior samples to a latex table
820 generate_latex_macros:
821 generate a set of latex macros for the stored posterior samples
822 reweight_samples:
823 reweight the posterior and/or samples according to a new prior
824 """
825 def __init__(self, *args, **kwargs):
826 super(MultiAnalysisRead, self).__init__(*args, **kwargs)
828 @staticmethod
829 def check_for_nan_likelihoods(parameters, samples, remove=False):
830 import copy
831 _parameters = copy.deepcopy(parameters)
832 _samples = copy.deepcopy(samples)
833 for num, (params, samps) in enumerate(zip(_parameters, _samples)):
834 _parameters[num], _samples[num] = Read.check_for_nan_likelihoods(
835 params, samps, remove=remove
836 )
837 return _parameters, _samples
839 def samples_dict_for_label(self, label):
840 """Return the posterior samples for a specific label
842 Parameters
843 ----------
844 labels: str
845 label you wish to get posterior samples for
847 Returns
848 -------
849 outdict: SamplesDict
850 Returns a SamplesDict containing the requested posterior samples
851 """
852 if label not in self.labels:
853 raise ValueError("Unrecognised label: '{}'".format(label))
854 idx = self.labels.index(label)
855 return SamplesDict(self.parameters[idx], np.array(self.samples[idx]).T)
857 def reduced_samples_dict(self, labels):
858 """Return the posterior samples for one or more labels
860 Parameters
861 ----------
862 labels: str, list
863 label(s) you wish to get posterior samples for
865 Returns
866 -------
867 outdict: MultiAnalysisSamplesDict
868 Returns a MultiAnalysisSamplesDict containing the requested
869 posterior samples
870 """
871 if not isinstance(labels, list):
872 labels = [labels]
873 not_allowed = [_label for _label in labels if _label not in self.labels]
874 if len(not_allowed):
875 raise ValueError(
876 "Unrecognised label(s) '{}'. The list of available labels are "
877 "{}.".format(", ".join(not_allowed), ", ".join(self.labels))
878 )
879 return MultiAnalysisSamplesDict(
880 {
881 label: self.samples_dict_for_label(label) for label in labels
882 }
883 )
885 @property
886 def samples_dict(self):
887 if self.mcmc_samples:
888 outdict = MCMCSamplesDict(
889 self.parameters[0], [np.array(i).T for i in self.samples[0]]
890 )
891 else:
892 outdict = self.reduced_samples_dict(self.labels)
893 return outdict
895 @property
896 def _default_version(self):
897 return ["No version information found"] * len(self.parameters)
899 @property
900 def _default_kwargs(self):
901 _kwargs = [{"sampler": {}, "meta_data": {}}] * len(self.parameters)
902 for num, ss in enumerate(self.data["samples"]):
903 _kwargs[num]["sampler"]["nsamples"] = len(ss)
904 return _kwargs
906 @property
907 def _default_description(self):
908 return {label: "No description found" for label in self.labels}
910 def write(self, package="core", file_format="dat", **kwargs):
911 """Save the data to file
913 Parameters
914 ----------
915 package: str, optional
916 package you wish to use when writing the data
917 kwargs: dict, optional
918 all additional kwargs are passed to the pesummary.io.write function
919 """
920 return super(MultiAnalysisRead, self).write(
921 package=package, file_format=file_format,
922 extra_kwargs=self.kwargs_dict, file_versions=self.version_dict,
923 **kwargs
924 )
926 @property
927 def kwargs_dict(self):
928 return {
929 label: kwarg for label, kwarg in zip(self.labels, self.extra_kwargs)
930 }
932 @property
933 def version_dict(self):
934 return {
935 label: version for label, version in zip(self.labels, self.input_version)
936 }
938 def summary(self, *args, parameters_to_show=4, **kwargs):
939 """Return a summary of the contents of the file
941 Parameters
942 ----------
943 parameters_to_show: int, optional
944 number of parameters to show. Default 4
945 """
946 string = super(MultiAnalysisRead, self).summary(
947 show_parameters=False, show_nsamples=False
948 )
949 string += "analyses: {}\n\n".format(", ".join(self.labels))
950 for num, label in enumerate(self.labels):
951 string += "{}\n".format(label)
952 string += "-" * len(label) + "\n"
953 string += "description: {}\n".format(self.description[label])
954 string += "nsamples: {}\n".format(len(self.samples[num]))
955 string += "parameters: {}\n\n".format(
956 self._parameter_summary(
957 self.parameters[num], parameters_to_show=parameters_to_show
958 )
959 )
960 return string[:-2]
962 def downsample(self, number, labels=None):
963 """Downsample the posterior samples stored in the result file
964 """
965 for num, ss in enumerate(self.samples):
966 if labels is not None and self.labels[num] not in labels:
967 continue
968 self.samples[num], self.extra_kwargs[num] = _downsample(
969 ss, number, extra_kwargs=self.extra_kwargs[num]
970 )
972 def to_latex_table(self, labels="all", parameter_dict=None, save_to_file=None):
973 """Make a latex table displaying the data in the result file.
975 Parameters
976 ----------
977 labels: list, optional
978 list of labels that you want to include in the table
979 parameter_dict: dict, optional
980 dictionary of parameters that you wish to include in the latex
981 table. The keys are the name of the parameters and the items are
982 the descriptive text. If None, all parameters are included
983 save_to_file: str, optional
984 name of the file you wish to save the latex table to. If None, print
985 to stdout
986 """
987 import os
989 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)):
990 raise FileExistsError(
991 "The file {} already exists.".format(save_to_file)
992 )
993 if labels != "all" and isinstance(labels, str) and labels not in self.labels:
994 raise ValueError("The label %s does not exist." % (labels))
995 elif labels == "all":
996 labels = list(self.labels)
997 elif isinstance(labels, str):
998 labels = [labels]
999 elif isinstance(labels, list):
1000 for ll in labels:
1001 if ll not in list(self.labels):
1002 raise ValueError("The label %s does not exist." % (ll))
1004 table = self.latex_table(
1005 [self.samples_dict[label] for label in labels], parameter_dict,
1006 labels=labels
1007 )
1008 if save_to_file is None:
1009 print(table)
1010 elif os.path.isfile("{}".format(save_to_file)):
1011 logger.warning(
1012 "File {} already exists. Printing to stdout".format(save_to_file)
1013 )
1014 print(table)
1015 else:
1016 with open(save_to_file, "w") as f:
1017 f.writelines([table])
1019 def generate_latex_macros(
1020 self, labels="all", parameter_dict=None, save_to_file=None,
1021 rounding=2
1022 ):
1023 """Generate a list of latex macros for each parameter in the result
1024 file
1026 Parameters
1027 ----------
1028 labels: list, optional
1029 list of labels that you want to include in the table
1030 parameter_dict: dict, optional
1031 dictionary of parameters that you wish to generate macros for. The
1032 keys are the name of the parameters and the items are the latex
1033 macros name you wish to use. If None, all parameters are included.
1034 save_to_file: str, optional
1035 name of the file you wish to save the latex table to. If None, print
1036 to stdout
1037 rounding: int, optional
1038 number of decimal points to round the latex macros
1039 """
1040 import os
1042 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)):
1043 raise FileExistsError(
1044 "The file {} already exists.".format(save_to_file)
1045 )
1046 if labels != "all" and isinstance(labels, str) and labels not in self.labels:
1047 raise ValueError("The label %s does not exist." % (labels))
1048 elif labels == "all":
1049 labels = list(self.labels)
1050 elif isinstance(labels, str):
1051 labels = [labels]
1052 elif isinstance(labels, list):
1053 for ll in labels:
1054 if ll not in list(self.labels):
1055 raise ValueError("The label %s does not exist." % (ll))
1057 macros = self.latex_macros(
1058 [self.samples_dict[i] for i in labels], parameter_dict,
1059 labels=labels, rounding=rounding
1060 )
1061 if save_to_file is None:
1062 print(macros)
1063 else:
1064 with open(save_to_file, "w") as f:
1065 f.writelines([macros])
1067 def reweight_samples(self, function, labels=None, **kwargs):
1068 """Reweight the posterior and/or prior samples according to a new prior
1070 Parameters
1071 ----------
1072 labels: list, optional
1073 list of analyses you wish to reweight. Default reweight all
1074 analyses
1075 """
1076 _samples_dict = self.samples_dict
1077 for idx, label in enumerate(self.labels):
1078 if labels is not None and label not in labels:
1079 continue
1080 new_samples = _samples_dict[label].reweight(function, **kwargs)
1081 self.parameters[idx] = Parameters(new_samples.parameters)
1082 self.samples[idx] = np.array(new_samples.samples).T
1083 self.extra_kwargs[idx]["sampler"].update(
1084 {
1085 "nsamples": new_samples.number_of_samples,
1086 "nsamples_before_reweighting": (
1087 _samples_dict[label].number_of_samples
1088 )
1089 }
1090 )
1091 self.extra_kwargs[idx]["meta_data"]["reweighting"] = function
1092 if not hasattr(self, "priors"):
1093 continue
1094 if "samples" not in self.priors.keys():
1095 continue
1096 prior_samples = self.priors["samples"][label]
1097 if not len(prior_samples):
1098 continue
1099 new_prior_samples = prior_samples.reweight(function, **kwargs)
1100 self.priors["samples"][label] = new_prior_samples