Coverage for pesummary/gw/file/psd.py: 56.3%
126 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import os
4import numpy as np
5from pesummary import conf
6from pesummary.utils.utils import logger, check_file_exists_and_rename
7from pesummary.utils.dict import Dict
9__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
12class PSDDict(Dict):
13 """Class to handle a dictionary of PSDs
15 Parameters
16 ----------
17 detectors: list
18 list of detectors
19 data: nd list
20 list of psd samples for each detector. First column is frequencies,
21 second column is strains
23 Attributes
24 ----------
25 detectors: list
26 list of detectors stored in the dictionary
28 Methods
29 -------
30 plot:
31 Generate a plot based on the psd samples stored
32 to_pycbc:
33 Convert dictionary of PSD objects to a dictionary of
34 pycbc.frequencyseries objects objects
36 Examples
37 --------
38 >>> from pesummary.gw.file.psd import PSDDict
39 >>> detectors = ["H1", "V1"]
40 >>> psd_data = [
41 ... [[0.00000e+00, 2.50000e-01],
42 ... [1.25000e-01, 2.50000e-01],
43 ... [2.50000e-01, 2.50000e-01]],
44 ... [[0.00000e+00, 2.50000e-01],
45 ... [1.25000e-01, 2.50000e-01],
46 ... [2.50000e-01, 2.50000e-01]]
47 ... ]
48 >>> psd_dict = PSDDict(detectors, psd_data)
49 >>> psd_data = {
50 ... "H1": [[0.00000e+00, 2.50000e-01],
51 ... [1.25000e-01, 2.50000e-01],
52 ... [2.50000e-01, 2.50000e-01]],
53 ... "V1": [[0.00000e+00, 2.50000e-01],
54 ... [1.25000e-01, 2.50000e-01],
55 ... [2.50000e-01, 2.50000e-01]]
56 ... }
57 >>> psd_dict = PSDDict(psd_data)
58 """
59 def __init__(self, *args):
60 super(PSDDict, self).__init__(
61 *args, value_class=PSD, value_columns=["frequencies", "strains"]
62 )
64 @property
65 def detectors(self):
66 return list(self.keys())
68 @classmethod
69 def read(cls, files=None, detectors=None, common_string=None):
70 """Initiate PSDDict with a set of PSD files
72 Parameters
73 ----------
74 files: list/dict, optional
75 Either a list of files or a dictionary of files to read.
76 If a list of files are provided, a list of corresponding
77 detectors must also be provided
78 common_string: str, optional
79 Common string for PSD files. The string must be formattable and
80 take one argument which is the detector. For example
81 common_string='./{}_psd.dat'. Used if files is not provided
82 detectors: list, optional
83 List of detectors to use when loading files. Used if files
84 if not provided or if files is a list or if common_string is
85 provided
86 """
87 if files is not None:
88 if isinstance(files, list) and detectors is not None:
89 if len(detectors) != len(files):
90 raise ValueError(
91 "Please provide a detector for each file"
92 )
93 files = {det: ff for det, ff in zip(detectors, files)}
94 elif isinstance(files, dict):
95 pass
96 else:
97 raise ValueError(
98 "Please provide either a dictionary of files, or a list "
99 "files and a list of detectors for which they correspond."
100 )
101 elif common_string is not None and detectors is not None:
102 files = {det: common_string.format(det) for det in detectors}
103 else:
104 raise ValueError(
105 "Please provide either a list of files to read or "
106 "a common string and a list of detectors to load."
107 )
108 psd = {}
109 for key, item in files.items():
110 psd[key] = PSD.read(item, IFO=key)
111 return PSDDict(psd)
113 def plot(self, **kwargs):
114 """Generate a plot to display the PSD data stored in PSDDict
116 Parameters
117 ----------
118 **kwargs: dict
119 all additional kwargs are passed to
120 pesummary.gw.plots.plot._psd_plot
121 """
122 from pesummary.gw.plots.plot import _psd_plot
124 _detectors = self.detectors
125 frequencies = [self[IFO].frequencies for IFO in _detectors]
126 strains = [self[IFO].strains for IFO in _detectors]
127 return _psd_plot(frequencies, strains, labels=_detectors, **kwargs)
129 def to_pycbc(self, *args, **kwargs):
130 """Transform dictionary to pycbc.frequencyseries objects
132 Parameters
133 ----------
134 *args: tuple
135 all args passed to PSD.to_pycbc()
136 **kwargs: dict, optional
137 all kwargs passed to PSD.to_pycbc()
138 """
139 psd = {}
140 for key, item in self.items():
141 psd[key] = item.to_pycbc(*args, **kwargs)
142 return PSDDict(psd)
144 def interpolate(self, low_freq_cutoff, delta_f):
145 """Interpolate a dictionary of PSDs to a new delta_f
147 Parameters
148 ----------
149 low_freq_cutoff: float
150 Frequencies below this value are set to zero.
151 delta_f : float, optional
152 Frequency resolution of the frequency series in Hertz.
153 """
154 psd = {}
155 for key, item in self.items():
156 psd[key] = item.interpolate(low_freq_cutoff, delta_f)
157 return PSDDict(psd)
160class PSD(np.ndarray):
161 """Class to handle PSD data
162 """
163 def __new__(cls, input_array):
164 obj = np.asarray(input_array).view(cls)
165 if obj.shape[1] != 2:
166 raise ValueError(
167 "Invalid input data. See the docs for instructions"
168 )
169 obj.delta_f = cls.delta_f(obj)
170 obj.f_high = cls.f_high(obj)
171 obj.frequencies = cls.frequencies(obj)
172 return obj
174 @property
175 def low_frequency(self):
176 return self.frequencies[0]
178 @staticmethod
179 def delta_f(array):
180 return array.T[0][1] - array.T[0][0]
182 @staticmethod
183 def f_high(array):
184 return array.T[0][-1]
186 @staticmethod
187 def frequencies(array):
188 return array.T[0]
190 @classmethod
191 def read(cls, path_to_file, **kwargs):
192 """Read in a file and initialize the PSD class
194 Parameters
195 ----------
196 path_to_file: str
197 the path to the file you wish to load
198 **kwargs: dict
199 all kwargs are passed to the read methods
200 """
201 from pesummary.core.file.formats.base_read import Read
203 mapping = {
204 "dat": PSD.read_from_dat,
205 "txt": PSD.read_from_dat,
206 "xml": PSD.read_from_xml,
207 }
208 if not os.path.isfile(path_to_file):
209 raise FileNotFoundError(
210 "The file '{}' does not exist".format(path_to_file)
211 )
212 extension = Read.extension_from_path(path_to_file)
213 if ".xml.gz" in path_to_file:
214 return cls(mapping["xml"](path_to_file, **kwargs))
215 elif extension not in mapping.keys():
216 raise NotImplementedError(
217 "Unable to read in a PSD with format '{}'. The allowed formats "
218 "are: {}".format(extension, ", ".join(list(mapping.keys())))
219 )
220 return cls(mapping[extension](path_to_file, **kwargs))
222 @staticmethod
223 def read_from_dat(path_to_file, IFO=None, **kwargs):
224 """Read in a dat file and return a numpy array containing the data
226 Parameters
227 ----------
228 path_to_file: str
229 the path to the file you wish to load
230 **kwargs: dict
231 all kwargs are passed to the numpy.genfromtxt method
232 """
233 try:
234 data = np.genfromtxt(path_to_file, **kwargs)
235 return data
236 except ValueError:
237 data = np.genfromtxt(path_to_file, skip_footer=2, **kwargs)
238 return data
240 @staticmethod
241 def read_from_xml(path_to_file, IFO=None, **kwargs):
242 """Read in an xml file and return a numpy array containing the data
244 Parameters
245 ----------
246 path_to_file: str
247 the path to the file you wish to load
248 IFO: str, optional
249 name of the dataset that you wish to load
250 **kwargs: dict
251 all kwargs are passed to the
252 gwpy.frequencyseries.FrequencySeries.read method
253 """
254 from gwpy.frequencyseries import FrequencySeries
256 data = FrequencySeries.read(path_to_file, name=IFO, **kwargs)
257 frequencies = np.array(data.frequencies)
258 strains = np.array(data)
259 return np.vstack([frequencies, strains]).T
261 def save_to_file(self, file_name, comments="#", delimiter=conf.delimiter):
262 """Save the calibration data to file
264 Parameters
265 ----------
266 file_name: str
267 name of the file name that you wish to use
268 comments: str, optional
269 String that will be prepended to the header and footer strings, to
270 mark them as comments. Default is '#'.
271 delimiter: str, optional
272 String or character separating columns.
273 """
274 check_file_exists_and_rename(file_name)
275 header = ["Frequency", "Strain"]
276 np.savetxt(
277 file_name, self, delimiter=delimiter, comments=comments,
278 header=delimiter.join(header)
279 )
281 def __array_finalize__(self, obj):
282 if obj is None:
283 return
284 self.delta_f = getattr(obj, "delta_f", None)
285 self.f_high = getattr(obj, "f_high", None)
286 self.frequencies = getattr(obj, "frequencies", None)
288 def to_pycbc(
289 self, low_freq_cutoff, f_high=None, length=None, delta_f=None,
290 f_high_override=False
291 ):
292 """Convert the PSD object to an interpolated pycbc.types.FrequencySeries
294 Parameters
295 ----------
296 length : int, optional
297 Length of the frequency series in samples.
298 delta_f : float, optional
299 Frequency resolution of the frequency series in Herz.
300 low_freq_cutoff : float, optional
301 Frequencies below this value are set to zero.
302 f_high_override: Bool, optional
303 Override the final frequency if it is above the maximum stored.
304 Default False
305 """
306 from pycbc.psd.read import from_numpy_arrays
308 if delta_f is None:
309 delta_f = self.delta_f
310 if f_high is None:
311 f_high = self.f_high
312 elif f_high > self.f_high:
313 msg = (
314 "Specified value of final frequency: {} is above the maximum "
315 "frequency stored: {}. ".format(f_high, self.f_high)
316 )
317 if f_high_override:
318 msg += "Overwriting the final frequency"
319 f_high = self.f_high
320 else:
321 msg += (
322 "This will result in an interpolation error. Either change "
323 "the final frequency specified or set the 'f_high_override' "
324 "kwarg to True"
325 )
326 logger.warning(msg)
327 if length is None:
328 length = int(f_high / delta_f) + 1
329 pycbc_psd = from_numpy_arrays(
330 self.T[0], self.T[1], length, delta_f, low_freq_cutoff
331 )
332 return pycbc_psd
334 def interpolate(self, low_freq_cutoff, delta_f):
335 """Interpolate PSD to a new delta_f
337 Parameters
338 ----------
339 low_freq_cutoff: float
340 Frequencies below this value are set to zero.
341 delta_f : float, optional
342 Frequency resolution of the frequency series in Hertz.
343 """
344 from pesummary.gw.pycbc import interpolate_psd
345 psd = interpolate_psd(self.copy(), low_freq_cutoff, delta_f)
346 frequencies, strains = psd.sample_frequencies, psd
347 inds = np.where(frequencies >= low_freq_cutoff)
348 return PSD(np.vstack([frequencies[inds], strains[inds]]).T)