Coverage for pesummary/tests/classification_test.py: 22.2%
90 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# License under an MIT style license -- see LICENSE.md
3import os
4import shutil
5import numpy as np
6from ligo.em_bright.em_bright import source_classification_pe
7from .base import make_result_file
8from pesummary.io import read
9from pesummary.utils.decorators import no_latex_plot
10from pesummary.gw.classification import Classify, PEPredicates, PAstro
12__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
15class _Base(object):
16 """Base testing class
17 """
18 def setup_method(self):
19 """Setup the class
20 """
21 if not os.path.isdir(".outdir"):
22 os.mkdir(".outdir")
23 make_result_file(gw=True, extension="dat", outdir=".outdir/")
24 f = read(".outdir/test.dat")
25 # regenerate the mass_1_source, mass_2_source posteriors because these
26 # are randomly chosen and do not correspond to
27 # mass_1 / (1. + z) and mass_2 / (1. + z)
28 f.generate_all_posterior_samples(
29 regenerate=["mass_1_source", "mass_2_source"]
30 )
31 f.write(
32 filename="test_converted.dat", file_format="dat", outdir=".outdir",
33 overwrite=True
34 )
35 f.write(
36 filename="test_lalinference.hdf5", file_format="hdf5",
37 outdir=".outdir", overwrite=True
38 )
40 def teardown_method(self):
41 """Remove the files and directories created from this class
42 """
43 if os.path.isdir(".outdir"):
44 shutil.rmtree(".outdir")
46 @no_latex_plot
47 def test_plotting(self, cls=PEPredicates):
48 """Test the .plot method
49 """
50 import matplotlib.figure
51 samples = read(".outdir/test_converted.dat").samples_dict
52 _cls = cls(samples)
53 for plot in _cls.available_plots:
54 fig = _cls.plot(type=plot)
55 assert isinstance(fig, matplotlib.figure.Figure)
58class TestPEPredicates(_Base):
59 """Test the pesummary.gw.classification.PEPredicates class and
60 pesummary.gw.classification.Classify class
61 """
62 def test_classification(self):
63 """Test the base classification method agrees with the
64 pepredicates.predicate_table function
65 """
66 from pepredicates import (
67 predicate_table, BNS_p, NSBH_p, BBH_p, MG_p
68 )
69 from pandas import DataFrame
70 samples = read(".outdir/test_converted.dat").samples_dict
71 pesummary = PEPredicates(samples).classification()
72 pesummary2 = PEPredicates.classification_from_file(
73 ".outdir/test_converted.dat"
74 )
75 pesummary3 = Classify.classification_from_file(
76 ".outdir/test_converted.dat"
77 )
78 pesummary4 = Classify.dual_classification_from_file(
79 ".outdir/test_converted.dat"
80 )["default"]
81 df = DataFrame.from_dict(
82 {
83 "m1_source": samples["mass_1_source"],
84 "m2_source": samples["mass_2_source"],
85 }
86 )
87 probs = predicate_table(
88 {"BNS": BNS_p, "BBH": BBH_p, "MassGap": MG_p, "NSBH": NSBH_p},
89 df
90 )
91 for key in probs.keys():
92 np.testing.assert_almost_equal(probs[key], pesummary[key])
93 np.testing.assert_almost_equal(probs[key], pesummary2[key])
94 np.testing.assert_almost_equal(probs[key], pesummary3[key])
95 np.testing.assert_almost_equal(probs[key], pesummary4[key])
98class TestPAstro(_Base):
99 """Test the pesummary.gw.classification.PAstro class and
100 pesummary.gw.classification.Classify class
101 """
102 def test_classification(self):
103 """Test the base classification method agrees with the
104 ligo.em_bright.source_classification_pe function
105 """
106 p_astro = source_classification_pe(
107 ".outdir/test_lalinference.hdf5"
108 )
109 samples = read(".outdir/test_converted.dat").samples_dict
110 pesummary = PAstro(samples).classification()
111 np.testing.assert_almost_equal(p_astro[0], pesummary["HasNS"], 5)
112 np.testing.assert_almost_equal(p_astro[1], pesummary["HasRemnant"], 5)
113 pesummary2 = PAstro.classification_from_file(
114 ".outdir/test_converted.dat"
115 )
116 pesummary3 = Classify.classification_from_file(
117 ".outdir/test_converted.dat"
118 )
119 pesummary4 = Classify.dual_classification_from_file(
120 ".outdir/test_converted.dat"
121 )["default"]
122 for key, val in pesummary2.items():
123 np.testing.assert_almost_equal(pesummary[key], val, 5)
124 np.testing.assert_almost_equal(pesummary[key], pesummary3[key], 5)
125 np.testing.assert_almost_equal(pesummary[key], pesummary4[key], 5)
127 def test_reweight_classification(self):
128 """Test that the population reweighted classification method agrees
129 with the ligo.em_bright.source_classification_pe function.
130 """
131 from pepredicates import rewt_approx_massdist_redshift
132 from pesummary.gw.conversions import mchirp_from_m1_m2, q_from_m1_m2
133 from pandas import DataFrame
135 np.random.seed(1234)
136 rerun = True
137 while rerun:
138 samples = read(".outdir/test_converted.dat").samples_dict
139 df = DataFrame.from_dict(
140 dict(
141 m1_source=samples["mass_1_source"],
142 m2_source=samples["mass_2_source"],
143 a1=samples["a_1"],
144 a2=samples["a_2"],
145 dist=samples["luminosity_distance"],
146 redshift=samples["redshift"]
147 )
148 )
149 df["mc_source"] = mchirp_from_m1_m2(df["m1_source"], df["m2_source"])
150 df["q"] = q_from_m1_m2(df["m1_source"], df["m2_source"])
151 _reweighted_samples = rewt_approx_massdist_redshift(df)
152 # ligo.em_bright.source_classification_pe fails if there is only
153 # one sample
154 if len(_reweighted_samples["m1_source"]) != 1:
155 rerun = False
156 else:
157 self.setup_method()
158 _reweighted_samples.to_csv(
159 ".outdir/test_reweighted.dat", sep=" ", index=False
160 )
161 reweighted_file = read(".outdir/test_reweighted.dat")
162 reweighted_file.write(
163 filename="test_reweighted.hdf5", file_format="hdf5", outdir=".outdir",
164 overwrite=True
165 )
166 p_astro = source_classification_pe(
167 ".outdir/test_reweighted.hdf5"
168 )
169 _samples, pesummary = PAstro(samples).classification(
170 population=True, return_samples=True, seed=12345
171 )
172 np.testing.assert_almost_equal(p_astro[0], pesummary["HasNS"], 5)
173 np.testing.assert_almost_equal(p_astro[1], pesummary["HasRemnant"], 5)
174 pesummary2 = PAstro.classification_from_file(
175 ".outdir/test_converted.dat", population=True
176 )
177 pesummary3 = Classify.classification_from_file(
178 ".outdir/test_converted.dat", population=True
179 )
180 pesummary4 = Classify.dual_classification_from_file(
181 ".outdir/test_converted.dat"
182 )["population"]
183 for key, val in pesummary2.items():
184 np.testing.assert_almost_equal(pesummary[key], val, 5)
185 np.testing.assert_almost_equal(pesummary[key], pesummary3[key], 5)
186 np.testing.assert_almost_equal(pesummary[key], pesummary4[key], 5)
188 def test_plotting(self):
189 """Test the .plot method
190 """
191 super(TestPAstro, self).test_plotting(cls=PAstro)