Coverage for pesummary/gw/file/psd.py: 77.8%
126 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import os
4import numpy as np
5from pesummary import conf
6from pesummary.utils.utils import logger, check_file_exists_and_rename
7from pesummary.utils.dict import Dict
9__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
12class PSDDict(Dict):
13 """Class to handle a dictionary of PSDs
15 Parameters
16 ----------
17 detectors: list
18 list of detectors
19 data: nd list
20 list of psd samples for each detector. First column is frequencies,
21 second column is strains
23 Attributes
24 ----------
25 detectors: list
26 list of detectors stored in the dictionary
28 Methods
29 -------
30 plot:
31 Generate a plot based on the psd samples stored
32 to_pycbc:
33 Convert dictionary of PSD objects to a dictionary of
34 pycbc.frequencyseries objects objects
36 Examples
37 --------
38 >>> from pesummary.gw.file.psd import PSDDict
39 >>> detectors = ["H1", "V1"]
40 >>> psd_data = [
41 ... [[0.00000e+00, 2.50000e-01],
42 ... [1.25000e-01, 2.50000e-01],
43 ... [2.50000e-01, 2.50000e-01]],
44 ... [[0.00000e+00, 2.50000e-01],
45 ... [1.25000e-01, 2.50000e-01],
46 ... [2.50000e-01, 2.50000e-01]]
47 ... ]
48 >>> psd_dict = PSDDict(detectors, psd_data)
49 >>> psd_data = {
50 ... "H1": [[0.00000e+00, 2.50000e-01],
51 ... [1.25000e-01, 2.50000e-01],
52 ... [2.50000e-01, 2.50000e-01]],
53 ... "V1": [[0.00000e+00, 2.50000e-01],
54 ... [1.25000e-01, 2.50000e-01],
55 ... [2.50000e-01, 2.50000e-01]]
56 ... }
57 >>> psd_dict = PSDDict(psd_data)
58 """
59 def __init__(self, *args):
60 super(PSDDict, self).__init__(
61 *args, value_class=PSD, value_columns=["frequencies", "strains"],
62 deconstruct_complex_columns=False
63 )
65 @property
66 def detectors(self):
67 return list(self.keys())
69 @classmethod
70 def read(cls, files=None, detectors=None, common_string=None):
71 """Initiate PSDDict with a set of PSD files
73 Parameters
74 ----------
75 files: list/dict, optional
76 Either a list of files or a dictionary of files to read.
77 If a list of files are provided, a list of corresponding
78 detectors must also be provided
79 common_string: str, optional
80 Common string for PSD files. The string must be formattable and
81 take one argument which is the detector. For example
82 common_string='./{}_psd.dat'. Used if files is not provided
83 detectors: list, optional
84 List of detectors to use when loading files. Used if files
85 if not provided or if files is a list or if common_string is
86 provided
87 """
88 if files is not None:
89 if isinstance(files, list) and detectors is not None:
90 if len(detectors) != len(files):
91 raise ValueError(
92 "Please provide a detector for each file"
93 )
94 files = {det: ff for det, ff in zip(detectors, files)}
95 elif isinstance(files, dict):
96 pass
97 else:
98 raise ValueError(
99 "Please provide either a dictionary of files, or a list "
100 "files and a list of detectors for which they correspond."
101 )
102 elif common_string is not None and detectors is not None:
103 files = {det: common_string.format(det) for det in detectors}
104 else:
105 raise ValueError(
106 "Please provide either a list of files to read or "
107 "a common string and a list of detectors to load."
108 )
109 psd = {}
110 for key, item in files.items():
111 psd[key] = PSD.read(item, IFO=key)
112 return PSDDict(psd)
114 def plot(self, **kwargs):
115 """Generate a plot to display the PSD data stored in PSDDict
117 Parameters
118 ----------
119 **kwargs: dict
120 all additional kwargs are passed to
121 pesummary.gw.plots.plot._psd_plot
122 """
123 from pesummary.gw.plots.plot import _psd_plot
125 _detectors = self.detectors
126 frequencies = [self[IFO].frequencies for IFO in _detectors]
127 strains = [self[IFO].strains for IFO in _detectors]
128 return _psd_plot(frequencies, strains, labels=_detectors, **kwargs)
130 def to_pycbc(self, *args, **kwargs):
131 """Transform dictionary to pycbc.frequencyseries objects
133 Parameters
134 ----------
135 *args: tuple
136 all args passed to PSD.to_pycbc()
137 **kwargs: dict, optional
138 all kwargs passed to PSD.to_pycbc()
139 """
140 psd = {}
141 for key, item in self.items():
142 psd[key] = item.to_pycbc(*args, **kwargs)
143 return PSDDict(psd)
145 def interpolate(self, low_freq_cutoff, delta_f):
146 """Interpolate a dictionary of PSDs to a new delta_f
148 Parameters
149 ----------
150 low_freq_cutoff: float
151 Frequencies below this value are set to zero.
152 delta_f : float, optional
153 Frequency resolution of the frequency series in Hertz.
154 """
155 psd = {}
156 for key, item in self.items():
157 psd[key] = item.interpolate(low_freq_cutoff, delta_f)
158 return PSDDict(psd)
161class PSD(np.ndarray):
162 """Class to handle PSD data
163 """
164 def __new__(cls, input_array):
165 obj = np.asarray(input_array).view(cls)
166 if obj.shape[1] != 2:
167 raise ValueError(
168 "Invalid input data. See the docs for instructions"
169 )
170 obj.delta_f = cls.delta_f(obj)
171 obj.f_high = cls.f_high(obj)
172 obj.frequencies = cls.frequencies(obj)
173 return obj
175 @property
176 def low_frequency(self):
177 return self.frequencies[0]
179 @staticmethod
180 def delta_f(array):
181 return array.T[0][1] - array.T[0][0]
183 @staticmethod
184 def f_high(array):
185 return array.T[0][-1]
187 @staticmethod
188 def frequencies(array):
189 return array.T[0]
191 @classmethod
192 def read(cls, path_to_file, **kwargs):
193 """Read in a file and initialize the PSD class
195 Parameters
196 ----------
197 path_to_file: str
198 the path to the file you wish to load
199 **kwargs: dict
200 all kwargs are passed to the read methods
201 """
202 from pesummary.core.file.formats.base_read import Read
204 mapping = {
205 "dat": PSD.read_from_dat,
206 "txt": PSD.read_from_dat,
207 "xml": PSD.read_from_xml,
208 }
209 if not os.path.isfile(path_to_file):
210 raise FileNotFoundError(
211 "The file '{}' does not exist".format(path_to_file)
212 )
213 extension = Read.extension_from_path(path_to_file)
214 if ".xml.gz" in path_to_file:
215 return cls(mapping["xml"](path_to_file, **kwargs))
216 elif extension not in mapping.keys():
217 raise NotImplementedError(
218 "Unable to read in a PSD with format '{}'. The allowed formats "
219 "are: {}".format(extension, ", ".join(list(mapping.keys())))
220 )
221 return cls(mapping[extension](path_to_file, **kwargs))
223 @staticmethod
224 def read_from_dat(path_to_file, IFO=None, **kwargs):
225 """Read in a dat file and return a numpy array containing the data
227 Parameters
228 ----------
229 path_to_file: str
230 the path to the file you wish to load
231 **kwargs: dict
232 all kwargs are passed to the numpy.genfromtxt method
233 """
234 try:
235 data = np.genfromtxt(path_to_file, **kwargs)
236 return data
237 except ValueError:
238 data = np.genfromtxt(path_to_file, skip_footer=2, **kwargs)
239 return data
241 @staticmethod
242 def read_from_xml(path_to_file, IFO=None, **kwargs):
243 """Read in an xml file and return a numpy array containing the data
245 Parameters
246 ----------
247 path_to_file: str
248 the path to the file you wish to load
249 IFO: str, optional
250 name of the dataset that you wish to load
251 **kwargs: dict
252 all kwargs are passed to the
253 gwpy.frequencyseries.FrequencySeries.read method
254 """
255 from gwpy.frequencyseries import FrequencySeries
257 data = FrequencySeries.read(path_to_file, name=IFO, **kwargs)
258 frequencies = np.array(data.frequencies)
259 strains = np.array(data)
260 return np.vstack([frequencies, strains]).T
262 def save_to_file(self, file_name, comments="#", delimiter=conf.delimiter):
263 """Save the calibration data to file
265 Parameters
266 ----------
267 file_name: str
268 name of the file name that you wish to use
269 comments: str, optional
270 String that will be prepended to the header and footer strings, to
271 mark them as comments. Default is '#'.
272 delimiter: str, optional
273 String or character separating columns.
274 """
275 check_file_exists_and_rename(file_name)
276 header = ["Frequency", "Strain"]
277 np.savetxt(
278 file_name, self, delimiter=delimiter, comments=comments,
279 header=delimiter.join(header)
280 )
282 def __array_finalize__(self, obj):
283 if obj is None:
284 return
285 self.delta_f = getattr(obj, "delta_f", None)
286 self.f_high = getattr(obj, "f_high", None)
287 self.frequencies = getattr(obj, "frequencies", None)
289 def to_pycbc(
290 self, low_freq_cutoff, f_high=None, length=None, delta_f=None,
291 f_high_override=False
292 ):
293 """Convert the PSD object to an interpolated pycbc.types.FrequencySeries
295 Parameters
296 ----------
297 length : int, optional
298 Length of the frequency series in samples.
299 delta_f : float, optional
300 Frequency resolution of the frequency series in Herz.
301 low_freq_cutoff : float, optional
302 Frequencies below this value are set to zero.
303 f_high_override: Bool, optional
304 Override the final frequency if it is above the maximum stored.
305 Default False
306 """
307 from pycbc.psd.read import from_numpy_arrays
309 if delta_f is None:
310 delta_f = self.delta_f
311 if f_high is None:
312 f_high = self.f_high
313 elif f_high > self.f_high:
314 msg = (
315 "Specified value of final frequency: {} is above the maximum "
316 "frequency stored: {}. ".format(f_high, self.f_high)
317 )
318 if f_high_override:
319 msg += "Overwriting the final frequency"
320 f_high = self.f_high
321 else:
322 msg += (
323 "This will result in an interpolation error. Either change "
324 "the final frequency specified or set the 'f_high_override' "
325 "kwarg to True"
326 )
327 logger.warning(msg)
328 if length is None:
329 length = int(f_high / delta_f) + 1
330 pycbc_psd = from_numpy_arrays(
331 self.T[0], self.T[1], length, delta_f, low_freq_cutoff
332 )
333 return pycbc_psd
335 def interpolate(self, low_freq_cutoff, delta_f):
336 """Interpolate PSD to a new delta_f
338 Parameters
339 ----------
340 low_freq_cutoff: float
341 Frequencies below this value are set to zero.
342 delta_f : float, optional
343 Frequency resolution of the frequency series in Hertz.
344 """
345 from pesummary.gw.pycbc import interpolate_psd
346 psd = interpolate_psd(self.copy(), low_freq_cutoff, delta_f)
347 frequencies, strains = psd.sample_frequencies, psd
348 inds = np.where(frequencies >= low_freq_cutoff)
349 return PSD(np.vstack([frequencies[inds], strains[inds]]).T)