Skip to content

Fix overflow in historical scoring model point count summation #3616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions lightning/src/routing/scoring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2060,15 +2060,17 @@ mod bucketed_history {
}

fn recalculate_valid_point_count(&mut self) {
let mut total_valid_points_tracked = 0;
let mut total_valid_points_tracked = 0u128;
for (min_idx, min_bucket) in self.min_liquidity_offset_history.buckets.iter().enumerate() {
for max_bucket in self.max_liquidity_offset_history.buckets.iter().take(32 - min_idx) {
// In testing, raising the weights of buckets to a high power led to better
// scoring results. Thus, we raise the bucket weights to the 4th power here (by
// squaring the result of multiplying the weights).
// squaring the result of multiplying the weights). This results in
// bucket_weight having at max 64 bits, which means we have to do our summation
// in 128-bit math.
let mut bucket_weight = (*min_bucket as u64) * (*max_bucket as u64);
bucket_weight *= bucket_weight;
total_valid_points_tracked += bucket_weight;
total_valid_points_tracked += bucket_weight as u128;
}
}
self.total_valid_points_tracked = total_valid_points_tracked as f64;
Expand Down Expand Up @@ -2161,12 +2163,12 @@ mod bucketed_history {

let total_valid_points_tracked = self.tracker.total_valid_points_tracked;
#[cfg(debug_assertions)] {
let mut actual_valid_points_tracked = 0;
let mut actual_valid_points_tracked = 0u128;
for (min_idx, min_bucket) in min_liquidity_offset_history_buckets.iter().enumerate() {
for max_bucket in max_liquidity_offset_history_buckets.iter().take(32 - min_idx) {
let mut bucket_weight = (*min_bucket as u64) * (*max_bucket as u64);
bucket_weight *= bucket_weight;
actual_valid_points_tracked += bucket_weight;
actual_valid_points_tracked += bucket_weight as u128;
}
}
assert_eq!(total_valid_points_tracked, actual_valid_points_tracked as f64);
Expand All @@ -2193,7 +2195,7 @@ mod bucketed_history {
// max-bucket with at least BUCKET_FIXED_POINT_ONE.
let mut highest_max_bucket_with_points = 0;
let mut highest_max_bucket_with_full_points = None;
let mut total_weight = 0;
let mut total_weight = 0u128;
for (max_idx, max_bucket) in max_liquidity_offset_history_buckets.iter().enumerate() {
if *max_bucket >= BUCKET_FIXED_POINT_ONE {
highest_max_bucket_with_full_points = Some(cmp::max(highest_max_bucket_with_full_points.unwrap_or(0), max_idx));
Expand All @@ -2206,7 +2208,7 @@ mod bucketed_history {
// squaring the result of multiplying the weights), matching the logic in
// `recalculate_valid_point_count`.
let bucket_weight = (*max_bucket as u64) * (min_liquidity_offset_history_buckets[0] as u64);
total_weight += bucket_weight * bucket_weight;
total_weight += (bucket_weight * bucket_weight) as u128;
}
debug_assert!(total_weight as f64 <= total_valid_points_tracked);
// Use the highest max-bucket with at least BUCKET_FIXED_POINT_ONE, but if none is
Expand Down Expand Up @@ -2343,6 +2345,26 @@ mod bucketed_history {

assert_ne!(probability1, probability);
}

#[test]
fn historical_heavy_buckets_operations() {
// Checks that we don't hit overflows when working with tons of data (even an
// impossible-to-reach amount of data).
let mut tracker = HistoricalLiquidityTracker::new();
tracker.min_liquidity_offset_history.buckets = [0xffff; 32];
tracker.max_liquidity_offset_history.buckets = [0xffff; 32];
tracker.recalculate_valid_point_count();
tracker.merge(&tracker.clone());
assert_eq!(tracker.min_liquidity_offset_history.buckets, [0xffff; 32]);
assert_eq!(tracker.max_liquidity_offset_history.buckets, [0xffff; 32]);

let mut directed = tracker.as_directed_mut(true);
let default_params = ProbabilisticScoringFeeParameters::default();
directed.calculate_success_probability_times_billion(&default_params, 42, 1000);
directed.track_datapoint(42, 52, 1000);

tracker.decay_buckets(1.0);
}
}
}

Expand Down
Loading