Coverage for pesummary/utils/dict.py: 84.7%
144 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-05 13:38 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import numpy as np
4import copy
6__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
9def paths_to_key(key, dictionary, current_path=None):
10 """Return the path to a key stored in a nested dictionary
12 Parameters
13 ----------`
14 key: str
15 the key that you would like to find
16 dictionary: dict
17 the nested dictionary that has the key stored somewhere within it
18 current_path: str, optional
19 the current level in the dictionary
20 """
21 if current_path is None:
22 current_path = []
24 for k, v in dictionary.items():
25 if k == key:
26 yield current_path + [key]
27 else:
28 if isinstance(v, dict):
29 path = current_path + [k]
30 for z in paths_to_key(key, v, path):
31 yield z
34def convert_value_to_string(dictionary):
35 """Convert all nested lists of a single value to an item
37 Parameters
38 ----------
39 dictionary: dict
40 nested dictionary with nested lists
41 """
42 for key, value in dictionary.items():
43 if isinstance(value, dict):
44 convert_value_to_string(value)
45 else:
46 dictionary.update({key: str(value)})
47 return dictionary
50def convert_list_to_item(dictionary):
51 """Convert all nested lists of a single value to an item
53 Parameters
54 ----------
55 dictionary: dict
56 nested dictionary with nested lists
57 """
58 from pesummary.utils.array import Array
60 for key, value in dictionary.items():
61 if isinstance(value, dict):
62 convert_list_to_item(value)
63 else:
64 if isinstance(value, (list, np.ndarray, Array)):
65 if len(value) == 1 and isinstance(value[0], bytes):
66 dictionary.update({key: value[0].decode("utf-8")})
67 elif len(value) == 1:
68 dictionary.update({key: value[0]})
69 return dictionary
72def load_recursively(key, dictionary):
73 """Return an entry in a nested dictionary for a key of format 'a/b/c/d'
75 Parameters
76 ----------
77 key: str
78 key of format 'a/b/c/d'
79 dictionary: dict
80 the dictionary that has the key stored
81 """
82 if "/" in key:
83 key = key.split("/")
84 if isinstance(key, (str, float)):
85 key = [key]
86 if key[-1] in dictionary.keys():
87 try:
88 converted_dictionary = convert_list_to_item(
89 dictionary[key[-1]]
90 )
91 yield converted_dictionary
92 except AttributeError:
93 yield dictionary[key[-1]]
94 else:
95 old, new = key[0], key[1:]
96 for z in load_recursively(new, dictionary[old]):
97 yield z
100def edit_dictionary(dictionary, path, value):
101 """Replace an entry in a nested dictionary
103 Parameters
104 ----------
105 dictionary: dict
106 the nested dictionary that you would like to edit
107 path: list
108 the path to the key that you would like to edit
109 value:
110 the replacement
111 """
112 from functools import reduce
113 from operator import getitem
115 edit = dictionary.copy()
116 reduce(getitem, path[:-1], edit)[path[-1]] = value
117 return edit
120class Dict(dict):
121 """Base nested dictionary class.
123 Parameters
124 ----------
125 value_class: func, optional
126 Class you wish to use for the nested dictionary
127 value_columns: list, optional
128 Names for each column in value_class to be stored as properties
129 deconstruct_complex_columns: bool, optional
130 if True, any columns containing complex values will be deconstructed
131 into their real (np.real), amplitude (np.abs) and angle (np.angle) components.
132 Default True
133 **kwargs: dict
134 All other kwargs are turned into properties of the class. Key
135 is the name of the property
136 """
137 def __init__(
138 self, *args, value_class=np.array, value_columns=None, _init=True,
139 make_dict_kwargs={}, logger_warn="warn", latex_labels={},
140 extra_kwargs={}, deconstruct_complex_columns=True, **kwargs
141 ):
142 from .parameters import Parameters
143 from .utils import logger
144 super(Dict, self).__init__()
145 if not _init:
146 return
147 self.logger_warn = logger_warn
148 self.all_latex_labels = latex_labels
149 if isinstance(args[0], dict):
150 if args[0].__class__.__name__ == "SamplesDict":
151 self.parameters = list(args[0].keys(remove_debug=False))
152 _iterator = args[0].items(remove_debug=False)
153 else:
154 self.parameters = list(args[0].keys())
155 _iterator = args[0].items()
156 _samples = [args[0][param] for param in self.parameters]
157 self.samples = _samples
158 else:
159 self.parameters, self.samples = args
160 _iterator = zip(self.parameters, self.samples)
161 try:
162 _samples = copy.deepcopy(self.samples)
163 _complex = np.iscomplex(_samples)
164 if deconstruct_complex_columns:
165 # if a fraction are complex and others are not convert
166 # the complex parameters
167 _original_samples = copy.deepcopy(_samples)
168 for num, ss in enumerate(_original_samples):
169 if np.iscomplex(ss).any():
170 getattr(logger, self.logger_warn)(
171 f"Deconstructing {self.parameters[num]} as it contains "
172 f"complex numbers. To disable this pass "
173 f"deconstruct_complex_columns=False."
174 )
175 _param = self.parameters[num] + "_abs"
176 _ss = np.abs(ss)
177 _samples.append(_ss)
178 self.parameters.append(_param)
180 _param = self.parameters[num] + "_angle"
181 _ss = np.angle(ss)
182 _samples.append(_ss)
183 self.parameters.append(_param)
185 _ss = np.real(ss)
186 _samples[num] = _ss
187 self.samples = np.array(_samples)
188 _iterator = zip(self.parameters, self.samples)
189 else:
190 self.samples = np.array(_samples)
191 except Exception:
192 pass
194 try:
195 self.make_dictionary(**make_dict_kwargs)
196 except (TypeError, IndexError):
197 for key, item in _iterator:
198 try:
199 self[key] = value_class(item)
200 except Exception:
201 self[key] = value_class(*item)
203 if value_columns is not None:
204 for key in self.keys():
205 if len(value_columns) == self[key].shape[1]:
206 for num, col in enumerate(value_columns):
207 setattr(self[key], col, np.array(self[key].T[num]))
208 for key, item in kwargs.items():
209 setattr(self, key, item)
210 self._update_latex_labels()
211 self.extra_kwargs = extra_kwargs
212 self.parameters = Parameters(self.parameters)
214 def __getitem__(self, key):
215 """Return an object representing the specialization of Dict
216 by type arguments found in key.
217 """
218 if isinstance(key, list):
219 allowed = [_key for _key in key if _key in self.keys()]
220 remove = [_key for _key in self.keys() if _key not in allowed]
221 if len(allowed):
222 if len(allowed) != len(key):
223 import warnings
224 warnings.warn(
225 "Only returning a dict with keys: {} as not all keys "
226 "are in the {} class".format(
227 ", ".join(allowed), self.__class__.__name__
228 )
229 )
230 _self = copy.deepcopy(self)
231 for _key in remove:
232 _self.pop(_key)
233 return _self
234 raise KeyError(
235 "The keys: {} are not available in {}. The list of "
236 "available keys are: {}".format(
237 ", ".join(key), self.__class__.__name__,
238 ", ".join(self.keys())
239 )
240 )
241 elif isinstance(key, str):
242 if key not in self.keys():
243 raise KeyError(
244 "{} not in {}. The list of available keys are {}".format(
245 key, self.__class__.__name__, ", ".join(self.keys())
246 )
247 )
248 return super(Dict, self).__getitem__(key)
250 @property
251 def latex_labels(self):
252 return self._latex_labels
254 @property
255 def plotting_map(self):
256 return {}
258 @property
259 def available_plots(self):
260 return list(self.plotting_map.keys())
262 def _update_latex_labels(self):
263 """Update the stored latex labels
264 """
265 self._latex_labels = {
266 param: self.all_latex_labels[param] if param in
267 self.all_latex_labels.keys() else param for param in self.parameters
268 }
270 def plot(self, *args, type="", **kwargs):
271 """Generate a plot for data stored in Dict
273 Parameters
274 ----------
275 *args: tuple
276 all arguments are passed to the plotting function
277 type: str
278 name of the plot you wish to make
279 **kwargs: dict
280 all additional kwargs are passed to the plotting function
281 """
282 if type not in self.plotting_map.keys():
283 raise NotImplementedError(
284 "The {} method is not currently implemented. The allowed "
285 "plotting methods are {}".format(
286 type, ", ".join(self.available_plots)
287 )
288 )
289 return self.plotting_map[type](*args, **kwargs)
291 def make_dictionary(self, *args, **kwargs):
292 """Add the parameters and samples to the class
293 """
294 raise TypeError