Coverage for pesummary/core/file/mcmc.py: 94.3%

35 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3from pesummary.utils.utils import logger 

4import numpy as np 

5import copy 

6 

7__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

8STEP_NUMBER_PARAMS = ["cycle"] 

9algorithms = ["burnin_by_step_number", "burnin_by_first_n"] 

10 

11 

12def _number_of_negative_steps(samples, logger_level="debug"): 

13 """Return the number of samples that have step < 0 for each dictionary 

14 

15 Parameters 

16 ---------- 

17 samples: pesummary.utils.samples_dict.MCMCSamplesDict 

18 MCMCSamplesDict object containing the samples for multiple mcmc chains 

19 logger_level: str, optional 

20 logger level to use when printing information to stdout. Default debug 

21 """ 

22 _samples = copy.deepcopy(samples) 

23 try: 

24 parameters = set.intersection( 

25 *[set(_params) for _params in _samples.parameters.values()] 

26 ) 

27 except AttributeError: 

28 parameters = list(_samples.parameters) 

29 step_param = [ 

30 alternative for alternative in STEP_NUMBER_PARAMS if alternative 

31 in parameters 

32 ] 

33 if not len(step_param): 

34 logger.warning( 

35 "Unable to find a step number in the MCMCSamplesDict object. " 

36 "Aborting discard" 

37 ) 

38 return {key: None for key in _samples.keys()} 

39 elif len(step_param) > 1: 

40 step_param = step_param[0] 

41 getattr(logger, logger_level)( 

42 "Multiple columns found with possible step numbers. Using " 

43 "{}".format(step_param) 

44 ) 

45 else: 

46 step_param = step_param[0] 

47 keys = _samples.keys() 

48 step_idx = [ 

49 np.arange(_samples[key].number_of_samples)[_samples[key][step_param] > 0] 

50 for key in _samples.keys() 

51 ] 

52 return { 

53 key: step[0] if len(step_idx) else 0 for key, step in 

54 zip(keys, step_idx) 

55 } 

56 

57 

58def burnin_by_step_number(samples, logger_level="debug"): 

59 """Discard all samples with step number < 0 as burnin 

60 

61 Parameters 

62 ---------- 

63 samples: pesummary.utils.samples_dict.MCMCSamplesDict 

64 MCMCSamplesDict object containing the samples for multiple mcmc chains 

65 logger_level: str, optional 

66 logger level to use when printing information to stdout. Default debug 

67 """ 

68 _samples = copy.deepcopy(samples) 

69 n_samples = _number_of_negative_steps(_samples, logger_level=logger_level) 

70 getattr(logger, logger_level)( 

71 "Removing the first {} as burnin".format( 

72 ", ".join( 

73 ["{} samples from {}".format(val, key) for key, val in n_samples.items()] 

74 ) 

75 ) 

76 ) 

77 return _samples.discard_samples(n_samples) 

78 

79 

80def burnin_by_first_n(samples, N, step_number=False, logger_level="debug"): 

81 """Discard the first N samples as burnin 

82 

83 Parameters 

84 ---------- 

85 samples: pesummary.utils.samples_dict.MCMCSamplesDict 

86 MCMCSamplesDict object containing the samples for multiple mcmc chains 

87 N: int 

88 Number of samples to discard as burnin 

89 step_number: Bool, optional 

90 If True, discard all samples that have step number < N 

91 logger_level: str, optional 

92 logger level to use when printing information to stdout. Default debug 

93 """ 

94 _samples = copy.deepcopy(samples) 

95 n_samples = {key: N for key in _samples.keys()} 

96 if step_number: 

97 n_samples = { 

98 key: item + N if item is not None else N for key, item in 

99 _number_of_negative_steps(_samples, logger_level=logger_level).items() 

100 } 

101 getattr(logger, logger_level)( 

102 "Removing the first {} as burnin".format( 

103 ", ".join( 

104 ["{} samples from {}".format(val, key) for key, val in n_samples.items()] 

105 ) 

106 ) 

107 ) 

108 return _samples.discard_samples(n_samples)