Skip to content

Commit 1d0c8de

Browse files
committed
htlcswitch: continue threading context through
Here, we address one `context.TODO` in the `memoryMailBox` by adding a context to the FailAdd method. The rest of the PR ensures that we thread context to any call to FailAdd. We add another `context.TODO` in the `interceptedForward` Fail method.
1 parent c6022e4 commit 1d0c8de

File tree

6 files changed

+108
-64
lines changed

6 files changed

+108
-64
lines changed

htlcswitch/interceptable_switch.go

+15-5
Original file line numberDiff line numberDiff line change
@@ -751,9 +751,11 @@ func (f *interceptedForward) ResumeModified(
751751
// Fail notifies the intention to Fail an existing hold forward with an
752752
// encrypted failure reason.
753753
func (f *interceptedForward) Fail(reason []byte) error {
754+
ctx := context.TODO()
755+
754756
obfuscatedReason := f.packet.obfuscator.IntermediateEncrypt(reason)
755757

756-
return f.resolve(&lnwire.UpdateFailHTLC{
758+
return f.resolve(ctx, &lnwire.UpdateFailHTLC{
757759
Reason: obfuscatedReason,
758760
})
759761
}
@@ -824,24 +826,29 @@ func (f *interceptedForward) FailWithCode(ctx context.Context,
824826
return fmt.Errorf("failed to encrypt failure reason %w", err)
825827
}
826828

827-
return f.resolve(&lnwire.UpdateFailHTLC{
829+
return f.resolve(ctx, &lnwire.UpdateFailHTLC{
828830
Reason: reason,
829831
})
830832
}
831833

832834
// Settle forwards a settled packet to the switch.
833835
func (f *interceptedForward) Settle(preimage lntypes.Preimage) error {
836+
ctx := context.TODO()
837+
834838
if !preimage.Matches(f.htlc.PaymentHash) {
835839
return errors.New("preimage does not match hash")
836840
}
837-
return f.resolve(&lnwire.UpdateFulfillHTLC{
841+
842+
return f.resolve(ctx, &lnwire.UpdateFulfillHTLC{
838843
PaymentPreimage: preimage,
839844
})
840845
}
841846

842847
// resolve is used for both Settle and Fail and forwards the message to the
843848
// switch.
844-
func (f *interceptedForward) resolve(message lnwire.Message) error {
849+
func (f *interceptedForward) resolve(ctx context.Context,
850+
message lnwire.Message) error {
851+
845852
pkt := &htlcPacket{
846853
incomingChanID: f.packet.incomingChanID,
847854
incomingHTLCID: f.packet.incomingHTLCID,
@@ -853,5 +860,8 @@ func (f *interceptedForward) resolve(message lnwire.Message) error {
853860
obfuscator: f.packet.obfuscator,
854861
sourceRef: f.packet.sourceRef,
855862
}
856-
return f.htlcSwitch.mailOrchestrator.Deliver(pkt.incomingChanID, pkt)
863+
864+
return f.htlcSwitch.mailOrchestrator.Deliver(
865+
ctx, pkt.incomingChanID, pkt,
866+
)
857867
}

htlcswitch/link.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -1695,7 +1695,7 @@ func (l *channelLink) handleDownstreamUpdateAdd(ctx context.Context,
16951695
// already sent Stfu, then we can't add new htlcs to the link and we
16961696
// need to bounce it.
16971697
if l.IsFlushing(Outgoing) || !l.quiescer.CanSendUpdates() {
1698-
l.mailBox.FailAdd(pkt)
1698+
l.mailBox.FailAdd(ctx, pkt)
16991699

17001700
return NewDetailedLinkError(
17011701
&lnwire.FailTemporaryChannelFailure{},
@@ -1718,7 +1718,7 @@ func (l *channelLink) handleDownstreamUpdateAdd(ctx context.Context,
17181718
l.log.Debugf("Unable to handle downstream HTLC - max fee " +
17191719
"exposure exceeded")
17201720

1721-
l.mailBox.FailAdd(pkt)
1721+
l.mailBox.FailAdd(ctx, pkt)
17221722

17231723
return NewDetailedLinkError(
17241724
lnwire.NewTemporaryChannelFailure(nil),
@@ -1752,7 +1752,7 @@ func (l *channelLink) handleDownstreamUpdateAdd(ctx context.Context,
17521752
// the switch, since the circuit was never fully opened,
17531753
// and the forwarding package shows it as
17541754
// unacknowledged.
1755-
l.mailBox.FailAdd(pkt)
1755+
l.mailBox.FailAdd(ctx, pkt)
17561756

17571757
return NewDetailedLinkError(
17581758
lnwire.NewTemporaryChannelFailure(nil),

htlcswitch/mailbox.go

+37-19
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"time"
1111

1212
"github.com/lightningnetwork/lnd/clock"
13+
"github.com/lightningnetwork/lnd/fn/v2"
1314
"github.com/lightningnetwork/lnd/lntypes"
1415
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
1516
"github.com/lightningnetwork/lnd/lnwire"
@@ -53,7 +54,7 @@ type MailBox interface {
5354
// packet from being delivered after the link restarts if the switch has
5455
// remained online. The generated LinkError will show an
5556
// OutgoingFailureDownstreamHtlcAdd FailureDetail.
56-
FailAdd(pkt *htlcPacket)
57+
FailAdd(ctx context.Context, pkt *htlcPacket)
5758

5859
// MessageOutBox returns a channel that any new messages ready for
5960
// delivery will be sent on.
@@ -82,7 +83,7 @@ type MailBox interface {
8283

8384
// Start starts the mailbox and any goroutines it needs to operate
8485
// properly.
85-
Start()
86+
Start(ctx context.Context)
8687

8788
// Stop signals the mailbox and its goroutines for a graceful shutdown.
8889
Stop()
@@ -148,7 +149,9 @@ type memoryMailBox struct {
148149

149150
wireShutdown chan struct{}
150151
pktShutdown chan struct{}
151-
quit chan struct{}
152+
153+
cancel fn.Option[context.CancelFunc]
154+
quit chan struct{}
152155

153156
// feeRate is set when the link receives or sends out fee updates. It
154157
// is refreshed when AttachMailBox is called in case a fee update did
@@ -205,10 +208,13 @@ const (
205208
// Start starts the mailbox and any goroutines it needs to operate properly.
206209
//
207210
// NOTE: This method is part of the MailBox interface.
208-
func (m *memoryMailBox) Start() {
211+
func (m *memoryMailBox) Start(ctx context.Context) {
209212
m.started.Do(func() {
210-
go m.wireMailCourier()
211-
go m.pktMailCourier()
213+
ctx, cancel := context.WithCancel(ctx)
214+
m.cancel = fn.Some(cancel)
215+
216+
go m.wireMailCourier(ctx)
217+
go m.pktMailCourier(ctx)
212218
})
213219
}
214220

@@ -324,6 +330,7 @@ func (m *memoryMailBox) HasPacket(inKey CircuitKey) bool {
324330
// NOTE: This method is part of the MailBox interface.
325331
func (m *memoryMailBox) Stop() {
326332
m.stopped.Do(func() {
333+
m.cancel.WhenSome(func(fn context.CancelFunc) { fn() })
327334
close(m.quit)
328335

329336
m.signalUntilShutdown(wireCourier)
@@ -372,7 +379,7 @@ func (p *pktWithExpiry) deadline(clock clock.Clock) <-chan time.Time {
372379

373380
// wireMailCourier is a dedicated goroutine whose job is to reliably deliver
374381
// wire messages.
375-
func (m *memoryMailBox) wireMailCourier() {
382+
func (m *memoryMailBox) wireMailCourier(ctx context.Context) {
376383
defer close(m.wireShutdown)
377384

378385
for {
@@ -389,6 +396,9 @@ func (m *memoryMailBox) wireMailCourier() {
389396
case <-m.quit:
390397
m.wireCond.L.Unlock()
391398
return
399+
case <-ctx.Done():
400+
m.wireCond.L.Unlock()
401+
return
392402
default:
393403
}
394404
}
@@ -418,13 +428,15 @@ func (m *memoryMailBox) wireMailCourier() {
418428
close(msgDone)
419429
case <-m.quit:
420430
return
431+
case <-ctx.Done():
432+
return
421433
}
422434
}
423435
}
424436

425437
// pktMailCourier is a dedicated goroutine whose job is to reliably deliver
426438
// packet messages.
427-
func (m *memoryMailBox) pktMailCourier() {
439+
func (m *memoryMailBox) pktMailCourier(ctx context.Context) {
428440
defer close(m.pktShutdown)
429441

430442
for {
@@ -447,6 +459,11 @@ func (m *memoryMailBox) pktMailCourier() {
447459
case <-m.quit:
448460
m.pktCond.L.Unlock()
449461
return
462+
463+
case <-ctx.Done():
464+
m.pktCond.L.Unlock()
465+
return
466+
450467
default:
451468
}
452469
}
@@ -541,7 +558,7 @@ func (m *memoryMailBox) pktMailCourier() {
541558
case <-deadline:
542559
log.Debugf("Expiring add htlc with "+
543560
"keystone=%v", add.keystone())
544-
m.FailAdd(add)
561+
m.FailAdd(ctx, add)
545562

546563
case pktDone := <-m.pktReset:
547564
m.pktCond.L.Lock()
@@ -553,6 +570,9 @@ func (m *memoryMailBox) pktMailCourier() {
553570

554571
case <-m.quit:
555572
return
573+
574+
case <-ctx.Done():
575+
return
556576
}
557577
}
558578
}
@@ -688,9 +708,7 @@ func (m *memoryMailBox) DustPackets() (lnwire.MilliSatoshi,
688708
// delivered after the link restarts if the switch has remained online. The
689709
// generated LinkError will show an OutgoingFailureDownstreamHtlcAdd
690710
// FailureDetail.
691-
func (m *memoryMailBox) FailAdd(pkt *htlcPacket) {
692-
ctx := context.TODO()
693-
711+
func (m *memoryMailBox) FailAdd(ctx context.Context, pkt *htlcPacket) {
694712
// First, remove the packet from mailbox. If we didn't find the packet
695713
// because it has already been acked, we'll exit early to avoid sending
696714
// a duplicate fail message through the switch.
@@ -844,8 +862,8 @@ func (mo *mailOrchestrator) Stop() {
844862

845863
// GetOrCreateMailBox returns an existing mailbox belonging to `chanID`, or
846864
// creates and returns a new mailbox if none is found.
847-
func (mo *mailOrchestrator) GetOrCreateMailBox(chanID lnwire.ChannelID,
848-
shortChanID lnwire.ShortChannelID) MailBox {
865+
func (mo *mailOrchestrator) GetOrCreateMailBox(ctx context.Context,
866+
chanID lnwire.ChannelID, shortChanID lnwire.ShortChannelID) MailBox {
849867

850868
// First, try lookup the mailbox directly using only the shared mutex.
851869
mo.mu.RLock()
@@ -859,7 +877,7 @@ func (mo *mailOrchestrator) GetOrCreateMailBox(chanID lnwire.ChannelID,
859877
// Otherwise, we will try again with exclusive lock, creating a mailbox
860878
// if one still has not been created.
861879
mo.mu.Lock()
862-
mailbox = mo.exclusiveGetOrCreateMailBox(chanID, shortChanID)
880+
mailbox = mo.exclusiveGetOrCreateMailBox(ctx, chanID, shortChanID)
863881
mo.mu.Unlock()
864882

865883
return mailbox
@@ -870,7 +888,7 @@ func (mo *mailOrchestrator) GetOrCreateMailBox(chanID lnwire.ChannelID,
870888
// recorded.
871889
//
872890
// NOTE: This method MUST be invoked with the mailOrchestrator's exclusive lock.
873-
func (mo *mailOrchestrator) exclusiveGetOrCreateMailBox(
891+
func (mo *mailOrchestrator) exclusiveGetOrCreateMailBox(ctx context.Context,
874892
chanID lnwire.ChannelID, shortChanID lnwire.ShortChannelID) MailBox {
875893

876894
mailbox, ok := mo.mailboxes[chanID]
@@ -882,7 +900,7 @@ func (mo *mailOrchestrator) exclusiveGetOrCreateMailBox(
882900
expiry: mo.cfg.expiry,
883901
failMailboxUpdate: mo.cfg.failMailboxUpdate,
884902
})
885-
mailbox.Start()
903+
mailbox.Start(ctx)
886904
mo.mailboxes[chanID] = mailbox
887905
}
888906

@@ -916,7 +934,7 @@ func (mo *mailOrchestrator) BindLiveShortChanID(mailbox MailBox,
916934
// to channel_id. If the mailbox is found, the message is delivered directly.
917935
// Otherwise the packet is recorded as unclaimed, and will be delivered to the
918936
// mailbox upon the subsequent call to BindLiveShortChanID.
919-
func (mo *mailOrchestrator) Deliver(
937+
func (mo *mailOrchestrator) Deliver(ctx context.Context,
920938
sid lnwire.ShortChannelID, pkt *htlcPacket) error {
921939

922940
var (
@@ -961,7 +979,7 @@ func (mo *mailOrchestrator) Deliver(
961979
// index should only be set if the mailbox had been initialized
962980
// beforehand. However, this does ensure that this case is
963981
// handled properly in the event that it could happen.
964-
mailbox = mo.exclusiveGetOrCreateMailBox(chanID, sid)
982+
mailbox = mo.exclusiveGetOrCreateMailBox(ctx, chanID, sid)
965983
mo.mu.Unlock()
966984

967985
// Deliver the packet to the mailbox if it was found or created.

htlcswitch/mailbox_test.go

+10-8
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ func newMailboxContextWithClock(t *testing.T,
218218
forwardPackets: ctx.forward,
219219
clock: clock,
220220
})
221-
ctx.mailbox.Start()
221+
ctx.mailbox.Start(context.Background())
222222
t.Cleanup(ctx.mailbox.Stop)
223223

224224
return ctx
@@ -245,7 +245,7 @@ func newMailboxContext(t *testing.T, startTime time.Time,
245245
clock: ctx.clock,
246246
expiry: expiry,
247247
})
248-
ctx.mailbox.Start()
248+
ctx.mailbox.Start(context.Background())
249249
t.Cleanup(ctx.mailbox.Stop)
250250

251251
return ctx
@@ -334,6 +334,7 @@ func (c *mailboxContext) checkFails(adds []*htlcPacket) {
334334
// TestMailBoxFailAdd asserts that FailAdd returns a response to the switch
335335
// under various interleavings with other operations on the mailbox.
336336
func TestMailBoxFailAdd(t *testing.T) {
337+
t.Parallel()
337338
var (
338339
batchDelay = time.Second
339340
expiry = time.Minute
@@ -346,7 +347,7 @@ func TestMailBoxFailAdd(t *testing.T) {
346347

347348
failAdds := func(adds []*htlcPacket) {
348349
for _, add := range adds {
349-
ctx.mailbox.FailAdd(add)
350+
ctx.mailbox.FailAdd(context.Background(), add)
350351
}
351352
}
352353

@@ -541,7 +542,7 @@ func TestMailBoxDuplicateAddPacket(t *testing.T) {
541542
t.Parallel()
542543

543544
ctx := newMailboxContext(t, time.Now(), testExpiry)
544-
ctx.mailbox.Start()
545+
ctx.mailbox.Start(context.Background())
545546

546547
addTwice := func(t *testing.T, pkt *htlcPacket) {
547548
// The first add should succeed.
@@ -697,6 +698,7 @@ func testMailBoxDust(t *testing.T, chantype channeldb.ChannelType) {
697698
// readily to mailboxes for channels that are already in the live state.
698699
func TestMailOrchestrator(t *testing.T) {
699700
t.Parallel()
701+
ctx := context.Background()
700702

701703
failMailboxUpdate := func(_ context.Context, outScid,
702704
mboxScid lnwire.ShortChannelID) lnwire.FailureMessage {
@@ -738,11 +740,11 @@ func TestMailOrchestrator(t *testing.T) {
738740
}
739741
sentPackets[i] = pkt
740742

741-
mo.Deliver(pkt.outgoingChanID, pkt)
743+
mo.Deliver(ctx, pkt.outgoingChanID, pkt)
742744
}
743745

744746
// Now, initialize a new mailbox for Alice's chanid.
745-
mailbox := mo.GetOrCreateMailBox(chanID1, aliceChanID)
747+
mailbox := mo.GetOrCreateMailBox(ctx, chanID1, aliceChanID)
746748

747749
// Verify that no messages are received, since Alice's mailbox has not
748750
// been made live.
@@ -787,7 +789,7 @@ func TestMailOrchestrator(t *testing.T) {
787789

788790
// For the second half of the test, create a new mailbox for Bob and
789791
// immediately make it live with an assigned short chan id.
790-
mailbox = mo.GetOrCreateMailBox(chanID2, bobChanID)
792+
mailbox = mo.GetOrCreateMailBox(ctx, chanID2, bobChanID)
791793
mo.BindLiveShortChanID(mailbox, chanID2, bobChanID)
792794

793795
// Create the second half of our htlcs, and deliver them via the
@@ -806,7 +808,7 @@ func TestMailOrchestrator(t *testing.T) {
806808
}
807809
sentPackets[i] = pkt
808810

809-
mo.Deliver(pkt.incomingChanID, pkt)
811+
mo.Deliver(ctx, pkt.incomingChanID, pkt)
810812

811813
timeout := time.After(50 * time.Millisecond)
812814
select {

0 commit comments

Comments
 (0)