Coverage for pesummary/cli/summaryclassification.py: 77.2%

101 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1#! /usr/bin/env python 

2 

3# Licensed under an MIT style license -- see LICENSE.md 

4 

5import os 

6import pesummary 

7from pesummary.core.cli.inputs import _Input 

8from pesummary.gw.file.read import read as GWRead 

9from pesummary.gw.classification import PEPredicates, PAstro 

10from pesummary.utils.utils import make_dir, logger 

11from pesummary.utils.exceptions import InputError 

12from pesummary.core.cli.parser import ArgumentParser as _ArgumentParser 

13 

14__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

15__doc__ = """This executable is used to generate a txt file containing the 

16source classification probailities""" 

17 

18 

19class ArgumentParser(_ArgumentParser): 

20 def _pesummary_options(self): 

21 options = super(ArgumentParser, self)._pesummary_options() 

22 options.update( 

23 { 

24 "--prior": { 

25 "choices": ["population", "default", "both"], 

26 "default": "both", 

27 "help": ( 

28 "Prior to use when calculating source classification " 

29 "probabilities" 

30 ) 

31 }, 

32 "--plot": { 

33 "choices": ["bar", "mass_1_mass_2"], 

34 "default": "bar", 

35 "help": "Name of the plot you wish to make", 

36 }, 

37 } 

38 ) 

39 return options 

40 

41 

42def generate_probabilities(result_files, prior="both", seed=123456789): 

43 """Generate the classification probabilities 

44 

45 Parameters 

46 ---------- 

47 result_files: list 

48 list of result files 

49 prior: str 

50 prior you wish to reweight your samples too 

51 """ 

52 classifications = [] 

53 if prior == "both": 

54 _func = "dual_classification" 

55 _kwargs = {} 

56 else: 

57 _func = "classification" 

58 _kwargs = {"population": True if prior == "population" else False} 

59 

60 for num, i in enumerate(result_files): 

61 mydict = {} 

62 if not _Input.is_pesummary_metafile(i): 

63 mydict = getattr( 

64 PEPredicates, "{}_from_file".format(_func) 

65 )(i, seed=seed, **_kwargs) 

66 em_bright = getattr( 

67 PAstro, "{}_from_file".format(_func) 

68 )(i, seed=seed, **_kwargs) 

69 else: 

70 f = GWRead(i) 

71 label = f.labels[0] 

72 mydict = getattr( 

73 PEPredicates(f.samples_dict[label]), _func 

74 )(seed=seed, **_kwargs) 

75 em_bright = getattr( 

76 PAstro(f.samples_dict[label]), _func 

77 )(seed=seed, **_kwargs) 

78 if prior == "both": 

79 mydict["default"].update(em_bright["default"]) 

80 mydict["population"].update(em_bright["population"]) 

81 else: 

82 mydict.update(em_bright) 

83 classifications.append(mydict) 

84 return classifications 

85 

86 

87def save_classifications(savedir, classifications, labels): 

88 """Read and return a list of parameters and samples stored in the result 

89 files 

90 

91 Parameters 

92 ---------- 

93 result_files: list 

94 list of result files 

95 classifications: dict 

96 dictionary of classification probabilities 

97 """ 

98 import os 

99 import json 

100 

101 base_path = os.path.join(savedir, "{}_{}_prior_pe_classification.json") 

102 for num, i in enumerate(classifications): 

103 for prior in i.keys(): 

104 with open(base_path.format(labels[num], prior), "w") as f: 

105 json.dump(i[prior], f) 

106 

107 

108def make_plots( 

109 result_files, webdir=None, labels=None, prior=None, plot_type="bar", 

110 probs=None 

111): 

112 """Save the plots generated by PEPredicates 

113 

114 Parameters 

115 ---------- 

116 result_files: list 

117 list of result files 

118 webdir: str 

119 path to save the files 

120 labels: list 

121 lisy of strings to identify each result file 

122 prior: str 

123 Either 'default' or 'population'. If 'population' the samples are reweighted 

124 to a population prior 

125 plot_type: str 

126 The plot type that you wish to make 

127 probs: dict 

128 Dictionary of classification probabilities 

129 """ 

130 if webdir is None: 

131 webdir = "./" 

132 

133 for num, i in enumerate(result_files): 

134 if labels is None: 

135 label = num 

136 else: 

137 label = labels[num] 

138 f = GWRead(i) 

139 if not isinstance(f, pesummary.gw.file.formats.pesummary.PESummary): 

140 f.generate_all_posterior_samples() 

141 if plot_type == "bar": 

142 from pesummary.gw.plots.plot import _classification_plot 

143 if prior == "both": 

144 probs_func = lambda probs, prior: probs[prior] 

145 else: 

146 probs_func = lambda probs, prior: probs 

147 if prior == "default" or prior == "both": 

148 fig = _classification_plot(probs_func(probs[num], "default")) 

149 fig.savefig( 

150 os.path.join( 

151 webdir, 

152 "{}_default_pepredicates_bar.png".format(label) 

153 ) 

154 ) 

155 if prior == "population" or prior == "both": 

156 fig = _classification_plot(probs_func(probs[num], "population")) 

157 fig.savefig( 

158 os.path.join( 

159 webdir, 

160 "{}_population_pepredicates_bar.png".format(label) 

161 ) 

162 ) 

163 elif plot_type == "mass_1_mass_2": 

164 if prior == "default" or prior == "both": 

165 fig = PEPredicates.plot( 

166 f.samples, f.parameters, population_prior=False 

167 ) 

168 fig.savefig( 

169 os.path.join( 

170 webdir, "{}_default_pepredicates.png".format(label) 

171 ) 

172 ) 

173 if prior == "population" or prior == "both": 

174 fig = PEPredicates.plot(f.samples, f.parameters) 

175 fig.savefig( 

176 os.path.join( 

177 webdir, "{}_population_pepredicates.png".format(label) 

178 ) 

179 ) 

180 

181 

182def main(args=None): 

183 """Top level interface for `summarypublication` 

184 """ 

185 parser = ArgumentParser(description=__doc__) 

186 parser.add_known_options_to_parser( 

187 ["--webdir", "--samples", "--labels", "--prior", "--plot", "--seed"] 

188 ) 

189 opts, _ = parser.parse_known_args(args=args) 

190 if opts.webdir: 

191 make_dir(opts.webdir) 

192 else: 

193 logger.warning( 

194 "No webdir given so plots will not be generated and " 

195 "classifications will be shown in stdout rather than saved to file" 

196 ) 

197 classifications = generate_probabilities( 

198 opts.samples, prior=opts.prior, seed=opts.seed 

199 ) 

200 if opts.labels is None: 

201 opts.labels = [] 

202 for i in opts.samples: 

203 f = GWRead(i) 

204 if hasattr(f, "labels"): 

205 opts.labels.append(f.labels[0]) 

206 else: 

207 raise InputError("Please provide a label for each result file") 

208 if opts.webdir: 

209 if opts.prior != "both": 

210 _classifications = [{opts.prior: c} for c in classifications] 

211 else: 

212 _classifications = classifications 

213 save_classifications(opts.webdir, _classifications, opts.labels) 

214 else: 

215 print(classifications) 

216 return 

217 if opts.plot == "bar": 

218 probs = classifications 

219 else: 

220 probs = None 

221 make_plots( 

222 opts.samples, webdir=opts.webdir, labels=opts.labels, prior=opts.prior, 

223 plot_type=opts.plot, probs=probs 

224 ) 

225 

226 

227if __name__ == "__main__": 

228 main()