Coverage for pesummary/tests/utils_test.py: 41.4%
729 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import os
4import shutil
5import h5py
6import numpy as np
7import copy
9import pesummary
10from pesummary.io import write
11import pesummary.cli as cli
12from pesummary.utils import utils
13from pesummary.utils.tqdm import tqdm
14from pesummary.utils.dict import Dict
15from pesummary.utils.list import List
16from pesummary.utils.pdf import DiscretePDF, DiscretePDF2D, DiscretePDF2Dplus1D
17from pesummary.utils.array import _2DArray
18from pesummary.utils.samples_dict import (
19 Array, SamplesDict, MCMCSamplesDict, MultiAnalysisSamplesDict
20)
21from pesummary.utils.probability_dict import ProbabilityDict, ProbabilityDict2D
22from pesummary._version_helper import GitInformation, PackageInformation
23from pesummary._version_helper import get_version_information
25import pytest
26from testfixtures import LogCapture
27import tempfile
29tmpdir = tempfile.TemporaryDirectory(prefix=".", dir=".").name
30__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
32DEFAULT_DIRECTORY = os.getenv("CI_PROJECT_DIR", os.getcwd())
35class TestGitInformation(object):
36 """Class to test the GitInformation helper class
37 """
38 def setup_method(self):
39 """Setup the TestGitInformation class
40 """
41 self.git = GitInformation(directory=DEFAULT_DIRECTORY)
43 def test_last_commit_info(self):
44 """Test the last_commit_info property
45 """
46 assert len(self.git.last_commit_info) == 2
47 assert isinstance(self.git.last_commit_info[0], str)
48 assert isinstance(self.git.last_commit_info[1], str)
50 def test_last_version(self):
51 """Test the last_version property
52 """
53 assert isinstance(self.git.last_version, str)
55 def test_status(self):
56 """Test the status property
57 """
58 assert isinstance(self.git.status, str)
60 def test_builder(self):
61 """Test the builder property
62 """
63 assert isinstance(self.git.builder, str)
65 def test_build_date(self):
66 """Test the build_date property
67 """
68 assert isinstance(self.git.build_date, str)
71class TestPackageInformation(object):
72 """Class to test the PackageInformation helper class
73 """
74 def setup_method(self):
75 """Setup the TestPackageInformation class
76 """
77 self.package = PackageInformation()
79 def test_package_info(self):
80 """Test the package_info property
81 """
82 pi = self.package.package_info
83 assert isinstance(pi, list)
84 pkg = pi[0]
85 assert "name" in pkg
86 assert "version" in pkg
87 if "build_string" in pkg: # conda only
88 assert "channel" in pkg
91class TestUtils(object):
92 """Class to test pesummary.utils.utils
93 """
94 def setup_method(self):
95 """Setup the TestUtils class
96 """
97 if not os.path.isdir(tmpdir):
98 os.mkdir(tmpdir)
100 def teardown_method(self):
101 """Remove the files created from this class
102 """
103 if os.path.isdir(tmpdir):
104 shutil.rmtree(tmpdir)
106 def test_check_condition(self):
107 """Test the check_condition method
108 """
109 with pytest.raises(Exception) as info:
110 condition = True
111 utils.check_condition(condition, "error")
112 assert str(info.value) == "error"
114 def test_rename_group_in_hf5_file(self):
115 """Test the rename_group_in_hf5_file method
116 """
117 f = h5py.File("{}/rename_group.h5".format(tmpdir), "w")
118 group = f.create_group("group")
119 group.create_dataset("example", data=np.array([10]))
120 f.close()
121 utils.rename_group_or_dataset_in_hf5_file(
122 "{}/rename_group.h5".format(tmpdir),
123 group=["group", "replaced"])
124 f = h5py.File("{}/rename_group.h5".format(tmpdir))
125 assert list(f.keys()) == ["replaced"]
126 assert list(f["replaced"].keys()) == ["example"]
127 assert len(f["replaced/example"]) == 1
128 assert f["replaced/example"][0] == 10
129 f.close()
131 def test_rename_dataset_in_hf5_file(self):
132 f = h5py.File("{}/rename_dataset.h5".format(tmpdir), "w")
133 group = f.create_group("group")
134 group.create_dataset("example", data=np.array([10]))
135 f.close()
136 utils.rename_group_or_dataset_in_hf5_file(
137 "{}/rename_dataset.h5".format(tmpdir),
138 dataset=["group/example", "group/replaced"])
139 f = h5py.File("{}/rename_dataset.h5".format(tmpdir))
140 assert list(f.keys()) == ["group"]
141 assert list(f["group"].keys()) == ["replaced"]
142 assert len(f["group/replaced"]) == 1
143 assert f["group/replaced"][0] == 10
144 f.close()
146 def test_rename_unknown_hf5_file(self):
147 with pytest.raises(Exception) as info:
148 utils.rename_group_or_dataset_in_hf5_file(
149 "{}/unknown.h5".format(tmpdir),
150 group=["None", "replaced"])
151 assert "does not exist" in str(info.value)
153 def test_directory_creation(self):
154 directory = '{}/test_dir'.format(tmpdir)
155 assert os.path.isdir(directory) == False
156 utils.make_dir(directory)
157 assert os.path.isdir(directory) == True
159 def test_url_guess(self):
160 host = ["raven", "cit", "ligo-wa", "uwm", "phy.syr.edu", "vulcan",
161 "atlas", "iucaa", "alice"]
162 expected = ["https://geo2.arcca.cf.ac.uk/~albert.einstein/test",
163 "https://ldas-jobs.ligo.caltech.edu/~albert.einstein/test",
164 "https://ldas-jobs.ligo-wa.caltech.edu/~albert.einstein/test",
165 "https://ldas-jobs.phys.uwm.edu/~albert.einstein/test",
166 "https://sugar-jobs.phy.syr.edu/~albert.einstein/test",
167 "https://galahad.aei.mpg.de/~albert.einstein/test",
168 "https://atlas1.atlas.aei.uni-hannover.de/~albert.einstein/test",
169 "https://ldas-jobs.gw.iucaa.in/~albert.einstein/test",
170 "https://dumpty.alice.icts.res.in/~albert.einstein/test"]
171 user = "albert.einstein"
172 webdir = '/home/albert.einstein/public_html/test'
173 for i,j in zip(host, expected):
174 url = utils.guess_url(webdir, i, user)
175 assert url == j
177 def test_make_dir(self):
178 """Test the make_dir method
179 """
180 assert not os.path.isdir(os.path.join(tmpdir, "test"))
181 utils.make_dir(os.path.join(tmpdir, "test"))
182 assert os.path.isdir(os.path.join(tmpdir, "test"))
183 with open(os.path.join(tmpdir, "test", "test.dat"), "w") as f:
184 f.writelines(["test"])
185 utils.make_dir(os.path.join(tmpdir, "test"))
186 assert os.path.isfile(os.path.join(tmpdir, "test", "test.dat"))
188 def test_resample_posterior_distribution(self):
189 """Test the resample_posterior_distribution method
190 """
191 data = np.random.normal(1, 0.1, 1000)
192 resampled = utils.resample_posterior_distribution([data], 500)
193 assert len(resampled) == 500
194 assert np.round(np.mean(resampled), 1) == 1.
195 assert np.round(np.std(resampled), 1) == 0.1
197 def test_gw_results_file(self):
198 """Test the gw_results_file method
199 """
200 from .base import namespace
202 opts = namespace({"gw": True, "psd": True})
203 assert utils.gw_results_file(opts)
204 opts = namespace({"webdir": tmpdir})
205 assert not utils.gw_results_file(opts)
207 def test_functions(self):
208 """Test the functions method
209 """
210 from .base import namespace
212 opts = namespace({"gw": True, "psd": True})
213 funcs = utils.functions(opts)
214 assert funcs["input"] == pesummary.gw.cli.inputs.WebpagePlusPlottingPlusMetaFileInput
215 assert funcs["MetaFile"] == pesummary.gw.file.meta_file.GWMetaFile
217 opts = namespace({"webdir": tmpdir})
218 funcs = utils.functions(opts)
219 assert funcs["input"] == pesummary.core.cli.inputs.WebpagePlusPlottingPlusMetaFileInput
220 assert funcs["MetaFile"] == pesummary.core.file.meta_file.MetaFile
222 def test_get_version_information(self):
223 """Test the get_version_information method
224 """
225 assert isinstance(get_version_information(), str)
228class TestGelmanRubin(object):
229 """Test the Gelman Rubin calculation
230 """
231 def test_same_as_lalinference(self):
232 """Test the Gelman rubin output from pesummary is the same as
233 the one coded in LALInference
234 """
235 from lalinference.bayespputils import Posterior
236 from pesummary.utils.utils import gelman_rubin
238 header = ["a", "b", "logL", "chain"]
239 for _ in np.arange(100):
240 samples = np.array(
241 [
242 np.random.uniform(np.random.random(), 0.1, 3).tolist() +
243 [np.random.randint(1, 3)] for _ in range(10)
244 ]
245 )
246 obj = Posterior([header, np.array(samples)])
247 R = obj.gelman_rubin("a")
248 chains = np.unique(obj["chain"].samples)
249 chain_index = obj.names.index("chain")
250 param_index = obj.names.index("a")
251 data, _ = obj.samples()
252 chainData=[
253 data[data[:,chain_index] == chain, param_index] for chain in
254 chains
255 ]
256 np.testing.assert_almost_equal(
257 gelman_rubin(chainData, decimal=10), R, 7
258 )
260 def test_same_samples(self):
261 """Test that when passed two identical chains (perfect convergence),
262 the Gelman Rubin is 1
263 """
264 from pesummary.core.plots.plot import gelman_rubin
266 samples = np.random.uniform(1, 0.5, 10)
267 R = gelman_rubin([samples, samples])
268 assert R == 1
271class TestSamplesDict(object):
272 """Test the SamplesDict class
273 """
274 def setup_method(self):
275 self.parameters = ["a", "b"]
276 self.samples = [
277 np.random.uniform(10, 0.5, 100), np.random.uniform(200, 10, 100)
278 ]
279 if not os.path.isdir(tmpdir):
280 os.mkdir(tmpdir)
281 write(
282 self.parameters, np.array(self.samples).T, outdir=tmpdir,
283 filename="test.dat", file_format="dat"
284 )
286 def teardown_method(self):
287 """Remove the files created from this class
288 """
289 if os.path.isdir(tmpdir):
290 shutil.rmtree(tmpdir)
292 def test_initalize(self):
293 """Test that the two ways to initialize the SamplesDict class are
294 equivalent
295 """
296 base = SamplesDict(self.parameters, self.samples)
297 other = SamplesDict(
298 {
299 param: sample for param, sample in zip(
300 self.parameters, self.samples
301 )
302 }
303 )
304 assert base.parameters == other.parameters
305 assert sorted(base.parameters) == sorted(self.parameters)
306 np.testing.assert_almost_equal(base.samples, other.samples)
307 assert sorted(list(base.keys())) == sorted(list(other.keys()))
308 np.testing.assert_almost_equal(base.samples, self.samples)
309 class_method = SamplesDict.from_file(
310 "{}/test.dat".format(tmpdir), add_zero_likelihood=False
311 )
312 np.testing.assert_almost_equal(class_method.samples, self.samples)
314 def test_complex_columns(self):
315 """Test that complex columns are desconstructed correctly
316 """
317 parameters = ["a", "b", "a_j", "b_j"]
318 samples = [
319 np.random.uniform(10, 0.5, 100), np.random.uniform(200, 10, 100),
320 np.random.uniform(10, 0.5, 100) + 10j,
321 np.random.uniform(200, 10, 100) + 2j
322 ]
323 dataset1 = SamplesDict(
324 parameters, samples, deconstruct_complex_columns=True
325 )
326 dataset2 = SamplesDict(
327 {p: s for p, s in zip(parameters, samples)},
328 deconstruct_complex_columns=True
329 )
330 for dd in [dataset1, dataset2]:
331 for num, param in enumerate(["a_j", "b_j"]):
332 assert f"{param}_abs" in dd.parameters
333 assert f"{param}_angle" in dd.parameters
334 np.testing.assert_almost_equal(
335 dd[f"{param}_abs"], np.abs(samples[2 + num])
336 )
337 np.testing.assert_almost_equal(
338 dd[f"{param}_angle"], np.angle(samples[2 + num])
339 )
340 np.testing.assert_almost_equal(
341 dd[f"{param}"], np.real(samples[2 + num])
342 )
343 np.testing.assert_almost_equal(dd["a"], samples[0])
344 np.testing.assert_almost_equal(dd["b"], samples[1])
345 dataset1 = SamplesDict(
346 parameters, samples, deconstruct_complex_columns=False
347 )
348 dataset2 = SamplesDict(
349 {p: s for p, s in zip(parameters, samples)},
350 deconstruct_complex_columns=False
351 )
352 for dd in [dataset1, dataset2]:
353 for param, ss in zip(parameters, samples):
354 np.testing.assert_almost_equal(
355 dd[param], ss
356 )
358 def test_properties(self):
359 """Test that the properties of the SamplesDict class are correct
360 """
361 import pandas as pd
363 dataset = SamplesDict(self.parameters, self.samples)
364 assert sorted(dataset.minimum.keys()) == sorted(self.parameters)
365 assert dataset.minimum["a"] == np.min(self.samples[0])
366 assert dataset.minimum["b"] == np.min(self.samples[1])
367 assert dataset.median["a"] == np.median(self.samples[0])
368 assert dataset.median["b"] == np.median(self.samples[1])
369 assert dataset.mean["a"] == np.mean(self.samples[0])
370 assert dataset.mean["b"] == np.mean(self.samples[1])
371 assert dataset.number_of_samples == len(self.samples[1])
372 assert len(dataset.downsample(10)["a"]) == 10
373 dataset = SamplesDict(self.parameters, self.samples)
374 assert len(dataset.discard_samples(10)["a"]) == len(self.samples[0]) - 10
375 np.testing.assert_almost_equal(
376 dataset.discard_samples(10)["a"], self.samples[0][10:]
377 )
378 np.testing.assert_almost_equal(
379 dataset.discard_samples(10)["b"], self.samples[1][10:]
380 )
381 p = dataset.to_pandas()
382 assert isinstance(p, pd.core.frame.DataFrame)
383 remove = dataset.pop("a")
384 assert list(dataset.keys()) == ["b"]
386 def test_core_plots(self):
387 """Test that the core plotting methods of the SamplesDict class work as
388 expected
389 """
390 import matplotlib.figure
392 dataset = SamplesDict(self.parameters, self.samples)
393 fig = dataset.plot(self.parameters[0], type="hist")
394 assert isinstance(fig, matplotlib.figure.Figure)
395 fig = dataset.plot(self.parameters[0], type="marginalized_posterior")
396 assert isinstance(fig, matplotlib.figure.Figure)
398 def test_standardize_parameter_names(self):
399 """Test the standardize_parameter_names method
400 """
401 dictionary = {"mass1": [1,2,3,4], "mass2": [4,3,2,1], "zz": [1,1,1,1]}
402 dictionary_copy = dictionary.copy()
403 mydict = SamplesDict(dictionary_copy)
404 standard_dict = mydict.standardize_parameter_names()
405 # check standard parameter names are in the new dictionary
406 assert sorted(list(standard_dict.keys())) == ["mass_1", "mass_2", "zz"]
407 # check that the dictionary items remains the same
408 _mapping = {"mass1": "mass_1", "mass2": "mass_2", "zz": "zz"}
409 for key, item in dictionary.items():
410 np.testing.assert_almost_equal(item, standard_dict[_mapping[key]])
411 # check old dictionary remains unchanged
412 assert sorted(list(mydict.keys())) == ["mass1", "mass2", "zz"]
413 for key, item in mydict.items():
414 np.testing.assert_almost_equal(item, dictionary[key])
415 # try custom mapping
416 mapping = {"mass1": "custom_m1", "mass2": "custom_m2", "zz": "custom_zz"}
417 new_dict = mydict.standardize_parameter_names(mapping=mapping)
418 assert sorted(list(new_dict.keys())) == [
419 "custom_m1", "custom_m2", "custom_zz"
420 ]
421 for old, new in mapping.items():
422 np.testing.assert_almost_equal(dictionary[old], new_dict[new])
424 def test_waveforms(self):
425 """Test the waveform generation
426 """
427 from pesummary.core.fetch import download_dir
428 try:
429 from pycbc.waveform import get_fd_waveform, get_td_waveform
430 except (ValueError, ImportError):
431 return
433 downloaded_file = os.path.join(
434 download_dir, "GW190814_posterior_samples.h5"
435 )
436 if not os.path.isfile(downloaded_file):
437 from pesummary.gw.fetch import fetch_open_samples
438 f = fetch_open_samples(
439 "GW190814", read_file=True, outdir=download_dir, unpack=True,
440 path="GW190814.h5", catalog="GWTC-2"
441 )
442 else:
443 from pesummary.io import read
444 f = read(downloaded_file)
445 samples = f.samples_dict["C01:IMRPhenomPv3HM"]
446 ind = 0
447 data = samples.fd_waveform("IMRPhenomPv3HM", 1./256, 20., 1024., ind=ind)
448 hp_pycbc, hc_pycbc = get_fd_waveform(
449 approximant="IMRPhenomPv3HM", mass1=samples["mass_1"][ind],
450 mass2=samples["mass_2"][ind], spin1x=samples["spin_1x"][ind],
451 spin1y=samples["spin_1y"][ind], spin1z=samples["spin_1z"][ind],
452 spin2x=samples["spin_2x"][ind], spin2y=samples["spin_2y"][ind],
453 spin2z=samples["spin_2z"][ind], inclination=samples["iota"][ind],
454 distance=samples["luminosity_distance"][ind],
455 coa_phase=samples["phase"][ind], f_lower=20., f_final=1024.,
456 delta_f=1./256, f_ref=20.
457 )
458 np.testing.assert_almost_equal(
459 data["h_plus"].frequencies.value, hp_pycbc.sample_frequencies
460 )
461 np.testing.assert_almost_equal(
462 data["h_cross"].frequencies.value, hc_pycbc.sample_frequencies
463 )
464 np.testing.assert_almost_equal(
465 data["h_plus"].value * 10**25, hp_pycbc._data * 10**25
466 )
467 np.testing.assert_almost_equal(
468 data["h_cross"].value * 10**25, hc_pycbc._data * 10**25
469 )
470 data = samples.td_waveform("SEOBNRv4PHM", 1./4096, 20., ind=ind)
471 hp_pycbc, hc_pycbc = get_td_waveform(
472 approximant="SEOBNRv4PHM", mass1=samples["mass_1"][ind],
473 mass2=samples["mass_2"][ind], spin1x=samples["spin_1x"][ind],
474 spin1y=samples["spin_1y"][ind], spin1z=samples["spin_1z"][ind],
475 spin2x=samples["spin_2x"][ind], spin2y=samples["spin_2y"][ind],
476 spin2z=samples["spin_2z"][ind], inclination=samples["iota"][ind],
477 distance=samples["luminosity_distance"][ind],
478 coa_phase=samples["phase"][ind], f_lower=20., delta_t=1./4096,
479 f_ref=20.
480 )
481 np.testing.assert_almost_equal(
482 data["h_plus"].times.value, hp_pycbc.sample_times
483 )
484 np.testing.assert_almost_equal(
485 data["h_cross"].times.value, hc_pycbc.sample_times
486 )
487 np.testing.assert_almost_equal(
488 data["h_plus"].value * 10**25, hp_pycbc._data * 10**25
489 )
490 np.testing.assert_almost_equal(
491 data["h_cross"].value * 10**25, hc_pycbc._data * 10**25
492 )
495class TestMultiAnalysisSamplesDict(object):
496 """Test the MultiAnalysisSamplesDict class
497 """
498 def setup_method(self):
499 self.parameters = ["a", "b"]
500 self.samples = [
501 [np.random.uniform(10, 0.5, 100), np.random.uniform(100, 10, 100)],
502 [np.random.uniform(5, 0.5, 100), np.random.uniform(80, 10, 100)],
503 ]
504 self.labels = ["one", "two"]
505 if not os.path.isdir(tmpdir):
506 os.mkdir(tmpdir)
507 for num, _samples in enumerate(self.samples):
508 write(
509 self.parameters, np.array(_samples).T, outdir=tmpdir,
510 filename="test_{}.dat".format(num + 1), file_format="dat"
511 )
513 def teardown_method(self):
514 """Remove the files created from this class
515 """
516 if os.path.isdir(tmpdir):
517 shutil.rmtree(tmpdir)
519 def test_initalize(self):
520 """Test the different ways to initalize the class
521 """
522 dataframe = MultiAnalysisSamplesDict(
523 self.parameters, self.samples, labels=["one", "two"]
524 )
525 assert sorted(list(dataframe.keys())) == sorted(self.labels)
526 assert sorted(list(dataframe["one"])) == sorted(["a", "b"])
527 assert sorted(list(dataframe["two"])) == sorted(["a", "b"])
528 np.testing.assert_almost_equal(
529 dataframe["one"]["a"], self.samples[0][0]
530 )
531 np.testing.assert_almost_equal(
532 dataframe["one"]["b"], self.samples[0][1]
533 )
534 np.testing.assert_almost_equal(
535 dataframe["two"]["a"], self.samples[1][0]
536 )
537 np.testing.assert_almost_equal(
538 dataframe["two"]["b"], self.samples[1][1]
539 )
540 _other = MCMCSamplesDict({
541 label: {
542 param: self.samples[num][idx] for idx, param in enumerate(
543 self.parameters
544 )
545 } for num, label in enumerate(self.labels)
546 })
547 class_method = MultiAnalysisSamplesDict.from_files(
548 {
549 'one': "{}/test_1.dat".format(tmpdir),
550 'two': "{}/test_2.dat".format(tmpdir)
551 }, add_zero_likelihood=False
552 )
553 for other in [_other, class_method]:
554 assert sorted(other.keys()) == sorted(dataframe.keys())
555 assert sorted(other["one"].keys()) == sorted(
556 dataframe["one"].keys()
557 )
558 np.testing.assert_almost_equal(
559 other["one"]["a"], dataframe["one"]["a"]
560 )
561 np.testing.assert_almost_equal(
562 other["one"]["b"], dataframe["one"]["b"]
563 )
564 np.testing.assert_almost_equal(
565 other["two"]["a"], dataframe["two"]["a"]
566 )
567 np.testing.assert_almost_equal(
568 other["two"]["b"], dataframe["two"]["b"]
569 )
572 def test_different_samples_for_different_analyses(self):
573 """Test that nothing breaks when different samples have different parameters
574 """
575 data = {
576 "one": {
577 "a": np.random.uniform(10, 0.5, 100),
578 "b": np.random.uniform(5, 0.5, 100)
579 }, "two": {
580 "a": np.random.uniform(10, 0.5, 100)
581 }
582 }
583 dataframe = MultiAnalysisSamplesDict(data)
584 assert sorted(dataframe["one"].keys()) == sorted(data["one"].keys())
585 assert sorted(dataframe["two"].keys()) == sorted(data["two"].keys())
586 np.testing.assert_almost_equal(
587 dataframe["one"]["a"], data["one"]["a"]
588 )
589 np.testing.assert_almost_equal(
590 dataframe["one"]["b"], data["one"]["b"]
591 )
592 np.testing.assert_almost_equal(
593 dataframe["two"]["a"], data["two"]["a"]
594 )
595 with pytest.raises(ValueError):
596 transpose = dataframe.T
598 def test_adding_to_table(self):
599 """
600 """
601 data = {
602 "one": {
603 "a": np.random.uniform(10, 0.5, 100),
604 "b": np.random.uniform(5, 0.5, 100)
605 }, "two": {
606 "a": np.random.uniform(10, 0.5, 100)
607 }
608 }
609 dataframe = MultiAnalysisSamplesDict(data)
610 assert "three" not in dataframe.parameters.keys()
611 new_data = {"a": np.random.uniform(10, 0.5, 100)}
612 dataframe["three"] = new_data
613 np.testing.assert_almost_equal(dataframe["three"]["a"], new_data["a"])
614 assert dataframe.parameters["three"] == ["a"]
615 assert "three" in dataframe.number_of_samples.keys()
617 def test_combine(self):
618 """Test that the .combine method is working as expected
619 """
620 data = {
621 "one": {
622 "a": np.random.uniform(10, 0.5, 100),
623 "b": np.random.uniform(5, 0.5, 100)
624 }, "two": {
625 "a": np.random.uniform(100, 0.5, 100),
626 "b": np.random.uniform(50, 0.5, 100)
627 }
628 }
629 dataframe = MultiAnalysisSamplesDict(data)
630 # test that when weights are 0 and 1, we only get the second set of
631 # samples
632 combine = dataframe.combine(weights={"one": 0., "two": 1.})
633 assert "a" in combine.keys()
634 assert "b" in combine.keys()
635 assert all(ss in data["two"]["a"] for ss in combine["a"])
636 assert all(ss in data["two"]["b"] for ss in combine["b"])
637 # test that when weights are equal, the first half of samples are from
638 # one and the second half are from two
639 combine = dataframe.combine(labels=["one", "two"], weights=[0.5, 0.5])
640 nsamples = len(combine["a"])
641 half = int(nsamples / 2)
642 assert all(ss in data["one"]["a"] for ss in combine["a"][:half])
643 assert all(ss in data["two"]["a"] for ss in combine["a"][half:])
644 assert all(ss in data["one"]["b"] for ss in combine["b"][:half])
645 assert all(ss in data["two"]["b"] for ss in combine["b"][half:])
646 # test that the samples maintain order
647 for n in np.random.choice(half, size=10, replace=False):
648 ind = np.argwhere(data["one"]["a"] == combine["a"][n])
649 assert data["one"]["b"][ind] == combine["b"][n]
650 # test that when use_all is provided, all samples are included
651 combine = dataframe.combine(use_all=True)
652 assert len(combine["a"]) == len(data["one"]["a"]) + len(data["two"]["a"])
653 # test shuffle
654 combine = dataframe.combine(
655 labels=["one", "two"], weights=[0.5, 0.5], shuffle=True
656 )
657 for n in np.random.choice(half, size=10, replace=False):
658 if combine["a"][n] in data["one"]["a"]:
659 ind = np.argwhere(data["one"]["a"] == combine["a"][n])
660 assert data["one"]["b"][ind] == combine["b"][n]
661 else:
662 ind = np.argwhere(data["two"]["a"] == combine["a"][n])
663 assert data["two"]["b"][ind] == combine["b"][n]
664 assert len(set(combine["a"])) == len(combine["a"])
665 assert len(set(combine["b"])) == len(combine["b"])
668class TestMCMCSamplesDict(object):
669 """Test the MCMCSamplesDict class
670 """
671 def setup_method(self):
672 self.parameters = ["a", "b"]
673 self.chains = [
674 [np.random.uniform(10, 0.5, 100), np.random.uniform(100, 10, 100)],
675 [np.random.uniform(5, 0.5, 100), np.random.uniform(80, 10, 100)]
676 ]
678 def test_initalize(self):
679 """Test the different ways to initalize the class
680 """
681 dataframe = MCMCSamplesDict(self.parameters, self.chains)
682 assert sorted(list(dataframe.keys())) == sorted(
683 ["chain_{}".format(num) for num in range(len(self.chains))]
684 )
685 assert sorted(list(dataframe["chain_0"].keys())) == sorted(["a", "b"])
686 assert sorted(list(dataframe["chain_1"].keys())) == sorted(["a", "b"])
687 np.testing.assert_almost_equal(
688 dataframe["chain_0"]["a"], self.chains[0][0]
689 )
690 np.testing.assert_almost_equal(
691 dataframe["chain_0"]["b"], self.chains[0][1]
692 )
693 np.testing.assert_almost_equal(
694 dataframe["chain_1"]["a"], self.chains[1][0]
695 )
696 np.testing.assert_almost_equal(
697 dataframe["chain_1"]["b"], self.chains[1][1]
698 )
699 other = MCMCSamplesDict({
700 "chain_{}".format(num): {
701 param: self.chains[num][idx] for idx, param in enumerate(
702 self.parameters
703 )
704 } for num in range(len(self.chains))
705 })
706 assert sorted(other.keys()) == sorted(dataframe.keys())
707 assert sorted(other["chain_0"].keys()) == sorted(
708 dataframe["chain_0"].keys()
709 )
710 np.testing.assert_almost_equal(
711 other["chain_0"]["a"], dataframe["chain_0"]["a"]
712 )
713 np.testing.assert_almost_equal(
714 other["chain_0"]["b"], dataframe["chain_0"]["b"]
715 )
716 np.testing.assert_almost_equal(
717 other["chain_1"]["a"], dataframe["chain_1"]["a"]
718 )
719 np.testing.assert_almost_equal(
720 other["chain_1"]["b"], dataframe["chain_1"]["b"]
721 )
723 def test_unequal_chain_length(self):
724 """Test that when inverted, the chains keep their unequal chain
725 length
726 """
727 chains = [
728 [np.random.uniform(10, 0.5, 100), np.random.uniform(100, 10, 100)],
729 [np.random.uniform(5, 0.5, 200), np.random.uniform(80, 10, 200)]
730 ]
731 dataframe = MCMCSamplesDict(self.parameters, chains)
732 transpose = dataframe.T
733 assert len(transpose["a"]["chain_0"]) == 100
734 assert len(transpose["a"]["chain_1"]) == 200
735 assert dataframe.number_of_samples == {
736 "chain_0": 100, "chain_1": 200
737 }
738 assert dataframe.minimum_number_of_samples == 100
739 assert transpose.number_of_samples == dataframe.number_of_samples
740 assert transpose.minimum_number_of_samples == \
741 dataframe.minimum_number_of_samples
742 combined = dataframe.combine
743 assert combined.number_of_samples == 300
744 my_combine = np.concatenate(
745 [dataframe["chain_0"]["a"], dataframe["chain_1"]["a"]]
746 )
747 assert all(ss in my_combine for ss in combined["a"])
749 def test_properties(self):
750 """Test that the properties of the MCMCSamplesDict class are correct
751 """
752 dataframe = MCMCSamplesDict(self.parameters, self.chains)
753 transpose = dataframe.T
754 np.testing.assert_almost_equal(
755 dataframe["chain_0"]["a"], transpose["a"]["chain_0"]
756 )
757 np.testing.assert_almost_equal(
758 dataframe["chain_0"]["b"], transpose["b"]["chain_0"]
759 )
760 np.testing.assert_almost_equal(
761 dataframe["chain_1"]["a"], transpose["a"]["chain_1"]
762 )
763 np.testing.assert_almost_equal(
764 dataframe["chain_1"]["b"], transpose["b"]["chain_1"]
765 )
766 average = dataframe.average
767 transpose_average = transpose.average
768 for param in self.parameters:
769 np.testing.assert_almost_equal(
770 average[param], transpose_average[param]
771 )
772 assert dataframe.total_number_of_samples == 200
773 assert dataframe.total_number_of_samples == \
774 transpose.total_number_of_samples
775 combined = dataframe.combine
776 assert combined.number_of_samples == 200
777 mycombine = np.concatenate(
778 [dataframe["chain_0"]["a"], dataframe["chain_1"]["a"]]
779 )
780 assert all(s in combined["a"] for s in mycombine)
781 transpose_copy = transpose.T
782 assert sorted(list(transpose_copy.keys())) == sorted(list(dataframe.keys()))
783 for level1 in dataframe.keys():
784 assert sorted(list(transpose_copy[level1].keys())) == sorted(
785 list(dataframe[level1].keys())
786 )
787 for level2 in dataframe[level1].keys():
788 np.testing.assert_almost_equal(
789 transpose_copy[level1][level2], dataframe[level1][level2]
790 )
792 def test_key_data(self):
793 """Test that the key data is correct
794 """
795 dataframe = MCMCSamplesDict(self.parameters, self.chains)
796 key_data = dataframe.key_data
797 combined = dataframe.combine
798 for param, in key_data.keys():
799 np.testing.assert_almost_equal(
800 key_data[param]["mean"], np.mean(combined[param])
801 )
802 np.testing.assert_almost_equal(
803 key_data[param]["median"], np.median(combined[param])
804 )
806 def test_burnin_removal(self):
807 """Test that the different methods for removing the samples as burnin
808 as expected
809 """
810 uniform = np.random.uniform
811 parameters = ["a", "b", "cycle"]
812 chains = [
813 [uniform(10, 0.5, 100), uniform(100, 10, 100), uniform(1, 0.8, 100)],
814 [uniform(5, 0.5, 100), uniform(80, 10, 100), uniform(1, 0.8, 100)],
815 [uniform(1, 0.8, 100), uniform(90, 10, 100), uniform(1, 0.8, 100)]
816 ]
817 dataframe = MCMCSamplesDict(parameters, chains)
818 burnin = dataframe.burnin(algorithm="burnin_by_step_number")
819 idxs = np.argwhere(chains[0][2] > 0)
820 assert len(burnin["chain_0"]["a"]) == len(idxs)
821 dataframe = MCMCSamplesDict(parameters, chains)
822 burnin = dataframe.burnin(10, algorithm="burnin_by_first_n")
823 assert len(burnin["chain_0"]["a"]) == 90
824 dataframe = MCMCSamplesDict(parameters, chains)
825 burnin = dataframe.burnin(
826 10, algorithm="burnin_by_first_n", step_number=True
827 )
828 assert len(burnin["chain_0"]["a"]) == len(idxs) - 10
831class TestList(object):
832 """Test the List class
833 """
834 def test_added(self):
835 original = ["a", "b", "c", "d", "e"]
836 array = List(original)
837 assert not len(array.added)
838 array.append("f")
839 assert len(array.added)
840 assert array.added == ["f"]
841 array.extend(["g", "h"])
842 assert sorted(array.added) == sorted(["f", "g", "h"])
843 assert sorted(array.original) == sorted(original)
844 array.insert(2, "z")
845 assert sorted(array.added) == sorted(["f", "g", "h", "z"])
846 assert sorted(array) == sorted(original + ["f", "g", "h", "z"])
847 array = List(original)
848 array = array + ["f", "g", "h"]
849 assert sorted(array.added) == sorted(["f", "g", "h"])
850 array += ["i"]
851 assert sorted(array.added) == sorted(["f", "g", "h", "i"])
853 def test_removed(self):
854 original = ["a", "b", "c", "d", "e"]
855 array = List(original)
856 assert not len(array.removed)
857 array.remove("e")
858 assert sorted(array) == sorted(["a", "b", "c", "d"])
859 assert array.removed == ["e"]
860 assert not len(sorted(array.added))
861 array.extend(["f", "g"])
862 array.remove("f")
863 assert array.removed == ["e", "f"]
864 assert array.added == ["g"]
865 array.pop(0)
866 assert sorted(array.removed) == sorted(["e", "f", "a"])
869def test_2DArray():
870 """Test the pesummary.utils.array._2DArray class
871 """
872 samples = [
873 np.random.normal(np.random.randint(100), 0.2, size=1000) for _ in
874 range(10)
875 ]
876 arrays = _2DArray(samples)
877 for num, array in enumerate(arrays):
878 np.testing.assert_almost_equal(array, samples[num])
879 for num, array in enumerate(arrays):
880 assert array.standard_deviation == np.std(samples[num])
881 assert array.minimum == np.min(samples[num])
882 assert array.maximum == np.max(samples[num])
883 _key_data = array.key_data
884 assert _key_data["5th percentile"] == np.percentile(samples[num], 5)
885 assert _key_data["95th percentile"] == np.percentile(samples[num], 95)
886 assert _key_data["median"] == np.median(samples[num])
887 assert _key_data["mean"] == np.mean(samples[num])
889 samples = [
890 np.random.normal(np.random.randint(100), 0.2, size=1000) for _ in
891 range(10)
892 ]
893 likelihood = np.random.uniform(0, 1, 1000)
894 prior = np.random.uniform(0, 1, 1000)
895 arrays = _2DArray(samples, likelihood=likelihood, prior=prior)
896 for num, array in enumerate(arrays):
897 np.testing.assert_almost_equal(array, samples[num])
898 for num, array in enumerate(arrays):
899 assert array.standard_deviation == np.std(samples[num])
900 assert array.minimum == np.min(samples[num])
901 assert array.maximum == np.max(samples[num])
902 assert array.maxL == array[np.argmax(likelihood)]
903 assert array.maxP == array[np.argmax(likelihood + prior)]
904 _key_data = array.key_data
905 assert _key_data["5th percentile"] == np.percentile(samples[num], 5)
906 assert _key_data["95th percentile"] == np.percentile(samples[num], 95)
907 assert _key_data["median"] == np.median(samples[num])
908 assert _key_data["mean"] == np.mean(samples[num])
909 assert _key_data["maxL"] == array.maxL
910 assert _key_data["maxP"] == array.maxP
913class TestArray(object):
914 """Test the Array class
915 """
916 def test_properties(self):
917 samples = np.random.uniform(100, 10, 100)
918 array = Array(samples)
919 assert array.average(type="mean") == np.mean(samples)
920 assert array.average(type="median") == np.median(samples)
921 assert array.standard_deviation == np.std(samples)
922 np.testing.assert_almost_equal(
923 array.credible_interval(percentile=[5, 95]),
924 [np.percentile(array, 5), np.percentile(array, 95)]
925 )
927 def test_weighted_percentile(self):
928 x = np.random.normal(100, 20, 10000)
929 weights = np.array([np.random.randint(100) for _ in range(10000)])
930 array = Array(x, weights=weights)
931 numpy = np.percentile(np.repeat(x, weights), 90)
932 pesummary = array.credible_interval(percentile=90)
933 np.testing.assert_almost_equal(numpy, pesummary, 6)
935 def test_key_data(self):
936 samples = np.random.normal(100, 20, 10000)
937 array = Array(samples)
938 key_data = array.key_data
939 np.testing.assert_almost_equal(key_data["mean"], np.mean(samples))
940 np.testing.assert_almost_equal(key_data["median"], np.median(samples))
941 np.testing.assert_almost_equal(key_data["std"], np.std(samples))
942 np.testing.assert_almost_equal(
943 key_data["5th percentile"], np.percentile(samples, 5)
944 )
945 np.testing.assert_almost_equal(
946 key_data["95th percentile"], np.percentile(samples, 95)
947 )
948 assert key_data["maxL"] is None
949 assert key_data["maxP"] is None
952class TestTQDM(object):
953 """Test the pesummary.utils.tqdm.tqdm class
954 """
955 def setup_method(self):
956 self._range = range(100)
957 if not os.path.isdir(tmpdir):
958 os.mkdir(tmpdir)
960 def teardown_method(self):
961 """Remove the files and directories created from this class
962 """
963 if os.path.isdir(tmpdir):
964 shutil.rmtree(tmpdir)
966 def test_basic_iterator(self):
967 """Test that the core functionality of the tqdm class remains
968 """
969 for j in tqdm(self._range):
970 _ = j*j
972 def test_interaction_with_logger(self):
973 """Test that tqdm interacts nicely with logger
974 """
975 from pesummary.utils.utils import logger, LOG_FILE
977 with open("{}/test.dat".format(tmpdir), "w") as f:
978 for j in tqdm(self._range, logger=logger, file=f):
979 _ = j*j
981 with open("{}/test.dat".format(tmpdir), "r") as f:
982 lines = f.readlines()
983 assert "PESummary" in lines[-1]
984 assert "INFO" in lines[-1]
987def test_jensen_shannon_divergence():
988 """Test that the `jensen_shannon_divergence` method returns the same
989 values as the scipy function
990 """
991 from scipy.spatial.distance import jensenshannon
992 from scipy import stats
994 samples = [
995 np.random.uniform(5, 4, 100),
996 np.random.uniform(5, 4, 100)
997 ]
998 x = np.linspace(np.min(samples), np.max(samples), 100)
999 kde = [stats.gaussian_kde(i)(x) for i in samples]
1000 _scipy = jensenshannon(*kde)**2
1001 _pesummary = utils.jensen_shannon_divergence(samples, decimal=9)
1002 np.testing.assert_almost_equal(_scipy, _pesummary)
1004 from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE
1006 _pesummary = utils.jensen_shannon_divergence(
1007 samples, decimal=9, kde=ReflectionBoundedKDE, xlow=4.5, xhigh=5.5
1008 )
1011def test_make_cache_style_file():
1012 """Test that the `make_cache_style_file` works as expected
1013 """
1014 from pesummary.utils.utils import make_cache_style_file, CACHE_DIR
1015 sty = os.path.expanduser("{}/style/matplotlib_rcparams.sty".format(CACHE_DIR))
1016 with open("test.sty", "w") as f:
1017 f.writelines(["test : 10"])
1018 make_cache_style_file("test.sty")
1019 assert os.path.isfile(sty)
1020 with open(sty, "r") as f:
1021 lines = f.readlines()
1022 assert len(lines) == 1
1023 assert lines[0] == "test : 10"
1026def test_logger():
1027 with LogCapture() as l:
1028 utils.logger.propagate = True
1029 utils.logger.info("info")
1030 utils.logger.warning("warning")
1031 l.check(("PESummary", "INFO", "info"),
1032 ("PESummary", "WARNING", "warning"),)
1035def test_string_match():
1036 """function to test the string_match function
1037 """
1038 param = "mass_1"
1039 # Test that it matches itself
1040 assert utils.string_match(param, param)
1041 # Test that it does not match another string
1042 assert not utils.string_match(param, "mass_2")
1043 # Test that a single wildcard works
1044 assert utils.string_match(param, "{}*".format(param[0]))
1045 assert not utils.string_match(param, "z+")
1046 assert utils.string_match(param, "{}?".format(param[:-1]))
1047 # Test that multiple wildcards work
1048 assert utils.string_match(param, "*{}*".format(param[1]))
1049 assert utils.string_match(param, "*{}?".format(param[3:-1]))
1050 assert utils.string_match(param, "?{}?".format(param[1:-1]))
1051 assert not utils.string_match(param, "?{}?".format(param[:-1]))
1052 # Test 'start with' works
1053 assert utils.string_match(param, "^{}".format(param[:3]))
1054 # Test does not start with
1055 assert not utils.string_match(param, "^(?!{}).+".format(param[:3]))
1058def test_map_parameter_names():
1059 """function to test the map_parameter_names function
1060 """
1061 dictionary = {
1062 "a": np.random.uniform(1, 100, size=1000),
1063 "b": np.random.uniform(1, 100, size=1000),
1064 "c": np.random.uniform(1, 100, size=1000)
1065 }
1066 mapping = {"a": "z", "b": "y", "c": "x"}
1067 dictionary_copy = dictionary.copy()
1068 new_dictionary = utils.map_parameter_names(dictionary_copy, mapping)
1069 # check that the new keys are in the new dictionary
1070 assert sorted(list(new_dictionary.keys())) == ["x", "y", "z"]
1071 # check that the new dictionary has the same items as before
1072 for old, new in mapping.items():
1073 np.testing.assert_almost_equal(dictionary[old], new_dictionary[new])
1074 # check that the old dictionary remains unchanged
1075 for key, item in dictionary.items():
1076 np.testing.assert_almost_equal(item, dictionary_copy[key])
1079def test_list_match():
1080 """function to test the list_match function
1081 """
1082 params = [
1083 "mass_1", "mass_2", "chirp_mass", "total_mass", "mass_ratio", "a_1",
1084 "luminosity_distance", "x"
1085 ]
1086 # Test that all mass parameters are returned
1087 assert sorted(utils.list_match(params, "*mass*")) == sorted(
1088 [p for p in params if "mass" in p]
1089 )
1090 # Test that only parameters with an "_" are returned
1091 assert sorted(utils.list_match(params, "*_*")) == sorted(
1092 [p for p in params if "_" in p]
1093 )
1094 # Test that nothing is returned
1095 assert not len(utils.list_match(params, "z+"))
1096 # Test that only parameters that do not start with mass are returned
1097 assert sorted(utils.list_match(params, "^(?!mass).+")) == sorted(
1098 [p for p in params if p[:4] != "mass"]
1099 )
1100 # Test that only parameters that start with 'l' are returned
1101 assert sorted(utils.list_match(params, "^l")) == sorted(
1102 [p for p in params if p[0] == "l"]
1103 )
1104 # Test return false
1105 assert sorted(utils.list_match(params, "^l", return_false=True)) == sorted(
1106 [p for p in params if p[0] != "l"]
1107 )
1108 # Test multiple substrings
1109 assert sorted(utils.list_match(params, ["^m", "*2"])) == sorted(["mass_2"])
1112class TestDict(object):
1113 """Class to test the NestedDict object
1114 """
1115 def test_initiate(self):
1116 """Initiate the Dict class
1117 """
1118 from pesummary.gw.file.psd import PSD
1120 x = Dict(
1121 {"a": [[10, 20], [10, 20]]}, value_class=PSD,
1122 value_columns=["value", "value2"]
1123 )
1124 assert list(x.keys()) == ["a"]
1125 np.testing.assert_almost_equal(x["a"], [[10, 20], [10, 20]])
1126 assert isinstance(x["a"], PSD)
1127 np.testing.assert_almost_equal(x["a"].value, [10, 10])
1128 np.testing.assert_almost_equal(x["a"].value2, [20, 20])
1130 x = Dict(
1131 ["a"], [[[10, 20], [10, 20]]], value_class=PSD,
1132 value_columns=["value", "value2"]
1133 )
1134 assert list(x.keys()) == ["a"]
1135 np.testing.assert_almost_equal(x["a"], [[10, 20], [10, 20]])
1136 assert isinstance(x["a"], PSD)
1137 np.testing.assert_almost_equal(x["a"].value, [10, 10])
1138 np.testing.assert_almost_equal(x["a"].value2, [20, 20])
1140 def test_specify_columns(self):
1141 """Test that x[["a", "b"]] works as expected
1142 """
1143 x = Dict({"a": [10], "b": [20], "c": [30], "d": [40]}, value_class=list)
1144 y = x[["a", "b"]]
1145 assert sorted(list(y.keys())) == ["a", "b"]
1146 for key in y.keys():
1147 assert y[key] == x[key]
1148 with pytest.raises(Exception):
1149 z = x[["e", "f"]]
1152class TestProbabilityDict(object):
1153 """Test the ProbabilityDict class
1154 """
1155 def setup_method(self):
1156 """Setup the ProbabilityDict class
1157 """
1158 self.mydict = {
1159 "a": [[1,2,3,4], [0.1, 0.2, 0.3, 0.4]],
1160 "b": [[1,3,5,7], [0.1, 0.3, 0.2, 0.4]]
1161 }
1163 def test_initiate(self):
1164 """Test the different ways to initiate this class
1165 """
1166 probs = ProbabilityDict(self.mydict)
1167 assert all(param in probs.keys() for param in self.mydict.keys())
1168 for param, data in self.mydict.items():
1169 np.testing.assert_almost_equal(data[0], probs[param].x)
1170 np.testing.assert_almost_equal(data[1], probs[param].probs)
1171 parameters = list(self.mydict.keys())
1172 data = [self.mydict[key] for key in parameters]
1173 probs = ProbabilityDict(parameters, data)
1174 assert all(param in probs.keys() for param in self.mydict.keys())
1175 for param, data in self.mydict.items():
1176 np.testing.assert_almost_equal(data[0], probs[param].x)
1177 np.testing.assert_almost_equal(data[1], probs[param].probs)
1179 def test_rvs(self):
1180 """Test the .rvs method
1181 """
1182 probs = ProbabilityDict(self.mydict)
1183 # draw 10 samples for all parameters
1184 samples = probs.rvs(size=10, interpolate=False)
1185 assert isinstance(samples, SamplesDict)
1186 assert all(param in samples.keys() for param in self.mydict.keys())
1187 assert all(len(samples[param] == 10) for param in self.mydict.keys())
1188 for param in self.mydict.keys():
1189 for p in samples[param]:
1190 print(p, self.mydict[param][0])
1191 assert all(p in self.mydict[param][0] for p in samples[param])
1192 # draw 10 samples for only parameter 'a'
1193 samples = probs.rvs(size=10, parameters=["a"])
1194 assert list(samples.keys()) == ["a"]
1195 assert len(samples["a"]) == 10
1196 # interpolate first and then draw samples
1197 samples = probs.rvs(size=10, interpolate=True)
1198 assert isinstance(samples, SamplesDict)
1199 assert all(param in samples.keys() for param in self.mydict.keys())
1200 assert all(len(samples[param] == 10) for param in self.mydict.keys())
1202 def test_plotting(self):
1203 """Test the .plot method
1204 """
1205 probs = ProbabilityDict(self.mydict)
1206 # Test that a histogram plot is generated when required
1207 fig = probs.plot("a", type="hist")
1208 fig = probs.plot("a", type="marginalized_posterior")
1209 # Test that a PDF is generated when required
1210 fig = probs.plot("a", type="pdf")
1213class TestProbabilityDict2D(object):
1214 """Class to test the ProbabilityDict2D class
1215 """
1216 def setup_method(self):
1217 """Setup the TestProbabilityDict2D class
1218 """
1219 probs = np.random.uniform(0, 1, 100).reshape(10, 10)
1220 self.mydict = {
1221 "x_y": [
1222 np.random.uniform(0, 1, 10), np.random.uniform(0, 1, 10),
1223 probs / np.sum(probs)
1224 ]
1225 }
1227 def test_plotting(self):
1228 """Test the .plot method
1229 """
1230 # Test that a 2d KDE plot is generated when specified
1231 mydict = ProbabilityDict2D(self.mydict)
1232 fig = mydict.plot("x_y", type="2d_kde")
1233 # Test that a triangle plot is generated when specified
1234 fig, _, _, _ = mydict.plot("x_y", type="triangle")
1235 # Test that a reverse triangle plot is generated when specified
1236 fig, _, _, _ = mydict.plot("x_y", type="reverse_triangle")
1237 # Test that an Exception is raised when you try and generate a plot
1238 # which does not exist
1239 with pytest.raises(Exception):
1240 fig = mydict.plot("x_y", type="does_not_exist")
1243class TestDiscretePDF2D(object):
1244 """Class to test the DiscretePDF2D class
1245 """
1246 def setup_method(self):
1247 """Setup the TestDiscretePDF2D class
1248 """
1249 self.x = [1,2]
1250 self.y = [1,2]
1251 self.probs = [[0.25, 0.25], [0.25, 0.25]]
1253 def test_initiate(self):
1254 """
1255 """
1256 # Test that a ValueError is raised if a 2d probability array is not
1257 # provided
1258 with pytest.raises(ValueError):
1259 _ = DiscretePDF2D(self.x, self.y, [1,2,3,4,5,6])
1260 obj = DiscretePDF2D(self.x, self.y, self.probs)
1262 def test_marginalize(self):
1263 obj = DiscretePDF2D(self.x, self.y, self.probs)
1264 marg = obj.marginalize()
1265 assert isinstance(marg, DiscretePDF2Dplus1D)
1266 np.testing.assert_almost_equal(obj.probs, marg.probs_xy.probs)
1269class TestDiscretePDF2Dplus1D(object):
1270 """Class to test the DiscretePDF2Dplus1D class
1271 """
1272 def setup_method(self):
1273 """Setup the TestDiscretePDF2Dplus1D class
1274 """
1275 self.x = [1,2]
1276 self.y = [1,2]
1277 self.probs = [[0.5, 0.5], [0.5, 0.5], [[0.25, 0.25], [0.25, 0.25]]]
1279 def test_initiate_raise(self):
1280 """Test that the class raises ValueErrors when required
1281 """
1282 # Raise error when 3 probabilities are not given
1283 with pytest.raises(ValueError):
1284 _ = DiscretePDF2Dplus1D(self.x, self.y, self.probs[:-1])
1285 # Raise error when 2d probability not given
1286 with pytest.raises(ValueError):
1287 _ = DiscretePDF2Dplus1D(
1288 self.x, self.y, self.probs[:-1] + self.probs[:1]
1289 )
1290 # Raise error when only 1 1d array is given
1291 with pytest.raises(ValueError):
1292 _ = DiscretePDF2Dplus1D(
1293 self.x, self.y, self.probs[1:] + self.probs[-1:]
1294 )
1296 def test_initiate(self):
1297 """Test that we can initiate the class
1298 """
1299 obj = DiscretePDF2Dplus1D(self.x, self.y, self.probs)
1300 assert isinstance(obj.probs[0], DiscretePDF)
1301 assert isinstance(obj.probs[2], DiscretePDF2D)
1302 assert isinstance(obj.probs[1], DiscretePDF)
1305class TestKDEList(object):
1306 """Test the KDEList class
1307 """
1308 def setup_method(self):
1309 """Setup the KDEList class
1310 """
1311 from scipy.stats import gaussian_kde
1312 means = np.random.uniform(0, 5, size=10)
1313 stds = np.random.uniform(0, 2, size=10)
1314 self.inputs = np.array(
1315 [
1316 np.random.normal(mean, std, size=10000) for mean, std
1317 in zip(means, stds)
1318 ]
1319 )
1320 self.pts = np.linspace(-10, 10, 100)
1321 self.true_kdes = [gaussian_kde(_)(self.pts) for _ in self.inputs]
1323 def test_call(self):
1324 """Test the KDEList call method
1325 """
1326 from pesummary.utils.kde_list import KDEList
1327 from scipy.stats import gaussian_kde
1328 new = KDEList(self.inputs, kde=gaussian_kde, pts=self.pts)
1329 # test on a single CPU
1330 new_single = new(multi_process=1)
1331 for num in range(len(self.inputs)):
1332 np.testing.assert_almost_equal(self.true_kdes[num], new_single[num])
1333 # test on multiple CPUs
1334 new_multiple = new(multi_process=2)
1335 for num in range(len(self.inputs)):
1336 np.testing.assert_almost_equal(self.true_kdes[num], new_single[num])
1337 # test the `idx` kwarg
1338 for num in range(len(self.inputs)):
1339 np.testing.assert_almost_equal(self.true_kdes[num], new(idx=num))
1340 # split the calculation into 2 chunks and 2 CPUs
1341 new_chunk1 = new(
1342 idx=np.arange(0, len(self.inputs) // 2, 1), multi_process=2
1343 ).tolist()
1344 new_chunk2 = new(
1345 idx=np.arange(len(self.inputs) // 2, len(self.inputs), 1), multi_process=2
1346 ).tolist()
1347 new_combined = new_chunk1 + new_chunk2
1348 for num in range(len(self.inputs)):
1349 np.testing.assert_almost_equal(self.true_kdes[num], new_combined[num])
1350 # assert that different KDEs are produced for different pts
1351 with pytest.raises(AssertionError):
1352 new_diff = new(pts=np.linspace(-5, 5, 100), multi_process=1)
1353 for num in range(len(self.inputs)):
1354 np.testing.assert_almost_equal(self.true_kdes[num], new_diff[num])
1357def make_cache_style_file(style_file):
1358 """Make a cache directory which stores the style file you wish to use
1359 when plotting
1361 Parameters
1362 ----------
1363 style_file: str
1364 path to the style file that you wish to use when plotting
1365 """
1366 make_dir(CACHE_DIR)
1367 shutil.copyfile(
1368 style_file, os.path.join(CACHE_DIR, "matplotlib_rcparams.sty")
1369 )
1372def get_matplotlib_style_file():
1373 """Return the path to the matplotlib style file that you wish to use
1374 """
1375 return os.path.join(CACHE_DIR, "matplotlib_rcparams.sty")