1
+ use std:: error:: Error as StdError ;
1
2
use std:: fmt;
2
3
use std:: io;
3
4
//use std::net::SocketAddr;
@@ -42,6 +43,7 @@ where T: Service<Request=Uri, Error=io::Error> + 'static,
42
43
#[ derive( Clone ) ]
43
44
pub struct HttpConnector {
44
45
dns : dns:: Dns ,
46
+ enforce_http : bool ,
45
47
handle : Handle ,
46
48
}
47
49
@@ -50,15 +52,26 @@ impl HttpConnector {
50
52
/// Construct a new HttpConnector.
51
53
///
52
54
/// Takes number of DNS worker threads.
55
+ #[ inline]
53
56
pub fn new ( threads : usize , handle : & Handle ) -> HttpConnector {
54
57
HttpConnector {
55
58
dns : dns:: Dns :: new ( threads) ,
59
+ enforce_http : true ,
56
60
handle : handle. clone ( ) ,
57
61
}
58
62
}
63
+
64
+ /// Option to enforce all `Uri`s have the `http` scheme.
65
+ ///
66
+ /// Enabled by default.
67
+ #[ inline]
68
+ pub fn enforce_http ( & mut self , is_enforced : bool ) {
69
+ self . enforce_http = is_enforced;
70
+ }
59
71
}
60
72
61
73
impl fmt:: Debug for HttpConnector {
74
+ #[ inline]
62
75
fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
63
76
f. debug_struct ( "HttpConnector" )
64
77
. finish ( )
@@ -73,12 +86,18 @@ impl Service for HttpConnector {
73
86
74
87
fn call ( & self , uri : Uri ) -> Self :: Future {
75
88
debug ! ( "Http::connect({:?})" , uri) ;
89
+
90
+ if self . enforce_http {
91
+ if uri. scheme ( ) != Some ( "http" ) {
92
+ return invalid_url ( InvalidUrl :: NotHttp , & self . handle ) ;
93
+ }
94
+ } else if uri. scheme ( ) . is_none ( ) {
95
+ return invalid_url ( InvalidUrl :: MissingScheme , & self . handle ) ;
96
+ }
97
+
76
98
let host = match uri. host ( ) {
77
99
Some ( s) => s,
78
- None => return HttpConnecting {
79
- state : State :: Error ( Some ( io:: Error :: new ( io:: ErrorKind :: InvalidInput , "invalid url" ) ) ) ,
80
- handle : self . handle . clone ( ) ,
81
- } ,
100
+ None => return invalid_url ( InvalidUrl :: MissingAuthority , & self . handle ) ,
82
101
} ;
83
102
let port = match uri. port ( ) {
84
103
Some ( port) => port,
@@ -94,7 +113,37 @@ impl Service for HttpConnector {
94
113
handle : self . handle . clone ( ) ,
95
114
}
96
115
}
116
+ }
117
+
118
+ #[ inline]
119
+ fn invalid_url ( err : InvalidUrl , handle : & Handle ) -> HttpConnecting {
120
+ HttpConnecting {
121
+ state : State :: Error ( Some ( io:: Error :: new ( io:: ErrorKind :: InvalidInput , err) ) ) ,
122
+ handle : handle. clone ( ) ,
123
+ }
124
+ }
125
+
126
+ #[ derive( Debug , Clone , Copy ) ]
127
+ enum InvalidUrl {
128
+ MissingScheme ,
129
+ NotHttp ,
130
+ MissingAuthority ,
131
+ }
132
+
133
+ impl fmt:: Display for InvalidUrl {
134
+ fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
135
+ f. write_str ( self . description ( ) )
136
+ }
137
+ }
97
138
139
+ impl StdError for InvalidUrl {
140
+ fn description ( & self ) -> & str {
141
+ match * self {
142
+ InvalidUrl :: MissingScheme => "invalid URL, missing scheme" ,
143
+ InvalidUrl :: NotHttp => "invalid URL, scheme must be http" ,
144
+ InvalidUrl :: MissingAuthority => "invalid URL, missing domain" ,
145
+ }
146
+ }
98
147
}
99
148
100
149
/// A Future representing work to connect to a URL.
@@ -195,12 +244,30 @@ mod tests {
195
244
use super :: { Connect , HttpConnector } ;
196
245
197
246
#[ test]
198
- fn test_non_http_url ( ) {
247
+ fn test_errors_missing_authority ( ) {
199
248
let mut core = Core :: new ( ) . unwrap ( ) ;
200
249
let url = "/foo/bar?baz" . parse ( ) . unwrap ( ) ;
201
250
let connector = HttpConnector :: new ( 1 , & core. handle ( ) ) ;
202
251
203
252
assert_eq ! ( core. run( connector. connect( url) ) . unwrap_err( ) . kind( ) , io:: ErrorKind :: InvalidInput ) ;
204
253
}
205
254
255
+ #[ test]
256
+ fn test_errors_enforce_http ( ) {
257
+ let mut core = Core :: new ( ) . unwrap ( ) ;
258
+ let url = "https://example.domain/foo/bar?baz" . parse ( ) . unwrap ( ) ;
259
+ let connector = HttpConnector :: new ( 1 , & core. handle ( ) ) ;
260
+
261
+ assert_eq ! ( core. run( connector. connect( url) ) . unwrap_err( ) . kind( ) , io:: ErrorKind :: InvalidInput ) ;
262
+ }
263
+
264
+
265
+ #[ test]
266
+ fn test_errors_missing_scheme ( ) {
267
+ let mut core = Core :: new ( ) . unwrap ( ) ;
268
+ let url = "example.domain" . parse ( ) . unwrap ( ) ;
269
+ let connector = HttpConnector :: new ( 1 , & core. handle ( ) ) ;
270
+
271
+ assert_eq ! ( core. run( connector. connect( url) ) . unwrap_err( ) . kind( ) , io:: ErrorKind :: InvalidInput ) ;
272
+ }
206
273
}
0 commit comments