Coverage for pesummary/gw/file/skymap.py: 55.2%

58 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-05 13:38 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import numpy as np 

4from pesummary.utils.utils import check_file_exists_and_rename, Empty 

5from pesummary import conf 

6from pesummary.utils.dict import Dict 

7 

8__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

9 

10 

11class SkyMapDict(Dict): 

12 """Class to handle a dictionary of skymaps 

13 

14 Parameters 

15 ---------- 

16 labels: list 

17 list of labels for each skymap 

18 data: nd list 

19 list of skymap probabilities for each analysis 

20 **kwargs: dict 

21 All other kwargs are turned into properties of the class. Key 

22 is the name of the property 

23 

24 Attributes 

25 ---------- 

26 labels: list 

27 list of labels stored in the dictionary 

28 

29 Methods 

30 ------- 

31 plot: 

32 Generate a plot based on the skymap samples stored 

33 

34 Examples 

35 -------- 

36 >>> skymap_1 = SkyMap.from_fits("skymap.fits") 

37 >>> skymap_2 = SkyMap.from_fits("skymap_2.fits") 

38 >>> skymap_dict = SkyMapDict( 

39 ... ["one", "two"], [skymap_1, skymap_2] 

40 ... ) 

41 >>> skymap_dict = SkyMapDict( 

42 ... {"one": skymap_1, "two": skymap_2} 

43 ... ) 

44 """ 

45 def __init__(self, *args, **kwargs): 

46 super(SkyMapDict, self).__init__( 

47 *args, value_class=Empty, deconstruct_complex_columns=False, 

48 **kwargs 

49 ) 

50 

51 @property 

52 def labels(self): 

53 return list(self.keys()) 

54 

55 def plot(self, labels="all", colors=None, show_probability_map=False, **kwargs): 

56 """Generate a plot to compare the skymaps stored in the SkyMapDict 

57 

58 Parameters 

59 ---------- 

60 labels: list, optional 

61 list of analyses you wish to compare. Default all. 

62 **kwargs: dict 

63 all additional kwargs are passed to 

64 pesummary.gw.plots.plot._ 

65 """ 

66 from pesummary.gw.plots.plot import ( 

67 _ligo_skymap_comparion_plot_from_array 

68 ) 

69 

70 _labels = self.labels 

71 if labels != "all" and isinstance(labels, list): 

72 _labels = [] 

73 for label in labels: 

74 if label not in self.labels: 

75 raise ValueError( 

76 "No skymap for '{}' is stored in the dictionary. " 

77 "The list of available analyses are: {}".format( 

78 label, ", ".join(self.labels) 

79 ) 

80 ) 

81 _labels.append(label) 

82 skymaps = [self[key] for key in _labels] 

83 if colors is None: 

84 colors = list(conf.colorcycle) 

85 

86 if show_probability_map: 

87 show_probability_map = _labels.index(show_probability_map) 

88 

89 return _ligo_skymap_comparion_plot_from_array( 

90 skymaps, colors, _labels, show_probability_map=show_probability_map, 

91 **kwargs 

92 ) 

93 

94 

95class SkyMap(np.ndarray): 

96 """Class to handle PSD data 

97 

98 Parameters 

99 ---------- 

100 probabilities: np.ndarray 

101 array of probabilities 

102 meta_data: dict, optional 

103 optional meta data associated with the skymap 

104 

105 Attributes 

106 ---------- 

107 meta_data: dict 

108 dictionary containing meta data extracted from the skymap 

109 

110 Methods 

111 ------- 

112 plot: 

113 Generate a ligo.skymap plot based on the probabilities stored 

114 """ 

115 __slots__ = ["meta_data"] 

116 

117 def __new__(cls, probabilities, meta_data=None): 

118 obj = np.asarray(probabilities).view(cls) 

119 obj.meta_data = meta_data 

120 return obj 

121 

122 def __reduce__(self): 

123 pickled_state = super(SkyMap, self).__reduce__() 

124 new_state = pickled_state[2] + tuple( 

125 [getattr(self, i) for i in self.__slots__] 

126 ) 

127 return (pickled_state[0], pickled_state[1], new_state) 

128 

129 def __setstate__(self, state): 

130 self.meta_data = state[-1] 

131 super(SkyMap, self).__setstate__(state[0:-1]) 

132 

133 @classmethod 

134 def from_fits(cls, path, nest=None): 

135 """Initiate class with the path to a ligo.skymap fits file 

136 

137 Parameters 

138 ---------- 

139 path: str 

140 path to fits file you wish to load 

141 """ 

142 from ligo.skymap.io.fits import read_sky_map 

143 

144 skymap, meta = read_sky_map(path, nest=nest) 

145 return cls(skymap, meta) 

146 

147 def save_to_file(self, file_name): 

148 """Save the calibration data to file 

149 

150 Parameters 

151 ---------- 

152 file_name: str 

153 name of the file name that you wish to use 

154 """ 

155 from ligo.skymap.io.fits import write_sky_map 

156 

157 check_file_exists_and_rename(file_name) 

158 kwargs = {} 

159 if self.meta_data is not None: 

160 kwargs = self.meta_data 

161 write_sky_map(file_name, self, **kwargs) 

162 

163 def plot(self, **kwargs): 

164 """Generate a plot with ligo.skymap 

165 """ 

166 from pesummary.gw.plots.plot import _ligo_skymap_plot_from_array 

167 

168 return _ligo_skymap_plot_from_array(self, **kwargs) 

169 

170 def __array_finalize__(self, obj): 

171 if obj is None: 

172 return 

173 self.meta_data = getattr(obj, "meta_data", None)