Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import asyncio 

2import csv 

3import itertools 

4from itsdangerous import BadSignature 

5import json 

6import re 

7import time 

8import urllib 

9 

10import pint 

11 

12from datasette import __version__ 

13from datasette.plugins import pm 

14from datasette.database import QueryInterrupted 

15from datasette.utils import ( 

16 InvalidSql, 

17 LimitedWriter, 

18 call_with_supported_arguments, 

19 is_url, 

20 path_with_added_args, 

21 path_with_removed_args, 

22 path_with_format, 

23 resolve_table_and_format, 

24 sqlite3, 

25 to_css_class, 

26) 

27from datasette.utils.asgi import ( 

28 AsgiStream, 

29 AsgiWriter, 

30 AsgiRouter, 

31 AsgiView, 

32 Forbidden, 

33 NotFound, 

34 Response, 

35) 

36 

37ureg = pint.UnitRegistry() 

38 

39HASH_LENGTH = 7 

40 

41 

42class DatasetteError(Exception): 

43 def __init__( 

44 self, 

45 message, 

46 title=None, 

47 error_dict=None, 

48 status=500, 

49 template=None, 

50 messagge_is_html=False, 

51 ): 

52 self.message = message 

53 self.title = title 

54 self.error_dict = error_dict or {} 

55 self.status = status 

56 self.messagge_is_html = messagge_is_html 

57 

58 

59class BaseView(AsgiView): 

60 ds = None 

61 

62 async def head(self, *args, **kwargs): 

63 response = await self.get(*args, **kwargs) 

64 response.body = b"" 

65 return response 

66 

67 async def check_permission(self, request, action, resource=None): 

68 ok = await self.ds.permission_allowed( 

69 request.actor, action, resource=resource, default=True, 

70 ) 

71 if not ok: 

72 raise Forbidden(action) 

73 

74 def database_url(self, database): 

75 db = self.ds.databases[database] 

76 base_url = self.ds.config("base_url") 

77 if self.ds.config("hash_urls") and db.hash: 

78 return "{}{}-{}".format(base_url, database, db.hash[:HASH_LENGTH]) 

79 else: 

80 return "{}{}".format(base_url, database) 

81 

82 def database_color(self, database): 

83 return "ff0000" 

84 

85 async def dispatch_request(self, request, *args, **kwargs): 

86 # Populate request_messages if ds_messages cookie is present 

87 if self.ds: 

88 try: 

89 request._messages = self.ds.unsign( 

90 request.cookies.get("ds_messages", ""), "messages" 

91 ) 

92 except BadSignature: 

93 pass 

94 response = await super().dispatch_request(request, *args, **kwargs) 

95 if self.ds: 

96 self.ds._write_messages_to_response(request, response) 

97 return response 

98 

99 async def render(self, templates, request, context=None): 

100 context = context or {} 

101 template = self.ds.jinja_env.select_template(templates) 

102 template_context = { 

103 **context, 

104 **{ 

105 "database_url": self.database_url, 

106 "csrftoken": request.scope["csrftoken"], 

107 "database_color": self.database_color, 

108 "show_messages": lambda: self.ds._show_messages(request), 

109 "select_templates": [ 

110 "{}{}".format( 

111 "*" if template_name == template.name else "", template_name 

112 ) 

113 for template_name in templates 

114 ], 

115 }, 

116 } 

117 return Response.html( 

118 await self.ds.render_template( 

119 template, template_context, request=request, view_name=self.name 

120 ) 

121 ) 

122 

123 

124class DataView(BaseView): 

125 name = "" 

126 re_named_parameter = re.compile(":([a-zA-Z0-9_]+)") 

127 

128 def __init__(self, datasette): 

129 self.ds = datasette 

130 

131 def options(self, request, *args, **kwargs): 

132 r = Response.text("ok") 

133 if self.ds.cors: 

134 r.headers["Access-Control-Allow-Origin"] = "*" 

135 return r 

136 

137 def redirect(self, request, path, forward_querystring=True, remove_args=None): 

138 if request.query_string and "?" not in path and forward_querystring: 

139 path = "{}?{}".format(path, request.query_string) 

140 if remove_args: 

141 path = path_with_removed_args(request, remove_args, path=path) 

142 r = Response.redirect(path) 

143 r.headers["Link"] = "<{}>; rel=preload".format(path) 

144 if self.ds.cors: 

145 r.headers["Access-Control-Allow-Origin"] = "*" 

146 return r 

147 

148 async def data(self, request, database, hash, **kwargs): 

149 raise NotImplementedError 

150 

151 async def resolve_db_name(self, request, db_name, **kwargs): 

152 hash = None 

153 name = None 

154 if db_name not in self.ds.databases and "-" in db_name: 

155 # No matching DB found, maybe it's a name-hash? 

156 name_bit, hash_bit = db_name.rsplit("-", 1) 

157 if name_bit not in self.ds.databases: 

158 raise NotFound("Database not found: {}".format(name)) 

159 else: 

160 name = name_bit 

161 hash = hash_bit 

162 else: 

163 name = db_name 

164 name = urllib.parse.unquote_plus(name) 

165 try: 

166 db = self.ds.databases[name] 

167 except KeyError: 

168 raise NotFound("Database not found: {}".format(name)) 

169 

170 # Verify the hash 

171 expected = "000" 

172 if db.hash is not None: 

173 expected = db.hash[:HASH_LENGTH] 

174 correct_hash_provided = expected == hash 

175 

176 if not correct_hash_provided: 

177 if "table_and_format" in kwargs: 

178 

179 async def async_table_exists(t): 

180 return await db.table_exists(t) 

181 

182 table, _format = await resolve_table_and_format( 

183 table_and_format=urllib.parse.unquote_plus( 

184 kwargs["table_and_format"] 

185 ), 

186 table_exists=async_table_exists, 

187 allowed_formats=self.ds.renderers.keys(), 

188 ) 

189 kwargs["table"] = table 

190 if _format: 

191 kwargs["as_format"] = ".{}".format(_format) 

192 elif kwargs.get("table"): 

193 kwargs["table"] = urllib.parse.unquote_plus(kwargs["table"]) 

194 

195 should_redirect = "/{}-{}".format(name, expected) 

196 if kwargs.get("table"): 

197 should_redirect += "/" + urllib.parse.quote_plus(kwargs["table"]) 

198 if kwargs.get("pk_path"): 

199 should_redirect += "/" + kwargs["pk_path"] 

200 if kwargs.get("as_format"): 

201 should_redirect += kwargs["as_format"] 

202 if kwargs.get("as_db"): 

203 should_redirect += kwargs["as_db"] 

204 

205 if ( 

206 (self.ds.config("hash_urls") or "_hash" in request.args) 

207 and 

208 # Redirect only if database is immutable 

209 not self.ds.databases[name].is_mutable 

210 ): 

211 return name, expected, correct_hash_provided, should_redirect 

212 

213 return name, expected, correct_hash_provided, None 

214 

215 def get_templates(self, database, table=None): 

216 assert NotImplemented 

217 

218 async def get(self, request, db_name, **kwargs): 

219 ( 

220 database, 

221 hash, 

222 correct_hash_provided, 

223 should_redirect, 

224 ) = await self.resolve_db_name(request, db_name, **kwargs) 

225 if should_redirect: 

226 return self.redirect(request, should_redirect, remove_args={"_hash"}) 

227 

228 return await self.view_get( 

229 request, database, hash, correct_hash_provided, **kwargs 

230 ) 

231 

232 async def as_csv(self, request, database, hash, **kwargs): 

233 stream = request.args.get("_stream") 

234 if stream: 

235 # Some quick sanity checks 

236 if not self.ds.config("allow_csv_stream"): 

237 raise DatasetteError("CSV streaming is disabled", status=400) 

238 if request.args.get("_next"): 

239 raise DatasetteError("_next not allowed for CSV streaming", status=400) 

240 kwargs["_size"] = "max" 

241 # Fetch the first page 

242 try: 

243 response_or_template_contexts = await self.data( 

244 request, database, hash, **kwargs 

245 ) 

246 if isinstance(response_or_template_contexts, Response): 

247 return response_or_template_contexts 

248 else: 

249 data, _, _ = response_or_template_contexts 

250 except (sqlite3.OperationalError, InvalidSql) as e: 

251 raise DatasetteError(str(e), title="Invalid SQL", status=400) 

252 

253 except (sqlite3.OperationalError) as e: 

254 raise DatasetteError(str(e)) 

255 

256 except DatasetteError: 

257 raise 

258 

259 # Convert rows and columns to CSV 

260 headings = data["columns"] 

261 # if there are expanded_columns we need to add additional headings 

262 expanded_columns = set(data.get("expanded_columns") or []) 

263 if expanded_columns: 

264 headings = [] 

265 for column in data["columns"]: 

266 headings.append(column) 

267 if column in expanded_columns: 

268 headings.append("{}_label".format(column)) 

269 

270 async def stream_fn(r): 

271 nonlocal data 

272 writer = csv.writer(LimitedWriter(r, self.ds.config("max_csv_mb"))) 

273 first = True 

274 next = None 

275 while first or (next and stream): 

276 try: 

277 if next: 

278 kwargs["_next"] = next 

279 if not first: 

280 data, _, _ = await self.data(request, database, hash, **kwargs) 

281 if first: 

282 await writer.writerow(headings) 

283 first = False 

284 next = data.get("next") 

285 for row in data["rows"]: 

286 if not expanded_columns: 

287 # Simple path 

288 await writer.writerow(row) 

289 else: 

290 # Look for {"value": "label": } dicts and expand 

291 new_row = [] 

292 for heading, cell in zip(data["columns"], row): 

293 if heading in expanded_columns: 

294 if cell is None: 

295 new_row.extend(("", "")) 

296 else: 

297 assert isinstance(cell, dict) 

298 new_row.append(cell["value"]) 

299 new_row.append(cell["label"]) 

300 else: 

301 new_row.append(cell) 

302 await writer.writerow(new_row) 

303 except Exception as e: 

304 print("caught this", e) 

305 await r.write(str(e)) 

306 return 

307 

308 content_type = "text/plain; charset=utf-8" 

309 headers = {} 

310 if self.ds.cors: 

311 headers["Access-Control-Allow-Origin"] = "*" 

312 if request.args.get("_dl", None): 

313 content_type = "text/csv; charset=utf-8" 

314 disposition = 'attachment; filename="{}.csv"'.format( 

315 kwargs.get("table", database) 

316 ) 

317 headers["Content-Disposition"] = disposition 

318 

319 return AsgiStream(stream_fn, headers=headers, content_type=content_type) 

320 

321 async def get_format(self, request, database, args): 

322 """ Determine the format of the response from the request, from URL 

323 parameters or from a file extension. 

324 

325 `args` is a dict of the path components parsed from the URL by the router. 

326 """ 

327 # If ?_format= is provided, use that as the format 

328 _format = request.args.get("_format", None) 

329 if not _format: 

330 _format = (args.pop("as_format", None) or "").lstrip(".") 

331 else: 

332 args.pop("as_format", None) 

333 if "table_and_format" in args: 

334 db = self.ds.databases[database] 

335 

336 async def async_table_exists(t): 

337 return await db.table_exists(t) 

338 

339 table, _ext_format = await resolve_table_and_format( 

340 table_and_format=urllib.parse.unquote_plus(args["table_and_format"]), 

341 table_exists=async_table_exists, 

342 allowed_formats=self.ds.renderers.keys(), 

343 ) 

344 _format = _format or _ext_format 

345 args["table"] = table 

346 del args["table_and_format"] 

347 elif "table" in args: 

348 args["table"] = urllib.parse.unquote_plus(args["table"]) 

349 return _format, args 

350 

351 async def view_get(self, request, database, hash, correct_hash_provided, **kwargs): 

352 _format, kwargs = await self.get_format(request, database, kwargs) 

353 

354 if _format == "csv": 

355 return await self.as_csv(request, database, hash, **kwargs) 

356 

357 if _format is None: 

358 # HTML views default to expanding all foreign key labels 

359 kwargs["default_labels"] = True 

360 

361 extra_template_data = {} 

362 start = time.time() 

363 status_code = 200 

364 templates = [] 

365 try: 

366 response_or_template_contexts = await self.data( 

367 request, database, hash, **kwargs 

368 ) 

369 if isinstance(response_or_template_contexts, Response): 

370 return response_or_template_contexts 

371 

372 else: 

373 data, extra_template_data, templates = response_or_template_contexts 

374 except QueryInterrupted: 

375 raise DatasetteError( 

376 """ 

377 SQL query took too long. The time limit is controlled by the 

378 <a href="https://datasette.readthedocs.io/en/stable/config.html#sql-time-limit-ms">sql_time_limit_ms</a> 

379 configuration option. 

380 """, 

381 title="SQL Interrupted", 

382 status=400, 

383 messagge_is_html=True, 

384 ) 

385 except (sqlite3.OperationalError, InvalidSql) as e: 

386 raise DatasetteError(str(e), title="Invalid SQL", status=400) 

387 

388 except (sqlite3.OperationalError) as e: 

389 raise DatasetteError(str(e)) 

390 

391 except DatasetteError: 

392 raise 

393 

394 end = time.time() 

395 data["query_ms"] = (end - start) * 1000 

396 for key in ("source", "source_url", "license", "license_url"): 

397 value = self.ds.metadata(key) 

398 if value: 

399 data[key] = value 

400 

401 # Special case for .jsono extension - redirect to _shape=objects 

402 if _format == "jsono": 

403 return self.redirect( 

404 request, 

405 path_with_added_args( 

406 request, 

407 {"_shape": "objects"}, 

408 path=request.path.rsplit(".jsono", 1)[0] + ".json", 

409 ), 

410 forward_querystring=False, 

411 ) 

412 

413 if _format in self.ds.renderers.keys(): 

414 # Dispatch request to the correct output format renderer 

415 # (CSV is not handled here due to streaming) 

416 result = call_with_supported_arguments( 

417 self.ds.renderers[_format][0], 

418 datasette=self.ds, 

419 columns=data.get("columns") or [], 

420 rows=data.get("rows") or [], 

421 sql=data.get("query", {}).get("sql", None), 

422 query_name=data.get("query_name"), 

423 database=database, 

424 table=data.get("table"), 

425 request=request, 

426 view_name=self.name, 

427 # These will be deprecated in Datasette 1.0: 

428 args=request.args, 

429 data=data, 

430 ) 

431 if asyncio.iscoroutine(result): 

432 result = await result 

433 if result is None: 

434 raise NotFound("No data") 

435 

436 r = Response( 

437 body=result.get("body"), 

438 status=result.get("status_code", 200), 

439 content_type=result.get("content_type", "text/plain"), 

440 headers=result.get("headers"), 

441 ) 

442 else: 

443 extras = {} 

444 if callable(extra_template_data): 

445 extras = extra_template_data() 

446 if asyncio.iscoroutine(extras): 

447 extras = await extras 

448 else: 

449 extras = extra_template_data 

450 url_labels_extra = {} 

451 if data.get("expandable_columns"): 

452 url_labels_extra = {"_labels": "on"} 

453 

454 renderers = {} 

455 for key, (_, can_render) in self.ds.renderers.items(): 

456 it_can_render = call_with_supported_arguments( 

457 can_render, 

458 datasette=self.ds, 

459 columns=data.get("columns") or [], 

460 rows=data.get("rows") or [], 

461 sql=data.get("query", {}).get("sql", None), 

462 query_name=data.get("query_name"), 

463 database=database, 

464 table=data.get("table"), 

465 request=request, 

466 view_name=self.name, 

467 ) 

468 if asyncio.iscoroutine(it_can_render): 

469 it_can_render = await it_can_render 

470 if it_can_render: 

471 renderers[key] = path_with_format( 

472 request, key, {**url_labels_extra} 

473 ) 

474 

475 url_csv_args = {"_size": "max", **url_labels_extra} 

476 url_csv = path_with_format(request, "csv", url_csv_args) 

477 url_csv_path = url_csv.split("?")[0] 

478 context = { 

479 **data, 

480 **extras, 

481 **{ 

482 "renderers": renderers, 

483 "url_csv": url_csv, 

484 "url_csv_path": url_csv_path, 

485 "url_csv_hidden_args": [ 

486 (key, value) 

487 for key, value in urllib.parse.parse_qsl(request.query_string) 

488 if key not in ("_labels", "_facet", "_size") 

489 ] 

490 + [("_size", "max")], 

491 "datasette_version": __version__, 

492 "config": self.ds.config_dict(), 

493 }, 

494 } 

495 if "metadata" not in context: 

496 context["metadata"] = self.ds.metadata 

497 r = await self.render(templates, request=request, context=context) 

498 r.status = status_code 

499 

500 ttl = request.args.get("_ttl", None) 

501 if ttl is None or not ttl.isdigit(): 

502 if correct_hash_provided: 

503 ttl = self.ds.config("default_cache_ttl_hashed") 

504 else: 

505 ttl = self.ds.config("default_cache_ttl") 

506 

507 return self.set_response_headers(r, ttl) 

508 

509 def set_response_headers(self, response, ttl): 

510 # Set far-future cache expiry 

511 if self.ds.cache_headers and response.status == 200: 

512 ttl = int(ttl) 

513 if ttl == 0: 

514 ttl_header = "no-cache" 

515 else: 

516 ttl_header = "max-age={}".format(ttl) 

517 response.headers["Cache-Control"] = ttl_header 

518 response.headers["Referrer-Policy"] = "no-referrer" 

519 if self.ds.cors: 

520 response.headers["Access-Control-Allow-Origin"] = "*" 

521 return response