Coverage for pesummary/tests/ligo_skymap_test.py: 100.0%

47 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import os 

4import shutil 

5 

6import argparse 

7 

8from pesummary.core.plots import plot 

9from pesummary.gw.plots import plot as gwplot 

10from pesummary.utils.array import Array 

11from subprocess import CalledProcessError 

12 

13from astropy.coordinates import (CartesianRepresentation, SkyCoord, 

14 SphericalRepresentation) 

15from astropy.table import Table, setdiff 

16from astropy.utils.misc import NumpyRNGContext 

17from astropy import units as u 

18import numpy as np 

19from scipy import stats 

20import pytest 

21import matplotlib 

22 

23from ligo.skymap.io.hdf5 import read_samples, write_samples 

24from ligo.skymap.tool.tests import run_entry_point 

25 

26 

27import numpy as np 

28import matplotlib 

29from matplotlib import rcParams 

30import pytest 

31 

32__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

33rcParams["text.usetex"] = False 

34np.random.seed(150914) 

35 

36 

37@pytest.fixture 

38def samples(tmpdir): 

39 mean = SkyCoord(ra=stats.uniform(0, 360).rvs() * u.hourangle, 

40 dec=np.arcsin(stats.uniform(-1, 1).rvs()) * u.radian, 

41 distance=stats.uniform(100, 200).rvs()).cartesian.xyz.value 

42 eigvals = stats.uniform(0, 1).rvs(3) 

43 eigvals *= len(eigvals) / eigvals.sum() 

44 cov = stats.random_correlation.rvs(eigvals) * 100 

45 pts = stats.multivariate_normal(mean, cov).rvs(200) 

46 pts = SkyCoord(pts, representation_type=CartesianRepresentation) 

47 pts.representation_type = SphericalRepresentation 

48 time = stats.uniform(-0.01, 0.01).rvs(200) + 1e9 

49 table = Table({ 

50 'ra': pts.ra.rad, 'dec': pts.dec.rad, 'distance': pts.distance.value, 

51 'time': time 

52 }) 

53 filename = str(tmpdir / 'samples.hdf5') 

54 write_samples(table, filename, path='/posterior_samples') 

55 return filename 

56 

57 

58@pytest.mark.ligoskymaptest 

59def test_ligo_skymap(samples, tmpdir): 

60 run_entry_point('ligo-skymap-from-samples', '--seed', '150914', 

61 samples, '-o', str(tmpdir), 

62 '--instruments', 'H1', 'L1', 'V1', '--objid', 'S1234') 

63 table = Table.read(str(tmpdir / 'skymap.fits'), format='fits') 

64 _samples = Table.read(samples) 

65 fig = gwplot._ligo_skymap_plot( 

66 _samples["ra"], _samples["dec"], dist=_samples["distance"], 

67 savedir=str(tmpdir), label="pesummary" 

68 ) 

69 pesummary_table = Table.read( 

70 str(tmpdir / 'pesummary_skymap.fits'), format='fits' 

71 ) 

72 diff = setdiff(table, pesummary_table) 

73 assert not len(diff)