Coverage for pesummary/core/file/mcmc.py: 94.3%
35 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3from pesummary.utils.utils import logger
4import numpy as np
5import copy
7__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
8STEP_NUMBER_PARAMS = ["cycle"]
9algorithms = ["burnin_by_step_number", "burnin_by_first_n"]
12def _number_of_negative_steps(samples, logger_level="debug"):
13 """Return the number of samples that have step < 0 for each dictionary
15 Parameters
16 ----------
17 samples: pesummary.utils.samples_dict.MCMCSamplesDict
18 MCMCSamplesDict object containing the samples for multiple mcmc chains
19 logger_level: str, optional
20 logger level to use when printing information to stdout. Default debug
21 """
22 _samples = copy.deepcopy(samples)
23 try:
24 parameters = set.intersection(
25 *[set(_params) for _params in _samples.parameters.values()]
26 )
27 except AttributeError:
28 parameters = list(_samples.parameters)
29 step_param = [
30 alternative for alternative in STEP_NUMBER_PARAMS if alternative
31 in parameters
32 ]
33 if not len(step_param):
34 logger.warning(
35 "Unable to find a step number in the MCMCSamplesDict object. "
36 "Aborting discard"
37 )
38 return {key: None for key in _samples.keys()}
39 elif len(step_param) > 1:
40 step_param = step_param[0]
41 getattr(logger, logger_level)(
42 "Multiple columns found with possible step numbers. Using "
43 "{}".format(step_param)
44 )
45 else:
46 step_param = step_param[0]
47 keys = _samples.keys()
48 step_idx = [
49 np.arange(_samples[key].number_of_samples)[_samples[key][step_param] > 0]
50 for key in _samples.keys()
51 ]
52 return {
53 key: step[0] if len(step_idx) else 0 for key, step in
54 zip(keys, step_idx)
55 }
58def burnin_by_step_number(samples, logger_level="debug"):
59 """Discard all samples with step number < 0 as burnin
61 Parameters
62 ----------
63 samples: pesummary.utils.samples_dict.MCMCSamplesDict
64 MCMCSamplesDict object containing the samples for multiple mcmc chains
65 logger_level: str, optional
66 logger level to use when printing information to stdout. Default debug
67 """
68 _samples = copy.deepcopy(samples)
69 n_samples = _number_of_negative_steps(_samples, logger_level=logger_level)
70 getattr(logger, logger_level)(
71 "Removing the first {} as burnin".format(
72 ", ".join(
73 ["{} samples from {}".format(val, key) for key, val in n_samples.items()]
74 )
75 )
76 )
77 return _samples.discard_samples(n_samples)
80def burnin_by_first_n(samples, N, step_number=False, logger_level="debug"):
81 """Discard the first N samples as burnin
83 Parameters
84 ----------
85 samples: pesummary.utils.samples_dict.MCMCSamplesDict
86 MCMCSamplesDict object containing the samples for multiple mcmc chains
87 N: int
88 Number of samples to discard as burnin
89 step_number: Bool, optional
90 If True, discard all samples that have step number < N
91 logger_level: str, optional
92 logger level to use when printing information to stdout. Default debug
93 """
94 _samples = copy.deepcopy(samples)
95 n_samples = {key: N for key in _samples.keys()}
96 if step_number:
97 n_samples = {
98 key: item + N if item is not None else N for key, item in
99 _number_of_negative_steps(_samples, logger_level=logger_level).items()
100 }
101 getattr(logger, logger_level)(
102 "Removing the first {} as burnin".format(
103 ", ".join(
104 ["{} samples from {}".format(val, key) for key, val in n_samples.items()]
105 )
106 )
107 )
108 return _samples.discard_samples(n_samples)