Coverage for pesummary/cli/summaryclassification.py: 77.2%
101 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1#! /usr/bin/env python
3# Licensed under an MIT style license -- see LICENSE.md
5import os
6import pesummary
7from pesummary.core.cli.inputs import _Input
8from pesummary.gw.file.read import read as GWRead
9from pesummary.gw.classification import PEPredicates, PAstro
10from pesummary.utils.utils import make_dir, logger
11from pesummary.utils.exceptions import InputError
12from pesummary.core.cli.parser import ArgumentParser as _ArgumentParser
14__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
15__doc__ = """This executable is used to generate a txt file containing the
16source classification probailities"""
19class ArgumentParser(_ArgumentParser):
20 def _pesummary_options(self):
21 options = super(ArgumentParser, self)._pesummary_options()
22 options.update(
23 {
24 "--prior": {
25 "choices": ["population", "default", "both"],
26 "default": "both",
27 "help": (
28 "Prior to use when calculating source classification "
29 "probabilities"
30 )
31 },
32 "--plot": {
33 "choices": ["bar", "mass_1_mass_2"],
34 "default": "bar",
35 "help": "Name of the plot you wish to make",
36 },
37 }
38 )
39 return options
42def generate_probabilities(result_files, prior="both", seed=123456789):
43 """Generate the classification probabilities
45 Parameters
46 ----------
47 result_files: list
48 list of result files
49 prior: str
50 prior you wish to reweight your samples too
51 """
52 classifications = []
53 if prior == "both":
54 _func = "dual_classification"
55 _kwargs = {}
56 else:
57 _func = "classification"
58 _kwargs = {"population": True if prior == "population" else False}
60 for num, i in enumerate(result_files):
61 mydict = {}
62 if not _Input.is_pesummary_metafile(i):
63 mydict = getattr(
64 PEPredicates, "{}_from_file".format(_func)
65 )(i, seed=seed, **_kwargs)
66 em_bright = getattr(
67 PAstro, "{}_from_file".format(_func)
68 )(i, seed=seed, **_kwargs)
69 else:
70 f = GWRead(i)
71 label = f.labels[0]
72 mydict = getattr(
73 PEPredicates(f.samples_dict[label]), _func
74 )(seed=seed, **_kwargs)
75 em_bright = getattr(
76 PAstro(f.samples_dict[label]), _func
77 )(seed=seed, **_kwargs)
78 if prior == "both":
79 mydict["default"].update(em_bright["default"])
80 mydict["population"].update(em_bright["population"])
81 else:
82 mydict.update(em_bright)
83 classifications.append(mydict)
84 return classifications
87def save_classifications(savedir, classifications, labels):
88 """Read and return a list of parameters and samples stored in the result
89 files
91 Parameters
92 ----------
93 result_files: list
94 list of result files
95 classifications: dict
96 dictionary of classification probabilities
97 """
98 import os
99 import json
101 base_path = os.path.join(savedir, "{}_{}_prior_pe_classification.json")
102 for num, i in enumerate(classifications):
103 for prior in i.keys():
104 with open(base_path.format(labels[num], prior), "w") as f:
105 json.dump(i[prior], f)
108def make_plots(
109 result_files, webdir=None, labels=None, prior=None, plot_type="bar",
110 probs=None
111):
112 """Save the plots generated by PEPredicates
114 Parameters
115 ----------
116 result_files: list
117 list of result files
118 webdir: str
119 path to save the files
120 labels: list
121 lisy of strings to identify each result file
122 prior: str
123 Either 'default' or 'population'. If 'population' the samples are reweighted
124 to a population prior
125 plot_type: str
126 The plot type that you wish to make
127 probs: dict
128 Dictionary of classification probabilities
129 """
130 if webdir is None:
131 webdir = "./"
133 for num, i in enumerate(result_files):
134 if labels is None:
135 label = num
136 else:
137 label = labels[num]
138 f = GWRead(i)
139 if not isinstance(f, pesummary.gw.file.formats.pesummary.PESummary):
140 f.generate_all_posterior_samples()
141 if plot_type == "bar":
142 from pesummary.gw.plots.plot import _classification_plot
143 if prior == "both":
144 probs_func = lambda probs, prior: probs[prior]
145 else:
146 probs_func = lambda probs, prior: probs
147 if prior == "default" or prior == "both":
148 fig = _classification_plot(probs_func(probs[num], "default"))
149 fig.savefig(
150 os.path.join(
151 webdir,
152 "{}_default_pepredicates_bar.png".format(label)
153 )
154 )
155 if prior == "population" or prior == "both":
156 fig = _classification_plot(probs_func(probs[num], "population"))
157 fig.savefig(
158 os.path.join(
159 webdir,
160 "{}_population_pepredicates_bar.png".format(label)
161 )
162 )
163 elif plot_type == "mass_1_mass_2":
164 if prior == "default" or prior == "both":
165 fig = PEPredicates.plot(
166 f.samples, f.parameters, population_prior=False
167 )
168 fig.savefig(
169 os.path.join(
170 webdir, "{}_default_pepredicates.png".format(label)
171 )
172 )
173 if prior == "population" or prior == "both":
174 fig = PEPredicates.plot(f.samples, f.parameters)
175 fig.savefig(
176 os.path.join(
177 webdir, "{}_population_pepredicates.png".format(label)
178 )
179 )
182def main(args=None):
183 """Top level interface for `summarypublication`
184 """
185 parser = ArgumentParser(description=__doc__)
186 parser.add_known_options_to_parser(
187 ["--webdir", "--samples", "--labels", "--prior", "--plot", "--seed"]
188 )
189 opts, _ = parser.parse_known_args(args=args)
190 if opts.webdir:
191 make_dir(opts.webdir)
192 else:
193 logger.warning(
194 "No webdir given so plots will not be generated and "
195 "classifications will be shown in stdout rather than saved to file"
196 )
197 classifications = generate_probabilities(
198 opts.samples, prior=opts.prior, seed=opts.seed
199 )
200 if opts.labels is None:
201 opts.labels = []
202 for i in opts.samples:
203 f = GWRead(i)
204 if hasattr(f, "labels"):
205 opts.labels.append(f.labels[0])
206 else:
207 raise InputError("Please provide a label for each result file")
208 if opts.webdir:
209 if opts.prior != "both":
210 _classifications = [{opts.prior: c} for c in classifications]
211 else:
212 _classifications = classifications
213 save_classifications(opts.webdir, _classifications, opts.labels)
214 else:
215 print(classifications)
216 return
217 if opts.plot == "bar":
218 probs = classifications
219 else:
220 probs = None
221 make_plots(
222 opts.samples, webdir=opts.webdir, labels=opts.labels, prior=opts.prior,
223 plot_type=opts.plot, probs=probs
224 )
227if __name__ == "__main__":
228 main()