Coverage for pesummary/core/plots/corner.py: 52.9%

174 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import numpy as np 

4from scipy.stats import gaussian_kde 

5from matplotlib.colors import LinearSegmentedColormap, colorConverter 

6 

7__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

8 

9 

10def _set_xlim(new_fig, ax, new_xlim): 

11 if new_fig: 

12 return ax.set_xlim(new_xlim) 

13 xlim = ax.get_xlim() 

14 return ax.set_xlim([min(xlim[0], new_xlim[0]), max(xlim[1], new_xlim[1])]) 

15 

16 

17def _set_ylim(new_fig, ax, new_ylim): 

18 if new_fig: 

19 return ax.set_ylim(new_ylim) 

20 ylim = ax.get_ylim() 

21 return ax.set_ylim([min(ylim[0], new_ylim[0]), max(ylim[1], new_ylim[1])]) 

22 

23 

24def hist2d( 

25 x, y, bins=20, range=None, weights=None, levels=None, smooth=None, ax=None, 

26 color=None, quiet=False, plot_datapoints=True, plot_density=True, 

27 plot_contours=True, no_fill_contours=False, fill_contours=False, 

28 contour_kwargs=None, contourf_kwargs=None, data_kwargs=None, 

29 pcolor_kwargs=None, new_fig=True, kde=None, kde_kwargs={}, 

30 density_cmap=None, label=None, grid=True, **kwargs 

31): 

32 """Extension of the corner.hist2d function. Allows the user to specify the 

33 kde used when estimating the 2d probability density 

34 

35 Parameters 

36 ---------- 

37 x : array_like[nsamples,] 

38 The samples. 

39 y : array_like[nsamples,] 

40 The samples. 

41 quiet : bool 

42 If true, suppress warnings for small datasets. 

43 levels : array_like 

44 The contour levels to draw. 

45 ax : matplotlib.Axes 

46 A axes instance on which to add the 2-D histogram. 

47 plot_datapoints : bool 

48 Draw the individual data points. 

49 plot_density : bool 

50 Draw the density colormap. 

51 plot_contours : bool 

52 Draw the contours. 

53 no_fill_contours : bool 

54 Add no filling at all to the contours (unlike setting 

55 ``fill_contours=False``, which still adds a white fill at the densest 

56 points). 

57 fill_contours : bool 

58 Fill the contours. 

59 contour_kwargs : dict 

60 Any additional keyword arguments to pass to the `contour` method. 

61 contourf_kwargs : dict 

62 Any additional keyword arguments to pass to the `contourf` method. 

63 data_kwargs : dict 

64 Any additional keyword arguments to pass to the `plot` method when 

65 adding the individual data points. 

66 pcolor_kwargs : dict 

67 Any additional keyword arguments to pass to the `pcolor` method when 

68 adding the density colormap. 

69 kde: func, optional 

70 KDE you wish to use to work out the 2d probability density 

71 kde_kwargs: dict, optional 

72 kwargs passed directly to kde 

73 """ 

74 x = np.asarray(x) 

75 y = np.asarray(y) 

76 if kde is None: 

77 kde = gaussian_kde 

78 

79 if ax is None: 

80 raise ValueError("Please provide an axis to plot") 

81 # Set the default range based on the data range if not provided. 

82 if range is None: 

83 range = [[x.min(), x.max()], [y.min(), y.max()]] 

84 

85 # Set up the default plotting arguments. 

86 if color is None: 

87 color = "k" 

88 

89 # Choose the default "sigma" contour levels. 

90 if levels is None: 

91 levels = 1.0 - np.exp(-0.5 * np.arange(0.5, 2.1, 0.5) ** 2) 

92 

93 # This is the color map for the density plot, over-plotted to indicate the 

94 # density of the points near the center. 

95 if density_cmap is None: 

96 density_cmap = LinearSegmentedColormap.from_list( 

97 "density_cmap", [color, (1, 1, 1, 0)] 

98 ) 

99 elif isinstance(density_cmap, str): 

100 from matplotlib import cm 

101 

102 density_cmap = cm.get_cmap(density_cmap) 

103 

104 # This color map is used to hide the points at the high density areas. 

105 white_cmap = LinearSegmentedColormap.from_list( 

106 "white_cmap", [(1, 1, 1), (1, 1, 1)], N=2 

107 ) 

108 

109 # This "color map" is the list of colors for the contour levels if the 

110 # contours are filled. 

111 rgba_color = colorConverter.to_rgba(color) 

112 contour_cmap = [list(rgba_color) for l in levels] + [rgba_color] 

113 for i, l in enumerate(levels): 

114 contour_cmap[i][-1] *= float(i) / (len(levels) + 1) 

115 

116 # We'll make the 2D histogram to directly estimate the density. 

117 try: 

118 _, X, Y = np.histogram2d( 

119 x.flatten(), 

120 y.flatten(), 

121 bins=bins, 

122 range=list(map(np.sort, range)), 

123 weights=weights, 

124 ) 

125 except ValueError: 

126 raise ValueError( 

127 "It looks like at least one of your sample columns " 

128 "have no dynamic range. You could try using the " 

129 "'range' argument." 

130 ) 

131 

132 values = np.vstack([x.flatten(), y.flatten()]) 

133 kernel = kde(values, **kde_kwargs) 

134 xmin, xmax = np.min(x.flatten()), np.max(x.flatten()) 

135 ymin, ymax = np.min(y.flatten()), np.max(y.flatten()) 

136 X, Y = np.meshgrid(X, Y) 

137 pts = np.vstack([X.ravel(), Y.ravel()]) 

138 z = kernel(pts) 

139 H = z.reshape(X.shape) 

140 if smooth is not None: 

141 if kde_kwargs.get("transform", None) is not None: 

142 from pesummary.utils.utils import logger 

143 logger.warning( 

144 "Smoothing PDF. This may give unwanted effects especially near " 

145 "any boundaries" 

146 ) 

147 try: 

148 from scipy.ndimage import gaussian_filter 

149 except ImportError: 

150 raise ImportError("Please install scipy for smoothing") 

151 H = gaussian_filter(H, smooth) 

152 

153 if plot_contours or plot_density: 

154 pass 

155 

156 if kde_kwargs is None: 

157 kde_kwargs = dict() 

158 if contour_kwargs is None: 

159 contour_kwargs = dict() 

160 

161 if plot_datapoints: 

162 if data_kwargs is None: 

163 data_kwargs = dict() 

164 data_kwargs["color"] = data_kwargs.get("color", color) 

165 data_kwargs["ms"] = data_kwargs.get("ms", 2.0) 

166 data_kwargs["mec"] = data_kwargs.get("mec", "none") 

167 data_kwargs["alpha"] = data_kwargs.get("alpha", 0.1) 

168 ax.plot(x, y, "o", zorder=-1, rasterized=True, **data_kwargs) 

169 

170 # Plot the base fill to hide the densest data points. 

171 cs = ax.contour( 

172 X, Y, H, levels=(1 - np.array(levels)) * np.max(H), alpha=0. 

173 ) 

174 contour_set = [] 

175 for _contour in cs.collections: 

176 _contour_set = [] 

177 for _path in _contour.get_paths(): 

178 data = _path.vertices 

179 transpose = data.T 

180 for idx, axis in enumerate(["x", "y"]): 

181 limits = [ 

182 kde_kwargs.get("{}low".format(axis), -np.inf), 

183 kde_kwargs.get("{}high".format(axis), np.inf) 

184 ] 

185 if kde_kwargs.get("transform", None) is None: 

186 if limits[0] is not None: 

187 transpose[idx][ 

188 np.argwhere(transpose[idx] < limits[0]) 

189 ] = limits[0] 

190 if limits[1] is not None: 

191 transpose[idx][ 

192 np.argwhere(transpose[idx] > limits[1]) 

193 ] = limits[1] 

194 else: 

195 _transform = kde_kwargs["transform"](transpose) 

196 _contour_set.append(transpose) 

197 contour_set.append(_contour_set) 

198 

199 # Plot the density map. This can't be plotted at the same time as the 

200 # contour fills. 

201 if plot_density: 

202 if pcolor_kwargs is None: 

203 pcolor_kwargs = dict() 

204 pcolor_kwargs["shading"] = "auto" 

205 ax.pcolor(X, Y, np.max(H) - H, cmap=density_cmap, **pcolor_kwargs) 

206 

207 # Plot the contour edge colors. 

208 if plot_contours: 

209 colors = contour_kwargs.pop("colors", color) 

210 linestyles = kwargs.pop("linestyles", "-") 

211 _list = [colors, linestyles] 

212 for num, (prop, default) in enumerate(zip(_list, ['k', '-'])): 

213 if prop is None: 

214 _list[num] = default * len(contour_set) 

215 elif isinstance(prop, str): 

216 _list[num] = [prop] * len(contour_set) 

217 elif len(prop) < len(contour_set): 

218 raise ValueError( 

219 "Please provide a color/linestyle for each contour" 

220 ) 

221 for idx, _contour in enumerate(contour_set): 

222 for _idx, _path in enumerate(_contour): 

223 if idx == 0 and _idx == 0: 

224 _label = label 

225 else: 

226 _label = None 

227 ax.plot( 

228 *_path, color=_list[0][idx], label=_label, 

229 linestyle=_list[1][idx] 

230 ) 

231 

232 _set_xlim(new_fig, ax, range[0]) 

233 _set_ylim(new_fig, ax, range[1]) 

234 

235 

236def corner( 

237 samples, parameters, bins=20, *, 

238 # Original corner parameters 

239 range=None, axes_scale="linear", weights=None, color='k', 

240 hist_bin_factor=1, smooth=None, smooth1d=None, labels=None, 

241 label_kwargs=None, titles=None, show_titles=False, 

242 title_quantiles=None, title_fmt=".2f", title_kwargs=None, 

243 truths=None, truth_color="#4682b4", scale_hist=False, 

244 quantiles=None, verbose=False, fig=None, max_n_ticks=5, 

245 top_ticks=False, use_math_text=False, reverse=False, 

246 labelpad=0.0, hist_kwargs={}, 

247 # Arviz parameters 

248 group="posterior", var_names=None, filter_vars=None, 

249 coords=None, divergences=False, divergences_kwargs=None, 

250 labeller=None, 

251 # New parameters 

252 kde=None, kde_kwargs={}, kde_2d=None, kde_2d_kwargs={}, 

253 N=100, **hist2d_kwargs, 

254): 

255 """Wrapper for corner.corner which adds additional functionality 

256 to plot custom KDEs along the leading diagonal and custom 2D 

257 KDEs in the 2D panels 

258 """ 

259 from corner import corner 

260 if kde is not None: 

261 hist_kwargs["linewidth"] = 0. 

262 if kde_2d is not None: 

263 linewidths = [1.] 

264 hist2d_kwargs = hist2d_kwargs.copy() 

265 if hist2d_kwargs.get("plot_contours", False): 

266 if "contour_kwargs" not in hist2d_kwargs.keys(): 

267 hist2d_kwargs["contour_kwargs"] = {} 

268 linewidths = hist2d_kwargs["contour_kwargs"].get("linewidths", None) 

269 hist2d_kwargs["contour_kwargs"]["linewidths"] = 0. 

270 plot_density = hist2d_kwargs.get("plot_density", True) 

271 fill_contours = hist2d_kwargs.get("fill_contours", False) 

272 plot_contours = hist2d_kwargs.get("plot_contours", True) 

273 if plot_density: 

274 hist2d_kwargs["plot_density"] = False 

275 if fill_contours: 

276 hist2d_kwargs["fill_contours"] = False 

277 hist2d_kwargs["plot_contours"] = False 

278 

279 fig = corner( 

280 samples, range=range, axes_scale=axes_scale, weights=weights, 

281 color=color, hist_bin_factor=hist_bin_factor, smooth=smooth, 

282 smooth1d=smooth1d, labels=labels, label_kwargs=label_kwargs, 

283 titles=titles, show_titles=show_titles, title_quantiles=title_quantiles, 

284 title_fmt=title_fmt, title_kwargs=title_kwargs, truths=truths, 

285 truth_color=truth_color, scale_hist=scale_hist, 

286 quantiles=quantiles, verbose=verbose, fig=fig, 

287 max_n_ticks=max_n_ticks, top_ticks=top_ticks, 

288 use_math_text=use_math_text, reverse=reverse, 

289 labelpad=labelpad, hist_kwargs=hist_kwargs, 

290 # Arviz parameters 

291 group=group, var_names=var_names, filter_vars=filter_vars, 

292 coords=coords, divergences=divergences, 

293 divergences_kwargs=divergences_kwargs, labeller=labeller, 

294 **hist2d_kwargs 

295 ) 

296 if kde is None and kde_2d is None: 

297 return fig 

298 axs = np.array(fig.get_axes(), dtype=object).reshape( 

299 len(parameters), len(parameters) 

300 ) 

301 if kde is not None: 

302 for num, param in enumerate(parameters): 

303 if param in kde_kwargs.keys(): 

304 _kwargs = kde_kwargs[param] 

305 else: 

306 _kwargs = {} 

307 for key, item in kde_kwargs.items(): 

308 if key not in parameters: 

309 _kwargs[key] = item 

310 _kde = kde(samples[:,num], **_kwargs) 

311 xs = np.linspace(np.min(samples[:,num]), np.max(samples[:,num]), N) 

312 axs[num, num].plot( 

313 xs, _kde(xs), color=color 

314 ) 

315 if kde_2d is not None: 

316 _hist2d_kwargs = hist2d_kwargs.copy() 

317 _contour_kwargs = hist2d_kwargs.pop("contour_kwargs", {}) 

318 _contour_kwargs["linewidths"] = linewidths 

319 _hist2d_kwargs.update( 

320 { 

321 "plot_contours": plot_contours, 

322 "plot_density": plot_density, 

323 "fill_contours": fill_contours, 

324 "levels": hist2d_kwargs.pop("levels")[::-1], 

325 "contour_kwargs": _contour_kwargs 

326 } 

327 ) 

328 for i, x in enumerate(parameters): 

329 for j, y in enumerate(parameters): 

330 if j >= i: 

331 continue 

332 _kde_2d_kwargs = {} 

333 _xkwargs = kde_2d_kwargs.get(x, kde_2d_kwargs) 

334 if "low" in _xkwargs.keys(): 

335 _xkwargs["ylow"] = _xkwargs.pop("low") 

336 if "high" in _xkwargs.keys(): 

337 _xkwargs["yhigh"] = _xkwargs.pop("high") 

338 _kde_2d_kwargs.update(_xkwargs) 

339 _ykwargs = kde_2d_kwargs.get(y, kde_2d_kwargs) 

340 if "low" in _ykwargs.keys(): 

341 _ykwargs["xlow"] = _ykwargs.pop("low") 

342 if "high" in _ykwargs.keys(): 

343 _ykwargs["xhigh"] = _ykwargs.pop("high") 

344 _kde_2d_kwargs.update(_ykwargs) 

345 for key, item in kde_2d_kwargs.items(): 

346 if key not in parameters: 

347 _kde_2d_kwargs[key] = item 

348 hist2d( 

349 samples[:,j], samples[:,i], 

350 ax=axs[i, j], color=color, 

351 kde=kde_2d, kde_kwargs=_kde_2d_kwargs, 

352 bins=bins, **_hist2d_kwargs 

353 ) 

354 return fig