Skip to content

Commit d8df37d

Browse files
committed
Use BaseEventHandler to expose async event handling on InvoicePayer
We introduce a new sealed trait BaseEventHandler that has a blanket implementation for any T. Since the trait cannot be implemented outside of the crate, this allow us to expose specific implementations of InvoicePayer that allow for synchronous and asynchronous event handling.
1 parent 67c2932 commit d8df37d

File tree

1 file changed

+51
-15
lines changed

1 file changed

+51
-15
lines changed

lightning-invoice/src/payment.rs

+51-15
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ use secp256k1::PublicKey;
157157

158158
use core::fmt;
159159
use core::fmt::{Debug, Display, Formatter};
160+
use core::future::Future;
160161
use core::ops::Deref;
161162
use core::time::Duration;
162163
#[cfg(feature = "std")]
@@ -176,8 +177,15 @@ use crate::time_utils;
176177
#[cfg(feature = "no-std")]
177178
type ConfiguredTime = time_utils::Eternity;
178179

180+
/// Sealed trait with a blanket implementation to allow both sync and async implementations of event
181+
/// handling to exist within the InvoicePayer.
182+
mod sealed {
183+
pub trait BaseEventHandler {}
184+
impl<T> BaseEventHandler for T {}
185+
}
186+
179187
/// (C-not exported) generally all users should use the [`InvoicePayer`] type alias.
180-
pub struct InvoicePayerUsingTime<P: Deref, R: ScoringRouter, L: Deref, E: EventHandler, T: Time>
188+
pub struct InvoicePayerUsingTime<P: Deref, R: ScoringRouter, L: Deref, E: sealed::BaseEventHandler, T: Time>
181189
where
182190
P::Target: Payer,
183191
L::Target: Logger,
@@ -342,7 +350,7 @@ pub enum PaymentError {
342350
Sending(PaymentSendFailure),
343351
}
344352

345-
impl<P: Deref, R: ScoringRouter, L: Deref, E: EventHandler, T: Time> InvoicePayerUsingTime<P, R, L, E, T>
353+
impl<P: Deref, R: ScoringRouter, L: Deref, E: sealed::BaseEventHandler, T: Time> InvoicePayerUsingTime<P, R, L, E, T>
346354
where
347355
P::Target: Payer,
348356
L::Target: Logger,
@@ -744,12 +752,12 @@ fn has_expired(route_params: &RouteParameters) -> bool {
744752
} else { false }
745753
}
746754

747-
impl<P: Deref, R: ScoringRouter, L: Deref, E: EventHandler, T: Time> EventHandler for InvoicePayerUsingTime<P, R, L, E, T>
755+
impl<P: Deref, R: ScoringRouter, L: Deref, E: sealed::BaseEventHandler, T: Time> InvoicePayerUsingTime<P, R, L, E, T>
748756
where
749757
P::Target: Payer,
750758
L::Target: Logger,
751759
{
752-
fn handle_event(&self, event: Event) {
760+
fn handle_event_internal(&self, event: Event) -> Option<Event> {
753761
match event {
754762
Event::PaymentPathFailed { payment_hash, ref path, .. }
755763
| Event::PaymentPathSuccessful { ref path, payment_hash: Some(payment_hash), .. }
@@ -779,7 +787,7 @@ where
779787
self.payer.abandon_payment(payment_id.unwrap());
780788
} else if self.retry_payment(payment_id.unwrap(), payment_hash, retry.as_ref().unwrap()).is_ok() {
781789
// We retried at least somewhat, don't provide the PaymentPathFailed event to the user.
782-
return;
790+
return None;
783791
} else {
784792
self.payer.abandon_payment(payment_id.unwrap());
785793
}
@@ -814,7 +822,35 @@ where
814822
}
815823

816824
// Delegate to the decorated event handler unless the payment is retried.
817-
self.event_handler.handle_event(event)
825+
Some(event)
826+
}
827+
}
828+
829+
impl<P: Deref, R: ScoringRouter, L: Deref, E: EventHandler, T: Time>
830+
EventHandler for InvoicePayerUsingTime<P, R, L, E, T>
831+
where
832+
P::Target: Payer,
833+
L::Target: Logger,
834+
{
835+
fn handle_event(&self, event: Event) {
836+
if let Some(event) = self.handle_event_internal(event) {
837+
self.event_handler.handle_event(event)
838+
}
839+
}
840+
}
841+
842+
impl<P: Deref, R: ScoringRouter, L: Deref, T: Time, F: Future, H: Fn(Event) -> F>
843+
InvoicePayerUsingTime<P, R, L, H, T>
844+
where
845+
P::Target: Payer,
846+
L::Target: Logger,
847+
{
848+
/// Intercepts events required by the [`InvoicePayer`] and forwards them to the underlying event
849+
/// handler, if necessary, to handle them asynchronously.
850+
pub async fn handle_event_async(&self, event: Event) {
851+
if let Some(event) = self.handle_event_internal(event) {
852+
(self.event_handler)(event).await;
853+
}
818854
}
819855
}
820856

@@ -980,7 +1016,7 @@ mod tests {
9801016

9811017
#[test]
9821018
fn pays_invoice_on_partial_failure() {
983-
let event_handler = |_: _| { panic!() };
1019+
let event_handler = |_: Event| { panic!() };
9841020

9851021
let payment_preimage = PaymentPreimage([1; 32]);
9861022
let invoice = invoice(payment_preimage);
@@ -1184,7 +1220,7 @@ mod tests {
11841220
#[test]
11851221
fn fails_paying_invoice_after_expiration() {
11861222
let event_handled = core::cell::RefCell::new(false);
1187-
let event_handler = |_: _| { *event_handled.borrow_mut() = true; };
1223+
let event_handler = |_: Event| { *event_handled.borrow_mut() = true; };
11881224

11891225
let payer = TestPayer::new();
11901226
let router = TestRouter::new(TestScorer::new());
@@ -1360,7 +1396,7 @@ mod tests {
13601396
let router = FailingRouter {};
13611397
let logger = TestLogger::new();
13621398
let invoice_payer =
1363-
InvoicePayer::new(&payer, router, &logger, |_: _| {}, Retry::Attempts(0));
1399+
InvoicePayer::new(&payer, router, &logger, |_: Event| {}, Retry::Attempts(0));
13641400

13651401
let payment_preimage = PaymentPreimage([1; 32]);
13661402
let invoice = invoice(payment_preimage);
@@ -1383,7 +1419,7 @@ mod tests {
13831419
let router = TestRouter::new(TestScorer::new());
13841420
let logger = TestLogger::new();
13851421
let invoice_payer =
1386-
InvoicePayer::new(&payer, router, &logger, |_: _| {}, Retry::Attempts(0));
1422+
InvoicePayer::new(&payer, router, &logger, |_: Event| {}, Retry::Attempts(0));
13871423

13881424
match invoice_payer.pay_invoice(&invoice) {
13891425
Err(PaymentError::Sending(_)) => {},
@@ -1422,7 +1458,7 @@ mod tests {
14221458
#[test]
14231459
fn fails_paying_zero_value_invoice_with_amount() {
14241460
let event_handled = core::cell::RefCell::new(false);
1425-
let event_handler = |_: _| { *event_handled.borrow_mut() = true; };
1461+
let event_handler = |_: Event| { *event_handled.borrow_mut() = true; };
14261462

14271463
let payer = TestPayer::new();
14281464
let router = TestRouter::new(TestScorer::new());
@@ -1732,7 +1768,7 @@ mod tests {
17321768
#[test]
17331769
fn accounts_for_some_inflight_htlcs_sent_during_partial_failure() {
17341770
let event_handled = core::cell::RefCell::new(false);
1735-
let event_handler = |_: _| { *event_handled.borrow_mut() = true; };
1771+
let event_handler = |_: Event| { *event_handled.borrow_mut() = true; };
17361772

17371773
let payment_preimage = PaymentPreimage([1; 32]);
17381774
let invoice_to_pay = invoice(payment_preimage);
@@ -1763,7 +1799,7 @@ mod tests {
17631799
#[test]
17641800
fn accounts_for_all_inflight_htlcs_sent_during_partial_failure() {
17651801
let event_handled = core::cell::RefCell::new(false);
1766-
let event_handler = |_: _| { *event_handled.borrow_mut() = true; };
1802+
let event_handler = |_: Event| { *event_handled.borrow_mut() = true; };
17671803

17681804
let payment_preimage = PaymentPreimage([1; 32]);
17691805
let invoice_to_pay = invoice(payment_preimage);
@@ -2260,7 +2296,7 @@ mod tests {
22602296
route.paths[1][0].fee_msat = 50_000_000;
22612297
router.expect_find_route(Ok(route.clone()));
22622298

2263-
let event_handler = |_: _| { panic!(); };
2299+
let event_handler = |_: Event| { panic!(); };
22642300
let invoice_payer = InvoicePayer::new(nodes[0].node, router, nodes[0].logger, event_handler, Retry::Attempts(1));
22652301

22662302
assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager_and_duration_since_epoch(
@@ -2305,7 +2341,7 @@ mod tests {
23052341
route.paths[1][0].fee_msat = 50_000_001;
23062342
router.expect_find_route(Ok(route.clone()));
23072343

2308-
let event_handler = |_: _| { panic!(); };
2344+
let event_handler = |_: Event| { panic!(); };
23092345
let invoice_payer = InvoicePayer::new(nodes[0].node, router, nodes[0].logger, event_handler, Retry::Attempts(1));
23102346

23112347
assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager_and_duration_since_epoch(

0 commit comments

Comments
 (0)