Skip to content

Commit b56a548

Browse files
committed
Enable progressive proxy via flag
1 parent 76a98c9 commit b56a548

File tree

4 files changed

+110
-6
lines changed

4 files changed

+110
-6
lines changed

jupyter_server_proxy/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,27 @@ def cats_only(response, path):
249249
""",
250250
).tag(config=True)
251251

252+
progressive = Union(
253+
[Bool(), Callable()],
254+
default_value=None,
255+
allow_none=True,
256+
help="""
257+
Makes the proxy progressive, meaning it won't buffer any requests from the server.
258+
Useful for applications streaming their data, where the buffering of requests can lead
259+
to a lagging, e.g. in video streams.
260+
261+
Must be either None (default), a bool, or a function. Setting it to a boolean will enable/disable
262+
progressive requests for all requests. Setting to None, jupyter-server-proxy will only enable progressive
263+
for somespecial types, like videos, images and binary data. A function must be taking the "Accept" header of
264+
the request from the client as input and returning a bool, whether this request should be made progressive.
265+
266+
Note: `progressive` and `rewrite_response` are mutually exclusive on the same request. When rewrite_response
267+
is given and progressive is None, the proxying will never be progressive. If progressive is a function,
268+
rewrite_response will only be called on requests where it returns False. Progressive takes precedence over
269+
rewrite_response when both are given!
270+
""",
271+
).tag(config=True)
272+
252273
update_last_activity = Bool(
253274
True, help="Will cause the proxy to report activity back to jupyter server."
254275
).tag(config=True)

jupyter_server_proxy/handlers.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from traitlets.traitlets import HasTraits
2323

2424
from .unixsock import UnixResolver
25-
from .utils import call_with_asked_args
25+
from .utils import call_with_asked_args, mime_types_match
2626
from .websocket import WebSocketHandlerMixin, pingable_ws_connect
2727

2828

@@ -95,6 +95,15 @@ def get(self, *args):
9595
self.redirect(urlunparse(dest))
9696

9797

98+
COMMON_BINARY_MIME_TYPES = [
99+
"image/*",
100+
"audio/*",
101+
"video/*",
102+
"application/*",
103+
"text/event-stream",
104+
]
105+
106+
98107
class ProxyHandler(WebSocketHandlerMixin, JupyterHandler):
99108
"""
100109
A tornado request handler that proxies HTTP and websockets from
@@ -117,10 +126,41 @@ def __init__(self, *args, **kwargs):
117126
"rewrite_response",
118127
tuple(),
119128
)
129+
self.progressive = kwargs.pop("progressive", None)
120130
self._requested_subprotocols = None
121131
self.update_last_activity = kwargs.pop("update_last_activity", True)
122132
super().__init__(*args, **kwargs)
123133

134+
@property
135+
def progressive(self):
136+
accept_header = self.request.headers.get("Accept")
137+
138+
if self._progressive is not None:
139+
if callable(self._progressive):
140+
return self._progressive(accept_header)
141+
else:
142+
return self._progressive
143+
144+
# Progressive and RewritableResponse are mutually exclusive
145+
if self.rewrite_response:
146+
return False
147+
148+
if accept_header is None:
149+
return False
150+
151+
# If the client can accept multiple types, we will not make the request progressive
152+
if "," in accept_header:
153+
return False
154+
155+
return any(
156+
mime_types_match(pattern, accept_header)
157+
for pattern in COMMON_BINARY_MIME_TYPES
158+
)
159+
160+
@progressive.setter
161+
def progressive(self, value):
162+
self._progressive = value
163+
124164
# Support/use jupyter_server config arguments allow_origin and allow_origin_pat
125165
# to enable cross origin requests propagated by e.g. inverting proxies.
126166

@@ -376,16 +416,16 @@ async def proxy(self, host, port, proxied_path):
376416
)
377417
else:
378418
client = httpclient.AsyncHTTPClient(force_instance=True)
379-
# check if the request is stream request
380-
accept_header = self.request.headers.get("Accept")
381-
if accept_header == "text/event-stream":
419+
420+
if self.progressive:
382421
return await self._proxy_progressive(host, port, proxied_path, body, client)
383422
else:
384423
return await self._proxy_buffered(host, port, proxied_path, body, client)
385424

386425
async def _proxy_progressive(self, host, port, proxied_path, body, client):
387426
# Proxy in progressive flush mode, whenever chunks are received. Potentially slower but get results quicker for voila
388427
# Set up handlers so we can progressively flush result
428+
self.log.debug(f"Request to '{proxied_path}' will be proxied progressive")
389429

390430
headers_raw = []
391431

@@ -466,9 +506,10 @@ def streaming_callback(chunk):
466506
self.write(response.body)
467507

468508
async def _proxy_buffered(self, host, port, proxied_path, body, client):
469-
req = self._build_proxy_request(host, port, proxied_path, body)
509+
self.log.debug(f"Request to '{proxied_path}' will be proxied buffered")
470510

471-
self.log.debug(f"Proxying request to {req.url}")
511+
req = self._build_proxy_request(host, port, proxied_path, body)
512+
self.log.debug(f"Proxy request URL: {req.url}")
472513

473514
try:
474515
# Here, "response" is a tornado.httpclient.HTTPResponse object.

jupyter_server_proxy/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,20 @@ def call_with_asked_args(callback, args):
2828
)
2929
)
3030
return callback(*asked_arg_values)
31+
32+
33+
def mime_types_match(pattern: str, value: str) -> bool:
34+
"""
35+
Compare a MIME type pattern, possibly with wildcards, and a value
36+
"""
37+
value = value.split(";")[0] # Remove optional details
38+
if pattern == value:
39+
return True
40+
41+
type, subtype = value.split("/")
42+
pattern = pattern.split("/")
43+
44+
if pattern[0] == "*" or (pattern[0] == type and pattern[1] == "*"):
45+
return True
46+
47+
return False

tests/test_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,28 @@ def _test_func(a, b):
77
return c
88

99
assert utils.call_with_asked_args(_test_func, {"a": 5, "b": 4, "c": 8}) == 20
10+
11+
12+
def test_mime_types_match():
13+
# Exact match
14+
assert utils.mime_types_match("text/plain", "text/plain")
15+
assert not utils.mime_types_match("text/plain", "text/html")
16+
17+
# With optional parameters
18+
assert utils.mime_types_match("text/plain", "text/plain;charset=UTF-8")
19+
assert not utils.mime_types_match("text/plain", "text/html;charset=UTF-8")
20+
21+
# With a single widcard
22+
assert utils.mime_types_match("*", "text/plain")
23+
assert utils.mime_types_match("*", "text/plain;charset=UTF-8")
24+
25+
# With both components wildcard
26+
assert utils.mime_types_match("*/*", "text/plain")
27+
assert utils.mime_types_match("*/*", "text/plain;charset=UTF-8")
28+
29+
# With a subtype wildcard
30+
assert utils.mime_types_match("text/*", "text/plain")
31+
assert not utils.mime_types_match("image/*", "text/plain")
32+
33+
assert utils.mime_types_match("text/*", "text/plain;charset=UTF-8")
34+
assert not utils.mime_types_match("image/*", "text/plain;charset=UTF-8")

0 commit comments

Comments
 (0)