Coverage for pesummary/tests/utils_test.py: 98.9%

710 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import os 

4import shutil 

5import h5py 

6import numpy as np 

7import copy 

8 

9import pesummary 

10from pesummary.io import write 

11import pesummary.cli as cli 

12from pesummary.utils import utils 

13from pesummary.utils.tqdm import tqdm 

14from pesummary.utils.dict import Dict 

15from pesummary.utils.list import List 

16from pesummary.utils.pdf import DiscretePDF, DiscretePDF2D, DiscretePDF2Dplus1D 

17from pesummary.utils.array import _2DArray 

18from pesummary.utils.samples_dict import ( 

19 Array, SamplesDict, MCMCSamplesDict, MultiAnalysisSamplesDict 

20) 

21from pesummary.utils.probability_dict import ProbabilityDict, ProbabilityDict2D 

22from pesummary._version_helper import GitInformation, PackageInformation 

23from pesummary._version_helper import get_version_information 

24 

25import pytest 

26from testfixtures import LogCapture 

27import tempfile 

28 

29tmpdir = tempfile.TemporaryDirectory(prefix=".", dir=".").name 

30__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

31 

32DEFAULT_DIRECTORY = os.getenv("CI_PROJECT_DIR", os.getcwd()) 

33 

34 

35class TestGitInformation(object): 

36 """Class to test the GitInformation helper class 

37 """ 

38 def setup_method(self): 

39 """Setup the TestGitInformation class 

40 """ 

41 self.git = GitInformation(directory=DEFAULT_DIRECTORY) 

42 

43 def test_last_commit_info(self): 

44 """Test the last_commit_info property 

45 """ 

46 assert len(self.git.last_commit_info) == 2 

47 assert isinstance(self.git.last_commit_info[0], str) 

48 assert isinstance(self.git.last_commit_info[1], str) 

49 

50 def test_last_version(self): 

51 """Test the last_version property 

52 """ 

53 assert isinstance(self.git.last_version, str) 

54 

55 def test_status(self): 

56 """Test the status property 

57 """ 

58 assert isinstance(self.git.status, str) 

59 

60 def test_builder(self): 

61 """Test the builder property 

62 """ 

63 assert isinstance(self.git.builder, str) 

64 

65 def test_build_date(self): 

66 """Test the build_date property 

67 """ 

68 assert isinstance(self.git.build_date, str) 

69 

70 

71class TestPackageInformation(object): 

72 """Class to test the PackageInformation helper class 

73 """ 

74 def setup_method(self): 

75 """Setup the TestPackageInformation class 

76 """ 

77 self.package = PackageInformation() 

78 

79 def test_package_info(self): 

80 """Test the package_info property 

81 """ 

82 pi = self.package.package_info 

83 assert isinstance(pi, list) 

84 pkg = pi[0] 

85 assert "name" in pkg 

86 assert "version" in pkg 

87 if "build_string" in pkg: # conda only 

88 assert "channel" in pkg 

89 

90 

91class TestUtils(object): 

92 """Class to test pesummary.utils.utils 

93 """ 

94 def setup_method(self): 

95 """Setup the TestUtils class 

96 """ 

97 if not os.path.isdir(tmpdir): 

98 os.mkdir(tmpdir) 

99 

100 def teardown_method(self): 

101 """Remove the files created from this class 

102 """ 

103 if os.path.isdir(tmpdir): 

104 shutil.rmtree(tmpdir) 

105 

106 def test_check_condition(self): 

107 """Test the check_condition method 

108 """ 

109 with pytest.raises(Exception) as info: 

110 condition = True 

111 utils.check_condition(condition, "error") 

112 assert str(info.value) == "error" 

113 

114 def test_rename_group_in_hf5_file(self): 

115 """Test the rename_group_in_hf5_file method 

116 """ 

117 f = h5py.File("{}/rename_group.h5".format(tmpdir), "w") 

118 group = f.create_group("group") 

119 group.create_dataset("example", data=np.array([10])) 

120 f.close() 

121 utils.rename_group_or_dataset_in_hf5_file( 

122 "{}/rename_group.h5".format(tmpdir), 

123 group=["group", "replaced"]) 

124 f = h5py.File("{}/rename_group.h5".format(tmpdir)) 

125 assert list(f.keys()) == ["replaced"] 

126 assert list(f["replaced"].keys()) == ["example"] 

127 assert len(f["replaced/example"]) == 1 

128 assert f["replaced/example"][0] == 10 

129 f.close() 

130 

131 def test_rename_dataset_in_hf5_file(self): 

132 f = h5py.File("{}/rename_dataset.h5".format(tmpdir), "w") 

133 group = f.create_group("group") 

134 group.create_dataset("example", data=np.array([10])) 

135 f.close() 

136 utils.rename_group_or_dataset_in_hf5_file( 

137 "{}/rename_dataset.h5".format(tmpdir), 

138 dataset=["group/example", "group/replaced"]) 

139 f = h5py.File("{}/rename_dataset.h5".format(tmpdir)) 

140 assert list(f.keys()) == ["group"] 

141 assert list(f["group"].keys()) == ["replaced"] 

142 assert len(f["group/replaced"]) == 1 

143 assert f["group/replaced"][0] == 10 

144 f.close() 

145 

146 def test_rename_unknown_hf5_file(self): 

147 with pytest.raises(Exception) as info: 

148 utils.rename_group_or_dataset_in_hf5_file( 

149 "{}/unknown.h5".format(tmpdir), 

150 group=["None", "replaced"]) 

151 assert "does not exist" in str(info.value) 

152 

153 def test_directory_creation(self): 

154 directory = '{}/test_dir'.format(tmpdir) 

155 assert os.path.isdir(directory) == False 

156 utils.make_dir(directory) 

157 assert os.path.isdir(directory) == True 

158 

159 def test_url_guess(self): 

160 host = ["raven", "cit", "ligo-wa", "uwm", "phy.syr.edu", "vulcan", 

161 "atlas", "iucaa", "alice"] 

162 expected = ["https://geo2.arcca.cf.ac.uk/~albert.einstein/test", 

163 "https://ldas-jobs.ligo.caltech.edu/~albert.einstein/test", 

164 "https://ldas-jobs.ligo-wa.caltech.edu/~albert.einstein/test", 

165 "https://ldas-jobs.phys.uwm.edu/~albert.einstein/test", 

166 "https://sugar-jobs.phy.syr.edu/~albert.einstein/test", 

167 "https://galahad.aei.mpg.de/~albert.einstein/test", 

168 "https://atlas1.atlas.aei.uni-hannover.de/~albert.einstein/test", 

169 "https://ldas-jobs.gw.iucaa.in/~albert.einstein/test", 

170 "https://dumpty.alice.icts.res.in/~albert.einstein/test"] 

171 user = "albert.einstein" 

172 webdir = '/home/albert.einstein/public_html/test' 

173 for i,j in zip(host, expected): 

174 url = utils.guess_url(webdir, i, user) 

175 assert url == j 

176 

177 def test_make_dir(self): 

178 """Test the make_dir method 

179 """ 

180 assert not os.path.isdir(os.path.join(tmpdir, "test")) 

181 utils.make_dir(os.path.join(tmpdir, "test")) 

182 assert os.path.isdir(os.path.join(tmpdir, "test")) 

183 with open(os.path.join(tmpdir, "test", "test.dat"), "w") as f: 

184 f.writelines(["test"]) 

185 utils.make_dir(os.path.join(tmpdir, "test")) 

186 assert os.path.isfile(os.path.join(tmpdir, "test", "test.dat")) 

187 

188 def test_resample_posterior_distribution(self): 

189 """Test the resample_posterior_distribution method 

190 """ 

191 data = np.random.normal(1, 0.1, 1000) 

192 resampled = utils.resample_posterior_distribution([data], 500) 

193 assert len(resampled) == 500 

194 assert np.round(np.mean(resampled), 1) == 1. 

195 assert np.round(np.std(resampled), 1) == 0.1 

196 

197 def test_gw_results_file(self): 

198 """Test the gw_results_file method 

199 """ 

200 from .base import namespace 

201 

202 opts = namespace({"gw": True, "psd": True}) 

203 assert utils.gw_results_file(opts) 

204 opts = namespace({"webdir": tmpdir}) 

205 assert not utils.gw_results_file(opts) 

206 

207 def test_functions(self): 

208 """Test the functions method 

209 """ 

210 from .base import namespace 

211 

212 opts = namespace({"gw": True, "psd": True}) 

213 funcs = utils.functions(opts) 

214 assert funcs["input"] == pesummary.gw.cli.inputs.WebpagePlusPlottingPlusMetaFileInput 

215 assert funcs["MetaFile"] == pesummary.gw.file.meta_file.GWMetaFile 

216 

217 opts = namespace({"webdir": tmpdir}) 

218 funcs = utils.functions(opts) 

219 assert funcs["input"] == pesummary.core.cli.inputs.WebpagePlusPlottingPlusMetaFileInput 

220 assert funcs["MetaFile"] == pesummary.core.file.meta_file.MetaFile 

221 

222 def test_get_version_information(self): 

223 """Test the get_version_information method 

224 """ 

225 assert isinstance(get_version_information(), str) 

226 

227 

228class TestGelmanRubin(object): 

229 """Test the Gelman Rubin calculation 

230 """ 

231 def test_same_as_lalinference(self): 

232 """Test the Gelman rubin output from pesummary is the same as 

233 the one coded in LALInference 

234 """ 

235 from lalinference.bayespputils import Posterior 

236 from pesummary.utils.utils import gelman_rubin 

237 

238 header = ["a", "b", "logL", "chain"] 

239 for _ in np.arange(100): 

240 samples = np.array( 

241 [ 

242 np.random.uniform(np.random.random(), 0.1, 3).tolist() + 

243 [np.random.randint(1, 3)] for _ in range(10) 

244 ] 

245 ) 

246 obj = Posterior([header, np.array(samples)]) 

247 R = obj.gelman_rubin("a") 

248 chains = np.unique(obj["chain"].samples) 

249 chain_index = obj.names.index("chain") 

250 param_index = obj.names.index("a") 

251 data, _ = obj.samples() 

252 chainData=[ 

253 data[data[:,chain_index] == chain, param_index] for chain in 

254 chains 

255 ] 

256 np.testing.assert_almost_equal( 

257 gelman_rubin(chainData, decimal=10), R, 7 

258 ) 

259 

260 def test_same_samples(self): 

261 """Test that when passed two identical chains (perfect convergence), 

262 the Gelman Rubin is 1 

263 """ 

264 from pesummary.core.plots.plot import gelman_rubin 

265 

266 samples = np.random.uniform(1, 0.5, 10) 

267 R = gelman_rubin([samples, samples]) 

268 assert R == 1 

269 

270 

271class TestSamplesDict(object): 

272 """Test the SamplesDict class 

273 """ 

274 def setup_method(self): 

275 self.parameters = ["a", "b"] 

276 self.samples = [ 

277 np.random.uniform(10, 0.5, 100), np.random.uniform(200, 10, 100) 

278 ] 

279 if not os.path.isdir(tmpdir): 

280 os.mkdir(tmpdir) 

281 write( 

282 self.parameters, np.array(self.samples).T, outdir=tmpdir, 

283 filename="test.dat", file_format="dat" 

284 ) 

285 

286 def teardown_method(self): 

287 """Remove the files created from this class 

288 """ 

289 if os.path.isdir(tmpdir): 

290 shutil.rmtree(tmpdir) 

291 

292 def test_initalize(self): 

293 """Test that the two ways to initialize the SamplesDict class are 

294 equivalent 

295 """ 

296 base = SamplesDict(self.parameters, self.samples) 

297 other = SamplesDict( 

298 { 

299 param: sample for param, sample in zip( 

300 self.parameters, self.samples 

301 ) 

302 } 

303 ) 

304 assert base.parameters == other.parameters 

305 assert sorted(base.parameters) == sorted(self.parameters) 

306 np.testing.assert_almost_equal(base.samples, other.samples) 

307 assert sorted(list(base.keys())) == sorted(list(other.keys())) 

308 np.testing.assert_almost_equal(base.samples, self.samples) 

309 class_method = SamplesDict.from_file( 

310 "{}/test.dat".format(tmpdir), add_zero_likelihood=False 

311 ) 

312 np.testing.assert_almost_equal(class_method.samples, self.samples) 

313 

314 def test_properties(self): 

315 """Test that the properties of the SamplesDict class are correct 

316 """ 

317 import pandas as pd 

318 

319 dataset = SamplesDict(self.parameters, self.samples) 

320 assert sorted(dataset.minimum.keys()) == sorted(self.parameters) 

321 assert dataset.minimum["a"] == np.min(self.samples[0]) 

322 assert dataset.minimum["b"] == np.min(self.samples[1]) 

323 assert dataset.median["a"] == np.median(self.samples[0]) 

324 assert dataset.median["b"] == np.median(self.samples[1]) 

325 assert dataset.mean["a"] == np.mean(self.samples[0]) 

326 assert dataset.mean["b"] == np.mean(self.samples[1]) 

327 assert dataset.number_of_samples == len(self.samples[1]) 

328 assert len(dataset.downsample(10)["a"]) == 10 

329 dataset = SamplesDict(self.parameters, self.samples) 

330 assert len(dataset.discard_samples(10)["a"]) == len(self.samples[0]) - 10 

331 np.testing.assert_almost_equal( 

332 dataset.discard_samples(10)["a"], self.samples[0][10:] 

333 ) 

334 np.testing.assert_almost_equal( 

335 dataset.discard_samples(10)["b"], self.samples[1][10:] 

336 ) 

337 p = dataset.to_pandas() 

338 assert isinstance(p, pd.core.frame.DataFrame) 

339 remove = dataset.pop("a") 

340 assert list(dataset.keys()) == ["b"] 

341 

342 def test_core_plots(self): 

343 """Test that the core plotting methods of the SamplesDict class work as 

344 expected 

345 """ 

346 import matplotlib.figure 

347 

348 dataset = SamplesDict(self.parameters, self.samples) 

349 fig = dataset.plot(self.parameters[0], type="hist") 

350 assert isinstance(fig, matplotlib.figure.Figure) 

351 fig = dataset.plot(self.parameters[0], type="marginalized_posterior") 

352 assert isinstance(fig, matplotlib.figure.Figure) 

353 

354 def test_standardize_parameter_names(self): 

355 """Test the standardize_parameter_names method 

356 """ 

357 dictionary = {"mass1": [1,2,3,4], "mass2": [4,3,2,1], "zz": [1,1,1,1]} 

358 dictionary_copy = dictionary.copy() 

359 mydict = SamplesDict(dictionary_copy) 

360 standard_dict = mydict.standardize_parameter_names() 

361 # check standard parameter names are in the new dictionary 

362 assert sorted(list(standard_dict.keys())) == ["mass_1", "mass_2", "zz"] 

363 # check that the dictionary items remains the same 

364 _mapping = {"mass1": "mass_1", "mass2": "mass_2", "zz": "zz"} 

365 for key, item in dictionary.items(): 

366 np.testing.assert_almost_equal(item, standard_dict[_mapping[key]]) 

367 # check old dictionary remains unchanged 

368 assert sorted(list(mydict.keys())) == ["mass1", "mass2", "zz"] 

369 for key, item in mydict.items(): 

370 np.testing.assert_almost_equal(item, dictionary[key]) 

371 # try custom mapping 

372 mapping = {"mass1": "custom_m1", "mass2": "custom_m2", "zz": "custom_zz"} 

373 new_dict = mydict.standardize_parameter_names(mapping=mapping) 

374 assert sorted(list(new_dict.keys())) == [ 

375 "custom_m1", "custom_m2", "custom_zz" 

376 ] 

377 for old, new in mapping.items(): 

378 np.testing.assert_almost_equal(dictionary[old], new_dict[new]) 

379 

380 def test_waveforms(self): 

381 """Test the waveform generation 

382 """ 

383 from pesummary.core.fetch import download_dir 

384 try: 

385 from pycbc.waveform import get_fd_waveform, get_td_waveform 

386 except (ValueError, ImportError): 

387 return 

388 

389 downloaded_file = os.path.join( 

390 download_dir, "GW190814_posterior_samples.h5" 

391 ) 

392 if not os.path.isfile(downloaded_file): 

393 from pesummary.gw.fetch import fetch_open_samples 

394 f = fetch_open_samples( 

395 "GW190814", read_file=True, outdir=download_dir, unpack=True, 

396 path="GW190814.h5", catalog="GWTC-2" 

397 ) 

398 else: 

399 from pesummary.io import read 

400 f = read(downloaded_file) 

401 samples = f.samples_dict["C01:IMRPhenomPv3HM"] 

402 ind = 0 

403 data = samples.fd_waveform("IMRPhenomPv3HM", 1./256, 20., 1024., ind=ind) 

404 hp_pycbc, hc_pycbc = get_fd_waveform( 

405 approximant="IMRPhenomPv3HM", mass1=samples["mass_1"][ind], 

406 mass2=samples["mass_2"][ind], spin1x=samples["spin_1x"][ind], 

407 spin1y=samples["spin_1y"][ind], spin1z=samples["spin_1z"][ind], 

408 spin2x=samples["spin_2x"][ind], spin2y=samples["spin_2y"][ind], 

409 spin2z=samples["spin_2z"][ind], inclination=samples["iota"][ind], 

410 distance=samples["luminosity_distance"][ind], 

411 coa_phase=samples["phase"][ind], f_lower=20., f_final=1024., 

412 delta_f=1./256, f_ref=20. 

413 ) 

414 np.testing.assert_almost_equal( 

415 data["h_plus"].frequencies.value, hp_pycbc.sample_frequencies 

416 ) 

417 np.testing.assert_almost_equal( 

418 data["h_cross"].frequencies.value, hc_pycbc.sample_frequencies 

419 ) 

420 np.testing.assert_almost_equal( 

421 data["h_plus"].value * 10**25, hp_pycbc._data * 10**25 

422 ) 

423 np.testing.assert_almost_equal( 

424 data["h_cross"].value * 10**25, hc_pycbc._data * 10**25 

425 ) 

426 data = samples.td_waveform("SEOBNRv4PHM", 1./4096, 20., ind=ind) 

427 hp_pycbc, hc_pycbc = get_td_waveform( 

428 approximant="SEOBNRv4PHM", mass1=samples["mass_1"][ind], 

429 mass2=samples["mass_2"][ind], spin1x=samples["spin_1x"][ind], 

430 spin1y=samples["spin_1y"][ind], spin1z=samples["spin_1z"][ind], 

431 spin2x=samples["spin_2x"][ind], spin2y=samples["spin_2y"][ind], 

432 spin2z=samples["spin_2z"][ind], inclination=samples["iota"][ind], 

433 distance=samples["luminosity_distance"][ind], 

434 coa_phase=samples["phase"][ind], f_lower=20., delta_t=1./4096, 

435 f_ref=20. 

436 ) 

437 np.testing.assert_almost_equal( 

438 data["h_plus"].times.value, hp_pycbc.sample_times 

439 ) 

440 np.testing.assert_almost_equal( 

441 data["h_cross"].times.value, hc_pycbc.sample_times 

442 ) 

443 np.testing.assert_almost_equal( 

444 data["h_plus"].value * 10**25, hp_pycbc._data * 10**25 

445 ) 

446 np.testing.assert_almost_equal( 

447 data["h_cross"].value * 10**25, hc_pycbc._data * 10**25 

448 ) 

449 

450 

451class TestMultiAnalysisSamplesDict(object): 

452 """Test the MultiAnalysisSamplesDict class 

453 """ 

454 def setup_method(self): 

455 self.parameters = ["a", "b"] 

456 self.samples = [ 

457 [np.random.uniform(10, 0.5, 100), np.random.uniform(100, 10, 100)], 

458 [np.random.uniform(5, 0.5, 100), np.random.uniform(80, 10, 100)], 

459 ] 

460 self.labels = ["one", "two"] 

461 if not os.path.isdir(tmpdir): 

462 os.mkdir(tmpdir) 

463 for num, _samples in enumerate(self.samples): 

464 write( 

465 self.parameters, np.array(_samples).T, outdir=tmpdir, 

466 filename="test_{}.dat".format(num + 1), file_format="dat" 

467 ) 

468 

469 def teardown_method(self): 

470 """Remove the files created from this class 

471 """ 

472 if os.path.isdir(tmpdir): 

473 shutil.rmtree(tmpdir) 

474 

475 def test_initalize(self): 

476 """Test the different ways to initalize the class 

477 """ 

478 dataframe = MultiAnalysisSamplesDict( 

479 self.parameters, self.samples, labels=["one", "two"] 

480 ) 

481 assert sorted(list(dataframe.keys())) == sorted(self.labels) 

482 assert sorted(list(dataframe["one"])) == sorted(["a", "b"]) 

483 assert sorted(list(dataframe["two"])) == sorted(["a", "b"]) 

484 np.testing.assert_almost_equal( 

485 dataframe["one"]["a"], self.samples[0][0] 

486 ) 

487 np.testing.assert_almost_equal( 

488 dataframe["one"]["b"], self.samples[0][1] 

489 ) 

490 np.testing.assert_almost_equal( 

491 dataframe["two"]["a"], self.samples[1][0] 

492 ) 

493 np.testing.assert_almost_equal( 

494 dataframe["two"]["b"], self.samples[1][1] 

495 ) 

496 _other = MCMCSamplesDict({ 

497 label: { 

498 param: self.samples[num][idx] for idx, param in enumerate( 

499 self.parameters 

500 ) 

501 } for num, label in enumerate(self.labels) 

502 }) 

503 class_method = MultiAnalysisSamplesDict.from_files( 

504 { 

505 'one': "{}/test_1.dat".format(tmpdir), 

506 'two': "{}/test_2.dat".format(tmpdir) 

507 }, add_zero_likelihood=False 

508 ) 

509 for other in [_other, class_method]: 

510 assert sorted(other.keys()) == sorted(dataframe.keys()) 

511 assert sorted(other["one"].keys()) == sorted( 

512 dataframe["one"].keys() 

513 ) 

514 np.testing.assert_almost_equal( 

515 other["one"]["a"], dataframe["one"]["a"] 

516 ) 

517 np.testing.assert_almost_equal( 

518 other["one"]["b"], dataframe["one"]["b"] 

519 ) 

520 np.testing.assert_almost_equal( 

521 other["two"]["a"], dataframe["two"]["a"] 

522 ) 

523 np.testing.assert_almost_equal( 

524 other["two"]["b"], dataframe["two"]["b"] 

525 ) 

526 

527 

528 def test_different_samples_for_different_analyses(self): 

529 """Test that nothing breaks when different samples have different parameters 

530 """ 

531 data = { 

532 "one": { 

533 "a": np.random.uniform(10, 0.5, 100), 

534 "b": np.random.uniform(5, 0.5, 100) 

535 }, "two": { 

536 "a": np.random.uniform(10, 0.5, 100) 

537 } 

538 } 

539 dataframe = MultiAnalysisSamplesDict(data) 

540 assert sorted(dataframe["one"].keys()) == sorted(data["one"].keys()) 

541 assert sorted(dataframe["two"].keys()) == sorted(data["two"].keys()) 

542 np.testing.assert_almost_equal( 

543 dataframe["one"]["a"], data["one"]["a"] 

544 ) 

545 np.testing.assert_almost_equal( 

546 dataframe["one"]["b"], data["one"]["b"] 

547 ) 

548 np.testing.assert_almost_equal( 

549 dataframe["two"]["a"], data["two"]["a"] 

550 ) 

551 with pytest.raises(ValueError): 

552 transpose = dataframe.T 

553 

554 def test_adding_to_table(self): 

555 """ 

556 """ 

557 data = { 

558 "one": { 

559 "a": np.random.uniform(10, 0.5, 100), 

560 "b": np.random.uniform(5, 0.5, 100) 

561 }, "two": { 

562 "a": np.random.uniform(10, 0.5, 100) 

563 } 

564 } 

565 dataframe = MultiAnalysisSamplesDict(data) 

566 assert "three" not in dataframe.parameters.keys() 

567 new_data = {"a": np.random.uniform(10, 0.5, 100)} 

568 dataframe["three"] = new_data 

569 np.testing.assert_almost_equal(dataframe["three"]["a"], new_data["a"]) 

570 assert dataframe.parameters["three"] == ["a"] 

571 assert "three" in dataframe.number_of_samples.keys() 

572 

573 def test_combine(self): 

574 """Test that the .combine method is working as expected 

575 """ 

576 data = { 

577 "one": { 

578 "a": np.random.uniform(10, 0.5, 100), 

579 "b": np.random.uniform(5, 0.5, 100) 

580 }, "two": { 

581 "a": np.random.uniform(100, 0.5, 100), 

582 "b": np.random.uniform(50, 0.5, 100) 

583 } 

584 } 

585 dataframe = MultiAnalysisSamplesDict(data) 

586 # test that when weights are 0 and 1, we only get the second set of 

587 # samples 

588 combine = dataframe.combine(weights={"one": 0., "two": 1.}) 

589 assert "a" in combine.keys() 

590 assert "b" in combine.keys() 

591 assert all(ss in data["two"]["a"] for ss in combine["a"]) 

592 assert all(ss in data["two"]["b"] for ss in combine["b"]) 

593 # test that when weights are equal, the first half of samples are from 

594 # one and the second half are from two 

595 combine = dataframe.combine(labels=["one", "two"], weights=[0.5, 0.5]) 

596 nsamples = len(combine["a"]) 

597 half = int(nsamples / 2) 

598 assert all(ss in data["one"]["a"] for ss in combine["a"][:half]) 

599 assert all(ss in data["two"]["a"] for ss in combine["a"][half:]) 

600 assert all(ss in data["one"]["b"] for ss in combine["b"][:half]) 

601 assert all(ss in data["two"]["b"] for ss in combine["b"][half:]) 

602 # test that the samples maintain order 

603 for n in np.random.choice(half, size=10, replace=False): 

604 ind = np.argwhere(data["one"]["a"] == combine["a"][n]) 

605 assert data["one"]["b"][ind] == combine["b"][n] 

606 # test that when use_all is provided, all samples are included 

607 combine = dataframe.combine(use_all=True) 

608 assert len(combine["a"]) == len(data["one"]["a"]) + len(data["two"]["a"]) 

609 # test shuffle 

610 combine = dataframe.combine( 

611 labels=["one", "two"], weights=[0.5, 0.5], shuffle=True 

612 ) 

613 for n in np.random.choice(half, size=10, replace=False): 

614 if combine["a"][n] in data["one"]["a"]: 

615 ind = np.argwhere(data["one"]["a"] == combine["a"][n]) 

616 assert data["one"]["b"][ind] == combine["b"][n] 

617 else: 

618 ind = np.argwhere(data["two"]["a"] == combine["a"][n]) 

619 assert data["two"]["b"][ind] == combine["b"][n] 

620 assert len(set(combine["a"])) == len(combine["a"]) 

621 assert len(set(combine["b"])) == len(combine["b"]) 

622 

623 

624class TestMCMCSamplesDict(object): 

625 """Test the MCMCSamplesDict class 

626 """ 

627 def setup_method(self): 

628 self.parameters = ["a", "b"] 

629 self.chains = [ 

630 [np.random.uniform(10, 0.5, 100), np.random.uniform(100, 10, 100)], 

631 [np.random.uniform(5, 0.5, 100), np.random.uniform(80, 10, 100)] 

632 ] 

633 

634 def test_initalize(self): 

635 """Test the different ways to initalize the class 

636 """ 

637 dataframe = MCMCSamplesDict(self.parameters, self.chains) 

638 assert sorted(list(dataframe.keys())) == sorted( 

639 ["chain_{}".format(num) for num in range(len(self.chains))] 

640 ) 

641 assert sorted(list(dataframe["chain_0"].keys())) == sorted(["a", "b"]) 

642 assert sorted(list(dataframe["chain_1"].keys())) == sorted(["a", "b"]) 

643 np.testing.assert_almost_equal( 

644 dataframe["chain_0"]["a"], self.chains[0][0] 

645 ) 

646 np.testing.assert_almost_equal( 

647 dataframe["chain_0"]["b"], self.chains[0][1] 

648 ) 

649 np.testing.assert_almost_equal( 

650 dataframe["chain_1"]["a"], self.chains[1][0] 

651 ) 

652 np.testing.assert_almost_equal( 

653 dataframe["chain_1"]["b"], self.chains[1][1] 

654 ) 

655 other = MCMCSamplesDict({ 

656 "chain_{}".format(num): { 

657 param: self.chains[num][idx] for idx, param in enumerate( 

658 self.parameters 

659 ) 

660 } for num in range(len(self.chains)) 

661 }) 

662 assert sorted(other.keys()) == sorted(dataframe.keys()) 

663 assert sorted(other["chain_0"].keys()) == sorted( 

664 dataframe["chain_0"].keys() 

665 ) 

666 np.testing.assert_almost_equal( 

667 other["chain_0"]["a"], dataframe["chain_0"]["a"] 

668 ) 

669 np.testing.assert_almost_equal( 

670 other["chain_0"]["b"], dataframe["chain_0"]["b"] 

671 ) 

672 np.testing.assert_almost_equal( 

673 other["chain_1"]["a"], dataframe["chain_1"]["a"] 

674 ) 

675 np.testing.assert_almost_equal( 

676 other["chain_1"]["b"], dataframe["chain_1"]["b"] 

677 ) 

678 

679 def test_unequal_chain_length(self): 

680 """Test that when inverted, the chains keep their unequal chain 

681 length 

682 """ 

683 chains = [ 

684 [np.random.uniform(10, 0.5, 100), np.random.uniform(100, 10, 100)], 

685 [np.random.uniform(5, 0.5, 200), np.random.uniform(80, 10, 200)] 

686 ] 

687 dataframe = MCMCSamplesDict(self.parameters, chains) 

688 transpose = dataframe.T 

689 assert len(transpose["a"]["chain_0"]) == 100 

690 assert len(transpose["a"]["chain_1"]) == 200 

691 assert dataframe.number_of_samples == { 

692 "chain_0": 100, "chain_1": 200 

693 } 

694 assert dataframe.minimum_number_of_samples == 100 

695 assert transpose.number_of_samples == dataframe.number_of_samples 

696 assert transpose.minimum_number_of_samples == \ 

697 dataframe.minimum_number_of_samples 

698 combined = dataframe.combine 

699 assert combined.number_of_samples == 300 

700 my_combine = np.concatenate( 

701 [dataframe["chain_0"]["a"], dataframe["chain_1"]["a"]] 

702 ) 

703 assert all(ss in my_combine for ss in combined["a"]) 

704 

705 def test_properties(self): 

706 """Test that the properties of the MCMCSamplesDict class are correct 

707 """ 

708 dataframe = MCMCSamplesDict(self.parameters, self.chains) 

709 transpose = dataframe.T 

710 np.testing.assert_almost_equal( 

711 dataframe["chain_0"]["a"], transpose["a"]["chain_0"] 

712 ) 

713 np.testing.assert_almost_equal( 

714 dataframe["chain_0"]["b"], transpose["b"]["chain_0"] 

715 ) 

716 np.testing.assert_almost_equal( 

717 dataframe["chain_1"]["a"], transpose["a"]["chain_1"] 

718 ) 

719 np.testing.assert_almost_equal( 

720 dataframe["chain_1"]["b"], transpose["b"]["chain_1"] 

721 ) 

722 average = dataframe.average 

723 transpose_average = transpose.average 

724 for param in self.parameters: 

725 np.testing.assert_almost_equal( 

726 average[param], transpose_average[param] 

727 ) 

728 assert dataframe.total_number_of_samples == 200 

729 assert dataframe.total_number_of_samples == \ 

730 transpose.total_number_of_samples 

731 combined = dataframe.combine 

732 assert combined.number_of_samples == 200 

733 mycombine = np.concatenate( 

734 [dataframe["chain_0"]["a"], dataframe["chain_1"]["a"]] 

735 ) 

736 assert all(s in combined["a"] for s in mycombine) 

737 transpose_copy = transpose.T 

738 assert sorted(list(transpose_copy.keys())) == sorted(list(dataframe.keys())) 

739 for level1 in dataframe.keys(): 

740 assert sorted(list(transpose_copy[level1].keys())) == sorted( 

741 list(dataframe[level1].keys()) 

742 ) 

743 for level2 in dataframe[level1].keys(): 

744 np.testing.assert_almost_equal( 

745 transpose_copy[level1][level2], dataframe[level1][level2] 

746 ) 

747 

748 def test_key_data(self): 

749 """Test that the key data is correct 

750 """ 

751 dataframe = MCMCSamplesDict(self.parameters, self.chains) 

752 key_data = dataframe.key_data 

753 combined = dataframe.combine 

754 for param, in key_data.keys(): 

755 np.testing.assert_almost_equal( 

756 key_data[param]["mean"], np.mean(combined[param]) 

757 ) 

758 np.testing.assert_almost_equal( 

759 key_data[param]["median"], np.median(combined[param]) 

760 ) 

761 

762 def test_burnin_removal(self): 

763 """Test that the different methods for removing the samples as burnin 

764 as expected 

765 """ 

766 uniform = np.random.uniform 

767 parameters = ["a", "b", "cycle"] 

768 chains = [ 

769 [uniform(10, 0.5, 100), uniform(100, 10, 100), uniform(1, 0.8, 100)], 

770 [uniform(5, 0.5, 100), uniform(80, 10, 100), uniform(1, 0.8, 100)], 

771 [uniform(1, 0.8, 100), uniform(90, 10, 100), uniform(1, 0.8, 100)] 

772 ] 

773 dataframe = MCMCSamplesDict(parameters, chains) 

774 burnin = dataframe.burnin(algorithm="burnin_by_step_number") 

775 idxs = np.argwhere(chains[0][2] > 0) 

776 assert len(burnin["chain_0"]["a"]) == len(idxs) 

777 dataframe = MCMCSamplesDict(parameters, chains) 

778 burnin = dataframe.burnin(10, algorithm="burnin_by_first_n") 

779 assert len(burnin["chain_0"]["a"]) == 90 

780 dataframe = MCMCSamplesDict(parameters, chains) 

781 burnin = dataframe.burnin( 

782 10, algorithm="burnin_by_first_n", step_number=True 

783 ) 

784 assert len(burnin["chain_0"]["a"]) == len(idxs) - 10 

785 

786 

787class TestList(object): 

788 """Test the List class 

789 """ 

790 def test_added(self): 

791 original = ["a", "b", "c", "d", "e"] 

792 array = List(original) 

793 assert not len(array.added) 

794 array.append("f") 

795 assert len(array.added) 

796 assert array.added == ["f"] 

797 array.extend(["g", "h"]) 

798 assert sorted(array.added) == sorted(["f", "g", "h"]) 

799 assert sorted(array.original) == sorted(original) 

800 array.insert(2, "z") 

801 assert sorted(array.added) == sorted(["f", "g", "h", "z"]) 

802 assert sorted(array) == sorted(original + ["f", "g", "h", "z"]) 

803 array = List(original) 

804 array = array + ["f", "g", "h"] 

805 assert sorted(array.added) == sorted(["f", "g", "h"]) 

806 array += ["i"] 

807 assert sorted(array.added) == sorted(["f", "g", "h", "i"]) 

808 

809 def test_removed(self): 

810 original = ["a", "b", "c", "d", "e"] 

811 array = List(original) 

812 assert not len(array.removed) 

813 array.remove("e") 

814 assert sorted(array) == sorted(["a", "b", "c", "d"]) 

815 assert array.removed == ["e"] 

816 assert not len(sorted(array.added)) 

817 array.extend(["f", "g"]) 

818 array.remove("f") 

819 assert array.removed == ["e", "f"] 

820 assert array.added == ["g"] 

821 array.pop(0) 

822 assert sorted(array.removed) == sorted(["e", "f", "a"]) 

823 

824 

825def test_2DArray(): 

826 """Test the pesummary.utils.array._2DArray class 

827 """ 

828 samples = [ 

829 np.random.normal(np.random.randint(100), 0.2, size=1000) for _ in 

830 range(10) 

831 ] 

832 arrays = _2DArray(samples) 

833 for num, array in enumerate(arrays): 

834 np.testing.assert_almost_equal(array, samples[num]) 

835 for num, array in enumerate(arrays): 

836 assert array.standard_deviation == np.std(samples[num]) 

837 assert array.minimum == np.min(samples[num]) 

838 assert array.maximum == np.max(samples[num]) 

839 _key_data = array.key_data 

840 assert _key_data["5th percentile"] == np.percentile(samples[num], 5) 

841 assert _key_data["95th percentile"] == np.percentile(samples[num], 95) 

842 assert _key_data["median"] == np.median(samples[num]) 

843 assert _key_data["mean"] == np.mean(samples[num]) 

844 

845 samples = [ 

846 np.random.normal(np.random.randint(100), 0.2, size=1000) for _ in 

847 range(10) 

848 ] 

849 likelihood = np.random.uniform(0, 1, 1000) 

850 prior = np.random.uniform(0, 1, 1000) 

851 arrays = _2DArray(samples, likelihood=likelihood, prior=prior) 

852 for num, array in enumerate(arrays): 

853 np.testing.assert_almost_equal(array, samples[num]) 

854 for num, array in enumerate(arrays): 

855 assert array.standard_deviation == np.std(samples[num]) 

856 assert array.minimum == np.min(samples[num]) 

857 assert array.maximum == np.max(samples[num]) 

858 assert array.maxL == array[np.argmax(likelihood)] 

859 assert array.maxP == array[np.argmax(likelihood + prior)] 

860 _key_data = array.key_data 

861 assert _key_data["5th percentile"] == np.percentile(samples[num], 5) 

862 assert _key_data["95th percentile"] == np.percentile(samples[num], 95) 

863 assert _key_data["median"] == np.median(samples[num]) 

864 assert _key_data["mean"] == np.mean(samples[num]) 

865 assert _key_data["maxL"] == array.maxL 

866 assert _key_data["maxP"] == array.maxP 

867 

868 

869class TestArray(object): 

870 """Test the Array class 

871 """ 

872 def test_properties(self): 

873 samples = np.random.uniform(100, 10, 100) 

874 array = Array(samples) 

875 assert array.average(type="mean") == np.mean(samples) 

876 assert array.average(type="median") == np.median(samples) 

877 assert array.standard_deviation == np.std(samples) 

878 np.testing.assert_almost_equal( 

879 array.credible_interval(percentile=[5, 95]), 

880 [np.percentile(array, 5), np.percentile(array, 95)] 

881 ) 

882 

883 def test_weighted_percentile(self): 

884 x = np.random.normal(100, 20, 10000) 

885 weights = np.array([np.random.randint(100) for _ in range(10000)]) 

886 array = Array(x, weights=weights) 

887 numpy = np.percentile(np.repeat(x, weights), 90) 

888 pesummary = array.credible_interval(percentile=90) 

889 np.testing.assert_almost_equal(numpy, pesummary, 6) 

890 

891 def test_key_data(self): 

892 samples = np.random.normal(100, 20, 10000) 

893 array = Array(samples) 

894 key_data = array.key_data 

895 np.testing.assert_almost_equal(key_data["mean"], np.mean(samples)) 

896 np.testing.assert_almost_equal(key_data["median"], np.median(samples)) 

897 np.testing.assert_almost_equal(key_data["std"], np.std(samples)) 

898 np.testing.assert_almost_equal( 

899 key_data["5th percentile"], np.percentile(samples, 5) 

900 ) 

901 np.testing.assert_almost_equal( 

902 key_data["95th percentile"], np.percentile(samples, 95) 

903 ) 

904 assert key_data["maxL"] is None 

905 assert key_data["maxP"] is None 

906 

907 

908class TestTQDM(object): 

909 """Test the pesummary.utils.tqdm.tqdm class 

910 """ 

911 def setup_method(self): 

912 self._range = range(100) 

913 if not os.path.isdir(tmpdir): 

914 os.mkdir(tmpdir) 

915 

916 def teardown_method(self): 

917 """Remove the files and directories created from this class 

918 """ 

919 if os.path.isdir(tmpdir): 

920 shutil.rmtree(tmpdir) 

921 

922 def test_basic_iterator(self): 

923 """Test that the core functionality of the tqdm class remains 

924 """ 

925 for j in tqdm(self._range): 

926 _ = j*j 

927 

928 def test_interaction_with_logger(self): 

929 """Test that tqdm interacts nicely with logger 

930 """ 

931 from pesummary.utils.utils import logger, LOG_FILE 

932 

933 with open("{}/test.dat".format(tmpdir), "w") as f: 

934 for j in tqdm(self._range, logger=logger, file=f): 

935 _ = j*j 

936 

937 with open("{}/test.dat".format(tmpdir), "r") as f: 

938 lines = f.readlines() 

939 assert "PESummary" in lines[-1] 

940 assert "INFO" in lines[-1] 

941 

942 

943def test_jensen_shannon_divergence(): 

944 """Test that the `jensen_shannon_divergence` method returns the same 

945 values as the scipy function 

946 """ 

947 from scipy.spatial.distance import jensenshannon 

948 from scipy import stats 

949 

950 samples = [ 

951 np.random.uniform(5, 4, 100), 

952 np.random.uniform(5, 4, 100) 

953 ] 

954 x = np.linspace(np.min(samples), np.max(samples), 100) 

955 kde = [stats.gaussian_kde(i)(x) for i in samples] 

956 _scipy = jensenshannon(*kde)**2 

957 _pesummary = utils.jensen_shannon_divergence(samples, decimal=9) 

958 np.testing.assert_almost_equal(_scipy, _pesummary) 

959 

960 from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE 

961 

962 _pesummary = utils.jensen_shannon_divergence( 

963 samples, decimal=9, kde=ReflectionBoundedKDE, xlow=4.5, xhigh=5.5 

964 ) 

965 

966 

967def test_make_cache_style_file(): 

968 """Test that the `make_cache_style_file` works as expected 

969 """ 

970 from pesummary.utils.utils import make_cache_style_file, CACHE_DIR 

971 sty = os.path.expanduser("{}/style/matplotlib_rcparams.sty".format(CACHE_DIR)) 

972 with open("test.sty", "w") as f: 

973 f.writelines(["test : 10"]) 

974 make_cache_style_file("test.sty") 

975 assert os.path.isfile(sty) 

976 with open(sty, "r") as f: 

977 lines = f.readlines() 

978 assert len(lines) == 1 

979 assert lines[0] == "test : 10" 

980 

981 

982def test_logger(): 

983 with LogCapture() as l: 

984 utils.logger.propagate = True 

985 utils.logger.info("info") 

986 utils.logger.warning("warning") 

987 l.check(("PESummary", "INFO", "info"), 

988 ("PESummary", "WARNING", "warning"),) 

989 

990 

991def test_string_match(): 

992 """function to test the string_match function 

993 """ 

994 param = "mass_1" 

995 # Test that it matches itself 

996 assert utils.string_match(param, param) 

997 # Test that it does not match another string 

998 assert not utils.string_match(param, "mass_2") 

999 # Test that a single wildcard works 

1000 assert utils.string_match(param, "{}*".format(param[0])) 

1001 assert not utils.string_match(param, "z+") 

1002 assert utils.string_match(param, "{}?".format(param[:-1])) 

1003 # Test that multiple wildcards work 

1004 assert utils.string_match(param, "*{}*".format(param[1])) 

1005 assert utils.string_match(param, "*{}?".format(param[3:-1])) 

1006 assert utils.string_match(param, "?{}?".format(param[1:-1])) 

1007 assert not utils.string_match(param, "?{}?".format(param[:-1])) 

1008 # Test 'start with' works 

1009 assert utils.string_match(param, "^{}".format(param[:3])) 

1010 # Test does not start with 

1011 assert not utils.string_match(param, "^(?!{}).+".format(param[:3])) 

1012 

1013 

1014def test_map_parameter_names(): 

1015 """function to test the map_parameter_names function 

1016 """ 

1017 dictionary = { 

1018 "a": np.random.uniform(1, 100, size=1000), 

1019 "b": np.random.uniform(1, 100, size=1000), 

1020 "c": np.random.uniform(1, 100, size=1000) 

1021 } 

1022 mapping = {"a": "z", "b": "y", "c": "x"} 

1023 dictionary_copy = dictionary.copy() 

1024 new_dictionary = utils.map_parameter_names(dictionary_copy, mapping) 

1025 # check that the new keys are in the new dictionary 

1026 assert sorted(list(new_dictionary.keys())) == ["x", "y", "z"] 

1027 # check that the new dictionary has the same items as before 

1028 for old, new in mapping.items(): 

1029 np.testing.assert_almost_equal(dictionary[old], new_dictionary[new]) 

1030 # check that the old dictionary remains unchanged 

1031 for key, item in dictionary.items(): 

1032 np.testing.assert_almost_equal(item, dictionary_copy[key]) 

1033 

1034 

1035def test_list_match(): 

1036 """function to test the list_match function 

1037 """ 

1038 params = [ 

1039 "mass_1", "mass_2", "chirp_mass", "total_mass", "mass_ratio", "a_1", 

1040 "luminosity_distance", "x" 

1041 ] 

1042 # Test that all mass parameters are returned 

1043 assert sorted(utils.list_match(params, "*mass*")) == sorted( 

1044 [p for p in params if "mass" in p] 

1045 ) 

1046 # Test that only parameters with an "_" are returned 

1047 assert sorted(utils.list_match(params, "*_*")) == sorted( 

1048 [p for p in params if "_" in p] 

1049 ) 

1050 # Test that nothing is returned 

1051 assert not len(utils.list_match(params, "z+")) 

1052 # Test that only parameters that do not start with mass are returned 

1053 assert sorted(utils.list_match(params, "^(?!mass).+")) == sorted( 

1054 [p for p in params if p[:4] != "mass"] 

1055 ) 

1056 # Test that only parameters that start with 'l' are returned 

1057 assert sorted(utils.list_match(params, "^l")) == sorted( 

1058 [p for p in params if p[0] == "l"] 

1059 ) 

1060 # Test return false 

1061 assert sorted(utils.list_match(params, "^l", return_false=True)) == sorted( 

1062 [p for p in params if p[0] != "l"] 

1063 ) 

1064 # Test multiple substrings 

1065 assert sorted(utils.list_match(params, ["^m", "*2"])) == sorted(["mass_2"]) 

1066 

1067 

1068class TestDict(object): 

1069 """Class to test the NestedDict object 

1070 """ 

1071 def test_initiate(self): 

1072 """Initiate the Dict class 

1073 """ 

1074 from pesummary.gw.file.psd import PSD 

1075 

1076 x = Dict( 

1077 {"a": [[10, 20], [10, 20]]}, value_class=PSD, 

1078 value_columns=["value", "value2"] 

1079 ) 

1080 assert list(x.keys()) == ["a"] 

1081 np.testing.assert_almost_equal(x["a"], [[10, 20], [10, 20]]) 

1082 assert isinstance(x["a"], PSD) 

1083 np.testing.assert_almost_equal(x["a"].value, [10, 10]) 

1084 np.testing.assert_almost_equal(x["a"].value2, [20, 20]) 

1085 

1086 x = Dict( 

1087 ["a"], [[[10, 20], [10, 20]]], value_class=PSD, 

1088 value_columns=["value", "value2"] 

1089 ) 

1090 assert list(x.keys()) == ["a"] 

1091 np.testing.assert_almost_equal(x["a"], [[10, 20], [10, 20]]) 

1092 assert isinstance(x["a"], PSD) 

1093 np.testing.assert_almost_equal(x["a"].value, [10, 10]) 

1094 np.testing.assert_almost_equal(x["a"].value2, [20, 20]) 

1095 

1096 def test_specify_columns(self): 

1097 """Test that x[["a", "b"]] works as expected 

1098 """ 

1099 x = Dict({"a": [10], "b": [20], "c": [30], "d": [40]}, value_class=list) 

1100 y = x[["a", "b"]] 

1101 assert sorted(list(y.keys())) == ["a", "b"] 

1102 for key in y.keys(): 

1103 assert y[key] == x[key] 

1104 with pytest.raises(Exception): 

1105 z = x[["e", "f"]] 

1106 

1107 

1108class TestProbabilityDict(object): 

1109 """Test the ProbabilityDict class 

1110 """ 

1111 def setup_method(self): 

1112 """Setup the ProbabilityDict class 

1113 """ 

1114 self.mydict = { 

1115 "a": [[1,2,3,4], [0.1, 0.2, 0.3, 0.4]], 

1116 "b": [[1,3,5,7], [0.1, 0.3, 0.2, 0.4]] 

1117 } 

1118 

1119 def test_initiate(self): 

1120 """Test the different ways to initiate this class 

1121 """ 

1122 probs = ProbabilityDict(self.mydict) 

1123 assert all(param in probs.keys() for param in self.mydict.keys()) 

1124 for param, data in self.mydict.items(): 

1125 np.testing.assert_almost_equal(data[0], probs[param].x) 

1126 np.testing.assert_almost_equal(data[1], probs[param].probs) 

1127 parameters = list(self.mydict.keys()) 

1128 data = [self.mydict[key] for key in parameters] 

1129 probs = ProbabilityDict(parameters, data) 

1130 assert all(param in probs.keys() for param in self.mydict.keys()) 

1131 for param, data in self.mydict.items(): 

1132 np.testing.assert_almost_equal(data[0], probs[param].x) 

1133 np.testing.assert_almost_equal(data[1], probs[param].probs) 

1134 

1135 def test_rvs(self): 

1136 """Test the .rvs method 

1137 """ 

1138 probs = ProbabilityDict(self.mydict) 

1139 # draw 10 samples for all parameters 

1140 samples = probs.rvs(size=10, interpolate=False) 

1141 assert isinstance(samples, SamplesDict) 

1142 assert all(param in samples.keys() for param in self.mydict.keys()) 

1143 assert all(len(samples[param] == 10) for param in self.mydict.keys()) 

1144 for param in self.mydict.keys(): 

1145 for p in samples[param]: 

1146 print(p, self.mydict[param][0]) 

1147 assert all(p in self.mydict[param][0] for p in samples[param]) 

1148 # draw 10 samples for only parameter 'a' 

1149 samples = probs.rvs(size=10, parameters=["a"]) 

1150 assert list(samples.keys()) == ["a"] 

1151 assert len(samples["a"]) == 10 

1152 # interpolate first and then draw samples 

1153 samples = probs.rvs(size=10, interpolate=True) 

1154 assert isinstance(samples, SamplesDict) 

1155 assert all(param in samples.keys() for param in self.mydict.keys()) 

1156 assert all(len(samples[param] == 10) for param in self.mydict.keys()) 

1157 

1158 def test_plotting(self): 

1159 """Test the .plot method 

1160 """ 

1161 probs = ProbabilityDict(self.mydict) 

1162 # Test that a histogram plot is generated when required 

1163 fig = probs.plot("a", type="hist") 

1164 fig = probs.plot("a", type="marginalized_posterior") 

1165 # Test that a PDF is generated when required 

1166 fig = probs.plot("a", type="pdf") 

1167 

1168 

1169class TestProbabilityDict2D(object): 

1170 """Class to test the ProbabilityDict2D class 

1171 """ 

1172 def setup_method(self): 

1173 """Setup the TestProbabilityDict2D class 

1174 """ 

1175 probs = np.random.uniform(0, 1, 100).reshape(10, 10) 

1176 self.mydict = { 

1177 "x_y": [ 

1178 np.random.uniform(0, 1, 10), np.random.uniform(0, 1, 10), 

1179 probs / np.sum(probs) 

1180 ] 

1181 } 

1182 

1183 def test_plotting(self): 

1184 """Test the .plot method 

1185 """ 

1186 # Test that a 2d KDE plot is generated when specified 

1187 mydict = ProbabilityDict2D(self.mydict) 

1188 fig = mydict.plot("x_y", type="2d_kde") 

1189 # Test that a triangle plot is generated when specified 

1190 fig, _, _, _ = mydict.plot("x_y", type="triangle") 

1191 # Test that a reverse triangle plot is generated when specified 

1192 fig, _, _, _ = mydict.plot("x_y", type="reverse_triangle") 

1193 # Test that an Exception is raised when you try and generate a plot 

1194 # which does not exist 

1195 with pytest.raises(Exception): 

1196 fig = mydict.plot("x_y", type="does_not_exist") 

1197 

1198 

1199class TestDiscretePDF2D(object): 

1200 """Class to test the DiscretePDF2D class 

1201 """ 

1202 def setup_method(self): 

1203 """Setup the TestDiscretePDF2D class 

1204 """ 

1205 self.x = [1,2] 

1206 self.y = [1,2] 

1207 self.probs = [[0.25, 0.25], [0.25, 0.25]] 

1208 

1209 def test_initiate(self): 

1210 """ 

1211 """ 

1212 # Test that a ValueError is raised if a 2d probability array is not 

1213 # provided 

1214 with pytest.raises(ValueError): 

1215 _ = DiscretePDF2D(self.x, self.y, [1,2,3,4,5,6]) 

1216 obj = DiscretePDF2D(self.x, self.y, self.probs) 

1217 

1218 def test_marginalize(self): 

1219 obj = DiscretePDF2D(self.x, self.y, self.probs) 

1220 marg = obj.marginalize() 

1221 assert isinstance(marg, DiscretePDF2Dplus1D) 

1222 np.testing.assert_almost_equal(obj.probs, marg.probs_xy.probs) 

1223 

1224 

1225class TestDiscretePDF2Dplus1D(object): 

1226 """Class to test the DiscretePDF2Dplus1D class 

1227 """ 

1228 def setup_method(self): 

1229 """Setup the TestDiscretePDF2Dplus1D class 

1230 """ 

1231 self.x = [1,2] 

1232 self.y = [1,2] 

1233 self.probs = [[0.5, 0.5], [0.5, 0.5], [[0.25, 0.25], [0.25, 0.25]]] 

1234 

1235 def test_initiate_raise(self): 

1236 """Test that the class raises ValueErrors when required 

1237 """ 

1238 # Raise error when 3 probabilities are not given 

1239 with pytest.raises(ValueError): 

1240 _ = DiscretePDF2Dplus1D(self.x, self.y, self.probs[:-1]) 

1241 # Raise error when 2d probability not given 

1242 with pytest.raises(ValueError): 

1243 _ = DiscretePDF2Dplus1D( 

1244 self.x, self.y, self.probs[:-1] + self.probs[:1] 

1245 ) 

1246 # Raise error when only 1 1d array is given 

1247 with pytest.raises(ValueError): 

1248 _ = DiscretePDF2Dplus1D( 

1249 self.x, self.y, self.probs[1:] + self.probs[-1:] 

1250 ) 

1251 

1252 def test_initiate(self): 

1253 """Test that we can initiate the class 

1254 """ 

1255 obj = DiscretePDF2Dplus1D(self.x, self.y, self.probs) 

1256 assert isinstance(obj.probs[0], DiscretePDF) 

1257 assert isinstance(obj.probs[2], DiscretePDF2D) 

1258 assert isinstance(obj.probs[1], DiscretePDF) 

1259 

1260 

1261class TestKDEList(object): 

1262 """Test the KDEList class 

1263 """ 

1264 def setup_method(self): 

1265 """Setup the KDEList class 

1266 """ 

1267 from scipy.stats import gaussian_kde 

1268 means = np.random.uniform(0, 5, size=10) 

1269 stds = np.random.uniform(0, 2, size=10) 

1270 self.inputs = np.array( 

1271 [ 

1272 np.random.normal(mean, std, size=10000) for mean, std 

1273 in zip(means, stds) 

1274 ] 

1275 ) 

1276 self.pts = np.linspace(-10, 10, 100) 

1277 self.true_kdes = [gaussian_kde(_)(self.pts) for _ in self.inputs] 

1278 

1279 def test_call(self): 

1280 """Test the KDEList call method 

1281 """ 

1282 from pesummary.utils.kde_list import KDEList 

1283 from scipy.stats import gaussian_kde 

1284 new = KDEList(self.inputs, kde=gaussian_kde, pts=self.pts) 

1285 # test on a single CPU 

1286 new_single = new(multi_process=1) 

1287 for num in range(len(self.inputs)): 

1288 np.testing.assert_almost_equal(self.true_kdes[num], new_single[num]) 

1289 # test on multiple CPUs 

1290 new_multiple = new(multi_process=2) 

1291 for num in range(len(self.inputs)): 

1292 np.testing.assert_almost_equal(self.true_kdes[num], new_single[num]) 

1293 # test the `idx` kwarg 

1294 for num in range(len(self.inputs)): 

1295 np.testing.assert_almost_equal(self.true_kdes[num], new(idx=num)) 

1296 # split the calculation into 2 chunks and 2 CPUs 

1297 new_chunk1 = new( 

1298 idx=np.arange(0, len(self.inputs) // 2, 1), multi_process=2 

1299 ).tolist() 

1300 new_chunk2 = new( 

1301 idx=np.arange(len(self.inputs) // 2, len(self.inputs), 1), multi_process=2 

1302 ).tolist() 

1303 new_combined = new_chunk1 + new_chunk2 

1304 for num in range(len(self.inputs)): 

1305 np.testing.assert_almost_equal(self.true_kdes[num], new_combined[num]) 

1306 # assert that different KDEs are produced for different pts 

1307 with pytest.raises(AssertionError): 

1308 new_diff = new(pts=np.linspace(-5, 5, 100), multi_process=1) 

1309 for num in range(len(self.inputs)): 

1310 np.testing.assert_almost_equal(self.true_kdes[num], new_diff[num]) 

1311 

1312 

1313def make_cache_style_file(style_file): 

1314 """Make a cache directory which stores the style file you wish to use 

1315 when plotting 

1316 

1317 Parameters 

1318 ---------- 

1319 style_file: str 

1320 path to the style file that you wish to use when plotting 

1321 """ 

1322 make_dir(CACHE_DIR) 

1323 shutil.copyfile( 

1324 style_file, os.path.join(CACHE_DIR, "matplotlib_rcparams.sty") 

1325 ) 

1326 

1327 

1328def get_matplotlib_style_file(): 

1329 """Return the path to the matplotlib style file that you wish to use 

1330 """ 

1331 return os.path.join(CACHE_DIR, "matplotlib_rcparams.sty")