Coverage for pesummary/tests/ligo_skymap_test.py: 100.0%
47 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import os
4import shutil
6import argparse
8from pesummary.core.plots import plot
9from pesummary.gw.plots import plot as gwplot
10from pesummary.utils.array import Array
11from subprocess import CalledProcessError
13from astropy.coordinates import (CartesianRepresentation, SkyCoord,
14 SphericalRepresentation)
15from astropy.table import Table, setdiff
16from astropy.utils.misc import NumpyRNGContext
17from astropy import units as u
18import numpy as np
19from scipy import stats
20import pytest
21import matplotlib
23from ligo.skymap.io.hdf5 import read_samples, write_samples
24from ligo.skymap.tool.tests import run_entry_point
27import numpy as np
28import matplotlib
29from matplotlib import rcParams
30import pytest
32__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
33rcParams["text.usetex"] = False
34np.random.seed(150914)
37@pytest.fixture
38def samples(tmpdir):
39 mean = SkyCoord(ra=stats.uniform(0, 360).rvs() * u.hourangle,
40 dec=np.arcsin(stats.uniform(-1, 1).rvs()) * u.radian,
41 distance=stats.uniform(100, 200).rvs()).cartesian.xyz.value
42 eigvals = stats.uniform(0, 1).rvs(3)
43 eigvals *= len(eigvals) / eigvals.sum()
44 cov = stats.random_correlation.rvs(eigvals) * 100
45 pts = stats.multivariate_normal(mean, cov).rvs(200)
46 pts = SkyCoord(pts, representation_type=CartesianRepresentation)
47 pts.representation_type = SphericalRepresentation
48 time = stats.uniform(-0.01, 0.01).rvs(200) + 1e9
49 table = Table({
50 'ra': pts.ra.rad, 'dec': pts.dec.rad, 'distance': pts.distance.value,
51 'time': time
52 })
53 filename = str(tmpdir / 'samples.hdf5')
54 write_samples(table, filename, path='/posterior_samples')
55 return filename
58@pytest.mark.ligoskymaptest
59def test_ligo_skymap(samples, tmpdir):
60 run_entry_point('ligo-skymap-from-samples', '--seed', '150914',
61 samples, '-o', str(tmpdir),
62 '--instruments', 'H1', 'L1', 'V1', '--objid', 'S1234')
63 table = Table.read(str(tmpdir / 'skymap.fits'), format='fits')
64 _samples = Table.read(samples)
65 fig = gwplot._ligo_skymap_plot(
66 _samples["ra"], _samples["dec"], dist=_samples["distance"],
67 savedir=str(tmpdir), label="pesummary"
68 )
69 pesummary_table = Table.read(
70 str(tmpdir / 'pesummary_skymap.fits'), format='fits'
71 )
72 diff = setdiff(table, pesummary_table)
73 assert not len(diff)