Coverage for pesummary/gw/file/psd.py: 77.8%

126 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import os 

4import numpy as np 

5from pesummary import conf 

6from pesummary.utils.utils import logger, check_file_exists_and_rename 

7from pesummary.utils.dict import Dict 

8 

9__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

10 

11 

12class PSDDict(Dict): 

13 """Class to handle a dictionary of PSDs 

14 

15 Parameters 

16 ---------- 

17 detectors: list 

18 list of detectors 

19 data: nd list 

20 list of psd samples for each detector. First column is frequencies, 

21 second column is strains 

22 

23 Attributes 

24 ---------- 

25 detectors: list 

26 list of detectors stored in the dictionary 

27 

28 Methods 

29 ------- 

30 plot: 

31 Generate a plot based on the psd samples stored 

32 to_pycbc: 

33 Convert dictionary of PSD objects to a dictionary of 

34 pycbc.frequencyseries objects objects 

35 

36 Examples 

37 -------- 

38 >>> from pesummary.gw.file.psd import PSDDict 

39 >>> detectors = ["H1", "V1"] 

40 >>> psd_data = [ 

41 ... [[0.00000e+00, 2.50000e-01], 

42 ... [1.25000e-01, 2.50000e-01], 

43 ... [2.50000e-01, 2.50000e-01]], 

44 ... [[0.00000e+00, 2.50000e-01], 

45 ... [1.25000e-01, 2.50000e-01], 

46 ... [2.50000e-01, 2.50000e-01]] 

47 ... ] 

48 >>> psd_dict = PSDDict(detectors, psd_data) 

49 >>> psd_data = { 

50 ... "H1": [[0.00000e+00, 2.50000e-01], 

51 ... [1.25000e-01, 2.50000e-01], 

52 ... [2.50000e-01, 2.50000e-01]], 

53 ... "V1": [[0.00000e+00, 2.50000e-01], 

54 ... [1.25000e-01, 2.50000e-01], 

55 ... [2.50000e-01, 2.50000e-01]] 

56 ... } 

57 >>> psd_dict = PSDDict(psd_data) 

58 """ 

59 def __init__(self, *args): 

60 super(PSDDict, self).__init__( 

61 *args, value_class=PSD, value_columns=["frequencies", "strains"], 

62 deconstruct_complex_columns=False 

63 ) 

64 

65 @property 

66 def detectors(self): 

67 return list(self.keys()) 

68 

69 @classmethod 

70 def read(cls, files=None, detectors=None, common_string=None): 

71 """Initiate PSDDict with a set of PSD files 

72 

73 Parameters 

74 ---------- 

75 files: list/dict, optional 

76 Either a list of files or a dictionary of files to read. 

77 If a list of files are provided, a list of corresponding 

78 detectors must also be provided 

79 common_string: str, optional 

80 Common string for PSD files. The string must be formattable and 

81 take one argument which is the detector. For example 

82 common_string='./{}_psd.dat'. Used if files is not provided 

83 detectors: list, optional 

84 List of detectors to use when loading files. Used if files 

85 if not provided or if files is a list or if common_string is 

86 provided 

87 """ 

88 if files is not None: 

89 if isinstance(files, list) and detectors is not None: 

90 if len(detectors) != len(files): 

91 raise ValueError( 

92 "Please provide a detector for each file" 

93 ) 

94 files = {det: ff for det, ff in zip(detectors, files)} 

95 elif isinstance(files, dict): 

96 pass 

97 else: 

98 raise ValueError( 

99 "Please provide either a dictionary of files, or a list " 

100 "files and a list of detectors for which they correspond." 

101 ) 

102 elif common_string is not None and detectors is not None: 

103 files = {det: common_string.format(det) for det in detectors} 

104 else: 

105 raise ValueError( 

106 "Please provide either a list of files to read or " 

107 "a common string and a list of detectors to load." 

108 ) 

109 psd = {} 

110 for key, item in files.items(): 

111 psd[key] = PSD.read(item, IFO=key) 

112 return PSDDict(psd) 

113 

114 def plot(self, **kwargs): 

115 """Generate a plot to display the PSD data stored in PSDDict 

116 

117 Parameters 

118 ---------- 

119 **kwargs: dict 

120 all additional kwargs are passed to 

121 pesummary.gw.plots.plot._psd_plot 

122 """ 

123 from pesummary.gw.plots.plot import _psd_plot 

124 

125 _detectors = self.detectors 

126 frequencies = [self[IFO].frequencies for IFO in _detectors] 

127 strains = [self[IFO].strains for IFO in _detectors] 

128 return _psd_plot(frequencies, strains, labels=_detectors, **kwargs) 

129 

130 def to_pycbc(self, *args, **kwargs): 

131 """Transform dictionary to pycbc.frequencyseries objects 

132 

133 Parameters 

134 ---------- 

135 *args: tuple 

136 all args passed to PSD.to_pycbc() 

137 **kwargs: dict, optional 

138 all kwargs passed to PSD.to_pycbc() 

139 """ 

140 psd = {} 

141 for key, item in self.items(): 

142 psd[key] = item.to_pycbc(*args, **kwargs) 

143 return PSDDict(psd) 

144 

145 def interpolate(self, low_freq_cutoff, delta_f): 

146 """Interpolate a dictionary of PSDs to a new delta_f 

147 

148 Parameters 

149 ---------- 

150 low_freq_cutoff: float 

151 Frequencies below this value are set to zero. 

152 delta_f : float, optional 

153 Frequency resolution of the frequency series in Hertz. 

154 """ 

155 psd = {} 

156 for key, item in self.items(): 

157 psd[key] = item.interpolate(low_freq_cutoff, delta_f) 

158 return PSDDict(psd) 

159 

160 

161class PSD(np.ndarray): 

162 """Class to handle PSD data 

163 """ 

164 def __new__(cls, input_array): 

165 obj = np.asarray(input_array).view(cls) 

166 if obj.shape[1] != 2: 

167 raise ValueError( 

168 "Invalid input data. See the docs for instructions" 

169 ) 

170 obj.delta_f = cls.delta_f(obj) 

171 obj.f_high = cls.f_high(obj) 

172 obj.frequencies = cls.frequencies(obj) 

173 return obj 

174 

175 @property 

176 def low_frequency(self): 

177 return self.frequencies[0] 

178 

179 @staticmethod 

180 def delta_f(array): 

181 return array.T[0][1] - array.T[0][0] 

182 

183 @staticmethod 

184 def f_high(array): 

185 return array.T[0][-1] 

186 

187 @staticmethod 

188 def frequencies(array): 

189 return array.T[0] 

190 

191 @classmethod 

192 def read(cls, path_to_file, **kwargs): 

193 """Read in a file and initialize the PSD class 

194 

195 Parameters 

196 ---------- 

197 path_to_file: str 

198 the path to the file you wish to load 

199 **kwargs: dict 

200 all kwargs are passed to the read methods 

201 """ 

202 from pesummary.core.file.formats.base_read import Read 

203 

204 mapping = { 

205 "dat": PSD.read_from_dat, 

206 "txt": PSD.read_from_dat, 

207 "xml": PSD.read_from_xml, 

208 } 

209 if not os.path.isfile(path_to_file): 

210 raise FileNotFoundError( 

211 "The file '{}' does not exist".format(path_to_file) 

212 ) 

213 extension = Read.extension_from_path(path_to_file) 

214 if ".xml.gz" in path_to_file: 

215 return cls(mapping["xml"](path_to_file, **kwargs)) 

216 elif extension not in mapping.keys(): 

217 raise NotImplementedError( 

218 "Unable to read in a PSD with format '{}'. The allowed formats " 

219 "are: {}".format(extension, ", ".join(list(mapping.keys()))) 

220 ) 

221 return cls(mapping[extension](path_to_file, **kwargs)) 

222 

223 @staticmethod 

224 def read_from_dat(path_to_file, IFO=None, **kwargs): 

225 """Read in a dat file and return a numpy array containing the data 

226 

227 Parameters 

228 ---------- 

229 path_to_file: str 

230 the path to the file you wish to load 

231 **kwargs: dict 

232 all kwargs are passed to the numpy.genfromtxt method 

233 """ 

234 try: 

235 data = np.genfromtxt(path_to_file, **kwargs) 

236 return data 

237 except ValueError: 

238 data = np.genfromtxt(path_to_file, skip_footer=2, **kwargs) 

239 return data 

240 

241 @staticmethod 

242 def read_from_xml(path_to_file, IFO=None, **kwargs): 

243 """Read in an xml file and return a numpy array containing the data 

244 

245 Parameters 

246 ---------- 

247 path_to_file: str 

248 the path to the file you wish to load 

249 IFO: str, optional 

250 name of the dataset that you wish to load 

251 **kwargs: dict 

252 all kwargs are passed to the 

253 gwpy.frequencyseries.FrequencySeries.read method 

254 """ 

255 from gwpy.frequencyseries import FrequencySeries 

256 

257 data = FrequencySeries.read(path_to_file, name=IFO, **kwargs) 

258 frequencies = np.array(data.frequencies) 

259 strains = np.array(data) 

260 return np.vstack([frequencies, strains]).T 

261 

262 def save_to_file(self, file_name, comments="#", delimiter=conf.delimiter): 

263 """Save the calibration data to file 

264 

265 Parameters 

266 ---------- 

267 file_name: str 

268 name of the file name that you wish to use 

269 comments: str, optional 

270 String that will be prepended to the header and footer strings, to 

271 mark them as comments. Default is '#'. 

272 delimiter: str, optional 

273 String or character separating columns. 

274 """ 

275 check_file_exists_and_rename(file_name) 

276 header = ["Frequency", "Strain"] 

277 np.savetxt( 

278 file_name, self, delimiter=delimiter, comments=comments, 

279 header=delimiter.join(header) 

280 ) 

281 

282 def __array_finalize__(self, obj): 

283 if obj is None: 

284 return 

285 self.delta_f = getattr(obj, "delta_f", None) 

286 self.f_high = getattr(obj, "f_high", None) 

287 self.frequencies = getattr(obj, "frequencies", None) 

288 

289 def to_pycbc( 

290 self, low_freq_cutoff, f_high=None, length=None, delta_f=None, 

291 f_high_override=False 

292 ): 

293 """Convert the PSD object to an interpolated pycbc.types.FrequencySeries 

294 

295 Parameters 

296 ---------- 

297 length : int, optional 

298 Length of the frequency series in samples. 

299 delta_f : float, optional 

300 Frequency resolution of the frequency series in Herz. 

301 low_freq_cutoff : float, optional 

302 Frequencies below this value are set to zero. 

303 f_high_override: Bool, optional 

304 Override the final frequency if it is above the maximum stored. 

305 Default False 

306 """ 

307 from pycbc.psd.read import from_numpy_arrays 

308 

309 if delta_f is None: 

310 delta_f = self.delta_f 

311 if f_high is None: 

312 f_high = self.f_high 

313 elif f_high > self.f_high: 

314 msg = ( 

315 "Specified value of final frequency: {} is above the maximum " 

316 "frequency stored: {}. ".format(f_high, self.f_high) 

317 ) 

318 if f_high_override: 

319 msg += "Overwriting the final frequency" 

320 f_high = self.f_high 

321 else: 

322 msg += ( 

323 "This will result in an interpolation error. Either change " 

324 "the final frequency specified or set the 'f_high_override' " 

325 "kwarg to True" 

326 ) 

327 logger.warning(msg) 

328 if length is None: 

329 length = int(f_high / delta_f) + 1 

330 pycbc_psd = from_numpy_arrays( 

331 self.T[0], self.T[1], length, delta_f, low_freq_cutoff 

332 ) 

333 return pycbc_psd 

334 

335 def interpolate(self, low_freq_cutoff, delta_f): 

336 """Interpolate PSD to a new delta_f 

337 

338 Parameters 

339 ---------- 

340 low_freq_cutoff: float 

341 Frequencies below this value are set to zero. 

342 delta_f : float, optional 

343 Frequency resolution of the frequency series in Hertz. 

344 """ 

345 from pesummary.gw.pycbc import interpolate_psd 

346 psd = interpolate_psd(self.copy(), low_freq_cutoff, delta_f) 

347 frequencies, strains = psd.sample_frequencies, psd 

348 inds = np.where(frequencies >= low_freq_cutoff) 

349 return PSD(np.vstack([frequencies[inds], strains[inds]]).T)