Coverage for pesummary/utils/utils.py: 70.9%

498 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import os 

4import sys 

5import logging 

6import contextlib 

7import time 

8import copy 

9import shutil 

10 

11import numpy as np 

12from scipy.integrate import cumtrapz 

13from scipy.interpolate import interp1d 

14from scipy import stats 

15import h5py 

16from pesummary import conf 

17 

18__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

19 

20try: 

21 from coloredlogs import ColoredFormatter as LogFormatter 

22except ImportError: 

23 LogFormatter = logging.Formatter 

24 

25CACHE_DIR = os.path.join( 

26 os.getenv( 

27 "XDG_CACHE_HOME", 

28 os.path.expanduser(os.path.join("~", ".cache")), 

29 ), 

30 "pesummary", 

31) 

32STYLE_CACHE = os.path.join(CACHE_DIR, "style") 

33LOG_CACHE = os.path.join(CACHE_DIR, "log") 

34 

35 

36def resample_posterior_distribution(posterior, nsamples): 

37 """Randomly draw nsamples from the posterior distribution 

38 

39 Parameters 

40 ---------- 

41 posterior: ndlist 

42 nd list of posterior samples. If you only want to resample one 

43 posterior distribution then posterior=[[1., 2., 3., 4.]]. For multiple 

44 posterior distributions then posterior=[[1., 2., 3., 4.], [1., 2., 3.]] 

45 nsamples: int 

46 number of samples that you wish to randomly draw from the distribution 

47 """ 

48 if len(posterior) == 1: 

49 n, bins = np.histogram(posterior, bins=50) 

50 n = np.array([0] + [i for i in n]) 

51 cdf = cumtrapz(n, bins, initial=0) 

52 cdf /= cdf[-1] 

53 icdf = interp1d(cdf, bins) 

54 samples = icdf(np.random.rand(nsamples)) 

55 else: 

56 posterior = np.array([i for i in posterior]) 

57 keep_idxs = np.random.choice( 

58 len(posterior[0]), nsamples, replace=False) 

59 samples = [i[keep_idxs] for i in posterior] 

60 return samples 

61 

62 

63def check_file_exists_and_rename(file_name): 

64 """Check to see if a file exists and if it does then rename the file 

65 

66 Parameters 

67 ---------- 

68 file_name: str 

69 proposed file name to store data 

70 """ 

71 if os.path.isfile(file_name): 

72 import shutil 

73 

74 old_file = "{}_old".format(file_name) 

75 while os.path.isfile(old_file): 

76 old_file += "_old" 

77 logger.warning( 

78 "The file '{}' already exists. Renaming the existing file to " 

79 "{} and saving the data to the requested file name".format( 

80 file_name, old_file 

81 ) 

82 ) 

83 shutil.move(file_name, old_file) 

84 

85 

86def check_condition(condition, error_message): 

87 """Raise an exception if the condition is not satisfied 

88 """ 

89 if condition: 

90 raise Exception(error_message) 

91 

92 

93def rename_group_or_dataset_in_hf5_file(base_file, group=None, dataset=None): 

94 """Rename a group or dataset in an hdf5 file 

95 

96 Parameters 

97 ---------- 

98 group: list, optional 

99 a list containing the path to the group that you would like to change 

100 as the first argument and the new name of the group as the second 

101 argument 

102 dataset: list, optional 

103 a list containing the name of the dataset that you would like to change 

104 as the first argument and the new name of the dataset as the second 

105 argument 

106 """ 

107 condition = not os.path.isfile(base_file) 

108 check_condition(condition, "The file %s does not exist" % (base_file)) 

109 f = h5py.File(base_file, "a") 

110 if group: 

111 f[group[1]] = f[group[0]] 

112 del f[group[0]] 

113 elif dataset: 

114 f[dataset[1]] = f[dataset[0]] 

115 del f[dataset[0]] 

116 f.close() 

117 

118 

119def make_dir(path): 

120 if os.path.isdir(os.path.expanduser(path)): 

121 pass 

122 else: 

123 os.makedirs(os.path.expanduser(path), exist_ok=True) 

124 

125 

126def guess_url(web_dir, host, user): 

127 """Guess the base url from the host name 

128 

129 Parameters 

130 ---------- 

131 web_dir: str 

132 path to the web directory where you want the data to be saved 

133 host: str 

134 the host name of the machine where the python interpreter is currently 

135 executing 

136 user: str 

137 the user that is current executing the python interpreter 

138 """ 

139 ligo_data_grid = False 

140 if 'public_html' in web_dir: 

141 ligo_data_grid = True 

142 if ligo_data_grid: 

143 path = web_dir.split("public_html")[1] 

144 if "raven" in host or "arcca" in host: 

145 url = "https://geo2.arcca.cf.ac.uk/~{}".format(user) 

146 elif 'ligo-wa' in host: 

147 url = "https://ldas-jobs.ligo-wa.caltech.edu/~{}".format(user) 

148 elif 'ligo-la' in host: 

149 url = "https://ldas-jobs.ligo-la.caltech.edu/~{}".format(user) 

150 elif "cit" in host or "caltech" in host: 

151 url = "https://ldas-jobs.ligo.caltech.edu/~{}".format(user) 

152 elif 'uwm' in host or 'nemo' in host: 

153 url = "https://ldas-jobs.phys.uwm.edu/~{}".format(user) 

154 elif 'phy.syr.edu' in host: 

155 url = "https://sugar-jobs.phy.syr.edu/~{}".format(user) 

156 elif 'vulcan' in host: 

157 url = "https://galahad.aei.mpg.de/~{}".format(user) 

158 elif 'atlas' in host: 

159 url = "https://atlas1.atlas.aei.uni-hannover.de/~{}".format(user) 

160 elif 'iucaa' in host: 

161 url = "https://ldas-jobs.gw.iucaa.in/~{}".format(user) 

162 elif 'alice' in host: 

163 url = "https://dumpty.alice.icts.res.in/~{}".format(user) 

164 elif 'hawk' in host: 

165 url = "https://ligo.gravity.cf.ac.uk/~{}".format(user) 

166 else: 

167 url = "https://{}/~{}".format(host, user) 

168 url += path 

169 else: 

170 url = "https://{}".format(web_dir) 

171 return url 

172 

173 

174def map_parameter_names(dictionary, mapping): 

175 """Modify keys in dictionary to use different names according to a map 

176 

177 Parameters 

178 ---------- 

179 mapping: dict 

180 dictionary mapping existing keys to new names. 

181 

182 Returns 

183 ------- 

184 standard_dict: dict 

185 dict object with new parameter names 

186 """ 

187 standard_dict = {} 

188 for key, item in dictionary.items(): 

189 if key not in mapping.keys(): 

190 standard_dict[key] = item 

191 continue 

192 standard_dict[mapping[key]] = item 

193 return standard_dict 

194 

195 

196def command_line_arguments(): 

197 """Return the command line arguments 

198 """ 

199 return sys.argv[1:] 

200 

201 

202def command_line_dict(): 

203 """Return a dictionary of command line arguments 

204 """ 

205 from pesummary.gw.cli.parser import ArgumentParser 

206 parser = ArgumentParser() 

207 parser.add_all_known_options_to_parser() 

208 opts = parser.parse_args() 

209 return vars(opts) 

210 

211 

212def gw_results_file(opts): 

213 """Determine if a GW results file is passed 

214 """ 

215 from pesummary.gw.cli.parser import ArgumentParser 

216 

217 attrs, defaults = ArgumentParser().gw_options 

218 condition = any( 

219 hasattr(opts, attr) and getattr(opts, attr) and getattr(opts, attr) 

220 != default for attr, default in zip(attrs, defaults) 

221 ) 

222 if condition: 

223 return True 

224 return False 

225 

226 

227def functions(opts, gw=False): 

228 """Return a dictionary of functions that are either specific to GW results 

229 files or core. 

230 """ 

231 from pesummary.core.cli.inputs import ( 

232 WebpagePlusPlottingPlusMetaFileInput as Input 

233 ) 

234 from pesummary.gw.cli.inputs import ( 

235 WebpagePlusPlottingPlusMetaFileInput as GWInput 

236 ) 

237 from pesummary.core.file.meta_file import MetaFile 

238 from pesummary.gw.file.meta_file import GWMetaFile 

239 from pesummary.core.finish import FinishingTouches 

240 from pesummary.gw.finish import GWFinishingTouches 

241 

242 dictionary = {} 

243 dictionary["input"] = GWInput if gw_results_file(opts) or gw else Input 

244 dictionary["MetaFile"] = GWMetaFile if gw_results_file(opts) or gw else MetaFile 

245 dictionary["FinishingTouches"] = \ 

246 GWFinishingTouches if gw_results_file(opts) or gw else FinishingTouches 

247 return dictionary 

248 

249 

250def _logger_format(): 

251 return '%(asctime)s %(name)s %(levelname)-8s: %(message)s' 

252 

253 

254def setup_logger(): 

255 """Set up the logger output. 

256 """ 

257 import tempfile 

258 

259 def get_console_handler(stream_level="INFO"): 

260 console_handler = logging.StreamHandler() 

261 console_handler.setLevel(level=getattr(logging, stream_level)) 

262 console_handler.setFormatter(FORMATTER) 

263 return console_handler 

264 

265 def get_file_handler(log_file): 

266 file_handler = logging.FileHandler(log_file, mode='w') 

267 file_handler.setLevel(level=logging.DEBUG) 

268 file_handler.setFormatter(FORMATTER) 

269 return file_handler 

270 

271 make_dir(LOG_CACHE) 

272 dirpath = tempfile.mkdtemp(dir=LOG_CACHE) 

273 stream_level = 'INFO' 

274 if "-v" in sys.argv or "--verbose" in sys.argv: 

275 stream_level = 'DEBUG' 

276 

277 FORMATTER = LogFormatter(_logger_format(), datefmt='%Y-%m-%d %H:%M:%S') 

278 LOG_FILE = '%s/pesummary.log' % (dirpath) 

279 logger = logging.getLogger('PESummary') 

280 logger.propagate = False 

281 logger.setLevel(level=logging.DEBUG) 

282 logger.addHandler(get_console_handler(stream_level=stream_level)) 

283 logger.addHandler(get_file_handler(LOG_FILE)) 

284 return logger, LOG_FILE 

285 

286 

287def remove_tmp_directories(): 

288 """Remove the temporary directories created by PESummary 

289 """ 

290 import shutil 

291 from glob import glob 

292 

293 directories = glob(".tmp/pesummary/*") 

294 

295 for i in directories: 

296 if os.path.isdir(i): 

297 shutil.rmtree(i) 

298 elif os.path.isfile(i): 

299 os.remove(i) 

300 

301 

302def _add_existing_data(namespace): 

303 """Add existing data to namespace object 

304 """ 

305 for num, i in enumerate(namespace.existing_labels): 

306 if hasattr(namespace, "labels") and i not in namespace.labels: 

307 namespace.labels.append(i) 

308 if hasattr(namespace, "samples") and i not in list(namespace.samples.keys()): 

309 namespace.samples[i] = namespace.existing_samples[i] 

310 if hasattr(namespace, "weights") and i not in list(namespace.weights.keys()): 

311 if namespace.existing_weights is None: 

312 namespace.weights[i] = None 

313 else: 

314 namespace.weights[i] = namespace.existing_weights[i] 

315 if hasattr(namespace, "injection_data"): 

316 if i not in list(namespace.injection_data.keys()): 

317 namespace.injection_data[i] = namespace.existing_injection_data[i] 

318 if hasattr(namespace, "file_versions"): 

319 if i not in list(namespace.file_versions.keys()): 

320 namespace.file_versions[i] = namespace.existing_file_version[i] 

321 if hasattr(namespace, "file_kwargs"): 

322 if i not in list(namespace.file_kwargs.keys()): 

323 namespace.file_kwargs[i] = namespace.existing_file_kwargs[i] 

324 if hasattr(namespace, "config"): 

325 if namespace.existing_config[num] not in namespace.config: 

326 namespace.config.append(namespace.existing_config[num]) 

327 elif namespace.existing_config[num] is None: 

328 namespace.config.append(None) 

329 if hasattr(namespace, "priors"): 

330 if hasattr(namespace, "existing_priors"): 

331 for key, item in namespace.existing_priors.items(): 

332 if key in namespace.priors.keys(): 

333 for label in item.keys(): 

334 if label not in namespace.priors[key].keys(): 

335 namespace.priors[key][label] = item[label] 

336 else: 

337 namespace.priors.update({key: item}) 

338 if hasattr(namespace, "approximant") and namespace.approximant is not None: 

339 if i not in list(namespace.approximant.keys()): 

340 if i in list(namespace.existing_approximant.keys()): 

341 namespace.approximant[i] = namespace.existing_approximant[i] 

342 if hasattr(namespace, "psds") and namespace.psds is not None: 

343 if i not in list(namespace.psds.keys()): 

344 if i in list(namespace.existing_psd.keys()): 

345 namespace.psds[i] = namespace.existing_psd[i] 

346 else: 

347 namespace.psds[i] = {} 

348 if hasattr(namespace, "calibration") and namespace.calibration is not None: 

349 if i not in list(namespace.calibration.keys()): 

350 if i in list(namespace.existing_calibration.keys()): 

351 namespace.calibration[i] = namespace.existing_calibration[i] 

352 else: 

353 namespace.calibration[i] = {} 

354 if hasattr(namespace, "skymap") and namespace.skymap is not None: 

355 if i not in list(namespace.skymap.keys()): 

356 if i in list(namespace.existing_skymap.keys()): 

357 namespace.skymap[i] = namespace.existing_skymap[i] 

358 else: 

359 namespace.skymap[i] = None 

360 if hasattr(namespace, "maxL_samples"): 

361 if i not in list(namespace.maxL_samples.keys()): 

362 namespace.maxL_samples[i] = { 

363 key: val.maxL for key, val in namespace.samples[i].items() 

364 } 

365 if hasattr(namespace, "pepredicates_probs"): 

366 if i not in list(namespace.pepredicates_probs.keys()): 

367 from pesummary.gw.classification import PEPredicates 

368 try: 

369 namespace.pepredicates_probs[i] = PEPredicates( 

370 namespace.existing_samples[i] 

371 ).dual_classification() 

372 except Exception: 

373 namespace.pepredicates_probs[i] = None 

374 if hasattr(namespace, "pastro_probs"): 

375 if i not in list(namespace.pastro_probs.keys()): 

376 from pesummary.gw.classification import PAstro 

377 try: 

378 namespace.pastro_probs[i] = PAstro( 

379 namespace.existing_samples[i] 

380 ).dual_classification() 

381 except Exception: 

382 namespace.pastro_probs[i] = None 

383 if hasattr(namespace, "result_files"): 

384 number = len(namespace.labels) 

385 while len(namespace.result_files) < number: 

386 namespace.result_files.append(namespace.existing_metafile) 

387 parameters = [list(namespace.samples[i].keys()) for i in namespace.labels] 

388 namespace.same_parameters = list( 

389 set.intersection(*[set(l) for l in parameters]) 

390 ) 

391 namespace.same_samples = { 

392 param: { 

393 i: namespace.samples[i][param] for i in namespace.labels 

394 } for param in namespace.same_parameters 

395 } 

396 return namespace 

397 

398 

399def customwarn(message, category, filename, lineno, file=None, line=None): 

400 """ 

401 """ 

402 import sys 

403 import warnings 

404 

405 sys.stdout.write( 

406 warnings.formatwarning("%s" % (message), category, filename, lineno) 

407 ) 

408 

409 

410def determine_gps_time_and_window(maxL_samples, labels): 

411 """Determine the gps time and window to use in the spectrogram and 

412 omegascan plots 

413 """ 

414 times = [ 

415 maxL_samples[label]["geocent_time"] for label in labels 

416 ] 

417 gps_time = np.mean(times) 

418 time_range = np.max(times) - np.min(times) 

419 if time_range < 4.: 

420 window = 4. 

421 else: 

422 window = time_range * 1.5 

423 return gps_time, window 

424 

425 

426def number_of_columns_for_legend(labels): 

427 """Determine the number of columns to use in a legend 

428 

429 Parameters 

430 ---------- 

431 labels: list 

432 list of labels in the legend 

433 """ 

434 max_length = np.max([len(i) for i in labels]) + 5. 

435 if max_length > 50.: 

436 return 1 

437 else: 

438 return int(50. / max_length) 

439 

440 

441class RedirectLogger(object): 

442 """Class to redirect the output from other codes to the `pesummary` 

443 logger 

444 

445 Parameters 

446 ---------- 

447 level: str, optional 

448 the level to display the messages 

449 """ 

450 def __init__(self, code, level="Debug"): 

451 self.logger = logging.getLogger('PESummary') 

452 self.level = getattr(logging, level) 

453 self._redirector = contextlib.redirect_stdout(self) 

454 self.code = code 

455 

456 def isatty(self): 

457 pass 

458 

459 def write(self, msg): 

460 """Write the message to stdout 

461 

462 Parameters 

463 ---------- 

464 msg: str 

465 the message you wish to be printed to stdout 

466 """ 

467 if msg and not msg.isspace(): 

468 self.logger.log(self.level, "[from %s] %s" % (self.code, msg)) 

469 

470 def flush(self): 

471 pass 

472 

473 def __enter__(self): 

474 self._redirector.__enter__() 

475 return self 

476 

477 def __exit__(self, exc_type, exc_value, traceback): 

478 self._redirector.__exit__(exc_type, exc_value, traceback) 

479 

480 

481def draw_conditioned_prior_samples( 

482 samples_dict, prior_samples_dict, conditioned, xlow, xhigh, N=100, 

483 nsamples=1000 

484): 

485 """Return a prior_dict that is conditioned on certain parameters 

486 

487 Parameters 

488 ---------- 

489 samples_dict: pesummary.utils.samples_dict.SamplesDict 

490 SamplesDict containing the posterior samples 

491 prior_samples_dict: pesummary.utils.samples_dict.SamplesDict 

492 SamplesDict containing the prior samples 

493 conditioned: list 

494 list of parameters that you wish to condition your prior on 

495 xlow: dict 

496 dictionary of lower bounds for each parameter 

497 xhigh: dict 

498 dictionary of upper bounds for each parameter 

499 N: int, optional 

500 number of points to use within the grid. Default 100 

501 nsamples: int, optional 

502 number of samples to draw. Default 1000 

503 """ 

504 for param in conditioned: 

505 indices = _draw_conditioned_prior_samples( 

506 prior_samples_dict[param], samples_dict[param], xlow[param], 

507 xhigh[param], xN=N, N=nsamples 

508 ) 

509 for key, val in prior_samples_dict.items(): 

510 prior_samples_dict[key] = val[indices] 

511 

512 return prior_samples_dict 

513 

514 

515def _draw_conditioned_prior_samples( 

516 prior_samples, posterior_samples, xlow, xhigh, xN=1000, N=1000 

517): 

518 """Return a list of indices for the conditioned prior via rejection 

519 sampling. The conditioned prior will then be `prior_samples[indicies]`. 

520 Code from Michael Puerrer. 

521 

522 Parameters 

523 ---------- 

524 prior_samples: np.ndarray 

525 array of prior samples that you wish to condition 

526 posterior_samples: np.ndarray 

527 array of posterior samples that you wish to condition on 

528 xlow: float 

529 lower bound for grid to be used 

530 xhigh: float 

531 upper bound for grid to be used 

532 xN: int, optional 

533 Number of points to use within the grid 

534 N: int, optional 

535 Number of samples to generate 

536 """ 

537 from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE 

538 

539 prior_KDE = ReflectionBoundedKDE(prior_samples) 

540 posterior_KDE = ReflectionBoundedKDE(posterior_samples) 

541 

542 x = np.linspace(xlow, xhigh, xN) 

543 idx_nz = np.nonzero(posterior_KDE(x)) 

544 pdf_ratio = prior_KDE(x)[idx_nz] / posterior_KDE(x)[idx_nz] 

545 M = 1.1 / min(pdf_ratio[np.where(pdf_ratio < 1)]) 

546 

547 indicies = [] 

548 i = 0 

549 while i < N: 

550 x_i = np.random.choice(prior_samples) 

551 idx_i = np.argmin(np.abs(prior_samples - x_i)) 

552 u = np.random.uniform() 

553 if u < posterior_KDE(x_i) / (M * prior_KDE(x_i)): 

554 indicies.append(idx_i) 

555 i += 1 

556 return indicies 

557 

558 

559def unzip(zip_file, outdir=None, overwrite=False): 

560 """Extract the data from a zipped file and save in outdir. 

561 

562 Parameters 

563 ---------- 

564 zip_file: str 

565 path to the file you wish to unzip 

566 outdir: str, optional 

567 path to the directory where you wish to save the unzipped file. Default 

568 None which means that the unzipped file is stored in CACHE_DIR 

569 overwrite: Bool, optional 

570 If True, overwrite a file that has the same name 

571 """ 

572 import gzip 

573 import shutil 

574 from pathlib import Path 

575 

576 f = Path(zip_file) 

577 file_name = f.stem 

578 if outdir is None: 

579 outdir = CACHE_DIR 

580 out_file = os.path.join(outdir, file_name) 

581 if os.path.isfile(out_file) and not overwrite: 

582 raise FileExistsError( 

583 "The file '{}' already exists. Not overwriting".format(out_file) 

584 ) 

585 with gzip.open(zip_file, 'rb') as input: 

586 with open(out_file, 'wb') as output: 

587 shutil.copyfileobj(input, output) 

588 return out_file 

589 

590 

591def iterator( 

592 iterable, desc=None, logger=None, tqdm=False, total=None, file=None, 

593 bar_format=None 

594): 

595 """Return either a tqdm iterator, if tqdm installed, or a simple range 

596 

597 Parameters 

598 ---------- 

599 iterable: func 

600 iterable that you wish to iterate over 

601 desc: str, optional 

602 description for the tqdm bar 

603 tqdm: Bool, optional 

604 If True, a tqdm object is used. Otherwise simply returns the iterator. 

605 logger_output: Bool, optional 

606 If True, the tqdm progress bar interacts with logger 

607 total: float, optional 

608 total length of iterable 

609 logger_name: str, optional 

610 name of the logger you wish to use 

611 file: str, optional 

612 path to file that you wish to write the output to 

613 """ 

614 from pesummary.utils.tqdm import tqdm 

615 if tqdm: 

616 try: 

617 FORMAT, DESC = None, None 

618 if bar_format is None: 

619 FORMAT = ( 

620 '{desc} | {percentage:3.0f}% | {n_fmt}/{total_fmt} | {elapsed}' 

621 ) 

622 if desc is not None: 

623 DESC = desc 

624 return tqdm( 

625 iterable, total=total, logger=logger, desc=DESC, file=file, 

626 bar_format=FORMAT, 

627 ) 

628 except ImportError: 

629 return iterable 

630 else: 

631 return iterable 

632 

633 

634def _check_latex_install(force_tex=False): 

635 from matplotlib import rcParams 

636 from distutils.spawn import find_executable 

637 

638 original = rcParams['text.usetex'] 

639 if find_executable("latex") is not None: 

640 try: 

641 from matplotlib.texmanager import TexManager 

642 

643 texmanager = TexManager() 

644 texmanager.make_dvi(r"$mass_{1}$", 12) 

645 if force_tex: 

646 original = True 

647 rcParams["text.usetex"] = original 

648 except RuntimeError: 

649 rcParams["text.usetex"] = False 

650 else: 

651 rcParams["text.usetex"] = False 

652 

653 

654def smart_round(parameters, return_latex=False, return_latex_row=False): 

655 """Round a parameter according to the uncertainty. If more than one parameter 

656 and uncertainty is passed, each parameter is rounded according to the 

657 lowest uncertainty 

658 

659 Parameters 

660 ---------- 

661 parameter_dictionary: list/np.ndarray 

662 list containing the median, upper bound and lower bound for a given parameter 

663 return_latex: Bool, optional 

664 if True, return as a latex string 

665 return_latex_row: Bool, optional 

666 if True, return the rounded data as a single row in latex format 

667 

668 Examples 

669 -------- 

670 >>> data = [1.234, 0.2, 0.1] 

671 >>> smart_round(data) 

672 [ 1.2 0.2 0.1] 

673 >>> data = [ 

674 ... [6.093, 0.059, 0.055], 

675 ... [6.104, 0.057, 0.052], 

676 ... [6.08, 0.056, 0.052] 

677 ... ] 

678 >>> smart_round(data) 

679 [[ 6.09 0.06 0.06] 

680 [ 6.1 0.06 0.05] 

681 [ 6.08 0.06 0.05]] 

682 >>> smart_round(data, return_latex=True) 

683 6.09^{+0.06}_{-0.06} 

684 6.10^{+0.06}_{-0.05} 

685 6.08^{+0.06}_{-0.05} 

686 >>> data = [ 

687 ... [743.25, 43.6, 53.2], 

688 ... [8712.5, 21.5, 35.2], 

689 ... [196.46, 65.2, 12.5] 

690 ... ] 

691 >>> smart_round(data, return_latex_row=True) 

692 740^{+40}_{-50} & 8710^{+20}_{-40} & 200^{+70}_{-10} 

693 >>> data = [ 

694 ... [743.25, 43.6, 53.2], 

695 ... [8712.5, 21.5, 35.2], 

696 ... [196.46, 65.2, 8.2] 

697 ... ] 

698 >>> smart_round(data, return_latex_row=True) 

699 743^{+44}_{-53} & 8712^{+22}_{-35} & 196^{+65}_{-8} 

700 """ 

701 rounded = copy.deepcopy(np.atleast_2d(parameters)) 

702 lowest_uncertainty = np.min(np.abs(parameters)) 

703 rounding = int(-1 * np.floor(np.log10(lowest_uncertainty))) 

704 for num, _ in enumerate(rounded): 

705 rounded[num] = [np.round(value, rounding) for value in rounded[num]] 

706 if return_latex or return_latex_row: 

707 if rounding > 0: 

708 _format = "%.{}f".format(rounding) 

709 else: 

710 _format = "%.f" 

711 string = "{0}^{{+{0}}}_{{-{0}}}".format(_format) 

712 latex = [string % (value[0], value[1], value[2]) for value in rounded] 

713 if return_latex: 

714 for ll in latex: 

715 print(ll) 

716 else: 

717 print(" & ".join(latex)) 

718 return "" 

719 elif np.array(parameters).ndim == 1: 

720 return rounded[0] 

721 else: 

722 return rounded 

723 

724 

725def safe_round(a, decimals=0, **kwargs): 

726 """Try and round an array to the given number of decimals. If an exception 

727 is raised, return the original array 

728 

729 Parameters 

730 ---------- 

731 a: np.ndarray 

732 array you wish to round 

733 decimals: int 

734 the number of decimals you wish to round too 

735 **kwargs: dict 

736 all kwargs are passed to numpy.round 

737 """ 

738 try: 

739 return np.round(a, decimals=decimals, **kwargs) 

740 except Exception: 

741 return a 

742 

743 

744def gelman_rubin(samples, decimal=5): 

745 """Return an approximation to the Gelman-Rubin statistic (see Gelman, A. and 

746 Rubin, D. B., Statistical Science, Vol 7, No. 4, pp. 457--511 (1992)) 

747 

748 Parameters 

749 ---------- 

750 samples: np.ndarray 

751 2d array of samples for a given parameter, one for each chain 

752 decimal: int 

753 number of decimal places to keep when rounding 

754 

755 Examples 

756 -------- 

757 >>> from pesummary.utils.utils import gelman_rubin 

758 >>> samples = [[1, 1.5, 1.2, 1.4, 1.6, 1.2], [1.5, 1.3, 1.4, 1.7]] 

759 >>> gelman_rubin(samples, decimal=5) 

760 1.2972 

761 """ 

762 means = [np.mean(data) for data in samples] 

763 variances = [np.var(data) for data in samples] 

764 BoverN = np.var(means) 

765 W = np.mean(variances) 

766 sigma = W + BoverN 

767 m = len(samples) 

768 Vhat = sigma + BoverN / m 

769 return np.round(Vhat / W, decimal) 

770 

771 

772def kolmogorov_smirnov_test(samples, decimal=5): 

773 """Return the KS p value between two PDFs 

774 

775 Parameters 

776 ---------- 

777 samples: 2d list 

778 2d list containing the 2 PDFs that you wish to compare 

779 decimal: int 

780 number of decimal places to keep when rounding 

781 """ 

782 return np.round(stats.ks_2samp(*samples)[1], decimal) 

783 

784 

785def jensen_shannon_divergence(*args, **kwargs): 

786 import warnings 

787 warnings.warn( 

788 "The jensen_shannon_divergence function has changed its name to " 

789 "jensen_shannon_divergence_from_samples. jensen_shannon_divergence " 

790 "may not be supported in future releases. Please update" 

791 ) 

792 return jensen_shannon_divergence_from_samples(*args, **kwargs) 

793 

794 

795def jensen_shannon_divergence_from_samples( 

796 samples, kde=stats.gaussian_kde, decimal=5, base=np.e, **kwargs 

797): 

798 """Calculate the JS divergence between two sets of samples 

799 

800 Parameters 

801 ---------- 

802 samples: list 

803 2d list containing the samples drawn from two pdfs 

804 kde: func 

805 function to use when calculating the kde of the samples 

806 decimal: int, float 

807 number of decimal places to round the JS divergence to 

808 base: float, optional 

809 optional base to use for the scipy.stats.entropy function. Default 

810 np.e 

811 kwargs: dict 

812 all kwargs are passed to the kde function 

813 """ 

814 pdfs = samples_to_kde(samples, kde=kde, **kwargs) 

815 return jensen_shannon_divergence_from_pdfs(pdfs, decimal=decimal, base=base) 

816 

817 

818def jensen_shannon_divergence_from_pdfs(pdfs, decimal=5, base=np.e): 

819 """Calculate the JS divergence between two distributions 

820 

821 Parameters 

822 ---------- 

823 pdfs: list 

824 list of length 2 containing the distributions you wish to compare 

825 decimal: int, float 

826 number of decimal places to round the JS divergence to 

827 base: float, optional 

828 optional base to use for the scipy.stats.entropy function. Default 

829 np.e 

830 """ 

831 if any(np.isnan(_).any() for _ in pdfs): 

832 return float("nan") 

833 a, b = pdfs 

834 a = np.asarray(a) 

835 b = np.asarray(b) 

836 a /= a.sum() 

837 b /= b.sum() 

838 m = 1. / 2 * (a + b) 

839 kl_forward = stats.entropy(a, qk=m, base=base) 

840 kl_backward = stats.entropy(b, qk=m, base=base) 

841 return np.round(kl_forward / 2. + kl_backward / 2., decimal) 

842 

843 

844def samples_to_kde(samples, kde=stats.gaussian_kde, **kwargs): 

845 """Generate KDE for a set of samples 

846 

847 Parameters 

848 ---------- 

849 samples: list 

850 list containing the samples to create a KDE for. samples can also 

851 be a 2d list containing samples from multiple analyses. 

852 kde: func 

853 function to use when calculating the kde of the samples 

854 """ 

855 _SINGLE_ANALYSIS = False 

856 if not isinstance(samples[0], (np.ndarray, list, tuple)): 

857 _SINGLE_ANALYSIS = True 

858 _samples = [samples] 

859 else: 

860 _samples = samples 

861 kernel = [] 

862 for i in _samples: 

863 try: 

864 kernel.append(kde(i, **kwargs)) 

865 except np.linalg.LinAlgError: 

866 kernel.append(None) 

867 x = np.linspace( 

868 np.min([np.min(i) for i in _samples]), 

869 np.max([np.max(i) for i in _samples]), 

870 100 

871 ) 

872 pdfs = [k(x) if k is not None else float('nan') for k in kernel] 

873 if _SINGLE_ANALYSIS: 

874 return pdfs[0] 

875 return pdfs 

876 

877 

878def make_cache_style_file(style_file): 

879 """Make a cache directory which stores the style file you wish to use 

880 when plotting 

881 

882 Parameters 

883 ---------- 

884 style_file: str 

885 path to the style file that you wish to use when plotting 

886 """ 

887 make_dir(STYLE_CACHE) 

888 shutil.copyfile( 

889 style_file, os.path.join(STYLE_CACHE, "matplotlib_rcparams.sty") 

890 ) 

891 

892 

893def get_matplotlib_style_file(): 

894 """Return the path to the matplotlib style file that you wish to use 

895 """ 

896 style_file = os.path.join(STYLE_CACHE, "matplotlib_rcparams.sty") 

897 if not os.path.isfile(style_file): 

898 from pesummary import conf 

899 

900 return conf.style_file 

901 return os.path.join(style_file) 

902 

903 

904def get_matplotlib_backend(parallel=False): 

905 """Return the matplotlib backend to use for the plotting modules 

906 

907 Parameters 

908 ---------- 

909 parallel: Bool, optional 

910 if True, backend is always set to 'Agg' for the multiprocessing module 

911 """ 

912 try: 

913 os.environ["DISPLAY"] 

914 except KeyError: 

915 try: 

916 __IPYTHON__ 

917 except NameError: 

918 DISPLAY = False 

919 else: 

920 DISPLAY = True 

921 else: 

922 DISPLAY = True 

923 if DISPLAY and not parallel: 

924 backend = "TKAgg" 

925 else: 

926 backend = "Agg" 

927 return backend 

928 

929 

930def _default_filename(default_filename, label=None): 

931 """Return a default filename 

932 

933 Parameters 

934 ---------- 

935 default_filename: str, optional 

936 the default filename to use if a filename is not provided. default_filename 

937 must be a formattable string with one empty argument for a label 

938 label: str, optional 

939 The label of the analysis. This is used in the filename 

940 """ 

941 if not label: 

942 filename = default_filename.format(round(time.time())) 

943 else: 

944 filename = default_filename.format(label) 

945 return filename 

946 

947 

948def check_filename( 

949 default_filename="pesummary_{}.dat", outdir="./", label=None, filename=None, 

950 overwrite=False, delete_existing=False 

951): 

952 """Check to see if a file exists. If no filename is provided, a default 

953 filename is checked 

954 

955 Parameters 

956 ---------- 

957 default_filename: str, optional 

958 the default filename to use if a filename is not provided. default_filename 

959 must be a formattable string with one empty argument for a label 

960 outdir: str, optional 

961 directory to write the dat file 

962 label: str, optional 

963 The label of the analysis. This is used in the filename if a filename 

964 if not specified 

965 filename: str, optional 

966 The name of the file that you wish to write 

967 overwrite: Bool, optional 

968 If True, an existing file of the same name will be overwritten 

969 """ 

970 if not filename: 

971 filename = _default_filename(default_filename, label=label) 

972 _file = os.path.join(outdir, filename) 

973 if os.path.isfile(_file) and not overwrite: 

974 raise FileExistsError( 

975 "The file '{}' already exists in the directory {}".format( 

976 filename, outdir 

977 ) 

978 ) 

979 if os.path.isfile(_file) and delete_existing: 

980 os.remove(_file) 

981 return _file 

982 

983 

984def string_match(string, substring): 

985 """Return True if a string matches a substring. This substring may include 

986 wildcards 

987 

988 Parameters 

989 ---------- 

990 string: str 

991 string you wish to match 

992 substring: str 

993 string you wish to match against 

994 """ 

995 import re 

996 import sre_constants 

997 

998 try: 

999 match = re.match(re.compile(substring), string) 

1000 if match: 

1001 return True 

1002 return False 

1003 except sre_constants.error: 

1004 import fnmatch 

1005 return string_match(string, fnmatch.translate(substring)) 

1006 

1007 

1008def glob_directory(base): 

1009 """Return a list of files matching base 

1010 

1011 Parameters 

1012 ---------- 

1013 base: str 

1014 string you wish to match e.g. "./", "./*.py" 

1015 """ 

1016 import glob 

1017 if "*" not in base: 

1018 base = os.path.join(base, "*") 

1019 return glob.glob(base) 

1020 

1021 

1022def list_match(list_to_match, substring, return_true=True, return_false=False): 

1023 """Match a list of strings to a substring. This substring may include 

1024 wildcards 

1025 

1026 Parameters 

1027 ---------- 

1028 list_to_match: list 

1029 list of string you wish to match 

1030 substring: str, list 

1031 string you wish to match against or a list of string you wish to match 

1032 against 

1033 return_true: Bool, optional 

1034 if True, return a sublist containing only the parameters that match the 

1035 substring. Default True 

1036 """ 

1037 match = np.ones(len(list_to_match), dtype=bool) 

1038 if isinstance(substring, str): 

1039 substring = [substring] 

1040 

1041 for _substring in substring: 

1042 match *= np.array( 

1043 [string_match(item, _substring) for item in list_to_match], 

1044 dtype=bool 

1045 ) 

1046 if return_false: 

1047 return np.array(list_to_match)[~match] 

1048 elif return_true: 

1049 return np.array(list_to_match)[match] 

1050 return match 

1051 

1052 

1053class Empty(object): 

1054 """Define an empty class which simply returns the input 

1055 """ 

1056 def __new__(self, *args): 

1057 return args[0] 

1058 

1059 

1060def history_dictionary(program=None, creator=conf.user, command_line=None): 

1061 """Create a dictionary containing useful information about the origin of 

1062 a PESummary data product 

1063 

1064 Parameters 

1065 ---------- 

1066 program: str, optional 

1067 program used to generate the PESummary data product 

1068 creator: str, optional 

1069 The user who created the PESummary data product 

1070 command_line: str, optional 

1071 The command line which was run to generate the PESummary data product 

1072 """ 

1073 from astropy.time import Time 

1074 

1075 _dict = { 

1076 "gps_creation_time": Time.now().gps, 

1077 "creator": creator, 

1078 } 

1079 if command_line is not None: 

1080 _dict["command_line"] = ( 

1081 "Generated by running the following script: {}".format( 

1082 command_line 

1083 ) 

1084 ) 

1085 else: 

1086 _dict["command_line"] = " ".join(sys.argv) 

1087 if program is not None: 

1088 _dict["program"] = program 

1089 return _dict 

1090 

1091 

1092def mute_logger(): 

1093 """Mute the PESummary logger 

1094 """ 

1095 _logger = logging.getLogger('PESummary') 

1096 _logger.setLevel(logging.CRITICAL + 10) 

1097 return 

1098 

1099 

1100def unmute_logger(): 

1101 """Unmute the PESummary logger 

1102 """ 

1103 _logger = logging.getLogger('PESummary') 

1104 _logger.setLevel(logging.INFO) 

1105 return 

1106 

1107 

1108_, LOG_FILE = setup_logger() 

1109logger = logging.getLogger('PESummary')