Coverage for pesummary/tests/plot_test.py: 98.3%

233 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import os 

4import shutil 

5 

6import argparse 

7 

8from pesummary.core.plots import plot 

9from pesummary.gw.plots import plot as gwplot 

10from pesummary.utils.array import Array 

11from subprocess import CalledProcessError 

12 

13import numpy as np 

14import matplotlib 

15from matplotlib import rcParams 

16import pytest 

17import tempfile 

18 

19tmpdir = tempfile.TemporaryDirectory(prefix=".", dir=".").name 

20 

21__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

22rcParams["text.usetex"] = False 

23 

24class TestPlot(object): 

25 

26 def setup_method(self): 

27 if os.path.isdir(tmpdir): 

28 shutil.rmtree(tmpdir) 

29 os.makedirs(tmpdir) 

30 

31 def _grab_frequencies_from_psd_data_file(self, file): 

32 """Return the frequencies stored in the psd data files 

33 

34 Parameters 

35 ---------- 

36 file: str 

37 path to the psd data file 

38 """ 

39 fil = open(file) 

40 fil = fil.readlines() 

41 fil = [i.strip().split() for i in fil] 

42 return [float(i[0]) for i in fil] 

43 

44 def _grab_strains_from_psd_data_file(sef, file): 

45 """Return the strains stored in the psd data files 

46 

47 Parameters 

48 ---------- 

49 file: str 

50 path to the psd data file 

51 """ 

52 fil = open(file) 

53 fil = fil.readlines() 

54 fil = [i.strip().split() for i in fil] 

55 return [float(i[1]) for i in fil] 

56 

57 def test_autocorrelation_plot(self): 

58 rcParams["text.usetex"] = False 

59 fig = plot._autocorrelation_plot("mass_1", Array([10, 20, 30, 40])) 

60 assert isinstance(fig, matplotlib.figure.Figure) == True 

61 

62 @pytest.mark.parametrize("param, samples", [("mass_1", 

63 [Array([10, 20, 30, 40]), Array([10, 20, 30, 40])]), ]) 

64 def test_autocorrelation_plot_mcmc(self, param, samples): 

65 fig = plot._autocorrelation_plot_mcmc(param, samples) 

66 assert isinstance(fig, matplotlib.figure.Figure) == True 

67 

68 @pytest.mark.parametrize("param, samples, latex_label", [("mass_1", 

69 Array([10, 20, 30, 40]), r"$m_{1}$"),]) 

70 def test_sample_evolution_plot(self, param, samples, latex_label): 

71 fig = plot._sample_evolution_plot(param, samples, latex_label) 

72 assert isinstance(fig, matplotlib.figure.Figure) == True 

73 

74 @pytest.mark.parametrize("param, samples, latex_label", [("mass_1", 

75 [Array([10, 20, 30, 40]), Array([10, 20, 30, 40])], r"$m_{1}$"), ]) 

76 def test_sample_evolution_plot_mcmc(self, param, samples, latex_label): 

77 fig = plot._autocorrelation_plot_mcmc(param, samples, latex_label) 

78 assert isinstance(fig, matplotlib.figure.Figure) == True 

79 

80 @pytest.mark.parametrize("param, samples, latex_label", [("mass_1", 

81 Array([10, 20, 30, 40]), r"$m_{1}$"),]) 

82 def test_1d_cdf_plot(self, param, samples, latex_label): 

83 fig = plot._1d_cdf_plot(param, samples, latex_label) 

84 assert isinstance(fig, matplotlib.figure.Figure) == True 

85 

86 @pytest.mark.parametrize("param, samples, latex_label", [("mass_1", 

87 [Array([10, 20, 30, 40]), Array([10, 20, 30, 40])], r"$m_{1}$"), ]) 

88 def test_1d_cdf_plot_mcmc(self, param, samples, latex_label): 

89 fig = plot._1d_cdf_plot_mcmc(param, samples, latex_label) 

90 assert isinstance(fig, matplotlib.figure.Figure) == True 

91 

92 @pytest.mark.parametrize("param, samples, colors, latex_label, labels", 

93 [("mass1", [[10,20,30,40], [1,2,3,4]], 

94 ["b", "r"], r"$m_{1}$", "approx1"),]) 

95 def test_1d_cdf_comparison_plot(self, param, samples, colors, 

96 latex_label, labels): 

97 fig = plot._1d_cdf_comparison_plot(param, samples, colors, 

98 latex_label, labels) 

99 assert isinstance(fig, matplotlib.figure.Figure) == True 

100 

101 @pytest.mark.parametrize("param, samples, latex_label", [("mass1", 

102 Array([10,20,30,40]), r"$m_{1}$"),]) 

103 def test_1d_histogram_plot(self, param, samples, latex_label): 

104 for module in [plot, gwplot]: 

105 fig = getattr(module, "_1d_histogram_plot")(param, samples, latex_label) 

106 assert isinstance(fig, matplotlib.figure.Figure) == True 

107 fig = getattr(module, "_1d_histogram_plot")(param, samples, latex_label, kde=True) 

108 assert isinstance(fig, matplotlib.figure.Figure) == True 

109 

110 @pytest.mark.parametrize("param, samples, latex_label", 

111 [("mass1", [[10,20,30,40], [1,2,3,4]], r"$m_{1}$"),]) 

112 def test_1d_histogram_plot_mcmc(self, param, samples, latex_label): 

113 for module in [plot, gwplot]: 

114 fig = getattr(module, "_1d_histogram_plot_mcmc")(param, samples, latex_label) 

115 assert isinstance(fig, matplotlib.figure.Figure) == True 

116 fig = getattr(module, "_1d_histogram_plot_mcmc")(param, samples, latex_label) 

117 assert isinstance(fig, matplotlib.figure.Figure) == True 

118 

119 @pytest.mark.parametrize("param, samples, colors, latex_label, labels", 

120 [("mass1", [[10,20,30,40], [1,2,3,4]], 

121 ["b", "r"], r"$m_{1}$", "approx1"),]) 

122 def test_1d_comparison_histogram_plot(self, param, samples, colors, 

123 latex_label, labels): 

124 for module in [plot, gwplot]: 

125 fig = getattr(module, "_1d_comparison_histogram_plot")( 

126 param, samples, colors, latex_label, labels 

127 ) 

128 assert isinstance(fig, matplotlib.figure.Figure) == True 

129 fig = getattr(module, "_1d_comparison_histogram_plot")( 

130 param, samples, colors, latex_label, labels, kde=True 

131 ) 

132 assert isinstance(fig, matplotlib.figure.Figure) == True 

133 

134 @pytest.mark.parametrize("param, samples, colors, latex_label, labels", 

135 [("mass1", [[10,20,30,40], [1,2,3,4]], 

136 ["b", "r"], r"$m_{1}$", ["approx1", "approx2"]),]) 

137 def test_comparison_box_plot(self, param, samples, colors, 

138 latex_label, labels): 

139 fig = plot._comparison_box_plot(param, samples, colors, latex_label, 

140 labels) 

141 assert isinstance(fig, matplotlib.figure.Figure) == True 

142 

143 def test_waveform_plot(self): 

144 maxL_params = {"approximant": "IMRPhenomPv2", "mass_1": 10., "mass_2": 5., 

145 "theta_jn": 1., "phi_jl": 0., "tilt_1": 0., "tilt_2": 0., 

146 "phi_12": 0., "a_1": 0.5, "a_2": 0., "phase": 0., 

147 "ra": 1., "dec": 1., "psi": 0., "geocent_time": 0., 

148 "luminosity_distance": 100} 

149 fig = gwplot._waveform_plot(["H1"], maxL_params) 

150 assert isinstance(fig, matplotlib.figure.Figure) == True 

151 

152 def test_timedomain_waveform_plot(self): 

153 maxL_params = {"approximant": "IMRPhenomPv2", "mass_1": 10., "mass_2": 5., 

154 "theta_jn": 1., "phi_jl": 0., "tilt_1": 0., "tilt_2": 0., 

155 "phi_12": 0., "a_1": 0.5, "a_2": 0., "phase": 0., 

156 "ra": 1., "dec": 1., "psi": 0., "geocent_time": 0., 

157 "luminosity_distance": 100} 

158 fig = gwplot._time_domain_waveform(["H1"], maxL_params) 

159 assert isinstance(fig, matplotlib.figure.Figure) == True 

160 

161 def test_waveform_comparison_plot(self): 

162 maxL_params = {"approximant": "IMRPhenomPv2", "mass_1": 10., "mass_2": 5., 

163 "theta_jn": 1., "phi_jl": 0., "tilt_1": 0., "tilt_2": 0., 

164 "phi_12": 0., "a_1": 0.5, "a_2": 0., "phase": 0., 

165 "ra": 1., "dec": 1., "psi": 0., "geocent_time": 0., 

166 "luminosity_distance": 100} 

167 maxL_params = [maxL_params, maxL_params] 

168 maxL_params[1]["mass_1"] = 7. 

169 fig = gwplot._waveform_comparison_plot(maxL_params, ["b", "r"], 

170 ["IMRPhenomPv2"]*2) 

171 assert isinstance(fig, matplotlib.figure.Figure) == True 

172 

173 def test_time_domain_waveform_comparison_plot(self): 

174 maxL_params = {"approximant": "IMRPhenomPv2", "mass_1": 10., "mass_2": 5., 

175 "theta_jn": 1., "phi_jl": 0., "tilt_1": 0., "tilt_2": 0., 

176 "phi_12": 0., "a_1": 0.5, "a_2": 0., "phase": 0., 

177 "ra": 1., "dec": 1., "psi": 0., "geocent_time": 0., 

178 "luminosity_distance": 100} 

179 maxL_params = [maxL_params, maxL_params] 

180 maxL_params[1]["mass_1"] = 7. 

181 fig = gwplot._time_domain_waveform_comparison_plot(maxL_params, ["b", "r"], 

182 ["IMRPhenomPv2"]*2) 

183 assert isinstance(fig, matplotlib.figure.Figure) == True 

184 

185 @pytest.mark.parametrize("ra, dec", [([1,2,3,4], [1,1,1,1]),]) 

186 def test_sky_map_plot(self, ra, dec): 

187 fig = gwplot._default_skymap_plot(ra, dec) 

188 assert isinstance(fig, matplotlib.figure.Figure) == True 

189 

190 @pytest.mark.parametrize("ra, dec, approx, colors", [([[1,2,3,4],[1,2,2,1]], 

191 [[1,1,2,1],[1,1,1,1]], ["approx1", "approx2"], ["b", "r"]),]) 

192 def test_sky_map_comparison_plot(self, ra, dec, approx, colors): 

193 fig = gwplot._sky_map_comparison_plot(ra, dec, approx, colors) 

194 assert isinstance(fig, matplotlib.figure.Figure) == True 

195 

196 def test_corner_plot(self): 

197 latex_labels = {"luminosity_distance": r"$d_{L}$", 

198 "dec": r"$\delta$", 

199 "a_2": r"$a_{2}$", "a_1": r"$a_{1}$", 

200 "geocent_time": r"$t$", "phi_jl": r"$\phi_{JL}$", 

201 "psi": r"$\Psi$", "ra": r"$\alpha$", "phase": r"$\psi$", 

202 "mass_2": r"$m_{2}$", "mass_1": r"$m_{1}$", 

203 "phi_12": r"$\phi_{12}$", "tilt_2": r"$t_{1}$", 

204 "iota": r"$\iota$", "tilt_1": r"$t_{1}$", 

205 "chi_p": r"$\chi_{p}$", "chirp_mass": r"$\mathcal{M}$", 

206 "mass_ratio": r"$q$", "symmetric_mass_ratio": r"$\eta$", 

207 "total_mass": r"$M$", "chi_eff": r"$\chi_{eff}$"} 

208 samples = [[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]]*21 

209 samples = [np.random.random(21).tolist() for i in range(21)] 

210 params = list(latex_labels.keys()) 

211 samples = { 

212 i: samples[num] for num, i in enumerate(params)} 

213 fig, included_params, data = gwplot._make_corner_plot(samples, latex_labels) 

214 assert isinstance(fig, matplotlib.figure.Figure) == True 

215 

216 def test_source_corner_plot(self): 

217 latex_labels = {"luminosity_distance": r"$d_{L}$", 

218 "dec": r"$\delta$", 

219 "a_2": r"$a_{2}$", "a_1": r"$a_{1}$", 

220 "geocent_time": r"$t$", "phi_jl": r"$\phi_{JL}$", 

221 "psi": r"$\Psi$", "ra": r"$\alpha$", "phase": r"$\psi$", 

222 "mass_2": r"$m_{2}$", "mass_1": r"$m_{1}$", 

223 "phi_12": r"$\phi_{12}$", "tilt_2": r"$t_{1}$", 

224 "iota": r"$\iota$", "tilt_1": r"$t_{1}$", 

225 "chi_p": r"$\chi_{p}$", "chirp_mass": r"$\mathcal{M}$", 

226 "mass_ratio": r"$q$", "symmetric_mass_ratio": r"$\eta$", 

227 "total_mass": r"$M$", "chi_eff": r"$\chi_{eff}$"} 

228 samples = [[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]]*21 

229 samples = [np.random.random(21).tolist() for i in range(21)] 

230 params = list(latex_labels.keys()) 

231 samples = {i: j for i, j in zip(params, samples)} 

232 fig = gwplot._make_source_corner_plot(samples, latex_labels) 

233 assert isinstance(fig, matplotlib.figure.Figure) == True 

234 

235 def test_extrinsic_corner_plot(self): 

236 latex_labels = {"luminosity_distance": r"$d_{L}$", 

237 "dec": r"$\delta$", 

238 "a_2": r"$a_{2}$", "a_1": r"$a_{1}$", 

239 "geocent_time": r"$t$", "phi_jl": r"$\phi_{JL}$", 

240 "psi": r"$\Psi$", "ra": r"$\alpha$", "phase": r"$\psi$", 

241 "mass_2": r"$m_{2}$", "mass_1": r"$m_{1}$", 

242 "phi_12": r"$\phi_{12}$", "tilt_2": r"$t_{1}$", 

243 "iota": r"$\iota$", "tilt_1": r"$t_{1}$", 

244 "chi_p": r"$\chi_{p}$", "chirp_mass": r"$\mathcal{M}$", 

245 "mass_ratio": r"$q$", "symmetric_mass_ratio": r"$\eta$", 

246 "total_mass": r"$M$", "chi_eff": r"$\chi_{eff}$"} 

247 samples = [[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]]*21 

248 samples = [np.random.random(21).tolist() for i in range(21)] 

249 params = list(latex_labels.keys()) 

250 samples = {i: j for i, j in zip(params, samples)} 

251 fig = gwplot._make_extrinsic_corner_plot(samples, latex_labels) 

252 assert isinstance(fig, matplotlib.figure.Figure) == True 

253 

254 def test_comparison_corner_plot(self): 

255 latex_labels = {"luminosity_distance": r"$d_{L}$", 

256 "dec": r"$\delta$", 

257 "a_2": r"$a_{2}$", "a_1": r"$a_{1}$", 

258 "geocent_time": r"$t$", "phi_jl": r"$\phi_{JL}$", 

259 "psi": r"$\Psi$", "ra": r"$\alpha$", "phase": r"$\psi$", 

260 "mass_2": r"$m_{2}$", "mass_1": r"$m_{1}$", 

261 "phi_12": r"$\phi_{12}$", "tilt_2": r"$t_{1}$", 

262 "iota": r"$\iota$", "tilt_1": r"$t_{1}$", 

263 "chi_p": r"$\chi_{p}$", "chirp_mass": r"$\mathcal{M}$", 

264 "mass_ratio": r"$q$", "symmetric_mass_ratio": r"$\eta$", 

265 "total_mass": r"$M$", "chi_eff": r"$\chi_{eff}$"} 

266 samples = [[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]]*21 

267 samples = [np.random.random(21).tolist() for i in range(21)] 

268 params = list(latex_labels.keys()) 

269 _samples = { 

270 i: samples[num] for num, i in enumerate(params)} 

271 _samples = {"one": _samples, "two": _samples} 

272 fig = gwplot._make_comparison_corner_plot( 

273 _samples, latex_labels, corner_parameters=params 

274 ) 

275 assert isinstance(fig, matplotlib.figure.Figure) == True 

276 fig.close() 

277 

278 def test_sensitivity_plot(self): 

279 maxL_params = {"approximant": "IMRPhenomPv2", "mass_1": 10., "mass_2": 5., 

280 "iota": 1., "phi_jl": 0., "tilt_1": 0., "tilt_2": 0., 

281 "phi_12": 0., "a_1": 0.5, "a_2": 0., "phase": 0., 

282 "ra": 1., "dec": 1., "psi": 0., "geocent_time": 0., 

283 "luminosity_distance": 100} 

284 fig = gwplot._sky_sensitivity(["H1", "L1"], 1.0, maxL_params) 

285 assert isinstance(fig, matplotlib.figure.Figure) == True 

286 

287 def test_psd_plot(self): 

288 with open("{}/psd.dat".format(tmpdir), "w") as f: 

289 f.writelines(["0.5 100"]) 

290 f.writelines(["1.0 150"]) 

291 f.writelines(["5.0 200"]) 

292 frequencies = [ 

293 self._grab_frequencies_from_psd_data_file("{}/psd.dat".format(tmpdir)) 

294 ] 

295 strains = [ 

296 self._grab_frequencies_from_psd_data_file("{}/psd.dat".format(tmpdir)) 

297 ] 

298 fig = gwplot._psd_plot(frequencies, strains, labels=["H1"]) 

299 assert isinstance(fig, matplotlib.figure.Figure) == True 

300 

301 def test_calibration_plot(self): 

302 frequencies = np.arange(20, 100, 0.2) 

303 ifos = ["H1"] 

304 calibration = [[ 

305 [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], 

306 [2000.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] 

307 ]] 

308 fig = gwplot._calibration_envelope_plot(frequencies, calibration, ifos) 

309 assert isinstance(fig, matplotlib.figure.Figure) == True 

310 

311 def test_classification_plot(self): 

312 classifications = {"BBH": 0.95, "NSBH": 0.05} 

313 fig = gwplot._classification_plot(classifications) 

314 assert isinstance(fig, matplotlib.figure.Figure) == True 

315 

316 

317class TestPopulation(object): 

318 """Class to test the `pesummary.core.plot.population` module 

319 """ 

320 def test_scatter_plot(self): 

321 from pesummary.core.plots.population import scatter_plot 

322 

323 parameters = ["a", "b"] 

324 sample_dict = {"one": {"a": 10, "b": 20}, "two": {"a": 15, "b": 5}} 

325 latex_labels = {"a": "a", "b": "b"} 

326 fig = scatter_plot(parameters, sample_dict, latex_labels) 

327 assert isinstance(fig, matplotlib.figure.Figure) 

328 fig = scatter_plot( 

329 parameters, sample_dict, latex_labels, xerr=sample_dict, 

330 yerr=sample_dict 

331 ) 

332 assert isinstance(fig, matplotlib.figure.Figure) 

333 

334 

335class TestDetchar(object): 

336 """Class to test the `pesummary.gw.plot.detchar` module 

337 """ 

338 def test_spectrogram(self): 

339 from gwpy.timeseries.core import TimeSeriesBase 

340 from pesummary.gw.plots.detchar import spectrogram 

341 

342 strain = {"H1": TimeSeriesBase(np.random.normal(size=200), x0=0, dx=1)} 

343 fig = spectrogram(strain) 

344 assert isinstance(fig["H1"], matplotlib.figure.Figure) 

345 

346 def test_omegascan(self): 

347 from gwpy.timeseries.core import TimeSeriesBase 

348 from pesummary.gw.plots.detchar import omegascan 

349 

350 strain = {"H1": TimeSeriesBase(np.random.normal(size=200), x0=0, dx=1)} 

351 fig = omegascan(strain, 0) 

352 assert isinstance(fig["H1"], matplotlib.figure.Figure) 

353 

354 

355class TestPublication(object): 

356 """Class to test the `pesummary.gw.plots.publication` module 

357 """ 

358 def test_twod_contour_plots(self): 

359 from pesummary.gw.plots.publication import twod_contour_plots 

360 

361 parameters = ["a", "b"] 

362 samples = [np.array([ 

363 np.random.uniform(0., 3000, 1000), 

364 np.random.uniform(0., 3000, 1000) 

365 ])] 

366 labels = ["a", "b"] 

367 fig = twod_contour_plots( 

368 parameters, samples, labels, {"a": "a", "b": "b"} 

369 ) 

370 assert isinstance(fig, matplotlib.figure.Figure) 

371 

372 def test_violin(self): 

373 from pesummary.gw.plots.publication import violin_plots 

374 from pesummary.core.plots.seaborn.violin import split_dataframe 

375 

376 parameter = "a" 

377 samples = [ 

378 np.random.uniform(0., 3000, 1000), 

379 np.random.uniform(0., 3000, 1000) 

380 ] 

381 labels = ["a", "b"] 

382 fig = violin_plots(parameter, samples, labels, {"a": "a", "b": "b"}) 

383 assert isinstance(fig, matplotlib.figure.Figure) 

384 samples2 = [ 

385 np.random.uniform(0., 3000, 1000), 

386 np.random.uniform(0., 3000, 1000) 

387 ] 

388 split = split_dataframe(samples, samples2, labels) 

389 fig = violin_plots( 

390 parameter, split, labels, {"a": "a", "b": "b"}, 

391 cut=0, x="label", y="data", hue="side", split=True 

392 ) 

393 assert isinstance(fig, matplotlib.figure.Figure) 

394 

395 def test_spin_distribution_plots(self): 

396 from pesummary.gw.plots.publication import spin_distribution_plots 

397 

398 parameters = ["a_1", "a_2", "cos_tilt_1", "cos_tilt_2"] 

399 samples = [ 

400 np.random.uniform(0, 1, 1000), np.random.uniform(0, 1, 1000), 

401 np.random.uniform(-1, 1, 1000), np.random.uniform(-1, 1, 1000) 

402 ] 

403 label = "test" 

404 color = "r" 

405 fig = spin_distribution_plots(parameters, samples, label, color) 

406 assert isinstance(fig, matplotlib.figure.Figure) 

407 

408 def test_triangle(self): 

409 from pesummary.core.plots.publication import triangle_plot 

410 import numpy as np 

411 

412 x = [np.random.normal(10, i, 1000) for i in [2, 3]] 

413 y = [np.random.normal(10, i, 1000) for i in [2, 2.5]] 

414 

415 fig, _, _, _ = triangle_plot( 

416 x, y, fill_alpha=0.2, xlabel=r"$x$", ylabel=r"$y$", 

417 linestyles=["-", "--"], percentiles=[5, 95] 

418 ) 

419 assert isinstance(fig, matplotlib.figure.Figure) 

420 

421 def test_reverse_triangle(self): 

422 from pesummary.core.plots.publication import reverse_triangle_plot 

423 import numpy as np 

424 

425 x = [np.random.normal(10, i, 1000) for i in [2, 3]] 

426 y = [np.random.normal(10, i, 1000) for i in [2, 2.5]] 

427 

428 fig, _, _, _ = reverse_triangle_plot( 

429 x, y, fill_alpha=0.2, xlabel=r"$x$", ylabel=r"$y$", 

430 linestyles=["-", "--"], percentiles=[5, 95] 

431 ) 

432 assert isinstance(fig, matplotlib.figure.Figure)