Coverage for datasette/tracer.py : 81%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2from contextlib import contextmanager
3import time
4import json
5import traceback
7tracers = {}
9TRACE_RESERVED_KEYS = {"type", "start", "end", "duration_ms", "traceback"}
12# asyncio.current_task was introduced in Python 3.7:
13for obj in (asyncio, asyncio.Task):
14 current_task = getattr(obj, "current_task", None)
15 if current_task is not None:
16 break
19def get_task_id():
20 try:
21 loop = asyncio.get_event_loop()
22 except RuntimeError:
23 return None
24 return id(current_task(loop=loop))
27@contextmanager
28def trace(type, **kwargs):
29 assert not TRACE_RESERVED_KEYS.intersection(
30 kwargs.keys()
31 ), ".trace() keyword parameters cannot include {}".format(TRACE_RESERVED_KEYS)
32 task_id = get_task_id()
33 if task_id is None:
34 yield
35 return
36 tracer = tracers.get(task_id)
37 if tracer is None:
38 yield
39 return
40 start = time.time()
41 yield
42 end = time.time()
43 trace_info = {
44 "type": type,
45 "start": start,
46 "end": end,
47 "duration_ms": (end - start) * 1000,
48 "traceback": traceback.format_list(traceback.extract_stack(limit=6)[:-3]),
49 }
50 trace_info.update(kwargs)
51 tracer.append(trace_info)
54@contextmanager
55def capture_traces(tracer):
56 # tracer is a list
57 task_id = get_task_id()
58 if task_id is None:
59 yield
60 return
61 tracers[task_id] = tracer
62 yield
63 del tracers[task_id]
66class AsgiTracer:
67 # If the body is larger than this we don't attempt to append the trace
68 max_body_bytes = 1024 * 256 # 256 KB
70 def __init__(self, app):
71 self.app = app
73 async def __call__(self, scope, receive, send):
74 if b"_trace=1" not in scope.get("query_string", b"").split(b"&"):
75 await self.app(scope, receive, send)
76 return
77 trace_start = time.time()
78 traces = []
80 accumulated_body = b""
81 size_limit_exceeded = False
82 response_headers = []
84 async def wrapped_send(message):
85 nonlocal accumulated_body, size_limit_exceeded, response_headers
86 if message["type"] == "http.response.start":
87 response_headers = message["headers"]
88 await send(message)
89 return
91 if message["type"] != "http.response.body" or size_limit_exceeded:
92 await send(message)
93 return
95 # Accumulate body until the end or until size is exceeded
96 accumulated_body += message["body"]
97 if len(accumulated_body) > self.max_body_bytes:
98 await send(
99 {
100 "type": "http.response.body",
101 "body": accumulated_body,
102 "more_body": True,
103 }
104 )
105 size_limit_exceeded = True
106 return
108 if not message.get("more_body"):
109 # We have all the body - modify it and send the result
110 # TODO: What to do about Content-Type or other cases?
111 trace_info = {
112 "request_duration_ms": 1000 * (time.time() - trace_start),
113 "sum_trace_duration_ms": sum(t["duration_ms"] for t in traces),
114 "num_traces": len(traces),
115 "traces": traces,
116 }
117 try:
118 content_type = [
119 v.decode("utf8")
120 for k, v in response_headers
121 if k.lower() == b"content-type"
122 ][0]
123 except IndexError:
124 content_type = ""
125 if "text/html" in content_type and b"</body>" in accumulated_body:
126 extra = json.dumps(trace_info, indent=2)
127 extra_html = "<pre>{}</pre></body>".format(extra).encode("utf8")
128 accumulated_body = accumulated_body.replace(b"</body>", extra_html)
129 elif "json" in content_type and accumulated_body.startswith(b"{"):
130 data = json.loads(accumulated_body.decode("utf8"))
131 if "_trace" not in data:
132 data["_trace"] = trace_info
133 accumulated_body = json.dumps(data).encode("utf8")
134 await send({"type": "http.response.body", "body": accumulated_body})
136 with capture_traces(traces):
137 await self.app(scope, receive, wrapped_send)