Coverage for pesummary/core/cli/inputs.py: 80.6%
1264 statements
« prev ^ index » next coverage.py v7.4.4, created at 2026-01-15 17:49 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2026-01-15 17:49 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import os
4import socket
5from glob import glob
6from pathlib import Path
7from getpass import getuser
9import math
10import numpy as np
11from pesummary.core.file.read import read as Read
12from pesummary.utils.exceptions import InputError
13from pesummary.utils.decorators import deprecation
14from pesummary.utils.samples_dict import SamplesDict, MCMCSamplesDict
15from pesummary.utils.utils import (
16 guess_url, logger, make_dir, make_cache_style_file, list_match
17)
18from pesummary import conf
20__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
23class _Input(object):
24 """Super class to handle the command line arguments
25 """
26 @staticmethod
27 def is_pesummary_metafile(proposed_file):
28 """Determine if a file is a PESummary metafile or not
30 Parameters
31 ----------
32 proposed_file: str
33 path to the file
34 """
35 extension = proposed_file.split(".")[-1]
36 if extension == "h5" or extension == "hdf5" or extension == "hdf":
37 from pesummary.core.file.read import (
38 is_pesummary_hdf5_file, is_pesummary_hdf5_file_deprecated
39 )
41 result = any(
42 func(proposed_file) for func in [
43 is_pesummary_hdf5_file,
44 is_pesummary_hdf5_file_deprecated
45 ]
46 )
47 return result
48 elif extension == "json":
49 from pesummary.core.file.read import (
50 is_pesummary_json_file, is_pesummary_json_file_deprecated
51 )
53 result = any(
54 func(proposed_file) for func in [
55 is_pesummary_json_file,
56 is_pesummary_json_file_deprecated
57 ]
58 )
59 return result
60 else:
61 return False
63 @staticmethod
64 def grab_data_from_metafile(
65 existing_file, webdir, compare=None, read_function=Read,
66 _replace_with_pesummary_kwargs={}, nsamples=None,
67 disable_injection=False, keep_nan_likelihood_samples=False,
68 reweight_samples=False, **kwargs
69 ):
70 """Grab data from an existing PESummary metafile
72 Parameters
73 ----------
74 existing_file: str
75 path to the existing metafile
76 webdir: str
77 the directory to store the existing configuration file
78 compare: list, optional
79 list of labels for events stored in an existing metafile that you
80 wish to compare
81 read_function: func, optional
82 PESummary function to use to read in the existing file
83 _replace_with_pesummary_kwargs: dict, optional
84 dictionary of kwargs that you wish to replace with the data stored
85 in the PESummary file
86 nsamples: int, optional
87 Number of samples to use. Default all available samples
88 kwargs: dict
89 All kwargs are passed to the `generate_all_posterior_samples`
90 method
91 """
92 f = read_function(
93 existing_file,
94 remove_nan_likelihood_samples=not keep_nan_likelihood_samples
95 )
96 for ind, label in enumerate(f.labels):
97 kwargs[label] = kwargs.copy()
98 for key, item in _replace_with_pesummary_kwargs.items():
99 try:
100 kwargs[label][key] = eval(
101 item.format(file="f", ind=ind, label=label)
102 )
103 except TypeError:
104 _item = item.split("['{label}']")[0]
105 kwargs[label][key] = eval(
106 _item.format(file="f", ind=ind, label=label)
107 )
108 except (AttributeError, KeyError, NameError):
109 pass
111 if not f.mcmc_samples:
112 labels = f.labels
113 else:
114 labels = list(f.samples_dict.keys())
115 indicies = np.arange(len(labels))
117 if compare:
118 indicies = []
119 for i in compare:
120 if i not in labels:
121 raise InputError(
122 "Label '%s' does not exist in the metafile. The list "
123 "of available labels are %s" % (i, labels)
124 )
125 indicies.append(labels.index(i))
126 labels = compare
128 if nsamples is not None:
129 f.downsample(nsamples, labels=labels)
130 if not f.mcmc_samples:
131 f.generate_all_posterior_samples(labels=labels, **kwargs)
132 if reweight_samples:
133 f.reweight_samples(reweight_samples, labels=labels, **kwargs)
135 parameters = f.parameters
136 if not f.mcmc_samples:
137 samples = [np.array(i).T for i in f.samples]
138 DataFrame = {
139 label: SamplesDict(parameters[ind], samples[ind])
140 for label, ind in zip(labels, indicies)
141 }
142 _parameters = lambda label: DataFrame[label].keys()
143 else:
144 DataFrame = {
145 f.labels[0]: MCMCSamplesDict(
146 {
147 label: f.samples_dict[label] for label in labels
148 }
149 )
150 }
151 labels = f.labels
152 indicies = np.arange(len(labels))
153 _parameters = lambda label: DataFrame[f.labels[0]].parameters
154 if not disable_injection and f.injection_parameters != []:
155 inj_values = f.injection_dict
156 for label in labels:
157 for param in DataFrame[label].keys():
158 if param not in f.injection_dict[label].keys():
159 f.injection_dict[label][param] = float("nan")
160 else:
161 inj_values = {
162 i: {
163 param: float("nan") for param in DataFrame[i].parameters
164 } for i in labels
165 }
166 for i in inj_values.keys():
167 for param in inj_values[i].keys():
168 if inj_values[i][param] == "nan":
169 inj_values[i][param] = float("nan")
170 if isinstance(inj_values[i][param], bytes):
171 inj_values[i][param] = inj_values[i][param].decode("utf-8")
173 if hasattr(f, "priors") and f.priors is not None and f.priors != {}:
174 priors = f.priors
175 else:
176 priors = {label: {} for label in labels}
178 config = []
179 if f.config is not None and not all(i is None for i in f.config):
180 config = []
181 for i in labels:
182 config_dir = os.path.join(webdir, "config")
183 filename = f.write_config_to_file(
184 i, outdir=config_dir, _raise=False,
185 filename="{}_config.ini".format(i)
186 )
187 _config = os.path.join(config_dir, filename)
188 if filename is not None and os.path.isfile(_config):
189 config.append(_config)
190 else:
191 config.append(None)
192 else:
193 for i in labels:
194 config.append(None)
196 if f.weights is not None:
197 weights = {i: f.weights[i] for i in labels}
198 else:
199 weights = {i: None for i in labels}
201 return {
202 "samples": DataFrame,
203 "injection_data": inj_values,
204 "file_version": {
205 i: j for i, j in zip(
206 labels, [f.input_version[ind] for ind in indicies]
207 )
208 },
209 "file_kwargs": {
210 i: j for i, j in zip(
211 labels, [f.extra_kwargs[ind] for ind in indicies]
212 )
213 },
214 "prior": priors,
215 "config": config,
216 "labels": labels,
217 "weights": weights,
218 "indicies": indicies,
219 "mcmc_samples": f.mcmc_samples,
220 "open_file": f,
221 "descriptions": f.description
222 }
224 @staticmethod
225 def grab_data_from_file(
226 file, label, webdir, config=None, injection=None, read_function=Read,
227 file_format=None, nsamples=None, disable_prior_sampling=False,
228 nsamples_for_prior=None, path_to_samples=None,
229 keep_nan_likelihood_samples=False, reweight_samples=False,
230 multi_process=None, **kwargs
231 ):
232 """Grab data from a result file containing posterior samples
234 Parameters
235 ----------
236 file: str
237 path to the result file
238 label: str
239 label that you wish to use for the result file
240 config: str, optional
241 path to a configuration file used in the analysis
242 injection: str, optional
243 path to an injection file used in the analysis
244 read_function: func, optional
245 PESummary function to use to read in the file
246 file_format, str, optional
247 the file format you wish to use when loading. Default None.
248 If None, the read function loops through all possible options
249 kwargs: dict
250 Dictionary of keyword arguments fed to the
251 `generate_all_posterior_samples` method
252 """
253 f = read_function(
254 file, file_format=file_format, disable_prior=disable_prior_sampling,
255 nsamples_for_prior=nsamples_for_prior, path_to_samples=path_to_samples,
256 remove_nan_likelihood_samples=not keep_nan_likelihood_samples, multi_process=multi_process, **kwargs
257 )
258 if config is not None:
259 f.add_fixed_parameters_from_config_file(config)
261 if nsamples is not None:
262 f.downsample(nsamples)
263 f.generate_all_posterior_samples(**kwargs)
264 if injection:
265 f.add_injection_parameters_from_file(
266 injection, conversion_kwargs=kwargs
267 )
268 if reweight_samples:
269 f.reweight_samples(reweight_samples)
270 parameters = f.parameters
271 samples = np.array(f.samples).T
272 DataFrame = {label: SamplesDict(parameters, samples)}
273 kwargs = f.extra_kwargs
274 if hasattr(f, "injection_parameters"):
275 injection = f.injection_parameters
276 if injection is not None:
277 for i in parameters:
278 if i not in list(injection.keys()):
279 injection[i] = float("nan")
280 else:
281 injection = {i: j for i, j in zip(
282 parameters, [float("nan")] * len(parameters))}
283 else:
284 injection = {i: j for i, j in zip(
285 parameters, [float("nan")] * len(parameters))}
286 version = f.input_version
287 if hasattr(f, "priors") and f.priors is not None:
288 priors = {key: {label: item} for key, item in f.priors.items()}
289 else:
290 priors = {label: []}
291 if hasattr(f, "weights") and f.weights is not None:
292 weights = f.weights
293 else:
294 weights = None
295 data = {
296 "samples": DataFrame,
297 "injection_data": {label: injection},
298 "file_version": {label: version},
299 "file_kwargs": {label: kwargs},
300 "prior": priors,
301 "weights": {label: weights},
302 "open_file": f,
303 "descriptions": {label: f.description}
304 }
305 if hasattr(f, "config") and f.config is not None:
306 if config is None:
307 config_dir = os.path.join(webdir, "config")
308 filename = "{}_config.ini".format(label)
309 logger.debug(
310 "Successfully extracted config data from the provided "
311 "input file. Saving the data to the file '{}'".format(
312 os.path.join(config_dir, filename)
313 )
314 )
315 _filename = f.write(
316 filename=filename, outdir=config_dir, file_format="ini",
317 _raise=False
318 )
319 data["config"] = _filename
320 else:
321 logger.info(
322 "Ignoring config data extracted from the input file and "
323 "using the config file provided"
324 )
325 return data
327 @property
328 def result_files(self):
329 return self._result_files
331 @result_files.setter
332 def result_files(self, result_files):
333 self._result_files = result_files
334 if self._result_files is not None:
335 for num, ff in enumerate(self._result_files):
336 func = None
337 if not os.path.isfile(ff) and "@" in ff:
338 from pesummary.io.read import _fetch_from_remote_server
339 func = _fetch_from_remote_server
340 elif not os.path.isfile(ff) and "https://" in ff:
341 from pesummary.io.read import _fetch_from_url
342 func = _fetch_from_url
343 elif not os.path.isfile(ff) and "*" in ff:
344 from pesummary.utils.utils import glob_directory
345 func = glob_directory
346 if func is not None:
347 _data = func(ff)
348 if isinstance(_data, (np.ndarray, list)) and len(_data) > 0:
349 self._result_files[num] = _data[0]
350 if len(_data) > 1:
351 _ = [
352 self._result_files.insert(num + 1, d) for d in
353 _data[1:][::-1]
354 ]
355 elif isinstance(_data, np.ndarray):
356 raise InputError(
357 "Unable to find any files matching '{}'".format(ff)
358 )
359 else:
360 self._result_files[num] = _data
362 @property
363 def seed(self):
364 return self._seed
366 @seed.setter
367 def seed(self, seed):
368 np.random.seed(seed)
369 self._seed = seed
371 @property
372 def existing(self):
373 return self._existing
375 @existing.setter
376 def existing(self, existing):
377 self._existing = existing
378 if existing is not None:
379 self._existing = os.path.abspath(existing)
381 @property
382 def existing_metafile(self):
383 return self._existing_metafile
385 @existing_metafile.setter
386 def existing_metafile(self, existing_metafile):
387 from glob import glob
389 self._existing_metafile = existing_metafile
390 if self._existing_metafile is None:
391 return
392 if not os.path.isdir(os.path.join(self.existing, "samples")):
393 raise InputError("Please provide a valid existing directory")
394 _dir = os.path.join(self.existing, "samples")
395 files = glob(os.path.join(_dir, "posterior_samples*"))
396 dir_content = glob(os.path.join(_dir, "*.h5"))
397 dir_content.extend(glob(os.path.join(_dir, "*.json")))
398 dir_content.extend(glob(os.path.join(_dir, "*.hdf5")))
399 if len(files) == 0 and len(dir_content):
400 files = dir_content
401 logger.warning(
402 "Unable to find a 'posterior_samples*' file in the existing "
403 "directory. Using '{}' as the existing metafile".format(
404 dir_content[0]
405 )
406 )
407 elif len(files) == 0:
408 raise InputError(
409 "Unable to find an existing metafile in the existing webdir"
410 )
411 elif len(files) > 1:
412 raise InputError(
413 "Multiple metafiles in the existing directory. Please either "
414 "run the `summarycombine_metafile` executable to combine the "
415 "meta files or simple remove the unwanted meta file"
416 )
417 self._existing_metafile = os.path.join(
418 self.existing, "samples", files[0]
419 )
421 @property
422 def style_file(self):
423 return self._style_file
425 @style_file.setter
426 def style_file(self, style_file):
427 default = conf.style_file
428 if style_file is not None and not os.path.isfile(style_file):
429 logger.warning(
430 "The file '{}' does not exist. Resorting to default".format(
431 style_file
432 )
433 )
434 style_file = default
435 elif style_file is not None and os.path.isfile(style_file):
436 logger.info(
437 "Using the file '{}' as the matplotlib style file".format(
438 style_file
439 )
440 )
441 elif style_file is None:
442 logger.debug(
443 "Using the default matplotlib style file"
444 )
445 style_file = default
446 make_cache_style_file(style_file)
447 self._style_file = style_file
449 @property
450 def filename(self):
451 return self._filename
453 @filename.setter
454 def filename(self, filename):
455 self._filename = filename
456 if filename is not None:
457 if "/" in filename:
458 logger.warning("")
459 filename = filename.split("/")[-1]
460 if os.path.isfile(os.path.join(self.webdir, "samples", filename)):
461 logger.warning(
462 "A file with filename '{}' already exists in the samples "
463 "directory '{}'. This will be overwritten"
464 )
466 @property
467 def user(self):
468 return self._user
470 @user.setter
471 def user(self, user):
472 try:
473 self._user = getuser()
474 logger.info(
475 conf.overwrite.format("user", conf.user, self._user)
476 )
477 except KeyError as e:
478 logger.info(
479 "Failed to grab user information because {}. Default will be "
480 "used".format(e)
481 )
482 self._user = user
484 @property
485 def host(self):
486 return socket.getfqdn()
488 @property
489 def webdir(self):
490 return self._webdir
492 @webdir.setter
493 def webdir(self, webdir):
494 cond1 = webdir is None or webdir == "None" or webdir == "none"
495 cond2 = (
496 self.existing is None or self.existing == "None"
497 or self.existing == "none"
498 )
499 if cond1 and cond2:
500 raise InputError(
501 "Please provide a web directory to store the webpages. If "
502 "you wish to add to an existing webpage, then pass the "
503 "existing web directory under the '--existing_webdir' command "
504 "line argument. If this is a new set of webpages, then pass "
505 "the web directory under the '--webdir' argument"
506 )
507 elif webdir is None and self.existing is not None:
508 if not os.path.isdir(self.existing):
509 raise InputError(
510 "The directory {} does not exist".format(self.existing)
511 )
512 entries = glob(self.existing + "/*")
513 if os.path.join(self.existing, "home.html") not in entries:
514 raise InputError(
515 "Please give the base directory of an existing output"
516 )
517 self._webdir = self.existing
518 else:
519 if not os.path.isdir(webdir):
520 logger.debug(
521 "Given web directory does not exist. Creating it now"
522 )
523 make_dir(webdir)
524 self._webdir = os.path.abspath(webdir)
526 @property
527 def baseurl(self):
528 return self._baseurl
530 @baseurl.setter
531 def baseurl(self, baseurl):
532 self._baseurl = baseurl
533 if baseurl is None:
534 self._baseurl = guess_url(self.webdir, self.host, self.user)
536 @property
537 def mcmc_samples(self):
538 return self._mcmc_samples
540 @mcmc_samples.setter
541 def mcmc_samples(self, mcmc_samples):
542 self._mcmc_samples = mcmc_samples
543 if self._mcmc_samples:
544 logger.info(
545 "Treating all samples as seperate mcmc chains for the same "
546 "analysis."
547 )
549 @property
550 def labels(self):
551 return self._labels
553 @labels.setter
554 def labels(self, labels):
555 if self.result_files is not None:
556 if any(self.is_pesummary_metafile(s) for s in self.result_files):
557 logger.warning(
558 "labels argument is ignored when a pesummary metafile is "
559 "input. Stored analyses will use their stored labels. If "
560 "you wish to change the labels, please use `summarymodify`"
561 )
562 labels = self.default_labels()
563 if not hasattr(self, "._labels"):
564 if labels is None:
565 labels = self.default_labels()
566 elif self.mcmc_samples and len(labels) != 1:
567 raise InputError(
568 "Please provide a single label that corresponds to all "
569 "mcmc samples"
570 )
571 elif len(np.unique(labels)) != len(labels):
572 raise InputError(
573 "Please provide unique labels for each result file"
574 )
575 for num, i in enumerate(labels):
576 if "." in i:
577 logger.warning(
578 "Replacing the label {} by {} to make it compatible "
579 "with the html pages".format(i, i.replace(".", "_"))
580 )
581 labels[num] = i.replace(".", "_")
582 if self.add_to_existing:
583 for i in labels:
584 if i in self.existing_labels:
585 raise InputError(
586 "The label '%s' already exists in the existing "
587 "metafile. Please pass another unique label"
588 )
590 if len(self.result_files) != len(labels) and not self.mcmc_samples:
591 import copy
592 _new_labels = copy.deepcopy(labels)
593 idx = 1
594 while len(_new_labels) < len(self.result_files):
595 _new_labels.extend(
596 [_label + str(idx) for _label in labels]
597 )
598 idx += 1
599 _new_labels = _new_labels[:len(self.result_files)]
600 logger.info(
601 "You have passed {} result files and {} labels. Setting "
602 "labels = {}".format(
603 len(self.result_files), len(labels), _new_labels
604 )
605 )
606 labels = _new_labels
607 self._labels = labels
609 @property
610 def config(self):
611 return self._config
613 @config.setter
614 def config(self, config):
615 if config and len(config) != len(self.labels):
616 raise InputError(
617 "Please provide a configuration file for each label"
618 )
619 if config is None and not self.meta_file:
620 self._config = [None] * len(self.labels)
621 elif self.meta_file:
622 self._config = [None] * len(self.labels)
623 else:
624 self._config = config
625 for num, ff in enumerate(self._config):
626 if isinstance(ff, str) and ff.lower() == "none":
627 self._config[num] = None
629 @property
630 def injection_file(self):
631 return self._injection_file
633 @injection_file.setter
634 def injection_file(self, injection_file):
635 if injection_file and len(injection_file) != len(self.labels):
636 if len(injection_file) == 1:
637 logger.info(
638 "Only one injection file passed. Assuming the same "
639 "injection for all {} result files".format(len(self.labels))
640 )
641 else:
642 raise InputError(
643 "You have passed {} for {} result files. Please provide an "
644 "injection file for each result file".format(
645 len(self.injection_file), len(self.labels)
646 )
647 )
648 if injection_file is None:
649 injection_file = [None] * len(self.labels)
650 self._injection_file = injection_file
652 @property
653 def injection_data(self):
654 return self._injection_data
656 @property
657 def file_version(self):
658 return self._file_version
660 @property
661 def file_kwargs(self):
662 return self._file_kwargs
664 @property
665 def kde_plot(self):
666 return self._kde_plot
668 @kde_plot.setter
669 def kde_plot(self, kde_plot):
670 self._kde_plot = kde_plot
671 if kde_plot != conf.kde_plot:
672 logger.info(
673 conf.overwrite.format("kde_plot", conf.kde_plot, kde_plot)
674 )
676 @property
677 def file_format(self):
678 return self._file_format
680 @file_format.setter
681 def file_format(self, file_format):
682 if file_format is None:
683 self._file_format = [None] * len(self.labels)
684 elif len(file_format) == 1 and len(file_format) != len(self.labels):
685 logger.warning(
686 "Only one file format specified. Assuming all files are of "
687 "this format"
688 )
689 self._file_format = [file_format[0]] * len(self.labels)
690 elif len(file_format) != len(self.labels):
691 raise InputError(
692 "Please provide a file format for each result file. If you "
693 "wish to specify the file format for the second result file "
694 "and not for any of the others, for example, simply pass 'None "
695 "{format} None'"
696 )
697 else:
698 for num, ff in enumerate(file_format):
699 if ff.lower() == "none":
700 file_format[num] = None
701 self._file_format = file_format
703 @property
704 def samples(self):
705 return self._samples
707 @samples.setter
708 def samples(self, samples):
709 if isinstance(samples, dict):
710 return samples
711 self._set_samples(samples)
713 def _set_samples(
714 self, samples,
715 ignore_keys=["prior", "weights", "labels", "indicies", "open_file"]
716 ):
717 """Extract the samples and store them as attributes of self
719 Parameters
720 ----------
721 samples: list
722 A list containing the paths to result files
723 ignore_keys: list, optional
724 A list containing properties of the read file that you do not want to be
725 stored as attributes of self
726 """
727 if not samples:
728 raise InputError("Please provide a results file")
729 _samples_generator = (self.is_pesummary_metafile(s) for s in samples)
730 if any(_samples_generator) and not all(_samples_generator):
731 raise InputError(
732 "It seems that you have passed a combination of pesummary "
733 "metafiles and non-pesummary metafiles. This is currently "
734 "not supported."
735 )
736 labels, labels_dict = None, {}
737 weights_dict = {}
738 if self.mcmc_samples:
739 nsamples = 0.
740 for num, i in enumerate(samples):
741 idx = num
742 if not self.mcmc_samples:
743 if not self.is_pesummary_metafile(samples[num]):
744 logger.info("Assigning {} to {}".format(self.labels[num], i))
745 else:
746 num = 0
747 if not os.path.isfile(i):
748 raise InputError("File %s does not exist" % (i))
749 if self.is_pesummary_metafile(samples[num]):
750 data = self.grab_data_from_input(
751 i, self.labels[num], config=None, injection=None
752 )
753 self.mcmc_samples = data["mcmc_samples"]
754 else:
755 data = self.grab_data_from_input(
756 i, self.labels[num], config=self.config[num],
757 injection=self.injection_file[num],
758 file_format=self.file_format[num]
759 )
760 if "config" in data.keys():
761 msg = (
762 "Overwriting the provided config file for '{}' with "
763 "the config information stored in the input "
764 "file".format(self.labels[num])
765 )
766 if self.config[num] is None:
767 logger.debug(msg)
768 else:
769 logger.info(msg)
770 self.config[num] = data.pop("config")
771 if self.mcmc_samples:
772 data["samples"] = {
773 "{}_mcmc_chain_{}".format(key, idx): item for key, item
774 in data["samples"].items()
775 }
776 for key, item in data.items():
777 if key not in ignore_keys:
778 if idx == 0:
779 setattr(self, "_{}".format(key), item)
780 else:
781 x = getattr(self, "_{}".format(key))
782 if isinstance(x, dict):
783 x.update(item)
784 elif isinstance(x, list):
785 x += item
786 setattr(self, "_{}".format(key), x)
787 if self.mcmc_samples:
788 try:
789 nsamples += data["file_kwargs"][self.labels[num]]["sampler"][
790 "nsamples"
791 ]
792 except UnboundLocalError:
793 pass
794 if "labels" in data.keys():
795 stored_labels = data["labels"]
796 else:
797 stored_labels = [self.labels[num]]
798 if "weights" in data.keys():
799 weights_dict.update(data["weights"])
800 if "prior" in data.keys():
801 for label in stored_labels:
802 pp = data["prior"]
803 if pp != {} and label in pp.keys() and pp[label] == []:
804 if len(self.priors):
805 if label not in self.priors["samples"].keys():
806 self.add_to_prior_dict(
807 "samples/{}".format(label), []
808 )
809 else:
810 self.add_to_prior_dict(
811 "samples/{}".format(label), []
812 )
813 elif pp != {} and label not in pp.keys():
814 for key in pp.keys():
815 if key in self.priors.keys():
816 if label in self.priors[key].keys():
817 logger.warning(
818 "Replacing the prior file for {} "
819 "with the prior file stored in "
820 "the result file".format(
821 label
822 )
823 )
824 if pp[key] == {}:
825 self.add_to_prior_dict(
826 "{}/{}".format(key, label), []
827 )
828 elif label not in pp[key].keys():
829 self.add_to_prior_dict(
830 "{}/{}".format(key, label), {}
831 )
832 else:
833 self.add_to_prior_dict(
834 "{}/{}".format(key, label), pp[key][label]
835 )
836 else:
837 self.add_to_prior_dict(
838 "samples/{}".format(label), []
839 )
840 if "labels" in data.keys():
841 _duplicated = [
842 _ for _ in data["labels"] if num != 0 and _ in labels
843 ]
844 if num == 0:
845 labels = data["labels"]
846 elif len(_duplicated):
847 raise InputError(
848 "The labels stored in the supplied files are not "
849 "unique. The label{}: '{}' appear{} in two or more "
850 "files. Please provide unique labels for each "
851 "analysis.".format(
852 "s" if len(_duplicated) > 1 else "",
853 ", ".join(_duplicated),
854 "" if len(_duplicated) > 1 else "s"
855 )
856 )
857 else:
858 labels += data["labels"]
859 labels_dict[num] = data["labels"]
860 if self.mcmc_samples:
861 try:
862 self.file_kwargs[self.labels[0]]["sampler"].update(
863 {"nsamples": nsamples, "nchains": len(self.result_files)}
864 )
865 except (KeyError, UnboundLocalError):
866 pass
867 _labels = list(self._samples.keys())
868 if not isinstance(self._samples[_labels[0]], MCMCSamplesDict):
869 self._samples = MCMCSamplesDict(self._samples)
870 else:
871 self._samples = self._samples[_labels[0]]
872 if labels is not None:
873 self._labels = labels
874 if len(labels) != len(self.result_files):
875 result_files = []
876 for num, f in enumerate(samples):
877 for ii in np.arange(len(labels_dict[num])):
878 result_files.append(self.result_files[num])
879 self.result_files = result_files
880 self.weights = {i: None for i in self.labels}
881 if weights_dict != {}:
882 self.weights = weights_dict
884 @property
885 def burnin_method(self):
886 return self._burnin_method
888 @burnin_method.setter
889 def burnin_method(self, burnin_method):
890 self._burnin_method = burnin_method
891 if not self.mcmc_samples and burnin_method is not None:
892 logger.info(
893 "The {} method will not be used to remove samples as "
894 "burnin as this can only be used for mcmc chains.".format(
895 burnin_method
896 )
897 )
898 self._burnin_method = None
899 elif self.mcmc_samples and burnin_method is None:
900 logger.info(
901 "No burnin method provided. Using {} as default".format(
902 conf.burnin_method
903 )
904 )
905 self._burnin_method = conf.burnin_method
906 elif self.mcmc_samples:
907 from pesummary.core.file import mcmc
909 if burnin_method not in mcmc.algorithms:
910 logger.warning(
911 "Unrecognised burnin method: {}. Resorting to the default: "
912 "{}".format(burnin_method, conf.burnin_method)
913 )
914 self._burnin_method = conf.burnin_method
915 if self._burnin_method is not None:
916 for label in self.labels:
917 self.file_kwargs[label]["sampler"]["burnin_method"] = (
918 self._burnin_method
919 )
921 @property
922 def burnin(self):
923 return self._burnin
925 @burnin.setter
926 def burnin(self, burnin):
927 _name = "nsamples_removed_from_burnin"
928 if burnin is not None:
929 samples_lengths = [
930 self.samples[key].number_of_samples for key in
931 self.samples.keys()
932 ]
933 if not all(int(burnin) < i for i in samples_lengths):
934 raise InputError(
935 "The chosen burnin is larger than the number of samples. "
936 "Please choose a value less than {}".format(
937 np.max(samples_lengths)
938 )
939 )
940 logger.info(
941 conf.overwrite.format("burnin", conf.burnin, burnin)
942 )
943 burnin = int(burnin)
944 else:
945 burnin = conf.burnin
946 if self.burnin_method is not None:
947 arguments, kwargs = [], {}
948 if burnin != 0 and self.burnin_method == "burnin_by_step_number":
949 logger.warning(
950 "The first {} samples have been requested to be removed "
951 "as burnin, but the burnin method has been chosen to be "
952 "burnin_by_step_number. Changing method to "
953 "burnin_by_first_n with keyword argument step_number="
954 "True such that all samples with step number < {} are "
955 "removed".format(burnin, burnin)
956 )
957 self.burnin_method = "burnin_by_first_n"
958 arguments = [burnin]
959 kwargs = {"step_number": True}
960 elif self.burnin_method == "burnin_by_first_n":
961 arguments = [burnin]
962 initial = self.samples.total_number_of_samples
963 self._samples = self.samples.burnin(
964 *arguments, algorithm=self.burnin_method, **kwargs
965 )
966 diff = initial - self.samples.total_number_of_samples
967 self.file_kwargs[self.labels[0]]["sampler"][_name] = diff
968 self.file_kwargs[self.labels[0]]["sampler"]["nsamples"] = \
969 self._samples.total_number_of_samples
970 else:
971 for label in self.samples:
972 self.samples[label] = self.samples[label].discard_samples(
973 burnin
974 )
975 if burnin != conf.burnin:
976 self.file_kwargs[label]["sampler"][_name] = burnin
978 @property
979 def nsamples(self):
980 return self._nsamples
982 @nsamples.setter
983 def nsamples(self, nsamples):
984 self._nsamples = nsamples
985 if nsamples is not None:
986 logger.info(
987 "{} samples will be used for each result file".format(nsamples)
988 )
989 self._nsamples = int(nsamples)
991 @property
992 def reweight_samples(self):
993 return self._reweight_samples
995 @reweight_samples.setter
996 def reweight_samples(self, reweight_samples):
997 from pesummary.core.reweight import options
998 self._reweight_samples = self._check_reweight_samples(
999 reweight_samples, options
1000 )
1002 def _check_reweight_samples(self, reweight_samples, options):
1003 if reweight_samples and reweight_samples not in options.keys():
1004 logger.warning(
1005 "Unknown reweight function: '{}'. Not reweighting posterior "
1006 "and/or prior samples".format(reweight_samples)
1007 )
1008 return False
1009 return reweight_samples
1011 @property
1012 def path_to_samples(self):
1013 return self._path_to_samples
1015 @path_to_samples.setter
1016 def path_to_samples(self, path_to_samples):
1017 self._path_to_samples = path_to_samples
1018 if path_to_samples is None:
1019 self._path_to_samples = {label: None for label in self.labels}
1020 elif len(path_to_samples) != len(self.labels):
1021 raise InputError(
1022 "Please provide a path for all result files passed. If "
1023 "two result files are passed, and only one requires the "
1024 "path_to_samples arguement, please pass --path_to_samples "
1025 "None path/to/samples"
1026 )
1027 else:
1028 _paths = {}
1029 for num, path in enumerate(path_to_samples):
1030 _label = self.labels[num]
1031 if path.lower() == "none":
1032 _paths[_label] = None
1033 else:
1034 _paths[_label] = path
1035 self._path_to_samples = _paths
1037 @property
1038 def priors(self):
1039 return self._priors
1041 @priors.setter
1042 def priors(self, priors):
1043 self._priors = self.grab_priors_from_inputs(priors)
1045 @property
1046 def custom_plotting(self):
1047 return self._custom_plotting
1049 @custom_plotting.setter
1050 def custom_plotting(self, custom_plotting):
1051 self._custom_plotting = custom_plotting
1052 if custom_plotting is not None:
1053 import importlib
1055 path_to_python_file = os.path.dirname(custom_plotting)
1056 python_file = os.path.splitext(os.path.basename(custom_plotting))[0]
1057 if path_to_python_file != "":
1058 import sys
1060 sys.path.append(path_to_python_file)
1061 try:
1062 mod = importlib.import_module(python_file)
1063 methods = getattr(mod, "__single_plots__", list()).copy()
1064 methods += getattr(mod, "__comparion_plots__", list()).copy()
1065 if len(methods) > 0:
1066 self._custom_plotting = [path_to_python_file, python_file]
1067 else:
1068 logger.warning(
1069 "No __single_plots__ or __comparison_plots__ in {}. "
1070 "If you wish to use custom plotting, then please "
1071 "add the variable :__single_plots__ and/or "
1072 "__comparison_plots__ in future. No custom plotting "
1073 "will be done"
1074 )
1075 except ModuleNotFoundError as e:
1076 logger.warning(
1077 "Failed to import {} because {}. No custom plotting will "
1078 "be done".format(python_file, e)
1079 )
1081 @property
1082 def external_hdf5_links(self):
1083 return self._external_hdf5_links
1085 @external_hdf5_links.setter
1086 def external_hdf5_links(self, external_hdf5_links):
1087 self._external_hdf5_links = external_hdf5_links
1088 if not self.hdf5 and self.external_hdf5_links:
1089 logger.warning(
1090 "You can only apply external hdf5 links when saving the meta "
1091 "file in hdf5 format. Turning external hdf5 links off."
1092 )
1093 self._external_hdf5_links = False
1095 @property
1096 def hdf5_compression(self):
1097 return self._hdf5_compression
1099 @hdf5_compression.setter
1100 def hdf5_compression(self, hdf5_compression):
1101 self._hdf5_compression = hdf5_compression
1102 if not self.hdf5 and hdf5_compression is not None:
1103 logger.warning(
1104 "You can only apply compression when saving the meta "
1105 "file in hdf5 format. Turning compression off."
1106 )
1107 self._hdf5_compression = None
1109 @property
1110 def existing_plot(self):
1111 return self._existing_plot
1113 @existing_plot.setter
1114 def existing_plot(self, existing_plot):
1115 self._existing_plot = existing_plot
1116 if self._existing_plot is not None:
1117 from pathlib import Path
1118 import shutil
1119 if isinstance(self._existing_plot, list):
1120 logger.warning(
1121 "Assigning {} to all labels".format(
1122 ", ".join(self._existing_plot)
1123 )
1124 )
1125 self._existing_plot = {
1126 label: self._existing_plot for label in self.labels
1127 }
1128 _does_not_exist = (
1129 "The plot {} does not exist. Not adding plot to summarypages."
1130 )
1131 keys_to_remove = []
1132 for key, _plot in self._existing_plot.items():
1133 if isinstance(_plot, list):
1134 allowed = []
1135 for _subplot in _plot:
1136 if not os.path.isfile(_subplot):
1137 logger.warning(_does_not_exist.format(_subplot))
1138 else:
1139 _filename = os.path.join(
1140 self.webdir, "plots", Path(_subplot).name
1141 )
1142 try:
1143 shutil.copyfile(_subplot, _filename)
1144 except shutil.SameFileError:
1145 pass
1146 allowed.append(_filename)
1147 if not len(allowed):
1148 keys_to_remove.append(key)
1149 elif len(allowed) == 1:
1150 self._existing_plot[key] = allowed[0]
1151 else:
1152 self._existing_plot[key] = allowed
1153 else:
1154 if not os.path.isfile(_plot):
1155 logger.warning(_does_not_exist.format(_plot))
1156 keys_to_remove.append(key)
1157 else:
1158 _filename = os.path.join(
1159 self.webdir, "plots", Path(_plot).name
1160 )
1161 try:
1162 shutil.copyfile(_plot, _filename)
1163 except shutil.SameFileError:
1164 _filename = os.path.join(
1165 self.webdir, "plots", key + "_" + Path(_plot).name
1166 )
1167 shutil.copyfile(_plot, _filename)
1168 self._existing_plot[key] = _filename
1169 for key in keys_to_remove:
1170 del self._existing_plot[key]
1171 if not len(self._existing_plot):
1172 self._existing_plot = None
1174 def add_to_prior_dict(self, path, data):
1175 """Add priors to the prior dictionary
1177 Parameters
1178 ----------
1179 path: str
1180 the location where you wish to store the prior. If this is inside
1181 a nested dictionary, then please pass the path as 'a/b'
1182 data: np.ndarray
1183 the prior samples
1184 """
1185 from functools import reduce
1187 def build_tree(dictionary, path):
1188 """Build a dictionary tree from a list of keys
1190 Parameters
1191 ----------
1192 dictionary: dict
1193 existing dictionary that you wish to add to
1194 path: list
1195 list of keys specifying location
1197 Examples
1198 --------
1199 >>> dictionary = {"label": {"mass_1": [1,2,3,4,5,6]}}
1200 >>> path = ["label", "mass_2"]
1201 >>> build_tree(dictionary, path)
1202 {"label": {"mass_1": [1,2,3,4,5,6], "mass_2": {}}}
1203 """
1204 if path != [] and path[0] not in dictionary.keys():
1205 dictionary[path[0]] = {}
1206 if path != []:
1207 build_tree(dictionary[path[0]], path[1:])
1208 return dictionary
1210 def get_nested_dictionary(dictionary, path):
1211 """Return a nested dictionary from a list specifying path
1213 Parameters
1214 ----------
1215 dictionary: dict
1216 existing dictionary that you wish to extract information from
1217 path: list
1218 list of keys specifying location
1220 Examples
1221 --------
1222 >>> dictionary = {"label": {"mass_1": [1,2,3,4,5,6]}}
1223 >>> path = ["label", "mass_1"]
1224 >>> get_nested_dictionary(dictionary, path)
1225 [1,2,3,4,5,6]
1226 """
1227 return reduce(dict.get, path, dictionary)
1229 if "/" in path:
1230 path = path.split("/")
1231 else:
1232 path = [path]
1233 tree = build_tree(self._priors, path)
1234 nested_dictionary = get_nested_dictionary(self._priors, path[:-1])
1235 nested_dictionary[path[-1]] = data
1237 def grab_priors_from_inputs(self, priors, read_func=None, read_kwargs={}):
1238 """
1239 """
1240 if read_func is None:
1241 from pesummary.core.file.read import read as Read
1242 read_func = Read
1244 prior_dict = {}
1245 if priors is not None:
1246 prior_dict = {"samples": {}, "analytic": {}}
1247 for i in priors:
1248 if not os.path.isfile(i):
1249 raise InputError("The file {} does not exist".format(i))
1250 if len(priors) != len(self.labels) and len(priors) == 1:
1251 logger.warning(
1252 "You have only specified a single prior file for {} result "
1253 "files. Assuming the same prior file for all result "
1254 "files".format(len(self.labels))
1255 )
1256 data = read_func(
1257 priors[0], nsamples=self.nsamples_for_prior
1258 )
1259 for i in self.labels:
1260 prior_dict["samples"][i] = data.samples_dict
1261 try:
1262 if data.analytic is not None:
1263 prior_dict["analytic"][i] = data.analytic
1264 except AttributeError:
1265 continue
1266 elif len(priors) != len(self.labels):
1267 raise InputError(
1268 "Please provide a prior file for each result file"
1269 )
1270 else:
1271 for num, i in enumerate(priors):
1272 if i.lower() == "none":
1273 continue
1274 logger.info(
1275 "Assigning {} to {}".format(self.labels[num], i)
1276 )
1277 if self.labels[num] in read_kwargs.keys():
1278 grab_data_kwargs = read_kwargs[self.labels[num]]
1279 else:
1280 grab_data_kwargs = read_kwargs
1281 data = read_func(
1282 priors[num], nsamples=self.nsamples_for_prior,
1283 **grab_data_kwargs
1284 )
1285 prior_dict["samples"][self.labels[num]] = data.samples_dict
1286 try:
1287 if data.analytic is not None:
1288 prior_dict["analytic"][self.labels[num]] = data.analytic
1289 except AttributeError:
1290 continue
1291 return prior_dict
1293 @property
1294 def grab_data_kwargs(self):
1295 return {
1296 label: dict(regenerate=self.regenerate) for label in self.labels
1297 }
1299 def grab_data_from_input(
1300 self, file, label, config=None, injection=None, file_format=None
1301 ):
1302 """Wrapper function for the grab_data_from_metafile and
1303 grab_data_from_file functions
1305 Parameters
1306 ----------
1307 file: str
1308 path to the result file
1309 label: str
1310 label that you wish to use for the result file
1311 config: str, optional
1312 path to a configuration file used in the analysis
1313 injection: str, optional
1314 path to an injection file used in the analysis
1315 file_format, str, optional
1316 the file format you wish to use when loading. Default None.
1317 If None, the read function loops through all possible options
1318 mcmc: Bool, optional
1319 if True, the result file is an mcmc chain
1320 """
1321 if label in self.grab_data_kwargs.keys():
1322 grab_data_kwargs = self.grab_data_kwargs[label]
1323 else:
1324 grab_data_kwargs = self.grab_data_kwargs
1326 if self.is_pesummary_metafile(file):
1327 data = self.grab_data_from_metafile(
1328 file, self.webdir, compare=self.compare_results,
1329 nsamples=self.nsamples, reweight_samples=self.reweight_samples,
1330 disable_injection=self.disable_injection,
1331 keep_nan_likelihood_samples=self.keep_nan_likelihood_samples,
1332 **grab_data_kwargs
1333 )
1334 else:
1335 data = self.grab_data_from_file(
1336 file, label, self.webdir, config=config, injection=injection,
1337 file_format=file_format, nsamples=self.nsamples,
1338 disable_prior_sampling=self.disable_prior_sampling,
1339 nsamples_for_prior=self.nsamples_for_prior,
1340 path_to_samples=self.path_to_samples[label],
1341 reweight_samples=self.reweight_samples,
1342 keep_nan_likelihood_samples=self.keep_nan_likelihood_samples,
1343 num_processes=self.multi_process,
1344 **grab_data_kwargs
1345 )
1346 self._open_result_files.update({file: data["open_file"]})
1347 return data
1349 @property
1350 def email(self):
1351 return self._email
1353 @email.setter
1354 def email(self, email):
1355 if email is not None and "@" not in email:
1356 raise InputError("Please provide a valid email address")
1357 self._email = email
1359 @property
1360 def dump(self):
1361 return self._dump
1363 @dump.setter
1364 def dump(self, dump):
1365 self._dump = dump
1367 @property
1368 def palette(self):
1369 return self._palette
1371 @palette.setter
1372 def palette(self, palette):
1373 self._palette = palette
1374 if palette is not conf.palette:
1375 from pesummary.core.plots.palette import (
1376 color_palette, AVAILABLE_PALETTES
1377 )
1378 try:
1379 color_palette(palette, n_colors=1)
1380 logger.info(
1381 conf.overwrite.format("palette", conf.palette, palette)
1382 )
1383 conf.palette = palette
1384 except ValueError as e:
1385 raise InputError(
1386 "Unrecognised palette. Please choose from one of the "
1387 "following {}".format(
1388 ", ".join(AVAILABLE_PALETTES)
1389 )
1390 )
1392 @property
1393 def include_prior(self):
1394 return self._include_prior
1396 @include_prior.setter
1397 def include_prior(self, include_prior):
1398 self._include_prior = include_prior
1399 if include_prior != conf.include_prior:
1400 conf.overwrite.format("prior", conf.include_prior, include_prior)
1401 conf.include_prior = include_prior
1403 @property
1404 def colors(self):
1405 return self._colors
1407 @colors.setter
1408 def colors(self, colors):
1409 if colors is not None:
1410 number = len(self.labels)
1411 if self.existing:
1412 number += len(self.existing_labels)
1413 if len(colors) != number and len(colors) > number:
1414 logger.info(
1415 "You have passed {} colors for {} result files. Setting "
1416 "colors = {}".format(
1417 len(colors), number, colors[:number]
1418 )
1419 )
1420 self._colors = colors[:number]
1421 return
1422 elif len(colors) != number:
1423 logger.warning(
1424 "Number of colors does not match the number of labels. "
1425 "Using default colors"
1426 )
1427 from pesummary.core.plots.palette import color_palette
1428 number = len(self.labels)
1429 if self.existing:
1430 number += len(self.existing_labels)
1431 colors = color_palette(
1432 palette=conf.palette, n_colors=number
1433 ).as_hex()
1434 self._colors = colors
1436 @property
1437 def linestyles(self):
1438 return self._linestyles
1440 @linestyles.setter
1441 def linestyles(self, linestyles):
1442 if linestyles is not None:
1443 if len(linestyles) != len(self.colors):
1444 if len(linestyles) > len(self.colors):
1445 logger.info(
1446 "You have passed {} linestyles for {} result files. "
1447 "Setting linestyles = {}".format(
1448 len(linestyles), len(self.colors),
1449 linestyles[:len(self.colors)]
1450 )
1451 )
1452 self._linestyles = linestyles[:len(self.colors)]
1453 return
1454 else:
1455 logger.warning(
1456 "Number of linestyles does not match the number of "
1457 "labels. Using default linestyles"
1458 )
1459 available_linestyles = ["-", "--", ":", "-."]
1460 linestyles = ["-"] * len(self.colors)
1461 unique_colors = np.unique(self.colors)
1462 for color in unique_colors:
1463 indicies = [num for num, i in enumerate(self.colors) if i == color]
1464 for idx, j in enumerate(indicies):
1465 linestyles[j] = available_linestyles[
1466 np.mod(idx, len(available_linestyles))
1467 ]
1468 self._linestyles = linestyles
1470 @property
1471 def disable_corner(self):
1472 return self._disable_corner
1474 @disable_corner.setter
1475 def disable_corner(self, disable_corner):
1476 self._disable_corner = disable_corner
1477 if disable_corner:
1478 logger.warning(
1479 "No corner plot will be produced. This will reduce overall "
1480 "runtime but does mean that the interactive corner plot feature "
1481 "on the webpages will no longer work"
1482 )
1484 @property
1485 def add_to_corner(self):
1486 return self._add_to_corner
1488 @add_to_corner.setter
1489 def add_to_corner(self, add_to_corner):
1490 self._add_to_corner = self._set_corner_params(add_to_corner)
1492 def _set_corner_params(self, corner_params):
1493 cls = self.__class__.__name__
1494 if corner_params is not None:
1495 for label in self.labels:
1496 _not_included = [
1497 param for param in corner_params if param not in
1498 self.samples[label].keys()
1499 ]
1500 if len(_not_included) == len(corner_params) and cls == "Input":
1501 logger.warning(
1502 "None of the chosen corner parameters are "
1503 "included in the posterior table for '{}'. Using "
1504 "all available parameters for the corner plot".format(
1505 label
1506 )
1507 )
1508 corner_params = None
1509 break
1510 elif len(_not_included):
1511 logger.warning(
1512 "The following parameters are not included in the "
1513 "posterior table for '{}': {}. Not adding to corner "
1514 "plot".format(label, ", ".join(_not_included))
1515 )
1516 elif cls == "Input":
1517 logger.debug(
1518 "Using all parameters stored in the result file for the "
1519 "corner plots. This may take some time."
1520 )
1521 return corner_params
1523 @property
1524 def pe_algorithm(self):
1525 return self._pe_algorithm
1527 @pe_algorithm.setter
1528 def pe_algorithm(self, pe_algorithm):
1529 self._pe_algorithm = pe_algorithm
1530 if pe_algorithm is None:
1531 return
1532 if len(pe_algorithm) != len(self.labels):
1533 raise ValueError("Please provide an algorithm for each result file")
1534 for num, (label, _algorithm) in enumerate(zip(self.labels, pe_algorithm)):
1535 if "pe_algorithm" in self.file_kwargs[label]["sampler"].keys():
1536 _stored = self.file_kwargs[label]["sampler"]["pe_algorithm"]
1537 if _stored != _algorithm:
1538 logger.warning(
1539 "Overwriting the pe_algorithm extracted from the file "
1540 "'{}': {} with the algorithm provided from the command "
1541 "line: {}".format(
1542 self.result_files[num], _stored, _algorithm
1543 )
1544 )
1545 self.file_kwargs[label]["sampler"]["pe_algorithm"] = _algorithm
1547 @property
1548 def notes(self):
1549 return self._notes
1551 @notes.setter
1552 def notes(self, notes):
1553 self._notes = notes
1554 if notes is not None:
1555 if not os.path.isfile(notes):
1556 raise InputError(
1557 "No such file or directory called {}".format(notes)
1558 )
1559 try:
1560 with open(notes, "r") as f:
1561 self._notes = f.read()
1562 except FileNotFoundError:
1563 logger.warning(
1564 "No such file or directory called {}. Custom notes will "
1565 "not be added to the summarypages".format(notes)
1566 )
1567 except IOError as e:
1568 logger.warning(
1569 "Failed to read {}. Unable to put notes on "
1570 "summarypages".format(notes)
1571 )
1573 @property
1574 def descriptions(self):
1575 return self._descriptions
1577 @descriptions.setter
1578 def descriptions(self, descriptions):
1579 import json
1580 if hasattr(self, "_descriptions") and not len(descriptions):
1581 return
1582 elif not len(descriptions):
1583 self._descriptions = None
1584 return
1586 if len(descriptions) and isinstance(descriptions, dict):
1587 data = descriptions
1588 elif len(descriptions):
1589 descriptions = descriptions[0]
1590 _is_file = not isinstance(descriptions, dict)
1591 if hasattr(self, "_descriptions"):
1592 logger.warning(
1593 "Ignoring descriptions found in result file and using "
1594 "descriptions in '{}'".format(descriptions)
1595 )
1596 self._descriptions = None
1597 if _is_file and not os.path.isfile(descriptions):
1598 logger.warning(
1599 "No such file called {}. Unable to add descriptions".format(
1600 descriptions
1601 )
1602 )
1603 return
1604 if _is_file:
1605 try:
1606 with open(descriptions, "r") as f:
1607 data = json.load(f)
1608 except json.decoder.JSONDecodeError:
1609 logger.warning(
1610 "Unable to open file '{}'. Not storing descriptions".format(
1611 descriptions
1612 )
1613 )
1614 return
1615 if not all(label in data.keys() for label in self.labels):
1616 not_included = [
1617 label for label in self.labels if label not in data.keys()
1618 ]
1619 logger.debug(
1620 "No description found for '{}'. Using default "
1621 "description".format(", ".join(not_included))
1622 )
1623 for label in not_included:
1624 data[label] = "No description found"
1625 if len(data.keys()) > len(self.labels):
1626 logger.warning(
1627 "Descriptions file contains descriptions for analyses other "
1628 "than {}. Ignoring other descriptions".format(
1629 ", ".join(self.labels)
1630 )
1631 )
1632 other = [
1633 analysis for analysis in data.keys() if analysis not in
1634 self.labels
1635 ]
1636 for analysis in other:
1637 _ = data.pop(analysis)
1638 _remove = []
1639 for key, desc in data.items():
1640 if not isinstance(desc, (str, bytes)):
1641 logger.warning(
1642 "Unknown description '{}' for '{}'. The description should "
1643 "be a string or bytes object"
1644 )
1645 _remove.append(key)
1646 if len(_remove):
1647 for analysis in _remove:
1648 _ = data.pop(analysis)
1649 self._descriptions = data
1651 @property
1652 def preferred(self):
1653 return self._preferred
1655 @preferred.setter
1656 def preferred(self, preferred):
1657 if preferred is not None and preferred not in self.labels:
1658 logger.warning(
1659 "'{}' not in list of labels. Unable to stored as the "
1660 "preferred analysis".format(preferred)
1661 )
1662 self._preferred = None
1663 elif preferred is not None:
1664 logger.debug(
1665 "Setting '{}' as the preferred analysis".format(preferred)
1666 )
1667 self._preferred = preferred
1668 elif len(self.labels) == 1:
1669 self._preferred = self.labels[0]
1670 else:
1671 self._preferred = None
1672 if self._preferred is not None:
1673 try:
1674 self.file_kwargs[self._preferred]["other"].update(
1675 {"preferred": "True"}
1676 )
1677 except KeyError:
1678 self.file_kwargs[self._preferred].update(
1679 {"other": {"preferred": "True"}}
1680 )
1681 for _label in self.labels:
1682 if self._preferred is not None and _label == self._preferred:
1683 continue
1684 try:
1685 self.file_kwargs[_label]["other"].update(
1686 {"preferred": "False"}
1687 )
1688 except KeyError:
1689 self.file_kwargs[_label].update(
1690 {"other": {"preferred": "False"}}
1691 )
1692 return
1694 @property
1695 def public(self):
1696 return self._public
1698 @public.setter
1699 def public(self, public):
1700 self._public = public
1701 if public != conf.public:
1702 logger.info(
1703 conf.overwrite.format("public", conf.public, public)
1704 )
1706 @property
1707 def multi_process(self):
1708 return self._multi_process
1710 @multi_process.setter
1711 def multi_process(self, multi_process):
1712 self._multi_process = int(multi_process)
1713 if multi_process is not None and int(multi_process) != int(conf.multi_process):
1714 logger.info(
1715 conf.overwrite.format(
1716 "multi_process", conf.multi_process, multi_process
1717 )
1718 )
1720 @property
1721 def publication_kwargs(self):
1722 return self._publication_kwargs
1724 @publication_kwargs.setter
1725 def publication_kwargs(self, publication_kwargs):
1726 self._publication_kwargs = publication_kwargs
1727 if publication_kwargs != {}:
1728 allowed_kwargs = ["gridsize"]
1729 if not any(i in publication_kwargs.keys() for i in allowed_kwargs):
1730 logger.warning(
1731 "Currently the only allowed publication kwargs are {}. "
1732 "Ignoring other inputs.".format(
1733 ", ".join(allowed_kwargs)
1734 )
1735 )
1737 @property
1738 def ignore_parameters(self):
1739 return self._ignore_parameters
1741 @ignore_parameters.setter
1742 def ignore_parameters(self, ignore_parameters):
1743 self._ignore_parameters = ignore_parameters
1744 if ignore_parameters is not None:
1745 for num, label in enumerate(self.labels):
1746 removed_parameters = list_match(
1747 list(self.samples[label].keys()), ignore_parameters
1748 )
1749 if not len(removed_parameters):
1750 logger.warning(
1751 "Failed to remove any parameters from {}".format(
1752 self.result_files[num]
1753 )
1754 )
1755 else:
1756 logger.warning(
1757 "Removing parameters: {} from {}".format(
1758 ", ".join(removed_parameters),
1759 self.result_files[num]
1760 )
1761 )
1762 for ignore in removed_parameters:
1763 self.samples[label].pop(ignore)
1765 @staticmethod
1766 def _make_directories(webdir, dirs):
1767 """Make the directories to store the information
1768 """
1769 for i in dirs:
1770 if not os.path.isdir(os.path.join(webdir, i)):
1771 make_dir(os.path.join(webdir, i))
1773 def make_directories(self):
1774 """Make the directories to store the information
1775 """
1776 self._make_directories(self.webdir, self.default_directories)
1778 @staticmethod
1779 def _copy_files(paths):
1780 """Copy the relevant file to the web directory
1782 Parameters
1783 ----------
1784 paths: nd list
1785 list of files you wish to copy. First element is the path of the
1786 file to copy and second element is the location of where you
1787 wish the file to be put
1789 Examples
1790 --------
1791 >>> paths = [
1792 ... ["config/config.ini", "webdir/config.ini"],
1793 ... ["samples/samples.h5", "webdir/samples.h5"]
1794 ... ]
1795 """
1796 import shutil
1798 for ff in paths:
1799 shutil.copyfile(ff[0], ff[1])
1801 def copy_files(self):
1802 """Copy the relevant file to the web directory
1803 """
1804 self._copy_files(self.default_files_to_copy)
1806 def default_labels(self):
1807 """Return a list of default labels.
1808 """
1809 from time import time
1811 def _default_label(file_name):
1812 return "%s_%s" % (round(time()), file_name)
1814 label_list = []
1815 if self.result_files is None or len(self.result_files) == 0:
1816 raise InputError("Please provide a results file")
1817 elif self.mcmc_samples:
1818 f = self.result_files[0]
1819 file_name = os.path.splitext(os.path.basename(f))[0]
1820 label_list.append(_default_label(file_name))
1821 else:
1822 for num, i in enumerate(self.result_files):
1823 file_name = os.path.splitext(os.path.basename(i))[0]
1824 label_list.append(_default_label(file_name))
1826 duplicates = dict(set(
1827 (x, label_list.count(x)) for x in
1828 filter(lambda rec: label_list.count(rec) > 1, label_list)))
1830 for i in duplicates.keys():
1831 for j in range(duplicates[i]):
1832 ind = label_list.index(i)
1833 label_list[ind] += "_%s" % (j)
1834 if self.add_to_existing:
1835 for num, i in enumerate(label_list):
1836 if i in self.existing_labels:
1837 ind = label_list.index(i)
1838 label_list[ind] += "_%s" % (num)
1839 return label_list
1841 @staticmethod
1842 def get_package_information():
1843 """Return a dictionary of parameter information
1844 """
1845 from pesummary._version_helper import PackageInformation
1846 from operator import itemgetter
1848 _package = PackageInformation()
1849 package_info = _package.package_info
1850 package_dir = _package.package_dir
1851 if "build_string" in package_info[0]: # conda list
1852 headings = ("name", "version", "channel", "build_string")
1853 else: # pip list installed
1854 headings = ("name", "version")
1855 packages = np.array([
1856 tuple(pkg[col.lower()] for col in headings) for pkg in
1857 sorted(package_info, key=itemgetter("name"))
1858 ], dtype=[(col, "S20") for col in headings]).view(np.recarray)
1859 return {
1860 "packages": packages, "environment": [package_dir],
1861 "manager": _package.package_manager
1862 }
1864 def grab_key_data_from_result_files(self):
1865 """Grab the mean, median, maxL and standard deviation for all
1866 parameters for all each result file
1867 """
1868 key_data = {
1869 key: samples.key_data for key, samples in self.samples.items()
1870 }
1871 for key, val in self.samples.items():
1872 for j in val.keys():
1873 _inj = self.injection_data[key][j]
1874 key_data[key][j]["injected"] = (
1875 _inj[0] if not math.isnan(_inj) and isinstance(
1876 _inj, (list, np.ndarray)
1877 ) else _inj
1878 )
1879 return key_data
1882class BaseInput(_Input):
1883 """Class to handle and store base command line arguments
1884 """
1885 def __init__(self, opts, ignore_copy=False, checkpoint=None, gw=False):
1886 self.opts = opts
1887 self.gw = gw
1888 self.restart_from_checkpoint = self.opts.restart_from_checkpoint
1889 if checkpoint is not None:
1890 for key, item in vars(checkpoint).items():
1891 setattr(self, key, item)
1892 logger.info(
1893 "Loaded command line arguments: {}".format(self.opts)
1894 )
1895 self.restart_from_checkpoint = True
1896 self._restarted_from_checkpoint = True
1897 return
1898 self.seed = self.opts.seed
1899 self.result_files = self.opts.samples
1900 self.user = self.opts.user
1901 self.existing = self.opts.existing
1902 self.add_to_existing = False
1903 if self.existing is not None:
1904 self.add_to_existing = True
1905 self.existing_metafile = True
1906 self.webdir = self.opts.webdir
1907 self._restarted_from_checkpoint = False
1908 self.resume_file_dir = conf.checkpoint_dir(self.webdir)
1909 self.resume_file = conf.resume_file
1910 self._resume_file_path = os.path.join(
1911 self.resume_file_dir, self.resume_file
1912 )
1913 self.make_directories()
1914 self.email = self.opts.email
1915 self.pe_algorithm = self.opts.pe_algorithm
1916 self.multi_process = self.opts.multi_process
1917 self.package_information = self.get_package_information()
1918 if not ignore_copy:
1919 self.copy_files()
1920 self.write_current_state()
1922 @property
1923 def default_directories(self):
1924 return ["checkpoint"]
1926 @property
1927 def default_files_to_copy(self):
1928 return []
1930 def write_current_state(self):
1931 """Write the current state of the input class to file
1932 """
1933 from pesummary.io import write
1934 write(
1935 self, outdir=self.resume_file_dir, file_format="pickle",
1936 filename=self.resume_file, overwrite=True
1937 )
1938 logger.debug(
1939 "Written checkpoint file: {}".format(self._resume_file_path)
1940 )
1943class SamplesInput(BaseInput):
1944 """Class to handle and store sample specific command line arguments
1945 """
1946 def __init__(self, *args, extra_options=None, **kwargs):
1947 """
1948 """
1949 super(SamplesInput, self).__init__(*args, **kwargs)
1950 if self.result_files is not None:
1951 self._open_result_files = {path: None for path in self.result_files}
1952 self.meta_file = False
1953 if self.result_files is not None and len(self.result_files) == 1:
1954 self.meta_file = self.is_pesummary_metafile(self.result_files[0])
1955 self.compare_results = self.opts.compare_results
1956 self.disable_injection = self.opts.disable_injection
1957 if self.existing is not None:
1958 self.existing_data = self.grab_data_from_metafile(
1959 self.existing_metafile, self.existing,
1960 compare=self.compare_results
1961 )
1962 self.existing_samples = self.existing_data["samples"]
1963 self.existing_injection_data = self.existing_data["injection_data"]
1964 self.existing_file_version = self.existing_data["file_version"]
1965 self.existing_file_kwargs = self.existing_data["file_kwargs"]
1966 self.existing_priors = self.existing_data["prior"]
1967 self.existing_config = self.existing_data["config"]
1968 self.existing_labels = self.existing_data["labels"]
1969 self.existing_weights = self.existing_data["weights"]
1970 else:
1971 self.existing_metafile = None
1972 self.existing_labels = None
1973 self.existing_weights = None
1974 self.existing_samples = None
1975 self.existing_file_version = None
1976 self.existing_file_kwargs = None
1977 self.existing_priors = None
1978 self.existing_config = None
1979 self.existing_injection_data = None
1980 self.mcmc_samples = self.opts.mcmc_samples
1981 self.labels = self.opts.labels
1982 self.weights = {i: None for i in self.labels}
1983 self.config = self.opts.config
1984 self.injection_file = self.opts.inj_file
1985 self.regenerate = self.opts.regenerate
1986 if extra_options is not None:
1987 for opt in extra_options:
1988 setattr(self, opt, getattr(self.opts, opt))
1989 self.nsamples_for_prior = self.opts.nsamples_for_prior
1990 self.priors = self.opts.prior_file
1991 self.disable_prior_sampling = self.opts.disable_prior_sampling
1992 self.path_to_samples = self.opts.path_to_samples
1993 self.file_format = self.opts.file_format
1994 self.nsamples = self.opts.nsamples
1995 self.keep_nan_likelihood_samples = self.opts.keep_nan_likelihood_samples
1996 self.reweight_samples = self.opts.reweight_samples
1997 self.samples = self.opts.samples
1998 self.ignore_parameters = self.opts.ignore_parameters
1999 self.burnin_method = self.opts.burnin_method
2000 self.burnin = self.opts.burnin
2001 self.same_parameters = []
2002 if self.mcmc_samples:
2003 self._samples = {label: self.samples.T for label in self.labels}
2004 self.write_current_state()
2006 @property
2007 def analytic_prior_dict(self):
2008 return {
2009 label: "\n".join(
2010 [
2011 "{} = {}".format(key, value) for key, value in
2012 self.priors["analytic"][label].items()
2013 ]
2014 ) if "analytic" in self.priors.keys() and label in
2015 self.priors["analytic"].keys() else None for label in self.labels
2016 }
2018 @property
2019 def same_parameters(self):
2020 return self._same_parameters
2022 @same_parameters.setter
2023 def same_parameters(self, same_parameters):
2024 self._same_parameters = self.intersect_samples_dict(self.samples)
2026 def intersect_samples_dict(self, samples):
2027 parameters = [
2028 list(samples[key].keys()) for key in samples.keys()
2029 ]
2030 params = list(set.intersection(*[set(l) for l in parameters]))
2031 return params
2034class PlottingInput(SamplesInput):
2035 """Class to handle and store plotting specific command line arguments
2036 """
2037 def __init__(self, *args, **kwargs):
2038 super(PlottingInput, self).__init__(*args, **kwargs)
2039 self.style_file = self.opts.style_file
2040 self.publication = self.opts.publication
2041 self.publication_kwargs = self.opts.publication_kwargs
2042 self.kde_plot = self.opts.kde_plot
2043 self.custom_plotting = self.opts.custom_plotting
2044 self.add_to_corner = self.opts.add_to_corner
2045 self.corner_params = self.add_to_corner
2046 self.palette = self.opts.palette
2047 self.include_prior = self.opts.include_prior
2048 self.colors = self.opts.colors
2049 self.linestyles = self.opts.linestyles
2050 self.disable_corner = self.opts.disable_corner
2051 self.disable_comparison = self.opts.disable_comparison
2052 self.disable_interactive = self.opts.disable_interactive
2053 self.disable_expert = not self.opts.enable_expert
2054 self.multi_threading_for_plots = self.multi_process
2055 self.write_current_state()
2057 @property
2058 def default_directories(self):
2059 dirs = super(PlottingInput, self).default_directories
2060 dirs += ["plots", "plots/corner", "plots/publication", "samples"]
2061 return dirs
2064class WebpageInput(SamplesInput):
2065 """Class to handle and store webpage specific command line arguments
2066 """
2067 def __init__(self, *args, **kwargs):
2068 super(WebpageInput, self).__init__(*args, **kwargs)
2069 self.baseurl = self.opts.baseurl
2070 self.existing_plot = self.opts.existing_plot
2071 self.pe_algorithm = self.opts.pe_algorithm
2072 self.notes = self.opts.notes
2073 self.dump = self.opts.dump
2074 self.hdf5 = not self.opts.save_to_json
2075 self.external_hdf5_links = self.opts.external_hdf5_links
2076 self.file_kwargs["webpage_url"] = self.baseurl + "/home.html"
2077 self.write_current_state()
2079 @property
2080 def default_directories(self):
2081 dirs = super(WebpageInput, self).default_directories
2082 dirs += ["js", "html", "css"]
2083 return dirs
2085 @property
2086 def default_files_to_copy(self):
2087 from pesummary import core
2088 files_to_copy = super(WebpageInput, self).default_files_to_copy
2089 path = core.__path__[0]
2090 scripts = glob(os.path.join(path, "js", "*.js"))
2091 for i in scripts:
2092 files_to_copy.append(
2093 [i, os.path.join(self.webdir, "js", os.path.basename(i))]
2094 )
2095 scripts = glob(os.path.join(path, "css", "*.css"))
2096 for i in scripts:
2097 files_to_copy.append(
2098 [i, os.path.join(self.webdir, "css", os.path.basename(i))]
2099 )
2100 return files_to_copy
2103class WebpagePlusPlottingInput(PlottingInput, WebpageInput):
2104 """Class to handle and store webpage and plotting specific command line
2105 arguments
2106 """
2107 def __init__(self, *args, **kwargs):
2108 super(WebpagePlusPlottingInput, self).__init__(*args, **kwargs)
2109 self.copy_files()
2111 @property
2112 def default_directories(self):
2113 return super(WebpagePlusPlottingInput, self).default_directories
2115 @property
2116 def default_files_to_copy(self):
2117 return super(WebpagePlusPlottingInput, self).default_files_to_copy
2120class MetaFileInput(SamplesInput):
2121 """Class to handle and store metafile specific command line arguments
2122 """
2123 def __init__(self, *args, **kwargs):
2124 kwargs.update({"ignore_copy": True})
2125 super(MetaFileInput, self).__init__(*args, **kwargs)
2126 self.copy_files()
2127 self.filename = self.opts.filename
2128 self.hdf5 = not self.opts.save_to_json
2129 self.hdf5_compression = self.opts.hdf5_compression
2130 self.external_hdf5_links = self.opts.external_hdf5_links
2131 self.descriptions = self.opts.descriptions
2132 self.preferred = self.opts.preferred
2133 self.write_current_state()
2135 @property
2136 def default_directories(self):
2137 dirs = super(MetaFileInput, self).default_directories
2138 dirs += ["samples", "config"]
2139 return dirs
2141 @property
2142 def default_files_to_copy(self):
2143 files_to_copy = super(MetaFileInput, self).default_files_to_copy
2144 if not all(i is None for i in self.config):
2145 for num, i in enumerate(self.config):
2146 if i is not None and self.webdir not in i:
2147 filename = "_".join(
2148 [self.labels[num], "config.ini"]
2149 )
2150 files_to_copy.append(
2151 [i, os.path.join(self.webdir, "config", filename)]
2152 )
2153 for num, _file in enumerate(self.result_files):
2154 if not self.mcmc_samples:
2155 filename = "{}_{}".format(self.labels[num], Path(_file).name)
2156 else:
2157 filename = "chain_{}_{}".format(num, Path(_file).name)
2158 files_to_copy.append(
2159 [_file, os.path.join(self.webdir, "samples", filename)]
2160 )
2161 return files_to_copy
2164class WebpagePlusPlottingPlusMetaFileInput(MetaFileInput, WebpagePlusPlottingInput):
2165 """Class to handle and store webpage, plotting and metafile specific command
2166 line arguments
2167 """
2168 def __init__(self, *args, **kwargs):
2169 super(WebpagePlusPlottingPlusMetaFileInput, self).__init__(
2170 *args, **kwargs
2171 )
2173 @property
2174 def default_directories(self):
2175 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_directories
2177 @property
2178 def default_files_to_copy(self):
2179 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_files_to_copy
2182@deprecation(
2183 "The Input class is deprecated. Please use either the BaseInput, "
2184 "SamplesInput, PlottingInput, WebpageInput, WebpagePlusPlottingInput, "
2185 "MetaFileInput or the WebpagePlusPlottingPlusMetaFileInput class"
2186)
2187class Input(WebpagePlusPlottingPlusMetaFileInput):
2188 pass
2191def load_current_state(resume_file):
2192 """Load a pickle file containing checkpoint information
2194 Parameters
2195 ----------
2196 resume_file: str
2197 path to a checkpoint file
2198 """
2199 from pesummary.io import read
2200 if not os.path.isfile(resume_file):
2201 logger.info(
2202 "Unable to find resume file. Not restarting from checkpoint"
2203 )
2204 return
2205 logger.info(
2206 "Reading checkpoint file: {}".format(resume_file)
2207 )
2208 state = read(resume_file, checkpoint=True)
2209 return state