Coverage for pesummary/utils/utils.py: 66.9%

501 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import os 

4import sys 

5import logging 

6import contextlib 

7import time 

8import copy 

9import shutil 

10 

11import numpy as np 

12from scipy.integrate import cumulative_trapezoid as cumtrapz 

13from scipy.interpolate import interp1d 

14from scipy import stats 

15import h5py 

16from pesummary import conf 

17 

18__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

19 

20try: 

21 from coloredlogs import ColoredFormatter as LogFormatter 

22except ImportError: 

23 LogFormatter = logging.Formatter 

24 

25CACHE_DIR = os.path.join( 

26 os.getenv( 

27 "XDG_CACHE_HOME", 

28 os.path.expanduser(os.path.join("~", ".cache")), 

29 ), 

30 "pesummary", 

31) 

32STYLE_CACHE = os.path.join(CACHE_DIR, "style") 

33LOG_CACHE = os.path.join(CACHE_DIR, "log") 

34 

35LATEX = shutil.which("latex") 

36 

37 

38def resample_posterior_distribution(posterior, nsamples): 

39 """Randomly draw nsamples from the posterior distribution 

40 

41 Parameters 

42 ---------- 

43 posterior: ndlist 

44 nd list of posterior samples. If you only want to resample one 

45 posterior distribution then posterior=[[1., 2., 3., 4.]]. For multiple 

46 posterior distributions then posterior=[[1., 2., 3., 4.], [1., 2., 3.]] 

47 nsamples: int 

48 number of samples that you wish to randomly draw from the distribution 

49 """ 

50 if len(posterior) == 1: 

51 n, bins = np.histogram(posterior, bins=50) 

52 n = np.array([0] + [i for i in n]) 

53 cdf = cumtrapz(n, bins, initial=0) 

54 cdf /= cdf[-1] 

55 icdf = interp1d(cdf, bins) 

56 samples = icdf(np.random.rand(nsamples)) 

57 else: 

58 posterior = np.array([i for i in posterior]) 

59 keep_idxs = np.random.choice( 

60 len(posterior[0]), nsamples, replace=False) 

61 samples = [i[keep_idxs] for i in posterior] 

62 return samples 

63 

64 

65def check_file_exists_and_rename(file_name): 

66 """Check to see if a file exists and if it does then rename the file 

67 

68 Parameters 

69 ---------- 

70 file_name: str 

71 proposed file name to store data 

72 """ 

73 if os.path.isfile(file_name): 

74 import shutil 

75 

76 old_file = "{}_old".format(file_name) 

77 while os.path.isfile(old_file): 

78 old_file += "_old" 

79 logger.warning( 

80 "The file '{}' already exists. Renaming the existing file to " 

81 "{} and saving the data to the requested file name".format( 

82 file_name, old_file 

83 ) 

84 ) 

85 shutil.move(file_name, old_file) 

86 

87 

88def check_condition(condition, error_message): 

89 """Raise an exception if the condition is not satisfied 

90 """ 

91 if condition: 

92 raise Exception(error_message) 

93 

94 

95def rename_group_or_dataset_in_hf5_file(base_file, group=None, dataset=None): 

96 """Rename a group or dataset in an hdf5 file 

97 

98 Parameters 

99 ---------- 

100 group: list, optional 

101 a list containing the path to the group that you would like to change 

102 as the first argument and the new name of the group as the second 

103 argument 

104 dataset: list, optional 

105 a list containing the name of the dataset that you would like to change 

106 as the first argument and the new name of the dataset as the second 

107 argument 

108 """ 

109 condition = not os.path.isfile(base_file) 

110 check_condition(condition, "The file %s does not exist" % (base_file)) 

111 f = h5py.File(base_file, "a") 

112 if group: 

113 f[group[1]] = f[group[0]] 

114 del f[group[0]] 

115 elif dataset: 

116 f[dataset[1]] = f[dataset[0]] 

117 del f[dataset[0]] 

118 f.close() 

119 

120 

121def make_dir(path): 

122 if os.path.isdir(os.path.expanduser(path)): 

123 pass 

124 else: 

125 os.makedirs(os.path.expanduser(path), exist_ok=True) 

126 

127 

128def guess_url(web_dir, host, user): 

129 """Guess the base url from the host name 

130 

131 Parameters 

132 ---------- 

133 web_dir: str 

134 path to the web directory where you want the data to be saved 

135 host: str 

136 the host name of the machine where the python interpreter is currently 

137 executing 

138 user: str 

139 the user that is current executing the python interpreter 

140 """ 

141 ligo_data_grid = False 

142 if 'public_html' in web_dir: 

143 ligo_data_grid = True 

144 if ligo_data_grid: 

145 path = web_dir.split("public_html")[1] 

146 if "raven" in host or "arcca" in host: 

147 url = "https://geo2.arcca.cf.ac.uk/~{}".format(user) 

148 elif 'ligo-wa' in host: 

149 url = "https://ldas-jobs.ligo-wa.caltech.edu/~{}".format(user) 

150 elif 'ligo-la' in host: 

151 url = "https://ldas-jobs.ligo-la.caltech.edu/~{}".format(user) 

152 elif "cit" in host or "caltech" in host: 

153 url = "https://ldas-jobs.ligo.caltech.edu/~{}".format(user) 

154 elif 'uwm' in host or 'nemo' in host: 

155 url = "https://ldas-jobs.phys.uwm.edu/~{}".format(user) 

156 elif 'phy.syr.edu' in host: 

157 url = "https://sugar-jobs.phy.syr.edu/~{}".format(user) 

158 elif 'vulcan' in host: 

159 url = "https://galahad.aei.mpg.de/~{}".format(user) 

160 elif 'atlas' in host: 

161 url = "https://atlas1.atlas.aei.uni-hannover.de/~{}".format(user) 

162 elif 'iucaa' in host: 

163 url = "https://ldas-jobs.gw.iucaa.in/~{}".format(user) 

164 elif 'alice' in host: 

165 url = "https://dumpty.alice.icts.res.in/~{}".format(user) 

166 elif 'hawk' in host: 

167 url = "https://ligo.gravity.cf.ac.uk/~{}".format(user) 

168 else: 

169 url = "https://{}/~{}".format(host, user) 

170 url += path 

171 else: 

172 url = "https://{}".format(web_dir) 

173 return url 

174 

175 

176def map_parameter_names(dictionary, mapping): 

177 """Modify keys in dictionary to use different names according to a map 

178 

179 Parameters 

180 ---------- 

181 mapping: dict 

182 dictionary mapping existing keys to new names. 

183 

184 Returns 

185 ------- 

186 standard_dict: dict 

187 dict object with new parameter names 

188 """ 

189 standard_dict = {} 

190 for key, item in dictionary.items(): 

191 if key not in mapping.keys(): 

192 standard_dict[key] = item 

193 continue 

194 standard_dict[mapping[key]] = item 

195 return standard_dict 

196 

197 

198def command_line_arguments(): 

199 """Return the command line arguments 

200 """ 

201 return sys.argv[1:] 

202 

203 

204def command_line_dict(): 

205 """Return a dictionary of command line arguments 

206 """ 

207 from pesummary.gw.cli.parser import ArgumentParser 

208 parser = ArgumentParser() 

209 parser.add_all_known_options_to_parser() 

210 opts = parser.parse_args() 

211 return vars(opts) 

212 

213 

214def gw_results_file(opts): 

215 """Determine if a GW results file is passed 

216 """ 

217 from pesummary.gw.cli.parser import ArgumentParser 

218 

219 attrs, defaults = ArgumentParser().gw_options 

220 condition = any( 

221 hasattr(opts, attr) and getattr(opts, attr) and getattr(opts, attr) 

222 != default for attr, default in zip(attrs, defaults) 

223 ) 

224 if condition: 

225 return True 

226 return False 

227 

228 

229def functions(opts, gw=False): 

230 """Return a dictionary of functions that are either specific to GW results 

231 files or core. 

232 """ 

233 from pesummary.core.cli.inputs import ( 

234 WebpagePlusPlottingPlusMetaFileInput as Input 

235 ) 

236 from pesummary.gw.cli.inputs import ( 

237 WebpagePlusPlottingPlusMetaFileInput as GWInput 

238 ) 

239 from pesummary.core.file.meta_file import MetaFile 

240 from pesummary.gw.file.meta_file import GWMetaFile 

241 from pesummary.core.finish import FinishingTouches 

242 from pesummary.gw.finish import GWFinishingTouches 

243 

244 dictionary = {} 

245 dictionary["input"] = GWInput if gw_results_file(opts) or gw else Input 

246 dictionary["MetaFile"] = GWMetaFile if gw_results_file(opts) or gw else MetaFile 

247 dictionary["FinishingTouches"] = \ 

248 GWFinishingTouches if gw_results_file(opts) or gw else FinishingTouches 

249 return dictionary 

250 

251 

252def _logger_format(): 

253 return '%(asctime)s %(name)s %(levelname)-8s: %(message)s' 

254 

255 

256def setup_logger(): 

257 """Set up the logger output. 

258 """ 

259 import tempfile 

260 

261 def get_console_handler(stream_level="INFO"): 

262 console_handler = logging.StreamHandler() 

263 console_handler.setLevel(level=getattr(logging, stream_level)) 

264 console_handler.setFormatter(FORMATTER) 

265 return console_handler 

266 

267 def get_file_handler(log_file): 

268 file_handler = logging.FileHandler(log_file, mode='w') 

269 file_handler.setLevel(level=logging.DEBUG) 

270 file_handler.setFormatter(FORMATTER) 

271 return file_handler 

272 

273 make_dir(LOG_CACHE) 

274 dirpath = tempfile.mkdtemp(dir=LOG_CACHE) 

275 stream_level = 'INFO' 

276 if "-v" in sys.argv or "--verbose" in sys.argv: 

277 stream_level = 'DEBUG' 

278 

279 FORMATTER = LogFormatter(_logger_format(), datefmt='%Y-%m-%d %H:%M:%S') 

280 LOG_FILE = '%s/pesummary.log' % (dirpath) 

281 logger = logging.getLogger('PESummary') 

282 logger.propagate = False 

283 logger.setLevel(level=logging.DEBUG) 

284 logger.addHandler(get_console_handler(stream_level=stream_level)) 

285 logger.addHandler(get_file_handler(LOG_FILE)) 

286 return logger, LOG_FILE 

287 

288 

289def remove_tmp_directories(): 

290 """Remove the temporary directories created by PESummary 

291 """ 

292 import shutil 

293 from glob import glob 

294 

295 directories = glob(".tmp/pesummary/*") 

296 

297 for i in directories: 

298 if os.path.isdir(i): 

299 shutil.rmtree(i) 

300 elif os.path.isfile(i): 

301 os.remove(i) 

302 

303 

304def _add_existing_data(namespace): 

305 """Add existing data to namespace object 

306 """ 

307 for num, i in enumerate(namespace.existing_labels): 

308 if hasattr(namespace, "labels") and i not in namespace.labels: 

309 namespace.labels.append(i) 

310 if hasattr(namespace, "samples") and i not in list(namespace.samples.keys()): 

311 namespace.samples[i] = namespace.existing_samples[i] 

312 if hasattr(namespace, "weights") and i not in list(namespace.weights.keys()): 

313 if namespace.existing_weights is None: 

314 namespace.weights[i] = None 

315 else: 

316 namespace.weights[i] = namespace.existing_weights[i] 

317 if hasattr(namespace, "injection_data"): 

318 if i not in list(namespace.injection_data.keys()): 

319 namespace.injection_data[i] = namespace.existing_injection_data[i] 

320 if hasattr(namespace, "file_versions"): 

321 if i not in list(namespace.file_versions.keys()): 

322 namespace.file_versions[i] = namespace.existing_file_version[i] 

323 if hasattr(namespace, "file_kwargs"): 

324 if i not in list(namespace.file_kwargs.keys()): 

325 namespace.file_kwargs[i] = namespace.existing_file_kwargs[i] 

326 if hasattr(namespace, "config"): 

327 if namespace.existing_config[num] not in namespace.config: 

328 namespace.config.append(namespace.existing_config[num]) 

329 elif namespace.existing_config[num] is None: 

330 namespace.config.append(None) 

331 if hasattr(namespace, "priors"): 

332 if hasattr(namespace, "existing_priors"): 

333 for key, item in namespace.existing_priors.items(): 

334 if key in namespace.priors.keys(): 

335 for label in item.keys(): 

336 if label not in namespace.priors[key].keys(): 

337 namespace.priors[key][label] = item[label] 

338 else: 

339 namespace.priors.update({key: item}) 

340 if hasattr(namespace, "approximant") and namespace.approximant is not None: 

341 if i not in list(namespace.approximant.keys()): 

342 if i in list(namespace.existing_approximant.keys()): 

343 namespace.approximant[i] = namespace.existing_approximant[i] 

344 if hasattr(namespace, "psds") and namespace.psds is not None: 

345 if i not in list(namespace.psds.keys()): 

346 if i in list(namespace.existing_psd.keys()): 

347 namespace.psds[i] = namespace.existing_psd[i] 

348 else: 

349 namespace.psds[i] = {} 

350 if hasattr(namespace, "calibration") and namespace.calibration is not None: 

351 if i not in list(namespace.calibration.keys()): 

352 if i in list(namespace.existing_calibration.keys()): 

353 namespace.calibration[i] = namespace.existing_calibration[i] 

354 else: 

355 namespace.calibration[i] = {} 

356 if hasattr(namespace, "skymap") and namespace.skymap is not None: 

357 if i not in list(namespace.skymap.keys()): 

358 if i in list(namespace.existing_skymap.keys()): 

359 namespace.skymap[i] = namespace.existing_skymap[i] 

360 else: 

361 namespace.skymap[i] = None 

362 if hasattr(namespace, "maxL_samples"): 

363 if i not in list(namespace.maxL_samples.keys()): 

364 namespace.maxL_samples[i] = { 

365 key: val.maxL for key, val in namespace.samples[i].items() 

366 } 

367 if hasattr(namespace, "pastro_probs"): 

368 if i not in list(namespace.pastro_probs.keys()): 

369 from pesummary.gw.classification import PAstro 

370 try: 

371 namespace.pastro_probs[i] = {"default": PAstro( 

372 namespace.existing_samples[i], 

373 ).classification()} 

374 except Exception: 

375 namespace.pastro_probs[i] = {"default": PAstro.defaults} 

376 if hasattr(namespace, "embright_probs"): 

377 if i not in list(namespace.embright_probs.keys()): 

378 from pesummary.gw.classification import EMBright 

379 try: 

380 namespace.embright_probs[i] = {"default": EMBright( 

381 namespace.existing_samples[i] 

382 ).classification()} 

383 except Exception: 

384 namespace.embright_probs[i] = {"default": EMBright.defaults} 

385 if hasattr(namespace, "result_files"): 

386 number = len(namespace.labels) 

387 while len(namespace.result_files) < number: 

388 namespace.result_files.append(namespace.existing_metafile) 

389 parameters = [list(namespace.samples[i].keys()) for i in namespace.labels] 

390 namespace.same_parameters = list( 

391 set.intersection(*[set(l) for l in parameters]) 

392 ) 

393 namespace.same_samples = { 

394 param: { 

395 i: namespace.samples[i][param] for i in namespace.labels 

396 } for param in namespace.same_parameters 

397 } 

398 return namespace 

399 

400 

401def customwarn(message, category, filename, lineno, file=None, line=None): 

402 """ 

403 """ 

404 import sys 

405 import warnings 

406 

407 sys.stdout.write( 

408 warnings.formatwarning("%s" % (message), category, filename, lineno) 

409 ) 

410 

411 

412def determine_gps_time_and_window(maxL_samples, labels): 

413 """Determine the gps time and window to use in the spectrogram and 

414 omegascan plots 

415 """ 

416 times = [ 

417 maxL_samples[label]["geocent_time"] for label in labels 

418 ] 

419 gps_time = np.mean(times) 

420 time_range = np.max(times) - np.min(times) 

421 if time_range < 4.: 

422 window = 4. 

423 else: 

424 window = time_range * 1.5 

425 return gps_time, window 

426 

427 

428def number_of_columns_for_legend(labels): 

429 """Determine the number of columns to use in a legend 

430 

431 Parameters 

432 ---------- 

433 labels: list 

434 list of labels in the legend 

435 """ 

436 max_length = np.max([len(i) for i in labels]) + 5. 

437 if max_length > 50.: 

438 return 1 

439 else: 

440 return int(50. / max_length) 

441 

442 

443class RedirectLogger(object): 

444 """Class to redirect the output from other codes to the `pesummary` 

445 logger 

446 

447 Parameters 

448 ---------- 

449 level: str, optional 

450 the level to display the messages 

451 """ 

452 def __init__(self, code, level="Debug"): 

453 self.logger = logging.getLogger('PESummary') 

454 self.level = getattr(logging, level) 

455 self._redirector = contextlib.redirect_stdout(self) 

456 self.code = code 

457 

458 def isatty(self): 

459 pass 

460 

461 def write(self, msg): 

462 """Write the message to stdout 

463 

464 Parameters 

465 ---------- 

466 msg: str 

467 the message you wish to be printed to stdout 

468 """ 

469 if msg and not msg.isspace(): 

470 self.logger.log(self.level, "[from %s] %s" % (self.code, msg)) 

471 

472 def flush(self): 

473 pass 

474 

475 def __enter__(self): 

476 self._redirector.__enter__() 

477 return self 

478 

479 def __exit__(self, exc_type, exc_value, traceback): 

480 self._redirector.__exit__(exc_type, exc_value, traceback) 

481 

482 

483def draw_conditioned_prior_samples( 

484 samples_dict, prior_samples_dict, conditioned, xlow, xhigh, N=100, 

485 nsamples=1000 

486): 

487 """Return a prior_dict that is conditioned on certain parameters 

488 

489 Parameters 

490 ---------- 

491 samples_dict: pesummary.utils.samples_dict.SamplesDict 

492 SamplesDict containing the posterior samples 

493 prior_samples_dict: pesummary.utils.samples_dict.SamplesDict 

494 SamplesDict containing the prior samples 

495 conditioned: list 

496 list of parameters that you wish to condition your prior on 

497 xlow: dict 

498 dictionary of lower bounds for each parameter 

499 xhigh: dict 

500 dictionary of upper bounds for each parameter 

501 N: int, optional 

502 number of points to use within the grid. Default 100 

503 nsamples: int, optional 

504 number of samples to draw. Default 1000 

505 """ 

506 for param in conditioned: 

507 indices = _draw_conditioned_prior_samples( 

508 prior_samples_dict[param], samples_dict[param], xlow[param], 

509 xhigh[param], xN=N, N=nsamples 

510 ) 

511 for key, val in prior_samples_dict.items(): 

512 prior_samples_dict[key] = val[indices] 

513 

514 return prior_samples_dict 

515 

516 

517def _draw_conditioned_prior_samples( 

518 prior_samples, posterior_samples, xlow, xhigh, xN=1000, N=1000 

519): 

520 """Return a list of indices for the conditioned prior via rejection 

521 sampling. The conditioned prior will then be `prior_samples[indicies]`. 

522 Code from Michael Puerrer. 

523 

524 Parameters 

525 ---------- 

526 prior_samples: np.ndarray 

527 array of prior samples that you wish to condition 

528 posterior_samples: np.ndarray 

529 array of posterior samples that you wish to condition on 

530 xlow: float 

531 lower bound for grid to be used 

532 xhigh: float 

533 upper bound for grid to be used 

534 xN: int, optional 

535 Number of points to use within the grid 

536 N: int, optional 

537 Number of samples to generate 

538 """ 

539 from pesummary.utils.bounded_1d_kde import ReflectionBoundedKDE 

540 

541 prior_KDE = ReflectionBoundedKDE(prior_samples) 

542 posterior_KDE = ReflectionBoundedKDE(posterior_samples) 

543 

544 x = np.linspace(xlow, xhigh, xN) 

545 idx_nz = np.nonzero(posterior_KDE(x)) 

546 pdf_ratio = prior_KDE(x)[idx_nz] / posterior_KDE(x)[idx_nz] 

547 M = 1.1 / min(pdf_ratio[np.where(pdf_ratio < 1)]) 

548 

549 indicies = [] 

550 i = 0 

551 while i < N: 

552 x_i = np.random.choice(prior_samples) 

553 idx_i = np.argmin(np.abs(prior_samples - x_i)) 

554 u = np.random.uniform() 

555 if u < posterior_KDE(x_i) / (M * prior_KDE(x_i)): 

556 indicies.append(idx_i) 

557 i += 1 

558 return indicies 

559 

560 

561def unzip(zip_file, outdir=None, overwrite=False): 

562 """Extract the data from a zipped file and save in outdir. 

563 

564 Parameters 

565 ---------- 

566 zip_file: str 

567 path to the file you wish to unzip 

568 outdir: str, optional 

569 path to the directory where you wish to save the unzipped file. Default 

570 None which means that the unzipped file is stored in CACHE_DIR 

571 overwrite: Bool, optional 

572 If True, overwrite a file that has the same name 

573 """ 

574 import gzip 

575 import shutil 

576 from pathlib import Path 

577 

578 f = Path(zip_file) 

579 file_name = f.stem 

580 if outdir is None: 

581 outdir = CACHE_DIR 

582 out_file = os.path.join(outdir, file_name) 

583 if os.path.isfile(out_file) and not overwrite: 

584 raise FileExistsError( 

585 "The file '{}' already exists. Not overwriting".format(out_file) 

586 ) 

587 with gzip.open(zip_file, 'rb') as input: 

588 with open(out_file, 'wb') as output: 

589 shutil.copyfileobj(input, output) 

590 return out_file 

591 

592 

593def iterator( 

594 iterable, desc=None, logger=None, tqdm=False, total=None, file=None, 

595 bar_format=None 

596): 

597 """Return either a tqdm iterator, if tqdm installed, or a simple range 

598 

599 Parameters 

600 ---------- 

601 iterable: func 

602 iterable that you wish to iterate over 

603 desc: str, optional 

604 description for the tqdm bar 

605 tqdm: Bool, optional 

606 If True, a tqdm object is used. Otherwise simply returns the iterator. 

607 logger_output: Bool, optional 

608 If True, the tqdm progress bar interacts with logger 

609 total: float, optional 

610 total length of iterable 

611 logger_name: str, optional 

612 name of the logger you wish to use 

613 file: str, optional 

614 path to file that you wish to write the output to 

615 """ 

616 from pesummary.utils.tqdm import tqdm 

617 if tqdm: 

618 try: 

619 FORMAT, DESC = None, None 

620 if bar_format is None: 

621 FORMAT = ( 

622 '{desc} | {percentage:3.0f}% | {n_fmt}/{total_fmt} | {elapsed}' 

623 ) 

624 if desc is not None: 

625 DESC = desc 

626 return tqdm( 

627 iterable, total=total, logger=logger, desc=DESC, file=file, 

628 bar_format=FORMAT, 

629 ) 

630 except ImportError: 

631 return iterable 

632 else: 

633 return iterable 

634 

635 

636def _check_latex_install(force_tex=False): 

637 from matplotlib import rcParams 

638 from matplotlib.texmanager import TexManager 

639 

640 # If LaTeX executable is not found, disable usetex quickly 

641 if LATEX is None: 

642 rcParams["text.usetex"] = False 

643 return False 

644 

645 # Otherwise, try and render something 

646 texmanager = TexManager() 

647 try: 

648 texmanager.make_dvi(r"$mass_{1}$", 12) 

649 except RuntimeError: 

650 # It failed, disable usetex 

651 rcParams["text.usetex"] = False 

652 return False 

653 

654 # It works! enable usetex if forced 

655 if force_tex: 

656 rcParams["text.usetex"] = True 

657 

658 return True 

659 

660 

661def smart_round(parameters, return_latex=False, return_latex_row=False): 

662 """Round a parameter according to the uncertainty. If more than one parameter 

663 and uncertainty is passed, each parameter is rounded according to the 

664 lowest uncertainty 

665 

666 Parameters 

667 ---------- 

668 parameter_dictionary: list/np.ndarray 

669 list containing the median, upper bound and lower bound for a given parameter 

670 return_latex: Bool, optional 

671 if True, return as a latex string 

672 return_latex_row: Bool, optional 

673 if True, return the rounded data as a single row in latex format 

674 

675 Examples 

676 -------- 

677 >>> data = [1.234, 0.2, 0.1] 

678 >>> smart_round(data) 

679 [ 1.2 0.2 0.1] 

680 >>> data = [ 

681 ... [6.093, 0.059, 0.055], 

682 ... [6.104, 0.057, 0.052], 

683 ... [6.08, 0.056, 0.052] 

684 ... ] 

685 >>> smart_round(data) 

686 [[ 6.09 0.06 0.06] 

687 [ 6.1 0.06 0.05] 

688 [ 6.08 0.06 0.05]] 

689 >>> smart_round(data, return_latex=True) 

690 6.09^{+0.06}_{-0.06} 

691 6.10^{+0.06}_{-0.05} 

692 6.08^{+0.06}_{-0.05} 

693 >>> data = [ 

694 ... [743.25, 43.6, 53.2], 

695 ... [8712.5, 21.5, 35.2], 

696 ... [196.46, 65.2, 12.5] 

697 ... ] 

698 >>> smart_round(data, return_latex_row=True) 

699 740^{+40}_{-50} & 8710^{+20}_{-40} & 200^{+70}_{-10} 

700 >>> data = [ 

701 ... [743.25, 43.6, 53.2], 

702 ... [8712.5, 21.5, 35.2], 

703 ... [196.46, 65.2, 8.2] 

704 ... ] 

705 >>> smart_round(data, return_latex_row=True) 

706 743^{+44}_{-53} & 8712^{+22}_{-35} & 196^{+65}_{-8} 

707 """ 

708 rounded = copy.deepcopy(np.atleast_2d(parameters)) 

709 lowest_uncertainty = np.min(np.abs(parameters)) 

710 rounding = int(-1 * np.floor(np.log10(lowest_uncertainty))) 

711 for num, _ in enumerate(rounded): 

712 rounded[num] = [np.round(value, rounding) for value in rounded[num]] 

713 if return_latex or return_latex_row: 

714 if rounding > 0: 

715 _format = "%.{}f".format(rounding) 

716 else: 

717 _format = "%.f" 

718 string = "{0}^{{+{0}}}_{{-{0}}}".format(_format) 

719 latex = [string % (value[0], value[1], value[2]) for value in rounded] 

720 if return_latex: 

721 for ll in latex: 

722 print(ll) 

723 else: 

724 print(" & ".join(latex)) 

725 return "" 

726 elif np.array(parameters).ndim == 1: 

727 return rounded[0] 

728 else: 

729 return rounded 

730 

731 

732def safe_round(a, decimals=0, **kwargs): 

733 """Try and round an array to the given number of decimals. If an exception 

734 is raised, return the original array 

735 

736 Parameters 

737 ---------- 

738 a: np.ndarray 

739 array you wish to round 

740 decimals: int 

741 the number of decimals you wish to round too 

742 **kwargs: dict 

743 all kwargs are passed to numpy.round 

744 """ 

745 try: 

746 return np.round(a, decimals=decimals, **kwargs) 

747 except Exception: 

748 return a 

749 

750 

751def gelman_rubin(samples, decimal=5): 

752 """Return an approximation to the Gelman-Rubin statistic (see Gelman, A. and 

753 Rubin, D. B., Statistical Science, Vol 7, No. 4, pp. 457--511 (1992)) 

754 

755 Parameters 

756 ---------- 

757 samples: np.ndarray 

758 2d array of samples for a given parameter, one for each chain 

759 decimal: int 

760 number of decimal places to keep when rounding 

761 

762 Examples 

763 -------- 

764 >>> from pesummary.utils.utils import gelman_rubin 

765 >>> samples = [[1, 1.5, 1.2, 1.4, 1.6, 1.2], [1.5, 1.3, 1.4, 1.7]] 

766 >>> gelman_rubin(samples, decimal=5) 

767 1.2972 

768 """ 

769 means = [np.mean(data) for data in samples] 

770 variances = [np.var(data) for data in samples] 

771 BoverN = np.var(means) 

772 W = np.mean(variances) 

773 sigma = W + BoverN 

774 m = len(samples) 

775 Vhat = sigma + BoverN / m 

776 return np.round(Vhat / W, decimal) 

777 

778 

779def kolmogorov_smirnov_test(samples, decimal=5): 

780 """Return the KS p value between two PDFs 

781 

782 Parameters 

783 ---------- 

784 samples: 2d list 

785 2d list containing the 2 PDFs that you wish to compare 

786 decimal: int 

787 number of decimal places to keep when rounding 

788 """ 

789 return np.round(stats.ks_2samp(*samples)[1], decimal) 

790 

791 

792def jensen_shannon_divergence(*args, **kwargs): 

793 import warnings 

794 warnings.warn( 

795 "The jensen_shannon_divergence function has changed its name to " 

796 "jensen_shannon_divergence_from_samples. jensen_shannon_divergence " 

797 "may not be supported in future releases. Please update" 

798 ) 

799 return jensen_shannon_divergence_from_samples(*args, **kwargs) 

800 

801 

802def jensen_shannon_divergence_from_samples( 

803 samples, kde=stats.gaussian_kde, decimal=5, base=np.e, **kwargs 

804): 

805 """Calculate the JS divergence between two sets of samples 

806 

807 Parameters 

808 ---------- 

809 samples: list 

810 2d list containing the samples drawn from two pdfs 

811 kde: func 

812 function to use when calculating the kde of the samples 

813 decimal: int, float 

814 number of decimal places to round the JS divergence to 

815 base: float, optional 

816 optional base to use for the scipy.stats.entropy function. Default 

817 np.e 

818 kwargs: dict 

819 all kwargs are passed to the kde function 

820 """ 

821 pdfs = samples_to_kde(samples, kde=kde, **kwargs) 

822 return jensen_shannon_divergence_from_pdfs(pdfs, decimal=decimal, base=base) 

823 

824 

825def jensen_shannon_divergence_from_pdfs(pdfs, decimal=5, base=np.e): 

826 """Calculate the JS divergence between two distributions 

827 

828 Parameters 

829 ---------- 

830 pdfs: list 

831 list of length 2 containing the distributions you wish to compare 

832 decimal: int, float 

833 number of decimal places to round the JS divergence to 

834 base: float, optional 

835 optional base to use for the scipy.stats.entropy function. Default 

836 np.e 

837 """ 

838 if any(np.isnan(_).any() for _ in pdfs): 

839 return float("nan") 

840 a, b = pdfs 

841 a = np.asarray(a) 

842 b = np.asarray(b) 

843 a /= a.sum() 

844 b /= b.sum() 

845 m = 1. / 2 * (a + b) 

846 kl_forward = stats.entropy(a, qk=m, base=base) 

847 kl_backward = stats.entropy(b, qk=m, base=base) 

848 return np.round(kl_forward / 2. + kl_backward / 2., decimal) 

849 

850 

851def samples_to_kde(samples, kde=stats.gaussian_kde, **kwargs): 

852 """Generate KDE for a set of samples 

853 

854 Parameters 

855 ---------- 

856 samples: list 

857 list containing the samples to create a KDE for. samples can also 

858 be a 2d list containing samples from multiple analyses. 

859 kde: func 

860 function to use when calculating the kde of the samples 

861 """ 

862 _SINGLE_ANALYSIS = False 

863 if not isinstance(samples[0], (np.ndarray, list, tuple)): 

864 _SINGLE_ANALYSIS = True 

865 _samples = [samples] 

866 else: 

867 _samples = samples 

868 kernel = [] 

869 for i in _samples: 

870 try: 

871 kernel.append(kde(i, **kwargs)) 

872 except np.linalg.LinAlgError: 

873 kernel.append(None) 

874 x = np.linspace( 

875 np.min([np.min(i) for i in _samples]), 

876 np.max([np.max(i) for i in _samples]), 

877 100 

878 ) 

879 pdfs = [k(x) if k is not None else float('nan') for k in kernel] 

880 if _SINGLE_ANALYSIS: 

881 return pdfs[0] 

882 return pdfs 

883 

884 

885def make_cache_style_file(style_file): 

886 """Make a cache directory which stores the style file you wish to use 

887 when plotting 

888 

889 Parameters 

890 ---------- 

891 style_file: str 

892 path to the style file that you wish to use when plotting 

893 """ 

894 make_dir(STYLE_CACHE) 

895 shutil.copyfile( 

896 style_file, os.path.join(STYLE_CACHE, "matplotlib_rcparams.sty") 

897 ) 

898 

899 

900def get_matplotlib_style_file(): 

901 """Return the path to the matplotlib style file that you wish to use 

902 """ 

903 style_file = os.path.join(STYLE_CACHE, "matplotlib_rcparams.sty") 

904 if not os.path.isfile(style_file): 

905 from pesummary import conf 

906 

907 return conf.style_file 

908 return os.path.join(style_file) 

909 

910 

911def get_matplotlib_backend(parallel=False): 

912 """Return the matplotlib backend to use for the plotting modules 

913 

914 Parameters 

915 ---------- 

916 parallel: Bool, optional 

917 if True, backend is always set to 'Agg' for the multiprocessing module 

918 """ 

919 try: 

920 os.environ["DISPLAY"] 

921 except KeyError: 

922 try: 

923 __IPYTHON__ 

924 except NameError: 

925 DISPLAY = False 

926 else: 

927 DISPLAY = True 

928 else: 

929 DISPLAY = True 

930 if DISPLAY and not parallel: 

931 backend = "TKAgg" 

932 else: 

933 backend = "Agg" 

934 return backend 

935 

936 

937def _default_filename(default_filename, label=None): 

938 """Return a default filename 

939 

940 Parameters 

941 ---------- 

942 default_filename: str, optional 

943 the default filename to use if a filename is not provided. default_filename 

944 must be a formattable string with one empty argument for a label 

945 label: str, optional 

946 The label of the analysis. This is used in the filename 

947 """ 

948 if not label: 

949 filename = default_filename.format(round(time.time())) 

950 else: 

951 filename = default_filename.format(label) 

952 return filename 

953 

954 

955def check_filename( 

956 default_filename="pesummary_{}.dat", outdir="./", label=None, filename=None, 

957 overwrite=False, delete_existing=False 

958): 

959 """Check to see if a file exists. If no filename is provided, a default 

960 filename is checked 

961 

962 Parameters 

963 ---------- 

964 default_filename: str, optional 

965 the default filename to use if a filename is not provided. default_filename 

966 must be a formattable string with one empty argument for a label 

967 outdir: str, optional 

968 directory to write the dat file 

969 label: str, optional 

970 The label of the analysis. This is used in the filename if a filename 

971 if not specified 

972 filename: str, optional 

973 The name of the file that you wish to write 

974 overwrite: Bool, optional 

975 If True, an existing file of the same name will be overwritten 

976 """ 

977 if not filename: 

978 filename = _default_filename(default_filename, label=label) 

979 _file = os.path.join(outdir, filename) 

980 if os.path.isfile(_file) and not overwrite: 

981 raise FileExistsError( 

982 "The file '{}' already exists in the directory {}".format( 

983 filename, outdir 

984 ) 

985 ) 

986 if os.path.isfile(_file) and delete_existing: 

987 os.remove(_file) 

988 return _file 

989 

990 

991def string_match(string, substring): 

992 """Return True if a string matches a substring. This substring may include 

993 wildcards 

994 

995 Parameters 

996 ---------- 

997 string: str 

998 string you wish to match 

999 substring: str 

1000 string you wish to match against 

1001 """ 

1002 import re 

1003 import sre_constants 

1004 

1005 try: 

1006 match = re.match(re.compile(substring), string) 

1007 if match: 

1008 return True 

1009 return False 

1010 except sre_constants.error: 

1011 import fnmatch 

1012 return string_match(string, fnmatch.translate(substring)) 

1013 

1014 

1015def glob_directory(base): 

1016 """Return a list of files matching base 

1017 

1018 Parameters 

1019 ---------- 

1020 base: str 

1021 string you wish to match e.g. "./", "./*.py" 

1022 """ 

1023 import glob 

1024 if "*" not in base: 

1025 base = os.path.join(base, "*") 

1026 return glob.glob(base) 

1027 

1028 

1029def list_match(list_to_match, substring, return_true=True, return_false=False): 

1030 """Match a list of strings to a substring. This substring may include 

1031 wildcards 

1032 

1033 Parameters 

1034 ---------- 

1035 list_to_match: list 

1036 list of string you wish to match 

1037 substring: str, list 

1038 string you wish to match against or a list of string you wish to match 

1039 against 

1040 return_true: Bool, optional 

1041 if True, return a sublist containing only the parameters that match the 

1042 substring. Default True 

1043 """ 

1044 match = np.ones(len(list_to_match), dtype=bool) 

1045 if isinstance(substring, str): 

1046 substring = [substring] 

1047 

1048 for _substring in substring: 

1049 match *= np.array( 

1050 [string_match(item, _substring) for item in list_to_match], 

1051 dtype=bool 

1052 ) 

1053 if return_false: 

1054 return np.array(list_to_match)[~match] 

1055 elif return_true: 

1056 return np.array(list_to_match)[match] 

1057 return match 

1058 

1059 

1060class Empty(object): 

1061 """Define an empty class which simply returns the input 

1062 """ 

1063 def __new__(self, *args): 

1064 return args[0] 

1065 

1066 

1067def history_dictionary(program=None, creator=conf.user, command_line=None): 

1068 """Create a dictionary containing useful information about the origin of 

1069 a PESummary data product 

1070 

1071 Parameters 

1072 ---------- 

1073 program: str, optional 

1074 program used to generate the PESummary data product 

1075 creator: str, optional 

1076 The user who created the PESummary data product 

1077 command_line: str, optional 

1078 The command line which was run to generate the PESummary data product 

1079 """ 

1080 from astropy.time import Time 

1081 

1082 _dict = { 

1083 "gps_creation_time": Time.now().gps, 

1084 "creator": creator, 

1085 } 

1086 if command_line is not None: 

1087 _dict["command_line"] = ( 

1088 "Generated by running the following script: {}".format( 

1089 command_line 

1090 ) 

1091 ) 

1092 else: 

1093 _dict["command_line"] = " ".join(sys.argv) 

1094 if program is not None: 

1095 _dict["program"] = program 

1096 return _dict 

1097 

1098 

1099def mute_logger(): 

1100 """Mute the PESummary logger 

1101 """ 

1102 _logger = logging.getLogger('PESummary') 

1103 _logger.setLevel(logging.CRITICAL + 10) 

1104 return 

1105 

1106 

1107def unmute_logger(): 

1108 """Unmute the PESummary logger 

1109 """ 

1110 _logger = logging.getLogger('PESummary') 

1111 _logger.setLevel(logging.INFO) 

1112 return 

1113 

1114# import error message 

1115import_error_msg = ( 

1116 "Unable to install '{}'. You will not be able to use some of the inbuilt " 

1117 "functions." 

1118) 

1119 

1120 

1121# silence matplotlib warnings 

1122logging.getLogger('matplotlib.font_manager').setLevel(logging.CRITICAL + 10) 

1123# setup pesummary logger 

1124_, LOG_FILE = setup_logger() 

1125logger = logging.getLogger('PESummary')