@@ -130,18 +130,27 @@ impl<'a, S: Deref> ScorerAccountingForInFlightHtlcs<'a, S> where S::Target: Scor
130
130
131
131
impl < ' a , S : Deref > ScoreLookUp for ScorerAccountingForInFlightHtlcs < ' a , S > where S :: Target : ScoreLookUp {
132
132
type ScoreParams = <S :: Target as ScoreLookUp >:: ScoreParams ;
133
- fn channel_penalty_msat ( & self , short_channel_id : u64 , source : & NodeId , target : & NodeId , usage : ChannelUsage , score_params : & Self :: ScoreParams ) -> u64 {
133
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , usage : ChannelUsage , score_params : & Self :: ScoreParams ) -> u64 {
134
+ let target = match candidate. target ( ) {
135
+ Some ( target) => target,
136
+ None => return self . scorer . channel_penalty_msat ( candidate, usage, score_params) ,
137
+ } ;
138
+ let short_channel_id = match candidate. short_channel_id ( ) {
139
+ Some ( short_channel_id) => short_channel_id,
140
+ None => return self . scorer . channel_penalty_msat ( candidate, usage, score_params) ,
141
+ } ;
142
+ let source = candidate. source ( ) ;
134
143
if let Some ( used_liquidity) = self . inflight_htlcs . used_liquidity_msat (
135
- source, target, short_channel_id
144
+ & source, & target, short_channel_id
136
145
) {
137
146
let usage = ChannelUsage {
138
147
inflight_htlc_msat : usage. inflight_htlc_msat . saturating_add ( used_liquidity) ,
139
148
..usage
140
149
} ;
141
150
142
- self . scorer . channel_penalty_msat ( short_channel_id , source , target , usage, score_params)
151
+ self . scorer . channel_penalty_msat ( candidate , usage, score_params)
143
152
} else {
144
- self . scorer . channel_penalty_msat ( short_channel_id , source , target , usage, score_params)
153
+ self . scorer . channel_penalty_msat ( candidate , usage, score_params)
145
154
}
146
155
}
147
156
}
@@ -1068,7 +1077,7 @@ impl<'a> CandidateRouteHop<'a> {
1068
1077
/// For `Blinded` and `OneHopBlinded` we return `None` because next hop is not known.
1069
1078
pub fn short_channel_id ( & self ) -> Option < u64 > {
1070
1079
match self {
1071
- CandidateRouteHop :: FirstHop { details, .. } => Some ( details. get_outbound_payment_scid ( ) . unwrap ( ) ) ,
1080
+ CandidateRouteHop :: FirstHop { details, .. } => details. get_outbound_payment_scid ( ) ,
1072
1081
CandidateRouteHop :: PublicHop { short_channel_id, .. } => Some ( * short_channel_id) ,
1073
1082
CandidateRouteHop :: PrivateHop { hint, .. } => Some ( hint. short_channel_id ) ,
1074
1083
CandidateRouteHop :: Blinded { .. } => None ,
@@ -1173,7 +1182,7 @@ impl<'a> CandidateRouteHop<'a> {
1173
1182
CandidateRouteHop :: PublicHop { info, .. } => * info. source ( ) ,
1174
1183
CandidateRouteHop :: PrivateHop { hint, .. } => hint. src_node_id . into ( ) ,
1175
1184
CandidateRouteHop :: Blinded { hint, .. } => hint. 1 . introduction_node_id . into ( ) ,
1176
- CandidateRouteHop :: OneHopBlinded { hint, .. } => hint. 1 . introduction_node_id . into ( )
1185
+ CandidateRouteHop :: OneHopBlinded { hint, .. } => hint. 1 . introduction_node_id . into ( ) ,
1177
1186
}
1178
1187
}
1179
1188
/// Returns the target node id of this hop, if known.
@@ -2011,9 +2020,10 @@ where L::Target: Logger {
2011
2020
inflight_htlc_msat: used_liquidity_msat,
2012
2021
effective_capacity,
2013
2022
} ;
2014
- let channel_penalty_msat = scid_opt. map_or( 0 ,
2015
- |scid| scorer. channel_penalty_msat( scid, & src_node_id, & dest_node_id,
2016
- channel_usage, score_params) ) ;
2023
+ let channel_penalty_msat =
2024
+ scorer. channel_penalty_msat( $candidate,
2025
+ channel_usage,
2026
+ score_params) ;
2017
2027
let path_penalty_msat = $next_hops_path_penalty_msat
2018
2028
. saturating_add( channel_penalty_msat) ;
2019
2029
let new_graph_node = RouteGraphNode {
@@ -2324,7 +2334,7 @@ where L::Target: Logger {
2324
2334
effective_capacity : candidate. effective_capacity ( ) ,
2325
2335
} ;
2326
2336
let channel_penalty_msat = scorer. channel_penalty_msat (
2327
- hop . short_channel_id , & source , & target , channel_usage, score_params
2337
+ & candidate , channel_usage, score_params
2328
2338
) ;
2329
2339
aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
2330
2340
. saturating_add ( channel_penalty_msat) ;
@@ -2879,13 +2889,13 @@ fn build_route_from_hops_internal<L: Deref>(
2879
2889
2880
2890
impl ScoreLookUp for HopScorer {
2881
2891
type ScoreParams = ( ) ;
2882
- fn channel_penalty_msat ( & self , _short_channel_id : u64 , source : & NodeId , target : & NodeId ,
2892
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop ,
2883
2893
_usage : ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64
2884
2894
{
2885
2895
let mut cur_id = self . our_node_id ;
2886
2896
for i in 0 ..self . hop_ids . len ( ) {
2887
2897
if let Some ( next_id) = self . hop_ids [ i] {
2888
- if cur_id == * source && next_id == * target {
2898
+ if cur_id == candidate . source ( ) && Some ( next_id) == candidate . target ( ) {
2889
2899
return 0 ;
2890
2900
}
2891
2901
cur_id = next_id;
@@ -2926,7 +2936,7 @@ mod tests {
2926
2936
use crate :: routing:: utxo:: UtxoResult ;
2927
2937
use crate :: routing:: router:: { get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features,
2928
2938
BlindedTail , InFlightHtlcs , Path , PaymentParameters , Route , RouteHint , RouteHintHop , RouteHop , RoutingFees ,
2929
- DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA , MAX_PATH_LENGTH_ESTIMATE , RouteParameters } ;
2939
+ DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA , MAX_PATH_LENGTH_ESTIMATE , RouteParameters , CandidateRouteHop } ;
2930
2940
use crate :: routing:: scoring:: { ChannelUsage , FixedPenaltyScorer , ScoreLookUp , ProbabilisticScorer , ProbabilisticScoringFeeParameters , ProbabilisticScoringDecayParameters } ;
2931
2941
use crate :: routing:: test_utils:: { add_channel, add_or_update_node, build_graph, build_line_graph, id_to_feature_flags, get_nodes, update_channel} ;
2932
2942
use crate :: chain:: transaction:: OutPoint ;
@@ -6231,8 +6241,8 @@ mod tests {
6231
6241
}
6232
6242
impl ScoreLookUp for BadChannelScorer {
6233
6243
type ScoreParams = ( ) ;
6234
- fn channel_penalty_msat ( & self , short_channel_id : u64 , _ : & NodeId , _ : & NodeId , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6235
- if short_channel_id == self . short_channel_id { u64:: max_value ( ) } else { 0 }
6244
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6245
+ if candidate . short_channel_id ( ) == Some ( self . short_channel_id ) { u64:: max_value ( ) } else { 0 }
6236
6246
}
6237
6247
}
6238
6248
@@ -6247,8 +6257,8 @@ mod tests {
6247
6257
6248
6258
impl ScoreLookUp for BadNodeScorer {
6249
6259
type ScoreParams = ( ) ;
6250
- fn channel_penalty_msat ( & self , _ : u64 , _ : & NodeId , target : & NodeId , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6251
- if * target == self . node_id { u64:: max_value ( ) } else { 0 }
6260
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6261
+ if candidate . target ( ) == Some ( self . node_id ) { u64:: max_value ( ) } else { 0 }
6252
6262
}
6253
6263
}
6254
6264
@@ -6736,26 +6746,32 @@ mod tests {
6736
6746
} ;
6737
6747
scorer_params. set_manual_penalty ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) , 123 ) ;
6738
6748
scorer_params. set_manual_penalty ( & NodeId :: from_pubkey ( & nodes[ 4 ] ) , 456 ) ;
6739
- assert_eq ! ( scorer. channel_penalty_msat( 42 , & NodeId :: from_pubkey( & nodes[ 3 ] ) , & NodeId :: from_pubkey( & nodes[ 4 ] ) , usage, & scorer_params) , 456 ) ;
6749
+ let network_graph = network_graph. read_only ( ) ;
6750
+ let channels = network_graph. channels ( ) ;
6751
+ let channel = channels. get ( & 5 ) . unwrap ( ) ;
6752
+ let info = channel. as_directed_from ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) . unwrap ( ) ;
6753
+ let candidate: CandidateRouteHop = CandidateRouteHop :: PublicHop {
6754
+ info : info. 0 ,
6755
+ short_channel_id : 5 ,
6756
+ } ;
6757
+ assert_eq ! ( scorer. channel_penalty_msat( & candidate, usage, & scorer_params) , 456 ) ;
6740
6758
6741
6759
// Then check we can get a normal route
6742
6760
let payment_params = PaymentParameters :: from_node_id ( nodes[ 10 ] , 42 ) ;
6743
6761
let route_params = RouteParameters :: from_payment_params_and_value (
6744
6762
payment_params, 100 ) ;
6745
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6763
+ let route = get_route ( & our_id, & route_params, & network_graph, None ,
6746
6764
Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6747
6765
assert ! ( route. is_ok( ) ) ;
6748
6766
6749
6767
// Then check that we can't get a route if we ban an intermediate node.
6750
6768
scorer_params. add_banned ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) ;
6751
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6752
- Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6769
+ let route = get_route ( & our_id, & route_params, & network_graph, None , Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6753
6770
assert ! ( route. is_err( ) ) ;
6754
6771
6755
6772
// Finally make sure we can route again, when we remove the ban.
6756
6773
scorer_params. remove_banned ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) ;
6757
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6758
- Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6774
+ let route = get_route ( & our_id, & route_params, & network_graph, None , Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6759
6775
assert ! ( route. is_ok( ) ) ;
6760
6776
}
6761
6777
0 commit comments