Coverage for pesummary/utils/samples_dict.py: 60.5%

631 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import copy 

4import numpy as np 

5from pesummary.utils.utils import resample_posterior_distribution, logger 

6from pesummary.utils.decorators import docstring_subfunction 

7from pesummary.utils.array import Array, _2DArray 

8from pesummary.utils.dict import Dict 

9from pesummary.utils.parameters import Parameters 

10from pesummary.core.plots.latex_labels import latex_labels 

11from pesummary.gw.plots.latex_labels import GWlatex_labels 

12from pesummary import conf 

13import importlib 

14 

15__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

16 

17latex_labels.update(GWlatex_labels) 

18 

19 

20class SamplesDict(Dict): 

21 """Class to store the samples from a single run 

22 

23 Parameters 

24 ---------- 

25 parameters: list 

26 list of parameters 

27 samples: nd list 

28 list of samples for each parameter 

29 autoscale: Bool, optional 

30 If True, the posterior samples for each parameter are scaled to the 

31 same length 

32 

33 Attributes 

34 ---------- 

35 maxL: pesummary.utils.samples_dict.SamplesDict 

36 SamplesDict object containing the maximum likelihood sample keyed by 

37 the parameter 

38 minimum: pesummary.utils.samples_dict.SamplesDict 

39 SamplesDict object containing the minimum sample for each parameter 

40 maximum: pesummary.utils.samples_dict.SamplesDict 

41 SamplesDict object containing the maximum sample for each parameter 

42 median: pesummary.utils.samples_dict.SamplesDict 

43 SamplesDict object containining the median of each marginalized 

44 posterior distribution 

45 mean: pesummary.utils.samples_dict.SamplesDict 

46 SamplesDict object containing the mean of each marginalized posterior 

47 distribution 

48 key_data: dict 

49 dictionary containing the key data associated with each array 

50 number_of_samples: int 

51 Number of samples stored in the SamplesDict object 

52 latex_labels: dict 

53 Dictionary of latex labels for each parameter 

54 available_plots: list 

55 list of plots which the user may user to display the contained posterior 

56 samples 

57 

58 Methods 

59 ------- 

60 from_file: 

61 Initialize the SamplesDict class with the contents of a file 

62 to_pandas: 

63 Convert the SamplesDict object to a pandas DataFrame 

64 to_structured_array: 

65 Convert the SamplesDict object to a numpy structured array 

66 pop: 

67 Remove an entry from the SamplesDict object 

68 standardize_parameter_names: 

69 Modify keys in SamplesDict to use standard PESummary names 

70 downsample: 

71 Downsample the samples stored in the SamplesDict object. See the 

72 pesummary.utils.utils.resample_posterior_distribution method 

73 discard_samples: 

74 Remove the first N samples from each distribution 

75 plot: 

76 Generate a plot based on the posterior samples stored 

77 generate_all_posterior_samples: 

78 Convert the posterior samples in the SamplesDict object according to 

79 a conversion function 

80 debug_keys: list 

81 list of keys with an '_' as their first character 

82 reweight: 

83 Reweight the posterior samples according to a new prior 

84 write: 

85 Save the stored posterior samples to file 

86 

87 Examples 

88 -------- 

89 How the initialize the SamplesDict class 

90 

91 >>> from pesummary.utils.samples_dict import SamplesDict 

92 >>> data = { 

93 ... "a": [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6], 

94 ... "b": [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9] 

95 ... } 

96 >>> dataset = SamplesDict(data) 

97 >>> parameters = ["a", "b"] 

98 >>> samples = [ 

99 ... [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6], 

100 ... [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9] 

101 ... } 

102 >>> dataset = SamplesDict(parameters, samples) 

103 >>> fig = dataset.plot("a", type="hist", bins=30) 

104 >>> fig.show() 

105 """ 

106 def __init__(self, *args, logger_warn="warn", autoscale=True,**kwargs): 

107 super(SamplesDict, self).__init__( 

108 *args, value_class=Array, make_dict_kwargs={"autoscale": autoscale}, 

109 logger_warn=logger_warn, latex_labels=latex_labels, **kwargs 

110 ) 

111 

112 def __getitem__(self, key): 

113 """Return an object representing the specialization of SamplesDict 

114 by type arguments found in key. 

115 """ 

116 if isinstance(key, slice): 

117 return SamplesDict( 

118 self.parameters, np.array( 

119 [i[key.start:key.stop:key.step] for i in self.samples] 

120 ) 

121 ) 

122 elif isinstance(key, (list, np.ndarray)): 

123 return SamplesDict( 

124 self.parameters, np.array([i[key] for i in self.samples]) 

125 ) 

126 elif key[0] == "_": 

127 return self.samples[self.parameters.index(key)] 

128 return super(SamplesDict, self).__getitem__(key) 

129 

130 def __setitem__(self, key, value): 

131 _value = value 

132 if not isinstance(value, Array): 

133 _value = Array(value) 

134 super(SamplesDict, self).__setitem__(key, _value) 

135 try: 

136 if key not in self.parameters: 

137 self.parameters.append(key) 

138 try: 

139 cond = ( 

140 np.array(self.samples).ndim == 1 and isinstance( 

141 self.samples[0], (float, int, np.number) 

142 ) 

143 ) 

144 except Exception: 

145 cond = False 

146 if cond and isinstance(self.samples, np.ndarray): 

147 self.samples = np.append(self.samples, value) 

148 elif cond and isinstance(self.samples, list): 

149 self.samples.append(value) 

150 else: 

151 self.samples = np.vstack([self.samples, value]) 

152 self._update_latex_labels() 

153 except (AttributeError, TypeError): 

154 pass 

155 

156 def __str__(self): 

157 """Print a summary of the information stored in the dictionary 

158 """ 

159 def format_string(string, row): 

160 """Format a list into a table 

161 

162 Parameters 

163 ---------- 

164 string: str 

165 existing table 

166 row: list 

167 the row you wish to be written to a table 

168 """ 

169 string += "{:<8}".format(row[0]) 

170 for i in range(1, len(row)): 

171 if isinstance(row[i], str): 

172 string += "{:<15}".format(row[i]) 

173 elif isinstance(row[i], (float, int, np.int64, np.int32)): 

174 string += "{:<15.6f}".format(row[i]) 

175 string += "\n" 

176 return string 

177 

178 string = "" 

179 string = format_string(string, ["idx"] + list(self.keys())) 

180 

181 if self.number_of_samples < 8: 

182 for i in range(self.number_of_samples): 

183 string = format_string( 

184 string, [i] + [item[i] for key, item in self.items()] 

185 ) 

186 else: 

187 for i in range(4): 

188 string = format_string( 

189 string, [i] + [item[i] for key, item in self.items()] 

190 ) 

191 for i in range(2): 

192 string = format_string(string, ["."] * (len(self.keys()) + 1)) 

193 for i in range(self.number_of_samples - 2, self.number_of_samples): 

194 string = format_string( 

195 string, [i] + [item[i] for key, item in self.items()] 

196 ) 

197 return string 

198 

199 @classmethod 

200 def from_file(cls, filename, **kwargs): 

201 """Initialize the SamplesDict class with the contents of a result file 

202 

203 Parameters 

204 ---------- 

205 filename: str 

206 path to the result file you wish to load. 

207 **kwargs: dict 

208 all kwargs are passed to the pesummary.io.read function 

209 """ 

210 from pesummary.io import read 

211 

212 return read(filename, **kwargs).samples_dict 

213 

214 @property 

215 def key_data(self): 

216 return {param: value.key_data for param, value in self.items()} 

217 

218 @property 

219 def maxL(self): 

220 return SamplesDict( 

221 self.parameters, [[item.maxL] for key, item in self.items()] 

222 ) 

223 

224 @property 

225 def minimum(self): 

226 return SamplesDict( 

227 self.parameters, [[item.minimum] for key, item in self.items()] 

228 ) 

229 

230 @property 

231 def maximum(self): 

232 return SamplesDict( 

233 self.parameters, [[item.maximum] for key, item in self.items()] 

234 ) 

235 

236 @property 

237 def median(self): 

238 return SamplesDict( 

239 self.parameters, 

240 [[item.average(type="median")] for key, item in self.items()] 

241 ) 

242 

243 @property 

244 def mean(self): 

245 return SamplesDict( 

246 self.parameters, 

247 [[item.average(type="mean")] for key, item in self.items()] 

248 ) 

249 

250 @property 

251 def number_of_samples(self): 

252 return len(self[self.parameters[0]]) 

253 

254 @property 

255 def plotting_map(self): 

256 existing = super(SamplesDict, self).plotting_map 

257 modified = existing.copy() 

258 modified.update( 

259 { 

260 "marginalized_posterior": self._marginalized_posterior, 

261 "skymap": self._skymap, 

262 "hist": self._marginalized_posterior, 

263 "corner": self._corner, 

264 "spin_disk": self._spin_disk, 

265 "2d_kde": self._2d_kde, 

266 "triangle": self._triangle, 

267 "reverse_triangle": self._reverse_triangle, 

268 } 

269 ) 

270 return modified 

271 

272 def standardize_parameter_names(self, mapping=None): 

273 """Modify keys in SamplesDict to use standard PESummary names 

274 

275 Parameters 

276 ---------- 

277 mapping: dict, optional 

278 dictionary mapping existing keys to standard PESummary names. 

279 Default pesummary.gw.file.standard_names.standard_names 

280 

281 Returns 

282 ------- 

283 standard_dict: SamplesDict 

284 SamplesDict object with standard PESummary parameter names 

285 """ 

286 from pesummary.utils.utils import map_parameter_names 

287 if mapping is None: 

288 from pesummary.gw.file.standard_names import standard_names 

289 mapping = standard_names 

290 return SamplesDict(map_parameter_names(self, mapping)) 

291 

292 def debug_keys(self, *args, **kwargs): 

293 _keys = self.keys() 

294 _total = self.keys(remove_debug=False) 

295 return Parameters([key for key in _total if key not in _keys]) 

296 

297 def keys(self, *args, remove_debug=True, **kwargs): 

298 original = super(SamplesDict, self).keys(*args, **kwargs) 

299 if remove_debug: 

300 return Parameters([key for key in original if key[0] != "_"]) 

301 return Parameters(original) 

302 

303 def write(self, **kwargs): 

304 """Save the stored posterior samples to file 

305 

306 Parameters 

307 ---------- 

308 **kwargs: dict, optional 

309 all additional kwargs passed to the pesummary.io.write function 

310 """ 

311 from pesummary.io import write 

312 write(self.parameters, self.samples.T, **kwargs) 

313 

314 def items(self, *args, remove_debug=True, **kwargs): 

315 items = super(SamplesDict, self).items(*args, **kwargs) 

316 if remove_debug: 

317 return [item for item in items if item[0][0] != "_"] 

318 return items 

319 

320 def to_pandas(self, **kwargs): 

321 """Convert a SamplesDict object to a pandas dataframe 

322 """ 

323 from pandas import DataFrame 

324 

325 return DataFrame(self, **kwargs) 

326 

327 def to_structured_array(self, **kwargs): 

328 """Convert a SamplesDict object to a structured numpy array 

329 """ 

330 return self.to_pandas(**kwargs).to_records( 

331 index=False, column_dtypes=float 

332 ) 

333 

334 def pop(self, parameter): 

335 """Delete a parameter from the SamplesDict 

336 

337 Parameters 

338 ---------- 

339 parameter: str 

340 name of the parameter you wish to remove from the SamplesDict 

341 """ 

342 if parameter not in self.parameters: 

343 logger.info( 

344 "{} not in SamplesDict. Unable to remove {}".format( 

345 parameter, parameter 

346 ) 

347 ) 

348 return 

349 ind = self.parameters.index(parameter) 

350 self.parameters.remove(parameter) 

351 samples = self.samples 

352 self.samples = np.delete(samples, ind, axis=0) 

353 return super(SamplesDict, self).pop(parameter) 

354 

355 def downsample(self, number): 

356 """Downsample the samples stored in the SamplesDict class 

357 

358 Parameters 

359 ---------- 

360 number: int 

361 Number of samples you wish to downsample to 

362 """ 

363 self.samples = resample_posterior_distribution(self.samples, number) 

364 self.make_dictionary() 

365 return self 

366 

367 def discard_samples(self, number): 

368 """Remove the first n samples 

369 

370 Parameters 

371 ---------- 

372 number: int 

373 Number of samples that you wish to remove 

374 """ 

375 self.make_dictionary(discard_samples=number) 

376 return self 

377 

378 def make_dictionary(self, discard_samples=None, autoscale=True): 

379 """Add the parameters and samples to the class 

380 """ 

381 lengths = [len(i) for i in self.samples] 

382 if len(np.unique(lengths)) > 1 and autoscale: 

383 nsamples = np.min(lengths) 

384 getattr(logger, self.logger_warn)( 

385 "Unequal number of samples for each parameter. " 

386 "Restricting all posterior samples to have {} " 

387 "samples".format(nsamples) 

388 ) 

389 self.samples = [ 

390 dataset[:nsamples] for dataset in self.samples 

391 ] 

392 if "log_likelihood" in self.parameters: 

393 likelihoods = self.samples[self.parameters.index("log_likelihood")] 

394 likelihoods = likelihoods[discard_samples:] 

395 else: 

396 likelihoods = None 

397 if "log_prior" in self.parameters: 

398 priors = self.samples[self.parameters.index("log_prior")] 

399 priors = priors[discard_samples:] 

400 else: 

401 priors = None 

402 if any(i in self.parameters for i in ["weights", "weight"]): 

403 ind = ( 

404 self.parameters.index("weights") if "weights" in self.parameters 

405 else self.parameters.index("weight") 

406 ) 

407 weights = self.samples[ind][discard_samples:] 

408 else: 

409 weights = None 

410 _2d_array = _2DArray( 

411 np.array(self.samples)[:, discard_samples:], likelihood=likelihoods, 

412 prior=priors, weights=weights 

413 ) 

414 for key, val in zip(self.parameters, _2d_array): 

415 self[key] = val 

416 

417 @docstring_subfunction([ 

418 'pesummary.core.plots.plot._1d_histogram_plot', 

419 'pesummary.gw.plots.plot._1d_histogram_plot', 

420 'pesummary.gw.plots.plot._ligo_skymap_plot', 

421 'pesummary.gw.plots.publication.spin_distribution_plots', 

422 'pesummary.core.plots.plot._make_corner_plot', 

423 'pesummary.gw.plots.plot._make_corner_plot' 

424 ]) 

425 def plot(self, *args, type="marginalized_posterior", **kwargs): 

426 """Generate a plot for the posterior samples stored in SamplesDict 

427 

428 Parameters 

429 ---------- 

430 *args: tuple 

431 all arguments are passed to the plotting function 

432 type: str 

433 name of the plot you wish to make 

434 **kwargs: dict 

435 all additional kwargs are passed to the plotting function 

436 """ 

437 return super(SamplesDict, self).plot(*args, type=type, **kwargs) 

438 

439 def generate_all_posterior_samples(self, function=None, **kwargs): 

440 """Convert samples stored in the SamplesDict according to a conversion 

441 function 

442 

443 Parameters 

444 ---------- 

445 function: func, optional 

446 function to use when converting posterior samples. Must take a 

447 dictionary as input and return a dictionary of converted posterior 

448 samples. Default `pesummary.gw.conversions.convert 

449 **kwargs: dict, optional 

450 All additional kwargs passed to function 

451 """ 

452 if function is None: 

453 from pesummary.gw.conversions import convert 

454 function = convert 

455 _samples = self.copy() 

456 _keys = list(_samples.keys()) 

457 kwargs.update({"return_dict": True}) 

458 out = function(_samples, **kwargs) 

459 if kwargs.get("return_kwargs", False): 

460 converted_samples, extra_kwargs = out 

461 else: 

462 converted_samples, extra_kwargs = out, None 

463 for key, item in converted_samples.items(): 

464 if key not in _keys: 

465 self[key] = item 

466 return extra_kwargs 

467 

468 def reweight( 

469 self, function, ignore_debug_params=["recalib", "spcal"], **kwargs 

470 ): 

471 """Reweight the posterior samples according to a new prior 

472 

473 Parameters 

474 ---------- 

475 function: func/str 

476 function to use when resampling 

477 ignore_debug_params: list, optional 

478 params to ignore when storing unweighted posterior distributions. 

479 Default any param with ['recalib', 'spcal'] in their name 

480 """ 

481 from pesummary.gw.reweight import options 

482 if isinstance(function, str) and function in options.keys(): 

483 function = options[function] 

484 elif isinstance(function, str): 

485 raise ValueError( 

486 "Unknown function '{}'. Please provide a function for " 

487 "reweighting or select one of the following: {}".format( 

488 function, ", ".join(list(options.keys())) 

489 ) 

490 ) 

491 _samples = SamplesDict(self.copy()) 

492 new_samples = function(_samples, **kwargs) 

493 _samples.downsample(new_samples.number_of_samples) 

494 for key, item in new_samples.items(): 

495 if not any(param in key for param in ignore_debug_params): 

496 _samples["_{}_non_reweighted".format(key)] = _samples[key] 

497 _samples[key] = item 

498 return SamplesDict(_samples) 

499 

500 def _marginalized_posterior(self, parameter, module="core", **kwargs): 

501 """Wrapper for the `pesummary.core.plots.plot._1d_histogram_plot` or 

502 `pesummary.gw.plots.plot._1d_histogram_plot` 

503 

504 Parameters 

505 ---------- 

506 parameter: str 

507 name of the parameter you wish to plot 

508 module: str, optional 

509 module you wish to use for the plotting 

510 **kwargs: dict 

511 all additional kwargs are passed to the `_1d_histogram_plot` 

512 function 

513 """ 

514 module = importlib.import_module( 

515 "pesummary.{}.plots.plot".format(module) 

516 ) 

517 return getattr(module, "_1d_histogram_plot")( 

518 parameter, self[parameter], self.latex_labels[parameter], 

519 weights=self[parameter].weights, **kwargs 

520 ) 

521 

522 def _skymap(self, **kwargs): 

523 """Wrapper for the `pesummary.gw.plots.plot._ligo_skymap_plot` 

524 function 

525 

526 Parameters 

527 ---------- 

528 **kwargs: dict 

529 All kwargs are passed to the `_ligo_skymap_plot` function 

530 """ 

531 from pesummary.gw.plots.plot import _ligo_skymap_plot 

532 

533 if "luminosity_distance" in self.keys(): 

534 dist = self["luminosity_distance"] 

535 else: 

536 dist = None 

537 

538 return _ligo_skymap_plot(self["ra"], self["dec"], dist=dist, **kwargs) 

539 

540 def _spin_disk(self, **kwargs): 

541 """Wrapper for the `pesummary.gw.plots.publication.spin_distribution_plots` 

542 function 

543 """ 

544 from pesummary.gw.plots.publication import spin_distribution_plots 

545 

546 required = ["a_1", "a_2", "cos_tilt_1", "cos_tilt_2"] 

547 if not all(param in self.keys() for param in required): 

548 raise ValueError( 

549 "The spin disk plot requires samples for the following " 

550 "parameters: {}".format(", ".join(required)) 

551 ) 

552 samples = [self[param] for param in required] 

553 return spin_distribution_plots(required, samples, None, **kwargs) 

554 

555 def _corner(self, module="core", parameters=None, **kwargs): 

556 """Wrapper for the `pesummary.core.plots.plot._make_corner_plot` or 

557 `pesummary.gw.plots.plot._make_corner_plot` function 

558 

559 Parameters 

560 ---------- 

561 module: str, optional 

562 module you wish to use for the plotting 

563 **kwargs: dict 

564 all additional kwargs are passed to the `_make_corner_plot` 

565 function 

566 """ 

567 module = importlib.import_module( 

568 "pesummary.{}.plots.plot".format(module) 

569 ) 

570 return getattr(module, "_make_corner_plot")( 

571 self, self.latex_labels, corner_parameters=parameters, **kwargs 

572 )[0] 

573 

574 def _2d_kde(self, parameters, module="core", **kwargs): 

575 """Wrapper for the `pesummary.gw.plots.publication.twod_contour_plot` or 

576 `pesummary.core.plots.publication.twod_contour_plot` function 

577 

578 Parameters 

579 ---------- 

580 parameters: list 

581 list of length 2 giving the parameters you wish to plot 

582 module: str, optional 

583 module you wish to use for the plotting 

584 **kwargs: dict, optional 

585 all additional kwargs are passed to the `twod_contour_plot` function 

586 """ 

587 _module = importlib.import_module( 

588 "pesummary.{}.plots.publication".format(module) 

589 ) 

590 if module == "gw": 

591 return getattr(_module, "twod_contour_plots")( 

592 parameters, [[self[parameters[0]], self[parameters[1]]]], 

593 [None], { 

594 parameters[0]: self.latex_labels[parameters[0]], 

595 parameters[1]: self.latex_labels[parameters[1]] 

596 }, **kwargs 

597 ) 

598 return getattr(_module, "twod_contour_plot")( 

599 self[parameters[0]], self[parameters[1]], 

600 xlabel=self.latex_labels[parameters[0]], 

601 ylabel=self.latex_labels[parameters[1]], **kwargs 

602 ) 

603 

604 def _triangle(self, parameters, module="core", **kwargs): 

605 """Wrapper for the `pesummary.core.plots.publication.triangle_plot` 

606 function 

607 

608 Parameters 

609 ---------- 

610 parameters: list 

611 list of parameters they wish to study 

612 **kwargs: dict 

613 all additional kwargs are passed to the `triangle_plot` function 

614 """ 

615 _module = importlib.import_module( 

616 "pesummary.{}.plots.publication".format(module) 

617 ) 

618 if module == "gw": 

619 kwargs["parameters"] = parameters 

620 return getattr(_module, "triangle_plot")( 

621 [self[parameters[0]]], [self[parameters[1]]], 

622 xlabel=self.latex_labels[parameters[0]], 

623 ylabel=self.latex_labels[parameters[1]], **kwargs 

624 ) 

625 

626 def _reverse_triangle(self, parameters, module="core", **kwargs): 

627 """Wrapper for the `pesummary.core.plots.publication.reverse_triangle_plot` 

628 function 

629 

630 Parameters 

631 ---------- 

632 parameters: list 

633 list of parameters they wish to study 

634 **kwargs: dict 

635 all additional kwargs are passed to the `triangle_plot` function 

636 """ 

637 _module = importlib.import_module( 

638 "pesummary.{}.plots.publication".format(module) 

639 ) 

640 if module == "gw": 

641 kwargs["parameters"] = parameters 

642 return getattr(_module, "reverse_triangle_plot")( 

643 [self[parameters[0]]], [self[parameters[1]]], 

644 xlabel=self.latex_labels[parameters[0]], 

645 ylabel=self.latex_labels[parameters[1]], **kwargs 

646 ) 

647 

648 def classification(self, dual=True, population=False): 

649 """Return the classification probabilities 

650 

651 Parameters 

652 ---------- 

653 dual: Bool, optional 

654 if True, return classification probabilities generated from the 

655 raw samples ('default') an samples reweighted to a population 

656 inferred prior ('population'). Default True. 

657 population: Bool, optional 

658 if True, reweight the samples to a population informed prior and 

659 then calculate classification probabilities. Default False. Only 

660 used when dual=False 

661 """ 

662 from pesummary.gw.classification import Classify 

663 if dual: 

664 probs = Classify(self).dual_classification() 

665 else: 

666 probs = Classify(self).classification(population=population) 

667 return probs 

668 

669 def _waveform_args(self, f_ref=20., ind=0, longAscNodes=0., eccentricity=0.): 

670 """Arguments to be passed to waveform generation 

671 

672 Parameters 

673 ---------- 

674 f_ref: float, optional 

675 reference frequency to use when converting spherical spins to 

676 cartesian spins 

677 ind: int, optional 

678 index for the sample you wish to plot 

679 longAscNodes: float, optional 

680 longitude of ascending nodes, degenerate with the polarization 

681 angle. Default 0. 

682 eccentricity: float, optional 

683 eccentricity at reference frequency. Default 0. 

684 """ 

685 from lal import MSUN_SI, PC_SI 

686 

687 _samples = {key: value[ind] for key, value in self.items()} 

688 required = [ 

689 "mass_1", "mass_2", "luminosity_distance" 

690 ] 

691 if not all(param in _samples.keys() for param in required): 

692 raise ValueError( 

693 "Unable to generate a waveform. Please add samples for " 

694 + ", ".join(required) 

695 ) 

696 waveform_args = [ 

697 _samples["mass_1"] * MSUN_SI, _samples["mass_2"] * MSUN_SI 

698 ] 

699 spin_angles = [ 

700 "theta_jn", "phi_jl", "tilt_1", "tilt_2", "phi_12", "a_1", "a_2", 

701 "phase" 

702 ] 

703 spin_angles_condition = all( 

704 spin in _samples.keys() for spin in spin_angles 

705 ) 

706 cartesian_spins = [ 

707 "spin_1x", "spin_1y", "spin_1z", "spin_2x", "spin_2y", "spin_2z" 

708 ] 

709 cartesian_spins_condition = any( 

710 spin in _samples.keys() for spin in cartesian_spins 

711 ) 

712 if spin_angles_condition and not cartesian_spins_condition: 

713 from pesummary.gw.conversions import component_spins 

714 data = component_spins( 

715 _samples["theta_jn"], _samples["phi_jl"], _samples["tilt_1"], 

716 _samples["tilt_2"], _samples["phi_12"], _samples["a_1"], 

717 _samples["a_2"], _samples["mass_1"], _samples["mass_2"], 

718 f_ref, _samples["phase"] 

719 ) 

720 iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = data.T 

721 spins = [spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z] 

722 else: 

723 iota = _samples["iota"] 

724 spins = [ 

725 _samples[param] if param in _samples.keys() else 0. for param in 

726 ["spin_1x", "spin_1y", "spin_1z", "spin_2x", "spin_2y", "spin_2z"] 

727 ] 

728 waveform_args += spins 

729 phase = _samples["phase"] if "phase" in _samples.keys() else 0. 

730 waveform_args += [ 

731 _samples["luminosity_distance"] * PC_SI * 10**6, iota, phase 

732 ] 

733 waveform_args += [longAscNodes, eccentricity, 0.] 

734 return waveform_args, _samples 

735 

736 def antenna_response(self, ifo): 

737 """ 

738 """ 

739 from pesummary.gw.waveform import antenna_response 

740 return antenna_response(self, ifo) 

741 

742 def _project_waveform(self, ifo, hp, hc, ra, dec, psi, time): 

743 """Project a waveform onto a given detector 

744 

745 Parameters 

746 ---------- 

747 ifo: str 

748 name of the detector you wish to project the waveform onto 

749 hp: np.ndarray 

750 plus gravitational wave polarization 

751 hc: np.ndarray 

752 cross gravitational wave polarization 

753 ra: float 

754 right ascension to be passed to antenna response function 

755 dec: float 

756 declination to be passed to antenna response function 

757 psi: float 

758 polarization to be passed to antenna response function 

759 time: float 

760 time to be passed to antenna response function 

761 """ 

762 import importlib 

763 

764 mod = importlib.import_module("pesummary.gw.plots.plot") 

765 func = getattr(mod, "__antenna_response") 

766 antenna = func(ifo, ra, dec, psi, time) 

767 ht = hp * antenna[0] + hc * antenna[1] 

768 return ht 

769 

770 def fd_waveform(self, approximant, delta_f, f_low, f_high, **kwargs): 

771 """Generate a gravitational wave in the frequency domain 

772 

773 Parameters 

774 ---------- 

775 approximant: str 

776 name of the approximant to use when generating the waveform 

777 delta_f: float 

778 spacing between frequency samples 

779 f_low: float 

780 frequency to start evaluating the waveform 

781 f_high: float 

782 frequency to stop evaluating the waveform 

783 f_ref: float, optional 

784 reference frequency 

785 project: str, optional 

786 name of the detector to project the waveform onto. If None, 

787 the plus and cross polarizations are returned. Default None 

788 ind: int, optional 

789 index for the sample you wish to plot 

790 longAscNodes: float, optional 

791 longitude of ascending nodes, degenerate with the polarization 

792 angle. Default 0. 

793 eccentricity: float, optional 

794 eccentricity at reference frequency. Default 0. 

795 LAL_parameters: dict, optional 

796 LAL dictioanry containing accessory parameters. Default None 

797 pycbc: Bool, optional 

798 return a the waveform as a pycbc.frequencyseries.FrequencySeries 

799 object 

800 """ 

801 from pesummary.gw.waveform import fd_waveform 

802 return fd_waveform(self, approximant, delta_f, f_low, f_high, **kwargs) 

803 

804 def td_waveform( 

805 self, approximant, delta_t, f_low, **kwargs 

806 ): 

807 """Generate a gravitational wave in the time domain 

808 

809 Parameters 

810 ---------- 

811 approximant: str 

812 name of the approximant to use when generating the waveform 

813 delta_t: float 

814 spacing between frequency samples 

815 f_low: float 

816 frequency to start evaluating the waveform 

817 f_ref: float, optional 

818 reference frequency 

819 project: str, optional 

820 name of the detector to project the waveform onto. If None, 

821 the plus and cross polarizations are returned. Default None 

822 ind: int, optional 

823 index for the sample you wish to plot 

824 longAscNodes: float, optional 

825 longitude of ascending nodes, degenerate with the polarization 

826 angle. Default 0. 

827 eccentricity: float, optional 

828 eccentricity at reference frequency. Default 0. 

829 LAL_parameters: dict, optional 

830 LAL dictioanry containing accessory parameters. Default None 

831 pycbc: Bool, optional 

832 return a the waveform as a pycbc.timeseries.TimeSeries object 

833 level: list, optional 

834 the symmetric confidence interval of the time domain waveform. Level 

835 must be greater than 0 and less than 1 

836 """ 

837 from pesummary.gw.waveform import td_waveform 

838 return td_waveform( 

839 self, approximant, delta_t, f_low, **kwargs 

840 ) 

841 

842 def _maxL_waveform(self, func, *args, **kwargs): 

843 """Return the maximum likelihood waveform in a given domain 

844 

845 Parameters 

846 ---------- 

847 func: function 

848 function you wish to use when generating the maximum likelihood 

849 waveform 

850 *args: tuple 

851 all args passed to func 

852 **kwargs: dict 

853 all kwargs passed to func 

854 """ 

855 ind = np.argmax(self["log_likelihood"]) 

856 kwargs["ind"] = ind 

857 return func(*args, **kwargs) 

858 

859 def maxL_td_waveform(self, *args, **kwargs): 

860 """Generate the maximum likelihood gravitational wave in the time domain 

861 

862 Parameters 

863 ---------- 

864 approximant: str 

865 name of the approximant to use when generating the waveform 

866 delta_t: float 

867 spacing between frequency samples 

868 f_low: float 

869 frequency to start evaluating the waveform 

870 f_ref: float, optional 

871 reference frequency 

872 project: str, optional 

873 name of the detector to project the waveform onto. If None, 

874 the plus and cross polarizations are returned. Default None 

875 longAscNodes: float, optional 

876 longitude of ascending nodes, degenerate with the polarization 

877 angle. Default 0. 

878 eccentricity: float, optional 

879 eccentricity at reference frequency. Default 0. 

880 LAL_parameters: dict, optional 

881 LAL dictioanry containing accessory parameters. Default None 

882 level: list, optional 

883 the symmetric confidence interval of the time domain waveform. Level 

884 must be greater than 0 and less than 1 

885 """ 

886 return self._maxL_waveform(self.td_waveform, *args, **kwargs) 

887 

888 def maxL_fd_waveform(self, *args, **kwargs): 

889 """Generate the maximum likelihood gravitational wave in the frequency 

890 domain 

891 

892 Parameters 

893 ---------- 

894 approximant: str 

895 name of the approximant to use when generating the waveform 

896 delta_f: float 

897 spacing between frequency samples 

898 f_low: float 

899 frequency to start evaluating the waveform 

900 f_high: float 

901 frequency to stop evaluating the waveform 

902 f_ref: float, optional 

903 reference frequency 

904 project: str, optional 

905 name of the detector to project the waveform onto. If None, 

906 the plus and cross polarizations are returned. Default None 

907 longAscNodes: float, optional 

908 longitude of ascending nodes, degenerate with the polarization 

909 angle. Default 0. 

910 eccentricity: float, optional 

911 eccentricity at reference frequency. Default 0. 

912 LAL_parameters: dict, optional 

913 LAL dictioanry containing accessory parameters. Default None 

914 """ 

915 return self._maxL_waveform(self.fd_waveform, *args, **kwargs) 

916 

917 

918class _MultiDimensionalSamplesDict(Dict): 

919 """Class to store multiple SamplesDict objects 

920 

921 Parameters 

922 ---------- 

923 parameters: list 

924 list of parameters 

925 samples: nd list 

926 list of samples for each parameter for each chain 

927 label_prefix: str, optional 

928 prefix to use when distinguishing different analyses. The label is then 

929 '{label_prefix}_{num}' where num is the result file index. Default 

930 is 'dataset' 

931 transpose: Bool, optional 

932 True if the input is a transposed dictionary 

933 labels: list, optional 

934 the labels to use to distinguish different analyses. If provided 

935 label_prefix is ignored 

936 

937 Attributes 

938 ---------- 

939 T: pesummary.utils.samples_dict._MultiDimensionalSamplesDict 

940 Transposed _MultiDimensionalSamplesDict object keyed by parameters 

941 rather than label 

942 nsamples: int 

943 Total number of analyses stored in the _MultiDimensionalSamplesDict 

944 object 

945 number_of_samples: dict 

946 Number of samples stored in the _MultiDimensionalSamplesDict for each 

947 analysis 

948 total_number_of_samples: int 

949 Total number of samples stored across the multiple analyses 

950 minimum_number_of_samples: int 

951 The number of samples in the smallest analysis 

952 

953 Methods 

954 ------- 

955 samples: 

956 Return a list of samples stored in the _MultiDimensionalSamplesDict 

957 object for a given parameter 

958 """ 

959 def __init__( 

960 self, *args, label_prefix="dataset", transpose=False, labels=None 

961 ): 

962 if labels is not None and len(np.unique(labels)) != len(labels): 

963 raise ValueError( 

964 "Please provide a unique set of labels for each analysis" 

965 ) 

966 invalid_label_number_error = "Please provide a label for each analysis" 

967 self.labels = labels 

968 self.name = _MultiDimensionalSamplesDict 

969 self.transpose = transpose 

970 if len(args) == 1 and isinstance(args[0], dict): 

971 if transpose: 

972 parameters = list(args[0].keys()) 

973 _labels = list(args[0][parameters[0]].keys()) 

974 outer_iterator, inner_iterator = parameters, _labels 

975 else: 

976 _labels = list(args[0].keys()) 

977 parameters = { 

978 label: list(args[0][label].keys()) for label in _labels 

979 } 

980 outer_iterator, inner_iterator = _labels, parameters 

981 if labels is None: 

982 self.labels = _labels 

983 for num, dataset in enumerate(outer_iterator): 

984 if isinstance(inner_iterator, dict): 

985 try: 

986 samples = np.array( 

987 [args[0][dataset][param] for param in inner_iterator[dataset]] 

988 ) 

989 except ValueError: # numpy deprecation error 

990 samples = np.array( 

991 [args[0][dataset][param] for param in inner_iterator[dataset]], 

992 dtype=object 

993 ) 

994 else: 

995 try: 

996 samples = np.array( 

997 [args[0][dataset][param] for param in inner_iterator] 

998 ) 

999 except ValueError: # numpy deprecation error 

1000 samples = np.array( 

1001 [args[0][dataset][param] for param in inner_iterator], 

1002 dtype=object 

1003 ) 

1004 if transpose: 

1005 desc = parameters[num] 

1006 self[desc] = SamplesDict( 

1007 self.labels, samples, logger_warn="debug", 

1008 autoscale=False 

1009 ) 

1010 else: 

1011 if self.labels is not None: 

1012 desc = self.labels[num] 

1013 else: 

1014 desc = "{}_{}".format(label_prefix, num) 

1015 self[desc] = SamplesDict(parameters[self.labels[num]], samples) 

1016 else: 

1017 parameters, samples = args 

1018 if labels is not None and len(labels) != len(samples): 

1019 raise ValueError(invalid_label_number_error) 

1020 for num, dataset in enumerate(samples): 

1021 if labels is not None: 

1022 desc = labels[num] 

1023 else: 

1024 desc = "{}_{}".format(label_prefix, num) 

1025 self[desc] = SamplesDict(parameters, dataset) 

1026 if self.labels is None: 

1027 self.labels = [ 

1028 "{}_{}".format(label_prefix, num) for num, _ in 

1029 enumerate(samples) 

1030 ] 

1031 self.parameters = parameters 

1032 self._update_latex_labels() 

1033 

1034 def _update_latex_labels(self): 

1035 """Update the stored latex labels 

1036 """ 

1037 _parameters = [ 

1038 list(value.keys()) for value in self.values() 

1039 ] 

1040 _parameters = [item for sublist in _parameters for item in sublist] 

1041 self._latex_labels = { 

1042 param: latex_labels[param] if param in latex_labels.keys() else 

1043 param for param in self.total_list_of_parameters + _parameters 

1044 } 

1045 

1046 def __setitem__(self, key, value): 

1047 _value = value 

1048 if not isinstance(value, SamplesDict): 

1049 _value = SamplesDict(value) 

1050 super(_MultiDimensionalSamplesDict, self).__setitem__(key, _value) 

1051 try: 

1052 if key not in self.labels: 

1053 parameters = list(value.keys()) 

1054 try: 

1055 samples = np.array([value[param] for param in parameters]) 

1056 except ValueError: # numpy deprecation error 

1057 samples = np.array( 

1058 [value[param] for param in parameters], dtype=object 

1059 ) 

1060 self.parameters[key] = parameters 

1061 self.labels.append(key) 

1062 self.latex_labels = self._latex_labels() 

1063 except (AttributeError, TypeError): 

1064 pass 

1065 

1066 @property 

1067 def T(self): 

1068 _transpose = not self.transpose 

1069 if not self.transpose: 

1070 _params = sorted([param for param in self[self.labels[0]].keys()]) 

1071 if not all(sorted(self[l].keys()) == _params for l in self.labels): 

1072 raise ValueError( 

1073 "Unable to transpose as not all samples have the same " 

1074 "parameters" 

1075 ) 

1076 transpose_dict = { 

1077 param: { 

1078 label: dataset[param] for label, dataset in self.items() 

1079 } for param in self[self.labels[0]].keys() 

1080 } 

1081 else: 

1082 transpose_dict = { 

1083 label: { 

1084 param: self[param][label] for param in self.keys() 

1085 } for label in self.labels 

1086 } 

1087 return self.name(transpose_dict, transpose=_transpose) 

1088 

1089 def _combine( 

1090 self, labels=None, use_all=False, weights=None, shuffle=False, 

1091 logger_level="debug" 

1092 ): 

1093 """Combine samples from a select number of analyses into a single 

1094 SamplesDict object. 

1095 

1096 Parameters 

1097 ---------- 

1098 labels: list, optional 

1099 analyses you wish to combine. Default use all labels stored in the 

1100 dictionary 

1101 use_all: Bool, optional 

1102 if True, use all of the samples (do not weight). Default False 

1103 weights: dict, optional 

1104 dictionary of weights for each of the posteriors. Keys must be the 

1105 labels you wish to combine and values are the weights you wish to 

1106 assign to the posterior 

1107 shuffle: Bool, optional 

1108 shuffle the combined samples 

1109 logger_level: str, optional 

1110 logger level you wish to use. Default debug. 

1111 """ 

1112 try: 

1113 _logger = getattr(logger, logger_level) 

1114 except AttributeError: 

1115 raise ValueError( 

1116 "Unknown logger level. Please choose either 'info' or 'debug'" 

1117 ) 

1118 if labels is None: 

1119 _provided_labels = False 

1120 labels = self.labels 

1121 else: 

1122 _provided_labels = True 

1123 if not all(label in self.labels for label in labels): 

1124 raise ValueError( 

1125 "Not all of the provided labels exist in the dictionary. " 

1126 "The list of available labels are: {}".format( 

1127 ", ".join(self.labels) 

1128 ) 

1129 ) 

1130 _logger("Combining the following analyses: {}".format(labels)) 

1131 if use_all and weights is not None: 

1132 raise ValueError( 

1133 "Unable to use all samples and provide weights" 

1134 ) 

1135 elif not use_all and weights is None: 

1136 weights = {label: 1. for label in labels} 

1137 elif not use_all and weights is not None: 

1138 if len(weights) < len(labels): 

1139 raise ValueError( 

1140 "Please provide weights for each set of samples: {}".format( 

1141 len(labels) 

1142 ) 

1143 ) 

1144 if not _provided_labels and not isinstance(weights, dict): 

1145 raise ValueError( 

1146 "Weights must be provided as a dictionary keyed by the " 

1147 "analysis label. The available labels are: {}".format( 

1148 ", ".join(labels) 

1149 ) 

1150 ) 

1151 elif not isinstance(weights, dict): 

1152 weights = { 

1153 label: weight for label, weight in zip(labels, weights) 

1154 } 

1155 if not all(label in labels for label in weights.keys()): 

1156 for label in labels: 

1157 if label not in weights.keys(): 

1158 weights[label] = 1. 

1159 logger.warning( 

1160 "No weight given for '{}'. Assigning a weight of " 

1161 "1".format(label) 

1162 ) 

1163 sum_weights = np.sum([_weight for _weight in weights.values()]) 

1164 weights = { 

1165 key: item / sum_weights for key, item in weights.items() 

1166 } 

1167 if weights is not None: 

1168 _logger( 

1169 "Using the following weights for each file, {}".format( 

1170 " ".join( 

1171 ["{}: {}".format(k, v) for k, v in weights.items()] 

1172 ) 

1173 ) 

1174 ) 

1175 _lengths = np.array( 

1176 [self.number_of_samples[key] for key in labels] 

1177 ) 

1178 if use_all: 

1179 draw = _lengths 

1180 else: 

1181 draw = np.zeros(len(labels), dtype=int) 

1182 _weights = np.array([weights[key] for key in labels]) 

1183 inds = np.argwhere(_weights > 0.) 

1184 # The next 4 lines are inspired from the 'cbcBayesCombinePosteriors' 

1185 # executable provided by LALSuite. Credit should go to the 

1186 # authors of that code. 

1187 initial = _weights[inds] * float(sum(_lengths[inds])) 

1188 min_index = np.argmin(_lengths[inds] / initial) 

1189 size = _lengths[inds][min_index] / _weights[inds][min_index] 

1190 draw[inds] = np.around(_weights[inds] * size).astype(int) 

1191 _logger( 

1192 "Randomly drawing the following number of samples from each file, " 

1193 "{}".format( 

1194 " ".join( 

1195 [ 

1196 "{}: {}/{}".format(l, draw[n], _lengths[n]) for n, l in 

1197 enumerate(labels) 

1198 ] 

1199 ) 

1200 ) 

1201 ) 

1202 

1203 if self.transpose: 

1204 _data = self.T 

1205 else: 

1206 _data = copy.deepcopy(self) 

1207 for num, label in enumerate(labels): 

1208 if draw[num] > 0: 

1209 _data[label].downsample(draw[num]) 

1210 else: 

1211 _data[label] = { 

1212 param: np.array([]) for param in _data[label].keys() 

1213 } 

1214 try: 

1215 intersection = set.intersection( 

1216 *[ 

1217 set(_params) for _key, _params in _data.parameters.items() if 

1218 _key in labels 

1219 ] 

1220 ) 

1221 except AttributeError: 

1222 intersection = _data.parameters 

1223 logger.debug( 

1224 "Only including the parameters: {} as they are common to all " 

1225 "analyses".format(", ".join(list(intersection))) 

1226 ) 

1227 data = { 

1228 param: np.concatenate([_data[key][param] for key in labels]) for 

1229 param in intersection 

1230 } 

1231 if shuffle: 

1232 inds = np.random.choice( 

1233 np.sum(draw), size=np.sum(draw), replace=False 

1234 ) 

1235 data = { 

1236 param: value[inds] for param, value in data.items() 

1237 } 

1238 return SamplesDict(data, logger_warn="debug") 

1239 

1240 @property 

1241 def nsamples(self): 

1242 if self.transpose: 

1243 parameters = list(self.keys()) 

1244 return len(self[parameters[0]]) 

1245 return len(self) 

1246 

1247 @property 

1248 def number_of_samples(self): 

1249 if self.transpose: 

1250 return { 

1251 label: len(self[iterator][label]) for iterator, label in zip( 

1252 self.keys(), self.labels 

1253 ) 

1254 } 

1255 return { 

1256 label: self[iterator].number_of_samples for iterator, label in zip( 

1257 self.keys(), self.labels 

1258 ) 

1259 } 

1260 

1261 @property 

1262 def total_number_of_samples(self): 

1263 return np.sum([length for length in self.number_of_samples.values()]) 

1264 

1265 @property 

1266 def minimum_number_of_samples(self): 

1267 return np.min([length for length in self.number_of_samples.values()]) 

1268 

1269 @property 

1270 def total_list_of_parameters(self): 

1271 if isinstance(self.parameters, dict): 

1272 _parameters = [item for item in self.parameters.values()] 

1273 _flat_parameters = [ 

1274 item for sublist in _parameters for item in sublist 

1275 ] 

1276 elif isinstance(self.parameters, list): 

1277 if np.array(self.parameters).ndim > 1: 

1278 _flat_parameters = [ 

1279 item for sublist in self.parameters for item in sublist 

1280 ] 

1281 else: 

1282 _flat_parameters = self.parameters 

1283 return list(set(_flat_parameters)) 

1284 

1285 def samples(self, parameter): 

1286 if self.transpose: 

1287 samples = [self[parameter][label] for label in self.labels] 

1288 else: 

1289 samples = [self[label][parameter] for label in self.labels] 

1290 return samples 

1291 

1292 

1293class MCMCSamplesDict(_MultiDimensionalSamplesDict): 

1294 """Class to store the mcmc chains from a single run 

1295 

1296 Parameters 

1297 ---------- 

1298 parameters: list 

1299 list of parameters 

1300 samples: nd list 

1301 list of samples for each parameter for each chain 

1302 transpose: Bool, optional 

1303 True if the input is a transposed dictionary 

1304 

1305 Attributes 

1306 ---------- 

1307 T: pesummary.utils.samples_dict.MCMCSamplesDict 

1308 Transposed MCMCSamplesDict object keyed by parameters rather than 

1309 chain 

1310 average: pesummary.utils.samples_dict.SamplesDict 

1311 The mean of each sample across multiple chains. If the chains are of 

1312 different lengths, all chains are resized to the minimum number of 

1313 samples 

1314 combine: pesummary.utils.samples_dict.SamplesDict 

1315 Combine all samples from all chains into a single SamplesDict object 

1316 nchains: int 

1317 Total number of chains stored in the MCMCSamplesDict object 

1318 number_of_samples: dict 

1319 Number of samples stored in the MCMCSamplesDict for each chain 

1320 total_number_of_samples: int 

1321 Total number of samples stored across the multiple chains 

1322 minimum_number_of_samples: int 

1323 The number of samples in the smallest chain 

1324 

1325 Methods 

1326 ------- 

1327 discard_samples: 

1328 Discard the first N samples for each chain 

1329 burnin: 

1330 Remove the first N samples as burnin. For different algorithms 

1331 see pesummary.core.file.mcmc.algorithms 

1332 gelman_rubin: float 

1333 Return the Gelman-Rubin statistic between the chains for a given 

1334 parameter. See pesummary.utils.utils.gelman_rubin 

1335 samples: 

1336 Return a list of samples stored in the MCMCSamplesDict object for a 

1337 given parameter 

1338 

1339 Examples 

1340 -------- 

1341 Initializing the MCMCSamplesDict class 

1342 

1343 >>> from pesummary.utils.samplesdict import MCMCSamplesDict 

1344 >>> data = { 

1345 ... "chain_0": { 

1346 ... "a": [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6], 

1347 ... "b": [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9] 

1348 ... }, 

1349 ... "chain_1": { 

1350 ... "a": [0.8, 0.5, 1.7, 1.4, 1.2, 1.7, 0.9], 

1351 ... "b": [10, 10.5, 10.4, 9.6, 8.6, 11.6, 16.2] 

1352 ... } 

1353 ... } 

1354 >>> dataset = MCMCSamplesDict(data) 

1355 >>> parameters = ["a", "b"] 

1356 >>> samples = [ 

1357 ... [ 

1358 ... [1, 1.2, 1.7, 1.1, 1.4, 0.8, 1.6], 

1359 ... [10.2, 11.3, 11.6, 9.5, 8.6, 10.8, 10.9] 

1360 ... ], [ 

1361 ... [0.8, 0.5, 1.7, 1.4, 1.2, 1.7, 0.9], 

1362 ... [10, 10.5, 10.4, 9.6, 8.6, 11.6, 16.2] 

1363 ... ] 

1364 ... ] 

1365 >>> dataset = MCMCSamplesDict(parameter, samples) 

1366 """ 

1367 def __init__(self, *args, transpose=False): 

1368 single_chain_error = ( 

1369 "This class requires more than one mcmc chain to be passed. " 

1370 "As only one dataset is available, please use the SamplesDict " 

1371 "class." 

1372 ) 

1373 super(MCMCSamplesDict, self).__init__( 

1374 *args, transpose=transpose, label_prefix="chain" 

1375 ) 

1376 self.name = MCMCSamplesDict 

1377 if len(self.labels) == 1: 

1378 raise ValueError(single_chain_error) 

1379 self.chains = self.labels 

1380 self.nchains = self.nsamples 

1381 

1382 @property 

1383 def average(self): 

1384 if self.transpose: 

1385 data = SamplesDict({ 

1386 param: np.mean( 

1387 [ 

1388 self[param][key][:self.minimum_number_of_samples] for 

1389 key in self[param].keys() 

1390 ], axis=0 

1391 ) for param in self.parameters 

1392 }, logger_warn="debug") 

1393 else: 

1394 data = SamplesDict({ 

1395 param: np.mean( 

1396 [ 

1397 self[key][param][:self.minimum_number_of_samples] for 

1398 key in self.keys() 

1399 ], axis=0 

1400 ) for param in self.parameters 

1401 }, logger_warn="debug") 

1402 return data 

1403 

1404 @property 

1405 def key_data(self): 

1406 data = {} 

1407 for param, value in self.combine.items(): 

1408 data[param] = value.key_data 

1409 return data 

1410 

1411 @property 

1412 def combine(self): 

1413 return self._combine(use_all=True, weights=None) 

1414 

1415 def discard_samples(self, number): 

1416 """Remove the first n samples 

1417 

1418 Parameters 

1419 ---------- 

1420 number: int/dict 

1421 Number of samples that you wish to remove across all chains or a 

1422 dictionary containing the number of samples to remove per chain 

1423 """ 

1424 if isinstance(number, int): 

1425 number = {chain: number for chain in self.keys()} 

1426 for chain in self.keys(): 

1427 self[chain].discard_samples(number[chain]) 

1428 return self 

1429 

1430 def burnin(self, *args, algorithm="burnin_by_step_number", **kwargs): 

1431 """Remove the first N samples as burnin 

1432 

1433 Parameters 

1434 ---------- 

1435 algorithm: str, optional 

1436 The algorithm you wish to use to remove samples as burnin. Default 

1437 is 'burnin_by_step_number'. See 

1438 `pesummary.core.file.mcmc.algorithms` for list of available 

1439 algorithms 

1440 """ 

1441 from pesummary.core.file import mcmc 

1442 

1443 if algorithm not in mcmc.algorithms: 

1444 raise ValueError( 

1445 "{} is not a valid algorithm for removing samples as " 

1446 "burnin".format(algorithm) 

1447 ) 

1448 arguments = [self] + [i for i in args] 

1449 return getattr(mcmc, algorithm)(*arguments, **kwargs) 

1450 

1451 def gelman_rubin(self, parameter, decimal=5): 

1452 """Return the gelman rubin statistic between chains for a given 

1453 parameter 

1454 

1455 Parameters 

1456 ---------- 

1457 parameter: str 

1458 name of the parameter you wish to return the gelman rubin statistic 

1459 for 

1460 decimal: int 

1461 number of decimal places to keep when rounding 

1462 """ 

1463 from pesummary.utils.utils import gelman_rubin as _gelman_rubin 

1464 

1465 return _gelman_rubin(self.samples(parameter), decimal=decimal) 

1466 

1467 

1468class MultiAnalysisSamplesDict(_MultiDimensionalSamplesDict): 

1469 """Class to samples from multiple analyses 

1470 

1471 Parameters 

1472 ---------- 

1473 parameters: list 

1474 list of parameters 

1475 samples: nd list 

1476 list of samples for each parameter for each chain 

1477 labels: list, optional 

1478 the labels to use to distinguish different analyses. 

1479 transpose: Bool, optional 

1480 True if the input is a transposed dictionary 

1481 

1482 Attributes 

1483 ---------- 

1484 T: pesummary.utils.samples_dict.MultiAnalysisSamplesDict 

1485 Transposed MultiAnalysisSamplesDict object keyed by parameters 

1486 rather than label 

1487 nsamples: int 

1488 Total number of analyses stored in the MultiAnalysisSamplesDict 

1489 object 

1490 number_of_samples: dict 

1491 Number of samples stored in the MultiAnalysisSamplesDict for each 

1492 analysis 

1493 total_number_of_samples: int 

1494 Total number of samples stored across the multiple analyses 

1495 minimum_number_of_samples: int 

1496 The number of samples in the smallest analysis 

1497 available_plots: list 

1498 list of plots which the user may user to display the contained posterior 

1499 samples 

1500 

1501 Methods 

1502 ------- 

1503 from_files: 

1504 Initialize the MultiAnalysisSamplesDict class with the contents of 

1505 multiple files 

1506 combine: pesummary.utils.samples_dict.SamplesDict 

1507 Combine samples from a select number of analyses into a single 

1508 SamplesDict object. 

1509 js_divergence: float 

1510 Return the JS divergence between two posterior distributions for a 

1511 given parameter. See pesummary.utils.utils.jensen_shannon_divergence 

1512 ks_statistic: float 

1513 Return the KS statistic between two posterior distributions for a 

1514 given parameter. See pesummary.utils.utils.kolmogorov_smirnov_test 

1515 samples: 

1516 Return a list of samples stored in the MCMCSamplesDict object for a 

1517 given parameter 

1518 write: 

1519 Save the stored posterior samples to file 

1520 """ 

1521 def __init__(self, *args, labels=None, transpose=False): 

1522 if labels is None and not isinstance(args[0], dict): 

1523 raise ValueError( 

1524 "Please provide a unique label for each analysis" 

1525 ) 

1526 super(MultiAnalysisSamplesDict, self).__init__( 

1527 *args, labels=labels, transpose=transpose 

1528 ) 

1529 self.name = MultiAnalysisSamplesDict 

1530 

1531 @classmethod 

1532 def from_files(cls, filenames, **kwargs): 

1533 """Initialize the MultiAnalysisSamplesDict class with the contents of 

1534 multiple result files 

1535 

1536 Parameters 

1537 ---------- 

1538 filenames: dict 

1539 dictionary containing the path to the result file you wish to load 

1540 as the item and a label associated with each result file as the key. 

1541 If you are providing one or more PESummary metafiles, the key 

1542 is ignored and labels stored in the metafile are used. 

1543 **kwargs: dict 

1544 all kwargs are passed to the pesummary.io.read function 

1545 """ 

1546 from pesummary.io import read 

1547 

1548 samples = {} 

1549 for label, filename in filenames.items(): 

1550 _kwargs = kwargs 

1551 if label in kwargs.keys(): 

1552 _kwargs = kwargs[label] 

1553 _file = read(filename, **_kwargs) 

1554 _samples = _file.samples_dict 

1555 if isinstance(_samples, MultiAnalysisSamplesDict): 

1556 _stored_labels = _samples.keys() 

1557 cond1 = any( 

1558 _label in filenames.keys() for _label in _stored_labels if 

1559 _label != label 

1560 ) 

1561 cond2 = any( 

1562 _label in samples.keys() for _label in _stored_labels 

1563 ) 

1564 if cond1 or cond2: 

1565 raise ValueError( 

1566 "The file '{}' contains the labels: {}. The " 

1567 "dictionary already contains the labels: {}. Please " 

1568 "provide unique labels for each dataset".format( 

1569 filename, ", ".join(_stored_labels), 

1570 ", ".join(samples.keys()) 

1571 ) 

1572 ) 

1573 samples.update(_samples) 

1574 else: 

1575 if label in samples.keys(): 

1576 raise ValueError( 

1577 "The label '{}' has alreadt been used. Please select " 

1578 "another label".format(label) 

1579 ) 

1580 samples[label] = _samples 

1581 return cls(samples) 

1582 

1583 @property 

1584 def plotting_map(self): 

1585 return { 

1586 "hist": self._marginalized_posterior, 

1587 "corner": self._corner, 

1588 "triangle": self._triangle, 

1589 "reverse_triangle": self._reverse_triangle, 

1590 "violin": self._violin, 

1591 "2d_kde": self._2d_kde 

1592 } 

1593 

1594 @property 

1595 def available_plots(self): 

1596 return list(self.plotting_map.keys()) 

1597 

1598 @docstring_subfunction([ 

1599 'pesummary.core.plots.plot._1d_comparison_histogram_plot', 

1600 'pesummary.gw.plots.plot._1d_comparison_histogram_plot', 

1601 'pesummary.core.plots.publication.triangle_plot', 

1602 'pesummary.core.plots.publication.reverse_triangle_plot' 

1603 ]) 

1604 def plot( 

1605 self, *args, type="hist", labels="all", colors=None, latex_friendly=True, 

1606 **kwargs 

1607 ): 

1608 """Generate a plot for the posterior samples stored in 

1609 MultiDimensionalSamplesDict 

1610 

1611 Parameters 

1612 ---------- 

1613 *args: tuple 

1614 all arguments are passed to the plotting function 

1615 type: str 

1616 name of the plot you wish to make 

1617 labels: list 

1618 list of analyses that you wish to include in the plot 

1619 colors: list 

1620 list of colors to use for each analysis 

1621 latex_friendly: Bool, optional 

1622 if True, make the labels latex friendly. Default True 

1623 **kwargs: dict 

1624 all additional kwargs are passed to the plotting function 

1625 """ 

1626 if type not in self.plotting_map.keys(): 

1627 raise NotImplementedError( 

1628 "The {} method is not currently implemented. The allowed " 

1629 "plotting methods are {}".format( 

1630 type, ", ".join(self.available_plots) 

1631 ) 

1632 ) 

1633 

1634 self._update_latex_labels() 

1635 if labels == "all": 

1636 labels = self.labels 

1637 elif isinstance(labels, list): 

1638 for label in labels: 

1639 if label not in self.labels: 

1640 raise ValueError( 

1641 "'{}' is not a stored analysis. The available analyses " 

1642 "are: '{}'".format(label, ", ".join(self.labels)) 

1643 ) 

1644 else: 

1645 raise ValueError( 

1646 "Please provide a list of analyses that you wish to plot" 

1647 ) 

1648 if colors is None: 

1649 colors = list(conf.colorcycle) 

1650 while len(colors) < len(labels): 

1651 colors += colors 

1652 

1653 kwargs["labels"] = labels 

1654 kwargs["colors"] = colors 

1655 kwargs["latex_friendly"] = latex_friendly 

1656 return self.plotting_map[type](*args, **kwargs) 

1657 

1658 def _marginalized_posterior( 

1659 self, parameter, module="core", labels="all", colors=None, **kwargs 

1660 ): 

1661 """Wrapper for the 

1662 `pesummary.core.plots.plot._1d_comparison_histogram_plot` or 

1663 `pesummary.gw.plots.plot._comparison_1d_histogram_plot` 

1664 

1665 Parameters 

1666 ---------- 

1667 parameter: str 

1668 name of the parameter you wish to plot 

1669 module: str, optional 

1670 module you wish to use for the plotting 

1671 labels: list 

1672 list of analyses that you wish to include in the plot 

1673 colors: list 

1674 list of colors to use for each analysis 

1675 **kwargs: dict 

1676 all additional kwargs are passed to the 

1677 `_1d_comparison_histogram_plot` function 

1678 """ 

1679 module = importlib.import_module( 

1680 "pesummary.{}.plots.plot".format(module) 

1681 ) 

1682 return getattr(module, "_1d_comparison_histogram_plot")( 

1683 parameter, [self[label][parameter] for label in labels], 

1684 colors, self.latex_labels[parameter], labels, **kwargs 

1685 ) 

1686 

1687 def _base_triangle(self, parameters, labels="all"): 

1688 """Check that the parameters are valid for the different triangle 

1689 plots available 

1690 

1691 Parameters 

1692 ---------- 

1693 parameters: list 

1694 list of parameters they wish to study 

1695 labels: list 

1696 list of analyses that you wish to include in the plot 

1697 """ 

1698 samples = [self[label] for label in labels] 

1699 if len(parameters) > 2: 

1700 raise ValueError("Function is only 2d") 

1701 condition = set( 

1702 label for num, label in enumerate(labels) for param in parameters if 

1703 param not in samples[num].keys() 

1704 ) 

1705 if len(condition): 

1706 raise ValueError( 

1707 "{} and {} are not available for the following " 

1708 " analyses: {}".format( 

1709 parameters[0], parameters[1], ", ".join(condition) 

1710 ) 

1711 ) 

1712 return samples 

1713 

1714 def _triangle(self, parameters, labels="all", module="core", **kwargs): 

1715 """Wrapper for the `pesummary.core.plots.publication.triangle_plot` 

1716 function 

1717 

1718 Parameters 

1719 ---------- 

1720 parameters: list 

1721 list of parameters they wish to study 

1722 labels: list 

1723 list of analyses that you wish to include in the plot 

1724 **kwargs: dict 

1725 all additional kwargs are passed to the `triangle_plot` function 

1726 """ 

1727 _module = importlib.import_module( 

1728 "pesummary.{}.plots.publication".format(module) 

1729 ) 

1730 samples = self._base_triangle(parameters, labels=labels) 

1731 if module == "gw": 

1732 kwargs["parameters"] = parameters 

1733 return getattr(_module, "triangle_plot")( 

1734 [_samples[parameters[0]] for _samples in samples], 

1735 [_samples[parameters[1]] for _samples in samples], 

1736 xlabel=self.latex_labels[parameters[0]], 

1737 ylabel=self.latex_labels[parameters[1]], labels=labels, **kwargs 

1738 ) 

1739 

1740 def _reverse_triangle(self, parameters, labels="all", module="core", **kwargs): 

1741 """Wrapper for the `pesummary.core.plots.publication.reverse_triangle_plot` 

1742 function 

1743 

1744 Parameters 

1745 ---------- 

1746 parameters: list 

1747 list of parameters they wish to study 

1748 labels: list 

1749 list of analyses that you wish to include in the plot 

1750 **kwargs: dict 

1751 all additional kwargs are passed to the `triangle_plot` function 

1752 """ 

1753 _module = importlib.import_module( 

1754 "pesummary.{}.plots.publication".format(module) 

1755 ) 

1756 samples = self._base_triangle(parameters, labels=labels) 

1757 if module == "gw": 

1758 kwargs["parameters"] = parameters 

1759 return getattr(_module, "reverse_triangle_plot")( 

1760 [_samples[parameters[0]] for _samples in samples], 

1761 [_samples[parameters[1]] for _samples in samples], 

1762 xlabel=self.latex_labels[parameters[0]], 

1763 ylabel=self.latex_labels[parameters[1]], labels=labels, **kwargs 

1764 ) 

1765 

1766 def _violin( 

1767 self, parameter, labels="all", priors=None, latex_labels=GWlatex_labels, 

1768 **kwargs 

1769 ): 

1770 """Wrapper for the `pesummary.gw.plots.publication.violin_plots` 

1771 function 

1772 

1773 Parameters 

1774 ---------- 

1775 parameter: str, optional 

1776 name of the parameter you wish to generate a violin plot for 

1777 labels: list 

1778 list of analyses that you wish to include in the plot 

1779 priors: MultiAnalysisSamplesDict, optional 

1780 prior samples for each analysis. If provided, the right hand side 

1781 of each violin will show the prior 

1782 latex_labels: dict, optional 

1783 dictionary containing the latex label associated with parameter 

1784 **kwargs: dict 

1785 all additional kwargs are passed to the `violin_plots` function 

1786 """ 

1787 from pesummary.gw.plots.publication import violin_plots 

1788 

1789 _labels = [label for label in labels if parameter in self[label].keys()] 

1790 if not len(_labels): 

1791 raise ValueError( 

1792 "{} is not in any of the posterior samples tables. Please " 

1793 "choose another parameter to plot".format(parameter) 

1794 ) 

1795 elif len(_labels) != len(labels): 

1796 no = list(set(labels) - set(_labels)) 

1797 logger.warning( 

1798 "Unable to generate a violin plot for {} because {} is not " 

1799 "in their posterior samples table".format( 

1800 " or ".join(no), parameter 

1801 ) 

1802 ) 

1803 samples = [self[label][parameter] for label in _labels] 

1804 if priors is not None and not all( 

1805 label in priors.keys() for label in _labels 

1806 ): 

1807 raise ValueError("Please provide prior samples for all labels") 

1808 elif priors is not None and not all( 

1809 parameter in priors[label].keys() for label in _labels 

1810 ): 

1811 raise ValueError( 

1812 "Please provide prior samples for {} for all labels".format( 

1813 parameter 

1814 ) 

1815 ) 

1816 elif priors is not None: 

1817 from pesummary.core.plots.seaborn.violin import split_dataframe 

1818 

1819 priors = [priors[label][parameter] for label in _labels] 

1820 samples = split_dataframe(samples, priors, _labels) 

1821 palette = kwargs.get("palette", None) 

1822 left, right = "color: white", "pastel" 

1823 if palette is not None and not isinstance(palette, dict): 

1824 right = palette 

1825 elif palette is not None and all( 

1826 side in palette.keys() for side in ["left", "right"] 

1827 ): 

1828 left, right = palette["left"], palette["right"] 

1829 kwargs.update( 

1830 { 

1831 "split": True, "x": "label", "y": "data", "hue": "side", 

1832 "palette": {"right": right, "left": left} 

1833 } 

1834 ) 

1835 return violin_plots( 

1836 parameter, samples, _labels, latex_labels, **kwargs 

1837 ) 

1838 

1839 def _corner(self, module="core", labels="all", parameters=None, **kwargs): 

1840 """Wrapper for the `pesummary.core.plots.plot._make_comparison_corner_plot` 

1841 or `pesummary.gw.plots.plot._make_comparison_corner_plot` function 

1842 

1843 Parameters 

1844 ---------- 

1845 module: str, optional 

1846 module you wish to use for the plotting 

1847 labels: list 

1848 list of analyses that you wish to include in the plot 

1849 **kwargs: dict 

1850 all additional kwargs are passed to the `_make_comparison_corner_plot` 

1851 function 

1852 """ 

1853 module = importlib.import_module( 

1854 "pesummary.{}.plots.plot".format(module) 

1855 ) 

1856 _samples = {label: self[label] for label in labels} 

1857 _parameters = None 

1858 if parameters is not None: 

1859 _parameters = [ 

1860 param for param in parameters if all( 

1861 param in posterior for posterior in _samples.values() 

1862 ) 

1863 ] 

1864 if not len(_parameters): 

1865 raise ValueError( 

1866 "None of the chosen parameters are in all of the posterior " 

1867 "samples tables. Please choose other parameters to plot" 

1868 ) 

1869 return getattr(module, "_make_comparison_corner_plot")( 

1870 _samples, self.latex_labels, corner_parameters=_parameters, **kwargs 

1871 ) 

1872 

1873 def _2d_kde( 

1874 self, parameters, module="core", labels="all", plot_density=None, 

1875 **kwargs 

1876 ): 

1877 """Wrapper for the 

1878 `pesummary.gw.plots.publication.comparison_twod_contour_plot` or 

1879 `pesummary.core.plots.publication.comparison_twod_contour_plot` function 

1880 

1881 Parameters 

1882 ---------- 

1883 parameters: list 

1884 list of length 2 giving the parameters you wish to plot 

1885 module: str, optional 

1886 module you wish to use for the plotting 

1887 labels: list 

1888 list of analyses that you wish to include in the plot 

1889 **kwargs: dict, optional 

1890 all additional kwargs are passed to the 

1891 `comparison_twod_contour_plot` function 

1892 """ 

1893 _module = importlib.import_module( 

1894 "pesummary.{}.plots.publication".format(module) 

1895 ) 

1896 samples = self._base_triangle(parameters, labels=labels) 

1897 if plot_density is not None: 

1898 if isinstance(plot_density, str): 

1899 plot_density = [plot_density] 

1900 elif isinstance(plot_density, bool) and plot_density: 

1901 plot_density = labels 

1902 for i in plot_density: 

1903 if i not in labels: 

1904 raise ValueError( 

1905 "Unable to plot the density for '{}'. Please choose " 

1906 "from: {}".format(plot_density, ", ".join(labels)) 

1907 ) 

1908 if module == "gw": 

1909 return getattr(_module, "twod_contour_plots")( 

1910 parameters, [ 

1911 [self[label][param] for param in parameters] for label in 

1912 labels 

1913 ], labels, { 

1914 parameters[0]: self.latex_labels[parameters[0]], 

1915 parameters[1]: self.latex_labels[parameters[1]] 

1916 }, plot_density=plot_density, **kwargs 

1917 ) 

1918 return getattr(_module, "comparison_twod_contour_plot")( 

1919 [_samples[parameters[0]] for _samples in samples], 

1920 [_samples[parameters[1]] for _samples in samples], 

1921 xlabel=self.latex_labels[parameters[0]], 

1922 ylabel=self.latex_labels[parameters[1]], labels=labels, 

1923 plot_density=plot_density, **kwargs 

1924 ) 

1925 

1926 def combine(self, **kwargs): 

1927 """Combine samples from a select number of analyses into a single 

1928 SamplesDict object. 

1929 

1930 Parameters 

1931 ---------- 

1932 labels: list, optional 

1933 analyses you wish to combine. Default use all labels stored in the 

1934 dictionary 

1935 use_all: Bool, optional 

1936 if True, use all of the samples (do not weight). Default False 

1937 weights: dict, optional 

1938 dictionary of weights for each of the posteriors. Keys must be the 

1939 labels you wish to combine and values are the weights you wish to 

1940 assign to the posterior 

1941 logger_level: str, optional 

1942 logger level you wish to use. Default debug. 

1943 """ 

1944 return self._combine(**kwargs) 

1945 

1946 def write(self, labels=None, **kwargs): 

1947 """Save the stored posterior samples to file 

1948 

1949 Parameters 

1950 ---------- 

1951 labels: list, optional 

1952 list of analyses that you wish to save to file. Default save all 

1953 analyses to file 

1954 **kwargs: dict, optional 

1955 all additional kwargs passed to the pesummary.io.write function 

1956 """ 

1957 if labels is None: 

1958 labels = self.labels 

1959 elif not all(label in self.labels for label in labels): 

1960 for label in labels: 

1961 if label not in self.labels: 

1962 raise ValueError( 

1963 "Unable to find analysis: '{}'. The list of " 

1964 "available analyses are: {}".format( 

1965 label, ", ".join(self.labels) 

1966 ) 

1967 ) 

1968 for label in labels: 

1969 self[label].write(**kwargs) 

1970 

1971 def js_divergence(self, parameter, decimal=5): 

1972 """Return the JS divergence between the posterior samples for 

1973 a given parameter 

1974 

1975 Parameters 

1976 ---------- 

1977 parameter: str 

1978 name of the parameter you wish to return the gelman rubin statistic 

1979 for 

1980 decimal: int 

1981 number of decimal places to keep when rounding 

1982 """ 

1983 from pesummary.utils.utils import jensen_shannon_divergence 

1984 

1985 return jensen_shannon_divergence( 

1986 self.samples(parameter), decimal=decimal 

1987 ) 

1988 

1989 def ks_statistic(self, parameter, decimal=5): 

1990 """Return the KS statistic between the posterior samples for 

1991 a given parameter 

1992 

1993 Parameters 

1994 ---------- 

1995 parameter: str 

1996 name of the parameter you wish to return the gelman rubin statistic 

1997 for 

1998 decimal: int 

1999 number of decimal places to keep when rounding 

2000 """ 

2001 from pesummary.utils.utils import kolmogorov_smirnov_test 

2002 

2003 return kolmogorov_smirnov_test( 

2004 self.samples(parameter), decimal=decimal 

2005 )