Source code for bilby_pipe.input
#!/usr/bin/env python
"""
Module containing the main input class
"""
import glob
import inspect
import itertools
import json
import os
import shutil
import subprocess
from importlib import import_module
import numpy as np
import pandas as pd
from gwosc.datasets import event_gps
import bilby
from . import utils
from .utils import (
SAMPLER_SETTINGS,
BilbyPipeError,
BilbyPipeInternalError,
check_if_psd_is_from_built_in,
convert_string_to_dict,
convert_string_to_list,
convert_string_to_tuple,
get_colored_string,
get_function_from_string_path,
get_time_prior,
logger,
pretty_print_dictionary,
)
[docs]
class Input(object):
"""Superclass of input handlers"""
def __init__(self, args, unknown_args, print_msg=True):
if print_msg:
logger.info(f"Command line arguments: {args}")
logger.info(f"Unknown command line arguments: {unknown_args}")
self.conda_env = getattr(self.known_args, "conda_env", None)
@property
@property
@idx.setter
def idx(self, idx):
self._idx = idx
@property
[docs]
def known_detectors(self):
dirs = os.path.join(os.path.dirname(bilby.gw.detector.__file__), "detectors")
known_files = glob.glob(os.path.join(dirs, "*"))
return [os.path.basename(kf).split(".")[0] for kf in known_files]
@property
[docs]
def detectors(self):
"""A list of the detectors to include, e.g., ['H1', 'L1']"""
return self._detectors
@detectors.setter
def detectors(self, detectors):
self._detectors = utils.convert_detectors_input(detectors)
self._check_detectors_against_known_detectors()
[docs]
def _check_detectors_against_known_detectors(self):
for element in self.detectors:
if element not in self.known_detectors:
msg = (
'Argument detectors contains "{}" not in the known '
"detectors list: {} ".format(element, self.known_detectors)
+ ". This will likely fail at the data generation step"
)
logger.warning(get_colored_string(msg))
@staticmethod
[docs]
def _split_string_by_space(string):
"""Converts "H1 L1" to ["H1", "L1"]"""
return string.split(" ")
@staticmethod
[docs]
def _convert_string_to_list(string):
"""Converts various strings to a list"""
string = string.replace(",", " ")
string = string.replace("[", "")
string = string.replace("]", "")
string = string.replace('"', "")
string = string.replace("'", "")
string_list = string.split()
return string_list
@property
[docs]
def outdir(self):
"""The path to the directory where output will be stored"""
utils.check_directory_exists_and_if_not_mkdir(self._outdir)
return self._outdir
@outdir.setter
def outdir(self, outdir):
if outdir == ".":
raise BilbyPipeError("Unable to use '.' as an outdir")
self._outdir = os.path.relpath(outdir)
@property
[docs]
def submit_directory(self):
"""The path to the directory where submit output will be stored"""
path = os.path.join(self._outdir, "submit")
utils.check_directory_exists_and_if_not_mkdir(path)
return path
@property
[docs]
def log_directory(self):
"""The top-level directory for the log directories"""
utils.check_directory_exists_and_if_not_mkdir(self._log_directory)
return self._log_directory
@log_directory.setter
def log_directory(self, log_directory):
if log_directory is None:
self._log_directory = self._outdir
else:
self._log_directory = log_directory
@property
[docs]
def data_generation_log_directory(self):
"""The path to the directory where generation logs will be stored"""
path = os.path.join(self.log_directory, "log_data_generation")
utils.check_directory_exists_and_if_not_mkdir(path)
return path
@property
[docs]
def data_analysis_log_directory(self):
"""The path to the directory where analysis logs will be stored"""
path = os.path.join(self.log_directory, "log_data_analysis")
utils.check_directory_exists_and_if_not_mkdir(path)
return path
@property
[docs]
def summary_log_directory(self):
"""The path to the directory where pesummary logs will be stored"""
path = os.path.join(self.log_directory, "log_results_page")
utils.check_directory_exists_and_if_not_mkdir(path)
return path
@property
[docs]
def data_directory(self):
"""The path to the directory where data output will be stored"""
path = os.path.join(self._outdir, "data")
utils.check_directory_exists_and_if_not_mkdir(path)
return path
@property
[docs]
def result_directory(self):
"""The path to the directory where result output will be stored"""
path = os.path.join(self._outdir, "result")
utils.check_directory_exists_and_if_not_mkdir(path)
return path
@property
[docs]
def final_result_directory(self):
"""The path to the directory where final result output will be stored"""
path = os.path.join(self._outdir, "final_result")
utils.check_directory_exists_and_if_not_mkdir(path)
return path
@property
[docs]
def webdir(self):
utils.check_directory_exists_and_if_not_mkdir(self._webdir)
return self._webdir
@webdir.setter
def webdir(self, webdir):
if webdir is None:
self._webdir = os.path.join(self.outdir, "results_page")
else:
self._webdir = webdir
@property
[docs]
def gps_file(self):
"""The gps file containing the list of gps times"""
return self._gps_file
@gps_file.setter
def gps_file(self, gps_file):
"""Set and parse the gps_file"""
if gps_file is None:
self._gps_file = None
return
elif os.path.isfile(gps_file):
self._gps_file = os.path.relpath(gps_file)
else:
raise FileNotFoundError(f"Input file gps_file={gps_file} does not exist")
self._parse_gps_file()
[docs]
def _parse_gps_file(self):
gpstimes = self.read_gps_file(self.gps_file)
n = len(gpstimes)
logger.debug(f"{n} start times found in gps_file={self.gps_file}")
self.gpstimes = gpstimes
@staticmethod
[docs]
def read_gps_file(gps_file):
gpstimes = np.loadtxt(gps_file, ndmin=2, delimiter=",")
if gpstimes.ndim > 1:
logger.debug(f"Reading column 0 from gps_file={gps_file}")
gpstimes = gpstimes[:, 0]
return gpstimes
@property
@gps_tuple.setter
def gps_tuple(self, gps_tuple):
"""Set and parse the gps_tuple"""
if gps_tuple is None:
self._gps_tuple = None
return
else:
self._gps_tuple = gps_tuple
self.gpstimes = self.parse_gps_tuple(gps_tuple)
@staticmethod
[docs]
def parse_gps_tuple(gps_tuple):
start, dt, N = convert_string_to_tuple(gps_tuple)
start_times = np.linspace(start, start + (N - 1) * dt, N)
return start_times
@property
[docs]
def timeslide_file(self):
"""Timeslide file.
Timeslide file containing the list of timeslides to apply to each
detector's start time.
"""
return self._timeslide_file
@timeslide_file.setter
def timeslide_file(self, timeslide_file):
"""Set the timeslide_file.
At setting, will check the file exists, read the contents,
save the timeslide value for each of the detectors.
"""
if timeslide_file is None:
self._timeslide_file = None
return
elif os.path.isfile(timeslide_file):
self._timeslide_file = os.path.relpath(timeslide_file)
else:
raise FileNotFoundError(
f"Input file timeslide_file={timeslide_file} not understood"
)
if hasattr(self, "_timeslide_file"):
self._parse_timeslide_file()
else:
logger.debug("No _parse_timeslide_file method present")
[docs]
def read_timeslide_file(self):
"""Read timeslide file.
Each row of file is an array, hence ndmin = 2
[ [timshift1,...], [], [] ...]
"""
timeslides_list = np.loadtxt(self.timeslide_file, ndmin=2)
return timeslides_list
[docs]
def _parse_timeslide_file(self):
"""Parse the timeslide file and check for correctness.
Sets the attribute "timeslides" if timeslide file correctly formatted
and passed to Inputs()
"""
timeslides_list = self.read_timeslide_file()
number_rows, number_columns = timeslides_list.shape
if number_columns != len(self.detectors):
raise BilbyPipeError(
"The timeslide file must have one column for each of the detectors. "
"Number Cols: {}, Number Detectors: {}".format(
number_columns, len(self.detectors)
)
)
if number_rows != len(self.gpstimes):
raise BilbyPipeError(
"The timeslide file must have one row for each gps time. "
"Number Rows: {}, Number Gps Times: {}".format(
number_rows, len(self.gpstimes)
)
)
times = np.hsplit(timeslides_list, len(self.detectors))
self.timeslides = {}
for i in range(len(self.detectors)):
self.timeslides.update({self.detectors[i]: times[i].flatten()})
logger.debug(
f"{number_rows} timeslides found in timeslide_file={self.timeslide_file}"
)
[docs]
def get_timeslide_dict(self, idx):
"""Return a specific timeslide value from the timeslide file.
Given an index, the dict of {detector: timeslide value} is created for
the specific index and returned.
"""
if not hasattr(self, "timeslides"):
raise BilbyPipeError("Timeslide file must be provided.")
if any(len(t) <= idx for t in self.timeslides.values()):
raise BilbyPipeError(
f"Timeslide index={idx} > number of timeslides available."
)
timeslide_val = {
det: timeslide[idx] for det, timeslide in self.timeslides.items()
}
logger.debug(f"Timeslide value: {timeslide_val}")
return timeslide_val
@property
[docs]
def bilby_frequency_domain_source_model(self):
"""
The bilby function to pass to the waveform_generator
This can be a function defined in an external package.
"""
if self.frequency_domain_source_model in bilby.gw.source.__dict__.keys():
model = self._frequency_domain_source_model
logger.debug(f"Using the {model} source model")
return bilby.gw.source.__dict__[model]
elif "." in self.frequency_domain_source_model:
return get_function_from_string_path(self._frequency_domain_source_model)
else:
raise BilbyPipeError(
f"No source model {self._frequency_domain_source_model} found."
)
@property
@reference_frequency.setter
def reference_frequency(self, reference_frequency):
self._reference_frequency = float(reference_frequency)
@property
@mode_array.setter
def mode_array(self, mode_array):
# Pre sanitize the mode array
if mode_array == [None]:
mode_array = None
if isinstance(mode_array, list):
# Hack because configargparse splits the mode_array
mode_array = ",".join(mode_array)
if mode_array is not None:
self._mode_array = convert_string_to_list(mode_array)
else:
logger.debug("mode_array not set")
self._mode_array = None
# Ensure it is a list of lists
if np.array(self._mode_array).ndim == 1:
self._mode_array = [self._mode_array]
if np.array(self._mode_array).ndim == 2:
for mode in self._mode_array:
if len(mode) != 2:
raise BilbyPipeError(f"mode_array {self._mode_array} is invalid")
if np.array(self._mode_array).ndim > 2:
raise BilbyPipeError(f"mode_array {self._mode_array} is invalid")
[docs]
def get_default_waveform_arguments(self):
wfa = dict(
reference_frequency=self.reference_frequency,
waveform_approximant=self.waveform_approximant,
minimum_frequency=self.minimum_frequency,
maximum_frequency=self.maximum_frequency,
catch_waveform_errors=self.catch_waveform_errors,
pn_spin_order=self.pn_spin_order,
pn_tidal_order=self.pn_tidal_order,
pn_phase_order=self.pn_phase_order,
pn_amplitude_order=self.pn_amplitude_order,
mode_array=self.mode_array,
)
if self.waveform_arguments_dict is not None:
wfa.update(convert_string_to_dict(self.waveform_arguments_dict))
logger.debug(f"Default waveform_arguments: {pretty_print_dictionary(wfa)}")
return wfa
[docs]
def get_injection_waveform_arguments(self):
"""Get the dict of the waveform arguments needed for creating injections.
Defaults the injection-waveform-approximant to waveform-approximant, if
no injection-waveform-approximant provided. Note that the default
waveform-approximant is `IMRPhenomPv2`.
"""
if self.injection_waveform_approximant is None:
self.injection_waveform_approximant = self.waveform_approximant
waveform_arguments = self.get_default_waveform_arguments()
if self.injection_waveform_arguments is not None:
waveform_arguments.update(
convert_string_to_dict(self.injection_waveform_arguments)
)
waveform_arguments["waveform_approximant"] = self.injection_waveform_approximant
waveform_arguments["numerical_relativity_file"] = self.numerical_relativity_file
return waveform_arguments
@property
[docs]
def bilby_roq_frequency_domain_source_model(self):
if "binary_neutron_star" in self.frequency_domain_source_model:
logger.debug("Using the binary_neutron_star_roq source model")
return bilby.gw.source.binary_neutron_star_roq
elif "binary_black_hole" in self.frequency_domain_source_model:
logger.debug("Using the binary_black_hole_roq source model")
return bilby.gw.source.binary_black_hole_roq
else:
raise BilbyPipeError("Unable to determine roq_source from source model")
@property
[docs]
def bilby_relative_binning_frequency_domain_source_model(self):
if "binary_neutron_star" in self.frequency_domain_source_model:
logger.debug("Using the binary_neutron_star_relative_binning source model")
return bilby.gw.source.lal_binary_neutron_star_relative_binning
elif "binary_black_hole" in self.frequency_domain_source_model:
logger.debug("Using the binary_black_hole_relative_binning source model")
return bilby.gw.source.lal_binary_black_hole_relative_binning
else:
raise BilbyPipeError(
"Unable to determine relative binning source from source model"
)
@property
[docs]
def bilby_multiband_frequency_domain_source_model(self):
if "binary_neutron_star" in self.frequency_domain_source_model:
logger.info("Using the binary_neutron_star_frequency_sequence source model")
return bilby.gw.source.binary_neutron_star_frequency_sequence
elif "binary_black_hole" in self.frequency_domain_source_model:
logger.info("Using the binary_black_hole_frequency_sequence source model")
return bilby.gw.source.binary_black_hole_frequency_sequence
else:
raise BilbyPipeError(
"Unable to determine multiband_source from source model"
)
@property
[docs]
def frequency_domain_source_model(self):
"""String of which frequency domain source model to use"""
return self._frequency_domain_source_model
@frequency_domain_source_model.setter
def frequency_domain_source_model(self, frequency_domain_source_model):
self._frequency_domain_source_model = frequency_domain_source_model
@property
@trigger_time.setter
def trigger_time(self, trigger_time):
# Convert trigger time
if trigger_time is None:
logger.debug("No trigger time given")
elif isinstance(trigger_time, str) and "GW" in trigger_time:
logger.debug(f"Using gwosc to find trigger time for event {trigger_time}")
trigger_time = event_gps(trigger_time)
else:
trigger_time = float(trigger_time)
self._trigger_time = trigger_time
if trigger_time is not None:
logger.debug(f"Setting trigger time {trigger_time}")
@property
[docs]
def start_time(self):
if hasattr(self, "_start_time"):
self._verify_start_time(self._start_time)
return self._start_time
try:
self._start_time = (
self.trigger_time + self.post_trigger_duration - self.duration
)
return self._start_time
except AttributeError:
logger.warning("Unable to calculate default segment start time")
return None
[docs]
def _verify_start_time(self, start_time):
try:
inferred_start_time = (
self.trigger_time + self.post_trigger_duration - self.duration
)
except AttributeError:
logger.warning("Unable to verify start-time consistency")
return
if inferred_start_time != start_time:
raise BilbyPipeError("Unexpected behaviour encountered with start time")
@start_time.setter
def start_time(self, start_time):
self._verify_start_time(start_time)
self._start_time = start_time
if start_time is not None:
logger.debug(f"Setting segment start time {start_time}")
@property
@duration.setter
def duration(self, duration):
self._duration = duration
if duration is not None:
logger.debug(f"Setting segment duration {duration}s")
@property
[docs]
def injection_numbers(self):
if hasattr(self, "_injection_numbers"):
return self._injection_numbers
else:
raise BilbyPipeInternalError("Injection numbers requested, but not yet set")
@injection_numbers.setter
def injection_numbers(self, injection_numbers):
if (
injection_numbers is None
or len(injection_numbers) == 0
or injection_numbers == "None"
or injection_numbers[0] == "None"
or injection_numbers[0] is None
):
self._injection_numbers = None
elif all(
i is not None
and not isinstance(i, float)
and utils.check_if_represents_int(i)
or ":" in str(i)
for i in injection_numbers
):
list_of_lists = []
for i in injection_numbers:
if utils.check_if_represents_int(i):
list_of_lists.append([int(i)])
else:
list_of_lists.append(utils.convert_string_slice_syntax(i))
self._injection_numbers = list(set(itertools.chain(*list_of_lists)))
else:
raise BilbyPipeError(f"Invalid injection numbers {injection_numbers}")
@property
@injection_df.setter
def injection_df(self, injection_df):
if isinstance(injection_df, pd.DataFrame) is False:
raise BilbyPipeError("Setting injection df with non-pandas DataFrame")
elif self.injection_numbers is not None:
logger.debug(
f"Truncating injection injection df to rows {self.injection_numbers}"
)
try:
self._injection_df = injection_df.iloc[self.injection_numbers]
except IndexError:
raise BilbyPipeError(
"Your injection_numbers are incompatible with the injection set"
)
else:
self._injection_df = injection_df
@property
@injection_file.setter
def injection_file(self, injection_file):
if injection_file is None:
logger.debug("No injection file set")
self._injection_file = None
elif os.path.isfile(injection_file):
self._injection_file = os.path.relpath(injection_file)
self.injection_df = self.read_injection_file(injection_file)
self.total_number_of_injections = len(self.injection_df)
self.injection = True
else:
raise FileNotFoundError(f"Injection file {injection_file} not found")
@property
@injection_dict.setter
def injection_dict(self, injection_dict):
if injection_dict is None:
self._injection_dict = None
return
elif isinstance(injection_dict, str):
self._injection_dict = convert_string_to_dict(injection_dict)
elif isinstance(injection_dict, dict):
self._injection_dict = injection_dict
else:
raise BilbyPipeError("injection-dict can not be coerced to a dict")
self.injection_df = pd.DataFrame(self._injection_dict, index=[0])
self.total_number_of_injections = 1
self.injection = True
@staticmethod
[docs]
def read_injection_file(injection_file):
if "json" in injection_file:
return Input.read_json_injection_file(injection_file)
elif "dat" in injection_file:
return Input.read_dat_injection_file(injection_file)
@staticmethod
[docs]
def read_json_injection_file(injection_file):
with open(injection_file, "r") as file:
injection_dict = json.load(
file, object_hook=bilby.core.utils.decode_bilby_json
)
injection_df = injection_dict["injections"]
try:
injection_df = pd.DataFrame(injection_df)
except ValueError:
# If injection_df is a dictionary of single elements, set the index-array in pandas
injection_df = pd.DataFrame(injection_df, index=[0])
return injection_df
@staticmethod
[docs]
def read_dat_injection_file(injection_file):
return pd.read_csv(injection_file, delim_whitespace=True)
@property
@spline_calibration_envelope_dict.setter
def spline_calibration_envelope_dict(self, spline_calibration_envelope_dict):
if spline_calibration_envelope_dict is not None:
self._spline_calibration_envelope_dict = convert_string_to_dict(
spline_calibration_envelope_dict, "spline-calibration-envelope-dict"
)
else:
logger.debug("spline_calibration_envelope_dict")
self._spline_calibration_envelope_dict = None
@property
[docs]
def spline_calibration_amplitude_uncertainty_dict(self):
return self._spline_calibration_amplitude_uncertainty_dict
@spline_calibration_amplitude_uncertainty_dict.setter
def spline_calibration_amplitude_uncertainty_dict(
self, spline_calibration_amplitude_uncertainty_dict
):
if spline_calibration_amplitude_uncertainty_dict is not None:
self._spline_calibration_amplitude_uncertainty_dict = (
convert_string_to_dict(
spline_calibration_amplitude_uncertainty_dict,
"spline-calibration-amplitude-uncertainty-dict",
)
)
else:
logger.debug("spline_calibration_amplitude_uncertainty_dict")
self._spline_calibration_amplitude_uncertainty_dict = None
@property
[docs]
def spline_calibration_phase_uncertainty_dict(self):
return self._spline_calibration_phase_uncertainty_dict
@spline_calibration_phase_uncertainty_dict.setter
def spline_calibration_phase_uncertainty_dict(
self, spline_calibration_phase_uncertainty_dict
):
if spline_calibration_phase_uncertainty_dict is not None:
self._spline_calibration_phase_uncertainty_dict = convert_string_to_dict(
spline_calibration_phase_uncertainty_dict,
"spline-calibration-phase-uncertainty-dict",
)
else:
logger.debug("spline_calibration_phase_uncertainty_dict")
self._spline_calibration_phase_uncertainty_dict = None
@property
[docs]
def minimum_frequency(self):
"""The minimum frequency
If a per-detector dictionary is given, this will return the minimum
frequency value. To access the dictionary,
see self.minimum_frequency_dict
"""
return self._minimum_frequency
@minimum_frequency.setter
def minimum_frequency(self, minimum_frequency):
if minimum_frequency is None:
self._minimum_frequency = None
self.minimum_frequency_dict = {det: None for det in self.detectors}
else:
try:
self._minimum_frequency = float(minimum_frequency)
self.minimum_frequency_dict = {
det: float(minimum_frequency) for det in self.detectors
}
except ValueError:
self.minimum_frequency_dict = convert_string_to_dict(
minimum_frequency, "minimum-frequency"
)
self._minimum_frequency = np.min(
[xx for xx in self._minimum_frequency_dict.values()]
).item()
@property
@minimum_frequency_dict.setter
def minimum_frequency_dict(self, minimum_frequency_dict):
self.test_frequency_dict(frequency_dict=minimum_frequency_dict, label="minimum")
self._minimum_frequency_dict = minimum_frequency_dict
@property
[docs]
def maximum_frequency(self):
"""The maximum frequency
If a per-detector dictionary is given, this will return the maximum
frequency value. To access the dictionary,
see self.maximum_frequency_dict
"""
return self._maximum_frequency
@maximum_frequency.setter
def maximum_frequency(self, maximum_frequency):
if maximum_frequency is None:
self._maximum_frequency = self.sampling_frequency / 2
self.maximum_frequency_dict = {
det: self._maximum_frequency for det in self.detectors
}
logger.debug(
"No maximum frequency given. "
"Setting to sampling frequency / 2 = {}".format(self._maximum_frequency)
)
else:
try:
self._maximum_frequency = float(maximum_frequency)
self.maximum_frequency_dict = {
det: float(maximum_frequency) for det in self.detectors
}
except ValueError:
self.maximum_frequency_dict = convert_string_to_dict(
maximum_frequency, "maximum-frequency"
)
self._maximum_frequency = np.max(
[xx for xx in self._maximum_frequency_dict.values()]
).item()
@property
@maximum_frequency_dict.setter
def maximum_frequency_dict(self, maximum_frequency_dict):
self.test_frequency_dict(frequency_dict=maximum_frequency_dict, label="maximum")
self._maximum_frequency_dict = maximum_frequency_dict
[docs]
def test_frequency_dict(self, frequency_dict, label=""):
for det in self.detectors:
if det not in frequency_dict.keys():
raise BilbyPipeError(
f"Input {label} frequency required for detector {det}"
)
return frequency_dict
@property
@staticmethod
[docs]
def get_default_prior_files():
"""Returns a dictionary of the default priors"""
prior_files_glob = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "data_files/*prior"
)
filenames = glob.glob(prior_files_glob)
return {os.path.basename(ff).rstrip(".prior"): ff for ff in filenames}
[docs]
def get_distance_file_lookup_table(self, prior_file_str):
lookup_file_source = self.get_source_distance_file_lookup_table(prior_file_str)
fname = os.path.basename(lookup_file_source)
lookup_file_dest = os.path.join(self.outdir, "." + fname)
if os.path.isfile(lookup_file_source):
shutil.copyfile(lookup_file_source, lookup_file_dest)
return lookup_file_dest
[docs]
def get_source_distance_file_lookup_table(self, prior_file_str):
direc = os.path.dirname(self.default_prior_files[prior_file_str])
if self.phase_marginalization:
fname = f"{prior_file_str}_distance_marginalization_lookup_phase.npz"
else:
fname = f"{prior_file_str}_distance_marginalization_lookup.npz"
return os.path.join(direc, fname)
@property
@prior_file.setter
def prior_file(self, prior_file):
if prior_file is None:
self._prior_file = None
elif os.path.isfile(prior_file):
self._prior_file = prior_file
elif os.path.isfile(os.path.basename(prior_file)):
# Allows for the prior-file to be moved to the local directory (file-transfer mechanism)
self._prior_file = os.path.basename(prior_file)
elif prior_file in self.default_prior_files:
self._prior_file = self.default_prior_files[prior_file]
self.distance_marginalization_lookup_table = (
self.get_distance_file_lookup_table(prior_file)
)
else:
raise FileNotFoundError(f"No prior file {prior_file} available")
logger.debug(f"Setting prior-file to {self._prior_file}")
@property
[docs]
def prior_dict(self):
"""The input prior_dict from the ini (if given)
Note, this is not the bilby prior (see self.priors for that), this is
a key-val dictionary where the val's are strings which are converting
into bilby priors in `_get_prior
"""
return self._prior_dict
@prior_dict.setter
def prior_dict(self, prior_dict):
if isinstance(prior_dict, dict):
prior_dict = prior_dict
elif isinstance(prior_dict, str):
prior_dict = utils.convert_prior_string_input(prior_dict)
elif prior_dict is None:
self._prior_dict = None
return
else:
raise BilbyPipeError(f"prior_dict={prior_dict} not understood")
self._prior_dict = {
self._convert_prior_dict_key(key): val for key, val in prior_dict.items()
}
@staticmethod
[docs]
def _convert_prior_dict_key(key):
"""Converts the prior dict key to standard form
In the ini read, mass_1 -> mass-1, this corrects for that
"""
if "-" in key:
key_replaced = key.replace("-", "_")
logger.debug(f"Converting prior-dict key {key} to {key_replaced}")
key = key_replaced
return key
@property
[docs]
def distance_marginalization_lookup_table(self):
return self._distance_marginalization_lookup_table
@distance_marginalization_lookup_table.setter
def distance_marginalization_lookup_table(
self, distance_marginalization_lookup_table
):
if distance_marginalization_lookup_table is None:
if hasattr(self, "_distance_marginalization_lookup_table"):
pass
else:
self._distance_marginalization_lookup_table = None
else:
if hasattr(self, "_distance_marginalization_lookup_table"):
logger.debug("Overwriting distance_marginalization_lookup_table")
self._distance_marginalization_lookup_table = (
distance_marginalization_lookup_table
)
@property
@default_prior.setter
def default_prior(self, default_prior):
self._default_prior = default_prior
@property
[docs]
def combined_default_prior_dicts(self):
d = bilby.core.prior.__dict__.copy()
d.update(bilby.gw.prior.__dict__)
return d
@property
[docs]
def create_time_prior(self):
cond_a = getattr(self, "trigger_time", None) is not None
cond_b = getattr(self, "deltaT", None) is not None
if cond_a and cond_b:
logger.debug(
"Setting geocent time prior using trigger-time={} and deltaT={}".format(
self.trigger_time, self.deltaT
)
)
if self.time_reference == "geocent":
latex_label = "$t_c$"
else:
latex_label = f"$t_{self.time_reference[0]}$"
time_prior = get_time_prior(
time=self.trigger_time,
uncertainty=self.deltaT / 2.0,
name=self.time_parameter,
latex_label=latex_label,
)
else:
raise BilbyPipeError("Unable to set geocent_time prior from trigger_time")
return time_prior
@property
[docs]
def priors(self):
"""Read in and compose the prior at run-time"""
if getattr(self, "_priors", None) is None:
self._priors = self._get_priors()
error = self.enforce_signal_duration
if error:
self._priors.validate_prior(
duration=self.duration,
minimum_frequency=self.minimum_frequency,
error=True,
warning=False,
)
return self._priors
@priors.setter
def priors(self, priors):
self._priors = priors
[docs]
def _get_priors(self, add_time=True):
"""Construct the priors
Parameters
----------
add_time: bool
If True, the time prior is constructed from either the
prior file or the trigger time. If False (used for the overview
page where a single time-prior doesn't make sense), this isn't
added to the prior
Returns
-------
prior: bilby.core.prior.PriorDict
The generated prior
"""
if self.default_prior in self.combined_default_prior_dicts.keys():
prior_class = self.combined_default_prior_dicts[self.default_prior]
elif "." in self.default_prior:
prior_class = get_function_from_string_path(self.default_prior)
else:
raise ValueError("Unable to set prior: default_prior unavailable")
if self.prior_dict is not None:
priors = prior_class(dictionary=self.prior_dict)
else:
priors = prior_class(filename=self.prior_file)
priors = self._update_default_prior_to_sky_frame_parameters(priors)
if self.time_parameter in priors:
logger.debug(f"Using {self.time_parameter} prior from prior_file")
elif add_time:
priors[self.time_parameter] = self.create_time_prior()
else:
logger.debug("No time prior available or requested")
if self.calibration_model is not None:
priors.update(self.calibration_prior)
return priors
[docs]
def _get_default_sky_priors(self):
return bilby.core.prior.PriorDict(
dict(
dec=bilby.core.prior.Cosine(name="dec"),
ra=bilby.core.prior.Uniform(
name="ra", minimum=0, maximum=2 * np.pi, boundary="periodic"
),
)
)
[docs]
def _priors_contains_default_sky_prior(self, priors):
sky_priors = self._get_default_sky_priors()
for key in sky_priors:
if sky_priors[key] != priors.get(key, None):
return False
return True
[docs]
def _update_default_prior_to_sky_frame_parameters(self, priors):
if (
self._priors_contains_default_sky_prior(priors)
and self.reference_frame != "sky"
):
if "ra" in priors:
del priors["ra"]
if "dec" in priors:
del priors["dec"]
if "azimuth" not in priors:
priors["azimuth"] = bilby.core.prior.Uniform(
minimum=0,
maximum=2 * np.pi,
latex_label="$\\epsilon$",
boundary="periodic",
)
if "zenith" not in priors:
priors["zenith"] = bilby.core.prior.Sine(latex_label="$\\kappa$")
return priors
@property
[docs]
def calibration_prior(self):
if self.calibration_model is None:
return None
if getattr(self, "_calibration_prior", None) is not None:
return self._calibration_prior
self._calibration_prior = bilby.core.prior.PriorDict()
if self.calibration_model.lower() == "cubicspline":
for det in self.detectors:
if (
self.spline_calibration_envelope_dict is not None
and det in self.spline_calibration_envelope_dict
):
logger.info(
"Creating calibration prior for {} from {}".format(
det, self.spline_calibration_envelope_dict[det]
)
)
self._calibration_prior.update(
bilby.gw.prior.CalibrationPriorDict.from_envelope_file(
self.spline_calibration_envelope_dict[det],
minimum_frequency=self.minimum_frequency_dict[det],
maximum_frequency=self.maximum_frequency_dict[det],
n_nodes=self.spline_calibration_nodes,
label=det,
boundary=self.calibration_prior_boundary,
)
)
elif (
det in self.spline_calibration_amplitude_uncertainty_dict
and det in self.spline_calibration_phase_uncertainty_dict
):
logger.debug(
"Creating calibration prior for {} from "
"provided constant uncertainty values.".format(det)
)
self._calibration_prior.update(
bilby.gw.prior.CalibrationPriorDict.constant_uncertainty_spline(
amplitude_sigma=self.spline_calibration_amplitude_uncertainty_dict[
det
],
phase_sigma=self.spline_calibration_phase_uncertainty_dict[
det
],
minimum_frequency=self.minimum_frequency_dict[det],
maximum_frequency=self.maximum_frequency_dict[det],
n_nodes=self.spline_calibration_nodes,
label=det,
)
)
else:
logger.warning(f"No calibration information for {det}")
elif self.calibration_model.lower() == "precomputed":
for det in self.detectors:
self._calibration_prior[
f"recalib_index_{det}"
] = bilby.core.prior.Categorical(1000)
return self._calibration_prior
@property
@calibration_model.setter
def calibration_model(self, calibration_model):
if calibration_model is not None:
logger.debug(f"Setting calibration_model={calibration_model}")
self._calibration_model = calibration_model
else:
logger.debug(
"No calibration_model model provided, calibration "
"marginalization will not be used"
)
self._calibration_model = None
@property
@calibration_lookup_table.setter
def calibration_lookup_table(self, lookup):
if isinstance(lookup, str):
lookup = convert_string_to_dict(lookup)
self._calibration_lookup_table = lookup
@property
[docs]
def likelihood(self):
self.search_priors = self.priors.copy()
likelihood_kwargs = dict(
interferometers=self.interferometers,
waveform_generator=self.waveform_generator,
priors=self.search_priors,
phase_marginalization=self.phase_marginalization,
distance_marginalization=self.distance_marginalization,
distance_marginalization_lookup_table=self.distance_marginalization_lookup_table,
time_marginalization=self.time_marginalization,
reference_frame=self.reference_frame,
time_reference=self.time_reference,
calibration_marginalization=self.calibration_marginalization,
calibration_lookup_table=self.calibration_lookup_table,
number_of_response_curves=self.number_of_response_curves,
)
relative_binning_kwargs = dict(
fiducial_parameters=self.fiducial_parameters,
update_fiducial_parameters=self.update_fiducial_parameters,
epsilon=self.epsilon,
)
if getattr(self, "likelihood_lookup_table", None) is not None:
logger.debug("Using internally loaded likelihood_lookup_table")
likelihood_kwargs["distance_marginalization_lookup_table"] = getattr(
self, "likelihood_lookup_table"
)
if self.likelihood_type in ["GravitationalWaveTransient", "zero"]:
Likelihood = bilby.gw.likelihood.GravitationalWaveTransient
likelihood_kwargs.update(jitter_time=self.jitter_time)
elif self.likelihood_type == "ROQGravitationalWaveTransient":
Likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient
likelihood_kwargs.update(
self.roq_likelihood_kwargs, jitter_time=self.jitter_time
)
elif self.likelihood_type == "RelativeBinningGravitationalWaveTransient":
Likelihood = bilby.gw.likelihood.RelativeBinningGravitationalWaveTransient
likelihood_kwargs.update(relative_binning_kwargs)
elif self.likelihood_type == "MBGravitationalWaveTransient":
Likelihood = bilby.gw.likelihood.MBGravitationalWaveTransient
likelihood_kwargs.update(self.multiband_likelihood_kwargs)
likelihood_kwargs.update(self.extra_likelihood_kwargs)
elif "." in self.likelihood_type:
split_path = self.likelihood_type.split(".")
module = ".".join(split_path[:-1])
likelihood_class = split_path[-1]
Likelihood = getattr(import_module(module), likelihood_class)
likelihood_kwargs.update(self.extra_likelihood_kwargs)
if "roq" in self.likelihood_type.lower():
likelihood_kwargs.update(self.roq_likelihood_kwargs)
elif "relative" in self.likelihood_type.lower():
likelihood_kwargs.update(relative_binning_kwargs)
if "multiband" in self.likelihood_type.lower():
likelihood_kwargs.update(self.multiband_likelihood_kwargs)
else:
raise ValueError("Unknown Likelihood class {}")
likelihood_kwargs = {
key: likelihood_kwargs[key]
for key in likelihood_kwargs
if key in inspect.getfullargspec(Likelihood.__init__).args
}
logger.debug(
f"Initialise likelihood {Likelihood} with kwargs: \n{likelihood_kwargs}"
)
if likelihood_kwargs.get("update_fiducial_parameters", False):
for key in self.calibration_prior:
self.search_priors[key] = self.calibration_prior[key].rescale(0.5)
likelihood = Likelihood(**likelihood_kwargs)
if likelihood_kwargs.get("update_fiducial_parameters", False):
for key in self.calibration_prior:
self.search_priors[key] = self.calibration_prior[key]
# If requested, use a zero likelihood: for testing purposes
if self.likelihood_type == "zero":
logger.debug("Using a ZeroLikelihood")
likelihood = bilby.core.likelihood.ZeroLikelihood(likelihood)
return likelihood
@property
@extra_likelihood_kwargs.setter
def extra_likelihood_kwargs(self, likelihood_kwargs):
if isinstance(likelihood_kwargs, str):
likelihood_kwargs = utils.convert_string_to_dict(likelihood_kwargs)
elif likelihood_kwargs is None:
likelihood_kwargs = dict()
elif not isinstance(likelihood_kwargs, dict):
raise TypeError(
f"Type {type(likelihood_kwargs)} not understood for likelihood kwargs."
)
for key in list(likelihood_kwargs.keys()):
likelihood_kwargs[key.replace("-", "_")] = likelihood_kwargs.pop(key)
forbidden_keys = [
"interferometers",
"waveform_generator",
"priors",
"distance_marginalization",
"time_marginalization",
"phase_marginalization",
"jitter_time",
"distance_marginalization_lookup_table",
"reference_frame",
"time_reference",
]
if "roq" in self.likelihood_type.lower():
forbidden_keys += ["weights", "roq_params", "roq_scale_factor"]
for key in forbidden_keys:
if key in likelihood_kwargs:
raise KeyError(
"{} should be passed through named argument not likelihood_kwargs".format(
key
)
)
self._extra_likelihood_kwargs = likelihood_kwargs
@property
[docs]
def roq_likelihood_kwargs(self):
if hasattr(self, "likelihood_roq_params"):
params = self.likelihood_roq_params
elif self.roq_folder is not None:
params = np.genfromtxt(self.roq_folder + "/params.dat", names=True)
else:
params = None
if hasattr(self, "likelihood_roq_weights"):
weights = self.likelihood_roq_weights
else:
weights = self.meta_data["weight_file"]
logger.debug(f"Loading ROQ weights from {weights}")
return dict(
weights=weights, roq_params=params, roq_scale_factor=self.roq_scale_factor
)
@property
[docs]
def multiband_likelihood_kwargs(self):
if hasattr(self, "likelihood_multiband_weights"):
weights = self.likelihood_multiband_weights
else:
weights = self.meta_data["weight_file"]
logger.info(f"Loading multiband weights from {weights}")
return dict(weights=weights)
@property
[docs]
def parameter_conversion(self):
cf = self.conversion_function
_lookups = dict(noconvert=None)
if isinstance(cf, str) and cf.lower() in _lookups:
logger.info(f"Using conversion function {cf} from lookups")
return _lookups[cf.lower()]
elif cf is not None:
logger.info(f"Using user-specified conversion_function {cf}")
return get_function_from_string_path(cf)
elif "binary_neutron_star" in self._frequency_domain_source_model:
logger.info(
"Using conversion_function convert_to_lal_binary_neutron_star_parameters"
)
return bilby.gw.conversion.convert_to_lal_binary_neutron_star_parameters
elif "binary_black_hole" in self._frequency_domain_source_model:
logger.info(
"Using conversion_function convert_to_lal_binary_black_hole_parameters"
)
return bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters
else:
logger.info("No conversion function")
return None
@property
[docs]
def waveform_generator(self):
waveform_arguments = self.get_default_waveform_arguments()
if "ROQ" in self.likelihood_type:
if self.roq_folder is not None:
logger.info(
"Using {} likelihood with roq-folder={}".format(
self.likelihood_type, self.roq_folder
)
)
freq_nodes_linear = np.load(self.roq_folder + "/fnodes_linear.npy")
freq_nodes_quadratic = np.load(
self.roq_folder + "/fnodes_quadratic.npy"
)
freq_nodes_linear *= self.roq_scale_factor
freq_nodes_quadratic *= self.roq_scale_factor
waveform_arguments["frequency_nodes_linear"] = freq_nodes_linear
waveform_arguments["frequency_nodes_quadratic"] = freq_nodes_quadratic
fdsm = self.bilby_roq_frequency_domain_source_model
elif "relative" in self.likelihood_type.lower():
fdsm = self.bilby_relative_binning_frequency_domain_source_model
elif self.is_likelihood_multiband:
fdsm = self.bilby_multiband_frequency_domain_source_model
else:
fdsm = self.bilby_frequency_domain_source_model
waveform_generator = self.waveform_generator_class(
frequency_domain_source_model=fdsm,
sampling_frequency=self.interferometers.sampling_frequency,
duration=self.interferometers.duration,
parameter_conversion=self.parameter_conversion,
start_time=self.interferometers.start_time,
waveform_arguments=waveform_arguments,
)
return waveform_generator
@property
@waveform_generator_class.setter
def waveform_generator_class(self, class_name):
if "." in class_name:
module = ".".join(class_name.split(".")[:-1])
class_name = class_name.split(".")[-1]
else:
module = "bilby.gw.waveform_generator"
wfg_class = getattr(import_module(module), class_name, None)
if wfg_class is not None:
self._waveform_generator_class = wfg_class
else:
raise BilbyPipeError(
f"Cannot import waveform generator class {module}.{class_name}"
)
@property
[docs]
def parameter_generation(self):
gf = self.generation_function
_lookups = dict(noconvert=None)
if isinstance(gf, str) and gf.lower() in _lookups:
logger.info(f"Using generation function {gf} from lookups")
return _lookups[gf.lower()]
elif gf is not None:
logger.info(f"Using user-specified generation {gf}")
return get_function_from_string_path(self.generation_function)
elif "binary_neutron_star" in self._frequency_domain_source_model:
logger.info("Using generation_function generate_all_bns_parameters")
return bilby.gw.conversion.generate_all_bns_parameters
elif "binary_black_hole" in self._frequency_domain_source_model:
logger.info("Using generation_function generate_all_bbh_parameters")
return bilby.gw.conversion.generate_all_bbh_parameters
else:
logger.info("No conversion function")
return None
@property
@summarypages_arguments.setter
def summarypages_arguments(self, summarypages_arguments):
if summarypages_arguments is None:
self._summarypages_arguments = None
return
string = summarypages_arguments
if "{" in string and "}" in string:
self._summarypages_arguments = convert_string_to_dict(string)
else:
self._summarypages_arguments = summarypages_arguments
@property
@postprocessing_arguments.setter
def postprocessing_arguments(self, postprocessing_arguments):
if postprocessing_arguments in [None, "None"]:
self._postprocessing_arguments = None
elif postprocessing_arguments == [None]:
self._postprocessing_arguments = None
elif isinstance(postprocessing_arguments, str):
self._postprocessing_arguments = postprocessing_arguments.split(" ")
else:
self._postprocessing_arguments = postprocessing_arguments
@property
@sampler.setter
def sampler(self, sampler):
"""Setter for the sampler"""
if not isinstance(sampler, str):
raise BilbyPipeError("Sampler must be a single string")
elif sampler in bilby.core.sampler.IMPLEMENTED_SAMPLERS:
self._sampler = sampler
else:
raise BilbyPipeError(f"Requested sampler {sampler} not implemented")
@property
@sampler_kwargs.setter
def sampler_kwargs(self, sampler_kwargs):
# Set up the default choices
if self.sampler == "dynesty":
self._sampler_kwargs = SAMPLER_SETTINGS["DynestyDefault"]
elif self.sampler == "bilby_mcmc":
self._sampler_kwargs = SAMPLER_SETTINGS["BilbyMCMCDefault"]
else:
self._sampler_kwargs = dict()
if sampler_kwargs is not None:
if sampler_kwargs.lower() in ["default", "dynestydefault"]:
self._sampler_kwargs.update(SAMPLER_SETTINGS["DynestyDefault"])
elif sampler_kwargs.lower() == "bilbymcmcdefault":
self._sampler_kwargs.update(SAMPLER_SETTINGS["BilbyMCMCDefault"])
elif sampler_kwargs.lower() == "fasttest":
self._sampler_kwargs.update(SAMPLER_SETTINGS["FastTest"])
else:
self._sampler_kwargs.update(
convert_string_to_dict(sampler_kwargs, "sampler-kwargs")
)
else:
self._sampler_kwargs = dict()
self.update_sampler_kwargs_conditional_on_request_cpus()
[docs]
def update_sampler_kwargs_conditional_on_request_cpus(self):
"""If the user adds request-cpu >1, update npool in the sampler kwargs"""
# Only run if request_cpus > 1
if self.request_cpus > 1:
self._sampler_kwargs.update(npool=self.request_cpus)
[docs]
def pretty_print_prior(self):
try:
prior = self._get_priors(add_time=False)
except Exception as e:
raise BilbyPipeError(
get_colored_string(f"Unable to parse prior, exception raised {e}")
)
pp = pretty_print_dictionary(prior)
logger.info(f"Input prior = {pp}")
@property
@conda_env.setter
def conda_env(self, conda_env):
if conda_env is None:
self._conda_env = None
self._conda_path = None
elif os.path.isdir(conda_env):
# conda_env is a path
self._conda_path = conda_env
self._conda_env = conda_env.rstrip("/").split("/")
elif (
_conda_path := self._determine_conda_path_from_env(conda_env)
) is not None:
self._conda_env = conda_env
self._conda_path = _conda_path
else:
raise BilbyPipeError(
f"conda_env={conda_env} not recognised as an environment"
)
@staticmethod
[docs]
def _determine_conda_path_from_env(conda_env):
run_out = subprocess.run(["conda", "env", "list"], capture_output=True)
elist = run_out.stdout.decode().split("\n")
for line in elist:
line_split = line.split()
if (len(line_split) > 0) and (conda_env == line_split[0]):
_conda_path = line_split[-1]
if line_split[1] != "*":
logger.warning(
get_colored_string(
"You are running bilby_pipe from a different environment"
" than the run-time environment. This is not recommended"
" and could cause unexpected behaviour"
)
)
return _conda_path
@property
@psd_dict.setter
def psd_dict(self, psd_dict):
if psd_dict is not None:
self._psd_dict = convert_string_to_dict(psd_dict, "psd-dict")
self._validate_psd_dict()
else:
logger.debug("psd-dict set to None")
self._psd_dict = None
[docs]
def _validate_psd_dict(self):
# Check all detectors are listed
for det in self.detectors:
if det not in self.psd_dict:
raise BilbyPipeError(f"Detector {det} not listed in the psd-dict")
# Check all PSD files exist
for det, psd_file in self.psd_dict.items():
if os.path.exists(psd_file):
return
elif check_if_psd_is_from_built_in(psd_file):
return
else:
raise BilbyPipeError(
f"PSD file {psd_file} for detector {det} does not exist"
)
@property
@additional_transfer_paths.setter
def additional_transfer_paths(self, paths):
if isinstance(paths, list):
self._additional_transfer_paths = paths
if paths is None or paths == [None]:
self._additional_transfer_paths = list()
@property
[docs]
def is_likelihood_multiband(self):
return (
self.likelihood_type == "MBGravitationalWaveTransient"
or "multiband" in self.likelihood_type.lower()
)
@property
[docs]
def reweighting_configuration(self):
return getattr(self, "_reweighting_configuration", None)
@reweighting_configuration.setter
def reweighting_configuration(self, conf):
if isinstance(conf, str):
if os.path.exists(conf):
with open(conf, "r") as ff:
conf = json.load(ff)
else:
try:
conf = convert_string_to_dict(conf)
except (ValueError, SyntaxError):
logger.error("Cannot parse reweighting configuration")
raise
self._reweighting_configuration = conf