Source code for bilby_pipe.data_analysis
#!/usr/bin/env python
""" Script to perform data analysis """
import os
import signal
import sys
import time
import numpy as np
import bilby
from bilby_pipe.input import Input
from bilby_pipe.main import parse_args
from bilby_pipe.parser import _get_items_in_group, create_parser
from bilby_pipe.utils import (
CHECKPOINT_EXIT_CODE,
BilbyPipeError,
DataDump,
log_version_information,
logger,
)
# fmt: off
import matplotlib # isort:skip
matplotlib.use("agg")
# fmt: on
[docs]
def sighandler(signum, frame):
logger.info("Performing periodic eviction")
sys.exit(CHECKPOINT_EXIT_CODE)
[docs]
class DataAnalysisInput(Input):
"""Handles user-input for the data analysis script
Parameters
----------
parser: BilbyArgParser, optional
The parser containing the command line / ini file inputs
args_list: list, optional
A list of the arguments to parse. Defaults to `sys.argv[1:]`
"""
def __init__(self, args, unknown_args, test=False):
super().__init__(args, unknown_args)
# Generic initialisation
# Admin arguments
# Naming arguments
# Data dump file to run on
# Choices for running
self.sampling_seed = args.sampling_seed
# Frequencies
# Waveform, source model and likelihood
# ROQ
# Calibration
[docs]
self.spline_calibration_amplitude_uncertainty_dict = (
args.spline_calibration_amplitude_uncertainty_dict
)
[docs]
self.spline_calibration_phase_uncertainty_dict = (
args.spline_calibration_phase_uncertainty_dict
)
# Marginalization
if test is False:
self._load_data_dump()
# Reweighting
self.reweighting_configuration = self.meta_data.get(
"reweighting_configuration", None
)
self.reweight_nested_samples = args.reweight_nested_samples
# heterodyning - this relies on the data dump
self.fiducial_parameters = self.meta_data.get("fiducial_parameters", None)
self.update_fiducial_parameters = False
self.epsilon = args.epsilon
@property
@sampling_seed.setter
def sampling_seed(self, sampling_seed):
if sampling_seed is None:
sampling_seed = np.random.randint(1, 1e6)
self._sampling_seed = sampling_seed
np.random.seed(sampling_seed)
logger.info(f"Sampling seed set to {sampling_seed}")
if not any(
[
k in self.sampler_kwargs
for k in bilby.core.sampler.Sampler.sampling_seed_equiv_kwargs
]
):
self.sampler_kwargs["sampling_seed"] = self._sampling_seed
@property
[docs]
def interferometers(self):
try:
return self._interferometers
except AttributeError:
ifos = self.data_dump.interferometers
names = [ifo.name for ifo in ifos]
logger.info(f"Found data for detectors = {names}")
ifos_to_use = [ifo for ifo in ifos if ifo.name in self.detectors]
names_to_use = [ifo.name for ifo in ifos_to_use]
logger.info(f"Using data for detectors = {names_to_use}")
self._interferometers = bilby.gw.detector.InterferometerList(ifos_to_use)
self.print_detector_information(self._interferometers)
return self._interferometers
@staticmethod
[docs]
def print_detector_information(interferometers):
for ifo in interferometers:
logger.info(
"{}: sampling-frequency={}, segment-start-time={}, duration={}".format(
ifo.name,
ifo.strain_data.sampling_frequency,
ifo.strain_data.start_time,
ifo.strain_data.duration,
)
)
@property
[docs]
def data_dump(self):
if hasattr(self, "_data_dump"):
return self._data_dump
else:
raise BilbyPipeError("Data dump not loaded")
[docs]
def _load_data_dump(self):
filename = self.data_dump_file
self.meta_data["data_dump"] = filename
logger.debug("Data dump not previously loaded")
if os.path.isfile(filename):
pass
elif os.path.isfile(os.path.basename(filename)):
filename = os.path.basename(filename)
else:
raise FileNotFoundError(
"No dump data {} file found. Most likely the generation "
"step failed".format(filename)
)
self._data_dump = DataDump.from_pickle(filename)
self.meta_data.update(self._data_dump.meta_data)
return self._data_dump
@property
[docs]
def result_class(self):
"""The bilby result class to store results in"""
try:
return bilby.gw.result.CompactBinaryCoalescenceResult
except AttributeError:
logger.debug("Unable to use CBC specific result class")
return None
@property
[docs]
def result_directory(self):
result_dir = os.path.join(self.outdir, "result")
return os.path.relpath(result_dir)
[docs]
def get_likelihood_and_priors(self):
"""Read in the likelihood and prior from the data dump
This reads in the data dump values and reconstructs the likelihood and
priors. Note, care must be taken to use the "search_priors" which differ
from the true prior when using marginalization
Returns
-------
likelihood, priors
The bilby likelihood and priors
"""
priors = self.data_dump.priors_class(self.data_dump.priors_dict)
self.priors = priors
self.likelihood_lookup_table = self.data_dump.likelihood_lookup_table
self.likelihood_roq_weights = self.data_dump.likelihood_roq_weights
self.likelihood_roq_params = self.data_dump.likelihood_roq_params
self.likelihood_multiband_weights = self.data_dump.likelihood_multiband_weights
likelihood = self.likelihood
priors = self.search_priors
return likelihood, priors
[docs]
def run_sampler(self):
if self.scheduler.lower() == "condor" and not self.run_local:
signal.signal(signal.SIGALRM, handler=sighandler)
signal.alarm(self.periodic_restart_time)
likelihood, priors = self.get_likelihood_and_priors()
if self.reweighting_configuration is not None:
conversion_function = None
else:
conversion_function = self.parameter_generation
self.result = bilby.run_sampler(
likelihood=likelihood,
priors=priors,
sampler=self.sampler,
label=self.label,
outdir=self.result_directory,
conversion_function=conversion_function,
injection_parameters=self.meta_data["injection_parameters"],
meta_data=self.meta_data,
result_class=self.result_class,
exit_code=CHECKPOINT_EXIT_CODE,
save=self.result_format,
**self.sampler_kwargs,
)
[docs]
def get_likelihood_and_priors_for_reweighting(self):
data = self.reweighting_configuration
need_likelihood = False
likelihood = None
priors = None
self.search_priors = self.priors.copy()
target_arguments = _get_items_in_group("likelihood")
target_arguments += _get_items_in_group("waveform")
if "prior-file" in data:
priors = self.priors.copy()
priors.update(bilby.core.prior.PriorDict(data["prior-file"]))
self.search_priors = priors
for key, value in data.items():
key = key.replace("-", "_")
if key in target_arguments:
logger.info(f"Setting {key} to {value} for reweighting")
setattr(self, key, value)
need_likelihood = True
if need_likelihood:
likelihood = self.likelihood
likelihood.parameters.update(self.search_priors.sample())
likelihood.outdir = self.outdir
likelihood.label = self.label
return likelihood, priors
[docs]
def reweight_result(self):
old_priors = self.result.priors
self.search_priors = old_priors
reweight_nest = self.reweight_nested_samples
if reweight_nest and self.sampler in ["dynesty", "nessai"]:
self.result.nested_samples["log_prior"] = old_priors.ln_prob(
{
key: self.result.nested_samples[key].values
for key in self.result.search_parameter_keys
},
axis=0,
)
for key in old_priors:
if isinstance(old_priors[key], bilby.core.prior.DeltaFunction):
self.result.nested_samples[key] = old_priors[key].sample()
else:
reweight_nest = False
old_likelihood = self.likelihood
likelihood, priors = self.get_likelihood_and_priors_for_reweighting()
if likelihood is None:
old_likelihood = None
if priors is None:
old_priors = None
target_keys = list(self.search_priors.keys()) + ["log_prior", "log_likelihood"]
for key in list(self.result.posterior):
if key not in target_keys:
del self.result.posterior[key]
if likelihood is not None:
n_evaluations = min(100, len(self.result.posterior))
t_start = time.time()
for i in range(n_evaluations):
likelihood.parameters = {
key: self.result.posterior.iloc[i][key]
for key in self.result.posterior
}
likelihood.log_likelihood()
time_per_likelihood = (time.time() - t_start) / n_evaluations
logger.debug(f"{time_per_likelihood:.2f} s per likelihood evaluation")
n_checkpoint = int(300 / time_per_likelihood)
else:
n_checkpoint = 3000
logger.debug(f"Checkpointing every {n_checkpoint} samples")
reweighted = bilby.core.result.reweight(
result=self.result,
label=self.label,
new_likelihood=likelihood,
new_prior=priors,
old_likelihood=old_likelihood,
old_prior=old_priors,
conversion_function=self.parameter_generation,
npool=self.request_cpus,
verbose_output=False,
resume_file=f"{self.result_directory}/{self.label}_reweight_resume.pkl",
n_checkpoint=n_checkpoint,
use_nested_samples=reweight_nest,
)
reweighted.save_to_file(
extension=self.result_format, overwrite=True, outdir=self.result_directory
)
[docs]
def create_analysis_parser():
"""Data analysis parser creation"""
return create_parser(top_level=False)
[docs]
def main():
"""Data analysis main logic"""
args, unknown_args = parse_args(sys.argv[1:], create_analysis_parser())
log_version_information()
analysis = DataAnalysisInput(args, unknown_args)
analysis.run_sampler()
if analysis.reweighting_configuration is not None:
analysis.reweight_result()
logger.info("Run completed")
[docs]
def reweight():
parser = create_analysis_parser()
parser.add_argument("--result-file", type=str, default=None)
args, unknown_args = parse_args(sys.argv[1:], parser)
log_version_information()
analysis = DataAnalysisInput(args, unknown_args)
analysis.reweighting_configuration = args.reweighting_configuration
_ = analysis.get_likelihood_and_priors()
outdir = f"{analysis.outdir}/result"
if args.result_file is None:
filename = f"{outdir}/{analysis.label}_result.{analysis.result_format}"
else:
filename = args.result_file
result = bilby.core.result.read_in_result(filename)
result.outdir = outdir
analysis.result = result
analysis.label += "_reweighted"
if analysis.reweighting_configuration is not None:
logger.info("Running reweighting")
analysis.reweight_result()
else:
raise ValueError("Reweighting configuration not passed")