#!/usr/bin/env python
"""
Module containing the tools for plotting of results
"""
import os
from bilby.core.utils import check_directory_exists_and_if_not_mkdir
from bilby.gw.result import CBCResult
from bilby.gw.source import (
binary_black_hole_frequency_sequence,
binary_black_hole_roq,
binary_neutron_star_frequency_sequence,
binary_neutron_star_roq,
lal_binary_black_hole,
lal_binary_neutron_star,
)
from .bilbyargparser import BilbyArgParser
from .utils import DataDump, get_command_line_arguments, logger, parse_args
# fmt: off
import matplotlib # isort:skip
matplotlib.use("agg") # noqa
import matplotlib.pyplot as plt # noqa isort:skip
# fmt: on
[docs]
def create_parser():
"""Generate a parser for the plot script
Returns
-------
parser: BilbyArgParser
A parser with all the default options already added
"""
parser = BilbyArgParser(ignore_unknown_config_file_keys=True)
parser.add("--result", type=str, required=True, help="The result file")
parser.add("--calibration", action="store_true", help="Generate calibration plot")
parser.add("--corner", action="store_true", help="Generate corner plots")
parser.add("--marginal", action="store_true", help="Generate marginal plots")
parser.add("--skymap", action="store_true", help="Generate skymap")
parser.add("--waveform", action="store_true", help="Generate waveform")
parser.add(
"--outdir", type=str, required=False, help="The directory to save the plots in"
)
parser.add(
"--format",
type=str,
default="png",
help="Format for making bilby_pipe plots, can be [png, pdf, html]. "
"If specified format is not supported, will default to png.",
)
return parser
[docs]
def _parse_and_load():
args, unknown_args = parse_args(get_command_line_arguments(), create_parser())
logger.info(f"Generating plots for results file {args.result}")
extension = os.path.splitext(args.result)[-1][1:]
if extension == "json":
result = CBCResult.from_json(args.result)
elif extension == "hdf5":
result = CBCResult.from_hdf5(args.result)
elif extension == "pkl":
result = CBCResult.from_pickle(args.result)
else:
raise ValueError(f"Result format {extension} not understood.")
if "data_dump" in result.meta_data and os.path.exists(
result.meta_data["data_dump"]
):
data_dump = DataDump.from_pickle(result.meta_data["data_dump"])
logger.info(f"Loaded data from {result.meta_data['data_dump']}")
else:
data_dump = None
logger.info("Failed to load data dump file")
if hasattr(args, "webdir"):
outdir = os.path.join(args.webdir, "bilby")
elif hasattr(args, "outdir"):
outdir = args.outdir
else:
outdir = result.outdir
logger.info(f"Plots will be made in {outdir}")
check_directory_exists_and_if_not_mkdir(outdir)
result.outdir = outdir
return args, result, data_dump
[docs]
def plot_calibration():
args, result, _ = _parse_and_load()
logger.info("Generating calibration posterior")
allowed_formats = list(plt.gcf().canvas.get_supported_filetypes())
if args.format in allowed_formats:
_format = args.format
else:
logger.info(
f"Requested format '{args.format}' not recognised. Falling back to png."
)
_format = "png"
result.plot_calibration_posterior(format=_format)
[docs]
def plot_corner():
_, result, _ = _parse_and_load()
logger.info("Generating intrinsic parameter corner")
result.plot_corner(
[
"mass_1_source",
"mass_2_source",
"chirp_mass_source",
"mass_ratio",
"chi_eff",
"chi_p",
],
filename=f"{result.outdir}/{result.label}_intrinsic_corner.png",
)
logger.info("Generating extrinsic parameter corner")
result.plot_corner(
["luminosity_distance", "redshift", "theta_jn", "ra", "dec", "geocent_time"],
filename=f"{result.outdir}/{result.label}_extrinsic_corner.png",
)
[docs]
def plot_marginal():
_, result, _ = _parse_and_load()
logger.info("Plotting 1d posteriors")
result.plot_marginals(priors=True)
[docs]
def plot_skymap():
_, result, _ = _parse_and_load()
logger.info("Generating skymap")
try:
result.plot_skymap(maxpts=5000)
except Exception as e:
logger.info(f"Unable to generate skymap: error {e}")
[docs]
def main():
"""Top-level interface for bilby_pipe"""
args, result, data_dump = _parse_and_load()
if args.skymap:
plot_skymap()
if args.marginal:
plot_marginal()
if args.corner:
plot_corner()
if args.calibration:
plot_calibration()
if args.waveform:
plot_waveform()