Coverage for pesummary/gw/file/skymap.py: 55.2%

58 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import numpy as np 

4from pesummary.utils.utils import check_file_exists_and_rename, Empty 

5from pesummary import conf 

6from pesummary.utils.dict import Dict 

7 

8__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

9 

10 

11class SkyMapDict(Dict): 

12 """Class to handle a dictionary of skymaps 

13 

14 Parameters 

15 ---------- 

16 labels: list 

17 list of labels for each skymap 

18 data: nd list 

19 list of skymap probabilities for each analysis 

20 **kwargs: dict 

21 All other kwargs are turned into properties of the class. Key 

22 is the name of the property 

23 

24 Attributes 

25 ---------- 

26 labels: list 

27 list of labels stored in the dictionary 

28 

29 Methods 

30 ------- 

31 plot: 

32 Generate a plot based on the skymap samples stored 

33 

34 Examples 

35 -------- 

36 >>> skymap_1 = SkyMap.from_fits("skymap.fits") 

37 >>> skymap_2 = SkyMap.from_fits("skymap_2.fits") 

38 >>> skymap_dict = SkyMapDict( 

39 ... ["one", "two"], [skymap_1, skymap_2] 

40 ... ) 

41 >>> skymap_dict = SkyMapDict( 

42 ... {"one": skymap_1, "two": skymap_2} 

43 ... ) 

44 """ 

45 def __init__(self, *args, **kwargs): 

46 super(SkyMapDict, self).__init__(*args, value_class=Empty, **kwargs) 

47 

48 @property 

49 def labels(self): 

50 return list(self.keys()) 

51 

52 def plot(self, labels="all", colors=None, show_probability_map=False, **kwargs): 

53 """Generate a plot to compare the skymaps stored in the SkyMapDict 

54 

55 Parameters 

56 ---------- 

57 labels: list, optional 

58 list of analyses you wish to compare. Default all. 

59 **kwargs: dict 

60 all additional kwargs are passed to 

61 pesummary.gw.plots.plot._ 

62 """ 

63 from pesummary.gw.plots.plot import ( 

64 _ligo_skymap_comparion_plot_from_array 

65 ) 

66 

67 _labels = self.labels 

68 if labels != "all" and isinstance(labels, list): 

69 _labels = [] 

70 for label in labels: 

71 if label not in self.labels: 

72 raise ValueError( 

73 "No skymap for '{}' is stored in the dictionary. " 

74 "The list of available analyses are: {}".format( 

75 label, ", ".join(self.labels) 

76 ) 

77 ) 

78 _labels.append(label) 

79 skymaps = [self[key] for key in _labels] 

80 if colors is None: 

81 colors = list(conf.colorcycle) 

82 

83 if show_probability_map: 

84 show_probability_map = _labels.index(show_probability_map) 

85 

86 return _ligo_skymap_comparion_plot_from_array( 

87 skymaps, colors, _labels, show_probability_map=show_probability_map, 

88 **kwargs 

89 ) 

90 

91 

92class SkyMap(np.ndarray): 

93 """Class to handle PSD data 

94 

95 Parameters 

96 ---------- 

97 probabilities: np.ndarray 

98 array of probabilities 

99 meta_data: dict, optional 

100 optional meta data associated with the skymap 

101 

102 Attributes 

103 ---------- 

104 meta_data: dict 

105 dictionary containing meta data extracted from the skymap 

106 

107 Methods 

108 ------- 

109 plot: 

110 Generate a ligo.skymap plot based on the probabilities stored 

111 """ 

112 __slots__ = ["meta_data"] 

113 

114 def __new__(cls, probabilities, meta_data=None): 

115 obj = np.asarray(probabilities).view(cls) 

116 obj.meta_data = meta_data 

117 return obj 

118 

119 def __reduce__(self): 

120 pickled_state = super(SkyMap, self).__reduce__() 

121 new_state = pickled_state[2] + tuple( 

122 [getattr(self, i) for i in self.__slots__] 

123 ) 

124 return (pickled_state[0], pickled_state[1], new_state) 

125 

126 def __setstate__(self, state): 

127 self.meta_data = state[-1] 

128 super(SkyMap, self).__setstate__(state[0:-1]) 

129 

130 @classmethod 

131 def from_fits(cls, path, nest=None): 

132 """Initiate class with the path to a ligo.skymap fits file 

133 

134 Parameters 

135 ---------- 

136 path: str 

137 path to fits file you wish to load 

138 """ 

139 from ligo.skymap.io.fits import read_sky_map 

140 

141 skymap, meta = read_sky_map(path, nest=nest) 

142 return cls(skymap, meta) 

143 

144 def save_to_file(self, file_name): 

145 """Save the calibration data to file 

146 

147 Parameters 

148 ---------- 

149 file_name: str 

150 name of the file name that you wish to use 

151 """ 

152 from ligo.skymap.io.fits import write_sky_map 

153 

154 check_file_exists_and_rename(file_name) 

155 kwargs = {} 

156 if self.meta_data is not None: 

157 kwargs = self.meta_data 

158 write_sky_map(file_name, self, **kwargs) 

159 

160 def plot(self, **kwargs): 

161 """Generate a plot with ligo.skymap 

162 """ 

163 from pesummary.gw.plots.plot import _ligo_skymap_plot_from_array 

164 

165 return _ligo_skymap_plot_from_array(self, **kwargs) 

166 

167 def __array_finalize__(self, obj): 

168 if obj is None: 

169 return 

170 self.meta_data = getattr(obj, "meta_data", None)