Coverage for pesummary/utils/samples_dict.py: 60.5%
631 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import copy
4import numpy as np
5from pesummary.utils.utils import resample_posterior_distribution, logger
6from pesummary.utils.decorators import docstring_subfunction
7from pesummary.utils.array import Array, _2DArray
8from pesummary.utils.dict import Dict
9from pesummary.utils.parameters import Parameters
10from pesummary.core.plots.latex_labels import latex_labels
11from pesummary.gw.plots.latex_labels import GWlatex_labels
12from pesummary import conf
13import importlib
15__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
17latex_labels.update(GWlatex_labels)
20class SamplesDict(Dict):
21 """Class to store the samples from a single run
23 Parameters
24 ----------
25 parameters: list
26 list of parameters
27 samples: nd list
28 list of samples for each parameter
29 autoscale: Bool, optional
30 If True, the posterior samples for each parameter are scaled to the
31 same length
33 Attributes
34 ----------
35 maxL: pesummary.utils.samples_dict.SamplesDict
36 SamplesDict object containing the maximum likelihood sample keyed by
37 the parameter
38 minimum: pesummary.utils.samples_dict.SamplesDict
39 SamplesDict object containing the minimum sample for each parameter
40 maximum: pesummary.utils.samples_dict.SamplesDict
41 SamplesDict object containing the maximum sample for each parameter
42 median: pesummary.utils.samples_dict.SamplesDict
43 SamplesDict object containining the median of each marginalized
44 posterior distribution
45 mean: pesummary.utils.samples_dict.SamplesDict
46 SamplesDict object containing the mean of each marginalized posterior
47 distribution
48 key_data: dict
49 dictionary containing the key data associated with each array
50 number_of_samples: int
51 Number of samples stored in the SamplesDict object
52 latex_labels: dict
53 Dictionary of latex labels for each parameter
54 available_plots: list
55 list of plots which the user may user to display the contained posterior
56 samples
58 Methods
59 -------
60 from_file:
61 Initialize the SamplesDict class with the contents of a file
62 to_pandas:
63 Convert the SamplesDict object to a pandas DataFrame
64 to_structured_array:
65 Convert the SamplesDict object to a numpy structured array
66 pop:
67 Remove an entry from the SamplesDict object
68 standardize_parameter_names:
69 Modify keys in SamplesDict to use standard PESummary names
70 downsample:
71 Downsample the samples stored in the SamplesDict object. See the
72 pesummary.utils.utils.resample_posterior_distribution method
73 discard_samples:
74 Remove the first N samples from each distribution
75 plot:
76 Generate a plot based on the posterior samples stored
77 generate_all_posterior_samples:
78 Convert the posterior samples in the SamplesDict object according to
79 a conversion function
80 debug_keys: list
81 list of keys with an '_' as their first character
82 reweight:
83 Reweight the posterior samples according to a new prior
84 write:
85 Save the stored posterior samples to file
87 Examples
88 --------
89 How the initialize the SamplesDict class
91 >>> from pesummary.utils.samples_dict import SamplesDict
92 >>> data = {
93 ... "a": [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6],
94 ... "b": [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9]
95 ... }
96 >>> dataset = SamplesDict(data)
97 >>> parameters = ["a", "b"]
98 >>> samples = [
99 ... [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6],
100 ... [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9]
101 ... }
102 >>> dataset = SamplesDict(parameters, samples)
103 >>> fig = dataset.plot("a", type="hist", bins=30)
104 >>> fig.show()
105 """
106 def __init__(self, *args, logger_warn="warn", autoscale=True,**kwargs):
107 super(SamplesDict, self).__init__(
108 *args, value_class=Array, make_dict_kwargs={"autoscale": autoscale},
109 logger_warn=logger_warn, latex_labels=latex_labels, **kwargs
110 )
112 def __getitem__(self, key):
113 """Return an object representing the specialization of SamplesDict
114 by type arguments found in key.
115 """
116 if isinstance(key, slice):
117 return SamplesDict(
118 self.parameters, np.array(
119 [i[key.start:key.stop:key.step] for i in self.samples]
120 )
121 )
122 elif isinstance(key, (list, np.ndarray)):
123 return SamplesDict(
124 self.parameters, np.array([i[key] for i in self.samples])
125 )
126 elif key[0] == "_":
127 return self.samples[self.parameters.index(key)]
128 return super(SamplesDict, self).__getitem__(key)
130 def __setitem__(self, key, value):
131 _value = value
132 if not isinstance(value, Array):
133 _value = Array(value)
134 super(SamplesDict, self).__setitem__(key, _value)
135 try:
136 if key not in self.parameters:
137 self.parameters.append(key)
138 try:
139 cond = (
140 np.array(self.samples).ndim == 1 and isinstance(
141 self.samples[0], (float, int, np.number)
142 )
143 )
144 except Exception:
145 cond = False
146 if cond and isinstance(self.samples, np.ndarray):
147 self.samples = np.append(self.samples, value)
148 elif cond and isinstance(self.samples, list):
149 self.samples.append(value)
150 else:
151 self.samples = np.vstack([self.samples, value])
152 self._update_latex_labels()
153 except (AttributeError, TypeError):
154 pass
156 def __str__(self):
157 """Print a summary of the information stored in the dictionary
158 """
159 def format_string(string, row):
160 """Format a list into a table
162 Parameters
163 ----------
164 string: str
165 existing table
166 row: list
167 the row you wish to be written to a table
168 """
169 string += "{:<8}".format(row[0])
170 for i in range(1, len(row)):
171 if isinstance(row[i], str):
172 string += "{:<15}".format(row[i])
173 elif isinstance(row[i], (float, int, np.int64, np.int32)):
174 string += "{:<15.6f}".format(row[i])
175 string += "\n"
176 return string
178 string = ""
179 string = format_string(string, ["idx"] + list(self.keys()))
181 if self.number_of_samples < 8:
182 for i in range(self.number_of_samples):
183 string = format_string(
184 string, [i] + [item[i] for key, item in self.items()]
185 )
186 else:
187 for i in range(4):
188 string = format_string(
189 string, [i] + [item[i] for key, item in self.items()]
190 )
191 for i in range(2):
192 string = format_string(string, ["."] * (len(self.keys()) + 1))
193 for i in range(self.number_of_samples - 2, self.number_of_samples):
194 string = format_string(
195 string, [i] + [item[i] for key, item in self.items()]
196 )
197 return string
199 @classmethod
200 def from_file(cls, filename, **kwargs):
201 """Initialize the SamplesDict class with the contents of a result file
203 Parameters
204 ----------
205 filename: str
206 path to the result file you wish to load.
207 **kwargs: dict
208 all kwargs are passed to the pesummary.io.read function
209 """
210 from pesummary.io import read
212 return read(filename, **kwargs).samples_dict
214 @property
215 def key_data(self):
216 return {param: value.key_data for param, value in self.items()}
218 @property
219 def maxL(self):
220 return SamplesDict(
221 self.parameters, [[item.maxL] for key, item in self.items()]
222 )
224 @property
225 def minimum(self):
226 return SamplesDict(
227 self.parameters, [[item.minimum] for key, item in self.items()]
228 )
230 @property
231 def maximum(self):
232 return SamplesDict(
233 self.parameters, [[item.maximum] for key, item in self.items()]
234 )
236 @property
237 def median(self):
238 return SamplesDict(
239 self.parameters,
240 [[item.average(type="median")] for key, item in self.items()]
241 )
243 @property
244 def mean(self):
245 return SamplesDict(
246 self.parameters,
247 [[item.average(type="mean")] for key, item in self.items()]
248 )
250 @property
251 def number_of_samples(self):
252 return len(self[self.parameters[0]])
254 @property
255 def plotting_map(self):
256 existing = super(SamplesDict, self).plotting_map
257 modified = existing.copy()
258 modified.update(
259 {
260 "marginalized_posterior": self._marginalized_posterior,
261 "skymap": self._skymap,
262 "hist": self._marginalized_posterior,
263 "corner": self._corner,
264 "spin_disk": self._spin_disk,
265 "2d_kde": self._2d_kde,
266 "triangle": self._triangle,
267 "reverse_triangle": self._reverse_triangle,
268 }
269 )
270 return modified
272 def standardize_parameter_names(self, mapping=None):
273 """Modify keys in SamplesDict to use standard PESummary names
275 Parameters
276 ----------
277 mapping: dict, optional
278 dictionary mapping existing keys to standard PESummary names.
279 Default pesummary.gw.file.standard_names.standard_names
281 Returns
282 -------
283 standard_dict: SamplesDict
284 SamplesDict object with standard PESummary parameter names
285 """
286 from pesummary.utils.utils import map_parameter_names
287 if mapping is None:
288 from pesummary.gw.file.standard_names import standard_names
289 mapping = standard_names
290 return SamplesDict(map_parameter_names(self, mapping))
292 def debug_keys(self, *args, **kwargs):
293 _keys = self.keys()
294 _total = self.keys(remove_debug=False)
295 return Parameters([key for key in _total if key not in _keys])
297 def keys(self, *args, remove_debug=True, **kwargs):
298 original = super(SamplesDict, self).keys(*args, **kwargs)
299 if remove_debug:
300 return Parameters([key for key in original if key[0] != "_"])
301 return Parameters(original)
303 def write(self, **kwargs):
304 """Save the stored posterior samples to file
306 Parameters
307 ----------
308 **kwargs: dict, optional
309 all additional kwargs passed to the pesummary.io.write function
310 """
311 from pesummary.io import write
312 write(self.parameters, self.samples.T, **kwargs)
314 def items(self, *args, remove_debug=True, **kwargs):
315 items = super(SamplesDict, self).items(*args, **kwargs)
316 if remove_debug:
317 return [item for item in items if item[0][0] != "_"]
318 return items
320 def to_pandas(self, **kwargs):
321 """Convert a SamplesDict object to a pandas dataframe
322 """
323 from pandas import DataFrame
325 return DataFrame(self, **kwargs)
327 def to_structured_array(self, **kwargs):
328 """Convert a SamplesDict object to a structured numpy array
329 """
330 return self.to_pandas(**kwargs).to_records(
331 index=False, column_dtypes=float
332 )
334 def pop(self, parameter):
335 """Delete a parameter from the SamplesDict
337 Parameters
338 ----------
339 parameter: str
340 name of the parameter you wish to remove from the SamplesDict
341 """
342 if parameter not in self.parameters:
343 logger.info(
344 "{} not in SamplesDict. Unable to remove {}".format(
345 parameter, parameter
346 )
347 )
348 return
349 ind = self.parameters.index(parameter)
350 self.parameters.remove(parameter)
351 samples = self.samples
352 self.samples = np.delete(samples, ind, axis=0)
353 return super(SamplesDict, self).pop(parameter)
355 def downsample(self, number):
356 """Downsample the samples stored in the SamplesDict class
358 Parameters
359 ----------
360 number: int
361 Number of samples you wish to downsample to
362 """
363 self.samples = resample_posterior_distribution(self.samples, number)
364 self.make_dictionary()
365 return self
367 def discard_samples(self, number):
368 """Remove the first n samples
370 Parameters
371 ----------
372 number: int
373 Number of samples that you wish to remove
374 """
375 self.make_dictionary(discard_samples=number)
376 return self
378 def make_dictionary(self, discard_samples=None, autoscale=True):
379 """Add the parameters and samples to the class
380 """
381 lengths = [len(i) for i in self.samples]
382 if len(np.unique(lengths)) > 1 and autoscale:
383 nsamples = np.min(lengths)
384 getattr(logger, self.logger_warn)(
385 "Unequal number of samples for each parameter. "
386 "Restricting all posterior samples to have {} "
387 "samples".format(nsamples)
388 )
389 self.samples = [
390 dataset[:nsamples] for dataset in self.samples
391 ]
392 if "log_likelihood" in self.parameters:
393 likelihoods = self.samples[self.parameters.index("log_likelihood")]
394 likelihoods = likelihoods[discard_samples:]
395 else:
396 likelihoods = None
397 if "log_prior" in self.parameters:
398 priors = self.samples[self.parameters.index("log_prior")]
399 priors = priors[discard_samples:]
400 else:
401 priors = None
402 if any(i in self.parameters for i in ["weights", "weight"]):
403 ind = (
404 self.parameters.index("weights") if "weights" in self.parameters
405 else self.parameters.index("weight")
406 )
407 weights = self.samples[ind][discard_samples:]
408 else:
409 weights = None
410 _2d_array = _2DArray(
411 np.array(self.samples)[:, discard_samples:], likelihood=likelihoods,
412 prior=priors, weights=weights
413 )
414 for key, val in zip(self.parameters, _2d_array):
415 self[key] = val
417 @docstring_subfunction([
418 'pesummary.core.plots.plot._1d_histogram_plot',
419 'pesummary.gw.plots.plot._1d_histogram_plot',
420 'pesummary.gw.plots.plot._ligo_skymap_plot',
421 'pesummary.gw.plots.publication.spin_distribution_plots',
422 'pesummary.core.plots.plot._make_corner_plot',
423 'pesummary.gw.plots.plot._make_corner_plot'
424 ])
425 def plot(self, *args, type="marginalized_posterior", **kwargs):
426 """Generate a plot for the posterior samples stored in SamplesDict
428 Parameters
429 ----------
430 *args: tuple
431 all arguments are passed to the plotting function
432 type: str
433 name of the plot you wish to make
434 **kwargs: dict
435 all additional kwargs are passed to the plotting function
436 """
437 return super(SamplesDict, self).plot(*args, type=type, **kwargs)
439 def generate_all_posterior_samples(self, function=None, **kwargs):
440 """Convert samples stored in the SamplesDict according to a conversion
441 function
443 Parameters
444 ----------
445 function: func, optional
446 function to use when converting posterior samples. Must take a
447 dictionary as input and return a dictionary of converted posterior
448 samples. Default `pesummary.gw.conversions.convert
449 **kwargs: dict, optional
450 All additional kwargs passed to function
451 """
452 if function is None:
453 from pesummary.gw.conversions import convert
454 function = convert
455 _samples = self.copy()
456 _keys = list(_samples.keys())
457 kwargs.update({"return_dict": True})
458 out = function(_samples, **kwargs)
459 if kwargs.get("return_kwargs", False):
460 converted_samples, extra_kwargs = out
461 else:
462 converted_samples, extra_kwargs = out, None
463 for key, item in converted_samples.items():
464 if key not in _keys:
465 self[key] = item
466 return extra_kwargs
468 def reweight(
469 self, function, ignore_debug_params=["recalib", "spcal"], **kwargs
470 ):
471 """Reweight the posterior samples according to a new prior
473 Parameters
474 ----------
475 function: func/str
476 function to use when resampling
477 ignore_debug_params: list, optional
478 params to ignore when storing unweighted posterior distributions.
479 Default any param with ['recalib', 'spcal'] in their name
480 """
481 from pesummary.gw.reweight import options
482 if isinstance(function, str) and function in options.keys():
483 function = options[function]
484 elif isinstance(function, str):
485 raise ValueError(
486 "Unknown function '{}'. Please provide a function for "
487 "reweighting or select one of the following: {}".format(
488 function, ", ".join(list(options.keys()))
489 )
490 )
491 _samples = SamplesDict(self.copy())
492 new_samples = function(_samples, **kwargs)
493 _samples.downsample(new_samples.number_of_samples)
494 for key, item in new_samples.items():
495 if not any(param in key for param in ignore_debug_params):
496 _samples["_{}_non_reweighted".format(key)] = _samples[key]
497 _samples[key] = item
498 return SamplesDict(_samples)
500 def _marginalized_posterior(self, parameter, module="core", **kwargs):
501 """Wrapper for the `pesummary.core.plots.plot._1d_histogram_plot` or
502 `pesummary.gw.plots.plot._1d_histogram_plot`
504 Parameters
505 ----------
506 parameter: str
507 name of the parameter you wish to plot
508 module: str, optional
509 module you wish to use for the plotting
510 **kwargs: dict
511 all additional kwargs are passed to the `_1d_histogram_plot`
512 function
513 """
514 module = importlib.import_module(
515 "pesummary.{}.plots.plot".format(module)
516 )
517 return getattr(module, "_1d_histogram_plot")(
518 parameter, self[parameter], self.latex_labels[parameter],
519 weights=self[parameter].weights, **kwargs
520 )
522 def _skymap(self, **kwargs):
523 """Wrapper for the `pesummary.gw.plots.plot._ligo_skymap_plot`
524 function
526 Parameters
527 ----------
528 **kwargs: dict
529 All kwargs are passed to the `_ligo_skymap_plot` function
530 """
531 from pesummary.gw.plots.plot import _ligo_skymap_plot
533 if "luminosity_distance" in self.keys():
534 dist = self["luminosity_distance"]
535 else:
536 dist = None
538 return _ligo_skymap_plot(self["ra"], self["dec"], dist=dist, **kwargs)
540 def _spin_disk(self, **kwargs):
541 """Wrapper for the `pesummary.gw.plots.publication.spin_distribution_plots`
542 function
543 """
544 from pesummary.gw.plots.publication import spin_distribution_plots
546 required = ["a_1", "a_2", "cos_tilt_1", "cos_tilt_2"]
547 if not all(param in self.keys() for param in required):
548 raise ValueError(
549 "The spin disk plot requires samples for the following "
550 "parameters: {}".format(", ".join(required))
551 )
552 samples = [self[param] for param in required]
553 return spin_distribution_plots(required, samples, None, **kwargs)
555 def _corner(self, module="core", parameters=None, **kwargs):
556 """Wrapper for the `pesummary.core.plots.plot._make_corner_plot` or
557 `pesummary.gw.plots.plot._make_corner_plot` function
559 Parameters
560 ----------
561 module: str, optional
562 module you wish to use for the plotting
563 **kwargs: dict
564 all additional kwargs are passed to the `_make_corner_plot`
565 function
566 """
567 module = importlib.import_module(
568 "pesummary.{}.plots.plot".format(module)
569 )
570 return getattr(module, "_make_corner_plot")(
571 self, self.latex_labels, corner_parameters=parameters, **kwargs
572 )[0]
574 def _2d_kde(self, parameters, module="core", **kwargs):
575 """Wrapper for the `pesummary.gw.plots.publication.twod_contour_plot` or
576 `pesummary.core.plots.publication.twod_contour_plot` function
578 Parameters
579 ----------
580 parameters: list
581 list of length 2 giving the parameters you wish to plot
582 module: str, optional
583 module you wish to use for the plotting
584 **kwargs: dict, optional
585 all additional kwargs are passed to the `twod_contour_plot` function
586 """
587 _module = importlib.import_module(
588 "pesummary.{}.plots.publication".format(module)
589 )
590 if module == "gw":
591 return getattr(_module, "twod_contour_plots")(
592 parameters, [[self[parameters[0]], self[parameters[1]]]],
593 [None], {
594 parameters[0]: self.latex_labels[parameters[0]],
595 parameters[1]: self.latex_labels[parameters[1]]
596 }, **kwargs
597 )
598 return getattr(_module, "twod_contour_plot")(
599 self[parameters[0]], self[parameters[1]],
600 xlabel=self.latex_labels[parameters[0]],
601 ylabel=self.latex_labels[parameters[1]], **kwargs
602 )
604 def _triangle(self, parameters, module="core", **kwargs):
605 """Wrapper for the `pesummary.core.plots.publication.triangle_plot`
606 function
608 Parameters
609 ----------
610 parameters: list
611 list of parameters they wish to study
612 **kwargs: dict
613 all additional kwargs are passed to the `triangle_plot` function
614 """
615 _module = importlib.import_module(
616 "pesummary.{}.plots.publication".format(module)
617 )
618 if module == "gw":
619 kwargs["parameters"] = parameters
620 return getattr(_module, "triangle_plot")(
621 [self[parameters[0]]], [self[parameters[1]]],
622 xlabel=self.latex_labels[parameters[0]],
623 ylabel=self.latex_labels[parameters[1]], **kwargs
624 )
626 def _reverse_triangle(self, parameters, module="core", **kwargs):
627 """Wrapper for the `pesummary.core.plots.publication.reverse_triangle_plot`
628 function
630 Parameters
631 ----------
632 parameters: list
633 list of parameters they wish to study
634 **kwargs: dict
635 all additional kwargs are passed to the `triangle_plot` function
636 """
637 _module = importlib.import_module(
638 "pesummary.{}.plots.publication".format(module)
639 )
640 if module == "gw":
641 kwargs["parameters"] = parameters
642 return getattr(_module, "reverse_triangle_plot")(
643 [self[parameters[0]]], [self[parameters[1]]],
644 xlabel=self.latex_labels[parameters[0]],
645 ylabel=self.latex_labels[parameters[1]], **kwargs
646 )
648 def classification(self, dual=True, population=False):
649 """Return the classification probabilities
651 Parameters
652 ----------
653 dual: Bool, optional
654 if True, return classification probabilities generated from the
655 raw samples ('default') an samples reweighted to a population
656 inferred prior ('population'). Default True.
657 population: Bool, optional
658 if True, reweight the samples to a population informed prior and
659 then calculate classification probabilities. Default False. Only
660 used when dual=False
661 """
662 from pesummary.gw.classification import Classify
663 if dual:
664 probs = Classify(self).dual_classification()
665 else:
666 probs = Classify(self).classification(population=population)
667 return probs
669 def _waveform_args(self, f_ref=20., ind=0, longAscNodes=0., eccentricity=0.):
670 """Arguments to be passed to waveform generation
672 Parameters
673 ----------
674 f_ref: float, optional
675 reference frequency to use when converting spherical spins to
676 cartesian spins
677 ind: int, optional
678 index for the sample you wish to plot
679 longAscNodes: float, optional
680 longitude of ascending nodes, degenerate with the polarization
681 angle. Default 0.
682 eccentricity: float, optional
683 eccentricity at reference frequency. Default 0.
684 """
685 from lal import MSUN_SI, PC_SI
687 _samples = {key: value[ind] for key, value in self.items()}
688 required = [
689 "mass_1", "mass_2", "luminosity_distance"
690 ]
691 if not all(param in _samples.keys() for param in required):
692 raise ValueError(
693 "Unable to generate a waveform. Please add samples for "
694 + ", ".join(required)
695 )
696 waveform_args = [
697 _samples["mass_1"] * MSUN_SI, _samples["mass_2"] * MSUN_SI
698 ]
699 spin_angles = [
700 "theta_jn", "phi_jl", "tilt_1", "tilt_2", "phi_12", "a_1", "a_2",
701 "phase"
702 ]
703 spin_angles_condition = all(
704 spin in _samples.keys() for spin in spin_angles
705 )
706 cartesian_spins = [
707 "spin_1x", "spin_1y", "spin_1z", "spin_2x", "spin_2y", "spin_2z"
708 ]
709 cartesian_spins_condition = any(
710 spin in _samples.keys() for spin in cartesian_spins
711 )
712 if spin_angles_condition and not cartesian_spins_condition:
713 from pesummary.gw.conversions import component_spins
714 data = component_spins(
715 _samples["theta_jn"], _samples["phi_jl"], _samples["tilt_1"],
716 _samples["tilt_2"], _samples["phi_12"], _samples["a_1"],
717 _samples["a_2"], _samples["mass_1"], _samples["mass_2"],
718 f_ref, _samples["phase"]
719 )
720 iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = data.T
721 spins = [spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z]
722 else:
723 iota = _samples["iota"]
724 spins = [
725 _samples[param] if param in _samples.keys() else 0. for param in
726 ["spin_1x", "spin_1y", "spin_1z", "spin_2x", "spin_2y", "spin_2z"]
727 ]
728 waveform_args += spins
729 phase = _samples["phase"] if "phase" in _samples.keys() else 0.
730 waveform_args += [
731 _samples["luminosity_distance"] * PC_SI * 10**6, iota, phase
732 ]
733 waveform_args += [longAscNodes, eccentricity, 0.]
734 return waveform_args, _samples
736 def antenna_response(self, ifo):
737 """
738 """
739 from pesummary.gw.waveform import antenna_response
740 return antenna_response(self, ifo)
742 def _project_waveform(self, ifo, hp, hc, ra, dec, psi, time):
743 """Project a waveform onto a given detector
745 Parameters
746 ----------
747 ifo: str
748 name of the detector you wish to project the waveform onto
749 hp: np.ndarray
750 plus gravitational wave polarization
751 hc: np.ndarray
752 cross gravitational wave polarization
753 ra: float
754 right ascension to be passed to antenna response function
755 dec: float
756 declination to be passed to antenna response function
757 psi: float
758 polarization to be passed to antenna response function
759 time: float
760 time to be passed to antenna response function
761 """
762 import importlib
764 mod = importlib.import_module("pesummary.gw.plots.plot")
765 func = getattr(mod, "__antenna_response")
766 antenna = func(ifo, ra, dec, psi, time)
767 ht = hp * antenna[0] + hc * antenna[1]
768 return ht
770 def fd_waveform(self, approximant, delta_f, f_low, f_high, **kwargs):
771 """Generate a gravitational wave in the frequency domain
773 Parameters
774 ----------
775 approximant: str
776 name of the approximant to use when generating the waveform
777 delta_f: float
778 spacing between frequency samples
779 f_low: float
780 frequency to start evaluating the waveform
781 f_high: float
782 frequency to stop evaluating the waveform
783 f_ref: float, optional
784 reference frequency
785 project: str, optional
786 name of the detector to project the waveform onto. If None,
787 the plus and cross polarizations are returned. Default None
788 ind: int, optional
789 index for the sample you wish to plot
790 longAscNodes: float, optional
791 longitude of ascending nodes, degenerate with the polarization
792 angle. Default 0.
793 eccentricity: float, optional
794 eccentricity at reference frequency. Default 0.
795 LAL_parameters: dict, optional
796 LAL dictioanry containing accessory parameters. Default None
797 pycbc: Bool, optional
798 return a the waveform as a pycbc.frequencyseries.FrequencySeries
799 object
800 """
801 from pesummary.gw.waveform import fd_waveform
802 return fd_waveform(self, approximant, delta_f, f_low, f_high, **kwargs)
804 def td_waveform(
805 self, approximant, delta_t, f_low, **kwargs
806 ):
807 """Generate a gravitational wave in the time domain
809 Parameters
810 ----------
811 approximant: str
812 name of the approximant to use when generating the waveform
813 delta_t: float
814 spacing between frequency samples
815 f_low: float
816 frequency to start evaluating the waveform
817 f_ref: float, optional
818 reference frequency
819 project: str, optional
820 name of the detector to project the waveform onto. If None,
821 the plus and cross polarizations are returned. Default None
822 ind: int, optional
823 index for the sample you wish to plot
824 longAscNodes: float, optional
825 longitude of ascending nodes, degenerate with the polarization
826 angle. Default 0.
827 eccentricity: float, optional
828 eccentricity at reference frequency. Default 0.
829 LAL_parameters: dict, optional
830 LAL dictioanry containing accessory parameters. Default None
831 pycbc: Bool, optional
832 return a the waveform as a pycbc.timeseries.TimeSeries object
833 level: list, optional
834 the symmetric confidence interval of the time domain waveform. Level
835 must be greater than 0 and less than 1
836 """
837 from pesummary.gw.waveform import td_waveform
838 return td_waveform(
839 self, approximant, delta_t, f_low, **kwargs
840 )
842 def _maxL_waveform(self, func, *args, **kwargs):
843 """Return the maximum likelihood waveform in a given domain
845 Parameters
846 ----------
847 func: function
848 function you wish to use when generating the maximum likelihood
849 waveform
850 *args: tuple
851 all args passed to func
852 **kwargs: dict
853 all kwargs passed to func
854 """
855 ind = np.argmax(self["log_likelihood"])
856 kwargs["ind"] = ind
857 return func(*args, **kwargs)
859 def maxL_td_waveform(self, *args, **kwargs):
860 """Generate the maximum likelihood gravitational wave in the time domain
862 Parameters
863 ----------
864 approximant: str
865 name of the approximant to use when generating the waveform
866 delta_t: float
867 spacing between frequency samples
868 f_low: float
869 frequency to start evaluating the waveform
870 f_ref: float, optional
871 reference frequency
872 project: str, optional
873 name of the detector to project the waveform onto. If None,
874 the plus and cross polarizations are returned. Default None
875 longAscNodes: float, optional
876 longitude of ascending nodes, degenerate with the polarization
877 angle. Default 0.
878 eccentricity: float, optional
879 eccentricity at reference frequency. Default 0.
880 LAL_parameters: dict, optional
881 LAL dictioanry containing accessory parameters. Default None
882 level: list, optional
883 the symmetric confidence interval of the time domain waveform. Level
884 must be greater than 0 and less than 1
885 """
886 return self._maxL_waveform(self.td_waveform, *args, **kwargs)
888 def maxL_fd_waveform(self, *args, **kwargs):
889 """Generate the maximum likelihood gravitational wave in the frequency
890 domain
892 Parameters
893 ----------
894 approximant: str
895 name of the approximant to use when generating the waveform
896 delta_f: float
897 spacing between frequency samples
898 f_low: float
899 frequency to start evaluating the waveform
900 f_high: float
901 frequency to stop evaluating the waveform
902 f_ref: float, optional
903 reference frequency
904 project: str, optional
905 name of the detector to project the waveform onto. If None,
906 the plus and cross polarizations are returned. Default None
907 longAscNodes: float, optional
908 longitude of ascending nodes, degenerate with the polarization
909 angle. Default 0.
910 eccentricity: float, optional
911 eccentricity at reference frequency. Default 0.
912 LAL_parameters: dict, optional
913 LAL dictioanry containing accessory parameters. Default None
914 """
915 return self._maxL_waveform(self.fd_waveform, *args, **kwargs)
918class _MultiDimensionalSamplesDict(Dict):
919 """Class to store multiple SamplesDict objects
921 Parameters
922 ----------
923 parameters: list
924 list of parameters
925 samples: nd list
926 list of samples for each parameter for each chain
927 label_prefix: str, optional
928 prefix to use when distinguishing different analyses. The label is then
929 '{label_prefix}_{num}' where num is the result file index. Default
930 is 'dataset'
931 transpose: Bool, optional
932 True if the input is a transposed dictionary
933 labels: list, optional
934 the labels to use to distinguish different analyses. If provided
935 label_prefix is ignored
937 Attributes
938 ----------
939 T: pesummary.utils.samples_dict._MultiDimensionalSamplesDict
940 Transposed _MultiDimensionalSamplesDict object keyed by parameters
941 rather than label
942 nsamples: int
943 Total number of analyses stored in the _MultiDimensionalSamplesDict
944 object
945 number_of_samples: dict
946 Number of samples stored in the _MultiDimensionalSamplesDict for each
947 analysis
948 total_number_of_samples: int
949 Total number of samples stored across the multiple analyses
950 minimum_number_of_samples: int
951 The number of samples in the smallest analysis
953 Methods
954 -------
955 samples:
956 Return a list of samples stored in the _MultiDimensionalSamplesDict
957 object for a given parameter
958 """
959 def __init__(
960 self, *args, label_prefix="dataset", transpose=False, labels=None
961 ):
962 if labels is not None and len(np.unique(labels)) != len(labels):
963 raise ValueError(
964 "Please provide a unique set of labels for each analysis"
965 )
966 invalid_label_number_error = "Please provide a label for each analysis"
967 self.labels = labels
968 self.name = _MultiDimensionalSamplesDict
969 self.transpose = transpose
970 if len(args) == 1 and isinstance(args[0], dict):
971 if transpose:
972 parameters = list(args[0].keys())
973 _labels = list(args[0][parameters[0]].keys())
974 outer_iterator, inner_iterator = parameters, _labels
975 else:
976 _labels = list(args[0].keys())
977 parameters = {
978 label: list(args[0][label].keys()) for label in _labels
979 }
980 outer_iterator, inner_iterator = _labels, parameters
981 if labels is None:
982 self.labels = _labels
983 for num, dataset in enumerate(outer_iterator):
984 if isinstance(inner_iterator, dict):
985 try:
986 samples = np.array(
987 [args[0][dataset][param] for param in inner_iterator[dataset]]
988 )
989 except ValueError: # numpy deprecation error
990 samples = np.array(
991 [args[0][dataset][param] for param in inner_iterator[dataset]],
992 dtype=object
993 )
994 else:
995 try:
996 samples = np.array(
997 [args[0][dataset][param] for param in inner_iterator]
998 )
999 except ValueError: # numpy deprecation error
1000 samples = np.array(
1001 [args[0][dataset][param] for param in inner_iterator],
1002 dtype=object
1003 )
1004 if transpose:
1005 desc = parameters[num]
1006 self[desc] = SamplesDict(
1007 self.labels, samples, logger_warn="debug",
1008 autoscale=False
1009 )
1010 else:
1011 if self.labels is not None:
1012 desc = self.labels[num]
1013 else:
1014 desc = "{}_{}".format(label_prefix, num)
1015 self[desc] = SamplesDict(parameters[self.labels[num]], samples)
1016 else:
1017 parameters, samples = args
1018 if labels is not None and len(labels) != len(samples):
1019 raise ValueError(invalid_label_number_error)
1020 for num, dataset in enumerate(samples):
1021 if labels is not None:
1022 desc = labels[num]
1023 else:
1024 desc = "{}_{}".format(label_prefix, num)
1025 self[desc] = SamplesDict(parameters, dataset)
1026 if self.labels is None:
1027 self.labels = [
1028 "{}_{}".format(label_prefix, num) for num, _ in
1029 enumerate(samples)
1030 ]
1031 self.parameters = parameters
1032 self._update_latex_labels()
1034 def _update_latex_labels(self):
1035 """Update the stored latex labels
1036 """
1037 _parameters = [
1038 list(value.keys()) for value in self.values()
1039 ]
1040 _parameters = [item for sublist in _parameters for item in sublist]
1041 self._latex_labels = {
1042 param: latex_labels[param] if param in latex_labels.keys() else
1043 param for param in self.total_list_of_parameters + _parameters
1044 }
1046 def __setitem__(self, key, value):
1047 _value = value
1048 if not isinstance(value, SamplesDict):
1049 _value = SamplesDict(value)
1050 super(_MultiDimensionalSamplesDict, self).__setitem__(key, _value)
1051 try:
1052 if key not in self.labels:
1053 parameters = list(value.keys())
1054 try:
1055 samples = np.array([value[param] for param in parameters])
1056 except ValueError: # numpy deprecation error
1057 samples = np.array(
1058 [value[param] for param in parameters], dtype=object
1059 )
1060 self.parameters[key] = parameters
1061 self.labels.append(key)
1062 self.latex_labels = self._latex_labels()
1063 except (AttributeError, TypeError):
1064 pass
1066 @property
1067 def T(self):
1068 _transpose = not self.transpose
1069 if not self.transpose:
1070 _params = sorted([param for param in self[self.labels[0]].keys()])
1071 if not all(sorted(self[l].keys()) == _params for l in self.labels):
1072 raise ValueError(
1073 "Unable to transpose as not all samples have the same "
1074 "parameters"
1075 )
1076 transpose_dict = {
1077 param: {
1078 label: dataset[param] for label, dataset in self.items()
1079 } for param in self[self.labels[0]].keys()
1080 }
1081 else:
1082 transpose_dict = {
1083 label: {
1084 param: self[param][label] for param in self.keys()
1085 } for label in self.labels
1086 }
1087 return self.name(transpose_dict, transpose=_transpose)
1089 def _combine(
1090 self, labels=None, use_all=False, weights=None, shuffle=False,
1091 logger_level="debug"
1092 ):
1093 """Combine samples from a select number of analyses into a single
1094 SamplesDict object.
1096 Parameters
1097 ----------
1098 labels: list, optional
1099 analyses you wish to combine. Default use all labels stored in the
1100 dictionary
1101 use_all: Bool, optional
1102 if True, use all of the samples (do not weight). Default False
1103 weights: dict, optional
1104 dictionary of weights for each of the posteriors. Keys must be the
1105 labels you wish to combine and values are the weights you wish to
1106 assign to the posterior
1107 shuffle: Bool, optional
1108 shuffle the combined samples
1109 logger_level: str, optional
1110 logger level you wish to use. Default debug.
1111 """
1112 try:
1113 _logger = getattr(logger, logger_level)
1114 except AttributeError:
1115 raise ValueError(
1116 "Unknown logger level. Please choose either 'info' or 'debug'"
1117 )
1118 if labels is None:
1119 _provided_labels = False
1120 labels = self.labels
1121 else:
1122 _provided_labels = True
1123 if not all(label in self.labels for label in labels):
1124 raise ValueError(
1125 "Not all of the provided labels exist in the dictionary. "
1126 "The list of available labels are: {}".format(
1127 ", ".join(self.labels)
1128 )
1129 )
1130 _logger("Combining the following analyses: {}".format(labels))
1131 if use_all and weights is not None:
1132 raise ValueError(
1133 "Unable to use all samples and provide weights"
1134 )
1135 elif not use_all and weights is None:
1136 weights = {label: 1. for label in labels}
1137 elif not use_all and weights is not None:
1138 if len(weights) < len(labels):
1139 raise ValueError(
1140 "Please provide weights for each set of samples: {}".format(
1141 len(labels)
1142 )
1143 )
1144 if not _provided_labels and not isinstance(weights, dict):
1145 raise ValueError(
1146 "Weights must be provided as a dictionary keyed by the "
1147 "analysis label. The available labels are: {}".format(
1148 ", ".join(labels)
1149 )
1150 )
1151 elif not isinstance(weights, dict):
1152 weights = {
1153 label: weight for label, weight in zip(labels, weights)
1154 }
1155 if not all(label in labels for label in weights.keys()):
1156 for label in labels:
1157 if label not in weights.keys():
1158 weights[label] = 1.
1159 logger.warning(
1160 "No weight given for '{}'. Assigning a weight of "
1161 "1".format(label)
1162 )
1163 sum_weights = np.sum([_weight for _weight in weights.values()])
1164 weights = {
1165 key: item / sum_weights for key, item in weights.items()
1166 }
1167 if weights is not None:
1168 _logger(
1169 "Using the following weights for each file, {}".format(
1170 " ".join(
1171 ["{}: {}".format(k, v) for k, v in weights.items()]
1172 )
1173 )
1174 )
1175 _lengths = np.array(
1176 [self.number_of_samples[key] for key in labels]
1177 )
1178 if use_all:
1179 draw = _lengths
1180 else:
1181 draw = np.zeros(len(labels), dtype=int)
1182 _weights = np.array([weights[key] for key in labels])
1183 inds = np.argwhere(_weights > 0.)
1184 # The next 4 lines are inspired from the 'cbcBayesCombinePosteriors'
1185 # executable provided by LALSuite. Credit should go to the
1186 # authors of that code.
1187 initial = _weights[inds] * float(sum(_lengths[inds]))
1188 min_index = np.argmin(_lengths[inds] / initial)
1189 size = _lengths[inds][min_index] / _weights[inds][min_index]
1190 draw[inds] = np.around(_weights[inds] * size).astype(int)
1191 _logger(
1192 "Randomly drawing the following number of samples from each file, "
1193 "{}".format(
1194 " ".join(
1195 [
1196 "{}: {}/{}".format(l, draw[n], _lengths[n]) for n, l in
1197 enumerate(labels)
1198 ]
1199 )
1200 )
1201 )
1203 if self.transpose:
1204 _data = self.T
1205 else:
1206 _data = copy.deepcopy(self)
1207 for num, label in enumerate(labels):
1208 if draw[num] > 0:
1209 _data[label].downsample(draw[num])
1210 else:
1211 _data[label] = {
1212 param: np.array([]) for param in _data[label].keys()
1213 }
1214 try:
1215 intersection = set.intersection(
1216 *[
1217 set(_params) for _key, _params in _data.parameters.items() if
1218 _key in labels
1219 ]
1220 )
1221 except AttributeError:
1222 intersection = _data.parameters
1223 logger.debug(
1224 "Only including the parameters: {} as they are common to all "
1225 "analyses".format(", ".join(list(intersection)))
1226 )
1227 data = {
1228 param: np.concatenate([_data[key][param] for key in labels]) for
1229 param in intersection
1230 }
1231 if shuffle:
1232 inds = np.random.choice(
1233 np.sum(draw), size=np.sum(draw), replace=False
1234 )
1235 data = {
1236 param: value[inds] for param, value in data.items()
1237 }
1238 return SamplesDict(data, logger_warn="debug")
1240 @property
1241 def nsamples(self):
1242 if self.transpose:
1243 parameters = list(self.keys())
1244 return len(self[parameters[0]])
1245 return len(self)
1247 @property
1248 def number_of_samples(self):
1249 if self.transpose:
1250 return {
1251 label: len(self[iterator][label]) for iterator, label in zip(
1252 self.keys(), self.labels
1253 )
1254 }
1255 return {
1256 label: self[iterator].number_of_samples for iterator, label in zip(
1257 self.keys(), self.labels
1258 )
1259 }
1261 @property
1262 def total_number_of_samples(self):
1263 return np.sum([length for length in self.number_of_samples.values()])
1265 @property
1266 def minimum_number_of_samples(self):
1267 return np.min([length for length in self.number_of_samples.values()])
1269 @property
1270 def total_list_of_parameters(self):
1271 if isinstance(self.parameters, dict):
1272 _parameters = [item for item in self.parameters.values()]
1273 _flat_parameters = [
1274 item for sublist in _parameters for item in sublist
1275 ]
1276 elif isinstance(self.parameters, list):
1277 if np.array(self.parameters).ndim > 1:
1278 _flat_parameters = [
1279 item for sublist in self.parameters for item in sublist
1280 ]
1281 else:
1282 _flat_parameters = self.parameters
1283 return list(set(_flat_parameters))
1285 def samples(self, parameter):
1286 if self.transpose:
1287 samples = [self[parameter][label] for label in self.labels]
1288 else:
1289 samples = [self[label][parameter] for label in self.labels]
1290 return samples
1293class MCMCSamplesDict(_MultiDimensionalSamplesDict):
1294 """Class to store the mcmc chains from a single run
1296 Parameters
1297 ----------
1298 parameters: list
1299 list of parameters
1300 samples: nd list
1301 list of samples for each parameter for each chain
1302 transpose: Bool, optional
1303 True if the input is a transposed dictionary
1305 Attributes
1306 ----------
1307 T: pesummary.utils.samples_dict.MCMCSamplesDict
1308 Transposed MCMCSamplesDict object keyed by parameters rather than
1309 chain
1310 average: pesummary.utils.samples_dict.SamplesDict
1311 The mean of each sample across multiple chains. If the chains are of
1312 different lengths, all chains are resized to the minimum number of
1313 samples
1314 combine: pesummary.utils.samples_dict.SamplesDict
1315 Combine all samples from all chains into a single SamplesDict object
1316 nchains: int
1317 Total number of chains stored in the MCMCSamplesDict object
1318 number_of_samples: dict
1319 Number of samples stored in the MCMCSamplesDict for each chain
1320 total_number_of_samples: int
1321 Total number of samples stored across the multiple chains
1322 minimum_number_of_samples: int
1323 The number of samples in the smallest chain
1325 Methods
1326 -------
1327 discard_samples:
1328 Discard the first N samples for each chain
1329 burnin:
1330 Remove the first N samples as burnin. For different algorithms
1331 see pesummary.core.file.mcmc.algorithms
1332 gelman_rubin: float
1333 Return the Gelman-Rubin statistic between the chains for a given
1334 parameter. See pesummary.utils.utils.gelman_rubin
1335 samples:
1336 Return a list of samples stored in the MCMCSamplesDict object for a
1337 given parameter
1339 Examples
1340 --------
1341 Initializing the MCMCSamplesDict class
1343 >>> from pesummary.utils.samplesdict import MCMCSamplesDict
1344 >>> data = {
1345 ... "chain_0": {
1346 ... "a": [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6],
1347 ... "b": [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9]
1348 ... },
1349 ... "chain_1": {
1350 ... "a": [0.8, 0.5, 1.7, 1.4, 1.2, 1.7, 0.9],
1351 ... "b": [10, 10.5, 10.4, 9.6, 8.6, 11.6, 16.2]
1352 ... }
1353 ... }
1354 >>> dataset = MCMCSamplesDict(data)
1355 >>> parameters = ["a", "b"]
1356 >>> samples = [
1357 ... [
1358 ... [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6],
1359 ... [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9]
1360 ... ], [
1361 ... [0.8, 0.5, 1.7, 1.4, 1.2, 1.7, 0.9],
1362 ... [10, 10.5, 10.4, 9.6, 8.6, 11.6, 16.2]
1363 ... ]
1364 ... ]
1365 >>> dataset = MCMCSamplesDict(parameter, samples)
1366 """
1367 def __init__(self, *args, transpose=False):
1368 single_chain_error = (
1369 "This class requires more than one mcmc chain to be passed. "
1370 "As only one dataset is available, please use the SamplesDict "
1371 "class."
1372 )
1373 super(MCMCSamplesDict, self).__init__(
1374 *args, transpose=transpose, label_prefix="chain"
1375 )
1376 self.name = MCMCSamplesDict
1377 if len(self.labels) == 1:
1378 raise ValueError(single_chain_error)
1379 self.chains = self.labels
1380 self.nchains = self.nsamples
1382 @property
1383 def average(self):
1384 if self.transpose:
1385 data = SamplesDict({
1386 param: np.mean(
1387 [
1388 self[param][key][:self.minimum_number_of_samples] for
1389 key in self[param].keys()
1390 ], axis=0
1391 ) for param in self.parameters
1392 }, logger_warn="debug")
1393 else:
1394 data = SamplesDict({
1395 param: np.mean(
1396 [
1397 self[key][param][:self.minimum_number_of_samples] for
1398 key in self.keys()
1399 ], axis=0
1400 ) for param in self.parameters
1401 }, logger_warn="debug")
1402 return data
1404 @property
1405 def key_data(self):
1406 data = {}
1407 for param, value in self.combine.items():
1408 data[param] = value.key_data
1409 return data
1411 @property
1412 def combine(self):
1413 return self._combine(use_all=True, weights=None)
1415 def discard_samples(self, number):
1416 """Remove the first n samples
1418 Parameters
1419 ----------
1420 number: int/dict
1421 Number of samples that you wish to remove across all chains or a
1422 dictionary containing the number of samples to remove per chain
1423 """
1424 if isinstance(number, int):
1425 number = {chain: number for chain in self.keys()}
1426 for chain in self.keys():
1427 self[chain].discard_samples(number[chain])
1428 return self
1430 def burnin(self, *args, algorithm="burnin_by_step_number", **kwargs):
1431 """Remove the first N samples as burnin
1433 Parameters
1434 ----------
1435 algorithm: str, optional
1436 The algorithm you wish to use to remove samples as burnin. Default
1437 is 'burnin_by_step_number'. See
1438 `pesummary.core.file.mcmc.algorithms` for list of available
1439 algorithms
1440 """
1441 from pesummary.core.file import mcmc
1443 if algorithm not in mcmc.algorithms:
1444 raise ValueError(
1445 "{} is not a valid algorithm for removing samples as "
1446 "burnin".format(algorithm)
1447 )
1448 arguments = [self] + [i for i in args]
1449 return getattr(mcmc, algorithm)(*arguments, **kwargs)
1451 def gelman_rubin(self, parameter, decimal=5):
1452 """Return the gelman rubin statistic between chains for a given
1453 parameter
1455 Parameters
1456 ----------
1457 parameter: str
1458 name of the parameter you wish to return the gelman rubin statistic
1459 for
1460 decimal: int
1461 number of decimal places to keep when rounding
1462 """
1463 from pesummary.utils.utils import gelman_rubin as _gelman_rubin
1465 return _gelman_rubin(self.samples(parameter), decimal=decimal)
1468class MultiAnalysisSamplesDict(_MultiDimensionalSamplesDict):
1469 """Class to samples from multiple analyses
1471 Parameters
1472 ----------
1473 parameters: list
1474 list of parameters
1475 samples: nd list
1476 list of samples for each parameter for each chain
1477 labels: list, optional
1478 the labels to use to distinguish different analyses.
1479 transpose: Bool, optional
1480 True if the input is a transposed dictionary
1482 Attributes
1483 ----------
1484 T: pesummary.utils.samples_dict.MultiAnalysisSamplesDict
1485 Transposed MultiAnalysisSamplesDict object keyed by parameters
1486 rather than label
1487 nsamples: int
1488 Total number of analyses stored in the MultiAnalysisSamplesDict
1489 object
1490 number_of_samples: dict
1491 Number of samples stored in the MultiAnalysisSamplesDict for each
1492 analysis
1493 total_number_of_samples: int
1494 Total number of samples stored across the multiple analyses
1495 minimum_number_of_samples: int
1496 The number of samples in the smallest analysis
1497 available_plots: list
1498 list of plots which the user may user to display the contained posterior
1499 samples
1501 Methods
1502 -------
1503 from_files:
1504 Initialize the MultiAnalysisSamplesDict class with the contents of
1505 multiple files
1506 combine: pesummary.utils.samples_dict.SamplesDict
1507 Combine samples from a select number of analyses into a single
1508 SamplesDict object.
1509 js_divergence: float
1510 Return the JS divergence between two posterior distributions for a
1511 given parameter. See pesummary.utils.utils.jensen_shannon_divergence
1512 ks_statistic: float
1513 Return the KS statistic between two posterior distributions for a
1514 given parameter. See pesummary.utils.utils.kolmogorov_smirnov_test
1515 samples:
1516 Return a list of samples stored in the MCMCSamplesDict object for a
1517 given parameter
1518 write:
1519 Save the stored posterior samples to file
1520 """
1521 def __init__(self, *args, labels=None, transpose=False):
1522 if labels is None and not isinstance(args[0], dict):
1523 raise ValueError(
1524 "Please provide a unique label for each analysis"
1525 )
1526 super(MultiAnalysisSamplesDict, self).__init__(
1527 *args, labels=labels, transpose=transpose
1528 )
1529 self.name = MultiAnalysisSamplesDict
1531 @classmethod
1532 def from_files(cls, filenames, **kwargs):
1533 """Initialize the MultiAnalysisSamplesDict class with the contents of
1534 multiple result files
1536 Parameters
1537 ----------
1538 filenames: dict
1539 dictionary containing the path to the result file you wish to load
1540 as the item and a label associated with each result file as the key.
1541 If you are providing one or more PESummary metafiles, the key
1542 is ignored and labels stored in the metafile are used.
1543 **kwargs: dict
1544 all kwargs are passed to the pesummary.io.read function
1545 """
1546 from pesummary.io import read
1548 samples = {}
1549 for label, filename in filenames.items():
1550 _kwargs = kwargs
1551 if label in kwargs.keys():
1552 _kwargs = kwargs[label]
1553 _file = read(filename, **_kwargs)
1554 _samples = _file.samples_dict
1555 if isinstance(_samples, MultiAnalysisSamplesDict):
1556 _stored_labels = _samples.keys()
1557 cond1 = any(
1558 _label in filenames.keys() for _label in _stored_labels if
1559 _label != label
1560 )
1561 cond2 = any(
1562 _label in samples.keys() for _label in _stored_labels
1563 )
1564 if cond1 or cond2:
1565 raise ValueError(
1566 "The file '{}' contains the labels: {}. The "
1567 "dictionary already contains the labels: {}. Please "
1568 "provide unique labels for each dataset".format(
1569 filename, ", ".join(_stored_labels),
1570 ", ".join(samples.keys())
1571 )
1572 )
1573 samples.update(_samples)
1574 else:
1575 if label in samples.keys():
1576 raise ValueError(
1577 "The label '{}' has alreadt been used. Please select "
1578 "another label".format(label)
1579 )
1580 samples[label] = _samples
1581 return cls(samples)
1583 @property
1584 def plotting_map(self):
1585 return {
1586 "hist": self._marginalized_posterior,
1587 "corner": self._corner,
1588 "triangle": self._triangle,
1589 "reverse_triangle": self._reverse_triangle,
1590 "violin": self._violin,
1591 "2d_kde": self._2d_kde
1592 }
1594 @property
1595 def available_plots(self):
1596 return list(self.plotting_map.keys())
1598 @docstring_subfunction([
1599 'pesummary.core.plots.plot._1d_comparison_histogram_plot',
1600 'pesummary.gw.plots.plot._1d_comparison_histogram_plot',
1601 'pesummary.core.plots.publication.triangle_plot',
1602 'pesummary.core.plots.publication.reverse_triangle_plot'
1603 ])
1604 def plot(
1605 self, *args, type="hist", labels="all", colors=None, latex_friendly=True,
1606 **kwargs
1607 ):
1608 """Generate a plot for the posterior samples stored in
1609 MultiDimensionalSamplesDict
1611 Parameters
1612 ----------
1613 *args: tuple
1614 all arguments are passed to the plotting function
1615 type: str
1616 name of the plot you wish to make
1617 labels: list
1618 list of analyses that you wish to include in the plot
1619 colors: list
1620 list of colors to use for each analysis
1621 latex_friendly: Bool, optional
1622 if True, make the labels latex friendly. Default True
1623 **kwargs: dict
1624 all additional kwargs are passed to the plotting function
1625 """
1626 if type not in self.plotting_map.keys():
1627 raise NotImplementedError(
1628 "The {} method is not currently implemented. The allowed "
1629 "plotting methods are {}".format(
1630 type, ", ".join(self.available_plots)
1631 )
1632 )
1634 self._update_latex_labels()
1635 if labels == "all":
1636 labels = self.labels
1637 elif isinstance(labels, list):
1638 for label in labels:
1639 if label not in self.labels:
1640 raise ValueError(
1641 "'{}' is not a stored analysis. The available analyses "
1642 "are: '{}'".format(label, ", ".join(self.labels))
1643 )
1644 else:
1645 raise ValueError(
1646 "Please provide a list of analyses that you wish to plot"
1647 )
1648 if colors is None:
1649 colors = list(conf.colorcycle)
1650 while len(colors) < len(labels):
1651 colors += colors
1653 kwargs["labels"] = labels
1654 kwargs["colors"] = colors
1655 kwargs["latex_friendly"] = latex_friendly
1656 return self.plotting_map[type](*args, **kwargs)
1658 def _marginalized_posterior(
1659 self, parameter, module="core", labels="all", colors=None, **kwargs
1660 ):
1661 """Wrapper for the
1662 `pesummary.core.plots.plot._1d_comparison_histogram_plot` or
1663 `pesummary.gw.plots.plot._comparison_1d_histogram_plot`
1665 Parameters
1666 ----------
1667 parameter: str
1668 name of the parameter you wish to plot
1669 module: str, optional
1670 module you wish to use for the plotting
1671 labels: list
1672 list of analyses that you wish to include in the plot
1673 colors: list
1674 list of colors to use for each analysis
1675 **kwargs: dict
1676 all additional kwargs are passed to the
1677 `_1d_comparison_histogram_plot` function
1678 """
1679 module = importlib.import_module(
1680 "pesummary.{}.plots.plot".format(module)
1681 )
1682 return getattr(module, "_1d_comparison_histogram_plot")(
1683 parameter, [self[label][parameter] for label in labels],
1684 colors, self.latex_labels[parameter], labels, **kwargs
1685 )
1687 def _base_triangle(self, parameters, labels="all"):
1688 """Check that the parameters are valid for the different triangle
1689 plots available
1691 Parameters
1692 ----------
1693 parameters: list
1694 list of parameters they wish to study
1695 labels: list
1696 list of analyses that you wish to include in the plot
1697 """
1698 samples = [self[label] for label in labels]
1699 if len(parameters) > 2:
1700 raise ValueError("Function is only 2d")
1701 condition = set(
1702 label for num, label in enumerate(labels) for param in parameters if
1703 param not in samples[num].keys()
1704 )
1705 if len(condition):
1706 raise ValueError(
1707 "{} and {} are not available for the following "
1708 " analyses: {}".format(
1709 parameters[0], parameters[1], ", ".join(condition)
1710 )
1711 )
1712 return samples
1714 def _triangle(self, parameters, labels="all", module="core", **kwargs):
1715 """Wrapper for the `pesummary.core.plots.publication.triangle_plot`
1716 function
1718 Parameters
1719 ----------
1720 parameters: list
1721 list of parameters they wish to study
1722 labels: list
1723 list of analyses that you wish to include in the plot
1724 **kwargs: dict
1725 all additional kwargs are passed to the `triangle_plot` function
1726 """
1727 _module = importlib.import_module(
1728 "pesummary.{}.plots.publication".format(module)
1729 )
1730 samples = self._base_triangle(parameters, labels=labels)
1731 if module == "gw":
1732 kwargs["parameters"] = parameters
1733 return getattr(_module, "triangle_plot")(
1734 [_samples[parameters[0]] for _samples in samples],
1735 [_samples[parameters[1]] for _samples in samples],
1736 xlabel=self.latex_labels[parameters[0]],
1737 ylabel=self.latex_labels[parameters[1]], labels=labels, **kwargs
1738 )
1740 def _reverse_triangle(self, parameters, labels="all", module="core", **kwargs):
1741 """Wrapper for the `pesummary.core.plots.publication.reverse_triangle_plot`
1742 function
1744 Parameters
1745 ----------
1746 parameters: list
1747 list of parameters they wish to study
1748 labels: list
1749 list of analyses that you wish to include in the plot
1750 **kwargs: dict
1751 all additional kwargs are passed to the `triangle_plot` function
1752 """
1753 _module = importlib.import_module(
1754 "pesummary.{}.plots.publication".format(module)
1755 )
1756 samples = self._base_triangle(parameters, labels=labels)
1757 if module == "gw":
1758 kwargs["parameters"] = parameters
1759 return getattr(_module, "reverse_triangle_plot")(
1760 [_samples[parameters[0]] for _samples in samples],
1761 [_samples[parameters[1]] for _samples in samples],
1762 xlabel=self.latex_labels[parameters[0]],
1763 ylabel=self.latex_labels[parameters[1]], labels=labels, **kwargs
1764 )
1766 def _violin(
1767 self, parameter, labels="all", priors=None, latex_labels=GWlatex_labels,
1768 **kwargs
1769 ):
1770 """Wrapper for the `pesummary.gw.plots.publication.violin_plots`
1771 function
1773 Parameters
1774 ----------
1775 parameter: str, optional
1776 name of the parameter you wish to generate a violin plot for
1777 labels: list
1778 list of analyses that you wish to include in the plot
1779 priors: MultiAnalysisSamplesDict, optional
1780 prior samples for each analysis. If provided, the right hand side
1781 of each violin will show the prior
1782 latex_labels: dict, optional
1783 dictionary containing the latex label associated with parameter
1784 **kwargs: dict
1785 all additional kwargs are passed to the `violin_plots` function
1786 """
1787 from pesummary.gw.plots.publication import violin_plots
1789 _labels = [label for label in labels if parameter in self[label].keys()]
1790 if not len(_labels):
1791 raise ValueError(
1792 "{} is not in any of the posterior samples tables. Please "
1793 "choose another parameter to plot".format(parameter)
1794 )
1795 elif len(_labels) != len(labels):
1796 no = list(set(labels) - set(_labels))
1797 logger.warning(
1798 "Unable to generate a violin plot for {} because {} is not "
1799 "in their posterior samples table".format(
1800 " or ".join(no), parameter
1801 )
1802 )
1803 samples = [self[label][parameter] for label in _labels]
1804 if priors is not None and not all(
1805 label in priors.keys() for label in _labels
1806 ):
1807 raise ValueError("Please provide prior samples for all labels")
1808 elif priors is not None and not all(
1809 parameter in priors[label].keys() for label in _labels
1810 ):
1811 raise ValueError(
1812 "Please provide prior samples for {} for all labels".format(
1813 parameter
1814 )
1815 )
1816 elif priors is not None:
1817 from pesummary.core.plots.seaborn.violin import split_dataframe
1819 priors = [priors[label][parameter] for label in _labels]
1820 samples = split_dataframe(samples, priors, _labels)
1821 palette = kwargs.get("palette", None)
1822 left, right = "color: white", "pastel"
1823 if palette is not None and not isinstance(palette, dict):
1824 right = palette
1825 elif palette is not None and all(
1826 side in palette.keys() for side in ["left", "right"]
1827 ):
1828 left, right = palette["left"], palette["right"]
1829 kwargs.update(
1830 {
1831 "split": True, "x": "label", "y": "data", "hue": "side",
1832 "palette": {"right": right, "left": left}
1833 }
1834 )
1835 return violin_plots(
1836 parameter, samples, _labels, latex_labels, **kwargs
1837 )
1839 def _corner(self, module="core", labels="all", parameters=None, **kwargs):
1840 """Wrapper for the `pesummary.core.plots.plot._make_comparison_corner_plot`
1841 or `pesummary.gw.plots.plot._make_comparison_corner_plot` function
1843 Parameters
1844 ----------
1845 module: str, optional
1846 module you wish to use for the plotting
1847 labels: list
1848 list of analyses that you wish to include in the plot
1849 **kwargs: dict
1850 all additional kwargs are passed to the `_make_comparison_corner_plot`
1851 function
1852 """
1853 module = importlib.import_module(
1854 "pesummary.{}.plots.plot".format(module)
1855 )
1856 _samples = {label: self[label] for label in labels}
1857 _parameters = None
1858 if parameters is not None:
1859 _parameters = [
1860 param for param in parameters if all(
1861 param in posterior for posterior in _samples.values()
1862 )
1863 ]
1864 if not len(_parameters):
1865 raise ValueError(
1866 "None of the chosen parameters are in all of the posterior "
1867 "samples tables. Please choose other parameters to plot"
1868 )
1869 return getattr(module, "_make_comparison_corner_plot")(
1870 _samples, self.latex_labels, corner_parameters=_parameters, **kwargs
1871 )
1873 def _2d_kde(
1874 self, parameters, module="core", labels="all", plot_density=None,
1875 **kwargs
1876 ):
1877 """Wrapper for the
1878 `pesummary.gw.plots.publication.comparison_twod_contour_plot` or
1879 `pesummary.core.plots.publication.comparison_twod_contour_plot` function
1881 Parameters
1882 ----------
1883 parameters: list
1884 list of length 2 giving the parameters you wish to plot
1885 module: str, optional
1886 module you wish to use for the plotting
1887 labels: list
1888 list of analyses that you wish to include in the plot
1889 **kwargs: dict, optional
1890 all additional kwargs are passed to the
1891 `comparison_twod_contour_plot` function
1892 """
1893 _module = importlib.import_module(
1894 "pesummary.{}.plots.publication".format(module)
1895 )
1896 samples = self._base_triangle(parameters, labels=labels)
1897 if plot_density is not None:
1898 if isinstance(plot_density, str):
1899 plot_density = [plot_density]
1900 elif isinstance(plot_density, bool) and plot_density:
1901 plot_density = labels
1902 for i in plot_density:
1903 if i not in labels:
1904 raise ValueError(
1905 "Unable to plot the density for '{}'. Please choose "
1906 "from: {}".format(plot_density, ", ".join(labels))
1907 )
1908 if module == "gw":
1909 return getattr(_module, "twod_contour_plots")(
1910 parameters, [
1911 [self[label][param] for param in parameters] for label in
1912 labels
1913 ], labels, {
1914 parameters[0]: self.latex_labels[parameters[0]],
1915 parameters[1]: self.latex_labels[parameters[1]]
1916 }, plot_density=plot_density, **kwargs
1917 )
1918 return getattr(_module, "comparison_twod_contour_plot")(
1919 [_samples[parameters[0]] for _samples in samples],
1920 [_samples[parameters[1]] for _samples in samples],
1921 xlabel=self.latex_labels[parameters[0]],
1922 ylabel=self.latex_labels[parameters[1]], labels=labels,
1923 plot_density=plot_density, **kwargs
1924 )
1926 def combine(self, **kwargs):
1927 """Combine samples from a select number of analyses into a single
1928 SamplesDict object.
1930 Parameters
1931 ----------
1932 labels: list, optional
1933 analyses you wish to combine. Default use all labels stored in the
1934 dictionary
1935 use_all: Bool, optional
1936 if True, use all of the samples (do not weight). Default False
1937 weights: dict, optional
1938 dictionary of weights for each of the posteriors. Keys must be the
1939 labels you wish to combine and values are the weights you wish to
1940 assign to the posterior
1941 logger_level: str, optional
1942 logger level you wish to use. Default debug.
1943 """
1944 return self._combine(**kwargs)
1946 def write(self, labels=None, **kwargs):
1947 """Save the stored posterior samples to file
1949 Parameters
1950 ----------
1951 labels: list, optional
1952 list of analyses that you wish to save to file. Default save all
1953 analyses to file
1954 **kwargs: dict, optional
1955 all additional kwargs passed to the pesummary.io.write function
1956 """
1957 if labels is None:
1958 labels = self.labels
1959 elif not all(label in self.labels for label in labels):
1960 for label in labels:
1961 if label not in self.labels:
1962 raise ValueError(
1963 "Unable to find analysis: '{}'. The list of "
1964 "available analyses are: {}".format(
1965 label, ", ".join(self.labels)
1966 )
1967 )
1968 for label in labels:
1969 self[label].write(**kwargs)
1971 def js_divergence(self, parameter, decimal=5):
1972 """Return the JS divergence between the posterior samples for
1973 a given parameter
1975 Parameters
1976 ----------
1977 parameter: str
1978 name of the parameter you wish to return the gelman rubin statistic
1979 for
1980 decimal: int
1981 number of decimal places to keep when rounding
1982 """
1983 from pesummary.utils.utils import jensen_shannon_divergence
1985 return jensen_shannon_divergence(
1986 self.samples(parameter), decimal=decimal
1987 )
1989 def ks_statistic(self, parameter, decimal=5):
1990 """Return the KS statistic between the posterior samples for
1991 a given parameter
1993 Parameters
1994 ----------
1995 parameter: str
1996 name of the parameter you wish to return the gelman rubin statistic
1997 for
1998 decimal: int
1999 number of decimal places to keep when rounding
2000 """
2001 from pesummary.utils.utils import kolmogorov_smirnov_test
2003 return kolmogorov_smirnov_test(
2004 self.samples(parameter), decimal=decimal
2005 )