Coverage for pesummary/utils/dict.py: 95.9%

123 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import numpy as np 

4import copy 

5 

6__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

7 

8 

9def paths_to_key(key, dictionary, current_path=None): 

10 """Return the path to a key stored in a nested dictionary 

11 

12 Parameters 

13 ----------` 

14 key: str 

15 the key that you would like to find 

16 dictionary: dict 

17 the nested dictionary that has the key stored somewhere within it 

18 current_path: str, optional 

19 the current level in the dictionary 

20 """ 

21 if current_path is None: 

22 current_path = [] 

23 

24 for k, v in dictionary.items(): 

25 if k == key: 

26 yield current_path + [key] 

27 else: 

28 if isinstance(v, dict): 

29 path = current_path + [k] 

30 for z in paths_to_key(key, v, path): 

31 yield z 

32 

33 

34def convert_value_to_string(dictionary): 

35 """Convert all nested lists of a single value to an item 

36 

37 Parameters 

38 ---------- 

39 dictionary: dict 

40 nested dictionary with nested lists 

41 """ 

42 for key, value in dictionary.items(): 

43 if isinstance(value, dict): 

44 convert_value_to_string(value) 

45 else: 

46 dictionary.update({key: str(value)}) 

47 return dictionary 

48 

49 

50def convert_list_to_item(dictionary): 

51 """Convert all nested lists of a single value to an item 

52 

53 Parameters 

54 ---------- 

55 dictionary: dict 

56 nested dictionary with nested lists 

57 """ 

58 from pesummary.utils.array import Array 

59 

60 for key, value in dictionary.items(): 

61 if isinstance(value, dict): 

62 convert_list_to_item(value) 

63 else: 

64 if isinstance(value, (list, np.ndarray, Array)): 

65 if len(value) == 1 and isinstance(value[0], bytes): 

66 dictionary.update({key: value[0].decode("utf-8")}) 

67 elif len(value) == 1: 

68 dictionary.update({key: value[0]}) 

69 return dictionary 

70 

71 

72def load_recursively(key, dictionary): 

73 """Return an entry in a nested dictionary for a key of format 'a/b/c/d' 

74 

75 Parameters 

76 ---------- 

77 key: str 

78 key of format 'a/b/c/d' 

79 dictionary: dict 

80 the dictionary that has the key stored 

81 """ 

82 if "/" in key: 

83 key = key.split("/") 

84 if isinstance(key, (str, float)): 

85 key = [key] 

86 if key[-1] in dictionary.keys(): 

87 try: 

88 converted_dictionary = convert_list_to_item( 

89 dictionary[key[-1]] 

90 ) 

91 yield converted_dictionary 

92 except AttributeError: 

93 yield dictionary[key[-1]] 

94 else: 

95 old, new = key[0], key[1:] 

96 for z in load_recursively(new, dictionary[old]): 

97 yield z 

98 

99 

100def edit_dictionary(dictionary, path, value): 

101 """Replace an entry in a nested dictionary 

102 

103 Parameters 

104 ---------- 

105 dictionary: dict 

106 the nested dictionary that you would like to edit 

107 path: list 

108 the path to the key that you would like to edit 

109 value: 

110 the replacement 

111 """ 

112 from functools import reduce 

113 from operator import getitem 

114 

115 edit = dictionary.copy() 

116 reduce(getitem, path[:-1], edit)[path[-1]] = value 

117 return edit 

118 

119 

120class Dict(dict): 

121 """Base nested dictionary class. 

122 

123 Parameters 

124 ---------- 

125 value_class: func, optional 

126 Class you wish to use for the nested dictionary 

127 value_columns: list, optional 

128 Names for each column in value_class to be stored as properties 

129 **kwargs: dict 

130 All other kwargs are turned into properties of the class. Key 

131 is the name of the property 

132 """ 

133 def __init__( 

134 self, *args, value_class=np.array, value_columns=None, _init=True, 

135 make_dict_kwargs={}, logger_warn="warn", latex_labels={}, 

136 extra_kwargs={}, **kwargs 

137 ): 

138 from .parameters import Parameters 

139 super(Dict, self).__init__() 

140 if not _init: 

141 return 

142 self.logger_warn = logger_warn 

143 self.all_latex_labels = latex_labels 

144 if isinstance(args[0], dict): 

145 if args[0].__class__.__name__ == "SamplesDict": 

146 self.parameters = list(args[0].keys(remove_debug=False)) 

147 _iterator = args[0].items(remove_debug=False) 

148 else: 

149 self.parameters = list(args[0].keys()) 

150 _iterator = args[0].items() 

151 _samples = [args[0][param] for param in self.parameters] 

152 try: 

153 self.samples = np.array(_samples) 

154 except ValueError: 

155 self.samples = _samples 

156 else: 

157 self.parameters, self.samples = args 

158 _iterator = zip(self.parameters, self.samples) 

159 try: 

160 self.make_dictionary(**make_dict_kwargs) 

161 except (TypeError, IndexError): 

162 for key, item in _iterator: 

163 try: 

164 self[key] = value_class(item) 

165 except Exception: 

166 self[key] = value_class(*item) 

167 

168 if value_columns is not None: 

169 for key in self.keys(): 

170 if len(value_columns) == self[key].shape[1]: 

171 for num, col in enumerate(value_columns): 

172 setattr(self[key], col, np.array(self[key].T[num])) 

173 for key, item in kwargs.items(): 

174 setattr(self, key, item) 

175 self._update_latex_labels() 

176 self.extra_kwargs = extra_kwargs 

177 self.parameters = Parameters(self.parameters) 

178 

179 def __getitem__(self, key): 

180 """Return an object representing the specialization of Dict 

181 by type arguments found in key. 

182 """ 

183 if isinstance(key, list): 

184 allowed = [_key for _key in key if _key in self.keys()] 

185 remove = [_key for _key in self.keys() if _key not in allowed] 

186 if len(allowed): 

187 if len(allowed) != len(key): 

188 import warnings 

189 warnings.warn( 

190 "Only returning a dict with keys: {} as not all keys " 

191 "are in the {} class".format( 

192 ", ".join(allowed), self.__class__.__name__ 

193 ) 

194 ) 

195 _self = copy.deepcopy(self) 

196 for _key in remove: 

197 _self.pop(_key) 

198 return _self 

199 raise KeyError( 

200 "The keys: {} are not available in {}. The list of " 

201 "available keys are: {}".format( 

202 ", ".join(key), self.__class__.__name__, 

203 ", ".join(self.keys()) 

204 ) 

205 ) 

206 elif isinstance(key, str): 

207 if key not in self.keys(): 

208 raise KeyError( 

209 "{} not in {}. The list of available keys are {}".format( 

210 key, self.__class__.__name__, ", ".join(self.keys()) 

211 ) 

212 ) 

213 return super(Dict, self).__getitem__(key) 

214 

215 @property 

216 def latex_labels(self): 

217 return self._latex_labels 

218 

219 @property 

220 def plotting_map(self): 

221 return {} 

222 

223 @property 

224 def available_plots(self): 

225 return list(self.plotting_map.keys()) 

226 

227 def _update_latex_labels(self): 

228 """Update the stored latex labels 

229 """ 

230 self._latex_labels = { 

231 param: self.all_latex_labels[param] if param in 

232 self.all_latex_labels.keys() else param for param in self.parameters 

233 } 

234 

235 def plot(self, *args, type="", **kwargs): 

236 """Generate a plot for data stored in Dict 

237 

238 Parameters 

239 ---------- 

240 *args: tuple 

241 all arguments are passed to the plotting function 

242 type: str 

243 name of the plot you wish to make 

244 **kwargs: dict 

245 all additional kwargs are passed to the plotting function 

246 """ 

247 if type not in self.plotting_map.keys(): 

248 raise NotImplementedError( 

249 "The {} method is not currently implemented. The allowed " 

250 "plotting methods are {}".format( 

251 type, ", ".join(self.available_plots) 

252 ) 

253 ) 

254 return self.plotting_map[type](*args, **kwargs) 

255 

256 def make_dictionary(self, *args, **kwargs): 

257 """Add the parameters and samples to the class 

258 """ 

259 raise TypeError