Skip to content

Allow nodes to be avoided during pathfinding #1550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion lightning/src/routing/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1878,7 +1878,7 @@ mod tests {
use routing::router::{get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features,
PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees,
DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, MAX_PATH_LENGTH_ESTIMATE};
use routing::scoring::{ChannelUsage, Score};
use routing::scoring::{ChannelUsage, Score, ProbabilisticScorer, ProbabilisticScoringParameters};
use chain::transaction::OutPoint;
use chain::keysinterface::KeysInterface;
use ln::features::{ChannelFeatures, InitFeatures, InvoiceFeatures, NodeFeatures};
Expand Down Expand Up @@ -5713,6 +5713,33 @@ mod tests {
}
}
}

#[test]
fn avoids_banned_nodes() {
let (secp_ctx, network_graph, _, _, logger) = build_line_graph();
let (_, our_id, _, nodes) = get_nodes(&secp_ctx);

let keys_manager = test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
let random_seed_bytes = keys_manager.get_secure_random_bytes();

let scorer_params = ProbabilisticScoringParameters::default();
let mut scorer = ProbabilisticScorer::new(scorer_params, Arc::clone(&network_graph), Arc::clone(&logger));

// First check we can get a route.
let payment_params = PaymentParameters::from_node_id(nodes[10]);
let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, 42, Arc::clone(&logger), &scorer, &random_seed_bytes);
assert!(route.is_ok());

// Then check that we can't get a route if we ban an intermediate node.
scorer.add_banned(&NodeId::from_pubkey(&nodes[3]));
let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, 42, Arc::clone(&logger), &scorer, &random_seed_bytes);
assert!(route.is_err());

// Finally make sure we can route again, when we remove the ban.
scorer.remove_banned(&NodeId::from_pubkey(&nodes[3]));
let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, 42, Arc::clone(&logger), &scorer, &random_seed_bytes);
assert!(route.is_ok());
}
}

#[cfg(all(test, not(feature = "no-std")))]
Expand Down
53 changes: 44 additions & 9 deletions lightning/src/routing/scoring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ where L::Target: Logger {
///
/// Used to configure base, liquidity, and amount penalties, the sum of which comprises the channel
/// penalty (i.e., the amount in msats willing to be paid to avoid routing through the channel).
#[derive(Clone, Copy)]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, adding the HashSet here makes this only moveable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can leave Clone, no? Just have to drop Copy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I of course left Clone. Sorry if that was confusing.

#[derive(Clone)]
pub struct ProbabilisticScoringParameters {
/// A fixed penalty in msats to apply to each channel.
///
Expand Down Expand Up @@ -361,6 +361,11 @@ pub struct ProbabilisticScoringParameters {
///
/// Default value: 256 msat
pub amount_penalty_multiplier_msat: u64,

/// A list of nodes that won't be considered during path finding.
///
/// (C-not exported)
pub banned_nodes: HashSet<NodeId>,
}

/// Accounting for channel liquidity balance uncertainty.
Expand Down Expand Up @@ -451,6 +456,22 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> ProbabilisticScorerU
}
None
}

/// Marks the node with the given `node_id` as banned, i.e.,
/// it will be avoided during path finding.
pub fn add_banned(&mut self, node_id: &NodeId) {
self.params.banned_nodes.insert(*node_id);
}

/// Removes the node with the given `node_id` from the list of nodes to avoid.
pub fn remove_banned(&mut self, node_id: &NodeId) {
self.params.banned_nodes.remove(node_id);
}

/// Clears the list of nodes that are avoided during path finding.
pub fn clear_banned(&mut self) {
self.params.banned_nodes = HashSet::new();
}
}

impl ProbabilisticScoringParameters {
Expand All @@ -461,6 +482,15 @@ impl ProbabilisticScoringParameters {
liquidity_penalty_multiplier_msat: 0,
liquidity_offset_half_life: Duration::from_secs(3600),
amount_penalty_multiplier_msat: 0,
banned_nodes: HashSet::new(),
}
}

/// Marks all nodes in the given list as banned, i.e.,
/// they will be avoided during path finding.
pub fn add_banned_from_list(&mut self, node_ids: Vec<NodeId>) {
for id in node_ids {
self.banned_nodes.insert(id);
}
}
}
Expand All @@ -472,6 +502,7 @@ impl Default for ProbabilisticScoringParameters {
liquidity_penalty_multiplier_msat: 40_000,
liquidity_offset_half_life: Duration::from_secs(3600),
amount_penalty_multiplier_msat: 256,
banned_nodes: HashSet::new(),
}
}
}
Expand Down Expand Up @@ -543,7 +574,7 @@ const AMOUNT_PENALTY_DIVISOR: u64 = 1 << 20;
impl<L: Deref<Target = u64>, T: Time, U: Deref<Target = T>> DirectedChannelLiquidity<L, T, U> {
/// Returns a penalty for routing the given HTLC `amount_msat` through the channel in this
/// direction.
fn penalty_msat(&self, amount_msat: u64, params: ProbabilisticScoringParameters) -> u64 {
fn penalty_msat(&self, amount_msat: u64, params: &ProbabilisticScoringParameters) -> u64 {
let max_liquidity_msat = self.max_liquidity_msat();
let min_liquidity_msat = core::cmp::min(self.min_liquidity_msat(), max_liquidity_msat);
if amount_msat <= min_liquidity_msat {
Expand Down Expand Up @@ -580,7 +611,7 @@ impl<L: Deref<Target = u64>, T: Time, U: Deref<Target = T>> DirectedChannelLiqui
#[inline(always)]
fn combined_penalty_msat(
&self, amount_msat: u64, negative_log10_times_2048: u64,
params: ProbabilisticScoringParameters
params: &ProbabilisticScoringParameters
) -> u64 {
let liquidity_penalty_msat = {
// Upper bound the liquidity penalty to ensure some channel is selected.
Expand Down Expand Up @@ -672,6 +703,10 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> Score for Probabilis
fn channel_penalty_msat(
&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage
) -> u64 {
if self.params.banned_nodes.contains(source) || self.params.banned_nodes.contains(target) {
return u64::max_value();
}

if let EffectiveCapacity::ExactLiquidity { liquidity_msat } = usage.effective_capacity {
if usage.amount_msat > liquidity_msat {
return u64::max_value();
Expand All @@ -688,7 +723,7 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> Score for Probabilis
.get(&short_channel_id)
.unwrap_or(&ChannelLiquidity::new())
.as_directed(source, target, capacity_msat, liquidity_offset_half_life)
.penalty_msat(amount_msat, self.params)
.penalty_msat(amount_msat, &self.params)
}

fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) {
Expand Down Expand Up @@ -1072,7 +1107,7 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> Writeable for Probab
#[inline]
fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
write_tlv_fields!(w, {
(0, self.channel_liquidities, required)
(0, self.channel_liquidities, required),
});
Ok(())
}
Expand All @@ -1087,7 +1122,7 @@ ReadableArgs<(ProbabilisticScoringParameters, G, L)> for ProbabilisticScorerUsin
let (params, network_graph, logger) = args;
let mut channel_liquidities = HashMap::new();
read_tlv_fields!(r, {
(0, channel_liquidities, required)
(0, channel_liquidities, required),
});
Ok(Self {
params,
Expand Down Expand Up @@ -1862,7 +1897,7 @@ mod tests {
liquidity_offset_half_life: Duration::from_secs(10),
..ProbabilisticScoringParameters::zero_penalty()
};
let mut scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
let mut scorer = ProbabilisticScorer::new(params.clone(), &network_graph, &logger);
let source = source_node_id();
let target = target_node_id();
let usage = ChannelUsage {
Expand Down Expand Up @@ -1898,7 +1933,7 @@ mod tests {
liquidity_offset_half_life: Duration::from_secs(10),
..ProbabilisticScoringParameters::zero_penalty()
};
let mut scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
let mut scorer = ProbabilisticScorer::new(params.clone(), &network_graph, &logger);
let source = source_node_id();
let target = target_node_id();
let usage = ChannelUsage {
Expand Down Expand Up @@ -2086,7 +2121,7 @@ mod tests {
let logger = TestLogger::new();
let network_graph = network_graph(&logger);
let params = ProbabilisticScoringParameters::default();
let scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
let scorer = ProbabilisticScorer::new(params.clone(), &network_graph, &logger);
let source = source_node_id();
let target = target_node_id();

Expand Down