Coverage for pesummary/gw/waveform.py: 60.9%

169 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import numpy as np 

4import lalsimulation as lalsim 

5from lalsimulation import ( 

6 SimInspiralGetSpinFreqFromApproximant, SIM_INSPIRAL_SPINS_CASEBYCASE, 

7 SIM_INSPIRAL_SPINS_FLOW 

8) 

9from pesummary.utils.utils import iterator, logger 

10from pesummary.utils.exceptions import EvolveSpinError 

11 

12__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

13 

14 

15def _get_spin_freq_from_approximant(approximant): 

16 """Determine whether the reference frequency is the starting frequency 

17 for a given approximant string. 

18 

19 Parameters 

20 ---------- 

21 approximant: str 

22 Name of the approximant you wish to check 

23 """ 

24 try: 

25 # default to using LAL code 

26 approx = getattr(lalsim, approximant) 

27 return SimInspiralGetSpinFreqFromApproximant(approx) 

28 except AttributeError: 

29 from lalsimulation import ( 

30 SIM_INSPIRAL_SPINS_NONPRECESSING, SIM_INSPIRAL_SPINS_F_REF 

31 ) 

32 # check to see if approximant is in gwsignal 

33 from lalsimulation.gwsignal.models import gwsignal_get_waveform_generator 

34 approx = gwsignal_get_waveform_generator(approximant) 

35 meta = approx.metadata 

36 if meta["type"] == "aligned_spin": 

37 return SIM_INSPIRAL_SPINS_NONPRECESSING 

38 elif meta["type"] == "precessing_spin": 

39 if meta["f_ref_spin"]: 

40 return SIM_INSPIRAL_SPINS_F_REF 

41 return SIM_INSPIRAL_SPINS_FLOW 

42 raise EvolveSpinError( 

43 "Unable to evolve spins as '{}' does not have a set frequency " 

44 "at which the spins are defined".format(approximant) 

45 ) 

46 

47 

48def _get_start_freq_from_approximant(approximant, f_low, f_ref): 

49 """Determine the starting frequency to use when evolving the spins for 

50 a given approximant string. 

51 

52 Parameters 

53 ---------- 

54 approximant: str 

55 Name of the approximant you wish to check 

56 f_low: float 

57 Low frequency used when generating the posterior samples 

58 f_ref: float 

59 Reference frequency used when generating the posterior samples 

60 """ 

61 try: 

62 spinfreq_enum = _get_spin_freq_from_approximant(approximant) 

63 except ValueError: # raised when approximant is not in gwsignal 

64 raise EvolveSpinError( 

65 "Unable to evolve spins as '{}' is unknown to lalsimulation " 

66 "and gwsignal".format(approximant) 

67 ) 

68 if spinfreq_enum == SIM_INSPIRAL_SPINS_CASEBYCASE: 

69 _msg = ( 

70 "Unable to evolve spins as '{}' does not have a set frequency " 

71 "at which the spins are defined".format(approximant) 

72 ) 

73 logger.warning(_msg) 

74 raise EvolveSpinError(_msg) 

75 return float(np.where( 

76 np.array(spinfreq_enum == SIM_INSPIRAL_SPINS_FLOW), f_low, f_ref 

77 )) 

78 

79 

80def _check_approximant_from_string(approximant): 

81 """Check to see if the approximant is known to lalsimulation and/or 

82 gwsignal 

83 

84 Parameters 

85 ---------- 

86 approximant: str 

87 approximant you wish to check 

88 """ 

89 if hasattr(lalsim, approximant): 

90 return True 

91 else: 

92 from lalsimulation.gwsignal.models import gwsignal_get_waveform_generator 

93 try: 

94 _ = gwsignal_get_waveform_generator(approximant) 

95 except (ValueError, NameError): 

96 return False 

97 return True 

98 

99 

100def _lal_approximant_from_string(approximant): 

101 """Return the LAL approximant number given an approximant string 

102 

103 Parameters 

104 ---------- 

105 approximant: str 

106 approximant you wish to convert 

107 """ 

108 return lalsim.GetApproximantFromString(approximant) 

109 

110 

111def _insert_mode_array(modes, LAL_parameters=None): 

112 """Add a mode array to a LAL dictionary 

113 

114 Parameters 

115 ---------- 

116 modes: 2d list 

117 2d list of modes you wish to add to a LAL dictionary. Must be of the 

118 form [[l1, m1], [l2, m2]] 

119 LAL_parameters: LALDict, optional 

120 An existing LAL dictionary to add mode array to. If not provided, a new 

121 LAL dictionary is created. Default None. 

122 """ 

123 if LAL_parameters is None: 

124 import lal 

125 LAL_parameters = lal.CreateDict() 

126 _mode_array = lalsim.SimInspiralCreateModeArray() 

127 for l, m in modes: 

128 lalsim.SimInspiralModeArrayActivateMode(_mode_array, l, m) 

129 lalsim.SimInspiralWaveformParamsInsertModeArray(LAL_parameters, _mode_array) 

130 return LAL_parameters 

131 

132 

133def _waveform_args(samples, f_ref=20., ind=0, longAscNodes=0., eccentricity=0.): 

134 """Arguments to be passed to waveform generation 

135 

136 Parameters 

137 ---------- 

138 f_ref: float, optional 

139 reference frequency to use when converting spherical spins to 

140 cartesian spins 

141 ind: int, optional 

142 index for the sample you wish to plot 

143 longAscNodes: float, optional 

144 longitude of ascending nodes, degenerate with the polarization 

145 angle. Default 0. 

146 eccentricity: float, optional 

147 eccentricity at reference frequency. Default 0. 

148 """ 

149 from lal import MSUN_SI, PC_SI 

150 

151 key = list(samples.keys())[0] 

152 if isinstance(samples[key], (list, np.ndarray)): 

153 _samples = {key: value[ind] for key, value in samples.items()} 

154 else: 

155 _samples = samples.copy() 

156 required = [ 

157 "mass_1", "mass_2", "luminosity_distance" 

158 ] 

159 if not all(param in _samples.keys() for param in required): 

160 raise ValueError( 

161 "Unable to generate a waveform. Please add samples for " 

162 + ", ".join(required) 

163 ) 

164 waveform_args = [ 

165 _samples["mass_1"] * MSUN_SI, _samples["mass_2"] * MSUN_SI 

166 ] 

167 spin_angles = [ 

168 "theta_jn", "phi_jl", "tilt_1", "tilt_2", "phi_12", "a_1", "a_2", 

169 "phase" 

170 ] 

171 spin_angles_condition = all( 

172 spin in _samples.keys() for spin in spin_angles 

173 ) 

174 cartesian_spins = [ 

175 "spin_1x", "spin_1y", "spin_1z", "spin_2x", "spin_2y", "spin_2z" 

176 ] 

177 cartesian_spins_condition = any( 

178 spin in _samples.keys() for spin in cartesian_spins 

179 ) 

180 if spin_angles_condition and not cartesian_spins_condition: 

181 from pesummary.gw.conversions import component_spins 

182 data = component_spins( 

183 _samples["theta_jn"], _samples["phi_jl"], _samples["tilt_1"], 

184 _samples["tilt_2"], _samples["phi_12"], _samples["a_1"], 

185 _samples["a_2"], _samples["mass_1"], _samples["mass_2"], 

186 f_ref, _samples["phase"] 

187 ) 

188 iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = data.T 

189 spins = [spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z] 

190 else: 

191 iota = _samples["iota"] 

192 spins = [ 

193 _samples[param] if param in _samples.keys() else 0. for param in 

194 ["spin_1x", "spin_1y", "spin_1z", "spin_2x", "spin_2y", "spin_2z"] 

195 ] 

196 _zero_spins = np.isclose(spins, 0.) 

197 if sum(_zero_spins): 

198 spins = np.array(spins) 

199 spins[_zero_spins] = 0. 

200 spins = list(spins) 

201 waveform_args += spins 

202 phase = _samples["phase"] if "phase" in _samples.keys() else 0. 

203 waveform_args += [ 

204 _samples["luminosity_distance"] * PC_SI * 10**6, iota, phase 

205 ] 

206 waveform_args += [longAscNodes, eccentricity, 0.] 

207 return waveform_args, _samples 

208 

209 

210def antenna_response(samples, ifo): 

211 """ 

212 """ 

213 import importlib 

214 

215 mod = importlib.import_module("pesummary.gw.plots.plot") 

216 func = getattr(mod, "__antenna_response") 

217 antenna = func( 

218 ifo, samples["ra"], samples["dec"], samples["psi"], 

219 samples["geocent_time"] 

220 ) 

221 return antenna 

222 

223 

224def _project_waveform(ifo, hp, hc, ra, dec, psi, time): 

225 """Project a waveform onto a given detector 

226 

227 Parameters 

228 ---------- 

229 ifo: str 

230 name of the detector you wish to project the waveform onto 

231 hp: np.ndarray 

232 plus gravitational wave polarization 

233 hc: np.ndarray 

234 cross gravitational wave polarization 

235 ra: float 

236 right ascension to be passed to antenna response function 

237 dec: float 

238 declination to be passed to antenna response function 

239 psi: float 

240 polarization to be passed to antenna response function 

241 time: float 

242 time to be passed to antenna response function 

243 """ 

244 samples = { 

245 "ra": ra, "dec": dec, "psi": psi, "geocent_time": time 

246 } 

247 antenna = antenna_response(samples, ifo) 

248 ht = hp * antenna[0] + hc * antenna[1] 

249 return ht 

250 

251 

252def fd_waveform( 

253 samples, approximant, delta_f, f_low, f_high, f_ref=20., project=None, 

254 ind=0, longAscNodes=0., eccentricity=0., LAL_parameters=None, 

255 mode_array=None, pycbc=False, flen=None 

256): 

257 """Generate a gravitational wave in the frequency domain 

258 

259 Parameters 

260 ---------- 

261 approximant: str 

262 name of the approximant to use when generating the waveform 

263 delta_f: float 

264 spacing between frequency samples 

265 f_low: float 

266 frequency to start evaluating the waveform 

267 f_high: float 

268 frequency to stop evaluating the waveform 

269 f_ref: float, optional 

270 reference frequency 

271 project: str, optional 

272 name of the detector to project the waveform onto. If None, 

273 the plus and cross polarizations are returned. Default None 

274 ind: int, optional 

275 index for the sample you wish to plot 

276 longAscNodes: float, optional 

277 longitude of ascending nodes, degenerate with the polarization 

278 angle. Default 0. 

279 eccentricity: float, optional 

280 eccentricity at reference frequency. Default 0. 

281 LAL_parameters: LALDict, optional 

282 LAL dictionary containing accessory parameters. Default None 

283 mode_array: 2d list 

284 2d list of modes you wish to include in waveform. Must be of the form 

285 [[l1, m1], [l2, m2]] 

286 pycbc: Bool, optional 

287 return a the waveform as a pycbc.frequencyseries.FrequencySeries 

288 object 

289 flen: int 

290 Length of the frequency series in samples. Default is None. Only used 

291 when pycbc=True 

292 """ 

293 from gwpy.frequencyseries import FrequencySeries 

294 

295 waveform_args, _samples = _waveform_args( 

296 samples, f_ref=f_ref, ind=ind, longAscNodes=longAscNodes, 

297 eccentricity=eccentricity 

298 ) 

299 approx = _lal_approximant_from_string(approximant) 

300 if mode_array is not None: 

301 LAL_parameters = _insert_mode_array( 

302 mode_array, LAL_parameters=LAL_parameters 

303 ) 

304 hp, hc = lalsim.SimInspiralChooseFDWaveform( 

305 *waveform_args, delta_f, f_low, f_high, f_ref, LAL_parameters, approx 

306 ) 

307 hp = FrequencySeries(hp.data.data, df=hp.deltaF, f0=0.) 

308 hc = FrequencySeries(hc.data.data, df=hc.deltaF, f0=0.) 

309 if pycbc: 

310 hp, hc = hp.to_pycbc(), hc.to_pycbc() 

311 if flen is not None: 

312 hp.resize(flen) 

313 hc.resize(flen) 

314 if project is None: 

315 return {"h_plus": hp, "h_cross": hc} 

316 ht = _project_waveform( 

317 project, hp, hc, _samples["ra"], _samples["dec"], _samples["psi"], 

318 _samples["geocent_time"] 

319 ) 

320 return ht 

321 

322 

323def _wrapper_for_td_waveform(args): 

324 """Wrapper function for td_waveform for a pool of workers 

325 

326 Parameters 

327 ---------- 

328 args: tuple 

329 All args passed to td_waveform 

330 """ 

331 return td_waveform(*args) 

332 

333 

334def td_waveform( 

335 samples, approximant, delta_t, f_low, f_ref=20., project=None, ind=0, 

336 longAscNodes=0., eccentricity=0., LAL_parameters=None, mode_array=None, 

337 pycbc=False, level=None, multi_process=1 

338): 

339 """Generate a gravitational wave in the time domain 

340 

341 Parameters 

342 ---------- 

343 approximant: str 

344 name of the approximant to use when generating the waveform 

345 delta_t: float 

346 spacing between frequency samples 

347 f_low: float 

348 frequency to start evaluating the waveform 

349 f_ref: float, optional 

350 reference frequency 

351 project: str, optional 

352 name of the detector to project the waveform onto. If None, 

353 the plus and cross polarizations are returned. Default None 

354 ind: int, optional 

355 index for the sample you wish to plot 

356 longAscNodes: float, optional 

357 longitude of ascending nodes, degenerate with the polarization 

358 angle. Default 0. 

359 eccentricity: float, optional 

360 eccentricity at reference frequency. Default 0. 

361 LAL_parameters: LALDict, optional 

362 LAL dictionary containing accessory parameters. Default None 

363 mode_array: 2d list 

364 2d list of modes you wish to include in waveform. Must be of the form 

365 [[l1, m1], [l2, m2]] 

366 pycbc: Bool, optional 

367 return a the waveform as a pycbc.timeseries.TimeSeries object 

368 level: list, optional 

369 the symmetric confidence interval of the time domain waveform. Level 

370 must be greater than 0 and less than 1 

371 multi_process: int, optional 

372 number of cores to run on when generating waveforms. Only used when 

373 level is not None 

374 """ 

375 approx = _lal_approximant_from_string(approximant) 

376 if mode_array is not None: 

377 LAL_parameters = _insert_mode_array( 

378 mode_array, LAL_parameters=LAL_parameters 

379 ) 

380 if level is not None: 

381 import multiprocessing 

382 from pesummary.core.plots.interpolate import Bounded_interp1d 

383 td_waveform_list = [] 

384 _key = list(samples.keys())[0] 

385 N = len(samples[_key]) 

386 with multiprocessing.Pool(multi_process) as pool: 

387 args = np.array([ 

388 [samples] * N, [approximant] * N, [delta_t] * N, [f_low] * N, 

389 [f_ref] * N, [project] * N, np.arange(N), [longAscNodes] * N, 

390 [eccentricity] * N, [LAL_parameters] * N, [mode_array] * N, 

391 [pycbc] * N, [None] * N 

392 ], dtype="object").T 

393 td_waveform_list = list( 

394 iterator( 

395 pool.imap(_wrapper_for_td_waveform, args), 

396 tqdm=True, logger=logger, total=N, 

397 desc="Generating waveforms" 

398 ) 

399 ) 

400 td_waveform_array = np.array(td_waveform_list, dtype=object) 

401 _level = (1 + np.array(level)) / 2 

402 if project is None: 

403 mint = np.min( 

404 [ 

405 np.min([_.times[0].value for _ in waveform.values()]) for 

406 waveform in td_waveform_array 

407 ] 

408 ) 

409 maxt = np.max( 

410 [ 

411 np.max([_.times[-1].value for _ in waveform.values()]) for 

412 waveform in td_waveform_array 

413 ] 

414 ) 

415 new_t = np.arange(mint, maxt, delta_t) 

416 td_waveform_array = { 

417 polarization: np.array( 

418 [ 

419 Bounded_interp1d( 

420 np.array(waveform[polarization].times, dtype=np.float64), 

421 waveform[polarization], xlow=mint, xhigh=maxt 

422 )(new_t) for waveform in td_waveform_array 

423 ] 

424 ) for polarization in ["h_plus", "h_cross"] 

425 } 

426 else: 

427 mint = np.min([_.times[0].value for _ in td_waveform_array]) 

428 maxt = np.max([_.times[-1].value for _ in td_waveform_array]) 

429 new_t = np.arange(mint, maxt, delta_t) 

430 td_waveform_array = { 

431 "h_t": [ 

432 Bounded_interp1d( 

433 np.array(waveform.times, dtype=np.float64), waveform, 

434 xlow=mint, xhigh=maxt 

435 )(new_t) for waveform in td_waveform_array 

436 ] 

437 } 

438 

439 upper = { 

440 polarization: np.percentile( 

441 td_waveform_array[polarization], _level * 100, axis=0 

442 ) for polarization in td_waveform_array.keys() 

443 } 

444 lower = { 

445 polarization: np.percentile( 

446 td_waveform_array[polarization], (1 - _level) * 100, axis=0 

447 ) for polarization in td_waveform_array.keys() 

448 } 

449 if len(upper) == 1: 

450 upper = upper["h_t"] 

451 lower = lower["h_t"] 

452 

453 waveform_args, _samples = _waveform_args( 

454 samples, ind=ind, longAscNodes=longAscNodes, eccentricity=eccentricity, 

455 f_ref=f_ref 

456 ) 

457 waveform = _td_waveform( 

458 waveform_args, approx, delta_t, f_low, f_ref, LAL_parameters, _samples, 

459 pycbc=pycbc, project=project 

460 ) 

461 if level is not None: 

462 return waveform, upper, lower, new_t 

463 return waveform 

464 

465 

466def _td_waveform( 

467 waveform_args, approximant, delta_t, f_low, f_ref, LAL_parameters, samples, 

468 pycbc=False, project=None 

469): 

470 """Generate a gravitational wave in the time domain 

471 

472 Parameters 

473 ---------- 

474 waveform_args: tuple 

475 args to pass to lalsimulation.SimInspiralChooseTDWaveform 

476 approximant: str 

477 lalsimulation approximant number to use when generating a waveform 

478 delta_t: float 

479 spacing between time samples 

480 f_low: float 

481 frequency to start evaluating the waveform 

482 f_ref: float, optional 

483 reference frequency 

484 LAL_parameters: LALDict 

485 LAL dictionary containing accessory parameters. Default None 

486 samples: dict 

487 dictionary of posterior samples to use when projecting the waveform 

488 onto a given detector 

489 pycbc: Bool, optional 

490 return a the waveform as a pycbc.timeseries.TimeSeries object 

491 project: str, optional 

492 name of the detector to project the waveform onto. If None, 

493 the plus and cross polarizations are returned. Default None 

494 """ 

495 from gwpy.timeseries import TimeSeries 

496 from astropy.units import Quantity 

497 

498 hp, hc = lalsim.SimInspiralChooseTDWaveform( 

499 *waveform_args, delta_t, f_low, f_ref, LAL_parameters, approximant 

500 ) 

501 hp = TimeSeries(hp.data.data, dt=hp.deltaT, t0=hp.epoch) 

502 hc = TimeSeries(hc.data.data, dt=hc.deltaT, t0=hc.epoch) 

503 if pycbc: 

504 hp, hc = hp.to_pycbc(), hc.to_pycbc() 

505 if project is None: 

506 return {"h_plus": hp, "h_cross": hc} 

507 ht = _project_waveform( 

508 project, hp, hc, samples["ra"], samples["dec"], samples["psi"], 

509 samples["geocent_time"] 

510 ) 

511 if "{}_time".format(project) not in samples.keys(): 

512 from pesummary.gw.conversions import time_in_each_ifo 

513 try: 

514 _detector_time = time_in_each_ifo( 

515 project, samples["ra"], samples["dec"], samples["geocent_time"] 

516 ) 

517 except Exception: 

518 logger.warning( 

519 "Unable to calculate samples for '{}_time' using the provided " 

520 "posterior samples. Unable to shift merger to merger time in " 

521 "the detector".format(project) 

522 ) 

523 return ht 

524 else: 

525 _detector_time = samples["{}_time".format(project)] 

526 ht.times = ( 

527 Quantity(ht.times, unit="s") + Quantity(_detector_time, unit="s") 

528 ) 

529 return ht