Coverage for pesummary/cli/summaryclassification.py: 88.6%
79 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
1#! /usr/bin/env python
3# Licensed under an MIT style license -- see LICENSE.md
5import os
6import pesummary
7from pesummary.core.cli.inputs import _Input
8from pesummary.gw.file.read import read as GWRead
9from pesummary.gw.classification import EMBright, PAstro
10from pesummary.utils.utils import make_dir, logger
11from pesummary.utils.exceptions import InputError
12from pesummary.core.cli.parser import ArgumentParser as _ArgumentParser
14__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
15__doc__ = """This executable is used to generate a txt file containing the
16source classification probailities"""
19class ArgumentParser(_ArgumentParser):
20 def _pesummary_options(self):
21 options = super(ArgumentParser, self)._pesummary_options()
22 options.update(
23 {
24 "--plot": {
25 "choices": ["bar"],
26 "default": "bar",
27 "help": "Name of the plot you wish to make",
28 },
29 "--pastro_category_file": {
30 "default": None,
31 "help": (
32 "path to yml file containing summary data for each "
33 "category (BBH, BNS, NSBH). This includes e.g. rates, "
34 "mass bounds etc. This is used when computing PAstro"
35 )
36 },
37 "--terrestrial_probability": {
38 "default": None,
39 "help": (
40 "Terrestrial probability for the candidate you are "
41 "analysing. This is used when computing PAstro"
42 ),
43 },
44 "--catch_terrestrial_probability_error": {
45 "default": False,
46 "action": "store_true",
47 "help": (
48 "Catch the ValueError raised when no terrestrial "
49 "probability is provided when computing PAstro"
50 ),
51 "key": "gw",
52 },
53 },
54 )
55 return options
58def generate_probabilities(
59 result_files, classification_file, terrestrial_probability,
60 catch_terrestrial_probability_error
61):
62 """Generate the classification probabilities
64 Parameters
65 ----------
66 result_files: list
67 list of result files
68 """
69 classifications = []
70 _func = "classification"
71 _kwargs = {}
73 for num, i in enumerate(result_files):
74 mydict = {}
75 if not _Input.is_pesummary_metafile(i):
76 mydict = getattr(
77 EMBright, "{}_from_file".format(_func)
78 )(i, **_kwargs)
79 em_bright = getattr(PAstro, "{}_from_file".format(_func))(
80 i, category_data=classification_file,
81 terrestrial_probability=terrestrial_probability,
82 catch_terrestrial_probability_error=catch_terrestrial_probability_error,
83 **_kwargs
84 )
85 else:
86 f = GWRead(i)
87 label = f.labels[0]
88 mydict = getattr(
89 EMBright(f.samples_dict[label]), _func
90 )(**_kwargs)
91 em_bright = getattr(
92 PAstro(
93 f.samples_dict[label],
94 category_data=classification_file,
95 terrestrial_probability=terrestrial_probability,
96 catch_terrestrial_probability_error=catch_terrestrial_probability_error
97 ), _func
98 )(**_kwargs)
99 mydict.update(em_bright)
100 classifications.append(mydict)
101 return classifications
104def save_classifications(savedir, classifications, labels):
105 """Read and return a list of parameters and samples stored in the result
106 files
108 Parameters
109 ----------
110 result_files: list
111 list of result files
112 classifications: dict
113 dictionary of classification probabilities
114 """
115 import os
116 import json
118 base_path = os.path.join(savedir, "{}_pe_classification.json")
119 for num, i in enumerate(classifications):
120 for prior in i.keys():
121 with open(base_path.format(labels[num]), "w") as f:
122 json.dump(i, f)
125def make_plots(
126 result_files, webdir=None, labels=None, plot_type="bar",
127 probs=None
128):
129 """Save the plots generated by EMBright
131 Parameters
132 ----------
133 result_files: list
134 list of result files
135 webdir: str
136 path to save the files
137 labels: list
138 lisy of strings to identify each result file
139 plot_type: str
140 The plot type that you wish to make
141 probs: dict
142 Dictionary of classification probabilities
143 """
144 if webdir is None:
145 webdir = "./"
147 for num, i in enumerate(result_files):
148 if labels is None:
149 label = num
150 else:
151 label = labels[num]
152 f = GWRead(i)
153 if not isinstance(f, pesummary.gw.file.formats.pesummary.PESummary):
154 f.generate_all_posterior_samples()
155 if plot_type == "bar":
156 from pesummary.gw.plots.plot import _classification_plot
157 fig = _classification_plot(probs[num])
158 fig.savefig(
159 os.path.join(
160 webdir,
161 "{}_pastro_bar.png".format(label)
162 )
163 )
164 else:
165 raise ValueError(f"Unknown plot type: {plot_type}")
168def main(args=None):
169 """Top level interface for `summarypublication`
170 """
171 parser = ArgumentParser(description=__doc__)
172 parser.add_known_options_to_parser(
173 [
174 "--webdir", "--samples", "--labels", "--plot",
175 "--pastro_category_file", "--terrestrial_probability",
176 "--catch_terrestrial_probability_error"
177 ]
178 )
179 opts, _ = parser.parse_known_args(args=args)
180 if opts.webdir:
181 make_dir(opts.webdir)
182 else:
183 logger.warning(
184 "No webdir given so plots will not be generated and "
185 "classifications will be shown in stdout rather than saved to file"
186 )
187 classifications = generate_probabilities(
188 opts.samples, opts.pastro_category_file, opts.terrestrial_probability,
189 opts.catch_terrestrial_probability_error
190 )
191 if opts.labels is None:
192 opts.labels = []
193 for i in opts.samples:
194 f = GWRead(i)
195 if hasattr(f, "labels"):
196 opts.labels.append(f.labels[0])
197 else:
198 raise InputError("Please provide a label for each result file")
199 if opts.webdir:
200 save_classifications(opts.webdir, classifications, opts.labels)
201 else:
202 print(classifications)
203 return
204 if opts.plot == "bar":
205 probs = classifications
206 else:
207 probs = None
208 make_plots(
209 opts.samples, webdir=opts.webdir, labels=opts.labels,
210 probs=probs
211 )
214if __name__ == "__main__":
215 main()