Coverage for pesummary/utils/utils.py: 59.4%
498 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import os
4import sys
5import logging
6import contextlib
7import time
8import copy
9import shutil
11import numpy as np
12from scipy.integrate import cumulative_trapezoid as cumtrapz
13from scipy.interpolate import interp1d
14from scipy import stats
15import h5py
16from pesummary import conf
18__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
20try:
21 from coloredlogs import ColoredFormatter as LogFormatter
22except ImportError:
23 LogFormatter = logging.Formatter
25CACHE_DIR = os.path.join(
26 os.getenv(
27 "XDG_CACHE_HOME",
28 os.path.expanduser(os.path.join("~", ".cache")),
29 ),
30 "pesummary",
31)
32STYLE_CACHE = os.path.join(CACHE_DIR, "style")
33LOG_CACHE = os.path.join(CACHE_DIR, "log")
36def resample_posterior_distribution(posterior, nsamples):
37 """Randomly draw nsamples from the posterior distribution
39 Parameters
40 ----------
41 posterior: ndlist
42 nd list of posterior samples. If you only want to resample one
43 posterior distribution then posterior=[[1., 2., 3., 4.]]. For multiple
44 posterior distributions then posterior=[[1., 2., 3., 4.], [1., 2., 3.]]
45 nsamples: int
46 number of samples that you wish to randomly draw from the distribution
47 """
48 if len(posterior) == 1:
49 n, bins = np.histogram(posterior, bins=50)
50 n = np.array([0] + [i for i in n])
51 cdf = cumtrapz(n, bins, initial=0)
52 cdf /= cdf[-1]
53 icdf = interp1d(cdf, bins)
54 samples = icdf(np.random.rand(nsamples))
55 else:
56 posterior = np.array([i for i in posterior])
57 keep_idxs = np.random.choice(
58 len(posterior[0]), nsamples, replace=False)
59 samples = [i[keep_idxs] for i in posterior]
60 return samples
63def check_file_exists_and_rename(file_name):
64 """Check to see if a file exists and if it does then rename the file
66 Parameters
67 ----------
68 file_name: str
69 proposed file name to store data
70 """
71 if os.path.isfile(file_name):
72 import shutil
74 old_file = "{}_old".format(file_name)
75 while os.path.isfile(old_file):
76 old_file += "_old"
77 logger.warning(
78 "The file '{}' already exists. Renaming the existing file to "
79 "{} and saving the data to the requested file name".format(
80 file_name, old_file
81 )
82 )
83 shutil.move(file_name, old_file)
86def check_condition(condition, error_message):
87 """Raise an exception if the condition is not satisfied
88 """
89 if condition:
90 raise Exception(error_message)
93def rename_group_or_dataset_in_hf5_file(base_file, group=None, dataset=None):
94 """Rename a group or dataset in an hdf5 file
96 Parameters
97 ----------
98 group: list, optional
99 a list containing the path to the group that you would like to change
100 as the first argument and the new name of the group as the second
101 argument
102 dataset: list, optional
103 a list containing the name of the dataset that you would like to change
104 as the first argument and the new name of the dataset as the second
105 argument
106 """
107 condition = not os.path.isfile(base_file)
108 check_condition(condition, "The file %s does not exist" % (base_file))
109 f = h5py.File(base_file, "a")
110 if group:
111 f[group[1]] = f[group[0]]
112 del f[group[0]]
113 elif dataset:
114 f[dataset[1]] = f[dataset[0]]
115 del f[dataset[0]]
116 f.close()
119def make_dir(path):
120 if os.path.isdir(os.path.expanduser(path)):
121 pass
122 else:
123 os.makedirs(os.path.expanduser(path), exist_ok=True)
126def guess_url(web_dir, host, user):
127 """Guess the base url from the host name
129 Parameters
130 ----------
131 web_dir: str
132 path to the web directory where you want the data to be saved
133 host: str
134 the host name of the machine where the python interpreter is currently
135 executing
136 user: str
137 the user that is current executing the python interpreter
138 """
139 ligo_data_grid = False
140 if 'public_html' in web_dir:
141 ligo_data_grid = True
142 if ligo_data_grid:
143 path = web_dir.split("public_html")[1]
144 if "raven" in host or "arcca" in host:
145 url = "https://geo2.arcca.cf.ac.uk/~{}".format(user)
146 elif 'ligo-wa' in host:
147 url = "https://ldas-jobs.ligo-wa.caltech.edu/~{}".format(user)
148 elif 'ligo-la' in host:
149 url = "https://ldas-jobs.ligo-la.caltech.edu/~{}".format(user)
150 elif "cit" in host or "caltech" in host:
151 url = "https://ldas-jobs.ligo.caltech.edu/~{}".format(user)
152 elif 'uwm' in host or 'nemo' in host:
153 url = "https://ldas-jobs.phys.uwm.edu/~{}".format(user)
154 elif 'phy.syr.edu' in host:
155 url = "https://sugar-jobs.phy.syr.edu/~{}".format(user)
156 elif 'vulcan' in host:
157 url = "https://galahad.aei.mpg.de/~{}".format(user)
158 elif 'atlas' in host:
159 url = "https://atlas1.atlas.aei.uni-hannover.de/~{}".format(user)
160 elif 'iucaa' in host:
161 url = "https://ldas-jobs.gw.iucaa.in/~{}".format(user)
162 elif 'alice' in host:
163 url = "https://dumpty.alice.icts.res.in/~{}".format(user)
164 elif 'hawk' in host:
165 url = "https://ligo.gravity.cf.ac.uk/~{}".format(user)
166 else:
167 url = "https://{}/~{}".format(host, user)
168 url += path
169 else:
170 url = "https://{}".format(web_dir)
171 return url
174def map_parameter_names(dictionary, mapping):
175 """Modify keys in dictionary to use different names according to a map
177 Parameters
178 ----------
179 mapping: dict
180 dictionary mapping existing keys to new names.
182 Returns
183 -------
184 standard_dict: dict
185 dict object with new parameter names
186 """
187 standard_dict = {}
188 for key, item in dictionary.items():
189 if key not in mapping.keys():
190 standard_dict[key] = item
191 continue
192 standard_dict[mapping[key]] = item
193 return standard_dict
196def command_line_arguments():
197 """Return the command line arguments
198 """
199 return sys.argv[1:]
202def command_line_dict():
203 """Return a dictionary of command line arguments
204 """
205 from pesummary.gw.cli.parser import ArgumentParser
206 parser = ArgumentParser()
207 parser.add_all_known_options_to_parser()
208 opts = parser.parse_args()
209 return vars(opts)
212def gw_results_file(opts):
213 """Determine if a GW results file is passed
214 """
215 from pesummary.gw.cli.parser import ArgumentParser
217 attrs, defaults = ArgumentParser().gw_options
218 condition = any(
219 hasattr(opts, attr) and getattr(opts, attr) and getattr(opts, attr)
220 != default for attr, default in zip(attrs, defaults)
221 )
222 if condition:
223 return True
224 return False
227def functions(opts, gw=False):
228 """Return a dictionary of functions that are either specific to GW results
229 files or core.
230 """
231 from pesummary.core.cli.inputs import (
232 WebpagePlusPlottingPlusMetaFileInput as Input
233 )
234 from pesummary.gw.cli.inputs import (
235 WebpagePlusPlottingPlusMetaFileInput as GWInput
236 )
237 from pesummary.core.file.meta_file import MetaFile
238 from pesummary.gw.file.meta_file import GWMetaFile
239 from pesummary.core.finish import FinishingTouches
240 from pesummary.gw.finish import GWFinishingTouches
242 dictionary = {}
243 dictionary["input"] = GWInput if gw_results_file(opts) or gw else Input
244 dictionary["MetaFile"] = GWMetaFile if gw_results_file(opts) or gw else MetaFile
245 dictionary["FinishingTouches"] = \
246 GWFinishingTouches if gw_results_file(opts) or gw else FinishingTouches
247 return dictionary
250def _logger_format():
251 return '%(asctime)s %(name)s %(levelname)-8s: %(message)s'
254def setup_logger():
255 """Set up the logger output.
256 """
257 import tempfile
259 def get_console_handler(stream_level="INFO"):
260 console_handler = logging.StreamHandler()
261 console_handler.setLevel(level=getattr(logging, stream_level))
262 console_handler.setFormatter(FORMATTER)
263 return console_handler
265 def get_file_handler(log_file):
266 file_handler = logging.FileHandler(log_file, mode='w')
267 file_handler.setLevel(level=logging.DEBUG)
268 file_handler.setFormatter(FORMATTER)
269 return file_handler
271 make_dir(LOG_CACHE)
272 dirpath = tempfile.mkdtemp(dir=LOG_CACHE)
273 stream_level = 'INFO'
274 if "-v" in sys.argv or "--verbose" in sys.argv:
275 stream_level = 'DEBUG'
277 FORMATTER = LogFormatter(_logger_format(), datefmt='%Y-%m-%d %H:%M:%S')
278 LOG_FILE = '%s/pesummary.log' % (dirpath)
279 logger = logging.getLogger('PESummary')
280 logger.propagate = False
281 logger.setLevel(level=logging.DEBUG)
282 logger.addHandler(get_console_handler(stream_level=stream_level))
283 logger.addHandler(get_file_handler(LOG_FILE))
284 return logger, LOG_FILE
287def remove_tmp_directories():
288 """Remove the temporary directories created by PESummary
289 """
290 import shutil
291 from glob import glob
293 directories = glob(".tmp/pesummary/*")
295 for i in directories:
296 if os.path.isdir(i):
297 shutil.rmtree(i)
298 elif os.path.isfile(i):
299 os.remove(i)
302def _add_existing_data(namespace):
303 """Add existing data to namespace object
304 """
305 for num, i in enumerate(namespace.existing_labels):
306 if hasattr(namespace, "labels") and i not in namespace.labels:
307 namespace.labels.append(i)
308 if hasattr(namespace, "samples") and i not in list(namespace.samples.keys()):
309 namespace.samples[i] = namespace.existing_samples[i]
310 if hasattr(namespace, "weights") and i not in list(namespace.weights.keys()):
311 if namespace.existing_weights is None:
312 namespace.weights[i] = None
313 else:
314 namespace.weights[i] = namespace.existing_weights[i]
315 if hasattr(namespace, "injection_data"):
316 if i not in list(namespace.injection_data.keys()):
317 namespace.injection_data[i] = namespace.existing_injection_data[i]
318 if hasattr(namespace, "file_versions"):
319 if i not in list(namespace.file_versions.keys()):
320 namespace.file_versions[i] = namespace.existing_file_version[i]
321 if hasattr(namespace, "file_kwargs"):
322 if i not in list(namespace.file_kwargs.keys()):
323 namespace.file_kwargs[i] = namespace.existing_file_kwargs[i]
324 if hasattr(namespace, "config"):
325 if namespace.existing_config[num] not in namespace.config:
326 namespace.config.append(namespace.existing_config[num])
327 elif namespace.existing_config[num] is None:
328 namespace.config.append(None)
329 if hasattr(namespace, "priors"):
330 if hasattr(namespace, "existing_priors"):
331 for key, item in namespace.existing_priors.items():
332 if key in namespace.priors.keys():
333 for label in item.keys():
334 if label not in namespace.priors[key].keys():
335 namespace.priors[key][label] = item[label]
336 else:
337 namespace.priors.update({key: item})
338 if hasattr(namespace, "approximant") and namespace.approximant is not None:
339 if i not in list(namespace.approximant.keys()):
340 if i in list(namespace.existing_approximant.keys()):
341 namespace.approximant[i] = namespace.existing_approximant[i]
342 if hasattr(namespace, "psds") and namespace.psds is not None:
343 if i not in list(namespace.psds.keys()):
344 if i in list(namespace.existing_psd.keys()):
345 namespace.psds[i] = namespace.existing_psd[i]
346 else:
347 namespace.psds[i] = {}
348 if hasattr(namespace, "calibration") and namespace.calibration is not None:
349 if i not in list(namespace.calibration.keys()):
350 if i in list(namespace.existing_calibration.keys()):
351 namespace.calibration[i] = namespace.existing_calibration[i]
352 else:
353 namespace.calibration[i] = {}
354 if hasattr(namespace, "skymap") and namespace.skymap is not None:
355 if i not in list(namespace.skymap.keys()):
356 if i in list(namespace.existing_skymap.keys()):
357 namespace.skymap[i] = namespace.existing_skymap[i]
358 else:
359 namespace.skymap[i] = None
360 if hasattr(namespace, "maxL_samples"):
361 if i not in list(namespace.maxL_samples.keys()):
362 namespace.maxL_samples[i] = {
363 key: val.maxL for key, val in namespace.samples[i].items()
364 }
365 if hasattr(namespace, "pepredicates_probs"):
366 if i not in list(namespace.pepredicates_probs.keys()):
367 from pesummary.gw.classification import PEPredicates
368 try:
369 namespace.pepredicates_probs[i] = PEPredicates(
370 namespace.existing_samples[i]
371 ).dual_classification()
372 except Exception:
373 namespace.pepredicates_probs[i] = None
374 if hasattr(namespace, "pastro_probs"):
375 if i not in list(namespace.pastro_probs.keys()):
376 from pesummary.gw.classification import PAstro
377 try:
378 namespace.pastro_probs[i] = PAstro(
379 namespace.existing_samples[i]
380 ).dual_classification()
381 except Exception:
382 namespace.pastro_probs[i] = None
383 if hasattr(namespace, "result_files"):
384 number = len(namespace.labels)
385 while len(namespace.result_files) < number:
386 namespace.result_files.append(namespace.existing_metafile)
387 parameters = [list(namespace.samples[i].keys()) for i in namespace.labels]
388 namespace.same_parameters = list(
389 set.intersection(*[set(l) for l in parameters])
390 )
391 namespace.same_samples = {
392 param: {
393 i: namespace.samples[i][param] for i in namespace.labels
394 } for param in namespace.same_parameters
395 }
396 return namespace
399def customwarn(message, category, filename, lineno, file=None, line=None):
400 """
401 """
402 import sys
403 import warnings
405 sys.stdout.write(
406 warnings.formatwarning("%s" % (message), category, filename, lineno)
407 )
410def determine_gps_time_and_window(maxL_samples, labels):
411 """Determine the gps time and window to use in the spectrogram and
412 omegascan plots
413 """
414 times = [
415 maxL_samples[label]["geocent_time"] for label in labels
416 ]
417 gps_time = np.mean(times)
418 time_range = np.max(times) - np.min(times)
419 if time_range < 4.:
420 window = 4.
421 else:
422 window = time_range * 1.5
423 return gps_time, window
426def number_of_columns_for_legend(labels):
427 """Determine the number of columns to use in a legend
429 Parameters
430 ----------
431 labels: list
432 list of labels in the legend
433 """
434 max_length = np.max([len(i) for i in labels]) + 5.
435 if max_length > 50.:
436 return 1
437 else:
438 return int(50. / max_length)
441class RedirectLogger(object):
442 """Class to redirect the output from other codes to the `pesummary`
443 logger
445 Parameters
446 ----------
447 level: str, optional
448 the level to display the messages
449 """
450 def __init__(self, code, level="Debug"):
451 self.logger = logging.getLogger('PESummary')
452 self.level = getattr(logging, level)
453 self._redirector = contextlib.redirect_stdout(self)
454 self.code = code
456 def isatty(self):
457 pass
459 def write(self, msg):
460 """Write the message to stdout
462 Parameters
463 ----------
464 msg: str
465 the message you wish to be printed to stdout
466 """
467 if msg and not msg.isspace():
468 self.logger.log(self.level, "[from %s] %s" % (self.code, msg))
470 def flush(self):
471 pass
473 def __enter__(self):
474 self._redirector.__enter__()
475 return self
477 def __exit__(self, exc_type, exc_value, traceback):
478 self._redirector.__exit__(exc_type, exc_value, traceback)
481def draw_conditioned_prior_samples(
482 samples_dict, prior_samples_dict, conditioned, xlow, xhigh, N=100,
483 nsamples=1000
484):
485 """Return a prior_dict that is conditioned on certain parameters
487 Parameters
488 ----------
489 samples_dict: pesummary.utils.samples_dict.SamplesDict
490 SamplesDict containing the posterior samples
491 prior_samples_dict: pesummary.utils.samples_dict.SamplesDict
492 SamplesDict containing the prior samples
493 conditioned: list
494 list of parameters that you wish to condition your prior on
495 xlow: dict
496 dictionary of lower bounds for each parameter
497 xhigh: dict
498 dictionary of upper bounds for each parameter
499 N: int, optional
500 number of points to use within the grid. Default 100
501 nsamples: int, optional
502 number of samples to draw. Default 1000
503 """
504 for param in conditioned:
505 indices = _draw_conditioned_prior_samples(
506 prior_samples_dict[param], samples_dict[param], xlow[param],
507 xhigh[param], xN=N, N=nsamples
508 )
509 for key, val in prior_samples_dict.items():
510 prior_samples_dict[key] = val[indices]
512 return prior_samples_dict
515def _draw_conditioned_prior_samples(
516 prior_samples, posterior_samples, xlow, xhigh, xN=1000, N=1000
517):
518 """Return a list of indices for the conditioned prior via rejection
519 sampling. The conditioned prior will then be `prior_samples[indicies]`.
520 Code from Michael Puerrer.
522 Parameters
523 ----------
524 prior_samples: np.ndarray
525 array of prior samples that you wish to condition
526 posterior_samples: np.ndarray
527 array of posterior samples that you wish to condition on
528 xlow: float
529 lower bound for grid to be used
530 xhigh: float
531 upper bound for grid to be used
532 xN: int, optional
533 Number of points to use within the grid
534 N: int, optional
535 Number of samples to generate
536 """
537 from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE
539 prior_KDE = ReflectionBoundedKDE(prior_samples)
540 posterior_KDE = ReflectionBoundedKDE(posterior_samples)
542 x = np.linspace(xlow, xhigh, xN)
543 idx_nz = np.nonzero(posterior_KDE(x))
544 pdf_ratio = prior_KDE(x)[idx_nz] / posterior_KDE(x)[idx_nz]
545 M = 1.1 / min(pdf_ratio[np.where(pdf_ratio < 1)])
547 indicies = []
548 i = 0
549 while i < N:
550 x_i = np.random.choice(prior_samples)
551 idx_i = np.argmin(np.abs(prior_samples - x_i))
552 u = np.random.uniform()
553 if u < posterior_KDE(x_i) / (M * prior_KDE(x_i)):
554 indicies.append(idx_i)
555 i += 1
556 return indicies
559def unzip(zip_file, outdir=None, overwrite=False):
560 """Extract the data from a zipped file and save in outdir.
562 Parameters
563 ----------
564 zip_file: str
565 path to the file you wish to unzip
566 outdir: str, optional
567 path to the directory where you wish to save the unzipped file. Default
568 None which means that the unzipped file is stored in CACHE_DIR
569 overwrite: Bool, optional
570 If True, overwrite a file that has the same name
571 """
572 import gzip
573 import shutil
574 from pathlib import Path
576 f = Path(zip_file)
577 file_name = f.stem
578 if outdir is None:
579 outdir = CACHE_DIR
580 out_file = os.path.join(outdir, file_name)
581 if os.path.isfile(out_file) and not overwrite:
582 raise FileExistsError(
583 "The file '{}' already exists. Not overwriting".format(out_file)
584 )
585 with gzip.open(zip_file, 'rb') as input:
586 with open(out_file, 'wb') as output:
587 shutil.copyfileobj(input, output)
588 return out_file
591def iterator(
592 iterable, desc=None, logger=None, tqdm=False, total=None, file=None,
593 bar_format=None
594):
595 """Return either a tqdm iterator, if tqdm installed, or a simple range
597 Parameters
598 ----------
599 iterable: func
600 iterable that you wish to iterate over
601 desc: str, optional
602 description for the tqdm bar
603 tqdm: Bool, optional
604 If True, a tqdm object is used. Otherwise simply returns the iterator.
605 logger_output: Bool, optional
606 If True, the tqdm progress bar interacts with logger
607 total: float, optional
608 total length of iterable
609 logger_name: str, optional
610 name of the logger you wish to use
611 file: str, optional
612 path to file that you wish to write the output to
613 """
614 from pesummary.utils.tqdm import tqdm
615 if tqdm:
616 try:
617 FORMAT, DESC = None, None
618 if bar_format is None:
619 FORMAT = (
620 '{desc} | {percentage:3.0f}% | {n_fmt}/{total_fmt} | {elapsed}'
621 )
622 if desc is not None:
623 DESC = desc
624 return tqdm(
625 iterable, total=total, logger=logger, desc=DESC, file=file,
626 bar_format=FORMAT,
627 )
628 except ImportError:
629 return iterable
630 else:
631 return iterable
634def _check_latex_install(force_tex=False):
635 from matplotlib import rcParams
636 from distutils.spawn import find_executable
638 original = rcParams['text.usetex']
639 if find_executable("latex") is not None:
640 try:
641 from matplotlib.texmanager import TexManager
643 texmanager = TexManager()
644 texmanager.make_dvi(r"$mass_{1}$", 12)
645 if force_tex:
646 original = True
647 rcParams["text.usetex"] = original
648 except RuntimeError:
649 rcParams["text.usetex"] = False
650 else:
651 rcParams["text.usetex"] = False
654def smart_round(parameters, return_latex=False, return_latex_row=False):
655 """Round a parameter according to the uncertainty. If more than one parameter
656 and uncertainty is passed, each parameter is rounded according to the
657 lowest uncertainty
659 Parameters
660 ----------
661 parameter_dictionary: list/np.ndarray
662 list containing the median, upper bound and lower bound for a given parameter
663 return_latex: Bool, optional
664 if True, return as a latex string
665 return_latex_row: Bool, optional
666 if True, return the rounded data as a single row in latex format
668 Examples
669 --------
670 >>> data = [1.234, 0.2, 0.1]
671 >>> smart_round(data)
672 [ 1.2 0.2 0.1]
673 >>> data = [
674 ... [6.093, 0.059, 0.055],
675 ... [6.104, 0.057, 0.052],
676 ... [6.08, 0.056, 0.052]
677 ... ]
678 >>> smart_round(data)
679 [[ 6.09 0.06 0.06]
680 [ 6.1 0.06 0.05]
681 [ 6.08 0.06 0.05]]
682 >>> smart_round(data, return_latex=True)
683 6.09^{+0.06}_{-0.06}
684 6.10^{+0.06}_{-0.05}
685 6.08^{+0.06}_{-0.05}
686 >>> data = [
687 ... [743.25, 43.6, 53.2],
688 ... [8712.5, 21.5, 35.2],
689 ... [196.46, 65.2, 12.5]
690 ... ]
691 >>> smart_round(data, return_latex_row=True)
692 740^{+40}_{-50} & 8710^{+20}_{-40} & 200^{+70}_{-10}
693 >>> data = [
694 ... [743.25, 43.6, 53.2],
695 ... [8712.5, 21.5, 35.2],
696 ... [196.46, 65.2, 8.2]
697 ... ]
698 >>> smart_round(data, return_latex_row=True)
699 743^{+44}_{-53} & 8712^{+22}_{-35} & 196^{+65}_{-8}
700 """
701 rounded = copy.deepcopy(np.atleast_2d(parameters))
702 lowest_uncertainty = np.min(np.abs(parameters))
703 rounding = int(-1 * np.floor(np.log10(lowest_uncertainty)))
704 for num, _ in enumerate(rounded):
705 rounded[num] = [np.round(value, rounding) for value in rounded[num]]
706 if return_latex or return_latex_row:
707 if rounding > 0:
708 _format = "%.{}f".format(rounding)
709 else:
710 _format = "%.f"
711 string = "{0}^{{+{0}}}_{{-{0}}}".format(_format)
712 latex = [string % (value[0], value[1], value[2]) for value in rounded]
713 if return_latex:
714 for ll in latex:
715 print(ll)
716 else:
717 print(" & ".join(latex))
718 return ""
719 elif np.array(parameters).ndim == 1:
720 return rounded[0]
721 else:
722 return rounded
725def safe_round(a, decimals=0, **kwargs):
726 """Try and round an array to the given number of decimals. If an exception
727 is raised, return the original array
729 Parameters
730 ----------
731 a: np.ndarray
732 array you wish to round
733 decimals: int
734 the number of decimals you wish to round too
735 **kwargs: dict
736 all kwargs are passed to numpy.round
737 """
738 try:
739 return np.round(a, decimals=decimals, **kwargs)
740 except Exception:
741 return a
744def gelman_rubin(samples, decimal=5):
745 """Return an approximation to the Gelman-Rubin statistic (see Gelman, A. and
746 Rubin, D. B., Statistical Science, Vol 7, No. 4, pp. 457--511 (1992))
748 Parameters
749 ----------
750 samples: np.ndarray
751 2d array of samples for a given parameter, one for each chain
752 decimal: int
753 number of decimal places to keep when rounding
755 Examples
756 --------
757 >>> from pesummary.utils.utils import gelman_rubin
758 >>> samples = [[1, 1.5, 1.2, 1.4, 1.6, 1.2], [1.5, 1.3, 1.4, 1.7]]
759 >>> gelman_rubin(samples, decimal=5)
760 1.2972
761 """
762 means = [np.mean(data) for data in samples]
763 variances = [np.var(data) for data in samples]
764 BoverN = np.var(means)
765 W = np.mean(variances)
766 sigma = W + BoverN
767 m = len(samples)
768 Vhat = sigma + BoverN / m
769 return np.round(Vhat / W, decimal)
772def kolmogorov_smirnov_test(samples, decimal=5):
773 """Return the KS p value between two PDFs
775 Parameters
776 ----------
777 samples: 2d list
778 2d list containing the 2 PDFs that you wish to compare
779 decimal: int
780 number of decimal places to keep when rounding
781 """
782 return np.round(stats.ks_2samp(*samples)[1], decimal)
785def jensen_shannon_divergence(*args, **kwargs):
786 import warnings
787 warnings.warn(
788 "The jensen_shannon_divergence function has changed its name to "
789 "jensen_shannon_divergence_from_samples. jensen_shannon_divergence "
790 "may not be supported in future releases. Please update"
791 )
792 return jensen_shannon_divergence_from_samples(*args, **kwargs)
795def jensen_shannon_divergence_from_samples(
796 samples, kde=stats.gaussian_kde, decimal=5, base=np.e, **kwargs
797):
798 """Calculate the JS divergence between two sets of samples
800 Parameters
801 ----------
802 samples: list
803 2d list containing the samples drawn from two pdfs
804 kde: func
805 function to use when calculating the kde of the samples
806 decimal: int, float
807 number of decimal places to round the JS divergence to
808 base: float, optional
809 optional base to use for the scipy.stats.entropy function. Default
810 np.e
811 kwargs: dict
812 all kwargs are passed to the kde function
813 """
814 pdfs = samples_to_kde(samples, kde=kde, **kwargs)
815 return jensen_shannon_divergence_from_pdfs(pdfs, decimal=decimal, base=base)
818def jensen_shannon_divergence_from_pdfs(pdfs, decimal=5, base=np.e):
819 """Calculate the JS divergence between two distributions
821 Parameters
822 ----------
823 pdfs: list
824 list of length 2 containing the distributions you wish to compare
825 decimal: int, float
826 number of decimal places to round the JS divergence to
827 base: float, optional
828 optional base to use for the scipy.stats.entropy function. Default
829 np.e
830 """
831 if any(np.isnan(_).any() for _ in pdfs):
832 return float("nan")
833 a, b = pdfs
834 a = np.asarray(a)
835 b = np.asarray(b)
836 a /= a.sum()
837 b /= b.sum()
838 m = 1. / 2 * (a + b)
839 kl_forward = stats.entropy(a, qk=m, base=base)
840 kl_backward = stats.entropy(b, qk=m, base=base)
841 return np.round(kl_forward / 2. + kl_backward / 2., decimal)
844def samples_to_kde(samples, kde=stats.gaussian_kde, **kwargs):
845 """Generate KDE for a set of samples
847 Parameters
848 ----------
849 samples: list
850 list containing the samples to create a KDE for. samples can also
851 be a 2d list containing samples from multiple analyses.
852 kde: func
853 function to use when calculating the kde of the samples
854 """
855 _SINGLE_ANALYSIS = False
856 if not isinstance(samples[0], (np.ndarray, list, tuple)):
857 _SINGLE_ANALYSIS = True
858 _samples = [samples]
859 else:
860 _samples = samples
861 kernel = []
862 for i in _samples:
863 try:
864 kernel.append(kde(i, **kwargs))
865 except np.linalg.LinAlgError:
866 kernel.append(None)
867 x = np.linspace(
868 np.min([np.min(i) for i in _samples]),
869 np.max([np.max(i) for i in _samples]),
870 100
871 )
872 pdfs = [k(x) if k is not None else float('nan') for k in kernel]
873 if _SINGLE_ANALYSIS:
874 return pdfs[0]
875 return pdfs
878def make_cache_style_file(style_file):
879 """Make a cache directory which stores the style file you wish to use
880 when plotting
882 Parameters
883 ----------
884 style_file: str
885 path to the style file that you wish to use when plotting
886 """
887 make_dir(STYLE_CACHE)
888 shutil.copyfile(
889 style_file, os.path.join(STYLE_CACHE, "matplotlib_rcparams.sty")
890 )
893def get_matplotlib_style_file():
894 """Return the path to the matplotlib style file that you wish to use
895 """
896 style_file = os.path.join(STYLE_CACHE, "matplotlib_rcparams.sty")
897 if not os.path.isfile(style_file):
898 from pesummary import conf
900 return conf.style_file
901 return os.path.join(style_file)
904def get_matplotlib_backend(parallel=False):
905 """Return the matplotlib backend to use for the plotting modules
907 Parameters
908 ----------
909 parallel: Bool, optional
910 if True, backend is always set to 'Agg' for the multiprocessing module
911 """
912 try:
913 os.environ["DISPLAY"]
914 except KeyError:
915 try:
916 __IPYTHON__
917 except NameError:
918 DISPLAY = False
919 else:
920 DISPLAY = True
921 else:
922 DISPLAY = True
923 if DISPLAY and not parallel:
924 backend = "TKAgg"
925 else:
926 backend = "Agg"
927 return backend
930def _default_filename(default_filename, label=None):
931 """Return a default filename
933 Parameters
934 ----------
935 default_filename: str, optional
936 the default filename to use if a filename is not provided. default_filename
937 must be a formattable string with one empty argument for a label
938 label: str, optional
939 The label of the analysis. This is used in the filename
940 """
941 if not label:
942 filename = default_filename.format(round(time.time()))
943 else:
944 filename = default_filename.format(label)
945 return filename
948def check_filename(
949 default_filename="pesummary_{}.dat", outdir="./", label=None, filename=None,
950 overwrite=False, delete_existing=False
951):
952 """Check to see if a file exists. If no filename is provided, a default
953 filename is checked
955 Parameters
956 ----------
957 default_filename: str, optional
958 the default filename to use if a filename is not provided. default_filename
959 must be a formattable string with one empty argument for a label
960 outdir: str, optional
961 directory to write the dat file
962 label: str, optional
963 The label of the analysis. This is used in the filename if a filename
964 if not specified
965 filename: str, optional
966 The name of the file that you wish to write
967 overwrite: Bool, optional
968 If True, an existing file of the same name will be overwritten
969 """
970 if not filename:
971 filename = _default_filename(default_filename, label=label)
972 _file = os.path.join(outdir, filename)
973 if os.path.isfile(_file) and not overwrite:
974 raise FileExistsError(
975 "The file '{}' already exists in the directory {}".format(
976 filename, outdir
977 )
978 )
979 if os.path.isfile(_file) and delete_existing:
980 os.remove(_file)
981 return _file
984def string_match(string, substring):
985 """Return True if a string matches a substring. This substring may include
986 wildcards
988 Parameters
989 ----------
990 string: str
991 string you wish to match
992 substring: str
993 string you wish to match against
994 """
995 import re
996 import sre_constants
998 try:
999 match = re.match(re.compile(substring), string)
1000 if match:
1001 return True
1002 return False
1003 except sre_constants.error:
1004 import fnmatch
1005 return string_match(string, fnmatch.translate(substring))
1008def glob_directory(base):
1009 """Return a list of files matching base
1011 Parameters
1012 ----------
1013 base: str
1014 string you wish to match e.g. "./", "./*.py"
1015 """
1016 import glob
1017 if "*" not in base:
1018 base = os.path.join(base, "*")
1019 return glob.glob(base)
1022def list_match(list_to_match, substring, return_true=True, return_false=False):
1023 """Match a list of strings to a substring. This substring may include
1024 wildcards
1026 Parameters
1027 ----------
1028 list_to_match: list
1029 list of string you wish to match
1030 substring: str, list
1031 string you wish to match against or a list of string you wish to match
1032 against
1033 return_true: Bool, optional
1034 if True, return a sublist containing only the parameters that match the
1035 substring. Default True
1036 """
1037 match = np.ones(len(list_to_match), dtype=bool)
1038 if isinstance(substring, str):
1039 substring = [substring]
1041 for _substring in substring:
1042 match *= np.array(
1043 [string_match(item, _substring) for item in list_to_match],
1044 dtype=bool
1045 )
1046 if return_false:
1047 return np.array(list_to_match)[~match]
1048 elif return_true:
1049 return np.array(list_to_match)[match]
1050 return match
1053class Empty(object):
1054 """Define an empty class which simply returns the input
1055 """
1056 def __new__(self, *args):
1057 return args[0]
1060def history_dictionary(program=None, creator=conf.user, command_line=None):
1061 """Create a dictionary containing useful information about the origin of
1062 a PESummary data product
1064 Parameters
1065 ----------
1066 program: str, optional
1067 program used to generate the PESummary data product
1068 creator: str, optional
1069 The user who created the PESummary data product
1070 command_line: str, optional
1071 The command line which was run to generate the PESummary data product
1072 """
1073 from astropy.time import Time
1075 _dict = {
1076 "gps_creation_time": Time.now().gps,
1077 "creator": creator,
1078 }
1079 if command_line is not None:
1080 _dict["command_line"] = (
1081 "Generated by running the following script: {}".format(
1082 command_line
1083 )
1084 )
1085 else:
1086 _dict["command_line"] = " ".join(sys.argv)
1087 if program is not None:
1088 _dict["program"] = program
1089 return _dict
1092def mute_logger():
1093 """Mute the PESummary logger
1094 """
1095 _logger = logging.getLogger('PESummary')
1096 _logger.setLevel(logging.CRITICAL + 10)
1097 return
1100def unmute_logger():
1101 """Unmute the PESummary logger
1102 """
1103 _logger = logging.getLogger('PESummary')
1104 _logger.setLevel(logging.INFO)
1105 return
1108_, LOG_FILE = setup_logger()
1109logger = logging.getLogger('PESummary')