Coverage for pesummary/cli/summaryclassification.py: 88.6%

79 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1#! /usr/bin/env python 

2 

3# Licensed under an MIT style license -- see LICENSE.md 

4 

5import os 

6import pesummary 

7from pesummary.core.cli.inputs import _Input 

8from pesummary.gw.file.read import read as GWRead 

9from pesummary.gw.classification import EMBright, PAstro 

10from pesummary.utils.utils import make_dir, logger 

11from pesummary.utils.exceptions import InputError 

12from pesummary.core.cli.parser import ArgumentParser as _ArgumentParser 

13 

14__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

15__doc__ = """This executable is used to generate a txt file containing the 

16source classification probailities""" 

17 

18 

19class ArgumentParser(_ArgumentParser): 

20 def _pesummary_options(self): 

21 options = super(ArgumentParser, self)._pesummary_options() 

22 options.update( 

23 { 

24 "--plot": { 

25 "choices": ["bar"], 

26 "default": "bar", 

27 "help": "Name of the plot you wish to make", 

28 }, 

29 "--pastro_category_file": { 

30 "default": None, 

31 "help": ( 

32 "path to yml file containing summary data for each " 

33 "category (BBH, BNS, NSBH). This includes e.g. rates, " 

34 "mass bounds etc. This is used when computing PAstro" 

35 ) 

36 }, 

37 "--terrestrial_probability": { 

38 "default": None, 

39 "help": ( 

40 "Terrestrial probability for the candidate you are " 

41 "analysing. This is used when computing PAstro" 

42 ), 

43 }, 

44 "--catch_terrestrial_probability_error": { 

45 "default": False, 

46 "action": "store_true", 

47 "help": ( 

48 "Catch the ValueError raised when no terrestrial " 

49 "probability is provided when computing PAstro" 

50 ), 

51 "key": "gw", 

52 }, 

53 }, 

54 ) 

55 return options 

56 

57 

58def generate_probabilities( 

59 result_files, classification_file, terrestrial_probability, 

60 catch_terrestrial_probability_error 

61): 

62 """Generate the classification probabilities 

63 

64 Parameters 

65 ---------- 

66 result_files: list 

67 list of result files 

68 """ 

69 classifications = [] 

70 _func = "classification" 

71 _kwargs = {} 

72 

73 for num, i in enumerate(result_files): 

74 mydict = {} 

75 if not _Input.is_pesummary_metafile(i): 

76 mydict = getattr( 

77 EMBright, "{}_from_file".format(_func) 

78 )(i, **_kwargs) 

79 em_bright = getattr(PAstro, "{}_from_file".format(_func))( 

80 i, category_data=classification_file, 

81 terrestrial_probability=terrestrial_probability, 

82 catch_terrestrial_probability_error=catch_terrestrial_probability_error, 

83 **_kwargs 

84 ) 

85 else: 

86 f = GWRead(i) 

87 label = f.labels[0] 

88 mydict = getattr( 

89 EMBright(f.samples_dict[label]), _func 

90 )(**_kwargs) 

91 em_bright = getattr( 

92 PAstro( 

93 f.samples_dict[label], 

94 category_data=classification_file, 

95 terrestrial_probability=terrestrial_probability, 

96 catch_terrestrial_probability_error=catch_terrestrial_probability_error 

97 ), _func 

98 )(**_kwargs) 

99 mydict.update(em_bright) 

100 classifications.append(mydict) 

101 return classifications 

102 

103 

104def save_classifications(savedir, classifications, labels): 

105 """Read and return a list of parameters and samples stored in the result 

106 files 

107 

108 Parameters 

109 ---------- 

110 result_files: list 

111 list of result files 

112 classifications: dict 

113 dictionary of classification probabilities 

114 """ 

115 import os 

116 import json 

117 

118 base_path = os.path.join(savedir, "{}_pe_classification.json") 

119 for num, i in enumerate(classifications): 

120 for prior in i.keys(): 

121 with open(base_path.format(labels[num]), "w") as f: 

122 json.dump(i, f) 

123 

124 

125def make_plots( 

126 result_files, webdir=None, labels=None, plot_type="bar", 

127 probs=None 

128): 

129 """Save the plots generated by EMBright 

130 

131 Parameters 

132 ---------- 

133 result_files: list 

134 list of result files 

135 webdir: str 

136 path to save the files 

137 labels: list 

138 lisy of strings to identify each result file 

139 plot_type: str 

140 The plot type that you wish to make 

141 probs: dict 

142 Dictionary of classification probabilities 

143 """ 

144 if webdir is None: 

145 webdir = "./" 

146 

147 for num, i in enumerate(result_files): 

148 if labels is None: 

149 label = num 

150 else: 

151 label = labels[num] 

152 f = GWRead(i) 

153 if not isinstance(f, pesummary.gw.file.formats.pesummary.PESummary): 

154 f.generate_all_posterior_samples() 

155 if plot_type == "bar": 

156 from pesummary.gw.plots.plot import _classification_plot 

157 fig = _classification_plot(probs[num]) 

158 fig.savefig( 

159 os.path.join( 

160 webdir, 

161 "{}_pastro_bar.png".format(label) 

162 ) 

163 ) 

164 else: 

165 raise ValueError(f"Unknown plot type: {plot_type}") 

166 

167 

168def main(args=None): 

169 """Top level interface for `summarypublication` 

170 """ 

171 parser = ArgumentParser(description=__doc__) 

172 parser.add_known_options_to_parser( 

173 [ 

174 "--webdir", "--samples", "--labels", "--plot", 

175 "--pastro_category_file", "--terrestrial_probability", 

176 "--catch_terrestrial_probability_error" 

177 ] 

178 ) 

179 opts, _ = parser.parse_known_args(args=args) 

180 if opts.webdir: 

181 make_dir(opts.webdir) 

182 else: 

183 logger.warning( 

184 "No webdir given so plots will not be generated and " 

185 "classifications will be shown in stdout rather than saved to file" 

186 ) 

187 classifications = generate_probabilities( 

188 opts.samples, opts.pastro_category_file, opts.terrestrial_probability, 

189 opts.catch_terrestrial_probability_error 

190 ) 

191 if opts.labels is None: 

192 opts.labels = [] 

193 for i in opts.samples: 

194 f = GWRead(i) 

195 if hasattr(f, "labels"): 

196 opts.labels.append(f.labels[0]) 

197 else: 

198 raise InputError("Please provide a label for each result file") 

199 if opts.webdir: 

200 save_classifications(opts.webdir, classifications, opts.labels) 

201 else: 

202 print(classifications) 

203 return 

204 if opts.plot == "bar": 

205 probs = classifications 

206 else: 

207 probs = None 

208 make_plots( 

209 opts.samples, webdir=opts.webdir, labels=opts.labels, 

210 probs=probs 

211 ) 

212 

213 

214if __name__ == "__main__": 

215 main()