Coverage for pesummary/tests/classification_test.py: 100.0%

90 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# License under an MIT style license -- see LICENSE.md 

2 

3import os 

4import shutil 

5import numpy as np 

6from ligo.em_bright.em_bright import source_classification_pe 

7from .base import make_result_file 

8from pesummary.io import read 

9from pesummary.utils.decorators import no_latex_plot 

10from pesummary.gw.classification import Classify, PEPredicates, PAstro 

11 

12__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

13 

14 

15class _Base(object): 

16 """Base testing class 

17 """ 

18 def setup_method(self): 

19 """Setup the class 

20 """ 

21 if not os.path.isdir(".outdir"): 

22 os.mkdir(".outdir") 

23 make_result_file(gw=True, extension="dat", outdir=".outdir/") 

24 f = read(".outdir/test.dat") 

25 # regenerate the mass_1_source, mass_2_source posteriors because these 

26 # are randomly chosen and do not correspond to 

27 # mass_1 / (1. + z) and mass_2 / (1. + z) 

28 f.generate_all_posterior_samples( 

29 regenerate=["mass_1_source", "mass_2_source"] 

30 ) 

31 f.write( 

32 filename="test_converted.dat", file_format="dat", outdir=".outdir", 

33 overwrite=True 

34 ) 

35 f.write( 

36 filename="test_lalinference.hdf5", file_format="hdf5", 

37 outdir=".outdir", overwrite=True 

38 ) 

39 

40 def teardown_method(self): 

41 """Remove the files and directories created from this class 

42 """ 

43 if os.path.isdir(".outdir"): 

44 shutil.rmtree(".outdir") 

45 

46 @no_latex_plot 

47 def test_plotting(self, cls=PEPredicates): 

48 """Test the .plot method 

49 """ 

50 import matplotlib.figure 

51 samples = read(".outdir/test_converted.dat").samples_dict 

52 _cls = cls(samples) 

53 for plot in _cls.available_plots: 

54 fig = _cls.plot(type=plot) 

55 assert isinstance(fig, matplotlib.figure.Figure) 

56 

57 

58class TestPEPredicates(_Base): 

59 """Test the pesummary.gw.classification.PEPredicates class and 

60 pesummary.gw.classification.Classify class 

61 """ 

62 def test_classification(self): 

63 """Test the base classification method agrees with the 

64 pepredicates.predicate_table function 

65 """ 

66 from pepredicates import ( 

67 predicate_table, BNS_p, NSBH_p, BBH_p, MG_p 

68 ) 

69 from pandas import DataFrame 

70 samples = read(".outdir/test_converted.dat").samples_dict 

71 pesummary = PEPredicates(samples).classification() 

72 pesummary2 = PEPredicates.classification_from_file( 

73 ".outdir/test_converted.dat" 

74 ) 

75 pesummary3 = Classify.classification_from_file( 

76 ".outdir/test_converted.dat" 

77 ) 

78 pesummary4 = Classify.dual_classification_from_file( 

79 ".outdir/test_converted.dat" 

80 )["default"] 

81 df = DataFrame.from_dict( 

82 { 

83 "m1_source": samples["mass_1_source"], 

84 "m2_source": samples["mass_2_source"], 

85 } 

86 ) 

87 probs = predicate_table( 

88 {"BNS": BNS_p, "BBH": BBH_p, "MassGap": MG_p, "NSBH": NSBH_p}, 

89 df 

90 ) 

91 for key in probs.keys(): 

92 np.testing.assert_almost_equal(probs[key], pesummary[key]) 

93 np.testing.assert_almost_equal(probs[key], pesummary2[key]) 

94 np.testing.assert_almost_equal(probs[key], pesummary3[key]) 

95 np.testing.assert_almost_equal(probs[key], pesummary4[key]) 

96 

97 

98class TestPAstro(_Base): 

99 """Test the pesummary.gw.classification.PAstro class and 

100 pesummary.gw.classification.Classify class 

101 """ 

102 def test_classification(self): 

103 """Test the base classification method agrees with the 

104 ligo.em_bright.source_classification_pe function 

105 """ 

106 p_astro = source_classification_pe( 

107 ".outdir/test_lalinference.hdf5" 

108 ) 

109 samples = read(".outdir/test_converted.dat").samples_dict 

110 pesummary = PAstro(samples).classification() 

111 np.testing.assert_almost_equal(p_astro[0], pesummary["HasNS"], 5) 

112 np.testing.assert_almost_equal(p_astro[1], pesummary["HasRemnant"], 5) 

113 pesummary2 = PAstro.classification_from_file( 

114 ".outdir/test_converted.dat" 

115 ) 

116 pesummary3 = Classify.classification_from_file( 

117 ".outdir/test_converted.dat" 

118 ) 

119 pesummary4 = Classify.dual_classification_from_file( 

120 ".outdir/test_converted.dat" 

121 )["default"] 

122 for key, val in pesummary2.items(): 

123 np.testing.assert_almost_equal(pesummary[key], val, 5) 

124 np.testing.assert_almost_equal(pesummary[key], pesummary3[key], 5) 

125 np.testing.assert_almost_equal(pesummary[key], pesummary4[key], 5) 

126 

127 def test_reweight_classification(self): 

128 """Test that the population reweighted classification method agrees 

129 with the ligo.em_bright.source_classification_pe function. 

130 """ 

131 from pepredicates import rewt_approx_massdist_redshift 

132 from pesummary.gw.conversions import mchirp_from_m1_m2, q_from_m1_m2 

133 from pandas import DataFrame 

134 

135 np.random.seed(1234) 

136 rerun = True 

137 while rerun: 

138 samples = read(".outdir/test_converted.dat").samples_dict 

139 df = DataFrame.from_dict( 

140 dict( 

141 m1_source=samples["mass_1_source"], 

142 m2_source=samples["mass_2_source"], 

143 a1=samples["a_1"], 

144 a2=samples["a_2"], 

145 dist=samples["luminosity_distance"], 

146 redshift=samples["redshift"] 

147 ) 

148 ) 

149 df["mc_source"] = mchirp_from_m1_m2(df["m1_source"], df["m2_source"]) 

150 df["q"] = q_from_m1_m2(df["m1_source"], df["m2_source"]) 

151 _reweighted_samples = rewt_approx_massdist_redshift(df) 

152 # ligo.em_bright.source_classification_pe fails if there is only 

153 # one sample 

154 if len(_reweighted_samples["m1_source"]) != 1: 

155 rerun = False 

156 else: 

157 self.setup_method() 

158 _reweighted_samples.to_csv( 

159 ".outdir/test_reweighted.dat", sep=" ", index=False 

160 ) 

161 reweighted_file = read(".outdir/test_reweighted.dat") 

162 reweighted_file.write( 

163 filename="test_reweighted.hdf5", file_format="hdf5", outdir=".outdir", 

164 overwrite=True 

165 ) 

166 p_astro = source_classification_pe( 

167 ".outdir/test_reweighted.hdf5" 

168 ) 

169 _samples, pesummary = PAstro(samples).classification( 

170 population=True, return_samples=True, seed=12345 

171 ) 

172 np.testing.assert_almost_equal(p_astro[0], pesummary["HasNS"], 5) 

173 np.testing.assert_almost_equal(p_astro[1], pesummary["HasRemnant"], 5) 

174 pesummary2 = PAstro.classification_from_file( 

175 ".outdir/test_converted.dat", population=True 

176 ) 

177 pesummary3 = Classify.classification_from_file( 

178 ".outdir/test_converted.dat", population=True 

179 ) 

180 pesummary4 = Classify.dual_classification_from_file( 

181 ".outdir/test_converted.dat" 

182 )["population"] 

183 for key, val in pesummary2.items(): 

184 np.testing.assert_almost_equal(pesummary[key], val, 5) 

185 np.testing.assert_almost_equal(pesummary[key], pesummary3[key], 5) 

186 np.testing.assert_almost_equal(pesummary[key], pesummary4[key], 5) 

187 

188 def test_plotting(self): 

189 """Test the .plot method 

190 """ 

191 super(TestPAstro, self).test_plotting(cls=PAstro)