Source code for bilby_pipe.data_generation
#!/usr/bin/env python
""" Script to perform data generation steps """
import glob
import os
import sys
import gwpy
import lal
import numpy as np
import bilby
from bilby.gw.detector import PowerSpectralDensity, calibration
from bilby_pipe.input import Input
from bilby_pipe.main import parse_args
from bilby_pipe.parser import create_parser
from bilby_pipe.plotting_utils import plot_whitened_data, strain_spectrogram_plot
from bilby_pipe.utils import (
BilbyPipeError,
DataDump,
convert_string_to_dict,
get_geocent_time_with_uncertainty,
get_version_information,
is_a_power_of_2,
log_function_call,
log_version_information,
logger,
)
# fmt: off
import matplotlib # isort:skip
matplotlib.use("agg")
# fmt: on
try:
import nds2 # noqa
except ImportError:
logger.warning(
"You do not have nds2 (python-nds2-client) installed. You may "
" experience problems accessing interferometer data."
)
try:
import LDAStools.frameCPP # noqa
except ImportError:
logger.warning(
"You do not have LDAStools.frameCPP (python-ldas-tools-framecpp) "
"installed. You may experience problems accessing interferometer data."
)
[docs]
class DataGenerationInput(Input):
"""Handles user-input for the data generation script
Parameters
----------
parser: configargparse.ArgParser, optional
The parser containing the command line / ini file inputs
args_list: list, optional
A list of the arguments to parse. Defaults to `sys.argv[1:]`
create_data: bool
If false, no data is generated (used for testing)
"""
def __init__(self, args, unknown_args, create_data=True):
super().__init__(args, unknown_args)
# Generic initialisation
[docs]
self.meta_data = dict(
command_line_args=args.__dict__,
unknown_command_line_args=unknown_args,
injection_parameters=None,
bilby_version=bilby.__version__,
bilby_pipe_version=get_version_information(),
)
self.injection_parameters = None
# Admin arguments
# Run index arguments
self.generation_seed = args.generation_seed
# Naming arguments
# Prior arguments
# Data arguments
self.channel_dict = args.channel_dict
if args.timeslide_dict is not None:
self.timeslide_dict = convert_string_to_dict(args.timeslide_dict)
logger.info(f"Read-in timeslide dict directly: {self.timeslide_dict}")
elif args.timeslide_file is not None:
self.gps_file = args.gps_file
self.timeslide_file = args.timeslide_file
self.timeslide_dict = self.get_timeslide_dict(self.idx)
# Data duration arguments
# Frequencies
self.sampling_frequency = args.sampling_frequency
# Waveform, source model and likelihood
# PSD
if self.psd_dict is None:
self.psd_length = args.psd_length
self.psd_fractional_overlap = args.psd_fractional_overlap
self.psd_start_time = args.psd_start_time
self.psd_method = args.psd_method
# Calibration
[docs]
self.spline_calibration_amplitude_uncertainty_dict = (
args.spline_calibration_amplitude_uncertainty_dict
)
[docs]
self.spline_calibration_phase_uncertainty_dict = (
args.spline_calibration_phase_uncertainty_dict
)
# ROQ
# Heterodyning
if args.fiducial_parameters is not None:
self.fiducial_parameters = convert_string_to_dict(args.fiducial_parameters)
else:
self.fiducial_parameters = args.fiducial_parameters
if self.fiducial_parameters is None:
self.update_fiducial_parameters = True
# Marginalization
self.phase_marginalization = args.phase_marginalization
# Plotting
# Reweighting
if create_data:
self.create_data(args)
[docs]
def create_data(self, args):
"""Function to iterate through data generation method
Note, the data methods are mutually exclusive and only one can be given to
the parser.
Parameters
----------
args: Namespace
Input arguments
Raises
------
BilbyPipeError:
If no data is generated
"""
self.data_set = False
self.injection = args.injection
self.injection_numbers = args.injection_numbers
self.injection_file = args.injection_file
self.injection_dict = args.injection_dict
self.gaussian_noise = args.gaussian_noise
self.injection_waveform_arguments = args.injection_waveform_arguments
# The following are all mutually exclusive methods to set the data
if self.gaussian_noise or self.zero_noise:
if args.injection_file is not None:
logger.debug("Using provided injection file")
self._set_interferometers_from_injection_in_gaussian_noise()
elif args.injection_dict is not None:
logger.debug("Using provided injection dict")
self._set_interferometers_from_injection_in_gaussian_noise()
elif args.injection is False:
self._set_interferometers_from_gaussian_noise()
else:
raise BilbyPipeError("Unable to set data: no injection file")
else:
self._set_interferometers_from_data()
if self.data_set is False:
raise BilbyPipeError("Unable to set data")
@property
@generation_seed.setter
def generation_seed(self, generation_seed):
"""Sets the generation seed.
If no generation seed has been provided, a random seed between 1 and 1e6 is
selected.
If a seed is provided, it is used as the base seed and all generation jobs will
have their seeds set as {generation_seed = base_seed + job_idx}.
NOTE: self.idx must not be None
Parameters
----------
generation_seed: int or None
"""
if generation_seed is None:
generation_seed = np.random.randint(1, 1e6)
else:
assert self.idx is not None
generation_seed = generation_seed + self.idx
self._generation_seed = generation_seed
bilby.core.utils.random.seed(generation_seed)
logger.info(f"Generation seed set to {generation_seed}")
@property
@injection_parameters.setter
def injection_parameters(self, injection_parameters):
self._injection_parameters = injection_parameters
if self.calibration_prior is not None:
for key in self.calibration_prior:
if key not in injection_parameters:
if "frequency" in key:
injection_parameters[key] = self.calibration_prior[key].peak
else:
injection_parameters[key] = 0
self.meta_data["injection_parameters"] = injection_parameters
@property
[docs]
def parameter_conversion(self):
if "binary_neutron_star" in self.frequency_domain_source_model:
return bilby.gw.conversion.convert_to_lal_binary_neutron_star_parameters
elif "binary_black_hole" in self.frequency_domain_source_model:
return bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters
else:
return None
@property
@channel_dict.setter
def channel_dict(self, channel_dict):
if channel_dict is not None:
self._channel_dict = convert_string_to_dict(channel_dict, "channel-dict")
else:
logger.debug("channel-dict set to None")
self._channel_dict = None
[docs]
def get_channel_type(self, det):
"""Help method to read the channel_dict and print useful messages"""
if self.channel_dict is None:
raise BilbyPipeError("No channel-dict argument provided")
if det in self.channel_dict:
return self.channel_dict[det]
else:
raise BilbyPipeError(f"Detector {det} not given in the channel-dict")
@property
@sampling_frequency.setter
def sampling_frequency(self, sampling_frequency):
if is_a_power_of_2(sampling_frequency) is False:
logger.warning(
"Sampling frequency {} not a power of 2, this can cause problems".format(
sampling_frequency
)
)
self._sampling_frequency = sampling_frequency
[docs]
def _set_interferometers_from_gaussian_noise(self):
"""Method to generate the interferometers data from Gaussian noise"""
ifos = bilby.gw.detector.InterferometerList(self.detectors)
if self.psd_dict is not None:
for ifo in ifos:
if ifo.name in self.psd_dict.keys():
self._set_psd_from_file(ifo)
if self.zero_noise:
logger.info("Setting strain data from zero noise")
ifos.set_strain_data_from_zero_noise(
sampling_frequency=self.sampling_frequency,
duration=self.duration,
start_time=self.start_time,
)
else:
logger.info("Simulating strain data from psd-colored noise")
ifos.set_strain_data_from_power_spectral_densities(
sampling_frequency=self.sampling_frequency,
duration=self.duration,
start_time=self.start_time,
)
self.interferometers = ifos
[docs]
def _set_interferometers_from_injection_in_gaussian_noise(self):
"""Method to generate the interferometers data from an injection in Gaussian noise"""
self.injection_parameters = self.injection_df.iloc[self.idx].to_dict()
logger.info("Injecting waveform with ")
for prop in [
"minimum_frequency",
"maximum_frequency",
"trigger_time",
"start_time",
"duration",
]:
logger.info(f"{prop} = {getattr(self, prop)}")
self._set_interferometers_from_gaussian_noise()
waveform_arguments = self.get_injection_waveform_arguments()
logger.info(f"Using waveform arguments: {waveform_arguments}")
waveform_generator = self.waveform_generator_class(
duration=self.duration,
start_time=self.start_time,
sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=self.bilby_frequency_domain_source_model,
parameter_conversion=self.parameter_conversion,
waveform_arguments=waveform_arguments,
)
self.interferometers.inject_signal(
waveform_generator=waveform_generator,
parameters=self.injection_parameters,
raise_error=self.enforce_signal_duration,
)
[docs]
def inject_signal_into_time_domain_data(self, data, ifo):
"""Method to inject a signal into time-domain interferometer data
Parameters of the injection are obtained from the `injection_parameters` or
the injection file (if injection_parameters has not been set).
The geocent_time of the injection is set to be trigger_time +/- deltaT/2 if
the geocent_time is not provided in the injection parameters.
Parameters
----------
data: gwpy.timeseries.TimeSeries
The data into which to inject the signal
ifo: bilby.gw.detector.Interferometer
The interferometer for which the data relates to
Returns
-------
data_and_signal: gwpy.timeseries.TimeSeries
The data with the signal added
"""
# Get the injection parameters
if self.injection_parameters is not None:
parameters = self.injection_parameters
else:
parameters = self.injection_df.iloc[self.idx].to_dict()
self.injection_parameters = parameters
# Set the geocent time if none is provided
if "geocent_time" not in parameters:
parameters["geocent_time"] = get_geocent_time_with_uncertainty(
geocent_time=self.trigger_time, uncertainty=self.deltaT / 2.0
)
waveform_arguments = self.get_injection_waveform_arguments()
waveform_generator = self.waveform_generator_class(
duration=self.duration,
sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=self.bilby_frequency_domain_source_model,
parameter_conversion=self.parameter_conversion,
waveform_arguments=waveform_arguments,
)
if self.plot_injection:
outdir = self.data_directory
label = self.label
else:
outdir = None
label = None
logger.info(f"Injecting with {self.injection_waveform_approximant}")
(
signal_and_data,
meta_data,
) = bilby.gw.detector.inject_signal_into_gwpy_timeseries(
data=data,
waveform_generator=waveform_generator,
parameters=parameters,
det=ifo.name,
power_spectral_density=ifo.power_spectral_density,
outdir=outdir,
label=label,
)
ifo.meta_data = meta_data
if self.plot_spectrogram:
# Plots of before and after injection saved
plot_kwargs = dict(
det=ifo.name,
data_directory=self.data_directory,
trigger_time=self.trigger_time,
duration=self.duration,
post_trigger_duration=self.post_trigger_duration,
label=self.label,
)
strain_spectrogram_plot(
data=data, extra_label="before_injection", **plot_kwargs
)
strain_spectrogram_plot(
data=signal_and_data, extra_label="with_injection", **plot_kwargs
)
return signal_and_data
[docs]
def _set_psd_from_file(self, ifo):
psd_file = self.psd_dict[ifo.name]
logger.info(f"Setting {ifo.name} PSD from file {psd_file}")
ifo.power_spectral_density = (
PowerSpectralDensity.from_power_spectral_density_file(psd_file=psd_file)
)
[docs]
def _set_interferometers_from_data(self):
"""
Method to generate the interferometers data from data
This sets the PSD before the analysis data so that the SNR of injected
signals is correct.
"""
end_time = self.start_time + self.duration
roll_off = self.tukey_roll_off
if 2 * roll_off > self.duration:
raise ValueError("2 * tukey-roll-off is longer than segment duration.")
ifo_list = []
for det in self.detectors:
ifo = bilby.gw.detector.get_empty_interferometer(det)
ifo.strain_data.roll_off = roll_off
if self.psd_dict is not None and det in self.psd_dict:
psd_data = None
self._set_psd_from_file(ifo)
else:
logger.info(f"Setting PSD for {det} from data")
psd_data = self.__get_psd_data(det)
psd = self.__generate_psd(psd_data, roll_off)
ifo.power_spectral_density = PowerSpectralDensity(
frequency_array=psd.frequencies.value, psd_array=psd.value
)
logger.info(f"Getting analysis-segment data for {det}")
data = self._get_data(
det, self.get_channel_type(det), self.start_time, end_time
)
if self.injection:
data = self.inject_signal_into_time_domain_data(data, ifo)
ifo.strain_data.set_from_gwpy_timeseries(data)
if self.plot_spectrogram:
self.__plot_ifo_spectrogram(
det, strain_data=data, psd_strain_data=psd_data
)
ifo_list.append(ifo)
self.interferometers = bilby.gw.detector.InterferometerList(ifo_list)
[docs]
def __get_psd_data(self, det):
# psd_start_time is given relative to the segment start time
# so here we calculate the actual start time
actual_psd_start_time = self.start_time + self.psd_start_time
actual_psd_end_time = actual_psd_start_time + self.psd_duration
logger.info(f"Getting psd-segment data for {det}")
psd_data = self._get_data(
det, self.get_channel_type(det), actual_psd_start_time, actual_psd_end_time
)
return psd_data
[docs]
def __generate_psd(self, psd_data, roll_off):
"""Create the psd from strain data."""
psd_alpha = 2 * roll_off / self.duration
overlap = self.psd_fractional_overlap * self.duration
logger.info(
"PSD settings: window=Tukey, Tukey-alpha={} roll-off={},"
" overlap={}, method={}".format(
psd_alpha, roll_off, overlap, self.psd_method
)
)
psd = psd_data.psd(
fftlength=self.duration,
overlap=overlap,
window=("tukey", psd_alpha),
method=self.psd_method,
)
return psd
[docs]
def __plot_ifo_spectrogram(self, det, strain_data, psd_strain_data=None):
"""Method to plot an IFO's data.
Parameters
----------
det: str
The detector name corresponding to the key in data-dict
strain_data, psd_strain_data: gwpy.TimeSeries
The timeseries strain data of a detector.
Returns
-------
None
File by the name `<outdir>/data/<det>_<Label>_D{duration}_data.png`
is saved
"""
if psd_strain_data is None:
logger.info("Unable to plot the IFO data without the PSD data")
return
else:
plot_psd = True
plot_kwargs = dict(
det=det,
data_directory=self.data_directory,
trigger_time=self.trigger_time,
duration=self.duration,
post_trigger_duration=self.post_trigger_duration,
label=self.label,
)
time = [strain_data.t0.value, strain_data.t0.value + strain_data.duration.value]
psd_time = [
psd_strain_data.t0.value,
psd_strain_data.t0.value + psd_strain_data.duration.value,
]
# plot PSD
if plot_psd:
strain_spectrogram_plot(
data=psd_strain_data,
extra_label=f"D{int(psd_time[1] - psd_time[0])}",
**plot_kwargs,
)
# plot psd_strain_data+strain_data and zoom into strain_data segment
data_with_psd = psd_strain_data.append(strain_data, inplace=False)
strain_spectrogram_plot(
data=data_with_psd, extra_label=f"D{int(time[1] - time[0])}", **plot_kwargs
)
[docs]
def _get_data(self, det, channel_type, start_time, end_time, resample=True):
"""Read in data using gwpy
If the channel_type is "GWOSC", open data is obtained. Otherwise, we
try to read in the data first using "read" if a data_dict exists and
then using "get".
Parameters
----------
channel_type: str
The full channel name is formed from <det>:<channel_type>, see
bilby_pipe --help for more information.
start_time, end_time: float
GPS start and end time of segment
Returns
-------
data: gwpy.timeseries.TimeSeries
The loaded data
Raises
------
BilbyPipeError:
If there is an issue obtaining the data or with the data itself
"""
timeslide_val = None
if hasattr(self, "timeslide_dict"):
timeslide_val = self.timeslide_dict[det]
start_time = start_time + timeslide_val
end_time = end_time + timeslide_val
logger.info(
"Applying timeshift of {}. Time range {} - {} => {} - {}".format(
timeslide_val,
start_time - timeslide_val,
end_time - timeslide_val,
start_time,
end_time,
)
)
if self.ignore_gwpy_data_quality_check is False:
data_is_good = self._is_gwpy_data_good(start_time, end_time, det)
if not data_is_good:
raise BilbyPipeError("Data quality is not good.")
data = None
if data is None and channel_type == "GWOSC":
data = self._gwpy_fetch_open_data(det, start_time, end_time)
channel = f"{det}:{channel_type}"
if data is not None:
pass
elif self.data_dict is not None:
data = self._gwpy_read(det, channel, start_time, end_time)
else:
data = self._gwpy_get(channel, start_time, end_time)
if data is None:
raise BilbyPipeError("Failed to obtain data")
if np.all(data.value == 0):
raise BilbyPipeError("Obtained data is all zeros")
if data.unit.to_string() != "":
from astropy.units import dimensionless_unscaled
logger.warning(
f"Loaded {det} data has unit '{data.unit.to_string()}', "
"overwriting to dimensionless"
)
data.override_unit(dimensionless_unscaled)
if resample and data.sample_rate.value == self.sampling_frequency:
logger.info("Sample rate matches data no resampling")
elif resample:
message = "Resampling data to sampling_frequency {} using {}"
if self.resampling_method == "gwpy":
logger.info(
message.format(self.sampling_frequency, self.resampling_method)
)
data = data.resample(self.sampling_frequency)
elif self.resampling_method == "lal":
logger.info(
message.format(self.sampling_frequency, self.resampling_method)
)
try:
lal_timeseries = data.to_lal()
lal.ResampleREAL8TimeSeries(
lal_timeseries, float(1 / self.sampling_frequency)
)
except Exception as e:
raise BilbyPipeError(
"The lal resampling method has failed with exception {} "
"You may wish to try a different resampling method".format(e)
)
data = gwpy.timeseries.TimeSeries(
lal_timeseries.data.data,
epoch=lal_timeseries.epoch,
dt=lal_timeseries.deltaT,
)
else:
logger.warning(
"Resampling method {} not understood, should be "
"'gwpy' or 'lal'.".format(self.resampling_method)
)
else:
logger.info("No data resampling requested")
if timeslide_val:
# to match up the time axis for the interferometer network
data.shift(-timeslide_val)
return data
@staticmethod
[docs]
def _is_gwpy_data_good(start_time, end_time, det):
"""Check if start-end time is a period when the IFO has quality data.
Check passes if the IFO has quality data during the time period provided.
Note: we are using the ITF_SCIENCE channel to check the quality for V1, and DMT-SCIENCE for other detectors.
https://labcit.ligo.caltech.edu/~jzweizig/talks/LSC-2009-06-03/DMT-DQ_Stat-2009-06-03.pdf
This method is slow as it queries GWpy.
Parameters
----------
start_time, end_time: float
GPS start and end time of required data.
det: str
The string key that represents the detector ("H1", "L1", etc)
Returns
-------
True: if data is good (IFO has quality data during entire duration).
False: if data is bad (IFO does not have quality data during entire duration).
None: if the data quality check failed
"""
# Create data quality flag
channel_num = 1
quality_flag = (
f"{det}:ITF_SCIENCE:{channel_num}"
if det == "V1"
else f"{det}:DMT-SCIENCE:{channel_num}"
)
logger.info(f"Checking data quality {quality_flag} {start_time}-{end_time}")
try:
flag = gwpy.segments.DataQualityFlag.query(
quality_flag, gwpy.time.to_gps(start_time), gwpy.time.to_gps(end_time)
)
# compare active duration from quality flag and total duration
total_duration = end_time - start_time
active_duration = float(flag.livetime)
inactive_duration = total_duration - active_duration
# data is not good if there is any period when the IFO is inactive
if inactive_duration > 0:
data_is_good = False
logger.warning(
"Data quality check: FAILED. \n"
"{det} does not have quality data for "
"{inactive_duration}s out of {total_duration}s".format(
det=det,
inactive_duration=inactive_duration,
total_duration=total_duration,
)
)
else:
data_is_good = True
logger.info("Data quality check: PASSED.")
except Exception as e:
logger.warning(f"Error in Data Quality Check: {e}.")
data_is_good = None
return data_is_good
[docs]
def _gwpy_read(self, det, channel, start_time, end_time, dtype="float64"):
"""Wrapper function to gwpy.timeseries.TimeSeries.read()
Parameters
----------
det: str
The detector name corresponding to the key in data-dict
channel: str
The name of the channel to read, e.g. 'L1:GDS-CALIB_STRAIN'
start_time, end_time: float
GPS start and end time of required data
dtype: str or np.dtype
Data type requested
Returns
-------
data: TimeSeries
If successful, the data, otherwise None is returned
"""
logger.debug("data-dict provided, attempt read of data")
if det not in self.data_dict:
logger.info(f"Detector {det} not found in data-dict")
return None
else:
source = self.data_dict[det]
if isinstance(source, str):
format_ext = os.path.splitext(source)[1]
elif isinstance(source, list):
format_ext = os.path.splitext(source[0])[1]
else:
raise TypeError(
f"Frames should be either string or list, not {type(source)}"
)
# If the source contains a glob-path, e.g. *gwf, glob it first
if "*" in source:
logger.info(f"Globbing {source}")
source = glob.glob(source)
logger.info(f"Setting source={source}")
if "gwf" in format_ext:
kwargs = dict(source=source, channel=channel, format="gwf.lalframe")
type_kwargs = dict(dtype=dtype, subok=True, copy=False)
elif "hdf5" in format_ext:
kwargs = dict(source=source, start=start_time, end=end_time, format="hdf5")
elif "txt" in format_ext:
data = kwargs = dict(source=source)
else:
# Generic best try
kwargs = dict(
source=source, channel=channel, start=start_time, end=end_time
)
if self.data_format is not None:
kwargs["format"] = self.data_format
try:
if "gwf" in format_ext:
_call = log_function_call(
"gwpy.timeseries.TimeSeries.read", kwargs, log=False
)
log_function_call(f"{_call}.astype", type_kwargs)
data = gwpy.timeseries.TimeSeries.read(**kwargs).astype(**type_kwargs)
else:
log_function_call("gwpy.timeseries.TimeSeries.read", kwargs)
data = gwpy.timeseries.TimeSeries.read(**kwargs)
data = data.crop(start=start_time, end=end_time)
if data.duration.value < end_time - start_time:
logger.warning(
"Unable to read in requested {}s duration of data from {}"
" only {}s available: returning None".format(
end_time - start_time, source, data.duration.value
)
)
data = None
elif data.duration.value > end_time - start_time:
logger.info(
"Read in {}s of data from {}, but {}s requested, truncating".format(
data.duration.value, source, end_time - start_time
)
)
data = data[data.times.value >= start_time]
data = data[data.times.value < end_time]
return data
except ValueError as e:
logger.info(f"Reading of data failed with error {e}")
return None
[docs]
def _gwpy_get(self, channel, start_time, end_time, dtype="float64"):
"""Wrapper function to gwpy.timeseries.TimeSeries.get()
Parameters
----------
channel: str
The name of the channel to read, e.g. 'L1:GDS-CALIB_STRAIN'
start_time, end_time: float
GPS start and end time of required data
dtype: str or np.dtype
Data type requested
Returns
-------
data: TimeSeries
If successful, the data, otherwise None is returned
"""
logger.debug("Attempt to locate data")
kwargs = dict(
channel=channel,
start=start_time,
end=end_time,
verbose=False,
allow_tape=self.allow_tape,
)
type_kwargs = dict(
dtype=dtype,
subok=True,
copy=False,
)
if self.data_format:
kwargs["format"] = self.data_format
_call = log_function_call("gwpy.timeseries.TimeSeries.get", kwargs, log=False)
log_function_call(f"{_call}.astype", type_kwargs)
try:
data = gwpy.timeseries.TimeSeries.get(**kwargs).astype(**type_kwargs)
return data
except RuntimeError as e:
logger.info(
f"Unable to read data for channel {channel}. "
"You may need to set the GWDATAFIND_SERVER environment variable."
)
logger.debug(f"Error message {e}")
except ImportError:
logger.info("Unable to read data as NDS2 is not installed")
except TypeError:
logger.debug("Problem reading data try again without kwargs")
kwargs = dict(
channel=channel,
start=start_time,
end=end_time,
verbose=False,
)
_call = log_function_call(
"gwpy.timeseries.TimeSeries.get", kwargs, log=False
)
log_function_call(f"{_call}.astype", type_kwargs)
data = gwpy.timeseries.TimeSeries.get(**kwargs).astype(**type_kwargs)
return data
[docs]
def _gwpy_fetch_open_data(self, det, start_time, end_time):
"""Wrapper function to gwpy.timeseries.TimeSeries.fetch_open_data()
Parameters
----------
det: str
The detector name, e.g 'H1'
start_time, end_time: float
GPS start and end time of required data
Returns
-------
data: TimeSeries
If successful, the data, otherwise None is returned
"""
logger.info("Attempting to download data from GWOSC")
kwargs = dict(ifo=det, start=start_time, end=end_time)
log_function_call("gwpy.timeseries.TimeSeries.fetch_open_data", kwargs)
data = gwpy.timeseries.TimeSeries.fetch_open_data(**kwargs)
return data
@property
[docs]
def interferometers(self):
"""A bilby.gw.detector.InterferometerList"""
try:
return self._interferometers
except AttributeError:
raise ValueError(
"interferometers unset, did you provide a set-data method?"
)
[docs]
def add_calibration_model_to_interferometers(self, ifo):
if self.calibration_model in [None, "Recalibrate"]:
ifo.calibration_model = bilby.gw.calibration.Recalibrate()
elif self.calibration_model == "CubicSpline":
ifo.calibration_model = bilby.gw.calibration.CubicSpline(
prefix=f"recalib_{ifo.name}_",
minimum_frequency=ifo.minimum_frequency,
maximum_frequency=ifo.maximum_frequency,
n_points=self.spline_calibration_nodes,
)
elif self.calibration_model == "Precomputed":
model = bilby.gw.calibration.Precomputed
det = ifo.name
if (
self.spline_calibration_envelope_dict is not None
and det in self.spline_calibration_envelope_dict
):
ifo.calibration_model = model.from_envelope_file(
self.spline_calibration_envelope_dict[det],
frequency_array=ifo.frequency_array[ifo.frequency_mask],
n_nodes=self.spline_calibration_nodes,
label=det,
n_curves=1000,
)
elif (
det in self.spline_calibration_amplitude_uncertainty_dict
and det in self.spline_calibration_phase_uncertainty_dict
):
ifo.calibration_model = model.constant_uncertainty_spline(
amplitude_sigma=self.spline_calibration_amplitude_uncertainty_dict[
det
],
phase_sigma=self.spline_calibration_phase_uncertainty_dict[det],
frequency_array=ifo.frequency_array[ifo.frequency_mask],
n_nodes=self.spline_calibration_nodes,
label=det,
n_curves=1000,
)
else:
raise BilbyPipeError(
f"calibration model {self.calibration_model} not implemented"
)
@interferometers.setter
def interferometers(self, interferometers):
for ifo in interferometers:
if isinstance(ifo, bilby.gw.detector.Interferometer) is False:
raise BilbyPipeError(f"ifo={ifo} is not a bilby Interferometer")
if self.minimum_frequency is not None:
ifo.minimum_frequency = self.minimum_frequency_dict[ifo.name]
if self.maximum_frequency is not None:
ifo.maximum_frequency = self.maximum_frequency_dict[ifo.name]
if self.calibration_model is not None:
self.add_calibration_model_to_interferometers(ifo)
self._interferometers = interferometers
self.data_set = True
if self.plot_data:
interferometers.plot_data(outdir=self.data_directory, label=self.label)
plot_whitened_data(
interferometers=interferometers,
data_directory=self.data_directory,
label=self.label,
)
[docs]
def build_calibration_lookups_if_needed(self):
"""
Build lookup files that are needed for incorporating calibration uncertainty.
These are needed if:
- :code:`calibration_marginalization` is used either during sampling or
post-processing
- the calibration model is :code:`Precomputed`
"""
sampling_calibration = self.calibration_model
sampling_marginalization = self.calibration_marginalization
calibration_lookup = self.calibration_lookup_table
n_response = self.number_of_response_curves
if self.reweighting_configuration is not None:
data = self.reweighting_configuration
if "calibration-model" in data:
self.calibration_model = data["calibration-model"]
for ifo in self.interferometers:
self.add_calibration_model_to_interferometers(ifo)
if "calibration-marginalization" in data:
self.calibration_marginalization = data["calibration-marginalization"]
if "calibration-lookup-table" in data:
self.calibration_lookup_table = data["calibration-lookup-table"]
if "number-of-response-curves" in data:
self.number_of_response_curves = data["number-of-response-curves"]
if (
self.calibration_marginalization
or self.calibration_model == "Precomputed"
or self.calibration_lookup_table
):
calibration.build_calibration_lookup(
interferometers=self.interferometers,
lookup_files=self.calibration_lookup_table,
priors=self.calibration_prior,
number_of_response_curves=self.number_of_response_curves,
)
self.calibration_model = sampling_calibration
for ifo in self.interferometers:
self.add_calibration_model_to_interferometers(ifo)
self.calibration_marginalization = sampling_marginalization
self.calibration_lookup_table = calibration_lookup
self.number_of_response_curves = n_response
[docs]
def save_data_dump(self):
"""Method to dump the saved data to disk for later analysis"""
self.build_calibration_lookups_if_needed()
likelihood = self.likelihood
if self.distance_marginalization:
likelihood_lookup_table = dict(
np.load(likelihood.cached_lookup_table_filename)
)
else:
likelihood_lookup_table = None
if hasattr(likelihood, "fiducial_parameters"):
self.meta_data["fiducial_parameters"] = likelihood.fiducial_parameters
if self.is_likelihood_multiband:
likelihood_roq_weights = None
likelihood_multiband_weights = likelihood.weights
else:
likelihood_roq_weights = getattr(likelihood, "weights", None)
likelihood_multiband_weights = None
self.meta_data["reweighting_configuration"] = self.reweighting_configuration
data_dump = DataDump(
outdir=self.data_directory,
label=self.label,
idx=self.idx,
trigger_time=self.trigger_time,
interferometers=self.interferometers,
meta_data=self.meta_data,
likelihood_lookup_table=likelihood_lookup_table,
likelihood_roq_weights=likelihood_roq_weights,
likelihood_roq_params=getattr(likelihood, "roq_params", None),
likelihood_multiband_weights=likelihood_multiband_weights,
priors_dict=dict(self.priors),
priors_class=self.priors.__class__,
)
data_dump.to_pickle()
[docs]
def save_roq_weights(self):
if self.likelihood_type != "ROQGravitationalWaveTransient":
raise ValueError(
"ROQ weights can only be saved for ROQGravitationalWaveTransient"
)
if getattr(self, "likelihood_roq_weights", None) is not None:
self.setup_roq_weights()
weight_file = os.path.join(
self.data_directory, f"{self.label}_roq_weights.{self.roq_weight_format}"
)
self.meta_data["weight_file"] = weight_file
self.likelihood.save_weights(weight_file, format=self.roq_weight_format)
[docs]
def save_multiband_weights(self):
if not self.is_likelihood_multiband:
raise ValueError(
"Multiband weights can only be saved for "
"MultibandGravitationalWaveTransient"
)
weight_file = os.path.join(
self.data_directory, f"{self.label}_multiband_weights.hdf5"
)
self.meta_data["weight_file"] = weight_file
self.likelihood.save_weights(weight_file)
[docs]
def create_generation_parser():
"""Data generation parser creation"""
return create_parser(top_level=False)
[docs]
def main():
"""Data generation main logic"""
args, unknown_args = parse_args(sys.argv[1:], create_generation_parser())
log_version_information()
data = DataGenerationInput(args, unknown_args)
data.save_data_dump()
logger.info("Completed data generation")