Skip to content

Commit 7f78e1b

Browse files
authored
Merge pull request #479 from ganisback/support-stream
feat: support stream api
2 parents b689a4d + e201ffa commit 7f78e1b

File tree

4 files changed

+168
-4
lines changed

4 files changed

+168
-4
lines changed

jupyter_server_proxy/handlers.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import os
8+
import re
89
import socket
910
from asyncio import Lock
1011
from copy import copy
@@ -287,7 +288,7 @@ def get_client_uri(self, protocol, host, port, proxied_path):
287288

288289
return client_uri
289290

290-
def _build_proxy_request(self, host, port, proxied_path, body):
291+
def _build_proxy_request(self, host, port, proxied_path, body, **extra_opts):
291292
headers = self.proxy_request_headers()
292293

293294
client_uri = self.get_client_uri("http", host, port, proxied_path)
@@ -307,6 +308,7 @@ def _build_proxy_request(self, host, port, proxied_path, body):
307308
decompress_response=False,
308309
headers=headers,
309310
**self.proxy_request_options(),
311+
**extra_opts,
310312
)
311313
return req
312314

@@ -365,7 +367,6 @@ async def proxy(self, host, port, proxied_path):
365367
body = b""
366368
else:
367369
body = None
368-
369370
if self.unix_socket is not None:
370371
# Port points to a Unix domain socket
371372
self.log.debug("Making client for Unix socket %r", self.unix_socket)
@@ -374,8 +375,97 @@ async def proxy(self, host, port, proxied_path):
374375
force_instance=True, resolver=UnixResolver(self.unix_socket)
375376
)
376377
else:
377-
client = httpclient.AsyncHTTPClient()
378+
client = httpclient.AsyncHTTPClient(force_instance=True)
379+
# check if the request is stream request
380+
accept_header = self.request.headers.get("Accept")
381+
if accept_header == "text/event-stream":
382+
return await self._proxy_progressive(host, port, proxied_path, body, client)
383+
else:
384+
return await self._proxy_buffered(host, port, proxied_path, body, client)
385+
386+
async def _proxy_progressive(self, host, port, proxied_path, body, client):
387+
# Proxy in progressive flush mode, whenever chunks are received. Potentially slower but get results quicker for voila
388+
# Set up handlers so we can progressively flush result
389+
390+
headers_raw = []
391+
392+
def dump_headers(headers_raw):
393+
for line in headers_raw:
394+
r = re.match("^([a-zA-Z0-9\\-_]+)\\s*\\:\\s*([^\r\n]+)[\r\n]*$", line)
395+
if r:
396+
k, v = r.groups([1, 2])
397+
if k not in (
398+
"Content-Length",
399+
"Transfer-Encoding",
400+
"Content-Encoding",
401+
"Connection",
402+
):
403+
# some header appear multiple times, eg 'Set-Cookie'
404+
self.set_header(k, v)
405+
else:
406+
r = re.match(r"^HTTP[^\s]* ([0-9]+)", line)
407+
if r:
408+
status_code = r.group(1)
409+
self.set_status(int(status_code))
410+
headers_raw.clear()
411+
412+
# clear tornado default header
413+
self._headers = httputil.HTTPHeaders()
414+
415+
def header_callback(line):
416+
headers_raw.append(line)
417+
418+
def streaming_callback(chunk):
419+
# record activity at start and end of requests
420+
self._record_activity()
421+
# Do this here, not in header_callback so we can be sure headers are out of the way first
422+
dump_headers(
423+
headers_raw
424+
) # array will be empty if this was already called before
425+
self.write(chunk)
426+
self.flush()
427+
428+
# Now make the request
429+
430+
req = self._build_proxy_request(
431+
host,
432+
port,
433+
proxied_path,
434+
body,
435+
streaming_callback=streaming_callback,
436+
header_callback=header_callback,
437+
)
438+
439+
# no timeout for stream api
440+
req.request_timeout = 7200
441+
req.connect_timeout = 600
442+
443+
try:
444+
response = await client.fetch(req, raise_error=False)
445+
except httpclient.HTTPError as err:
446+
if err.code == 599:
447+
self._record_activity()
448+
self.set_status(599)
449+
self.write(str(err))
450+
return
451+
else:
452+
raise
453+
454+
# For all non http errors...
455+
if response.error and type(response.error) is not httpclient.HTTPError:
456+
self.set_status(500)
457+
self.write(str(response.error))
458+
else:
459+
self.set_status(
460+
response.code, response.reason
461+
) # Should already have been set
462+
463+
dump_headers(headers_raw) # Should already have been emptied
464+
465+
if response.body: # Likewise, should already be chunked out and flushed
466+
self.write(response.body)
378467

468+
async def _proxy_buffered(self, host, port, proxied_path, body, client):
379469
req = self._build_proxy_request(host, port, proxied_path, body)
380470

381471
self.log.debug(f"Proxying request to {req.url}")

tests/resources/eventstream.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import asyncio
2+
3+
import tornado.escape
4+
import tornado.ioloop
5+
import tornado.options
6+
import tornado.web
7+
import tornado.websocket
8+
from tornado.options import define, options
9+
10+
11+
class Application(tornado.web.Application):
12+
def __init__(self):
13+
handlers = [
14+
(r"/stream/(\d+)", StreamHandler),
15+
]
16+
super().__init__(handlers)
17+
18+
19+
class StreamHandler(tornado.web.RequestHandler):
20+
async def get(self, seconds):
21+
for i in range(int(seconds)):
22+
await asyncio.sleep(0.5)
23+
self.write(f"data: {i}\n\n")
24+
await self.flush()
25+
26+
27+
def main():
28+
define("port", default=8888, help="run on the given port", type=int)
29+
options.parse_command_line()
30+
app = Application()
31+
app.listen(options.port)
32+
tornado.ioloop.IOLoop.current().start()
33+
34+
35+
if __name__ == "__main__":
36+
main()

tests/resources/jupyter_server_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def my_env():
7979
"X-Custom-Header": "pytest-23456",
8080
},
8181
},
82+
"python-eventstream": {
83+
"command": [sys.executable, "./tests/resources/eventstream.py", "--port={port}"]
84+
},
8285
"python-unix-socket-true": {
8386
"command": [
8487
sys.executable,

tests/test_proxies.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import gzip
22
import json
33
import sys
4+
import time
45
from http.client import HTTPConnection
56
from io import BytesIO
67
from typing import Tuple
78
from urllib.parse import quote
89

910
import pytest
10-
from tornado.httpclient import HTTPClientError
11+
from tornado.httpclient import AsyncHTTPClient, HTTPClientError
1112
from tornado.websocket import websocket_connect
1213

1314
# use ipv4 for CI, etc.
@@ -343,6 +344,40 @@ def test_server_content_encoding_header(
343344
assert f.read() == b"this is a test"
344345

345346

347+
async def test_eventstream(a_server_port_and_token: Tuple[int, str]) -> None:
348+
PORT, TOKEN = a_server_port_and_token
349+
# The test server under eventstream.py will send back monotonically increasing numbers
350+
# starting at 0 until the specified limit, with a 500ms gap between them. We test that:
351+
# 1. We get back as many callbacks from our streaming read as the total number,
352+
# as the server does a flush after writing each entry.
353+
# 2. The streaming entries are read (with some error margin) around the 500ms mark, to
354+
# ensure this is *actually* being streamed
355+
limit = 3
356+
last_cb_time = time.perf_counter()
357+
times_called = 0
358+
stream_read_intervals = []
359+
stream_data = []
360+
361+
def streaming_cb(data):
362+
nonlocal times_called, last_cb_time, stream_read_intervals
363+
time_taken = time.perf_counter() - last_cb_time
364+
last_cb_time = time.perf_counter()
365+
stream_read_intervals.append(time_taken)
366+
times_called += 1
367+
stream_data.append(data)
368+
369+
url = f"http://{LOCALHOST}:{PORT}/python-eventstream/stream/{limit}?token={TOKEN}"
370+
client = AsyncHTTPClient()
371+
await client.fetch(
372+
url,
373+
headers={"Accept": "text/event-stream"},
374+
streaming_callback=streaming_cb,
375+
)
376+
assert times_called == limit
377+
assert all([0.45 < t < 3.0 for t in stream_read_intervals])
378+
assert stream_data == [b"data: 0\n\n", b"data: 1\n\n", b"data: 2\n\n"]
379+
380+
346381
async def test_server_proxy_websocket_messages(
347382
a_server_port_and_token: Tuple[int, str]
348383
) -> None:

0 commit comments

Comments
 (0)