Coverage for pesummary/tests/read_test.py: 100.0%
1027 statements
« prev ^ index » next coverage.py v7.4.4, created at 2026-01-15 17:49 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2026-01-15 17:49 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import os
4import shutil
5import numpy as np
7from .base import make_result_file, testing_dir
8import pesummary
9from pesummary.gw.file.read import read as GWRead
10from pesummary.core.file.read import read as Read
11from pesummary.io import read, write
12import glob
13import tempfile
15tmpdir = tempfile.TemporaryDirectory(prefix=".", dir=".").name
16__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
19class BaseRead(object):
20 """Base class to test the core functions in the Read and GWRead functions
21 """
22 def test_parameters(self, true, pesummary=False):
23 """Test the parameter property
24 """
25 if pesummary:
26 assert all(i in self.result.parameters[0] for i in true)
27 assert all(i in true for i in self.result.parameters[0])
28 else:
29 assert all(i in self.result.parameters for i in true)
30 assert all(i in true for i in self.result.parameters)
32 def test_samples(self, true, pesummary=False):
33 """Test the samples property
34 """
35 if pesummary:
36 assert len(self.result.samples[0]) == 1000
37 assert len(self.result.samples[0][0]) == 18
38 samples = self.result.samples[0]
39 parameters = self.result.parameters[0]
40 else:
41 assert len(self.result.samples) == 1000
42 assert len(self.result.samples[0]) == 18
43 samples = self.result.samples
44 parameters = self.result.parameters
46 idxs = [self.parameters.index(i) for i in parameters]
47 np.testing.assert_almost_equal(
48 np.array(samples), np.array(self.samples)[:, idxs]
49 )
50 for ind, param in enumerate(parameters):
51 samp = np.array(samples).T[ind]
52 idx = self.parameters.index(param)
53 np.testing.assert_almost_equal(samp, np.array(self.samples).T[idx])
55 def test_samples_dict(self, true):
56 """Test the samples_dict property
57 """
58 parameters = true[0]
59 samples = true[1]
61 for num, param in enumerate(parameters):
62 specific_samples = [i[num] for i in samples]
63 drawn_samples = self.result.samples_dict[param]
64 np.testing.assert_almost_equal(drawn_samples, specific_samples)
66 def test_version(self, true=None):
67 """Test the version property
68 """
69 if true is None:
70 assert self.result.input_version == "No version information found"
71 else:
72 assert self.result.input_version == true
74 def test_extra_kwargs(self, true=None):
75 """Test the extra_kwargs property
76 """
77 if true is None:
78 assert self.result.extra_kwargs == {
79 "sampler": {"nsamples": 1000}, "meta_data": {}
80 }
81 else:
82 assert sorted(self.result.extra_kwargs) == sorted(true)
84 def test_injection_parameters(self, true, pesummary=False):
85 """Test the injection_parameters property
86 """
87 if true is None:
88 assert self.result.injection_parameters is None
89 else:
90 import math
92 assert all(i in list(true.keys()) for i in self.parameters)
93 assert all(i in self.parameters for i in list(true.keys()))
95 if not pesummary:
96 for i in true.keys():
97 if math.isnan(true[i]):
98 assert math.isnan(self.result.injection_parameters[i])
99 else:
100 assert true[i] == self.result.injection_parameters[i]
102 def test_to_dat(self):
103 """Test the to_dat method
104 """
105 self.result.to_dat(outdir=tmpdir, label="label")
106 assert os.path.isfile(os.path.join(tmpdir, "pesummary_label.dat"))
107 data = np.genfromtxt(
108 os.path.join(tmpdir, "pesummary_label.dat"), names=True)
109 assert all(i in self.parameters for i in list(data.dtype.names))
110 assert all(i in list(data.dtype.names) for i in self.parameters)
111 for param in self.parameters:
112 assert np.testing.assert_almost_equal(
113 data[param], self.result.samples_dict[param], 8
114 ) is None
116 def test_file_format_read(self, path, file_format, _class, function=Read):
117 """Test that when the file_format is specified, that correct class is used
118 """
119 result = function(path, file_format=file_format)
120 assert isinstance(result, _class)
122 def test_downsample(self):
123 """Test the .downsample method. This includes testing that the
124 .downsample method downsamples to the specified number of samples,
125 that it only takes samples that are currently in the posterior
126 table and that it maintains concurrent samples.
127 """
128 old_samples_dict = self.result.samples_dict
129 nsamples = 50
130 self.result.downsample(nsamples)
131 new_samples_dict = self.result.samples_dict
132 assert new_samples_dict.number_of_samples == nsamples
133 for param in self.parameters:
134 assert all(
135 samp in old_samples_dict[param] for samp in
136 new_samples_dict[param]
137 )
138 for num in range(nsamples):
139 samp_inds = [
140 old_samples_dict[param].tolist().index(
141 new_samples_dict[param][num]
142 ) for param in self.parameters
143 ]
144 assert len(set(samp_inds)) == 1
147class GWBaseRead(BaseRead):
148 """Base class to test the GWRead specific functions
149 """
150 def test_parameters(self, true, pesummary=False):
151 """Test the parameter property
152 """
153 super(GWBaseRead, self).test_parameters(true, pesummary=pesummary)
154 from .base import gw_parameters
155 full_parameters = gw_parameters()
157 self.result.generate_all_posterior_samples()
158 assert all(i in self.result.parameters for i in full_parameters)
159 assert all(i in full_parameters for i in self.result.parameters)
161 def test_injection_parameters(self, true):
162 """Test the injection_parameters property
163 """
164 import math
166 super(GWBaseRead, self).test_injection_parameters(true)
167 self.result.add_injection_parameters_from_file(testing_dir + "/main_injection.xml", conversion=False)
168 true = {
169 'dec': [1.949725], 'geocent_time': [1186741861], 'spin_2x': [0.],
170 'spin_2y': [0.], 'spin_2z': [0.], 'luminosity_distance': [139.7643],
171 'ra': [-1.261573], 'spin_1y': [0.], 'spin_1x': [0.], 'spin_1z': [0.],
172 'psi': [1.75], 'phase': [0.], 'iota': [1.0471976],
173 'mass_1': [53.333332], 'mass_2': [26.666668],
174 'symmetric_mass_ratio': [0.22222222], 'a_1': float('nan'),
175 'a_2': float('nan'), 'tilt_1': float('nan'), 'tilt_2': float('nan'),
176 'phi_jl': float('nan'), 'phi_12': float('nan'),
177 'theta_jn': float('nan'), 'redshift': float('nan'),
178 'mass_1_source': float('nan'), 'mass_2_source': float('nan'),
179 'log_likelihood': float('nan')
180 }
181 assert all(i in list(true.keys()) for i in self.parameters)
182 for i in true.keys():
183 if not isinstance(true[i], list) and math.isnan(true[i]):
184 assert math.isnan(self.result.injection_parameters[i])
185 else:
186 np.testing.assert_almost_equal(
187 true[i], self.result.injection_parameters[i], 5
188 )
190 def test_calibration_data_in_results_file(self):
191 """Test the calibration_data_in_results_file property
192 """
193 pass
195 def test_add_injection_parameters_from_file(self):
196 """Test the add_injection_parameters_from_file method
197 """
198 pass
200 def test_add_fixed_parameters_from_config_file(self):
201 """Test the add_fixed_parameters_from_config_file method
202 """
203 pass
205 def test_to_lalinference_dat(self):
206 """Test the to_lalinference dat=True method
207 """
208 from pesummary.gw.file.standard_names import lalinference_map
210 self.result.to_lalinference(dat=True, outdir=tmpdir,
211 filename="lalinference_label.dat")
212 assert os.path.isfile(os.path.join(tmpdir, "lalinference_label.dat"))
213 data = np.genfromtxt(
214 os.path.join(tmpdir, "lalinference_label.dat"), names=True)
215 for param in data.dtype.names:
216 if param not in self.result.parameters:
217 pesummary_param = lalinference_map[param]
218 else:
219 pesummary_param = param
220 assert np.testing.assert_almost_equal(
221 data[param], self.result.samples_dict[pesummary_param], 8
222 ) is None
224 def test_file_format_read(self, path, file_format, _class):
225 """Test that when the file_format is specified, that correct class is used
226 """
227 super(GWBaseRead, self).test_file_format_read(
228 path, file_format, _class, function=GWRead
229 )
232class TestCoreJsonFile(BaseRead):
233 """Class to test loading in a JSON file with the core Read function
234 """
235 def setup_method(self):
236 """Setup the TestCoreJsonFile class
237 """
238 if not os.path.isdir(tmpdir):
239 os.mkdir(tmpdir)
240 self.parameters, self.samples = make_result_file(
241 outdir=tmpdir, extension="json", gw=False
242 )
243 self.path = os.path.join(tmpdir, "test.json")
244 self.result = Read(self.path)
246 def teardown_method(self):
247 """Remove all files and directories created from this class
248 """
249 if os.path.isdir(tmpdir):
250 shutil.rmtree(tmpdir)
252 def test_class_name(self):
253 """Test the class used to load in this file
254 """
255 assert isinstance(
256 self.result, pesummary.core.file.formats.default.SingleAnalysisDefault
257 )
259 def test_parameters(self):
260 """Test the parameter property of the default class
261 """
262 super(TestCoreJsonFile, self).test_parameters(self.parameters)
264 def test_samples(self):
265 """Test the samples property of the default class
266 """
267 super(TestCoreJsonFile, self).test_samples(self.samples)
269 def test_samples_dict(self):
270 true = [self.parameters, self.samples]
271 super(TestCoreJsonFile, self).test_samples_dict(true)
273 def test_version(self):
274 """Test the version property of the default class
275 """
276 super(TestCoreJsonFile, self).test_version()
278 def test_extra_kwargs(self):
279 """Test the extra_kwargs property of the default class
280 """
281 super(TestCoreJsonFile, self).test_extra_kwargs()
283 def test_injection_parameters(self):
284 """Test the injection_parameters property
285 """
286 true = {par: float("nan") for par in self.parameters}
287 super(TestCoreJsonFile, self).test_injection_parameters(true)
289 def test_to_dat(self):
290 """Test the to_dat method
291 """
292 super(TestCoreJsonFile, self).test_to_dat()
294 def test_file_format_read(self):
295 """Test that when the file_format is specified, that correct class is used
296 """
297 from pesummary.core.file.formats.default import SingleAnalysisDefault
299 super(TestCoreJsonFile, self).test_file_format_read(
300 self.path, "json", SingleAnalysisDefault
301 )
303 def test_downsample(self):
304 """Test that the posterior table is correctly downsampled
305 """
306 super(TestCoreJsonFile, self).test_downsample()
309class TestCoreHDF5File(BaseRead):
310 """Class to test loading in an HDF5 file with the core Read function
311 """
312 def setup_method(self):
313 """Setup the TestCoreHDF5File class
314 """
315 if not os.path.isdir(tmpdir):
316 os.mkdir(tmpdir)
317 self.parameters, self.samples = make_result_file(
318 outdir=tmpdir, extension="hdf5", gw=False
319 )
320 self.path = os.path.join(tmpdir, "test.h5")
321 self.result = Read(self.path)
323 def teardown_method(self):
324 """Remove the files and directories created from this class
325 """
326 if os.path.isdir(tmpdir):
327 shutil.rmtree(tmpdir)
329 def test_class_name(self):
330 """Test the class used to load in this file
331 """
332 assert isinstance(
333 self.result, pesummary.core.file.formats.default.SingleAnalysisDefault
334 )
336 def test_parameters(self):
337 """Test the parameter property of the default class
338 """
339 super(TestCoreHDF5File, self).test_parameters(self.parameters)
341 def test_samples(self):
342 """Test the samples property of the default class
343 """
344 super(TestCoreHDF5File, self).test_samples(self.samples)
346 def test_samples_dict(self):
347 """Test the samples_dict property of the default class
348 """
349 true = [self.parameters, self.samples]
350 super(TestCoreHDF5File, self).test_samples_dict(true)
352 def test_version(self):
353 """Test the version property of the default class
354 """
355 super(TestCoreHDF5File, self).test_version()
357 def test_extra_kwargs(self):
358 """Test the extra_kwargs property of the default class
359 """
360 super(TestCoreHDF5File, self).test_extra_kwargs()
362 def test_injection_parameters(self):
363 """Test the injection_parameters property
364 """
365 true = {par: float("nan") for par in self.parameters}
366 super(TestCoreHDF5File, self).test_injection_parameters(true)
368 def test_to_dat(self):
369 """Test the to_dat method
370 """
371 super(TestCoreHDF5File, self).test_to_dat()
373 def test_file_format_read(self):
374 """Test that when the file_format is specified, that correct class is used
375 """
376 from pesummary.core.file.formats.default import SingleAnalysisDefault
378 super(TestCoreHDF5File, self).test_file_format_read(self.path, "hdf5", SingleAnalysisDefault)
380 def test_downsample(self):
381 """Test that the posterior table is correctly downsampled
382 """
383 super(TestCoreHDF5File, self).test_downsample()
386class TestCoreCSVFile(BaseRead):
387 """Class to test loading in a csv file with the core Read function
388 """
389 def setup_method(self):
390 """Setup the TestCoreCSVFile class
391 """
392 if not os.path.isdir(tmpdir):
393 os.mkdir(tmpdir)
394 self.parameters, self.samples = make_result_file(
395 extension="csv", outdir=tmpdir, gw=False
396 )
397 self.path = os.path.join(tmpdir, "test.csv")
398 self.result = Read(self.path)
400 def teardown_method(self):
401 """Remove the files and directories created from this class
402 """
403 if os.path.isdir(tmpdir):
404 shutil.rmtree(tmpdir)
406 def test_class_name(self):
407 """Test the class used to load in this file
408 """
409 assert isinstance(
410 self.result, pesummary.core.file.formats.default.SingleAnalysisDefault
411 )
413 def test_parameters(self):
414 """Test the parameter property of the default class
415 """
416 super(TestCoreCSVFile, self).test_parameters(self.parameters)
418 def test_samples(self):
419 """Test the samples property of the default class
420 """
421 super(TestCoreCSVFile, self).test_samples(self.samples)
423 def test_samples_dict(self):
424 """Test the samples_dict property of the default class
425 """
426 true = [self.parameters, self.samples]
427 super(TestCoreCSVFile, self).test_samples_dict(true)
429 def test_version(self):
430 """Test the version property of the default class
431 """
432 super(TestCoreCSVFile, self).test_version()
434 def test_extra_kwargs(self):
435 """Test the extra_kwargs property of the default class
436 """
437 super(TestCoreCSVFile, self).test_extra_kwargs()
439 def test_injection_parameters(self):
440 """Test the injection_parameters property
441 """
442 true = {par: float("nan") for par in self.parameters}
443 super(TestCoreCSVFile, self).test_injection_parameters(true)
445 def test_to_dat(self):
446 """Test the to_dat method
447 """
448 super(TestCoreCSVFile, self).test_to_dat()
450 def test_file_format_read(self):
451 """Test that when the file_format is specified, that correct class is used
452 """
453 from pesummary.core.file.formats.default import SingleAnalysisDefault
455 super(TestCoreCSVFile, self).test_file_format_read(self.path, "csv", SingleAnalysisDefault)
457 def test_downsample(self):
458 """Test that the posterior table is correctly downsampled
459 """
460 super(TestCoreCSVFile, self).test_downsample()
463class TestCoreNumpyFile(BaseRead):
464 """Class to test loading in a numpy file with the core Read function
465 """
466 def setup_method(self):
467 """Setup the TestCoreNumpyFile class
468 """
469 if not os.path.isdir(tmpdir):
470 os.mkdir(tmpdir)
471 self.parameters, self.samples = make_result_file(
472 outdir=tmpdir, extension="npy", gw=False
473 )
474 self.path = os.path.join(tmpdir, "test.npy")
475 self.result = Read(self.path)
477 def teardown_method(self):
478 """Remove the files and directories created from this class
479 """
480 if os.path.isdir(tmpdir):
481 shutil.rmtree(tmpdir)
483 def test_class_name(self):
484 """Test the class used to load in this file
485 """
486 assert isinstance(
487 self.result, pesummary.core.file.formats.default.SingleAnalysisDefault
488 )
490 def test_parameters(self):
491 """Test the parameter property of the default class
492 """
493 super(TestCoreNumpyFile, self).test_parameters(self.parameters)
495 def test_samples(self):
496 """Test the samples property of the default class
497 """
498 super(TestCoreNumpyFile, self).test_samples(self.samples)
500 def test_samples_dict(self):
501 """Test the samples_dict property of the default class
502 """
503 true = [self.parameters, self.samples]
504 super(TestCoreNumpyFile, self).test_samples_dict(true)
506 def test_version(self):
507 """Test the version property of the default class
508 """
509 super(TestCoreNumpyFile, self).test_version()
511 def test_extra_kwargs(self):
512 """Test the extra_kwargs property of the default class
513 """
514 super(TestCoreNumpyFile, self).test_extra_kwargs()
516 def test_injection_parameters(self):
517 """Test the injection_parameters property
518 """
519 true = {par: float("nan") for par in self.parameters}
520 super(TestCoreNumpyFile, self).test_injection_parameters(true)
522 def test_to_dat(self):
523 """Test the to_dat method
524 """
525 super(TestCoreNumpyFile, self).test_to_dat()
527 def test_file_format_read(self):
528 """Test that when the file_format is specified, that correct class is used
529 """
530 from pesummary.core.file.formats.default import SingleAnalysisDefault
532 super(TestCoreNumpyFile, self).test_file_format_read(self.path, "numpy", SingleAnalysisDefault)
534 def test_downsample(self):
535 """Test that the posterior table is correctly downsampled
536 """
537 super(TestCoreNumpyFile, self).test_downsample()
540class TestCoreDatFile(BaseRead):
541 """Class to test loading in an dat file with the core Read function
542 """
543 def setup_method(self):
544 """Setup the TestCoreDatFile class
545 """
546 if not os.path.isdir(tmpdir):
547 os.mkdir(tmpdir)
548 self.parameters, self.samples = make_result_file(
549 outdir=tmpdir, extension="dat", gw=False
550 )
551 self.path = os.path.join(tmpdir, "test.dat")
552 self.result = Read(self.path)
554 def teardown_method(self):
555 """Remove the files and directories created from this class
556 """
557 if os.path.isdir(tmpdir):
558 shutil.rmtree(tmpdir)
560 def test_class_name(self):
561 """Test the class used to load in this file
562 """
563 assert isinstance(
564 self.result, pesummary.core.file.formats.default.SingleAnalysisDefault
565 )
567 def test_parameters(self):
568 """Test the parameter property of the default class
569 """
570 super(TestCoreDatFile, self).test_parameters(self.parameters)
572 def test_samples(self):
573 """Test the samples property of the default class
574 """
575 super(TestCoreDatFile, self).test_samples(self.samples)
577 def test_samples_dict(self):
578 """Test the samples_dict property of the default class
579 """
580 true = [self.parameters, self.samples]
581 super(TestCoreDatFile, self).test_samples_dict(true)
583 def test_version(self):
584 """Test the version property of the default class
585 """
586 super(TestCoreDatFile, self).test_version()
588 def test_extra_kwargs(self):
589 """Test the extra_kwargs property of the default class
590 """
591 super(TestCoreDatFile, self).test_extra_kwargs()
593 def test_injection_parameters(self):
594 """Test the injection_parameters property
595 """
596 true = {par: float("nan") for par in self.parameters}
597 super(TestCoreDatFile, self).test_injection_parameters(true)
599 def test_to_dat(self):
600 """Test the to_dat method
601 """
602 super(TestCoreDatFile, self).test_to_dat()
604 def test_file_format_read(self):
605 """Test that when the file_format is specified, that correct class is used
606 """
607 from pesummary.core.file.formats.default import SingleAnalysisDefault
609 super(TestCoreDatFile, self).test_file_format_read(self.path, "dat", SingleAnalysisDefault)
611 def test_downsample(self):
612 """Test that the posterior table is correctly downsampled
613 """
614 super(TestCoreDatFile, self).test_downsample()
617class BilbyFile(BaseRead):
618 """Base class to test loading in a bilby file with the core Read function
619 """
620 def test_class_name(self):
621 """Test the class used to load in this file
622 """
623 assert isinstance(self.result, pesummary.core.file.formats.bilby.Bilby)
625 def test_parameters(self):
626 """Test the parameter property of the bilby class
627 """
628 super(BilbyFile, self).test_parameters(self.parameters)
630 def test_samples(self):
631 """Test the samples property of the bilby class
632 """
633 super(BilbyFile, self).test_samples(self.samples)
635 def test_samples_dict(self):
636 """Test the samples_dict property of the bilby class
637 """
638 true = [self.parameters, self.samples]
639 super(BilbyFile, self).test_samples_dict(true)
641 def test_version(self):
642 """Test the version property of the bilby class
643 """
644 true = "bilby=0.5.3:"
645 super(BilbyFile, self).test_version(true)
647 def test_extra_kwargs(self):
648 """Test the extra_kwargs property of the default class
649 """
650 true = {"sampler": {
651 "log_bayes_factor": 0.5,
652 "log_noise_evidence": 0.1,
653 "log_evidence": 0.2,
654 "log_evidence_err": 0.1},
655 "meta_data": {'time_marginalization': True},
656 "other": {"likelihood": {"time_marginalization": "True"}}
657 }
658 super(BilbyFile, self).test_extra_kwargs(true)
660 def test_injection_parameters(self, true):
661 """Test the injection_parameters property
662 """
663 super(BilbyFile, self).test_injection_parameters(true)
665 def test_file_format_read(self, path, file_format):
666 """Test that when the file_format is specified, that correct class is used
667 """
668 from pesummary.core.file.formats.bilby import Bilby
670 super(BilbyFile, self).test_file_format_read(path, file_format, Bilby)
672 def test_priors(self, read_function=Read):
673 """Test that the priors are correctly extracted from the bilby result
674 file
675 """
676 for param, prior in self.result.priors["samples"].items():
677 assert isinstance(prior, np.ndarray)
678 f = read_function(self.path, disable_prior=True)
679 assert not len(f.priors["samples"])
680 f = read_function(self.path, nsamples_for_prior=200)
681 params = list(f.priors["samples"].keys())
682 assert len(f.priors["samples"][params[0]]) == 200
685class DingoFile(BaseRead):
686 """Base class to test loading in a bilby file with the core Read function"""
688 def test_class_name(self):
689 """Test the class used to load in this file"""
690 assert isinstance(self.result, pesummary.core.file.formats.dingo.Dingo)
692 def test_parameters(self):
693 """Test the parameter property of the dingo class"""
694 super(DingoFile, self).test_parameters(self.parameters)
696 def test_samples(self):
697 """
698 Test the samples property of the dingo class
699 The problem is, we want unweighted samples to pass to
700 pesummary. But this means we need to resample
701 according to the weights. This means that
702 the samples which were used to instantiate the
703 result object are not in the same order (or neccesarily
704 the same) as the pesummary samples. So we can't easily
705 compare that the unweighted pesummary samples are the same
706 as the samples used to instantiate the dingo result
707 """
708 assert len(self.result.samples) == 1000
709 assert len(self.result.samples[0]) == 18
711 def test_samples_dict(self):
712 """
713 See test_samples documentation
714 """
715 pass
717 def test_version(self):
718 """Test the version property of the dingo class"""
719 import dingo
720 true = f"dingo={dingo.__version__}"
721 super(DingoFile, self).test_version(true)
723 def test_extra_kwargs(self):
724 """Test the extra_kwargs property of the default class"""
725 # extra kwargs may contain a huge amount of information regarding
726 # the training_settings or dataset_settings. It is unrealistic
727 # to have a true value for this test. So instead just test a
728 # few properties
729 assert self.result.extra_kwargs["sampler"] in (
730 {"pe_algorithm": "dingo-is", "nsamples": 1000},
731 {"pe_algorithm": "dingo", "nsamples": 1000},
732 )
734 def test_injection_parameters(self, true=None, pesummary=False):
735 # not testing injections as this is not implemented in the
736 # current conda version of dingo (0.8.4)
737 self.result.injection_parameters = None
738 return super().test_injection_parameters(true, pesummary)
740 def test_file_format_read(self, path, file_format):
741 """Test that when the file_format is specified, that correct class is used"""
742 from pesummary.core.file.formats.dingo import Dingo
744 super(DingoFile, self).test_file_format_read(path, file_format, Dingo)
746 def test_priors(self, read_function=Read):
747 """Test that the priors are correctly extracted from the dingo result
748 file
749 """
750 for param, prior in self.result.priors["samples"].items():
751 assert isinstance(prior, np.ndarray)
752 f = read_function(self.path, disable_prior=True)
753 assert not len(f.priors["samples"])
754 f = read_function(self.path, nsamples_for_prior=200)
755 params = list(f.priors["samples"].keys())
756 assert len(f.priors["samples"][params[0]]) == 200
759class TestCoreJsonBilbyFile(BilbyFile):
760 """Class to test loading in a bilby json file with the core Read function
761 """
762 def setup_method(self):
763 """Setup the TestCoreBilbyFile class
764 """
765 if not os.path.isdir(tmpdir):
766 os.mkdir(tmpdir)
767 self.parameters, self.samples = make_result_file(
768 outdir=tmpdir, extension="json", gw=False, bilby=True
769 )
770 self.path = os.path.join(tmpdir, "test.json")
771 self.result = Read(self.path)
773 def teardown_method(self):
774 """Remove the files and directories created from this class
775 """
776 if os.path.isdir(tmpdir):
777 shutil.rmtree(tmpdir)
779 def test_class_name(self):
780 """Test the class used to load in this file
781 """
782 super(TestCoreJsonBilbyFile, self).test_class_name()
784 def test_parameters(self):
785 """Test the parameter property of the bilby class
786 """
787 super(TestCoreJsonBilbyFile, self).test_parameters()
789 def test_samples(self):
790 """Test the samples property of the bilby class
791 """
792 super(TestCoreJsonBilbyFile, self).test_samples()
794 def test_samples_dict(self):
795 """Test the samples_dict property of the bilby class
796 """
797 super(TestCoreJsonBilbyFile, self).test_samples_dict()
799 def test_version(self):
800 """Test the version property of the default class
801 """
802 super(TestCoreJsonBilbyFile, self).test_version()
804 def test_extra_kwargs(self):
805 """Test the extra_kwargs property of the default class
806 """
807 super(TestCoreJsonBilbyFile, self).test_extra_kwargs()
809 def test_injection_parameters(self):
810 """Test the injection_parameters property
811 """
812 true = {par: 1. for par in self.parameters}
813 super(TestCoreJsonBilbyFile, self).test_injection_parameters(true)
815 def test_to_dat(self):
816 """Test the to_dat method
817 """
818 super(TestCoreJsonBilbyFile, self).test_to_dat()
820 def test_file_format_read(self):
821 """Test that when the file_format is specified, that correct class is used
822 """
823 super(TestCoreJsonBilbyFile, self).test_file_format_read(self.path, "bilby")
825 def test_downsample(self):
826 """Test that the posterior table is correctly downsampled
827 """
828 super(TestCoreJsonBilbyFile, self).test_downsample()
830 def test_priors(self):
831 """Test that the priors are correctly extracted from the bilby result
832 file
833 """
834 super(TestCoreJsonBilbyFile, self).test_priors()
838class TestCoreHDF5BilbyFile(BilbyFile):
839 """Class to test loading in a bilby hdf5 file with the core Read function
840 """
841 def setup_method(self):
842 """Setup the TestCoreBilbyFile class
843 """
844 if not os.path.isdir(tmpdir):
845 os.mkdir(tmpdir)
846 self.parameters, self.samples = make_result_file(
847 outdir=tmpdir, extension="hdf5", gw=False, bilby=True
848 )
849 self.path = os.path.join(tmpdir, "test.hdf5")
850 self.result = Read(self.path)
852 def teardown_method(self):
853 """Remove the files and directories created from this class
854 """
855 if os.path.isdir(tmpdir):
856 shutil.rmtree(tmpdir)
858 def test_class_name(self):
859 """Test the class used to load in this file
860 """
861 super(TestCoreHDF5BilbyFile, self).test_class_name()
863 def test_parameters(self):
864 """Test the parameter property of the bilby class
865 """
866 super(TestCoreHDF5BilbyFile, self).test_parameters()
868 def test_samples(self):
869 """Test the samples property of the bilby class
870 """
871 super(TestCoreHDF5BilbyFile, self).test_samples()
873 def test_samples_dict(self):
874 """Test the samples_dict property of the bilby class
875 """
876 super(TestCoreHDF5BilbyFile, self).test_samples_dict()
878 def test_version(self):
879 """Test the version property of the default class
880 """
881 super(TestCoreHDF5BilbyFile, self).test_version()
883 def test_extra_kwargs(self):
884 """Test the extra_kwargs property of the default class
885 """
886 super(TestCoreHDF5BilbyFile, self).test_extra_kwargs()
888 def test_injection_parameters(self):
889 """Test the injection_parameters property
890 """
891 true = {par: 1. for par in self.parameters}
892 super(TestCoreHDF5BilbyFile, self).test_injection_parameters(true)
894 def test_to_dat(self):
895 """Test the to_dat method
896 """
897 super(TestCoreHDF5BilbyFile, self).test_to_dat()
899 def test_file_format_read(self):
900 """Test that when the file_format is specified, that correct class is used
901 """
902 super(TestCoreHDF5BilbyFile, self).test_file_format_read(self.path, "bilby")
904 def test_downsample(self):
905 """Test that the posterior table is correctly downsampled
906 """
907 super(TestCoreHDF5BilbyFile, self).test_downsample()
909 def test_priors(self):
910 """Test that the priors are correctly extracted from the bilby result
911 file
912 """
913 super(TestCoreHDF5BilbyFile, self).test_priors(read_function=Read)
916class TestCoreHDF5DingoFile(DingoFile):
917 """Class to test loading in a bilby hdf5 file with the core Read function"""
919 def setup_method(self):
920 """Setup the TestCoreBilbyFile class"""
921 if not os.path.isdir(tmpdir):
922 os.mkdir(tmpdir)
923 self.parameters, self.samples = make_result_file(
924 outdir=tmpdir, extension="hdf5", gw=True, dingo=True
925 )
926 self.path = os.path.join(tmpdir, "test.hdf5")
927 self.result = Read(self.path)
929 def teardown_method(self):
930 """Remove the files and directories created from this class"""
931 if os.path.isdir(tmpdir):
932 shutil.rmtree(tmpdir)
934 def test_class_name(self):
935 """Test the class used to load in this file"""
936 super(TestCoreHDF5DingoFile, self).test_class_name()
938 def test_parameters(self):
939 """Test the parameter property of the dingo class"""
940 super(TestCoreHDF5DingoFile, self).test_parameters()
942 def test_samples(self):
943 """Test the samples property of the bilby class"""
944 super(TestCoreHDF5DingoFile, self).test_samples()
946 def test_samples_dict(self):
947 """Test the samples_dict property of the dingo class"""
948 super(TestCoreHDF5DingoFile, self).test_samples_dict()
950 def test_version(self):
951 """Test the version property of the default class"""
952 super(TestCoreHDF5DingoFile, self).test_version()
954 def test_extra_kwargs(self):
955 """Test the extra_kwargs property of the default class"""
956 super(TestCoreHDF5DingoFile, self).test_extra_kwargs()
958 # not testing injections as this is not implemented in the
959 # current conda version of dingo (0.5.11)
961 def test_to_dat(self):
962 """Test the to_dat method"""
963 super(TestCoreHDF5DingoFile, self).test_to_dat()
965 def test_file_format_read(self):
966 """Test that when the file_format is specified, that correct class is used"""
967 super(TestCoreHDF5DingoFile, self).test_file_format_read(self.path, "dingo")
969 def test_priors(self):
970 """Test that the priors are correctly extracted from the bilby result
971 file
972 """
973 super(TestCoreHDF5DingoFile, self).test_priors(read_function=Read)
976class PESummaryFile(BaseRead):
977 """Base class to test loading in a PESummary file with the core Read function"""
979 def test_class_name(self):
980 """Test the class used to load in this file"""
981 assert isinstance(self.result, pesummary.core.file.formats.pesummary.PESummary)
983 def test_parameters(self):
984 """Test the parameter property of the PESummary class"""
985 super(PESummaryFile, self).test_parameters(self.parameters, pesummary=True)
987 def test_samples(self):
988 """Test the samples property of the PESummary class"""
989 super(PESummaryFile, self).test_samples(self.samples, pesummary=True)
991 def test_version(self):
992 """Test the version property of the default class"""
993 true = ["No version information found"]
994 super(PESummaryFile, self).test_version(true)
996 def test_extra_kwargs(self):
997 """Test the extra_kwargs property of the default class
998 """
999 true = [{"sampler": {"log_evidence": 0.5}, "meta_data": {}}]
1000 super(PESummaryFile, self).test_extra_kwargs(true)
1002 def test_samples_dict(self):
1003 """Test the samples_dict property
1004 """
1005 assert list(self.result.samples_dict.keys()) == ["label"]
1007 parameters = self.parameters
1008 samples = self.samples
1009 for num, param in enumerate(parameters):
1010 specific_samples = [i[num] for i in samples]
1011 drawn_samples = self.result.samples_dict["label"][param]
1012 np.testing.assert_almost_equal(drawn_samples, specific_samples)
1014 def test_to_bilby(self):
1015 """Test the to_bilby method
1016 """
1017 from pesummary.core.file.read import is_bilby_json_file
1019 bilby_object = self.result.to_bilby(save=False)["label"]
1020 bilby_object.save_to_file(
1021 filename=os.path.join(tmpdir, "bilby.json"))
1022 assert is_bilby_json_file(os.path.join(tmpdir, "bilby.json"))
1024 def test_to_dat(self):
1025 """Test the to_dat method
1026 """
1027 self.result.to_dat(
1028 outdir=tmpdir, filenames={"label": "pesummary_label.dat"}
1029 )
1030 assert os.path.isfile(os.path.join(tmpdir, "pesummary_label.dat"))
1031 data = np.genfromtxt(
1032 os.path.join(tmpdir, "pesummary_label.dat"), names=True)
1033 assert all(i in self.parameters for i in list(data.dtype.names))
1034 assert all(i in list(data.dtype.names) for i in self.parameters)
1036 def test_downsample(self):
1037 """Test the .downsample method
1038 """
1039 old_samples_dict = self.result.samples_dict
1040 nsamples = 50
1041 self.result.downsample(nsamples)
1042 for num, label in enumerate(self.result.labels):
1043 assert self.result.samples_dict[label].number_of_samples == nsamples
1044 for param in self.parameters[num]:
1045 assert all(
1046 samp in old_samples_dict[label][param] for samp in
1047 self.result.samples_dict[label][param]
1048 )
1049 for idx in range(nsamples):
1050 samp_inds = [
1051 old_samples_dict[label][param].tolist().index(
1052 self.result.samples_dict[label][param][idx]
1053 ) for param in self.parameters[num]
1054 ]
1055 assert len(set(samp_inds)) == 1
1059class TestCoreJsonPESummaryFile(PESummaryFile):
1060 """Class to test loading in a PESummary json file with the core Read
1061 function
1062 """
1063 def setup_method(self):
1064 """Setup the TestCorePESummaryFile class
1065 """
1066 if not os.path.isdir(tmpdir):
1067 os.mkdir(tmpdir)
1068 self.parameters, self.samples = make_result_file(
1069 outdir=tmpdir, extension="json", gw=False, pesummary=True
1070 )
1071 self.result = Read(os.path.join(tmpdir, "test.json"))
1073 def teardown_method(self):
1074 """Remove the files and directories created from this class
1075 """
1076 if os.path.isdir(tmpdir):
1077 shutil.rmtree(tmpdir)
1079 def test_class_name(self):
1080 """Test the class used to load in this file
1081 """
1082 super(TestCoreJsonPESummaryFile, self).test_class_name()
1084 def test_parameters(self):
1085 """Test the parameter property of the PESummary class
1086 """
1087 super(TestCoreJsonPESummaryFile, self).test_parameters()
1089 def test_samples(self):
1090 """Test the samples property of the PESummary class
1091 """
1092 super(TestCoreJsonPESummaryFile, self).test_samples()
1094 def test_samples_dict(self):
1095 """Test the samples_dict property
1096 """
1097 super(TestCoreJsonPESummaryFile, self).test_samples_dict()
1099 def test_version(self):
1100 """Test the version property of the default class
1101 """
1102 super(TestCoreJsonPESummaryFile, self).test_version()
1104 def test_extra_kwargs(self):
1105 """Test the extra_kwargs property of the default class
1106 """
1107 super(TestCoreJsonPESummaryFile, self).test_extra_kwargs()
1109 def test_injection_parameters(self):
1110 """Test the injection_parameters property
1111 """
1112 true = {par: float("nan") for par in self.parameters}
1113 super(TestCoreJsonPESummaryFile, self).test_injection_parameters(
1114 true, pesummary=True)
1116 def test_to_bilby(self):
1117 """Test the to_bilby method
1118 """
1119 super(TestCoreJsonPESummaryFile, self).test_to_bilby()
1121 def test_to_dat(self):
1122 """Test the to_dat method
1123 """
1124 super(TestCoreJsonPESummaryFile, self).test_to_dat()
1126 def test_file_format_read(self):
1127 """Test that when the file_format is specified, that correct class is used
1128 """
1129 pass
1131 def test_downsample(self):
1132 """Test that the posterior table is correctly downsampled
1133 """
1134 super(TestCoreJsonPESummaryFile, self).test_downsample()
1137class TestCoreHDF5PESummaryFile(PESummaryFile):
1138 """Class to test loading in a PESummary hdf5 file with the core Read
1139 function
1140 """
1141 def setup_method(self):
1142 """Setup the TestCorePESummaryFile class
1143 """
1144 if not os.path.isdir(tmpdir):
1145 os.mkdir(tmpdir)
1146 self.parameters, self.samples = make_result_file(
1147 outdir=tmpdir, extension="hdf5", gw=False, pesummary=True
1148 )
1149 self.result = Read(os.path.join(tmpdir, "test.h5"))
1151 def teardown_method(self):
1152 """Remove the files and directories created from this class
1153 """
1154 if os.path.isdir(tmpdir):
1155 shutil.rmtree(tmpdir)
1157 def test_class_name(self):
1158 """Test the class used to load in this file
1159 """
1160 super(TestCoreHDF5PESummaryFile, self).test_class_name()
1162 def test_parameters(self):
1163 """Test the parameter property of the PESummary class
1164 """
1165 super(TestCoreHDF5PESummaryFile, self).test_parameters()
1167 def test_samples(self):
1168 """Test the samples property of the PESummary class
1169 """
1170 super(TestCoreHDF5PESummaryFile, self).test_samples()
1172 def test_samples_dict(self):
1173 """Test the samples_dict property
1174 """
1175 super(TestCoreHDF5PESummaryFile, self).test_samples_dict()
1177 def test_version(self):
1178 """Test the version property of the default class
1179 """
1180 super(TestCoreHDF5PESummaryFile, self).test_version()
1182 def test_extra_kwargs(self):
1183 """Test the extra_kwargs property of the default class
1184 """
1185 super(TestCoreHDF5PESummaryFile, self).test_extra_kwargs()
1187 def test_injection_parameters(self):
1188 """Test the injection_parameters property
1189 """
1190 true = {par: float("nan") for par in self.parameters}
1191 super(TestCoreHDF5PESummaryFile, self).test_injection_parameters(
1192 true, pesummary=True)
1194 def test_to_bilby(self):
1195 """Test the to_bilby method
1196 """
1197 super(TestCoreHDF5PESummaryFile, self).test_to_bilby()
1199 def test_to_dat(self):
1200 """Test the to_dat method
1201 """
1202 super(TestCoreHDF5PESummaryFile, self).test_to_dat()
1204 def test_file_format_read(self):
1205 """Test that when the file_format is specified, that correct class is used
1206 """
1207 pass
1209 def test_downsample(self):
1210 """Test that the posterior table is correctly downsampled
1211 """
1212 super(TestCoreHDF5PESummaryFile, self).test_downsample()
1215class TestGWCSVFile(GWBaseRead):
1216 """Class to test loading in a csv file with the core Read function
1217 """
1218 def setup_method(self):
1219 """Setup the TestGWCSVFile class
1220 """
1221 if not os.path.isdir(tmpdir):
1222 os.mkdir(tmpdir)
1223 self.parameters, self.samples = make_result_file(
1224 outdir=tmpdir, extension="csv", gw=True
1225 )
1226 self.path = os.path.join(tmpdir, "test.csv")
1227 self.result = GWRead(self.path)
1229 def teardown_method(self):
1230 """Remove the files and directories created from this class
1231 """
1232 if os.path.isdir(tmpdir):
1233 shutil.rmtree(tmpdir)
1235 def test_class_name(self):
1236 """Test the class used to load in this file
1237 """
1238 assert isinstance(
1239 self.result, pesummary.gw.file.formats.default.SingleAnalysisDefault
1240 )
1242 def test_parameters(self):
1243 """Test the parameter property of the default class
1244 """
1245 super(TestGWCSVFile, self).test_parameters(self.parameters)
1247 def test_samples(self):
1248 """Test the samples property of the default class
1249 """
1250 super(TestGWCSVFile, self).test_samples(self.samples)
1252 def test_samples_dict(self):
1253 """Test the samples_dict property of the default class
1254 """
1255 true = [self.parameters, self.samples]
1256 super(TestGWCSVFile, self).test_samples_dict(true)
1258 def test_version(self):
1259 """Test the version property of the default class
1260 """
1261 super(TestGWCSVFile, self).test_version()
1263 def test_extra_kwargs(self):
1264 """Test the extra_kwargs property of the default class
1265 """
1266 super(TestGWCSVFile, self).test_extra_kwargs()
1268 def test_injection_parameters(self):
1269 """Test the injection_parameters property
1270 """
1271 true = {par: float("nan") for par in self.parameters}
1272 super(TestGWCSVFile, self).test_injection_parameters(true)
1274 def test_to_dat(self):
1275 """Test the to_dat method
1276 """
1277 super(TestGWCSVFile, self).test_to_dat()
1279 def test_to_lalinference_dat(self):
1280 """Test the to_lalinference dat=True method
1281 """
1282 super(TestGWCSVFile, self).test_to_lalinference_dat()
1284 def test_file_format_read(self):
1285 """Test that when the file_format is specified, that correct class is used
1286 """
1287 from pesummary.gw.file.formats.default import SingleAnalysisDefault
1289 super(TestGWCSVFile, self).test_file_format_read(
1290 self.path, "csv", SingleAnalysisDefault
1291 )
1294class TestGWNumpyFile(GWBaseRead):
1295 """Class to test loading in a npy file with the core Read function
1296 """
1297 def setup_method(self):
1298 """Setup the TestGWNumpyFile class
1299 """
1300 if not os.path.isdir(tmpdir):
1301 os.mkdir(tmpdir)
1302 self.parameters, self.samples = make_result_file(
1303 outdir=tmpdir, extension="npy", gw=True
1304 )
1305 self.path = os.path.join(tmpdir, "test.npy")
1306 self.result = GWRead(self.path)
1308 def teardown_method(self):
1309 """Remove the files and directories created from this class
1310 """
1311 if os.path.isdir(tmpdir):
1312 shutil.rmtree(tmpdir)
1314 def test_class_name(self):
1315 """Test the class used to load in this file
1316 """
1317 assert isinstance(
1318 self.result, pesummary.gw.file.formats.default.SingleAnalysisDefault
1319 )
1321 def test_parameters(self):
1322 """Test the parameter property of the default class
1323 """
1324 super(TestGWNumpyFile, self).test_parameters(self.parameters)
1326 def test_samples(self):
1327 """Test the samples property of the default class
1328 """
1329 super(TestGWNumpyFile, self).test_samples(self.samples)
1331 def test_samples_dict(self):
1332 """Test the samples_dict property of the default class
1333 """
1334 true = [self.parameters, self.samples]
1335 super(TestGWNumpyFile, self).test_samples_dict(true)
1337 def test_version(self):
1338 """Test the version property of the default class
1339 """
1340 super(TestGWNumpyFile, self).test_version()
1342 def test_extra_kwargs(self):
1343 """Test the extra_kwargs property of the default class
1344 """
1345 super(TestGWNumpyFile, self).test_extra_kwargs()
1347 def test_injection_parameters(self):
1348 """Test the injection_parameters property
1349 """
1350 true = {par: float("nan") for par in self.parameters}
1351 super(TestGWNumpyFile, self).test_injection_parameters(true)
1353 def test_to_dat(self):
1354 """Test the to_dat method
1355 """
1356 super(TestGWNumpyFile, self).test_to_dat()
1358 def test_to_lalinference_dat(self):
1359 """Test the to_lalinference dat=True method
1360 """
1361 super(TestGWNumpyFile, self).test_to_lalinference_dat()
1363 def test_file_format_read(self):
1364 """Test that when the file_format is specified, that correct class is used
1365 """
1366 from pesummary.gw.file.formats.default import SingleAnalysisDefault
1368 super(TestGWNumpyFile, self).test_file_format_read(
1369 self.path, "numpy", SingleAnalysisDefault
1370 )
1373class TestGWDatFile(GWBaseRead):
1374 """Class to test loading in an dat file with the core Read function
1375 """
1376 def setup_method(self):
1377 """Setup the TestGWDatFile class
1378 """
1379 if not os.path.isdir(tmpdir):
1380 os.mkdir(tmpdir)
1381 self.parameters, self.samples = make_result_file(
1382 outdir=tmpdir, extension="dat", gw=True
1383 )
1384 self.path = os.path.join(tmpdir, "test.dat")
1385 self.result = GWRead(self.path)
1387 def teardown_method(self):
1388 """Remove the files and directories created from this class
1389 """
1390 if os.path.isdir(tmpdir):
1391 shutil.rmtree(tmpdir)
1393 def test_class_name(self):
1394 """Test the class used to load in this file
1395 """
1396 assert isinstance(
1397 self.result, pesummary.gw.file.formats.default.SingleAnalysisDefault
1398 )
1400 def test_parameters(self):
1401 """Test the parameter property of the default class
1402 """
1403 super(TestGWDatFile, self).test_parameters(self.parameters)
1405 def test_samples(self):
1406 """Test the samples property of the default class
1407 """
1408 super(TestGWDatFile, self).test_samples(self.samples)
1410 def test_samples_dict(self):
1411 """Test the samples_dict property of the default class
1412 """
1413 true = [self.parameters, self.samples]
1414 super(TestGWDatFile, self).test_samples_dict(true)
1416 def test_version(self):
1417 """Test the version property of the default class
1418 """
1419 super(TestGWDatFile, self).test_version()
1421 def test_extra_kwargs(self):
1422 """Test the extra_kwargs property of the default class
1423 """
1424 super(TestGWDatFile, self).test_extra_kwargs()
1426 def test_injection_parameters(self):
1427 """Test the injection_parameters property
1428 """
1429 true = {par: float("nan") for par in self.parameters}
1430 super(TestGWDatFile, self).test_injection_parameters(true)
1432 def test_to_dat(self):
1433 """Test the to_dat method
1434 """
1435 super(TestGWDatFile, self).test_to_dat()
1437 def test_to_lalinference_dat(self):
1438 """Test the to_lalinference dat=True method
1439 """
1440 super(TestGWDatFile, self).test_to_lalinference_dat()
1442 def test_file_format_read(self):
1443 """Test that when the file_format is specified, that correct class is used
1444 """
1445 from pesummary.gw.file.formats.default import SingleAnalysisDefault
1447 super(TestGWDatFile, self).test_file_format_read(
1448 self.path, "dat", SingleAnalysisDefault
1449 )
1451 def test_downsample(self):
1452 """Test that the posterior table is correctly downsampled
1453 """
1454 super(TestGWDatFile, self).test_downsample()
1457class TestGWHDF5File(GWBaseRead):
1458 """Class to test loading in an HDF5 file with the gw Read function
1459 """
1460 def setup_method(self):
1461 """Setup the TestCoreHDF5File class
1462 """
1463 if not os.path.isdir(tmpdir):
1464 os.mkdir(tmpdir)
1465 self.parameters, self.samples = make_result_file(
1466 outdir=tmpdir, extension="hdf5", gw=True
1467 )
1468 self.path = os.path.join(tmpdir, "test.h5")
1469 self.result = GWRead(self.path)
1471 def teardown_method(self):
1472 """Remove the files and directories created from this class
1473 """
1474 if os.path.isdir(tmpdir):
1475 shutil.rmtree(tmpdir)
1477 def test_class_name(self):
1478 """Test the class used to load in this file
1479 """
1480 assert isinstance(
1481 self.result, pesummary.gw.file.formats.default.SingleAnalysisDefault
1482 )
1484 def test_parameters(self):
1485 """Test the parameter property of the default class
1486 """
1487 super(TestGWHDF5File, self).test_parameters(self.parameters)
1489 def test_samples(self):
1490 """Test the samples property of the default class
1491 """
1492 super(TestGWHDF5File, self).test_samples(self.samples)
1494 def test_samples_dict(self):
1495 """Test the samples_dict property of the default class
1496 """
1497 true = [self.parameters, self.samples]
1498 super(TestGWHDF5File, self).test_samples_dict(true)
1500 def test_version(self):
1501 """Test the version property of the default class
1502 """
1503 super(TestGWHDF5File, self).test_version()
1505 def test_extra_kwargs(self):
1506 """Test the extra_kwargs property of the default class
1507 """
1508 super(TestGWHDF5File, self).test_extra_kwargs()
1510 def test_injection_parameters(self):
1511 """Test the injection_parameters property
1512 """
1513 true = {par: float("nan") for par in self.parameters}
1514 super(TestGWHDF5File, self).test_injection_parameters(true)
1516 def test_to_dat(self):
1517 """Test the to_dat method
1518 """
1519 super(TestGWHDF5File, self).test_to_dat()
1521 def test_to_lalinference_dat(self):
1522 """Test the to_lalinference dat=True method
1523 """
1524 super(TestGWHDF5File, self).test_to_lalinference_dat()
1526 def test_file_format_read(self):
1527 """Test that when the file_format is specified, that correct class is used
1528 """
1529 from pesummary.gw.file.formats.default import SingleAnalysisDefault
1531 super(TestGWHDF5File, self).test_file_format_read(
1532 self.path, "hdf5", SingleAnalysisDefault
1533 )
1535 def test_downsample(self):
1536 """Test that the posterior table is correctly downsampled
1537 """
1538 super(TestGWHDF5File, self).test_downsample()
1541class TestGWJsonFile(GWBaseRead):
1542 """Class to test loading in an json file with the gw Read function
1543 """
1544 def setup_method(self):
1545 """Setup the TestGWDatFile class
1546 """
1547 if not os.path.isdir(tmpdir):
1548 os.mkdir(tmpdir)
1549 self.parameters, self.samples = make_result_file(
1550 outdir=tmpdir, extension="json", gw=True
1551 )
1552 self.path = os.path.join(tmpdir, "test.json")
1553 self.result = GWRead(self.path)
1555 def teardown_method(self):
1556 """Remove the files and directories created from this class
1557 """
1558 if os.path.isdir(tmpdir):
1559 shutil.rmtree(tmpdir)
1561 def test_class_name(self):
1562 """Test the class used to load in this file
1563 """
1564 assert isinstance(
1565 self.result, pesummary.gw.file.formats.default.SingleAnalysisDefault
1566 )
1568 def test_parameters(self):
1569 """Test the parameter property of the default class
1570 """
1571 super(TestGWJsonFile, self).test_parameters(self.parameters)
1573 def test_samples(self):
1574 """Test the samples property of the default class
1575 """
1576 super(TestGWJsonFile, self).test_samples(self.samples)
1578 def test_samples_dict(self):
1579 """Test the samples_dict property of the default class
1580 """
1581 true = [self.parameters, self.samples]
1582 super(TestGWJsonFile, self).test_samples_dict(true)
1584 def test_version(self):
1585 """Test the version property of the default class
1586 """
1587 super(TestGWJsonFile, self).test_version()
1589 def test_extra_kwargs(self):
1590 """Test the extra_kwargs property of the default class
1591 """
1592 super(TestGWJsonFile, self).test_extra_kwargs()
1594 def test_injection_parameters(self):
1595 """Test the injection_parameters property
1596 """
1597 true = {par: float("nan") for par in self.parameters}
1598 super(TestGWJsonFile, self).test_injection_parameters(true)
1600 def test_to_dat(self):
1601 """Test the to_dat method
1602 """
1603 super(TestGWJsonFile, self).test_to_dat()
1605 def test_to_lalinference_dat(self):
1606 """Test the to_lalinference dat=True method
1607 """
1608 super(TestGWJsonFile, self).test_to_lalinference_dat()
1610 def test_file_format_read(self):
1611 """Test that when the file_format is specified, that correct class is used
1612 """
1613 from pesummary.gw.file.formats.default import SingleAnalysisDefault
1615 super(TestGWJsonFile, self).test_file_format_read(
1616 self.path, "json", SingleAnalysisDefault
1617 )
1619 def test_downsample(self):
1620 """Test that the posterior table is correctly downsampled
1621 """
1622 super(TestGWJsonFile, self).test_downsample()
1625class TestGWJsonBilbyFile(GWBaseRead):
1626 """Class to test loading in a bilby json file with the gw Read function
1627 """
1628 def setup_method(self):
1629 """Setup the TestCoreBilbyFile class
1630 """
1631 if not os.path.isdir(tmpdir):
1632 os.mkdir(tmpdir)
1633 self.parameters, self.samples = make_result_file(
1634 outdir=tmpdir, extension="json", gw=True, bilby=True
1635 )
1636 self.path = os.path.join(tmpdir, "test.json")
1637 self.result = GWRead(self.path, disable_prior=True)
1639 def teardown_method(self):
1640 """Remove the files and directories created from this class
1641 """
1642 if os.path.isdir(tmpdir):
1643 shutil.rmtree(tmpdir)
1645 def test_class_name(self):
1646 """Test the class used to load in this file
1647 """
1648 assert isinstance(self.result, pesummary.gw.file.formats.bilby.Bilby)
1650 def test_parameters(self):
1651 """Test the parameter property of the bilby class
1652 """
1653 super(TestGWJsonBilbyFile, self).test_parameters(self.parameters)
1655 def test_samples(self):
1656 """Test the samples property of the bilby class
1657 """
1658 super(TestGWJsonBilbyFile, self).test_samples(self.samples)
1660 def test_samples_dict(self):
1661 """Test the samples_dict property of the bilby class
1662 """
1663 true = [self.parameters, self.samples]
1664 super(TestGWJsonBilbyFile, self).test_samples_dict(true)
1666 def test_version(self):
1667 """Test the version property of the default class
1668 """
1669 true = "bilby=0.5.3:"
1670 super(TestGWJsonBilbyFile, self).test_version(true)
1672 def test_extra_kwargs(self):
1673 """Test the extra_kwargs property of the default class
1674 """
1675 true = {"sampler": {
1676 "log_bayes_factor": 0.5,
1677 "log_noise_evidence": 0.1,
1678 "log_evidence": 0.2,
1679 "log_evidence_err": 0.1},
1680 "meta_data": {"time_marginalization": True},
1681 "other": {"likelihood": {"time_marginalization": "True"}}
1682 }
1683 super(TestGWJsonBilbyFile, self).test_extra_kwargs(true)
1685 def test_injection_parameters(self):
1686 """Test the injection_parameters property
1687 """
1688 true = {par: 1. for par in self.parameters}
1689 super(TestGWJsonBilbyFile, self).test_injection_parameters(true)
1691 def test_to_dat(self):
1692 """Test the to_dat method
1693 """
1694 super(TestGWJsonBilbyFile, self).test_to_dat()
1696 def test_to_lalinference_dat(self):
1697 """Test the to_lalinference dat=True method
1698 """
1699 super(TestGWJsonBilbyFile, self).test_to_lalinference_dat()
1701 def test_file_format_read(self):
1702 """Test that when the file_format is specified, that correct class is used
1703 """
1704 from pesummary.gw.file.formats.bilby import Bilby
1706 super(TestGWJsonBilbyFile, self).test_file_format_read(self.path, "bilby", Bilby)
1708 def test_downsample(self):
1709 """Test that the posterior table is correctly downsampled
1710 """
1711 super(TestGWJsonBilbyFile, self).test_downsample()
1713 def test_priors(self, read_function=GWRead):
1714 """Test that the priors are correctly extracted from the bilby result
1715 file
1716 """
1717 self.result = GWRead(self.path)
1718 assert "final_mass_source_non_evolved" not in self.result.parameters
1719 for param, prior in self.result.priors["samples"].items():
1720 assert isinstance(prior, np.ndarray)
1721 assert "final_mass_source_non_evolved" in self.result.priors["samples"].keys()
1722 f = read_function(self.path, disable_prior_conversion=True)
1723 assert "final_mass_source_non_evolved" not in f.priors["samples"].keys()
1724 f = read_function(self.path, disable_prior=True)
1725 assert not len(f.priors["samples"])
1726 f = read_function(self.path, nsamples_for_prior=200)
1727 params = list(f.priors["samples"].keys())
1728 assert len(f.priors["samples"][params[0]]) == 200
1731class TestGWLALInferenceFile(GWBaseRead):
1732 """Class to test loading in a LALInference file with the gw Read function
1733 """
1734 def setup_method(self):
1735 """Setup the TestCoreBilbyFile class
1736 """
1737 if not os.path.isdir(tmpdir):
1738 os.mkdir(tmpdir)
1739 self.parameters, self.samples = make_result_file(
1740 outdir=tmpdir, extension="hdf5", gw=True, lalinference=True
1741 )
1742 self.path = os.path.join(tmpdir, "test.hdf5")
1743 self.result = GWRead(self.path)
1745 def teardown_method(self):
1746 """Remove the files and directories created from this class
1747 """
1748 if os.path.isdir(tmpdir):
1749 shutil.rmtree(tmpdir)
1751 def test_hdf5_dataset_to_list(self):
1752 """Test method to convert hdf5 dataset to list
1753 """
1754 import h5py
1755 f = h5py.File(self.path)
1756 path_to_samples = "lalinference/lalinference_nest/posterior_samples"
1757 parameters = f[path_to_samples].dtype.names
1758 old = [
1759 [float(i[parameters.index(j)]) for j in parameters] for
1760 i in f[path_to_samples]
1761 ]
1762 new = np.array(f[path_to_samples]).view((float, len(parameters))).tolist()
1763 for n in range(len(old)):
1764 np.testing.assert_almost_equal(old[n], new[n])
1766 def test_class_name(self):
1767 """Test the class used to load in this file
1768 """
1769 assert isinstance(
1770 self.result, pesummary.gw.file.formats.lalinference.LALInference)
1772 def test_parameters(self):
1773 """Test the parameter property of the bilby class
1774 """
1775 super(TestGWLALInferenceFile, self).test_parameters(self.parameters)
1777 def test_samples(self):
1778 """Test the samples property of the bilby class
1779 """
1780 super(TestGWLALInferenceFile, self).test_samples(self.samples)
1782 def test_samples_dict(self):
1783 """Test the samples_dict property of the bilby class
1784 """
1785 true = [self.parameters, self.samples]
1786 super(TestGWLALInferenceFile, self).test_samples_dict(true)
1788 def test_version(self):
1789 """Test the version property of the default class
1790 """
1791 super(TestGWLALInferenceFile, self).test_version()
1793 def test_extra_kwargs(self):
1794 """Test the extra_kwargs property of the default class
1795 """
1796 true = {"sampler": {"nsamples": 1000}, "meta_data": {}, "other": {}}
1797 super(TestGWLALInferenceFile, self).test_extra_kwargs(true=true)
1799 def test_injection_parameters(self):
1800 """Test the injection_parameters property
1801 """
1802 super(TestGWLALInferenceFile, self).test_injection_parameters(None)
1804 def test_to_dat(self):
1805 """Test the to_dat method
1806 """
1807 super(TestGWLALInferenceFile, self).test_to_dat()
1809 def test_to_lalinference_dat(self):
1810 """Test the to_lalinference dat=True method
1811 """
1812 super(TestGWLALInferenceFile, self).test_to_lalinference_dat()
1814 def test_file_format_read(self):
1815 """Test that when the file_format is specified, that correct class is used
1816 """
1817 from pesummary.gw.file.formats.lalinference import LALInference
1819 super(TestGWLALInferenceFile, self).test_file_format_read(
1820 self.path, "lalinference", LALInference
1821 )
1823 def test_downsample(self):
1824 """Test that the posterior table is correctly downsampled
1825 """
1826 super(TestGWLALInferenceFile, self).test_downsample()
1829class TestPublicPycbc(object):
1830 """Test that data files produced by Nitz et al.
1831 (https://github.com/gwastro/2-ogc) can be read in correctly.
1832 """
1833 def setup_method(self):
1834 """Setup the TestCoreBilbyFile class
1835 """
1836 if not os.path.isdir(tmpdir):
1837 os.mkdir(tmpdir)
1839 def teardown_method(self):
1840 """Remove the files and directories created from this class
1841 """
1842 if os.path.isdir(tmpdir):
1843 shutil.rmtree(tmpdir)
1845 def _pycbc_check(self, filename):
1846 """Test a public pycbc posterior samples file
1848 Parameters
1849 ----------
1850 filename: str
1851 url of pycbc posterior samples file you wish to download, read and
1852 test
1853 """
1854 from pesummary.core.fetch import download_and_read_file
1855 from pesummary.gw.file.standard_names import standard_names
1856 import h5py
1857 self.file = download_and_read_file(
1858 filename, read_file=False, outdir=tmpdir
1859 )
1860 self.result = GWRead(self.file, path_to_samples="samples", psi_mod_pi=False)
1861 samples = self.result.samples_dict
1862 fp = h5py.File(self.file, 'r')
1863 fp_samples = fp["samples"]
1864 for param in fp_samples.keys():
1865 np.testing.assert_almost_equal(
1866 fp_samples[param], samples[standard_names.get(param, param)]
1867 )
1868 fp.close()
1870 def test_2_OGC(self):
1871 """Test the samples released as part of the 2-OGC catalog
1872 """
1873 self._pycbc_check(
1874 "https://github.com/gwastro/2-ogc/raw/master/posterior_samples/"
1875 "H1L1V1-EXTRACT_POSTERIOR_150914_09H_50M_45UTC-0-1.hdf"
1876 )
1878 def test_3_OGC(self):
1879 """Test the samples released as part of the 3-OGC catalog
1880 """
1881 self._pycbc_check(
1882 "https://github.com/gwastro/3-ogc/raw/master/posterior/"
1883 "GW150914_095045-PYCBC-POSTERIOR-XPHM.hdf"
1884 )
1887class TestPublicPrincetonO1O2(object):
1888 """Test that data files produced by Venumadhav et al.
1889 (https://github.com/jroulet/O2_samples) can be read in correctly
1890 """
1891 def setup_method(self):
1892 """Setup the TestCoreBilbyFile class
1893 """
1894 from pesummary.core.fetch import download_and_read_file
1895 if not os.path.isdir(".outdir"):
1896 os.mkdir(".outdir")
1897 self.file = download_and_read_file(
1898 "https://github.com/jroulet/O2_samples/raw/master/GW150914.npy",
1899 read_file=False, outdir=".outdir"
1900 )
1901 self.result = GWRead(self.file, file_format="princeton")
1903 def teardown_method(self):
1904 """Remove the files and directories created from this class
1905 """
1906 if os.path.isdir(".outdir"):
1907 shutil.rmtree(".outdir")
1909 def test_samples_dict(self):
1910 """
1911 """
1912 data = np.load(self.file)
1913 samples = self.result.samples_dict
1914 map = {
1915 "mchirp": "chirp_mass", "eta": "symmetric_mass_ratio",
1916 "s1z": "spin_1z", "s2z": "spin_2z", "RA": "ra", "DEC": "dec",
1917 "psi": "psi", "iota": "iota", "vphi": "phase", "tc": "geocent_time",
1918 "DL": "luminosity_distance"
1919 }
1920 columns = [
1921 'mchirp', 'eta', 's1z', 's2z', 'RA', 'DEC', 'psi', 'iota', 'vphi',
1922 'tc', 'DL'
1923 ]
1924 for num, param in enumerate(columns):
1925 np.testing.assert_almost_equal(data.T[num], samples[map[param]])
1928class TestMultiAnalysis(object):
1929 """Class to test that a file which contains multiple analyses can be read
1930 in appropiately
1931 """
1932 def setup_method(self):
1933 """Setup the TestMultiAnalysis class
1934 """
1935 from pesummary.utils.samples_dict import MultiAnalysisSamplesDict
1936 from pesummary.io import write
1938 if not os.path.isdir(tmpdir):
1939 os.mkdir(tmpdir)
1940 self.data = MultiAnalysisSamplesDict(
1941 {"label1": {
1942 "mass_1": np.random.uniform(20, 100, 10),
1943 "mass_2": np.random.uniform(5, 20, 10),
1944 }, "label2": {
1945 "mass_1": np.random.uniform(20, 100, 10),
1946 "mass_2": np.random.uniform(5, 20, 10)
1947 }}
1948 )
1949 write(
1950 self.data, file_format="sql", filename="multi_analysis.db",
1951 outdir=tmpdir, overwrite=True, delete_existing=True
1952 )
1953 self.result = read(
1954 os.path.join(tmpdir, "multi_analysis.db"),
1955 add_zero_likelihood=False, remove_row_column="ROW"
1956 )
1957 self.samples_dict = self.result.samples_dict
1959 def teardown_method(self):
1960 """Remove all files and directories created from this class
1961 """
1962 if os.path.isdir(tmpdir):
1963 shutil.rmtree(tmpdir)
1965 def test_multi_analysis_db(self):
1966 """Test that an sql database with more than one set of samples can
1967 be read in appropiately
1968 """
1969 assert sorted(self.samples_dict.keys()) == sorted(self.data.keys())
1970 for key in self.samples_dict.keys():
1971 assert sorted(self.samples_dict[key].keys()) == sorted(
1972 self.data[key].keys()
1973 )
1974 for param in self.samples_dict[key].keys():
1975 np.testing.assert_almost_equal(
1976 self.samples_dict[key][param], self.data[key][param]
1977 )
1978 self.result.generate_all_posterior_samples()
1979 self.samples_dict = self.result.samples_dict
1980 for key in self.samples_dict.keys():
1981 assert "total_mass" in self.samples_dict[key].keys()
1982 np.testing.assert_almost_equal(
1983 self.data[key]["mass_1"] + self.data[key]["mass_2"],
1984 self.samples_dict[key]["total_mass"]
1985 )
1988class TestSingleAnalysisChangeFormat(object):
1989 """Test that when changing file format through the 'write' method, the
1990 samples are conserved
1991 """
1992 def setup_method(self):
1993 """Setup the TestChangeFormat class
1994 """
1995 if not os.path.isdir(tmpdir):
1996 os.mkdir(tmpdir)
1997 self.parameters = ["log_likelihood", "mass_1", "mass_2"]
1998 self.samples = np.array(
1999 [
2000 np.random.uniform(20, 100, 1000),
2001 np.random.uniform(5, 10, 1000), np.random.uniform(0, 1, 1000)
2002 ]
2003 ).T
2004 write(
2005 self.parameters, self.samples, outdir=tmpdir, filename="test.dat",
2006 overwrite=True
2007 )
2008 self.result = read(os.path.join(tmpdir, "test.dat"))
2010 def teardown_method(self):
2011 """Remove all files and directories created from this class
2012 """
2013 if os.path.isdir(tmpdir):
2014 shutil.rmtree(tmpdir)
2016 def save_and_check(
2017 self, file_format, bilby=False, pesummary=False, lalinference=False
2018 ):
2019 """Save the result file and check the contents
2020 """
2021 if bilby:
2022 filename = "test_bilby.json"
2023 elif pesummary or lalinference:
2024 filename = "test_pesummary.h5"
2025 else:
2026 filename = "test.{}".format(file_format)
2027 self.result.write(
2028 file_format=file_format, outdir=tmpdir, filename=filename
2029 )
2030 result = read(os.path.join(tmpdir, filename), disable_prior=True)
2031 if pesummary:
2032 assert result.parameters[0] == self.parameters
2033 np.testing.assert_almost_equal(result.samples[0], self.samples)
2034 else:
2035 original = result.parameters
2036 sorted_params = sorted(result.parameters)
2037 idxs = [original.index(i) for i in sorted_params]
2038 assert sorted(result.parameters) == self.parameters
2039 np.testing.assert_almost_equal(
2040 np.array(result.samples)[:, idxs], self.samples
2041 )
2043 def test_to_bilby(self):
2044 """Test saving to bilby format
2045 """
2046 self.save_and_check("bilby", bilby=True)
2048 def test_to_hdf5(self):
2049 """Test saving to hdf5
2050 """
2051 self.save_and_check("hdf5")
2053 def test_to_json(self):
2054 """Test saving to json
2055 """
2056 self.save_and_check("json")
2058 def test_to_sql(self):
2059 """Test saving to sql
2060 """
2061 self.save_and_check("sql")
2063 def test_to_pesummary(self):
2064 self.save_and_check("pesummary", pesummary=True)
2066 def test_to_lalinference(self):
2067 self.save_and_check("lalinference", lalinference=True)
2070class TestMultipleAnalysisChangeFormat(object):
2071 """Test that when changing file format through the 'write' method, the
2072 samples are conserved
2073 """
2074 def setup_method(self):
2075 """Setup the TestMultiplAnalysisChangeFormat class
2076 """
2077 if not os.path.isdir(tmpdir):
2078 os.mkdir(tmpdir)
2079 self.parameters = [
2080 ["log_likelihood", "mass_1", "mass_2"],
2081 ["chirp_mass", "log_likelihood", "total_mass"]
2082 ]
2083 self.samples = np.array(
2084 [np.array(
2085 [
2086 np.random.uniform(20, 100, 1000),
2087 np.random.uniform(5, 10, 1000),
2088 np.random.uniform(0, 1, 1000)
2089 ]
2090 ).T, np.array(
2091 [
2092 np.random.uniform(20, 100, 1000),
2093 np.random.uniform(5, 10, 1000),
2094 np.random.uniform(0, 1, 1000)
2095 ]
2096 ).T]
2097 )
2098 write(
2099 self.parameters, self.samples, outdir=tmpdir, filename="test.db",
2100 overwrite=True, file_format="sql"
2101 )
2102 self.result = read(os.path.join(tmpdir, "test.db"))
2104 def teardown_method(self):
2105 """Remove all files and directories created from this class
2106 """
2107 if os.path.isdir(tmpdir):
2108 shutil.rmtree(tmpdir)
2110 def save_and_check(
2111 self, file_format, bilby=False, pesummary=False, lalinference=False,
2112 multiple_files=True
2113 ):
2114 """Save the result file and check the contents
2115 """
2116 if bilby:
2117 filename = "test_bilby.json"
2118 elif pesummary or lalinference:
2119 filename = "test_pesummary.h5"
2120 else:
2121 filename = "test.{}".format(file_format)
2122 self.result.write(
2123 file_format=file_format, outdir=tmpdir, filename=filename
2124 )
2125 if multiple_files:
2126 files = sorted(glob.glob(tmpdir + "/{}_*.{}".format(*filename.split("."))))
2127 assert len(files) == 2
2128 for num, _file in enumerate(files):
2129 result = read(_file, disable_prior=True)
2130 original = result.parameters
2131 sorted_params = sorted(result.parameters)
2132 idxs = [original.index(i) for i in sorted_params]
2133 assert sorted(result.parameters) == self.parameters[num]
2134 np.testing.assert_almost_equal(
2135 np.array(result.samples)[:, idxs], self.samples[num]
2136 )
2137 else:
2138 result = read(os.path.join(tmpdir, filename), disable_prior=True)
2139 original = result.parameters
2140 sorted_params = sorted(result.parameters)
2141 idxs = [original.index(i) for i in sorted_params]
2142 for ii in range(len(original)):
2143 assert result.parameters[ii] == self.parameters[ii]
2144 np.testing.assert_almost_equal(
2145 np.array(result.samples), self.samples
2146 )
2148 def test_to_bilby(self):
2149 """Test saving to bilby
2150 """
2151 self.save_and_check("bilby", bilby=True)
2153 def test_to_dat(self):
2154 """Test saving to dat
2155 """
2156 self.save_and_check("dat")
2158 def test_to_hdf5(self):
2159 """Test saving to hdf5
2160 """
2161 self.save_and_check("hdf5")
2163 def test_to_json(self):
2164 """Test saving to json
2165 """
2166 self.save_and_check("json")
2168 def test_to_sql(self):
2169 """Test saving to sql
2170 """
2171 self.save_and_check("sql", multiple_files=False)
2173 def test_to_pesummary(self):
2174 self.save_and_check("pesummary", pesummary=True, multiple_files=False)
2176 def test_to_lalinference(self):
2177 self.save_and_check("lalinference", lalinference=True)
2180def test_remove_nan_likelihoods():
2181 """Test that samples with 'nan' log_likelihoods are removed from the
2182 posterior table
2183 """
2184 from pesummary.utils.samples_dict import MultiAnalysisSamplesDict
2185 import math
2187 if not os.path.isdir(tmpdir):
2188 os.mkdir(tmpdir)
2189 parameters = ["a", "b", "log_likelihood"]
2190 likelihoods = np.random.uniform(0, 1, 1000)
2191 inds = np.random.choice(len(likelihoods), size=100, replace=False)
2192 likelihoods[inds] = float('nan')
2193 samples = np.array([
2194 np.random.uniform(10, 5, 1000), np.random.uniform(10, 5, 1000),
2195 likelihoods
2196 ]).T
2197 write(parameters, samples, filename="test.dat", outdir=tmpdir)
2198 f = read("{}/test.dat".format(tmpdir), remove_nan_likelihood_samples=False)
2199 read_samples = f.samples_dict
2200 for param in parameters:
2201 assert len(read_samples[param]) == 1000
2202 for num, param in enumerate(parameters):
2203 np.testing.assert_almost_equal(read_samples[param], samples.T[num])
2204 f = read("{}/test.dat".format(tmpdir), remove_nan_likelihood_samples=True)
2205 read_samples = f.samples_dict
2206 for param in parameters:
2207 assert len(read_samples[param]) == 900
2208 inds = np.array([math.isnan(_) for _ in likelihoods], dtype=bool)
2209 for num, param in enumerate(parameters):
2210 np.testing.assert_almost_equal(
2211 read_samples[param], samples.T[num][~inds]
2212 )
2213 likelihoods = np.random.uniform(0, 1, 2000).reshape(2, 1000)
2214 inds = np.random.choice(1000, size=100, replace=False)
2215 likelihoods[0][inds] = float('nan')
2216 inds = np.random.choice(1000, size=500, replace=False)
2217 likelihoods[1][inds] = float('nan')
2218 samples = {
2219 "one": {
2220 "a": np.random.uniform(1, 5, 1000), "b": np.random.uniform(1, 2, 1000),
2221 "log_likelihood": likelihoods[0]
2222 }, "two": {
2223 "c": np.random.uniform(1, 5, 1000), "d": np.random.uniform(1, 2, 1000),
2224 "log_likelihood": likelihoods[1]
2225 }
2226 }
2227 data = MultiAnalysisSamplesDict(samples)
2228 write(
2229 data, file_format="pesummary", filename="multi.h5", outdir=tmpdir,
2230 )
2231 f = read("{}/multi.h5".format(tmpdir), remove_nan_likelihood_samples=True)
2232 _samples_dict = f.samples_dict
2233 for num, label in enumerate(["one", "two"]):
2234 inds = np.array([math.isnan(_) for _ in likelihoods[num]], dtype=bool)
2235 if num == 0:
2236 assert len(_samples_dict["one"]["a"]) == 900
2237 else:
2238 assert len(_samples_dict["two"]["c"]) == 500
2239 for param in samples[label].keys():
2240 np.testing.assert_almost_equal(
2241 _samples_dict[label][param], samples[label][param][~inds]
2242 )
2243 if os.path.isdir(tmpdir):
2244 shutil.rmtree(tmpdir)
2247def test_add_log_likelihood():
2248 """Test that zero log likelihood samples are added when the posterior table
2249 does not include likelihood samples
2250 """
2251 from pesummary.utils.samples_dict import MultiAnalysisSamplesDict
2253 if not os.path.isdir(tmpdir):
2254 os.mkdir(tmpdir)
2255 parameters = ["a", "b"]
2256 samples = np.array([
2257 np.random.uniform(10, 5, 1000), np.random.uniform(10, 5, 1000)
2258 ]).T
2259 write(parameters, samples, filename="test.dat", outdir=tmpdir)
2260 f = read("{}/test.dat".format(tmpdir))
2261 _samples_dict = f.samples_dict
2262 assert sorted(f.parameters) == ["a", "b", "log_likelihood"]
2263 np.testing.assert_almost_equal(
2264 _samples_dict["log_likelihood"], np.zeros(1000)
2265 )
2266 np.testing.assert_almost_equal(_samples_dict["a"], samples.T[0])
2267 np.testing.assert_almost_equal(_samples_dict["b"], samples.T[1])
2268 parameters = [["a", "b"], ["c", "d"]]
2269 samples = [
2270 np.array([np.random.uniform(1, 5, 1000), np.random.uniform(1, 2, 1000)]).T,
2271 np.array([np.random.uniform(1, 5, 1000), np.random.uniform(1, 2, 1000)]).T
2272 ]
2273 data = MultiAnalysisSamplesDict({
2274 "one": {
2275 "a": np.random.uniform(1, 5, 1000), "b": np.random.uniform(1, 2, 1000)
2276 }, "two": {
2277 "c": np.random.uniform(1, 5, 1000), "d": np.random.uniform(1, 2, 1000)
2278 }
2279 })
2280 write(
2281 data, file_format="pesummary", filename="multi.h5", outdir=tmpdir,
2282 )
2283 f = read("{}/multi.h5".format(tmpdir))
2284 _samples_dict = f.samples_dict
2285 np.testing.assert_almost_equal(
2286 _samples_dict["one"]["log_likelihood"], np.zeros(1000)
2287 )
2288 np.testing.assert_almost_equal(
2289 _samples_dict["two"]["log_likelihood"], np.zeros(1000)
2290 )
2291 if os.path.isdir(tmpdir):
2292 shutil.rmtree(tmpdir)