Skip to content

Commit d3273d3

Browse files
committed
Changes to Router
1 parent bb49c2f commit d3273d3

File tree

1 file changed

+67
-57
lines changed

1 file changed

+67
-57
lines changed

lightning/src/routing/router.rs

Lines changed: 67 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -33,34 +33,36 @@ use core::cmp;
3333
use core::ops::Deref;
3434

3535
/// A [`Router`] implemented using [`find_route`].
36-
pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref> where
36+
pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref> where
3737
L::Target: Logger,
38-
S::Target: for <'a> LockableScore<'a>,
3938
{
4039
network_graph: G,
4140
logger: L,
42-
random_seed_bytes: Mutex<[u8; 32]>,
43-
scorer: S
41+
random_seed_bytes: Mutex<[u8; 32]>
4442
}
4543

46-
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref> DefaultRouter<G, L, S> where
44+
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> DefaultRouter<G, L> where
4745
L::Target: Logger,
48-
S::Target: for <'a> LockableScore<'a>,
4946
{
5047
/// Creates a new router.
51-
pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32], scorer: S) -> Self {
48+
pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32]) -> Self {
5249
let random_seed_bytes = Mutex::new(random_seed_bytes);
53-
Self { network_graph, logger, random_seed_bytes, scorer }
50+
Self { network_graph, logger, random_seed_bytes }
5451
}
5552
}
5653

57-
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref> Router for DefaultRouter<G, L, S> where
54+
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref> Router<S> for DefaultRouter<G, L> where
5855
L::Target: Logger,
5956
S::Target: for <'a> LockableScore<'a>,
6057
{
6158
fn find_route(
62-
&self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>,
63-
inflight_htlcs: &InFlightHtlcs
59+
&self,
60+
payer: &PublicKey,
61+
params: &RouteParameters,
62+
first_hops: Option<&[&ChannelDetails]>,
63+
inflight_htlcs: &InFlightHtlcs,
64+
scorer: S,
65+
score_params: &<<S as Deref>::Target as Score>::ScoreParams
6466
) -> Result<Route, LightningError> {
6567
let random_seed_bytes = {
6668
let mut locked_random_seed_bytes = self.random_seed_bytes.lock().unwrap();
@@ -70,27 +72,27 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref> Router for DefaultR
7072

7173
find_route(
7274
payer, params, &self.network_graph, first_hops, &*self.logger,
73-
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock(), inflight_htlcs),
75+
&ScorerAccountingForInFlightHtlcs::new(scorer.lock(), inflight_htlcs), score_params,
7476
&random_seed_bytes
7577
)
7678
}
7779
}
7880

7981
/// A trait defining behavior for routing a payment.
80-
pub trait Router {
82+
pub trait Router<S: Deref>{
8183
/// Finds a [`Route`] between `payer` and `payee` for a payment with the given values.
82-
fn find_route(
84+
fn find_route (
8385
&self, payer: &PublicKey, route_params: &RouteParameters,
84-
first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: &InFlightHtlcs
86+
first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: &InFlightHtlcs, scorer:S ,score_params: &<<S as Deref>::Target as Score>::ScoreParams
8587
) -> Result<Route, LightningError>;
8688
/// Finds a [`Route`] between `payer` and `payee` for a payment with the given values. Includes
8789
/// `PaymentHash` and `PaymentId` to be able to correlate the request with a specific payment.
88-
fn find_route_with_id(
90+
fn find_route_with_id (
8991
&self, payer: &PublicKey, route_params: &RouteParameters,
9092
first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: &InFlightHtlcs,
91-
_payment_hash: PaymentHash, _payment_id: PaymentId
93+
_payment_hash: PaymentHash, _payment_id: PaymentId, scorer:S ,score_params: &<<S as Deref>::Target as Score>::ScoreParams
9294
) -> Result<Route, LightningError> {
93-
self.find_route(payer, route_params, first_hops, inflight_htlcs)
95+
self.find_route(payer, route_params, first_hops, inflight_htlcs, score_params)
9496
}
9597
}
9698

@@ -122,7 +124,8 @@ impl<'a, S: Score> Writeable for ScorerAccountingForInFlightHtlcs<'a, S> {
122124
}
123125

124126
impl<'a, S: Score> Score for ScorerAccountingForInFlightHtlcs<'a, S> {
125-
fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage) -> u64 {
127+
type ScoreParams = S::ScoreParams;
128+
fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &S::ScoreParams) -> u64 {
126129
if let Some(used_liquidity) = self.inflight_htlcs.used_liquidity_msat(
127130
source, target, short_channel_id
128131
) {
@@ -131,26 +134,26 @@ impl<'a, S: Score> Score for ScorerAccountingForInFlightHtlcs<'a, S> {
131134
..usage
132135
};
133136

134-
self.scorer.channel_penalty_msat(short_channel_id, source, target, usage)
137+
self.scorer.channel_penalty_msat(short_channel_id, source, target, usage, score_params)
135138
} else {
136-
self.scorer.channel_penalty_msat(short_channel_id, source, target, usage)
139+
self.scorer.channel_penalty_msat(short_channel_id, source, target, usage, score_params)
137140
}
138141
}
139142

140-
fn payment_path_failed(&mut self, path: &Path, short_channel_id: u64) {
141-
self.scorer.payment_path_failed(path, short_channel_id)
143+
fn payment_path_failed(&mut self, path: &Path, short_channel_id: u64, score_params: &S::ScoreParams) {
144+
self.scorer.payment_path_failed(path, short_channel_id, score_params)
142145
}
143146

144-
fn payment_path_successful(&mut self, path: &Path) {
145-
self.scorer.payment_path_successful(path)
147+
fn payment_path_successful(&mut self, path: &Path, score_params: &S::ScoreParams) {
148+
self.scorer.payment_path_successful(path, score_params)
146149
}
147150

148-
fn probe_failed(&mut self, path: &Path, short_channel_id: u64) {
149-
self.scorer.probe_failed(path, short_channel_id)
151+
fn probe_failed(&mut self, path: &Path, short_channel_id: u64, score_params: &S::ScoreParams) {
152+
self.scorer.probe_failed(path, short_channel_id, score_params)
150153
}
151154

152-
fn probe_successful(&mut self, path: &Path) {
153-
self.scorer.probe_successful(path)
155+
fn probe_successful(&mut self, path: &Path, score_params: &S::ScoreParams) {
156+
self.scorer.probe_successful(path, score_params)
154157
}
155158
}
156159

@@ -1097,15 +1100,20 @@ fn default_node_features() -> NodeFeatures {
10971100
/// [`Event::PaymentPathFailed`]: crate::events::Event::PaymentPathFailed
10981101
/// [`NetworkGraph`]: crate::routing::gossip::NetworkGraph
10991102
pub fn find_route<L: Deref, GL: Deref, S: Score>(
1100-
our_node_pubkey: &PublicKey, route_params: &RouteParameters,
1101-
network_graph: &NetworkGraph<GL>, first_hops: Option<&[&ChannelDetails]>, logger: L,
1102-
scorer: &S, random_seed_bytes: &[u8; 32]
1103+
our_node_pubkey: &PublicKey,
1104+
route_params: &RouteParameters,
1105+
network_graph: &NetworkGraph<GL>,
1106+
first_hops: Option<&[&ChannelDetails]>,
1107+
logger: L,
1108+
scorer: &S,
1109+
score_params: &S::ScoreParams,
1110+
random_seed_bytes: &[u8; 32]
11031111
) -> Result<Route, LightningError>
11041112
where L::Target: Logger, GL::Target: Logger {
11051113
let graph_lock = network_graph.read_only();
11061114
let final_cltv_expiry_delta = route_params.payment_params.final_cltv_expiry_delta;
11071115
let mut route = get_route(our_node_pubkey, &route_params.payment_params, &graph_lock, first_hops,
1108-
route_params.final_value_msat, final_cltv_expiry_delta, logger, scorer,
1116+
route_params.final_value_msat, final_cltv_expiry_delta, logger, scorer, score_params,
11091117
random_seed_bytes)?;
11101118
add_random_cltv_offset(&mut route, &route_params.payment_params, &graph_lock, random_seed_bytes);
11111119
Ok(route)
@@ -1114,7 +1122,7 @@ where L::Target: Logger, GL::Target: Logger {
11141122
pub(crate) fn get_route<L: Deref, S: Score>(
11151123
our_node_pubkey: &PublicKey, payment_params: &PaymentParameters, network_graph: &ReadOnlyNetworkGraph,
11161124
first_hops: Option<&[&ChannelDetails]>, final_value_msat: u64, final_cltv_expiry_delta: u32,
1117-
logger: L, scorer: &S, _random_seed_bytes: &[u8; 32]
1125+
logger: L, scorer: &S, score_params: &S::ScoreParams, _random_seed_bytes: &[u8; 32]
11181126
) -> Result<Route, LightningError>
11191127
where L::Target: Logger {
11201128
let payee_node_id = NodeId::from_pubkey(&payment_params.payee_pubkey);
@@ -1475,7 +1483,7 @@ where L::Target: Logger {
14751483
effective_capacity,
14761484
};
14771485
let channel_penalty_msat = scorer.channel_penalty_msat(
1478-
short_channel_id, &$src_node_id, &$dest_node_id, channel_usage
1486+
short_channel_id, &$src_node_id, &$dest_node_id, channel_usage, score_params
14791487
);
14801488
let path_penalty_msat = $next_hops_path_penalty_msat
14811489
.saturating_add(channel_penalty_msat);
@@ -1723,7 +1731,7 @@ where L::Target: Logger {
17231731
effective_capacity: candidate.effective_capacity(),
17241732
};
17251733
let channel_penalty_msat = scorer.channel_penalty_msat(
1726-
hop.short_channel_id, &source, &target, channel_usage
1734+
hop.short_channel_id, &source, &target, channel_usage, score_params
17271735
);
17281736
aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
17291737
.saturating_add(channel_penalty_msat);
@@ -2222,7 +2230,7 @@ fn build_route_from_hops_internal<L: Deref>(
22222230

22232231
impl Score for HopScorer {
22242232
fn channel_penalty_msat(&self, _short_channel_id: u64, source: &NodeId, target: &NodeId,
2225-
_usage: ChannelUsage) -> u64
2233+
_usage: ChannelUsage, scorer_params: &Self::ScoreParams) -> u64
22262234
{
22272235
let mut cur_id = self.our_node_id;
22282236
for i in 0..self.hop_ids.len() {
@@ -2238,13 +2246,13 @@ fn build_route_from_hops_internal<L: Deref>(
22382246
u64::max_value()
22392247
}
22402248

2241-
fn payment_path_failed(&mut self, _path: &Path, _short_channel_id: u64) {}
2249+
fn payment_path_failed(&mut self, _path: &Path, _short_channel_id: u64, scorer_params: &Self::ScoreParams) {}
22422250

2243-
fn payment_path_successful(&mut self, _path: &Path) {}
2251+
fn payment_path_successful(&mut self, _path: &Path, scorer_params: &Self::ScoreParams) {}
22442252

2245-
fn probe_failed(&mut self, _path: &Path, _short_channel_id: u64) {}
2253+
fn probe_failed(&mut self, _path: &Path, _short_channel_id: u64, scorer_params: &Self::ScoreParams) {}
22462254

2247-
fn probe_successful(&mut self, _path: &Path) {}
2255+
fn probe_successful(&mut self, _path: &Path, scorer_params: &Self::ScoreParams) {}
22482256
}
22492257

22502258
impl<'a> Writeable for HopScorer {
@@ -2267,7 +2275,7 @@ fn build_route_from_hops_internal<L: Deref>(
22672275
let scorer = HopScorer { our_node_id, hop_ids };
22682276

22692277
get_route(our_node_pubkey, payment_params, network_graph, None, final_value_msat,
2270-
final_cltv_expiry_delta, logger, &scorer, random_seed_bytes)
2278+
final_cltv_expiry_delta, logger, &scorer, &0, random_seed_bytes)
22712279
}
22722280

22732281
#[cfg(test)]
@@ -5254,14 +5262,15 @@ mod tests {
52545262
fn write<W: Writer>(&self, _w: &mut W) -> Result<(), crate::io::Error> { unimplemented!() }
52555263
}
52565264
impl Score for BadChannelScorer {
5257-
fn channel_penalty_msat(&self, short_channel_id: u64, _: &NodeId, _: &NodeId, _: ChannelUsage) -> u64 {
5265+
type ScoreParams = i32;
5266+
fn channel_penalty_msat(&self, short_channel_id: u64, _: &NodeId, _: &NodeId, _: ChannelUsage, scorer_params: &Self::ScoreParams) -> u64 {
52585267
if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
52595268
}
52605269

5261-
fn payment_path_failed(&mut self, _path: &Path, _short_channel_id: u64) {}
5262-
fn payment_path_successful(&mut self, _path: &Path) {}
5263-
fn probe_failed(&mut self, _path: &Path, _short_channel_id: u64) {}
5264-
fn probe_successful(&mut self, _path: &Path) {}
5270+
fn payment_path_failed(&mut self, _path: &Path, _short_channel_id: u64, scorer_params: &Self::ScoreParams) {}
5271+
fn payment_path_successful(&mut self, _path: &Path, scorer_params: &Self::ScoreParams) {}
5272+
fn probe_failed(&mut self, _path: &Path, _short_channel_id: u64, scorer_params: &Self::ScoreParams) {}
5273+
fn probe_successful(&mut self, _path: &Path, scorer_params: &Self::ScoreParams) {}
52655274
}
52665275

52675276
struct BadNodeScorer {
@@ -5274,14 +5283,15 @@ mod tests {
52745283
}
52755284

52765285
impl Score for BadNodeScorer {
5277-
fn channel_penalty_msat(&self, _: u64, _: &NodeId, target: &NodeId, _: ChannelUsage) -> u64 {
5286+
type ScoreParams = i32;
5287+
fn channel_penalty_msat(&self, _: u64, _: &NodeId, target: &NodeId, _: ChannelUsage, scorer_params: &Self::ScoreParams) -> u64 {
52785288
if *target == self.node_id { u64::max_value() } else { 0 }
52795289
}
52805290

5281-
fn payment_path_failed(&mut self, _path: &Path, _short_channel_id: u64) {}
5282-
fn payment_path_successful(&mut self, _path: &Path) {}
5283-
fn probe_failed(&mut self, _path: &Path, _short_channel_id: u64) {}
5284-
fn probe_successful(&mut self, _path: &Path) {}
5291+
fn payment_path_failed(&mut self, _path: &Path, _short_channel_id: u64, scorer_params: &Self::ScoreParams) {}
5292+
fn payment_path_successful(&mut self, _path: &Path, scorer_params: &Self::ScoreParams) {}
5293+
fn probe_failed(&mut self, _path: &Path, _short_channel_id: u64, scorer_params: &Self::ScoreParams) {}
5294+
fn probe_successful(&mut self, _path: &Path, scorer_params: &Self::ScoreParams) {}
52855295
}
52865296

52875297
#[test]
@@ -5756,22 +5766,22 @@ mod tests {
57565766
inflight_htlc_msat: 0,
57575767
effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_024_000, htlc_maximum_msat: 1_000 },
57585768
};
5759-
scorer.set_manual_penalty(&NodeId::from_pubkey(&nodes[3]), 123);
5760-
scorer.set_manual_penalty(&NodeId::from_pubkey(&nodes[4]), 456);
5761-
assert_eq!(scorer.channel_penalty_msat(42, &NodeId::from_pubkey(&nodes[3]), &NodeId::from_pubkey(&nodes[4]), usage), 456);
5769+
scorer.set_manual_penalty(&NodeId::from_pubkey(&nodes[3]), 123, &scorer_params);
5770+
scorer.set_manual_penalty(&NodeId::from_pubkey(&nodes[4]), 456, &scorer_params);
5771+
assert_eq!(scorer.channel_penalty_msat(42, &NodeId::from_pubkey(&nodes[3]), &NodeId::from_pubkey(&nodes[4]), usage, &scorer_params), 456);
57625772

57635773
// Then check we can get a normal route
57645774
let payment_params = PaymentParameters::from_node_id(nodes[10], 42);
57655775
let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, 42, Arc::clone(&logger), &scorer, &random_seed_bytes);
57665776
assert!(route.is_ok());
57675777

57685778
// Then check that we can't get a route if we ban an intermediate node.
5769-
scorer.add_banned(&NodeId::from_pubkey(&nodes[3]));
5779+
scorer.add_banned(&NodeId::from_pubkey(&nodes[3]), &scorer_params);
57705780
let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, 42, Arc::clone(&logger), &scorer, &random_seed_bytes);
57715781
assert!(route.is_err());
57725782

57735783
// Finally make sure we can route again, when we remove the ban.
5774-
scorer.remove_banned(&NodeId::from_pubkey(&nodes[3]));
5784+
scorer.remove_banned(&NodeId::from_pubkey(&nodes[3]), &scorer_params);
57755785
let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, 42, Arc::clone(&logger), &scorer, &random_seed_bytes);
57765786
assert!(route.is_ok());
57775787
}

0 commit comments

Comments
 (0)