Coverage for pesummary/tests/utils_test.py: 37.0%
710 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import os
4import shutil
5import h5py
6import numpy as np
7import copy
9import pesummary
10from pesummary.io import write
11import pesummary.cli as cli
12from pesummary.utils import utils
13from pesummary.utils.tqdm import tqdm
14from pesummary.utils.dict import Dict
15from pesummary.utils.list import List
16from pesummary.utils.pdf import DiscretePDF, DiscretePDF2D, DiscretePDF2Dplus1D
17from pesummary.utils.array import _2DArray
18from pesummary.utils.samples_dict import (
19 Array, SamplesDict, MCMCSamplesDict, MultiAnalysisSamplesDict
20)
21from pesummary.utils.probability_dict import ProbabilityDict, ProbabilityDict2D
22from pesummary._version_helper import GitInformation, PackageInformation
23from pesummary._version_helper import get_version_information
25import pytest
26from testfixtures import LogCapture
27import tempfile
29tmpdir = tempfile.TemporaryDirectory(prefix=".", dir=".").name
30__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
32DEFAULT_DIRECTORY = os.getenv("CI_PROJECT_DIR", os.getcwd())
35class TestGitInformation(object):
36 """Class to test the GitInformation helper class
37 """
38 def setup_method(self):
39 """Setup the TestGitInformation class
40 """
41 self.git = GitInformation(directory=DEFAULT_DIRECTORY)
43 def test_last_commit_info(self):
44 """Test the last_commit_info property
45 """
46 assert len(self.git.last_commit_info) == 2
47 assert isinstance(self.git.last_commit_info[0], str)
48 assert isinstance(self.git.last_commit_info[1], str)
50 def test_last_version(self):
51 """Test the last_version property
52 """
53 assert isinstance(self.git.last_version, str)
55 def test_status(self):
56 """Test the status property
57 """
58 assert isinstance(self.git.status, str)
60 def test_builder(self):
61 """Test the builder property
62 """
63 assert isinstance(self.git.builder, str)
65 def test_build_date(self):
66 """Test the build_date property
67 """
68 assert isinstance(self.git.build_date, str)
71class TestPackageInformation(object):
72 """Class to test the PackageInformation helper class
73 """
74 def setup_method(self):
75 """Setup the TestPackageInformation class
76 """
77 self.package = PackageInformation()
79 def test_package_info(self):
80 """Test the package_info property
81 """
82 pi = self.package.package_info
83 assert isinstance(pi, list)
84 pkg = pi[0]
85 assert "name" in pkg
86 assert "version" in pkg
87 if "build_string" in pkg: # conda only
88 assert "channel" in pkg
91class TestUtils(object):
92 """Class to test pesummary.utils.utils
93 """
94 def setup_method(self):
95 """Setup the TestUtils class
96 """
97 if not os.path.isdir(tmpdir):
98 os.mkdir(tmpdir)
100 def teardown_method(self):
101 """Remove the files created from this class
102 """
103 if os.path.isdir(tmpdir):
104 shutil.rmtree(tmpdir)
106 def test_check_condition(self):
107 """Test the check_condition method
108 """
109 with pytest.raises(Exception) as info:
110 condition = True
111 utils.check_condition(condition, "error")
112 assert str(info.value) == "error"
114 def test_rename_group_in_hf5_file(self):
115 """Test the rename_group_in_hf5_file method
116 """
117 f = h5py.File("{}/rename_group.h5".format(tmpdir), "w")
118 group = f.create_group("group")
119 group.create_dataset("example", data=np.array([10]))
120 f.close()
121 utils.rename_group_or_dataset_in_hf5_file(
122 "{}/rename_group.h5".format(tmpdir),
123 group=["group", "replaced"])
124 f = h5py.File("{}/rename_group.h5".format(tmpdir))
125 assert list(f.keys()) == ["replaced"]
126 assert list(f["replaced"].keys()) == ["example"]
127 assert len(f["replaced/example"]) == 1
128 assert f["replaced/example"][0] == 10
129 f.close()
131 def test_rename_dataset_in_hf5_file(self):
132 f = h5py.File("{}/rename_dataset.h5".format(tmpdir), "w")
133 group = f.create_group("group")
134 group.create_dataset("example", data=np.array([10]))
135 f.close()
136 utils.rename_group_or_dataset_in_hf5_file(
137 "{}/rename_dataset.h5".format(tmpdir),
138 dataset=["group/example", "group/replaced"])
139 f = h5py.File("{}/rename_dataset.h5".format(tmpdir))
140 assert list(f.keys()) == ["group"]
141 assert list(f["group"].keys()) == ["replaced"]
142 assert len(f["group/replaced"]) == 1
143 assert f["group/replaced"][0] == 10
144 f.close()
146 def test_rename_unknown_hf5_file(self):
147 with pytest.raises(Exception) as info:
148 utils.rename_group_or_dataset_in_hf5_file(
149 "{}/unknown.h5".format(tmpdir),
150 group=["None", "replaced"])
151 assert "does not exist" in str(info.value)
153 def test_directory_creation(self):
154 directory = '{}/test_dir'.format(tmpdir)
155 assert os.path.isdir(directory) == False
156 utils.make_dir(directory)
157 assert os.path.isdir(directory) == True
159 def test_url_guess(self):
160 host = ["raven", "cit", "ligo-wa", "uwm", "phy.syr.edu", "vulcan",
161 "atlas", "iucaa", "alice"]
162 expected = ["https://geo2.arcca.cf.ac.uk/~albert.einstein/test",
163 "https://ldas-jobs.ligo.caltech.edu/~albert.einstein/test",
164 "https://ldas-jobs.ligo-wa.caltech.edu/~albert.einstein/test",
165 "https://ldas-jobs.phys.uwm.edu/~albert.einstein/test",
166 "https://sugar-jobs.phy.syr.edu/~albert.einstein/test",
167 "https://galahad.aei.mpg.de/~albert.einstein/test",
168 "https://atlas1.atlas.aei.uni-hannover.de/~albert.einstein/test",
169 "https://ldas-jobs.gw.iucaa.in/~albert.einstein/test",
170 "https://dumpty.alice.icts.res.in/~albert.einstein/test"]
171 user = "albert.einstein"
172 webdir = '/home/albert.einstein/public_html/test'
173 for i,j in zip(host, expected):
174 url = utils.guess_url(webdir, i, user)
175 assert url == j
177 def test_make_dir(self):
178 """Test the make_dir method
179 """
180 assert not os.path.isdir(os.path.join(tmpdir, "test"))
181 utils.make_dir(os.path.join(tmpdir, "test"))
182 assert os.path.isdir(os.path.join(tmpdir, "test"))
183 with open(os.path.join(tmpdir, "test", "test.dat"), "w") as f:
184 f.writelines(["test"])
185 utils.make_dir(os.path.join(tmpdir, "test"))
186 assert os.path.isfile(os.path.join(tmpdir, "test", "test.dat"))
188 def test_resample_posterior_distribution(self):
189 """Test the resample_posterior_distribution method
190 """
191 data = np.random.normal(1, 0.1, 1000)
192 resampled = utils.resample_posterior_distribution([data], 500)
193 assert len(resampled) == 500
194 assert np.round(np.mean(resampled), 1) == 1.
195 assert np.round(np.std(resampled), 1) == 0.1
197 def test_gw_results_file(self):
198 """Test the gw_results_file method
199 """
200 from .base import namespace
202 opts = namespace({"gw": True, "psd": True})
203 assert utils.gw_results_file(opts)
204 opts = namespace({"webdir": tmpdir})
205 assert not utils.gw_results_file(opts)
207 def test_functions(self):
208 """Test the functions method
209 """
210 from .base import namespace
212 opts = namespace({"gw": True, "psd": True})
213 funcs = utils.functions(opts)
214 assert funcs["input"] == pesummary.gw.cli.inputs.WebpagePlusPlottingPlusMetaFileInput
215 assert funcs["MetaFile"] == pesummary.gw.file.meta_file.GWMetaFile
217 opts = namespace({"webdir": tmpdir})
218 funcs = utils.functions(opts)
219 assert funcs["input"] == pesummary.core.cli.inputs.WebpagePlusPlottingPlusMetaFileInput
220 assert funcs["MetaFile"] == pesummary.core.file.meta_file.MetaFile
222 def test_get_version_information(self):
223 """Test the get_version_information method
224 """
225 assert isinstance(get_version_information(), str)
228class TestGelmanRubin(object):
229 """Test the Gelman Rubin calculation
230 """
231 def test_same_as_lalinference(self):
232 """Test the Gelman rubin output from pesummary is the same as
233 the one coded in LALInference
234 """
235 from lalinference.bayespputils import Posterior
236 from pesummary.utils.utils import gelman_rubin
238 header = ["a", "b", "logL", "chain"]
239 for _ in np.arange(100):
240 samples = np.array(
241 [
242 np.random.uniform(np.random.random(), 0.1, 3).tolist() +
243 [np.random.randint(1, 3)] for _ in range(10)
244 ]
245 )
246 obj = Posterior([header, np.array(samples)])
247 R = obj.gelman_rubin("a")
248 chains = np.unique(obj["chain"].samples)
249 chain_index = obj.names.index("chain")
250 param_index = obj.names.index("a")
251 data, _ = obj.samples()
252 chainData=[
253 data[data[:,chain_index] == chain, param_index] for chain in
254 chains
255 ]
256 np.testing.assert_almost_equal(
257 gelman_rubin(chainData, decimal=10), R, 7
258 )
260 def test_same_samples(self):
261 """Test that when passed two identical chains (perfect convergence),
262 the Gelman Rubin is 1
263 """
264 from pesummary.core.plots.plot import gelman_rubin
266 samples = np.random.uniform(1, 0.5, 10)
267 R = gelman_rubin([samples, samples])
268 assert R == 1
271class TestSamplesDict(object):
272 """Test the SamplesDict class
273 """
274 def setup_method(self):
275 self.parameters = ["a", "b"]
276 self.samples = [
277 np.random.uniform(10, 0.5, 100), np.random.uniform(200, 10, 100)
278 ]
279 if not os.path.isdir(tmpdir):
280 os.mkdir(tmpdir)
281 write(
282 self.parameters, np.array(self.samples).T, outdir=tmpdir,
283 filename="test.dat", file_format="dat"
284 )
286 def teardown_method(self):
287 """Remove the files created from this class
288 """
289 if os.path.isdir(tmpdir):
290 shutil.rmtree(tmpdir)
292 def test_initalize(self):
293 """Test that the two ways to initialize the SamplesDict class are
294 equivalent
295 """
296 base = SamplesDict(self.parameters, self.samples)
297 other = SamplesDict(
298 {
299 param: sample for param, sample in zip(
300 self.parameters, self.samples
301 )
302 }
303 )
304 assert base.parameters == other.parameters
305 assert sorted(base.parameters) == sorted(self.parameters)
306 np.testing.assert_almost_equal(base.samples, other.samples)
307 assert sorted(list(base.keys())) == sorted(list(other.keys()))
308 np.testing.assert_almost_equal(base.samples, self.samples)
309 class_method = SamplesDict.from_file(
310 "{}/test.dat".format(tmpdir), add_zero_likelihood=False
311 )
312 np.testing.assert_almost_equal(class_method.samples, self.samples)
314 def test_properties(self):
315 """Test that the properties of the SamplesDict class are correct
316 """
317 import pandas as pd
319 dataset = SamplesDict(self.parameters, self.samples)
320 assert sorted(dataset.minimum.keys()) == sorted(self.parameters)
321 assert dataset.minimum["a"] == np.min(self.samples[0])
322 assert dataset.minimum["b"] == np.min(self.samples[1])
323 assert dataset.median["a"] == np.median(self.samples[0])
324 assert dataset.median["b"] == np.median(self.samples[1])
325 assert dataset.mean["a"] == np.mean(self.samples[0])
326 assert dataset.mean["b"] == np.mean(self.samples[1])
327 assert dataset.number_of_samples == len(self.samples[1])
328 assert len(dataset.downsample(10)["a"]) == 10
329 dataset = SamplesDict(self.parameters, self.samples)
330 assert len(dataset.discard_samples(10)["a"]) == len(self.samples[0]) - 10
331 np.testing.assert_almost_equal(
332 dataset.discard_samples(10)["a"], self.samples[0][10:]
333 )
334 np.testing.assert_almost_equal(
335 dataset.discard_samples(10)["b"], self.samples[1][10:]
336 )
337 p = dataset.to_pandas()
338 assert isinstance(p, pd.core.frame.DataFrame)
339 remove = dataset.pop("a")
340 assert list(dataset.keys()) == ["b"]
342 def test_core_plots(self):
343 """Test that the core plotting methods of the SamplesDict class work as
344 expected
345 """
346 import matplotlib.figure
348 dataset = SamplesDict(self.parameters, self.samples)
349 fig = dataset.plot(self.parameters[0], type="hist")
350 assert isinstance(fig, matplotlib.figure.Figure)
351 fig = dataset.plot(self.parameters[0], type="marginalized_posterior")
352 assert isinstance(fig, matplotlib.figure.Figure)
354 def test_standardize_parameter_names(self):
355 """Test the standardize_parameter_names method
356 """
357 dictionary = {"mass1": [1,2,3,4], "mass2": [4,3,2,1], "zz": [1,1,1,1]}
358 dictionary_copy = dictionary.copy()
359 mydict = SamplesDict(dictionary_copy)
360 standard_dict = mydict.standardize_parameter_names()
361 # check standard parameter names are in the new dictionary
362 assert sorted(list(standard_dict.keys())) == ["mass_1", "mass_2", "zz"]
363 # check that the dictionary items remains the same
364 _mapping = {"mass1": "mass_1", "mass2": "mass_2", "zz": "zz"}
365 for key, item in dictionary.items():
366 np.testing.assert_almost_equal(item, standard_dict[_mapping[key]])
367 # check old dictionary remains unchanged
368 assert sorted(list(mydict.keys())) == ["mass1", "mass2", "zz"]
369 for key, item in mydict.items():
370 np.testing.assert_almost_equal(item, dictionary[key])
371 # try custom mapping
372 mapping = {"mass1": "custom_m1", "mass2": "custom_m2", "zz": "custom_zz"}
373 new_dict = mydict.standardize_parameter_names(mapping=mapping)
374 assert sorted(list(new_dict.keys())) == [
375 "custom_m1", "custom_m2", "custom_zz"
376 ]
377 for old, new in mapping.items():
378 np.testing.assert_almost_equal(dictionary[old], new_dict[new])
380 def test_waveforms(self):
381 """Test the waveform generation
382 """
383 from pesummary.core.fetch import download_dir
384 try:
385 from pycbc.waveform import get_fd_waveform, get_td_waveform
386 except (ValueError, ImportError):
387 return
389 downloaded_file = os.path.join(
390 download_dir, "GW190814_posterior_samples.h5"
391 )
392 if not os.path.isfile(downloaded_file):
393 from pesummary.gw.fetch import fetch_open_samples
394 f = fetch_open_samples(
395 "GW190814", read_file=True, outdir=download_dir, unpack=True,
396 path="GW190814.h5", catalog="GWTC-2"
397 )
398 else:
399 from pesummary.io import read
400 f = read(downloaded_file)
401 samples = f.samples_dict["C01:IMRPhenomPv3HM"]
402 ind = 0
403 data = samples.fd_waveform("IMRPhenomPv3HM", 1./256, 20., 1024., ind=ind)
404 hp_pycbc, hc_pycbc = get_fd_waveform(
405 approximant="IMRPhenomPv3HM", mass1=samples["mass_1"][ind],
406 mass2=samples["mass_2"][ind], spin1x=samples["spin_1x"][ind],
407 spin1y=samples["spin_1y"][ind], spin1z=samples["spin_1z"][ind],
408 spin2x=samples["spin_2x"][ind], spin2y=samples["spin_2y"][ind],
409 spin2z=samples["spin_2z"][ind], inclination=samples["iota"][ind],
410 distance=samples["luminosity_distance"][ind],
411 coa_phase=samples["phase"][ind], f_lower=20., f_final=1024.,
412 delta_f=1./256, f_ref=20.
413 )
414 np.testing.assert_almost_equal(
415 data["h_plus"].frequencies.value, hp_pycbc.sample_frequencies
416 )
417 np.testing.assert_almost_equal(
418 data["h_cross"].frequencies.value, hc_pycbc.sample_frequencies
419 )
420 np.testing.assert_almost_equal(
421 data["h_plus"].value * 10**25, hp_pycbc._data * 10**25
422 )
423 np.testing.assert_almost_equal(
424 data["h_cross"].value * 10**25, hc_pycbc._data * 10**25
425 )
426 data = samples.td_waveform("SEOBNRv4PHM", 1./4096, 20., ind=ind)
427 hp_pycbc, hc_pycbc = get_td_waveform(
428 approximant="SEOBNRv4PHM", mass1=samples["mass_1"][ind],
429 mass2=samples["mass_2"][ind], spin1x=samples["spin_1x"][ind],
430 spin1y=samples["spin_1y"][ind], spin1z=samples["spin_1z"][ind],
431 spin2x=samples["spin_2x"][ind], spin2y=samples["spin_2y"][ind],
432 spin2z=samples["spin_2z"][ind], inclination=samples["iota"][ind],
433 distance=samples["luminosity_distance"][ind],
434 coa_phase=samples["phase"][ind], f_lower=20., delta_t=1./4096,
435 f_ref=20.
436 )
437 np.testing.assert_almost_equal(
438 data["h_plus"].times.value, hp_pycbc.sample_times
439 )
440 np.testing.assert_almost_equal(
441 data["h_cross"].times.value, hc_pycbc.sample_times
442 )
443 np.testing.assert_almost_equal(
444 data["h_plus"].value * 10**25, hp_pycbc._data * 10**25
445 )
446 np.testing.assert_almost_equal(
447 data["h_cross"].value * 10**25, hc_pycbc._data * 10**25
448 )
451class TestMultiAnalysisSamplesDict(object):
452 """Test the MultiAnalysisSamplesDict class
453 """
454 def setup_method(self):
455 self.parameters = ["a", "b"]
456 self.samples = [
457 [np.random.uniform(10, 0.5, 100), np.random.uniform(100, 10, 100)],
458 [np.random.uniform(5, 0.5, 100), np.random.uniform(80, 10, 100)],
459 ]
460 self.labels = ["one", "two"]
461 if not os.path.isdir(tmpdir):
462 os.mkdir(tmpdir)
463 for num, _samples in enumerate(self.samples):
464 write(
465 self.parameters, np.array(_samples).T, outdir=tmpdir,
466 filename="test_{}.dat".format(num + 1), file_format="dat"
467 )
469 def teardown_method(self):
470 """Remove the files created from this class
471 """
472 if os.path.isdir(tmpdir):
473 shutil.rmtree(tmpdir)
475 def test_initalize(self):
476 """Test the different ways to initalize the class
477 """
478 dataframe = MultiAnalysisSamplesDict(
479 self.parameters, self.samples, labels=["one", "two"]
480 )
481 assert sorted(list(dataframe.keys())) == sorted(self.labels)
482 assert sorted(list(dataframe["one"])) == sorted(["a", "b"])
483 assert sorted(list(dataframe["two"])) == sorted(["a", "b"])
484 np.testing.assert_almost_equal(
485 dataframe["one"]["a"], self.samples[0][0]
486 )
487 np.testing.assert_almost_equal(
488 dataframe["one"]["b"], self.samples[0][1]
489 )
490 np.testing.assert_almost_equal(
491 dataframe["two"]["a"], self.samples[1][0]
492 )
493 np.testing.assert_almost_equal(
494 dataframe["two"]["b"], self.samples[1][1]
495 )
496 _other = MCMCSamplesDict({
497 label: {
498 param: self.samples[num][idx] for idx, param in enumerate(
499 self.parameters
500 )
501 } for num, label in enumerate(self.labels)
502 })
503 class_method = MultiAnalysisSamplesDict.from_files(
504 {
505 'one': "{}/test_1.dat".format(tmpdir),
506 'two': "{}/test_2.dat".format(tmpdir)
507 }, add_zero_likelihood=False
508 )
509 for other in [_other, class_method]:
510 assert sorted(other.keys()) == sorted(dataframe.keys())
511 assert sorted(other["one"].keys()) == sorted(
512 dataframe["one"].keys()
513 )
514 np.testing.assert_almost_equal(
515 other["one"]["a"], dataframe["one"]["a"]
516 )
517 np.testing.assert_almost_equal(
518 other["one"]["b"], dataframe["one"]["b"]
519 )
520 np.testing.assert_almost_equal(
521 other["two"]["a"], dataframe["two"]["a"]
522 )
523 np.testing.assert_almost_equal(
524 other["two"]["b"], dataframe["two"]["b"]
525 )
528 def test_different_samples_for_different_analyses(self):
529 """Test that nothing breaks when different samples have different parameters
530 """
531 data = {
532 "one": {
533 "a": np.random.uniform(10, 0.5, 100),
534 "b": np.random.uniform(5, 0.5, 100)
535 }, "two": {
536 "a": np.random.uniform(10, 0.5, 100)
537 }
538 }
539 dataframe = MultiAnalysisSamplesDict(data)
540 assert sorted(dataframe["one"].keys()) == sorted(data["one"].keys())
541 assert sorted(dataframe["two"].keys()) == sorted(data["two"].keys())
542 np.testing.assert_almost_equal(
543 dataframe["one"]["a"], data["one"]["a"]
544 )
545 np.testing.assert_almost_equal(
546 dataframe["one"]["b"], data["one"]["b"]
547 )
548 np.testing.assert_almost_equal(
549 dataframe["two"]["a"], data["two"]["a"]
550 )
551 with pytest.raises(ValueError):
552 transpose = dataframe.T
554 def test_adding_to_table(self):
555 """
556 """
557 data = {
558 "one": {
559 "a": np.random.uniform(10, 0.5, 100),
560 "b": np.random.uniform(5, 0.5, 100)
561 }, "two": {
562 "a": np.random.uniform(10, 0.5, 100)
563 }
564 }
565 dataframe = MultiAnalysisSamplesDict(data)
566 assert "three" not in dataframe.parameters.keys()
567 new_data = {"a": np.random.uniform(10, 0.5, 100)}
568 dataframe["three"] = new_data
569 np.testing.assert_almost_equal(dataframe["three"]["a"], new_data["a"])
570 assert dataframe.parameters["three"] == ["a"]
571 assert "three" in dataframe.number_of_samples.keys()
573 def test_combine(self):
574 """Test that the .combine method is working as expected
575 """
576 data = {
577 "one": {
578 "a": np.random.uniform(10, 0.5, 100),
579 "b": np.random.uniform(5, 0.5, 100)
580 }, "two": {
581 "a": np.random.uniform(100, 0.5, 100),
582 "b": np.random.uniform(50, 0.5, 100)
583 }
584 }
585 dataframe = MultiAnalysisSamplesDict(data)
586 # test that when weights are 0 and 1, we only get the second set of
587 # samples
588 combine = dataframe.combine(weights={"one": 0., "two": 1.})
589 assert "a" in combine.keys()
590 assert "b" in combine.keys()
591 assert all(ss in data["two"]["a"] for ss in combine["a"])
592 assert all(ss in data["two"]["b"] for ss in combine["b"])
593 # test that when weights are equal, the first half of samples are from
594 # one and the second half are from two
595 combine = dataframe.combine(labels=["one", "two"], weights=[0.5, 0.5])
596 nsamples = len(combine["a"])
597 half = int(nsamples / 2)
598 assert all(ss in data["one"]["a"] for ss in combine["a"][:half])
599 assert all(ss in data["two"]["a"] for ss in combine["a"][half:])
600 assert all(ss in data["one"]["b"] for ss in combine["b"][:half])
601 assert all(ss in data["two"]["b"] for ss in combine["b"][half:])
602 # test that the samples maintain order
603 for n in np.random.choice(half, size=10, replace=False):
604 ind = np.argwhere(data["one"]["a"] == combine["a"][n])
605 assert data["one"]["b"][ind] == combine["b"][n]
606 # test that when use_all is provided, all samples are included
607 combine = dataframe.combine(use_all=True)
608 assert len(combine["a"]) == len(data["one"]["a"]) + len(data["two"]["a"])
609 # test shuffle
610 combine = dataframe.combine(
611 labels=["one", "two"], weights=[0.5, 0.5], shuffle=True
612 )
613 for n in np.random.choice(half, size=10, replace=False):
614 if combine["a"][n] in data["one"]["a"]:
615 ind = np.argwhere(data["one"]["a"] == combine["a"][n])
616 assert data["one"]["b"][ind] == combine["b"][n]
617 else:
618 ind = np.argwhere(data["two"]["a"] == combine["a"][n])
619 assert data["two"]["b"][ind] == combine["b"][n]
620 assert len(set(combine["a"])) == len(combine["a"])
621 assert len(set(combine["b"])) == len(combine["b"])
624class TestMCMCSamplesDict(object):
625 """Test the MCMCSamplesDict class
626 """
627 def setup_method(self):
628 self.parameters = ["a", "b"]
629 self.chains = [
630 [np.random.uniform(10, 0.5, 100), np.random.uniform(100, 10, 100)],
631 [np.random.uniform(5, 0.5, 100), np.random.uniform(80, 10, 100)]
632 ]
634 def test_initalize(self):
635 """Test the different ways to initalize the class
636 """
637 dataframe = MCMCSamplesDict(self.parameters, self.chains)
638 assert sorted(list(dataframe.keys())) == sorted(
639 ["chain_{}".format(num) for num in range(len(self.chains))]
640 )
641 assert sorted(list(dataframe["chain_0"].keys())) == sorted(["a", "b"])
642 assert sorted(list(dataframe["chain_1"].keys())) == sorted(["a", "b"])
643 np.testing.assert_almost_equal(
644 dataframe["chain_0"]["a"], self.chains[0][0]
645 )
646 np.testing.assert_almost_equal(
647 dataframe["chain_0"]["b"], self.chains[0][1]
648 )
649 np.testing.assert_almost_equal(
650 dataframe["chain_1"]["a"], self.chains[1][0]
651 )
652 np.testing.assert_almost_equal(
653 dataframe["chain_1"]["b"], self.chains[1][1]
654 )
655 other = MCMCSamplesDict({
656 "chain_{}".format(num): {
657 param: self.chains[num][idx] for idx, param in enumerate(
658 self.parameters
659 )
660 } for num in range(len(self.chains))
661 })
662 assert sorted(other.keys()) == sorted(dataframe.keys())
663 assert sorted(other["chain_0"].keys()) == sorted(
664 dataframe["chain_0"].keys()
665 )
666 np.testing.assert_almost_equal(
667 other["chain_0"]["a"], dataframe["chain_0"]["a"]
668 )
669 np.testing.assert_almost_equal(
670 other["chain_0"]["b"], dataframe["chain_0"]["b"]
671 )
672 np.testing.assert_almost_equal(
673 other["chain_1"]["a"], dataframe["chain_1"]["a"]
674 )
675 np.testing.assert_almost_equal(
676 other["chain_1"]["b"], dataframe["chain_1"]["b"]
677 )
679 def test_unequal_chain_length(self):
680 """Test that when inverted, the chains keep their unequal chain
681 length
682 """
683 chains = [
684 [np.random.uniform(10, 0.5, 100), np.random.uniform(100, 10, 100)],
685 [np.random.uniform(5, 0.5, 200), np.random.uniform(80, 10, 200)]
686 ]
687 dataframe = MCMCSamplesDict(self.parameters, chains)
688 transpose = dataframe.T
689 assert len(transpose["a"]["chain_0"]) == 100
690 assert len(transpose["a"]["chain_1"]) == 200
691 assert dataframe.number_of_samples == {
692 "chain_0": 100, "chain_1": 200
693 }
694 assert dataframe.minimum_number_of_samples == 100
695 assert transpose.number_of_samples == dataframe.number_of_samples
696 assert transpose.minimum_number_of_samples == \
697 dataframe.minimum_number_of_samples
698 combined = dataframe.combine
699 assert combined.number_of_samples == 300
700 my_combine = np.concatenate(
701 [dataframe["chain_0"]["a"], dataframe["chain_1"]["a"]]
702 )
703 assert all(ss in my_combine for ss in combined["a"])
705 def test_properties(self):
706 """Test that the properties of the MCMCSamplesDict class are correct
707 """
708 dataframe = MCMCSamplesDict(self.parameters, self.chains)
709 transpose = dataframe.T
710 np.testing.assert_almost_equal(
711 dataframe["chain_0"]["a"], transpose["a"]["chain_0"]
712 )
713 np.testing.assert_almost_equal(
714 dataframe["chain_0"]["b"], transpose["b"]["chain_0"]
715 )
716 np.testing.assert_almost_equal(
717 dataframe["chain_1"]["a"], transpose["a"]["chain_1"]
718 )
719 np.testing.assert_almost_equal(
720 dataframe["chain_1"]["b"], transpose["b"]["chain_1"]
721 )
722 average = dataframe.average
723 transpose_average = transpose.average
724 for param in self.parameters:
725 np.testing.assert_almost_equal(
726 average[param], transpose_average[param]
727 )
728 assert dataframe.total_number_of_samples == 200
729 assert dataframe.total_number_of_samples == \
730 transpose.total_number_of_samples
731 combined = dataframe.combine
732 assert combined.number_of_samples == 200
733 mycombine = np.concatenate(
734 [dataframe["chain_0"]["a"], dataframe["chain_1"]["a"]]
735 )
736 assert all(s in combined["a"] for s in mycombine)
737 transpose_copy = transpose.T
738 assert sorted(list(transpose_copy.keys())) == sorted(list(dataframe.keys()))
739 for level1 in dataframe.keys():
740 assert sorted(list(transpose_copy[level1].keys())) == sorted(
741 list(dataframe[level1].keys())
742 )
743 for level2 in dataframe[level1].keys():
744 np.testing.assert_almost_equal(
745 transpose_copy[level1][level2], dataframe[level1][level2]
746 )
748 def test_key_data(self):
749 """Test that the key data is correct
750 """
751 dataframe = MCMCSamplesDict(self.parameters, self.chains)
752 key_data = dataframe.key_data
753 combined = dataframe.combine
754 for param, in key_data.keys():
755 np.testing.assert_almost_equal(
756 key_data[param]["mean"], np.mean(combined[param])
757 )
758 np.testing.assert_almost_equal(
759 key_data[param]["median"], np.median(combined[param])
760 )
762 def test_burnin_removal(self):
763 """Test that the different methods for removing the samples as burnin
764 as expected
765 """
766 uniform = np.random.uniform
767 parameters = ["a", "b", "cycle"]
768 chains = [
769 [uniform(10, 0.5, 100), uniform(100, 10, 100), uniform(1, 0.8, 100)],
770 [uniform(5, 0.5, 100), uniform(80, 10, 100), uniform(1, 0.8, 100)],
771 [uniform(1, 0.8, 100), uniform(90, 10, 100), uniform(1, 0.8, 100)]
772 ]
773 dataframe = MCMCSamplesDict(parameters, chains)
774 burnin = dataframe.burnin(algorithm="burnin_by_step_number")
775 idxs = np.argwhere(chains[0][2] > 0)
776 assert len(burnin["chain_0"]["a"]) == len(idxs)
777 dataframe = MCMCSamplesDict(parameters, chains)
778 burnin = dataframe.burnin(10, algorithm="burnin_by_first_n")
779 assert len(burnin["chain_0"]["a"]) == 90
780 dataframe = MCMCSamplesDict(parameters, chains)
781 burnin = dataframe.burnin(
782 10, algorithm="burnin_by_first_n", step_number=True
783 )
784 assert len(burnin["chain_0"]["a"]) == len(idxs) - 10
787class TestList(object):
788 """Test the List class
789 """
790 def test_added(self):
791 original = ["a", "b", "c", "d", "e"]
792 array = List(original)
793 assert not len(array.added)
794 array.append("f")
795 assert len(array.added)
796 assert array.added == ["f"]
797 array.extend(["g", "h"])
798 assert sorted(array.added) == sorted(["f", "g", "h"])
799 assert sorted(array.original) == sorted(original)
800 array.insert(2, "z")
801 assert sorted(array.added) == sorted(["f", "g", "h", "z"])
802 assert sorted(array) == sorted(original + ["f", "g", "h", "z"])
803 array = List(original)
804 array = array + ["f", "g", "h"]
805 assert sorted(array.added) == sorted(["f", "g", "h"])
806 array += ["i"]
807 assert sorted(array.added) == sorted(["f", "g", "h", "i"])
809 def test_removed(self):
810 original = ["a", "b", "c", "d", "e"]
811 array = List(original)
812 assert not len(array.removed)
813 array.remove("e")
814 assert sorted(array) == sorted(["a", "b", "c", "d"])
815 assert array.removed == ["e"]
816 assert not len(sorted(array.added))
817 array.extend(["f", "g"])
818 array.remove("f")
819 assert array.removed == ["e", "f"]
820 assert array.added == ["g"]
821 array.pop(0)
822 assert sorted(array.removed) == sorted(["e", "f", "a"])
825def test_2DArray():
826 """Test the pesummary.utils.array._2DArray class
827 """
828 samples = [
829 np.random.normal(np.random.randint(100), 0.2, size=1000) for _ in
830 range(10)
831 ]
832 arrays = _2DArray(samples)
833 for num, array in enumerate(arrays):
834 np.testing.assert_almost_equal(array, samples[num])
835 for num, array in enumerate(arrays):
836 assert array.standard_deviation == np.std(samples[num])
837 assert array.minimum == np.min(samples[num])
838 assert array.maximum == np.max(samples[num])
839 _key_data = array.key_data
840 assert _key_data["5th percentile"] == np.percentile(samples[num], 5)
841 assert _key_data["95th percentile"] == np.percentile(samples[num], 95)
842 assert _key_data["median"] == np.median(samples[num])
843 assert _key_data["mean"] == np.mean(samples[num])
845 samples = [
846 np.random.normal(np.random.randint(100), 0.2, size=1000) for _ in
847 range(10)
848 ]
849 likelihood = np.random.uniform(0, 1, 1000)
850 prior = np.random.uniform(0, 1, 1000)
851 arrays = _2DArray(samples, likelihood=likelihood, prior=prior)
852 for num, array in enumerate(arrays):
853 np.testing.assert_almost_equal(array, samples[num])
854 for num, array in enumerate(arrays):
855 assert array.standard_deviation == np.std(samples[num])
856 assert array.minimum == np.min(samples[num])
857 assert array.maximum == np.max(samples[num])
858 assert array.maxL == array[np.argmax(likelihood)]
859 assert array.maxP == array[np.argmax(likelihood + prior)]
860 _key_data = array.key_data
861 assert _key_data["5th percentile"] == np.percentile(samples[num], 5)
862 assert _key_data["95th percentile"] == np.percentile(samples[num], 95)
863 assert _key_data["median"] == np.median(samples[num])
864 assert _key_data["mean"] == np.mean(samples[num])
865 assert _key_data["maxL"] == array.maxL
866 assert _key_data["maxP"] == array.maxP
869class TestArray(object):
870 """Test the Array class
871 """
872 def test_properties(self):
873 samples = np.random.uniform(100, 10, 100)
874 array = Array(samples)
875 assert array.average(type="mean") == np.mean(samples)
876 assert array.average(type="median") == np.median(samples)
877 assert array.standard_deviation == np.std(samples)
878 np.testing.assert_almost_equal(
879 array.credible_interval(percentile=[5, 95]),
880 [np.percentile(array, 5), np.percentile(array, 95)]
881 )
883 def test_weighted_percentile(self):
884 x = np.random.normal(100, 20, 10000)
885 weights = np.array([np.random.randint(100) for _ in range(10000)])
886 array = Array(x, weights=weights)
887 numpy = np.percentile(np.repeat(x, weights), 90)
888 pesummary = array.credible_interval(percentile=90)
889 np.testing.assert_almost_equal(numpy, pesummary, 6)
891 def test_key_data(self):
892 samples = np.random.normal(100, 20, 10000)
893 array = Array(samples)
894 key_data = array.key_data
895 np.testing.assert_almost_equal(key_data["mean"], np.mean(samples))
896 np.testing.assert_almost_equal(key_data["median"], np.median(samples))
897 np.testing.assert_almost_equal(key_data["std"], np.std(samples))
898 np.testing.assert_almost_equal(
899 key_data["5th percentile"], np.percentile(samples, 5)
900 )
901 np.testing.assert_almost_equal(
902 key_data["95th percentile"], np.percentile(samples, 95)
903 )
904 assert key_data["maxL"] is None
905 assert key_data["maxP"] is None
908class TestTQDM(object):
909 """Test the pesummary.utils.tqdm.tqdm class
910 """
911 def setup_method(self):
912 self._range = range(100)
913 if not os.path.isdir(tmpdir):
914 os.mkdir(tmpdir)
916 def teardown_method(self):
917 """Remove the files and directories created from this class
918 """
919 if os.path.isdir(tmpdir):
920 shutil.rmtree(tmpdir)
922 def test_basic_iterator(self):
923 """Test that the core functionality of the tqdm class remains
924 """
925 for j in tqdm(self._range):
926 _ = j*j
928 def test_interaction_with_logger(self):
929 """Test that tqdm interacts nicely with logger
930 """
931 from pesummary.utils.utils import logger, LOG_FILE
933 with open("{}/test.dat".format(tmpdir), "w") as f:
934 for j in tqdm(self._range, logger=logger, file=f):
935 _ = j*j
937 with open("{}/test.dat".format(tmpdir), "r") as f:
938 lines = f.readlines()
939 assert "PESummary" in lines[-1]
940 assert "INFO" in lines[-1]
943def test_jensen_shannon_divergence():
944 """Test that the `jensen_shannon_divergence` method returns the same
945 values as the scipy function
946 """
947 from scipy.spatial.distance import jensenshannon
948 from scipy import stats
950 samples = [
951 np.random.uniform(5, 4, 100),
952 np.random.uniform(5, 4, 100)
953 ]
954 x = np.linspace(np.min(samples), np.max(samples), 100)
955 kde = [stats.gaussian_kde(i)(x) for i in samples]
956 _scipy = jensenshannon(*kde)**2
957 _pesummary = utils.jensen_shannon_divergence(samples, decimal=9)
958 np.testing.assert_almost_equal(_scipy, _pesummary)
960 from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE
962 _pesummary = utils.jensen_shannon_divergence(
963 samples, decimal=9, kde=ReflectionBoundedKDE, xlow=4.5, xhigh=5.5
964 )
967def test_make_cache_style_file():
968 """Test that the `make_cache_style_file` works as expected
969 """
970 from pesummary.utils.utils import make_cache_style_file, CACHE_DIR
971 sty = os.path.expanduser("{}/style/matplotlib_rcparams.sty".format(CACHE_DIR))
972 with open("test.sty", "w") as f:
973 f.writelines(["test : 10"])
974 make_cache_style_file("test.sty")
975 assert os.path.isfile(sty)
976 with open(sty, "r") as f:
977 lines = f.readlines()
978 assert len(lines) == 1
979 assert lines[0] == "test : 10"
982def test_logger():
983 with LogCapture() as l:
984 utils.logger.propagate = True
985 utils.logger.info("info")
986 utils.logger.warning("warning")
987 l.check(("PESummary", "INFO", "info"),
988 ("PESummary", "WARNING", "warning"),)
991def test_string_match():
992 """function to test the string_match function
993 """
994 param = "mass_1"
995 # Test that it matches itself
996 assert utils.string_match(param, param)
997 # Test that it does not match another string
998 assert not utils.string_match(param, "mass_2")
999 # Test that a single wildcard works
1000 assert utils.string_match(param, "{}*".format(param[0]))
1001 assert not utils.string_match(param, "z+")
1002 assert utils.string_match(param, "{}?".format(param[:-1]))
1003 # Test that multiple wildcards work
1004 assert utils.string_match(param, "*{}*".format(param[1]))
1005 assert utils.string_match(param, "*{}?".format(param[3:-1]))
1006 assert utils.string_match(param, "?{}?".format(param[1:-1]))
1007 assert not utils.string_match(param, "?{}?".format(param[:-1]))
1008 # Test 'start with' works
1009 assert utils.string_match(param, "^{}".format(param[:3]))
1010 # Test does not start with
1011 assert not utils.string_match(param, "^(?!{}).+".format(param[:3]))
1014def test_map_parameter_names():
1015 """function to test the map_parameter_names function
1016 """
1017 dictionary = {
1018 "a": np.random.uniform(1, 100, size=1000),
1019 "b": np.random.uniform(1, 100, size=1000),
1020 "c": np.random.uniform(1, 100, size=1000)
1021 }
1022 mapping = {"a": "z", "b": "y", "c": "x"}
1023 dictionary_copy = dictionary.copy()
1024 new_dictionary = utils.map_parameter_names(dictionary_copy, mapping)
1025 # check that the new keys are in the new dictionary
1026 assert sorted(list(new_dictionary.keys())) == ["x", "y", "z"]
1027 # check that the new dictionary has the same items as before
1028 for old, new in mapping.items():
1029 np.testing.assert_almost_equal(dictionary[old], new_dictionary[new])
1030 # check that the old dictionary remains unchanged
1031 for key, item in dictionary.items():
1032 np.testing.assert_almost_equal(item, dictionary_copy[key])
1035def test_list_match():
1036 """function to test the list_match function
1037 """
1038 params = [
1039 "mass_1", "mass_2", "chirp_mass", "total_mass", "mass_ratio", "a_1",
1040 "luminosity_distance", "x"
1041 ]
1042 # Test that all mass parameters are returned
1043 assert sorted(utils.list_match(params, "*mass*")) == sorted(
1044 [p for p in params if "mass" in p]
1045 )
1046 # Test that only parameters with an "_" are returned
1047 assert sorted(utils.list_match(params, "*_*")) == sorted(
1048 [p for p in params if "_" in p]
1049 )
1050 # Test that nothing is returned
1051 assert not len(utils.list_match(params, "z+"))
1052 # Test that only parameters that do not start with mass are returned
1053 assert sorted(utils.list_match(params, "^(?!mass).+")) == sorted(
1054 [p for p in params if p[:4] != "mass"]
1055 )
1056 # Test that only parameters that start with 'l' are returned
1057 assert sorted(utils.list_match(params, "^l")) == sorted(
1058 [p for p in params if p[0] == "l"]
1059 )
1060 # Test return false
1061 assert sorted(utils.list_match(params, "^l", return_false=True)) == sorted(
1062 [p for p in params if p[0] != "l"]
1063 )
1064 # Test multiple substrings
1065 assert sorted(utils.list_match(params, ["^m", "*2"])) == sorted(["mass_2"])
1068class TestDict(object):
1069 """Class to test the NestedDict object
1070 """
1071 def test_initiate(self):
1072 """Initiate the Dict class
1073 """
1074 from pesummary.gw.file.psd import PSD
1076 x = Dict(
1077 {"a": [[10, 20], [10, 20]]}, value_class=PSD,
1078 value_columns=["value", "value2"]
1079 )
1080 assert list(x.keys()) == ["a"]
1081 np.testing.assert_almost_equal(x["a"], [[10, 20], [10, 20]])
1082 assert isinstance(x["a"], PSD)
1083 np.testing.assert_almost_equal(x["a"].value, [10, 10])
1084 np.testing.assert_almost_equal(x["a"].value2, [20, 20])
1086 x = Dict(
1087 ["a"], [[[10, 20], [10, 20]]], value_class=PSD,
1088 value_columns=["value", "value2"]
1089 )
1090 assert list(x.keys()) == ["a"]
1091 np.testing.assert_almost_equal(x["a"], [[10, 20], [10, 20]])
1092 assert isinstance(x["a"], PSD)
1093 np.testing.assert_almost_equal(x["a"].value, [10, 10])
1094 np.testing.assert_almost_equal(x["a"].value2, [20, 20])
1096 def test_specify_columns(self):
1097 """Test that x[["a", "b"]] works as expected
1098 """
1099 x = Dict({"a": [10], "b": [20], "c": [30], "d": [40]}, value_class=list)
1100 y = x[["a", "b"]]
1101 assert sorted(list(y.keys())) == ["a", "b"]
1102 for key in y.keys():
1103 assert y[key] == x[key]
1104 with pytest.raises(Exception):
1105 z = x[["e", "f"]]
1108class TestProbabilityDict(object):
1109 """Test the ProbabilityDict class
1110 """
1111 def setup_method(self):
1112 """Setup the ProbabilityDict class
1113 """
1114 self.mydict = {
1115 "a": [[1,2,3,4], [0.1, 0.2, 0.3, 0.4]],
1116 "b": [[1,3,5,7], [0.1, 0.3, 0.2, 0.4]]
1117 }
1119 def test_initiate(self):
1120 """Test the different ways to initiate this class
1121 """
1122 probs = ProbabilityDict(self.mydict)
1123 assert all(param in probs.keys() for param in self.mydict.keys())
1124 for param, data in self.mydict.items():
1125 np.testing.assert_almost_equal(data[0], probs[param].x)
1126 np.testing.assert_almost_equal(data[1], probs[param].probs)
1127 parameters = list(self.mydict.keys())
1128 data = [self.mydict[key] for key in parameters]
1129 probs = ProbabilityDict(parameters, data)
1130 assert all(param in probs.keys() for param in self.mydict.keys())
1131 for param, data in self.mydict.items():
1132 np.testing.assert_almost_equal(data[0], probs[param].x)
1133 np.testing.assert_almost_equal(data[1], probs[param].probs)
1135 def test_rvs(self):
1136 """Test the .rvs method
1137 """
1138 probs = ProbabilityDict(self.mydict)
1139 # draw 10 samples for all parameters
1140 samples = probs.rvs(size=10, interpolate=False)
1141 assert isinstance(samples, SamplesDict)
1142 assert all(param in samples.keys() for param in self.mydict.keys())
1143 assert all(len(samples[param] == 10) for param in self.mydict.keys())
1144 for param in self.mydict.keys():
1145 for p in samples[param]:
1146 print(p, self.mydict[param][0])
1147 assert all(p in self.mydict[param][0] for p in samples[param])
1148 # draw 10 samples for only parameter 'a'
1149 samples = probs.rvs(size=10, parameters=["a"])
1150 assert list(samples.keys()) == ["a"]
1151 assert len(samples["a"]) == 10
1152 # interpolate first and then draw samples
1153 samples = probs.rvs(size=10, interpolate=True)
1154 assert isinstance(samples, SamplesDict)
1155 assert all(param in samples.keys() for param in self.mydict.keys())
1156 assert all(len(samples[param] == 10) for param in self.mydict.keys())
1158 def test_plotting(self):
1159 """Test the .plot method
1160 """
1161 probs = ProbabilityDict(self.mydict)
1162 # Test that a histogram plot is generated when required
1163 fig = probs.plot("a", type="hist")
1164 fig = probs.plot("a", type="marginalized_posterior")
1165 # Test that a PDF is generated when required
1166 fig = probs.plot("a", type="pdf")
1169class TestProbabilityDict2D(object):
1170 """Class to test the ProbabilityDict2D class
1171 """
1172 def setup_method(self):
1173 """Setup the TestProbabilityDict2D class
1174 """
1175 probs = np.random.uniform(0, 1, 100).reshape(10, 10)
1176 self.mydict = {
1177 "x_y": [
1178 np.random.uniform(0, 1, 10), np.random.uniform(0, 1, 10),
1179 probs / np.sum(probs)
1180 ]
1181 }
1183 def test_plotting(self):
1184 """Test the .plot method
1185 """
1186 # Test that a 2d KDE plot is generated when specified
1187 mydict = ProbabilityDict2D(self.mydict)
1188 fig = mydict.plot("x_y", type="2d_kde")
1189 # Test that a triangle plot is generated when specified
1190 fig, _, _, _ = mydict.plot("x_y", type="triangle")
1191 # Test that a reverse triangle plot is generated when specified
1192 fig, _, _, _ = mydict.plot("x_y", type="reverse_triangle")
1193 # Test that an Exception is raised when you try and generate a plot
1194 # which does not exist
1195 with pytest.raises(Exception):
1196 fig = mydict.plot("x_y", type="does_not_exist")
1199class TestDiscretePDF2D(object):
1200 """Class to test the DiscretePDF2D class
1201 """
1202 def setup_method(self):
1203 """Setup the TestDiscretePDF2D class
1204 """
1205 self.x = [1,2]
1206 self.y = [1,2]
1207 self.probs = [[0.25, 0.25], [0.25, 0.25]]
1209 def test_initiate(self):
1210 """
1211 """
1212 # Test that a ValueError is raised if a 2d probability array is not
1213 # provided
1214 with pytest.raises(ValueError):
1215 _ = DiscretePDF2D(self.x, self.y, [1,2,3,4,5,6])
1216 obj = DiscretePDF2D(self.x, self.y, self.probs)
1218 def test_marginalize(self):
1219 obj = DiscretePDF2D(self.x, self.y, self.probs)
1220 marg = obj.marginalize()
1221 assert isinstance(marg, DiscretePDF2Dplus1D)
1222 np.testing.assert_almost_equal(obj.probs, marg.probs_xy.probs)
1225class TestDiscretePDF2Dplus1D(object):
1226 """Class to test the DiscretePDF2Dplus1D class
1227 """
1228 def setup_method(self):
1229 """Setup the TestDiscretePDF2Dplus1D class
1230 """
1231 self.x = [1,2]
1232 self.y = [1,2]
1233 self.probs = [[0.5, 0.5], [0.5, 0.5], [[0.25, 0.25], [0.25, 0.25]]]
1235 def test_initiate_raise(self):
1236 """Test that the class raises ValueErrors when required
1237 """
1238 # Raise error when 3 probabilities are not given
1239 with pytest.raises(ValueError):
1240 _ = DiscretePDF2Dplus1D(self.x, self.y, self.probs[:-1])
1241 # Raise error when 2d probability not given
1242 with pytest.raises(ValueError):
1243 _ = DiscretePDF2Dplus1D(
1244 self.x, self.y, self.probs[:-1] + self.probs[:1]
1245 )
1246 # Raise error when only 1 1d array is given
1247 with pytest.raises(ValueError):
1248 _ = DiscretePDF2Dplus1D(
1249 self.x, self.y, self.probs[1:] + self.probs[-1:]
1250 )
1252 def test_initiate(self):
1253 """Test that we can initiate the class
1254 """
1255 obj = DiscretePDF2Dplus1D(self.x, self.y, self.probs)
1256 assert isinstance(obj.probs[0], DiscretePDF)
1257 assert isinstance(obj.probs[2], DiscretePDF2D)
1258 assert isinstance(obj.probs[1], DiscretePDF)
1261class TestKDEList(object):
1262 """Test the KDEList class
1263 """
1264 def setup_method(self):
1265 """Setup the KDEList class
1266 """
1267 from scipy.stats import gaussian_kde
1268 means = np.random.uniform(0, 5, size=10)
1269 stds = np.random.uniform(0, 2, size=10)
1270 self.inputs = np.array(
1271 [
1272 np.random.normal(mean, std, size=10000) for mean, std
1273 in zip(means, stds)
1274 ]
1275 )
1276 self.pts = np.linspace(-10, 10, 100)
1277 self.true_kdes = [gaussian_kde(_)(self.pts) for _ in self.inputs]
1279 def test_call(self):
1280 """Test the KDEList call method
1281 """
1282 from pesummary.utils.kde_list import KDEList
1283 from scipy.stats import gaussian_kde
1284 new = KDEList(self.inputs, kde=gaussian_kde, pts=self.pts)
1285 # test on a single CPU
1286 new_single = new(multi_process=1)
1287 for num in range(len(self.inputs)):
1288 np.testing.assert_almost_equal(self.true_kdes[num], new_single[num])
1289 # test on multiple CPUs
1290 new_multiple = new(multi_process=2)
1291 for num in range(len(self.inputs)):
1292 np.testing.assert_almost_equal(self.true_kdes[num], new_single[num])
1293 # test the `idx` kwarg
1294 for num in range(len(self.inputs)):
1295 np.testing.assert_almost_equal(self.true_kdes[num], new(idx=num))
1296 # split the calculation into 2 chunks and 2 CPUs
1297 new_chunk1 = new(
1298 idx=np.arange(0, len(self.inputs) // 2, 1), multi_process=2
1299 ).tolist()
1300 new_chunk2 = new(
1301 idx=np.arange(len(self.inputs) // 2, len(self.inputs), 1), multi_process=2
1302 ).tolist()
1303 new_combined = new_chunk1 + new_chunk2
1304 for num in range(len(self.inputs)):
1305 np.testing.assert_almost_equal(self.true_kdes[num], new_combined[num])
1306 # assert that different KDEs are produced for different pts
1307 with pytest.raises(AssertionError):
1308 new_diff = new(pts=np.linspace(-5, 5, 100), multi_process=1)
1309 for num in range(len(self.inputs)):
1310 np.testing.assert_almost_equal(self.true_kdes[num], new_diff[num])
1313def make_cache_style_file(style_file):
1314 """Make a cache directory which stores the style file you wish to use
1315 when plotting
1317 Parameters
1318 ----------
1319 style_file: str
1320 path to the style file that you wish to use when plotting
1321 """
1322 make_dir(CACHE_DIR)
1323 shutil.copyfile(
1324 style_file, os.path.join(CACHE_DIR, "matplotlib_rcparams.sty")
1325 )
1328def get_matplotlib_style_file():
1329 """Return the path to the matplotlib style file that you wish to use
1330 """
1331 return os.path.join(CACHE_DIR, "matplotlib_rcparams.sty")