Coverage for pesummary/core/file/formats/sql.py: 60.3%
73 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-09 22:34 +0000
1# Licensed under an MIT style license -- see LICENSE.md
3import sqlite3
4import numpy as np
5from pesummary.utils.samples_dict import MultiAnalysisSamplesDict
6from pesummary.utils.utils import logger, check_filename
8__author__ = ["Charlie Hoy <charlie.hoy@ligo.org>"]
11def read_sql(path, path_to_samples=None, remove_row_column="ROW", **kwargs):
12 """Grab the parameters and samples in an sql database file
14 Parameters
15 ----------
16 path: str
17 path to the result file you wish to read in
18 path_to_samples: str/list, optional
19 table or list of tables that you wish to load
20 remove_row_column: str, optional
21 remove the column with name 'remove_row_column' which indicates the row.
22 Default 'ROW'
23 """
24 db = sqlite3.connect(path)
25 d = db.cursor()
26 d.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;")
27 tables = [i[0] for i in d.fetchall()]
28 if path_to_samples is not None:
29 if isinstance(path_to_samples, str):
30 if path_to_samples in tables:
31 tables = [path_to_samples]
32 else:
33 raise ValueError(
34 "{} not in list of tables".format(path_to_samples)
35 )
36 elif isinstance(path_to_samples, (np.ndarray, list)):
37 if not all(name in tables for name in path_to_samples):
38 names = [
39 name for name in path_to_samples if name not in tables
40 ]
41 raise ValueError(
42 "The tables: {} are not in the sql database".format(
43 ", ".join(names)
44 )
45 )
46 else:
47 tables = path_to_samples
48 else:
49 raise ValueError("{} not understood".format(path_to_samples))
50 parameters, samples = [], []
51 for table in tables:
52 d.execute(
53 "SELECT * FROM {}".format(table)
54 )
55 samples.append(np.array(d.fetchall()))
56 parameters.append([i[0] for i in d.description])
57 for num, (_parameters, _samples) in enumerate(zip(parameters, samples)):
58 if remove_row_column in _parameters:
59 ind = _parameters.index(remove_row_column)
60 _parameters.remove(remove_row_column)
61 mask = np.ones(len(_samples.T), dtype=bool)
62 mask[ind] = False
63 samples[num] = _samples[:, mask]
64 if len(tables) == 1:
65 return parameters[0], np.array(samples[0]).tolist(), tables
66 return parameters, np.array(samples).tolist(), tables
69def write_sql(
70 *args, table_name="MYTABLE", outdir="./", filename=None, overwrite=False,
71 keys_as_table_name=True, delete_existing=False, **kwargs
72):
73 """Write a set of samples to an sql database
75 Parameters
76 ----------
77 args: tuple, dict, MultiAnalysisSamplesDict
78 the posterior samples you wish to save to file. Either a tuple
79 of parameters and a 2d list of samples with columns corresponding to
80 a given parameter, dict of parameters and samples, or a
81 MultiAnalysisSamplesDict object with parameters and samples for
82 multiple analyses
83 table_name: str, optional
84 name of the table to store the samples. If a MultiAnalysisSamplesDict
85 if provided, this is ignored and the table_names are the labels stored
86 outdir: str, optional
87 directory to write the dat file
88 filename: str, optional
89 The name of the file that you wish to write
90 overwrite: Bool, optional
91 If True, an existing file of the same name will be overwritten
92 keys_as_table_name: Bool, optional
93 if True, ignore table_name and use the keys of the
94 MultiAnalysisSamplesDict as the table name. Default True
95 """
96 default_filename = "pesummary_{}.db"
97 filename = check_filename(
98 default_filename=default_filename, outdir=outdir, label=table_name,
99 filename=filename, overwrite=overwrite, delete_existing=delete_existing
100 )
102 if isinstance(args[0], MultiAnalysisSamplesDict):
103 if isinstance(table_name, str):
104 logger.info(
105 "Ignoring the table name: {} and using the labels in the "
106 "MultiAnalysisSamplesDict".format(table_name)
107 )
108 table_name = list(args[0].keys())
109 elif isinstance(table_name, dict):
110 if keys_as_table_name:
111 logger.info(
112 "Ignoring table_name and using the labels in the "
113 "MultiAnalysisSamplesDict. To override this, set "
114 "`keys_as_table_name=False`"
115 )
116 table_name = list(args[0].keys())
117 elif not all(key in table_name.keys() for key in args.keys()):
118 raise ValueError("Please provide a table_name for all analyses")
119 else:
120 table_name = [
121 key for key in table_name.keys() if key in args.keys()
122 ]
123 else:
124 raise ValueError(
125 "Please provide table name as a dictionary which maps "
126 "the analysis label to the table name"
127 )
128 table_name = list(args[0].keys())
129 columns = [list(args[0][label].keys()) for label in table_name]
130 rows = [
131 np.array([args[0][label][param] for param in columns[num]]).T for
132 num, label in enumerate(table_name)
133 ]
134 elif isinstance(args[0], dict):
135 columns = list(args[0].keys())
136 rows = np.array([args[0][param] for param in columns]).T
137 else:
138 columns, rows = args
140 table_names = np.atleast_1d(table_name)
141 columns = np.atleast_2d(columns)
142 if np.array(rows).ndim == 1:
143 rows = [[rows]]
144 elif np.array(rows).ndim == 2:
145 rows = [rows]
147 if len(table_names) != len(columns):
148 table_names = [
149 "{}_{}".format(table_names[0], idx) for idx in range(len(columns))
150 ]
152 db = sqlite3.connect(filename)
153 d = db.cursor()
154 for num, table_name in enumerate(table_names):
155 command = "CREATE TABLE {} (ROW INT, {});".format(
156 table_name, ", ".join(["%s DOUBLE" % (col) for col in columns[num]])
157 )
158 for idx, row in enumerate(rows[num]):
159 command += "INSERT INTO {} (ROW, {}) VALUES ({}, {});".format(
160 table_name, ", ".join(columns[num]), idx,
161 ", ".join([str(r) for r in row])
162 )
163 d.executescript(command)