Coverage for pesummary/core/plots/corner.py: 52.9%
174 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-02 08:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-02 08:42 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import numpy as np
4from scipy.stats import gaussian_kde
5from matplotlib.colors import LinearSegmentedColormap, colorConverter
7__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
10def _set_xlim(new_fig, ax, new_xlim):
11 if new_fig:
12 return ax.set_xlim(new_xlim)
13 xlim = ax.get_xlim()
14 return ax.set_xlim([min(xlim[0], new_xlim[0]), max(xlim[1], new_xlim[1])])
17def _set_ylim(new_fig, ax, new_ylim):
18 if new_fig:
19 return ax.set_ylim(new_ylim)
20 ylim = ax.get_ylim()
21 return ax.set_ylim([min(ylim[0], new_ylim[0]), max(ylim[1], new_ylim[1])])
24def hist2d(
25 x, y, bins=20, range=None, weights=None, levels=None, smooth=None, ax=None,
26 color=None, quiet=False, plot_datapoints=True, plot_density=True,
27 plot_contours=True, no_fill_contours=False, fill_contours=False,
28 contour_kwargs=None, contourf_kwargs=None, data_kwargs=None,
29 pcolor_kwargs=None, new_fig=True, kde=None, kde_kwargs={},
30 density_cmap=None, label=None, grid=True, **kwargs
31):
32 """Extension of the corner.hist2d function. Allows the user to specify the
33 kde used when estimating the 2d probability density
35 Parameters
36 ----------
37 x : array_like[nsamples,]
38 The samples.
39 y : array_like[nsamples,]
40 The samples.
41 quiet : bool
42 If true, suppress warnings for small datasets.
43 levels : array_like
44 The contour levels to draw.
45 ax : matplotlib.Axes
46 A axes instance on which to add the 2-D histogram.
47 plot_datapoints : bool
48 Draw the individual data points.
49 plot_density : bool
50 Draw the density colormap.
51 plot_contours : bool
52 Draw the contours.
53 no_fill_contours : bool
54 Add no filling at all to the contours (unlike setting
55 ``fill_contours=False``, which still adds a white fill at the densest
56 points).
57 fill_contours : bool
58 Fill the contours.
59 contour_kwargs : dict
60 Any additional keyword arguments to pass to the `contour` method.
61 contourf_kwargs : dict
62 Any additional keyword arguments to pass to the `contourf` method.
63 data_kwargs : dict
64 Any additional keyword arguments to pass to the `plot` method when
65 adding the individual data points.
66 pcolor_kwargs : dict
67 Any additional keyword arguments to pass to the `pcolor` method when
68 adding the density colormap.
69 kde: func, optional
70 KDE you wish to use to work out the 2d probability density
71 kde_kwargs: dict, optional
72 kwargs passed directly to kde
73 """
74 x = np.asarray(x)
75 y = np.asarray(y)
76 if kde is None:
77 kde = gaussian_kde
79 if ax is None:
80 raise ValueError("Please provide an axis to plot")
81 # Set the default range based on the data range if not provided.
82 if range is None:
83 range = [[x.min(), x.max()], [y.min(), y.max()]]
85 # Set up the default plotting arguments.
86 if color is None:
87 color = "k"
89 # Choose the default "sigma" contour levels.
90 if levels is None:
91 levels = 1.0 - np.exp(-0.5 * np.arange(0.5, 2.1, 0.5) ** 2)
93 # This is the color map for the density plot, over-plotted to indicate the
94 # density of the points near the center.
95 if density_cmap is None:
96 density_cmap = LinearSegmentedColormap.from_list(
97 "density_cmap", [color, (1, 1, 1, 0)]
98 )
99 elif isinstance(density_cmap, str):
100 from matplotlib import cm
102 density_cmap = cm.get_cmap(density_cmap)
104 # This color map is used to hide the points at the high density areas.
105 white_cmap = LinearSegmentedColormap.from_list(
106 "white_cmap", [(1, 1, 1), (1, 1, 1)], N=2
107 )
109 # This "color map" is the list of colors for the contour levels if the
110 # contours are filled.
111 rgba_color = colorConverter.to_rgba(color)
112 contour_cmap = [list(rgba_color) for l in levels] + [rgba_color]
113 for i, l in enumerate(levels):
114 contour_cmap[i][-1] *= float(i) / (len(levels) + 1)
116 # We'll make the 2D histogram to directly estimate the density.
117 try:
118 _, X, Y = np.histogram2d(
119 x.flatten(),
120 y.flatten(),
121 bins=bins,
122 range=list(map(np.sort, range)),
123 weights=weights,
124 )
125 except ValueError:
126 raise ValueError(
127 "It looks like at least one of your sample columns "
128 "have no dynamic range. You could try using the "
129 "'range' argument."
130 )
132 values = np.vstack([x.flatten(), y.flatten()])
133 kernel = kde(values, **kde_kwargs)
134 xmin, xmax = np.min(x.flatten()), np.max(x.flatten())
135 ymin, ymax = np.min(y.flatten()), np.max(y.flatten())
136 X, Y = np.meshgrid(X, Y)
137 pts = np.vstack([X.ravel(), Y.ravel()])
138 z = kernel(pts)
139 H = z.reshape(X.shape)
140 if smooth is not None:
141 if kde_kwargs.get("transform", None) is not None:
142 from pesummary.utils.utils import logger
143 logger.warning(
144 "Smoothing PDF. This may give unwanted effects especially near "
145 "any boundaries"
146 )
147 try:
148 from scipy.ndimage import gaussian_filter
149 except ImportError:
150 raise ImportError("Please install scipy for smoothing")
151 H = gaussian_filter(H, smooth)
153 if plot_contours or plot_density:
154 pass
156 if kde_kwargs is None:
157 kde_kwargs = dict()
158 if contour_kwargs is None:
159 contour_kwargs = dict()
161 if plot_datapoints:
162 if data_kwargs is None:
163 data_kwargs = dict()
164 data_kwargs["color"] = data_kwargs.get("color", color)
165 data_kwargs["ms"] = data_kwargs.get("ms", 2.0)
166 data_kwargs["mec"] = data_kwargs.get("mec", "none")
167 data_kwargs["alpha"] = data_kwargs.get("alpha", 0.1)
168 ax.plot(x, y, "o", zorder=-1, rasterized=True, **data_kwargs)
170 # Plot the base fill to hide the densest data points.
171 cs = ax.contour(
172 X, Y, H, levels=(1 - np.array(levels)) * np.max(H), alpha=0.
173 )
174 contour_set = []
175 for _contour in cs.collections:
176 _contour_set = []
177 for _path in _contour.get_paths():
178 data = _path.vertices
179 transpose = data.T
180 for idx, axis in enumerate(["x", "y"]):
181 limits = [
182 kde_kwargs.get("{}low".format(axis), -np.inf),
183 kde_kwargs.get("{}high".format(axis), np.inf)
184 ]
185 if kde_kwargs.get("transform", None) is None:
186 if limits[0] is not None:
187 transpose[idx][
188 np.argwhere(transpose[idx] < limits[0])
189 ] = limits[0]
190 if limits[1] is not None:
191 transpose[idx][
192 np.argwhere(transpose[idx] > limits[1])
193 ] = limits[1]
194 else:
195 _transform = kde_kwargs["transform"](transpose)
196 _contour_set.append(transpose)
197 contour_set.append(_contour_set)
199 # Plot the density map. This can't be plotted at the same time as the
200 # contour fills.
201 if plot_density:
202 if pcolor_kwargs is None:
203 pcolor_kwargs = dict()
204 pcolor_kwargs["shading"] = "auto"
205 ax.pcolor(X, Y, np.max(H) - H, cmap=density_cmap, **pcolor_kwargs)
207 # Plot the contour edge colors.
208 if plot_contours:
209 colors = contour_kwargs.pop("colors", color)
210 linestyles = kwargs.pop("linestyles", "-")
211 _list = [colors, linestyles]
212 for num, (prop, default) in enumerate(zip(_list, ['k', '-'])):
213 if prop is None:
214 _list[num] = default * len(contour_set)
215 elif isinstance(prop, str):
216 _list[num] = [prop] * len(contour_set)
217 elif len(prop) < len(contour_set):
218 raise ValueError(
219 "Please provide a color/linestyle for each contour"
220 )
221 for idx, _contour in enumerate(contour_set):
222 for _idx, _path in enumerate(_contour):
223 if idx == 0 and _idx == 0:
224 _label = label
225 else:
226 _label = None
227 ax.plot(
228 *_path, color=_list[0][idx], label=_label,
229 linestyle=_list[1][idx]
230 )
232 _set_xlim(new_fig, ax, range[0])
233 _set_ylim(new_fig, ax, range[1])
236def corner(
237 samples, parameters, bins=20, *,
238 # Original corner parameters
239 range=None, axes_scale="linear", weights=None, color='k',
240 hist_bin_factor=1, smooth=None, smooth1d=None, labels=None,
241 label_kwargs=None, titles=None, show_titles=False,
242 title_quantiles=None, title_fmt=".2f", title_kwargs=None,
243 truths=None, truth_color="#4682b4", scale_hist=False,
244 quantiles=None, verbose=False, fig=None, max_n_ticks=5,
245 top_ticks=False, use_math_text=False, reverse=False,
246 labelpad=0.0, hist_kwargs={},
247 # Arviz parameters
248 group="posterior", var_names=None, filter_vars=None,
249 coords=None, divergences=False, divergences_kwargs=None,
250 labeller=None,
251 # New parameters
252 kde=None, kde_kwargs={}, kde_2d=None, kde_2d_kwargs={},
253 N=100, **hist2d_kwargs,
254):
255 """Wrapper for corner.corner which adds additional functionality
256 to plot custom KDEs along the leading diagonal and custom 2D
257 KDEs in the 2D panels
258 """
259 from corner import corner
260 if kde is not None:
261 hist_kwargs["linewidth"] = 0.
262 if kde_2d is not None:
263 linewidths = [1.]
264 hist2d_kwargs = hist2d_kwargs.copy()
265 if hist2d_kwargs.get("plot_contours", False):
266 if "contour_kwargs" not in hist2d_kwargs.keys():
267 hist2d_kwargs["contour_kwargs"] = {}
268 linewidths = hist2d_kwargs["contour_kwargs"].get("linewidths", None)
269 hist2d_kwargs["contour_kwargs"]["linewidths"] = 0.
270 plot_density = hist2d_kwargs.get("plot_density", True)
271 fill_contours = hist2d_kwargs.get("fill_contours", False)
272 plot_contours = hist2d_kwargs.get("plot_contours", True)
273 if plot_density:
274 hist2d_kwargs["plot_density"] = False
275 if fill_contours:
276 hist2d_kwargs["fill_contours"] = False
277 hist2d_kwargs["plot_contours"] = False
279 fig = corner(
280 samples, range=range, axes_scale=axes_scale, weights=weights,
281 color=color, hist_bin_factor=hist_bin_factor, smooth=smooth,
282 smooth1d=smooth1d, labels=labels, label_kwargs=label_kwargs,
283 titles=titles, show_titles=show_titles, title_quantiles=title_quantiles,
284 title_fmt=title_fmt, title_kwargs=title_kwargs, truths=truths,
285 truth_color=truth_color, scale_hist=scale_hist,
286 quantiles=quantiles, verbose=verbose, fig=fig,
287 max_n_ticks=max_n_ticks, top_ticks=top_ticks,
288 use_math_text=use_math_text, reverse=reverse,
289 labelpad=labelpad, hist_kwargs=hist_kwargs,
290 # Arviz parameters
291 group=group, var_names=var_names, filter_vars=filter_vars,
292 coords=coords, divergences=divergences,
293 divergences_kwargs=divergences_kwargs, labeller=labeller,
294 **hist2d_kwargs
295 )
296 if kde is None and kde_2d is None:
297 return fig
298 axs = np.array(fig.get_axes(), dtype=object).reshape(
299 len(parameters), len(parameters)
300 )
301 if kde is not None:
302 for num, param in enumerate(parameters):
303 if param in kde_kwargs.keys():
304 _kwargs = kde_kwargs[param]
305 else:
306 _kwargs = {}
307 for key, item in kde_kwargs.items():
308 if key not in parameters:
309 _kwargs[key] = item
310 _kde = kde(samples[:,num], **_kwargs)
311 xs = np.linspace(np.min(samples[:,num]), np.max(samples[:,num]), N)
312 axs[num, num].plot(
313 xs, _kde(xs), color=color
314 )
315 if kde_2d is not None:
316 _hist2d_kwargs = hist2d_kwargs.copy()
317 _contour_kwargs = hist2d_kwargs.pop("contour_kwargs", {})
318 _contour_kwargs["linewidths"] = linewidths
319 _hist2d_kwargs.update(
320 {
321 "plot_contours": plot_contours,
322 "plot_density": plot_density,
323 "fill_contours": fill_contours,
324 "levels": hist2d_kwargs.pop("levels")[::-1],
325 "contour_kwargs": _contour_kwargs
326 }
327 )
328 for i, x in enumerate(parameters):
329 for j, y in enumerate(parameters):
330 if j >= i:
331 continue
332 _kde_2d_kwargs = {}
333 _xkwargs = kde_2d_kwargs.get(x, kde_2d_kwargs)
334 if "low" in _xkwargs.keys():
335 _xkwargs["ylow"] = _xkwargs.pop("low")
336 if "high" in _xkwargs.keys():
337 _xkwargs["yhigh"] = _xkwargs.pop("high")
338 _kde_2d_kwargs.update(_xkwargs)
339 _ykwargs = kde_2d_kwargs.get(y, kde_2d_kwargs)
340 if "low" in _ykwargs.keys():
341 _ykwargs["xlow"] = _ykwargs.pop("low")
342 if "high" in _ykwargs.keys():
343 _ykwargs["xhigh"] = _ykwargs.pop("high")
344 _kde_2d_kwargs.update(_ykwargs)
345 for key, item in kde_2d_kwargs.items():
346 if key not in parameters:
347 _kde_2d_kwargs[key] = item
348 hist2d(
349 samples[:,j], samples[:,i],
350 ax=axs[i, j], color=color,
351 kde=kde_2d, kde_kwargs=_kde_2d_kwargs,
352 bins=bins, **_hist2d_kwargs
353 )
354 return fig