Coverage for pesummary/tests/psd_test.py: 100.0%

69 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3from pesummary.gw.file.psd import PSDDict, PSD 

4import numpy as np 

5import os 

6import shutil 

7import tempfile 

8 

9tmpdir = tempfile.TemporaryDirectory(prefix=".", dir=".").name 

10 

11__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

12 

13 

14class TestPSDDict(object): 

15 """Test that the PSDDict works as expected 

16 """ 

17 def setup_method(self): 

18 """Setup the testing class 

19 """ 

20 self.psd_data = { 

21 "H1": [[0.00000e+00, 2.50000e-01], 

22 [1.25000e-01, 2.50000e-01], 

23 [2.50000e-01, 2.50000e-01]], 

24 "V1": [[0.00000e+00, 2.50000e-01], 

25 [1.25000e-01, 2.50000e-01], 

26 [2.50000e-01, 2.50000e-01]] 

27 } 

28 if not os.path.isdir(tmpdir): 

29 os.mkdir(tmpdir) 

30 

31 def teardown_method(self): 

32 """Remove all files and directories created from this class 

33 """ 

34 if os.path.isdir(tmpdir): 

35 shutil.rmtree(tmpdir) 

36 

37 def test_initiate(self): 

38 """Test that the PSDDict class can be initalized correctly 

39 """ 

40 psd_dict = PSDDict(self.psd_data.keys(), self.psd_data.values()) 

41 assert sorted(list(psd_dict.detectors)) == ["H1", "V1"] 

42 assert isinstance(psd_dict["H1"], PSD) 

43 np.testing.assert_almost_equal( 

44 psd_dict["H1"].frequencies, [0, 0.125, 0.25] 

45 ) 

46 np.testing.assert_almost_equal( 

47 psd_dict["V1"].strains, [0.25, 0.25, 0.25] 

48 ) 

49 

50 psd_dict = PSDDict(self.psd_data) 

51 assert sorted(list(psd_dict.detectors)) == ["H1", "V1"] 

52 assert isinstance(psd_dict["H1"], PSD) 

53 np.testing.assert_almost_equal( 

54 psd_dict["H1"].frequencies, [0, 0.125, 0.25] 

55 ) 

56 np.testing.assert_almost_equal( 

57 psd_dict["V1"].strains, [0.25, 0.25, 0.25] 

58 ) 

59 

60 def test_plot(self): 

61 """Test the plotting function works correctly 

62 """ 

63 import matplotlib 

64 

65 psd_dict = PSDDict(self.psd_data) 

66 assert isinstance(psd_dict.plot(), matplotlib.figure.Figure) 

67 

68 def test_read(self): 

69 """Test that the PSDDict class can be initialized correctly with the 

70 read classmethod 

71 """ 

72 f = PSDDict(self.psd_data) 

73 for ifo, psd in f.items(): 

74 psd.save_to_file("{}/{}_test.dat".format(tmpdir, ifo)) 

75 g = PSDDict.read( 

76 files=[ 

77 "{}/H1_test.dat".format(tmpdir), "{}/V1_test.dat".format(tmpdir) 

78 ], detectors=["H1", "V1"] 

79 ) 

80 for ifo, psd in g.items(): 

81 np.testing.assert_almost_equal(psd.frequencies, f[ifo].frequencies) 

82 np.testing.assert_almost_equal(psd.strains, f[ifo].strains) 

83 g = PSDDict.read( 

84 common_string="%s/{}_test.dat" % (tmpdir), detectors=["H1", "V1"] 

85 ) 

86 for ifo, psd in g.items(): 

87 np.testing.assert_almost_equal(psd.frequencies, f[ifo].frequencies) 

88 np.testing.assert_almost_equal(psd.strains, f[ifo].strains) 

89 

90 def test_interpolate(self): 

91 """Test the interpolate method 

92 """ 

93 f = PSDDict(self.psd_data) 

94 g = f.interpolate( 

95 f["H1"].low_frequency, f["H1"].delta_f / 2 

96 ) 

97 for ifo, psd in f.items(): 

98 np.testing.assert_almost_equal(g[ifo].delta_f, psd.delta_f / 2) 

99 np.testing.assert_almost_equal(g[ifo].low_frequency, psd.low_frequency) 

100 

101 

102class TestPSD(object): 

103 """Test the PSD class 

104 """ 

105 def setup_method(self): 

106 """Setup the testing class 

107 """ 

108 self.obj = PSD([[10, 20], [10.25, 20], [10.5, 20]]) 

109 if not os.path.isdir(tmpdir): 

110 os.mkdir(tmpdir) 

111 

112 def teardown_method(self): 

113 """Remove all files and directories created from this class 

114 """ 

115 if os.path.isdir(tmpdir): 

116 shutil.rmtree(tmpdir) 

117 

118 def test_save_to_file(self): 

119 """Test the save to file method 

120 """ 

121 self.obj.save_to_file("{}/test.dat".format(tmpdir)) 

122 data = np.genfromtxt("{}/test.dat".format(tmpdir)) 

123 np.testing.assert_almost_equal(data.T[0], [10, 10.25, 10.5]) 

124 np.testing.assert_almost_equal(data.T[1], [20, 20, 20]) 

125 

126 def test_invalid_input(self): 

127 """Test that the appropiate error is raised if the input is wrong 

128 """ 

129 import pytest 

130 

131 with pytest.raises(IndexError): 

132 obj = PSD([10, 10]) 

133 

134 def test_interpolate(self): 

135 """Test the interpolate method 

136 """ 

137 g = self.obj.interpolate( 

138 self.obj.low_frequency, 

139 self.obj.delta_f / 2 

140 ) 

141 np.testing.assert_almost_equal(g.delta_f, self.obj.delta_f / 2) 

142 np.testing.assert_almost_equal(g.low_frequency, self.obj.low_frequency)