Coverage for pesummary/utils/dict.py: 84.7%

144 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import numpy as np 

4import copy 

5 

6__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

7 

8 

9def paths_to_key(key, dictionary, current_path=None): 

10 """Return the path to a key stored in a nested dictionary 

11 

12 Parameters 

13 ----------` 

14 key: str 

15 the key that you would like to find 

16 dictionary: dict 

17 the nested dictionary that has the key stored somewhere within it 

18 current_path: str, optional 

19 the current level in the dictionary 

20 """ 

21 if current_path is None: 

22 current_path = [] 

23 

24 for k, v in dictionary.items(): 

25 if k == key: 

26 yield current_path + [key] 

27 else: 

28 if isinstance(v, dict): 

29 path = current_path + [k] 

30 for z in paths_to_key(key, v, path): 

31 yield z 

32 

33 

34def convert_value_to_string(dictionary): 

35 """Convert all nested lists of a single value to an item 

36 

37 Parameters 

38 ---------- 

39 dictionary: dict 

40 nested dictionary with nested lists 

41 """ 

42 for key, value in dictionary.items(): 

43 if isinstance(value, dict): 

44 convert_value_to_string(value) 

45 else: 

46 dictionary.update({key: str(value)}) 

47 return dictionary 

48 

49 

50def convert_list_to_item(dictionary): 

51 """Convert all nested lists of a single value to an item 

52 

53 Parameters 

54 ---------- 

55 dictionary: dict 

56 nested dictionary with nested lists 

57 """ 

58 from pesummary.utils.array import Array 

59 

60 for key, value in dictionary.items(): 

61 if isinstance(value, dict): 

62 convert_list_to_item(value) 

63 else: 

64 if isinstance(value, (list, np.ndarray, Array)): 

65 if len(value) == 1 and isinstance(value[0], bytes): 

66 dictionary.update({key: value[0].decode("utf-8")}) 

67 elif len(value) == 1: 

68 dictionary.update({key: value[0]}) 

69 return dictionary 

70 

71 

72def load_recursively(key, dictionary): 

73 """Return an entry in a nested dictionary for a key of format 'a/b/c/d' 

74 

75 Parameters 

76 ---------- 

77 key: str 

78 key of format 'a/b/c/d' 

79 dictionary: dict 

80 the dictionary that has the key stored 

81 """ 

82 if "/" in key: 

83 key = key.split("/") 

84 if isinstance(key, (str, float)): 

85 key = [key] 

86 if key[-1] in dictionary.keys(): 

87 try: 

88 converted_dictionary = convert_list_to_item( 

89 dictionary[key[-1]] 

90 ) 

91 yield converted_dictionary 

92 except AttributeError: 

93 yield dictionary[key[-1]] 

94 else: 

95 old, new = key[0], key[1:] 

96 for z in load_recursively(new, dictionary[old]): 

97 yield z 

98 

99 

100def edit_dictionary(dictionary, path, value): 

101 """Replace an entry in a nested dictionary 

102 

103 Parameters 

104 ---------- 

105 dictionary: dict 

106 the nested dictionary that you would like to edit 

107 path: list 

108 the path to the key that you would like to edit 

109 value: 

110 the replacement 

111 """ 

112 from functools import reduce 

113 from operator import getitem 

114 

115 edit = dictionary.copy() 

116 reduce(getitem, path[:-1], edit)[path[-1]] = value 

117 return edit 

118 

119 

120class Dict(dict): 

121 """Base nested dictionary class. 

122 

123 Parameters 

124 ---------- 

125 value_class: func, optional 

126 Class you wish to use for the nested dictionary 

127 value_columns: list, optional 

128 Names for each column in value_class to be stored as properties 

129 deconstruct_complex_columns: bool, optional 

130 if True, any columns containing complex values will be deconstructed 

131 into their real (np.real), amplitude (np.abs) and angle (np.angle) components. 

132 Default True 

133 **kwargs: dict 

134 All other kwargs are turned into properties of the class. Key 

135 is the name of the property 

136 """ 

137 def __init__( 

138 self, *args, value_class=np.array, value_columns=None, _init=True, 

139 make_dict_kwargs={}, logger_warn="warn", latex_labels={}, 

140 extra_kwargs={}, deconstruct_complex_columns=True, **kwargs 

141 ): 

142 from .parameters import Parameters 

143 from .utils import logger 

144 super(Dict, self).__init__() 

145 if not _init: 

146 return 

147 self.logger_warn = logger_warn 

148 self.all_latex_labels = latex_labels 

149 if isinstance(args[0], dict): 

150 if args[0].__class__.__name__ == "SamplesDict": 

151 self.parameters = list(args[0].keys(remove_debug=False)) 

152 _iterator = args[0].items(remove_debug=False) 

153 else: 

154 self.parameters = list(args[0].keys()) 

155 _iterator = args[0].items() 

156 _samples = [args[0][param] for param in self.parameters] 

157 self.samples = _samples 

158 else: 

159 self.parameters, self.samples = args 

160 _iterator = zip(self.parameters, self.samples) 

161 try: 

162 _samples = copy.deepcopy(self.samples) 

163 _complex = np.iscomplex(_samples) 

164 if deconstruct_complex_columns: 

165 # if a fraction are complex and others are not convert 

166 # the complex parameters 

167 _original_samples = copy.deepcopy(_samples) 

168 for num, ss in enumerate(_original_samples): 

169 if np.iscomplex(ss).any(): 

170 getattr(logger, self.logger_warn)( 

171 f"Deconstructing {self.parameters[num]} as it contains " 

172 f"complex numbers. To disable this pass " 

173 f"deconstruct_complex_columns=False." 

174 ) 

175 _param = self.parameters[num] + "_abs" 

176 _ss = np.abs(ss) 

177 _samples.append(_ss) 

178 self.parameters.append(_param) 

179 

180 _param = self.parameters[num] + "_angle" 

181 _ss = np.angle(ss) 

182 _samples.append(_ss) 

183 self.parameters.append(_param) 

184 

185 _ss = np.real(ss) 

186 _samples[num] = _ss 

187 self.samples = np.array(_samples) 

188 _iterator = zip(self.parameters, self.samples) 

189 else: 

190 self.samples = np.array(_samples) 

191 except Exception: 

192 pass 

193 

194 try: 

195 self.make_dictionary(**make_dict_kwargs) 

196 except (TypeError, IndexError): 

197 for key, item in _iterator: 

198 try: 

199 self[key] = value_class(item) 

200 except Exception: 

201 self[key] = value_class(*item) 

202 

203 if value_columns is not None: 

204 for key in self.keys(): 

205 if len(value_columns) == self[key].shape[1]: 

206 for num, col in enumerate(value_columns): 

207 setattr(self[key], col, np.array(self[key].T[num])) 

208 for key, item in kwargs.items(): 

209 setattr(self, key, item) 

210 self._update_latex_labels() 

211 self.extra_kwargs = extra_kwargs 

212 self.parameters = Parameters(self.parameters) 

213 

214 def __getitem__(self, key): 

215 """Return an object representing the specialization of Dict 

216 by type arguments found in key. 

217 """ 

218 if isinstance(key, list): 

219 allowed = [_key for _key in key if _key in self.keys()] 

220 remove = [_key for _key in self.keys() if _key not in allowed] 

221 if len(allowed): 

222 if len(allowed) != len(key): 

223 import warnings 

224 warnings.warn( 

225 "Only returning a dict with keys: {} as not all keys " 

226 "are in the {} class".format( 

227 ", ".join(allowed), self.__class__.__name__ 

228 ) 

229 ) 

230 _self = copy.deepcopy(self) 

231 for _key in remove: 

232 _self.pop(_key) 

233 return _self 

234 raise KeyError( 

235 "The keys: {} are not available in {}. The list of " 

236 "available keys are: {}".format( 

237 ", ".join(key), self.__class__.__name__, 

238 ", ".join(self.keys()) 

239 ) 

240 ) 

241 elif isinstance(key, str): 

242 if key not in self.keys(): 

243 raise KeyError( 

244 "{} not in {}. The list of available keys are {}".format( 

245 key, self.__class__.__name__, ", ".join(self.keys()) 

246 ) 

247 ) 

248 return super(Dict, self).__getitem__(key) 

249 

250 @property 

251 def latex_labels(self): 

252 return self._latex_labels 

253 

254 @property 

255 def plotting_map(self): 

256 return {} 

257 

258 @property 

259 def available_plots(self): 

260 return list(self.plotting_map.keys()) 

261 

262 def _update_latex_labels(self): 

263 """Update the stored latex labels 

264 """ 

265 self._latex_labels = { 

266 param: self.all_latex_labels[param] if param in 

267 self.all_latex_labels.keys() else param for param in self.parameters 

268 } 

269 

270 def plot(self, *args, type="", **kwargs): 

271 """Generate a plot for data stored in Dict 

272 

273 Parameters 

274 ---------- 

275 *args: tuple 

276 all arguments are passed to the plotting function 

277 type: str 

278 name of the plot you wish to make 

279 **kwargs: dict 

280 all additional kwargs are passed to the plotting function 

281 """ 

282 if type not in self.plotting_map.keys(): 

283 raise NotImplementedError( 

284 "The {} method is not currently implemented. The allowed " 

285 "plotting methods are {}".format( 

286 type, ", ".join(self.available_plots) 

287 ) 

288 ) 

289 return self.plotting_map[type](*args, **kwargs) 

290 

291 def make_dictionary(self, *args, **kwargs): 

292 """Add the parameters and samples to the class 

293 """ 

294 raise TypeError