31
31
32
32
use bitcoin:: secp256k1:: PublicKey ;
33
33
34
- use tokio:: net:: TcpStream ;
34
+ use tokio:: net:: { tcp , TcpStream } ;
35
35
use tokio:: { io, time} ;
36
36
use tokio:: sync:: mpsc;
37
- use tokio:: io:: { AsyncReadExt , AsyncWrite , AsyncWriteExt } ;
37
+ use tokio:: io:: AsyncWrite ;
38
38
39
39
use lightning:: ln:: peer_handler;
40
40
use lightning:: ln:: peer_handler:: SocketDescriptor as LnSocketTrait ;
@@ -59,7 +59,7 @@ static ID_COUNTER: AtomicU64 = AtomicU64::new(0);
59
59
// define a trivial two- and three- select macro with the specific types we need and just use that.
60
60
61
61
pub ( crate ) enum SelectorOutput {
62
- A ( Option < ( ) > ) , B ( Option < ( ) > ) , C ( tokio:: io:: Result < usize > ) ,
62
+ A ( Option < ( ) > ) , B ( Option < ( ) > ) , C ( tokio:: io:: Result < ( ) > ) ,
63
63
}
64
64
65
65
pub ( crate ) struct TwoSelector <
@@ -87,15 +87,15 @@ impl<
87
87
}
88
88
89
89
pub ( crate ) struct ThreeSelector <
90
- A : Future < Output =Option < ( ) > > + Unpin , B : Future < Output =Option < ( ) > > + Unpin , C : Future < Output =tokio:: io:: Result < usize > > + Unpin
90
+ A : Future < Output =Option < ( ) > > + Unpin , B : Future < Output =Option < ( ) > > + Unpin , C : Future < Output =tokio:: io:: Result < ( ) > > + Unpin
91
91
> {
92
92
pub a : A ,
93
93
pub b : B ,
94
94
pub c : C ,
95
95
}
96
96
97
97
impl <
98
- A : Future < Output =Option < ( ) > > + Unpin , B : Future < Output =Option < ( ) > > + Unpin , C : Future < Output =tokio:: io:: Result < usize > > + Unpin
98
+ A : Future < Output =Option < ( ) > > + Unpin , B : Future < Output =Option < ( ) > > + Unpin , C : Future < Output =tokio:: io:: Result < ( ) > > + Unpin
99
99
> Future for ThreeSelector < A , B , C > {
100
100
type Output = SelectorOutput ;
101
101
fn poll ( mut self : Pin < & mut Self > , ctx : & mut task:: Context < ' _ > ) -> Poll < SelectorOutput > {
@@ -119,7 +119,7 @@ impl<
119
119
/// Connection object (in an Arc<Mutex<>>) in each SocketDescriptor we create as well as in the
120
120
/// read future (which is returned by schedule_read).
121
121
struct Connection {
122
- writer : Option < io :: WriteHalf < TcpStream > > ,
122
+ writer : Option < Arc < TcpStream > > ,
123
123
// Because our PeerManager is templated by user-provided types, and we can't (as far as I can
124
124
// tell) have a const RawWakerVTable built out of templated functions, we need some indirection
125
125
// between being woken up with write-ready and calling PeerManager::write_buffer_space_avail.
@@ -156,7 +156,7 @@ impl Connection {
156
156
async fn schedule_read < PM : Deref + ' static + Send + Sync + Clone > (
157
157
peer_manager : PM ,
158
158
us : Arc < Mutex < Self > > ,
159
- mut reader : io :: ReadHalf < TcpStream > ,
159
+ reader : Arc < TcpStream > ,
160
160
mut read_wake_receiver : mpsc:: Receiver < ( ) > ,
161
161
mut write_avail_receiver : mpsc:: Receiver < ( ) > ,
162
162
) where PM :: Target : APeerManager < Descriptor = SocketDescriptor > {
@@ -200,7 +200,7 @@ impl Connection {
200
200
ThreeSelector {
201
201
a : Box :: pin ( write_avail_receiver. recv ( ) ) ,
202
202
b : Box :: pin ( read_wake_receiver. recv ( ) ) ,
203
- c : Box :: pin ( reader. read ( & mut buf ) ) ,
203
+ c : Box :: pin ( reader. readable ( ) ) ,
204
204
} . await
205
205
} ;
206
206
match select_result {
@@ -211,8 +211,9 @@ impl Connection {
211
211
}
212
212
} ,
213
213
SelectorOutput :: B ( _) => { } ,
214
- SelectorOutput :: C ( read) => {
215
- match read {
214
+ SelectorOutput :: C ( res) => {
215
+ if res. is_err ( ) { break Disconnect :: PeerDisconnected ; }
216
+ match reader. try_read ( & mut buf) {
216
217
Ok ( 0 ) => break Disconnect :: PeerDisconnected ,
217
218
Ok ( len) => {
218
219
let read_res = peer_manager. as_ref ( ) . read_event ( & mut our_descriptor, & buf[ 0 ..len] ) ;
@@ -226,7 +227,11 @@ impl Connection {
226
227
Err ( _) => break Disconnect :: CloseConnection ,
227
228
}
228
229
} ,
229
- Err ( _) => break Disconnect :: PeerDisconnected ,
230
+ Err ( e) if e. kind ( ) == std:: io:: ErrorKind :: WouldBlock => {
231
+ // readable() is allowed to spuriously wake, so we have to handle
232
+ // WouldBlock here.
233
+ } ,
234
+ Err ( e) => break Disconnect :: PeerDisconnected ,
230
235
}
231
236
} ,
232
237
}
@@ -239,18 +244,14 @@ impl Connection {
239
244
// here.
240
245
let _ = tokio:: task:: yield_now ( ) . await ;
241
246
} ;
242
- let writer_option = us. lock ( ) . unwrap ( ) . writer . take ( ) ;
243
- if let Some ( mut writer) = writer_option {
244
- // If the socket is already closed, shutdown() will fail, so just ignore it.
245
- let _ = writer. shutdown ( ) . await ;
246
- }
247
+ us. lock ( ) . unwrap ( ) . writer . take ( ) ;
247
248
if let Disconnect :: PeerDisconnected = disconnect_type {
248
249
peer_manager. as_ref ( ) . socket_disconnected ( & our_descriptor) ;
249
250
peer_manager. as_ref ( ) . process_events ( ) ;
250
251
}
251
252
}
252
253
253
- fn new ( stream : StdTcpStream ) -> ( io :: ReadHalf < TcpStream > , mpsc:: Receiver < ( ) > , mpsc:: Receiver < ( ) > , Arc < Mutex < Self > > ) {
254
+ fn new ( stream : StdTcpStream ) -> ( Arc < TcpStream > , mpsc:: Receiver < ( ) > , mpsc:: Receiver < ( ) > , Arc < Mutex < Self > > ) {
254
255
// We only ever need a channel of depth 1 here: if we returned a non-full write to the
255
256
// PeerManager, we will eventually get notified that there is room in the socket to write
256
257
// new bytes, which will generate an event. That event will be popped off the queue before
@@ -262,11 +263,11 @@ impl Connection {
262
263
// false.
263
264
let ( read_waker, read_receiver) = mpsc:: channel ( 1 ) ;
264
265
stream. set_nonblocking ( true ) . unwrap ( ) ;
265
- let ( reader , writer ) = io :: split ( TcpStream :: from_std ( stream) . unwrap ( ) ) ;
266
+ let tokio_stream = Arc :: new ( TcpStream :: from_std ( stream) . unwrap ( ) ) ;
266
267
267
- ( reader , write_receiver, read_receiver,
268
+ ( Arc :: clone ( & tokio_stream ) , write_receiver, read_receiver,
268
269
Arc :: new ( Mutex :: new ( Self {
269
- writer : Some ( writer ) , write_avail, read_waker, read_paused : false ,
270
+ writer : Some ( tokio_stream ) , write_avail, read_waker, read_paused : false ,
270
271
rl_requested_disconnect : false ,
271
272
id : ID_COUNTER . fetch_add ( 1 , Ordering :: AcqRel )
272
273
} ) ) )
@@ -462,9 +463,9 @@ impl SocketDescriptor {
462
463
}
463
464
impl peer_handler:: SocketDescriptor for SocketDescriptor {
464
465
fn send_data ( & mut self , data : & [ u8 ] , resume_read : bool ) -> usize {
465
- // To send data, we take a lock on our Connection to access the WriteHalf of the TcpStream,
466
- // writing to it if there's room in the kernel buffer, or otherwise create a new Waker with
467
- // a SocketDescriptor in it which can wake up the write_avail Sender, waking up the
466
+ // To send data, we take a lock on our Connection to access the TcpStream, writing to it if
467
+ // there's room in the kernel buffer, or otherwise create a new Waker with a
468
+ // SocketDescriptor in it which can wake up the write_avail Sender, waking up the
468
469
// processing future which will call write_buffer_space_avail and we'll end up back here.
469
470
let mut us = self . conn . lock ( ) . unwrap ( ) ;
470
471
if us. writer . is_none ( ) {
@@ -484,24 +485,18 @@ impl peer_handler::SocketDescriptor for SocketDescriptor {
484
485
let mut ctx = task:: Context :: from_waker ( & waker) ;
485
486
let mut written_len = 0 ;
486
487
loop {
487
- match std:: pin:: Pin :: new ( us. writer . as_mut ( ) . unwrap ( ) ) . poll_write ( & mut ctx, & data[ written_len..] ) {
488
- task:: Poll :: Ready ( Ok ( res) ) => {
489
- // The tokio docs *seem* to indicate this can't happen, and I certainly don't
490
- // know how to handle it if it does (cause it should be a Poll::Pending
491
- // instead):
492
- assert_ne ! ( res, 0 ) ;
493
- written_len += res;
494
- if written_len == data. len ( ) { return written_len; }
495
- } ,
496
- task:: Poll :: Ready ( Err ( e) ) => {
497
- // The tokio docs *seem* to indicate this can't happen, and I certainly don't
498
- // know how to handle it if it does (cause it should be a Poll::Pending
499
- // instead):
500
- assert_ne ! ( e. kind( ) , io:: ErrorKind :: WouldBlock ) ;
501
- // Probably we've already been closed, just return what we have and let the
502
- // read thread handle closing logic.
503
- return written_len;
488
+ match us. writer . as_ref ( ) . unwrap ( ) . poll_write_ready ( & mut ctx) {
489
+ task:: Poll :: Ready ( Ok ( ( ) ) ) => {
490
+ match us. writer . as_ref ( ) . unwrap ( ) . try_write ( & data[ written_len..] ) {
491
+ Ok ( res) => {
492
+ debug_assert_ne ! ( res, 0 ) ;
493
+ written_len += res;
494
+ if written_len == data. len ( ) { return written_len; }
495
+ } ,
496
+ Err ( e) => return written_len,
497
+ }
504
498
} ,
499
+ task:: Poll :: Ready ( Err ( e) ) => return written_len,
505
500
task:: Poll :: Pending => {
506
501
// We're queued up for a write event now, but we need to make sure we also
507
502
// pause read given we're now waiting on the remote end to ACK (and in
0 commit comments