Coverage for pesummary/tests/write_test.py: 98.0%

101 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import os 

4import shutil 

5import numpy as np 

6import pytest 

7 

8from pesummary.io import write, read 

9import tempfile 

10 

11tmpdir = tempfile.TemporaryDirectory(prefix=".", dir=".").name 

12 

13__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

14 

15 

16class Base(object): 

17 """Base class containing useful functions 

18 """ 

19 def write(self, file_format, filename, **kwargs): 

20 """Write the samples to file 

21 """ 

22 self.parameters = ["a", "b"] 

23 self.samples = np.array([ 

24 np.random.uniform(10, 5, 100), 

25 np.random.uniform(100, 2, 100) 

26 ]).T 

27 write( 

28 self.parameters, self.samples, file_format=file_format, filename=filename, 

29 outdir=tmpdir, **kwargs 

30 ) 

31 return self.parameters, self.samples 

32 

33 def check_samples(self, filename, parameters, samples, pesummary=False): 

34 """Check the saved posterior samples 

35 """ 

36 f = read(filename) 

37 posterior_samples = f.samples_dict 

38 if pesummary: 

39 posterior_samples = posterior_samples["label"] 

40 for num, param in enumerate(parameters): 

41 np.testing.assert_almost_equal( 

42 samples[num], posterior_samples[param] 

43 ) 

44 

45 

46class TestWrite(Base): 

47 """Class to test the pesummary.io.write method 

48 """ 

49 def setup_method(self): 

50 """Setup the Write class 

51 """ 

52 if not os.path.isdir(tmpdir): 

53 os.mkdir(tmpdir) 

54 

55 def teardown_method(self): 

56 """Remove the files and directories created from this class 

57 """ 

58 if os.path.isdir(tmpdir): 

59 shutil.rmtree(tmpdir) 

60 

61 def test_dat(self): 

62 """Test that the user can write to a dat file 

63 """ 

64 parameters, samples = self.write("dat", "pesummary.dat") 

65 self.check_samples("{}/pesummary.dat".format(tmpdir), parameters, samples.T) 

66 

67 def test_json(self): 

68 """Test that the user can write to a json file 

69 """ 

70 parameters, samples = self.write("json", "pesummary.json") 

71 self.check_samples("{}/pesummary.json".format(tmpdir), parameters, samples.T) 

72 

73 def test_hdf5(self): 

74 """Test that the user can write to a hdf5 file 

75 """ 

76 parameters, samples = self.write("h5", "pesummary.h5") 

77 self.check_samples("{}/pesummary.h5".format(tmpdir), parameters, samples.T) 

78 

79 def test_bilby(self): 

80 """Test that the user can write to a bilby file 

81 """ 

82 parameters, samples = self.write("bilby", "bilby.json") 

83 self.check_samples("{}/bilby.json".format(tmpdir), parameters, samples.T) 

84 parameters, samples = self.write("bilby", "bilby.h5", extension="hdf5") 

85 self.check_samples("{}/bilby.h5".format(tmpdir), parameters, samples.T) 

86 

87 def test_lalinference(self): 

88 """Test that the user can write to a lalinference file 

89 """ 

90 parameters, samples = self.write("lalinference", "lalinference.hdf5") 

91 self.check_samples("{}/lalinference.hdf5".format(tmpdir), parameters, samples.T) 

92 

93 def test_sql(self): 

94 """Test that the user can write to an sql database 

95 """ 

96 parameters, samples = self.write("sql", "sql.db") 

97 self.check_samples("{}/sql.db".format(tmpdir), parameters, samples.T) 

98 

99 def test_numpy(self): 

100 """Test that the user can write to a npy file 

101 """ 

102 parameters, samples = self.write("numpy", "numpy.npy") 

103 self.check_samples("{}/numpy.npy".format(tmpdir), parameters, samples.T) 

104 

105 def test_pesummary(self): 

106 """Test that the user can write to a pesummary file 

107 """ 

108 parameters, samples = self.write("pesummary", "pesummary.hdf5", label="label") 

109 self.check_samples( 

110 "{}/pesummary.hdf5".format(tmpdir), parameters, samples.T, pesummary=True 

111 ) 

112 

113 

114class TestWritePESummary(object): 

115 """Test the `.write` function as part of the 

116 `pesummary.gw.file.formats.pesummary.PESummary class 

117 """ 

118 @pytest.fixture(scope='class', autouse=True) 

119 def setup_method(self): 

120 """Setup the TestWritePESummary class 

121 """ 

122 from pesummary.core.fetch import download_dir 

123 downloaded_file = os.path.join( 

124 download_dir, "GW190814_posterior_samples.h5" 

125 ) 

126 if not os.path.isfile(downloaded_file): 

127 os.system( 

128 "curl https://dcc.ligo.org/public/0168/P2000183/008/GW190814_posterior_samples.h5 " 

129 "-o {}/GW190814_posterior_samples.h5".format(tmpdir) 

130 ) 

131 downloaded_file = "{}/GW190814_posterior_samples.h5".format(tmpdir) 

132 

133 type(self).result = read(downloaded_file) 

134 type(self).posterior = type(self).result.samples_dict 

135 

136 def teardown_method(self): 

137 """Remove the files and directories created from this class 

138 """ 

139 if os.path.isdir(tmpdir): 

140 shutil.rmtree(tmpdir) 

141 

142 def _write(self, file_format, extension, pesummary=False, **kwargs): 

143 if not os.path.isdir(tmpdir): 

144 os.mkdir(tmpdir) 

145 filename = { 

146 "C01:IMRPhenomHM": "test.{}".format(extension), 

147 "C01:IMRPhenomPv3HM": "test2.{}".format(extension) 

148 } 

149 self.result.write( 

150 labels=["C01:IMRPhenomHM", "C01:IMRPhenomPv3HM"], file_format=file_format, 

151 filenames=filename, outdir=tmpdir, **kwargs 

152 ) 

153 if not pesummary: 

154 assert os.path.isfile("{}/test.{}".format(tmpdir, extension)) 

155 assert os.path.isfile("{}/test2.{}".format(tmpdir, extension)) 

156 one = read("{}/test.{}".format(tmpdir, extension)) 

157 two = read("{}/test2.{}".format(tmpdir, extension)) 

158 np.testing.assert_almost_equal( 

159 one.samples_dict["mass_1"], self.posterior["C01:IMRPhenomHM"]["mass_1"] 

160 ) 

161 np.testing.assert_almost_equal( 

162 two.samples_dict["mass_1"], self.posterior["C01:IMRPhenomPv3HM"]["mass_1"] 

163 ) 

164 os.system("rm {}/test.{}".format(tmpdir, extension)) 

165 os.system("rm {}/test2.{}".format(tmpdir, extension)) 

166 else: 

167 assert os.path.isfile("{}/test.h5".format(tmpdir)) 

168 one = read("{}/test.h5".format(tmpdir)) 

169 assert sorted(one.labels) == sorted(["C01:IMRPhenomHM"]) 

170 np.testing.assert_almost_equal( 

171 one.samples_dict["C01:IMRPhenomHM"]["mass_1"], 

172 self.posterior["C01:IMRPhenomHM"]["mass_1"] 

173 ) 

174 np.testing.assert_almost_equal( 

175 one.psd["C01:IMRPhenomHM"]["H1"], self.result.psd["C01:IMRPhenomHM"]["H1"] 

176 ) 

177 

178 def test_write_dat(self): 

179 """Test write to dat 

180 """ 

181 self._write("dat", "dat") 

182 

183 def test_write_numpy(self): 

184 """Test write to numpy 

185 """ 

186 self._write("numpy", "npy") 

187 

188 def test_write_json(self): 

189 """Test write to dat 

190 """ 

191 self._write("json", "json") 

192 

193 def test_write_hdf5(self): 

194 """Test write to dat 

195 """ 

196 self._write("hdf5", "h5") 

197 

198 def test_write_bilby(self): 

199 """Test write to dat 

200 """ 

201 self._write("bilby", "json") 

202 

203 def test_write_pesummary(self): 

204 """Test write to dat 

205 """ 

206 self._write("pesummary", "h5", pesummary=True) 

207 

208 def test_write_lalinference(self): 

209 """Test write to dat 

210 """ 

211 self._write("lalinference", "h5") 

212 self._write("lalinference", "dat", dat=True)