Coverage for pesummary/utils/utils.py: 66.9%
501 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import os
4import sys
5import logging
6import contextlib
7import time
8import copy
9import shutil
11import numpy as np
12from scipy.integrate import cumulative_trapezoid as cumtrapz
13from scipy.interpolate import interp1d
14from scipy import stats
15import h5py
16from pesummary import conf
18__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
20try:
21 from coloredlogs import ColoredFormatter as LogFormatter
22except ImportError:
23 LogFormatter = logging.Formatter
25CACHE_DIR = os.path.join(
26 os.getenv(
27 "XDG_CACHE_HOME",
28 os.path.expanduser(os.path.join("~", ".cache")),
29 ),
30 "pesummary",
31)
32STYLE_CACHE = os.path.join(CACHE_DIR, "style")
33LOG_CACHE = os.path.join(CACHE_DIR, "log")
35LATEX = shutil.which("latex")
38def resample_posterior_distribution(posterior, nsamples):
39 """Randomly draw nsamples from the posterior distribution
41 Parameters
42 ----------
43 posterior: ndlist
44 nd list of posterior samples. If you only want to resample one
45 posterior distribution then posterior=[[1., 2., 3., 4.]]. For multiple
46 posterior distributions then posterior=[[1., 2., 3., 4.], [1., 2., 3.]]
47 nsamples: int
48 number of samples that you wish to randomly draw from the distribution
49 """
50 if len(posterior) == 1:
51 n, bins = np.histogram(posterior, bins=50)
52 n = np.array([0] + [i for i in n])
53 cdf = cumtrapz(n, bins, initial=0)
54 cdf /= cdf[-1]
55 icdf = interp1d(cdf, bins)
56 samples = icdf(np.random.rand(nsamples))
57 else:
58 posterior = np.array([i for i in posterior])
59 keep_idxs = np.random.choice(
60 len(posterior[0]), nsamples, replace=False)
61 samples = [i[keep_idxs] for i in posterior]
62 return samples
65def check_file_exists_and_rename(file_name):
66 """Check to see if a file exists and if it does then rename the file
68 Parameters
69 ----------
70 file_name: str
71 proposed file name to store data
72 """
73 if os.path.isfile(file_name):
74 import shutil
76 old_file = "{}_old".format(file_name)
77 while os.path.isfile(old_file):
78 old_file += "_old"
79 logger.warning(
80 "The file '{}' already exists. Renaming the existing file to "
81 "{} and saving the data to the requested file name".format(
82 file_name, old_file
83 )
84 )
85 shutil.move(file_name, old_file)
88def check_condition(condition, error_message):
89 """Raise an exception if the condition is not satisfied
90 """
91 if condition:
92 raise Exception(error_message)
95def rename_group_or_dataset_in_hf5_file(base_file, group=None, dataset=None):
96 """Rename a group or dataset in an hdf5 file
98 Parameters
99 ----------
100 group: list, optional
101 a list containing the path to the group that you would like to change
102 as the first argument and the new name of the group as the second
103 argument
104 dataset: list, optional
105 a list containing the name of the dataset that you would like to change
106 as the first argument and the new name of the dataset as the second
107 argument
108 """
109 condition = not os.path.isfile(base_file)
110 check_condition(condition, "The file %s does not exist" % (base_file))
111 f = h5py.File(base_file, "a")
112 if group:
113 f[group[1]] = f[group[0]]
114 del f[group[0]]
115 elif dataset:
116 f[dataset[1]] = f[dataset[0]]
117 del f[dataset[0]]
118 f.close()
121def make_dir(path):
122 if os.path.isdir(os.path.expanduser(path)):
123 pass
124 else:
125 os.makedirs(os.path.expanduser(path), exist_ok=True)
128def guess_url(web_dir, host, user):
129 """Guess the base url from the host name
131 Parameters
132 ----------
133 web_dir: str
134 path to the web directory where you want the data to be saved
135 host: str
136 the host name of the machine where the python interpreter is currently
137 executing
138 user: str
139 the user that is current executing the python interpreter
140 """
141 ligo_data_grid = False
142 if 'public_html' in web_dir:
143 ligo_data_grid = True
144 if ligo_data_grid:
145 path = web_dir.split("public_html")[1]
146 if "raven" in host or "arcca" in host:
147 url = "https://geo2.arcca.cf.ac.uk/~{}".format(user)
148 elif 'ligo-wa' in host:
149 url = "https://ldas-jobs.ligo-wa.caltech.edu/~{}".format(user)
150 elif 'ligo-la' in host:
151 url = "https://ldas-jobs.ligo-la.caltech.edu/~{}".format(user)
152 elif "cit" in host or "caltech" in host:
153 url = "https://ldas-jobs.ligo.caltech.edu/~{}".format(user)
154 elif 'uwm' in host or 'nemo' in host:
155 url = "https://ldas-jobs.phys.uwm.edu/~{}".format(user)
156 elif 'phy.syr.edu' in host:
157 url = "https://sugar-jobs.phy.syr.edu/~{}".format(user)
158 elif 'vulcan' in host:
159 url = "https://galahad.aei.mpg.de/~{}".format(user)
160 elif 'atlas' in host:
161 url = "https://atlas1.atlas.aei.uni-hannover.de/~{}".format(user)
162 elif 'iucaa' in host:
163 url = "https://ldas-jobs.gw.iucaa.in/~{}".format(user)
164 elif 'alice' in host:
165 url = "https://dumpty.alice.icts.res.in/~{}".format(user)
166 elif 'hawk' in host:
167 url = "https://ligo.gravity.cf.ac.uk/~{}".format(user)
168 else:
169 url = "https://{}/~{}".format(host, user)
170 url += path
171 else:
172 url = "https://{}".format(web_dir)
173 return url
176def map_parameter_names(dictionary, mapping):
177 """Modify keys in dictionary to use different names according to a map
179 Parameters
180 ----------
181 mapping: dict
182 dictionary mapping existing keys to new names.
184 Returns
185 -------
186 standard_dict: dict
187 dict object with new parameter names
188 """
189 standard_dict = {}
190 for key, item in dictionary.items():
191 if key not in mapping.keys():
192 standard_dict[key] = item
193 continue
194 standard_dict[mapping[key]] = item
195 return standard_dict
198def command_line_arguments():
199 """Return the command line arguments
200 """
201 return sys.argv[1:]
204def command_line_dict():
205 """Return a dictionary of command line arguments
206 """
207 from pesummary.gw.cli.parser import ArgumentParser
208 parser = ArgumentParser()
209 parser.add_all_known_options_to_parser()
210 opts = parser.parse_args()
211 return vars(opts)
214def gw_results_file(opts):
215 """Determine if a GW results file is passed
216 """
217 from pesummary.gw.cli.parser import ArgumentParser
219 attrs, defaults = ArgumentParser().gw_options
220 condition = any(
221 hasattr(opts, attr) and getattr(opts, attr) and getattr(opts, attr)
222 != default for attr, default in zip(attrs, defaults)
223 )
224 if condition:
225 return True
226 return False
229def functions(opts, gw=False):
230 """Return a dictionary of functions that are either specific to GW results
231 files or core.
232 """
233 from pesummary.core.cli.inputs import (
234 WebpagePlusPlottingPlusMetaFileInput as Input
235 )
236 from pesummary.gw.cli.inputs import (
237 WebpagePlusPlottingPlusMetaFileInput as GWInput
238 )
239 from pesummary.core.file.meta_file import MetaFile
240 from pesummary.gw.file.meta_file import GWMetaFile
241 from pesummary.core.finish import FinishingTouches
242 from pesummary.gw.finish import GWFinishingTouches
244 dictionary = {}
245 dictionary["input"] = GWInput if gw_results_file(opts) or gw else Input
246 dictionary["MetaFile"] = GWMetaFile if gw_results_file(opts) or gw else MetaFile
247 dictionary["FinishingTouches"] = \
248 GWFinishingTouches if gw_results_file(opts) or gw else FinishingTouches
249 return dictionary
252def _logger_format():
253 return '%(asctime)s %(name)s %(levelname)-8s: %(message)s'
256def setup_logger():
257 """Set up the logger output.
258 """
259 import tempfile
261 def get_console_handler(stream_level="INFO"):
262 console_handler = logging.StreamHandler()
263 console_handler.setLevel(level=getattr(logging, stream_level))
264 console_handler.setFormatter(FORMATTER)
265 return console_handler
267 def get_file_handler(log_file):
268 file_handler = logging.FileHandler(log_file, mode='w')
269 file_handler.setLevel(level=logging.DEBUG)
270 file_handler.setFormatter(FORMATTER)
271 return file_handler
273 make_dir(LOG_CACHE)
274 dirpath = tempfile.mkdtemp(dir=LOG_CACHE)
275 stream_level = 'INFO'
276 if "-v" in sys.argv or "--verbose" in sys.argv:
277 stream_level = 'DEBUG'
279 FORMATTER = LogFormatter(_logger_format(), datefmt='%Y-%m-%d %H:%M:%S')
280 LOG_FILE = '%s/pesummary.log' % (dirpath)
281 logger = logging.getLogger('PESummary')
282 logger.propagate = False
283 logger.setLevel(level=logging.DEBUG)
284 logger.addHandler(get_console_handler(stream_level=stream_level))
285 logger.addHandler(get_file_handler(LOG_FILE))
286 return logger, LOG_FILE
289def remove_tmp_directories():
290 """Remove the temporary directories created by PESummary
291 """
292 import shutil
293 from glob import glob
295 directories = glob(".tmp/pesummary/*")
297 for i in directories:
298 if os.path.isdir(i):
299 shutil.rmtree(i)
300 elif os.path.isfile(i):
301 os.remove(i)
304def _add_existing_data(namespace):
305 """Add existing data to namespace object
306 """
307 for num, i in enumerate(namespace.existing_labels):
308 if hasattr(namespace, "labels") and i not in namespace.labels:
309 namespace.labels.append(i)
310 if hasattr(namespace, "samples") and i not in list(namespace.samples.keys()):
311 namespace.samples[i] = namespace.existing_samples[i]
312 if hasattr(namespace, "weights") and i not in list(namespace.weights.keys()):
313 if namespace.existing_weights is None:
314 namespace.weights[i] = None
315 else:
316 namespace.weights[i] = namespace.existing_weights[i]
317 if hasattr(namespace, "injection_data"):
318 if i not in list(namespace.injection_data.keys()):
319 namespace.injection_data[i] = namespace.existing_injection_data[i]
320 if hasattr(namespace, "file_versions"):
321 if i not in list(namespace.file_versions.keys()):
322 namespace.file_versions[i] = namespace.existing_file_version[i]
323 if hasattr(namespace, "file_kwargs"):
324 if i not in list(namespace.file_kwargs.keys()):
325 namespace.file_kwargs[i] = namespace.existing_file_kwargs[i]
326 if hasattr(namespace, "config"):
327 if namespace.existing_config[num] not in namespace.config:
328 namespace.config.append(namespace.existing_config[num])
329 elif namespace.existing_config[num] is None:
330 namespace.config.append(None)
331 if hasattr(namespace, "priors"):
332 if hasattr(namespace, "existing_priors"):
333 for key, item in namespace.existing_priors.items():
334 if key in namespace.priors.keys():
335 for label in item.keys():
336 if label not in namespace.priors[key].keys():
337 namespace.priors[key][label] = item[label]
338 else:
339 namespace.priors.update({key: item})
340 if hasattr(namespace, "approximant") and namespace.approximant is not None:
341 if i not in list(namespace.approximant.keys()):
342 if i in list(namespace.existing_approximant.keys()):
343 namespace.approximant[i] = namespace.existing_approximant[i]
344 if hasattr(namespace, "psds") and namespace.psds is not None:
345 if i not in list(namespace.psds.keys()):
346 if i in list(namespace.existing_psd.keys()):
347 namespace.psds[i] = namespace.existing_psd[i]
348 else:
349 namespace.psds[i] = {}
350 if hasattr(namespace, "calibration") and namespace.calibration is not None:
351 if i not in list(namespace.calibration.keys()):
352 if i in list(namespace.existing_calibration.keys()):
353 namespace.calibration[i] = namespace.existing_calibration[i]
354 else:
355 namespace.calibration[i] = {}
356 if hasattr(namespace, "skymap") and namespace.skymap is not None:
357 if i not in list(namespace.skymap.keys()):
358 if i in list(namespace.existing_skymap.keys()):
359 namespace.skymap[i] = namespace.existing_skymap[i]
360 else:
361 namespace.skymap[i] = None
362 if hasattr(namespace, "maxL_samples"):
363 if i not in list(namespace.maxL_samples.keys()):
364 namespace.maxL_samples[i] = {
365 key: val.maxL for key, val in namespace.samples[i].items()
366 }
367 if hasattr(namespace, "pastro_probs"):
368 if i not in list(namespace.pastro_probs.keys()):
369 from pesummary.gw.classification import PAstro
370 try:
371 namespace.pastro_probs[i] = {"default": PAstro(
372 namespace.existing_samples[i],
373 ).classification()}
374 except Exception:
375 namespace.pastro_probs[i] = {"default": PAstro.defaults}
376 if hasattr(namespace, "embright_probs"):
377 if i not in list(namespace.embright_probs.keys()):
378 from pesummary.gw.classification import EMBright
379 try:
380 namespace.embright_probs[i] = {"default": EMBright(
381 namespace.existing_samples[i]
382 ).classification()}
383 except Exception:
384 namespace.embright_probs[i] = {"default": EMBright.defaults}
385 if hasattr(namespace, "result_files"):
386 number = len(namespace.labels)
387 while len(namespace.result_files) < number:
388 namespace.result_files.append(namespace.existing_metafile)
389 parameters = [list(namespace.samples[i].keys()) for i in namespace.labels]
390 namespace.same_parameters = list(
391 set.intersection(*[set(l) for l in parameters])
392 )
393 namespace.same_samples = {
394 param: {
395 i: namespace.samples[i][param] for i in namespace.labels
396 } for param in namespace.same_parameters
397 }
398 return namespace
401def customwarn(message, category, filename, lineno, file=None, line=None):
402 """
403 """
404 import sys
405 import warnings
407 sys.stdout.write(
408 warnings.formatwarning("%s" % (message), category, filename, lineno)
409 )
412def determine_gps_time_and_window(maxL_samples, labels):
413 """Determine the gps time and window to use in the spectrogram and
414 omegascan plots
415 """
416 times = [
417 maxL_samples[label]["geocent_time"] for label in labels
418 ]
419 gps_time = np.mean(times)
420 time_range = np.max(times) - np.min(times)
421 if time_range < 4.:
422 window = 4.
423 else:
424 window = time_range * 1.5
425 return gps_time, window
428def number_of_columns_for_legend(labels):
429 """Determine the number of columns to use in a legend
431 Parameters
432 ----------
433 labels: list
434 list of labels in the legend
435 """
436 max_length = np.max([len(i) for i in labels]) + 5.
437 if max_length > 50.:
438 return 1
439 else:
440 return int(50. / max_length)
443class RedirectLogger(object):
444 """Class to redirect the output from other codes to the `pesummary`
445 logger
447 Parameters
448 ----------
449 level: str, optional
450 the level to display the messages
451 """
452 def __init__(self, code, level="Debug"):
453 self.logger = logging.getLogger('PESummary')
454 self.level = getattr(logging, level)
455 self._redirector = contextlib.redirect_stdout(self)
456 self.code = code
458 def isatty(self):
459 pass
461 def write(self, msg):
462 """Write the message to stdout
464 Parameters
465 ----------
466 msg: str
467 the message you wish to be printed to stdout
468 """
469 if msg and not msg.isspace():
470 self.logger.log(self.level, "[from %s] %s" % (self.code, msg))
472 def flush(self):
473 pass
475 def __enter__(self):
476 self._redirector.__enter__()
477 return self
479 def __exit__(self, exc_type, exc_value, traceback):
480 self._redirector.__exit__(exc_type, exc_value, traceback)
483def draw_conditioned_prior_samples(
484 samples_dict, prior_samples_dict, conditioned, xlow, xhigh, N=100,
485 nsamples=1000
486):
487 """Return a prior_dict that is conditioned on certain parameters
489 Parameters
490 ----------
491 samples_dict: pesummary.utils.samples_dict.SamplesDict
492 SamplesDict containing the posterior samples
493 prior_samples_dict: pesummary.utils.samples_dict.SamplesDict
494 SamplesDict containing the prior samples
495 conditioned: list
496 list of parameters that you wish to condition your prior on
497 xlow: dict
498 dictionary of lower bounds for each parameter
499 xhigh: dict
500 dictionary of upper bounds for each parameter
501 N: int, optional
502 number of points to use within the grid. Default 100
503 nsamples: int, optional
504 number of samples to draw. Default 1000
505 """
506 for param in conditioned:
507 indices = _draw_conditioned_prior_samples(
508 prior_samples_dict[param], samples_dict[param], xlow[param],
509 xhigh[param], xN=N, N=nsamples
510 )
511 for key, val in prior_samples_dict.items():
512 prior_samples_dict[key] = val[indices]
514 return prior_samples_dict
517def _draw_conditioned_prior_samples(
518 prior_samples, posterior_samples, xlow, xhigh, xN=1000, N=1000
519):
520 """Return a list of indices for the conditioned prior via rejection
521 sampling. The conditioned prior will then be `prior_samples[indicies]`.
522 Code from Michael Puerrer.
524 Parameters
525 ----------
526 prior_samples: np.ndarray
527 array of prior samples that you wish to condition
528 posterior_samples: np.ndarray
529 array of posterior samples that you wish to condition on
530 xlow: float
531 lower bound for grid to be used
532 xhigh: float
533 upper bound for grid to be used
534 xN: int, optional
535 Number of points to use within the grid
536 N: int, optional
537 Number of samples to generate
538 """
539 from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE
541 prior_KDE = ReflectionBoundedKDE(prior_samples)
542 posterior_KDE = ReflectionBoundedKDE(posterior_samples)
544 x = np.linspace(xlow, xhigh, xN)
545 idx_nz = np.nonzero(posterior_KDE(x))
546 pdf_ratio = prior_KDE(x)[idx_nz] / posterior_KDE(x)[idx_nz]
547 M = 1.1 / min(pdf_ratio[np.where(pdf_ratio < 1)])
549 indicies = []
550 i = 0
551 while i < N:
552 x_i = np.random.choice(prior_samples)
553 idx_i = np.argmin(np.abs(prior_samples - x_i))
554 u = np.random.uniform()
555 if u < posterior_KDE(x_i) / (M * prior_KDE(x_i)):
556 indicies.append(idx_i)
557 i += 1
558 return indicies
561def unzip(zip_file, outdir=None, overwrite=False):
562 """Extract the data from a zipped file and save in outdir.
564 Parameters
565 ----------
566 zip_file: str
567 path to the file you wish to unzip
568 outdir: str, optional
569 path to the directory where you wish to save the unzipped file. Default
570 None which means that the unzipped file is stored in CACHE_DIR
571 overwrite: Bool, optional
572 If True, overwrite a file that has the same name
573 """
574 import gzip
575 import shutil
576 from pathlib import Path
578 f = Path(zip_file)
579 file_name = f.stem
580 if outdir is None:
581 outdir = CACHE_DIR
582 out_file = os.path.join(outdir, file_name)
583 if os.path.isfile(out_file) and not overwrite:
584 raise FileExistsError(
585 "The file '{}' already exists. Not overwriting".format(out_file)
586 )
587 with gzip.open(zip_file, 'rb') as input:
588 with open(out_file, 'wb') as output:
589 shutil.copyfileobj(input, output)
590 return out_file
593def iterator(
594 iterable, desc=None, logger=None, tqdm=False, total=None, file=None,
595 bar_format=None
596):
597 """Return either a tqdm iterator, if tqdm installed, or a simple range
599 Parameters
600 ----------
601 iterable: func
602 iterable that you wish to iterate over
603 desc: str, optional
604 description for the tqdm bar
605 tqdm: Bool, optional
606 If True, a tqdm object is used. Otherwise simply returns the iterator.
607 logger_output: Bool, optional
608 If True, the tqdm progress bar interacts with logger
609 total: float, optional
610 total length of iterable
611 logger_name: str, optional
612 name of the logger you wish to use
613 file: str, optional
614 path to file that you wish to write the output to
615 """
616 from pesummary.utils.tqdm import tqdm
617 if tqdm:
618 try:
619 FORMAT, DESC = None, None
620 if bar_format is None:
621 FORMAT = (
622 '{desc} | {percentage:3.0f}% | {n_fmt}/{total_fmt} | {elapsed}'
623 )
624 if desc is not None:
625 DESC = desc
626 return tqdm(
627 iterable, total=total, logger=logger, desc=DESC, file=file,
628 bar_format=FORMAT,
629 )
630 except ImportError:
631 return iterable
632 else:
633 return iterable
636def _check_latex_install(force_tex=False):
637 from matplotlib import rcParams
638 from matplotlib.texmanager import TexManager
640 # If LaTeX executable is not found, disable usetex quickly
641 if LATEX is None:
642 rcParams["text.usetex"] = False
643 return False
645 # Otherwise, try and render something
646 texmanager = TexManager()
647 try:
648 texmanager.make_dvi(r"$mass_{1}$", 12)
649 except RuntimeError:
650 # It failed, disable usetex
651 rcParams["text.usetex"] = False
652 return False
654 # It works! enable usetex if forced
655 if force_tex:
656 rcParams["text.usetex"] = True
658 return True
661def smart_round(parameters, return_latex=False, return_latex_row=False):
662 """Round a parameter according to the uncertainty. If more than one parameter
663 and uncertainty is passed, each parameter is rounded according to the
664 lowest uncertainty
666 Parameters
667 ----------
668 parameter_dictionary: list/np.ndarray
669 list containing the median, upper bound and lower bound for a given parameter
670 return_latex: Bool, optional
671 if True, return as a latex string
672 return_latex_row: Bool, optional
673 if True, return the rounded data as a single row in latex format
675 Examples
676 --------
677 >>> data = [1.234, 0.2, 0.1]
678 >>> smart_round(data)
679 [ 1.2 0.2 0.1]
680 >>> data = [
681 ... [6.093, 0.059, 0.055],
682 ... [6.104, 0.057, 0.052],
683 ... [6.08, 0.056, 0.052]
684 ... ]
685 >>> smart_round(data)
686 [[ 6.09 0.06 0.06]
687 [ 6.1 0.06 0.05]
688 [ 6.08 0.06 0.05]]
689 >>> smart_round(data, return_latex=True)
690 6.09^{+0.06}_{-0.06}
691 6.10^{+0.06}_{-0.05}
692 6.08^{+0.06}_{-0.05}
693 >>> data = [
694 ... [743.25, 43.6, 53.2],
695 ... [8712.5, 21.5, 35.2],
696 ... [196.46, 65.2, 12.5]
697 ... ]
698 >>> smart_round(data, return_latex_row=True)
699 740^{+40}_{-50} & 8710^{+20}_{-40} & 200^{+70}_{-10}
700 >>> data = [
701 ... [743.25, 43.6, 53.2],
702 ... [8712.5, 21.5, 35.2],
703 ... [196.46, 65.2, 8.2]
704 ... ]
705 >>> smart_round(data, return_latex_row=True)
706 743^{+44}_{-53} & 8712^{+22}_{-35} & 196^{+65}_{-8}
707 """
708 rounded = copy.deepcopy(np.atleast_2d(parameters))
709 lowest_uncertainty = np.min(np.abs(parameters))
710 rounding = int(-1 * np.floor(np.log10(lowest_uncertainty)))
711 for num, _ in enumerate(rounded):
712 rounded[num] = [np.round(value, rounding) for value in rounded[num]]
713 if return_latex or return_latex_row:
714 if rounding > 0:
715 _format = "%.{}f".format(rounding)
716 else:
717 _format = "%.f"
718 string = "{0}^{{+{0}}}_{{-{0}}}".format(_format)
719 latex = [string % (value[0], value[1], value[2]) for value in rounded]
720 if return_latex:
721 for ll in latex:
722 print(ll)
723 else:
724 print(" & ".join(latex))
725 return ""
726 elif np.array(parameters).ndim == 1:
727 return rounded[0]
728 else:
729 return rounded
732def safe_round(a, decimals=0, **kwargs):
733 """Try and round an array to the given number of decimals. If an exception
734 is raised, return the original array
736 Parameters
737 ----------
738 a: np.ndarray
739 array you wish to round
740 decimals: int
741 the number of decimals you wish to round too
742 **kwargs: dict
743 all kwargs are passed to numpy.round
744 """
745 try:
746 return np.round(a, decimals=decimals, **kwargs)
747 except Exception:
748 return a
751def gelman_rubin(samples, decimal=5):
752 """Return an approximation to the Gelman-Rubin statistic (see Gelman, A. and
753 Rubin, D. B., Statistical Science, Vol 7, No. 4, pp. 457--511 (1992))
755 Parameters
756 ----------
757 samples: np.ndarray
758 2d array of samples for a given parameter, one for each chain
759 decimal: int
760 number of decimal places to keep when rounding
762 Examples
763 --------
764 >>> from pesummary.utils.utils import gelman_rubin
765 >>> samples = [[1, 1.5, 1.2, 1.4, 1.6, 1.2], [1.5, 1.3, 1.4, 1.7]]
766 >>> gelman_rubin(samples, decimal=5)
767 1.2972
768 """
769 means = [np.mean(data) for data in samples]
770 variances = [np.var(data) for data in samples]
771 BoverN = np.var(means)
772 W = np.mean(variances)
773 sigma = W + BoverN
774 m = len(samples)
775 Vhat = sigma + BoverN / m
776 return np.round(Vhat / W, decimal)
779def kolmogorov_smirnov_test(samples, decimal=5):
780 """Return the KS p value between two PDFs
782 Parameters
783 ----------
784 samples: 2d list
785 2d list containing the 2 PDFs that you wish to compare
786 decimal: int
787 number of decimal places to keep when rounding
788 """
789 return np.round(stats.ks_2samp(*samples)[1], decimal)
792def jensen_shannon_divergence(*args, **kwargs):
793 import warnings
794 warnings.warn(
795 "The jensen_shannon_divergence function has changed its name to "
796 "jensen_shannon_divergence_from_samples. jensen_shannon_divergence "
797 "may not be supported in future releases. Please update"
798 )
799 return jensen_shannon_divergence_from_samples(*args, **kwargs)
802def jensen_shannon_divergence_from_samples(
803 samples, kde=stats.gaussian_kde, decimal=5, base=np.e, **kwargs
804):
805 """Calculate the JS divergence between two sets of samples
807 Parameters
808 ----------
809 samples: list
810 2d list containing the samples drawn from two pdfs
811 kde: func
812 function to use when calculating the kde of the samples
813 decimal: int, float
814 number of decimal places to round the JS divergence to
815 base: float, optional
816 optional base to use for the scipy.stats.entropy function. Default
817 np.e
818 kwargs: dict
819 all kwargs are passed to the kde function
820 """
821 pdfs = samples_to_kde(samples, kde=kde, **kwargs)
822 return jensen_shannon_divergence_from_pdfs(pdfs, decimal=decimal, base=base)
825def jensen_shannon_divergence_from_pdfs(pdfs, decimal=5, base=np.e):
826 """Calculate the JS divergence between two distributions
828 Parameters
829 ----------
830 pdfs: list
831 list of length 2 containing the distributions you wish to compare
832 decimal: int, float
833 number of decimal places to round the JS divergence to
834 base: float, optional
835 optional base to use for the scipy.stats.entropy function. Default
836 np.e
837 """
838 if any(np.isnan(_).any() for _ in pdfs):
839 return float("nan")
840 a, b = pdfs
841 a = np.asarray(a)
842 b = np.asarray(b)
843 a /= a.sum()
844 b /= b.sum()
845 m = 1. / 2 * (a + b)
846 kl_forward = stats.entropy(a, qk=m, base=base)
847 kl_backward = stats.entropy(b, qk=m, base=base)
848 return np.round(kl_forward / 2. + kl_backward / 2., decimal)
851def samples_to_kde(samples, kde=stats.gaussian_kde, **kwargs):
852 """Generate KDE for a set of samples
854 Parameters
855 ----------
856 samples: list
857 list containing the samples to create a KDE for. samples can also
858 be a 2d list containing samples from multiple analyses.
859 kde: func
860 function to use when calculating the kde of the samples
861 """
862 _SINGLE_ANALYSIS = False
863 if not isinstance(samples[0], (np.ndarray, list, tuple)):
864 _SINGLE_ANALYSIS = True
865 _samples = [samples]
866 else:
867 _samples = samples
868 kernel = []
869 for i in _samples:
870 try:
871 kernel.append(kde(i, **kwargs))
872 except np.linalg.LinAlgError:
873 kernel.append(None)
874 x = np.linspace(
875 np.min([np.min(i) for i in _samples]),
876 np.max([np.max(i) for i in _samples]),
877 100
878 )
879 pdfs = [k(x) if k is not None else float('nan') for k in kernel]
880 if _SINGLE_ANALYSIS:
881 return pdfs[0]
882 return pdfs
885def make_cache_style_file(style_file):
886 """Make a cache directory which stores the style file you wish to use
887 when plotting
889 Parameters
890 ----------
891 style_file: str
892 path to the style file that you wish to use when plotting
893 """
894 make_dir(STYLE_CACHE)
895 shutil.copyfile(
896 style_file, os.path.join(STYLE_CACHE, "matplotlib_rcparams.sty")
897 )
900def get_matplotlib_style_file():
901 """Return the path to the matplotlib style file that you wish to use
902 """
903 style_file = os.path.join(STYLE_CACHE, "matplotlib_rcparams.sty")
904 if not os.path.isfile(style_file):
905 from pesummary import conf
907 return conf.style_file
908 return os.path.join(style_file)
911def get_matplotlib_backend(parallel=False):
912 """Return the matplotlib backend to use for the plotting modules
914 Parameters
915 ----------
916 parallel: Bool, optional
917 if True, backend is always set to 'Agg' for the multiprocessing module
918 """
919 try:
920 os.environ["DISPLAY"]
921 except KeyError:
922 try:
923 __IPYTHON__
924 except NameError:
925 DISPLAY = False
926 else:
927 DISPLAY = True
928 else:
929 DISPLAY = True
930 if DISPLAY and not parallel:
931 backend = "TKAgg"
932 else:
933 backend = "Agg"
934 return backend
937def _default_filename(default_filename, label=None):
938 """Return a default filename
940 Parameters
941 ----------
942 default_filename: str, optional
943 the default filename to use if a filename is not provided. default_filename
944 must be a formattable string with one empty argument for a label
945 label: str, optional
946 The label of the analysis. This is used in the filename
947 """
948 if not label:
949 filename = default_filename.format(round(time.time()))
950 else:
951 filename = default_filename.format(label)
952 return filename
955def check_filename(
956 default_filename="pesummary_{}.dat", outdir="./", label=None, filename=None,
957 overwrite=False, delete_existing=False
958):
959 """Check to see if a file exists. If no filename is provided, a default
960 filename is checked
962 Parameters
963 ----------
964 default_filename: str, optional
965 the default filename to use if a filename is not provided. default_filename
966 must be a formattable string with one empty argument for a label
967 outdir: str, optional
968 directory to write the dat file
969 label: str, optional
970 The label of the analysis. This is used in the filename if a filename
971 if not specified
972 filename: str, optional
973 The name of the file that you wish to write
974 overwrite: Bool, optional
975 If True, an existing file of the same name will be overwritten
976 """
977 if not filename:
978 filename = _default_filename(default_filename, label=label)
979 _file = os.path.join(outdir, filename)
980 if os.path.isfile(_file) and not overwrite:
981 raise FileExistsError(
982 "The file '{}' already exists in the directory {}".format(
983 filename, outdir
984 )
985 )
986 if os.path.isfile(_file) and delete_existing:
987 os.remove(_file)
988 return _file
991def string_match(string, substring):
992 """Return True if a string matches a substring. This substring may include
993 wildcards
995 Parameters
996 ----------
997 string: str
998 string you wish to match
999 substring: str
1000 string you wish to match against
1001 """
1002 import re
1003 import sre_constants
1005 try:
1006 match = re.match(re.compile(substring), string)
1007 if match:
1008 return True
1009 return False
1010 except sre_constants.error:
1011 import fnmatch
1012 return string_match(string, fnmatch.translate(substring))
1015def glob_directory(base):
1016 """Return a list of files matching base
1018 Parameters
1019 ----------
1020 base: str
1021 string you wish to match e.g. "./", "./*.py"
1022 """
1023 import glob
1024 if "*" not in base:
1025 base = os.path.join(base, "*")
1026 return glob.glob(base)
1029def list_match(list_to_match, substring, return_true=True, return_false=False):
1030 """Match a list of strings to a substring. This substring may include
1031 wildcards
1033 Parameters
1034 ----------
1035 list_to_match: list
1036 list of string you wish to match
1037 substring: str, list
1038 string you wish to match against or a list of string you wish to match
1039 against
1040 return_true: Bool, optional
1041 if True, return a sublist containing only the parameters that match the
1042 substring. Default True
1043 """
1044 match = np.ones(len(list_to_match), dtype=bool)
1045 if isinstance(substring, str):
1046 substring = [substring]
1048 for _substring in substring:
1049 match *= np.array(
1050 [string_match(item, _substring) for item in list_to_match],
1051 dtype=bool
1052 )
1053 if return_false:
1054 return np.array(list_to_match)[~match]
1055 elif return_true:
1056 return np.array(list_to_match)[match]
1057 return match
1060class Empty(object):
1061 """Define an empty class which simply returns the input
1062 """
1063 def __new__(self, *args):
1064 return args[0]
1067def history_dictionary(program=None, creator=conf.user, command_line=None):
1068 """Create a dictionary containing useful information about the origin of
1069 a PESummary data product
1071 Parameters
1072 ----------
1073 program: str, optional
1074 program used to generate the PESummary data product
1075 creator: str, optional
1076 The user who created the PESummary data product
1077 command_line: str, optional
1078 The command line which was run to generate the PESummary data product
1079 """
1080 from astropy.time import Time
1082 _dict = {
1083 "gps_creation_time": Time.now().gps,
1084 "creator": creator,
1085 }
1086 if command_line is not None:
1087 _dict["command_line"] = (
1088 "Generated by running the following script: {}".format(
1089 command_line
1090 )
1091 )
1092 else:
1093 _dict["command_line"] = " ".join(sys.argv)
1094 if program is not None:
1095 _dict["program"] = program
1096 return _dict
1099def mute_logger():
1100 """Mute the PESummary logger
1101 """
1102 _logger = logging.getLogger('PESummary')
1103 _logger.setLevel(logging.CRITICAL + 10)
1104 return
1107def unmute_logger():
1108 """Unmute the PESummary logger
1109 """
1110 _logger = logging.getLogger('PESummary')
1111 _logger.setLevel(logging.INFO)
1112 return
1114# import error message
1115import_error_msg = (
1116 "Unable to install '{}'. You will not be able to use some of the inbuilt "
1117 "functions."
1118)
1121# silence matplotlib warnings
1122logging.getLogger('matplotlib.font_manager').setLevel(logging.CRITICAL + 10)
1123# setup pesummary logger
1124_, LOG_FILE = setup_logger()
1125logger = logging.getLogger('PESummary')