Coverage for pesummary/core/plots/interactive.py: 65.1%

43 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import plotly.graph_objects as go 

4import plotly 

5 

6__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

7 

8 

9def write_to_html(fig, filename): 

10 """Write a plotly.graph.objects.go.Figure to a html file 

11 

12 Parameters 

13 ---------- 

14 fig: plotly.graph.objects.go.Figure object 

15 figure containing the plot that you wish to save to html 

16 filename: str 

17 name of the file that you wish to write the figure to 

18 """ 

19 div = plotly.offline.plot(fig, include_plotlyjs=False, output_type='div') 

20 data = "<script src='https://cdn.plot.ly/plotly-latest.min.js'></script>\n" 

21 data += ( 

22 "<script src='https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/" 

23 "MathJax.js?config=TeX-MML-AM_SVG'></script>" 

24 ) 

25 with open(filename, "w") as f: 

26 data += div 

27 f.write(data) 

28 

29 

30def histogram2d( 

31 x, y, xlabel='x', ylabel='y', contour=False, contour_color='Blues', 

32 marker_color='rgba(248,148,6,1)', dimensions={'width': 900, 'height': 900}, 

33 write_to_html_file="interactive_2d_histogram.html", showgrid=False, 

34 showlegend=False 

35): 

36 """Build an interactive 2d histogram plot 

37 

38 Parameters 

39 ---------- 

40 x: np.ndarray 

41 An array containing the x coordinates of the points to be histogrammed 

42 y: np.ndarray 

43 An array containing the y coordinates of the points to be histogrammed 

44 xlabel: str 

45 The label for the x coordinates 

46 ylabel: str 

47 The label for the y coordinates 

48 contour: Bool 

49 Whether or not to show contours on the scatter plot 

50 contour_color: str 

51 Name of the matplotlib palette to use for contour colors 

52 marker_color: str 

53 Color to use for the markers 

54 dimensions: dict 

55 A dictionary giving the width and height of the figure. 

56 write_to_html_file: str 

57 Name of the html file you wish to write the figure to 

58 showgrid: Bool 

59 Whether or not to show a grid on the plot 

60 showlegend: Bool 

61 Whether or not to add a legend to the plot 

62 """ 

63 fig = go.Figure() 

64 if contour: 

65 fig.add_trace( 

66 go.Histogram2dContour( 

67 x=x, y=y, colorscale=contour_color, reversescale=True, 

68 xaxis='x', yaxis='y', histnorm="probability density" 

69 ) 

70 ) 

71 fig.add_trace( 

72 go.Scatter( 

73 x=x, y=y, xaxis='x', yaxis='y', mode='markers', 

74 marker=dict(color=marker_color, size=3) 

75 ) 

76 ) 

77 fig.add_trace( 

78 go.Histogram( 

79 y=y, xaxis='x2', marker=dict(color=marker_color), 

80 histnorm="probability density" 

81 ) 

82 ) 

83 fig.add_trace( 

84 go.Histogram( 

85 x=x, yaxis='y2', marker=dict(color=marker_color), 

86 histnorm="probability density" 

87 ) 

88 ) 

89 

90 fig.update_layout( 

91 autosize=False, 

92 xaxis=dict( 

93 zeroline=False, domain=[0, 0.85], showgrid=showgrid 

94 ), 

95 yaxis=dict( 

96 zeroline=False, domain=[0, 0.85], showgrid=showgrid 

97 ), 

98 xaxis2=dict( 

99 zeroline=False, domain=[0.85, 1], showgrid=showgrid 

100 ), 

101 yaxis2=dict( 

102 zeroline=False, domain=[0.85, 1], showgrid=showgrid 

103 ), 

104 height=dimensions["height"], 

105 width=dimensions["width"], 

106 bargap=0, 

107 hovermode='closest', 

108 showlegend=showlegend, 

109 font=dict( 

110 size=10 

111 ), 

112 xaxis_title=xlabel, 

113 yaxis_title=ylabel, 

114 ) 

115 if write_to_html_file is not None: 

116 write_to_html(fig, write_to_html_file) 

117 return 

118 return fig 

119 

120 

121def ridgeline( 

122 data, labels, xlabel='x', palette='colorblind', colors=None, width=3, 

123 write_to_html_file="interactive_ridgeline.html", showlegend=False, 

124 dimensions={'width': 1100, 'height': 700} 

125): 

126 """Build an interactive ridgeline plot 

127 

128 Parameters 

129 ---------- 

130 data: list, np.ndarray 

131 The samples you wish to produce a ridgline plot for. This should be a 2 

132 dimensional array where the zeroth axis is the list of samples and 

133 the next axis is are the dimensions of the space 

134 labels: list 

135 List of labels corresponding to each set of samples 

136 xlabel: str 

137 The label for the x coordinates 

138 palette: str 

139 Name of the seaborn colorpalette to use for the different posterior 

140 distributions 

141 colors: list 

142 List of colors to use for the different posterior distributions 

143 width: float 

144 Width of the violin plots 

145 write_to_html_file: str 

146 Name of the html file you wish to write the figure to 

147 showlegend: Bool 

148 Whether or not to add a legend to the plot 

149 dimensions: dict 

150 A dictionary giving the width and height of the figure 

151 """ 

152 fig = go.Figure() 

153 if colors is None: 

154 import seaborn 

155 

156 colors = seaborn.color_palette( 

157 palette=palette, n_colors=len(data) 

158 ).as_hex() 

159 

160 for dd, label, color in zip(data, labels, colors): 

161 fig.add_trace(go.Violin(x=dd, line_color=color, name=label)) 

162 

163 fig.update_traces( 

164 orientation='h', side='positive', width=width, points=False 

165 ) 

166 fig.update_layout( 

167 xaxis_showgrid=False, xaxis_zeroline=False, xaxis_title=xlabel, 

168 width=dimensions["width"], height=dimensions["height"], 

169 font=dict(size=18), showlegend=showlegend 

170 ) 

171 if write_to_html_file is not None: 

172 write_to_html(fig, write_to_html_file) 

173 return 

174 return fig 

175 

176 

177def corner( 

178 data, labels, dimensions={'width': 900, 'height': 900}, show_diagonal=False, 

179 colors={'selected': 'rgba(248,148,6,1)', 'not_selected': 'rgba(0,0,0,1)'}, 

180 show_upper_half=False, write_to_html_file="interactive_corner.html" 

181): 

182 """Build an interactive corner plot 

183 

184 Parameters 

185 ---------- 

186 data: list, np.ndarray 

187 The samples you wish to produce a corner plot for. This should be a 2 

188 dimensional array where the zeroth axis is the list of samples and 

189 the next axis is are the dimensions of the space 

190 labels: list, np.ndarray 

191 A list of names for each dimension 

192 dimensions: dict 

193 A dictionary giving the width and height of the figure. 

194 show_diagonal: Bool 

195 Whether or not to show the diagonal scatter plots 

196 colors: dict 

197 A dictionary of colors for the individual samples. The dictionary should 

198 have keys 'selected' and 'not_selected' to indicate the colors to be 

199 used when the markers are selected and not selected respectively 

200 show_upper_half: Bool 

201 Whether or not to show the upper half of scatter plots 

202 write_to_html_file: str 

203 Name of the html file you wish to write the figure to 

204 """ 

205 data_structure = [ 

206 dict(label=label, values=value) for label, value in zip( 

207 labels, data 

208 ) 

209 ] 

210 fig = go.Figure( 

211 data=go.Splom( 

212 dimensions=data_structure, 

213 marker=dict( 

214 color=colors["not_selected"], showscale=False, 

215 line_color='white', line_width=0.5, 

216 size=3 

217 ), 

218 selected=dict(marker=dict(color=colors["selected"])), 

219 diagonal_visible=show_diagonal, 

220 showupperhalf=show_upper_half, 

221 ) 

222 ) 

223 fig.update_layout( 

224 dragmode='select', 

225 width=dimensions["width"], 

226 height=dimensions["height"], 

227 hovermode='closest', 

228 font=dict( 

229 size=10 

230 ) 

231 ) 

232 if write_to_html_file is not None: 

233 write_to_html(fig, write_to_html_file) 

234 return 

235 return fig