Coverage for pesummary/tests/write_test.py: 63.4%
101 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import os
4import shutil
5import numpy as np
6import pytest
8from pesummary.io import write, read
9import tempfile
11tmpdir = tempfile.TemporaryDirectory(prefix=".", dir=".").name
13__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
16class Base(object):
17 """Base class containing useful functions
18 """
19 def write(self, file_format, filename, **kwargs):
20 """Write the samples to file
21 """
22 self.parameters = ["a", "b"]
23 self.samples = np.array([
24 np.random.uniform(10, 5, 100),
25 np.random.uniform(100, 2, 100)
26 ]).T
27 write(
28 self.parameters, self.samples, file_format=file_format, filename=filename,
29 outdir=tmpdir, **kwargs
30 )
31 return self.parameters, self.samples
33 def check_samples(self, filename, parameters, samples, pesummary=False):
34 """Check the saved posterior samples
35 """
36 f = read(filename)
37 posterior_samples = f.samples_dict
38 if pesummary:
39 posterior_samples = posterior_samples["label"]
40 for num, param in enumerate(parameters):
41 np.testing.assert_almost_equal(
42 samples[num], posterior_samples[param]
43 )
46class TestWrite(Base):
47 """Class to test the pesummary.io.write method
48 """
49 def setup_method(self):
50 """Setup the Write class
51 """
52 if not os.path.isdir(tmpdir):
53 os.mkdir(tmpdir)
55 def teardown_method(self):
56 """Remove the files and directories created from this class
57 """
58 if os.path.isdir(tmpdir):
59 shutil.rmtree(tmpdir)
61 def test_dat(self):
62 """Test that the user can write to a dat file
63 """
64 parameters, samples = self.write("dat", "pesummary.dat")
65 self.check_samples("{}/pesummary.dat".format(tmpdir), parameters, samples.T)
67 def test_json(self):
68 """Test that the user can write to a json file
69 """
70 parameters, samples = self.write("json", "pesummary.json")
71 self.check_samples("{}/pesummary.json".format(tmpdir), parameters, samples.T)
73 def test_hdf5(self):
74 """Test that the user can write to a hdf5 file
75 """
76 parameters, samples = self.write("h5", "pesummary.h5")
77 self.check_samples("{}/pesummary.h5".format(tmpdir), parameters, samples.T)
79 def test_bilby(self):
80 """Test that the user can write to a bilby file
81 """
82 parameters, samples = self.write("bilby", "bilby.json")
83 self.check_samples("{}/bilby.json".format(tmpdir), parameters, samples.T)
84 parameters, samples = self.write("bilby", "bilby.hdf5", extension="hdf5")
85 self.check_samples("{}/bilby.hdf5".format(tmpdir), parameters, samples.T)
87 def test_lalinference(self):
88 """Test that the user can write to a lalinference file
89 """
90 parameters, samples = self.write("lalinference", "lalinference.hdf5")
91 self.check_samples("{}/lalinference.hdf5".format(tmpdir), parameters, samples.T)
93 def test_sql(self):
94 """Test that the user can write to an sql database
95 """
96 parameters, samples = self.write("sql", "sql.db")
97 self.check_samples("{}/sql.db".format(tmpdir), parameters, samples.T)
99 def test_numpy(self):
100 """Test that the user can write to a npy file
101 """
102 parameters, samples = self.write("numpy", "numpy.npy")
103 self.check_samples("{}/numpy.npy".format(tmpdir), parameters, samples.T)
105 def test_pesummary(self):
106 """Test that the user can write to a pesummary file
107 """
108 parameters, samples = self.write("pesummary", "pesummary.hdf5", label="label")
109 self.check_samples(
110 "{}/pesummary.hdf5".format(tmpdir), parameters, samples.T, pesummary=True
111 )
114class TestWritePESummary(object):
115 """Test the `.write` function as part of the
116 `pesummary.gw.file.formats.pesummary.PESummary class
117 """
118 @pytest.fixture(scope='class', autouse=True)
119 def setup_method(self):
120 """Setup the TestWritePESummary class
121 """
122 from pesummary.core.fetch import download_dir
123 downloaded_file = os.path.join(
124 download_dir, "GW190814_posterior_samples.h5"
125 )
126 if not os.path.isfile(downloaded_file):
127 os.system(
128 "curl https://dcc.ligo.org/public/0168/P2000183/008/GW190814_posterior_samples.h5 "
129 "-o {}/GW190814_posterior_samples.h5".format(tmpdir)
130 )
131 downloaded_file = "{}/GW190814_posterior_samples.h5".format(tmpdir)
133 type(self).result = read(downloaded_file)
134 type(self).posterior = type(self).result.samples_dict
136 def teardown_method(self):
137 """Remove the files and directories created from this class
138 """
139 if os.path.isdir(tmpdir):
140 shutil.rmtree(tmpdir)
142 def _write(self, file_format, extension, pesummary=False, **kwargs):
143 if not os.path.isdir(tmpdir):
144 os.mkdir(tmpdir)
145 filename = {
146 "C01:IMRPhenomHM": "test.{}".format(extension),
147 "C01:IMRPhenomPv3HM": "test2.{}".format(extension)
148 }
149 self.result.write(
150 labels=["C01:IMRPhenomHM", "C01:IMRPhenomPv3HM"], file_format=file_format,
151 filenames=filename, outdir=tmpdir, **kwargs
152 )
153 if not pesummary:
154 assert os.path.isfile("{}/test.{}".format(tmpdir, extension))
155 assert os.path.isfile("{}/test2.{}".format(tmpdir, extension))
156 one = read("{}/test.{}".format(tmpdir, extension))
157 two = read("{}/test2.{}".format(tmpdir, extension))
158 np.testing.assert_almost_equal(
159 one.samples_dict["mass_1"], self.posterior["C01:IMRPhenomHM"]["mass_1"]
160 )
161 np.testing.assert_almost_equal(
162 two.samples_dict["mass_1"], self.posterior["C01:IMRPhenomPv3HM"]["mass_1"]
163 )
164 os.system("rm {}/test.{}".format(tmpdir, extension))
165 os.system("rm {}/test2.{}".format(tmpdir, extension))
166 else:
167 assert os.path.isfile("{}/test.h5".format(tmpdir))
168 one = read("{}/test.h5".format(tmpdir))
169 assert sorted(one.labels) == sorted(["C01:IMRPhenomHM"])
170 np.testing.assert_almost_equal(
171 one.samples_dict["C01:IMRPhenomHM"]["mass_1"],
172 self.posterior["C01:IMRPhenomHM"]["mass_1"]
173 )
174 np.testing.assert_almost_equal(
175 one.psd["C01:IMRPhenomHM"]["H1"], self.result.psd["C01:IMRPhenomHM"]["H1"]
176 )
178 def test_write_dat(self):
179 """Test write to dat
180 """
181 self._write("dat", "dat")
183 def test_write_numpy(self):
184 """Test write to numpy
185 """
186 self._write("numpy", "npy")
188 def test_write_json(self):
189 """Test write to dat
190 """
191 self._write("json", "json")
193 def test_write_hdf5(self):
194 """Test write to dat
195 """
196 self._write("hdf5", "h5")
198 def test_write_bilby(self):
199 """Test write to dat
200 """
201 self._write("bilby", "json")
203 def test_write_pesummary(self):
204 """Test write to dat
205 """
206 self._write("pesummary", "h5", pesummary=True)
208 def test_write_lalinference(self):
209 """Test write to dat
210 """
211 self._write("lalinference", "h5")
212 self._write("lalinference", "dat", dat=True)