@@ -2,17 +2,21 @@ use std::error::Error as StdError;
2
2
#[ cfg( feature = "runtime" ) ]
3
3
use std:: time:: Duration ;
4
4
5
+ use bytes:: Bytes ;
5
6
use futures_channel:: { mpsc, oneshot} ;
6
7
use futures_util:: future:: { self , Either , FutureExt as _, TryFutureExt as _} ;
7
8
use futures_util:: stream:: StreamExt as _;
8
9
use h2:: client:: { Builder , SendRequest } ;
10
+ use http:: { Method , StatusCode } ;
9
11
use tokio:: io:: { AsyncRead , AsyncWrite } ;
10
12
11
- use super :: { decode_content_length , ping , PipeToSendStream , SendBuf } ;
13
+ use super :: { ping , H2Upgraded , PipeToSendStream , SendBuf } ;
12
14
use crate :: body:: HttpBody ;
13
15
use crate :: common:: { exec:: Exec , task, Future , Never , Pin , Poll } ;
14
16
use crate :: headers;
17
+ use crate :: proto:: h2:: UpgradedSendStream ;
15
18
use crate :: proto:: Dispatched ;
19
+ use crate :: upgrade:: Upgraded ;
16
20
use crate :: { Body , Request , Response } ;
17
21
18
22
type ClientRx < B > = crate :: client:: dispatch:: Receiver < Request < B > , Response < Body > > ;
@@ -233,8 +237,25 @@ where
233
237
headers:: set_content_length_if_missing ( req. headers_mut ( ) , len) ;
234
238
}
235
239
}
240
+
241
+ let is_connect = req. method ( ) == Method :: CONNECT ;
236
242
let eos = body. is_end_stream ( ) ;
237
- let ( fut, body_tx) = match self . h2_tx . send_request ( req, eos) {
243
+ let ping = self . ping . clone ( ) ;
244
+
245
+ if is_connect {
246
+ if headers:: content_length_parse_all ( req. headers ( ) )
247
+ . map_or ( false , |len| len != 0 )
248
+ {
249
+ warn ! ( "h2 connect request with non-zero body not supported" ) ;
250
+ cb. send ( Err ( (
251
+ crate :: Error :: new_h2 ( h2:: Reason :: INTERNAL_ERROR . into ( ) ) ,
252
+ None ,
253
+ ) ) ) ;
254
+ continue ;
255
+ }
256
+ }
257
+
258
+ let ( fut, body_tx) = match self . h2_tx . send_request ( req, !is_connect && eos) {
238
259
Ok ( ok) => ok,
239
260
Err ( err) => {
240
261
debug ! ( "client send request error: {}" , err) ;
@@ -243,45 +264,81 @@ where
243
264
}
244
265
} ;
245
266
246
- let ping = self . ping . clone ( ) ;
247
- if !eos {
248
- let mut pipe = Box :: pin ( PipeToSendStream :: new ( body, body_tx) ) . map ( |res| {
249
- if let Err ( e) = res {
250
- debug ! ( "client request body error: {}" , e) ;
251
- }
252
- } ) ;
253
-
254
- // eagerly see if the body pipe is ready and
255
- // can thus skip allocating in the executor
256
- match Pin :: new ( & mut pipe) . poll ( cx) {
257
- Poll :: Ready ( _) => ( ) ,
258
- Poll :: Pending => {
259
- let conn_drop_ref = self . conn_drop_ref . clone ( ) ;
260
- // keep the ping recorder's knowledge of an
261
- // "open stream" alive while this body is
262
- // still sending...
263
- let ping = ping. clone ( ) ;
264
- let pipe = pipe. map ( move |x| {
265
- drop ( conn_drop_ref) ;
266
- drop ( ping) ;
267
- x
267
+ let send_stream = if !is_connect {
268
+ if !eos {
269
+ let mut pipe =
270
+ Box :: pin ( PipeToSendStream :: new ( body, body_tx) ) . map ( |res| {
271
+ if let Err ( e) = res {
272
+ debug ! ( "client request body error: {}" , e) ;
273
+ }
268
274
} ) ;
269
- self . executor . execute ( pipe) ;
275
+
276
+ // eagerly see if the body pipe is ready and
277
+ // can thus skip allocating in the executor
278
+ match Pin :: new ( & mut pipe) . poll ( cx) {
279
+ Poll :: Ready ( _) => ( ) ,
280
+ Poll :: Pending => {
281
+ let conn_drop_ref = self . conn_drop_ref . clone ( ) ;
282
+ // keep the ping recorder's knowledge of an
283
+ // "open stream" alive while this body is
284
+ // still sending...
285
+ let ping = ping. clone ( ) ;
286
+ let pipe = pipe. map ( move |x| {
287
+ drop ( conn_drop_ref) ;
288
+ drop ( ping) ;
289
+ x
290
+ } ) ;
291
+ self . executor . execute ( pipe) ;
292
+ }
270
293
}
271
294
}
272
- }
295
+
296
+ None
297
+ } else {
298
+ Some ( body_tx)
299
+ } ;
273
300
274
301
let fut = fut. map ( move |result| match result {
275
302
Ok ( res) => {
276
303
// record that we got the response headers
277
304
ping. record_non_data ( ) ;
278
305
279
- let content_length = decode_content_length ( res. headers ( ) ) ;
280
- let res = res. map ( |stream| {
281
- let ping = ping. for_stream ( & stream) ;
282
- crate :: Body :: h2 ( stream, content_length, ping)
283
- } ) ;
284
- Ok ( res)
306
+ let content_length = headers:: content_length_parse_all ( res. headers ( ) ) ;
307
+ if let ( Some ( mut send_stream) , StatusCode :: OK ) =
308
+ ( send_stream, res. status ( ) )
309
+ {
310
+ if content_length. map_or ( false , |len| len != 0 ) {
311
+ warn ! ( "h2 connect response with non-zero body not supported" ) ;
312
+
313
+ send_stream. send_reset ( h2:: Reason :: INTERNAL_ERROR ) ;
314
+ return Err ( (
315
+ crate :: Error :: new_h2 ( h2:: Reason :: INTERNAL_ERROR . into ( ) ) ,
316
+ None ,
317
+ ) ) ;
318
+ }
319
+ let ( parts, recv_stream) = res. into_parts ( ) ;
320
+ let mut res = Response :: from_parts ( parts, Body :: empty ( ) ) ;
321
+
322
+ let ( pending, on_upgrade) = crate :: upgrade:: pending ( ) ;
323
+ let io = H2Upgraded {
324
+ ping,
325
+ send_stream : unsafe { UpgradedSendStream :: new ( send_stream) } ,
326
+ recv_stream,
327
+ buf : Bytes :: new ( ) ,
328
+ } ;
329
+ let upgraded = Upgraded :: new ( io, Bytes :: new ( ) ) ;
330
+
331
+ pending. fulfill ( upgraded) ;
332
+ res. extensions_mut ( ) . insert ( on_upgrade) ;
333
+
334
+ Ok ( res)
335
+ } else {
336
+ let res = res. map ( |stream| {
337
+ let ping = ping. for_stream ( & stream) ;
338
+ crate :: Body :: h2 ( stream, content_length. into ( ) , ping)
339
+ } ) ;
340
+ Ok ( res)
341
+ }
285
342
}
286
343
Err ( err) => {
287
344
ping. ensure_not_timed_out ( ) . map_err ( |e| ( e, None ) ) ?;
0 commit comments