Skip to content

Commit 3ab11f2

Browse files
committed
Move Conduit decryption into read()
This gets rid of the complexity required to handle Iterators that return errors and makes way for fewer copies in the decryption path.
1 parent bf76e7c commit 3ab11f2

File tree

3 files changed

+57
-152
lines changed

3 files changed

+57
-152
lines changed

lightning/src/ln/peers/conduit.rs

Lines changed: 28 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
1212
use ln::peers::{chacha, hkdf5869rfc};
1313
use util::byte_utils;
14+
use std::collections::VecDeque;
1415

1516
pub(super) type SymmetricKey = [u8; 32];
1617

@@ -47,29 +48,14 @@ pub(super) struct Decryptor {
4748

4849
pending_message_length: Option<usize>,
4950
read_buffer: Option<Vec<u8>>,
50-
poisoned: bool, // signal an error has occurred so None is returned on iteration after failure
51+
decrypted_payloads: VecDeque<Vec<u8>>,
5152
}
5253

5354
impl Iterator for Decryptor {
54-
type Item = Result<Option<Vec<u8>>, String>;
55+
type Item = Vec<u8>;
5556

5657
fn next(&mut self) -> Option<Self::Item> {
57-
if self.poisoned {
58-
return None;
59-
}
60-
61-
match self.decrypt_single_message(None) {
62-
Ok(Some(result)) => {
63-
Some(Ok(Some(result)))
64-
},
65-
Ok(None) => {
66-
None
67-
}
68-
Err(e) => {
69-
self.poisoned = true;
70-
Some(Err(e))
71-
}
72-
}
58+
self.decrypted_payloads.pop_front()
7359
}
7460
}
7561

@@ -88,7 +74,7 @@ impl Conduit {
8874
receiving_nonce: 0,
8975
read_buffer: None,
9076
pending_message_length: None,
91-
poisoned: false
77+
decrypted_payloads: VecDeque::new(),
9278
}
9379
}
9480
}
@@ -98,7 +84,7 @@ impl Conduit {
9884
self.encryptor.encrypt(buffer)
9985
}
10086

101-
pub(super) fn read(&mut self, data: &[u8]) {
87+
pub(super) fn read(&mut self, data: &[u8]) -> Result<(), String>{
10288
self.decryptor.read(data)
10389
}
10490

@@ -152,9 +138,25 @@ impl Encryptor {
152138
}
153139

154140
impl Decryptor {
155-
pub(super) fn read(&mut self, data: &[u8]) {
156-
let read_buffer = self.read_buffer.get_or_insert(Vec::new());
157-
read_buffer.extend_from_slice(data);
141+
pub(super) fn read(&mut self, data: &[u8]) -> Result<(), String> {
142+
let mut input_data = Some(data);
143+
144+
loop {
145+
match self.decrypt_single_message(input_data) {
146+
Ok(Some(result)) => {
147+
self.decrypted_payloads.push_back(result);
148+
},
149+
Ok(None) => {
150+
break;
151+
}
152+
Err(e) => {
153+
return Err(e);
154+
}
155+
}
156+
input_data = None;
157+
}
158+
159+
Ok(())
158160
}
159161

160162
/// Decrypt a single message. If data containing more than one message has been received,
@@ -341,7 +343,7 @@ mod tests {
341343
let encrypted = remote_peer.encrypt(&[1]);
342344

343345
connected_peer.decryptor.receiving_key = [0; 32];
344-
assert_eq!(connected_peer.decrypt_single_message(Some(&encrypted)), Err("invalid hmac".to_string()));
346+
assert_eq!(connected_peer.read(&encrypted), Err("invalid hmac".to_string()));
345347
}
346348

347349
// Test next()::None
@@ -357,86 +359,9 @@ mod tests {
357359
fn decryptor_iterator_one_item_valid() {
358360
let (mut connected_peer, mut remote_peer) = setup_peers();
359361
let encrypted = remote_peer.encrypt(&[1]);
360-
connected_peer.read(&encrypted);
361-
362-
assert_eq!(connected_peer.decryptor.next(), Some(Ok(Some(vec![1]))));
363-
assert_eq!(connected_peer.decryptor.next(), None);
364-
}
365-
366-
// Test next()::err -> next()::None
367-
#[test]
368-
fn decryptor_iterator_error() {
369-
let (mut connected_peer, mut remote_peer) = setup_peers();
370-
let encrypted = remote_peer.encrypt(&[1]);
371-
connected_peer.read(&encrypted);
372-
373-
connected_peer.decryptor.receiving_key = [0; 32];
374-
assert_eq!(connected_peer.decryptor.next(), Some(Err("invalid hmac".to_string())));
375-
assert_eq!(connected_peer.decryptor.next(), None);
376-
}
377-
378-
// Test next()::Some -> next()::err -> next()::None
379-
#[test]
380-
fn decryptor_iterator_error_after_success() {
381-
let (mut connected_peer, mut remote_peer) = setup_peers();
382-
let encrypted = remote_peer.encrypt(&[1]);
383-
connected_peer.read(&encrypted);
384-
let encrypted = remote_peer.encrypt(&[2]);
385-
connected_peer.read(&encrypted);
386-
387-
assert_eq!(connected_peer.decryptor.next(), Some(Ok(Some(vec![1]))));
388-
connected_peer.decryptor.receiving_key = [0; 32];
389-
assert_eq!(connected_peer.decryptor.next(), Some(Err("invalid hmac".to_string())));
390-
assert_eq!(connected_peer.decryptor.next(), None);
391-
}
392-
393-
// Test that next()::Some -> next()::err -> next()::None
394-
// Error should poison decryptor
395-
#[test]
396-
fn decryptor_iterator_next_after_error_returns_none() {
397-
let (mut connected_peer, mut remote_peer) = setup_peers();
398-
let encrypted = remote_peer.encrypt(&[1]);
399-
connected_peer.read(&encrypted);
400-
let encrypted = remote_peer.encrypt(&[2]);
401-
connected_peer.read(&encrypted);
402-
let encrypted = remote_peer.encrypt(&[3]);
403-
connected_peer.read(&encrypted);
404-
405-
// Get one valid value
406-
assert_eq!(connected_peer.decryptor.next(), Some(Ok(Some(vec![1]))));
407-
let valid_receiving_key = connected_peer.decryptor.receiving_key;
408-
409-
// Corrupt the receiving key and ensure we get a failure
410-
connected_peer.decryptor.receiving_key = [0; 32];
411-
assert_eq!(connected_peer.decryptor.next(), Some(Err("invalid hmac".to_string())));
412-
413-
// Restore the receiving key, do a read and ensure None is returned (poisoned)
414-
connected_peer.decryptor.receiving_key = valid_receiving_key;
415-
assert_eq!(connected_peer.decryptor.next(), None);
416-
}
417-
418-
// Test next()::Some -> next()::err -> read() -> next()::None
419-
// Error should poison decryptor even after future reads
420-
#[test]
421-
fn decryptor_iterator_read_next_after_error_returns_none() {
422-
let (mut connected_peer, mut remote_peer) = setup_peers();
423-
let encrypted = remote_peer.encrypt(&[1]);
424-
connected_peer.read(&encrypted);
425-
let encrypted = remote_peer.encrypt(&[2]);
426-
connected_peer.read(&encrypted);
427-
428-
// Get one valid value
429-
assert_eq!(connected_peer.decryptor.next(), Some(Ok(Some(vec![1]))));
430-
let valid_receiving_key = connected_peer.decryptor.receiving_key;
431-
432-
// Corrupt the receiving key and ensure we get a failure
433-
connected_peer.decryptor.receiving_key = [0; 32];
434-
assert_eq!(connected_peer.decryptor.next(), Some(Err("invalid hmac".to_string())));
362+
connected_peer.read(&encrypted).unwrap();
435363

436-
// Restore the receiving key, do a read and ensure None is returned (poisoned)
437-
let encrypted = remote_peer.encrypt(&[3]);
438-
connected_peer.read(&encrypted);
439-
connected_peer.decryptor.receiving_key = valid_receiving_key;
364+
assert_eq!(connected_peer.decryptor.next(), Some(vec![1]));
440365
assert_eq!(connected_peer.decryptor.next(), None);
441366
}
442367

lightning/src/ln/peers/handshake/states.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ impl IHandshakeState for ResponderAwaitingActThreeState {
395395

396396
// Any remaining data in the read buffer would be encrypted, so transfer ownership
397397
// to the Conduit for future use.
398-
conduit.read(&input[bytes_read..]);
398+
conduit.read(&input[bytes_read..])?;
399399

400400
Ok((
401401
None,
@@ -778,7 +778,7 @@ mod test {
778778
let test_ctx = TestCtx::new();
779779
let (_act2, awaiting_act_three_state) = do_next_or_panic!(test_ctx.responder, &test_ctx.valid_act1);
780780
let mut act3 = test_ctx.valid_act3;
781-
act3.extend_from_slice(&[2; 100]);
781+
act3.extend_from_slice(&[2; 16]);
782782

783783
let (conduit, remote_pubkey) = if let (None, Complete(Some((conduit, remote_pubkey)))) = awaiting_act_three_state.next(&act3).unwrap() {
784784
(conduit, remote_pubkey)
@@ -787,7 +787,7 @@ mod test {
787787
};
788788

789789
assert_eq!(remote_pubkey, test_ctx.initiator_public_key);
790-
assert_eq!(100, conduit.decryptor.read_buffer_length());
790+
assert_eq!(16, conduit.decryptor.read_buffer_length());
791791
}
792792

793793
// Responder::AwaitingActThree -> Error (bad version bytes)

lightning/src/ln/peers/transport.rs

Lines changed: 26 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ impl<PeerHandshakeImpl: IPeerHandshake> ITransport for Transport<PeerHandshakeIm
8888
}
8989
}
9090
Some(ref mut conduit) => {
91-
conduit.read(input);
91+
conduit.read(input)?;
9292
Ok(false)
9393
}
9494
}
@@ -117,53 +117,33 @@ impl<PeerHandshakeImpl: IPeerHandshake> ITransport for Transport<PeerHandshakeIm
117117
match self.conduit {
118118
None => {}
119119
Some(ref mut conduit) => {
120-
// Using Iterators that can error requires special handling
121-
// The item returned from next() has type Option<Result<Option<Vec>, String>>
122-
// The Some wrapper is stripped for each item inside the loop
123-
// There are 3 valid match cases:
124-
// 1) Some(Ok(Some(msg_data))) => Indicates a valid decrypted msg accessed via msg_data
125-
// 2) Some(Err(_)) => Indicates an error during decryption that should be handled
126-
// 3) None -> Indicates there were no messages available to decrypt
127-
// Invalid Cases
128-
// 1) Some(Ok(None)) => Translated to None case above so users of iterators can stop correctly
129-
for msg_data_result in &mut conduit.decryptor {
130-
match msg_data_result {
131-
Ok(Some(msg_data)) => {
132-
let mut reader = ::std::io::Cursor::new(&msg_data[..]);
133-
let message_result = wire::read(&mut reader);
134-
let message = match message_result {
135-
Ok(x) => x,
136-
Err(e) => {
137-
match e {
138-
msgs::DecodeError::UnknownVersion => return Err(PeerHandleError { no_connection_possible: false }),
139-
msgs::DecodeError::UnknownRequiredFeature => {
140-
log_debug!(logger, "Got a channel/node announcement with an known required feature flag, you may want to update!");
141-
continue;
142-
}
143-
msgs::DecodeError::InvalidValue => {
144-
log_debug!(logger, "Got an invalid value while deserializing message");
145-
return Err(PeerHandleError { no_connection_possible: false });
146-
}
147-
msgs::DecodeError::ShortRead => {
148-
log_debug!(logger, "Deserialization failed due to shortness of message");
149-
return Err(PeerHandleError { no_connection_possible: false });
150-
}
151-
msgs::DecodeError::BadLengthDescriptor => return Err(PeerHandleError { no_connection_possible: false }),
152-
msgs::DecodeError::Io(_) => return Err(PeerHandleError { no_connection_possible: false }),
153-
}
154-
}
155-
};
156-
157-
received_messages.push(message);
158-
},
120+
for msg_data in &mut conduit.decryptor {
121+
let mut reader = ::std::io::Cursor::new(&msg_data[..]);
122+
let message_result = wire::read(&mut reader);
123+
let message = match message_result {
124+
Ok(x) => x,
159125
Err(e) => {
160-
log_trace!(logger, "Message decryption failed due to: {}", e);
161-
return Err(PeerHandleError { no_connection_possible: false });
162-
}
163-
Ok(None) => {
164-
panic!("Invalid behavior. Conduit iterator should never return this match.")
126+
match e {
127+
msgs::DecodeError::UnknownVersion => return Err(PeerHandleError { no_connection_possible: false }),
128+
msgs::DecodeError::UnknownRequiredFeature => {
129+
log_debug!(logger, "Got a channel/node announcement with an known required feature flag, you may want to update!");
130+
continue;
131+
}
132+
msgs::DecodeError::InvalidValue => {
133+
log_debug!(logger, "Got an invalid value while deserializing message");
134+
return Err(PeerHandleError { no_connection_possible: false });
135+
}
136+
msgs::DecodeError::ShortRead => {
137+
log_debug!(logger, "Deserialization failed due to shortness of message");
138+
return Err(PeerHandleError { no_connection_possible: false });
139+
}
140+
msgs::DecodeError::BadLengthDescriptor => return Err(PeerHandleError { no_connection_possible: false }),
141+
msgs::DecodeError::Io(_) => return Err(PeerHandleError { no_connection_possible: false }),
142+
}
165143
}
166-
}
144+
};
145+
146+
received_messages.push(message);
167147
}
168148
}
169149
}

0 commit comments

Comments
 (0)