Coverage for pesummary/core/plots/corner.py: 50.3%

183 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-09 22:34 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import copy 

4import numpy as np 

5from scipy.stats import gaussian_kde 

6from matplotlib.colors import LinearSegmentedColormap, colorConverter 

7 

8__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

9 

10 

11def _set_xlim(new_fig, ax, new_xlim): 

12 if new_fig: 

13 return ax.set_xlim(new_xlim) 

14 xlim = ax.get_xlim() 

15 return ax.set_xlim([min(xlim[0], new_xlim[0]), max(xlim[1], new_xlim[1])]) 

16 

17 

18def _set_ylim(new_fig, ax, new_ylim): 

19 if new_fig: 

20 return ax.set_ylim(new_ylim) 

21 ylim = ax.get_ylim() 

22 return ax.set_ylim([min(ylim[0], new_ylim[0]), max(ylim[1], new_ylim[1])]) 

23 

24 

25def hist2d( 

26 x, y, bins=20, range=None, weights=None, levels=None, smooth=None, ax=None, 

27 color=None, quiet=False, plot_datapoints=True, plot_density=True, 

28 plot_contours=True, no_fill_contours=False, fill_contours=False, 

29 contour_kwargs=None, contourf_kwargs=None, data_kwargs=None, 

30 pcolor_kwargs=None, new_fig=True, kde=None, kde_kwargs={}, 

31 density_cmap=None, label=None, grid=True, **kwargs 

32): 

33 """Extension of the corner.hist2d function. Allows the user to specify the 

34 kde used when estimating the 2d probability density 

35 

36 Parameters 

37 ---------- 

38 x : array_like[nsamples,] 

39 The samples. 

40 y : array_like[nsamples,] 

41 The samples. 

42 quiet : bool 

43 If true, suppress warnings for small datasets. 

44 levels : array_like 

45 The contour levels to draw. 

46 ax : matplotlib.Axes 

47 A axes instance on which to add the 2-D histogram. 

48 plot_datapoints : bool 

49 Draw the individual data points. 

50 plot_density : bool 

51 Draw the density colormap. 

52 plot_contours : bool 

53 Draw the contours. 

54 no_fill_contours : bool 

55 Add no filling at all to the contours (unlike setting 

56 ``fill_contours=False``, which still adds a white fill at the densest 

57 points). 

58 fill_contours : bool 

59 Fill the contours. 

60 contour_kwargs : dict 

61 Any additional keyword arguments to pass to the `contour` method. 

62 contourf_kwargs : dict 

63 Any additional keyword arguments to pass to the `contourf` method. 

64 data_kwargs : dict 

65 Any additional keyword arguments to pass to the `plot` method when 

66 adding the individual data points. 

67 pcolor_kwargs : dict 

68 Any additional keyword arguments to pass to the `pcolor` method when 

69 adding the density colormap. 

70 kde: func, optional 

71 KDE you wish to use to work out the 2d probability density 

72 kde_kwargs: dict, optional 

73 kwargs passed directly to kde 

74 """ 

75 x = np.asarray(x) 

76 y = np.asarray(y) 

77 if kde is None: 

78 kde = gaussian_kde 

79 

80 if ax is None: 

81 raise ValueError("Please provide an axis to plot") 

82 # Set the default range based on the data range if not provided. 

83 if range is None: 

84 range = [[x.min(), x.max()], [y.min(), y.max()]] 

85 

86 # Set up the default plotting arguments. 

87 if color is None: 

88 color = "k" 

89 

90 # Choose the default "sigma" contour levels. 

91 if levels is None: 

92 levels = 1.0 - np.exp(-0.5 * np.arange(0.5, 2.1, 0.5) ** 2) 

93 

94 # This is the color map for the density plot, over-plotted to indicate the 

95 # density of the points near the center. 

96 if density_cmap is None: 

97 density_cmap = LinearSegmentedColormap.from_list( 

98 "density_cmap", [color, (1, 1, 1, 0)] 

99 ) 

100 elif isinstance(density_cmap, str): 

101 from matplotlib import cm 

102 

103 density_cmap = cm.get_cmap(density_cmap) 

104 

105 # This color map is used to hide the points at the high density areas. 

106 white_cmap = LinearSegmentedColormap.from_list( 

107 "white_cmap", [(1, 1, 1), (1, 1, 1)], N=2 

108 ) 

109 

110 # This "color map" is the list of colors for the contour levels if the 

111 # contours are filled. 

112 rgba_color = colorConverter.to_rgba(color) 

113 contour_cmap = [list(rgba_color) for l in levels] + [rgba_color] 

114 for i, l in enumerate(levels): 

115 contour_cmap[i][-1] *= float(i) / (len(levels) + 1) 

116 

117 # We'll make the 2D histogram to directly estimate the density. 

118 try: 

119 _, X, Y = np.histogram2d( 

120 x.flatten(), 

121 y.flatten(), 

122 bins=bins, 

123 range=list(map(np.sort, range)), 

124 weights=weights, 

125 ) 

126 except ValueError: 

127 raise ValueError( 

128 "It looks like at least one of your sample columns " 

129 "have no dynamic range. You could try using the " 

130 "'range' argument." 

131 ) 

132 

133 values = np.vstack([x.flatten(), y.flatten()]) 

134 kernel = kde(values, weights=weights, **kde_kwargs) 

135 xmin, xmax = np.min(x.flatten()), np.max(x.flatten()) 

136 ymin, ymax = np.min(y.flatten()), np.max(y.flatten()) 

137 X, Y = np.meshgrid(X, Y) 

138 pts = np.vstack([X.ravel(), Y.ravel()]) 

139 z = kernel(pts) 

140 H = z.reshape(X.shape) 

141 if smooth is not None: 

142 if kde_kwargs.get("transform", None) is not None: 

143 from pesummary.utils.utils import logger 

144 logger.warning( 

145 "Smoothing PDF. This may give unwanted effects especially near " 

146 "any boundaries" 

147 ) 

148 try: 

149 from scipy.ndimage import gaussian_filter 

150 except ImportError: 

151 raise ImportError("Please install scipy for smoothing") 

152 H = gaussian_filter(H, smooth) 

153 

154 if plot_contours or plot_density: 

155 pass 

156 

157 if kde_kwargs is None: 

158 kde_kwargs = dict() 

159 if contour_kwargs is None: 

160 contour_kwargs = dict() 

161 

162 if plot_datapoints: 

163 if data_kwargs is None: 

164 data_kwargs = dict() 

165 data_kwargs["color"] = data_kwargs.get("color", color) 

166 data_kwargs["ms"] = data_kwargs.get("ms", 2.0) 

167 data_kwargs["mec"] = data_kwargs.get("mec", "none") 

168 data_kwargs["alpha"] = data_kwargs.get("alpha", 0.1) 

169 if weights is None: 

170 ax.plot(x, y, "o", zorder=-1, rasterized=True, **data_kwargs) 

171 else: 

172 _weights = copy.deepcopy(weights) 

173 _weights /= np.max(_weights) 

174 idxs = np.argsort(_weights) 

175 for num, (xx, yy) in enumerate(zip(x[idxs], y[idxs])): 

176 _data_kwargs = data_kwargs.copy() 

177 _data_kwargs["alpha"] *= _weights[num] 

178 ax.plot(xx, yy, "o", zorder=-1, rasterized=True, **_data_kwargs) 

179 

180 # Plot the base fill to hide the densest data points. 

181 cs = ax.contour( 

182 X, Y, H, levels=(1 - np.array(levels)) * np.max(H), alpha=0. 

183 ) 

184 contour_set = [] 

185 for _contour in cs.collections: 

186 _contour_set = [] 

187 for _path in _contour.get_paths(): 

188 data = _path.vertices 

189 transpose = data.T 

190 for idx, axis in enumerate(["x", "y"]): 

191 limits = [ 

192 kde_kwargs.get("{}low".format(axis), -np.inf), 

193 kde_kwargs.get("{}high".format(axis), np.inf) 

194 ] 

195 if kde_kwargs.get("transform", None) is None: 

196 if limits[0] is not None: 

197 transpose[idx][ 

198 np.argwhere(transpose[idx] < limits[0]) 

199 ] = limits[0] 

200 if limits[1] is not None: 

201 transpose[idx][ 

202 np.argwhere(transpose[idx] > limits[1]) 

203 ] = limits[1] 

204 else: 

205 _transform = kde_kwargs["transform"](transpose) 

206 _contour_set.append(transpose) 

207 contour_set.append(_contour_set) 

208 

209 # Plot the density map. This can't be plotted at the same time as the 

210 # contour fills. 

211 if plot_density: 

212 if pcolor_kwargs is None: 

213 pcolor_kwargs = dict() 

214 pcolor_kwargs["shading"] = "auto" 

215 ax.pcolor(X, Y, np.max(H) - H, cmap=density_cmap, **pcolor_kwargs) 

216 

217 # Plot the contour edge colors. 

218 if plot_contours: 

219 colors = contour_kwargs.pop("colors", color) 

220 linestyles = kwargs.pop("linestyles", "-") 

221 _list = [colors, linestyles] 

222 for num, (prop, default) in enumerate(zip(_list, ['k', '-'])): 

223 if prop is None: 

224 _list[num] = default * len(contour_set) 

225 elif isinstance(prop, str): 

226 _list[num] = [prop] * len(contour_set) 

227 elif len(prop) < len(contour_set): 

228 raise ValueError( 

229 "Please provide a color/linestyle for each contour" 

230 ) 

231 for idx, _contour in enumerate(contour_set): 

232 for _idx, _path in enumerate(_contour): 

233 if idx == 0 and _idx == 0: 

234 _label = label 

235 else: 

236 _label = None 

237 ax.plot( 

238 *_path, color=_list[0][idx], label=_label, 

239 linestyle=_list[1][idx] 

240 ) 

241 

242 _set_xlim(new_fig, ax, range[0]) 

243 _set_ylim(new_fig, ax, range[1]) 

244 

245 

246def corner( 

247 samples, parameters, bins=20, *, 

248 # Original corner parameters 

249 range=None, axes_scale="linear", weights=None, color='k', 

250 hist_bin_factor=1, smooth=None, smooth1d=None, labels=None, 

251 label_kwargs=None, titles=None, show_titles=False, 

252 title_quantiles=None, title_fmt=".2f", title_kwargs=None, 

253 truths=None, truth_color="#4682b4", scale_hist=False, 

254 quantiles=None, verbose=False, fig=None, max_n_ticks=5, 

255 top_ticks=False, use_math_text=False, reverse=False, 

256 labelpad=0.0, hist_kwargs={}, 

257 # Arviz parameters 

258 group="posterior", var_names=None, filter_vars=None, 

259 coords=None, divergences=False, divergences_kwargs=None, 

260 labeller=None, 

261 # New parameters 

262 kde=None, kde_kwargs={}, kde_2d=None, kde_2d_kwargs={}, 

263 N=100, **hist2d_kwargs, 

264): 

265 """Wrapper for corner.corner which adds additional functionality 

266 to plot custom KDEs along the leading diagonal and custom 2D 

267 KDEs in the 2D panels 

268 """ 

269 from corner import corner 

270 if kde is not None: 

271 hist_kwargs["linewidth"] = 0. 

272 if kde_2d is not None: 

273 linewidths = [1.] 

274 hist2d_kwargs = hist2d_kwargs.copy() 

275 if hist2d_kwargs.get("plot_contours", False): 

276 if "contour_kwargs" not in hist2d_kwargs.keys(): 

277 hist2d_kwargs["contour_kwargs"] = {} 

278 linewidths = hist2d_kwargs["contour_kwargs"].get("linewidths", None) 

279 hist2d_kwargs["contour_kwargs"]["linewidths"] = 0. 

280 plot_density = hist2d_kwargs.get("plot_density", True) 

281 fill_contours = hist2d_kwargs.get("fill_contours", False) 

282 plot_contours = hist2d_kwargs.get("plot_contours", True) 

283 if plot_density: 

284 hist2d_kwargs["plot_density"] = False 

285 if fill_contours: 

286 hist2d_kwargs["fill_contours"] = False 

287 hist2d_kwargs["plot_contours"] = False 

288 

289 fig = corner( 

290 samples, range=range, axes_scale=axes_scale, weights=weights, 

291 color=color, hist_bin_factor=hist_bin_factor, smooth=smooth, 

292 smooth1d=smooth1d, labels=labels, label_kwargs=label_kwargs, 

293 titles=titles, show_titles=show_titles, title_quantiles=title_quantiles, 

294 title_fmt=title_fmt, title_kwargs=title_kwargs, truths=truths, 

295 truth_color=truth_color, scale_hist=scale_hist, 

296 quantiles=quantiles, verbose=verbose, fig=fig, 

297 max_n_ticks=max_n_ticks, top_ticks=top_ticks, 

298 use_math_text=use_math_text, reverse=reverse, 

299 labelpad=labelpad, hist_kwargs=hist_kwargs, 

300 # Arviz parameters 

301 group=group, var_names=var_names, filter_vars=filter_vars, 

302 coords=coords, divergences=divergences, 

303 divergences_kwargs=divergences_kwargs, labeller=labeller, 

304 **hist2d_kwargs 

305 ) 

306 if kde is None and kde_2d is None: 

307 return fig 

308 axs = np.array(fig.get_axes(), dtype=object).reshape( 

309 len(parameters), len(parameters) 

310 ) 

311 if kde is not None: 

312 for num, param in enumerate(parameters): 

313 if param in kde_kwargs.keys(): 

314 _kwargs = kde_kwargs[param] 

315 else: 

316 _kwargs = {} 

317 for key, item in kde_kwargs.items(): 

318 if key not in parameters: 

319 _kwargs[key] = item 

320 _kde = kde(samples[:,num], **_kwargs) 

321 xs = np.linspace(np.min(samples[:,num]), np.max(samples[:,num]), N) 

322 axs[num, num].plot( 

323 xs, _kde(xs), color=color 

324 ) 

325 if kde_2d is not None: 

326 _hist2d_kwargs = hist2d_kwargs.copy() 

327 _contour_kwargs = hist2d_kwargs.pop("contour_kwargs", {}) 

328 _contour_kwargs["linewidths"] = linewidths 

329 _hist2d_kwargs.update( 

330 { 

331 "plot_contours": plot_contours, 

332 "plot_density": plot_density, 

333 "fill_contours": fill_contours, 

334 "levels": hist2d_kwargs.pop("levels")[::-1], 

335 "contour_kwargs": _contour_kwargs 

336 } 

337 ) 

338 for i, x in enumerate(parameters): 

339 for j, y in enumerate(parameters): 

340 if j >= i: 

341 continue 

342 _kde_2d_kwargs = {} 

343 _xkwargs = kde_2d_kwargs.get(x, kde_2d_kwargs) 

344 if "low" in _xkwargs.keys(): 

345 _xkwargs["ylow"] = _xkwargs.pop("low") 

346 if "high" in _xkwargs.keys(): 

347 _xkwargs["yhigh"] = _xkwargs.pop("high") 

348 _kde_2d_kwargs.update(_xkwargs) 

349 _ykwargs = kde_2d_kwargs.get(y, kde_2d_kwargs) 

350 if "low" in _ykwargs.keys(): 

351 _ykwargs["xlow"] = _ykwargs.pop("low") 

352 if "high" in _ykwargs.keys(): 

353 _ykwargs["xhigh"] = _ykwargs.pop("high") 

354 _kde_2d_kwargs.update(_ykwargs) 

355 for key, item in kde_2d_kwargs.items(): 

356 if key not in parameters: 

357 _kde_2d_kwargs[key] = item 

358 hist2d( 

359 samples[:,j], samples[:,i], 

360 ax=axs[i, j], color=color, 

361 kde=kde_2d, kde_kwargs=_kde_2d_kwargs, 

362 bins=bins, **_hist2d_kwargs 

363 ) 

364 return fig