Coverage for pesummary/core/file/formats/base_read.py: 56.5%

414 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import os 

4import numpy as np 

5import h5py 

6from pesummary.utils.parameters import MultiAnalysisParameters, Parameters 

7from pesummary.utils.samples_dict import ( 

8 MultiAnalysisSamplesDict, SamplesDict, MCMCSamplesDict, Array 

9) 

10from pesummary.utils.utils import logger 

11 

12__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

13 

14 

15def _downsample(samples, number, extra_kwargs=None): 

16 """Downsample a posterior table 

17 

18 Parameters 

19 ---------- 

20 samples: 2d list 

21 list of posterior samples where the columns correspond to a given 

22 parameter 

23 number: int 

24 number of posterior samples you wish to downsample to 

25 extra_kwargs: dict, optional 

26 dictionary of kwargs to modify 

27 """ 

28 from pesummary.utils.utils import resample_posterior_distribution 

29 import copy 

30 

31 _samples = np.array(samples).T 

32 if number > len(_samples[0]): 

33 raise ValueError( 

34 "Failed to downsample the posterior samples to {} because " 

35 "there are only {} samples stored in the file.".format( 

36 number, len(_samples[0]) 

37 ) 

38 ) 

39 _samples = np.array(resample_posterior_distribution(_samples, number)) 

40 if extra_kwargs is None: 

41 return _samples.T.tolist() 

42 _extra_kwargs = copy.deepcopy(extra_kwargs) 

43 _extra_kwargs["sampler"]["nsamples"] = number 

44 return _samples.T.tolist(), _extra_kwargs 

45 

46 

47class Read(object): 

48 """Base class to read in a results file 

49 

50 Parameters 

51 ---------- 

52 path_to_results_file: str 

53 path to the results file you wish to load 

54 remove_nan_likelihood_samples: Bool, optional 

55 if True, remove samples which have log_likelihood='nan'. Default True 

56 

57 Attributes 

58 ---------- 

59 parameters: list 

60 list of parameters stored in the result file 

61 samples: 2d list 

62 list of samples stored in the result file 

63 samples_dict: dict 

64 dictionary of samples stored in the result file keyed by parameters 

65 input_version: str 

66 version of the result file passed. 

67 extra_kwargs: dict 

68 dictionary of kwargs that were extracted from the result file 

69 pe_algorithm: str 

70 name of the algorithm used to generate the posterior samples 

71 

72 Methods 

73 ------- 

74 downsample: 

75 downsample the posterior samples stored in the result file 

76 to_dat: 

77 save the posterior samples to a .dat file 

78 to_latex_table: 

79 convert the posterior samples to a latex table 

80 generate_latex_macros: 

81 generate a set of latex macros for the stored posterior samples 

82 """ 

83 def __init__( 

84 self, path_to_results_file, remove_nan_likelihood_samples=True, **kwargs 

85 ): 

86 self.path_to_results_file = path_to_results_file 

87 self.mcmc_samples = False 

88 self.remove_nan_likelihood_samples = remove_nan_likelihood_samples 

89 self.extension = self.extension_from_path(self.path_to_results_file) 

90 self.converted_parameters = [] 

91 

92 @classmethod 

93 def load_file(cls, path, **kwargs): 

94 """Initialize the class with a file 

95 

96 Parameters 

97 ---------- 

98 path: str 

99 path to the result file you wish to load 

100 **kwargs: dict, optional 

101 all kwargs passed to the class 

102 """ 

103 if not os.path.isfile(path): 

104 raise FileNotFoundError("%s does not exist" % (path)) 

105 return cls(path, **kwargs) 

106 

107 @staticmethod 

108 def load_from_function(function, path_to_file, **kwargs): 

109 """Load a file according to a given function 

110 

111 Parameters 

112 ---------- 

113 function: func 

114 callable function that will load in your file 

115 path_to_file: str 

116 path to the file that you wish to load 

117 kwargs: dict 

118 all kwargs are passed to the function 

119 """ 

120 return function(path_to_file, **kwargs) 

121 

122 @staticmethod 

123 def check_for_nan_likelihoods(parameters, samples, remove=False): 

124 """Check to see if there are any samples with log_likelihood='nan' in 

125 the posterior table and remove if requested 

126 

127 Parameters 

128 ---------- 

129 parameters: list 

130 list of parameters stored in the result file 

131 samples: np.ndarray 

132 array of samples for each parameter 

133 remove: Bool, optional 

134 if True, remove samples with log_likelihood='nan' from samples 

135 """ 

136 import math 

137 if "log_likelihood" not in parameters: 

138 return parameters, samples 

139 ind = parameters.index("log_likelihood") 

140 likelihoods = np.array(samples).T[ind] 

141 inds = np.array( 

142 [math.isnan(_) for _ in likelihoods], dtype=bool 

143 ) 

144 if not sum(inds): 

145 return parameters, samples 

146 msg = ( 

147 "Posterior table contains {} samples with 'nan' log likelihood. " 

148 ) 

149 if remove: 

150 msg += "Removing samples from posterior table." 

151 samples = np.array(samples)[~inds].tolist() 

152 else: 

153 msg += "This may cause problems when analysing posterior samples." 

154 logger.warning(msg.format(sum(inds))) 

155 return parameters, samples 

156 

157 @staticmethod 

158 def check_for_weights(parameters, samples): 

159 """Check to see if the samples are weighted 

160 

161 Parameters 

162 ---------- 

163 parameters: list 

164 list of parameters stored in the result file 

165 samples: np.ndarray 

166 array of samples for each parameter 

167 """ 

168 likely_names = ["weights", "weight"] 

169 if any(i in parameters for i in likely_names): 

170 ind = ( 

171 parameters.index("weights") if "weights" in parameters else 

172 parameters.index("weight") 

173 ) 

174 return Array(np.array(samples).T[ind]) 

175 return None 

176 

177 @property 

178 def pe_algorithm(self): 

179 try: 

180 return self.extra_kwargs["sampler"]["pe_algorithm"] 

181 except KeyError: 

182 return None 

183 

184 def __repr__(self): 

185 return self.summary() 

186 

187 def _parameter_summary(self, parameters, parameters_to_show=4): 

188 """Return a summary of the parameter stored 

189 

190 Parameters 

191 ---------- 

192 parameters: list 

193 list of parameters to create a summary for 

194 parameters_to_show: int, optional 

195 number of parameters to show. Default 4. 

196 """ 

197 params = parameters 

198 if len(parameters) > parameters_to_show: 

199 params = parameters[:2] + ["..."] + parameters[-2:] 

200 return ", ".join(params) 

201 

202 def summary( 

203 self, parameters_to_show=4, show_parameters=True, show_nsamples=True 

204 ): 

205 """Return a summary of the contents of the file 

206 

207 Parameters 

208 ---------- 

209 parameters_to_show: int, optional 

210 number of parameters to show. Default 4 

211 show_parameters: Bool, optional 

212 if True print a list of the parameters stored 

213 show_nsamples: Bool, optional 

214 if True print how many samples are stored in the file 

215 """ 

216 string = "" 

217 if self.path_to_results_file is not None: 

218 string += "file: {}\n".format(self.path_to_results_file) 

219 string += "cls: {}.{}\n".format( 

220 self.__class__.__module__, self.__class__.__name__ 

221 ) 

222 if show_nsamples: 

223 string += "nsamples: {}\n".format(len(self.samples)) 

224 if show_parameters: 

225 string += "parameters: {}".format( 

226 self._parameter_summary( 

227 self.parameters, parameters_to_show=parameters_to_show 

228 ) 

229 ) 

230 return string 

231 

232 attrs = { 

233 "input_version": "version", "extra_kwargs": "kwargs", 

234 "priors": "prior", "analytic": "analytic", "labels": "labels", 

235 "config": "config", "weights": "weights", "history": "history", 

236 "description": "description" 

237 } 

238 

239 def _load(self, function, **kwargs): 

240 """Extract the data from a file using a given function 

241 

242 Parameters 

243 ---------- 

244 function: func 

245 function you wish to use to extract the data 

246 **kwargs: dict, optional 

247 optional kwargs to pass to the load function 

248 """ 

249 return self.load_from_function( 

250 function, self.path_to_results_file, **kwargs 

251 ) 

252 

253 def load(self, function, _data=None, **kwargs): 

254 """Load a results file according to a given function 

255 

256 Parameters 

257 ---------- 

258 function: func 

259 callable function that will load in your results file 

260 """ 

261 self.data = _data 

262 if _data is None: 

263 self.data = self._load(function, **kwargs) 

264 if isinstance(self.data["parameters"][0], list): 

265 _cls = MultiAnalysisParameters 

266 else: 

267 _cls = Parameters 

268 self.parameters = _cls(self.data["parameters"]) 

269 self.converted_parameters = [] 

270 self.samples = self.data["samples"] 

271 self.parameters, self.samples = self.check_for_nan_likelihoods( 

272 self.parameters, self.samples, 

273 remove=self.remove_nan_likelihood_samples 

274 ) 

275 if "mcmc_samples" in self.data.keys(): 

276 self.mcmc_samples = self.data["mcmc_samples"] 

277 if "injection" in self.data.keys(): 

278 if isinstance(self.data["injection"], dict): 

279 self.injection_parameters = { 

280 key.decode("utf-8") if isinstance(key, bytes) else key: val 

281 for key, val in self.data["injection"].items() 

282 } 

283 elif isinstance(self.data["injection"], list): 

284 self.injection_parameters = [ 

285 { 

286 key.decode("utf-8") if isinstance(key, bytes) else 

287 key: val for key, val in i.items() 

288 } for i in self.data["injection"] 

289 ] 

290 else: 

291 self.injection_parameters = self.data["injection"] 

292 for new_attr, existing_attr in self.attrs.items(): 

293 if existing_attr in self.data.keys(): 

294 setattr(self, new_attr, self.data[existing_attr]) 

295 else: 

296 setattr(self, new_attr, None) 

297 if self.input_version is None: 

298 self.input_version = self._default_version 

299 if self.extra_kwargs is None: 

300 self.extra_kwargs = self._default_kwargs 

301 if self.description is None: 

302 self.description = self._default_description 

303 if self.weights is None: 

304 self.weights = self.check_for_weights(self.parameters, self.samples) 

305 

306 @staticmethod 

307 def extension_from_path(path): 

308 """Return the extension of the file from the file path 

309 

310 Parameters 

311 ---------- 

312 path: str 

313 path to the results file 

314 """ 

315 extension = path.split(".")[-1] 

316 return extension 

317 

318 @staticmethod 

319 def guess_path_to_samples(path): 

320 """Guess the path to the posterior samples stored in an hdf5 object 

321 

322 Parameters 

323 ---------- 

324 path: str 

325 path to the results file 

326 """ 

327 def _find_name(name, item): 

328 c1 = "posterior_samples" in name or "posterior" in name 

329 c2 = isinstance(item, (h5py._hl.dataset.Dataset, np.ndarray)) 

330 try: 

331 c3 = isinstance(item, h5py._hl.group.Group) and isinstance( 

332 item[0], (float, int, np.number) 

333 ) 

334 except (TypeError, AttributeError): 

335 c3 = False 

336 c4 = ( 

337 isinstance(item, h5py._hl.group.Group) and "parameter_names" in 

338 item.keys() and "samples" in item.keys() 

339 ) 

340 if c1 and c3: 

341 paths.append(name) 

342 elif c1 and c4: 

343 return paths.append(name) 

344 elif c1 and c2: 

345 return paths.append(name) 

346 

347 f = h5py.File(path, 'r') 

348 paths = [] 

349 f.visititems(_find_name) 

350 f.close() 

351 if len(paths) == 1: 

352 return paths[0] 

353 elif len(paths) > 1: 

354 raise ValueError( 

355 "Found multiple posterior sample tables in '{}': {}. Not sure " 

356 "which to load.".format( 

357 path, ", ".join(paths) 

358 ) 

359 ) 

360 else: 

361 raise ValueError( 

362 "Unable to find a posterior samples table in '{}'".format(path) 

363 ) 

364 

365 def generate_all_posterior_samples(self, **kwargs): 

366 """Empty function 

367 """ 

368 pass 

369 

370 def add_fixed_parameters_from_config_file(self, config_file): 

371 """Search the conifiguration file and add fixed parameters to the 

372 list of parameters and samples 

373 

374 Parameters 

375 ---------- 

376 config_file: str 

377 path to the configuration file 

378 """ 

379 pass 

380 

381 def add_injection_parameters_from_file(self, injection_file, **kwargs): 

382 """Populate the 'injection_parameters' property with data from a file 

383 

384 Parameters 

385 ---------- 

386 injection_file: str 

387 path to injection file 

388 """ 

389 self.injection_parameters = self._grab_injection_parameters_from_file( 

390 injection_file, **kwargs 

391 ) 

392 

393 def _grab_injection_parameters_from_file( 

394 self, path, cls=None, add_nans=True, **kwargs 

395 ): 

396 """Extract data from an injection file 

397 

398 Parameters 

399 ---------- 

400 path: str 

401 path to injection file 

402 cls: class, optional 

403 class to read in injection file. The class must have a read class 

404 method and a samples_dict property. Default None which means that 

405 the pesummary.core.file.injection.Injection class is used 

406 """ 

407 if cls is None: 

408 from pesummary.core.file.injection import Injection 

409 cls = Injection 

410 data = cls.read(path, **kwargs).samples_dict 

411 for i in self.parameters: 

412 if i not in data.keys(): 

413 data[i] = float("nan") 

414 return data 

415 

416 def write( 

417 self, package="core", file_format="dat", extra_kwargs=None, 

418 file_versions=None, **kwargs 

419 ): 

420 """Save the data to file 

421 

422 Parameters 

423 ---------- 

424 package: str, optional 

425 package you wish to use when writing the data 

426 kwargs: dict, optional 

427 all additional kwargs are passed to the pesummary.io.write function 

428 """ 

429 from pesummary.io import write 

430 

431 if file_format == "pesummary" and np.array(self.parameters).ndim > 1: 

432 args = [self.samples_dict] 

433 else: 

434 args = [self.parameters, self.samples] 

435 if extra_kwargs is None: 

436 extra_kwargs = self.extra_kwargs 

437 if file_versions is None: 

438 file_versions = self.input_version 

439 if file_format == "ini": 

440 kwargs["file_format"] = "ini" 

441 return write(getattr(self, "config", None), **kwargs) 

442 else: 

443 return write( 

444 *args, package=package, file_versions=file_versions, 

445 file_kwargs=extra_kwargs, file_format=file_format, **kwargs 

446 ) 

447 

448 def downsample(self, number): 

449 """Downsample the posterior samples stored in the result file 

450 """ 

451 self.samples, self.extra_kwargs = _downsample( 

452 self.samples, number, extra_kwargs=self.extra_kwargs 

453 ) 

454 

455 @staticmethod 

456 def latex_table(samples, parameter_dict=None, labels=None): 

457 """Return a latex table displaying the passed data. 

458 

459 Parameters 

460 ---------- 

461 samples_dict: list 

462 list of pesummary.utils.utils.SamplesDict objects 

463 parameter_dict: dict, optional 

464 dictionary of parameters that you wish to include in the latex 

465 table. The keys are the name of the parameters and the items are 

466 the descriptive text. If None, all parameters are included 

467 """ 

468 table = ( 

469 "\\begin{table}[hptb]\n\\begin{ruledtabular}\n\\begin{tabular}" 

470 "{l %s}\n" % ("c " * len(samples)) 

471 ) 

472 if labels: 

473 table += ( 

474 " & " + " & ".join(labels) 

475 ) 

476 table += "\\\ \n\\hline \\\ \n" 

477 data = {i: i for i in samples[0].keys()} 

478 if parameter_dict is not None: 

479 import copy 

480 

481 data = copy.deepcopy(parameter_dict) 

482 for param in parameter_dict.keys(): 

483 if not all(param in samples_dict.keys() for samples_dict in samples): 

484 logger.warning( 

485 "{} not in list of parameters. Not adding to " 

486 "table".format(param) 

487 ) 

488 data.pop(param) 

489 

490 for param, desc in data.items(): 

491 table += "{}".format(desc) 

492 for samples_dict in samples: 

493 median = samples_dict[param].average(type="median") 

494 confidence = samples_dict[param].credible_interval() 

495 table += ( 

496 " & $%s^{+%s}_{-%s}$" % ( 

497 np.round(median, 2), 

498 np.round(confidence[1] - median, 2), 

499 np.round(median - confidence[0], 2) 

500 ) 

501 ) 

502 table += "\\\ \n" 

503 table += ( 

504 "\\end{tabular}\n\\end{ruledtabular}\n\\caption{}\n\\end{table}" 

505 ) 

506 return table 

507 

508 @staticmethod 

509 def latex_macros( 

510 samples, parameter_dict=None, labels=None, rounding="smart" 

511 ): 

512 """Return a latex table displaying the passed data. 

513 

514 Parameters 

515 ---------- 

516 samples_dict: list 

517 list of pesummary.utils.utils.SamplesDict objects 

518 parameter_dict: dict, optional 

519 dictionary of parameters that you wish to generate macros for. The 

520 keys are the name of the parameters and the items are the latex 

521 macros name you wish to use. If None, all parameters are included. 

522 rounding: int, optional 

523 decimal place for rounding. Default uses the 

524 `pesummary.utils.utils.smart_round` function to round according to 

525 the uncertainty 

526 """ 

527 macros = "" 

528 data = {i: i for i in samples[0].keys()} 

529 if parameter_dict is not None: 

530 import copy 

531 

532 data = copy.deepcopy(parameter_dict) 

533 for param in parameter_dict.keys(): 

534 if not all(param in samples_dict.keys() for samples_dict in samples): 

535 logger.warning( 

536 "{} not in list of parameters. Not generating " 

537 "macro".format(param) 

538 ) 

539 data.pop(param) 

540 for param, desc in data.items(): 

541 for num, samples_dict in enumerate(samples): 

542 if labels: 

543 description = "{}{}".format(desc, labels[num]) 

544 else: 

545 description = desc 

546 

547 median = samples_dict[param].average(type="median") 

548 confidence = samples_dict[param].credible_interval() 

549 if rounding == "smart": 

550 from pesummary.utils.utils import smart_round 

551 

552 median, upper, low = smart_round([ 

553 median, confidence[1] - median, median - confidence[0] 

554 ]) 

555 else: 

556 median = np.round(median, rounding) 

557 low = np.round(median - confidence[0], rounding) 

558 upper = np.round(confidence[1] - median, rounding) 

559 macros += ( 

560 "\\def\\%s{$%s_{-%s}^{+%s}$}\n" % ( 

561 description, median, low, upper 

562 ) 

563 ) 

564 macros += ( 

565 "\\def\\%smedian{$%s$}\n" % (description, median) 

566 ) 

567 macros += ( 

568 "\\def\\%supper{$%s$}\n" % ( 

569 description, np.round(median + upper, 9) 

570 ) 

571 ) 

572 macros += ( 

573 "\\def\\%slower{$%s$}\n" % ( 

574 description, np.round(median - low, 9) 

575 ) 

576 ) 

577 return macros 

578 

579 

580class SingleAnalysisRead(Read): 

581 """Base class to read in a results file which contains a single analyses 

582 

583 Parameters 

584 ---------- 

585 path_to_results_file: str 

586 path to the results file you wish to load 

587 remove_nan_likelihood_samples: Bool, optional 

588 if True, remove samples which have log_likelihood='nan'. Default True 

589 

590 Attributes 

591 ---------- 

592 parameters: list 

593 list of parameters stored in the file 

594 samples: 2d list 

595 list of samples stored in the result file 

596 samples_dict: dict 

597 dictionary of samples stored in the result file 

598 input_version: str 

599 version of the result file passed 

600 extra_kwargs: dict 

601 dictionary of kwargs that were extracted from the result file 

602 

603 Methods 

604 ------- 

605 downsample: 

606 downsample the posterior samples stored in the result file 

607 to_dat: 

608 save the posterior samples to a .dat file 

609 to_latex_table: 

610 convert the posterior samples to a latex table 

611 generate_latex_macros: 

612 generate a set of latex macros for the stored posterior samples 

613 reweight_samples: 

614 reweight the posterior and/or samples according to a new prior 

615 """ 

616 def __init__(self, *args, **kwargs): 

617 super(SingleAnalysisRead, self).__init__(*args, **kwargs) 

618 

619 @property 

620 def samples_dict(self): 

621 if self.mcmc_samples: 

622 return MCMCSamplesDict( 

623 self.parameters, [np.array(i).T for i in self.samples] 

624 ) 

625 return SamplesDict(self.parameters, np.array(self.samples).T) 

626 

627 @property 

628 def _default_version(self): 

629 return "No version information found" 

630 

631 @property 

632 def _default_kwargs(self): 

633 _kwargs = {"sampler": {}, "meta_data": {}} 

634 _kwargs["sampler"]["nsamples"] = len(self.data["samples"]) 

635 return _kwargs 

636 

637 @property 

638 def _default_description(self): 

639 return "No description found" 

640 

641 def _add_fixed_parameters_from_config_file(self, config_file, function): 

642 """Search the conifiguration file and add fixed parameters to the 

643 list of parameters and samples 

644 

645 Parameters 

646 ---------- 

647 config_file: str 

648 path to the configuration file 

649 function: func 

650 function you wish to use to extract the information from the 

651 configuration file 

652 """ 

653 self.data[0], self.data[1] = function(self.parameters, self.samples, config_file) 

654 

655 def _add_marginalized_parameters_from_config_file(self, config_file, function): 

656 """Search the configuration file and add marginalized parameters to the 

657 list of parameters and samples 

658 

659 Parameters 

660 ---------- 

661 config_file: str 

662 path to the configuration file 

663 function: func 

664 function you wish to use to extract the information from the 

665 configuration file 

666 """ 

667 self.data[0], self.data[1] = function(self.parameters, self.samples, config_file) 

668 

669 def to_latex_table(self, parameter_dict=None, save_to_file=None): 

670 """Make a latex table displaying the data in the result file. 

671 

672 Parameters 

673 ---------- 

674 parameter_dict: dict, optional 

675 dictionary of parameters that you wish to include in the latex 

676 table. The keys are the name of the parameters and the items are 

677 the descriptive text. If None, all parameters are included 

678 save_to_file: str, optional 

679 name of the file you wish to save the latex table to. If None, print 

680 to stdout 

681 """ 

682 import os 

683 

684 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)): 

685 raise FileExistsError( 

686 "The file {} already exists.".format(save_to_file) 

687 ) 

688 

689 table = self.latex_table([self.samples_dict], parameter_dict) 

690 if save_to_file is None: 

691 print(table) 

692 elif os.path.isfile("{}".format(save_to_file)): 

693 logger.warning( 

694 "File {} already exists. Printing to stdout".format(save_to_file) 

695 ) 

696 print(table) 

697 else: 

698 with open(save_to_file, "w") as f: 

699 f.writelines([table]) 

700 

701 def generate_latex_macros( 

702 self, parameter_dict=None, save_to_file=None, rounding="smart" 

703 ): 

704 """Generate a list of latex macros for each parameter in the result 

705 file 

706 

707 Parameters 

708 ---------- 

709 labels: list, optional 

710 list of labels that you want to include in the table 

711 parameter_dict: dict, optional 

712 dictionary of parameters that you wish to generate macros for. The 

713 keys are the name of the parameters and the items are the latex 

714 macros name you wish to use. If None, all parameters are included. 

715 save_to_file: str, optional 

716 name of the file you wish to save the latex table to. If None, print 

717 to stdout 

718 rounding: int, optional 

719 number of decimal points to round the latex macros 

720 """ 

721 import os 

722 

723 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)): 

724 raise FileExistsError( 

725 "The file {} already exists.".format(save_to_file) 

726 ) 

727 

728 macros = self.latex_macros( 

729 [self.samples_dict], parameter_dict, rounding=rounding 

730 ) 

731 if save_to_file is None: 

732 print(macros) 

733 else: 

734 with open(save_to_file, "w") as f: 

735 f.writelines([macros]) 

736 

737 def to_dat(self, **kwargs): 

738 """Save the PESummary results file object to a dat file 

739 

740 Parameters 

741 ---------- 

742 kwargs: dict 

743 all kwargs passed to the pesummary.core.file.formats.dat.write_dat 

744 function 

745 """ 

746 return self.write(file_format="dat", **kwargs) 

747 

748 def reweight_samples(self, function, **kwargs): 

749 """Reweight the posterior and/or prior samples according to a new prior 

750 """ 

751 if self.mcmc_samples: 

752 return ValueError("Cannot currently reweight MCMC chains") 

753 _samples = self.samples_dict 

754 new_samples = _samples.reweight(function, **kwargs) 

755 self.parameters = Parameters(new_samples.parameters) 

756 self.samples = np.array(new_samples.samples).T 

757 self.extra_kwargs["sampler"].update( 

758 { 

759 "nsamples": new_samples.number_of_samples, 

760 "nsamples_before_reweighting": _samples.number_of_samples 

761 } 

762 ) 

763 self.extra_kwargs["meta_data"]["reweighting"] = function 

764 if not hasattr(self, "priors"): 

765 return 

766 if (self.priors is None) or ("samples" not in self.priors.keys()): 

767 return 

768 prior_samples = self.priors["samples"] 

769 if not len(prior_samples): 

770 return 

771 new_prior_samples = prior_samples.reweight(function, **kwargs) 

772 self.priors["samples"] = new_prior_samples 

773 

774 

775class MultiAnalysisRead(Read): 

776 """Base class to read in a results file which contains multiple analyses 

777 

778 Parameters 

779 ---------- 

780 path_to_results_file: str 

781 path to the results file you wish to load 

782 remove_nan_likelihood_samples: Bool, optional 

783 if True, remove samples which have log_likelihood='nan'. Default True 

784 

785 Attributes 

786 ---------- 

787 parameters: 2d list 

788 list of parameters for each analysis 

789 samples: 3d list 

790 list of samples stored in the result file for each analysis 

791 samples_dict: dict 

792 dictionary of samples stored in the result file keyed by analysis label 

793 input_version: str 

794 version of the result files passed 

795 extra_kwargs: dict 

796 dictionary of kwargs that were extracted from the result file 

797 

798 Methods 

799 ------- 

800 samples_dict_for_label: dict 

801 dictionary of samples for a specific analysis 

802 reduced_samples_dict: dict 

803 dictionary of samples for one or more analyses 

804 downsample: 

805 downsample the posterior samples stored in the result file 

806 to_dat: 

807 save the posterior samples to a .dat file 

808 to_latex_table: 

809 convert the posterior samples to a latex table 

810 generate_latex_macros: 

811 generate a set of latex macros for the stored posterior samples 

812 reweight_samples: 

813 reweight the posterior and/or samples according to a new prior 

814 """ 

815 def __init__(self, *args, **kwargs): 

816 super(MultiAnalysisRead, self).__init__(*args, **kwargs) 

817 

818 @staticmethod 

819 def check_for_nan_likelihoods(parameters, samples, remove=False): 

820 import copy 

821 _parameters = copy.deepcopy(parameters) 

822 _samples = copy.deepcopy(samples) 

823 for num, (params, samps) in enumerate(zip(_parameters, _samples)): 

824 _parameters[num], _samples[num] = Read.check_for_nan_likelihoods( 

825 params, samps, remove=remove 

826 ) 

827 return _parameters, _samples 

828 

829 def samples_dict_for_label(self, label): 

830 """Return the posterior samples for a specific label 

831 

832 Parameters 

833 ---------- 

834 labels: str 

835 label you wish to get posterior samples for 

836 

837 Returns 

838 ------- 

839 outdict: SamplesDict 

840 Returns a SamplesDict containing the requested posterior samples 

841 """ 

842 if label not in self.labels: 

843 raise ValueError("Unrecognised label: '{}'".format(label)) 

844 idx = self.labels.index(label) 

845 return SamplesDict(self.parameters[idx], np.array(self.samples[idx]).T) 

846 

847 def reduced_samples_dict(self, labels): 

848 """Return the posterior samples for one or more labels 

849 

850 Parameters 

851 ---------- 

852 labels: str, list 

853 label(s) you wish to get posterior samples for 

854 

855 Returns 

856 ------- 

857 outdict: MultiAnalysisSamplesDict 

858 Returns a MultiAnalysisSamplesDict containing the requested 

859 posterior samples 

860 """ 

861 if not isinstance(labels, list): 

862 labels = [labels] 

863 not_allowed = [_label for _label in labels if _label not in self.labels] 

864 if len(not_allowed): 

865 raise ValueError( 

866 "Unrecognised label(s) '{}'. The list of available labels are " 

867 "{}.".format(", ".join(not_allowed), ", ".join(self.labels)) 

868 ) 

869 return MultiAnalysisSamplesDict( 

870 { 

871 label: self.samples_dict_for_label(label) for label in labels 

872 } 

873 ) 

874 

875 @property 

876 def samples_dict(self): 

877 if self.mcmc_samples: 

878 outdict = MCMCSamplesDict( 

879 self.parameters[0], [np.array(i).T for i in self.samples[0]] 

880 ) 

881 else: 

882 outdict = self.reduced_samples_dict(self.labels) 

883 return outdict 

884 

885 @property 

886 def _default_version(self): 

887 return ["No version information found"] * len(self.parameters) 

888 

889 @property 

890 def _default_kwargs(self): 

891 _kwargs = [{"sampler": {}, "meta_data": {}}] * len(self.parameters) 

892 for num, ss in enumerate(self.data["samples"]): 

893 _kwargs[num]["sampler"]["nsamples"] = len(ss) 

894 return _kwargs 

895 

896 @property 

897 def _default_description(self): 

898 return {label: "No description found" for label in self.labels} 

899 

900 def write(self, package="core", file_format="dat", **kwargs): 

901 """Save the data to file 

902 

903 Parameters 

904 ---------- 

905 package: str, optional 

906 package you wish to use when writing the data 

907 kwargs: dict, optional 

908 all additional kwargs are passed to the pesummary.io.write function 

909 """ 

910 return super(MultiAnalysisRead, self).write( 

911 package=package, file_format=file_format, 

912 extra_kwargs=self.kwargs_dict, file_versions=self.version_dict, 

913 **kwargs 

914 ) 

915 

916 @property 

917 def kwargs_dict(self): 

918 return { 

919 label: kwarg for label, kwarg in zip(self.labels, self.extra_kwargs) 

920 } 

921 

922 @property 

923 def version_dict(self): 

924 return { 

925 label: version for label, version in zip(self.labels, self.input_version) 

926 } 

927 

928 def summary(self, *args, parameters_to_show=4, **kwargs): 

929 """Return a summary of the contents of the file 

930 

931 Parameters 

932 ---------- 

933 parameters_to_show: int, optional 

934 number of parameters to show. Default 4 

935 """ 

936 string = super(MultiAnalysisRead, self).summary( 

937 show_parameters=False, show_nsamples=False 

938 ) 

939 string += "analyses: {}\n\n".format(", ".join(self.labels)) 

940 for num, label in enumerate(self.labels): 

941 string += "{}\n".format(label) 

942 string += "-" * len(label) + "\n" 

943 string += "description: {}\n".format(self.description[label]) 

944 string += "nsamples: {}\n".format(len(self.samples[num])) 

945 string += "parameters: {}\n\n".format( 

946 self._parameter_summary( 

947 self.parameters[num], parameters_to_show=parameters_to_show 

948 ) 

949 ) 

950 return string[:-2] 

951 

952 def downsample(self, number, labels=None): 

953 """Downsample the posterior samples stored in the result file 

954 """ 

955 for num, ss in enumerate(self.samples): 

956 if labels is not None and self.labels[num] not in labels: 

957 continue 

958 self.samples[num], self.extra_kwargs[num] = _downsample( 

959 ss, number, extra_kwargs=self.extra_kwargs[num] 

960 ) 

961 

962 def to_latex_table(self, labels="all", parameter_dict=None, save_to_file=None): 

963 """Make a latex table displaying the data in the result file. 

964 

965 Parameters 

966 ---------- 

967 labels: list, optional 

968 list of labels that you want to include in the table 

969 parameter_dict: dict, optional 

970 dictionary of parameters that you wish to include in the latex 

971 table. The keys are the name of the parameters and the items are 

972 the descriptive text. If None, all parameters are included 

973 save_to_file: str, optional 

974 name of the file you wish to save the latex table to. If None, print 

975 to stdout 

976 """ 

977 import os 

978 

979 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)): 

980 raise FileExistsError( 

981 "The file {} already exists.".format(save_to_file) 

982 ) 

983 if labels != "all" and isinstance(labels, str) and labels not in self.labels: 

984 raise ValueError("The label %s does not exist." % (labels)) 

985 elif labels == "all": 

986 labels = list(self.labels) 

987 elif isinstance(labels, str): 

988 labels = [labels] 

989 elif isinstance(labels, list): 

990 for ll in labels: 

991 if ll not in list(self.labels): 

992 raise ValueError("The label %s does not exist." % (ll)) 

993 

994 table = self.latex_table( 

995 [self.samples_dict[label] for label in labels], parameter_dict, 

996 labels=labels 

997 ) 

998 if save_to_file is None: 

999 print(table) 

1000 elif os.path.isfile("{}".format(save_to_file)): 

1001 logger.warning( 

1002 "File {} already exists. Printing to stdout".format(save_to_file) 

1003 ) 

1004 print(table) 

1005 else: 

1006 with open(save_to_file, "w") as f: 

1007 f.writelines([table]) 

1008 

1009 def generate_latex_macros( 

1010 self, labels="all", parameter_dict=None, save_to_file=None, 

1011 rounding=2 

1012 ): 

1013 """Generate a list of latex macros for each parameter in the result 

1014 file 

1015 

1016 Parameters 

1017 ---------- 

1018 labels: list, optional 

1019 list of labels that you want to include in the table 

1020 parameter_dict: dict, optional 

1021 dictionary of parameters that you wish to generate macros for. The 

1022 keys are the name of the parameters and the items are the latex 

1023 macros name you wish to use. If None, all parameters are included. 

1024 save_to_file: str, optional 

1025 name of the file you wish to save the latex table to. If None, print 

1026 to stdout 

1027 rounding: int, optional 

1028 number of decimal points to round the latex macros 

1029 """ 

1030 import os 

1031 

1032 if save_to_file is not None and os.path.isfile("{}".format(save_to_file)): 

1033 raise FileExistsError( 

1034 "The file {} already exists.".format(save_to_file) 

1035 ) 

1036 if labels != "all" and isinstance(labels, str) and labels not in self.labels: 

1037 raise ValueError("The label %s does not exist." % (labels)) 

1038 elif labels == "all": 

1039 labels = list(self.labels) 

1040 elif isinstance(labels, str): 

1041 labels = [labels] 

1042 elif isinstance(labels, list): 

1043 for ll in labels: 

1044 if ll not in list(self.labels): 

1045 raise ValueError("The label %s does not exist." % (ll)) 

1046 

1047 macros = self.latex_macros( 

1048 [self.samples_dict[i] for i in labels], parameter_dict, 

1049 labels=labels, rounding=rounding 

1050 ) 

1051 if save_to_file is None: 

1052 print(macros) 

1053 else: 

1054 with open(save_to_file, "w") as f: 

1055 f.writelines([macros]) 

1056 

1057 def reweight_samples(self, function, labels=None, **kwargs): 

1058 """Reweight the posterior and/or prior samples according to a new prior 

1059 

1060 Parameters 

1061 ---------- 

1062 labels: list, optional 

1063 list of analyses you wish to reweight. Default reweight all 

1064 analyses 

1065 """ 

1066 _samples_dict = self.samples_dict 

1067 for idx, label in enumerate(self.labels): 

1068 if labels is not None and label not in labels: 

1069 continue 

1070 new_samples = _samples_dict[label].reweight(function, **kwargs) 

1071 self.parameters[idx] = Parameters(new_samples.parameters) 

1072 self.samples[idx] = np.array(new_samples.samples).T 

1073 self.extra_kwargs[idx]["sampler"].update( 

1074 { 

1075 "nsamples": new_samples.number_of_samples, 

1076 "nsamples_before_reweighting": ( 

1077 _samples_dict[label].number_of_samples 

1078 ) 

1079 } 

1080 ) 

1081 self.extra_kwargs[idx]["meta_data"]["reweighting"] = function 

1082 if not hasattr(self, "priors"): 

1083 continue 

1084 if "samples" not in self.priors.keys(): 

1085 continue 

1086 prior_samples = self.priors["samples"][label] 

1087 if not len(prior_samples): 

1088 continue 

1089 new_prior_samples = prior_samples.reweight(function, **kwargs) 

1090 self.priors["samples"][label] = new_prior_samples