Source code for bilby_pipe.main

#!/usr/bin/env python
"""
bilby_pipe is a command line tools for taking user input (as command line
arguments or an ini file) and creating DAG files for submitting bilby parameter
estimation jobs. To get started, write an ini file `config.ini` and run

$ bilby_pipe config.ini

Instruction for how to submit the job are printed in a log message. You can
also specify extra arguments from the command line, e.g.

$ bilby_pipe config.ini --submit

will build and submit the job.
"""
import importlib
import json
import os

import numpy as np
import pandas as pd

from .create_injections import create_injection_file
from .input import Input
from .job_creation import generate_dag
from .parser import create_parser
from .utils import (
    BilbyPipeError,
    get_colored_string,
    get_command_line_arguments,
    get_outdir_name,
    log_version_information,
    logger,
    parse_args,
    request_memory_generation_lookup,
    tcolors,
)


[docs] class MainInput(Input): """An object to hold all the inputs to bilby_pipe""" def __init__(self, args, unknown_args, perform_checks=True): super().__init__(args, unknown_args, print_msg=False)
[docs] self.known_args = args
[docs] self.unknown_args = unknown_args
self.ini = args.ini
[docs] self.submit = args.submit
[docs] self.condor_job_priority = args.condor_job_priority
[docs] self.create_summary = args.create_summary
[docs] self.outdir = args.outdir
[docs] self.label = args.label
[docs] self.log_directory = args.log_directory
[docs] self.accounting = args.accounting
[docs] self.accounting_user = args.accounting_user
[docs] self.sampler = args.sampler
[docs] self.detectors = args.detectors
[docs] self.coherence_test = args.coherence_test
[docs] self.n_parallel = args.n_parallel
[docs] self.transfer_files = args.transfer_files
[docs] self.additional_transfer_paths = args.additional_transfer_paths
[docs] self.osg = args.osg
[docs] self.desired_sites = args.desired_sites
self.analysis_executable = args.analysis_executable
[docs] self.analysis_executable_parser = args.analysis_executable_parser
[docs] self.result_format = args.result_format
[docs] self.final_result = args.final_result
[docs] self.final_result_nsamples = args.final_result_nsamples
[docs] self.webdir = args.webdir
[docs] self.email = args.email
self.notification = args.notification
[docs] self.queue = args.queue
[docs] self.existing_dir = args.existing_dir
[docs] self.scheduler = args.scheduler
[docs] self.scheduler_args = args.scheduler_args
[docs] self.scheduler_module = args.scheduler_module
[docs] self.scheduler_env = args.scheduler_env
[docs] self.scheduler_analysis_time = args.scheduler_analysis_time
[docs] self.disable_hdf5_locking = args.disable_hdf5_locking
[docs] self.waveform_approximant = args.waveform_approximant
[docs] self.time_reference = args.time_reference
[docs] self.reference_frame = args.reference_frame
[docs] self.likelihood_type = args.likelihood_type
[docs] self.duration = args.duration
[docs] self.phase_marginalization = args.phase_marginalization
[docs] self.prior_file = args.prior_file
[docs] self.prior_dict = args.prior_dict
[docs] self.default_prior = args.default_prior
[docs] self.minimum_frequency = args.minimum_frequency
[docs] self.enforce_signal_duration = args.enforce_signal_duration
[docs] self.run_local = args.local
[docs] self.local_generation = args.local_generation
[docs] self.local_plot = args.local_plot
[docs] self.post_trigger_duration = args.post_trigger_duration
[docs] self.ignore_gwpy_data_quality_check = args.ignore_gwpy_data_quality_check
[docs] self.trigger_time = args.trigger_time
[docs] self.deltaT = args.deltaT
[docs] self.gps_tuple = args.gps_tuple
[docs] self.gps_file = args.gps_file
[docs] self.timeslide_file = args.timeslide_file
[docs] self.gaussian_noise = args.gaussian_noise
[docs] self.zero_noise = args.zero_noise
self.n_simulation = args.n_simulation
[docs] self.injection = args.injection
[docs] self.injection_numbers = args.injection_numbers
[docs] self.injection_file = args.injection_file
[docs] self.injection_dict = args.injection_dict
[docs] self.injection_waveform_arguments = args.injection_waveform_arguments
[docs] self.injection_waveform_approximant = args.injection_waveform_approximant
[docs] self.generation_seed = args.generation_seed
self.request_disk = args.request_disk self.request_memory = args.request_memory self.request_memory_generation = args.request_memory_generation self.request_cpus = args.request_cpus
[docs] self.sampler_kwargs = args.sampler_kwargs
[docs] self.mpi_samplers = ["pymultinest"]
self.use_mpi = (self.sampler in self.mpi_samplers) and (self.request_cpus > 1) # Set plotting options when need the plot node
[docs] self.plot_node_needed = False
for plot_attr in [ "calibration", "corner", "marginal", "skymap", "waveform", ]: attr = f"plot_{plot_attr}" setattr(self, attr, getattr(args, attr)) if getattr(self, attr): self.plot_node_needed = True # Set all other plotting options for plot_attr in [ "trace", "data", "injection", "spectrogram", "format", ]: attr = f"plot_{plot_attr}" setattr(self, attr, getattr(args, attr))
[docs] self.postprocessing_executable = args.postprocessing_executable
[docs] self.postprocessing_arguments = args.postprocessing_arguments
[docs] self.single_postprocessing_executable = args.single_postprocessing_executable
[docs] self.single_postprocessing_arguments = args.single_postprocessing_arguments
[docs] self.summarypages_arguments = args.summarypages_arguments
[docs] self.psd_dict = args.psd_dict
if perform_checks: self.check_source_model(args) self.check_calibration_prior_boundary(args) self.check_cpu_parallelisation() if self.injection: self.check_injection()
[docs] self.extra_lines = []
[docs] self.requirements = []
@property
[docs] def ini(self): return self._ini
@ini.setter def ini(self, ini): if os.path.isfile(ini) is False: raise FileNotFoundError(f"No ini file {ini} found") self._ini = os.path.relpath(ini) @property
[docs] def notification(self): return self._notification
@notification.setter def notification(self, notification): valid_settings = ["Always", "Complete", "Error", "Never"] if notification in valid_settings: self._notification = notification else: raise BilbyPipeError( "'{}' is not a valid notification setting. " "Valid settings are {}.".format(notification, valid_settings) ) @property
[docs] def initialdir(self): return os.getcwd()
@property
[docs] def n_simulation(self): return self._n_simulation
@n_simulation.setter def n_simulation(self, n_simulation): logger.debug(f"Setting n_simulation={n_simulation}") if isinstance(n_simulation, int) and n_simulation >= 0: self._n_simulation = n_simulation elif n_simulation is None: self._n_simulation = 0 else: raise BilbyPipeError(f"Input n_simulation={n_simulation} not understood") @property
[docs] def analysis_executable(self): return self._analysis_executable
@analysis_executable.setter def analysis_executable(self, analysis_executable): if analysis_executable: self._analysis_executable = analysis_executable else: self._analysis_executable = "bilby_pipe_analysis" @property
[docs] def request_disk(self): return self._request_disk
@request_disk.setter def request_disk(self, request_disk): self._request_disk = f"{request_disk}GB" self._request_disk_in_GB = float(request_disk) logger.debug(f"Setting analysis request_disk={self._request_disk}") self._request_disk = f"{request_disk}GB" @property
[docs] def request_memory(self): return self._request_memory
@request_memory.setter def request_memory(self, request_memory): self._request_memory = f"{request_memory}GB" self._request_memory_in_GB = request_memory logger.debug(f"Setting analysis request_memory={self._request_memory}") @property
[docs] def request_memory_generation(self): return self._request_memory_generation
@request_memory_generation.setter def request_memory_generation(self, request_memory_generation): if request_memory_generation is None: roq = "roq" in self.likelihood_type.lower() request_memory_generation = request_memory_generation_lookup( self.duration, roq=roq ) logger.debug(f"Setting request_memory_generation={request_memory_generation}GB") self._request_memory_generation = f"{request_memory_generation}GB" @property
[docs] def request_cpus(self): return self._request_cpus
@request_cpus.setter def request_cpus(self, request_cpus): logger.debug(f"Setting analysis request_cpus = {request_cpus}") self._request_cpus = request_cpus @property
[docs] def use_mpi(self): return self._use_mpi
@use_mpi.setter def use_mpi(self, use_mpi): if use_mpi: logger.debug(f"Turning on MPI for {self.sampler}") self._use_mpi = use_mpi @staticmethod
[docs] def check_source_model(args): """Check the source model consistency with the approximant""" if "tidal" in args.waveform_approximant.lower(): if "neutron_star" not in args.frequency_domain_source_model.lower(): msg = [ tcolors.WARNING, "You appear to be using a tidal waveform with the", f"{args.frequency_domain_source_model} source model.", "You may want to use `frequency-domain-source-model=", "lal_binary_neutron_star`.", tcolors.END, ] logger.warning(" ".join(msg))
@staticmethod
[docs] def check_calibration_prior_boundary(args): # List of recommendations: print warning if these are not adhered to recs = dict(bilby_mcmc=None, dynesty="reflective") suggested_boundary = recs.get(args.sampler, args.calibration_prior_boundary) if args.calibration_prior_boundary != suggested_boundary: msg = ( "You have requested a calibration prior boundary " f"{args.calibration_prior_boundary}, but {suggested_boundary} " "is recommended." ) logger.warning(get_colored_string(msg))
[docs] def check_cpu_parallelisation(self): request_cpus = self.request_cpus npool = self.sampler_kwargs.get("npool", request_cpus) if request_cpus != npool: msg = ( f"request-cpus={request_cpus}, but sampler_kwargs[npool]={npool}:" "this may cause inefficient performance" ) logger.warning(get_colored_string(msg))
[docs] def check_injection(self): """Check injection behaviour If injections are requested, either use the injection-dict, injection-file, or create an injection-file """ default_injection_file_name = "{}/{}_injection_file.dat".format( self.data_directory, self.label ) if self.injection_dict is not None: logger.debug( "Using injection dict from ini file {}".format( json.dumps(self.injection_dict, indent=2) ) ) elif self.injection_file is not None: logger.debug(f"Using injection file {self.injection_file}") elif os.path.isfile(default_injection_file_name): # This is done to avoid overwriting the injection file logger.debug(f"Using injection file {default_injection_file_name}") self.injection_file = default_injection_file_name else: logger.debug("No injection file found, generating one now") if self.gps_file is not None or self.gps_tuple is not None: if self.n_simulation > 0 and self.n_simulation != len(self.gpstimes): raise BilbyPipeError( "gps_file/gps_tuple option and n_simulation are not matched" ) gpstimes = self.gpstimes n_injection = len(gpstimes) else: gpstimes = None n_injection = self.n_simulation if self.trigger_time is None: trigger_time_injections = 0 else: trigger_time_injections = self.trigger_time create_injection_file( filename=default_injection_file_name, prior_file=self.prior_file, prior_dict=self.prior_dict, n_injection=n_injection, trigger_time=trigger_time_injections, deltaT=self.deltaT, gpstimes=gpstimes, duration=self.duration, post_trigger_duration=self.post_trigger_duration, generation_seed=self.generation_seed, extension="dat", default_prior=self.default_prior, ) self.injection_file = default_injection_file_name # Check the gps_file has the sample length as number of simulation if self.gps_file is not None: if len(self.gpstimes) != len(self.injection_df): raise BilbyPipeError("Injection file length does not match gps_file") if self.n_simulation > 0: if self.n_simulation != len(self.injection_df): raise BilbyPipeError( "n-simulation does not match the number of injections: " "please check your ini file" ) elif self.n_simulation == 0 and self.gps_file is None: self.n_simulation = len(self.injection_df) logger.debug( f"Setting n_simulation={self.n_simulation} to match injections" )
[docs] def write_complete_config_file(parser, args, inputs, input_cls=MainInput): args_dict = vars(args).copy() for key, val in args_dict.items(): if key == "label": continue if isinstance(val, str): if os.path.isfile(val) or os.path.isdir(val): setattr(args, key, os.path.abspath(val)) if isinstance(val, list): if len(val) == 0: setattr(args, key, "[]") elif isinstance(val[0], str): setattr(args, key, f"[{', '.join(val)}]") args.sampler_kwargs = str(inputs.sampler_kwargs) args.submit = False parser.write_to_file( filename=inputs.complete_ini_file, args=args, overwrite=False, include_description=False, ) # Verify that the written complete config is identical to the source config complete_args = parser.parse([inputs.complete_ini_file]) complete_inputs = input_cls(complete_args, "", perform_checks=False) ignore_keys = ["scheduler_module", "submit"] differences = [] for key, val in inputs.__dict__.items(): if key in ignore_keys: continue if key not in complete_args: continue if isinstance(val, pd.DataFrame) and all(val == complete_inputs.__dict__[key]): continue if isinstance(val, np.ndarray) and all( np.array(val) == np.array(complete_inputs.__dict__[key]) ): continue if isinstance(val, str) and os.path.isfile(val): # Check if it is relpath vs abspath if os.path.abspath(val) == os.path.abspath(complete_inputs.__dict__[key]): continue if val == complete_inputs.__dict__[key]: continue differences.append(key) if len(differences) > 0: for key in differences: print(key, f"{inputs.__dict__[key]} -> {complete_inputs.__dict__[key]}") raise BilbyPipeError( "The written config file {} differs from the source {} in {}".format( inputs.ini, inputs.complete_ini_file, differences ) ) else: logger.info(f"To see full configuration, check {inputs.complete_ini_file}")
[docs] def main(): """Top-level interface for bilby_pipe""" parser = create_parser(top_level=True) args, unknown_args = parse_args(get_command_line_arguments(), parser) if args.analysis_executable_parser is not None: # Alternative parser requested, reload args module = ".".join(args.analysis_executable_parser.split(".")[:-1]) function = args.analysis_executable_parser.split(".")[-1] parser = getattr(importlib.import_module(module), function)() args, unknown_args = parse_args(get_command_line_arguments(), parser) # Check and sort outdir args.outdir = args.outdir.replace("'", "").replace('"', "") if args.overwrite_outdir is False: args.outdir = get_outdir_name(args.outdir) log_version_information() inputs = MainInput(args, unknown_args) write_complete_config_file(parser, args, inputs) generate_dag(inputs) if len(unknown_args) > 0: msg = [tcolors.WARNING, f"Unrecognized arguments {unknown_args}", tcolors.END] logger.warning(" ".join(msg))