Coverage for pesummary/core/plots/corner.py: 50.8%

183 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2026-01-15 17:49 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import copy 

4import numpy as np 

5from scipy.stats import gaussian_kde 

6from matplotlib.colors import LinearSegmentedColormap, colorConverter 

7 

8__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

9 

10 

11def _set_xlim(new_fig, ax, new_xlim): 

12 if new_fig: 

13 return ax.set_xlim(new_xlim) 

14 xlim = ax.get_xlim() 

15 return ax.set_xlim([min(xlim[0], new_xlim[0]), max(xlim[1], new_xlim[1])]) 

16 

17 

18def _set_ylim(new_fig, ax, new_ylim): 

19 if new_fig: 

20 return ax.set_ylim(new_ylim) 

21 ylim = ax.get_ylim() 

22 return ax.set_ylim([min(ylim[0], new_ylim[0]), max(ylim[1], new_ylim[1])]) 

23 

24 

25def hist2d( 

26 x, y, bins=20, range=None, weights=None, levels=None, smooth=None, ax=None, 

27 color=None, quiet=False, plot_datapoints=True, plot_density=True, 

28 plot_contours=True, no_fill_contours=False, fill_contours=False, 

29 contour_kwargs=None, contourf_kwargs=None, data_kwargs=None, 

30 pcolor_kwargs=None, new_fig=True, kde=None, kde_kwargs={}, 

31 density_cmap=None, label=None, grid=True, **kwargs 

32): 

33 """Extension of the corner.hist2d function. Allows the user to specify the 

34 kde used when estimating the 2d probability density 

35 

36 Parameters 

37 ---------- 

38 x : array_like[nsamples,] 

39 The samples. 

40 y : array_like[nsamples,] 

41 The samples. 

42 quiet : bool 

43 If true, suppress warnings for small datasets. 

44 levels : array_like 

45 The contour levels to draw. 

46 ax : matplotlib.Axes 

47 A axes instance on which to add the 2-D histogram. 

48 plot_datapoints : bool 

49 Draw the individual data points. 

50 plot_density : bool 

51 Draw the density colormap. 

52 plot_contours : bool 

53 Draw the contours. 

54 no_fill_contours : bool 

55 Add no filling at all to the contours (unlike setting 

56 ``fill_contours=False``, which still adds a white fill at the densest 

57 points). 

58 fill_contours : bool 

59 Fill the contours. 

60 contour_kwargs : dict 

61 Any additional keyword arguments to pass to the `contour` method. 

62 contourf_kwargs : dict 

63 Any additional keyword arguments to pass to the `contourf` method. 

64 data_kwargs : dict 

65 Any additional keyword arguments to pass to the `plot` method when 

66 adding the individual data points. 

67 pcolor_kwargs : dict 

68 Any additional keyword arguments to pass to the `pcolor` method when 

69 adding the density colormap. 

70 kde: func, optional 

71 KDE you wish to use to work out the 2d probability density 

72 kde_kwargs: dict, optional 

73 kwargs passed directly to kde 

74 """ 

75 x = np.asarray(x) 

76 y = np.asarray(y) 

77 if kde is None: 

78 kde = gaussian_kde 

79 

80 if ax is None: 

81 raise ValueError("Please provide an axis to plot") 

82 # Set the default range based on the data range if not provided. 

83 if range is None: 

84 range = [[x.min(), x.max()], [y.min(), y.max()]] 

85 

86 # Set up the default plotting arguments. 

87 if color is None: 

88 color = "k" 

89 

90 # Choose the default "sigma" contour levels. 

91 if levels is None: 

92 levels = 1.0 - np.exp(-0.5 * np.arange(0.5, 2.1, 0.5) ** 2) 

93 

94 # This is the color map for the density plot, over-plotted to indicate the 

95 # density of the points near the center. 

96 if density_cmap is None: 

97 density_cmap = LinearSegmentedColormap.from_list( 

98 "density_cmap", [color, (1, 1, 1, 0)] 

99 ) 

100 elif isinstance(density_cmap, str): 

101 from matplotlib import cm 

102 

103 density_cmap = cm.get_cmap(density_cmap) 

104 

105 # This color map is used to hide the points at the high density areas. 

106 white_cmap = LinearSegmentedColormap.from_list( 

107 "white_cmap", [(1, 1, 1), (1, 1, 1)], N=2 

108 ) 

109 

110 # This "color map" is the list of colors for the contour levels if the 

111 # contours are filled. 

112 rgba_color = colorConverter.to_rgba(color) 

113 contour_cmap = [list(rgba_color) for l in levels] + [rgba_color] 

114 for i, l in enumerate(levels): 

115 contour_cmap[i][-1] *= float(i) / (len(levels) + 1) 

116 

117 # We'll make the 2D histogram to directly estimate the density. 

118 try: 

119 _, X, Y = np.histogram2d( 

120 x.flatten(), 

121 y.flatten(), 

122 bins=bins, 

123 range=list(map(np.sort, range)), 

124 weights=weights, 

125 ) 

126 except ValueError: 

127 raise ValueError( 

128 "It looks like at least one of your sample columns " 

129 "have no dynamic range. You could try using the " 

130 "'range' argument." 

131 ) 

132 

133 values = np.vstack([x.flatten(), y.flatten()]) 

134 kernel = kde(values, weights=weights, **kde_kwargs) 

135 xmin, xmax = np.min(x.flatten()), np.max(x.flatten()) 

136 ymin, ymax = np.min(y.flatten()), np.max(y.flatten()) 

137 X, Y = np.meshgrid(X, Y) 

138 pts = np.vstack([X.ravel(), Y.ravel()]) 

139 z = kernel(pts) 

140 H = z.reshape(X.shape) 

141 if smooth is not None: 

142 if kde_kwargs.get("transform", None) is not None: 

143 from pesummary.utils.utils import logger 

144 logger.warning( 

145 "Smoothing PDF. This may give unwanted effects especially near " 

146 "any boundaries" 

147 ) 

148 try: 

149 from scipy.ndimage import gaussian_filter 

150 except ImportError: 

151 raise ImportError("Please install scipy for smoothing") 

152 H = gaussian_filter(H, smooth) 

153 

154 if plot_contours or plot_density: 

155 pass 

156 

157 if kde_kwargs is None: 

158 kde_kwargs = dict() 

159 if contour_kwargs is None: 

160 contour_kwargs = dict() 

161 

162 if plot_datapoints: 

163 if data_kwargs is None: 

164 data_kwargs = dict() 

165 data_kwargs["color"] = data_kwargs.get("color", color) 

166 data_kwargs["ms"] = data_kwargs.get("ms", 2.0) 

167 data_kwargs["mec"] = data_kwargs.get("mec", "none") 

168 data_kwargs["alpha"] = data_kwargs.get("alpha", 0.1) 

169 if weights is None: 

170 ax.plot(x, y, "o", zorder=-1, rasterized=True, **data_kwargs) 

171 else: 

172 _weights = copy.deepcopy(weights) 

173 _weights /= np.max(_weights) 

174 idxs = np.argsort(_weights) 

175 for num, (xx, yy) in enumerate(zip(x[idxs], y[idxs])): 

176 _data_kwargs = data_kwargs.copy() 

177 _data_kwargs["alpha"] *= _weights[num] 

178 ax.plot(xx, yy, "o", zorder=-1, rasterized=True, **_data_kwargs) 

179 

180 # Plot the base fill to hide the densest data points. 

181 cs = ax.contour( 

182 X, Y, H, levels=(1 - np.array(levels)) * np.max(H), alpha=0. 

183 ) 

184 contour_set = [] 

185 for _path in cs.get_paths(): 

186 _contour_set = [] 

187 for data in _path.to_polygons(): 

188 transpose = data.T 

189 for idx, axis in enumerate(["x", "y"]): 

190 limits = [ 

191 kde_kwargs.get("{}low".format(axis), -np.inf), 

192 kde_kwargs.get("{}high".format(axis), np.inf) 

193 ] 

194 if kde_kwargs.get("transform", None) is None: 

195 if limits[0] is not None: 

196 transpose[idx][ 

197 np.argwhere(transpose[idx] < limits[0]) 

198 ] = limits[0] 

199 if limits[1] is not None: 

200 transpose[idx][ 

201 np.argwhere(transpose[idx] > limits[1]) 

202 ] = limits[1] 

203 else: 

204 _transform = kde_kwargs["transform"](transpose) 

205 _contour_set.append(transpose) 

206 contour_set.append(_contour_set) 

207 

208 # Plot the density map. This can't be plotted at the same time as the 

209 # contour fills. 

210 if plot_density: 

211 if pcolor_kwargs is None: 

212 pcolor_kwargs = dict() 

213 pcolor_kwargs["shading"] = "auto" 

214 ax.pcolor(X, Y, np.max(H) - H, cmap=density_cmap, **pcolor_kwargs) 

215 

216 # Plot the contour edge colors. 

217 if plot_contours: 

218 colors = contour_kwargs.pop("colors", color) 

219 linestyles = kwargs.pop("linestyles", "-") 

220 _list = [colors, linestyles] 

221 for num, (prop, default) in enumerate(zip(_list, ['k', '-'])): 

222 if prop is None: 

223 _list[num] = default * len(contour_set) 

224 elif isinstance(prop, str): 

225 _list[num] = [prop] * len(contour_set) 

226 elif len(prop) < len(contour_set): 

227 raise ValueError( 

228 "Please provide a color/linestyle for each contour" 

229 ) 

230 for idx, _contour in enumerate(contour_set): 

231 for _idx, _path in enumerate(_contour): 

232 if idx == 0 and _idx == 0: 

233 _label = label 

234 else: 

235 _label = None 

236 ax.plot( 

237 *_path, color=_list[0][idx], label=_label, 

238 linestyle=_list[1][idx] 

239 ) 

240 

241 _set_xlim(new_fig, ax, range[0]) 

242 _set_ylim(new_fig, ax, range[1]) 

243 

244 

245def corner( 

246 samples, parameters, bins=20, *, 

247 # Original corner parameters 

248 range=None, axes_scale="linear", weights=None, color='k', 

249 hist_bin_factor=1, smooth=None, smooth1d=None, labels=None, 

250 label_kwargs=None, titles=None, show_titles=False, 

251 title_quantiles=None, title_fmt=".2f", title_kwargs=None, 

252 truths=None, truth_color="#4682b4", scale_hist=False, 

253 quantiles=None, verbose=False, fig=None, max_n_ticks=5, 

254 top_ticks=False, use_math_text=False, reverse=False, 

255 labelpad=0.0, hist_kwargs={}, 

256 # Arviz parameters 

257 group="posterior", var_names=None, filter_vars=None, 

258 coords=None, divergences=False, divergences_kwargs=None, 

259 labeller=None, 

260 # New parameters 

261 kde=None, kde_kwargs={}, kde_2d=None, kde_2d_kwargs={}, 

262 N=100, **hist2d_kwargs, 

263): 

264 """Wrapper for corner.corner which adds additional functionality 

265 to plot custom KDEs along the leading diagonal and custom 2D 

266 KDEs in the 2D panels 

267 """ 

268 from corner import corner 

269 if kde is not None: 

270 hist_kwargs["linewidth"] = 0. 

271 if kde_2d is not None: 

272 linewidths = [1.] 

273 hist2d_kwargs = hist2d_kwargs.copy() 

274 if hist2d_kwargs.get("plot_contours", False): 

275 if "contour_kwargs" not in hist2d_kwargs.keys(): 

276 hist2d_kwargs["contour_kwargs"] = {} 

277 linewidths = hist2d_kwargs["contour_kwargs"].get("linewidths", None) 

278 hist2d_kwargs["contour_kwargs"]["linewidths"] = 0. 

279 plot_density = hist2d_kwargs.get("plot_density", True) 

280 fill_contours = hist2d_kwargs.get("fill_contours", False) 

281 plot_contours = hist2d_kwargs.get("plot_contours", True) 

282 if plot_density: 

283 hist2d_kwargs["plot_density"] = False 

284 if fill_contours: 

285 hist2d_kwargs["fill_contours"] = False 

286 hist2d_kwargs["plot_contours"] = False 

287 

288 fig = corner( 

289 samples, range=range, axes_scale=axes_scale, weights=weights, 

290 color=color, hist_bin_factor=hist_bin_factor, smooth=smooth, 

291 smooth1d=smooth1d, labels=labels, label_kwargs=label_kwargs, 

292 titles=titles, show_titles=show_titles, title_quantiles=title_quantiles, 

293 title_fmt=title_fmt, title_kwargs=title_kwargs, truths=truths, 

294 truth_color=truth_color, scale_hist=scale_hist, 

295 quantiles=quantiles, verbose=verbose, fig=fig, 

296 max_n_ticks=max_n_ticks, top_ticks=top_ticks, 

297 use_math_text=use_math_text, reverse=reverse, 

298 labelpad=labelpad, hist_kwargs=hist_kwargs, bins=bins, 

299 # Arviz parameters 

300 group=group, var_names=var_names, filter_vars=filter_vars, 

301 coords=coords, divergences=divergences, 

302 divergences_kwargs=divergences_kwargs, labeller=labeller, 

303 **hist2d_kwargs 

304 ) 

305 if kde is None and kde_2d is None: 

306 return fig 

307 axs = np.array(fig.get_axes(), dtype=object).reshape( 

308 len(parameters), len(parameters) 

309 ) 

310 if kde is not None: 

311 for num, param in enumerate(parameters): 

312 if param in kde_kwargs.keys(): 

313 _kwargs = kde_kwargs[param] 

314 else: 

315 _kwargs = {} 

316 for key, item in kde_kwargs.items(): 

317 if key not in parameters: 

318 _kwargs[key] = item 

319 _kde = kde(samples[:,num], **_kwargs) 

320 xs = np.linspace(np.min(samples[:,num]), np.max(samples[:,num]), N) 

321 axs[num, num].plot( 

322 xs, _kde(xs), color=color 

323 ) 

324 if kde_2d is not None: 

325 _hist2d_kwargs = hist2d_kwargs.copy() 

326 _contour_kwargs = hist2d_kwargs.pop("contour_kwargs", {}) 

327 _contour_kwargs["linewidths"] = linewidths 

328 _hist2d_kwargs.update( 

329 { 

330 "plot_contours": plot_contours, 

331 "plot_density": plot_density, 

332 "fill_contours": fill_contours, 

333 "levels": hist2d_kwargs.pop("levels")[::-1], 

334 "contour_kwargs": _contour_kwargs 

335 } 

336 ) 

337 for i, x in enumerate(parameters): 

338 for j, y in enumerate(parameters): 

339 if j >= i: 

340 continue 

341 _kde_2d_kwargs = {} 

342 _xkwargs = kde_2d_kwargs.get(x, kde_2d_kwargs) 

343 if "low" in _xkwargs.keys(): 

344 _xkwargs["ylow"] = _xkwargs.pop("low") 

345 if "high" in _xkwargs.keys(): 

346 _xkwargs["yhigh"] = _xkwargs.pop("high") 

347 _kde_2d_kwargs.update(_xkwargs) 

348 _ykwargs = kde_2d_kwargs.get(y, kde_2d_kwargs) 

349 if "low" in _ykwargs.keys(): 

350 _ykwargs["xlow"] = _ykwargs.pop("low") 

351 if "high" in _ykwargs.keys(): 

352 _ykwargs["xhigh"] = _ykwargs.pop("high") 

353 _kde_2d_kwargs.update(_ykwargs) 

354 for key, item in kde_2d_kwargs.items(): 

355 if key not in parameters: 

356 _kde_2d_kwargs[key] = item 

357 hist2d( 

358 samples[:,j], samples[:,i], 

359 ax=axs[i, j], color=color, 

360 kde=kde_2d, kde_kwargs=_kde_2d_kwargs, 

361 bins=bins, **_hist2d_kwargs 

362 ) 

363 return fig