Coverage for datasette/utils/asgi.py : 90%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import json
2from datasette.utils import MultiParams
3from mimetypes import guess_type
4from urllib.parse import parse_qs, urlunparse, parse_qsl
5from pathlib import Path
6from html import escape
7from http.cookies import SimpleCookie, Morsel
8import re
9import aiofiles
11# Workaround for adding samesite support to pre 3.8 python
12Morsel._reserved["samesite"] = "SameSite"
13# Thanks, Starlette:
14# https://github.com/encode/starlette/blob/519f575/starlette/responses.py#L17
17class NotFound(Exception):
18 pass
21class Forbidden(Exception):
22 pass
25SAMESITE_VALUES = ("strict", "lax", "none")
28class Request:
29 def __init__(self, scope, receive):
30 self.scope = scope
31 self.receive = receive
33 @property
34 def method(self):
35 return self.scope["method"]
37 @property
38 def url(self):
39 return urlunparse(
40 (self.scheme, self.host, self.path, None, self.query_string, None)
41 )
43 @property
44 def url_vars(self):
45 return (self.scope.get("url_route") or {}).get("kwargs") or {}
47 @property
48 def scheme(self):
49 return self.scope.get("scheme") or "http"
51 @property
52 def headers(self):
53 return dict(
54 [
55 (k.decode("latin-1").lower(), v.decode("latin-1"))
56 for k, v in self.scope.get("headers") or []
57 ]
58 )
60 @property
61 def host(self):
62 return self.headers.get("host") or "localhost"
64 @property
65 def cookies(self):
66 cookies = SimpleCookie()
67 cookies.load(self.headers.get("cookie", ""))
68 return {key: value.value for key, value in cookies.items()}
70 @property
71 def path(self):
72 if self.scope.get("raw_path") is not None:
73 return self.scope["raw_path"].decode("latin-1")
74 else:
75 path = self.scope["path"]
76 if isinstance(path, str):
77 return path
78 else:
79 return path.decode("utf-8")
81 @property
82 def query_string(self):
83 return (self.scope.get("query_string") or b"").decode("latin-1")
85 @property
86 def args(self):
87 return MultiParams(parse_qs(qs=self.query_string))
89 @property
90 def actor(self):
91 return self.scope.get("actor", None)
93 async def post_vars(self):
94 body = []
95 body = b""
96 more_body = True
97 while more_body:
98 message = await self.receive()
99 assert message["type"] == "http.request", message
100 body += message.get("body", b"")
101 more_body = message.get("more_body", False)
103 return dict(parse_qsl(body.decode("utf-8"), keep_blank_values=True))
105 @classmethod
106 def fake(cls, path_with_query_string, method="GET", scheme="http"):
107 "Useful for constructing Request objects for tests"
108 path, _, query_string = path_with_query_string.partition("?")
109 scope = {
110 "http_version": "1.1",
111 "method": method,
112 "path": path,
113 "raw_path": path.encode("latin-1"),
114 "query_string": query_string.encode("latin-1"),
115 "scheme": scheme,
116 "type": "http",
117 }
118 return cls(scope, None)
121class AsgiRouter:
122 def __init__(self, routes=None):
123 routes = routes or []
124 self.routes = [
125 # Compile any strings to regular expressions
126 ((re.compile(pattern) if isinstance(pattern, str) else pattern), view)
127 for pattern, view in routes
128 ]
130 async def __call__(self, scope, receive, send):
131 # Because we care about "foo/bar" v.s. "foo%2Fbar" we decode raw_path ourselves
132 path = scope["path"]
133 raw_path = scope.get("raw_path")
134 if raw_path:
135 path = raw_path.decode("ascii")
136 return await self.route_path(scope, receive, send, path)
138 async def route_path(self, scope, receive, send, path):
139 for regex, view in self.routes:
140 match = regex.match(path)
141 if match is not None:
142 new_scope = dict(scope, url_route={"kwargs": match.groupdict()})
143 try:
144 return await view(new_scope, receive, send)
145 except NotFound as exception:
146 return await self.handle_404(scope, receive, send, exception)
147 except Exception as exception:
148 return await self.handle_500(scope, receive, send, exception)
149 return await self.handle_404(scope, receive, send)
151 async def handle_404(self, scope, receive, send, exception=None):
152 await send(
153 {
154 "type": "http.response.start",
155 "status": 404,
156 "headers": [[b"content-type", b"text/html; charset=utf-8"]],
157 }
158 )
159 await send({"type": "http.response.body", "body": b"<h1>404</h1>"})
161 async def handle_500(self, scope, receive, send, exception):
162 await send(
163 {
164 "type": "http.response.start",
165 "status": 404,
166 "headers": [[b"content-type", b"text/html; charset=utf-8"]],
167 }
168 )
169 html = "<h1>500</h1><pre{}></pre>".format(escape(repr(exception)))
170 await send({"type": "http.response.body", "body": html.encode("utf-8")})
173class AsgiLifespan:
174 def __init__(self, app, on_startup=None, on_shutdown=None):
175 self.app = app
176 on_startup = on_startup or []
177 on_shutdown = on_shutdown or []
178 if not isinstance(on_startup or [], list):
179 on_startup = [on_startup]
180 if not isinstance(on_shutdown or [], list):
181 on_shutdown = [on_shutdown]
182 self.on_startup = on_startup
183 self.on_shutdown = on_shutdown
185 async def __call__(self, scope, receive, send):
186 if scope["type"] == "lifespan":
187 while True:
188 message = await receive()
189 if message["type"] == "lifespan.startup":
190 for fn in self.on_startup:
191 await fn()
192 await send({"type": "lifespan.startup.complete"})
193 elif message["type"] == "lifespan.shutdown":
194 for fn in self.on_shutdown:
195 await fn()
196 await send({"type": "lifespan.shutdown.complete"})
197 return
198 else:
199 await self.app(scope, receive, send)
202class AsgiView:
203 async def dispatch_request(self, request, *args, **kwargs):
204 handler = getattr(self, request.method.lower(), None)
205 return await handler(request, *args, **kwargs)
207 @classmethod
208 def as_asgi(cls, *class_args, **class_kwargs):
209 async def view(scope, receive, send):
210 # Uses scope to create a request object, then dispatches that to
211 # self.get(...) or self.options(...) along with keyword arguments
212 # that were already tucked into scope["url_route"]["kwargs"] by
213 # the router, similar to how Django Channels works:
214 # https://channels.readthedocs.io/en/latest/topics/routing.html#urlrouter
215 request = Request(scope, receive)
216 self = view.view_class(*class_args, **class_kwargs)
217 response = await self.dispatch_request(
218 request, **scope["url_route"]["kwargs"]
219 )
220 await response.asgi_send(send)
222 view.view_class = cls
223 view.__doc__ = cls.__doc__
224 view.__module__ = cls.__module__
225 view.__name__ = cls.__name__
226 return view
229class AsgiStream:
230 def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"):
231 self.stream_fn = stream_fn
232 self.status = status
233 self.headers = headers or {}
234 self.content_type = content_type
236 async def asgi_send(self, send):
237 # Remove any existing content-type header
238 headers = dict(
239 [(k, v) for k, v in self.headers.items() if k.lower() != "content-type"]
240 )
241 headers["content-type"] = self.content_type
242 await send(
243 {
244 "type": "http.response.start",
245 "status": self.status,
246 "headers": [
247 [key.encode("utf-8"), value.encode("utf-8")]
248 for key, value in headers.items()
249 ],
250 }
251 )
252 w = AsgiWriter(send)
253 await self.stream_fn(w)
254 await send({"type": "http.response.body", "body": b""})
257class AsgiWriter:
258 def __init__(self, send):
259 self.send = send
261 async def write(self, chunk):
262 await self.send(
263 {
264 "type": "http.response.body",
265 "body": chunk.encode("utf-8"),
266 "more_body": True,
267 }
268 )
271async def asgi_send_json(send, info, status=200, headers=None):
272 headers = headers or {}
273 await asgi_send(
274 send,
275 json.dumps(info),
276 status=status,
277 headers=headers,
278 content_type="application/json; charset=utf-8",
279 )
282async def asgi_send_html(send, html, status=200, headers=None):
283 headers = headers or {}
284 await asgi_send(
285 send,
286 html,
287 status=status,
288 headers=headers,
289 content_type="text/html; charset=utf-8",
290 )
293async def asgi_send_redirect(send, location, status=302):
294 await asgi_send(
295 send,
296 "",
297 status=status,
298 headers={"Location": location},
299 content_type="text/html; charset=utf-8",
300 )
303async def asgi_send(send, content, status, headers=None, content_type="text/plain"):
304 await asgi_start(send, status, headers, content_type)
305 await send({"type": "http.response.body", "body": content.encode("utf-8")})
308async def asgi_start(send, status, headers=None, content_type="text/plain"):
309 headers = headers or {}
310 # Remove any existing content-type header
311 headers = dict([(k, v) for k, v in headers.items() if k.lower() != "content-type"])
312 headers["content-type"] = content_type
313 await send(
314 {
315 "type": "http.response.start",
316 "status": status,
317 "headers": [
318 [key.encode("latin1"), value.encode("latin1")]
319 for key, value in headers.items()
320 ],
321 }
322 )
325async def asgi_send_file(
326 send, filepath, filename=None, content_type=None, chunk_size=4096
327):
328 headers = {}
329 if filename:
330 headers["Content-Disposition"] = 'attachment; filename="{}"'.format(filename)
331 first = True
332 async with aiofiles.open(str(filepath), mode="rb") as fp:
333 if first:
334 await asgi_start(
335 send,
336 200,
337 headers,
338 content_type or guess_type(str(filepath))[0] or "text/plain",
339 )
340 first = False
341 more_body = True
342 while more_body:
343 chunk = await fp.read(chunk_size)
344 more_body = len(chunk) == chunk_size
345 await send(
346 {"type": "http.response.body", "body": chunk, "more_body": more_body}
347 )
350def asgi_static(root_path, chunk_size=4096, headers=None, content_type=None):
351 async def inner_static(scope, receive, send):
352 path = scope["url_route"]["kwargs"]["path"]
353 try:
354 full_path = (Path(root_path) / path).resolve().absolute()
355 except FileNotFoundError:
356 await asgi_send_html(send, "404", 404)
357 return
358 if full_path.is_dir():
359 await asgi_send_html(send, "403: Directory listing is not allowed", 403)
360 return
361 # Ensure full_path is within root_path to avoid weird "../" tricks
362 try:
363 full_path.relative_to(root_path)
364 except ValueError:
365 await asgi_send_html(send, "404", 404)
366 return
367 try:
368 await asgi_send_file(send, full_path, chunk_size=chunk_size)
369 except FileNotFoundError:
370 await asgi_send_html(send, "404", 404)
371 return
373 return inner_static
376class Response:
377 def __init__(self, body=None, status=200, headers=None, content_type="text/plain"):
378 self.body = body
379 self.status = status
380 self.headers = headers or {}
381 self._set_cookie_headers = []
382 self.content_type = content_type
384 async def asgi_send(self, send):
385 headers = {}
386 headers.update(self.headers)
387 headers["content-type"] = self.content_type
388 raw_headers = [
389 [key.encode("utf-8"), value.encode("utf-8")]
390 for key, value in headers.items()
391 ]
392 for set_cookie in self._set_cookie_headers:
393 raw_headers.append([b"set-cookie", set_cookie.encode("utf-8")])
394 await send(
395 {
396 "type": "http.response.start",
397 "status": self.status,
398 "headers": raw_headers,
399 }
400 )
401 body = self.body
402 if not isinstance(body, bytes):
403 body = body.encode("utf-8")
404 await send({"type": "http.response.body", "body": body})
406 def set_cookie(
407 self,
408 key,
409 value="",
410 max_age=None,
411 expires=None,
412 path="/",
413 domain=None,
414 secure=False,
415 httponly=False,
416 samesite="lax",
417 ):
418 assert samesite in SAMESITE_VALUES, "samesite should be one of {}".format(
419 SAMESITE_VALUES
420 )
421 cookie = SimpleCookie()
422 cookie[key] = value
423 for prop_name, prop_value in (
424 ("max_age", max_age),
425 ("expires", expires),
426 ("path", path),
427 ("domain", domain),
428 ("samesite", samesite),
429 ):
430 if prop_value is not None:
431 cookie[key][prop_name.replace("_", "-")] = prop_value
432 for prop_name, prop_value in (("secure", secure), ("httponly", httponly)):
433 if prop_value:
434 cookie[key][prop_name] = True
435 self._set_cookie_headers.append(cookie.output(header="").strip())
437 @classmethod
438 def html(cls, body, status=200, headers=None):
439 return cls(
440 body,
441 status=status,
442 headers=headers,
443 content_type="text/html; charset=utf-8",
444 )
446 @classmethod
447 def text(cls, body, status=200, headers=None):
448 return cls(
449 str(body),
450 status=status,
451 headers=headers,
452 content_type="text/plain; charset=utf-8",
453 )
455 @classmethod
456 def json(cls, body, status=200, headers=None):
457 return cls(
458 json.dumps(body),
459 status=status,
460 headers=headers,
461 content_type="application/json; charset=utf-8",
462 )
464 @classmethod
465 def redirect(cls, path, status=302, headers=None):
466 headers = headers or {}
467 headers["Location"] = path
468 return cls("", status=status, headers=headers)
471class AsgiFileDownload:
472 def __init__(
473 self, filepath, filename=None, content_type="application/octet-stream"
474 ):
475 self.filepath = filepath
476 self.filename = filename
477 self.content_type = content_type
479 async def asgi_send(self, send):
480 return await asgi_send_file(send, self.filepath, content_type=self.content_type)