Skip to content

Commit 486c16a

Browse files
authored
Merge pull request #2290 from upjohnc/2240_replace_vectors_with_iterators
Set return type to Iterator for functions in file: `lightning-invoice/src/utils.rs` : issue #2240
2 parents 166f326 + ec67ee7 commit 486c16a

File tree

1 file changed

+190
-55
lines changed

1 file changed

+190
-55
lines changed

lightning-invoice/src/utils.rs

+190-55
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use lightning::util::logger::Logger;
1818
use secp256k1::PublicKey;
1919
use core::ops::Deref;
2020
use core::time::Duration;
21+
use core::iter::Iterator;
2122

2223
/// Utility to create an invoice that can be paid to one of multiple nodes, or a "phantom invoice."
2324
/// See [`PhantomKeysManager`] for more information on phantom node payments.
@@ -132,6 +133,8 @@ where
132133
)
133134
}
134135

136+
const MAX_CHANNEL_HINTS: usize = 3;
137+
135138
fn _create_phantom_invoice<ES: Deref, NS: Deref, L: Deref>(
136139
amt_msat: Option<u64>, payment_hash: Option<PaymentHash>, description: InvoiceDescription,
137140
invoice_expiry_delta_secs: u32, phantom_route_hints: Vec<PhantomRouteHints>, entropy_source: ES,
@@ -202,7 +205,8 @@ where
202205
invoice = invoice.amount_milli_satoshis(amt);
203206
}
204207

205-
for route_hint in select_phantom_hints(amt_msat, phantom_route_hints, logger) {
208+
209+
for route_hint in select_phantom_hints(amt_msat, phantom_route_hints, logger).take(MAX_CHANNEL_HINTS) {
206210
invoice = invoice.private_route(route_hint);
207211
}
208212

@@ -229,36 +233,48 @@ where
229233
///
230234
/// [`PhantomKeysManager`]: lightning::sign::PhantomKeysManager
231235
fn select_phantom_hints<L: Deref>(amt_msat: Option<u64>, phantom_route_hints: Vec<PhantomRouteHints>,
232-
logger: L) -> Vec<RouteHint>
236+
logger: L) -> impl Iterator<Item = RouteHint>
233237
where
234238
L::Target: Logger,
235239
{
236-
let mut phantom_hints: Vec<Vec<RouteHint>> = Vec::new();
240+
let mut phantom_hints: Vec<_> = Vec::new();
237241

238242
for PhantomRouteHints { channels, phantom_scid, real_node_pubkey } in phantom_route_hints {
239243
log_trace!(logger, "Generating phantom route hints for node {}",
240244
log_pubkey!(real_node_pubkey));
241-
let mut route_hints = sort_and_filter_channels(channels, amt_msat, &logger);
245+
let route_hints = sort_and_filter_channels(channels, amt_msat, &logger);
242246

243247
// If we have any public channel, the route hints from `sort_and_filter_channels` will be
244248
// empty. In that case we create a RouteHint on which we will push a single hop with the
245249
// phantom route into the invoice, and let the sender find the path to the `real_node_pubkey`
246250
// node by looking at our public channels.
247-
if route_hints.is_empty() {
248-
route_hints.push(RouteHint(vec![]))
249-
}
250-
for route_hint in &mut route_hints {
251-
route_hint.0.push(RouteHintHop {
252-
src_node_id: real_node_pubkey,
253-
short_channel_id: phantom_scid,
254-
fees: RoutingFees {
255-
base_msat: 0,
256-
proportional_millionths: 0,
257-
},
258-
cltv_expiry_delta: MIN_CLTV_EXPIRY_DELTA,
259-
htlc_minimum_msat: None,
260-
htlc_maximum_msat: None,});
261-
}
251+
let empty_route_hints = route_hints.len() == 0;
252+
let mut have_pushed_empty = false;
253+
let route_hints = route_hints
254+
.chain(core::iter::from_fn(move || {
255+
if empty_route_hints && !have_pushed_empty {
256+
// set flag of having handled the empty route_hints and ensure empty vector
257+
// returned only once
258+
have_pushed_empty = true;
259+
Some(RouteHint(Vec::new()))
260+
} else {
261+
None
262+
}
263+
}))
264+
.map(move |mut hint| {
265+
hint.0.push(RouteHintHop {
266+
src_node_id: real_node_pubkey,
267+
short_channel_id: phantom_scid,
268+
fees: RoutingFees {
269+
base_msat: 0,
270+
proportional_millionths: 0,
271+
},
272+
cltv_expiry_delta: MIN_CLTV_EXPIRY_DELTA,
273+
htlc_minimum_msat: None,
274+
htlc_maximum_msat: None,
275+
});
276+
hint
277+
});
262278

263279
phantom_hints.push(route_hints);
264280
}
@@ -267,29 +283,34 @@ where
267283
// the hints across our real nodes we add one hint from each in turn until no node has any hints
268284
// left (if one node has more hints than any other, these will accumulate at the end of the
269285
// vector).
270-
let mut invoice_hints: Vec<RouteHint> = Vec::new();
271-
let mut hint_idx = 0;
286+
rotate_through_iterators(phantom_hints)
287+
}
272288

273-
loop {
274-
let mut remaining_hints = false;
289+
/// Draw items iteratively from multiple iterators. The items are retrieved by index and
290+
/// rotates through the iterators - first the zero index then the first index then second index, etc.
291+
fn rotate_through_iterators<T, I: Iterator<Item = T>>(mut vecs: Vec<I>) -> impl Iterator<Item = T> {
292+
let mut iterations = 0;
275293

276-
for hints in phantom_hints.iter() {
277-
if invoice_hints.len() == 3 {
278-
return invoice_hints
294+
core::iter::from_fn(move || {
295+
let mut exhausted_iterators = 0;
296+
loop {
297+
if vecs.is_empty() {
298+
return None;
279299
}
280-
281-
if hint_idx < hints.len() {
282-
invoice_hints.push(hints[hint_idx].clone());
283-
remaining_hints = true
300+
let next_idx = iterations % vecs.len();
301+
iterations += 1;
302+
if let Some(item) = vecs[next_idx].next() {
303+
return Some(item);
304+
}
305+
// exhausted_vectors increase when the "next_idx" vector is exhausted
306+
exhausted_iterators += 1;
307+
// The check for exhausted iterators gets reset to 0 after each yield of `Some()`
308+
// The loop will return None when all of the nested iterators are exhausted
309+
if exhausted_iterators == vecs.len() {
310+
return None;
284311
}
285312
}
286-
287-
if !remaining_hints {
288-
return invoice_hints
289-
}
290-
291-
hint_idx +=1;
292-
}
313+
})
293314
}
294315

295316
#[cfg(feature = "std")]
@@ -575,15 +596,34 @@ fn _create_invoice_from_channelmanager_and_duration_since_epoch_with_payment_has
575596
/// * Sorted by lowest inbound capacity if an online channel with the minimum amount requested exists,
576597
/// otherwise sort by highest inbound capacity to give the payment the best chance of succeeding.
577598
fn sort_and_filter_channels<L: Deref>(
578-
channels: Vec<ChannelDetails>, min_inbound_capacity_msat: Option<u64>, logger: &L
579-
) -> Vec<RouteHint> where L::Target: Logger {
599+
channels: Vec<ChannelDetails>,
600+
min_inbound_capacity_msat: Option<u64>,
601+
logger: &L,
602+
) -> impl ExactSizeIterator<Item = RouteHint>
603+
where
604+
L::Target: Logger,
605+
{
580606
let mut filtered_channels: HashMap<PublicKey, ChannelDetails> = HashMap::new();
581607
let min_inbound_capacity = min_inbound_capacity_msat.unwrap_or(0);
582608
let mut min_capacity_channel_exists = false;
583609
let mut online_channel_exists = false;
584610
let mut online_min_capacity_channel_exists = false;
585611
let mut has_pub_unconf_chan = false;
586612

613+
let route_hint_from_channel = |channel: ChannelDetails| {
614+
let forwarding_info = channel.counterparty.forwarding_info.as_ref().unwrap();
615+
RouteHint(vec![RouteHintHop {
616+
src_node_id: channel.counterparty.node_id,
617+
short_channel_id: channel.get_inbound_payment_scid().unwrap(),
618+
fees: RoutingFees {
619+
base_msat: forwarding_info.fee_base_msat,
620+
proportional_millionths: forwarding_info.fee_proportional_millionths,
621+
},
622+
cltv_expiry_delta: forwarding_info.cltv_expiry_delta,
623+
htlc_minimum_msat: channel.inbound_htlc_minimum_msat,
624+
htlc_maximum_msat: channel.inbound_htlc_maximum_msat,}])
625+
};
626+
587627
log_trace!(logger, "Considering {} channels for invoice route hints", channels.len());
588628
for channel in channels.into_iter().filter(|chan| chan.is_channel_ready) {
589629
if channel.get_inbound_payment_scid().is_none() || channel.counterparty.forwarding_info.is_none() {
@@ -602,7 +642,7 @@ fn sort_and_filter_channels<L: Deref>(
602642
// look at the public channels instead.
603643
log_trace!(logger, "Not including channels in invoice route hints on account of public channel {}",
604644
log_bytes!(channel.channel_id));
605-
return vec![]
645+
return vec![].into_iter().take(MAX_CHANNEL_HINTS).map(route_hint_from_channel);
606646
}
607647
}
608648

@@ -662,19 +702,6 @@ fn sort_and_filter_channels<L: Deref>(
662702
}
663703
}
664704

665-
let route_hint_from_channel = |channel: ChannelDetails| {
666-
let forwarding_info = channel.counterparty.forwarding_info.as_ref().unwrap();
667-
RouteHint(vec![RouteHintHop {
668-
src_node_id: channel.counterparty.node_id,
669-
short_channel_id: channel.get_inbound_payment_scid().unwrap(),
670-
fees: RoutingFees {
671-
base_msat: forwarding_info.fee_base_msat,
672-
proportional_millionths: forwarding_info.fee_proportional_millionths,
673-
},
674-
cltv_expiry_delta: forwarding_info.cltv_expiry_delta,
675-
htlc_minimum_msat: channel.inbound_htlc_minimum_msat,
676-
htlc_maximum_msat: channel.inbound_htlc_maximum_msat,}])
677-
};
678705
// If all channels are private, prefer to return route hints which have a higher capacity than
679706
// the payment value and where we're currently connected to the channel counterparty.
680707
// Even if we cannot satisfy both goals, always ensure we include *some* hints, preferring
@@ -724,7 +751,8 @@ fn sort_and_filter_channels<L: Deref>(
724751
} else {
725752
b.inbound_capacity_msat.cmp(&a.inbound_capacity_msat)
726753
}});
727-
eligible_channels.into_iter().take(3).map(route_hint_from_channel).collect::<Vec<RouteHint>>()
754+
755+
eligible_channels.into_iter().take(MAX_CHANNEL_HINTS).map(route_hint_from_channel)
728756
}
729757

730758
/// prefer_current_channel chooses a channel to use for route hints between a currently selected and candidate
@@ -777,7 +805,7 @@ mod test {
777805
use lightning::routing::router::{PaymentParameters, RouteParameters};
778806
use lightning::util::test_utils;
779807
use lightning::util::config::UserConfig;
780-
use crate::utils::create_invoice_from_channelmanager_and_duration_since_epoch;
808+
use crate::utils::{create_invoice_from_channelmanager_and_duration_since_epoch, rotate_through_iterators};
781809
use std::collections::HashSet;
782810

783811
#[test]
@@ -1886,4 +1914,111 @@ mod test {
18861914
_ => panic!(),
18871915
}
18881916
}
1917+
1918+
#[test]
1919+
fn test_rotate_through_iterators() {
1920+
// two nested vectors
1921+
let a = vec![vec!["a0", "b0", "c0"].into_iter(), vec!["a1", "b1"].into_iter()];
1922+
let result = rotate_through_iterators(a).collect::<Vec<_>>();
1923+
1924+
let expected = vec!["a0", "a1", "b0", "b1", "c0"];
1925+
assert_eq!(expected, result);
1926+
1927+
// test single nested vector
1928+
let a = vec![vec!["a0", "b0", "c0"].into_iter()];
1929+
let result = rotate_through_iterators(a).collect::<Vec<_>>();
1930+
1931+
let expected = vec!["a0", "b0", "c0"];
1932+
assert_eq!(expected, result);
1933+
1934+
// test second vector with only one element
1935+
let a = vec![vec!["a0", "b0", "c0"].into_iter(), vec!["a1"].into_iter()];
1936+
let result = rotate_through_iterators(a).collect::<Vec<_>>();
1937+
1938+
let expected = vec!["a0", "a1", "b0", "c0"];
1939+
assert_eq!(expected, result);
1940+
1941+
// test three nestend vectors
1942+
let a = vec![vec!["a0"].into_iter(), vec!["a1", "b1", "c1"].into_iter(), vec!["a2"].into_iter()];
1943+
let result = rotate_through_iterators(a).collect::<Vec<_>>();
1944+
1945+
let expected = vec!["a0", "a1", "a2", "b1", "c1"];
1946+
assert_eq!(expected, result);
1947+
1948+
// test single nested vector with a single value
1949+
let a = vec![vec!["a0"].into_iter()];
1950+
let result = rotate_through_iterators(a).collect::<Vec<_>>();
1951+
1952+
let expected = vec!["a0"];
1953+
assert_eq!(expected, result);
1954+
1955+
// test single empty nested vector
1956+
let a:Vec<std::vec::IntoIter<&str>> = vec![vec![].into_iter()];
1957+
let result = rotate_through_iterators(a).collect::<Vec<&str>>();
1958+
let expected:Vec<&str> = vec![];
1959+
1960+
assert_eq!(expected, result);
1961+
1962+
// test first nested vector is empty
1963+
let a:Vec<std::vec::IntoIter<&str>>= vec![vec![].into_iter(), vec!["a1", "b1", "c1"].into_iter()];
1964+
let result = rotate_through_iterators(a).collect::<Vec<&str>>();
1965+
1966+
let expected = vec!["a1", "b1", "c1"];
1967+
assert_eq!(expected, result);
1968+
1969+
// test two empty vectors
1970+
let a:Vec<std::vec::IntoIter<&str>> = vec![vec![].into_iter(), vec![].into_iter()];
1971+
let result = rotate_through_iterators(a).collect::<Vec<&str>>();
1972+
1973+
let expected:Vec<&str> = vec![];
1974+
assert_eq!(expected, result);
1975+
1976+
// test an empty vector amongst other filled vectors
1977+
let a = vec![
1978+
vec!["a0", "b0", "c0"].into_iter(),
1979+
vec![].into_iter(),
1980+
vec!["a1", "b1", "c1"].into_iter(),
1981+
vec!["a2", "b2", "c2"].into_iter(),
1982+
];
1983+
let result = rotate_through_iterators(a).collect::<Vec<_>>();
1984+
1985+
let expected = vec!["a0", "a1", "a2", "b0", "b1", "b2", "c0", "c1", "c2"];
1986+
assert_eq!(expected, result);
1987+
1988+
// test a filled vector between two empty vectors
1989+
let a = vec![vec![].into_iter(), vec!["a1", "b1", "c1"].into_iter(), vec![].into_iter()];
1990+
let result = rotate_through_iterators(a).collect::<Vec<_>>();
1991+
1992+
let expected = vec!["a1", "b1", "c1"];
1993+
assert_eq!(expected, result);
1994+
1995+
// test an empty vector at the end of the vectors
1996+
let a = vec![vec!["a0", "b0", "c0"].into_iter(), vec![].into_iter()];
1997+
let result = rotate_through_iterators(a).collect::<Vec<_>>();
1998+
1999+
let expected = vec!["a0", "b0", "c0"];
2000+
assert_eq!(expected, result);
2001+
2002+
// test multiple empty vectors amongst multiple filled vectors
2003+
let a = vec![
2004+
vec![].into_iter(),
2005+
vec!["a1", "b1", "c1"].into_iter(),
2006+
vec![].into_iter(),
2007+
vec!["a3", "b3"].into_iter(),
2008+
vec![].into_iter(),
2009+
];
2010+
2011+
let result = rotate_through_iterators(a).collect::<Vec<_>>();
2012+
2013+
let expected = vec!["a1", "a3", "b1", "b3", "c1"];
2014+
assert_eq!(expected, result);
2015+
2016+
// test one element in the first nested vectore and two elements in the second nested
2017+
// vector
2018+
let a = vec![vec!["a0"].into_iter(), vec!["a1", "b1"].into_iter()];
2019+
let result = rotate_through_iterators(a).collect::<Vec<_>>();
2020+
2021+
let expected = vec!["a0", "a1", "b1"];
2022+
assert_eq!(expected, result);
2023+
}
18892024
}

0 commit comments

Comments
 (0)