Skip to content

Commit 94721bd

Browse files
committed
refactor: Make Transport the source of their_node_id
This patch ontinues to separate state the exists before NOISE is complete and after it is complete to unlock future refactoring. Most callers immediately unwrapped the value from Peer and can just call Transport::get_their_node_id(). The duplicate connection disconnect path has been rewritten to determine whether or not to remove & send a disconnect event without needing to use a None value for Option<PublicKey> All other users are in contexts where they either exit early or continue if !transport.is_connected() so it is also safe to call Transport::get_their_node_id()
1 parent f781569 commit 94721bd

File tree

2 files changed

+90
-55
lines changed

2 files changed

+90
-55
lines changed

lightning/src/ln/peers/handler.rs

Lines changed: 47 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ pub(super) trait ITransport {
5151
/// Returns true if the connection is established and encrypted messages can be sent.
5252
fn is_connected(&self) -> bool;
5353

54+
/// Returns the node_id of the remote node. Panics if not connected.
55+
fn get_their_node_id(&self) -> PublicKey;
56+
5457
/// Returns all Messages that have been received and can be parsed by the Transport
5558
fn drain_messages<L: Deref>(&mut self, logger: L) -> Result<Vec<Message>, PeerHandleError> where L::Target: Logger;
5659

@@ -184,7 +187,6 @@ enum InitSyncTracker{
184187
struct Peer {
185188
transport: Transport,
186189
outbound: bool,
187-
their_node_id: Option<PublicKey>,
188190
their_features: Option<InitFeatures>,
189191

190192
pending_outbound_buffer: OutboundQueue,
@@ -330,7 +332,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
330332
if !p.transport.is_connected() || p.their_features.is_none() {
331333
return None;
332334
}
333-
p.their_node_id
335+
Some(p.transport.get_their_node_id())
334336
}).collect()
335337
}
336338

@@ -363,7 +365,6 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
363365
if peers.peers.insert(descriptor, Peer {
364366
transport,
365367
outbound: true,
366-
their_node_id: Some(their_node_id.clone()),
367368
their_features: None,
368369

369370
pending_outbound_buffer: OutboundQueue::new(MSG_BUFF_SIZE),
@@ -391,7 +392,6 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
391392
if peers.peers.insert(descriptor, Peer {
392393
transport: Transport::new_inbound(&self.our_node_secret, &self.get_ephemeral_key()),
393394
outbound: false,
394-
their_node_id: None,
395395
their_features: None,
396396

397397
pending_outbound_buffer: OutboundQueue::new(MSG_BUFF_SIZE),
@@ -530,22 +530,20 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
530530

531531
// If the transport is newly connected, do the appropriate set up for the connection
532532
if peer.transport.is_connected() {
533-
let their_node_id = peer.transport.their_node_id.unwrap();
533+
let their_node_id = peer.transport.get_their_node_id();
534534

535535
match peers.node_id_to_descriptor.entry(their_node_id.clone()) {
536536
hash_map::Entry::Occupied(entry) => {
537537
if entry.get() != peer_descriptor {
538538
// Existing entry in map is from a different descriptor, this is a duplicate
539539
log_trace!(self.logger, "Got second connection with {}, closing", log_pubkey!(&their_node_id));
540-
peer.their_node_id = None;
541540
return Err(PeerHandleError { no_connection_possible: false });
542541
} else {
543542
// read_event for existing peer
544543
}
545544
},
546545
hash_map::Entry::Vacant(entry) => {
547546
log_trace!(self.logger, "Finished noise handshake for connection with {}", log_pubkey!(&their_node_id));
548-
peer.their_node_id = Some(their_node_id.clone());
549547

550548
if peer.outbound {
551549
let mut features = InitFeatures::known();
@@ -616,12 +614,12 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
616614

617615
/// Process an incoming message and return a decision (ok, lightning error, peer handling error) regarding the next action with the peer
618616
fn handle_message(&self, peers_needing_send: &mut HashSet<Descriptor>, peer: &mut Peer, peer_descriptor: Descriptor, message: wire::Message) -> Result<(), MessageHandlingError> {
619-
log_trace!(self.logger, "Received message of type {} from {}", message.type_id(), log_pubkey!(peer.their_node_id.unwrap()));
617+
log_trace!(self.logger, "Received message of type {} from {}", message.type_id(), log_pubkey!(peer.transport.get_their_node_id()));
620618

621619
// Need an Init as first message
622620
if let wire::Message::Init(_) = message {
623621
} else if peer.their_features.is_none() {
624-
log_trace!(self.logger, "Peer {} sent non-Init first message", log_pubkey!(peer.their_node_id.unwrap()));
622+
log_trace!(self.logger, "Peer {} sent non-Init first message", log_pubkey!(peer.transport.get_their_node_id()));
625623
return Err(PeerHandleError{ no_connection_possible: false }.into());
626624
}
627625

@@ -654,21 +652,21 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
654652
peers_needing_send.insert(peer_descriptor.clone());
655653
}
656654
if !msg.features.supports_static_remote_key() {
657-
log_debug!(self.logger, "Peer {} does not support static remote key, disconnecting with no_connection_possible", log_pubkey!(peer.their_node_id.unwrap()));
655+
log_debug!(self.logger, "Peer {} does not support static remote key, disconnecting with no_connection_possible", log_pubkey!(peer.transport.get_their_node_id()));
658656
return Err(PeerHandleError{ no_connection_possible: true }.into());
659657
}
660658

661659
if !peer.outbound {
662660
let mut features = InitFeatures::known();
663-
if !self.message_handler.route_handler.should_request_full_sync(&peer.their_node_id.unwrap()) {
661+
if !self.message_handler.route_handler.should_request_full_sync(&peer.transport.get_their_node_id()) {
664662
features.clear_initial_routing_sync();
665663
}
666664

667665
let resp = msgs::Init { features };
668666
self.enqueue_message(peers_needing_send, &mut peer.transport, &mut peer.pending_outbound_buffer, &peer_descriptor, &resp);
669667
}
670668

671-
self.message_handler.chan_handler.peer_connected(&peer.their_node_id.unwrap(), &msg);
669+
self.message_handler.chan_handler.peer_connected(&peer.transport.get_their_node_id(), &msg);
672670
peer.their_features = Some(msg.features);
673671
},
674672
wire::Message::Error(msg) => {
@@ -681,11 +679,11 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
681679
}
682680

683681
if data_is_printable {
684-
log_debug!(self.logger, "Got Err message from {}: {}", log_pubkey!(peer.their_node_id.unwrap()), msg.data);
682+
log_debug!(self.logger, "Got Err message from {}: {}", log_pubkey!(peer.transport.get_their_node_id()), msg.data);
685683
} else {
686-
log_debug!(self.logger, "Got Err message from {} with non-ASCII error message", log_pubkey!(peer.their_node_id.unwrap()));
684+
log_debug!(self.logger, "Got Err message from {} with non-ASCII error message", log_pubkey!(peer.transport.get_their_node_id()));
687685
}
688-
self.message_handler.chan_handler.handle_error(&peer.their_node_id.unwrap(), &msg);
686+
self.message_handler.chan_handler.handle_error(&peer.transport.get_their_node_id(), &msg);
689687
if msg.channel_id == [0; 32] {
690688
return Err(PeerHandleError{ no_connection_possible: true }.into());
691689
}
@@ -703,59 +701,59 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
703701

704702
// Channel messages:
705703
wire::Message::OpenChannel(msg) => {
706-
self.message_handler.chan_handler.handle_open_channel(&peer.their_node_id.unwrap(), peer.their_features.clone().unwrap(), &msg);
704+
self.message_handler.chan_handler.handle_open_channel(&peer.transport.get_their_node_id(), peer.their_features.clone().unwrap(), &msg);
707705
},
708706
wire::Message::AcceptChannel(msg) => {
709-
self.message_handler.chan_handler.handle_accept_channel(&peer.their_node_id.unwrap(), peer.their_features.clone().unwrap(), &msg);
707+
self.message_handler.chan_handler.handle_accept_channel(&peer.transport.get_their_node_id(), peer.their_features.clone().unwrap(), &msg);
710708
},
711709

712710
wire::Message::FundingCreated(msg) => {
713-
self.message_handler.chan_handler.handle_funding_created(&peer.their_node_id.unwrap(), &msg);
711+
self.message_handler.chan_handler.handle_funding_created(&peer.transport.get_their_node_id(), &msg);
714712
},
715713
wire::Message::FundingSigned(msg) => {
716-
self.message_handler.chan_handler.handle_funding_signed(&peer.their_node_id.unwrap(), &msg);
714+
self.message_handler.chan_handler.handle_funding_signed(&peer.transport.get_their_node_id(), &msg);
717715
},
718716
wire::Message::FundingLocked(msg) => {
719-
self.message_handler.chan_handler.handle_funding_locked(&peer.their_node_id.unwrap(), &msg);
717+
self.message_handler.chan_handler.handle_funding_locked(&peer.transport.get_their_node_id(), &msg);
720718
},
721719

722720
wire::Message::Shutdown(msg) => {
723-
self.message_handler.chan_handler.handle_shutdown(&peer.their_node_id.unwrap(), &msg);
721+
self.message_handler.chan_handler.handle_shutdown(&peer.transport.get_their_node_id(), &msg);
724722
},
725723
wire::Message::ClosingSigned(msg) => {
726-
self.message_handler.chan_handler.handle_closing_signed(&peer.their_node_id.unwrap(), &msg);
724+
self.message_handler.chan_handler.handle_closing_signed(&peer.transport.get_their_node_id(), &msg);
727725
},
728726

729727
// Commitment messages:
730728
wire::Message::UpdateAddHTLC(msg) => {
731-
self.message_handler.chan_handler.handle_update_add_htlc(&peer.their_node_id.unwrap(), &msg);
729+
self.message_handler.chan_handler.handle_update_add_htlc(&peer.transport.get_their_node_id(), &msg);
732730
},
733731
wire::Message::UpdateFulfillHTLC(msg) => {
734-
self.message_handler.chan_handler.handle_update_fulfill_htlc(&peer.their_node_id.unwrap(), &msg);
732+
self.message_handler.chan_handler.handle_update_fulfill_htlc(&peer.transport.get_their_node_id(), &msg);
735733
},
736734
wire::Message::UpdateFailHTLC(msg) => {
737-
self.message_handler.chan_handler.handle_update_fail_htlc(&peer.their_node_id.unwrap(), &msg);
735+
self.message_handler.chan_handler.handle_update_fail_htlc(&peer.transport.get_their_node_id(), &msg);
738736
},
739737
wire::Message::UpdateFailMalformedHTLC(msg) => {
740-
self.message_handler.chan_handler.handle_update_fail_malformed_htlc(&peer.their_node_id.unwrap(), &msg);
738+
self.message_handler.chan_handler.handle_update_fail_malformed_htlc(&peer.transport.get_their_node_id(), &msg);
741739
},
742740

743741
wire::Message::CommitmentSigned(msg) => {
744-
self.message_handler.chan_handler.handle_commitment_signed(&peer.their_node_id.unwrap(), &msg);
742+
self.message_handler.chan_handler.handle_commitment_signed(&peer.transport.get_their_node_id(), &msg);
745743
},
746744
wire::Message::RevokeAndACK(msg) => {
747-
self.message_handler.chan_handler.handle_revoke_and_ack(&peer.their_node_id.unwrap(), &msg);
745+
self.message_handler.chan_handler.handle_revoke_and_ack(&peer.transport.get_their_node_id(), &msg);
748746
},
749747
wire::Message::UpdateFee(msg) => {
750-
self.message_handler.chan_handler.handle_update_fee(&peer.their_node_id.unwrap(), &msg);
748+
self.message_handler.chan_handler.handle_update_fee(&peer.transport.get_their_node_id(), &msg);
751749
},
752750
wire::Message::ChannelReestablish(msg) => {
753-
self.message_handler.chan_handler.handle_channel_reestablish(&peer.their_node_id.unwrap(), &msg);
751+
self.message_handler.chan_handler.handle_channel_reestablish(&peer.transport.get_their_node_id(), &msg);
754752
},
755753

756754
// Routing messages:
757755
wire::Message::AnnouncementSignatures(msg) => {
758-
self.message_handler.chan_handler.handle_announcement_signatures(&peer.their_node_id.unwrap(), &msg);
756+
self.message_handler.chan_handler.handle_announcement_signatures(&peer.transport.get_their_node_id(), &msg);
759757
},
760758
wire::Message::ChannelAnnouncement(msg) => {
761759
let should_forward = match self.message_handler.route_handler.handle_channel_announcement(&msg) {
@@ -1000,13 +998,10 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
1000998
!peer.should_forward_channel_announcement(msg.contents.short_channel_id) {
1001999
continue
10021000
}
1003-
match peer.their_node_id {
1004-
None => continue,
1005-
Some(their_node_id) => {
1006-
if their_node_id == msg.contents.node_id_1 || their_node_id == msg.contents.node_id_2 {
1007-
continue
1008-
}
1009-
}
1001+
1002+
let their_node_id = peer.transport.get_their_node_id();
1003+
if their_node_id == msg.contents.node_id_1 || their_node_id == msg.contents.node_id_2 {
1004+
continue
10101005
}
10111006
if peer.transport.is_connected() {
10121007
peer.transport.enqueue_message(msg, &mut peer.pending_outbound_buffer, &*self.logger);
@@ -1119,12 +1114,17 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
11191114
match peer_option {
11201115
None => panic!("Descriptor for disconnect_event is not already known to PeerManager"),
11211116
Some(peer) => {
1122-
match peer.their_node_id {
1123-
Some(node_id) => {
1117+
if peer.transport.is_connected() {
1118+
let node_id = peer.transport.get_their_node_id();
1119+
1120+
if peers.node_id_to_descriptor.get(&node_id).unwrap() == descriptor {
11241121
peers.node_id_to_descriptor.remove(&node_id);
11251122
self.message_handler.chan_handler.peer_disconnected(&node_id, no_connection_possible);
1126-
},
1127-
None => {}
1123+
} else {
1124+
// This must have been generated from a duplicate connection error
1125+
}
1126+
} else {
1127+
// Unconnected nodes never make it into node_id_to_descriptor
11281128
}
11291129
}
11301130
};
@@ -1147,18 +1147,11 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
11471147
if peer.awaiting_pong {
11481148
peers_needing_send.remove(descriptor);
11491149
descriptors_needing_disconnect.push(descriptor.clone());
1150-
match peer.their_node_id {
1151-
Some(node_id) => {
1152-
log_trace!(self.logger, "Disconnecting peer with id {} due to ping timeout", node_id);
1153-
node_id_to_descriptor.remove(&node_id);
1154-
self.message_handler.chan_handler.peer_disconnected(&node_id, false);
1155-
}
1156-
None => {
1157-
// This can't actually happen as we should have hit
1158-
// is_connected() previously on this same peer.
1159-
unreachable!();
1160-
},
1161-
}
1150+
let their_node_id = peer.transport.get_their_node_id();
1151+
log_trace!(self.logger, "Disconnecting peer with id {} due to ping timeout", their_node_id);
1152+
node_id_to_descriptor.remove(&their_node_id);
1153+
self.message_handler.chan_handler.peer_disconnected(&their_node_id, false);
1154+
11621155
return false;
11631156
}
11641157

lightning/src/ln/peers/transport.rs

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pub trait IPeerHandshake {
3131
pub(super) struct Transport<PeerHandshakeImpl: IPeerHandshake=PeerHandshake> {
3232
pub(super) conduit: Option<Conduit>,
3333
handshake: PeerHandshakeImpl,
34-
pub(super) their_node_id: Option<PublicKey>,
34+
their_node_id: Option<PublicKey>,
3535
}
3636

3737
impl<PeerHandshakeImpl: IPeerHandshake> ITransport for Transport<PeerHandshakeImpl> {
@@ -160,6 +160,11 @@ impl<PeerHandshakeImpl: IPeerHandshake> ITransport for Transport<PeerHandshakeIm
160160
}
161161
}
162162
}
163+
164+
fn get_their_node_id(&self) -> PublicKey {
165+
assert!(self.is_connected(), "Retrieving the remote node_id is only supported after transport is connected");
166+
self.their_node_id.unwrap()
167+
}
163168
}
164169

165170
#[cfg(test)]
@@ -254,6 +259,42 @@ mod tests {
254259
assert!(transport.is_connected());
255260
}
256261

262+
// Test get_their_node_id() in unconnected and connected scenarios
263+
#[test]
264+
#[should_panic(expected = "Retrieving the remote node_id is only supported after transport is connected")]
265+
fn inbound_unconnected_get_their_node_id_panics() {
266+
let transport = create_inbound_for_test::<PeerHandshakeTestStubFail>();
267+
268+
let _should_panic = transport.get_their_node_id();
269+
}
270+
271+
#[test]
272+
#[should_panic(expected = "Retrieving the remote node_id is only supported after transport is connected")]
273+
fn outbound_unconnected_get_their_node_id_panics() {
274+
let mut transport = create_outbound_for_test::<PeerHandshakeTestStubFail>();
275+
transport.set_up_outbound();
276+
277+
let _should_panic = transport.get_their_node_id();
278+
}
279+
280+
#[test]
281+
fn inbound_unconnected_get_their_node_id() {
282+
let mut transport = create_inbound_for_test::<PeerHandshakeTestStubComplete>();
283+
let mut spy = Vec::new();
284+
285+
transport.process_input(&[], &mut spy).unwrap();
286+
let _no_panic = transport.get_their_node_id();
287+
}
288+
289+
#[test]
290+
fn outbound_unconnected_get_their_node_id() {
291+
let mut transport = create_inbound_for_test::<PeerHandshakeTestStubComplete>();
292+
let mut spy = Vec::new();
293+
294+
transport.process_input(&[], &mut spy).unwrap();
295+
let _no_panic = transport.get_their_node_id();
296+
}
297+
257298
// Test that when a handshake completes is_connected() is correct
258299
#[test]
259300
fn outbound_handshake_complete_ready_for_encryption() {
@@ -262,6 +303,7 @@ mod tests {
262303

263304
transport.process_input(&[], &mut spy).unwrap();
264305
assert!(transport.is_connected());
306+
let _no_panic = transport.get_their_node_id();
265307
}
266308

267309
#[test]

0 commit comments

Comments
 (0)