Coverage for pesummary/gw/file/psd.py: 77.8%

126 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import os 

4import numpy as np 

5from pesummary import conf 

6from pesummary.utils.utils import logger, check_file_exists_and_rename 

7from pesummary.utils.dict import Dict 

8 

9__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

10 

11 

12class PSDDict(Dict): 

13 """Class to handle a dictionary of PSDs 

14 

15 Parameters 

16 ---------- 

17 detectors: list 

18 list of detectors 

19 data: nd list 

20 list of psd samples for each detector. First column is frequencies, 

21 second column is strains 

22 

23 Attributes 

24 ---------- 

25 detectors: list 

26 list of detectors stored in the dictionary 

27 

28 Methods 

29 ------- 

30 plot: 

31 Generate a plot based on the psd samples stored 

32 to_pycbc: 

33 Convert dictionary of PSD objects to a dictionary of 

34 pycbc.frequencyseries objects objects 

35 

36 Examples 

37 -------- 

38 >>> from pesummary.gw.file.psd import PSDDict 

39 >>> detectors = ["H1", "V1"] 

40 >>> psd_data = [ 

41 ... [[0.00000e+00, 2.50000e-01], 

42 ... [1.25000e-01, 2.50000e-01], 

43 ... [2.50000e-01, 2.50000e-01]], 

44 ... [[0.00000e+00, 2.50000e-01], 

45 ... [1.25000e-01, 2.50000e-01], 

46 ... [2.50000e-01, 2.50000e-01]] 

47 ... ] 

48 >>> psd_dict = PSDDict(detectors, psd_data) 

49 >>> psd_data = { 

50 ... "H1": [[0.00000e+00, 2.50000e-01], 

51 ... [1.25000e-01, 2.50000e-01], 

52 ... [2.50000e-01, 2.50000e-01]], 

53 ... "V1": [[0.00000e+00, 2.50000e-01], 

54 ... [1.25000e-01, 2.50000e-01], 

55 ... [2.50000e-01, 2.50000e-01]] 

56 ... } 

57 >>> psd_dict = PSDDict(psd_data) 

58 """ 

59 def __init__(self, *args): 

60 super(PSDDict, self).__init__( 

61 *args, value_class=PSD, value_columns=["frequencies", "strains"] 

62 ) 

63 

64 @property 

65 def detectors(self): 

66 return list(self.keys()) 

67 

68 @classmethod 

69 def read(cls, files=None, detectors=None, common_string=None): 

70 """Initiate PSDDict with a set of PSD files 

71 

72 Parameters 

73 ---------- 

74 files: list/dict, optional 

75 Either a list of files or a dictionary of files to read. 

76 If a list of files are provided, a list of corresponding 

77 detectors must also be provided 

78 common_string: str, optional 

79 Common string for PSD files. The string must be formattable and 

80 take one argument which is the detector. For example 

81 common_string='./{}_psd.dat'. Used if files is not provided 

82 detectors: list, optional 

83 List of detectors to use when loading files. Used if files 

84 if not provided or if files is a list or if common_string is 

85 provided 

86 """ 

87 if files is not None: 

88 if isinstance(files, list) and detectors is not None: 

89 if len(detectors) != len(files): 

90 raise ValueError( 

91 "Please provide a detector for each file" 

92 ) 

93 files = {det: ff for det, ff in zip(detectors, files)} 

94 elif isinstance(files, dict): 

95 pass 

96 else: 

97 raise ValueError( 

98 "Please provide either a dictionary of files, or a list " 

99 "files and a list of detectors for which they correspond." 

100 ) 

101 elif common_string is not None and detectors is not None: 

102 files = {det: common_string.format(det) for det in detectors} 

103 else: 

104 raise ValueError( 

105 "Please provide either a list of files to read or " 

106 "a common string and a list of detectors to load." 

107 ) 

108 psd = {} 

109 for key, item in files.items(): 

110 psd[key] = PSD.read(item, IFO=key) 

111 return PSDDict(psd) 

112 

113 def plot(self, **kwargs): 

114 """Generate a plot to display the PSD data stored in PSDDict 

115 

116 Parameters 

117 ---------- 

118 **kwargs: dict 

119 all additional kwargs are passed to 

120 pesummary.gw.plots.plot._psd_plot 

121 """ 

122 from pesummary.gw.plots.plot import _psd_plot 

123 

124 _detectors = self.detectors 

125 frequencies = [self[IFO].frequencies for IFO in _detectors] 

126 strains = [self[IFO].strains for IFO in _detectors] 

127 return _psd_plot(frequencies, strains, labels=_detectors, **kwargs) 

128 

129 def to_pycbc(self, *args, **kwargs): 

130 """Transform dictionary to pycbc.frequencyseries objects 

131 

132 Parameters 

133 ---------- 

134 *args: tuple 

135 all args passed to PSD.to_pycbc() 

136 **kwargs: dict, optional 

137 all kwargs passed to PSD.to_pycbc() 

138 """ 

139 psd = {} 

140 for key, item in self.items(): 

141 psd[key] = item.to_pycbc(*args, **kwargs) 

142 return PSDDict(psd) 

143 

144 def interpolate(self, low_freq_cutoff, delta_f): 

145 """Interpolate a dictionary of PSDs to a new delta_f 

146 

147 Parameters 

148 ---------- 

149 low_freq_cutoff: float 

150 Frequencies below this value are set to zero. 

151 delta_f : float, optional 

152 Frequency resolution of the frequency series in Hertz. 

153 """ 

154 psd = {} 

155 for key, item in self.items(): 

156 psd[key] = item.interpolate(low_freq_cutoff, delta_f) 

157 return PSDDict(psd) 

158 

159 

160class PSD(np.ndarray): 

161 """Class to handle PSD data 

162 """ 

163 def __new__(cls, input_array): 

164 obj = np.asarray(input_array).view(cls) 

165 if obj.shape[1] != 2: 

166 raise ValueError( 

167 "Invalid input data. See the docs for instructions" 

168 ) 

169 obj.delta_f = cls.delta_f(obj) 

170 obj.f_high = cls.f_high(obj) 

171 obj.frequencies = cls.frequencies(obj) 

172 return obj 

173 

174 @property 

175 def low_frequency(self): 

176 return self.frequencies[0] 

177 

178 @staticmethod 

179 def delta_f(array): 

180 return array.T[0][1] - array.T[0][0] 

181 

182 @staticmethod 

183 def f_high(array): 

184 return array.T[0][-1] 

185 

186 @staticmethod 

187 def frequencies(array): 

188 return array.T[0] 

189 

190 @classmethod 

191 def read(cls, path_to_file, **kwargs): 

192 """Read in a file and initialize the PSD class 

193 

194 Parameters 

195 ---------- 

196 path_to_file: str 

197 the path to the file you wish to load 

198 **kwargs: dict 

199 all kwargs are passed to the read methods 

200 """ 

201 from pesummary.core.file.formats.base_read import Read 

202 

203 mapping = { 

204 "dat": PSD.read_from_dat, 

205 "txt": PSD.read_from_dat, 

206 "xml": PSD.read_from_xml, 

207 } 

208 if not os.path.isfile(path_to_file): 

209 raise FileNotFoundError( 

210 "The file '{}' does not exist".format(path_to_file) 

211 ) 

212 extension = Read.extension_from_path(path_to_file) 

213 if ".xml.gz" in path_to_file: 

214 return cls(mapping["xml"](path_to_file, **kwargs)) 

215 elif extension not in mapping.keys(): 

216 raise NotImplementedError( 

217 "Unable to read in a PSD with format '{}'. The allowed formats " 

218 "are: {}".format(extension, ", ".join(list(mapping.keys()))) 

219 ) 

220 return cls(mapping[extension](path_to_file, **kwargs)) 

221 

222 @staticmethod 

223 def read_from_dat(path_to_file, IFO=None, **kwargs): 

224 """Read in a dat file and return a numpy array containing the data 

225 

226 Parameters 

227 ---------- 

228 path_to_file: str 

229 the path to the file you wish to load 

230 **kwargs: dict 

231 all kwargs are passed to the numpy.genfromtxt method 

232 """ 

233 try: 

234 data = np.genfromtxt(path_to_file, **kwargs) 

235 return data 

236 except ValueError: 

237 data = np.genfromtxt(path_to_file, skip_footer=2, **kwargs) 

238 return data 

239 

240 @staticmethod 

241 def read_from_xml(path_to_file, IFO=None, **kwargs): 

242 """Read in an xml file and return a numpy array containing the data 

243 

244 Parameters 

245 ---------- 

246 path_to_file: str 

247 the path to the file you wish to load 

248 IFO: str, optional 

249 name of the dataset that you wish to load 

250 **kwargs: dict 

251 all kwargs are passed to the 

252 gwpy.frequencyseries.FrequencySeries.read method 

253 """ 

254 from gwpy.frequencyseries import FrequencySeries 

255 

256 data = FrequencySeries.read(path_to_file, name=IFO, **kwargs) 

257 frequencies = np.array(data.frequencies) 

258 strains = np.array(data) 

259 return np.vstack([frequencies, strains]).T 

260 

261 def save_to_file(self, file_name, comments="#", delimiter=conf.delimiter): 

262 """Save the calibration data to file 

263 

264 Parameters 

265 ---------- 

266 file_name: str 

267 name of the file name that you wish to use 

268 comments: str, optional 

269 String that will be prepended to the header and footer strings, to 

270 mark them as comments. Default is '#'. 

271 delimiter: str, optional 

272 String or character separating columns. 

273 """ 

274 check_file_exists_and_rename(file_name) 

275 header = ["Frequency", "Strain"] 

276 np.savetxt( 

277 file_name, self, delimiter=delimiter, comments=comments, 

278 header=delimiter.join(header) 

279 ) 

280 

281 def __array_finalize__(self, obj): 

282 if obj is None: 

283 return 

284 self.delta_f = getattr(obj, "delta_f", None) 

285 self.f_high = getattr(obj, "f_high", None) 

286 self.frequencies = getattr(obj, "frequencies", None) 

287 

288 def to_pycbc( 

289 self, low_freq_cutoff, f_high=None, length=None, delta_f=None, 

290 f_high_override=False 

291 ): 

292 """Convert the PSD object to an interpolated pycbc.types.FrequencySeries 

293 

294 Parameters 

295 ---------- 

296 length : int, optional 

297 Length of the frequency series in samples. 

298 delta_f : float, optional 

299 Frequency resolution of the frequency series in Herz. 

300 low_freq_cutoff : float, optional 

301 Frequencies below this value are set to zero. 

302 f_high_override: Bool, optional 

303 Override the final frequency if it is above the maximum stored. 

304 Default False 

305 """ 

306 from pycbc.psd.read import from_numpy_arrays 

307 

308 if delta_f is None: 

309 delta_f = self.delta_f 

310 if f_high is None: 

311 f_high = self.f_high 

312 elif f_high > self.f_high: 

313 msg = ( 

314 "Specified value of final frequency: {} is above the maximum " 

315 "frequency stored: {}. ".format(f_high, self.f_high) 

316 ) 

317 if f_high_override: 

318 msg += "Overwriting the final frequency" 

319 f_high = self.f_high 

320 else: 

321 msg += ( 

322 "This will result in an interpolation error. Either change " 

323 "the final frequency specified or set the 'f_high_override' " 

324 "kwarg to True" 

325 ) 

326 logger.warning(msg) 

327 if length is None: 

328 length = int(f_high / delta_f) + 1 

329 pycbc_psd = from_numpy_arrays( 

330 self.T[0], self.T[1], length, delta_f, low_freq_cutoff 

331 ) 

332 return pycbc_psd 

333 

334 def interpolate(self, low_freq_cutoff, delta_f): 

335 """Interpolate PSD to a new delta_f 

336 

337 Parameters 

338 ---------- 

339 low_freq_cutoff: float 

340 Frequencies below this value are set to zero. 

341 delta_f : float, optional 

342 Frequency resolution of the frequency series in Hertz. 

343 """ 

344 from pesummary.gw.pycbc import interpolate_psd 

345 psd = interpolate_psd(self.copy(), low_freq_cutoff, delta_f) 

346 frequencies, strains = psd.sample_frequencies, psd 

347 inds = np.where(frequencies >= low_freq_cutoff) 

348 return PSD(np.vstack([frequencies[inds], strains[inds]]).T)