Coverage for pesummary/core/plots/corner.py: 50.3%
183 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import copy
4import numpy as np
5from scipy.stats import gaussian_kde
6from matplotlib.colors import LinearSegmentedColormap, colorConverter
8__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
11def _set_xlim(new_fig, ax, new_xlim):
12 if new_fig:
13 return ax.set_xlim(new_xlim)
14 xlim = ax.get_xlim()
15 return ax.set_xlim([min(xlim[0], new_xlim[0]), max(xlim[1], new_xlim[1])])
18def _set_ylim(new_fig, ax, new_ylim):
19 if new_fig:
20 return ax.set_ylim(new_ylim)
21 ylim = ax.get_ylim()
22 return ax.set_ylim([min(ylim[0], new_ylim[0]), max(ylim[1], new_ylim[1])])
25def hist2d(
26 x, y, bins=20, range=None, weights=None, levels=None, smooth=None, ax=None,
27 color=None, quiet=False, plot_datapoints=True, plot_density=True,
28 plot_contours=True, no_fill_contours=False, fill_contours=False,
29 contour_kwargs=None, contourf_kwargs=None, data_kwargs=None,
30 pcolor_kwargs=None, new_fig=True, kde=None, kde_kwargs={},
31 density_cmap=None, label=None, grid=True, **kwargs
32):
33 """Extension of the corner.hist2d function. Allows the user to specify the
34 kde used when estimating the 2d probability density
36 Parameters
37 ----------
38 x : array_like[nsamples,]
39 The samples.
40 y : array_like[nsamples,]
41 The samples.
42 quiet : bool
43 If true, suppress warnings for small datasets.
44 levels : array_like
45 The contour levels to draw.
46 ax : matplotlib.Axes
47 A axes instance on which to add the 2-D histogram.
48 plot_datapoints : bool
49 Draw the individual data points.
50 plot_density : bool
51 Draw the density colormap.
52 plot_contours : bool
53 Draw the contours.
54 no_fill_contours : bool
55 Add no filling at all to the contours (unlike setting
56 ``fill_contours=False``, which still adds a white fill at the densest
57 points).
58 fill_contours : bool
59 Fill the contours.
60 contour_kwargs : dict
61 Any additional keyword arguments to pass to the `contour` method.
62 contourf_kwargs : dict
63 Any additional keyword arguments to pass to the `contourf` method.
64 data_kwargs : dict
65 Any additional keyword arguments to pass to the `plot` method when
66 adding the individual data points.
67 pcolor_kwargs : dict
68 Any additional keyword arguments to pass to the `pcolor` method when
69 adding the density colormap.
70 kde: func, optional
71 KDE you wish to use to work out the 2d probability density
72 kde_kwargs: dict, optional
73 kwargs passed directly to kde
74 """
75 x = np.asarray(x)
76 y = np.asarray(y)
77 if kde is None:
78 kde = gaussian_kde
80 if ax is None:
81 raise ValueError("Please provide an axis to plot")
82 # Set the default range based on the data range if not provided.
83 if range is None:
84 range = [[x.min(), x.max()], [y.min(), y.max()]]
86 # Set up the default plotting arguments.
87 if color is None:
88 color = "k"
90 # Choose the default "sigma" contour levels.
91 if levels is None:
92 levels = 1.0 - np.exp(-0.5 * np.arange(0.5, 2.1, 0.5) ** 2)
94 # This is the color map for the density plot, over-plotted to indicate the
95 # density of the points near the center.
96 if density_cmap is None:
97 density_cmap = LinearSegmentedColormap.from_list(
98 "density_cmap", [color, (1, 1, 1, 0)]
99 )
100 elif isinstance(density_cmap, str):
101 from matplotlib import cm
103 density_cmap = cm.get_cmap(density_cmap)
105 # This color map is used to hide the points at the high density areas.
106 white_cmap = LinearSegmentedColormap.from_list(
107 "white_cmap", [(1, 1, 1), (1, 1, 1)], N=2
108 )
110 # This "color map" is the list of colors for the contour levels if the
111 # contours are filled.
112 rgba_color = colorConverter.to_rgba(color)
113 contour_cmap = [list(rgba_color) for l in levels] + [rgba_color]
114 for i, l in enumerate(levels):
115 contour_cmap[i][-1] *= float(i) / (len(levels) + 1)
117 # We'll make the 2D histogram to directly estimate the density.
118 try:
119 _, X, Y = np.histogram2d(
120 x.flatten(),
121 y.flatten(),
122 bins=bins,
123 range=list(map(np.sort, range)),
124 weights=weights,
125 )
126 except ValueError:
127 raise ValueError(
128 "It looks like at least one of your sample columns "
129 "have no dynamic range. You could try using the "
130 "'range' argument."
131 )
133 values = np.vstack([x.flatten(), y.flatten()])
134 kernel = kde(values, weights=weights, **kde_kwargs)
135 xmin, xmax = np.min(x.flatten()), np.max(x.flatten())
136 ymin, ymax = np.min(y.flatten()), np.max(y.flatten())
137 X, Y = np.meshgrid(X, Y)
138 pts = np.vstack([X.ravel(), Y.ravel()])
139 z = kernel(pts)
140 H = z.reshape(X.shape)
141 if smooth is not None:
142 if kde_kwargs.get("transform", None) is not None:
143 from pesummary.utils.utils import logger
144 logger.warning(
145 "Smoothing PDF. This may give unwanted effects especially near "
146 "any boundaries"
147 )
148 try:
149 from scipy.ndimage import gaussian_filter
150 except ImportError:
151 raise ImportError("Please install scipy for smoothing")
152 H = gaussian_filter(H, smooth)
154 if plot_contours or plot_density:
155 pass
157 if kde_kwargs is None:
158 kde_kwargs = dict()
159 if contour_kwargs is None:
160 contour_kwargs = dict()
162 if plot_datapoints:
163 if data_kwargs is None:
164 data_kwargs = dict()
165 data_kwargs["color"] = data_kwargs.get("color", color)
166 data_kwargs["ms"] = data_kwargs.get("ms", 2.0)
167 data_kwargs["mec"] = data_kwargs.get("mec", "none")
168 data_kwargs["alpha"] = data_kwargs.get("alpha", 0.1)
169 if weights is None:
170 ax.plot(x, y, "o", zorder=-1, rasterized=True, **data_kwargs)
171 else:
172 _weights = copy.deepcopy(weights)
173 _weights /= np.max(_weights)
174 idxs = np.argsort(_weights)
175 for num, (xx, yy) in enumerate(zip(x[idxs], y[idxs])):
176 _data_kwargs = data_kwargs.copy()
177 _data_kwargs["alpha"] *= _weights[num]
178 ax.plot(xx, yy, "o", zorder=-1, rasterized=True, **_data_kwargs)
180 # Plot the base fill to hide the densest data points.
181 cs = ax.contour(
182 X, Y, H, levels=(1 - np.array(levels)) * np.max(H), alpha=0.
183 )
184 contour_set = []
185 for _contour in cs.collections:
186 _contour_set = []
187 for _path in _contour.get_paths():
188 data = _path.vertices
189 transpose = data.T
190 for idx, axis in enumerate(["x", "y"]):
191 limits = [
192 kde_kwargs.get("{}low".format(axis), -np.inf),
193 kde_kwargs.get("{}high".format(axis), np.inf)
194 ]
195 if kde_kwargs.get("transform", None) is None:
196 if limits[0] is not None:
197 transpose[idx][
198 np.argwhere(transpose[idx] < limits[0])
199 ] = limits[0]
200 if limits[1] is not None:
201 transpose[idx][
202 np.argwhere(transpose[idx] > limits[1])
203 ] = limits[1]
204 else:
205 _transform = kde_kwargs["transform"](transpose)
206 _contour_set.append(transpose)
207 contour_set.append(_contour_set)
209 # Plot the density map. This can't be plotted at the same time as the
210 # contour fills.
211 if plot_density:
212 if pcolor_kwargs is None:
213 pcolor_kwargs = dict()
214 pcolor_kwargs["shading"] = "auto"
215 ax.pcolor(X, Y, np.max(H) - H, cmap=density_cmap, **pcolor_kwargs)
217 # Plot the contour edge colors.
218 if plot_contours:
219 colors = contour_kwargs.pop("colors", color)
220 linestyles = kwargs.pop("linestyles", "-")
221 _list = [colors, linestyles]
222 for num, (prop, default) in enumerate(zip(_list, ['k', '-'])):
223 if prop is None:
224 _list[num] = default * len(contour_set)
225 elif isinstance(prop, str):
226 _list[num] = [prop] * len(contour_set)
227 elif len(prop) < len(contour_set):
228 raise ValueError(
229 "Please provide a color/linestyle for each contour"
230 )
231 for idx, _contour in enumerate(contour_set):
232 for _idx, _path in enumerate(_contour):
233 if idx == 0 and _idx == 0:
234 _label = label
235 else:
236 _label = None
237 ax.plot(
238 *_path, color=_list[0][idx], label=_label,
239 linestyle=_list[1][idx]
240 )
242 _set_xlim(new_fig, ax, range[0])
243 _set_ylim(new_fig, ax, range[1])
246def corner(
247 samples, parameters, bins=20, *,
248 # Original corner parameters
249 range=None, axes_scale="linear", weights=None, color='k',
250 hist_bin_factor=1, smooth=None, smooth1d=None, labels=None,
251 label_kwargs=None, titles=None, show_titles=False,
252 title_quantiles=None, title_fmt=".2f", title_kwargs=None,
253 truths=None, truth_color="#4682b4", scale_hist=False,
254 quantiles=None, verbose=False, fig=None, max_n_ticks=5,
255 top_ticks=False, use_math_text=False, reverse=False,
256 labelpad=0.0, hist_kwargs={},
257 # Arviz parameters
258 group="posterior", var_names=None, filter_vars=None,
259 coords=None, divergences=False, divergences_kwargs=None,
260 labeller=None,
261 # New parameters
262 kde=None, kde_kwargs={}, kde_2d=None, kde_2d_kwargs={},
263 N=100, **hist2d_kwargs,
264):
265 """Wrapper for corner.corner which adds additional functionality
266 to plot custom KDEs along the leading diagonal and custom 2D
267 KDEs in the 2D panels
268 """
269 from corner import corner
270 if kde is not None:
271 hist_kwargs["linewidth"] = 0.
272 if kde_2d is not None:
273 linewidths = [1.]
274 hist2d_kwargs = hist2d_kwargs.copy()
275 if hist2d_kwargs.get("plot_contours", False):
276 if "contour_kwargs" not in hist2d_kwargs.keys():
277 hist2d_kwargs["contour_kwargs"] = {}
278 linewidths = hist2d_kwargs["contour_kwargs"].get("linewidths", None)
279 hist2d_kwargs["contour_kwargs"]["linewidths"] = 0.
280 plot_density = hist2d_kwargs.get("plot_density", True)
281 fill_contours = hist2d_kwargs.get("fill_contours", False)
282 plot_contours = hist2d_kwargs.get("plot_contours", True)
283 if plot_density:
284 hist2d_kwargs["plot_density"] = False
285 if fill_contours:
286 hist2d_kwargs["fill_contours"] = False
287 hist2d_kwargs["plot_contours"] = False
289 fig = corner(
290 samples, range=range, axes_scale=axes_scale, weights=weights,
291 color=color, hist_bin_factor=hist_bin_factor, smooth=smooth,
292 smooth1d=smooth1d, labels=labels, label_kwargs=label_kwargs,
293 titles=titles, show_titles=show_titles, title_quantiles=title_quantiles,
294 title_fmt=title_fmt, title_kwargs=title_kwargs, truths=truths,
295 truth_color=truth_color, scale_hist=scale_hist,
296 quantiles=quantiles, verbose=verbose, fig=fig,
297 max_n_ticks=max_n_ticks, top_ticks=top_ticks,
298 use_math_text=use_math_text, reverse=reverse,
299 labelpad=labelpad, hist_kwargs=hist_kwargs,
300 # Arviz parameters
301 group=group, var_names=var_names, filter_vars=filter_vars,
302 coords=coords, divergences=divergences,
303 divergences_kwargs=divergences_kwargs, labeller=labeller,
304 **hist2d_kwargs
305 )
306 if kde is None and kde_2d is None:
307 return fig
308 axs = np.array(fig.get_axes(), dtype=object).reshape(
309 len(parameters), len(parameters)
310 )
311 if kde is not None:
312 for num, param in enumerate(parameters):
313 if param in kde_kwargs.keys():
314 _kwargs = kde_kwargs[param]
315 else:
316 _kwargs = {}
317 for key, item in kde_kwargs.items():
318 if key not in parameters:
319 _kwargs[key] = item
320 _kde = kde(samples[:,num], **_kwargs)
321 xs = np.linspace(np.min(samples[:,num]), np.max(samples[:,num]), N)
322 axs[num, num].plot(
323 xs, _kde(xs), color=color
324 )
325 if kde_2d is not None:
326 _hist2d_kwargs = hist2d_kwargs.copy()
327 _contour_kwargs = hist2d_kwargs.pop("contour_kwargs", {})
328 _contour_kwargs["linewidths"] = linewidths
329 _hist2d_kwargs.update(
330 {
331 "plot_contours": plot_contours,
332 "plot_density": plot_density,
333 "fill_contours": fill_contours,
334 "levels": hist2d_kwargs.pop("levels")[::-1],
335 "contour_kwargs": _contour_kwargs
336 }
337 )
338 for i, x in enumerate(parameters):
339 for j, y in enumerate(parameters):
340 if j >= i:
341 continue
342 _kde_2d_kwargs = {}
343 _xkwargs = kde_2d_kwargs.get(x, kde_2d_kwargs)
344 if "low" in _xkwargs.keys():
345 _xkwargs["ylow"] = _xkwargs.pop("low")
346 if "high" in _xkwargs.keys():
347 _xkwargs["yhigh"] = _xkwargs.pop("high")
348 _kde_2d_kwargs.update(_xkwargs)
349 _ykwargs = kde_2d_kwargs.get(y, kde_2d_kwargs)
350 if "low" in _ykwargs.keys():
351 _ykwargs["xlow"] = _ykwargs.pop("low")
352 if "high" in _ykwargs.keys():
353 _ykwargs["xhigh"] = _ykwargs.pop("high")
354 _kde_2d_kwargs.update(_ykwargs)
355 for key, item in kde_2d_kwargs.items():
356 if key not in parameters:
357 _kde_2d_kwargs[key] = item
358 hist2d(
359 samples[:,j], samples[:,i],
360 ax=axs[i, j], color=color,
361 kde=kde_2d, kde_kwargs=_kde_2d_kwargs,
362 bins=bins, **_hist2d_kwargs
363 )
364 return fig