Coverage for pesummary/core/cli/inputs.py: 80.6%

1264 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2026-01-15 17:49 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import os 

4import socket 

5from glob import glob 

6from pathlib import Path 

7from getpass import getuser 

8 

9import math 

10import numpy as np 

11from pesummary.core.file.read import read as Read 

12from pesummary.utils.exceptions import InputError 

13from pesummary.utils.decorators import deprecation 

14from pesummary.utils.samples_dict import SamplesDict, MCMCSamplesDict 

15from pesummary.utils.utils import ( 

16 guess_url, logger, make_dir, make_cache_style_file, list_match 

17) 

18from pesummary import conf 

19 

20__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

21 

22 

23class _Input(object): 

24 """Super class to handle the command line arguments 

25 """ 

26 @staticmethod 

27 def is_pesummary_metafile(proposed_file): 

28 """Determine if a file is a PESummary metafile or not 

29 

30 Parameters 

31 ---------- 

32 proposed_file: str 

33 path to the file 

34 """ 

35 extension = proposed_file.split(".")[-1] 

36 if extension == "h5" or extension == "hdf5" or extension == "hdf": 

37 from pesummary.core.file.read import ( 

38 is_pesummary_hdf5_file, is_pesummary_hdf5_file_deprecated 

39 ) 

40 

41 result = any( 

42 func(proposed_file) for func in [ 

43 is_pesummary_hdf5_file, 

44 is_pesummary_hdf5_file_deprecated 

45 ] 

46 ) 

47 return result 

48 elif extension == "json": 

49 from pesummary.core.file.read import ( 

50 is_pesummary_json_file, is_pesummary_json_file_deprecated 

51 ) 

52 

53 result = any( 

54 func(proposed_file) for func in [ 

55 is_pesummary_json_file, 

56 is_pesummary_json_file_deprecated 

57 ] 

58 ) 

59 return result 

60 else: 

61 return False 

62 

63 @staticmethod 

64 def grab_data_from_metafile( 

65 existing_file, webdir, compare=None, read_function=Read, 

66 _replace_with_pesummary_kwargs={}, nsamples=None, 

67 disable_injection=False, keep_nan_likelihood_samples=False, 

68 reweight_samples=False, **kwargs 

69 ): 

70 """Grab data from an existing PESummary metafile 

71 

72 Parameters 

73 ---------- 

74 existing_file: str 

75 path to the existing metafile 

76 webdir: str 

77 the directory to store the existing configuration file 

78 compare: list, optional 

79 list of labels for events stored in an existing metafile that you 

80 wish to compare 

81 read_function: func, optional 

82 PESummary function to use to read in the existing file 

83 _replace_with_pesummary_kwargs: dict, optional 

84 dictionary of kwargs that you wish to replace with the data stored 

85 in the PESummary file 

86 nsamples: int, optional 

87 Number of samples to use. Default all available samples 

88 kwargs: dict 

89 All kwargs are passed to the `generate_all_posterior_samples` 

90 method 

91 """ 

92 f = read_function( 

93 existing_file, 

94 remove_nan_likelihood_samples=not keep_nan_likelihood_samples 

95 ) 

96 for ind, label in enumerate(f.labels): 

97 kwargs[label] = kwargs.copy() 

98 for key, item in _replace_with_pesummary_kwargs.items(): 

99 try: 

100 kwargs[label][key] = eval( 

101 item.format(file="f", ind=ind, label=label) 

102 ) 

103 except TypeError: 

104 _item = item.split("['{label}']")[0] 

105 kwargs[label][key] = eval( 

106 _item.format(file="f", ind=ind, label=label) 

107 ) 

108 except (AttributeError, KeyError, NameError): 

109 pass 

110 

111 if not f.mcmc_samples: 

112 labels = f.labels 

113 else: 

114 labels = list(f.samples_dict.keys()) 

115 indicies = np.arange(len(labels)) 

116 

117 if compare: 

118 indicies = [] 

119 for i in compare: 

120 if i not in labels: 

121 raise InputError( 

122 "Label '%s' does not exist in the metafile. The list " 

123 "of available labels are %s" % (i, labels) 

124 ) 

125 indicies.append(labels.index(i)) 

126 labels = compare 

127 

128 if nsamples is not None: 

129 f.downsample(nsamples, labels=labels) 

130 if not f.mcmc_samples: 

131 f.generate_all_posterior_samples(labels=labels, **kwargs) 

132 if reweight_samples: 

133 f.reweight_samples(reweight_samples, labels=labels, **kwargs) 

134 

135 parameters = f.parameters 

136 if not f.mcmc_samples: 

137 samples = [np.array(i).T for i in f.samples] 

138 DataFrame = { 

139 label: SamplesDict(parameters[ind], samples[ind]) 

140 for label, ind in zip(labels, indicies) 

141 } 

142 _parameters = lambda label: DataFrame[label].keys() 

143 else: 

144 DataFrame = { 

145 f.labels[0]: MCMCSamplesDict( 

146 { 

147 label: f.samples_dict[label] for label in labels 

148 } 

149 ) 

150 } 

151 labels = f.labels 

152 indicies = np.arange(len(labels)) 

153 _parameters = lambda label: DataFrame[f.labels[0]].parameters 

154 if not disable_injection and f.injection_parameters != []: 

155 inj_values = f.injection_dict 

156 for label in labels: 

157 for param in DataFrame[label].keys(): 

158 if param not in f.injection_dict[label].keys(): 

159 f.injection_dict[label][param] = float("nan") 

160 else: 

161 inj_values = { 

162 i: { 

163 param: float("nan") for param in DataFrame[i].parameters 

164 } for i in labels 

165 } 

166 for i in inj_values.keys(): 

167 for param in inj_values[i].keys(): 

168 if inj_values[i][param] == "nan": 

169 inj_values[i][param] = float("nan") 

170 if isinstance(inj_values[i][param], bytes): 

171 inj_values[i][param] = inj_values[i][param].decode("utf-8") 

172 

173 if hasattr(f, "priors") and f.priors is not None and f.priors != {}: 

174 priors = f.priors 

175 else: 

176 priors = {label: {} for label in labels} 

177 

178 config = [] 

179 if f.config is not None and not all(i is None for i in f.config): 

180 config = [] 

181 for i in labels: 

182 config_dir = os.path.join(webdir, "config") 

183 filename = f.write_config_to_file( 

184 i, outdir=config_dir, _raise=False, 

185 filename="{}_config.ini".format(i) 

186 ) 

187 _config = os.path.join(config_dir, filename) 

188 if filename is not None and os.path.isfile(_config): 

189 config.append(_config) 

190 else: 

191 config.append(None) 

192 else: 

193 for i in labels: 

194 config.append(None) 

195 

196 if f.weights is not None: 

197 weights = {i: f.weights[i] for i in labels} 

198 else: 

199 weights = {i: None for i in labels} 

200 

201 return { 

202 "samples": DataFrame, 

203 "injection_data": inj_values, 

204 "file_version": { 

205 i: j for i, j in zip( 

206 labels, [f.input_version[ind] for ind in indicies] 

207 ) 

208 }, 

209 "file_kwargs": { 

210 i: j for i, j in zip( 

211 labels, [f.extra_kwargs[ind] for ind in indicies] 

212 ) 

213 }, 

214 "prior": priors, 

215 "config": config, 

216 "labels": labels, 

217 "weights": weights, 

218 "indicies": indicies, 

219 "mcmc_samples": f.mcmc_samples, 

220 "open_file": f, 

221 "descriptions": f.description 

222 } 

223 

224 @staticmethod 

225 def grab_data_from_file( 

226 file, label, webdir, config=None, injection=None, read_function=Read, 

227 file_format=None, nsamples=None, disable_prior_sampling=False, 

228 nsamples_for_prior=None, path_to_samples=None, 

229 keep_nan_likelihood_samples=False, reweight_samples=False, 

230 multi_process=None, **kwargs 

231 ): 

232 """Grab data from a result file containing posterior samples 

233 

234 Parameters 

235 ---------- 

236 file: str 

237 path to the result file 

238 label: str 

239 label that you wish to use for the result file 

240 config: str, optional 

241 path to a configuration file used in the analysis 

242 injection: str, optional 

243 path to an injection file used in the analysis 

244 read_function: func, optional 

245 PESummary function to use to read in the file 

246 file_format, str, optional 

247 the file format you wish to use when loading. Default None. 

248 If None, the read function loops through all possible options 

249 kwargs: dict 

250 Dictionary of keyword arguments fed to the 

251 `generate_all_posterior_samples` method 

252 """ 

253 f = read_function( 

254 file, file_format=file_format, disable_prior=disable_prior_sampling, 

255 nsamples_for_prior=nsamples_for_prior, path_to_samples=path_to_samples, 

256 remove_nan_likelihood_samples=not keep_nan_likelihood_samples, multi_process=multi_process, **kwargs 

257 ) 

258 if config is not None: 

259 f.add_fixed_parameters_from_config_file(config) 

260 

261 if nsamples is not None: 

262 f.downsample(nsamples) 

263 f.generate_all_posterior_samples(**kwargs) 

264 if injection: 

265 f.add_injection_parameters_from_file( 

266 injection, conversion_kwargs=kwargs 

267 ) 

268 if reweight_samples: 

269 f.reweight_samples(reweight_samples) 

270 parameters = f.parameters 

271 samples = np.array(f.samples).T 

272 DataFrame = {label: SamplesDict(parameters, samples)} 

273 kwargs = f.extra_kwargs 

274 if hasattr(f, "injection_parameters"): 

275 injection = f.injection_parameters 

276 if injection is not None: 

277 for i in parameters: 

278 if i not in list(injection.keys()): 

279 injection[i] = float("nan") 

280 else: 

281 injection = {i: j for i, j in zip( 

282 parameters, [float("nan")] * len(parameters))} 

283 else: 

284 injection = {i: j for i, j in zip( 

285 parameters, [float("nan")] * len(parameters))} 

286 version = f.input_version 

287 if hasattr(f, "priors") and f.priors is not None: 

288 priors = {key: {label: item} for key, item in f.priors.items()} 

289 else: 

290 priors = {label: []} 

291 if hasattr(f, "weights") and f.weights is not None: 

292 weights = f.weights 

293 else: 

294 weights = None 

295 data = { 

296 "samples": DataFrame, 

297 "injection_data": {label: injection}, 

298 "file_version": {label: version}, 

299 "file_kwargs": {label: kwargs}, 

300 "prior": priors, 

301 "weights": {label: weights}, 

302 "open_file": f, 

303 "descriptions": {label: f.description} 

304 } 

305 if hasattr(f, "config") and f.config is not None: 

306 if config is None: 

307 config_dir = os.path.join(webdir, "config") 

308 filename = "{}_config.ini".format(label) 

309 logger.debug( 

310 "Successfully extracted config data from the provided " 

311 "input file. Saving the data to the file '{}'".format( 

312 os.path.join(config_dir, filename) 

313 ) 

314 ) 

315 _filename = f.write( 

316 filename=filename, outdir=config_dir, file_format="ini", 

317 _raise=False 

318 ) 

319 data["config"] = _filename 

320 else: 

321 logger.info( 

322 "Ignoring config data extracted from the input file and " 

323 "using the config file provided" 

324 ) 

325 return data 

326 

327 @property 

328 def result_files(self): 

329 return self._result_files 

330 

331 @result_files.setter 

332 def result_files(self, result_files): 

333 self._result_files = result_files 

334 if self._result_files is not None: 

335 for num, ff in enumerate(self._result_files): 

336 func = None 

337 if not os.path.isfile(ff) and "@" in ff: 

338 from pesummary.io.read import _fetch_from_remote_server 

339 func = _fetch_from_remote_server 

340 elif not os.path.isfile(ff) and "https://" in ff: 

341 from pesummary.io.read import _fetch_from_url 

342 func = _fetch_from_url 

343 elif not os.path.isfile(ff) and "*" in ff: 

344 from pesummary.utils.utils import glob_directory 

345 func = glob_directory 

346 if func is not None: 

347 _data = func(ff) 

348 if isinstance(_data, (np.ndarray, list)) and len(_data) > 0: 

349 self._result_files[num] = _data[0] 

350 if len(_data) > 1: 

351 _ = [ 

352 self._result_files.insert(num + 1, d) for d in 

353 _data[1:][::-1] 

354 ] 

355 elif isinstance(_data, np.ndarray): 

356 raise InputError( 

357 "Unable to find any files matching '{}'".format(ff) 

358 ) 

359 else: 

360 self._result_files[num] = _data 

361 

362 @property 

363 def seed(self): 

364 return self._seed 

365 

366 @seed.setter 

367 def seed(self, seed): 

368 np.random.seed(seed) 

369 self._seed = seed 

370 

371 @property 

372 def existing(self): 

373 return self._existing 

374 

375 @existing.setter 

376 def existing(self, existing): 

377 self._existing = existing 

378 if existing is not None: 

379 self._existing = os.path.abspath(existing) 

380 

381 @property 

382 def existing_metafile(self): 

383 return self._existing_metafile 

384 

385 @existing_metafile.setter 

386 def existing_metafile(self, existing_metafile): 

387 from glob import glob 

388 

389 self._existing_metafile = existing_metafile 

390 if self._existing_metafile is None: 

391 return 

392 if not os.path.isdir(os.path.join(self.existing, "samples")): 

393 raise InputError("Please provide a valid existing directory") 

394 _dir = os.path.join(self.existing, "samples") 

395 files = glob(os.path.join(_dir, "posterior_samples*")) 

396 dir_content = glob(os.path.join(_dir, "*.h5")) 

397 dir_content.extend(glob(os.path.join(_dir, "*.json"))) 

398 dir_content.extend(glob(os.path.join(_dir, "*.hdf5"))) 

399 if len(files) == 0 and len(dir_content): 

400 files = dir_content 

401 logger.warning( 

402 "Unable to find a 'posterior_samples*' file in the existing " 

403 "directory. Using '{}' as the existing metafile".format( 

404 dir_content[0] 

405 ) 

406 ) 

407 elif len(files) == 0: 

408 raise InputError( 

409 "Unable to find an existing metafile in the existing webdir" 

410 ) 

411 elif len(files) > 1: 

412 raise InputError( 

413 "Multiple metafiles in the existing directory. Please either " 

414 "run the `summarycombine_metafile` executable to combine the " 

415 "meta files or simple remove the unwanted meta file" 

416 ) 

417 self._existing_metafile = os.path.join( 

418 self.existing, "samples", files[0] 

419 ) 

420 

421 @property 

422 def style_file(self): 

423 return self._style_file 

424 

425 @style_file.setter 

426 def style_file(self, style_file): 

427 default = conf.style_file 

428 if style_file is not None and not os.path.isfile(style_file): 

429 logger.warning( 

430 "The file '{}' does not exist. Resorting to default".format( 

431 style_file 

432 ) 

433 ) 

434 style_file = default 

435 elif style_file is not None and os.path.isfile(style_file): 

436 logger.info( 

437 "Using the file '{}' as the matplotlib style file".format( 

438 style_file 

439 ) 

440 ) 

441 elif style_file is None: 

442 logger.debug( 

443 "Using the default matplotlib style file" 

444 ) 

445 style_file = default 

446 make_cache_style_file(style_file) 

447 self._style_file = style_file 

448 

449 @property 

450 def filename(self): 

451 return self._filename 

452 

453 @filename.setter 

454 def filename(self, filename): 

455 self._filename = filename 

456 if filename is not None: 

457 if "/" in filename: 

458 logger.warning("") 

459 filename = filename.split("/")[-1] 

460 if os.path.isfile(os.path.join(self.webdir, "samples", filename)): 

461 logger.warning( 

462 "A file with filename '{}' already exists in the samples " 

463 "directory '{}'. This will be overwritten" 

464 ) 

465 

466 @property 

467 def user(self): 

468 return self._user 

469 

470 @user.setter 

471 def user(self, user): 

472 try: 

473 self._user = getuser() 

474 logger.info( 

475 conf.overwrite.format("user", conf.user, self._user) 

476 ) 

477 except KeyError as e: 

478 logger.info( 

479 "Failed to grab user information because {}. Default will be " 

480 "used".format(e) 

481 ) 

482 self._user = user 

483 

484 @property 

485 def host(self): 

486 return socket.getfqdn() 

487 

488 @property 

489 def webdir(self): 

490 return self._webdir 

491 

492 @webdir.setter 

493 def webdir(self, webdir): 

494 cond1 = webdir is None or webdir == "None" or webdir == "none" 

495 cond2 = ( 

496 self.existing is None or self.existing == "None" 

497 or self.existing == "none" 

498 ) 

499 if cond1 and cond2: 

500 raise InputError( 

501 "Please provide a web directory to store the webpages. If " 

502 "you wish to add to an existing webpage, then pass the " 

503 "existing web directory under the '--existing_webdir' command " 

504 "line argument. If this is a new set of webpages, then pass " 

505 "the web directory under the '--webdir' argument" 

506 ) 

507 elif webdir is None and self.existing is not None: 

508 if not os.path.isdir(self.existing): 

509 raise InputError( 

510 "The directory {} does not exist".format(self.existing) 

511 ) 

512 entries = glob(self.existing + "/*") 

513 if os.path.join(self.existing, "home.html") not in entries: 

514 raise InputError( 

515 "Please give the base directory of an existing output" 

516 ) 

517 self._webdir = self.existing 

518 else: 

519 if not os.path.isdir(webdir): 

520 logger.debug( 

521 "Given web directory does not exist. Creating it now" 

522 ) 

523 make_dir(webdir) 

524 self._webdir = os.path.abspath(webdir) 

525 

526 @property 

527 def baseurl(self): 

528 return self._baseurl 

529 

530 @baseurl.setter 

531 def baseurl(self, baseurl): 

532 self._baseurl = baseurl 

533 if baseurl is None: 

534 self._baseurl = guess_url(self.webdir, self.host, self.user) 

535 

536 @property 

537 def mcmc_samples(self): 

538 return self._mcmc_samples 

539 

540 @mcmc_samples.setter 

541 def mcmc_samples(self, mcmc_samples): 

542 self._mcmc_samples = mcmc_samples 

543 if self._mcmc_samples: 

544 logger.info( 

545 "Treating all samples as seperate mcmc chains for the same " 

546 "analysis." 

547 ) 

548 

549 @property 

550 def labels(self): 

551 return self._labels 

552 

553 @labels.setter 

554 def labels(self, labels): 

555 if self.result_files is not None: 

556 if any(self.is_pesummary_metafile(s) for s in self.result_files): 

557 logger.warning( 

558 "labels argument is ignored when a pesummary metafile is " 

559 "input. Stored analyses will use their stored labels. If " 

560 "you wish to change the labels, please use `summarymodify`" 

561 ) 

562 labels = self.default_labels() 

563 if not hasattr(self, "._labels"): 

564 if labels is None: 

565 labels = self.default_labels() 

566 elif self.mcmc_samples and len(labels) != 1: 

567 raise InputError( 

568 "Please provide a single label that corresponds to all " 

569 "mcmc samples" 

570 ) 

571 elif len(np.unique(labels)) != len(labels): 

572 raise InputError( 

573 "Please provide unique labels for each result file" 

574 ) 

575 for num, i in enumerate(labels): 

576 if "." in i: 

577 logger.warning( 

578 "Replacing the label {} by {} to make it compatible " 

579 "with the html pages".format(i, i.replace(".", "_")) 

580 ) 

581 labels[num] = i.replace(".", "_") 

582 if self.add_to_existing: 

583 for i in labels: 

584 if i in self.existing_labels: 

585 raise InputError( 

586 "The label '%s' already exists in the existing " 

587 "metafile. Please pass another unique label" 

588 ) 

589 

590 if len(self.result_files) != len(labels) and not self.mcmc_samples: 

591 import copy 

592 _new_labels = copy.deepcopy(labels) 

593 idx = 1 

594 while len(_new_labels) < len(self.result_files): 

595 _new_labels.extend( 

596 [_label + str(idx) for _label in labels] 

597 ) 

598 idx += 1 

599 _new_labels = _new_labels[:len(self.result_files)] 

600 logger.info( 

601 "You have passed {} result files and {} labels. Setting " 

602 "labels = {}".format( 

603 len(self.result_files), len(labels), _new_labels 

604 ) 

605 ) 

606 labels = _new_labels 

607 self._labels = labels 

608 

609 @property 

610 def config(self): 

611 return self._config 

612 

613 @config.setter 

614 def config(self, config): 

615 if config and len(config) != len(self.labels): 

616 raise InputError( 

617 "Please provide a configuration file for each label" 

618 ) 

619 if config is None and not self.meta_file: 

620 self._config = [None] * len(self.labels) 

621 elif self.meta_file: 

622 self._config = [None] * len(self.labels) 

623 else: 

624 self._config = config 

625 for num, ff in enumerate(self._config): 

626 if isinstance(ff, str) and ff.lower() == "none": 

627 self._config[num] = None 

628 

629 @property 

630 def injection_file(self): 

631 return self._injection_file 

632 

633 @injection_file.setter 

634 def injection_file(self, injection_file): 

635 if injection_file and len(injection_file) != len(self.labels): 

636 if len(injection_file) == 1: 

637 logger.info( 

638 "Only one injection file passed. Assuming the same " 

639 "injection for all {} result files".format(len(self.labels)) 

640 ) 

641 else: 

642 raise InputError( 

643 "You have passed {} for {} result files. Please provide an " 

644 "injection file for each result file".format( 

645 len(self.injection_file), len(self.labels) 

646 ) 

647 ) 

648 if injection_file is None: 

649 injection_file = [None] * len(self.labels) 

650 self._injection_file = injection_file 

651 

652 @property 

653 def injection_data(self): 

654 return self._injection_data 

655 

656 @property 

657 def file_version(self): 

658 return self._file_version 

659 

660 @property 

661 def file_kwargs(self): 

662 return self._file_kwargs 

663 

664 @property 

665 def kde_plot(self): 

666 return self._kde_plot 

667 

668 @kde_plot.setter 

669 def kde_plot(self, kde_plot): 

670 self._kde_plot = kde_plot 

671 if kde_plot != conf.kde_plot: 

672 logger.info( 

673 conf.overwrite.format("kde_plot", conf.kde_plot, kde_plot) 

674 ) 

675 

676 @property 

677 def file_format(self): 

678 return self._file_format 

679 

680 @file_format.setter 

681 def file_format(self, file_format): 

682 if file_format is None: 

683 self._file_format = [None] * len(self.labels) 

684 elif len(file_format) == 1 and len(file_format) != len(self.labels): 

685 logger.warning( 

686 "Only one file format specified. Assuming all files are of " 

687 "this format" 

688 ) 

689 self._file_format = [file_format[0]] * len(self.labels) 

690 elif len(file_format) != len(self.labels): 

691 raise InputError( 

692 "Please provide a file format for each result file. If you " 

693 "wish to specify the file format for the second result file " 

694 "and not for any of the others, for example, simply pass 'None " 

695 "{format} None'" 

696 ) 

697 else: 

698 for num, ff in enumerate(file_format): 

699 if ff.lower() == "none": 

700 file_format[num] = None 

701 self._file_format = file_format 

702 

703 @property 

704 def samples(self): 

705 return self._samples 

706 

707 @samples.setter 

708 def samples(self, samples): 

709 if isinstance(samples, dict): 

710 return samples 

711 self._set_samples(samples) 

712 

713 def _set_samples( 

714 self, samples, 

715 ignore_keys=["prior", "weights", "labels", "indicies", "open_file"] 

716 ): 

717 """Extract the samples and store them as attributes of self 

718 

719 Parameters 

720 ---------- 

721 samples: list 

722 A list containing the paths to result files 

723 ignore_keys: list, optional 

724 A list containing properties of the read file that you do not want to be 

725 stored as attributes of self 

726 """ 

727 if not samples: 

728 raise InputError("Please provide a results file") 

729 _samples_generator = (self.is_pesummary_metafile(s) for s in samples) 

730 if any(_samples_generator) and not all(_samples_generator): 

731 raise InputError( 

732 "It seems that you have passed a combination of pesummary " 

733 "metafiles and non-pesummary metafiles. This is currently " 

734 "not supported." 

735 ) 

736 labels, labels_dict = None, {} 

737 weights_dict = {} 

738 if self.mcmc_samples: 

739 nsamples = 0. 

740 for num, i in enumerate(samples): 

741 idx = num 

742 if not self.mcmc_samples: 

743 if not self.is_pesummary_metafile(samples[num]): 

744 logger.info("Assigning {} to {}".format(self.labels[num], i)) 

745 else: 

746 num = 0 

747 if not os.path.isfile(i): 

748 raise InputError("File %s does not exist" % (i)) 

749 if self.is_pesummary_metafile(samples[num]): 

750 data = self.grab_data_from_input( 

751 i, self.labels[num], config=None, injection=None 

752 ) 

753 self.mcmc_samples = data["mcmc_samples"] 

754 else: 

755 data = self.grab_data_from_input( 

756 i, self.labels[num], config=self.config[num], 

757 injection=self.injection_file[num], 

758 file_format=self.file_format[num] 

759 ) 

760 if "config" in data.keys(): 

761 msg = ( 

762 "Overwriting the provided config file for '{}' with " 

763 "the config information stored in the input " 

764 "file".format(self.labels[num]) 

765 ) 

766 if self.config[num] is None: 

767 logger.debug(msg) 

768 else: 

769 logger.info(msg) 

770 self.config[num] = data.pop("config") 

771 if self.mcmc_samples: 

772 data["samples"] = { 

773 "{}_mcmc_chain_{}".format(key, idx): item for key, item 

774 in data["samples"].items() 

775 } 

776 for key, item in data.items(): 

777 if key not in ignore_keys: 

778 if idx == 0: 

779 setattr(self, "_{}".format(key), item) 

780 else: 

781 x = getattr(self, "_{}".format(key)) 

782 if isinstance(x, dict): 

783 x.update(item) 

784 elif isinstance(x, list): 

785 x += item 

786 setattr(self, "_{}".format(key), x) 

787 if self.mcmc_samples: 

788 try: 

789 nsamples += data["file_kwargs"][self.labels[num]]["sampler"][ 

790 "nsamples" 

791 ] 

792 except UnboundLocalError: 

793 pass 

794 if "labels" in data.keys(): 

795 stored_labels = data["labels"] 

796 else: 

797 stored_labels = [self.labels[num]] 

798 if "weights" in data.keys(): 

799 weights_dict.update(data["weights"]) 

800 if "prior" in data.keys(): 

801 for label in stored_labels: 

802 pp = data["prior"] 

803 if pp != {} and label in pp.keys() and pp[label] == []: 

804 if len(self.priors): 

805 if label not in self.priors["samples"].keys(): 

806 self.add_to_prior_dict( 

807 "samples/{}".format(label), [] 

808 ) 

809 else: 

810 self.add_to_prior_dict( 

811 "samples/{}".format(label), [] 

812 ) 

813 elif pp != {} and label not in pp.keys(): 

814 for key in pp.keys(): 

815 if key in self.priors.keys(): 

816 if label in self.priors[key].keys(): 

817 logger.warning( 

818 "Replacing the prior file for {} " 

819 "with the prior file stored in " 

820 "the result file".format( 

821 label 

822 ) 

823 ) 

824 if pp[key] == {}: 

825 self.add_to_prior_dict( 

826 "{}/{}".format(key, label), [] 

827 ) 

828 elif label not in pp[key].keys(): 

829 self.add_to_prior_dict( 

830 "{}/{}".format(key, label), {} 

831 ) 

832 else: 

833 self.add_to_prior_dict( 

834 "{}/{}".format(key, label), pp[key][label] 

835 ) 

836 else: 

837 self.add_to_prior_dict( 

838 "samples/{}".format(label), [] 

839 ) 

840 if "labels" in data.keys(): 

841 _duplicated = [ 

842 _ for _ in data["labels"] if num != 0 and _ in labels 

843 ] 

844 if num == 0: 

845 labels = data["labels"] 

846 elif len(_duplicated): 

847 raise InputError( 

848 "The labels stored in the supplied files are not " 

849 "unique. The label{}: '{}' appear{} in two or more " 

850 "files. Please provide unique labels for each " 

851 "analysis.".format( 

852 "s" if len(_duplicated) > 1 else "", 

853 ", ".join(_duplicated), 

854 "" if len(_duplicated) > 1 else "s" 

855 ) 

856 ) 

857 else: 

858 labels += data["labels"] 

859 labels_dict[num] = data["labels"] 

860 if self.mcmc_samples: 

861 try: 

862 self.file_kwargs[self.labels[0]]["sampler"].update( 

863 {"nsamples": nsamples, "nchains": len(self.result_files)} 

864 ) 

865 except (KeyError, UnboundLocalError): 

866 pass 

867 _labels = list(self._samples.keys()) 

868 if not isinstance(self._samples[_labels[0]], MCMCSamplesDict): 

869 self._samples = MCMCSamplesDict(self._samples) 

870 else: 

871 self._samples = self._samples[_labels[0]] 

872 if labels is not None: 

873 self._labels = labels 

874 if len(labels) != len(self.result_files): 

875 result_files = [] 

876 for num, f in enumerate(samples): 

877 for ii in np.arange(len(labels_dict[num])): 

878 result_files.append(self.result_files[num]) 

879 self.result_files = result_files 

880 self.weights = {i: None for i in self.labels} 

881 if weights_dict != {}: 

882 self.weights = weights_dict 

883 

884 @property 

885 def burnin_method(self): 

886 return self._burnin_method 

887 

888 @burnin_method.setter 

889 def burnin_method(self, burnin_method): 

890 self._burnin_method = burnin_method 

891 if not self.mcmc_samples and burnin_method is not None: 

892 logger.info( 

893 "The {} method will not be used to remove samples as " 

894 "burnin as this can only be used for mcmc chains.".format( 

895 burnin_method 

896 ) 

897 ) 

898 self._burnin_method = None 

899 elif self.mcmc_samples and burnin_method is None: 

900 logger.info( 

901 "No burnin method provided. Using {} as default".format( 

902 conf.burnin_method 

903 ) 

904 ) 

905 self._burnin_method = conf.burnin_method 

906 elif self.mcmc_samples: 

907 from pesummary.core.file import mcmc 

908 

909 if burnin_method not in mcmc.algorithms: 

910 logger.warning( 

911 "Unrecognised burnin method: {}. Resorting to the default: " 

912 "{}".format(burnin_method, conf.burnin_method) 

913 ) 

914 self._burnin_method = conf.burnin_method 

915 if self._burnin_method is not None: 

916 for label in self.labels: 

917 self.file_kwargs[label]["sampler"]["burnin_method"] = ( 

918 self._burnin_method 

919 ) 

920 

921 @property 

922 def burnin(self): 

923 return self._burnin 

924 

925 @burnin.setter 

926 def burnin(self, burnin): 

927 _name = "nsamples_removed_from_burnin" 

928 if burnin is not None: 

929 samples_lengths = [ 

930 self.samples[key].number_of_samples for key in 

931 self.samples.keys() 

932 ] 

933 if not all(int(burnin) < i for i in samples_lengths): 

934 raise InputError( 

935 "The chosen burnin is larger than the number of samples. " 

936 "Please choose a value less than {}".format( 

937 np.max(samples_lengths) 

938 ) 

939 ) 

940 logger.info( 

941 conf.overwrite.format("burnin", conf.burnin, burnin) 

942 ) 

943 burnin = int(burnin) 

944 else: 

945 burnin = conf.burnin 

946 if self.burnin_method is not None: 

947 arguments, kwargs = [], {} 

948 if burnin != 0 and self.burnin_method == "burnin_by_step_number": 

949 logger.warning( 

950 "The first {} samples have been requested to be removed " 

951 "as burnin, but the burnin method has been chosen to be " 

952 "burnin_by_step_number. Changing method to " 

953 "burnin_by_first_n with keyword argument step_number=" 

954 "True such that all samples with step number < {} are " 

955 "removed".format(burnin, burnin) 

956 ) 

957 self.burnin_method = "burnin_by_first_n" 

958 arguments = [burnin] 

959 kwargs = {"step_number": True} 

960 elif self.burnin_method == "burnin_by_first_n": 

961 arguments = [burnin] 

962 initial = self.samples.total_number_of_samples 

963 self._samples = self.samples.burnin( 

964 *arguments, algorithm=self.burnin_method, **kwargs 

965 ) 

966 diff = initial - self.samples.total_number_of_samples 

967 self.file_kwargs[self.labels[0]]["sampler"][_name] = diff 

968 self.file_kwargs[self.labels[0]]["sampler"]["nsamples"] = \ 

969 self._samples.total_number_of_samples 

970 else: 

971 for label in self.samples: 

972 self.samples[label] = self.samples[label].discard_samples( 

973 burnin 

974 ) 

975 if burnin != conf.burnin: 

976 self.file_kwargs[label]["sampler"][_name] = burnin 

977 

978 @property 

979 def nsamples(self): 

980 return self._nsamples 

981 

982 @nsamples.setter 

983 def nsamples(self, nsamples): 

984 self._nsamples = nsamples 

985 if nsamples is not None: 

986 logger.info( 

987 "{} samples will be used for each result file".format(nsamples) 

988 ) 

989 self._nsamples = int(nsamples) 

990 

991 @property 

992 def reweight_samples(self): 

993 return self._reweight_samples 

994 

995 @reweight_samples.setter 

996 def reweight_samples(self, reweight_samples): 

997 from pesummary.core.reweight import options 

998 self._reweight_samples = self._check_reweight_samples( 

999 reweight_samples, options 

1000 ) 

1001 

1002 def _check_reweight_samples(self, reweight_samples, options): 

1003 if reweight_samples and reweight_samples not in options.keys(): 

1004 logger.warning( 

1005 "Unknown reweight function: '{}'. Not reweighting posterior " 

1006 "and/or prior samples".format(reweight_samples) 

1007 ) 

1008 return False 

1009 return reweight_samples 

1010 

1011 @property 

1012 def path_to_samples(self): 

1013 return self._path_to_samples 

1014 

1015 @path_to_samples.setter 

1016 def path_to_samples(self, path_to_samples): 

1017 self._path_to_samples = path_to_samples 

1018 if path_to_samples is None: 

1019 self._path_to_samples = {label: None for label in self.labels} 

1020 elif len(path_to_samples) != len(self.labels): 

1021 raise InputError( 

1022 "Please provide a path for all result files passed. If " 

1023 "two result files are passed, and only one requires the " 

1024 "path_to_samples arguement, please pass --path_to_samples " 

1025 "None path/to/samples" 

1026 ) 

1027 else: 

1028 _paths = {} 

1029 for num, path in enumerate(path_to_samples): 

1030 _label = self.labels[num] 

1031 if path.lower() == "none": 

1032 _paths[_label] = None 

1033 else: 

1034 _paths[_label] = path 

1035 self._path_to_samples = _paths 

1036 

1037 @property 

1038 def priors(self): 

1039 return self._priors 

1040 

1041 @priors.setter 

1042 def priors(self, priors): 

1043 self._priors = self.grab_priors_from_inputs(priors) 

1044 

1045 @property 

1046 def custom_plotting(self): 

1047 return self._custom_plotting 

1048 

1049 @custom_plotting.setter 

1050 def custom_plotting(self, custom_plotting): 

1051 self._custom_plotting = custom_plotting 

1052 if custom_plotting is not None: 

1053 import importlib 

1054 

1055 path_to_python_file = os.path.dirname(custom_plotting) 

1056 python_file = os.path.splitext(os.path.basename(custom_plotting))[0] 

1057 if path_to_python_file != "": 

1058 import sys 

1059 

1060 sys.path.append(path_to_python_file) 

1061 try: 

1062 mod = importlib.import_module(python_file) 

1063 methods = getattr(mod, "__single_plots__", list()).copy() 

1064 methods += getattr(mod, "__comparion_plots__", list()).copy() 

1065 if len(methods) > 0: 

1066 self._custom_plotting = [path_to_python_file, python_file] 

1067 else: 

1068 logger.warning( 

1069 "No __single_plots__ or __comparison_plots__ in {}. " 

1070 "If you wish to use custom plotting, then please " 

1071 "add the variable :__single_plots__ and/or " 

1072 "__comparison_plots__ in future. No custom plotting " 

1073 "will be done" 

1074 ) 

1075 except ModuleNotFoundError as e: 

1076 logger.warning( 

1077 "Failed to import {} because {}. No custom plotting will " 

1078 "be done".format(python_file, e) 

1079 ) 

1080 

1081 @property 

1082 def external_hdf5_links(self): 

1083 return self._external_hdf5_links 

1084 

1085 @external_hdf5_links.setter 

1086 def external_hdf5_links(self, external_hdf5_links): 

1087 self._external_hdf5_links = external_hdf5_links 

1088 if not self.hdf5 and self.external_hdf5_links: 

1089 logger.warning( 

1090 "You can only apply external hdf5 links when saving the meta " 

1091 "file in hdf5 format. Turning external hdf5 links off." 

1092 ) 

1093 self._external_hdf5_links = False 

1094 

1095 @property 

1096 def hdf5_compression(self): 

1097 return self._hdf5_compression 

1098 

1099 @hdf5_compression.setter 

1100 def hdf5_compression(self, hdf5_compression): 

1101 self._hdf5_compression = hdf5_compression 

1102 if not self.hdf5 and hdf5_compression is not None: 

1103 logger.warning( 

1104 "You can only apply compression when saving the meta " 

1105 "file in hdf5 format. Turning compression off." 

1106 ) 

1107 self._hdf5_compression = None 

1108 

1109 @property 

1110 def existing_plot(self): 

1111 return self._existing_plot 

1112 

1113 @existing_plot.setter 

1114 def existing_plot(self, existing_plot): 

1115 self._existing_plot = existing_plot 

1116 if self._existing_plot is not None: 

1117 from pathlib import Path 

1118 import shutil 

1119 if isinstance(self._existing_plot, list): 

1120 logger.warning( 

1121 "Assigning {} to all labels".format( 

1122 ", ".join(self._existing_plot) 

1123 ) 

1124 ) 

1125 self._existing_plot = { 

1126 label: self._existing_plot for label in self.labels 

1127 } 

1128 _does_not_exist = ( 

1129 "The plot {} does not exist. Not adding plot to summarypages." 

1130 ) 

1131 keys_to_remove = [] 

1132 for key, _plot in self._existing_plot.items(): 

1133 if isinstance(_plot, list): 

1134 allowed = [] 

1135 for _subplot in _plot: 

1136 if not os.path.isfile(_subplot): 

1137 logger.warning(_does_not_exist.format(_subplot)) 

1138 else: 

1139 _filename = os.path.join( 

1140 self.webdir, "plots", Path(_subplot).name 

1141 ) 

1142 try: 

1143 shutil.copyfile(_subplot, _filename) 

1144 except shutil.SameFileError: 

1145 pass 

1146 allowed.append(_filename) 

1147 if not len(allowed): 

1148 keys_to_remove.append(key) 

1149 elif len(allowed) == 1: 

1150 self._existing_plot[key] = allowed[0] 

1151 else: 

1152 self._existing_plot[key] = allowed 

1153 else: 

1154 if not os.path.isfile(_plot): 

1155 logger.warning(_does_not_exist.format(_plot)) 

1156 keys_to_remove.append(key) 

1157 else: 

1158 _filename = os.path.join( 

1159 self.webdir, "plots", Path(_plot).name 

1160 ) 

1161 try: 

1162 shutil.copyfile(_plot, _filename) 

1163 except shutil.SameFileError: 

1164 _filename = os.path.join( 

1165 self.webdir, "plots", key + "_" + Path(_plot).name 

1166 ) 

1167 shutil.copyfile(_plot, _filename) 

1168 self._existing_plot[key] = _filename 

1169 for key in keys_to_remove: 

1170 del self._existing_plot[key] 

1171 if not len(self._existing_plot): 

1172 self._existing_plot = None 

1173 

1174 def add_to_prior_dict(self, path, data): 

1175 """Add priors to the prior dictionary 

1176 

1177 Parameters 

1178 ---------- 

1179 path: str 

1180 the location where you wish to store the prior. If this is inside 

1181 a nested dictionary, then please pass the path as 'a/b' 

1182 data: np.ndarray 

1183 the prior samples 

1184 """ 

1185 from functools import reduce 

1186 

1187 def build_tree(dictionary, path): 

1188 """Build a dictionary tree from a list of keys 

1189 

1190 Parameters 

1191 ---------- 

1192 dictionary: dict 

1193 existing dictionary that you wish to add to 

1194 path: list 

1195 list of keys specifying location 

1196 

1197 Examples 

1198 -------- 

1199 >>> dictionary = {"label": {"mass_1": [1,2,3,4,5,6]}} 

1200 >>> path = ["label", "mass_2"] 

1201 >>> build_tree(dictionary, path) 

1202 {"label": {"mass_1": [1,2,3,4,5,6], "mass_2": {}}} 

1203 """ 

1204 if path != [] and path[0] not in dictionary.keys(): 

1205 dictionary[path[0]] = {} 

1206 if path != []: 

1207 build_tree(dictionary[path[0]], path[1:]) 

1208 return dictionary 

1209 

1210 def get_nested_dictionary(dictionary, path): 

1211 """Return a nested dictionary from a list specifying path 

1212 

1213 Parameters 

1214 ---------- 

1215 dictionary: dict 

1216 existing dictionary that you wish to extract information from 

1217 path: list 

1218 list of keys specifying location 

1219 

1220 Examples 

1221 -------- 

1222 >>> dictionary = {"label": {"mass_1": [1,2,3,4,5,6]}} 

1223 >>> path = ["label", "mass_1"] 

1224 >>> get_nested_dictionary(dictionary, path) 

1225 [1,2,3,4,5,6] 

1226 """ 

1227 return reduce(dict.get, path, dictionary) 

1228 

1229 if "/" in path: 

1230 path = path.split("/") 

1231 else: 

1232 path = [path] 

1233 tree = build_tree(self._priors, path) 

1234 nested_dictionary = get_nested_dictionary(self._priors, path[:-1]) 

1235 nested_dictionary[path[-1]] = data 

1236 

1237 def grab_priors_from_inputs(self, priors, read_func=None, read_kwargs={}): 

1238 """ 

1239 """ 

1240 if read_func is None: 

1241 from pesummary.core.file.read import read as Read 

1242 read_func = Read 

1243 

1244 prior_dict = {} 

1245 if priors is not None: 

1246 prior_dict = {"samples": {}, "analytic": {}} 

1247 for i in priors: 

1248 if not os.path.isfile(i): 

1249 raise InputError("The file {} does not exist".format(i)) 

1250 if len(priors) != len(self.labels) and len(priors) == 1: 

1251 logger.warning( 

1252 "You have only specified a single prior file for {} result " 

1253 "files. Assuming the same prior file for all result " 

1254 "files".format(len(self.labels)) 

1255 ) 

1256 data = read_func( 

1257 priors[0], nsamples=self.nsamples_for_prior 

1258 ) 

1259 for i in self.labels: 

1260 prior_dict["samples"][i] = data.samples_dict 

1261 try: 

1262 if data.analytic is not None: 

1263 prior_dict["analytic"][i] = data.analytic 

1264 except AttributeError: 

1265 continue 

1266 elif len(priors) != len(self.labels): 

1267 raise InputError( 

1268 "Please provide a prior file for each result file" 

1269 ) 

1270 else: 

1271 for num, i in enumerate(priors): 

1272 if i.lower() == "none": 

1273 continue 

1274 logger.info( 

1275 "Assigning {} to {}".format(self.labels[num], i) 

1276 ) 

1277 if self.labels[num] in read_kwargs.keys(): 

1278 grab_data_kwargs = read_kwargs[self.labels[num]] 

1279 else: 

1280 grab_data_kwargs = read_kwargs 

1281 data = read_func( 

1282 priors[num], nsamples=self.nsamples_for_prior, 

1283 **grab_data_kwargs 

1284 ) 

1285 prior_dict["samples"][self.labels[num]] = data.samples_dict 

1286 try: 

1287 if data.analytic is not None: 

1288 prior_dict["analytic"][self.labels[num]] = data.analytic 

1289 except AttributeError: 

1290 continue 

1291 return prior_dict 

1292 

1293 @property 

1294 def grab_data_kwargs(self): 

1295 return { 

1296 label: dict(regenerate=self.regenerate) for label in self.labels 

1297 } 

1298 

1299 def grab_data_from_input( 

1300 self, file, label, config=None, injection=None, file_format=None 

1301 ): 

1302 """Wrapper function for the grab_data_from_metafile and 

1303 grab_data_from_file functions 

1304 

1305 Parameters 

1306 ---------- 

1307 file: str 

1308 path to the result file 

1309 label: str 

1310 label that you wish to use for the result file 

1311 config: str, optional 

1312 path to a configuration file used in the analysis 

1313 injection: str, optional 

1314 path to an injection file used in the analysis 

1315 file_format, str, optional 

1316 the file format you wish to use when loading. Default None. 

1317 If None, the read function loops through all possible options 

1318 mcmc: Bool, optional 

1319 if True, the result file is an mcmc chain 

1320 """ 

1321 if label in self.grab_data_kwargs.keys(): 

1322 grab_data_kwargs = self.grab_data_kwargs[label] 

1323 else: 

1324 grab_data_kwargs = self.grab_data_kwargs 

1325 

1326 if self.is_pesummary_metafile(file): 

1327 data = self.grab_data_from_metafile( 

1328 file, self.webdir, compare=self.compare_results, 

1329 nsamples=self.nsamples, reweight_samples=self.reweight_samples, 

1330 disable_injection=self.disable_injection, 

1331 keep_nan_likelihood_samples=self.keep_nan_likelihood_samples, 

1332 **grab_data_kwargs 

1333 ) 

1334 else: 

1335 data = self.grab_data_from_file( 

1336 file, label, self.webdir, config=config, injection=injection, 

1337 file_format=file_format, nsamples=self.nsamples, 

1338 disable_prior_sampling=self.disable_prior_sampling, 

1339 nsamples_for_prior=self.nsamples_for_prior, 

1340 path_to_samples=self.path_to_samples[label], 

1341 reweight_samples=self.reweight_samples, 

1342 keep_nan_likelihood_samples=self.keep_nan_likelihood_samples, 

1343 num_processes=self.multi_process, 

1344 **grab_data_kwargs 

1345 ) 

1346 self._open_result_files.update({file: data["open_file"]}) 

1347 return data 

1348 

1349 @property 

1350 def email(self): 

1351 return self._email 

1352 

1353 @email.setter 

1354 def email(self, email): 

1355 if email is not None and "@" not in email: 

1356 raise InputError("Please provide a valid email address") 

1357 self._email = email 

1358 

1359 @property 

1360 def dump(self): 

1361 return self._dump 

1362 

1363 @dump.setter 

1364 def dump(self, dump): 

1365 self._dump = dump 

1366 

1367 @property 

1368 def palette(self): 

1369 return self._palette 

1370 

1371 @palette.setter 

1372 def palette(self, palette): 

1373 self._palette = palette 

1374 if palette is not conf.palette: 

1375 from pesummary.core.plots.palette import ( 

1376 color_palette, AVAILABLE_PALETTES 

1377 ) 

1378 try: 

1379 color_palette(palette, n_colors=1) 

1380 logger.info( 

1381 conf.overwrite.format("palette", conf.palette, palette) 

1382 ) 

1383 conf.palette = palette 

1384 except ValueError as e: 

1385 raise InputError( 

1386 "Unrecognised palette. Please choose from one of the " 

1387 "following {}".format( 

1388 ", ".join(AVAILABLE_PALETTES) 

1389 ) 

1390 ) 

1391 

1392 @property 

1393 def include_prior(self): 

1394 return self._include_prior 

1395 

1396 @include_prior.setter 

1397 def include_prior(self, include_prior): 

1398 self._include_prior = include_prior 

1399 if include_prior != conf.include_prior: 

1400 conf.overwrite.format("prior", conf.include_prior, include_prior) 

1401 conf.include_prior = include_prior 

1402 

1403 @property 

1404 def colors(self): 

1405 return self._colors 

1406 

1407 @colors.setter 

1408 def colors(self, colors): 

1409 if colors is not None: 

1410 number = len(self.labels) 

1411 if self.existing: 

1412 number += len(self.existing_labels) 

1413 if len(colors) != number and len(colors) > number: 

1414 logger.info( 

1415 "You have passed {} colors for {} result files. Setting " 

1416 "colors = {}".format( 

1417 len(colors), number, colors[:number] 

1418 ) 

1419 ) 

1420 self._colors = colors[:number] 

1421 return 

1422 elif len(colors) != number: 

1423 logger.warning( 

1424 "Number of colors does not match the number of labels. " 

1425 "Using default colors" 

1426 ) 

1427 from pesummary.core.plots.palette import color_palette 

1428 number = len(self.labels) 

1429 if self.existing: 

1430 number += len(self.existing_labels) 

1431 colors = color_palette( 

1432 palette=conf.palette, n_colors=number 

1433 ).as_hex() 

1434 self._colors = colors 

1435 

1436 @property 

1437 def linestyles(self): 

1438 return self._linestyles 

1439 

1440 @linestyles.setter 

1441 def linestyles(self, linestyles): 

1442 if linestyles is not None: 

1443 if len(linestyles) != len(self.colors): 

1444 if len(linestyles) > len(self.colors): 

1445 logger.info( 

1446 "You have passed {} linestyles for {} result files. " 

1447 "Setting linestyles = {}".format( 

1448 len(linestyles), len(self.colors), 

1449 linestyles[:len(self.colors)] 

1450 ) 

1451 ) 

1452 self._linestyles = linestyles[:len(self.colors)] 

1453 return 

1454 else: 

1455 logger.warning( 

1456 "Number of linestyles does not match the number of " 

1457 "labels. Using default linestyles" 

1458 ) 

1459 available_linestyles = ["-", "--", ":", "-."] 

1460 linestyles = ["-"] * len(self.colors) 

1461 unique_colors = np.unique(self.colors) 

1462 for color in unique_colors: 

1463 indicies = [num for num, i in enumerate(self.colors) if i == color] 

1464 for idx, j in enumerate(indicies): 

1465 linestyles[j] = available_linestyles[ 

1466 np.mod(idx, len(available_linestyles)) 

1467 ] 

1468 self._linestyles = linestyles 

1469 

1470 @property 

1471 def disable_corner(self): 

1472 return self._disable_corner 

1473 

1474 @disable_corner.setter 

1475 def disable_corner(self, disable_corner): 

1476 self._disable_corner = disable_corner 

1477 if disable_corner: 

1478 logger.warning( 

1479 "No corner plot will be produced. This will reduce overall " 

1480 "runtime but does mean that the interactive corner plot feature " 

1481 "on the webpages will no longer work" 

1482 ) 

1483 

1484 @property 

1485 def add_to_corner(self): 

1486 return self._add_to_corner 

1487 

1488 @add_to_corner.setter 

1489 def add_to_corner(self, add_to_corner): 

1490 self._add_to_corner = self._set_corner_params(add_to_corner) 

1491 

1492 def _set_corner_params(self, corner_params): 

1493 cls = self.__class__.__name__ 

1494 if corner_params is not None: 

1495 for label in self.labels: 

1496 _not_included = [ 

1497 param for param in corner_params if param not in 

1498 self.samples[label].keys() 

1499 ] 

1500 if len(_not_included) == len(corner_params) and cls == "Input": 

1501 logger.warning( 

1502 "None of the chosen corner parameters are " 

1503 "included in the posterior table for '{}'. Using " 

1504 "all available parameters for the corner plot".format( 

1505 label 

1506 ) 

1507 ) 

1508 corner_params = None 

1509 break 

1510 elif len(_not_included): 

1511 logger.warning( 

1512 "The following parameters are not included in the " 

1513 "posterior table for '{}': {}. Not adding to corner " 

1514 "plot".format(label, ", ".join(_not_included)) 

1515 ) 

1516 elif cls == "Input": 

1517 logger.debug( 

1518 "Using all parameters stored in the result file for the " 

1519 "corner plots. This may take some time." 

1520 ) 

1521 return corner_params 

1522 

1523 @property 

1524 def pe_algorithm(self): 

1525 return self._pe_algorithm 

1526 

1527 @pe_algorithm.setter 

1528 def pe_algorithm(self, pe_algorithm): 

1529 self._pe_algorithm = pe_algorithm 

1530 if pe_algorithm is None: 

1531 return 

1532 if len(pe_algorithm) != len(self.labels): 

1533 raise ValueError("Please provide an algorithm for each result file") 

1534 for num, (label, _algorithm) in enumerate(zip(self.labels, pe_algorithm)): 

1535 if "pe_algorithm" in self.file_kwargs[label]["sampler"].keys(): 

1536 _stored = self.file_kwargs[label]["sampler"]["pe_algorithm"] 

1537 if _stored != _algorithm: 

1538 logger.warning( 

1539 "Overwriting the pe_algorithm extracted from the file " 

1540 "'{}': {} with the algorithm provided from the command " 

1541 "line: {}".format( 

1542 self.result_files[num], _stored, _algorithm 

1543 ) 

1544 ) 

1545 self.file_kwargs[label]["sampler"]["pe_algorithm"] = _algorithm 

1546 

1547 @property 

1548 def notes(self): 

1549 return self._notes 

1550 

1551 @notes.setter 

1552 def notes(self, notes): 

1553 self._notes = notes 

1554 if notes is not None: 

1555 if not os.path.isfile(notes): 

1556 raise InputError( 

1557 "No such file or directory called {}".format(notes) 

1558 ) 

1559 try: 

1560 with open(notes, "r") as f: 

1561 self._notes = f.read() 

1562 except FileNotFoundError: 

1563 logger.warning( 

1564 "No such file or directory called {}. Custom notes will " 

1565 "not be added to the summarypages".format(notes) 

1566 ) 

1567 except IOError as e: 

1568 logger.warning( 

1569 "Failed to read {}. Unable to put notes on " 

1570 "summarypages".format(notes) 

1571 ) 

1572 

1573 @property 

1574 def descriptions(self): 

1575 return self._descriptions 

1576 

1577 @descriptions.setter 

1578 def descriptions(self, descriptions): 

1579 import json 

1580 if hasattr(self, "_descriptions") and not len(descriptions): 

1581 return 

1582 elif not len(descriptions): 

1583 self._descriptions = None 

1584 return 

1585 

1586 if len(descriptions) and isinstance(descriptions, dict): 

1587 data = descriptions 

1588 elif len(descriptions): 

1589 descriptions = descriptions[0] 

1590 _is_file = not isinstance(descriptions, dict) 

1591 if hasattr(self, "_descriptions"): 

1592 logger.warning( 

1593 "Ignoring descriptions found in result file and using " 

1594 "descriptions in '{}'".format(descriptions) 

1595 ) 

1596 self._descriptions = None 

1597 if _is_file and not os.path.isfile(descriptions): 

1598 logger.warning( 

1599 "No such file called {}. Unable to add descriptions".format( 

1600 descriptions 

1601 ) 

1602 ) 

1603 return 

1604 if _is_file: 

1605 try: 

1606 with open(descriptions, "r") as f: 

1607 data = json.load(f) 

1608 except json.decoder.JSONDecodeError: 

1609 logger.warning( 

1610 "Unable to open file '{}'. Not storing descriptions".format( 

1611 descriptions 

1612 ) 

1613 ) 

1614 return 

1615 if not all(label in data.keys() for label in self.labels): 

1616 not_included = [ 

1617 label for label in self.labels if label not in data.keys() 

1618 ] 

1619 logger.debug( 

1620 "No description found for '{}'. Using default " 

1621 "description".format(", ".join(not_included)) 

1622 ) 

1623 for label in not_included: 

1624 data[label] = "No description found" 

1625 if len(data.keys()) > len(self.labels): 

1626 logger.warning( 

1627 "Descriptions file contains descriptions for analyses other " 

1628 "than {}. Ignoring other descriptions".format( 

1629 ", ".join(self.labels) 

1630 ) 

1631 ) 

1632 other = [ 

1633 analysis for analysis in data.keys() if analysis not in 

1634 self.labels 

1635 ] 

1636 for analysis in other: 

1637 _ = data.pop(analysis) 

1638 _remove = [] 

1639 for key, desc in data.items(): 

1640 if not isinstance(desc, (str, bytes)): 

1641 logger.warning( 

1642 "Unknown description '{}' for '{}'. The description should " 

1643 "be a string or bytes object" 

1644 ) 

1645 _remove.append(key) 

1646 if len(_remove): 

1647 for analysis in _remove: 

1648 _ = data.pop(analysis) 

1649 self._descriptions = data 

1650 

1651 @property 

1652 def preferred(self): 

1653 return self._preferred 

1654 

1655 @preferred.setter 

1656 def preferred(self, preferred): 

1657 if preferred is not None and preferred not in self.labels: 

1658 logger.warning( 

1659 "'{}' not in list of labels. Unable to stored as the " 

1660 "preferred analysis".format(preferred) 

1661 ) 

1662 self._preferred = None 

1663 elif preferred is not None: 

1664 logger.debug( 

1665 "Setting '{}' as the preferred analysis".format(preferred) 

1666 ) 

1667 self._preferred = preferred 

1668 elif len(self.labels) == 1: 

1669 self._preferred = self.labels[0] 

1670 else: 

1671 self._preferred = None 

1672 if self._preferred is not None: 

1673 try: 

1674 self.file_kwargs[self._preferred]["other"].update( 

1675 {"preferred": "True"} 

1676 ) 

1677 except KeyError: 

1678 self.file_kwargs[self._preferred].update( 

1679 {"other": {"preferred": "True"}} 

1680 ) 

1681 for _label in self.labels: 

1682 if self._preferred is not None and _label == self._preferred: 

1683 continue 

1684 try: 

1685 self.file_kwargs[_label]["other"].update( 

1686 {"preferred": "False"} 

1687 ) 

1688 except KeyError: 

1689 self.file_kwargs[_label].update( 

1690 {"other": {"preferred": "False"}} 

1691 ) 

1692 return 

1693 

1694 @property 

1695 def public(self): 

1696 return self._public 

1697 

1698 @public.setter 

1699 def public(self, public): 

1700 self._public = public 

1701 if public != conf.public: 

1702 logger.info( 

1703 conf.overwrite.format("public", conf.public, public) 

1704 ) 

1705 

1706 @property 

1707 def multi_process(self): 

1708 return self._multi_process 

1709 

1710 @multi_process.setter 

1711 def multi_process(self, multi_process): 

1712 self._multi_process = int(multi_process) 

1713 if multi_process is not None and int(multi_process) != int(conf.multi_process): 

1714 logger.info( 

1715 conf.overwrite.format( 

1716 "multi_process", conf.multi_process, multi_process 

1717 ) 

1718 ) 

1719 

1720 @property 

1721 def publication_kwargs(self): 

1722 return self._publication_kwargs 

1723 

1724 @publication_kwargs.setter 

1725 def publication_kwargs(self, publication_kwargs): 

1726 self._publication_kwargs = publication_kwargs 

1727 if publication_kwargs != {}: 

1728 allowed_kwargs = ["gridsize"] 

1729 if not any(i in publication_kwargs.keys() for i in allowed_kwargs): 

1730 logger.warning( 

1731 "Currently the only allowed publication kwargs are {}. " 

1732 "Ignoring other inputs.".format( 

1733 ", ".join(allowed_kwargs) 

1734 ) 

1735 ) 

1736 

1737 @property 

1738 def ignore_parameters(self): 

1739 return self._ignore_parameters 

1740 

1741 @ignore_parameters.setter 

1742 def ignore_parameters(self, ignore_parameters): 

1743 self._ignore_parameters = ignore_parameters 

1744 if ignore_parameters is not None: 

1745 for num, label in enumerate(self.labels): 

1746 removed_parameters = list_match( 

1747 list(self.samples[label].keys()), ignore_parameters 

1748 ) 

1749 if not len(removed_parameters): 

1750 logger.warning( 

1751 "Failed to remove any parameters from {}".format( 

1752 self.result_files[num] 

1753 ) 

1754 ) 

1755 else: 

1756 logger.warning( 

1757 "Removing parameters: {} from {}".format( 

1758 ", ".join(removed_parameters), 

1759 self.result_files[num] 

1760 ) 

1761 ) 

1762 for ignore in removed_parameters: 

1763 self.samples[label].pop(ignore) 

1764 

1765 @staticmethod 

1766 def _make_directories(webdir, dirs): 

1767 """Make the directories to store the information 

1768 """ 

1769 for i in dirs: 

1770 if not os.path.isdir(os.path.join(webdir, i)): 

1771 make_dir(os.path.join(webdir, i)) 

1772 

1773 def make_directories(self): 

1774 """Make the directories to store the information 

1775 """ 

1776 self._make_directories(self.webdir, self.default_directories) 

1777 

1778 @staticmethod 

1779 def _copy_files(paths): 

1780 """Copy the relevant file to the web directory 

1781 

1782 Parameters 

1783 ---------- 

1784 paths: nd list 

1785 list of files you wish to copy. First element is the path of the 

1786 file to copy and second element is the location of where you 

1787 wish the file to be put 

1788 

1789 Examples 

1790 -------- 

1791 >>> paths = [ 

1792 ... ["config/config.ini", "webdir/config.ini"], 

1793 ... ["samples/samples.h5", "webdir/samples.h5"] 

1794 ... ] 

1795 """ 

1796 import shutil 

1797 

1798 for ff in paths: 

1799 shutil.copyfile(ff[0], ff[1]) 

1800 

1801 def copy_files(self): 

1802 """Copy the relevant file to the web directory 

1803 """ 

1804 self._copy_files(self.default_files_to_copy) 

1805 

1806 def default_labels(self): 

1807 """Return a list of default labels. 

1808 """ 

1809 from time import time 

1810 

1811 def _default_label(file_name): 

1812 return "%s_%s" % (round(time()), file_name) 

1813 

1814 label_list = [] 

1815 if self.result_files is None or len(self.result_files) == 0: 

1816 raise InputError("Please provide a results file") 

1817 elif self.mcmc_samples: 

1818 f = self.result_files[0] 

1819 file_name = os.path.splitext(os.path.basename(f))[0] 

1820 label_list.append(_default_label(file_name)) 

1821 else: 

1822 for num, i in enumerate(self.result_files): 

1823 file_name = os.path.splitext(os.path.basename(i))[0] 

1824 label_list.append(_default_label(file_name)) 

1825 

1826 duplicates = dict(set( 

1827 (x, label_list.count(x)) for x in 

1828 filter(lambda rec: label_list.count(rec) > 1, label_list))) 

1829 

1830 for i in duplicates.keys(): 

1831 for j in range(duplicates[i]): 

1832 ind = label_list.index(i) 

1833 label_list[ind] += "_%s" % (j) 

1834 if self.add_to_existing: 

1835 for num, i in enumerate(label_list): 

1836 if i in self.existing_labels: 

1837 ind = label_list.index(i) 

1838 label_list[ind] += "_%s" % (num) 

1839 return label_list 

1840 

1841 @staticmethod 

1842 def get_package_information(): 

1843 """Return a dictionary of parameter information 

1844 """ 

1845 from pesummary._version_helper import PackageInformation 

1846 from operator import itemgetter 

1847 

1848 _package = PackageInformation() 

1849 package_info = _package.package_info 

1850 package_dir = _package.package_dir 

1851 if "build_string" in package_info[0]: # conda list 

1852 headings = ("name", "version", "channel", "build_string") 

1853 else: # pip list installed 

1854 headings = ("name", "version") 

1855 packages = np.array([ 

1856 tuple(pkg[col.lower()] for col in headings) for pkg in 

1857 sorted(package_info, key=itemgetter("name")) 

1858 ], dtype=[(col, "S20") for col in headings]).view(np.recarray) 

1859 return { 

1860 "packages": packages, "environment": [package_dir], 

1861 "manager": _package.package_manager 

1862 } 

1863 

1864 def grab_key_data_from_result_files(self): 

1865 """Grab the mean, median, maxL and standard deviation for all 

1866 parameters for all each result file 

1867 """ 

1868 key_data = { 

1869 key: samples.key_data for key, samples in self.samples.items() 

1870 } 

1871 for key, val in self.samples.items(): 

1872 for j in val.keys(): 

1873 _inj = self.injection_data[key][j] 

1874 key_data[key][j]["injected"] = ( 

1875 _inj[0] if not math.isnan(_inj) and isinstance( 

1876 _inj, (list, np.ndarray) 

1877 ) else _inj 

1878 ) 

1879 return key_data 

1880 

1881 

1882class BaseInput(_Input): 

1883 """Class to handle and store base command line arguments 

1884 """ 

1885 def __init__(self, opts, ignore_copy=False, checkpoint=None, gw=False): 

1886 self.opts = opts 

1887 self.gw = gw 

1888 self.restart_from_checkpoint = self.opts.restart_from_checkpoint 

1889 if checkpoint is not None: 

1890 for key, item in vars(checkpoint).items(): 

1891 setattr(self, key, item) 

1892 logger.info( 

1893 "Loaded command line arguments: {}".format(self.opts) 

1894 ) 

1895 self.restart_from_checkpoint = True 

1896 self._restarted_from_checkpoint = True 

1897 return 

1898 self.seed = self.opts.seed 

1899 self.result_files = self.opts.samples 

1900 self.user = self.opts.user 

1901 self.existing = self.opts.existing 

1902 self.add_to_existing = False 

1903 if self.existing is not None: 

1904 self.add_to_existing = True 

1905 self.existing_metafile = True 

1906 self.webdir = self.opts.webdir 

1907 self._restarted_from_checkpoint = False 

1908 self.resume_file_dir = conf.checkpoint_dir(self.webdir) 

1909 self.resume_file = conf.resume_file 

1910 self._resume_file_path = os.path.join( 

1911 self.resume_file_dir, self.resume_file 

1912 ) 

1913 self.make_directories() 

1914 self.email = self.opts.email 

1915 self.pe_algorithm = self.opts.pe_algorithm 

1916 self.multi_process = self.opts.multi_process 

1917 self.package_information = self.get_package_information() 

1918 if not ignore_copy: 

1919 self.copy_files() 

1920 self.write_current_state() 

1921 

1922 @property 

1923 def default_directories(self): 

1924 return ["checkpoint"] 

1925 

1926 @property 

1927 def default_files_to_copy(self): 

1928 return [] 

1929 

1930 def write_current_state(self): 

1931 """Write the current state of the input class to file 

1932 """ 

1933 from pesummary.io import write 

1934 write( 

1935 self, outdir=self.resume_file_dir, file_format="pickle", 

1936 filename=self.resume_file, overwrite=True 

1937 ) 

1938 logger.debug( 

1939 "Written checkpoint file: {}".format(self._resume_file_path) 

1940 ) 

1941 

1942 

1943class SamplesInput(BaseInput): 

1944 """Class to handle and store sample specific command line arguments 

1945 """ 

1946 def __init__(self, *args, extra_options=None, **kwargs): 

1947 """ 

1948 """ 

1949 super(SamplesInput, self).__init__(*args, **kwargs) 

1950 if self.result_files is not None: 

1951 self._open_result_files = {path: None for path in self.result_files} 

1952 self.meta_file = False 

1953 if self.result_files is not None and len(self.result_files) == 1: 

1954 self.meta_file = self.is_pesummary_metafile(self.result_files[0]) 

1955 self.compare_results = self.opts.compare_results 

1956 self.disable_injection = self.opts.disable_injection 

1957 if self.existing is not None: 

1958 self.existing_data = self.grab_data_from_metafile( 

1959 self.existing_metafile, self.existing, 

1960 compare=self.compare_results 

1961 ) 

1962 self.existing_samples = self.existing_data["samples"] 

1963 self.existing_injection_data = self.existing_data["injection_data"] 

1964 self.existing_file_version = self.existing_data["file_version"] 

1965 self.existing_file_kwargs = self.existing_data["file_kwargs"] 

1966 self.existing_priors = self.existing_data["prior"] 

1967 self.existing_config = self.existing_data["config"] 

1968 self.existing_labels = self.existing_data["labels"] 

1969 self.existing_weights = self.existing_data["weights"] 

1970 else: 

1971 self.existing_metafile = None 

1972 self.existing_labels = None 

1973 self.existing_weights = None 

1974 self.existing_samples = None 

1975 self.existing_file_version = None 

1976 self.existing_file_kwargs = None 

1977 self.existing_priors = None 

1978 self.existing_config = None 

1979 self.existing_injection_data = None 

1980 self.mcmc_samples = self.opts.mcmc_samples 

1981 self.labels = self.opts.labels 

1982 self.weights = {i: None for i in self.labels} 

1983 self.config = self.opts.config 

1984 self.injection_file = self.opts.inj_file 

1985 self.regenerate = self.opts.regenerate 

1986 if extra_options is not None: 

1987 for opt in extra_options: 

1988 setattr(self, opt, getattr(self.opts, opt)) 

1989 self.nsamples_for_prior = self.opts.nsamples_for_prior 

1990 self.priors = self.opts.prior_file 

1991 self.disable_prior_sampling = self.opts.disable_prior_sampling 

1992 self.path_to_samples = self.opts.path_to_samples 

1993 self.file_format = self.opts.file_format 

1994 self.nsamples = self.opts.nsamples 

1995 self.keep_nan_likelihood_samples = self.opts.keep_nan_likelihood_samples 

1996 self.reweight_samples = self.opts.reweight_samples 

1997 self.samples = self.opts.samples 

1998 self.ignore_parameters = self.opts.ignore_parameters 

1999 self.burnin_method = self.opts.burnin_method 

2000 self.burnin = self.opts.burnin 

2001 self.same_parameters = [] 

2002 if self.mcmc_samples: 

2003 self._samples = {label: self.samples.T for label in self.labels} 

2004 self.write_current_state() 

2005 

2006 @property 

2007 def analytic_prior_dict(self): 

2008 return { 

2009 label: "\n".join( 

2010 [ 

2011 "{} = {}".format(key, value) for key, value in 

2012 self.priors["analytic"][label].items() 

2013 ] 

2014 ) if "analytic" in self.priors.keys() and label in 

2015 self.priors["analytic"].keys() else None for label in self.labels 

2016 } 

2017 

2018 @property 

2019 def same_parameters(self): 

2020 return self._same_parameters 

2021 

2022 @same_parameters.setter 

2023 def same_parameters(self, same_parameters): 

2024 self._same_parameters = self.intersect_samples_dict(self.samples) 

2025 

2026 def intersect_samples_dict(self, samples): 

2027 parameters = [ 

2028 list(samples[key].keys()) for key in samples.keys() 

2029 ] 

2030 params = list(set.intersection(*[set(l) for l in parameters])) 

2031 return params 

2032 

2033 

2034class PlottingInput(SamplesInput): 

2035 """Class to handle and store plotting specific command line arguments 

2036 """ 

2037 def __init__(self, *args, **kwargs): 

2038 super(PlottingInput, self).__init__(*args, **kwargs) 

2039 self.style_file = self.opts.style_file 

2040 self.publication = self.opts.publication 

2041 self.publication_kwargs = self.opts.publication_kwargs 

2042 self.kde_plot = self.opts.kde_plot 

2043 self.custom_plotting = self.opts.custom_plotting 

2044 self.add_to_corner = self.opts.add_to_corner 

2045 self.corner_params = self.add_to_corner 

2046 self.palette = self.opts.palette 

2047 self.include_prior = self.opts.include_prior 

2048 self.colors = self.opts.colors 

2049 self.linestyles = self.opts.linestyles 

2050 self.disable_corner = self.opts.disable_corner 

2051 self.disable_comparison = self.opts.disable_comparison 

2052 self.disable_interactive = self.opts.disable_interactive 

2053 self.disable_expert = not self.opts.enable_expert 

2054 self.multi_threading_for_plots = self.multi_process 

2055 self.write_current_state() 

2056 

2057 @property 

2058 def default_directories(self): 

2059 dirs = super(PlottingInput, self).default_directories 

2060 dirs += ["plots", "plots/corner", "plots/publication", "samples"] 

2061 return dirs 

2062 

2063 

2064class WebpageInput(SamplesInput): 

2065 """Class to handle and store webpage specific command line arguments 

2066 """ 

2067 def __init__(self, *args, **kwargs): 

2068 super(WebpageInput, self).__init__(*args, **kwargs) 

2069 self.baseurl = self.opts.baseurl 

2070 self.existing_plot = self.opts.existing_plot 

2071 self.pe_algorithm = self.opts.pe_algorithm 

2072 self.notes = self.opts.notes 

2073 self.dump = self.opts.dump 

2074 self.hdf5 = not self.opts.save_to_json 

2075 self.external_hdf5_links = self.opts.external_hdf5_links 

2076 self.file_kwargs["webpage_url"] = self.baseurl + "/home.html" 

2077 self.write_current_state() 

2078 

2079 @property 

2080 def default_directories(self): 

2081 dirs = super(WebpageInput, self).default_directories 

2082 dirs += ["js", "html", "css"] 

2083 return dirs 

2084 

2085 @property 

2086 def default_files_to_copy(self): 

2087 from pesummary import core 

2088 files_to_copy = super(WebpageInput, self).default_files_to_copy 

2089 path = core.__path__[0] 

2090 scripts = glob(os.path.join(path, "js", "*.js")) 

2091 for i in scripts: 

2092 files_to_copy.append( 

2093 [i, os.path.join(self.webdir, "js", os.path.basename(i))] 

2094 ) 

2095 scripts = glob(os.path.join(path, "css", "*.css")) 

2096 for i in scripts: 

2097 files_to_copy.append( 

2098 [i, os.path.join(self.webdir, "css", os.path.basename(i))] 

2099 ) 

2100 return files_to_copy 

2101 

2102 

2103class WebpagePlusPlottingInput(PlottingInput, WebpageInput): 

2104 """Class to handle and store webpage and plotting specific command line 

2105 arguments 

2106 """ 

2107 def __init__(self, *args, **kwargs): 

2108 super(WebpagePlusPlottingInput, self).__init__(*args, **kwargs) 

2109 self.copy_files() 

2110 

2111 @property 

2112 def default_directories(self): 

2113 return super(WebpagePlusPlottingInput, self).default_directories 

2114 

2115 @property 

2116 def default_files_to_copy(self): 

2117 return super(WebpagePlusPlottingInput, self).default_files_to_copy 

2118 

2119 

2120class MetaFileInput(SamplesInput): 

2121 """Class to handle and store metafile specific command line arguments 

2122 """ 

2123 def __init__(self, *args, **kwargs): 

2124 kwargs.update({"ignore_copy": True}) 

2125 super(MetaFileInput, self).__init__(*args, **kwargs) 

2126 self.copy_files() 

2127 self.filename = self.opts.filename 

2128 self.hdf5 = not self.opts.save_to_json 

2129 self.hdf5_compression = self.opts.hdf5_compression 

2130 self.external_hdf5_links = self.opts.external_hdf5_links 

2131 self.descriptions = self.opts.descriptions 

2132 self.preferred = self.opts.preferred 

2133 self.write_current_state() 

2134 

2135 @property 

2136 def default_directories(self): 

2137 dirs = super(MetaFileInput, self).default_directories 

2138 dirs += ["samples", "config"] 

2139 return dirs 

2140 

2141 @property 

2142 def default_files_to_copy(self): 

2143 files_to_copy = super(MetaFileInput, self).default_files_to_copy 

2144 if not all(i is None for i in self.config): 

2145 for num, i in enumerate(self.config): 

2146 if i is not None and self.webdir not in i: 

2147 filename = "_".join( 

2148 [self.labels[num], "config.ini"] 

2149 ) 

2150 files_to_copy.append( 

2151 [i, os.path.join(self.webdir, "config", filename)] 

2152 ) 

2153 for num, _file in enumerate(self.result_files): 

2154 if not self.mcmc_samples: 

2155 filename = "{}_{}".format(self.labels[num], Path(_file).name) 

2156 else: 

2157 filename = "chain_{}_{}".format(num, Path(_file).name) 

2158 files_to_copy.append( 

2159 [_file, os.path.join(self.webdir, "samples", filename)] 

2160 ) 

2161 return files_to_copy 

2162 

2163 

2164class WebpagePlusPlottingPlusMetaFileInput(MetaFileInput, WebpagePlusPlottingInput): 

2165 """Class to handle and store webpage, plotting and metafile specific command 

2166 line arguments 

2167 """ 

2168 def __init__(self, *args, **kwargs): 

2169 super(WebpagePlusPlottingPlusMetaFileInput, self).__init__( 

2170 *args, **kwargs 

2171 ) 

2172 

2173 @property 

2174 def default_directories(self): 

2175 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_directories 

2176 

2177 @property 

2178 def default_files_to_copy(self): 

2179 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_files_to_copy 

2180 

2181 

2182@deprecation( 

2183 "The Input class is deprecated. Please use either the BaseInput, " 

2184 "SamplesInput, PlottingInput, WebpageInput, WebpagePlusPlottingInput, " 

2185 "MetaFileInput or the WebpagePlusPlottingPlusMetaFileInput class" 

2186) 

2187class Input(WebpagePlusPlottingPlusMetaFileInput): 

2188 pass 

2189 

2190 

2191def load_current_state(resume_file): 

2192 """Load a pickle file containing checkpoint information 

2193 

2194 Parameters 

2195 ---------- 

2196 resume_file: str 

2197 path to a checkpoint file 

2198 """ 

2199 from pesummary.io import read 

2200 if not os.path.isfile(resume_file): 

2201 logger.info( 

2202 "Unable to find resume file. Not restarting from checkpoint" 

2203 ) 

2204 return 

2205 logger.info( 

2206 "Reading checkpoint file: {}".format(resume_file) 

2207 ) 

2208 state = read(resume_file, checkpoint=True) 

2209 return state