Coverage for pesummary/tests/classification_test.py: 100.0%

52 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1# License under an MIT style license -- see LICENSE.md 

2 

3import os 

4import shutil 

5import numpy as np 

6from ligo.em_bright.em_bright import source_classification_pe 

7from .base import make_result_file, testing_dir 

8from pesummary.io import read 

9from pesummary.utils.decorators import no_latex_plot 

10from pesummary.gw.classification import Classify, EMBright, PAstro 

11 

12__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

13 

14 

15class _Base(object): 

16 """Base testing class 

17 """ 

18 def setup_method(self): 

19 """Setup the class 

20 """ 

21 if not os.path.isdir(".outdir"): 

22 os.mkdir(".outdir") 

23 make_result_file(gw=True, extension="dat", outdir=".outdir/") 

24 f = read(".outdir/test.dat") 

25 # regenerate the mass_1_source, mass_2_source posteriors because these 

26 # are randomly chosen and do not correspond to 

27 # mass_1 / (1. + z) and mass_2 / (1. + z) 

28 f.generate_all_posterior_samples( 

29 regenerate=["mass_1_source", "mass_2_source"] 

30 ) 

31 f.write( 

32 filename="test_converted.dat", file_format="dat", outdir=".outdir", 

33 overwrite=True 

34 ) 

35 f.write( 

36 filename="test_lalinference.hdf5", file_format="hdf5", 

37 outdir=".outdir", overwrite=True 

38 ) 

39 

40 def teardown_method(self): 

41 """Remove the files and directories created from this class 

42 """ 

43 if os.path.isdir(".outdir"): 

44 shutil.rmtree(".outdir") 

45 

46 @no_latex_plot 

47 def test_plotting(self, cls=EMBright, **kwargs): 

48 """Test the .plot method 

49 """ 

50 import matplotlib.figure 

51 samples = read(".outdir/test_converted.dat").samples_dict 

52 _cls = cls(samples, **kwargs) 

53 ptable = _cls.classification() 

54 for plot in _cls.available_plots: 

55 fig = _cls.plot(ptable, type=plot) 

56 assert isinstance(fig, matplotlib.figure.Figure) 

57 

58 

59class TestPAstro(_Base): 

60 """Test the pesummary.gw.classification.PAstro class and 

61 pesummary.gw.classification.Classify class 

62 """ 

63 def test_classification(self): 

64 """Test the base classification method agrees with the 

65 pepredicates.predicate_table function 

66 """ 

67 pass 

68 

69 def test_plotting(self): 

70 """Test the .plot method 

71 """ 

72 import pytest 

73 # catch ValueError when no terrestrial probability is given 

74 with pytest.raises(ValueError): 

75 super(TestPAstro, self).test_plotting( 

76 cls=PAstro, category_data=f"{testing_dir}/rates.yml", 

77 ) 

78 super(TestPAstro, self).test_plotting( 

79 cls=PAstro, category_data=f"{testing_dir}/rates.yml", 

80 terrestrial_probability=0. 

81 ) 

82 

83 

84class TestEMBright(_Base): 

85 """Test the pesummary.gw.classification.EMBright class and 

86 pesummary.gw.classification.Classify class 

87 """ 

88 def test_classification(self): 

89 """Test the base classification method agrees with the 

90 ligo.em_bright.source_classification_pe function 

91 """ 

92 p_astro = source_classification_pe( 

93 ".outdir/test_lalinference.hdf5" 

94 ) 

95 samples = read(".outdir/test_converted.dat").samples_dict 

96 pesummary = EMBright(samples).classification() 

97 np.testing.assert_almost_equal(p_astro[0], pesummary["HasNS"], 5) 

98 np.testing.assert_almost_equal(p_astro[1], pesummary["HasRemnant"], 5) 

99 pesummary2 = EMBright.classification_from_file( 

100 ".outdir/test_converted.dat" 

101 ) 

102 pesummary3 = Classify.classification_from_file( 

103 ".outdir/test_converted.dat", 

104 category_data=f"{testing_dir}/rates.yml", 

105 terrestrial_probability=0. 

106 ) 

107 for key, val in pesummary2.items(): 

108 np.testing.assert_almost_equal(pesummary[key], val, 5) 

109 np.testing.assert_almost_equal(pesummary[key], pesummary3[key], 5) 

110 

111 def test_plotting(self): 

112 """Test the .plot method 

113 """ 

114 super(TestEMBright, self).test_plotting(cls=EMBright)