Coverage for pesummary/core/file/formats/base_read.py: 56.5%
414 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import os
4import numpy as np
5import h5py
6from pesummary.utils.parameters import MultiAnalysisParameters, Parameters
7from pesummary.utils.samples_dict import (
8 MultiAnalysisSamplesDict, SamplesDict, MCMCSamplesDict, Array
9)
10from pesummary.utils.utils import logger
12__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
15def _downsample(samples, number, extra_kwargs=None):
16 """Downsample a posterior table
18 Parameters
19 ----------
20 samples: 2d list
21 list of posterior samples where the columns correspond to a given
22 parameter
23 number: int
24 number of posterior samples you wish to downsample to
25 extra_kwargs: dict, optional
26 dictionary of kwargs to modify
27 """
28 from pesummary.utils.utils import resample_posterior_distribution
29 import copy
31 _samples = np.array(samples).T
32 if number > len(_samples[0]):
33 raise ValueError(
34 "Failed to downsample the posterior samples to {} because "
35 "there are only {} samples stored in the file.".format(
36 number, len(_samples[0])
37 )
38 )
39 _samples = np.array(resample_posterior_distribution(_samples, number))
40 if extra_kwargs is None:
41 return _samples.T.tolist()
42 _extra_kwargs = copy.deepcopy(extra_kwargs)
43 _extra_kwargs["sampler"]["nsamples"] = number
44 return _samples.T.tolist(), _extra_kwargs
47class Read(object):
48 """Base class to read in a results file
50 Parameters
51 ----------
52 path_to_results_file: str
53 path to the results file you wish to load
54 remove_nan_likelihood_samples: Bool, optional
55 if True, remove samples which have log_likelihood='nan'. Default True
57 Attributes
58 ----------
59 parameters: list
60 list of parameters stored in the result file
61 samples: 2d list
62 list of samples stored in the result file
63 samples_dict: dict
64 dictionary of samples stored in the result file keyed by parameters
65 input_version: str
66 version of the result file passed.
67 extra_kwargs: dict
68 dictionary of kwargs that were extracted from the result file
69 pe_algorithm: str
70 name of the algorithm used to generate the posterior samples
72 Methods
73 -------
74 downsample:
75 downsample the posterior samples stored in the result file
76 to_dat:
77 save the posterior samples to a .dat file
78 to_latex_table:
79 convert the posterior samples to a latex table
80 generate_latex_macros:
81 generate a set of latex macros for the stored posterior samples
82 """
83 def __init__(
84 self, path_to_results_file, remove_nan_likelihood_samples=True, **kwargs
85 ):
86 self.path_to_results_file = path_to_results_file
87 self.mcmc_samples = False
88 self.remove_nan_likelihood_samples = remove_nan_likelihood_samples
89 self.extension = self.extension_from_path(self.path_to_results_file)
90 self.converted_parameters = []
92 @classmethod
93 def load_file(cls, path, **kwargs):
94 """Initialize the class with a file
96 Parameters
97 ----------
98 path: str
99 path to the result file you wish to load
100 **kwargs: dict, optional
101 all kwargs passed to the class
102 """
103 if not os.path.isfile(path):
104 raise FileNotFoundError("%s does not exist" % (path))
105 return cls(path, **kwargs)
107 @staticmethod
108 def load_from_function(function, path_to_file, **kwargs):
109 """Load a file according to a given function
111 Parameters
112 ----------
113 function: func
114 callable function that will load in your file
115 path_to_file: str
116 path to the file that you wish to load
117 kwargs: dict
118 all kwargs are passed to the function
119 """
120 return function(path_to_file, **kwargs)
122 @staticmethod
123 def check_for_nan_likelihoods(parameters, samples, remove=False):
124 """Check to see if there are any samples with log_likelihood='nan' in
125 the posterior table and remove if requested
127 Parameters
128 ----------
129 parameters: list
130 list of parameters stored in the result file
131 samples: np.ndarray
132 array of samples for each parameter
133 remove: Bool, optional
134 if True, remove samples with log_likelihood='nan' from samples
135 """
136 import math
137 if "log_likelihood" not in parameters:
138 return parameters, samples
139 ind = parameters.index("log_likelihood")
140 likelihoods = np.array(samples).T[ind]
141 inds = np.array(
142 [math.isnan(_) for _ in likelihoods], dtype=bool
143 )
144 if not sum(inds):
145 return parameters, samples
146 msg = (
147 "Posterior table contains {} samples with 'nan' log likelihood. "
148 )
149 if remove:
150 msg += "Removing samples from posterior table."
151 samples = np.array(samples)[~inds].tolist()
152 else:
153 msg += "This may cause problems when analysing posterior samples."
154 logger.warning(msg.format(sum(inds)))
155 return parameters, samples
157 @staticmethod
158 def check_for_weights(parameters, samples):
159 """Check to see if the samples are weighted
161 Parameters
162 ----------
163 parameters: list
164 list of parameters stored in the result file
165 samples: np.ndarray
166 array of samples for each parameter
167 """
168 likely_names = ["weights", "weight"]
169 if any(i in parameters for i in likely_names):
170 ind = (
171 parameters.index("weights") if "weights" in parameters else
172 parameters.index("weight")
173 )
174 return Array(np.array(samples).T[ind])
175 return None
177 @property
178 def pe_algorithm(self):
179 try:
180 return self.extra_kwargs["sampler"]["pe_algorithm"]
181 except KeyError:
182 return None
184 def __repr__(self):
185 return self.summary()
187 def _parameter_summary(self, parameters, parameters_to_show=4):
188 """Return a summary of the parameter stored
190 Parameters
191 ----------
192 parameters: list
193 list of parameters to create a summary for
194 parameters_to_show: int, optional
195 number of parameters to show. Default 4.
196 """
197 params = parameters
198 if len(parameters) > parameters_to_show:
199 params = parameters[:2] + ["..."] + parameters[-2:]
200 return ", ".join(params)
202 def summary(
203 self, parameters_to_show=4, show_parameters=True, show_nsamples=True
204 ):
205 """Return a summary of the contents of the file
207 Parameters
208 ----------
209 parameters_to_show: int, optional
210 number of parameters to show. Default 4
211 show_parameters: Bool, optional
212 if True print a list of the parameters stored
213 show_nsamples: Bool, optional
214 if True print how many samples are stored in the file
215 """
216 string = ""
217 if self.path_to_results_file is not None:
218 string += "file: {}\n".format(self.path_to_results_file)
219 string += "cls: {}.{}\n".format(
220 self.__class__.__module__, self.__class__.__name__
221 )
222 if show_nsamples:
223 string += "nsamples: {}\n".format(len(self.samples))
224 if show_parameters:
225 string += "parameters: {}".format(
226 self._parameter_summary(
227 self.parameters, parameters_to_show=parameters_to_show
228 )
229 )
230 return string
232 attrs = {
233 "input_version": "version", "extra_kwargs": "kwargs",
234 "priors": "prior", "analytic": "analytic", "labels": "labels",
235 "config": "config", "weights": "weights", "history": "history",
236 "description": "description"
237 }
239 def _load(self, function, **kwargs):
240 """Extract the data from a file using a given function
242 Parameters
243 ----------
244 function: func
245 function you wish to use to extract the data
246 **kwargs: dict, optional
247 optional kwargs to pass to the load function
248 """
249 return self.load_from_function(
250 function, self.path_to_results_file, **kwargs
251 )
253 def load(self, function, _data=None, **kwargs):
254 """Load a results file according to a given function
256 Parameters
257 ----------
258 function: func
259 callable function that will load in your results file
260 """
261 self.data = _data
262 if _data is None:
263 self.data = self._load(function, **kwargs)
264 if isinstance(self.data["parameters"][0], list):
265 _cls = MultiAnalysisParameters
266 else:
267 _cls = Parameters
268 self.parameters = _cls(self.data["parameters"])
269 self.converted_parameters = []
270 self.samples = self.data["samples"]
271 self.parameters, self.samples = self.check_for_nan_likelihoods(
272 self.parameters, self.samples,
273 remove=self.remove_nan_likelihood_samples
274 )
275 if "mcmc_samples" in self.data.keys():
276 self.mcmc_samples = self.data["mcmc_samples"]
277 if "injection" in self.data.keys():
278 if isinstance(self.data["injection"], dict):
279 self.injection_parameters = {
280 key.decode("utf-8") if isinstance(key, bytes) else key: val
281 for key, val in self.data["injection"].items()
282 }
283 elif isinstance(self.data["injection"], list):
284 self.injection_parameters = [
285 {
286 key.decode("utf-8") if isinstance(key, bytes) else
287 key: val for key, val in i.items()
288 } for i in self.data["injection"]
289 ]
290 else:
291 self.injection_parameters = self.data["injection"]
292 for new_attr, existing_attr in self.attrs.items():
293 if existing_attr in self.data.keys():
294 setattr(self, new_attr, self.data[existing_attr])
295 else:
296 setattr(self, new_attr, None)
297 if self.input_version is None:
298 self.input_version = self._default_version
299 if self.extra_kwargs is None:
300 self.extra_kwargs = self._default_kwargs
301 if self.description is None:
302 self.description = self._default_description
303 if self.weights is None:
304 self.weights = self.check_for_weights(self.parameters, self.samples)
306 @staticmethod
307 def extension_from_path(path):
308 """Return the extension of the file from the file path
310 Parameters
311 ----------
312 path: str
313 path to the results file
314 """
315 extension = path.split(".")[-1]
316 return extension
318 @staticmethod
319 def guess_path_to_samples(path):
320 """Guess the path to the posterior samples stored in an hdf5 object
322 Parameters
323 ----------
324 path: str
325 path to the results file
326 """
327 def _find_name(name, item):
328 c1 = "posterior_samples" in name or "posterior" in name
329 c2 = isinstance(item, (h5py._hl.dataset.Dataset, np.ndarray))
330 try:
331 c3 = isinstance(item, h5py._hl.group.Group) and isinstance(
332 item[0], (float, int, np.number)
333 )
334 except (TypeError, AttributeError):
335 c3 = False
336 c4 = (
337 isinstance(item, h5py._hl.group.Group) and "parameter_names" in
338 item.keys() and "samples" in item.keys()
339 )
340 if c1 and c3:
341 paths.append(name)
342 elif c1 and c4:
343 return paths.append(name)
344 elif c1 and c2:
345 return paths.append(name)
347 f = h5py.File(path, 'r')
348 paths = []
349 f.visititems(_find_name)
350 f.close()
351 if len(paths) == 1:
352 return paths[0]
353 elif len(paths) > 1:
354 raise ValueError(
355 "Found multiple posterior sample tables in '{}': {}. Not sure "
356 "which to load.".format(
357 path, ", ".join(paths)
358 )
359 )
360 else:
361 raise ValueError(
362 "Unable to find a posterior samples table in '{}'".format(path)
363 )
365 def generate_all_posterior_samples(self, **kwargs):
366 """Empty function
367 """
368 pass
370 def add_fixed_parameters_from_config_file(self, config_file):
371 """Search the conifiguration file and add fixed parameters to the
372 list of parameters and samples
374 Parameters
375 ----------
376 config_file: str
377 path to the configuration file
378 """
379 pass
381 def add_injection_parameters_from_file(self, injection_file, **kwargs):
382 """Populate the 'injection_parameters' property with data from a file
384 Parameters
385 ----------
386 injection_file: str
387 path to injection file
388 """
389 self.injection_parameters = self._grab_injection_parameters_from_file(
390 injection_file, **kwargs
391 )
393 def _grab_injection_parameters_from_file(
394 self, path, cls=None, add_nans=True, **kwargs
395 ):
396 """Extract data from an injection file
398 Parameters
399 ----------
400 path: str
401 path to injection file
402 cls: class, optional
403 class to read in injection file. The class must have a read class
404 method and a samples_dict property. Default None which means that
405 the pesummary.core.file.injection.Injection class is used
406 """
407 if cls is None:
408 from pesummary.core.file.injection import Injection
409 cls = Injection
410 data = cls.read(path, **kwargs).samples_dict
411 for i in self.parameters:
412 if i not in data.keys():
413 data[i] = float("nan")
414 return data
416 def write(
417 self, package="core", file_format="dat", extra_kwargs=None,
418 file_versions=None, **kwargs
419 ):
420 """Save the data to file
422 Parameters
423 ----------
424 package: str, optional
425 package you wish to use when writing the data
426 kwargs: dict, optional
427 all additional kwargs are passed to the pesummary.io.write function
428 """
429 from pesummary.io import write
431 if file_format == "pesummary" and np.array(self.parameters).ndim > 1:
432 args = [self.samples_dict]
433 else:
434 args = [self.parameters, self.samples]
435 if extra_kwargs is None:
436 extra_kwargs = self.extra_kwargs
437 if file_versions is None:
438 file_versions = self.input_version
439 if file_format == "ini":
440 kwargs["file_format"] = "ini"
441 return write(getattr(self, "config", None), **kwargs)
442 else:
443 return write(
444 *args, package=package, file_versions=file_versions,
445 file_kwargs=extra_kwargs, file_format=file_format, **kwargs
446 )
448 def downsample(self, number):
449 """Downsample the posterior samples stored in the result file
450 """
451 self.samples, self.extra_kwargs = _downsample(
452 self.samples, number, extra_kwargs=self.extra_kwargs
453 )
455 @staticmethod
456 def latex_table(samples, parameter_dict=None, labels=None):
457 """Return a latex table displaying the passed data.
459 Parameters
460 ----------
461 samples_dict: list
462 list of pesummary.utils.utils.SamplesDict objects
463 parameter_dict: dict, optional
464 dictionary of parameters that you wish to include in the latex
465 table. The keys are the name of the parameters and the items are
466 the descriptive text. If None, all parameters are included
467 """
468 table = (
469 "\\begin{table}[hptb]\n\\begin{ruledtabular}\n\\begin{tabular}"
470 "{l %s}\n" % ("c " * len(samples))
471 )
472 if labels:
473 table += (
474 " & " + " & ".join(labels)
475 )
476 table += "\\\ \n\\hline \\\ \n"
477 data = {i: i for i in samples[0].keys()}
478 if parameter_dict is not None:
479 import copy
481 data = copy.deepcopy(parameter_dict)
482 for param in parameter_dict.keys():
483 if not all(param in samples_dict.keys() for samples_dict in samples):
484 logger.warning(
485 "{} not in list of parameters. Not adding to "
486 "table".format(param)
487 )
488 data.pop(param)
490 for param, desc in data.items():
491 table += "{}".format(desc)
492 for samples_dict in samples:
493 median = samples_dict[param].average(type="median")
494 confidence = samples_dict[param].credible_interval()
495 table += (
496 " & $%s^{+%s}_{-%s}$" % (
497 np.round(median, 2),
498 np.round(confidence[1] - median, 2),
499 np.round(median - confidence[0], 2)
500 )
501 )
502 table += "\\\ \n"
503 table += (
504 "\\end{tabular}\n\\end{ruledtabular}\n\\caption{}\n\\end{table}"
505 )
506 return table
508 @staticmethod
509 def latex_macros(
510 samples, parameter_dict=None, labels=None, rounding="smart"
511 ):
512 """Return a latex table displaying the passed data.
514 Parameters
515 ----------
516 samples_dict: list
517 list of pesummary.utils.utils.SamplesDict objects
518 parameter_dict: dict, optional
519 dictionary of parameters that you wish to generate macros for. The
520 keys are the name of the parameters and the items are the latex
521 macros name you wish to use. If None, all parameters are included.
522 rounding: int, optional
523 decimal place for rounding. Default uses the
524 `pesummary.utils.utils.smart_round` function to round according to
525 the uncertainty
526 """
527 macros = ""
528 data = {i: i for i in samples[0].keys()}
529 if parameter_dict is not None:
530 import copy
532 data = copy.deepcopy(parameter_dict)
533 for param in parameter_dict.keys():
534 if not all(param in samples_dict.keys() for samples_dict in samples):
535 logger.warning(
536 "{} not in list of parameters. Not generating "
537 "macro".format(param)
538 )
539 data.pop(param)
540 for param, desc in data.items():
541 for num, samples_dict in enumerate(samples):
542 if labels:
543 description = "{}{}".format(desc, labels[num])
544 else:
545 description = desc
547 median = samples_dict[param].average(type="median")
548 confidence = samples_dict[param].credible_interval()
549 if rounding == "smart":
550 from pesummary.utils.utils import smart_round
552 median, upper, low = smart_round([
553 median, confidence[1] - median, median - confidence[0]
554 ])
555 else:
556 median = np.round(median, rounding)
557 low = np.round(median - confidence[0], rounding)
558 upper = np.round(confidence[1] - median, rounding)
559 macros += (
560 "\\def\\%s{$%s_{-%s}^{+%s}$}\n" % (
561 description, median, low, upper
562 )
563 )
564 macros += (
565 "\\def\\%smedian{$%s$}\n" % (description, median)
566 )
567 macros += (
568 "\\def\\%supper{$%s$}\n" % (
569 description, np.round(median + upper, 9)
570 )
571 )
572 macros += (
573 "\\def\\%slower{$%s$}\n" % (
574 description, np.round(median - low, 9)
575 )
576 )
577 return macros
580class SingleAnalysisRead(Read):
581 """Base class to read in a results file which contains a single analyses
583 Parameters
584 ----------
585 path_to_results_file: str
586 path to the results file you wish to load
587 remove_nan_likelihood_samples: Bool, optional
588 if True, remove samples which have log_likelihood='nan'. Default True
590 Attributes
591 ----------
592 parameters: list
593 list of parameters stored in the file
594 samples: 2d list
595 list of samples stored in the result file
596 samples_dict: dict
597 dictionary of samples stored in the result file
598 input_version: str
599 version of the result file passed
600 extra_kwargs: dict
601 dictionary of kwargs that were extracted from the result file
603 Methods
604 -------
605 downsample:
606 downsample the posterior samples stored in the result file
607 to_dat:
608 save the posterior samples to a .dat file
609 to_latex_table:
610 convert the posterior samples to a latex table
611 generate_latex_macros:
612 generate a set of latex macros for the stored posterior samples
613 reweight_samples:
614 reweight the posterior and/or samples according to a new prior
615 """
616 def __init__(self, *args, **kwargs):
617 super(SingleAnalysisRead, self).__init__(*args, **kwargs)
619 @property
620 def samples_dict(self):
621 if self.mcmc_samples:
622 return MCMCSamplesDict(
623 self.parameters, [np.array(i).T for i in self.samples]
624 )
625 return SamplesDict(self.parameters, np.array(self.samples).T)
627 @property
628 def _default_version(self):
629 return "No version information found"
631 @property
632 def _default_kwargs(self):
633 _kwargs = {"sampler": {}, "meta_data": {}}
634 _kwargs["sampler"]["nsamples"] = len(self.data["samples"])
635 return _kwargs
637 @property
638 def _default_description(self):
639 return "No description found"
641 def _add_fixed_parameters_from_config_file(self, config_file, function):
642 """Search the conifiguration file and add fixed parameters to the
643 list of parameters and samples
645 Parameters
646 ----------
647 config_file: str
648 path to the configuration file
649 function: func
650 function you wish to use to extract the information from the
651 configuration file
652 """
653 self.data[0], self.data[1] = function(self.parameters, self.samples, config_file)
655 def _add_marginalized_parameters_from_config_file(self, config_file, function):
656 """Search the configuration file and add marginalized parameters to the
657 list of parameters and samples
659 Parameters
660 ----------
661 config_file: str
662 path to the configuration file
663 function: func
664 function you wish to use to extract the information from the
665 configuration file
666 """
667 self.data[0], self.data[1] = function(self.parameters, self.samples, config_file)
669 def to_latex_table(self, parameter_dict=None, save_to_file=None):
670 """Make a latex table displaying the data in the result file.
672 Parameters
673 ----------
674 parameter_dict: dict, optional
675 dictionary of parameters that you wish to include in the latex
676 table. The keys are the name of the parameters and the items are
677 the descriptive text. If None, all parameters are included
678 save_to_file: str, optional
679 name of the file you wish to save the latex table to. If None, print
680 to stdout
681 """
682 import os
684 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)):
685 raise FileExistsError(
686 "The file {} already exists.".format(save_to_file)
687 )
689 table = self.latex_table([self.samples_dict], parameter_dict)
690 if save_to_file is None:
691 print(table)
692 elif os.path.isfile("{}".format(save_to_file)):
693 logger.warning(
694 "File {} already exists. Printing to stdout".format(save_to_file)
695 )
696 print(table)
697 else:
698 with open(save_to_file, "w") as f:
699 f.writelines([table])
701 def generate_latex_macros(
702 self, parameter_dict=None, save_to_file=None, rounding="smart"
703 ):
704 """Generate a list of latex macros for each parameter in the result
705 file
707 Parameters
708 ----------
709 labels: list, optional
710 list of labels that you want to include in the table
711 parameter_dict: dict, optional
712 dictionary of parameters that you wish to generate macros for. The
713 keys are the name of the parameters and the items are the latex
714 macros name you wish to use. If None, all parameters are included.
715 save_to_file: str, optional
716 name of the file you wish to save the latex table to. If None, print
717 to stdout
718 rounding: int, optional
719 number of decimal points to round the latex macros
720 """
721 import os
723 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)):
724 raise FileExistsError(
725 "The file {} already exists.".format(save_to_file)
726 )
728 macros = self.latex_macros(
729 [self.samples_dict], parameter_dict, rounding=rounding
730 )
731 if save_to_file is None:
732 print(macros)
733 else:
734 with open(save_to_file, "w") as f:
735 f.writelines([macros])
737 def to_dat(self, **kwargs):
738 """Save the PESummary results file object to a dat file
740 Parameters
741 ----------
742 kwargs: dict
743 all kwargs passed to the pesummary.core.file.formats.dat.write_dat
744 function
745 """
746 return self.write(file_format="dat", **kwargs)
748 def reweight_samples(self, function, **kwargs):
749 """Reweight the posterior and/or prior samples according to a new prior
750 """
751 if self.mcmc_samples:
752 return ValueError("Cannot currently reweight MCMC chains")
753 _samples = self.samples_dict
754 new_samples = _samples.reweight(function, **kwargs)
755 self.parameters = Parameters(new_samples.parameters)
756 self.samples = np.array(new_samples.samples).T
757 self.extra_kwargs["sampler"].update(
758 {
759 "nsamples": new_samples.number_of_samples,
760 "nsamples_before_reweighting": _samples.number_of_samples
761 }
762 )
763 self.extra_kwargs["meta_data"]["reweighting"] = function
764 if not hasattr(self, "priors"):
765 return
766 if (self.priors is None) or ("samples" not in self.priors.keys()):
767 return
768 prior_samples = self.priors["samples"]
769 if not len(prior_samples):
770 return
771 new_prior_samples = prior_samples.reweight(function, **kwargs)
772 self.priors["samples"] = new_prior_samples
775class MultiAnalysisRead(Read):
776 """Base class to read in a results file which contains multiple analyses
778 Parameters
779 ----------
780 path_to_results_file: str
781 path to the results file you wish to load
782 remove_nan_likelihood_samples: Bool, optional
783 if True, remove samples which have log_likelihood='nan'. Default True
785 Attributes
786 ----------
787 parameters: 2d list
788 list of parameters for each analysis
789 samples: 3d list
790 list of samples stored in the result file for each analysis
791 samples_dict: dict
792 dictionary of samples stored in the result file keyed by analysis label
793 input_version: str
794 version of the result files passed
795 extra_kwargs: dict
796 dictionary of kwargs that were extracted from the result file
798 Methods
799 -------
800 samples_dict_for_label: dict
801 dictionary of samples for a specific analysis
802 reduced_samples_dict: dict
803 dictionary of samples for one or more analyses
804 downsample:
805 downsample the posterior samples stored in the result file
806 to_dat:
807 save the posterior samples to a .dat file
808 to_latex_table:
809 convert the posterior samples to a latex table
810 generate_latex_macros:
811 generate a set of latex macros for the stored posterior samples
812 reweight_samples:
813 reweight the posterior and/or samples according to a new prior
814 """
815 def __init__(self, *args, **kwargs):
816 super(MultiAnalysisRead, self).__init__(*args, **kwargs)
818 @staticmethod
819 def check_for_nan_likelihoods(parameters, samples, remove=False):
820 import copy
821 _parameters = copy.deepcopy(parameters)
822 _samples = copy.deepcopy(samples)
823 for num, (params, samps) in enumerate(zip(_parameters, _samples)):
824 _parameters[num], _samples[num] = Read.check_for_nan_likelihoods(
825 params, samps, remove=remove
826 )
827 return _parameters, _samples
829 def samples_dict_for_label(self, label):
830 """Return the posterior samples for a specific label
832 Parameters
833 ----------
834 labels: str
835 label you wish to get posterior samples for
837 Returns
838 -------
839 outdict: SamplesDict
840 Returns a SamplesDict containing the requested posterior samples
841 """
842 if label not in self.labels:
843 raise ValueError("Unrecognised label: '{}'".format(label))
844 idx = self.labels.index(label)
845 return SamplesDict(self.parameters[idx], np.array(self.samples[idx]).T)
847 def reduced_samples_dict(self, labels):
848 """Return the posterior samples for one or more labels
850 Parameters
851 ----------
852 labels: str, list
853 label(s) you wish to get posterior samples for
855 Returns
856 -------
857 outdict: MultiAnalysisSamplesDict
858 Returns a MultiAnalysisSamplesDict containing the requested
859 posterior samples
860 """
861 if not isinstance(labels, list):
862 labels = [labels]
863 not_allowed = [_label for _label in labels if _label not in self.labels]
864 if len(not_allowed):
865 raise ValueError(
866 "Unrecognised label(s) '{}'. The list of available labels are "
867 "{}.".format(", ".join(not_allowed), ", ".join(self.labels))
868 )
869 return MultiAnalysisSamplesDict(
870 {
871 label: self.samples_dict_for_label(label) for label in labels
872 }
873 )
875 @property
876 def samples_dict(self):
877 if self.mcmc_samples:
878 outdict = MCMCSamplesDict(
879 self.parameters[0], [np.array(i).T for i in self.samples[0]]
880 )
881 else:
882 outdict = self.reduced_samples_dict(self.labels)
883 return outdict
885 @property
886 def _default_version(self):
887 return ["No version information found"] * len(self.parameters)
889 @property
890 def _default_kwargs(self):
891 _kwargs = [{"sampler": {}, "meta_data": {}}] * len(self.parameters)
892 for num, ss in enumerate(self.data["samples"]):
893 _kwargs[num]["sampler"]["nsamples"] = len(ss)
894 return _kwargs
896 @property
897 def _default_description(self):
898 return {label: "No description found" for label in self.labels}
900 def write(self, package="core", file_format="dat", **kwargs):
901 """Save the data to file
903 Parameters
904 ----------
905 package: str, optional
906 package you wish to use when writing the data
907 kwargs: dict, optional
908 all additional kwargs are passed to the pesummary.io.write function
909 """
910 return super(MultiAnalysisRead, self).write(
911 package=package, file_format=file_format,
912 extra_kwargs=self.kwargs_dict, file_versions=self.version_dict,
913 **kwargs
914 )
916 @property
917 def kwargs_dict(self):
918 return {
919 label: kwarg for label, kwarg in zip(self.labels, self.extra_kwargs)
920 }
922 @property
923 def version_dict(self):
924 return {
925 label: version for label, version in zip(self.labels, self.input_version)
926 }
928 def summary(self, *args, parameters_to_show=4, **kwargs):
929 """Return a summary of the contents of the file
931 Parameters
932 ----------
933 parameters_to_show: int, optional
934 number of parameters to show. Default 4
935 """
936 string = super(MultiAnalysisRead, self).summary(
937 show_parameters=False, show_nsamples=False
938 )
939 string += "analyses: {}\n\n".format(", ".join(self.labels))
940 for num, label in enumerate(self.labels):
941 string += "{}\n".format(label)
942 string += "-" * len(label) + "\n"
943 string += "description: {}\n".format(self.description[label])
944 string += "nsamples: {}\n".format(len(self.samples[num]))
945 string += "parameters: {}\n\n".format(
946 self._parameter_summary(
947 self.parameters[num], parameters_to_show=parameters_to_show
948 )
949 )
950 return string[:-2]
952 def downsample(self, number, labels=None):
953 """Downsample the posterior samples stored in the result file
954 """
955 for num, ss in enumerate(self.samples):
956 if labels is not None and self.labels[num] not in labels:
957 continue
958 self.samples[num], self.extra_kwargs[num] = _downsample(
959 ss, number, extra_kwargs=self.extra_kwargs[num]
960 )
962 def to_latex_table(self, labels="all", parameter_dict=None, save_to_file=None):
963 """Make a latex table displaying the data in the result file.
965 Parameters
966 ----------
967 labels: list, optional
968 list of labels that you want to include in the table
969 parameter_dict: dict, optional
970 dictionary of parameters that you wish to include in the latex
971 table. The keys are the name of the parameters and the items are
972 the descriptive text. If None, all parameters are included
973 save_to_file: str, optional
974 name of the file you wish to save the latex table to. If None, print
975 to stdout
976 """
977 import os
979 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)):
980 raise FileExistsError(
981 "The file {} already exists.".format(save_to_file)
982 )
983 if labels != "all" and isinstance(labels, str) and labels not in self.labels:
984 raise ValueError("The label %s does not exist." % (labels))
985 elif labels == "all":
986 labels = list(self.labels)
987 elif isinstance(labels, str):
988 labels = [labels]
989 elif isinstance(labels, list):
990 for ll in labels:
991 if ll not in list(self.labels):
992 raise ValueError("The label %s does not exist." % (ll))
994 table = self.latex_table(
995 [self.samples_dict[label] for label in labels], parameter_dict,
996 labels=labels
997 )
998 if save_to_file is None:
999 print(table)
1000 elif os.path.isfile("{}".format(save_to_file)):
1001 logger.warning(
1002 "File {} already exists. Printing to stdout".format(save_to_file)
1003 )
1004 print(table)
1005 else:
1006 with open(save_to_file, "w") as f:
1007 f.writelines([table])
1009 def generate_latex_macros(
1010 self, labels="all", parameter_dict=None, save_to_file=None,
1011 rounding=2
1012 ):
1013 """Generate a list of latex macros for each parameter in the result
1014 file
1016 Parameters
1017 ----------
1018 labels: list, optional
1019 list of labels that you want to include in the table
1020 parameter_dict: dict, optional
1021 dictionary of parameters that you wish to generate macros for. The
1022 keys are the name of the parameters and the items are the latex
1023 macros name you wish to use. If None, all parameters are included.
1024 save_to_file: str, optional
1025 name of the file you wish to save the latex table to. If None, print
1026 to stdout
1027 rounding: int, optional
1028 number of decimal points to round the latex macros
1029 """
1030 import os
1032 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)):
1033 raise FileExistsError(
1034 "The file {} already exists.".format(save_to_file)
1035 )
1036 if labels != "all" and isinstance(labels, str) and labels not in self.labels:
1037 raise ValueError("The label %s does not exist." % (labels))
1038 elif labels == "all":
1039 labels = list(self.labels)
1040 elif isinstance(labels, str):
1041 labels = [labels]
1042 elif isinstance(labels, list):
1043 for ll in labels:
1044 if ll not in list(self.labels):
1045 raise ValueError("The label %s does not exist." % (ll))
1047 macros = self.latex_macros(
1048 [self.samples_dict[i] for i in labels], parameter_dict,
1049 labels=labels, rounding=rounding
1050 )
1051 if save_to_file is None:
1052 print(macros)
1053 else:
1054 with open(save_to_file, "w") as f:
1055 f.writelines([macros])
1057 def reweight_samples(self, function, labels=None, **kwargs):
1058 """Reweight the posterior and/or prior samples according to a new prior
1060 Parameters
1061 ----------
1062 labels: list, optional
1063 list of analyses you wish to reweight. Default reweight all
1064 analyses
1065 """
1066 _samples_dict = self.samples_dict
1067 for idx, label in enumerate(self.labels):
1068 if labels is not None and label not in labels:
1069 continue
1070 new_samples = _samples_dict[label].reweight(function, **kwargs)
1071 self.parameters[idx] = Parameters(new_samples.parameters)
1072 self.samples[idx] = np.array(new_samples.samples).T
1073 self.extra_kwargs[idx]["sampler"].update(
1074 {
1075 "nsamples": new_samples.number_of_samples,
1076 "nsamples_before_reweighting": (
1077 _samples_dict[label].number_of_samples
1078 )
1079 }
1080 )
1081 self.extra_kwargs[idx]["meta_data"]["reweighting"] = function
1082 if not hasattr(self, "priors"):
1083 continue
1084 if "samples" not in self.priors.keys():
1085 continue
1086 prior_samples = self.priors["samples"][label]
1087 if not len(prior_samples):
1088 continue
1089 new_prior_samples = prior_samples.reweight(function, **kwargs)
1090 self.priors["samples"][label] = new_prior_samples