Coverage for pesummary/gw/file/skymap.py: 55.2%
58 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import numpy as np
4from pesummary.utils.utils import check_file_exists_and_rename, Empty
5from pesummary import conf
6from pesummary.utils.dict import Dict
8__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
11class SkyMapDict(Dict):
12 """Class to handle a dictionary of skymaps
14 Parameters
15 ----------
16 labels: list
17 list of labels for each skymap
18 data: nd list
19 list of skymap probabilities for each analysis
20 **kwargs: dict
21 All other kwargs are turned into properties of the class. Key
22 is the name of the property
24 Attributes
25 ----------
26 labels: list
27 list of labels stored in the dictionary
29 Methods
30 -------
31 plot:
32 Generate a plot based on the skymap samples stored
34 Examples
35 --------
36 >>> skymap_1 = SkyMap.from_fits("skymap.fits")
37 >>> skymap_2 = SkyMap.from_fits("skymap_2.fits")
38 >>> skymap_dict = SkyMapDict(
39 ... ["one", "two"], [skymap_1, skymap_2]
40 ... )
41 >>> skymap_dict = SkyMapDict(
42 ... {"one": skymap_1, "two": skymap_2}
43 ... )
44 """
45 def __init__(self, *args, **kwargs):
46 super(SkyMapDict, self).__init__(*args, value_class=Empty, **kwargs)
48 @property
49 def labels(self):
50 return list(self.keys())
52 def plot(self, labels="all", colors=None, show_probability_map=False, **kwargs):
53 """Generate a plot to compare the skymaps stored in the SkyMapDict
55 Parameters
56 ----------
57 labels: list, optional
58 list of analyses you wish to compare. Default all.
59 **kwargs: dict
60 all additional kwargs are passed to
61 pesummary.gw.plots.plot._
62 """
63 from pesummary.gw.plots.plot import (
64 _ligo_skymap_comparion_plot_from_array
65 )
67 _labels = self.labels
68 if labels != "all" and isinstance(labels, list):
69 _labels = []
70 for label in labels:
71 if label not in self.labels:
72 raise ValueError(
73 "No skymap for '{}' is stored in the dictionary. "
74 "The list of available analyses are: {}".format(
75 label, ", ".join(self.labels)
76 )
77 )
78 _labels.append(label)
79 skymaps = [self[key] for key in _labels]
80 if colors is None:
81 colors = list(conf.colorcycle)
83 if show_probability_map:
84 show_probability_map = _labels.index(show_probability_map)
86 return _ligo_skymap_comparion_plot_from_array(
87 skymaps, colors, _labels, show_probability_map=show_probability_map,
88 **kwargs
89 )
92class SkyMap(np.ndarray):
93 """Class to handle PSD data
95 Parameters
96 ----------
97 probabilities: np.ndarray
98 array of probabilities
99 meta_data: dict, optional
100 optional meta data associated with the skymap
102 Attributes
103 ----------
104 meta_data: dict
105 dictionary containing meta data extracted from the skymap
107 Methods
108 -------
109 plot:
110 Generate a ligo.skymap plot based on the probabilities stored
111 """
112 __slots__ = ["meta_data"]
114 def __new__(cls, probabilities, meta_data=None):
115 obj = np.asarray(probabilities).view(cls)
116 obj.meta_data = meta_data
117 return obj
119 def __reduce__(self):
120 pickled_state = super(SkyMap, self).__reduce__()
121 new_state = pickled_state[2] + tuple(
122 [getattr(self, i) for i in self.__slots__]
123 )
124 return (pickled_state[0], pickled_state[1], new_state)
126 def __setstate__(self, state):
127 self.meta_data = state[-1]
128 super(SkyMap, self).__setstate__(state[0:-1])
130 @classmethod
131 def from_fits(cls, path, nest=None):
132 """Initiate class with the path to a ligo.skymap fits file
134 Parameters
135 ----------
136 path: str
137 path to fits file you wish to load
138 """
139 from ligo.skymap.io.fits import read_sky_map
141 skymap, meta = read_sky_map(path, nest=nest)
142 return cls(skymap, meta)
144 def save_to_file(self, file_name):
145 """Save the calibration data to file
147 Parameters
148 ----------
149 file_name: str
150 name of the file name that you wish to use
151 """
152 from ligo.skymap.io.fits import write_sky_map
154 check_file_exists_and_rename(file_name)
155 kwargs = {}
156 if self.meta_data is not None:
157 kwargs = self.meta_data
158 write_sky_map(file_name, self, **kwargs)
160 def plot(self, **kwargs):
161 """Generate a plot with ligo.skymap
162 """
163 from pesummary.gw.plots.plot import _ligo_skymap_plot_from_array
165 return _ligo_skymap_plot_from_array(self, **kwargs)
167 def __array_finalize__(self, obj):
168 if obj is None:
169 return
170 self.meta_data = getattr(obj, "meta_data", None)