Skip to content

Commit 398a69f

Browse files
committed
Move banned_nodes to params struct
1 parent 6c52118 commit 398a69f

File tree

2 files changed

+44
-30
lines changed

2 files changed

+44
-30
lines changed

lightning/src/routing/router.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5722,21 +5722,23 @@ mod tests {
57225722
let keys_manager = test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
57235723
let random_seed_bytes = keys_manager.get_secure_random_bytes();
57245724

5725-
let scorer_params = ProbabilisticScoringParameters::default();
5726-
let mut scorer = ProbabilisticScorer::new(scorer_params.clone(), Arc::clone(&network_graph), Arc::clone(&logger));
5725+
let payment_params = PaymentParameters::from_node_id(nodes[10]);
5726+
let mut scorer_params = ProbabilisticScoringParameters::default();
57275727

57285728
// First check we can get a route.
5729-
let payment_params = PaymentParameters::from_node_id(nodes[10]);
5729+
let scorer = ProbabilisticScorer::new(scorer_params.clone(), Arc::clone(&network_graph), Arc::clone(&logger));
57305730
let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, 42, Arc::clone(&logger), &scorer, &random_seed_bytes);
57315731
assert!(route.is_ok());
57325732

57335733
// Then check that we can't get a route if we ban an intermediate node.
5734-
scorer.add_banned(&NodeId::from_pubkey(&nodes[3]));
5734+
scorer_params.add_banned(&NodeId::from_pubkey(&nodes[3]));
5735+
let scorer = ProbabilisticScorer::new(scorer_params.clone(), Arc::clone(&network_graph), Arc::clone(&logger));
57355736
let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, 42, Arc::clone(&logger), &scorer, &random_seed_bytes);
57365737
assert!(route.is_err());
57375738

57385739
// Finally make sure we can route again, when we remove the ban.
5739-
scorer.remove_banned(&NodeId::from_pubkey(&nodes[3]));
5740+
scorer_params.remove_banned(&NodeId::from_pubkey(&nodes[3]));
5741+
let scorer = ProbabilisticScorer::new(scorer_params.clone(), Arc::clone(&network_graph), Arc::clone(&logger));
57405742
let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, 42, Arc::clone(&logger), &scorer, &random_seed_bytes);
57415743
assert!(route.is_ok());
57425744
}

lightning/src/routing/scoring.rs

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,13 @@ where L::Target: Logger {
300300
logger: L,
301301
// TODO: Remove entries of closed channels.
302302
channel_liquidities: HashMap<u64, ChannelLiquidity<T>>,
303-
banned_nodes: HashSet<NodeId>,
304303
}
305304

306305
/// Parameters for configuring [`ProbabilisticScorer`].
307306
///
308307
/// Used to configure base, liquidity, and amount penalties, the sum of which comprises the channel
309308
/// penalty (i.e., the amount in msats willing to be paid to avoid routing through the channel).
310-
#[derive(Clone, Copy)]
309+
#[derive(Clone)]
311310
pub struct ProbabilisticScoringParameters {
312311
/// A fixed penalty in msats to apply to each channel.
313312
///
@@ -362,6 +361,11 @@ pub struct ProbabilisticScoringParameters {
362361
///
363362
/// Default value: 256 msat
364363
pub amount_penalty_multiplier_msat: u64,
364+
365+
/// A list of nodes that won't be considered during path finding.
366+
///
367+
/// (C-not exported)
368+
pub banned_nodes: HashSet<NodeId>,
365369
}
366370

367371
/// Accounting for channel liquidity balance uncertainty.
@@ -400,7 +404,6 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> ProbabilisticScorerU
400404
network_graph,
401405
logger,
402406
channel_liquidities: HashMap::new(),
403-
banned_nodes: HashSet::new(),
404407
}
405408
}
406409

@@ -410,22 +413,6 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> ProbabilisticScorerU
410413
self
411414
}
412415

413-
/// Marks the node with the given `node_id` as banned, i.e.,
414-
/// it will be avoided during path finding.
415-
pub fn add_banned(&mut self, node_id: &NodeId) {
416-
self.banned_nodes.insert(*node_id);
417-
}
418-
419-
/// Removes the node with the given `node_id` from the list of nodes to avoid.
420-
pub fn remove_banned(&mut self, node_id: &NodeId) {
421-
self.banned_nodes.remove(node_id);
422-
}
423-
424-
/// Clears the list of nodes that are avoided during path finding.
425-
pub fn clear_banned(&mut self) {
426-
self.banned_nodes = HashSet::new();
427-
}
428-
429416
/// Dump the contents of this scorer into the configured logger.
430417
///
431418
/// Note that this writes roughly one line per channel for which we have a liquidity estimate,
@@ -479,8 +466,33 @@ impl ProbabilisticScoringParameters {
479466
liquidity_penalty_multiplier_msat: 0,
480467
liquidity_offset_half_life: Duration::from_secs(3600),
481468
amount_penalty_multiplier_msat: 0,
469+
banned_nodes: HashSet::new(),
470+
}
471+
}
472+
473+
/// Marks the node with the given `node_id` as banned, i.e.,
474+
/// it will be avoided during path finding.
475+
pub fn add_banned(&mut self, node_id: &NodeId) {
476+
self.banned_nodes.insert(*node_id);
477+
}
478+
479+
/// Marks all nodes in the given list as banned, i.e.,
480+
/// they will be avoided during path finding.
481+
pub fn add_banned_from_list(&mut self, node_ids: Vec<NodeId>) {
482+
for id in node_ids {
483+
self.banned_nodes.insert(id);
482484
}
483485
}
486+
487+
/// Removes the node with the given `node_id` from the list of nodes to avoid.
488+
pub fn remove_banned(&mut self, node_id: &NodeId) {
489+
self.banned_nodes.remove(node_id);
490+
}
491+
492+
/// Clears the list of nodes that are avoided during path finding.
493+
pub fn clear_banned(&mut self) {
494+
self.banned_nodes = HashSet::new();
495+
}
484496
}
485497

486498
impl Default for ProbabilisticScoringParameters {
@@ -490,6 +502,7 @@ impl Default for ProbabilisticScoringParameters {
490502
liquidity_penalty_multiplier_msat: 40_000,
491503
liquidity_offset_half_life: Duration::from_secs(3600),
492504
amount_penalty_multiplier_msat: 256,
505+
banned_nodes: HashSet::new(),
493506
}
494507
}
495508
}
@@ -690,7 +703,7 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> Score for Probabilis
690703
fn channel_penalty_msat(
691704
&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage
692705
) -> u64 {
693-
if self.banned_nodes.contains(source) || self.banned_nodes.contains(target) {
706+
if self.params.banned_nodes.contains(source) || self.params.banned_nodes.contains(target) {
694707
return u64::max_value();
695708
}
696709

@@ -710,7 +723,7 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> Score for Probabilis
710723
.get(&short_channel_id)
711724
.unwrap_or(&ChannelLiquidity::new())
712725
.as_directed(source, target, capacity_msat, liquidity_offset_half_life)
713-
.penalty_msat(amount_msat, self.params)
726+
.penalty_msat(amount_msat, self.params.clone())
714727
}
715728

716729
fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) {
@@ -1116,7 +1129,6 @@ ReadableArgs<(ProbabilisticScoringParameters, G, L)> for ProbabilisticScorerUsin
11161129
network_graph,
11171130
logger,
11181131
channel_liquidities,
1119-
banned_nodes: HashSet::new(),
11201132
})
11211133
}
11221134
}
@@ -1885,7 +1897,7 @@ mod tests {
18851897
liquidity_offset_half_life: Duration::from_secs(10),
18861898
..ProbabilisticScoringParameters::zero_penalty()
18871899
};
1888-
let mut scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
1900+
let mut scorer = ProbabilisticScorer::new(params.clone(), &network_graph, &logger);
18891901
let source = source_node_id();
18901902
let target = target_node_id();
18911903
let usage = ChannelUsage {
@@ -1921,7 +1933,7 @@ mod tests {
19211933
liquidity_offset_half_life: Duration::from_secs(10),
19221934
..ProbabilisticScoringParameters::zero_penalty()
19231935
};
1924-
let mut scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
1936+
let mut scorer = ProbabilisticScorer::new(params.clone(), &network_graph, &logger);
19251937
let source = source_node_id();
19261938
let target = target_node_id();
19271939
let usage = ChannelUsage {
@@ -2109,7 +2121,7 @@ mod tests {
21092121
let logger = TestLogger::new();
21102122
let network_graph = network_graph(&logger);
21112123
let params = ProbabilisticScoringParameters::default();
2112-
let scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
2124+
let scorer = ProbabilisticScorer::new(params.clone(), &network_graph, &logger);
21132125
let source = source_node_id();
21142126
let target = target_node_id();
21152127

0 commit comments

Comments
 (0)