Coverage for pesummary/core/file/formats/sql.py: 71.2%

73 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-02 08:42 +0000

1# Licensed under an MIT style license -- see LICENSE.md 

2 

3import sqlite3 

4import numpy as np 

5from pesummary.utils.samples_dict import MultiAnalysisSamplesDict 

6from pesummary.utils.utils import logger, check_filename 

7 

8__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"] 

9 

10 

11def read_sql(path, path_to_samples=None, remove_row_column="ROW", **kwargs): 

12 """Grab the parameters and samples in an sql database file 

13 

14 Parameters 

15 ---------- 

16 path: str 

17 path to the result file you wish to read in 

18 path_to_samples: str/list, optional 

19 table or list of tables that you wish to load 

20 remove_row_column: str, optional 

21 remove the column with name 'remove_row_column' which indicates the row. 

22 Default 'ROW' 

23 """ 

24 db = sqlite3.connect(path) 

25 d = db.cursor() 

26 d.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;") 

27 tables = [i[0] for i in d.fetchall()] 

28 if path_to_samples is not None: 

29 if isinstance(path_to_samples, str): 

30 if path_to_samples in tables: 

31 tables = [path_to_samples] 

32 else: 

33 raise ValueError( 

34 "{} not in list of tables".format(path_to_samples) 

35 ) 

36 elif isinstance(path_to_samples, (np.ndarray, list)): 

37 if not all(name in tables for name in path_to_samples): 

38 names = [ 

39 name for name in path_to_samples if name not in tables 

40 ] 

41 raise ValueError( 

42 "The tables: {} are not in the sql database".format( 

43 ", ".join(names) 

44 ) 

45 ) 

46 else: 

47 tables = path_to_samples 

48 else: 

49 raise ValueError("{} not understood".format(path_to_samples)) 

50 parameters, samples = [], [] 

51 for table in tables: 

52 d.execute( 

53 "SELECT * FROM {}".format(table) 

54 ) 

55 samples.append(np.array(d.fetchall())) 

56 parameters.append([i[0] for i in d.description]) 

57 for num, (_parameters, _samples) in enumerate(zip(parameters, samples)): 

58 if remove_row_column in _parameters: 

59 ind = _parameters.index(remove_row_column) 

60 _parameters.remove(remove_row_column) 

61 mask = np.ones(len(_samples.T), dtype=bool) 

62 mask[ind] = False 

63 samples[num] = _samples[:, mask] 

64 if len(tables) == 1: 

65 return parameters[0], np.array(samples[0]).tolist(), tables 

66 return parameters, np.array(samples).tolist(), tables 

67 

68 

69def write_sql( 

70 *args, table_name="MYTABLE", outdir="./", filename=None, overwrite=False, 

71 keys_as_table_name=True, delete_existing=False, **kwargs 

72): 

73 """Write a set of samples to an sql database 

74 

75 Parameters 

76 ---------- 

77 args: tuple, dict, MultiAnalysisSamplesDict 

78 the posterior samples you wish to save to file. Either a tuple 

79 of parameters and a 2d list of samples with columns corresponding to 

80 a given parameter, dict of parameters and samples, or a 

81 MultiAnalysisSamplesDict object with parameters and samples for 

82 multiple analyses 

83 table_name: str, optional 

84 name of the table to store the samples. If a MultiAnalysisSamplesDict 

85 if provided, this is ignored and the table_names are the labels stored 

86 outdir: str, optional 

87 directory to write the dat file 

88 filename: str, optional 

89 The name of the file that you wish to write 

90 overwrite: Bool, optional 

91 If True, an existing file of the same name will be overwritten 

92 keys_as_table_name: Bool, optional 

93 if True, ignore table_name and use the keys of the 

94 MultiAnalysisSamplesDict as the table name. Default True 

95 """ 

96 default_filename = "pesummary_{}.db" 

97 filename = check_filename( 

98 default_filename=default_filename, outdir=outdir, label=table_name, 

99 filename=filename, overwrite=overwrite, delete_existing=delete_existing 

100 ) 

101 

102 if isinstance(args[0], MultiAnalysisSamplesDict): 

103 if isinstance(table_name, str): 

104 logger.info( 

105 "Ignoring the table name: {} and using the labels in the " 

106 "MultiAnalysisSamplesDict".format(table_name) 

107 ) 

108 table_name = list(args[0].keys()) 

109 elif isinstance(table_name, dict): 

110 if keys_as_table_name: 

111 logger.info( 

112 "Ignoring table_name and using the labels in the " 

113 "MultiAnalysisSamplesDict. To override this, set " 

114 "`keys_as_table_name=False`" 

115 ) 

116 table_name = list(args[0].keys()) 

117 elif not all(key in table_name.keys() for key in args.keys()): 

118 raise ValueError("Please provide a table_name for all analyses") 

119 else: 

120 table_name = [ 

121 key for key in table_name.keys() if key in args.keys() 

122 ] 

123 else: 

124 raise ValueError( 

125 "Please provide table name as a dictionary which maps " 

126 "the analysis label to the table name" 

127 ) 

128 table_name = list(args[0].keys()) 

129 columns = [list(args[0][label].keys()) for label in table_name] 

130 rows = [ 

131 np.array([args[0][label][param] for param in columns[num]]).T for 

132 num, label in enumerate(table_name) 

133 ] 

134 elif isinstance(args[0], dict): 

135 columns = list(args[0].keys()) 

136 rows = np.array([args[0][param] for param in columns]).T 

137 else: 

138 columns, rows = args 

139 

140 table_names = np.atleast_1d(table_name) 

141 columns = np.atleast_2d(columns) 

142 if np.array(rows).ndim == 1: 

143 rows = [[rows]] 

144 elif np.array(rows).ndim == 2: 

145 rows = [rows] 

146 

147 if len(table_names) != len(columns): 

148 table_names = [ 

149 "{}_{}".format(table_names[0], idx) for idx in range(len(columns)) 

150 ] 

151 

152 db = sqlite3.connect(filename) 

153 d = db.cursor() 

154 for num, table_name in enumerate(table_names): 

155 command = "CREATE TABLE {} (ROW INT, {});".format( 

156 table_name, ", ".join(["%s DOUBLE" % (col) for col in columns[num]]) 

157 ) 

158 for idx, row in enumerate(rows[num]): 

159 command += "INSERT INTO {} (ROW, {}) VALUES ({}, {});".format( 

160 table_name, ", ".join(columns[num]), idx, 

161 ", ".join([str(r) for r in row]) 

162 ) 

163 d.executescript(command)