Coverage for pesummary/gw/file/formats/pesummary.py: 80.4%
143 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3from pesummary.gw.file.formats.base_read import GWMultiAnalysisRead
4from pesummary.core.file.formats.pesummary import (
5 PESummary as CorePESummary, PESummaryDeprecated as CorePESummaryDeprecated
6)
7from pesummary.utils.dict import load_recursively
8from pesummary.utils.decorators import deprecation
9import numpy as np
11__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
14def write_pesummary(*args, **kwargs):
15 """Write a set of samples to a pesummary file
17 Parameters
18 ----------
19 args: tuple
20 either a 2d tuple containing the parameters as first argument and samples
21 as the second argument, or a SamplesDict object containing the samples
22 outdir: str, optional
23 directory to write the dat file
24 label: str, optional
25 The label of the analysis. This is used in the filename if a filename
26 if not specified
27 config: dict, optional
28 configuration file that you wish to save to file
29 injection_data: dict, optional
30 dictionary containing the injection values that you wish to save to file keyed
31 by parameter
32 file_kwargs: dict, optional
33 any kwargs that you wish to save to file
34 mcmc_samples: Bool, optional
35 if True, the set of samples provided are from multiple MCMC chains
36 hdf5: Bool, optional
37 if True, save the pesummary file in hdf5 format
38 kwargs: dict
39 all other kwargs are passed to the pesummary.gw.file.meta_file._GWMetaFile class
40 """
41 from pesummary.core.file.formats.pesummary import write_pesummary as core_write
42 from pesummary.gw.file.meta_file import _GWMetaFile
44 return core_write(*args, cls=_GWMetaFile, **kwargs)
47class PESummary(GWMultiAnalysisRead, CorePESummary):
48 """This class handles the existing posterior_samples.h5 file
50 Parameters
51 ----------
52 path_to_results_file: str
53 path to the results file you wish to load
54 remove_nan_likelihood_samples: Bool, optional
55 if True, remove samples which have log_likelihood='nan'. Default True
57 Attributes
58 ----------
59 parameters: list
60 list of parameters stored in the result file
61 converted_parameters: list
62 list of parameters that have been derived from the sampled distributions
63 samples: 2d list
64 list of samples stored in the result file
65 samples_dict: dict
66 dictionary of samples stored in the result file keyed by parameters
67 input_version: str
68 version of the result file passed.
69 extra_kwargs: dict
70 dictionary of kwargs that were extracted from the result file
71 approximant: list
72 list of approximants stored in the result file
73 labels: list
74 list of analyses stored in the result file
75 config: list
76 list of dictonaries containing the configuration files for each
77 analysis
78 psd: dict
79 dictionary containing the psds stored in the result file keyed by
80 the analysis label
81 calibration: dict
82 dictionary containing the calibration posterior samples keyed by
83 the analysis label
84 skymap: dict
85 dictionary containing the skymap probabilities keyed by the analysis
86 label
87 gwdata: pesummary.gw.file.strain.StrainDataDict
88 dictionary containing the strain data used in the analysis
89 prior: dict
90 dictionary containing the prior samples for each analysis
91 weights: dict
92 dictionary of weights for each samples for each analysis
93 detectors: list
94 list of IFOs used in each analysis
95 pe_algorithm: dict
96 name of the algorithm used to generate the each analysis
97 preferred: str
98 name of the preferred analysis in the result file
100 Methods
101 -------
102 samples_dict_for_label: dict
103 dictionary of samples for a specific analysis
104 reduced_samples_dict: dict
105 dictionary of samples for one or more analyses
106 to_dat:
107 save the posterior samples to a .dat file
108 to_latex_table:
109 convert the posterior samples to a latex table
110 generate_latex_macros:
111 generate a set of latex macros for the stored posterior samples
112 to_lalinference:
113 convert the posterior samples to a lalinference result file
114 to_bilby:
115 convert the posterior samples to a bilby result file
116 generate_all_posterior_samples:
117 generate all posterior distributions that may be derived from
118 sampled distributions
119 """
120 def __init__(self, path_to_results_file, **kwargs):
121 super(PESummary, self).__init__(
122 path_to_results_file=path_to_results_file
123 )
125 @property
126 def load_kwargs(self):
127 return dict(grab_data_from_dictionary=self._grab_data_from_dictionary)
129 @staticmethod
130 def _grab_data_from_dictionary(dictionary):
131 """
132 """
133 stored_data = CorePESummary._grab_data_from_dictionary(
134 dictionary=dictionary, ignore=["strain"]
135 )
137 approx_list = list()
138 psd_dict, cal_dict, skymap_dict = {}, {}, {}
139 psd, cal = None, None
140 for num, label in enumerate(stored_data["labels"]):
141 data, = load_recursively(label, dictionary)
142 if "psds" in data.keys():
143 psd_dict[label] = data["psds"]
144 if "calibration_envelope" in data.keys():
145 cal_dict[label] = data["calibration_envelope"]
146 if "skymap" in data.keys():
147 skymap_dict[label] = data["skymap"]
148 if "approximant" in data.keys():
149 approx_list.append(data["approximant"])
150 else:
151 approx_list.append(None)
152 stored_data["approximant"] = approx_list
153 stored_data["calibration"] = cal_dict
154 stored_data["psd"] = psd_dict
155 stored_data["skymap"] = skymap_dict
156 if "strain" in dictionary.keys():
157 stored_data["gwdata"] = dictionary["strain"]
158 return stored_data
160 @property
161 def calibration_data_in_results_file(self):
162 if self.calibration:
163 keys = [list(self.calibration[i].keys()) for i in self.labels]
164 total = [[self.calibration[key][ifo] for ifo in keys[num]] for
165 num, key in enumerate(self.labels)]
166 return total, keys
167 return None
169 @property
170 def detectors(self):
171 det_list = list()
172 for parameters in self.parameters:
173 detectors = list()
174 for param in parameters:
175 if "_optimal_snr" in param and param != "network_optimal_snr":
176 detectors.append(param.split("_optimal_snr")[0])
177 if not detectors:
178 detectors.append(None)
179 det_list.append(detectors)
180 return det_list
182 def write(self, labels="all", **kwargs):
183 """Save the data to file
185 Parameters
186 ----------
187 package: str, optional
188 package you wish to use when writing the data
189 kwargs: dict, optional
190 all additional kwargs are passed to the pesummary.io.write function
191 """
192 approximant = {
193 label: self.approximant[num] if self.approximant[num] != {} else
194 None for num, label in enumerate(self.labels)
195 }
196 properties = dict(
197 calibration=self.calibration, psd=self.psd, approximant=approximant,
198 skymap=self.skymap
199 )
200 CorePESummary.write(
201 self, package="gw", labels=labels, cls_properties=properties, **kwargs
202 )
204 def to_bilby(self, labels="all", **kwargs):
205 """Convert a PESummary metafile to a bilby results object
206 """
207 from bilby.gw.result import CompactBinaryCoalescenceResult
209 return CorePESummary.write(
210 self, labels=labels, package="core", file_format="bilby",
211 _return=True, cls=CompactBinaryCoalescenceResult, **kwargs
212 )
214 def to_lalinference(self, labels="all", **kwargs):
215 """Convert the samples stored in a PESummary metafile to a .dat file
217 Parameters
218 ----------
219 labels: list, optional
220 optional list of analyses to save to file
221 kwargs: dict, optional
222 all additional kwargs are passed to the pesummary.io.write function
223 """
224 return self.write(
225 labels=labels, file_format="lalinference", **kwargs
226 )
229class PESummaryDeprecated(PESummary):
230 """
231 """
232 @deprecation(
233 "This file format is out-of-date and may not be supported in future "
234 "releases."
235 )
236 def __init__(self, path_to_results_file, **kwargs):
237 super(PESummaryDeprecated, self).__init__(path_to_results_file, **kwargs)
239 @property
240 def load_kwargs(self):
241 return {
242 "grab_data_from_dictionary": PESummaryDeprecated._grab_data_from_dictionary
243 }
245 @staticmethod
246 def _grab_data_from_dictionary(dictionary):
247 """
248 """
249 data = CorePESummaryDeprecated._grab_data_from_dictionary(
250 dictionary=dictionary
251 )
253 approx_list = list()
254 psd, cal = None, None
255 for num, key in enumerate(data["labels"]):
256 if "psds" in dictionary.keys():
257 psd, = load_recursively("psds", dictionary)
258 if "calibration_envelope" in dictionary.keys():
259 cal, = load_recursively("calibration_envelope", dictionary)
260 if "approximant" in dictionary.keys():
261 if key in dictionary["approximant"].keys():
262 approx_list.append(dictionary["approximant"][key])
263 else:
264 approx_list.append(None)
265 else:
266 approx_list.append(None)
267 data["approximant"] = approx_list
268 data["calibration"] = cal
269 data["psd"] = psd
271 return data
274class TGRPESummary(PESummary):
275 """This class handles TGR PESummary result files
277 Parameters
278 ----------
279 path_to_results_file: str
280 path to the results file you wish to load
282 Attributes
283 ----------
284 parameters: list
285 list of parameters stored in the result file
286 converted_parameters: list
287 list of parameters that have been derived from the sampled distributions
288 samples: 2d list
289 list of samples stored in the result file
290 samples_dict: dict
291 dictionary of samples stored in the result file keyed by parameters
292 labels: list
293 list of analyses stored in the result file
294 file_kwargs: dict
295 dictionary of kwargs associated with each label
296 imrct_deviation: dict
297 dictionary of pesummary.utils.probability_dict.ProbabilityDict2D
298 objects, one for each analysis
299 """
300 def __init__(self, path_to_results_file, **kwargs):
301 super(PESummary, self).__init__(
302 path_to_results_file=path_to_results_file
303 )
305 def load(self, *args, **kwargs):
306 super(TGRPESummary, self).load(*args, **kwargs)
307 self.imrct_deviation = {}
308 if "imrct_deviation" in self.data.keys():
309 if len(self.data["imrct_deviation"]):
310 from pesummary.utils.probability_dict import ProbabilityDict2D
311 analysis_label = [
312 label.split(":inspiral")[0] for label in self.labels
313 if "inspiral" in label and "postinspiral" not in label
314 ]
315 self.imrct_deviation = {
316 label: ProbabilityDict2D(
317 {
318 "final_mass_final_spin_deviations": [
319 *self.data["imrct_deviation"][num]
320 ]
321 }
322 ) for num, label in enumerate(analysis_label)
323 }
324 if len(analysis_label) == 1:
325 self.imrct_deviation = self.imrct_deviation["inspiral"]
326 else:
327 self.imrct_deviation = {}
329 @staticmethod
330 def _grab_data_from_dictionary(dictionary):
331 """
332 """
333 labels = list(dictionary.keys())
334 if "version" in labels:
335 labels.remove("version")
337 history_dict = None
338 if "history" in labels:
339 history_dict = dictionary["history"]
340 labels.remove("history")
341 parameter_list, sample_list, imrct_deviation = [], [], []
342 file_kwargs = {}
343 _labels = []
344 for num, label in enumerate(labels):
345 if label == "version" or label == "history":
346 continue
347 data, = load_recursively(label, dictionary)
348 posterior_samples = data["posterior_samples"]
349 if "imrct" in data.keys():
350 if "meta_data" in data["imrct"].keys():
351 _meta_data = data["imrct"]["meta_data"]
352 else:
353 _meta_data = {}
354 file_kwargs[label] = _meta_data
355 for analysis in ["inspiral", "postinspiral"]:
356 if len(labels) > 1:
357 _labels.append("{}:{}".format(label, analysis))
358 else:
359 _labels.append(analysis)
360 parameters = [
361 j for j in posterior_samples[analysis].dtype.names
362 ]
363 samples = [
364 np.array(j.tolist()) for j in posterior_samples[analysis]
365 ]
366 if isinstance(parameters[0], bytes):
367 parameters = [
368 parameter.decode("utf-8") for parameter in parameters
369 ]
370 parameter_list.append(parameters)
371 sample_list.append(samples)
372 imrct_deviation.append(
373 [
374 data["imrct"]["final_mass_deviation"],
375 data["imrct"]["final_spin_deviation"],
376 data["imrct"]["pdf"]
377 ]
378 )
379 labels = _labels
380 return {
381 "parameters": parameter_list,
382 "samples": sample_list,
383 "injection": None,
384 "labels": labels,
385 "history": history_dict,
386 "imrct_deviation": imrct_deviation,
387 "kwargs": file_kwargs
388 }