Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import asyncio 

2from contextlib import contextmanager 

3import time 

4import json 

5import traceback 

6 

7tracers = {} 

8 

9TRACE_RESERVED_KEYS = {"type", "start", "end", "duration_ms", "traceback"} 

10 

11 

12# asyncio.current_task was introduced in Python 3.7: 

13for obj in (asyncio, asyncio.Task): 

14 current_task = getattr(obj, "current_task", None) 

15 if current_task is not None: 

16 break 

17 

18 

19def get_task_id(): 

20 try: 

21 loop = asyncio.get_event_loop() 

22 except RuntimeError: 

23 return None 

24 return id(current_task(loop=loop)) 

25 

26 

27@contextmanager 

28def trace(type, **kwargs): 

29 assert not TRACE_RESERVED_KEYS.intersection( 

30 kwargs.keys() 

31 ), ".trace() keyword parameters cannot include {}".format(TRACE_RESERVED_KEYS) 

32 task_id = get_task_id() 

33 if task_id is None: 

34 yield 

35 return 

36 tracer = tracers.get(task_id) 

37 if tracer is None: 

38 yield 

39 return 

40 start = time.time() 

41 yield 

42 end = time.time() 

43 trace_info = { 

44 "type": type, 

45 "start": start, 

46 "end": end, 

47 "duration_ms": (end - start) * 1000, 

48 "traceback": traceback.format_list(traceback.extract_stack(limit=6)[:-3]), 

49 } 

50 trace_info.update(kwargs) 

51 tracer.append(trace_info) 

52 

53 

54@contextmanager 

55def capture_traces(tracer): 

56 # tracer is a list 

57 task_id = get_task_id() 

58 if task_id is None: 

59 yield 

60 return 

61 tracers[task_id] = tracer 

62 yield 

63 del tracers[task_id] 

64 

65 

66class AsgiTracer: 

67 # If the body is larger than this we don't attempt to append the trace 

68 max_body_bytes = 1024 * 256 # 256 KB 

69 

70 def __init__(self, app): 

71 self.app = app 

72 

73 async def __call__(self, scope, receive, send): 

74 if b"_trace=1" not in scope.get("query_string", b"").split(b"&"): 

75 await self.app(scope, receive, send) 

76 return 

77 trace_start = time.time() 

78 traces = [] 

79 

80 accumulated_body = b"" 

81 size_limit_exceeded = False 

82 response_headers = [] 

83 

84 async def wrapped_send(message): 

85 nonlocal accumulated_body, size_limit_exceeded, response_headers 

86 if message["type"] == "http.response.start": 

87 response_headers = message["headers"] 

88 await send(message) 

89 return 

90 

91 if message["type"] != "http.response.body" or size_limit_exceeded: 

92 await send(message) 

93 return 

94 

95 # Accumulate body until the end or until size is exceeded 

96 accumulated_body += message["body"] 

97 if len(accumulated_body) > self.max_body_bytes: 

98 await send( 

99 { 

100 "type": "http.response.body", 

101 "body": accumulated_body, 

102 "more_body": True, 

103 } 

104 ) 

105 size_limit_exceeded = True 

106 return 

107 

108 if not message.get("more_body"): 

109 # We have all the body - modify it and send the result 

110 # TODO: What to do about Content-Type or other cases? 

111 trace_info = { 

112 "request_duration_ms": 1000 * (time.time() - trace_start), 

113 "sum_trace_duration_ms": sum(t["duration_ms"] for t in traces), 

114 "num_traces": len(traces), 

115 "traces": traces, 

116 } 

117 try: 

118 content_type = [ 

119 v.decode("utf8") 

120 for k, v in response_headers 

121 if k.lower() == b"content-type" 

122 ][0] 

123 except IndexError: 

124 content_type = "" 

125 if "text/html" in content_type and b"</body>" in accumulated_body: 

126 extra = json.dumps(trace_info, indent=2) 

127 extra_html = "<pre>{}</pre></body>".format(extra).encode("utf8") 

128 accumulated_body = accumulated_body.replace(b"</body>", extra_html) 

129 elif "json" in content_type and accumulated_body.startswith(b"{"): 

130 data = json.loads(accumulated_body.decode("utf8")) 

131 if "_trace" not in data: 

132 data["_trace"] = trace_info 

133 accumulated_body = json.dumps(data).encode("utf8") 

134 await send({"type": "http.response.body", "body": accumulated_body}) 

135 

136 with capture_traces(traces): 

137 await self.app(scope, receive, wrapped_send)