Coverage for pesummary/gw/plots/publication.py: 82.8%
227 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3from pesummary.utils.utils import logger, number_of_columns_for_legend
4import seaborn
5from pesummary.core.plots.figure import figure
6from pesummary.core.plots.seaborn import violin
7from pesummary.utils.bounded_2d_kde import Bounded_2d_kde
8from pesummary.gw.plots.bounds import default_bounds
9from pesummary.gw.plots.cmap import colormap_with_fixed_hue
10from pesummary.gw.conversions import mchirp_from_m1_m2, q_from_m1_m2
11import numpy as np
12import copy
14__author__ = [
15 "Charlie Hoy <charlie.hoy@ligo.org>",
16 "Michael Puerrer <michael.puerrer@ligo.org>"
17]
20def chirp_mass_and_q_from_mass1_mass2(pts):
21 """Transform the component masses to chirp mass and mass ratio
23 Parameters
24 ----------
25 pts: numpy.array
26 array containing the mass1 and mass2 samples
27 """
28 pts = np.atleast_2d(pts)
30 m1, m2 = pts
31 mc = mchirp_from_m1_m2(m1, m2)
32 q = q_from_m1_m2(m1, m2)
33 return np.vstack([mc, q])
36def _return_bounds(parameters, T=True):
37 """Return bounds for KDE
39 Parameters
40 ----------
41 parameters: list
42 list of parameters being plotted
43 T: Bool, optional
44 if True, modify the parameter bounds if a transform is required
45 """
46 transform = xlow = xhigh = ylow = yhigh = None
47 if parameters[0] in list(default_bounds.keys()):
48 if "low" in list(default_bounds[parameters[0]].keys()):
49 xlow = default_bounds[parameters[0]]["low"]
50 if "high" in list(default_bounds[parameters[0]].keys()):
51 if isinstance(default_bounds[parameters[0]]["high"], str) and T:
52 if "mass_1" in default_bounds[parameters[0]]["high"]:
53 transform = chirp_mass_and_q_from_mass1_mass2
54 xhigh = 1.
55 elif isinstance(default_bounds[parameters[0]]["high"], str):
56 xhigh = None
57 else:
58 xhigh = default_bounds[parameters[0]]["high"]
59 if parameters[1] in list(default_bounds.keys()):
60 if "low" in list(default_bounds[parameters[1]].keys()):
61 ylow = default_bounds[parameters[1]]["low"]
62 if "high" in list(default_bounds[parameters[1]].keys()):
63 if isinstance(default_bounds[parameters[1]]["high"], str) and T:
64 if "mass_1" in default_bounds[parameters[1]]["high"]:
65 transform = chirp_mass_and_q_from_mass1_mass2
66 yhigh = 1.
67 elif isinstance(default_bounds[parameters[1]]["high"], str):
68 yhigh = None
69 else:
70 yhigh = default_bounds[parameters[1]]["high"]
71 return transform, xlow, xhigh, ylow, yhigh
74def twod_contour_plots(
75 parameters, samples, labels, latex_labels, colors=None, linestyles=None,
76 return_ax=False, plot_datapoints=False, smooth=None, latex_friendly=False,
77 levels=[0.9], legend_kwargs={
78 "bbox_to_anchor": (0., 1.02, 1., .102), "loc": 3, "handlelength": 3,
79 "mode": "expand", "borderaxespad": 0., "handleheight": 1.75
80 }, **kwargs
81):
82 """Generate 2d contour plots for a set of samples for given parameters
84 Parameters
85 ----------
86 parameters: list
87 names of the parameters that you wish to plot
88 samples: nd list
89 list of samples for each parameter
90 labels: list
91 list of labels corresponding to each set of samples
92 latex_labels: dict
93 dictionary of latex labels
94 """
95 from pesummary.core.plots.publication import (
96 comparison_twod_contour_plot as core
97 )
98 from matplotlib.patches import Polygon
100 logger.debug("Generating 2d contour plots for %s" % ("_and_".join(parameters)))
101 if colors is None:
102 palette = seaborn.color_palette(palette="pastel", n_colors=len(samples))
103 else:
104 palette = colors
105 if linestyles is None:
106 linestyles = ["-"] * len(samples)
107 fig, ax1 = figure(gca=True)
108 transform, xlow, xhigh, ylow, yhigh = _return_bounds(parameters)
109 kwargs.update(
110 {
111 "kde": Bounded_2d_kde, "kde_kwargs": {
112 "transform": transform, "xlow": xlow, "xhigh": xhigh,
113 "ylow": ylow, "yhigh": yhigh
114 }
115 }
116 )
117 fig = core(
118 [i[0] for i in samples], [i[1] for i in samples], colors=colors,
119 labels=labels, xlabel=latex_labels[parameters[0]], smooth=smooth,
120 ylabel=latex_labels[parameters[1]], linestyles=linestyles,
121 plot_datapoints=plot_datapoints, levels=levels, **kwargs
122 )
123 ax1 = fig.gca()
124 if all("mass_1" in i or "mass_2" in i for i in parameters):
125 reg = Polygon([[0, 0], [0, 1000], [1000, 1000]], color='gray', alpha=0.75)
126 ax1.add_patch(reg)
127 ncols = number_of_columns_for_legend(labels)
128 legend_kwargs.update({"ncol": ncols})
129 legend = ax1.legend(**legend_kwargs)
130 for leg in legend.get_lines():
131 leg.set_linewidth(legend_kwargs.get("handleheight", 1.))
132 fig.tight_layout()
133 if return_ax:
134 return fig, ax1
135 return fig
138def _setup_triangle_plot(parameters, kwargs):
139 """Modify a dictionary of kwargs for bounded KDEs
141 Parameters
142 ----------
143 parameters: list
144 list of parameters being plotted
145 kwargs: dict
146 kwargs to be passed to pesummary.gw.plots.publication.triangle_plot
147 or pesummary.gw.plots.publication.reverse_triangle_plot
148 """
149 from pesummary.utils.bounded_1d_kde import bounded_1d_kde
151 if not len(parameters):
152 raise ValueError("Please provide a list of parameters")
153 transform, xlow, xhigh, ylow, yhigh = _return_bounds(parameters)
154 kwargs.update(
155 {
156 "kde_2d": Bounded_2d_kde, "kde_2d_kwargs": {
157 "transform": transform, "xlow": xlow, "xhigh": xhigh,
158 "ylow": ylow, "yhigh": yhigh
159 }, "kde": bounded_1d_kde
160 }
161 )
162 _, xlow, xhigh, ylow, yhigh = _return_bounds(parameters, T=False)
163 kwargs["kde_kwargs"] = {
164 "x_axis": {"xlow": xlow, "xhigh": xhigh},
165 "y_axis": {"xlow": ylow, "xhigh": yhigh}
166 }
167 return kwargs
170def triangle_plot(*args, parameters=[], **kwargs):
171 """Generate a triangular plot made of 3 axis. One central axis showing the
172 2d marginalized posterior and two smaller axes showing the marginalized 1d
173 posterior distribution (above and to the right of central axis)
175 Parameters
176 ----------
177 *args: tuple
178 all args passed to pesummary.core.plots.publication.triangle_plot
179 parameters: list
180 list of parameters being plotted
181 kwargs: dict, optional
182 all kwargs passed to pesummary.core.plots.publication.triangle_plot
183 """
184 from pesummary.core.plots.publication import triangle_plot as core
185 kwargs = _setup_triangle_plot(parameters, kwargs)
186 return core(*args, **kwargs)
189def reverse_triangle_plot(*args, parameters=[], **kwargs):
190 """Generate a triangular plot made of 3 axis. One central axis showing the
191 2d marginalized posterior and two smaller axes showing the marginalized 1d
192 posterior distribution (below and to the left of central axis). Only two
193 axes are plotted, each below the 1d marginalized posterior distribution
195 Parameters
196 ----------
197 *args: tuple
198 all args passed to
199 pesummary.core.plots.publication.reverse_triangle_plot
200 parameters: list
201 list of parameters being plotted
202 kwargs: dict, optional
203 all kwargs passed to
204 pesummary.core.plots.publication.reverse_triangle_plot
205 """
206 from pesummary.core.plots.publication import reverse_triangle_plot as core
207 kwargs = _setup_triangle_plot(parameters, kwargs)
208 return core(*args, **kwargs)
211def violin_plots(
212 parameter, samples, labels, latex_labels, inj_values=None, cut=0,
213 _default_kwargs={"palette": "pastel", "inner": "line", "outer": "percent: 90"},
214 latex_friendly=True, **kwargs
215):
216 """Generate violin plots for a set of parameters and samples
218 Parameters
219 ----------
220 parameters: str
221 the name of the parameter that you wish to plot
222 samples: nd list
223 list of samples for each parameter
224 labels: list
225 list of labels corresponding to each set of samples
226 latex_labels: dict
227 dictionary of latex labels
228 inj_values: list
229 list of injected values for each set of samples
230 """
231 logger.debug("Generating violin plots for %s" % (parameter))
232 fig, ax1 = figure(gca=True)
233 _default_kwargs.update(kwargs)
234 ax1 = violin.violinplot(
235 data=samples, cut=cut, ax=ax1, scale="width", inj=inj_values, **_default_kwargs
236 )
237 if latex_friendly:
238 labels = copy.deepcopy(labels)
239 for num, _ in enumerate(labels):
240 labels[num] = labels[num].replace("_", "\_")
241 ax1.set_xticklabels(labels)
242 for label in ax1.get_xmajorticklabels():
243 label.set_rotation(30)
244 ax1.set_ylabel(latex_labels[parameter])
245 fig.tight_layout()
246 return fig
249def spin_distribution_plots(
250 parameters, samples, label, color=None, cmap=None, annotate=False,
251 show_label=True, colorbar=False, vmin=0.,
252 vmax=np.log(1.0 + np.exp(1.) * 3.024)
253):
254 """Generate spin distribution plots for a set of parameters and samples
256 Parameters
257 ----------
258 parameters: list
259 list of parameters
260 samples: nd list
261 list of samples for each spin component
262 label: str
263 the label corresponding to the set of samples
264 color: str, optioanl
265 color to use for plotting
266 cmap: str, optional
267 cmap to use for plotting. cmap is preferentially chosen over color
268 annotate: Bool, optional
269 if True, label the magnitude and tilt directions
270 show_label: Bool, optional
271 if True, add labels indicating which side of the spin disk corresponds
272 to which binary component
273 """
274 logger.debug("Generating spin distribution plots for %s" % (label))
275 from matplotlib.projections import PolarAxes
276 from matplotlib.transforms import Affine2D
277 from matplotlib.patches import Wedge
278 from matplotlib import patheffects as PathEffects
279 from matplotlib.collections import PatchCollection
280 from matplotlib.transforms import ScaledTranslation
282 from mpl_toolkits.axisartist.grid_finder import MaxNLocator
283 import mpl_toolkits.axisartist.floating_axes as floating_axes
284 import mpl_toolkits.axisartist.angle_helper as angle_helper
286 if color is not None and cmap is None:
287 cmap = colormap_with_fixed_hue(color)
288 elif color is None and cmap is None:
289 raise ValueError(
290 "Please provide either a single color or a cmap to use for plotting"
291 )
293 spin1 = samples[parameters.index("a_1")]
294 spin2 = samples[parameters.index("a_2")]
295 costheta1 = samples[parameters.index("cos_tilt_1")]
296 costheta2 = samples[parameters.index("cos_tilt_2")]
298 pts = np.array([spin1, costheta1])
299 selected_indices = np.random.choice(pts.shape[1], pts.shape[1] // 2, replace=False)
300 kde_sel = np.zeros(pts.shape[1], dtype=bool)
301 kde_sel[selected_indices] = True
302 kde_pts = pts[:, kde_sel]
303 spin1 = Bounded_2d_kde(kde_pts, xlow=0, xhigh=.99, ylow=-1, yhigh=1)
304 pts = np.array([spin2, costheta2])
305 selected_indices = np.random.choice(pts.shape[1], pts.shape[1] // 2, replace=False)
306 kde_sel = np.zeros(pts.shape[1], dtype=bool)
307 kde_sel[selected_indices] = True
308 kde_pts = pts[:, kde_sel]
309 spin2 = Bounded_2d_kde(kde_pts, xlow=0, xhigh=.99, ylow=-1, yhigh=1)
311 rs = np.linspace(0, .99, 25)
312 dr = np.abs(rs[1] - rs[0])
313 costs = np.linspace(-1, 1, 25)
314 dcost = np.abs(costs[1] - costs[0])
315 COSTS, RS = np.meshgrid(costs[:-1], rs[:-1])
316 X = np.arccos(COSTS) * 180 / np.pi + 90.
317 Y = RS
319 scale = np.exp(1.0)
320 spin1_PDF = spin1(
321 np.vstack([RS.ravel() + dr / 2, COSTS.ravel() + dcost / 2]))
322 spin2_PDF = spin2(
323 np.vstack([RS.ravel() + dr / 2, COSTS.ravel() + dcost / 2]))
324 H1 = np.log(1.0 + scale * spin1_PDF)
325 H2 = np.log(1.0 + scale * spin2_PDF)
327 rect = 121
329 tr = Affine2D().translate(90, 0) + Affine2D().scale(np.pi / 180., 1.) + \
330 PolarAxes.PolarTransform()
332 grid_locator1 = angle_helper.LocatorD(7)
333 tick_formatter1 = angle_helper.FormatterDMS()
334 grid_locator2 = MaxNLocator(5)
335 grid_helper = floating_axes.GridHelperCurveLinear(
336 tr, extremes=(0, 180, 0, .99),
337 grid_locator1=grid_locator1,
338 grid_locator2=grid_locator2,
339 tick_formatter1=tick_formatter1,
340 tick_formatter2=None)
342 fig = figure(figsize=(6, 6), gca=False)
343 ax1 = floating_axes.FloatingSubplot(fig, rect, grid_helper=grid_helper)
344 fig.add_subplot(ax1)
346 ax1.axis["bottom"].toggle(all=False)
347 ax1.axis["top"].toggle(all=True)
348 ax1.axis["top"].major_ticks.set_tick_out(True)
350 ax1.axis["top"].set_axis_direction("top")
351 ax1.axis["top"].set_ticklabel_direction('+')
353 ax1.axis["left"].major_ticks.set_tick_out(True)
354 ax1.axis["left"].set_axis_direction('right')
355 dx = 7.0 / 72.
356 dy = 0 / 72.
357 offset_transform = ScaledTranslation(dx, dy, fig.dpi_scale_trans)
358 ax1.axis["left"].major_ticklabels.set(figure=fig,
359 transform=offset_transform)
361 patches = []
362 colors = []
363 for x, y, h in zip(X.ravel(), Y.ravel(), H1.ravel()):
364 cosx = np.cos((x - 90) * np.pi / 180)
365 cosxp = cosx + dcost
366 xp = np.arccos(cosxp)
367 xp = xp * 180. / np.pi + 90.
368 patches.append(Wedge((0., 0.), y + dr, xp, x, width=dr))
369 colors.append(h)
371 p = PatchCollection(patches, cmap=cmap, edgecolors='face', zorder=10)
372 p.set_clim(vmin, vmax)
373 p.set_array(np.array(colors))
374 ax1.add_collection(p)
376 # Spin 2
377 rect = 122
379 tr_rotate = Affine2D().translate(90, 0)
380 tr_scale = Affine2D().scale(np.pi / 180., 1.)
381 tr = tr_rotate + tr_scale + PolarAxes.PolarTransform()
383 grid_locator1 = angle_helper.LocatorD(7)
384 tick_formatter1 = angle_helper.FormatterDMS()
386 grid_locator2 = MaxNLocator(5)
388 grid_helper = floating_axes.GridHelperCurveLinear(
389 tr, extremes=(0, 180, 0, .99),
390 grid_locator1=grid_locator1,
391 grid_locator2=grid_locator2,
392 tick_formatter1=tick_formatter1,
393 tick_formatter2=None)
395 ax1 = floating_axes.FloatingSubplot(fig, rect, grid_helper=grid_helper)
396 ax1.invert_xaxis()
397 fig.add_subplot(ax1)
399 # Label angles on the outside
400 ax1.axis["bottom"].toggle(all=False)
401 ax1.axis["top"].toggle(all=True)
402 ax1.axis["top"].set_axis_direction("top")
403 ax1.axis["top"].major_ticks.set_tick_out(True)
405 # Remove radial labels
406 ax1.axis["left"].major_ticks.set_tick_out(True)
407 ax1.axis["left"].toggle(ticklabels=False)
408 ax1.axis["left"].major_ticklabels.set_visible(False)
409 # Also have radial ticks for the lower half of the right semidisk
410 ax1.axis["right"].major_ticks.set_tick_out(True)
412 patches = []
413 colors = []
414 for x, y, h in zip(X.ravel(), Y.ravel(), H2.ravel()):
415 cosx = np.cos((x - 90) * np.pi / 180)
416 cosxp = cosx + dcost
417 xp = np.arccos(cosxp)
418 xp = xp * 180. / np.pi + 90.
419 patches.append(Wedge((0., 0.), y + dr, xp, x, width=dr))
420 colors.append(h)
422 p = PatchCollection(patches, cmap=cmap, edgecolors='face', zorder=10)
423 p.set_clim(vmin, vmax)
424 p.set_array(np.array(colors))
425 ax1.add_collection(p)
427 # Event name top, spin labels bottom
428 if label is not None:
429 title = ax1.text(0.16, 1.25, label, fontsize=18, horizontalalignment='center')
430 if show_label:
431 S1_label = ax1.text(1.25, -1.15, r'$c{S}_{1}/(Gm_1^2)$', fontsize=14)
432 S2_label = ax1.text(-.5, -1.15, r'$c{S}_{2}/(Gm_2^2)$', fontsize=14)
433 if annotate:
434 scale = 1.0
435 aux_ax2 = ax1.get_aux_axes(tr)
436 txt = aux_ax2.text(
437 50 * scale, 0.35 * scale, r'$\mathrm{magnitude}$', fontsize=20,
438 zorder=10
439 )
440 txt = aux_ax2.text(
441 45 * scale, 1.2 * scale, r'$\mathrm{tilt}$', fontsize=20, zorder=10
442 )
443 txt = aux_ax2.annotate(
444 "", xy=(55, 1.158 * scale), xycoords='data',
445 xytext=(35, 1.158 * scale), textcoords='data',
446 arrowprops=dict(
447 arrowstyle="->", color="k", shrinkA=2, shrinkB=2, patchA=None,
448 patchB=None, connectionstyle='arc3,rad=-0.16'
449 )
450 )
451 txt.arrow_patch.set_path_effects(
452 [PathEffects.Stroke(linewidth=2, foreground="w"), PathEffects.Normal()]
453 )
454 txt = aux_ax2.annotate(
455 "", xy=(35, 0.55 * scale), xycoords='data',
456 xytext=(150, 0. * scale), textcoords='data',
457 arrowprops=dict(
458 arrowstyle="->", color="k", shrinkA=2, shrinkB=2, patchA=None,
459 patchB=None
460 ), zorder=100
461 )
462 txt.arrow_patch.set_path_effects(
463 [
464 PathEffects.Stroke(linewidth=0.3, foreground="k"),
465 PathEffects.Normal()
466 ]
467 )
468 fig.subplots_adjust(wspace=0.295)
469 if colorbar:
470 ax3 = fig.add_axes([0.22, 0.05, 0.55, 0.02])
471 cbar = fig.colorbar(
472 p, cax=ax3, orientation="horizontal", pad=0.2, shrink=0.5,
473 label='posterior probability per pixel'
474 )
475 return fig