Skip to content

Commit 0d1072b

Browse files
authored
Merge pull request #2383 from henghonglee/fix-dyn
Fix DefaultRouter type restrained to only MutexGuard
2 parents 86fd9e7 + 54bcb6e commit 0d1072b

File tree

5 files changed

+91
-55
lines changed

5 files changed

+91
-55
lines changed

lightning-background-processor/src/lib.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,22 @@ mod tests {
885885
fn disconnect_socket(&mut self) {}
886886
}
887887

888-
type ChannelManager = channelmanager::ChannelManager<Arc<ChainMonitor>, Arc<test_utils::TestBroadcaster>, Arc<KeysManager>, Arc<KeysManager>, Arc<KeysManager>, Arc<test_utils::TestFeeEstimator>, Arc<DefaultRouter<Arc<NetworkGraph<Arc<test_utils::TestLogger>>>, Arc<test_utils::TestLogger>, Arc<Mutex<TestScorer>>, (), TestScorer>>, Arc<test_utils::TestLogger>>;
888+
type ChannelManager =
889+
channelmanager::ChannelManager<
890+
Arc<ChainMonitor>,
891+
Arc<test_utils::TestBroadcaster>,
892+
Arc<KeysManager>,
893+
Arc<KeysManager>,
894+
Arc<KeysManager>,
895+
Arc<test_utils::TestFeeEstimator>,
896+
Arc<DefaultRouter<
897+
Arc<NetworkGraph<Arc<test_utils::TestLogger>>>,
898+
Arc<test_utils::TestLogger>,
899+
Arc<Mutex<TestScorer>>,
900+
(),
901+
TestScorer>
902+
>,
903+
Arc<test_utils::TestLogger>>;
889904

890905
type ChainMonitor = chainmonitor::ChainMonitor<InMemorySigner, Arc<test_utils::TestChainSource>, Arc<test_utils::TestBroadcaster>, Arc<test_utils::TestFeeEstimator>, Arc<test_utils::TestLogger>, Arc<FilesystemPersister>>;
891906

lightning/src/ln/channelmanager.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,23 @@ pub type SimpleArcChannelManager<M, T, F, L> = ChannelManager<
752752
/// of [`KeysManager`] and [`DefaultRouter`].
753753
///
754754
/// This is not exported to bindings users as Arcs don't make sense in bindings
755-
pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> = ChannelManager<&'a M, &'b T, &'c KeysManager, &'c KeysManager, &'c KeysManager, &'d F, &'e DefaultRouter<&'f NetworkGraph<&'g L>, &'g L, &'h Mutex<ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>, ProbabilisticScoringFeeParameters, ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>, &'g L>;
755+
pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> =
756+
ChannelManager<
757+
&'a M,
758+
&'b T,
759+
&'c KeysManager,
760+
&'c KeysManager,
761+
&'c KeysManager,
762+
&'d F,
763+
&'e DefaultRouter<
764+
&'f NetworkGraph<&'g L>,
765+
&'g L,
766+
&'h Mutex<ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>,
767+
ProbabilisticScoringFeeParameters,
768+
ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>
769+
>,
770+
&'g L
771+
>;
756772

757773
macro_rules! define_test_pub_trait { ($vis: vis) => {
758774
/// A trivial trait which describes any [`ChannelManager`] used in testing.

lightning/src/routing/router.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ use crate::util::chacha20::ChaCha20;
2727

2828
use crate::io;
2929
use crate::prelude::*;
30-
use crate::sync::{Mutex, MutexGuard};
30+
use crate::sync::{Mutex};
3131
use alloc::collections::BinaryHeap;
3232
use core::{cmp, fmt};
33-
use core::ops::Deref;
33+
use core::ops::{Deref, DerefMut};
3434

3535
/// A [`Router`] implemented using [`find_route`].
3636
pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> where
3737
L::Target: Logger,
38-
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
38+
S::Target: for <'a> LockableScore<'a, Score = Sc>,
3939
{
4040
network_graph: G,
4141
logger: L,
@@ -46,7 +46,7 @@ pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref,
4646

4747
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> DefaultRouter<G, L, S, SP, Sc> where
4848
L::Target: Logger,
49-
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
49+
S::Target: for <'a> LockableScore<'a, Score = Sc>,
5050
{
5151
/// Creates a new router.
5252
pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32], scorer: S, score_params: SP) -> Self {
@@ -55,9 +55,9 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Scor
5555
}
5656
}
5757

58-
impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> Router for DefaultRouter<G, L, S, SP, Sc> where
58+
impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> Router for DefaultRouter<G, L, S, SP, Sc> where
5959
L::Target: Logger,
60-
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
60+
S::Target: for <'a> LockableScore<'a, Score = Sc>,
6161
{
6262
fn find_route(
6363
&self,
@@ -73,7 +73,7 @@ impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Sc
7373
};
7474
find_route(
7575
payer, params, &self.network_graph, first_hops, &*self.logger,
76-
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock(), inflight_htlcs),
76+
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().deref_mut(), inflight_htlcs),
7777
&self.score_params,
7878
&random_seed_bytes
7979
)
@@ -104,15 +104,15 @@ pub trait Router {
104104
/// [`find_route`].
105105
///
106106
/// [`Score`]: crate::routing::scoring::Score
107-
pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score> {
108-
scorer: S,
107+
pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score<ScoreParams = SP>, SP: Sized> {
108+
scorer: &'a mut S,
109109
// Maps a channel's short channel id and its direction to the liquidity used up.
110110
inflight_htlcs: &'a InFlightHtlcs,
111111
}
112112

113-
impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> {
113+
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> ScorerAccountingForInFlightHtlcs<'a, S, SP> {
114114
/// Initialize a new `ScorerAccountingForInFlightHtlcs`.
115-
pub fn new(scorer: S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
115+
pub fn new(scorer: &'a mut S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
116116
ScorerAccountingForInFlightHtlcs {
117117
scorer,
118118
inflight_htlcs
@@ -121,11 +121,11 @@ impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> {
121121
}
122122

123123
#[cfg(c_bindings)]
124-
impl<'a, S: Score> Writeable for ScorerAccountingForInFlightHtlcs<'a, S> {
124+
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> Writeable for ScorerAccountingForInFlightHtlcs<'a, S, SP> {
125125
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> { self.scorer.write(writer) }
126126
}
127127

128-
impl<'a, S: Score> Score for ScorerAccountingForInFlightHtlcs<'a, S> {
128+
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> Score for ScorerAccountingForInFlightHtlcs<'a, S, SP> {
129129
type ScoreParams = S::ScoreParams;
130130
fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
131131
if let Some(used_liquidity) = self.inflight_htlcs.used_liquidity_msat(

lightning/src/routing/scoring.rs

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,11 @@ define_score!();
157157
///
158158
/// [`find_route`]: crate::routing::router::find_route
159159
pub trait LockableScore<'a> {
160+
/// The [`Score`] type.
161+
type Score: 'a + Score;
162+
160163
/// The locked [`Score`] type.
161-
type Locked: 'a + Score;
164+
type Locked: DerefMut<Target = Self::Score> + Sized;
162165

163166
/// Returns the locked scorer.
164167
fn lock(&'a self) -> Self::Locked;
@@ -174,60 +177,35 @@ pub trait WriteableScore<'a>: LockableScore<'a> + Writeable {}
174177
impl<'a, T> WriteableScore<'a> for T where T: LockableScore<'a> + Writeable {}
175178
/// This is not exported to bindings users
176179
impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
180+
type Score = T;
177181
type Locked = MutexGuard<'a, T>;
178182

179-
fn lock(&'a self) -> MutexGuard<'a, T> {
183+
fn lock(&'a self) -> Self::Locked {
180184
Mutex::lock(self).unwrap()
181185
}
182186
}
183187

184188
impl<'a, T: 'a + Score> LockableScore<'a> for RefCell<T> {
189+
type Score = T;
185190
type Locked = RefMut<'a, T>;
186191

187-
fn lock(&'a self) -> RefMut<'a, T> {
192+
fn lock(&'a self) -> Self::Locked {
188193
self.borrow_mut()
189194
}
190195
}
191196

192197
#[cfg(c_bindings)]
193198
/// A concrete implementation of [`LockableScore`] which supports multi-threading.
194-
pub struct MultiThreadedLockableScore<S: Score> {
195-
score: Mutex<S>,
196-
}
197-
#[cfg(c_bindings)]
198-
/// A locked `MultiThreadedLockableScore`.
199-
pub struct MultiThreadedScoreLock<'a, S: Score>(MutexGuard<'a, S>);
200-
#[cfg(c_bindings)]
201-
impl<'a, T: Score + 'a> Score for MultiThreadedScoreLock<'a, T> {
202-
type ScoreParams = <T as Score>::ScoreParams;
203-
fn channel_penalty_msat(&self, scid: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
204-
self.0.channel_penalty_msat(scid, source, target, usage, score_params)
205-
}
206-
fn payment_path_failed(&mut self, path: &Path, short_channel_id: u64) {
207-
self.0.payment_path_failed(path, short_channel_id)
208-
}
209-
fn payment_path_successful(&mut self, path: &Path) {
210-
self.0.payment_path_successful(path)
211-
}
212-
fn probe_failed(&mut self, path: &Path, short_channel_id: u64) {
213-
self.0.probe_failed(path, short_channel_id)
214-
}
215-
fn probe_successful(&mut self, path: &Path) {
216-
self.0.probe_successful(path)
217-
}
218-
}
219-
#[cfg(c_bindings)]
220-
impl<'a, T: Score + 'a> Writeable for MultiThreadedScoreLock<'a, T> {
221-
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
222-
self.0.write(writer)
223-
}
199+
pub struct MultiThreadedLockableScore<T: Score> {
200+
score: Mutex<T>,
224201
}
225202

226203
#[cfg(c_bindings)]
227-
impl<'a, T: Score + 'a> LockableScore<'a> for MultiThreadedLockableScore<T> {
204+
impl<'a, T: 'a + Score> LockableScore<'a> for MultiThreadedLockableScore<T> {
205+
type Score = T;
228206
type Locked = MultiThreadedScoreLock<'a, T>;
229207

230-
fn lock(&'a self) -> MultiThreadedScoreLock<'a, T> {
208+
fn lock(&'a self) -> Self::Locked {
231209
MultiThreadedScoreLock(Mutex::lock(&self.score).unwrap())
232210
}
233211
}
@@ -240,7 +218,7 @@ impl<T: Score> Writeable for MultiThreadedLockableScore<T> {
240218
}
241219

242220
#[cfg(c_bindings)]
243-
impl<'a, T: Score + 'a> WriteableScore<'a> for MultiThreadedLockableScore<T> {}
221+
impl<'a, T: 'a + Score> WriteableScore<'a> for MultiThreadedLockableScore<T> {}
244222

245223
#[cfg(c_bindings)]
246224
impl<T: Score> MultiThreadedLockableScore<T> {
@@ -250,6 +228,33 @@ impl<T: Score> MultiThreadedLockableScore<T> {
250228
}
251229
}
252230

231+
#[cfg(c_bindings)]
232+
/// A locked `MultiThreadedLockableScore`.
233+
pub struct MultiThreadedScoreLock<'a, T: Score>(MutexGuard<'a, T>);
234+
235+
#[cfg(c_bindings)]
236+
impl<'a, T: 'a + Score> Writeable for MultiThreadedScoreLock<'a, T> {
237+
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
238+
self.0.write(writer)
239+
}
240+
}
241+
242+
#[cfg(c_bindings)]
243+
impl<'a, T: 'a + Score> DerefMut for MultiThreadedScoreLock<'a, T> {
244+
fn deref_mut(&mut self) -> &mut Self::Target {
245+
self.0.deref_mut()
246+
}
247+
}
248+
249+
#[cfg(c_bindings)]
250+
impl<'a, T: 'a + Score> Deref for MultiThreadedScoreLock<'a, T> {
251+
type Target = T;
252+
253+
fn deref(&self) -> &Self::Target {
254+
self.0.deref()
255+
}
256+
}
257+
253258
#[cfg(c_bindings)]
254259
/// This is not exported to bindings users
255260
impl<'a, T: Writeable> Writeable for RefMut<'a, T> {

lightning/src/util/test_utils.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ use regex;
5151
use crate::io;
5252
use crate::prelude::*;
5353
use core::cell::RefCell;
54+
use core::ops::DerefMut;
5455
use core::time::Duration;
5556
use crate::sync::{Mutex, Arc};
5657
use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
@@ -113,8 +114,8 @@ impl<'a> Router for TestRouter<'a> {
113114
if let Some((find_route_query, find_route_res)) = self.next_routes.lock().unwrap().pop_front() {
114115
assert_eq!(find_route_query, *params);
115116
if let Ok(ref route) = find_route_res {
116-
let locked_scorer = self.scorer.lock().unwrap();
117-
let scorer = ScorerAccountingForInFlightHtlcs::new(locked_scorer, inflight_htlcs);
117+
let mut binding = self.scorer.lock().unwrap();
118+
let scorer = ScorerAccountingForInFlightHtlcs::new(binding.deref_mut(), inflight_htlcs);
118119
for path in &route.paths {
119120
let mut aggregate_msat = 0u64;
120121
for (idx, hop) in path.hops.iter().rev().enumerate() {
@@ -139,10 +140,9 @@ impl<'a> Router for TestRouter<'a> {
139140
return find_route_res;
140141
}
141142
let logger = TestLogger::new();
142-
let scorer = self.scorer.lock().unwrap();
143143
find_route(
144144
payer, params, &self.network_graph, first_hops, &logger,
145-
&ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs), &(),
145+
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().unwrap().deref_mut(), &inflight_htlcs), &(),
146146
&[42; 32]
147147
)
148148
}

0 commit comments

Comments
 (0)