Skip to content

Commit ae57869

Browse files
committed
Store local and remote inputs/output in separate vectors
1 parent 50dce76 commit ae57869

File tree

1 file changed

+50
-39
lines changed

1 file changed

+50
-39
lines changed

lightning/src/ln/interactivetxs.rs

+50-39
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ struct NegotiationContext {
8787
holder_is_initiator: bool,
8888
received_tx_add_input_count: u16,
8989
received_tx_add_output_count: u16,
90-
/// The inputs to be contributed by the holder.
91-
inputs: HashMap<SerialId, InteractiveTxInput>,
90+
/// The inputs contributed by the holder
91+
local_inputs: HashMap<SerialId, InteractiveTxInput>,
92+
/// The inputs contributed by the counterparty
93+
remote_inputs: HashMap<SerialId, InteractiveTxInput>,
9294
/// The output intended to be the new funding output.
9395
/// When an output added to the same pubkey, it will be treated as the shared output.
9496
/// The script pubkey is used to discriminate which output is the funding output.
@@ -108,8 +110,10 @@ struct NegotiationContext {
108110
/// Note: this output is also included in `outputs`.
109111
actual_new_funding_output: Option<SharedOutput>,
110112
prevtx_outpoints: HashSet<OutPoint>,
111-
/// The outputs to be contributed by the holder (excluding the funding output)
112-
outputs: HashMap<SerialId, InteractiveTxOutput>,
113+
/// The outputs contributed by the holder
114+
local_outputs: HashMap<SerialId, InteractiveTxOutput>,
115+
/// The outputs contributed by the counterparty
116+
remote_outputs: HashMap<SerialId, InteractiveTxOutput>,
113117
/// The locktime of the funding transaction.
114118
tx_locktime: AbsoluteLockTime,
115119
/// The fee rate used for the transaction
@@ -129,12 +133,14 @@ impl NegotiationContext {
129133
holder_is_initiator,
130134
received_tx_add_input_count: 0,
131135
received_tx_add_output_count: 0,
132-
inputs: new_hash_map(),
136+
local_inputs: new_hash_map(),
137+
remote_inputs: new_hash_map(),
133138
intended_new_funding_output,
134139
intended_local_contribution_satoshis,
135140
actual_new_funding_output: None,
136141
prevtx_outpoints: new_hash_set(),
137-
outputs: new_hash_map(),
142+
local_outputs: new_hash_map(),
143+
remote_outputs: new_hash_map(),
138144
tx_locktime,
139145
feerate_sat_per_kw,
140146
}
@@ -191,24 +197,16 @@ impl NegotiationContext {
191197
self.holder_is_initiator == serial_id.is_for_non_initiator()
192198
}
193199

194-
fn total_input_and_output_count(&self) -> usize {
195-
self.inputs.len().saturating_add(self.outputs.len())
200+
fn total_input_count(&self) -> usize {
201+
self.local_inputs.len().saturating_add(self.remote_inputs.len())
196202
}
197203

198-
fn counterparty_inputs_contributed(&self) -> impl Iterator<Item = &InteractiveTxInput> + Clone {
199-
self.inputs
200-
.iter()
201-
.filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
202-
.map(|(_, input_with_prevout)| input_with_prevout)
204+
fn total_output_count(&self) -> usize {
205+
self.local_outputs.len().saturating_add(self.remote_outputs.len())
203206
}
204207

205-
fn counterparty_outputs_contributed(
206-
&self,
207-
) -> impl Iterator<Item = &InteractiveTxOutput> + Clone {
208-
self.outputs
209-
.iter()
210-
.filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
211-
.map(|(_, output)| output)
208+
fn total_input_and_output_count(&self) -> usize {
209+
self.total_input_count().saturating_add(self.total_output_count())
212210
}
213211

214212
fn received_tx_add_input(&mut self, msg: &msgs::TxAddInput) -> Result<(), AbortReason> {
@@ -265,7 +263,7 @@ impl NegotiationContext {
265263
}
266264

267265
let prev_outpoint = OutPoint { txid, vout: msg.prevtx_out };
268-
match self.inputs.entry(msg.serial_id) {
266+
match self.remote_inputs.entry(msg.serial_id) {
269267
hash_map::Entry::Occupied(_) => {
270268
// The receiving node:
271269
// - MUST fail the negotiation if:
@@ -303,7 +301,7 @@ impl NegotiationContext {
303301
return Err(AbortReason::IncorrectSerialIdParity);
304302
}
305303

306-
self.inputs
304+
self.remote_inputs
307305
.remove(&msg.serial_id)
308306
// The receiving node:
309307
// - MUST fail the negotiation if:
@@ -339,7 +337,7 @@ impl NegotiationContext {
339337
// Check that adding this output would not cause the total output value to exceed the total
340338
// bitcoin supply.
341339
let mut outputs_value: u64 = 0;
342-
for output in self.outputs.iter() {
340+
for output in self.local_outputs.iter().chain(self.remote_outputs.iter()) {
343341
outputs_value = outputs_value.saturating_add(output.1.value());
344342
}
345343
if outputs_value.saturating_add(msg.sats) > TOTAL_BITCOIN_SUPPLY_SATOSHIS {
@@ -377,7 +375,7 @@ impl NegotiationContext {
377375
} else {
378376
InteractiveTxOutput::Remote(RemoteOutput { serial_id: msg.serial_id, txout })
379377
};
380-
match self.outputs.entry(msg.serial_id) {
378+
match self.remote_outputs.entry(msg.serial_id) {
381379
hash_map::Entry::Occupied(_) => {
382380
// The receiving node:
383381
// - MUST fail the negotiation if:
@@ -395,7 +393,7 @@ impl NegotiationContext {
395393
if !self.is_serial_id_valid_for_counterparty(&msg.serial_id) {
396394
return Err(AbortReason::IncorrectSerialIdParity);
397395
}
398-
if let Some(_) = self.outputs.remove(&msg.serial_id) {
396+
if let Some(_) = self.remote_outputs.remove(&msg.serial_id) {
399397
Ok(())
400398
} else {
401399
// The receiving node:
@@ -430,7 +428,7 @@ impl NegotiationContext {
430428
.ok_or(AbortReason::PrevTxOutInvalid)?
431429
.value,
432430
});
433-
self.inputs.insert(msg.serial_id, input);
431+
self.local_inputs.insert(msg.serial_id, input);
434432
Ok(())
435433
}
436434

@@ -443,17 +441,17 @@ impl NegotiationContext {
443441
} else {
444442
InteractiveTxOutput::Local(LocalOutput { serial_id: msg.serial_id, txout })
445443
};
446-
self.outputs.insert(msg.serial_id, output);
444+
self.local_outputs.insert(msg.serial_id, output);
447445
Ok(())
448446
}
449447

450448
fn sent_tx_remove_input(&mut self, msg: &msgs::TxRemoveInput) -> Result<(), AbortReason> {
451-
self.inputs.remove(&msg.serial_id);
449+
self.local_inputs.remove(&msg.serial_id);
452450
Ok(())
453451
}
454452

455453
fn sent_tx_remove_output(&mut self, msg: &msgs::TxRemoveOutput) -> Result<(), AbortReason> {
456-
self.outputs.remove(&msg.serial_id);
454+
self.local_outputs.remove(&msg.serial_id);
457455
Ok(())
458456
}
459457

@@ -464,11 +462,19 @@ impl NegotiationContext {
464462
// - the peer's total input satoshis with its part of any shared input is less than their outputs
465463
// and proportion of any shared output
466464
let mut counterparty_value_in: u64 = 0;
467-
for (_, input) in &self.inputs {
465+
// Consider remote and local also, due to possible shared inputs
466+
for (_, input) in &self.remote_inputs {
467+
counterparty_value_in = counterparty_value_in.saturating_add(input.remote_value());
468+
}
469+
for (_, input) in &self.local_inputs {
468470
counterparty_value_in = counterparty_value_in.saturating_add(input.remote_value());
469471
}
470472
let mut counterparty_value_out: u64 = 0;
471-
for (_, output) in &self.outputs {
473+
// Consider both local and remote, due to possible shared inputs
474+
for (_, output) in &self.remote_outputs {
475+
counterparty_value_out = counterparty_value_out.saturating_add(output.remote_value());
476+
}
477+
for (_, output) in &self.local_outputs {
472478
counterparty_value_out = counterparty_value_out.saturating_add(output.remote_value());
473479
}
474480
if counterparty_value_in < counterparty_value_out {
@@ -477,8 +483,8 @@ impl NegotiationContext {
477483

478484
// - there are more than 252 inputs
479485
// - there are more than 252 outputs
480-
if self.inputs.len() > MAX_INPUTS_OUTPUTS_COUNT
481-
|| self.outputs.len() > MAX_INPUTS_OUTPUTS_COUNT
486+
if self.total_input_count() > MAX_INPUTS_OUTPUTS_COUNT
487+
|| self.total_output_count() > MAX_INPUTS_OUTPUTS_COUNT
482488
{
483489
return Err(AbortReason::ExceededNumberOfInputsOrOutputs);
484490
}
@@ -488,14 +494,14 @@ impl NegotiationContext {
488494

489495
// - the peer's paid feerate does not meet or exceed the agreed feerate (based on the minimum fee).
490496
let mut counterparty_weight_contributed: u64 = self
491-
.counterparty_outputs_contributed()
492-
.map(|output| {
497+
.remote_outputs
498+
.iter()
499+
.map(|(_, output)| {
493500
(8 /* value */ + output.script_pubkey().consensus_encode(&mut sink()).unwrap() as u64)
494501
* WITNESS_SCALE_FACTOR as u64
495502
})
496503
.sum();
497-
counterparty_weight_contributed +=
498-
self.counterparty_inputs_contributed().count() as u64 * INPUT_WEIGHT;
504+
counterparty_weight_contributed += self.remote_inputs.len() as u64 * INPUT_WEIGHT;
499505
let counterparty_fees_contributed =
500506
counterparty_value_in.saturating_sub(counterparty_value_out);
501507
let mut required_counterparty_contribution_fee =
@@ -516,8 +522,13 @@ impl NegotiationContext {
516522
}
517523

518524
// Inputs and outputs must be sorted by serial_id
519-
let mut inputs = self.inputs.into_iter().collect::<Vec<_>>();
520-
let mut outputs = self.outputs.into_iter().collect::<Vec<_>>();
525+
let mut inputs =
526+
self.local_inputs.into_iter().chain(self.remote_inputs.into_iter()).collect::<Vec<_>>();
527+
let mut outputs = self
528+
.local_outputs
529+
.into_iter()
530+
.chain(self.remote_outputs.into_iter())
531+
.collect::<Vec<_>>();
521532
inputs.sort_unstable_by_key(|(serial_id, _)| *serial_id);
522533
outputs.sort_unstable_by_key(|(serial_id, _)| *serial_id);
523534

0 commit comments

Comments
 (0)