Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import asyncio 

2import contextlib 

3from pathlib import Path 

4import janus 

5import queue 

6import threading 

7import uuid 

8 

9from .tracer import trace 

10from .utils import ( 

11 detect_fts, 

12 detect_primary_keys, 

13 detect_spatialite, 

14 get_all_foreign_keys, 

15 get_outbound_foreign_keys, 

16 sqlite_timelimit, 

17 sqlite3, 

18 table_columns, 

19) 

20from .inspect import inspect_hash 

21 

22connections = threading.local() 

23 

24 

25class Database: 

26 def __init__(self, ds, path=None, is_mutable=False, is_memory=False): 

27 self.ds = ds 

28 self.path = path 

29 self.is_mutable = is_mutable 

30 self.is_memory = is_memory 

31 self.hash = None 

32 self.cached_size = None 

33 self.cached_table_counts = None 

34 self._write_thread = None 

35 self._write_queue = None 

36 if not self.is_mutable and not self.is_memory: 

37 p = Path(path) 

38 self.hash = inspect_hash(p) 

39 self.cached_size = p.stat().st_size 

40 # Maybe use self.ds.inspect_data to populate cached_table_counts 

41 if self.ds.inspect_data and self.ds.inspect_data.get(self.name): 

42 self.cached_table_counts = { 

43 key: value["count"] 

44 for key, value in self.ds.inspect_data[self.name]["tables"].items() 

45 } 

46 

47 def connect(self, write=False): 

48 if self.is_memory: 

49 return sqlite3.connect(":memory:") 

50 # mode=ro or immutable=1? 

51 if self.is_mutable: 

52 qs = "?mode=ro" 

53 else: 

54 qs = "?immutable=1" 

55 assert not (write and not self.is_mutable) 

56 if write: 

57 qs = "" 

58 return sqlite3.connect( 

59 "file:{}{}".format(self.path, qs), uri=True, check_same_thread=False 

60 ) 

61 

62 async def execute_write(self, sql, params=None, block=False): 

63 def _inner(conn): 

64 with conn: 

65 return conn.execute(sql, params or []) 

66 

67 return await self.execute_write_fn(_inner, block=block) 

68 

69 async def execute_write_fn(self, fn, block=False): 

70 task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io") 

71 if self._write_queue is None: 

72 self._write_queue = queue.Queue() 

73 if self._write_thread is None: 

74 self._write_thread = threading.Thread( 

75 target=self._execute_writes, daemon=True 

76 ) 

77 self._write_thread.start() 

78 reply_queue = janus.Queue() 

79 self._write_queue.put(WriteTask(fn, task_id, reply_queue)) 

80 if block: 

81 result = await reply_queue.async_q.get() 

82 if isinstance(result, Exception): 

83 raise result 

84 else: 

85 return result 

86 else: 

87 return task_id 

88 

89 def _execute_writes(self): 

90 # Infinite looping thread that protects the single write connection 

91 # to this database 

92 conn = self.connect(write=True) 

93 while True: 

94 task = self._write_queue.get() 

95 try: 

96 result = task.fn(conn) 

97 except Exception as e: 

98 print(e) 

99 result = e 

100 task.reply_queue.sync_q.put(result) 

101 

102 async def execute_fn(self, fn): 

103 def in_thread(): 

104 conn = getattr(connections, self.name, None) 

105 if not conn: 

106 conn = self.connect() 

107 self.ds._prepare_connection(conn, self.name) 

108 setattr(connections, self.name, conn) 

109 return fn(conn) 

110 

111 return await asyncio.get_event_loop().run_in_executor( 

112 self.ds.executor, in_thread 

113 ) 

114 

115 async def execute( 

116 self, 

117 sql, 

118 params=None, 

119 truncate=False, 

120 custom_time_limit=None, 

121 page_size=None, 

122 log_sql_errors=True, 

123 ): 

124 """Executes sql against db_name in a thread""" 

125 page_size = page_size or self.ds.page_size 

126 

127 def sql_operation_in_thread(conn): 

128 time_limit_ms = self.ds.sql_time_limit_ms 

129 if custom_time_limit and custom_time_limit < time_limit_ms: 

130 time_limit_ms = custom_time_limit 

131 

132 with sqlite_timelimit(conn, time_limit_ms): 

133 try: 

134 cursor = conn.cursor() 

135 cursor.execute(sql, params or {}) 

136 max_returned_rows = self.ds.max_returned_rows 

137 if max_returned_rows == page_size: 

138 max_returned_rows += 1 

139 if max_returned_rows and truncate: 

140 rows = cursor.fetchmany(max_returned_rows + 1) 

141 truncated = len(rows) > max_returned_rows 

142 rows = rows[:max_returned_rows] 

143 else: 

144 rows = cursor.fetchall() 

145 truncated = False 

146 except (sqlite3.OperationalError, sqlite3.DatabaseError) as e: 

147 if e.args == ("interrupted",): 

148 raise QueryInterrupted(e, sql, params) 

149 if log_sql_errors: 

150 print( 

151 "ERROR: conn={}, sql = {}, params = {}: {}".format( 

152 conn, repr(sql), params, e 

153 ) 

154 ) 

155 raise 

156 

157 if truncate: 

158 return Results(rows, truncated, cursor.description) 

159 

160 else: 

161 return Results(rows, False, cursor.description) 

162 

163 with trace("sql", database=self.name, sql=sql.strip(), params=params): 

164 results = await self.execute_fn(sql_operation_in_thread) 

165 return results 

166 

167 @property 

168 def size(self): 

169 if self.is_memory: 

170 return 0 

171 if self.cached_size is not None: 

172 return self.cached_size 

173 else: 

174 return Path(self.path).stat().st_size 

175 

176 async def table_counts(self, limit=10): 

177 if not self.is_mutable and self.cached_table_counts is not None: 

178 return self.cached_table_counts 

179 # Try to get counts for each table, $limit timeout for each count 

180 counts = {} 

181 for table in await self.table_names(): 

182 try: 

183 table_count = ( 

184 await self.execute( 

185 "select count(*) from [{}]".format(table), 

186 custom_time_limit=limit, 

187 ) 

188 ).rows[0][0] 

189 counts[table] = table_count 

190 # In some cases I saw "SQL Logic Error" here in addition to 

191 # QueryInterrupted - so we catch that too: 

192 except (QueryInterrupted, sqlite3.OperationalError, sqlite3.DatabaseError): 

193 counts[table] = None 

194 if not self.is_mutable: 

195 self.cached_table_counts = counts 

196 return counts 

197 

198 @property 

199 def mtime_ns(self): 

200 if self.is_memory: 

201 return None 

202 return Path(self.path).stat().st_mtime_ns 

203 

204 @property 

205 def name(self): 

206 if self.is_memory: 

207 return ":memory:" 

208 else: 

209 return Path(self.path).stem 

210 

211 async def table_exists(self, table): 

212 results = await self.execute( 

213 "select 1 from sqlite_master where type='table' and name=?", params=(table,) 

214 ) 

215 return bool(results.rows) 

216 

217 async def table_names(self): 

218 results = await self.execute( 

219 "select name from sqlite_master where type='table'" 

220 ) 

221 return [r[0] for r in results.rows] 

222 

223 async def table_columns(self, table): 

224 return await self.execute_fn(lambda conn: table_columns(conn, table)) 

225 

226 async def primary_keys(self, table): 

227 return await self.execute_fn(lambda conn: detect_primary_keys(conn, table)) 

228 

229 async def fts_table(self, table): 

230 return await self.execute_fn(lambda conn: detect_fts(conn, table)) 

231 

232 async def label_column_for_table(self, table): 

233 explicit_label_column = self.ds.table_metadata(self.name, table).get( 

234 "label_column" 

235 ) 

236 if explicit_label_column: 

237 return explicit_label_column 

238 # If a table has two columns, one of which is ID, then label_column is the other one 

239 column_names = await self.execute_fn(lambda conn: table_columns(conn, table)) 

240 # Is there a name or title column? 

241 name_or_title = [c for c in column_names if c in ("name", "title")] 

242 if name_or_title: 

243 return name_or_title[0] 

244 if ( 

245 column_names 

246 and len(column_names) == 2 

247 and ("id" in column_names or "pk" in column_names) 

248 ): 

249 return [c for c in column_names if c not in ("id", "pk")][0] 

250 # Couldn't find a label: 

251 return None 

252 

253 async def foreign_keys_for_table(self, table): 

254 return await self.execute_fn( 

255 lambda conn: get_outbound_foreign_keys(conn, table) 

256 ) 

257 

258 async def hidden_table_names(self): 

259 # Mark tables 'hidden' if they relate to FTS virtual tables 

260 hidden_tables = [ 

261 r[0] 

262 for r in ( 

263 await self.execute( 

264 """ 

265 select name from sqlite_master 

266 where rootpage = 0 

267 and sql like '%VIRTUAL TABLE%USING FTS%' 

268 """ 

269 ) 

270 ).rows 

271 ] 

272 has_spatialite = await self.execute_fn(detect_spatialite) 

273 if has_spatialite: 

274 # Also hide Spatialite internal tables 

275 hidden_tables += [ 

276 "ElementaryGeometries", 

277 "SpatialIndex", 

278 "geometry_columns", 

279 "spatial_ref_sys", 

280 "spatialite_history", 

281 "sql_statements_log", 

282 "sqlite_sequence", 

283 "views_geometry_columns", 

284 "virts_geometry_columns", 

285 ] + [ 

286 r[0] 

287 for r in ( 

288 await self.execute( 

289 """ 

290 select name from sqlite_master 

291 where name like "idx_%" 

292 and type = "table" 

293 """ 

294 ) 

295 ).rows 

296 ] 

297 # Add any from metadata.json 

298 db_metadata = self.ds.metadata(database=self.name) 

299 if "tables" in db_metadata: 

300 hidden_tables += [ 

301 t 

302 for t in db_metadata["tables"] 

303 if db_metadata["tables"][t].get("hidden") 

304 ] 

305 # Also mark as hidden any tables which start with the name of a hidden table 

306 # e.g. "searchable_fts" implies "searchable_fts_content" should be hidden 

307 for table_name in await self.table_names(): 

308 for hidden_table in hidden_tables[:]: 

309 if table_name.startswith(hidden_table): 

310 hidden_tables.append(table_name) 

311 continue 

312 

313 return hidden_tables 

314 

315 async def view_names(self): 

316 results = await self.execute("select name from sqlite_master where type='view'") 

317 return [r[0] for r in results.rows] 

318 

319 async def get_all_foreign_keys(self): 

320 return await self.execute_fn(get_all_foreign_keys) 

321 

322 async def get_table_definition(self, table, type_="table"): 

323 table_definition_rows = list( 

324 await self.execute( 

325 "select sql from sqlite_master where name = :n and type=:t", 

326 {"n": table, "t": type_}, 

327 ) 

328 ) 

329 if not table_definition_rows: 

330 return None 

331 bits = [table_definition_rows[0][0] + ";"] 

332 # Add on any indexes 

333 index_rows = list( 

334 await self.execute( 

335 "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null", 

336 {"n": table}, 

337 ) 

338 ) 

339 for index_row in index_rows: 

340 bits.append(index_row[0] + ";") 

341 return "\n".join(bits) 

342 

343 async def get_view_definition(self, view): 

344 return await self.get_table_definition(view, "view") 

345 

346 def __repr__(self): 

347 tags = [] 

348 if self.is_mutable: 

349 tags.append("mutable") 

350 if self.is_memory: 

351 tags.append("memory") 

352 if self.hash: 

353 tags.append("hash={}".format(self.hash)) 

354 if self.size is not None: 

355 tags.append("size={}".format(self.size)) 

356 tags_str = "" 

357 if tags: 

358 tags_str = " ({})".format(", ".join(tags)) 

359 return "<Database: {}{}>".format(self.name, tags_str) 

360 

361 

362class WriteTask: 

363 __slots__ = ("fn", "task_id", "reply_queue") 

364 

365 def __init__(self, fn, task_id, reply_queue): 

366 self.fn = fn 

367 self.task_id = task_id 

368 self.reply_queue = reply_queue 

369 

370 

371class QueryInterrupted(Exception): 

372 pass 

373 

374 

375class MultipleValues(Exception): 

376 pass 

377 

378 

379class Results: 

380 def __init__(self, rows, truncated, description): 

381 self.rows = rows 

382 self.truncated = truncated 

383 self.description = description 

384 

385 @property 

386 def columns(self): 

387 return [d[0] for d in self.description] 

388 

389 def first(self): 

390 if self.rows: 

391 return self.rows[0] 

392 else: 

393 return None 

394 

395 def single_value(self): 

396 if self.rows and 1 == len(self.rows) and 1 == len(self.rows[0]): 

397 return self.rows[0][0] 

398 else: 

399 raise MultipleValues 

400 

401 def __iter__(self): 

402 return iter(self.rows) 

403 

404 def __len__(self): 

405 return len(self.rows)