Coverage for pesummary/utils/samples_dict.py: 67.0%

636 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import copy 

4import numpy as np 

5from pesummary.utils.utils import resample_posterior_distribution, logger 

6from pesummary.utils.decorators import docstring_subfunction 

7from pesummary.utils.array import Array, _2DArray 

8from pesummary.utils.dict import Dict 

9from pesummary.utils.parameters import Parameters 

10from pesummary.core.plots.latex_labels import latex_labels 

11from pesummary.gw.plots.latex_labels import GWlatex_labels 

12from pesummary import conf 

13import importlib 

14 

15__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

16 

17latex_labels.update(GWlatex_labels) 

18 

19 

20class SamplesDict(Dict): 

21 """Class to store the samples from a single run 

22 

23 Parameters 

24 ---------- 

25 parameters: list 

26 list of parameters 

27 samples: nd list 

28 list of samples for each parameter 

29 autoscale: Bool, optional 

30 If True, the posterior samples for each parameter are scaled to the 

31 same length 

32 

33 Attributes 

34 ---------- 

35 maxL: pesummary.utils.samples_dict.SamplesDict 

36 SamplesDict object containing the maximum likelihood sample keyed by 

37 the parameter 

38 minimum: pesummary.utils.samples_dict.SamplesDict 

39 SamplesDict object containing the minimum sample for each parameter 

40 maximum: pesummary.utils.samples_dict.SamplesDict 

41 SamplesDict object containing the maximum sample for each parameter 

42 median: pesummary.utils.samples_dict.SamplesDict 

43 SamplesDict object containining the median of each marginalized 

44 posterior distribution 

45 mean: pesummary.utils.samples_dict.SamplesDict 

46 SamplesDict object containing the mean of each marginalized posterior 

47 distribution 

48 key_data: dict 

49 dictionary containing the key data associated with each array 

50 number_of_samples: int 

51 Number of samples stored in the SamplesDict object 

52 latex_labels: dict 

53 Dictionary of latex labels for each parameter 

54 available_plots: list 

55 list of plots which the user may user to display the contained posterior 

56 samples 

57 

58 Methods 

59 ------- 

60 from_file: 

61 Initialize the SamplesDict class with the contents of a file 

62 to_pandas: 

63 Convert the SamplesDict object to a pandas DataFrame 

64 to_structured_array: 

65 Convert the SamplesDict object to a numpy structured array 

66 pop: 

67 Remove an entry from the SamplesDict object 

68 standardize_parameter_names: 

69 Modify keys in SamplesDict to use standard PESummary names 

70 downsample: 

71 Downsample the samples stored in the SamplesDict object. See the 

72 pesummary.utils.utils.resample_posterior_distribution method 

73 discard_samples: 

74 Remove the first N samples from each distribution 

75 plot: 

76 Generate a plot based on the posterior samples stored 

77 generate_all_posterior_samples: 

78 Convert the posterior samples in the SamplesDict object according to 

79 a conversion function 

80 debug_keys: list 

81 list of keys with an '_' as their first character 

82 reweight: 

83 Reweight the posterior samples according to a new prior 

84 write: 

85 Save the stored posterior samples to file 

86 

87 Examples 

88 -------- 

89 How the initialize the SamplesDict class 

90 

91 >>> from pesummary.utils.samples_dict import SamplesDict 

92 >>> data = { 

93 ... "a": [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6], 

94 ... "b": [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9] 

95 ... } 

96 >>> dataset = SamplesDict(data) 

97 >>> parameters = ["a", "b"] 

98 >>> samples = [ 

99 ... [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6], 

100 ... [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9] 

101 ... } 

102 >>> dataset = SamplesDict(parameters, samples) 

103 >>> fig = dataset.plot("a", type="hist", bins=30) 

104 >>> fig.show() 

105 """ 

106 def __init__(self, *args, logger_warn="warn", autoscale=True): 

107 super(SamplesDict, self).__init__( 

108 *args, value_class=Array, make_dict_kwargs={"autoscale": autoscale}, 

109 logger_warn=logger_warn, latex_labels=latex_labels 

110 ) 

111 

112 def __getitem__(self, key): 

113 """Return an object representing the specialization of SamplesDict 

114 by type arguments found in key. 

115 """ 

116 if isinstance(key, slice): 

117 return SamplesDict( 

118 self.parameters, np.array( 

119 [i[key.start:key.stop:key.step] for i in self.samples] 

120 ) 

121 ) 

122 elif isinstance(key, (list, np.ndarray)): 

123 return SamplesDict( 

124 self.parameters, np.array([i[key] for i in self.samples]) 

125 ) 

126 elif key[0] == "_": 

127 return self.samples[self.parameters.index(key)] 

128 return super(SamplesDict, self).__getitem__(key) 

129 

130 def __setitem__(self, key, value): 

131 _value = value 

132 if not isinstance(value, Array): 

133 _value = Array(value) 

134 super(SamplesDict, self).__setitem__(key, _value) 

135 try: 

136 if key not in self.parameters: 

137 self.parameters.append(key) 

138 try: 

139 cond = ( 

140 np.array(self.samples).ndim == 1 and isinstance( 

141 self.samples[0], (float, int, np.number) 

142 ) 

143 ) 

144 except Exception: 

145 cond = False 

146 if cond and isinstance(self.samples, np.ndarray): 

147 self.samples = np.append(self.samples, value) 

148 elif cond and isinstance(self.samples, list): 

149 self.samples.append(value) 

150 else: 

151 self.samples = np.vstack([self.samples, value]) 

152 self._update_latex_labels() 

153 except (AttributeError, TypeError): 

154 pass 

155 

156 def __str__(self): 

157 """Print a summary of the information stored in the dictionary 

158 """ 

159 def format_string(string, row): 

160 """Format a list into a table 

161 

162 Parameters 

163 ---------- 

164 string: str 

165 existing table 

166 row: list 

167 the row you wish to be written to a table 

168 """ 

169 string += "{:<8}".format(row[0]) 

170 for i in range(1, len(row)): 

171 if isinstance(row[i], str): 

172 string += "{:<15}".format(row[i]) 

173 elif isinstance(row[i], (float, int, np.int64, np.int32)): 

174 string += "{:<15.6f}".format(row[i]) 

175 string += "\n" 

176 return string 

177 

178 string = "" 

179 string = format_string(string, ["idx"] + list(self.keys())) 

180 

181 if self.number_of_samples < 8: 

182 for i in range(self.number_of_samples): 

183 string = format_string( 

184 string, [i] + [item[i] for key, item in self.items()] 

185 ) 

186 else: 

187 for i in range(4): 

188 string = format_string( 

189 string, [i] + [item[i] for key, item in self.items()] 

190 ) 

191 for i in range(2): 

192 string = format_string(string, ["."] * (len(self.keys()) + 1)) 

193 for i in range(self.number_of_samples - 2, self.number_of_samples): 

194 string = format_string( 

195 string, [i] + [item[i] for key, item in self.items()] 

196 ) 

197 return string 

198 

199 @classmethod 

200 def from_file(cls, filename, **kwargs): 

201 """Initialize the SamplesDict class with the contents of a result file 

202 

203 Parameters 

204 ---------- 

205 filename: str 

206 path to the result file you wish to load. 

207 **kwargs: dict 

208 all kwargs are passed to the pesummary.io.read function 

209 """ 

210 from pesummary.io import read 

211 

212 return read(filename, **kwargs).samples_dict 

213 

214 @property 

215 def key_data(self): 

216 return {param: value.key_data for param, value in self.items()} 

217 

218 @property 

219 def maxL(self): 

220 return SamplesDict( 

221 self.parameters, [[item.maxL] for key, item in self.items()] 

222 ) 

223 

224 @property 

225 def minimum(self): 

226 return SamplesDict( 

227 self.parameters, [[item.minimum] for key, item in self.items()] 

228 ) 

229 

230 @property 

231 def maximum(self): 

232 return SamplesDict( 

233 self.parameters, [[item.maximum] for key, item in self.items()] 

234 ) 

235 

236 @property 

237 def median(self): 

238 return SamplesDict( 

239 self.parameters, 

240 [[item.average(type="median")] for key, item in self.items()] 

241 ) 

242 

243 @property 

244 def mean(self): 

245 return SamplesDict( 

246 self.parameters, 

247 [[item.average(type="mean")] for key, item in self.items()] 

248 ) 

249 

250 @property 

251 def number_of_samples(self): 

252 return len(self[self.parameters[0]]) 

253 

254 @property 

255 def plotting_map(self): 

256 existing = super(SamplesDict, self).plotting_map 

257 modified = existing.copy() 

258 modified.update( 

259 { 

260 "marginalized_posterior": self._marginalized_posterior, 

261 "skymap": self._skymap, 

262 "hist": self._marginalized_posterior, 

263 "corner": self._corner, 

264 "spin_disk": self._spin_disk, 

265 "2d_kde": self._2d_kde, 

266 "triangle": self._triangle, 

267 "reverse_triangle": self._reverse_triangle, 

268 } 

269 ) 

270 return modified 

271 

272 def standardize_parameter_names(self, mapping=None): 

273 """Modify keys in SamplesDict to use standard PESummary names 

274 

275 Parameters 

276 ---------- 

277 mapping: dict, optional 

278 dictionary mapping existing keys to standard PESummary names. 

279 Default pesummary.gw.file.standard_names.standard_names 

280 

281 Returns 

282 ------- 

283 standard_dict: SamplesDict 

284 SamplesDict object with standard PESummary parameter names 

285 """ 

286 from pesummary.utils.utils import map_parameter_names 

287 if mapping is None: 

288 from pesummary.gw.file.standard_names import standard_names 

289 mapping = standard_names 

290 return SamplesDict(map_parameter_names(self, mapping)) 

291 

292 def debug_keys(self, *args, **kwargs): 

293 _keys = self.keys() 

294 _total = self.keys(remove_debug=False) 

295 return Parameters([key for key in _total if key not in _keys]) 

296 

297 def keys(self, *args, remove_debug=True, **kwargs): 

298 original = super(SamplesDict, self).keys(*args, **kwargs) 

299 if remove_debug: 

300 return Parameters([key for key in original if key[0] != "_"]) 

301 return Parameters(original) 

302 

303 def write(self, **kwargs): 

304 """Save the stored posterior samples to file 

305 

306 Parameters 

307 ---------- 

308 **kwargs: dict, optional 

309 all additional kwargs passed to the pesummary.io.write function 

310 """ 

311 from pesummary.io import write 

312 write(self.parameters, self.samples.T, **kwargs) 

313 

314 def items(self, *args, remove_debug=True, **kwargs): 

315 items = super(SamplesDict, self).items(*args, **kwargs) 

316 if remove_debug: 

317 return [item for item in items if item[0][0] != "_"] 

318 return items 

319 

320 def to_pandas(self, **kwargs): 

321 """Convert a SamplesDict object to a pandas dataframe 

322 """ 

323 from pandas import DataFrame 

324 

325 return DataFrame(self, **kwargs) 

326 

327 def to_structured_array(self, **kwargs): 

328 """Convert a SamplesDict object to a structured numpy array 

329 """ 

330 return self.to_pandas(**kwargs).to_records( 

331 index=False, column_dtypes=float 

332 ) 

333 

334 def pop(self, parameter): 

335 """Delete a parameter from the SamplesDict 

336 

337 Parameters 

338 ---------- 

339 parameter: str 

340 name of the parameter you wish to remove from the SamplesDict 

341 """ 

342 if parameter not in self.parameters: 

343 logger.info( 

344 "{} not in SamplesDict. Unable to remove {}".format( 

345 parameter, parameter 

346 ) 

347 ) 

348 return 

349 ind = self.parameters.index(parameter) 

350 self.parameters.remove(parameter) 

351 samples = self.samples 

352 self.samples = np.delete(samples, ind, axis=0) 

353 return super(SamplesDict, self).pop(parameter) 

354 

355 def downsample(self, number): 

356 """Downsample the samples stored in the SamplesDict class 

357 

358 Parameters 

359 ---------- 

360 number: int 

361 Number of samples you wish to downsample to 

362 """ 

363 self.samples = resample_posterior_distribution(self.samples, number) 

364 self.make_dictionary() 

365 return self 

366 

367 def discard_samples(self, number): 

368 """Remove the first n samples 

369 

370 Parameters 

371 ---------- 

372 number: int 

373 Number of samples that you wish to remove 

374 """ 

375 self.make_dictionary(discard_samples=number) 

376 return self 

377 

378 def make_dictionary(self, discard_samples=None, autoscale=True): 

379 """Add the parameters and samples to the class 

380 """ 

381 lengths = [len(i) for i in self.samples] 

382 if len(np.unique(lengths)) > 1 and autoscale: 

383 nsamples = np.min(lengths) 

384 getattr(logger, self.logger_warn)( 

385 "Unequal number of samples for each parameter. " 

386 "Restricting all posterior samples to have {} " 

387 "samples".format(nsamples) 

388 ) 

389 self.samples = [ 

390 dataset[:nsamples] for dataset in self.samples 

391 ] 

392 if "log_likelihood" in self.parameters: 

393 likelihoods = self.samples[self.parameters.index("log_likelihood")] 

394 likelihoods = likelihoods[discard_samples:] 

395 else: 

396 likelihoods = None 

397 if "log_prior" in self.parameters: 

398 priors = self.samples[self.parameters.index("log_prior")] 

399 priors = priors[discard_samples:] 

400 else: 

401 priors = None 

402 if any(i in self.parameters for i in ["weights", "weight"]): 

403 ind = ( 

404 self.parameters.index("weights") if "weights" in self.parameters 

405 else self.parameters.index("weight") 

406 ) 

407 weights = self.samples[ind][discard_samples:] 

408 else: 

409 weights = None 

410 _2d_array = _2DArray( 

411 np.array(self.samples)[:, discard_samples:], likelihood=likelihoods, 

412 prior=priors, weights=weights 

413 ) 

414 for key, val in zip(self.parameters, _2d_array): 

415 self[key] = val 

416 

417 @docstring_subfunction([ 

418 'pesummary.core.plots.plot._1d_histogram_plot', 

419 'pesummary.gw.plots.plot._1d_histogram_plot', 

420 'pesummary.gw.plots.plot._ligo_skymap_plot', 

421 'pesummary.gw.plots.publication.spin_distribution_plots', 

422 'pesummary.core.plots.plot._make_corner_plot', 

423 'pesummary.gw.plots.plot._make_corner_plot' 

424 ]) 

425 def plot(self, *args, type="marginalized_posterior", **kwargs): 

426 """Generate a plot for the posterior samples stored in SamplesDict 

427 

428 Parameters 

429 ---------- 

430 *args: tuple 

431 all arguments are passed to the plotting function 

432 type: str 

433 name of the plot you wish to make 

434 **kwargs: dict 

435 all additional kwargs are passed to the plotting function 

436 """ 

437 return super(SamplesDict, self).plot(*args, type=type, **kwargs) 

438 

439 def generate_all_posterior_samples(self, function=None, **kwargs): 

440 """Convert samples stored in the SamplesDict according to a conversion 

441 function 

442 

443 Parameters 

444 ---------- 

445 function: func, optional 

446 function to use when converting posterior samples. Must take a 

447 dictionary as input and return a dictionary of converted posterior 

448 samples. Default `pesummary.gw.conversions.convert 

449 **kwargs: dict, optional 

450 All additional kwargs passed to function 

451 """ 

452 if function is None: 

453 from pesummary.gw.conversions import convert 

454 function = convert 

455 _samples = self.copy() 

456 _keys = list(_samples.keys()) 

457 kwargs.update({"return_dict": True}) 

458 out = function(_samples, **kwargs) 

459 if kwargs.get("return_kwargs", False): 

460 converted_samples, extra_kwargs = out 

461 else: 

462 converted_samples, extra_kwargs = out, None 

463 for key, item in converted_samples.items(): 

464 if key not in _keys: 

465 self[key] = item 

466 return extra_kwargs 

467 

468 def reweight( 

469 self, function, ignore_debug_params=["recalib", "spcal"], **kwargs 

470 ): 

471 """Reweight the posterior samples according to a new prior 

472 

473 Parameters 

474 ---------- 

475 function: func/str 

476 function to use when resampling 

477 ignore_debug_params: list, optional 

478 params to ignore when storing unweighted posterior distributions. 

479 Default any param with ['recalib', 'spcal'] in their name 

480 """ 

481 from pesummary.gw.reweight import options 

482 if isinstance(function, str) and function in options.keys(): 

483 function = options[function] 

484 elif isinstance(function, str): 

485 raise ValueError( 

486 "Unknown function '{}'. Please provide a function for " 

487 "reweighting or select one of the following: {}".format( 

488 function, ", ".join(list(options.keys())) 

489 ) 

490 ) 

491 _samples = SamplesDict(self.copy()) 

492 new_samples = function(_samples, **kwargs) 

493 _samples.downsample(new_samples.number_of_samples) 

494 for key, item in new_samples.items(): 

495 if not any(param in key for param in ignore_debug_params): 

496 _samples["_{}_non_reweighted".format(key)] = _samples[key] 

497 _samples[key] = item 

498 return SamplesDict(_samples) 

499 

500 def _marginalized_posterior(self, parameter, module="core", **kwargs): 

501 """Wrapper for the `pesummary.core.plots.plot._1d_histogram_plot` or 

502 `pesummary.gw.plots.plot._1d_histogram_plot` 

503 

504 Parameters 

505 ---------- 

506 parameter: str 

507 name of the parameter you wish to plot 

508 module: str, optional 

509 module you wish to use for the plotting 

510 **kwargs: dict 

511 all additional kwargs are passed to the `_1d_histogram_plot` 

512 function 

513 """ 

514 module = importlib.import_module( 

515 "pesummary.{}.plots.plot".format(module) 

516 ) 

517 return getattr(module, "_1d_histogram_plot")( 

518 parameter, self[parameter], self.latex_labels[parameter], 

519 weights=self[parameter].weights, **kwargs 

520 ) 

521 

522 def _skymap(self, **kwargs): 

523 """Wrapper for the `pesummary.gw.plots.plot._ligo_skymap_plot` 

524 function 

525 

526 Parameters 

527 ---------- 

528 **kwargs: dict 

529 All kwargs are passed to the `_ligo_skymap_plot` function 

530 """ 

531 from pesummary.gw.plots.plot import _ligo_skymap_plot 

532 

533 if "luminosity_distance" in self.keys(): 

534 dist = self["luminosity_distance"] 

535 else: 

536 dist = None 

537 

538 return _ligo_skymap_plot(self["ra"], self["dec"], dist=dist, **kwargs) 

539 

540 def _spin_disk(self, **kwargs): 

541 """Wrapper for the `pesummary.gw.plots.publication.spin_distribution_plots` 

542 function 

543 """ 

544 from pesummary.gw.plots.publication import spin_distribution_plots 

545 

546 required = ["a_1", "a_2", "cos_tilt_1", "cos_tilt_2"] 

547 if not all(param in self.keys() for param in required): 

548 raise ValueError( 

549 "The spin disk plot requires samples for the following " 

550 "parameters: {}".format(", ".join(required)) 

551 ) 

552 samples = [self[param] for param in required] 

553 return spin_distribution_plots(required, samples, None, **kwargs) 

554 

555 def _corner(self, module="core", parameters=None, **kwargs): 

556 """Wrapper for the `pesummary.core.plots.plot._make_corner_plot` or 

557 `pesummary.gw.plots.plot._make_corner_plot` function 

558 

559 Parameters 

560 ---------- 

561 module: str, optional 

562 module you wish to use for the plotting 

563 **kwargs: dict 

564 all additional kwargs are passed to the `_make_corner_plot` 

565 function 

566 """ 

567 module = importlib.import_module( 

568 "pesummary.{}.plots.plot".format(module) 

569 ) 

570 _parameters = None 

571 if parameters is not None: 

572 _parameters = [param for param in parameters if param in self.keys()] 

573 if not len(_parameters): 

574 raise ValueError( 

575 "None of the chosen parameters are in the posterior " 

576 "samples table. Please choose other parameters to plot" 

577 ) 

578 return getattr(module, "_make_corner_plot")( 

579 self, self.latex_labels, corner_parameters=_parameters, **kwargs 

580 )[0] 

581 

582 def _2d_kde(self, parameters, module="core", **kwargs): 

583 """Wrapper for the `pesummary.gw.plots.publication.twod_contour_plot` or 

584 `pesummary.core.plots.publication.twod_contour_plot` function 

585 

586 Parameters 

587 ---------- 

588 parameters: list 

589 list of length 2 giving the parameters you wish to plot 

590 module: str, optional 

591 module you wish to use for the plotting 

592 **kwargs: dict, optional 

593 all additional kwargs are passed to the `twod_contour_plot` function 

594 """ 

595 _module = importlib.import_module( 

596 "pesummary.{}.plots.publication".format(module) 

597 ) 

598 if module == "gw": 

599 return getattr(_module, "twod_contour_plots")( 

600 parameters, [[self[parameters[0]], self[parameters[1]]]], 

601 [None], { 

602 parameters[0]: self.latex_labels[parameters[0]], 

603 parameters[1]: self.latex_labels[parameters[1]] 

604 }, **kwargs 

605 ) 

606 return getattr(_module, "twod_contour_plot")( 

607 self[parameters[0]], self[parameters[1]], 

608 xlabel=self.latex_labels[parameters[0]], 

609 ylabel=self.latex_labels[parameters[1]], **kwargs 

610 ) 

611 

612 def _triangle(self, parameters, module="core", **kwargs): 

613 """Wrapper for the `pesummary.core.plots.publication.triangle_plot` 

614 function 

615 

616 Parameters 

617 ---------- 

618 parameters: list 

619 list of parameters they wish to study 

620 **kwargs: dict 

621 all additional kwargs are passed to the `triangle_plot` function 

622 """ 

623 _module = importlib.import_module( 

624 "pesummary.{}.plots.publication".format(module) 

625 ) 

626 if module == "gw": 

627 kwargs["parameters"] = parameters 

628 return getattr(_module, "triangle_plot")( 

629 [self[parameters[0]]], [self[parameters[1]]], 

630 xlabel=self.latex_labels[parameters[0]], 

631 ylabel=self.latex_labels[parameters[1]], **kwargs 

632 ) 

633 

634 def _reverse_triangle(self, parameters, module="core", **kwargs): 

635 """Wrapper for the `pesummary.core.plots.publication.reverse_triangle_plot` 

636 function 

637 

638 Parameters 

639 ---------- 

640 parameters: list 

641 list of parameters they wish to study 

642 **kwargs: dict 

643 all additional kwargs are passed to the `triangle_plot` function 

644 """ 

645 _module = importlib.import_module( 

646 "pesummary.{}.plots.publication".format(module) 

647 ) 

648 if module == "gw": 

649 kwargs["parameters"] = parameters 

650 return getattr(_module, "reverse_triangle_plot")( 

651 [self[parameters[0]]], [self[parameters[1]]], 

652 xlabel=self.latex_labels[parameters[0]], 

653 ylabel=self.latex_labels[parameters[1]], **kwargs 

654 ) 

655 

656 def classification(self, dual=True, population=False): 

657 """Return the classification probabilities 

658 

659 Parameters 

660 ---------- 

661 dual: Bool, optional 

662 if True, return classification probabilities generated from the 

663 raw samples ('default') an samples reweighted to a population 

664 inferred prior ('population'). Default True. 

665 population: Bool, optional 

666 if True, reweight the samples to a population informed prior and 

667 then calculate classification probabilities. Default False. Only 

668 used when dual=False 

669 """ 

670 from pesummary.gw.classification import Classify 

671 if dual: 

672 probs = Classify(self).dual_classification() 

673 else: 

674 probs = Classify(self).classification(population=population) 

675 return probs 

676 

677 def _waveform_args(self, f_ref=20., ind=0, longAscNodes=0., eccentricity=0.): 

678 """Arguments to be passed to waveform generation 

679 

680 Parameters 

681 ---------- 

682 f_ref: float, optional 

683 reference frequency to use when converting spherical spins to 

684 cartesian spins 

685 ind: int, optional 

686 index for the sample you wish to plot 

687 longAscNodes: float, optional 

688 longitude of ascending nodes, degenerate with the polarization 

689 angle. Default 0. 

690 eccentricity: float, optional 

691 eccentricity at reference frequency. Default 0. 

692 """ 

693 from lal import MSUN_SI, PC_SI 

694 

695 _samples = {key: value[ind] for key, value in self.items()} 

696 required = [ 

697 "mass_1", "mass_2", "luminosity_distance" 

698 ] 

699 if not all(param in _samples.keys() for param in required): 

700 raise ValueError( 

701 "Unable to generate a waveform. Please add samples for " 

702 + ", ".join(required) 

703 ) 

704 waveform_args = [ 

705 _samples["mass_1"] * MSUN_SI, _samples["mass_2"] * MSUN_SI 

706 ] 

707 spin_angles = [ 

708 "theta_jn", "phi_jl", "tilt_1", "tilt_2", "phi_12", "a_1", "a_2", 

709 "phase" 

710 ] 

711 spin_angles_condition = all( 

712 spin in _samples.keys() for spin in spin_angles 

713 ) 

714 cartesian_spins = [ 

715 "spin_1x", "spin_1y", "spin_1z", "spin_2x", "spin_2y", "spin_2z" 

716 ] 

717 cartesian_spins_condition = any( 

718 spin in _samples.keys() for spin in cartesian_spins 

719 ) 

720 if spin_angles_condition and not cartesian_spins_condition: 

721 from pesummary.gw.conversions import component_spins 

722 data = component_spins( 

723 _samples["theta_jn"], _samples["phi_jl"], _samples["tilt_1"], 

724 _samples["tilt_2"], _samples["phi_12"], _samples["a_1"], 

725 _samples["a_2"], _samples["mass_1"], _samples["mass_2"], 

726 f_ref, _samples["phase"] 

727 ) 

728 iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = data.T 

729 spins = [spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z] 

730 else: 

731 iota = _samples["iota"] 

732 spins = [ 

733 _samples[param] if param in _samples.keys() else 0. for param in 

734 ["spin_1x", "spin_1y", "spin_1z", "spin_2x", "spin_2y", "spin_2z"] 

735 ] 

736 waveform_args += spins 

737 phase = _samples["phase"] if "phase" in _samples.keys() else 0. 

738 waveform_args += [ 

739 _samples["luminosity_distance"] * PC_SI * 10**6, iota, phase 

740 ] 

741 waveform_args += [longAscNodes, eccentricity, 0.] 

742 return waveform_args, _samples 

743 

744 def antenna_response(self, ifo): 

745 """ 

746 """ 

747 from pesummary.gw.waveform import antenna_response 

748 return antenna_response(self, ifo) 

749 

750 def _project_waveform(self, ifo, hp, hc, ra, dec, psi, time): 

751 """Project a waveform onto a given detector 

752 

753 Parameters 

754 ---------- 

755 ifo: str 

756 name of the detector you wish to project the waveform onto 

757 hp: np.ndarray 

758 plus gravitational wave polarization 

759 hc: np.ndarray 

760 cross gravitational wave polarization 

761 ra: float 

762 right ascension to be passed to antenna response function 

763 dec: float 

764 declination to be passed to antenna response function 

765 psi: float 

766 polarization to be passed to antenna response function 

767 time: float 

768 time to be passed to antenna response function 

769 """ 

770 import importlib 

771 

772 mod = importlib.import_module("pesummary.gw.plots.plot") 

773 func = getattr(mod, "__antenna_response") 

774 antenna = func(ifo, ra, dec, psi, time) 

775 ht = hp * antenna[0] + hc * antenna[1] 

776 return ht 

777 

778 def fd_waveform(self, approximant, delta_f, f_low, f_high, **kwargs): 

779 """Generate a gravitational wave in the frequency domain 

780 

781 Parameters 

782 ---------- 

783 approximant: str 

784 name of the approximant to use when generating the waveform 

785 delta_f: float 

786 spacing between frequency samples 

787 f_low: float 

788 frequency to start evaluating the waveform 

789 f_high: float 

790 frequency to stop evaluating the waveform 

791 f_ref: float, optional 

792 reference frequency 

793 project: str, optional 

794 name of the detector to project the waveform onto. If None, 

795 the plus and cross polarizations are returned. Default None 

796 ind: int, optional 

797 index for the sample you wish to plot 

798 longAscNodes: float, optional 

799 longitude of ascending nodes, degenerate with the polarization 

800 angle. Default 0. 

801 eccentricity: float, optional 

802 eccentricity at reference frequency. Default 0. 

803 LAL_parameters: dict, optional 

804 LAL dictioanry containing accessory parameters. Default None 

805 pycbc: Bool, optional 

806 return a the waveform as a pycbc.frequencyseries.FrequencySeries 

807 object 

808 """ 

809 from pesummary.gw.waveform import fd_waveform 

810 return fd_waveform(self, approximant, delta_f, f_low, f_high, **kwargs) 

811 

812 def td_waveform( 

813 self, approximant, delta_t, f_low, **kwargs 

814 ): 

815 """Generate a gravitational wave in the time domain 

816 

817 Parameters 

818 ---------- 

819 approximant: str 

820 name of the approximant to use when generating the waveform 

821 delta_t: float 

822 spacing between frequency samples 

823 f_low: float 

824 frequency to start evaluating the waveform 

825 f_ref: float, optional 

826 reference frequency 

827 project: str, optional 

828 name of the detector to project the waveform onto. If None, 

829 the plus and cross polarizations are returned. Default None 

830 ind: int, optional 

831 index for the sample you wish to plot 

832 longAscNodes: float, optional 

833 longitude of ascending nodes, degenerate with the polarization 

834 angle. Default 0. 

835 eccentricity: float, optional 

836 eccentricity at reference frequency. Default 0. 

837 LAL_parameters: dict, optional 

838 LAL dictioanry containing accessory parameters. Default None 

839 pycbc: Bool, optional 

840 return a the waveform as a pycbc.timeseries.TimeSeries object 

841 level: list, optional 

842 the symmetric confidence interval of the time domain waveform. Level 

843 must be greater than 0 and less than 1 

844 """ 

845 from pesummary.gw.waveform import td_waveform 

846 return td_waveform( 

847 self, approximant, delta_t, f_low, **kwargs 

848 ) 

849 

850 def _maxL_waveform(self, func, *args, **kwargs): 

851 """Return the maximum likelihood waveform in a given domain 

852 

853 Parameters 

854 ---------- 

855 func: function 

856 function you wish to use when generating the maximum likelihood 

857 waveform 

858 *args: tuple 

859 all args passed to func 

860 **kwargs: dict 

861 all kwargs passed to func 

862 """ 

863 ind = np.argmax(self["log_likelihood"]) 

864 kwargs["ind"] = ind 

865 return func(*args, **kwargs) 

866 

867 def maxL_td_waveform(self, *args, **kwargs): 

868 """Generate the maximum likelihood gravitational wave in the time domain 

869 

870 Parameters 

871 ---------- 

872 approximant: str 

873 name of the approximant to use when generating the waveform 

874 delta_t: float 

875 spacing between frequency samples 

876 f_low: float 

877 frequency to start evaluating the waveform 

878 f_ref: float, optional 

879 reference frequency 

880 project: str, optional 

881 name of the detector to project the waveform onto. If None, 

882 the plus and cross polarizations are returned. Default None 

883 longAscNodes: float, optional 

884 longitude of ascending nodes, degenerate with the polarization 

885 angle. Default 0. 

886 eccentricity: float, optional 

887 eccentricity at reference frequency. Default 0. 

888 LAL_parameters: dict, optional 

889 LAL dictioanry containing accessory parameters. Default None 

890 level: list, optional 

891 the symmetric confidence interval of the time domain waveform. Level 

892 must be greater than 0 and less than 1 

893 """ 

894 return self._maxL_waveform(self.td_waveform, *args, **kwargs) 

895 

896 def maxL_fd_waveform(self, *args, **kwargs): 

897 """Generate the maximum likelihood gravitational wave in the frequency 

898 domain 

899 

900 Parameters 

901 ---------- 

902 approximant: str 

903 name of the approximant to use when generating the waveform 

904 delta_f: float 

905 spacing between frequency samples 

906 f_low: float 

907 frequency to start evaluating the waveform 

908 f_high: float 

909 frequency to stop evaluating the waveform 

910 f_ref: float, optional 

911 reference frequency 

912 project: str, optional 

913 name of the detector to project the waveform onto. If None, 

914 the plus and cross polarizations are returned. Default None 

915 longAscNodes: float, optional 

916 longitude of ascending nodes, degenerate with the polarization 

917 angle. Default 0. 

918 eccentricity: float, optional 

919 eccentricity at reference frequency. Default 0. 

920 LAL_parameters: dict, optional 

921 LAL dictioanry containing accessory parameters. Default None 

922 """ 

923 return self._maxL_waveform(self.fd_waveform, *args, **kwargs) 

924 

925 

926class _MultiDimensionalSamplesDict(Dict): 

927 """Class to store multiple SamplesDict objects 

928 

929 Parameters 

930 ---------- 

931 parameters: list 

932 list of parameters 

933 samples: nd list 

934 list of samples for each parameter for each chain 

935 label_prefix: str, optional 

936 prefix to use when distinguishing different analyses. The label is then 

937 '{label_prefix}_{num}' where num is the result file index. Default 

938 is 'dataset' 

939 transpose: Bool, optional 

940 True if the input is a transposed dictionary 

941 labels: list, optional 

942 the labels to use to distinguish different analyses. If provided 

943 label_prefix is ignored 

944 

945 Attributes 

946 ---------- 

947 T: pesummary.utils.samples_dict._MultiDimensionalSamplesDict 

948 Transposed _MultiDimensionalSamplesDict object keyed by parameters 

949 rather than label 

950 nsamples: int 

951 Total number of analyses stored in the _MultiDimensionalSamplesDict 

952 object 

953 number_of_samples: dict 

954 Number of samples stored in the _MultiDimensionalSamplesDict for each 

955 analysis 

956 total_number_of_samples: int 

957 Total number of samples stored across the multiple analyses 

958 minimum_number_of_samples: int 

959 The number of samples in the smallest analysis 

960 

961 Methods 

962 ------- 

963 samples: 

964 Return a list of samples stored in the _MultiDimensionalSamplesDict 

965 object for a given parameter 

966 """ 

967 def __init__( 

968 self, *args, label_prefix="dataset", transpose=False, labels=None 

969 ): 

970 if labels is not None and len(np.unique(labels)) != len(labels): 

971 raise ValueError( 

972 "Please provide a unique set of labels for each analysis" 

973 ) 

974 invalid_label_number_error = "Please provide a label for each analysis" 

975 self.labels = labels 

976 self.name = _MultiDimensionalSamplesDict 

977 self.transpose = transpose 

978 if len(args) == 1 and isinstance(args[0], dict): 

979 if transpose: 

980 parameters = list(args[0].keys()) 

981 _labels = list(args[0][parameters[0]].keys()) 

982 outer_iterator, inner_iterator = parameters, _labels 

983 else: 

984 _labels = list(args[0].keys()) 

985 parameters = { 

986 label: list(args[0][label].keys()) for label in _labels 

987 } 

988 outer_iterator, inner_iterator = _labels, parameters 

989 if labels is None: 

990 self.labels = _labels 

991 for num, dataset in enumerate(outer_iterator): 

992 if isinstance(inner_iterator, dict): 

993 try: 

994 samples = np.array( 

995 [args[0][dataset][param] for param in inner_iterator[dataset]] 

996 ) 

997 except ValueError: # numpy deprecation error 

998 samples = np.array( 

999 [args[0][dataset][param] for param in inner_iterator[dataset]], 

1000 dtype=object 

1001 ) 

1002 else: 

1003 try: 

1004 samples = np.array( 

1005 [args[0][dataset][param] for param in inner_iterator] 

1006 ) 

1007 except ValueError: # numpy deprecation error 

1008 samples = np.array( 

1009 [args[0][dataset][param] for param in inner_iterator], 

1010 dtype=object 

1011 ) 

1012 if transpose: 

1013 desc = parameters[num] 

1014 self[desc] = SamplesDict( 

1015 self.labels, samples, logger_warn="debug", 

1016 autoscale=False 

1017 ) 

1018 else: 

1019 if self.labels is not None: 

1020 desc = self.labels[num] 

1021 else: 

1022 desc = "{}_{}".format(label_prefix, num) 

1023 self[desc] = SamplesDict(parameters[self.labels[num]], samples) 

1024 else: 

1025 parameters, samples = args 

1026 if labels is not None and len(labels) != len(samples): 

1027 raise ValueError(invalid_label_number_error) 

1028 for num, dataset in enumerate(samples): 

1029 if labels is not None: 

1030 desc = labels[num] 

1031 else: 

1032 desc = "{}_{}".format(label_prefix, num) 

1033 self[desc] = SamplesDict(parameters, dataset) 

1034 if self.labels is None: 

1035 self.labels = [ 

1036 "{}_{}".format(label_prefix, num) for num, _ in 

1037 enumerate(samples) 

1038 ] 

1039 self.parameters = parameters 

1040 self._update_latex_labels() 

1041 

1042 def _update_latex_labels(self): 

1043 """Update the stored latex labels 

1044 """ 

1045 _parameters = [ 

1046 list(value.keys()) for value in self.values() 

1047 ] 

1048 _parameters = [item for sublist in _parameters for item in sublist] 

1049 self._latex_labels = { 

1050 param: latex_labels[param] if param in latex_labels.keys() else 

1051 param for param in self.total_list_of_parameters + _parameters 

1052 } 

1053 

1054 def __setitem__(self, key, value): 

1055 _value = value 

1056 if not isinstance(value, SamplesDict): 

1057 _value = SamplesDict(value) 

1058 super(_MultiDimensionalSamplesDict, self).__setitem__(key, _value) 

1059 try: 

1060 if key not in self.labels: 

1061 parameters = list(value.keys()) 

1062 try: 

1063 samples = np.array([value[param] for param in parameters]) 

1064 except ValueError: # numpy deprecation error 

1065 samples = np.array( 

1066 [value[param] for param in parameters], dtype=object 

1067 ) 

1068 self.parameters[key] = parameters 

1069 self.labels.append(key) 

1070 self.latex_labels = self._latex_labels() 

1071 except (AttributeError, TypeError): 

1072 pass 

1073 

1074 @property 

1075 def T(self): 

1076 _transpose = not self.transpose 

1077 if not self.transpose: 

1078 _params = sorted([param for param in self[self.labels[0]].keys()]) 

1079 if not all(sorted(self[l].keys()) == _params for l in self.labels): 

1080 raise ValueError( 

1081 "Unable to transpose as not all samples have the same " 

1082 "parameters" 

1083 ) 

1084 transpose_dict = { 

1085 param: { 

1086 label: dataset[param] for label, dataset in self.items() 

1087 } for param in self[self.labels[0]].keys() 

1088 } 

1089 else: 

1090 transpose_dict = { 

1091 label: { 

1092 param: self[param][label] for param in self.keys() 

1093 } for label in self.labels 

1094 } 

1095 return self.name(transpose_dict, transpose=_transpose) 

1096 

1097 def _combine( 

1098 self, labels=None, use_all=False, weights=None, shuffle=False, 

1099 logger_level="debug" 

1100 ): 

1101 """Combine samples from a select number of analyses into a single 

1102 SamplesDict object. 

1103 

1104 Parameters 

1105 ---------- 

1106 labels: list, optional 

1107 analyses you wish to combine. Default use all labels stored in the 

1108 dictionary 

1109 use_all: Bool, optional 

1110 if True, use all of the samples (do not weight). Default False 

1111 weights: dict, optional 

1112 dictionary of weights for each of the posteriors. Keys must be the 

1113 labels you wish to combine and values are the weights you wish to 

1114 assign to the posterior 

1115 shuffle: Bool, optional 

1116 shuffle the combined samples 

1117 logger_level: str, optional 

1118 logger level you wish to use. Default debug. 

1119 """ 

1120 try: 

1121 _logger = getattr(logger, logger_level) 

1122 except AttributeError: 

1123 raise ValueError( 

1124 "Unknown logger level. Please choose either 'info' or 'debug'" 

1125 ) 

1126 if labels is None: 

1127 _provided_labels = False 

1128 labels = self.labels 

1129 else: 

1130 _provided_labels = True 

1131 if not all(label in self.labels for label in labels): 

1132 raise ValueError( 

1133 "Not all of the provided labels exist in the dictionary. " 

1134 "The list of available labels are: {}".format( 

1135 ", ".join(self.labels) 

1136 ) 

1137 ) 

1138 _logger("Combining the following analyses: {}".format(labels)) 

1139 if use_all and weights is not None: 

1140 raise ValueError( 

1141 "Unable to use all samples and provide weights" 

1142 ) 

1143 elif not use_all and weights is None: 

1144 weights = {label: 1. for label in labels} 

1145 elif not use_all and weights is not None: 

1146 if len(weights) < len(labels): 

1147 raise ValueError( 

1148 "Please provide weights for each set of samples: {}".format( 

1149 len(labels) 

1150 ) 

1151 ) 

1152 if not _provided_labels and not isinstance(weights, dict): 

1153 raise ValueError( 

1154 "Weights must be provided as a dictionary keyed by the " 

1155 "analysis label. The available labels are: {}".format( 

1156 ", ".join(labels) 

1157 ) 

1158 ) 

1159 elif not isinstance(weights, dict): 

1160 weights = { 

1161 label: weight for label, weight in zip(labels, weights) 

1162 } 

1163 if not all(label in labels for label in weights.keys()): 

1164 for label in labels: 

1165 if label not in weights.keys(): 

1166 weights[label] = 1. 

1167 logger.warning( 

1168 "No weight given for '{}'. Assigning a weight of " 

1169 "1".format(label) 

1170 ) 

1171 sum_weights = np.sum([_weight for _weight in weights.values()]) 

1172 weights = { 

1173 key: item / sum_weights for key, item in weights.items() 

1174 } 

1175 if weights is not None: 

1176 _logger( 

1177 "Using the following weights for each file, {}".format( 

1178 " ".join( 

1179 ["{}: {}".format(k, v) for k, v in weights.items()] 

1180 ) 

1181 ) 

1182 ) 

1183 _lengths = np.array( 

1184 [self.number_of_samples[key] for key in labels] 

1185 ) 

1186 if use_all: 

1187 draw = _lengths 

1188 else: 

1189 draw = np.zeros(len(labels), dtype=int) 

1190 _weights = np.array([weights[key] for key in labels]) 

1191 inds = np.argwhere(_weights > 0.) 

1192 # The next 4 lines are inspired from the 'cbcBayesCombinePosteriors' 

1193 # executable provided by LALSuite. Credit should go to the 

1194 # authors of that code. 

1195 initial = _weights[inds] * float(sum(_lengths[inds])) 

1196 min_index = np.argmin(_lengths[inds] / initial) 

1197 size = _lengths[inds][min_index] / _weights[inds][min_index] 

1198 draw[inds] = np.around(_weights[inds] * size).astype(int) 

1199 _logger( 

1200 "Randomly drawing the following number of samples from each file, " 

1201 "{}".format( 

1202 " ".join( 

1203 [ 

1204 "{}: {}/{}".format(l, draw[n], _lengths[n]) for n, l in 

1205 enumerate(labels) 

1206 ] 

1207 ) 

1208 ) 

1209 ) 

1210 

1211 if self.transpose: 

1212 _data = self.T 

1213 else: 

1214 _data = copy.deepcopy(self) 

1215 for num, label in enumerate(labels): 

1216 if draw[num] > 0: 

1217 _data[label].downsample(draw[num]) 

1218 else: 

1219 _data[label] = { 

1220 param: np.array([]) for param in _data[label].keys() 

1221 } 

1222 try: 

1223 intersection = set.intersection( 

1224 *[ 

1225 set(_params) for _key, _params in _data.parameters.items() if 

1226 _key in labels 

1227 ] 

1228 ) 

1229 except AttributeError: 

1230 intersection = _data.parameters 

1231 logger.debug( 

1232 "Only including the parameters: {} as they are common to all " 

1233 "analyses".format(", ".join(list(intersection))) 

1234 ) 

1235 data = { 

1236 param: np.concatenate([_data[key][param] for key in labels]) for 

1237 param in intersection 

1238 } 

1239 if shuffle: 

1240 inds = np.random.choice( 

1241 np.sum(draw), size=np.sum(draw), replace=False 

1242 ) 

1243 data = { 

1244 param: value[inds] for param, value in data.items() 

1245 } 

1246 return SamplesDict(data, logger_warn="debug") 

1247 

1248 @property 

1249 def nsamples(self): 

1250 if self.transpose: 

1251 parameters = list(self.keys()) 

1252 return len(self[parameters[0]]) 

1253 return len(self) 

1254 

1255 @property 

1256 def number_of_samples(self): 

1257 if self.transpose: 

1258 return { 

1259 label: len(self[iterator][label]) for iterator, label in zip( 

1260 self.keys(), self.labels 

1261 ) 

1262 } 

1263 return { 

1264 label: self[iterator].number_of_samples for iterator, label in zip( 

1265 self.keys(), self.labels 

1266 ) 

1267 } 

1268 

1269 @property 

1270 def total_number_of_samples(self): 

1271 return np.sum([length for length in self.number_of_samples.values()]) 

1272 

1273 @property 

1274 def minimum_number_of_samples(self): 

1275 return np.min([length for length in self.number_of_samples.values()]) 

1276 

1277 @property 

1278 def total_list_of_parameters(self): 

1279 if isinstance(self.parameters, dict): 

1280 _parameters = [item for item in self.parameters.values()] 

1281 _flat_parameters = [ 

1282 item for sublist in _parameters for item in sublist 

1283 ] 

1284 elif isinstance(self.parameters, list): 

1285 if np.array(self.parameters).ndim > 1: 

1286 _flat_parameters = [ 

1287 item for sublist in self.parameters for item in sublist 

1288 ] 

1289 else: 

1290 _flat_parameters = self.parameters 

1291 return list(set(_flat_parameters)) 

1292 

1293 def samples(self, parameter): 

1294 if self.transpose: 

1295 samples = [self[parameter][label] for label in self.labels] 

1296 else: 

1297 samples = [self[label][parameter] for label in self.labels] 

1298 return samples 

1299 

1300 

1301class MCMCSamplesDict(_MultiDimensionalSamplesDict): 

1302 """Class to store the mcmc chains from a single run 

1303 

1304 Parameters 

1305 ---------- 

1306 parameters: list 

1307 list of parameters 

1308 samples: nd list 

1309 list of samples for each parameter for each chain 

1310 transpose: Bool, optional 

1311 True if the input is a transposed dictionary 

1312 

1313 Attributes 

1314 ---------- 

1315 T: pesummary.utils.samples_dict.MCMCSamplesDict 

1316 Transposed MCMCSamplesDict object keyed by parameters rather than 

1317 chain 

1318 average: pesummary.utils.samples_dict.SamplesDict 

1319 The mean of each sample across multiple chains. If the chains are of 

1320 different lengths, all chains are resized to the minimum number of 

1321 samples 

1322 combine: pesummary.utils.samples_dict.SamplesDict 

1323 Combine all samples from all chains into a single SamplesDict object 

1324 nchains: int 

1325 Total number of chains stored in the MCMCSamplesDict object 

1326 number_of_samples: dict 

1327 Number of samples stored in the MCMCSamplesDict for each chain 

1328 total_number_of_samples: int 

1329 Total number of samples stored across the multiple chains 

1330 minimum_number_of_samples: int 

1331 The number of samples in the smallest chain 

1332 

1333 Methods 

1334 ------- 

1335 discard_samples: 

1336 Discard the first N samples for each chain 

1337 burnin: 

1338 Remove the first N samples as burnin. For different algorithms 

1339 see pesummary.core.file.mcmc.algorithms 

1340 gelman_rubin: float 

1341 Return the Gelman-Rubin statistic between the chains for a given 

1342 parameter. See pesummary.utils.utils.gelman_rubin 

1343 samples: 

1344 Return a list of samples stored in the MCMCSamplesDict object for a 

1345 given parameter 

1346 

1347 Examples 

1348 -------- 

1349 Initializing the MCMCSamplesDict class 

1350 

1351 >>> from pesummary.utils.samplesdict import MCMCSamplesDict 

1352 >>> data = { 

1353 ... "chain_0": { 

1354 ... "a": [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6], 

1355 ... "b": [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9] 

1356 ... }, 

1357 ... "chain_1": { 

1358 ... "a": [0.8, 0.5, 1.7, 1.4, 1.2, 1.7, 0.9], 

1359 ... "b": [10, 10.5, 10.4, 9.6, 8.6, 11.6, 16.2] 

1360 ... } 

1361 ... } 

1362 >>> dataset = MCMCSamplesDict(data) 

1363 >>> parameters = ["a", "b"] 

1364 >>> samples = [ 

1365 ... [ 

1366 ... [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6], 

1367 ... [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9] 

1368 ... ], [ 

1369 ... [0.8, 0.5, 1.7, 1.4, 1.2, 1.7, 0.9], 

1370 ... [10, 10.5, 10.4, 9.6, 8.6, 11.6, 16.2] 

1371 ... ] 

1372 ... ] 

1373 >>> dataset = MCMCSamplesDict(parameter, samples) 

1374 """ 

1375 def __init__(self, *args, transpose=False): 

1376 single_chain_error = ( 

1377 "This class requires more than one mcmc chain to be passed. " 

1378 "As only one dataset is available, please use the SamplesDict " 

1379 "class." 

1380 ) 

1381 super(MCMCSamplesDict, self).__init__( 

1382 *args, transpose=transpose, label_prefix="chain" 

1383 ) 

1384 self.name = MCMCSamplesDict 

1385 if len(self.labels) == 1: 

1386 raise ValueError(single_chain_error) 

1387 self.chains = self.labels 

1388 self.nchains = self.nsamples 

1389 

1390 @property 

1391 def average(self): 

1392 if self.transpose: 

1393 data = SamplesDict({ 

1394 param: np.mean( 

1395 [ 

1396 self[param][key][:self.minimum_number_of_samples] for 

1397 key in self[param].keys() 

1398 ], axis=0 

1399 ) for param in self.parameters 

1400 }, logger_warn="debug") 

1401 else: 

1402 data = SamplesDict({ 

1403 param: np.mean( 

1404 [ 

1405 self[key][param][:self.minimum_number_of_samples] for 

1406 key in self.keys() 

1407 ], axis=0 

1408 ) for param in self.parameters 

1409 }, logger_warn="debug") 

1410 return data 

1411 

1412 @property 

1413 def key_data(self): 

1414 data = {} 

1415 for param, value in self.combine.items(): 

1416 data[param] = value.key_data 

1417 return data 

1418 

1419 @property 

1420 def combine(self): 

1421 return self._combine(use_all=True, weights=None) 

1422 

1423 def discard_samples(self, number): 

1424 """Remove the first n samples 

1425 

1426 Parameters 

1427 ---------- 

1428 number: int/dict 

1429 Number of samples that you wish to remove across all chains or a 

1430 dictionary containing the number of samples to remove per chain 

1431 """ 

1432 if isinstance(number, int): 

1433 number = {chain: number for chain in self.keys()} 

1434 for chain in self.keys(): 

1435 self[chain].discard_samples(number[chain]) 

1436 return self 

1437 

1438 def burnin(self, *args, algorithm="burnin_by_step_number", **kwargs): 

1439 """Remove the first N samples as burnin 

1440 

1441 Parameters 

1442 ---------- 

1443 algorithm: str, optional 

1444 The algorithm you wish to use to remove samples as burnin. Default 

1445 is 'burnin_by_step_number'. See 

1446 `pesummary.core.file.mcmc.algorithms` for list of available 

1447 algorithms 

1448 """ 

1449 from pesummary.core.file import mcmc 

1450 

1451 if algorithm not in mcmc.algorithms: 

1452 raise ValueError( 

1453 "{} is not a valid algorithm for removing samples as " 

1454 "burnin".format(algorithm) 

1455 ) 

1456 arguments = [self] + [i for i in args] 

1457 return getattr(mcmc, algorithm)(*arguments, **kwargs) 

1458 

1459 def gelman_rubin(self, parameter, decimal=5): 

1460 """Return the gelman rubin statistic between chains for a given 

1461 parameter 

1462 

1463 Parameters 

1464 ---------- 

1465 parameter: str 

1466 name of the parameter you wish to return the gelman rubin statistic 

1467 for 

1468 decimal: int 

1469 number of decimal places to keep when rounding 

1470 """ 

1471 from pesummary.utils.utils import gelman_rubin as _gelman_rubin 

1472 

1473 return _gelman_rubin(self.samples(parameter), decimal=decimal) 

1474 

1475 

1476class MultiAnalysisSamplesDict(_MultiDimensionalSamplesDict): 

1477 """Class to samples from multiple analyses 

1478 

1479 Parameters 

1480 ---------- 

1481 parameters: list 

1482 list of parameters 

1483 samples: nd list 

1484 list of samples for each parameter for each chain 

1485 labels: list, optional 

1486 the labels to use to distinguish different analyses. 

1487 transpose: Bool, optional 

1488 True if the input is a transposed dictionary 

1489 

1490 Attributes 

1491 ---------- 

1492 T: pesummary.utils.samples_dict.MultiAnalysisSamplesDict 

1493 Transposed MultiAnalysisSamplesDict object keyed by parameters 

1494 rather than label 

1495 nsamples: int 

1496 Total number of analyses stored in the MultiAnalysisSamplesDict 

1497 object 

1498 number_of_samples: dict 

1499 Number of samples stored in the MultiAnalysisSamplesDict for each 

1500 analysis 

1501 total_number_of_samples: int 

1502 Total number of samples stored across the multiple analyses 

1503 minimum_number_of_samples: int 

1504 The number of samples in the smallest analysis 

1505 available_plots: list 

1506 list of plots which the user may user to display the contained posterior 

1507 samples 

1508 

1509 Methods 

1510 ------- 

1511 from_files: 

1512 Initialize the MultiAnalysisSamplesDict class with the contents of 

1513 multiple files 

1514 combine: pesummary.utils.samples_dict.SamplesDict 

1515 Combine samples from a select number of analyses into a single 

1516 SamplesDict object. 

1517 js_divergence: float 

1518 Return the JS divergence between two posterior distributions for a 

1519 given parameter. See pesummary.utils.utils.jensen_shannon_divergence 

1520 ks_statistic: float 

1521 Return the KS statistic between two posterior distributions for a 

1522 given parameter. See pesummary.utils.utils.kolmogorov_smirnov_test 

1523 samples: 

1524 Return a list of samples stored in the MCMCSamplesDict object for a 

1525 given parameter 

1526 write: 

1527 Save the stored posterior samples to file 

1528 """ 

1529 def __init__(self, *args, labels=None, transpose=False): 

1530 if labels is None and not isinstance(args[0], dict): 

1531 raise ValueError( 

1532 "Please provide a unique label for each analysis" 

1533 ) 

1534 super(MultiAnalysisSamplesDict, self).__init__( 

1535 *args, labels=labels, transpose=transpose 

1536 ) 

1537 self.name = MultiAnalysisSamplesDict 

1538 

1539 @classmethod 

1540 def from_files(cls, filenames, **kwargs): 

1541 """Initialize the MultiAnalysisSamplesDict class with the contents of 

1542 multiple result files 

1543 

1544 Parameters 

1545 ---------- 

1546 filenames: dict 

1547 dictionary containing the path to the result file you wish to load 

1548 as the item and a label associated with each result file as the key. 

1549 If you are providing one or more PESummary metafiles, the key 

1550 is ignored and labels stored in the metafile are used. 

1551 **kwargs: dict 

1552 all kwargs are passed to the pesummary.io.read function 

1553 """ 

1554 from pesummary.io import read 

1555 

1556 samples = {} 

1557 for label, filename in filenames.items(): 

1558 _kwargs = kwargs 

1559 if label in kwargs.keys(): 

1560 _kwargs = kwargs[label] 

1561 _file = read(filename, **_kwargs) 

1562 _samples = _file.samples_dict 

1563 if isinstance(_samples, MultiAnalysisSamplesDict): 

1564 _stored_labels = _samples.keys() 

1565 cond1 = any( 

1566 _label in filenames.keys() for _label in _stored_labels if 

1567 _label != label 

1568 ) 

1569 cond2 = any( 

1570 _label in samples.keys() for _label in _stored_labels 

1571 ) 

1572 if cond1 or cond2: 

1573 raise ValueError( 

1574 "The file '{}' contains the labels: {}. The " 

1575 "dictionary already contains the labels: {}. Please " 

1576 "provide unique labels for each dataset".format( 

1577 filename, ", ".join(_stored_labels), 

1578 ", ".join(samples.keys()) 

1579 ) 

1580 ) 

1581 samples.update(_samples) 

1582 else: 

1583 if label in samples.keys(): 

1584 raise ValueError( 

1585 "The label '{}' has alreadt been used. Please select " 

1586 "another label".format(label) 

1587 ) 

1588 samples[label] = _samples 

1589 return cls(samples) 

1590 

1591 @property 

1592 def plotting_map(self): 

1593 return { 

1594 "hist": self._marginalized_posterior, 

1595 "corner": self._corner, 

1596 "triangle": self._triangle, 

1597 "reverse_triangle": self._reverse_triangle, 

1598 "violin": self._violin, 

1599 "2d_kde": self._2d_kde 

1600 } 

1601 

1602 @property 

1603 def available_plots(self): 

1604 return list(self.plotting_map.keys()) 

1605 

1606 @docstring_subfunction([ 

1607 'pesummary.core.plots.plot._1d_comparison_histogram_plot', 

1608 'pesummary.gw.plots.plot._1d_comparison_histogram_plot', 

1609 'pesummary.core.plots.publication.triangle_plot', 

1610 'pesummary.core.plots.publication.reverse_triangle_plot' 

1611 ]) 

1612 def plot( 

1613 self, *args, type="hist", labels="all", colors=None, latex_friendly=True, 

1614 **kwargs 

1615 ): 

1616 """Generate a plot for the posterior samples stored in 

1617 MultiDimensionalSamplesDict 

1618 

1619 Parameters 

1620 ---------- 

1621 *args: tuple 

1622 all arguments are passed to the plotting function 

1623 type: str 

1624 name of the plot you wish to make 

1625 labels: list 

1626 list of analyses that you wish to include in the plot 

1627 colors: list 

1628 list of colors to use for each analysis 

1629 latex_friendly: Bool, optional 

1630 if True, make the labels latex friendly. Default True 

1631 **kwargs: dict 

1632 all additional kwargs are passed to the plotting function 

1633 """ 

1634 if type not in self.plotting_map.keys(): 

1635 raise NotImplementedError( 

1636 "The {} method is not currently implemented. The allowed " 

1637 "plotting methods are {}".format( 

1638 type, ", ".join(self.available_plots) 

1639 ) 

1640 ) 

1641 

1642 self._update_latex_labels() 

1643 if labels == "all": 

1644 labels = self.labels 

1645 elif isinstance(labels, list): 

1646 for label in labels: 

1647 if label not in self.labels: 

1648 raise ValueError( 

1649 "'{}' is not a stored analysis. The available analyses " 

1650 "are: '{}'".format(label, ", ".join(self.labels)) 

1651 ) 

1652 else: 

1653 raise ValueError( 

1654 "Please provide a list of analyses that you wish to plot" 

1655 ) 

1656 if colors is None: 

1657 colors = list(conf.colorcycle) 

1658 while len(colors) < len(labels): 

1659 colors += colors 

1660 

1661 kwargs["labels"] = labels 

1662 kwargs["colors"] = colors 

1663 kwargs["latex_friendly"] = latex_friendly 

1664 return self.plotting_map[type](*args, **kwargs) 

1665 

1666 def _marginalized_posterior( 

1667 self, parameter, module="core", labels="all", colors=None, **kwargs 

1668 ): 

1669 """Wrapper for the 

1670 `pesummary.core.plots.plot._1d_comparison_histogram_plot` or 

1671 `pesummary.gw.plots.plot._comparison_1d_histogram_plot` 

1672 

1673 Parameters 

1674 ---------- 

1675 parameter: str 

1676 name of the parameter you wish to plot 

1677 module: str, optional 

1678 module you wish to use for the plotting 

1679 labels: list 

1680 list of analyses that you wish to include in the plot 

1681 colors: list 

1682 list of colors to use for each analysis 

1683 **kwargs: dict 

1684 all additional kwargs are passed to the 

1685 `_1d_comparison_histogram_plot` function 

1686 """ 

1687 module = importlib.import_module( 

1688 "pesummary.{}.plots.plot".format(module) 

1689 ) 

1690 return getattr(module, "_1d_comparison_histogram_plot")( 

1691 parameter, [self[label][parameter] for label in labels], 

1692 colors, self.latex_labels[parameter], labels, **kwargs 

1693 ) 

1694 

1695 def _base_triangle(self, parameters, labels="all"): 

1696 """Check that the parameters are valid for the different triangle 

1697 plots available 

1698 

1699 Parameters 

1700 ---------- 

1701 parameters: list 

1702 list of parameters they wish to study 

1703 labels: list 

1704 list of analyses that you wish to include in the plot 

1705 """ 

1706 samples = [self[label] for label in labels] 

1707 if len(parameters) > 2: 

1708 raise ValueError("Function is only 2d") 

1709 condition = set( 

1710 label for num, label in enumerate(labels) for param in parameters if 

1711 param not in samples[num].keys() 

1712 ) 

1713 if len(condition): 

1714 raise ValueError( 

1715 "{} and {} are not available for the following " 

1716 " analyses: {}".format( 

1717 parameters[0], parameters[1], ", ".join(condition) 

1718 ) 

1719 ) 

1720 return samples 

1721 

1722 def _triangle(self, parameters, labels="all", module="core", **kwargs): 

1723 """Wrapper for the `pesummary.core.plots.publication.triangle_plot` 

1724 function 

1725 

1726 Parameters 

1727 ---------- 

1728 parameters: list 

1729 list of parameters they wish to study 

1730 labels: list 

1731 list of analyses that you wish to include in the plot 

1732 **kwargs: dict 

1733 all additional kwargs are passed to the `triangle_plot` function 

1734 """ 

1735 _module = importlib.import_module( 

1736 "pesummary.{}.plots.publication".format(module) 

1737 ) 

1738 samples = self._base_triangle(parameters, labels=labels) 

1739 if module == "gw": 

1740 kwargs["parameters"] = parameters 

1741 return getattr(_module, "triangle_plot")( 

1742 [_samples[parameters[0]] for _samples in samples], 

1743 [_samples[parameters[1]] for _samples in samples], 

1744 xlabel=self.latex_labels[parameters[0]], 

1745 ylabel=self.latex_labels[parameters[1]], labels=labels, **kwargs 

1746 ) 

1747 

1748 def _reverse_triangle(self, parameters, labels="all", module="core", **kwargs): 

1749 """Wrapper for the `pesummary.core.plots.publication.reverse_triangle_plot` 

1750 function 

1751 

1752 Parameters 

1753 ---------- 

1754 parameters: list 

1755 list of parameters they wish to study 

1756 labels: list 

1757 list of analyses that you wish to include in the plot 

1758 **kwargs: dict 

1759 all additional kwargs are passed to the `triangle_plot` function 

1760 """ 

1761 _module = importlib.import_module( 

1762 "pesummary.{}.plots.publication".format(module) 

1763 ) 

1764 samples = self._base_triangle(parameters, labels=labels) 

1765 if module == "gw": 

1766 kwargs["parameters"] = parameters 

1767 return getattr(_module, "reverse_triangle_plot")( 

1768 [_samples[parameters[0]] for _samples in samples], 

1769 [_samples[parameters[1]] for _samples in samples], 

1770 xlabel=self.latex_labels[parameters[0]], 

1771 ylabel=self.latex_labels[parameters[1]], labels=labels, **kwargs 

1772 ) 

1773 

1774 def _violin( 

1775 self, parameter, labels="all", priors=None, latex_labels=GWlatex_labels, 

1776 **kwargs 

1777 ): 

1778 """Wrapper for the `pesummary.gw.plots.publication.violin_plots` 

1779 function 

1780 

1781 Parameters 

1782 ---------- 

1783 parameter: str, optional 

1784 name of the parameter you wish to generate a violin plot for 

1785 labels: list 

1786 list of analyses that you wish to include in the plot 

1787 priors: MultiAnalysisSamplesDict, optional 

1788 prior samples for each analysis. If provided, the right hand side 

1789 of each violin will show the prior 

1790 latex_labels: dict, optional 

1791 dictionary containing the latex label associated with parameter 

1792 **kwargs: dict 

1793 all additional kwargs are passed to the `violin_plots` function 

1794 """ 

1795 from pesummary.gw.plots.publication import violin_plots 

1796 

1797 _labels = [label for label in labels if parameter in self[label].keys()] 

1798 if not len(_labels): 

1799 raise ValueError( 

1800 "{} is not in any of the posterior samples tables. Please " 

1801 "choose another parameter to plot".format(parameter) 

1802 ) 

1803 elif len(_labels) != len(labels): 

1804 no = list(set(labels) - set(_labels)) 

1805 logger.warning( 

1806 "Unable to generate a violin plot for {} because {} is not " 

1807 "in their posterior samples table".format( 

1808 " or ".join(no), parameter 

1809 ) 

1810 ) 

1811 samples = [self[label][parameter] for label in _labels] 

1812 if priors is not None and not all( 

1813 label in priors.keys() for label in _labels 

1814 ): 

1815 raise ValueError("Please provide prior samples for all labels") 

1816 elif priors is not None and not all( 

1817 parameter in priors[label].keys() for label in _labels 

1818 ): 

1819 raise ValueError( 

1820 "Please provide prior samples for {} for all labels".format( 

1821 parameter 

1822 ) 

1823 ) 

1824 elif priors is not None: 

1825 from pesummary.core.plots.seaborn.violin import split_dataframe 

1826 

1827 priors = [priors[label][parameter] for label in _labels] 

1828 samples = split_dataframe(samples, priors, _labels) 

1829 palette = kwargs.get("palette", None) 

1830 left, right = "color: white", "pastel" 

1831 if palette is not None and not isinstance(palette, dict): 

1832 right = palette 

1833 elif palette is not None and all( 

1834 side in palette.keys() for side in ["left", "right"] 

1835 ): 

1836 left, right = palette["left"], palette["right"] 

1837 kwargs.update( 

1838 { 

1839 "split": True, "x": "label", "y": "data", "hue": "side", 

1840 "palette": {"right": right, "left": left} 

1841 } 

1842 ) 

1843 return violin_plots( 

1844 parameter, samples, _labels, latex_labels, **kwargs 

1845 ) 

1846 

1847 def _corner(self, module="core", labels="all", parameters=None, **kwargs): 

1848 """Wrapper for the `pesummary.core.plots.plot._make_comparison_corner_plot` 

1849 or `pesummary.gw.plots.plot._make_comparison_corner_plot` function 

1850 

1851 Parameters 

1852 ---------- 

1853 module: str, optional 

1854 module you wish to use for the plotting 

1855 labels: list 

1856 list of analyses that you wish to include in the plot 

1857 **kwargs: dict 

1858 all additional kwargs are passed to the `_make_comparison_corner_plot` 

1859 function 

1860 """ 

1861 module = importlib.import_module( 

1862 "pesummary.{}.plots.plot".format(module) 

1863 ) 

1864 _samples = {label: self[label] for label in labels} 

1865 _parameters = None 

1866 if parameters is not None: 

1867 _parameters = [ 

1868 param for param in parameters if all( 

1869 param in posterior for posterior in _samples.values() 

1870 ) 

1871 ] 

1872 if not len(_parameters): 

1873 raise ValueError( 

1874 "None of the chosen parameters are in all of the posterior " 

1875 "samples tables. Please choose other parameters to plot" 

1876 ) 

1877 return getattr(module, "_make_comparison_corner_plot")( 

1878 _samples, self.latex_labels, corner_parameters=_parameters, **kwargs 

1879 ) 

1880 

1881 def _2d_kde( 

1882 self, parameters, module="core", labels="all", plot_density=None, 

1883 **kwargs 

1884 ): 

1885 """Wrapper for the 

1886 `pesummary.gw.plots.publication.comparison_twod_contour_plot` or 

1887 `pesummary.core.plots.publication.comparison_twod_contour_plot` function 

1888 

1889 Parameters 

1890 ---------- 

1891 parameters: list 

1892 list of length 2 giving the parameters you wish to plot 

1893 module: str, optional 

1894 module you wish to use for the plotting 

1895 labels: list 

1896 list of analyses that you wish to include in the plot 

1897 **kwargs: dict, optional 

1898 all additional kwargs are passed to the 

1899 `comparison_twod_contour_plot` function 

1900 """ 

1901 _module = importlib.import_module( 

1902 "pesummary.{}.plots.publication".format(module) 

1903 ) 

1904 samples = self._base_triangle(parameters, labels=labels) 

1905 if plot_density is not None: 

1906 if isinstance(plot_density, str): 

1907 plot_density = [plot_density] 

1908 elif isinstance(plot_density, bool) and plot_density: 

1909 plot_density = labels 

1910 for i in plot_density: 

1911 if i not in labels: 

1912 raise ValueError( 

1913 "Unable to plot the density for '{}'. Please choose " 

1914 "from: {}".format(plot_density, ", ".join(labels)) 

1915 ) 

1916 if module == "gw": 

1917 return getattr(_module, "twod_contour_plots")( 

1918 parameters, [ 

1919 [self[label][param] for param in parameters] for label in 

1920 labels 

1921 ], labels, { 

1922 parameters[0]: self.latex_labels[parameters[0]], 

1923 parameters[1]: self.latex_labels[parameters[1]] 

1924 }, plot_density=plot_density, **kwargs 

1925 ) 

1926 return getattr(_module, "comparison_twod_contour_plot")( 

1927 [_samples[parameters[0]] for _samples in samples], 

1928 [_samples[parameters[1]] for _samples in samples], 

1929 xlabel=self.latex_labels[parameters[0]], 

1930 ylabel=self.latex_labels[parameters[1]], labels=labels, 

1931 plot_density=plot_density, **kwargs 

1932 ) 

1933 

1934 def combine(self, **kwargs): 

1935 """Combine samples from a select number of analyses into a single 

1936 SamplesDict object. 

1937 

1938 Parameters 

1939 ---------- 

1940 labels: list, optional 

1941 analyses you wish to combine. Default use all labels stored in the 

1942 dictionary 

1943 use_all: Bool, optional 

1944 if True, use all of the samples (do not weight). Default False 

1945 weights: dict, optional 

1946 dictionary of weights for each of the posteriors. Keys must be the 

1947 labels you wish to combine and values are the weights you wish to 

1948 assign to the posterior 

1949 logger_level: str, optional 

1950 logger level you wish to use. Default debug. 

1951 """ 

1952 return self._combine(**kwargs) 

1953 

1954 def write(self, labels=None, **kwargs): 

1955 """Save the stored posterior samples to file 

1956 

1957 Parameters 

1958 ---------- 

1959 labels: list, optional 

1960 list of analyses that you wish to save to file. Default save all 

1961 analyses to file 

1962 **kwargs: dict, optional 

1963 all additional kwargs passed to the pesummary.io.write function 

1964 """ 

1965 if labels is None: 

1966 labels = self.labels 

1967 elif not all(label in self.labels for label in labels): 

1968 for label in labels: 

1969 if label not in self.labels: 

1970 raise ValueError( 

1971 "Unable to find analysis: '{}'. The list of " 

1972 "available analyses are: {}".format( 

1973 label, ", ".join(self.labels) 

1974 ) 

1975 ) 

1976 for label in labels: 

1977 self[label].write(**kwargs) 

1978 

1979 def js_divergence(self, parameter, decimal=5): 

1980 """Return the JS divergence between the posterior samples for 

1981 a given parameter 

1982 

1983 Parameters 

1984 ---------- 

1985 parameter: str 

1986 name of the parameter you wish to return the gelman rubin statistic 

1987 for 

1988 decimal: int 

1989 number of decimal places to keep when rounding 

1990 """ 

1991 from pesummary.utils.utils import jensen_shannon_divergence 

1992 

1993 return jensen_shannon_divergence( 

1994 self.samples(parameter), decimal=decimal 

1995 ) 

1996 

1997 def ks_statistic(self, parameter, decimal=5): 

1998 """Return the KS statistic between the posterior samples for 

1999 a given parameter 

2000 

2001 Parameters 

2002 ---------- 

2003 parameter: str 

2004 name of the parameter you wish to return the gelman rubin statistic 

2005 for 

2006 decimal: int 

2007 number of decimal places to keep when rounding 

2008 """ 

2009 from pesummary.utils.utils import kolmogorov_smirnov_test 

2010 

2011 return kolmogorov_smirnov_test( 

2012 self.samples(parameter), decimal=decimal 

2013 )