Coverage for pesummary/utils/decorators.py: 77.4%
168 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import functools
4import copy
5import numpy as np
6import os
7from pesummary.utils.utils import logger
9__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
12def open_config(index=0):
13 """Open a configuration file. The function first looks for a config file
14 stored as the keyword argument 'config'. If no kwarg found, one must specify
15 the argument index which corresponds to the config file. Default is the 0th
16 argument.
18 Examples
19 --------
20 @open_config(index=0)
21 def open(config):
22 print(list(config['condor'].keys()))
24 @open_config(index=2)
25 def open(parameters, samples, config):
26 print(list(config['condor'].keys()))
28 @open_config(index=None)
29 def open(parameters, samples, config=config):
30 print(list(config['condor'].keys()))
31 """
32 import configparser
34 def _safe_read(config, config_file):
35 setattr(config, "error", False)
36 if not os.path.isfile(config_file):
37 try:
38 return config.read_string(config_file)
39 except Exception as e:
40 setattr(config, "error", "No such file or directory")
41 return None
43 setattr(config, "path_to_file", config_file)
44 try:
45 setattr(config, "error", False)
46 return config.read(config_file)
47 except configparser.MissingSectionHeaderError:
48 with open(config_file, "r") as f:
49 _config = '[config]\n' + f.read()
50 return config.read_string(_config)
51 except Exception as e:
52 setattr(config, "error", e)
53 return None
55 def decorator(func):
56 @functools.wraps(func)
57 def wrapper_function(*args, **kwargs):
58 config = configparser.ConfigParser()
59 config.optionxform = str
60 if kwargs.get("config", None) is not None:
61 _safe_read(config, kwargs.get("config"))
62 kwargs.update({"config": config})
63 else:
64 args = list(copy.deepcopy(args))
65 _safe_read(config, args[index])
66 args[index] = config
67 return func(*args, **kwargs)
68 return wrapper_function
69 return decorator
72def bound_samples(minimum=-np.inf, maximum=np.inf, logger_level="debug"):
73 """Bound samples to be within a specified range. If any samples lie
74 outside of this range, we set these invalid samples to equal the value at
75 the boundary.
77 Parameters
78 ----------
79 minimum: float
80 lower boundary. Default -np.inf
81 maximum: float
82 upper boundary. Default np.inf
83 logger_level: str
84 level to use for any logger messages
86 Examples
87 --------
88 @bound_samples(minimum=-1., maximum=1., logger_level="info")
89 def random_samples():
90 return np.random.uniform(-2, 2, 10000)
92 >>> random_samples()
93 PESummary INFO : 2576/10000 (25.76%) samples lie outside of the specified
94 range for the function random_samples (< -1.0). Truncating these samples to
95 -1.0.
96 PESummary INFO : 2495/10000 (24.95%) samples lie outside of the specified
97 range for the function random_samples (> 1.0). Truncating these samples to
98 1.0.
99 """
100 def decorator(func):
101 @functools.wraps(func)
102 def wrapper_function(*args, **kwargs):
103 value = np.atleast_1d(func(*args, **kwargs))
104 _minimum_inds = np.argwhere(value < minimum)
105 _maximum_inds = np.argwhere(value > maximum)
106 zipped = zip([_minimum_inds, _maximum_inds], [minimum, maximum])
107 for invalid, bound in zipped:
108 if len(invalid):
109 getattr(logger, logger_level)(
110 "{}/{} ({}%) samples lie outside of the specified "
111 "range for the function {} ({} {}). Truncating these "
112 "samples to {}.".format(
113 len(invalid), len(value),
114 np.round(len(invalid) / len(value) * 100, 2),
115 func.__name__, "<" if bound == minimum else ">",
116 bound, bound
117 )
118 )
119 value[invalid] = bound
120 return value
121 return wrapper_function
122 return decorator
125def no_latex_plot(func):
126 """Turn off latex plotting for a given function
127 """
128 @functools.wraps(func)
129 def wrapper_function(*args, **kwargs):
130 from matplotlib import rcParams
132 original_tex = rcParams["text.usetex"]
133 rcParams["text.usetex"] = False
134 value = func(*args, **kwargs)
135 rcParams["text.usetex"] = original_tex
136 return value
137 return wrapper_function
140def try_latex_plot(func):
141 """Try to make a latex plot, if RuntimeError raised, turn latex off
142 and try again
143 """
144 @functools.wraps(func)
145 def wrapper_function(*args, **kwargs):
146 from matplotlib import rcParams
148 original_tex = rcParams["text.usetex"]
149 try:
150 value = func(*args, **kwargs)
151 except RuntimeError:
152 logger.debug("Unable to use latex. Turning off for this plot")
153 rcParams["text.usetex"] = False
154 value = func(*args, **kwargs)
155 rcParams["text.usetex"] = original_tex
156 return value
157 return wrapper_function
160def tmp_directory(func):
161 """Make a temporary directory run the function from within that
162 directory. Change directory back again after the function has finished
163 running
164 """
165 @functools.wraps(func)
166 def wrapper_function(*args, **kwargs):
167 import tempfile
168 import os
170 current_dir = os.getcwd()
171 with tempfile.TemporaryDirectory(dir="./") as path:
172 os.chdir(path)
173 try:
174 value = func(*args, **kwargs)
175 finally:
176 os.chdir(current_dir)
177 return value
178 return wrapper_function
181def array_input(ignore_args=None, ignore_kwargs=None, force_return_array=False):
182 """Convert the input into an np.ndarray and return either a float or a
183 np.ndarray depending on what was input.
185 Examples
186 --------
187 >>> @array_input
188 >>> def total_mass(mass_1, mass_2):
189 ... total_mass = mass_1 + mass_2
190 ... return total_mass
191 ...
192 >>> print(total_mass(30, 10))
193 40.0
194 >>> print(total_mass([30, 3], [10, 1]))
195 [40 4]
196 """
197 def _array_input(func):
198 @functools.wraps(func)
199 def wrapper_function(*args, **kwargs):
200 new_args = list(copy.deepcopy(args))
201 new_kwargs = kwargs.copy()
202 return_float = False
203 for num, arg in enumerate(args):
204 if ignore_args is not None and num in ignore_args:
205 pass
206 elif isinstance(arg, (float, int)):
207 new_args[num] = np.array([arg])
208 return_float = True
209 elif isinstance(arg, (list, np.ndarray)):
210 new_args[num] = np.array(arg)
211 else:
212 pass
213 for key, item in kwargs.items():
214 if ignore_kwargs is not None and key in ignore_kwargs:
215 pass
216 elif isinstance(item, (float, int)):
217 new_kwargs[key] = np.array([item])
218 elif isinstance(item, (list, np.ndarray)):
219 new_kwargs[key] = np.array(item)
220 output = func(*new_args, **new_kwargs)
221 if isinstance(output, dict):
222 return output
223 try:
224 value = np.array(output)
225 except ValueError:
226 value = np.array(output, dtype=object)
227 if return_float and not force_return_array:
228 new_value = copy.deepcopy(value)
229 if len(new_value) > 1:
230 new_value = np.array([arg[0] for arg in value])
231 elif new_value.ndim == 2:
232 new_value = new_value[0]
233 else:
234 new_value = float(new_value)
235 return new_value
236 return value
237 return wrapper_function
238 return _array_input
241def docstring_subfunction(*args):
242 """Edit the docstring of a function to show the docstrings of subfunctions
243 """
244 def wrapper_function(func):
245 import importlib
247 original_docstring = func.__doc__
248 if isinstance(args[0], list):
249 original_docstring += "\n\nSubfunctions:\n"
250 for subfunction in args[0]:
251 _subfunction = subfunction.split(".")
252 module = ".".join(_subfunction[:-1])
253 function = _subfunction[-1]
254 module = importlib.import_module(module)
255 original_docstring += "\n{}{}".format(
256 subfunction + "\n" + "-" * len(subfunction) + "\n",
257 getattr(module, function).__doc__
258 )
259 else:
260 _subfunction = args[0].split(".")
261 module = ".".join(_subfunction[:-1])
262 function = _subfunction[-1]
263 module = importlib.import_module(module)
264 original_docstring += (
265 "\n\nSubfunctions:\n\n{}{}".format(
266 args[0] + "\n" + "-" * len(args[0]) + "\n",
267 getattr(module, function).__doc__
268 )
269 )
270 func.__doc__ = original_docstring
271 return func
272 return wrapper_function
275def set_docstring(docstring):
276 def wrapper_function(func):
277 func.__doc__ = docstring
278 return func
279 return wrapper_function
282def deprecation(warning):
283 def decorator(func):
284 @functools.wraps(func)
285 def wrapper_function(*args, **kwargs):
286 import warnings
288 warnings.warn(warning)
289 return func(*args, **kwargs)
290 return wrapper_function
291 return decorator