Skip to content

Commit 6ffc4b3

Browse files
committed
Exhaustively handle adding/removing inputs and outputs
1 parent 2643fce commit 6ffc4b3

File tree

1 file changed

+76
-21
lines changed

1 file changed

+76
-21
lines changed

lightning/src/ln/interactivetxs.rs

Lines changed: 76 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -262,27 +262,50 @@ impl<S> InteractiveTxStateMachine<S>
262262
}
263263
}
264264

265-
pub(crate) fn receive_tx_abort(mut self) -> InteractiveTxStateMachine<NegotiationAborted> {
266-
todo!();
265+
fn receive_tx_remove_output(mut self, serial_id: SerialId) ->
266+
Result<InteractiveTxStateMachine<Negotiating>, InteractiveTxStateMachine<NegotiationAborted>> {
267+
if !self.is_valid_counterparty_serial_id(serial_id) {
268+
return self.abort_negotiation(AbortReason::IncorrectSerialIdParity);
269+
}
270+
271+
if let Some(output) = self.context.outputs.remove(&serial_id) {
272+
Ok(InteractiveTxStateMachine { context: self.context, state: Negotiating {} })
273+
} else {
274+
self.abort_negotiation(AbortReason::SerialIdUnknown)
275+
}
267276
}
268277

269-
fn send_tx_add_input(mut self, serial_id: u64, input: TxIn) -> InteractiveTxStateMachine<Negotiating> {
278+
pub(crate) fn send_tx_add_input(mut self, serial_id: u64, input: TxIn) -> InteractiveTxStateMachine<Negotiating> {
270279
self.context.inputs.insert(serial_id, input);
271280
InteractiveTxStateMachine { context: self.context, state: Negotiating {} }
272281
}
273282

274-
pub(crate) fn send_tx_add_output(mut self, serial_id: u64, output: TxOut) -> InteractiveTxStateMachine<Negotiating> {
283+
pub(crate) fn send_tx_add_output(mut self, serial_id: SerialId, output: TxOut) -> InteractiveTxStateMachine<Negotiating> {
275284
self.context.outputs.insert(serial_id, output);
276285
InteractiveTxStateMachine { context: self.context, state: Negotiating {} }
277286
}
278287

288+
pub(crate) fn send_tx_remove_input(mut self, serial_id: SerialId) -> InteractiveTxStateMachine<Negotiating> {
289+
self.context.inputs.remove(&serial_id);
290+
InteractiveTxStateMachine { context: self.context, state: Negotiating {} }
291+
}
292+
293+
pub(crate) fn send_tx_remove_output(mut self, serial_id: SerialId) -> InteractiveTxStateMachine<Negotiating> {
294+
self.context.outputs.remove(&serial_id);
295+
InteractiveTxStateMachine { context: self.context, state: Negotiating {} }
296+
}
297+
279298
pub(crate) fn send_tx_abort(mut self) -> InteractiveTxStateMachine<NegotiationAborted> {
280299
// A sending node:
281300
// - MUST NOT have already transmitted tx_signatures
282301
// - SHOULD forget the current negotiation and reset their state.
283302
todo!();
284303
}
285304

305+
pub(crate) fn receive_tx_abort(mut self) -> InteractiveTxStateMachine<NegotiationAborted> {
306+
todo!();
307+
}
308+
286309
fn is_valid_counterparty_serial_id(&self, serial_id: SerialId) -> bool {
287310
// A received `SerialId`'s parity must match the role of the counterparty.
288311
self.context.holder_is_initiator == !serial_id.is_valid_for_initiator()
@@ -358,7 +381,43 @@ impl InteractiveTxConstructor {
358381
}
359382

360383
// Functions that handle the case where mode is [`ChannelMode::Negotiating`]
361-
fn handle_negotiating<F>(&mut self, f: F)
384+
fn abort_negotation(&mut self, reason: AbortReason) {
385+
self.handle_negotiating_receive(|state_machine| state_machine.abort_negotiation(reason))
386+
}
387+
388+
fn receive_tx_add_input(&mut self, serial_id: SerialId, transaction_input: TxAddInput, confirmed: bool) {
389+
self.handle_negotiating_receive(|state_machine| state_machine.receive_tx_add_input(serial_id, transaction_input, confirmed))
390+
}
391+
392+
fn receive_tx_remove_input(&mut self, serial_id: SerialId) {
393+
self.handle_negotiating_receive(|state_machine| state_machine.receive_tx_remove_input(serial_id))
394+
}
395+
396+
fn receive_tx_add_output(&mut self, serial_id: SerialId, output: TxOut) {
397+
self.handle_negotiating_receive(|state_machine| state_machine.receive_tx_add_output(serial_id, output))
398+
}
399+
400+
fn receive_tx_remove_output(&mut self, serial_id: SerialId) {
401+
self.handle_negotiating_receive(|state_machine| state_machine.receive_tx_remove_output(serial_id))
402+
}
403+
404+
fn send_tx_add_input(&mut self, serial_id: SerialId, transaction_input: TxIn) {
405+
self.handle_negotiating_send(|state_machine| state_machine.send_tx_add_input(serial_id, transaction_input))
406+
}
407+
408+
fn send_tx_remove_input(&mut self, serial_id: SerialId) {
409+
self.handle_negotiating_send(|state_machine| state_machine.send_tx_remove_input(serial_id))
410+
}
411+
412+
fn send_tx_add_output(&mut self, serial_id: SerialId, transaction_output: TxOut) {
413+
self.handle_negotiating_send(|state_machine| state_machine.send_tx_add_output(serial_id, transaction_output))
414+
}
415+
416+
fn send_tx_remove_output(&mut self, serial_id: SerialId) {
417+
self.handle_negotiating_send(|state_machine| state_machine.send_tx_remove_output(serial_id))
418+
}
419+
420+
fn handle_negotiating_receive<F>(&mut self, f: F)
362421
where F: FnOnce(InteractiveTxStateMachine<Negotiating>) -> Result<InteractiveTxStateMachine<Negotiating>, InteractiveTxStateMachine<NegotiationAborted>> {
363422
// We use mem::take here because we want to update `self.mode` based on its value and
364423
// avoid cloning `ChannelMode`.
@@ -372,23 +431,19 @@ impl InteractiveTxConstructor {
372431
} else {
373432
mode
374433
}
375-
376434
}
377435

378-
fn abort_negotation(&mut self, reason: AbortReason) {
379-
self.handle_negotiating(|state_machine| state_machine.abort_negotiation(reason))
380-
}
381-
382-
fn add_tx_input(&mut self, serial_id: SerialId, transaction_input: TxAddInput, confirmed: bool) {
383-
self.handle_negotiating(|state_machine| state_machine.receive_tx_add_input(serial_id, transaction_input, confirmed))
384-
}
385-
386-
fn remove_tx_input(&mut self, serial_id: SerialId) {
387-
self.handle_negotiating(|state_machine| state_machine.receive_tx_remove_input(serial_id))
388-
}
389-
390-
fn add_tx_output(&mut self, serial_id: SerialId, output: TxOut) {
391-
self.handle_negotiating(|state_machine| state_machine.receive_tx_add_output(serial_id, output))
436+
fn handle_negotiating_send<F>(&mut self, f: F)
437+
where F: FnOnce(InteractiveTxStateMachine<Negotiating>) -> InteractiveTxStateMachine<Negotiating> {
438+
// We use mem::take here because we want to update `self.mode` based on its value and
439+
// avoid cloning `ChannelMode`.
440+
// By moving the value out of the struct, we can now safely modify it in this scope.
441+
let mut mode = core::mem::take(&mut self.mode);
442+
self.mode = if let ChannelMode::Negotiating(constructor) = mode {
443+
ChannelMode::Negotiating(f(constructor))
444+
} else {
445+
mode
446+
}
392447
}
393448
}
394449

@@ -431,7 +486,7 @@ mod tests {
431486
}
432487

433488
fn handle_add_tx_input(&mut self) {
434-
self.tx_constructor.add_tx_input(1234, get_sample_tx_add_input(), true)
489+
self.tx_constructor.receive_tx_add_input(1234, get_sample_tx_add_input(), true)
435490
}
436491
}
437492

0 commit comments

Comments
 (0)