Coverage for pesummary/tests/utils_test.py: 41.4%

729 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import os 

4import shutil 

5import h5py 

6import numpy as np 

7import copy 

8 

9import pesummary 

10from pesummary.io import write 

11import pesummary.cli as cli 

12from pesummary.utils import utils 

13from pesummary.utils.tqdm import tqdm 

14from pesummary.utils.dict import Dict 

15from pesummary.utils.list import List 

16from pesummary.utils.pdf import DiscretePDF, DiscretePDF2D, DiscretePDF2Dplus1D 

17from pesummary.utils.array import _2DArray 

18from pesummary.utils.samples_dict import ( 

19 Array, SamplesDict, MCMCSamplesDict, MultiAnalysisSamplesDict 

20) 

21from pesummary.utils.probability_dict import ProbabilityDict, ProbabilityDict2D 

22from pesummary._version_helper import GitInformation, PackageInformation 

23from pesummary._version_helper import get_version_information 

24 

25import pytest 

26from testfixtures import LogCapture 

27import tempfile 

28 

29tmpdir = tempfile.TemporaryDirectory(prefix=".", dir=".").name 

30__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

31 

32DEFAULT_DIRECTORY = os.getenv("CI_PROJECT_DIR", os.getcwd()) 

33 

34 

35class TestGitInformation(object): 

36 """Class to test the GitInformation helper class 

37 """ 

38 def setup_method(self): 

39 """Setup the TestGitInformation class 

40 """ 

41 self.git = GitInformation(directory=DEFAULT_DIRECTORY) 

42 

43 def test_last_commit_info(self): 

44 """Test the last_commit_info property 

45 """ 

46 assert len(self.git.last_commit_info) == 2 

47 assert isinstance(self.git.last_commit_info[0], str) 

48 assert isinstance(self.git.last_commit_info[1], str) 

49 

50 def test_last_version(self): 

51 """Test the last_version property 

52 """ 

53 assert isinstance(self.git.last_version, str) 

54 

55 def test_status(self): 

56 """Test the status property 

57 """ 

58 assert isinstance(self.git.status, str) 

59 

60 def test_builder(self): 

61 """Test the builder property 

62 """ 

63 assert isinstance(self.git.builder, str) 

64 

65 def test_build_date(self): 

66 """Test the build_date property 

67 """ 

68 assert isinstance(self.git.build_date, str) 

69 

70 

71class TestPackageInformation(object): 

72 """Class to test the PackageInformation helper class 

73 """ 

74 def setup_method(self): 

75 """Setup the TestPackageInformation class 

76 """ 

77 self.package = PackageInformation() 

78 

79 def test_package_info(self): 

80 """Test the package_info property 

81 """ 

82 pi = self.package.package_info 

83 assert isinstance(pi, list) 

84 pkg = pi[0] 

85 assert "name" in pkg 

86 assert "version" in pkg 

87 if "build_string" in pkg: # conda only 

88 assert "channel" in pkg 

89 

90 

91class TestUtils(object): 

92 """Class to test pesummary.utils.utils 

93 """ 

94 def setup_method(self): 

95 """Setup the TestUtils class 

96 """ 

97 if not os.path.isdir(tmpdir): 

98 os.mkdir(tmpdir) 

99 

100 def teardown_method(self): 

101 """Remove the files created from this class 

102 """ 

103 if os.path.isdir(tmpdir): 

104 shutil.rmtree(tmpdir) 

105 

106 def test_check_condition(self): 

107 """Test the check_condition method 

108 """ 

109 with pytest.raises(Exception) as info: 

110 condition = True 

111 utils.check_condition(condition, "error") 

112 assert str(info.value) == "error" 

113 

114 def test_rename_group_in_hf5_file(self): 

115 """Test the rename_group_in_hf5_file method 

116 """ 

117 f = h5py.File("{}/rename_group.h5".format(tmpdir), "w") 

118 group = f.create_group("group") 

119 group.create_dataset("example", data=np.array([10])) 

120 f.close() 

121 utils.rename_group_or_dataset_in_hf5_file( 

122 "{}/rename_group.h5".format(tmpdir), 

123 group=["group", "replaced"]) 

124 f = h5py.File("{}/rename_group.h5".format(tmpdir)) 

125 assert list(f.keys()) == ["replaced"] 

126 assert list(f["replaced"].keys()) == ["example"] 

127 assert len(f["replaced/example"]) == 1 

128 assert f["replaced/example"][0] == 10 

129 f.close() 

130 

131 def test_rename_dataset_in_hf5_file(self): 

132 f = h5py.File("{}/rename_dataset.h5".format(tmpdir), "w") 

133 group = f.create_group("group") 

134 group.create_dataset("example", data=np.array([10])) 

135 f.close() 

136 utils.rename_group_or_dataset_in_hf5_file( 

137 "{}/rename_dataset.h5".format(tmpdir), 

138 dataset=["group/example", "group/replaced"]) 

139 f = h5py.File("{}/rename_dataset.h5".format(tmpdir)) 

140 assert list(f.keys()) == ["group"] 

141 assert list(f["group"].keys()) == ["replaced"] 

142 assert len(f["group/replaced"]) == 1 

143 assert f["group/replaced"][0] == 10 

144 f.close() 

145 

146 def test_rename_unknown_hf5_file(self): 

147 with pytest.raises(Exception) as info: 

148 utils.rename_group_or_dataset_in_hf5_file( 

149 "{}/unknown.h5".format(tmpdir), 

150 group=["None", "replaced"]) 

151 assert "does not exist" in str(info.value) 

152 

153 def test_directory_creation(self): 

154 directory = '{}/test_dir'.format(tmpdir) 

155 assert os.path.isdir(directory) == False 

156 utils.make_dir(directory) 

157 assert os.path.isdir(directory) == True 

158 

159 def test_url_guess(self): 

160 host = ["raven", "cit", "ligo-wa", "uwm", "phy.syr.edu", "vulcan", 

161 "atlas", "iucaa", "alice"] 

162 expected = ["https://geo2.arcca.cf.ac.uk/~albert.einstein/test", 

163 "https://ldas-jobs.ligo.caltech.edu/~albert.einstein/test", 

164 "https://ldas-jobs.ligo-wa.caltech.edu/~albert.einstein/test", 

165 "https://ldas-jobs.phys.uwm.edu/~albert.einstein/test", 

166 "https://sugar-jobs.phy.syr.edu/~albert.einstein/test", 

167 "https://galahad.aei.mpg.de/~albert.einstein/test", 

168 "https://atlas1.atlas.aei.uni-hannover.de/~albert.einstein/test", 

169 "https://ldas-jobs.gw.iucaa.in/~albert.einstein/test", 

170 "https://dumpty.alice.icts.res.in/~albert.einstein/test"] 

171 user = "albert.einstein" 

172 webdir = '/home/albert.einstein/public_html/test' 

173 for i,j in zip(host, expected): 

174 url = utils.guess_url(webdir, i, user) 

175 assert url == j 

176 

177 def test_make_dir(self): 

178 """Test the make_dir method 

179 """ 

180 assert not os.path.isdir(os.path.join(tmpdir, "test")) 

181 utils.make_dir(os.path.join(tmpdir, "test")) 

182 assert os.path.isdir(os.path.join(tmpdir, "test")) 

183 with open(os.path.join(tmpdir, "test", "test.dat"), "w") as f: 

184 f.writelines(["test"]) 

185 utils.make_dir(os.path.join(tmpdir, "test")) 

186 assert os.path.isfile(os.path.join(tmpdir, "test", "test.dat")) 

187 

188 def test_resample_posterior_distribution(self): 

189 """Test the resample_posterior_distribution method 

190 """ 

191 data = np.random.normal(1, 0.1, 1000) 

192 resampled = utils.resample_posterior_distribution([data], 500) 

193 assert len(resampled) == 500 

194 assert np.round(np.mean(resampled), 1) == 1. 

195 assert np.round(np.std(resampled), 1) == 0.1 

196 

197 def test_gw_results_file(self): 

198 """Test the gw_results_file method 

199 """ 

200 from .base import namespace 

201 

202 opts = namespace({"gw": True, "psd": True}) 

203 assert utils.gw_results_file(opts) 

204 opts = namespace({"webdir": tmpdir}) 

205 assert not utils.gw_results_file(opts) 

206 

207 def test_functions(self): 

208 """Test the functions method 

209 """ 

210 from .base import namespace 

211 

212 opts = namespace({"gw": True, "psd": True}) 

213 funcs = utils.functions(opts) 

214 assert funcs["input"] == pesummary.gw.cli.inputs.WebpagePlusPlottingPlusMetaFileInput 

215 assert funcs["MetaFile"] == pesummary.gw.file.meta_file.GWMetaFile 

216 

217 opts = namespace({"webdir": tmpdir}) 

218 funcs = utils.functions(opts) 

219 assert funcs["input"] == pesummary.core.cli.inputs.WebpagePlusPlottingPlusMetaFileInput 

220 assert funcs["MetaFile"] == pesummary.core.file.meta_file.MetaFile 

221 

222 def test_get_version_information(self): 

223 """Test the get_version_information method 

224 """ 

225 assert isinstance(get_version_information(), str) 

226 

227 

228class TestGelmanRubin(object): 

229 """Test the Gelman Rubin calculation 

230 """ 

231 def test_same_as_lalinference(self): 

232 """Test the Gelman rubin output from pesummary is the same as 

233 the one coded in LALInference 

234 """ 

235 from lalinference.bayespputils import Posterior 

236 from pesummary.utils.utils import gelman_rubin 

237 

238 header = ["a", "b", "logL", "chain"] 

239 for _ in np.arange(100): 

240 samples = np.array( 

241 [ 

242 np.random.uniform(np.random.random(), 0.1, 3).tolist() + 

243 [np.random.randint(1, 3)] for _ in range(10) 

244 ] 

245 ) 

246 obj = Posterior([header, np.array(samples)]) 

247 R = obj.gelman_rubin("a") 

248 chains = np.unique(obj["chain"].samples) 

249 chain_index = obj.names.index("chain") 

250 param_index = obj.names.index("a") 

251 data, _ = obj.samples() 

252 chainData=[ 

253 data[data[:,chain_index] == chain, param_index] for chain in 

254 chains 

255 ] 

256 np.testing.assert_almost_equal( 

257 gelman_rubin(chainData, decimal=10), R, 7 

258 ) 

259 

260 def test_same_samples(self): 

261 """Test that when passed two identical chains (perfect convergence), 

262 the Gelman Rubin is 1 

263 """ 

264 from pesummary.core.plots.plot import gelman_rubin 

265 

266 samples = np.random.uniform(1, 0.5, 10) 

267 R = gelman_rubin([samples, samples]) 

268 assert R == 1 

269 

270 

271class TestSamplesDict(object): 

272 """Test the SamplesDict class 

273 """ 

274 def setup_method(self): 

275 self.parameters = ["a", "b"] 

276 self.samples = [ 

277 np.random.uniform(10, 0.5, 100), np.random.uniform(200, 10, 100) 

278 ] 

279 if not os.path.isdir(tmpdir): 

280 os.mkdir(tmpdir) 

281 write( 

282 self.parameters, np.array(self.samples).T, outdir=tmpdir, 

283 filename="test.dat", file_format="dat" 

284 ) 

285 

286 def teardown_method(self): 

287 """Remove the files created from this class 

288 """ 

289 if os.path.isdir(tmpdir): 

290 shutil.rmtree(tmpdir) 

291 

292 def test_initalize(self): 

293 """Test that the two ways to initialize the SamplesDict class are 

294 equivalent 

295 """ 

296 base = SamplesDict(self.parameters, self.samples) 

297 other = SamplesDict( 

298 { 

299 param: sample for param, sample in zip( 

300 self.parameters, self.samples 

301 ) 

302 } 

303 ) 

304 assert base.parameters == other.parameters 

305 assert sorted(base.parameters) == sorted(self.parameters) 

306 np.testing.assert_almost_equal(base.samples, other.samples) 

307 assert sorted(list(base.keys())) == sorted(list(other.keys())) 

308 np.testing.assert_almost_equal(base.samples, self.samples) 

309 class_method = SamplesDict.from_file( 

310 "{}/test.dat".format(tmpdir), add_zero_likelihood=False 

311 ) 

312 np.testing.assert_almost_equal(class_method.samples, self.samples) 

313 

314 def test_complex_columns(self): 

315 """Test that complex columns are desconstructed correctly 

316 """ 

317 parameters = ["a", "b", "a_j", "b_j"] 

318 samples = [ 

319 np.random.uniform(10, 0.5, 100), np.random.uniform(200, 10, 100), 

320 np.random.uniform(10, 0.5, 100) + 10j, 

321 np.random.uniform(200, 10, 100) + 2j 

322 ] 

323 dataset1 = SamplesDict( 

324 parameters, samples, deconstruct_complex_columns=True 

325 ) 

326 dataset2 = SamplesDict( 

327 {p: s for p, s in zip(parameters, samples)}, 

328 deconstruct_complex_columns=True 

329 ) 

330 for dd in [dataset1, dataset2]: 

331 for num, param in enumerate(["a_j", "b_j"]): 

332 assert f"{param}_abs" in dd.parameters 

333 assert f"{param}_angle" in dd.parameters 

334 np.testing.assert_almost_equal( 

335 dd[f"{param}_abs"], np.abs(samples[2 + num]) 

336 ) 

337 np.testing.assert_almost_equal( 

338 dd[f"{param}_angle"], np.angle(samples[2 + num]) 

339 ) 

340 np.testing.assert_almost_equal( 

341 dd[f"{param}"], np.real(samples[2 + num]) 

342 ) 

343 np.testing.assert_almost_equal(dd["a"], samples[0]) 

344 np.testing.assert_almost_equal(dd["b"], samples[1]) 

345 dataset1 = SamplesDict( 

346 parameters, samples, deconstruct_complex_columns=False 

347 ) 

348 dataset2 = SamplesDict( 

349 {p: s for p, s in zip(parameters, samples)}, 

350 deconstruct_complex_columns=False 

351 ) 

352 for dd in [dataset1, dataset2]: 

353 for param, ss in zip(parameters, samples): 

354 np.testing.assert_almost_equal( 

355 dd[param], ss 

356 ) 

357 

358 def test_properties(self): 

359 """Test that the properties of the SamplesDict class are correct 

360 """ 

361 import pandas as pd 

362 

363 dataset = SamplesDict(self.parameters, self.samples) 

364 assert sorted(dataset.minimum.keys()) == sorted(self.parameters) 

365 assert dataset.minimum["a"] == np.min(self.samples[0]) 

366 assert dataset.minimum["b"] == np.min(self.samples[1]) 

367 assert dataset.median["a"] == np.median(self.samples[0]) 

368 assert dataset.median["b"] == np.median(self.samples[1]) 

369 assert dataset.mean["a"] == np.mean(self.samples[0]) 

370 assert dataset.mean["b"] == np.mean(self.samples[1]) 

371 assert dataset.number_of_samples == len(self.samples[1]) 

372 assert len(dataset.downsample(10)["a"]) == 10 

373 dataset = SamplesDict(self.parameters, self.samples) 

374 assert len(dataset.discard_samples(10)["a"]) == len(self.samples[0]) - 10 

375 np.testing.assert_almost_equal( 

376 dataset.discard_samples(10)["a"], self.samples[0][10:] 

377 ) 

378 np.testing.assert_almost_equal( 

379 dataset.discard_samples(10)["b"], self.samples[1][10:] 

380 ) 

381 p = dataset.to_pandas() 

382 assert isinstance(p, pd.core.frame.DataFrame) 

383 remove = dataset.pop("a") 

384 assert list(dataset.keys()) == ["b"] 

385 

386 def test_core_plots(self): 

387 """Test that the core plotting methods of the SamplesDict class work as 

388 expected 

389 """ 

390 import matplotlib.figure 

391 

392 dataset = SamplesDict(self.parameters, self.samples) 

393 fig = dataset.plot(self.parameters[0], type="hist") 

394 assert isinstance(fig, matplotlib.figure.Figure) 

395 fig = dataset.plot(self.parameters[0], type="marginalized_posterior") 

396 assert isinstance(fig, matplotlib.figure.Figure) 

397 

398 def test_standardize_parameter_names(self): 

399 """Test the standardize_parameter_names method 

400 """ 

401 dictionary = {"mass1": [1,2,3,4], "mass2": [4,3,2,1], "zz": [1,1,1,1]} 

402 dictionary_copy = dictionary.copy() 

403 mydict = SamplesDict(dictionary_copy) 

404 standard_dict = mydict.standardize_parameter_names() 

405 # check standard parameter names are in the new dictionary 

406 assert sorted(list(standard_dict.keys())) == ["mass_1", "mass_2", "zz"] 

407 # check that the dictionary items remains the same 

408 _mapping = {"mass1": "mass_1", "mass2": "mass_2", "zz": "zz"} 

409 for key, item in dictionary.items(): 

410 np.testing.assert_almost_equal(item, standard_dict[_mapping[key]]) 

411 # check old dictionary remains unchanged 

412 assert sorted(list(mydict.keys())) == ["mass1", "mass2", "zz"] 

413 for key, item in mydict.items(): 

414 np.testing.assert_almost_equal(item, dictionary[key]) 

415 # try custom mapping 

416 mapping = {"mass1": "custom_m1", "mass2": "custom_m2", "zz": "custom_zz"} 

417 new_dict = mydict.standardize_parameter_names(mapping=mapping) 

418 assert sorted(list(new_dict.keys())) == [ 

419 "custom_m1", "custom_m2", "custom_zz" 

420 ] 

421 for old, new in mapping.items(): 

422 np.testing.assert_almost_equal(dictionary[old], new_dict[new]) 

423 

424 def test_waveforms(self): 

425 """Test the waveform generation 

426 """ 

427 from pesummary.core.fetch import download_dir 

428 try: 

429 from pycbc.waveform import get_fd_waveform, get_td_waveform 

430 except (ValueError, ImportError): 

431 return 

432 

433 downloaded_file = os.path.join( 

434 download_dir, "GW190814_posterior_samples.h5" 

435 ) 

436 if not os.path.isfile(downloaded_file): 

437 from pesummary.gw.fetch import fetch_open_samples 

438 f = fetch_open_samples( 

439 "GW190814", read_file=True, outdir=download_dir, unpack=True, 

440 path="GW190814.h5", catalog="GWTC-2" 

441 ) 

442 else: 

443 from pesummary.io import read 

444 f = read(downloaded_file) 

445 samples = f.samples_dict["C01:IMRPhenomPv3HM"] 

446 ind = 0 

447 data = samples.fd_waveform("IMRPhenomPv3HM", 1./256, 20., 1024., ind=ind) 

448 hp_pycbc, hc_pycbc = get_fd_waveform( 

449 approximant="IMRPhenomPv3HM", mass1=samples["mass_1"][ind], 

450 mass2=samples["mass_2"][ind], spin1x=samples["spin_1x"][ind], 

451 spin1y=samples["spin_1y"][ind], spin1z=samples["spin_1z"][ind], 

452 spin2x=samples["spin_2x"][ind], spin2y=samples["spin_2y"][ind], 

453 spin2z=samples["spin_2z"][ind], inclination=samples["iota"][ind], 

454 distance=samples["luminosity_distance"][ind], 

455 coa_phase=samples["phase"][ind], f_lower=20., f_final=1024., 

456 delta_f=1./256, f_ref=20. 

457 ) 

458 np.testing.assert_almost_equal( 

459 data["h_plus"].frequencies.value, hp_pycbc.sample_frequencies 

460 ) 

461 np.testing.assert_almost_equal( 

462 data["h_cross"].frequencies.value, hc_pycbc.sample_frequencies 

463 ) 

464 np.testing.assert_almost_equal( 

465 data["h_plus"].value * 10**25, hp_pycbc._data * 10**25 

466 ) 

467 np.testing.assert_almost_equal( 

468 data["h_cross"].value * 10**25, hc_pycbc._data * 10**25 

469 ) 

470 data = samples.td_waveform("SEOBNRv4PHM", 1./4096, 20., ind=ind) 

471 hp_pycbc, hc_pycbc = get_td_waveform( 

472 approximant="SEOBNRv4PHM", mass1=samples["mass_1"][ind], 

473 mass2=samples["mass_2"][ind], spin1x=samples["spin_1x"][ind], 

474 spin1y=samples["spin_1y"][ind], spin1z=samples["spin_1z"][ind], 

475 spin2x=samples["spin_2x"][ind], spin2y=samples["spin_2y"][ind], 

476 spin2z=samples["spin_2z"][ind], inclination=samples["iota"][ind], 

477 distance=samples["luminosity_distance"][ind], 

478 coa_phase=samples["phase"][ind], f_lower=20., delta_t=1./4096, 

479 f_ref=20. 

480 ) 

481 np.testing.assert_almost_equal( 

482 data["h_plus"].times.value, hp_pycbc.sample_times 

483 ) 

484 np.testing.assert_almost_equal( 

485 data["h_cross"].times.value, hc_pycbc.sample_times 

486 ) 

487 np.testing.assert_almost_equal( 

488 data["h_plus"].value * 10**25, hp_pycbc._data * 10**25 

489 ) 

490 np.testing.assert_almost_equal( 

491 data["h_cross"].value * 10**25, hc_pycbc._data * 10**25 

492 ) 

493 

494 

495class TestMultiAnalysisSamplesDict(object): 

496 """Test the MultiAnalysisSamplesDict class 

497 """ 

498 def setup_method(self): 

499 self.parameters = ["a", "b"] 

500 self.samples = [ 

501 [np.random.uniform(10, 0.5, 100), np.random.uniform(100, 10, 100)], 

502 [np.random.uniform(5, 0.5, 100), np.random.uniform(80, 10, 100)], 

503 ] 

504 self.labels = ["one", "two"] 

505 if not os.path.isdir(tmpdir): 

506 os.mkdir(tmpdir) 

507 for num, _samples in enumerate(self.samples): 

508 write( 

509 self.parameters, np.array(_samples).T, outdir=tmpdir, 

510 filename="test_{}.dat".format(num + 1), file_format="dat" 

511 ) 

512 

513 def teardown_method(self): 

514 """Remove the files created from this class 

515 """ 

516 if os.path.isdir(tmpdir): 

517 shutil.rmtree(tmpdir) 

518 

519 def test_initalize(self): 

520 """Test the different ways to initalize the class 

521 """ 

522 dataframe = MultiAnalysisSamplesDict( 

523 self.parameters, self.samples, labels=["one", "two"] 

524 ) 

525 assert sorted(list(dataframe.keys())) == sorted(self.labels) 

526 assert sorted(list(dataframe["one"])) == sorted(["a", "b"]) 

527 assert sorted(list(dataframe["two"])) == sorted(["a", "b"]) 

528 np.testing.assert_almost_equal( 

529 dataframe["one"]["a"], self.samples[0][0] 

530 ) 

531 np.testing.assert_almost_equal( 

532 dataframe["one"]["b"], self.samples[0][1] 

533 ) 

534 np.testing.assert_almost_equal( 

535 dataframe["two"]["a"], self.samples[1][0] 

536 ) 

537 np.testing.assert_almost_equal( 

538 dataframe["two"]["b"], self.samples[1][1] 

539 ) 

540 _other = MCMCSamplesDict({ 

541 label: { 

542 param: self.samples[num][idx] for idx, param in enumerate( 

543 self.parameters 

544 ) 

545 } for num, label in enumerate(self.labels) 

546 }) 

547 class_method = MultiAnalysisSamplesDict.from_files( 

548 { 

549 'one': "{}/test_1.dat".format(tmpdir), 

550 'two': "{}/test_2.dat".format(tmpdir) 

551 }, add_zero_likelihood=False 

552 ) 

553 for other in [_other, class_method]: 

554 assert sorted(other.keys()) == sorted(dataframe.keys()) 

555 assert sorted(other["one"].keys()) == sorted( 

556 dataframe["one"].keys() 

557 ) 

558 np.testing.assert_almost_equal( 

559 other["one"]["a"], dataframe["one"]["a"] 

560 ) 

561 np.testing.assert_almost_equal( 

562 other["one"]["b"], dataframe["one"]["b"] 

563 ) 

564 np.testing.assert_almost_equal( 

565 other["two"]["a"], dataframe["two"]["a"] 

566 ) 

567 np.testing.assert_almost_equal( 

568 other["two"]["b"], dataframe["two"]["b"] 

569 ) 

570 

571 

572 def test_different_samples_for_different_analyses(self): 

573 """Test that nothing breaks when different samples have different parameters 

574 """ 

575 data = { 

576 "one": { 

577 "a": np.random.uniform(10, 0.5, 100), 

578 "b": np.random.uniform(5, 0.5, 100) 

579 }, "two": { 

580 "a": np.random.uniform(10, 0.5, 100) 

581 } 

582 } 

583 dataframe = MultiAnalysisSamplesDict(data) 

584 assert sorted(dataframe["one"].keys()) == sorted(data["one"].keys()) 

585 assert sorted(dataframe["two"].keys()) == sorted(data["two"].keys()) 

586 np.testing.assert_almost_equal( 

587 dataframe["one"]["a"], data["one"]["a"] 

588 ) 

589 np.testing.assert_almost_equal( 

590 dataframe["one"]["b"], data["one"]["b"] 

591 ) 

592 np.testing.assert_almost_equal( 

593 dataframe["two"]["a"], data["two"]["a"] 

594 ) 

595 with pytest.raises(ValueError): 

596 transpose = dataframe.T 

597 

598 def test_adding_to_table(self): 

599 """ 

600 """ 

601 data = { 

602 "one": { 

603 "a": np.random.uniform(10, 0.5, 100), 

604 "b": np.random.uniform(5, 0.5, 100) 

605 }, "two": { 

606 "a": np.random.uniform(10, 0.5, 100) 

607 } 

608 } 

609 dataframe = MultiAnalysisSamplesDict(data) 

610 assert "three" not in dataframe.parameters.keys() 

611 new_data = {"a": np.random.uniform(10, 0.5, 100)} 

612 dataframe["three"] = new_data 

613 np.testing.assert_almost_equal(dataframe["three"]["a"], new_data["a"]) 

614 assert dataframe.parameters["three"] == ["a"] 

615 assert "three" in dataframe.number_of_samples.keys() 

616 

617 def test_combine(self): 

618 """Test that the .combine method is working as expected 

619 """ 

620 data = { 

621 "one": { 

622 "a": np.random.uniform(10, 0.5, 100), 

623 "b": np.random.uniform(5, 0.5, 100) 

624 }, "two": { 

625 "a": np.random.uniform(100, 0.5, 100), 

626 "b": np.random.uniform(50, 0.5, 100) 

627 } 

628 } 

629 dataframe = MultiAnalysisSamplesDict(data) 

630 # test that when weights are 0 and 1, we only get the second set of 

631 # samples 

632 combine = dataframe.combine(weights={"one": 0., "two": 1.}) 

633 assert "a" in combine.keys() 

634 assert "b" in combine.keys() 

635 assert all(ss in data["two"]["a"] for ss in combine["a"]) 

636 assert all(ss in data["two"]["b"] for ss in combine["b"]) 

637 # test that when weights are equal, the first half of samples are from 

638 # one and the second half are from two 

639 combine = dataframe.combine(labels=["one", "two"], weights=[0.5, 0.5]) 

640 nsamples = len(combine["a"]) 

641 half = int(nsamples / 2) 

642 assert all(ss in data["one"]["a"] for ss in combine["a"][:half]) 

643 assert all(ss in data["two"]["a"] for ss in combine["a"][half:]) 

644 assert all(ss in data["one"]["b"] for ss in combine["b"][:half]) 

645 assert all(ss in data["two"]["b"] for ss in combine["b"][half:]) 

646 # test that the samples maintain order 

647 for n in np.random.choice(half, size=10, replace=False): 

648 ind = np.argwhere(data["one"]["a"] == combine["a"][n]) 

649 assert data["one"]["b"][ind] == combine["b"][n] 

650 # test that when use_all is provided, all samples are included 

651 combine = dataframe.combine(use_all=True) 

652 assert len(combine["a"]) == len(data["one"]["a"]) + len(data["two"]["a"]) 

653 # test shuffle 

654 combine = dataframe.combine( 

655 labels=["one", "two"], weights=[0.5, 0.5], shuffle=True 

656 ) 

657 for n in np.random.choice(half, size=10, replace=False): 

658 if combine["a"][n] in data["one"]["a"]: 

659 ind = np.argwhere(data["one"]["a"] == combine["a"][n]) 

660 assert data["one"]["b"][ind] == combine["b"][n] 

661 else: 

662 ind = np.argwhere(data["two"]["a"] == combine["a"][n]) 

663 assert data["two"]["b"][ind] == combine["b"][n] 

664 assert len(set(combine["a"])) == len(combine["a"]) 

665 assert len(set(combine["b"])) == len(combine["b"]) 

666 

667 

668class TestMCMCSamplesDict(object): 

669 """Test the MCMCSamplesDict class 

670 """ 

671 def setup_method(self): 

672 self.parameters = ["a", "b"] 

673 self.chains = [ 

674 [np.random.uniform(10, 0.5, 100), np.random.uniform(100, 10, 100)], 

675 [np.random.uniform(5, 0.5, 100), np.random.uniform(80, 10, 100)] 

676 ] 

677 

678 def test_initalize(self): 

679 """Test the different ways to initalize the class 

680 """ 

681 dataframe = MCMCSamplesDict(self.parameters, self.chains) 

682 assert sorted(list(dataframe.keys())) == sorted( 

683 ["chain_{}".format(num) for num in range(len(self.chains))] 

684 ) 

685 assert sorted(list(dataframe["chain_0"].keys())) == sorted(["a", "b"]) 

686 assert sorted(list(dataframe["chain_1"].keys())) == sorted(["a", "b"]) 

687 np.testing.assert_almost_equal( 

688 dataframe["chain_0"]["a"], self.chains[0][0] 

689 ) 

690 np.testing.assert_almost_equal( 

691 dataframe["chain_0"]["b"], self.chains[0][1] 

692 ) 

693 np.testing.assert_almost_equal( 

694 dataframe["chain_1"]["a"], self.chains[1][0] 

695 ) 

696 np.testing.assert_almost_equal( 

697 dataframe["chain_1"]["b"], self.chains[1][1] 

698 ) 

699 other = MCMCSamplesDict({ 

700 "chain_{}".format(num): { 

701 param: self.chains[num][idx] for idx, param in enumerate( 

702 self.parameters 

703 ) 

704 } for num in range(len(self.chains)) 

705 }) 

706 assert sorted(other.keys()) == sorted(dataframe.keys()) 

707 assert sorted(other["chain_0"].keys()) == sorted( 

708 dataframe["chain_0"].keys() 

709 ) 

710 np.testing.assert_almost_equal( 

711 other["chain_0"]["a"], dataframe["chain_0"]["a"] 

712 ) 

713 np.testing.assert_almost_equal( 

714 other["chain_0"]["b"], dataframe["chain_0"]["b"] 

715 ) 

716 np.testing.assert_almost_equal( 

717 other["chain_1"]["a"], dataframe["chain_1"]["a"] 

718 ) 

719 np.testing.assert_almost_equal( 

720 other["chain_1"]["b"], dataframe["chain_1"]["b"] 

721 ) 

722 

723 def test_unequal_chain_length(self): 

724 """Test that when inverted, the chains keep their unequal chain 

725 length 

726 """ 

727 chains = [ 

728 [np.random.uniform(10, 0.5, 100), np.random.uniform(100, 10, 100)], 

729 [np.random.uniform(5, 0.5, 200), np.random.uniform(80, 10, 200)] 

730 ] 

731 dataframe = MCMCSamplesDict(self.parameters, chains) 

732 transpose = dataframe.T 

733 assert len(transpose["a"]["chain_0"]) == 100 

734 assert len(transpose["a"]["chain_1"]) == 200 

735 assert dataframe.number_of_samples == { 

736 "chain_0": 100, "chain_1": 200 

737 } 

738 assert dataframe.minimum_number_of_samples == 100 

739 assert transpose.number_of_samples == dataframe.number_of_samples 

740 assert transpose.minimum_number_of_samples == \ 

741 dataframe.minimum_number_of_samples 

742 combined = dataframe.combine 

743 assert combined.number_of_samples == 300 

744 my_combine = np.concatenate( 

745 [dataframe["chain_0"]["a"], dataframe["chain_1"]["a"]] 

746 ) 

747 assert all(ss in my_combine for ss in combined["a"]) 

748 

749 def test_properties(self): 

750 """Test that the properties of the MCMCSamplesDict class are correct 

751 """ 

752 dataframe = MCMCSamplesDict(self.parameters, self.chains) 

753 transpose = dataframe.T 

754 np.testing.assert_almost_equal( 

755 dataframe["chain_0"]["a"], transpose["a"]["chain_0"] 

756 ) 

757 np.testing.assert_almost_equal( 

758 dataframe["chain_0"]["b"], transpose["b"]["chain_0"] 

759 ) 

760 np.testing.assert_almost_equal( 

761 dataframe["chain_1"]["a"], transpose["a"]["chain_1"] 

762 ) 

763 np.testing.assert_almost_equal( 

764 dataframe["chain_1"]["b"], transpose["b"]["chain_1"] 

765 ) 

766 average = dataframe.average 

767 transpose_average = transpose.average 

768 for param in self.parameters: 

769 np.testing.assert_almost_equal( 

770 average[param], transpose_average[param] 

771 ) 

772 assert dataframe.total_number_of_samples == 200 

773 assert dataframe.total_number_of_samples == \ 

774 transpose.total_number_of_samples 

775 combined = dataframe.combine 

776 assert combined.number_of_samples == 200 

777 mycombine = np.concatenate( 

778 [dataframe["chain_0"]["a"], dataframe["chain_1"]["a"]] 

779 ) 

780 assert all(s in combined["a"] for s in mycombine) 

781 transpose_copy = transpose.T 

782 assert sorted(list(transpose_copy.keys())) == sorted(list(dataframe.keys())) 

783 for level1 in dataframe.keys(): 

784 assert sorted(list(transpose_copy[level1].keys())) == sorted( 

785 list(dataframe[level1].keys()) 

786 ) 

787 for level2 in dataframe[level1].keys(): 

788 np.testing.assert_almost_equal( 

789 transpose_copy[level1][level2], dataframe[level1][level2] 

790 ) 

791 

792 def test_key_data(self): 

793 """Test that the key data is correct 

794 """ 

795 dataframe = MCMCSamplesDict(self.parameters, self.chains) 

796 key_data = dataframe.key_data 

797 combined = dataframe.combine 

798 for param, in key_data.keys(): 

799 np.testing.assert_almost_equal( 

800 key_data[param]["mean"], np.mean(combined[param]) 

801 ) 

802 np.testing.assert_almost_equal( 

803 key_data[param]["median"], np.median(combined[param]) 

804 ) 

805 

806 def test_burnin_removal(self): 

807 """Test that the different methods for removing the samples as burnin 

808 as expected 

809 """ 

810 uniform = np.random.uniform 

811 parameters = ["a", "b", "cycle"] 

812 chains = [ 

813 [uniform(10, 0.5, 100), uniform(100, 10, 100), uniform(1, 0.8, 100)], 

814 [uniform(5, 0.5, 100), uniform(80, 10, 100), uniform(1, 0.8, 100)], 

815 [uniform(1, 0.8, 100), uniform(90, 10, 100), uniform(1, 0.8, 100)] 

816 ] 

817 dataframe = MCMCSamplesDict(parameters, chains) 

818 burnin = dataframe.burnin(algorithm="burnin_by_step_number") 

819 idxs = np.argwhere(chains[0][2] > 0) 

820 assert len(burnin["chain_0"]["a"]) == len(idxs) 

821 dataframe = MCMCSamplesDict(parameters, chains) 

822 burnin = dataframe.burnin(10, algorithm="burnin_by_first_n") 

823 assert len(burnin["chain_0"]["a"]) == 90 

824 dataframe = MCMCSamplesDict(parameters, chains) 

825 burnin = dataframe.burnin( 

826 10, algorithm="burnin_by_first_n", step_number=True 

827 ) 

828 assert len(burnin["chain_0"]["a"]) == len(idxs) - 10 

829 

830 

831class TestList(object): 

832 """Test the List class 

833 """ 

834 def test_added(self): 

835 original = ["a", "b", "c", "d", "e"] 

836 array = List(original) 

837 assert not len(array.added) 

838 array.append("f") 

839 assert len(array.added) 

840 assert array.added == ["f"] 

841 array.extend(["g", "h"]) 

842 assert sorted(array.added) == sorted(["f", "g", "h"]) 

843 assert sorted(array.original) == sorted(original) 

844 array.insert(2, "z") 

845 assert sorted(array.added) == sorted(["f", "g", "h", "z"]) 

846 assert sorted(array) == sorted(original + ["f", "g", "h", "z"]) 

847 array = List(original) 

848 array = array + ["f", "g", "h"] 

849 assert sorted(array.added) == sorted(["f", "g", "h"]) 

850 array += ["i"] 

851 assert sorted(array.added) == sorted(["f", "g", "h", "i"]) 

852 

853 def test_removed(self): 

854 original = ["a", "b", "c", "d", "e"] 

855 array = List(original) 

856 assert not len(array.removed) 

857 array.remove("e") 

858 assert sorted(array) == sorted(["a", "b", "c", "d"]) 

859 assert array.removed == ["e"] 

860 assert not len(sorted(array.added)) 

861 array.extend(["f", "g"]) 

862 array.remove("f") 

863 assert array.removed == ["e", "f"] 

864 assert array.added == ["g"] 

865 array.pop(0) 

866 assert sorted(array.removed) == sorted(["e", "f", "a"]) 

867 

868 

869def test_2DArray(): 

870 """Test the pesummary.utils.array._2DArray class 

871 """ 

872 samples = [ 

873 np.random.normal(np.random.randint(100), 0.2, size=1000) for _ in 

874 range(10) 

875 ] 

876 arrays = _2DArray(samples) 

877 for num, array in enumerate(arrays): 

878 np.testing.assert_almost_equal(array, samples[num]) 

879 for num, array in enumerate(arrays): 

880 assert array.standard_deviation == np.std(samples[num]) 

881 assert array.minimum == np.min(samples[num]) 

882 assert array.maximum == np.max(samples[num]) 

883 _key_data = array.key_data 

884 assert _key_data["5th percentile"] == np.percentile(samples[num], 5) 

885 assert _key_data["95th percentile"] == np.percentile(samples[num], 95) 

886 assert _key_data["median"] == np.median(samples[num]) 

887 assert _key_data["mean"] == np.mean(samples[num]) 

888 

889 samples = [ 

890 np.random.normal(np.random.randint(100), 0.2, size=1000) for _ in 

891 range(10) 

892 ] 

893 likelihood = np.random.uniform(0, 1, 1000) 

894 prior = np.random.uniform(0, 1, 1000) 

895 arrays = _2DArray(samples, likelihood=likelihood, prior=prior) 

896 for num, array in enumerate(arrays): 

897 np.testing.assert_almost_equal(array, samples[num]) 

898 for num, array in enumerate(arrays): 

899 assert array.standard_deviation == np.std(samples[num]) 

900 assert array.minimum == np.min(samples[num]) 

901 assert array.maximum == np.max(samples[num]) 

902 assert array.maxL == array[np.argmax(likelihood)] 

903 assert array.maxP == array[np.argmax(likelihood + prior)] 

904 _key_data = array.key_data 

905 assert _key_data["5th percentile"] == np.percentile(samples[num], 5) 

906 assert _key_data["95th percentile"] == np.percentile(samples[num], 95) 

907 assert _key_data["median"] == np.median(samples[num]) 

908 assert _key_data["mean"] == np.mean(samples[num]) 

909 assert _key_data["maxL"] == array.maxL 

910 assert _key_data["maxP"] == array.maxP 

911 

912 

913class TestArray(object): 

914 """Test the Array class 

915 """ 

916 def test_properties(self): 

917 samples = np.random.uniform(100, 10, 100) 

918 array = Array(samples) 

919 assert array.average(type="mean") == np.mean(samples) 

920 assert array.average(type="median") == np.median(samples) 

921 assert array.standard_deviation == np.std(samples) 

922 np.testing.assert_almost_equal( 

923 array.credible_interval(percentile=[5, 95]), 

924 [np.percentile(array, 5), np.percentile(array, 95)] 

925 ) 

926 

927 def test_weighted_percentile(self): 

928 x = np.random.normal(100, 20, 10000) 

929 weights = np.array([np.random.randint(100) for _ in range(10000)]) 

930 array = Array(x, weights=weights) 

931 numpy = np.percentile(np.repeat(x, weights), 90) 

932 pesummary = array.credible_interval(percentile=90) 

933 np.testing.assert_almost_equal(numpy, pesummary, 6) 

934 

935 def test_key_data(self): 

936 samples = np.random.normal(100, 20, 10000) 

937 array = Array(samples) 

938 key_data = array.key_data 

939 np.testing.assert_almost_equal(key_data["mean"], np.mean(samples)) 

940 np.testing.assert_almost_equal(key_data["median"], np.median(samples)) 

941 np.testing.assert_almost_equal(key_data["std"], np.std(samples)) 

942 np.testing.assert_almost_equal( 

943 key_data["5th percentile"], np.percentile(samples, 5) 

944 ) 

945 np.testing.assert_almost_equal( 

946 key_data["95th percentile"], np.percentile(samples, 95) 

947 ) 

948 assert key_data["maxL"] is None 

949 assert key_data["maxP"] is None 

950 

951 

952class TestTQDM(object): 

953 """Test the pesummary.utils.tqdm.tqdm class 

954 """ 

955 def setup_method(self): 

956 self._range = range(100) 

957 if not os.path.isdir(tmpdir): 

958 os.mkdir(tmpdir) 

959 

960 def teardown_method(self): 

961 """Remove the files and directories created from this class 

962 """ 

963 if os.path.isdir(tmpdir): 

964 shutil.rmtree(tmpdir) 

965 

966 def test_basic_iterator(self): 

967 """Test that the core functionality of the tqdm class remains 

968 """ 

969 for j in tqdm(self._range): 

970 _ = j*j 

971 

972 def test_interaction_with_logger(self): 

973 """Test that tqdm interacts nicely with logger 

974 """ 

975 from pesummary.utils.utils import logger, LOG_FILE 

976 

977 with open("{}/test.dat".format(tmpdir), "w") as f: 

978 for j in tqdm(self._range, logger=logger, file=f): 

979 _ = j*j 

980 

981 with open("{}/test.dat".format(tmpdir), "r") as f: 

982 lines = f.readlines() 

983 assert "PESummary" in lines[-1] 

984 assert "INFO" in lines[-1] 

985 

986 

987def test_jensen_shannon_divergence(): 

988 """Test that the `jensen_shannon_divergence` method returns the same 

989 values as the scipy function 

990 """ 

991 from scipy.spatial.distance import jensenshannon 

992 from scipy import stats 

993 

994 samples = [ 

995 np.random.uniform(5, 4, 100), 

996 np.random.uniform(5, 4, 100) 

997 ] 

998 x = np.linspace(np.min(samples), np.max(samples), 100) 

999 kde = [stats.gaussian_kde(i)(x) for i in samples] 

1000 _scipy = jensenshannon(*kde)**2 

1001 _pesummary = utils.jensen_shannon_divergence(samples, decimal=9) 

1002 np.testing.assert_almost_equal(_scipy, _pesummary) 

1003 

1004 from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE 

1005 

1006 _pesummary = utils.jensen_shannon_divergence( 

1007 samples, decimal=9, kde=ReflectionBoundedKDE, xlow=4.5, xhigh=5.5 

1008 ) 

1009 

1010 

1011def test_make_cache_style_file(): 

1012 """Test that the `make_cache_style_file` works as expected 

1013 """ 

1014 from pesummary.utils.utils import make_cache_style_file, CACHE_DIR 

1015 sty = os.path.expanduser("{}/style/matplotlib_rcparams.sty".format(CACHE_DIR)) 

1016 with open("test.sty", "w") as f: 

1017 f.writelines(["test : 10"]) 

1018 make_cache_style_file("test.sty") 

1019 assert os.path.isfile(sty) 

1020 with open(sty, "r") as f: 

1021 lines = f.readlines() 

1022 assert len(lines) == 1 

1023 assert lines[0] == "test : 10" 

1024 

1025 

1026def test_logger(): 

1027 with LogCapture() as l: 

1028 utils.logger.propagate = True 

1029 utils.logger.info("info") 

1030 utils.logger.warning("warning") 

1031 l.check(("PESummary", "INFO", "info"), 

1032 ("PESummary", "WARNING", "warning"),) 

1033 

1034 

1035def test_string_match(): 

1036 """function to test the string_match function 

1037 """ 

1038 param = "mass_1" 

1039 # Test that it matches itself 

1040 assert utils.string_match(param, param) 

1041 # Test that it does not match another string 

1042 assert not utils.string_match(param, "mass_2") 

1043 # Test that a single wildcard works 

1044 assert utils.string_match(param, "{}*".format(param[0])) 

1045 assert not utils.string_match(param, "z+") 

1046 assert utils.string_match(param, "{}?".format(param[:-1])) 

1047 # Test that multiple wildcards work 

1048 assert utils.string_match(param, "*{}*".format(param[1])) 

1049 assert utils.string_match(param, "*{}?".format(param[3:-1])) 

1050 assert utils.string_match(param, "?{}?".format(param[1:-1])) 

1051 assert not utils.string_match(param, "?{}?".format(param[:-1])) 

1052 # Test 'start with' works 

1053 assert utils.string_match(param, "^{}".format(param[:3])) 

1054 # Test does not start with 

1055 assert not utils.string_match(param, "^(?!{}).+".format(param[:3])) 

1056 

1057 

1058def test_map_parameter_names(): 

1059 """function to test the map_parameter_names function 

1060 """ 

1061 dictionary = { 

1062 "a": np.random.uniform(1, 100, size=1000), 

1063 "b": np.random.uniform(1, 100, size=1000), 

1064 "c": np.random.uniform(1, 100, size=1000) 

1065 } 

1066 mapping = {"a": "z", "b": "y", "c": "x"} 

1067 dictionary_copy = dictionary.copy() 

1068 new_dictionary = utils.map_parameter_names(dictionary_copy, mapping) 

1069 # check that the new keys are in the new dictionary 

1070 assert sorted(list(new_dictionary.keys())) == ["x", "y", "z"] 

1071 # check that the new dictionary has the same items as before 

1072 for old, new in mapping.items(): 

1073 np.testing.assert_almost_equal(dictionary[old], new_dictionary[new]) 

1074 # check that the old dictionary remains unchanged 

1075 for key, item in dictionary.items(): 

1076 np.testing.assert_almost_equal(item, dictionary_copy[key]) 

1077 

1078 

1079def test_list_match(): 

1080 """function to test the list_match function 

1081 """ 

1082 params = [ 

1083 "mass_1", "mass_2", "chirp_mass", "total_mass", "mass_ratio", "a_1", 

1084 "luminosity_distance", "x" 

1085 ] 

1086 # Test that all mass parameters are returned 

1087 assert sorted(utils.list_match(params, "*mass*")) == sorted( 

1088 [p for p in params if "mass" in p] 

1089 ) 

1090 # Test that only parameters with an "_" are returned 

1091 assert sorted(utils.list_match(params, "*_*")) == sorted( 

1092 [p for p in params if "_" in p] 

1093 ) 

1094 # Test that nothing is returned 

1095 assert not len(utils.list_match(params, "z+")) 

1096 # Test that only parameters that do not start with mass are returned 

1097 assert sorted(utils.list_match(params, "^(?!mass).+")) == sorted( 

1098 [p for p in params if p[:4] != "mass"] 

1099 ) 

1100 # Test that only parameters that start with 'l' are returned 

1101 assert sorted(utils.list_match(params, "^l")) == sorted( 

1102 [p for p in params if p[0] == "l"] 

1103 ) 

1104 # Test return false 

1105 assert sorted(utils.list_match(params, "^l", return_false=True)) == sorted( 

1106 [p for p in params if p[0] != "l"] 

1107 ) 

1108 # Test multiple substrings 

1109 assert sorted(utils.list_match(params, ["^m", "*2"])) == sorted(["mass_2"]) 

1110 

1111 

1112class TestDict(object): 

1113 """Class to test the NestedDict object 

1114 """ 

1115 def test_initiate(self): 

1116 """Initiate the Dict class 

1117 """ 

1118 from pesummary.gw.file.psd import PSD 

1119 

1120 x = Dict( 

1121 {"a": [[10, 20], [10, 20]]}, value_class=PSD, 

1122 value_columns=["value", "value2"] 

1123 ) 

1124 assert list(x.keys()) == ["a"] 

1125 np.testing.assert_almost_equal(x["a"], [[10, 20], [10, 20]]) 

1126 assert isinstance(x["a"], PSD) 

1127 np.testing.assert_almost_equal(x["a"].value, [10, 10]) 

1128 np.testing.assert_almost_equal(x["a"].value2, [20, 20]) 

1129 

1130 x = Dict( 

1131 ["a"], [[[10, 20], [10, 20]]], value_class=PSD, 

1132 value_columns=["value", "value2"] 

1133 ) 

1134 assert list(x.keys()) == ["a"] 

1135 np.testing.assert_almost_equal(x["a"], [[10, 20], [10, 20]]) 

1136 assert isinstance(x["a"], PSD) 

1137 np.testing.assert_almost_equal(x["a"].value, [10, 10]) 

1138 np.testing.assert_almost_equal(x["a"].value2, [20, 20]) 

1139 

1140 def test_specify_columns(self): 

1141 """Test that x[["a", "b"]] works as expected 

1142 """ 

1143 x = Dict({"a": [10], "b": [20], "c": [30], "d": [40]}, value_class=list) 

1144 y = x[["a", "b"]] 

1145 assert sorted(list(y.keys())) == ["a", "b"] 

1146 for key in y.keys(): 

1147 assert y[key] == x[key] 

1148 with pytest.raises(Exception): 

1149 z = x[["e", "f"]] 

1150 

1151 

1152class TestProbabilityDict(object): 

1153 """Test the ProbabilityDict class 

1154 """ 

1155 def setup_method(self): 

1156 """Setup the ProbabilityDict class 

1157 """ 

1158 self.mydict = { 

1159 "a": [[1,2,3,4], [0.1, 0.2, 0.3, 0.4]], 

1160 "b": [[1,3,5,7], [0.1, 0.3, 0.2, 0.4]] 

1161 } 

1162 

1163 def test_initiate(self): 

1164 """Test the different ways to initiate this class 

1165 """ 

1166 probs = ProbabilityDict(self.mydict) 

1167 assert all(param in probs.keys() for param in self.mydict.keys()) 

1168 for param, data in self.mydict.items(): 

1169 np.testing.assert_almost_equal(data[0], probs[param].x) 

1170 np.testing.assert_almost_equal(data[1], probs[param].probs) 

1171 parameters = list(self.mydict.keys()) 

1172 data = [self.mydict[key] for key in parameters] 

1173 probs = ProbabilityDict(parameters, data) 

1174 assert all(param in probs.keys() for param in self.mydict.keys()) 

1175 for param, data in self.mydict.items(): 

1176 np.testing.assert_almost_equal(data[0], probs[param].x) 

1177 np.testing.assert_almost_equal(data[1], probs[param].probs) 

1178 

1179 def test_rvs(self): 

1180 """Test the .rvs method 

1181 """ 

1182 probs = ProbabilityDict(self.mydict) 

1183 # draw 10 samples for all parameters 

1184 samples = probs.rvs(size=10, interpolate=False) 

1185 assert isinstance(samples, SamplesDict) 

1186 assert all(param in samples.keys() for param in self.mydict.keys()) 

1187 assert all(len(samples[param] == 10) for param in self.mydict.keys()) 

1188 for param in self.mydict.keys(): 

1189 for p in samples[param]: 

1190 print(p, self.mydict[param][0]) 

1191 assert all(p in self.mydict[param][0] for p in samples[param]) 

1192 # draw 10 samples for only parameter 'a' 

1193 samples = probs.rvs(size=10, parameters=["a"]) 

1194 assert list(samples.keys()) == ["a"] 

1195 assert len(samples["a"]) == 10 

1196 # interpolate first and then draw samples 

1197 samples = probs.rvs(size=10, interpolate=True) 

1198 assert isinstance(samples, SamplesDict) 

1199 assert all(param in samples.keys() for param in self.mydict.keys()) 

1200 assert all(len(samples[param] == 10) for param in self.mydict.keys()) 

1201 

1202 def test_plotting(self): 

1203 """Test the .plot method 

1204 """ 

1205 probs = ProbabilityDict(self.mydict) 

1206 # Test that a histogram plot is generated when required 

1207 fig = probs.plot("a", type="hist") 

1208 fig = probs.plot("a", type="marginalized_posterior") 

1209 # Test that a PDF is generated when required 

1210 fig = probs.plot("a", type="pdf") 

1211 

1212 

1213class TestProbabilityDict2D(object): 

1214 """Class to test the ProbabilityDict2D class 

1215 """ 

1216 def setup_method(self): 

1217 """Setup the TestProbabilityDict2D class 

1218 """ 

1219 probs = np.random.uniform(0, 1, 100).reshape(10, 10) 

1220 self.mydict = { 

1221 "x_y": [ 

1222 np.random.uniform(0, 1, 10), np.random.uniform(0, 1, 10), 

1223 probs / np.sum(probs) 

1224 ] 

1225 } 

1226 

1227 def test_plotting(self): 

1228 """Test the .plot method 

1229 """ 

1230 # Test that a 2d KDE plot is generated when specified 

1231 mydict = ProbabilityDict2D(self.mydict) 

1232 fig = mydict.plot("x_y", type="2d_kde") 

1233 # Test that a triangle plot is generated when specified 

1234 fig, _, _, _ = mydict.plot("x_y", type="triangle") 

1235 # Test that a reverse triangle plot is generated when specified 

1236 fig, _, _, _ = mydict.plot("x_y", type="reverse_triangle") 

1237 # Test that an Exception is raised when you try and generate a plot 

1238 # which does not exist 

1239 with pytest.raises(Exception): 

1240 fig = mydict.plot("x_y", type="does_not_exist") 

1241 

1242 

1243class TestDiscretePDF2D(object): 

1244 """Class to test the DiscretePDF2D class 

1245 """ 

1246 def setup_method(self): 

1247 """Setup the TestDiscretePDF2D class 

1248 """ 

1249 self.x = [1,2] 

1250 self.y = [1,2] 

1251 self.probs = [[0.25, 0.25], [0.25, 0.25]] 

1252 

1253 def test_initiate(self): 

1254 """ 

1255 """ 

1256 # Test that a ValueError is raised if a 2d probability array is not 

1257 # provided 

1258 with pytest.raises(ValueError): 

1259 _ = DiscretePDF2D(self.x, self.y, [1,2,3,4,5,6]) 

1260 obj = DiscretePDF2D(self.x, self.y, self.probs) 

1261 

1262 def test_marginalize(self): 

1263 obj = DiscretePDF2D(self.x, self.y, self.probs) 

1264 marg = obj.marginalize() 

1265 assert isinstance(marg, DiscretePDF2Dplus1D) 

1266 np.testing.assert_almost_equal(obj.probs, marg.probs_xy.probs) 

1267 

1268 

1269class TestDiscretePDF2Dplus1D(object): 

1270 """Class to test the DiscretePDF2Dplus1D class 

1271 """ 

1272 def setup_method(self): 

1273 """Setup the TestDiscretePDF2Dplus1D class 

1274 """ 

1275 self.x = [1,2] 

1276 self.y = [1,2] 

1277 self.probs = [[0.5, 0.5], [0.5, 0.5], [[0.25, 0.25], [0.25, 0.25]]] 

1278 

1279 def test_initiate_raise(self): 

1280 """Test that the class raises ValueErrors when required 

1281 """ 

1282 # Raise error when 3 probabilities are not given 

1283 with pytest.raises(ValueError): 

1284 _ = DiscretePDF2Dplus1D(self.x, self.y, self.probs[:-1]) 

1285 # Raise error when 2d probability not given 

1286 with pytest.raises(ValueError): 

1287 _ = DiscretePDF2Dplus1D( 

1288 self.x, self.y, self.probs[:-1] + self.probs[:1] 

1289 ) 

1290 # Raise error when only 1 1d array is given 

1291 with pytest.raises(ValueError): 

1292 _ = DiscretePDF2Dplus1D( 

1293 self.x, self.y, self.probs[1:] + self.probs[-1:] 

1294 ) 

1295 

1296 def test_initiate(self): 

1297 """Test that we can initiate the class 

1298 """ 

1299 obj = DiscretePDF2Dplus1D(self.x, self.y, self.probs) 

1300 assert isinstance(obj.probs[0], DiscretePDF) 

1301 assert isinstance(obj.probs[2], DiscretePDF2D) 

1302 assert isinstance(obj.probs[1], DiscretePDF) 

1303 

1304 

1305class TestKDEList(object): 

1306 """Test the KDEList class 

1307 """ 

1308 def setup_method(self): 

1309 """Setup the KDEList class 

1310 """ 

1311 from scipy.stats import gaussian_kde 

1312 means = np.random.uniform(0, 5, size=10) 

1313 stds = np.random.uniform(0, 2, size=10) 

1314 self.inputs = np.array( 

1315 [ 

1316 np.random.normal(mean, std, size=10000) for mean, std 

1317 in zip(means, stds) 

1318 ] 

1319 ) 

1320 self.pts = np.linspace(-10, 10, 100) 

1321 self.true_kdes = [gaussian_kde(_)(self.pts) for _ in self.inputs] 

1322 

1323 def test_call(self): 

1324 """Test the KDEList call method 

1325 """ 

1326 from pesummary.utils.kde_list import KDEList 

1327 from scipy.stats import gaussian_kde 

1328 new = KDEList(self.inputs, kde=gaussian_kde, pts=self.pts) 

1329 # test on a single CPU 

1330 new_single = new(multi_process=1) 

1331 for num in range(len(self.inputs)): 

1332 np.testing.assert_almost_equal(self.true_kdes[num], new_single[num]) 

1333 # test on multiple CPUs 

1334 new_multiple = new(multi_process=2) 

1335 for num in range(len(self.inputs)): 

1336 np.testing.assert_almost_equal(self.true_kdes[num], new_single[num]) 

1337 # test the `idx` kwarg 

1338 for num in range(len(self.inputs)): 

1339 np.testing.assert_almost_equal(self.true_kdes[num], new(idx=num)) 

1340 # split the calculation into 2 chunks and 2 CPUs 

1341 new_chunk1 = new( 

1342 idx=np.arange(0, len(self.inputs) // 2, 1), multi_process=2 

1343 ).tolist() 

1344 new_chunk2 = new( 

1345 idx=np.arange(len(self.inputs) // 2, len(self.inputs), 1), multi_process=2 

1346 ).tolist() 

1347 new_combined = new_chunk1 + new_chunk2 

1348 for num in range(len(self.inputs)): 

1349 np.testing.assert_almost_equal(self.true_kdes[num], new_combined[num]) 

1350 # assert that different KDEs are produced for different pts 

1351 with pytest.raises(AssertionError): 

1352 new_diff = new(pts=np.linspace(-5, 5, 100), multi_process=1) 

1353 for num in range(len(self.inputs)): 

1354 np.testing.assert_almost_equal(self.true_kdes[num], new_diff[num]) 

1355 

1356 

1357def make_cache_style_file(style_file): 

1358 """Make a cache directory which stores the style file you wish to use 

1359 when plotting 

1360 

1361 Parameters 

1362 ---------- 

1363 style_file: str 

1364 path to the style file that you wish to use when plotting 

1365 """ 

1366 make_dir(CACHE_DIR) 

1367 shutil.copyfile( 

1368 style_file, os.path.join(CACHE_DIR, "matplotlib_rcparams.sty") 

1369 ) 

1370 

1371 

1372def get_matplotlib_style_file(): 

1373 """Return the path to the matplotlib style file that you wish to use 

1374 """ 

1375 return os.path.join(CACHE_DIR, "matplotlib_rcparams.sty")