Coverage for pesummary/core/file/formats/base_read.py: 57.4%

427 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import os 

4import numpy as np 

5import h5py 

6from pesummary.utils.parameters import MultiAnalysisParameters, Parameters 

7from pesummary.utils.samples_dict import ( 

8 MultiAnalysisSamplesDict, SamplesDict, MCMCSamplesDict, Array 

9) 

10from pesummary.utils.utils import logger 

11 

12__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

13 

14 

15def _downsample(samples, number, extra_kwargs=None): 

16 """Downsample a posterior table 

17 

18 Parameters 

19 ---------- 

20 samples: 2d list 

21 list of posterior samples where the columns correspond to a given 

22 parameter 

23 number: int 

24 number of posterior samples you wish to downsample to 

25 extra_kwargs: dict, optional 

26 dictionary of kwargs to modify 

27 """ 

28 from pesummary.utils.utils import resample_posterior_distribution 

29 import copy 

30 

31 _samples = np.array(samples).T 

32 if number > len(_samples[0]): 

33 raise ValueError( 

34 "Failed to downsample the posterior samples to {} because " 

35 "there are only {} samples stored in the file.".format( 

36 number, len(_samples[0]) 

37 ) 

38 ) 

39 _samples = np.array(resample_posterior_distribution(_samples, number)) 

40 if extra_kwargs is None: 

41 return _samples.T.tolist() 

42 _extra_kwargs = copy.deepcopy(extra_kwargs) 

43 _extra_kwargs["sampler"]["nsamples"] = number 

44 return _samples.T.tolist(), _extra_kwargs 

45 

46 

47class Read(object): 

48 """Base class to read in a results file 

49 

50 Parameters 

51 ---------- 

52 path_to_results_file: str 

53 path to the results file you wish to load 

54 remove_nan_likelihood_samples: Bool, optional 

55 if True, remove samples which have log_likelihood='nan'. Default True 

56 

57 Attributes 

58 ---------- 

59 parameters: list 

60 list of parameters stored in the result file 

61 samples: 2d list 

62 list of samples stored in the result file 

63 samples_dict: dict 

64 dictionary of samples stored in the result file keyed by parameters 

65 input_version: str 

66 version of the result file passed. 

67 extra_kwargs: dict 

68 dictionary of kwargs that were extracted from the result file 

69 pe_algorithm: str 

70 name of the algorithm used to generate the posterior samples 

71 

72 Methods 

73 ------- 

74 downsample: 

75 downsample the posterior samples stored in the result file 

76 to_dat: 

77 save the posterior samples to a .dat file 

78 to_latex_table: 

79 convert the posterior samples to a latex table 

80 generate_latex_macros: 

81 generate a set of latex macros for the stored posterior samples 

82 """ 

83 def __init__( 

84 self, path_to_results_file, remove_nan_likelihood_samples=True, **kwargs 

85 ): 

86 self.path_to_results_file = path_to_results_file 

87 self.mcmc_samples = False 

88 self.remove_nan_likelihood_samples = remove_nan_likelihood_samples 

89 self.extension = self.extension_from_path(self.path_to_results_file) 

90 self.converted_parameters = [] 

91 

92 @classmethod 

93 def load_file(cls, path, **kwargs): 

94 """Initialize the class with a file 

95 

96 Parameters 

97 ---------- 

98 path: str 

99 path to the result file you wish to load 

100 **kwargs: dict, optional 

101 all kwargs passed to the class 

102 """ 

103 if not os.path.isfile(path): 

104 raise FileNotFoundError("%s does not exist" % (path)) 

105 return cls(path, **kwargs) 

106 

107 @staticmethod 

108 def load_from_function(function, path_to_file, **kwargs): 

109 """Load a file according to a given function 

110 

111 Parameters 

112 ---------- 

113 function: func 

114 callable function that will load in your file 

115 path_to_file: str 

116 path to the file that you wish to load 

117 kwargs: dict 

118 all kwargs are passed to the function 

119 """ 

120 return function(path_to_file, **kwargs) 

121 

122 @staticmethod 

123 def check_for_nan_likelihoods(parameters, samples, remove=False): 

124 """Check to see if there are any samples with log_likelihood='nan' in 

125 the posterior table and remove if requested 

126 

127 Parameters 

128 ---------- 

129 parameters: list 

130 list of parameters stored in the result file 

131 samples: np.ndarray 

132 array of samples for each parameter 

133 remove: Bool, optional 

134 if True, remove samples with log_likelihood='nan' from samples 

135 """ 

136 import math 

137 if "log_likelihood" not in parameters: 

138 return parameters, samples 

139 ind = parameters.index("log_likelihood") 

140 likelihoods = np.array(samples).T[ind] 

141 inds = np.array( 

142 [math.isnan(_) for _ in likelihoods], dtype=bool 

143 ) 

144 if not sum(inds): 

145 return parameters, samples 

146 msg = ( 

147 "Posterior table contains {} samples with 'nan' log likelihood. " 

148 ) 

149 if remove: 

150 msg += "Removing samples from posterior table." 

151 samples = np.array(samples)[~inds].tolist() 

152 else: 

153 msg += "This may cause problems when analysing posterior samples." 

154 logger.warning(msg.format(sum(inds))) 

155 return parameters, samples 

156 

157 @staticmethod 

158 def check_for_weights(parameters, samples): 

159 """Check to see if the samples are weighted 

160 

161 Parameters 

162 ---------- 

163 parameters: list 

164 list of parameters stored in the result file 

165 samples: np.ndarray 

166 array of samples for each parameter 

167 """ 

168 likely_names = ["weights", "weight"] 

169 if any(i in parameters for i in likely_names): 

170 ind = ( 

171 parameters.index("weights") if "weights" in parameters else 

172 parameters.index("weight") 

173 ) 

174 return Array(np.array(samples).T[ind]) 

175 return None 

176 

177 @property 

178 def pe_algorithm(self): 

179 try: 

180 return self.extra_kwargs["sampler"]["pe_algorithm"] 

181 except KeyError: 

182 return None 

183 

184 def __repr__(self): 

185 return self.summary() 

186 

187 def _parameter_summary(self, parameters, parameters_to_show=4): 

188 """Return a summary of the parameter stored 

189 

190 Parameters 

191 ---------- 

192 parameters: list 

193 list of parameters to create a summary for 

194 parameters_to_show: int, optional 

195 number of parameters to show. Default 4. 

196 """ 

197 params = parameters 

198 if len(parameters) > parameters_to_show: 

199 params = parameters[:2] + ["..."] + parameters[-2:] 

200 return ", ".join(params) 

201 

202 def summary( 

203 self, parameters_to_show=4, show_parameters=True, show_nsamples=True 

204 ): 

205 """Return a summary of the contents of the file 

206 

207 Parameters 

208 ---------- 

209 parameters_to_show: int, optional 

210 number of parameters to show. Default 4 

211 show_parameters: Bool, optional 

212 if True print a list of the parameters stored 

213 show_nsamples: Bool, optional 

214 if True print how many samples are stored in the file 

215 """ 

216 string = "" 

217 if self.path_to_results_file is not None: 

218 string += "file: {}\n".format(self.path_to_results_file) 

219 string += "cls: {}.{}\n".format( 

220 self.__class__.__module__, self.__class__.__name__ 

221 ) 

222 if show_nsamples: 

223 string += "nsamples: {}\n".format(len(self.samples)) 

224 if show_parameters: 

225 string += "parameters: {}".format( 

226 self._parameter_summary( 

227 self.parameters, parameters_to_show=parameters_to_show 

228 ) 

229 ) 

230 return string 

231 

232 attrs = { 

233 "input_version": "version", "extra_kwargs": "kwargs", 

234 "priors": "prior", "analytic": "analytic", "labels": "labels", 

235 "config": "config", "weights": "weights", "history": "history", 

236 "description": "description" 

237 } 

238 

239 def _load(self, function, **kwargs): 

240 """Extract the data from a file using a given function 

241 

242 Parameters 

243 ---------- 

244 function: func 

245 function you wish to use to extract the data 

246 **kwargs: dict, optional 

247 optional kwargs to pass to the load function 

248 """ 

249 return self.load_from_function( 

250 function, self.path_to_results_file, **kwargs 

251 ) 

252 

253 def load(self, function, _data=None, **kwargs): 

254 """Load a results file according to a given function 

255 

256 Parameters 

257 ---------- 

258 function: func 

259 callable function that will load in your results file 

260 """ 

261 self.data = _data 

262 if _data is None: 

263 self.data = self._load(function, **kwargs) 

264 if isinstance(self.data["parameters"][0], list): 

265 _cls = MultiAnalysisParameters 

266 else: 

267 _cls = Parameters 

268 self.parameters = _cls(self.data["parameters"]) 

269 self.converted_parameters = [] 

270 self.samples = self.data["samples"] 

271 self.parameters, self.samples = self.check_for_nan_likelihoods( 

272 self.parameters, self.samples, 

273 remove=self.remove_nan_likelihood_samples 

274 ) 

275 if "mcmc_samples" in self.data.keys(): 

276 self.mcmc_samples = self.data["mcmc_samples"] 

277 if "injection" in self.data.keys(): 

278 if isinstance(self.data["injection"], dict): 

279 self.injection_parameters = { 

280 key.decode("utf-8") if isinstance(key, bytes) else key: val 

281 for key, val in self.data["injection"].items() 

282 } 

283 elif isinstance(self.data["injection"], list): 

284 self.injection_parameters = [ 

285 { 

286 key.decode("utf-8") if isinstance(key, bytes) else 

287 key: val for key, val in i.items() 

288 } for i in self.data["injection"] 

289 ] 

290 else: 

291 self.injection_parameters = self.data["injection"] 

292 for new_attr, existing_attr in self.attrs.items(): 

293 if existing_attr in self.data.keys(): 

294 setattr(self, new_attr, self.data[existing_attr]) 

295 else: 

296 setattr(self, new_attr, None) 

297 if self.input_version is None: 

298 self.input_version = self._default_version 

299 if self.extra_kwargs is None: 

300 self.extra_kwargs = self._default_kwargs 

301 if self.description is None: 

302 self.description = self._default_description 

303 if self.weights is None: 

304 self.weights = self.check_for_weights(self.parameters, self.samples) 

305 

306 @staticmethod 

307 def extension_from_path(path): 

308 """Return the extension of the file from the file path 

309 

310 Parameters 

311 ---------- 

312 path: str 

313 path to the results file 

314 """ 

315 extension = path.split(".")[-1] 

316 return extension 

317 

318 @staticmethod 

319 def guess_path_to_samples(path): 

320 """Guess the path to the posterior samples stored in an hdf5 object 

321 

322 Parameters 

323 ---------- 

324 path: str 

325 path to the results file 

326 """ 

327 def _find_name(name, item): 

328 c1 = "posterior_samples" in name or "posterior" in name 

329 c2 = isinstance(item, (h5py._hl.dataset.Dataset, np.ndarray)) 

330 _group = isinstance(item, h5py._hl.group.Group) 

331 c3, c4 = False, False 

332 if _group: 

333 try: 

334 if isinstance(item[0], (float, int, np.number)): 

335 c3 = True 

336 except (TypeError, AttributeError): 

337 c3 = False 

338 try: 

339 keys = list(item.keys()) 

340 if isinstance(item[keys[0]], (h5py._hl.dataset.Dataset, np.ndarray)): 

341 c4 = True 

342 except (TypeError, IndexError, AttributeError): 

343 c4 = False 

344 c5 = ( 

345 _group and "parameter_names" in item.keys() and "samples" in item.keys() 

346 ) 

347 if c1 and c4: 

348 paths.append(name) 

349 elif c1 and c3: 

350 paths.append(name) 

351 elif c1 and c5: 

352 paths.append(name) 

353 elif c1 and c2: 

354 if "/".join(name.split("/")[:-1]) not in paths: 

355 paths.append(name) 

356 

357 f = h5py.File(path, 'r') 

358 paths = [] 

359 f.visititems(_find_name) 

360 f.close() 

361 if len(paths) == 1: 

362 return paths[0] 

363 elif len(paths) > 1: 

364 raise ValueError( 

365 "Found multiple posterior sample tables in '{}': {}. Not sure " 

366 "which to load.".format( 

367 path, ", ".join(paths) 

368 ) 

369 ) 

370 else: 

371 raise ValueError( 

372 "Unable to find a posterior samples table in '{}'".format(path) 

373 ) 

374 

375 def generate_all_posterior_samples(self, **kwargs): 

376 """Empty function 

377 """ 

378 pass 

379 

380 def add_fixed_parameters_from_config_file(self, config_file): 

381 """Search the conifiguration file and add fixed parameters to the 

382 list of parameters and samples 

383 

384 Parameters 

385 ---------- 

386 config_file: str 

387 path to the configuration file 

388 """ 

389 pass 

390 

391 def add_injection_parameters_from_file(self, injection_file, **kwargs): 

392 """Populate the 'injection_parameters' property with data from a file 

393 

394 Parameters 

395 ---------- 

396 injection_file: str 

397 path to injection file 

398 """ 

399 self.injection_parameters = self._grab_injection_parameters_from_file( 

400 injection_file, **kwargs 

401 ) 

402 

403 def _grab_injection_parameters_from_file( 

404 self, path, cls=None, add_nans=True, **kwargs 

405 ): 

406 """Extract data from an injection file 

407 

408 Parameters 

409 ---------- 

410 path: str 

411 path to injection file 

412 cls: class, optional 

413 class to read in injection file. The class must have a read class 

414 method and a samples_dict property. Default None which means that 

415 the pesummary.core.file.injection.Injection class is used 

416 """ 

417 if cls is None: 

418 from pesummary.core.file.injection import Injection 

419 cls = Injection 

420 data = cls.read(path, **kwargs).samples_dict 

421 for i in self.parameters: 

422 if i not in data.keys(): 

423 data[i] = float("nan") 

424 return data 

425 

426 def write( 

427 self, package="core", file_format="dat", extra_kwargs=None, 

428 file_versions=None, **kwargs 

429 ): 

430 """Save the data to file 

431 

432 Parameters 

433 ---------- 

434 package: str, optional 

435 package you wish to use when writing the data 

436 kwargs: dict, optional 

437 all additional kwargs are passed to the pesummary.io.write function 

438 """ 

439 from pesummary.io import write 

440 

441 if file_format == "pesummary" and np.array(self.parameters).ndim > 1: 

442 args = [self.samples_dict] 

443 else: 

444 args = [self.parameters, self.samples] 

445 if extra_kwargs is None: 

446 extra_kwargs = self.extra_kwargs 

447 if file_versions is None: 

448 file_versions = self.input_version 

449 if file_format == "ini": 

450 kwargs["file_format"] = "ini" 

451 return write(getattr(self, "config", None), **kwargs) 

452 else: 

453 return write( 

454 *args, package=package, file_versions=file_versions, 

455 file_kwargs=extra_kwargs, file_format=file_format, **kwargs 

456 ) 

457 

458 def downsample(self, number): 

459 """Downsample the posterior samples stored in the result file 

460 """ 

461 self.samples, self.extra_kwargs = _downsample( 

462 self.samples, number, extra_kwargs=self.extra_kwargs 

463 ) 

464 

465 @staticmethod 

466 def latex_table(samples, parameter_dict=None, labels=None): 

467 """Return a latex table displaying the passed data. 

468 

469 Parameters 

470 ---------- 

471 samples_dict: list 

472 list of pesummary.utils.utils.SamplesDict objects 

473 parameter_dict: dict, optional 

474 dictionary of parameters that you wish to include in the latex 

475 table. The keys are the name of the parameters and the items are 

476 the descriptive text. If None, all parameters are included 

477 """ 

478 table = ( 

479 "\\begin{table}[hptb]\n\\begin{ruledtabular}\n\\begin{tabular}" 

480 "{l %s}\n" % ("c " * len(samples)) 

481 ) 

482 if labels: 

483 table += ( 

484 " & " + " & ".join(labels) 

485 ) 

486 table += "\\\ \n\\hline \\\ \n" 

487 data = {i: i for i in samples[0].keys()} 

488 if parameter_dict is not None: 

489 import copy 

490 

491 data = copy.deepcopy(parameter_dict) 

492 for param in parameter_dict.keys(): 

493 if not all(param in samples_dict.keys() for samples_dict in samples): 

494 logger.warning( 

495 "{} not in list of parameters. Not adding to " 

496 "table".format(param) 

497 ) 

498 data.pop(param) 

499 

500 for param, desc in data.items(): 

501 table += "{}".format(desc) 

502 for samples_dict in samples: 

503 median = samples_dict[param].average(type="median") 

504 confidence = samples_dict[param].credible_interval() 

505 table += ( 

506 " & $%s^{+%s}_{-%s}$" % ( 

507 np.round(median, 2), 

508 np.round(confidence[1] - median, 2), 

509 np.round(median - confidence[0], 2) 

510 ) 

511 ) 

512 table += "\\\ \n" 

513 table += ( 

514 "\\end{tabular}\n\\end{ruledtabular}\n\\caption{}\n\\end{table}" 

515 ) 

516 return table 

517 

518 @staticmethod 

519 def latex_macros( 

520 samples, parameter_dict=None, labels=None, rounding="smart" 

521 ): 

522 """Return a latex table displaying the passed data. 

523 

524 Parameters 

525 ---------- 

526 samples_dict: list 

527 list of pesummary.utils.utils.SamplesDict objects 

528 parameter_dict: dict, optional 

529 dictionary of parameters that you wish to generate macros for. The 

530 keys are the name of the parameters and the items are the latex 

531 macros name you wish to use. If None, all parameters are included. 

532 rounding: int, optional 

533 decimal place for rounding. Default uses the 

534 `pesummary.utils.utils.smart_round` function to round according to 

535 the uncertainty 

536 """ 

537 macros = "" 

538 data = {i: i for i in samples[0].keys()} 

539 if parameter_dict is not None: 

540 import copy 

541 

542 data = copy.deepcopy(parameter_dict) 

543 for param in parameter_dict.keys(): 

544 if not all(param in samples_dict.keys() for samples_dict in samples): 

545 logger.warning( 

546 "{} not in list of parameters. Not generating " 

547 "macro".format(param) 

548 ) 

549 data.pop(param) 

550 for param, desc in data.items(): 

551 for num, samples_dict in enumerate(samples): 

552 if labels: 

553 description = "{}{}".format(desc, labels[num]) 

554 else: 

555 description = desc 

556 

557 median = samples_dict[param].average(type="median") 

558 confidence = samples_dict[param].credible_interval() 

559 if rounding == "smart": 

560 from pesummary.utils.utils import smart_round 

561 

562 median, upper, low = smart_round([ 

563 median, confidence[1] - median, median - confidence[0] 

564 ]) 

565 else: 

566 median = np.round(median, rounding) 

567 low = np.round(median - confidence[0], rounding) 

568 upper = np.round(confidence[1] - median, rounding) 

569 macros += ( 

570 "\\def\\%s{$%s_{-%s}^{+%s}$}\n" % ( 

571 description, median, low, upper 

572 ) 

573 ) 

574 macros += ( 

575 "\\def\\%smedian{$%s$}\n" % (description, median) 

576 ) 

577 macros += ( 

578 "\\def\\%supper{$%s$}\n" % ( 

579 description, np.round(median + upper, 9) 

580 ) 

581 ) 

582 macros += ( 

583 "\\def\\%slower{$%s$}\n" % ( 

584 description, np.round(median - low, 9) 

585 ) 

586 ) 

587 return macros 

588 

589 

590class SingleAnalysisRead(Read): 

591 """Base class to read in a results file which contains a single analyses 

592 

593 Parameters 

594 ---------- 

595 path_to_results_file: str 

596 path to the results file you wish to load 

597 remove_nan_likelihood_samples: Bool, optional 

598 if True, remove samples which have log_likelihood='nan'. Default True 

599 

600 Attributes 

601 ---------- 

602 parameters: list 

603 list of parameters stored in the file 

604 samples: 2d list 

605 list of samples stored in the result file 

606 samples_dict: dict 

607 dictionary of samples stored in the result file 

608 input_version: str 

609 version of the result file passed 

610 extra_kwargs: dict 

611 dictionary of kwargs that were extracted from the result file 

612 

613 Methods 

614 ------- 

615 downsample: 

616 downsample the posterior samples stored in the result file 

617 to_dat: 

618 save the posterior samples to a .dat file 

619 to_latex_table: 

620 convert the posterior samples to a latex table 

621 generate_latex_macros: 

622 generate a set of latex macros for the stored posterior samples 

623 reweight_samples: 

624 reweight the posterior and/or samples according to a new prior 

625 """ 

626 def __init__(self, *args, **kwargs): 

627 super(SingleAnalysisRead, self).__init__(*args, **kwargs) 

628 

629 @property 

630 def samples_dict(self): 

631 if self.mcmc_samples: 

632 return MCMCSamplesDict( 

633 self.parameters, [np.array(i).T for i in self.samples] 

634 ) 

635 return SamplesDict(self.parameters, np.array(self.samples).T) 

636 

637 @property 

638 def _default_version(self): 

639 return "No version information found" 

640 

641 @property 

642 def _default_kwargs(self): 

643 _kwargs = {"sampler": {}, "meta_data": {}} 

644 _kwargs["sampler"]["nsamples"] = len(self.data["samples"]) 

645 return _kwargs 

646 

647 @property 

648 def _default_description(self): 

649 return "No description found" 

650 

651 def _add_fixed_parameters_from_config_file(self, config_file, function): 

652 """Search the conifiguration file and add fixed parameters to the 

653 list of parameters and samples 

654 

655 Parameters 

656 ---------- 

657 config_file: str 

658 path to the configuration file 

659 function: func 

660 function you wish to use to extract the information from the 

661 configuration file 

662 """ 

663 self.data[0], self.data[1] = function(self.parameters, self.samples, config_file) 

664 

665 def _add_marginalized_parameters_from_config_file(self, config_file, function): 

666 """Search the configuration file and add marginalized parameters to the 

667 list of parameters and samples 

668 

669 Parameters 

670 ---------- 

671 config_file: str 

672 path to the configuration file 

673 function: func 

674 function you wish to use to extract the information from the 

675 configuration file 

676 """ 

677 self.data[0], self.data[1] = function(self.parameters, self.samples, config_file) 

678 

679 def to_latex_table(self, parameter_dict=None, save_to_file=None): 

680 """Make a latex table displaying the data in the result file. 

681 

682 Parameters 

683 ---------- 

684 parameter_dict: dict, optional 

685 dictionary of parameters that you wish to include in the latex 

686 table. The keys are the name of the parameters and the items are 

687 the descriptive text. If None, all parameters are included 

688 save_to_file: str, optional 

689 name of the file you wish to save the latex table to. If None, print 

690 to stdout 

691 """ 

692 import os 

693 

694 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)): 

695 raise FileExistsError( 

696 "The file {} already exists.".format(save_to_file) 

697 ) 

698 

699 table = self.latex_table([self.samples_dict], parameter_dict) 

700 if save_to_file is None: 

701 print(table) 

702 elif os.path.isfile("{}".format(save_to_file)): 

703 logger.warning( 

704 "File {} already exists. Printing to stdout".format(save_to_file) 

705 ) 

706 print(table) 

707 else: 

708 with open(save_to_file, "w") as f: 

709 f.writelines([table]) 

710 

711 def generate_latex_macros( 

712 self, parameter_dict=None, save_to_file=None, rounding="smart" 

713 ): 

714 """Generate a list of latex macros for each parameter in the result 

715 file 

716 

717 Parameters 

718 ---------- 

719 labels: list, optional 

720 list of labels that you want to include in the table 

721 parameter_dict: dict, optional 

722 dictionary of parameters that you wish to generate macros for. The 

723 keys are the name of the parameters and the items are the latex 

724 macros name you wish to use. If None, all parameters are included. 

725 save_to_file: str, optional 

726 name of the file you wish to save the latex table to. If None, print 

727 to stdout 

728 rounding: int, optional 

729 number of decimal points to round the latex macros 

730 """ 

731 import os 

732 

733 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)): 

734 raise FileExistsError( 

735 "The file {} already exists.".format(save_to_file) 

736 ) 

737 

738 macros = self.latex_macros( 

739 [self.samples_dict], parameter_dict, rounding=rounding 

740 ) 

741 if save_to_file is None: 

742 print(macros) 

743 else: 

744 with open(save_to_file, "w") as f: 

745 f.writelines([macros]) 

746 

747 def to_dat(self, **kwargs): 

748 """Save the PESummary results file object to a dat file 

749 

750 Parameters 

751 ---------- 

752 kwargs: dict 

753 all kwargs passed to the pesummary.core.file.formats.dat.write_dat 

754 function 

755 """ 

756 return self.write(file_format="dat", **kwargs) 

757 

758 def reweight_samples(self, function, **kwargs): 

759 """Reweight the posterior and/or prior samples according to a new prior 

760 """ 

761 if self.mcmc_samples: 

762 return ValueError("Cannot currently reweight MCMC chains") 

763 _samples = self.samples_dict 

764 new_samples = _samples.reweight(function, **kwargs) 

765 self.parameters = Parameters(new_samples.parameters) 

766 self.samples = np.array(new_samples.samples).T 

767 self.extra_kwargs["sampler"].update( 

768 { 

769 "nsamples": new_samples.number_of_samples, 

770 "nsamples_before_reweighting": _samples.number_of_samples 

771 } 

772 ) 

773 self.extra_kwargs["meta_data"]["reweighting"] = function 

774 if not hasattr(self, "priors"): 

775 return 

776 if (self.priors is None) or ("samples" not in self.priors.keys()): 

777 return 

778 prior_samples = self.priors["samples"] 

779 if not len(prior_samples): 

780 return 

781 new_prior_samples = prior_samples.reweight(function, **kwargs) 

782 self.priors["samples"] = new_prior_samples 

783 

784 

785class MultiAnalysisRead(Read): 

786 """Base class to read in a results file which contains multiple analyses 

787 

788 Parameters 

789 ---------- 

790 path_to_results_file: str 

791 path to the results file you wish to load 

792 remove_nan_likelihood_samples: Bool, optional 

793 if True, remove samples which have log_likelihood='nan'. Default True 

794 

795 Attributes 

796 ---------- 

797 parameters: 2d list 

798 list of parameters for each analysis 

799 samples: 3d list 

800 list of samples stored in the result file for each analysis 

801 samples_dict: dict 

802 dictionary of samples stored in the result file keyed by analysis label 

803 input_version: str 

804 version of the result files passed 

805 extra_kwargs: dict 

806 dictionary of kwargs that were extracted from the result file 

807 

808 Methods 

809 ------- 

810 samples_dict_for_label: dict 

811 dictionary of samples for a specific analysis 

812 reduced_samples_dict: dict 

813 dictionary of samples for one or more analyses 

814 downsample: 

815 downsample the posterior samples stored in the result file 

816 to_dat: 

817 save the posterior samples to a .dat file 

818 to_latex_table: 

819 convert the posterior samples to a latex table 

820 generate_latex_macros: 

821 generate a set of latex macros for the stored posterior samples 

822 reweight_samples: 

823 reweight the posterior and/or samples according to a new prior 

824 """ 

825 def __init__(self, *args, **kwargs): 

826 super(MultiAnalysisRead, self).__init__(*args, **kwargs) 

827 

828 @staticmethod 

829 def check_for_nan_likelihoods(parameters, samples, remove=False): 

830 import copy 

831 _parameters = copy.deepcopy(parameters) 

832 _samples = copy.deepcopy(samples) 

833 for num, (params, samps) in enumerate(zip(_parameters, _samples)): 

834 _parameters[num], _samples[num] = Read.check_for_nan_likelihoods( 

835 params, samps, remove=remove 

836 ) 

837 return _parameters, _samples 

838 

839 def samples_dict_for_label(self, label): 

840 """Return the posterior samples for a specific label 

841 

842 Parameters 

843 ---------- 

844 labels: str 

845 label you wish to get posterior samples for 

846 

847 Returns 

848 ------- 

849 outdict: SamplesDict 

850 Returns a SamplesDict containing the requested posterior samples 

851 """ 

852 if label not in self.labels: 

853 raise ValueError("Unrecognised label: '{}'".format(label)) 

854 idx = self.labels.index(label) 

855 return SamplesDict(self.parameters[idx], np.array(self.samples[idx]).T) 

856 

857 def reduced_samples_dict(self, labels): 

858 """Return the posterior samples for one or more labels 

859 

860 Parameters 

861 ---------- 

862 labels: str, list 

863 label(s) you wish to get posterior samples for 

864 

865 Returns 

866 ------- 

867 outdict: MultiAnalysisSamplesDict 

868 Returns a MultiAnalysisSamplesDict containing the requested 

869 posterior samples 

870 """ 

871 if not isinstance(labels, list): 

872 labels = [labels] 

873 not_allowed = [_label for _label in labels if _label not in self.labels] 

874 if len(not_allowed): 

875 raise ValueError( 

876 "Unrecognised label(s) '{}'. The list of available labels are " 

877 "{}.".format(", ".join(not_allowed), ", ".join(self.labels)) 

878 ) 

879 return MultiAnalysisSamplesDict( 

880 { 

881 label: self.samples_dict_for_label(label) for label in labels 

882 } 

883 ) 

884 

885 @property 

886 def samples_dict(self): 

887 if self.mcmc_samples: 

888 outdict = MCMCSamplesDict( 

889 self.parameters[0], [np.array(i).T for i in self.samples[0]] 

890 ) 

891 else: 

892 outdict = self.reduced_samples_dict(self.labels) 

893 return outdict 

894 

895 @property 

896 def _default_version(self): 

897 return ["No version information found"] * len(self.parameters) 

898 

899 @property 

900 def _default_kwargs(self): 

901 _kwargs = [{"sampler": {}, "meta_data": {}}] * len(self.parameters) 

902 for num, ss in enumerate(self.data["samples"]): 

903 _kwargs[num]["sampler"]["nsamples"] = len(ss) 

904 return _kwargs 

905 

906 @property 

907 def _default_description(self): 

908 return {label: "No description found" for label in self.labels} 

909 

910 def write(self, package="core", file_format="dat", **kwargs): 

911 """Save the data to file 

912 

913 Parameters 

914 ---------- 

915 package: str, optional 

916 package you wish to use when writing the data 

917 kwargs: dict, optional 

918 all additional kwargs are passed to the pesummary.io.write function 

919 """ 

920 return super(MultiAnalysisRead, self).write( 

921 package=package, file_format=file_format, 

922 extra_kwargs=self.kwargs_dict, file_versions=self.version_dict, 

923 **kwargs 

924 ) 

925 

926 @property 

927 def kwargs_dict(self): 

928 return { 

929 label: kwarg for label, kwarg in zip(self.labels, self.extra_kwargs) 

930 } 

931 

932 @property 

933 def version_dict(self): 

934 return { 

935 label: version for label, version in zip(self.labels, self.input_version) 

936 } 

937 

938 def summary(self, *args, parameters_to_show=4, **kwargs): 

939 """Return a summary of the contents of the file 

940 

941 Parameters 

942 ---------- 

943 parameters_to_show: int, optional 

944 number of parameters to show. Default 4 

945 """ 

946 string = super(MultiAnalysisRead, self).summary( 

947 show_parameters=False, show_nsamples=False 

948 ) 

949 string += "analyses: {}\n\n".format(", ".join(self.labels)) 

950 for num, label in enumerate(self.labels): 

951 string += "{}\n".format(label) 

952 string += "-" * len(label) + "\n" 

953 string += "description: {}\n".format(self.description[label]) 

954 string += "nsamples: {}\n".format(len(self.samples[num])) 

955 string += "parameters: {}\n\n".format( 

956 self._parameter_summary( 

957 self.parameters[num], parameters_to_show=parameters_to_show 

958 ) 

959 ) 

960 return string[:-2] 

961 

962 def downsample(self, number, labels=None): 

963 """Downsample the posterior samples stored in the result file 

964 """ 

965 for num, ss in enumerate(self.samples): 

966 if labels is not None and self.labels[num] not in labels: 

967 continue 

968 self.samples[num], self.extra_kwargs[num] = _downsample( 

969 ss, number, extra_kwargs=self.extra_kwargs[num] 

970 ) 

971 

972 def to_latex_table(self, labels="all", parameter_dict=None, save_to_file=None): 

973 """Make a latex table displaying the data in the result file. 

974 

975 Parameters 

976 ---------- 

977 labels: list, optional 

978 list of labels that you want to include in the table 

979 parameter_dict: dict, optional 

980 dictionary of parameters that you wish to include in the latex 

981 table. The keys are the name of the parameters and the items are 

982 the descriptive text. If None, all parameters are included 

983 save_to_file: str, optional 

984 name of the file you wish to save the latex table to. If None, print 

985 to stdout 

986 """ 

987 import os 

988 

989 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)): 

990 raise FileExistsError( 

991 "The file {} already exists.".format(save_to_file) 

992 ) 

993 if labels != "all" and isinstance(labels, str) and labels not in self.labels: 

994 raise ValueError("The label %s does not exist." % (labels)) 

995 elif labels == "all": 

996 labels = list(self.labels) 

997 elif isinstance(labels, str): 

998 labels = [labels] 

999 elif isinstance(labels, list): 

1000 for ll in labels: 

1001 if ll not in list(self.labels): 

1002 raise ValueError("The label %s does not exist." % (ll)) 

1003 

1004 table = self.latex_table( 

1005 [self.samples_dict[label] for label in labels], parameter_dict, 

1006 labels=labels 

1007 ) 

1008 if save_to_file is None: 

1009 print(table) 

1010 elif os.path.isfile("{}".format(save_to_file)): 

1011 logger.warning( 

1012 "File {} already exists. Printing to stdout".format(save_to_file) 

1013 ) 

1014 print(table) 

1015 else: 

1016 with open(save_to_file, "w") as f: 

1017 f.writelines([table]) 

1018 

1019 def generate_latex_macros( 

1020 self, labels="all", parameter_dict=None, save_to_file=None, 

1021 rounding=2 

1022 ): 

1023 """Generate a list of latex macros for each parameter in the result 

1024 file 

1025 

1026 Parameters 

1027 ---------- 

1028 labels: list, optional 

1029 list of labels that you want to include in the table 

1030 parameter_dict: dict, optional 

1031 dictionary of parameters that you wish to generate macros for. The 

1032 keys are the name of the parameters and the items are the latex 

1033 macros name you wish to use. If None, all parameters are included. 

1034 save_to_file: str, optional 

1035 name of the file you wish to save the latex table to. If None, print 

1036 to stdout 

1037 rounding: int, optional 

1038 number of decimal points to round the latex macros 

1039 """ 

1040 import os 

1041 

1042 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)): 

1043 raise FileExistsError( 

1044 "The file {} already exists.".format(save_to_file) 

1045 ) 

1046 if labels != "all" and isinstance(labels, str) and labels not in self.labels: 

1047 raise ValueError("The label %s does not exist." % (labels)) 

1048 elif labels == "all": 

1049 labels = list(self.labels) 

1050 elif isinstance(labels, str): 

1051 labels = [labels] 

1052 elif isinstance(labels, list): 

1053 for ll in labels: 

1054 if ll not in list(self.labels): 

1055 raise ValueError("The label %s does not exist." % (ll)) 

1056 

1057 macros = self.latex_macros( 

1058 [self.samples_dict[i] for i in labels], parameter_dict, 

1059 labels=labels, rounding=rounding 

1060 ) 

1061 if save_to_file is None: 

1062 print(macros) 

1063 else: 

1064 with open(save_to_file, "w") as f: 

1065 f.writelines([macros]) 

1066 

1067 def reweight_samples(self, function, labels=None, **kwargs): 

1068 """Reweight the posterior and/or prior samples according to a new prior 

1069 

1070 Parameters 

1071 ---------- 

1072 labels: list, optional 

1073 list of analyses you wish to reweight. Default reweight all 

1074 analyses 

1075 """ 

1076 _samples_dict = self.samples_dict 

1077 for idx, label in enumerate(self.labels): 

1078 if labels is not None and label not in labels: 

1079 continue 

1080 new_samples = _samples_dict[label].reweight(function, **kwargs) 

1081 self.parameters[idx] = Parameters(new_samples.parameters) 

1082 self.samples[idx] = np.array(new_samples.samples).T 

1083 self.extra_kwargs[idx]["sampler"].update( 

1084 { 

1085 "nsamples": new_samples.number_of_samples, 

1086 "nsamples_before_reweighting": ( 

1087 _samples_dict[label].number_of_samples 

1088 ) 

1089 } 

1090 ) 

1091 self.extra_kwargs[idx]["meta_data"]["reweighting"] = function 

1092 if not hasattr(self, "priors"): 

1093 continue 

1094 if "samples" not in self.priors.keys(): 

1095 continue 

1096 prior_samples = self.priors["samples"][label] 

1097 if not len(prior_samples): 

1098 continue 

1099 new_prior_samples = prior_samples.reweight(function, **kwargs) 

1100 self.priors["samples"][label] = new_prior_samples