@@ -32,19 +32,15 @@ pub(super) struct Decryptor {
32
32
receiving_nonce : u32 ,
33
33
34
34
pending_message_length : Option < usize > ,
35
- read_buffer : Option < Vec < u8 > > ,
35
+ read_buffer : Vec < u8 > ,
36
36
poisoned : bool , // signal an error has occurred so None is returned on iteration after failure
37
37
}
38
38
39
39
impl Iterator for Decryptor {
40
40
type Item = Result < Option < Vec < u8 > > , String > ;
41
41
42
42
fn next ( & mut self ) -> Option < Self :: Item > {
43
- if self . poisoned {
44
- return None ;
45
- }
46
-
47
- match self . decrypt_single_message ( None ) {
43
+ match self . decrypt_single_message ( & [ ] ) {
48
44
Ok ( Some ( result) ) => {
49
45
Some ( Ok ( Some ( result) ) )
50
46
} ,
@@ -72,7 +68,7 @@ impl Conduit {
72
68
receiving_key,
73
69
receiving_chaining_key : chaining_key,
74
70
receiving_nonce : 0 ,
75
- read_buffer : None ,
71
+ read_buffer : Vec :: new ( ) ,
76
72
pending_message_length : None ,
77
73
poisoned : false
78
74
}
@@ -92,8 +88,9 @@ impl Conduit {
92
88
/// only the first message will be returned, and the rest stored in the internal buffer.
93
89
/// If a message pending in the buffer still hasn't been decrypted, that message will be
94
90
/// returned in lieu of anything new, even if new data is provided.
91
+ /// After a failure, all calls will return Ok(None)
95
92
#[ cfg( any( test, feature = "fuzztarget" ) ) ]
96
- pub fn decrypt_single_message ( & mut self , new_data : Option < & [ u8 ] > ) -> Result < Option < Vec < u8 > > , String > {
93
+ pub fn decrypt_single_message ( & mut self , new_data : & [ u8 ] ) -> Result < Option < Vec < u8 > > , String > {
97
94
Ok ( self . decryptor . decrypt_single_message ( new_data) ?)
98
95
}
99
96
@@ -135,44 +132,32 @@ impl Encryptor {
135
132
136
133
impl Decryptor {
137
134
pub ( super ) fn read ( & mut self , data : & [ u8 ] ) {
138
- let read_buffer = self . read_buffer . get_or_insert ( Vec :: new ( ) ) ;
139
- read_buffer. extend_from_slice ( data) ;
135
+ self . read_buffer . extend ( data) ;
140
136
}
141
137
142
138
/// Decrypt a single message. If data containing more than one message has been received,
143
139
/// only the first message will be returned, and the rest stored in the internal buffer.
144
140
/// If a message pending in the buffer still hasn't been decrypted, that message will be
145
141
/// returned in lieu of anything new, even if new data is provided.
146
- pub fn decrypt_single_message ( & mut self , new_data : Option < & [ u8 ] > ) -> Result < Option < Vec < u8 > > , String > {
147
- let mut read_buffer = if let Some ( buffer) = self . read_buffer . take ( ) {
148
- buffer
149
- } else {
150
- Vec :: new ( )
151
- } ;
152
-
153
- if let Some ( data) = new_data {
154
- read_buffer. extend_from_slice ( data) ;
142
+ /// After a failure, all calls will return Ok(None)
143
+ pub fn decrypt_single_message ( & mut self , new_data : & [ u8 ] ) -> Result < Option < Vec < u8 > > , String > {
144
+ if self . poisoned {
145
+ return Ok ( None ) ;
155
146
}
156
147
157
- let ( current_message, offset) = self . decrypt ( & read_buffer[ ..] ) ?;
158
- read_buffer. drain ( ..offset) ; // drain the read buffer
159
- self . read_buffer = Some ( read_buffer) ; // assign the new value to the built-in buffer
160
- Ok ( current_message)
161
- }
148
+ self . read ( new_data) ;
162
149
163
- fn decrypt ( & mut self , buffer : & [ u8 ] ) -> Result < ( Option < Vec < u8 > > , usize ) , String > {
164
150
let message_length = if let Some ( length) = self . pending_message_length {
165
151
// we have already decrypted the header
166
152
length
167
153
} else {
168
- if buffer . len ( ) < TAGGED_MESSAGE_LENGTH_HEADER_SIZE {
154
+ if self . read_buffer . len ( ) < TAGGED_MESSAGE_LENGTH_HEADER_SIZE {
169
155
// A message must be at least 18 bytes (2 for encrypted length, 16 for the tag)
170
- return Ok ( ( None , 0 ) ) ;
156
+ return Ok ( None ) ;
171
157
}
172
158
173
- let encrypted_length = & buffer[ 0 ..TAGGED_MESSAGE_LENGTH_HEADER_SIZE ] ;
174
159
let mut length_bytes = [ 0u8 ; MESSAGE_LENGTH_HEADER_SIZE ] ;
175
- chacha:: decrypt ( & self . receiving_key , self . receiving_nonce as u64 , & [ 0 ; 0 ] , encrypted_length , & mut length_bytes) ?;
160
+ chacha:: decrypt ( & self . receiving_key , self . receiving_nonce as u64 , & [ 0 ; 0 ] , & self . read_buffer [ .. TAGGED_MESSAGE_LENGTH_HEADER_SIZE ] , & mut length_bytes) ?;
176
161
177
162
self . increment_nonce ( ) ;
178
163
@@ -182,21 +167,22 @@ impl Decryptor {
182
167
183
168
let message_end_index = TAGGED_MESSAGE_LENGTH_HEADER_SIZE + message_length + chacha:: TAG_SIZE ;
184
169
185
- if buffer . len ( ) < message_end_index {
170
+ if self . read_buffer . len ( ) < message_end_index {
186
171
self . pending_message_length = Some ( message_length) ;
187
- return Ok ( ( None , 0 ) ) ;
172
+ return Ok ( None ) ;
188
173
}
189
174
190
175
self . pending_message_length = None ;
191
176
192
- let encrypted_message = & buffer[ TAGGED_MESSAGE_LENGTH_HEADER_SIZE ..message_end_index] ;
193
177
let mut message = vec ! [ 0u8 ; message_length] ;
194
178
195
- chacha:: decrypt ( & self . receiving_key , self . receiving_nonce as u64 , & [ 0 ; 0 ] , encrypted_message , & mut message) ?;
179
+ chacha:: decrypt ( & self . receiving_key , self . receiving_nonce as u64 , & [ 0 ; 0 ] , & self . read_buffer [ TAGGED_MESSAGE_LENGTH_HEADER_SIZE ..message_end_index ] , & mut message) ?;
196
180
197
181
self . increment_nonce ( ) ;
198
182
199
- Ok ( ( Some ( message) , message_end_index) )
183
+ self . read_buffer . drain ( ..message_end_index) ;
184
+
185
+ Ok ( Some ( message) )
200
186
}
201
187
202
188
fn increment_nonce ( & mut self ) {
@@ -207,10 +193,7 @@ impl Decryptor {
207
193
// infrastructure to properly encode it
208
194
#[ cfg( test) ]
209
195
pub fn read_buffer_length ( & self ) -> usize {
210
- match & self . read_buffer {
211
- & Some ( ref vec) => { vec. len ( ) }
212
- & None => 0
213
- }
196
+ self . read_buffer . len ( )
214
197
}
215
198
}
216
199
@@ -247,7 +230,7 @@ mod tests {
247
230
let encrypted_message = connected_peer. encrypt ( & message) ;
248
231
assert_eq ! ( encrypted_message. len( ) , 2 + 16 + 16 ) ;
249
232
250
- let decrypted_message = remote_peer. decrypt_single_message ( Some ( & encrypted_message) ) . unwrap ( ) . unwrap ( ) ;
233
+ let decrypted_message = remote_peer. decrypt_single_message ( & encrypted_message) . unwrap ( ) . unwrap ( ) ;
251
234
assert_eq ! ( decrypted_message, vec![ ] ) ;
252
235
}
253
236
@@ -300,13 +283,13 @@ mod tests {
300
283
let mut current_encrypted_message = encrypted_messages. remove ( 0 ) ;
301
284
let next_encrypted_message = encrypted_messages. remove ( 0 ) ;
302
285
current_encrypted_message. extend_from_slice ( & next_encrypted_message) ;
303
- let decrypted_message = remote_peer. decrypt_single_message ( Some ( & current_encrypted_message) ) . unwrap ( ) . unwrap ( ) ;
286
+ let decrypted_message = remote_peer. decrypt_single_message ( & current_encrypted_message) . unwrap ( ) . unwrap ( ) ;
304
287
assert_eq ! ( decrypted_message, message) ;
305
288
}
306
289
307
290
for _ in 0 ..501 {
308
291
// decrypt messages directly from buffer without adding to it
309
- let decrypted_message = remote_peer. decrypt_single_message ( None ) . unwrap ( ) . unwrap ( ) ;
292
+ let decrypted_message = remote_peer. decrypt_single_message ( & [ ] ) . unwrap ( ) . unwrap ( ) ;
310
293
assert_eq ! ( decrypted_message, message) ;
311
294
}
312
295
}
@@ -318,7 +301,7 @@ mod tests {
318
301
let encrypted = remote_peer. encrypt ( & [ 1 ] ) ;
319
302
320
303
connected_peer. decryptor . receiving_key = [ 0 ; 32 ] ;
321
- assert_eq ! ( connected_peer. decrypt_single_message( Some ( & encrypted) ) , Err ( "invalid hmac" . to_string( ) ) ) ;
304
+ assert_eq ! ( connected_peer. decrypt_single_message( & encrypted) , Err ( "invalid hmac" . to_string( ) ) ) ;
322
305
}
323
306
324
307
// Test next()::None
0 commit comments