Coverage for pesummary/core/cli/inputs.py: 80.1%
1264 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import os
4import socket
5from glob import glob
6from pathlib import Path
7from getpass import getuser
9import math
10import numpy as np
11from pesummary.core.file.read import read as Read
12from pesummary.utils.exceptions import InputError
13from pesummary.utils.decorators import deprecation
14from pesummary.utils.samples_dict import SamplesDict, MCMCSamplesDict
15from pesummary.utils.utils import (
16 guess_url, logger, make_dir, make_cache_style_file, list_match
17)
18from pesummary import conf
20__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
23class _Input(object):
24 """Super class to handle the command line arguments
25 """
26 @staticmethod
27 def is_pesummary_metafile(proposed_file):
28 """Determine if a file is a PESummary metafile or not
30 Parameters
31 ----------
32 proposed_file: str
33 path to the file
34 """
35 extension = proposed_file.split(".")[-1]
36 if extension == "h5" or extension == "hdf5" or extension == "hdf":
37 from pesummary.core.file.read import (
38 is_pesummary_hdf5_file, is_pesummary_hdf5_file_deprecated
39 )
41 result = any(
42 func(proposed_file) for func in [
43 is_pesummary_hdf5_file,
44 is_pesummary_hdf5_file_deprecated
45 ]
46 )
47 return result
48 elif extension == "json":
49 from pesummary.core.file.read import (
50 is_pesummary_json_file, is_pesummary_json_file_deprecated
51 )
53 result = any(
54 func(proposed_file) for func in [
55 is_pesummary_json_file,
56 is_pesummary_json_file_deprecated
57 ]
58 )
59 return result
60 else:
61 return False
63 @staticmethod
64 def grab_data_from_metafile(
65 existing_file, webdir, compare=None, read_function=Read,
66 _replace_with_pesummary_kwargs={}, nsamples=None,
67 disable_injection=False, keep_nan_likelihood_samples=False,
68 reweight_samples=False, **kwargs
69 ):
70 """Grab data from an existing PESummary metafile
72 Parameters
73 ----------
74 existing_file: str
75 path to the existing metafile
76 webdir: str
77 the directory to store the existing configuration file
78 compare: list, optional
79 list of labels for events stored in an existing metafile that you
80 wish to compare
81 read_function: func, optional
82 PESummary function to use to read in the existing file
83 _replace_with_pesummary_kwargs: dict, optional
84 dictionary of kwargs that you wish to replace with the data stored
85 in the PESummary file
86 nsamples: int, optional
87 Number of samples to use. Default all available samples
88 kwargs: dict
89 All kwargs are passed to the `generate_all_posterior_samples`
90 method
91 """
92 f = read_function(
93 existing_file,
94 remove_nan_likelihood_samples=not keep_nan_likelihood_samples
95 )
96 for ind, label in enumerate(f.labels):
97 kwargs[label] = kwargs.copy()
98 for key, item in _replace_with_pesummary_kwargs.items():
99 try:
100 kwargs[label][key] = eval(
101 item.format(file="f", ind=ind, label=label)
102 )
103 except TypeError:
104 _item = item.split("['{label}']")[0]
105 kwargs[label][key] = eval(
106 _item.format(file="f", ind=ind, label=label)
107 )
108 except (AttributeError, KeyError, NameError):
109 pass
111 if not f.mcmc_samples:
112 labels = f.labels
113 else:
114 labels = list(f.samples_dict.keys())
115 indicies = np.arange(len(labels))
117 if compare:
118 indicies = []
119 for i in compare:
120 if i not in labels:
121 raise InputError(
122 "Label '%s' does not exist in the metafile. The list "
123 "of available labels are %s" % (i, labels)
124 )
125 indicies.append(labels.index(i))
126 labels = compare
128 if nsamples is not None:
129 f.downsample(nsamples, labels=labels)
130 if not f.mcmc_samples:
131 f.generate_all_posterior_samples(labels=labels, **kwargs)
132 if reweight_samples:
133 f.reweight_samples(reweight_samples, labels=labels, **kwargs)
135 parameters = f.parameters
136 if not f.mcmc_samples:
137 samples = [np.array(i).T for i in f.samples]
138 DataFrame = {
139 label: SamplesDict(parameters[ind], samples[ind])
140 for label, ind in zip(labels, indicies)
141 }
142 _parameters = lambda label: DataFrame[label].keys()
143 else:
144 DataFrame = {
145 f.labels[0]: MCMCSamplesDict(
146 {
147 label: f.samples_dict[label] for label in labels
148 }
149 )
150 }
151 labels = f.labels
152 indicies = np.arange(len(labels))
153 _parameters = lambda label: DataFrame[f.labels[0]].parameters
154 if not disable_injection and f.injection_parameters != []:
155 inj_values = f.injection_dict
156 for label in labels:
157 for param in DataFrame[label].keys():
158 if param not in f.injection_dict[label].keys():
159 f.injection_dict[label][param] = float("nan")
160 else:
161 inj_values = {
162 i: {
163 param: float("nan") for param in DataFrame[i].parameters
164 } for i in labels
165 }
166 for i in inj_values.keys():
167 for param in inj_values[i].keys():
168 if inj_values[i][param] == "nan":
169 inj_values[i][param] = float("nan")
170 if isinstance(inj_values[i][param], bytes):
171 inj_values[i][param] = inj_values[i][param].decode("utf-8")
173 if hasattr(f, "priors") and f.priors is not None and f.priors != {}:
174 priors = f.priors
175 else:
176 priors = {label: {} for label in labels}
178 config = []
179 if f.config is not None and not all(i is None for i in f.config):
180 config = []
181 for i in labels:
182 config_dir = os.path.join(webdir, "config")
183 filename = f.write_config_to_file(
184 i, outdir=config_dir, _raise=False,
185 filename="{}_config.ini".format(i)
186 )
187 _config = os.path.join(config_dir, filename)
188 if filename is not None and os.path.isfile(_config):
189 config.append(_config)
190 else:
191 config.append(None)
192 else:
193 for i in labels:
194 config.append(None)
196 if f.weights is not None:
197 weights = {i: f.weights[i] for i in labels}
198 else:
199 weights = {i: None for i in labels}
201 return {
202 "samples": DataFrame,
203 "injection_data": inj_values,
204 "file_version": {
205 i: j for i, j in zip(
206 labels, [f.input_version[ind] for ind in indicies]
207 )
208 },
209 "file_kwargs": {
210 i: j for i, j in zip(
211 labels, [f.extra_kwargs[ind] for ind in indicies]
212 )
213 },
214 "prior": priors,
215 "config": config,
216 "labels": labels,
217 "weights": weights,
218 "indicies": indicies,
219 "mcmc_samples": f.mcmc_samples,
220 "open_file": f,
221 "descriptions": f.description
222 }
224 @staticmethod
225 def grab_data_from_file(
226 file, label, webdir, config=None, injection=None, read_function=Read,
227 file_format=None, nsamples=None, disable_prior_sampling=False,
228 nsamples_for_prior=None, path_to_samples=None,
229 keep_nan_likelihood_samples=False, reweight_samples=False,
230 **kwargs
231 ):
232 """Grab data from a result file containing posterior samples
234 Parameters
235 ----------
236 file: str
237 path to the result file
238 label: str
239 label that you wish to use for the result file
240 config: str, optional
241 path to a configuration file used in the analysis
242 injection: str, optional
243 path to an injection file used in the analysis
244 read_function: func, optional
245 PESummary function to use to read in the file
246 file_format, str, optional
247 the file format you wish to use when loading. Default None.
248 If None, the read function loops through all possible options
249 kwargs: dict
250 Dictionary of keyword arguments fed to the
251 `generate_all_posterior_samples` method
252 """
253 f = read_function(
254 file, file_format=file_format, disable_prior=disable_prior_sampling,
255 nsamples_for_prior=nsamples_for_prior, path_to_samples=path_to_samples,
256 remove_nan_likelihood_samples=not keep_nan_likelihood_samples
257 )
258 if config is not None:
259 f.add_fixed_parameters_from_config_file(config)
261 if nsamples is not None:
262 f.downsample(nsamples)
263 f.generate_all_posterior_samples(**kwargs)
264 if injection:
265 f.add_injection_parameters_from_file(
266 injection, conversion_kwargs=kwargs
267 )
268 if reweight_samples:
269 f.reweight_samples(reweight_samples)
270 parameters = f.parameters
271 samples = np.array(f.samples).T
272 DataFrame = {label: SamplesDict(parameters, samples)}
273 kwargs = f.extra_kwargs
274 if hasattr(f, "injection_parameters"):
275 injection = f.injection_parameters
276 if injection is not None:
277 for i in parameters:
278 if i not in list(injection.keys()):
279 injection[i] = float("nan")
280 else:
281 injection = {i: j for i, j in zip(
282 parameters, [float("nan")] * len(parameters))}
283 else:
284 injection = {i: j for i, j in zip(
285 parameters, [float("nan")] * len(parameters))}
286 version = f.input_version
287 if hasattr(f, "priors") and f.priors is not None:
288 priors = {key: {label: item} for key, item in f.priors.items()}
289 else:
290 priors = {label: []}
291 if hasattr(f, "weights") and f.weights is not None:
292 weights = f.weights
293 else:
294 weights = None
295 data = {
296 "samples": DataFrame,
297 "injection_data": {label: injection},
298 "file_version": {label: version},
299 "file_kwargs": {label: kwargs},
300 "prior": priors,
301 "weights": {label: weights},
302 "open_file": f,
303 "descriptions": {label: f.description}
304 }
305 if hasattr(f, "config") and f.config is not None:
306 if config is None:
307 config_dir = os.path.join(webdir, "config")
308 filename = "{}_config.ini".format(label)
309 logger.debug(
310 "Successfully extracted config data from the provided "
311 "input file. Saving the data to the file '{}'".format(
312 os.path.join(config_dir, filename)
313 )
314 )
315 _filename = f.write(
316 filename=filename, outdir=config_dir, file_format="ini",
317 _raise=False
318 )
319 data["config"] = _filename
320 else:
321 logger.info(
322 "Ignoring config data extracted from the input file and "
323 "using the config file provided"
324 )
325 return data
327 @property
328 def result_files(self):
329 return self._result_files
331 @result_files.setter
332 def result_files(self, result_files):
333 self._result_files = result_files
334 if self._result_files is not None:
335 for num, ff in enumerate(self._result_files):
336 func = None
337 if not os.path.isfile(ff) and "@" in ff:
338 from pesummary.io.read import _fetch_from_remote_server
339 func = _fetch_from_remote_server
340 elif not os.path.isfile(ff) and "https://" in ff:
341 from pesummary.io.read import _fetch_from_url
342 func = _fetch_from_url
343 elif not os.path.isfile(ff) and "*" in ff:
344 from pesummary.utils.utils import glob_directory
345 func = glob_directory
346 if func is not None:
347 _data = func(ff)
348 if isinstance(_data, (np.ndarray, list)) and len(_data) > 0:
349 self._result_files[num] = _data[0]
350 if len(_data) > 1:
351 _ = [
352 self._result_files.insert(num + 1, d) for d in
353 _data[1:][::-1]
354 ]
355 elif isinstance(_data, np.ndarray):
356 raise InputError(
357 "Unable to find any files matching '{}'".format(ff)
358 )
359 else:
360 self._result_files[num] = _data
362 @property
363 def seed(self):
364 return self._seed
366 @seed.setter
367 def seed(self, seed):
368 np.random.seed(seed)
369 self._seed = seed
371 @property
372 def existing(self):
373 return self._existing
375 @existing.setter
376 def existing(self, existing):
377 self._existing = existing
378 if existing is not None:
379 self._existing = os.path.abspath(existing)
381 @property
382 def existing_metafile(self):
383 return self._existing_metafile
385 @existing_metafile.setter
386 def existing_metafile(self, existing_metafile):
387 from glob import glob
389 self._existing_metafile = existing_metafile
390 if self._existing_metafile is None:
391 return
392 if not os.path.isdir(os.path.join(self.existing, "samples")):
393 raise InputError("Please provide a valid existing directory")
394 _dir = os.path.join(self.existing, "samples")
395 files = glob(os.path.join(_dir, "posterior_samples*"))
396 dir_content = glob(os.path.join(_dir, "*.h5"))
397 dir_content.extend(glob(os.path.join(_dir, "*.json")))
398 dir_content.extend(glob(os.path.join(_dir, "*.hdf5")))
399 if len(files) == 0 and len(dir_content):
400 files = dir_content
401 logger.warning(
402 "Unable to find a 'posterior_samples*' file in the existing "
403 "directory. Using '{}' as the existing metafile".format(
404 dir_content[0]
405 )
406 )
407 elif len(files) == 0:
408 raise InputError(
409 "Unable to find an existing metafile in the existing webdir"
410 )
411 elif len(files) > 1:
412 raise InputError(
413 "Multiple metafiles in the existing directory. Please either "
414 "run the `summarycombine_metafile` executable to combine the "
415 "meta files or simple remove the unwanted meta file"
416 )
417 self._existing_metafile = os.path.join(
418 self.existing, "samples", files[0]
419 )
421 @property
422 def style_file(self):
423 return self._style_file
425 @style_file.setter
426 def style_file(self, style_file):
427 default = conf.style_file
428 if style_file is not None and not os.path.isfile(style_file):
429 logger.warning(
430 "The file '{}' does not exist. Resorting to default".format(
431 style_file
432 )
433 )
434 style_file = default
435 elif style_file is not None and os.path.isfile(style_file):
436 logger.info(
437 "Using the file '{}' as the matplotlib style file".format(
438 style_file
439 )
440 )
441 elif style_file is None:
442 logger.debug(
443 "Using the default matplotlib style file"
444 )
445 style_file = default
446 make_cache_style_file(style_file)
447 self._style_file = style_file
449 @property
450 def filename(self):
451 return self._filename
453 @filename.setter
454 def filename(self, filename):
455 self._filename = filename
456 if filename is not None:
457 if "/" in filename:
458 logger.warning("")
459 filename = filename.split("/")[-1]
460 if os.path.isfile(os.path.join(self.webdir, "samples", filename)):
461 logger.warning(
462 "A file with filename '{}' already exists in the samples "
463 "directory '{}'. This will be overwritten"
464 )
466 @property
467 def user(self):
468 return self._user
470 @user.setter
471 def user(self, user):
472 try:
473 self._user = getuser()
474 logger.info(
475 conf.overwrite.format("user", conf.user, self._user)
476 )
477 except KeyError as e:
478 logger.info(
479 "Failed to grab user information because {}. Default will be "
480 "used".format(e)
481 )
482 self._user = user
484 @property
485 def host(self):
486 return socket.getfqdn()
488 @property
489 def webdir(self):
490 return self._webdir
492 @webdir.setter
493 def webdir(self, webdir):
494 cond1 = webdir is None or webdir == "None" or webdir == "none"
495 cond2 = (
496 self.existing is None or self.existing == "None"
497 or self.existing == "none"
498 )
499 if cond1 and cond2:
500 raise InputError(
501 "Please provide a web directory to store the webpages. If "
502 "you wish to add to an existing webpage, then pass the "
503 "existing web directory under the '--existing_webdir' command "
504 "line argument. If this is a new set of webpages, then pass "
505 "the web directory under the '--webdir' argument"
506 )
507 elif webdir is None and self.existing is not None:
508 if not os.path.isdir(self.existing):
509 raise InputError(
510 "The directory {} does not exist".format(self.existing)
511 )
512 entries = glob(self.existing + "/*")
513 if os.path.join(self.existing, "home.html") not in entries:
514 raise InputError(
515 "Please give the base directory of an existing output"
516 )
517 self._webdir = self.existing
518 else:
519 if not os.path.isdir(webdir):
520 logger.debug(
521 "Given web directory does not exist. Creating it now"
522 )
523 make_dir(webdir)
524 self._webdir = os.path.abspath(webdir)
526 @property
527 def baseurl(self):
528 return self._baseurl
530 @baseurl.setter
531 def baseurl(self, baseurl):
532 self._baseurl = baseurl
533 if baseurl is None:
534 self._baseurl = guess_url(self.webdir, self.host, self.user)
536 @property
537 def mcmc_samples(self):
538 return self._mcmc_samples
540 @mcmc_samples.setter
541 def mcmc_samples(self, mcmc_samples):
542 self._mcmc_samples = mcmc_samples
543 if self._mcmc_samples:
544 logger.info(
545 "Treating all samples as seperate mcmc chains for the same "
546 "analysis."
547 )
549 @property
550 def labels(self):
551 return self._labels
553 @labels.setter
554 def labels(self, labels):
555 if self.result_files is not None:
556 if any(self.is_pesummary_metafile(s) for s in self.result_files):
557 logger.warning(
558 "labels argument is ignored when a pesummary metafile is "
559 "input. Stored analyses will use their stored labels. If "
560 "you wish to change the labels, please use `summarymodify`"
561 )
562 labels = self.default_labels()
563 if not hasattr(self, "._labels"):
564 if labels is None:
565 labels = self.default_labels()
566 elif self.mcmc_samples and len(labels) != 1:
567 raise InputError(
568 "Please provide a single label that corresponds to all "
569 "mcmc samples"
570 )
571 elif len(np.unique(labels)) != len(labels):
572 raise InputError(
573 "Please provide unique labels for each result file"
574 )
575 for num, i in enumerate(labels):
576 if "." in i:
577 logger.warning(
578 "Replacing the label {} by {} to make it compatible "
579 "with the html pages".format(i, i.replace(".", "_"))
580 )
581 labels[num] = i.replace(".", "_")
582 if self.add_to_existing:
583 for i in labels:
584 if i in self.existing_labels:
585 raise InputError(
586 "The label '%s' already exists in the existing "
587 "metafile. Please pass another unique label"
588 )
590 if len(self.result_files) != len(labels) and not self.mcmc_samples:
591 import copy
592 _new_labels = copy.deepcopy(labels)
593 idx = 1
594 while len(_new_labels) < len(self.result_files):
595 _new_labels.extend(
596 [_label + str(idx) for _label in labels]
597 )
598 idx += 1
599 _new_labels = _new_labels[:len(self.result_files)]
600 logger.info(
601 "You have passed {} result files and {} labels. Setting "
602 "labels = {}".format(
603 len(self.result_files), len(labels), _new_labels
604 )
605 )
606 labels = _new_labels
607 self._labels = labels
609 @property
610 def config(self):
611 return self._config
613 @config.setter
614 def config(self, config):
615 if config and len(config) != len(self.labels):
616 raise InputError(
617 "Please provide a configuration file for each label"
618 )
619 if config is None and not self.meta_file:
620 self._config = [None] * len(self.labels)
621 elif self.meta_file:
622 self._config = [None] * len(self.labels)
623 else:
624 self._config = config
625 for num, ff in enumerate(self._config):
626 if isinstance(ff, str) and ff.lower() == "none":
627 self._config[num] = None
629 @property
630 def injection_file(self):
631 return self._injection_file
633 @injection_file.setter
634 def injection_file(self, injection_file):
635 if injection_file and len(injection_file) != len(self.labels):
636 if len(injection_file) == 1:
637 logger.info(
638 "Only one injection file passed. Assuming the same "
639 "injection for all {} result files".format(len(self.labels))
640 )
641 else:
642 raise InputError(
643 "You have passed {} for {} result files. Please provide an "
644 "injection file for each result file".format(
645 len(self.injection_file), len(self.labels)
646 )
647 )
648 if injection_file is None:
649 injection_file = [None] * len(self.labels)
650 self._injection_file = injection_file
652 @property
653 def injection_data(self):
654 return self._injection_data
656 @property
657 def file_version(self):
658 return self._file_version
660 @property
661 def file_kwargs(self):
662 return self._file_kwargs
664 @property
665 def kde_plot(self):
666 return self._kde_plot
668 @kde_plot.setter
669 def kde_plot(self, kde_plot):
670 self._kde_plot = kde_plot
671 if kde_plot != conf.kde_plot:
672 logger.info(
673 conf.overwrite.format("kde_plot", conf.kde_plot, kde_plot)
674 )
676 @property
677 def file_format(self):
678 return self._file_format
680 @file_format.setter
681 def file_format(self, file_format):
682 if file_format is None:
683 self._file_format = [None] * len(self.labels)
684 elif len(file_format) == 1 and len(file_format) != len(self.labels):
685 logger.warning(
686 "Only one file format specified. Assuming all files are of "
687 "this format"
688 )
689 self._file_format = [file_format[0]] * len(self.labels)
690 elif len(file_format) != len(self.labels):
691 raise InputError(
692 "Please provide a file format for each result file. If you "
693 "wish to specify the file format for the second result file "
694 "and not for any of the others, for example, simply pass 'None "
695 "{format} None'"
696 )
697 else:
698 for num, ff in enumerate(file_format):
699 if ff.lower() == "none":
700 file_format[num] = None
701 self._file_format = file_format
703 @property
704 def samples(self):
705 return self._samples
707 @samples.setter
708 def samples(self, samples):
709 if isinstance(samples, dict):
710 return samples
711 self._set_samples(samples)
713 def _set_samples(
714 self, samples,
715 ignore_keys=["prior", "weights", "labels", "indicies", "open_file"]
716 ):
717 """Extract the samples and store them as attributes of self
719 Parameters
720 ----------
721 samples: list
722 A list containing the paths to result files
723 ignore_keys: list, optional
724 A list containing properties of the read file that you do not want to be
725 stored as attributes of self
726 """
727 if not samples:
728 raise InputError("Please provide a results file")
729 _samples_generator = (self.is_pesummary_metafile(s) for s in samples)
730 if any(_samples_generator) and not all(_samples_generator):
731 raise InputError(
732 "It seems that you have passed a combination of pesummary "
733 "metafiles and non-pesummary metafiles. This is currently "
734 "not supported."
735 )
736 labels, labels_dict = None, {}
737 weights_dict = {}
738 if self.mcmc_samples:
739 nsamples = 0.
740 for num, i in enumerate(samples):
741 idx = num
742 if not self.mcmc_samples:
743 if not self.is_pesummary_metafile(samples[num]):
744 logger.info("Assigning {} to {}".format(self.labels[num], i))
745 else:
746 num = 0
747 if not os.path.isfile(i):
748 raise InputError("File %s does not exist" % (i))
749 if self.is_pesummary_metafile(samples[num]):
750 data = self.grab_data_from_input(
751 i, self.labels[num], config=None, injection=None
752 )
753 self.mcmc_samples = data["mcmc_samples"]
754 else:
755 data = self.grab_data_from_input(
756 i, self.labels[num], config=self.config[num],
757 injection=self.injection_file[num],
758 file_format=self.file_format[num]
759 )
760 if "config" in data.keys():
761 msg = (
762 "Overwriting the provided config file for '{}' with "
763 "the config information stored in the input "
764 "file".format(self.labels[num])
765 )
766 if self.config[num] is None:
767 logger.debug(msg)
768 else:
769 logger.info(msg)
770 self.config[num] = data.pop("config")
771 if self.mcmc_samples:
772 data["samples"] = {
773 "{}_mcmc_chain_{}".format(key, idx): item for key, item
774 in data["samples"].items()
775 }
776 for key, item in data.items():
777 if key not in ignore_keys:
778 if idx == 0:
779 setattr(self, "_{}".format(key), item)
780 else:
781 x = getattr(self, "_{}".format(key))
782 if isinstance(x, dict):
783 x.update(item)
784 elif isinstance(x, list):
785 x += item
786 setattr(self, "_{}".format(key), x)
787 if self.mcmc_samples:
788 try:
789 nsamples += data["file_kwargs"][self.labels[num]]["sampler"][
790 "nsamples"
791 ]
792 except UnboundLocalError:
793 pass
794 if "labels" in data.keys():
795 stored_labels = data["labels"]
796 else:
797 stored_labels = [self.labels[num]]
798 if "weights" in data.keys():
799 weights_dict.update(data["weights"])
800 if "prior" in data.keys():
801 for label in stored_labels:
802 pp = data["prior"]
803 if pp != {} and label in pp.keys() and pp[label] == []:
804 if len(self.priors):
805 if label not in self.priors["samples"].keys():
806 self.add_to_prior_dict(
807 "samples/{}".format(label), []
808 )
809 else:
810 self.add_to_prior_dict(
811 "samples/{}".format(label), []
812 )
813 elif pp != {} and label not in pp.keys():
814 for key in pp.keys():
815 if key in self.priors.keys():
816 if label in self.priors[key].keys():
817 logger.warning(
818 "Replacing the prior file for {} "
819 "with the prior file stored in "
820 "the result file".format(
821 label
822 )
823 )
824 if pp[key] == {}:
825 self.add_to_prior_dict(
826 "{}/{}".format(key, label), []
827 )
828 elif label not in pp[key].keys():
829 self.add_to_prior_dict(
830 "{}/{}".format(key, label), {}
831 )
832 else:
833 self.add_to_prior_dict(
834 "{}/{}".format(key, label), pp[key][label]
835 )
836 else:
837 self.add_to_prior_dict(
838 "samples/{}".format(label), []
839 )
840 if "labels" in data.keys():
841 _duplicated = [
842 _ for _ in data["labels"] if num != 0 and _ in labels
843 ]
844 if num == 0:
845 labels = data["labels"]
846 elif len(_duplicated):
847 raise InputError(
848 "The labels stored in the supplied files are not "
849 "unique. The label{}: '{}' appear{} in two or more "
850 "files. Please provide unique labels for each "
851 "analysis.".format(
852 "s" if len(_duplicated) > 1 else "",
853 ", ".join(_duplicated),
854 "" if len(_duplicated) > 1 else "s"
855 )
856 )
857 else:
858 labels += data["labels"]
859 labels_dict[num] = data["labels"]
860 if self.mcmc_samples:
861 try:
862 self.file_kwargs[self.labels[0]]["sampler"].update(
863 {"nsamples": nsamples, "nchains": len(self.result_files)}
864 )
865 except (KeyError, UnboundLocalError):
866 pass
867 _labels = list(self._samples.keys())
868 if not isinstance(self._samples[_labels[0]], MCMCSamplesDict):
869 self._samples = MCMCSamplesDict(self._samples)
870 else:
871 self._samples = self._samples[_labels[0]]
872 if labels is not None:
873 self._labels = labels
874 if len(labels) != len(self.result_files):
875 result_files = []
876 for num, f in enumerate(samples):
877 for ii in np.arange(len(labels_dict[num])):
878 result_files.append(self.result_files[num])
879 self.result_files = result_files
880 self.weights = {i: None for i in self.labels}
881 if weights_dict != {}:
882 self.weights = weights_dict
884 @property
885 def burnin_method(self):
886 return self._burnin_method
888 @burnin_method.setter
889 def burnin_method(self, burnin_method):
890 self._burnin_method = burnin_method
891 if not self.mcmc_samples and burnin_method is not None:
892 logger.info(
893 "The {} method will not be used to remove samples as "
894 "burnin as this can only be used for mcmc chains.".format(
895 burnin_method
896 )
897 )
898 self._burnin_method = None
899 elif self.mcmc_samples and burnin_method is None:
900 logger.info(
901 "No burnin method provided. Using {} as default".format(
902 conf.burnin_method
903 )
904 )
905 self._burnin_method = conf.burnin_method
906 elif self.mcmc_samples:
907 from pesummary.core.file import mcmc
909 if burnin_method not in mcmc.algorithms:
910 logger.warning(
911 "Unrecognised burnin method: {}. Resorting to the default: "
912 "{}".format(burnin_method, conf.burnin_method)
913 )
914 self._burnin_method = conf.burnin_method
915 if self._burnin_method is not None:
916 for label in self.labels:
917 self.file_kwargs[label]["sampler"]["burnin_method"] = (
918 self._burnin_method
919 )
921 @property
922 def burnin(self):
923 return self._burnin
925 @burnin.setter
926 def burnin(self, burnin):
927 _name = "nsamples_removed_from_burnin"
928 if burnin is not None:
929 samples_lengths = [
930 self.samples[key].number_of_samples for key in
931 self.samples.keys()
932 ]
933 if not all(int(burnin) < i for i in samples_lengths):
934 raise InputError(
935 "The chosen burnin is larger than the number of samples. "
936 "Please choose a value less than {}".format(
937 np.max(samples_lengths)
938 )
939 )
940 logger.info(
941 conf.overwrite.format("burnin", conf.burnin, burnin)
942 )
943 burnin = int(burnin)
944 else:
945 burnin = conf.burnin
946 if self.burnin_method is not None:
947 arguments, kwargs = [], {}
948 if burnin != 0 and self.burnin_method == "burnin_by_step_number":
949 logger.warning(
950 "The first {} samples have been requested to be removed "
951 "as burnin, but the burnin method has been chosen to be "
952 "burnin_by_step_number. Changing method to "
953 "burnin_by_first_n with keyword argument step_number="
954 "True such that all samples with step number < {} are "
955 "removed".format(burnin, burnin)
956 )
957 self.burnin_method = "burnin_by_first_n"
958 arguments = [burnin]
959 kwargs = {"step_number": True}
960 elif self.burnin_method == "burnin_by_first_n":
961 arguments = [burnin]
962 initial = self.samples.total_number_of_samples
963 self._samples = self.samples.burnin(
964 *arguments, algorithm=self.burnin_method, **kwargs
965 )
966 diff = initial - self.samples.total_number_of_samples
967 self.file_kwargs[self.labels[0]]["sampler"][_name] = diff
968 self.file_kwargs[self.labels[0]]["sampler"]["nsamples"] = \
969 self._samples.total_number_of_samples
970 else:
971 for label in self.samples:
972 self.samples[label] = self.samples[label].discard_samples(
973 burnin
974 )
975 if burnin != conf.burnin:
976 self.file_kwargs[label]["sampler"][_name] = burnin
978 @property
979 def nsamples(self):
980 return self._nsamples
982 @nsamples.setter
983 def nsamples(self, nsamples):
984 self._nsamples = nsamples
985 if nsamples is not None:
986 logger.info(
987 "{} samples will be used for each result file".format(nsamples)
988 )
989 self._nsamples = int(nsamples)
991 @property
992 def reweight_samples(self):
993 return self._reweight_samples
995 @reweight_samples.setter
996 def reweight_samples(self, reweight_samples):
997 from pesummary.core.reweight import options
998 self._reweight_samples = self._check_reweight_samples(
999 reweight_samples, options
1000 )
1002 def _check_reweight_samples(self, reweight_samples, options):
1003 if reweight_samples and reweight_samples not in options.keys():
1004 logger.warning(
1005 "Unknown reweight function: '{}'. Not reweighting posterior "
1006 "and/or prior samples".format(reweight_samples)
1007 )
1008 return False
1009 return reweight_samples
1011 @property
1012 def path_to_samples(self):
1013 return self._path_to_samples
1015 @path_to_samples.setter
1016 def path_to_samples(self, path_to_samples):
1017 self._path_to_samples = path_to_samples
1018 if path_to_samples is None:
1019 self._path_to_samples = {label: None for label in self.labels}
1020 elif len(path_to_samples) != len(self.labels):
1021 raise InputError(
1022 "Please provide a path for all result files passed. If "
1023 "two result files are passed, and only one requires the "
1024 "path_to_samples arguement, please pass --path_to_samples "
1025 "None path/to/samples"
1026 )
1027 else:
1028 _paths = {}
1029 for num, path in enumerate(path_to_samples):
1030 _label = self.labels[num]
1031 if path.lower() == "none":
1032 _paths[_label] = None
1033 else:
1034 _paths[_label] = path
1035 self._path_to_samples = _paths
1037 @property
1038 def priors(self):
1039 return self._priors
1041 @priors.setter
1042 def priors(self, priors):
1043 self._priors = self.grab_priors_from_inputs(priors)
1045 @property
1046 def custom_plotting(self):
1047 return self._custom_plotting
1049 @custom_plotting.setter
1050 def custom_plotting(self, custom_plotting):
1051 self._custom_plotting = custom_plotting
1052 if custom_plotting is not None:
1053 import importlib
1055 path_to_python_file = os.path.dirname(custom_plotting)
1056 python_file = os.path.splitext(os.path.basename(custom_plotting))[0]
1057 if path_to_python_file != "":
1058 import sys
1060 sys.path.append(path_to_python_file)
1061 try:
1062 mod = importlib.import_module(python_file)
1063 methods = getattr(mod, "__single_plots__", list()).copy()
1064 methods += getattr(mod, "__comparion_plots__", list()).copy()
1065 if len(methods) > 0:
1066 self._custom_plotting = [path_to_python_file, python_file]
1067 else:
1068 logger.warning(
1069 "No __single_plots__ or __comparison_plots__ in {}. "
1070 "If you wish to use custom plotting, then please "
1071 "add the variable :__single_plots__ and/or "
1072 "__comparison_plots__ in future. No custom plotting "
1073 "will be done"
1074 )
1075 except ModuleNotFoundError as e:
1076 logger.warning(
1077 "Failed to import {} because {}. No custom plotting will "
1078 "be done".format(python_file, e)
1079 )
1081 @property
1082 def external_hdf5_links(self):
1083 return self._external_hdf5_links
1085 @external_hdf5_links.setter
1086 def external_hdf5_links(self, external_hdf5_links):
1087 self._external_hdf5_links = external_hdf5_links
1088 if not self.hdf5 and self.external_hdf5_links:
1089 logger.warning(
1090 "You can only apply external hdf5 links when saving the meta "
1091 "file in hdf5 format. Turning external hdf5 links off."
1092 )
1093 self._external_hdf5_links = False
1095 @property
1096 def hdf5_compression(self):
1097 return self._hdf5_compression
1099 @hdf5_compression.setter
1100 def hdf5_compression(self, hdf5_compression):
1101 self._hdf5_compression = hdf5_compression
1102 if not self.hdf5 and hdf5_compression is not None:
1103 logger.warning(
1104 "You can only apply compression when saving the meta "
1105 "file in hdf5 format. Turning compression off."
1106 )
1107 self._hdf5_compression = None
1109 @property
1110 def existing_plot(self):
1111 return self._existing_plot
1113 @existing_plot.setter
1114 def existing_plot(self, existing_plot):
1115 self._existing_plot = existing_plot
1116 if self._existing_plot is not None:
1117 from pathlib import Path
1118 import shutil
1119 if isinstance(self._existing_plot, list):
1120 logger.warning(
1121 "Assigning {} to all labels".format(
1122 ", ".join(self._existing_plot)
1123 )
1124 )
1125 self._existing_plot = {
1126 label: self._existing_plot for label in self.labels
1127 }
1128 _does_not_exist = (
1129 "The plot {} does not exist. Not adding plot to summarypages."
1130 )
1131 keys_to_remove = []
1132 for key, _plot in self._existing_plot.items():
1133 if isinstance(_plot, list):
1134 allowed = []
1135 for _subplot in _plot:
1136 if not os.path.isfile(_subplot):
1137 logger.warning(_does_not_exist.format(_subplot))
1138 else:
1139 _filename = os.path.join(
1140 self.webdir, "plots", Path(_subplot).name
1141 )
1142 try:
1143 shutil.copyfile(_subplot, _filename)
1144 except shutil.SameFileError:
1145 pass
1146 allowed.append(_filename)
1147 if not len(allowed):
1148 keys_to_remove.append(key)
1149 elif len(allowed) == 1:
1150 self._existing_plot[key] = allowed[0]
1151 else:
1152 self._existing_plot[key] = allowed
1153 else:
1154 if not os.path.isfile(_plot):
1155 logger.warning(_does_not_exist.format(_plot))
1156 keys_to_remove.append(key)
1157 else:
1158 _filename = os.path.join(
1159 self.webdir, "plots", Path(_plot).name
1160 )
1161 try:
1162 shutil.copyfile(_plot, _filename)
1163 except shutil.SameFileError:
1164 _filename = os.path.join(
1165 self.webdir, "plots", key + "_" + Path(_plot).name
1166 )
1167 shutil.copyfile(_plot, _filename)
1168 self._existing_plot[key] = _filename
1169 for key in keys_to_remove:
1170 del self._existing_plot[key]
1171 if not len(self._existing_plot):
1172 self._existing_plot = None
1174 def add_to_prior_dict(self, path, data):
1175 """Add priors to the prior dictionary
1177 Parameters
1178 ----------
1179 path: str
1180 the location where you wish to store the prior. If this is inside
1181 a nested dictionary, then please pass the path as 'a/b'
1182 data: np.ndarray
1183 the prior samples
1184 """
1185 from functools import reduce
1187 def build_tree(dictionary, path):
1188 """Build a dictionary tree from a list of keys
1190 Parameters
1191 ----------
1192 dictionary: dict
1193 existing dictionary that you wish to add to
1194 path: list
1195 list of keys specifying location
1197 Examples
1198 --------
1199 >>> dictionary = {"label": {"mass_1": [1,2,3,4,5,6]}}
1200 >>> path = ["label", "mass_2"]
1201 >>> build_tree(dictionary, path)
1202 {"label": {"mass_1": [1,2,3,4,5,6], "mass_2": {}}}
1203 """
1204 if path != [] and path[0] not in dictionary.keys():
1205 dictionary[path[0]] = {}
1206 if path != []:
1207 build_tree(dictionary[path[0]], path[1:])
1208 return dictionary
1210 def get_nested_dictionary(dictionary, path):
1211 """Return a nested dictionary from a list specifying path
1213 Parameters
1214 ----------
1215 dictionary: dict
1216 existing dictionary that you wish to extract information from
1217 path: list
1218 list of keys specifying location
1220 Examples
1221 --------
1222 >>> dictionary = {"label": {"mass_1": [1,2,3,4,5,6]}}
1223 >>> path = ["label", "mass_1"]
1224 >>> get_nested_dictionary(dictionary, path)
1225 [1,2,3,4,5,6]
1226 """
1227 return reduce(dict.get, path, dictionary)
1229 if "/" in path:
1230 path = path.split("/")
1231 else:
1232 path = [path]
1233 tree = build_tree(self._priors, path)
1234 nested_dictionary = get_nested_dictionary(self._priors, path[:-1])
1235 nested_dictionary[path[-1]] = data
1237 def grab_priors_from_inputs(self, priors, read_func=None, read_kwargs={}):
1238 """
1239 """
1240 if read_func is None:
1241 from pesummary.core.file.read import read as Read
1242 read_func = Read
1244 prior_dict = {}
1245 if priors is not None:
1246 prior_dict = {"samples": {}, "analytic": {}}
1247 for i in priors:
1248 if not os.path.isfile(i):
1249 raise InputError("The file {} does not exist".format(i))
1250 if len(priors) != len(self.labels) and len(priors) == 1:
1251 logger.warning(
1252 "You have only specified a single prior file for {} result "
1253 "files. Assuming the same prior file for all result "
1254 "files".format(len(self.labels))
1255 )
1256 data = read_func(
1257 priors[0], nsamples=self.nsamples_for_prior
1258 )
1259 for i in self.labels:
1260 prior_dict["samples"][i] = data.samples_dict
1261 try:
1262 if data.analytic is not None:
1263 prior_dict["analytic"][i] = data.analytic
1264 except AttributeError:
1265 continue
1266 elif len(priors) != len(self.labels):
1267 raise InputError(
1268 "Please provide a prior file for each result file"
1269 )
1270 else:
1271 for num, i in enumerate(priors):
1272 if i.lower() == "none":
1273 continue
1274 logger.info(
1275 "Assigning {} to {}".format(self.labels[num], i)
1276 )
1277 if self.labels[num] in read_kwargs.keys():
1278 grab_data_kwargs = read_kwargs[self.labels[num]]
1279 else:
1280 grab_data_kwargs = read_kwargs
1281 data = read_func(
1282 priors[num], nsamples=self.nsamples_for_prior,
1283 **grab_data_kwargs
1284 )
1285 prior_dict["samples"][self.labels[num]] = data.samples_dict
1286 try:
1287 if data.analytic is not None:
1288 prior_dict["analytic"][self.labels[num]] = data.analytic
1289 except AttributeError:
1290 continue
1291 return prior_dict
1293 @property
1294 def grab_data_kwargs(self):
1295 return {
1296 label: dict(regenerate=self.regenerate) for label in self.labels
1297 }
1299 def grab_data_from_input(
1300 self, file, label, config=None, injection=None, file_format=None
1301 ):
1302 """Wrapper function for the grab_data_from_metafile and
1303 grab_data_from_file functions
1305 Parameters
1306 ----------
1307 file: str
1308 path to the result file
1309 label: str
1310 label that you wish to use for the result file
1311 config: str, optional
1312 path to a configuration file used in the analysis
1313 injection: str, optional
1314 path to an injection file used in the analysis
1315 file_format, str, optional
1316 the file format you wish to use when loading. Default None.
1317 If None, the read function loops through all possible options
1318 mcmc: Bool, optional
1319 if True, the result file is an mcmc chain
1320 """
1321 if label in self.grab_data_kwargs.keys():
1322 grab_data_kwargs = self.grab_data_kwargs[label]
1323 else:
1324 grab_data_kwargs = self.grab_data_kwargs
1326 if self.is_pesummary_metafile(file):
1327 data = self.grab_data_from_metafile(
1328 file, self.webdir, compare=self.compare_results,
1329 nsamples=self.nsamples, reweight_samples=self.reweight_samples,
1330 disable_injection=self.disable_injection,
1331 keep_nan_likelihood_samples=self.keep_nan_likelihood_samples,
1332 **grab_data_kwargs
1333 )
1334 else:
1335 data = self.grab_data_from_file(
1336 file, label, self.webdir, config=config, injection=injection,
1337 file_format=file_format, nsamples=self.nsamples,
1338 disable_prior_sampling=self.disable_prior_sampling,
1339 nsamples_for_prior=self.nsamples_for_prior,
1340 path_to_samples=self.path_to_samples[label],
1341 reweight_samples=self.reweight_samples,
1342 keep_nan_likelihood_samples=self.keep_nan_likelihood_samples,
1343 **grab_data_kwargs
1344 )
1345 self._open_result_files.update({file: data["open_file"]})
1346 return data
1348 @property
1349 def email(self):
1350 return self._email
1352 @email.setter
1353 def email(self, email):
1354 if email is not None and "@" not in email:
1355 raise InputError("Please provide a valid email address")
1356 self._email = email
1358 @property
1359 def dump(self):
1360 return self._dump
1362 @dump.setter
1363 def dump(self, dump):
1364 self._dump = dump
1366 @property
1367 def palette(self):
1368 return self._palette
1370 @palette.setter
1371 def palette(self, palette):
1372 self._palette = palette
1373 if palette is not conf.palette:
1374 import seaborn
1376 try:
1377 seaborn.color_palette(palette, n_colors=1)
1378 logger.info(
1379 conf.overwrite.format("palette", conf.palette, palette)
1380 )
1381 conf.palette = palette
1382 except ValueError as e:
1383 raise InputError(
1384 "Unrecognised palette. Please choose from one of the "
1385 "following {}".format(
1386 ", ".join(seaborn.palettes.SEABORN_PALETTES.keys())
1387 )
1388 )
1390 @property
1391 def include_prior(self):
1392 return self._include_prior
1394 @include_prior.setter
1395 def include_prior(self, include_prior):
1396 self._include_prior = include_prior
1397 if include_prior != conf.include_prior:
1398 conf.overwrite.format("prior", conf.include_prior, include_prior)
1399 conf.include_prior = include_prior
1401 @property
1402 def colors(self):
1403 return self._colors
1405 @colors.setter
1406 def colors(self, colors):
1407 if colors is not None:
1408 number = len(self.labels)
1409 if self.existing:
1410 number += len(self.existing_labels)
1411 if len(colors) != number and len(colors) > number:
1412 logger.info(
1413 "You have passed {} colors for {} result files. Setting "
1414 "colors = {}".format(
1415 len(colors), number, colors[:number]
1416 )
1417 )
1418 self._colors = colors[:number]
1419 return
1420 elif len(colors) != number:
1421 logger.warning(
1422 "Number of colors does not match the number of labels. "
1423 "Using default colors"
1424 )
1425 import seaborn
1427 number = len(self.labels)
1428 if self.existing:
1429 number += len(self.existing_labels)
1430 colors = seaborn.color_palette(
1431 palette=conf.palette, n_colors=number
1432 ).as_hex()
1433 self._colors = colors
1435 @property
1436 def linestyles(self):
1437 return self._linestyles
1439 @linestyles.setter
1440 def linestyles(self, linestyles):
1441 if linestyles is not None:
1442 if len(linestyles) != len(self.colors):
1443 if len(linestyles) > len(self.colors):
1444 logger.info(
1445 "You have passed {} linestyles for {} result files. "
1446 "Setting linestyles = {}".format(
1447 len(linestyles), len(self.colors),
1448 linestyles[:len(self.colors)]
1449 )
1450 )
1451 self._linestyles = linestyles[:len(self.colors)]
1452 return
1453 else:
1454 logger.warning(
1455 "Number of linestyles does not match the number of "
1456 "labels. Using default linestyles"
1457 )
1458 available_linestyles = ["-", "--", ":", "-."]
1459 linestyles = ["-"] * len(self.colors)
1460 unique_colors = np.unique(self.colors)
1461 for color in unique_colors:
1462 indicies = [num for num, i in enumerate(self.colors) if i == color]
1463 for idx, j in enumerate(indicies):
1464 linestyles[j] = available_linestyles[
1465 np.mod(idx, len(available_linestyles))
1466 ]
1467 self._linestyles = linestyles
1469 @property
1470 def disable_corner(self):
1471 return self._disable_corner
1473 @disable_corner.setter
1474 def disable_corner(self, disable_corner):
1475 self._disable_corner = disable_corner
1476 if disable_corner:
1477 logger.warning(
1478 "No corner plot will be produced. This will reduce overall "
1479 "runtime but does mean that the interactive corner plot feature "
1480 "on the webpages will no longer work"
1481 )
1483 @property
1484 def add_to_corner(self):
1485 return self._add_to_corner
1487 @add_to_corner.setter
1488 def add_to_corner(self, add_to_corner):
1489 self._add_to_corner = self._set_corner_params(add_to_corner)
1491 def _set_corner_params(self, corner_params):
1492 cls = self.__class__.__name__
1493 if corner_params is not None:
1494 for label in self.labels:
1495 _not_included = [
1496 param for param in corner_params if param not in
1497 self.samples[label].keys()
1498 ]
1499 if len(_not_included) == len(corner_params) and cls == "Input":
1500 logger.warning(
1501 "None of the chosen corner parameters are "
1502 "included in the posterior table for '{}'. Using "
1503 "all available parameters for the corner plot".format(
1504 label
1505 )
1506 )
1507 corner_params = None
1508 break
1509 elif len(_not_included):
1510 logger.warning(
1511 "The following parameters are not included in the "
1512 "posterior table for '{}': {}. Not adding to corner "
1513 "plot".format(label, ", ".join(_not_included))
1514 )
1515 elif cls == "Input":
1516 logger.debug(
1517 "Using all parameters stored in the result file for the "
1518 "corner plots. This may take some time."
1519 )
1520 return corner_params
1522 @property
1523 def pe_algorithm(self):
1524 return self._pe_algorithm
1526 @pe_algorithm.setter
1527 def pe_algorithm(self, pe_algorithm):
1528 self._pe_algorithm = pe_algorithm
1529 if pe_algorithm is None:
1530 return
1531 if len(pe_algorithm) != len(self.labels):
1532 raise ValueError("Please provide an algorithm for each result file")
1533 for num, (label, _algorithm) in enumerate(zip(self.labels, pe_algorithm)):
1534 if "pe_algorithm" in self.file_kwargs[label]["sampler"].keys():
1535 _stored = self.file_kwargs[label]["sampler"]["pe_algorithm"]
1536 if _stored != _algorithm:
1537 logger.warning(
1538 "Overwriting the pe_algorithm extracted from the file "
1539 "'{}': {} with the algorithm provided from the command "
1540 "line: {}".format(
1541 self.result_files[num], _stored, _algorithm
1542 )
1543 )
1544 self.file_kwargs[label]["sampler"]["pe_algorithm"] = _algorithm
1546 @property
1547 def notes(self):
1548 return self._notes
1550 @notes.setter
1551 def notes(self, notes):
1552 self._notes = notes
1553 if notes is not None:
1554 if not os.path.isfile(notes):
1555 raise InputError(
1556 "No such file or directory called {}".format(notes)
1557 )
1558 try:
1559 with open(notes, "r") as f:
1560 self._notes = f.read()
1561 except FileNotFoundError:
1562 logger.warning(
1563 "No such file or directory called {}. Custom notes will "
1564 "not be added to the summarypages".format(notes)
1565 )
1566 except IOError as e:
1567 logger.warning(
1568 "Failed to read {}. Unable to put notes on "
1569 "summarypages".format(notes)
1570 )
1572 @property
1573 def descriptions(self):
1574 return self._descriptions
1576 @descriptions.setter
1577 def descriptions(self, descriptions):
1578 import json
1579 if hasattr(self, "_descriptions") and not len(descriptions):
1580 return
1581 elif not len(descriptions):
1582 self._descriptions = None
1583 return
1585 if len(descriptions) and isinstance(descriptions, dict):
1586 data = descriptions
1587 elif len(descriptions):
1588 descriptions = descriptions[0]
1589 _is_file = not isinstance(descriptions, dict)
1590 if hasattr(self, "_descriptions"):
1591 logger.warning(
1592 "Ignoring descriptions found in result file and using "
1593 "descriptions in '{}'".format(descriptions)
1594 )
1595 self._descriptions = None
1596 if _is_file and not os.path.isfile(descriptions):
1597 logger.warning(
1598 "No such file called {}. Unable to add descriptions".format(
1599 descriptions
1600 )
1601 )
1602 return
1603 if _is_file:
1604 try:
1605 with open(descriptions, "r") as f:
1606 data = json.load(f)
1607 except json.decoder.JSONDecodeError:
1608 logger.warning(
1609 "Unable to open file '{}'. Not storing descriptions".format(
1610 descriptions
1611 )
1612 )
1613 return
1614 if not all(label in data.keys() for label in self.labels):
1615 not_included = [
1616 label for label in self.labels if label not in data.keys()
1617 ]
1618 logger.debug(
1619 "No description found for '{}'. Using default "
1620 "description".format(", ".join(not_included))
1621 )
1622 for label in not_included:
1623 data[label] = "No description found"
1624 if len(data.keys()) > len(self.labels):
1625 logger.warning(
1626 "Descriptions file contains descriptions for analyses other "
1627 "than {}. Ignoring other descriptions".format(
1628 ", ".join(self.labels)
1629 )
1630 )
1631 other = [
1632 analysis for analysis in data.keys() if analysis not in
1633 self.labels
1634 ]
1635 for analysis in other:
1636 _ = data.pop(analysis)
1637 _remove = []
1638 for key, desc in data.items():
1639 if not isinstance(desc, (str, bytes)):
1640 logger.warning(
1641 "Unknown description '{}' for '{}'. The description should "
1642 "be a string or bytes object"
1643 )
1644 _remove.append(key)
1645 if len(_remove):
1646 for analysis in _remove:
1647 _ = data.pop(analysis)
1648 self._descriptions = data
1650 @property
1651 def preferred(self):
1652 return self._preferred
1654 @preferred.setter
1655 def preferred(self, preferred):
1656 if preferred is not None and preferred not in self.labels:
1657 logger.warning(
1658 "'{}' not in list of labels. Unable to stored as the "
1659 "preferred analysis".format(preferred)
1660 )
1661 self._preferred = None
1662 elif preferred is not None:
1663 logger.debug(
1664 "Setting '{}' as the preferred analysis".format(preferred)
1665 )
1666 self._preferred = preferred
1667 elif len(self.labels) == 1:
1668 self._preferred = self.labels[0]
1669 else:
1670 self._preferred = None
1671 if self._preferred is not None:
1672 try:
1673 self.file_kwargs[self._preferred]["other"].update(
1674 {"preferred": "True"}
1675 )
1676 except KeyError:
1677 self.file_kwargs[self._preferred].update(
1678 {"other": {"preferred": "True"}}
1679 )
1680 for _label in self.labels:
1681 if self._preferred is not None and _label == self._preferred:
1682 continue
1683 try:
1684 self.file_kwargs[_label]["other"].update(
1685 {"preferred": "False"}
1686 )
1687 except KeyError:
1688 self.file_kwargs[_label].update(
1689 {"other": {"preferred": "False"}}
1690 )
1691 return
1693 @property
1694 def public(self):
1695 return self._public
1697 @public.setter
1698 def public(self, public):
1699 self._public = public
1700 if public != conf.public:
1701 logger.info(
1702 conf.overwrite.format("public", conf.public, public)
1703 )
1705 @property
1706 def multi_process(self):
1707 return self._multi_process
1709 @multi_process.setter
1710 def multi_process(self, multi_process):
1711 self._multi_process = int(multi_process)
1712 if multi_process is not None and int(multi_process) != int(conf.multi_process):
1713 logger.info(
1714 conf.overwrite.format(
1715 "multi_process", conf.multi_process, multi_process
1716 )
1717 )
1719 @property
1720 def publication_kwargs(self):
1721 return self._publication_kwargs
1723 @publication_kwargs.setter
1724 def publication_kwargs(self, publication_kwargs):
1725 self._publication_kwargs = publication_kwargs
1726 if publication_kwargs != {}:
1727 allowed_kwargs = ["gridsize"]
1728 if not any(i in publication_kwargs.keys() for i in allowed_kwargs):
1729 logger.warning(
1730 "Currently the only allowed publication kwargs are {}. "
1731 "Ignoring other inputs.".format(
1732 ", ".join(allowed_kwargs)
1733 )
1734 )
1736 @property
1737 def ignore_parameters(self):
1738 return self._ignore_parameters
1740 @ignore_parameters.setter
1741 def ignore_parameters(self, ignore_parameters):
1742 self._ignore_parameters = ignore_parameters
1743 if ignore_parameters is not None:
1744 for num, label in enumerate(self.labels):
1745 removed_parameters = list_match(
1746 list(self.samples[label].keys()), ignore_parameters
1747 )
1748 if not len(removed_parameters):
1749 logger.warning(
1750 "Failed to remove any parameters from {}".format(
1751 self.result_files[num]
1752 )
1753 )
1754 else:
1755 logger.warning(
1756 "Removing parameters: {} from {}".format(
1757 ", ".join(removed_parameters),
1758 self.result_files[num]
1759 )
1760 )
1761 for ignore in removed_parameters:
1762 self.samples[label].pop(ignore)
1764 @staticmethod
1765 def _make_directories(webdir, dirs):
1766 """Make the directories to store the information
1767 """
1768 for i in dirs:
1769 if not os.path.isdir(os.path.join(webdir, i)):
1770 make_dir(os.path.join(webdir, i))
1772 def make_directories(self):
1773 """Make the directories to store the information
1774 """
1775 self._make_directories(self.webdir, self.default_directories)
1777 @staticmethod
1778 def _copy_files(paths):
1779 """Copy the relevant file to the web directory
1781 Parameters
1782 ----------
1783 paths: nd list
1784 list of files you wish to copy. First element is the path of the
1785 file to copy and second element is the location of where you
1786 wish the file to be put
1788 Examples
1789 --------
1790 >>> paths = [
1791 ... ["config/config.ini", "webdir/config.ini"],
1792 ... ["samples/samples.h5", "webdir/samples.h5"]
1793 ... ]
1794 """
1795 import shutil
1797 for ff in paths:
1798 shutil.copyfile(ff[0], ff[1])
1800 def copy_files(self):
1801 """Copy the relevant file to the web directory
1802 """
1803 self._copy_files(self.default_files_to_copy)
1805 def default_labels(self):
1806 """Return a list of default labels.
1807 """
1808 from time import time
1810 def _default_label(file_name):
1811 return "%s_%s" % (round(time()), file_name)
1813 label_list = []
1814 if self.result_files is None or len(self.result_files) == 0:
1815 raise InputError("Please provide a results file")
1816 elif self.mcmc_samples:
1817 f = self.result_files[0]
1818 file_name = os.path.splitext(os.path.basename(f))[0]
1819 label_list.append(_default_label(file_name))
1820 else:
1821 for num, i in enumerate(self.result_files):
1822 file_name = os.path.splitext(os.path.basename(i))[0]
1823 label_list.append(_default_label(file_name))
1825 duplicates = dict(set(
1826 (x, label_list.count(x)) for x in
1827 filter(lambda rec: label_list.count(rec) > 1, label_list)))
1829 for i in duplicates.keys():
1830 for j in range(duplicates[i]):
1831 ind = label_list.index(i)
1832 label_list[ind] += "_%s" % (j)
1833 if self.add_to_existing:
1834 for num, i in enumerate(label_list):
1835 if i in self.existing_labels:
1836 ind = label_list.index(i)
1837 label_list[ind] += "_%s" % (num)
1838 return label_list
1840 @staticmethod
1841 def get_package_information():
1842 """Return a dictionary of parameter information
1843 """
1844 from pesummary._version_helper import PackageInformation
1845 from operator import itemgetter
1847 _package = PackageInformation()
1848 package_info = _package.package_info
1849 package_dir = _package.package_dir
1850 if "build_string" in package_info[0]: # conda list
1851 headings = ("name", "version", "channel", "build_string")
1852 else: # pip list installed
1853 headings = ("name", "version")
1854 packages = np.array([
1855 tuple(pkg[col.lower()] for col in headings) for pkg in
1856 sorted(package_info, key=itemgetter("name"))
1857 ], dtype=[(col, "S20") for col in headings]).view(np.recarray)
1858 return {
1859 "packages": packages, "environment": [package_dir],
1860 "manager": _package.package_manager
1861 }
1863 def grab_key_data_from_result_files(self):
1864 """Grab the mean, median, maxL and standard deviation for all
1865 parameters for all each result file
1866 """
1867 key_data = {
1868 key: samples.key_data for key, samples in self.samples.items()
1869 }
1870 for key, val in self.samples.items():
1871 for j in val.keys():
1872 _inj = self.injection_data[key][j]
1873 key_data[key][j]["injected"] = (
1874 _inj[0] if not math.isnan(_inj) and isinstance(
1875 _inj, (list, np.ndarray)
1876 ) else _inj
1877 )
1878 return key_data
1881class BaseInput(_Input):
1882 """Class to handle and store base command line arguments
1883 """
1884 def __init__(self, opts, ignore_copy=False, checkpoint=None, gw=False):
1885 self.opts = opts
1886 self.gw = gw
1887 self.restart_from_checkpoint = self.opts.restart_from_checkpoint
1888 if checkpoint is not None:
1889 for key, item in vars(checkpoint).items():
1890 setattr(self, key, item)
1891 logger.info(
1892 "Loaded command line arguments: {}".format(self.opts)
1893 )
1894 self.restart_from_checkpoint = True
1895 self._restarted_from_checkpoint = True
1896 return
1897 self.seed = self.opts.seed
1898 self.result_files = self.opts.samples
1899 self.user = self.opts.user
1900 self.existing = self.opts.existing
1901 self.add_to_existing = False
1902 if self.existing is not None:
1903 self.add_to_existing = True
1904 self.existing_metafile = True
1905 self.webdir = self.opts.webdir
1906 self._restarted_from_checkpoint = False
1907 self.resume_file_dir = conf.checkpoint_dir(self.webdir)
1908 self.resume_file = conf.resume_file
1909 self._resume_file_path = os.path.join(
1910 self.resume_file_dir, self.resume_file
1911 )
1912 self.make_directories()
1913 self.email = self.opts.email
1914 self.pe_algorithm = self.opts.pe_algorithm
1915 self.multi_process = self.opts.multi_process
1916 self.package_information = self.get_package_information()
1917 if not ignore_copy:
1918 self.copy_files()
1919 self.write_current_state()
1921 @property
1922 def default_directories(self):
1923 return ["checkpoint"]
1925 @property
1926 def default_files_to_copy(self):
1927 return []
1929 def write_current_state(self):
1930 """Write the current state of the input class to file
1931 """
1932 from pesummary.io import write
1933 write(
1934 self, outdir=self.resume_file_dir, file_format="pickle",
1935 filename=self.resume_file, overwrite=True
1936 )
1937 logger.debug(
1938 "Written checkpoint file: {}".format(self._resume_file_path)
1939 )
1942class SamplesInput(BaseInput):
1943 """Class to handle and store sample specific command line arguments
1944 """
1945 def __init__(self, *args, extra_options=None, **kwargs):
1946 """
1947 """
1948 super(SamplesInput, self).__init__(*args, **kwargs)
1949 if self.result_files is not None:
1950 self._open_result_files = {path: None for path in self.result_files}
1951 self.meta_file = False
1952 if self.result_files is not None and len(self.result_files) == 1:
1953 self.meta_file = self.is_pesummary_metafile(self.result_files[0])
1954 self.compare_results = self.opts.compare_results
1955 self.disable_injection = self.opts.disable_injection
1956 if self.existing is not None:
1957 self.existing_data = self.grab_data_from_metafile(
1958 self.existing_metafile, self.existing,
1959 compare=self.compare_results
1960 )
1961 self.existing_samples = self.existing_data["samples"]
1962 self.existing_injection_data = self.existing_data["injection_data"]
1963 self.existing_file_version = self.existing_data["file_version"]
1964 self.existing_file_kwargs = self.existing_data["file_kwargs"]
1965 self.existing_priors = self.existing_data["prior"]
1966 self.existing_config = self.existing_data["config"]
1967 self.existing_labels = self.existing_data["labels"]
1968 self.existing_weights = self.existing_data["weights"]
1969 else:
1970 self.existing_metafile = None
1971 self.existing_labels = None
1972 self.existing_weights = None
1973 self.existing_samples = None
1974 self.existing_file_version = None
1975 self.existing_file_kwargs = None
1976 self.existing_priors = None
1977 self.existing_config = None
1978 self.existing_injection_data = None
1979 self.mcmc_samples = self.opts.mcmc_samples
1980 self.labels = self.opts.labels
1981 self.weights = {i: None for i in self.labels}
1982 self.config = self.opts.config
1983 self.injection_file = self.opts.inj_file
1984 self.regenerate = self.opts.regenerate
1985 if extra_options is not None:
1986 for opt in extra_options:
1987 setattr(self, opt, getattr(self.opts, opt))
1988 self.nsamples_for_prior = self.opts.nsamples_for_prior
1989 self.priors = self.opts.prior_file
1990 self.disable_prior_sampling = self.opts.disable_prior_sampling
1991 self.path_to_samples = self.opts.path_to_samples
1992 self.file_format = self.opts.file_format
1993 self.nsamples = self.opts.nsamples
1994 self.keep_nan_likelihood_samples = self.opts.keep_nan_likelihood_samples
1995 self.reweight_samples = self.opts.reweight_samples
1996 self.samples = self.opts.samples
1997 self.ignore_parameters = self.opts.ignore_parameters
1998 self.burnin_method = self.opts.burnin_method
1999 self.burnin = self.opts.burnin
2000 self.same_parameters = []
2001 if self.mcmc_samples:
2002 self._samples = {label: self.samples.T for label in self.labels}
2003 self.write_current_state()
2005 @property
2006 def analytic_prior_dict(self):
2007 return {
2008 label: "\n".join(
2009 [
2010 "{} = {}".format(key, value) for key, value in
2011 self.priors["analytic"][label].items()
2012 ]
2013 ) if "analytic" in self.priors.keys() and label in
2014 self.priors["analytic"].keys() else None for label in self.labels
2015 }
2017 @property
2018 def same_parameters(self):
2019 return self._same_parameters
2021 @same_parameters.setter
2022 def same_parameters(self, same_parameters):
2023 self._same_parameters = self.intersect_samples_dict(self.samples)
2025 def intersect_samples_dict(self, samples):
2026 parameters = [
2027 list(samples[key].keys()) for key in samples.keys()
2028 ]
2029 params = list(set.intersection(*[set(l) for l in parameters]))
2030 return params
2033class PlottingInput(SamplesInput):
2034 """Class to handle and store plotting specific command line arguments
2035 """
2036 def __init__(self, *args, **kwargs):
2037 super(PlottingInput, self).__init__(*args, **kwargs)
2038 self.style_file = self.opts.style_file
2039 self.publication = self.opts.publication
2040 self.publication_kwargs = self.opts.publication_kwargs
2041 self.kde_plot = self.opts.kde_plot
2042 self.custom_plotting = self.opts.custom_plotting
2043 self.add_to_corner = self.opts.add_to_corner
2044 self.corner_params = self.add_to_corner
2045 self.palette = self.opts.palette
2046 self.include_prior = self.opts.include_prior
2047 self.colors = self.opts.colors
2048 self.linestyles = self.opts.linestyles
2049 self.disable_corner = self.opts.disable_corner
2050 self.disable_comparison = self.opts.disable_comparison
2051 self.disable_interactive = self.opts.disable_interactive
2052 self.disable_expert = not self.opts.enable_expert
2053 self.multi_threading_for_plots = self.multi_process
2054 self.write_current_state()
2056 @property
2057 def default_directories(self):
2058 dirs = super(PlottingInput, self).default_directories
2059 dirs += ["plots", "plots/corner", "plots/publication"]
2060 return dirs
2063class WebpageInput(SamplesInput):
2064 """Class to handle and store webpage specific command line arguments
2065 """
2066 def __init__(self, *args, **kwargs):
2067 super(WebpageInput, self).__init__(*args, **kwargs)
2068 self.baseurl = self.opts.baseurl
2069 self.existing_plot = self.opts.existing_plot
2070 self.pe_algorithm = self.opts.pe_algorithm
2071 self.notes = self.opts.notes
2072 self.dump = self.opts.dump
2073 self.hdf5 = not self.opts.save_to_json
2074 self.external_hdf5_links = self.opts.external_hdf5_links
2075 self.file_kwargs["webpage_url"] = self.baseurl + "/home.html"
2076 self.write_current_state()
2078 @property
2079 def default_directories(self):
2080 dirs = super(WebpageInput, self).default_directories
2081 dirs += ["js", "html", "css"]
2082 return dirs
2084 @property
2085 def default_files_to_copy(self):
2086 from pesummary import core
2087 files_to_copy = super(WebpageInput, self).default_files_to_copy
2088 path = core.__path__[0]
2089 scripts = glob(os.path.join(path, "js", "*.js"))
2090 for i in scripts:
2091 files_to_copy.append(
2092 [i, os.path.join(self.webdir, "js", os.path.basename(i))]
2093 )
2094 scripts = glob(os.path.join(path, "css", "*.css"))
2095 for i in scripts:
2096 files_to_copy.append(
2097 [i, os.path.join(self.webdir, "css", os.path.basename(i))]
2098 )
2099 return files_to_copy
2102class WebpagePlusPlottingInput(PlottingInput, WebpageInput):
2103 """Class to handle and store webpage and plotting specific command line
2104 arguments
2105 """
2106 def __init__(self, *args, **kwargs):
2107 super(WebpagePlusPlottingInput, self).__init__(*args, **kwargs)
2108 self.copy_files()
2110 @property
2111 def default_directories(self):
2112 return super(WebpagePlusPlottingInput, self).default_directories
2114 @property
2115 def default_files_to_copy(self):
2116 return super(WebpagePlusPlottingInput, self).default_files_to_copy
2119class MetaFileInput(SamplesInput):
2120 """Class to handle and store metafile specific command line arguments
2121 """
2122 def __init__(self, *args, **kwargs):
2123 kwargs.update({"ignore_copy": True})
2124 super(MetaFileInput, self).__init__(*args, **kwargs)
2125 self.copy_files()
2126 self.filename = self.opts.filename
2127 self.hdf5 = not self.opts.save_to_json
2128 self.hdf5_compression = self.opts.hdf5_compression
2129 self.external_hdf5_links = self.opts.external_hdf5_links
2130 self.descriptions = self.opts.descriptions
2131 self.preferred = self.opts.preferred
2132 self.write_current_state()
2134 @property
2135 def default_directories(self):
2136 dirs = super(MetaFileInput, self).default_directories
2137 dirs += ["samples", "config"]
2138 return dirs
2140 @property
2141 def default_files_to_copy(self):
2142 files_to_copy = super(MetaFileInput, self).default_files_to_copy
2143 if not all(i is None for i in self.config):
2144 for num, i in enumerate(self.config):
2145 if i is not None and self.webdir not in i:
2146 filename = "_".join(
2147 [self.labels[num], "config.ini"]
2148 )
2149 files_to_copy.append(
2150 [i, os.path.join(self.webdir, "config", filename)]
2151 )
2152 for num, _file in enumerate(self.result_files):
2153 if not self.mcmc_samples:
2154 filename = "{}_{}".format(self.labels[num], Path(_file).name)
2155 else:
2156 filename = "chain_{}_{}".format(num, Path(_file).name)
2157 files_to_copy.append(
2158 [_file, os.path.join(self.webdir, "samples", filename)]
2159 )
2160 return files_to_copy
2163class WebpagePlusPlottingPlusMetaFileInput(MetaFileInput, WebpagePlusPlottingInput):
2164 """Class to handle and store webpage, plotting and metafile specific command
2165 line arguments
2166 """
2167 def __init__(self, *args, **kwargs):
2168 super(WebpagePlusPlottingPlusMetaFileInput, self).__init__(
2169 *args, **kwargs
2170 )
2172 @property
2173 def default_directories(self):
2174 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_directories
2176 @property
2177 def default_files_to_copy(self):
2178 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_files_to_copy
2181@deprecation(
2182 "The Input class is deprecated. Please use either the BaseInput, "
2183 "SamplesInput, PlottingInput, WebpageInput, WebpagePlusPlottingInput, "
2184 "MetaFileInput or the WebpagePlusPlottingPlusMetaFileInput class"
2185)
2186class Input(WebpagePlusPlottingPlusMetaFileInput):
2187 pass
2190def load_current_state(resume_file):
2191 """Load a pickle file containing checkpoint information
2193 Parameters
2194 ----------
2195 resume_file: str
2196 path to a checkpoint file
2197 """
2198 from pesummary.io import read
2199 if not os.path.isfile(resume_file):
2200 logger.info(
2201 "Unable to find resume file. Not restarting from checkpoint"
2202 )
2203 return
2204 logger.info(
2205 "Reading checkpoint file: {}".format(resume_file)
2206 )
2207 state = read(resume_file, checkpoint=True)
2208 return state