Coverage for datasette/views/base.py : 93%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import csv
3import itertools
4from itsdangerous import BadSignature
5import json
6import re
7import time
8import urllib
10import pint
12from datasette import __version__
13from datasette.plugins import pm
14from datasette.database import QueryInterrupted
15from datasette.utils import (
16 InvalidSql,
17 LimitedWriter,
18 call_with_supported_arguments,
19 is_url,
20 path_with_added_args,
21 path_with_removed_args,
22 path_with_format,
23 resolve_table_and_format,
24 sqlite3,
25 to_css_class,
26)
27from datasette.utils.asgi import (
28 AsgiStream,
29 AsgiWriter,
30 AsgiRouter,
31 AsgiView,
32 Forbidden,
33 NotFound,
34 Response,
35)
37ureg = pint.UnitRegistry()
39HASH_LENGTH = 7
42class DatasetteError(Exception):
43 def __init__(
44 self,
45 message,
46 title=None,
47 error_dict=None,
48 status=500,
49 template=None,
50 messagge_is_html=False,
51 ):
52 self.message = message
53 self.title = title
54 self.error_dict = error_dict or {}
55 self.status = status
56 self.messagge_is_html = messagge_is_html
59class BaseView(AsgiView):
60 ds = None
62 async def head(self, *args, **kwargs):
63 response = await self.get(*args, **kwargs)
64 response.body = b""
65 return response
67 async def check_permission(self, request, action, resource=None):
68 ok = await self.ds.permission_allowed(
69 request.actor, action, resource=resource, default=True,
70 )
71 if not ok:
72 raise Forbidden(action)
74 def database_url(self, database):
75 db = self.ds.databases[database]
76 base_url = self.ds.config("base_url")
77 if self.ds.config("hash_urls") and db.hash:
78 return "{}{}-{}".format(base_url, database, db.hash[:HASH_LENGTH])
79 else:
80 return "{}{}".format(base_url, database)
82 def database_color(self, database):
83 return "ff0000"
85 async def dispatch_request(self, request, *args, **kwargs):
86 # Populate request_messages if ds_messages cookie is present
87 if self.ds:
88 try:
89 request._messages = self.ds.unsign(
90 request.cookies.get("ds_messages", ""), "messages"
91 )
92 except BadSignature:
93 pass
94 response = await super().dispatch_request(request, *args, **kwargs)
95 if self.ds:
96 self.ds._write_messages_to_response(request, response)
97 return response
99 async def render(self, templates, request, context=None):
100 context = context or {}
101 template = self.ds.jinja_env.select_template(templates)
102 template_context = {
103 **context,
104 **{
105 "database_url": self.database_url,
106 "csrftoken": request.scope["csrftoken"],
107 "database_color": self.database_color,
108 "show_messages": lambda: self.ds._show_messages(request),
109 "select_templates": [
110 "{}{}".format(
111 "*" if template_name == template.name else "", template_name
112 )
113 for template_name in templates
114 ],
115 },
116 }
117 return Response.html(
118 await self.ds.render_template(
119 template, template_context, request=request, view_name=self.name
120 )
121 )
124class DataView(BaseView):
125 name = ""
126 re_named_parameter = re.compile(":([a-zA-Z0-9_]+)")
128 def __init__(self, datasette):
129 self.ds = datasette
131 def options(self, request, *args, **kwargs):
132 r = Response.text("ok")
133 if self.ds.cors:
134 r.headers["Access-Control-Allow-Origin"] = "*"
135 return r
137 def redirect(self, request, path, forward_querystring=True, remove_args=None):
138 if request.query_string and "?" not in path and forward_querystring:
139 path = "{}?{}".format(path, request.query_string)
140 if remove_args:
141 path = path_with_removed_args(request, remove_args, path=path)
142 r = Response.redirect(path)
143 r.headers["Link"] = "<{}>; rel=preload".format(path)
144 if self.ds.cors:
145 r.headers["Access-Control-Allow-Origin"] = "*"
146 return r
148 async def data(self, request, database, hash, **kwargs):
149 raise NotImplementedError
151 async def resolve_db_name(self, request, db_name, **kwargs):
152 hash = None
153 name = None
154 if db_name not in self.ds.databases and "-" in db_name:
155 # No matching DB found, maybe it's a name-hash?
156 name_bit, hash_bit = db_name.rsplit("-", 1)
157 if name_bit not in self.ds.databases:
158 raise NotFound("Database not found: {}".format(name))
159 else:
160 name = name_bit
161 hash = hash_bit
162 else:
163 name = db_name
164 name = urllib.parse.unquote_plus(name)
165 try:
166 db = self.ds.databases[name]
167 except KeyError:
168 raise NotFound("Database not found: {}".format(name))
170 # Verify the hash
171 expected = "000"
172 if db.hash is not None:
173 expected = db.hash[:HASH_LENGTH]
174 correct_hash_provided = expected == hash
176 if not correct_hash_provided:
177 if "table_and_format" in kwargs:
179 async def async_table_exists(t):
180 return await db.table_exists(t)
182 table, _format = await resolve_table_and_format(
183 table_and_format=urllib.parse.unquote_plus(
184 kwargs["table_and_format"]
185 ),
186 table_exists=async_table_exists,
187 allowed_formats=self.ds.renderers.keys(),
188 )
189 kwargs["table"] = table
190 if _format:
191 kwargs["as_format"] = ".{}".format(_format)
192 elif kwargs.get("table"):
193 kwargs["table"] = urllib.parse.unquote_plus(kwargs["table"])
195 should_redirect = "/{}-{}".format(name, expected)
196 if kwargs.get("table"):
197 should_redirect += "/" + urllib.parse.quote_plus(kwargs["table"])
198 if kwargs.get("pk_path"):
199 should_redirect += "/" + kwargs["pk_path"]
200 if kwargs.get("as_format"):
201 should_redirect += kwargs["as_format"]
202 if kwargs.get("as_db"):
203 should_redirect += kwargs["as_db"]
205 if (
206 (self.ds.config("hash_urls") or "_hash" in request.args)
207 and
208 # Redirect only if database is immutable
209 not self.ds.databases[name].is_mutable
210 ):
211 return name, expected, correct_hash_provided, should_redirect
213 return name, expected, correct_hash_provided, None
215 def get_templates(self, database, table=None):
216 assert NotImplemented
218 async def get(self, request, db_name, **kwargs):
219 (
220 database,
221 hash,
222 correct_hash_provided,
223 should_redirect,
224 ) = await self.resolve_db_name(request, db_name, **kwargs)
225 if should_redirect:
226 return self.redirect(request, should_redirect, remove_args={"_hash"})
228 return await self.view_get(
229 request, database, hash, correct_hash_provided, **kwargs
230 )
232 async def as_csv(self, request, database, hash, **kwargs):
233 stream = request.args.get("_stream")
234 if stream:
235 # Some quick sanity checks
236 if not self.ds.config("allow_csv_stream"):
237 raise DatasetteError("CSV streaming is disabled", status=400)
238 if request.args.get("_next"):
239 raise DatasetteError("_next not allowed for CSV streaming", status=400)
240 kwargs["_size"] = "max"
241 # Fetch the first page
242 try:
243 response_or_template_contexts = await self.data(
244 request, database, hash, **kwargs
245 )
246 if isinstance(response_or_template_contexts, Response):
247 return response_or_template_contexts
248 else:
249 data, _, _ = response_or_template_contexts
250 except (sqlite3.OperationalError, InvalidSql) as e:
251 raise DatasetteError(str(e), title="Invalid SQL", status=400)
253 except (sqlite3.OperationalError) as e:
254 raise DatasetteError(str(e))
256 except DatasetteError:
257 raise
259 # Convert rows and columns to CSV
260 headings = data["columns"]
261 # if there are expanded_columns we need to add additional headings
262 expanded_columns = set(data.get("expanded_columns") or [])
263 if expanded_columns:
264 headings = []
265 for column in data["columns"]:
266 headings.append(column)
267 if column in expanded_columns:
268 headings.append("{}_label".format(column))
270 async def stream_fn(r):
271 nonlocal data
272 writer = csv.writer(LimitedWriter(r, self.ds.config("max_csv_mb")))
273 first = True
274 next = None
275 while first or (next and stream):
276 try:
277 if next:
278 kwargs["_next"] = next
279 if not first:
280 data, _, _ = await self.data(request, database, hash, **kwargs)
281 if first:
282 await writer.writerow(headings)
283 first = False
284 next = data.get("next")
285 for row in data["rows"]:
286 if not expanded_columns:
287 # Simple path
288 await writer.writerow(row)
289 else:
290 # Look for {"value": "label": } dicts and expand
291 new_row = []
292 for heading, cell in zip(data["columns"], row):
293 if heading in expanded_columns:
294 if cell is None:
295 new_row.extend(("", ""))
296 else:
297 assert isinstance(cell, dict)
298 new_row.append(cell["value"])
299 new_row.append(cell["label"])
300 else:
301 new_row.append(cell)
302 await writer.writerow(new_row)
303 except Exception as e:
304 print("caught this", e)
305 await r.write(str(e))
306 return
308 content_type = "text/plain; charset=utf-8"
309 headers = {}
310 if self.ds.cors:
311 headers["Access-Control-Allow-Origin"] = "*"
312 if request.args.get("_dl", None):
313 content_type = "text/csv; charset=utf-8"
314 disposition = 'attachment; filename="{}.csv"'.format(
315 kwargs.get("table", database)
316 )
317 headers["Content-Disposition"] = disposition
319 return AsgiStream(stream_fn, headers=headers, content_type=content_type)
321 async def get_format(self, request, database, args):
322 """ Determine the format of the response from the request, from URL
323 parameters or from a file extension.
325 `args` is a dict of the path components parsed from the URL by the router.
326 """
327 # If ?_format= is provided, use that as the format
328 _format = request.args.get("_format", None)
329 if not _format:
330 _format = (args.pop("as_format", None) or "").lstrip(".")
331 else:
332 args.pop("as_format", None)
333 if "table_and_format" in args:
334 db = self.ds.databases[database]
336 async def async_table_exists(t):
337 return await db.table_exists(t)
339 table, _ext_format = await resolve_table_and_format(
340 table_and_format=urllib.parse.unquote_plus(args["table_and_format"]),
341 table_exists=async_table_exists,
342 allowed_formats=self.ds.renderers.keys(),
343 )
344 _format = _format or _ext_format
345 args["table"] = table
346 del args["table_and_format"]
347 elif "table" in args:
348 args["table"] = urllib.parse.unquote_plus(args["table"])
349 return _format, args
351 async def view_get(self, request, database, hash, correct_hash_provided, **kwargs):
352 _format, kwargs = await self.get_format(request, database, kwargs)
354 if _format == "csv":
355 return await self.as_csv(request, database, hash, **kwargs)
357 if _format is None:
358 # HTML views default to expanding all foreign key labels
359 kwargs["default_labels"] = True
361 extra_template_data = {}
362 start = time.time()
363 status_code = 200
364 templates = []
365 try:
366 response_or_template_contexts = await self.data(
367 request, database, hash, **kwargs
368 )
369 if isinstance(response_or_template_contexts, Response):
370 return response_or_template_contexts
372 else:
373 data, extra_template_data, templates = response_or_template_contexts
374 except QueryInterrupted:
375 raise DatasetteError(
376 """
377 SQL query took too long. The time limit is controlled by the
378 <a href="https://datasette.readthedocs.io/en/stable/config.html#sql-time-limit-ms">sql_time_limit_ms</a>
379 configuration option.
380 """,
381 title="SQL Interrupted",
382 status=400,
383 messagge_is_html=True,
384 )
385 except (sqlite3.OperationalError, InvalidSql) as e:
386 raise DatasetteError(str(e), title="Invalid SQL", status=400)
388 except (sqlite3.OperationalError) as e:
389 raise DatasetteError(str(e))
391 except DatasetteError:
392 raise
394 end = time.time()
395 data["query_ms"] = (end - start) * 1000
396 for key in ("source", "source_url", "license", "license_url"):
397 value = self.ds.metadata(key)
398 if value:
399 data[key] = value
401 # Special case for .jsono extension - redirect to _shape=objects
402 if _format == "jsono":
403 return self.redirect(
404 request,
405 path_with_added_args(
406 request,
407 {"_shape": "objects"},
408 path=request.path.rsplit(".jsono", 1)[0] + ".json",
409 ),
410 forward_querystring=False,
411 )
413 if _format in self.ds.renderers.keys():
414 # Dispatch request to the correct output format renderer
415 # (CSV is not handled here due to streaming)
416 result = call_with_supported_arguments(
417 self.ds.renderers[_format][0],
418 datasette=self.ds,
419 columns=data.get("columns") or [],
420 rows=data.get("rows") or [],
421 sql=data.get("query", {}).get("sql", None),
422 query_name=data.get("query_name"),
423 database=database,
424 table=data.get("table"),
425 request=request,
426 view_name=self.name,
427 # These will be deprecated in Datasette 1.0:
428 args=request.args,
429 data=data,
430 )
431 if asyncio.iscoroutine(result):
432 result = await result
433 if result is None:
434 raise NotFound("No data")
436 r = Response(
437 body=result.get("body"),
438 status=result.get("status_code", 200),
439 content_type=result.get("content_type", "text/plain"),
440 headers=result.get("headers"),
441 )
442 else:
443 extras = {}
444 if callable(extra_template_data):
445 extras = extra_template_data()
446 if asyncio.iscoroutine(extras):
447 extras = await extras
448 else:
449 extras = extra_template_data
450 url_labels_extra = {}
451 if data.get("expandable_columns"):
452 url_labels_extra = {"_labels": "on"}
454 renderers = {}
455 for key, (_, can_render) in self.ds.renderers.items():
456 it_can_render = call_with_supported_arguments(
457 can_render,
458 datasette=self.ds,
459 columns=data.get("columns") or [],
460 rows=data.get("rows") or [],
461 sql=data.get("query", {}).get("sql", None),
462 query_name=data.get("query_name"),
463 database=database,
464 table=data.get("table"),
465 request=request,
466 view_name=self.name,
467 )
468 if asyncio.iscoroutine(it_can_render):
469 it_can_render = await it_can_render
470 if it_can_render:
471 renderers[key] = path_with_format(
472 request, key, {**url_labels_extra}
473 )
475 url_csv_args = {"_size": "max", **url_labels_extra}
476 url_csv = path_with_format(request, "csv", url_csv_args)
477 url_csv_path = url_csv.split("?")[0]
478 context = {
479 **data,
480 **extras,
481 **{
482 "renderers": renderers,
483 "url_csv": url_csv,
484 "url_csv_path": url_csv_path,
485 "url_csv_hidden_args": [
486 (key, value)
487 for key, value in urllib.parse.parse_qsl(request.query_string)
488 if key not in ("_labels", "_facet", "_size")
489 ]
490 + [("_size", "max")],
491 "datasette_version": __version__,
492 "config": self.ds.config_dict(),
493 },
494 }
495 if "metadata" not in context:
496 context["metadata"] = self.ds.metadata
497 r = await self.render(templates, request=request, context=context)
498 r.status = status_code
500 ttl = request.args.get("_ttl", None)
501 if ttl is None or not ttl.isdigit():
502 if correct_hash_provided:
503 ttl = self.ds.config("default_cache_ttl_hashed")
504 else:
505 ttl = self.ds.config("default_cache_ttl")
507 return self.set_response_headers(r, ttl)
509 def set_response_headers(self, response, ttl):
510 # Set far-future cache expiry
511 if self.ds.cache_headers and response.status == 200:
512 ttl = int(ttl)
513 if ttl == 0:
514 ttl_header = "no-cache"
515 else:
516 ttl_header = "max-age={}".format(ttl)
517 response.headers["Cache-Control"] = ttl_header
518 response.headers["Referrer-Policy"] = "no-referrer"
519 if self.ds.cors:
520 response.headers["Access-Control-Allow-Origin"] = "*"
521 return response