Coverage for pesummary/utils/dict.py: 92.7%
123 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import numpy as np
4import copy
6__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
9def paths_to_key(key, dictionary, current_path=None):
10 """Return the path to a key stored in a nested dictionary
12 Parameters
13 ----------`
14 key: str
15 the key that you would like to find
16 dictionary: dict
17 the nested dictionary that has the key stored somewhere within it
18 current_path: str, optional
19 the current level in the dictionary
20 """
21 if current_path is None:
22 current_path = []
24 for k, v in dictionary.items():
25 if k == key:
26 yield current_path + [key]
27 else:
28 if isinstance(v, dict):
29 path = current_path + [k]
30 for z in paths_to_key(key, v, path):
31 yield z
34def convert_value_to_string(dictionary):
35 """Convert all nested lists of a single value to an item
37 Parameters
38 ----------
39 dictionary: dict
40 nested dictionary with nested lists
41 """
42 for key, value in dictionary.items():
43 if isinstance(value, dict):
44 convert_value_to_string(value)
45 else:
46 dictionary.update({key: str(value)})
47 return dictionary
50def convert_list_to_item(dictionary):
51 """Convert all nested lists of a single value to an item
53 Parameters
54 ----------
55 dictionary: dict
56 nested dictionary with nested lists
57 """
58 from pesummary.utils.array import Array
60 for key, value in dictionary.items():
61 if isinstance(value, dict):
62 convert_list_to_item(value)
63 else:
64 if isinstance(value, (list, np.ndarray, Array)):
65 if len(value) == 1 and isinstance(value[0], bytes):
66 dictionary.update({key: value[0].decode("utf-8")})
67 elif len(value) == 1:
68 dictionary.update({key: value[0]})
69 return dictionary
72def load_recursively(key, dictionary):
73 """Return an entry in a nested dictionary for a key of format 'a/b/c/d'
75 Parameters
76 ----------
77 key: str
78 key of format 'a/b/c/d'
79 dictionary: dict
80 the dictionary that has the key stored
81 """
82 if "/" in key:
83 key = key.split("/")
84 if isinstance(key, (str, float)):
85 key = [key]
86 if key[-1] in dictionary.keys():
87 try:
88 converted_dictionary = convert_list_to_item(
89 dictionary[key[-1]]
90 )
91 yield converted_dictionary
92 except AttributeError:
93 yield dictionary[key[-1]]
94 else:
95 old, new = key[0], key[1:]
96 for z in load_recursively(new, dictionary[old]):
97 yield z
100def edit_dictionary(dictionary, path, value):
101 """Replace an entry in a nested dictionary
103 Parameters
104 ----------
105 dictionary: dict
106 the nested dictionary that you would like to edit
107 path: list
108 the path to the key that you would like to edit
109 value:
110 the replacement
111 """
112 from functools import reduce
113 from operator import getitem
115 edit = dictionary.copy()
116 reduce(getitem, path[:-1], edit)[path[-1]] = value
117 return edit
120class Dict(dict):
121 """Base nested dictionary class.
123 Parameters
124 ----------
125 value_class: func, optional
126 Class you wish to use for the nested dictionary
127 value_columns: list, optional
128 Names for each column in value_class to be stored as properties
129 **kwargs: dict
130 All other kwargs are turned into properties of the class. Key
131 is the name of the property
132 """
133 def __init__(
134 self, *args, value_class=np.array, value_columns=None, _init=True,
135 make_dict_kwargs={}, logger_warn="warn", latex_labels={},
136 extra_kwargs={}, **kwargs
137 ):
138 from .parameters import Parameters
139 super(Dict, self).__init__()
140 if not _init:
141 return
142 self.logger_warn = logger_warn
143 self.all_latex_labels = latex_labels
144 if isinstance(args[0], dict):
145 if args[0].__class__.__name__ == "SamplesDict":
146 self.parameters = list(args[0].keys(remove_debug=False))
147 _iterator = args[0].items(remove_debug=False)
148 else:
149 self.parameters = list(args[0].keys())
150 _iterator = args[0].items()
151 _samples = [args[0][param] for param in self.parameters]
152 try:
153 self.samples = np.array(_samples)
154 except ValueError:
155 self.samples = _samples
156 else:
157 self.parameters, self.samples = args
158 _iterator = zip(self.parameters, self.samples)
159 try:
160 self.make_dictionary(**make_dict_kwargs)
161 except (TypeError, IndexError):
162 for key, item in _iterator:
163 try:
164 self[key] = value_class(item)
165 except Exception:
166 self[key] = value_class(*item)
168 if value_columns is not None:
169 for key in self.keys():
170 if len(value_columns) == self[key].shape[1]:
171 for num, col in enumerate(value_columns):
172 setattr(self[key], col, np.array(self[key].T[num]))
173 for key, item in kwargs.items():
174 setattr(self, key, item)
175 self._update_latex_labels()
176 self.extra_kwargs = extra_kwargs
177 self.parameters = Parameters(self.parameters)
179 def __getitem__(self, key):
180 """Return an object representing the specialization of Dict
181 by type arguments found in key.
182 """
183 if isinstance(key, list):
184 allowed = [_key for _key in key if _key in self.keys()]
185 remove = [_key for _key in self.keys() if _key not in allowed]
186 if len(allowed):
187 if len(allowed) != len(key):
188 import warnings
189 warnings.warn(
190 "Only returning a dict with keys: {} as not all keys "
191 "are in the {} class".format(
192 ", ".join(allowed), self.__class__.__name__
193 )
194 )
195 _self = copy.deepcopy(self)
196 for _key in remove:
197 _self.pop(_key)
198 return _self
199 raise KeyError(
200 "The keys: {} are not available in {}. The list of "
201 "available keys are: {}".format(
202 ", ".join(key), self.__class__.__name__,
203 ", ".join(self.keys())
204 )
205 )
206 elif isinstance(key, str):
207 if key not in self.keys():
208 raise KeyError(
209 "{} not in {}. The list of available keys are {}".format(
210 key, self.__class__.__name__, ", ".join(self.keys())
211 )
212 )
213 return super(Dict, self).__getitem__(key)
215 @property
216 def latex_labels(self):
217 return self._latex_labels
219 @property
220 def plotting_map(self):
221 return {}
223 @property
224 def available_plots(self):
225 return list(self.plotting_map.keys())
227 def _update_latex_labels(self):
228 """Update the stored latex labels
229 """
230 self._latex_labels = {
231 param: self.all_latex_labels[param] if param in
232 self.all_latex_labels.keys() else param for param in self.parameters
233 }
235 def plot(self, *args, type="", **kwargs):
236 """Generate a plot for data stored in Dict
238 Parameters
239 ----------
240 *args: tuple
241 all arguments are passed to the plotting function
242 type: str
243 name of the plot you wish to make
244 **kwargs: dict
245 all additional kwargs are passed to the plotting function
246 """
247 if type not in self.plotting_map.keys():
248 raise NotImplementedError(
249 "The {} method is not currently implemented. The allowed "
250 "plotting methods are {}".format(
251 type, ", ".join(self.available_plots)
252 )
253 )
254 return self.plotting_map[type](*args, **kwargs)
256 def make_dictionary(self, *args, **kwargs):
257 """Add the parameters and samples to the class
258 """
259 raise TypeError