@@ -87,8 +87,10 @@ struct NegotiationContext {
87
87
holder_is_initiator : bool ,
88
88
received_tx_add_input_count : u16 ,
89
89
received_tx_add_output_count : u16 ,
90
- /// The inputs to be contributed by the holder.
91
- inputs : HashMap < SerialId , InteractiveTxInput > ,
90
+ /// The inputs contributed by the holder
91
+ local_inputs : HashMap < SerialId , InteractiveTxInput > ,
92
+ /// The inputs contributed by the counterparty
93
+ remote_inputs : HashMap < SerialId , InteractiveTxInput > ,
92
94
/// The output intended to be the new funding output.
93
95
/// When an output added to the same pubkey, it will be treated as the shared output.
94
96
/// The script pubkey is used to discriminate which output is the funding output.
@@ -108,8 +110,10 @@ struct NegotiationContext {
108
110
/// Note: this output is also included in `outputs`.
109
111
actual_new_funding_output : Option < SharedOutput > ,
110
112
prevtx_outpoints : HashSet < OutPoint > ,
111
- /// The outputs to be contributed by the holder (excluding the funding output)
112
- outputs : HashMap < SerialId , InteractiveTxOutput > ,
113
+ /// The outputs contributed by the holder
114
+ local_outputs : HashMap < SerialId , InteractiveTxOutput > ,
115
+ /// The outputs contributed by the counterparty
116
+ remote_outputs : HashMap < SerialId , InteractiveTxOutput > ,
113
117
/// The locktime of the funding transaction.
114
118
tx_locktime : AbsoluteLockTime ,
115
119
/// The fee rate used for the transaction
@@ -129,12 +133,14 @@ impl NegotiationContext {
129
133
holder_is_initiator,
130
134
received_tx_add_input_count : 0 ,
131
135
received_tx_add_output_count : 0 ,
132
- inputs : new_hash_map ( ) ,
136
+ local_inputs : new_hash_map ( ) ,
137
+ remote_inputs : new_hash_map ( ) ,
133
138
intended_new_funding_output,
134
139
intended_local_contribution_satoshis,
135
140
actual_new_funding_output : None ,
136
141
prevtx_outpoints : new_hash_set ( ) ,
137
- outputs : new_hash_map ( ) ,
142
+ local_outputs : new_hash_map ( ) ,
143
+ remote_outputs : new_hash_map ( ) ,
138
144
tx_locktime,
139
145
feerate_sat_per_kw,
140
146
}
@@ -191,24 +197,16 @@ impl NegotiationContext {
191
197
self . holder_is_initiator == serial_id. is_for_non_initiator ( )
192
198
}
193
199
194
- fn total_input_and_output_count ( & self ) -> usize {
195
- self . inputs . len ( ) . saturating_add ( self . outputs . len ( ) )
200
+ fn total_input_count ( & self ) -> usize {
201
+ self . local_inputs . len ( ) . saturating_add ( self . remote_inputs . len ( ) )
196
202
}
197
203
198
- fn counterparty_inputs_contributed ( & self ) -> impl Iterator < Item = & InteractiveTxInput > + Clone {
199
- self . inputs
200
- . iter ( )
201
- . filter ( move |( serial_id, _) | self . is_serial_id_valid_for_counterparty ( serial_id) )
202
- . map ( |( _, input_with_prevout) | input_with_prevout)
204
+ fn total_output_count ( & self ) -> usize {
205
+ self . local_outputs . len ( ) . saturating_add ( self . remote_outputs . len ( ) )
203
206
}
204
207
205
- fn counterparty_outputs_contributed (
206
- & self ,
207
- ) -> impl Iterator < Item = & InteractiveTxOutput > + Clone {
208
- self . outputs
209
- . iter ( )
210
- . filter ( move |( serial_id, _) | self . is_serial_id_valid_for_counterparty ( serial_id) )
211
- . map ( |( _, output) | output)
208
+ fn total_input_and_output_count ( & self ) -> usize {
209
+ self . total_input_count ( ) . saturating_add ( self . total_output_count ( ) )
212
210
}
213
211
214
212
fn received_tx_add_input ( & mut self , msg : & msgs:: TxAddInput ) -> Result < ( ) , AbortReason > {
@@ -265,7 +263,7 @@ impl NegotiationContext {
265
263
}
266
264
267
265
let prev_outpoint = OutPoint { txid, vout : msg. prevtx_out } ;
268
- match self . inputs . entry ( msg. serial_id ) {
266
+ match self . remote_inputs . entry ( msg. serial_id ) {
269
267
hash_map:: Entry :: Occupied ( _) => {
270
268
// The receiving node:
271
269
// - MUST fail the negotiation if:
@@ -303,7 +301,7 @@ impl NegotiationContext {
303
301
return Err ( AbortReason :: IncorrectSerialIdParity ) ;
304
302
}
305
303
306
- self . inputs
304
+ self . remote_inputs
307
305
. remove ( & msg. serial_id )
308
306
// The receiving node:
309
307
// - MUST fail the negotiation if:
@@ -339,7 +337,7 @@ impl NegotiationContext {
339
337
// Check that adding this output would not cause the total output value to exceed the total
340
338
// bitcoin supply.
341
339
let mut outputs_value: u64 = 0 ;
342
- for output in self . outputs . iter ( ) {
340
+ for output in self . local_outputs . iter ( ) . chain ( self . remote_outputs . iter ( ) ) {
343
341
outputs_value = outputs_value. saturating_add ( output. 1 . value ( ) ) ;
344
342
}
345
343
if outputs_value. saturating_add ( msg. sats ) > TOTAL_BITCOIN_SUPPLY_SATOSHIS {
@@ -377,7 +375,7 @@ impl NegotiationContext {
377
375
} else {
378
376
InteractiveTxOutput :: Remote ( RemoteOutput { serial_id : msg. serial_id , txout } )
379
377
} ;
380
- match self . outputs . entry ( msg. serial_id ) {
378
+ match self . remote_outputs . entry ( msg. serial_id ) {
381
379
hash_map:: Entry :: Occupied ( _) => {
382
380
// The receiving node:
383
381
// - MUST fail the negotiation if:
@@ -395,7 +393,7 @@ impl NegotiationContext {
395
393
if !self . is_serial_id_valid_for_counterparty ( & msg. serial_id ) {
396
394
return Err ( AbortReason :: IncorrectSerialIdParity ) ;
397
395
}
398
- if let Some ( _) = self . outputs . remove ( & msg. serial_id ) {
396
+ if let Some ( _) = self . remote_outputs . remove ( & msg. serial_id ) {
399
397
Ok ( ( ) )
400
398
} else {
401
399
// The receiving node:
@@ -430,7 +428,7 @@ impl NegotiationContext {
430
428
. ok_or ( AbortReason :: PrevTxOutInvalid ) ?
431
429
. value ,
432
430
} ) ;
433
- self . inputs . insert ( msg. serial_id , input) ;
431
+ self . local_inputs . insert ( msg. serial_id , input) ;
434
432
Ok ( ( ) )
435
433
}
436
434
@@ -443,17 +441,17 @@ impl NegotiationContext {
443
441
} else {
444
442
InteractiveTxOutput :: Local ( LocalOutput { serial_id : msg. serial_id , txout } )
445
443
} ;
446
- self . outputs . insert ( msg. serial_id , output) ;
444
+ self . local_outputs . insert ( msg. serial_id , output) ;
447
445
Ok ( ( ) )
448
446
}
449
447
450
448
fn sent_tx_remove_input ( & mut self , msg : & msgs:: TxRemoveInput ) -> Result < ( ) , AbortReason > {
451
- self . inputs . remove ( & msg. serial_id ) ;
449
+ self . local_inputs . remove ( & msg. serial_id ) ;
452
450
Ok ( ( ) )
453
451
}
454
452
455
453
fn sent_tx_remove_output ( & mut self , msg : & msgs:: TxRemoveOutput ) -> Result < ( ) , AbortReason > {
456
- self . outputs . remove ( & msg. serial_id ) ;
454
+ self . local_outputs . remove ( & msg. serial_id ) ;
457
455
Ok ( ( ) )
458
456
}
459
457
@@ -464,11 +462,19 @@ impl NegotiationContext {
464
462
// - the peer's total input satoshis with its part of any shared input is less than their outputs
465
463
// and proportion of any shared output
466
464
let mut counterparty_value_in: u64 = 0 ;
467
- for ( _, input) in & self . inputs {
465
+ // Consider remote and local also, due to possible shared inputs
466
+ for ( _, input) in & self . remote_inputs {
467
+ counterparty_value_in = counterparty_value_in. saturating_add ( input. remote_value ( ) ) ;
468
+ }
469
+ for ( _, input) in & self . local_inputs {
468
470
counterparty_value_in = counterparty_value_in. saturating_add ( input. remote_value ( ) ) ;
469
471
}
470
472
let mut counterparty_value_out: u64 = 0 ;
471
- for ( _, output) in & self . outputs {
473
+ // Consider both local and remote, due to possible shared inputs
474
+ for ( _, output) in & self . remote_outputs {
475
+ counterparty_value_out = counterparty_value_out. saturating_add ( output. remote_value ( ) ) ;
476
+ }
477
+ for ( _, output) in & self . local_outputs {
472
478
counterparty_value_out = counterparty_value_out. saturating_add ( output. remote_value ( ) ) ;
473
479
}
474
480
if counterparty_value_in < counterparty_value_out {
@@ -477,8 +483,8 @@ impl NegotiationContext {
477
483
478
484
// - there are more than 252 inputs
479
485
// - there are more than 252 outputs
480
- if self . inputs . len ( ) > MAX_INPUTS_OUTPUTS_COUNT
481
- || self . outputs . len ( ) > MAX_INPUTS_OUTPUTS_COUNT
486
+ if self . total_input_count ( ) > MAX_INPUTS_OUTPUTS_COUNT
487
+ || self . total_output_count ( ) > MAX_INPUTS_OUTPUTS_COUNT
482
488
{
483
489
return Err ( AbortReason :: ExceededNumberOfInputsOrOutputs ) ;
484
490
}
@@ -488,14 +494,14 @@ impl NegotiationContext {
488
494
489
495
// - the peer's paid feerate does not meet or exceed the agreed feerate (based on the minimum fee).
490
496
let mut counterparty_weight_contributed: u64 = self
491
- . counterparty_outputs_contributed ( )
492
- . map ( |output| {
497
+ . remote_outputs
498
+ . iter ( )
499
+ . map ( |( _, output) | {
493
500
( 8 /* value */ + output. script_pubkey ( ) . consensus_encode ( & mut sink ( ) ) . unwrap ( ) as u64 )
494
501
* WITNESS_SCALE_FACTOR as u64
495
502
} )
496
503
. sum ( ) ;
497
- counterparty_weight_contributed +=
498
- self . counterparty_inputs_contributed ( ) . count ( ) as u64 * INPUT_WEIGHT ;
504
+ counterparty_weight_contributed += self . remote_inputs . len ( ) as u64 * INPUT_WEIGHT ;
499
505
let counterparty_fees_contributed =
500
506
counterparty_value_in. saturating_sub ( counterparty_value_out) ;
501
507
let mut required_counterparty_contribution_fee =
@@ -516,8 +522,13 @@ impl NegotiationContext {
516
522
}
517
523
518
524
// Inputs and outputs must be sorted by serial_id
519
- let mut inputs = self . inputs . into_iter ( ) . collect :: < Vec < _ > > ( ) ;
520
- let mut outputs = self . outputs . into_iter ( ) . collect :: < Vec < _ > > ( ) ;
525
+ let mut inputs =
526
+ self . local_inputs . into_iter ( ) . chain ( self . remote_inputs . into_iter ( ) ) . collect :: < Vec < _ > > ( ) ;
527
+ let mut outputs = self
528
+ . local_outputs
529
+ . into_iter ( )
530
+ . chain ( self . remote_outputs . into_iter ( ) )
531
+ . collect :: < Vec < _ > > ( ) ;
521
532
inputs. sort_unstable_by_key ( |( serial_id, _) | * serial_id) ;
522
533
outputs. sort_unstable_by_key ( |( serial_id, _) | * serial_id) ;
523
534
0 commit comments