Skip to content

Commit ee109e6

Browse files
committed
Refactor lightning-invoice/src/utils.rs to yield iterators
- two functions refatored: `select_phantom_hints`, `sort_and_filter_channels`
1 parent c8a2d79 commit ee109e6

File tree

1 file changed

+60
-59
lines changed

1 file changed

+60
-59
lines changed

lightning-invoice/src/utils.rs

+60-59
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,9 @@ where
203203
invoice = invoice.amount_milli_satoshis(amt);
204204
}
205205

206-
for route_hint in select_phantom_hints(amt_msat, phantom_route_hints, logger) {
206+
const MAX_HINTS: usize = 3;
207+
208+
for route_hint in select_phantom_hints(amt_msat, phantom_route_hints, logger).take(MAX_HINTS) {
207209
invoice = invoice.private_route(route_hint);
208210
}
209211

@@ -230,36 +232,48 @@ where
230232
///
231233
/// [`PhantomKeysManager`]: lightning::sign::PhantomKeysManager
232234
fn select_phantom_hints<L: Deref>(amt_msat: Option<u64>, phantom_route_hints: Vec<PhantomRouteHints>,
233-
logger: L) -> Vec<RouteHint>
235+
logger: L) -> impl Iterator<Item = RouteHint>
234236
where
235237
L::Target: Logger,
236238
{
237-
let mut phantom_hints: Vec<Vec<RouteHint>> = Vec::new();
239+
let mut phantom_hints: Vec<_> = Vec::new();
238240

239241
for PhantomRouteHints { channels, phantom_scid, real_node_pubkey } in phantom_route_hints {
240242
log_trace!(logger, "Generating phantom route hints for node {}",
241243
log_pubkey!(real_node_pubkey));
242-
let mut route_hints = sort_and_filter_channels(channels, amt_msat, &logger);
244+
let route_hints = sort_and_filter_channels(channels, amt_msat, &logger);
243245

244246
// If we have any public channel, the route hints from `sort_and_filter_channels` will be
245247
// empty. In that case we create a RouteHint on which we will push a single hop with the
246248
// phantom route into the invoice, and let the sender find the path to the `real_node_pubkey`
247249
// node by looking at our public channels.
248-
if route_hints.is_empty() {
249-
route_hints.push(RouteHint(vec![]))
250-
}
251-
for route_hint in &mut route_hints {
252-
route_hint.0.push(RouteHintHop {
253-
src_node_id: real_node_pubkey,
254-
short_channel_id: phantom_scid,
255-
fees: RoutingFees {
256-
base_msat: 0,
257-
proportional_millionths: 0,
258-
},
259-
cltv_expiry_delta: MIN_CLTV_EXPIRY_DELTA,
260-
htlc_minimum_msat: None,
261-
htlc_maximum_msat: None,});
262-
}
250+
let empty_route_hints = route_hints.len() == 0;
251+
let mut have_pushed_empty = false;
252+
let route_hints = route_hints
253+
.chain(core::iter::from_fn(move || {
254+
if empty_route_hints && !have_pushed_empty {
255+
// set flag of having handled the empty route_hints and ensure empty vector
256+
// returned only once
257+
have_pushed_empty = true;
258+
Some(RouteHint(Vec::new()))
259+
} else {
260+
None
261+
}
262+
}))
263+
.map(move |mut hint| {
264+
hint.0.push(RouteHintHop {
265+
src_node_id: real_node_pubkey,
266+
short_channel_id: phantom_scid,
267+
fees: RoutingFees {
268+
base_msat: 0,
269+
proportional_millionths: 0,
270+
},
271+
cltv_expiry_delta: MIN_CLTV_EXPIRY_DELTA,
272+
htlc_minimum_msat: None,
273+
htlc_maximum_msat: None,
274+
});
275+
hint
276+
});
263277

264278
phantom_hints.push(route_hints);
265279
}
@@ -268,29 +282,7 @@ where
268282
// the hints across our real nodes we add one hint from each in turn until no node has any hints
269283
// left (if one node has more hints than any other, these will accumulate at the end of the
270284
// vector).
271-
let mut invoice_hints: Vec<RouteHint> = Vec::new();
272-
let mut hint_idx = 0;
273-
274-
loop {
275-
let mut remaining_hints = false;
276-
277-
for hints in phantom_hints.iter() {
278-
if invoice_hints.len() == 3 {
279-
return invoice_hints
280-
}
281-
282-
if hint_idx < hints.len() {
283-
invoice_hints.push(hints[hint_idx].clone());
284-
remaining_hints = true
285-
}
286-
}
287-
288-
if !remaining_hints {
289-
return invoice_hints
290-
}
291-
292-
hint_idx +=1;
293-
}
285+
rotate_through_iterators(phantom_hints)
294286
}
295287

296288
/// Draw items iteratively from multiple iterators. The items are retrieved by index and
@@ -318,6 +310,8 @@ fn rotate_through_iterators<T, I: Iterator<Item = T>>(mut vecs: Vec<I>) -> impl
318310
}
319311
}
320312
})
313+
}
314+
321315
#[cfg(feature = "std")]
322316
/// Utility to construct an invoice. Generally, unless you want to do something like a custom
323317
/// cltv_expiry, this is what you should be using to create an invoice. The reason being, this
@@ -601,15 +595,34 @@ fn _create_invoice_from_channelmanager_and_duration_since_epoch_with_payment_has
601595
/// * Sorted by lowest inbound capacity if an online channel with the minimum amount requested exists,
602596
/// otherwise sort by highest inbound capacity to give the payment the best chance of succeeding.
603597
fn sort_and_filter_channels<L: Deref>(
604-
channels: Vec<ChannelDetails>, min_inbound_capacity_msat: Option<u64>, logger: &L
605-
) -> Vec<RouteHint> where L::Target: Logger {
598+
channels: Vec<ChannelDetails>,
599+
min_inbound_capacity_msat: Option<u64>,
600+
logger: &L,
601+
) -> impl ExactSizeIterator<Item = RouteHint>
602+
where
603+
L::Target: Logger,
604+
{
606605
let mut filtered_channels: HashMap<PublicKey, ChannelDetails> = HashMap::new();
607606
let min_inbound_capacity = min_inbound_capacity_msat.unwrap_or(0);
608607
let mut min_capacity_channel_exists = false;
609608
let mut online_channel_exists = false;
610609
let mut online_min_capacity_channel_exists = false;
611610
let mut has_pub_unconf_chan = false;
612611

612+
let route_hint_from_channel = |channel: ChannelDetails| {
613+
let forwarding_info = channel.counterparty.forwarding_info.as_ref().unwrap();
614+
RouteHint(vec![RouteHintHop {
615+
src_node_id: channel.counterparty.node_id,
616+
short_channel_id: channel.get_inbound_payment_scid().unwrap(),
617+
fees: RoutingFees {
618+
base_msat: forwarding_info.fee_base_msat,
619+
proportional_millionths: forwarding_info.fee_proportional_millionths,
620+
},
621+
cltv_expiry_delta: forwarding_info.cltv_expiry_delta,
622+
htlc_minimum_msat: channel.inbound_htlc_minimum_msat,
623+
htlc_maximum_msat: channel.inbound_htlc_maximum_msat,}])
624+
};
625+
613626
log_trace!(logger, "Considering {} channels for invoice route hints", channels.len());
614627
for channel in channels.into_iter().filter(|chan| chan.is_channel_ready) {
615628
if channel.get_inbound_payment_scid().is_none() || channel.counterparty.forwarding_info.is_none() {
@@ -628,7 +641,7 @@ fn sort_and_filter_channels<L: Deref>(
628641
// look at the public channels instead.
629642
log_trace!(logger, "Not including channels in invoice route hints on account of public channel {}",
630643
log_bytes!(channel.channel_id));
631-
return vec![]
644+
return vec![].into_iter().take(3).map(route_hint_from_channel);
632645
}
633646
}
634647

@@ -688,19 +701,6 @@ fn sort_and_filter_channels<L: Deref>(
688701
}
689702
}
690703

691-
let route_hint_from_channel = |channel: ChannelDetails| {
692-
let forwarding_info = channel.counterparty.forwarding_info.as_ref().unwrap();
693-
RouteHint(vec![RouteHintHop {
694-
src_node_id: channel.counterparty.node_id,
695-
short_channel_id: channel.get_inbound_payment_scid().unwrap(),
696-
fees: RoutingFees {
697-
base_msat: forwarding_info.fee_base_msat,
698-
proportional_millionths: forwarding_info.fee_proportional_millionths,
699-
},
700-
cltv_expiry_delta: forwarding_info.cltv_expiry_delta,
701-
htlc_minimum_msat: channel.inbound_htlc_minimum_msat,
702-
htlc_maximum_msat: channel.inbound_htlc_maximum_msat,}])
703-
};
704704
// If all channels are private, prefer to return route hints which have a higher capacity than
705705
// the payment value and where we're currently connected to the channel counterparty.
706706
// Even if we cannot satisfy both goals, always ensure we include *some* hints, preferring
@@ -750,7 +750,8 @@ fn sort_and_filter_channels<L: Deref>(
750750
} else {
751751
b.inbound_capacity_msat.cmp(&a.inbound_capacity_msat)
752752
}});
753-
eligible_channels.into_iter().take(3).map(route_hint_from_channel).collect::<Vec<RouteHint>>()
753+
754+
eligible_channels.into_iter().take(3).map(route_hint_from_channel)
754755
}
755756

756757
/// prefer_current_channel chooses a channel to use for route hints between a currently selected and candidate

0 commit comments

Comments
 (0)