Coverage for pesummary/core/plots/corner.py: 50.8%
183 statements
« prev ^ index » next coverage.py v7.4.4, created at 2026-01-15 17:49 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2026-01-15 17:49 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import copy
4import numpy as np
5from scipy.stats import gaussian_kde
6from matplotlib.colors import LinearSegmentedColormap, colorConverter
8__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
11def _set_xlim(new_fig, ax, new_xlim):
12 if new_fig:
13 return ax.set_xlim(new_xlim)
14 xlim = ax.get_xlim()
15 return ax.set_xlim([min(xlim[0], new_xlim[0]), max(xlim[1], new_xlim[1])])
18def _set_ylim(new_fig, ax, new_ylim):
19 if new_fig:
20 return ax.set_ylim(new_ylim)
21 ylim = ax.get_ylim()
22 return ax.set_ylim([min(ylim[0], new_ylim[0]), max(ylim[1], new_ylim[1])])
25def hist2d(
26 x, y, bins=20, range=None, weights=None, levels=None, smooth=None, ax=None,
27 color=None, quiet=False, plot_datapoints=True, plot_density=True,
28 plot_contours=True, no_fill_contours=False, fill_contours=False,
29 contour_kwargs=None, contourf_kwargs=None, data_kwargs=None,
30 pcolor_kwargs=None, new_fig=True, kde=None, kde_kwargs={},
31 density_cmap=None, label=None, grid=True, **kwargs
32):
33 """Extension of the corner.hist2d function. Allows the user to specify the
34 kde used when estimating the 2d probability density
36 Parameters
37 ----------
38 x : array_like[nsamples,]
39 The samples.
40 y : array_like[nsamples,]
41 The samples.
42 quiet : bool
43 If true, suppress warnings for small datasets.
44 levels : array_like
45 The contour levels to draw.
46 ax : matplotlib.Axes
47 A axes instance on which to add the 2-D histogram.
48 plot_datapoints : bool
49 Draw the individual data points.
50 plot_density : bool
51 Draw the density colormap.
52 plot_contours : bool
53 Draw the contours.
54 no_fill_contours : bool
55 Add no filling at all to the contours (unlike setting
56 ``fill_contours=False``, which still adds a white fill at the densest
57 points).
58 fill_contours : bool
59 Fill the contours.
60 contour_kwargs : dict
61 Any additional keyword arguments to pass to the `contour` method.
62 contourf_kwargs : dict
63 Any additional keyword arguments to pass to the `contourf` method.
64 data_kwargs : dict
65 Any additional keyword arguments to pass to the `plot` method when
66 adding the individual data points.
67 pcolor_kwargs : dict
68 Any additional keyword arguments to pass to the `pcolor` method when
69 adding the density colormap.
70 kde: func, optional
71 KDE you wish to use to work out the 2d probability density
72 kde_kwargs: dict, optional
73 kwargs passed directly to kde
74 """
75 x = np.asarray(x)
76 y = np.asarray(y)
77 if kde is None:
78 kde = gaussian_kde
80 if ax is None:
81 raise ValueError("Please provide an axis to plot")
82 # Set the default range based on the data range if not provided.
83 if range is None:
84 range = [[x.min(), x.max()], [y.min(), y.max()]]
86 # Set up the default plotting arguments.
87 if color is None:
88 color = "k"
90 # Choose the default "sigma" contour levels.
91 if levels is None:
92 levels = 1.0 - np.exp(-0.5 * np.arange(0.5, 2.1, 0.5) ** 2)
94 # This is the color map for the density plot, over-plotted to indicate the
95 # density of the points near the center.
96 if density_cmap is None:
97 density_cmap = LinearSegmentedColormap.from_list(
98 "density_cmap", [color, (1, 1, 1, 0)]
99 )
100 elif isinstance(density_cmap, str):
101 from matplotlib import cm
103 density_cmap = cm.get_cmap(density_cmap)
105 # This color map is used to hide the points at the high density areas.
106 white_cmap = LinearSegmentedColormap.from_list(
107 "white_cmap", [(1, 1, 1), (1, 1, 1)], N=2
108 )
110 # This "color map" is the list of colors for the contour levels if the
111 # contours are filled.
112 rgba_color = colorConverter.to_rgba(color)
113 contour_cmap = [list(rgba_color) for l in levels] + [rgba_color]
114 for i, l in enumerate(levels):
115 contour_cmap[i][-1] *= float(i) / (len(levels) + 1)
117 # We'll make the 2D histogram to directly estimate the density.
118 try:
119 _, X, Y = np.histogram2d(
120 x.flatten(),
121 y.flatten(),
122 bins=bins,
123 range=list(map(np.sort, range)),
124 weights=weights,
125 )
126 except ValueError:
127 raise ValueError(
128 "It looks like at least one of your sample columns "
129 "have no dynamic range. You could try using the "
130 "'range' argument."
131 )
133 values = np.vstack([x.flatten(), y.flatten()])
134 kernel = kde(values, weights=weights, **kde_kwargs)
135 xmin, xmax = np.min(x.flatten()), np.max(x.flatten())
136 ymin, ymax = np.min(y.flatten()), np.max(y.flatten())
137 X, Y = np.meshgrid(X, Y)
138 pts = np.vstack([X.ravel(), Y.ravel()])
139 z = kernel(pts)
140 H = z.reshape(X.shape)
141 if smooth is not None:
142 if kde_kwargs.get("transform", None) is not None:
143 from pesummary.utils.utils import logger
144 logger.warning(
145 "Smoothing PDF. This may give unwanted effects especially near "
146 "any boundaries"
147 )
148 try:
149 from scipy.ndimage import gaussian_filter
150 except ImportError:
151 raise ImportError("Please install scipy for smoothing")
152 H = gaussian_filter(H, smooth)
154 if plot_contours or plot_density:
155 pass
157 if kde_kwargs is None:
158 kde_kwargs = dict()
159 if contour_kwargs is None:
160 contour_kwargs = dict()
162 if plot_datapoints:
163 if data_kwargs is None:
164 data_kwargs = dict()
165 data_kwargs["color"] = data_kwargs.get("color", color)
166 data_kwargs["ms"] = data_kwargs.get("ms", 2.0)
167 data_kwargs["mec"] = data_kwargs.get("mec", "none")
168 data_kwargs["alpha"] = data_kwargs.get("alpha", 0.1)
169 if weights is None:
170 ax.plot(x, y, "o", zorder=-1, rasterized=True, **data_kwargs)
171 else:
172 _weights = copy.deepcopy(weights)
173 _weights /= np.max(_weights)
174 idxs = np.argsort(_weights)
175 for num, (xx, yy) in enumerate(zip(x[idxs], y[idxs])):
176 _data_kwargs = data_kwargs.copy()
177 _data_kwargs["alpha"] *= _weights[num]
178 ax.plot(xx, yy, "o", zorder=-1, rasterized=True, **_data_kwargs)
180 # Plot the base fill to hide the densest data points.
181 cs = ax.contour(
182 X, Y, H, levels=(1 - np.array(levels)) * np.max(H), alpha=0.
183 )
184 contour_set = []
185 for _path in cs.get_paths():
186 _contour_set = []
187 for data in _path.to_polygons():
188 transpose = data.T
189 for idx, axis in enumerate(["x", "y"]):
190 limits = [
191 kde_kwargs.get("{}low".format(axis), -np.inf),
192 kde_kwargs.get("{}high".format(axis), np.inf)
193 ]
194 if kde_kwargs.get("transform", None) is None:
195 if limits[0] is not None:
196 transpose[idx][
197 np.argwhere(transpose[idx] < limits[0])
198 ] = limits[0]
199 if limits[1] is not None:
200 transpose[idx][
201 np.argwhere(transpose[idx] > limits[1])
202 ] = limits[1]
203 else:
204 _transform = kde_kwargs["transform"](transpose)
205 _contour_set.append(transpose)
206 contour_set.append(_contour_set)
208 # Plot the density map. This can't be plotted at the same time as the
209 # contour fills.
210 if plot_density:
211 if pcolor_kwargs is None:
212 pcolor_kwargs = dict()
213 pcolor_kwargs["shading"] = "auto"
214 ax.pcolor(X, Y, np.max(H) - H, cmap=density_cmap, **pcolor_kwargs)
216 # Plot the contour edge colors.
217 if plot_contours:
218 colors = contour_kwargs.pop("colors", color)
219 linestyles = kwargs.pop("linestyles", "-")
220 _list = [colors, linestyles]
221 for num, (prop, default) in enumerate(zip(_list, ['k', '-'])):
222 if prop is None:
223 _list[num] = default * len(contour_set)
224 elif isinstance(prop, str):
225 _list[num] = [prop] * len(contour_set)
226 elif len(prop) < len(contour_set):
227 raise ValueError(
228 "Please provide a color/linestyle for each contour"
229 )
230 for idx, _contour in enumerate(contour_set):
231 for _idx, _path in enumerate(_contour):
232 if idx == 0 and _idx == 0:
233 _label = label
234 else:
235 _label = None
236 ax.plot(
237 *_path, color=_list[0][idx], label=_label,
238 linestyle=_list[1][idx]
239 )
241 _set_xlim(new_fig, ax, range[0])
242 _set_ylim(new_fig, ax, range[1])
245def corner(
246 samples, parameters, bins=20, *,
247 # Original corner parameters
248 range=None, axes_scale="linear", weights=None, color='k',
249 hist_bin_factor=1, smooth=None, smooth1d=None, labels=None,
250 label_kwargs=None, titles=None, show_titles=False,
251 title_quantiles=None, title_fmt=".2f", title_kwargs=None,
252 truths=None, truth_color="#4682b4", scale_hist=False,
253 quantiles=None, verbose=False, fig=None, max_n_ticks=5,
254 top_ticks=False, use_math_text=False, reverse=False,
255 labelpad=0.0, hist_kwargs={},
256 # Arviz parameters
257 group="posterior", var_names=None, filter_vars=None,
258 coords=None, divergences=False, divergences_kwargs=None,
259 labeller=None,
260 # New parameters
261 kde=None, kde_kwargs={}, kde_2d=None, kde_2d_kwargs={},
262 N=100, **hist2d_kwargs,
263):
264 """Wrapper for corner.corner which adds additional functionality
265 to plot custom KDEs along the leading diagonal and custom 2D
266 KDEs in the 2D panels
267 """
268 from corner import corner
269 if kde is not None:
270 hist_kwargs["linewidth"] = 0.
271 if kde_2d is not None:
272 linewidths = [1.]
273 hist2d_kwargs = hist2d_kwargs.copy()
274 if hist2d_kwargs.get("plot_contours", False):
275 if "contour_kwargs" not in hist2d_kwargs.keys():
276 hist2d_kwargs["contour_kwargs"] = {}
277 linewidths = hist2d_kwargs["contour_kwargs"].get("linewidths", None)
278 hist2d_kwargs["contour_kwargs"]["linewidths"] = 0.
279 plot_density = hist2d_kwargs.get("plot_density", True)
280 fill_contours = hist2d_kwargs.get("fill_contours", False)
281 plot_contours = hist2d_kwargs.get("plot_contours", True)
282 if plot_density:
283 hist2d_kwargs["plot_density"] = False
284 if fill_contours:
285 hist2d_kwargs["fill_contours"] = False
286 hist2d_kwargs["plot_contours"] = False
288 fig = corner(
289 samples, range=range, axes_scale=axes_scale, weights=weights,
290 color=color, hist_bin_factor=hist_bin_factor, smooth=smooth,
291 smooth1d=smooth1d, labels=labels, label_kwargs=label_kwargs,
292 titles=titles, show_titles=show_titles, title_quantiles=title_quantiles,
293 title_fmt=title_fmt, title_kwargs=title_kwargs, truths=truths,
294 truth_color=truth_color, scale_hist=scale_hist,
295 quantiles=quantiles, verbose=verbose, fig=fig,
296 max_n_ticks=max_n_ticks, top_ticks=top_ticks,
297 use_math_text=use_math_text, reverse=reverse,
298 labelpad=labelpad, hist_kwargs=hist_kwargs, bins=bins,
299 # Arviz parameters
300 group=group, var_names=var_names, filter_vars=filter_vars,
301 coords=coords, divergences=divergences,
302 divergences_kwargs=divergences_kwargs, labeller=labeller,
303 **hist2d_kwargs
304 )
305 if kde is None and kde_2d is None:
306 return fig
307 axs = np.array(fig.get_axes(), dtype=object).reshape(
308 len(parameters), len(parameters)
309 )
310 if kde is not None:
311 for num, param in enumerate(parameters):
312 if param in kde_kwargs.keys():
313 _kwargs = kde_kwargs[param]
314 else:
315 _kwargs = {}
316 for key, item in kde_kwargs.items():
317 if key not in parameters:
318 _kwargs[key] = item
319 _kde = kde(samples[:,num], **_kwargs)
320 xs = np.linspace(np.min(samples[:,num]), np.max(samples[:,num]), N)
321 axs[num, num].plot(
322 xs, _kde(xs), color=color
323 )
324 if kde_2d is not None:
325 _hist2d_kwargs = hist2d_kwargs.copy()
326 _contour_kwargs = hist2d_kwargs.pop("contour_kwargs", {})
327 _contour_kwargs["linewidths"] = linewidths
328 _hist2d_kwargs.update(
329 {
330 "plot_contours": plot_contours,
331 "plot_density": plot_density,
332 "fill_contours": fill_contours,
333 "levels": hist2d_kwargs.pop("levels")[::-1],
334 "contour_kwargs": _contour_kwargs
335 }
336 )
337 for i, x in enumerate(parameters):
338 for j, y in enumerate(parameters):
339 if j >= i:
340 continue
341 _kde_2d_kwargs = {}
342 _xkwargs = kde_2d_kwargs.get(x, kde_2d_kwargs)
343 if "low" in _xkwargs.keys():
344 _xkwargs["ylow"] = _xkwargs.pop("low")
345 if "high" in _xkwargs.keys():
346 _xkwargs["yhigh"] = _xkwargs.pop("high")
347 _kde_2d_kwargs.update(_xkwargs)
348 _ykwargs = kde_2d_kwargs.get(y, kde_2d_kwargs)
349 if "low" in _ykwargs.keys():
350 _ykwargs["xlow"] = _ykwargs.pop("low")
351 if "high" in _ykwargs.keys():
352 _ykwargs["xhigh"] = _ykwargs.pop("high")
353 _kde_2d_kwargs.update(_ykwargs)
354 for key, item in kde_2d_kwargs.items():
355 if key not in parameters:
356 _kde_2d_kwargs[key] = item
357 hist2d(
358 samples[:,j], samples[:,i],
359 ax=axs[i, j], color=color,
360 kde=kde_2d, kde_kwargs=_kde_2d_kwargs,
361 bins=bins, **_hist2d_kwargs
362 )
363 return fig