Coverage for pesummary/tests/classification_test.py: 100.0%
52 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
1# License under an MIT style license -- see LICENSE.md
3import os
4import shutil
5import numpy as np
6from ligo.em_bright.em_bright import source_classification_pe
7from .base import make_result_file, testing_dir
8from pesummary.io import read
9from pesummary.utils.decorators import no_latex_plot
10from pesummary.gw.classification import Classify, EMBright, PAstro
12__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
15class _Base(object):
16 """Base testing class
17 """
18 def setup_method(self):
19 """Setup the class
20 """
21 if not os.path.isdir(".outdir"):
22 os.mkdir(".outdir")
23 make_result_file(gw=True, extension="dat", outdir=".outdir/")
24 f = read(".outdir/test.dat")
25 # regenerate the mass_1_source, mass_2_source posteriors because these
26 # are randomly chosen and do not correspond to
27 # mass_1 / (1. + z) and mass_2 / (1. + z)
28 f.generate_all_posterior_samples(
29 regenerate=["mass_1_source", "mass_2_source"]
30 )
31 f.write(
32 filename="test_converted.dat", file_format="dat", outdir=".outdir",
33 overwrite=True
34 )
35 f.write(
36 filename="test_lalinference.hdf5", file_format="hdf5",
37 outdir=".outdir", overwrite=True
38 )
40 def teardown_method(self):
41 """Remove the files and directories created from this class
42 """
43 if os.path.isdir(".outdir"):
44 shutil.rmtree(".outdir")
46 @no_latex_plot
47 def test_plotting(self, cls=EMBright, **kwargs):
48 """Test the .plot method
49 """
50 import matplotlib.figure
51 samples = read(".outdir/test_converted.dat").samples_dict
52 _cls = cls(samples, **kwargs)
53 ptable = _cls.classification()
54 for plot in _cls.available_plots:
55 fig = _cls.plot(ptable, type=plot)
56 assert isinstance(fig, matplotlib.figure.Figure)
59class TestPAstro(_Base):
60 """Test the pesummary.gw.classification.PAstro class and
61 pesummary.gw.classification.Classify class
62 """
63 def test_classification(self):
64 """Test the base classification method agrees with the
65 pepredicates.predicate_table function
66 """
67 pass
69 def test_plotting(self):
70 """Test the .plot method
71 """
72 import pytest
73 # catch ValueError when no terrestrial probability is given
74 with pytest.raises(ValueError):
75 super(TestPAstro, self).test_plotting(
76 cls=PAstro, category_data=f"{testing_dir}/rates.yml",
77 )
78 super(TestPAstro, self).test_plotting(
79 cls=PAstro, category_data=f"{testing_dir}/rates.yml",
80 terrestrial_probability=0.
81 )
84class TestEMBright(_Base):
85 """Test the pesummary.gw.classification.EMBright class and
86 pesummary.gw.classification.Classify class
87 """
88 def test_classification(self):
89 """Test the base classification method agrees with the
90 ligo.em_bright.source_classification_pe function
91 """
92 p_astro = source_classification_pe(
93 ".outdir/test_lalinference.hdf5"
94 )
95 samples = read(".outdir/test_converted.dat").samples_dict
96 pesummary = EMBright(samples).classification()
97 np.testing.assert_almost_equal(p_astro[0], pesummary["HasNS"], 5)
98 np.testing.assert_almost_equal(p_astro[1], pesummary["HasRemnant"], 5)
99 pesummary2 = EMBright.classification_from_file(
100 ".outdir/test_converted.dat"
101 )
102 pesummary3 = Classify.classification_from_file(
103 ".outdir/test_converted.dat",
104 category_data=f"{testing_dir}/rates.yml",
105 terrestrial_probability=0.
106 )
107 for key, val in pesummary2.items():
108 np.testing.assert_almost_equal(pesummary[key], val, 5)
109 np.testing.assert_almost_equal(pesummary[key], pesummary3[key], 5)
111 def test_plotting(self):
112 """Test the .plot method
113 """
114 super(TestEMBright, self).test_plotting(cls=EMBright)