24
24
//! The call site should, thus, look something like this:
25
25
//! ```
26
26
//! use tokio::sync::mpsc;
27
- //! use tokio ::net::TcpStream;
27
+ //! use std ::net::TcpStream;
28
28
//! use bitcoin::secp256k1::key::PublicKey;
29
29
//! use lightning::util::events::EventsProvider;
30
30
//! use std::net::SocketAddr;
@@ -86,6 +86,7 @@ use lightning::util::logger::Logger;
86
86
87
87
use std:: { task, thread} ;
88
88
use std:: net:: SocketAddr ;
89
+ use std:: net:: TcpStream as StdTcpStream ;
89
90
use std:: sync:: { Arc , Mutex , MutexGuard } ;
90
91
use std:: sync:: atomic:: { AtomicU64 , Ordering } ;
91
92
use std:: time:: Duration ;
@@ -156,7 +157,7 @@ impl Connection {
156
157
// In this case, we do need to call peer_manager.socket_disconnected() to inform
157
158
// Rust-Lightning that the socket is gone.
158
159
PeerDisconnected
159
- } ;
160
+ }
160
161
let disconnect_type = loop {
161
162
macro_rules! shutdown_socket {
162
163
( $err: expr, $need_disconnect: expr) => { {
@@ -218,7 +219,7 @@ impl Connection {
218
219
}
219
220
}
220
221
221
- fn new ( event_notify : mpsc:: Sender < ( ) > , stream : TcpStream ) -> ( io:: ReadHalf < TcpStream > , mpsc:: Receiver < ( ) > , mpsc:: Receiver < ( ) > , Arc < Mutex < Self > > ) {
222
+ fn new ( event_notify : mpsc:: Sender < ( ) > , stream : StdTcpStream ) -> ( io:: ReadHalf < TcpStream > , mpsc:: Receiver < ( ) > , mpsc:: Receiver < ( ) > , Arc < Mutex < Self > > ) {
222
223
// We only ever need a channel of depth 1 here: if we returned a non-full write to the
223
224
// PeerManager, we will eventually get notified that there is room in the socket to write
224
225
// new bytes, which will generate an event. That event will be popped off the queue before
@@ -229,7 +230,8 @@ impl Connection {
229
230
// we shove a value into the channel which comes after we've reset the read_paused bool to
230
231
// false.
231
232
let ( read_waker, read_receiver) = mpsc:: channel ( 1 ) ;
232
- let ( reader, writer) = io:: split ( stream) ;
233
+ stream. set_nonblocking ( true ) . unwrap ( ) ;
234
+ let ( reader, writer) = io:: split ( TcpStream :: from_std ( stream) . unwrap ( ) ) ;
233
235
234
236
( reader, write_receiver, read_receiver,
235
237
Arc :: new ( Mutex :: new ( Self {
@@ -248,7 +250,7 @@ impl Connection {
248
250
/// not need to poll the provided future in order to make progress.
249
251
///
250
252
/// See the module-level documentation for how to handle the event_notify mpsc::Sender.
251
- pub fn setup_inbound < CMH , RMH , L > ( peer_manager : Arc < peer_handler:: PeerManager < SocketDescriptor , Arc < CMH > , Arc < RMH > , Arc < L > > > , event_notify : mpsc:: Sender < ( ) > , stream : TcpStream ) -> impl std:: future:: Future < Output =( ) > where
253
+ pub fn setup_inbound < CMH , RMH , L > ( peer_manager : Arc < peer_handler:: PeerManager < SocketDescriptor , Arc < CMH > , Arc < RMH > , Arc < L > > > , event_notify : mpsc:: Sender < ( ) > , stream : StdTcpStream ) -> impl std:: future:: Future < Output =( ) > where
252
254
CMH : ChannelMessageHandler + ' static ,
253
255
RMH : RoutingMessageHandler + ' static ,
254
256
L : Logger + ' static + ?Sized {
@@ -290,7 +292,7 @@ pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<So
290
292
/// not need to poll the provided future in order to make progress.
291
293
///
292
294
/// See the module-level documentation for how to handle the event_notify mpsc::Sender.
293
- pub fn setup_outbound < CMH , RMH , L > ( peer_manager : Arc < peer_handler:: PeerManager < SocketDescriptor , Arc < CMH > , Arc < RMH > , Arc < L > > > , event_notify : mpsc:: Sender < ( ) > , their_node_id : PublicKey , stream : TcpStream ) -> impl std:: future:: Future < Output =( ) > where
295
+ pub fn setup_outbound < CMH , RMH , L > ( peer_manager : Arc < peer_handler:: PeerManager < SocketDescriptor , Arc < CMH > , Arc < RMH > , Arc < L > > > , event_notify : mpsc:: Sender < ( ) > , their_node_id : PublicKey , stream : StdTcpStream ) -> impl std:: future:: Future < Output =( ) > where
294
296
CMH : ChannelMessageHandler + ' static ,
295
297
RMH : RoutingMessageHandler + ' static ,
296
298
L : Logger + ' static + ?Sized {
@@ -366,7 +368,7 @@ pub async fn connect_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerM
366
368
CMH : ChannelMessageHandler + ' static ,
367
369
RMH : RoutingMessageHandler + ' static ,
368
370
L : Logger + ' static + ?Sized {
369
- if let Ok ( Ok ( stream) ) = time:: timeout ( Duration :: from_secs ( 10 ) , TcpStream :: connect ( & addr) ) . await {
371
+ if let Ok ( Ok ( stream) ) = time:: timeout ( Duration :: from_secs ( 10 ) , async { TcpStream :: connect ( & addr) . await . map ( |s| s . into_std ( ) . unwrap ( ) ) } ) . await {
370
372
Some ( setup_outbound ( peer_manager, event_notify, their_node_id, stream) )
371
373
} else { None }
372
374
}
@@ -388,7 +390,7 @@ fn wake_socket_waker(orig_ptr: *const ()) {
388
390
}
389
391
fn wake_socket_waker_by_ref ( orig_ptr : * const ( ) ) {
390
392
let sender_ptr = orig_ptr as * const mpsc:: Sender < ( ) > ;
391
- let mut sender = unsafe { ( * sender_ptr) . clone ( ) } ;
393
+ let sender = unsafe { ( * sender_ptr) . clone ( ) } ;
392
394
let _ = sender. try_send ( ( ) ) ;
393
395
}
394
396
fn drop_socket_waker ( orig_ptr : * const ( ) ) {
@@ -512,6 +514,7 @@ mod tests {
512
514
use tokio:: sync:: mpsc;
513
515
514
516
use std:: mem;
517
+ use std:: sync:: atomic:: { AtomicBool , Ordering } ;
515
518
use std:: sync:: { Arc , Mutex } ;
516
519
use std:: time:: Duration ;
517
520
@@ -526,6 +529,7 @@ mod tests {
526
529
expected_pubkey : PublicKey ,
527
530
pubkey_connected : mpsc:: Sender < ( ) > ,
528
531
pubkey_disconnected : mpsc:: Sender < ( ) > ,
532
+ disconnected_flag : AtomicBool ,
529
533
msg_events : Mutex < Vec < MessageSendEvent > > ,
530
534
}
531
535
impl RoutingMessageHandler for MsgHandler {
@@ -559,6 +563,7 @@ mod tests {
559
563
fn handle_announcement_signatures ( & self , _their_node_id : & PublicKey , _msg : & AnnouncementSignatures ) { }
560
564
fn peer_disconnected ( & self , their_node_id : & PublicKey , _no_connection_possible : bool ) {
561
565
if * their_node_id == self . expected_pubkey {
566
+ self . disconnected_flag . store ( true , Ordering :: SeqCst ) ;
562
567
self . pubkey_disconnected . clone ( ) . try_send ( ( ) ) . unwrap ( ) ;
563
568
}
564
569
}
@@ -591,6 +596,7 @@ mod tests {
591
596
expected_pubkey : b_pub,
592
597
pubkey_connected : a_connected_sender,
593
598
pubkey_disconnected : a_disconnected_sender,
599
+ disconnected_flag : AtomicBool :: new ( false ) ,
594
600
msg_events : Mutex :: new ( Vec :: new ( ) ) ,
595
601
} ) ;
596
602
let a_manager = Arc :: new ( PeerManager :: new ( MessageHandler {
@@ -604,6 +610,7 @@ mod tests {
604
610
expected_pubkey : a_pub,
605
611
pubkey_connected : b_connected_sender,
606
612
pubkey_disconnected : b_disconnected_sender,
613
+ disconnected_flag : AtomicBool :: new ( false ) ,
607
614
msg_events : Mutex :: new ( Vec :: new ( ) ) ,
608
615
} ) ;
609
616
let b_manager = Arc :: new ( PeerManager :: new ( MessageHandler {
@@ -624,27 +631,29 @@ mod tests {
624
631
} else { panic ! ( "Failed to bind to v4 localhost on common ports" ) ; } ;
625
632
626
633
let ( sender, _receiver) = mpsc:: channel ( 2 ) ;
627
- let fut_a = super :: setup_outbound ( Arc :: clone ( & a_manager) , sender. clone ( ) , b_pub, tokio :: net :: TcpStream :: from_std ( conn_a) . unwrap ( ) ) ;
628
- let fut_b = super :: setup_inbound ( b_manager, sender, tokio :: net :: TcpStream :: from_std ( conn_b) . unwrap ( ) ) ;
634
+ let fut_a = super :: setup_outbound ( Arc :: clone ( & a_manager) , sender. clone ( ) , b_pub, conn_a) ;
635
+ let fut_b = super :: setup_inbound ( b_manager, sender, conn_b) ;
629
636
630
637
tokio:: time:: timeout ( Duration :: from_secs ( 10 ) , a_connected. recv ( ) ) . await . unwrap ( ) ;
631
638
tokio:: time:: timeout ( Duration :: from_secs ( 1 ) , b_connected. recv ( ) ) . await . unwrap ( ) ;
632
639
633
640
a_handler. msg_events . lock ( ) . unwrap ( ) . push ( MessageSendEvent :: HandleError {
634
641
node_id : b_pub, action : ErrorAction :: DisconnectPeer { msg : None }
635
642
} ) ;
636
- assert ! ( a_disconnected . try_recv ( ) . is_err ( ) ) ;
637
- assert ! ( b_disconnected . try_recv ( ) . is_err ( ) ) ;
643
+ assert ! ( !a_handler . disconnected_flag . load ( Ordering :: SeqCst ) ) ;
644
+ assert ! ( !b_handler . disconnected_flag . load ( Ordering :: SeqCst ) ) ;
638
645
639
646
a_manager. process_events ( ) ;
640
647
tokio:: time:: timeout ( Duration :: from_secs ( 10 ) , a_disconnected. recv ( ) ) . await . unwrap ( ) ;
641
648
tokio:: time:: timeout ( Duration :: from_secs ( 1 ) , b_disconnected. recv ( ) ) . await . unwrap ( ) ;
649
+ assert ! ( a_handler. disconnected_flag. load( Ordering :: SeqCst ) ) ;
650
+ assert ! ( b_handler. disconnected_flag. load( Ordering :: SeqCst ) ) ;
642
651
643
652
fut_a. await ;
644
653
fut_b. await ;
645
654
}
646
655
647
- #[ tokio:: test( threaded_scheduler ) ]
656
+ #[ tokio:: test( flavor = "multi_thread" ) ]
648
657
async fn basic_threaded_connection_test ( ) {
649
658
do_basic_connection_test ( ) . await ;
650
659
}
0 commit comments