Coverage for datasette/database.py : 93%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import contextlib
3from pathlib import Path
4import janus
5import queue
6import threading
7import uuid
9from .tracer import trace
10from .utils import (
11 detect_fts,
12 detect_primary_keys,
13 detect_spatialite,
14 get_all_foreign_keys,
15 get_outbound_foreign_keys,
16 sqlite_timelimit,
17 sqlite3,
18 table_columns,
19)
20from .inspect import inspect_hash
22connections = threading.local()
25class Database:
26 def __init__(self, ds, path=None, is_mutable=False, is_memory=False):
27 self.ds = ds
28 self.path = path
29 self.is_mutable = is_mutable
30 self.is_memory = is_memory
31 self.hash = None
32 self.cached_size = None
33 self.cached_table_counts = None
34 self._write_thread = None
35 self._write_queue = None
36 if not self.is_mutable and not self.is_memory:
37 p = Path(path)
38 self.hash = inspect_hash(p)
39 self.cached_size = p.stat().st_size
40 # Maybe use self.ds.inspect_data to populate cached_table_counts
41 if self.ds.inspect_data and self.ds.inspect_data.get(self.name):
42 self.cached_table_counts = {
43 key: value["count"]
44 for key, value in self.ds.inspect_data[self.name]["tables"].items()
45 }
47 def connect(self, write=False):
48 if self.is_memory:
49 return sqlite3.connect(":memory:")
50 # mode=ro or immutable=1?
51 if self.is_mutable:
52 qs = "?mode=ro"
53 else:
54 qs = "?immutable=1"
55 assert not (write and not self.is_mutable)
56 if write:
57 qs = ""
58 return sqlite3.connect(
59 "file:{}{}".format(self.path, qs), uri=True, check_same_thread=False
60 )
62 async def execute_write(self, sql, params=None, block=False):
63 def _inner(conn):
64 with conn:
65 return conn.execute(sql, params or [])
67 return await self.execute_write_fn(_inner, block=block)
69 async def execute_write_fn(self, fn, block=False):
70 task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io")
71 if self._write_queue is None:
72 self._write_queue = queue.Queue()
73 if self._write_thread is None:
74 self._write_thread = threading.Thread(
75 target=self._execute_writes, daemon=True
76 )
77 self._write_thread.start()
78 reply_queue = janus.Queue()
79 self._write_queue.put(WriteTask(fn, task_id, reply_queue))
80 if block:
81 result = await reply_queue.async_q.get()
82 if isinstance(result, Exception):
83 raise result
84 else:
85 return result
86 else:
87 return task_id
89 def _execute_writes(self):
90 # Infinite looping thread that protects the single write connection
91 # to this database
92 conn = self.connect(write=True)
93 while True:
94 task = self._write_queue.get()
95 try:
96 result = task.fn(conn)
97 except Exception as e:
98 print(e)
99 result = e
100 task.reply_queue.sync_q.put(result)
102 async def execute_fn(self, fn):
103 def in_thread():
104 conn = getattr(connections, self.name, None)
105 if not conn:
106 conn = self.connect()
107 self.ds._prepare_connection(conn, self.name)
108 setattr(connections, self.name, conn)
109 return fn(conn)
111 return await asyncio.get_event_loop().run_in_executor(
112 self.ds.executor, in_thread
113 )
115 async def execute(
116 self,
117 sql,
118 params=None,
119 truncate=False,
120 custom_time_limit=None,
121 page_size=None,
122 log_sql_errors=True,
123 ):
124 """Executes sql against db_name in a thread"""
125 page_size = page_size or self.ds.page_size
127 def sql_operation_in_thread(conn):
128 time_limit_ms = self.ds.sql_time_limit_ms
129 if custom_time_limit and custom_time_limit < time_limit_ms:
130 time_limit_ms = custom_time_limit
132 with sqlite_timelimit(conn, time_limit_ms):
133 try:
134 cursor = conn.cursor()
135 cursor.execute(sql, params or {})
136 max_returned_rows = self.ds.max_returned_rows
137 if max_returned_rows == page_size:
138 max_returned_rows += 1
139 if max_returned_rows and truncate:
140 rows = cursor.fetchmany(max_returned_rows + 1)
141 truncated = len(rows) > max_returned_rows
142 rows = rows[:max_returned_rows]
143 else:
144 rows = cursor.fetchall()
145 truncated = False
146 except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
147 if e.args == ("interrupted",):
148 raise QueryInterrupted(e, sql, params)
149 if log_sql_errors:
150 print(
151 "ERROR: conn={}, sql = {}, params = {}: {}".format(
152 conn, repr(sql), params, e
153 )
154 )
155 raise
157 if truncate:
158 return Results(rows, truncated, cursor.description)
160 else:
161 return Results(rows, False, cursor.description)
163 with trace("sql", database=self.name, sql=sql.strip(), params=params):
164 results = await self.execute_fn(sql_operation_in_thread)
165 return results
167 @property
168 def size(self):
169 if self.is_memory:
170 return 0
171 if self.cached_size is not None:
172 return self.cached_size
173 else:
174 return Path(self.path).stat().st_size
176 async def table_counts(self, limit=10):
177 if not self.is_mutable and self.cached_table_counts is not None:
178 return self.cached_table_counts
179 # Try to get counts for each table, $limit timeout for each count
180 counts = {}
181 for table in await self.table_names():
182 try:
183 table_count = (
184 await self.execute(
185 "select count(*) from [{}]".format(table),
186 custom_time_limit=limit,
187 )
188 ).rows[0][0]
189 counts[table] = table_count
190 # In some cases I saw "SQL Logic Error" here in addition to
191 # QueryInterrupted - so we catch that too:
192 except (QueryInterrupted, sqlite3.OperationalError, sqlite3.DatabaseError):
193 counts[table] = None
194 if not self.is_mutable:
195 self.cached_table_counts = counts
196 return counts
198 @property
199 def mtime_ns(self):
200 if self.is_memory:
201 return None
202 return Path(self.path).stat().st_mtime_ns
204 @property
205 def name(self):
206 if self.is_memory:
207 return ":memory:"
208 else:
209 return Path(self.path).stem
211 async def table_exists(self, table):
212 results = await self.execute(
213 "select 1 from sqlite_master where type='table' and name=?", params=(table,)
214 )
215 return bool(results.rows)
217 async def table_names(self):
218 results = await self.execute(
219 "select name from sqlite_master where type='table'"
220 )
221 return [r[0] for r in results.rows]
223 async def table_columns(self, table):
224 return await self.execute_fn(lambda conn: table_columns(conn, table))
226 async def primary_keys(self, table):
227 return await self.execute_fn(lambda conn: detect_primary_keys(conn, table))
229 async def fts_table(self, table):
230 return await self.execute_fn(lambda conn: detect_fts(conn, table))
232 async def label_column_for_table(self, table):
233 explicit_label_column = self.ds.table_metadata(self.name, table).get(
234 "label_column"
235 )
236 if explicit_label_column:
237 return explicit_label_column
238 # If a table has two columns, one of which is ID, then label_column is the other one
239 column_names = await self.execute_fn(lambda conn: table_columns(conn, table))
240 # Is there a name or title column?
241 name_or_title = [c for c in column_names if c in ("name", "title")]
242 if name_or_title:
243 return name_or_title[0]
244 if (
245 column_names
246 and len(column_names) == 2
247 and ("id" in column_names or "pk" in column_names)
248 ):
249 return [c for c in column_names if c not in ("id", "pk")][0]
250 # Couldn't find a label:
251 return None
253 async def foreign_keys_for_table(self, table):
254 return await self.execute_fn(
255 lambda conn: get_outbound_foreign_keys(conn, table)
256 )
258 async def hidden_table_names(self):
259 # Mark tables 'hidden' if they relate to FTS virtual tables
260 hidden_tables = [
261 r[0]
262 for r in (
263 await self.execute(
264 """
265 select name from sqlite_master
266 where rootpage = 0
267 and sql like '%VIRTUAL TABLE%USING FTS%'
268 """
269 )
270 ).rows
271 ]
272 has_spatialite = await self.execute_fn(detect_spatialite)
273 if has_spatialite:
274 # Also hide Spatialite internal tables
275 hidden_tables += [
276 "ElementaryGeometries",
277 "SpatialIndex",
278 "geometry_columns",
279 "spatial_ref_sys",
280 "spatialite_history",
281 "sql_statements_log",
282 "sqlite_sequence",
283 "views_geometry_columns",
284 "virts_geometry_columns",
285 ] + [
286 r[0]
287 for r in (
288 await self.execute(
289 """
290 select name from sqlite_master
291 where name like "idx_%"
292 and type = "table"
293 """
294 )
295 ).rows
296 ]
297 # Add any from metadata.json
298 db_metadata = self.ds.metadata(database=self.name)
299 if "tables" in db_metadata:
300 hidden_tables += [
301 t
302 for t in db_metadata["tables"]
303 if db_metadata["tables"][t].get("hidden")
304 ]
305 # Also mark as hidden any tables which start with the name of a hidden table
306 # e.g. "searchable_fts" implies "searchable_fts_content" should be hidden
307 for table_name in await self.table_names():
308 for hidden_table in hidden_tables[:]:
309 if table_name.startswith(hidden_table):
310 hidden_tables.append(table_name)
311 continue
313 return hidden_tables
315 async def view_names(self):
316 results = await self.execute("select name from sqlite_master where type='view'")
317 return [r[0] for r in results.rows]
319 async def get_all_foreign_keys(self):
320 return await self.execute_fn(get_all_foreign_keys)
322 async def get_table_definition(self, table, type_="table"):
323 table_definition_rows = list(
324 await self.execute(
325 "select sql from sqlite_master where name = :n and type=:t",
326 {"n": table, "t": type_},
327 )
328 )
329 if not table_definition_rows:
330 return None
331 bits = [table_definition_rows[0][0] + ";"]
332 # Add on any indexes
333 index_rows = list(
334 await self.execute(
335 "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null",
336 {"n": table},
337 )
338 )
339 for index_row in index_rows:
340 bits.append(index_row[0] + ";")
341 return "\n".join(bits)
343 async def get_view_definition(self, view):
344 return await self.get_table_definition(view, "view")
346 def __repr__(self):
347 tags = []
348 if self.is_mutable:
349 tags.append("mutable")
350 if self.is_memory:
351 tags.append("memory")
352 if self.hash:
353 tags.append("hash={}".format(self.hash))
354 if self.size is not None:
355 tags.append("size={}".format(self.size))
356 tags_str = ""
357 if tags:
358 tags_str = " ({})".format(", ".join(tags))
359 return "<Database: {}{}>".format(self.name, tags_str)
362class WriteTask:
363 __slots__ = ("fn", "task_id", "reply_queue")
365 def __init__(self, fn, task_id, reply_queue):
366 self.fn = fn
367 self.task_id = task_id
368 self.reply_queue = reply_queue
371class QueryInterrupted(Exception):
372 pass
375class MultipleValues(Exception):
376 pass
379class Results:
380 def __init__(self, rows, truncated, description):
381 self.rows = rows
382 self.truncated = truncated
383 self.description = description
385 @property
386 def columns(self):
387 return [d[0] for d in self.description]
389 def first(self):
390 if self.rows:
391 return self.rows[0]
392 else:
393 return None
395 def single_value(self):
396 if self.rows and 1 == len(self.rows) and 1 == len(self.rows[0]):
397 return self.rows[0][0]
398 else:
399 raise MultipleValues
401 def __iter__(self):
402 return iter(self.rows)
404 def __len__(self):
405 return len(self.rows)