Coverage for pesummary/utils/samples_dict.py: 62.9%
636 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import copy
4import numpy as np
5from pesummary.utils.utils import resample_posterior_distribution, logger
6from pesummary.utils.decorators import docstring_subfunction
7from pesummary.utils.array import Array, _2DArray
8from pesummary.utils.dict import Dict
9from pesummary.utils.parameters import Parameters
10from pesummary.core.plots.latex_labels import latex_labels
11from pesummary.gw.plots.latex_labels import GWlatex_labels
12from pesummary import conf
13import importlib
15__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
17latex_labels.update(GWlatex_labels)
20class SamplesDict(Dict):
21 """Class to store the samples from a single run
23 Parameters
24 ----------
25 parameters: list
26 list of parameters
27 samples: nd list
28 list of samples for each parameter
29 autoscale: Bool, optional
30 If True, the posterior samples for each parameter are scaled to the
31 same length
33 Attributes
34 ----------
35 maxL: pesummary.utils.samples_dict.SamplesDict
36 SamplesDict object containing the maximum likelihood sample keyed by
37 the parameter
38 minimum: pesummary.utils.samples_dict.SamplesDict
39 SamplesDict object containing the minimum sample for each parameter
40 maximum: pesummary.utils.samples_dict.SamplesDict
41 SamplesDict object containing the maximum sample for each parameter
42 median: pesummary.utils.samples_dict.SamplesDict
43 SamplesDict object containining the median of each marginalized
44 posterior distribution
45 mean: pesummary.utils.samples_dict.SamplesDict
46 SamplesDict object containing the mean of each marginalized posterior
47 distribution
48 key_data: dict
49 dictionary containing the key data associated with each array
50 number_of_samples: int
51 Number of samples stored in the SamplesDict object
52 latex_labels: dict
53 Dictionary of latex labels for each parameter
54 available_plots: list
55 list of plots which the user may user to display the contained posterior
56 samples
58 Methods
59 -------
60 from_file:
61 Initialize the SamplesDict class with the contents of a file
62 to_pandas:
63 Convert the SamplesDict object to a pandas DataFrame
64 to_structured_array:
65 Convert the SamplesDict object to a numpy structured array
66 pop:
67 Remove an entry from the SamplesDict object
68 standardize_parameter_names:
69 Modify keys in SamplesDict to use standard PESummary names
70 downsample:
71 Downsample the samples stored in the SamplesDict object. See the
72 pesummary.utils.utils.resample_posterior_distribution method
73 discard_samples:
74 Remove the first N samples from each distribution
75 plot:
76 Generate a plot based on the posterior samples stored
77 generate_all_posterior_samples:
78 Convert the posterior samples in the SamplesDict object according to
79 a conversion function
80 debug_keys: list
81 list of keys with an '_' as their first character
82 reweight:
83 Reweight the posterior samples according to a new prior
84 write:
85 Save the stored posterior samples to file
87 Examples
88 --------
89 How the initialize the SamplesDict class
91 >>> from pesummary.utils.samples_dict import SamplesDict
92 >>> data = {
93 ... "a": [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6],
94 ... "b": [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9]
95 ... }
96 >>> dataset = SamplesDict(data)
97 >>> parameters = ["a", "b"]
98 >>> samples = [
99 ... [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6],
100 ... [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9]
101 ... }
102 >>> dataset = SamplesDict(parameters, samples)
103 >>> fig = dataset.plot("a", type="hist", bins=30)
104 >>> fig.show()
105 """
106 def __init__(self, *args, logger_warn="warn", autoscale=True):
107 super(SamplesDict, self).__init__(
108 *args, value_class=Array, make_dict_kwargs={"autoscale": autoscale},
109 logger_warn=logger_warn, latex_labels=latex_labels
110 )
112 def __getitem__(self, key):
113 """Return an object representing the specialization of SamplesDict
114 by type arguments found in key.
115 """
116 if isinstance(key, slice):
117 return SamplesDict(
118 self.parameters, np.array(
119 [i[key.start:key.stop:key.step] for i in self.samples]
120 )
121 )
122 elif isinstance(key, (list, np.ndarray)):
123 return SamplesDict(
124 self.parameters, np.array([i[key] for i in self.samples])
125 )
126 elif key[0] == "_":
127 return self.samples[self.parameters.index(key)]
128 return super(SamplesDict, self).__getitem__(key)
130 def __setitem__(self, key, value):
131 _value = value
132 if not isinstance(value, Array):
133 _value = Array(value)
134 super(SamplesDict, self).__setitem__(key, _value)
135 try:
136 if key not in self.parameters:
137 self.parameters.append(key)
138 try:
139 cond = (
140 np.array(self.samples).ndim == 1 and isinstance(
141 self.samples[0], (float, int, np.number)
142 )
143 )
144 except Exception:
145 cond = False
146 if cond and isinstance(self.samples, np.ndarray):
147 self.samples = np.append(self.samples, value)
148 elif cond and isinstance(self.samples, list):
149 self.samples.append(value)
150 else:
151 self.samples = np.vstack([self.samples, value])
152 self._update_latex_labels()
153 except (AttributeError, TypeError):
154 pass
156 def __str__(self):
157 """Print a summary of the information stored in the dictionary
158 """
159 def format_string(string, row):
160 """Format a list into a table
162 Parameters
163 ----------
164 string: str
165 existing table
166 row: list
167 the row you wish to be written to a table
168 """
169 string += "{:<8}".format(row[0])
170 for i in range(1, len(row)):
171 if isinstance(row[i], str):
172 string += "{:<15}".format(row[i])
173 elif isinstance(row[i], (float, int, np.int64, np.int32)):
174 string += "{:<15.6f}".format(row[i])
175 string += "\n"
176 return string
178 string = ""
179 string = format_string(string, ["idx"] + list(self.keys()))
181 if self.number_of_samples < 8:
182 for i in range(self.number_of_samples):
183 string = format_string(
184 string, [i] + [item[i] for key, item in self.items()]
185 )
186 else:
187 for i in range(4):
188 string = format_string(
189 string, [i] + [item[i] for key, item in self.items()]
190 )
191 for i in range(2):
192 string = format_string(string, ["."] * (len(self.keys()) + 1))
193 for i in range(self.number_of_samples - 2, self.number_of_samples):
194 string = format_string(
195 string, [i] + [item[i] for key, item in self.items()]
196 )
197 return string
199 @classmethod
200 def from_file(cls, filename, **kwargs):
201 """Initialize the SamplesDict class with the contents of a result file
203 Parameters
204 ----------
205 filename: str
206 path to the result file you wish to load.
207 **kwargs: dict
208 all kwargs are passed to the pesummary.io.read function
209 """
210 from pesummary.io import read
212 return read(filename, **kwargs).samples_dict
214 @property
215 def key_data(self):
216 return {param: value.key_data for param, value in self.items()}
218 @property
219 def maxL(self):
220 return SamplesDict(
221 self.parameters, [[item.maxL] for key, item in self.items()]
222 )
224 @property
225 def minimum(self):
226 return SamplesDict(
227 self.parameters, [[item.minimum] for key, item in self.items()]
228 )
230 @property
231 def maximum(self):
232 return SamplesDict(
233 self.parameters, [[item.maximum] for key, item in self.items()]
234 )
236 @property
237 def median(self):
238 return SamplesDict(
239 self.parameters,
240 [[item.average(type="median")] for key, item in self.items()]
241 )
243 @property
244 def mean(self):
245 return SamplesDict(
246 self.parameters,
247 [[item.average(type="mean")] for key, item in self.items()]
248 )
250 @property
251 def number_of_samples(self):
252 return len(self[self.parameters[0]])
254 @property
255 def plotting_map(self):
256 existing = super(SamplesDict, self).plotting_map
257 modified = existing.copy()
258 modified.update(
259 {
260 "marginalized_posterior": self._marginalized_posterior,
261 "skymap": self._skymap,
262 "hist": self._marginalized_posterior,
263 "corner": self._corner,
264 "spin_disk": self._spin_disk,
265 "2d_kde": self._2d_kde,
266 "triangle": self._triangle,
267 "reverse_triangle": self._reverse_triangle,
268 }
269 )
270 return modified
272 def standardize_parameter_names(self, mapping=None):
273 """Modify keys in SamplesDict to use standard PESummary names
275 Parameters
276 ----------
277 mapping: dict, optional
278 dictionary mapping existing keys to standard PESummary names.
279 Default pesummary.gw.file.standard_names.standard_names
281 Returns
282 -------
283 standard_dict: SamplesDict
284 SamplesDict object with standard PESummary parameter names
285 """
286 from pesummary.utils.utils import map_parameter_names
287 if mapping is None:
288 from pesummary.gw.file.standard_names import standard_names
289 mapping = standard_names
290 return SamplesDict(map_parameter_names(self, mapping))
292 def debug_keys(self, *args, **kwargs):
293 _keys = self.keys()
294 _total = self.keys(remove_debug=False)
295 return Parameters([key for key in _total if key not in _keys])
297 def keys(self, *args, remove_debug=True, **kwargs):
298 original = super(SamplesDict, self).keys(*args, **kwargs)
299 if remove_debug:
300 return Parameters([key for key in original if key[0] != "_"])
301 return Parameters(original)
303 def write(self, **kwargs):
304 """Save the stored posterior samples to file
306 Parameters
307 ----------
308 **kwargs: dict, optional
309 all additional kwargs passed to the pesummary.io.write function
310 """
311 from pesummary.io import write
312 write(self.parameters, self.samples.T, **kwargs)
314 def items(self, *args, remove_debug=True, **kwargs):
315 items = super(SamplesDict, self).items(*args, **kwargs)
316 if remove_debug:
317 return [item for item in items if item[0][0] != "_"]
318 return items
320 def to_pandas(self, **kwargs):
321 """Convert a SamplesDict object to a pandas dataframe
322 """
323 from pandas import DataFrame
325 return DataFrame(self, **kwargs)
327 def to_structured_array(self, **kwargs):
328 """Convert a SamplesDict object to a structured numpy array
329 """
330 return self.to_pandas(**kwargs).to_records(
331 index=False, column_dtypes=float
332 )
334 def pop(self, parameter):
335 """Delete a parameter from the SamplesDict
337 Parameters
338 ----------
339 parameter: str
340 name of the parameter you wish to remove from the SamplesDict
341 """
342 if parameter not in self.parameters:
343 logger.info(
344 "{} not in SamplesDict. Unable to remove {}".format(
345 parameter, parameter
346 )
347 )
348 return
349 ind = self.parameters.index(parameter)
350 self.parameters.remove(parameter)
351 samples = self.samples
352 self.samples = np.delete(samples, ind, axis=0)
353 return super(SamplesDict, self).pop(parameter)
355 def downsample(self, number):
356 """Downsample the samples stored in the SamplesDict class
358 Parameters
359 ----------
360 number: int
361 Number of samples you wish to downsample to
362 """
363 self.samples = resample_posterior_distribution(self.samples, number)
364 self.make_dictionary()
365 return self
367 def discard_samples(self, number):
368 """Remove the first n samples
370 Parameters
371 ----------
372 number: int
373 Number of samples that you wish to remove
374 """
375 self.make_dictionary(discard_samples=number)
376 return self
378 def make_dictionary(self, discard_samples=None, autoscale=True):
379 """Add the parameters and samples to the class
380 """
381 lengths = [len(i) for i in self.samples]
382 if len(np.unique(lengths)) > 1 and autoscale:
383 nsamples = np.min(lengths)
384 getattr(logger, self.logger_warn)(
385 "Unequal number of samples for each parameter. "
386 "Restricting all posterior samples to have {} "
387 "samples".format(nsamples)
388 )
389 self.samples = [
390 dataset[:nsamples] for dataset in self.samples
391 ]
392 if "log_likelihood" in self.parameters:
393 likelihoods = self.samples[self.parameters.index("log_likelihood")]
394 likelihoods = likelihoods[discard_samples:]
395 else:
396 likelihoods = None
397 if "log_prior" in self.parameters:
398 priors = self.samples[self.parameters.index("log_prior")]
399 priors = priors[discard_samples:]
400 else:
401 priors = None
402 if any(i in self.parameters for i in ["weights", "weight"]):
403 ind = (
404 self.parameters.index("weights") if "weights" in self.parameters
405 else self.parameters.index("weight")
406 )
407 weights = self.samples[ind][discard_samples:]
408 else:
409 weights = None
410 _2d_array = _2DArray(
411 np.array(self.samples)[:, discard_samples:], likelihood=likelihoods,
412 prior=priors, weights=weights
413 )
414 for key, val in zip(self.parameters, _2d_array):
415 self[key] = val
417 @docstring_subfunction([
418 'pesummary.core.plots.plot._1d_histogram_plot',
419 'pesummary.gw.plots.plot._1d_histogram_plot',
420 'pesummary.gw.plots.plot._ligo_skymap_plot',
421 'pesummary.gw.plots.publication.spin_distribution_plots',
422 'pesummary.core.plots.plot._make_corner_plot',
423 'pesummary.gw.plots.plot._make_corner_plot'
424 ])
425 def plot(self, *args, type="marginalized_posterior", **kwargs):
426 """Generate a plot for the posterior samples stored in SamplesDict
428 Parameters
429 ----------
430 *args: tuple
431 all arguments are passed to the plotting function
432 type: str
433 name of the plot you wish to make
434 **kwargs: dict
435 all additional kwargs are passed to the plotting function
436 """
437 return super(SamplesDict, self).plot(*args, type=type, **kwargs)
439 def generate_all_posterior_samples(self, function=None, **kwargs):
440 """Convert samples stored in the SamplesDict according to a conversion
441 function
443 Parameters
444 ----------
445 function: func, optional
446 function to use when converting posterior samples. Must take a
447 dictionary as input and return a dictionary of converted posterior
448 samples. Default `pesummary.gw.conversions.convert
449 **kwargs: dict, optional
450 All additional kwargs passed to function
451 """
452 if function is None:
453 from pesummary.gw.conversions import convert
454 function = convert
455 _samples = self.copy()
456 _keys = list(_samples.keys())
457 kwargs.update({"return_dict": True})
458 out = function(_samples, **kwargs)
459 if kwargs.get("return_kwargs", False):
460 converted_samples, extra_kwargs = out
461 else:
462 converted_samples, extra_kwargs = out, None
463 for key, item in converted_samples.items():
464 if key not in _keys:
465 self[key] = item
466 return extra_kwargs
468 def reweight(
469 self, function, ignore_debug_params=["recalib", "spcal"], **kwargs
470 ):
471 """Reweight the posterior samples according to a new prior
473 Parameters
474 ----------
475 function: func/str
476 function to use when resampling
477 ignore_debug_params: list, optional
478 params to ignore when storing unweighted posterior distributions.
479 Default any param with ['recalib', 'spcal'] in their name
480 """
481 from pesummary.gw.reweight import options
482 if isinstance(function, str) and function in options.keys():
483 function = options[function]
484 elif isinstance(function, str):
485 raise ValueError(
486 "Unknown function '{}'. Please provide a function for "
487 "reweighting or select one of the following: {}".format(
488 function, ", ".join(list(options.keys()))
489 )
490 )
491 _samples = SamplesDict(self.copy())
492 new_samples = function(_samples, **kwargs)
493 _samples.downsample(new_samples.number_of_samples)
494 for key, item in new_samples.items():
495 if not any(param in key for param in ignore_debug_params):
496 _samples["_{}_non_reweighted".format(key)] = _samples[key]
497 _samples[key] = item
498 return SamplesDict(_samples)
500 def _marginalized_posterior(self, parameter, module="core", **kwargs):
501 """Wrapper for the `pesummary.core.plots.plot._1d_histogram_plot` or
502 `pesummary.gw.plots.plot._1d_histogram_plot`
504 Parameters
505 ----------
506 parameter: str
507 name of the parameter you wish to plot
508 module: str, optional
509 module you wish to use for the plotting
510 **kwargs: dict
511 all additional kwargs are passed to the `_1d_histogram_plot`
512 function
513 """
514 module = importlib.import_module(
515 "pesummary.{}.plots.plot".format(module)
516 )
517 return getattr(module, "_1d_histogram_plot")(
518 parameter, self[parameter], self.latex_labels[parameter],
519 weights=self[parameter].weights, **kwargs
520 )
522 def _skymap(self, **kwargs):
523 """Wrapper for the `pesummary.gw.plots.plot._ligo_skymap_plot`
524 function
526 Parameters
527 ----------
528 **kwargs: dict
529 All kwargs are passed to the `_ligo_skymap_plot` function
530 """
531 from pesummary.gw.plots.plot import _ligo_skymap_plot
533 if "luminosity_distance" in self.keys():
534 dist = self["luminosity_distance"]
535 else:
536 dist = None
538 return _ligo_skymap_plot(self["ra"], self["dec"], dist=dist, **kwargs)
540 def _spin_disk(self, **kwargs):
541 """Wrapper for the `pesummary.gw.plots.publication.spin_distribution_plots`
542 function
543 """
544 from pesummary.gw.plots.publication import spin_distribution_plots
546 required = ["a_1", "a_2", "cos_tilt_1", "cos_tilt_2"]
547 if not all(param in self.keys() for param in required):
548 raise ValueError(
549 "The spin disk plot requires samples for the following "
550 "parameters: {}".format(", ".join(required))
551 )
552 samples = [self[param] for param in required]
553 return spin_distribution_plots(required, samples, None, **kwargs)
555 def _corner(self, module="core", parameters=None, **kwargs):
556 """Wrapper for the `pesummary.core.plots.plot._make_corner_plot` or
557 `pesummary.gw.plots.plot._make_corner_plot` function
559 Parameters
560 ----------
561 module: str, optional
562 module you wish to use for the plotting
563 **kwargs: dict
564 all additional kwargs are passed to the `_make_corner_plot`
565 function
566 """
567 module = importlib.import_module(
568 "pesummary.{}.plots.plot".format(module)
569 )
570 _parameters = None
571 if parameters is not None:
572 _parameters = [param for param in parameters if param in self.keys()]
573 if not len(_parameters):
574 raise ValueError(
575 "None of the chosen parameters are in the posterior "
576 "samples table. Please choose other parameters to plot"
577 )
578 return getattr(module, "_make_corner_plot")(
579 self, self.latex_labels, corner_parameters=_parameters, **kwargs
580 )[0]
582 def _2d_kde(self, parameters, module="core", **kwargs):
583 """Wrapper for the `pesummary.gw.plots.publication.twod_contour_plot` or
584 `pesummary.core.plots.publication.twod_contour_plot` function
586 Parameters
587 ----------
588 parameters: list
589 list of length 2 giving the parameters you wish to plot
590 module: str, optional
591 module you wish to use for the plotting
592 **kwargs: dict, optional
593 all additional kwargs are passed to the `twod_contour_plot` function
594 """
595 _module = importlib.import_module(
596 "pesummary.{}.plots.publication".format(module)
597 )
598 if module == "gw":
599 return getattr(_module, "twod_contour_plots")(
600 parameters, [[self[parameters[0]], self[parameters[1]]]],
601 [None], {
602 parameters[0]: self.latex_labels[parameters[0]],
603 parameters[1]: self.latex_labels[parameters[1]]
604 }, **kwargs
605 )
606 return getattr(_module, "twod_contour_plot")(
607 self[parameters[0]], self[parameters[1]],
608 xlabel=self.latex_labels[parameters[0]],
609 ylabel=self.latex_labels[parameters[1]], **kwargs
610 )
612 def _triangle(self, parameters, module="core", **kwargs):
613 """Wrapper for the `pesummary.core.plots.publication.triangle_plot`
614 function
616 Parameters
617 ----------
618 parameters: list
619 list of parameters they wish to study
620 **kwargs: dict
621 all additional kwargs are passed to the `triangle_plot` function
622 """
623 _module = importlib.import_module(
624 "pesummary.{}.plots.publication".format(module)
625 )
626 if module == "gw":
627 kwargs["parameters"] = parameters
628 return getattr(_module, "triangle_plot")(
629 [self[parameters[0]]], [self[parameters[1]]],
630 xlabel=self.latex_labels[parameters[0]],
631 ylabel=self.latex_labels[parameters[1]], **kwargs
632 )
634 def _reverse_triangle(self, parameters, module="core", **kwargs):
635 """Wrapper for the `pesummary.core.plots.publication.reverse_triangle_plot`
636 function
638 Parameters
639 ----------
640 parameters: list
641 list of parameters they wish to study
642 **kwargs: dict
643 all additional kwargs are passed to the `triangle_plot` function
644 """
645 _module = importlib.import_module(
646 "pesummary.{}.plots.publication".format(module)
647 )
648 if module == "gw":
649 kwargs["parameters"] = parameters
650 return getattr(_module, "reverse_triangle_plot")(
651 [self[parameters[0]]], [self[parameters[1]]],
652 xlabel=self.latex_labels[parameters[0]],
653 ylabel=self.latex_labels[parameters[1]], **kwargs
654 )
656 def classification(self, dual=True, population=False):
657 """Return the classification probabilities
659 Parameters
660 ----------
661 dual: Bool, optional
662 if True, return classification probabilities generated from the
663 raw samples ('default') an samples reweighted to a population
664 inferred prior ('population'). Default True.
665 population: Bool, optional
666 if True, reweight the samples to a population informed prior and
667 then calculate classification probabilities. Default False. Only
668 used when dual=False
669 """
670 from pesummary.gw.classification import Classify
671 if dual:
672 probs = Classify(self).dual_classification()
673 else:
674 probs = Classify(self).classification(population=population)
675 return probs
677 def _waveform_args(self, f_ref=20., ind=0, longAscNodes=0., eccentricity=0.):
678 """Arguments to be passed to waveform generation
680 Parameters
681 ----------
682 f_ref: float, optional
683 reference frequency to use when converting spherical spins to
684 cartesian spins
685 ind: int, optional
686 index for the sample you wish to plot
687 longAscNodes: float, optional
688 longitude of ascending nodes, degenerate with the polarization
689 angle. Default 0.
690 eccentricity: float, optional
691 eccentricity at reference frequency. Default 0.
692 """
693 from lal import MSUN_SI, PC_SI
695 _samples = {key: value[ind] for key, value in self.items()}
696 required = [
697 "mass_1", "mass_2", "luminosity_distance"
698 ]
699 if not all(param in _samples.keys() for param in required):
700 raise ValueError(
701 "Unable to generate a waveform. Please add samples for "
702 + ", ".join(required)
703 )
704 waveform_args = [
705 _samples["mass_1"] * MSUN_SI, _samples["mass_2"] * MSUN_SI
706 ]
707 spin_angles = [
708 "theta_jn", "phi_jl", "tilt_1", "tilt_2", "phi_12", "a_1", "a_2",
709 "phase"
710 ]
711 spin_angles_condition = all(
712 spin in _samples.keys() for spin in spin_angles
713 )
714 cartesian_spins = [
715 "spin_1x", "spin_1y", "spin_1z", "spin_2x", "spin_2y", "spin_2z"
716 ]
717 cartesian_spins_condition = any(
718 spin in _samples.keys() for spin in cartesian_spins
719 )
720 if spin_angles_condition and not cartesian_spins_condition:
721 from pesummary.gw.conversions import component_spins
722 data = component_spins(
723 _samples["theta_jn"], _samples["phi_jl"], _samples["tilt_1"],
724 _samples["tilt_2"], _samples["phi_12"], _samples["a_1"],
725 _samples["a_2"], _samples["mass_1"], _samples["mass_2"],
726 f_ref, _samples["phase"]
727 )
728 iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = data.T
729 spins = [spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z]
730 else:
731 iota = _samples["iota"]
732 spins = [
733 _samples[param] if param in _samples.keys() else 0. for param in
734 ["spin_1x", "spin_1y", "spin_1z", "spin_2x", "spin_2y", "spin_2z"]
735 ]
736 waveform_args += spins
737 phase = _samples["phase"] if "phase" in _samples.keys() else 0.
738 waveform_args += [
739 _samples["luminosity_distance"] * PC_SI * 10**6, iota, phase
740 ]
741 waveform_args += [longAscNodes, eccentricity, 0.]
742 return waveform_args, _samples
744 def antenna_response(self, ifo):
745 """
746 """
747 from pesummary.gw.waveform import antenna_response
748 return antenna_response(self, ifo)
750 def _project_waveform(self, ifo, hp, hc, ra, dec, psi, time):
751 """Project a waveform onto a given detector
753 Parameters
754 ----------
755 ifo: str
756 name of the detector you wish to project the waveform onto
757 hp: np.ndarray
758 plus gravitational wave polarization
759 hc: np.ndarray
760 cross gravitational wave polarization
761 ra: float
762 right ascension to be passed to antenna response function
763 dec: float
764 declination to be passed to antenna response function
765 psi: float
766 polarization to be passed to antenna response function
767 time: float
768 time to be passed to antenna response function
769 """
770 import importlib
772 mod = importlib.import_module("pesummary.gw.plots.plot")
773 func = getattr(mod, "__antenna_response")
774 antenna = func(ifo, ra, dec, psi, time)
775 ht = hp * antenna[0] + hc * antenna[1]
776 return ht
778 def fd_waveform(self, approximant, delta_f, f_low, f_high, **kwargs):
779 """Generate a gravitational wave in the frequency domain
781 Parameters
782 ----------
783 approximant: str
784 name of the approximant to use when generating the waveform
785 delta_f: float
786 spacing between frequency samples
787 f_low: float
788 frequency to start evaluating the waveform
789 f_high: float
790 frequency to stop evaluating the waveform
791 f_ref: float, optional
792 reference frequency
793 project: str, optional
794 name of the detector to project the waveform onto. If None,
795 the plus and cross polarizations are returned. Default None
796 ind: int, optional
797 index for the sample you wish to plot
798 longAscNodes: float, optional
799 longitude of ascending nodes, degenerate with the polarization
800 angle. Default 0.
801 eccentricity: float, optional
802 eccentricity at reference frequency. Default 0.
803 LAL_parameters: dict, optional
804 LAL dictioanry containing accessory parameters. Default None
805 pycbc: Bool, optional
806 return a the waveform as a pycbc.frequencyseries.FrequencySeries
807 object
808 """
809 from pesummary.gw.waveform import fd_waveform
810 return fd_waveform(self, approximant, delta_f, f_low, f_high, **kwargs)
812 def td_waveform(
813 self, approximant, delta_t, f_low, **kwargs
814 ):
815 """Generate a gravitational wave in the time domain
817 Parameters
818 ----------
819 approximant: str
820 name of the approximant to use when generating the waveform
821 delta_t: float
822 spacing between frequency samples
823 f_low: float
824 frequency to start evaluating the waveform
825 f_ref: float, optional
826 reference frequency
827 project: str, optional
828 name of the detector to project the waveform onto. If None,
829 the plus and cross polarizations are returned. Default None
830 ind: int, optional
831 index for the sample you wish to plot
832 longAscNodes: float, optional
833 longitude of ascending nodes, degenerate with the polarization
834 angle. Default 0.
835 eccentricity: float, optional
836 eccentricity at reference frequency. Default 0.
837 LAL_parameters: dict, optional
838 LAL dictioanry containing accessory parameters. Default None
839 pycbc: Bool, optional
840 return a the waveform as a pycbc.timeseries.TimeSeries object
841 level: list, optional
842 the symmetric confidence interval of the time domain waveform. Level
843 must be greater than 0 and less than 1
844 """
845 from pesummary.gw.waveform import td_waveform
846 return td_waveform(
847 self, approximant, delta_t, f_low, **kwargs
848 )
850 def _maxL_waveform(self, func, *args, **kwargs):
851 """Return the maximum likelihood waveform in a given domain
853 Parameters
854 ----------
855 func: function
856 function you wish to use when generating the maximum likelihood
857 waveform
858 *args: tuple
859 all args passed to func
860 **kwargs: dict
861 all kwargs passed to func
862 """
863 ind = np.argmax(self["log_likelihood"])
864 kwargs["ind"] = ind
865 return func(*args, **kwargs)
867 def maxL_td_waveform(self, *args, **kwargs):
868 """Generate the maximum likelihood gravitational wave in the time domain
870 Parameters
871 ----------
872 approximant: str
873 name of the approximant to use when generating the waveform
874 delta_t: float
875 spacing between frequency samples
876 f_low: float
877 frequency to start evaluating the waveform
878 f_ref: float, optional
879 reference frequency
880 project: str, optional
881 name of the detector to project the waveform onto. If None,
882 the plus and cross polarizations are returned. Default None
883 longAscNodes: float, optional
884 longitude of ascending nodes, degenerate with the polarization
885 angle. Default 0.
886 eccentricity: float, optional
887 eccentricity at reference frequency. Default 0.
888 LAL_parameters: dict, optional
889 LAL dictioanry containing accessory parameters. Default None
890 level: list, optional
891 the symmetric confidence interval of the time domain waveform. Level
892 must be greater than 0 and less than 1
893 """
894 return self._maxL_waveform(self.td_waveform, *args, **kwargs)
896 def maxL_fd_waveform(self, *args, **kwargs):
897 """Generate the maximum likelihood gravitational wave in the frequency
898 domain
900 Parameters
901 ----------
902 approximant: str
903 name of the approximant to use when generating the waveform
904 delta_f: float
905 spacing between frequency samples
906 f_low: float
907 frequency to start evaluating the waveform
908 f_high: float
909 frequency to stop evaluating the waveform
910 f_ref: float, optional
911 reference frequency
912 project: str, optional
913 name of the detector to project the waveform onto. If None,
914 the plus and cross polarizations are returned. Default None
915 longAscNodes: float, optional
916 longitude of ascending nodes, degenerate with the polarization
917 angle. Default 0.
918 eccentricity: float, optional
919 eccentricity at reference frequency. Default 0.
920 LAL_parameters: dict, optional
921 LAL dictioanry containing accessory parameters. Default None
922 """
923 return self._maxL_waveform(self.fd_waveform, *args, **kwargs)
926class _MultiDimensionalSamplesDict(Dict):
927 """Class to store multiple SamplesDict objects
929 Parameters
930 ----------
931 parameters: list
932 list of parameters
933 samples: nd list
934 list of samples for each parameter for each chain
935 label_prefix: str, optional
936 prefix to use when distinguishing different analyses. The label is then
937 '{label_prefix}_{num}' where num is the result file index. Default
938 is 'dataset'
939 transpose: Bool, optional
940 True if the input is a transposed dictionary
941 labels: list, optional
942 the labels to use to distinguish different analyses. If provided
943 label_prefix is ignored
945 Attributes
946 ----------
947 T: pesummary.utils.samples_dict._MultiDimensionalSamplesDict
948 Transposed _MultiDimensionalSamplesDict object keyed by parameters
949 rather than label
950 nsamples: int
951 Total number of analyses stored in the _MultiDimensionalSamplesDict
952 object
953 number_of_samples: dict
954 Number of samples stored in the _MultiDimensionalSamplesDict for each
955 analysis
956 total_number_of_samples: int
957 Total number of samples stored across the multiple analyses
958 minimum_number_of_samples: int
959 The number of samples in the smallest analysis
961 Methods
962 -------
963 samples:
964 Return a list of samples stored in the _MultiDimensionalSamplesDict
965 object for a given parameter
966 """
967 def __init__(
968 self, *args, label_prefix="dataset", transpose=False, labels=None
969 ):
970 if labels is not None and len(np.unique(labels)) != len(labels):
971 raise ValueError(
972 "Please provide a unique set of labels for each analysis"
973 )
974 invalid_label_number_error = "Please provide a label for each analysis"
975 self.labels = labels
976 self.name = _MultiDimensionalSamplesDict
977 self.transpose = transpose
978 if len(args) == 1 and isinstance(args[0], dict):
979 if transpose:
980 parameters = list(args[0].keys())
981 _labels = list(args[0][parameters[0]].keys())
982 outer_iterator, inner_iterator = parameters, _labels
983 else:
984 _labels = list(args[0].keys())
985 parameters = {
986 label: list(args[0][label].keys()) for label in _labels
987 }
988 outer_iterator, inner_iterator = _labels, parameters
989 if labels is None:
990 self.labels = _labels
991 for num, dataset in enumerate(outer_iterator):
992 if isinstance(inner_iterator, dict):
993 try:
994 samples = np.array(
995 [args[0][dataset][param] for param in inner_iterator[dataset]]
996 )
997 except ValueError: # numpy deprecation error
998 samples = np.array(
999 [args[0][dataset][param] for param in inner_iterator[dataset]],
1000 dtype=object
1001 )
1002 else:
1003 try:
1004 samples = np.array(
1005 [args[0][dataset][param] for param in inner_iterator]
1006 )
1007 except ValueError: # numpy deprecation error
1008 samples = np.array(
1009 [args[0][dataset][param] for param in inner_iterator],
1010 dtype=object
1011 )
1012 if transpose:
1013 desc = parameters[num]
1014 self[desc] = SamplesDict(
1015 self.labels, samples, logger_warn="debug",
1016 autoscale=False
1017 )
1018 else:
1019 if self.labels is not None:
1020 desc = self.labels[num]
1021 else:
1022 desc = "{}_{}".format(label_prefix, num)
1023 self[desc] = SamplesDict(parameters[self.labels[num]], samples)
1024 else:
1025 parameters, samples = args
1026 if labels is not None and len(labels) != len(samples):
1027 raise ValueError(invalid_label_number_error)
1028 for num, dataset in enumerate(samples):
1029 if labels is not None:
1030 desc = labels[num]
1031 else:
1032 desc = "{}_{}".format(label_prefix, num)
1033 self[desc] = SamplesDict(parameters, dataset)
1034 if self.labels is None:
1035 self.labels = [
1036 "{}_{}".format(label_prefix, num) for num, _ in
1037 enumerate(samples)
1038 ]
1039 self.parameters = parameters
1040 self._update_latex_labels()
1042 def _update_latex_labels(self):
1043 """Update the stored latex labels
1044 """
1045 _parameters = [
1046 list(value.keys()) for value in self.values()
1047 ]
1048 _parameters = [item for sublist in _parameters for item in sublist]
1049 self._latex_labels = {
1050 param: latex_labels[param] if param in latex_labels.keys() else
1051 param for param in self.total_list_of_parameters + _parameters
1052 }
1054 def __setitem__(self, key, value):
1055 _value = value
1056 if not isinstance(value, SamplesDict):
1057 _value = SamplesDict(value)
1058 super(_MultiDimensionalSamplesDict, self).__setitem__(key, _value)
1059 try:
1060 if key not in self.labels:
1061 parameters = list(value.keys())
1062 try:
1063 samples = np.array([value[param] for param in parameters])
1064 except ValueError: # numpy deprecation error
1065 samples = np.array(
1066 [value[param] for param in parameters], dtype=object
1067 )
1068 self.parameters[key] = parameters
1069 self.labels.append(key)
1070 self.latex_labels = self._latex_labels()
1071 except (AttributeError, TypeError):
1072 pass
1074 @property
1075 def T(self):
1076 _transpose = not self.transpose
1077 if not self.transpose:
1078 _params = sorted([param for param in self[self.labels[0]].keys()])
1079 if not all(sorted(self[l].keys()) == _params for l in self.labels):
1080 raise ValueError(
1081 "Unable to transpose as not all samples have the same "
1082 "parameters"
1083 )
1084 transpose_dict = {
1085 param: {
1086 label: dataset[param] for label, dataset in self.items()
1087 } for param in self[self.labels[0]].keys()
1088 }
1089 else:
1090 transpose_dict = {
1091 label: {
1092 param: self[param][label] for param in self.keys()
1093 } for label in self.labels
1094 }
1095 return self.name(transpose_dict, transpose=_transpose)
1097 def _combine(
1098 self, labels=None, use_all=False, weights=None, shuffle=False,
1099 logger_level="debug"
1100 ):
1101 """Combine samples from a select number of analyses into a single
1102 SamplesDict object.
1104 Parameters
1105 ----------
1106 labels: list, optional
1107 analyses you wish to combine. Default use all labels stored in the
1108 dictionary
1109 use_all: Bool, optional
1110 if True, use all of the samples (do not weight). Default False
1111 weights: dict, optional
1112 dictionary of weights for each of the posteriors. Keys must be the
1113 labels you wish to combine and values are the weights you wish to
1114 assign to the posterior
1115 shuffle: Bool, optional
1116 shuffle the combined samples
1117 logger_level: str, optional
1118 logger level you wish to use. Default debug.
1119 """
1120 try:
1121 _logger = getattr(logger, logger_level)
1122 except AttributeError:
1123 raise ValueError(
1124 "Unknown logger level. Please choose either 'info' or 'debug'"
1125 )
1126 if labels is None:
1127 _provided_labels = False
1128 labels = self.labels
1129 else:
1130 _provided_labels = True
1131 if not all(label in self.labels for label in labels):
1132 raise ValueError(
1133 "Not all of the provided labels exist in the dictionary. "
1134 "The list of available labels are: {}".format(
1135 ", ".join(self.labels)
1136 )
1137 )
1138 _logger("Combining the following analyses: {}".format(labels))
1139 if use_all and weights is not None:
1140 raise ValueError(
1141 "Unable to use all samples and provide weights"
1142 )
1143 elif not use_all and weights is None:
1144 weights = {label: 1. for label in labels}
1145 elif not use_all and weights is not None:
1146 if len(weights) < len(labels):
1147 raise ValueError(
1148 "Please provide weights for each set of samples: {}".format(
1149 len(labels)
1150 )
1151 )
1152 if not _provided_labels and not isinstance(weights, dict):
1153 raise ValueError(
1154 "Weights must be provided as a dictionary keyed by the "
1155 "analysis label. The available labels are: {}".format(
1156 ", ".join(labels)
1157 )
1158 )
1159 elif not isinstance(weights, dict):
1160 weights = {
1161 label: weight for label, weight in zip(labels, weights)
1162 }
1163 if not all(label in labels for label in weights.keys()):
1164 for label in labels:
1165 if label not in weights.keys():
1166 weights[label] = 1.
1167 logger.warning(
1168 "No weight given for '{}'. Assigning a weight of "
1169 "1".format(label)
1170 )
1171 sum_weights = np.sum([_weight for _weight in weights.values()])
1172 weights = {
1173 key: item / sum_weights for key, item in weights.items()
1174 }
1175 if weights is not None:
1176 _logger(
1177 "Using the following weights for each file, {}".format(
1178 " ".join(
1179 ["{}: {}".format(k, v) for k, v in weights.items()]
1180 )
1181 )
1182 )
1183 _lengths = np.array(
1184 [self.number_of_samples[key] for key in labels]
1185 )
1186 if use_all:
1187 draw = _lengths
1188 else:
1189 draw = np.zeros(len(labels), dtype=int)
1190 _weights = np.array([weights[key] for key in labels])
1191 inds = np.argwhere(_weights > 0.)
1192 # The next 4 lines are inspired from the 'cbcBayesCombinePosteriors'
1193 # executable provided by LALSuite. Credit should go to the
1194 # authors of that code.
1195 initial = _weights[inds] * float(sum(_lengths[inds]))
1196 min_index = np.argmin(_lengths[inds] / initial)
1197 size = _lengths[inds][min_index] / _weights[inds][min_index]
1198 draw[inds] = np.around(_weights[inds] * size).astype(int)
1199 _logger(
1200 "Randomly drawing the following number of samples from each file, "
1201 "{}".format(
1202 " ".join(
1203 [
1204 "{}: {}/{}".format(l, draw[n], _lengths[n]) for n, l in
1205 enumerate(labels)
1206 ]
1207 )
1208 )
1209 )
1211 if self.transpose:
1212 _data = self.T
1213 else:
1214 _data = copy.deepcopy(self)
1215 for num, label in enumerate(labels):
1216 if draw[num] > 0:
1217 _data[label].downsample(draw[num])
1218 else:
1219 _data[label] = {
1220 param: np.array([]) for param in _data[label].keys()
1221 }
1222 try:
1223 intersection = set.intersection(
1224 *[
1225 set(_params) for _key, _params in _data.parameters.items() if
1226 _key in labels
1227 ]
1228 )
1229 except AttributeError:
1230 intersection = _data.parameters
1231 logger.debug(
1232 "Only including the parameters: {} as they are common to all "
1233 "analyses".format(", ".join(list(intersection)))
1234 )
1235 data = {
1236 param: np.concatenate([_data[key][param] for key in labels]) for
1237 param in intersection
1238 }
1239 if shuffle:
1240 inds = np.random.choice(
1241 np.sum(draw), size=np.sum(draw), replace=False
1242 )
1243 data = {
1244 param: value[inds] for param, value in data.items()
1245 }
1246 return SamplesDict(data, logger_warn="debug")
1248 @property
1249 def nsamples(self):
1250 if self.transpose:
1251 parameters = list(self.keys())
1252 return len(self[parameters[0]])
1253 return len(self)
1255 @property
1256 def number_of_samples(self):
1257 if self.transpose:
1258 return {
1259 label: len(self[iterator][label]) for iterator, label in zip(
1260 self.keys(), self.labels
1261 )
1262 }
1263 return {
1264 label: self[iterator].number_of_samples for iterator, label in zip(
1265 self.keys(), self.labels
1266 )
1267 }
1269 @property
1270 def total_number_of_samples(self):
1271 return np.sum([length for length in self.number_of_samples.values()])
1273 @property
1274 def minimum_number_of_samples(self):
1275 return np.min([length for length in self.number_of_samples.values()])
1277 @property
1278 def total_list_of_parameters(self):
1279 if isinstance(self.parameters, dict):
1280 _parameters = [item for item in self.parameters.values()]
1281 _flat_parameters = [
1282 item for sublist in _parameters for item in sublist
1283 ]
1284 elif isinstance(self.parameters, list):
1285 if np.array(self.parameters).ndim > 1:
1286 _flat_parameters = [
1287 item for sublist in self.parameters for item in sublist
1288 ]
1289 else:
1290 _flat_parameters = self.parameters
1291 return list(set(_flat_parameters))
1293 def samples(self, parameter):
1294 if self.transpose:
1295 samples = [self[parameter][label] for label in self.labels]
1296 else:
1297 samples = [self[label][parameter] for label in self.labels]
1298 return samples
1301class MCMCSamplesDict(_MultiDimensionalSamplesDict):
1302 """Class to store the mcmc chains from a single run
1304 Parameters
1305 ----------
1306 parameters: list
1307 list of parameters
1308 samples: nd list
1309 list of samples for each parameter for each chain
1310 transpose: Bool, optional
1311 True if the input is a transposed dictionary
1313 Attributes
1314 ----------
1315 T: pesummary.utils.samples_dict.MCMCSamplesDict
1316 Transposed MCMCSamplesDict object keyed by parameters rather than
1317 chain
1318 average: pesummary.utils.samples_dict.SamplesDict
1319 The mean of each sample across multiple chains. If the chains are of
1320 different lengths, all chains are resized to the minimum number of
1321 samples
1322 combine: pesummary.utils.samples_dict.SamplesDict
1323 Combine all samples from all chains into a single SamplesDict object
1324 nchains: int
1325 Total number of chains stored in the MCMCSamplesDict object
1326 number_of_samples: dict
1327 Number of samples stored in the MCMCSamplesDict for each chain
1328 total_number_of_samples: int
1329 Total number of samples stored across the multiple chains
1330 minimum_number_of_samples: int
1331 The number of samples in the smallest chain
1333 Methods
1334 -------
1335 discard_samples:
1336 Discard the first N samples for each chain
1337 burnin:
1338 Remove the first N samples as burnin. For different algorithms
1339 see pesummary.core.file.mcmc.algorithms
1340 gelman_rubin: float
1341 Return the Gelman-Rubin statistic between the chains for a given
1342 parameter. See pesummary.utils.utils.gelman_rubin
1343 samples:
1344 Return a list of samples stored in the MCMCSamplesDict object for a
1345 given parameter
1347 Examples
1348 --------
1349 Initializing the MCMCSamplesDict class
1351 >>> from pesummary.utils.samplesdict import MCMCSamplesDict
1352 >>> data = {
1353 ... "chain_0": {
1354 ... "a": [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6],
1355 ... "b": [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9]
1356 ... },
1357 ... "chain_1": {
1358 ... "a": [0.8, 0.5, 1.7, 1.4, 1.2, 1.7, 0.9],
1359 ... "b": [10, 10.5, 10.4, 9.6, 8.6, 11.6, 16.2]
1360 ... }
1361 ... }
1362 >>> dataset = MCMCSamplesDict(data)
1363 >>> parameters = ["a", "b"]
1364 >>> samples = [
1365 ... [
1366 ... [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6],
1367 ... [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9]
1368 ... ], [
1369 ... [0.8, 0.5, 1.7, 1.4, 1.2, 1.7, 0.9],
1370 ... [10, 10.5, 10.4, 9.6, 8.6, 11.6, 16.2]
1371 ... ]
1372 ... ]
1373 >>> dataset = MCMCSamplesDict(parameter, samples)
1374 """
1375 def __init__(self, *args, transpose=False):
1376 single_chain_error = (
1377 "This class requires more than one mcmc chain to be passed. "
1378 "As only one dataset is available, please use the SamplesDict "
1379 "class."
1380 )
1381 super(MCMCSamplesDict, self).__init__(
1382 *args, transpose=transpose, label_prefix="chain"
1383 )
1384 self.name = MCMCSamplesDict
1385 if len(self.labels) == 1:
1386 raise ValueError(single_chain_error)
1387 self.chains = self.labels
1388 self.nchains = self.nsamples
1390 @property
1391 def average(self):
1392 if self.transpose:
1393 data = SamplesDict({
1394 param: np.mean(
1395 [
1396 self[param][key][:self.minimum_number_of_samples] for
1397 key in self[param].keys()
1398 ], axis=0
1399 ) for param in self.parameters
1400 }, logger_warn="debug")
1401 else:
1402 data = SamplesDict({
1403 param: np.mean(
1404 [
1405 self[key][param][:self.minimum_number_of_samples] for
1406 key in self.keys()
1407 ], axis=0
1408 ) for param in self.parameters
1409 }, logger_warn="debug")
1410 return data
1412 @property
1413 def key_data(self):
1414 data = {}
1415 for param, value in self.combine.items():
1416 data[param] = value.key_data
1417 return data
1419 @property
1420 def combine(self):
1421 return self._combine(use_all=True, weights=None)
1423 def discard_samples(self, number):
1424 """Remove the first n samples
1426 Parameters
1427 ----------
1428 number: int/dict
1429 Number of samples that you wish to remove across all chains or a
1430 dictionary containing the number of samples to remove per chain
1431 """
1432 if isinstance(number, int):
1433 number = {chain: number for chain in self.keys()}
1434 for chain in self.keys():
1435 self[chain].discard_samples(number[chain])
1436 return self
1438 def burnin(self, *args, algorithm="burnin_by_step_number", **kwargs):
1439 """Remove the first N samples as burnin
1441 Parameters
1442 ----------
1443 algorithm: str, optional
1444 The algorithm you wish to use to remove samples as burnin. Default
1445 is 'burnin_by_step_number'. See
1446 `pesummary.core.file.mcmc.algorithms` for list of available
1447 algorithms
1448 """
1449 from pesummary.core.file import mcmc
1451 if algorithm not in mcmc.algorithms:
1452 raise ValueError(
1453 "{} is not a valid algorithm for removing samples as "
1454 "burnin".format(algorithm)
1455 )
1456 arguments = [self] + [i for i in args]
1457 return getattr(mcmc, algorithm)(*arguments, **kwargs)
1459 def gelman_rubin(self, parameter, decimal=5):
1460 """Return the gelman rubin statistic between chains for a given
1461 parameter
1463 Parameters
1464 ----------
1465 parameter: str
1466 name of the parameter you wish to return the gelman rubin statistic
1467 for
1468 decimal: int
1469 number of decimal places to keep when rounding
1470 """
1471 from pesummary.utils.utils import gelman_rubin as _gelman_rubin
1473 return _gelman_rubin(self.samples(parameter), decimal=decimal)
1476class MultiAnalysisSamplesDict(_MultiDimensionalSamplesDict):
1477 """Class to samples from multiple analyses
1479 Parameters
1480 ----------
1481 parameters: list
1482 list of parameters
1483 samples: nd list
1484 list of samples for each parameter for each chain
1485 labels: list, optional
1486 the labels to use to distinguish different analyses.
1487 transpose: Bool, optional
1488 True if the input is a transposed dictionary
1490 Attributes
1491 ----------
1492 T: pesummary.utils.samples_dict.MultiAnalysisSamplesDict
1493 Transposed MultiAnalysisSamplesDict object keyed by parameters
1494 rather than label
1495 nsamples: int
1496 Total number of analyses stored in the MultiAnalysisSamplesDict
1497 object
1498 number_of_samples: dict
1499 Number of samples stored in the MultiAnalysisSamplesDict for each
1500 analysis
1501 total_number_of_samples: int
1502 Total number of samples stored across the multiple analyses
1503 minimum_number_of_samples: int
1504 The number of samples in the smallest analysis
1505 available_plots: list
1506 list of plots which the user may user to display the contained posterior
1507 samples
1509 Methods
1510 -------
1511 from_files:
1512 Initialize the MultiAnalysisSamplesDict class with the contents of
1513 multiple files
1514 combine: pesummary.utils.samples_dict.SamplesDict
1515 Combine samples from a select number of analyses into a single
1516 SamplesDict object.
1517 js_divergence: float
1518 Return the JS divergence between two posterior distributions for a
1519 given parameter. See pesummary.utils.utils.jensen_shannon_divergence
1520 ks_statistic: float
1521 Return the KS statistic between two posterior distributions for a
1522 given parameter. See pesummary.utils.utils.kolmogorov_smirnov_test
1523 samples:
1524 Return a list of samples stored in the MCMCSamplesDict object for a
1525 given parameter
1526 write:
1527 Save the stored posterior samples to file
1528 """
1529 def __init__(self, *args, labels=None, transpose=False):
1530 if labels is None and not isinstance(args[0], dict):
1531 raise ValueError(
1532 "Please provide a unique label for each analysis"
1533 )
1534 super(MultiAnalysisSamplesDict, self).__init__(
1535 *args, labels=labels, transpose=transpose
1536 )
1537 self.name = MultiAnalysisSamplesDict
1539 @classmethod
1540 def from_files(cls, filenames, **kwargs):
1541 """Initialize the MultiAnalysisSamplesDict class with the contents of
1542 multiple result files
1544 Parameters
1545 ----------
1546 filenames: dict
1547 dictionary containing the path to the result file you wish to load
1548 as the item and a label associated with each result file as the key.
1549 If you are providing one or more PESummary metafiles, the key
1550 is ignored and labels stored in the metafile are used.
1551 **kwargs: dict
1552 all kwargs are passed to the pesummary.io.read function
1553 """
1554 from pesummary.io import read
1556 samples = {}
1557 for label, filename in filenames.items():
1558 _kwargs = kwargs
1559 if label in kwargs.keys():
1560 _kwargs = kwargs[label]
1561 _file = read(filename, **_kwargs)
1562 _samples = _file.samples_dict
1563 if isinstance(_samples, MultiAnalysisSamplesDict):
1564 _stored_labels = _samples.keys()
1565 cond1 = any(
1566 _label in filenames.keys() for _label in _stored_labels if
1567 _label != label
1568 )
1569 cond2 = any(
1570 _label in samples.keys() for _label in _stored_labels
1571 )
1572 if cond1 or cond2:
1573 raise ValueError(
1574 "The file '{}' contains the labels: {}. The "
1575 "dictionary already contains the labels: {}. Please "
1576 "provide unique labels for each dataset".format(
1577 filename, ", ".join(_stored_labels),
1578 ", ".join(samples.keys())
1579 )
1580 )
1581 samples.update(_samples)
1582 else:
1583 if label in samples.keys():
1584 raise ValueError(
1585 "The label '{}' has alreadt been used. Please select "
1586 "another label".format(label)
1587 )
1588 samples[label] = _samples
1589 return cls(samples)
1591 @property
1592 def plotting_map(self):
1593 return {
1594 "hist": self._marginalized_posterior,
1595 "corner": self._corner,
1596 "triangle": self._triangle,
1597 "reverse_triangle": self._reverse_triangle,
1598 "violin": self._violin,
1599 "2d_kde": self._2d_kde
1600 }
1602 @property
1603 def available_plots(self):
1604 return list(self.plotting_map.keys())
1606 @docstring_subfunction([
1607 'pesummary.core.plots.plot._1d_comparison_histogram_plot',
1608 'pesummary.gw.plots.plot._1d_comparison_histogram_plot',
1609 'pesummary.core.plots.publication.triangle_plot',
1610 'pesummary.core.plots.publication.reverse_triangle_plot'
1611 ])
1612 def plot(
1613 self, *args, type="hist", labels="all", colors=None, latex_friendly=True,
1614 **kwargs
1615 ):
1616 """Generate a plot for the posterior samples stored in
1617 MultiDimensionalSamplesDict
1619 Parameters
1620 ----------
1621 *args: tuple
1622 all arguments are passed to the plotting function
1623 type: str
1624 name of the plot you wish to make
1625 labels: list
1626 list of analyses that you wish to include in the plot
1627 colors: list
1628 list of colors to use for each analysis
1629 latex_friendly: Bool, optional
1630 if True, make the labels latex friendly. Default True
1631 **kwargs: dict
1632 all additional kwargs are passed to the plotting function
1633 """
1634 if type not in self.plotting_map.keys():
1635 raise NotImplementedError(
1636 "The {} method is not currently implemented. The allowed "
1637 "plotting methods are {}".format(
1638 type, ", ".join(self.available_plots)
1639 )
1640 )
1642 self._update_latex_labels()
1643 if labels == "all":
1644 labels = self.labels
1645 elif isinstance(labels, list):
1646 for label in labels:
1647 if label not in self.labels:
1648 raise ValueError(
1649 "'{}' is not a stored analysis. The available analyses "
1650 "are: '{}'".format(label, ", ".join(self.labels))
1651 )
1652 else:
1653 raise ValueError(
1654 "Please provide a list of analyses that you wish to plot"
1655 )
1656 if colors is None:
1657 colors = list(conf.colorcycle)
1658 while len(colors) < len(labels):
1659 colors += colors
1661 kwargs["labels"] = labels
1662 kwargs["colors"] = colors
1663 kwargs["latex_friendly"] = latex_friendly
1664 return self.plotting_map[type](*args, **kwargs)
1666 def _marginalized_posterior(
1667 self, parameter, module="core", labels="all", colors=None, **kwargs
1668 ):
1669 """Wrapper for the
1670 `pesummary.core.plots.plot._1d_comparison_histogram_plot` or
1671 `pesummary.gw.plots.plot._comparison_1d_histogram_plot`
1673 Parameters
1674 ----------
1675 parameter: str
1676 name of the parameter you wish to plot
1677 module: str, optional
1678 module you wish to use for the plotting
1679 labels: list
1680 list of analyses that you wish to include in the plot
1681 colors: list
1682 list of colors to use for each analysis
1683 **kwargs: dict
1684 all additional kwargs are passed to the
1685 `_1d_comparison_histogram_plot` function
1686 """
1687 module = importlib.import_module(
1688 "pesummary.{}.plots.plot".format(module)
1689 )
1690 return getattr(module, "_1d_comparison_histogram_plot")(
1691 parameter, [self[label][parameter] for label in labels],
1692 colors, self.latex_labels[parameter], labels, **kwargs
1693 )
1695 def _base_triangle(self, parameters, labels="all"):
1696 """Check that the parameters are valid for the different triangle
1697 plots available
1699 Parameters
1700 ----------
1701 parameters: list
1702 list of parameters they wish to study
1703 labels: list
1704 list of analyses that you wish to include in the plot
1705 """
1706 samples = [self[label] for label in labels]
1707 if len(parameters) > 2:
1708 raise ValueError("Function is only 2d")
1709 condition = set(
1710 label for num, label in enumerate(labels) for param in parameters if
1711 param not in samples[num].keys()
1712 )
1713 if len(condition):
1714 raise ValueError(
1715 "{} and {} are not available for the following "
1716 " analyses: {}".format(
1717 parameters[0], parameters[1], ", ".join(condition)
1718 )
1719 )
1720 return samples
1722 def _triangle(self, parameters, labels="all", module="core", **kwargs):
1723 """Wrapper for the `pesummary.core.plots.publication.triangle_plot`
1724 function
1726 Parameters
1727 ----------
1728 parameters: list
1729 list of parameters they wish to study
1730 labels: list
1731 list of analyses that you wish to include in the plot
1732 **kwargs: dict
1733 all additional kwargs are passed to the `triangle_plot` function
1734 """
1735 _module = importlib.import_module(
1736 "pesummary.{}.plots.publication".format(module)
1737 )
1738 samples = self._base_triangle(parameters, labels=labels)
1739 if module == "gw":
1740 kwargs["parameters"] = parameters
1741 return getattr(_module, "triangle_plot")(
1742 [_samples[parameters[0]] for _samples in samples],
1743 [_samples[parameters[1]] for _samples in samples],
1744 xlabel=self.latex_labels[parameters[0]],
1745 ylabel=self.latex_labels[parameters[1]], labels=labels, **kwargs
1746 )
1748 def _reverse_triangle(self, parameters, labels="all", module="core", **kwargs):
1749 """Wrapper for the `pesummary.core.plots.publication.reverse_triangle_plot`
1750 function
1752 Parameters
1753 ----------
1754 parameters: list
1755 list of parameters they wish to study
1756 labels: list
1757 list of analyses that you wish to include in the plot
1758 **kwargs: dict
1759 all additional kwargs are passed to the `triangle_plot` function
1760 """
1761 _module = importlib.import_module(
1762 "pesummary.{}.plots.publication".format(module)
1763 )
1764 samples = self._base_triangle(parameters, labels=labels)
1765 if module == "gw":
1766 kwargs["parameters"] = parameters
1767 return getattr(_module, "reverse_triangle_plot")(
1768 [_samples[parameters[0]] for _samples in samples],
1769 [_samples[parameters[1]] for _samples in samples],
1770 xlabel=self.latex_labels[parameters[0]],
1771 ylabel=self.latex_labels[parameters[1]], labels=labels, **kwargs
1772 )
1774 def _violin(
1775 self, parameter, labels="all", priors=None, latex_labels=GWlatex_labels,
1776 **kwargs
1777 ):
1778 """Wrapper for the `pesummary.gw.plots.publication.violin_plots`
1779 function
1781 Parameters
1782 ----------
1783 parameter: str, optional
1784 name of the parameter you wish to generate a violin plot for
1785 labels: list
1786 list of analyses that you wish to include in the plot
1787 priors: MultiAnalysisSamplesDict, optional
1788 prior samples for each analysis. If provided, the right hand side
1789 of each violin will show the prior
1790 latex_labels: dict, optional
1791 dictionary containing the latex label associated with parameter
1792 **kwargs: dict
1793 all additional kwargs are passed to the `violin_plots` function
1794 """
1795 from pesummary.gw.plots.publication import violin_plots
1797 _labels = [label for label in labels if parameter in self[label].keys()]
1798 if not len(_labels):
1799 raise ValueError(
1800 "{} is not in any of the posterior samples tables. Please "
1801 "choose another parameter to plot".format(parameter)
1802 )
1803 elif len(_labels) != len(labels):
1804 no = list(set(labels) - set(_labels))
1805 logger.warning(
1806 "Unable to generate a violin plot for {} because {} is not "
1807 "in their posterior samples table".format(
1808 " or ".join(no), parameter
1809 )
1810 )
1811 samples = [self[label][parameter] for label in _labels]
1812 if priors is not None and not all(
1813 label in priors.keys() for label in _labels
1814 ):
1815 raise ValueError("Please provide prior samples for all labels")
1816 elif priors is not None and not all(
1817 parameter in priors[label].keys() for label in _labels
1818 ):
1819 raise ValueError(
1820 "Please provide prior samples for {} for all labels".format(
1821 parameter
1822 )
1823 )
1824 elif priors is not None:
1825 from pesummary.core.plots.seaborn.violin import split_dataframe
1827 priors = [priors[label][parameter] for label in _labels]
1828 samples = split_dataframe(samples, priors, _labels)
1829 palette = kwargs.get("palette", None)
1830 left, right = "color: white", "pastel"
1831 if palette is not None and not isinstance(palette, dict):
1832 right = palette
1833 elif palette is not None and all(
1834 side in palette.keys() for side in ["left", "right"]
1835 ):
1836 left, right = palette["left"], palette["right"]
1837 kwargs.update(
1838 {
1839 "split": True, "x": "label", "y": "data", "hue": "side",
1840 "palette": {"right": right, "left": left}
1841 }
1842 )
1843 return violin_plots(
1844 parameter, samples, _labels, latex_labels, **kwargs
1845 )
1847 def _corner(self, module="core", labels="all", parameters=None, **kwargs):
1848 """Wrapper for the `pesummary.core.plots.plot._make_comparison_corner_plot`
1849 or `pesummary.gw.plots.plot._make_comparison_corner_plot` function
1851 Parameters
1852 ----------
1853 module: str, optional
1854 module you wish to use for the plotting
1855 labels: list
1856 list of analyses that you wish to include in the plot
1857 **kwargs: dict
1858 all additional kwargs are passed to the `_make_comparison_corner_plot`
1859 function
1860 """
1861 module = importlib.import_module(
1862 "pesummary.{}.plots.plot".format(module)
1863 )
1864 _samples = {label: self[label] for label in labels}
1865 _parameters = None
1866 if parameters is not None:
1867 _parameters = [
1868 param for param in parameters if all(
1869 param in posterior for posterior in _samples.values()
1870 )
1871 ]
1872 if not len(_parameters):
1873 raise ValueError(
1874 "None of the chosen parameters are in all of the posterior "
1875 "samples tables. Please choose other parameters to plot"
1876 )
1877 return getattr(module, "_make_comparison_corner_plot")(
1878 _samples, self.latex_labels, corner_parameters=_parameters, **kwargs
1879 )
1881 def _2d_kde(
1882 self, parameters, module="core", labels="all", plot_density=None,
1883 **kwargs
1884 ):
1885 """Wrapper for the
1886 `pesummary.gw.plots.publication.comparison_twod_contour_plot` or
1887 `pesummary.core.plots.publication.comparison_twod_contour_plot` function
1889 Parameters
1890 ----------
1891 parameters: list
1892 list of length 2 giving the parameters you wish to plot
1893 module: str, optional
1894 module you wish to use for the plotting
1895 labels: list
1896 list of analyses that you wish to include in the plot
1897 **kwargs: dict, optional
1898 all additional kwargs are passed to the
1899 `comparison_twod_contour_plot` function
1900 """
1901 _module = importlib.import_module(
1902 "pesummary.{}.plots.publication".format(module)
1903 )
1904 samples = self._base_triangle(parameters, labels=labels)
1905 if plot_density is not None:
1906 if isinstance(plot_density, str):
1907 plot_density = [plot_density]
1908 elif isinstance(plot_density, bool) and plot_density:
1909 plot_density = labels
1910 for i in plot_density:
1911 if i not in labels:
1912 raise ValueError(
1913 "Unable to plot the density for '{}'. Please choose "
1914 "from: {}".format(plot_density, ", ".join(labels))
1915 )
1916 if module == "gw":
1917 return getattr(_module, "twod_contour_plots")(
1918 parameters, [
1919 [self[label][param] for param in parameters] for label in
1920 labels
1921 ], labels, {
1922 parameters[0]: self.latex_labels[parameters[0]],
1923 parameters[1]: self.latex_labels[parameters[1]]
1924 }, plot_density=plot_density, **kwargs
1925 )
1926 return getattr(_module, "comparison_twod_contour_plot")(
1927 [_samples[parameters[0]] for _samples in samples],
1928 [_samples[parameters[1]] for _samples in samples],
1929 xlabel=self.latex_labels[parameters[0]],
1930 ylabel=self.latex_labels[parameters[1]], labels=labels,
1931 plot_density=plot_density, **kwargs
1932 )
1934 def combine(self, **kwargs):
1935 """Combine samples from a select number of analyses into a single
1936 SamplesDict object.
1938 Parameters
1939 ----------
1940 labels: list, optional
1941 analyses you wish to combine. Default use all labels stored in the
1942 dictionary
1943 use_all: Bool, optional
1944 if True, use all of the samples (do not weight). Default False
1945 weights: dict, optional
1946 dictionary of weights for each of the posteriors. Keys must be the
1947 labels you wish to combine and values are the weights you wish to
1948 assign to the posterior
1949 logger_level: str, optional
1950 logger level you wish to use. Default debug.
1951 """
1952 return self._combine(**kwargs)
1954 def write(self, labels=None, **kwargs):
1955 """Save the stored posterior samples to file
1957 Parameters
1958 ----------
1959 labels: list, optional
1960 list of analyses that you wish to save to file. Default save all
1961 analyses to file
1962 **kwargs: dict, optional
1963 all additional kwargs passed to the pesummary.io.write function
1964 """
1965 if labels is None:
1966 labels = self.labels
1967 elif not all(label in self.labels for label in labels):
1968 for label in labels:
1969 if label not in self.labels:
1970 raise ValueError(
1971 "Unable to find analysis: '{}'. The list of "
1972 "available analyses are: {}".format(
1973 label, ", ".join(self.labels)
1974 )
1975 )
1976 for label in labels:
1977 self[label].write(**kwargs)
1979 def js_divergence(self, parameter, decimal=5):
1980 """Return the JS divergence between the posterior samples for
1981 a given parameter
1983 Parameters
1984 ----------
1985 parameter: str
1986 name of the parameter you wish to return the gelman rubin statistic
1987 for
1988 decimal: int
1989 number of decimal places to keep when rounding
1990 """
1991 from pesummary.utils.utils import jensen_shannon_divergence
1993 return jensen_shannon_divergence(
1994 self.samples(parameter), decimal=decimal
1995 )
1997 def ks_statistic(self, parameter, decimal=5):
1998 """Return the KS statistic between the posterior samples for
1999 a given parameter
2001 Parameters
2002 ----------
2003 parameter: str
2004 name of the parameter you wish to return the gelman rubin statistic
2005 for
2006 decimal: int
2007 number of decimal places to keep when rounding
2008 """
2009 from pesummary.utils.utils import kolmogorov_smirnov_test
2011 return kolmogorov_smirnov_test(
2012 self.samples(parameter), decimal=decimal
2013 )