Skip to content

Commit d4d8b81

Browse files
committed
Changing sort_and_filter_channels return to ExactSizeIterator and
removing built up vector in `select_phantom_hints` - refactoring `rotate_nested_vectors` from Matt Corallo's suggestion
1 parent eee4f6a commit d4d8b81

File tree

1 file changed

+78
-67
lines changed

1 file changed

+78
-67
lines changed

lightning-invoice/src/utils.rs

+78-67
Original file line numberDiff line numberDiff line change
@@ -236,69 +236,80 @@ fn select_phantom_hints<L: Deref>(amt_msat: Option<u64>, phantom_route_hints: Ve
236236
where
237237
L::Target: Logger,
238238
{
239-
let mut phantom_hints: Vec<Vec<RouteHint>> = Vec::new();
239+
let mut phantom_hints: Vec<_> = Vec::new();
240240

241241
for PhantomRouteHints { channels, phantom_scid, real_node_pubkey } in phantom_route_hints {
242242
log_trace!(logger, "Generating phantom route hints for node {}",
243243
log_pubkey!(real_node_pubkey));
244-
let mut route_hints: Vec<RouteHint> = sort_and_filter_channels(channels, amt_msat, &logger).collect();
244+
let route_hints = sort_and_filter_channels(channels, amt_msat, &logger);
245245

246246
// If we have any public channel, the route hints from `sort_and_filter_channels` will be
247247
// empty. In that case we create a RouteHint on which we will push a single hop with the
248248
// phantom route into the invoice, and let the sender find the path to the `real_node_pubkey`
249249
// node by looking at our public channels.
250-
if route_hints.is_empty() {
251-
route_hints.push(RouteHint(vec![]))
252-
}
253-
for route_hint in &mut route_hints {
254-
route_hint.0.push(RouteHintHop {
255-
src_node_id: real_node_pubkey,
256-
short_channel_id: phantom_scid,
257-
fees: RoutingFees {
258-
base_msat: 0,
259-
proportional_millionths: 0,
260-
},
261-
cltv_expiry_delta: MIN_CLTV_EXPIRY_DELTA,
262-
htlc_minimum_msat: None,
263-
htlc_maximum_msat: None,});
264-
}
250+
let empty_route_hints = route_hints.len() == 0;
251+
let mut have_pushed_empty = false;
252+
let route_hints = route_hints
253+
.chain(core::iter::from_fn(move || {
254+
if empty_route_hints && !have_pushed_empty {
255+
// set flag of having handled the empty route_hints and ensure empty vector
256+
// returned only once
257+
have_pushed_empty = true;
258+
Some(RouteHint(Vec::new()))
259+
} else {
260+
None
261+
}
262+
}))
263+
.map(move |mut hint| {
264+
hint.0.push(RouteHintHop {
265+
src_node_id: real_node_pubkey,
266+
short_channel_id: phantom_scid,
267+
fees: RoutingFees {
268+
base_msat: 0,
269+
proportional_millionths: 0,
270+
},
271+
cltv_expiry_delta: MIN_CLTV_EXPIRY_DELTA,
272+
htlc_minimum_msat: None,
273+
htlc_maximum_msat: None,
274+
});
275+
hint
276+
});
265277

266278
phantom_hints.push(route_hints);
279+
267280
}
268281

269282
// We have one vector per real node involved in creating the phantom invoice. To distribute
270283
// the hints across our real nodes we add one hint from each in turn until no node has any hints
271284
// left (if one node has more hints than any other, these will accumulate at the end of the
272285
// vector).
273286
rotate_nested_vectors(phantom_hints)
274-
275287
}
276288

277-
// Draw items iteratively from multiple nested vectors. The items are retrieved by index and
278-
// rotates through the vectors - first the zero index then the first index then second index, etc.
279-
fn rotate_nested_vectors<T: Clone>(vecs: Vec<Vec<T>>) -> impl Iterator<Item = T> {
280-
let max_vector_length: usize = vecs.iter().map(|x| x.len()).max().unwrap();
281-
let mut hint_index = 0;
282-
let mut vector_index = 0;
283-
let number_inner_vectors: usize = vecs.len();
284-
285-
core::iter::from_fn(move || loop {
286-
if hint_index == max_vector_length {
287-
return None;
288-
};
289-
let hint_value = if vecs[vector_index].len() != 0 && vecs[vector_index].len() > hint_index {
290-
Some(vecs[vector_index][hint_index].clone())
291-
} else {
292-
None // no value retrieved - continue looping
293-
};
294-
vector_index += 1;
295-
if hint_index < max_vector_length && vector_index == number_inner_vectors {
296-
vector_index = 0;
297-
hint_index += 1;
298-
};
299-
if !hint_value.is_none() {
300-
return hint_value;
301-
};
289+
/// Draw items iteratively from multiple nested vectors. The items are retrieved by index and
290+
/// rotates through the vectors - first the zero index then the first index then second index, etc.
291+
fn rotate_nested_vectors<T, I: Iterator<Item = T>>(mut vecs: Vec<I>) -> impl Iterator<Item = T> {
292+
let mut idx = 0;
293+
294+
core::iter::from_fn(move || {
295+
let mut exhausted_vectors = 0;
296+
loop {
297+
if vecs.is_empty() {
298+
return None;
299+
}
300+
let next_idx = idx % vecs.len();
301+
let hint_opt = vecs[next_idx].next();
302+
idx += 1;
303+
if let Some(hint) = hint_opt {
304+
return Some(hint);
305+
}
306+
// exhausted_vectors increase when the "next_idx" vector is exhausted
307+
exhausted_vectors += 1;
308+
// return None when all of the nested vectors are exhausted
309+
if exhausted_vectors > vecs.len() {
310+
return None;
311+
}
312+
}
302313
})
303314
}
304315

@@ -588,7 +599,7 @@ fn sort_and_filter_channels<L: Deref>(
588599
channels: Vec<ChannelDetails>,
589600
min_inbound_capacity_msat: Option<u64>,
590601
logger: &L,
591-
) -> impl Iterator<Item = RouteHint>
602+
) -> impl ExactSizeIterator<Item = RouteHint>
592603
where
593604
L::Target: Logger,
594605
{
@@ -1908,94 +1919,94 @@ mod test {
19081919
#[test]
19091920
fn test_zip_nested_vectors() {
19101921
// two nested vectors
1911-
let a = vec![vec!["a0", "b0", "c0"], vec!["a1", "b1"]];
1922+
let a = vec![vec!["a0", "b0", "c0"].into_iter(), vec!["a1", "b1"].into_iter()];
19121923
let result = rotate_nested_vectors(a).collect::<Vec<_>>();
19131924

19141925
let expected = vec!["a0", "a1", "b0", "b1", "c0"];
19151926
assert_eq!(expected, result);
19161927

19171928
// test single nested vector
1918-
let a = vec![vec!["a0", "b0", "c0"]];
1929+
let a = vec![vec!["a0", "b0", "c0"].into_iter()];
19191930
let result = rotate_nested_vectors(a).collect::<Vec<_>>();
19201931

19211932
let expected = vec!["a0", "b0", "c0"];
19221933
assert_eq!(expected, result);
19231934

19241935
// test second vector with only one element
1925-
let a = vec![vec!["a0", "b0", "c0"], vec!["a1"]];
1936+
let a = vec![vec!["a0", "b0", "c0"].into_iter(), vec!["a1"].into_iter()];
19261937
let result = rotate_nested_vectors(a).collect::<Vec<_>>();
19271938

19281939
let expected = vec!["a0", "a1", "b0", "c0"];
19291940
assert_eq!(expected, result);
19301941

19311942
// test three nestend vectors
1932-
let a = vec![vec!["a0"], vec!["a1", "b1", "c1"], vec!["a2"]];
1943+
let a = vec![vec!["a0"].into_iter(), vec!["a1", "b1", "c1"].into_iter(), vec!["a2"].into_iter()];
19331944
let result = rotate_nested_vectors(a).collect::<Vec<_>>();
19341945

19351946
let expected = vec!["a0", "a1", "a2", "b1", "c1"];
19361947
assert_eq!(expected, result);
19371948

19381949
// test single nested vector with a single value
1939-
let a = vec![vec!["a0"]];
1950+
let a = vec![vec!["a0"].into_iter()];
19401951
let result = rotate_nested_vectors(a).collect::<Vec<_>>();
19411952

19421953
let expected = vec!["a0"];
19431954
assert_eq!(expected, result);
19441955

19451956
// test single empty nested vector
1946-
let a:Vec<Vec<&str>> = vec![vec![]];
1947-
let result = rotate_nested_vectors(a).collect::<Vec<_>>();
1957+
let a:Vec<std::vec::IntoIter<&str>> = vec![vec![].into_iter()];
1958+
let result = rotate_nested_vectors(a).collect::<Vec<&str>>();
19481959
let expected:Vec<&str> = vec![];
19491960

19501961
assert_eq!(expected, result);
19511962

19521963
// test first nested vector is empty
1953-
let a = vec![vec![], vec!["a1", "b1", "c1"]];
1954-
let result = rotate_nested_vectors(a).collect::<Vec<_>>();
1964+
let a:Vec<std::vec::IntoIter<&str>>= vec![vec![].into_iter(), vec!["a1", "b1", "c1"].into_iter()];
1965+
let result = rotate_nested_vectors(a).collect::<Vec<&str>>();
19551966

19561967
let expected = vec!["a1", "b1", "c1"];
19571968
assert_eq!(expected, result);
19581969

19591970
// test two empty vectors
1960-
let a:Vec<Vec<&str>> = vec![vec![], vec![]];
1961-
let result = rotate_nested_vectors(a).collect::<Vec<_>>();
1971+
let a:Vec<std::vec::IntoIter<&str>> = vec![vec![].into_iter(), vec![].into_iter()];
1972+
let result = rotate_nested_vectors(a).collect::<Vec<&str>>();
19621973

19631974
let expected:Vec<&str> = vec![];
19641975
assert_eq!(expected, result);
19651976

19661977
// test an empty vector amongst other filled vectors
19671978
let a = vec![
1968-
vec!["a0", "b0", "c0"],
1969-
vec![],
1970-
vec!["a1", "b1", "c1"],
1971-
vec!["a2", "b2", "c2"],
1979+
vec!["a0", "b0", "c0"].into_iter(),
1980+
vec![].into_iter(),
1981+
vec!["a1", "b1", "c1"].into_iter(),
1982+
vec!["a2", "b2", "c2"].into_iter(),
19721983
];
19731984
let result = rotate_nested_vectors(a).collect::<Vec<_>>();
19741985

19751986
let expected = vec!["a0", "a1", "a2", "b0", "b1", "b2", "c0", "c1", "c2"];
19761987
assert_eq!(expected, result);
19771988

19781989
// test a filled vector between two empty vectors
1979-
let a = vec![vec![], vec!["a1", "b1", "c1"], vec![]];
1990+
let a = vec![vec![].into_iter(), vec!["a1", "b1", "c1"].into_iter(), vec![].into_iter()];
19801991
let result = rotate_nested_vectors(a).collect::<Vec<_>>();
19811992

19821993
let expected = vec!["a1", "b1", "c1"];
19831994
assert_eq!(expected, result);
19841995

19851996
// test an empty vector at the end of the vectors
1986-
let a = vec![vec!["a0", "b0", "c0"], vec![]];
1997+
let a = vec![vec!["a0", "b0", "c0"].into_iter(), vec![].into_iter()];
19871998
let result = rotate_nested_vectors(a).collect::<Vec<_>>();
19881999

19892000
let expected = vec!["a0", "b0", "c0"];
19902001
assert_eq!(expected, result);
19912002

19922003
// test multiple empty vectors amongst multiple filled vectors
19932004
let a = vec![
1994-
vec![],
1995-
vec!["a1", "b1", "c1"],
1996-
vec![],
1997-
vec!["a3", "b3"],
1998-
vec![],
2005+
vec![].into_iter(),
2006+
vec!["a1", "b1", "c1"].into_iter(),
2007+
vec![].into_iter(),
2008+
vec!["a3", "b3"].into_iter(),
2009+
vec![].into_iter(),
19992010
];
20002011

20012012
let result = rotate_nested_vectors(a).collect::<Vec<_>>();
@@ -2005,7 +2016,7 @@ mod test {
20052016

20062017
// test one element in the first nested vectore and two elements in the second nested
20072018
// vector
2008-
let a = vec![vec!["a0"], vec!["a1", "b1"]];
2019+
let a = vec![vec!["a0"].into_iter(), vec!["a1", "b1"].into_iter()];
20092020
let result = rotate_nested_vectors(a).collect::<Vec<_>>();
20102021

20112022
let expected = vec!["a0", "a1", "b1"];

0 commit comments

Comments
 (0)