Coverage for pesummary/utils/decorators.py: 79.8%

168 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import functools 

4import copy 

5import numpy as np 

6import os 

7from pesummary.utils.utils import logger 

8 

9__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

10 

11 

12def open_config(index=0): 

13 """Open a configuration file. The function first looks for a config file 

14 stored as the keyword argument 'config'. If no kwarg found, one must specify 

15 the argument index which corresponds to the config file. Default is the 0th 

16 argument. 

17 

18 Examples 

19 -------- 

20 @open_config(index=0) 

21 def open(config): 

22 print(list(config['condor'].keys())) 

23 

24 @open_config(index=2) 

25 def open(parameters, samples, config): 

26 print(list(config['condor'].keys())) 

27 

28 @open_config(index=None) 

29 def open(parameters, samples, config=config): 

30 print(list(config['condor'].keys())) 

31 """ 

32 import configparser 

33 

34 def _safe_read(config, config_file): 

35 setattr(config, "error", False) 

36 if not os.path.isfile(config_file): 

37 try: 

38 return config.read_string(config_file) 

39 except Exception as e: 

40 setattr(config, "error", "No such file or directory") 

41 return None 

42 

43 setattr(config, "path_to_file", config_file) 

44 try: 

45 setattr(config, "error", False) 

46 return config.read(config_file) 

47 except configparser.MissingSectionHeaderError: 

48 with open(config_file, "r") as f: 

49 _config = '[config]\n' + f.read() 

50 return config.read_string(_config) 

51 except Exception as e: 

52 setattr(config, "error", e) 

53 return None 

54 

55 def decorator(func): 

56 @functools.wraps(func) 

57 def wrapper_function(*args, **kwargs): 

58 config = configparser.ConfigParser() 

59 config.optionxform = str 

60 if kwargs.get("config", None) is not None: 

61 _safe_read(config, kwargs.get("config")) 

62 kwargs.update({"config": config}) 

63 else: 

64 args = list(copy.deepcopy(args)) 

65 _safe_read(config, args[index]) 

66 args[index] = config 

67 return func(*args, **kwargs) 

68 return wrapper_function 

69 return decorator 

70 

71 

72def bound_samples(minimum=-np.inf, maximum=np.inf, logger_level="debug"): 

73 """Bound samples to be within a specified range. If any samples lie 

74 outside of this range, we set these invalid samples to equal the value at 

75 the boundary. 

76 

77 Parameters 

78 ---------- 

79 minimum: float 

80 lower boundary. Default -np.inf 

81 maximum: float 

82 upper boundary. Default np.inf 

83 logger_level: str 

84 level to use for any logger messages 

85 

86 Examples 

87 -------- 

88 @bound_samples(minimum=-1., maximum=1., logger_level="info") 

89 def random_samples(): 

90 return np.random.uniform(-2, 2, 10000) 

91 

92 >>> random_samples() 

93 PESummary INFO : 2576/10000 (25.76%) samples lie outside of the specified 

94 range for the function random_samples (< -1.0). Truncating these samples to 

95 -1.0. 

96 PESummary INFO : 2495/10000 (24.95%) samples lie outside of the specified 

97 range for the function random_samples (> 1.0). Truncating these samples to 

98 1.0. 

99 """ 

100 def decorator(func): 

101 @functools.wraps(func) 

102 def wrapper_function(*args, **kwargs): 

103 value = np.atleast_1d(func(*args, **kwargs)) 

104 _minimum_inds = np.argwhere(value < minimum) 

105 _maximum_inds = np.argwhere(value > maximum) 

106 zipped = zip([_minimum_inds, _maximum_inds], [minimum, maximum]) 

107 for invalid, bound in zipped: 

108 if len(invalid): 

109 getattr(logger, logger_level)( 

110 "{}/{} ({}%) samples lie outside of the specified " 

111 "range for the function {} ({} {}). Truncating these " 

112 "samples to {}.".format( 

113 len(invalid), len(value), 

114 np.round(len(invalid) / len(value) * 100, 2), 

115 func.__name__, "<" if bound == minimum else ">", 

116 bound, bound 

117 ) 

118 ) 

119 value[invalid] = bound 

120 return value 

121 return wrapper_function 

122 return decorator 

123 

124 

125def no_latex_plot(func): 

126 """Turn off latex plotting for a given function 

127 """ 

128 @functools.wraps(func) 

129 def wrapper_function(*args, **kwargs): 

130 from matplotlib import rcParams 

131 

132 original_tex = rcParams["text.usetex"] 

133 rcParams["text.usetex"] = False 

134 value = func(*args, **kwargs) 

135 rcParams["text.usetex"] = original_tex 

136 return value 

137 return wrapper_function 

138 

139 

140def try_latex_plot(func): 

141 """Try to make a latex plot, if RuntimeError raised, turn latex off 

142 and try again 

143 """ 

144 @functools.wraps(func) 

145 def wrapper_function(*args, **kwargs): 

146 from matplotlib import rcParams 

147 

148 original_tex = rcParams["text.usetex"] 

149 try: 

150 value = func(*args, **kwargs) 

151 except RuntimeError: 

152 logger.debug("Unable to use latex. Turning off for this plot") 

153 rcParams["text.usetex"] = False 

154 value = func(*args, **kwargs) 

155 rcParams["text.usetex"] = original_tex 

156 return value 

157 return wrapper_function 

158 

159 

160def tmp_directory(func): 

161 """Make a temporary directory run the function from within that 

162 directory. Change directory back again after the function has finished 

163 running 

164 """ 

165 @functools.wraps(func) 

166 def wrapper_function(*args, **kwargs): 

167 import tempfile 

168 import os 

169 

170 current_dir = os.getcwd() 

171 with tempfile.TemporaryDirectory(dir="./") as path: 

172 os.chdir(path) 

173 try: 

174 value = func(*args, **kwargs) 

175 finally: 

176 os.chdir(current_dir) 

177 return value 

178 return wrapper_function 

179 

180 

181def array_input(ignore_args=None, ignore_kwargs=None, force_return_array=False): 

182 """Convert the input into an np.ndarray and return either a float or a 

183 np.ndarray depending on what was input. 

184 

185 Examples 

186 -------- 

187 >>> @array_input 

188 >>> def total_mass(mass_1, mass_2): 

189 ... total_mass = mass_1 + mass_2 

190 ... return total_mass 

191 ... 

192 >>> print(total_mass(30, 10)) 

193 40.0 

194 >>> print(total_mass([30, 3], [10, 1])) 

195 [40 4] 

196 """ 

197 def _array_input(func): 

198 @functools.wraps(func) 

199 def wrapper_function(*args, **kwargs): 

200 new_args = list(copy.deepcopy(args)) 

201 new_kwargs = kwargs.copy() 

202 return_float = False 

203 for num, arg in enumerate(args): 

204 if ignore_args is not None and num in ignore_args: 

205 pass 

206 elif isinstance(arg, (float, int)): 

207 new_args[num] = np.array([arg]) 

208 return_float = True 

209 elif isinstance(arg, (list, np.ndarray)): 

210 new_args[num] = np.array(arg) 

211 else: 

212 pass 

213 for key, item in kwargs.items(): 

214 if ignore_kwargs is not None and key in ignore_kwargs: 

215 pass 

216 elif isinstance(item, (float, int)): 

217 new_kwargs[key] = np.array([item]) 

218 elif isinstance(item, (list, np.ndarray)): 

219 new_kwargs[key] = np.array(item) 

220 output = func(*new_args, **new_kwargs) 

221 if isinstance(output, dict): 

222 return output 

223 try: 

224 value = np.array(output) 

225 except ValueError: 

226 value = np.array(output, dtype=object) 

227 if return_float and not force_return_array: 

228 new_value = copy.deepcopy(value) 

229 if len(new_value) > 1: 

230 new_value = np.array([arg[0] for arg in value]) 

231 elif new_value.ndim == 2: 

232 new_value = new_value[0] 

233 else: 

234 new_value = float(new_value) 

235 return new_value 

236 return value 

237 return wrapper_function 

238 return _array_input 

239 

240 

241def docstring_subfunction(*args): 

242 """Edit the docstring of a function to show the docstrings of subfunctions 

243 """ 

244 def wrapper_function(func): 

245 import importlib 

246 

247 original_docstring = func.__doc__ 

248 if isinstance(args[0], list): 

249 original_docstring += "\n\nSubfunctions:\n" 

250 for subfunction in args[0]: 

251 _subfunction = subfunction.split(".") 

252 module = ".".join(_subfunction[:-1]) 

253 function = _subfunction[-1] 

254 module = importlib.import_module(module) 

255 original_docstring += "\n{}{}".format( 

256 subfunction + "\n" + "-" * len(subfunction) + "\n", 

257 getattr(module, function).__doc__ 

258 ) 

259 else: 

260 _subfunction = args[0].split(".") 

261 module = ".".join(_subfunction[:-1]) 

262 function = _subfunction[-1] 

263 module = importlib.import_module(module) 

264 original_docstring += ( 

265 "\n\nSubfunctions:\n\n{}{}".format( 

266 args[0] + "\n" + "-" * len(args[0]) + "\n", 

267 getattr(module, function).__doc__ 

268 ) 

269 ) 

270 func.__doc__ = original_docstring 

271 return func 

272 return wrapper_function 

273 

274 

275def set_docstring(docstring): 

276 def wrapper_function(func): 

277 func.__doc__ = docstring 

278 return func 

279 return wrapper_function 

280 

281 

282def deprecation(warning): 

283 def decorator(func): 

284 @functools.wraps(func) 

285 def wrapper_function(*args, **kwargs): 

286 import warnings 

287 

288 warnings.warn(warning) 

289 return func(*args, **kwargs) 

290 return wrapper_function 

291 return decorator