Coverage for pesummary/core/cli/inputs.py: 80.5%

1264 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import os 

4import socket 

5from glob import glob 

6from pathlib import Path 

7from getpass import getuser 

8 

9import math 

10import numpy as np 

11from pesummary.core.file.read import read as Read 

12from pesummary.utils.exceptions import InputError 

13from pesummary.utils.decorators import deprecation 

14from pesummary.utils.samples_dict import SamplesDict, MCMCSamplesDict 

15from pesummary.utils.utils import ( 

16 guess_url, logger, make_dir, make_cache_style_file, list_match 

17) 

18from pesummary import conf 

19 

20__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

21 

22 

23class _Input(object): 

24 """Super class to handle the command line arguments 

25 """ 

26 @staticmethod 

27 def is_pesummary_metafile(proposed_file): 

28 """Determine if a file is a PESummary metafile or not 

29 

30 Parameters 

31 ---------- 

32 proposed_file: str 

33 path to the file 

34 """ 

35 extension = proposed_file.split(".")[-1] 

36 if extension == "h5" or extension == "hdf5" or extension == "hdf": 

37 from pesummary.core.file.read import ( 

38 is_pesummary_hdf5_file, is_pesummary_hdf5_file_deprecated 

39 ) 

40 

41 result = any( 

42 func(proposed_file) for func in [ 

43 is_pesummary_hdf5_file, 

44 is_pesummary_hdf5_file_deprecated 

45 ] 

46 ) 

47 return result 

48 elif extension == "json": 

49 from pesummary.core.file.read import ( 

50 is_pesummary_json_file, is_pesummary_json_file_deprecated 

51 ) 

52 

53 result = any( 

54 func(proposed_file) for func in [ 

55 is_pesummary_json_file, 

56 is_pesummary_json_file_deprecated 

57 ] 

58 ) 

59 return result 

60 else: 

61 return False 

62 

63 @staticmethod 

64 def grab_data_from_metafile( 

65 existing_file, webdir, compare=None, read_function=Read, 

66 _replace_with_pesummary_kwargs={}, nsamples=None, 

67 disable_injection=False, keep_nan_likelihood_samples=False, 

68 reweight_samples=False, **kwargs 

69 ): 

70 """Grab data from an existing PESummary metafile 

71 

72 Parameters 

73 ---------- 

74 existing_file: str 

75 path to the existing metafile 

76 webdir: str 

77 the directory to store the existing configuration file 

78 compare: list, optional 

79 list of labels for events stored in an existing metafile that you 

80 wish to compare 

81 read_function: func, optional 

82 PESummary function to use to read in the existing file 

83 _replace_with_pesummary_kwargs: dict, optional 

84 dictionary of kwargs that you wish to replace with the data stored 

85 in the PESummary file 

86 nsamples: int, optional 

87 Number of samples to use. Default all available samples 

88 kwargs: dict 

89 All kwargs are passed to the `generate_all_posterior_samples` 

90 method 

91 """ 

92 f = read_function( 

93 existing_file, 

94 remove_nan_likelihood_samples=not keep_nan_likelihood_samples 

95 ) 

96 for ind, label in enumerate(f.labels): 

97 kwargs[label] = kwargs.copy() 

98 for key, item in _replace_with_pesummary_kwargs.items(): 

99 try: 

100 kwargs[label][key] = eval( 

101 item.format(file="f", ind=ind, label=label) 

102 ) 

103 except TypeError: 

104 _item = item.split("['{label}']")[0] 

105 kwargs[label][key] = eval( 

106 _item.format(file="f", ind=ind, label=label) 

107 ) 

108 except (AttributeError, KeyError, NameError): 

109 pass 

110 

111 if not f.mcmc_samples: 

112 labels = f.labels 

113 else: 

114 labels = list(f.samples_dict.keys()) 

115 indicies = np.arange(len(labels)) 

116 

117 if compare: 

118 indicies = [] 

119 for i in compare: 

120 if i not in labels: 

121 raise InputError( 

122 "Label '%s' does not exist in the metafile. The list " 

123 "of available labels are %s" % (i, labels) 

124 ) 

125 indicies.append(labels.index(i)) 

126 labels = compare 

127 

128 if nsamples is not None: 

129 f.downsample(nsamples, labels=labels) 

130 if not f.mcmc_samples: 

131 f.generate_all_posterior_samples(labels=labels, **kwargs) 

132 if reweight_samples: 

133 f.reweight_samples(reweight_samples, labels=labels, **kwargs) 

134 

135 parameters = f.parameters 

136 if not f.mcmc_samples: 

137 samples = [np.array(i).T for i in f.samples] 

138 DataFrame = { 

139 label: SamplesDict(parameters[ind], samples[ind]) 

140 for label, ind in zip(labels, indicies) 

141 } 

142 _parameters = lambda label: DataFrame[label].keys() 

143 else: 

144 DataFrame = { 

145 f.labels[0]: MCMCSamplesDict( 

146 { 

147 label: f.samples_dict[label] for label in labels 

148 } 

149 ) 

150 } 

151 labels = f.labels 

152 indicies = np.arange(len(labels)) 

153 _parameters = lambda label: DataFrame[f.labels[0]].parameters 

154 if not disable_injection and f.injection_parameters != []: 

155 inj_values = f.injection_dict 

156 for label in labels: 

157 for param in DataFrame[label].keys(): 

158 if param not in f.injection_dict[label].keys(): 

159 f.injection_dict[label][param] = float("nan") 

160 else: 

161 inj_values = { 

162 i: { 

163 param: float("nan") for param in DataFrame[i].parameters 

164 } for i in labels 

165 } 

166 for i in inj_values.keys(): 

167 for param in inj_values[i].keys(): 

168 if inj_values[i][param] == "nan": 

169 inj_values[i][param] = float("nan") 

170 if isinstance(inj_values[i][param], bytes): 

171 inj_values[i][param] = inj_values[i][param].decode("utf-8") 

172 

173 if hasattr(f, "priors") and f.priors is not None and f.priors != {}: 

174 priors = f.priors 

175 else: 

176 priors = {label: {} for label in labels} 

177 

178 config = [] 

179 if f.config is not None and not all(i is None for i in f.config): 

180 config = [] 

181 for i in labels: 

182 config_dir = os.path.join(webdir, "config") 

183 filename = f.write_config_to_file( 

184 i, outdir=config_dir, _raise=False, 

185 filename="{}_config.ini".format(i) 

186 ) 

187 _config = os.path.join(config_dir, filename) 

188 if filename is not None and os.path.isfile(_config): 

189 config.append(_config) 

190 else: 

191 config.append(None) 

192 else: 

193 for i in labels: 

194 config.append(None) 

195 

196 if f.weights is not None: 

197 weights = {i: f.weights[i] for i in labels} 

198 else: 

199 weights = {i: None for i in labels} 

200 

201 return { 

202 "samples": DataFrame, 

203 "injection_data": inj_values, 

204 "file_version": { 

205 i: j for i, j in zip( 

206 labels, [f.input_version[ind] for ind in indicies] 

207 ) 

208 }, 

209 "file_kwargs": { 

210 i: j for i, j in zip( 

211 labels, [f.extra_kwargs[ind] for ind in indicies] 

212 ) 

213 }, 

214 "prior": priors, 

215 "config": config, 

216 "labels": labels, 

217 "weights": weights, 

218 "indicies": indicies, 

219 "mcmc_samples": f.mcmc_samples, 

220 "open_file": f, 

221 "descriptions": f.description 

222 } 

223 

224 @staticmethod 

225 def grab_data_from_file( 

226 file, label, webdir, config=None, injection=None, read_function=Read, 

227 file_format=None, nsamples=None, disable_prior_sampling=False, 

228 nsamples_for_prior=None, path_to_samples=None, 

229 keep_nan_likelihood_samples=False, reweight_samples=False, 

230 **kwargs 

231 ): 

232 """Grab data from a result file containing posterior samples 

233 

234 Parameters 

235 ---------- 

236 file: str 

237 path to the result file 

238 label: str 

239 label that you wish to use for the result file 

240 config: str, optional 

241 path to a configuration file used in the analysis 

242 injection: str, optional 

243 path to an injection file used in the analysis 

244 read_function: func, optional 

245 PESummary function to use to read in the file 

246 file_format, str, optional 

247 the file format you wish to use when loading. Default None. 

248 If None, the read function loops through all possible options 

249 kwargs: dict 

250 Dictionary of keyword arguments fed to the 

251 `generate_all_posterior_samples` method 

252 """ 

253 f = read_function( 

254 file, file_format=file_format, disable_prior=disable_prior_sampling, 

255 nsamples_for_prior=nsamples_for_prior, path_to_samples=path_to_samples, 

256 remove_nan_likelihood_samples=not keep_nan_likelihood_samples 

257 ) 

258 if config is not None: 

259 f.add_fixed_parameters_from_config_file(config) 

260 

261 if nsamples is not None: 

262 f.downsample(nsamples) 

263 f.generate_all_posterior_samples(**kwargs) 

264 if injection: 

265 f.add_injection_parameters_from_file( 

266 injection, conversion_kwargs=kwargs 

267 ) 

268 if reweight_samples: 

269 f.reweight_samples(reweight_samples) 

270 parameters = f.parameters 

271 samples = np.array(f.samples).T 

272 DataFrame = {label: SamplesDict(parameters, samples)} 

273 kwargs = f.extra_kwargs 

274 if hasattr(f, "injection_parameters"): 

275 injection = f.injection_parameters 

276 if injection is not None: 

277 for i in parameters: 

278 if i not in list(injection.keys()): 

279 injection[i] = float("nan") 

280 else: 

281 injection = {i: j for i, j in zip( 

282 parameters, [float("nan")] * len(parameters))} 

283 else: 

284 injection = {i: j for i, j in zip( 

285 parameters, [float("nan")] * len(parameters))} 

286 version = f.input_version 

287 if hasattr(f, "priors") and f.priors is not None: 

288 priors = {key: {label: item} for key, item in f.priors.items()} 

289 else: 

290 priors = {label: []} 

291 if hasattr(f, "weights") and f.weights is not None: 

292 weights = f.weights 

293 else: 

294 weights = None 

295 data = { 

296 "samples": DataFrame, 

297 "injection_data": {label: injection}, 

298 "file_version": {label: version}, 

299 "file_kwargs": {label: kwargs}, 

300 "prior": priors, 

301 "weights": {label: weights}, 

302 "open_file": f, 

303 "descriptions": {label: f.description} 

304 } 

305 if hasattr(f, "config") and f.config is not None: 

306 if config is None: 

307 config_dir = os.path.join(webdir, "config") 

308 filename = "{}_config.ini".format(label) 

309 logger.debug( 

310 "Successfully extracted config data from the provided " 

311 "input file. Saving the data to the file '{}'".format( 

312 os.path.join(config_dir, filename) 

313 ) 

314 ) 

315 _filename = f.write( 

316 filename=filename, outdir=config_dir, file_format="ini", 

317 _raise=False 

318 ) 

319 data["config"] = _filename 

320 else: 

321 logger.info( 

322 "Ignoring config data extracted from the input file and " 

323 "using the config file provided" 

324 ) 

325 return data 

326 

327 @property 

328 def result_files(self): 

329 return self._result_files 

330 

331 @result_files.setter 

332 def result_files(self, result_files): 

333 self._result_files = result_files 

334 if self._result_files is not None: 

335 for num, ff in enumerate(self._result_files): 

336 func = None 

337 if not os.path.isfile(ff) and "@" in ff: 

338 from pesummary.io.read import _fetch_from_remote_server 

339 func = _fetch_from_remote_server 

340 elif not os.path.isfile(ff) and "https://" in ff: 

341 from pesummary.io.read import _fetch_from_url 

342 func = _fetch_from_url 

343 elif not os.path.isfile(ff) and "*" in ff: 

344 from pesummary.utils.utils import glob_directory 

345 func = glob_directory 

346 if func is not None: 

347 _data = func(ff) 

348 if isinstance(_data, (np.ndarray, list)) and len(_data) > 0: 

349 self._result_files[num] = _data[0] 

350 if len(_data) > 1: 

351 _ = [ 

352 self._result_files.insert(num + 1, d) for d in 

353 _data[1:][::-1] 

354 ] 

355 elif isinstance(_data, np.ndarray): 

356 raise InputError( 

357 "Unable to find any files matching '{}'".format(ff) 

358 ) 

359 else: 

360 self._result_files[num] = _data 

361 

362 @property 

363 def seed(self): 

364 return self._seed 

365 

366 @seed.setter 

367 def seed(self, seed): 

368 np.random.seed(seed) 

369 self._seed = seed 

370 

371 @property 

372 def existing(self): 

373 return self._existing 

374 

375 @existing.setter 

376 def existing(self, existing): 

377 self._existing = existing 

378 if existing is not None: 

379 self._existing = os.path.abspath(existing) 

380 

381 @property 

382 def existing_metafile(self): 

383 return self._existing_metafile 

384 

385 @existing_metafile.setter 

386 def existing_metafile(self, existing_metafile): 

387 from glob import glob 

388 

389 self._existing_metafile = existing_metafile 

390 if self._existing_metafile is None: 

391 return 

392 if not os.path.isdir(os.path.join(self.existing, "samples")): 

393 raise InputError("Please provide a valid existing directory") 

394 _dir = os.path.join(self.existing, "samples") 

395 files = glob(os.path.join(_dir, "posterior_samples*")) 

396 dir_content = glob(os.path.join(_dir, "*.h5")) 

397 dir_content.extend(glob(os.path.join(_dir, "*.json"))) 

398 dir_content.extend(glob(os.path.join(_dir, "*.hdf5"))) 

399 if len(files) == 0 and len(dir_content): 

400 files = dir_content 

401 logger.warning( 

402 "Unable to find a 'posterior_samples*' file in the existing " 

403 "directory. Using '{}' as the existing metafile".format( 

404 dir_content[0] 

405 ) 

406 ) 

407 elif len(files) == 0: 

408 raise InputError( 

409 "Unable to find an existing metafile in the existing webdir" 

410 ) 

411 elif len(files) > 1: 

412 raise InputError( 

413 "Multiple metafiles in the existing directory. Please either " 

414 "run the `summarycombine_metafile` executable to combine the " 

415 "meta files or simple remove the unwanted meta file" 

416 ) 

417 self._existing_metafile = os.path.join( 

418 self.existing, "samples", files[0] 

419 ) 

420 

421 @property 

422 def style_file(self): 

423 return self._style_file 

424 

425 @style_file.setter 

426 def style_file(self, style_file): 

427 default = conf.style_file 

428 if style_file is not None and not os.path.isfile(style_file): 

429 logger.warning( 

430 "The file '{}' does not exist. Resorting to default".format( 

431 style_file 

432 ) 

433 ) 

434 style_file = default 

435 elif style_file is not None and os.path.isfile(style_file): 

436 logger.info( 

437 "Using the file '{}' as the matplotlib style file".format( 

438 style_file 

439 ) 

440 ) 

441 elif style_file is None: 

442 logger.debug( 

443 "Using the default matplotlib style file" 

444 ) 

445 style_file = default 

446 make_cache_style_file(style_file) 

447 self._style_file = style_file 

448 

449 @property 

450 def filename(self): 

451 return self._filename 

452 

453 @filename.setter 

454 def filename(self, filename): 

455 self._filename = filename 

456 if filename is not None: 

457 if "/" in filename: 

458 logger.warning("") 

459 filename = filename.split("/")[-1] 

460 if os.path.isfile(os.path.join(self.webdir, "samples", filename)): 

461 logger.warning( 

462 "A file with filename '{}' already exists in the samples " 

463 "directory '{}'. This will be overwritten" 

464 ) 

465 

466 @property 

467 def user(self): 

468 return self._user 

469 

470 @user.setter 

471 def user(self, user): 

472 try: 

473 self._user = getuser() 

474 logger.info( 

475 conf.overwrite.format("user", conf.user, self._user) 

476 ) 

477 except KeyError as e: 

478 logger.info( 

479 "Failed to grab user information because {}. Default will be " 

480 "used".format(e) 

481 ) 

482 self._user = user 

483 

484 @property 

485 def host(self): 

486 return socket.getfqdn() 

487 

488 @property 

489 def webdir(self): 

490 return self._webdir 

491 

492 @webdir.setter 

493 def webdir(self, webdir): 

494 cond1 = webdir is None or webdir == "None" or webdir == "none" 

495 cond2 = ( 

496 self.existing is None or self.existing == "None" 

497 or self.existing == "none" 

498 ) 

499 if cond1 and cond2: 

500 raise InputError( 

501 "Please provide a web directory to store the webpages. If " 

502 "you wish to add to an existing webpage, then pass the " 

503 "existing web directory under the '--existing_webdir' command " 

504 "line argument. If this is a new set of webpages, then pass " 

505 "the web directory under the '--webdir' argument" 

506 ) 

507 elif webdir is None and self.existing is not None: 

508 if not os.path.isdir(self.existing): 

509 raise InputError( 

510 "The directory {} does not exist".format(self.existing) 

511 ) 

512 entries = glob(self.existing + "/*") 

513 if os.path.join(self.existing, "home.html") not in entries: 

514 raise InputError( 

515 "Please give the base directory of an existing output" 

516 ) 

517 self._webdir = self.existing 

518 else: 

519 if not os.path.isdir(webdir): 

520 logger.debug( 

521 "Given web directory does not exist. Creating it now" 

522 ) 

523 make_dir(webdir) 

524 self._webdir = os.path.abspath(webdir) 

525 

526 @property 

527 def baseurl(self): 

528 return self._baseurl 

529 

530 @baseurl.setter 

531 def baseurl(self, baseurl): 

532 self._baseurl = baseurl 

533 if baseurl is None: 

534 self._baseurl = guess_url(self.webdir, self.host, self.user) 

535 

536 @property 

537 def mcmc_samples(self): 

538 return self._mcmc_samples 

539 

540 @mcmc_samples.setter 

541 def mcmc_samples(self, mcmc_samples): 

542 self._mcmc_samples = mcmc_samples 

543 if self._mcmc_samples: 

544 logger.info( 

545 "Treating all samples as seperate mcmc chains for the same " 

546 "analysis." 

547 ) 

548 

549 @property 

550 def labels(self): 

551 return self._labels 

552 

553 @labels.setter 

554 def labels(self, labels): 

555 if self.result_files is not None: 

556 if any(self.is_pesummary_metafile(s) for s in self.result_files): 

557 logger.warning( 

558 "labels argument is ignored when a pesummary metafile is " 

559 "input. Stored analyses will use their stored labels. If " 

560 "you wish to change the labels, please use `summarymodify`" 

561 ) 

562 labels = self.default_labels() 

563 if not hasattr(self, "._labels"): 

564 if labels is None: 

565 labels = self.default_labels() 

566 elif self.mcmc_samples and len(labels) != 1: 

567 raise InputError( 

568 "Please provide a single label that corresponds to all " 

569 "mcmc samples" 

570 ) 

571 elif len(np.unique(labels)) != len(labels): 

572 raise InputError( 

573 "Please provide unique labels for each result file" 

574 ) 

575 for num, i in enumerate(labels): 

576 if "." in i: 

577 logger.warning( 

578 "Replacing the label {} by {} to make it compatible " 

579 "with the html pages".format(i, i.replace(".", "_")) 

580 ) 

581 labels[num] = i.replace(".", "_") 

582 if self.add_to_existing: 

583 for i in labels: 

584 if i in self.existing_labels: 

585 raise InputError( 

586 "The label '%s' already exists in the existing " 

587 "metafile. Please pass another unique label" 

588 ) 

589 

590 if len(self.result_files) != len(labels) and not self.mcmc_samples: 

591 import copy 

592 _new_labels = copy.deepcopy(labels) 

593 idx = 1 

594 while len(_new_labels) < len(self.result_files): 

595 _new_labels.extend( 

596 [_label + str(idx) for _label in labels] 

597 ) 

598 idx += 1 

599 _new_labels = _new_labels[:len(self.result_files)] 

600 logger.info( 

601 "You have passed {} result files and {} labels. Setting " 

602 "labels = {}".format( 

603 len(self.result_files), len(labels), _new_labels 

604 ) 

605 ) 

606 labels = _new_labels 

607 self._labels = labels 

608 

609 @property 

610 def config(self): 

611 return self._config 

612 

613 @config.setter 

614 def config(self, config): 

615 if config and len(config) != len(self.labels): 

616 raise InputError( 

617 "Please provide a configuration file for each label" 

618 ) 

619 if config is None and not self.meta_file: 

620 self._config = [None] * len(self.labels) 

621 elif self.meta_file: 

622 self._config = [None] * len(self.labels) 

623 else: 

624 self._config = config 

625 for num, ff in enumerate(self._config): 

626 if isinstance(ff, str) and ff.lower() == "none": 

627 self._config[num] = None 

628 

629 @property 

630 def injection_file(self): 

631 return self._injection_file 

632 

633 @injection_file.setter 

634 def injection_file(self, injection_file): 

635 if injection_file and len(injection_file) != len(self.labels): 

636 if len(injection_file) == 1: 

637 logger.info( 

638 "Only one injection file passed. Assuming the same " 

639 "injection for all {} result files".format(len(self.labels)) 

640 ) 

641 else: 

642 raise InputError( 

643 "You have passed {} for {} result files. Please provide an " 

644 "injection file for each result file".format( 

645 len(self.injection_file), len(self.labels) 

646 ) 

647 ) 

648 if injection_file is None: 

649 injection_file = [None] * len(self.labels) 

650 self._injection_file = injection_file 

651 

652 @property 

653 def injection_data(self): 

654 return self._injection_data 

655 

656 @property 

657 def file_version(self): 

658 return self._file_version 

659 

660 @property 

661 def file_kwargs(self): 

662 return self._file_kwargs 

663 

664 @property 

665 def kde_plot(self): 

666 return self._kde_plot 

667 

668 @kde_plot.setter 

669 def kde_plot(self, kde_plot): 

670 self._kde_plot = kde_plot 

671 if kde_plot != conf.kde_plot: 

672 logger.info( 

673 conf.overwrite.format("kde_plot", conf.kde_plot, kde_plot) 

674 ) 

675 

676 @property 

677 def file_format(self): 

678 return self._file_format 

679 

680 @file_format.setter 

681 def file_format(self, file_format): 

682 if file_format is None: 

683 self._file_format = [None] * len(self.labels) 

684 elif len(file_format) == 1 and len(file_format) != len(self.labels): 

685 logger.warning( 

686 "Only one file format specified. Assuming all files are of " 

687 "this format" 

688 ) 

689 self._file_format = [file_format[0]] * len(self.labels) 

690 elif len(file_format) != len(self.labels): 

691 raise InputError( 

692 "Please provide a file format for each result file. If you " 

693 "wish to specify the file format for the second result file " 

694 "and not for any of the others, for example, simply pass 'None " 

695 "{format} None'" 

696 ) 

697 else: 

698 for num, ff in enumerate(file_format): 

699 if ff.lower() == "none": 

700 file_format[num] = None 

701 self._file_format = file_format 

702 

703 @property 

704 def samples(self): 

705 return self._samples 

706 

707 @samples.setter 

708 def samples(self, samples): 

709 if isinstance(samples, dict): 

710 return samples 

711 self._set_samples(samples) 

712 

713 def _set_samples( 

714 self, samples, 

715 ignore_keys=["prior", "weights", "labels", "indicies", "open_file"] 

716 ): 

717 """Extract the samples and store them as attributes of self 

718 

719 Parameters 

720 ---------- 

721 samples: list 

722 A list containing the paths to result files 

723 ignore_keys: list, optional 

724 A list containing properties of the read file that you do not want to be 

725 stored as attributes of self 

726 """ 

727 if not samples: 

728 raise InputError("Please provide a results file") 

729 _samples_generator = (self.is_pesummary_metafile(s) for s in samples) 

730 if any(_samples_generator) and not all(_samples_generator): 

731 raise InputError( 

732 "It seems that you have passed a combination of pesummary " 

733 "metafiles and non-pesummary metafiles. This is currently " 

734 "not supported." 

735 ) 

736 labels, labels_dict = None, {} 

737 weights_dict = {} 

738 if self.mcmc_samples: 

739 nsamples = 0. 

740 for num, i in enumerate(samples): 

741 idx = num 

742 if not self.mcmc_samples: 

743 if not self.is_pesummary_metafile(samples[num]): 

744 logger.info("Assigning {} to {}".format(self.labels[num], i)) 

745 else: 

746 num = 0 

747 if not os.path.isfile(i): 

748 raise InputError("File %s does not exist" % (i)) 

749 if self.is_pesummary_metafile(samples[num]): 

750 data = self.grab_data_from_input( 

751 i, self.labels[num], config=None, injection=None 

752 ) 

753 self.mcmc_samples = data["mcmc_samples"] 

754 else: 

755 data = self.grab_data_from_input( 

756 i, self.labels[num], config=self.config[num], 

757 injection=self.injection_file[num], 

758 file_format=self.file_format[num] 

759 ) 

760 if "config" in data.keys(): 

761 msg = ( 

762 "Overwriting the provided config file for '{}' with " 

763 "the config information stored in the input " 

764 "file".format(self.labels[num]) 

765 ) 

766 if self.config[num] is None: 

767 logger.debug(msg) 

768 else: 

769 logger.info(msg) 

770 self.config[num] = data.pop("config") 

771 if self.mcmc_samples: 

772 data["samples"] = { 

773 "{}_mcmc_chain_{}".format(key, idx): item for key, item 

774 in data["samples"].items() 

775 } 

776 for key, item in data.items(): 

777 if key not in ignore_keys: 

778 if idx == 0: 

779 setattr(self, "_{}".format(key), item) 

780 else: 

781 x = getattr(self, "_{}".format(key)) 

782 if isinstance(x, dict): 

783 x.update(item) 

784 elif isinstance(x, list): 

785 x += item 

786 setattr(self, "_{}".format(key), x) 

787 if self.mcmc_samples: 

788 try: 

789 nsamples += data["file_kwargs"][self.labels[num]]["sampler"][ 

790 "nsamples" 

791 ] 

792 except UnboundLocalError: 

793 pass 

794 if "labels" in data.keys(): 

795 stored_labels = data["labels"] 

796 else: 

797 stored_labels = [self.labels[num]] 

798 if "weights" in data.items(): 

799 weights_dict = data["weights"] 

800 if "prior" in data.keys(): 

801 for label in stored_labels: 

802 pp = data["prior"] 

803 if pp != {} and label in pp.keys() and pp[label] == []: 

804 if len(self.priors): 

805 if label not in self.priors["samples"].keys(): 

806 self.add_to_prior_dict( 

807 "samples/{}".format(label), [] 

808 ) 

809 else: 

810 self.add_to_prior_dict( 

811 "samples/{}".format(label), [] 

812 ) 

813 elif pp != {} and label not in pp.keys(): 

814 for key in pp.keys(): 

815 if key in self.priors.keys(): 

816 if label in self.priors[key].keys(): 

817 logger.warning( 

818 "Replacing the prior file for {} " 

819 "with the prior file stored in " 

820 "the result file".format( 

821 label 

822 ) 

823 ) 

824 if pp[key] == {}: 

825 self.add_to_prior_dict( 

826 "{}/{}".format(key, label), [] 

827 ) 

828 elif label not in pp[key].keys(): 

829 self.add_to_prior_dict( 

830 "{}/{}".format(key, label), {} 

831 ) 

832 else: 

833 self.add_to_prior_dict( 

834 "{}/{}".format(key, label), pp[key][label] 

835 ) 

836 else: 

837 self.add_to_prior_dict( 

838 "samples/{}".format(label), [] 

839 ) 

840 if "labels" in data.keys(): 

841 _duplicated = [ 

842 _ for _ in data["labels"] if num != 0 and _ in labels 

843 ] 

844 if num == 0: 

845 labels = data["labels"] 

846 elif len(_duplicated): 

847 raise InputError( 

848 "The labels stored in the supplied files are not " 

849 "unique. The label{}: '{}' appear{} in two or more " 

850 "files. Please provide unique labels for each " 

851 "analysis.".format( 

852 "s" if len(_duplicated) > 1 else "", 

853 ", ".join(_duplicated), 

854 "" if len(_duplicated) > 1 else "s" 

855 ) 

856 ) 

857 else: 

858 labels += data["labels"] 

859 labels_dict[num] = data["labels"] 

860 if self.mcmc_samples: 

861 try: 

862 self.file_kwargs[self.labels[0]]["sampler"].update( 

863 {"nsamples": nsamples, "nchains": len(self.result_files)} 

864 ) 

865 except (KeyError, UnboundLocalError): 

866 pass 

867 _labels = list(self._samples.keys()) 

868 if not isinstance(self._samples[_labels[0]], MCMCSamplesDict): 

869 self._samples = MCMCSamplesDict(self._samples) 

870 else: 

871 self._samples = self._samples[_labels[0]] 

872 if labels is not None: 

873 self._labels = labels 

874 if len(labels) != len(self.result_files): 

875 result_files = [] 

876 for num, f in enumerate(samples): 

877 for ii in np.arange(len(labels_dict[num])): 

878 result_files.append(self.result_files[num]) 

879 self.result_files = result_files 

880 self.weights = {i: None for i in self.labels} 

881 if weights_dict != {}: 

882 self.weights = weights_dict 

883 

884 @property 

885 def burnin_method(self): 

886 return self._burnin_method 

887 

888 @burnin_method.setter 

889 def burnin_method(self, burnin_method): 

890 self._burnin_method = burnin_method 

891 if not self.mcmc_samples and burnin_method is not None: 

892 logger.info( 

893 "The {} method will not be used to remove samples as " 

894 "burnin as this can only be used for mcmc chains.".format( 

895 burnin_method 

896 ) 

897 ) 

898 self._burnin_method = None 

899 elif self.mcmc_samples and burnin_method is None: 

900 logger.info( 

901 "No burnin method provided. Using {} as default".format( 

902 conf.burnin_method 

903 ) 

904 ) 

905 self._burnin_method = conf.burnin_method 

906 elif self.mcmc_samples: 

907 from pesummary.core.file import mcmc 

908 

909 if burnin_method not in mcmc.algorithms: 

910 logger.warning( 

911 "Unrecognised burnin method: {}. Resorting to the default: " 

912 "{}".format(burnin_method, conf.burnin_method) 

913 ) 

914 self._burnin_method = conf.burnin_method 

915 if self._burnin_method is not None: 

916 for label in self.labels: 

917 self.file_kwargs[label]["sampler"]["burnin_method"] = ( 

918 self._burnin_method 

919 ) 

920 

921 @property 

922 def burnin(self): 

923 return self._burnin 

924 

925 @burnin.setter 

926 def burnin(self, burnin): 

927 _name = "nsamples_removed_from_burnin" 

928 if burnin is not None: 

929 samples_lengths = [ 

930 self.samples[key].number_of_samples for key in 

931 self.samples.keys() 

932 ] 

933 if not all(int(burnin) < i for i in samples_lengths): 

934 raise InputError( 

935 "The chosen burnin is larger than the number of samples. " 

936 "Please choose a value less than {}".format( 

937 np.max(samples_lengths) 

938 ) 

939 ) 

940 logger.info( 

941 conf.overwrite.format("burnin", conf.burnin, burnin) 

942 ) 

943 burnin = int(burnin) 

944 else: 

945 burnin = conf.burnin 

946 if self.burnin_method is not None: 

947 arguments, kwargs = [], {} 

948 if burnin != 0 and self.burnin_method == "burnin_by_step_number": 

949 logger.warning( 

950 "The first {} samples have been requested to be removed " 

951 "as burnin, but the burnin method has been chosen to be " 

952 "burnin_by_step_number. Changing method to " 

953 "burnin_by_first_n with keyword argument step_number=" 

954 "True such that all samples with step number < {} are " 

955 "removed".format(burnin, burnin) 

956 ) 

957 self.burnin_method = "burnin_by_first_n" 

958 arguments = [burnin] 

959 kwargs = {"step_number": True} 

960 elif self.burnin_method == "burnin_by_first_n": 

961 arguments = [burnin] 

962 initial = self.samples.total_number_of_samples 

963 self._samples = self.samples.burnin( 

964 *arguments, algorithm=self.burnin_method, **kwargs 

965 ) 

966 diff = initial - self.samples.total_number_of_samples 

967 self.file_kwargs[self.labels[0]]["sampler"][_name] = diff 

968 self.file_kwargs[self.labels[0]]["sampler"]["nsamples"] = \ 

969 self._samples.total_number_of_samples 

970 else: 

971 for label in self.samples: 

972 self.samples[label] = self.samples[label].discard_samples( 

973 burnin 

974 ) 

975 if burnin != conf.burnin: 

976 self.file_kwargs[label]["sampler"][_name] = burnin 

977 

978 @property 

979 def nsamples(self): 

980 return self._nsamples 

981 

982 @nsamples.setter 

983 def nsamples(self, nsamples): 

984 self._nsamples = nsamples 

985 if nsamples is not None: 

986 logger.info( 

987 "{} samples will be used for each result file".format(nsamples) 

988 ) 

989 self._nsamples = int(nsamples) 

990 

991 @property 

992 def reweight_samples(self): 

993 return self._reweight_samples 

994 

995 @reweight_samples.setter 

996 def reweight_samples(self, reweight_samples): 

997 from pesummary.core.reweight import options 

998 self._reweight_samples = self._check_reweight_samples( 

999 reweight_samples, options 

1000 ) 

1001 

1002 def _check_reweight_samples(self, reweight_samples, options): 

1003 if reweight_samples and reweight_samples not in options.keys(): 

1004 logger.warning( 

1005 "Unknown reweight function: '{}'. Not reweighting posterior " 

1006 "and/or prior samples".format(reweight_samples) 

1007 ) 

1008 return False 

1009 return reweight_samples 

1010 

1011 @property 

1012 def path_to_samples(self): 

1013 return self._path_to_samples 

1014 

1015 @path_to_samples.setter 

1016 def path_to_samples(self, path_to_samples): 

1017 self._path_to_samples = path_to_samples 

1018 if path_to_samples is None: 

1019 self._path_to_samples = {label: None for label in self.labels} 

1020 elif len(path_to_samples) != len(self.labels): 

1021 raise InputError( 

1022 "Please provide a path for all result files passed. If " 

1023 "two result files are passed, and only one requires the " 

1024 "path_to_samples arguement, please pass --path_to_samples " 

1025 "None path/to/samples" 

1026 ) 

1027 else: 

1028 _paths = {} 

1029 for num, path in enumerate(path_to_samples): 

1030 _label = self.labels[num] 

1031 if path.lower() == "none": 

1032 _paths[_label] = None 

1033 else: 

1034 _paths[_label] = path 

1035 self._path_to_samples = _paths 

1036 

1037 @property 

1038 def priors(self): 

1039 return self._priors 

1040 

1041 @priors.setter 

1042 def priors(self, priors): 

1043 self._priors = self.grab_priors_from_inputs(priors) 

1044 

1045 @property 

1046 def custom_plotting(self): 

1047 return self._custom_plotting 

1048 

1049 @custom_plotting.setter 

1050 def custom_plotting(self, custom_plotting): 

1051 self._custom_plotting = custom_plotting 

1052 if custom_plotting is not None: 

1053 import importlib 

1054 

1055 path_to_python_file = os.path.dirname(custom_plotting) 

1056 python_file = os.path.splitext(os.path.basename(custom_plotting))[0] 

1057 if path_to_python_file != "": 

1058 import sys 

1059 

1060 sys.path.append(path_to_python_file) 

1061 try: 

1062 mod = importlib.import_module(python_file) 

1063 methods = getattr(mod, "__single_plots__", list()).copy() 

1064 methods += getattr(mod, "__comparion_plots__", list()).copy() 

1065 if len(methods) > 0: 

1066 self._custom_plotting = [path_to_python_file, python_file] 

1067 else: 

1068 logger.warning( 

1069 "No __single_plots__ or __comparison_plots__ in {}. " 

1070 "If you wish to use custom plotting, then please " 

1071 "add the variable :__single_plots__ and/or " 

1072 "__comparison_plots__ in future. No custom plotting " 

1073 "will be done" 

1074 ) 

1075 except ModuleNotFoundError as e: 

1076 logger.warning( 

1077 "Failed to import {} because {}. No custom plotting will " 

1078 "be done".format(python_file, e) 

1079 ) 

1080 

1081 @property 

1082 def external_hdf5_links(self): 

1083 return self._external_hdf5_links 

1084 

1085 @external_hdf5_links.setter 

1086 def external_hdf5_links(self, external_hdf5_links): 

1087 self._external_hdf5_links = external_hdf5_links 

1088 if not self.hdf5 and self.external_hdf5_links: 

1089 logger.warning( 

1090 "You can only apply external hdf5 links when saving the meta " 

1091 "file in hdf5 format. Turning external hdf5 links off." 

1092 ) 

1093 self._external_hdf5_links = False 

1094 

1095 @property 

1096 def hdf5_compression(self): 

1097 return self._hdf5_compression 

1098 

1099 @hdf5_compression.setter 

1100 def hdf5_compression(self, hdf5_compression): 

1101 self._hdf5_compression = hdf5_compression 

1102 if not self.hdf5 and hdf5_compression is not None: 

1103 logger.warning( 

1104 "You can only apply compression when saving the meta " 

1105 "file in hdf5 format. Turning compression off." 

1106 ) 

1107 self._hdf5_compression = None 

1108 

1109 @property 

1110 def existing_plot(self): 

1111 return self._existing_plot 

1112 

1113 @existing_plot.setter 

1114 def existing_plot(self, existing_plot): 

1115 self._existing_plot = existing_plot 

1116 if self._existing_plot is not None: 

1117 from pathlib import Path 

1118 import shutil 

1119 if isinstance(self._existing_plot, list): 

1120 logger.warning( 

1121 "Assigning {} to all labels".format( 

1122 ", ".join(self._existing_plot) 

1123 ) 

1124 ) 

1125 self._existing_plot = { 

1126 label: self._existing_plot for label in self.labels 

1127 } 

1128 _does_not_exist = ( 

1129 "The plot {} does not exist. Not adding plot to summarypages." 

1130 ) 

1131 keys_to_remove = [] 

1132 for key, _plot in self._existing_plot.items(): 

1133 if isinstance(_plot, list): 

1134 allowed = [] 

1135 for _subplot in _plot: 

1136 if not os.path.isfile(_subplot): 

1137 logger.warning(_does_not_exist.format(_subplot)) 

1138 else: 

1139 _filename = os.path.join( 

1140 self.webdir, "plots", Path(_subplot).name 

1141 ) 

1142 try: 

1143 shutil.copyfile(_subplot, _filename) 

1144 except shutil.SameFileError: 

1145 pass 

1146 allowed.append(_filename) 

1147 if not len(allowed): 

1148 keys_to_remove.append(key) 

1149 elif len(allowed) == 1: 

1150 self._existing_plot[key] = allowed[0] 

1151 else: 

1152 self._existing_plot[key] = allowed 

1153 else: 

1154 if not os.path.isfile(_plot): 

1155 logger.warning(_does_not_exist.format(_plot)) 

1156 keys_to_remove.append(key) 

1157 else: 

1158 _filename = os.path.join( 

1159 self.webdir, "plots", Path(_plot).name 

1160 ) 

1161 try: 

1162 shutil.copyfile(_plot, _filename) 

1163 except shutil.SameFileError: 

1164 _filename = os.path.join( 

1165 self.webdir, "plots", key + "_" + Path(_plot).name 

1166 ) 

1167 shutil.copyfile(_plot, _filename) 

1168 self._existing_plot[key] = _filename 

1169 for key in keys_to_remove: 

1170 del self._existing_plot[key] 

1171 if not len(self._existing_plot): 

1172 self._existing_plot = None 

1173 

1174 def add_to_prior_dict(self, path, data): 

1175 """Add priors to the prior dictionary 

1176 

1177 Parameters 

1178 ---------- 

1179 path: str 

1180 the location where you wish to store the prior. If this is inside 

1181 a nested dictionary, then please pass the path as 'a/b' 

1182 data: np.ndarray 

1183 the prior samples 

1184 """ 

1185 from functools import reduce 

1186 

1187 def build_tree(dictionary, path): 

1188 """Build a dictionary tree from a list of keys 

1189 

1190 Parameters 

1191 ---------- 

1192 dictionary: dict 

1193 existing dictionary that you wish to add to 

1194 path: list 

1195 list of keys specifying location 

1196 

1197 Examples 

1198 -------- 

1199 >>> dictionary = {"label": {"mass_1": [1,2,3,4,5,6]}} 

1200 >>> path = ["label", "mass_2"] 

1201 >>> build_tree(dictionary, path) 

1202 {"label": {"mass_1": [1,2,3,4,5,6], "mass_2": {}}} 

1203 """ 

1204 if path != [] and path[0] not in dictionary.keys(): 

1205 dictionary[path[0]] = {} 

1206 if path != []: 

1207 build_tree(dictionary[path[0]], path[1:]) 

1208 return dictionary 

1209 

1210 def get_nested_dictionary(dictionary, path): 

1211 """Return a nested dictionary from a list specifying path 

1212 

1213 Parameters 

1214 ---------- 

1215 dictionary: dict 

1216 existing dictionary that you wish to extract information from 

1217 path: list 

1218 list of keys specifying location 

1219 

1220 Examples 

1221 -------- 

1222 >>> dictionary = {"label": {"mass_1": [1,2,3,4,5,6]}} 

1223 >>> path = ["label", "mass_1"] 

1224 >>> get_nested_dictionary(dictionary, path) 

1225 [1,2,3,4,5,6] 

1226 """ 

1227 return reduce(dict.get, path, dictionary) 

1228 

1229 if "/" in path: 

1230 path = path.split("/") 

1231 else: 

1232 path = [path] 

1233 tree = build_tree(self._priors, path) 

1234 nested_dictionary = get_nested_dictionary(self._priors, path[:-1]) 

1235 nested_dictionary[path[-1]] = data 

1236 

1237 def grab_priors_from_inputs(self, priors, read_func=None, read_kwargs={}): 

1238 """ 

1239 """ 

1240 if read_func is None: 

1241 from pesummary.core.file.read import read as Read 

1242 read_func = Read 

1243 

1244 prior_dict = {} 

1245 if priors is not None: 

1246 prior_dict = {"samples": {}, "analytic": {}} 

1247 for i in priors: 

1248 if not os.path.isfile(i): 

1249 raise InputError("The file {} does not exist".format(i)) 

1250 if len(priors) != len(self.labels) and len(priors) == 1: 

1251 logger.warning( 

1252 "You have only specified a single prior file for {} result " 

1253 "files. Assuming the same prior file for all result " 

1254 "files".format(len(self.labels)) 

1255 ) 

1256 data = read_func( 

1257 priors[0], nsamples=self.nsamples_for_prior 

1258 ) 

1259 for i in self.labels: 

1260 prior_dict["samples"][i] = data.samples_dict 

1261 try: 

1262 if data.analytic is not None: 

1263 prior_dict["analytic"][i] = data.analytic 

1264 except AttributeError: 

1265 continue 

1266 elif len(priors) != len(self.labels): 

1267 raise InputError( 

1268 "Please provide a prior file for each result file" 

1269 ) 

1270 else: 

1271 for num, i in enumerate(priors): 

1272 if i.lower() == "none": 

1273 continue 

1274 logger.info( 

1275 "Assigning {} to {}".format(self.labels[num], i) 

1276 ) 

1277 if self.labels[num] in read_kwargs.keys(): 

1278 grab_data_kwargs = read_kwargs[self.labels[num]] 

1279 else: 

1280 grab_data_kwargs = read_kwargs 

1281 data = read_func( 

1282 priors[num], nsamples=self.nsamples_for_prior, 

1283 **grab_data_kwargs 

1284 ) 

1285 prior_dict["samples"][self.labels[num]] = data.samples_dict 

1286 try: 

1287 if data.analytic is not None: 

1288 prior_dict["analytic"][self.labels[num]] = data.analytic 

1289 except AttributeError: 

1290 continue 

1291 return prior_dict 

1292 

1293 @property 

1294 def grab_data_kwargs(self): 

1295 return { 

1296 label: dict(regenerate=self.regenerate) for label in self.labels 

1297 } 

1298 

1299 def grab_data_from_input( 

1300 self, file, label, config=None, injection=None, file_format=None 

1301 ): 

1302 """Wrapper function for the grab_data_from_metafile and 

1303 grab_data_from_file functions 

1304 

1305 Parameters 

1306 ---------- 

1307 file: str 

1308 path to the result file 

1309 label: str 

1310 label that you wish to use for the result file 

1311 config: str, optional 

1312 path to a configuration file used in the analysis 

1313 injection: str, optional 

1314 path to an injection file used in the analysis 

1315 file_format, str, optional 

1316 the file format you wish to use when loading. Default None. 

1317 If None, the read function loops through all possible options 

1318 mcmc: Bool, optional 

1319 if True, the result file is an mcmc chain 

1320 """ 

1321 if label in self.grab_data_kwargs.keys(): 

1322 grab_data_kwargs = self.grab_data_kwargs[label] 

1323 else: 

1324 grab_data_kwargs = self.grab_data_kwargs 

1325 

1326 if self.is_pesummary_metafile(file): 

1327 data = self.grab_data_from_metafile( 

1328 file, self.webdir, compare=self.compare_results, 

1329 nsamples=self.nsamples, reweight_samples=self.reweight_samples, 

1330 disable_injection=self.disable_injection, 

1331 keep_nan_likelihood_samples=self.keep_nan_likelihood_samples, 

1332 **grab_data_kwargs 

1333 ) 

1334 else: 

1335 data = self.grab_data_from_file( 

1336 file, label, self.webdir, config=config, injection=injection, 

1337 file_format=file_format, nsamples=self.nsamples, 

1338 disable_prior_sampling=self.disable_prior_sampling, 

1339 nsamples_for_prior=self.nsamples_for_prior, 

1340 path_to_samples=self.path_to_samples[label], 

1341 reweight_samples=self.reweight_samples, 

1342 keep_nan_likelihood_samples=self.keep_nan_likelihood_samples, 

1343 **grab_data_kwargs 

1344 ) 

1345 self._open_result_files.update({file: data["open_file"]}) 

1346 return data 

1347 

1348 @property 

1349 def email(self): 

1350 return self._email 

1351 

1352 @email.setter 

1353 def email(self, email): 

1354 if email is not None and "@" not in email: 

1355 raise InputError("Please provide a valid email address") 

1356 self._email = email 

1357 

1358 @property 

1359 def dump(self): 

1360 return self._dump 

1361 

1362 @dump.setter 

1363 def dump(self, dump): 

1364 self._dump = dump 

1365 

1366 @property 

1367 def palette(self): 

1368 return self._palette 

1369 

1370 @palette.setter 

1371 def palette(self, palette): 

1372 self._palette = palette 

1373 if palette is not conf.palette: 

1374 import seaborn 

1375 

1376 try: 

1377 seaborn.color_palette(palette, n_colors=1) 

1378 logger.info( 

1379 conf.overwrite.format("palette", conf.palette, palette) 

1380 ) 

1381 conf.palette = palette 

1382 except ValueError as e: 

1383 raise InputError( 

1384 "Unrecognised palette. Please choose from one of the " 

1385 "following {}".format( 

1386 ", ".join(seaborn.palettes.SEABORN_PALETTES.keys()) 

1387 ) 

1388 ) 

1389 

1390 @property 

1391 def include_prior(self): 

1392 return self._include_prior 

1393 

1394 @include_prior.setter 

1395 def include_prior(self, include_prior): 

1396 self._include_prior = include_prior 

1397 if include_prior != conf.include_prior: 

1398 conf.overwrite.format("prior", conf.include_prior, include_prior) 

1399 conf.include_prior = include_prior 

1400 

1401 @property 

1402 def colors(self): 

1403 return self._colors 

1404 

1405 @colors.setter 

1406 def colors(self, colors): 

1407 if colors is not None: 

1408 number = len(self.labels) 

1409 if self.existing: 

1410 number += len(self.existing_labels) 

1411 if len(colors) != number and len(colors) > number: 

1412 logger.info( 

1413 "You have passed {} colors for {} result files. Setting " 

1414 "colors = {}".format( 

1415 len(colors), number, colors[:number] 

1416 ) 

1417 ) 

1418 self._colors = colors[:number] 

1419 return 

1420 elif len(colors) != number: 

1421 logger.warning( 

1422 "Number of colors does not match the number of labels. " 

1423 "Using default colors" 

1424 ) 

1425 import seaborn 

1426 

1427 number = len(self.labels) 

1428 if self.existing: 

1429 number += len(self.existing_labels) 

1430 colors = seaborn.color_palette( 

1431 palette=conf.palette, n_colors=number 

1432 ).as_hex() 

1433 self._colors = colors 

1434 

1435 @property 

1436 def linestyles(self): 

1437 return self._linestyles 

1438 

1439 @linestyles.setter 

1440 def linestyles(self, linestyles): 

1441 if linestyles is not None: 

1442 if len(linestyles) != len(self.colors): 

1443 if len(linestyles) > len(self.colors): 

1444 logger.info( 

1445 "You have passed {} linestyles for {} result files. " 

1446 "Setting linestyles = {}".format( 

1447 len(linestyles), len(self.colors), 

1448 linestyles[:len(self.colors)] 

1449 ) 

1450 ) 

1451 self._linestyles = linestyles[:len(self.colors)] 

1452 return 

1453 else: 

1454 logger.warning( 

1455 "Number of linestyles does not match the number of " 

1456 "labels. Using default linestyles" 

1457 ) 

1458 available_linestyles = ["-", "--", ":", "-."] 

1459 linestyles = ["-"] * len(self.colors) 

1460 unique_colors = np.unique(self.colors) 

1461 for color in unique_colors: 

1462 indicies = [num for num, i in enumerate(self.colors) if i == color] 

1463 for idx, j in enumerate(indicies): 

1464 linestyles[j] = available_linestyles[ 

1465 np.mod(idx, len(available_linestyles)) 

1466 ] 

1467 self._linestyles = linestyles 

1468 

1469 @property 

1470 def disable_corner(self): 

1471 return self._disable_corner 

1472 

1473 @disable_corner.setter 

1474 def disable_corner(self, disable_corner): 

1475 self._disable_corner = disable_corner 

1476 if disable_corner: 

1477 logger.warning( 

1478 "No corner plot will be produced. This will reduce overall " 

1479 "runtime but does mean that the interactive corner plot feature " 

1480 "on the webpages will no longer work" 

1481 ) 

1482 

1483 @property 

1484 def add_to_corner(self): 

1485 return self._add_to_corner 

1486 

1487 @add_to_corner.setter 

1488 def add_to_corner(self, add_to_corner): 

1489 self._add_to_corner = self._set_corner_params(add_to_corner) 

1490 

1491 def _set_corner_params(self, corner_params): 

1492 cls = self.__class__.__name__ 

1493 if corner_params is not None: 

1494 for label in self.labels: 

1495 _not_included = [ 

1496 param for param in corner_params if param not in 

1497 self.samples[label].keys() 

1498 ] 

1499 if len(_not_included) == len(corner_params) and cls == "Input": 

1500 logger.warning( 

1501 "None of the chosen corner parameters are " 

1502 "included in the posterior table for '{}'. Using " 

1503 "all available parameters for the corner plot".format( 

1504 label 

1505 ) 

1506 ) 

1507 corner_params = None 

1508 break 

1509 elif len(_not_included): 

1510 logger.warning( 

1511 "The following parameters are not included in the " 

1512 "posterior table for '{}': {}. Not adding to corner " 

1513 "plot".format(label, ", ".join(_not_included)) 

1514 ) 

1515 elif cls == "Input": 

1516 logger.debug( 

1517 "Using all parameters stored in the result file for the " 

1518 "corner plots. This may take some time." 

1519 ) 

1520 return corner_params 

1521 

1522 @property 

1523 def pe_algorithm(self): 

1524 return self._pe_algorithm 

1525 

1526 @pe_algorithm.setter 

1527 def pe_algorithm(self, pe_algorithm): 

1528 self._pe_algorithm = pe_algorithm 

1529 if pe_algorithm is None: 

1530 return 

1531 if len(pe_algorithm) != len(self.labels): 

1532 raise ValueError("Please provide an algorithm for each result file") 

1533 for num, (label, _algorithm) in enumerate(zip(self.labels, pe_algorithm)): 

1534 if "pe_algorithm" in self.file_kwargs[label]["sampler"].keys(): 

1535 _stored = self.file_kwargs[label]["sampler"]["pe_algorithm"] 

1536 if _stored != _algorithm: 

1537 logger.warning( 

1538 "Overwriting the pe_algorithm extracted from the file " 

1539 "'{}': {} with the algorithm provided from the command " 

1540 "line: {}".format( 

1541 self.result_files[num], _stored, _algorithm 

1542 ) 

1543 ) 

1544 self.file_kwargs[label]["sampler"]["pe_algorithm"] = _algorithm 

1545 

1546 @property 

1547 def notes(self): 

1548 return self._notes 

1549 

1550 @notes.setter 

1551 def notes(self, notes): 

1552 self._notes = notes 

1553 if notes is not None: 

1554 if not os.path.isfile(notes): 

1555 raise InputError( 

1556 "No such file or directory called {}".format(notes) 

1557 ) 

1558 try: 

1559 with open(notes, "r") as f: 

1560 self._notes = f.read() 

1561 except FileNotFoundError: 

1562 logger.warning( 

1563 "No such file or directory called {}. Custom notes will " 

1564 "not be added to the summarypages".format(notes) 

1565 ) 

1566 except IOError as e: 

1567 logger.warning( 

1568 "Failed to read {}. Unable to put notes on " 

1569 "summarypages".format(notes) 

1570 ) 

1571 

1572 @property 

1573 def descriptions(self): 

1574 return self._descriptions 

1575 

1576 @descriptions.setter 

1577 def descriptions(self, descriptions): 

1578 import json 

1579 if hasattr(self, "_descriptions") and not len(descriptions): 

1580 return 

1581 elif not len(descriptions): 

1582 self._descriptions = None 

1583 return 

1584 

1585 if len(descriptions) and isinstance(descriptions, dict): 

1586 data = descriptions 

1587 elif len(descriptions): 

1588 descriptions = descriptions[0] 

1589 _is_file = not isinstance(descriptions, dict) 

1590 if hasattr(self, "_descriptions"): 

1591 logger.warning( 

1592 "Ignoring descriptions found in result file and using " 

1593 "descriptions in '{}'".format(descriptions) 

1594 ) 

1595 self._descriptions = None 

1596 if _is_file and not os.path.isfile(descriptions): 

1597 logger.warning( 

1598 "No such file called {}. Unable to add descriptions".format( 

1599 descriptions 

1600 ) 

1601 ) 

1602 return 

1603 if _is_file: 

1604 try: 

1605 with open(descriptions, "r") as f: 

1606 data = json.load(f) 

1607 except json.decoder.JSONDecodeError: 

1608 logger.warning( 

1609 "Unable to open file '{}'. Not storing descriptions".format( 

1610 descriptions 

1611 ) 

1612 ) 

1613 return 

1614 if not all(label in data.keys() for label in self.labels): 

1615 not_included = [ 

1616 label for label in self.labels if label not in data.keys() 

1617 ] 

1618 logger.debug( 

1619 "No description found for '{}'. Using default " 

1620 "description".format(", ".join(not_included)) 

1621 ) 

1622 for label in not_included: 

1623 data[label] = "No description found" 

1624 if len(data.keys()) > len(self.labels): 

1625 logger.warning( 

1626 "Descriptions file contains descriptions for analyses other " 

1627 "than {}. Ignoring other descriptions".format( 

1628 ", ".join(self.labels) 

1629 ) 

1630 ) 

1631 other = [ 

1632 analysis for analysis in data.keys() if analysis not in 

1633 self.labels 

1634 ] 

1635 for analysis in other: 

1636 _ = data.pop(analysis) 

1637 _remove = [] 

1638 for key, desc in data.items(): 

1639 if not isinstance(desc, (str, bytes)): 

1640 logger.warning( 

1641 "Unknown description '{}' for '{}'. The description should " 

1642 "be a string or bytes object" 

1643 ) 

1644 _remove.append(key) 

1645 if len(_remove): 

1646 for analysis in _remove: 

1647 _ = data.pop(analysis) 

1648 self._descriptions = data 

1649 

1650 @property 

1651 def preferred(self): 

1652 return self._preferred 

1653 

1654 @preferred.setter 

1655 def preferred(self, preferred): 

1656 if preferred is not None and preferred not in self.labels: 

1657 logger.warning( 

1658 "'{}' not in list of labels. Unable to stored as the " 

1659 "preferred analysis".format(preferred) 

1660 ) 

1661 self._preferred = None 

1662 elif preferred is not None: 

1663 logger.debug( 

1664 "Setting '{}' as the preferred analysis".format(preferred) 

1665 ) 

1666 self._preferred = preferred 

1667 elif len(self.labels) == 1: 

1668 self._preferred = self.labels[0] 

1669 else: 

1670 self._preferred = None 

1671 if self._preferred is not None: 

1672 try: 

1673 self.file_kwargs[self._preferred]["other"].update( 

1674 {"preferred": "True"} 

1675 ) 

1676 except KeyError: 

1677 self.file_kwargs[self._preferred].update( 

1678 {"other": {"preferred": "True"}} 

1679 ) 

1680 for _label in self.labels: 

1681 if self._preferred is not None and _label == self._preferred: 

1682 continue 

1683 try: 

1684 self.file_kwargs[_label]["other"].update( 

1685 {"preferred": "False"} 

1686 ) 

1687 except KeyError: 

1688 self.file_kwargs[_label].update( 

1689 {"other": {"preferred": "False"}} 

1690 ) 

1691 return 

1692 

1693 @property 

1694 def public(self): 

1695 return self._public 

1696 

1697 @public.setter 

1698 def public(self, public): 

1699 self._public = public 

1700 if public != conf.public: 

1701 logger.info( 

1702 conf.overwrite.format("public", conf.public, public) 

1703 ) 

1704 

1705 @property 

1706 def multi_process(self): 

1707 return self._multi_process 

1708 

1709 @multi_process.setter 

1710 def multi_process(self, multi_process): 

1711 self._multi_process = int(multi_process) 

1712 if multi_process is not None and int(multi_process) != int(conf.multi_process): 

1713 logger.info( 

1714 conf.overwrite.format( 

1715 "multi_process", conf.multi_process, multi_process 

1716 ) 

1717 ) 

1718 

1719 @property 

1720 def publication_kwargs(self): 

1721 return self._publication_kwargs 

1722 

1723 @publication_kwargs.setter 

1724 def publication_kwargs(self, publication_kwargs): 

1725 self._publication_kwargs = publication_kwargs 

1726 if publication_kwargs != {}: 

1727 allowed_kwargs = ["gridsize"] 

1728 if not any(i in publication_kwargs.keys() for i in allowed_kwargs): 

1729 logger.warning( 

1730 "Currently the only allowed publication kwargs are {}. " 

1731 "Ignoring other inputs.".format( 

1732 ", ".join(allowed_kwargs) 

1733 ) 

1734 ) 

1735 

1736 @property 

1737 def ignore_parameters(self): 

1738 return self._ignore_parameters 

1739 

1740 @ignore_parameters.setter 

1741 def ignore_parameters(self, ignore_parameters): 

1742 self._ignore_parameters = ignore_parameters 

1743 if ignore_parameters is not None: 

1744 for num, label in enumerate(self.labels): 

1745 removed_parameters = list_match( 

1746 list(self.samples[label].keys()), ignore_parameters 

1747 ) 

1748 if not len(removed_parameters): 

1749 logger.warning( 

1750 "Failed to remove any parameters from {}".format( 

1751 self.result_files[num] 

1752 ) 

1753 ) 

1754 else: 

1755 logger.warning( 

1756 "Removing parameters: {} from {}".format( 

1757 ", ".join(removed_parameters), 

1758 self.result_files[num] 

1759 ) 

1760 ) 

1761 for ignore in removed_parameters: 

1762 self.samples[label].pop(ignore) 

1763 

1764 @staticmethod 

1765 def _make_directories(webdir, dirs): 

1766 """Make the directories to store the information 

1767 """ 

1768 for i in dirs: 

1769 if not os.path.isdir(os.path.join(webdir, i)): 

1770 make_dir(os.path.join(webdir, i)) 

1771 

1772 def make_directories(self): 

1773 """Make the directories to store the information 

1774 """ 

1775 self._make_directories(self.webdir, self.default_directories) 

1776 

1777 @staticmethod 

1778 def _copy_files(paths): 

1779 """Copy the relevant file to the web directory 

1780 

1781 Parameters 

1782 ---------- 

1783 paths: nd list 

1784 list of files you wish to copy. First element is the path of the 

1785 file to copy and second element is the location of where you 

1786 wish the file to be put 

1787 

1788 Examples 

1789 -------- 

1790 >>> paths = [ 

1791 ... ["config/config.ini", "webdir/config.ini"], 

1792 ... ["samples/samples.h5", "webdir/samples.h5"] 

1793 ... ] 

1794 """ 

1795 import shutil 

1796 

1797 for ff in paths: 

1798 shutil.copyfile(ff[0], ff[1]) 

1799 

1800 def copy_files(self): 

1801 """Copy the relevant file to the web directory 

1802 """ 

1803 self._copy_files(self.default_files_to_copy) 

1804 

1805 def default_labels(self): 

1806 """Return a list of default labels. 

1807 """ 

1808 from time import time 

1809 

1810 def _default_label(file_name): 

1811 return "%s_%s" % (round(time()), file_name) 

1812 

1813 label_list = [] 

1814 if self.result_files is None or len(self.result_files) == 0: 

1815 raise InputError("Please provide a results file") 

1816 elif self.mcmc_samples: 

1817 f = self.result_files[0] 

1818 file_name = os.path.splitext(os.path.basename(f))[0] 

1819 label_list.append(_default_label(file_name)) 

1820 else: 

1821 for num, i in enumerate(self.result_files): 

1822 file_name = os.path.splitext(os.path.basename(i))[0] 

1823 label_list.append(_default_label(file_name)) 

1824 

1825 duplicates = dict(set( 

1826 (x, label_list.count(x)) for x in 

1827 filter(lambda rec: label_list.count(rec) > 1, label_list))) 

1828 

1829 for i in duplicates.keys(): 

1830 for j in range(duplicates[i]): 

1831 ind = label_list.index(i) 

1832 label_list[ind] += "_%s" % (j) 

1833 if self.add_to_existing: 

1834 for num, i in enumerate(label_list): 

1835 if i in self.existing_labels: 

1836 ind = label_list.index(i) 

1837 label_list[ind] += "_%s" % (num) 

1838 return label_list 

1839 

1840 @staticmethod 

1841 def get_package_information(): 

1842 """Return a dictionary of parameter information 

1843 """ 

1844 from pesummary._version_helper import PackageInformation 

1845 from operator import itemgetter 

1846 

1847 _package = PackageInformation() 

1848 package_info = _package.package_info 

1849 package_dir = _package.package_dir 

1850 if "build_string" in package_info[0]: # conda list 

1851 headings = ("name", "version", "channel", "build_string") 

1852 else: # pip list installed 

1853 headings = ("name", "version") 

1854 packages = np.array([ 

1855 tuple(pkg[col.lower()] for col in headings) for pkg in 

1856 sorted(package_info, key=itemgetter("name")) 

1857 ], dtype=[(col, "S20") for col in headings]).view(np.recarray) 

1858 return { 

1859 "packages": packages, "environment": [package_dir], 

1860 "manager": _package.package_manager 

1861 } 

1862 

1863 def grab_key_data_from_result_files(self): 

1864 """Grab the mean, median, maxL and standard deviation for all 

1865 parameters for all each result file 

1866 """ 

1867 key_data = { 

1868 key: samples.key_data for key, samples in self.samples.items() 

1869 } 

1870 for key, val in self.samples.items(): 

1871 for j in val.keys(): 

1872 _inj = self.injection_data[key][j] 

1873 key_data[key][j]["injected"] = ( 

1874 _inj[0] if not math.isnan(_inj) and isinstance( 

1875 _inj, (list, np.ndarray) 

1876 ) else _inj 

1877 ) 

1878 return key_data 

1879 

1880 

1881class BaseInput(_Input): 

1882 """Class to handle and store base command line arguments 

1883 """ 

1884 def __init__(self, opts, ignore_copy=False, checkpoint=None, gw=False): 

1885 self.opts = opts 

1886 self.gw = gw 

1887 self.restart_from_checkpoint = self.opts.restart_from_checkpoint 

1888 if checkpoint is not None: 

1889 for key, item in vars(checkpoint).items(): 

1890 setattr(self, key, item) 

1891 logger.info( 

1892 "Loaded command line arguments: {}".format(self.opts) 

1893 ) 

1894 self.restart_from_checkpoint = True 

1895 self._restarted_from_checkpoint = True 

1896 return 

1897 self.seed = self.opts.seed 

1898 self.result_files = self.opts.samples 

1899 self.user = self.opts.user 

1900 self.existing = self.opts.existing 

1901 self.add_to_existing = False 

1902 if self.existing is not None: 

1903 self.add_to_existing = True 

1904 self.existing_metafile = True 

1905 self.webdir = self.opts.webdir 

1906 self._restarted_from_checkpoint = False 

1907 self.resume_file_dir = conf.checkpoint_dir(self.webdir) 

1908 self.resume_file = conf.resume_file 

1909 self._resume_file_path = os.path.join( 

1910 self.resume_file_dir, self.resume_file 

1911 ) 

1912 self.make_directories() 

1913 self.email = self.opts.email 

1914 self.pe_algorithm = self.opts.pe_algorithm 

1915 self.multi_process = self.opts.multi_process 

1916 self.package_information = self.get_package_information() 

1917 if not ignore_copy: 

1918 self.copy_files() 

1919 self.write_current_state() 

1920 

1921 @property 

1922 def default_directories(self): 

1923 return ["checkpoint"] 

1924 

1925 @property 

1926 def default_files_to_copy(self): 

1927 return [] 

1928 

1929 def write_current_state(self): 

1930 """Write the current state of the input class to file 

1931 """ 

1932 from pesummary.io import write 

1933 write( 

1934 self, outdir=self.resume_file_dir, file_format="pickle", 

1935 filename=self.resume_file, overwrite=True 

1936 ) 

1937 logger.debug( 

1938 "Written checkpoint file: {}".format(self._resume_file_path) 

1939 ) 

1940 

1941 

1942class SamplesInput(BaseInput): 

1943 """Class to handle and store sample specific command line arguments 

1944 """ 

1945 def __init__(self, *args, extra_options=None, **kwargs): 

1946 """ 

1947 """ 

1948 super(SamplesInput, self).__init__(*args, **kwargs) 

1949 if self.result_files is not None: 

1950 self._open_result_files = {path: None for path in self.result_files} 

1951 self.meta_file = False 

1952 if self.result_files is not None and len(self.result_files) == 1: 

1953 self.meta_file = self.is_pesummary_metafile(self.result_files[0]) 

1954 self.compare_results = self.opts.compare_results 

1955 self.disable_injection = self.opts.disable_injection 

1956 if self.existing is not None: 

1957 self.existing_data = self.grab_data_from_metafile( 

1958 self.existing_metafile, self.existing, 

1959 compare=self.compare_results 

1960 ) 

1961 self.existing_samples = self.existing_data["samples"] 

1962 self.existing_injection_data = self.existing_data["injection_data"] 

1963 self.existing_file_version = self.existing_data["file_version"] 

1964 self.existing_file_kwargs = self.existing_data["file_kwargs"] 

1965 self.existing_priors = self.existing_data["prior"] 

1966 self.existing_config = self.existing_data["config"] 

1967 self.existing_labels = self.existing_data["labels"] 

1968 self.existing_weights = self.existing_data["weights"] 

1969 else: 

1970 self.existing_metafile = None 

1971 self.existing_labels = None 

1972 self.existing_weights = None 

1973 self.existing_samples = None 

1974 self.existing_file_version = None 

1975 self.existing_file_kwargs = None 

1976 self.existing_priors = None 

1977 self.existing_config = None 

1978 self.existing_injection_data = None 

1979 self.mcmc_samples = self.opts.mcmc_samples 

1980 self.labels = self.opts.labels 

1981 self.weights = {i: None for i in self.labels} 

1982 self.config = self.opts.config 

1983 self.injection_file = self.opts.inj_file 

1984 self.regenerate = self.opts.regenerate 

1985 if extra_options is not None: 

1986 for opt in extra_options: 

1987 setattr(self, opt, getattr(self.opts, opt)) 

1988 self.nsamples_for_prior = self.opts.nsamples_for_prior 

1989 self.priors = self.opts.prior_file 

1990 self.disable_prior_sampling = self.opts.disable_prior_sampling 

1991 self.path_to_samples = self.opts.path_to_samples 

1992 self.file_format = self.opts.file_format 

1993 self.nsamples = self.opts.nsamples 

1994 self.keep_nan_likelihood_samples = self.opts.keep_nan_likelihood_samples 

1995 self.reweight_samples = self.opts.reweight_samples 

1996 self.samples = self.opts.samples 

1997 self.ignore_parameters = self.opts.ignore_parameters 

1998 self.burnin_method = self.opts.burnin_method 

1999 self.burnin = self.opts.burnin 

2000 self.same_parameters = [] 

2001 if self.mcmc_samples: 

2002 self._samples = {label: self.samples.T for label in self.labels} 

2003 self.write_current_state() 

2004 

2005 @property 

2006 def analytic_prior_dict(self): 

2007 return { 

2008 label: "\n".join( 

2009 [ 

2010 "{} = {}".format(key, value) for key, value in 

2011 self.priors["analytic"][label].items() 

2012 ] 

2013 ) if "analytic" in self.priors.keys() and label in 

2014 self.priors["analytic"].keys() else None for label in self.labels 

2015 } 

2016 

2017 @property 

2018 def same_parameters(self): 

2019 return self._same_parameters 

2020 

2021 @same_parameters.setter 

2022 def same_parameters(self, same_parameters): 

2023 self._same_parameters = self.intersect_samples_dict(self.samples) 

2024 

2025 def intersect_samples_dict(self, samples): 

2026 parameters = [ 

2027 list(samples[key].keys()) for key in samples.keys() 

2028 ] 

2029 params = list(set.intersection(*[set(l) for l in parameters])) 

2030 return params 

2031 

2032 

2033class PlottingInput(SamplesInput): 

2034 """Class to handle and store plotting specific command line arguments 

2035 """ 

2036 def __init__(self, *args, **kwargs): 

2037 super(PlottingInput, self).__init__(*args, **kwargs) 

2038 self.style_file = self.opts.style_file 

2039 self.publication = self.opts.publication 

2040 self.publication_kwargs = self.opts.publication_kwargs 

2041 self.kde_plot = self.opts.kde_plot 

2042 self.custom_plotting = self.opts.custom_plotting 

2043 self.add_to_corner = self.opts.add_to_corner 

2044 self.corner_params = self.add_to_corner 

2045 self.palette = self.opts.palette 

2046 self.include_prior = self.opts.include_prior 

2047 self.colors = self.opts.colors 

2048 self.linestyles = self.opts.linestyles 

2049 self.disable_corner = self.opts.disable_corner 

2050 self.disable_comparison = self.opts.disable_comparison 

2051 self.disable_interactive = self.opts.disable_interactive 

2052 self.disable_expert = not self.opts.enable_expert 

2053 self.multi_threading_for_plots = self.multi_process 

2054 self.write_current_state() 

2055 

2056 @property 

2057 def default_directories(self): 

2058 dirs = super(PlottingInput, self).default_directories 

2059 dirs += ["plots", "plots/corner", "plots/publication"] 

2060 return dirs 

2061 

2062 

2063class WebpageInput(SamplesInput): 

2064 """Class to handle and store webpage specific command line arguments 

2065 """ 

2066 def __init__(self, *args, **kwargs): 

2067 super(WebpageInput, self).__init__(*args, **kwargs) 

2068 self.baseurl = self.opts.baseurl 

2069 self.existing_plot = self.opts.existing_plot 

2070 self.pe_algorithm = self.opts.pe_algorithm 

2071 self.notes = self.opts.notes 

2072 self.dump = self.opts.dump 

2073 self.hdf5 = not self.opts.save_to_json 

2074 self.external_hdf5_links = self.opts.external_hdf5_links 

2075 self.file_kwargs["webpage_url"] = self.baseurl + "/home.html" 

2076 self.write_current_state() 

2077 

2078 @property 

2079 def default_directories(self): 

2080 dirs = super(WebpageInput, self).default_directories 

2081 dirs += ["js", "html", "css"] 

2082 return dirs 

2083 

2084 @property 

2085 def default_files_to_copy(self): 

2086 from pesummary import core 

2087 files_to_copy = super(WebpageInput, self).default_files_to_copy 

2088 path = core.__path__[0] 

2089 scripts = glob(os.path.join(path, "js", "*.js")) 

2090 for i in scripts: 

2091 files_to_copy.append( 

2092 [i, os.path.join(self.webdir, "js", os.path.basename(i))] 

2093 ) 

2094 scripts = glob(os.path.join(path, "css", "*.css")) 

2095 for i in scripts: 

2096 files_to_copy.append( 

2097 [i, os.path.join(self.webdir, "css", os.path.basename(i))] 

2098 ) 

2099 return files_to_copy 

2100 

2101 

2102class WebpagePlusPlottingInput(PlottingInput, WebpageInput): 

2103 """Class to handle and store webpage and plotting specific command line 

2104 arguments 

2105 """ 

2106 def __init__(self, *args, **kwargs): 

2107 super(WebpagePlusPlottingInput, self).__init__(*args, **kwargs) 

2108 self.copy_files() 

2109 

2110 @property 

2111 def default_directories(self): 

2112 return super(WebpagePlusPlottingInput, self).default_directories 

2113 

2114 @property 

2115 def default_files_to_copy(self): 

2116 return super(WebpagePlusPlottingInput, self).default_files_to_copy 

2117 

2118 

2119class MetaFileInput(SamplesInput): 

2120 """Class to handle and store metafile specific command line arguments 

2121 """ 

2122 def __init__(self, *args, **kwargs): 

2123 kwargs.update({"ignore_copy": True}) 

2124 super(MetaFileInput, self).__init__(*args, **kwargs) 

2125 self.copy_files() 

2126 self.filename = self.opts.filename 

2127 self.hdf5 = not self.opts.save_to_json 

2128 self.hdf5_compression = self.opts.hdf5_compression 

2129 self.external_hdf5_links = self.opts.external_hdf5_links 

2130 self.descriptions = self.opts.descriptions 

2131 self.preferred = self.opts.preferred 

2132 self.write_current_state() 

2133 

2134 @property 

2135 def default_directories(self): 

2136 dirs = super(MetaFileInput, self).default_directories 

2137 dirs += ["samples", "config"] 

2138 return dirs 

2139 

2140 @property 

2141 def default_files_to_copy(self): 

2142 files_to_copy = super(MetaFileInput, self).default_files_to_copy 

2143 if not all(i is None for i in self.config): 

2144 for num, i in enumerate(self.config): 

2145 if i is not None and self.webdir not in i: 

2146 filename = "_".join( 

2147 [self.labels[num], "config.ini"] 

2148 ) 

2149 files_to_copy.append( 

2150 [i, os.path.join(self.webdir, "config", filename)] 

2151 ) 

2152 for num, _file in enumerate(self.result_files): 

2153 if not self.mcmc_samples: 

2154 filename = "{}_{}".format(self.labels[num], Path(_file).name) 

2155 else: 

2156 filename = "chain_{}_{}".format(num, Path(_file).name) 

2157 files_to_copy.append( 

2158 [_file, os.path.join(self.webdir, "samples", filename)] 

2159 ) 

2160 return files_to_copy 

2161 

2162 

2163class WebpagePlusPlottingPlusMetaFileInput(MetaFileInput, WebpagePlusPlottingInput): 

2164 """Class to handle and store webpage, plotting and metafile specific command 

2165 line arguments 

2166 """ 

2167 def __init__(self, *args, **kwargs): 

2168 super(WebpagePlusPlottingPlusMetaFileInput, self).__init__( 

2169 *args, **kwargs 

2170 ) 

2171 

2172 @property 

2173 def default_directories(self): 

2174 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_directories 

2175 

2176 @property 

2177 def default_files_to_copy(self): 

2178 return super(WebpagePlusPlottingPlusMetaFileInput, self).default_files_to_copy 

2179 

2180 

2181@deprecation( 

2182 "The Input class is deprecated. Please use either the BaseInput, " 

2183 "SamplesInput, PlottingInput, WebpageInput, WebpagePlusPlottingInput, " 

2184 "MetaFileInput or the WebpagePlusPlottingPlusMetaFileInput class" 

2185) 

2186class Input(WebpagePlusPlottingPlusMetaFileInput): 

2187 pass 

2188 

2189 

2190def load_current_state(resume_file): 

2191 """Load a pickle file containing checkpoint information 

2192 

2193 Parameters 

2194 ---------- 

2195 resume_file: str 

2196 path to a checkpoint file 

2197 """ 

2198 from pesummary.io import read 

2199 if not os.path.isfile(resume_file): 

2200 logger.info( 

2201 "Unable to find resume file. Not restarting from checkpoint" 

2202 ) 

2203 return 

2204 logger.info( 

2205 "Reading checkpoint file: {}".format(resume_file) 

2206 ) 

2207 state = read(resume_file, checkpoint=True) 

2208 return state