Skip to content

Set return type to Iterator for functions in file: lightning-invoice/src/utils.rs : issue #2240 #2290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 190 additions & 55 deletions lightning-invoice/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use lightning::util::logger::Logger;
use secp256k1::PublicKey;
use core::ops::Deref;
use core::time::Duration;
use core::iter::Iterator;

/// Utility to create an invoice that can be paid to one of multiple nodes, or a "phantom invoice."
/// See [`PhantomKeysManager`] for more information on phantom node payments.
Expand Down Expand Up @@ -132,6 +133,8 @@ where
)
}

const MAX_CHANNEL_HINTS: usize = 3;

fn _create_phantom_invoice<ES: Deref, NS: Deref, L: Deref>(
amt_msat: Option<u64>, payment_hash: Option<PaymentHash>, description: InvoiceDescription,
invoice_expiry_delta_secs: u32, phantom_route_hints: Vec<PhantomRouteHints>, entropy_source: ES,
Expand Down Expand Up @@ -202,7 +205,8 @@ where
invoice = invoice.amount_milli_satoshis(amt);
}

for route_hint in select_phantom_hints(amt_msat, phantom_route_hints, logger) {

for route_hint in select_phantom_hints(amt_msat, phantom_route_hints, logger).take(MAX_CHANNEL_HINTS) {
invoice = invoice.private_route(route_hint);
}

Expand All @@ -229,36 +233,48 @@ where
///
/// [`PhantomKeysManager`]: lightning::sign::PhantomKeysManager
fn select_phantom_hints<L: Deref>(amt_msat: Option<u64>, phantom_route_hints: Vec<PhantomRouteHints>,
logger: L) -> Vec<RouteHint>
logger: L) -> impl Iterator<Item = RouteHint>
where
L::Target: Logger,
{
let mut phantom_hints: Vec<Vec<RouteHint>> = Vec::new();
let mut phantom_hints: Vec<_> = Vec::new();

for PhantomRouteHints { channels, phantom_scid, real_node_pubkey } in phantom_route_hints {
log_trace!(logger, "Generating phantom route hints for node {}",
log_pubkey!(real_node_pubkey));
let mut route_hints = sort_and_filter_channels(channels, amt_msat, &logger);
let route_hints = sort_and_filter_channels(channels, amt_msat, &logger);

// If we have any public channel, the route hints from `sort_and_filter_channels` will be
// empty. In that case we create a RouteHint on which we will push a single hop with the
// phantom route into the invoice, and let the sender find the path to the `real_node_pubkey`
// node by looking at our public channels.
if route_hints.is_empty() {
route_hints.push(RouteHint(vec![]))
}
for route_hint in &mut route_hints {
route_hint.0.push(RouteHintHop {
src_node_id: real_node_pubkey,
short_channel_id: phantom_scid,
fees: RoutingFees {
base_msat: 0,
proportional_millionths: 0,
},
cltv_expiry_delta: MIN_CLTV_EXPIRY_DELTA,
htlc_minimum_msat: None,
htlc_maximum_msat: None,});
}
let empty_route_hints = route_hints.len() == 0;
let mut have_pushed_empty = false;
let route_hints = route_hints
.chain(core::iter::from_fn(move || {
if empty_route_hints && !have_pushed_empty {
// set flag of having handled the empty route_hints and ensure empty vector
// returned only once
have_pushed_empty = true;
Some(RouteHint(Vec::new()))
} else {
None
}
}))
.map(move |mut hint| {
hint.0.push(RouteHintHop {
src_node_id: real_node_pubkey,
short_channel_id: phantom_scid,
fees: RoutingFees {
base_msat: 0,
proportional_millionths: 0,
},
cltv_expiry_delta: MIN_CLTV_EXPIRY_DELTA,
htlc_minimum_msat: None,
htlc_maximum_msat: None,
});
hint
});

phantom_hints.push(route_hints);
}
Expand All @@ -267,29 +283,34 @@ where
// the hints across our real nodes we add one hint from each in turn until no node has any hints
// left (if one node has more hints than any other, these will accumulate at the end of the
// vector).
let mut invoice_hints: Vec<RouteHint> = Vec::new();
let mut hint_idx = 0;
rotate_through_iterators(phantom_hints)
}

loop {
let mut remaining_hints = false;
/// Draw items iteratively from multiple iterators. The items are retrieved by index and
/// rotates through the iterators - first the zero index then the first index then second index, etc.
fn rotate_through_iterators<T, I: Iterator<Item = T>>(mut vecs: Vec<I>) -> impl Iterator<Item = T> {
let mut iterations = 0;

for hints in phantom_hints.iter() {
if invoice_hints.len() == 3 {
return invoice_hints
core::iter::from_fn(move || {
let mut exhausted_iterators = 0;
loop {
if vecs.is_empty() {
return None;
}

if hint_idx < hints.len() {
invoice_hints.push(hints[hint_idx].clone());
remaining_hints = true
let next_idx = iterations % vecs.len();
iterations += 1;
if let Some(item) = vecs[next_idx].next() {
return Some(item);
}
// exhausted_vectors increase when the "next_idx" vector is exhausted
exhausted_iterators += 1;
// The check for exhausted iterators gets reset to 0 after each yield of `Some()`
// The loop will return None when all of the nested iterators are exhausted
if exhausted_iterators == vecs.len() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work if, for example, there's one iterator with zero elements and another iterator with 5? On each loop iteration/function call we check the first iterator, see its empty, increment exhausted_iterators, and then return Some for the next iterator. Once we've done that once we'll return None, even though we have more elements from another iterator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we yield Some, the next call will reset exhausted_iterators to zero since it is a local variable, IIUC.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops indeed, you're right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add to the comment to help with the explanation of how the loop function. Each of us has needed to think through it, so maybe a comment with help with future reading of the code.

return None;
}
}

if !remaining_hints {
return invoice_hints
}

hint_idx +=1;
}
})
}

#[cfg(feature = "std")]
Expand Down Expand Up @@ -575,15 +596,34 @@ fn _create_invoice_from_channelmanager_and_duration_since_epoch_with_payment_has
/// * Sorted by lowest inbound capacity if an online channel with the minimum amount requested exists,
/// otherwise sort by highest inbound capacity to give the payment the best chance of succeeding.
fn sort_and_filter_channels<L: Deref>(
channels: Vec<ChannelDetails>, min_inbound_capacity_msat: Option<u64>, logger: &L
) -> Vec<RouteHint> where L::Target: Logger {
channels: Vec<ChannelDetails>,
min_inbound_capacity_msat: Option<u64>,
logger: &L,
) -> impl ExactSizeIterator<Item = RouteHint>
where
L::Target: Logger,
{
let mut filtered_channels: HashMap<PublicKey, ChannelDetails> = HashMap::new();
let min_inbound_capacity = min_inbound_capacity_msat.unwrap_or(0);
let mut min_capacity_channel_exists = false;
let mut online_channel_exists = false;
let mut online_min_capacity_channel_exists = false;
let mut has_pub_unconf_chan = false;

let route_hint_from_channel = |channel: ChannelDetails| {
let forwarding_info = channel.counterparty.forwarding_info.as_ref().unwrap();
RouteHint(vec![RouteHintHop {
src_node_id: channel.counterparty.node_id,
short_channel_id: channel.get_inbound_payment_scid().unwrap(),
fees: RoutingFees {
base_msat: forwarding_info.fee_base_msat,
proportional_millionths: forwarding_info.fee_proportional_millionths,
},
cltv_expiry_delta: forwarding_info.cltv_expiry_delta,
htlc_minimum_msat: channel.inbound_htlc_minimum_msat,
htlc_maximum_msat: channel.inbound_htlc_maximum_msat,}])
};

log_trace!(logger, "Considering {} channels for invoice route hints", channels.len());
for channel in channels.into_iter().filter(|chan| chan.is_channel_ready) {
if channel.get_inbound_payment_scid().is_none() || channel.counterparty.forwarding_info.is_none() {
Expand All @@ -602,7 +642,7 @@ fn sort_and_filter_channels<L: Deref>(
// look at the public channels instead.
log_trace!(logger, "Not including channels in invoice route hints on account of public channel {}",
log_bytes!(channel.channel_id));
return vec![]
return vec![].into_iter().take(MAX_CHANNEL_HINTS).map(route_hint_from_channel);
}
}

Expand Down Expand Up @@ -662,19 +702,6 @@ fn sort_and_filter_channels<L: Deref>(
}
}

let route_hint_from_channel = |channel: ChannelDetails| {
let forwarding_info = channel.counterparty.forwarding_info.as_ref().unwrap();
RouteHint(vec![RouteHintHop {
src_node_id: channel.counterparty.node_id,
short_channel_id: channel.get_inbound_payment_scid().unwrap(),
fees: RoutingFees {
base_msat: forwarding_info.fee_base_msat,
proportional_millionths: forwarding_info.fee_proportional_millionths,
},
cltv_expiry_delta: forwarding_info.cltv_expiry_delta,
htlc_minimum_msat: channel.inbound_htlc_minimum_msat,
htlc_maximum_msat: channel.inbound_htlc_maximum_msat,}])
};
// If all channels are private, prefer to return route hints which have a higher capacity than
// the payment value and where we're currently connected to the channel counterparty.
// Even if we cannot satisfy both goals, always ensure we include *some* hints, preferring
Expand Down Expand Up @@ -724,7 +751,8 @@ fn sort_and_filter_channels<L: Deref>(
} else {
b.inbound_capacity_msat.cmp(&a.inbound_capacity_msat)
}});
eligible_channels.into_iter().take(3).map(route_hint_from_channel).collect::<Vec<RouteHint>>()

eligible_channels.into_iter().take(MAX_CHANNEL_HINTS).map(route_hint_from_channel)
}

/// prefer_current_channel chooses a channel to use for route hints between a currently selected and candidate
Expand Down Expand Up @@ -777,7 +805,7 @@ mod test {
use lightning::routing::router::{PaymentParameters, RouteParameters};
use lightning::util::test_utils;
use lightning::util::config::UserConfig;
use crate::utils::create_invoice_from_channelmanager_and_duration_since_epoch;
use crate::utils::{create_invoice_from_channelmanager_and_duration_since_epoch, rotate_through_iterators};
use std::collections::HashSet;

#[test]
Expand Down Expand Up @@ -1886,4 +1914,111 @@ mod test {
_ => panic!(),
}
}

#[test]
fn test_rotate_through_iterators() {
// two nested vectors
let a = vec![vec!["a0", "b0", "c0"].into_iter(), vec!["a1", "b1"].into_iter()];
let result = rotate_through_iterators(a).collect::<Vec<_>>();

let expected = vec!["a0", "a1", "b0", "b1", "c0"];
assert_eq!(expected, result);

// test single nested vector
let a = vec![vec!["a0", "b0", "c0"].into_iter()];
let result = rotate_through_iterators(a).collect::<Vec<_>>();

let expected = vec!["a0", "b0", "c0"];
assert_eq!(expected, result);

// test second vector with only one element
let a = vec![vec!["a0", "b0", "c0"].into_iter(), vec!["a1"].into_iter()];
let result = rotate_through_iterators(a).collect::<Vec<_>>();

let expected = vec!["a0", "a1", "b0", "c0"];
assert_eq!(expected, result);

// test three nestend vectors
let a = vec![vec!["a0"].into_iter(), vec!["a1", "b1", "c1"].into_iter(), vec!["a2"].into_iter()];
let result = rotate_through_iterators(a).collect::<Vec<_>>();

let expected = vec!["a0", "a1", "a2", "b1", "c1"];
assert_eq!(expected, result);

// test single nested vector with a single value
let a = vec![vec!["a0"].into_iter()];
let result = rotate_through_iterators(a).collect::<Vec<_>>();

let expected = vec!["a0"];
assert_eq!(expected, result);

// test single empty nested vector
let a:Vec<std::vec::IntoIter<&str>> = vec![vec![].into_iter()];
let result = rotate_through_iterators(a).collect::<Vec<&str>>();
let expected:Vec<&str> = vec![];

assert_eq!(expected, result);

// test first nested vector is empty
let a:Vec<std::vec::IntoIter<&str>>= vec![vec![].into_iter(), vec!["a1", "b1", "c1"].into_iter()];
let result = rotate_through_iterators(a).collect::<Vec<&str>>();

let expected = vec!["a1", "b1", "c1"];
assert_eq!(expected, result);

// test two empty vectors
let a:Vec<std::vec::IntoIter<&str>> = vec![vec![].into_iter(), vec![].into_iter()];
let result = rotate_through_iterators(a).collect::<Vec<&str>>();

let expected:Vec<&str> = vec![];
assert_eq!(expected, result);

// test an empty vector amongst other filled vectors
let a = vec![
vec!["a0", "b0", "c0"].into_iter(),
vec![].into_iter(),
vec!["a1", "b1", "c1"].into_iter(),
vec!["a2", "b2", "c2"].into_iter(),
];
let result = rotate_through_iterators(a).collect::<Vec<_>>();

let expected = vec!["a0", "a1", "a2", "b0", "b1", "b2", "c0", "c1", "c2"];
assert_eq!(expected, result);

// test a filled vector between two empty vectors
let a = vec![vec![].into_iter(), vec!["a1", "b1", "c1"].into_iter(), vec![].into_iter()];
let result = rotate_through_iterators(a).collect::<Vec<_>>();

let expected = vec!["a1", "b1", "c1"];
assert_eq!(expected, result);

// test an empty vector at the end of the vectors
let a = vec![vec!["a0", "b0", "c0"].into_iter(), vec![].into_iter()];
let result = rotate_through_iterators(a).collect::<Vec<_>>();

let expected = vec!["a0", "b0", "c0"];
assert_eq!(expected, result);

// test multiple empty vectors amongst multiple filled vectors
let a = vec![
vec![].into_iter(),
vec!["a1", "b1", "c1"].into_iter(),
vec![].into_iter(),
vec!["a3", "b3"].into_iter(),
vec![].into_iter(),
];

let result = rotate_through_iterators(a).collect::<Vec<_>>();

let expected = vec!["a1", "a3", "b1", "b3", "c1"];
assert_eq!(expected, result);

// test one element in the first nested vectore and two elements in the second nested
// vector
let a = vec![vec!["a0"].into_iter(), vec!["a1", "b1"].into_iter()];
let result = rotate_through_iterators(a).collect::<Vec<_>>();

let expected = vec!["a0", "a1", "b1"];
assert_eq!(expected, result);
}
}