@@ -18,6 +18,7 @@ use lightning::util::logger::Logger;
18
18
use secp256k1:: PublicKey ;
19
19
use core:: ops:: Deref ;
20
20
use core:: time:: Duration ;
21
+ use core:: iter:: Iterator ;
21
22
22
23
/// Utility to create an invoice that can be paid to one of multiple nodes, or a "phantom invoice."
23
24
/// See [`PhantomKeysManager`] for more information on phantom node payments.
@@ -132,6 +133,8 @@ where
132
133
)
133
134
}
134
135
136
+ const MAX_CHANNEL_HINTS : usize = 3 ;
137
+
135
138
fn _create_phantom_invoice < ES : Deref , NS : Deref , L : Deref > (
136
139
amt_msat : Option < u64 > , payment_hash : Option < PaymentHash > , description : InvoiceDescription ,
137
140
invoice_expiry_delta_secs : u32 , phantom_route_hints : Vec < PhantomRouteHints > , entropy_source : ES ,
@@ -202,7 +205,8 @@ where
202
205
invoice = invoice. amount_milli_satoshis ( amt) ;
203
206
}
204
207
205
- for route_hint in select_phantom_hints ( amt_msat, phantom_route_hints, logger) {
208
+
209
+ for route_hint in select_phantom_hints ( amt_msat, phantom_route_hints, logger) . take ( MAX_CHANNEL_HINTS ) {
206
210
invoice = invoice. private_route ( route_hint) ;
207
211
}
208
212
@@ -229,36 +233,48 @@ where
229
233
///
230
234
/// [`PhantomKeysManager`]: lightning::sign::PhantomKeysManager
231
235
fn select_phantom_hints < L : Deref > ( amt_msat : Option < u64 > , phantom_route_hints : Vec < PhantomRouteHints > ,
232
- logger : L ) -> Vec < RouteHint >
236
+ logger : L ) -> impl Iterator < Item = RouteHint >
233
237
where
234
238
L :: Target : Logger ,
235
239
{
236
- let mut phantom_hints: Vec < Vec < RouteHint > > = Vec :: new ( ) ;
240
+ let mut phantom_hints: Vec < _ > = Vec :: new ( ) ;
237
241
238
242
for PhantomRouteHints { channels, phantom_scid, real_node_pubkey } in phantom_route_hints {
239
243
log_trace ! ( logger, "Generating phantom route hints for node {}" ,
240
244
log_pubkey!( real_node_pubkey) ) ;
241
- let mut route_hints = sort_and_filter_channels ( channels, amt_msat, & logger) ;
245
+ let route_hints = sort_and_filter_channels ( channels, amt_msat, & logger) ;
242
246
243
247
// If we have any public channel, the route hints from `sort_and_filter_channels` will be
244
248
// empty. In that case we create a RouteHint on which we will push a single hop with the
245
249
// phantom route into the invoice, and let the sender find the path to the `real_node_pubkey`
246
250
// node by looking at our public channels.
247
- if route_hints. is_empty ( ) {
248
- route_hints. push ( RouteHint ( vec ! [ ] ) )
249
- }
250
- for route_hint in & mut route_hints {
251
- route_hint. 0 . push ( RouteHintHop {
252
- src_node_id : real_node_pubkey,
253
- short_channel_id : phantom_scid,
254
- fees : RoutingFees {
255
- base_msat : 0 ,
256
- proportional_millionths : 0 ,
257
- } ,
258
- cltv_expiry_delta : MIN_CLTV_EXPIRY_DELTA ,
259
- htlc_minimum_msat : None ,
260
- htlc_maximum_msat : None , } ) ;
261
- }
251
+ let empty_route_hints = route_hints. len ( ) == 0 ;
252
+ let mut have_pushed_empty = false ;
253
+ let route_hints = route_hints
254
+ . chain ( core:: iter:: from_fn ( move || {
255
+ if empty_route_hints && !have_pushed_empty {
256
+ // set flag of having handled the empty route_hints and ensure empty vector
257
+ // returned only once
258
+ have_pushed_empty = true ;
259
+ Some ( RouteHint ( Vec :: new ( ) ) )
260
+ } else {
261
+ None
262
+ }
263
+ } ) )
264
+ . map ( move |mut hint| {
265
+ hint. 0 . push ( RouteHintHop {
266
+ src_node_id : real_node_pubkey,
267
+ short_channel_id : phantom_scid,
268
+ fees : RoutingFees {
269
+ base_msat : 0 ,
270
+ proportional_millionths : 0 ,
271
+ } ,
272
+ cltv_expiry_delta : MIN_CLTV_EXPIRY_DELTA ,
273
+ htlc_minimum_msat : None ,
274
+ htlc_maximum_msat : None ,
275
+ } ) ;
276
+ hint
277
+ } ) ;
262
278
263
279
phantom_hints. push ( route_hints) ;
264
280
}
@@ -267,29 +283,34 @@ where
267
283
// the hints across our real nodes we add one hint from each in turn until no node has any hints
268
284
// left (if one node has more hints than any other, these will accumulate at the end of the
269
285
// vector).
270
- let mut invoice_hints : Vec < RouteHint > = Vec :: new ( ) ;
271
- let mut hint_idx = 0 ;
286
+ rotate_through_iterators ( phantom_hints )
287
+ }
272
288
273
- loop {
274
- let mut remaining_hints = false ;
289
+ /// Draw items iteratively from multiple iterators. The items are retrieved by index and
290
+ /// rotates through the iterators - first the zero index then the first index then second index, etc.
291
+ fn rotate_through_iterators < T , I : Iterator < Item = T > > ( mut vecs : Vec < I > ) -> impl Iterator < Item = T > {
292
+ let mut iterations = 0 ;
275
293
276
- for hints in phantom_hints. iter ( ) {
277
- if invoice_hints. len ( ) == 3 {
278
- return invoice_hints
294
+ core:: iter:: from_fn ( move || {
295
+ let mut exhausted_iterators = 0 ;
296
+ loop {
297
+ if vecs. is_empty ( ) {
298
+ return None ;
279
299
}
280
-
281
- if hint_idx < hints. len ( ) {
282
- invoice_hints. push ( hints[ hint_idx] . clone ( ) ) ;
283
- remaining_hints = true
300
+ let next_idx = iterations % vecs. len ( ) ;
301
+ iterations += 1 ;
302
+ if let Some ( item) = vecs[ next_idx] . next ( ) {
303
+ return Some ( item) ;
304
+ }
305
+ // exhausted_vectors increase when the "next_idx" vector is exhausted
306
+ exhausted_iterators += 1 ;
307
+ // The check for exhausted iterators gets reset to 0 after each yield of `Some()`
308
+ // The loop will return None when all of the nested iterators are exhausted
309
+ if exhausted_iterators == vecs. len ( ) {
310
+ return None ;
284
311
}
285
312
}
286
-
287
- if !remaining_hints {
288
- return invoice_hints
289
- }
290
-
291
- hint_idx +=1 ;
292
- }
313
+ } )
293
314
}
294
315
295
316
#[ cfg( feature = "std" ) ]
@@ -575,15 +596,34 @@ fn _create_invoice_from_channelmanager_and_duration_since_epoch_with_payment_has
575
596
/// * Sorted by lowest inbound capacity if an online channel with the minimum amount requested exists,
576
597
/// otherwise sort by highest inbound capacity to give the payment the best chance of succeeding.
577
598
fn sort_and_filter_channels < L : Deref > (
578
- channels : Vec < ChannelDetails > , min_inbound_capacity_msat : Option < u64 > , logger : & L
579
- ) -> Vec < RouteHint > where L :: Target : Logger {
599
+ channels : Vec < ChannelDetails > ,
600
+ min_inbound_capacity_msat : Option < u64 > ,
601
+ logger : & L ,
602
+ ) -> impl ExactSizeIterator < Item = RouteHint >
603
+ where
604
+ L :: Target : Logger ,
605
+ {
580
606
let mut filtered_channels: HashMap < PublicKey , ChannelDetails > = HashMap :: new ( ) ;
581
607
let min_inbound_capacity = min_inbound_capacity_msat. unwrap_or ( 0 ) ;
582
608
let mut min_capacity_channel_exists = false ;
583
609
let mut online_channel_exists = false ;
584
610
let mut online_min_capacity_channel_exists = false ;
585
611
let mut has_pub_unconf_chan = false ;
586
612
613
+ let route_hint_from_channel = |channel : ChannelDetails | {
614
+ let forwarding_info = channel. counterparty . forwarding_info . as_ref ( ) . unwrap ( ) ;
615
+ RouteHint ( vec ! [ RouteHintHop {
616
+ src_node_id: channel. counterparty. node_id,
617
+ short_channel_id: channel. get_inbound_payment_scid( ) . unwrap( ) ,
618
+ fees: RoutingFees {
619
+ base_msat: forwarding_info. fee_base_msat,
620
+ proportional_millionths: forwarding_info. fee_proportional_millionths,
621
+ } ,
622
+ cltv_expiry_delta: forwarding_info. cltv_expiry_delta,
623
+ htlc_minimum_msat: channel. inbound_htlc_minimum_msat,
624
+ htlc_maximum_msat: channel. inbound_htlc_maximum_msat, } ] )
625
+ } ;
626
+
587
627
log_trace ! ( logger, "Considering {} channels for invoice route hints" , channels. len( ) ) ;
588
628
for channel in channels. into_iter ( ) . filter ( |chan| chan. is_channel_ready ) {
589
629
if channel. get_inbound_payment_scid ( ) . is_none ( ) || channel. counterparty . forwarding_info . is_none ( ) {
@@ -602,7 +642,7 @@ fn sort_and_filter_channels<L: Deref>(
602
642
// look at the public channels instead.
603
643
log_trace ! ( logger, "Not including channels in invoice route hints on account of public channel {}" ,
604
644
log_bytes!( channel. channel_id) ) ;
605
- return vec ! [ ]
645
+ return vec ! [ ] . into_iter ( ) . take ( MAX_CHANNEL_HINTS ) . map ( route_hint_from_channel ) ;
606
646
}
607
647
}
608
648
@@ -662,19 +702,6 @@ fn sort_and_filter_channels<L: Deref>(
662
702
}
663
703
}
664
704
665
- let route_hint_from_channel = |channel : ChannelDetails | {
666
- let forwarding_info = channel. counterparty . forwarding_info . as_ref ( ) . unwrap ( ) ;
667
- RouteHint ( vec ! [ RouteHintHop {
668
- src_node_id: channel. counterparty. node_id,
669
- short_channel_id: channel. get_inbound_payment_scid( ) . unwrap( ) ,
670
- fees: RoutingFees {
671
- base_msat: forwarding_info. fee_base_msat,
672
- proportional_millionths: forwarding_info. fee_proportional_millionths,
673
- } ,
674
- cltv_expiry_delta: forwarding_info. cltv_expiry_delta,
675
- htlc_minimum_msat: channel. inbound_htlc_minimum_msat,
676
- htlc_maximum_msat: channel. inbound_htlc_maximum_msat, } ] )
677
- } ;
678
705
// If all channels are private, prefer to return route hints which have a higher capacity than
679
706
// the payment value and where we're currently connected to the channel counterparty.
680
707
// Even if we cannot satisfy both goals, always ensure we include *some* hints, preferring
@@ -724,7 +751,8 @@ fn sort_and_filter_channels<L: Deref>(
724
751
} else {
725
752
b. inbound_capacity_msat . cmp ( & a. inbound_capacity_msat )
726
753
} } ) ;
727
- eligible_channels. into_iter ( ) . take ( 3 ) . map ( route_hint_from_channel) . collect :: < Vec < RouteHint > > ( )
754
+
755
+ eligible_channels. into_iter ( ) . take ( MAX_CHANNEL_HINTS ) . map ( route_hint_from_channel)
728
756
}
729
757
730
758
/// prefer_current_channel chooses a channel to use for route hints between a currently selected and candidate
@@ -777,7 +805,7 @@ mod test {
777
805
use lightning:: routing:: router:: { PaymentParameters , RouteParameters } ;
778
806
use lightning:: util:: test_utils;
779
807
use lightning:: util:: config:: UserConfig ;
780
- use crate :: utils:: create_invoice_from_channelmanager_and_duration_since_epoch;
808
+ use crate :: utils:: { create_invoice_from_channelmanager_and_duration_since_epoch, rotate_through_iterators } ;
781
809
use std:: collections:: HashSet ;
782
810
783
811
#[ test]
@@ -1886,4 +1914,111 @@ mod test {
1886
1914
_ => panic ! ( ) ,
1887
1915
}
1888
1916
}
1917
+
1918
+ #[ test]
1919
+ fn test_rotate_through_iterators ( ) {
1920
+ // two nested vectors
1921
+ let a = vec ! [ vec![ "a0" , "b0" , "c0" ] . into_iter( ) , vec![ "a1" , "b1" ] . into_iter( ) ] ;
1922
+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1923
+
1924
+ let expected = vec ! [ "a0" , "a1" , "b0" , "b1" , "c0" ] ;
1925
+ assert_eq ! ( expected, result) ;
1926
+
1927
+ // test single nested vector
1928
+ let a = vec ! [ vec![ "a0" , "b0" , "c0" ] . into_iter( ) ] ;
1929
+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1930
+
1931
+ let expected = vec ! [ "a0" , "b0" , "c0" ] ;
1932
+ assert_eq ! ( expected, result) ;
1933
+
1934
+ // test second vector with only one element
1935
+ let a = vec ! [ vec![ "a0" , "b0" , "c0" ] . into_iter( ) , vec![ "a1" ] . into_iter( ) ] ;
1936
+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1937
+
1938
+ let expected = vec ! [ "a0" , "a1" , "b0" , "c0" ] ;
1939
+ assert_eq ! ( expected, result) ;
1940
+
1941
+ // test three nestend vectors
1942
+ let a = vec ! [ vec![ "a0" ] . into_iter( ) , vec![ "a1" , "b1" , "c1" ] . into_iter( ) , vec![ "a2" ] . into_iter( ) ] ;
1943
+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1944
+
1945
+ let expected = vec ! [ "a0" , "a1" , "a2" , "b1" , "c1" ] ;
1946
+ assert_eq ! ( expected, result) ;
1947
+
1948
+ // test single nested vector with a single value
1949
+ let a = vec ! [ vec![ "a0" ] . into_iter( ) ] ;
1950
+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1951
+
1952
+ let expected = vec ! [ "a0" ] ;
1953
+ assert_eq ! ( expected, result) ;
1954
+
1955
+ // test single empty nested vector
1956
+ let a: Vec < std:: vec:: IntoIter < & str > > = vec ! [ vec![ ] . into_iter( ) ] ;
1957
+ let result = rotate_through_iterators ( a) . collect :: < Vec < & str > > ( ) ;
1958
+ let expected: Vec < & str > = vec ! [ ] ;
1959
+
1960
+ assert_eq ! ( expected, result) ;
1961
+
1962
+ // test first nested vector is empty
1963
+ let a: Vec < std:: vec:: IntoIter < & str > > = vec ! [ vec![ ] . into_iter( ) , vec![ "a1" , "b1" , "c1" ] . into_iter( ) ] ;
1964
+ let result = rotate_through_iterators ( a) . collect :: < Vec < & str > > ( ) ;
1965
+
1966
+ let expected = vec ! [ "a1" , "b1" , "c1" ] ;
1967
+ assert_eq ! ( expected, result) ;
1968
+
1969
+ // test two empty vectors
1970
+ let a: Vec < std:: vec:: IntoIter < & str > > = vec ! [ vec![ ] . into_iter( ) , vec![ ] . into_iter( ) ] ;
1971
+ let result = rotate_through_iterators ( a) . collect :: < Vec < & str > > ( ) ;
1972
+
1973
+ let expected: Vec < & str > = vec ! [ ] ;
1974
+ assert_eq ! ( expected, result) ;
1975
+
1976
+ // test an empty vector amongst other filled vectors
1977
+ let a = vec ! [
1978
+ vec![ "a0" , "b0" , "c0" ] . into_iter( ) ,
1979
+ vec![ ] . into_iter( ) ,
1980
+ vec![ "a1" , "b1" , "c1" ] . into_iter( ) ,
1981
+ vec![ "a2" , "b2" , "c2" ] . into_iter( ) ,
1982
+ ] ;
1983
+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1984
+
1985
+ let expected = vec ! [ "a0" , "a1" , "a2" , "b0" , "b1" , "b2" , "c0" , "c1" , "c2" ] ;
1986
+ assert_eq ! ( expected, result) ;
1987
+
1988
+ // test a filled vector between two empty vectors
1989
+ let a = vec ! [ vec![ ] . into_iter( ) , vec![ "a1" , "b1" , "c1" ] . into_iter( ) , vec![ ] . into_iter( ) ] ;
1990
+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1991
+
1992
+ let expected = vec ! [ "a1" , "b1" , "c1" ] ;
1993
+ assert_eq ! ( expected, result) ;
1994
+
1995
+ // test an empty vector at the end of the vectors
1996
+ let a = vec ! [ vec![ "a0" , "b0" , "c0" ] . into_iter( ) , vec![ ] . into_iter( ) ] ;
1997
+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1998
+
1999
+ let expected = vec ! [ "a0" , "b0" , "c0" ] ;
2000
+ assert_eq ! ( expected, result) ;
2001
+
2002
+ // test multiple empty vectors amongst multiple filled vectors
2003
+ let a = vec ! [
2004
+ vec![ ] . into_iter( ) ,
2005
+ vec![ "a1" , "b1" , "c1" ] . into_iter( ) ,
2006
+ vec![ ] . into_iter( ) ,
2007
+ vec![ "a3" , "b3" ] . into_iter( ) ,
2008
+ vec![ ] . into_iter( ) ,
2009
+ ] ;
2010
+
2011
+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
2012
+
2013
+ let expected = vec ! [ "a1" , "a3" , "b1" , "b3" , "c1" ] ;
2014
+ assert_eq ! ( expected, result) ;
2015
+
2016
+ // test one element in the first nested vectore and two elements in the second nested
2017
+ // vector
2018
+ let a = vec ! [ vec![ "a0" ] . into_iter( ) , vec![ "a1" , "b1" ] . into_iter( ) ] ;
2019
+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
2020
+
2021
+ let expected = vec ! [ "a0" , "a1" , "b1" ] ;
2022
+ assert_eq ! ( expected, result) ;
2023
+ }
1889
2024
}
0 commit comments