Coverage for pesummary/core/plots/interactive.py: 63.8%

47 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3from pesummary.utils.utils import logger, import_error_msg 

4try: 

5 import plotly.graph_objects as go 

6 import plotly 

7except ImportError: 

8 logger.warning(import_error_msg.format("plotly")) 

9 

10__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

11 

12 

13def write_to_html(fig, filename): 

14 """Write a plotly.graph.objects.go.Figure to a html file 

15 

16 Parameters 

17 ---------- 

18 fig: plotly.graph.objects.go.Figure object 

19 figure containing the plot that you wish to save to html 

20 filename: str 

21 name of the file that you wish to write the figure to 

22 """ 

23 div = plotly.offline.plot(fig, include_plotlyjs=False, output_type='div') 

24 data = "<script src='https://cdn.plot.ly/plotly-latest.min.js'></script>\n" 

25 data += ( 

26 "<script src='https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/" 

27 "MathJax.js?config=TeX-MML-AM_SVG'></script>" 

28 ) 

29 with open(filename, "w") as f: 

30 data += div 

31 f.write(data) 

32 

33 

34def histogram2d( 

35 x, y, xlabel='x', ylabel='y', contour=False, contour_color='Blues', 

36 marker_color='rgba(248,148,6,1)', dimensions={'width': 900, 'height': 900}, 

37 write_to_html_file="interactive_2d_histogram.html", showgrid=False, 

38 showlegend=False 

39): 

40 """Build an interactive 2d histogram plot 

41 

42 Parameters 

43 ---------- 

44 x: np.ndarray 

45 An array containing the x coordinates of the points to be histogrammed 

46 y: np.ndarray 

47 An array containing the y coordinates of the points to be histogrammed 

48 xlabel: str 

49 The label for the x coordinates 

50 ylabel: str 

51 The label for the y coordinates 

52 contour: Bool 

53 Whether or not to show contours on the scatter plot 

54 contour_color: str 

55 Name of the matplotlib palette to use for contour colors 

56 marker_color: str 

57 Color to use for the markers 

58 dimensions: dict 

59 A dictionary giving the width and height of the figure. 

60 write_to_html_file: str 

61 Name of the html file you wish to write the figure to 

62 showgrid: Bool 

63 Whether or not to show a grid on the plot 

64 showlegend: Bool 

65 Whether or not to add a legend to the plot 

66 """ 

67 fig = go.Figure() 

68 if contour: 

69 fig.add_trace( 

70 go.Histogram2dContour( 

71 x=x, y=y, colorscale=contour_color, reversescale=True, 

72 xaxis='x', yaxis='y', histnorm="probability density" 

73 ) 

74 ) 

75 fig.add_trace( 

76 go.Scatter( 

77 x=x, y=y, xaxis='x', yaxis='y', mode='markers', 

78 marker=dict(color=marker_color, size=3) 

79 ) 

80 ) 

81 fig.add_trace( 

82 go.Histogram( 

83 y=y, xaxis='x2', marker=dict(color=marker_color), 

84 histnorm="probability density" 

85 ) 

86 ) 

87 fig.add_trace( 

88 go.Histogram( 

89 x=x, yaxis='y2', marker=dict(color=marker_color), 

90 histnorm="probability density" 

91 ) 

92 ) 

93 

94 fig.update_layout( 

95 autosize=False, 

96 xaxis=dict( 

97 zeroline=False, domain=[0, 0.85], showgrid=showgrid 

98 ), 

99 yaxis=dict( 

100 zeroline=False, domain=[0, 0.85], showgrid=showgrid 

101 ), 

102 xaxis2=dict( 

103 zeroline=False, domain=[0.85, 1], showgrid=showgrid 

104 ), 

105 yaxis2=dict( 

106 zeroline=False, domain=[0.85, 1], showgrid=showgrid 

107 ), 

108 height=dimensions["height"], 

109 width=dimensions["width"], 

110 bargap=0, 

111 hovermode='closest', 

112 showlegend=showlegend, 

113 font=dict( 

114 size=10 

115 ), 

116 xaxis_title=xlabel, 

117 yaxis_title=ylabel, 

118 ) 

119 if write_to_html_file is not None: 

120 write_to_html(fig, write_to_html_file) 

121 return 

122 return fig 

123 

124 

125def ridgeline( 

126 data, labels, xlabel='x', palette='colorblind', colors=None, width=3, 

127 write_to_html_file="interactive_ridgeline.html", showlegend=False, 

128 dimensions={'width': 1100, 'height': 700} 

129): 

130 """Build an interactive ridgeline plot 

131 

132 Parameters 

133 ---------- 

134 data: list, np.ndarray 

135 The samples you wish to produce a ridgline plot for. This should be a 2 

136 dimensional array where the zeroth axis is the list of samples and 

137 the next axis is are the dimensions of the space 

138 labels: list 

139 List of labels corresponding to each set of samples 

140 xlabel: str 

141 The label for the x coordinates 

142 palette: str 

143 Name of the seaborn colorpalette to use for the different posterior 

144 distributions 

145 colors: list 

146 List of colors to use for the different posterior distributions 

147 width: float 

148 Width of the violin plots 

149 write_to_html_file: str 

150 Name of the html file you wish to write the figure to 

151 showlegend: Bool 

152 Whether or not to add a legend to the plot 

153 dimensions: dict 

154 A dictionary giving the width and height of the figure 

155 """ 

156 fig = go.Figure() 

157 if colors is None: 

158 from pesummary.core.plots.palette import color_palette 

159 

160 colors = color_palette( 

161 palette=palette, n_colors=len(data) 

162 ).as_hex() 

163 

164 for dd, label, color in zip(data, labels, colors): 

165 fig.add_trace(go.Violin(x=dd, line_color=color, name=label)) 

166 

167 fig.update_traces( 

168 orientation='h', side='positive', width=width, points=False 

169 ) 

170 fig.update_layout( 

171 xaxis_showgrid=False, xaxis_zeroline=False, xaxis_title=xlabel, 

172 width=dimensions["width"], height=dimensions["height"], 

173 font=dict(size=18), showlegend=showlegend 

174 ) 

175 if write_to_html_file is not None: 

176 write_to_html(fig, write_to_html_file) 

177 return 

178 return fig 

179 

180 

181def corner( 

182 data, labels, dimensions={'width': 900, 'height': 900}, show_diagonal=False, 

183 colors={'selected': 'rgba(248,148,6,1)', 'not_selected': 'rgba(0,0,0,1)'}, 

184 show_upper_half=False, write_to_html_file="interactive_corner.html" 

185): 

186 """Build an interactive corner plot 

187 

188 Parameters 

189 ---------- 

190 data: list, np.ndarray 

191 The samples you wish to produce a corner plot for. This should be a 2 

192 dimensional array where the zeroth axis is the list of samples and 

193 the next axis is are the dimensions of the space 

194 labels: list, np.ndarray 

195 A list of names for each dimension 

196 dimensions: dict 

197 A dictionary giving the width and height of the figure. 

198 show_diagonal: Bool 

199 Whether or not to show the diagonal scatter plots 

200 colors: dict 

201 A dictionary of colors for the individual samples. The dictionary should 

202 have keys 'selected' and 'not_selected' to indicate the colors to be 

203 used when the markers are selected and not selected respectively 

204 show_upper_half: Bool 

205 Whether or not to show the upper half of scatter plots 

206 write_to_html_file: str 

207 Name of the html file you wish to write the figure to 

208 """ 

209 data_structure = [ 

210 dict(label=label, values=value) for label, value in zip( 

211 labels, data 

212 ) 

213 ] 

214 fig = go.Figure( 

215 data=go.Splom( 

216 dimensions=data_structure, 

217 marker=dict( 

218 color=colors["not_selected"], showscale=False, 

219 line_color='white', line_width=0.5, 

220 size=3 

221 ), 

222 selected=dict(marker=dict(color=colors["selected"])), 

223 diagonal_visible=show_diagonal, 

224 showupperhalf=show_upper_half, 

225 ) 

226 ) 

227 fig.update_layout( 

228 dragmode='select', 

229 width=dimensions["width"], 

230 height=dimensions["height"], 

231 hovermode='closest', 

232 font=dict( 

233 size=10 

234 ) 

235 ) 

236 if write_to_html_file is not None: 

237 write_to_html(fig, write_to_html_file) 

238 return 

239 return fig